mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-12 12:36:15 -04:00
Compare commits
1 Commits
bind-ipv6
...
preresolve
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
3251bc79fa |
2
.github/workflows/release.yml
vendored
2
.github/workflows/release.yml
vendored
@@ -9,7 +9,7 @@ on:
|
||||
pull_request:
|
||||
|
||||
env:
|
||||
SIGN_PIPE_VER: "v0.0.20"
|
||||
SIGN_PIPE_VER: "v0.0.18"
|
||||
GORELEASER_VER: "v2.3.2"
|
||||
PRODUCT_NAME: "NetBird"
|
||||
COPYRIGHT: "NetBird GmbH"
|
||||
|
||||
@@ -134,7 +134,6 @@ jobs:
|
||||
NETBIRD_STORE_ENGINE_MYSQL_DSN: '${{ env.NETBIRD_STORE_ENGINE_MYSQL_DSN }}$'
|
||||
CI_NETBIRD_MGMT_IDP_SIGNKEY_REFRESH: false
|
||||
CI_NETBIRD_TURN_EXTERNAL_IP: "1.2.3.4"
|
||||
CI_NETBIRD_MGMT_DISABLE_DEFAULT_POLICY: false
|
||||
|
||||
run: |
|
||||
set -x
|
||||
@@ -181,7 +180,6 @@ jobs:
|
||||
grep -A 7 Relay management.json | egrep '"Secret": ".+"'
|
||||
grep DisablePromptLogin management.json | grep 'true'
|
||||
grep LoginFlag management.json | grep 0
|
||||
grep DisableDefaultPolicy management.json | grep "$CI_NETBIRD_MGMT_DISABLE_DEFAULT_POLICY"
|
||||
|
||||
- name: Install modules
|
||||
run: go mod tidy
|
||||
|
||||
@@ -149,7 +149,6 @@ nfpms:
|
||||
dockers:
|
||||
- image_templates:
|
||||
- netbirdio/netbird:{{ .Version }}-amd64
|
||||
- ghcr.io/netbirdio/netbird:{{ .Version }}-amd64
|
||||
ids:
|
||||
- netbird
|
||||
goarch: amd64
|
||||
@@ -165,7 +164,6 @@ dockers:
|
||||
- "--label=maintainer=dev@netbird.io"
|
||||
- image_templates:
|
||||
- netbirdio/netbird:{{ .Version }}-arm64v8
|
||||
- ghcr.io/netbirdio/netbird:{{ .Version }}-arm64v8
|
||||
ids:
|
||||
- netbird
|
||||
goarch: arm64
|
||||
|
||||
@@ -14,9 +14,6 @@
|
||||
<br>
|
||||
<a href="https://docs.netbird.io/slack-url">
|
||||
<img src="https://img.shields.io/badge/slack-@netbird-red.svg?logo=slack"/>
|
||||
</a>
|
||||
<a href="https://forum.netbird.io">
|
||||
<img src="https://img.shields.io/badge/community forum-@netbird-red.svg?logo=discourse"/>
|
||||
</a>
|
||||
<br>
|
||||
<a href="https://gurubase.io/g/netbird">
|
||||
@@ -32,13 +29,13 @@
|
||||
<br/>
|
||||
See <a href="https://netbird.io/docs/">Documentation</a>
|
||||
<br/>
|
||||
Join our <a href="https://docs.netbird.io/slack-url">Slack channel</a> or our <a href="https://forum.netbird.io">Community forum</a>
|
||||
Join our <a href="https://docs.netbird.io/slack-url">Slack channel</a>
|
||||
<br/>
|
||||
|
||||
</strong>
|
||||
<br>
|
||||
<a href="https://registry.terraform.io/providers/netbirdio/netbird/latest">
|
||||
New: NetBird terraform provider
|
||||
<a href="https://github.com/netbirdio/kubernetes-operator">
|
||||
New: NetBird Kubernetes Operator
|
||||
</a>
|
||||
</p>
|
||||
|
||||
|
||||
@@ -1,9 +1,6 @@
|
||||
FROM alpine:3.21.3
|
||||
# iproute2: busybox doesn't display ip rules properly
|
||||
RUN apk add --no-cache ca-certificates ip6tables iproute2 iptables
|
||||
|
||||
ARG NETBIRD_BINARY=netbird
|
||||
COPY ${NETBIRD_BINARY} /usr/local/bin/netbird
|
||||
|
||||
ENV NB_FOREGROUND_MODE=true
|
||||
ENTRYPOINT [ "/usr/local/bin/netbird","up"]
|
||||
COPY netbird /usr/local/bin/netbird
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
FROM alpine:3.21.0
|
||||
|
||||
ARG NETBIRD_BINARY=netbird
|
||||
COPY ${NETBIRD_BINARY} /usr/local/bin/netbird
|
||||
COPY netbird /usr/local/bin/netbird
|
||||
|
||||
RUN apk add --no-cache ca-certificates \
|
||||
&& adduser -D -h /var/lib/netbird netbird
|
||||
|
||||
@@ -59,8 +59,6 @@ type Client struct {
|
||||
deviceName string
|
||||
uiVersion string
|
||||
networkChangeListener listener.NetworkChangeListener
|
||||
|
||||
connectClient *internal.ConnectClient
|
||||
}
|
||||
|
||||
// NewClient instantiate a new Client
|
||||
@@ -108,8 +106,8 @@ func (c *Client) Run(urlOpener URLOpener, dns *DNSList, dnsReadyListener DnsRead
|
||||
|
||||
// todo do not throw error in case of cancelled context
|
||||
ctx = internal.CtxInitState(ctx)
|
||||
c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder)
|
||||
return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, dns.items, dnsReadyListener)
|
||||
connectClient := internal.NewConnectClient(ctx, cfg, c.recorder)
|
||||
return connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, dns.items, dnsReadyListener)
|
||||
}
|
||||
|
||||
// RunWithoutLogin we apply this type of run function when the backed has been started without UI (i.e. after reboot).
|
||||
@@ -134,8 +132,8 @@ func (c *Client) RunWithoutLogin(dns *DNSList, dnsReadyListener DnsReadyListener
|
||||
|
||||
// todo do not throw error in case of cancelled context
|
||||
ctx = internal.CtxInitState(ctx)
|
||||
c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder)
|
||||
return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, dns.items, dnsReadyListener)
|
||||
connectClient := internal.NewConnectClient(ctx, cfg, c.recorder)
|
||||
return connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, dns.items, dnsReadyListener)
|
||||
}
|
||||
|
||||
// Stop the internal client and free the resources
|
||||
@@ -176,55 +174,6 @@ func (c *Client) PeersList() *PeerInfoArray {
|
||||
return &PeerInfoArray{items: peerInfos}
|
||||
}
|
||||
|
||||
func (c *Client) Networks() *NetworkArray {
|
||||
if c.connectClient == nil {
|
||||
log.Error("not connected")
|
||||
return nil
|
||||
}
|
||||
|
||||
engine := c.connectClient.Engine()
|
||||
if engine == nil {
|
||||
log.Error("could not get engine")
|
||||
return nil
|
||||
}
|
||||
|
||||
routeManager := engine.GetRouteManager()
|
||||
if routeManager == nil {
|
||||
log.Error("could not get route manager")
|
||||
return nil
|
||||
}
|
||||
|
||||
networkArray := &NetworkArray{
|
||||
items: make([]Network, 0),
|
||||
}
|
||||
|
||||
for id, routes := range routeManager.GetClientRoutesWithNetID() {
|
||||
if len(routes) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
r := routes[0]
|
||||
netStr := r.Network.String()
|
||||
if r.IsDynamic() {
|
||||
netStr = r.Domains.SafeString()
|
||||
}
|
||||
|
||||
peer, err := c.recorder.GetPeer(routes[0].Peer)
|
||||
if err != nil {
|
||||
log.Errorf("could not get peer info for %s: %v", routes[0].Peer, err)
|
||||
continue
|
||||
}
|
||||
network := Network{
|
||||
Name: string(id),
|
||||
Network: netStr,
|
||||
Peer: peer.FQDN,
|
||||
Status: peer.ConnStatus.String(),
|
||||
}
|
||||
networkArray.Add(network)
|
||||
}
|
||||
return networkArray
|
||||
}
|
||||
|
||||
// OnUpdatedHostDNS update the DNS servers addresses for root zones
|
||||
func (c *Client) OnUpdatedHostDNS(list *DNSList) error {
|
||||
dnsServer, err := dns.GetServerDns()
|
||||
|
||||
@@ -1,27 +0,0 @@
|
||||
//go:build android
|
||||
|
||||
package android
|
||||
|
||||
type Network struct {
|
||||
Name string
|
||||
Network string
|
||||
Peer string
|
||||
Status string
|
||||
}
|
||||
|
||||
type NetworkArray struct {
|
||||
items []Network
|
||||
}
|
||||
|
||||
func (array *NetworkArray) Add(s Network) *NetworkArray {
|
||||
array.items = append(array.items, s)
|
||||
return array
|
||||
}
|
||||
|
||||
func (array *NetworkArray) Get(i int) *Network {
|
||||
return &array.items[i]
|
||||
}
|
||||
|
||||
func (array *NetworkArray) Size() int {
|
||||
return len(array.items)
|
||||
}
|
||||
@@ -7,23 +7,30 @@ type PeerInfo struct {
|
||||
ConnStatus string // Todo replace to enum
|
||||
}
|
||||
|
||||
// PeerInfoArray is a wrapper of []PeerInfo
|
||||
// PeerInfoCollection made for Java layer to get non default types as collection
|
||||
type PeerInfoCollection interface {
|
||||
Add(s string) PeerInfoCollection
|
||||
Get(i int) string
|
||||
Size() int
|
||||
}
|
||||
|
||||
// PeerInfoArray is the implementation of the PeerInfoCollection
|
||||
type PeerInfoArray struct {
|
||||
items []PeerInfo
|
||||
}
|
||||
|
||||
// Add new PeerInfo to the collection
|
||||
func (array *PeerInfoArray) Add(s PeerInfo) *PeerInfoArray {
|
||||
func (array PeerInfoArray) Add(s PeerInfo) PeerInfoArray {
|
||||
array.items = append(array.items, s)
|
||||
return array
|
||||
}
|
||||
|
||||
// Get return an element of the collection
|
||||
func (array *PeerInfoArray) Get(i int) *PeerInfo {
|
||||
func (array PeerInfoArray) Get(i int) *PeerInfo {
|
||||
return &array.items[i]
|
||||
}
|
||||
|
||||
// Size return with the size of the collection
|
||||
func (array *PeerInfoArray) Size() int {
|
||||
func (array PeerInfoArray) Size() int {
|
||||
return len(array.items)
|
||||
}
|
||||
|
||||
@@ -4,12 +4,12 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
)
|
||||
|
||||
// Preferences exports a subset of the internal config for gomobile
|
||||
// Preferences export a subset of the internal config for gomobile
|
||||
type Preferences struct {
|
||||
configInput internal.ConfigInput
|
||||
}
|
||||
|
||||
// NewPreferences creates a new Preferences instance
|
||||
// NewPreferences create new Preferences instance
|
||||
func NewPreferences(configPath string) *Preferences {
|
||||
ci := internal.ConfigInput{
|
||||
ConfigPath: configPath,
|
||||
@@ -17,7 +17,7 @@ func NewPreferences(configPath string) *Preferences {
|
||||
return &Preferences{ci}
|
||||
}
|
||||
|
||||
// GetManagementURL reads URL from config file
|
||||
// GetManagementURL read url from config file
|
||||
func (p *Preferences) GetManagementURL() (string, error) {
|
||||
if p.configInput.ManagementURL != "" {
|
||||
return p.configInput.ManagementURL, nil
|
||||
@@ -30,12 +30,12 @@ func (p *Preferences) GetManagementURL() (string, error) {
|
||||
return cfg.ManagementURL.String(), err
|
||||
}
|
||||
|
||||
// SetManagementURL stores the given URL and waits for commit
|
||||
// SetManagementURL store the given url and wait for commit
|
||||
func (p *Preferences) SetManagementURL(url string) {
|
||||
p.configInput.ManagementURL = url
|
||||
}
|
||||
|
||||
// GetAdminURL reads URL from config file
|
||||
// GetAdminURL read url from config file
|
||||
func (p *Preferences) GetAdminURL() (string, error) {
|
||||
if p.configInput.AdminURL != "" {
|
||||
return p.configInput.AdminURL, nil
|
||||
@@ -48,12 +48,12 @@ func (p *Preferences) GetAdminURL() (string, error) {
|
||||
return cfg.AdminURL.String(), err
|
||||
}
|
||||
|
||||
// SetAdminURL stores the given URL and waits for commit
|
||||
// SetAdminURL store the given url and wait for commit
|
||||
func (p *Preferences) SetAdminURL(url string) {
|
||||
p.configInput.AdminURL = url
|
||||
}
|
||||
|
||||
// GetPreSharedKey reads pre-shared key from config file
|
||||
// GetPreSharedKey read preshared key from config file
|
||||
func (p *Preferences) GetPreSharedKey() (string, error) {
|
||||
if p.configInput.PreSharedKey != nil {
|
||||
return *p.configInput.PreSharedKey, nil
|
||||
@@ -66,17 +66,17 @@ func (p *Preferences) GetPreSharedKey() (string, error) {
|
||||
return cfg.PreSharedKey, err
|
||||
}
|
||||
|
||||
// SetPreSharedKey stores the given key and waits for commit
|
||||
// SetPreSharedKey store the given key and wait for commit
|
||||
func (p *Preferences) SetPreSharedKey(key string) {
|
||||
p.configInput.PreSharedKey = &key
|
||||
}
|
||||
|
||||
// SetRosenpassEnabled stores whether Rosenpass is enabled
|
||||
// SetRosenpassEnabled store if rosenpass is enabled
|
||||
func (p *Preferences) SetRosenpassEnabled(enabled bool) {
|
||||
p.configInput.RosenpassEnabled = &enabled
|
||||
}
|
||||
|
||||
// GetRosenpassEnabled reads Rosenpass enabled status from config file
|
||||
// GetRosenpassEnabled read rosenpass enabled from config file
|
||||
func (p *Preferences) GetRosenpassEnabled() (bool, error) {
|
||||
if p.configInput.RosenpassEnabled != nil {
|
||||
return *p.configInput.RosenpassEnabled, nil
|
||||
@@ -89,12 +89,12 @@ func (p *Preferences) GetRosenpassEnabled() (bool, error) {
|
||||
return cfg.RosenpassEnabled, err
|
||||
}
|
||||
|
||||
// SetRosenpassPermissive stores the given permissive setting and waits for commit
|
||||
// SetRosenpassPermissive store the given permissive and wait for commit
|
||||
func (p *Preferences) SetRosenpassPermissive(permissive bool) {
|
||||
p.configInput.RosenpassPermissive = &permissive
|
||||
}
|
||||
|
||||
// GetRosenpassPermissive reads Rosenpass permissive setting from config file
|
||||
// GetRosenpassPermissive read rosenpass permissive from config file
|
||||
func (p *Preferences) GetRosenpassPermissive() (bool, error) {
|
||||
if p.configInput.RosenpassPermissive != nil {
|
||||
return *p.configInput.RosenpassPermissive, nil
|
||||
@@ -107,119 +107,7 @@ func (p *Preferences) GetRosenpassPermissive() (bool, error) {
|
||||
return cfg.RosenpassPermissive, err
|
||||
}
|
||||
|
||||
// GetDisableClientRoutes reads disable client routes setting from config file
|
||||
func (p *Preferences) GetDisableClientRoutes() (bool, error) {
|
||||
if p.configInput.DisableClientRoutes != nil {
|
||||
return *p.configInput.DisableClientRoutes, nil
|
||||
}
|
||||
|
||||
cfg, err := internal.ReadConfig(p.configInput.ConfigPath)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return cfg.DisableClientRoutes, err
|
||||
}
|
||||
|
||||
// SetDisableClientRoutes stores the given value and waits for commit
|
||||
func (p *Preferences) SetDisableClientRoutes(disable bool) {
|
||||
p.configInput.DisableClientRoutes = &disable
|
||||
}
|
||||
|
||||
// GetDisableServerRoutes reads disable server routes setting from config file
|
||||
func (p *Preferences) GetDisableServerRoutes() (bool, error) {
|
||||
if p.configInput.DisableServerRoutes != nil {
|
||||
return *p.configInput.DisableServerRoutes, nil
|
||||
}
|
||||
|
||||
cfg, err := internal.ReadConfig(p.configInput.ConfigPath)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return cfg.DisableServerRoutes, err
|
||||
}
|
||||
|
||||
// SetDisableServerRoutes stores the given value and waits for commit
|
||||
func (p *Preferences) SetDisableServerRoutes(disable bool) {
|
||||
p.configInput.DisableServerRoutes = &disable
|
||||
}
|
||||
|
||||
// GetDisableDNS reads disable DNS setting from config file
|
||||
func (p *Preferences) GetDisableDNS() (bool, error) {
|
||||
if p.configInput.DisableDNS != nil {
|
||||
return *p.configInput.DisableDNS, nil
|
||||
}
|
||||
|
||||
cfg, err := internal.ReadConfig(p.configInput.ConfigPath)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return cfg.DisableDNS, err
|
||||
}
|
||||
|
||||
// SetDisableDNS stores the given value and waits for commit
|
||||
func (p *Preferences) SetDisableDNS(disable bool) {
|
||||
p.configInput.DisableDNS = &disable
|
||||
}
|
||||
|
||||
// GetDisableFirewall reads disable firewall setting from config file
|
||||
func (p *Preferences) GetDisableFirewall() (bool, error) {
|
||||
if p.configInput.DisableFirewall != nil {
|
||||
return *p.configInput.DisableFirewall, nil
|
||||
}
|
||||
|
||||
cfg, err := internal.ReadConfig(p.configInput.ConfigPath)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return cfg.DisableFirewall, err
|
||||
}
|
||||
|
||||
// SetDisableFirewall stores the given value and waits for commit
|
||||
func (p *Preferences) SetDisableFirewall(disable bool) {
|
||||
p.configInput.DisableFirewall = &disable
|
||||
}
|
||||
|
||||
// GetServerSSHAllowed reads server SSH allowed setting from config file
|
||||
func (p *Preferences) GetServerSSHAllowed() (bool, error) {
|
||||
if p.configInput.ServerSSHAllowed != nil {
|
||||
return *p.configInput.ServerSSHAllowed, nil
|
||||
}
|
||||
|
||||
cfg, err := internal.ReadConfig(p.configInput.ConfigPath)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
if cfg.ServerSSHAllowed == nil {
|
||||
// Default to false for security on Android
|
||||
return false, nil
|
||||
}
|
||||
return *cfg.ServerSSHAllowed, err
|
||||
}
|
||||
|
||||
// SetServerSSHAllowed stores the given value and waits for commit
|
||||
func (p *Preferences) SetServerSSHAllowed(allowed bool) {
|
||||
p.configInput.ServerSSHAllowed = &allowed
|
||||
}
|
||||
|
||||
// GetBlockInbound reads block inbound setting from config file
|
||||
func (p *Preferences) GetBlockInbound() (bool, error) {
|
||||
if p.configInput.BlockInbound != nil {
|
||||
return *p.configInput.BlockInbound, nil
|
||||
}
|
||||
|
||||
cfg, err := internal.ReadConfig(p.configInput.ConfigPath)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return cfg.BlockInbound, err
|
||||
}
|
||||
|
||||
// SetBlockInbound stores the given value and waits for commit
|
||||
func (p *Preferences) SetBlockInbound(block bool) {
|
||||
p.configInput.BlockInbound = &block
|
||||
}
|
||||
|
||||
// Commit writes out the changes to the config file
|
||||
// Commit write out the changes into config file
|
||||
func (p *Preferences) Commit() error {
|
||||
_, err := internal.UpdateOrCreateConfig(p.configInput)
|
||||
return err
|
||||
|
||||
@@ -69,10 +69,7 @@ func statusFunc(cmd *cobra.Command, args []string) error {
|
||||
return err
|
||||
}
|
||||
|
||||
status := resp.GetStatus()
|
||||
|
||||
if status == string(internal.StatusNeedsLogin) || status == string(internal.StatusLoginFailed) ||
|
||||
status == string(internal.StatusSessionExpired) {
|
||||
if resp.GetStatus() == string(internal.StatusNeedsLogin) || resp.GetStatus() == string(internal.StatusLoginFailed) {
|
||||
cmd.Printf("Daemon status: %s\n\n"+
|
||||
"Run UP command to log in with SSO (interactive login):\n\n"+
|
||||
" netbird up \n\n"+
|
||||
@@ -120,7 +117,7 @@ func getStatus(ctx context.Context) (*proto.StatusResponse, error) {
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
resp, err := proto.NewDaemonServiceClient(conn).Status(ctx, &proto.StatusRequest{GetFullPeerStatus: true, ShouldRunProbes: true})
|
||||
resp, err := proto.NewDaemonServiceClient(conn).Status(ctx, &proto.StatusRequest{GetFullPeerStatus: true})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("status failed: %v", status.Convert(err).Message())
|
||||
}
|
||||
|
||||
@@ -38,5 +38,5 @@ func init() {
|
||||
|
||||
upCmd.PersistentFlags().BoolVar(&blockInbound, blockInboundFlag, false,
|
||||
"Block inbound connections. If enabled, the client will not allow any inbound connections to the local machine nor routed networks.\n"+
|
||||
"This overrides any policies received from the management service.")
|
||||
"This overrides any policies received from the management service.")
|
||||
}
|
||||
|
||||
@@ -103,7 +103,7 @@ func startManagement(t *testing.T, config *types.Config, testFile string) (*grpc
|
||||
Return(&types.Settings{}, nil).
|
||||
AnyTimes()
|
||||
|
||||
accountManager, err := mgmt.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, iv, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false)
|
||||
accountManager, err := mgmt.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, iv, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
@@ -118,7 +118,7 @@ func tracePacket(cmd *cobra.Command, args []string) error {
|
||||
}
|
||||
|
||||
func printTrace(cmd *cobra.Command, src, dst, proto string, sport, dport uint16, resp *proto.TracePacketResponse) {
|
||||
cmd.Printf("Packet trace %s:%d → %s:%d (%s)\n\n", src, sport, dst, dport, strings.ToUpper(proto))
|
||||
cmd.Printf("Packet trace %s:%d -> %s:%d (%s)\n\n", src, sport, dst, dport, strings.ToUpper(proto))
|
||||
|
||||
for _, stage := range resp.Stages {
|
||||
if stage.ForwardingDetails != nil {
|
||||
|
||||
@@ -62,5 +62,5 @@ type ConnKey struct {
|
||||
}
|
||||
|
||||
func (c ConnKey) String() string {
|
||||
return fmt.Sprintf("%s:%d → %s:%d", c.SrcIP.Unmap(), c.SrcPort, c.DstIP.Unmap(), c.DstPort)
|
||||
return fmt.Sprintf("%s:%d -> %s:%d", c.SrcIP.Unmap(), c.SrcPort, c.DstIP.Unmap(), c.DstPort)
|
||||
}
|
||||
|
||||
@@ -3,7 +3,6 @@ package conntrack
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"sync"
|
||||
"time"
|
||||
@@ -20,10 +19,6 @@ const (
|
||||
DefaultICMPTimeout = 30 * time.Second
|
||||
// ICMPCleanupInterval is how often we check for stale ICMP connections
|
||||
ICMPCleanupInterval = 15 * time.Second
|
||||
|
||||
// MaxICMPPayloadLength is the maximum length of ICMP payload we consider for original packet info,
|
||||
// which includes the IP header (20 bytes) and transport header (8 bytes)
|
||||
MaxICMPPayloadLength = 28
|
||||
)
|
||||
|
||||
// ICMPConnKey uniquely identifies an ICMP connection
|
||||
@@ -34,7 +29,7 @@ type ICMPConnKey struct {
|
||||
}
|
||||
|
||||
func (i ICMPConnKey) String() string {
|
||||
return fmt.Sprintf("%s → %s (id %d)", i.SrcIP, i.DstIP, i.ID)
|
||||
return fmt.Sprintf("%s -> %s (id %d)", i.SrcIP, i.DstIP, i.ID)
|
||||
}
|
||||
|
||||
// ICMPConnTrack represents an ICMP connection state
|
||||
@@ -55,72 +50,6 @@ type ICMPTracker struct {
|
||||
flowLogger nftypes.FlowLogger
|
||||
}
|
||||
|
||||
// ICMPInfo holds ICMP type, code, and payload for lazy string formatting in logs
|
||||
type ICMPInfo struct {
|
||||
TypeCode layers.ICMPv4TypeCode
|
||||
PayloadData [MaxICMPPayloadLength]byte
|
||||
// actual length of valid data
|
||||
PayloadLen int
|
||||
}
|
||||
|
||||
// String implements fmt.Stringer for lazy evaluation in log messages
|
||||
func (info ICMPInfo) String() string {
|
||||
if info.isErrorMessage() && info.PayloadLen >= MaxICMPPayloadLength {
|
||||
if origInfo := info.parseOriginalPacket(); origInfo != "" {
|
||||
return fmt.Sprintf("%s (original: %s)", info.TypeCode, origInfo)
|
||||
}
|
||||
}
|
||||
|
||||
return info.TypeCode.String()
|
||||
}
|
||||
|
||||
// isErrorMessage returns true if this ICMP type carries original packet info
|
||||
func (info ICMPInfo) isErrorMessage() bool {
|
||||
typ := info.TypeCode.Type()
|
||||
return typ == 3 || // Destination Unreachable
|
||||
typ == 5 || // Redirect
|
||||
typ == 11 || // Time Exceeded
|
||||
typ == 12 // Parameter Problem
|
||||
}
|
||||
|
||||
// parseOriginalPacket extracts info about the original packet from ICMP payload
|
||||
func (info ICMPInfo) parseOriginalPacket() string {
|
||||
if info.PayloadLen < MaxICMPPayloadLength {
|
||||
return ""
|
||||
}
|
||||
|
||||
// TODO: handle IPv6
|
||||
if version := (info.PayloadData[0] >> 4) & 0xF; version != 4 {
|
||||
return ""
|
||||
}
|
||||
|
||||
protocol := info.PayloadData[9]
|
||||
srcIP := net.IP(info.PayloadData[12:16])
|
||||
dstIP := net.IP(info.PayloadData[16:20])
|
||||
|
||||
transportData := info.PayloadData[20:]
|
||||
|
||||
switch nftypes.Protocol(protocol) {
|
||||
case nftypes.TCP:
|
||||
srcPort := uint16(transportData[0])<<8 | uint16(transportData[1])
|
||||
dstPort := uint16(transportData[2])<<8 | uint16(transportData[3])
|
||||
return fmt.Sprintf("TCP %s:%d → %s:%d", srcIP, srcPort, dstIP, dstPort)
|
||||
|
||||
case nftypes.UDP:
|
||||
srcPort := uint16(transportData[0])<<8 | uint16(transportData[1])
|
||||
dstPort := uint16(transportData[2])<<8 | uint16(transportData[3])
|
||||
return fmt.Sprintf("UDP %s:%d → %s:%d", srcIP, srcPort, dstIP, dstPort)
|
||||
|
||||
case nftypes.ICMP:
|
||||
icmpType := transportData[0]
|
||||
icmpCode := transportData[1]
|
||||
return fmt.Sprintf("ICMP %s → %s (type %d code %d)", srcIP, dstIP, icmpType, icmpCode)
|
||||
|
||||
default:
|
||||
return fmt.Sprintf("Proto %d %s → %s", protocol, srcIP, dstIP)
|
||||
}
|
||||
}
|
||||
|
||||
// NewICMPTracker creates a new ICMP connection tracker
|
||||
func NewICMPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nftypes.FlowLogger) *ICMPTracker {
|
||||
if timeout == 0 {
|
||||
@@ -164,64 +93,30 @@ func (t *ICMPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, id uint
|
||||
}
|
||||
|
||||
// TrackOutbound records an outbound ICMP connection
|
||||
func (t *ICMPTracker) TrackOutbound(
|
||||
srcIP netip.Addr,
|
||||
dstIP netip.Addr,
|
||||
id uint16,
|
||||
typecode layers.ICMPv4TypeCode,
|
||||
payload []byte,
|
||||
size int,
|
||||
) {
|
||||
func (t *ICMPTracker) TrackOutbound(srcIP netip.Addr, dstIP netip.Addr, id uint16, typecode layers.ICMPv4TypeCode, size int) {
|
||||
if _, exists := t.updateIfExists(dstIP, srcIP, id, nftypes.Egress, size); !exists {
|
||||
// if (inverted direction) conn is not tracked, track this direction
|
||||
t.track(srcIP, dstIP, id, typecode, nftypes.Egress, nil, payload, size)
|
||||
t.track(srcIP, dstIP, id, typecode, nftypes.Egress, nil, size)
|
||||
}
|
||||
}
|
||||
|
||||
// TrackInbound records an inbound ICMP Echo Request
|
||||
func (t *ICMPTracker) TrackInbound(
|
||||
srcIP netip.Addr,
|
||||
dstIP netip.Addr,
|
||||
id uint16,
|
||||
typecode layers.ICMPv4TypeCode,
|
||||
ruleId []byte,
|
||||
payload []byte,
|
||||
size int,
|
||||
) {
|
||||
t.track(srcIP, dstIP, id, typecode, nftypes.Ingress, ruleId, payload, size)
|
||||
func (t *ICMPTracker) TrackInbound(srcIP netip.Addr, dstIP netip.Addr, id uint16, typecode layers.ICMPv4TypeCode, ruleId []byte, size int) {
|
||||
t.track(srcIP, dstIP, id, typecode, nftypes.Ingress, ruleId, size)
|
||||
}
|
||||
|
||||
// track is the common implementation for tracking both inbound and outbound ICMP connections
|
||||
func (t *ICMPTracker) track(
|
||||
srcIP netip.Addr,
|
||||
dstIP netip.Addr,
|
||||
id uint16,
|
||||
typecode layers.ICMPv4TypeCode,
|
||||
direction nftypes.Direction,
|
||||
ruleId []byte,
|
||||
payload []byte,
|
||||
size int,
|
||||
) {
|
||||
func (t *ICMPTracker) track(srcIP netip.Addr, dstIP netip.Addr, id uint16, typecode layers.ICMPv4TypeCode, direction nftypes.Direction, ruleId []byte, size int) {
|
||||
key, exists := t.updateIfExists(srcIP, dstIP, id, direction, size)
|
||||
if exists {
|
||||
return
|
||||
}
|
||||
|
||||
typ, code := typecode.Type(), typecode.Code()
|
||||
icmpInfo := ICMPInfo{
|
||||
TypeCode: typecode,
|
||||
}
|
||||
if len(payload) > 0 {
|
||||
icmpInfo.PayloadLen = len(payload)
|
||||
if icmpInfo.PayloadLen > MaxICMPPayloadLength {
|
||||
icmpInfo.PayloadLen = MaxICMPPayloadLength
|
||||
}
|
||||
copy(icmpInfo.PayloadData[:], payload[:icmpInfo.PayloadLen])
|
||||
}
|
||||
|
||||
// non echo requests don't need tracking
|
||||
if typ != uint8(layers.ICMPv4TypeEchoRequest) {
|
||||
t.logger.Trace("New %s ICMP connection %s - %s", direction, key, icmpInfo)
|
||||
t.logger.Trace("New %s ICMP connection %s type %d code %d", direction, key, typ, code)
|
||||
t.sendStartEvent(direction, srcIP, dstIP, typ, code, ruleId, size)
|
||||
return
|
||||
}
|
||||
@@ -243,7 +138,7 @@ func (t *ICMPTracker) track(
|
||||
t.connections[key] = conn
|
||||
t.mutex.Unlock()
|
||||
|
||||
t.logger.Trace("New %s ICMP connection %s - %s", direction, key, icmpInfo)
|
||||
t.logger.Trace("New %s ICMP connection %s type %d code %d", direction, key, typ, code)
|
||||
t.sendEvent(nftypes.TypeStart, conn, ruleId)
|
||||
}
|
||||
|
||||
|
||||
@@ -15,7 +15,7 @@ func BenchmarkICMPTracker(b *testing.B) {
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 0, []byte{}, 0)
|
||||
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 0, 0)
|
||||
}
|
||||
})
|
||||
|
||||
@@ -28,7 +28,7 @@ func BenchmarkICMPTracker(b *testing.B) {
|
||||
|
||||
// Pre-populate some connections
|
||||
for i := 0; i < 1000; i++ {
|
||||
tracker.TrackOutbound(srcIP, dstIP, uint16(i), 0, []byte{}, 0)
|
||||
tracker.TrackOutbound(srcIP, dstIP, uint16(i), 0, 0)
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
|
||||
@@ -86,5 +86,5 @@ type epID stack.TransportEndpointID
|
||||
|
||||
func (i epID) String() string {
|
||||
// src and remote is swapped
|
||||
return fmt.Sprintf("%s:%d → %s:%d", i.RemoteAddress, i.RemotePort, i.LocalAddress, i.LocalPort)
|
||||
return fmt.Sprintf("%s:%d -> %s:%d", i.RemoteAddress, i.RemotePort, i.LocalAddress, i.LocalPort)
|
||||
}
|
||||
|
||||
@@ -111,12 +111,12 @@ func (f *Forwarder) proxyTCP(id stack.TransportEndpointID, inConn *gonet.TCPConn
|
||||
|
||||
if errInToOut != nil {
|
||||
if !isClosedError(errInToOut) {
|
||||
f.logger.Error("proxyTCP: copy error (in → out) for %s: %v", epID(id), errInToOut)
|
||||
f.logger.Error("proxyTCP: copy error (in -> out) for %s: %v", epID(id), errInToOut)
|
||||
}
|
||||
}
|
||||
if errOutToIn != nil {
|
||||
if !isClosedError(errOutToIn) {
|
||||
f.logger.Error("proxyTCP: copy error (out → in) for %s: %v", epID(id), errOutToIn)
|
||||
f.logger.Error("proxyTCP: copy error (out -> in) for %s: %v", epID(id), errOutToIn)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -250,10 +250,10 @@ func (f *Forwarder) proxyUDP(ctx context.Context, pConn *udpPacketConn, id stack
|
||||
wg.Wait()
|
||||
|
||||
if outboundErr != nil && !isClosedError(outboundErr) {
|
||||
f.logger.Error("proxyUDP: copy error (outbound→inbound) for %s: %v", epID(id), outboundErr)
|
||||
f.logger.Error("proxyUDP: copy error (outbound->inbound) for %s: %v", epID(id), outboundErr)
|
||||
}
|
||||
if inboundErr != nil && !isClosedError(inboundErr) {
|
||||
f.logger.Error("proxyUDP: copy error (inbound→outbound) for %s: %v", epID(id), inboundErr)
|
||||
f.logger.Error("proxyUDP: copy error (inbound->outbound) for %s: %v", epID(id), inboundErr)
|
||||
}
|
||||
|
||||
var rxPackets, txPackets uint64
|
||||
|
||||
@@ -1,408 +0,0 @@
|
||||
package uspfilter
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
|
||||
"github.com/google/gopacket/layers"
|
||||
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
)
|
||||
|
||||
var ErrIPv4Only = errors.New("only IPv4 is supported for DNAT")
|
||||
|
||||
func ipv4Checksum(header []byte) uint16 {
|
||||
if len(header) < 20 {
|
||||
return 0
|
||||
}
|
||||
|
||||
var sum1, sum2 uint32
|
||||
|
||||
// Parallel processing - unroll and compute two sums simultaneously
|
||||
sum1 += uint32(binary.BigEndian.Uint16(header[0:2]))
|
||||
sum2 += uint32(binary.BigEndian.Uint16(header[2:4]))
|
||||
sum1 += uint32(binary.BigEndian.Uint16(header[4:6]))
|
||||
sum2 += uint32(binary.BigEndian.Uint16(header[6:8]))
|
||||
sum1 += uint32(binary.BigEndian.Uint16(header[8:10]))
|
||||
// Skip checksum field at [10:12]
|
||||
sum2 += uint32(binary.BigEndian.Uint16(header[12:14]))
|
||||
sum1 += uint32(binary.BigEndian.Uint16(header[14:16]))
|
||||
sum2 += uint32(binary.BigEndian.Uint16(header[16:18]))
|
||||
sum1 += uint32(binary.BigEndian.Uint16(header[18:20]))
|
||||
|
||||
sum := sum1 + sum2
|
||||
|
||||
// Handle remaining bytes for headers > 20 bytes
|
||||
for i := 20; i < len(header)-1; i += 2 {
|
||||
sum += uint32(binary.BigEndian.Uint16(header[i : i+2]))
|
||||
}
|
||||
|
||||
if len(header)%2 == 1 {
|
||||
sum += uint32(header[len(header)-1]) << 8
|
||||
}
|
||||
|
||||
// Optimized carry fold - single iteration handles most cases
|
||||
sum = (sum & 0xFFFF) + (sum >> 16)
|
||||
if sum > 0xFFFF {
|
||||
sum++
|
||||
}
|
||||
|
||||
return ^uint16(sum)
|
||||
}
|
||||
|
||||
func icmpChecksum(data []byte) uint16 {
|
||||
var sum1, sum2, sum3, sum4 uint32
|
||||
i := 0
|
||||
|
||||
// Process 16 bytes at once with 4 parallel accumulators
|
||||
for i <= len(data)-16 {
|
||||
sum1 += uint32(binary.BigEndian.Uint16(data[i : i+2]))
|
||||
sum2 += uint32(binary.BigEndian.Uint16(data[i+2 : i+4]))
|
||||
sum3 += uint32(binary.BigEndian.Uint16(data[i+4 : i+6]))
|
||||
sum4 += uint32(binary.BigEndian.Uint16(data[i+6 : i+8]))
|
||||
sum1 += uint32(binary.BigEndian.Uint16(data[i+8 : i+10]))
|
||||
sum2 += uint32(binary.BigEndian.Uint16(data[i+10 : i+12]))
|
||||
sum3 += uint32(binary.BigEndian.Uint16(data[i+12 : i+14]))
|
||||
sum4 += uint32(binary.BigEndian.Uint16(data[i+14 : i+16]))
|
||||
i += 16
|
||||
}
|
||||
|
||||
sum := sum1 + sum2 + sum3 + sum4
|
||||
|
||||
// Handle remaining bytes
|
||||
for i < len(data)-1 {
|
||||
sum += uint32(binary.BigEndian.Uint16(data[i : i+2]))
|
||||
i += 2
|
||||
}
|
||||
|
||||
if len(data)%2 == 1 {
|
||||
sum += uint32(data[len(data)-1]) << 8
|
||||
}
|
||||
|
||||
sum = (sum & 0xFFFF) + (sum >> 16)
|
||||
if sum > 0xFFFF {
|
||||
sum++
|
||||
}
|
||||
|
||||
return ^uint16(sum)
|
||||
}
|
||||
|
||||
type biDNATMap struct {
|
||||
forward map[netip.Addr]netip.Addr
|
||||
reverse map[netip.Addr]netip.Addr
|
||||
}
|
||||
|
||||
func newBiDNATMap() *biDNATMap {
|
||||
return &biDNATMap{
|
||||
forward: make(map[netip.Addr]netip.Addr),
|
||||
reverse: make(map[netip.Addr]netip.Addr),
|
||||
}
|
||||
}
|
||||
|
||||
func (b *biDNATMap) set(original, translated netip.Addr) {
|
||||
b.forward[original] = translated
|
||||
b.reverse[translated] = original
|
||||
}
|
||||
|
||||
func (b *biDNATMap) delete(original netip.Addr) {
|
||||
if translated, exists := b.forward[original]; exists {
|
||||
delete(b.forward, original)
|
||||
delete(b.reverse, translated)
|
||||
}
|
||||
}
|
||||
|
||||
func (b *biDNATMap) getTranslated(original netip.Addr) (netip.Addr, bool) {
|
||||
translated, exists := b.forward[original]
|
||||
return translated, exists
|
||||
}
|
||||
|
||||
func (b *biDNATMap) getOriginal(translated netip.Addr) (netip.Addr, bool) {
|
||||
original, exists := b.reverse[translated]
|
||||
return original, exists
|
||||
}
|
||||
|
||||
func (m *Manager) AddInternalDNATMapping(originalAddr, translatedAddr netip.Addr) error {
|
||||
if !originalAddr.IsValid() || !translatedAddr.IsValid() {
|
||||
return fmt.Errorf("invalid IP addresses")
|
||||
}
|
||||
|
||||
if m.localipmanager.IsLocalIP(translatedAddr) {
|
||||
return fmt.Errorf("cannot map to local IP: %s", translatedAddr)
|
||||
}
|
||||
|
||||
m.dnatMutex.Lock()
|
||||
defer m.dnatMutex.Unlock()
|
||||
|
||||
// Initialize both maps together if either is nil
|
||||
if m.dnatMappings == nil || m.dnatBiMap == nil {
|
||||
m.dnatMappings = make(map[netip.Addr]netip.Addr)
|
||||
m.dnatBiMap = newBiDNATMap()
|
||||
}
|
||||
|
||||
m.dnatMappings[originalAddr] = translatedAddr
|
||||
m.dnatBiMap.set(originalAddr, translatedAddr)
|
||||
|
||||
if len(m.dnatMappings) == 1 {
|
||||
m.dnatEnabled.Store(true)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemoveInternalDNATMapping removes a 1:1 IP address mapping
|
||||
func (m *Manager) RemoveInternalDNATMapping(originalAddr netip.Addr) error {
|
||||
m.dnatMutex.Lock()
|
||||
defer m.dnatMutex.Unlock()
|
||||
|
||||
if _, exists := m.dnatMappings[originalAddr]; !exists {
|
||||
return fmt.Errorf("mapping not found for: %s", originalAddr)
|
||||
}
|
||||
|
||||
delete(m.dnatMappings, originalAddr)
|
||||
m.dnatBiMap.delete(originalAddr)
|
||||
if len(m.dnatMappings) == 0 {
|
||||
m.dnatEnabled.Store(false)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// getDNATTranslation returns the translated address if a mapping exists
|
||||
func (m *Manager) getDNATTranslation(addr netip.Addr) (netip.Addr, bool) {
|
||||
if !m.dnatEnabled.Load() {
|
||||
return addr, false
|
||||
}
|
||||
|
||||
m.dnatMutex.RLock()
|
||||
translated, exists := m.dnatBiMap.getTranslated(addr)
|
||||
m.dnatMutex.RUnlock()
|
||||
return translated, exists
|
||||
}
|
||||
|
||||
// findReverseDNATMapping finds original address for return traffic
|
||||
func (m *Manager) findReverseDNATMapping(translatedAddr netip.Addr) (netip.Addr, bool) {
|
||||
if !m.dnatEnabled.Load() {
|
||||
return translatedAddr, false
|
||||
}
|
||||
|
||||
m.dnatMutex.RLock()
|
||||
original, exists := m.dnatBiMap.getOriginal(translatedAddr)
|
||||
m.dnatMutex.RUnlock()
|
||||
return original, exists
|
||||
}
|
||||
|
||||
// translateOutboundDNAT applies DNAT translation to outbound packets
|
||||
func (m *Manager) translateOutboundDNAT(packetData []byte, d *decoder) bool {
|
||||
if !m.dnatEnabled.Load() {
|
||||
return false
|
||||
}
|
||||
|
||||
if len(packetData) < 20 || d.decoded[0] != layers.LayerTypeIPv4 {
|
||||
return false
|
||||
}
|
||||
|
||||
dstIP := netip.AddrFrom4([4]byte{packetData[16], packetData[17], packetData[18], packetData[19]})
|
||||
|
||||
translatedIP, exists := m.getDNATTranslation(dstIP)
|
||||
if !exists {
|
||||
return false
|
||||
}
|
||||
|
||||
if err := m.rewritePacketDestination(packetData, d, translatedIP); err != nil {
|
||||
m.logger.Error("Failed to rewrite packet destination: %v", err)
|
||||
return false
|
||||
}
|
||||
|
||||
m.logger.Trace("DNAT: %s -> %s", dstIP, translatedIP)
|
||||
return true
|
||||
}
|
||||
|
||||
// translateInboundReverse applies reverse DNAT to inbound return traffic
|
||||
func (m *Manager) translateInboundReverse(packetData []byte, d *decoder) bool {
|
||||
if !m.dnatEnabled.Load() {
|
||||
return false
|
||||
}
|
||||
|
||||
if len(packetData) < 20 || d.decoded[0] != layers.LayerTypeIPv4 {
|
||||
return false
|
||||
}
|
||||
|
||||
srcIP := netip.AddrFrom4([4]byte{packetData[12], packetData[13], packetData[14], packetData[15]})
|
||||
|
||||
originalIP, exists := m.findReverseDNATMapping(srcIP)
|
||||
if !exists {
|
||||
return false
|
||||
}
|
||||
|
||||
if err := m.rewritePacketSource(packetData, d, originalIP); err != nil {
|
||||
m.logger.Error("Failed to rewrite packet source: %v", err)
|
||||
return false
|
||||
}
|
||||
|
||||
m.logger.Trace("Reverse DNAT: %s -> %s", srcIP, originalIP)
|
||||
return true
|
||||
}
|
||||
|
||||
// rewritePacketDestination replaces destination IP in the packet
|
||||
func (m *Manager) rewritePacketDestination(packetData []byte, d *decoder, newIP netip.Addr) error {
|
||||
if len(packetData) < 20 || d.decoded[0] != layers.LayerTypeIPv4 || !newIP.Is4() {
|
||||
return ErrIPv4Only
|
||||
}
|
||||
|
||||
var oldDst [4]byte
|
||||
copy(oldDst[:], packetData[16:20])
|
||||
newDst := newIP.As4()
|
||||
|
||||
copy(packetData[16:20], newDst[:])
|
||||
|
||||
ipHeaderLen := int(d.ip4.IHL) * 4
|
||||
if ipHeaderLen < 20 || ipHeaderLen > len(packetData) {
|
||||
return fmt.Errorf("invalid IP header length")
|
||||
}
|
||||
|
||||
binary.BigEndian.PutUint16(packetData[10:12], 0)
|
||||
ipChecksum := ipv4Checksum(packetData[:ipHeaderLen])
|
||||
binary.BigEndian.PutUint16(packetData[10:12], ipChecksum)
|
||||
|
||||
if len(d.decoded) > 1 {
|
||||
switch d.decoded[1] {
|
||||
case layers.LayerTypeTCP:
|
||||
m.updateTCPChecksum(packetData, ipHeaderLen, oldDst[:], newDst[:])
|
||||
case layers.LayerTypeUDP:
|
||||
m.updateUDPChecksum(packetData, ipHeaderLen, oldDst[:], newDst[:])
|
||||
case layers.LayerTypeICMPv4:
|
||||
m.updateICMPChecksum(packetData, ipHeaderLen)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// rewritePacketSource replaces the source IP address in the packet
|
||||
func (m *Manager) rewritePacketSource(packetData []byte, d *decoder, newIP netip.Addr) error {
|
||||
if len(packetData) < 20 || d.decoded[0] != layers.LayerTypeIPv4 || !newIP.Is4() {
|
||||
return ErrIPv4Only
|
||||
}
|
||||
|
||||
var oldSrc [4]byte
|
||||
copy(oldSrc[:], packetData[12:16])
|
||||
newSrc := newIP.As4()
|
||||
|
||||
copy(packetData[12:16], newSrc[:])
|
||||
|
||||
ipHeaderLen := int(d.ip4.IHL) * 4
|
||||
if ipHeaderLen < 20 || ipHeaderLen > len(packetData) {
|
||||
return fmt.Errorf("invalid IP header length")
|
||||
}
|
||||
|
||||
binary.BigEndian.PutUint16(packetData[10:12], 0)
|
||||
ipChecksum := ipv4Checksum(packetData[:ipHeaderLen])
|
||||
binary.BigEndian.PutUint16(packetData[10:12], ipChecksum)
|
||||
|
||||
if len(d.decoded) > 1 {
|
||||
switch d.decoded[1] {
|
||||
case layers.LayerTypeTCP:
|
||||
m.updateTCPChecksum(packetData, ipHeaderLen, oldSrc[:], newSrc[:])
|
||||
case layers.LayerTypeUDP:
|
||||
m.updateUDPChecksum(packetData, ipHeaderLen, oldSrc[:], newSrc[:])
|
||||
case layers.LayerTypeICMPv4:
|
||||
m.updateICMPChecksum(packetData, ipHeaderLen)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manager) updateTCPChecksum(packetData []byte, ipHeaderLen int, oldIP, newIP []byte) {
|
||||
tcpStart := ipHeaderLen
|
||||
if len(packetData) < tcpStart+18 {
|
||||
return
|
||||
}
|
||||
|
||||
checksumOffset := tcpStart + 16
|
||||
oldChecksum := binary.BigEndian.Uint16(packetData[checksumOffset : checksumOffset+2])
|
||||
newChecksum := incrementalUpdate(oldChecksum, oldIP, newIP)
|
||||
binary.BigEndian.PutUint16(packetData[checksumOffset:checksumOffset+2], newChecksum)
|
||||
}
|
||||
|
||||
func (m *Manager) updateUDPChecksum(packetData []byte, ipHeaderLen int, oldIP, newIP []byte) {
|
||||
udpStart := ipHeaderLen
|
||||
if len(packetData) < udpStart+8 {
|
||||
return
|
||||
}
|
||||
|
||||
checksumOffset := udpStart + 6
|
||||
oldChecksum := binary.BigEndian.Uint16(packetData[checksumOffset : checksumOffset+2])
|
||||
|
||||
if oldChecksum == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
newChecksum := incrementalUpdate(oldChecksum, oldIP, newIP)
|
||||
binary.BigEndian.PutUint16(packetData[checksumOffset:checksumOffset+2], newChecksum)
|
||||
}
|
||||
|
||||
func (m *Manager) updateICMPChecksum(packetData []byte, ipHeaderLen int) {
|
||||
icmpStart := ipHeaderLen
|
||||
if len(packetData) < icmpStart+8 {
|
||||
return
|
||||
}
|
||||
|
||||
icmpData := packetData[icmpStart:]
|
||||
binary.BigEndian.PutUint16(icmpData[2:4], 0)
|
||||
checksum := icmpChecksum(icmpData)
|
||||
binary.BigEndian.PutUint16(icmpData[2:4], checksum)
|
||||
}
|
||||
|
||||
// incrementalUpdate performs incremental checksum update per RFC 1624
|
||||
func incrementalUpdate(oldChecksum uint16, oldBytes, newBytes []byte) uint16 {
|
||||
sum := uint32(^oldChecksum)
|
||||
|
||||
// Fast path for IPv4 addresses (4 bytes) - most common case
|
||||
if len(oldBytes) == 4 && len(newBytes) == 4 {
|
||||
sum += uint32(^binary.BigEndian.Uint16(oldBytes[0:2]))
|
||||
sum += uint32(^binary.BigEndian.Uint16(oldBytes[2:4]))
|
||||
sum += uint32(binary.BigEndian.Uint16(newBytes[0:2]))
|
||||
sum += uint32(binary.BigEndian.Uint16(newBytes[2:4]))
|
||||
} else {
|
||||
// Fallback for other lengths
|
||||
for i := 0; i < len(oldBytes)-1; i += 2 {
|
||||
sum += uint32(^binary.BigEndian.Uint16(oldBytes[i : i+2]))
|
||||
}
|
||||
if len(oldBytes)%2 == 1 {
|
||||
sum += uint32(^oldBytes[len(oldBytes)-1]) << 8
|
||||
}
|
||||
|
||||
for i := 0; i < len(newBytes)-1; i += 2 {
|
||||
sum += uint32(binary.BigEndian.Uint16(newBytes[i : i+2]))
|
||||
}
|
||||
if len(newBytes)%2 == 1 {
|
||||
sum += uint32(newBytes[len(newBytes)-1]) << 8
|
||||
}
|
||||
}
|
||||
|
||||
sum = (sum & 0xFFFF) + (sum >> 16)
|
||||
if sum > 0xFFFF {
|
||||
sum++
|
||||
}
|
||||
|
||||
return ^uint16(sum)
|
||||
}
|
||||
|
||||
// AddDNATRule adds a DNAT rule (delegates to native firewall for port forwarding)
|
||||
func (m *Manager) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) {
|
||||
if m.nativeFirewall == nil {
|
||||
return nil, errNatNotSupported
|
||||
}
|
||||
return m.nativeFirewall.AddDNATRule(rule)
|
||||
}
|
||||
|
||||
// DeleteDNATRule deletes a DNAT rule (delegates to native firewall)
|
||||
func (m *Manager) DeleteDNATRule(rule firewall.Rule) error {
|
||||
if m.nativeFirewall == nil {
|
||||
return errNatNotSupported
|
||||
}
|
||||
return m.nativeFirewall.DeleteDNATRule(rule)
|
||||
}
|
||||
@@ -1,416 +0,0 @@
|
||||
package uspfilter
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
"github.com/google/gopacket"
|
||||
"github.com/google/gopacket/layers"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
)
|
||||
|
||||
// BenchmarkDNATTranslation measures the performance of DNAT operations
|
||||
func BenchmarkDNATTranslation(b *testing.B) {
|
||||
scenarios := []struct {
|
||||
name string
|
||||
proto layers.IPProtocol
|
||||
setupDNAT bool
|
||||
description string
|
||||
}{
|
||||
{
|
||||
name: "tcp_with_dnat",
|
||||
proto: layers.IPProtocolTCP,
|
||||
setupDNAT: true,
|
||||
description: "TCP packet with DNAT translation enabled",
|
||||
},
|
||||
{
|
||||
name: "tcp_without_dnat",
|
||||
proto: layers.IPProtocolTCP,
|
||||
setupDNAT: false,
|
||||
description: "TCP packet without DNAT (baseline)",
|
||||
},
|
||||
{
|
||||
name: "udp_with_dnat",
|
||||
proto: layers.IPProtocolUDP,
|
||||
setupDNAT: true,
|
||||
description: "UDP packet with DNAT translation enabled",
|
||||
},
|
||||
{
|
||||
name: "udp_without_dnat",
|
||||
proto: layers.IPProtocolUDP,
|
||||
setupDNAT: false,
|
||||
description: "UDP packet without DNAT (baseline)",
|
||||
},
|
||||
{
|
||||
name: "icmp_with_dnat",
|
||||
proto: layers.IPProtocolICMPv4,
|
||||
setupDNAT: true,
|
||||
description: "ICMP packet with DNAT translation enabled",
|
||||
},
|
||||
{
|
||||
name: "icmp_without_dnat",
|
||||
proto: layers.IPProtocolICMPv4,
|
||||
setupDNAT: false,
|
||||
description: "ICMP packet without DNAT (baseline)",
|
||||
},
|
||||
}
|
||||
|
||||
for _, sc := range scenarios {
|
||||
b.Run(sc.name, func(b *testing.B) {
|
||||
manager, err := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, false, flowLogger)
|
||||
require.NoError(b, err)
|
||||
defer func() {
|
||||
require.NoError(b, manager.Close(nil))
|
||||
}()
|
||||
|
||||
// Set logger to error level to reduce noise during benchmarking
|
||||
manager.SetLogLevel(log.ErrorLevel)
|
||||
defer func() {
|
||||
// Restore to info level after benchmark
|
||||
manager.SetLogLevel(log.InfoLevel)
|
||||
}()
|
||||
|
||||
// Setup DNAT mapping if needed
|
||||
originalIP := netip.MustParseAddr("192.168.1.100")
|
||||
translatedIP := netip.MustParseAddr("10.0.0.100")
|
||||
|
||||
if sc.setupDNAT {
|
||||
err := manager.AddInternalDNATMapping(originalIP, translatedIP)
|
||||
require.NoError(b, err)
|
||||
}
|
||||
|
||||
// Create test packets
|
||||
srcIP := netip.MustParseAddr("172.16.0.1")
|
||||
outboundPacket := generateDNATTestPacket(b, srcIP, originalIP, sc.proto, 12345, 80)
|
||||
|
||||
// Pre-establish connection for reverse DNAT test
|
||||
if sc.setupDNAT {
|
||||
manager.filterOutbound(outboundPacket, 0)
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
|
||||
// Benchmark outbound DNAT translation
|
||||
b.Run("outbound", func(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
// Create fresh packet each time since translation modifies it
|
||||
packet := generateDNATTestPacket(b, srcIP, originalIP, sc.proto, 12345, 80)
|
||||
manager.filterOutbound(packet, 0)
|
||||
}
|
||||
})
|
||||
|
||||
// Benchmark inbound reverse DNAT translation
|
||||
if sc.setupDNAT {
|
||||
b.Run("inbound_reverse", func(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
// Create fresh packet each time since translation modifies it
|
||||
packet := generateDNATTestPacket(b, translatedIP, srcIP, sc.proto, 80, 12345)
|
||||
manager.filterInbound(packet, 0)
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkDNATConcurrency tests DNAT performance under concurrent load
|
||||
func BenchmarkDNATConcurrency(b *testing.B) {
|
||||
manager, err := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, false, flowLogger)
|
||||
require.NoError(b, err)
|
||||
defer func() {
|
||||
require.NoError(b, manager.Close(nil))
|
||||
}()
|
||||
|
||||
// Set logger to error level to reduce noise during benchmarking
|
||||
manager.SetLogLevel(log.ErrorLevel)
|
||||
defer func() {
|
||||
// Restore to info level after benchmark
|
||||
manager.SetLogLevel(log.InfoLevel)
|
||||
}()
|
||||
|
||||
// Setup multiple DNAT mappings
|
||||
numMappings := 100
|
||||
originalIPs := make([]netip.Addr, numMappings)
|
||||
translatedIPs := make([]netip.Addr, numMappings)
|
||||
|
||||
for i := 0; i < numMappings; i++ {
|
||||
originalIPs[i] = netip.MustParseAddr(fmt.Sprintf("192.168.%d.%d", (i/254)+1, (i%254)+1))
|
||||
translatedIPs[i] = netip.MustParseAddr(fmt.Sprintf("10.0.%d.%d", (i/254)+1, (i%254)+1))
|
||||
err := manager.AddInternalDNATMapping(originalIPs[i], translatedIPs[i])
|
||||
require.NoError(b, err)
|
||||
}
|
||||
|
||||
srcIP := netip.MustParseAddr("172.16.0.1")
|
||||
|
||||
// Pre-generate packets
|
||||
outboundPackets := make([][]byte, numMappings)
|
||||
inboundPackets := make([][]byte, numMappings)
|
||||
for i := 0; i < numMappings; i++ {
|
||||
outboundPackets[i] = generateDNATTestPacket(b, srcIP, originalIPs[i], layers.IPProtocolTCP, 12345, 80)
|
||||
inboundPackets[i] = generateDNATTestPacket(b, translatedIPs[i], srcIP, layers.IPProtocolTCP, 80, 12345)
|
||||
// Establish connections
|
||||
manager.filterOutbound(outboundPackets[i], 0)
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
|
||||
b.Run("concurrent_outbound", func(b *testing.B) {
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
i := 0
|
||||
for pb.Next() {
|
||||
idx := i % numMappings
|
||||
packet := generateDNATTestPacket(b, srcIP, originalIPs[idx], layers.IPProtocolTCP, 12345, 80)
|
||||
manager.filterOutbound(packet, 0)
|
||||
i++
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
b.Run("concurrent_inbound", func(b *testing.B) {
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
i := 0
|
||||
for pb.Next() {
|
||||
idx := i % numMappings
|
||||
packet := generateDNATTestPacket(b, translatedIPs[idx], srcIP, layers.IPProtocolTCP, 80, 12345)
|
||||
manager.filterInbound(packet, 0)
|
||||
i++
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
// BenchmarkDNATScaling tests how DNAT performance scales with number of mappings
|
||||
func BenchmarkDNATScaling(b *testing.B) {
|
||||
mappingCounts := []int{1, 10, 100, 1000}
|
||||
|
||||
for _, count := range mappingCounts {
|
||||
b.Run(fmt.Sprintf("mappings_%d", count), func(b *testing.B) {
|
||||
manager, err := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, false, flowLogger)
|
||||
require.NoError(b, err)
|
||||
defer func() {
|
||||
require.NoError(b, manager.Close(nil))
|
||||
}()
|
||||
|
||||
// Set logger to error level to reduce noise during benchmarking
|
||||
manager.SetLogLevel(log.ErrorLevel)
|
||||
defer func() {
|
||||
// Restore to info level after benchmark
|
||||
manager.SetLogLevel(log.InfoLevel)
|
||||
}()
|
||||
|
||||
// Setup DNAT mappings
|
||||
for i := 0; i < count; i++ {
|
||||
originalIP := netip.MustParseAddr(fmt.Sprintf("192.168.%d.%d", (i/254)+1, (i%254)+1))
|
||||
translatedIP := netip.MustParseAddr(fmt.Sprintf("10.0.%d.%d", (i/254)+1, (i%254)+1))
|
||||
err := manager.AddInternalDNATMapping(originalIP, translatedIP)
|
||||
require.NoError(b, err)
|
||||
}
|
||||
|
||||
// Test with the last mapping added (worst case for lookup)
|
||||
srcIP := netip.MustParseAddr("172.16.0.1")
|
||||
lastOriginal := netip.MustParseAddr(fmt.Sprintf("192.168.%d.%d", ((count-1)/254)+1, ((count-1)%254)+1))
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
packet := generateDNATTestPacket(b, srcIP, lastOriginal, layers.IPProtocolTCP, 12345, 80)
|
||||
manager.filterOutbound(packet, 0)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// generateDNATTestPacket creates a test packet for DNAT benchmarking
|
||||
func generateDNATTestPacket(tb testing.TB, srcIP, dstIP netip.Addr, proto layers.IPProtocol, srcPort, dstPort uint16) []byte {
|
||||
tb.Helper()
|
||||
|
||||
ipv4 := &layers.IPv4{
|
||||
TTL: 64,
|
||||
Version: 4,
|
||||
SrcIP: srcIP.AsSlice(),
|
||||
DstIP: dstIP.AsSlice(),
|
||||
Protocol: proto,
|
||||
}
|
||||
|
||||
var transportLayer gopacket.SerializableLayer
|
||||
switch proto {
|
||||
case layers.IPProtocolTCP:
|
||||
tcp := &layers.TCP{
|
||||
SrcPort: layers.TCPPort(srcPort),
|
||||
DstPort: layers.TCPPort(dstPort),
|
||||
SYN: true,
|
||||
}
|
||||
require.NoError(tb, tcp.SetNetworkLayerForChecksum(ipv4))
|
||||
transportLayer = tcp
|
||||
case layers.IPProtocolUDP:
|
||||
udp := &layers.UDP{
|
||||
SrcPort: layers.UDPPort(srcPort),
|
||||
DstPort: layers.UDPPort(dstPort),
|
||||
}
|
||||
require.NoError(tb, udp.SetNetworkLayerForChecksum(ipv4))
|
||||
transportLayer = udp
|
||||
case layers.IPProtocolICMPv4:
|
||||
icmp := &layers.ICMPv4{
|
||||
TypeCode: layers.CreateICMPv4TypeCode(layers.ICMPv4TypeEchoRequest, 0),
|
||||
}
|
||||
transportLayer = icmp
|
||||
}
|
||||
|
||||
buf := gopacket.NewSerializeBuffer()
|
||||
opts := gopacket.SerializeOptions{ComputeChecksums: true, FixLengths: true}
|
||||
err := gopacket.SerializeLayers(buf, opts, ipv4, transportLayer, gopacket.Payload("test"))
|
||||
require.NoError(tb, err)
|
||||
return buf.Bytes()
|
||||
}
|
||||
|
||||
// BenchmarkChecksumUpdate specifically benchmarks checksum calculation performance
|
||||
func BenchmarkChecksumUpdate(b *testing.B) {
|
||||
// Create test data for checksum calculations
|
||||
testData := make([]byte, 64) // Typical packet size for checksum testing
|
||||
for i := range testData {
|
||||
testData[i] = byte(i)
|
||||
}
|
||||
|
||||
b.Run("ipv4_checksum", func(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
_ = ipv4Checksum(testData[:20]) // IPv4 header is typically 20 bytes
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("icmp_checksum", func(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
_ = icmpChecksum(testData)
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("incremental_update", func(b *testing.B) {
|
||||
oldBytes := []byte{192, 168, 1, 100}
|
||||
newBytes := []byte{10, 0, 0, 100}
|
||||
oldChecksum := uint16(0x1234)
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
_ = incrementalUpdate(oldChecksum, oldBytes, newBytes)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// BenchmarkDNATMemoryAllocations checks for memory allocations in DNAT operations
|
||||
func BenchmarkDNATMemoryAllocations(b *testing.B) {
|
||||
manager, err := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, false, flowLogger)
|
||||
require.NoError(b, err)
|
||||
defer func() {
|
||||
require.NoError(b, manager.Close(nil))
|
||||
}()
|
||||
|
||||
// Set logger to error level to reduce noise during benchmarking
|
||||
manager.SetLogLevel(log.ErrorLevel)
|
||||
defer func() {
|
||||
// Restore to info level after benchmark
|
||||
manager.SetLogLevel(log.InfoLevel)
|
||||
}()
|
||||
|
||||
originalIP := netip.MustParseAddr("192.168.1.100")
|
||||
translatedIP := netip.MustParseAddr("10.0.0.100")
|
||||
srcIP := netip.MustParseAddr("172.16.0.1")
|
||||
|
||||
err = manager.AddInternalDNATMapping(originalIP, translatedIP)
|
||||
require.NoError(b, err)
|
||||
|
||||
packet := generateDNATTestPacket(b, srcIP, originalIP, layers.IPProtocolTCP, 12345, 80)
|
||||
|
||||
b.ResetTimer()
|
||||
b.ReportAllocs()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
// Create fresh packet each time to isolate allocation testing
|
||||
testPacket := make([]byte, len(packet))
|
||||
copy(testPacket, packet)
|
||||
|
||||
// Parse the packet fresh each time to get a clean decoder
|
||||
d := &decoder{decoded: []gopacket.LayerType{}}
|
||||
d.parser = gopacket.NewDecodingLayerParser(
|
||||
layers.LayerTypeIPv4,
|
||||
&d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp,
|
||||
)
|
||||
d.parser.IgnoreUnsupported = true
|
||||
err = d.parser.DecodeLayers(testPacket, &d.decoded)
|
||||
assert.NoError(b, err)
|
||||
|
||||
manager.translateOutboundDNAT(testPacket, d)
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkDirectIPExtraction tests the performance improvement of direct IP extraction
|
||||
func BenchmarkDirectIPExtraction(b *testing.B) {
|
||||
// Create a test packet
|
||||
srcIP := netip.MustParseAddr("172.16.0.1")
|
||||
dstIP := netip.MustParseAddr("192.168.1.100")
|
||||
packet := generateDNATTestPacket(b, srcIP, dstIP, layers.IPProtocolTCP, 12345, 80)
|
||||
|
||||
b.Run("direct_byte_access", func(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
// Direct extraction from packet bytes
|
||||
_ = netip.AddrFrom4([4]byte{packet[16], packet[17], packet[18], packet[19]})
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("decoder_extraction", func(b *testing.B) {
|
||||
// Create decoder once for comparison
|
||||
d := &decoder{decoded: []gopacket.LayerType{}}
|
||||
d.parser = gopacket.NewDecodingLayerParser(
|
||||
layers.LayerTypeIPv4,
|
||||
&d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp,
|
||||
)
|
||||
d.parser.IgnoreUnsupported = true
|
||||
err := d.parser.DecodeLayers(packet, &d.decoded)
|
||||
assert.NoError(b, err)
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
// Extract using decoder (traditional method)
|
||||
dst, _ := netip.AddrFromSlice(d.ip4.DstIP)
|
||||
_ = dst
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// BenchmarkChecksumOptimizations compares optimized vs standard checksum implementations
|
||||
func BenchmarkChecksumOptimizations(b *testing.B) {
|
||||
// Create test IPv4 header (20 bytes)
|
||||
header := make([]byte, 20)
|
||||
for i := range header {
|
||||
header[i] = byte(i)
|
||||
}
|
||||
// Clear checksum field
|
||||
header[10] = 0
|
||||
header[11] = 0
|
||||
|
||||
b.Run("optimized_ipv4_checksum", func(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
_ = ipv4Checksum(header)
|
||||
}
|
||||
})
|
||||
|
||||
// Test incremental checksum updates
|
||||
oldIP := []byte{192, 168, 1, 100}
|
||||
newIP := []byte{10, 0, 0, 100}
|
||||
oldChecksum := uint16(0x1234)
|
||||
|
||||
b.Run("optimized_incremental_update", func(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
_ = incrementalUpdate(oldChecksum, oldIP, newIP)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -1,145 +0,0 @@
|
||||
package uspfilter
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
"github.com/google/gopacket"
|
||||
"github.com/google/gopacket/layers"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
)
|
||||
|
||||
// TestDNATTranslationCorrectness verifies DNAT translation works correctly
|
||||
func TestDNATTranslationCorrectness(t *testing.T) {
|
||||
manager, err := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, false, flowLogger)
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
require.NoError(t, manager.Close(nil))
|
||||
}()
|
||||
|
||||
originalIP := netip.MustParseAddr("192.168.1.100")
|
||||
translatedIP := netip.MustParseAddr("10.0.0.100")
|
||||
srcIP := netip.MustParseAddr("172.16.0.1")
|
||||
|
||||
// Add DNAT mapping
|
||||
err = manager.AddInternalDNATMapping(originalIP, translatedIP)
|
||||
require.NoError(t, err)
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
protocol layers.IPProtocol
|
||||
srcPort uint16
|
||||
dstPort uint16
|
||||
}{
|
||||
{"TCP", layers.IPProtocolTCP, 12345, 80},
|
||||
{"UDP", layers.IPProtocolUDP, 12345, 53},
|
||||
{"ICMP", layers.IPProtocolICMPv4, 0, 0},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
// Test outbound DNAT translation
|
||||
outboundPacket := generateDNATTestPacket(t, srcIP, originalIP, tc.protocol, tc.srcPort, tc.dstPort)
|
||||
originalOutbound := make([]byte, len(outboundPacket))
|
||||
copy(originalOutbound, outboundPacket)
|
||||
|
||||
// Process outbound packet (should translate destination)
|
||||
translated := manager.translateOutboundDNAT(outboundPacket, parsePacket(t, outboundPacket))
|
||||
require.True(t, translated, "Outbound packet should be translated")
|
||||
|
||||
// Verify destination IP was changed
|
||||
dstIPAfter := netip.AddrFrom4([4]byte{outboundPacket[16], outboundPacket[17], outboundPacket[18], outboundPacket[19]})
|
||||
require.Equal(t, translatedIP, dstIPAfter, "Destination IP should be translated")
|
||||
|
||||
// Test inbound reverse DNAT translation
|
||||
inboundPacket := generateDNATTestPacket(t, translatedIP, srcIP, tc.protocol, tc.dstPort, tc.srcPort)
|
||||
originalInbound := make([]byte, len(inboundPacket))
|
||||
copy(originalInbound, inboundPacket)
|
||||
|
||||
// Process inbound packet (should reverse translate source)
|
||||
reversed := manager.translateInboundReverse(inboundPacket, parsePacket(t, inboundPacket))
|
||||
require.True(t, reversed, "Inbound packet should be reverse translated")
|
||||
|
||||
// Verify source IP was changed back to original
|
||||
srcIPAfter := netip.AddrFrom4([4]byte{inboundPacket[12], inboundPacket[13], inboundPacket[14], inboundPacket[15]})
|
||||
require.Equal(t, originalIP, srcIPAfter, "Source IP should be reverse translated")
|
||||
|
||||
// Test that checksums are recalculated correctly
|
||||
if tc.protocol != layers.IPProtocolICMPv4 {
|
||||
// For TCP/UDP, verify the transport checksum was updated
|
||||
require.NotEqual(t, originalOutbound, outboundPacket, "Outbound packet should be modified")
|
||||
require.NotEqual(t, originalInbound, inboundPacket, "Inbound packet should be modified")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// parsePacket helper to create a decoder for testing
|
||||
func parsePacket(t testing.TB, packetData []byte) *decoder {
|
||||
t.Helper()
|
||||
d := &decoder{
|
||||
decoded: []gopacket.LayerType{},
|
||||
}
|
||||
d.parser = gopacket.NewDecodingLayerParser(
|
||||
layers.LayerTypeIPv4,
|
||||
&d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp,
|
||||
)
|
||||
d.parser.IgnoreUnsupported = true
|
||||
|
||||
err := d.parser.DecodeLayers(packetData, &d.decoded)
|
||||
require.NoError(t, err)
|
||||
return d
|
||||
}
|
||||
|
||||
// TestDNATMappingManagement tests adding/removing DNAT mappings
|
||||
func TestDNATMappingManagement(t *testing.T) {
|
||||
manager, err := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, false, flowLogger)
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
require.NoError(t, manager.Close(nil))
|
||||
}()
|
||||
|
||||
originalIP := netip.MustParseAddr("192.168.1.100")
|
||||
translatedIP := netip.MustParseAddr("10.0.0.100")
|
||||
|
||||
// Test adding mapping
|
||||
err = manager.AddInternalDNATMapping(originalIP, translatedIP)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify mapping exists
|
||||
result, exists := manager.getDNATTranslation(originalIP)
|
||||
require.True(t, exists)
|
||||
require.Equal(t, translatedIP, result)
|
||||
|
||||
// Test reverse lookup
|
||||
reverseResult, exists := manager.findReverseDNATMapping(translatedIP)
|
||||
require.True(t, exists)
|
||||
require.Equal(t, originalIP, reverseResult)
|
||||
|
||||
// Test removing mapping
|
||||
err = manager.RemoveInternalDNATMapping(originalIP)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify mapping no longer exists
|
||||
_, exists = manager.getDNATTranslation(originalIP)
|
||||
require.False(t, exists)
|
||||
|
||||
_, exists = manager.findReverseDNATMapping(translatedIP)
|
||||
require.False(t, exists)
|
||||
|
||||
// Test error cases
|
||||
err = manager.AddInternalDNATMapping(netip.Addr{}, translatedIP)
|
||||
require.Error(t, err, "Should reject invalid original IP")
|
||||
|
||||
err = manager.AddInternalDNATMapping(originalIP, netip.Addr{})
|
||||
require.Error(t, err, "Should reject invalid translated IP")
|
||||
|
||||
err = manager.RemoveInternalDNATMapping(originalIP)
|
||||
require.Error(t, err, "Should error when removing non-existent mapping")
|
||||
}
|
||||
@@ -401,7 +401,7 @@ func (m *Manager) addForwardingResult(trace *PacketTrace, action, remoteAddr str
|
||||
|
||||
func (m *Manager) traceOutbound(packetData []byte, trace *PacketTrace) *PacketTrace {
|
||||
// will create or update the connection state
|
||||
dropped := m.filterOutbound(packetData, 0)
|
||||
dropped := m.processOutgoingHooks(packetData, 0)
|
||||
if dropped {
|
||||
trace.AddResult(StageCompleted, "Packet dropped by outgoing hook", false)
|
||||
} else {
|
||||
|
||||
@@ -104,12 +104,6 @@ type Manager struct {
|
||||
flowLogger nftypes.FlowLogger
|
||||
|
||||
blockRule firewall.Rule
|
||||
|
||||
// Internal 1:1 DNAT
|
||||
dnatEnabled atomic.Bool
|
||||
dnatMappings map[netip.Addr]netip.Addr
|
||||
dnatMutex sync.RWMutex
|
||||
dnatBiMap *biDNATMap
|
||||
}
|
||||
|
||||
// decoder for packages
|
||||
@@ -195,7 +189,6 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe
|
||||
flowLogger: flowLogger,
|
||||
netstack: netstack.IsEnabled(),
|
||||
localForwarding: enableLocalForwarding,
|
||||
dnatMappings: make(map[netip.Addr]netip.Addr),
|
||||
}
|
||||
m.routingEnabled.Store(false)
|
||||
|
||||
@@ -526,6 +519,22 @@ func (m *Manager) SetLegacyManagement(isLegacy bool) error {
|
||||
// Flush doesn't need to be implemented for this manager
|
||||
func (m *Manager) Flush() error { return nil }
|
||||
|
||||
// AddDNATRule adds a DNAT rule
|
||||
func (m *Manager) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) {
|
||||
if m.nativeFirewall == nil {
|
||||
return nil, errNatNotSupported
|
||||
}
|
||||
return m.nativeFirewall.AddDNATRule(rule)
|
||||
}
|
||||
|
||||
// DeleteDNATRule deletes a DNAT rule
|
||||
func (m *Manager) DeleteDNATRule(rule firewall.Rule) error {
|
||||
if m.nativeFirewall == nil {
|
||||
return errNatNotSupported
|
||||
}
|
||||
return m.nativeFirewall.DeleteDNATRule(rule)
|
||||
}
|
||||
|
||||
// UpdateSet updates the rule destinations associated with the given set
|
||||
// by merging the existing prefixes with the new ones, then deduplicating.
|
||||
func (m *Manager) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
|
||||
@@ -572,14 +581,14 @@ func (m *Manager) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// FilterOutBound filters outgoing packets
|
||||
func (m *Manager) FilterOutbound(packetData []byte, size int) bool {
|
||||
return m.filterOutbound(packetData, size)
|
||||
// DropOutgoing filter outgoing packets
|
||||
func (m *Manager) DropOutgoing(packetData []byte, size int) bool {
|
||||
return m.processOutgoingHooks(packetData, size)
|
||||
}
|
||||
|
||||
// FilterInbound filters incoming packets
|
||||
func (m *Manager) FilterInbound(packetData []byte, size int) bool {
|
||||
return m.filterInbound(packetData, size)
|
||||
// DropIncoming filter incoming packets
|
||||
func (m *Manager) DropIncoming(packetData []byte, size int) bool {
|
||||
return m.dropFilter(packetData, size)
|
||||
}
|
||||
|
||||
// UpdateLocalIPs updates the list of local IPs
|
||||
@@ -587,7 +596,7 @@ func (m *Manager) UpdateLocalIPs() error {
|
||||
return m.localipmanager.UpdateLocalIPs(m.wgIface)
|
||||
}
|
||||
|
||||
func (m *Manager) filterOutbound(packetData []byte, size int) bool {
|
||||
func (m *Manager) processOutgoingHooks(packetData []byte, size int) bool {
|
||||
d := m.decoders.Get().(*decoder)
|
||||
defer m.decoders.Put(d)
|
||||
|
||||
@@ -609,8 +618,8 @@ func (m *Manager) filterOutbound(packetData []byte, size int) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// for netflow we keep track even if the firewall is stateless
|
||||
m.trackOutbound(d, srcIP, dstIP, size)
|
||||
m.translateOutboundDNAT(packetData, d)
|
||||
|
||||
return false
|
||||
}
|
||||
@@ -662,7 +671,7 @@ func (m *Manager) trackOutbound(d *decoder, srcIP, dstIP netip.Addr, size int) {
|
||||
flags := getTCPFlags(&d.tcp)
|
||||
m.tcpTracker.TrackOutbound(srcIP, dstIP, uint16(d.tcp.SrcPort), uint16(d.tcp.DstPort), flags, size)
|
||||
case layers.LayerTypeICMPv4:
|
||||
m.icmpTracker.TrackOutbound(srcIP, dstIP, d.icmp4.Id, d.icmp4.TypeCode, d.icmp4.Payload, size)
|
||||
m.icmpTracker.TrackOutbound(srcIP, dstIP, d.icmp4.Id, d.icmp4.TypeCode, size)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -675,7 +684,7 @@ func (m *Manager) trackInbound(d *decoder, srcIP, dstIP netip.Addr, ruleID []byt
|
||||
flags := getTCPFlags(&d.tcp)
|
||||
m.tcpTracker.TrackInbound(srcIP, dstIP, uint16(d.tcp.SrcPort), uint16(d.tcp.DstPort), flags, ruleID, size)
|
||||
case layers.LayerTypeICMPv4:
|
||||
m.icmpTracker.TrackInbound(srcIP, dstIP, d.icmp4.Id, d.icmp4.TypeCode, ruleID, d.icmp4.Payload, size)
|
||||
m.icmpTracker.TrackInbound(srcIP, dstIP, d.icmp4.Id, d.icmp4.TypeCode, ruleID, size)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -714,9 +723,9 @@ func (m *Manager) udpHooksDrop(dport uint16, dstIP netip.Addr, packetData []byte
|
||||
return false
|
||||
}
|
||||
|
||||
// filterInbound implements filtering logic for incoming packets.
|
||||
// dropFilter implements filtering logic for incoming packets.
|
||||
// If it returns true, the packet should be dropped.
|
||||
func (m *Manager) filterInbound(packetData []byte, size int) bool {
|
||||
func (m *Manager) dropFilter(packetData []byte, size int) bool {
|
||||
d := m.decoders.Get().(*decoder)
|
||||
defer m.decoders.Put(d)
|
||||
|
||||
@@ -738,15 +747,8 @@ func (m *Manager) filterInbound(packetData []byte, size int) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
if translated := m.translateInboundReverse(packetData, d); translated {
|
||||
// Re-decode after translation to get original addresses
|
||||
if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil {
|
||||
m.logger.Error("Failed to re-decode packet after reverse DNAT: %v", err)
|
||||
return true
|
||||
}
|
||||
srcIP, dstIP = m.extractIPs(d)
|
||||
}
|
||||
|
||||
// For all inbound traffic, first check if it matches a tracked connection.
|
||||
// This must happen before any other filtering because the packets are statefully tracked.
|
||||
if m.stateful && m.isValidTrackedConnection(d, srcIP, dstIP, size) {
|
||||
return false
|
||||
}
|
||||
@@ -188,13 +188,13 @@ func BenchmarkCoreFiltering(b *testing.B) {
|
||||
|
||||
// For stateful scenarios, establish the connection
|
||||
if sc.stateful {
|
||||
manager.filterOutbound(outbound, 0)
|
||||
manager.processOutgoingHooks(outbound, 0)
|
||||
}
|
||||
|
||||
// Measure inbound packet processing
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
manager.filterInbound(inbound, 0)
|
||||
manager.dropFilter(inbound, 0)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -220,7 +220,7 @@ func BenchmarkStateScaling(b *testing.B) {
|
||||
for i := 0; i < count; i++ {
|
||||
outbound := generatePacket(b, srcIPs[i], dstIPs[i],
|
||||
uint16(1024+i), 80, layers.IPProtocolTCP)
|
||||
manager.filterOutbound(outbound, 0)
|
||||
manager.processOutgoingHooks(outbound, 0)
|
||||
}
|
||||
|
||||
// Test packet
|
||||
@@ -228,11 +228,11 @@ func BenchmarkStateScaling(b *testing.B) {
|
||||
testIn := generatePacket(b, dstIPs[0], srcIPs[0], 80, 1024, layers.IPProtocolTCP)
|
||||
|
||||
// First establish our test connection
|
||||
manager.filterOutbound(testOut, 0)
|
||||
manager.processOutgoingHooks(testOut, 0)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
manager.filterInbound(testIn, 0)
|
||||
manager.dropFilter(testIn, 0)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -263,12 +263,12 @@ func BenchmarkEstablishmentOverhead(b *testing.B) {
|
||||
inbound := generatePacket(b, dstIP, srcIP, 80, 1024, layers.IPProtocolTCP)
|
||||
|
||||
if sc.established {
|
||||
manager.filterOutbound(outbound, 0)
|
||||
manager.processOutgoingHooks(outbound, 0)
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
manager.filterInbound(inbound, 0)
|
||||
manager.dropFilter(inbound, 0)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -426,25 +426,25 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
|
||||
// For stateful cases and established connections
|
||||
if !strings.Contains(sc.name, "allow_non_wg") ||
|
||||
(strings.Contains(sc.state, "established") || sc.state == "post_handshake") {
|
||||
manager.filterOutbound(outbound, 0)
|
||||
manager.processOutgoingHooks(outbound, 0)
|
||||
|
||||
// For TCP post-handshake, simulate full handshake
|
||||
if sc.state == "post_handshake" {
|
||||
// SYN
|
||||
syn := generateTCPPacketWithFlags(b, srcIP, dstIP, 1024, 80, uint16(conntrack.TCPSyn))
|
||||
manager.filterOutbound(syn, 0)
|
||||
manager.processOutgoingHooks(syn, 0)
|
||||
// SYN-ACK
|
||||
synack := generateTCPPacketWithFlags(b, dstIP, srcIP, 80, 1024, uint16(conntrack.TCPSyn|conntrack.TCPAck))
|
||||
manager.filterInbound(synack, 0)
|
||||
manager.dropFilter(synack, 0)
|
||||
// ACK
|
||||
ack := generateTCPPacketWithFlags(b, srcIP, dstIP, 1024, 80, uint16(conntrack.TCPAck))
|
||||
manager.filterOutbound(ack, 0)
|
||||
manager.processOutgoingHooks(ack, 0)
|
||||
}
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
manager.filterInbound(inbound, 0)
|
||||
manager.dropFilter(inbound, 0)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -568,17 +568,17 @@ func BenchmarkLongLivedConnections(b *testing.B) {
|
||||
// Initial SYN
|
||||
syn := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
|
||||
uint16(1024+i), 80, uint16(conntrack.TCPSyn))
|
||||
manager.filterOutbound(syn, 0)
|
||||
manager.processOutgoingHooks(syn, 0)
|
||||
|
||||
// SYN-ACK
|
||||
synack := generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i],
|
||||
80, uint16(1024+i), uint16(conntrack.TCPSyn|conntrack.TCPAck))
|
||||
manager.filterInbound(synack, 0)
|
||||
manager.dropFilter(synack, 0)
|
||||
|
||||
// ACK
|
||||
ack := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
|
||||
uint16(1024+i), 80, uint16(conntrack.TCPAck))
|
||||
manager.filterOutbound(ack, 0)
|
||||
manager.processOutgoingHooks(ack, 0)
|
||||
}
|
||||
|
||||
// Prepare test packets simulating bidirectional traffic
|
||||
@@ -599,9 +599,9 @@ func BenchmarkLongLivedConnections(b *testing.B) {
|
||||
|
||||
// Simulate bidirectional traffic
|
||||
// First outbound data
|
||||
manager.filterOutbound(outPackets[connIdx], 0)
|
||||
manager.processOutgoingHooks(outPackets[connIdx], 0)
|
||||
// Then inbound response - this is what we're actually measuring
|
||||
manager.filterInbound(inPackets[connIdx], 0)
|
||||
manager.dropFilter(inPackets[connIdx], 0)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -700,19 +700,19 @@ func BenchmarkShortLivedConnections(b *testing.B) {
|
||||
p := patterns[connIdx]
|
||||
|
||||
// Connection establishment
|
||||
manager.filterOutbound(p.syn, 0)
|
||||
manager.filterInbound(p.synAck, 0)
|
||||
manager.filterOutbound(p.ack, 0)
|
||||
manager.processOutgoingHooks(p.syn, 0)
|
||||
manager.dropFilter(p.synAck, 0)
|
||||
manager.processOutgoingHooks(p.ack, 0)
|
||||
|
||||
// Data transfer
|
||||
manager.filterOutbound(p.request, 0)
|
||||
manager.filterInbound(p.response, 0)
|
||||
manager.processOutgoingHooks(p.request, 0)
|
||||
manager.dropFilter(p.response, 0)
|
||||
|
||||
// Connection teardown
|
||||
manager.filterOutbound(p.finClient, 0)
|
||||
manager.filterInbound(p.ackServer, 0)
|
||||
manager.filterInbound(p.finServer, 0)
|
||||
manager.filterOutbound(p.ackClient, 0)
|
||||
manager.processOutgoingHooks(p.finClient, 0)
|
||||
manager.dropFilter(p.ackServer, 0)
|
||||
manager.dropFilter(p.finServer, 0)
|
||||
manager.processOutgoingHooks(p.ackClient, 0)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -760,15 +760,15 @@ func BenchmarkParallelLongLivedConnections(b *testing.B) {
|
||||
for i := 0; i < sc.connCount; i++ {
|
||||
syn := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
|
||||
uint16(1024+i), 80, uint16(conntrack.TCPSyn))
|
||||
manager.filterOutbound(syn, 0)
|
||||
manager.processOutgoingHooks(syn, 0)
|
||||
|
||||
synack := generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i],
|
||||
80, uint16(1024+i), uint16(conntrack.TCPSyn|conntrack.TCPAck))
|
||||
manager.filterInbound(synack, 0)
|
||||
manager.dropFilter(synack, 0)
|
||||
|
||||
ack := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
|
||||
uint16(1024+i), 80, uint16(conntrack.TCPAck))
|
||||
manager.filterOutbound(ack, 0)
|
||||
manager.processOutgoingHooks(ack, 0)
|
||||
}
|
||||
|
||||
// Pre-generate test packets
|
||||
@@ -790,8 +790,8 @@ func BenchmarkParallelLongLivedConnections(b *testing.B) {
|
||||
counter++
|
||||
|
||||
// Simulate bidirectional traffic
|
||||
manager.filterOutbound(outPackets[connIdx], 0)
|
||||
manager.filterInbound(inPackets[connIdx], 0)
|
||||
manager.processOutgoingHooks(outPackets[connIdx], 0)
|
||||
manager.dropFilter(inPackets[connIdx], 0)
|
||||
}
|
||||
})
|
||||
})
|
||||
@@ -879,17 +879,17 @@ func BenchmarkParallelShortLivedConnections(b *testing.B) {
|
||||
p := patterns[connIdx]
|
||||
|
||||
// Full connection lifecycle
|
||||
manager.filterOutbound(p.syn, 0)
|
||||
manager.filterInbound(p.synAck, 0)
|
||||
manager.filterOutbound(p.ack, 0)
|
||||
manager.processOutgoingHooks(p.syn, 0)
|
||||
manager.dropFilter(p.synAck, 0)
|
||||
manager.processOutgoingHooks(p.ack, 0)
|
||||
|
||||
manager.filterOutbound(p.request, 0)
|
||||
manager.filterInbound(p.response, 0)
|
||||
manager.processOutgoingHooks(p.request, 0)
|
||||
manager.dropFilter(p.response, 0)
|
||||
|
||||
manager.filterOutbound(p.finClient, 0)
|
||||
manager.filterInbound(p.ackServer, 0)
|
||||
manager.filterInbound(p.finServer, 0)
|
||||
manager.filterOutbound(p.ackClient, 0)
|
||||
manager.processOutgoingHooks(p.finClient, 0)
|
||||
manager.dropFilter(p.ackServer, 0)
|
||||
manager.dropFilter(p.finServer, 0)
|
||||
manager.processOutgoingHooks(p.ackClient, 0)
|
||||
}
|
||||
})
|
||||
})
|
||||
@@ -462,7 +462,7 @@ func TestPeerACLFiltering(t *testing.T) {
|
||||
|
||||
t.Run("Implicit DROP (no rules)", func(t *testing.T) {
|
||||
packet := createTestPacket(t, "100.10.0.1", "100.10.0.100", fw.ProtocolTCP, 12345, 443)
|
||||
isDropped := manager.FilterInbound(packet, 0)
|
||||
isDropped := manager.DropIncoming(packet, 0)
|
||||
require.True(t, isDropped, "Packet should be dropped when no rules exist")
|
||||
})
|
||||
|
||||
@@ -509,7 +509,7 @@ func TestPeerACLFiltering(t *testing.T) {
|
||||
})
|
||||
|
||||
packet := createTestPacket(t, tc.srcIP, tc.dstIP, tc.proto, tc.srcPort, tc.dstPort)
|
||||
isDropped := manager.FilterInbound(packet, 0)
|
||||
isDropped := manager.DropIncoming(packet, 0)
|
||||
require.Equal(t, tc.shouldBeBlocked, isDropped)
|
||||
})
|
||||
}
|
||||
@@ -1233,7 +1233,7 @@ func TestRouteACLFiltering(t *testing.T) {
|
||||
srcIP := netip.MustParseAddr(tc.srcIP)
|
||||
dstIP := netip.MustParseAddr(tc.dstIP)
|
||||
|
||||
// testing routeACLsPass only and not FilterInbound, as routed packets are dropped after being passed
|
||||
// testing routeACLsPass only and not DropIncoming, as routed packets are dropped after being passed
|
||||
// to the forwarder
|
||||
_, isAllowed := manager.routeACLsPass(srcIP, dstIP, tc.proto, tc.srcPort, tc.dstPort)
|
||||
require.Equal(t, tc.shouldPass, isAllowed)
|
||||
@@ -321,7 +321,7 @@ func TestNotMatchByIP(t *testing.T) {
|
||||
return
|
||||
}
|
||||
|
||||
if m.filterInbound(buf.Bytes(), 0) {
|
||||
if m.dropFilter(buf.Bytes(), 0) {
|
||||
t.Errorf("expected packet to be accepted")
|
||||
return
|
||||
}
|
||||
@@ -447,7 +447,7 @@ func TestProcessOutgoingHooks(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
// Test hook gets called
|
||||
result := manager.filterOutbound(buf.Bytes(), 0)
|
||||
result := manager.processOutgoingHooks(buf.Bytes(), 0)
|
||||
require.True(t, result)
|
||||
require.True(t, hookCalled)
|
||||
|
||||
@@ -457,7 +457,7 @@ func TestProcessOutgoingHooks(t *testing.T) {
|
||||
err = gopacket.SerializeLayers(buf, opts, ipv4)
|
||||
require.NoError(t, err)
|
||||
|
||||
result = manager.filterOutbound(buf.Bytes(), 0)
|
||||
result = manager.processOutgoingHooks(buf.Bytes(), 0)
|
||||
require.False(t, result)
|
||||
}
|
||||
|
||||
@@ -553,7 +553,7 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
// Process outbound packet and verify connection tracking
|
||||
drop := manager.FilterOutbound(outboundBuf.Bytes(), 0)
|
||||
drop := manager.DropOutgoing(outboundBuf.Bytes(), 0)
|
||||
require.False(t, drop, "Initial outbound packet should not be dropped")
|
||||
|
||||
// Verify connection was tracked
|
||||
@@ -620,7 +620,7 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
|
||||
for _, cp := range checkPoints {
|
||||
time.Sleep(cp.sleep)
|
||||
|
||||
drop = manager.filterInbound(inboundBuf.Bytes(), 0)
|
||||
drop = manager.dropFilter(inboundBuf.Bytes(), 0)
|
||||
require.Equal(t, cp.shouldAllow, !drop, cp.description)
|
||||
|
||||
// If the connection should still be valid, verify it exists
|
||||
@@ -669,7 +669,7 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
|
||||
}
|
||||
|
||||
// Create a new outbound connection for invalid tests
|
||||
drop = manager.filterOutbound(outboundBuf.Bytes(), 0)
|
||||
drop = manager.processOutgoingHooks(outboundBuf.Bytes(), 0)
|
||||
require.False(t, drop, "Second outbound packet should not be dropped")
|
||||
|
||||
for _, tc := range invalidCases {
|
||||
@@ -691,7 +691,7 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify the invalid packet is dropped
|
||||
drop = manager.filterInbound(testBuf.Bytes(), 0)
|
||||
drop = manager.dropFilter(testBuf.Bytes(), 0)
|
||||
require.True(t, drop, tc.description)
|
||||
})
|
||||
}
|
||||
@@ -1,94 +0,0 @@
|
||||
package bind
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/monotime"
|
||||
)
|
||||
|
||||
const (
|
||||
saveFrequency = int64(5 * time.Second)
|
||||
)
|
||||
|
||||
type PeerRecord struct {
|
||||
Address netip.AddrPort
|
||||
LastActivity atomic.Int64 // UnixNano timestamp
|
||||
}
|
||||
|
||||
type ActivityRecorder struct {
|
||||
mu sync.RWMutex
|
||||
peers map[string]*PeerRecord // publicKey to PeerRecord map
|
||||
addrToPeer map[netip.AddrPort]*PeerRecord // address to PeerRecord map
|
||||
}
|
||||
|
||||
func NewActivityRecorder() *ActivityRecorder {
|
||||
return &ActivityRecorder{
|
||||
peers: make(map[string]*PeerRecord),
|
||||
addrToPeer: make(map[netip.AddrPort]*PeerRecord),
|
||||
}
|
||||
}
|
||||
|
||||
// GetLastActivities returns a snapshot of peer last activity
|
||||
func (r *ActivityRecorder) GetLastActivities() map[string]time.Time {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
|
||||
activities := make(map[string]time.Time, len(r.peers))
|
||||
for key, record := range r.peers {
|
||||
unixNano := record.LastActivity.Load()
|
||||
activities[key] = time.Unix(0, unixNano)
|
||||
}
|
||||
return activities
|
||||
}
|
||||
|
||||
// UpsertAddress adds or updates the address for a publicKey
|
||||
func (r *ActivityRecorder) UpsertAddress(publicKey string, address netip.AddrPort) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
if pr, exists := r.peers[publicKey]; exists {
|
||||
delete(r.addrToPeer, pr.Address)
|
||||
pr.Address = address
|
||||
} else {
|
||||
record := &PeerRecord{
|
||||
Address: address,
|
||||
}
|
||||
record.LastActivity.Store(monotime.Now())
|
||||
r.peers[publicKey] = record
|
||||
}
|
||||
|
||||
r.addrToPeer[address] = r.peers[publicKey]
|
||||
}
|
||||
|
||||
func (r *ActivityRecorder) Remove(publicKey string) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
if record, exists := r.peers[publicKey]; exists {
|
||||
delete(r.addrToPeer, record.Address)
|
||||
delete(r.peers, publicKey)
|
||||
}
|
||||
}
|
||||
|
||||
// record updates LastActivity for the given address using atomic store
|
||||
func (r *ActivityRecorder) record(address netip.AddrPort) {
|
||||
r.mu.RLock()
|
||||
record, ok := r.addrToPeer[address]
|
||||
r.mu.RUnlock()
|
||||
if !ok {
|
||||
log.Warnf("could not find record for address %s", address)
|
||||
return
|
||||
}
|
||||
|
||||
now := monotime.Now()
|
||||
last := record.LastActivity.Load()
|
||||
if now-last < saveFrequency {
|
||||
return
|
||||
}
|
||||
|
||||
_ = record.LastActivity.CompareAndSwap(last, now)
|
||||
}
|
||||
@@ -1,27 +0,0 @@
|
||||
package bind
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestActivityRecorder_GetLastActivities(t *testing.T) {
|
||||
peer := "peer1"
|
||||
ar := NewActivityRecorder()
|
||||
ar.UpsertAddress("peer1", netip.MustParseAddrPort("192.168.0.5:51820"))
|
||||
activities := ar.GetLastActivities()
|
||||
|
||||
p, ok := activities[peer]
|
||||
if !ok {
|
||||
t.Fatalf("Expected activity for peer %s, but got none", peer)
|
||||
}
|
||||
|
||||
if p.IsZero() {
|
||||
t.Fatalf("Expected activity for peer %s, but got zero", peer)
|
||||
}
|
||||
|
||||
if p.Before(time.Now().Add(-2 * time.Minute)) {
|
||||
t.Fatalf("Expected activity for peer %s to be recent, but got %v", peer, p)
|
||||
}
|
||||
}
|
||||
@@ -1,153 +0,0 @@
|
||||
package bind
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/go-multierror"
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
)
|
||||
|
||||
var (
|
||||
errNoIPv4Conn = errors.New("no IPv4 connection available")
|
||||
errNoIPv6Conn = errors.New("no IPv6 connection available")
|
||||
errInvalidAddr = errors.New("invalid address type")
|
||||
)
|
||||
|
||||
// DualStackPacketConn is a composite PacketConn that can handle both IPv4 and IPv6
|
||||
type DualStackPacketConn struct {
|
||||
ipv4Conn net.PacketConn
|
||||
ipv6Conn net.PacketConn
|
||||
}
|
||||
|
||||
// NewDualStackPacketConn creates a new dual-stack packet connection
|
||||
func NewDualStackPacketConn(ipv4Conn, ipv6Conn net.PacketConn) *DualStackPacketConn {
|
||||
return &DualStackPacketConn{
|
||||
ipv4Conn: ipv4Conn,
|
||||
ipv6Conn: ipv6Conn,
|
||||
}
|
||||
}
|
||||
|
||||
// ReadFrom reads from both IPv4 and IPv6 connections
|
||||
func (d *DualStackPacketConn) ReadFrom(b []byte) (n int, addr net.Addr, err error) {
|
||||
// Prefer IPv4 if available
|
||||
if d.ipv4Conn != nil {
|
||||
return d.ipv4Conn.ReadFrom(b)
|
||||
}
|
||||
if d.ipv6Conn != nil {
|
||||
return d.ipv6Conn.ReadFrom(b)
|
||||
}
|
||||
return 0, nil, net.ErrClosed
|
||||
}
|
||||
|
||||
// WriteTo writes to the appropriate connection based on the address type
|
||||
func (d *DualStackPacketConn) WriteTo(b []byte, addr net.Addr) (n int, err error) {
|
||||
udpAddr, ok := addr.(*net.UDPAddr)
|
||||
if !ok {
|
||||
return 0, &net.OpError{
|
||||
Op: "write",
|
||||
Net: "udp",
|
||||
Addr: addr,
|
||||
Err: errInvalidAddr,
|
||||
}
|
||||
}
|
||||
|
||||
if udpAddr.IP.To4() == nil {
|
||||
if d.ipv6Conn != nil {
|
||||
return d.ipv6Conn.WriteTo(b, addr)
|
||||
}
|
||||
return 0, &net.OpError{
|
||||
Op: "write",
|
||||
Net: "udp6",
|
||||
Addr: addr,
|
||||
Err: errNoIPv6Conn,
|
||||
}
|
||||
}
|
||||
|
||||
if d.ipv4Conn != nil {
|
||||
return d.ipv4Conn.WriteTo(b, addr)
|
||||
}
|
||||
return 0, &net.OpError{
|
||||
Op: "write",
|
||||
Net: "udp4",
|
||||
Addr: addr,
|
||||
Err: errNoIPv4Conn,
|
||||
}
|
||||
}
|
||||
|
||||
// Close closes both connections
|
||||
func (d *DualStackPacketConn) Close() error {
|
||||
var result *multierror.Error
|
||||
if d.ipv4Conn != nil {
|
||||
if err := d.ipv4Conn.Close(); err != nil {
|
||||
result = multierror.Append(result, err)
|
||||
}
|
||||
}
|
||||
if d.ipv6Conn != nil {
|
||||
if err := d.ipv6Conn.Close(); err != nil {
|
||||
result = multierror.Append(result, err)
|
||||
}
|
||||
}
|
||||
return nberrors.FormatErrorOrNil(result)
|
||||
}
|
||||
|
||||
// LocalAddr returns the local address of the IPv4 connection (for compatibility)
|
||||
func (d *DualStackPacketConn) LocalAddr() net.Addr {
|
||||
if d.ipv4Conn != nil {
|
||||
return d.ipv4Conn.LocalAddr()
|
||||
}
|
||||
if d.ipv6Conn != nil {
|
||||
return d.ipv6Conn.LocalAddr()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetDeadline sets the deadline for both connections
|
||||
func (d *DualStackPacketConn) SetDeadline(t time.Time) error {
|
||||
var result *multierror.Error
|
||||
if d.ipv4Conn != nil {
|
||||
if err := d.ipv4Conn.SetDeadline(t); err != nil {
|
||||
result = multierror.Append(result, err)
|
||||
}
|
||||
}
|
||||
if d.ipv6Conn != nil {
|
||||
if err := d.ipv6Conn.SetDeadline(t); err != nil {
|
||||
result = multierror.Append(result, err)
|
||||
}
|
||||
}
|
||||
return nberrors.FormatErrorOrNil(result)
|
||||
}
|
||||
|
||||
// SetReadDeadline sets the read deadline for both connections
|
||||
func (d *DualStackPacketConn) SetReadDeadline(t time.Time) error {
|
||||
var result *multierror.Error
|
||||
if d.ipv4Conn != nil {
|
||||
if err := d.ipv4Conn.SetReadDeadline(t); err != nil {
|
||||
result = multierror.Append(result, err)
|
||||
}
|
||||
}
|
||||
if d.ipv6Conn != nil {
|
||||
if err := d.ipv6Conn.SetReadDeadline(t); err != nil {
|
||||
result = multierror.Append(result, err)
|
||||
}
|
||||
}
|
||||
return nberrors.FormatErrorOrNil(result)
|
||||
}
|
||||
|
||||
// SetWriteDeadline sets the write deadline for both connections
|
||||
func (d *DualStackPacketConn) SetWriteDeadline(t time.Time) error {
|
||||
var result *multierror.Error
|
||||
if d.ipv4Conn != nil {
|
||||
if err := d.ipv4Conn.SetWriteDeadline(t); err != nil {
|
||||
result = multierror.Append(result, err)
|
||||
}
|
||||
}
|
||||
if d.ipv6Conn != nil {
|
||||
if err := d.ipv6Conn.SetWriteDeadline(t); err != nil {
|
||||
result = multierror.Append(result, err)
|
||||
}
|
||||
}
|
||||
return nberrors.FormatErrorOrNil(result)
|
||||
}
|
||||
@@ -1,8 +1,7 @@
|
||||
package bind
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"runtime"
|
||||
@@ -11,6 +10,7 @@ import (
|
||||
"github.com/pion/stun/v2"
|
||||
"github.com/pion/transport/v3"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/net/ipv4"
|
||||
"golang.org/x/net/ipv6"
|
||||
wgConn "golang.zx2c4.com/wireguard/conn"
|
||||
|
||||
@@ -26,8 +26,8 @@ type receiverCreator struct {
|
||||
iceBind *ICEBind
|
||||
}
|
||||
|
||||
func (rc receiverCreator) CreateReceiverFn(pc wgConn.BatchReader, conn *net.UDPConn, rxOffload bool, msgPool *sync.Pool) wgConn.ReceiveFunc {
|
||||
return rc.iceBind.createReceiverFn(pc, conn, rxOffload, msgPool)
|
||||
func (rc receiverCreator) CreateIPv4ReceiverFn(pc *ipv4.PacketConn, conn *net.UDPConn, rxOffload bool, msgPool *sync.Pool) wgConn.ReceiveFunc {
|
||||
return rc.iceBind.createIPv4ReceiverFn(pc, conn, rxOffload, msgPool)
|
||||
}
|
||||
|
||||
// ICEBind is a bind implementation with two main features:
|
||||
@@ -51,26 +51,22 @@ type ICEBind struct {
|
||||
closedChanMu sync.RWMutex // protect the closeChan recreation from reading from it.
|
||||
closed bool
|
||||
|
||||
muUDPMux sync.Mutex
|
||||
udpMux *UniversalUDPMuxDefault
|
||||
ipv4Conn *net.UDPConn
|
||||
ipv6Conn *net.UDPConn
|
||||
address wgaddr.Address
|
||||
activityRecorder *ActivityRecorder
|
||||
muUDPMux sync.Mutex
|
||||
udpMux *UniversalUDPMuxDefault
|
||||
address wgaddr.Address
|
||||
}
|
||||
|
||||
func NewICEBind(transportNet transport.Net, filterFn FilterFn, address wgaddr.Address) *ICEBind {
|
||||
b, _ := wgConn.NewStdNetBind().(*wgConn.StdNetBind)
|
||||
ib := &ICEBind{
|
||||
StdNetBind: b,
|
||||
RecvChan: make(chan RecvMessage, 1),
|
||||
transportNet: transportNet,
|
||||
filterFn: filterFn,
|
||||
endpoints: make(map[netip.Addr]net.Conn),
|
||||
closedChan: make(chan struct{}),
|
||||
closed: true,
|
||||
address: address,
|
||||
activityRecorder: NewActivityRecorder(),
|
||||
StdNetBind: b,
|
||||
RecvChan: make(chan RecvMessage, 1),
|
||||
transportNet: transportNet,
|
||||
filterFn: filterFn,
|
||||
endpoints: make(map[netip.Addr]net.Conn),
|
||||
closedChan: make(chan struct{}),
|
||||
closed: true,
|
||||
address: address,
|
||||
}
|
||||
|
||||
rc := receiverCreator{
|
||||
@@ -104,19 +100,15 @@ func (s *ICEBind) Close() error {
|
||||
return s.StdNetBind.Close()
|
||||
}
|
||||
|
||||
func (s *ICEBind) ActivityRecorder() *ActivityRecorder {
|
||||
return s.activityRecorder
|
||||
}
|
||||
|
||||
// GetICEMux returns the ICE UDPMux that was created and used by ICEBind
|
||||
func (s *ICEBind) GetICEMux() (*UniversalUDPMuxDefault, error) {
|
||||
s.muUDPMux.Lock()
|
||||
defer s.muUDPMux.Unlock()
|
||||
|
||||
if s.udpMux != nil {
|
||||
return s.udpMux, nil
|
||||
if s.udpMux == nil {
|
||||
return nil, fmt.Errorf("ICEBind has not been initialized yet")
|
||||
}
|
||||
return nil, errors.New("ICEBind has not been initialized yet")
|
||||
|
||||
return s.udpMux, nil
|
||||
}
|
||||
|
||||
func (b *ICEBind) SetEndpoint(fakeIP netip.Addr, conn net.Conn) {
|
||||
@@ -148,46 +140,18 @@ func (b *ICEBind) Send(bufs [][]byte, ep wgConn.Endpoint) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *ICEBind) createReceiverFn(pc wgConn.BatchReader, conn *net.UDPConn, rxOffload bool, msgsPool *sync.Pool) wgConn.ReceiveFunc {
|
||||
func (s *ICEBind) createIPv4ReceiverFn(pc *ipv4.PacketConn, conn *net.UDPConn, rxOffload bool, msgsPool *sync.Pool) wgConn.ReceiveFunc {
|
||||
s.muUDPMux.Lock()
|
||||
defer s.muUDPMux.Unlock()
|
||||
|
||||
localAddr, ok := conn.LocalAddr().(*net.UDPAddr)
|
||||
if !ok {
|
||||
log.Errorf("ICEBind: unexpected address type: %T", conn.LocalAddr())
|
||||
return nil
|
||||
}
|
||||
isIPv6 := localAddr.IP.To4() == nil
|
||||
|
||||
if isIPv6 {
|
||||
s.ipv6Conn = conn
|
||||
} else {
|
||||
s.ipv4Conn = conn
|
||||
}
|
||||
|
||||
needsNewMux := s.udpMux == nil && (s.ipv4Conn != nil || s.ipv6Conn != nil)
|
||||
needsUpgrade := s.udpMux != nil && s.ipv4Conn != nil && s.ipv6Conn != nil
|
||||
|
||||
if needsNewMux || needsUpgrade {
|
||||
var iceMuxConn net.PacketConn
|
||||
switch {
|
||||
case s.ipv4Conn != nil && s.ipv6Conn != nil:
|
||||
iceMuxConn = NewDualStackPacketConn(s.ipv4Conn, s.ipv6Conn)
|
||||
case s.ipv4Conn != nil:
|
||||
iceMuxConn = s.ipv4Conn
|
||||
default:
|
||||
iceMuxConn = s.ipv6Conn
|
||||
}
|
||||
|
||||
s.udpMux = NewUniversalUDPMuxDefault(
|
||||
UniversalUDPMuxParams{
|
||||
UDPConn: iceMuxConn,
|
||||
Net: s.transportNet,
|
||||
FilterFn: s.filterFn,
|
||||
WGAddress: s.address,
|
||||
},
|
||||
)
|
||||
}
|
||||
s.udpMux = NewUniversalUDPMuxDefault(
|
||||
UniversalUDPMuxParams{
|
||||
UDPConn: conn,
|
||||
Net: s.transportNet,
|
||||
FilterFn: s.filterFn,
|
||||
WGAddress: s.address,
|
||||
},
|
||||
)
|
||||
return func(bufs [][]byte, sizes []int, eps []wgConn.Endpoint) (n int, err error) {
|
||||
msgs := getMessages(msgsPool)
|
||||
for i := range bufs {
|
||||
@@ -234,17 +198,7 @@ func (s *ICEBind) createReceiverFn(pc wgConn.BatchReader, conn *net.UDPConn, rxO
|
||||
if sizes[i] == 0 {
|
||||
continue
|
||||
}
|
||||
udpAddr, ok := msg.Addr.(*net.UDPAddr)
|
||||
if !ok {
|
||||
log.Errorf("ICEBind: unexpected address type: %T", msg.Addr)
|
||||
continue
|
||||
}
|
||||
addrPort := udpAddr.AddrPort()
|
||||
|
||||
if isTransportPkg(msg.Buffers, msg.N) {
|
||||
s.activityRecorder.record(addrPort)
|
||||
}
|
||||
|
||||
addrPort := msg.Addr.(*net.UDPAddr).AddrPort()
|
||||
ep := &wgConn.StdNetEndpoint{AddrPort: addrPort} // TODO: remove allocation
|
||||
wgConn.GetSrcFromControl(msg.OOB[:msg.NN], ep)
|
||||
eps[i] = ep
|
||||
@@ -265,10 +219,9 @@ func (s *ICEBind) filterOutStunMessages(buffers [][]byte, n int, addr net.Addr)
|
||||
return true, err
|
||||
}
|
||||
|
||||
if s.udpMux != nil {
|
||||
if err := s.udpMux.HandleSTUNMessage(msg, addr); err != nil {
|
||||
log.Warnf("failed to handle STUN packet: %v", err)
|
||||
}
|
||||
muxErr := s.udpMux.HandleSTUNMessage(msg, addr)
|
||||
if muxErr != nil {
|
||||
log.Warnf("failed to handle STUN packet")
|
||||
}
|
||||
|
||||
buffers[i] = []byte{}
|
||||
@@ -304,13 +257,6 @@ func (c *ICEBind) receiveRelayed(buffs [][]byte, sizes []int, eps []wgConn.Endpo
|
||||
copy(buffs[0], msg.Buffer)
|
||||
sizes[0] = len(msg.Buffer)
|
||||
eps[0] = wgConn.Endpoint(msg.Endpoint)
|
||||
|
||||
if isTransportPkg(buffs, sizes[0]) {
|
||||
if ep, ok := eps[0].(*Endpoint); ok {
|
||||
c.activityRecorder.record(ep.AddrPort)
|
||||
}
|
||||
}
|
||||
|
||||
return 1, nil
|
||||
}
|
||||
}
|
||||
@@ -326,19 +272,3 @@ func putMessages(msgs *[]ipv6.Message, msgsPool *sync.Pool) {
|
||||
}
|
||||
msgsPool.Put(msgs)
|
||||
}
|
||||
|
||||
func isTransportPkg(buffers [][]byte, n int) bool {
|
||||
// The first buffer should contain at least 4 bytes for type
|
||||
if len(buffers[0]) < 4 {
|
||||
return true
|
||||
}
|
||||
|
||||
// WireGuard packet type is a little-endian uint32 at start
|
||||
packetType := binary.LittleEndian.Uint32(buffers[0][:4])
|
||||
|
||||
// Check if packetType matches known WireGuard message types
|
||||
if packetType == 4 && n > 32 {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -342,9 +342,6 @@ func (m *UDPMuxDefault) Close() error {
|
||||
}
|
||||
|
||||
func (m *UDPMuxDefault) writeTo(buf []byte, rAddr net.Addr) (n int, err error) {
|
||||
if dualStackConn, ok := m.params.UDPConn.(*DualStackPacketConn); ok {
|
||||
return dualStackConn.WriteTo(buf, rAddr)
|
||||
}
|
||||
return m.params.UDPConn.WriteTo(buf, rAddr)
|
||||
}
|
||||
|
||||
|
||||
@@ -126,11 +126,6 @@ type udpConn struct {
|
||||
}
|
||||
|
||||
func (u *udpConn) WriteTo(b []byte, addr net.Addr) (int, error) {
|
||||
// Check if this is a dual-stack connection and handle IPv6 addresses properly
|
||||
if dualStackConn, ok := u.PacketConn.(*DualStackPacketConn); ok {
|
||||
return dualStackConn.WriteTo(b, addr)
|
||||
}
|
||||
|
||||
if u.filterFn == nil {
|
||||
return u.PacketConn.WriteTo(b, addr)
|
||||
}
|
||||
@@ -146,11 +141,6 @@ func (u *udpConn) handleCachedAddress(isRouted bool, b []byte, addr net.Addr) (i
|
||||
if isRouted {
|
||||
return 0, fmt.Errorf("address %s is part of a routed network, refusing to write", addr)
|
||||
}
|
||||
|
||||
if dualStackConn, ok := u.PacketConn.(*DualStackPacketConn); ok {
|
||||
return dualStackConn.WriteTo(b, addr)
|
||||
}
|
||||
|
||||
return u.PacketConn.WriteTo(b, addr)
|
||||
}
|
||||
|
||||
@@ -158,11 +148,6 @@ func (u *udpConn) handleUncachedAddress(b []byte, addr net.Addr) (int, error) {
|
||||
if err := u.performFilterCheck(addr); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
if dualStackConn, ok := u.PacketConn.(*DualStackPacketConn); ok {
|
||||
return dualStackConn.WriteTo(b, addr)
|
||||
}
|
||||
|
||||
return u.PacketConn.WriteTo(b, addr)
|
||||
}
|
||||
|
||||
|
||||
@@ -276,7 +276,3 @@ func (c *KernelConfigurer) GetStats() (map[string]WGStats, error) {
|
||||
}
|
||||
return stats, nil
|
||||
}
|
||||
|
||||
func (c *KernelConfigurer) LastActivities() map[string]time.Time {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -16,7 +16,6 @@ import (
|
||||
"golang.zx2c4.com/wireguard/device"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/bind"
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
)
|
||||
|
||||
@@ -37,18 +36,16 @@ const (
|
||||
var ErrAllowedIPNotFound = fmt.Errorf("allowed IP not found")
|
||||
|
||||
type WGUSPConfigurer struct {
|
||||
device *device.Device
|
||||
deviceName string
|
||||
activityRecorder *bind.ActivityRecorder
|
||||
device *device.Device
|
||||
deviceName string
|
||||
|
||||
uapiListener net.Listener
|
||||
}
|
||||
|
||||
func NewUSPConfigurer(device *device.Device, deviceName string, activityRecorder *bind.ActivityRecorder) *WGUSPConfigurer {
|
||||
func NewUSPConfigurer(device *device.Device, deviceName string) *WGUSPConfigurer {
|
||||
wgCfg := &WGUSPConfigurer{
|
||||
device: device,
|
||||
deviceName: deviceName,
|
||||
activityRecorder: activityRecorder,
|
||||
device: device,
|
||||
deviceName: deviceName,
|
||||
}
|
||||
wgCfg.startUAPI()
|
||||
return wgCfg
|
||||
@@ -90,19 +87,7 @@ func (c *WGUSPConfigurer) UpdatePeer(peerKey string, allowedIps []netip.Prefix,
|
||||
Peers: []wgtypes.PeerConfig{peer},
|
||||
}
|
||||
|
||||
if ipcErr := c.device.IpcSet(toWgUserspaceString(config)); ipcErr != nil {
|
||||
return ipcErr
|
||||
}
|
||||
|
||||
if endpoint != nil {
|
||||
addr, err := netip.ParseAddr(endpoint.IP.String())
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse endpoint address: %w", err)
|
||||
}
|
||||
addrPort := netip.AddrPortFrom(addr, uint16(endpoint.Port))
|
||||
c.activityRecorder.UpsertAddress(peerKey, addrPort)
|
||||
}
|
||||
return nil
|
||||
return c.device.IpcSet(toWgUserspaceString(config))
|
||||
}
|
||||
|
||||
func (c *WGUSPConfigurer) RemovePeer(peerKey string) error {
|
||||
@@ -119,10 +104,7 @@ func (c *WGUSPConfigurer) RemovePeer(peerKey string) error {
|
||||
config := wgtypes.Config{
|
||||
Peers: []wgtypes.PeerConfig{peer},
|
||||
}
|
||||
ipcErr := c.device.IpcSet(toWgUserspaceString(config))
|
||||
|
||||
c.activityRecorder.Remove(peerKey)
|
||||
return ipcErr
|
||||
return c.device.IpcSet(toWgUserspaceString(config))
|
||||
}
|
||||
|
||||
func (c *WGUSPConfigurer) AddAllowedIP(peerKey string, allowedIP netip.Prefix) error {
|
||||
@@ -223,10 +205,6 @@ func (c *WGUSPConfigurer) FullStats() (*Stats, error) {
|
||||
return parseStatus(c.deviceName, ipcStr)
|
||||
}
|
||||
|
||||
func (c *WGUSPConfigurer) LastActivities() map[string]time.Time {
|
||||
return c.activityRecorder.GetLastActivities()
|
||||
}
|
||||
|
||||
// startUAPI starts the UAPI listener for managing the WireGuard interface via external tool
|
||||
func (t *WGUSPConfigurer) startUAPI() {
|
||||
var err error
|
||||
|
||||
@@ -24,7 +24,6 @@ type WGTunDevice struct {
|
||||
mtu int
|
||||
iceBind *bind.ICEBind
|
||||
tunAdapter TunAdapter
|
||||
disableDNS bool
|
||||
|
||||
name string
|
||||
device *device.Device
|
||||
@@ -33,7 +32,7 @@ type WGTunDevice struct {
|
||||
configurer WGConfigurer
|
||||
}
|
||||
|
||||
func NewTunDevice(address wgaddr.Address, port int, key string, mtu int, iceBind *bind.ICEBind, tunAdapter TunAdapter, disableDNS bool) *WGTunDevice {
|
||||
func NewTunDevice(address wgaddr.Address, port int, key string, mtu int, iceBind *bind.ICEBind, tunAdapter TunAdapter) *WGTunDevice {
|
||||
return &WGTunDevice{
|
||||
address: address,
|
||||
port: port,
|
||||
@@ -41,7 +40,6 @@ func NewTunDevice(address wgaddr.Address, port int, key string, mtu int, iceBind
|
||||
mtu: mtu,
|
||||
iceBind: iceBind,
|
||||
tunAdapter: tunAdapter,
|
||||
disableDNS: disableDNS,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -51,13 +49,6 @@ func (t *WGTunDevice) Create(routes []string, dns string, searchDomains []string
|
||||
routesString := routesToString(routes)
|
||||
searchDomainsToString := searchDomainsToString(searchDomains)
|
||||
|
||||
// Skip DNS configuration when DisableDNS is enabled
|
||||
if t.disableDNS {
|
||||
log.Info("DNS is disabled, skipping DNS and search domain configuration")
|
||||
dns = ""
|
||||
searchDomainsToString = ""
|
||||
}
|
||||
|
||||
fd, err := t.tunAdapter.ConfigureInterface(t.address.String(), t.mtu, dns, searchDomainsToString, routesString)
|
||||
if err != nil {
|
||||
log.Errorf("failed to create Android interface: %s", err)
|
||||
@@ -79,7 +70,7 @@ func (t *WGTunDevice) Create(routes []string, dns string, searchDomains []string
|
||||
// this helps with support for the older NetBird clients that had a hardcoded direct mode
|
||||
// t.device.DisableSomeRoamingForBrokenMobileSemantics()
|
||||
|
||||
t.configurer = configurer.NewUSPConfigurer(t.device, t.name, t.iceBind.ActivityRecorder())
|
||||
t.configurer = configurer.NewUSPConfigurer(t.device, t.name)
|
||||
err = t.configurer.ConfigureInterface(t.key, t.port)
|
||||
if err != nil {
|
||||
t.device.Close()
|
||||
|
||||
@@ -61,7 +61,7 @@ func (t *TunDevice) Create() (WGConfigurer, error) {
|
||||
return nil, fmt.Errorf("error assigning ip: %s", err)
|
||||
}
|
||||
|
||||
t.configurer = configurer.NewUSPConfigurer(t.device, t.name, t.iceBind.ActivityRecorder())
|
||||
t.configurer = configurer.NewUSPConfigurer(t.device, t.name)
|
||||
err = t.configurer.ConfigureInterface(t.key, t.port)
|
||||
if err != nil {
|
||||
t.device.Close()
|
||||
|
||||
@@ -9,11 +9,11 @@ import (
|
||||
|
||||
// PacketFilter interface for firewall abilities
|
||||
type PacketFilter interface {
|
||||
// FilterOutbound filter outgoing packets from host to external destinations
|
||||
FilterOutbound(packetData []byte, size int) bool
|
||||
// DropOutgoing filter outgoing packets from host to external destinations
|
||||
DropOutgoing(packetData []byte, size int) bool
|
||||
|
||||
// FilterInbound filter incoming packets from external sources to host
|
||||
FilterInbound(packetData []byte, size int) bool
|
||||
// DropIncoming filter incoming packets from external sources to host
|
||||
DropIncoming(packetData []byte, size int) bool
|
||||
|
||||
// AddUDPPacketHook calls hook when UDP packet from given direction matched
|
||||
//
|
||||
@@ -54,7 +54,7 @@ func (d *FilteredDevice) Read(bufs [][]byte, sizes []int, offset int) (n int, er
|
||||
}
|
||||
|
||||
for i := 0; i < n; i++ {
|
||||
if filter.FilterOutbound(bufs[i][offset:offset+sizes[i]], sizes[i]) {
|
||||
if filter.DropOutgoing(bufs[i][offset:offset+sizes[i]], sizes[i]) {
|
||||
bufs = append(bufs[:i], bufs[i+1:]...)
|
||||
sizes = append(sizes[:i], sizes[i+1:]...)
|
||||
n--
|
||||
@@ -78,7 +78,7 @@ func (d *FilteredDevice) Write(bufs [][]byte, offset int) (int, error) {
|
||||
filteredBufs := make([][]byte, 0, len(bufs))
|
||||
dropped := 0
|
||||
for _, buf := range bufs {
|
||||
if !filter.FilterInbound(buf[offset:], len(buf)) {
|
||||
if !filter.DropIncoming(buf[offset:], len(buf)) {
|
||||
filteredBufs = append(filteredBufs, buf)
|
||||
dropped++
|
||||
}
|
||||
|
||||
@@ -146,7 +146,7 @@ func TestDeviceWrapperRead(t *testing.T) {
|
||||
tun.EXPECT().Write(mockBufs, 0).Return(0, nil)
|
||||
|
||||
filter := mocks.NewMockPacketFilter(ctrl)
|
||||
filter.EXPECT().FilterInbound(gomock.Any(), gomock.Any()).Return(true)
|
||||
filter.EXPECT().DropIncoming(gomock.Any(), gomock.Any()).Return(true)
|
||||
|
||||
wrapped := newDeviceFilter(tun)
|
||||
wrapped.filter = filter
|
||||
@@ -201,7 +201,7 @@ func TestDeviceWrapperRead(t *testing.T) {
|
||||
return 1, nil
|
||||
})
|
||||
filter := mocks.NewMockPacketFilter(ctrl)
|
||||
filter.EXPECT().FilterOutbound(gomock.Any(), gomock.Any()).Return(true)
|
||||
filter.EXPECT().DropOutgoing(gomock.Any(), gomock.Any()).Return(true)
|
||||
|
||||
wrapped := newDeviceFilter(tun)
|
||||
wrapped.filter = filter
|
||||
|
||||
@@ -71,7 +71,7 @@ func (t *TunDevice) Create() (WGConfigurer, error) {
|
||||
// this helps with support for the older NetBird clients that had a hardcoded direct mode
|
||||
// t.device.DisableSomeRoamingForBrokenMobileSemantics()
|
||||
|
||||
t.configurer = configurer.NewUSPConfigurer(t.device, t.name, t.iceBind.ActivityRecorder())
|
||||
t.configurer = configurer.NewUSPConfigurer(t.device, t.name)
|
||||
err = t.configurer.ConfigureInterface(t.key, t.port)
|
||||
if err != nil {
|
||||
t.device.Close()
|
||||
|
||||
@@ -72,7 +72,7 @@ func (t *TunNetstackDevice) Create() (WGConfigurer, error) {
|
||||
device.NewLogger(wgLogLevel(), "[netbird] "),
|
||||
)
|
||||
|
||||
t.configurer = configurer.NewUSPConfigurer(t.device, t.name, t.iceBind.ActivityRecorder())
|
||||
t.configurer = configurer.NewUSPConfigurer(t.device, t.name)
|
||||
err = t.configurer.ConfigureInterface(t.key, t.port)
|
||||
if err != nil {
|
||||
_ = tunIface.Close()
|
||||
|
||||
@@ -64,7 +64,7 @@ func (t *USPDevice) Create() (WGConfigurer, error) {
|
||||
return nil, fmt.Errorf("error assigning ip: %s", err)
|
||||
}
|
||||
|
||||
t.configurer = configurer.NewUSPConfigurer(t.device, t.name, t.iceBind.ActivityRecorder())
|
||||
t.configurer = configurer.NewUSPConfigurer(t.device, t.name)
|
||||
err = t.configurer.ConfigureInterface(t.key, t.port)
|
||||
if err != nil {
|
||||
t.device.Close()
|
||||
|
||||
@@ -94,7 +94,7 @@ func (t *TunDevice) Create() (WGConfigurer, error) {
|
||||
return nil, fmt.Errorf("error assigning ip: %s", err)
|
||||
}
|
||||
|
||||
t.configurer = configurer.NewUSPConfigurer(t.device, t.name, t.iceBind.ActivityRecorder())
|
||||
t.configurer = configurer.NewUSPConfigurer(t.device, t.name)
|
||||
err = t.configurer.ConfigureInterface(t.key, t.port)
|
||||
if err != nil {
|
||||
t.device.Close()
|
||||
|
||||
@@ -19,5 +19,4 @@ type WGConfigurer interface {
|
||||
Close()
|
||||
GetStats() (map[string]configurer.WGStats, error)
|
||||
FullStats() (*configurer.Stats, error)
|
||||
LastActivities() map[string]time.Time
|
||||
}
|
||||
|
||||
@@ -43,7 +43,6 @@ type WGIFaceOpts struct {
|
||||
MobileArgs *device.MobileIFaceArguments
|
||||
TransportNet transport.Net
|
||||
FilterFn bind.FilterFn
|
||||
DisableDNS bool
|
||||
}
|
||||
|
||||
// WGIface represents an interface instance
|
||||
@@ -217,14 +216,6 @@ func (w *WGIface) GetStats() (map[string]configurer.WGStats, error) {
|
||||
return w.configurer.GetStats()
|
||||
}
|
||||
|
||||
func (w *WGIface) LastActivities() map[string]time.Time {
|
||||
w.mu.Lock()
|
||||
defer w.mu.Unlock()
|
||||
|
||||
return w.configurer.LastActivities()
|
||||
|
||||
}
|
||||
|
||||
func (w *WGIface) FullStats() (*configurer.Stats, error) {
|
||||
return w.configurer.FullStats()
|
||||
}
|
||||
|
||||
@@ -18,7 +18,7 @@ func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
|
||||
|
||||
wgIFace := &WGIface{
|
||||
userspaceBind: true,
|
||||
tun: device.NewTunDevice(wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, opts.MobileArgs.TunAdapter, opts.DisableDNS),
|
||||
tun: device.NewTunDevice(wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, opts.MobileArgs.TunAdapter),
|
||||
wgProxyFactory: wgproxy.NewUSPFactory(iceBind),
|
||||
}
|
||||
return wgIFace, nil
|
||||
|
||||
@@ -48,32 +48,32 @@ func (mr *MockPacketFilterMockRecorder) AddUDPPacketHook(arg0, arg1, arg2, arg3
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddUDPPacketHook", reflect.TypeOf((*MockPacketFilter)(nil).AddUDPPacketHook), arg0, arg1, arg2, arg3)
|
||||
}
|
||||
|
||||
// FilterInbound mocks base method.
|
||||
func (m *MockPacketFilter) FilterInbound(arg0 []byte, arg1 int) bool {
|
||||
// DropIncoming mocks base method.
|
||||
func (m *MockPacketFilter) DropIncoming(arg0 []byte, arg1 int) bool {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "FilterInbound", arg0, arg1)
|
||||
ret := m.ctrl.Call(m, "DropIncoming", arg0, arg1)
|
||||
ret0, _ := ret[0].(bool)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// FilterInbound indicates an expected call of FilterInbound.
|
||||
func (mr *MockPacketFilterMockRecorder) FilterInbound(arg0 interface{}, arg1 any) *gomock.Call {
|
||||
// DropIncoming indicates an expected call of DropIncoming.
|
||||
func (mr *MockPacketFilterMockRecorder) DropIncoming(arg0 interface{}, arg1 any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FilterInbound", reflect.TypeOf((*MockPacketFilter)(nil).FilterInbound), arg0, arg1)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DropIncoming", reflect.TypeOf((*MockPacketFilter)(nil).DropIncoming), arg0, arg1)
|
||||
}
|
||||
|
||||
// FilterOutbound mocks base method.
|
||||
func (m *MockPacketFilter) FilterOutbound(arg0 []byte, arg1 int) bool {
|
||||
// DropOutgoing mocks base method.
|
||||
func (m *MockPacketFilter) DropOutgoing(arg0 []byte, arg1 int) bool {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "FilterOutbound", arg0, arg1)
|
||||
ret := m.ctrl.Call(m, "DropOutgoing", arg0, arg1)
|
||||
ret0, _ := ret[0].(bool)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// FilterOutbound indicates an expected call of FilterOutbound.
|
||||
func (mr *MockPacketFilterMockRecorder) FilterOutbound(arg0 interface{}, arg1 any) *gomock.Call {
|
||||
// DropOutgoing indicates an expected call of DropOutgoing.
|
||||
func (mr *MockPacketFilterMockRecorder) DropOutgoing(arg0 interface{}, arg1 any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FilterOutbound", reflect.TypeOf((*MockPacketFilter)(nil).FilterOutbound), arg0, arg1)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DropOutgoing", reflect.TypeOf((*MockPacketFilter)(nil).DropOutgoing), arg0, arg1)
|
||||
}
|
||||
|
||||
// RemovePacketHook mocks base method.
|
||||
|
||||
@@ -46,32 +46,32 @@ func (mr *MockPacketFilterMockRecorder) AddUDPPacketHook(arg0, arg1, arg2, arg3
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddUDPPacketHook", reflect.TypeOf((*MockPacketFilter)(nil).AddUDPPacketHook), arg0, arg1, arg2, arg3)
|
||||
}
|
||||
|
||||
// FilterInbound mocks base method.
|
||||
func (m *MockPacketFilter) FilterInbound(arg0 []byte) bool {
|
||||
// DropIncoming mocks base method.
|
||||
func (m *MockPacketFilter) DropIncoming(arg0 []byte) bool {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "FilterInbound", arg0)
|
||||
ret := m.ctrl.Call(m, "DropIncoming", arg0)
|
||||
ret0, _ := ret[0].(bool)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// FilterInbound indicates an expected call of FilterInbound.
|
||||
func (mr *MockPacketFilterMockRecorder) FilterInbound(arg0 interface{}) *gomock.Call {
|
||||
// DropIncoming indicates an expected call of DropIncoming.
|
||||
func (mr *MockPacketFilterMockRecorder) DropIncoming(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FilterInbound", reflect.TypeOf((*MockPacketFilter)(nil).FilterInbound), arg0)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DropIncoming", reflect.TypeOf((*MockPacketFilter)(nil).DropIncoming), arg0)
|
||||
}
|
||||
|
||||
// FilterOutbound mocks base method.
|
||||
func (m *MockPacketFilter) FilterOutbound(arg0 []byte) bool {
|
||||
// DropOutgoing mocks base method.
|
||||
func (m *MockPacketFilter) DropOutgoing(arg0 []byte) bool {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "FilterOutbound", arg0)
|
||||
ret := m.ctrl.Call(m, "DropOutgoing", arg0)
|
||||
ret0, _ := ret[0].(bool)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// FilterOutbound indicates an expected call of FilterOutbound.
|
||||
func (mr *MockPacketFilterMockRecorder) FilterOutbound(arg0 interface{}) *gomock.Call {
|
||||
// DropOutgoing indicates an expected call of DropOutgoing.
|
||||
func (mr *MockPacketFilterMockRecorder) DropOutgoing(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FilterOutbound", reflect.TypeOf((*MockPacketFilter)(nil).FilterOutbound), arg0)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DropOutgoing", reflect.TypeOf((*MockPacketFilter)(nil).DropOutgoing), arg0)
|
||||
}
|
||||
|
||||
// SetNetwork mocks base method.
|
||||
|
||||
@@ -398,15 +398,11 @@ func (d *DefaultManager) squashAcceptRules(
|
||||
//
|
||||
// We zeroed this to notify squash function that this protocol can't be squashed.
|
||||
addRuleToCalculationMap := func(i int, r *mgmProto.FirewallRule, protocols map[mgmProto.RuleProtocol]*protoMatch) {
|
||||
hasPortRestrictions := r.Action == mgmProto.RuleAction_DROP ||
|
||||
r.Port != "" || !portInfoEmpty(r.PortInfo)
|
||||
|
||||
if hasPortRestrictions {
|
||||
// Don't squash rules with port restrictions
|
||||
drop := r.Action == mgmProto.RuleAction_DROP || r.Port != ""
|
||||
if drop {
|
||||
protocols[r.Protocol] = &protoMatch{ips: map[string]int{}}
|
||||
return
|
||||
}
|
||||
|
||||
if _, ok := protocols[r.Protocol]; !ok {
|
||||
protocols[r.Protocol] = &protoMatch{
|
||||
ips: map[string]int{},
|
||||
|
||||
@@ -330,434 +330,6 @@ func TestDefaultManagerSquashRulesNoAffect(t *testing.T) {
|
||||
assert.Equal(t, len(networkMap.FirewallRules), len(rules))
|
||||
}
|
||||
|
||||
func TestDefaultManagerSquashRulesWithPortRestrictions(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
rules []*mgmProto.FirewallRule
|
||||
expectedCount int
|
||||
description string
|
||||
}{
|
||||
{
|
||||
name: "should not squash rules with port ranges",
|
||||
rules: []*mgmProto.FirewallRule{
|
||||
{
|
||||
PeerIP: "10.93.0.1",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
PortInfo: &mgmProto.PortInfo{
|
||||
PortSelection: &mgmProto.PortInfo_Range_{
|
||||
Range: &mgmProto.PortInfo_Range{
|
||||
Start: 8080,
|
||||
End: 8090,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.2",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
PortInfo: &mgmProto.PortInfo{
|
||||
PortSelection: &mgmProto.PortInfo_Range_{
|
||||
Range: &mgmProto.PortInfo_Range{
|
||||
Start: 8080,
|
||||
End: 8090,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.3",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
PortInfo: &mgmProto.PortInfo{
|
||||
PortSelection: &mgmProto.PortInfo_Range_{
|
||||
Range: &mgmProto.PortInfo_Range{
|
||||
Start: 8080,
|
||||
End: 8090,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.4",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
PortInfo: &mgmProto.PortInfo{
|
||||
PortSelection: &mgmProto.PortInfo_Range_{
|
||||
Range: &mgmProto.PortInfo_Range{
|
||||
Start: 8080,
|
||||
End: 8090,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
expectedCount: 4,
|
||||
description: "Rules with port ranges should not be squashed even if they cover all peers",
|
||||
},
|
||||
{
|
||||
name: "should not squash rules with specific ports",
|
||||
rules: []*mgmProto.FirewallRule{
|
||||
{
|
||||
PeerIP: "10.93.0.1",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
PortInfo: &mgmProto.PortInfo{
|
||||
PortSelection: &mgmProto.PortInfo_Port{
|
||||
Port: 80,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.2",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
PortInfo: &mgmProto.PortInfo{
|
||||
PortSelection: &mgmProto.PortInfo_Port{
|
||||
Port: 80,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.3",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
PortInfo: &mgmProto.PortInfo{
|
||||
PortSelection: &mgmProto.PortInfo_Port{
|
||||
Port: 80,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.4",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
PortInfo: &mgmProto.PortInfo{
|
||||
PortSelection: &mgmProto.PortInfo_Port{
|
||||
Port: 80,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
expectedCount: 4,
|
||||
description: "Rules with specific ports should not be squashed even if they cover all peers",
|
||||
},
|
||||
{
|
||||
name: "should not squash rules with legacy port field",
|
||||
rules: []*mgmProto.FirewallRule{
|
||||
{
|
||||
PeerIP: "10.93.0.1",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
Port: "443",
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.2",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
Port: "443",
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.3",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
Port: "443",
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.4",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
Port: "443",
|
||||
},
|
||||
},
|
||||
expectedCount: 4,
|
||||
description: "Rules with legacy port field should not be squashed",
|
||||
},
|
||||
{
|
||||
name: "should not squash rules with DROP action",
|
||||
rules: []*mgmProto.FirewallRule{
|
||||
{
|
||||
PeerIP: "10.93.0.1",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_DROP,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.2",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_DROP,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.3",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_DROP,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.4",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_DROP,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
},
|
||||
},
|
||||
expectedCount: 4,
|
||||
description: "Rules with DROP action should not be squashed",
|
||||
},
|
||||
{
|
||||
name: "should squash rules without port restrictions",
|
||||
rules: []*mgmProto.FirewallRule{
|
||||
{
|
||||
PeerIP: "10.93.0.1",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.2",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.3",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.4",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
},
|
||||
},
|
||||
expectedCount: 1,
|
||||
description: "Rules without port restrictions should be squashed into a single 0.0.0.0 rule",
|
||||
},
|
||||
{
|
||||
name: "mixed rules should not squash protocol with port restrictions",
|
||||
rules: []*mgmProto.FirewallRule{
|
||||
{
|
||||
PeerIP: "10.93.0.1",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.2",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
PortInfo: &mgmProto.PortInfo{
|
||||
PortSelection: &mgmProto.PortInfo_Port{
|
||||
Port: 80,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.3",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.4",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
},
|
||||
},
|
||||
expectedCount: 4,
|
||||
description: "TCP should not be squashed because one rule has port restrictions",
|
||||
},
|
||||
{
|
||||
name: "should squash UDP but not TCP when TCP has port restrictions",
|
||||
rules: []*mgmProto.FirewallRule{
|
||||
// TCP rules with port restrictions - should NOT be squashed
|
||||
{
|
||||
PeerIP: "10.93.0.1",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
Port: "443",
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.2",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
Port: "443",
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.3",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
Port: "443",
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.4",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
Port: "443",
|
||||
},
|
||||
// UDP rules without port restrictions - SHOULD be squashed
|
||||
{
|
||||
PeerIP: "10.93.0.1",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_UDP,
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.2",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_UDP,
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.3",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_UDP,
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.4",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_UDP,
|
||||
},
|
||||
},
|
||||
expectedCount: 5, // 4 TCP rules + 1 squashed UDP rule (0.0.0.0)
|
||||
description: "UDP should be squashed to 0.0.0.0 rule, but TCP should remain as individual rules due to port restrictions",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
networkMap := &mgmProto.NetworkMap{
|
||||
RemotePeers: []*mgmProto.RemotePeerConfig{
|
||||
{AllowedIps: []string{"10.93.0.1"}},
|
||||
{AllowedIps: []string{"10.93.0.2"}},
|
||||
{AllowedIps: []string{"10.93.0.3"}},
|
||||
{AllowedIps: []string{"10.93.0.4"}},
|
||||
},
|
||||
FirewallRules: tt.rules,
|
||||
}
|
||||
|
||||
manager := &DefaultManager{}
|
||||
rules, _ := manager.squashAcceptRules(networkMap)
|
||||
|
||||
assert.Equal(t, tt.expectedCount, len(rules), tt.description)
|
||||
|
||||
// For squashed rules, verify we get the expected 0.0.0.0 rule
|
||||
if tt.expectedCount == 1 {
|
||||
assert.Equal(t, "0.0.0.0", rules[0].PeerIP)
|
||||
assert.Equal(t, mgmProto.RuleDirection_IN, rules[0].Direction)
|
||||
assert.Equal(t, mgmProto.RuleAction_ACCEPT, rules[0].Action)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPortInfoEmpty(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
portInfo *mgmProto.PortInfo
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "nil PortInfo should be empty",
|
||||
portInfo: nil,
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "PortInfo with zero port should be empty",
|
||||
portInfo: &mgmProto.PortInfo{
|
||||
PortSelection: &mgmProto.PortInfo_Port{
|
||||
Port: 0,
|
||||
},
|
||||
},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "PortInfo with valid port should not be empty",
|
||||
portInfo: &mgmProto.PortInfo{
|
||||
PortSelection: &mgmProto.PortInfo_Port{
|
||||
Port: 80,
|
||||
},
|
||||
},
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "PortInfo with nil range should be empty",
|
||||
portInfo: &mgmProto.PortInfo{
|
||||
PortSelection: &mgmProto.PortInfo_Range_{
|
||||
Range: nil,
|
||||
},
|
||||
},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "PortInfo with zero start range should be empty",
|
||||
portInfo: &mgmProto.PortInfo{
|
||||
PortSelection: &mgmProto.PortInfo_Range_{
|
||||
Range: &mgmProto.PortInfo_Range{
|
||||
Start: 0,
|
||||
End: 100,
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "PortInfo with zero end range should be empty",
|
||||
portInfo: &mgmProto.PortInfo{
|
||||
PortSelection: &mgmProto.PortInfo_Range_{
|
||||
Range: &mgmProto.PortInfo_Range{
|
||||
Start: 80,
|
||||
End: 0,
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "PortInfo with valid range should not be empty",
|
||||
portInfo: &mgmProto.PortInfo{
|
||||
PortSelection: &mgmProto.PortInfo_Range_{
|
||||
Range: &mgmProto.PortInfo_Range{
|
||||
Start: 8080,
|
||||
End: 8090,
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := portInfoEmpty(tt.portInfo)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultManagerEnableSSHRules(t *testing.T) {
|
||||
networkMap := &mgmProto.NetworkMap{
|
||||
PeerConfig: &mgmProto.PeerConfig{
|
||||
|
||||
@@ -223,8 +223,6 @@ func createNewConfig(input ConfigInput) (*Config, error) {
|
||||
config := &Config{
|
||||
// defaults to false only for new (post 0.26) configurations
|
||||
ServerSSHAllowed: util.False(),
|
||||
// default to disabling server routes on Android for security
|
||||
DisableServerRoutes: runtime.GOOS == "android",
|
||||
}
|
||||
|
||||
if _, err := config.apply(input); err != nil {
|
||||
@@ -319,6 +317,10 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) {
|
||||
*input.WireguardPort, config.WgPort)
|
||||
config.WgPort = *input.WireguardPort
|
||||
updated = true
|
||||
} else if config.WgPort == 0 {
|
||||
config.WgPort = iface.DefaultWgPort
|
||||
log.Infof("using default Wireguard port %d", config.WgPort)
|
||||
updated = true
|
||||
}
|
||||
|
||||
if input.InterfaceName != nil && *input.InterfaceName != config.WgIface {
|
||||
@@ -414,15 +416,9 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) {
|
||||
config.ServerSSHAllowed = input.ServerSSHAllowed
|
||||
updated = true
|
||||
} else if config.ServerSSHAllowed == nil {
|
||||
if runtime.GOOS == "android" {
|
||||
// default to disabled SSH on Android for security
|
||||
log.Infof("setting SSH server to false by default on Android")
|
||||
config.ServerSSHAllowed = util.False()
|
||||
} else {
|
||||
// enables SSH for configs from old versions to preserve backwards compatibility
|
||||
log.Infof("falling back to enabled SSH server for pre-existing configuration")
|
||||
config.ServerSSHAllowed = util.True()
|
||||
}
|
||||
// enables SSH for configs from old versions to preserve backwards compatibility
|
||||
log.Infof("falling back to enabled SSH server for pre-existing configuration")
|
||||
config.ServerSSHAllowed = util.True()
|
||||
updated = true
|
||||
}
|
||||
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/lazyconn"
|
||||
"github.com/netbirdio/netbird/client/internal/lazyconn/manager"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/peer/dispatcher"
|
||||
"github.com/netbirdio/netbird/client/internal/peerstore"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
)
|
||||
@@ -25,11 +26,11 @@ import (
|
||||
//
|
||||
// The implementation is not thread-safe; it is protected by engine.syncMsgMux.
|
||||
type ConnMgr struct {
|
||||
peerStore *peerstore.Store
|
||||
statusRecorder *peer.Status
|
||||
iface lazyconn.WGIface
|
||||
enabledLocally bool
|
||||
rosenpassEnabled bool
|
||||
peerStore *peerstore.Store
|
||||
statusRecorder *peer.Status
|
||||
iface lazyconn.WGIface
|
||||
dispatcher *dispatcher.ConnectionDispatcher
|
||||
enabledLocally bool
|
||||
|
||||
lazyConnMgr *manager.Manager
|
||||
|
||||
@@ -38,12 +39,12 @@ type ConnMgr struct {
|
||||
lazyCtxCancel context.CancelFunc
|
||||
}
|
||||
|
||||
func NewConnMgr(engineConfig *EngineConfig, statusRecorder *peer.Status, peerStore *peerstore.Store, iface lazyconn.WGIface) *ConnMgr {
|
||||
func NewConnMgr(engineConfig *EngineConfig, statusRecorder *peer.Status, peerStore *peerstore.Store, iface lazyconn.WGIface, dispatcher *dispatcher.ConnectionDispatcher) *ConnMgr {
|
||||
e := &ConnMgr{
|
||||
peerStore: peerStore,
|
||||
statusRecorder: statusRecorder,
|
||||
iface: iface,
|
||||
rosenpassEnabled: engineConfig.RosenpassEnabled,
|
||||
peerStore: peerStore,
|
||||
statusRecorder: statusRecorder,
|
||||
iface: iface,
|
||||
dispatcher: dispatcher,
|
||||
}
|
||||
if engineConfig.LazyConnectionEnabled || lazyconn.IsLazyConnEnabledByEnv() {
|
||||
e.enabledLocally = true
|
||||
@@ -63,11 +64,6 @@ func (e *ConnMgr) Start(ctx context.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
if e.rosenpassEnabled {
|
||||
log.Warnf("rosenpass connection manager is enabled, lazy connection manager will not be started")
|
||||
return
|
||||
}
|
||||
|
||||
e.initLazyManager(ctx)
|
||||
e.statusRecorder.UpdateLazyConnection(true)
|
||||
}
|
||||
@@ -87,12 +83,7 @@ func (e *ConnMgr) UpdatedRemoteFeatureFlag(ctx context.Context, enabled bool) er
|
||||
return nil
|
||||
}
|
||||
|
||||
if e.rosenpassEnabled {
|
||||
log.Infof("rosenpass connection manager is enabled, lazy connection manager will not be started")
|
||||
return nil
|
||||
}
|
||||
|
||||
log.Warnf("lazy connection manager is enabled by management feature flag")
|
||||
log.Infof("lazy connection manager is enabled by management feature flag")
|
||||
e.initLazyManager(ctx)
|
||||
e.statusRecorder.UpdateLazyConnection(true)
|
||||
return e.addPeersToLazyConnManager()
|
||||
@@ -142,7 +133,7 @@ func (e *ConnMgr) SetExcludeList(ctx context.Context, peerIDs map[string]bool) {
|
||||
excludedPeers = append(excludedPeers, lazyPeerCfg)
|
||||
}
|
||||
|
||||
added := e.lazyConnMgr.ExcludePeer(excludedPeers)
|
||||
added := e.lazyConnMgr.ExcludePeer(e.lazyCtx, excludedPeers)
|
||||
for _, peerID := range added {
|
||||
var peerConn *peer.Conn
|
||||
var exists bool
|
||||
@@ -210,7 +201,7 @@ func (e *ConnMgr) RemovePeerConn(peerKey string) {
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
defer conn.Close(false)
|
||||
defer conn.Close()
|
||||
|
||||
if !e.isStartedWithLazyMgr() {
|
||||
return
|
||||
@@ -220,28 +211,23 @@ func (e *ConnMgr) RemovePeerConn(peerKey string) {
|
||||
conn.Log.Infof("removed peer from lazy conn manager")
|
||||
}
|
||||
|
||||
func (e *ConnMgr) ActivatePeer(ctx context.Context, conn *peer.Conn) {
|
||||
if !e.isStartedWithLazyMgr() {
|
||||
return
|
||||
func (e *ConnMgr) OnSignalMsg(ctx context.Context, peerKey string) (*peer.Conn, bool) {
|
||||
conn, ok := e.peerStore.PeerConn(peerKey)
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
if found := e.lazyConnMgr.ActivatePeer(conn.GetKey()); found {
|
||||
if !e.isStartedWithLazyMgr() {
|
||||
return conn, true
|
||||
}
|
||||
|
||||
if found := e.lazyConnMgr.ActivatePeer(e.lazyCtx, peerKey); found {
|
||||
conn.Log.Infof("activated peer from inactive state")
|
||||
if err := conn.Open(ctx); err != nil {
|
||||
conn.Log.Errorf("failed to open connection: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// DeactivatePeer deactivates a peer connection in the lazy connection manager.
|
||||
// If locally the lazy connection is disabled, we force the peer connection open.
|
||||
func (e *ConnMgr) DeactivatePeer(conn *peer.Conn) {
|
||||
if !e.isStartedWithLazyMgr() {
|
||||
return
|
||||
}
|
||||
|
||||
conn.Log.Infof("closing peer connection: remote peer initiated inactive, idle lazy state and sent GOAWAY")
|
||||
e.lazyConnMgr.DeactivatePeer(conn.ConnID())
|
||||
return conn, true
|
||||
}
|
||||
|
||||
func (e *ConnMgr) Close() {
|
||||
@@ -258,7 +244,7 @@ func (e *ConnMgr) initLazyManager(engineCtx context.Context) {
|
||||
cfg := manager.Config{
|
||||
InactivityThreshold: inactivityThresholdEnv(),
|
||||
}
|
||||
e.lazyConnMgr = manager.NewManager(cfg, engineCtx, e.peerStore, e.iface)
|
||||
e.lazyConnMgr = manager.NewManager(cfg, engineCtx, e.peerStore, e.iface, e.dispatcher)
|
||||
|
||||
e.lazyCtx, e.lazyCtxCancel = context.WithCancel(engineCtx)
|
||||
|
||||
@@ -289,7 +275,7 @@ func (e *ConnMgr) addPeersToLazyConnManager() error {
|
||||
lazyPeerCfgs = append(lazyPeerCfgs, lazyPeerCfg)
|
||||
}
|
||||
|
||||
return e.lazyConnMgr.AddActivePeers(lazyPeerCfgs)
|
||||
return e.lazyConnMgr.AddActivePeers(e.lazyCtx, lazyPeerCfgs)
|
||||
}
|
||||
|
||||
func (e *ConnMgr) closeManager(ctx context.Context) {
|
||||
|
||||
@@ -17,6 +17,7 @@ import (
|
||||
"google.golang.org/grpc/codes"
|
||||
gstatus "google.golang.org/grpc/status"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/client/internal/dns"
|
||||
"github.com/netbirdio/netbird/client/internal/listener"
|
||||
@@ -525,13 +526,17 @@ func statusRecorderToSignalConnStateNotifier(statusRecorder *peer.Status) signal
|
||||
|
||||
// freePort attempts to determine if the provided port is available, if not it will ask the system for a free port.
|
||||
func freePort(initPort int) (int, error) {
|
||||
addr := net.UDPAddr{Port: initPort}
|
||||
addr := net.UDPAddr{}
|
||||
if initPort == 0 {
|
||||
initPort = iface.DefaultWgPort
|
||||
}
|
||||
|
||||
addr.Port = initPort
|
||||
|
||||
conn, err := net.ListenUDP("udp", &addr)
|
||||
if err == nil {
|
||||
returnPort := conn.LocalAddr().(*net.UDPAddr).Port
|
||||
closeConnWithLog(conn)
|
||||
return returnPort, nil
|
||||
return initPort, nil
|
||||
}
|
||||
|
||||
// if the port is already in use, ask the system for a free port
|
||||
|
||||
@@ -13,10 +13,10 @@ func Test_freePort(t *testing.T) {
|
||||
shouldMatch bool
|
||||
}{
|
||||
{
|
||||
name: "when port is 0 use random port",
|
||||
name: "not provided, fallback to default",
|
||||
port: 0,
|
||||
want: 0,
|
||||
shouldMatch: false,
|
||||
want: 51820,
|
||||
shouldMatch: true,
|
||||
},
|
||||
{
|
||||
name: "provided and available",
|
||||
@@ -31,7 +31,7 @@ func Test_freePort(t *testing.T) {
|
||||
shouldMatch: false,
|
||||
},
|
||||
}
|
||||
c1, err := net.ListenUDP("udp", &net.UDPAddr{Port: 0})
|
||||
c1, err := net.ListenUDP("udp", &net.UDPAddr{Port: 51830})
|
||||
if err != nil {
|
||||
t.Errorf("freePort error = %v", err)
|
||||
}
|
||||
@@ -39,14 +39,6 @@ func Test_freePort(t *testing.T) {
|
||||
_ = c1.Close()
|
||||
}(c1)
|
||||
|
||||
if tests[1].port == c1.LocalAddr().(*net.UDPAddr).Port {
|
||||
tests[1].port++
|
||||
tests[1].want++
|
||||
}
|
||||
|
||||
tests[2].port = c1.LocalAddr().(*net.UDPAddr).Port
|
||||
tests[2].want = c1.LocalAddr().(*net.UDPAddr).Port
|
||||
|
||||
for _, tt := range tests {
|
||||
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
|
||||
@@ -11,10 +11,9 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
PriorityLocal = 100
|
||||
PriorityDNSRoute = 75
|
||||
PriorityUpstream = 50
|
||||
PriorityDefault = 1
|
||||
PriorityDNSRoute = 100
|
||||
PriorityMatchDomain = 50
|
||||
PriorityDefault = 1
|
||||
)
|
||||
|
||||
type SubdomainMatcher interface {
|
||||
|
||||
@@ -22,7 +22,7 @@ func TestHandlerChain_ServeDNS_Priorities(t *testing.T) {
|
||||
|
||||
// Setup handlers with different priorities
|
||||
chain.AddHandler("example.com.", defaultHandler, nbdns.PriorityDefault)
|
||||
chain.AddHandler("example.com.", matchDomainHandler, nbdns.PriorityUpstream)
|
||||
chain.AddHandler("example.com.", matchDomainHandler, nbdns.PriorityMatchDomain)
|
||||
chain.AddHandler("example.com.", dnsRouteHandler, nbdns.PriorityDNSRoute)
|
||||
|
||||
// Create test request
|
||||
@@ -200,7 +200,7 @@ func TestHandlerChain_ServeDNS_OverlappingDomains(t *testing.T) {
|
||||
priority int
|
||||
}{
|
||||
{pattern: "*.example.com.", priority: nbdns.PriorityDefault},
|
||||
{pattern: "*.example.com.", priority: nbdns.PriorityUpstream},
|
||||
{pattern: "*.example.com.", priority: nbdns.PriorityMatchDomain},
|
||||
{pattern: "*.example.com.", priority: nbdns.PriorityDNSRoute},
|
||||
},
|
||||
queryDomain: "test.example.com.",
|
||||
@@ -214,7 +214,7 @@ func TestHandlerChain_ServeDNS_OverlappingDomains(t *testing.T) {
|
||||
priority int
|
||||
}{
|
||||
{pattern: "*.example.com.", priority: nbdns.PriorityDefault},
|
||||
{pattern: "test.example.com.", priority: nbdns.PriorityUpstream},
|
||||
{pattern: "test.example.com.", priority: nbdns.PriorityMatchDomain},
|
||||
{pattern: "*.test.example.com.", priority: nbdns.PriorityDNSRoute},
|
||||
},
|
||||
queryDomain: "sub.test.example.com.",
|
||||
@@ -281,7 +281,7 @@ func TestHandlerChain_ServeDNS_ChainContinuation(t *testing.T) {
|
||||
|
||||
// Add handlers in priority order
|
||||
chain.AddHandler("example.com.", handler1, nbdns.PriorityDNSRoute)
|
||||
chain.AddHandler("example.com.", handler2, nbdns.PriorityUpstream)
|
||||
chain.AddHandler("example.com.", handler2, nbdns.PriorityMatchDomain)
|
||||
chain.AddHandler("example.com.", handler3, nbdns.PriorityDefault)
|
||||
|
||||
// Create test request
|
||||
@@ -344,13 +344,13 @@ func TestHandlerChain_PriorityDeregistration(t *testing.T) {
|
||||
priority int
|
||||
}{
|
||||
{"add", "example.com.", nbdns.PriorityDNSRoute},
|
||||
{"add", "example.com.", nbdns.PriorityUpstream},
|
||||
{"add", "example.com.", nbdns.PriorityMatchDomain},
|
||||
{"remove", "example.com.", nbdns.PriorityDNSRoute},
|
||||
},
|
||||
query: "example.com.",
|
||||
expectedCalls: map[int]bool{
|
||||
nbdns.PriorityDNSRoute: false,
|
||||
nbdns.PriorityUpstream: true,
|
||||
nbdns.PriorityDNSRoute: false,
|
||||
nbdns.PriorityMatchDomain: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -361,13 +361,13 @@ func TestHandlerChain_PriorityDeregistration(t *testing.T) {
|
||||
priority int
|
||||
}{
|
||||
{"add", "example.com.", nbdns.PriorityDNSRoute},
|
||||
{"add", "example.com.", nbdns.PriorityUpstream},
|
||||
{"remove", "example.com.", nbdns.PriorityUpstream},
|
||||
{"add", "example.com.", nbdns.PriorityMatchDomain},
|
||||
{"remove", "example.com.", nbdns.PriorityMatchDomain},
|
||||
},
|
||||
query: "example.com.",
|
||||
expectedCalls: map[int]bool{
|
||||
nbdns.PriorityDNSRoute: true,
|
||||
nbdns.PriorityUpstream: false,
|
||||
nbdns.PriorityDNSRoute: true,
|
||||
nbdns.PriorityMatchDomain: false,
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -378,16 +378,16 @@ func TestHandlerChain_PriorityDeregistration(t *testing.T) {
|
||||
priority int
|
||||
}{
|
||||
{"add", "example.com.", nbdns.PriorityDNSRoute},
|
||||
{"add", "example.com.", nbdns.PriorityUpstream},
|
||||
{"add", "example.com.", nbdns.PriorityMatchDomain},
|
||||
{"add", "example.com.", nbdns.PriorityDefault},
|
||||
{"remove", "example.com.", nbdns.PriorityDNSRoute},
|
||||
{"remove", "example.com.", nbdns.PriorityUpstream},
|
||||
{"remove", "example.com.", nbdns.PriorityMatchDomain},
|
||||
},
|
||||
query: "example.com.",
|
||||
expectedCalls: map[int]bool{
|
||||
nbdns.PriorityDNSRoute: false,
|
||||
nbdns.PriorityUpstream: false,
|
||||
nbdns.PriorityDefault: true,
|
||||
nbdns.PriorityDNSRoute: false,
|
||||
nbdns.PriorityMatchDomain: false,
|
||||
nbdns.PriorityDefault: true,
|
||||
},
|
||||
},
|
||||
}
|
||||
@@ -454,7 +454,7 @@ func TestHandlerChain_MultiPriorityHandling(t *testing.T) {
|
||||
// Add handlers in mixed order
|
||||
chain.AddHandler(testDomain, defaultHandler, nbdns.PriorityDefault)
|
||||
chain.AddHandler(testDomain, routeHandler, nbdns.PriorityDNSRoute)
|
||||
chain.AddHandler(testDomain, matchHandler, nbdns.PriorityUpstream)
|
||||
chain.AddHandler(testDomain, matchHandler, nbdns.PriorityMatchDomain)
|
||||
|
||||
// Test 1: Initial state
|
||||
w1 := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}}
|
||||
@@ -490,7 +490,7 @@ func TestHandlerChain_MultiPriorityHandling(t *testing.T) {
|
||||
defaultHandler.Calls = nil
|
||||
|
||||
// Test 3: Remove middle priority handler
|
||||
chain.RemoveHandler(testDomain, nbdns.PriorityUpstream)
|
||||
chain.RemoveHandler(testDomain, nbdns.PriorityMatchDomain)
|
||||
|
||||
w3 := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}}
|
||||
// Now lowest priority handler (defaultHandler) should be called
|
||||
@@ -607,7 +607,7 @@ func TestHandlerChain_CaseSensitivity(t *testing.T) {
|
||||
shouldMatch bool
|
||||
}{
|
||||
{"EXAMPLE.COM.", nbdns.PriorityDefault, false, false},
|
||||
{"example.com.", nbdns.PriorityUpstream, false, false},
|
||||
{"example.com.", nbdns.PriorityMatchDomain, false, false},
|
||||
{"Example.Com.", nbdns.PriorityDNSRoute, false, true},
|
||||
},
|
||||
query: "example.com.",
|
||||
@@ -702,8 +702,8 @@ func TestHandlerChain_DomainSpecificityOrdering(t *testing.T) {
|
||||
priority int
|
||||
subdomain bool
|
||||
}{
|
||||
{"add", "example.com.", nbdns.PriorityUpstream, true},
|
||||
{"add", "sub.example.com.", nbdns.PriorityUpstream, false},
|
||||
{"add", "example.com.", nbdns.PriorityMatchDomain, true},
|
||||
{"add", "sub.example.com.", nbdns.PriorityMatchDomain, false},
|
||||
},
|
||||
query: "sub.example.com.",
|
||||
expectedMatch: "sub.example.com.",
|
||||
@@ -717,8 +717,8 @@ func TestHandlerChain_DomainSpecificityOrdering(t *testing.T) {
|
||||
priority int
|
||||
subdomain bool
|
||||
}{
|
||||
{"add", "example.com.", nbdns.PriorityUpstream, true},
|
||||
{"add", "sub.example.com.", nbdns.PriorityUpstream, true},
|
||||
{"add", "example.com.", nbdns.PriorityMatchDomain, true},
|
||||
{"add", "sub.example.com.", nbdns.PriorityMatchDomain, true},
|
||||
},
|
||||
query: "sub.example.com.",
|
||||
expectedMatch: "sub.example.com.",
|
||||
@@ -732,10 +732,10 @@ func TestHandlerChain_DomainSpecificityOrdering(t *testing.T) {
|
||||
priority int
|
||||
subdomain bool
|
||||
}{
|
||||
{"add", "example.com.", nbdns.PriorityUpstream, true},
|
||||
{"add", "sub.example.com.", nbdns.PriorityUpstream, true},
|
||||
{"add", "test.sub.example.com.", nbdns.PriorityUpstream, false},
|
||||
{"remove", "test.sub.example.com.", nbdns.PriorityUpstream, false},
|
||||
{"add", "example.com.", nbdns.PriorityMatchDomain, true},
|
||||
{"add", "sub.example.com.", nbdns.PriorityMatchDomain, true},
|
||||
{"add", "test.sub.example.com.", nbdns.PriorityMatchDomain, false},
|
||||
{"remove", "test.sub.example.com.", nbdns.PriorityMatchDomain, false},
|
||||
},
|
||||
query: "test.sub.example.com.",
|
||||
expectedMatch: "sub.example.com.",
|
||||
@@ -749,7 +749,7 @@ func TestHandlerChain_DomainSpecificityOrdering(t *testing.T) {
|
||||
priority int
|
||||
subdomain bool
|
||||
}{
|
||||
{"add", "sub.example.com.", nbdns.PriorityUpstream, false},
|
||||
{"add", "sub.example.com.", nbdns.PriorityMatchDomain, false},
|
||||
{"add", "example.com.", nbdns.PriorityDNSRoute, true},
|
||||
},
|
||||
query: "sub.example.com.",
|
||||
@@ -764,9 +764,9 @@ func TestHandlerChain_DomainSpecificityOrdering(t *testing.T) {
|
||||
priority int
|
||||
subdomain bool
|
||||
}{
|
||||
{"add", "example.com.", nbdns.PriorityUpstream, true},
|
||||
{"add", "other.example.com.", nbdns.PriorityUpstream, true},
|
||||
{"add", "sub.example.com.", nbdns.PriorityUpstream, false},
|
||||
{"add", "example.com.", nbdns.PriorityMatchDomain, true},
|
||||
{"add", "other.example.com.", nbdns.PriorityMatchDomain, true},
|
||||
{"add", "sub.example.com.", nbdns.PriorityMatchDomain, false},
|
||||
},
|
||||
query: "sub.example.com.",
|
||||
expectedMatch: "sub.example.com.",
|
||||
|
||||
@@ -527,7 +527,7 @@ func (s *DefaultServer) buildLocalHandlerUpdate(customZones []nbdns.CustomZone)
|
||||
muxUpdates = append(muxUpdates, handlerWrapper{
|
||||
domain: customZone.Domain,
|
||||
handler: s.localResolver,
|
||||
priority: PriorityLocal,
|
||||
priority: PriorityMatchDomain,
|
||||
})
|
||||
|
||||
for _, record := range customZone.Records {
|
||||
@@ -566,7 +566,7 @@ func (s *DefaultServer) buildUpstreamHandlerUpdate(nameServerGroups []*nbdns.Nam
|
||||
groupedNS := groupNSGroupsByDomain(nameServerGroups)
|
||||
|
||||
for _, domainGroup := range groupedNS {
|
||||
basePriority := PriorityUpstream
|
||||
basePriority := PriorityMatchDomain
|
||||
if domainGroup.domain == nbdns.RootZone {
|
||||
basePriority = PriorityDefault
|
||||
}
|
||||
@@ -588,14 +588,10 @@ func (s *DefaultServer) createHandlersForDomainGroup(domainGroup nsGroupsByDomai
|
||||
// Decrement priority by handler index (0, 1, 2, ...) to avoid conflicts
|
||||
priority := basePriority - i
|
||||
|
||||
// Check if we're about to overlap with the next priority tier.
|
||||
// This boundary check ensures that the priority of upstream handlers does not conflict
|
||||
// with the default priority tier. By decrementing the priority for each handler, we avoid
|
||||
// overlaps, but if the calculated priority falls into the default tier, we skip the remaining
|
||||
// handlers to maintain the integrity of the priority system.
|
||||
if basePriority == PriorityUpstream && priority <= PriorityDefault {
|
||||
// Check if we're about to overlap with the next priority tier
|
||||
if basePriority == PriorityMatchDomain && priority <= PriorityDefault {
|
||||
log.Warnf("too many handlers for domain=%s, would overlap with default priority tier (diff=%d). Skipping remaining handlers",
|
||||
domainGroup.domain, PriorityUpstream-PriorityDefault)
|
||||
domainGroup.domain, PriorityMatchDomain-PriorityDefault)
|
||||
break
|
||||
}
|
||||
|
||||
|
||||
@@ -164,12 +164,12 @@ func TestUpdateDNSServer(t *testing.T) {
|
||||
generateDummyHandler("netbird.io", nameServers).ID(): handlerWrapper{
|
||||
domain: "netbird.io",
|
||||
handler: dummyHandler,
|
||||
priority: PriorityUpstream,
|
||||
priority: PriorityMatchDomain,
|
||||
},
|
||||
dummyHandler.ID(): handlerWrapper{
|
||||
domain: "netbird.cloud",
|
||||
handler: dummyHandler,
|
||||
priority: PriorityLocal,
|
||||
priority: PriorityMatchDomain,
|
||||
},
|
||||
generateDummyHandler(".", nameServers).ID(): handlerWrapper{
|
||||
domain: nbdns.RootZone,
|
||||
@@ -186,7 +186,7 @@ func TestUpdateDNSServer(t *testing.T) {
|
||||
generateDummyHandler(zoneRecords[0].Name, nameServers).ID(): handlerWrapper{
|
||||
domain: "netbird.cloud",
|
||||
handler: dummyHandler,
|
||||
priority: PriorityUpstream,
|
||||
priority: PriorityMatchDomain,
|
||||
},
|
||||
},
|
||||
initSerial: 0,
|
||||
@@ -210,12 +210,12 @@ func TestUpdateDNSServer(t *testing.T) {
|
||||
generateDummyHandler("netbird.io", nameServers).ID(): handlerWrapper{
|
||||
domain: "netbird.io",
|
||||
handler: dummyHandler,
|
||||
priority: PriorityUpstream,
|
||||
priority: PriorityMatchDomain,
|
||||
},
|
||||
"local-resolver": handlerWrapper{
|
||||
domain: "netbird.cloud",
|
||||
handler: dummyHandler,
|
||||
priority: PriorityLocal,
|
||||
priority: PriorityMatchDomain,
|
||||
},
|
||||
},
|
||||
expectedLocalQs: []dns.Question{{Name: zoneRecords[0].Name, Qtype: 1, Qclass: 1}},
|
||||
@@ -305,7 +305,7 @@ func TestUpdateDNSServer(t *testing.T) {
|
||||
generateDummyHandler(zoneRecords[0].Name, nameServers).ID(): handlerWrapper{
|
||||
domain: zoneRecords[0].Name,
|
||||
handler: dummyHandler,
|
||||
priority: PriorityUpstream,
|
||||
priority: PriorityMatchDomain,
|
||||
},
|
||||
},
|
||||
initSerial: 0,
|
||||
@@ -321,7 +321,7 @@ func TestUpdateDNSServer(t *testing.T) {
|
||||
generateDummyHandler(zoneRecords[0].Name, nameServers).ID(): handlerWrapper{
|
||||
domain: zoneRecords[0].Name,
|
||||
handler: dummyHandler,
|
||||
priority: PriorityUpstream,
|
||||
priority: PriorityMatchDomain,
|
||||
},
|
||||
},
|
||||
initSerial: 0,
|
||||
@@ -464,7 +464,7 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) {
|
||||
defer ctrl.Finish()
|
||||
|
||||
packetfilter := pfmock.NewMockPacketFilter(ctrl)
|
||||
packetfilter.EXPECT().FilterOutbound(gomock.Any(), gomock.Any()).AnyTimes()
|
||||
packetfilter.EXPECT().DropOutgoing(gomock.Any(), gomock.Any()).AnyTimes()
|
||||
packetfilter.EXPECT().AddUDPPacketHook(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
|
||||
packetfilter.EXPECT().RemovePacketHook(gomock.Any())
|
||||
|
||||
@@ -495,7 +495,7 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) {
|
||||
"id1": handlerWrapper{
|
||||
domain: zoneRecords[0].Name,
|
||||
handler: &local.Resolver{},
|
||||
priority: PriorityUpstream,
|
||||
priority: PriorityMatchDomain,
|
||||
},
|
||||
}
|
||||
//dnsServer.localResolver.RegisteredMap = local.RegistrationMap{local.BuildRecordKey("netbird.cloud", dns.ClassINET, dns.TypeA): struct{}{}}
|
||||
@@ -978,7 +978,7 @@ func TestHandlerChain_DomainPriorities(t *testing.T) {
|
||||
}
|
||||
|
||||
chain.AddHandler("example.com.", dnsRouteHandler, PriorityDNSRoute)
|
||||
chain.AddHandler("example.com.", upstreamHandler, PriorityUpstream)
|
||||
chain.AddHandler("example.com.", upstreamHandler, PriorityMatchDomain)
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
@@ -1059,14 +1059,14 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
|
||||
handler: &mockHandler{
|
||||
Id: "upstream-group1",
|
||||
},
|
||||
priority: PriorityUpstream,
|
||||
priority: PriorityMatchDomain,
|
||||
},
|
||||
"upstream-group2": {
|
||||
domain: "example.com",
|
||||
handler: &mockHandler{
|
||||
Id: "upstream-group2",
|
||||
},
|
||||
priority: PriorityUpstream - 1,
|
||||
priority: PriorityMatchDomain - 1,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -1093,21 +1093,21 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
|
||||
handler: &mockHandler{
|
||||
Id: "upstream-group1",
|
||||
},
|
||||
priority: PriorityUpstream,
|
||||
priority: PriorityMatchDomain,
|
||||
},
|
||||
"upstream-group2": {
|
||||
domain: "example.com",
|
||||
handler: &mockHandler{
|
||||
Id: "upstream-group2",
|
||||
},
|
||||
priority: PriorityUpstream - 1,
|
||||
priority: PriorityMatchDomain - 1,
|
||||
},
|
||||
"upstream-other": {
|
||||
domain: "other.com",
|
||||
handler: &mockHandler{
|
||||
Id: "upstream-other",
|
||||
},
|
||||
priority: PriorityUpstream,
|
||||
priority: PriorityMatchDomain,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -1128,7 +1128,7 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
|
||||
handler: &mockHandler{
|
||||
Id: "upstream-group2",
|
||||
},
|
||||
priority: PriorityUpstream - 1,
|
||||
priority: PriorityMatchDomain - 1,
|
||||
},
|
||||
},
|
||||
expectedHandlers: map[string]string{
|
||||
@@ -1146,7 +1146,7 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
|
||||
handler: &mockHandler{
|
||||
Id: "upstream-group1",
|
||||
},
|
||||
priority: PriorityUpstream,
|
||||
priority: PriorityMatchDomain,
|
||||
},
|
||||
},
|
||||
expectedHandlers: map[string]string{
|
||||
@@ -1164,7 +1164,7 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
|
||||
handler: &mockHandler{
|
||||
Id: "upstream-group3",
|
||||
},
|
||||
priority: PriorityUpstream + 1,
|
||||
priority: PriorityMatchDomain + 1,
|
||||
},
|
||||
// Keep existing groups with their original priorities
|
||||
{
|
||||
@@ -1172,14 +1172,14 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
|
||||
handler: &mockHandler{
|
||||
Id: "upstream-group1",
|
||||
},
|
||||
priority: PriorityUpstream,
|
||||
priority: PriorityMatchDomain,
|
||||
},
|
||||
{
|
||||
domain: "example.com",
|
||||
handler: &mockHandler{
|
||||
Id: "upstream-group2",
|
||||
},
|
||||
priority: PriorityUpstream - 1,
|
||||
priority: PriorityMatchDomain - 1,
|
||||
},
|
||||
},
|
||||
expectedHandlers: map[string]string{
|
||||
@@ -1199,14 +1199,14 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
|
||||
handler: &mockHandler{
|
||||
Id: "upstream-group1",
|
||||
},
|
||||
priority: PriorityUpstream,
|
||||
priority: PriorityMatchDomain,
|
||||
},
|
||||
{
|
||||
domain: "example.com",
|
||||
handler: &mockHandler{
|
||||
Id: "upstream-group2",
|
||||
},
|
||||
priority: PriorityUpstream - 1,
|
||||
priority: PriorityMatchDomain - 1,
|
||||
},
|
||||
// Add group3 with lowest priority
|
||||
{
|
||||
@@ -1214,7 +1214,7 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
|
||||
handler: &mockHandler{
|
||||
Id: "upstream-group3",
|
||||
},
|
||||
priority: PriorityUpstream - 2,
|
||||
priority: PriorityMatchDomain - 2,
|
||||
},
|
||||
},
|
||||
expectedHandlers: map[string]string{
|
||||
@@ -1335,14 +1335,14 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
|
||||
handler: &mockHandler{
|
||||
Id: "upstream-group1",
|
||||
},
|
||||
priority: PriorityUpstream,
|
||||
priority: PriorityMatchDomain,
|
||||
},
|
||||
{
|
||||
domain: "other.com",
|
||||
handler: &mockHandler{
|
||||
Id: "upstream-other",
|
||||
},
|
||||
priority: PriorityUpstream,
|
||||
priority: PriorityMatchDomain,
|
||||
},
|
||||
},
|
||||
expectedHandlers: map[string]string{
|
||||
@@ -1360,28 +1360,28 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
|
||||
handler: &mockHandler{
|
||||
Id: "upstream-group1",
|
||||
},
|
||||
priority: PriorityUpstream,
|
||||
priority: PriorityMatchDomain,
|
||||
},
|
||||
{
|
||||
domain: "example.com",
|
||||
handler: &mockHandler{
|
||||
Id: "upstream-group2",
|
||||
},
|
||||
priority: PriorityUpstream - 1,
|
||||
priority: PriorityMatchDomain - 1,
|
||||
},
|
||||
{
|
||||
domain: "other.com",
|
||||
handler: &mockHandler{
|
||||
Id: "upstream-other",
|
||||
},
|
||||
priority: PriorityUpstream,
|
||||
priority: PriorityMatchDomain,
|
||||
},
|
||||
{
|
||||
domain: "new.com",
|
||||
handler: &mockHandler{
|
||||
Id: "upstream-new",
|
||||
},
|
||||
priority: PriorityUpstream,
|
||||
priority: PriorityMatchDomain,
|
||||
},
|
||||
},
|
||||
expectedHandlers: map[string]string{
|
||||
@@ -1791,14 +1791,14 @@ func TestExtraDomainsRefCounting(t *testing.T) {
|
||||
|
||||
// Register domains from different handlers with same domain
|
||||
server.RegisterHandler(domain.List{"*.shared.example.com"}, &MockHandler{}, PriorityDNSRoute)
|
||||
server.RegisterHandler(domain.List{"shared.example.com."}, &MockHandler{}, PriorityUpstream)
|
||||
server.RegisterHandler(domain.List{"shared.example.com."}, &MockHandler{}, PriorityMatchDomain)
|
||||
|
||||
// Verify refcount is 2
|
||||
zoneKey := toZone("shared.example.com")
|
||||
assert.Equal(t, 2, server.extraDomains[zoneKey], "Refcount should be 2 after registering same domain twice")
|
||||
|
||||
// Deregister one handler
|
||||
server.DeregisterHandler(domain.List{"shared.example.com"}, PriorityUpstream)
|
||||
server.DeregisterHandler(domain.List{"shared.example.com"}, PriorityMatchDomain)
|
||||
|
||||
// Verify refcount is 1
|
||||
assert.Equal(t, 1, server.extraDomains[zoneKey], "Refcount should be 1 after deregistering one handler")
|
||||
@@ -1925,7 +1925,7 @@ func TestDomainCaseHandling(t *testing.T) {
|
||||
}
|
||||
|
||||
server.RegisterHandler(domain.List{"MIXED.example.com"}, &MockHandler{}, PriorityDefault)
|
||||
server.RegisterHandler(domain.List{"mixed.EXAMPLE.com"}, &MockHandler{}, PriorityUpstream)
|
||||
server.RegisterHandler(domain.List{"mixed.EXAMPLE.com"}, &MockHandler{}, PriorityMatchDomain)
|
||||
|
||||
assert.Equal(t, 1, len(server.extraDomains), "Case differences should be normalized")
|
||||
|
||||
@@ -1945,111 +1945,3 @@ func TestDomainCaseHandling(t *testing.T) {
|
||||
assert.Contains(t, domains, "config.example.com.", "Mixed case domain should be normalized and pre.sent")
|
||||
assert.Contains(t, domains, "mixed.example.com.", "Mixed case domain should be normalized and present")
|
||||
}
|
||||
|
||||
func TestLocalResolverPriorityInServer(t *testing.T) {
|
||||
server := &DefaultServer{
|
||||
ctx: context.Background(),
|
||||
wgInterface: &mocWGIface{},
|
||||
handlerChain: NewHandlerChain(),
|
||||
localResolver: local.NewResolver(),
|
||||
service: &mockService{},
|
||||
extraDomains: make(map[domain.Domain]int),
|
||||
}
|
||||
|
||||
config := nbdns.Config{
|
||||
ServiceEnable: true,
|
||||
CustomZones: []nbdns.CustomZone{
|
||||
{
|
||||
Domain: "local.example.com",
|
||||
Records: []nbdns.SimpleRecord{
|
||||
{
|
||||
Name: "test.local.example.com",
|
||||
Type: int(dns.TypeA),
|
||||
Class: nbdns.DefaultClass,
|
||||
TTL: 300,
|
||||
RData: "192.168.1.100",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
NameServerGroups: []*nbdns.NameServerGroup{
|
||||
{
|
||||
Domains: []string{"local.example.com"}, // Same domain as local records
|
||||
NameServers: []nbdns.NameServer{
|
||||
{
|
||||
IP: netip.MustParseAddr("8.8.8.8"),
|
||||
NSType: nbdns.UDPNameServerType,
|
||||
Port: 53,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
localMuxUpdates, _, err := server.buildLocalHandlerUpdate(config.CustomZones)
|
||||
assert.NoError(t, err)
|
||||
|
||||
upstreamMuxUpdates, err := server.buildUpstreamHandlerUpdate(config.NameServerGroups)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Verify that local handler has higher priority than upstream for same domain
|
||||
var localPriority, upstreamPriority int
|
||||
localFound, upstreamFound := false, false
|
||||
|
||||
for _, update := range localMuxUpdates {
|
||||
if update.domain == "local.example.com" {
|
||||
localPriority = update.priority
|
||||
localFound = true
|
||||
}
|
||||
}
|
||||
|
||||
for _, update := range upstreamMuxUpdates {
|
||||
if update.domain == "local.example.com" {
|
||||
upstreamPriority = update.priority
|
||||
upstreamFound = true
|
||||
}
|
||||
}
|
||||
|
||||
assert.True(t, localFound, "Local handler should be found")
|
||||
assert.True(t, upstreamFound, "Upstream handler should be found")
|
||||
assert.Greater(t, localPriority, upstreamPriority,
|
||||
"Local handler priority (%d) should be higher than upstream priority (%d)",
|
||||
localPriority, upstreamPriority)
|
||||
assert.Equal(t, PriorityLocal, localPriority, "Local handler should use PriorityLocal")
|
||||
assert.Equal(t, PriorityUpstream, upstreamPriority, "Upstream handler should use PriorityUpstream")
|
||||
}
|
||||
|
||||
func TestLocalResolverPriorityConstants(t *testing.T) {
|
||||
// Test that priority constants are ordered correctly
|
||||
assert.Greater(t, PriorityLocal, PriorityDNSRoute, "Local priority should be higher than DNS route")
|
||||
assert.Greater(t, PriorityLocal, PriorityUpstream, "Local priority should be higher than upstream")
|
||||
assert.Greater(t, PriorityUpstream, PriorityDefault, "Upstream priority should be higher than default")
|
||||
|
||||
// Test that local resolver uses the correct priority
|
||||
server := &DefaultServer{
|
||||
localResolver: local.NewResolver(),
|
||||
}
|
||||
|
||||
config := nbdns.Config{
|
||||
CustomZones: []nbdns.CustomZone{
|
||||
{
|
||||
Domain: "local.example.com",
|
||||
Records: []nbdns.SimpleRecord{
|
||||
{
|
||||
Name: "test.local.example.com",
|
||||
Type: int(dns.TypeA),
|
||||
Class: nbdns.DefaultClass,
|
||||
TTL: 300,
|
||||
RData: "192.168.1.100",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
localMuxUpdates, _, err := server.buildLocalHandlerUpdate(config.CustomZones)
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, localMuxUpdates, 1)
|
||||
assert.Equal(t, PriorityLocal, localMuxUpdates[0].priority, "Local handler should use PriorityLocal")
|
||||
assert.Equal(t, "local.example.com", localMuxUpdates[0].domain)
|
||||
}
|
||||
|
||||
@@ -2,7 +2,6 @@ package dns
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
@@ -104,21 +103,19 @@ func (u *upstreamResolverBase) Stop() {
|
||||
|
||||
// ServeDNS handles a DNS request
|
||||
func (u *upstreamResolverBase) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||
requestID := GenerateRequestID()
|
||||
logger := log.WithField("request_id", requestID)
|
||||
var err error
|
||||
defer func() {
|
||||
u.checkUpstreamFails(err)
|
||||
}()
|
||||
|
||||
logger.Tracef("received upstream question: domain=%s type=%v class=%v", r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass)
|
||||
log.Tracef("received upstream question: domain=%s type=%v class=%v", r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass)
|
||||
if r.Extra == nil {
|
||||
r.MsgHdr.AuthenticatedData = true
|
||||
}
|
||||
|
||||
select {
|
||||
case <-u.ctx.Done():
|
||||
logger.Tracef("%s has been stopped", u)
|
||||
log.Tracef("%s has been stopped", u)
|
||||
return
|
||||
default:
|
||||
}
|
||||
@@ -135,35 +132,35 @@ func (u *upstreamResolverBase) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||
|
||||
if err != nil {
|
||||
if errors.Is(err, context.DeadlineExceeded) || isTimeout(err) {
|
||||
logger.Warnf("upstream %s timed out for question domain=%s", upstream, r.Question[0].Name)
|
||||
log.Warnf("upstream %s timed out for question domain=%s", upstream, r.Question[0].Name)
|
||||
continue
|
||||
}
|
||||
logger.Warnf("failed to query upstream %s for question domain=%s: %s", upstream, r.Question[0].Name, err)
|
||||
log.Warnf("failed to query upstream %s for question domain=%s: %s", upstream, r.Question[0].Name, err)
|
||||
continue
|
||||
}
|
||||
|
||||
if rm == nil || !rm.Response {
|
||||
logger.Warnf("no response from upstream %s for question domain=%s", upstream, r.Question[0].Name)
|
||||
log.Warnf("no response from upstream %s for question domain=%s", upstream, r.Question[0].Name)
|
||||
continue
|
||||
}
|
||||
|
||||
u.successCount.Add(1)
|
||||
logger.Tracef("took %s to query the upstream %s for question domain=%s", t, upstream, r.Question[0].Name)
|
||||
log.Tracef("took %s to query the upstream %s for question domain=%s", t, upstream, r.Question[0].Name)
|
||||
|
||||
if err = w.WriteMsg(rm); err != nil {
|
||||
logger.Errorf("failed to write DNS response for question domain=%s: %s", r.Question[0].Name, err)
|
||||
log.Errorf("failed to write DNS response for question domain=%s: %s", r.Question[0].Name, err)
|
||||
}
|
||||
// count the fails only if they happen sequentially
|
||||
u.failsCount.Store(0)
|
||||
return
|
||||
}
|
||||
u.failsCount.Add(1)
|
||||
logger.Errorf("all queries to the %s failed for question domain=%s", u, r.Question[0].Name)
|
||||
log.Errorf("all queries to the %s failed for question domain=%s", u, r.Question[0].Name)
|
||||
|
||||
m := new(dns.Msg)
|
||||
m.SetRcode(r, dns.RcodeServerFailure)
|
||||
if err := w.WriteMsg(m); err != nil {
|
||||
logger.Errorf("failed to write error response for %s for question domain=%s: %s", u, r.Question[0].Name, err)
|
||||
log.Errorf("failed to write error response for %s for question domain=%s: %s", u, r.Question[0].Name, err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -388,13 +385,3 @@ func ExchangeWithFallback(ctx context.Context, client *dns.Client, r *dns.Msg, u
|
||||
|
||||
return rm, t, nil
|
||||
}
|
||||
|
||||
func GenerateRequestID() string {
|
||||
bytes := make([]byte, 4)
|
||||
_, err := rand.Read(bytes)
|
||||
if err != nil {
|
||||
log.Errorf("failed to generate request ID: %v", err)
|
||||
return ""
|
||||
}
|
||||
return hex.EncodeToString(bytes)
|
||||
}
|
||||
|
||||
@@ -84,10 +84,3 @@ func (u *upstreamResolver) isLocalResolver(upstream string) bool {
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func GetClientPrivate(ip netip.Addr, interfaceName string, dialTimeout time.Duration) (*dns.Client, error) {
|
||||
return &dns.Client{
|
||||
Timeout: dialTimeout,
|
||||
Net: "udp",
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -36,10 +36,3 @@ func newUpstreamResolver(
|
||||
func (u *upstreamResolver) exchange(ctx context.Context, upstream string, r *dns.Msg) (rm *dns.Msg, t time.Duration, err error) {
|
||||
return ExchangeWithFallback(ctx, &dns.Client{}, r, upstream)
|
||||
}
|
||||
|
||||
func GetClientPrivate(ip netip.Addr, interfaceName string, dialTimeout time.Duration) (*dns.Client, error) {
|
||||
return &dns.Client{
|
||||
Timeout: dialTimeout,
|
||||
Net: "udp",
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -18,20 +18,14 @@ import (
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/domain"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
)
|
||||
|
||||
const errResolveFailed = "failed to resolve query for domain=%s: %v"
|
||||
const upstreamTimeout = 15 * time.Second
|
||||
|
||||
type resolver interface {
|
||||
LookupNetIP(ctx context.Context, network, host string) ([]netip.Addr, error)
|
||||
}
|
||||
|
||||
type firewaller interface {
|
||||
UpdateSet(set firewall.Set, prefixes []netip.Prefix) error
|
||||
}
|
||||
|
||||
type DNSForwarder struct {
|
||||
listenAddress string
|
||||
ttl uint32
|
||||
@@ -44,18 +38,16 @@ type DNSForwarder struct {
|
||||
|
||||
mutex sync.RWMutex
|
||||
fwdEntries []*ForwarderEntry
|
||||
firewall firewaller
|
||||
resolver resolver
|
||||
firewall firewall.Manager
|
||||
}
|
||||
|
||||
func NewDNSForwarder(listenAddress string, ttl uint32, firewall firewaller, statusRecorder *peer.Status) *DNSForwarder {
|
||||
func NewDNSForwarder(listenAddress string, ttl uint32, firewall firewall.Manager, statusRecorder *peer.Status) *DNSForwarder {
|
||||
log.Debugf("creating DNS forwarder with listen_address=%s ttl=%d", listenAddress, ttl)
|
||||
return &DNSForwarder{
|
||||
listenAddress: listenAddress,
|
||||
ttl: ttl,
|
||||
firewall: firewall,
|
||||
statusRecorder: statusRecorder,
|
||||
resolver: net.DefaultResolver,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -65,17 +57,14 @@ func (f *DNSForwarder) Listen(entries []*ForwarderEntry) error {
|
||||
// UDP server
|
||||
mux := dns.NewServeMux()
|
||||
f.mux = mux
|
||||
mux.HandleFunc(".", f.handleDNSQueryUDP)
|
||||
f.dnsServer = &dns.Server{
|
||||
Addr: f.listenAddress,
|
||||
Net: "udp",
|
||||
Handler: mux,
|
||||
}
|
||||
|
||||
// TCP server
|
||||
tcpMux := dns.NewServeMux()
|
||||
f.tcpMux = tcpMux
|
||||
tcpMux.HandleFunc(".", f.handleDNSQueryTCP)
|
||||
f.tcpServer = &dns.Server{
|
||||
Addr: f.listenAddress,
|
||||
Net: "tcp",
|
||||
@@ -98,13 +87,30 @@ func (f *DNSForwarder) Listen(entries []*ForwarderEntry) error {
|
||||
// return the first error we get (e.g. bind failure or shutdown)
|
||||
return <-errCh
|
||||
}
|
||||
|
||||
func (f *DNSForwarder) UpdateDomains(entries []*ForwarderEntry) {
|
||||
f.mutex.Lock()
|
||||
defer f.mutex.Unlock()
|
||||
|
||||
if f.mux == nil {
|
||||
log.Debug("DNS mux is nil, skipping domain update")
|
||||
f.fwdEntries = entries
|
||||
return
|
||||
}
|
||||
|
||||
oldDomains := filterDomains(f.fwdEntries)
|
||||
for _, d := range oldDomains {
|
||||
f.mux.HandleRemove(d.PunycodeString())
|
||||
f.tcpMux.HandleRemove(d.PunycodeString())
|
||||
}
|
||||
|
||||
newDomains := filterDomains(entries)
|
||||
for _, d := range newDomains {
|
||||
f.mux.HandleFunc(d.PunycodeString(), f.handleDNSQueryUDP)
|
||||
f.tcpMux.HandleFunc(d.PunycodeString(), f.handleDNSQueryTCP)
|
||||
}
|
||||
|
||||
f.fwdEntries = entries
|
||||
log.Debugf("Updated DNS forwarder with %d domains", len(entries))
|
||||
log.Debugf("Updated domains from %v to %v", oldDomains, newDomains)
|
||||
}
|
||||
|
||||
func (f *DNSForwarder) Close(ctx context.Context) error {
|
||||
@@ -151,31 +157,22 @@ func (f *DNSForwarder) handleDNSQuery(w dns.ResponseWriter, query *dns.Msg) *dns
|
||||
return nil
|
||||
}
|
||||
|
||||
mostSpecificResId, matchingEntries := f.getMatchingEntries(strings.TrimSuffix(domain, "."))
|
||||
// query doesn't match any configured domain
|
||||
if mostSpecificResId == "" {
|
||||
resp.Rcode = dns.RcodeRefused
|
||||
if err := w.WriteMsg(resp); err != nil {
|
||||
log.Errorf("failed to write DNS response: %v", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), upstreamTimeout)
|
||||
defer cancel()
|
||||
ips, err := f.resolver.LookupNetIP(ctx, network, domain)
|
||||
ips, err := net.DefaultResolver.LookupNetIP(ctx, network, domain)
|
||||
if err != nil {
|
||||
f.handleDNSError(w, query, resp, domain, err)
|
||||
return nil
|
||||
}
|
||||
|
||||
f.updateInternalState(ips, mostSpecificResId, matchingEntries)
|
||||
f.updateInternalState(domain, ips)
|
||||
f.addIPsToResponse(resp, domain, ips)
|
||||
|
||||
return resp
|
||||
}
|
||||
|
||||
func (f *DNSForwarder) handleDNSQueryUDP(w dns.ResponseWriter, query *dns.Msg) {
|
||||
|
||||
resp := f.handleDNSQuery(w, query)
|
||||
if resp == nil {
|
||||
return
|
||||
@@ -209,8 +206,9 @@ func (f *DNSForwarder) handleDNSQueryTCP(w dns.ResponseWriter, query *dns.Msg) {
|
||||
}
|
||||
}
|
||||
|
||||
func (f *DNSForwarder) updateInternalState(ips []netip.Addr, mostSpecificResId route.ResID, matchingEntries []*ForwarderEntry) {
|
||||
func (f *DNSForwarder) updateInternalState(domain string, ips []netip.Addr) {
|
||||
var prefixes []netip.Prefix
|
||||
mostSpecificResId, matchingEntries := f.getMatchingEntries(strings.TrimSuffix(domain, "."))
|
||||
if mostSpecificResId != "" {
|
||||
for _, ip := range ips {
|
||||
var prefix netip.Prefix
|
||||
@@ -341,3 +339,16 @@ func (f *DNSForwarder) getMatchingEntries(domain string) (route.ResID, []*Forwar
|
||||
|
||||
return selectedResId, matches
|
||||
}
|
||||
|
||||
// filterDomains returns a list of normalized domains
|
||||
func filterDomains(entries []*ForwarderEntry) domain.List {
|
||||
newDomains := make(domain.List, 0, len(entries))
|
||||
for _, d := range entries {
|
||||
if d.Domain == "" {
|
||||
log.Warn("empty domain in DNS forwarder")
|
||||
continue
|
||||
}
|
||||
newDomains = append(newDomains, domain.Domain(nbdns.NormalizeZone(d.Domain.PunycodeString())))
|
||||
}
|
||||
return newDomains
|
||||
}
|
||||
|
||||
@@ -1,21 +1,11 @@
|
||||
package dnsfwd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/internal/dns/test"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/management/domain"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
)
|
||||
@@ -23,7 +13,7 @@ import (
|
||||
func Test_getMatchingEntries(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
storedMappings map[string]route.ResID
|
||||
storedMappings map[string]route.ResID // key: domain pattern, value: resId
|
||||
queryDomain string
|
||||
expectedResId route.ResID
|
||||
}{
|
||||
@@ -54,7 +44,7 @@ func Test_getMatchingEntries(t *testing.T) {
|
||||
{
|
||||
name: "Wildcard pattern does not match different domain",
|
||||
storedMappings: map[string]route.ResID{"*.example.com": "res4"},
|
||||
queryDomain: "foo.example.org",
|
||||
queryDomain: "foo.notexample.com",
|
||||
expectedResId: "",
|
||||
},
|
||||
{
|
||||
@@ -111,619 +101,3 @@ func Test_getMatchingEntries(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type MockFirewall struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *MockFirewall) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
|
||||
args := m.Called(set, prefixes)
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
type MockResolver struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *MockResolver) LookupNetIP(ctx context.Context, network, host string) ([]netip.Addr, error) {
|
||||
args := m.Called(ctx, network, host)
|
||||
return args.Get(0).([]netip.Addr), args.Error(1)
|
||||
}
|
||||
|
||||
func TestDNSForwarder_SubdomainAccessLogic(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
configuredDomain string
|
||||
queryDomain string
|
||||
shouldMatch bool
|
||||
expectedResID route.ResID
|
||||
description string
|
||||
}{
|
||||
{
|
||||
name: "exact domain match should be allowed",
|
||||
configuredDomain: "example.com",
|
||||
queryDomain: "example.com",
|
||||
shouldMatch: true,
|
||||
expectedResID: "test-res-id",
|
||||
description: "Direct match to configured domain should work",
|
||||
},
|
||||
{
|
||||
name: "subdomain access should be restricted",
|
||||
configuredDomain: "example.com",
|
||||
queryDomain: "mail.example.com",
|
||||
shouldMatch: false,
|
||||
expectedResID: "",
|
||||
description: "Subdomain should not be accessible unless explicitly configured",
|
||||
},
|
||||
{
|
||||
name: "wildcard should allow subdomains",
|
||||
configuredDomain: "*.example.com",
|
||||
queryDomain: "mail.example.com",
|
||||
shouldMatch: true,
|
||||
expectedResID: "test-res-id",
|
||||
description: "Wildcard domains should allow subdomain access",
|
||||
},
|
||||
{
|
||||
name: "wildcard should allow base domain",
|
||||
configuredDomain: "*.example.com",
|
||||
queryDomain: "example.com",
|
||||
shouldMatch: true,
|
||||
expectedResID: "test-res-id",
|
||||
description: "Wildcard should also match the base domain",
|
||||
},
|
||||
{
|
||||
name: "deep subdomain should be restricted",
|
||||
configuredDomain: "example.com",
|
||||
queryDomain: "deep.mail.example.com",
|
||||
shouldMatch: false,
|
||||
expectedResID: "",
|
||||
description: "Deep subdomains should not be accessible",
|
||||
},
|
||||
{
|
||||
name: "wildcard allows deep subdomains",
|
||||
configuredDomain: "*.example.com",
|
||||
queryDomain: "deep.mail.example.com",
|
||||
shouldMatch: true,
|
||||
expectedResID: "test-res-id",
|
||||
description: "Wildcard should allow deep subdomains",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
forwarder := &DNSForwarder{}
|
||||
|
||||
d, err := domain.FromString(tt.configuredDomain)
|
||||
require.NoError(t, err)
|
||||
|
||||
entries := []*ForwarderEntry{
|
||||
{
|
||||
Domain: d,
|
||||
ResID: "test-res-id",
|
||||
},
|
||||
}
|
||||
|
||||
forwarder.UpdateDomains(entries)
|
||||
|
||||
resID, matchingEntries := forwarder.getMatchingEntries(tt.queryDomain)
|
||||
|
||||
if tt.shouldMatch {
|
||||
assert.Equal(t, tt.expectedResID, resID, "Expected matching ResID")
|
||||
assert.NotEmpty(t, matchingEntries, "Expected matching entries")
|
||||
t.Logf("✓ Domain %s correctly matches pattern %s", tt.queryDomain, tt.configuredDomain)
|
||||
} else {
|
||||
assert.Equal(t, tt.expectedResID, resID, "Expected no ResID match")
|
||||
assert.Empty(t, matchingEntries, "Expected no matching entries")
|
||||
t.Logf("✓ Domain %s correctly does NOT match pattern %s", tt.queryDomain, tt.configuredDomain)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNSForwarder_UnauthorizedDomainAccess(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("Skipping integration test in short mode")
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
configuredDomain string
|
||||
queryDomain string
|
||||
shouldResolve bool
|
||||
description string
|
||||
}{
|
||||
{
|
||||
name: "configured exact domain resolves",
|
||||
configuredDomain: "example.com",
|
||||
queryDomain: "example.com",
|
||||
shouldResolve: true,
|
||||
description: "Exact match should resolve",
|
||||
},
|
||||
{
|
||||
name: "unauthorized subdomain blocked",
|
||||
configuredDomain: "example.com",
|
||||
queryDomain: "mail.example.com",
|
||||
shouldResolve: false,
|
||||
description: "Subdomain should be blocked without wildcard",
|
||||
},
|
||||
{
|
||||
name: "wildcard allows subdomain",
|
||||
configuredDomain: "*.example.com",
|
||||
queryDomain: "mail.example.com",
|
||||
shouldResolve: true,
|
||||
description: "Wildcard should allow subdomain",
|
||||
},
|
||||
{
|
||||
name: "wildcard allows base domain",
|
||||
configuredDomain: "*.example.com",
|
||||
queryDomain: "example.com",
|
||||
shouldResolve: true,
|
||||
description: "Wildcard should allow base domain",
|
||||
},
|
||||
{
|
||||
name: "unrelated domain blocked",
|
||||
configuredDomain: "example.com",
|
||||
queryDomain: "example.org",
|
||||
shouldResolve: false,
|
||||
description: "Unrelated domain should be blocked",
|
||||
},
|
||||
{
|
||||
name: "deep subdomain blocked",
|
||||
configuredDomain: "example.com",
|
||||
queryDomain: "deep.mail.example.com",
|
||||
shouldResolve: false,
|
||||
description: "Deep subdomain should be blocked",
|
||||
},
|
||||
{
|
||||
name: "wildcard allows deep subdomain",
|
||||
configuredDomain: "*.example.com",
|
||||
queryDomain: "deep.mail.example.com",
|
||||
shouldResolve: true,
|
||||
description: "Wildcard should allow deep subdomain",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
mockFirewall := &MockFirewall{}
|
||||
mockResolver := &MockResolver{}
|
||||
|
||||
if tt.shouldResolve {
|
||||
mockFirewall.On("UpdateSet", mock.AnythingOfType("manager.Set"), mock.AnythingOfType("[]netip.Prefix")).Return(nil)
|
||||
|
||||
// Mock successful DNS resolution
|
||||
fakeIP := netip.MustParseAddr("1.2.3.4")
|
||||
mockResolver.On("LookupNetIP", mock.Anything, "ip4", dns.Fqdn(tt.queryDomain)).Return([]netip.Addr{fakeIP}, nil)
|
||||
}
|
||||
|
||||
forwarder := NewDNSForwarder("127.0.0.1:0", 300, mockFirewall, &peer.Status{})
|
||||
forwarder.resolver = mockResolver
|
||||
|
||||
d, err := domain.FromString(tt.configuredDomain)
|
||||
require.NoError(t, err)
|
||||
|
||||
entries := []*ForwarderEntry{
|
||||
{
|
||||
Domain: d,
|
||||
ResID: "test-res-id",
|
||||
Set: firewall.NewDomainSet([]domain.Domain{d}),
|
||||
},
|
||||
}
|
||||
|
||||
forwarder.UpdateDomains(entries)
|
||||
|
||||
query := &dns.Msg{}
|
||||
query.SetQuestion(dns.Fqdn(tt.queryDomain), dns.TypeA)
|
||||
|
||||
mockWriter := &test.MockResponseWriter{}
|
||||
resp := forwarder.handleDNSQuery(mockWriter, query)
|
||||
|
||||
if tt.shouldResolve {
|
||||
require.NotNil(t, resp, "Expected response for authorized domain")
|
||||
require.Equal(t, dns.RcodeSuccess, resp.Rcode, "Expected successful response")
|
||||
assert.NotEmpty(t, resp.Answer, "Expected DNS answer records")
|
||||
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
mockFirewall.AssertExpectations(t)
|
||||
mockResolver.AssertExpectations(t)
|
||||
} else {
|
||||
if resp != nil {
|
||||
assert.True(t, len(resp.Answer) == 0 || resp.Rcode != dns.RcodeSuccess,
|
||||
"Unauthorized domain should not return successful answers")
|
||||
}
|
||||
mockFirewall.AssertNotCalled(t, "UpdateSet")
|
||||
mockResolver.AssertNotCalled(t, "LookupNetIP")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNSForwarder_FirewallSetUpdates(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
configuredDomains []string
|
||||
query string
|
||||
mockIP string
|
||||
shouldResolve bool
|
||||
expectedSetCount int // How many sets should be updated
|
||||
description string
|
||||
}{
|
||||
{
|
||||
name: "exact domain gets firewall update",
|
||||
configuredDomains: []string{"example.com"},
|
||||
query: "example.com",
|
||||
mockIP: "1.1.1.1",
|
||||
shouldResolve: true,
|
||||
expectedSetCount: 1,
|
||||
description: "Single exact match updates one set",
|
||||
},
|
||||
{
|
||||
name: "wildcard domain gets firewall update",
|
||||
configuredDomains: []string{"*.example.com"},
|
||||
query: "mail.example.com",
|
||||
mockIP: "1.1.1.2",
|
||||
shouldResolve: true,
|
||||
expectedSetCount: 1,
|
||||
description: "Wildcard match updates one set",
|
||||
},
|
||||
{
|
||||
name: "overlapping exact and wildcard both get updates",
|
||||
configuredDomains: []string{"*.example.com", "mail.example.com"},
|
||||
query: "mail.example.com",
|
||||
mockIP: "1.1.1.3",
|
||||
shouldResolve: true,
|
||||
expectedSetCount: 2,
|
||||
description: "Both exact and wildcard sets should be updated",
|
||||
},
|
||||
{
|
||||
name: "unauthorized domain gets no firewall update",
|
||||
configuredDomains: []string{"example.com"},
|
||||
query: "mail.example.com",
|
||||
mockIP: "1.1.1.4",
|
||||
shouldResolve: false,
|
||||
expectedSetCount: 0,
|
||||
description: "No firewall update for unauthorized domains",
|
||||
},
|
||||
{
|
||||
name: "multiple wildcards matching get all updated",
|
||||
configuredDomains: []string{"*.example.com", "*.sub.example.com"},
|
||||
query: "test.sub.example.com",
|
||||
mockIP: "1.1.1.5",
|
||||
shouldResolve: true,
|
||||
expectedSetCount: 2,
|
||||
description: "All matching wildcard sets should be updated",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
mockFirewall := &MockFirewall{}
|
||||
mockResolver := &MockResolver{}
|
||||
|
||||
// Set up forwarder
|
||||
forwarder := NewDNSForwarder("127.0.0.1:0", 300, mockFirewall, &peer.Status{})
|
||||
forwarder.resolver = mockResolver
|
||||
|
||||
// Create entries and track sets
|
||||
var entries []*ForwarderEntry
|
||||
sets := make([]firewall.Set, 0)
|
||||
|
||||
for i, configDomain := range tt.configuredDomains {
|
||||
d, err := domain.FromString(configDomain)
|
||||
require.NoError(t, err)
|
||||
|
||||
set := firewall.NewDomainSet([]domain.Domain{d})
|
||||
sets = append(sets, set)
|
||||
|
||||
entries = append(entries, &ForwarderEntry{
|
||||
Domain: d,
|
||||
ResID: route.ResID(fmt.Sprintf("res-%d", i)),
|
||||
Set: set,
|
||||
})
|
||||
}
|
||||
|
||||
forwarder.UpdateDomains(entries)
|
||||
|
||||
// Set up mocks
|
||||
if tt.shouldResolve {
|
||||
fakeIP := netip.MustParseAddr(tt.mockIP)
|
||||
mockResolver.On("LookupNetIP", mock.Anything, "ip4", dns.Fqdn(tt.query)).
|
||||
Return([]netip.Addr{fakeIP}, nil).Once()
|
||||
|
||||
expectedPrefixes := []netip.Prefix{netip.PrefixFrom(fakeIP, 32)}
|
||||
|
||||
// Count how many sets should actually match
|
||||
updateCount := 0
|
||||
for i, entry := range entries {
|
||||
domain := strings.ToLower(tt.query)
|
||||
pattern := entry.Domain.PunycodeString()
|
||||
|
||||
matches := false
|
||||
if strings.HasPrefix(pattern, "*.") {
|
||||
baseDomain := strings.TrimPrefix(pattern, "*.")
|
||||
if domain == baseDomain || strings.HasSuffix(domain, "."+baseDomain) {
|
||||
matches = true
|
||||
}
|
||||
} else if domain == pattern {
|
||||
matches = true
|
||||
}
|
||||
|
||||
if matches {
|
||||
mockFirewall.On("UpdateSet", sets[i], expectedPrefixes).Return(nil).Once()
|
||||
updateCount++
|
||||
}
|
||||
}
|
||||
|
||||
assert.Equal(t, tt.expectedSetCount, updateCount,
|
||||
"Expected %d sets to be updated, but mock expects %d",
|
||||
tt.expectedSetCount, updateCount)
|
||||
}
|
||||
|
||||
// Execute query
|
||||
dnsQuery := &dns.Msg{}
|
||||
dnsQuery.SetQuestion(dns.Fqdn(tt.query), dns.TypeA)
|
||||
|
||||
mockWriter := &test.MockResponseWriter{}
|
||||
resp := forwarder.handleDNSQuery(mockWriter, dnsQuery)
|
||||
|
||||
// Verify response
|
||||
if tt.shouldResolve {
|
||||
require.NotNil(t, resp, "Expected response for authorized domain")
|
||||
require.Equal(t, dns.RcodeSuccess, resp.Rcode)
|
||||
require.NotEmpty(t, resp.Answer)
|
||||
} else if resp != nil {
|
||||
assert.True(t, resp.Rcode == dns.RcodeRefused || len(resp.Answer) == 0,
|
||||
"Unauthorized domain should be refused or have no answers")
|
||||
}
|
||||
|
||||
// Verify all mock expectations were met
|
||||
mockFirewall.AssertExpectations(t)
|
||||
mockResolver.AssertExpectations(t)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Test to verify that multiple IPs for one domain result in all prefixes being sent together
|
||||
func TestDNSForwarder_MultipleIPsInSingleUpdate(t *testing.T) {
|
||||
mockFirewall := &MockFirewall{}
|
||||
mockResolver := &MockResolver{}
|
||||
|
||||
forwarder := NewDNSForwarder("127.0.0.1:0", 300, mockFirewall, &peer.Status{})
|
||||
forwarder.resolver = mockResolver
|
||||
|
||||
// Configure a single domain
|
||||
d, err := domain.FromString("example.com")
|
||||
require.NoError(t, err)
|
||||
|
||||
set := firewall.NewDomainSet([]domain.Domain{d})
|
||||
entries := []*ForwarderEntry{{
|
||||
Domain: d,
|
||||
ResID: "test-res",
|
||||
Set: set,
|
||||
}}
|
||||
|
||||
forwarder.UpdateDomains(entries)
|
||||
|
||||
// Mock resolver returns multiple IPs
|
||||
ips := []netip.Addr{
|
||||
netip.MustParseAddr("1.1.1.1"),
|
||||
netip.MustParseAddr("1.1.1.2"),
|
||||
netip.MustParseAddr("1.1.1.3"),
|
||||
}
|
||||
mockResolver.On("LookupNetIP", mock.Anything, "ip4", "example.com.").
|
||||
Return(ips, nil).Once()
|
||||
|
||||
// Expect ONE UpdateSet call with ALL prefixes
|
||||
expectedPrefixes := []netip.Prefix{
|
||||
netip.PrefixFrom(ips[0], 32),
|
||||
netip.PrefixFrom(ips[1], 32),
|
||||
netip.PrefixFrom(ips[2], 32),
|
||||
}
|
||||
mockFirewall.On("UpdateSet", set, expectedPrefixes).Return(nil).Once()
|
||||
|
||||
// Execute query
|
||||
query := &dns.Msg{}
|
||||
query.SetQuestion("example.com.", dns.TypeA)
|
||||
|
||||
mockWriter := &test.MockResponseWriter{}
|
||||
resp := forwarder.handleDNSQuery(mockWriter, query)
|
||||
|
||||
// Verify response contains all IPs
|
||||
require.NotNil(t, resp)
|
||||
require.Equal(t, dns.RcodeSuccess, resp.Rcode)
|
||||
require.Len(t, resp.Answer, 3, "Should have 3 answer records")
|
||||
|
||||
// Verify mocks
|
||||
mockFirewall.AssertExpectations(t)
|
||||
mockResolver.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func TestDNSForwarder_ResponseCodes(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
queryType uint16
|
||||
queryDomain string
|
||||
configured string
|
||||
expectedCode int
|
||||
description string
|
||||
}{
|
||||
{
|
||||
name: "unauthorized domain returns REFUSED",
|
||||
queryType: dns.TypeA,
|
||||
queryDomain: "evil.com",
|
||||
configured: "example.com",
|
||||
expectedCode: dns.RcodeRefused,
|
||||
description: "RFC compliant REFUSED for unauthorized queries",
|
||||
},
|
||||
{
|
||||
name: "unsupported query type returns NOTIMP",
|
||||
queryType: dns.TypeMX,
|
||||
queryDomain: "example.com",
|
||||
configured: "example.com",
|
||||
expectedCode: dns.RcodeNotImplemented,
|
||||
description: "RFC compliant NOTIMP for unsupported types",
|
||||
},
|
||||
{
|
||||
name: "CNAME query returns NOTIMP",
|
||||
queryType: dns.TypeCNAME,
|
||||
queryDomain: "example.com",
|
||||
configured: "example.com",
|
||||
expectedCode: dns.RcodeNotImplemented,
|
||||
description: "CNAME queries not supported",
|
||||
},
|
||||
{
|
||||
name: "TXT query returns NOTIMP",
|
||||
queryType: dns.TypeTXT,
|
||||
queryDomain: "example.com",
|
||||
configured: "example.com",
|
||||
expectedCode: dns.RcodeNotImplemented,
|
||||
description: "TXT queries not supported",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
forwarder := NewDNSForwarder("127.0.0.1:0", 300, nil, &peer.Status{})
|
||||
|
||||
d, err := domain.FromString(tt.configured)
|
||||
require.NoError(t, err)
|
||||
|
||||
entries := []*ForwarderEntry{{Domain: d, ResID: "test-res"}}
|
||||
forwarder.UpdateDomains(entries)
|
||||
|
||||
query := &dns.Msg{}
|
||||
query.SetQuestion(dns.Fqdn(tt.queryDomain), tt.queryType)
|
||||
|
||||
// Capture the written response
|
||||
var writtenResp *dns.Msg
|
||||
mockWriter := &test.MockResponseWriter{
|
||||
WriteMsgFunc: func(m *dns.Msg) error {
|
||||
writtenResp = m
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
_ = forwarder.handleDNSQuery(mockWriter, query)
|
||||
|
||||
// Check the response written to the writer
|
||||
require.NotNil(t, writtenResp, "Expected response to be written")
|
||||
assert.Equal(t, tt.expectedCode, writtenResp.Rcode, tt.description)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNSForwarder_TCPTruncation(t *testing.T) {
|
||||
// Test that large UDP responses are truncated with TC bit set
|
||||
mockResolver := &MockResolver{}
|
||||
forwarder := NewDNSForwarder("127.0.0.1:0", 300, nil, &peer.Status{})
|
||||
forwarder.resolver = mockResolver
|
||||
|
||||
d, _ := domain.FromString("example.com")
|
||||
entries := []*ForwarderEntry{{Domain: d, ResID: "test-res"}}
|
||||
forwarder.UpdateDomains(entries)
|
||||
|
||||
// Mock many IPs to create a large response
|
||||
var manyIPs []netip.Addr
|
||||
for i := 0; i < 100; i++ {
|
||||
manyIPs = append(manyIPs, netip.MustParseAddr(fmt.Sprintf("1.1.1.%d", i%256)))
|
||||
}
|
||||
mockResolver.On("LookupNetIP", mock.Anything, "ip4", "example.com.").Return(manyIPs, nil)
|
||||
|
||||
// Query without EDNS0
|
||||
query := &dns.Msg{}
|
||||
query.SetQuestion("example.com.", dns.TypeA)
|
||||
|
||||
var writtenResp *dns.Msg
|
||||
mockWriter := &test.MockResponseWriter{
|
||||
WriteMsgFunc: func(m *dns.Msg) error {
|
||||
writtenResp = m
|
||||
return nil
|
||||
},
|
||||
}
|
||||
forwarder.handleDNSQueryUDP(mockWriter, query)
|
||||
|
||||
require.NotNil(t, writtenResp)
|
||||
assert.True(t, writtenResp.Truncated, "Large response should be truncated")
|
||||
assert.LessOrEqual(t, writtenResp.Len(), dns.MinMsgSize, "Response should fit in minimum UDP size")
|
||||
}
|
||||
|
||||
func TestDNSForwarder_MultipleOverlappingPatterns(t *testing.T) {
|
||||
// Test complex overlapping pattern scenarios
|
||||
mockFirewall := &MockFirewall{}
|
||||
mockResolver := &MockResolver{}
|
||||
|
||||
forwarder := NewDNSForwarder("127.0.0.1:0", 300, mockFirewall, &peer.Status{})
|
||||
forwarder.resolver = mockResolver
|
||||
|
||||
// Set up complex overlapping patterns
|
||||
patterns := []string{
|
||||
"*.example.com", // Matches all subdomains
|
||||
"*.mail.example.com", // More specific wildcard
|
||||
"smtp.mail.example.com", // Exact match
|
||||
"example.com", // Base domain
|
||||
}
|
||||
|
||||
var entries []*ForwarderEntry
|
||||
sets := make(map[string]firewall.Set)
|
||||
|
||||
for _, pattern := range patterns {
|
||||
d, _ := domain.FromString(pattern)
|
||||
set := firewall.NewDomainSet([]domain.Domain{d})
|
||||
sets[pattern] = set
|
||||
entries = append(entries, &ForwarderEntry{
|
||||
Domain: d,
|
||||
ResID: route.ResID("res-" + pattern),
|
||||
Set: set,
|
||||
})
|
||||
}
|
||||
|
||||
forwarder.UpdateDomains(entries)
|
||||
|
||||
// Test smtp.mail.example.com - should match 3 patterns
|
||||
fakeIP := netip.MustParseAddr("1.2.3.4")
|
||||
mockResolver.On("LookupNetIP", mock.Anything, "ip4", "smtp.mail.example.com.").Return([]netip.Addr{fakeIP}, nil)
|
||||
|
||||
expectedPrefix := netip.PrefixFrom(fakeIP, 32)
|
||||
// All three matching patterns should get firewall updates
|
||||
mockFirewall.On("UpdateSet", sets["smtp.mail.example.com"], []netip.Prefix{expectedPrefix}).Return(nil)
|
||||
mockFirewall.On("UpdateSet", sets["*.mail.example.com"], []netip.Prefix{expectedPrefix}).Return(nil)
|
||||
mockFirewall.On("UpdateSet", sets["*.example.com"], []netip.Prefix{expectedPrefix}).Return(nil)
|
||||
|
||||
query := &dns.Msg{}
|
||||
query.SetQuestion("smtp.mail.example.com.", dns.TypeA)
|
||||
|
||||
mockWriter := &test.MockResponseWriter{}
|
||||
resp := forwarder.handleDNSQuery(mockWriter, query)
|
||||
|
||||
require.NotNil(t, resp)
|
||||
assert.Equal(t, dns.RcodeSuccess, resp.Rcode)
|
||||
|
||||
// Verify all three sets were updated
|
||||
mockFirewall.AssertExpectations(t)
|
||||
|
||||
// Verify the most specific ResID was selected
|
||||
// (exact match should win over wildcards)
|
||||
resID, matches := forwarder.getMatchingEntries("smtp.mail.example.com")
|
||||
assert.Equal(t, route.ResID("res-smtp.mail.example.com"), resID)
|
||||
assert.Len(t, matches, 3, "Should match 3 patterns")
|
||||
}
|
||||
|
||||
func TestDNSForwarder_EmptyQuery(t *testing.T) {
|
||||
// Test handling of malformed query with no questions
|
||||
forwarder := NewDNSForwarder("127.0.0.1:0", 300, nil, &peer.Status{})
|
||||
|
||||
query := &dns.Msg{}
|
||||
// Don't set any question
|
||||
|
||||
writeCalled := false
|
||||
mockWriter := &test.MockResponseWriter{
|
||||
WriteMsgFunc: func(m *dns.Msg) error {
|
||||
writeCalled = true
|
||||
return nil
|
||||
},
|
||||
}
|
||||
resp := forwarder.handleDNSQuery(mockWriter, query)
|
||||
|
||||
assert.Nil(t, resp, "Should return nil for empty query")
|
||||
assert.False(t, writeCalled, "Should not write response for empty query")
|
||||
}
|
||||
|
||||
@@ -38,6 +38,7 @@ import (
|
||||
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||
"github.com/netbirdio/netbird/client/internal/networkmonitor"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/peer/dispatcher"
|
||||
"github.com/netbirdio/netbird/client/internal/peer/guard"
|
||||
icemaker "github.com/netbirdio/netbird/client/internal/peer/ice"
|
||||
"github.com/netbirdio/netbird/client/internal/peerstore"
|
||||
@@ -174,7 +175,8 @@ type Engine struct {
|
||||
sshServerFunc func(hostKeyPEM []byte, addr string) (nbssh.Server, error)
|
||||
sshServer nbssh.Server
|
||||
|
||||
statusRecorder *peer.Status
|
||||
statusRecorder *peer.Status
|
||||
peerConnDispatcher *dispatcher.ConnectionDispatcher
|
||||
|
||||
firewall firewallManager.Manager
|
||||
routeManager routemanager.Manager
|
||||
@@ -381,13 +383,7 @@ func (e *Engine) Start() error {
|
||||
}
|
||||
e.stateManager.Start()
|
||||
|
||||
initialRoutes, dnsConfig, dnsFeatureFlag, err := e.readInitialSettings()
|
||||
if err != nil {
|
||||
e.close()
|
||||
return fmt.Errorf("read initial settings: %w", err)
|
||||
}
|
||||
|
||||
dnsServer, err := e.newDnsServer(dnsConfig)
|
||||
initialRoutes, dnsServer, err := e.newDnsServer()
|
||||
if err != nil {
|
||||
e.close()
|
||||
return fmt.Errorf("create dns server: %w", err)
|
||||
@@ -404,7 +400,6 @@ func (e *Engine) Start() error {
|
||||
InitialRoutes: initialRoutes,
|
||||
StateManager: e.stateManager,
|
||||
DNSServer: dnsServer,
|
||||
DNSFeatureFlag: dnsFeatureFlag,
|
||||
PeerStore: e.peerStore,
|
||||
DisableClientRoutes: e.config.DisableClientRoutes,
|
||||
DisableServerRoutes: e.config.DisableServerRoutes,
|
||||
@@ -456,7 +451,9 @@ func (e *Engine) Start() error {
|
||||
NATExternalIPs: e.parseNATExternalIPMappings(),
|
||||
}
|
||||
|
||||
e.connMgr = NewConnMgr(e.config, e.statusRecorder, e.peerStore, wgIface)
|
||||
e.peerConnDispatcher = dispatcher.NewConnectionDispatcher()
|
||||
|
||||
e.connMgr = NewConnMgr(e.config, e.statusRecorder, e.peerStore, wgIface, e.peerConnDispatcher)
|
||||
e.connMgr.Start(e.ctx)
|
||||
|
||||
e.srWatcher = guard.NewSRWatcher(e.signal, e.relayManager, e.mobileDep.IFaceDiscover, iceCfg)
|
||||
@@ -491,9 +488,9 @@ func (e *Engine) createFirewall() error {
|
||||
}
|
||||
|
||||
func (e *Engine) initFirewall() error {
|
||||
if err := e.routeManager.SetFirewall(e.firewall); err != nil {
|
||||
if err := e.routeManager.EnableServerRouter(e.firewall); err != nil {
|
||||
e.close()
|
||||
return fmt.Errorf("set firewall: %w", err)
|
||||
return fmt.Errorf("enable server router: %w", err)
|
||||
}
|
||||
|
||||
if e.config.BlockLANAccess {
|
||||
@@ -1012,6 +1009,8 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
|
||||
log.Errorf("failed to update dns server, err: %v", err)
|
||||
}
|
||||
|
||||
dnsRouteFeatureFlag := toDNSFeatureFlag(networkMap)
|
||||
|
||||
// apply routes first, route related actions might depend on routing being enabled
|
||||
routes := toRoutes(networkMap.GetRoutes())
|
||||
serverRoutes, clientRoutes := e.routeManager.ClassifyRoutes(routes)
|
||||
@@ -1022,7 +1021,6 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
|
||||
log.Debugf("updated lazy connection manager with %d HA groups", len(clientRoutes))
|
||||
}
|
||||
|
||||
dnsRouteFeatureFlag := toDNSFeatureFlag(networkMap)
|
||||
if err := e.routeManager.UpdateRoutes(serial, serverRoutes, clientRoutes, dnsRouteFeatureFlag); err != nil {
|
||||
log.Errorf("failed to update routes: %v", err)
|
||||
}
|
||||
@@ -1257,7 +1255,7 @@ func (e *Engine) addNewPeer(peerConfig *mgmProto.RemotePeerConfig) error {
|
||||
}
|
||||
|
||||
if exists := e.connMgr.AddPeerConn(e.ctx, peerKey, conn); exists {
|
||||
conn.Close(false)
|
||||
conn.Close()
|
||||
return fmt.Errorf("peer already exists: %s", peerKey)
|
||||
}
|
||||
|
||||
@@ -1304,12 +1302,13 @@ func (e *Engine) createPeerConn(pubKey string, allowedIPs []netip.Prefix, agentV
|
||||
}
|
||||
|
||||
serviceDependencies := peer.ServiceDependencies{
|
||||
StatusRecorder: e.statusRecorder,
|
||||
Signaler: e.signaler,
|
||||
IFaceDiscover: e.mobileDep.IFaceDiscover,
|
||||
RelayManager: e.relayManager,
|
||||
SrWatcher: e.srWatcher,
|
||||
Semaphore: e.connSemaphore,
|
||||
StatusRecorder: e.statusRecorder,
|
||||
Signaler: e.signaler,
|
||||
IFaceDiscover: e.mobileDep.IFaceDiscover,
|
||||
RelayManager: e.relayManager,
|
||||
SrWatcher: e.srWatcher,
|
||||
Semaphore: e.connSemaphore,
|
||||
PeerConnDispatcher: e.peerConnDispatcher,
|
||||
}
|
||||
peerConn, err := peer.NewConn(config, serviceDependencies)
|
||||
if err != nil {
|
||||
@@ -1332,16 +1331,11 @@ func (e *Engine) receiveSignalEvents() {
|
||||
e.syncMsgMux.Lock()
|
||||
defer e.syncMsgMux.Unlock()
|
||||
|
||||
conn, ok := e.peerStore.PeerConn(msg.Key)
|
||||
conn, ok := e.connMgr.OnSignalMsg(e.ctx, msg.Key)
|
||||
if !ok {
|
||||
return fmt.Errorf("wrongly addressed message %s", msg.Key)
|
||||
}
|
||||
|
||||
msgType := msg.GetBody().GetType()
|
||||
if msgType != sProto.Body_GO_IDLE {
|
||||
e.connMgr.ActivatePeer(e.ctx, conn)
|
||||
}
|
||||
|
||||
switch msg.GetBody().Type {
|
||||
case sProto.Body_OFFER:
|
||||
remoteCred, err := signal.UnMarshalCredential(msg)
|
||||
@@ -1398,8 +1392,6 @@ func (e *Engine) receiveSignalEvents() {
|
||||
|
||||
go conn.OnRemoteCandidate(candidate, e.routeManager.GetClientRoutes())
|
||||
case sProto.Body_MODE:
|
||||
case sProto.Body_GO_IDLE:
|
||||
e.connMgr.DeactivatePeer(conn)
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -1497,12 +1489,7 @@ func (e *Engine) close() {
|
||||
}
|
||||
}
|
||||
|
||||
func (e *Engine) readInitialSettings() ([]*route.Route, *nbdns.Config, bool, error) {
|
||||
if runtime.GOOS != "android" {
|
||||
// nolint:nilnil
|
||||
return nil, nil, false, nil
|
||||
}
|
||||
|
||||
func (e *Engine) readInitialSettings() ([]*route.Route, *nbdns.Config, error) {
|
||||
info := system.GetInfo(e.ctx)
|
||||
info.SetFlags(
|
||||
e.config.RosenpassEnabled,
|
||||
@@ -1519,12 +1506,11 @@ func (e *Engine) readInitialSettings() ([]*route.Route, *nbdns.Config, bool, err
|
||||
|
||||
netMap, err := e.mgmClient.GetNetworkMap(info)
|
||||
if err != nil {
|
||||
return nil, nil, false, err
|
||||
return nil, nil, err
|
||||
}
|
||||
routes := toRoutes(netMap.GetRoutes())
|
||||
dnsCfg := toDNSConfig(netMap.GetDNSConfig(), e.wgInterface.Address().Network)
|
||||
dnsFeatureFlag := toDNSFeatureFlag(netMap)
|
||||
return routes, &dnsCfg, dnsFeatureFlag, nil
|
||||
return routes, &dnsCfg, nil
|
||||
}
|
||||
|
||||
func (e *Engine) newWgIface() (*iface.WGIface, error) {
|
||||
@@ -1541,7 +1527,6 @@ func (e *Engine) newWgIface() (*iface.WGIface, error) {
|
||||
MTU: iface.DefaultMTU,
|
||||
TransportNet: transportNet,
|
||||
FilterFn: e.addrViaRoutes,
|
||||
DisableDNS: e.config.DisableDNS,
|
||||
}
|
||||
|
||||
switch runtime.GOOS {
|
||||
@@ -1572,14 +1557,18 @@ func (e *Engine) wgInterfaceCreate() (err error) {
|
||||
return err
|
||||
}
|
||||
|
||||
func (e *Engine) newDnsServer(dnsConfig *nbdns.Config) (dns.Server, error) {
|
||||
func (e *Engine) newDnsServer() ([]*route.Route, dns.Server, error) {
|
||||
// due to tests where we are using a mocked version of the DNS server
|
||||
if e.dnsServer != nil {
|
||||
return e.dnsServer, nil
|
||||
return nil, e.dnsServer, nil
|
||||
}
|
||||
|
||||
switch runtime.GOOS {
|
||||
case "android":
|
||||
routes, dnsConfig, err := e.readInitialSettings()
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
dnsServer := dns.NewDefaultServerPermanentUpstream(
|
||||
e.ctx,
|
||||
e.wgInterface,
|
||||
@@ -1590,19 +1579,19 @@ func (e *Engine) newDnsServer(dnsConfig *nbdns.Config) (dns.Server, error) {
|
||||
e.config.DisableDNS,
|
||||
)
|
||||
go e.mobileDep.DnsReadyListener.OnReady()
|
||||
return dnsServer, nil
|
||||
return routes, dnsServer, nil
|
||||
|
||||
case "ios":
|
||||
dnsServer := dns.NewDefaultServerIos(e.ctx, e.wgInterface, e.mobileDep.DnsManager, e.statusRecorder, e.config.DisableDNS)
|
||||
return dnsServer, nil
|
||||
return nil, dnsServer, nil
|
||||
|
||||
default:
|
||||
dnsServer, err := dns.NewDefaultServer(e.ctx, e.wgInterface, e.config.CustomDNSAddress, e.statusRecorder, e.stateManager, e.config.DisableDNS)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
return dnsServer, nil
|
||||
return nil, dnsServer, nil
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -36,6 +36,7 @@ import (
|
||||
"github.com/netbirdio/netbird/client/iface/wgproxy"
|
||||
"github.com/netbirdio/netbird/client/internal/dns"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/peer/dispatcher"
|
||||
"github.com/netbirdio/netbird/client/internal/peer/guard"
|
||||
icemaker "github.com/netbirdio/netbird/client/internal/peer/ice"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager"
|
||||
@@ -96,7 +97,6 @@ type MockWGIface struct {
|
||||
GetInterfaceGUIDStringFunc func() (string, error)
|
||||
GetProxyFunc func() wgproxy.Proxy
|
||||
GetNetFunc func() *netstack.Net
|
||||
LastActivitiesFunc func() map[string]time.Time
|
||||
}
|
||||
|
||||
func (m *MockWGIface) FullStats() (*configurer.Stats, error) {
|
||||
@@ -187,13 +187,6 @@ func (m *MockWGIface) GetNet() *netstack.Net {
|
||||
return m.GetNetFunc()
|
||||
}
|
||||
|
||||
func (m *MockWGIface) LastActivities() map[string]time.Time {
|
||||
if m.LastActivitiesFunc != nil {
|
||||
return m.LastActivitiesFunc()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
_ = util.InitLog("debug", "console")
|
||||
code := m.Run()
|
||||
@@ -411,7 +404,7 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
|
||||
engine.udpMux = bind.NewUniversalUDPMuxDefault(bind.UniversalUDPMuxParams{UDPConn: conn})
|
||||
engine.ctx = ctx
|
||||
engine.srWatcher = guard.NewSRWatcher(nil, nil, nil, icemaker.Config{})
|
||||
engine.connMgr = NewConnMgr(engine.config, engine.statusRecorder, engine.peerStore, wgIface)
|
||||
engine.connMgr = NewConnMgr(engine.config, engine.statusRecorder, engine.peerStore, wgIface, dispatcher.NewConnectionDispatcher())
|
||||
engine.connMgr.Start(ctx)
|
||||
|
||||
type testCase struct {
|
||||
@@ -800,7 +793,7 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) {
|
||||
|
||||
engine.routeManager = mockRouteManager
|
||||
engine.dnsServer = &dns.MockServer{}
|
||||
engine.connMgr = NewConnMgr(engine.config, engine.statusRecorder, engine.peerStore, engine.wgInterface)
|
||||
engine.connMgr = NewConnMgr(engine.config, engine.statusRecorder, engine.peerStore, engine.wgInterface, dispatcher.NewConnectionDispatcher())
|
||||
engine.connMgr.Start(ctx)
|
||||
|
||||
defer func() {
|
||||
@@ -998,7 +991,7 @@ func TestEngine_UpdateNetworkMapWithDNSUpdate(t *testing.T) {
|
||||
}
|
||||
|
||||
engine.dnsServer = mockDNSServer
|
||||
engine.connMgr = NewConnMgr(engine.config, engine.statusRecorder, engine.peerStore, engine.wgInterface)
|
||||
engine.connMgr = NewConnMgr(engine.config, engine.statusRecorder, engine.peerStore, engine.wgInterface, dispatcher.NewConnectionDispatcher())
|
||||
engine.connMgr.Start(ctx)
|
||||
|
||||
defer func() {
|
||||
@@ -1483,7 +1476,7 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri
|
||||
|
||||
permissionsManager := permissions.NewManager(store)
|
||||
|
||||
accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
|
||||
accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
@@ -38,5 +38,4 @@ type wgIfaceBase interface {
|
||||
GetStats() (map[string]configurer.WGStats, error)
|
||||
GetNet() *netstack.Net
|
||||
FullStats() (*configurer.Stats, error)
|
||||
LastActivities() map[string]time.Time
|
||||
}
|
||||
|
||||
@@ -13,7 +13,7 @@ import (
|
||||
|
||||
// Listener it is not a thread safe implementation, do not call Close before ReadPackets. It will cause blocking
|
||||
type Listener struct {
|
||||
wgIface WgInterface
|
||||
wgIface lazyconn.WGIface
|
||||
peerCfg lazyconn.PeerConfig
|
||||
conn *net.UDPConn
|
||||
endpoint *net.UDPAddr
|
||||
@@ -22,7 +22,7 @@ type Listener struct {
|
||||
isClosed atomic.Bool // use to avoid error log when closing the listener
|
||||
}
|
||||
|
||||
func NewListener(wgIface WgInterface, cfg lazyconn.PeerConfig) (*Listener, error) {
|
||||
func NewListener(wgIface lazyconn.WGIface, cfg lazyconn.PeerConfig) (*Listener, error) {
|
||||
d := &Listener{
|
||||
wgIface: wgIface,
|
||||
peerCfg: cfg,
|
||||
|
||||
@@ -1,27 +1,18 @@
|
||||
package activity
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/lazyconn"
|
||||
peerid "github.com/netbirdio/netbird/client/internal/peer/id"
|
||||
)
|
||||
|
||||
type WgInterface interface {
|
||||
RemovePeer(peerKey string) error
|
||||
UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error
|
||||
}
|
||||
|
||||
type Manager struct {
|
||||
OnActivityChan chan peerid.ConnID
|
||||
|
||||
wgIface WgInterface
|
||||
wgIface lazyconn.WGIface
|
||||
|
||||
peers map[peerid.ConnID]*Listener
|
||||
done chan struct{}
|
||||
@@ -29,7 +20,7 @@ type Manager struct {
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
func NewManager(wgIface WgInterface) *Manager {
|
||||
func NewManager(wgIface lazyconn.WGIface) *Manager {
|
||||
m := &Manager{
|
||||
OnActivityChan: make(chan peerid.ConnID, 1),
|
||||
wgIface: wgIface,
|
||||
|
||||
70
client/internal/lazyconn/inactivity/inactivity.go
Normal file
70
client/internal/lazyconn/inactivity/inactivity.go
Normal file
@@ -0,0 +1,70 @@
|
||||
package inactivity
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
peer "github.com/netbirdio/netbird/client/internal/peer/id"
|
||||
)
|
||||
|
||||
const (
|
||||
DefaultInactivityThreshold = 60 * time.Minute // idle after 1 hour inactivity
|
||||
MinimumInactivityThreshold = 3 * time.Minute
|
||||
)
|
||||
|
||||
type Monitor struct {
|
||||
id peer.ConnID
|
||||
timer *time.Timer
|
||||
cancel context.CancelFunc
|
||||
inactivityThreshold time.Duration
|
||||
}
|
||||
|
||||
func NewInactivityMonitor(peerID peer.ConnID, threshold time.Duration) *Monitor {
|
||||
i := &Monitor{
|
||||
id: peerID,
|
||||
timer: time.NewTimer(0),
|
||||
inactivityThreshold: threshold,
|
||||
}
|
||||
i.timer.Stop()
|
||||
return i
|
||||
}
|
||||
|
||||
func (i *Monitor) Start(ctx context.Context, timeoutChan chan peer.ConnID) {
|
||||
i.timer.Reset(i.inactivityThreshold)
|
||||
defer i.timer.Stop()
|
||||
|
||||
ctx, i.cancel = context.WithCancel(ctx)
|
||||
defer func() {
|
||||
defer i.cancel()
|
||||
select {
|
||||
case <-i.timer.C:
|
||||
default:
|
||||
}
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-i.timer.C:
|
||||
select {
|
||||
case timeoutChan <- i.id:
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func (i *Monitor) Stop() {
|
||||
if i.cancel == nil {
|
||||
return
|
||||
}
|
||||
i.cancel()
|
||||
}
|
||||
|
||||
func (i *Monitor) PauseTimer() {
|
||||
i.timer.Stop()
|
||||
}
|
||||
|
||||
func (i *Monitor) ResetTimer() {
|
||||
i.timer.Reset(i.inactivityThreshold)
|
||||
}
|
||||
156
client/internal/lazyconn/inactivity/inactivity_test.go
Normal file
156
client/internal/lazyconn/inactivity/inactivity_test.go
Normal file
@@ -0,0 +1,156 @@
|
||||
package inactivity
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
peerid "github.com/netbirdio/netbird/client/internal/peer/id"
|
||||
)
|
||||
|
||||
type MocPeer struct {
|
||||
}
|
||||
|
||||
func (m *MocPeer) ConnID() peerid.ConnID {
|
||||
return peerid.ConnID(m)
|
||||
}
|
||||
|
||||
func TestInactivityMonitor(t *testing.T) {
|
||||
tCtx, testTimeoutCancel := context.WithTimeout(context.Background(), time.Second*5)
|
||||
defer testTimeoutCancel()
|
||||
|
||||
p := &MocPeer{}
|
||||
im := NewInactivityMonitor(p.ConnID(), time.Second*2)
|
||||
|
||||
timeoutChan := make(chan peerid.ConnID)
|
||||
|
||||
exitChan := make(chan struct{})
|
||||
|
||||
go func() {
|
||||
defer close(exitChan)
|
||||
im.Start(tCtx, timeoutChan)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-timeoutChan:
|
||||
case <-tCtx.Done():
|
||||
t.Fatal("timeout")
|
||||
}
|
||||
|
||||
select {
|
||||
case <-exitChan:
|
||||
case <-tCtx.Done():
|
||||
t.Fatal("timeout")
|
||||
}
|
||||
}
|
||||
|
||||
func TestReuseInactivityMonitor(t *testing.T) {
|
||||
p := &MocPeer{}
|
||||
im := NewInactivityMonitor(p.ConnID(), time.Second*2)
|
||||
|
||||
timeoutChan := make(chan peerid.ConnID)
|
||||
|
||||
for i := 2; i > 0; i-- {
|
||||
exitChan := make(chan struct{})
|
||||
|
||||
testTimeoutCtx, testTimeoutCancel := context.WithTimeout(context.Background(), time.Second*5)
|
||||
|
||||
go func() {
|
||||
defer close(exitChan)
|
||||
im.Start(testTimeoutCtx, timeoutChan)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-timeoutChan:
|
||||
case <-testTimeoutCtx.Done():
|
||||
t.Fatal("timeout")
|
||||
}
|
||||
|
||||
select {
|
||||
case <-exitChan:
|
||||
case <-testTimeoutCtx.Done():
|
||||
t.Fatal("timeout")
|
||||
}
|
||||
testTimeoutCancel()
|
||||
}
|
||||
}
|
||||
|
||||
func TestStopInactivityMonitor(t *testing.T) {
|
||||
tCtx, testTimeoutCancel := context.WithTimeout(context.Background(), time.Second*5)
|
||||
defer testTimeoutCancel()
|
||||
|
||||
p := &MocPeer{}
|
||||
im := NewInactivityMonitor(p.ConnID(), DefaultInactivityThreshold)
|
||||
|
||||
timeoutChan := make(chan peerid.ConnID)
|
||||
|
||||
exitChan := make(chan struct{})
|
||||
|
||||
go func() {
|
||||
defer close(exitChan)
|
||||
im.Start(tCtx, timeoutChan)
|
||||
}()
|
||||
|
||||
go func() {
|
||||
time.Sleep(3 * time.Second)
|
||||
im.Stop()
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-timeoutChan:
|
||||
t.Fatal("unexpected timeout")
|
||||
case <-exitChan:
|
||||
case <-tCtx.Done():
|
||||
t.Fatal("timeout")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPauseInactivityMonitor(t *testing.T) {
|
||||
tCtx, testTimeoutCancel := context.WithTimeout(context.Background(), time.Second*10)
|
||||
defer testTimeoutCancel()
|
||||
|
||||
p := &MocPeer{}
|
||||
trashHold := time.Second * 3
|
||||
im := NewInactivityMonitor(p.ConnID(), trashHold)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
timeoutChan := make(chan peerid.ConnID)
|
||||
|
||||
exitChan := make(chan struct{})
|
||||
|
||||
go func() {
|
||||
defer close(exitChan)
|
||||
im.Start(ctx, timeoutChan)
|
||||
}()
|
||||
|
||||
time.Sleep(1 * time.Second) // grant time to start the monitor
|
||||
im.PauseTimer()
|
||||
|
||||
// check to do not receive timeout
|
||||
thresholdCtx, thresholdCancel := context.WithTimeout(context.Background(), trashHold+time.Second)
|
||||
defer thresholdCancel()
|
||||
select {
|
||||
case <-exitChan:
|
||||
t.Fatal("unexpected exit")
|
||||
case <-timeoutChan:
|
||||
t.Fatal("unexpected timeout")
|
||||
case <-thresholdCtx.Done():
|
||||
// test ok
|
||||
case <-tCtx.Done():
|
||||
t.Fatal("test timed out")
|
||||
}
|
||||
|
||||
// test reset timer
|
||||
im.ResetTimer()
|
||||
|
||||
select {
|
||||
case <-tCtx.Done():
|
||||
t.Fatal("test timed out")
|
||||
case <-exitChan:
|
||||
t.Fatal("unexpected exit")
|
||||
case <-timeoutChan:
|
||||
// expected timeout
|
||||
}
|
||||
}
|
||||
@@ -1,152 +0,0 @@
|
||||
package inactivity
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/lazyconn"
|
||||
)
|
||||
|
||||
const (
|
||||
checkInterval = 1 * time.Minute
|
||||
|
||||
DefaultInactivityThreshold = 15 * time.Minute
|
||||
MinimumInactivityThreshold = 1 * time.Minute
|
||||
)
|
||||
|
||||
type WgInterface interface {
|
||||
LastActivities() map[string]time.Time
|
||||
}
|
||||
|
||||
type Manager struct {
|
||||
inactivePeersChan chan map[string]struct{}
|
||||
|
||||
iface WgInterface
|
||||
interestedPeers map[string]*lazyconn.PeerConfig
|
||||
inactivityThreshold time.Duration
|
||||
}
|
||||
|
||||
func NewManager(iface WgInterface, configuredThreshold *time.Duration) *Manager {
|
||||
inactivityThreshold, err := validateInactivityThreshold(configuredThreshold)
|
||||
if err != nil {
|
||||
inactivityThreshold = DefaultInactivityThreshold
|
||||
log.Warnf("invalid inactivity threshold configured: %v, using default: %v", err, DefaultInactivityThreshold)
|
||||
}
|
||||
|
||||
log.Infof("inactivity threshold configured: %v", inactivityThreshold)
|
||||
return &Manager{
|
||||
inactivePeersChan: make(chan map[string]struct{}, 1),
|
||||
iface: iface,
|
||||
interestedPeers: make(map[string]*lazyconn.PeerConfig),
|
||||
inactivityThreshold: inactivityThreshold,
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) InactivePeersChan() chan map[string]struct{} {
|
||||
if m == nil {
|
||||
// return a nil channel that blocks forever
|
||||
return nil
|
||||
}
|
||||
|
||||
return m.inactivePeersChan
|
||||
}
|
||||
|
||||
func (m *Manager) AddPeer(peerCfg *lazyconn.PeerConfig) {
|
||||
if m == nil {
|
||||
return
|
||||
}
|
||||
|
||||
if _, exists := m.interestedPeers[peerCfg.PublicKey]; exists {
|
||||
return
|
||||
}
|
||||
|
||||
peerCfg.Log.Infof("adding peer to inactivity manager")
|
||||
m.interestedPeers[peerCfg.PublicKey] = peerCfg
|
||||
}
|
||||
|
||||
func (m *Manager) RemovePeer(peer string) {
|
||||
if m == nil {
|
||||
return
|
||||
}
|
||||
|
||||
pi, ok := m.interestedPeers[peer]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
pi.Log.Debugf("remove peer from inactivity manager")
|
||||
delete(m.interestedPeers, peer)
|
||||
}
|
||||
|
||||
func (m *Manager) Start(ctx context.Context) {
|
||||
if m == nil {
|
||||
return
|
||||
}
|
||||
|
||||
ticker := newTicker(checkInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C():
|
||||
idlePeers, err := m.checkStats()
|
||||
if err != nil {
|
||||
log.Errorf("error checking stats: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if len(idlePeers) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
m.notifyInactivePeers(ctx, idlePeers)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) notifyInactivePeers(ctx context.Context, inactivePeers map[string]struct{}) {
|
||||
select {
|
||||
case m.inactivePeersChan <- inactivePeers:
|
||||
case <-ctx.Done():
|
||||
return
|
||||
default:
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) checkStats() (map[string]struct{}, error) {
|
||||
lastActivities := m.iface.LastActivities()
|
||||
|
||||
idlePeers := make(map[string]struct{})
|
||||
|
||||
for peerID, peerCfg := range m.interestedPeers {
|
||||
lastActive, ok := lastActivities[peerID]
|
||||
if !ok {
|
||||
// when peer is in connecting state
|
||||
peerCfg.Log.Warnf("peer not found in wg stats")
|
||||
continue
|
||||
}
|
||||
|
||||
if time.Since(lastActive) > m.inactivityThreshold {
|
||||
peerCfg.Log.Infof("peer is inactive since: %v", lastActive)
|
||||
idlePeers[peerID] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
return idlePeers, nil
|
||||
}
|
||||
|
||||
func validateInactivityThreshold(configuredThreshold *time.Duration) (time.Duration, error) {
|
||||
if configuredThreshold == nil {
|
||||
return DefaultInactivityThreshold, nil
|
||||
}
|
||||
if *configuredThreshold < MinimumInactivityThreshold {
|
||||
return 0, fmt.Errorf("configured inactivity threshold %v is too low, using %v", *configuredThreshold, MinimumInactivityThreshold)
|
||||
}
|
||||
return *configuredThreshold, nil
|
||||
}
|
||||
@@ -1,113 +0,0 @@
|
||||
package inactivity
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/lazyconn"
|
||||
)
|
||||
|
||||
type mockWgInterface struct {
|
||||
lastActivities map[string]time.Time
|
||||
}
|
||||
|
||||
func (m *mockWgInterface) LastActivities() map[string]time.Time {
|
||||
return m.lastActivities
|
||||
}
|
||||
|
||||
func TestPeerTriggersInactivity(t *testing.T) {
|
||||
peerID := "peer1"
|
||||
|
||||
wgMock := &mockWgInterface{
|
||||
lastActivities: map[string]time.Time{
|
||||
peerID: time.Now().Add(-20 * time.Minute),
|
||||
},
|
||||
}
|
||||
|
||||
fakeTick := make(chan time.Time, 1)
|
||||
newTicker = func(d time.Duration) Ticker {
|
||||
return &fakeTickerMock{CChan: fakeTick}
|
||||
}
|
||||
|
||||
peerLog := log.WithField("peer", peerID)
|
||||
peerCfg := &lazyconn.PeerConfig{
|
||||
PublicKey: peerID,
|
||||
Log: peerLog,
|
||||
}
|
||||
|
||||
manager := NewManager(wgMock, nil)
|
||||
manager.AddPeer(peerCfg)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
// Start the manager in a goroutine
|
||||
go manager.Start(ctx)
|
||||
|
||||
// Send a tick to simulate time passage
|
||||
fakeTick <- time.Now()
|
||||
|
||||
// Check if peer appears on inactivePeersChan
|
||||
select {
|
||||
case inactivePeers := <-manager.inactivePeersChan:
|
||||
assert.Contains(t, inactivePeers, peerID, "expected peer to be marked inactive")
|
||||
case <-time.After(1 * time.Second):
|
||||
t.Fatal("expected inactivity event, but none received")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPeerTriggersActivity(t *testing.T) {
|
||||
peerID := "peer1"
|
||||
|
||||
wgMock := &mockWgInterface{
|
||||
lastActivities: map[string]time.Time{
|
||||
peerID: time.Now().Add(-5 * time.Minute),
|
||||
},
|
||||
}
|
||||
|
||||
fakeTick := make(chan time.Time, 1)
|
||||
newTicker = func(d time.Duration) Ticker {
|
||||
return &fakeTickerMock{CChan: fakeTick}
|
||||
}
|
||||
|
||||
peerLog := log.WithField("peer", peerID)
|
||||
peerCfg := &lazyconn.PeerConfig{
|
||||
PublicKey: peerID,
|
||||
Log: peerLog,
|
||||
}
|
||||
|
||||
manager := NewManager(wgMock, nil)
|
||||
manager.AddPeer(peerCfg)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
// Start the manager in a goroutine
|
||||
go manager.Start(ctx)
|
||||
|
||||
// Send a tick to simulate time passage
|
||||
fakeTick <- time.Now()
|
||||
|
||||
// Check if peer appears on inactivePeersChan
|
||||
select {
|
||||
case <-manager.inactivePeersChan:
|
||||
t.Fatal("expected inactive peer to be marked inactive")
|
||||
case <-time.After(1 * time.Second):
|
||||
// No inactivity event should be received
|
||||
}
|
||||
}
|
||||
|
||||
// fakeTickerMock implements Ticker interface for testing
|
||||
type fakeTickerMock struct {
|
||||
CChan chan time.Time
|
||||
}
|
||||
|
||||
func (f *fakeTickerMock) C() <-chan time.Time {
|
||||
return f.CChan
|
||||
}
|
||||
|
||||
func (f *fakeTickerMock) Stop() {}
|
||||
@@ -1,24 +0,0 @@
|
||||
package inactivity
|
||||
|
||||
import "time"
|
||||
|
||||
var newTicker = func(d time.Duration) Ticker {
|
||||
return &realTicker{t: time.NewTicker(d)}
|
||||
}
|
||||
|
||||
type Ticker interface {
|
||||
C() <-chan time.Time
|
||||
Stop()
|
||||
}
|
||||
|
||||
type realTicker struct {
|
||||
t *time.Ticker
|
||||
}
|
||||
|
||||
func (r *realTicker) C() <-chan time.Time {
|
||||
return r.t.C
|
||||
}
|
||||
|
||||
func (r *realTicker) Stop() {
|
||||
r.t.Stop()
|
||||
}
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/lazyconn"
|
||||
"github.com/netbirdio/netbird/client/internal/lazyconn/activity"
|
||||
"github.com/netbirdio/netbird/client/internal/lazyconn/inactivity"
|
||||
"github.com/netbirdio/netbird/client/internal/peer/dispatcher"
|
||||
peerid "github.com/netbirdio/netbird/client/internal/peer/id"
|
||||
"github.com/netbirdio/netbird/client/internal/peerstore"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
@@ -42,46 +43,60 @@ type Config struct {
|
||||
type Manager struct {
|
||||
engineCtx context.Context
|
||||
peerStore *peerstore.Store
|
||||
connStateDispatcher *dispatcher.ConnectionDispatcher
|
||||
inactivityThreshold time.Duration
|
||||
|
||||
connStateListener *dispatcher.ConnectionListener
|
||||
managedPeers map[string]*lazyconn.PeerConfig
|
||||
managedPeersByConnID map[peerid.ConnID]*managedPeer
|
||||
excludes map[string]lazyconn.PeerConfig
|
||||
managedPeersMu sync.Mutex
|
||||
|
||||
activityManager *activity.Manager
|
||||
inactivityManager *inactivity.Manager
|
||||
activityManager *activity.Manager
|
||||
inactivityMonitors map[peerid.ConnID]*inactivity.Monitor
|
||||
|
||||
// Route HA group management
|
||||
// If any peer in the same HA group is active, all peers in that group should prevent going idle
|
||||
peerToHAGroups map[string][]route.HAUniqueID // peer ID -> HA groups they belong to
|
||||
haGroupToPeers map[route.HAUniqueID][]string // HA group -> peer IDs in the group
|
||||
routesMu sync.RWMutex
|
||||
routesMu sync.RWMutex // protects route mappings
|
||||
|
||||
onInactive chan peerid.ConnID
|
||||
}
|
||||
|
||||
// NewManager creates a new lazy connection manager
|
||||
// engineCtx is the context for creating peer Connection
|
||||
func NewManager(config Config, engineCtx context.Context, peerStore *peerstore.Store, wgIface lazyconn.WGIface) *Manager {
|
||||
func NewManager(config Config, engineCtx context.Context, peerStore *peerstore.Store, wgIface lazyconn.WGIface, connStateDispatcher *dispatcher.ConnectionDispatcher) *Manager {
|
||||
log.Infof("setup lazy connection service")
|
||||
|
||||
m := &Manager{
|
||||
engineCtx: engineCtx,
|
||||
peerStore: peerStore,
|
||||
connStateDispatcher: connStateDispatcher,
|
||||
inactivityThreshold: inactivity.DefaultInactivityThreshold,
|
||||
managedPeers: make(map[string]*lazyconn.PeerConfig),
|
||||
managedPeersByConnID: make(map[peerid.ConnID]*managedPeer),
|
||||
excludes: make(map[string]lazyconn.PeerConfig),
|
||||
activityManager: activity.NewManager(wgIface),
|
||||
inactivityMonitors: make(map[peerid.ConnID]*inactivity.Monitor),
|
||||
peerToHAGroups: make(map[string][]route.HAUniqueID),
|
||||
haGroupToPeers: make(map[route.HAUniqueID][]string),
|
||||
onInactive: make(chan peerid.ConnID),
|
||||
}
|
||||
|
||||
if wgIface.IsUserspaceBind() {
|
||||
m.inactivityManager = inactivity.NewManager(wgIface, config.InactivityThreshold)
|
||||
} else {
|
||||
log.Warnf("inactivity manager not supported for kernel mode, wait for remote peer to close the connection")
|
||||
if config.InactivityThreshold != nil {
|
||||
if *config.InactivityThreshold >= inactivity.MinimumInactivityThreshold {
|
||||
m.inactivityThreshold = *config.InactivityThreshold
|
||||
} else {
|
||||
log.Warnf("inactivity threshold is too low, using %v", m.inactivityThreshold)
|
||||
}
|
||||
}
|
||||
|
||||
m.connStateListener = &dispatcher.ConnectionListener{
|
||||
OnConnected: m.onPeerConnected,
|
||||
OnDisconnected: m.onPeerDisconnected,
|
||||
}
|
||||
|
||||
connStateDispatcher.AddListener(m.connStateListener)
|
||||
|
||||
return m
|
||||
}
|
||||
|
||||
@@ -116,28 +131,24 @@ func (m *Manager) UpdateRouteHAMap(haMap route.HAMap) {
|
||||
}
|
||||
}
|
||||
|
||||
log.Debugf("updated route HA mappings: %d HA groups, %d peers with routes", len(m.haGroupToPeers), len(m.peerToHAGroups))
|
||||
log.Debugf("updated route HA mappings: %d HA groups, %d peers with routes",
|
||||
len(m.haGroupToPeers), len(m.peerToHAGroups))
|
||||
}
|
||||
|
||||
// Start starts the manager and listens for peer activity and inactivity events
|
||||
func (m *Manager) Start(ctx context.Context) {
|
||||
defer m.close()
|
||||
|
||||
if m.inactivityManager != nil {
|
||||
go m.inactivityManager.Start(ctx)
|
||||
}
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case peerConnID := <-m.activityManager.OnActivityChan:
|
||||
m.onPeerActivity(peerConnID)
|
||||
case peerIDs := <-m.inactivityManager.InactivePeersChan():
|
||||
m.onPeerInactivityTimedOut(peerIDs)
|
||||
m.onPeerActivity(ctx, peerConnID)
|
||||
case peerConnID := <-m.onInactive:
|
||||
m.onPeerInactivityTimedOut(peerConnID)
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// ExcludePeer marks peers for a permanent connection
|
||||
@@ -145,7 +156,7 @@ func (m *Manager) Start(ctx context.Context) {
|
||||
// Adds them back to the managed list and start the inactivity listener if they are removed from the exclude list. In
|
||||
// this case, we suppose that the connection status is connected or connecting.
|
||||
// If the peer is not exists yet in the managed list then the responsibility is the upper layer to call the AddPeer function
|
||||
func (m *Manager) ExcludePeer(peerConfigs []lazyconn.PeerConfig) []string {
|
||||
func (m *Manager) ExcludePeer(ctx context.Context, peerConfigs []lazyconn.PeerConfig) []string {
|
||||
m.managedPeersMu.Lock()
|
||||
defer m.managedPeersMu.Unlock()
|
||||
|
||||
@@ -176,7 +187,7 @@ func (m *Manager) ExcludePeer(peerConfigs []lazyconn.PeerConfig) []string {
|
||||
|
||||
peerCfg.Log.Infof("peer removed from lazy connection exclude list")
|
||||
|
||||
if err := m.addActivePeer(&peerCfg); err != nil {
|
||||
if err := m.addActivePeer(ctx, peerCfg); err != nil {
|
||||
log.Errorf("failed to add peer to lazy connection manager: %s", err)
|
||||
continue
|
||||
}
|
||||
@@ -206,24 +217,20 @@ func (m *Manager) AddPeer(peerCfg lazyconn.PeerConfig) (bool, error) {
|
||||
return false, err
|
||||
}
|
||||
|
||||
im := inactivity.NewInactivityMonitor(peerCfg.PeerConnID, m.inactivityThreshold)
|
||||
m.inactivityMonitors[peerCfg.PeerConnID] = im
|
||||
|
||||
m.managedPeers[peerCfg.PublicKey] = &peerCfg
|
||||
m.managedPeersByConnID[peerCfg.PeerConnID] = &managedPeer{
|
||||
peerCfg: &peerCfg,
|
||||
expectedWatcher: watcherActivity,
|
||||
}
|
||||
|
||||
// Check if this peer should be activated because its HA group peers are active
|
||||
if group, ok := m.shouldActivateNewPeer(peerCfg.PublicKey); ok {
|
||||
peerCfg.Log.Debugf("peer belongs to active HA group %s, will activate immediately", group)
|
||||
m.activateNewPeerInActiveGroup(peerCfg)
|
||||
}
|
||||
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// AddActivePeers adds a list of peers to the lazy connection manager
|
||||
// suppose these peers was in connected or in connecting states
|
||||
func (m *Manager) AddActivePeers(peerCfg []lazyconn.PeerConfig) error {
|
||||
func (m *Manager) AddActivePeers(ctx context.Context, peerCfg []lazyconn.PeerConfig) error {
|
||||
m.managedPeersMu.Lock()
|
||||
defer m.managedPeersMu.Unlock()
|
||||
|
||||
@@ -233,7 +240,7 @@ func (m *Manager) AddActivePeers(peerCfg []lazyconn.PeerConfig) error {
|
||||
continue
|
||||
}
|
||||
|
||||
if err := m.addActivePeer(&cfg); err != nil {
|
||||
if err := m.addActivePeer(ctx, cfg); err != nil {
|
||||
cfg.Log.Errorf("failed to add peer to lazy connection manager: %v", err)
|
||||
return err
|
||||
}
|
||||
@@ -250,7 +257,7 @@ func (m *Manager) RemovePeer(peerID string) {
|
||||
|
||||
// ActivatePeer activates a peer connection when a signal message is received
|
||||
// Also activates all peers in the same HA groups as this peer
|
||||
func (m *Manager) ActivatePeer(peerID string) (found bool) {
|
||||
func (m *Manager) ActivatePeer(ctx context.Context, peerID string) (found bool) {
|
||||
m.managedPeersMu.Lock()
|
||||
defer m.managedPeersMu.Unlock()
|
||||
cfg, mp := m.getPeerForActivation(peerID)
|
||||
@@ -258,42 +265,15 @@ func (m *Manager) ActivatePeer(peerID string) (found bool) {
|
||||
return false
|
||||
}
|
||||
|
||||
if !m.activateSinglePeer(cfg, mp) {
|
||||
if !m.activateSinglePeer(ctx, cfg, mp) {
|
||||
return false
|
||||
}
|
||||
|
||||
m.activateHAGroupPeers(cfg)
|
||||
m.activateHAGroupPeers(ctx, peerID)
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func (m *Manager) DeactivatePeer(peerID peerid.ConnID) {
|
||||
m.managedPeersMu.Lock()
|
||||
defer m.managedPeersMu.Unlock()
|
||||
|
||||
mp, ok := m.managedPeersByConnID[peerID]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
if mp.expectedWatcher != watcherInactivity {
|
||||
return
|
||||
}
|
||||
|
||||
m.peerStore.PeerConnClose(mp.peerCfg.PublicKey)
|
||||
|
||||
mp.peerCfg.Log.Infof("start activity monitor")
|
||||
|
||||
mp.expectedWatcher = watcherActivity
|
||||
|
||||
m.inactivityManager.RemovePeer(mp.peerCfg.PublicKey)
|
||||
|
||||
if err := m.activityManager.MonitorPeerActivity(*mp.peerCfg); err != nil {
|
||||
mp.peerCfg.Log.Errorf("failed to create activity monitor: %v", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// getPeerForActivation checks if a peer can be activated and returns the necessary structs
|
||||
// Returns nil values if the peer should be skipped
|
||||
func (m *Manager) getPeerForActivation(peerID string) (*lazyconn.PeerConfig, *managedPeer) {
|
||||
@@ -315,120 +295,82 @@ func (m *Manager) getPeerForActivation(peerID string) (*lazyconn.PeerConfig, *ma
|
||||
return cfg, mp
|
||||
}
|
||||
|
||||
// activateSinglePeer activates a single peer
|
||||
// return true if the peer was activated, false if it was already active
|
||||
func (m *Manager) activateSinglePeer(cfg *lazyconn.PeerConfig, mp *managedPeer) bool {
|
||||
if mp.expectedWatcher == watcherInactivity {
|
||||
// activateSinglePeer activates a single peer (internal method)
|
||||
func (m *Manager) activateSinglePeer(ctx context.Context, cfg *lazyconn.PeerConfig, mp *managedPeer) bool {
|
||||
mp.expectedWatcher = watcherInactivity
|
||||
|
||||
m.activityManager.RemovePeer(cfg.Log, cfg.PeerConnID)
|
||||
|
||||
im, ok := m.inactivityMonitors[cfg.PeerConnID]
|
||||
if !ok {
|
||||
cfg.Log.Errorf("inactivity monitor not found for peer")
|
||||
return false
|
||||
}
|
||||
|
||||
mp.expectedWatcher = watcherInactivity
|
||||
m.activityManager.RemovePeer(cfg.Log, cfg.PeerConnID)
|
||||
m.inactivityManager.AddPeer(cfg)
|
||||
cfg.Log.Infof("starting inactivity monitor")
|
||||
go im.Start(ctx, m.onInactive)
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// activateHAGroupPeers activates all peers in HA groups that the given peer belongs to
|
||||
func (m *Manager) activateHAGroupPeers(triggeredPeerCfg *lazyconn.PeerConfig) {
|
||||
var peersToActivate []string
|
||||
|
||||
func (m *Manager) activateHAGroupPeers(ctx context.Context, triggerPeerID string) {
|
||||
m.routesMu.RLock()
|
||||
haGroups := m.peerToHAGroups[triggeredPeerCfg.PublicKey]
|
||||
haGroups := m.peerToHAGroups[triggerPeerID]
|
||||
m.routesMu.RUnlock()
|
||||
|
||||
if len(haGroups) == 0 {
|
||||
m.routesMu.RUnlock()
|
||||
triggeredPeerCfg.Log.Debugf("peer is not part of any HA groups")
|
||||
log.Debugf("peer %s is not part of any HA groups", triggerPeerID)
|
||||
return
|
||||
}
|
||||
|
||||
for _, haGroup := range haGroups {
|
||||
peers := m.haGroupToPeers[haGroup]
|
||||
for _, peerID := range peers {
|
||||
if peerID != triggeredPeerCfg.PublicKey {
|
||||
peersToActivate = append(peersToActivate, peerID)
|
||||
}
|
||||
}
|
||||
}
|
||||
m.routesMu.RUnlock()
|
||||
|
||||
activatedCount := 0
|
||||
for _, peerID := range peersToActivate {
|
||||
cfg, mp := m.getPeerForActivation(peerID)
|
||||
if cfg == nil {
|
||||
continue
|
||||
}
|
||||
for _, haGroup := range haGroups {
|
||||
m.routesMu.RLock()
|
||||
peers := m.haGroupToPeers[haGroup]
|
||||
m.routesMu.RUnlock()
|
||||
|
||||
if m.activateSinglePeer(cfg, mp) {
|
||||
activatedCount++
|
||||
cfg.Log.Infof("activated peer as part of HA group (triggered by %s)", triggeredPeerCfg.PublicKey)
|
||||
m.peerStore.PeerConnOpen(m.engineCtx, cfg.PublicKey)
|
||||
for _, peerID := range peers {
|
||||
if peerID == triggerPeerID {
|
||||
continue
|
||||
}
|
||||
|
||||
cfg, mp := m.getPeerForActivation(peerID)
|
||||
if cfg == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
if m.activateSinglePeer(ctx, cfg, mp) {
|
||||
activatedCount++
|
||||
cfg.Log.Infof("activated peer as part of HA group %s (triggered by %s)", haGroup, triggerPeerID)
|
||||
m.peerStore.PeerConnOpen(m.engineCtx, cfg.PublicKey)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if activatedCount > 0 {
|
||||
log.Infof("activated %d additional peers in HA groups for peer %s (groups: %v)",
|
||||
activatedCount, triggeredPeerCfg.PublicKey, haGroups)
|
||||
activatedCount, triggerPeerID, haGroups)
|
||||
}
|
||||
}
|
||||
|
||||
// shouldActivateNewPeer checks if a newly added peer should be activated
|
||||
// because other peers in its HA groups are already active
|
||||
func (m *Manager) shouldActivateNewPeer(peerID string) (route.HAUniqueID, bool) {
|
||||
m.routesMu.RLock()
|
||||
defer m.routesMu.RUnlock()
|
||||
|
||||
haGroups := m.peerToHAGroups[peerID]
|
||||
if len(haGroups) == 0 {
|
||||
return "", false
|
||||
}
|
||||
|
||||
for _, haGroup := range haGroups {
|
||||
peers := m.haGroupToPeers[haGroup]
|
||||
for _, groupPeerID := range peers {
|
||||
if groupPeerID == peerID {
|
||||
continue
|
||||
}
|
||||
|
||||
cfg, ok := m.managedPeers[groupPeerID]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if mp, ok := m.managedPeersByConnID[cfg.PeerConnID]; ok && mp.expectedWatcher == watcherInactivity {
|
||||
return haGroup, true
|
||||
}
|
||||
}
|
||||
}
|
||||
return "", false
|
||||
}
|
||||
|
||||
// activateNewPeerInActiveGroup activates a newly added peer that should be active due to HA group
|
||||
func (m *Manager) activateNewPeerInActiveGroup(peerCfg lazyconn.PeerConfig) {
|
||||
mp, ok := m.managedPeersByConnID[peerCfg.PeerConnID]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
if !m.activateSinglePeer(&peerCfg, mp) {
|
||||
return
|
||||
}
|
||||
|
||||
peerCfg.Log.Infof("activated newly added peer due to active HA group peers")
|
||||
m.peerStore.PeerConnOpen(m.engineCtx, peerCfg.PublicKey)
|
||||
}
|
||||
|
||||
func (m *Manager) addActivePeer(peerCfg *lazyconn.PeerConfig) error {
|
||||
func (m *Manager) addActivePeer(ctx context.Context, peerCfg lazyconn.PeerConfig) error {
|
||||
if _, ok := m.managedPeers[peerCfg.PublicKey]; ok {
|
||||
peerCfg.Log.Warnf("peer already managed")
|
||||
return nil
|
||||
}
|
||||
|
||||
m.managedPeers[peerCfg.PublicKey] = peerCfg
|
||||
im := inactivity.NewInactivityMonitor(peerCfg.PeerConnID, m.inactivityThreshold)
|
||||
m.inactivityMonitors[peerCfg.PeerConnID] = im
|
||||
|
||||
m.managedPeers[peerCfg.PublicKey] = &peerCfg
|
||||
m.managedPeersByConnID[peerCfg.PeerConnID] = &managedPeer{
|
||||
peerCfg: peerCfg,
|
||||
peerCfg: &peerCfg,
|
||||
expectedWatcher: watcherInactivity,
|
||||
}
|
||||
|
||||
m.inactivityManager.AddPeer(peerCfg)
|
||||
peerCfg.Log.Infof("starting inactivity monitor on peer that has been removed from exclude list")
|
||||
go im.Start(ctx, m.onInactive)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -440,7 +382,12 @@ func (m *Manager) removePeer(peerID string) {
|
||||
|
||||
cfg.Log.Infof("removing lazy peer")
|
||||
|
||||
m.inactivityManager.RemovePeer(cfg.PublicKey)
|
||||
if im, ok := m.inactivityMonitors[cfg.PeerConnID]; ok {
|
||||
im.Stop()
|
||||
delete(m.inactivityMonitors, cfg.PeerConnID)
|
||||
cfg.Log.Debugf("inactivity monitor stopped")
|
||||
}
|
||||
|
||||
m.activityManager.RemovePeer(cfg.Log, cfg.PeerConnID)
|
||||
delete(m.managedPeers, peerID)
|
||||
delete(m.managedPeersByConnID, cfg.PeerConnID)
|
||||
@@ -450,8 +397,12 @@ func (m *Manager) close() {
|
||||
m.managedPeersMu.Lock()
|
||||
defer m.managedPeersMu.Unlock()
|
||||
|
||||
m.connStateDispatcher.RemoveListener(m.connStateListener)
|
||||
m.activityManager.Close()
|
||||
|
||||
for _, iw := range m.inactivityMonitors {
|
||||
iw.Stop()
|
||||
}
|
||||
m.inactivityMonitors = make(map[peerid.ConnID]*inactivity.Monitor)
|
||||
m.managedPeers = make(map[string]*lazyconn.PeerConfig)
|
||||
m.managedPeersByConnID = make(map[peerid.ConnID]*managedPeer)
|
||||
|
||||
@@ -464,56 +415,7 @@ func (m *Manager) close() {
|
||||
log.Infof("lazy connection manager closed")
|
||||
}
|
||||
|
||||
// shouldDeferIdleForHA checks if peer should stay connected due to HA group requirements
|
||||
func (m *Manager) shouldDeferIdleForHA(inactivePeers map[string]struct{}, peerID string) bool {
|
||||
m.routesMu.RLock()
|
||||
defer m.routesMu.RUnlock()
|
||||
|
||||
haGroups := m.peerToHAGroups[peerID]
|
||||
if len(haGroups) == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
for _, haGroup := range haGroups {
|
||||
if active := m.checkHaGroupActivity(haGroup, peerID, inactivePeers); active {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func (m *Manager) checkHaGroupActivity(haGroup route.HAUniqueID, peerID string, inactivePeers map[string]struct{}) bool {
|
||||
groupPeers := m.haGroupToPeers[haGroup]
|
||||
for _, groupPeerID := range groupPeers {
|
||||
|
||||
if groupPeerID == peerID {
|
||||
continue
|
||||
}
|
||||
|
||||
cfg, ok := m.managedPeers[groupPeerID]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
groupMp, ok := m.managedPeersByConnID[cfg.PeerConnID]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
if groupMp.expectedWatcher != watcherInactivity {
|
||||
continue
|
||||
}
|
||||
|
||||
// If any peer in the group is active, do defer idle
|
||||
if _, isInactive := inactivePeers[groupPeerID]; !isInactive {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (m *Manager) onPeerActivity(peerConnID peerid.ConnID) {
|
||||
func (m *Manager) onPeerActivity(ctx context.Context, peerConnID peerid.ConnID) {
|
||||
m.managedPeersMu.Lock()
|
||||
defer m.managedPeersMu.Unlock()
|
||||
|
||||
@@ -530,56 +432,89 @@ func (m *Manager) onPeerActivity(peerConnID peerid.ConnID) {
|
||||
|
||||
mp.peerCfg.Log.Infof("detected peer activity")
|
||||
|
||||
if !m.activateSinglePeer(mp.peerCfg, mp) {
|
||||
if !m.activateSinglePeer(ctx, mp.peerCfg, mp) {
|
||||
return
|
||||
}
|
||||
|
||||
m.activateHAGroupPeers(mp.peerCfg)
|
||||
m.activateHAGroupPeers(ctx, mp.peerCfg.PublicKey)
|
||||
|
||||
m.peerStore.PeerConnOpen(m.engineCtx, mp.peerCfg.PublicKey)
|
||||
}
|
||||
|
||||
func (m *Manager) onPeerInactivityTimedOut(peerIDs map[string]struct{}) {
|
||||
func (m *Manager) onPeerInactivityTimedOut(peerConnID peerid.ConnID) {
|
||||
m.managedPeersMu.Lock()
|
||||
defer m.managedPeersMu.Unlock()
|
||||
|
||||
for peerID := range peerIDs {
|
||||
peerCfg, ok := m.managedPeers[peerID]
|
||||
if !ok {
|
||||
log.Errorf("peer not found by peerId: %v", peerID)
|
||||
continue
|
||||
}
|
||||
mp, ok := m.managedPeersByConnID[peerConnID]
|
||||
if !ok {
|
||||
log.Errorf("peer not found by id: %v", peerConnID)
|
||||
return
|
||||
}
|
||||
|
||||
mp, ok := m.managedPeersByConnID[peerCfg.PeerConnID]
|
||||
if !ok {
|
||||
log.Errorf("peer not found by conn id: %v", peerCfg.PeerConnID)
|
||||
continue
|
||||
}
|
||||
if mp.expectedWatcher != watcherInactivity {
|
||||
mp.peerCfg.Log.Warnf("ignore inactivity event")
|
||||
return
|
||||
}
|
||||
|
||||
if mp.expectedWatcher != watcherInactivity {
|
||||
mp.peerCfg.Log.Warnf("ignore inactivity event")
|
||||
continue
|
||||
}
|
||||
mp.peerCfg.Log.Infof("connection timed out")
|
||||
|
||||
if m.shouldDeferIdleForHA(peerIDs, mp.peerCfg.PublicKey) {
|
||||
mp.peerCfg.Log.Infof("defer inactivity due to active HA group peers")
|
||||
continue
|
||||
}
|
||||
// this is blocking operation, potentially can be optimized
|
||||
m.peerStore.PeerConnClose(mp.peerCfg.PublicKey)
|
||||
|
||||
mp.peerCfg.Log.Infof("connection timed out")
|
||||
mp.peerCfg.Log.Infof("start activity monitor")
|
||||
|
||||
// this is blocking operation, potentially can be optimized
|
||||
m.peerStore.PeerConnIdle(mp.peerCfg.PublicKey)
|
||||
mp.expectedWatcher = watcherActivity
|
||||
|
||||
mp.peerCfg.Log.Infof("start activity monitor")
|
||||
// just in case free up
|
||||
m.inactivityMonitors[peerConnID].PauseTimer()
|
||||
|
||||
mp.expectedWatcher = watcherActivity
|
||||
|
||||
m.inactivityManager.RemovePeer(mp.peerCfg.PublicKey)
|
||||
|
||||
if err := m.activityManager.MonitorPeerActivity(*mp.peerCfg); err != nil {
|
||||
mp.peerCfg.Log.Errorf("failed to create activity monitor: %v", err)
|
||||
continue
|
||||
}
|
||||
if err := m.activityManager.MonitorPeerActivity(*mp.peerCfg); err != nil {
|
||||
mp.peerCfg.Log.Errorf("failed to create activity monitor: %v", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) onPeerConnected(peerConnID peerid.ConnID) {
|
||||
m.managedPeersMu.Lock()
|
||||
defer m.managedPeersMu.Unlock()
|
||||
|
||||
mp, ok := m.managedPeersByConnID[peerConnID]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
if mp.expectedWatcher != watcherInactivity {
|
||||
return
|
||||
}
|
||||
|
||||
iw, ok := m.inactivityMonitors[mp.peerCfg.PeerConnID]
|
||||
if !ok {
|
||||
mp.peerCfg.Log.Errorf("inactivity monitor not found for peer")
|
||||
return
|
||||
}
|
||||
|
||||
mp.peerCfg.Log.Infof("peer connected, pausing inactivity monitor while connection is not disconnected")
|
||||
iw.PauseTimer()
|
||||
}
|
||||
|
||||
func (m *Manager) onPeerDisconnected(peerConnID peerid.ConnID) {
|
||||
m.managedPeersMu.Lock()
|
||||
defer m.managedPeersMu.Unlock()
|
||||
|
||||
mp, ok := m.managedPeersByConnID[peerConnID]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
if mp.expectedWatcher != watcherInactivity {
|
||||
return
|
||||
}
|
||||
|
||||
iw, ok := m.inactivityMonitors[mp.peerCfg.PeerConnID]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
mp.peerCfg.Log.Infof("reset inactivity monitor timer")
|
||||
iw.ResetTimer()
|
||||
}
|
||||
|
||||
@@ -11,6 +11,4 @@ import (
|
||||
type WGIface interface {
|
||||
RemovePeer(peerKey string) error
|
||||
UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error
|
||||
IsUserspaceBind() bool
|
||||
LastActivities() map[string]time.Time
|
||||
}
|
||||
|
||||
@@ -148,7 +148,7 @@ func registerPeer(ctx context.Context, serverPublicKey wgtypes.Key, client *mgm.
|
||||
)
|
||||
loginResp, err := client.Register(serverPublicKey, validSetupKey.String(), jwtToken, info, pubSSHKey, config.DNSLabels)
|
||||
if err != nil {
|
||||
log.Errorf("failed registering peer %v", err)
|
||||
log.Errorf("failed registering peer %v,%s", err, validSetupKey.String())
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
||||
@@ -204,7 +204,7 @@ func (c *ConnTrack) handleEvent(event nfct.Event) {
|
||||
eventStr = "Ended"
|
||||
}
|
||||
|
||||
log.Tracef("%s %s %s connection: %s:%d → %s:%d", eventStr, direction, proto, srcIP, srcPort, dstIP, dstPort)
|
||||
log.Tracef("%s %s %s connection: %s:%d -> %s:%d", eventStr, direction, proto, srcIP, srcPort, dstIP, dstPort)
|
||||
|
||||
c.flowLogger.StoreEvent(nftypes.EventFields{
|
||||
FlowID: flowID,
|
||||
|
||||
@@ -117,9 +117,10 @@ type Conn struct {
|
||||
wgProxyRelay wgproxy.Proxy
|
||||
handshaker *Handshaker
|
||||
|
||||
guard *guard.Guard
|
||||
semaphore *semaphoregroup.SemaphoreGroup
|
||||
wg sync.WaitGroup
|
||||
guard *guard.Guard
|
||||
semaphore *semaphoregroup.SemaphoreGroup
|
||||
peerConnDispatcher *dispatcher.ConnectionDispatcher
|
||||
wg sync.WaitGroup
|
||||
|
||||
// debug purpose
|
||||
dumpState *stateDump
|
||||
@@ -135,17 +136,18 @@ func NewConn(config ConnConfig, services ServiceDependencies) (*Conn, error) {
|
||||
connLog := log.WithField("peer", config.Key)
|
||||
|
||||
var conn = &Conn{
|
||||
Log: connLog,
|
||||
config: config,
|
||||
statusRecorder: services.StatusRecorder,
|
||||
signaler: services.Signaler,
|
||||
iFaceDiscover: services.IFaceDiscover,
|
||||
relayManager: services.RelayManager,
|
||||
srWatcher: services.SrWatcher,
|
||||
semaphore: services.Semaphore,
|
||||
statusRelay: worker.NewAtomicStatus(),
|
||||
statusICE: worker.NewAtomicStatus(),
|
||||
dumpState: newStateDump(config.Key, connLog, services.StatusRecorder),
|
||||
Log: connLog,
|
||||
config: config,
|
||||
statusRecorder: services.StatusRecorder,
|
||||
signaler: services.Signaler,
|
||||
iFaceDiscover: services.IFaceDiscover,
|
||||
relayManager: services.RelayManager,
|
||||
srWatcher: services.SrWatcher,
|
||||
semaphore: services.Semaphore,
|
||||
peerConnDispatcher: services.PeerConnDispatcher,
|
||||
statusRelay: worker.NewAtomicStatus(),
|
||||
statusICE: worker.NewAtomicStatus(),
|
||||
dumpState: newStateDump(config.Key, connLog, services.StatusRecorder),
|
||||
}
|
||||
|
||||
return conn, nil
|
||||
@@ -224,7 +226,7 @@ func (conn *Conn) Open(engineCtx context.Context) error {
|
||||
}
|
||||
|
||||
// Close closes this peer Conn issuing a close event to the Conn closeCh
|
||||
func (conn *Conn) Close(signalToRemote bool) {
|
||||
func (conn *Conn) Close() {
|
||||
conn.mu.Lock()
|
||||
defer conn.wgWatcherWg.Wait()
|
||||
defer conn.mu.Unlock()
|
||||
@@ -234,12 +236,6 @@ func (conn *Conn) Close(signalToRemote bool) {
|
||||
return
|
||||
}
|
||||
|
||||
if signalToRemote {
|
||||
if err := conn.signaler.SignalIdle(conn.config.Key); err != nil {
|
||||
conn.Log.Errorf("failed to signal idle state to peer: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
conn.Log.Infof("close peer connection")
|
||||
conn.ctxCancel()
|
||||
|
||||
@@ -321,12 +317,12 @@ func (conn *Conn) WgConfig() WgConfig {
|
||||
return conn.config.WgConfig
|
||||
}
|
||||
|
||||
// IsConnected returns true if the peer is connected
|
||||
// IsConnected unit tests only
|
||||
// refactor unit test to use status recorder use refactor status recorded to manage connection status in peer.Conn
|
||||
func (conn *Conn) IsConnected() bool {
|
||||
conn.mu.Lock()
|
||||
defer conn.mu.Unlock()
|
||||
|
||||
return conn.evalStatus() == StatusConnected
|
||||
return conn.currentConnPriority != conntype.None
|
||||
}
|
||||
|
||||
func (conn *Conn) GetKey() string {
|
||||
@@ -408,10 +404,15 @@ func (conn *Conn) onICEConnectionIsReady(priority conntype.ConnPriority, iceConn
|
||||
}
|
||||
wgConfigWorkaround()
|
||||
|
||||
oldState := conn.currentConnPriority
|
||||
conn.currentConnPriority = priority
|
||||
conn.statusICE.SetConnected()
|
||||
conn.updateIceState(iceConnInfo)
|
||||
conn.doOnConnected(iceConnInfo.RosenpassPubKey, iceConnInfo.RosenpassAddr)
|
||||
|
||||
if oldState == conntype.None {
|
||||
conn.peerConnDispatcher.NotifyConnected(conn.ConnID())
|
||||
}
|
||||
}
|
||||
|
||||
func (conn *Conn) onICEStateDisconnected() {
|
||||
@@ -449,6 +450,7 @@ func (conn *Conn) onICEStateDisconnected() {
|
||||
} else {
|
||||
conn.Log.Infof("ICE disconnected, do not switch to Relay. Reset priority to: %s", conntype.None.String())
|
||||
conn.currentConnPriority = conntype.None
|
||||
conn.peerConnDispatcher.NotifyDisconnected(conn.ConnID())
|
||||
}
|
||||
|
||||
changed := conn.statusICE.Get() != worker.StatusDisconnected
|
||||
@@ -528,6 +530,7 @@ func (conn *Conn) onRelayConnectionIsReady(rci RelayConnInfo) {
|
||||
conn.updateRelayStatus(rci.relayedConn.RemoteAddr().String(), rci.rosenpassPubKey)
|
||||
conn.Log.Infof("start to communicate with peer via relay")
|
||||
conn.doOnConnected(rci.rosenpassPubKey, rci.rosenpassAddr)
|
||||
conn.peerConnDispatcher.NotifyConnected(conn.ConnID())
|
||||
}
|
||||
|
||||
func (conn *Conn) onRelayDisconnected() {
|
||||
@@ -542,7 +545,11 @@ func (conn *Conn) onRelayDisconnected() {
|
||||
|
||||
if conn.currentConnPriority == conntype.Relay {
|
||||
conn.Log.Debugf("clean up WireGuard config")
|
||||
if err := conn.removeWgPeer(); err != nil {
|
||||
conn.Log.Errorf("failed to remove wg endpoint: %v", err)
|
||||
}
|
||||
conn.currentConnPriority = conntype.None
|
||||
conn.peerConnDispatcher.NotifyDisconnected(conn.ConnID())
|
||||
}
|
||||
|
||||
if conn.wgProxyRelay != nil {
|
||||
|
||||
@@ -68,13 +68,3 @@ func (s *Signaler) signalOfferAnswer(offerAnswer OfferAnswer, remoteKey string,
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Signaler) SignalIdle(remoteKey string) error {
|
||||
return s.signal.Send(&sProto.Message{
|
||||
Key: s.wgPrivateKey.PublicKey().String(),
|
||||
RemoteKey: remoteKey,
|
||||
Body: &sProto.Body{
|
||||
Type: sProto.Body_GO_IDLE,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
@@ -575,12 +575,13 @@ func (d *Status) UpdatePeerFQDN(peerPubKey, fqdn string) error {
|
||||
// FinishPeerListModifications this event invoke the notification
|
||||
func (d *Status) FinishPeerListModifications() {
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
|
||||
if !d.peerListChangedForNotification {
|
||||
d.mux.Unlock()
|
||||
return
|
||||
}
|
||||
d.peerListChangedForNotification = false
|
||||
d.mux.Unlock()
|
||||
|
||||
d.notifyPeerListChanged()
|
||||
|
||||
|
||||
@@ -146,8 +146,8 @@ func (w *WorkerICE) OnNewOffer(remoteOfferAnswer *OfferAnswer) {
|
||||
RosenpassAddr: remoteOfferAnswer.RosenpassAddr,
|
||||
LocalIceCandidateType: pair.Local.Type().String(),
|
||||
RemoteIceCandidateType: pair.Remote.Type().String(),
|
||||
LocalIceCandidateEndpoint: formatEndpoint(pair.Local.Address(), pair.Local.Port()),
|
||||
RemoteIceCandidateEndpoint: formatEndpoint(pair.Remote.Address(), pair.Remote.Port()),
|
||||
LocalIceCandidateEndpoint: fmt.Sprintf("%s:%d", pair.Local.Address(), pair.Local.Port()),
|
||||
RemoteIceCandidateEndpoint: fmt.Sprintf("%s:%d", pair.Remote.Address(), pair.Remote.Port()),
|
||||
Relayed: isRelayed(pair),
|
||||
RelayedOnLocal: isRelayCandidate(pair.Local),
|
||||
}
|
||||
@@ -405,12 +405,3 @@ func selectedPriority(pair *ice.CandidatePair) conntype.ConnPriority {
|
||||
return conntype.ICEP2P
|
||||
}
|
||||
}
|
||||
|
||||
// formatEndpoint formats an IP address and port for display, adding brackets around IPv6 addresses
|
||||
func formatEndpoint(addr string, port int) string {
|
||||
parsed, err := netip.ParseAddr(addr)
|
||||
if err == nil && parsed.Is6() {
|
||||
return fmt.Sprintf("[%s]:%d", addr, port)
|
||||
}
|
||||
return fmt.Sprintf("%s:%d", addr, port)
|
||||
}
|
||||
|
||||
@@ -95,17 +95,6 @@ func (s *Store) PeerConnOpen(ctx context.Context, pubKey string) {
|
||||
|
||||
}
|
||||
|
||||
func (s *Store) PeerConnIdle(pubKey string) {
|
||||
s.peerConnsMu.RLock()
|
||||
defer s.peerConnsMu.RUnlock()
|
||||
|
||||
p, ok := s.peerConns[pubKey]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
p.Close(true)
|
||||
}
|
||||
|
||||
func (s *Store) PeerConnClose(pubKey string) {
|
||||
s.peerConnsMu.RLock()
|
||||
defer s.peerConnsMu.RUnlock()
|
||||
@@ -114,7 +103,7 @@ func (s *Store) PeerConnClose(pubKey string) {
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
p.Close(false)
|
||||
p.Close()
|
||||
}
|
||||
|
||||
func (s *Store) PeersPubKey() []string {
|
||||
|
||||
@@ -4,16 +4,18 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"runtime"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/client/internal/dns"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/common"
|
||||
"github.com/netbirdio/netbird/client/internal/peerstore"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/dnsinterceptor"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/dynamic"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/iface"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/static"
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
@@ -21,7 +23,7 @@ import (
|
||||
|
||||
const (
|
||||
handlerTypeDynamic = iota
|
||||
handlerTypeDnsInterceptor
|
||||
handlerTypeDomain
|
||||
handlerTypeStatic
|
||||
)
|
||||
|
||||
@@ -552,16 +554,40 @@ func (w *Watcher) Stop() {
|
||||
w.currentChosenStatus = nil
|
||||
}
|
||||
|
||||
func HandlerFromRoute(params common.HandlerParams) RouteHandler {
|
||||
switch handlerType(params.Route, params.UseNewDNSRoute) {
|
||||
case handlerTypeDnsInterceptor:
|
||||
return dnsinterceptor.New(params)
|
||||
func HandlerFromRoute(
|
||||
rt *route.Route,
|
||||
routeRefCounter *refcounter.RouteRefCounter,
|
||||
allowedIPsRefCounter *refcounter.AllowedIPsRefCounter,
|
||||
dnsRouterInteval time.Duration,
|
||||
statusRecorder *peer.Status,
|
||||
wgInterface iface.WGIface,
|
||||
dnsServer nbdns.Server,
|
||||
peerStore *peerstore.Store,
|
||||
useNewDNSRoute bool,
|
||||
) RouteHandler {
|
||||
switch handlerType(rt, useNewDNSRoute) {
|
||||
case handlerTypeDomain:
|
||||
return dnsinterceptor.New(
|
||||
rt,
|
||||
routeRefCounter,
|
||||
allowedIPsRefCounter,
|
||||
statusRecorder,
|
||||
dnsServer,
|
||||
peerStore,
|
||||
)
|
||||
case handlerTypeDynamic:
|
||||
dns := nbdns.NewServiceViaMemory(params.WgInterface)
|
||||
dnsAddr := fmt.Sprintf("%s:%d", dns.RuntimeIP(), dns.RuntimePort())
|
||||
return dynamic.NewRoute(params, dnsAddr)
|
||||
dns := nbdns.NewServiceViaMemory(wgInterface)
|
||||
return dynamic.NewRoute(
|
||||
rt,
|
||||
routeRefCounter,
|
||||
allowedIPsRefCounter,
|
||||
dnsRouterInteval,
|
||||
statusRecorder,
|
||||
wgInterface,
|
||||
fmt.Sprintf("%s:%d", dns.RuntimeIP(), dns.RuntimePort()),
|
||||
)
|
||||
default:
|
||||
return static.NewRoute(params)
|
||||
return static.NewRoute(rt, routeRefCounter, allowedIPsRefCounter)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -570,8 +596,8 @@ func handlerType(rt *route.Route, useNewDNSRoute bool) int {
|
||||
return handlerTypeStatic
|
||||
}
|
||||
|
||||
if useNewDNSRoute {
|
||||
return handlerTypeDnsInterceptor
|
||||
if useNewDNSRoute && runtime.GOOS != "ios" {
|
||||
return handlerTypeDomain
|
||||
}
|
||||
return handlerTypeDynamic
|
||||
}
|
||||
|
||||
@@ -7,12 +7,12 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/common"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/static"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
)
|
||||
|
||||
func TestGetBestrouteFromStatuses(t *testing.T) {
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
statuses map[route.ID]routerPeerStatus
|
||||
@@ -811,12 +811,9 @@ func TestGetBestrouteFromStatuses(t *testing.T) {
|
||||
currentRoute = tc.existingRoutes[tc.currentRoute]
|
||||
}
|
||||
|
||||
params := common.HandlerParams{
|
||||
Route: &route.Route{Network: netip.MustParsePrefix("192.168.0.0/24")},
|
||||
}
|
||||
// create new clientNetwork
|
||||
client := &Watcher{
|
||||
handler: static.NewRoute(params),
|
||||
handler: static.NewRoute(&route.Route{Network: netip.MustParsePrefix("192.168.0.0/24")}, nil, nil),
|
||||
routes: tc.existingRoutes,
|
||||
currentChosen: currentRoute,
|
||||
}
|
||||
|
||||
@@ -1,28 +0,0 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/internal/dns"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/peerstore"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/fakeip"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/iface"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
)
|
||||
|
||||
type HandlerParams struct {
|
||||
Route *route.Route
|
||||
RouteRefCounter *refcounter.RouteRefCounter
|
||||
AllowedIPsRefCounter *refcounter.AllowedIPsRefCounter
|
||||
DnsRouterInterval time.Duration
|
||||
StatusRecorder *peer.Status
|
||||
WgInterface iface.WGIface
|
||||
DnsServer dns.Server
|
||||
PeerStore *peerstore.Store
|
||||
UseNewDNSRoute bool
|
||||
Firewall manager.Manager
|
||||
FakeIPManager *fakeip.Manager
|
||||
}
|
||||
@@ -4,23 +4,19 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"runtime"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/go-multierror"
|
||||
"github.com/miekg/dns"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
nbdns "github.com/netbirdio/netbird/client/internal/dns"
|
||||
"github.com/netbirdio/netbird/client/internal/dnsfwd"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/peerstore"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/common"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/fakeip"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
||||
"github.com/netbirdio/netbird/management/domain"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
@@ -28,16 +24,6 @@ import (
|
||||
|
||||
type domainMap map[domain.Domain][]netip.Prefix
|
||||
|
||||
type internalDNATer interface {
|
||||
RemoveInternalDNATMapping(netip.Addr) error
|
||||
AddInternalDNATMapping(netip.Addr, netip.Addr) error
|
||||
}
|
||||
|
||||
type wgInterface interface {
|
||||
Name() string
|
||||
Address() wgaddr.Address
|
||||
}
|
||||
|
||||
type DnsInterceptor struct {
|
||||
mu sync.RWMutex
|
||||
route *route.Route
|
||||
@@ -47,24 +33,25 @@ type DnsInterceptor struct {
|
||||
dnsServer nbdns.Server
|
||||
currentPeerKey string
|
||||
interceptedDomains domainMap
|
||||
wgInterface wgInterface
|
||||
peerStore *peerstore.Store
|
||||
firewall firewall.Manager
|
||||
fakeIPManager *fakeip.Manager
|
||||
}
|
||||
|
||||
func New(params common.HandlerParams) *DnsInterceptor {
|
||||
func New(
|
||||
rt *route.Route,
|
||||
routeRefCounter *refcounter.RouteRefCounter,
|
||||
allowedIPsRefCounter *refcounter.AllowedIPsRefCounter,
|
||||
statusRecorder *peer.Status,
|
||||
dnsServer nbdns.Server,
|
||||
peerStore *peerstore.Store,
|
||||
) *DnsInterceptor {
|
||||
return &DnsInterceptor{
|
||||
route: params.Route,
|
||||
routeRefCounter: params.RouteRefCounter,
|
||||
allowedIPsRefcounter: params.AllowedIPsRefCounter,
|
||||
statusRecorder: params.StatusRecorder,
|
||||
dnsServer: params.DnsServer,
|
||||
wgInterface: params.WgInterface,
|
||||
peerStore: params.PeerStore,
|
||||
firewall: params.Firewall,
|
||||
fakeIPManager: params.FakeIPManager,
|
||||
route: rt,
|
||||
routeRefCounter: routeRefCounter,
|
||||
allowedIPsRefcounter: allowedIPsRefCounter,
|
||||
statusRecorder: statusRecorder,
|
||||
dnsServer: dnsServer,
|
||||
interceptedDomains: make(domainMap),
|
||||
peerStore: peerStore,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -77,19 +64,127 @@ func (d *DnsInterceptor) AddRoute(context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// preResolveDomains performs background DNS resolution for non-wildcard domains
|
||||
func (d *DnsInterceptor) preResolveDomains() {
|
||||
for _, domain := range d.route.Domains {
|
||||
domainStr := string(domain)
|
||||
|
||||
if strings.HasPrefix(domainStr, "*.") {
|
||||
continue
|
||||
}
|
||||
|
||||
domainStr = strings.TrimSuffix(domainStr, ".")
|
||||
go func(domain string) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if err := d.resolveAndUpdateDomain(ctx, domain); err != nil {
|
||||
log.Debugf("pre-resolve failed for domain %s: %v", domain, err)
|
||||
} else {
|
||||
log.Tracef("pre-resolve completed for domain %s", domain)
|
||||
}
|
||||
}(domainStr)
|
||||
}
|
||||
}
|
||||
|
||||
// resolveAndUpdateDomain performs DNS resolution and updates domain prefixes
|
||||
func (d *DnsInterceptor) resolveAndUpdateDomain(ctx context.Context, qDomain string) error {
|
||||
d.mu.RLock()
|
||||
peerKey := d.currentPeerKey
|
||||
d.mu.RUnlock()
|
||||
|
||||
if peerKey == "" {
|
||||
return fmt.Errorf("no current peer key")
|
||||
}
|
||||
|
||||
upstreamIP, err := d.getUpstreamIP(peerKey)
|
||||
if err != nil {
|
||||
return fmt.Errorf("get upstream IP: %v", err)
|
||||
}
|
||||
|
||||
msg := new(dns.Msg)
|
||||
msg.SetQuestion(dns.Fqdn(qDomain), dns.TypeA)
|
||||
msg.Id = dns.Id()
|
||||
msg.MsgHdr.AuthenticatedData = true
|
||||
|
||||
reply, err := d.exchangeWithUpstream(ctx, msg, upstreamIP)
|
||||
if err != nil {
|
||||
return fmt.Errorf("exchange with upstream: %v", err)
|
||||
}
|
||||
|
||||
if reply == nil || len(reply.Answer) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
resolvedDomain := domain.Domain(dns.Fqdn(qDomain))
|
||||
return d.processResolveResponse(reply, resolvedDomain, resolvedDomain)
|
||||
}
|
||||
|
||||
// exchangeWithUpstream performs DNS exchange with the upstream server
|
||||
func (d *DnsInterceptor) exchangeWithUpstream(ctx context.Context, msg *dns.Msg, upstreamIP netip.Addr) (*dns.Msg, error) {
|
||||
client := &dns.Client{
|
||||
Timeout: nbdns.UpstreamTimeout,
|
||||
Net: "udp",
|
||||
}
|
||||
upstream := fmt.Sprintf("%s:%d", upstreamIP.String(), dnsfwd.ListenPort)
|
||||
|
||||
reply, _, err := nbdns.ExchangeWithFallback(ctx, client, msg, upstream)
|
||||
return reply, err
|
||||
}
|
||||
|
||||
// extractIPsFromDNSResponse extracts IP addresses from DNS answer records
|
||||
func (d *DnsInterceptor) extractIPsFromDNSResponse(reply *dns.Msg, domainForLogging domain.Domain) []netip.Prefix {
|
||||
if reply == nil || len(reply.Answer) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
var prefixes []netip.Prefix
|
||||
for _, answer := range reply.Answer {
|
||||
var ip netip.Addr
|
||||
switch rr := answer.(type) {
|
||||
case *dns.A:
|
||||
addr, ok := netip.AddrFromSlice(rr.A)
|
||||
if !ok {
|
||||
log.Tracef("failed to convert A record for domain=%s ip=%v", domainForLogging, rr.A)
|
||||
continue
|
||||
}
|
||||
ip = addr
|
||||
case *dns.AAAA:
|
||||
addr, ok := netip.AddrFromSlice(rr.AAAA)
|
||||
if !ok {
|
||||
log.Tracef("failed to convert AAAA record for domain=%s ip=%v", domainForLogging, rr.AAAA)
|
||||
continue
|
||||
}
|
||||
ip = addr
|
||||
default:
|
||||
continue
|
||||
}
|
||||
|
||||
prefix := netip.PrefixFrom(ip.Unmap(), ip.BitLen())
|
||||
prefixes = append(prefixes, prefix)
|
||||
}
|
||||
|
||||
return prefixes
|
||||
}
|
||||
|
||||
// processResolveResponse extracts IPs from DNS response and updates domain prefixes
|
||||
func (d *DnsInterceptor) processResolveResponse(reply *dns.Msg, resolvedDomain, originalDomain domain.Domain) error {
|
||||
newPrefixes := d.extractIPsFromDNSResponse(reply, resolvedDomain)
|
||||
if len(newPrefixes) > 0 {
|
||||
return d.updateDomainPrefixes(resolvedDomain, originalDomain, newPrefixes)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *DnsInterceptor) RemoveRoute() error {
|
||||
d.mu.Lock()
|
||||
|
||||
var merr *multierror.Error
|
||||
for domain, prefixes := range d.interceptedDomains {
|
||||
for _, prefix := range prefixes {
|
||||
// Routes should use fake IPs
|
||||
routePrefix := d.transformRealToFakePrefix(prefix)
|
||||
if _, err := d.routeRefCounter.Decrement(routePrefix); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove dynamic route for IP %s: %v", routePrefix, err))
|
||||
if _, err := d.routeRefCounter.Decrement(prefix); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove dynamic route for IP %s: %v", prefix, err))
|
||||
}
|
||||
|
||||
// AllowedIPs should use real IPs
|
||||
if d.currentPeerKey != "" {
|
||||
if _, err := d.allowedIPsRefcounter.Decrement(prefix); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove allowed IP %s: %v", prefix, err))
|
||||
@@ -97,10 +192,8 @@ func (d *DnsInterceptor) RemoveRoute() error {
|
||||
}
|
||||
}
|
||||
log.Debugf("removed dynamic route(s) for [%s]: %s", domain.SafeString(), strings.ReplaceAll(fmt.Sprintf("%s", prefixes), " ", ", "))
|
||||
|
||||
}
|
||||
|
||||
d.cleanupDNATMappings()
|
||||
|
||||
for _, domain := range d.route.Domains {
|
||||
d.statusRecorder.DeleteResolvedDomainsStates(domain)
|
||||
}
|
||||
@@ -113,68 +206,6 @@ func (d *DnsInterceptor) RemoveRoute() error {
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
// transformRealToFakePrefix returns fake IP prefix for routes (if DNAT enabled)
|
||||
func (d *DnsInterceptor) transformRealToFakePrefix(realPrefix netip.Prefix) netip.Prefix {
|
||||
if _, hasDNAT := d.internalDnatFw(); !hasDNAT {
|
||||
return realPrefix
|
||||
}
|
||||
|
||||
if fakeIP, ok := d.fakeIPManager.GetFakeIP(realPrefix.Addr()); ok {
|
||||
return netip.PrefixFrom(fakeIP, realPrefix.Bits())
|
||||
}
|
||||
|
||||
return realPrefix
|
||||
}
|
||||
|
||||
// addAllowedIPForPrefix handles the AllowedIPs logic for a single prefix (uses real IPs)
|
||||
func (d *DnsInterceptor) addAllowedIPForPrefix(realPrefix netip.Prefix, peerKey string, domain domain.Domain) error {
|
||||
// AllowedIPs always use real IPs
|
||||
ref, err := d.allowedIPsRefcounter.Increment(realPrefix, peerKey)
|
||||
if err != nil {
|
||||
return fmt.Errorf("add allowed IP %s: %v", realPrefix, err)
|
||||
}
|
||||
|
||||
if ref.Count > 1 && ref.Out != peerKey {
|
||||
log.Warnf("IP [%s] for domain [%s] is already routed by peer [%s]. HA routing disabled",
|
||||
realPrefix.Addr(),
|
||||
domain.SafeString(),
|
||||
ref.Out,
|
||||
)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// addRouteAndAllowedIP handles both route and AllowedIPs addition for a prefix
|
||||
func (d *DnsInterceptor) addRouteAndAllowedIP(realPrefix netip.Prefix, domain domain.Domain) error {
|
||||
// Routes use fake IPs (so traffic to fake IPs gets routed to interface)
|
||||
routePrefix := d.transformRealToFakePrefix(realPrefix)
|
||||
if _, err := d.routeRefCounter.Increment(routePrefix, struct{}{}); err != nil {
|
||||
return fmt.Errorf("add route for IP %s: %v", routePrefix, err)
|
||||
}
|
||||
|
||||
// Add to AllowedIPs if we have a current peer (uses real IPs)
|
||||
if d.currentPeerKey == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
return d.addAllowedIPForPrefix(realPrefix, d.currentPeerKey, domain)
|
||||
}
|
||||
|
||||
// removeAllowedIP handles AllowedIPs removal for a prefix (uses real IPs)
|
||||
func (d *DnsInterceptor) removeAllowedIP(realPrefix netip.Prefix) error {
|
||||
if d.currentPeerKey == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
// AllowedIPs use real IPs
|
||||
if _, err := d.allowedIPsRefcounter.Decrement(realPrefix); err != nil {
|
||||
return fmt.Errorf("remove allowed IP %s: %v", realPrefix, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *DnsInterceptor) AddAllowedIPs(peerKey string) error {
|
||||
d.mu.Lock()
|
||||
defer d.mu.Unlock()
|
||||
@@ -182,14 +213,20 @@ func (d *DnsInterceptor) AddAllowedIPs(peerKey string) error {
|
||||
var merr *multierror.Error
|
||||
for domain, prefixes := range d.interceptedDomains {
|
||||
for _, prefix := range prefixes {
|
||||
// AllowedIPs use real IPs
|
||||
if err := d.addAllowedIPForPrefix(prefix, peerKey, domain); err != nil {
|
||||
merr = multierror.Append(merr, err)
|
||||
if ref, err := d.allowedIPsRefcounter.Increment(prefix, peerKey); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("add allowed IP %s: %v", prefix, err))
|
||||
} else if ref.Count > 1 && ref.Out != peerKey {
|
||||
log.Warnf("IP [%s] for domain [%s] is already routed by peer [%s]. HA routing disabled",
|
||||
prefix.Addr(),
|
||||
domain.SafeString(),
|
||||
ref.Out,
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
d.currentPeerKey = peerKey
|
||||
go d.preResolveDomains()
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
@@ -200,7 +237,6 @@ func (d *DnsInterceptor) RemoveAllowedIPs() error {
|
||||
var merr *multierror.Error
|
||||
for _, prefixes := range d.interceptedDomains {
|
||||
for _, prefix := range prefixes {
|
||||
// AllowedIPs use real IPs
|
||||
if _, err := d.allowedIPsRefcounter.Decrement(prefix); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove allowed IP %s: %v", prefix, err))
|
||||
}
|
||||
@@ -213,18 +249,15 @@ func (d *DnsInterceptor) RemoveAllowedIPs() error {
|
||||
|
||||
// ServeDNS implements the dns.Handler interface
|
||||
func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||
requestID := nbdns.GenerateRequestID()
|
||||
logger := log.WithField("request_id", requestID)
|
||||
|
||||
if len(r.Question) == 0 {
|
||||
return
|
||||
}
|
||||
logger.Tracef("received DNS request for domain=%s type=%v class=%v",
|
||||
log.Tracef("received DNS request for domain=%s type=%v class=%v",
|
||||
r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass)
|
||||
|
||||
// pass if non A/AAAA query
|
||||
if r.Question[0].Qtype != dns.TypeA && r.Question[0].Qtype != dns.TypeAAAA {
|
||||
d.continueToNextHandler(w, r, logger, "non A/AAAA query")
|
||||
d.continueToNextHandler(w, r, "non A/AAAA query")
|
||||
return
|
||||
}
|
||||
|
||||
@@ -233,19 +266,13 @@ func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||
d.mu.RUnlock()
|
||||
|
||||
if peerKey == "" {
|
||||
d.writeDNSError(w, r, logger, "no current peer key")
|
||||
d.writeDNSError(w, r, "no current peer key")
|
||||
return
|
||||
}
|
||||
|
||||
upstreamIP, err := d.getUpstreamIP(peerKey)
|
||||
if err != nil {
|
||||
d.writeDNSError(w, r, logger, fmt.Sprintf("get upstream IP: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
client, err := nbdns.GetClientPrivate(d.wgInterface.Address().IP, d.wgInterface.Name(), nbdns.UpstreamTimeout)
|
||||
if err != nil {
|
||||
d.writeDNSError(w, r, logger, fmt.Sprintf("create DNS client: %v", err))
|
||||
d.writeDNSError(w, r, fmt.Sprintf("get upstream IP: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
@@ -253,12 +280,11 @@ func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||
r.MsgHdr.AuthenticatedData = true
|
||||
}
|
||||
|
||||
upstream := fmt.Sprintf("%s:%d", upstreamIP.String(), dnsfwd.ListenPort)
|
||||
reply, _, err := nbdns.ExchangeWithFallback(context.TODO(), client, r, upstream)
|
||||
reply, err := d.exchangeWithUpstream(context.TODO(), r, upstreamIP)
|
||||
if err != nil {
|
||||
logger.Errorf("failed to exchange DNS request with %s (%s) for domain=%s: %v", upstreamIP.String(), peerKey, r.Question[0].Name, err)
|
||||
log.Errorf("failed to exchange DNS request with %s (%s) for domain=%s: %v", upstreamIP.String(), peerKey, r.Question[0].Name, err)
|
||||
if err := w.WriteMsg(&dns.Msg{MsgHdr: dns.MsgHdr{Rcode: dns.RcodeServerFailure, Id: r.Id}}); err != nil {
|
||||
logger.Errorf("failed writing DNS response: %v", err)
|
||||
log.Errorf("failed writing DNS response: %v", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
@@ -268,34 +294,34 @@ func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||
answer = reply.Answer
|
||||
}
|
||||
|
||||
logger.Tracef("upstream %s (%s) DNS response for domain=%s answers=%v", upstreamIP.String(), peerKey, r.Question[0].Name, answer)
|
||||
log.Tracef("upstream %s (%s) DNS response for domain=%s answers=%v", upstreamIP.String(), peerKey, r.Question[0].Name, answer)
|
||||
|
||||
reply.Id = r.Id
|
||||
if err := d.writeMsg(w, reply); err != nil {
|
||||
logger.Errorf("failed writing DNS response: %v", err)
|
||||
log.Errorf("failed writing DNS response: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (d *DnsInterceptor) writeDNSError(w dns.ResponseWriter, r *dns.Msg, logger *log.Entry, reason string) {
|
||||
logger.Warnf("failed to query upstream for domain=%s: %s", r.Question[0].Name, reason)
|
||||
func (d *DnsInterceptor) writeDNSError(w dns.ResponseWriter, r *dns.Msg, reason string) {
|
||||
log.Warnf("failed to query upstream for domain=%s: %s", r.Question[0].Name, reason)
|
||||
|
||||
resp := new(dns.Msg)
|
||||
resp.SetRcode(r, dns.RcodeServerFailure)
|
||||
if err := w.WriteMsg(resp); err != nil {
|
||||
logger.Errorf("failed to write DNS error response: %v", err)
|
||||
log.Errorf("failed to write DNS error response: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// continueToNextHandler signals the handler chain to try the next handler
|
||||
func (d *DnsInterceptor) continueToNextHandler(w dns.ResponseWriter, r *dns.Msg, logger *log.Entry, reason string) {
|
||||
logger.Tracef("continuing to next handler for domain=%s reason=%s", r.Question[0].Name, reason)
|
||||
func (d *DnsInterceptor) continueToNextHandler(w dns.ResponseWriter, r *dns.Msg, reason string) {
|
||||
log.Tracef("continuing to next handler for domain=%s reason=%s", r.Question[0].Name, reason)
|
||||
|
||||
resp := new(dns.Msg)
|
||||
resp.SetRcode(r, dns.RcodeNameError)
|
||||
// Set Zero bit to signal handler chain to continue
|
||||
resp.MsgHdr.Zero = true
|
||||
if err := w.WriteMsg(resp); err != nil {
|
||||
logger.Errorf("failed writing DNS continue response: %v", err)
|
||||
log.Errorf("failed writing DNS continue response: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -319,45 +345,13 @@ func (d *DnsInterceptor) writeMsg(w dns.ResponseWriter, r *dns.Msg) error {
|
||||
}
|
||||
|
||||
resolvedDomain := domain.Domain(strings.ToLower(r.Question[0].Name))
|
||||
|
||||
// already punycode via RegisterHandler()
|
||||
originalDomain := domain.Domain(origPattern)
|
||||
if originalDomain == "" {
|
||||
originalDomain = resolvedDomain
|
||||
}
|
||||
|
||||
var newPrefixes []netip.Prefix
|
||||
for _, answer := range r.Answer {
|
||||
var ip netip.Addr
|
||||
switch rr := answer.(type) {
|
||||
case *dns.A:
|
||||
addr, ok := netip.AddrFromSlice(rr.A)
|
||||
if !ok {
|
||||
log.Tracef("failed to convert A record for domain=%s ip=%v", resolvedDomain, rr.A)
|
||||
continue
|
||||
}
|
||||
ip = addr
|
||||
case *dns.AAAA:
|
||||
addr, ok := netip.AddrFromSlice(rr.AAAA)
|
||||
if !ok {
|
||||
log.Tracef("failed to convert AAAA record for domain=%s ip=%v", resolvedDomain, rr.AAAA)
|
||||
continue
|
||||
}
|
||||
ip = addr
|
||||
default:
|
||||
continue
|
||||
}
|
||||
|
||||
prefix := netip.PrefixFrom(ip.Unmap(), ip.BitLen())
|
||||
newPrefixes = append(newPrefixes, prefix)
|
||||
}
|
||||
|
||||
if len(newPrefixes) > 0 {
|
||||
if err := d.updateDomainPrefixes(resolvedDomain, originalDomain, newPrefixes); err != nil {
|
||||
log.Errorf("failed to update domain prefixes: %v", err)
|
||||
}
|
||||
|
||||
d.replaceIPsInDNSResponse(r, newPrefixes)
|
||||
if err := d.processResolveResponse(r, resolvedDomain, originalDomain); err != nil {
|
||||
log.Errorf("failed to process DNS response: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -368,22 +362,6 @@ func (d *DnsInterceptor) writeMsg(w dns.ResponseWriter, r *dns.Msg) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// logPrefixChanges handles the logging for prefix changes
|
||||
func (d *DnsInterceptor) logPrefixChanges(resolvedDomain, originalDomain domain.Domain, toAdd, toRemove []netip.Prefix) {
|
||||
if len(toAdd) > 0 {
|
||||
log.Debugf("added dynamic route(s) for domain=%s (pattern: domain=%s): %s",
|
||||
resolvedDomain.SafeString(),
|
||||
originalDomain.SafeString(),
|
||||
toAdd)
|
||||
}
|
||||
if len(toRemove) > 0 && !d.route.KeepRoute {
|
||||
log.Debugf("removed dynamic route(s) for domain=%s (pattern: domain=%s): %s",
|
||||
resolvedDomain.SafeString(),
|
||||
originalDomain.SafeString(),
|
||||
toRemove)
|
||||
}
|
||||
}
|
||||
|
||||
func (d *DnsInterceptor) updateDomainPrefixes(resolvedDomain, originalDomain domain.Domain, newPrefixes []netip.Prefix) error {
|
||||
d.mu.Lock()
|
||||
defer d.mu.Unlock()
|
||||
@@ -392,163 +370,70 @@ func (d *DnsInterceptor) updateDomainPrefixes(resolvedDomain, originalDomain dom
|
||||
toAdd, toRemove := determinePrefixChanges(oldPrefixes, newPrefixes)
|
||||
|
||||
var merr *multierror.Error
|
||||
var dnatMappings map[netip.Addr]netip.Addr
|
||||
|
||||
// Handle DNAT mappings for new prefixes
|
||||
if _, hasDNAT := d.internalDnatFw(); hasDNAT {
|
||||
dnatMappings = make(map[netip.Addr]netip.Addr)
|
||||
for _, prefix := range toAdd {
|
||||
realIP := prefix.Addr()
|
||||
if fakeIP, err := d.fakeIPManager.AllocateFakeIP(realIP); err == nil {
|
||||
dnatMappings[fakeIP] = realIP
|
||||
log.Tracef("allocated fake IP %s for real IP %s", fakeIP, realIP)
|
||||
} else {
|
||||
log.Errorf("Failed to allocate fake IP for %s: %v", realIP, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Add new prefixes
|
||||
for _, prefix := range toAdd {
|
||||
if err := d.addRouteAndAllowedIP(prefix, resolvedDomain); err != nil {
|
||||
merr = multierror.Append(merr, err)
|
||||
if _, err := d.routeRefCounter.Increment(prefix, struct{}{}); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("add route for IP %s: %v", prefix, err))
|
||||
continue
|
||||
}
|
||||
|
||||
if d.currentPeerKey == "" {
|
||||
continue
|
||||
}
|
||||
if ref, err := d.allowedIPsRefcounter.Increment(prefix, d.currentPeerKey); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("add allowed IP %s: %v", prefix, err))
|
||||
} else if ref.Count > 1 && ref.Out != d.currentPeerKey {
|
||||
log.Warnf("IP [%s] for domain [%s] is already routed by peer [%s]. HA routing disabled",
|
||||
prefix.Addr(),
|
||||
resolvedDomain.SafeString(),
|
||||
ref.Out,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
d.addDNATMappings(dnatMappings)
|
||||
|
||||
if !d.route.KeepRoute {
|
||||
// Remove old prefixes
|
||||
for _, prefix := range toRemove {
|
||||
// Routes use fake IPs
|
||||
routePrefix := d.transformRealToFakePrefix(prefix)
|
||||
if _, err := d.routeRefCounter.Decrement(routePrefix); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove route for IP %s: %v", routePrefix, err))
|
||||
if _, err := d.routeRefCounter.Decrement(prefix); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove route for IP %s: %v", prefix, err))
|
||||
}
|
||||
// AllowedIPs use real IPs
|
||||
if err := d.removeAllowedIP(prefix); err != nil {
|
||||
merr = multierror.Append(merr, err)
|
||||
if d.currentPeerKey != "" {
|
||||
if _, err := d.allowedIPsRefcounter.Decrement(prefix); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove allowed IP %s: %v", prefix, err))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
d.removeDNATMappings(toRemove)
|
||||
}
|
||||
|
||||
// Update domain prefixes using resolved domain as key - store real IPs
|
||||
// Update domain prefixes using resolved domain as key
|
||||
if len(toAdd) > 0 || len(toRemove) > 0 {
|
||||
if d.route.KeepRoute {
|
||||
// replace stored prefixes with old + added
|
||||
// nolint:gocritic
|
||||
newPrefixes = append(oldPrefixes, toAdd...)
|
||||
}
|
||||
d.interceptedDomains[resolvedDomain] = newPrefixes
|
||||
originalDomain = domain.Domain(strings.TrimSuffix(string(originalDomain), "."))
|
||||
|
||||
// Store real IPs for status (user-facing), not fake IPs
|
||||
d.statusRecorder.UpdateResolvedDomainsStates(originalDomain, resolvedDomain, newPrefixes, d.route.GetResourceID())
|
||||
|
||||
d.logPrefixChanges(resolvedDomain, originalDomain, toAdd, toRemove)
|
||||
if len(toAdd) > 0 {
|
||||
log.Debugf("added dynamic route(s) for domain=%s (pattern: domain=%s): %s",
|
||||
resolvedDomain.SafeString(),
|
||||
originalDomain.SafeString(),
|
||||
toAdd)
|
||||
}
|
||||
if len(toRemove) > 0 && !d.route.KeepRoute {
|
||||
log.Debugf("removed dynamic route(s) for domain=%s (pattern: domain=%s): %s",
|
||||
resolvedDomain.SafeString(),
|
||||
originalDomain.SafeString(),
|
||||
toRemove)
|
||||
}
|
||||
}
|
||||
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
// removeDNATMappings removes DNAT mappings from the firewall for real IP prefixes
|
||||
func (d *DnsInterceptor) removeDNATMappings(realPrefixes []netip.Prefix) {
|
||||
if len(realPrefixes) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
dnatFirewall, ok := d.internalDnatFw()
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
for _, prefix := range realPrefixes {
|
||||
realIP := prefix.Addr()
|
||||
if fakeIP, exists := d.fakeIPManager.GetFakeIP(realIP); exists {
|
||||
if err := dnatFirewall.RemoveInternalDNATMapping(fakeIP); err != nil {
|
||||
log.Errorf("Failed to remove DNAT mapping for %s: %v", fakeIP, err)
|
||||
} else {
|
||||
log.Debugf("Removed DNAT mapping for: %s -> %s", fakeIP, realIP)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// internalDnatFw checks if the firewall supports internal DNAT
|
||||
func (d *DnsInterceptor) internalDnatFw() (internalDNATer, bool) {
|
||||
if d.firewall == nil || runtime.GOOS != "android" {
|
||||
return nil, false
|
||||
}
|
||||
fw, ok := d.firewall.(internalDNATer)
|
||||
return fw, ok
|
||||
}
|
||||
|
||||
// addDNATMappings adds DNAT mappings to the firewall
|
||||
func (d *DnsInterceptor) addDNATMappings(mappings map[netip.Addr]netip.Addr) {
|
||||
if len(mappings) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
dnatFirewall, ok := d.internalDnatFw()
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
for fakeIP, realIP := range mappings {
|
||||
if err := dnatFirewall.AddInternalDNATMapping(fakeIP, realIP); err != nil {
|
||||
log.Errorf("Failed to add DNAT mapping %s -> %s: %v", fakeIP, realIP, err)
|
||||
} else {
|
||||
log.Debugf("Added DNAT mapping: %s -> %s", fakeIP, realIP)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// cleanupDNATMappings removes all DNAT mappings for this interceptor
|
||||
func (d *DnsInterceptor) cleanupDNATMappings() {
|
||||
if _, ok := d.internalDnatFw(); !ok {
|
||||
return
|
||||
}
|
||||
|
||||
for _, prefixes := range d.interceptedDomains {
|
||||
d.removeDNATMappings(prefixes)
|
||||
}
|
||||
}
|
||||
|
||||
// replaceIPsInDNSResponse replaces real IPs with fake IPs in the DNS response
|
||||
func (d *DnsInterceptor) replaceIPsInDNSResponse(reply *dns.Msg, realPrefixes []netip.Prefix) {
|
||||
if _, ok := d.internalDnatFw(); !ok {
|
||||
return
|
||||
}
|
||||
|
||||
// Replace A and AAAA records with fake IPs
|
||||
for _, answer := range reply.Answer {
|
||||
switch rr := answer.(type) {
|
||||
case *dns.A:
|
||||
realIP, ok := netip.AddrFromSlice(rr.A)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
if fakeIP, exists := d.fakeIPManager.GetFakeIP(realIP); exists {
|
||||
rr.A = fakeIP.AsSlice()
|
||||
log.Tracef("Replaced real IP %s with fake IP %s in DNS response", realIP, fakeIP)
|
||||
}
|
||||
|
||||
case *dns.AAAA:
|
||||
realIP, ok := netip.AddrFromSlice(rr.AAAA)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
if fakeIP, exists := d.fakeIPManager.GetFakeIP(realIP); exists {
|
||||
rr.AAAA = fakeIP.AsSlice()
|
||||
log.Tracef("Replaced real IP %s with fake IP %s in DNS response", realIP, fakeIP)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func determinePrefixChanges(oldPrefixes, newPrefixes []netip.Prefix) (toAdd, toRemove []netip.Prefix) {
|
||||
prefixSet := make(map[netip.Prefix]bool)
|
||||
for _, prefix := range oldPrefixes {
|
||||
|
||||
@@ -14,7 +14,6 @@ import (
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/common"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/iface"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/util"
|
||||
@@ -53,16 +52,24 @@ type Route struct {
|
||||
resolverAddr string
|
||||
}
|
||||
|
||||
func NewRoute(params common.HandlerParams, resolverAddr string) *Route {
|
||||
func NewRoute(
|
||||
rt *route.Route,
|
||||
routeRefCounter *refcounter.RouteRefCounter,
|
||||
allowedIPsRefCounter *refcounter.AllowedIPsRefCounter,
|
||||
interval time.Duration,
|
||||
statusRecorder *peer.Status,
|
||||
wgInterface iface.WGIface,
|
||||
resolverAddr string,
|
||||
) *Route {
|
||||
return &Route{
|
||||
route: params.Route,
|
||||
routeRefCounter: params.RouteRefCounter,
|
||||
allowedIPsRefcounter: params.AllowedIPsRefCounter,
|
||||
interval: params.DnsRouterInterval,
|
||||
statusRecorder: params.StatusRecorder,
|
||||
wgInterface: params.WgInterface,
|
||||
resolverAddr: resolverAddr,
|
||||
route: rt,
|
||||
routeRefCounter: routeRefCounter,
|
||||
allowedIPsRefcounter: allowedIPsRefCounter,
|
||||
interval: interval,
|
||||
dynamicDomains: domainMap{},
|
||||
statusRecorder: statusRecorder,
|
||||
wgInterface: wgInterface,
|
||||
resolverAddr: resolverAddr,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,93 +0,0 @@
|
||||
package fakeip
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// Manager manages allocation of fake IPs from the 240.0.0.0/8 block
|
||||
type Manager struct {
|
||||
mu sync.Mutex
|
||||
nextIP netip.Addr // Next IP to allocate
|
||||
allocated map[netip.Addr]netip.Addr // real IP -> fake IP
|
||||
fakeToReal map[netip.Addr]netip.Addr // fake IP -> real IP
|
||||
baseIP netip.Addr // First usable IP: 240.0.0.1
|
||||
maxIP netip.Addr // Last usable IP: 240.255.255.254
|
||||
}
|
||||
|
||||
// NewManager creates a new fake IP manager using 240.0.0.0/8 block
|
||||
func NewManager() *Manager {
|
||||
baseIP := netip.AddrFrom4([4]byte{240, 0, 0, 1})
|
||||
maxIP := netip.AddrFrom4([4]byte{240, 255, 255, 254})
|
||||
|
||||
return &Manager{
|
||||
nextIP: baseIP,
|
||||
allocated: make(map[netip.Addr]netip.Addr),
|
||||
fakeToReal: make(map[netip.Addr]netip.Addr),
|
||||
baseIP: baseIP,
|
||||
maxIP: maxIP,
|
||||
}
|
||||
}
|
||||
|
||||
// AllocateFakeIP allocates a fake IP for the given real IP
|
||||
// Returns the fake IP, or existing fake IP if already allocated
|
||||
func (m *Manager) AllocateFakeIP(realIP netip.Addr) (netip.Addr, error) {
|
||||
if !realIP.Is4() {
|
||||
return netip.Addr{}, fmt.Errorf("only IPv4 addresses supported")
|
||||
}
|
||||
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
if fakeIP, exists := m.allocated[realIP]; exists {
|
||||
return fakeIP, nil
|
||||
}
|
||||
|
||||
startIP := m.nextIP
|
||||
for {
|
||||
currentIP := m.nextIP
|
||||
|
||||
// Advance to next IP, wrapping at boundary
|
||||
if m.nextIP.Compare(m.maxIP) >= 0 {
|
||||
m.nextIP = m.baseIP
|
||||
} else {
|
||||
m.nextIP = m.nextIP.Next()
|
||||
}
|
||||
|
||||
// Check if current IP is available
|
||||
if _, inUse := m.fakeToReal[currentIP]; !inUse {
|
||||
m.allocated[realIP] = currentIP
|
||||
m.fakeToReal[currentIP] = realIP
|
||||
return currentIP, nil
|
||||
}
|
||||
|
||||
// Prevent infinite loop if all IPs exhausted
|
||||
if m.nextIP.Compare(startIP) == 0 {
|
||||
return netip.Addr{}, fmt.Errorf("no more fake IPs available in 240.0.0.0/8 block")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// GetFakeIP returns the fake IP for a real IP if it exists
|
||||
func (m *Manager) GetFakeIP(realIP netip.Addr) (netip.Addr, bool) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
fakeIP, exists := m.allocated[realIP]
|
||||
return fakeIP, exists
|
||||
}
|
||||
|
||||
// GetRealIP returns the real IP for a fake IP if it exists, otherwise false
|
||||
func (m *Manager) GetRealIP(fakeIP netip.Addr) (netip.Addr, bool) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
realIP, exists := m.fakeToReal[fakeIP]
|
||||
return realIP, exists
|
||||
}
|
||||
|
||||
// GetFakeIPBlock returns the fake IP block used by this manager
|
||||
func (m *Manager) GetFakeIPBlock() netip.Prefix {
|
||||
return netip.MustParsePrefix("240.0.0.0/8")
|
||||
}
|
||||
@@ -1,240 +0,0 @@
|
||||
package fakeip
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"sync"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestNewManager(t *testing.T) {
|
||||
manager := NewManager()
|
||||
|
||||
if manager.baseIP.String() != "240.0.0.1" {
|
||||
t.Errorf("Expected base IP 240.0.0.1, got %s", manager.baseIP.String())
|
||||
}
|
||||
|
||||
if manager.maxIP.String() != "240.255.255.254" {
|
||||
t.Errorf("Expected max IP 240.255.255.254, got %s", manager.maxIP.String())
|
||||
}
|
||||
|
||||
if manager.nextIP.Compare(manager.baseIP) != 0 {
|
||||
t.Errorf("Expected nextIP to start at baseIP")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAllocateFakeIP(t *testing.T) {
|
||||
manager := NewManager()
|
||||
realIP := netip.MustParseAddr("8.8.8.8")
|
||||
|
||||
fakeIP, err := manager.AllocateFakeIP(realIP)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to allocate fake IP: %v", err)
|
||||
}
|
||||
|
||||
if !fakeIP.Is4() {
|
||||
t.Error("Fake IP should be IPv4")
|
||||
}
|
||||
|
||||
// Check it's in the correct range
|
||||
if fakeIP.As4()[0] != 240 {
|
||||
t.Errorf("Fake IP should be in 240.0.0.0/8 range, got %s", fakeIP.String())
|
||||
}
|
||||
|
||||
// Should return same fake IP for same real IP
|
||||
fakeIP2, err := manager.AllocateFakeIP(realIP)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get existing fake IP: %v", err)
|
||||
}
|
||||
|
||||
if fakeIP.Compare(fakeIP2) != 0 {
|
||||
t.Errorf("Expected same fake IP for same real IP, got %s and %s", fakeIP.String(), fakeIP2.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestAllocateFakeIPIPv6Rejection(t *testing.T) {
|
||||
manager := NewManager()
|
||||
realIPv6 := netip.MustParseAddr("2001:db8::1")
|
||||
|
||||
_, err := manager.AllocateFakeIP(realIPv6)
|
||||
if err == nil {
|
||||
t.Error("Expected error for IPv6 address")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetFakeIP(t *testing.T) {
|
||||
manager := NewManager()
|
||||
realIP := netip.MustParseAddr("1.1.1.1")
|
||||
|
||||
// Should not exist initially
|
||||
_, exists := manager.GetFakeIP(realIP)
|
||||
if exists {
|
||||
t.Error("Fake IP should not exist before allocation")
|
||||
}
|
||||
|
||||
// Allocate and check
|
||||
expectedFakeIP, err := manager.AllocateFakeIP(realIP)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to allocate: %v", err)
|
||||
}
|
||||
|
||||
fakeIP, exists := manager.GetFakeIP(realIP)
|
||||
if !exists {
|
||||
t.Error("Fake IP should exist after allocation")
|
||||
}
|
||||
|
||||
if fakeIP.Compare(expectedFakeIP) != 0 {
|
||||
t.Errorf("Expected %s, got %s", expectedFakeIP.String(), fakeIP.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestMultipleAllocations(t *testing.T) {
|
||||
manager := NewManager()
|
||||
|
||||
allocations := make(map[netip.Addr]netip.Addr)
|
||||
|
||||
// Allocate multiple IPs
|
||||
for i := 1; i <= 100; i++ {
|
||||
realIP := netip.AddrFrom4([4]byte{10, 0, byte(i / 256), byte(i % 256)})
|
||||
fakeIP, err := manager.AllocateFakeIP(realIP)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to allocate fake IP for %s: %v", realIP.String(), err)
|
||||
}
|
||||
|
||||
// Check for duplicates
|
||||
for _, existingFake := range allocations {
|
||||
if fakeIP.Compare(existingFake) == 0 {
|
||||
t.Errorf("Duplicate fake IP allocated: %s", fakeIP.String())
|
||||
}
|
||||
}
|
||||
|
||||
allocations[realIP] = fakeIP
|
||||
}
|
||||
|
||||
// Verify all allocations can be retrieved
|
||||
for realIP, expectedFake := range allocations {
|
||||
actualFake, exists := manager.GetFakeIP(realIP)
|
||||
if !exists {
|
||||
t.Errorf("Missing allocation for %s", realIP.String())
|
||||
}
|
||||
if actualFake.Compare(expectedFake) != 0 {
|
||||
t.Errorf("Mismatch for %s: expected %s, got %s", realIP.String(), expectedFake.String(), actualFake.String())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetFakeIPBlock(t *testing.T) {
|
||||
manager := NewManager()
|
||||
block := manager.GetFakeIPBlock()
|
||||
|
||||
expected := "240.0.0.0/8"
|
||||
if block.String() != expected {
|
||||
t.Errorf("Expected %s, got %s", expected, block.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestConcurrentAccess(t *testing.T) {
|
||||
manager := NewManager()
|
||||
|
||||
const numGoroutines = 50
|
||||
const allocationsPerGoroutine = 10
|
||||
|
||||
var wg sync.WaitGroup
|
||||
results := make(chan netip.Addr, numGoroutines*allocationsPerGoroutine)
|
||||
|
||||
// Concurrent allocations
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
wg.Add(1)
|
||||
go func(goroutineID int) {
|
||||
defer wg.Done()
|
||||
for j := 0; j < allocationsPerGoroutine; j++ {
|
||||
realIP := netip.AddrFrom4([4]byte{192, 168, byte(goroutineID), byte(j)})
|
||||
fakeIP, err := manager.AllocateFakeIP(realIP)
|
||||
if err != nil {
|
||||
t.Errorf("Failed to allocate in goroutine %d: %v", goroutineID, err)
|
||||
return
|
||||
}
|
||||
results <- fakeIP
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
close(results)
|
||||
|
||||
// Check for duplicates
|
||||
seen := make(map[netip.Addr]bool)
|
||||
count := 0
|
||||
for fakeIP := range results {
|
||||
if seen[fakeIP] {
|
||||
t.Errorf("Duplicate fake IP in concurrent test: %s", fakeIP.String())
|
||||
}
|
||||
seen[fakeIP] = true
|
||||
count++
|
||||
}
|
||||
|
||||
if count != numGoroutines*allocationsPerGoroutine {
|
||||
t.Errorf("Expected %d allocations, got %d", numGoroutines*allocationsPerGoroutine, count)
|
||||
}
|
||||
}
|
||||
|
||||
func TestIPExhaustion(t *testing.T) {
|
||||
// Create a manager with limited range for testing
|
||||
manager := &Manager{
|
||||
nextIP: netip.AddrFrom4([4]byte{240, 0, 0, 1}),
|
||||
allocated: make(map[netip.Addr]netip.Addr),
|
||||
fakeToReal: make(map[netip.Addr]netip.Addr),
|
||||
baseIP: netip.AddrFrom4([4]byte{240, 0, 0, 1}),
|
||||
maxIP: netip.AddrFrom4([4]byte{240, 0, 0, 3}), // Only 3 IPs available
|
||||
}
|
||||
|
||||
// Allocate all available IPs
|
||||
realIPs := []netip.Addr{
|
||||
netip.MustParseAddr("1.0.0.1"),
|
||||
netip.MustParseAddr("1.0.0.2"),
|
||||
netip.MustParseAddr("1.0.0.3"),
|
||||
}
|
||||
|
||||
for _, realIP := range realIPs {
|
||||
_, err := manager.AllocateFakeIP(realIP)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to allocate fake IP: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Try to allocate one more - should fail
|
||||
_, err := manager.AllocateFakeIP(netip.MustParseAddr("1.0.0.4"))
|
||||
if err == nil {
|
||||
t.Error("Expected exhaustion error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestWrapAround(t *testing.T) {
|
||||
// Create manager starting near the end of range
|
||||
manager := &Manager{
|
||||
nextIP: netip.AddrFrom4([4]byte{240, 0, 0, 254}),
|
||||
allocated: make(map[netip.Addr]netip.Addr),
|
||||
fakeToReal: make(map[netip.Addr]netip.Addr),
|
||||
baseIP: netip.AddrFrom4([4]byte{240, 0, 0, 1}),
|
||||
maxIP: netip.AddrFrom4([4]byte{240, 0, 0, 254}),
|
||||
}
|
||||
|
||||
// Allocate the last IP
|
||||
fakeIP1, err := manager.AllocateFakeIP(netip.MustParseAddr("1.0.0.1"))
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to allocate first IP: %v", err)
|
||||
}
|
||||
|
||||
if fakeIP1.String() != "240.0.0.254" {
|
||||
t.Errorf("Expected 240.0.0.254, got %s", fakeIP1.String())
|
||||
}
|
||||
|
||||
// Next allocation should wrap around to the beginning
|
||||
fakeIP2, err := manager.AllocateFakeIP(netip.MustParseAddr("1.0.0.2"))
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to allocate second IP: %v", err)
|
||||
}
|
||||
|
||||
if fakeIP2.String() != "240.0.0.1" {
|
||||
t.Errorf("Expected 240.0.0.1 after wrap, got %s", fakeIP2.String())
|
||||
}
|
||||
}
|
||||
@@ -8,11 +8,9 @@ import (
|
||||
"net/netip"
|
||||
"net/url"
|
||||
"runtime"
|
||||
"slices"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/hashicorp/go-multierror"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/exp/maps"
|
||||
@@ -26,8 +24,6 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/peerstore"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/client"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/common"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/fakeip"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/iface"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/notifier"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
||||
@@ -53,7 +49,7 @@ type Manager interface {
|
||||
GetClientRoutesWithNetID() map[route.NetID][]*route.Route
|
||||
SetRouteChangeListener(listener listener.NetworkChangeListener)
|
||||
InitialRouteRange() []string
|
||||
SetFirewall(firewall.Manager) error
|
||||
EnableServerRouter(firewall firewall.Manager) error
|
||||
Stop(stateManager *statemanager.Manager)
|
||||
}
|
||||
|
||||
@@ -67,7 +63,6 @@ type ManagerConfig struct {
|
||||
InitialRoutes []*route.Route
|
||||
StateManager *statemanager.Manager
|
||||
DNSServer dns.Server
|
||||
DNSFeatureFlag bool
|
||||
PeerStore *peerstore.Store
|
||||
DisableClientRoutes bool
|
||||
DisableServerRoutes bool
|
||||
@@ -94,13 +89,11 @@ type DefaultManager struct {
|
||||
// clientRoutes is the most recent list of clientRoutes received from the Management Service
|
||||
clientRoutes route.HAMap
|
||||
dnsServer dns.Server
|
||||
firewall firewall.Manager
|
||||
peerStore *peerstore.Store
|
||||
useNewDNSRoute bool
|
||||
disableClientRoutes bool
|
||||
disableServerRoutes bool
|
||||
activeRoutes map[route.HAUniqueID]client.RouteHandler
|
||||
fakeIPManager *fakeip.Manager
|
||||
}
|
||||
|
||||
func NewManager(config ManagerConfig) *DefaultManager {
|
||||
@@ -136,31 +129,11 @@ func NewManager(config ManagerConfig) *DefaultManager {
|
||||
}
|
||||
|
||||
if runtime.GOOS == "android" {
|
||||
dm.setupAndroidRoutes(config)
|
||||
cr := dm.initialClientRoutes(config.InitialRoutes)
|
||||
dm.notifier.SetInitialClientRoutes(cr)
|
||||
}
|
||||
return dm
|
||||
}
|
||||
func (m *DefaultManager) setupAndroidRoutes(config ManagerConfig) {
|
||||
cr := m.initialClientRoutes(config.InitialRoutes)
|
||||
|
||||
routesForComparison := slices.Clone(cr)
|
||||
|
||||
if config.DNSFeatureFlag {
|
||||
m.fakeIPManager = fakeip.NewManager()
|
||||
|
||||
id := uuid.NewString()
|
||||
fakeIPRoute := &route.Route{
|
||||
ID: route.ID(id),
|
||||
Network: m.fakeIPManager.GetFakeIPBlock(),
|
||||
NetID: route.NetID(id),
|
||||
Peer: m.pubKey,
|
||||
NetworkType: route.IPv4Network,
|
||||
}
|
||||
cr = append(cr, fakeIPRoute)
|
||||
}
|
||||
|
||||
m.notifier.SetInitialClientRoutes(cr, routesForComparison)
|
||||
}
|
||||
|
||||
func (m *DefaultManager) setupRefCounters(useNoop bool) {
|
||||
m.routeRefCounter = refcounter.New(
|
||||
@@ -249,16 +222,16 @@ func (m *DefaultManager) initSelector() *routeselector.RouteSelector {
|
||||
return routeselector.NewRouteSelector()
|
||||
}
|
||||
|
||||
// SetFirewall sets the firewall manager for the DefaultManager
|
||||
// Not thread-safe, should be called before starting the manager
|
||||
func (m *DefaultManager) SetFirewall(firewall firewall.Manager) error {
|
||||
m.firewall = firewall
|
||||
|
||||
if m.disableServerRoutes || firewall == nil {
|
||||
func (m *DefaultManager) EnableServerRouter(firewall firewall.Manager) error {
|
||||
if m.disableServerRoutes {
|
||||
log.Info("server routes are disabled")
|
||||
return nil
|
||||
}
|
||||
|
||||
if firewall == nil {
|
||||
return errors.New("firewall manager is not set")
|
||||
}
|
||||
|
||||
var err error
|
||||
m.serverRouter, err = server.NewRouter(m.ctx, m.wgInterface, firewall, m.statusRecorder)
|
||||
if err != nil {
|
||||
@@ -326,20 +299,17 @@ func (m *DefaultManager) updateSystemRoutes(newRoutes route.HAMap) error {
|
||||
}
|
||||
|
||||
for id, route := range toAdd {
|
||||
params := common.HandlerParams{
|
||||
Route: route,
|
||||
RouteRefCounter: m.routeRefCounter,
|
||||
AllowedIPsRefCounter: m.allowedIPsRefCounter,
|
||||
DnsRouterInterval: m.dnsRouteInterval,
|
||||
StatusRecorder: m.statusRecorder,
|
||||
WgInterface: m.wgInterface,
|
||||
DnsServer: m.dnsServer,
|
||||
PeerStore: m.peerStore,
|
||||
UseNewDNSRoute: m.useNewDNSRoute,
|
||||
Firewall: m.firewall,
|
||||
FakeIPManager: m.fakeIPManager,
|
||||
}
|
||||
handler := client.HandlerFromRoute(params)
|
||||
handler := client.HandlerFromRoute(
|
||||
route,
|
||||
m.routeRefCounter,
|
||||
m.allowedIPsRefCounter,
|
||||
m.dnsRouteInterval,
|
||||
m.statusRecorder,
|
||||
m.wgInterface,
|
||||
m.dnsServer,
|
||||
m.peerStore,
|
||||
m.useNewDNSRoute,
|
||||
)
|
||||
if err := handler.AddRoute(m.ctx); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("add route %s: %w", handler.String(), err))
|
||||
continue
|
||||
@@ -547,7 +517,6 @@ func (m *DefaultManager) initialClientRoutes(initialRoutes []*route.Route) []*ro
|
||||
for _, routes := range crMap {
|
||||
rs = append(rs, routes...)
|
||||
}
|
||||
|
||||
return rs
|
||||
}
|
||||
|
||||
|
||||
@@ -15,7 +15,7 @@ import (
|
||||
// MockManager is the mock instance of a route manager
|
||||
type MockManager struct {
|
||||
ClassifyRoutesFunc func(routes []*route.Route) (map[route.ID]*route.Route, route.HAMap)
|
||||
UpdateRoutesFunc func(updateSerial uint64, serverRoutes map[route.ID]*route.Route, clientRoutes route.HAMap, useNewDNSRoute bool) error
|
||||
UpdateRoutesFunc func (updateSerial uint64, serverRoutes map[route.ID]*route.Route, clientRoutes route.HAMap, useNewDNSRoute bool) error
|
||||
TriggerSelectionFunc func(haMap route.HAMap)
|
||||
GetRouteSelectorFunc func() *routeselector.RouteSelector
|
||||
GetClientRoutesFunc func() route.HAMap
|
||||
@@ -87,7 +87,7 @@ func (m *MockManager) SetRouteChangeListener(listener listener.NetworkChangeList
|
||||
|
||||
}
|
||||
|
||||
func (m *MockManager) SetFirewall(firewall.Manager) error {
|
||||
func (m *MockManager) EnableServerRouter(firewall firewall.Manager) error {
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
|
||||
136
client/internal/routemanager/notifier/notifier.go
Normal file
136
client/internal/routemanager/notifier/notifier.go
Normal file
@@ -0,0 +1,136 @@
|
||||
package notifier
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"runtime"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/listener"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
)
|
||||
|
||||
type Notifier struct {
|
||||
initialRouteRanges []string
|
||||
routeRanges []string
|
||||
|
||||
listener listener.NetworkChangeListener
|
||||
listenerMux sync.Mutex
|
||||
}
|
||||
|
||||
func NewNotifier() *Notifier {
|
||||
return &Notifier{}
|
||||
}
|
||||
|
||||
func (n *Notifier) SetListener(listener listener.NetworkChangeListener) {
|
||||
n.listenerMux.Lock()
|
||||
defer n.listenerMux.Unlock()
|
||||
n.listener = listener
|
||||
}
|
||||
|
||||
func (n *Notifier) SetInitialClientRoutes(clientRoutes []*route.Route) {
|
||||
nets := make([]string, 0)
|
||||
for _, r := range clientRoutes {
|
||||
// filter out domain routes
|
||||
if r.IsDynamic() {
|
||||
continue
|
||||
}
|
||||
nets = append(nets, r.Network.String())
|
||||
}
|
||||
sort.Strings(nets)
|
||||
n.initialRouteRanges = nets
|
||||
}
|
||||
|
||||
func (n *Notifier) OnNewRoutes(idMap route.HAMap) {
|
||||
if runtime.GOOS != "android" {
|
||||
return
|
||||
}
|
||||
newNets := make([]string, 0)
|
||||
for _, routes := range idMap {
|
||||
for _, r := range routes {
|
||||
newNets = append(newNets, r.Network.String())
|
||||
}
|
||||
}
|
||||
|
||||
sort.Strings(newNets)
|
||||
switch runtime.GOOS {
|
||||
case "android":
|
||||
if !n.hasDiff(n.initialRouteRanges, newNets) {
|
||||
return
|
||||
}
|
||||
default:
|
||||
if !n.hasDiff(n.routeRanges, newNets) {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
n.routeRanges = newNets
|
||||
|
||||
n.notify()
|
||||
}
|
||||
|
||||
func (n *Notifier) OnNewPrefixes(prefixes []netip.Prefix) {
|
||||
newNets := make([]string, 0)
|
||||
for _, prefix := range prefixes {
|
||||
newNets = append(newNets, prefix.String())
|
||||
}
|
||||
|
||||
sort.Strings(newNets)
|
||||
switch runtime.GOOS {
|
||||
case "android":
|
||||
if !n.hasDiff(n.initialRouteRanges, newNets) {
|
||||
return
|
||||
}
|
||||
default:
|
||||
if !n.hasDiff(n.routeRanges, newNets) {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
n.routeRanges = newNets
|
||||
|
||||
n.notify()
|
||||
}
|
||||
|
||||
func (n *Notifier) notify() {
|
||||
n.listenerMux.Lock()
|
||||
defer n.listenerMux.Unlock()
|
||||
if n.listener == nil {
|
||||
return
|
||||
}
|
||||
|
||||
go func(l listener.NetworkChangeListener) {
|
||||
l.OnNetworkChanged(strings.Join(addIPv6RangeIfNeeded(n.routeRanges), ","))
|
||||
}(n.listener)
|
||||
}
|
||||
|
||||
func (n *Notifier) hasDiff(a []string, b []string) bool {
|
||||
if len(a) != len(b) {
|
||||
return true
|
||||
}
|
||||
for i, v := range a {
|
||||
if v != b[i] {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (n *Notifier) GetInitialRouteRanges() []string {
|
||||
return addIPv6RangeIfNeeded(n.initialRouteRanges)
|
||||
}
|
||||
|
||||
// addIPv6RangeIfNeeded returns the input ranges with the default IPv6 range when there is an IPv4 default route.
|
||||
func addIPv6RangeIfNeeded(inputRanges []string) []string {
|
||||
ranges := inputRanges
|
||||
for _, r := range inputRanges {
|
||||
// we are intentionally adding the ipv6 default range in case of ipv4 default range
|
||||
// to ensure that all traffic is managed by the tunnel interface on android
|
||||
if r == "0.0.0.0/0" {
|
||||
ranges = append(ranges, "::/0")
|
||||
break
|
||||
}
|
||||
}
|
||||
return ranges
|
||||
}
|
||||
@@ -1,127 +0,0 @@
|
||||
//go:build android
|
||||
|
||||
package notifier
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"slices"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/listener"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
)
|
||||
|
||||
type Notifier struct {
|
||||
initialRoutes []*route.Route
|
||||
currentRoutes []*route.Route
|
||||
|
||||
listener listener.NetworkChangeListener
|
||||
listenerMux sync.Mutex
|
||||
}
|
||||
|
||||
func NewNotifier() *Notifier {
|
||||
return &Notifier{}
|
||||
}
|
||||
|
||||
func (n *Notifier) SetListener(listener listener.NetworkChangeListener) {
|
||||
n.listenerMux.Lock()
|
||||
defer n.listenerMux.Unlock()
|
||||
n.listener = listener
|
||||
}
|
||||
|
||||
func (n *Notifier) SetInitialClientRoutes(initialRoutes []*route.Route, routesForComparison []*route.Route) {
|
||||
// initialRoutes contains fake IP block for interface configuration
|
||||
filteredInitial := make([]*route.Route, 0)
|
||||
for _, r := range initialRoutes {
|
||||
if r.IsDynamic() {
|
||||
continue
|
||||
}
|
||||
filteredInitial = append(filteredInitial, r)
|
||||
}
|
||||
n.initialRoutes = filteredInitial
|
||||
|
||||
// routesForComparison excludes fake IP block for comparison with new routes
|
||||
filteredComparison := make([]*route.Route, 0)
|
||||
for _, r := range routesForComparison {
|
||||
if r.IsDynamic() {
|
||||
continue
|
||||
}
|
||||
filteredComparison = append(filteredComparison, r)
|
||||
}
|
||||
n.currentRoutes = filteredComparison
|
||||
}
|
||||
|
||||
func (n *Notifier) OnNewRoutes(idMap route.HAMap) {
|
||||
var newRoutes []*route.Route
|
||||
for _, routes := range idMap {
|
||||
for _, r := range routes {
|
||||
if r.IsDynamic() {
|
||||
continue
|
||||
}
|
||||
newRoutes = append(newRoutes, r)
|
||||
}
|
||||
}
|
||||
|
||||
if !n.hasRouteDiff(n.currentRoutes, newRoutes) {
|
||||
return
|
||||
}
|
||||
|
||||
n.currentRoutes = newRoutes
|
||||
n.notify()
|
||||
}
|
||||
|
||||
func (n *Notifier) OnNewPrefixes([]netip.Prefix) {
|
||||
// Not used on Android
|
||||
}
|
||||
|
||||
func (n *Notifier) notify() {
|
||||
n.listenerMux.Lock()
|
||||
defer n.listenerMux.Unlock()
|
||||
if n.listener == nil {
|
||||
return
|
||||
}
|
||||
|
||||
routeStrings := n.routesToStrings(n.currentRoutes)
|
||||
sort.Strings(routeStrings)
|
||||
go func(l listener.NetworkChangeListener) {
|
||||
l.OnNetworkChanged(strings.Join(n.addIPv6RangeIfNeeded(routeStrings, n.currentRoutes), ","))
|
||||
}(n.listener)
|
||||
}
|
||||
|
||||
func (n *Notifier) routesToStrings(routes []*route.Route) []string {
|
||||
nets := make([]string, 0, len(routes))
|
||||
for _, r := range routes {
|
||||
nets = append(nets, r.NetString())
|
||||
}
|
||||
return nets
|
||||
}
|
||||
|
||||
func (n *Notifier) hasRouteDiff(a []*route.Route, b []*route.Route) bool {
|
||||
slices.SortFunc(a, func(x, y *route.Route) int {
|
||||
return strings.Compare(x.NetString(), y.NetString())
|
||||
})
|
||||
slices.SortFunc(b, func(x, y *route.Route) int {
|
||||
return strings.Compare(x.NetString(), y.NetString())
|
||||
})
|
||||
|
||||
return !slices.EqualFunc(a, b, func(x, y *route.Route) bool {
|
||||
return x.NetString() == y.NetString()
|
||||
})
|
||||
}
|
||||
|
||||
func (n *Notifier) GetInitialRouteRanges() []string {
|
||||
initialStrings := n.routesToStrings(n.initialRoutes)
|
||||
sort.Strings(initialStrings)
|
||||
return n.addIPv6RangeIfNeeded(initialStrings, n.initialRoutes)
|
||||
}
|
||||
|
||||
func (n *Notifier) addIPv6RangeIfNeeded(inputRanges []string, routes []*route.Route) []string {
|
||||
for _, r := range routes {
|
||||
if r.Network.Addr().Is4() && r.Network.Bits() == 0 {
|
||||
return append(slices.Clone(inputRanges), "::/0")
|
||||
}
|
||||
}
|
||||
return inputRanges
|
||||
}
|
||||
@@ -1,80 +0,0 @@
|
||||
//go:build ios
|
||||
|
||||
package notifier
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"slices"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/listener"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
)
|
||||
|
||||
type Notifier struct {
|
||||
currentPrefixes []string
|
||||
|
||||
listener listener.NetworkChangeListener
|
||||
listenerMux sync.Mutex
|
||||
}
|
||||
|
||||
func NewNotifier() *Notifier {
|
||||
return &Notifier{}
|
||||
}
|
||||
|
||||
func (n *Notifier) SetListener(listener listener.NetworkChangeListener) {
|
||||
n.listenerMux.Lock()
|
||||
defer n.listenerMux.Unlock()
|
||||
n.listener = listener
|
||||
}
|
||||
|
||||
func (n *Notifier) SetInitialClientRoutes([]*route.Route, []*route.Route) {
|
||||
// iOS doesn't care about initial routes
|
||||
}
|
||||
|
||||
func (n *Notifier) OnNewRoutes(route.HAMap) {
|
||||
// Not used on iOS
|
||||
}
|
||||
|
||||
func (n *Notifier) OnNewPrefixes(prefixes []netip.Prefix) {
|
||||
newNets := make([]string, 0)
|
||||
for _, prefix := range prefixes {
|
||||
newNets = append(newNets, prefix.String())
|
||||
}
|
||||
|
||||
sort.Strings(newNets)
|
||||
|
||||
if slices.Equal(n.currentPrefixes, newNets) {
|
||||
return
|
||||
}
|
||||
|
||||
n.currentPrefixes = newNets
|
||||
n.notify()
|
||||
}
|
||||
|
||||
func (n *Notifier) notify() {
|
||||
n.listenerMux.Lock()
|
||||
defer n.listenerMux.Unlock()
|
||||
if n.listener == nil {
|
||||
return
|
||||
}
|
||||
|
||||
go func(l listener.NetworkChangeListener) {
|
||||
l.OnNetworkChanged(strings.Join(n.addIPv6RangeIfNeeded(n.currentPrefixes), ","))
|
||||
}(n.listener)
|
||||
}
|
||||
|
||||
func (n *Notifier) GetInitialRouteRanges() []string {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (n *Notifier) addIPv6RangeIfNeeded(inputRanges []string) []string {
|
||||
for _, r := range inputRanges {
|
||||
if r == "0.0.0.0/0" {
|
||||
return append(slices.Clone(inputRanges), "::/0")
|
||||
}
|
||||
}
|
||||
return inputRanges
|
||||
}
|
||||
@@ -1,36 +0,0 @@
|
||||
//go:build !android && !ios
|
||||
|
||||
package notifier
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/listener"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
)
|
||||
|
||||
type Notifier struct{}
|
||||
|
||||
func NewNotifier() *Notifier {
|
||||
return &Notifier{}
|
||||
}
|
||||
|
||||
func (n *Notifier) SetListener(listener listener.NetworkChangeListener) {
|
||||
// Not used on non-mobile platforms
|
||||
}
|
||||
|
||||
func (n *Notifier) SetInitialClientRoutes([]*route.Route, []*route.Route) {
|
||||
// Not used on non-mobile platforms
|
||||
}
|
||||
|
||||
func (n *Notifier) OnNewRoutes(idMap route.HAMap) {
|
||||
// Not used on non-mobile platforms
|
||||
}
|
||||
|
||||
func (n *Notifier) OnNewPrefixes(prefixes []netip.Prefix) {
|
||||
// Not used on non-mobile platforms
|
||||
}
|
||||
|
||||
func (n *Notifier) GetInitialRouteRanges() []string {
|
||||
return []string{}
|
||||
}
|
||||
@@ -6,7 +6,6 @@ import (
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/common"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
)
|
||||
@@ -17,11 +16,11 @@ type Route struct {
|
||||
allowedIPsRefcounter *refcounter.AllowedIPsRefCounter
|
||||
}
|
||||
|
||||
func NewRoute(params common.HandlerParams) *Route {
|
||||
func NewRoute(rt *route.Route, routeRefCounter *refcounter.RouteRefCounter, allowedIPsRefCounter *refcounter.AllowedIPsRefCounter) *Route {
|
||||
return &Route{
|
||||
route: params.Route,
|
||||
routeRefCounter: params.RouteRefCounter,
|
||||
allowedIPsRefcounter: params.AllowedIPsRefCounter,
|
||||
route: rt,
|
||||
routeRefCounter: routeRefCounter,
|
||||
allowedIPsRefcounter: allowedIPsRefCounter,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -5,7 +5,6 @@ import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/notifier"
|
||||
@@ -29,10 +28,7 @@ func (n Nexthop) String() string {
|
||||
if n.Intf == nil {
|
||||
return n.IP.String()
|
||||
}
|
||||
if n.IP.IsValid() {
|
||||
return fmt.Sprintf("%s @ %d (%s)", n.IP.String(), n.Intf.Index, n.Intf.Name)
|
||||
}
|
||||
return fmt.Sprintf("no-ip @ %d (%s)", n.Intf.Index, n.Intf.Name)
|
||||
return fmt.Sprintf("%s @ %d (%s)", n.IP.String(), n.Intf.Index, n.Intf.Name)
|
||||
}
|
||||
|
||||
type wgIface interface {
|
||||
@@ -53,9 +49,6 @@ type SysOps struct {
|
||||
mu sync.Mutex
|
||||
// notifier is used to notify the system of route changes (also used on mobile)
|
||||
notifier *notifier.Notifier
|
||||
// seq is an atomic counter for generating unique sequence numbers for route messages
|
||||
//nolint:unused // only used on BSD systems
|
||||
seq atomic.Uint32
|
||||
}
|
||||
|
||||
func NewSysOps(wgInterface wgIface, notifier *notifier.Notifier) *SysOps {
|
||||
@@ -65,11 +58,6 @@ func NewSysOps(wgInterface wgIface, notifier *notifier.Notifier) *SysOps {
|
||||
}
|
||||
}
|
||||
|
||||
//nolint:unused // only used on BSD systems
|
||||
func (r *SysOps) getSeq() int {
|
||||
return int(r.seq.Add(1))
|
||||
}
|
||||
|
||||
func (r *SysOps) validateRoute(prefix netip.Prefix) error {
|
||||
addr := prefix.Addr()
|
||||
|
||||
|
||||
@@ -108,7 +108,7 @@ func (r *SysOps) buildRouteMessage(action int, prefix netip.Prefix, nexthop Next
|
||||
Type: action,
|
||||
Flags: unix.RTF_UP,
|
||||
Version: unix.RTM_VERSION,
|
||||
Seq: r.getSeq(),
|
||||
Seq: 1,
|
||||
}
|
||||
|
||||
const numAddrs = unix.RTAX_NETMASK + 1
|
||||
|
||||
@@ -10,11 +10,10 @@ type StatusType string
|
||||
const (
|
||||
StatusIdle StatusType = "Idle"
|
||||
|
||||
StatusConnecting StatusType = "Connecting"
|
||||
StatusConnected StatusType = "Connected"
|
||||
StatusNeedsLogin StatusType = "NeedsLogin"
|
||||
StatusLoginFailed StatusType = "LoginFailed"
|
||||
StatusSessionExpired StatusType = "SessionExpired"
|
||||
StatusConnecting StatusType = "Connecting"
|
||||
StatusConnected StatusType = "Connected"
|
||||
StatusNeedsLogin StatusType = "NeedsLogin"
|
||||
StatusLoginFailed StatusType = "LoginFailed"
|
||||
)
|
||||
|
||||
// CtxInitState setup context state into the context tree.
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user