Compare commits

..

1 Commits

Author SHA1 Message Date
Viktor Liu
3251bc79fa Preresolve domains 2025-06-13 14:53:27 +02:00
178 changed files with 4212 additions and 8277 deletions

View File

@@ -9,7 +9,7 @@ on:
pull_request:
env:
SIGN_PIPE_VER: "v0.0.20"
SIGN_PIPE_VER: "v0.0.18"
GORELEASER_VER: "v2.3.2"
PRODUCT_NAME: "NetBird"
COPYRIGHT: "NetBird GmbH"

View File

@@ -134,7 +134,6 @@ jobs:
NETBIRD_STORE_ENGINE_MYSQL_DSN: '${{ env.NETBIRD_STORE_ENGINE_MYSQL_DSN }}$'
CI_NETBIRD_MGMT_IDP_SIGNKEY_REFRESH: false
CI_NETBIRD_TURN_EXTERNAL_IP: "1.2.3.4"
CI_NETBIRD_MGMT_DISABLE_DEFAULT_POLICY: false
run: |
set -x
@@ -181,7 +180,6 @@ jobs:
grep -A 7 Relay management.json | egrep '"Secret": ".+"'
grep DisablePromptLogin management.json | grep 'true'
grep LoginFlag management.json | grep 0
grep DisableDefaultPolicy management.json | grep "$CI_NETBIRD_MGMT_DISABLE_DEFAULT_POLICY"
- name: Install modules
run: go mod tidy

View File

@@ -149,7 +149,6 @@ nfpms:
dockers:
- image_templates:
- netbirdio/netbird:{{ .Version }}-amd64
- ghcr.io/netbirdio/netbird:{{ .Version }}-amd64
ids:
- netbird
goarch: amd64
@@ -165,7 +164,6 @@ dockers:
- "--label=maintainer=dev@netbird.io"
- image_templates:
- netbirdio/netbird:{{ .Version }}-arm64v8
- ghcr.io/netbirdio/netbird:{{ .Version }}-arm64v8
ids:
- netbird
goarch: arm64

View File

@@ -14,9 +14,6 @@
<br>
<a href="https://docs.netbird.io/slack-url">
<img src="https://img.shields.io/badge/slack-@netbird-red.svg?logo=slack"/>
</a>
<a href="https://forum.netbird.io">
<img src="https://img.shields.io/badge/community forum-@netbird-red.svg?logo=discourse"/>
</a>
<br>
<a href="https://gurubase.io/g/netbird">
@@ -32,13 +29,13 @@
<br/>
See <a href="https://netbird.io/docs/">Documentation</a>
<br/>
Join our <a href="https://docs.netbird.io/slack-url">Slack channel</a> or our <a href="https://forum.netbird.io">Community forum</a>
Join our <a href="https://docs.netbird.io/slack-url">Slack channel</a>
<br/>
</strong>
<br>
<a href="https://registry.terraform.io/providers/netbirdio/netbird/latest">
New: NetBird terraform provider
<a href="https://github.com/netbirdio/kubernetes-operator">
New: NetBird Kubernetes Operator
</a>
</p>

View File

@@ -1,9 +1,6 @@
FROM alpine:3.21.3
# iproute2: busybox doesn't display ip rules properly
RUN apk add --no-cache ca-certificates ip6tables iproute2 iptables
ARG NETBIRD_BINARY=netbird
COPY ${NETBIRD_BINARY} /usr/local/bin/netbird
ENV NB_FOREGROUND_MODE=true
ENTRYPOINT [ "/usr/local/bin/netbird","up"]
COPY netbird /usr/local/bin/netbird

View File

@@ -1,7 +1,6 @@
FROM alpine:3.21.0
ARG NETBIRD_BINARY=netbird
COPY ${NETBIRD_BINARY} /usr/local/bin/netbird
COPY netbird /usr/local/bin/netbird
RUN apk add --no-cache ca-certificates \
&& adduser -D -h /var/lib/netbird netbird

View File

@@ -59,8 +59,6 @@ type Client struct {
deviceName string
uiVersion string
networkChangeListener listener.NetworkChangeListener
connectClient *internal.ConnectClient
}
// NewClient instantiate a new Client
@@ -108,8 +106,8 @@ func (c *Client) Run(urlOpener URLOpener, dns *DNSList, dnsReadyListener DnsRead
// todo do not throw error in case of cancelled context
ctx = internal.CtxInitState(ctx)
c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder)
return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, dns.items, dnsReadyListener)
connectClient := internal.NewConnectClient(ctx, cfg, c.recorder)
return connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, dns.items, dnsReadyListener)
}
// RunWithoutLogin we apply this type of run function when the backed has been started without UI (i.e. after reboot).
@@ -134,8 +132,8 @@ func (c *Client) RunWithoutLogin(dns *DNSList, dnsReadyListener DnsReadyListener
// todo do not throw error in case of cancelled context
ctx = internal.CtxInitState(ctx)
c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder)
return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, dns.items, dnsReadyListener)
connectClient := internal.NewConnectClient(ctx, cfg, c.recorder)
return connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, dns.items, dnsReadyListener)
}
// Stop the internal client and free the resources
@@ -176,55 +174,6 @@ func (c *Client) PeersList() *PeerInfoArray {
return &PeerInfoArray{items: peerInfos}
}
func (c *Client) Networks() *NetworkArray {
if c.connectClient == nil {
log.Error("not connected")
return nil
}
engine := c.connectClient.Engine()
if engine == nil {
log.Error("could not get engine")
return nil
}
routeManager := engine.GetRouteManager()
if routeManager == nil {
log.Error("could not get route manager")
return nil
}
networkArray := &NetworkArray{
items: make([]Network, 0),
}
for id, routes := range routeManager.GetClientRoutesWithNetID() {
if len(routes) == 0 {
continue
}
r := routes[0]
netStr := r.Network.String()
if r.IsDynamic() {
netStr = r.Domains.SafeString()
}
peer, err := c.recorder.GetPeer(routes[0].Peer)
if err != nil {
log.Errorf("could not get peer info for %s: %v", routes[0].Peer, err)
continue
}
network := Network{
Name: string(id),
Network: netStr,
Peer: peer.FQDN,
Status: peer.ConnStatus.String(),
}
networkArray.Add(network)
}
return networkArray
}
// OnUpdatedHostDNS update the DNS servers addresses for root zones
func (c *Client) OnUpdatedHostDNS(list *DNSList) error {
dnsServer, err := dns.GetServerDns()

View File

@@ -1,27 +0,0 @@
//go:build android
package android
type Network struct {
Name string
Network string
Peer string
Status string
}
type NetworkArray struct {
items []Network
}
func (array *NetworkArray) Add(s Network) *NetworkArray {
array.items = append(array.items, s)
return array
}
func (array *NetworkArray) Get(i int) *Network {
return &array.items[i]
}
func (array *NetworkArray) Size() int {
return len(array.items)
}

View File

@@ -7,23 +7,30 @@ type PeerInfo struct {
ConnStatus string // Todo replace to enum
}
// PeerInfoArray is a wrapper of []PeerInfo
// PeerInfoCollection made for Java layer to get non default types as collection
type PeerInfoCollection interface {
Add(s string) PeerInfoCollection
Get(i int) string
Size() int
}
// PeerInfoArray is the implementation of the PeerInfoCollection
type PeerInfoArray struct {
items []PeerInfo
}
// Add new PeerInfo to the collection
func (array *PeerInfoArray) Add(s PeerInfo) *PeerInfoArray {
func (array PeerInfoArray) Add(s PeerInfo) PeerInfoArray {
array.items = append(array.items, s)
return array
}
// Get return an element of the collection
func (array *PeerInfoArray) Get(i int) *PeerInfo {
func (array PeerInfoArray) Get(i int) *PeerInfo {
return &array.items[i]
}
// Size return with the size of the collection
func (array *PeerInfoArray) Size() int {
func (array PeerInfoArray) Size() int {
return len(array.items)
}

View File

@@ -4,12 +4,12 @@ import (
"github.com/netbirdio/netbird/client/internal"
)
// Preferences exports a subset of the internal config for gomobile
// Preferences export a subset of the internal config for gomobile
type Preferences struct {
configInput internal.ConfigInput
}
// NewPreferences creates a new Preferences instance
// NewPreferences create new Preferences instance
func NewPreferences(configPath string) *Preferences {
ci := internal.ConfigInput{
ConfigPath: configPath,
@@ -17,7 +17,7 @@ func NewPreferences(configPath string) *Preferences {
return &Preferences{ci}
}
// GetManagementURL reads URL from config file
// GetManagementURL read url from config file
func (p *Preferences) GetManagementURL() (string, error) {
if p.configInput.ManagementURL != "" {
return p.configInput.ManagementURL, nil
@@ -30,12 +30,12 @@ func (p *Preferences) GetManagementURL() (string, error) {
return cfg.ManagementURL.String(), err
}
// SetManagementURL stores the given URL and waits for commit
// SetManagementURL store the given url and wait for commit
func (p *Preferences) SetManagementURL(url string) {
p.configInput.ManagementURL = url
}
// GetAdminURL reads URL from config file
// GetAdminURL read url from config file
func (p *Preferences) GetAdminURL() (string, error) {
if p.configInput.AdminURL != "" {
return p.configInput.AdminURL, nil
@@ -48,12 +48,12 @@ func (p *Preferences) GetAdminURL() (string, error) {
return cfg.AdminURL.String(), err
}
// SetAdminURL stores the given URL and waits for commit
// SetAdminURL store the given url and wait for commit
func (p *Preferences) SetAdminURL(url string) {
p.configInput.AdminURL = url
}
// GetPreSharedKey reads pre-shared key from config file
// GetPreSharedKey read preshared key from config file
func (p *Preferences) GetPreSharedKey() (string, error) {
if p.configInput.PreSharedKey != nil {
return *p.configInput.PreSharedKey, nil
@@ -66,17 +66,17 @@ func (p *Preferences) GetPreSharedKey() (string, error) {
return cfg.PreSharedKey, err
}
// SetPreSharedKey stores the given key and waits for commit
// SetPreSharedKey store the given key and wait for commit
func (p *Preferences) SetPreSharedKey(key string) {
p.configInput.PreSharedKey = &key
}
// SetRosenpassEnabled stores whether Rosenpass is enabled
// SetRosenpassEnabled store if rosenpass is enabled
func (p *Preferences) SetRosenpassEnabled(enabled bool) {
p.configInput.RosenpassEnabled = &enabled
}
// GetRosenpassEnabled reads Rosenpass enabled status from config file
// GetRosenpassEnabled read rosenpass enabled from config file
func (p *Preferences) GetRosenpassEnabled() (bool, error) {
if p.configInput.RosenpassEnabled != nil {
return *p.configInput.RosenpassEnabled, nil
@@ -89,12 +89,12 @@ func (p *Preferences) GetRosenpassEnabled() (bool, error) {
return cfg.RosenpassEnabled, err
}
// SetRosenpassPermissive stores the given permissive setting and waits for commit
// SetRosenpassPermissive store the given permissive and wait for commit
func (p *Preferences) SetRosenpassPermissive(permissive bool) {
p.configInput.RosenpassPermissive = &permissive
}
// GetRosenpassPermissive reads Rosenpass permissive setting from config file
// GetRosenpassPermissive read rosenpass permissive from config file
func (p *Preferences) GetRosenpassPermissive() (bool, error) {
if p.configInput.RosenpassPermissive != nil {
return *p.configInput.RosenpassPermissive, nil
@@ -107,119 +107,7 @@ func (p *Preferences) GetRosenpassPermissive() (bool, error) {
return cfg.RosenpassPermissive, err
}
// GetDisableClientRoutes reads disable client routes setting from config file
func (p *Preferences) GetDisableClientRoutes() (bool, error) {
if p.configInput.DisableClientRoutes != nil {
return *p.configInput.DisableClientRoutes, nil
}
cfg, err := internal.ReadConfig(p.configInput.ConfigPath)
if err != nil {
return false, err
}
return cfg.DisableClientRoutes, err
}
// SetDisableClientRoutes stores the given value and waits for commit
func (p *Preferences) SetDisableClientRoutes(disable bool) {
p.configInput.DisableClientRoutes = &disable
}
// GetDisableServerRoutes reads disable server routes setting from config file
func (p *Preferences) GetDisableServerRoutes() (bool, error) {
if p.configInput.DisableServerRoutes != nil {
return *p.configInput.DisableServerRoutes, nil
}
cfg, err := internal.ReadConfig(p.configInput.ConfigPath)
if err != nil {
return false, err
}
return cfg.DisableServerRoutes, err
}
// SetDisableServerRoutes stores the given value and waits for commit
func (p *Preferences) SetDisableServerRoutes(disable bool) {
p.configInput.DisableServerRoutes = &disable
}
// GetDisableDNS reads disable DNS setting from config file
func (p *Preferences) GetDisableDNS() (bool, error) {
if p.configInput.DisableDNS != nil {
return *p.configInput.DisableDNS, nil
}
cfg, err := internal.ReadConfig(p.configInput.ConfigPath)
if err != nil {
return false, err
}
return cfg.DisableDNS, err
}
// SetDisableDNS stores the given value and waits for commit
func (p *Preferences) SetDisableDNS(disable bool) {
p.configInput.DisableDNS = &disable
}
// GetDisableFirewall reads disable firewall setting from config file
func (p *Preferences) GetDisableFirewall() (bool, error) {
if p.configInput.DisableFirewall != nil {
return *p.configInput.DisableFirewall, nil
}
cfg, err := internal.ReadConfig(p.configInput.ConfigPath)
if err != nil {
return false, err
}
return cfg.DisableFirewall, err
}
// SetDisableFirewall stores the given value and waits for commit
func (p *Preferences) SetDisableFirewall(disable bool) {
p.configInput.DisableFirewall = &disable
}
// GetServerSSHAllowed reads server SSH allowed setting from config file
func (p *Preferences) GetServerSSHAllowed() (bool, error) {
if p.configInput.ServerSSHAllowed != nil {
return *p.configInput.ServerSSHAllowed, nil
}
cfg, err := internal.ReadConfig(p.configInput.ConfigPath)
if err != nil {
return false, err
}
if cfg.ServerSSHAllowed == nil {
// Default to false for security on Android
return false, nil
}
return *cfg.ServerSSHAllowed, err
}
// SetServerSSHAllowed stores the given value and waits for commit
func (p *Preferences) SetServerSSHAllowed(allowed bool) {
p.configInput.ServerSSHAllowed = &allowed
}
// GetBlockInbound reads block inbound setting from config file
func (p *Preferences) GetBlockInbound() (bool, error) {
if p.configInput.BlockInbound != nil {
return *p.configInput.BlockInbound, nil
}
cfg, err := internal.ReadConfig(p.configInput.ConfigPath)
if err != nil {
return false, err
}
return cfg.BlockInbound, err
}
// SetBlockInbound stores the given value and waits for commit
func (p *Preferences) SetBlockInbound(block bool) {
p.configInput.BlockInbound = &block
}
// Commit writes out the changes to the config file
// Commit write out the changes into config file
func (p *Preferences) Commit() error {
_, err := internal.UpdateOrCreateConfig(p.configInput)
return err

View File

@@ -69,10 +69,7 @@ func statusFunc(cmd *cobra.Command, args []string) error {
return err
}
status := resp.GetStatus()
if status == string(internal.StatusNeedsLogin) || status == string(internal.StatusLoginFailed) ||
status == string(internal.StatusSessionExpired) {
if resp.GetStatus() == string(internal.StatusNeedsLogin) || resp.GetStatus() == string(internal.StatusLoginFailed) {
cmd.Printf("Daemon status: %s\n\n"+
"Run UP command to log in with SSO (interactive login):\n\n"+
" netbird up \n\n"+
@@ -120,7 +117,7 @@ func getStatus(ctx context.Context) (*proto.StatusResponse, error) {
}
defer conn.Close()
resp, err := proto.NewDaemonServiceClient(conn).Status(ctx, &proto.StatusRequest{GetFullPeerStatus: true, ShouldRunProbes: true})
resp, err := proto.NewDaemonServiceClient(conn).Status(ctx, &proto.StatusRequest{GetFullPeerStatus: true})
if err != nil {
return nil, fmt.Errorf("status failed: %v", status.Convert(err).Message())
}

View File

@@ -38,5 +38,5 @@ func init() {
upCmd.PersistentFlags().BoolVar(&blockInbound, blockInboundFlag, false,
"Block inbound connections. If enabled, the client will not allow any inbound connections to the local machine nor routed networks.\n"+
"This overrides any policies received from the management service.")
"This overrides any policies received from the management service.")
}

View File

@@ -103,7 +103,7 @@ func startManagement(t *testing.T, config *types.Config, testFile string) (*grpc
Return(&types.Settings{}, nil).
AnyTimes()
accountManager, err := mgmt.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, iv, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false)
accountManager, err := mgmt.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, iv, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock)
if err != nil {
t.Fatal(err)
}

View File

@@ -118,7 +118,7 @@ func tracePacket(cmd *cobra.Command, args []string) error {
}
func printTrace(cmd *cobra.Command, src, dst, proto string, sport, dport uint16, resp *proto.TracePacketResponse) {
cmd.Printf("Packet trace %s:%d %s:%d (%s)\n\n", src, sport, dst, dport, strings.ToUpper(proto))
cmd.Printf("Packet trace %s:%d -> %s:%d (%s)\n\n", src, sport, dst, dport, strings.ToUpper(proto))
for _, stage := range resp.Stages {
if stage.ForwardingDetails != nil {

View File

@@ -62,5 +62,5 @@ type ConnKey struct {
}
func (c ConnKey) String() string {
return fmt.Sprintf("%s:%d %s:%d", c.SrcIP.Unmap(), c.SrcPort, c.DstIP.Unmap(), c.DstPort)
return fmt.Sprintf("%s:%d -> %s:%d", c.SrcIP.Unmap(), c.SrcPort, c.DstIP.Unmap(), c.DstPort)
}

View File

@@ -3,7 +3,6 @@ package conntrack
import (
"context"
"fmt"
"net"
"net/netip"
"sync"
"time"
@@ -20,10 +19,6 @@ const (
DefaultICMPTimeout = 30 * time.Second
// ICMPCleanupInterval is how often we check for stale ICMP connections
ICMPCleanupInterval = 15 * time.Second
// MaxICMPPayloadLength is the maximum length of ICMP payload we consider for original packet info,
// which includes the IP header (20 bytes) and transport header (8 bytes)
MaxICMPPayloadLength = 28
)
// ICMPConnKey uniquely identifies an ICMP connection
@@ -34,7 +29,7 @@ type ICMPConnKey struct {
}
func (i ICMPConnKey) String() string {
return fmt.Sprintf("%s %s (id %d)", i.SrcIP, i.DstIP, i.ID)
return fmt.Sprintf("%s -> %s (id %d)", i.SrcIP, i.DstIP, i.ID)
}
// ICMPConnTrack represents an ICMP connection state
@@ -55,72 +50,6 @@ type ICMPTracker struct {
flowLogger nftypes.FlowLogger
}
// ICMPInfo holds ICMP type, code, and payload for lazy string formatting in logs
type ICMPInfo struct {
TypeCode layers.ICMPv4TypeCode
PayloadData [MaxICMPPayloadLength]byte
// actual length of valid data
PayloadLen int
}
// String implements fmt.Stringer for lazy evaluation in log messages
func (info ICMPInfo) String() string {
if info.isErrorMessage() && info.PayloadLen >= MaxICMPPayloadLength {
if origInfo := info.parseOriginalPacket(); origInfo != "" {
return fmt.Sprintf("%s (original: %s)", info.TypeCode, origInfo)
}
}
return info.TypeCode.String()
}
// isErrorMessage returns true if this ICMP type carries original packet info
func (info ICMPInfo) isErrorMessage() bool {
typ := info.TypeCode.Type()
return typ == 3 || // Destination Unreachable
typ == 5 || // Redirect
typ == 11 || // Time Exceeded
typ == 12 // Parameter Problem
}
// parseOriginalPacket extracts info about the original packet from ICMP payload
func (info ICMPInfo) parseOriginalPacket() string {
if info.PayloadLen < MaxICMPPayloadLength {
return ""
}
// TODO: handle IPv6
if version := (info.PayloadData[0] >> 4) & 0xF; version != 4 {
return ""
}
protocol := info.PayloadData[9]
srcIP := net.IP(info.PayloadData[12:16])
dstIP := net.IP(info.PayloadData[16:20])
transportData := info.PayloadData[20:]
switch nftypes.Protocol(protocol) {
case nftypes.TCP:
srcPort := uint16(transportData[0])<<8 | uint16(transportData[1])
dstPort := uint16(transportData[2])<<8 | uint16(transportData[3])
return fmt.Sprintf("TCP %s:%d → %s:%d", srcIP, srcPort, dstIP, dstPort)
case nftypes.UDP:
srcPort := uint16(transportData[0])<<8 | uint16(transportData[1])
dstPort := uint16(transportData[2])<<8 | uint16(transportData[3])
return fmt.Sprintf("UDP %s:%d → %s:%d", srcIP, srcPort, dstIP, dstPort)
case nftypes.ICMP:
icmpType := transportData[0]
icmpCode := transportData[1]
return fmt.Sprintf("ICMP %s → %s (type %d code %d)", srcIP, dstIP, icmpType, icmpCode)
default:
return fmt.Sprintf("Proto %d %s → %s", protocol, srcIP, dstIP)
}
}
// NewICMPTracker creates a new ICMP connection tracker
func NewICMPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nftypes.FlowLogger) *ICMPTracker {
if timeout == 0 {
@@ -164,64 +93,30 @@ func (t *ICMPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, id uint
}
// TrackOutbound records an outbound ICMP connection
func (t *ICMPTracker) TrackOutbound(
srcIP netip.Addr,
dstIP netip.Addr,
id uint16,
typecode layers.ICMPv4TypeCode,
payload []byte,
size int,
) {
func (t *ICMPTracker) TrackOutbound(srcIP netip.Addr, dstIP netip.Addr, id uint16, typecode layers.ICMPv4TypeCode, size int) {
if _, exists := t.updateIfExists(dstIP, srcIP, id, nftypes.Egress, size); !exists {
// if (inverted direction) conn is not tracked, track this direction
t.track(srcIP, dstIP, id, typecode, nftypes.Egress, nil, payload, size)
t.track(srcIP, dstIP, id, typecode, nftypes.Egress, nil, size)
}
}
// TrackInbound records an inbound ICMP Echo Request
func (t *ICMPTracker) TrackInbound(
srcIP netip.Addr,
dstIP netip.Addr,
id uint16,
typecode layers.ICMPv4TypeCode,
ruleId []byte,
payload []byte,
size int,
) {
t.track(srcIP, dstIP, id, typecode, nftypes.Ingress, ruleId, payload, size)
func (t *ICMPTracker) TrackInbound(srcIP netip.Addr, dstIP netip.Addr, id uint16, typecode layers.ICMPv4TypeCode, ruleId []byte, size int) {
t.track(srcIP, dstIP, id, typecode, nftypes.Ingress, ruleId, size)
}
// track is the common implementation for tracking both inbound and outbound ICMP connections
func (t *ICMPTracker) track(
srcIP netip.Addr,
dstIP netip.Addr,
id uint16,
typecode layers.ICMPv4TypeCode,
direction nftypes.Direction,
ruleId []byte,
payload []byte,
size int,
) {
func (t *ICMPTracker) track(srcIP netip.Addr, dstIP netip.Addr, id uint16, typecode layers.ICMPv4TypeCode, direction nftypes.Direction, ruleId []byte, size int) {
key, exists := t.updateIfExists(srcIP, dstIP, id, direction, size)
if exists {
return
}
typ, code := typecode.Type(), typecode.Code()
icmpInfo := ICMPInfo{
TypeCode: typecode,
}
if len(payload) > 0 {
icmpInfo.PayloadLen = len(payload)
if icmpInfo.PayloadLen > MaxICMPPayloadLength {
icmpInfo.PayloadLen = MaxICMPPayloadLength
}
copy(icmpInfo.PayloadData[:], payload[:icmpInfo.PayloadLen])
}
// non echo requests don't need tracking
if typ != uint8(layers.ICMPv4TypeEchoRequest) {
t.logger.Trace("New %s ICMP connection %s - %s", direction, key, icmpInfo)
t.logger.Trace("New %s ICMP connection %s type %d code %d", direction, key, typ, code)
t.sendStartEvent(direction, srcIP, dstIP, typ, code, ruleId, size)
return
}
@@ -243,7 +138,7 @@ func (t *ICMPTracker) track(
t.connections[key] = conn
t.mutex.Unlock()
t.logger.Trace("New %s ICMP connection %s - %s", direction, key, icmpInfo)
t.logger.Trace("New %s ICMP connection %s type %d code %d", direction, key, typ, code)
t.sendEvent(nftypes.TypeStart, conn, ruleId)
}

View File

@@ -15,7 +15,7 @@ func BenchmarkICMPTracker(b *testing.B) {
b.ResetTimer()
for i := 0; i < b.N; i++ {
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 0, []byte{}, 0)
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 0, 0)
}
})
@@ -28,7 +28,7 @@ func BenchmarkICMPTracker(b *testing.B) {
// Pre-populate some connections
for i := 0; i < 1000; i++ {
tracker.TrackOutbound(srcIP, dstIP, uint16(i), 0, []byte{}, 0)
tracker.TrackOutbound(srcIP, dstIP, uint16(i), 0, 0)
}
b.ResetTimer()

View File

@@ -86,5 +86,5 @@ type epID stack.TransportEndpointID
func (i epID) String() string {
// src and remote is swapped
return fmt.Sprintf("%s:%d %s:%d", i.RemoteAddress, i.RemotePort, i.LocalAddress, i.LocalPort)
return fmt.Sprintf("%s:%d -> %s:%d", i.RemoteAddress, i.RemotePort, i.LocalAddress, i.LocalPort)
}

View File

@@ -111,12 +111,12 @@ func (f *Forwarder) proxyTCP(id stack.TransportEndpointID, inConn *gonet.TCPConn
if errInToOut != nil {
if !isClosedError(errInToOut) {
f.logger.Error("proxyTCP: copy error (in out) for %s: %v", epID(id), errInToOut)
f.logger.Error("proxyTCP: copy error (in -> out) for %s: %v", epID(id), errInToOut)
}
}
if errOutToIn != nil {
if !isClosedError(errOutToIn) {
f.logger.Error("proxyTCP: copy error (out in) for %s: %v", epID(id), errOutToIn)
f.logger.Error("proxyTCP: copy error (out -> in) for %s: %v", epID(id), errOutToIn)
}
}

View File

@@ -250,10 +250,10 @@ func (f *Forwarder) proxyUDP(ctx context.Context, pConn *udpPacketConn, id stack
wg.Wait()
if outboundErr != nil && !isClosedError(outboundErr) {
f.logger.Error("proxyUDP: copy error (outboundinbound) for %s: %v", epID(id), outboundErr)
f.logger.Error("proxyUDP: copy error (outbound->inbound) for %s: %v", epID(id), outboundErr)
}
if inboundErr != nil && !isClosedError(inboundErr) {
f.logger.Error("proxyUDP: copy error (inboundoutbound) for %s: %v", epID(id), inboundErr)
f.logger.Error("proxyUDP: copy error (inbound->outbound) for %s: %v", epID(id), inboundErr)
}
var rxPackets, txPackets uint64

View File

@@ -1,408 +0,0 @@
package uspfilter
import (
"encoding/binary"
"errors"
"fmt"
"net/netip"
"github.com/google/gopacket/layers"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
)
var ErrIPv4Only = errors.New("only IPv4 is supported for DNAT")
func ipv4Checksum(header []byte) uint16 {
if len(header) < 20 {
return 0
}
var sum1, sum2 uint32
// Parallel processing - unroll and compute two sums simultaneously
sum1 += uint32(binary.BigEndian.Uint16(header[0:2]))
sum2 += uint32(binary.BigEndian.Uint16(header[2:4]))
sum1 += uint32(binary.BigEndian.Uint16(header[4:6]))
sum2 += uint32(binary.BigEndian.Uint16(header[6:8]))
sum1 += uint32(binary.BigEndian.Uint16(header[8:10]))
// Skip checksum field at [10:12]
sum2 += uint32(binary.BigEndian.Uint16(header[12:14]))
sum1 += uint32(binary.BigEndian.Uint16(header[14:16]))
sum2 += uint32(binary.BigEndian.Uint16(header[16:18]))
sum1 += uint32(binary.BigEndian.Uint16(header[18:20]))
sum := sum1 + sum2
// Handle remaining bytes for headers > 20 bytes
for i := 20; i < len(header)-1; i += 2 {
sum += uint32(binary.BigEndian.Uint16(header[i : i+2]))
}
if len(header)%2 == 1 {
sum += uint32(header[len(header)-1]) << 8
}
// Optimized carry fold - single iteration handles most cases
sum = (sum & 0xFFFF) + (sum >> 16)
if sum > 0xFFFF {
sum++
}
return ^uint16(sum)
}
func icmpChecksum(data []byte) uint16 {
var sum1, sum2, sum3, sum4 uint32
i := 0
// Process 16 bytes at once with 4 parallel accumulators
for i <= len(data)-16 {
sum1 += uint32(binary.BigEndian.Uint16(data[i : i+2]))
sum2 += uint32(binary.BigEndian.Uint16(data[i+2 : i+4]))
sum3 += uint32(binary.BigEndian.Uint16(data[i+4 : i+6]))
sum4 += uint32(binary.BigEndian.Uint16(data[i+6 : i+8]))
sum1 += uint32(binary.BigEndian.Uint16(data[i+8 : i+10]))
sum2 += uint32(binary.BigEndian.Uint16(data[i+10 : i+12]))
sum3 += uint32(binary.BigEndian.Uint16(data[i+12 : i+14]))
sum4 += uint32(binary.BigEndian.Uint16(data[i+14 : i+16]))
i += 16
}
sum := sum1 + sum2 + sum3 + sum4
// Handle remaining bytes
for i < len(data)-1 {
sum += uint32(binary.BigEndian.Uint16(data[i : i+2]))
i += 2
}
if len(data)%2 == 1 {
sum += uint32(data[len(data)-1]) << 8
}
sum = (sum & 0xFFFF) + (sum >> 16)
if sum > 0xFFFF {
sum++
}
return ^uint16(sum)
}
type biDNATMap struct {
forward map[netip.Addr]netip.Addr
reverse map[netip.Addr]netip.Addr
}
func newBiDNATMap() *biDNATMap {
return &biDNATMap{
forward: make(map[netip.Addr]netip.Addr),
reverse: make(map[netip.Addr]netip.Addr),
}
}
func (b *biDNATMap) set(original, translated netip.Addr) {
b.forward[original] = translated
b.reverse[translated] = original
}
func (b *biDNATMap) delete(original netip.Addr) {
if translated, exists := b.forward[original]; exists {
delete(b.forward, original)
delete(b.reverse, translated)
}
}
func (b *biDNATMap) getTranslated(original netip.Addr) (netip.Addr, bool) {
translated, exists := b.forward[original]
return translated, exists
}
func (b *biDNATMap) getOriginal(translated netip.Addr) (netip.Addr, bool) {
original, exists := b.reverse[translated]
return original, exists
}
func (m *Manager) AddInternalDNATMapping(originalAddr, translatedAddr netip.Addr) error {
if !originalAddr.IsValid() || !translatedAddr.IsValid() {
return fmt.Errorf("invalid IP addresses")
}
if m.localipmanager.IsLocalIP(translatedAddr) {
return fmt.Errorf("cannot map to local IP: %s", translatedAddr)
}
m.dnatMutex.Lock()
defer m.dnatMutex.Unlock()
// Initialize both maps together if either is nil
if m.dnatMappings == nil || m.dnatBiMap == nil {
m.dnatMappings = make(map[netip.Addr]netip.Addr)
m.dnatBiMap = newBiDNATMap()
}
m.dnatMappings[originalAddr] = translatedAddr
m.dnatBiMap.set(originalAddr, translatedAddr)
if len(m.dnatMappings) == 1 {
m.dnatEnabled.Store(true)
}
return nil
}
// RemoveInternalDNATMapping removes a 1:1 IP address mapping
func (m *Manager) RemoveInternalDNATMapping(originalAddr netip.Addr) error {
m.dnatMutex.Lock()
defer m.dnatMutex.Unlock()
if _, exists := m.dnatMappings[originalAddr]; !exists {
return fmt.Errorf("mapping not found for: %s", originalAddr)
}
delete(m.dnatMappings, originalAddr)
m.dnatBiMap.delete(originalAddr)
if len(m.dnatMappings) == 0 {
m.dnatEnabled.Store(false)
}
return nil
}
// getDNATTranslation returns the translated address if a mapping exists
func (m *Manager) getDNATTranslation(addr netip.Addr) (netip.Addr, bool) {
if !m.dnatEnabled.Load() {
return addr, false
}
m.dnatMutex.RLock()
translated, exists := m.dnatBiMap.getTranslated(addr)
m.dnatMutex.RUnlock()
return translated, exists
}
// findReverseDNATMapping finds original address for return traffic
func (m *Manager) findReverseDNATMapping(translatedAddr netip.Addr) (netip.Addr, bool) {
if !m.dnatEnabled.Load() {
return translatedAddr, false
}
m.dnatMutex.RLock()
original, exists := m.dnatBiMap.getOriginal(translatedAddr)
m.dnatMutex.RUnlock()
return original, exists
}
// translateOutboundDNAT applies DNAT translation to outbound packets
func (m *Manager) translateOutboundDNAT(packetData []byte, d *decoder) bool {
if !m.dnatEnabled.Load() {
return false
}
if len(packetData) < 20 || d.decoded[0] != layers.LayerTypeIPv4 {
return false
}
dstIP := netip.AddrFrom4([4]byte{packetData[16], packetData[17], packetData[18], packetData[19]})
translatedIP, exists := m.getDNATTranslation(dstIP)
if !exists {
return false
}
if err := m.rewritePacketDestination(packetData, d, translatedIP); err != nil {
m.logger.Error("Failed to rewrite packet destination: %v", err)
return false
}
m.logger.Trace("DNAT: %s -> %s", dstIP, translatedIP)
return true
}
// translateInboundReverse applies reverse DNAT to inbound return traffic
func (m *Manager) translateInboundReverse(packetData []byte, d *decoder) bool {
if !m.dnatEnabled.Load() {
return false
}
if len(packetData) < 20 || d.decoded[0] != layers.LayerTypeIPv4 {
return false
}
srcIP := netip.AddrFrom4([4]byte{packetData[12], packetData[13], packetData[14], packetData[15]})
originalIP, exists := m.findReverseDNATMapping(srcIP)
if !exists {
return false
}
if err := m.rewritePacketSource(packetData, d, originalIP); err != nil {
m.logger.Error("Failed to rewrite packet source: %v", err)
return false
}
m.logger.Trace("Reverse DNAT: %s -> %s", srcIP, originalIP)
return true
}
// rewritePacketDestination replaces destination IP in the packet
func (m *Manager) rewritePacketDestination(packetData []byte, d *decoder, newIP netip.Addr) error {
if len(packetData) < 20 || d.decoded[0] != layers.LayerTypeIPv4 || !newIP.Is4() {
return ErrIPv4Only
}
var oldDst [4]byte
copy(oldDst[:], packetData[16:20])
newDst := newIP.As4()
copy(packetData[16:20], newDst[:])
ipHeaderLen := int(d.ip4.IHL) * 4
if ipHeaderLen < 20 || ipHeaderLen > len(packetData) {
return fmt.Errorf("invalid IP header length")
}
binary.BigEndian.PutUint16(packetData[10:12], 0)
ipChecksum := ipv4Checksum(packetData[:ipHeaderLen])
binary.BigEndian.PutUint16(packetData[10:12], ipChecksum)
if len(d.decoded) > 1 {
switch d.decoded[1] {
case layers.LayerTypeTCP:
m.updateTCPChecksum(packetData, ipHeaderLen, oldDst[:], newDst[:])
case layers.LayerTypeUDP:
m.updateUDPChecksum(packetData, ipHeaderLen, oldDst[:], newDst[:])
case layers.LayerTypeICMPv4:
m.updateICMPChecksum(packetData, ipHeaderLen)
}
}
return nil
}
// rewritePacketSource replaces the source IP address in the packet
func (m *Manager) rewritePacketSource(packetData []byte, d *decoder, newIP netip.Addr) error {
if len(packetData) < 20 || d.decoded[0] != layers.LayerTypeIPv4 || !newIP.Is4() {
return ErrIPv4Only
}
var oldSrc [4]byte
copy(oldSrc[:], packetData[12:16])
newSrc := newIP.As4()
copy(packetData[12:16], newSrc[:])
ipHeaderLen := int(d.ip4.IHL) * 4
if ipHeaderLen < 20 || ipHeaderLen > len(packetData) {
return fmt.Errorf("invalid IP header length")
}
binary.BigEndian.PutUint16(packetData[10:12], 0)
ipChecksum := ipv4Checksum(packetData[:ipHeaderLen])
binary.BigEndian.PutUint16(packetData[10:12], ipChecksum)
if len(d.decoded) > 1 {
switch d.decoded[1] {
case layers.LayerTypeTCP:
m.updateTCPChecksum(packetData, ipHeaderLen, oldSrc[:], newSrc[:])
case layers.LayerTypeUDP:
m.updateUDPChecksum(packetData, ipHeaderLen, oldSrc[:], newSrc[:])
case layers.LayerTypeICMPv4:
m.updateICMPChecksum(packetData, ipHeaderLen)
}
}
return nil
}
func (m *Manager) updateTCPChecksum(packetData []byte, ipHeaderLen int, oldIP, newIP []byte) {
tcpStart := ipHeaderLen
if len(packetData) < tcpStart+18 {
return
}
checksumOffset := tcpStart + 16
oldChecksum := binary.BigEndian.Uint16(packetData[checksumOffset : checksumOffset+2])
newChecksum := incrementalUpdate(oldChecksum, oldIP, newIP)
binary.BigEndian.PutUint16(packetData[checksumOffset:checksumOffset+2], newChecksum)
}
func (m *Manager) updateUDPChecksum(packetData []byte, ipHeaderLen int, oldIP, newIP []byte) {
udpStart := ipHeaderLen
if len(packetData) < udpStart+8 {
return
}
checksumOffset := udpStart + 6
oldChecksum := binary.BigEndian.Uint16(packetData[checksumOffset : checksumOffset+2])
if oldChecksum == 0 {
return
}
newChecksum := incrementalUpdate(oldChecksum, oldIP, newIP)
binary.BigEndian.PutUint16(packetData[checksumOffset:checksumOffset+2], newChecksum)
}
func (m *Manager) updateICMPChecksum(packetData []byte, ipHeaderLen int) {
icmpStart := ipHeaderLen
if len(packetData) < icmpStart+8 {
return
}
icmpData := packetData[icmpStart:]
binary.BigEndian.PutUint16(icmpData[2:4], 0)
checksum := icmpChecksum(icmpData)
binary.BigEndian.PutUint16(icmpData[2:4], checksum)
}
// incrementalUpdate performs incremental checksum update per RFC 1624
func incrementalUpdate(oldChecksum uint16, oldBytes, newBytes []byte) uint16 {
sum := uint32(^oldChecksum)
// Fast path for IPv4 addresses (4 bytes) - most common case
if len(oldBytes) == 4 && len(newBytes) == 4 {
sum += uint32(^binary.BigEndian.Uint16(oldBytes[0:2]))
sum += uint32(^binary.BigEndian.Uint16(oldBytes[2:4]))
sum += uint32(binary.BigEndian.Uint16(newBytes[0:2]))
sum += uint32(binary.BigEndian.Uint16(newBytes[2:4]))
} else {
// Fallback for other lengths
for i := 0; i < len(oldBytes)-1; i += 2 {
sum += uint32(^binary.BigEndian.Uint16(oldBytes[i : i+2]))
}
if len(oldBytes)%2 == 1 {
sum += uint32(^oldBytes[len(oldBytes)-1]) << 8
}
for i := 0; i < len(newBytes)-1; i += 2 {
sum += uint32(binary.BigEndian.Uint16(newBytes[i : i+2]))
}
if len(newBytes)%2 == 1 {
sum += uint32(newBytes[len(newBytes)-1]) << 8
}
}
sum = (sum & 0xFFFF) + (sum >> 16)
if sum > 0xFFFF {
sum++
}
return ^uint16(sum)
}
// AddDNATRule adds a DNAT rule (delegates to native firewall for port forwarding)
func (m *Manager) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) {
if m.nativeFirewall == nil {
return nil, errNatNotSupported
}
return m.nativeFirewall.AddDNATRule(rule)
}
// DeleteDNATRule deletes a DNAT rule (delegates to native firewall)
func (m *Manager) DeleteDNATRule(rule firewall.Rule) error {
if m.nativeFirewall == nil {
return errNatNotSupported
}
return m.nativeFirewall.DeleteDNATRule(rule)
}

View File

@@ -1,416 +0,0 @@
package uspfilter
import (
"fmt"
"net/netip"
"testing"
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/iface/device"
)
// BenchmarkDNATTranslation measures the performance of DNAT operations
func BenchmarkDNATTranslation(b *testing.B) {
scenarios := []struct {
name string
proto layers.IPProtocol
setupDNAT bool
description string
}{
{
name: "tcp_with_dnat",
proto: layers.IPProtocolTCP,
setupDNAT: true,
description: "TCP packet with DNAT translation enabled",
},
{
name: "tcp_without_dnat",
proto: layers.IPProtocolTCP,
setupDNAT: false,
description: "TCP packet without DNAT (baseline)",
},
{
name: "udp_with_dnat",
proto: layers.IPProtocolUDP,
setupDNAT: true,
description: "UDP packet with DNAT translation enabled",
},
{
name: "udp_without_dnat",
proto: layers.IPProtocolUDP,
setupDNAT: false,
description: "UDP packet without DNAT (baseline)",
},
{
name: "icmp_with_dnat",
proto: layers.IPProtocolICMPv4,
setupDNAT: true,
description: "ICMP packet with DNAT translation enabled",
},
{
name: "icmp_without_dnat",
proto: layers.IPProtocolICMPv4,
setupDNAT: false,
description: "ICMP packet without DNAT (baseline)",
},
}
for _, sc := range scenarios {
b.Run(sc.name, func(b *testing.B) {
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger)
require.NoError(b, err)
defer func() {
require.NoError(b, manager.Close(nil))
}()
// Set logger to error level to reduce noise during benchmarking
manager.SetLogLevel(log.ErrorLevel)
defer func() {
// Restore to info level after benchmark
manager.SetLogLevel(log.InfoLevel)
}()
// Setup DNAT mapping if needed
originalIP := netip.MustParseAddr("192.168.1.100")
translatedIP := netip.MustParseAddr("10.0.0.100")
if sc.setupDNAT {
err := manager.AddInternalDNATMapping(originalIP, translatedIP)
require.NoError(b, err)
}
// Create test packets
srcIP := netip.MustParseAddr("172.16.0.1")
outboundPacket := generateDNATTestPacket(b, srcIP, originalIP, sc.proto, 12345, 80)
// Pre-establish connection for reverse DNAT test
if sc.setupDNAT {
manager.filterOutbound(outboundPacket, 0)
}
b.ResetTimer()
// Benchmark outbound DNAT translation
b.Run("outbound", func(b *testing.B) {
for i := 0; i < b.N; i++ {
// Create fresh packet each time since translation modifies it
packet := generateDNATTestPacket(b, srcIP, originalIP, sc.proto, 12345, 80)
manager.filterOutbound(packet, 0)
}
})
// Benchmark inbound reverse DNAT translation
if sc.setupDNAT {
b.Run("inbound_reverse", func(b *testing.B) {
for i := 0; i < b.N; i++ {
// Create fresh packet each time since translation modifies it
packet := generateDNATTestPacket(b, translatedIP, srcIP, sc.proto, 80, 12345)
manager.filterInbound(packet, 0)
}
})
}
})
}
}
// BenchmarkDNATConcurrency tests DNAT performance under concurrent load
func BenchmarkDNATConcurrency(b *testing.B) {
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger)
require.NoError(b, err)
defer func() {
require.NoError(b, manager.Close(nil))
}()
// Set logger to error level to reduce noise during benchmarking
manager.SetLogLevel(log.ErrorLevel)
defer func() {
// Restore to info level after benchmark
manager.SetLogLevel(log.InfoLevel)
}()
// Setup multiple DNAT mappings
numMappings := 100
originalIPs := make([]netip.Addr, numMappings)
translatedIPs := make([]netip.Addr, numMappings)
for i := 0; i < numMappings; i++ {
originalIPs[i] = netip.MustParseAddr(fmt.Sprintf("192.168.%d.%d", (i/254)+1, (i%254)+1))
translatedIPs[i] = netip.MustParseAddr(fmt.Sprintf("10.0.%d.%d", (i/254)+1, (i%254)+1))
err := manager.AddInternalDNATMapping(originalIPs[i], translatedIPs[i])
require.NoError(b, err)
}
srcIP := netip.MustParseAddr("172.16.0.1")
// Pre-generate packets
outboundPackets := make([][]byte, numMappings)
inboundPackets := make([][]byte, numMappings)
for i := 0; i < numMappings; i++ {
outboundPackets[i] = generateDNATTestPacket(b, srcIP, originalIPs[i], layers.IPProtocolTCP, 12345, 80)
inboundPackets[i] = generateDNATTestPacket(b, translatedIPs[i], srcIP, layers.IPProtocolTCP, 80, 12345)
// Establish connections
manager.filterOutbound(outboundPackets[i], 0)
}
b.ResetTimer()
b.Run("concurrent_outbound", func(b *testing.B) {
b.RunParallel(func(pb *testing.PB) {
i := 0
for pb.Next() {
idx := i % numMappings
packet := generateDNATTestPacket(b, srcIP, originalIPs[idx], layers.IPProtocolTCP, 12345, 80)
manager.filterOutbound(packet, 0)
i++
}
})
})
b.Run("concurrent_inbound", func(b *testing.B) {
b.RunParallel(func(pb *testing.PB) {
i := 0
for pb.Next() {
idx := i % numMappings
packet := generateDNATTestPacket(b, translatedIPs[idx], srcIP, layers.IPProtocolTCP, 80, 12345)
manager.filterInbound(packet, 0)
i++
}
})
})
}
// BenchmarkDNATScaling tests how DNAT performance scales with number of mappings
func BenchmarkDNATScaling(b *testing.B) {
mappingCounts := []int{1, 10, 100, 1000}
for _, count := range mappingCounts {
b.Run(fmt.Sprintf("mappings_%d", count), func(b *testing.B) {
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger)
require.NoError(b, err)
defer func() {
require.NoError(b, manager.Close(nil))
}()
// Set logger to error level to reduce noise during benchmarking
manager.SetLogLevel(log.ErrorLevel)
defer func() {
// Restore to info level after benchmark
manager.SetLogLevel(log.InfoLevel)
}()
// Setup DNAT mappings
for i := 0; i < count; i++ {
originalIP := netip.MustParseAddr(fmt.Sprintf("192.168.%d.%d", (i/254)+1, (i%254)+1))
translatedIP := netip.MustParseAddr(fmt.Sprintf("10.0.%d.%d", (i/254)+1, (i%254)+1))
err := manager.AddInternalDNATMapping(originalIP, translatedIP)
require.NoError(b, err)
}
// Test with the last mapping added (worst case for lookup)
srcIP := netip.MustParseAddr("172.16.0.1")
lastOriginal := netip.MustParseAddr(fmt.Sprintf("192.168.%d.%d", ((count-1)/254)+1, ((count-1)%254)+1))
b.ResetTimer()
for i := 0; i < b.N; i++ {
packet := generateDNATTestPacket(b, srcIP, lastOriginal, layers.IPProtocolTCP, 12345, 80)
manager.filterOutbound(packet, 0)
}
})
}
}
// generateDNATTestPacket creates a test packet for DNAT benchmarking
func generateDNATTestPacket(tb testing.TB, srcIP, dstIP netip.Addr, proto layers.IPProtocol, srcPort, dstPort uint16) []byte {
tb.Helper()
ipv4 := &layers.IPv4{
TTL: 64,
Version: 4,
SrcIP: srcIP.AsSlice(),
DstIP: dstIP.AsSlice(),
Protocol: proto,
}
var transportLayer gopacket.SerializableLayer
switch proto {
case layers.IPProtocolTCP:
tcp := &layers.TCP{
SrcPort: layers.TCPPort(srcPort),
DstPort: layers.TCPPort(dstPort),
SYN: true,
}
require.NoError(tb, tcp.SetNetworkLayerForChecksum(ipv4))
transportLayer = tcp
case layers.IPProtocolUDP:
udp := &layers.UDP{
SrcPort: layers.UDPPort(srcPort),
DstPort: layers.UDPPort(dstPort),
}
require.NoError(tb, udp.SetNetworkLayerForChecksum(ipv4))
transportLayer = udp
case layers.IPProtocolICMPv4:
icmp := &layers.ICMPv4{
TypeCode: layers.CreateICMPv4TypeCode(layers.ICMPv4TypeEchoRequest, 0),
}
transportLayer = icmp
}
buf := gopacket.NewSerializeBuffer()
opts := gopacket.SerializeOptions{ComputeChecksums: true, FixLengths: true}
err := gopacket.SerializeLayers(buf, opts, ipv4, transportLayer, gopacket.Payload("test"))
require.NoError(tb, err)
return buf.Bytes()
}
// BenchmarkChecksumUpdate specifically benchmarks checksum calculation performance
func BenchmarkChecksumUpdate(b *testing.B) {
// Create test data for checksum calculations
testData := make([]byte, 64) // Typical packet size for checksum testing
for i := range testData {
testData[i] = byte(i)
}
b.Run("ipv4_checksum", func(b *testing.B) {
for i := 0; i < b.N; i++ {
_ = ipv4Checksum(testData[:20]) // IPv4 header is typically 20 bytes
}
})
b.Run("icmp_checksum", func(b *testing.B) {
for i := 0; i < b.N; i++ {
_ = icmpChecksum(testData)
}
})
b.Run("incremental_update", func(b *testing.B) {
oldBytes := []byte{192, 168, 1, 100}
newBytes := []byte{10, 0, 0, 100}
oldChecksum := uint16(0x1234)
for i := 0; i < b.N; i++ {
_ = incrementalUpdate(oldChecksum, oldBytes, newBytes)
}
})
}
// BenchmarkDNATMemoryAllocations checks for memory allocations in DNAT operations
func BenchmarkDNATMemoryAllocations(b *testing.B) {
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger)
require.NoError(b, err)
defer func() {
require.NoError(b, manager.Close(nil))
}()
// Set logger to error level to reduce noise during benchmarking
manager.SetLogLevel(log.ErrorLevel)
defer func() {
// Restore to info level after benchmark
manager.SetLogLevel(log.InfoLevel)
}()
originalIP := netip.MustParseAddr("192.168.1.100")
translatedIP := netip.MustParseAddr("10.0.0.100")
srcIP := netip.MustParseAddr("172.16.0.1")
err = manager.AddInternalDNATMapping(originalIP, translatedIP)
require.NoError(b, err)
packet := generateDNATTestPacket(b, srcIP, originalIP, layers.IPProtocolTCP, 12345, 80)
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
// Create fresh packet each time to isolate allocation testing
testPacket := make([]byte, len(packet))
copy(testPacket, packet)
// Parse the packet fresh each time to get a clean decoder
d := &decoder{decoded: []gopacket.LayerType{}}
d.parser = gopacket.NewDecodingLayerParser(
layers.LayerTypeIPv4,
&d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp,
)
d.parser.IgnoreUnsupported = true
err = d.parser.DecodeLayers(testPacket, &d.decoded)
assert.NoError(b, err)
manager.translateOutboundDNAT(testPacket, d)
}
}
// BenchmarkDirectIPExtraction tests the performance improvement of direct IP extraction
func BenchmarkDirectIPExtraction(b *testing.B) {
// Create a test packet
srcIP := netip.MustParseAddr("172.16.0.1")
dstIP := netip.MustParseAddr("192.168.1.100")
packet := generateDNATTestPacket(b, srcIP, dstIP, layers.IPProtocolTCP, 12345, 80)
b.Run("direct_byte_access", func(b *testing.B) {
for i := 0; i < b.N; i++ {
// Direct extraction from packet bytes
_ = netip.AddrFrom4([4]byte{packet[16], packet[17], packet[18], packet[19]})
}
})
b.Run("decoder_extraction", func(b *testing.B) {
// Create decoder once for comparison
d := &decoder{decoded: []gopacket.LayerType{}}
d.parser = gopacket.NewDecodingLayerParser(
layers.LayerTypeIPv4,
&d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp,
)
d.parser.IgnoreUnsupported = true
err := d.parser.DecodeLayers(packet, &d.decoded)
assert.NoError(b, err)
for i := 0; i < b.N; i++ {
// Extract using decoder (traditional method)
dst, _ := netip.AddrFromSlice(d.ip4.DstIP)
_ = dst
}
})
}
// BenchmarkChecksumOptimizations compares optimized vs standard checksum implementations
func BenchmarkChecksumOptimizations(b *testing.B) {
// Create test IPv4 header (20 bytes)
header := make([]byte, 20)
for i := range header {
header[i] = byte(i)
}
// Clear checksum field
header[10] = 0
header[11] = 0
b.Run("optimized_ipv4_checksum", func(b *testing.B) {
for i := 0; i < b.N; i++ {
_ = ipv4Checksum(header)
}
})
// Test incremental checksum updates
oldIP := []byte{192, 168, 1, 100}
newIP := []byte{10, 0, 0, 100}
oldChecksum := uint16(0x1234)
b.Run("optimized_incremental_update", func(b *testing.B) {
for i := 0; i < b.N; i++ {
_ = incrementalUpdate(oldChecksum, oldIP, newIP)
}
})
}

View File

@@ -1,145 +0,0 @@
package uspfilter
import (
"net/netip"
"testing"
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/iface/device"
)
// TestDNATTranslationCorrectness verifies DNAT translation works correctly
func TestDNATTranslationCorrectness(t *testing.T) {
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger)
require.NoError(t, err)
defer func() {
require.NoError(t, manager.Close(nil))
}()
originalIP := netip.MustParseAddr("192.168.1.100")
translatedIP := netip.MustParseAddr("10.0.0.100")
srcIP := netip.MustParseAddr("172.16.0.1")
// Add DNAT mapping
err = manager.AddInternalDNATMapping(originalIP, translatedIP)
require.NoError(t, err)
testCases := []struct {
name string
protocol layers.IPProtocol
srcPort uint16
dstPort uint16
}{
{"TCP", layers.IPProtocolTCP, 12345, 80},
{"UDP", layers.IPProtocolUDP, 12345, 53},
{"ICMP", layers.IPProtocolICMPv4, 0, 0},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
// Test outbound DNAT translation
outboundPacket := generateDNATTestPacket(t, srcIP, originalIP, tc.protocol, tc.srcPort, tc.dstPort)
originalOutbound := make([]byte, len(outboundPacket))
copy(originalOutbound, outboundPacket)
// Process outbound packet (should translate destination)
translated := manager.translateOutboundDNAT(outboundPacket, parsePacket(t, outboundPacket))
require.True(t, translated, "Outbound packet should be translated")
// Verify destination IP was changed
dstIPAfter := netip.AddrFrom4([4]byte{outboundPacket[16], outboundPacket[17], outboundPacket[18], outboundPacket[19]})
require.Equal(t, translatedIP, dstIPAfter, "Destination IP should be translated")
// Test inbound reverse DNAT translation
inboundPacket := generateDNATTestPacket(t, translatedIP, srcIP, tc.protocol, tc.dstPort, tc.srcPort)
originalInbound := make([]byte, len(inboundPacket))
copy(originalInbound, inboundPacket)
// Process inbound packet (should reverse translate source)
reversed := manager.translateInboundReverse(inboundPacket, parsePacket(t, inboundPacket))
require.True(t, reversed, "Inbound packet should be reverse translated")
// Verify source IP was changed back to original
srcIPAfter := netip.AddrFrom4([4]byte{inboundPacket[12], inboundPacket[13], inboundPacket[14], inboundPacket[15]})
require.Equal(t, originalIP, srcIPAfter, "Source IP should be reverse translated")
// Test that checksums are recalculated correctly
if tc.protocol != layers.IPProtocolICMPv4 {
// For TCP/UDP, verify the transport checksum was updated
require.NotEqual(t, originalOutbound, outboundPacket, "Outbound packet should be modified")
require.NotEqual(t, originalInbound, inboundPacket, "Inbound packet should be modified")
}
})
}
}
// parsePacket helper to create a decoder for testing
func parsePacket(t testing.TB, packetData []byte) *decoder {
t.Helper()
d := &decoder{
decoded: []gopacket.LayerType{},
}
d.parser = gopacket.NewDecodingLayerParser(
layers.LayerTypeIPv4,
&d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp,
)
d.parser.IgnoreUnsupported = true
err := d.parser.DecodeLayers(packetData, &d.decoded)
require.NoError(t, err)
return d
}
// TestDNATMappingManagement tests adding/removing DNAT mappings
func TestDNATMappingManagement(t *testing.T) {
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger)
require.NoError(t, err)
defer func() {
require.NoError(t, manager.Close(nil))
}()
originalIP := netip.MustParseAddr("192.168.1.100")
translatedIP := netip.MustParseAddr("10.0.0.100")
// Test adding mapping
err = manager.AddInternalDNATMapping(originalIP, translatedIP)
require.NoError(t, err)
// Verify mapping exists
result, exists := manager.getDNATTranslation(originalIP)
require.True(t, exists)
require.Equal(t, translatedIP, result)
// Test reverse lookup
reverseResult, exists := manager.findReverseDNATMapping(translatedIP)
require.True(t, exists)
require.Equal(t, originalIP, reverseResult)
// Test removing mapping
err = manager.RemoveInternalDNATMapping(originalIP)
require.NoError(t, err)
// Verify mapping no longer exists
_, exists = manager.getDNATTranslation(originalIP)
require.False(t, exists)
_, exists = manager.findReverseDNATMapping(translatedIP)
require.False(t, exists)
// Test error cases
err = manager.AddInternalDNATMapping(netip.Addr{}, translatedIP)
require.Error(t, err, "Should reject invalid original IP")
err = manager.AddInternalDNATMapping(originalIP, netip.Addr{})
require.Error(t, err, "Should reject invalid translated IP")
err = manager.RemoveInternalDNATMapping(originalIP)
require.Error(t, err, "Should error when removing non-existent mapping")
}

View File

@@ -401,7 +401,7 @@ func (m *Manager) addForwardingResult(trace *PacketTrace, action, remoteAddr str
func (m *Manager) traceOutbound(packetData []byte, trace *PacketTrace) *PacketTrace {
// will create or update the connection state
dropped := m.filterOutbound(packetData, 0)
dropped := m.processOutgoingHooks(packetData, 0)
if dropped {
trace.AddResult(StageCompleted, "Packet dropped by outgoing hook", false)
} else {

View File

@@ -104,12 +104,6 @@ type Manager struct {
flowLogger nftypes.FlowLogger
blockRule firewall.Rule
// Internal 1:1 DNAT
dnatEnabled atomic.Bool
dnatMappings map[netip.Addr]netip.Addr
dnatMutex sync.RWMutex
dnatBiMap *biDNATMap
}
// decoder for packages
@@ -195,7 +189,6 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe
flowLogger: flowLogger,
netstack: netstack.IsEnabled(),
localForwarding: enableLocalForwarding,
dnatMappings: make(map[netip.Addr]netip.Addr),
}
m.routingEnabled.Store(false)
@@ -526,6 +519,22 @@ func (m *Manager) SetLegacyManagement(isLegacy bool) error {
// Flush doesn't need to be implemented for this manager
func (m *Manager) Flush() error { return nil }
// AddDNATRule adds a DNAT rule
func (m *Manager) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) {
if m.nativeFirewall == nil {
return nil, errNatNotSupported
}
return m.nativeFirewall.AddDNATRule(rule)
}
// DeleteDNATRule deletes a DNAT rule
func (m *Manager) DeleteDNATRule(rule firewall.Rule) error {
if m.nativeFirewall == nil {
return errNatNotSupported
}
return m.nativeFirewall.DeleteDNATRule(rule)
}
// UpdateSet updates the rule destinations associated with the given set
// by merging the existing prefixes with the new ones, then deduplicating.
func (m *Manager) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
@@ -572,14 +581,14 @@ func (m *Manager) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
return nil
}
// FilterOutBound filters outgoing packets
func (m *Manager) FilterOutbound(packetData []byte, size int) bool {
return m.filterOutbound(packetData, size)
// DropOutgoing filter outgoing packets
func (m *Manager) DropOutgoing(packetData []byte, size int) bool {
return m.processOutgoingHooks(packetData, size)
}
// FilterInbound filters incoming packets
func (m *Manager) FilterInbound(packetData []byte, size int) bool {
return m.filterInbound(packetData, size)
// DropIncoming filter incoming packets
func (m *Manager) DropIncoming(packetData []byte, size int) bool {
return m.dropFilter(packetData, size)
}
// UpdateLocalIPs updates the list of local IPs
@@ -587,7 +596,7 @@ func (m *Manager) UpdateLocalIPs() error {
return m.localipmanager.UpdateLocalIPs(m.wgIface)
}
func (m *Manager) filterOutbound(packetData []byte, size int) bool {
func (m *Manager) processOutgoingHooks(packetData []byte, size int) bool {
d := m.decoders.Get().(*decoder)
defer m.decoders.Put(d)
@@ -609,8 +618,8 @@ func (m *Manager) filterOutbound(packetData []byte, size int) bool {
return true
}
// for netflow we keep track even if the firewall is stateless
m.trackOutbound(d, srcIP, dstIP, size)
m.translateOutboundDNAT(packetData, d)
return false
}
@@ -662,7 +671,7 @@ func (m *Manager) trackOutbound(d *decoder, srcIP, dstIP netip.Addr, size int) {
flags := getTCPFlags(&d.tcp)
m.tcpTracker.TrackOutbound(srcIP, dstIP, uint16(d.tcp.SrcPort), uint16(d.tcp.DstPort), flags, size)
case layers.LayerTypeICMPv4:
m.icmpTracker.TrackOutbound(srcIP, dstIP, d.icmp4.Id, d.icmp4.TypeCode, d.icmp4.Payload, size)
m.icmpTracker.TrackOutbound(srcIP, dstIP, d.icmp4.Id, d.icmp4.TypeCode, size)
}
}
@@ -675,7 +684,7 @@ func (m *Manager) trackInbound(d *decoder, srcIP, dstIP netip.Addr, ruleID []byt
flags := getTCPFlags(&d.tcp)
m.tcpTracker.TrackInbound(srcIP, dstIP, uint16(d.tcp.SrcPort), uint16(d.tcp.DstPort), flags, ruleID, size)
case layers.LayerTypeICMPv4:
m.icmpTracker.TrackInbound(srcIP, dstIP, d.icmp4.Id, d.icmp4.TypeCode, ruleID, d.icmp4.Payload, size)
m.icmpTracker.TrackInbound(srcIP, dstIP, d.icmp4.Id, d.icmp4.TypeCode, ruleID, size)
}
}
@@ -714,9 +723,9 @@ func (m *Manager) udpHooksDrop(dport uint16, dstIP netip.Addr, packetData []byte
return false
}
// filterInbound implements filtering logic for incoming packets.
// dropFilter implements filtering logic for incoming packets.
// If it returns true, the packet should be dropped.
func (m *Manager) filterInbound(packetData []byte, size int) bool {
func (m *Manager) dropFilter(packetData []byte, size int) bool {
d := m.decoders.Get().(*decoder)
defer m.decoders.Put(d)
@@ -738,15 +747,8 @@ func (m *Manager) filterInbound(packetData []byte, size int) bool {
return false
}
if translated := m.translateInboundReverse(packetData, d); translated {
// Re-decode after translation to get original addresses
if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil {
m.logger.Error("Failed to re-decode packet after reverse DNAT: %v", err)
return true
}
srcIP, dstIP = m.extractIPs(d)
}
// For all inbound traffic, first check if it matches a tracked connection.
// This must happen before any other filtering because the packets are statefully tracked.
if m.stateful && m.isValidTrackedConnection(d, srcIP, dstIP, size) {
return false
}

View File

@@ -188,13 +188,13 @@ func BenchmarkCoreFiltering(b *testing.B) {
// For stateful scenarios, establish the connection
if sc.stateful {
manager.filterOutbound(outbound, 0)
manager.processOutgoingHooks(outbound, 0)
}
// Measure inbound packet processing
b.ResetTimer()
for i := 0; i < b.N; i++ {
manager.filterInbound(inbound, 0)
manager.dropFilter(inbound, 0)
}
})
}
@@ -220,7 +220,7 @@ func BenchmarkStateScaling(b *testing.B) {
for i := 0; i < count; i++ {
outbound := generatePacket(b, srcIPs[i], dstIPs[i],
uint16(1024+i), 80, layers.IPProtocolTCP)
manager.filterOutbound(outbound, 0)
manager.processOutgoingHooks(outbound, 0)
}
// Test packet
@@ -228,11 +228,11 @@ func BenchmarkStateScaling(b *testing.B) {
testIn := generatePacket(b, dstIPs[0], srcIPs[0], 80, 1024, layers.IPProtocolTCP)
// First establish our test connection
manager.filterOutbound(testOut, 0)
manager.processOutgoingHooks(testOut, 0)
b.ResetTimer()
for i := 0; i < b.N; i++ {
manager.filterInbound(testIn, 0)
manager.dropFilter(testIn, 0)
}
})
}
@@ -263,12 +263,12 @@ func BenchmarkEstablishmentOverhead(b *testing.B) {
inbound := generatePacket(b, dstIP, srcIP, 80, 1024, layers.IPProtocolTCP)
if sc.established {
manager.filterOutbound(outbound, 0)
manager.processOutgoingHooks(outbound, 0)
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
manager.filterInbound(inbound, 0)
manager.dropFilter(inbound, 0)
}
})
}
@@ -426,25 +426,25 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
// For stateful cases and established connections
if !strings.Contains(sc.name, "allow_non_wg") ||
(strings.Contains(sc.state, "established") || sc.state == "post_handshake") {
manager.filterOutbound(outbound, 0)
manager.processOutgoingHooks(outbound, 0)
// For TCP post-handshake, simulate full handshake
if sc.state == "post_handshake" {
// SYN
syn := generateTCPPacketWithFlags(b, srcIP, dstIP, 1024, 80, uint16(conntrack.TCPSyn))
manager.filterOutbound(syn, 0)
manager.processOutgoingHooks(syn, 0)
// SYN-ACK
synack := generateTCPPacketWithFlags(b, dstIP, srcIP, 80, 1024, uint16(conntrack.TCPSyn|conntrack.TCPAck))
manager.filterInbound(synack, 0)
manager.dropFilter(synack, 0)
// ACK
ack := generateTCPPacketWithFlags(b, srcIP, dstIP, 1024, 80, uint16(conntrack.TCPAck))
manager.filterOutbound(ack, 0)
manager.processOutgoingHooks(ack, 0)
}
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
manager.filterInbound(inbound, 0)
manager.dropFilter(inbound, 0)
}
})
}
@@ -568,17 +568,17 @@ func BenchmarkLongLivedConnections(b *testing.B) {
// Initial SYN
syn := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
uint16(1024+i), 80, uint16(conntrack.TCPSyn))
manager.filterOutbound(syn, 0)
manager.processOutgoingHooks(syn, 0)
// SYN-ACK
synack := generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i],
80, uint16(1024+i), uint16(conntrack.TCPSyn|conntrack.TCPAck))
manager.filterInbound(synack, 0)
manager.dropFilter(synack, 0)
// ACK
ack := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
uint16(1024+i), 80, uint16(conntrack.TCPAck))
manager.filterOutbound(ack, 0)
manager.processOutgoingHooks(ack, 0)
}
// Prepare test packets simulating bidirectional traffic
@@ -599,9 +599,9 @@ func BenchmarkLongLivedConnections(b *testing.B) {
// Simulate bidirectional traffic
// First outbound data
manager.filterOutbound(outPackets[connIdx], 0)
manager.processOutgoingHooks(outPackets[connIdx], 0)
// Then inbound response - this is what we're actually measuring
manager.filterInbound(inPackets[connIdx], 0)
manager.dropFilter(inPackets[connIdx], 0)
}
})
}
@@ -700,19 +700,19 @@ func BenchmarkShortLivedConnections(b *testing.B) {
p := patterns[connIdx]
// Connection establishment
manager.filterOutbound(p.syn, 0)
manager.filterInbound(p.synAck, 0)
manager.filterOutbound(p.ack, 0)
manager.processOutgoingHooks(p.syn, 0)
manager.dropFilter(p.synAck, 0)
manager.processOutgoingHooks(p.ack, 0)
// Data transfer
manager.filterOutbound(p.request, 0)
manager.filterInbound(p.response, 0)
manager.processOutgoingHooks(p.request, 0)
manager.dropFilter(p.response, 0)
// Connection teardown
manager.filterOutbound(p.finClient, 0)
manager.filterInbound(p.ackServer, 0)
manager.filterInbound(p.finServer, 0)
manager.filterOutbound(p.ackClient, 0)
manager.processOutgoingHooks(p.finClient, 0)
manager.dropFilter(p.ackServer, 0)
manager.dropFilter(p.finServer, 0)
manager.processOutgoingHooks(p.ackClient, 0)
}
})
}
@@ -760,15 +760,15 @@ func BenchmarkParallelLongLivedConnections(b *testing.B) {
for i := 0; i < sc.connCount; i++ {
syn := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
uint16(1024+i), 80, uint16(conntrack.TCPSyn))
manager.filterOutbound(syn, 0)
manager.processOutgoingHooks(syn, 0)
synack := generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i],
80, uint16(1024+i), uint16(conntrack.TCPSyn|conntrack.TCPAck))
manager.filterInbound(synack, 0)
manager.dropFilter(synack, 0)
ack := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
uint16(1024+i), 80, uint16(conntrack.TCPAck))
manager.filterOutbound(ack, 0)
manager.processOutgoingHooks(ack, 0)
}
// Pre-generate test packets
@@ -790,8 +790,8 @@ func BenchmarkParallelLongLivedConnections(b *testing.B) {
counter++
// Simulate bidirectional traffic
manager.filterOutbound(outPackets[connIdx], 0)
manager.filterInbound(inPackets[connIdx], 0)
manager.processOutgoingHooks(outPackets[connIdx], 0)
manager.dropFilter(inPackets[connIdx], 0)
}
})
})
@@ -879,17 +879,17 @@ func BenchmarkParallelShortLivedConnections(b *testing.B) {
p := patterns[connIdx]
// Full connection lifecycle
manager.filterOutbound(p.syn, 0)
manager.filterInbound(p.synAck, 0)
manager.filterOutbound(p.ack, 0)
manager.processOutgoingHooks(p.syn, 0)
manager.dropFilter(p.synAck, 0)
manager.processOutgoingHooks(p.ack, 0)
manager.filterOutbound(p.request, 0)
manager.filterInbound(p.response, 0)
manager.processOutgoingHooks(p.request, 0)
manager.dropFilter(p.response, 0)
manager.filterOutbound(p.finClient, 0)
manager.filterInbound(p.ackServer, 0)
manager.filterInbound(p.finServer, 0)
manager.filterOutbound(p.ackClient, 0)
manager.processOutgoingHooks(p.finClient, 0)
manager.dropFilter(p.ackServer, 0)
manager.dropFilter(p.finServer, 0)
manager.processOutgoingHooks(p.ackClient, 0)
}
})
})

View File

@@ -462,7 +462,7 @@ func TestPeerACLFiltering(t *testing.T) {
t.Run("Implicit DROP (no rules)", func(t *testing.T) {
packet := createTestPacket(t, "100.10.0.1", "100.10.0.100", fw.ProtocolTCP, 12345, 443)
isDropped := manager.FilterInbound(packet, 0)
isDropped := manager.DropIncoming(packet, 0)
require.True(t, isDropped, "Packet should be dropped when no rules exist")
})
@@ -509,7 +509,7 @@ func TestPeerACLFiltering(t *testing.T) {
})
packet := createTestPacket(t, tc.srcIP, tc.dstIP, tc.proto, tc.srcPort, tc.dstPort)
isDropped := manager.FilterInbound(packet, 0)
isDropped := manager.DropIncoming(packet, 0)
require.Equal(t, tc.shouldBeBlocked, isDropped)
})
}
@@ -1233,7 +1233,7 @@ func TestRouteACLFiltering(t *testing.T) {
srcIP := netip.MustParseAddr(tc.srcIP)
dstIP := netip.MustParseAddr(tc.dstIP)
// testing routeACLsPass only and not FilterInbound, as routed packets are dropped after being passed
// testing routeACLsPass only and not DropIncoming, as routed packets are dropped after being passed
// to the forwarder
_, isAllowed := manager.routeACLsPass(srcIP, dstIP, tc.proto, tc.srcPort, tc.dstPort)
require.Equal(t, tc.shouldPass, isAllowed)

View File

@@ -321,7 +321,7 @@ func TestNotMatchByIP(t *testing.T) {
return
}
if m.filterInbound(buf.Bytes(), 0) {
if m.dropFilter(buf.Bytes(), 0) {
t.Errorf("expected packet to be accepted")
return
}
@@ -447,7 +447,7 @@ func TestProcessOutgoingHooks(t *testing.T) {
require.NoError(t, err)
// Test hook gets called
result := manager.filterOutbound(buf.Bytes(), 0)
result := manager.processOutgoingHooks(buf.Bytes(), 0)
require.True(t, result)
require.True(t, hookCalled)
@@ -457,7 +457,7 @@ func TestProcessOutgoingHooks(t *testing.T) {
err = gopacket.SerializeLayers(buf, opts, ipv4)
require.NoError(t, err)
result = manager.filterOutbound(buf.Bytes(), 0)
result = manager.processOutgoingHooks(buf.Bytes(), 0)
require.False(t, result)
}
@@ -553,7 +553,7 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
require.NoError(t, err)
// Process outbound packet and verify connection tracking
drop := manager.FilterOutbound(outboundBuf.Bytes(), 0)
drop := manager.DropOutgoing(outboundBuf.Bytes(), 0)
require.False(t, drop, "Initial outbound packet should not be dropped")
// Verify connection was tracked
@@ -620,7 +620,7 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
for _, cp := range checkPoints {
time.Sleep(cp.sleep)
drop = manager.filterInbound(inboundBuf.Bytes(), 0)
drop = manager.dropFilter(inboundBuf.Bytes(), 0)
require.Equal(t, cp.shouldAllow, !drop, cp.description)
// If the connection should still be valid, verify it exists
@@ -669,7 +669,7 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
}
// Create a new outbound connection for invalid tests
drop = manager.filterOutbound(outboundBuf.Bytes(), 0)
drop = manager.processOutgoingHooks(outboundBuf.Bytes(), 0)
require.False(t, drop, "Second outbound packet should not be dropped")
for _, tc := range invalidCases {
@@ -691,7 +691,7 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
require.NoError(t, err)
// Verify the invalid packet is dropped
drop = manager.filterInbound(testBuf.Bytes(), 0)
drop = manager.dropFilter(testBuf.Bytes(), 0)
require.True(t, drop, tc.description)
})
}

View File

@@ -1,94 +0,0 @@
package bind
import (
"net/netip"
"sync"
"sync/atomic"
"time"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/monotime"
)
const (
saveFrequency = int64(5 * time.Second)
)
type PeerRecord struct {
Address netip.AddrPort
LastActivity atomic.Int64 // UnixNano timestamp
}
type ActivityRecorder struct {
mu sync.RWMutex
peers map[string]*PeerRecord // publicKey to PeerRecord map
addrToPeer map[netip.AddrPort]*PeerRecord // address to PeerRecord map
}
func NewActivityRecorder() *ActivityRecorder {
return &ActivityRecorder{
peers: make(map[string]*PeerRecord),
addrToPeer: make(map[netip.AddrPort]*PeerRecord),
}
}
// GetLastActivities returns a snapshot of peer last activity
func (r *ActivityRecorder) GetLastActivities() map[string]time.Time {
r.mu.RLock()
defer r.mu.RUnlock()
activities := make(map[string]time.Time, len(r.peers))
for key, record := range r.peers {
unixNano := record.LastActivity.Load()
activities[key] = time.Unix(0, unixNano)
}
return activities
}
// UpsertAddress adds or updates the address for a publicKey
func (r *ActivityRecorder) UpsertAddress(publicKey string, address netip.AddrPort) {
r.mu.Lock()
defer r.mu.Unlock()
if pr, exists := r.peers[publicKey]; exists {
delete(r.addrToPeer, pr.Address)
pr.Address = address
} else {
record := &PeerRecord{
Address: address,
}
record.LastActivity.Store(monotime.Now())
r.peers[publicKey] = record
}
r.addrToPeer[address] = r.peers[publicKey]
}
func (r *ActivityRecorder) Remove(publicKey string) {
r.mu.Lock()
defer r.mu.Unlock()
if record, exists := r.peers[publicKey]; exists {
delete(r.addrToPeer, record.Address)
delete(r.peers, publicKey)
}
}
// record updates LastActivity for the given address using atomic store
func (r *ActivityRecorder) record(address netip.AddrPort) {
r.mu.RLock()
record, ok := r.addrToPeer[address]
r.mu.RUnlock()
if !ok {
log.Warnf("could not find record for address %s", address)
return
}
now := monotime.Now()
last := record.LastActivity.Load()
if now-last < saveFrequency {
return
}
_ = record.LastActivity.CompareAndSwap(last, now)
}

View File

@@ -1,27 +0,0 @@
package bind
import (
"net/netip"
"testing"
"time"
)
func TestActivityRecorder_GetLastActivities(t *testing.T) {
peer := "peer1"
ar := NewActivityRecorder()
ar.UpsertAddress("peer1", netip.MustParseAddrPort("192.168.0.5:51820"))
activities := ar.GetLastActivities()
p, ok := activities[peer]
if !ok {
t.Fatalf("Expected activity for peer %s, but got none", peer)
}
if p.IsZero() {
t.Fatalf("Expected activity for peer %s, but got zero", peer)
}
if p.Before(time.Now().Add(-2 * time.Minute)) {
t.Fatalf("Expected activity for peer %s to be recent, but got %v", peer, p)
}
}

View File

@@ -1,153 +0,0 @@
package bind
import (
"errors"
"net"
"time"
"github.com/hashicorp/go-multierror"
nberrors "github.com/netbirdio/netbird/client/errors"
)
var (
errNoIPv4Conn = errors.New("no IPv4 connection available")
errNoIPv6Conn = errors.New("no IPv6 connection available")
errInvalidAddr = errors.New("invalid address type")
)
// DualStackPacketConn is a composite PacketConn that can handle both IPv4 and IPv6
type DualStackPacketConn struct {
ipv4Conn net.PacketConn
ipv6Conn net.PacketConn
}
// NewDualStackPacketConn creates a new dual-stack packet connection
func NewDualStackPacketConn(ipv4Conn, ipv6Conn net.PacketConn) *DualStackPacketConn {
return &DualStackPacketConn{
ipv4Conn: ipv4Conn,
ipv6Conn: ipv6Conn,
}
}
// ReadFrom reads from both IPv4 and IPv6 connections
func (d *DualStackPacketConn) ReadFrom(b []byte) (n int, addr net.Addr, err error) {
// Prefer IPv4 if available
if d.ipv4Conn != nil {
return d.ipv4Conn.ReadFrom(b)
}
if d.ipv6Conn != nil {
return d.ipv6Conn.ReadFrom(b)
}
return 0, nil, net.ErrClosed
}
// WriteTo writes to the appropriate connection based on the address type
func (d *DualStackPacketConn) WriteTo(b []byte, addr net.Addr) (n int, err error) {
udpAddr, ok := addr.(*net.UDPAddr)
if !ok {
return 0, &net.OpError{
Op: "write",
Net: "udp",
Addr: addr,
Err: errInvalidAddr,
}
}
if udpAddr.IP.To4() == nil {
if d.ipv6Conn != nil {
return d.ipv6Conn.WriteTo(b, addr)
}
return 0, &net.OpError{
Op: "write",
Net: "udp6",
Addr: addr,
Err: errNoIPv6Conn,
}
}
if d.ipv4Conn != nil {
return d.ipv4Conn.WriteTo(b, addr)
}
return 0, &net.OpError{
Op: "write",
Net: "udp4",
Addr: addr,
Err: errNoIPv4Conn,
}
}
// Close closes both connections
func (d *DualStackPacketConn) Close() error {
var result *multierror.Error
if d.ipv4Conn != nil {
if err := d.ipv4Conn.Close(); err != nil {
result = multierror.Append(result, err)
}
}
if d.ipv6Conn != nil {
if err := d.ipv6Conn.Close(); err != nil {
result = multierror.Append(result, err)
}
}
return nberrors.FormatErrorOrNil(result)
}
// LocalAddr returns the local address of the IPv4 connection (for compatibility)
func (d *DualStackPacketConn) LocalAddr() net.Addr {
if d.ipv4Conn != nil {
return d.ipv4Conn.LocalAddr()
}
if d.ipv6Conn != nil {
return d.ipv6Conn.LocalAddr()
}
return nil
}
// SetDeadline sets the deadline for both connections
func (d *DualStackPacketConn) SetDeadline(t time.Time) error {
var result *multierror.Error
if d.ipv4Conn != nil {
if err := d.ipv4Conn.SetDeadline(t); err != nil {
result = multierror.Append(result, err)
}
}
if d.ipv6Conn != nil {
if err := d.ipv6Conn.SetDeadline(t); err != nil {
result = multierror.Append(result, err)
}
}
return nberrors.FormatErrorOrNil(result)
}
// SetReadDeadline sets the read deadline for both connections
func (d *DualStackPacketConn) SetReadDeadline(t time.Time) error {
var result *multierror.Error
if d.ipv4Conn != nil {
if err := d.ipv4Conn.SetReadDeadline(t); err != nil {
result = multierror.Append(result, err)
}
}
if d.ipv6Conn != nil {
if err := d.ipv6Conn.SetReadDeadline(t); err != nil {
result = multierror.Append(result, err)
}
}
return nberrors.FormatErrorOrNil(result)
}
// SetWriteDeadline sets the write deadline for both connections
func (d *DualStackPacketConn) SetWriteDeadline(t time.Time) error {
var result *multierror.Error
if d.ipv4Conn != nil {
if err := d.ipv4Conn.SetWriteDeadline(t); err != nil {
result = multierror.Append(result, err)
}
}
if d.ipv6Conn != nil {
if err := d.ipv6Conn.SetWriteDeadline(t); err != nil {
result = multierror.Append(result, err)
}
}
return nberrors.FormatErrorOrNil(result)
}

View File

@@ -1,8 +1,7 @@
package bind
import (
"encoding/binary"
"errors"
"fmt"
"net"
"net/netip"
"runtime"
@@ -11,6 +10,7 @@ import (
"github.com/pion/stun/v2"
"github.com/pion/transport/v3"
log "github.com/sirupsen/logrus"
"golang.org/x/net/ipv4"
"golang.org/x/net/ipv6"
wgConn "golang.zx2c4.com/wireguard/conn"
@@ -26,8 +26,8 @@ type receiverCreator struct {
iceBind *ICEBind
}
func (rc receiverCreator) CreateReceiverFn(pc wgConn.BatchReader, conn *net.UDPConn, rxOffload bool, msgPool *sync.Pool) wgConn.ReceiveFunc {
return rc.iceBind.createReceiverFn(pc, conn, rxOffload, msgPool)
func (rc receiverCreator) CreateIPv4ReceiverFn(pc *ipv4.PacketConn, conn *net.UDPConn, rxOffload bool, msgPool *sync.Pool) wgConn.ReceiveFunc {
return rc.iceBind.createIPv4ReceiverFn(pc, conn, rxOffload, msgPool)
}
// ICEBind is a bind implementation with two main features:
@@ -51,26 +51,22 @@ type ICEBind struct {
closedChanMu sync.RWMutex // protect the closeChan recreation from reading from it.
closed bool
muUDPMux sync.Mutex
udpMux *UniversalUDPMuxDefault
ipv4Conn *net.UDPConn
ipv6Conn *net.UDPConn
address wgaddr.Address
activityRecorder *ActivityRecorder
muUDPMux sync.Mutex
udpMux *UniversalUDPMuxDefault
address wgaddr.Address
}
func NewICEBind(transportNet transport.Net, filterFn FilterFn, address wgaddr.Address) *ICEBind {
b, _ := wgConn.NewStdNetBind().(*wgConn.StdNetBind)
ib := &ICEBind{
StdNetBind: b,
RecvChan: make(chan RecvMessage, 1),
transportNet: transportNet,
filterFn: filterFn,
endpoints: make(map[netip.Addr]net.Conn),
closedChan: make(chan struct{}),
closed: true,
address: address,
activityRecorder: NewActivityRecorder(),
StdNetBind: b,
RecvChan: make(chan RecvMessage, 1),
transportNet: transportNet,
filterFn: filterFn,
endpoints: make(map[netip.Addr]net.Conn),
closedChan: make(chan struct{}),
closed: true,
address: address,
}
rc := receiverCreator{
@@ -104,19 +100,15 @@ func (s *ICEBind) Close() error {
return s.StdNetBind.Close()
}
func (s *ICEBind) ActivityRecorder() *ActivityRecorder {
return s.activityRecorder
}
// GetICEMux returns the ICE UDPMux that was created and used by ICEBind
func (s *ICEBind) GetICEMux() (*UniversalUDPMuxDefault, error) {
s.muUDPMux.Lock()
defer s.muUDPMux.Unlock()
if s.udpMux != nil {
return s.udpMux, nil
if s.udpMux == nil {
return nil, fmt.Errorf("ICEBind has not been initialized yet")
}
return nil, errors.New("ICEBind has not been initialized yet")
return s.udpMux, nil
}
func (b *ICEBind) SetEndpoint(fakeIP netip.Addr, conn net.Conn) {
@@ -148,46 +140,18 @@ func (b *ICEBind) Send(bufs [][]byte, ep wgConn.Endpoint) error {
return nil
}
func (s *ICEBind) createReceiverFn(pc wgConn.BatchReader, conn *net.UDPConn, rxOffload bool, msgsPool *sync.Pool) wgConn.ReceiveFunc {
func (s *ICEBind) createIPv4ReceiverFn(pc *ipv4.PacketConn, conn *net.UDPConn, rxOffload bool, msgsPool *sync.Pool) wgConn.ReceiveFunc {
s.muUDPMux.Lock()
defer s.muUDPMux.Unlock()
localAddr, ok := conn.LocalAddr().(*net.UDPAddr)
if !ok {
log.Errorf("ICEBind: unexpected address type: %T", conn.LocalAddr())
return nil
}
isIPv6 := localAddr.IP.To4() == nil
if isIPv6 {
s.ipv6Conn = conn
} else {
s.ipv4Conn = conn
}
needsNewMux := s.udpMux == nil && (s.ipv4Conn != nil || s.ipv6Conn != nil)
needsUpgrade := s.udpMux != nil && s.ipv4Conn != nil && s.ipv6Conn != nil
if needsNewMux || needsUpgrade {
var iceMuxConn net.PacketConn
switch {
case s.ipv4Conn != nil && s.ipv6Conn != nil:
iceMuxConn = NewDualStackPacketConn(s.ipv4Conn, s.ipv6Conn)
case s.ipv4Conn != nil:
iceMuxConn = s.ipv4Conn
default:
iceMuxConn = s.ipv6Conn
}
s.udpMux = NewUniversalUDPMuxDefault(
UniversalUDPMuxParams{
UDPConn: iceMuxConn,
Net: s.transportNet,
FilterFn: s.filterFn,
WGAddress: s.address,
},
)
}
s.udpMux = NewUniversalUDPMuxDefault(
UniversalUDPMuxParams{
UDPConn: conn,
Net: s.transportNet,
FilterFn: s.filterFn,
WGAddress: s.address,
},
)
return func(bufs [][]byte, sizes []int, eps []wgConn.Endpoint) (n int, err error) {
msgs := getMessages(msgsPool)
for i := range bufs {
@@ -234,17 +198,7 @@ func (s *ICEBind) createReceiverFn(pc wgConn.BatchReader, conn *net.UDPConn, rxO
if sizes[i] == 0 {
continue
}
udpAddr, ok := msg.Addr.(*net.UDPAddr)
if !ok {
log.Errorf("ICEBind: unexpected address type: %T", msg.Addr)
continue
}
addrPort := udpAddr.AddrPort()
if isTransportPkg(msg.Buffers, msg.N) {
s.activityRecorder.record(addrPort)
}
addrPort := msg.Addr.(*net.UDPAddr).AddrPort()
ep := &wgConn.StdNetEndpoint{AddrPort: addrPort} // TODO: remove allocation
wgConn.GetSrcFromControl(msg.OOB[:msg.NN], ep)
eps[i] = ep
@@ -265,10 +219,9 @@ func (s *ICEBind) filterOutStunMessages(buffers [][]byte, n int, addr net.Addr)
return true, err
}
if s.udpMux != nil {
if err := s.udpMux.HandleSTUNMessage(msg, addr); err != nil {
log.Warnf("failed to handle STUN packet: %v", err)
}
muxErr := s.udpMux.HandleSTUNMessage(msg, addr)
if muxErr != nil {
log.Warnf("failed to handle STUN packet")
}
buffers[i] = []byte{}
@@ -304,13 +257,6 @@ func (c *ICEBind) receiveRelayed(buffs [][]byte, sizes []int, eps []wgConn.Endpo
copy(buffs[0], msg.Buffer)
sizes[0] = len(msg.Buffer)
eps[0] = wgConn.Endpoint(msg.Endpoint)
if isTransportPkg(buffs, sizes[0]) {
if ep, ok := eps[0].(*Endpoint); ok {
c.activityRecorder.record(ep.AddrPort)
}
}
return 1, nil
}
}
@@ -326,19 +272,3 @@ func putMessages(msgs *[]ipv6.Message, msgsPool *sync.Pool) {
}
msgsPool.Put(msgs)
}
func isTransportPkg(buffers [][]byte, n int) bool {
// The first buffer should contain at least 4 bytes for type
if len(buffers[0]) < 4 {
return true
}
// WireGuard packet type is a little-endian uint32 at start
packetType := binary.LittleEndian.Uint32(buffers[0][:4])
// Check if packetType matches known WireGuard message types
if packetType == 4 && n > 32 {
return true
}
return false
}

View File

@@ -342,9 +342,6 @@ func (m *UDPMuxDefault) Close() error {
}
func (m *UDPMuxDefault) writeTo(buf []byte, rAddr net.Addr) (n int, err error) {
if dualStackConn, ok := m.params.UDPConn.(*DualStackPacketConn); ok {
return dualStackConn.WriteTo(buf, rAddr)
}
return m.params.UDPConn.WriteTo(buf, rAddr)
}

View File

@@ -126,11 +126,6 @@ type udpConn struct {
}
func (u *udpConn) WriteTo(b []byte, addr net.Addr) (int, error) {
// Check if this is a dual-stack connection and handle IPv6 addresses properly
if dualStackConn, ok := u.PacketConn.(*DualStackPacketConn); ok {
return dualStackConn.WriteTo(b, addr)
}
if u.filterFn == nil {
return u.PacketConn.WriteTo(b, addr)
}
@@ -146,11 +141,6 @@ func (u *udpConn) handleCachedAddress(isRouted bool, b []byte, addr net.Addr) (i
if isRouted {
return 0, fmt.Errorf("address %s is part of a routed network, refusing to write", addr)
}
if dualStackConn, ok := u.PacketConn.(*DualStackPacketConn); ok {
return dualStackConn.WriteTo(b, addr)
}
return u.PacketConn.WriteTo(b, addr)
}
@@ -158,11 +148,6 @@ func (u *udpConn) handleUncachedAddress(b []byte, addr net.Addr) (int, error) {
if err := u.performFilterCheck(addr); err != nil {
return 0, err
}
if dualStackConn, ok := u.PacketConn.(*DualStackPacketConn); ok {
return dualStackConn.WriteTo(b, addr)
}
return u.PacketConn.WriteTo(b, addr)
}

View File

@@ -276,7 +276,3 @@ func (c *KernelConfigurer) GetStats() (map[string]WGStats, error) {
}
return stats, nil
}
func (c *KernelConfigurer) LastActivities() map[string]time.Time {
return nil
}

View File

@@ -16,7 +16,6 @@ import (
"golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/client/iface/bind"
nbnet "github.com/netbirdio/netbird/util/net"
)
@@ -37,18 +36,16 @@ const (
var ErrAllowedIPNotFound = fmt.Errorf("allowed IP not found")
type WGUSPConfigurer struct {
device *device.Device
deviceName string
activityRecorder *bind.ActivityRecorder
device *device.Device
deviceName string
uapiListener net.Listener
}
func NewUSPConfigurer(device *device.Device, deviceName string, activityRecorder *bind.ActivityRecorder) *WGUSPConfigurer {
func NewUSPConfigurer(device *device.Device, deviceName string) *WGUSPConfigurer {
wgCfg := &WGUSPConfigurer{
device: device,
deviceName: deviceName,
activityRecorder: activityRecorder,
device: device,
deviceName: deviceName,
}
wgCfg.startUAPI()
return wgCfg
@@ -90,19 +87,7 @@ func (c *WGUSPConfigurer) UpdatePeer(peerKey string, allowedIps []netip.Prefix,
Peers: []wgtypes.PeerConfig{peer},
}
if ipcErr := c.device.IpcSet(toWgUserspaceString(config)); ipcErr != nil {
return ipcErr
}
if endpoint != nil {
addr, err := netip.ParseAddr(endpoint.IP.String())
if err != nil {
return fmt.Errorf("failed to parse endpoint address: %w", err)
}
addrPort := netip.AddrPortFrom(addr, uint16(endpoint.Port))
c.activityRecorder.UpsertAddress(peerKey, addrPort)
}
return nil
return c.device.IpcSet(toWgUserspaceString(config))
}
func (c *WGUSPConfigurer) RemovePeer(peerKey string) error {
@@ -119,10 +104,7 @@ func (c *WGUSPConfigurer) RemovePeer(peerKey string) error {
config := wgtypes.Config{
Peers: []wgtypes.PeerConfig{peer},
}
ipcErr := c.device.IpcSet(toWgUserspaceString(config))
c.activityRecorder.Remove(peerKey)
return ipcErr
return c.device.IpcSet(toWgUserspaceString(config))
}
func (c *WGUSPConfigurer) AddAllowedIP(peerKey string, allowedIP netip.Prefix) error {
@@ -223,10 +205,6 @@ func (c *WGUSPConfigurer) FullStats() (*Stats, error) {
return parseStatus(c.deviceName, ipcStr)
}
func (c *WGUSPConfigurer) LastActivities() map[string]time.Time {
return c.activityRecorder.GetLastActivities()
}
// startUAPI starts the UAPI listener for managing the WireGuard interface via external tool
func (t *WGUSPConfigurer) startUAPI() {
var err error

View File

@@ -24,7 +24,6 @@ type WGTunDevice struct {
mtu int
iceBind *bind.ICEBind
tunAdapter TunAdapter
disableDNS bool
name string
device *device.Device
@@ -33,7 +32,7 @@ type WGTunDevice struct {
configurer WGConfigurer
}
func NewTunDevice(address wgaddr.Address, port int, key string, mtu int, iceBind *bind.ICEBind, tunAdapter TunAdapter, disableDNS bool) *WGTunDevice {
func NewTunDevice(address wgaddr.Address, port int, key string, mtu int, iceBind *bind.ICEBind, tunAdapter TunAdapter) *WGTunDevice {
return &WGTunDevice{
address: address,
port: port,
@@ -41,7 +40,6 @@ func NewTunDevice(address wgaddr.Address, port int, key string, mtu int, iceBind
mtu: mtu,
iceBind: iceBind,
tunAdapter: tunAdapter,
disableDNS: disableDNS,
}
}
@@ -51,13 +49,6 @@ func (t *WGTunDevice) Create(routes []string, dns string, searchDomains []string
routesString := routesToString(routes)
searchDomainsToString := searchDomainsToString(searchDomains)
// Skip DNS configuration when DisableDNS is enabled
if t.disableDNS {
log.Info("DNS is disabled, skipping DNS and search domain configuration")
dns = ""
searchDomainsToString = ""
}
fd, err := t.tunAdapter.ConfigureInterface(t.address.String(), t.mtu, dns, searchDomainsToString, routesString)
if err != nil {
log.Errorf("failed to create Android interface: %s", err)
@@ -79,7 +70,7 @@ func (t *WGTunDevice) Create(routes []string, dns string, searchDomains []string
// this helps with support for the older NetBird clients that had a hardcoded direct mode
// t.device.DisableSomeRoamingForBrokenMobileSemantics()
t.configurer = configurer.NewUSPConfigurer(t.device, t.name, t.iceBind.ActivityRecorder())
t.configurer = configurer.NewUSPConfigurer(t.device, t.name)
err = t.configurer.ConfigureInterface(t.key, t.port)
if err != nil {
t.device.Close()

View File

@@ -61,7 +61,7 @@ func (t *TunDevice) Create() (WGConfigurer, error) {
return nil, fmt.Errorf("error assigning ip: %s", err)
}
t.configurer = configurer.NewUSPConfigurer(t.device, t.name, t.iceBind.ActivityRecorder())
t.configurer = configurer.NewUSPConfigurer(t.device, t.name)
err = t.configurer.ConfigureInterface(t.key, t.port)
if err != nil {
t.device.Close()

View File

@@ -9,11 +9,11 @@ import (
// PacketFilter interface for firewall abilities
type PacketFilter interface {
// FilterOutbound filter outgoing packets from host to external destinations
FilterOutbound(packetData []byte, size int) bool
// DropOutgoing filter outgoing packets from host to external destinations
DropOutgoing(packetData []byte, size int) bool
// FilterInbound filter incoming packets from external sources to host
FilterInbound(packetData []byte, size int) bool
// DropIncoming filter incoming packets from external sources to host
DropIncoming(packetData []byte, size int) bool
// AddUDPPacketHook calls hook when UDP packet from given direction matched
//
@@ -54,7 +54,7 @@ func (d *FilteredDevice) Read(bufs [][]byte, sizes []int, offset int) (n int, er
}
for i := 0; i < n; i++ {
if filter.FilterOutbound(bufs[i][offset:offset+sizes[i]], sizes[i]) {
if filter.DropOutgoing(bufs[i][offset:offset+sizes[i]], sizes[i]) {
bufs = append(bufs[:i], bufs[i+1:]...)
sizes = append(sizes[:i], sizes[i+1:]...)
n--
@@ -78,7 +78,7 @@ func (d *FilteredDevice) Write(bufs [][]byte, offset int) (int, error) {
filteredBufs := make([][]byte, 0, len(bufs))
dropped := 0
for _, buf := range bufs {
if !filter.FilterInbound(buf[offset:], len(buf)) {
if !filter.DropIncoming(buf[offset:], len(buf)) {
filteredBufs = append(filteredBufs, buf)
dropped++
}

View File

@@ -146,7 +146,7 @@ func TestDeviceWrapperRead(t *testing.T) {
tun.EXPECT().Write(mockBufs, 0).Return(0, nil)
filter := mocks.NewMockPacketFilter(ctrl)
filter.EXPECT().FilterInbound(gomock.Any(), gomock.Any()).Return(true)
filter.EXPECT().DropIncoming(gomock.Any(), gomock.Any()).Return(true)
wrapped := newDeviceFilter(tun)
wrapped.filter = filter
@@ -201,7 +201,7 @@ func TestDeviceWrapperRead(t *testing.T) {
return 1, nil
})
filter := mocks.NewMockPacketFilter(ctrl)
filter.EXPECT().FilterOutbound(gomock.Any(), gomock.Any()).Return(true)
filter.EXPECT().DropOutgoing(gomock.Any(), gomock.Any()).Return(true)
wrapped := newDeviceFilter(tun)
wrapped.filter = filter

View File

@@ -71,7 +71,7 @@ func (t *TunDevice) Create() (WGConfigurer, error) {
// this helps with support for the older NetBird clients that had a hardcoded direct mode
// t.device.DisableSomeRoamingForBrokenMobileSemantics()
t.configurer = configurer.NewUSPConfigurer(t.device, t.name, t.iceBind.ActivityRecorder())
t.configurer = configurer.NewUSPConfigurer(t.device, t.name)
err = t.configurer.ConfigureInterface(t.key, t.port)
if err != nil {
t.device.Close()

View File

@@ -72,7 +72,7 @@ func (t *TunNetstackDevice) Create() (WGConfigurer, error) {
device.NewLogger(wgLogLevel(), "[netbird] "),
)
t.configurer = configurer.NewUSPConfigurer(t.device, t.name, t.iceBind.ActivityRecorder())
t.configurer = configurer.NewUSPConfigurer(t.device, t.name)
err = t.configurer.ConfigureInterface(t.key, t.port)
if err != nil {
_ = tunIface.Close()

View File

@@ -64,7 +64,7 @@ func (t *USPDevice) Create() (WGConfigurer, error) {
return nil, fmt.Errorf("error assigning ip: %s", err)
}
t.configurer = configurer.NewUSPConfigurer(t.device, t.name, t.iceBind.ActivityRecorder())
t.configurer = configurer.NewUSPConfigurer(t.device, t.name)
err = t.configurer.ConfigureInterface(t.key, t.port)
if err != nil {
t.device.Close()

View File

@@ -94,7 +94,7 @@ func (t *TunDevice) Create() (WGConfigurer, error) {
return nil, fmt.Errorf("error assigning ip: %s", err)
}
t.configurer = configurer.NewUSPConfigurer(t.device, t.name, t.iceBind.ActivityRecorder())
t.configurer = configurer.NewUSPConfigurer(t.device, t.name)
err = t.configurer.ConfigureInterface(t.key, t.port)
if err != nil {
t.device.Close()

View File

@@ -19,5 +19,4 @@ type WGConfigurer interface {
Close()
GetStats() (map[string]configurer.WGStats, error)
FullStats() (*configurer.Stats, error)
LastActivities() map[string]time.Time
}

View File

@@ -43,7 +43,6 @@ type WGIFaceOpts struct {
MobileArgs *device.MobileIFaceArguments
TransportNet transport.Net
FilterFn bind.FilterFn
DisableDNS bool
}
// WGIface represents an interface instance
@@ -217,14 +216,6 @@ func (w *WGIface) GetStats() (map[string]configurer.WGStats, error) {
return w.configurer.GetStats()
}
func (w *WGIface) LastActivities() map[string]time.Time {
w.mu.Lock()
defer w.mu.Unlock()
return w.configurer.LastActivities()
}
func (w *WGIface) FullStats() (*configurer.Stats, error) {
return w.configurer.FullStats()
}

View File

@@ -18,7 +18,7 @@ func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
wgIFace := &WGIface{
userspaceBind: true,
tun: device.NewTunDevice(wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, opts.MobileArgs.TunAdapter, opts.DisableDNS),
tun: device.NewTunDevice(wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, opts.MobileArgs.TunAdapter),
wgProxyFactory: wgproxy.NewUSPFactory(iceBind),
}
return wgIFace, nil

View File

@@ -48,32 +48,32 @@ func (mr *MockPacketFilterMockRecorder) AddUDPPacketHook(arg0, arg1, arg2, arg3
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddUDPPacketHook", reflect.TypeOf((*MockPacketFilter)(nil).AddUDPPacketHook), arg0, arg1, arg2, arg3)
}
// FilterInbound mocks base method.
func (m *MockPacketFilter) FilterInbound(arg0 []byte, arg1 int) bool {
// DropIncoming mocks base method.
func (m *MockPacketFilter) DropIncoming(arg0 []byte, arg1 int) bool {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "FilterInbound", arg0, arg1)
ret := m.ctrl.Call(m, "DropIncoming", arg0, arg1)
ret0, _ := ret[0].(bool)
return ret0
}
// FilterInbound indicates an expected call of FilterInbound.
func (mr *MockPacketFilterMockRecorder) FilterInbound(arg0 interface{}, arg1 any) *gomock.Call {
// DropIncoming indicates an expected call of DropIncoming.
func (mr *MockPacketFilterMockRecorder) DropIncoming(arg0 interface{}, arg1 any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FilterInbound", reflect.TypeOf((*MockPacketFilter)(nil).FilterInbound), arg0, arg1)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DropIncoming", reflect.TypeOf((*MockPacketFilter)(nil).DropIncoming), arg0, arg1)
}
// FilterOutbound mocks base method.
func (m *MockPacketFilter) FilterOutbound(arg0 []byte, arg1 int) bool {
// DropOutgoing mocks base method.
func (m *MockPacketFilter) DropOutgoing(arg0 []byte, arg1 int) bool {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "FilterOutbound", arg0, arg1)
ret := m.ctrl.Call(m, "DropOutgoing", arg0, arg1)
ret0, _ := ret[0].(bool)
return ret0
}
// FilterOutbound indicates an expected call of FilterOutbound.
func (mr *MockPacketFilterMockRecorder) FilterOutbound(arg0 interface{}, arg1 any) *gomock.Call {
// DropOutgoing indicates an expected call of DropOutgoing.
func (mr *MockPacketFilterMockRecorder) DropOutgoing(arg0 interface{}, arg1 any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FilterOutbound", reflect.TypeOf((*MockPacketFilter)(nil).FilterOutbound), arg0, arg1)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DropOutgoing", reflect.TypeOf((*MockPacketFilter)(nil).DropOutgoing), arg0, arg1)
}
// RemovePacketHook mocks base method.

View File

@@ -46,32 +46,32 @@ func (mr *MockPacketFilterMockRecorder) AddUDPPacketHook(arg0, arg1, arg2, arg3
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddUDPPacketHook", reflect.TypeOf((*MockPacketFilter)(nil).AddUDPPacketHook), arg0, arg1, arg2, arg3)
}
// FilterInbound mocks base method.
func (m *MockPacketFilter) FilterInbound(arg0 []byte) bool {
// DropIncoming mocks base method.
func (m *MockPacketFilter) DropIncoming(arg0 []byte) bool {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "FilterInbound", arg0)
ret := m.ctrl.Call(m, "DropIncoming", arg0)
ret0, _ := ret[0].(bool)
return ret0
}
// FilterInbound indicates an expected call of FilterInbound.
func (mr *MockPacketFilterMockRecorder) FilterInbound(arg0 interface{}) *gomock.Call {
// DropIncoming indicates an expected call of DropIncoming.
func (mr *MockPacketFilterMockRecorder) DropIncoming(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FilterInbound", reflect.TypeOf((*MockPacketFilter)(nil).FilterInbound), arg0)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DropIncoming", reflect.TypeOf((*MockPacketFilter)(nil).DropIncoming), arg0)
}
// FilterOutbound mocks base method.
func (m *MockPacketFilter) FilterOutbound(arg0 []byte) bool {
// DropOutgoing mocks base method.
func (m *MockPacketFilter) DropOutgoing(arg0 []byte) bool {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "FilterOutbound", arg0)
ret := m.ctrl.Call(m, "DropOutgoing", arg0)
ret0, _ := ret[0].(bool)
return ret0
}
// FilterOutbound indicates an expected call of FilterOutbound.
func (mr *MockPacketFilterMockRecorder) FilterOutbound(arg0 interface{}) *gomock.Call {
// DropOutgoing indicates an expected call of DropOutgoing.
func (mr *MockPacketFilterMockRecorder) DropOutgoing(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FilterOutbound", reflect.TypeOf((*MockPacketFilter)(nil).FilterOutbound), arg0)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DropOutgoing", reflect.TypeOf((*MockPacketFilter)(nil).DropOutgoing), arg0)
}
// SetNetwork mocks base method.

View File

@@ -398,15 +398,11 @@ func (d *DefaultManager) squashAcceptRules(
//
// We zeroed this to notify squash function that this protocol can't be squashed.
addRuleToCalculationMap := func(i int, r *mgmProto.FirewallRule, protocols map[mgmProto.RuleProtocol]*protoMatch) {
hasPortRestrictions := r.Action == mgmProto.RuleAction_DROP ||
r.Port != "" || !portInfoEmpty(r.PortInfo)
if hasPortRestrictions {
// Don't squash rules with port restrictions
drop := r.Action == mgmProto.RuleAction_DROP || r.Port != ""
if drop {
protocols[r.Protocol] = &protoMatch{ips: map[string]int{}}
return
}
if _, ok := protocols[r.Protocol]; !ok {
protocols[r.Protocol] = &protoMatch{
ips: map[string]int{},

View File

@@ -330,434 +330,6 @@ func TestDefaultManagerSquashRulesNoAffect(t *testing.T) {
assert.Equal(t, len(networkMap.FirewallRules), len(rules))
}
func TestDefaultManagerSquashRulesWithPortRestrictions(t *testing.T) {
tests := []struct {
name string
rules []*mgmProto.FirewallRule
expectedCount int
description string
}{
{
name: "should not squash rules with port ranges",
rules: []*mgmProto.FirewallRule{
{
PeerIP: "10.93.0.1",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
PortInfo: &mgmProto.PortInfo{
PortSelection: &mgmProto.PortInfo_Range_{
Range: &mgmProto.PortInfo_Range{
Start: 8080,
End: 8090,
},
},
},
},
{
PeerIP: "10.93.0.2",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
PortInfo: &mgmProto.PortInfo{
PortSelection: &mgmProto.PortInfo_Range_{
Range: &mgmProto.PortInfo_Range{
Start: 8080,
End: 8090,
},
},
},
},
{
PeerIP: "10.93.0.3",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
PortInfo: &mgmProto.PortInfo{
PortSelection: &mgmProto.PortInfo_Range_{
Range: &mgmProto.PortInfo_Range{
Start: 8080,
End: 8090,
},
},
},
},
{
PeerIP: "10.93.0.4",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
PortInfo: &mgmProto.PortInfo{
PortSelection: &mgmProto.PortInfo_Range_{
Range: &mgmProto.PortInfo_Range{
Start: 8080,
End: 8090,
},
},
},
},
},
expectedCount: 4,
description: "Rules with port ranges should not be squashed even if they cover all peers",
},
{
name: "should not squash rules with specific ports",
rules: []*mgmProto.FirewallRule{
{
PeerIP: "10.93.0.1",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
PortInfo: &mgmProto.PortInfo{
PortSelection: &mgmProto.PortInfo_Port{
Port: 80,
},
},
},
{
PeerIP: "10.93.0.2",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
PortInfo: &mgmProto.PortInfo{
PortSelection: &mgmProto.PortInfo_Port{
Port: 80,
},
},
},
{
PeerIP: "10.93.0.3",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
PortInfo: &mgmProto.PortInfo{
PortSelection: &mgmProto.PortInfo_Port{
Port: 80,
},
},
},
{
PeerIP: "10.93.0.4",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
PortInfo: &mgmProto.PortInfo{
PortSelection: &mgmProto.PortInfo_Port{
Port: 80,
},
},
},
},
expectedCount: 4,
description: "Rules with specific ports should not be squashed even if they cover all peers",
},
{
name: "should not squash rules with legacy port field",
rules: []*mgmProto.FirewallRule{
{
PeerIP: "10.93.0.1",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "443",
},
{
PeerIP: "10.93.0.2",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "443",
},
{
PeerIP: "10.93.0.3",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "443",
},
{
PeerIP: "10.93.0.4",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "443",
},
},
expectedCount: 4,
description: "Rules with legacy port field should not be squashed",
},
{
name: "should not squash rules with DROP action",
rules: []*mgmProto.FirewallRule{
{
PeerIP: "10.93.0.1",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_DROP,
Protocol: mgmProto.RuleProtocol_TCP,
},
{
PeerIP: "10.93.0.2",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_DROP,
Protocol: mgmProto.RuleProtocol_TCP,
},
{
PeerIP: "10.93.0.3",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_DROP,
Protocol: mgmProto.RuleProtocol_TCP,
},
{
PeerIP: "10.93.0.4",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_DROP,
Protocol: mgmProto.RuleProtocol_TCP,
},
},
expectedCount: 4,
description: "Rules with DROP action should not be squashed",
},
{
name: "should squash rules without port restrictions",
rules: []*mgmProto.FirewallRule{
{
PeerIP: "10.93.0.1",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
},
{
PeerIP: "10.93.0.2",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
},
{
PeerIP: "10.93.0.3",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
},
{
PeerIP: "10.93.0.4",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
},
},
expectedCount: 1,
description: "Rules without port restrictions should be squashed into a single 0.0.0.0 rule",
},
{
name: "mixed rules should not squash protocol with port restrictions",
rules: []*mgmProto.FirewallRule{
{
PeerIP: "10.93.0.1",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
},
{
PeerIP: "10.93.0.2",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
PortInfo: &mgmProto.PortInfo{
PortSelection: &mgmProto.PortInfo_Port{
Port: 80,
},
},
},
{
PeerIP: "10.93.0.3",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
},
{
PeerIP: "10.93.0.4",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
},
},
expectedCount: 4,
description: "TCP should not be squashed because one rule has port restrictions",
},
{
name: "should squash UDP but not TCP when TCP has port restrictions",
rules: []*mgmProto.FirewallRule{
// TCP rules with port restrictions - should NOT be squashed
{
PeerIP: "10.93.0.1",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "443",
},
{
PeerIP: "10.93.0.2",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "443",
},
{
PeerIP: "10.93.0.3",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "443",
},
{
PeerIP: "10.93.0.4",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "443",
},
// UDP rules without port restrictions - SHOULD be squashed
{
PeerIP: "10.93.0.1",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_UDP,
},
{
PeerIP: "10.93.0.2",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_UDP,
},
{
PeerIP: "10.93.0.3",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_UDP,
},
{
PeerIP: "10.93.0.4",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_UDP,
},
},
expectedCount: 5, // 4 TCP rules + 1 squashed UDP rule (0.0.0.0)
description: "UDP should be squashed to 0.0.0.0 rule, but TCP should remain as individual rules due to port restrictions",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
networkMap := &mgmProto.NetworkMap{
RemotePeers: []*mgmProto.RemotePeerConfig{
{AllowedIps: []string{"10.93.0.1"}},
{AllowedIps: []string{"10.93.0.2"}},
{AllowedIps: []string{"10.93.0.3"}},
{AllowedIps: []string{"10.93.0.4"}},
},
FirewallRules: tt.rules,
}
manager := &DefaultManager{}
rules, _ := manager.squashAcceptRules(networkMap)
assert.Equal(t, tt.expectedCount, len(rules), tt.description)
// For squashed rules, verify we get the expected 0.0.0.0 rule
if tt.expectedCount == 1 {
assert.Equal(t, "0.0.0.0", rules[0].PeerIP)
assert.Equal(t, mgmProto.RuleDirection_IN, rules[0].Direction)
assert.Equal(t, mgmProto.RuleAction_ACCEPT, rules[0].Action)
}
})
}
}
func TestPortInfoEmpty(t *testing.T) {
tests := []struct {
name string
portInfo *mgmProto.PortInfo
expected bool
}{
{
name: "nil PortInfo should be empty",
portInfo: nil,
expected: true,
},
{
name: "PortInfo with zero port should be empty",
portInfo: &mgmProto.PortInfo{
PortSelection: &mgmProto.PortInfo_Port{
Port: 0,
},
},
expected: true,
},
{
name: "PortInfo with valid port should not be empty",
portInfo: &mgmProto.PortInfo{
PortSelection: &mgmProto.PortInfo_Port{
Port: 80,
},
},
expected: false,
},
{
name: "PortInfo with nil range should be empty",
portInfo: &mgmProto.PortInfo{
PortSelection: &mgmProto.PortInfo_Range_{
Range: nil,
},
},
expected: true,
},
{
name: "PortInfo with zero start range should be empty",
portInfo: &mgmProto.PortInfo{
PortSelection: &mgmProto.PortInfo_Range_{
Range: &mgmProto.PortInfo_Range{
Start: 0,
End: 100,
},
},
},
expected: true,
},
{
name: "PortInfo with zero end range should be empty",
portInfo: &mgmProto.PortInfo{
PortSelection: &mgmProto.PortInfo_Range_{
Range: &mgmProto.PortInfo_Range{
Start: 80,
End: 0,
},
},
},
expected: true,
},
{
name: "PortInfo with valid range should not be empty",
portInfo: &mgmProto.PortInfo{
PortSelection: &mgmProto.PortInfo_Range_{
Range: &mgmProto.PortInfo_Range{
Start: 8080,
End: 8090,
},
},
},
expected: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := portInfoEmpty(tt.portInfo)
assert.Equal(t, tt.expected, result)
})
}
}
func TestDefaultManagerEnableSSHRules(t *testing.T) {
networkMap := &mgmProto.NetworkMap{
PeerConfig: &mgmProto.PeerConfig{

View File

@@ -223,8 +223,6 @@ func createNewConfig(input ConfigInput) (*Config, error) {
config := &Config{
// defaults to false only for new (post 0.26) configurations
ServerSSHAllowed: util.False(),
// default to disabling server routes on Android for security
DisableServerRoutes: runtime.GOOS == "android",
}
if _, err := config.apply(input); err != nil {
@@ -319,6 +317,10 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) {
*input.WireguardPort, config.WgPort)
config.WgPort = *input.WireguardPort
updated = true
} else if config.WgPort == 0 {
config.WgPort = iface.DefaultWgPort
log.Infof("using default Wireguard port %d", config.WgPort)
updated = true
}
if input.InterfaceName != nil && *input.InterfaceName != config.WgIface {
@@ -414,15 +416,9 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) {
config.ServerSSHAllowed = input.ServerSSHAllowed
updated = true
} else if config.ServerSSHAllowed == nil {
if runtime.GOOS == "android" {
// default to disabled SSH on Android for security
log.Infof("setting SSH server to false by default on Android")
config.ServerSSHAllowed = util.False()
} else {
// enables SSH for configs from old versions to preserve backwards compatibility
log.Infof("falling back to enabled SSH server for pre-existing configuration")
config.ServerSSHAllowed = util.True()
}
// enables SSH for configs from old versions to preserve backwards compatibility
log.Infof("falling back to enabled SSH server for pre-existing configuration")
config.ServerSSHAllowed = util.True()
updated = true
}

View File

@@ -12,6 +12,7 @@ import (
"github.com/netbirdio/netbird/client/internal/lazyconn"
"github.com/netbirdio/netbird/client/internal/lazyconn/manager"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/peer/dispatcher"
"github.com/netbirdio/netbird/client/internal/peerstore"
"github.com/netbirdio/netbird/route"
)
@@ -25,11 +26,11 @@ import (
//
// The implementation is not thread-safe; it is protected by engine.syncMsgMux.
type ConnMgr struct {
peerStore *peerstore.Store
statusRecorder *peer.Status
iface lazyconn.WGIface
enabledLocally bool
rosenpassEnabled bool
peerStore *peerstore.Store
statusRecorder *peer.Status
iface lazyconn.WGIface
dispatcher *dispatcher.ConnectionDispatcher
enabledLocally bool
lazyConnMgr *manager.Manager
@@ -38,12 +39,12 @@ type ConnMgr struct {
lazyCtxCancel context.CancelFunc
}
func NewConnMgr(engineConfig *EngineConfig, statusRecorder *peer.Status, peerStore *peerstore.Store, iface lazyconn.WGIface) *ConnMgr {
func NewConnMgr(engineConfig *EngineConfig, statusRecorder *peer.Status, peerStore *peerstore.Store, iface lazyconn.WGIface, dispatcher *dispatcher.ConnectionDispatcher) *ConnMgr {
e := &ConnMgr{
peerStore: peerStore,
statusRecorder: statusRecorder,
iface: iface,
rosenpassEnabled: engineConfig.RosenpassEnabled,
peerStore: peerStore,
statusRecorder: statusRecorder,
iface: iface,
dispatcher: dispatcher,
}
if engineConfig.LazyConnectionEnabled || lazyconn.IsLazyConnEnabledByEnv() {
e.enabledLocally = true
@@ -63,11 +64,6 @@ func (e *ConnMgr) Start(ctx context.Context) {
return
}
if e.rosenpassEnabled {
log.Warnf("rosenpass connection manager is enabled, lazy connection manager will not be started")
return
}
e.initLazyManager(ctx)
e.statusRecorder.UpdateLazyConnection(true)
}
@@ -87,12 +83,7 @@ func (e *ConnMgr) UpdatedRemoteFeatureFlag(ctx context.Context, enabled bool) er
return nil
}
if e.rosenpassEnabled {
log.Infof("rosenpass connection manager is enabled, lazy connection manager will not be started")
return nil
}
log.Warnf("lazy connection manager is enabled by management feature flag")
log.Infof("lazy connection manager is enabled by management feature flag")
e.initLazyManager(ctx)
e.statusRecorder.UpdateLazyConnection(true)
return e.addPeersToLazyConnManager()
@@ -142,7 +133,7 @@ func (e *ConnMgr) SetExcludeList(ctx context.Context, peerIDs map[string]bool) {
excludedPeers = append(excludedPeers, lazyPeerCfg)
}
added := e.lazyConnMgr.ExcludePeer(excludedPeers)
added := e.lazyConnMgr.ExcludePeer(e.lazyCtx, excludedPeers)
for _, peerID := range added {
var peerConn *peer.Conn
var exists bool
@@ -210,7 +201,7 @@ func (e *ConnMgr) RemovePeerConn(peerKey string) {
if !ok {
return
}
defer conn.Close(false)
defer conn.Close()
if !e.isStartedWithLazyMgr() {
return
@@ -220,28 +211,23 @@ func (e *ConnMgr) RemovePeerConn(peerKey string) {
conn.Log.Infof("removed peer from lazy conn manager")
}
func (e *ConnMgr) ActivatePeer(ctx context.Context, conn *peer.Conn) {
if !e.isStartedWithLazyMgr() {
return
func (e *ConnMgr) OnSignalMsg(ctx context.Context, peerKey string) (*peer.Conn, bool) {
conn, ok := e.peerStore.PeerConn(peerKey)
if !ok {
return nil, false
}
if found := e.lazyConnMgr.ActivatePeer(conn.GetKey()); found {
if !e.isStartedWithLazyMgr() {
return conn, true
}
if found := e.lazyConnMgr.ActivatePeer(e.lazyCtx, peerKey); found {
conn.Log.Infof("activated peer from inactive state")
if err := conn.Open(ctx); err != nil {
conn.Log.Errorf("failed to open connection: %v", err)
}
}
}
// DeactivatePeer deactivates a peer connection in the lazy connection manager.
// If locally the lazy connection is disabled, we force the peer connection open.
func (e *ConnMgr) DeactivatePeer(conn *peer.Conn) {
if !e.isStartedWithLazyMgr() {
return
}
conn.Log.Infof("closing peer connection: remote peer initiated inactive, idle lazy state and sent GOAWAY")
e.lazyConnMgr.DeactivatePeer(conn.ConnID())
return conn, true
}
func (e *ConnMgr) Close() {
@@ -258,7 +244,7 @@ func (e *ConnMgr) initLazyManager(engineCtx context.Context) {
cfg := manager.Config{
InactivityThreshold: inactivityThresholdEnv(),
}
e.lazyConnMgr = manager.NewManager(cfg, engineCtx, e.peerStore, e.iface)
e.lazyConnMgr = manager.NewManager(cfg, engineCtx, e.peerStore, e.iface, e.dispatcher)
e.lazyCtx, e.lazyCtxCancel = context.WithCancel(engineCtx)
@@ -289,7 +275,7 @@ func (e *ConnMgr) addPeersToLazyConnManager() error {
lazyPeerCfgs = append(lazyPeerCfgs, lazyPeerCfg)
}
return e.lazyConnMgr.AddActivePeers(lazyPeerCfgs)
return e.lazyConnMgr.AddActivePeers(e.lazyCtx, lazyPeerCfgs)
}
func (e *ConnMgr) closeManager(ctx context.Context) {

View File

@@ -17,6 +17,7 @@ import (
"google.golang.org/grpc/codes"
gstatus "google.golang.org/grpc/status"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/listener"
@@ -525,13 +526,17 @@ func statusRecorderToSignalConnStateNotifier(statusRecorder *peer.Status) signal
// freePort attempts to determine if the provided port is available, if not it will ask the system for a free port.
func freePort(initPort int) (int, error) {
addr := net.UDPAddr{Port: initPort}
addr := net.UDPAddr{}
if initPort == 0 {
initPort = iface.DefaultWgPort
}
addr.Port = initPort
conn, err := net.ListenUDP("udp", &addr)
if err == nil {
returnPort := conn.LocalAddr().(*net.UDPAddr).Port
closeConnWithLog(conn)
return returnPort, nil
return initPort, nil
}
// if the port is already in use, ask the system for a free port

View File

@@ -13,10 +13,10 @@ func Test_freePort(t *testing.T) {
shouldMatch bool
}{
{
name: "when port is 0 use random port",
name: "not provided, fallback to default",
port: 0,
want: 0,
shouldMatch: false,
want: 51820,
shouldMatch: true,
},
{
name: "provided and available",
@@ -31,7 +31,7 @@ func Test_freePort(t *testing.T) {
shouldMatch: false,
},
}
c1, err := net.ListenUDP("udp", &net.UDPAddr{Port: 0})
c1, err := net.ListenUDP("udp", &net.UDPAddr{Port: 51830})
if err != nil {
t.Errorf("freePort error = %v", err)
}
@@ -39,14 +39,6 @@ func Test_freePort(t *testing.T) {
_ = c1.Close()
}(c1)
if tests[1].port == c1.LocalAddr().(*net.UDPAddr).Port {
tests[1].port++
tests[1].want++
}
tests[2].port = c1.LocalAddr().(*net.UDPAddr).Port
tests[2].want = c1.LocalAddr().(*net.UDPAddr).Port
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {

View File

@@ -11,10 +11,9 @@ import (
)
const (
PriorityLocal = 100
PriorityDNSRoute = 75
PriorityUpstream = 50
PriorityDefault = 1
PriorityDNSRoute = 100
PriorityMatchDomain = 50
PriorityDefault = 1
)
type SubdomainMatcher interface {

View File

@@ -22,7 +22,7 @@ func TestHandlerChain_ServeDNS_Priorities(t *testing.T) {
// Setup handlers with different priorities
chain.AddHandler("example.com.", defaultHandler, nbdns.PriorityDefault)
chain.AddHandler("example.com.", matchDomainHandler, nbdns.PriorityUpstream)
chain.AddHandler("example.com.", matchDomainHandler, nbdns.PriorityMatchDomain)
chain.AddHandler("example.com.", dnsRouteHandler, nbdns.PriorityDNSRoute)
// Create test request
@@ -200,7 +200,7 @@ func TestHandlerChain_ServeDNS_OverlappingDomains(t *testing.T) {
priority int
}{
{pattern: "*.example.com.", priority: nbdns.PriorityDefault},
{pattern: "*.example.com.", priority: nbdns.PriorityUpstream},
{pattern: "*.example.com.", priority: nbdns.PriorityMatchDomain},
{pattern: "*.example.com.", priority: nbdns.PriorityDNSRoute},
},
queryDomain: "test.example.com.",
@@ -214,7 +214,7 @@ func TestHandlerChain_ServeDNS_OverlappingDomains(t *testing.T) {
priority int
}{
{pattern: "*.example.com.", priority: nbdns.PriorityDefault},
{pattern: "test.example.com.", priority: nbdns.PriorityUpstream},
{pattern: "test.example.com.", priority: nbdns.PriorityMatchDomain},
{pattern: "*.test.example.com.", priority: nbdns.PriorityDNSRoute},
},
queryDomain: "sub.test.example.com.",
@@ -281,7 +281,7 @@ func TestHandlerChain_ServeDNS_ChainContinuation(t *testing.T) {
// Add handlers in priority order
chain.AddHandler("example.com.", handler1, nbdns.PriorityDNSRoute)
chain.AddHandler("example.com.", handler2, nbdns.PriorityUpstream)
chain.AddHandler("example.com.", handler2, nbdns.PriorityMatchDomain)
chain.AddHandler("example.com.", handler3, nbdns.PriorityDefault)
// Create test request
@@ -344,13 +344,13 @@ func TestHandlerChain_PriorityDeregistration(t *testing.T) {
priority int
}{
{"add", "example.com.", nbdns.PriorityDNSRoute},
{"add", "example.com.", nbdns.PriorityUpstream},
{"add", "example.com.", nbdns.PriorityMatchDomain},
{"remove", "example.com.", nbdns.PriorityDNSRoute},
},
query: "example.com.",
expectedCalls: map[int]bool{
nbdns.PriorityDNSRoute: false,
nbdns.PriorityUpstream: true,
nbdns.PriorityDNSRoute: false,
nbdns.PriorityMatchDomain: true,
},
},
{
@@ -361,13 +361,13 @@ func TestHandlerChain_PriorityDeregistration(t *testing.T) {
priority int
}{
{"add", "example.com.", nbdns.PriorityDNSRoute},
{"add", "example.com.", nbdns.PriorityUpstream},
{"remove", "example.com.", nbdns.PriorityUpstream},
{"add", "example.com.", nbdns.PriorityMatchDomain},
{"remove", "example.com.", nbdns.PriorityMatchDomain},
},
query: "example.com.",
expectedCalls: map[int]bool{
nbdns.PriorityDNSRoute: true,
nbdns.PriorityUpstream: false,
nbdns.PriorityDNSRoute: true,
nbdns.PriorityMatchDomain: false,
},
},
{
@@ -378,16 +378,16 @@ func TestHandlerChain_PriorityDeregistration(t *testing.T) {
priority int
}{
{"add", "example.com.", nbdns.PriorityDNSRoute},
{"add", "example.com.", nbdns.PriorityUpstream},
{"add", "example.com.", nbdns.PriorityMatchDomain},
{"add", "example.com.", nbdns.PriorityDefault},
{"remove", "example.com.", nbdns.PriorityDNSRoute},
{"remove", "example.com.", nbdns.PriorityUpstream},
{"remove", "example.com.", nbdns.PriorityMatchDomain},
},
query: "example.com.",
expectedCalls: map[int]bool{
nbdns.PriorityDNSRoute: false,
nbdns.PriorityUpstream: false,
nbdns.PriorityDefault: true,
nbdns.PriorityDNSRoute: false,
nbdns.PriorityMatchDomain: false,
nbdns.PriorityDefault: true,
},
},
}
@@ -454,7 +454,7 @@ func TestHandlerChain_MultiPriorityHandling(t *testing.T) {
// Add handlers in mixed order
chain.AddHandler(testDomain, defaultHandler, nbdns.PriorityDefault)
chain.AddHandler(testDomain, routeHandler, nbdns.PriorityDNSRoute)
chain.AddHandler(testDomain, matchHandler, nbdns.PriorityUpstream)
chain.AddHandler(testDomain, matchHandler, nbdns.PriorityMatchDomain)
// Test 1: Initial state
w1 := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}}
@@ -490,7 +490,7 @@ func TestHandlerChain_MultiPriorityHandling(t *testing.T) {
defaultHandler.Calls = nil
// Test 3: Remove middle priority handler
chain.RemoveHandler(testDomain, nbdns.PriorityUpstream)
chain.RemoveHandler(testDomain, nbdns.PriorityMatchDomain)
w3 := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}}
// Now lowest priority handler (defaultHandler) should be called
@@ -607,7 +607,7 @@ func TestHandlerChain_CaseSensitivity(t *testing.T) {
shouldMatch bool
}{
{"EXAMPLE.COM.", nbdns.PriorityDefault, false, false},
{"example.com.", nbdns.PriorityUpstream, false, false},
{"example.com.", nbdns.PriorityMatchDomain, false, false},
{"Example.Com.", nbdns.PriorityDNSRoute, false, true},
},
query: "example.com.",
@@ -702,8 +702,8 @@ func TestHandlerChain_DomainSpecificityOrdering(t *testing.T) {
priority int
subdomain bool
}{
{"add", "example.com.", nbdns.PriorityUpstream, true},
{"add", "sub.example.com.", nbdns.PriorityUpstream, false},
{"add", "example.com.", nbdns.PriorityMatchDomain, true},
{"add", "sub.example.com.", nbdns.PriorityMatchDomain, false},
},
query: "sub.example.com.",
expectedMatch: "sub.example.com.",
@@ -717,8 +717,8 @@ func TestHandlerChain_DomainSpecificityOrdering(t *testing.T) {
priority int
subdomain bool
}{
{"add", "example.com.", nbdns.PriorityUpstream, true},
{"add", "sub.example.com.", nbdns.PriorityUpstream, true},
{"add", "example.com.", nbdns.PriorityMatchDomain, true},
{"add", "sub.example.com.", nbdns.PriorityMatchDomain, true},
},
query: "sub.example.com.",
expectedMatch: "sub.example.com.",
@@ -732,10 +732,10 @@ func TestHandlerChain_DomainSpecificityOrdering(t *testing.T) {
priority int
subdomain bool
}{
{"add", "example.com.", nbdns.PriorityUpstream, true},
{"add", "sub.example.com.", nbdns.PriorityUpstream, true},
{"add", "test.sub.example.com.", nbdns.PriorityUpstream, false},
{"remove", "test.sub.example.com.", nbdns.PriorityUpstream, false},
{"add", "example.com.", nbdns.PriorityMatchDomain, true},
{"add", "sub.example.com.", nbdns.PriorityMatchDomain, true},
{"add", "test.sub.example.com.", nbdns.PriorityMatchDomain, false},
{"remove", "test.sub.example.com.", nbdns.PriorityMatchDomain, false},
},
query: "test.sub.example.com.",
expectedMatch: "sub.example.com.",
@@ -749,7 +749,7 @@ func TestHandlerChain_DomainSpecificityOrdering(t *testing.T) {
priority int
subdomain bool
}{
{"add", "sub.example.com.", nbdns.PriorityUpstream, false},
{"add", "sub.example.com.", nbdns.PriorityMatchDomain, false},
{"add", "example.com.", nbdns.PriorityDNSRoute, true},
},
query: "sub.example.com.",
@@ -764,9 +764,9 @@ func TestHandlerChain_DomainSpecificityOrdering(t *testing.T) {
priority int
subdomain bool
}{
{"add", "example.com.", nbdns.PriorityUpstream, true},
{"add", "other.example.com.", nbdns.PriorityUpstream, true},
{"add", "sub.example.com.", nbdns.PriorityUpstream, false},
{"add", "example.com.", nbdns.PriorityMatchDomain, true},
{"add", "other.example.com.", nbdns.PriorityMatchDomain, true},
{"add", "sub.example.com.", nbdns.PriorityMatchDomain, false},
},
query: "sub.example.com.",
expectedMatch: "sub.example.com.",

View File

@@ -527,7 +527,7 @@ func (s *DefaultServer) buildLocalHandlerUpdate(customZones []nbdns.CustomZone)
muxUpdates = append(muxUpdates, handlerWrapper{
domain: customZone.Domain,
handler: s.localResolver,
priority: PriorityLocal,
priority: PriorityMatchDomain,
})
for _, record := range customZone.Records {
@@ -566,7 +566,7 @@ func (s *DefaultServer) buildUpstreamHandlerUpdate(nameServerGroups []*nbdns.Nam
groupedNS := groupNSGroupsByDomain(nameServerGroups)
for _, domainGroup := range groupedNS {
basePriority := PriorityUpstream
basePriority := PriorityMatchDomain
if domainGroup.domain == nbdns.RootZone {
basePriority = PriorityDefault
}
@@ -588,14 +588,10 @@ func (s *DefaultServer) createHandlersForDomainGroup(domainGroup nsGroupsByDomai
// Decrement priority by handler index (0, 1, 2, ...) to avoid conflicts
priority := basePriority - i
// Check if we're about to overlap with the next priority tier.
// This boundary check ensures that the priority of upstream handlers does not conflict
// with the default priority tier. By decrementing the priority for each handler, we avoid
// overlaps, but if the calculated priority falls into the default tier, we skip the remaining
// handlers to maintain the integrity of the priority system.
if basePriority == PriorityUpstream && priority <= PriorityDefault {
// Check if we're about to overlap with the next priority tier
if basePriority == PriorityMatchDomain && priority <= PriorityDefault {
log.Warnf("too many handlers for domain=%s, would overlap with default priority tier (diff=%d). Skipping remaining handlers",
domainGroup.domain, PriorityUpstream-PriorityDefault)
domainGroup.domain, PriorityMatchDomain-PriorityDefault)
break
}

View File

@@ -164,12 +164,12 @@ func TestUpdateDNSServer(t *testing.T) {
generateDummyHandler("netbird.io", nameServers).ID(): handlerWrapper{
domain: "netbird.io",
handler: dummyHandler,
priority: PriorityUpstream,
priority: PriorityMatchDomain,
},
dummyHandler.ID(): handlerWrapper{
domain: "netbird.cloud",
handler: dummyHandler,
priority: PriorityLocal,
priority: PriorityMatchDomain,
},
generateDummyHandler(".", nameServers).ID(): handlerWrapper{
domain: nbdns.RootZone,
@@ -186,7 +186,7 @@ func TestUpdateDNSServer(t *testing.T) {
generateDummyHandler(zoneRecords[0].Name, nameServers).ID(): handlerWrapper{
domain: "netbird.cloud",
handler: dummyHandler,
priority: PriorityUpstream,
priority: PriorityMatchDomain,
},
},
initSerial: 0,
@@ -210,12 +210,12 @@ func TestUpdateDNSServer(t *testing.T) {
generateDummyHandler("netbird.io", nameServers).ID(): handlerWrapper{
domain: "netbird.io",
handler: dummyHandler,
priority: PriorityUpstream,
priority: PriorityMatchDomain,
},
"local-resolver": handlerWrapper{
domain: "netbird.cloud",
handler: dummyHandler,
priority: PriorityLocal,
priority: PriorityMatchDomain,
},
},
expectedLocalQs: []dns.Question{{Name: zoneRecords[0].Name, Qtype: 1, Qclass: 1}},
@@ -305,7 +305,7 @@ func TestUpdateDNSServer(t *testing.T) {
generateDummyHandler(zoneRecords[0].Name, nameServers).ID(): handlerWrapper{
domain: zoneRecords[0].Name,
handler: dummyHandler,
priority: PriorityUpstream,
priority: PriorityMatchDomain,
},
},
initSerial: 0,
@@ -321,7 +321,7 @@ func TestUpdateDNSServer(t *testing.T) {
generateDummyHandler(zoneRecords[0].Name, nameServers).ID(): handlerWrapper{
domain: zoneRecords[0].Name,
handler: dummyHandler,
priority: PriorityUpstream,
priority: PriorityMatchDomain,
},
},
initSerial: 0,
@@ -464,7 +464,7 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) {
defer ctrl.Finish()
packetfilter := pfmock.NewMockPacketFilter(ctrl)
packetfilter.EXPECT().FilterOutbound(gomock.Any(), gomock.Any()).AnyTimes()
packetfilter.EXPECT().DropOutgoing(gomock.Any(), gomock.Any()).AnyTimes()
packetfilter.EXPECT().AddUDPPacketHook(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
packetfilter.EXPECT().RemovePacketHook(gomock.Any())
@@ -495,7 +495,7 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) {
"id1": handlerWrapper{
domain: zoneRecords[0].Name,
handler: &local.Resolver{},
priority: PriorityUpstream,
priority: PriorityMatchDomain,
},
}
//dnsServer.localResolver.RegisteredMap = local.RegistrationMap{local.BuildRecordKey("netbird.cloud", dns.ClassINET, dns.TypeA): struct{}{}}
@@ -978,7 +978,7 @@ func TestHandlerChain_DomainPriorities(t *testing.T) {
}
chain.AddHandler("example.com.", dnsRouteHandler, PriorityDNSRoute)
chain.AddHandler("example.com.", upstreamHandler, PriorityUpstream)
chain.AddHandler("example.com.", upstreamHandler, PriorityMatchDomain)
testCases := []struct {
name string
@@ -1059,14 +1059,14 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
handler: &mockHandler{
Id: "upstream-group1",
},
priority: PriorityUpstream,
priority: PriorityMatchDomain,
},
"upstream-group2": {
domain: "example.com",
handler: &mockHandler{
Id: "upstream-group2",
},
priority: PriorityUpstream - 1,
priority: PriorityMatchDomain - 1,
},
}
@@ -1093,21 +1093,21 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
handler: &mockHandler{
Id: "upstream-group1",
},
priority: PriorityUpstream,
priority: PriorityMatchDomain,
},
"upstream-group2": {
domain: "example.com",
handler: &mockHandler{
Id: "upstream-group2",
},
priority: PriorityUpstream - 1,
priority: PriorityMatchDomain - 1,
},
"upstream-other": {
domain: "other.com",
handler: &mockHandler{
Id: "upstream-other",
},
priority: PriorityUpstream,
priority: PriorityMatchDomain,
},
}
@@ -1128,7 +1128,7 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
handler: &mockHandler{
Id: "upstream-group2",
},
priority: PriorityUpstream - 1,
priority: PriorityMatchDomain - 1,
},
},
expectedHandlers: map[string]string{
@@ -1146,7 +1146,7 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
handler: &mockHandler{
Id: "upstream-group1",
},
priority: PriorityUpstream,
priority: PriorityMatchDomain,
},
},
expectedHandlers: map[string]string{
@@ -1164,7 +1164,7 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
handler: &mockHandler{
Id: "upstream-group3",
},
priority: PriorityUpstream + 1,
priority: PriorityMatchDomain + 1,
},
// Keep existing groups with their original priorities
{
@@ -1172,14 +1172,14 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
handler: &mockHandler{
Id: "upstream-group1",
},
priority: PriorityUpstream,
priority: PriorityMatchDomain,
},
{
domain: "example.com",
handler: &mockHandler{
Id: "upstream-group2",
},
priority: PriorityUpstream - 1,
priority: PriorityMatchDomain - 1,
},
},
expectedHandlers: map[string]string{
@@ -1199,14 +1199,14 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
handler: &mockHandler{
Id: "upstream-group1",
},
priority: PriorityUpstream,
priority: PriorityMatchDomain,
},
{
domain: "example.com",
handler: &mockHandler{
Id: "upstream-group2",
},
priority: PriorityUpstream - 1,
priority: PriorityMatchDomain - 1,
},
// Add group3 with lowest priority
{
@@ -1214,7 +1214,7 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
handler: &mockHandler{
Id: "upstream-group3",
},
priority: PriorityUpstream - 2,
priority: PriorityMatchDomain - 2,
},
},
expectedHandlers: map[string]string{
@@ -1335,14 +1335,14 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
handler: &mockHandler{
Id: "upstream-group1",
},
priority: PriorityUpstream,
priority: PriorityMatchDomain,
},
{
domain: "other.com",
handler: &mockHandler{
Id: "upstream-other",
},
priority: PriorityUpstream,
priority: PriorityMatchDomain,
},
},
expectedHandlers: map[string]string{
@@ -1360,28 +1360,28 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
handler: &mockHandler{
Id: "upstream-group1",
},
priority: PriorityUpstream,
priority: PriorityMatchDomain,
},
{
domain: "example.com",
handler: &mockHandler{
Id: "upstream-group2",
},
priority: PriorityUpstream - 1,
priority: PriorityMatchDomain - 1,
},
{
domain: "other.com",
handler: &mockHandler{
Id: "upstream-other",
},
priority: PriorityUpstream,
priority: PriorityMatchDomain,
},
{
domain: "new.com",
handler: &mockHandler{
Id: "upstream-new",
},
priority: PriorityUpstream,
priority: PriorityMatchDomain,
},
},
expectedHandlers: map[string]string{
@@ -1791,14 +1791,14 @@ func TestExtraDomainsRefCounting(t *testing.T) {
// Register domains from different handlers with same domain
server.RegisterHandler(domain.List{"*.shared.example.com"}, &MockHandler{}, PriorityDNSRoute)
server.RegisterHandler(domain.List{"shared.example.com."}, &MockHandler{}, PriorityUpstream)
server.RegisterHandler(domain.List{"shared.example.com."}, &MockHandler{}, PriorityMatchDomain)
// Verify refcount is 2
zoneKey := toZone("shared.example.com")
assert.Equal(t, 2, server.extraDomains[zoneKey], "Refcount should be 2 after registering same domain twice")
// Deregister one handler
server.DeregisterHandler(domain.List{"shared.example.com"}, PriorityUpstream)
server.DeregisterHandler(domain.List{"shared.example.com"}, PriorityMatchDomain)
// Verify refcount is 1
assert.Equal(t, 1, server.extraDomains[zoneKey], "Refcount should be 1 after deregistering one handler")
@@ -1925,7 +1925,7 @@ func TestDomainCaseHandling(t *testing.T) {
}
server.RegisterHandler(domain.List{"MIXED.example.com"}, &MockHandler{}, PriorityDefault)
server.RegisterHandler(domain.List{"mixed.EXAMPLE.com"}, &MockHandler{}, PriorityUpstream)
server.RegisterHandler(domain.List{"mixed.EXAMPLE.com"}, &MockHandler{}, PriorityMatchDomain)
assert.Equal(t, 1, len(server.extraDomains), "Case differences should be normalized")
@@ -1945,111 +1945,3 @@ func TestDomainCaseHandling(t *testing.T) {
assert.Contains(t, domains, "config.example.com.", "Mixed case domain should be normalized and pre.sent")
assert.Contains(t, domains, "mixed.example.com.", "Mixed case domain should be normalized and present")
}
func TestLocalResolverPriorityInServer(t *testing.T) {
server := &DefaultServer{
ctx: context.Background(),
wgInterface: &mocWGIface{},
handlerChain: NewHandlerChain(),
localResolver: local.NewResolver(),
service: &mockService{},
extraDomains: make(map[domain.Domain]int),
}
config := nbdns.Config{
ServiceEnable: true,
CustomZones: []nbdns.CustomZone{
{
Domain: "local.example.com",
Records: []nbdns.SimpleRecord{
{
Name: "test.local.example.com",
Type: int(dns.TypeA),
Class: nbdns.DefaultClass,
TTL: 300,
RData: "192.168.1.100",
},
},
},
},
NameServerGroups: []*nbdns.NameServerGroup{
{
Domains: []string{"local.example.com"}, // Same domain as local records
NameServers: []nbdns.NameServer{
{
IP: netip.MustParseAddr("8.8.8.8"),
NSType: nbdns.UDPNameServerType,
Port: 53,
},
},
},
},
}
localMuxUpdates, _, err := server.buildLocalHandlerUpdate(config.CustomZones)
assert.NoError(t, err)
upstreamMuxUpdates, err := server.buildUpstreamHandlerUpdate(config.NameServerGroups)
assert.NoError(t, err)
// Verify that local handler has higher priority than upstream for same domain
var localPriority, upstreamPriority int
localFound, upstreamFound := false, false
for _, update := range localMuxUpdates {
if update.domain == "local.example.com" {
localPriority = update.priority
localFound = true
}
}
for _, update := range upstreamMuxUpdates {
if update.domain == "local.example.com" {
upstreamPriority = update.priority
upstreamFound = true
}
}
assert.True(t, localFound, "Local handler should be found")
assert.True(t, upstreamFound, "Upstream handler should be found")
assert.Greater(t, localPriority, upstreamPriority,
"Local handler priority (%d) should be higher than upstream priority (%d)",
localPriority, upstreamPriority)
assert.Equal(t, PriorityLocal, localPriority, "Local handler should use PriorityLocal")
assert.Equal(t, PriorityUpstream, upstreamPriority, "Upstream handler should use PriorityUpstream")
}
func TestLocalResolverPriorityConstants(t *testing.T) {
// Test that priority constants are ordered correctly
assert.Greater(t, PriorityLocal, PriorityDNSRoute, "Local priority should be higher than DNS route")
assert.Greater(t, PriorityLocal, PriorityUpstream, "Local priority should be higher than upstream")
assert.Greater(t, PriorityUpstream, PriorityDefault, "Upstream priority should be higher than default")
// Test that local resolver uses the correct priority
server := &DefaultServer{
localResolver: local.NewResolver(),
}
config := nbdns.Config{
CustomZones: []nbdns.CustomZone{
{
Domain: "local.example.com",
Records: []nbdns.SimpleRecord{
{
Name: "test.local.example.com",
Type: int(dns.TypeA),
Class: nbdns.DefaultClass,
TTL: 300,
RData: "192.168.1.100",
},
},
},
},
}
localMuxUpdates, _, err := server.buildLocalHandlerUpdate(config.CustomZones)
assert.NoError(t, err)
assert.Len(t, localMuxUpdates, 1)
assert.Equal(t, PriorityLocal, localMuxUpdates[0].priority, "Local handler should use PriorityLocal")
assert.Equal(t, "local.example.com", localMuxUpdates[0].domain)
}

View File

@@ -2,7 +2,6 @@ package dns
import (
"context"
"crypto/rand"
"crypto/sha256"
"encoding/hex"
"errors"
@@ -104,21 +103,19 @@ func (u *upstreamResolverBase) Stop() {
// ServeDNS handles a DNS request
func (u *upstreamResolverBase) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
requestID := GenerateRequestID()
logger := log.WithField("request_id", requestID)
var err error
defer func() {
u.checkUpstreamFails(err)
}()
logger.Tracef("received upstream question: domain=%s type=%v class=%v", r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass)
log.Tracef("received upstream question: domain=%s type=%v class=%v", r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass)
if r.Extra == nil {
r.MsgHdr.AuthenticatedData = true
}
select {
case <-u.ctx.Done():
logger.Tracef("%s has been stopped", u)
log.Tracef("%s has been stopped", u)
return
default:
}
@@ -135,35 +132,35 @@ func (u *upstreamResolverBase) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
if err != nil {
if errors.Is(err, context.DeadlineExceeded) || isTimeout(err) {
logger.Warnf("upstream %s timed out for question domain=%s", upstream, r.Question[0].Name)
log.Warnf("upstream %s timed out for question domain=%s", upstream, r.Question[0].Name)
continue
}
logger.Warnf("failed to query upstream %s for question domain=%s: %s", upstream, r.Question[0].Name, err)
log.Warnf("failed to query upstream %s for question domain=%s: %s", upstream, r.Question[0].Name, err)
continue
}
if rm == nil || !rm.Response {
logger.Warnf("no response from upstream %s for question domain=%s", upstream, r.Question[0].Name)
log.Warnf("no response from upstream %s for question domain=%s", upstream, r.Question[0].Name)
continue
}
u.successCount.Add(1)
logger.Tracef("took %s to query the upstream %s for question domain=%s", t, upstream, r.Question[0].Name)
log.Tracef("took %s to query the upstream %s for question domain=%s", t, upstream, r.Question[0].Name)
if err = w.WriteMsg(rm); err != nil {
logger.Errorf("failed to write DNS response for question domain=%s: %s", r.Question[0].Name, err)
log.Errorf("failed to write DNS response for question domain=%s: %s", r.Question[0].Name, err)
}
// count the fails only if they happen sequentially
u.failsCount.Store(0)
return
}
u.failsCount.Add(1)
logger.Errorf("all queries to the %s failed for question domain=%s", u, r.Question[0].Name)
log.Errorf("all queries to the %s failed for question domain=%s", u, r.Question[0].Name)
m := new(dns.Msg)
m.SetRcode(r, dns.RcodeServerFailure)
if err := w.WriteMsg(m); err != nil {
logger.Errorf("failed to write error response for %s for question domain=%s: %s", u, r.Question[0].Name, err)
log.Errorf("failed to write error response for %s for question domain=%s: %s", u, r.Question[0].Name, err)
}
}
@@ -388,13 +385,3 @@ func ExchangeWithFallback(ctx context.Context, client *dns.Client, r *dns.Msg, u
return rm, t, nil
}
func GenerateRequestID() string {
bytes := make([]byte, 4)
_, err := rand.Read(bytes)
if err != nil {
log.Errorf("failed to generate request ID: %v", err)
return ""
}
return hex.EncodeToString(bytes)
}

View File

@@ -84,10 +84,3 @@ func (u *upstreamResolver) isLocalResolver(upstream string) bool {
}
return false
}
func GetClientPrivate(ip netip.Addr, interfaceName string, dialTimeout time.Duration) (*dns.Client, error) {
return &dns.Client{
Timeout: dialTimeout,
Net: "udp",
}, nil
}

View File

@@ -36,10 +36,3 @@ func newUpstreamResolver(
func (u *upstreamResolver) exchange(ctx context.Context, upstream string, r *dns.Msg) (rm *dns.Msg, t time.Duration, err error) {
return ExchangeWithFallback(ctx, &dns.Client{}, r, upstream)
}
func GetClientPrivate(ip netip.Addr, interfaceName string, dialTimeout time.Duration) (*dns.Client, error) {
return &dns.Client{
Timeout: dialTimeout,
Net: "udp",
}, nil
}

View File

@@ -18,20 +18,14 @@ import (
nberrors "github.com/netbirdio/netbird/client/errors"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/internal/peer"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/domain"
"github.com/netbirdio/netbird/route"
)
const errResolveFailed = "failed to resolve query for domain=%s: %v"
const upstreamTimeout = 15 * time.Second
type resolver interface {
LookupNetIP(ctx context.Context, network, host string) ([]netip.Addr, error)
}
type firewaller interface {
UpdateSet(set firewall.Set, prefixes []netip.Prefix) error
}
type DNSForwarder struct {
listenAddress string
ttl uint32
@@ -44,18 +38,16 @@ type DNSForwarder struct {
mutex sync.RWMutex
fwdEntries []*ForwarderEntry
firewall firewaller
resolver resolver
firewall firewall.Manager
}
func NewDNSForwarder(listenAddress string, ttl uint32, firewall firewaller, statusRecorder *peer.Status) *DNSForwarder {
func NewDNSForwarder(listenAddress string, ttl uint32, firewall firewall.Manager, statusRecorder *peer.Status) *DNSForwarder {
log.Debugf("creating DNS forwarder with listen_address=%s ttl=%d", listenAddress, ttl)
return &DNSForwarder{
listenAddress: listenAddress,
ttl: ttl,
firewall: firewall,
statusRecorder: statusRecorder,
resolver: net.DefaultResolver,
}
}
@@ -65,17 +57,14 @@ func (f *DNSForwarder) Listen(entries []*ForwarderEntry) error {
// UDP server
mux := dns.NewServeMux()
f.mux = mux
mux.HandleFunc(".", f.handleDNSQueryUDP)
f.dnsServer = &dns.Server{
Addr: f.listenAddress,
Net: "udp",
Handler: mux,
}
// TCP server
tcpMux := dns.NewServeMux()
f.tcpMux = tcpMux
tcpMux.HandleFunc(".", f.handleDNSQueryTCP)
f.tcpServer = &dns.Server{
Addr: f.listenAddress,
Net: "tcp",
@@ -98,13 +87,30 @@ func (f *DNSForwarder) Listen(entries []*ForwarderEntry) error {
// return the first error we get (e.g. bind failure or shutdown)
return <-errCh
}
func (f *DNSForwarder) UpdateDomains(entries []*ForwarderEntry) {
f.mutex.Lock()
defer f.mutex.Unlock()
if f.mux == nil {
log.Debug("DNS mux is nil, skipping domain update")
f.fwdEntries = entries
return
}
oldDomains := filterDomains(f.fwdEntries)
for _, d := range oldDomains {
f.mux.HandleRemove(d.PunycodeString())
f.tcpMux.HandleRemove(d.PunycodeString())
}
newDomains := filterDomains(entries)
for _, d := range newDomains {
f.mux.HandleFunc(d.PunycodeString(), f.handleDNSQueryUDP)
f.tcpMux.HandleFunc(d.PunycodeString(), f.handleDNSQueryTCP)
}
f.fwdEntries = entries
log.Debugf("Updated DNS forwarder with %d domains", len(entries))
log.Debugf("Updated domains from %v to %v", oldDomains, newDomains)
}
func (f *DNSForwarder) Close(ctx context.Context) error {
@@ -151,31 +157,22 @@ func (f *DNSForwarder) handleDNSQuery(w dns.ResponseWriter, query *dns.Msg) *dns
return nil
}
mostSpecificResId, matchingEntries := f.getMatchingEntries(strings.TrimSuffix(domain, "."))
// query doesn't match any configured domain
if mostSpecificResId == "" {
resp.Rcode = dns.RcodeRefused
if err := w.WriteMsg(resp); err != nil {
log.Errorf("failed to write DNS response: %v", err)
}
return nil
}
ctx, cancel := context.WithTimeout(context.Background(), upstreamTimeout)
defer cancel()
ips, err := f.resolver.LookupNetIP(ctx, network, domain)
ips, err := net.DefaultResolver.LookupNetIP(ctx, network, domain)
if err != nil {
f.handleDNSError(w, query, resp, domain, err)
return nil
}
f.updateInternalState(ips, mostSpecificResId, matchingEntries)
f.updateInternalState(domain, ips)
f.addIPsToResponse(resp, domain, ips)
return resp
}
func (f *DNSForwarder) handleDNSQueryUDP(w dns.ResponseWriter, query *dns.Msg) {
resp := f.handleDNSQuery(w, query)
if resp == nil {
return
@@ -209,8 +206,9 @@ func (f *DNSForwarder) handleDNSQueryTCP(w dns.ResponseWriter, query *dns.Msg) {
}
}
func (f *DNSForwarder) updateInternalState(ips []netip.Addr, mostSpecificResId route.ResID, matchingEntries []*ForwarderEntry) {
func (f *DNSForwarder) updateInternalState(domain string, ips []netip.Addr) {
var prefixes []netip.Prefix
mostSpecificResId, matchingEntries := f.getMatchingEntries(strings.TrimSuffix(domain, "."))
if mostSpecificResId != "" {
for _, ip := range ips {
var prefix netip.Prefix
@@ -341,3 +339,16 @@ func (f *DNSForwarder) getMatchingEntries(domain string) (route.ResID, []*Forwar
return selectedResId, matches
}
// filterDomains returns a list of normalized domains
func filterDomains(entries []*ForwarderEntry) domain.List {
newDomains := make(domain.List, 0, len(entries))
for _, d := range entries {
if d.Domain == "" {
log.Warn("empty domain in DNS forwarder")
continue
}
newDomains = append(newDomains, domain.Domain(nbdns.NormalizeZone(d.Domain.PunycodeString())))
}
return newDomains
}

View File

@@ -1,21 +1,11 @@
package dnsfwd
import (
"context"
"fmt"
"net/netip"
"strings"
"testing"
"time"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/internal/dns/test"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/management/domain"
"github.com/netbirdio/netbird/route"
)
@@ -23,7 +13,7 @@ import (
func Test_getMatchingEntries(t *testing.T) {
testCases := []struct {
name string
storedMappings map[string]route.ResID
storedMappings map[string]route.ResID // key: domain pattern, value: resId
queryDomain string
expectedResId route.ResID
}{
@@ -54,7 +44,7 @@ func Test_getMatchingEntries(t *testing.T) {
{
name: "Wildcard pattern does not match different domain",
storedMappings: map[string]route.ResID{"*.example.com": "res4"},
queryDomain: "foo.example.org",
queryDomain: "foo.notexample.com",
expectedResId: "",
},
{
@@ -111,619 +101,3 @@ func Test_getMatchingEntries(t *testing.T) {
})
}
}
type MockFirewall struct {
mock.Mock
}
func (m *MockFirewall) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
args := m.Called(set, prefixes)
return args.Error(0)
}
type MockResolver struct {
mock.Mock
}
func (m *MockResolver) LookupNetIP(ctx context.Context, network, host string) ([]netip.Addr, error) {
args := m.Called(ctx, network, host)
return args.Get(0).([]netip.Addr), args.Error(1)
}
func TestDNSForwarder_SubdomainAccessLogic(t *testing.T) {
tests := []struct {
name string
configuredDomain string
queryDomain string
shouldMatch bool
expectedResID route.ResID
description string
}{
{
name: "exact domain match should be allowed",
configuredDomain: "example.com",
queryDomain: "example.com",
shouldMatch: true,
expectedResID: "test-res-id",
description: "Direct match to configured domain should work",
},
{
name: "subdomain access should be restricted",
configuredDomain: "example.com",
queryDomain: "mail.example.com",
shouldMatch: false,
expectedResID: "",
description: "Subdomain should not be accessible unless explicitly configured",
},
{
name: "wildcard should allow subdomains",
configuredDomain: "*.example.com",
queryDomain: "mail.example.com",
shouldMatch: true,
expectedResID: "test-res-id",
description: "Wildcard domains should allow subdomain access",
},
{
name: "wildcard should allow base domain",
configuredDomain: "*.example.com",
queryDomain: "example.com",
shouldMatch: true,
expectedResID: "test-res-id",
description: "Wildcard should also match the base domain",
},
{
name: "deep subdomain should be restricted",
configuredDomain: "example.com",
queryDomain: "deep.mail.example.com",
shouldMatch: false,
expectedResID: "",
description: "Deep subdomains should not be accessible",
},
{
name: "wildcard allows deep subdomains",
configuredDomain: "*.example.com",
queryDomain: "deep.mail.example.com",
shouldMatch: true,
expectedResID: "test-res-id",
description: "Wildcard should allow deep subdomains",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
forwarder := &DNSForwarder{}
d, err := domain.FromString(tt.configuredDomain)
require.NoError(t, err)
entries := []*ForwarderEntry{
{
Domain: d,
ResID: "test-res-id",
},
}
forwarder.UpdateDomains(entries)
resID, matchingEntries := forwarder.getMatchingEntries(tt.queryDomain)
if tt.shouldMatch {
assert.Equal(t, tt.expectedResID, resID, "Expected matching ResID")
assert.NotEmpty(t, matchingEntries, "Expected matching entries")
t.Logf("✓ Domain %s correctly matches pattern %s", tt.queryDomain, tt.configuredDomain)
} else {
assert.Equal(t, tt.expectedResID, resID, "Expected no ResID match")
assert.Empty(t, matchingEntries, "Expected no matching entries")
t.Logf("✓ Domain %s correctly does NOT match pattern %s", tt.queryDomain, tt.configuredDomain)
}
})
}
}
func TestDNSForwarder_UnauthorizedDomainAccess(t *testing.T) {
if testing.Short() {
t.Skip("Skipping integration test in short mode")
}
tests := []struct {
name string
configuredDomain string
queryDomain string
shouldResolve bool
description string
}{
{
name: "configured exact domain resolves",
configuredDomain: "example.com",
queryDomain: "example.com",
shouldResolve: true,
description: "Exact match should resolve",
},
{
name: "unauthorized subdomain blocked",
configuredDomain: "example.com",
queryDomain: "mail.example.com",
shouldResolve: false,
description: "Subdomain should be blocked without wildcard",
},
{
name: "wildcard allows subdomain",
configuredDomain: "*.example.com",
queryDomain: "mail.example.com",
shouldResolve: true,
description: "Wildcard should allow subdomain",
},
{
name: "wildcard allows base domain",
configuredDomain: "*.example.com",
queryDomain: "example.com",
shouldResolve: true,
description: "Wildcard should allow base domain",
},
{
name: "unrelated domain blocked",
configuredDomain: "example.com",
queryDomain: "example.org",
shouldResolve: false,
description: "Unrelated domain should be blocked",
},
{
name: "deep subdomain blocked",
configuredDomain: "example.com",
queryDomain: "deep.mail.example.com",
shouldResolve: false,
description: "Deep subdomain should be blocked",
},
{
name: "wildcard allows deep subdomain",
configuredDomain: "*.example.com",
queryDomain: "deep.mail.example.com",
shouldResolve: true,
description: "Wildcard should allow deep subdomain",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
mockFirewall := &MockFirewall{}
mockResolver := &MockResolver{}
if tt.shouldResolve {
mockFirewall.On("UpdateSet", mock.AnythingOfType("manager.Set"), mock.AnythingOfType("[]netip.Prefix")).Return(nil)
// Mock successful DNS resolution
fakeIP := netip.MustParseAddr("1.2.3.4")
mockResolver.On("LookupNetIP", mock.Anything, "ip4", dns.Fqdn(tt.queryDomain)).Return([]netip.Addr{fakeIP}, nil)
}
forwarder := NewDNSForwarder("127.0.0.1:0", 300, mockFirewall, &peer.Status{})
forwarder.resolver = mockResolver
d, err := domain.FromString(tt.configuredDomain)
require.NoError(t, err)
entries := []*ForwarderEntry{
{
Domain: d,
ResID: "test-res-id",
Set: firewall.NewDomainSet([]domain.Domain{d}),
},
}
forwarder.UpdateDomains(entries)
query := &dns.Msg{}
query.SetQuestion(dns.Fqdn(tt.queryDomain), dns.TypeA)
mockWriter := &test.MockResponseWriter{}
resp := forwarder.handleDNSQuery(mockWriter, query)
if tt.shouldResolve {
require.NotNil(t, resp, "Expected response for authorized domain")
require.Equal(t, dns.RcodeSuccess, resp.Rcode, "Expected successful response")
assert.NotEmpty(t, resp.Answer, "Expected DNS answer records")
time.Sleep(10 * time.Millisecond)
mockFirewall.AssertExpectations(t)
mockResolver.AssertExpectations(t)
} else {
if resp != nil {
assert.True(t, len(resp.Answer) == 0 || resp.Rcode != dns.RcodeSuccess,
"Unauthorized domain should not return successful answers")
}
mockFirewall.AssertNotCalled(t, "UpdateSet")
mockResolver.AssertNotCalled(t, "LookupNetIP")
}
})
}
}
func TestDNSForwarder_FirewallSetUpdates(t *testing.T) {
tests := []struct {
name string
configuredDomains []string
query string
mockIP string
shouldResolve bool
expectedSetCount int // How many sets should be updated
description string
}{
{
name: "exact domain gets firewall update",
configuredDomains: []string{"example.com"},
query: "example.com",
mockIP: "1.1.1.1",
shouldResolve: true,
expectedSetCount: 1,
description: "Single exact match updates one set",
},
{
name: "wildcard domain gets firewall update",
configuredDomains: []string{"*.example.com"},
query: "mail.example.com",
mockIP: "1.1.1.2",
shouldResolve: true,
expectedSetCount: 1,
description: "Wildcard match updates one set",
},
{
name: "overlapping exact and wildcard both get updates",
configuredDomains: []string{"*.example.com", "mail.example.com"},
query: "mail.example.com",
mockIP: "1.1.1.3",
shouldResolve: true,
expectedSetCount: 2,
description: "Both exact and wildcard sets should be updated",
},
{
name: "unauthorized domain gets no firewall update",
configuredDomains: []string{"example.com"},
query: "mail.example.com",
mockIP: "1.1.1.4",
shouldResolve: false,
expectedSetCount: 0,
description: "No firewall update for unauthorized domains",
},
{
name: "multiple wildcards matching get all updated",
configuredDomains: []string{"*.example.com", "*.sub.example.com"},
query: "test.sub.example.com",
mockIP: "1.1.1.5",
shouldResolve: true,
expectedSetCount: 2,
description: "All matching wildcard sets should be updated",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
mockFirewall := &MockFirewall{}
mockResolver := &MockResolver{}
// Set up forwarder
forwarder := NewDNSForwarder("127.0.0.1:0", 300, mockFirewall, &peer.Status{})
forwarder.resolver = mockResolver
// Create entries and track sets
var entries []*ForwarderEntry
sets := make([]firewall.Set, 0)
for i, configDomain := range tt.configuredDomains {
d, err := domain.FromString(configDomain)
require.NoError(t, err)
set := firewall.NewDomainSet([]domain.Domain{d})
sets = append(sets, set)
entries = append(entries, &ForwarderEntry{
Domain: d,
ResID: route.ResID(fmt.Sprintf("res-%d", i)),
Set: set,
})
}
forwarder.UpdateDomains(entries)
// Set up mocks
if tt.shouldResolve {
fakeIP := netip.MustParseAddr(tt.mockIP)
mockResolver.On("LookupNetIP", mock.Anything, "ip4", dns.Fqdn(tt.query)).
Return([]netip.Addr{fakeIP}, nil).Once()
expectedPrefixes := []netip.Prefix{netip.PrefixFrom(fakeIP, 32)}
// Count how many sets should actually match
updateCount := 0
for i, entry := range entries {
domain := strings.ToLower(tt.query)
pattern := entry.Domain.PunycodeString()
matches := false
if strings.HasPrefix(pattern, "*.") {
baseDomain := strings.TrimPrefix(pattern, "*.")
if domain == baseDomain || strings.HasSuffix(domain, "."+baseDomain) {
matches = true
}
} else if domain == pattern {
matches = true
}
if matches {
mockFirewall.On("UpdateSet", sets[i], expectedPrefixes).Return(nil).Once()
updateCount++
}
}
assert.Equal(t, tt.expectedSetCount, updateCount,
"Expected %d sets to be updated, but mock expects %d",
tt.expectedSetCount, updateCount)
}
// Execute query
dnsQuery := &dns.Msg{}
dnsQuery.SetQuestion(dns.Fqdn(tt.query), dns.TypeA)
mockWriter := &test.MockResponseWriter{}
resp := forwarder.handleDNSQuery(mockWriter, dnsQuery)
// Verify response
if tt.shouldResolve {
require.NotNil(t, resp, "Expected response for authorized domain")
require.Equal(t, dns.RcodeSuccess, resp.Rcode)
require.NotEmpty(t, resp.Answer)
} else if resp != nil {
assert.True(t, resp.Rcode == dns.RcodeRefused || len(resp.Answer) == 0,
"Unauthorized domain should be refused or have no answers")
}
// Verify all mock expectations were met
mockFirewall.AssertExpectations(t)
mockResolver.AssertExpectations(t)
})
}
}
// Test to verify that multiple IPs for one domain result in all prefixes being sent together
func TestDNSForwarder_MultipleIPsInSingleUpdate(t *testing.T) {
mockFirewall := &MockFirewall{}
mockResolver := &MockResolver{}
forwarder := NewDNSForwarder("127.0.0.1:0", 300, mockFirewall, &peer.Status{})
forwarder.resolver = mockResolver
// Configure a single domain
d, err := domain.FromString("example.com")
require.NoError(t, err)
set := firewall.NewDomainSet([]domain.Domain{d})
entries := []*ForwarderEntry{{
Domain: d,
ResID: "test-res",
Set: set,
}}
forwarder.UpdateDomains(entries)
// Mock resolver returns multiple IPs
ips := []netip.Addr{
netip.MustParseAddr("1.1.1.1"),
netip.MustParseAddr("1.1.1.2"),
netip.MustParseAddr("1.1.1.3"),
}
mockResolver.On("LookupNetIP", mock.Anything, "ip4", "example.com.").
Return(ips, nil).Once()
// Expect ONE UpdateSet call with ALL prefixes
expectedPrefixes := []netip.Prefix{
netip.PrefixFrom(ips[0], 32),
netip.PrefixFrom(ips[1], 32),
netip.PrefixFrom(ips[2], 32),
}
mockFirewall.On("UpdateSet", set, expectedPrefixes).Return(nil).Once()
// Execute query
query := &dns.Msg{}
query.SetQuestion("example.com.", dns.TypeA)
mockWriter := &test.MockResponseWriter{}
resp := forwarder.handleDNSQuery(mockWriter, query)
// Verify response contains all IPs
require.NotNil(t, resp)
require.Equal(t, dns.RcodeSuccess, resp.Rcode)
require.Len(t, resp.Answer, 3, "Should have 3 answer records")
// Verify mocks
mockFirewall.AssertExpectations(t)
mockResolver.AssertExpectations(t)
}
func TestDNSForwarder_ResponseCodes(t *testing.T) {
tests := []struct {
name string
queryType uint16
queryDomain string
configured string
expectedCode int
description string
}{
{
name: "unauthorized domain returns REFUSED",
queryType: dns.TypeA,
queryDomain: "evil.com",
configured: "example.com",
expectedCode: dns.RcodeRefused,
description: "RFC compliant REFUSED for unauthorized queries",
},
{
name: "unsupported query type returns NOTIMP",
queryType: dns.TypeMX,
queryDomain: "example.com",
configured: "example.com",
expectedCode: dns.RcodeNotImplemented,
description: "RFC compliant NOTIMP for unsupported types",
},
{
name: "CNAME query returns NOTIMP",
queryType: dns.TypeCNAME,
queryDomain: "example.com",
configured: "example.com",
expectedCode: dns.RcodeNotImplemented,
description: "CNAME queries not supported",
},
{
name: "TXT query returns NOTIMP",
queryType: dns.TypeTXT,
queryDomain: "example.com",
configured: "example.com",
expectedCode: dns.RcodeNotImplemented,
description: "TXT queries not supported",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
forwarder := NewDNSForwarder("127.0.0.1:0", 300, nil, &peer.Status{})
d, err := domain.FromString(tt.configured)
require.NoError(t, err)
entries := []*ForwarderEntry{{Domain: d, ResID: "test-res"}}
forwarder.UpdateDomains(entries)
query := &dns.Msg{}
query.SetQuestion(dns.Fqdn(tt.queryDomain), tt.queryType)
// Capture the written response
var writtenResp *dns.Msg
mockWriter := &test.MockResponseWriter{
WriteMsgFunc: func(m *dns.Msg) error {
writtenResp = m
return nil
},
}
_ = forwarder.handleDNSQuery(mockWriter, query)
// Check the response written to the writer
require.NotNil(t, writtenResp, "Expected response to be written")
assert.Equal(t, tt.expectedCode, writtenResp.Rcode, tt.description)
})
}
}
func TestDNSForwarder_TCPTruncation(t *testing.T) {
// Test that large UDP responses are truncated with TC bit set
mockResolver := &MockResolver{}
forwarder := NewDNSForwarder("127.0.0.1:0", 300, nil, &peer.Status{})
forwarder.resolver = mockResolver
d, _ := domain.FromString("example.com")
entries := []*ForwarderEntry{{Domain: d, ResID: "test-res"}}
forwarder.UpdateDomains(entries)
// Mock many IPs to create a large response
var manyIPs []netip.Addr
for i := 0; i < 100; i++ {
manyIPs = append(manyIPs, netip.MustParseAddr(fmt.Sprintf("1.1.1.%d", i%256)))
}
mockResolver.On("LookupNetIP", mock.Anything, "ip4", "example.com.").Return(manyIPs, nil)
// Query without EDNS0
query := &dns.Msg{}
query.SetQuestion("example.com.", dns.TypeA)
var writtenResp *dns.Msg
mockWriter := &test.MockResponseWriter{
WriteMsgFunc: func(m *dns.Msg) error {
writtenResp = m
return nil
},
}
forwarder.handleDNSQueryUDP(mockWriter, query)
require.NotNil(t, writtenResp)
assert.True(t, writtenResp.Truncated, "Large response should be truncated")
assert.LessOrEqual(t, writtenResp.Len(), dns.MinMsgSize, "Response should fit in minimum UDP size")
}
func TestDNSForwarder_MultipleOverlappingPatterns(t *testing.T) {
// Test complex overlapping pattern scenarios
mockFirewall := &MockFirewall{}
mockResolver := &MockResolver{}
forwarder := NewDNSForwarder("127.0.0.1:0", 300, mockFirewall, &peer.Status{})
forwarder.resolver = mockResolver
// Set up complex overlapping patterns
patterns := []string{
"*.example.com", // Matches all subdomains
"*.mail.example.com", // More specific wildcard
"smtp.mail.example.com", // Exact match
"example.com", // Base domain
}
var entries []*ForwarderEntry
sets := make(map[string]firewall.Set)
for _, pattern := range patterns {
d, _ := domain.FromString(pattern)
set := firewall.NewDomainSet([]domain.Domain{d})
sets[pattern] = set
entries = append(entries, &ForwarderEntry{
Domain: d,
ResID: route.ResID("res-" + pattern),
Set: set,
})
}
forwarder.UpdateDomains(entries)
// Test smtp.mail.example.com - should match 3 patterns
fakeIP := netip.MustParseAddr("1.2.3.4")
mockResolver.On("LookupNetIP", mock.Anything, "ip4", "smtp.mail.example.com.").Return([]netip.Addr{fakeIP}, nil)
expectedPrefix := netip.PrefixFrom(fakeIP, 32)
// All three matching patterns should get firewall updates
mockFirewall.On("UpdateSet", sets["smtp.mail.example.com"], []netip.Prefix{expectedPrefix}).Return(nil)
mockFirewall.On("UpdateSet", sets["*.mail.example.com"], []netip.Prefix{expectedPrefix}).Return(nil)
mockFirewall.On("UpdateSet", sets["*.example.com"], []netip.Prefix{expectedPrefix}).Return(nil)
query := &dns.Msg{}
query.SetQuestion("smtp.mail.example.com.", dns.TypeA)
mockWriter := &test.MockResponseWriter{}
resp := forwarder.handleDNSQuery(mockWriter, query)
require.NotNil(t, resp)
assert.Equal(t, dns.RcodeSuccess, resp.Rcode)
// Verify all three sets were updated
mockFirewall.AssertExpectations(t)
// Verify the most specific ResID was selected
// (exact match should win over wildcards)
resID, matches := forwarder.getMatchingEntries("smtp.mail.example.com")
assert.Equal(t, route.ResID("res-smtp.mail.example.com"), resID)
assert.Len(t, matches, 3, "Should match 3 patterns")
}
func TestDNSForwarder_EmptyQuery(t *testing.T) {
// Test handling of malformed query with no questions
forwarder := NewDNSForwarder("127.0.0.1:0", 300, nil, &peer.Status{})
query := &dns.Msg{}
// Don't set any question
writeCalled := false
mockWriter := &test.MockResponseWriter{
WriteMsgFunc: func(m *dns.Msg) error {
writeCalled = true
return nil
},
}
resp := forwarder.handleDNSQuery(mockWriter, query)
assert.Nil(t, resp, "Should return nil for empty query")
assert.False(t, writeCalled, "Should not write response for empty query")
}

View File

@@ -38,6 +38,7 @@ import (
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
"github.com/netbirdio/netbird/client/internal/networkmonitor"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/peer/dispatcher"
"github.com/netbirdio/netbird/client/internal/peer/guard"
icemaker "github.com/netbirdio/netbird/client/internal/peer/ice"
"github.com/netbirdio/netbird/client/internal/peerstore"
@@ -174,7 +175,8 @@ type Engine struct {
sshServerFunc func(hostKeyPEM []byte, addr string) (nbssh.Server, error)
sshServer nbssh.Server
statusRecorder *peer.Status
statusRecorder *peer.Status
peerConnDispatcher *dispatcher.ConnectionDispatcher
firewall firewallManager.Manager
routeManager routemanager.Manager
@@ -381,13 +383,7 @@ func (e *Engine) Start() error {
}
e.stateManager.Start()
initialRoutes, dnsConfig, dnsFeatureFlag, err := e.readInitialSettings()
if err != nil {
e.close()
return fmt.Errorf("read initial settings: %w", err)
}
dnsServer, err := e.newDnsServer(dnsConfig)
initialRoutes, dnsServer, err := e.newDnsServer()
if err != nil {
e.close()
return fmt.Errorf("create dns server: %w", err)
@@ -404,7 +400,6 @@ func (e *Engine) Start() error {
InitialRoutes: initialRoutes,
StateManager: e.stateManager,
DNSServer: dnsServer,
DNSFeatureFlag: dnsFeatureFlag,
PeerStore: e.peerStore,
DisableClientRoutes: e.config.DisableClientRoutes,
DisableServerRoutes: e.config.DisableServerRoutes,
@@ -456,7 +451,9 @@ func (e *Engine) Start() error {
NATExternalIPs: e.parseNATExternalIPMappings(),
}
e.connMgr = NewConnMgr(e.config, e.statusRecorder, e.peerStore, wgIface)
e.peerConnDispatcher = dispatcher.NewConnectionDispatcher()
e.connMgr = NewConnMgr(e.config, e.statusRecorder, e.peerStore, wgIface, e.peerConnDispatcher)
e.connMgr.Start(e.ctx)
e.srWatcher = guard.NewSRWatcher(e.signal, e.relayManager, e.mobileDep.IFaceDiscover, iceCfg)
@@ -491,9 +488,9 @@ func (e *Engine) createFirewall() error {
}
func (e *Engine) initFirewall() error {
if err := e.routeManager.SetFirewall(e.firewall); err != nil {
if err := e.routeManager.EnableServerRouter(e.firewall); err != nil {
e.close()
return fmt.Errorf("set firewall: %w", err)
return fmt.Errorf("enable server router: %w", err)
}
if e.config.BlockLANAccess {
@@ -1012,6 +1009,8 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
log.Errorf("failed to update dns server, err: %v", err)
}
dnsRouteFeatureFlag := toDNSFeatureFlag(networkMap)
// apply routes first, route related actions might depend on routing being enabled
routes := toRoutes(networkMap.GetRoutes())
serverRoutes, clientRoutes := e.routeManager.ClassifyRoutes(routes)
@@ -1022,7 +1021,6 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
log.Debugf("updated lazy connection manager with %d HA groups", len(clientRoutes))
}
dnsRouteFeatureFlag := toDNSFeatureFlag(networkMap)
if err := e.routeManager.UpdateRoutes(serial, serverRoutes, clientRoutes, dnsRouteFeatureFlag); err != nil {
log.Errorf("failed to update routes: %v", err)
}
@@ -1257,7 +1255,7 @@ func (e *Engine) addNewPeer(peerConfig *mgmProto.RemotePeerConfig) error {
}
if exists := e.connMgr.AddPeerConn(e.ctx, peerKey, conn); exists {
conn.Close(false)
conn.Close()
return fmt.Errorf("peer already exists: %s", peerKey)
}
@@ -1304,12 +1302,13 @@ func (e *Engine) createPeerConn(pubKey string, allowedIPs []netip.Prefix, agentV
}
serviceDependencies := peer.ServiceDependencies{
StatusRecorder: e.statusRecorder,
Signaler: e.signaler,
IFaceDiscover: e.mobileDep.IFaceDiscover,
RelayManager: e.relayManager,
SrWatcher: e.srWatcher,
Semaphore: e.connSemaphore,
StatusRecorder: e.statusRecorder,
Signaler: e.signaler,
IFaceDiscover: e.mobileDep.IFaceDiscover,
RelayManager: e.relayManager,
SrWatcher: e.srWatcher,
Semaphore: e.connSemaphore,
PeerConnDispatcher: e.peerConnDispatcher,
}
peerConn, err := peer.NewConn(config, serviceDependencies)
if err != nil {
@@ -1332,16 +1331,11 @@ func (e *Engine) receiveSignalEvents() {
e.syncMsgMux.Lock()
defer e.syncMsgMux.Unlock()
conn, ok := e.peerStore.PeerConn(msg.Key)
conn, ok := e.connMgr.OnSignalMsg(e.ctx, msg.Key)
if !ok {
return fmt.Errorf("wrongly addressed message %s", msg.Key)
}
msgType := msg.GetBody().GetType()
if msgType != sProto.Body_GO_IDLE {
e.connMgr.ActivatePeer(e.ctx, conn)
}
switch msg.GetBody().Type {
case sProto.Body_OFFER:
remoteCred, err := signal.UnMarshalCredential(msg)
@@ -1398,8 +1392,6 @@ func (e *Engine) receiveSignalEvents() {
go conn.OnRemoteCandidate(candidate, e.routeManager.GetClientRoutes())
case sProto.Body_MODE:
case sProto.Body_GO_IDLE:
e.connMgr.DeactivatePeer(conn)
}
return nil
@@ -1497,12 +1489,7 @@ func (e *Engine) close() {
}
}
func (e *Engine) readInitialSettings() ([]*route.Route, *nbdns.Config, bool, error) {
if runtime.GOOS != "android" {
// nolint:nilnil
return nil, nil, false, nil
}
func (e *Engine) readInitialSettings() ([]*route.Route, *nbdns.Config, error) {
info := system.GetInfo(e.ctx)
info.SetFlags(
e.config.RosenpassEnabled,
@@ -1519,12 +1506,11 @@ func (e *Engine) readInitialSettings() ([]*route.Route, *nbdns.Config, bool, err
netMap, err := e.mgmClient.GetNetworkMap(info)
if err != nil {
return nil, nil, false, err
return nil, nil, err
}
routes := toRoutes(netMap.GetRoutes())
dnsCfg := toDNSConfig(netMap.GetDNSConfig(), e.wgInterface.Address().Network)
dnsFeatureFlag := toDNSFeatureFlag(netMap)
return routes, &dnsCfg, dnsFeatureFlag, nil
return routes, &dnsCfg, nil
}
func (e *Engine) newWgIface() (*iface.WGIface, error) {
@@ -1541,7 +1527,6 @@ func (e *Engine) newWgIface() (*iface.WGIface, error) {
MTU: iface.DefaultMTU,
TransportNet: transportNet,
FilterFn: e.addrViaRoutes,
DisableDNS: e.config.DisableDNS,
}
switch runtime.GOOS {
@@ -1572,14 +1557,18 @@ func (e *Engine) wgInterfaceCreate() (err error) {
return err
}
func (e *Engine) newDnsServer(dnsConfig *nbdns.Config) (dns.Server, error) {
func (e *Engine) newDnsServer() ([]*route.Route, dns.Server, error) {
// due to tests where we are using a mocked version of the DNS server
if e.dnsServer != nil {
return e.dnsServer, nil
return nil, e.dnsServer, nil
}
switch runtime.GOOS {
case "android":
routes, dnsConfig, err := e.readInitialSettings()
if err != nil {
return nil, nil, err
}
dnsServer := dns.NewDefaultServerPermanentUpstream(
e.ctx,
e.wgInterface,
@@ -1590,19 +1579,19 @@ func (e *Engine) newDnsServer(dnsConfig *nbdns.Config) (dns.Server, error) {
e.config.DisableDNS,
)
go e.mobileDep.DnsReadyListener.OnReady()
return dnsServer, nil
return routes, dnsServer, nil
case "ios":
dnsServer := dns.NewDefaultServerIos(e.ctx, e.wgInterface, e.mobileDep.DnsManager, e.statusRecorder, e.config.DisableDNS)
return dnsServer, nil
return nil, dnsServer, nil
default:
dnsServer, err := dns.NewDefaultServer(e.ctx, e.wgInterface, e.config.CustomDNSAddress, e.statusRecorder, e.stateManager, e.config.DisableDNS)
if err != nil {
return nil, err
return nil, nil, err
}
return dnsServer, nil
return nil, dnsServer, nil
}
}

View File

@@ -36,6 +36,7 @@ import (
"github.com/netbirdio/netbird/client/iface/wgproxy"
"github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/peer/dispatcher"
"github.com/netbirdio/netbird/client/internal/peer/guard"
icemaker "github.com/netbirdio/netbird/client/internal/peer/ice"
"github.com/netbirdio/netbird/client/internal/routemanager"
@@ -96,7 +97,6 @@ type MockWGIface struct {
GetInterfaceGUIDStringFunc func() (string, error)
GetProxyFunc func() wgproxy.Proxy
GetNetFunc func() *netstack.Net
LastActivitiesFunc func() map[string]time.Time
}
func (m *MockWGIface) FullStats() (*configurer.Stats, error) {
@@ -187,13 +187,6 @@ func (m *MockWGIface) GetNet() *netstack.Net {
return m.GetNetFunc()
}
func (m *MockWGIface) LastActivities() map[string]time.Time {
if m.LastActivitiesFunc != nil {
return m.LastActivitiesFunc()
}
return nil
}
func TestMain(m *testing.M) {
_ = util.InitLog("debug", "console")
code := m.Run()
@@ -411,7 +404,7 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
engine.udpMux = bind.NewUniversalUDPMuxDefault(bind.UniversalUDPMuxParams{UDPConn: conn})
engine.ctx = ctx
engine.srWatcher = guard.NewSRWatcher(nil, nil, nil, icemaker.Config{})
engine.connMgr = NewConnMgr(engine.config, engine.statusRecorder, engine.peerStore, wgIface)
engine.connMgr = NewConnMgr(engine.config, engine.statusRecorder, engine.peerStore, wgIface, dispatcher.NewConnectionDispatcher())
engine.connMgr.Start(ctx)
type testCase struct {
@@ -800,7 +793,7 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) {
engine.routeManager = mockRouteManager
engine.dnsServer = &dns.MockServer{}
engine.connMgr = NewConnMgr(engine.config, engine.statusRecorder, engine.peerStore, engine.wgInterface)
engine.connMgr = NewConnMgr(engine.config, engine.statusRecorder, engine.peerStore, engine.wgInterface, dispatcher.NewConnectionDispatcher())
engine.connMgr.Start(ctx)
defer func() {
@@ -998,7 +991,7 @@ func TestEngine_UpdateNetworkMapWithDNSUpdate(t *testing.T) {
}
engine.dnsServer = mockDNSServer
engine.connMgr = NewConnMgr(engine.config, engine.statusRecorder, engine.peerStore, engine.wgInterface)
engine.connMgr = NewConnMgr(engine.config, engine.statusRecorder, engine.peerStore, engine.wgInterface, dispatcher.NewConnectionDispatcher())
engine.connMgr.Start(ctx)
defer func() {
@@ -1483,7 +1476,7 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri
permissionsManager := permissions.NewManager(store)
accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager)
if err != nil {
return nil, "", err
}

View File

@@ -38,5 +38,4 @@ type wgIfaceBase interface {
GetStats() (map[string]configurer.WGStats, error)
GetNet() *netstack.Net
FullStats() (*configurer.Stats, error)
LastActivities() map[string]time.Time
}

View File

@@ -13,7 +13,7 @@ import (
// Listener it is not a thread safe implementation, do not call Close before ReadPackets. It will cause blocking
type Listener struct {
wgIface WgInterface
wgIface lazyconn.WGIface
peerCfg lazyconn.PeerConfig
conn *net.UDPConn
endpoint *net.UDPAddr
@@ -22,7 +22,7 @@ type Listener struct {
isClosed atomic.Bool // use to avoid error log when closing the listener
}
func NewListener(wgIface WgInterface, cfg lazyconn.PeerConfig) (*Listener, error) {
func NewListener(wgIface lazyconn.WGIface, cfg lazyconn.PeerConfig) (*Listener, error) {
d := &Listener{
wgIface: wgIface,
peerCfg: cfg,

View File

@@ -1,27 +1,18 @@
package activity
import (
"net"
"net/netip"
"sync"
"time"
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/client/internal/lazyconn"
peerid "github.com/netbirdio/netbird/client/internal/peer/id"
)
type WgInterface interface {
RemovePeer(peerKey string) error
UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error
}
type Manager struct {
OnActivityChan chan peerid.ConnID
wgIface WgInterface
wgIface lazyconn.WGIface
peers map[peerid.ConnID]*Listener
done chan struct{}
@@ -29,7 +20,7 @@ type Manager struct {
mu sync.Mutex
}
func NewManager(wgIface WgInterface) *Manager {
func NewManager(wgIface lazyconn.WGIface) *Manager {
m := &Manager{
OnActivityChan: make(chan peerid.ConnID, 1),
wgIface: wgIface,

View File

@@ -0,0 +1,70 @@
package inactivity
import (
"context"
"time"
peer "github.com/netbirdio/netbird/client/internal/peer/id"
)
const (
DefaultInactivityThreshold = 60 * time.Minute // idle after 1 hour inactivity
MinimumInactivityThreshold = 3 * time.Minute
)
type Monitor struct {
id peer.ConnID
timer *time.Timer
cancel context.CancelFunc
inactivityThreshold time.Duration
}
func NewInactivityMonitor(peerID peer.ConnID, threshold time.Duration) *Monitor {
i := &Monitor{
id: peerID,
timer: time.NewTimer(0),
inactivityThreshold: threshold,
}
i.timer.Stop()
return i
}
func (i *Monitor) Start(ctx context.Context, timeoutChan chan peer.ConnID) {
i.timer.Reset(i.inactivityThreshold)
defer i.timer.Stop()
ctx, i.cancel = context.WithCancel(ctx)
defer func() {
defer i.cancel()
select {
case <-i.timer.C:
default:
}
}()
select {
case <-i.timer.C:
select {
case timeoutChan <- i.id:
case <-ctx.Done():
return
}
case <-ctx.Done():
return
}
}
func (i *Monitor) Stop() {
if i.cancel == nil {
return
}
i.cancel()
}
func (i *Monitor) PauseTimer() {
i.timer.Stop()
}
func (i *Monitor) ResetTimer() {
i.timer.Reset(i.inactivityThreshold)
}

View File

@@ -0,0 +1,156 @@
package inactivity
import (
"context"
"testing"
"time"
peerid "github.com/netbirdio/netbird/client/internal/peer/id"
)
type MocPeer struct {
}
func (m *MocPeer) ConnID() peerid.ConnID {
return peerid.ConnID(m)
}
func TestInactivityMonitor(t *testing.T) {
tCtx, testTimeoutCancel := context.WithTimeout(context.Background(), time.Second*5)
defer testTimeoutCancel()
p := &MocPeer{}
im := NewInactivityMonitor(p.ConnID(), time.Second*2)
timeoutChan := make(chan peerid.ConnID)
exitChan := make(chan struct{})
go func() {
defer close(exitChan)
im.Start(tCtx, timeoutChan)
}()
select {
case <-timeoutChan:
case <-tCtx.Done():
t.Fatal("timeout")
}
select {
case <-exitChan:
case <-tCtx.Done():
t.Fatal("timeout")
}
}
func TestReuseInactivityMonitor(t *testing.T) {
p := &MocPeer{}
im := NewInactivityMonitor(p.ConnID(), time.Second*2)
timeoutChan := make(chan peerid.ConnID)
for i := 2; i > 0; i-- {
exitChan := make(chan struct{})
testTimeoutCtx, testTimeoutCancel := context.WithTimeout(context.Background(), time.Second*5)
go func() {
defer close(exitChan)
im.Start(testTimeoutCtx, timeoutChan)
}()
select {
case <-timeoutChan:
case <-testTimeoutCtx.Done():
t.Fatal("timeout")
}
select {
case <-exitChan:
case <-testTimeoutCtx.Done():
t.Fatal("timeout")
}
testTimeoutCancel()
}
}
func TestStopInactivityMonitor(t *testing.T) {
tCtx, testTimeoutCancel := context.WithTimeout(context.Background(), time.Second*5)
defer testTimeoutCancel()
p := &MocPeer{}
im := NewInactivityMonitor(p.ConnID(), DefaultInactivityThreshold)
timeoutChan := make(chan peerid.ConnID)
exitChan := make(chan struct{})
go func() {
defer close(exitChan)
im.Start(tCtx, timeoutChan)
}()
go func() {
time.Sleep(3 * time.Second)
im.Stop()
}()
select {
case <-timeoutChan:
t.Fatal("unexpected timeout")
case <-exitChan:
case <-tCtx.Done():
t.Fatal("timeout")
}
}
func TestPauseInactivityMonitor(t *testing.T) {
tCtx, testTimeoutCancel := context.WithTimeout(context.Background(), time.Second*10)
defer testTimeoutCancel()
p := &MocPeer{}
trashHold := time.Second * 3
im := NewInactivityMonitor(p.ConnID(), trashHold)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
timeoutChan := make(chan peerid.ConnID)
exitChan := make(chan struct{})
go func() {
defer close(exitChan)
im.Start(ctx, timeoutChan)
}()
time.Sleep(1 * time.Second) // grant time to start the monitor
im.PauseTimer()
// check to do not receive timeout
thresholdCtx, thresholdCancel := context.WithTimeout(context.Background(), trashHold+time.Second)
defer thresholdCancel()
select {
case <-exitChan:
t.Fatal("unexpected exit")
case <-timeoutChan:
t.Fatal("unexpected timeout")
case <-thresholdCtx.Done():
// test ok
case <-tCtx.Done():
t.Fatal("test timed out")
}
// test reset timer
im.ResetTimer()
select {
case <-tCtx.Done():
t.Fatal("test timed out")
case <-exitChan:
t.Fatal("unexpected exit")
case <-timeoutChan:
// expected timeout
}
}

View File

@@ -1,152 +0,0 @@
package inactivity
import (
"context"
"fmt"
"time"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/lazyconn"
)
const (
checkInterval = 1 * time.Minute
DefaultInactivityThreshold = 15 * time.Minute
MinimumInactivityThreshold = 1 * time.Minute
)
type WgInterface interface {
LastActivities() map[string]time.Time
}
type Manager struct {
inactivePeersChan chan map[string]struct{}
iface WgInterface
interestedPeers map[string]*lazyconn.PeerConfig
inactivityThreshold time.Duration
}
func NewManager(iface WgInterface, configuredThreshold *time.Duration) *Manager {
inactivityThreshold, err := validateInactivityThreshold(configuredThreshold)
if err != nil {
inactivityThreshold = DefaultInactivityThreshold
log.Warnf("invalid inactivity threshold configured: %v, using default: %v", err, DefaultInactivityThreshold)
}
log.Infof("inactivity threshold configured: %v", inactivityThreshold)
return &Manager{
inactivePeersChan: make(chan map[string]struct{}, 1),
iface: iface,
interestedPeers: make(map[string]*lazyconn.PeerConfig),
inactivityThreshold: inactivityThreshold,
}
}
func (m *Manager) InactivePeersChan() chan map[string]struct{} {
if m == nil {
// return a nil channel that blocks forever
return nil
}
return m.inactivePeersChan
}
func (m *Manager) AddPeer(peerCfg *lazyconn.PeerConfig) {
if m == nil {
return
}
if _, exists := m.interestedPeers[peerCfg.PublicKey]; exists {
return
}
peerCfg.Log.Infof("adding peer to inactivity manager")
m.interestedPeers[peerCfg.PublicKey] = peerCfg
}
func (m *Manager) RemovePeer(peer string) {
if m == nil {
return
}
pi, ok := m.interestedPeers[peer]
if !ok {
return
}
pi.Log.Debugf("remove peer from inactivity manager")
delete(m.interestedPeers, peer)
}
func (m *Manager) Start(ctx context.Context) {
if m == nil {
return
}
ticker := newTicker(checkInterval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C():
idlePeers, err := m.checkStats()
if err != nil {
log.Errorf("error checking stats: %v", err)
return
}
if len(idlePeers) == 0 {
continue
}
m.notifyInactivePeers(ctx, idlePeers)
}
}
}
func (m *Manager) notifyInactivePeers(ctx context.Context, inactivePeers map[string]struct{}) {
select {
case m.inactivePeersChan <- inactivePeers:
case <-ctx.Done():
return
default:
return
}
}
func (m *Manager) checkStats() (map[string]struct{}, error) {
lastActivities := m.iface.LastActivities()
idlePeers := make(map[string]struct{})
for peerID, peerCfg := range m.interestedPeers {
lastActive, ok := lastActivities[peerID]
if !ok {
// when peer is in connecting state
peerCfg.Log.Warnf("peer not found in wg stats")
continue
}
if time.Since(lastActive) > m.inactivityThreshold {
peerCfg.Log.Infof("peer is inactive since: %v", lastActive)
idlePeers[peerID] = struct{}{}
}
}
return idlePeers, nil
}
func validateInactivityThreshold(configuredThreshold *time.Duration) (time.Duration, error) {
if configuredThreshold == nil {
return DefaultInactivityThreshold, nil
}
if *configuredThreshold < MinimumInactivityThreshold {
return 0, fmt.Errorf("configured inactivity threshold %v is too low, using %v", *configuredThreshold, MinimumInactivityThreshold)
}
return *configuredThreshold, nil
}

View File

@@ -1,113 +0,0 @@
package inactivity
import (
"context"
"testing"
"time"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/netbirdio/netbird/client/internal/lazyconn"
)
type mockWgInterface struct {
lastActivities map[string]time.Time
}
func (m *mockWgInterface) LastActivities() map[string]time.Time {
return m.lastActivities
}
func TestPeerTriggersInactivity(t *testing.T) {
peerID := "peer1"
wgMock := &mockWgInterface{
lastActivities: map[string]time.Time{
peerID: time.Now().Add(-20 * time.Minute),
},
}
fakeTick := make(chan time.Time, 1)
newTicker = func(d time.Duration) Ticker {
return &fakeTickerMock{CChan: fakeTick}
}
peerLog := log.WithField("peer", peerID)
peerCfg := &lazyconn.PeerConfig{
PublicKey: peerID,
Log: peerLog,
}
manager := NewManager(wgMock, nil)
manager.AddPeer(peerCfg)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
// Start the manager in a goroutine
go manager.Start(ctx)
// Send a tick to simulate time passage
fakeTick <- time.Now()
// Check if peer appears on inactivePeersChan
select {
case inactivePeers := <-manager.inactivePeersChan:
assert.Contains(t, inactivePeers, peerID, "expected peer to be marked inactive")
case <-time.After(1 * time.Second):
t.Fatal("expected inactivity event, but none received")
}
}
func TestPeerTriggersActivity(t *testing.T) {
peerID := "peer1"
wgMock := &mockWgInterface{
lastActivities: map[string]time.Time{
peerID: time.Now().Add(-5 * time.Minute),
},
}
fakeTick := make(chan time.Time, 1)
newTicker = func(d time.Duration) Ticker {
return &fakeTickerMock{CChan: fakeTick}
}
peerLog := log.WithField("peer", peerID)
peerCfg := &lazyconn.PeerConfig{
PublicKey: peerID,
Log: peerLog,
}
manager := NewManager(wgMock, nil)
manager.AddPeer(peerCfg)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
// Start the manager in a goroutine
go manager.Start(ctx)
// Send a tick to simulate time passage
fakeTick <- time.Now()
// Check if peer appears on inactivePeersChan
select {
case <-manager.inactivePeersChan:
t.Fatal("expected inactive peer to be marked inactive")
case <-time.After(1 * time.Second):
// No inactivity event should be received
}
}
// fakeTickerMock implements Ticker interface for testing
type fakeTickerMock struct {
CChan chan time.Time
}
func (f *fakeTickerMock) C() <-chan time.Time {
return f.CChan
}
func (f *fakeTickerMock) Stop() {}

View File

@@ -1,24 +0,0 @@
package inactivity
import "time"
var newTicker = func(d time.Duration) Ticker {
return &realTicker{t: time.NewTicker(d)}
}
type Ticker interface {
C() <-chan time.Time
Stop()
}
type realTicker struct {
t *time.Ticker
}
func (r *realTicker) C() <-chan time.Time {
return r.t.C
}
func (r *realTicker) Stop() {
r.t.Stop()
}

View File

@@ -11,6 +11,7 @@ import (
"github.com/netbirdio/netbird/client/internal/lazyconn"
"github.com/netbirdio/netbird/client/internal/lazyconn/activity"
"github.com/netbirdio/netbird/client/internal/lazyconn/inactivity"
"github.com/netbirdio/netbird/client/internal/peer/dispatcher"
peerid "github.com/netbirdio/netbird/client/internal/peer/id"
"github.com/netbirdio/netbird/client/internal/peerstore"
"github.com/netbirdio/netbird/route"
@@ -42,46 +43,60 @@ type Config struct {
type Manager struct {
engineCtx context.Context
peerStore *peerstore.Store
connStateDispatcher *dispatcher.ConnectionDispatcher
inactivityThreshold time.Duration
connStateListener *dispatcher.ConnectionListener
managedPeers map[string]*lazyconn.PeerConfig
managedPeersByConnID map[peerid.ConnID]*managedPeer
excludes map[string]lazyconn.PeerConfig
managedPeersMu sync.Mutex
activityManager *activity.Manager
inactivityManager *inactivity.Manager
activityManager *activity.Manager
inactivityMonitors map[peerid.ConnID]*inactivity.Monitor
// Route HA group management
// If any peer in the same HA group is active, all peers in that group should prevent going idle
peerToHAGroups map[string][]route.HAUniqueID // peer ID -> HA groups they belong to
haGroupToPeers map[route.HAUniqueID][]string // HA group -> peer IDs in the group
routesMu sync.RWMutex
routesMu sync.RWMutex // protects route mappings
onInactive chan peerid.ConnID
}
// NewManager creates a new lazy connection manager
// engineCtx is the context for creating peer Connection
func NewManager(config Config, engineCtx context.Context, peerStore *peerstore.Store, wgIface lazyconn.WGIface) *Manager {
func NewManager(config Config, engineCtx context.Context, peerStore *peerstore.Store, wgIface lazyconn.WGIface, connStateDispatcher *dispatcher.ConnectionDispatcher) *Manager {
log.Infof("setup lazy connection service")
m := &Manager{
engineCtx: engineCtx,
peerStore: peerStore,
connStateDispatcher: connStateDispatcher,
inactivityThreshold: inactivity.DefaultInactivityThreshold,
managedPeers: make(map[string]*lazyconn.PeerConfig),
managedPeersByConnID: make(map[peerid.ConnID]*managedPeer),
excludes: make(map[string]lazyconn.PeerConfig),
activityManager: activity.NewManager(wgIface),
inactivityMonitors: make(map[peerid.ConnID]*inactivity.Monitor),
peerToHAGroups: make(map[string][]route.HAUniqueID),
haGroupToPeers: make(map[route.HAUniqueID][]string),
onInactive: make(chan peerid.ConnID),
}
if wgIface.IsUserspaceBind() {
m.inactivityManager = inactivity.NewManager(wgIface, config.InactivityThreshold)
} else {
log.Warnf("inactivity manager not supported for kernel mode, wait for remote peer to close the connection")
if config.InactivityThreshold != nil {
if *config.InactivityThreshold >= inactivity.MinimumInactivityThreshold {
m.inactivityThreshold = *config.InactivityThreshold
} else {
log.Warnf("inactivity threshold is too low, using %v", m.inactivityThreshold)
}
}
m.connStateListener = &dispatcher.ConnectionListener{
OnConnected: m.onPeerConnected,
OnDisconnected: m.onPeerDisconnected,
}
connStateDispatcher.AddListener(m.connStateListener)
return m
}
@@ -116,28 +131,24 @@ func (m *Manager) UpdateRouteHAMap(haMap route.HAMap) {
}
}
log.Debugf("updated route HA mappings: %d HA groups, %d peers with routes", len(m.haGroupToPeers), len(m.peerToHAGroups))
log.Debugf("updated route HA mappings: %d HA groups, %d peers with routes",
len(m.haGroupToPeers), len(m.peerToHAGroups))
}
// Start starts the manager and listens for peer activity and inactivity events
func (m *Manager) Start(ctx context.Context) {
defer m.close()
if m.inactivityManager != nil {
go m.inactivityManager.Start(ctx)
}
for {
select {
case <-ctx.Done():
return
case peerConnID := <-m.activityManager.OnActivityChan:
m.onPeerActivity(peerConnID)
case peerIDs := <-m.inactivityManager.InactivePeersChan():
m.onPeerInactivityTimedOut(peerIDs)
m.onPeerActivity(ctx, peerConnID)
case peerConnID := <-m.onInactive:
m.onPeerInactivityTimedOut(peerConnID)
}
}
}
// ExcludePeer marks peers for a permanent connection
@@ -145,7 +156,7 @@ func (m *Manager) Start(ctx context.Context) {
// Adds them back to the managed list and start the inactivity listener if they are removed from the exclude list. In
// this case, we suppose that the connection status is connected or connecting.
// If the peer is not exists yet in the managed list then the responsibility is the upper layer to call the AddPeer function
func (m *Manager) ExcludePeer(peerConfigs []lazyconn.PeerConfig) []string {
func (m *Manager) ExcludePeer(ctx context.Context, peerConfigs []lazyconn.PeerConfig) []string {
m.managedPeersMu.Lock()
defer m.managedPeersMu.Unlock()
@@ -176,7 +187,7 @@ func (m *Manager) ExcludePeer(peerConfigs []lazyconn.PeerConfig) []string {
peerCfg.Log.Infof("peer removed from lazy connection exclude list")
if err := m.addActivePeer(&peerCfg); err != nil {
if err := m.addActivePeer(ctx, peerCfg); err != nil {
log.Errorf("failed to add peer to lazy connection manager: %s", err)
continue
}
@@ -206,24 +217,20 @@ func (m *Manager) AddPeer(peerCfg lazyconn.PeerConfig) (bool, error) {
return false, err
}
im := inactivity.NewInactivityMonitor(peerCfg.PeerConnID, m.inactivityThreshold)
m.inactivityMonitors[peerCfg.PeerConnID] = im
m.managedPeers[peerCfg.PublicKey] = &peerCfg
m.managedPeersByConnID[peerCfg.PeerConnID] = &managedPeer{
peerCfg: &peerCfg,
expectedWatcher: watcherActivity,
}
// Check if this peer should be activated because its HA group peers are active
if group, ok := m.shouldActivateNewPeer(peerCfg.PublicKey); ok {
peerCfg.Log.Debugf("peer belongs to active HA group %s, will activate immediately", group)
m.activateNewPeerInActiveGroup(peerCfg)
}
return false, nil
}
// AddActivePeers adds a list of peers to the lazy connection manager
// suppose these peers was in connected or in connecting states
func (m *Manager) AddActivePeers(peerCfg []lazyconn.PeerConfig) error {
func (m *Manager) AddActivePeers(ctx context.Context, peerCfg []lazyconn.PeerConfig) error {
m.managedPeersMu.Lock()
defer m.managedPeersMu.Unlock()
@@ -233,7 +240,7 @@ func (m *Manager) AddActivePeers(peerCfg []lazyconn.PeerConfig) error {
continue
}
if err := m.addActivePeer(&cfg); err != nil {
if err := m.addActivePeer(ctx, cfg); err != nil {
cfg.Log.Errorf("failed to add peer to lazy connection manager: %v", err)
return err
}
@@ -250,7 +257,7 @@ func (m *Manager) RemovePeer(peerID string) {
// ActivatePeer activates a peer connection when a signal message is received
// Also activates all peers in the same HA groups as this peer
func (m *Manager) ActivatePeer(peerID string) (found bool) {
func (m *Manager) ActivatePeer(ctx context.Context, peerID string) (found bool) {
m.managedPeersMu.Lock()
defer m.managedPeersMu.Unlock()
cfg, mp := m.getPeerForActivation(peerID)
@@ -258,42 +265,15 @@ func (m *Manager) ActivatePeer(peerID string) (found bool) {
return false
}
if !m.activateSinglePeer(cfg, mp) {
if !m.activateSinglePeer(ctx, cfg, mp) {
return false
}
m.activateHAGroupPeers(cfg)
m.activateHAGroupPeers(ctx, peerID)
return true
}
func (m *Manager) DeactivatePeer(peerID peerid.ConnID) {
m.managedPeersMu.Lock()
defer m.managedPeersMu.Unlock()
mp, ok := m.managedPeersByConnID[peerID]
if !ok {
return
}
if mp.expectedWatcher != watcherInactivity {
return
}
m.peerStore.PeerConnClose(mp.peerCfg.PublicKey)
mp.peerCfg.Log.Infof("start activity monitor")
mp.expectedWatcher = watcherActivity
m.inactivityManager.RemovePeer(mp.peerCfg.PublicKey)
if err := m.activityManager.MonitorPeerActivity(*mp.peerCfg); err != nil {
mp.peerCfg.Log.Errorf("failed to create activity monitor: %v", err)
return
}
}
// getPeerForActivation checks if a peer can be activated and returns the necessary structs
// Returns nil values if the peer should be skipped
func (m *Manager) getPeerForActivation(peerID string) (*lazyconn.PeerConfig, *managedPeer) {
@@ -315,120 +295,82 @@ func (m *Manager) getPeerForActivation(peerID string) (*lazyconn.PeerConfig, *ma
return cfg, mp
}
// activateSinglePeer activates a single peer
// return true if the peer was activated, false if it was already active
func (m *Manager) activateSinglePeer(cfg *lazyconn.PeerConfig, mp *managedPeer) bool {
if mp.expectedWatcher == watcherInactivity {
// activateSinglePeer activates a single peer (internal method)
func (m *Manager) activateSinglePeer(ctx context.Context, cfg *lazyconn.PeerConfig, mp *managedPeer) bool {
mp.expectedWatcher = watcherInactivity
m.activityManager.RemovePeer(cfg.Log, cfg.PeerConnID)
im, ok := m.inactivityMonitors[cfg.PeerConnID]
if !ok {
cfg.Log.Errorf("inactivity monitor not found for peer")
return false
}
mp.expectedWatcher = watcherInactivity
m.activityManager.RemovePeer(cfg.Log, cfg.PeerConnID)
m.inactivityManager.AddPeer(cfg)
cfg.Log.Infof("starting inactivity monitor")
go im.Start(ctx, m.onInactive)
return true
}
// activateHAGroupPeers activates all peers in HA groups that the given peer belongs to
func (m *Manager) activateHAGroupPeers(triggeredPeerCfg *lazyconn.PeerConfig) {
var peersToActivate []string
func (m *Manager) activateHAGroupPeers(ctx context.Context, triggerPeerID string) {
m.routesMu.RLock()
haGroups := m.peerToHAGroups[triggeredPeerCfg.PublicKey]
haGroups := m.peerToHAGroups[triggerPeerID]
m.routesMu.RUnlock()
if len(haGroups) == 0 {
m.routesMu.RUnlock()
triggeredPeerCfg.Log.Debugf("peer is not part of any HA groups")
log.Debugf("peer %s is not part of any HA groups", triggerPeerID)
return
}
for _, haGroup := range haGroups {
peers := m.haGroupToPeers[haGroup]
for _, peerID := range peers {
if peerID != triggeredPeerCfg.PublicKey {
peersToActivate = append(peersToActivate, peerID)
}
}
}
m.routesMu.RUnlock()
activatedCount := 0
for _, peerID := range peersToActivate {
cfg, mp := m.getPeerForActivation(peerID)
if cfg == nil {
continue
}
for _, haGroup := range haGroups {
m.routesMu.RLock()
peers := m.haGroupToPeers[haGroup]
m.routesMu.RUnlock()
if m.activateSinglePeer(cfg, mp) {
activatedCount++
cfg.Log.Infof("activated peer as part of HA group (triggered by %s)", triggeredPeerCfg.PublicKey)
m.peerStore.PeerConnOpen(m.engineCtx, cfg.PublicKey)
for _, peerID := range peers {
if peerID == triggerPeerID {
continue
}
cfg, mp := m.getPeerForActivation(peerID)
if cfg == nil {
continue
}
if m.activateSinglePeer(ctx, cfg, mp) {
activatedCount++
cfg.Log.Infof("activated peer as part of HA group %s (triggered by %s)", haGroup, triggerPeerID)
m.peerStore.PeerConnOpen(m.engineCtx, cfg.PublicKey)
}
}
}
if activatedCount > 0 {
log.Infof("activated %d additional peers in HA groups for peer %s (groups: %v)",
activatedCount, triggeredPeerCfg.PublicKey, haGroups)
activatedCount, triggerPeerID, haGroups)
}
}
// shouldActivateNewPeer checks if a newly added peer should be activated
// because other peers in its HA groups are already active
func (m *Manager) shouldActivateNewPeer(peerID string) (route.HAUniqueID, bool) {
m.routesMu.RLock()
defer m.routesMu.RUnlock()
haGroups := m.peerToHAGroups[peerID]
if len(haGroups) == 0 {
return "", false
}
for _, haGroup := range haGroups {
peers := m.haGroupToPeers[haGroup]
for _, groupPeerID := range peers {
if groupPeerID == peerID {
continue
}
cfg, ok := m.managedPeers[groupPeerID]
if !ok {
continue
}
if mp, ok := m.managedPeersByConnID[cfg.PeerConnID]; ok && mp.expectedWatcher == watcherInactivity {
return haGroup, true
}
}
}
return "", false
}
// activateNewPeerInActiveGroup activates a newly added peer that should be active due to HA group
func (m *Manager) activateNewPeerInActiveGroup(peerCfg lazyconn.PeerConfig) {
mp, ok := m.managedPeersByConnID[peerCfg.PeerConnID]
if !ok {
return
}
if !m.activateSinglePeer(&peerCfg, mp) {
return
}
peerCfg.Log.Infof("activated newly added peer due to active HA group peers")
m.peerStore.PeerConnOpen(m.engineCtx, peerCfg.PublicKey)
}
func (m *Manager) addActivePeer(peerCfg *lazyconn.PeerConfig) error {
func (m *Manager) addActivePeer(ctx context.Context, peerCfg lazyconn.PeerConfig) error {
if _, ok := m.managedPeers[peerCfg.PublicKey]; ok {
peerCfg.Log.Warnf("peer already managed")
return nil
}
m.managedPeers[peerCfg.PublicKey] = peerCfg
im := inactivity.NewInactivityMonitor(peerCfg.PeerConnID, m.inactivityThreshold)
m.inactivityMonitors[peerCfg.PeerConnID] = im
m.managedPeers[peerCfg.PublicKey] = &peerCfg
m.managedPeersByConnID[peerCfg.PeerConnID] = &managedPeer{
peerCfg: peerCfg,
peerCfg: &peerCfg,
expectedWatcher: watcherInactivity,
}
m.inactivityManager.AddPeer(peerCfg)
peerCfg.Log.Infof("starting inactivity monitor on peer that has been removed from exclude list")
go im.Start(ctx, m.onInactive)
return nil
}
@@ -440,7 +382,12 @@ func (m *Manager) removePeer(peerID string) {
cfg.Log.Infof("removing lazy peer")
m.inactivityManager.RemovePeer(cfg.PublicKey)
if im, ok := m.inactivityMonitors[cfg.PeerConnID]; ok {
im.Stop()
delete(m.inactivityMonitors, cfg.PeerConnID)
cfg.Log.Debugf("inactivity monitor stopped")
}
m.activityManager.RemovePeer(cfg.Log, cfg.PeerConnID)
delete(m.managedPeers, peerID)
delete(m.managedPeersByConnID, cfg.PeerConnID)
@@ -450,8 +397,12 @@ func (m *Manager) close() {
m.managedPeersMu.Lock()
defer m.managedPeersMu.Unlock()
m.connStateDispatcher.RemoveListener(m.connStateListener)
m.activityManager.Close()
for _, iw := range m.inactivityMonitors {
iw.Stop()
}
m.inactivityMonitors = make(map[peerid.ConnID]*inactivity.Monitor)
m.managedPeers = make(map[string]*lazyconn.PeerConfig)
m.managedPeersByConnID = make(map[peerid.ConnID]*managedPeer)
@@ -464,56 +415,7 @@ func (m *Manager) close() {
log.Infof("lazy connection manager closed")
}
// shouldDeferIdleForHA checks if peer should stay connected due to HA group requirements
func (m *Manager) shouldDeferIdleForHA(inactivePeers map[string]struct{}, peerID string) bool {
m.routesMu.RLock()
defer m.routesMu.RUnlock()
haGroups := m.peerToHAGroups[peerID]
if len(haGroups) == 0 {
return false
}
for _, haGroup := range haGroups {
if active := m.checkHaGroupActivity(haGroup, peerID, inactivePeers); active {
return true
}
}
return false
}
func (m *Manager) checkHaGroupActivity(haGroup route.HAUniqueID, peerID string, inactivePeers map[string]struct{}) bool {
groupPeers := m.haGroupToPeers[haGroup]
for _, groupPeerID := range groupPeers {
if groupPeerID == peerID {
continue
}
cfg, ok := m.managedPeers[groupPeerID]
if !ok {
continue
}
groupMp, ok := m.managedPeersByConnID[cfg.PeerConnID]
if !ok {
continue
}
if groupMp.expectedWatcher != watcherInactivity {
continue
}
// If any peer in the group is active, do defer idle
if _, isInactive := inactivePeers[groupPeerID]; !isInactive {
return true
}
}
return false
}
func (m *Manager) onPeerActivity(peerConnID peerid.ConnID) {
func (m *Manager) onPeerActivity(ctx context.Context, peerConnID peerid.ConnID) {
m.managedPeersMu.Lock()
defer m.managedPeersMu.Unlock()
@@ -530,56 +432,89 @@ func (m *Manager) onPeerActivity(peerConnID peerid.ConnID) {
mp.peerCfg.Log.Infof("detected peer activity")
if !m.activateSinglePeer(mp.peerCfg, mp) {
if !m.activateSinglePeer(ctx, mp.peerCfg, mp) {
return
}
m.activateHAGroupPeers(mp.peerCfg)
m.activateHAGroupPeers(ctx, mp.peerCfg.PublicKey)
m.peerStore.PeerConnOpen(m.engineCtx, mp.peerCfg.PublicKey)
}
func (m *Manager) onPeerInactivityTimedOut(peerIDs map[string]struct{}) {
func (m *Manager) onPeerInactivityTimedOut(peerConnID peerid.ConnID) {
m.managedPeersMu.Lock()
defer m.managedPeersMu.Unlock()
for peerID := range peerIDs {
peerCfg, ok := m.managedPeers[peerID]
if !ok {
log.Errorf("peer not found by peerId: %v", peerID)
continue
}
mp, ok := m.managedPeersByConnID[peerConnID]
if !ok {
log.Errorf("peer not found by id: %v", peerConnID)
return
}
mp, ok := m.managedPeersByConnID[peerCfg.PeerConnID]
if !ok {
log.Errorf("peer not found by conn id: %v", peerCfg.PeerConnID)
continue
}
if mp.expectedWatcher != watcherInactivity {
mp.peerCfg.Log.Warnf("ignore inactivity event")
return
}
if mp.expectedWatcher != watcherInactivity {
mp.peerCfg.Log.Warnf("ignore inactivity event")
continue
}
mp.peerCfg.Log.Infof("connection timed out")
if m.shouldDeferIdleForHA(peerIDs, mp.peerCfg.PublicKey) {
mp.peerCfg.Log.Infof("defer inactivity due to active HA group peers")
continue
}
// this is blocking operation, potentially can be optimized
m.peerStore.PeerConnClose(mp.peerCfg.PublicKey)
mp.peerCfg.Log.Infof("connection timed out")
mp.peerCfg.Log.Infof("start activity monitor")
// this is blocking operation, potentially can be optimized
m.peerStore.PeerConnIdle(mp.peerCfg.PublicKey)
mp.expectedWatcher = watcherActivity
mp.peerCfg.Log.Infof("start activity monitor")
// just in case free up
m.inactivityMonitors[peerConnID].PauseTimer()
mp.expectedWatcher = watcherActivity
m.inactivityManager.RemovePeer(mp.peerCfg.PublicKey)
if err := m.activityManager.MonitorPeerActivity(*mp.peerCfg); err != nil {
mp.peerCfg.Log.Errorf("failed to create activity monitor: %v", err)
continue
}
if err := m.activityManager.MonitorPeerActivity(*mp.peerCfg); err != nil {
mp.peerCfg.Log.Errorf("failed to create activity monitor: %v", err)
return
}
}
func (m *Manager) onPeerConnected(peerConnID peerid.ConnID) {
m.managedPeersMu.Lock()
defer m.managedPeersMu.Unlock()
mp, ok := m.managedPeersByConnID[peerConnID]
if !ok {
return
}
if mp.expectedWatcher != watcherInactivity {
return
}
iw, ok := m.inactivityMonitors[mp.peerCfg.PeerConnID]
if !ok {
mp.peerCfg.Log.Errorf("inactivity monitor not found for peer")
return
}
mp.peerCfg.Log.Infof("peer connected, pausing inactivity monitor while connection is not disconnected")
iw.PauseTimer()
}
func (m *Manager) onPeerDisconnected(peerConnID peerid.ConnID) {
m.managedPeersMu.Lock()
defer m.managedPeersMu.Unlock()
mp, ok := m.managedPeersByConnID[peerConnID]
if !ok {
return
}
if mp.expectedWatcher != watcherInactivity {
return
}
iw, ok := m.inactivityMonitors[mp.peerCfg.PeerConnID]
if !ok {
return
}
mp.peerCfg.Log.Infof("reset inactivity monitor timer")
iw.ResetTimer()
}

View File

@@ -11,6 +11,4 @@ import (
type WGIface interface {
RemovePeer(peerKey string) error
UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error
IsUserspaceBind() bool
LastActivities() map[string]time.Time
}

View File

@@ -148,7 +148,7 @@ func registerPeer(ctx context.Context, serverPublicKey wgtypes.Key, client *mgm.
)
loginResp, err := client.Register(serverPublicKey, validSetupKey.String(), jwtToken, info, pubSSHKey, config.DNSLabels)
if err != nil {
log.Errorf("failed registering peer %v", err)
log.Errorf("failed registering peer %v,%s", err, validSetupKey.String())
return nil, err
}

View File

@@ -204,7 +204,7 @@ func (c *ConnTrack) handleEvent(event nfct.Event) {
eventStr = "Ended"
}
log.Tracef("%s %s %s connection: %s:%d %s:%d", eventStr, direction, proto, srcIP, srcPort, dstIP, dstPort)
log.Tracef("%s %s %s connection: %s:%d -> %s:%d", eventStr, direction, proto, srcIP, srcPort, dstIP, dstPort)
c.flowLogger.StoreEvent(nftypes.EventFields{
FlowID: flowID,

View File

@@ -117,9 +117,10 @@ type Conn struct {
wgProxyRelay wgproxy.Proxy
handshaker *Handshaker
guard *guard.Guard
semaphore *semaphoregroup.SemaphoreGroup
wg sync.WaitGroup
guard *guard.Guard
semaphore *semaphoregroup.SemaphoreGroup
peerConnDispatcher *dispatcher.ConnectionDispatcher
wg sync.WaitGroup
// debug purpose
dumpState *stateDump
@@ -135,17 +136,18 @@ func NewConn(config ConnConfig, services ServiceDependencies) (*Conn, error) {
connLog := log.WithField("peer", config.Key)
var conn = &Conn{
Log: connLog,
config: config,
statusRecorder: services.StatusRecorder,
signaler: services.Signaler,
iFaceDiscover: services.IFaceDiscover,
relayManager: services.RelayManager,
srWatcher: services.SrWatcher,
semaphore: services.Semaphore,
statusRelay: worker.NewAtomicStatus(),
statusICE: worker.NewAtomicStatus(),
dumpState: newStateDump(config.Key, connLog, services.StatusRecorder),
Log: connLog,
config: config,
statusRecorder: services.StatusRecorder,
signaler: services.Signaler,
iFaceDiscover: services.IFaceDiscover,
relayManager: services.RelayManager,
srWatcher: services.SrWatcher,
semaphore: services.Semaphore,
peerConnDispatcher: services.PeerConnDispatcher,
statusRelay: worker.NewAtomicStatus(),
statusICE: worker.NewAtomicStatus(),
dumpState: newStateDump(config.Key, connLog, services.StatusRecorder),
}
return conn, nil
@@ -224,7 +226,7 @@ func (conn *Conn) Open(engineCtx context.Context) error {
}
// Close closes this peer Conn issuing a close event to the Conn closeCh
func (conn *Conn) Close(signalToRemote bool) {
func (conn *Conn) Close() {
conn.mu.Lock()
defer conn.wgWatcherWg.Wait()
defer conn.mu.Unlock()
@@ -234,12 +236,6 @@ func (conn *Conn) Close(signalToRemote bool) {
return
}
if signalToRemote {
if err := conn.signaler.SignalIdle(conn.config.Key); err != nil {
conn.Log.Errorf("failed to signal idle state to peer: %v", err)
}
}
conn.Log.Infof("close peer connection")
conn.ctxCancel()
@@ -321,12 +317,12 @@ func (conn *Conn) WgConfig() WgConfig {
return conn.config.WgConfig
}
// IsConnected returns true if the peer is connected
// IsConnected unit tests only
// refactor unit test to use status recorder use refactor status recorded to manage connection status in peer.Conn
func (conn *Conn) IsConnected() bool {
conn.mu.Lock()
defer conn.mu.Unlock()
return conn.evalStatus() == StatusConnected
return conn.currentConnPriority != conntype.None
}
func (conn *Conn) GetKey() string {
@@ -408,10 +404,15 @@ func (conn *Conn) onICEConnectionIsReady(priority conntype.ConnPriority, iceConn
}
wgConfigWorkaround()
oldState := conn.currentConnPriority
conn.currentConnPriority = priority
conn.statusICE.SetConnected()
conn.updateIceState(iceConnInfo)
conn.doOnConnected(iceConnInfo.RosenpassPubKey, iceConnInfo.RosenpassAddr)
if oldState == conntype.None {
conn.peerConnDispatcher.NotifyConnected(conn.ConnID())
}
}
func (conn *Conn) onICEStateDisconnected() {
@@ -449,6 +450,7 @@ func (conn *Conn) onICEStateDisconnected() {
} else {
conn.Log.Infof("ICE disconnected, do not switch to Relay. Reset priority to: %s", conntype.None.String())
conn.currentConnPriority = conntype.None
conn.peerConnDispatcher.NotifyDisconnected(conn.ConnID())
}
changed := conn.statusICE.Get() != worker.StatusDisconnected
@@ -528,6 +530,7 @@ func (conn *Conn) onRelayConnectionIsReady(rci RelayConnInfo) {
conn.updateRelayStatus(rci.relayedConn.RemoteAddr().String(), rci.rosenpassPubKey)
conn.Log.Infof("start to communicate with peer via relay")
conn.doOnConnected(rci.rosenpassPubKey, rci.rosenpassAddr)
conn.peerConnDispatcher.NotifyConnected(conn.ConnID())
}
func (conn *Conn) onRelayDisconnected() {
@@ -542,7 +545,11 @@ func (conn *Conn) onRelayDisconnected() {
if conn.currentConnPriority == conntype.Relay {
conn.Log.Debugf("clean up WireGuard config")
if err := conn.removeWgPeer(); err != nil {
conn.Log.Errorf("failed to remove wg endpoint: %v", err)
}
conn.currentConnPriority = conntype.None
conn.peerConnDispatcher.NotifyDisconnected(conn.ConnID())
}
if conn.wgProxyRelay != nil {

View File

@@ -68,13 +68,3 @@ func (s *Signaler) signalOfferAnswer(offerAnswer OfferAnswer, remoteKey string,
return nil
}
func (s *Signaler) SignalIdle(remoteKey string) error {
return s.signal.Send(&sProto.Message{
Key: s.wgPrivateKey.PublicKey().String(),
RemoteKey: remoteKey,
Body: &sProto.Body{
Type: sProto.Body_GO_IDLE,
},
})
}

View File

@@ -575,12 +575,13 @@ func (d *Status) UpdatePeerFQDN(peerPubKey, fqdn string) error {
// FinishPeerListModifications this event invoke the notification
func (d *Status) FinishPeerListModifications() {
d.mux.Lock()
defer d.mux.Unlock()
if !d.peerListChangedForNotification {
d.mux.Unlock()
return
}
d.peerListChangedForNotification = false
d.mux.Unlock()
d.notifyPeerListChanged()

View File

@@ -146,8 +146,8 @@ func (w *WorkerICE) OnNewOffer(remoteOfferAnswer *OfferAnswer) {
RosenpassAddr: remoteOfferAnswer.RosenpassAddr,
LocalIceCandidateType: pair.Local.Type().String(),
RemoteIceCandidateType: pair.Remote.Type().String(),
LocalIceCandidateEndpoint: formatEndpoint(pair.Local.Address(), pair.Local.Port()),
RemoteIceCandidateEndpoint: formatEndpoint(pair.Remote.Address(), pair.Remote.Port()),
LocalIceCandidateEndpoint: fmt.Sprintf("%s:%d", pair.Local.Address(), pair.Local.Port()),
RemoteIceCandidateEndpoint: fmt.Sprintf("%s:%d", pair.Remote.Address(), pair.Remote.Port()),
Relayed: isRelayed(pair),
RelayedOnLocal: isRelayCandidate(pair.Local),
}
@@ -405,12 +405,3 @@ func selectedPriority(pair *ice.CandidatePair) conntype.ConnPriority {
return conntype.ICEP2P
}
}
// formatEndpoint formats an IP address and port for display, adding brackets around IPv6 addresses
func formatEndpoint(addr string, port int) string {
parsed, err := netip.ParseAddr(addr)
if err == nil && parsed.Is6() {
return fmt.Sprintf("[%s]:%d", addr, port)
}
return fmt.Sprintf("%s:%d", addr, port)
}

View File

@@ -95,17 +95,6 @@ func (s *Store) PeerConnOpen(ctx context.Context, pubKey string) {
}
func (s *Store) PeerConnIdle(pubKey string) {
s.peerConnsMu.RLock()
defer s.peerConnsMu.RUnlock()
p, ok := s.peerConns[pubKey]
if !ok {
return
}
p.Close(true)
}
func (s *Store) PeerConnClose(pubKey string) {
s.peerConnsMu.RLock()
defer s.peerConnsMu.RUnlock()
@@ -114,7 +103,7 @@ func (s *Store) PeerConnClose(pubKey string) {
if !ok {
return
}
p.Close(false)
p.Close()
}
func (s *Store) PeersPubKey() []string {

View File

@@ -4,16 +4,18 @@ import (
"context"
"fmt"
"reflect"
"runtime"
"time"
log "github.com/sirupsen/logrus"
nbdns "github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/routemanager/common"
"github.com/netbirdio/netbird/client/internal/peerstore"
"github.com/netbirdio/netbird/client/internal/routemanager/dnsinterceptor"
"github.com/netbirdio/netbird/client/internal/routemanager/dynamic"
"github.com/netbirdio/netbird/client/internal/routemanager/iface"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
"github.com/netbirdio/netbird/client/internal/routemanager/static"
"github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/route"
@@ -21,7 +23,7 @@ import (
const (
handlerTypeDynamic = iota
handlerTypeDnsInterceptor
handlerTypeDomain
handlerTypeStatic
)
@@ -552,16 +554,40 @@ func (w *Watcher) Stop() {
w.currentChosenStatus = nil
}
func HandlerFromRoute(params common.HandlerParams) RouteHandler {
switch handlerType(params.Route, params.UseNewDNSRoute) {
case handlerTypeDnsInterceptor:
return dnsinterceptor.New(params)
func HandlerFromRoute(
rt *route.Route,
routeRefCounter *refcounter.RouteRefCounter,
allowedIPsRefCounter *refcounter.AllowedIPsRefCounter,
dnsRouterInteval time.Duration,
statusRecorder *peer.Status,
wgInterface iface.WGIface,
dnsServer nbdns.Server,
peerStore *peerstore.Store,
useNewDNSRoute bool,
) RouteHandler {
switch handlerType(rt, useNewDNSRoute) {
case handlerTypeDomain:
return dnsinterceptor.New(
rt,
routeRefCounter,
allowedIPsRefCounter,
statusRecorder,
dnsServer,
peerStore,
)
case handlerTypeDynamic:
dns := nbdns.NewServiceViaMemory(params.WgInterface)
dnsAddr := fmt.Sprintf("%s:%d", dns.RuntimeIP(), dns.RuntimePort())
return dynamic.NewRoute(params, dnsAddr)
dns := nbdns.NewServiceViaMemory(wgInterface)
return dynamic.NewRoute(
rt,
routeRefCounter,
allowedIPsRefCounter,
dnsRouterInteval,
statusRecorder,
wgInterface,
fmt.Sprintf("%s:%d", dns.RuntimeIP(), dns.RuntimePort()),
)
default:
return static.NewRoute(params)
return static.NewRoute(rt, routeRefCounter, allowedIPsRefCounter)
}
}
@@ -570,8 +596,8 @@ func handlerType(rt *route.Route, useNewDNSRoute bool) int {
return handlerTypeStatic
}
if useNewDNSRoute {
return handlerTypeDnsInterceptor
if useNewDNSRoute && runtime.GOOS != "ios" {
return handlerTypeDomain
}
return handlerTypeDynamic
}

View File

@@ -7,12 +7,12 @@ import (
"time"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/routemanager/common"
"github.com/netbirdio/netbird/client/internal/routemanager/static"
"github.com/netbirdio/netbird/route"
)
func TestGetBestrouteFromStatuses(t *testing.T) {
testCases := []struct {
name string
statuses map[route.ID]routerPeerStatus
@@ -811,12 +811,9 @@ func TestGetBestrouteFromStatuses(t *testing.T) {
currentRoute = tc.existingRoutes[tc.currentRoute]
}
params := common.HandlerParams{
Route: &route.Route{Network: netip.MustParsePrefix("192.168.0.0/24")},
}
// create new clientNetwork
client := &Watcher{
handler: static.NewRoute(params),
handler: static.NewRoute(&route.Route{Network: netip.MustParsePrefix("192.168.0.0/24")}, nil, nil),
routes: tc.existingRoutes,
currentChosen: currentRoute,
}

View File

@@ -1,28 +0,0 @@
package common
import (
"time"
"github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/peerstore"
"github.com/netbirdio/netbird/client/internal/routemanager/fakeip"
"github.com/netbirdio/netbird/client/internal/routemanager/iface"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
"github.com/netbirdio/netbird/route"
)
type HandlerParams struct {
Route *route.Route
RouteRefCounter *refcounter.RouteRefCounter
AllowedIPsRefCounter *refcounter.AllowedIPsRefCounter
DnsRouterInterval time.Duration
StatusRecorder *peer.Status
WgInterface iface.WGIface
DnsServer dns.Server
PeerStore *peerstore.Store
UseNewDNSRoute bool
Firewall manager.Manager
FakeIPManager *fakeip.Manager
}

View File

@@ -4,23 +4,19 @@ import (
"context"
"fmt"
"net/netip"
"runtime"
"strings"
"sync"
"time"
"github.com/hashicorp/go-multierror"
"github.com/miekg/dns"
log "github.com/sirupsen/logrus"
nberrors "github.com/netbirdio/netbird/client/errors"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface/wgaddr"
nbdns "github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/dnsfwd"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/peerstore"
"github.com/netbirdio/netbird/client/internal/routemanager/common"
"github.com/netbirdio/netbird/client/internal/routemanager/fakeip"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
"github.com/netbirdio/netbird/management/domain"
"github.com/netbirdio/netbird/route"
@@ -28,16 +24,6 @@ import (
type domainMap map[domain.Domain][]netip.Prefix
type internalDNATer interface {
RemoveInternalDNATMapping(netip.Addr) error
AddInternalDNATMapping(netip.Addr, netip.Addr) error
}
type wgInterface interface {
Name() string
Address() wgaddr.Address
}
type DnsInterceptor struct {
mu sync.RWMutex
route *route.Route
@@ -47,24 +33,25 @@ type DnsInterceptor struct {
dnsServer nbdns.Server
currentPeerKey string
interceptedDomains domainMap
wgInterface wgInterface
peerStore *peerstore.Store
firewall firewall.Manager
fakeIPManager *fakeip.Manager
}
func New(params common.HandlerParams) *DnsInterceptor {
func New(
rt *route.Route,
routeRefCounter *refcounter.RouteRefCounter,
allowedIPsRefCounter *refcounter.AllowedIPsRefCounter,
statusRecorder *peer.Status,
dnsServer nbdns.Server,
peerStore *peerstore.Store,
) *DnsInterceptor {
return &DnsInterceptor{
route: params.Route,
routeRefCounter: params.RouteRefCounter,
allowedIPsRefcounter: params.AllowedIPsRefCounter,
statusRecorder: params.StatusRecorder,
dnsServer: params.DnsServer,
wgInterface: params.WgInterface,
peerStore: params.PeerStore,
firewall: params.Firewall,
fakeIPManager: params.FakeIPManager,
route: rt,
routeRefCounter: routeRefCounter,
allowedIPsRefcounter: allowedIPsRefCounter,
statusRecorder: statusRecorder,
dnsServer: dnsServer,
interceptedDomains: make(domainMap),
peerStore: peerStore,
}
}
@@ -77,19 +64,127 @@ func (d *DnsInterceptor) AddRoute(context.Context) error {
return nil
}
// preResolveDomains performs background DNS resolution for non-wildcard domains
func (d *DnsInterceptor) preResolveDomains() {
for _, domain := range d.route.Domains {
domainStr := string(domain)
if strings.HasPrefix(domainStr, "*.") {
continue
}
domainStr = strings.TrimSuffix(domainStr, ".")
go func(domain string) {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := d.resolveAndUpdateDomain(ctx, domain); err != nil {
log.Debugf("pre-resolve failed for domain %s: %v", domain, err)
} else {
log.Tracef("pre-resolve completed for domain %s", domain)
}
}(domainStr)
}
}
// resolveAndUpdateDomain performs DNS resolution and updates domain prefixes
func (d *DnsInterceptor) resolveAndUpdateDomain(ctx context.Context, qDomain string) error {
d.mu.RLock()
peerKey := d.currentPeerKey
d.mu.RUnlock()
if peerKey == "" {
return fmt.Errorf("no current peer key")
}
upstreamIP, err := d.getUpstreamIP(peerKey)
if err != nil {
return fmt.Errorf("get upstream IP: %v", err)
}
msg := new(dns.Msg)
msg.SetQuestion(dns.Fqdn(qDomain), dns.TypeA)
msg.Id = dns.Id()
msg.MsgHdr.AuthenticatedData = true
reply, err := d.exchangeWithUpstream(ctx, msg, upstreamIP)
if err != nil {
return fmt.Errorf("exchange with upstream: %v", err)
}
if reply == nil || len(reply.Answer) == 0 {
return nil
}
resolvedDomain := domain.Domain(dns.Fqdn(qDomain))
return d.processResolveResponse(reply, resolvedDomain, resolvedDomain)
}
// exchangeWithUpstream performs DNS exchange with the upstream server
func (d *DnsInterceptor) exchangeWithUpstream(ctx context.Context, msg *dns.Msg, upstreamIP netip.Addr) (*dns.Msg, error) {
client := &dns.Client{
Timeout: nbdns.UpstreamTimeout,
Net: "udp",
}
upstream := fmt.Sprintf("%s:%d", upstreamIP.String(), dnsfwd.ListenPort)
reply, _, err := nbdns.ExchangeWithFallback(ctx, client, msg, upstream)
return reply, err
}
// extractIPsFromDNSResponse extracts IP addresses from DNS answer records
func (d *DnsInterceptor) extractIPsFromDNSResponse(reply *dns.Msg, domainForLogging domain.Domain) []netip.Prefix {
if reply == nil || len(reply.Answer) == 0 {
return nil
}
var prefixes []netip.Prefix
for _, answer := range reply.Answer {
var ip netip.Addr
switch rr := answer.(type) {
case *dns.A:
addr, ok := netip.AddrFromSlice(rr.A)
if !ok {
log.Tracef("failed to convert A record for domain=%s ip=%v", domainForLogging, rr.A)
continue
}
ip = addr
case *dns.AAAA:
addr, ok := netip.AddrFromSlice(rr.AAAA)
if !ok {
log.Tracef("failed to convert AAAA record for domain=%s ip=%v", domainForLogging, rr.AAAA)
continue
}
ip = addr
default:
continue
}
prefix := netip.PrefixFrom(ip.Unmap(), ip.BitLen())
prefixes = append(prefixes, prefix)
}
return prefixes
}
// processResolveResponse extracts IPs from DNS response and updates domain prefixes
func (d *DnsInterceptor) processResolveResponse(reply *dns.Msg, resolvedDomain, originalDomain domain.Domain) error {
newPrefixes := d.extractIPsFromDNSResponse(reply, resolvedDomain)
if len(newPrefixes) > 0 {
return d.updateDomainPrefixes(resolvedDomain, originalDomain, newPrefixes)
}
return nil
}
func (d *DnsInterceptor) RemoveRoute() error {
d.mu.Lock()
var merr *multierror.Error
for domain, prefixes := range d.interceptedDomains {
for _, prefix := range prefixes {
// Routes should use fake IPs
routePrefix := d.transformRealToFakePrefix(prefix)
if _, err := d.routeRefCounter.Decrement(routePrefix); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove dynamic route for IP %s: %v", routePrefix, err))
if _, err := d.routeRefCounter.Decrement(prefix); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove dynamic route for IP %s: %v", prefix, err))
}
// AllowedIPs should use real IPs
if d.currentPeerKey != "" {
if _, err := d.allowedIPsRefcounter.Decrement(prefix); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove allowed IP %s: %v", prefix, err))
@@ -97,10 +192,8 @@ func (d *DnsInterceptor) RemoveRoute() error {
}
}
log.Debugf("removed dynamic route(s) for [%s]: %s", domain.SafeString(), strings.ReplaceAll(fmt.Sprintf("%s", prefixes), " ", ", "))
}
d.cleanupDNATMappings()
for _, domain := range d.route.Domains {
d.statusRecorder.DeleteResolvedDomainsStates(domain)
}
@@ -113,68 +206,6 @@ func (d *DnsInterceptor) RemoveRoute() error {
return nberrors.FormatErrorOrNil(merr)
}
// transformRealToFakePrefix returns fake IP prefix for routes (if DNAT enabled)
func (d *DnsInterceptor) transformRealToFakePrefix(realPrefix netip.Prefix) netip.Prefix {
if _, hasDNAT := d.internalDnatFw(); !hasDNAT {
return realPrefix
}
if fakeIP, ok := d.fakeIPManager.GetFakeIP(realPrefix.Addr()); ok {
return netip.PrefixFrom(fakeIP, realPrefix.Bits())
}
return realPrefix
}
// addAllowedIPForPrefix handles the AllowedIPs logic for a single prefix (uses real IPs)
func (d *DnsInterceptor) addAllowedIPForPrefix(realPrefix netip.Prefix, peerKey string, domain domain.Domain) error {
// AllowedIPs always use real IPs
ref, err := d.allowedIPsRefcounter.Increment(realPrefix, peerKey)
if err != nil {
return fmt.Errorf("add allowed IP %s: %v", realPrefix, err)
}
if ref.Count > 1 && ref.Out != peerKey {
log.Warnf("IP [%s] for domain [%s] is already routed by peer [%s]. HA routing disabled",
realPrefix.Addr(),
domain.SafeString(),
ref.Out,
)
}
return nil
}
// addRouteAndAllowedIP handles both route and AllowedIPs addition for a prefix
func (d *DnsInterceptor) addRouteAndAllowedIP(realPrefix netip.Prefix, domain domain.Domain) error {
// Routes use fake IPs (so traffic to fake IPs gets routed to interface)
routePrefix := d.transformRealToFakePrefix(realPrefix)
if _, err := d.routeRefCounter.Increment(routePrefix, struct{}{}); err != nil {
return fmt.Errorf("add route for IP %s: %v", routePrefix, err)
}
// Add to AllowedIPs if we have a current peer (uses real IPs)
if d.currentPeerKey == "" {
return nil
}
return d.addAllowedIPForPrefix(realPrefix, d.currentPeerKey, domain)
}
// removeAllowedIP handles AllowedIPs removal for a prefix (uses real IPs)
func (d *DnsInterceptor) removeAllowedIP(realPrefix netip.Prefix) error {
if d.currentPeerKey == "" {
return nil
}
// AllowedIPs use real IPs
if _, err := d.allowedIPsRefcounter.Decrement(realPrefix); err != nil {
return fmt.Errorf("remove allowed IP %s: %v", realPrefix, err)
}
return nil
}
func (d *DnsInterceptor) AddAllowedIPs(peerKey string) error {
d.mu.Lock()
defer d.mu.Unlock()
@@ -182,14 +213,20 @@ func (d *DnsInterceptor) AddAllowedIPs(peerKey string) error {
var merr *multierror.Error
for domain, prefixes := range d.interceptedDomains {
for _, prefix := range prefixes {
// AllowedIPs use real IPs
if err := d.addAllowedIPForPrefix(prefix, peerKey, domain); err != nil {
merr = multierror.Append(merr, err)
if ref, err := d.allowedIPsRefcounter.Increment(prefix, peerKey); err != nil {
merr = multierror.Append(merr, fmt.Errorf("add allowed IP %s: %v", prefix, err))
} else if ref.Count > 1 && ref.Out != peerKey {
log.Warnf("IP [%s] for domain [%s] is already routed by peer [%s]. HA routing disabled",
prefix.Addr(),
domain.SafeString(),
ref.Out,
)
}
}
}
d.currentPeerKey = peerKey
go d.preResolveDomains()
return nberrors.FormatErrorOrNil(merr)
}
@@ -200,7 +237,6 @@ func (d *DnsInterceptor) RemoveAllowedIPs() error {
var merr *multierror.Error
for _, prefixes := range d.interceptedDomains {
for _, prefix := range prefixes {
// AllowedIPs use real IPs
if _, err := d.allowedIPsRefcounter.Decrement(prefix); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove allowed IP %s: %v", prefix, err))
}
@@ -213,18 +249,15 @@ func (d *DnsInterceptor) RemoveAllowedIPs() error {
// ServeDNS implements the dns.Handler interface
func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
requestID := nbdns.GenerateRequestID()
logger := log.WithField("request_id", requestID)
if len(r.Question) == 0 {
return
}
logger.Tracef("received DNS request for domain=%s type=%v class=%v",
log.Tracef("received DNS request for domain=%s type=%v class=%v",
r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass)
// pass if non A/AAAA query
if r.Question[0].Qtype != dns.TypeA && r.Question[0].Qtype != dns.TypeAAAA {
d.continueToNextHandler(w, r, logger, "non A/AAAA query")
d.continueToNextHandler(w, r, "non A/AAAA query")
return
}
@@ -233,19 +266,13 @@ func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
d.mu.RUnlock()
if peerKey == "" {
d.writeDNSError(w, r, logger, "no current peer key")
d.writeDNSError(w, r, "no current peer key")
return
}
upstreamIP, err := d.getUpstreamIP(peerKey)
if err != nil {
d.writeDNSError(w, r, logger, fmt.Sprintf("get upstream IP: %v", err))
return
}
client, err := nbdns.GetClientPrivate(d.wgInterface.Address().IP, d.wgInterface.Name(), nbdns.UpstreamTimeout)
if err != nil {
d.writeDNSError(w, r, logger, fmt.Sprintf("create DNS client: %v", err))
d.writeDNSError(w, r, fmt.Sprintf("get upstream IP: %v", err))
return
}
@@ -253,12 +280,11 @@ func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
r.MsgHdr.AuthenticatedData = true
}
upstream := fmt.Sprintf("%s:%d", upstreamIP.String(), dnsfwd.ListenPort)
reply, _, err := nbdns.ExchangeWithFallback(context.TODO(), client, r, upstream)
reply, err := d.exchangeWithUpstream(context.TODO(), r, upstreamIP)
if err != nil {
logger.Errorf("failed to exchange DNS request with %s (%s) for domain=%s: %v", upstreamIP.String(), peerKey, r.Question[0].Name, err)
log.Errorf("failed to exchange DNS request with %s (%s) for domain=%s: %v", upstreamIP.String(), peerKey, r.Question[0].Name, err)
if err := w.WriteMsg(&dns.Msg{MsgHdr: dns.MsgHdr{Rcode: dns.RcodeServerFailure, Id: r.Id}}); err != nil {
logger.Errorf("failed writing DNS response: %v", err)
log.Errorf("failed writing DNS response: %v", err)
}
return
}
@@ -268,34 +294,34 @@ func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
answer = reply.Answer
}
logger.Tracef("upstream %s (%s) DNS response for domain=%s answers=%v", upstreamIP.String(), peerKey, r.Question[0].Name, answer)
log.Tracef("upstream %s (%s) DNS response for domain=%s answers=%v", upstreamIP.String(), peerKey, r.Question[0].Name, answer)
reply.Id = r.Id
if err := d.writeMsg(w, reply); err != nil {
logger.Errorf("failed writing DNS response: %v", err)
log.Errorf("failed writing DNS response: %v", err)
}
}
func (d *DnsInterceptor) writeDNSError(w dns.ResponseWriter, r *dns.Msg, logger *log.Entry, reason string) {
logger.Warnf("failed to query upstream for domain=%s: %s", r.Question[0].Name, reason)
func (d *DnsInterceptor) writeDNSError(w dns.ResponseWriter, r *dns.Msg, reason string) {
log.Warnf("failed to query upstream for domain=%s: %s", r.Question[0].Name, reason)
resp := new(dns.Msg)
resp.SetRcode(r, dns.RcodeServerFailure)
if err := w.WriteMsg(resp); err != nil {
logger.Errorf("failed to write DNS error response: %v", err)
log.Errorf("failed to write DNS error response: %v", err)
}
}
// continueToNextHandler signals the handler chain to try the next handler
func (d *DnsInterceptor) continueToNextHandler(w dns.ResponseWriter, r *dns.Msg, logger *log.Entry, reason string) {
logger.Tracef("continuing to next handler for domain=%s reason=%s", r.Question[0].Name, reason)
func (d *DnsInterceptor) continueToNextHandler(w dns.ResponseWriter, r *dns.Msg, reason string) {
log.Tracef("continuing to next handler for domain=%s reason=%s", r.Question[0].Name, reason)
resp := new(dns.Msg)
resp.SetRcode(r, dns.RcodeNameError)
// Set Zero bit to signal handler chain to continue
resp.MsgHdr.Zero = true
if err := w.WriteMsg(resp); err != nil {
logger.Errorf("failed writing DNS continue response: %v", err)
log.Errorf("failed writing DNS continue response: %v", err)
}
}
@@ -319,45 +345,13 @@ func (d *DnsInterceptor) writeMsg(w dns.ResponseWriter, r *dns.Msg) error {
}
resolvedDomain := domain.Domain(strings.ToLower(r.Question[0].Name))
// already punycode via RegisterHandler()
originalDomain := domain.Domain(origPattern)
if originalDomain == "" {
originalDomain = resolvedDomain
}
var newPrefixes []netip.Prefix
for _, answer := range r.Answer {
var ip netip.Addr
switch rr := answer.(type) {
case *dns.A:
addr, ok := netip.AddrFromSlice(rr.A)
if !ok {
log.Tracef("failed to convert A record for domain=%s ip=%v", resolvedDomain, rr.A)
continue
}
ip = addr
case *dns.AAAA:
addr, ok := netip.AddrFromSlice(rr.AAAA)
if !ok {
log.Tracef("failed to convert AAAA record for domain=%s ip=%v", resolvedDomain, rr.AAAA)
continue
}
ip = addr
default:
continue
}
prefix := netip.PrefixFrom(ip.Unmap(), ip.BitLen())
newPrefixes = append(newPrefixes, prefix)
}
if len(newPrefixes) > 0 {
if err := d.updateDomainPrefixes(resolvedDomain, originalDomain, newPrefixes); err != nil {
log.Errorf("failed to update domain prefixes: %v", err)
}
d.replaceIPsInDNSResponse(r, newPrefixes)
if err := d.processResolveResponse(r, resolvedDomain, originalDomain); err != nil {
log.Errorf("failed to process DNS response: %v", err)
}
}
@@ -368,22 +362,6 @@ func (d *DnsInterceptor) writeMsg(w dns.ResponseWriter, r *dns.Msg) error {
return nil
}
// logPrefixChanges handles the logging for prefix changes
func (d *DnsInterceptor) logPrefixChanges(resolvedDomain, originalDomain domain.Domain, toAdd, toRemove []netip.Prefix) {
if len(toAdd) > 0 {
log.Debugf("added dynamic route(s) for domain=%s (pattern: domain=%s): %s",
resolvedDomain.SafeString(),
originalDomain.SafeString(),
toAdd)
}
if len(toRemove) > 0 && !d.route.KeepRoute {
log.Debugf("removed dynamic route(s) for domain=%s (pattern: domain=%s): %s",
resolvedDomain.SafeString(),
originalDomain.SafeString(),
toRemove)
}
}
func (d *DnsInterceptor) updateDomainPrefixes(resolvedDomain, originalDomain domain.Domain, newPrefixes []netip.Prefix) error {
d.mu.Lock()
defer d.mu.Unlock()
@@ -392,163 +370,70 @@ func (d *DnsInterceptor) updateDomainPrefixes(resolvedDomain, originalDomain dom
toAdd, toRemove := determinePrefixChanges(oldPrefixes, newPrefixes)
var merr *multierror.Error
var dnatMappings map[netip.Addr]netip.Addr
// Handle DNAT mappings for new prefixes
if _, hasDNAT := d.internalDnatFw(); hasDNAT {
dnatMappings = make(map[netip.Addr]netip.Addr)
for _, prefix := range toAdd {
realIP := prefix.Addr()
if fakeIP, err := d.fakeIPManager.AllocateFakeIP(realIP); err == nil {
dnatMappings[fakeIP] = realIP
log.Tracef("allocated fake IP %s for real IP %s", fakeIP, realIP)
} else {
log.Errorf("Failed to allocate fake IP for %s: %v", realIP, err)
}
}
}
// Add new prefixes
for _, prefix := range toAdd {
if err := d.addRouteAndAllowedIP(prefix, resolvedDomain); err != nil {
merr = multierror.Append(merr, err)
if _, err := d.routeRefCounter.Increment(prefix, struct{}{}); err != nil {
merr = multierror.Append(merr, fmt.Errorf("add route for IP %s: %v", prefix, err))
continue
}
if d.currentPeerKey == "" {
continue
}
if ref, err := d.allowedIPsRefcounter.Increment(prefix, d.currentPeerKey); err != nil {
merr = multierror.Append(merr, fmt.Errorf("add allowed IP %s: %v", prefix, err))
} else if ref.Count > 1 && ref.Out != d.currentPeerKey {
log.Warnf("IP [%s] for domain [%s] is already routed by peer [%s]. HA routing disabled",
prefix.Addr(),
resolvedDomain.SafeString(),
ref.Out,
)
}
}
d.addDNATMappings(dnatMappings)
if !d.route.KeepRoute {
// Remove old prefixes
for _, prefix := range toRemove {
// Routes use fake IPs
routePrefix := d.transformRealToFakePrefix(prefix)
if _, err := d.routeRefCounter.Decrement(routePrefix); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove route for IP %s: %v", routePrefix, err))
if _, err := d.routeRefCounter.Decrement(prefix); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove route for IP %s: %v", prefix, err))
}
// AllowedIPs use real IPs
if err := d.removeAllowedIP(prefix); err != nil {
merr = multierror.Append(merr, err)
if d.currentPeerKey != "" {
if _, err := d.allowedIPsRefcounter.Decrement(prefix); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove allowed IP %s: %v", prefix, err))
}
}
}
d.removeDNATMappings(toRemove)
}
// Update domain prefixes using resolved domain as key - store real IPs
// Update domain prefixes using resolved domain as key
if len(toAdd) > 0 || len(toRemove) > 0 {
if d.route.KeepRoute {
// replace stored prefixes with old + added
// nolint:gocritic
newPrefixes = append(oldPrefixes, toAdd...)
}
d.interceptedDomains[resolvedDomain] = newPrefixes
originalDomain = domain.Domain(strings.TrimSuffix(string(originalDomain), "."))
// Store real IPs for status (user-facing), not fake IPs
d.statusRecorder.UpdateResolvedDomainsStates(originalDomain, resolvedDomain, newPrefixes, d.route.GetResourceID())
d.logPrefixChanges(resolvedDomain, originalDomain, toAdd, toRemove)
if len(toAdd) > 0 {
log.Debugf("added dynamic route(s) for domain=%s (pattern: domain=%s): %s",
resolvedDomain.SafeString(),
originalDomain.SafeString(),
toAdd)
}
if len(toRemove) > 0 && !d.route.KeepRoute {
log.Debugf("removed dynamic route(s) for domain=%s (pattern: domain=%s): %s",
resolvedDomain.SafeString(),
originalDomain.SafeString(),
toRemove)
}
}
return nberrors.FormatErrorOrNil(merr)
}
// removeDNATMappings removes DNAT mappings from the firewall for real IP prefixes
func (d *DnsInterceptor) removeDNATMappings(realPrefixes []netip.Prefix) {
if len(realPrefixes) == 0 {
return
}
dnatFirewall, ok := d.internalDnatFw()
if !ok {
return
}
for _, prefix := range realPrefixes {
realIP := prefix.Addr()
if fakeIP, exists := d.fakeIPManager.GetFakeIP(realIP); exists {
if err := dnatFirewall.RemoveInternalDNATMapping(fakeIP); err != nil {
log.Errorf("Failed to remove DNAT mapping for %s: %v", fakeIP, err)
} else {
log.Debugf("Removed DNAT mapping for: %s -> %s", fakeIP, realIP)
}
}
}
}
// internalDnatFw checks if the firewall supports internal DNAT
func (d *DnsInterceptor) internalDnatFw() (internalDNATer, bool) {
if d.firewall == nil || runtime.GOOS != "android" {
return nil, false
}
fw, ok := d.firewall.(internalDNATer)
return fw, ok
}
// addDNATMappings adds DNAT mappings to the firewall
func (d *DnsInterceptor) addDNATMappings(mappings map[netip.Addr]netip.Addr) {
if len(mappings) == 0 {
return
}
dnatFirewall, ok := d.internalDnatFw()
if !ok {
return
}
for fakeIP, realIP := range mappings {
if err := dnatFirewall.AddInternalDNATMapping(fakeIP, realIP); err != nil {
log.Errorf("Failed to add DNAT mapping %s -> %s: %v", fakeIP, realIP, err)
} else {
log.Debugf("Added DNAT mapping: %s -> %s", fakeIP, realIP)
}
}
}
// cleanupDNATMappings removes all DNAT mappings for this interceptor
func (d *DnsInterceptor) cleanupDNATMappings() {
if _, ok := d.internalDnatFw(); !ok {
return
}
for _, prefixes := range d.interceptedDomains {
d.removeDNATMappings(prefixes)
}
}
// replaceIPsInDNSResponse replaces real IPs with fake IPs in the DNS response
func (d *DnsInterceptor) replaceIPsInDNSResponse(reply *dns.Msg, realPrefixes []netip.Prefix) {
if _, ok := d.internalDnatFw(); !ok {
return
}
// Replace A and AAAA records with fake IPs
for _, answer := range reply.Answer {
switch rr := answer.(type) {
case *dns.A:
realIP, ok := netip.AddrFromSlice(rr.A)
if !ok {
continue
}
if fakeIP, exists := d.fakeIPManager.GetFakeIP(realIP); exists {
rr.A = fakeIP.AsSlice()
log.Tracef("Replaced real IP %s with fake IP %s in DNS response", realIP, fakeIP)
}
case *dns.AAAA:
realIP, ok := netip.AddrFromSlice(rr.AAAA)
if !ok {
continue
}
if fakeIP, exists := d.fakeIPManager.GetFakeIP(realIP); exists {
rr.AAAA = fakeIP.AsSlice()
log.Tracef("Replaced real IP %s with fake IP %s in DNS response", realIP, fakeIP)
}
}
}
}
func determinePrefixChanges(oldPrefixes, newPrefixes []netip.Prefix) (toAdd, toRemove []netip.Prefix) {
prefixSet := make(map[netip.Prefix]bool)
for _, prefix := range oldPrefixes {

View File

@@ -14,7 +14,6 @@ import (
nberrors "github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/routemanager/common"
"github.com/netbirdio/netbird/client/internal/routemanager/iface"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
"github.com/netbirdio/netbird/client/internal/routemanager/util"
@@ -53,16 +52,24 @@ type Route struct {
resolverAddr string
}
func NewRoute(params common.HandlerParams, resolverAddr string) *Route {
func NewRoute(
rt *route.Route,
routeRefCounter *refcounter.RouteRefCounter,
allowedIPsRefCounter *refcounter.AllowedIPsRefCounter,
interval time.Duration,
statusRecorder *peer.Status,
wgInterface iface.WGIface,
resolverAddr string,
) *Route {
return &Route{
route: params.Route,
routeRefCounter: params.RouteRefCounter,
allowedIPsRefcounter: params.AllowedIPsRefCounter,
interval: params.DnsRouterInterval,
statusRecorder: params.StatusRecorder,
wgInterface: params.WgInterface,
resolverAddr: resolverAddr,
route: rt,
routeRefCounter: routeRefCounter,
allowedIPsRefcounter: allowedIPsRefCounter,
interval: interval,
dynamicDomains: domainMap{},
statusRecorder: statusRecorder,
wgInterface: wgInterface,
resolverAddr: resolverAddr,
}
}

View File

@@ -1,93 +0,0 @@
package fakeip
import (
"fmt"
"net/netip"
"sync"
)
// Manager manages allocation of fake IPs from the 240.0.0.0/8 block
type Manager struct {
mu sync.Mutex
nextIP netip.Addr // Next IP to allocate
allocated map[netip.Addr]netip.Addr // real IP -> fake IP
fakeToReal map[netip.Addr]netip.Addr // fake IP -> real IP
baseIP netip.Addr // First usable IP: 240.0.0.1
maxIP netip.Addr // Last usable IP: 240.255.255.254
}
// NewManager creates a new fake IP manager using 240.0.0.0/8 block
func NewManager() *Manager {
baseIP := netip.AddrFrom4([4]byte{240, 0, 0, 1})
maxIP := netip.AddrFrom4([4]byte{240, 255, 255, 254})
return &Manager{
nextIP: baseIP,
allocated: make(map[netip.Addr]netip.Addr),
fakeToReal: make(map[netip.Addr]netip.Addr),
baseIP: baseIP,
maxIP: maxIP,
}
}
// AllocateFakeIP allocates a fake IP for the given real IP
// Returns the fake IP, or existing fake IP if already allocated
func (m *Manager) AllocateFakeIP(realIP netip.Addr) (netip.Addr, error) {
if !realIP.Is4() {
return netip.Addr{}, fmt.Errorf("only IPv4 addresses supported")
}
m.mu.Lock()
defer m.mu.Unlock()
if fakeIP, exists := m.allocated[realIP]; exists {
return fakeIP, nil
}
startIP := m.nextIP
for {
currentIP := m.nextIP
// Advance to next IP, wrapping at boundary
if m.nextIP.Compare(m.maxIP) >= 0 {
m.nextIP = m.baseIP
} else {
m.nextIP = m.nextIP.Next()
}
// Check if current IP is available
if _, inUse := m.fakeToReal[currentIP]; !inUse {
m.allocated[realIP] = currentIP
m.fakeToReal[currentIP] = realIP
return currentIP, nil
}
// Prevent infinite loop if all IPs exhausted
if m.nextIP.Compare(startIP) == 0 {
return netip.Addr{}, fmt.Errorf("no more fake IPs available in 240.0.0.0/8 block")
}
}
}
// GetFakeIP returns the fake IP for a real IP if it exists
func (m *Manager) GetFakeIP(realIP netip.Addr) (netip.Addr, bool) {
m.mu.Lock()
defer m.mu.Unlock()
fakeIP, exists := m.allocated[realIP]
return fakeIP, exists
}
// GetRealIP returns the real IP for a fake IP if it exists, otherwise false
func (m *Manager) GetRealIP(fakeIP netip.Addr) (netip.Addr, bool) {
m.mu.Lock()
defer m.mu.Unlock()
realIP, exists := m.fakeToReal[fakeIP]
return realIP, exists
}
// GetFakeIPBlock returns the fake IP block used by this manager
func (m *Manager) GetFakeIPBlock() netip.Prefix {
return netip.MustParsePrefix("240.0.0.0/8")
}

View File

@@ -1,240 +0,0 @@
package fakeip
import (
"net/netip"
"sync"
"testing"
)
func TestNewManager(t *testing.T) {
manager := NewManager()
if manager.baseIP.String() != "240.0.0.1" {
t.Errorf("Expected base IP 240.0.0.1, got %s", manager.baseIP.String())
}
if manager.maxIP.String() != "240.255.255.254" {
t.Errorf("Expected max IP 240.255.255.254, got %s", manager.maxIP.String())
}
if manager.nextIP.Compare(manager.baseIP) != 0 {
t.Errorf("Expected nextIP to start at baseIP")
}
}
func TestAllocateFakeIP(t *testing.T) {
manager := NewManager()
realIP := netip.MustParseAddr("8.8.8.8")
fakeIP, err := manager.AllocateFakeIP(realIP)
if err != nil {
t.Fatalf("Failed to allocate fake IP: %v", err)
}
if !fakeIP.Is4() {
t.Error("Fake IP should be IPv4")
}
// Check it's in the correct range
if fakeIP.As4()[0] != 240 {
t.Errorf("Fake IP should be in 240.0.0.0/8 range, got %s", fakeIP.String())
}
// Should return same fake IP for same real IP
fakeIP2, err := manager.AllocateFakeIP(realIP)
if err != nil {
t.Fatalf("Failed to get existing fake IP: %v", err)
}
if fakeIP.Compare(fakeIP2) != 0 {
t.Errorf("Expected same fake IP for same real IP, got %s and %s", fakeIP.String(), fakeIP2.String())
}
}
func TestAllocateFakeIPIPv6Rejection(t *testing.T) {
manager := NewManager()
realIPv6 := netip.MustParseAddr("2001:db8::1")
_, err := manager.AllocateFakeIP(realIPv6)
if err == nil {
t.Error("Expected error for IPv6 address")
}
}
func TestGetFakeIP(t *testing.T) {
manager := NewManager()
realIP := netip.MustParseAddr("1.1.1.1")
// Should not exist initially
_, exists := manager.GetFakeIP(realIP)
if exists {
t.Error("Fake IP should not exist before allocation")
}
// Allocate and check
expectedFakeIP, err := manager.AllocateFakeIP(realIP)
if err != nil {
t.Fatalf("Failed to allocate: %v", err)
}
fakeIP, exists := manager.GetFakeIP(realIP)
if !exists {
t.Error("Fake IP should exist after allocation")
}
if fakeIP.Compare(expectedFakeIP) != 0 {
t.Errorf("Expected %s, got %s", expectedFakeIP.String(), fakeIP.String())
}
}
func TestMultipleAllocations(t *testing.T) {
manager := NewManager()
allocations := make(map[netip.Addr]netip.Addr)
// Allocate multiple IPs
for i := 1; i <= 100; i++ {
realIP := netip.AddrFrom4([4]byte{10, 0, byte(i / 256), byte(i % 256)})
fakeIP, err := manager.AllocateFakeIP(realIP)
if err != nil {
t.Fatalf("Failed to allocate fake IP for %s: %v", realIP.String(), err)
}
// Check for duplicates
for _, existingFake := range allocations {
if fakeIP.Compare(existingFake) == 0 {
t.Errorf("Duplicate fake IP allocated: %s", fakeIP.String())
}
}
allocations[realIP] = fakeIP
}
// Verify all allocations can be retrieved
for realIP, expectedFake := range allocations {
actualFake, exists := manager.GetFakeIP(realIP)
if !exists {
t.Errorf("Missing allocation for %s", realIP.String())
}
if actualFake.Compare(expectedFake) != 0 {
t.Errorf("Mismatch for %s: expected %s, got %s", realIP.String(), expectedFake.String(), actualFake.String())
}
}
}
func TestGetFakeIPBlock(t *testing.T) {
manager := NewManager()
block := manager.GetFakeIPBlock()
expected := "240.0.0.0/8"
if block.String() != expected {
t.Errorf("Expected %s, got %s", expected, block.String())
}
}
func TestConcurrentAccess(t *testing.T) {
manager := NewManager()
const numGoroutines = 50
const allocationsPerGoroutine = 10
var wg sync.WaitGroup
results := make(chan netip.Addr, numGoroutines*allocationsPerGoroutine)
// Concurrent allocations
for i := 0; i < numGoroutines; i++ {
wg.Add(1)
go func(goroutineID int) {
defer wg.Done()
for j := 0; j < allocationsPerGoroutine; j++ {
realIP := netip.AddrFrom4([4]byte{192, 168, byte(goroutineID), byte(j)})
fakeIP, err := manager.AllocateFakeIP(realIP)
if err != nil {
t.Errorf("Failed to allocate in goroutine %d: %v", goroutineID, err)
return
}
results <- fakeIP
}
}(i)
}
wg.Wait()
close(results)
// Check for duplicates
seen := make(map[netip.Addr]bool)
count := 0
for fakeIP := range results {
if seen[fakeIP] {
t.Errorf("Duplicate fake IP in concurrent test: %s", fakeIP.String())
}
seen[fakeIP] = true
count++
}
if count != numGoroutines*allocationsPerGoroutine {
t.Errorf("Expected %d allocations, got %d", numGoroutines*allocationsPerGoroutine, count)
}
}
func TestIPExhaustion(t *testing.T) {
// Create a manager with limited range for testing
manager := &Manager{
nextIP: netip.AddrFrom4([4]byte{240, 0, 0, 1}),
allocated: make(map[netip.Addr]netip.Addr),
fakeToReal: make(map[netip.Addr]netip.Addr),
baseIP: netip.AddrFrom4([4]byte{240, 0, 0, 1}),
maxIP: netip.AddrFrom4([4]byte{240, 0, 0, 3}), // Only 3 IPs available
}
// Allocate all available IPs
realIPs := []netip.Addr{
netip.MustParseAddr("1.0.0.1"),
netip.MustParseAddr("1.0.0.2"),
netip.MustParseAddr("1.0.0.3"),
}
for _, realIP := range realIPs {
_, err := manager.AllocateFakeIP(realIP)
if err != nil {
t.Fatalf("Failed to allocate fake IP: %v", err)
}
}
// Try to allocate one more - should fail
_, err := manager.AllocateFakeIP(netip.MustParseAddr("1.0.0.4"))
if err == nil {
t.Error("Expected exhaustion error")
}
}
func TestWrapAround(t *testing.T) {
// Create manager starting near the end of range
manager := &Manager{
nextIP: netip.AddrFrom4([4]byte{240, 0, 0, 254}),
allocated: make(map[netip.Addr]netip.Addr),
fakeToReal: make(map[netip.Addr]netip.Addr),
baseIP: netip.AddrFrom4([4]byte{240, 0, 0, 1}),
maxIP: netip.AddrFrom4([4]byte{240, 0, 0, 254}),
}
// Allocate the last IP
fakeIP1, err := manager.AllocateFakeIP(netip.MustParseAddr("1.0.0.1"))
if err != nil {
t.Fatalf("Failed to allocate first IP: %v", err)
}
if fakeIP1.String() != "240.0.0.254" {
t.Errorf("Expected 240.0.0.254, got %s", fakeIP1.String())
}
// Next allocation should wrap around to the beginning
fakeIP2, err := manager.AllocateFakeIP(netip.MustParseAddr("1.0.0.2"))
if err != nil {
t.Fatalf("Failed to allocate second IP: %v", err)
}
if fakeIP2.String() != "240.0.0.1" {
t.Errorf("Expected 240.0.0.1 after wrap, got %s", fakeIP2.String())
}
}

View File

@@ -8,11 +8,9 @@ import (
"net/netip"
"net/url"
"runtime"
"slices"
"sync"
"time"
"github.com/google/uuid"
"github.com/hashicorp/go-multierror"
log "github.com/sirupsen/logrus"
"golang.org/x/exp/maps"
@@ -26,8 +24,6 @@ import (
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/peerstore"
"github.com/netbirdio/netbird/client/internal/routemanager/client"
"github.com/netbirdio/netbird/client/internal/routemanager/common"
"github.com/netbirdio/netbird/client/internal/routemanager/fakeip"
"github.com/netbirdio/netbird/client/internal/routemanager/iface"
"github.com/netbirdio/netbird/client/internal/routemanager/notifier"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
@@ -53,7 +49,7 @@ type Manager interface {
GetClientRoutesWithNetID() map[route.NetID][]*route.Route
SetRouteChangeListener(listener listener.NetworkChangeListener)
InitialRouteRange() []string
SetFirewall(firewall.Manager) error
EnableServerRouter(firewall firewall.Manager) error
Stop(stateManager *statemanager.Manager)
}
@@ -67,7 +63,6 @@ type ManagerConfig struct {
InitialRoutes []*route.Route
StateManager *statemanager.Manager
DNSServer dns.Server
DNSFeatureFlag bool
PeerStore *peerstore.Store
DisableClientRoutes bool
DisableServerRoutes bool
@@ -94,13 +89,11 @@ type DefaultManager struct {
// clientRoutes is the most recent list of clientRoutes received from the Management Service
clientRoutes route.HAMap
dnsServer dns.Server
firewall firewall.Manager
peerStore *peerstore.Store
useNewDNSRoute bool
disableClientRoutes bool
disableServerRoutes bool
activeRoutes map[route.HAUniqueID]client.RouteHandler
fakeIPManager *fakeip.Manager
}
func NewManager(config ManagerConfig) *DefaultManager {
@@ -136,31 +129,11 @@ func NewManager(config ManagerConfig) *DefaultManager {
}
if runtime.GOOS == "android" {
dm.setupAndroidRoutes(config)
cr := dm.initialClientRoutes(config.InitialRoutes)
dm.notifier.SetInitialClientRoutes(cr)
}
return dm
}
func (m *DefaultManager) setupAndroidRoutes(config ManagerConfig) {
cr := m.initialClientRoutes(config.InitialRoutes)
routesForComparison := slices.Clone(cr)
if config.DNSFeatureFlag {
m.fakeIPManager = fakeip.NewManager()
id := uuid.NewString()
fakeIPRoute := &route.Route{
ID: route.ID(id),
Network: m.fakeIPManager.GetFakeIPBlock(),
NetID: route.NetID(id),
Peer: m.pubKey,
NetworkType: route.IPv4Network,
}
cr = append(cr, fakeIPRoute)
}
m.notifier.SetInitialClientRoutes(cr, routesForComparison)
}
func (m *DefaultManager) setupRefCounters(useNoop bool) {
m.routeRefCounter = refcounter.New(
@@ -249,16 +222,16 @@ func (m *DefaultManager) initSelector() *routeselector.RouteSelector {
return routeselector.NewRouteSelector()
}
// SetFirewall sets the firewall manager for the DefaultManager
// Not thread-safe, should be called before starting the manager
func (m *DefaultManager) SetFirewall(firewall firewall.Manager) error {
m.firewall = firewall
if m.disableServerRoutes || firewall == nil {
func (m *DefaultManager) EnableServerRouter(firewall firewall.Manager) error {
if m.disableServerRoutes {
log.Info("server routes are disabled")
return nil
}
if firewall == nil {
return errors.New("firewall manager is not set")
}
var err error
m.serverRouter, err = server.NewRouter(m.ctx, m.wgInterface, firewall, m.statusRecorder)
if err != nil {
@@ -326,20 +299,17 @@ func (m *DefaultManager) updateSystemRoutes(newRoutes route.HAMap) error {
}
for id, route := range toAdd {
params := common.HandlerParams{
Route: route,
RouteRefCounter: m.routeRefCounter,
AllowedIPsRefCounter: m.allowedIPsRefCounter,
DnsRouterInterval: m.dnsRouteInterval,
StatusRecorder: m.statusRecorder,
WgInterface: m.wgInterface,
DnsServer: m.dnsServer,
PeerStore: m.peerStore,
UseNewDNSRoute: m.useNewDNSRoute,
Firewall: m.firewall,
FakeIPManager: m.fakeIPManager,
}
handler := client.HandlerFromRoute(params)
handler := client.HandlerFromRoute(
route,
m.routeRefCounter,
m.allowedIPsRefCounter,
m.dnsRouteInterval,
m.statusRecorder,
m.wgInterface,
m.dnsServer,
m.peerStore,
m.useNewDNSRoute,
)
if err := handler.AddRoute(m.ctx); err != nil {
merr = multierror.Append(merr, fmt.Errorf("add route %s: %w", handler.String(), err))
continue
@@ -547,7 +517,6 @@ func (m *DefaultManager) initialClientRoutes(initialRoutes []*route.Route) []*ro
for _, routes := range crMap {
rs = append(rs, routes...)
}
return rs
}

View File

@@ -15,7 +15,7 @@ import (
// MockManager is the mock instance of a route manager
type MockManager struct {
ClassifyRoutesFunc func(routes []*route.Route) (map[route.ID]*route.Route, route.HAMap)
UpdateRoutesFunc func(updateSerial uint64, serverRoutes map[route.ID]*route.Route, clientRoutes route.HAMap, useNewDNSRoute bool) error
UpdateRoutesFunc func (updateSerial uint64, serverRoutes map[route.ID]*route.Route, clientRoutes route.HAMap, useNewDNSRoute bool) error
TriggerSelectionFunc func(haMap route.HAMap)
GetRouteSelectorFunc func() *routeselector.RouteSelector
GetClientRoutesFunc func() route.HAMap
@@ -87,7 +87,7 @@ func (m *MockManager) SetRouteChangeListener(listener listener.NetworkChangeList
}
func (m *MockManager) SetFirewall(firewall.Manager) error {
func (m *MockManager) EnableServerRouter(firewall firewall.Manager) error {
panic("implement me")
}

View File

@@ -0,0 +1,136 @@
package notifier
import (
"net/netip"
"runtime"
"sort"
"strings"
"sync"
"github.com/netbirdio/netbird/client/internal/listener"
"github.com/netbirdio/netbird/route"
)
type Notifier struct {
initialRouteRanges []string
routeRanges []string
listener listener.NetworkChangeListener
listenerMux sync.Mutex
}
func NewNotifier() *Notifier {
return &Notifier{}
}
func (n *Notifier) SetListener(listener listener.NetworkChangeListener) {
n.listenerMux.Lock()
defer n.listenerMux.Unlock()
n.listener = listener
}
func (n *Notifier) SetInitialClientRoutes(clientRoutes []*route.Route) {
nets := make([]string, 0)
for _, r := range clientRoutes {
// filter out domain routes
if r.IsDynamic() {
continue
}
nets = append(nets, r.Network.String())
}
sort.Strings(nets)
n.initialRouteRanges = nets
}
func (n *Notifier) OnNewRoutes(idMap route.HAMap) {
if runtime.GOOS != "android" {
return
}
newNets := make([]string, 0)
for _, routes := range idMap {
for _, r := range routes {
newNets = append(newNets, r.Network.String())
}
}
sort.Strings(newNets)
switch runtime.GOOS {
case "android":
if !n.hasDiff(n.initialRouteRanges, newNets) {
return
}
default:
if !n.hasDiff(n.routeRanges, newNets) {
return
}
}
n.routeRanges = newNets
n.notify()
}
func (n *Notifier) OnNewPrefixes(prefixes []netip.Prefix) {
newNets := make([]string, 0)
for _, prefix := range prefixes {
newNets = append(newNets, prefix.String())
}
sort.Strings(newNets)
switch runtime.GOOS {
case "android":
if !n.hasDiff(n.initialRouteRanges, newNets) {
return
}
default:
if !n.hasDiff(n.routeRanges, newNets) {
return
}
}
n.routeRanges = newNets
n.notify()
}
func (n *Notifier) notify() {
n.listenerMux.Lock()
defer n.listenerMux.Unlock()
if n.listener == nil {
return
}
go func(l listener.NetworkChangeListener) {
l.OnNetworkChanged(strings.Join(addIPv6RangeIfNeeded(n.routeRanges), ","))
}(n.listener)
}
func (n *Notifier) hasDiff(a []string, b []string) bool {
if len(a) != len(b) {
return true
}
for i, v := range a {
if v != b[i] {
return true
}
}
return false
}
func (n *Notifier) GetInitialRouteRanges() []string {
return addIPv6RangeIfNeeded(n.initialRouteRanges)
}
// addIPv6RangeIfNeeded returns the input ranges with the default IPv6 range when there is an IPv4 default route.
func addIPv6RangeIfNeeded(inputRanges []string) []string {
ranges := inputRanges
for _, r := range inputRanges {
// we are intentionally adding the ipv6 default range in case of ipv4 default range
// to ensure that all traffic is managed by the tunnel interface on android
if r == "0.0.0.0/0" {
ranges = append(ranges, "::/0")
break
}
}
return ranges
}

View File

@@ -1,127 +0,0 @@
//go:build android
package notifier
import (
"net/netip"
"slices"
"sort"
"strings"
"sync"
"github.com/netbirdio/netbird/client/internal/listener"
"github.com/netbirdio/netbird/route"
)
type Notifier struct {
initialRoutes []*route.Route
currentRoutes []*route.Route
listener listener.NetworkChangeListener
listenerMux sync.Mutex
}
func NewNotifier() *Notifier {
return &Notifier{}
}
func (n *Notifier) SetListener(listener listener.NetworkChangeListener) {
n.listenerMux.Lock()
defer n.listenerMux.Unlock()
n.listener = listener
}
func (n *Notifier) SetInitialClientRoutes(initialRoutes []*route.Route, routesForComparison []*route.Route) {
// initialRoutes contains fake IP block for interface configuration
filteredInitial := make([]*route.Route, 0)
for _, r := range initialRoutes {
if r.IsDynamic() {
continue
}
filteredInitial = append(filteredInitial, r)
}
n.initialRoutes = filteredInitial
// routesForComparison excludes fake IP block for comparison with new routes
filteredComparison := make([]*route.Route, 0)
for _, r := range routesForComparison {
if r.IsDynamic() {
continue
}
filteredComparison = append(filteredComparison, r)
}
n.currentRoutes = filteredComparison
}
func (n *Notifier) OnNewRoutes(idMap route.HAMap) {
var newRoutes []*route.Route
for _, routes := range idMap {
for _, r := range routes {
if r.IsDynamic() {
continue
}
newRoutes = append(newRoutes, r)
}
}
if !n.hasRouteDiff(n.currentRoutes, newRoutes) {
return
}
n.currentRoutes = newRoutes
n.notify()
}
func (n *Notifier) OnNewPrefixes([]netip.Prefix) {
// Not used on Android
}
func (n *Notifier) notify() {
n.listenerMux.Lock()
defer n.listenerMux.Unlock()
if n.listener == nil {
return
}
routeStrings := n.routesToStrings(n.currentRoutes)
sort.Strings(routeStrings)
go func(l listener.NetworkChangeListener) {
l.OnNetworkChanged(strings.Join(n.addIPv6RangeIfNeeded(routeStrings, n.currentRoutes), ","))
}(n.listener)
}
func (n *Notifier) routesToStrings(routes []*route.Route) []string {
nets := make([]string, 0, len(routes))
for _, r := range routes {
nets = append(nets, r.NetString())
}
return nets
}
func (n *Notifier) hasRouteDiff(a []*route.Route, b []*route.Route) bool {
slices.SortFunc(a, func(x, y *route.Route) int {
return strings.Compare(x.NetString(), y.NetString())
})
slices.SortFunc(b, func(x, y *route.Route) int {
return strings.Compare(x.NetString(), y.NetString())
})
return !slices.EqualFunc(a, b, func(x, y *route.Route) bool {
return x.NetString() == y.NetString()
})
}
func (n *Notifier) GetInitialRouteRanges() []string {
initialStrings := n.routesToStrings(n.initialRoutes)
sort.Strings(initialStrings)
return n.addIPv6RangeIfNeeded(initialStrings, n.initialRoutes)
}
func (n *Notifier) addIPv6RangeIfNeeded(inputRanges []string, routes []*route.Route) []string {
for _, r := range routes {
if r.Network.Addr().Is4() && r.Network.Bits() == 0 {
return append(slices.Clone(inputRanges), "::/0")
}
}
return inputRanges
}

View File

@@ -1,80 +0,0 @@
//go:build ios
package notifier
import (
"net/netip"
"slices"
"sort"
"strings"
"sync"
"github.com/netbirdio/netbird/client/internal/listener"
"github.com/netbirdio/netbird/route"
)
type Notifier struct {
currentPrefixes []string
listener listener.NetworkChangeListener
listenerMux sync.Mutex
}
func NewNotifier() *Notifier {
return &Notifier{}
}
func (n *Notifier) SetListener(listener listener.NetworkChangeListener) {
n.listenerMux.Lock()
defer n.listenerMux.Unlock()
n.listener = listener
}
func (n *Notifier) SetInitialClientRoutes([]*route.Route, []*route.Route) {
// iOS doesn't care about initial routes
}
func (n *Notifier) OnNewRoutes(route.HAMap) {
// Not used on iOS
}
func (n *Notifier) OnNewPrefixes(prefixes []netip.Prefix) {
newNets := make([]string, 0)
for _, prefix := range prefixes {
newNets = append(newNets, prefix.String())
}
sort.Strings(newNets)
if slices.Equal(n.currentPrefixes, newNets) {
return
}
n.currentPrefixes = newNets
n.notify()
}
func (n *Notifier) notify() {
n.listenerMux.Lock()
defer n.listenerMux.Unlock()
if n.listener == nil {
return
}
go func(l listener.NetworkChangeListener) {
l.OnNetworkChanged(strings.Join(n.addIPv6RangeIfNeeded(n.currentPrefixes), ","))
}(n.listener)
}
func (n *Notifier) GetInitialRouteRanges() []string {
return nil
}
func (n *Notifier) addIPv6RangeIfNeeded(inputRanges []string) []string {
for _, r := range inputRanges {
if r == "0.0.0.0/0" {
return append(slices.Clone(inputRanges), "::/0")
}
}
return inputRanges
}

View File

@@ -1,36 +0,0 @@
//go:build !android && !ios
package notifier
import (
"net/netip"
"github.com/netbirdio/netbird/client/internal/listener"
"github.com/netbirdio/netbird/route"
)
type Notifier struct{}
func NewNotifier() *Notifier {
return &Notifier{}
}
func (n *Notifier) SetListener(listener listener.NetworkChangeListener) {
// Not used on non-mobile platforms
}
func (n *Notifier) SetInitialClientRoutes([]*route.Route, []*route.Route) {
// Not used on non-mobile platforms
}
func (n *Notifier) OnNewRoutes(idMap route.HAMap) {
// Not used on non-mobile platforms
}
func (n *Notifier) OnNewPrefixes(prefixes []netip.Prefix) {
// Not used on non-mobile platforms
}
func (n *Notifier) GetInitialRouteRanges() []string {
return []string{}
}

View File

@@ -6,7 +6,6 @@ import (
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/routemanager/common"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
"github.com/netbirdio/netbird/route"
)
@@ -17,11 +16,11 @@ type Route struct {
allowedIPsRefcounter *refcounter.AllowedIPsRefCounter
}
func NewRoute(params common.HandlerParams) *Route {
func NewRoute(rt *route.Route, routeRefCounter *refcounter.RouteRefCounter, allowedIPsRefCounter *refcounter.AllowedIPsRefCounter) *Route {
return &Route{
route: params.Route,
routeRefCounter: params.RouteRefCounter,
allowedIPsRefcounter: params.AllowedIPsRefCounter,
route: rt,
routeRefCounter: routeRefCounter,
allowedIPsRefcounter: allowedIPsRefCounter,
}
}

View File

@@ -5,7 +5,6 @@ import (
"net"
"net/netip"
"sync"
"sync/atomic"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/internal/routemanager/notifier"
@@ -29,10 +28,7 @@ func (n Nexthop) String() string {
if n.Intf == nil {
return n.IP.String()
}
if n.IP.IsValid() {
return fmt.Sprintf("%s @ %d (%s)", n.IP.String(), n.Intf.Index, n.Intf.Name)
}
return fmt.Sprintf("no-ip @ %d (%s)", n.Intf.Index, n.Intf.Name)
return fmt.Sprintf("%s @ %d (%s)", n.IP.String(), n.Intf.Index, n.Intf.Name)
}
type wgIface interface {
@@ -53,9 +49,6 @@ type SysOps struct {
mu sync.Mutex
// notifier is used to notify the system of route changes (also used on mobile)
notifier *notifier.Notifier
// seq is an atomic counter for generating unique sequence numbers for route messages
//nolint:unused // only used on BSD systems
seq atomic.Uint32
}
func NewSysOps(wgInterface wgIface, notifier *notifier.Notifier) *SysOps {
@@ -65,11 +58,6 @@ func NewSysOps(wgInterface wgIface, notifier *notifier.Notifier) *SysOps {
}
}
//nolint:unused // only used on BSD systems
func (r *SysOps) getSeq() int {
return int(r.seq.Add(1))
}
func (r *SysOps) validateRoute(prefix netip.Prefix) error {
addr := prefix.Addr()

View File

@@ -108,7 +108,7 @@ func (r *SysOps) buildRouteMessage(action int, prefix netip.Prefix, nexthop Next
Type: action,
Flags: unix.RTF_UP,
Version: unix.RTM_VERSION,
Seq: r.getSeq(),
Seq: 1,
}
const numAddrs = unix.RTAX_NETMASK + 1

View File

@@ -10,11 +10,10 @@ type StatusType string
const (
StatusIdle StatusType = "Idle"
StatusConnecting StatusType = "Connecting"
StatusConnected StatusType = "Connected"
StatusNeedsLogin StatusType = "NeedsLogin"
StatusLoginFailed StatusType = "LoginFailed"
StatusSessionExpired StatusType = "SessionExpired"
StatusConnecting StatusType = "Connecting"
StatusConnected StatusType = "Connected"
StatusNeedsLogin StatusType = "NeedsLogin"
StatusLoginFailed StatusType = "LoginFailed"
)
// CtxInitState setup context state into the context tree.

Some files were not shown because too many files have changed in this diff Show More