Compare commits

..

14 Commits

Author SHA1 Message Date
crn4
ca432ff681 refactor components creation 2026-01-08 17:34:53 +01:00
crn4
7b5d7aeb2e refactor 2026-01-08 17:19:16 +01:00
crn4
3bdce8d0b6 added support for ssh auth users to components 2026-01-08 16:36:40 +01:00
crn4
d534ce9dfc minor changes to benchmarks 2026-01-08 15:06:02 +01:00
crn4
bbc2b42807 changes after main merge 2026-01-08 13:44:15 +01:00
crn4
db9cc52c96 conflicts resolution 2026-01-08 13:25:10 +01:00
crn4
3209b241d9 minor opts 2026-01-06 12:03:01 +01:00
crn4
7566afd7d0 components approach - we are sending all components needed for nmap assembling on client side 2025-12-29 17:31:16 +01:00
crn4
e93d4132d3 Merge branch 'main' into nmap/compaction 2025-12-18 16:50:41 +01:00
crn4
21e5e6ddff cached compaction 2025-12-05 00:29:58 +01:00
crn4
10fb18736b benchmark on both maps was added 2025-12-05 00:29:58 +01:00
crn4
942abeca0c select nmap based on peer version 2025-12-05 00:29:58 +01:00
crn4
e184a43e8a firewallrules compacted 2025-12-05 00:29:58 +01:00
crn4
f33f84299f routes compacted 2025-12-05 00:29:58 +01:00
164 changed files with 4343 additions and 4809 deletions

View File

@@ -1,15 +1,15 @@
FROM golang:1.25-bookworm
FROM golang:1.23-bullseye
RUN apt-get update && export DEBIAN_FRONTEND=noninteractive \
&& apt-get -y install --no-install-recommends\
gettext-base=0.21-12 \
iptables=1.8.9-2 \
libgl1-mesa-dev=22.3.6-1+deb12u1 \
xorg-dev=1:7.7+23 \
libayatana-appindicator3-dev=0.5.92-1 \
gettext-base=0.21-4 \
iptables=1.8.7-1 \
libgl1-mesa-dev=20.3.5-1 \
xorg-dev=1:7.7+22 \
libayatana-appindicator3-dev=0.5.5-2+deb11u2 \
&& apt-get clean \
&& rm -rf /var/lib/apt/lists/* \
&& go install -v golang.org/x/tools/gopls@latest
&& go install -v golang.org/x/tools/gopls@v0.18.1
WORKDIR /app

View File

@@ -25,7 +25,7 @@ jobs:
release: "14.2"
prepare: |
pkg install -y curl pkgconf xorg
GO_TARBALL="go1.25.3.freebsd-amd64.tar.gz"
GO_TARBALL="go1.24.10.freebsd-amd64.tar.gz"
GO_URL="https://go.dev/dl/$GO_TARBALL"
curl -vLO "$GO_URL"
tar -C /usr/local -vxzf "$GO_TARBALL"

View File

@@ -200,7 +200,7 @@ jobs:
-e GOCACHE=${CONTAINER_GOCACHE} \
-e GOMODCACHE=${CONTAINER_GOMODCACHE} \
-e CONTAINER=${CONTAINER} \
golang:1.25-alpine \
golang:1.24-alpine \
sh -c ' \
apk update; apk add --no-cache \
ca-certificates iptables ip6tables dbus dbus-dev libpcap-dev build-base; \
@@ -259,7 +259,7 @@ jobs:
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
go test ${{ matrix.raceFlag }} \
-exec 'sudo' \
-timeout 10m -p 1 ./relay/... ./shared/relay/...
-timeout 10m ./relay/... ./shared/relay/...
test_signal:
name: "Signal / Unit"

View File

@@ -52,10 +52,7 @@ jobs:
if: matrix.os == 'ubuntu-latest'
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev libpcap-dev
- name: golangci-lint
uses: golangci/golangci-lint-action@4afd733a84b1f43292c63897423277bb7f4313a9 # v8.0.0
uses: golangci/golangci-lint-action@v4
with:
version: latest
skip-cache: true
skip-save-cache: true
cache-invalidation-interval: 0
args: --timeout=12m
args: --timeout=12m --out-format colored-line-number

View File

@@ -63,7 +63,7 @@ jobs:
pkg install -y git curl portlint go
# Install Go for building
GO_TARBALL="go1.25.5.freebsd-amd64.tar.gz"
GO_TARBALL="go1.24.10.freebsd-amd64.tar.gz"
GO_URL="https://go.dev/dl/$GO_TARBALL"
curl -LO "$GO_URL"
tar -C /usr/local -xzf "$GO_TARBALL"

View File

@@ -14,9 +14,6 @@ jobs:
js_lint:
name: "JS / Lint"
runs-on: ubuntu-latest
env:
GOOS: js
GOARCH: wasm
steps:
- name: Checkout repository
uses: actions/checkout@v4
@@ -27,14 +24,16 @@ jobs:
- name: Install dependencies
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev libpcap-dev
- name: Install golangci-lint
uses: golangci/golangci-lint-action@4afd733a84b1f43292c63897423277bb7f4313a9 # v8.0.0
uses: golangci/golangci-lint-action@d6238b002a20823d52840fda27e2d4891c5952dc
with:
version: latest
install-mode: binary
skip-cache: true
skip-save-cache: true
cache-invalidation-interval: 0
working-directory: ./client
skip-pkg-cache: true
skip-build-cache: true
- name: Run golangci-lint for WASM
run: |
GOOS=js GOARCH=wasm golangci-lint run --timeout=12m --out-format colored-line-number ./client/...
continue-on-error: true
js_build:

View File

@@ -1,124 +1,139 @@
version: "2"
linters:
default: none
enable:
- bodyclose
- dupword
- durationcheck
- errcheck
- forbidigo
- gocritic
- gosec
- govet
- ineffassign
- mirror
- misspell
- nilerr
- nilnil
- predeclared
- revive
- sqlclosecheck
- staticcheck
- unused
- wastedassign
settings:
errcheck:
check-type-assertions: false
gocritic:
disabled-checks:
- commentFormatting
- captLocal
- deprecatedComment
gosec:
includes:
- G101
- G103
- G104
- G106
- G108
- G109
- G110
- G111
- G201
- G202
- G203
- G301
- G302
- G303
- G304
- G305
- G306
- G307
- G403
- G502
- G503
- G504
- G601
- G602
govet:
enable:
- nilness
enable-all: false
revive:
rules:
- name: exported
arguments:
- checkPrivateReceivers
- sayRepetitiveInsteadOfStutters
severity: warning
disabled: false
exclusions:
generated: lax
presets:
- comments
- common-false-positives
- legacy
- std-error-handling
run:
# Timeout for analysis, e.g. 30s, 5m.
# Default: 1m
timeout: 6m
# This file contains only configs which differ from defaults.
# All possible options can be found here https://github.com/golangci/golangci-lint/blob/master/.golangci.reference.yml
linters-settings:
errcheck:
# Report about not checking of errors in type assertions: `a := b.(MyStruct)`.
# Such cases aren't reported by default.
# Default: false
check-type-assertions: false
gosec:
includes:
- G101 # Look for hard coded credentials
#- G102 # Bind to all interfaces
- G103 # Audit the use of unsafe block
- G104 # Audit errors not checked
- G106 # Audit the use of ssh.InsecureIgnoreHostKey
#- G107 # Url provided to HTTP request as taint input
- G108 # Profiling endpoint automatically exposed on /debug/pprof
- G109 # Potential Integer overflow made by strconv.Atoi result conversion to int16/32
- G110 # Potential DoS vulnerability via decompression bomb
- G111 # Potential directory traversal
#- G112 # Potential slowloris attack
- G113 # Usage of Rat.SetString in math/big with an overflow (CVE-2022-23772)
#- G114 # Use of net/http serve function that has no support for setting timeouts
- G201 # SQL query construction using format string
- G202 # SQL query construction using string concatenation
- G203 # Use of unescaped data in HTML templates
#- G204 # Audit use of command execution
- G301 # Poor file permissions used when creating a directory
- G302 # Poor file permissions used with chmod
- G303 # Creating tempfile using a predictable path
- G304 # File path provided as taint input
- G305 # File traversal when extracting zip/tar archive
- G306 # Poor file permissions used when writing to a new file
- G307 # Poor file permissions used when creating a file with os.Create
#- G401 # Detect the usage of DES, RC4, MD5 or SHA1
#- G402 # Look for bad TLS connection settings
- G403 # Ensure minimum RSA key length of 2048 bits
#- G404 # Insecure random number source (rand)
#- G501 # Import blocklist: crypto/md5
- G502 # Import blocklist: crypto/des
- G503 # Import blocklist: crypto/rc4
- G504 # Import blocklist: net/http/cgi
#- G505 # Import blocklist: crypto/sha1
- G601 # Implicit memory aliasing of items from a range statement
- G602 # Slice access out of bounds
gocritic:
disabled-checks:
- commentFormatting
- captLocal
- deprecatedComment
govet:
# Enable all analyzers.
# Default: false
enable-all: false
enable:
- nilness
revive:
rules:
- linters:
- forbidigo
path: management/cmd/root\.go
- linters:
- forbidigo
path: signal/cmd/root\.go
- linters:
- unused
path: sharedsock/filter\.go
- linters:
- unused
path: client/firewall/iptables/rule\.go
- linters:
- gosec
- mirror
path: test\.go
- linters:
- nilnil
path: mock\.go
- linters:
- staticcheck
text: grpc.DialContext is deprecated
- linters:
- staticcheck
text: grpc.WithBlock is deprecated
- linters:
- staticcheck
text: "QF1001"
- linters:
- staticcheck
text: "QF1008"
- linters:
- staticcheck
text: "QF1012"
paths:
- third_party$
- builtin$
- examples$
- name: exported
severity: warning
disabled: false
arguments:
- "checkPrivateReceivers"
- "sayRepetitiveInsteadOfStutters"
tenv:
# The option `all` will run against whole test files (`_test.go`) regardless of method/function signatures.
# Otherwise, only methods that take `*testing.T`, `*testing.B`, and `testing.TB` as arguments are checked.
# Default: false
all: true
linters:
disable-all: true
enable:
## enabled by default
- errcheck # checking for unchecked errors, these unchecked errors can be critical bugs in some cases
- gosimple # specializes in simplifying a code
- govet # reports suspicious constructs, such as Printf calls whose arguments do not align with the format string
- ineffassign # detects when assignments to existing variables are not used
- staticcheck # is a go vet on steroids, applying a ton of static analysis checks
- tenv # Tenv is analyzer that detects using os.Setenv instead of t.Setenv since Go1.17.
- typecheck # like the front-end of a Go compiler, parses and type-checks Go code
- unused # checks for unused constants, variables, functions and types
## disable by default but the have interesting results so lets add them
- bodyclose # checks whether HTTP response body is closed successfully
- dupword # dupword checks for duplicate words in the source code
- durationcheck # durationcheck checks for two durations multiplied together
- forbidigo # forbidigo forbids identifiers
- gocritic # provides diagnostics that check for bugs, performance and style issues
- gosec # inspects source code for security problems
- mirror # mirror reports wrong mirror patterns of bytes/strings usage
- misspell # misspess finds commonly misspelled English words in comments
- nilerr # finds the code that returns nil even if it checks that the error is not nil
- nilnil # checks that there is no simultaneous return of nil error and an invalid value
- predeclared # predeclared finds code that shadows one of Go's predeclared identifiers
- revive # Fast, configurable, extensible, flexible, and beautiful linter for Go. Drop-in replacement of golint.
- sqlclosecheck # checks that sql.Rows and sql.Stmt are closed
# - thelper # thelper detects Go test helpers without t.Helper() call and checks the consistency of test helpers.
- wastedassign # wastedassign finds wasted assignment statements
issues:
# Maximum count of issues with the same text.
# Set to 0 to disable.
# Default: 3
max-same-issues: 5
formatters:
exclusions:
generated: lax
paths:
- third_party$
- builtin$
- examples$
exclude-rules:
# allow fmt
- path: management/cmd/root\.go
linters: forbidigo
- path: signal/cmd/root\.go
linters: forbidigo
- path: sharedsock/filter\.go
linters:
- unused
- path: client/firewall/iptables/rule\.go
linters:
- unused
- path: test\.go
linters:
- mirror
- gosec
- path: mock\.go
linters:
- nilnil
# Exclude specific deprecation warnings for grpc methods
- linters:
- staticcheck
text: "grpc.DialContext is deprecated"
- linters:
- staticcheck
text: "grpc.WithBlock is deprecated"

View File

@@ -38,11 +38,6 @@
</strong>
<br>
<strong>
🚀 <a href="https://careers.netbird.io">We are hiring! Join us at careers.netbird.io</a>
</strong>
<br>
<br>
<a href="https://registry.terraform.io/providers/netbirdio/netbird/latest">
New: NetBird terraform provider
</a>

View File

@@ -136,7 +136,6 @@ func setLogLevel(cmd *cobra.Command, args []string) error {
client := proto.NewDaemonServiceClient(conn)
level := server.ParseLogLevel(args[0])
if level == proto.LogLevel_UNKNOWN {
//nolint
return fmt.Errorf("unknown log level: %s. Available levels are: panic, fatal, error, warn, info, debug, trace\n", args[0])
}
@@ -314,8 +313,9 @@ func getStatusOutput(cmd *cobra.Command, anon bool) string {
profName = activeProf.Name
}
overview := nbstatus.ConvertToStatusOutputOverview(statusResp, anon, "", nil, nil, nil, "", profName)
statusOutputString = overview.FullDetailSummary()
statusOutputString = nbstatus.ParseToFullDetailSummary(
nbstatus.ConvertToStatusOutputOverview(statusResp, anon, "", nil, nil, nil, "", profName),
)
}
return statusOutputString
}

View File

@@ -81,7 +81,6 @@ var loginCmd = &cobra.Command{
func doDaemonLogin(ctx context.Context, cmd *cobra.Command, providedSetupKey string, activeProf *profilemanager.Profile, username string, pm *profilemanager.ProfileManager) error {
conn, err := DialClientGRPCServer(ctx, daemonAddr)
if err != nil {
//nolint
return fmt.Errorf("failed to connect to daemon error: %v\n"+
"If the daemon is not running please run: "+
"\nnetbird service install \nnetbird service start\n", err)
@@ -207,7 +206,6 @@ func switchProfileOnDaemon(ctx context.Context, pm *profilemanager.ProfileManage
func switchProfile(ctx context.Context, profileName string, username string) error {
conn, err := DialClientGRPCServer(ctx, daemonAddr)
if err != nil {
//nolint
return fmt.Errorf("failed to connect to daemon error: %v\n"+
"If the daemon is not running please run: "+
"\nnetbird service install \nnetbird service start\n", err)

View File

@@ -1,4 +1,5 @@
//go:build pprof
// +build pprof
package cmd

View File

@@ -390,7 +390,6 @@ func getClient(cmd *cobra.Command) (*grpc.ClientConn, error) {
conn, err := DialClientGRPCServer(cmd.Context(), daemonAddr)
if err != nil {
//nolint
return nil, fmt.Errorf("failed to connect to daemon error: %v\n"+
"If the daemon is not running please run: "+
"\nnetbird service install \nnetbird service start\n", err)

View File

@@ -103,13 +103,13 @@ func statusFunc(cmd *cobra.Command, args []string) error {
var statusOutputString string
switch {
case detailFlag:
statusOutputString = outputInformationHolder.FullDetailSummary()
statusOutputString = nbstatus.ParseToFullDetailSummary(outputInformationHolder)
case jsonFlag:
statusOutputString, err = outputInformationHolder.JSON()
statusOutputString, err = nbstatus.ParseToJSON(outputInformationHolder)
case yamlFlag:
statusOutputString, err = outputInformationHolder.YAML()
statusOutputString, err = nbstatus.ParseToYAML(outputInformationHolder)
default:
statusOutputString = outputInformationHolder.GeneralSummary(false, false, false, false)
statusOutputString = nbstatus.ParseGeneralSummary(outputInformationHolder, false, false, false, false)
}
if err != nil {
@@ -124,7 +124,6 @@ func statusFunc(cmd *cobra.Command, args []string) error {
func getStatus(ctx context.Context, shouldRunProbes bool) (*proto.StatusResponse, error) {
conn, err := DialClientGRPCServer(ctx, daemonAddr)
if err != nil {
//nolint
return nil, fmt.Errorf("failed to connect to daemon error: %v\n"+
"If the daemon is not running please run: "+
"\nnetbird service install \nnetbird service start\n", err)

View File

@@ -89,6 +89,9 @@ func startManagement(t *testing.T, config *config.Config, testFile string) (*grp
t.Cleanup(cleanUp)
eventStore := &activity.InMemoryEventStore{}
if err != nil {
return nil, nil
}
ctrl := gomock.NewController(t)
t.Cleanup(ctrl.Finish)

View File

@@ -216,7 +216,6 @@ func runInDaemonMode(ctx context.Context, cmd *cobra.Command, pm *profilemanager
conn, err := DialClientGRPCServer(ctx, daemonAddr)
if err != nil {
//nolint
return fmt.Errorf("failed to connect to daemon error: %v\n"+
"If the daemon is not running please run: "+
"\nnetbird service install \nnetbird service start\n", err)

View File

@@ -10,7 +10,6 @@ import (
"net/netip"
"os"
"sync"
"time"
"github.com/sirupsen/logrus"
wgnetstack "golang.zx2c4.com/wireguard/tun/netstack"
@@ -21,7 +20,6 @@ import (
"github.com/netbirdio/netbird/client/internal/profilemanager"
sshcommon "github.com/netbirdio/netbird/client/ssh"
"github.com/netbirdio/netbird/client/system"
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
)
var (
@@ -31,11 +29,6 @@ var (
ErrConfigNotInitialized = errors.New("config not initialized")
)
const (
defaultPeerConnectionTimeout = 60 * time.Second
peerConnectionPollInterval = 500 * time.Millisecond
)
// Client manages a netbird embedded client instance.
type Client struct {
deviceName string
@@ -45,7 +38,6 @@ type Client struct {
setupKey string
jwtToken string
connect *internal.ConnectClient
recorder *peer.Status
}
// Options configures a new Client.
@@ -169,17 +161,11 @@ func New(opts Options) (*Client, error) {
func (c *Client) Start(startCtx context.Context) error {
c.mu.Lock()
defer c.mu.Unlock()
if c.connect != nil {
if c.cancel != nil {
return ErrClientAlreadyStarted
}
ctx, cancel := context.WithCancel(internal.CtxInitState(context.Background()))
defer func() {
if c.connect == nil {
cancel()
}
}()
ctx := internal.CtxInitState(context.Background())
// nolint:staticcheck
ctx = context.WithValue(ctx, system.DeviceNameCtxKey, c.deviceName)
if err := internal.Login(ctx, c.config, c.setupKey, c.jwtToken); err != nil {
@@ -187,9 +173,7 @@ func (c *Client) Start(startCtx context.Context) error {
}
recorder := peer.NewRecorder(c.config.ManagementURL.String())
c.recorder = recorder
client := internal.NewConnectClient(ctx, c.config, recorder, false)
client.SetSyncResponsePersistence(true)
// either startup error (permanent backoff err) or nil err (successful engine up)
// TODO: make after-startup backoff err available
@@ -213,7 +197,6 @@ func (c *Client) Start(startCtx context.Context) error {
}
c.connect = client
c.cancel = cancel
return nil
}
@@ -228,23 +211,17 @@ func (c *Client) Stop(ctx context.Context) error {
return ErrClientNotStarted
}
if c.cancel != nil {
c.cancel()
c.cancel = nil
}
done := make(chan error, 1)
connect := c.connect
go func() {
done <- connect.Stop()
done <- c.connect.Stop()
}()
select {
case <-ctx.Done():
c.connect = nil
c.cancel = nil
return ctx.Err()
case err := <-done:
c.connect = nil
c.cancel = nil
if err != nil {
return fmt.Errorf("stop: %w", err)
}
@@ -264,40 +241,18 @@ func (c *Client) GetConfig() (profilemanager.Config, error) {
// Dial dials a network address in the netbird network.
// Not applicable if the userspace networking mode is disabled.
// With lazy connections, the connection is established on first traffic.
func (c *Client) Dial(ctx context.Context, network, address string) (net.Conn, error) {
logrus.Infof("embed.Dial called: network=%s, address=%s", network, address)
// Check context status upfront
if ctx.Err() != nil {
logrus.Warnf("embed.Dial: context already cancelled/expired: %v", ctx.Err())
return nil, ctx.Err()
}
engine, err := c.getEngine()
if err != nil {
logrus.Errorf("embed.Dial: getEngine failed: %v", err)
return nil, err
}
nsnet, err := engine.GetNet()
if err != nil {
logrus.Errorf("embed.Dial: GetNet failed: %v", err)
return nil, fmt.Errorf("get net: %w", err)
}
// Note: Don't wait for peer connection here - lazy connection manager
// will open the connection when DialContext is called. The netstack
// dial triggers WireGuard traffic which activates the lazy connection.
logrus.Debugf("embed.Dial: calling nsnet.DialContext for %s", address)
conn, err := nsnet.DialContext(ctx, network, address)
if err != nil {
logrus.Errorf("embed.Dial: nsnet.DialContext failed: %v", err)
return nil, err
}
logrus.Infof("embed.Dial: successfully connected to %s", address)
return conn, nil
return nsnet.DialContext(ctx, network, address)
}
// DialContext dials a network address in the netbird network with context
@@ -360,90 +315,6 @@ func (c *Client) NewHTTPClient() *http.Client {
}
}
// Status returns the current status of the client.
func (c *Client) Status() (peer.FullStatus, error) {
c.mu.Lock()
recorder := c.recorder
connect := c.connect
c.mu.Unlock()
if recorder == nil {
return peer.FullStatus{}, errors.New("client not started")
}
if connect != nil {
engine := connect.Engine()
if engine != nil {
_ = engine.RunHealthProbes(false)
}
}
return recorder.GetFullStatus(), nil
}
// GetLatestSyncResponse returns the latest sync response from the management server.
func (c *Client) GetLatestSyncResponse() (*mgmProto.SyncResponse, error) {
engine, err := c.getEngine()
if err != nil {
return nil, err
}
syncResp, err := engine.GetLatestSyncResponse()
if err != nil {
return nil, fmt.Errorf("get sync response: %w", err)
}
return syncResp, nil
}
// WaitForPeerConnection waits for a peer with the given IP to be connected.
func (c *Client) WaitForPeerConnection(ctx context.Context, peerIP string) error {
logrus.Infof("Waiting for peer %s to be connected", peerIP)
ticker := time.NewTicker(peerConnectionPollInterval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return fmt.Errorf("timeout waiting for peer %s to connect: %w", peerIP, ctx.Err())
case <-ticker.C:
status, err := c.Status()
if err != nil {
logrus.Debugf("Error getting status while waiting for peer: %v", err)
continue
}
for _, p := range status.Peers {
if p.IP == peerIP && p.ConnStatus == peer.StatusConnected {
logrus.Infof("Peer %s is now connected (relayed: %v)", peerIP, p.Relayed)
return nil
}
}
logrus.Tracef("Peer %s not yet connected, waiting...", peerIP)
}
}
}
// SetLogLevel sets the logging level for the client and its components.
func (c *Client) SetLogLevel(levelStr string) error {
level, err := logrus.ParseLevel(levelStr)
if err != nil {
return fmt.Errorf("parse log level: %w", err)
}
logrus.SetLevel(level)
c.mu.Lock()
connect := c.connect
c.mu.Unlock()
// Note: ConnectClient doesn't have SetLogLevel method
_ = connect
return nil
}
// VerifySSHHostKey verifies an SSH host key against stored peer keys.
// Returns nil if the key matches, ErrPeerNotFound if peer is not in network,
// ErrNoStoredKey if peer has no stored key, or an error for verification failures.

View File

@@ -386,8 +386,11 @@ func (m *aclManager) updateState() {
// filterRuleSpecs returns the specs of a filtering rule
func filterRuleSpecs(ip net.IP, protocol string, sPort, dPort *firewall.Port, action firewall.Action, ipsetName string) (specs []string) {
matchByIP := true
// don't use IP matching if IP is 0.0.0.0
matchByIP := !ip.IsUnspecified()
if ip.IsUnspecified() {
matchByIP = false
}
if matchByIP {
if ipsetName != "" {

View File

@@ -161,7 +161,7 @@ func TestIptablesManagerDenyRules(t *testing.T) {
t.Logf(" [%d] %s", i, rule)
}
var denyRuleIndex, acceptRuleIndex = -1, -1
var denyRuleIndex, acceptRuleIndex int = -1, -1
for i, rule := range rules {
if strings.Contains(rule, "DROP") {
t.Logf("Found DROP rule at index %d: %s", i, rule)

View File

@@ -198,7 +198,7 @@ func TestNftablesManagerRuleOrder(t *testing.T) {
t.Logf("Found %d rules in nftables chain", len(rules))
// Find the accept and deny rules and verify deny comes before accept
var acceptRuleIndex, denyRuleIndex = -1, -1
var acceptRuleIndex, denyRuleIndex int = -1, -1
for i, rule := range rules {
hasAcceptHTTPSet := false
hasDenyHTTPSet := false
@@ -208,13 +208,11 @@ func TestNftablesManagerRuleOrder(t *testing.T) {
for _, e := range rule.Exprs {
// Check for set lookup
if lookup, ok := e.(*expr.Lookup); ok {
switch lookup.SetName {
case "accept-http":
if lookup.SetName == "accept-http" {
hasAcceptHTTPSet = true
case "deny-http":
} else if lookup.SetName == "deny-http" {
hasDenyHTTPSet = true
}
}
// Check for port 80
if cmp, ok := e.(*expr.Cmp); ok {
@@ -224,10 +222,9 @@ func TestNftablesManagerRuleOrder(t *testing.T) {
}
// Check for verdict
if verdict, ok := e.(*expr.Verdict); ok {
switch verdict.Kind {
case expr.VerdictAccept:
if verdict.Kind == expr.VerdictAccept {
action = "ACCEPT"
case expr.VerdictDrop:
} else if verdict.Kind == expr.VerdictDrop {
action = "DROP"
}
}

View File

@@ -29,7 +29,7 @@ import (
)
const (
layerTypeAll = 255
layerTypeAll = 0
// ipTCPHeaderMinSize represents minimum IP (20) + TCP (20) header size for MSS calculation
ipTCPHeaderMinSize = 40
@@ -262,7 +262,10 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe
}
func (m *Manager) blockInvalidRouted(iface common.IFaceMapper) (firewall.Rule, error) {
wgPrefix := iface.Address().Network
wgPrefix, err := netip.ParsePrefix(iface.Address().Network.String())
if err != nil {
return nil, fmt.Errorf("parse wireguard network: %w", err)
}
log.Debugf("blocking invalid routed traffic for %s", wgPrefix)
rule, err := m.addRouteFiltering(
@@ -436,7 +439,19 @@ func (m *Manager) AddPeerFiltering(
r.sPort = sPort
r.dPort = dPort
r.protoLayer = protoToLayer(proto, r.ipLayer)
switch proto {
case firewall.ProtocolTCP:
r.protoLayer = layers.LayerTypeTCP
case firewall.ProtocolUDP:
r.protoLayer = layers.LayerTypeUDP
case firewall.ProtocolICMP:
r.protoLayer = layers.LayerTypeICMPv4
if r.ipLayer == layers.LayerTypeIPv6 {
r.protoLayer = layers.LayerTypeICMPv6
}
case firewall.ProtocolALL:
r.protoLayer = layerTypeAll
}
m.mutex.Lock()
var targetMap map[netip.Addr]RuleSet
@@ -481,17 +496,16 @@ func (m *Manager) addRouteFiltering(
}
ruleID := uuid.New().String()
rule := RouteRule{
// TODO: consolidate these IDs
id: ruleID,
mgmtId: id,
sources: sources,
dstSet: destination.Set,
protoLayer: protoToLayer(proto, layers.LayerTypeIPv4),
srcPort: sPort,
dstPort: dPort,
action: action,
id: ruleID,
mgmtId: id,
sources: sources,
dstSet: destination.Set,
proto: proto,
srcPort: sPort,
dstPort: dPort,
action: action,
}
if destination.IsPrefix() {
rule.destinations = []netip.Prefix{destination.Prefix}
@@ -781,7 +795,7 @@ func (m *Manager) recalculateTCPChecksum(packetData []byte, d *decoder, tcpHeade
pseudoSum += uint32(d.ip4.Protocol)
pseudoSum += uint32(tcpLength)
var sum = pseudoSum
var sum uint32 = pseudoSum
for i := 0; i < tcpLength-1; i += 2 {
sum += uint32(tcpLayer[i])<<8 | uint32(tcpLayer[i+1])
}
@@ -931,7 +945,7 @@ func (m *Manager) filterInbound(packetData []byte, size int) bool {
func (m *Manager) handleLocalTraffic(d *decoder, srcIP, dstIP netip.Addr, packetData []byte, size int) bool {
ruleID, blocked := m.peerACLsBlock(srcIP, d, packetData)
if blocked {
pnum := getProtocolFromPacket(d)
_, pnum := getProtocolFromPacket(d)
srcPort, dstPort := getPortsFromPacket(d)
m.logger.Trace6("Dropping local packet (ACL denied): rule_id=%s proto=%v src=%s:%d dst=%s:%d",
@@ -996,22 +1010,20 @@ func (m *Manager) handleRoutedTraffic(d *decoder, srcIP, dstIP netip.Addr, packe
return false
}
protoLayer := d.decoded[1]
proto, pnum := getProtocolFromPacket(d)
srcPort, dstPort := getPortsFromPacket(d)
ruleID, pass := m.routeACLsPass(srcIP, dstIP, protoLayer, srcPort, dstPort)
ruleID, pass := m.routeACLsPass(srcIP, dstIP, proto, srcPort, dstPort)
if !pass {
proto := getProtocolFromPacket(d)
m.logger.Trace6("Dropping routed packet (ACL denied): rule_id=%s proto=%v src=%s:%d dst=%s:%d",
ruleID, proto, srcIP, srcPort, dstIP, dstPort)
ruleID, pnum, srcIP, srcPort, dstIP, dstPort)
m.flowLogger.StoreEvent(nftypes.EventFields{
FlowID: uuid.New(),
Type: nftypes.TypeDrop,
RuleID: ruleID,
Direction: nftypes.Ingress,
Protocol: proto,
Protocol: pnum,
SourceIP: srcIP,
DestIP: dstIP,
SourcePort: srcPort,
@@ -1040,33 +1052,16 @@ func (m *Manager) handleRoutedTraffic(d *decoder, srcIP, dstIP netip.Addr, packe
return true
}
func protoToLayer(proto firewall.Protocol, ipLayer gopacket.LayerType) gopacket.LayerType {
switch proto {
case firewall.ProtocolTCP:
return layers.LayerTypeTCP
case firewall.ProtocolUDP:
return layers.LayerTypeUDP
case firewall.ProtocolICMP:
if ipLayer == layers.LayerTypeIPv6 {
return layers.LayerTypeICMPv6
}
return layers.LayerTypeICMPv4
case firewall.ProtocolALL:
return layerTypeAll
}
return 0
}
func getProtocolFromPacket(d *decoder) nftypes.Protocol {
func getProtocolFromPacket(d *decoder) (firewall.Protocol, nftypes.Protocol) {
switch d.decoded[1] {
case layers.LayerTypeTCP:
return nftypes.TCP
return firewall.ProtocolTCP, nftypes.TCP
case layers.LayerTypeUDP:
return nftypes.UDP
return firewall.ProtocolUDP, nftypes.UDP
case layers.LayerTypeICMPv4, layers.LayerTypeICMPv6:
return nftypes.ICMP
return firewall.ProtocolICMP, nftypes.ICMP
default:
return nftypes.ProtocolUnknown
return firewall.ProtocolALL, nftypes.ProtocolUnknown
}
}
@@ -1238,30 +1233,19 @@ func validateRule(ip netip.Addr, packetData []byte, rules map[string]PeerRule, d
}
// routeACLsPass returns true if the packet is allowed by the route ACLs
func (m *Manager) routeACLsPass(srcIP, dstIP netip.Addr, protoLayer gopacket.LayerType, srcPort, dstPort uint16) ([]byte, bool) {
func (m *Manager) routeACLsPass(srcIP, dstIP netip.Addr, proto firewall.Protocol, srcPort, dstPort uint16) ([]byte, bool) {
m.mutex.RLock()
defer m.mutex.RUnlock()
for _, rule := range m.routeRules {
if matches := m.ruleMatches(rule, srcIP, dstIP, protoLayer, srcPort, dstPort); matches {
if matches := m.ruleMatches(rule, srcIP, dstIP, proto, srcPort, dstPort); matches {
return rule.mgmtId, rule.action == firewall.ActionAccept
}
}
return nil, false
}
func (m *Manager) ruleMatches(rule *RouteRule, srcAddr, dstAddr netip.Addr, protoLayer gopacket.LayerType, srcPort, dstPort uint16) bool {
// TODO: handle ipv6 vs ipv4 icmp rules
if rule.protoLayer != layerTypeAll && rule.protoLayer != protoLayer {
return false
}
if protoLayer == layers.LayerTypeTCP || protoLayer == layers.LayerTypeUDP {
if !portsMatch(rule.srcPort, srcPort) || !portsMatch(rule.dstPort, dstPort) {
return false
}
}
func (m *Manager) ruleMatches(rule *RouteRule, srcAddr, dstAddr netip.Addr, proto firewall.Protocol, srcPort, dstPort uint16) bool {
destMatched := false
for _, dst := range rule.destinations {
if dst.Contains(dstAddr) {
@@ -1280,8 +1264,21 @@ func (m *Manager) ruleMatches(rule *RouteRule, srcAddr, dstAddr netip.Addr, prot
break
}
}
if !sourceMatched {
return false
}
return sourceMatched
if rule.proto != firewall.ProtocolALL && rule.proto != proto {
return false
}
if proto == firewall.ProtocolTCP || proto == firewall.ProtocolUDP {
if !portsMatch(rule.srcPort, srcPort) || !portsMatch(rule.dstPort, dstPort) {
return false
}
}
return true
}
// AddUDPPacketHook calls hook when UDP packet from given direction matched

View File

@@ -955,7 +955,7 @@ func BenchmarkRouteACLs(b *testing.B) {
for _, tc := range cases {
srcIP := netip.MustParseAddr(tc.srcIP)
dstIP := netip.MustParseAddr(tc.dstIP)
manager.routeACLsPass(srcIP, dstIP, protoToLayer(tc.proto, layers.LayerTypeIPv4), 0, tc.dstPort)
manager.routeACLsPass(srcIP, dstIP, tc.proto, 0, tc.dstPort)
}
}
}

View File

@@ -1259,7 +1259,7 @@ func TestRouteACLFiltering(t *testing.T) {
// testing routeACLsPass only and not FilterInbound, as routed packets are dropped after being passed
// to the forwarder
_, isAllowed := manager.routeACLsPass(srcIP, dstIP, protoToLayer(tc.proto, layers.LayerTypeIPv4), tc.srcPort, tc.dstPort)
_, isAllowed := manager.routeACLsPass(srcIP, dstIP, tc.proto, tc.srcPort, tc.dstPort)
require.Equal(t, tc.shouldPass, isAllowed)
})
}
@@ -1445,7 +1445,7 @@ func TestRouteACLOrder(t *testing.T) {
srcIP := netip.MustParseAddr(p.srcIP)
dstIP := netip.MustParseAddr(p.dstIP)
_, isAllowed := manager.routeACLsPass(srcIP, dstIP, protoToLayer(p.proto, layers.LayerTypeIPv4), p.srcPort, p.dstPort)
_, isAllowed := manager.routeACLsPass(srcIP, dstIP, p.proto, p.srcPort, p.dstPort)
require.Equal(t, p.shouldPass, isAllowed, "packet %d failed", i)
}
})
@@ -1488,13 +1488,13 @@ func TestRouteACLSet(t *testing.T) {
dstIP := netip.MustParseAddr("192.168.1.100")
// Check that traffic is dropped (empty set shouldn't match anything)
_, isAllowed := manager.routeACLsPass(srcIP, dstIP, protoToLayer(fw.ProtocolTCP, layers.LayerTypeIPv4), 12345, 80)
_, isAllowed := manager.routeACLsPass(srcIP, dstIP, fw.ProtocolTCP, 12345, 80)
require.False(t, isAllowed, "Empty set should not allow any traffic")
err = manager.UpdateSet(set, []netip.Prefix{netip.MustParsePrefix("192.168.1.0/24")})
require.NoError(t, err)
// Now the packet should be allowed
_, isAllowed = manager.routeACLsPass(srcIP, dstIP, protoToLayer(fw.ProtocolTCP, layers.LayerTypeIPv4), 12345, 80)
_, isAllowed = manager.routeACLsPass(srcIP, dstIP, fw.ProtocolTCP, 12345, 80)
require.True(t, isAllowed, "After set update, traffic to the added network should be allowed")
}

View File

@@ -767,9 +767,9 @@ func TestUpdateSetMerge(t *testing.T) {
dstIP2 := netip.MustParseAddr("192.168.1.100")
dstIP3 := netip.MustParseAddr("172.16.0.100")
_, isAllowed1 := manager.routeACLsPass(srcIP, dstIP1, protoToLayer(fw.ProtocolTCP, layers.LayerTypeIPv4), 12345, 80)
_, isAllowed2 := manager.routeACLsPass(srcIP, dstIP2, protoToLayer(fw.ProtocolTCP, layers.LayerTypeIPv4), 12345, 80)
_, isAllowed3 := manager.routeACLsPass(srcIP, dstIP3, protoToLayer(fw.ProtocolTCP, layers.LayerTypeIPv4), 12345, 80)
_, isAllowed1 := manager.routeACLsPass(srcIP, dstIP1, fw.ProtocolTCP, 12345, 80)
_, isAllowed2 := manager.routeACLsPass(srcIP, dstIP2, fw.ProtocolTCP, 12345, 80)
_, isAllowed3 := manager.routeACLsPass(srcIP, dstIP3, fw.ProtocolTCP, 12345, 80)
require.True(t, isAllowed1, "Traffic to 10.0.0.100 should be allowed")
require.True(t, isAllowed2, "Traffic to 192.168.1.100 should be allowed")
@@ -784,8 +784,8 @@ func TestUpdateSetMerge(t *testing.T) {
require.NoError(t, err)
// Check that all original prefixes are still included
_, isAllowed1 = manager.routeACLsPass(srcIP, dstIP1, protoToLayer(fw.ProtocolTCP, layers.LayerTypeIPv4), 12345, 80)
_, isAllowed2 = manager.routeACLsPass(srcIP, dstIP2, protoToLayer(fw.ProtocolTCP, layers.LayerTypeIPv4), 12345, 80)
_, isAllowed1 = manager.routeACLsPass(srcIP, dstIP1, fw.ProtocolTCP, 12345, 80)
_, isAllowed2 = manager.routeACLsPass(srcIP, dstIP2, fw.ProtocolTCP, 12345, 80)
require.True(t, isAllowed1, "Traffic to 10.0.0.100 should still be allowed after update")
require.True(t, isAllowed2, "Traffic to 192.168.1.100 should still be allowed after update")
@@ -793,8 +793,8 @@ func TestUpdateSetMerge(t *testing.T) {
dstIP4 := netip.MustParseAddr("172.16.1.100")
dstIP5 := netip.MustParseAddr("10.1.0.50")
_, isAllowed4 := manager.routeACLsPass(srcIP, dstIP4, protoToLayer(fw.ProtocolTCP, layers.LayerTypeIPv4), 12345, 80)
_, isAllowed5 := manager.routeACLsPass(srcIP, dstIP5, protoToLayer(fw.ProtocolTCP, layers.LayerTypeIPv4), 12345, 80)
_, isAllowed4 := manager.routeACLsPass(srcIP, dstIP4, fw.ProtocolTCP, 12345, 80)
_, isAllowed5 := manager.routeACLsPass(srcIP, dstIP5, fw.ProtocolTCP, 12345, 80)
require.True(t, isAllowed4, "Traffic to new prefix 172.16.0.0/16 should be allowed")
require.True(t, isAllowed5, "Traffic to new prefix 10.1.0.0/24 should be allowed")
@@ -922,7 +922,7 @@ func TestUpdateSetDeduplication(t *testing.T) {
srcIP := netip.MustParseAddr("100.10.0.1")
for _, tc := range testCases {
_, isAllowed := manager.routeACLsPass(srcIP, tc.dstIP, protoToLayer(fw.ProtocolTCP, layers.LayerTypeIPv4), 12345, 80)
_, isAllowed := manager.routeACLsPass(srcIP, tc.dstIP, fw.ProtocolTCP, 12345, 80)
require.Equal(t, tc.expected, isAllowed, tc.desc)
}
}

View File

@@ -2,7 +2,6 @@ package forwarder
import (
"fmt"
"sync/atomic"
wgdevice "golang.zx2c4.com/wireguard/device"
"gvisor.dev/gvisor/pkg/tcpip"
@@ -17,7 +16,7 @@ type endpoint struct {
logger *nblog.Logger
dispatcher stack.NetworkDispatcher
device *wgdevice.Device
mtu atomic.Uint32
mtu uint32
}
func (e *endpoint) Attach(dispatcher stack.NetworkDispatcher) {
@@ -29,7 +28,7 @@ func (e *endpoint) IsAttached() bool {
}
func (e *endpoint) MTU() uint32 {
return e.mtu.Load()
return e.mtu
}
func (e *endpoint) Capabilities() stack.LinkEndpointCapabilities {
@@ -83,22 +82,6 @@ func (e *endpoint) ParseHeader(*stack.PacketBuffer) bool {
return true
}
func (e *endpoint) Close() {
// Endpoint cleanup - nothing to do as device is managed externally
}
func (e *endpoint) SetLinkAddress(tcpip.LinkAddress) {
// Link address is not used for this endpoint type
}
func (e *endpoint) SetMTU(mtu uint32) {
e.mtu.Store(mtu)
}
func (e *endpoint) SetOnCloseAction(func()) {
// No action needed on close
}
type epID stack.TransportEndpointID
func (i epID) String() string {

View File

@@ -7,7 +7,6 @@ import (
"net/netip"
"runtime"
"sync"
"time"
log "github.com/sirupsen/logrus"
"gvisor.dev/gvisor/pkg/buffer"
@@ -36,16 +35,14 @@ type Forwarder struct {
logger *nblog.Logger
flowLogger nftypes.FlowLogger
// ruleIdMap is used to store the rule ID for a given connection
ruleIdMap sync.Map
stack *stack.Stack
endpoint *endpoint
udpForwarder *udpForwarder
ctx context.Context
cancel context.CancelFunc
ip tcpip.Address
netstack bool
hasRawICMPAccess bool
pingSemaphore chan struct{}
ruleIdMap sync.Map
stack *stack.Stack
endpoint *endpoint
udpForwarder *udpForwarder
ctx context.Context
cancel context.CancelFunc
ip tcpip.Address
netstack bool
}
func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.FlowLogger, netstack bool, mtu uint16) (*Forwarder, error) {
@@ -63,8 +60,8 @@ func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.Flow
endpoint := &endpoint{
logger: logger,
device: iface.GetWGDevice(),
mtu: uint32(mtu),
}
endpoint.mtu.Store(uint32(mtu))
if err := s.CreateNIC(nicID, endpoint); err != nil {
return nil, fmt.Errorf("create NIC: %v", err)
@@ -106,16 +103,15 @@ func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.Flow
ctx, cancel := context.WithCancel(context.Background())
f := &Forwarder{
logger: logger,
flowLogger: flowLogger,
stack: s,
endpoint: endpoint,
udpForwarder: newUDPForwarder(mtu, logger, flowLogger),
ctx: ctx,
cancel: cancel,
netstack: netstack,
ip: tcpip.AddrFromSlice(iface.Address().IP.AsSlice()),
pingSemaphore: make(chan struct{}, 3),
logger: logger,
flowLogger: flowLogger,
stack: s,
endpoint: endpoint,
udpForwarder: newUDPForwarder(mtu, logger, flowLogger),
ctx: ctx,
cancel: cancel,
netstack: netstack,
ip: tcpip.AddrFromSlice(iface.Address().IP.AsSlice()),
}
receiveWindow := defaultReceiveWindow
@@ -133,8 +129,6 @@ func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.Flow
s.SetTransportProtocolHandler(icmp.ProtocolNumber4, f.handleICMP)
f.checkICMPCapability()
log.Debugf("forwarder: Initialization complete with NIC %d", nicID)
return f, nil
}
@@ -204,24 +198,3 @@ func buildKey(srcIP, dstIP netip.Addr, srcPort, dstPort uint16) conntrack.ConnKe
DstPort: dstPort,
}
}
// checkICMPCapability tests whether we have raw ICMP socket access at startup.
func (f *Forwarder) checkICMPCapability() {
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
lc := net.ListenConfig{}
conn, err := lc.ListenPacket(ctx, "ip4:icmp", "0.0.0.0")
if err != nil {
f.hasRawICMPAccess = false
f.logger.Debug("forwarder: No raw ICMP socket access, will use ping binary fallback")
return
}
if err := conn.Close(); err != nil {
f.logger.Debug1("forwarder: Failed to close ICMP capability test socket: %v", err)
}
f.hasRawICMPAccess = true
f.logger.Debug("forwarder: Raw ICMP socket access available")
}

View File

@@ -2,11 +2,8 @@ package forwarder
import (
"context"
"fmt"
"net"
"net/netip"
"os/exec"
"runtime"
"time"
"github.com/google/uuid"
@@ -17,95 +14,30 @@ import (
)
// handleICMP handles ICMP packets from the network stack
func (f *Forwarder) handleICMP(id stack.TransportEndpointID, pkt *stack.PacketBuffer) bool {
func (f *Forwarder) handleICMP(id stack.TransportEndpointID, pkt stack.PacketBufferPtr) bool {
icmpHdr := header.ICMPv4(pkt.TransportHeader().View().AsSlice())
icmpType := uint8(icmpHdr.Type())
icmpCode := uint8(icmpHdr.Code())
flowID := uuid.New()
f.sendICMPEvent(nftypes.TypeStart, flowID, id, uint8(icmpHdr.Type()), uint8(icmpHdr.Code()), 0, 0)
// For Echo Requests, send and wait for response
if icmpHdr.Type() == header.ICMPv4Echo {
return f.handleICMPEcho(flowID, id, pkt, uint8(icmpHdr.Type()), uint8(icmpHdr.Code()))
}
// For other ICMP types (Time Exceeded, Destination Unreachable, etc), forward without waiting
if !f.hasRawICMPAccess {
f.logger.Debug2("forwarder: Cannot handle ICMP type %v without raw socket access for %v", icmpHdr.Type(), epID(id))
return false
}
icmpData := stack.PayloadSince(pkt.TransportHeader()).AsSlice()
conn, err := f.forwardICMPPacket(id, icmpData, uint8(icmpHdr.Type()), uint8(icmpHdr.Code()), 100*time.Millisecond)
if err != nil {
f.logger.Error2("forwarder: Failed to forward ICMP packet for %v: %v", epID(id), err)
if header.ICMPv4Type(icmpType) == header.ICMPv4EchoReply {
// dont process our own replies
return true
}
if err := conn.Close(); err != nil {
f.logger.Debug1("forwarder: Failed to close ICMP socket: %v", err)
}
return true
}
flowID := uuid.New()
f.sendICMPEvent(nftypes.TypeStart, flowID, id, icmpType, icmpCode, 0, 0)
// handleICMPEcho handles ICMP echo requests asynchronously with rate limiting.
func (f *Forwarder) handleICMPEcho(flowID uuid.UUID, id stack.TransportEndpointID, pkt *stack.PacketBuffer, icmpType, icmpCode uint8) bool {
select {
case f.pingSemaphore <- struct{}{}:
icmpData := stack.PayloadSince(pkt.TransportHeader()).ToSlice()
rxBytes := pkt.Size()
go func() {
defer func() { <-f.pingSemaphore }()
if f.hasRawICMPAccess {
f.handleICMPViaSocket(flowID, id, icmpType, icmpCode, icmpData, rxBytes)
} else {
f.handleICMPViaPing(flowID, id, icmpType, icmpCode, icmpData, rxBytes)
}
}()
default:
f.logger.Debug3("forwarder: ICMP rate limit exceeded for %v type %v code %v",
epID(id), icmpType, icmpCode)
}
return true
}
// forwardICMPPacket creates a raw ICMP socket and sends the packet, returning the connection.
// The caller is responsible for closing the returned connection.
func (f *Forwarder) forwardICMPPacket(id stack.TransportEndpointID, payload []byte, icmpType, icmpCode uint8, timeout time.Duration) (net.PacketConn, error) {
ctx, cancel := context.WithTimeout(f.ctx, timeout)
ctx, cancel := context.WithTimeout(f.ctx, 5*time.Second)
defer cancel()
lc := net.ListenConfig{}
// TODO: support non-root
conn, err := lc.ListenPacket(ctx, "ip4:icmp", "0.0.0.0")
if err != nil {
return nil, fmt.Errorf("create ICMP socket: %w", err)
}
f.logger.Error2("forwarder: Failed to create ICMP socket for %v: %v", epID(id), err)
dstIP := f.determineDialAddr(id.LocalAddress)
dst := &net.IPAddr{IP: dstIP}
if _, err = conn.WriteTo(payload, dst); err != nil {
if closeErr := conn.Close(); closeErr != nil {
f.logger.Debug1("forwarder: Failed to close ICMP socket: %v", closeErr)
}
return nil, fmt.Errorf("write ICMP packet: %w", err)
}
f.logger.Trace3("forwarder: Forwarded ICMP packet %v type %v code %v",
epID(id), icmpType, icmpCode)
return conn, nil
}
// handleICMPViaSocket handles ICMP echo requests using raw sockets.
func (f *Forwarder) handleICMPViaSocket(flowID uuid.UUID, id stack.TransportEndpointID, icmpType, icmpCode uint8, icmpData []byte, rxBytes int) {
sendTime := time.Now()
conn, err := f.forwardICMPPacket(id, icmpData, icmpType, icmpCode, 5*time.Second)
if err != nil {
f.logger.Error2("forwarder: Failed to send ICMP packet for %v: %v", epID(id), err)
return
// This will make netstack reply on behalf of the original destination, that's ok for now
return false
}
defer func() {
if err := conn.Close(); err != nil {
@@ -113,22 +45,38 @@ func (f *Forwarder) handleICMPViaSocket(flowID uuid.UUID, id stack.TransportEndp
}
}()
txBytes := f.handleEchoResponse(conn, id)
rtt := time.Since(sendTime).Round(10 * time.Microsecond)
dstIP := f.determineDialAddr(id.LocalAddress)
dst := &net.IPAddr{IP: dstIP}
f.logger.Trace4("forwarder: Forwarded ICMP echo reply %v type %v code %v (rtt=%v, raw socket)",
epID(id), icmpType, icmpCode, rtt)
fullPacket := stack.PayloadSince(pkt.TransportHeader())
payload := fullPacket.AsSlice()
f.sendICMPEvent(nftypes.TypeEnd, flowID, id, icmpType, icmpCode, uint64(rxBytes), uint64(txBytes))
if _, err = conn.WriteTo(payload, dst); err != nil {
f.logger.Error2("forwarder: Failed to write ICMP packet for %v: %v", epID(id), err)
return true
}
f.logger.Trace3("forwarder: Forwarded ICMP packet %v type %v code %v",
epID(id), icmpHdr.Type(), icmpHdr.Code())
// For Echo Requests, send and handle response
if header.ICMPv4Type(icmpType) == header.ICMPv4Echo {
rxBytes := pkt.Size()
txBytes := f.handleEchoResponse(icmpHdr, conn, id)
f.sendICMPEvent(nftypes.TypeEnd, flowID, id, icmpType, icmpCode, uint64(rxBytes), uint64(txBytes))
}
// For other ICMP types (Time Exceeded, Destination Unreachable, etc) do nothing
return true
}
func (f *Forwarder) handleEchoResponse(conn net.PacketConn, id stack.TransportEndpointID) int {
func (f *Forwarder) handleEchoResponse(icmpHdr header.ICMPv4, conn net.PacketConn, id stack.TransportEndpointID) int {
if err := conn.SetReadDeadline(time.Now().Add(5 * time.Second)); err != nil {
f.logger.Error1("forwarder: Failed to set read deadline for ICMP response: %v", err)
return 0
}
response := make([]byte, f.endpoint.mtu.Load())
response := make([]byte, f.endpoint.mtu)
n, _, err := conn.ReadFrom(response)
if err != nil {
if !isTimeout(err) {
@@ -137,7 +85,31 @@ func (f *Forwarder) handleEchoResponse(conn net.PacketConn, id stack.TransportEn
return 0
}
return f.injectICMPReply(id, response[:n])
ipHdr := make([]byte, header.IPv4MinimumSize)
ip := header.IPv4(ipHdr)
ip.Encode(&header.IPv4Fields{
TotalLength: uint16(header.IPv4MinimumSize + n),
TTL: 64,
Protocol: uint8(header.ICMPv4ProtocolNumber),
SrcAddr: id.LocalAddress,
DstAddr: id.RemoteAddress,
})
ip.SetChecksum(^ip.CalculateChecksum())
fullPacket := make([]byte, 0, len(ipHdr)+n)
fullPacket = append(fullPacket, ipHdr...)
fullPacket = append(fullPacket, response[:n]...)
if err := f.InjectIncomingPacket(fullPacket); err != nil {
f.logger.Error1("forwarder: Failed to inject ICMP response: %v", err)
return 0
}
f.logger.Trace3("forwarder: Forwarded ICMP echo reply for %v type %v code %v",
epID(id), icmpHdr.Type(), icmpHdr.Code())
return len(fullPacket)
}
// sendICMPEvent stores flow events for ICMP packets
@@ -180,95 +152,3 @@ func (f *Forwarder) sendICMPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.T
f.flowLogger.StoreEvent(fields)
}
// handleICMPViaPing handles ICMP echo requests by executing the system ping binary.
// This is used as a fallback when raw socket access is not available.
func (f *Forwarder) handleICMPViaPing(flowID uuid.UUID, id stack.TransportEndpointID, icmpType, icmpCode uint8, icmpData []byte, rxBytes int) {
ctx, cancel := context.WithTimeout(f.ctx, 5*time.Second)
defer cancel()
dstIP := f.determineDialAddr(id.LocalAddress)
cmd := buildPingCommand(ctx, dstIP, 5*time.Second)
pingStart := time.Now()
if err := cmd.Run(); err != nil {
f.logger.Warn4("forwarder: Ping binary failed for %v type %v code %v: %v", epID(id),
icmpType, icmpCode, err)
return
}
rtt := time.Since(pingStart).Round(10 * time.Microsecond)
f.logger.Trace3("forwarder: Forwarded ICMP echo request %v type %v code %v",
epID(id), icmpType, icmpCode)
txBytes := f.synthesizeEchoReply(id, icmpData)
f.logger.Trace4("forwarder: Forwarded ICMP echo reply %v type %v code %v (rtt=%v, ping binary)",
epID(id), icmpType, icmpCode, rtt)
f.sendICMPEvent(nftypes.TypeEnd, flowID, id, icmpType, icmpCode, uint64(rxBytes), uint64(txBytes))
}
// buildPingCommand creates a platform-specific ping command.
func buildPingCommand(ctx context.Context, target net.IP, timeout time.Duration) *exec.Cmd {
timeoutSec := int(timeout.Seconds())
if timeoutSec < 1 {
timeoutSec = 1
}
switch runtime.GOOS {
case "linux", "android":
return exec.CommandContext(ctx, "ping", "-c", "1", "-W", fmt.Sprintf("%d", timeoutSec), "-q", target.String())
case "darwin", "ios":
return exec.CommandContext(ctx, "ping", "-c", "1", "-t", fmt.Sprintf("%d", timeoutSec), "-q", target.String())
case "freebsd":
return exec.CommandContext(ctx, "ping", "-c", "1", "-t", fmt.Sprintf("%d", timeoutSec), target.String())
case "openbsd", "netbsd":
return exec.CommandContext(ctx, "ping", "-c", "1", "-w", fmt.Sprintf("%d", timeoutSec), target.String())
case "windows":
return exec.CommandContext(ctx, "ping", "-n", "1", "-w", fmt.Sprintf("%d", timeoutSec*1000), target.String())
default:
return exec.CommandContext(ctx, "ping", "-c", "1", target.String())
}
}
// synthesizeEchoReply creates an ICMP echo reply from raw ICMP data and injects it back into the network stack.
// Returns the size of the injected packet.
func (f *Forwarder) synthesizeEchoReply(id stack.TransportEndpointID, icmpData []byte) int {
replyICMP := make([]byte, len(icmpData))
copy(replyICMP, icmpData)
replyICMPHdr := header.ICMPv4(replyICMP)
replyICMPHdr.SetType(header.ICMPv4EchoReply)
replyICMPHdr.SetChecksum(0)
replyICMPHdr.SetChecksum(header.ICMPv4Checksum(replyICMPHdr, 0))
return f.injectICMPReply(id, replyICMP)
}
// injectICMPReply wraps an ICMP payload in an IP header and injects it into the network stack.
// Returns the total size of the injected packet, or 0 if injection failed.
func (f *Forwarder) injectICMPReply(id stack.TransportEndpointID, icmpPayload []byte) int {
ipHdr := make([]byte, header.IPv4MinimumSize)
ip := header.IPv4(ipHdr)
ip.Encode(&header.IPv4Fields{
TotalLength: uint16(header.IPv4MinimumSize + len(icmpPayload)),
TTL: 64,
Protocol: uint8(header.ICMPv4ProtocolNumber),
SrcAddr: id.LocalAddress,
DstAddr: id.RemoteAddress,
})
ip.SetChecksum(^ip.CalculateChecksum())
fullPacket := make([]byte, 0, len(ipHdr)+len(icmpPayload))
fullPacket = append(fullPacket, ipHdr...)
fullPacket = append(fullPacket, icmpPayload...)
// Bypass netstack and send directly to peer to avoid looping through our ICMP handler
if err := f.endpoint.device.CreateOutboundPacket(fullPacket, id.RemoteAddress.AsSlice()); err != nil {
f.logger.Error1("forwarder: Failed to send ICMP reply to peer: %v", err)
return 0
}
return len(fullPacket)
}

View File

@@ -4,7 +4,6 @@ import (
"context"
"errors"
"fmt"
"io"
"net"
"net/netip"
"sync"
@@ -132,10 +131,10 @@ func (f *udpForwarder) cleanup() {
}
// handleUDP is called by the UDP forwarder for new packets
func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) bool {
func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) {
if f.ctx.Err() != nil {
f.logger.Trace("forwarder: context done, dropping UDP packet")
return false
return
}
id := r.ID()
@@ -145,7 +144,7 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) bool {
f.udpForwarder.RUnlock()
if exists {
f.logger.Trace1("forwarder: existing UDP connection for %v", epID(id))
return true
return
}
flowID := uuid.New()
@@ -163,7 +162,7 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) bool {
if err != nil {
f.logger.Debug2("forwarder: UDP dial error for %v: %v", epID(id), err)
// TODO: Send ICMP error message
return false
return
}
// Create wait queue for blocking syscalls
@@ -174,10 +173,10 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) bool {
if err := outConn.Close(); err != nil {
f.logger.Debug2("forwarder: UDP outConn close error for %v: %v", epID(id), err)
}
return false
return
}
inConn := gonet.NewUDPConn(&wq, ep)
inConn := gonet.NewUDPConn(f.stack, &wq, ep)
connCtx, connCancel := context.WithCancel(f.ctx)
pConn := &udpPacketConn{
@@ -200,7 +199,7 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) bool {
if err := outConn.Close(); err != nil {
f.logger.Debug2("forwarder: UDP outConn close error for %v: %v", epID(id), err)
}
return true
return
}
f.udpForwarder.conns[id] = pConn
f.udpForwarder.Unlock()
@@ -209,7 +208,6 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) bool {
f.logger.Trace1("forwarder: established UDP connection %v", epID(id))
go f.proxyUDP(connCtx, pConn, id, ep)
return true
}
func (f *Forwarder) proxyUDP(ctx context.Context, pConn *udpPacketConn, id stack.TransportEndpointID, ep tcpip.Endpoint) {
@@ -350,7 +348,7 @@ func (c *udpPacketConn) copy(ctx context.Context, dst net.Conn, src net.Conn, bu
}
func isClosedError(err error) bool {
return errors.Is(err, net.ErrClosed) || errors.Is(err, context.Canceled) || errors.Is(err, io.EOF)
return errors.Is(err, net.ErrClosed) || errors.Is(err, context.Canceled)
}
func isTimeout(err error) bool {

View File

@@ -130,7 +130,6 @@ func (m *localIPManager) UpdateLocalIPs(iface common.IFaceMapper) (err error) {
// 127.0.0.0/8
newIPv4Bitmap[127] = &ipv4LowBitmap{}
for i := 0; i < 8192; i++ {
// #nosec G602 -- bitmap is defined as [8192]uint32, loop range is correct
newIPv4Bitmap[127].bitmap[i] = 0xFFFFFFFF
}

View File

@@ -218,7 +218,7 @@ func BenchmarkIPChecks(b *testing.B) {
b.ResetTimer()
for i := 0; i < b.N; i++ {
// nolint:gosimple
_ = mapManager.localIPs[ip.String()]
_, _ = mapManager.localIPs[ip.String()]
}
})
@@ -227,7 +227,7 @@ func BenchmarkIPChecks(b *testing.B) {
b.ResetTimer()
for i := 0; i < b.N; i++ {
// nolint:gosimple
_ = mapManager.localIPs[ip.String()]
_, _ = mapManager.localIPs[ip.String()]
}
})
}

View File

@@ -168,15 +168,6 @@ func (l *Logger) Warn3(format string, arg1, arg2, arg3 any) {
}
}
func (l *Logger) Warn4(format string, arg1, arg2, arg3, arg4 any) {
if l.level.Load() >= uint32(LevelWarn) {
select {
case l.msgChannel <- logMessage{level: LevelWarn, format: format, arg1: arg1, arg2: arg2, arg3: arg3, arg4: arg4}:
default:
}
}
}
func (l *Logger) Debug1(format string, arg1 any) {
if l.level.Load() >= uint32(LevelDebug) {
select {

View File

@@ -234,10 +234,9 @@ func TestInboundPortDNATNegative(t *testing.T) {
require.False(t, translated, "Packet should NOT be translated for %s", tc.name)
d = parsePacket(t, packet)
switch tc.protocol {
case layers.IPProtocolTCP:
if tc.protocol == layers.IPProtocolTCP {
require.Equal(t, tc.dstPort, uint16(d.tcp.DstPort), "Port should remain unchanged")
case layers.IPProtocolUDP:
} else if tc.protocol == layers.IPProtocolUDP {
require.Equal(t, tc.dstPort, uint16(d.udp.DstPort), "Port should remain unchanged")
}
})

View File

@@ -34,7 +34,7 @@ type RouteRule struct {
sources []netip.Prefix
dstSet firewall.Set
destinations []netip.Prefix
protoLayer gopacket.LayerType
proto firewall.Protocol
srcPort *firewall.Port
dstPort *firewall.Port
action firewall.Action

View File

@@ -379,9 +379,9 @@ func (m *Manager) handleNativeRouter(trace *PacketTrace) *PacketTrace {
}
func (m *Manager) handleRouteACLs(trace *PacketTrace, d *decoder, srcIP, dstIP netip.Addr) *PacketTrace {
protoLayer := d.decoded[1]
proto, _ := getProtocolFromPacket(d)
srcPort, dstPort := getPortsFromPacket(d)
id, allowed := m.routeACLsPass(srcIP, dstIP, protoLayer, srcPort, dstPort)
id, allowed := m.routeACLsPass(srcIP, dstIP, proto, srcPort, dstPort)
strId := string(id)
if id == nil {

View File

@@ -27,23 +27,8 @@ type receiverCreator struct {
iceBind *ICEBind
}
func (rc receiverCreator) CreateReceiverFn(pc wgConn.BatchReader, conn *net.UDPConn, rxOffload bool, msgPool *sync.Pool) wgConn.ReceiveFunc {
if ipv4PC, ok := pc.(*ipv4.PacketConn); ok {
return rc.iceBind.createIPv4ReceiverFn(ipv4PC, conn, rxOffload, msgPool)
}
// IPv6 is currently not supported in the udpmux, this is a stub for compatibility with the
// wireguard-go ReceiverCreator interface which is called for both IPv4 and IPv6.
return func(bufs [][]byte, sizes []int, eps []wgConn.Endpoint) (n int, err error) {
buf := bufs[0]
size, ep, err := conn.ReadFromUDPAddrPort(buf)
if err != nil {
return 0, err
}
sizes[0] = size
stdEp := &wgConn.StdNetEndpoint{AddrPort: ep}
eps[0] = stdEp
return 1, nil
}
func (rc receiverCreator) CreateIPv4ReceiverFn(pc *ipv4.PacketConn, conn *net.UDPConn, rxOffload bool, msgPool *sync.Pool) wgConn.ReceiveFunc {
return rc.iceBind.createIPv4ReceiverFn(pc, conn, rxOffload, msgPool)
}
// ICEBind is a bind implementation with two main features:

View File

@@ -1,3 +1,6 @@
//go:build ios
// +build ios
package device
import (

View File

@@ -23,10 +23,10 @@ func NewNSDialer(net *netstack.Net) *NSDialer {
}
func (d *NSDialer) Dial(ctx context.Context, network, addr string) (net.Conn, error) {
log.Infof("NSDialer.Dial: network=%s, addr=%s", network, addr)
log.Debugf("dialing %s %s", network, addr)
conn, err := d.net.Dial(network, addr)
if err != nil {
log.Warnf("NSDialer.Dial failed: %s", err)
log.Debugf("failed to deal connection: %s", err)
}
return conn, err
}

View File

@@ -420,19 +420,6 @@ func (c *ConnectClient) GetLatestSyncResponse() (*mgmProto.SyncResponse, error)
return syncResponse, nil
}
// SetLogLevel sets the log level for the firewall manager if the engine is running.
func (c *ConnectClient) SetLogLevel(level log.Level) {
engine := c.Engine()
if engine == nil {
return
}
fwManager := engine.GetFirewallManager()
if fwManager != nil {
fwManager.SetLogLevel(level)
}
}
// Status returns the current client status
func (c *ConnectClient) Status() StatusType {
if c == nil {

View File

@@ -507,13 +507,15 @@ func formatPayloadWithCmp(p *expr.Payload, cmp *expr.Cmp) string {
if p.Base == expr.PayloadBaseNetworkHeader {
switch p.Offset {
case 12:
switch p.Len {
case 4, 2:
if p.Len == 4 {
return fmt.Sprintf("ip saddr %s %s", formatCmpOp(cmp.Op), formatIPBytes(cmp.Data))
} else if p.Len == 2 {
return fmt.Sprintf("ip saddr %s %s", formatCmpOp(cmp.Op), formatIPBytes(cmp.Data))
}
case 16:
switch p.Len {
case 4, 2:
if p.Len == 4 {
return fmt.Sprintf("ip daddr %s %s", formatCmpOp(cmp.Op), formatIPBytes(cmp.Data))
} else if p.Len == 2 {
return fmt.Sprintf("ip daddr %s %s", formatCmpOp(cmp.Op), formatIPBytes(cmp.Data))
}
}

View File

@@ -76,7 +76,7 @@ func collectPTRRecords(config *nbdns.Config, prefix netip.Prefix) []nbdns.Simple
var records []nbdns.SimpleRecord
for _, zone := range config.CustomZones {
if zone.NonAuthoritative {
if zone.SkipPTRProcess {
continue
}
for _, record := range zone.Records {

View File

@@ -3,21 +3,17 @@ package dns
import (
"fmt"
"slices"
"strconv"
"strings"
"sync"
"time"
"github.com/miekg/dns"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/dns/resutil"
)
const (
PriorityMgmtCache = 150
PriorityDNSRoute = 100
PriorityLocal = 75
PriorityLocal = 100
PriorityDNSRoute = 75
PriorityUpstream = 50
PriorityDefault = 1
PriorityFallback = -100
@@ -47,23 +43,7 @@ type HandlerChain struct {
type ResponseWriterChain struct {
dns.ResponseWriter
origPattern string
requestID string
shouldContinue bool
response *dns.Msg
meta map[string]string
}
// RequestID returns the request ID for tracing
func (w *ResponseWriterChain) RequestID() string {
return w.requestID
}
// SetMeta sets a metadata key-value pair for logging
func (w *ResponseWriterChain) SetMeta(key, value string) {
if w.meta == nil {
w.meta = make(map[string]string)
}
w.meta[key] = value
}
func (w *ResponseWriterChain) WriteMsg(m *dns.Msg) error {
@@ -72,7 +52,6 @@ func (w *ResponseWriterChain) WriteMsg(m *dns.Msg) error {
w.shouldContinue = true
return nil
}
w.response = m
return w.ResponseWriter.WriteMsg(m)
}
@@ -122,8 +101,6 @@ func (c *HandlerChain) AddHandler(pattern string, handler dns.Handler, priority
pos := c.findHandlerPosition(entry)
c.handlers = append(c.handlers[:pos], append([]HandlerEntry{entry}, c.handlers[pos:]...)...)
c.logHandlers()
}
// findHandlerPosition determines where to insert a new handler based on priority and specificity
@@ -163,109 +140,68 @@ func (c *HandlerChain) removeEntry(pattern string, priority int) {
for i := len(c.handlers) - 1; i >= 0; i-- {
entry := c.handlers[i]
if strings.EqualFold(entry.OrigPattern, pattern) && entry.Priority == priority {
log.Debugf("removing handler pattern: domain=%s priority=%d", entry.OrigPattern, priority)
c.handlers = append(c.handlers[:i], c.handlers[i+1:]...)
c.logHandlers()
break
}
}
}
// logHandlers logs the current handler chain state. Caller must hold the lock.
func (c *HandlerChain) logHandlers() {
if !log.IsLevelEnabled(log.TraceLevel) {
return
}
var b strings.Builder
b.WriteString("handler chain (" + strconv.Itoa(len(c.handlers)) + "):\n")
for _, h := range c.handlers {
b.WriteString(" - pattern: domain=" + h.Pattern + " original: domain=" + h.OrigPattern +
" wildcard=" + strconv.FormatBool(h.IsWildcard) +
" match_subdomain=" + strconv.FormatBool(h.MatchSubdomains) +
" priority=" + strconv.Itoa(h.Priority) + "\n")
}
log.Trace(strings.TrimSuffix(b.String(), "\n"))
}
func (c *HandlerChain) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
if len(r.Question) == 0 {
return
}
startTime := time.Now()
requestID := resutil.GenerateRequestID()
logger := log.WithFields(log.Fields{
"request_id": requestID,
"dns_id": fmt.Sprintf("%04x", r.Id),
})
question := r.Question[0]
qname := strings.ToLower(question.Name)
qname := strings.ToLower(r.Question[0].Name)
c.mu.RLock()
handlers := slices.Clone(c.handlers)
c.mu.RUnlock()
if log.IsLevelEnabled(log.TraceLevel) {
var b strings.Builder
b.WriteString(fmt.Sprintf("DNS request domain=%s, handlers (%d):\n", qname, len(handlers)))
for _, h := range handlers {
b.WriteString(fmt.Sprintf(" - pattern: domain=%s original: domain=%s wildcard=%v match_subdomain=%v priority=%d\n",
h.Pattern, h.OrigPattern, h.IsWildcard, h.MatchSubdomains, h.Priority))
}
log.Trace(strings.TrimSuffix(b.String(), "\n"))
}
// Try handlers in priority order
for _, entry := range handlers {
if !c.isHandlerMatch(qname, entry) {
continue
}
matched := c.isHandlerMatch(qname, entry)
handlerName := entry.OrigPattern
if s, ok := entry.Handler.(interface{ String() string }); ok {
handlerName = s.String()
}
if matched {
log.Tracef("handler matched: domain=%s -> pattern=%s wildcard=%v match_subdomain=%v priority=%d",
qname, entry.OrigPattern, entry.IsWildcard, entry.MatchSubdomains, entry.Priority)
logger.Tracef("question: domain=%s type=%s class=%s -> handler=%s pattern=%s wildcard=%v match_subdomain=%v priority=%d",
qname, dns.TypeToString[question.Qtype], dns.ClassToString[question.Qclass],
handlerName, entry.OrigPattern, entry.IsWildcard, entry.MatchSubdomains, entry.Priority)
chainWriter := &ResponseWriterChain{
ResponseWriter: w,
origPattern: entry.OrigPattern,
requestID: requestID,
}
entry.Handler.ServeDNS(chainWriter, r)
// If handler wants to continue, try next handler
if chainWriter.shouldContinue {
if entry.Priority != PriorityMgmtCache {
logger.Tracef("handler requested continue for domain=%s", qname)
chainWriter := &ResponseWriterChain{
ResponseWriter: w,
origPattern: entry.OrigPattern,
}
continue
}
entry.Handler.ServeDNS(chainWriter, r)
c.logResponse(logger, chainWriter, qname, startTime)
return
// If handler wants to continue, try next handler
if chainWriter.shouldContinue {
// Only log continue for non-management cache handlers to reduce noise
if entry.Priority != PriorityMgmtCache {
log.Tracef("handler requested continue to next handler for domain=%s", qname)
}
continue
}
return
}
}
// No handler matched or all handlers passed
logger.Tracef("no handler found for domain=%s type=%s class=%s",
qname, dns.TypeToString[question.Qtype], dns.ClassToString[question.Qclass])
log.Tracef("no handler found for domain=%s", qname)
resp := &dns.Msg{}
resp.SetRcode(r, dns.RcodeRefused)
if err := w.WriteMsg(resp); err != nil {
logger.Errorf("failed to write DNS response: %v", err)
log.Errorf("failed to write DNS response: %v", err)
}
}
func (c *HandlerChain) logResponse(logger *log.Entry, cw *ResponseWriterChain, qname string, startTime time.Time) {
if cw.response == nil {
return
}
var meta string
for k, v := range cw.meta {
meta += " " + k + "=" + v
}
logger.Tracef("response: domain=%s rcode=%s answers=%s%s took=%s",
qname, dns.RcodeToString[cw.response.Rcode], resutil.FormatAnswers(cw.response.Answer),
meta, time.Since(startTime))
}
func (c *HandlerChain) isHandlerMatch(qname string, entry HandlerEntry) bool {
switch {
case entry.Pattern == ".":

View File

@@ -1,52 +1,30 @@
package local
import (
"context"
"errors"
"fmt"
"net"
"net/netip"
"slices"
"strings"
"sync"
"time"
"github.com/miekg/dns"
log "github.com/sirupsen/logrus"
"golang.org/x/exp/maps"
"github.com/netbirdio/netbird/client/internal/dns/resutil"
"github.com/netbirdio/netbird/client/internal/dns/types"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/shared/management/domain"
)
const externalResolutionTimeout = 4 * time.Second
type resolver interface {
LookupNetIP(ctx context.Context, network, host string) ([]netip.Addr, error)
}
type Resolver struct {
mu sync.RWMutex
records map[dns.Question][]dns.RR
domains map[domain.Domain]struct{}
// zones maps zone domain -> NonAuthoritative (true = non-authoritative, user-created zone)
zones map[domain.Domain]bool
resolver resolver
ctx context.Context
cancel context.CancelFunc
}
func NewResolver() *Resolver {
ctx, cancel := context.WithCancel(context.Background())
return &Resolver{
records: make(map[dns.Question][]dns.RR),
domains: make(map[domain.Domain]struct{}),
zones: make(map[domain.Domain]bool),
ctx: ctx,
cancel: cancel,
}
}
@@ -59,18 +37,7 @@ func (d *Resolver) String() string {
return fmt.Sprintf("LocalResolver [%d records]", len(d.records))
}
func (d *Resolver) Stop() {
if d.cancel != nil {
d.cancel()
}
d.mu.Lock()
defer d.mu.Unlock()
maps.Clear(d.records)
maps.Clear(d.domains)
maps.Clear(d.zones)
}
func (d *Resolver) Stop() {}
// ID returns the unique handler ID
func (d *Resolver) ID() types.HandlerID {
@@ -81,85 +48,35 @@ func (d *Resolver) ProbeAvailability() {}
// ServeDNS handles a DNS request
func (d *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
logger := log.WithField("request_id", resutil.GetRequestID(w))
if len(r.Question) == 0 {
logger.Debug("received local resolver request with no question")
log.Debugf("received local resolver request with no question")
return
}
question := r.Question[0]
question.Name = strings.ToLower(dns.Fqdn(question.Name))
log.Tracef("received local question: domain=%s type=%v class=%v", r.Question[0].Name, question.Qtype, question.Qclass)
replyMessage := &dns.Msg{}
replyMessage.SetReply(r)
replyMessage.RecursionAvailable = true
result := d.lookupRecords(logger, question)
replyMessage.Authoritative = !result.hasExternalData
replyMessage.Answer = result.records
replyMessage.Rcode = d.determineRcode(question, result)
if replyMessage.Rcode == dns.RcodeNameError && d.shouldFallthrough(question.Name) {
d.continueToNext(logger, w, r)
return
// lookup all records matching the question
records := d.lookupRecords(question)
if len(records) > 0 {
replyMessage.Rcode = dns.RcodeSuccess
replyMessage.Answer = append(replyMessage.Answer, records...)
} else {
// Check if we have any records for this domain name with different types
if d.hasRecordsForDomain(domain.Domain(question.Name)) {
replyMessage.Rcode = dns.RcodeSuccess // NOERROR with 0 records
} else {
replyMessage.Rcode = dns.RcodeNameError // NXDOMAIN
}
}
if err := w.WriteMsg(replyMessage); err != nil {
logger.Warnf("failed to write the local resolver response: %v", err)
}
}
// determineRcode returns the appropriate DNS response code.
// Per RFC 6604, CNAME chains should return the rcode of the final target resolution,
// even if CNAME records are included in the answer.
func (d *Resolver) determineRcode(question dns.Question, result lookupResult) int {
// Use the rcode from lookup - this properly handles CNAME chains where
// the target may be NXDOMAIN or SERVFAIL even though we have CNAME records
if result.rcode != 0 {
return result.rcode
}
// No records found, but domain exists with different record types (NODATA)
if d.hasRecordsForDomain(domain.Domain(question.Name)) {
return dns.RcodeSuccess
}
return dns.RcodeNameError
}
// findZone finds the matching zone for a query name using reverse suffix lookup.
// Returns (nonAuthoritative, found). This is O(k) where k = number of labels in qname.
func (d *Resolver) findZone(qname string) (nonAuthoritative bool, found bool) {
qname = strings.ToLower(dns.Fqdn(qname))
for {
if nonAuth, ok := d.zones[domain.Domain(qname)]; ok {
return nonAuth, true
}
// Move to parent domain
idx := strings.Index(qname, ".")
if idx == -1 || idx == len(qname)-1 {
return false, false
}
qname = qname[idx+1:]
}
}
// shouldFallthrough checks if the query should fallthrough to the next handler.
// Returns true if the queried name belongs to a non-authoritative zone.
func (d *Resolver) shouldFallthrough(qname string) bool {
d.mu.RLock()
defer d.mu.RUnlock()
nonAuth, found := d.findZone(qname)
return found && nonAuth
}
func (d *Resolver) continueToNext(logger *log.Entry, w dns.ResponseWriter, r *dns.Msg) {
resp := &dns.Msg{}
resp.SetRcode(r, dns.RcodeNameError)
resp.MsgHdr.Zero = true
if err := w.WriteMsg(resp); err != nil {
logger.Warnf("failed to write continue signal: %v", err)
log.Warnf("failed to write the local resolver response: %v", err)
}
}
@@ -172,27 +89,8 @@ func (d *Resolver) hasRecordsForDomain(domainName domain.Domain) bool {
return exists
}
// isInManagedZone checks if the given name falls within any of our managed zones.
// This is used to avoid unnecessary external resolution for CNAME targets that
// are within zones we manage - if we don't have a record for it, it doesn't exist.
// Caller must NOT hold the lock.
func (d *Resolver) isInManagedZone(name string) bool {
d.mu.RLock()
defer d.mu.RUnlock()
_, found := d.findZone(name)
return found
}
// lookupResult contains the result of a DNS lookup operation.
type lookupResult struct {
records []dns.RR
rcode int
hasExternalData bool
}
// lookupRecords fetches *all* DNS records matching the first question in r.
func (d *Resolver) lookupRecords(logger *log.Entry, question dns.Question) lookupResult {
func (d *Resolver) lookupRecords(question dns.Question) []dns.RR {
d.mu.RLock()
records, found := d.records[question]
@@ -200,14 +98,10 @@ func (d *Resolver) lookupRecords(logger *log.Entry, question dns.Question) looku
d.mu.RUnlock()
// alternatively check if we have a cname
if question.Qtype != dns.TypeCNAME {
cnameQuestion := dns.Question{
Name: question.Name,
Qtype: dns.TypeCNAME,
Qclass: question.Qclass,
}
return d.lookupCNAMEChain(logger, cnameQuestion, question.Qtype)
question.Qtype = dns.TypeCNAME
return d.lookupRecords(question)
}
return lookupResult{rcode: dns.RcodeNameError}
return nil
}
recordsCopy := slices.Clone(records)
@@ -225,178 +119,20 @@ func (d *Resolver) lookupRecords(logger *log.Entry, question dns.Question) looku
d.mu.Unlock()
}
return lookupResult{records: recordsCopy, rcode: dns.RcodeSuccess}
return recordsCopy
}
// lookupCNAMEChain follows a CNAME chain and returns the CNAME records along with
// the final resolved record of the requested type. This is required for musl libc
// compatibility, which expects the full answer chain rather than just the CNAME.
func (d *Resolver) lookupCNAMEChain(logger *log.Entry, cnameQuestion dns.Question, targetType uint16) lookupResult {
const maxDepth = 8
var chain []dns.RR
for range maxDepth {
cnameRecords := d.getRecords(cnameQuestion)
if len(cnameRecords) == 0 {
break
}
chain = append(chain, cnameRecords...)
cname, ok := cnameRecords[0].(*dns.CNAME)
if !ok {
break
}
targetName := strings.ToLower(cname.Target)
targetResult := d.resolveCNAMETarget(logger, targetName, targetType, cnameQuestion.Qclass)
// keep following chain
if targetResult.rcode == -1 {
cnameQuestion = dns.Question{Name: targetName, Qtype: dns.TypeCNAME, Qclass: cnameQuestion.Qclass}
continue
}
return d.buildChainResult(chain, targetResult)
}
if len(chain) > 0 {
return lookupResult{records: chain, rcode: dns.RcodeSuccess}
}
return lookupResult{rcode: dns.RcodeSuccess}
}
// buildChainResult combines CNAME chain records with the target resolution result.
// Per RFC 6604, the final rcode is propagated through the chain.
func (d *Resolver) buildChainResult(chain []dns.RR, target lookupResult) lookupResult {
records := chain
if len(target.records) > 0 {
records = append(records, target.records...)
}
// preserve hasExternalData for SERVFAIL so caller knows the error came from upstream
if target.hasExternalData && target.rcode == dns.RcodeServerFailure {
return lookupResult{
records: records,
rcode: dns.RcodeServerFailure,
hasExternalData: true,
}
}
return lookupResult{
records: records,
rcode: target.rcode,
hasExternalData: target.hasExternalData,
}
}
// resolveCNAMETarget attempts to resolve a CNAME target name.
// Returns rcode=-1 to signal "keep following the chain".
func (d *Resolver) resolveCNAMETarget(logger *log.Entry, targetName string, targetType uint16, qclass uint16) lookupResult {
if records := d.getRecords(dns.Question{Name: targetName, Qtype: targetType, Qclass: qclass}); len(records) > 0 {
return lookupResult{records: records, rcode: dns.RcodeSuccess}
}
// another CNAME, keep following
if d.hasRecord(dns.Question{Name: targetName, Qtype: dns.TypeCNAME, Qclass: qclass}) {
return lookupResult{rcode: -1}
}
// domain exists locally but not this record type (NODATA)
if d.hasRecordsForDomain(domain.Domain(targetName)) {
return lookupResult{rcode: dns.RcodeSuccess}
}
// in our zone but doesn't exist (NXDOMAIN)
if d.isInManagedZone(targetName) {
return lookupResult{rcode: dns.RcodeNameError}
}
return d.resolveExternal(logger, targetName, targetType)
}
func (d *Resolver) getRecords(q dns.Question) []dns.RR {
d.mu.RLock()
defer d.mu.RUnlock()
return d.records[q]
}
func (d *Resolver) hasRecord(q dns.Question) bool {
d.mu.RLock()
defer d.mu.RUnlock()
_, ok := d.records[q]
return ok
}
// resolveExternal resolves a domain name using the system resolver.
// This is used to resolve CNAME targets that point outside our local zone,
// which is required for musl libc compatibility (musl expects complete answers).
func (d *Resolver) resolveExternal(logger *log.Entry, name string, qtype uint16) lookupResult {
network := resutil.NetworkForQtype(qtype)
if network == "" {
return lookupResult{rcode: dns.RcodeNotImplemented}
}
resolver := d.resolver
if resolver == nil {
resolver = net.DefaultResolver
}
ctx, cancel := context.WithTimeout(d.ctx, externalResolutionTimeout)
defer cancel()
result := resutil.LookupIP(ctx, resolver, network, name, qtype)
if result.Err != nil {
d.logDNSError(logger, name, qtype, result.Err)
return lookupResult{rcode: result.Rcode, hasExternalData: true}
}
return lookupResult{
records: resutil.IPsToRRs(name, result.IPs, 60),
rcode: dns.RcodeSuccess,
hasExternalData: true,
}
}
// logDNSError logs DNS resolution errors for debugging.
func (d *Resolver) logDNSError(logger *log.Entry, hostname string, qtype uint16, err error) {
qtypeName := dns.TypeToString[qtype]
var dnsErr *net.DNSError
if !errors.As(err, &dnsErr) {
logger.Debugf("DNS resolution failed for %s type %s: %v", hostname, qtypeName, err)
return
}
if dnsErr.IsNotFound {
logger.Tracef("DNS target not found: %s type %s", hostname, qtypeName)
return
}
if dnsErr.Server != "" {
logger.Debugf("DNS resolution failed for %s type %s server=%s: %v", hostname, qtypeName, dnsErr.Server, err)
} else {
logger.Debugf("DNS resolution failed for %s type %s: %v", hostname, qtypeName, err)
}
}
// Update replaces all zones and their records
func (d *Resolver) Update(customZones []nbdns.CustomZone) {
func (d *Resolver) Update(update []nbdns.SimpleRecord) {
d.mu.Lock()
defer d.mu.Unlock()
maps.Clear(d.records)
maps.Clear(d.domains)
maps.Clear(d.zones)
for _, zone := range customZones {
zoneDomain := domain.Domain(strings.ToLower(dns.Fqdn(zone.Domain)))
d.zones[zoneDomain] = zone.NonAuthoritative
for _, rec := range zone.Records {
if err := d.registerRecord(rec); err != nil {
log.Warnf("failed to register the record (%s): %v", rec, err)
}
for _, rec := range update {
if err := d.registerRecord(rec); err != nil {
log.Warnf("failed to register the record (%s): %v", rec, err)
continue
}
}
}

View File

@@ -1,14 +1,8 @@
package local
import (
"context"
"fmt"
"net"
"net/netip"
"strings"
"sync"
"testing"
"time"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
@@ -18,18 +12,6 @@ import (
nbdns "github.com/netbirdio/netbird/dns"
)
// mockResolver implements resolver for testing
type mockResolver struct {
lookupFunc func(ctx context.Context, network, host string) ([]netip.Addr, error)
}
func (m *mockResolver) LookupNetIP(ctx context.Context, network, host string) ([]netip.Addr, error) {
if m.lookupFunc != nil {
return m.lookupFunc(ctx, network, host)
}
return nil, nil
}
func TestLocalResolver_ServeDNS(t *testing.T) {
recordA := nbdns.SimpleRecord{
Name: "peera.netbird.cloud.",
@@ -124,11 +106,11 @@ func TestLocalResolver_Update_StaleRecord(t *testing.T) {
resolver := NewResolver()
zone1 := []nbdns.CustomZone{{Domain: "example.com.", Records: []nbdns.SimpleRecord{record1}}}
zone2 := []nbdns.CustomZone{{Domain: "example.com.", Records: []nbdns.SimpleRecord{record2}}}
update1 := []nbdns.SimpleRecord{record1}
update2 := []nbdns.SimpleRecord{record2}
// Apply first update
resolver.Update(zone1)
resolver.Update(update1)
// Verify first update
resolver.mu.RLock()
@@ -140,7 +122,7 @@ func TestLocalResolver_Update_StaleRecord(t *testing.T) {
assert.Contains(t, rrSlice1[0].String(), record1.RData, "Record after first update should be %s", record1.RData)
// Apply second update
resolver.Update(zone2)
resolver.Update(update2)
// Verify second update
resolver.mu.RLock()
@@ -169,10 +151,10 @@ func TestLocalResolver_MultipleRecords_SameQuestion(t *testing.T) {
Name: recordName, Type: int(recordType), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.2",
}
zones := []nbdns.CustomZone{{Domain: "example.com.", Records: []nbdns.SimpleRecord{record1, record2}}}
update := []nbdns.SimpleRecord{record1, record2}
// Apply update with both records
resolver.Update(zones)
resolver.Update(update)
// Create question that matches both records
question := dns.Question{
@@ -213,10 +195,10 @@ func TestLocalResolver_RecordRotation(t *testing.T) {
Name: recordName, Type: int(recordType), Class: nbdns.DefaultClass, TTL: 300, RData: "192.168.1.3",
}
zones := []nbdns.CustomZone{{Domain: "example.com.", Records: []nbdns.SimpleRecord{record1, record2, record3}}}
update := []nbdns.SimpleRecord{record1, record2, record3}
// Apply update with all three records
resolver.Update(zones)
resolver.Update(update)
msg := new(dns.Msg).SetQuestion(recordName, recordType)
@@ -282,7 +264,7 @@ func TestLocalResolver_CaseInsensitiveMatching(t *testing.T) {
}
// Update resolver with the records
resolver.Update([]nbdns.CustomZone{{Domain: "example.com.", Records: []nbdns.SimpleRecord{lowerCaseRecord, mixedCaseRecord}}})
resolver.Update([]nbdns.SimpleRecord{lowerCaseRecord, mixedCaseRecord})
testCases := []struct {
name string
@@ -397,7 +379,7 @@ func TestLocalResolver_CNAMEFallback(t *testing.T) {
}
// Update resolver with both records
resolver.Update([]nbdns.CustomZone{{Domain: "example.com.", Records: []nbdns.SimpleRecord{cnameRecord, targetRecord}}})
resolver.Update([]nbdns.SimpleRecord{cnameRecord, targetRecord})
testCases := []struct {
name string
@@ -494,20 +476,6 @@ func TestLocalResolver_CNAMEFallback(t *testing.T) {
// with 0 records instead of NXDOMAIN
func TestLocalResolver_NoErrorWithDifferentRecordType(t *testing.T) {
resolver := NewResolver()
// Mock external resolver for CNAME target resolution
resolver.resolver = &mockResolver{
lookupFunc: func(_ context.Context, network, host string) ([]netip.Addr, error) {
if host == "target.example.com." {
if network == "ip4" {
return []netip.Addr{netip.MustParseAddr("93.184.216.34")}, nil
}
if network == "ip6" {
return []netip.Addr{netip.MustParseAddr("2606:2800:220:1:248:1893:25c8:1946")}, nil
}
}
return nil, &net.DNSError{IsNotFound: true, Name: host}
},
}
recordA := nbdns.SimpleRecord{
Name: "example.netbird.cloud.",
@@ -525,7 +493,7 @@ func TestLocalResolver_NoErrorWithDifferentRecordType(t *testing.T) {
RData: "target.example.com.",
}
resolver.Update([]nbdns.CustomZone{{Domain: "netbird.cloud.", Records: []nbdns.SimpleRecord{recordA, recordCNAME}}})
resolver.Update([]nbdns.SimpleRecord{recordA, recordCNAME})
testCases := []struct {
name string
@@ -614,808 +582,3 @@ func TestLocalResolver_NoErrorWithDifferentRecordType(t *testing.T) {
})
}
}
// TestLocalResolver_CNAMEChainResolution tests comprehensive CNAME chain following
func TestLocalResolver_CNAMEChainResolution(t *testing.T) {
t.Run("simple internal CNAME chain", func(t *testing.T) {
resolver := NewResolver()
resolver.Update([]nbdns.CustomZone{{
Domain: "example.com.",
Records: []nbdns.SimpleRecord{
{Name: "alias.example.com.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "target.example.com."},
{Name: "target.example.com.", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "192.168.1.1"},
},
}})
msg := new(dns.Msg).SetQuestion("alias.example.com.", dns.TypeA)
var resp *dns.Msg
resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { resp = m; return nil }}, msg)
require.NotNil(t, resp)
require.Equal(t, dns.RcodeSuccess, resp.Rcode)
require.Len(t, resp.Answer, 2)
cname, ok := resp.Answer[0].(*dns.CNAME)
require.True(t, ok)
assert.Equal(t, "target.example.com.", cname.Target)
a, ok := resp.Answer[1].(*dns.A)
require.True(t, ok)
assert.Equal(t, "192.168.1.1", a.A.String())
})
t.Run("multi-hop CNAME chain", func(t *testing.T) {
resolver := NewResolver()
resolver.Update([]nbdns.CustomZone{{
Domain: "test.",
Records: []nbdns.SimpleRecord{
{Name: "hop1.test.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "hop2.test."},
{Name: "hop2.test.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "hop3.test."},
{Name: "hop3.test.", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"},
},
}})
msg := new(dns.Msg).SetQuestion("hop1.test.", dns.TypeA)
var resp *dns.Msg
resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { resp = m; return nil }}, msg)
require.NotNil(t, resp)
require.Equal(t, dns.RcodeSuccess, resp.Rcode)
require.Len(t, resp.Answer, 3)
})
t.Run("CNAME to non-existent internal target returns only CNAME", func(t *testing.T) {
resolver := NewResolver()
resolver.Update([]nbdns.CustomZone{{
Domain: "test.",
Records: []nbdns.SimpleRecord{
{Name: "alias.test.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "nonexistent.test."},
},
}})
msg := new(dns.Msg).SetQuestion("alias.test.", dns.TypeA)
var resp *dns.Msg
resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { resp = m; return nil }}, msg)
require.NotNil(t, resp)
require.Len(t, resp.Answer, 1)
_, ok := resp.Answer[0].(*dns.CNAME)
assert.True(t, ok)
})
}
// TestLocalResolver_CNAMEMaxDepth tests the maximum depth limit for CNAME chains
func TestLocalResolver_CNAMEMaxDepth(t *testing.T) {
t.Run("chain at max depth resolves", func(t *testing.T) {
resolver := NewResolver()
var records []nbdns.SimpleRecord
// Create chain of 7 CNAMEs (under max of 8)
for i := 1; i <= 7; i++ {
records = append(records, nbdns.SimpleRecord{
Name: fmt.Sprintf("hop%d.test.", i),
Type: int(dns.TypeCNAME),
Class: nbdns.DefaultClass,
TTL: 300,
RData: fmt.Sprintf("hop%d.test.", i+1),
})
}
records = append(records, nbdns.SimpleRecord{
Name: "hop8.test.", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.10.10.10",
})
resolver.Update([]nbdns.CustomZone{{Domain: "test.", Records: records}})
msg := new(dns.Msg).SetQuestion("hop1.test.", dns.TypeA)
var resp *dns.Msg
resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { resp = m; return nil }}, msg)
require.NotNil(t, resp)
require.Equal(t, dns.RcodeSuccess, resp.Rcode)
require.Len(t, resp.Answer, 8)
})
t.Run("chain exceeding max depth stops", func(t *testing.T) {
resolver := NewResolver()
var records []nbdns.SimpleRecord
// Create chain of 10 CNAMEs (exceeds max of 8)
for i := 1; i <= 10; i++ {
records = append(records, nbdns.SimpleRecord{
Name: fmt.Sprintf("deep%d.test.", i),
Type: int(dns.TypeCNAME),
Class: nbdns.DefaultClass,
TTL: 300,
RData: fmt.Sprintf("deep%d.test.", i+1),
})
}
records = append(records, nbdns.SimpleRecord{
Name: "deep11.test.", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.10.10.10",
})
resolver.Update([]nbdns.CustomZone{{Domain: "test.", Records: records}})
msg := new(dns.Msg).SetQuestion("deep1.test.", dns.TypeA)
var resp *dns.Msg
resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { resp = m; return nil }}, msg)
require.NotNil(t, resp)
// Should NOT have the final A record (chain too deep)
assert.LessOrEqual(t, len(resp.Answer), 8)
})
t.Run("circular CNAME is protected by max depth", func(t *testing.T) {
resolver := NewResolver()
resolver.Update([]nbdns.CustomZone{{
Domain: "test.",
Records: []nbdns.SimpleRecord{
{Name: "loop1.test.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "loop2.test."},
{Name: "loop2.test.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "loop1.test."},
},
}})
msg := new(dns.Msg).SetQuestion("loop1.test.", dns.TypeA)
var resp *dns.Msg
resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { resp = m; return nil }}, msg)
require.NotNil(t, resp)
assert.LessOrEqual(t, len(resp.Answer), 8)
})
}
// TestLocalResolver_ExternalCNAMEResolution tests CNAME resolution to external domains
func TestLocalResolver_ExternalCNAMEResolution(t *testing.T) {
t.Run("CNAME to external domain resolves via external resolver", func(t *testing.T) {
resolver := NewResolver()
resolver.resolver = &mockResolver{
lookupFunc: func(_ context.Context, network, host string) ([]netip.Addr, error) {
if host == "external.example.com." && network == "ip4" {
return []netip.Addr{netip.MustParseAddr("93.184.216.34")}, nil
}
return nil, nil
},
}
resolver.Update([]nbdns.CustomZone{{
Domain: "test.",
Records: []nbdns.SimpleRecord{
{Name: "alias.test.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "external.example.com."},
},
}})
msg := new(dns.Msg).SetQuestion("alias.test.", dns.TypeA)
var resp *dns.Msg
resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { resp = m; return nil }}, msg)
require.NotNil(t, resp)
require.Len(t, resp.Answer, 2, "Should have CNAME + A record")
cname, ok := resp.Answer[0].(*dns.CNAME)
require.True(t, ok)
assert.Equal(t, "external.example.com.", cname.Target)
a, ok := resp.Answer[1].(*dns.A)
require.True(t, ok)
assert.Equal(t, "93.184.216.34", a.A.String())
})
t.Run("CNAME to external domain resolves IPv6", func(t *testing.T) {
resolver := NewResolver()
resolver.resolver = &mockResolver{
lookupFunc: func(_ context.Context, network, host string) ([]netip.Addr, error) {
if host == "external.example.com." && network == "ip6" {
return []netip.Addr{netip.MustParseAddr("2606:2800:220:1:248:1893:25c8:1946")}, nil
}
return nil, nil
},
}
resolver.Update([]nbdns.CustomZone{{
Domain: "test.",
Records: []nbdns.SimpleRecord{
{Name: "alias.test.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "external.example.com."},
},
}})
msg := new(dns.Msg).SetQuestion("alias.test.", dns.TypeAAAA)
var resp *dns.Msg
resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { resp = m; return nil }}, msg)
require.NotNil(t, resp)
require.Len(t, resp.Answer, 2, "Should have CNAME + AAAA record")
cname, ok := resp.Answer[0].(*dns.CNAME)
require.True(t, ok)
assert.Equal(t, "external.example.com.", cname.Target)
aaaa, ok := resp.Answer[1].(*dns.AAAA)
require.True(t, ok)
assert.Equal(t, "2606:2800:220:1:248:1893:25c8:1946", aaaa.AAAA.String())
})
t.Run("concurrent external resolution", func(t *testing.T) {
resolver := NewResolver()
resolver.resolver = &mockResolver{
lookupFunc: func(_ context.Context, network, host string) ([]netip.Addr, error) {
if host == "external.example.com." && network == "ip4" {
return []netip.Addr{netip.MustParseAddr("93.184.216.34")}, nil
}
return nil, nil
},
}
resolver.Update([]nbdns.CustomZone{{
Domain: "test.",
Records: []nbdns.SimpleRecord{
{Name: "concurrent.test.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "external.example.com."},
},
}})
var wg sync.WaitGroup
results := make([]*dns.Msg, 10)
for i := 0; i < 10; i++ {
wg.Add(1)
go func(idx int) {
defer wg.Done()
msg := new(dns.Msg).SetQuestion("concurrent.test.", dns.TypeA)
var resp *dns.Msg
resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { resp = m; return nil }}, msg)
results[idx] = resp
}(i)
}
wg.Wait()
for i, resp := range results {
require.NotNil(t, resp, "Response %d should not be nil", i)
require.Len(t, resp.Answer, 2, "Response %d should have CNAME + A", i)
}
})
}
// TestLocalResolver_ZoneManagement tests zone-aware CNAME resolution
func TestLocalResolver_ZoneManagement(t *testing.T) {
t.Run("Update sets zones correctly", func(t *testing.T) {
resolver := NewResolver()
resolver.Update([]nbdns.CustomZone{
{Domain: "example.com.", Records: []nbdns.SimpleRecord{
{Name: "host.example.com.", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"},
}},
{Domain: "test.local."},
})
assert.True(t, resolver.isInManagedZone("host.example.com."))
assert.True(t, resolver.isInManagedZone("other.example.com."))
assert.True(t, resolver.isInManagedZone("sub.test.local."))
assert.False(t, resolver.isInManagedZone("external.com."))
})
t.Run("isInManagedZone case insensitive", func(t *testing.T) {
resolver := NewResolver()
resolver.Update([]nbdns.CustomZone{{Domain: "Example.COM."}})
assert.True(t, resolver.isInManagedZone("host.example.com."))
assert.True(t, resolver.isInManagedZone("HOST.EXAMPLE.COM."))
})
t.Run("Update clears zones", func(t *testing.T) {
resolver := NewResolver()
resolver.Update([]nbdns.CustomZone{{Domain: "example.com."}})
assert.True(t, resolver.isInManagedZone("host.example.com."))
resolver.Update(nil)
assert.False(t, resolver.isInManagedZone("host.example.com."))
})
}
// TestLocalResolver_CNAMEZoneAwareResolution tests CNAME resolution with zone awareness
func TestLocalResolver_CNAMEZoneAwareResolution(t *testing.T) {
t.Run("CNAME target in managed zone returns NXDOMAIN per RFC 6604", func(t *testing.T) {
resolver := NewResolver()
resolver.Update([]nbdns.CustomZone{{
Domain: "myzone.test.",
Records: []nbdns.SimpleRecord{
{Name: "alias.myzone.test.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "nonexistent.myzone.test."},
},
}})
msg := new(dns.Msg).SetQuestion("alias.myzone.test.", dns.TypeA)
var resp *dns.Msg
resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { resp = m; return nil }}, msg)
require.NotNil(t, resp)
assert.Equal(t, dns.RcodeNameError, resp.Rcode, "Should return NXDOMAIN")
require.Len(t, resp.Answer, 1, "Should include CNAME in answer")
})
t.Run("CNAME to external domain skips zone check", func(t *testing.T) {
resolver := NewResolver()
resolver.resolver = &mockResolver{
lookupFunc: func(_ context.Context, network, host string) ([]netip.Addr, error) {
if host == "external.other.com." && network == "ip4" {
return []netip.Addr{netip.MustParseAddr("203.0.113.1")}, nil
}
return nil, nil
},
}
resolver.Update([]nbdns.CustomZone{{
Domain: "myzone.test.",
Records: []nbdns.SimpleRecord{
{Name: "alias.myzone.test.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "external.other.com."},
},
}})
msg := new(dns.Msg).SetQuestion("alias.myzone.test.", dns.TypeA)
var resp *dns.Msg
resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { resp = m; return nil }}, msg)
require.NotNil(t, resp)
assert.Equal(t, dns.RcodeSuccess, resp.Rcode)
require.Len(t, resp.Answer, 2, "Should have CNAME + A from external resolution")
})
t.Run("CNAME target exists with different type returns NODATA not NXDOMAIN", func(t *testing.T) {
resolver := NewResolver()
// CNAME points to target that has A but no AAAA - query for AAAA should be NODATA
resolver.Update([]nbdns.CustomZone{{
Domain: "myzone.test.",
Records: []nbdns.SimpleRecord{
{Name: "alias.myzone.test.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "target.myzone.test."},
{Name: "target.myzone.test.", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "1.1.1.1"},
},
}})
msg := new(dns.Msg).SetQuestion("alias.myzone.test.", dns.TypeAAAA)
var resp *dns.Msg
resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { resp = m; return nil }}, msg)
require.NotNil(t, resp)
assert.Equal(t, dns.RcodeSuccess, resp.Rcode, "Should return NODATA (success), not NXDOMAIN")
require.Len(t, resp.Answer, 1, "Should have only CNAME, no AAAA")
_, ok := resp.Answer[0].(*dns.CNAME)
assert.True(t, ok, "Answer should be CNAME record")
})
t.Run("external CNAME target exists but no AAAA records (NODATA)", func(t *testing.T) {
resolver := NewResolver()
resolver.resolver = &mockResolver{
lookupFunc: func(_ context.Context, network, host string) ([]netip.Addr, error) {
if host == "external.example.com." {
if network == "ip6" {
// No AAAA records
return nil, &net.DNSError{IsNotFound: true, Name: host}
}
if network == "ip4" {
// But A records exist - domain exists
return []netip.Addr{netip.MustParseAddr("93.184.216.34")}, nil
}
}
return nil, &net.DNSError{IsNotFound: true, Name: host}
},
}
resolver.Update([]nbdns.CustomZone{{
Domain: "test.",
Records: []nbdns.SimpleRecord{
{Name: "alias.test.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "external.example.com."},
},
}})
msg := new(dns.Msg).SetQuestion("alias.test.", dns.TypeAAAA)
var resp *dns.Msg
resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { resp = m; return nil }}, msg)
require.NotNil(t, resp)
assert.Equal(t, dns.RcodeSuccess, resp.Rcode, "Should return NODATA (success), not NXDOMAIN")
require.Len(t, resp.Answer, 1, "Should have only CNAME")
_, ok := resp.Answer[0].(*dns.CNAME)
assert.True(t, ok, "Answer should be CNAME record")
})
// Table-driven test for all external resolution outcomes
externalCases := []struct {
name string
lookupFunc func(context.Context, string, string) ([]netip.Addr, error)
expectedRcode int
expectedAnswer int
}{
{
name: "external NXDOMAIN (both A and AAAA not found)",
lookupFunc: func(_ context.Context, network, host string) ([]netip.Addr, error) {
return nil, &net.DNSError{IsNotFound: true, Name: host}
},
expectedRcode: dns.RcodeNameError,
expectedAnswer: 1, // CNAME only
},
{
name: "external SERVFAIL (temporary error)",
lookupFunc: func(_ context.Context, network, host string) ([]netip.Addr, error) {
return nil, &net.DNSError{IsTemporary: true, Name: host}
},
expectedRcode: dns.RcodeServerFailure,
expectedAnswer: 1, // CNAME only
},
{
name: "external SERVFAIL (timeout)",
lookupFunc: func(_ context.Context, network, host string) ([]netip.Addr, error) {
return nil, &net.DNSError{IsTimeout: true, Name: host}
},
expectedRcode: dns.RcodeServerFailure,
expectedAnswer: 1, // CNAME only
},
{
name: "external SERVFAIL (generic error)",
lookupFunc: func(_ context.Context, network, host string) ([]netip.Addr, error) {
return nil, fmt.Errorf("connection refused")
},
expectedRcode: dns.RcodeServerFailure,
expectedAnswer: 1, // CNAME only
},
{
name: "external success with IPs",
lookupFunc: func(_ context.Context, network, host string) ([]netip.Addr, error) {
if network == "ip4" {
return []netip.Addr{netip.MustParseAddr("93.184.216.34")}, nil
}
return nil, &net.DNSError{IsNotFound: true, Name: host}
},
expectedRcode: dns.RcodeSuccess,
expectedAnswer: 2, // CNAME + A
},
}
for _, tc := range externalCases {
t.Run(tc.name, func(t *testing.T) {
resolver := NewResolver()
resolver.resolver = &mockResolver{lookupFunc: tc.lookupFunc}
resolver.Update([]nbdns.CustomZone{{
Domain: "test.",
Records: []nbdns.SimpleRecord{
{Name: "alias.test.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "external.example.com."},
},
}})
msg := new(dns.Msg).SetQuestion("alias.test.", dns.TypeA)
var resp *dns.Msg
resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { resp = m; return nil }}, msg)
require.NotNil(t, resp)
assert.Equal(t, tc.expectedRcode, resp.Rcode, "rcode mismatch")
assert.Len(t, resp.Answer, tc.expectedAnswer, "answer count mismatch")
if tc.expectedAnswer > 0 {
_, ok := resp.Answer[0].(*dns.CNAME)
assert.True(t, ok, "first answer should be CNAME")
}
})
}
}
// TestLocalResolver_Fallthrough verifies that non-authoritative zones
// trigger fallthrough (Zero bit set) when no records match
func TestLocalResolver_Fallthrough(t *testing.T) {
resolver := NewResolver()
record := nbdns.SimpleRecord{
Name: "existing.custom.zone.",
Type: int(dns.TypeA),
Class: nbdns.DefaultClass,
TTL: 300,
RData: "10.0.0.1",
}
testCases := []struct {
name string
zones []nbdns.CustomZone
queryName string
expectFallthrough bool
expectRecord bool
}{
{
name: "Authoritative zone returns NXDOMAIN without fallthrough",
zones: []nbdns.CustomZone{{
Domain: "custom.zone.",
Records: []nbdns.SimpleRecord{record},
}},
queryName: "nonexistent.custom.zone.",
expectFallthrough: false,
expectRecord: false,
},
{
name: "Non-authoritative zone triggers fallthrough",
zones: []nbdns.CustomZone{{
Domain: "custom.zone.",
Records: []nbdns.SimpleRecord{record},
NonAuthoritative: true,
}},
queryName: "nonexistent.custom.zone.",
expectFallthrough: true,
expectRecord: false,
},
{
name: "Record found in non-authoritative zone returns normally",
zones: []nbdns.CustomZone{{
Domain: "custom.zone.",
Records: []nbdns.SimpleRecord{record},
NonAuthoritative: true,
}},
queryName: "existing.custom.zone.",
expectFallthrough: false,
expectRecord: true,
},
{
name: "Record found in authoritative zone returns normally",
zones: []nbdns.CustomZone{{
Domain: "custom.zone.",
Records: []nbdns.SimpleRecord{record},
}},
queryName: "existing.custom.zone.",
expectFallthrough: false,
expectRecord: true,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
resolver.Update(tc.zones)
var responseMSG *dns.Msg
responseWriter := &test.MockResponseWriter{
WriteMsgFunc: func(m *dns.Msg) error {
responseMSG = m
return nil
},
}
msg := new(dns.Msg).SetQuestion(tc.queryName, dns.TypeA)
resolver.ServeDNS(responseWriter, msg)
require.NotNil(t, responseMSG, "Should have received a response")
if tc.expectFallthrough {
assert.True(t, responseMSG.MsgHdr.Zero, "Zero bit should be set for fallthrough")
assert.Equal(t, dns.RcodeNameError, responseMSG.Rcode, "Should return NXDOMAIN")
} else {
assert.False(t, responseMSG.MsgHdr.Zero, "Zero bit should not be set")
}
if tc.expectRecord {
assert.Greater(t, len(responseMSG.Answer), 0, "Should have answer records")
assert.Equal(t, dns.RcodeSuccess, responseMSG.Rcode)
}
})
}
}
// TestLocalResolver_AuthoritativeFlag tests the AA flag behavior
func TestLocalResolver_AuthoritativeFlag(t *testing.T) {
t.Run("direct record lookup is authoritative", func(t *testing.T) {
resolver := NewResolver()
resolver.Update([]nbdns.CustomZone{{
Domain: "example.com.",
Records: []nbdns.SimpleRecord{
{Name: "host.example.com.", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"},
},
}})
msg := new(dns.Msg).SetQuestion("host.example.com.", dns.TypeA)
var resp *dns.Msg
resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { resp = m; return nil }}, msg)
require.NotNil(t, resp)
assert.True(t, resp.Authoritative)
})
t.Run("external resolution is not authoritative", func(t *testing.T) {
resolver := NewResolver()
resolver.resolver = &mockResolver{
lookupFunc: func(_ context.Context, network, host string) ([]netip.Addr, error) {
if host == "external.example.com." && network == "ip4" {
return []netip.Addr{netip.MustParseAddr("93.184.216.34")}, nil
}
return nil, nil
},
}
resolver.Update([]nbdns.CustomZone{{
Domain: "test.",
Records: []nbdns.SimpleRecord{
{Name: "alias.test.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "external.example.com."},
},
}})
msg := new(dns.Msg).SetQuestion("alias.test.", dns.TypeA)
var resp *dns.Msg
resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { resp = m; return nil }}, msg)
require.NotNil(t, resp)
require.Len(t, resp.Answer, 2)
assert.False(t, resp.Authoritative)
})
}
// TestLocalResolver_Stop tests cleanup on Stop
func TestLocalResolver_Stop(t *testing.T) {
t.Run("Stop clears all state", func(t *testing.T) {
resolver := NewResolver()
resolver.Update([]nbdns.CustomZone{{
Domain: "example.com.",
Records: []nbdns.SimpleRecord{
{Name: "host.example.com.", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"},
},
}})
resolver.Stop()
msg := new(dns.Msg).SetQuestion("host.example.com.", dns.TypeA)
var resp *dns.Msg
resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { resp = m; return nil }}, msg)
require.NotNil(t, resp)
assert.Len(t, resp.Answer, 0)
assert.False(t, resolver.isInManagedZone("host.example.com."))
})
t.Run("Stop is safe to call multiple times", func(t *testing.T) {
resolver := NewResolver()
resolver.Update([]nbdns.CustomZone{{
Domain: "example.com.",
Records: []nbdns.SimpleRecord{
{Name: "host.example.com.", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"},
},
}})
resolver.Stop()
resolver.Stop()
resolver.Stop()
})
t.Run("Stop cancels in-flight external resolution", func(t *testing.T) {
resolver := NewResolver()
lookupStarted := make(chan struct{})
lookupCtxCanceled := make(chan struct{})
resolver.resolver = &mockResolver{
lookupFunc: func(ctx context.Context, network, host string) ([]netip.Addr, error) {
close(lookupStarted)
<-ctx.Done()
close(lookupCtxCanceled)
return nil, ctx.Err()
},
}
resolver.Update([]nbdns.CustomZone{{
Domain: "test.",
Records: []nbdns.SimpleRecord{
{Name: "alias.test.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "external.example.com."},
},
}})
done := make(chan struct{})
go func() {
msg := new(dns.Msg).SetQuestion("alias.test.", dns.TypeA)
resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { return nil }}, msg)
close(done)
}()
<-lookupStarted
resolver.Stop()
select {
case <-lookupCtxCanceled:
case <-time.After(time.Second):
t.Fatal("external lookup context was not canceled")
}
select {
case <-done:
case <-time.After(time.Second):
t.Fatal("ServeDNS did not return after Stop")
}
})
}
// TestLocalResolver_FallthroughCaseInsensitive verifies case-insensitive domain matching for fallthrough
func TestLocalResolver_FallthroughCaseInsensitive(t *testing.T) {
resolver := NewResolver()
resolver.Update([]nbdns.CustomZone{{
Domain: "EXAMPLE.COM.",
Records: []nbdns.SimpleRecord{{Name: "host.example.com.", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "1.2.3.4"}},
NonAuthoritative: true,
}})
var responseMSG *dns.Msg
responseWriter := &test.MockResponseWriter{
WriteMsgFunc: func(m *dns.Msg) error {
responseMSG = m
return nil
},
}
msg := new(dns.Msg).SetQuestion("nonexistent.example.com.", dns.TypeA)
resolver.ServeDNS(responseWriter, msg)
require.NotNil(t, responseMSG)
assert.True(t, responseMSG.MsgHdr.Zero, "Should fallthrough for non-authoritative zone with case-insensitive match")
}
// BenchmarkFindZone_BestCase benchmarks zone lookup with immediate match (first label)
func BenchmarkFindZone_BestCase(b *testing.B) {
resolver := NewResolver()
// Single zone that matches immediately
resolver.Update([]nbdns.CustomZone{{
Domain: "example.com.",
NonAuthoritative: true,
}})
b.ResetTimer()
for i := 0; i < b.N; i++ {
resolver.shouldFallthrough("example.com.")
}
}
// BenchmarkFindZone_WorstCase benchmarks zone lookup with many zones, no match, many labels
func BenchmarkFindZone_WorstCase(b *testing.B) {
resolver := NewResolver()
// 100 zones that won't match
var zones []nbdns.CustomZone
for i := 0; i < 100; i++ {
zones = append(zones, nbdns.CustomZone{
Domain: fmt.Sprintf("zone%d.internal.", i),
NonAuthoritative: true,
})
}
resolver.Update(zones)
// Query with many labels that won't match any zone
qname := "a.b.c.d.e.f.g.h.external.com."
b.ResetTimer()
for i := 0; i < b.N; i++ {
resolver.shouldFallthrough(qname)
}
}
// BenchmarkFindZone_TypicalCase benchmarks typical usage: few zones, subdomain match
func BenchmarkFindZone_TypicalCase(b *testing.B) {
resolver := NewResolver()
// Typical setup: peer zone (authoritative) + one user zone (non-authoritative)
resolver.Update([]nbdns.CustomZone{
{Domain: "netbird.cloud.", NonAuthoritative: false},
{Domain: "custom.local.", NonAuthoritative: true},
})
// Query for subdomain of user zone
qname := "myhost.custom.local."
b.ResetTimer()
for i := 0; i < b.N; i++ {
resolver.shouldFallthrough(qname)
}
}
// BenchmarkIsInManagedZone_ManyZones benchmarks isInManagedZone with 100 zones
func BenchmarkIsInManagedZone_ManyZones(b *testing.B) {
resolver := NewResolver()
var zones []nbdns.CustomZone
for i := 0; i < 100; i++ {
zones = append(zones, nbdns.CustomZone{
Domain: fmt.Sprintf("zone%d.internal.", i),
})
}
resolver.Update(zones)
// Query that matches zone50
qname := "host.zone50.internal."
b.ResetTimer()
for i := 0; i < b.N; i++ {
resolver.isInManagedZone(qname)
}
}

View File

@@ -1,197 +0,0 @@
// Package resutil provides shared DNS resolution utilities
package resutil
import (
"context"
"crypto/rand"
"encoding/hex"
"errors"
"net"
"net/netip"
"strings"
"github.com/miekg/dns"
log "github.com/sirupsen/logrus"
)
// GenerateRequestID creates a random 8-character hex string for request tracing.
func GenerateRequestID() string {
bytes := make([]byte, 4)
if _, err := rand.Read(bytes); err != nil {
log.Errorf("generate request ID: %v", err)
return ""
}
return hex.EncodeToString(bytes)
}
// IPsToRRs converts a slice of IP addresses to DNS resource records.
// IPv4 addresses become A records, IPv6 addresses become AAAA records.
func IPsToRRs(name string, ips []netip.Addr, ttl uint32) []dns.RR {
var result []dns.RR
for _, ip := range ips {
if ip.Is6() {
result = append(result, &dns.AAAA{
Hdr: dns.RR_Header{
Name: name,
Rrtype: dns.TypeAAAA,
Class: dns.ClassINET,
Ttl: ttl,
},
AAAA: ip.AsSlice(),
})
} else {
result = append(result, &dns.A{
Hdr: dns.RR_Header{
Name: name,
Rrtype: dns.TypeA,
Class: dns.ClassINET,
Ttl: ttl,
},
A: ip.AsSlice(),
})
}
}
return result
}
// NetworkForQtype returns the network string ("ip4" or "ip6") for a DNS query type.
// Returns empty string for unsupported types.
func NetworkForQtype(qtype uint16) string {
switch qtype {
case dns.TypeA:
return "ip4"
case dns.TypeAAAA:
return "ip6"
default:
return ""
}
}
type resolver interface {
LookupNetIP(ctx context.Context, network, host string) ([]netip.Addr, error)
}
// chainedWriter is implemented by ResponseWriters that carry request metadata
type chainedWriter interface {
RequestID() string
SetMeta(key, value string)
}
// GetRequestID extracts a request ID from the ResponseWriter if available,
// otherwise generates a new one.
func GetRequestID(w dns.ResponseWriter) string {
if cw, ok := w.(chainedWriter); ok {
if id := cw.RequestID(); id != "" {
return id
}
}
return GenerateRequestID()
}
// SetMeta sets metadata on the ResponseWriter if it supports it.
func SetMeta(w dns.ResponseWriter, key, value string) {
if cw, ok := w.(chainedWriter); ok {
cw.SetMeta(key, value)
}
}
// LookupResult contains the result of an external DNS lookup
type LookupResult struct {
IPs []netip.Addr
Rcode int
Err error // Original error for caller's logging needs
}
// LookupIP performs a DNS lookup and determines the appropriate rcode.
func LookupIP(ctx context.Context, r resolver, network, host string, qtype uint16) LookupResult {
ips, err := r.LookupNetIP(ctx, network, host)
if err != nil {
return LookupResult{
Rcode: getRcodeForError(ctx, r, host, qtype, err),
Err: err,
}
}
// Unmap IPv4-mapped IPv6 addresses that some resolvers may return
for i, ip := range ips {
ips[i] = ip.Unmap()
}
return LookupResult{
IPs: ips,
Rcode: dns.RcodeSuccess,
}
}
func getRcodeForError(ctx context.Context, r resolver, host string, qtype uint16, err error) int {
var dnsErr *net.DNSError
if !errors.As(err, &dnsErr) {
return dns.RcodeServerFailure
}
if dnsErr.IsNotFound {
return getRcodeForNotFound(ctx, r, host, qtype)
}
return dns.RcodeServerFailure
}
// getRcodeForNotFound distinguishes between NXDOMAIN (domain doesn't exist) and NODATA
// (domain exists but no records of requested type) by checking the opposite record type.
//
// musl libc (the reason we need this distinction) only queries A/AAAA pairs in getaddrinfo,
// so checking the opposite A/AAAA type is sufficient. Other record types (MX, TXT, etc.)
// are not queried by musl and don't need this handling.
func getRcodeForNotFound(ctx context.Context, r resolver, domain string, originalQtype uint16) int {
// Try querying for a different record type to see if the domain exists
// If the original query was for AAAA, try A. If it was for A, try AAAA.
// This helps distinguish between NXDOMAIN and NODATA.
var alternativeNetwork string
switch originalQtype {
case dns.TypeAAAA:
alternativeNetwork = "ip4"
case dns.TypeA:
alternativeNetwork = "ip6"
default:
return dns.RcodeNameError
}
if _, err := r.LookupNetIP(ctx, alternativeNetwork, domain); err != nil {
var dnsErr *net.DNSError
if errors.As(err, &dnsErr) && dnsErr.IsNotFound {
// Alternative query also returned not found - domain truly doesn't exist
return dns.RcodeNameError
}
// Some other error (timeout, server failure, etc.) - can't determine, assume domain exists
return dns.RcodeSuccess
}
// Alternative query succeeded - domain exists but has no records of this type
return dns.RcodeSuccess
}
// FormatAnswers formats DNS resource records for logging.
func FormatAnswers(answers []dns.RR) string {
if len(answers) == 0 {
return "[]"
}
parts := make([]string, 0, len(answers))
for _, rr := range answers {
switch r := rr.(type) {
case *dns.A:
parts = append(parts, r.A.String())
case *dns.AAAA:
parts = append(parts, r.AAAA.String())
case *dns.CNAME:
parts = append(parts, "CNAME:"+r.Target)
case *dns.PTR:
parts = append(parts, "PTR:"+r.Ptr)
default:
parts = append(parts, dns.TypeToString[rr.Header().Rrtype])
}
}
return "[" + strings.Join(parts, ", ") + "]"
}

View File

@@ -485,7 +485,7 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
}
}
localMuxUpdates, localZones, err := s.buildLocalHandlerUpdate(update.CustomZones)
localMuxUpdates, localRecords, err := s.buildLocalHandlerUpdate(update.CustomZones)
if err != nil {
return fmt.Errorf("local handler updater: %w", err)
}
@@ -498,7 +498,8 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
s.updateMux(muxUpdates)
s.localResolver.Update(localZones)
// register local records
s.localResolver.Update(localRecords)
s.currentConfig = dnsConfigToHostDNSConfig(update, s.service.RuntimeIP(), s.service.RuntimePort())
@@ -631,7 +632,9 @@ func (s *DefaultServer) registerFallback(config HostDNSConfig) {
handler, err := newUpstreamResolver(
s.ctx,
s.wgInterface,
s.wgInterface.Name(),
s.wgInterface.Address().IP,
s.wgInterface.Address().Network,
s.statusRecorder,
s.hostsDNSHolder,
nbdns.RootZone,
@@ -656,9 +659,9 @@ func (s *DefaultServer) registerFallback(config HostDNSConfig) {
s.registerHandler([]string{nbdns.RootZone}, handler, PriorityFallback)
}
func (s *DefaultServer) buildLocalHandlerUpdate(customZones []nbdns.CustomZone) ([]handlerWrapper, []nbdns.CustomZone, error) {
func (s *DefaultServer) buildLocalHandlerUpdate(customZones []nbdns.CustomZone) ([]handlerWrapper, []nbdns.SimpleRecord, error) {
var muxUpdates []handlerWrapper
var zones []nbdns.CustomZone
var localRecords []nbdns.SimpleRecord
for _, customZone := range customZones {
if len(customZone.Records) == 0 {
@@ -672,20 +675,17 @@ func (s *DefaultServer) buildLocalHandlerUpdate(customZones []nbdns.CustomZone)
priority: PriorityLocal,
})
// zone records contain the fqdn, so we can just flatten them
var localRecords []nbdns.SimpleRecord
for _, record := range customZone.Records {
if record.Class != nbdns.DefaultClass {
log.Warnf("received an invalid class type: %s", record.Class)
continue
}
// zone records contain the fqdn, so we can just flatten them
localRecords = append(localRecords, record)
}
customZone.Records = localRecords
zones = append(zones, customZone)
}
return muxUpdates, zones, nil
return muxUpdates, localRecords, nil
}
func (s *DefaultServer) buildUpstreamHandlerUpdate(nameServerGroups []*nbdns.NameServerGroup) ([]handlerWrapper, error) {
@@ -741,7 +741,9 @@ func (s *DefaultServer) createHandlersForDomainGroup(domainGroup nsGroupsByDomai
log.Debugf("creating handler for domain=%s with priority=%d", domainGroup.domain, priority)
handler, err := newUpstreamResolver(
s.ctx,
s.wgInterface,
s.wgInterface.Name(),
s.wgInterface.Address().IP,
s.wgInterface.Address().Network,
s.statusRecorder,
s.hostsDNSHolder,
domainGroup.domain,
@@ -922,7 +924,9 @@ func (s *DefaultServer) addHostRootZone() {
handler, err := newUpstreamResolver(
s.ctx,
s.wgInterface,
s.wgInterface.Name(),
s.wgInterface.Address().IP,
s.wgInterface.Address().Network,
s.statusRecorder,
s.hostsDNSHolder,
nbdns.RootZone,

View File

@@ -15,7 +15,6 @@ import (
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"golang.zx2c4.com/wireguard/tun/netstack"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/client/firewall/uspfilter"
@@ -82,10 +81,6 @@ func (w *mocWGIface) GetStats(_ string) (configurer.WGStats, error) {
return configurer.WGStats{}, nil
}
func (w *mocWGIface) GetNet() *netstack.Net {
return nil
}
var zoneRecords = []nbdns.SimpleRecord{
{
Name: "peera.netbird.cloud",
@@ -133,7 +128,7 @@ func TestUpdateDNSServer(t *testing.T) {
testCases := []struct {
name string
initUpstreamMap registeredHandlerMap
initLocalZones []nbdns.CustomZone
initLocalRecords []nbdns.SimpleRecord
initSerial uint64
inputSerial uint64
inputUpdate nbdns.Config
@@ -185,8 +180,8 @@ func TestUpdateDNSServer(t *testing.T) {
expectedLocalQs: []dns.Question{{Name: "peera.netbird.cloud.", Qtype: dns.TypeA, Qclass: dns.ClassINET}},
},
{
name: "New Config Should Succeed",
initLocalZones: []nbdns.CustomZone{{Domain: "netbird.cloud", Records: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: 1, Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}}}},
name: "New Config Should Succeed",
initLocalRecords: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: 1, Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}},
initUpstreamMap: registeredHandlerMap{
generateDummyHandler(zoneRecords[0].Name, nameServers).ID(): handlerWrapper{
domain: "netbird.cloud",
@@ -226,19 +221,19 @@ func TestUpdateDNSServer(t *testing.T) {
expectedLocalQs: []dns.Question{{Name: zoneRecords[0].Name, Qtype: 1, Qclass: 1}},
},
{
name: "Smaller Config Serial Should Be Skipped",
initLocalZones: []nbdns.CustomZone{},
initUpstreamMap: make(registeredHandlerMap),
initSerial: 2,
inputSerial: 1,
shouldFail: true,
name: "Smaller Config Serial Should Be Skipped",
initLocalRecords: []nbdns.SimpleRecord{},
initUpstreamMap: make(registeredHandlerMap),
initSerial: 2,
inputSerial: 1,
shouldFail: true,
},
{
name: "Empty NS Group Domain Or Not Primary Element Should Fail",
initLocalZones: []nbdns.CustomZone{},
initUpstreamMap: make(registeredHandlerMap),
initSerial: 0,
inputSerial: 1,
name: "Empty NS Group Domain Or Not Primary Element Should Fail",
initLocalRecords: []nbdns.SimpleRecord{},
initUpstreamMap: make(registeredHandlerMap),
initSerial: 0,
inputSerial: 1,
inputUpdate: nbdns.Config{
ServiceEnable: true,
CustomZones: []nbdns.CustomZone{
@@ -256,11 +251,11 @@ func TestUpdateDNSServer(t *testing.T) {
shouldFail: true,
},
{
name: "Invalid NS Group Nameservers list Should Fail",
initLocalZones: []nbdns.CustomZone{},
initUpstreamMap: make(registeredHandlerMap),
initSerial: 0,
inputSerial: 1,
name: "Invalid NS Group Nameservers list Should Fail",
initLocalRecords: []nbdns.SimpleRecord{},
initUpstreamMap: make(registeredHandlerMap),
initSerial: 0,
inputSerial: 1,
inputUpdate: nbdns.Config{
ServiceEnable: true,
CustomZones: []nbdns.CustomZone{
@@ -278,11 +273,11 @@ func TestUpdateDNSServer(t *testing.T) {
shouldFail: true,
},
{
name: "Invalid Custom Zone Records list Should Skip",
initLocalZones: []nbdns.CustomZone{},
initUpstreamMap: make(registeredHandlerMap),
initSerial: 0,
inputSerial: 1,
name: "Invalid Custom Zone Records list Should Skip",
initLocalRecords: []nbdns.SimpleRecord{},
initUpstreamMap: make(registeredHandlerMap),
initSerial: 0,
inputSerial: 1,
inputUpdate: nbdns.Config{
ServiceEnable: true,
CustomZones: []nbdns.CustomZone{
@@ -304,8 +299,8 @@ func TestUpdateDNSServer(t *testing.T) {
}},
},
{
name: "Empty Config Should Succeed and Clean Maps",
initLocalZones: []nbdns.CustomZone{{Domain: "netbird.cloud", Records: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}}}},
name: "Empty Config Should Succeed and Clean Maps",
initLocalRecords: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}},
initUpstreamMap: registeredHandlerMap{
generateDummyHandler(zoneRecords[0].Name, nameServers).ID(): handlerWrapper{
domain: zoneRecords[0].Name,
@@ -320,8 +315,8 @@ func TestUpdateDNSServer(t *testing.T) {
expectedLocalQs: []dns.Question{},
},
{
name: "Disabled Service Should clean map",
initLocalZones: []nbdns.CustomZone{{Domain: "netbird.cloud", Records: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}}}},
name: "Disabled Service Should clean map",
initLocalRecords: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}},
initUpstreamMap: registeredHandlerMap{
generateDummyHandler(zoneRecords[0].Name, nameServers).ID(): handlerWrapper{
domain: zoneRecords[0].Name,
@@ -390,7 +385,7 @@ func TestUpdateDNSServer(t *testing.T) {
}()
dnsServer.dnsMuxMap = testCase.initUpstreamMap
dnsServer.localResolver.Update(testCase.initLocalZones)
dnsServer.localResolver.Update(testCase.initLocalRecords)
dnsServer.updateSerial = testCase.initSerial
err = dnsServer.UpdateDNSServer(testCase.inputSerial, testCase.inputUpdate)
@@ -515,7 +510,8 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) {
priority: PriorityUpstream,
},
}
dnsServer.localResolver.Update([]nbdns.CustomZone{{Domain: "netbird.cloud", Records: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}}}})
//dnsServer.localResolver.RegisteredMap = local.RegistrationMap{local.BuildRecordKey("netbird.cloud", dns.ClassINET, dns.TypeA): struct{}{}}
dnsServer.localResolver.Update([]nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}})
dnsServer.updateSerial = 0
nameServers := []nbdns.NameServer{
@@ -2052,7 +2048,7 @@ func TestLocalResolverPriorityInServer(t *testing.T) {
func TestLocalResolverPriorityConstants(t *testing.T) {
// Test that priority constants are ordered correctly
assert.Greater(t, PriorityDNSRoute, PriorityLocal, "DNS Route should be higher than Local priority")
assert.Greater(t, PriorityLocal, PriorityDNSRoute, "Local priority should be higher than DNS route")
assert.Greater(t, PriorityLocal, PriorityUpstream, "Local priority should be higher than upstream")
assert.Greater(t, PriorityUpstream, PriorityDefault, "Upstream priority should be higher than default")

View File

@@ -2,6 +2,7 @@ package dns
import (
"context"
"crypto/rand"
"crypto/sha256"
"encoding/hex"
"errors"
@@ -18,10 +19,8 @@ import (
"github.com/hashicorp/go-multierror"
"github.com/miekg/dns"
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/tun/netstack"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/internal/dns/resutil"
"github.com/netbirdio/netbird/client/internal/dns/types"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/proto"
@@ -114,7 +113,10 @@ func (u *upstreamResolverBase) Stop() {
// ServeDNS handles a DNS request
func (u *upstreamResolverBase) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
logger := log.WithField("request_id", resutil.GetRequestID(w))
requestID := GenerateRequestID()
logger := log.WithField("request_id", requestID)
logger.Tracef("received upstream question: domain=%s type=%v class=%v", r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass)
u.prepareRequest(r)
@@ -200,18 +202,11 @@ func (u *upstreamResolverBase) handleUpstreamError(err error, upstream netip.Add
func (u *upstreamResolverBase) writeSuccessResponse(w dns.ResponseWriter, rm *dns.Msg, upstream netip.AddrPort, domain string, t time.Duration, logger *log.Entry) bool {
u.successCount.Add(1)
resutil.SetMeta(w, "upstream", upstream.String())
// Clear Zero bit from external responses to prevent upstream servers from
// manipulating our internal fallthrough signaling mechanism
rm.MsgHdr.Zero = false
logger.Tracef("took %s to query the upstream %s for question domain=%s", t, upstream, domain)
if err := w.WriteMsg(rm); err != nil {
logger.Errorf("failed to write DNS response for question domain=%s: %s", domain, err)
return true
}
return true
}
@@ -419,56 +414,16 @@ func ExchangeWithFallback(ctx context.Context, client *dns.Client, r *dns.Msg, u
return rm, t, nil
}
// ExchangeWithNetstack performs a DNS exchange using netstack for dialing.
// This is needed when netstack is enabled to reach peer IPs through the tunnel.
func ExchangeWithNetstack(ctx context.Context, nsNet *netstack.Net, r *dns.Msg, upstream string) (*dns.Msg, error) {
reply, err := netstackExchange(ctx, nsNet, r, upstream, "udp")
func GenerateRequestID() string {
bytes := make([]byte, 4)
_, err := rand.Read(bytes)
if err != nil {
return nil, err
log.Errorf("failed to generate request ID: %v", err)
return ""
}
// If response is truncated, retry with TCP
if reply != nil && reply.MsgHdr.Truncated {
log.Tracef("udp response for domain=%s type=%v class=%v is truncated, trying TCP",
r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass)
return netstackExchange(ctx, nsNet, r, upstream, "tcp")
}
return reply, nil
return hex.EncodeToString(bytes)
}
func netstackExchange(ctx context.Context, nsNet *netstack.Net, r *dns.Msg, upstream, network string) (*dns.Msg, error) {
conn, err := nsNet.DialContext(ctx, network, upstream)
if err != nil {
return nil, fmt.Errorf("with %s: %w", network, err)
}
defer func() {
if err := conn.Close(); err != nil {
log.Debugf("failed to close DNS connection: %v", err)
}
}()
if deadline, ok := ctx.Deadline(); ok {
if err := conn.SetDeadline(deadline); err != nil {
return nil, fmt.Errorf("set deadline: %w", err)
}
}
dnsConn := &dns.Conn{Conn: conn}
if err := dnsConn.WriteMsg(r); err != nil {
return nil, fmt.Errorf("write %s message: %w", network, err)
}
reply, err := dnsConn.ReadMsg()
if err != nil {
return nil, fmt.Errorf("read %s message: %w", network, err)
}
return reply, nil
}
// FormatPeerStatus formats peer connection status information for debugging DNS timeouts
func FormatPeerStatus(peerState *peer.State) string {
isConnected := peerState.ConnStatus == peer.StatusConnected

View File

@@ -23,7 +23,9 @@ type upstreamResolver struct {
// first time, and we need to wait for a while to start to use again the proper DNS resolver.
func newUpstreamResolver(
ctx context.Context,
_ WGIface,
_ string,
_ netip.Addr,
_ netip.Prefix,
statusRecorder *peer.Status,
hostsDNSHolder *hostsDNSHolder,
domain string,

View File

@@ -5,23 +5,22 @@ package dns
import (
"context"
"net/netip"
"runtime"
"time"
"github.com/miekg/dns"
"golang.zx2c4.com/wireguard/tun/netstack"
"github.com/netbirdio/netbird/client/internal/peer"
)
type upstreamResolver struct {
*upstreamResolverBase
nsNet *netstack.Net
}
func newUpstreamResolver(
ctx context.Context,
wgIface WGIface,
_ string,
_ netip.Addr,
_ netip.Prefix,
statusRecorder *peer.Status,
_ *hostsDNSHolder,
domain string,
@@ -29,23 +28,12 @@ func newUpstreamResolver(
upstreamResolverBase := newUpstreamResolverBase(ctx, statusRecorder, domain)
nonIOS := &upstreamResolver{
upstreamResolverBase: upstreamResolverBase,
nsNet: wgIface.GetNet(),
}
upstreamResolverBase.upstreamClient = nonIOS
return nonIOS, nil
}
func (u *upstreamResolver) exchange(ctx context.Context, upstream string, r *dns.Msg) (rm *dns.Msg, t time.Duration, err error) {
// TODO: Check if upstream DNS server is routed through a peer before using netstack.
// Similar to iOS logic, we should determine if the DNS server is reachable directly
// or needs to go through the tunnel, and only use netstack when necessary.
// For now, only use netstack on JS platform where direct access is not possible.
if u.nsNet != nil && runtime.GOOS == "js" {
start := time.Now()
reply, err := ExchangeWithNetstack(ctx, u.nsNet, r, upstream)
return reply, time.Since(start), err
}
client := &dns.Client{
Timeout: ClientTimeout,
}

View File

@@ -26,7 +26,9 @@ type upstreamResolverIOS struct {
func newUpstreamResolver(
ctx context.Context,
wgIface WGIface,
interfaceName string,
ip netip.Addr,
net netip.Prefix,
statusRecorder *peer.Status,
_ *hostsDNSHolder,
domain string,
@@ -35,9 +37,9 @@ func newUpstreamResolver(
ios := &upstreamResolverIOS{
upstreamResolverBase: upstreamResolverBase,
lIP: wgIface.Address().IP,
lNet: wgIface.Address().Network,
interfaceName: wgIface.Name(),
lIP: ip,
lNet: net,
interfaceName: interfaceName,
}
ios.upstreamClient = ios

View File

@@ -2,17 +2,13 @@ package dns
import (
"context"
"net"
"net/netip"
"strings"
"testing"
"time"
"github.com/miekg/dns"
"golang.zx2c4.com/wireguard/tun/netstack"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/internal/dns/test"
)
@@ -62,7 +58,7 @@ func TestUpstreamResolver_ServeDNS(t *testing.T) {
for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
ctx, cancel := context.WithCancel(context.TODO())
resolver, _ := newUpstreamResolver(ctx, &mockNetstackProvider{}, nil, nil, ".")
resolver, _ := newUpstreamResolver(ctx, "", netip.Addr{}, netip.Prefix{}, nil, nil, ".")
// Convert test servers to netip.AddrPort
var servers []netip.AddrPort
for _, server := range testCase.InputServers {
@@ -116,19 +112,6 @@ func TestUpstreamResolver_ServeDNS(t *testing.T) {
}
}
type mockNetstackProvider struct{}
func (m *mockNetstackProvider) Name() string { return "mock" }
func (m *mockNetstackProvider) Address() wgaddr.Address { return wgaddr.Address{} }
func (m *mockNetstackProvider) ToInterface() *net.Interface { return nil }
func (m *mockNetstackProvider) IsUserspaceBind() bool { return false }
func (m *mockNetstackProvider) GetFilter() device.PacketFilter { return nil }
func (m *mockNetstackProvider) GetDevice() *device.FilteredDevice { return nil }
func (m *mockNetstackProvider) GetNet() *netstack.Net { return nil }
func (m *mockNetstackProvider) GetInterfaceGUIDString() (string, error) {
return "", nil
}
type mockUpstreamResolver struct {
r *dns.Msg
rtt time.Duration

View File

@@ -5,8 +5,6 @@ package dns
import (
"net"
"golang.zx2c4.com/wireguard/tun/netstack"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
@@ -19,5 +17,4 @@ type WGIface interface {
IsUserspaceBind() bool
GetFilter() device.PacketFilter
GetDevice() *device.FilteredDevice
GetNet() *netstack.Net
}

View File

@@ -1,8 +1,6 @@
package dns
import (
"golang.zx2c4.com/wireguard/tun/netstack"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
@@ -14,6 +12,5 @@ type WGIface interface {
IsUserspaceBind() bool
GetFilter() device.PacketFilter
GetDevice() *device.FilteredDevice
GetNet() *netstack.Net
GetInterfaceGUIDString() (string, error)
}

View File

@@ -18,7 +18,6 @@ import (
nberrors "github.com/netbirdio/netbird/client/errors"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/internal/dns/resutil"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/route"
)
@@ -190,22 +189,29 @@ func (f *DNSForwarder) Close(ctx context.Context) error {
return nberrors.FormatErrorOrNil(result)
}
func (f *DNSForwarder) handleDNSQuery(logger *log.Entry, w dns.ResponseWriter, query *dns.Msg) *dns.Msg {
func (f *DNSForwarder) handleDNSQuery(w dns.ResponseWriter, query *dns.Msg) *dns.Msg {
if len(query.Question) == 0 {
return nil
}
question := query.Question[0]
logger.Tracef("received DNS request for DNS forwarder: domain=%s type=%s class=%s",
question.Name, dns.TypeToString[question.Qtype], dns.ClassToString[question.Qclass])
log.Tracef("received DNS request for DNS forwarder: domain=%v type=%v class=%v",
question.Name, question.Qtype, question.Qclass)
domain := strings.ToLower(question.Name)
resp := query.SetReply(query)
network := resutil.NetworkForQtype(question.Qtype)
if network == "" {
var network string
switch question.Qtype {
case dns.TypeA:
network = "ip4"
case dns.TypeAAAA:
network = "ip6"
default:
// TODO: Handle other types
resp.Rcode = dns.RcodeNotImplemented
if err := w.WriteMsg(resp); err != nil {
logger.Errorf("failed to write DNS response: %v", err)
log.Errorf("failed to write DNS response: %v", err)
}
return nil
}
@@ -215,35 +221,33 @@ func (f *DNSForwarder) handleDNSQuery(logger *log.Entry, w dns.ResponseWriter, q
if mostSpecificResId == "" {
resp.Rcode = dns.RcodeRefused
if err := w.WriteMsg(resp); err != nil {
logger.Errorf("failed to write DNS response: %v", err)
log.Errorf("failed to write DNS response: %v", err)
}
return nil
}
ctx, cancel := context.WithTimeout(context.Background(), upstreamTimeout)
defer cancel()
result := resutil.LookupIP(ctx, f.resolver, network, domain, question.Qtype)
if result.Err != nil {
f.handleDNSError(ctx, logger, w, question, resp, domain, result)
ips, err := f.resolver.LookupNetIP(ctx, network, domain)
if err != nil {
f.handleDNSError(ctx, w, question, resp, domain, err)
return nil
}
f.updateInternalState(result.IPs, mostSpecificResId, matchingEntries)
resp.Answer = append(resp.Answer, resutil.IPsToRRs(domain, result.IPs, f.ttl)...)
f.cache.set(domain, question.Qtype, result.IPs)
// Unmap IPv4-mapped IPv6 addresses that some resolvers may return
for i, ip := range ips {
ips[i] = ip.Unmap()
}
f.updateInternalState(ips, mostSpecificResId, matchingEntries)
f.addIPsToResponse(resp, domain, ips)
f.cache.set(domain, question.Qtype, ips)
return resp
}
func (f *DNSForwarder) handleDNSQueryUDP(w dns.ResponseWriter, query *dns.Msg) {
startTime := time.Now()
logger := log.WithFields(log.Fields{
"request_id": resutil.GenerateRequestID(),
"dns_id": fmt.Sprintf("%04x", query.Id),
})
resp := f.handleDNSQuery(logger, w, query)
resp := f.handleDNSQuery(w, query)
if resp == nil {
return
}
@@ -261,33 +265,19 @@ func (f *DNSForwarder) handleDNSQueryUDP(w dns.ResponseWriter, query *dns.Msg) {
}
if err := w.WriteMsg(resp); err != nil {
logger.Errorf("failed to write DNS response: %v", err)
return
log.Errorf("failed to write DNS response: %v", err)
}
logger.Tracef("response: domain=%s rcode=%s answers=%s took=%s",
query.Question[0].Name, dns.RcodeToString[resp.Rcode], resutil.FormatAnswers(resp.Answer), time.Since(startTime))
}
func (f *DNSForwarder) handleDNSQueryTCP(w dns.ResponseWriter, query *dns.Msg) {
startTime := time.Now()
logger := log.WithFields(log.Fields{
"request_id": resutil.GenerateRequestID(),
"dns_id": fmt.Sprintf("%04x", query.Id),
})
resp := f.handleDNSQuery(logger, w, query)
resp := f.handleDNSQuery(w, query)
if resp == nil {
return
}
if err := w.WriteMsg(resp); err != nil {
logger.Errorf("failed to write DNS response: %v", err)
return
log.Errorf("failed to write DNS response: %v", err)
}
logger.Tracef("response: domain=%s rcode=%s answers=%s took=%s",
query.Question[0].Name, dns.RcodeToString[resp.Rcode], resutil.FormatAnswers(resp.Answer), time.Since(startTime))
}
func (f *DNSForwarder) updateInternalState(ips []netip.Addr, mostSpecificResId route.ResID, matchingEntries []*ForwarderEntry) {
@@ -325,64 +315,140 @@ func (f *DNSForwarder) updateFirewall(matchingEntries []*ForwarderEntry, prefixe
}
}
// setResponseCodeForNotFound determines and sets the appropriate response code when IsNotFound is true
// It distinguishes between NXDOMAIN (domain doesn't exist) and NODATA (domain exists but no records of requested type)
//
// LIMITATION: This function only checks A and AAAA record types to determine domain existence.
// If a domain has only other record types (MX, TXT, CNAME, etc.) but no A/AAAA records,
// it may incorrectly return NXDOMAIN instead of NODATA. This is acceptable since the forwarder
// only handles A/AAAA queries and returns NOTIMP for other types.
func (f *DNSForwarder) setResponseCodeForNotFound(ctx context.Context, resp *dns.Msg, domain string, originalQtype uint16) {
// Try querying for a different record type to see if the domain exists
// If the original query was for AAAA, try A. If it was for A, try AAAA.
// This helps distinguish between NXDOMAIN and NODATA.
var alternativeNetwork string
switch originalQtype {
case dns.TypeAAAA:
alternativeNetwork = "ip4"
case dns.TypeA:
alternativeNetwork = "ip6"
default:
resp.Rcode = dns.RcodeNameError
return
}
if _, err := f.resolver.LookupNetIP(ctx, alternativeNetwork, domain); err != nil {
var dnsErr *net.DNSError
if errors.As(err, &dnsErr) && dnsErr.IsNotFound {
// Alternative query also returned not found - domain truly doesn't exist
resp.Rcode = dns.RcodeNameError
return
}
// Some other error (timeout, server failure, etc.) - can't determine, assume domain exists
resp.Rcode = dns.RcodeSuccess
return
}
// Alternative query succeeded - domain exists but has no records of this type
resp.Rcode = dns.RcodeSuccess
}
// handleDNSError processes DNS lookup errors and sends an appropriate error response.
func (f *DNSForwarder) handleDNSError(
ctx context.Context,
logger *log.Entry,
w dns.ResponseWriter,
question dns.Question,
resp *dns.Msg,
domain string,
result resutil.LookupResult,
err error,
) {
// Default to SERVFAIL; override below when appropriate.
resp.Rcode = dns.RcodeServerFailure
qType := question.Qtype
qTypeName := dns.TypeToString[qType]
resp.Rcode = result.Rcode
// NotFound: cache negative result and respond
if result.Rcode == dns.RcodeNameError || result.Rcode == dns.RcodeSuccess {
f.cache.set(domain, question.Qtype, nil)
// Prefer typed DNS errors; fall back to generic logging otherwise.
var dnsErr *net.DNSError
if !errors.As(err, &dnsErr) {
log.Warnf(errResolveFailed, domain, err)
if writeErr := w.WriteMsg(resp); writeErr != nil {
logger.Errorf("failed to write failure DNS response: %v", writeErr)
log.Errorf("failed to write failure DNS response: %v", writeErr)
}
return
}
// NotFound: set NXDOMAIN / appropriate code via helper.
if dnsErr.IsNotFound {
f.setResponseCodeForNotFound(ctx, resp, domain, qType)
if writeErr := w.WriteMsg(resp); writeErr != nil {
log.Errorf("failed to write failure DNS response: %v", writeErr)
}
f.cache.set(domain, question.Qtype, nil)
return
}
// Upstream failed but we might have a cached answer—serve it if present.
if ips, ok := f.cache.get(domain, qType); ok {
if len(ips) > 0 {
logger.Debugf("serving cached DNS response after upstream failure: domain=%s type=%s", domain, qTypeName)
resp.Answer = append(resp.Answer, resutil.IPsToRRs(domain, ips, f.ttl)...)
log.Debugf("serving cached DNS response after upstream failure: domain=%s type=%s", domain, qTypeName)
f.addIPsToResponse(resp, domain, ips)
resp.Rcode = dns.RcodeSuccess
if writeErr := w.WriteMsg(resp); writeErr != nil {
logger.Errorf("failed to write cached DNS response: %v", writeErr)
log.Errorf("failed to write cached DNS response: %v", writeErr)
}
return
}
// Cached negative result - re-verify NXDOMAIN vs NODATA
verifyResult := resutil.LookupIP(ctx, f.resolver, resutil.NetworkForQtype(qType), domain, qType)
if verifyResult.Rcode == dns.RcodeNameError || verifyResult.Rcode == dns.RcodeSuccess {
resp.Rcode = verifyResult.Rcode
} else { // send NXDOMAIN / appropriate code if cache is empty
f.setResponseCodeForNotFound(ctx, resp, domain, qType)
if writeErr := w.WriteMsg(resp); writeErr != nil {
logger.Errorf("failed to write failure DNS response: %v", writeErr)
log.Errorf("failed to write failure DNS response: %v", writeErr)
}
return
}
return
}
// No cache or verification failed. Log with or without the server field for more context.
var dnsErr *net.DNSError
if errors.As(result.Err, &dnsErr) && dnsErr.Server != "" {
logger.Warnf("failed to resolve: type=%s domain=%s server=%s: %v", qTypeName, domain, dnsErr.Server, result.Err)
// No cache. Log with or without the server field for more context.
if dnsErr.Server != "" {
log.Warnf("failed to resolve: type=%s domain=%s server=%s: %v", qTypeName, domain, dnsErr.Server, err)
} else {
logger.Warnf(errResolveFailed, domain, result.Err)
log.Warnf(errResolveFailed, domain, err)
}
// Write final failure response.
if writeErr := w.WriteMsg(resp); writeErr != nil {
logger.Errorf("failed to write failure DNS response: %v", writeErr)
log.Errorf("failed to write failure DNS response: %v", writeErr)
}
}
// addIPsToResponse adds IP addresses to the DNS response as appropriate A or AAAA records
func (f *DNSForwarder) addIPsToResponse(resp *dns.Msg, domain string, ips []netip.Addr) {
for _, ip := range ips {
var respRecord dns.RR
if ip.Is6() {
log.Tracef("resolved domain=%s to IPv6=%s", domain, ip)
rr := dns.AAAA{
AAAA: ip.AsSlice(),
Hdr: dns.RR_Header{
Name: domain,
Rrtype: dns.TypeAAAA,
Class: dns.ClassINET,
Ttl: f.ttl,
},
}
respRecord = &rr
} else {
log.Tracef("resolved domain=%s to IPv4=%s", domain, ip)
rr := dns.A{
A: ip.AsSlice(),
Hdr: dns.RR_Header{
Name: domain,
Rrtype: dns.TypeA,
Class: dns.ClassINET,
Ttl: f.ttl,
},
}
respRecord = &rr
}
resp.Answer = append(resp.Answer, respRecord)
}
}

View File

@@ -10,7 +10,6 @@ import (
"time"
"github.com/miekg/dns"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
@@ -318,7 +317,7 @@ func TestDNSForwarder_UnauthorizedDomainAccess(t *testing.T) {
query.SetQuestion(dns.Fqdn(tt.queryDomain), dns.TypeA)
mockWriter := &test.MockResponseWriter{}
resp := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query)
resp := forwarder.handleDNSQuery(mockWriter, query)
if tt.shouldResolve {
require.NotNil(t, resp, "Expected response for authorized domain")
@@ -466,7 +465,7 @@ func TestDNSForwarder_FirewallSetUpdates(t *testing.T) {
dnsQuery.SetQuestion(dns.Fqdn(tt.query), dns.TypeA)
mockWriter := &test.MockResponseWriter{}
resp := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, dnsQuery)
resp := forwarder.handleDNSQuery(mockWriter, dnsQuery)
// Verify response
if tt.shouldResolve {
@@ -528,7 +527,7 @@ func TestDNSForwarder_MultipleIPsInSingleUpdate(t *testing.T) {
query.SetQuestion("example.com.", dns.TypeA)
mockWriter := &test.MockResponseWriter{}
resp := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query)
resp := forwarder.handleDNSQuery(mockWriter, query)
// Verify response contains all IPs
require.NotNil(t, resp)
@@ -605,7 +604,7 @@ func TestDNSForwarder_ResponseCodes(t *testing.T) {
},
}
_ = forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query)
_ = forwarder.handleDNSQuery(mockWriter, query)
// Check the response written to the writer
require.NotNil(t, writtenResp, "Expected response to be written")
@@ -675,7 +674,7 @@ func TestDNSForwarder_ServeFromCacheOnUpstreamFailure(t *testing.T) {
q1 := &dns.Msg{}
q1.SetQuestion(dns.Fqdn("example.com"), dns.TypeA)
w1 := &test.MockResponseWriter{}
resp1 := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), w1, q1)
resp1 := forwarder.handleDNSQuery(w1, q1)
require.NotNil(t, resp1)
require.Equal(t, dns.RcodeSuccess, resp1.Rcode)
require.Len(t, resp1.Answer, 1)
@@ -685,7 +684,7 @@ func TestDNSForwarder_ServeFromCacheOnUpstreamFailure(t *testing.T) {
q2.SetQuestion(dns.Fqdn("example.com"), dns.TypeA)
var writtenResp *dns.Msg
w2 := &test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { writtenResp = m; return nil }}
_ = forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), w2, q2)
_ = forwarder.handleDNSQuery(w2, q2)
require.NotNil(t, writtenResp, "expected response to be written")
require.Equal(t, dns.RcodeSuccess, writtenResp.Rcode)
@@ -715,7 +714,7 @@ func TestDNSForwarder_CacheNormalizationCasingAndDot(t *testing.T) {
q1 := &dns.Msg{}
q1.SetQuestion(mixedQuery+".", dns.TypeA)
w1 := &test.MockResponseWriter{}
resp1 := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), w1, q1)
resp1 := forwarder.handleDNSQuery(w1, q1)
require.NotNil(t, resp1)
require.Equal(t, dns.RcodeSuccess, resp1.Rcode)
require.Len(t, resp1.Answer, 1)
@@ -729,7 +728,7 @@ func TestDNSForwarder_CacheNormalizationCasingAndDot(t *testing.T) {
q2.SetQuestion("EXAMPLE.COM", dns.TypeA)
var writtenResp *dns.Msg
w2 := &test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { writtenResp = m; return nil }}
_ = forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), w2, q2)
_ = forwarder.handleDNSQuery(w2, q2)
require.NotNil(t, writtenResp)
require.Equal(t, dns.RcodeSuccess, writtenResp.Rcode)
@@ -784,7 +783,7 @@ func TestDNSForwarder_MultipleOverlappingPatterns(t *testing.T) {
query.SetQuestion("smtp.mail.example.com.", dns.TypeA)
mockWriter := &test.MockResponseWriter{}
resp := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query)
resp := forwarder.handleDNSQuery(mockWriter, query)
require.NotNil(t, resp)
assert.Equal(t, dns.RcodeSuccess, resp.Rcode)
@@ -905,7 +904,7 @@ func TestDNSForwarder_NodataVsNxdomain(t *testing.T) {
},
}
resp := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query)
resp := forwarder.handleDNSQuery(mockWriter, query)
// If a response was returned, it means it should be written (happens in wrapper functions)
if resp != nil && writtenResp == nil {
@@ -938,7 +937,7 @@ func TestDNSForwarder_EmptyQuery(t *testing.T) {
return nil
},
}
resp := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query)
resp := forwarder.handleDNSQuery(mockWriter, query)
assert.Nil(t, resp, "Should return nil for empty query")
assert.False(t, writeCalled, "Should not write response for empty query")

View File

@@ -1251,16 +1251,11 @@ func toDNSConfig(protoDNSConfig *mgmProto.DNSConfig, network netip.Prefix) nbdns
ForwarderPort: forwarderPort,
}
protoZones := protoDNSConfig.GetCustomZones()
// Treat single zone as authoritative for backward compatibility with old servers
// that only send the peer FQDN zone without setting field 4.
singleZoneCompat := len(protoZones) == 1
for _, zone := range protoZones {
for _, zone := range protoDNSConfig.GetCustomZones() {
dnsZone := nbdns.CustomZone{
Domain: zone.GetDomain(),
SearchDomainDisabled: zone.GetSearchDomainDisabled(),
NonAuthoritative: zone.GetNonAuthoritative() && !singleZoneCompat,
SkipPTRProcess: zone.GetSkipPTRProcess(),
}
for _, record := range zone.Records {
dnsRecord := nbdns.SimpleRecord{
@@ -1748,26 +1743,22 @@ func (e *Engine) RunHealthProbes(waitForResult bool) bool {
}
e.syncMsgMux.Unlock()
// Skip STUN/TURN probing for JS/WASM as it's not available
relayHealthy := true
if runtime.GOOS != "js" {
var results []relay.ProbeResult
if waitForResult {
results = e.probeStunTurn.ProbeAllWaitResult(e.ctx, stuns, turns)
} else {
results = e.probeStunTurn.ProbeAll(e.ctx, stuns, turns)
}
e.statusRecorder.UpdateRelayStates(results)
for _, res := range results {
if res.Err != nil {
relayHealthy = false
break
}
}
log.Debugf("relay health check: healthy=%t", relayHealthy)
var results []relay.ProbeResult
if waitForResult {
results = e.probeStunTurn.ProbeAllWaitResult(e.ctx, stuns, turns)
} else {
results = e.probeStunTurn.ProbeAll(e.ctx, stuns, turns)
}
e.statusRecorder.UpdateRelayStates(results)
relayHealthy := true
for _, res := range results {
if res.Err != nil {
relayHealthy = false
break
}
}
log.Debugf("relay health check: healthy=%t", relayHealthy)
allHealthy := signalHealthy && managementHealthy && relayHealthy
log.Debugf("all health checks completed: healthy=%t", allHealthy)

View File

@@ -1,4 +1,5 @@
//go:build !windows
// +build !windows
package internal

View File

@@ -669,17 +669,10 @@ func (conn *Conn) isConnectedOnAllWay() (connected bool) {
}
}()
// For JS platform: only relay connection is supported
if runtime.GOOS == "js" {
return conn.statusRelay.Get() == worker.StatusConnected
}
// For non-JS platforms: check ICE connection status
if conn.statusICE.Get() == worker.StatusDisconnected && !conn.workerICE.InProgress() {
if runtime.GOOS != "js" && conn.statusICE.Get() == worker.StatusDisconnected && !conn.workerICE.InProgress() {
return false
}
// If relay is supported with peer, it must also be connected
if conn.workerRelay.IsRelayConnectionSupportedWithPeer() {
if conn.statusRelay.Get() == worker.StatusDisconnected {
return false

View File

@@ -14,7 +14,6 @@ import (
"golang.org/x/exp/maps"
"google.golang.org/grpc/codes"
gstatus "google.golang.org/grpc/status"
"google.golang.org/protobuf/types/known/durationpb"
"google.golang.org/protobuf/types/known/timestamppb"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
@@ -159,7 +158,6 @@ type FullStatus struct {
NSGroupStates []NSGroupState
NumOfForwardingRules int
LazyConnectionEnabled bool
Events []*proto.SystemEvent
}
type StatusChangeSubscription struct {
@@ -983,7 +981,6 @@ func (d *Status) GetFullStatus() FullStatus {
}
fullStatus.Peers = append(fullStatus.Peers, d.offlinePeers...)
fullStatus.Events = d.GetEventHistory()
return fullStatus
}
@@ -1184,97 +1181,3 @@ type EventSubscription struct {
func (s *EventSubscription) Events() <-chan *proto.SystemEvent {
return s.events
}
// ToProto converts FullStatus to proto.FullStatus.
func (fs FullStatus) ToProto() *proto.FullStatus {
pbFullStatus := proto.FullStatus{
ManagementState: &proto.ManagementState{},
SignalState: &proto.SignalState{},
LocalPeerState: &proto.LocalPeerState{},
Peers: []*proto.PeerState{},
}
pbFullStatus.ManagementState.URL = fs.ManagementState.URL
pbFullStatus.ManagementState.Connected = fs.ManagementState.Connected
if err := fs.ManagementState.Error; err != nil {
pbFullStatus.ManagementState.Error = err.Error()
}
pbFullStatus.SignalState.URL = fs.SignalState.URL
pbFullStatus.SignalState.Connected = fs.SignalState.Connected
if err := fs.SignalState.Error; err != nil {
pbFullStatus.SignalState.Error = err.Error()
}
pbFullStatus.LocalPeerState.IP = fs.LocalPeerState.IP
pbFullStatus.LocalPeerState.PubKey = fs.LocalPeerState.PubKey
pbFullStatus.LocalPeerState.KernelInterface = fs.LocalPeerState.KernelInterface
pbFullStatus.LocalPeerState.Fqdn = fs.LocalPeerState.FQDN
pbFullStatus.LocalPeerState.RosenpassPermissive = fs.RosenpassState.Permissive
pbFullStatus.LocalPeerState.RosenpassEnabled = fs.RosenpassState.Enabled
pbFullStatus.NumberOfForwardingRules = int32(fs.NumOfForwardingRules)
pbFullStatus.LazyConnectionEnabled = fs.LazyConnectionEnabled
pbFullStatus.LocalPeerState.Networks = maps.Keys(fs.LocalPeerState.Routes)
for _, peerState := range fs.Peers {
networks := maps.Keys(peerState.GetRoutes())
pbPeerState := &proto.PeerState{
IP: peerState.IP,
PubKey: peerState.PubKey,
ConnStatus: peerState.ConnStatus.String(),
ConnStatusUpdate: timestamppb.New(peerState.ConnStatusUpdate),
Relayed: peerState.Relayed,
LocalIceCandidateType: peerState.LocalIceCandidateType,
RemoteIceCandidateType: peerState.RemoteIceCandidateType,
LocalIceCandidateEndpoint: peerState.LocalIceCandidateEndpoint,
RemoteIceCandidateEndpoint: peerState.RemoteIceCandidateEndpoint,
RelayAddress: peerState.RelayServerAddress,
Fqdn: peerState.FQDN,
LastWireguardHandshake: timestamppb.New(peerState.LastWireguardHandshake),
BytesRx: peerState.BytesRx,
BytesTx: peerState.BytesTx,
RosenpassEnabled: peerState.RosenpassEnabled,
Networks: networks,
Latency: durationpb.New(peerState.Latency),
SshHostKey: peerState.SSHHostKey,
}
pbFullStatus.Peers = append(pbFullStatus.Peers, pbPeerState)
}
for _, relayState := range fs.Relays {
pbRelayState := &proto.RelayState{
URI: relayState.URI,
Available: relayState.Err == nil,
}
if err := relayState.Err; err != nil {
pbRelayState.Error = err.Error()
}
pbFullStatus.Relays = append(pbFullStatus.Relays, pbRelayState)
}
for _, dnsState := range fs.NSGroupStates {
var err string
if dnsState.Error != nil {
err = dnsState.Error.Error()
}
var servers []string
for _, server := range dnsState.Servers {
servers = append(servers, server.String())
}
pbDnsState := &proto.NSGroupState{
Servers: servers,
Domains: dnsState.Domains,
Enabled: dnsState.Enabled,
Error: err,
}
pbFullStatus.DnsServers = append(pbFullStatus.DnsServers, pbDnsState)
}
pbFullStatus.Events = fs.Events
return &pbFullStatus
}

View File

@@ -17,13 +17,12 @@ import (
nberrors "github.com/netbirdio/netbird/client/errors"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface/wgaddr"
nbdns "github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/dns/resutil"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/peerstore"
"github.com/netbirdio/netbird/client/internal/routemanager/common"
"github.com/netbirdio/netbird/client/internal/routemanager/fakeip"
iface "github.com/netbirdio/netbird/client/internal/routemanager/iface"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
"github.com/netbirdio/netbird/route"
"github.com/netbirdio/netbird/shared/management/domain"
@@ -38,6 +37,11 @@ type internalDNATer interface {
AddInternalDNATMapping(netip.Addr, netip.Addr) error
}
type wgInterface interface {
Name() string
Address() wgaddr.Address
}
type DnsInterceptor struct {
mu sync.RWMutex
route *route.Route
@@ -47,7 +51,7 @@ type DnsInterceptor struct {
dnsServer nbdns.Server
currentPeerKey string
interceptedDomains domainMap
wgInterface iface.WGIface
wgInterface wgInterface
peerStore *peerstore.Store
firewall firewall.Manager
fakeIPManager *fakeip.Manager
@@ -215,14 +219,14 @@ func (d *DnsInterceptor) RemoveAllowedIPs() error {
// ServeDNS implements the dns.Handler interface
func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
logger := log.WithFields(log.Fields{
"request_id": resutil.GetRequestID(w),
"dns_id": fmt.Sprintf("%04x", r.Id),
})
requestID := nbdns.GenerateRequestID()
logger := log.WithField("request_id", requestID)
if len(r.Question) == 0 {
return
}
logger.Tracef("received DNS request for domain=%s type=%v class=%v",
r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass)
// pass if non A/AAAA query
if r.Question[0].Qtype != dns.TypeA && r.Question[0].Qtype != dns.TypeAAAA {
@@ -245,6 +249,12 @@ func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
return
}
client, err := nbdns.GetClientPrivate(d.wgInterface.Address().IP, d.wgInterface.Name(), dnsTimeout)
if err != nil {
d.writeDNSError(w, r, logger, fmt.Sprintf("create DNS client: %v", err))
return
}
if r.Extra == nil {
r.MsgHdr.AuthenticatedData = true
}
@@ -253,15 +263,32 @@ func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
ctx, cancel := context.WithTimeout(context.Background(), dnsTimeout)
defer cancel()
reply := d.queryUpstreamDNS(ctx, w, r, upstream, upstreamIP, peerKey, logger)
if reply == nil {
startTime := time.Now()
reply, _, err := nbdns.ExchangeWithFallback(ctx, client, r, upstream)
if err != nil {
if errors.Is(err, context.DeadlineExceeded) {
elapsed := time.Since(startTime)
peerInfo := d.debugPeerTimeout(upstreamIP, peerKey)
logger.Errorf("peer DNS timeout after %v (timeout=%v) for domain=%s to peer %s (%s)%s - error: %v",
elapsed.Truncate(time.Millisecond), dnsTimeout, r.Question[0].Name, upstreamIP.String(), peerKey, peerInfo, err)
} else {
logger.Errorf("failed to exchange DNS request with %s (%s) for domain=%s: %v", upstreamIP.String(), peerKey, r.Question[0].Name, err)
}
if err := w.WriteMsg(&dns.Msg{MsgHdr: dns.MsgHdr{Rcode: dns.RcodeServerFailure, Id: r.Id}}); err != nil {
logger.Errorf("failed writing DNS response: %v", err)
}
return
}
resutil.SetMeta(w, "peer", peerKey)
var answer []dns.RR
if reply != nil {
answer = reply.Answer
}
logger.Tracef("upstream %s (%s) DNS response for domain=%s answers=%v", upstreamIP.String(), peerKey, r.Question[0].Name, answer)
reply.Id = r.Id
if err := d.writeMsg(w, reply, logger); err != nil {
if err := d.writeMsg(w, reply); err != nil {
logger.Errorf("failed writing DNS response: %v", err)
}
}
@@ -297,15 +324,11 @@ func (d *DnsInterceptor) getUpstreamIP(peerKey string) (netip.Addr, error) {
return peerAllowedIP, nil
}
func (d *DnsInterceptor) writeMsg(w dns.ResponseWriter, r *dns.Msg, logger *log.Entry) error {
func (d *DnsInterceptor) writeMsg(w dns.ResponseWriter, r *dns.Msg) error {
if r == nil {
return fmt.Errorf("received nil DNS message")
}
// Clear Zero bit from peer responses to prevent external sources from
// manipulating our internal fallthrough signaling mechanism
r.MsgHdr.Zero = false
if len(r.Answer) > 0 && len(r.Question) > 0 {
origPattern := ""
if writer, ok := w.(*nbdns.ResponseWriterChain); ok {
@@ -327,14 +350,14 @@ func (d *DnsInterceptor) writeMsg(w dns.ResponseWriter, r *dns.Msg, logger *log.
case *dns.A:
addr, ok := netip.AddrFromSlice(rr.A)
if !ok {
logger.Tracef("failed to convert A record for domain=%s ip=%v", resolvedDomain, rr.A)
log.Tracef("failed to convert A record for domain=%s ip=%v", resolvedDomain, rr.A)
continue
}
ip = addr
case *dns.AAAA:
addr, ok := netip.AddrFromSlice(rr.AAAA)
if !ok {
logger.Tracef("failed to convert AAAA record for domain=%s ip=%v", resolvedDomain, rr.AAAA)
log.Tracef("failed to convert AAAA record for domain=%s ip=%v", resolvedDomain, rr.AAAA)
continue
}
ip = addr
@@ -347,11 +370,11 @@ func (d *DnsInterceptor) writeMsg(w dns.ResponseWriter, r *dns.Msg, logger *log.
}
if len(newPrefixes) > 0 {
if err := d.updateDomainPrefixes(resolvedDomain, originalDomain, newPrefixes, logger); err != nil {
logger.Errorf("failed to update domain prefixes: %v", err)
if err := d.updateDomainPrefixes(resolvedDomain, originalDomain, newPrefixes); err != nil {
log.Errorf("failed to update domain prefixes: %v", err)
}
d.replaceIPsInDNSResponse(r, newPrefixes, logger)
d.replaceIPsInDNSResponse(r, newPrefixes)
}
}
@@ -363,22 +386,22 @@ func (d *DnsInterceptor) writeMsg(w dns.ResponseWriter, r *dns.Msg, logger *log.
}
// logPrefixChanges handles the logging for prefix changes
func (d *DnsInterceptor) logPrefixChanges(resolvedDomain, originalDomain domain.Domain, toAdd, toRemove []netip.Prefix, logger *log.Entry) {
func (d *DnsInterceptor) logPrefixChanges(resolvedDomain, originalDomain domain.Domain, toAdd, toRemove []netip.Prefix) {
if len(toAdd) > 0 {
logger.Debugf("added dynamic route(s) for domain=%s (pattern: domain=%s): %s",
log.Debugf("added dynamic route(s) for domain=%s (pattern: domain=%s): %s",
resolvedDomain.SafeString(),
originalDomain.SafeString(),
toAdd)
}
if len(toRemove) > 0 && !d.route.KeepRoute {
logger.Debugf("removed dynamic route(s) for domain=%s (pattern: domain=%s): %s",
log.Debugf("removed dynamic route(s) for domain=%s (pattern: domain=%s): %s",
resolvedDomain.SafeString(),
originalDomain.SafeString(),
toRemove)
}
}
func (d *DnsInterceptor) updateDomainPrefixes(resolvedDomain, originalDomain domain.Domain, newPrefixes []netip.Prefix, logger *log.Entry) error {
func (d *DnsInterceptor) updateDomainPrefixes(resolvedDomain, originalDomain domain.Domain, newPrefixes []netip.Prefix) error {
d.mu.Lock()
defer d.mu.Unlock()
@@ -395,9 +418,9 @@ func (d *DnsInterceptor) updateDomainPrefixes(resolvedDomain, originalDomain dom
realIP := prefix.Addr()
if fakeIP, err := d.fakeIPManager.AllocateFakeIP(realIP); err == nil {
dnatMappings[fakeIP] = realIP
logger.Tracef("allocated fake IP %s for real IP %s", fakeIP, realIP)
log.Tracef("allocated fake IP %s for real IP %s", fakeIP, realIP)
} else {
logger.Errorf("failed to allocate fake IP for %s: %v", realIP, err)
log.Errorf("Failed to allocate fake IP for %s: %v", realIP, err)
}
}
}
@@ -409,7 +432,7 @@ func (d *DnsInterceptor) updateDomainPrefixes(resolvedDomain, originalDomain dom
}
}
d.addDNATMappings(dnatMappings, logger)
d.addDNATMappings(dnatMappings)
if !d.route.KeepRoute {
// Remove old prefixes
@@ -425,7 +448,7 @@ func (d *DnsInterceptor) updateDomainPrefixes(resolvedDomain, originalDomain dom
}
}
d.removeDNATMappings(toRemove, logger)
d.removeDNATMappings(toRemove)
}
// Update domain prefixes using resolved domain as key - store real IPs
@@ -440,14 +463,14 @@ func (d *DnsInterceptor) updateDomainPrefixes(resolvedDomain, originalDomain dom
// Store real IPs for status (user-facing), not fake IPs
d.statusRecorder.UpdateResolvedDomainsStates(originalDomain, resolvedDomain, newPrefixes, d.route.GetResourceID())
d.logPrefixChanges(resolvedDomain, originalDomain, toAdd, toRemove, logger)
d.logPrefixChanges(resolvedDomain, originalDomain, toAdd, toRemove)
}
return nberrors.FormatErrorOrNil(merr)
}
// removeDNATMappings removes DNAT mappings from the firewall for real IP prefixes
func (d *DnsInterceptor) removeDNATMappings(realPrefixes []netip.Prefix, logger *log.Entry) {
func (d *DnsInterceptor) removeDNATMappings(realPrefixes []netip.Prefix) {
if len(realPrefixes) == 0 {
return
}
@@ -461,9 +484,9 @@ func (d *DnsInterceptor) removeDNATMappings(realPrefixes []netip.Prefix, logger
realIP := prefix.Addr()
if fakeIP, exists := d.fakeIPManager.GetFakeIP(realIP); exists {
if err := dnatFirewall.RemoveInternalDNATMapping(fakeIP); err != nil {
logger.Errorf("failed to remove DNAT mapping for %s: %v", fakeIP, err)
log.Errorf("Failed to remove DNAT mapping for %s: %v", fakeIP, err)
} else {
logger.Debugf("removed DNAT mapping: %s -> %s", fakeIP, realIP)
log.Debugf("Removed DNAT mapping for: %s -> %s", fakeIP, realIP)
}
}
}
@@ -479,7 +502,7 @@ func (d *DnsInterceptor) internalDnatFw() (internalDNATer, bool) {
}
// addDNATMappings adds DNAT mappings to the firewall
func (d *DnsInterceptor) addDNATMappings(mappings map[netip.Addr]netip.Addr, logger *log.Entry) {
func (d *DnsInterceptor) addDNATMappings(mappings map[netip.Addr]netip.Addr) {
if len(mappings) == 0 {
return
}
@@ -491,9 +514,9 @@ func (d *DnsInterceptor) addDNATMappings(mappings map[netip.Addr]netip.Addr, log
for fakeIP, realIP := range mappings {
if err := dnatFirewall.AddInternalDNATMapping(fakeIP, realIP); err != nil {
logger.Errorf("failed to add DNAT mapping %s -> %s: %v", fakeIP, realIP, err)
log.Errorf("Failed to add DNAT mapping %s -> %s: %v", fakeIP, realIP, err)
} else {
logger.Debugf("added DNAT mapping: %s -> %s", fakeIP, realIP)
log.Debugf("Added DNAT mapping: %s -> %s", fakeIP, realIP)
}
}
}
@@ -505,12 +528,12 @@ func (d *DnsInterceptor) cleanupDNATMappings() {
}
for _, prefixes := range d.interceptedDomains {
d.removeDNATMappings(prefixes, log.NewEntry(log.StandardLogger()))
d.removeDNATMappings(prefixes)
}
}
// replaceIPsInDNSResponse replaces real IPs with fake IPs in the DNS response
func (d *DnsInterceptor) replaceIPsInDNSResponse(reply *dns.Msg, realPrefixes []netip.Prefix, logger *log.Entry) {
func (d *DnsInterceptor) replaceIPsInDNSResponse(reply *dns.Msg, realPrefixes []netip.Prefix) {
if _, ok := d.internalDnatFw(); !ok {
return
}
@@ -526,7 +549,7 @@ func (d *DnsInterceptor) replaceIPsInDNSResponse(reply *dns.Msg, realPrefixes []
if fakeIP, exists := d.fakeIPManager.GetFakeIP(realIP); exists {
rr.A = fakeIP.AsSlice()
logger.Tracef("replaced real IP %s with fake IP %s in DNS response", realIP, fakeIP)
log.Tracef("Replaced real IP %s with fake IP %s in DNS response", realIP, fakeIP)
}
case *dns.AAAA:
@@ -537,7 +560,7 @@ func (d *DnsInterceptor) replaceIPsInDNSResponse(reply *dns.Msg, realPrefixes []
if fakeIP, exists := d.fakeIPManager.GetFakeIP(realIP); exists {
rr.AAAA = fakeIP.AsSlice()
logger.Tracef("replaced real IP %s with fake IP %s in DNS response", realIP, fakeIP)
log.Tracef("Replaced real IP %s with fake IP %s in DNS response", realIP, fakeIP)
}
}
}
@@ -563,44 +586,6 @@ func determinePrefixChanges(oldPrefixes, newPrefixes []netip.Prefix) (toAdd, toR
return
}
// queryUpstreamDNS queries the upstream DNS server using netstack if available, otherwise uses regular client.
// Returns the DNS reply on success, or nil on error (error responses are written internally).
func (d *DnsInterceptor) queryUpstreamDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg, upstream string, upstreamIP netip.Addr, peerKey string, logger *log.Entry) *dns.Msg {
startTime := time.Now()
nsNet := d.wgInterface.GetNet()
var reply *dns.Msg
var err error
if nsNet != nil {
reply, err = nbdns.ExchangeWithNetstack(ctx, nsNet, r, upstream)
} else {
client, clientErr := nbdns.GetClientPrivate(d.wgInterface.Address().IP, d.wgInterface.Name(), dnsTimeout)
if clientErr != nil {
d.writeDNSError(w, r, logger, fmt.Sprintf("create DNS client: %v", clientErr))
return nil
}
reply, _, err = nbdns.ExchangeWithFallback(ctx, client, r, upstream)
}
if err == nil {
return reply
}
if errors.Is(err, context.DeadlineExceeded) {
elapsed := time.Since(startTime)
peerInfo := d.debugPeerTimeout(upstreamIP, peerKey)
logger.Errorf("peer DNS timeout after %v (timeout=%v) for domain=%s to peer %s (%s)%s - error: %v",
elapsed.Truncate(time.Millisecond), dnsTimeout, r.Question[0].Name, upstreamIP.String(), peerKey, peerInfo, err)
} else {
logger.Errorf("failed to exchange DNS request with %s (%s) for domain=%s: %v", upstreamIP.String(), peerKey, r.Question[0].Name, err)
}
if err := w.WriteMsg(&dns.Msg{MsgHdr: dns.MsgHdr{Rcode: dns.RcodeServerFailure, Id: r.Id}}); err != nil {
logger.Errorf("failed writing DNS response: %v", err)
}
return nil
}
func (d *DnsInterceptor) debugPeerTimeout(peerIP netip.Addr, peerKey string) string {
if d.statusRecorder == nil {
return ""

View File

@@ -1,4 +1,5 @@
//go:build !windows
// +build !windows
package iface

View File

@@ -4,8 +4,6 @@ import (
"net"
"net/netip"
"golang.zx2c4.com/wireguard/tun/netstack"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
@@ -20,5 +18,4 @@ type wgIfaceBase interface {
IsUserspaceBind() bool
GetFilter() device.PacketFilter
GetDevice() *device.FilteredDevice
GetNet() *netstack.Net
}

View File

@@ -210,8 +210,7 @@ func (r *SysOps) refreshLocalSubnetsCache() {
func (r *SysOps) genericAddVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
nextHop := Nexthop{netip.Addr{}, intf}
switch prefix {
case vars.Defaultv4:
if prefix == vars.Defaultv4 {
if err := r.addToRouteTable(splitDefaultv4_1, nextHop); err != nil {
return err
}
@@ -234,7 +233,7 @@ func (r *SysOps) genericAddVPNRoute(prefix netip.Prefix, intf *net.Interface) er
}
return nil
case vars.Defaultv6:
} else if prefix == vars.Defaultv6 {
if err := r.addToRouteTable(splitDefaultv6_1, nextHop); err != nil {
return fmt.Errorf("add unreachable route split 1: %w", err)
}
@@ -256,8 +255,7 @@ func (r *SysOps) genericAddVPNRoute(prefix netip.Prefix, intf *net.Interface) er
func (r *SysOps) genericRemoveVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
nextHop := Nexthop{netip.Addr{}, intf}
switch prefix {
case vars.Defaultv4:
if prefix == vars.Defaultv4 {
var result *multierror.Error
if err := r.removeFromRouteTable(splitDefaultv4_1, nextHop); err != nil {
result = multierror.Append(result, err)
@@ -275,7 +273,7 @@ func (r *SysOps) genericRemoveVPNRoute(prefix netip.Prefix, intf *net.Interface)
}
return nberrors.FormatErrorOrNil(result)
case vars.Defaultv6:
} else if prefix == vars.Defaultv6 {
var result *multierror.Error
if err := r.removeFromRouteTable(splitDefaultv6_1, nextHop); err != nil {
result = multierror.Append(result, err)
@@ -285,9 +283,9 @@ func (r *SysOps) genericRemoveVPNRoute(prefix netip.Prefix, intf *net.Interface)
}
return nberrors.FormatErrorOrNil(result)
default:
return r.removeFromRouteTable(prefix, nextHop)
}
return r.removeFromRouteTable(prefix, nextHop)
}
func (r *SysOps) setupHooks(initAddresses []net.IP, stateManager *statemanager.Manager) error {

View File

@@ -76,7 +76,7 @@ type Client struct {
loginComplete bool
connectClient *internal.ConnectClient
// preloadedConfig holds config loaded from JSON (used on tvOS where file writes are blocked)
preloadedConfig *profilemanager.Config
preloadedConfig *profilemanager.Config
}
// NewClient instantiate a new Client

View File

@@ -173,9 +173,20 @@ func (s *Server) SetLogLevel(_ context.Context, req *proto.SetLogLevelRequest) (
log.SetLevel(level)
if s.connectClient != nil {
s.connectClient.SetLogLevel(level)
if s.connectClient == nil {
return nil, fmt.Errorf("connect client not initialized")
}
engine := s.connectClient.Engine()
if engine == nil {
return nil, fmt.Errorf("engine not initialized")
}
fwManager := engine.GetFirewallManager()
if fwManager == nil {
return nil, fmt.Errorf("firewall manager not initialized")
}
fwManager.SetLogLevel(level)
log.Infof("Log level set to %s", level.String())

View File

@@ -1,6 +1,8 @@
package server
import (
"context"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/proto"
@@ -27,3 +29,8 @@ func (s *Server) SubscribeEvents(req *proto.SubscribeRequest, stream proto.Daemo
}
}
}
func (s *Server) GetEvents(context.Context, *proto.GetEventsRequest) (*proto.GetEventsResponse, error) {
events := s.statusRecorder.GetEventHistory()
return &proto.GetEventsResponse{Events: events}, nil
}

View File

@@ -1,4 +1,5 @@
//go:build windows
// +build windows
package server

View File

@@ -13,12 +13,15 @@ import (
"time"
"github.com/cenkalti/backoff/v4"
"golang.org/x/exp/maps"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"google.golang.org/protobuf/types/known/durationpb"
log "github.com/sirupsen/logrus"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/metadata"
gstatus "google.golang.org/grpc/status"
"google.golang.org/protobuf/types/known/timestamppb"
"github.com/netbirdio/netbird/client/internal/auth"
"github.com/netbirdio/netbird/client/internal/profilemanager"
@@ -1064,9 +1067,11 @@ func (s *Server) Status(
if msg.GetFullPeerStatus {
s.runProbes(msg.ShouldRunProbes)
fullStatus := s.statusRecorder.GetFullStatus()
pbFullStatus := fullStatus.ToProto()
pbFullStatus := toProtoFullStatus(fullStatus)
pbFullStatus.Events = s.statusRecorder.GetEventHistory()
pbFullStatus.SshServerState = s.getSSHServerState()
statusResponse.FullStatus = pbFullStatus
}
@@ -1595,6 +1600,94 @@ func parseEnvDuration(envVar string, defaultDuration time.Duration) time.Duratio
return defaultDuration
}
func toProtoFullStatus(fullStatus peer.FullStatus) *proto.FullStatus {
pbFullStatus := proto.FullStatus{
ManagementState: &proto.ManagementState{},
SignalState: &proto.SignalState{},
LocalPeerState: &proto.LocalPeerState{},
Peers: []*proto.PeerState{},
}
pbFullStatus.ManagementState.URL = fullStatus.ManagementState.URL
pbFullStatus.ManagementState.Connected = fullStatus.ManagementState.Connected
if err := fullStatus.ManagementState.Error; err != nil {
pbFullStatus.ManagementState.Error = err.Error()
}
pbFullStatus.SignalState.URL = fullStatus.SignalState.URL
pbFullStatus.SignalState.Connected = fullStatus.SignalState.Connected
if err := fullStatus.SignalState.Error; err != nil {
pbFullStatus.SignalState.Error = err.Error()
}
pbFullStatus.LocalPeerState.IP = fullStatus.LocalPeerState.IP
pbFullStatus.LocalPeerState.PubKey = fullStatus.LocalPeerState.PubKey
pbFullStatus.LocalPeerState.KernelInterface = fullStatus.LocalPeerState.KernelInterface
pbFullStatus.LocalPeerState.Fqdn = fullStatus.LocalPeerState.FQDN
pbFullStatus.LocalPeerState.RosenpassPermissive = fullStatus.RosenpassState.Permissive
pbFullStatus.LocalPeerState.RosenpassEnabled = fullStatus.RosenpassState.Enabled
pbFullStatus.LocalPeerState.Networks = maps.Keys(fullStatus.LocalPeerState.Routes)
pbFullStatus.NumberOfForwardingRules = int32(fullStatus.NumOfForwardingRules)
pbFullStatus.LazyConnectionEnabled = fullStatus.LazyConnectionEnabled
for _, peerState := range fullStatus.Peers {
pbPeerState := &proto.PeerState{
IP: peerState.IP,
PubKey: peerState.PubKey,
ConnStatus: peerState.ConnStatus.String(),
ConnStatusUpdate: timestamppb.New(peerState.ConnStatusUpdate),
Relayed: peerState.Relayed,
LocalIceCandidateType: peerState.LocalIceCandidateType,
RemoteIceCandidateType: peerState.RemoteIceCandidateType,
LocalIceCandidateEndpoint: peerState.LocalIceCandidateEndpoint,
RemoteIceCandidateEndpoint: peerState.RemoteIceCandidateEndpoint,
RelayAddress: peerState.RelayServerAddress,
Fqdn: peerState.FQDN,
LastWireguardHandshake: timestamppb.New(peerState.LastWireguardHandshake),
BytesRx: peerState.BytesRx,
BytesTx: peerState.BytesTx,
RosenpassEnabled: peerState.RosenpassEnabled,
Networks: maps.Keys(peerState.GetRoutes()),
Latency: durationpb.New(peerState.Latency),
SshHostKey: peerState.SSHHostKey,
}
pbFullStatus.Peers = append(pbFullStatus.Peers, pbPeerState)
}
for _, relayState := range fullStatus.Relays {
pbRelayState := &proto.RelayState{
URI: relayState.URI,
Available: relayState.Err == nil,
}
if err := relayState.Err; err != nil {
pbRelayState.Error = err.Error()
}
pbFullStatus.Relays = append(pbFullStatus.Relays, pbRelayState)
}
for _, dnsState := range fullStatus.NSGroupStates {
var err string
if dnsState.Error != nil {
err = dnsState.Error.Error()
}
var servers []string
for _, server := range dnsState.Servers {
servers = append(servers, server.String())
}
pbDnsState := &proto.NSGroupState{
Servers: servers,
Domains: dnsState.Domains,
Enabled: dnsState.Enabled,
Error: err,
}
pbFullStatus.DnsServers = append(pbFullStatus.DnsServers, pbDnsState)
}
return &pbFullStatus
}
// sendTerminalNotification sends a terminal notification message
// to inform the user that the NetBird connection session has expired.
func sendTerminalNotification() error {

View File

@@ -602,13 +602,12 @@ func TestJWTAuthentication(t *testing.T) {
require.NoError(t, err)
var authMethods []cryptossh.AuthMethod
switch tc.token {
case "valid":
if tc.token == "valid" {
token := generateValidJWT(t, privateKey, issuer, audience)
authMethods = []cryptossh.AuthMethod{
cryptossh.Password(token),
}
case "invalid":
} else if tc.token == "invalid" {
invalidToken := "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ.invalid"
authMethods = []cryptossh.AuthMethod{
cryptossh.Password(invalidToken),

View File

@@ -325,64 +325,61 @@ func sortPeersByIP(peersStateDetail []PeerStateDetailOutput) {
}
}
// JSON returns the status overview as a JSON string.
func (o *OutputOverview) JSON() (string, error) {
jsonBytes, err := json.Marshal(o)
func ParseToJSON(overview OutputOverview) (string, error) {
jsonBytes, err := json.Marshal(overview)
if err != nil {
return "", fmt.Errorf("json marshal failed")
}
return string(jsonBytes), err
}
// YAML returns the status overview as a YAML string.
func (o *OutputOverview) YAML() (string, error) {
yamlBytes, err := yaml.Marshal(o)
func ParseToYAML(overview OutputOverview) (string, error) {
yamlBytes, err := yaml.Marshal(overview)
if err != nil {
return "", fmt.Errorf("yaml marshal failed")
}
return string(yamlBytes), nil
}
// GeneralSummary returns a general summary of the status overview.
func (o *OutputOverview) GeneralSummary(showURL bool, showRelays bool, showNameServers bool, showSSHSessions bool) string {
func ParseGeneralSummary(overview OutputOverview, showURL bool, showRelays bool, showNameServers bool, showSSHSessions bool) string {
var managementConnString string
if o.ManagementState.Connected {
if overview.ManagementState.Connected {
managementConnString = "Connected"
if showURL {
managementConnString = fmt.Sprintf("%s to %s", managementConnString, o.ManagementState.URL)
managementConnString = fmt.Sprintf("%s to %s", managementConnString, overview.ManagementState.URL)
}
} else {
managementConnString = "Disconnected"
if o.ManagementState.Error != "" {
managementConnString = fmt.Sprintf("%s, reason: %s", managementConnString, o.ManagementState.Error)
if overview.ManagementState.Error != "" {
managementConnString = fmt.Sprintf("%s, reason: %s", managementConnString, overview.ManagementState.Error)
}
}
var signalConnString string
if o.SignalState.Connected {
if overview.SignalState.Connected {
signalConnString = "Connected"
if showURL {
signalConnString = fmt.Sprintf("%s to %s", signalConnString, o.SignalState.URL)
signalConnString = fmt.Sprintf("%s to %s", signalConnString, overview.SignalState.URL)
}
} else {
signalConnString = "Disconnected"
if o.SignalState.Error != "" {
signalConnString = fmt.Sprintf("%s, reason: %s", signalConnString, o.SignalState.Error)
if overview.SignalState.Error != "" {
signalConnString = fmt.Sprintf("%s, reason: %s", signalConnString, overview.SignalState.Error)
}
}
interfaceTypeString := "Userspace"
interfaceIP := o.IP
if o.KernelInterface {
interfaceIP := overview.IP
if overview.KernelInterface {
interfaceTypeString = "Kernel"
} else if o.IP == "" {
} else if overview.IP == "" {
interfaceTypeString = "N/A"
interfaceIP = "N/A"
}
var relaysString string
if showRelays {
for _, relay := range o.Relays.Details {
for _, relay := range overview.Relays.Details {
available := "Available"
reason := ""
@@ -398,18 +395,18 @@ func (o *OutputOverview) GeneralSummary(showURL bool, showRelays bool, showNameS
relaysString += fmt.Sprintf("\n [%s] is %s%s", relay.URI, available, reason)
}
} else {
relaysString = fmt.Sprintf("%d/%d Available", o.Relays.Available, o.Relays.Total)
relaysString = fmt.Sprintf("%d/%d Available", overview.Relays.Available, overview.Relays.Total)
}
networks := "-"
if len(o.Networks) > 0 {
sort.Strings(o.Networks)
networks = strings.Join(o.Networks, ", ")
if len(overview.Networks) > 0 {
sort.Strings(overview.Networks)
networks = strings.Join(overview.Networks, ", ")
}
var dnsServersString string
if showNameServers {
for _, nsServerGroup := range o.NSServerGroups {
for _, nsServerGroup := range overview.NSServerGroups {
enabled := "Available"
if !nsServerGroup.Enabled {
enabled = "Unavailable"
@@ -433,25 +430,25 @@ func (o *OutputOverview) GeneralSummary(showURL bool, showRelays bool, showNameS
)
}
} else {
dnsServersString = fmt.Sprintf("%d/%d Available", countEnabled(o.NSServerGroups), len(o.NSServerGroups))
dnsServersString = fmt.Sprintf("%d/%d Available", countEnabled(overview.NSServerGroups), len(overview.NSServerGroups))
}
rosenpassEnabledStatus := "false"
if o.RosenpassEnabled {
if overview.RosenpassEnabled {
rosenpassEnabledStatus = "true"
if o.RosenpassPermissive {
if overview.RosenpassPermissive {
rosenpassEnabledStatus = "true (permissive)" //nolint:gosec
}
}
lazyConnectionEnabledStatus := "false"
if o.LazyConnectionEnabled {
if overview.LazyConnectionEnabled {
lazyConnectionEnabledStatus = "true"
}
sshServerStatus := "Disabled"
if o.SSHServerState.Enabled {
sessionCount := len(o.SSHServerState.Sessions)
if overview.SSHServerState.Enabled {
sessionCount := len(overview.SSHServerState.Sessions)
if sessionCount > 0 {
sessionWord := "session"
if sessionCount > 1 {
@@ -463,7 +460,7 @@ func (o *OutputOverview) GeneralSummary(showURL bool, showRelays bool, showNameS
}
if showSSHSessions && sessionCount > 0 {
for _, session := range o.SSHServerState.Sessions {
for _, session := range overview.SSHServerState.Sessions {
var sessionDisplay string
if session.JWTUsername != "" {
sessionDisplay = fmt.Sprintf("[%s@%s -> %s] %s",
@@ -487,7 +484,7 @@ func (o *OutputOverview) GeneralSummary(showURL bool, showRelays bool, showNameS
}
}
peersCountString := fmt.Sprintf("%d/%d Connected", o.Peers.Connected, o.Peers.Total)
peersCountString := fmt.Sprintf("%d/%d Connected", overview.Peers.Connected, overview.Peers.Total)
goos := runtime.GOOS
goarch := runtime.GOARCH
@@ -515,31 +512,30 @@ func (o *OutputOverview) GeneralSummary(showURL bool, showRelays bool, showNameS
"Forwarding rules: %d\n"+
"Peers count: %s\n",
fmt.Sprintf("%s/%s%s", goos, goarch, goarm),
o.DaemonVersion,
overview.DaemonVersion,
version.NetbirdVersion(),
o.ProfileName,
overview.ProfileName,
managementConnString,
signalConnString,
relaysString,
dnsServersString,
domain.Domain(o.FQDN).SafeString(),
domain.Domain(overview.FQDN).SafeString(),
interfaceIP,
interfaceTypeString,
rosenpassEnabledStatus,
lazyConnectionEnabledStatus,
sshServerStatus,
networks,
o.NumberOfForwardingRules,
overview.NumberOfForwardingRules,
peersCountString,
)
return summary
}
// FullDetailSummary returns a full detailed summary with peer details and events.
func (o *OutputOverview) FullDetailSummary() string {
parsedPeersString := parsePeers(o.Peers, o.RosenpassEnabled, o.RosenpassPermissive)
parsedEventsString := parseEvents(o.Events)
summary := o.GeneralSummary(true, true, true, true)
func ParseToFullDetailSummary(overview OutputOverview) string {
parsedPeersString := parsePeers(overview.Peers, overview.RosenpassEnabled, overview.RosenpassPermissive)
parsedEventsString := parseEvents(overview.Events)
summary := ParseGeneralSummary(overview, true, true, true, true)
return fmt.Sprintf(
"Peers detail:"+

View File

@@ -268,7 +268,7 @@ func TestSortingOfPeers(t *testing.T) {
}
func TestParsingToJSON(t *testing.T) {
jsonString, _ := overview.JSON()
jsonString, _ := ParseToJSON(overview)
//@formatter:off
expectedJSONString := `
@@ -404,7 +404,7 @@ func TestParsingToJSON(t *testing.T) {
}
func TestParsingToYAML(t *testing.T) {
yaml, _ := overview.YAML()
yaml, _ := ParseToYAML(overview)
expectedYAML :=
`peers:
@@ -511,7 +511,7 @@ func TestParsingToDetail(t *testing.T) {
lastConnectionUpdate2 := timeAgo(overview.Peers.Details[1].LastStatusUpdate)
lastHandshake2 := timeAgo(overview.Peers.Details[1].LastWireguardHandshake)
detail := overview.FullDetailSummary()
detail := ParseToFullDetailSummary(overview)
expectedDetail := fmt.Sprintf(
`Peers detail:
@@ -575,7 +575,7 @@ Peers count: 2/2 Connected
}
func TestParsingToShortVersion(t *testing.T) {
shortVersion := overview.GeneralSummary(false, false, false, false)
shortVersion := ParseGeneralSummary(overview, false, false, false, false)
expectedString := fmt.Sprintf("OS: %s/%s", runtime.GOOS, runtime.GOARCH) + `
Daemon version: 0.14.1

View File

@@ -1,3 +1,6 @@
//go:build android
// +build android
package system
import (

View File

@@ -1,4 +1,5 @@
//go:build !ios
// +build !ios
package system

View File

@@ -1,3 +1,6 @@
//go:build ios
// +build ios
package system
import (

View File

@@ -510,7 +510,7 @@ func (s *serviceClient) saveSettings() {
// Continue with default behavior if features can't be retrieved
} else if features != nil && features.DisableUpdateSettings {
log.Warn("Configuration updates are disabled by daemon")
dialog.ShowError(fmt.Errorf("configuration updates are disabled by daemon"), s.wSettings)
dialog.ShowError(fmt.Errorf("Configuration updates are disabled by daemon"), s.wSettings)
return
}
@@ -540,7 +540,7 @@ func (s *serviceClient) saveSettings() {
func (s *serviceClient) validateSettings() error {
if s.iPreSharedKey.Text != "" && s.iPreSharedKey.Text != censoredPreSharedKey {
if _, err := wgtypes.ParseKey(s.iPreSharedKey.Text); err != nil {
return fmt.Errorf("invalid pre-shared key value")
return fmt.Errorf("Invalid Pre-shared Key Value")
}
}
return nil
@@ -549,10 +549,10 @@ func (s *serviceClient) validateSettings() error {
func (s *serviceClient) parseNumericSettings() (int64, int64, error) {
port, err := strconv.ParseInt(s.iInterfacePort.Text, 10, 64)
if err != nil {
return 0, 0, errors.New("invalid interface port")
return 0, 0, errors.New("Invalid interface port")
}
if port < 1 || port > 65535 {
return 0, 0, errors.New("invalid interface port: out of range 1-65535")
return 0, 0, errors.New("Invalid interface port: out of range 1-65535")
}
var mtu int64
@@ -560,7 +560,7 @@ func (s *serviceClient) parseNumericSettings() (int64, int64, error) {
if mtuText != "" {
mtu, err = strconv.ParseInt(mtuText, 10, 64)
if err != nil {
return 0, 0, errors.New("invalid MTU value")
return 0, 0, errors.New("Invalid MTU value")
}
if mtu < iface.MinMTU || mtu > iface.MaxMTU {
return 0, 0, fmt.Errorf("MTU must be between %d and %d bytes", iface.MinMTU, iface.MaxMTU)
@@ -645,7 +645,7 @@ func (s *serviceClient) buildSetConfigRequest(iMngURL string, port, mtu int64) (
if sshJWTCacheTTLText != "" {
sshJWTCacheTTL, err := strconv.ParseInt(sshJWTCacheTTLText, 10, 32)
if err != nil {
return nil, errors.New("invalid SSH JWT Cache TTL value")
return nil, errors.New("Invalid SSH JWT Cache TTL value")
}
if sshJWTCacheTTL < 0 || sshJWTCacheTTL > maxSSHJWTCacheTTL {
return nil, fmt.Errorf("SSH JWT Cache TTL must be between 0 and %d seconds", maxSSHJWTCacheTTL)

View File

@@ -441,7 +441,7 @@ func (s *serviceClient) collectDebugData(
var postUpStatusOutput string
if postUpStatus != nil {
overview := nbstatus.ConvertToStatusOutputOverview(postUpStatus, params.anonymize, "", nil, nil, nil, "", profName)
postUpStatusOutput = overview.FullDetailSummary()
postUpStatusOutput = nbstatus.ParseToFullDetailSummary(overview)
}
headerPostUp := fmt.Sprintf("----- NetBird post-up - Timestamp: %s", time.Now().Format(time.RFC3339))
statusOutput := fmt.Sprintf("%s\n%s", headerPostUp, postUpStatusOutput)
@@ -458,7 +458,7 @@ func (s *serviceClient) collectDebugData(
var preDownStatusOutput string
if preDownStatus != nil {
overview := nbstatus.ConvertToStatusOutputOverview(preDownStatus, params.anonymize, "", nil, nil, nil, "", profName)
preDownStatusOutput = overview.FullDetailSummary()
preDownStatusOutput = nbstatus.ParseToFullDetailSummary(overview)
}
headerPreDown := fmt.Sprintf("----- NetBird pre-down - Timestamp: %s - Duration: %s",
time.Now().Format(time.RFC3339), params.duration)
@@ -595,7 +595,7 @@ func (s *serviceClient) createDebugBundle(anonymize bool, systemInfo bool, uploa
var statusOutput string
if statusResp != nil {
overview := nbstatus.ConvertToStatusOutputOverview(statusResp, anonymize, "", nil, nil, nil, "", profName)
statusOutput = overview.FullDetailSummary()
statusOutput = nbstatus.ParseToFullDetailSummary(overview)
}
request := &proto.DebugBundleRequest{

View File

@@ -164,7 +164,7 @@ func sendShowWindowSignal(pid int32) error {
err = windows.SetEvent(eventHandle)
if err != nil {
return fmt.Errorf("error setting event: %w", err)
return fmt.Errorf("Error setting event: %w", err)
}
return nil

View File

@@ -9,31 +9,20 @@ import (
"time"
log "github.com/sirupsen/logrus"
"google.golang.org/protobuf/encoding/protojson"
netbird "github.com/netbirdio/netbird/client/embed"
"github.com/netbirdio/netbird/client/proto"
sshdetection "github.com/netbirdio/netbird/client/ssh/detection"
nbstatus "github.com/netbirdio/netbird/client/status"
"github.com/netbirdio/netbird/client/wasm/internal/http"
"github.com/netbirdio/netbird/client/wasm/internal/rdp"
"github.com/netbirdio/netbird/client/wasm/internal/ssh"
"github.com/netbirdio/netbird/util"
"github.com/netbirdio/netbird/version"
)
const (
clientStartTimeout = 30 * time.Second
clientStopTimeout = 10 * time.Second
pingTimeout = 10 * time.Second
defaultLogLevel = "warn"
defaultSSHDetectionTimeout = 20 * time.Second
defaultPeerConnectionTimeout = 60 * time.Second
peerConnectionPollInterval = 500 * time.Millisecond
icmpEchoRequest = 8
icmpCodeEcho = 0
pingBufferSize = 1500
clientStartTimeout = 30 * time.Second
clientStopTimeout = 10 * time.Second
defaultLogLevel = "warn"
defaultSSHDetectionTimeout = 20 * time.Second
)
func main() {
@@ -124,45 +113,18 @@ func createStopMethod(client *netbird.Client) js.Func {
})
}
// validateSSHArgs validates SSH connection arguments
func validateSSHArgs(args []js.Value) (host string, port int, username string, err js.Value) {
if len(args) < 2 {
return "", 0, "", js.ValueOf("error: requires host and port")
}
if args[0].Type() != js.TypeString {
return "", 0, "", js.ValueOf("host parameter must be a string")
}
if args[1].Type() != js.TypeNumber {
return "", 0, "", js.ValueOf("port parameter must be a number")
}
host = args[0].String()
port = args[1].Int()
username = "root"
if len(args) > 2 {
if args[2].Type() == js.TypeString && args[2].String() != "" {
username = args[2].String()
} else if args[2].Type() != js.TypeString {
return "", 0, "", js.ValueOf("username parameter must be a string")
}
}
return host, port, username, js.Undefined()
}
// createSSHMethod creates the SSH connection method
func createSSHMethod(client *netbird.Client) js.Func {
return js.FuncOf(func(this js.Value, args []js.Value) any {
host, port, username, validationErr := validateSSHArgs(args)
if !validationErr.IsUndefined() {
if validationErr.Type() == js.TypeString && validationErr.String() == "error: requires host and port" {
return validationErr
}
return createPromise(func(resolve, reject js.Value) {
reject.Invoke(validationErr)
})
if len(args) < 2 {
return js.ValueOf("error: requires host and port")
}
host := args[0].String()
port := args[1].Int()
username := "root"
if len(args) > 2 && args[2].String() != "" {
username = args[2].String()
}
var jwtToken string
@@ -171,9 +133,6 @@ func createSSHMethod(client *netbird.Client) js.Func {
}
return createPromise(func(resolve, reject js.Value) {
// Note: Don't wait for peer connection here - lazy connection manager
// will open the connection when Dial is called in ssh.Connect().
sshClient := ssh.NewClient(client)
if err := sshClient.Connect(host, port, username, jwtToken); err != nil {
@@ -195,110 +154,6 @@ func createSSHMethod(client *netbird.Client) js.Func {
})
}
func performPing(client *netbird.Client, hostname string) {
ctx, cancel := context.WithTimeout(context.Background(), pingTimeout)
defer cancel()
start := time.Now()
conn, err := client.Dial(ctx, "ping", hostname)
if err != nil {
js.Global().Get("console").Call("log", fmt.Sprintf("Ping to %s failed: %v", hostname, err))
return
}
defer func() {
if err := conn.Close(); err != nil {
log.Debugf("failed to close ping connection: %v", err)
}
}()
icmpData := make([]byte, 8)
icmpData[0] = icmpEchoRequest
icmpData[1] = icmpCodeEcho
if _, err := conn.Write(icmpData); err != nil {
js.Global().Get("console").Call("log", fmt.Sprintf("Ping to %s write failed: %v", hostname, err))
return
}
buf := make([]byte, pingBufferSize)
if _, err := conn.Read(buf); err != nil {
js.Global().Get("console").Call("log", fmt.Sprintf("Ping to %s read failed: %v", hostname, err))
return
}
latency := time.Since(start)
js.Global().Get("console").Call("log", fmt.Sprintf("Ping to %s: %dms", hostname, latency.Milliseconds()))
}
func performPingTCP(client *netbird.Client, hostname string, port int) {
ctx, cancel := context.WithTimeout(context.Background(), pingTimeout)
defer cancel()
address := fmt.Sprintf("%s:%d", hostname, port)
start := time.Now()
conn, err := client.Dial(ctx, "tcp", address)
if err != nil {
js.Global().Get("console").Call("log", fmt.Sprintf("TCP ping to %s failed: %v", address, err))
return
}
latency := time.Since(start)
if err := conn.Close(); err != nil {
log.Debugf("failed to close TCP connection: %v", err)
}
js.Global().Get("console").Call("log", fmt.Sprintf("TCP ping to %s succeeded: %dms", address, latency.Milliseconds()))
}
// createPingMethod creates the ping method
func createPingMethod(client *netbird.Client) js.Func {
return js.FuncOf(func(this js.Value, args []js.Value) any {
if len(args) < 1 {
return js.ValueOf("error: hostname required")
}
if args[0].Type() != js.TypeString {
return createPromise(func(resolve, reject js.Value) {
reject.Invoke(js.ValueOf("hostname parameter must be a string"))
})
}
hostname := args[0].String()
return createPromise(func(resolve, reject js.Value) {
performPing(client, hostname)
resolve.Invoke(js.Undefined())
})
})
}
// createPingTCPMethod creates the pingtcp method
func createPingTCPMethod(client *netbird.Client) js.Func {
return js.FuncOf(func(this js.Value, args []js.Value) any {
if len(args) < 2 {
return js.ValueOf("error: hostname and port required")
}
if args[0].Type() != js.TypeString {
return createPromise(func(resolve, reject js.Value) {
reject.Invoke(js.ValueOf("hostname parameter must be a string"))
})
}
if args[1].Type() != js.TypeNumber {
return createPromise(func(resolve, reject js.Value) {
reject.Invoke(js.ValueOf("port parameter must be a number"))
})
}
hostname := args[0].String()
port := args[1].Int()
return createPromise(func(resolve, reject js.Value) {
performPingTCP(client, hostname, port)
resolve.Invoke(js.Undefined())
})
})
}
// createProxyRequestMethod creates the proxyRequest method
func createProxyRequestMethod(client *netbird.Client) js.Func {
return js.FuncOf(func(this js.Value, args []js.Value) any {
@@ -307,11 +162,6 @@ func createProxyRequestMethod(client *netbird.Client) js.Func {
}
request := args[0]
if request.Type() != js.TypeObject {
return createPromise(func(resolve, reject js.Value) {
reject.Invoke(js.ValueOf("request parameter must be an object"))
})
}
return createPromise(func(resolve, reject js.Value) {
response, err := http.ProxyRequest(client, request)
@@ -331,255 +181,23 @@ func createRDPProxyMethod(client *netbird.Client) js.Func {
return js.ValueOf("error: hostname and port required")
}
if args[0].Type() != js.TypeString {
return createPromise(func(resolve, reject js.Value) {
reject.Invoke(js.ValueOf("hostname parameter must be a string"))
})
}
if args[1].Type() != js.TypeString {
return createPromise(func(resolve, reject js.Value) {
reject.Invoke(js.ValueOf("port parameter must be a string"))
})
}
proxy := rdp.NewRDCleanPathProxy(client)
return proxy.CreateProxy(args[0].String(), args[1].String())
})
}
// getStatusOverview is a helper to get the status overview
func getStatusOverview(client *netbird.Client) (nbstatus.OutputOverview, error) {
fullStatus, err := client.Status()
if err != nil {
return nbstatus.OutputOverview{}, err
}
pbFullStatus := fullStatus.ToProto()
statusResp := &proto.StatusResponse{
DaemonVersion: version.NetbirdVersion(),
FullStatus: pbFullStatus,
}
return nbstatus.ConvertToStatusOutputOverview(statusResp, false, "", nil, nil, nil, "", ""), nil
}
// createStatusMethod creates the status method that returns JSON
func createStatusMethod(client *netbird.Client) js.Func {
return js.FuncOf(func(_ js.Value, args []js.Value) any {
return createPromise(func(resolve, reject js.Value) {
overview, err := getStatusOverview(client)
if err != nil {
reject.Invoke(js.ValueOf(err.Error()))
return
}
jsonStr, err := overview.JSON()
if err != nil {
reject.Invoke(js.ValueOf(err.Error()))
return
}
jsonObj := js.Global().Get("JSON").Call("parse", jsonStr)
resolve.Invoke(jsonObj)
})
})
}
// createStatusSummaryMethod creates the statusSummary method
func createStatusSummaryMethod(client *netbird.Client) js.Func {
return js.FuncOf(func(_ js.Value, args []js.Value) any {
return createPromise(func(resolve, reject js.Value) {
overview, err := getStatusOverview(client)
if err != nil {
reject.Invoke(js.ValueOf(err.Error()))
return
}
summary := overview.GeneralSummary(false, false, false, false)
js.Global().Get("console").Call("log", summary)
resolve.Invoke(js.Undefined())
})
})
}
// createStatusDetailMethod creates the statusDetail method
func createStatusDetailMethod(client *netbird.Client) js.Func {
return js.FuncOf(func(_ js.Value, args []js.Value) any {
log.Info("statusDetail called")
return createPromise(func(resolve, reject js.Value) {
log.Info("statusDetail: getting overview")
overview, err := getStatusOverview(client)
if err != nil {
log.Errorf("statusDetail: getStatusOverview failed: %v", err)
reject.Invoke(js.ValueOf(err.Error()))
return
}
log.Info("statusDetail: generating full detail summary")
detail := overview.FullDetailSummary()
log.Infof("statusDetail: detail length=%d", len(detail))
js.Global().Get("console").Call("log", detail)
resolve.Invoke(js.Undefined())
})
})
}
// createWaitForPeerConnectionMethod creates a method that waits for a peer to be connected
func createWaitForPeerConnectionMethod(client *netbird.Client) js.Func {
return js.FuncOf(func(_ js.Value, args []js.Value) any {
return createPromise(func(resolve, reject js.Value) {
if len(args) < 1 {
reject.Invoke(js.ValueOf("peer IP address required"))
return
}
peerIP := args[0].String()
timeoutMs := int(defaultPeerConnectionTimeout.Milliseconds())
if len(args) > 1 && !args[1].IsUndefined() && !args[1].IsNull() {
timeoutMs = args[1].Int()
}
timeout := time.Duration(timeoutMs) * time.Millisecond
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
log.Infof("Waiting for peer %s to be connected (timeout: %v)", peerIP, timeout)
ticker := time.NewTicker(peerConnectionPollInterval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
reject.Invoke(js.ValueOf(fmt.Sprintf("timeout waiting for peer %s to connect", peerIP)))
return
case <-ticker.C:
overview, err := getStatusOverview(client)
if err != nil {
log.Debugf("Error getting status: %v", err)
continue
}
for _, peer := range overview.Peers.Details {
if peer.IP == peerIP && peer.Status == "Connected" {
log.Infof("Peer %s is now connected (type: %s)", peerIP, peer.ConnType)
resolve.Invoke(js.ValueOf(map[string]interface{}{
"ip": peer.IP,
"status": peer.Status,
"connType": peer.ConnType,
}))
return
}
}
log.Tracef("Peer %s not yet connected, waiting...", peerIP)
}
}
})
})
}
// createGetSyncResponseMethod creates the getSyncResponse method that returns the latest sync response as JSON
func createGetSyncResponseMethod(client *netbird.Client) js.Func {
return js.FuncOf(func(_ js.Value, args []js.Value) any {
return createPromise(func(resolve, reject js.Value) {
syncResp, err := client.GetLatestSyncResponse()
if err != nil {
reject.Invoke(js.ValueOf(err.Error()))
return
}
options := protojson.MarshalOptions{
EmitUnpopulated: true,
UseProtoNames: true,
AllowPartial: true,
}
jsonBytes, err := options.Marshal(syncResp)
if err != nil {
reject.Invoke(js.ValueOf(fmt.Sprintf("marshal sync response: %v", err)))
return
}
jsonObj := js.Global().Get("JSON").Call("parse", string(jsonBytes))
resolve.Invoke(jsonObj)
})
})
}
// createSetLogLevelMethod creates the setLogLevel method to dynamically change logging level
func createSetLogLevelMethod(client *netbird.Client) js.Func {
return js.FuncOf(func(_ js.Value, args []js.Value) any {
if len(args) < 1 {
return js.ValueOf("error: log level required")
}
if args[0].Type() != js.TypeString {
return createPromise(func(resolve, reject js.Value) {
reject.Invoke(js.ValueOf("log level parameter must be a string"))
})
}
logLevel := args[0].String()
return createPromise(func(resolve, reject js.Value) {
if err := client.SetLogLevel(logLevel); err != nil {
reject.Invoke(js.ValueOf(fmt.Sprintf("set log level: %v", err)))
return
}
log.Infof("Log level set to: %s", logLevel)
resolve.Invoke(js.ValueOf(true))
})
})
}
// createPromise is a helper to create JavaScript promises
func createPromise(handler func(resolve, reject js.Value)) js.Value {
return js.Global().Get("Promise").New(js.FuncOf(func(_ js.Value, promiseArgs []js.Value) any {
resolve := promiseArgs[0]
reject := promiseArgs[1]
// Wrap reject to always log the error
loggingReject := js.FuncOf(func(_ js.Value, args []js.Value) any {
if len(args) > 0 {
log.Errorf("Promise rejected: %v", args[0])
js.Global().Get("console").Call("error", "WASM Promise rejected:", args[0])
}
reject.Invoke(args[0])
return nil
})
go handler(resolve, loggingReject.Value)
go handler(resolve, reject)
return nil
}))
}
// waitForPeerConnected waits for a peer with the given IP to be connected
func waitForPeerConnected(ctx context.Context, client *netbird.Client, peerIP string) error {
log.Infof("Waiting for peer %s to be connected before operation", peerIP)
ticker := time.NewTicker(peerConnectionPollInterval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return fmt.Errorf("timeout waiting for peer %s to connect", peerIP)
case <-ticker.C:
overview, err := getStatusOverview(client)
if err != nil {
log.Debugf("Error getting status while waiting for peer: %v", err)
continue
}
for _, peer := range overview.Peers.Details {
if peer.IP == peerIP && peer.Status == "Connected" {
log.Infof("Peer %s is now connected (type: %s), proceeding with operation", peerIP, peer.ConnType)
return nil
}
}
log.Tracef("Peer %s not yet connected, waiting...", peerIP)
}
}
}
// createDetectSSHServerMethod creates the SSH server detection method
func createDetectSSHServerMethod(client *netbird.Client) js.Func {
return js.FuncOf(func(this js.Value, args []js.Value) any {
@@ -602,10 +220,6 @@ func createDetectSSHServerMethod(client *netbird.Client) js.Func {
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(timeoutMs)*time.Millisecond)
defer cancel()
// Note: Don't wait for peer connection here - lazy connection manager
// will open the connection when Dial is called. Waiting would cause
// a deadlock since lazy connections only open on traffic.
serverType, err := sshdetection.DetectSSHServerType(ctx, client, host, port)
if err != nil {
reject.Invoke(err.Error())
@@ -623,25 +237,17 @@ func createClientObject(client *netbird.Client) js.Value {
obj["start"] = createStartMethod(client)
obj["stop"] = createStopMethod(client)
obj["ping"] = createPingMethod(client)
obj["pingtcp"] = createPingTCPMethod(client)
obj["detectSSHServerType"] = createDetectSSHServerMethod(client)
obj["createSSHConnection"] = createSSHMethod(client)
obj["proxyRequest"] = createProxyRequestMethod(client)
obj["createRDPProxy"] = createRDPProxyMethod(client)
obj["status"] = createStatusMethod(client)
obj["statusSummary"] = createStatusSummaryMethod(client)
obj["statusDetail"] = createStatusDetailMethod(client)
obj["waitForPeerConnection"] = createWaitForPeerConnectionMethod(client)
obj["getSyncResponse"] = createGetSyncResponseMethod(client)
obj["setLogLevel"] = createSetLogLevelMethod(client)
return js.ValueOf(obj)
}
// netBirdClientConstructor acts as a JavaScript constructor function
func netBirdClientConstructor(_ js.Value, args []js.Value) any {
return js.Global().Get("Promise").New(js.FuncOf(func(_ js.Value, promiseArgs []js.Value) any {
func netBirdClientConstructor(this js.Value, args []js.Value) any {
return js.Global().Get("Promise").New(js.FuncOf(func(this js.Value, promiseArgs []js.Value) any {
resolve := promiseArgs[0]
reject := promiseArgs[1]
@@ -657,10 +263,7 @@ func netBirdClientConstructor(_ js.Value, args []js.Value) any {
return
}
// Force trace logging for debugging - must be set BEFORE netbird.New
// as New() will call logrus.SetLevel with options.LogLevel
options.LogLevel = "trace"
if err := util.InitLog("trace", util.LogConsole); err != nil {
if err := util.InitLog(options.LogLevel, util.LogConsole); err != nil {
log.Warnf("Failed to initialize logging: %v", err)
}

View File

@@ -47,8 +47,8 @@ type CustomZone struct {
Records []SimpleRecord
// SearchDomainDisabled indicates whether to add match domains to a search domains list or not
SearchDomainDisabled bool
// NonAuthoritative marks user-created zones
NonAuthoritative bool
// SkipPTRProcess indicates whether a client should process PTR records from custom zones
SkipPTRProcess bool
}
// SimpleRecord provides a simple DNS record specification for CNAME, A and AAAA records

24
go.mod
View File

@@ -1,8 +1,6 @@
module github.com/netbirdio/netbird
go 1.25
toolchain go1.25.5
go 1.24.10
require (
cunicu.li/go-rosenpass v0.4.0
@@ -42,7 +40,7 @@ require (
github.com/cilium/ebpf v0.15.0
github.com/coder/websocket v1.8.13
github.com/coreos/go-iptables v0.7.0
github.com/creack/pty v1.1.24
github.com/creack/pty v1.1.18
github.com/dexidp/dex v0.0.0-00010101000000-000000000000
github.com/dexidp/dex/api/v2 v2.4.0
github.com/eko/gocache/lib/v4 v4.2.0
@@ -78,12 +76,12 @@ require (
github.com/pion/logging v0.2.4
github.com/pion/randutil v0.1.0
github.com/pion/stun/v2 v2.0.0
github.com/pion/stun/v3 v3.1.0
github.com/pion/transport/v3 v3.1.1
github.com/pion/stun/v3 v3.0.0
github.com/pion/transport/v3 v3.0.7
github.com/pion/turn/v3 v3.0.1
github.com/pkg/sftp v1.13.9
github.com/prometheus/client_golang v1.23.2
github.com/quic-go/quic-go v0.55.0
github.com/quic-go/quic-go v0.49.1
github.com/redis/go-redis/v9 v9.7.3
github.com/rs/xid v1.3.0
github.com/shirou/gopsutil/v3 v3.24.4
@@ -105,7 +103,7 @@ require (
go.opentelemetry.io/otel/exporters/prometheus v0.48.0
go.opentelemetry.io/otel/metric v1.38.0
go.opentelemetry.io/otel/sdk/metric v1.38.0
go.uber.org/mock v0.5.2
go.uber.org/mock v0.5.0
go.uber.org/zap v1.27.0
goauthentik.io/api/v3 v3.2023051.3
golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842
@@ -122,7 +120,7 @@ require (
gorm.io/driver/postgres v1.5.7
gorm.io/driver/sqlite v1.5.7
gorm.io/gorm v1.25.12
gvisor.dev/gvisor v0.0.0-20251031020517-ecfcdd2f171c
gvisor.dev/gvisor v0.0.0-20231020174304-b8a429915ff1
)
require (
@@ -188,10 +186,12 @@ require (
github.com/go-logr/stdr v1.2.2 // indirect
github.com/go-ole/go-ole v1.3.0 // indirect
github.com/go-sql-driver/mysql v1.9.3 // indirect
github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 // indirect
github.com/go-text/render v0.2.0 // indirect
github.com/go-text/typesetting v0.2.1 // indirect
github.com/gogo/protobuf v1.3.2 // indirect
github.com/google/btree v1.1.2 // indirect
github.com/google/pprof v0.0.0-20211214055906-6f57359322fd // indirect
github.com/google/s2a-go v0.1.9 // indirect
github.com/googleapis/enterprise-certificate-proxy v0.3.7 // indirect
github.com/googleapis/gax-go/v2 v2.15.0 // indirect
@@ -241,7 +241,7 @@ require (
github.com/opencontainers/go-digest v1.0.0 // indirect
github.com/opencontainers/image-spec v1.1.0 // indirect
github.com/pion/dtls/v2 v2.2.10 // indirect
github.com/pion/dtls/v3 v3.0.9 // indirect
github.com/pion/dtls/v3 v3.0.7 // indirect
github.com/pion/mdns/v2 v2.0.7 // indirect
github.com/pion/transport/v2 v2.2.4 // indirect
github.com/pion/turn/v4 v4.1.1 // indirect
@@ -263,7 +263,7 @@ require (
github.com/tklauser/numcpus v0.8.0 // indirect
github.com/vishvananda/netns v0.0.5 // indirect
github.com/vmihailenco/tagparser/v2 v2.0.0 // indirect
github.com/wlynxg/anet v0.0.5 // indirect
github.com/wlynxg/anet v0.0.3 // indirect
github.com/yuin/goldmark v1.7.8 // indirect
github.com/zeebo/blake3 v0.2.3 // indirect
go.opentelemetry.io/auto/sdk v1.2.1 // indirect
@@ -285,7 +285,7 @@ replace github.com/kardianos/service => github.com/netbirdio/service v0.0.0-2024
replace github.com/getlantern/systray => github.com/netbirdio/systray v0.0.0-20231030152038-ef1ed2a27949
replace golang.zx2c4.com/wireguard => github.com/netbirdio/wireguard-go v0.0.0-20260107100953-33b7c9d03db0
replace golang.zx2c4.com/wireguard => github.com/netbirdio/wireguard-go v0.0.0-20241230120307-6a676aebaaf6
replace github.com/cloudflare/circl => github.com/cunicu/circl v0.0.0-20230801113412-fec58fc7b5f6

41
go.sum
View File

@@ -101,6 +101,9 @@ github.com/cenkalti/backoff/v4 v4.3.0 h1:MyRJ/UdXutAwSAT+s3wNd7MfTIcy71VQueUuFK3
github.com/cenkalti/backoff/v4 v4.3.0/go.mod h1:Y3VNntkOUPxTVeUxJ/G5vcM//AlwfmyYozVcomhLiZE=
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/chzyer/logex v1.1.10/go.mod h1:+Ywpsq7O8HXn0nuIou7OrIPyXbp3wmkHB+jjWRnGsAI=
github.com/chzyer/readline v0.0.0-20180603132655-2972be24d48e/go.mod h1:nSuG5e5PlCu98SY8svDHJxuZscDgtXS6KTTbou5AhLI=
github.com/chzyer/test v0.0.0-20180213035817-a1ea475d72b1/go.mod h1:Q3SI9o4m/ZMnBNeIyt5eFwwo7qiLfzFZmjNmxjkiQlU=
github.com/cilium/ebpf v0.15.0 h1:7NxJhNiBT3NG8pZJ3c+yfrVdHY8ScgKD27sScgjLMMk=
github.com/cilium/ebpf v0.15.0/go.mod h1:DHp1WyrLeiBh19Cf/tfiSMhqheEiK8fXFZ4No0P1Hso=
github.com/coder/websocket v1.8.13 h1:f3QZdXy7uGVz+4uCJy2nTZyM0yTBj8yANEHhqlXZ9FE=
@@ -118,8 +121,8 @@ github.com/coreos/go-oidc/v3 v3.14.1/go.mod h1:HaZ3szPaZ0e4r6ebqvsLWlk2Tn+aejfmr
github.com/cpuguy83/dockercfg v0.3.2 h1:DlJTyZGBDlXqUZ2Dk2Q3xHs/FtnooJJVaad2S9GKorA=
github.com/cpuguy83/dockercfg v0.3.2/go.mod h1:sugsbF4//dDlL/i+S+rtpIWp+5h0BHJHfjj5/jFyUJc=
github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g=
github.com/creack/pty v1.1.24 h1:bJrF4RRfyJnbTJqzRLHzcGaZK1NeM5kTC9jGgovnR1s=
github.com/creack/pty v1.1.24/go.mod h1:08sCNb52WyoAwi2QDyzUCTgcvVFhUzewun7wtTfvcwE=
github.com/creack/pty v1.1.18 h1:n56/Zwd5o6whRC5PMGretI4IdRLlmBXYNjScPaBgsbY=
github.com/creack/pty v1.1.18/go.mod h1:MOBLtS5ELjhRRrroQr9kyvTxUAFNvYEK993ew/Vr4O4=
github.com/cunicu/circl v0.0.0-20230801113412-fec58fc7b5f6 h1:/DS5cDX3FJdl+XaN2D7XAwFpuanTxnp52DBLZAaJKx0=
github.com/cunicu/circl v0.0.0-20230801113412-fec58fc7b5f6/go.mod h1:+CauBF6R70Jqcyl8N2hC8pAXYbWkGIezuSbuGLtRhnw=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
@@ -283,6 +286,7 @@ github.com/hashicorp/go-version v1.6.0/go.mod h1:fltr4n8CU8Ke44wwGCBoEymUuxUHl09
github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU=
github.com/huandu/xstrings v1.5.0 h1:2ag3IFq9ZDANvthTwTiqSSZLjDc+BedvHPAp5tJy2TI=
github.com/huandu/xstrings v1.5.0/go.mod h1:y5/lhBue+AyNmUVz9RLU9xbLR0o4KIIExikq4ovT0aE=
github.com/ianlancetaylor/demangle v0.0.0-20210905161508-09a460cdf81d/go.mod h1:aYm2/VgdVmcIU8iMfdMvDMsRAQjcfZSKFby6HOFvi/w=
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
@@ -407,8 +411,8 @@ github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9ax
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45 h1:ujgviVYmx243Ksy7NdSwrdGPSRNE3pb8kEDSpH0QuAQ=
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45/go.mod h1:5/sjFmLb8O96B5737VCqhHyGRzNFIaN/Bu7ZodXc3qQ=
github.com/netbirdio/wireguard-go v0.0.0-20260107100953-33b7c9d03db0 h1:h/QnNzm7xzHPm+gajcblYUOclrW2FeNeDlUNj6tTWKQ=
github.com/netbirdio/wireguard-go v0.0.0-20260107100953-33b7c9d03db0/go.mod h1:rpwXGsirqLqN2L0JDJQlwOboGHmptD5ZD6T2VmcqhTw=
github.com/netbirdio/wireguard-go v0.0.0-20241230120307-6a676aebaaf6 h1:X5h5QgP7uHAv78FWgHV8+WYLjHxK9v3ilkVXT1cpCrQ=
github.com/netbirdio/wireguard-go v0.0.0-20241230120307-6a676aebaaf6/go.mod h1:tkCQ4FQXmpAgYVh++1cq16/dH4QJtmvpRv19DWGAHSA=
github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646 h1:zYyBkD/k9seD2A7fsi6Oo2LfFZAehjjQMERAvZLEDnQ=
github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646/go.mod h1:jpp1/29i3P1S/RLdc7JQKbRpFeM1dOBd8T9ki5s+AY8=
github.com/nicksnyder/go-i18n/v2 v2.5.1 h1:IxtPxYsR9Gp60cGXjfuR/llTqV8aYMsC472zD0D1vHk=
@@ -444,8 +448,8 @@ github.com/petermattis/goid v0.0.0-20250303134427-723919f7f203/go.mod h1:pxMtw7c
github.com/pion/dtls/v2 v2.2.7/go.mod h1:8WiMkebSHFD0T+dIU+UeBaoV7kDhOW5oDCzZ7WZ/F9s=
github.com/pion/dtls/v2 v2.2.10 h1:u2Axk+FyIR1VFTPurktB+1zoEPGIW3bmyj3LEFrXjAA=
github.com/pion/dtls/v2 v2.2.10/go.mod h1:d9SYc9fch0CqK90mRk1dC7AkzzpwJj6u2GU3u+9pqFE=
github.com/pion/dtls/v3 v3.0.9 h1:4AijfFRm8mAjd1gfdlB1wzJF3fjjR/VPIpJgkEtvYmM=
github.com/pion/dtls/v3 v3.0.9/go.mod h1:abApPjgadS/ra1wvUzHLc3o2HvoxppAh+NZkyApL4Os=
github.com/pion/dtls/v3 v3.0.7 h1:bItXtTYYhZwkPFk4t1n3Kkf5TDrfj6+4wG+CZR8uI9Q=
github.com/pion/dtls/v3 v3.0.7/go.mod h1:uDlH5VPrgOQIw59irKYkMudSFprY9IEFCqz/eTz16f8=
github.com/pion/logging v0.2.2/go.mod h1:k0/tDVsRCX2Mb2ZEmTqNa7CWsQPc+YYCB7Q+5pahoms=
github.com/pion/logging v0.2.4 h1:tTew+7cmQ+Mc1pTBLKH2puKsOvhm32dROumOZ655zB8=
github.com/pion/logging v0.2.4/go.mod h1:DffhXTKYdNZU+KtJ5pyQDjvOAh/GsNSyv1lbkFbe3so=
@@ -455,14 +459,14 @@ github.com/pion/randutil v0.1.0 h1:CFG1UdESneORglEsnimhUjf33Rwjubwj6xfiOXBa3mA=
github.com/pion/randutil v0.1.0/go.mod h1:XcJrSMMbbMRhASFVOlj/5hQial/Y8oH/HVo7TBZq+j8=
github.com/pion/stun/v2 v2.0.0 h1:A5+wXKLAypxQri59+tmQKVs7+l6mMM+3d+eER9ifRU0=
github.com/pion/stun/v2 v2.0.0/go.mod h1:22qRSh08fSEttYUmJZGlriq9+03jtVmXNODgLccj8GQ=
github.com/pion/stun/v3 v3.1.0 h1:bS1jjT3tGWZ4UPmIUeyalOylamTMTFg1OvXtY/r6seM=
github.com/pion/stun/v3 v3.1.0/go.mod h1:egmx1CUcfSSGJxQCOjtVlomfPqmQ58BibPyuOWNGQEU=
github.com/pion/stun/v3 v3.0.0 h1:4h1gwhWLWuZWOJIJR9s2ferRO+W3zA/b6ijOI6mKzUw=
github.com/pion/stun/v3 v3.0.0/go.mod h1:HvCN8txt8mwi4FBvS3EmDghW6aQJ24T+y+1TKjB5jyU=
github.com/pion/transport/v2 v2.2.1/go.mod h1:cXXWavvCnFF6McHTft3DWS9iic2Mftcz1Aq29pGcU5g=
github.com/pion/transport/v2 v2.2.4 h1:41JJK6DZQYSeVLxILA2+F4ZkKb4Xd/tFJZRFZQ9QAlo=
github.com/pion/transport/v2 v2.2.4/go.mod h1:q2U/tf9FEfnSBGSW6w5Qp5PFWRLRj3NjLhCCgpRK4p0=
github.com/pion/transport/v3 v3.0.1/go.mod h1:UY7kiITrlMv7/IKgd5eTUcaahZx5oUN3l9SzK5f5xE0=
github.com/pion/transport/v3 v3.1.1 h1:Tr684+fnnKlhPceU+ICdrw6KKkTms+5qHMgw6bIkYOM=
github.com/pion/transport/v3 v3.1.1/go.mod h1:+c2eewC5WJQHiAA46fkMMzoYZSuGzA/7E2FPrOYHctQ=
github.com/pion/transport/v3 v3.0.7 h1:iRbMH05BzSNwhILHoBoAPxoB9xQgOaJk+591KC9P1o0=
github.com/pion/transport/v3 v3.0.7/go.mod h1:YleKiTZ4vqNxVwh77Z0zytYi7rXHl7j6uPLGhhz9rwo=
github.com/pion/turn/v3 v3.0.1 h1:wLi7BTQr6/Q20R0vt/lHbjv6y4GChFtC33nkYbasoT8=
github.com/pion/turn/v3 v3.0.1/go.mod h1:MrJDKgqryDyWy1/4NT9TWfXWGMC7UHT6pJIv1+gMeNE=
github.com/pion/turn/v4 v4.1.1 h1:9UnY2HB99tpDyz3cVVZguSxcqkJ1DsTSZ+8TGruh4fc=
@@ -487,8 +491,8 @@ github.com/prometheus/common v0.66.1 h1:h5E0h5/Y8niHc5DlaLlWLArTQI7tMrsfQjHV+d9Z
github.com/prometheus/common v0.66.1/go.mod h1:gcaUsgf3KfRSwHY4dIMXLPV0K/Wg1oZ8+SbZk/HH/dA=
github.com/prometheus/procfs v0.16.1 h1:hZ15bTNuirocR6u0JZ6BAHHmwS1p8B4P6MRqxtzMyRg=
github.com/prometheus/procfs v0.16.1/go.mod h1:teAbpZRB1iIAJYREa1LsoWUXykVXA1KlTmWl8x/U+Is=
github.com/quic-go/quic-go v0.55.0 h1:zccPQIqYCXDt5NmcEabyYvOnomjs8Tlwl7tISjJh9Mk=
github.com/quic-go/quic-go v0.55.0/go.mod h1:DR51ilwU1uE164KuWXhinFcKWGlEjzys2l8zUl5Ss1U=
github.com/quic-go/quic-go v0.49.1 h1:e5JXpUyF0f2uFjckQzD8jTghZrOUK1xxDqqZhlwixo0=
github.com/quic-go/quic-go v0.49.1/go.mod h1:s2wDnmCdooUQBmQfpUSTCYBl1/D4FcqbULMMkASvR6s=
github.com/redis/go-redis/v9 v9.7.3 h1:YpPyAayJV+XErNsatSElgRZZVCwXX9QzkKYNvO7x0wM=
github.com/redis/go-redis/v9 v9.7.3/go.mod h1:bGUrSggJ9X9GUmZpZNEOQKaANxSGgOEBRltRTZHSvrA=
github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0tI/otEQ=
@@ -574,8 +578,8 @@ github.com/vmihailenco/msgpack/v5 v5.4.1 h1:cQriyiUvjTwOHg8QZaPihLWeRAAVoCpE00IU
github.com/vmihailenco/msgpack/v5 v5.4.1/go.mod h1:GaZTsDaehaPpQVyxrf5mtQlH+pc21PIudVV/E3rRQok=
github.com/vmihailenco/tagparser/v2 v2.0.0 h1:y09buUbR+b5aycVFQs/g70pqKVZNBmxwAhO7/IwNM9g=
github.com/vmihailenco/tagparser/v2 v2.0.0/go.mod h1:Wri+At7QHww0WTrCBeu4J6bNtoV6mEfg5OIWRZA9qds=
github.com/wlynxg/anet v0.0.5 h1:J3VJGi1gvo0JwZ/P1/Yc/8p63SoW98B5dHkYDmpgvvU=
github.com/wlynxg/anet v0.0.5/go.mod h1:eay5PRQr7fIVAMbTbchTnO9gG65Hg/uYGdc7mguHxoA=
github.com/wlynxg/anet v0.0.3 h1:PvR53psxFXstc12jelG6f1Lv4MWqE0tI76/hHGjh9rg=
github.com/wlynxg/anet v0.0.3/go.mod h1:eay5PRQr7fIVAMbTbchTnO9gG65Hg/uYGdc7mguHxoA=
github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k=
@@ -618,8 +622,8 @@ go.opentelemetry.io/proto/otlp v1.0.0 h1:T0TX0tmXU8a3CbNXzEKGeU5mIVOdf0oykP+u2lI
go.opentelemetry.io/proto/otlp v1.0.0/go.mod h1:Sy6pihPLfYHkr3NkUbEhGHFhINUSI/v80hjKIs5JXpM=
go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto=
go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE=
go.uber.org/mock v0.5.2 h1:LbtPTcP8A5k9WPXj54PPPbjcI4Y6lhyOZXn+VS7wNko=
go.uber.org/mock v0.5.2/go.mod h1:wLlUxC2vVTPTaE3UD51E0BGOAElKrILxhVSDYQLld5o=
go.uber.org/mock v0.5.0 h1:KAMbZvZPyBPWgD14IrIQ38QCyjwpvVVV6K/bHl1IwQU=
go.uber.org/mock v0.5.0/go.mod h1:ge71pBPLYDk7QIi1LupWxdAykm7KIEFchiOqd6z7qMM=
go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0=
go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y=
go.uber.org/zap v1.27.0 h1:aJMhYGrd5QSmlpLMr2MftRKl7t8J8PTZPA732ud/XR8=
@@ -713,6 +717,7 @@ golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7w
golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20210616094352-59db8d763f22/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20211007075335-d3039528d8ac/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
@@ -843,5 +848,5 @@ gorm.io/gorm v1.25.12 h1:I0u8i2hWQItBq1WfE0o2+WuL9+8L21K9e2HHSTE/0f8=
gorm.io/gorm v1.25.12/go.mod h1:xh7N7RHfYlNc5EmcI/El95gXusucDrQnHXe0+CgWcLQ=
gotest.tools/v3 v3.5.0 h1:Ljk6PdHdOhAb5aDMWXjDLMMhph+BpztA4v1QdqEW2eY=
gotest.tools/v3 v3.5.0/go.mod h1:isy3WKz7GK6uNw/sbHzfKBLvlvXwUyV06n6brMxxopU=
gvisor.dev/gvisor v0.0.0-20251031020517-ecfcdd2f171c h1:pfzmXIkkDgydR4ZRP+e1hXywZfYR21FA0Fbk6ptMkiA=
gvisor.dev/gvisor v0.0.0-20251031020517-ecfcdd2f171c/go.mod h1:/mc6CfwbOm5KKmqoV7Qx20Q+Ja8+vO4g7FuCdlVoAfQ=
gvisor.dev/gvisor v0.0.0-20231020174304-b8a429915ff1 h1:qDCwdCWECGnwQSQC01Dpnp09fRHxJs9PbktotUqG+hs=
gvisor.dev/gvisor v0.0.0-20231020174304-b8a429915ff1/go.mod h1:8hmigyCdYtw5xJGfQDJzSH5Ju8XEIDBnpyi8+O6GRt8=

View File

@@ -1,113 +0,0 @@
package dex
import (
"context"
"log/slog"
"github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/formatter"
)
// LogrusHandler is an slog.Handler that delegates to logrus.
// This allows Dex to use the same log format as the rest of NetBird.
type LogrusHandler struct {
logger *logrus.Logger
attrs []slog.Attr
groups []string
}
// NewLogrusHandler creates a new slog handler that wraps logrus with NetBird's text formatter.
func NewLogrusHandler(level slog.Level) *LogrusHandler {
logger := logrus.New()
formatter.SetTextFormatter(logger)
// Map slog level to logrus level
switch level {
case slog.LevelDebug:
logger.SetLevel(logrus.DebugLevel)
case slog.LevelInfo:
logger.SetLevel(logrus.InfoLevel)
case slog.LevelWarn:
logger.SetLevel(logrus.WarnLevel)
case slog.LevelError:
logger.SetLevel(logrus.ErrorLevel)
default:
logger.SetLevel(logrus.WarnLevel)
}
return &LogrusHandler{logger: logger}
}
// Enabled reports whether the handler handles records at the given level.
func (h *LogrusHandler) Enabled(_ context.Context, level slog.Level) bool {
switch level {
case slog.LevelDebug:
return h.logger.IsLevelEnabled(logrus.DebugLevel)
case slog.LevelInfo:
return h.logger.IsLevelEnabled(logrus.InfoLevel)
case slog.LevelWarn:
return h.logger.IsLevelEnabled(logrus.WarnLevel)
case slog.LevelError:
return h.logger.IsLevelEnabled(logrus.ErrorLevel)
default:
return true
}
}
// Handle handles the Record.
func (h *LogrusHandler) Handle(_ context.Context, r slog.Record) error {
fields := make(logrus.Fields)
// Add pre-set attributes
for _, attr := range h.attrs {
fields[attr.Key] = attr.Value.Any()
}
// Add record attributes
r.Attrs(func(attr slog.Attr) bool {
fields[attr.Key] = attr.Value.Any()
return true
})
entry := h.logger.WithFields(fields)
switch r.Level {
case slog.LevelDebug:
entry.Debug(r.Message)
case slog.LevelInfo:
entry.Info(r.Message)
case slog.LevelWarn:
entry.Warn(r.Message)
case slog.LevelError:
entry.Error(r.Message)
default:
entry.Info(r.Message)
}
return nil
}
// WithAttrs returns a new Handler with the given attributes added.
func (h *LogrusHandler) WithAttrs(attrs []slog.Attr) slog.Handler {
newAttrs := make([]slog.Attr, len(h.attrs)+len(attrs))
copy(newAttrs, h.attrs)
copy(newAttrs[len(h.attrs):], attrs)
return &LogrusHandler{
logger: h.logger,
attrs: newAttrs,
groups: h.groups,
}
}
// WithGroup returns a new Handler with the given group appended to the receiver's groups.
func (h *LogrusHandler) WithGroup(name string) slog.Handler {
newGroups := make([]string, len(h.groups)+1)
copy(newGroups, h.groups)
newGroups[len(h.groups)] = name
return &LogrusHandler{
logger: h.logger,
attrs: h.attrs,
groups: newGroups,
}
}

View File

@@ -130,21 +130,7 @@ func NewProvider(ctx context.Context, config *Config) (*Provider, error) {
// NewProviderFromYAML creates and initializes the Dex server from a YAMLConfig
func NewProviderFromYAML(ctx context.Context, yamlConfig *YAMLConfig) (*Provider, error) {
// Configure log level from config, default to WARN to avoid logging sensitive data (emails)
logLevel := slog.LevelWarn
if yamlConfig.Logger.Level != "" {
switch strings.ToLower(yamlConfig.Logger.Level) {
case "debug":
logLevel = slog.LevelDebug
case "info":
logLevel = slog.LevelInfo
case "warn", "warning":
logLevel = slog.LevelWarn
case "error":
logLevel = slog.LevelError
}
}
logger := slog.New(NewLogrusHandler(logLevel))
logger := slog.New(slog.NewTextHandler(os.Stderr, nil))
stor, err := yamlConfig.Storage.OpenStorage(logger)
if err != nil {
@@ -792,12 +778,11 @@ func (p *Provider) resolveRedirectURI(redirectURI string) string {
// buildOIDCConnectorConfig creates config for OIDC-based connectors
func buildOIDCConnectorConfig(cfg *ConnectorConfig, redirectURI string) ([]byte, error) {
oidcConfig := map[string]interface{}{
"issuer": cfg.Issuer,
"clientID": cfg.ClientID,
"clientSecret": cfg.ClientSecret,
"redirectURI": redirectURI,
"scopes": []string{"openid", "profile", "email"},
"insecureEnableGroups": true,
"issuer": cfg.Issuer,
"clientID": cfg.ClientID,
"clientSecret": cfg.ClientSecret,
"redirectURI": redirectURI,
"scopes": []string{"openid", "profile", "email"},
}
switch cfg.Type {
case "zitadel":
@@ -807,9 +792,6 @@ func buildOIDCConnectorConfig(cfg *ConnectorConfig, redirectURI string) ([]byte,
oidcConfig["claimMapping"] = map[string]string{"email": "preferred_username"}
case "okta":
oidcConfig["insecureSkipEmailVerified"] = true
oidcConfig["scopes"] = []string{"openid", "profile", "email", "groups"}
case "pocketid":
oidcConfig["scopes"] = []string{"openid", "profile", "email", "groups"}
}
return encodeConnectorConfig(oidcConfig)
}

View File

@@ -270,7 +270,7 @@ AUTH_CLIENT_ID=netbird-dashboard
AUTH_CLIENT_SECRET=
AUTH_AUTHORITY=$NETBIRD_HTTP_PROTOCOL://$NETBIRD_DOMAIN/oauth2
USE_AUTH0=false
AUTH_SUPPORTED_SCOPES=openid profile email groups
AUTH_SUPPORTED_SCOPES=openid profile email offline_access
AUTH_REDIRECT_URI=/nb-auth
AUTH_SILENT_REDIRECT_URI=/nb-silent-auth
# SSL

View File

@@ -64,7 +64,7 @@ var (
config.HttpConfig.IdpSignKeyRefreshEnabled = idpSignKeyRefreshEnabled
}
var tlsEnabled bool
tlsEnabled := false
if mgmtLetsencryptDomain != "" || (config.HttpConfig.CertFile != "" && config.HttpConfig.CertKey != "") {
tlsEnabled = true
}
@@ -143,7 +143,7 @@ func loadMgmtConfig(ctx context.Context, mgmtConfigPath string) (*nbconfig.Confi
applyCommandLineOverrides(loadedConfig)
// Apply EmbeddedIdP config to HttpConfig if embedded IdP is enabled
err := applyEmbeddedIdPConfig(ctx, loadedConfig)
err := applyEmbeddedIdPConfig(loadedConfig)
if err != nil {
return nil, err
}
@@ -177,7 +177,7 @@ func applyCommandLineOverrides(cfg *nbconfig.Config) {
// applyEmbeddedIdPConfig populates HttpConfig and EmbeddedIdP storage from config when embedded IdP is enabled.
// This allows users to only specify EmbeddedIdP config without duplicating values in HttpConfig.
func applyEmbeddedIdPConfig(ctx context.Context, cfg *nbconfig.Config) error {
func applyEmbeddedIdPConfig(cfg *nbconfig.Config) error {
if cfg.EmbeddedIdP == nil || !cfg.EmbeddedIdP.Enabled {
return nil
}
@@ -190,8 +190,10 @@ func applyEmbeddedIdPConfig(ctx context.Context, cfg *nbconfig.Config) error {
// Enable user deletion from IDP by default if EmbeddedIdP is enabled
userDeleteFromIDPEnabled = true
// Set LocalAddress for embedded IdP if enabled, used for internal JWT validation
cfg.EmbeddedIdP.LocalAddress = fmt.Sprintf("localhost:%d", mgmtPort)
// Ensure HttpConfig exists
if cfg.HttpConfig == nil {
cfg.HttpConfig = &nbconfig.HttpServerConfig{}
}
// Set storage defaults based on Datadir
if cfg.EmbeddedIdP.Storage.Type == "" {
@@ -203,22 +205,40 @@ func applyEmbeddedIdPConfig(ctx context.Context, cfg *nbconfig.Config) error {
issuer := cfg.EmbeddedIdP.Issuer
if cfg.HttpConfig != nil {
log.WithContext(ctx).Warnf("overriding HttpConfig with EmbeddedIdP config. " +
"HttpConfig is ignored when EmbeddedIdP is enabled. Please remove HttpConfig section from the config file")
} else {
// Ensure HttpConfig exists. We need it for backwards compatibility with the old config format.
cfg.HttpConfig = &nbconfig.HttpServerConfig{}
// Set AuthIssuer from EmbeddedIdP issuer
if cfg.HttpConfig.AuthIssuer == "" {
cfg.HttpConfig.AuthIssuer = issuer
}
// Set HttpConfig values from EmbeddedIdP
cfg.HttpConfig.AuthIssuer = issuer
cfg.HttpConfig.AuthAudience = "netbird-dashboard"
cfg.HttpConfig.CLIAuthAudience = "netbird-cli"
cfg.HttpConfig.AuthUserIDClaim = "sub"
cfg.HttpConfig.AuthKeysLocation = issuer + "/keys"
cfg.HttpConfig.OIDCConfigEndpoint = issuer + "/.well-known/openid-configuration"
cfg.HttpConfig.IdpSignKeyRefreshEnabled = true
// Set AuthAudience to the dashboard client ID
if cfg.HttpConfig.AuthAudience == "" {
cfg.HttpConfig.AuthAudience = "netbird-dashboard"
}
// Set CLIAuthAudience to the client app client ID
if cfg.HttpConfig.CLIAuthAudience == "" {
cfg.HttpConfig.CLIAuthAudience = "netbird-cli"
}
// Set AuthUserIDClaim to "sub" (standard OIDC claim)
if cfg.HttpConfig.AuthUserIDClaim == "" {
cfg.HttpConfig.AuthUserIDClaim = "sub"
}
// Set AuthKeysLocation to the JWKS endpoint
if cfg.HttpConfig.AuthKeysLocation == "" {
cfg.HttpConfig.AuthKeysLocation = issuer + "/keys"
}
// Set OIDCConfigEndpoint to the discovery endpoint
if cfg.HttpConfig.OIDCConfigEndpoint == "" {
cfg.HttpConfig.OIDCConfigEndpoint = issuer + "/.well-known/openid-configuration"
}
// Copy SignKeyRefreshEnabled from EmbeddedIdP config
if cfg.EmbeddedIdP.SignKeyRefreshEnabled {
cfg.HttpConfig.IdpSignKeyRefreshEnabled = true
}
return nil
}
@@ -226,12 +246,7 @@ func applyEmbeddedIdPConfig(ctx context.Context, cfg *nbconfig.Config) error {
// applyOIDCConfig fetches and applies OIDC configuration if endpoint is specified
func applyOIDCConfig(ctx context.Context, cfg *nbconfig.Config) error {
oidcEndpoint := cfg.HttpConfig.OIDCConfigEndpoint
if oidcEndpoint == "" {
return nil
}
if cfg.EmbeddedIdP != nil && cfg.EmbeddedIdP.Enabled {
// skip OIDC config fetching if EmbeddedIdP is enabled as it is unnecessary given it is embedded
if oidcEndpoint == "" || cfg.EmbeddedIdP != nil {
return nil
}

View File

@@ -36,6 +36,10 @@ import (
"github.com/netbirdio/netbird/util"
)
const (
compactNetworkMapMinVersion = "v0.61.0" // TODO change to real version
)
type Controller struct {
repo Repository
metrics *metrics
@@ -483,6 +487,11 @@ func (c *Controller) getPeerNetworkMapExp(
}
}
peer := account.GetPeer(peerId)
if peer != nil && supportsCompactNetworkMap(peer) {
return account.GetPeerNetworkMapCompactExp(ctx, peerId, customZone, validatedPeers, metrics)
}
return account.GetPeerNetworkMapExp(ctx, peerId, customZone, validatedPeers, metrics)
}
@@ -622,6 +631,19 @@ func (c *Controller) StartWarmup(ctx context.Context) {
}
func supportsCompactNetworkMap(peer *nbpeer.Peer) bool {
if peer.Meta.WtVersion == "development" || peer.Meta.WtVersion == "dev" {
return true
}
peerVersion := semver.Canonical("v" + peer.Meta.WtVersion)
if peerVersion == "" {
return false
}
return semver.Compare(peerVersion, compactNetworkMapMinVersion) >= 0
}
// computeForwarderPort checks if all peers in the account have updated to a specific version or newer.
// If all peers have the required version, it returns the new well-known port (22054), otherwise returns 0.
func computeForwarderPort(peers []*nbpeer.Peer, requiredVersion string) int64 {

View File

@@ -68,8 +68,7 @@ func (s *BaseServer) AuthManager() auth.Manager {
if len(audiences) > 0 {
audience = audiences[0] // Use the first client ID as the primary audience
}
// Use localhost keys location for internal validation (management has embedded Dex)
keysLocation = oauthProvider.GetLocalKeysLocation()
keysLocation = oauthProvider.GetKeysLocation()
signingKeyRefreshEnabled = true
issuer = oauthProvider.GetIssuer()
userIDClaim = oauthProvider.GetUserIDClaim()

View File

@@ -374,9 +374,8 @@ func shouldUsePortRange(rule *proto.FirewallRule) bool {
// Helper function to convert nbdns.CustomZone to proto.CustomZone
func convertToProtoCustomZone(zone nbdns.CustomZone) *proto.CustomZone {
protoZone := &proto.CustomZone{
Domain: zone.Domain,
Records: make([]*proto.SimpleRecord, 0, len(zone.Records)),
NonAuthoritative: zone.NonAuthoritative,
Domain: zone.Domain,
Records: make([]*proto.SimpleRecord, 0, len(zone.Records)),
}
for _, record := range zone.Records {
protoZone.Records = append(protoZone.Records, &proto.SimpleRecord{

View File

@@ -85,7 +85,6 @@ func (s *LoginFilterTestSuite) TestBanDurationIncreasesExponentially() {
s.True(s.filter.logged[pubKey].isBanned)
s.Equal(2, s.filter.logged[pubKey].banLevel)
secondBanDuration := s.filter.logged[pubKey].banExpiresAt.Sub(s.filter.logged[pubKey].lastSeen)
// nolint
expectedSecondDuration := time.Duration(float64(baseBan) * math.Pow(2, 1))
s.InDelta(expectedSecondDuration, secondBanDuration, float64(time.Millisecond))
}

View File

@@ -1006,7 +1006,7 @@ func (am *DefaultAccountManager) isCacheFresh(ctx context.Context, accountUsers
for user, loggedInOnce := range accountUsers {
if datum, ok := userDataMap[user]; ok {
// check if the matching user data has a pending invite and if the user has logged in once, forcing the cache to be refreshed
if datum.AppMetadata.WTPendingInvite != nil && *datum.AppMetadata.WTPendingInvite && loggedInOnce == true { //nolint
if datum.AppMetadata.WTPendingInvite != nil && *datum.AppMetadata.WTPendingInvite && loggedInOnce == true { //nolint:gosimple
log.WithContext(ctx).Infof("user %s has a pending invite and has logged in once, cache invalid", user)
return false
}

View File

@@ -753,7 +753,7 @@ func TestAccountManager_SetOrUpdateDomain(t *testing.T) {
t.Fatalf("expected to create an account for a user %s", userId)
}
if account.Domain != domain {
if account != nil && account.Domain != domain {
t.Errorf("setting account domain failed, expected %s, got %s", domain, account.Domain)
}
@@ -768,7 +768,7 @@ func TestAccountManager_SetOrUpdateDomain(t *testing.T) {
t.Fatalf("expected to get an account for a user %s", userId)
}
if account.Domain != domain {
if account != nil && account.Domain != domain {
t.Errorf("updating domain. expected %s got %s", domain, account.Domain)
}
}
@@ -3465,11 +3465,11 @@ func TestPropagateUserGroupMemberships(t *testing.T) {
account, err := manager.GetOrCreateAccountByUser(ctx, auth.UserAuth{UserId: initiatorId, Domain: domain})
require.NoError(t, err)
peer1 := &nbpeer.Peer{ID: "peer1", AccountID: account.Id, Key: "key1", UserID: initiatorId, IP: net.IP{1, 1, 1, 1}, DNSLabel: "peer1.domain.test"}
peer1 := &nbpeer.Peer{ID: "peer1", AccountID: account.Id, UserID: initiatorId, IP: net.IP{1, 1, 1, 1}, DNSLabel: "peer1.domain.test"}
err = manager.Store.AddPeerToAccount(ctx, peer1)
require.NoError(t, err)
peer2 := &nbpeer.Peer{ID: "peer2", AccountID: account.Id, Key: "key2", UserID: initiatorId, IP: net.IP{2, 2, 2, 2}, DNSLabel: "peer2.domain.test"}
peer2 := &nbpeer.Peer{ID: "peer2", AccountID: account.Id, UserID: initiatorId, IP: net.IP{2, 2, 2, 2}, DNSLabel: "peer2.domain.test"}
err = manager.Store.AddPeerToAccount(ctx, peer2)
require.NoError(t, err)

View File

@@ -26,8 +26,6 @@ type UserDataCache interface {
Get(ctx context.Context, key string) (*idp.UserData, error)
Set(ctx context.Context, key string, value *idp.UserData, expiration time.Duration) error
Delete(ctx context.Context, key string) error
GetUsers(ctx context.Context, key string) ([]*idp.UserData, error)
SetUsers(ctx context.Context, key string, users []*idp.UserData, expiration time.Duration) error
}
// UserDataCacheImpl is a struct that implements the UserDataCache interface.
@@ -53,29 +51,6 @@ func (u *UserDataCacheImpl) Delete(ctx context.Context, key string) error {
return u.cache.Delete(ctx, key)
}
func (u *UserDataCacheImpl) GetUsers(ctx context.Context, key string) ([]*idp.UserData, error) {
var users []*idp.UserData
v, err := u.cache.Get(ctx, key, &users)
if err != nil {
return nil, err
}
switch v := v.(type) {
case []*idp.UserData:
return v, nil
case *[]*idp.UserData:
return *v, nil
case []byte:
return unmarshalUserData(v)
}
return nil, fmt.Errorf("unexpected type: %T", v)
}
func (u *UserDataCacheImpl) SetUsers(ctx context.Context, key string, users []*idp.UserData, expiration time.Duration) error {
return u.cache.Set(ctx, key, users, store.WithExpiration(expiration))
}
// NewUserDataCache creates a new UserDataCacheImpl object.
func NewUserDataCache(store store.StoreInterface) *UserDataCacheImpl {
simpleCache := cache.New[any](store)

View File

@@ -893,7 +893,6 @@ func Test_AddPeerAndAddToAll(t *testing.T) {
peer := &peer2.Peer{
ID: strconv.Itoa(i),
AccountID: accountID,
Key: "key" + strconv.Itoa(i),
DNSLabel: "peer" + strconv.Itoa(i),
IP: uint32ToIP(uint32(i)),
}

View File

@@ -46,7 +46,7 @@ func initPostureChecksTestData(postureChecks ...*posture.Checks) *postureChecksH
testPostureChecks[postureChecks.ID] = postureChecks
if err := postureChecks.Validate(); err != nil {
return nil, status.Errorf(status.InvalidArgument, "%v", err) //nolint
return nil, status.Errorf(status.InvalidArgument, "%s", err.Error()) //nolint
}
return postureChecks, nil

View File

@@ -178,7 +178,7 @@ func (m *AuthMiddleware) checkPATFromRequest(r *http.Request, authHeaderParts []
m.patUsageTracker.IncrementUsage(token)
}
if m.rateLimiter != nil && !isTerraformRequest(r) {
if m.rateLimiter != nil {
if !m.rateLimiter.Allow(token) {
return r, status.Errorf(status.TooManyRequests, "too many requests")
}
@@ -214,11 +214,6 @@ func (m *AuthMiddleware) checkPATFromRequest(r *http.Request, authHeaderParts []
return nbcontext.SetUserAuthInRequest(r, userAuth), nil
}
func isTerraformRequest(r *http.Request) bool {
ua := strings.ToLower(r.Header.Get("User-Agent"))
return strings.Contains(ua, "terraform")
}
// getTokenFromJWTRequest is a "TokenExtractor" that takes auth header parts and extracts
// the JWT token from the Authorization header.
func getTokenFromJWTRequest(authHeaderParts []string) (string, error) {

View File

@@ -508,103 +508,6 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
handler.ServeHTTP(rec, req)
assert.Equal(t, http.StatusTooManyRequests, rec.Code, "Second request after cleanup should be rate limited again")
})
t.Run("Terraform User Agent Not Rate Limited", func(t *testing.T) {
rateLimitConfig := &RateLimiterConfig{
RequestsPerMinute: 1,
Burst: 1,
CleanupInterval: 5 * time.Minute,
LimiterTTL: 10 * time.Minute,
}
authMiddleware := NewAuthMiddleware(
mockAuth,
func(ctx context.Context, userAuth nbauth.UserAuth) (string, string, error) {
return userAuth.AccountId, userAuth.UserId, nil
},
func(ctx context.Context, userAuth nbauth.UserAuth) error {
return nil
},
func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) {
return &types.User{}, nil
},
rateLimitConfig,
nil,
)
handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
// Test various Terraform user agent formats
terraformUserAgents := []string{
"Terraform/1.5.0",
"terraform/1.0.0",
"Terraform-Provider/2.0.0",
"Mozilla/5.0 (compatible; Terraform/1.3.0)",
}
for _, userAgent := range terraformUserAgents {
t.Run("UserAgent: "+userAgent, func(t *testing.T) {
successCount := 0
for i := 0; i < 10; i++ {
req := httptest.NewRequest("GET", "http://testing/test", nil)
req.Header.Set("Authorization", "Token "+PAT)
req.Header.Set("User-Agent", userAgent)
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
if rec.Code == http.StatusOK {
successCount++
}
}
assert.Equal(t, 10, successCount, "All Terraform user agent requests should succeed (not rate limited)")
})
}
})
t.Run("Non-Terraform User Agent With PAT Is Rate Limited", func(t *testing.T) {
rateLimitConfig := &RateLimiterConfig{
RequestsPerMinute: 1,
Burst: 1,
CleanupInterval: 5 * time.Minute,
LimiterTTL: 10 * time.Minute,
}
authMiddleware := NewAuthMiddleware(
mockAuth,
func(ctx context.Context, userAuth nbauth.UserAuth) (string, string, error) {
return userAuth.AccountId, userAuth.UserId, nil
},
func(ctx context.Context, userAuth nbauth.UserAuth) error {
return nil
},
func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) {
return &types.User{}, nil
},
rateLimitConfig,
nil,
)
handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
req := httptest.NewRequest("GET", "http://testing/test", nil)
req.Header.Set("Authorization", "Token "+PAT)
req.Header.Set("User-Agent", "curl/7.68.0")
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
assert.Equal(t, http.StatusOK, rec.Code, "First request should succeed")
req = httptest.NewRequest("GET", "http://testing/test", nil)
req.Header.Set("Authorization", "Token "+PAT)
req.Header.Set("User-Agent", "curl/7.68.0")
rec = httptest.NewRecorder()
handler.ServeHTTP(rec, req)
assert.Equal(t, http.StatusTooManyRequests, rec.Code, "Second request should be rate limited")
})
}
func TestAuthMiddleware_Handler_Child(t *testing.T) {

View File

@@ -1,4 +1,5 @@
//go:build benchmark
// +build benchmark
package benchmarks

View File

@@ -1,4 +1,5 @@
//go:build benchmark
// +build benchmark
package benchmarks

View File

@@ -1,4 +1,5 @@
//go:build benchmark
// +build benchmark
package benchmarks

View File

@@ -1,4 +1,5 @@
//go:build integration
// +build integration
package integration

View File

@@ -2,13 +2,7 @@ package server
import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"strings"
"time"
"github.com/dexidp/dex/storage"
"github.com/rs/xid"
@@ -23,69 +17,6 @@ import (
"github.com/netbirdio/netbird/shared/management/status"
)
// oidcProviderJSON represents the OpenID Connect discovery document
type oidcProviderJSON struct {
Issuer string `json:"issuer"`
}
// validateOIDCIssuer validates the OIDC issuer by fetching the OpenID configuration
// and verifying that the returned issuer matches the configured one.
func validateOIDCIssuer(ctx context.Context, issuer string) error {
wellKnown := strings.TrimSuffix(issuer, "/") + "/.well-known/openid-configuration"
httpClient := &http.Client{
Timeout: 10 * time.Second,
}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, wellKnown, nil)
if err != nil {
return fmt.Errorf("%w: %v", types.ErrIdentityProviderIssuerUnreachable, err)
}
resp, err := httpClient.Do(req)
if err != nil {
return fmt.Errorf("%w: %v", types.ErrIdentityProviderIssuerUnreachable, err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return fmt.Errorf("%w: unable to read response body: %v", types.ErrIdentityProviderIssuerUnreachable, err)
}
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("%w: %s: %s", types.ErrIdentityProviderIssuerUnreachable, resp.Status, body)
}
var p oidcProviderJSON
if err := json.Unmarshal(body, &p); err != nil {
return fmt.Errorf("%w: failed to decode provider discovery object: %v", types.ErrIdentityProviderIssuerUnreachable, err)
}
if p.Issuer != issuer {
return fmt.Errorf("%w: expected %q got %q", types.ErrIdentityProviderIssuerMismatch, issuer, p.Issuer)
}
return nil
}
// validateIdentityProviderConfig validates the identity provider configuration including
// basic validation and OIDC issuer verification.
func validateIdentityProviderConfig(ctx context.Context, idpConfig *types.IdentityProvider) error {
if err := idpConfig.Validate(); err != nil {
return status.Errorf(status.InvalidArgument, "%s", err.Error())
}
// Validate the issuer by calling the OIDC discovery endpoint
if idpConfig.Issuer != "" {
if err := validateOIDCIssuer(ctx, idpConfig.Issuer); err != nil {
return status.Errorf(status.InvalidArgument, "%s", err.Error())
}
}
return nil
}
// GetIdentityProviders returns all identity providers for an account
func (am *DefaultAccountManager) GetIdentityProviders(ctx context.Context, accountID, userID string) ([]*types.IdentityProvider, error) {
ok, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.IdentityProviders, operations.Read)
@@ -151,8 +82,8 @@ func (am *DefaultAccountManager) CreateIdentityProvider(ctx context.Context, acc
return nil, status.NewPermissionDeniedError()
}
if err := validateIdentityProviderConfig(ctx, idpConfig); err != nil {
return nil, err
if err := idpConfig.Validate(); err != nil {
return nil, status.Errorf(status.InvalidArgument, "%s", err.Error())
}
embeddedManager, ok := am.idpManager.(*idp.EmbeddedIdPManager)
@@ -188,8 +119,8 @@ func (am *DefaultAccountManager) UpdateIdentityProvider(ctx context.Context, acc
return nil, status.NewPermissionDeniedError()
}
if err := validateIdentityProviderConfig(ctx, idpConfig); err != nil {
return nil, err
if err := idpConfig.Validate(); err != nil {
return nil, status.Errorf(status.InvalidArgument, "%s", err.Error())
}
embeddedManager, ok := am.idpManager.(*idp.EmbeddedIdPManager)

Some files were not shown because too many files have changed in this diff Show More