Compare commits

...

12 Commits

Author SHA1 Message Date
Zoltán Papp
63598916cb Trigger 2026-01-20 17:47:51 +01:00
Zoltan Papp
d9118eb239 [client] Fix WASM peer connection to lazy peers (#5097)
WASM peers now properly initiate relay connections instead of waiting for offers that lazy peers won't send.
2026-01-13 13:33:15 +01:00
Nima Sadeghifard
94de656fae [misc] Add hiring announcement with link to careers.netbird.io (#5095) 2026-01-12 19:06:28 +01:00
Misha Bragin
37abab8b69 [management] Check config compatibility (#5087)
* Enforce HttpConfig overwrite when embeddedIdp is enabled

* Disable offline_access scope in dashboard by default

* Add group propagation foundation to embedded idp

* Require groups scope in dex config for okt and pocket

* remove offline_access from device default scopes
2026-01-12 17:09:03 +01:00
Viktor Liu
b12c084a50 [client] Fall through dns chain for custom dns zones (#5081) 2026-01-12 13:56:39 +01:00
Viktor Liu
394ad19507 [client] Chase CNAMEs in local resolver to ensure musl compatibility (#5046) 2026-01-12 12:35:38 +01:00
Misha Bragin
614e7d5b90 Validate OIDC issuer when creating or updating (#5074) 2026-01-09 09:45:43 -05:00
Misha Bragin
f7967f9ae3 Feature/resolve local jwks keys (#5073) 2026-01-09 09:41:27 -05:00
Vlad
684fc0d2a2 [management] fix the issue with duplicated peers with the same key (#5053) 2026-01-09 11:49:26 +01:00
Viktor Liu
0ad0c81899 [client] Reorder userspace ACL checks to fail faster for better performance (#4226) 2026-01-09 09:13:04 +01:00
Viktor Liu
e8863fbb55 [client] Add non-root ICMP support to userspace firewall forwarder (#4792) 2026-01-09 02:53:37 +08:00
Zoltan Papp
9c9d8e17d7 Revert "Revert "[relay] Update GO version and QUIC version (#4736)" (#5055)" (#5071)
This reverts commit 24df442198.
2026-01-08 18:58:22 +01:00
127 changed files with 2914 additions and 881 deletions

View File

@@ -1,15 +1,15 @@
FROM golang:1.23-bullseye
FROM golang:1.25-bookworm
RUN apt-get update && export DEBIAN_FRONTEND=noninteractive \
&& apt-get -y install --no-install-recommends\
gettext-base=0.21-4 \
iptables=1.8.7-1 \
libgl1-mesa-dev=20.3.5-1 \
xorg-dev=1:7.7+22 \
libayatana-appindicator3-dev=0.5.5-2+deb11u2 \
gettext-base=0.21-12 \
iptables=1.8.9-2 \
libgl1-mesa-dev=22.3.6-1+deb12u1 \
xorg-dev=1:7.7+23 \
libayatana-appindicator3-dev=0.5.92-1 \
&& apt-get clean \
&& rm -rf /var/lib/apt/lists/* \
&& go install -v golang.org/x/tools/gopls@v0.18.1
&& go install -v golang.org/x/tools/gopls@latest
WORKDIR /app

View File

@@ -25,7 +25,7 @@ jobs:
release: "14.2"
prepare: |
pkg install -y curl pkgconf xorg
GO_TARBALL="go1.24.10.freebsd-amd64.tar.gz"
GO_TARBALL="go1.25.3.freebsd-amd64.tar.gz"
GO_URL="https://go.dev/dl/$GO_TARBALL"
curl -vLO "$GO_URL"
tar -C /usr/local -vxzf "$GO_TARBALL"

View File

@@ -200,7 +200,7 @@ jobs:
-e GOCACHE=${CONTAINER_GOCACHE} \
-e GOMODCACHE=${CONTAINER_GOMODCACHE} \
-e CONTAINER=${CONTAINER} \
golang:1.24-alpine \
golang:1.25-alpine \
sh -c ' \
apk update; apk add --no-cache \
ca-certificates iptables ip6tables dbus dbus-dev libpcap-dev build-base; \
@@ -259,7 +259,7 @@ jobs:
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
go test ${{ matrix.raceFlag }} \
-exec 'sudo' \
-timeout 10m ./relay/... ./shared/relay/...
-timeout 10m -p 1 ./relay/... ./shared/relay/...
test_signal:
name: "Signal / Unit"

View File

@@ -52,7 +52,10 @@ jobs:
if: matrix.os == 'ubuntu-latest'
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev libpcap-dev
- name: golangci-lint
uses: golangci/golangci-lint-action@v4
uses: golangci/golangci-lint-action@4afd733a84b1f43292c63897423277bb7f4313a9 # v8.0.0
with:
version: latest
args: --timeout=12m --out-format colored-line-number
skip-cache: true
skip-save-cache: true
cache-invalidation-interval: 0
args: --timeout=12m

View File

@@ -63,7 +63,7 @@ jobs:
pkg install -y git curl portlint go
# Install Go for building
GO_TARBALL="go1.24.10.freebsd-amd64.tar.gz"
GO_TARBALL="go1.25.5.freebsd-amd64.tar.gz"
GO_URL="https://go.dev/dl/$GO_TARBALL"
curl -LO "$GO_URL"
tar -C /usr/local -xzf "$GO_TARBALL"

View File

@@ -14,6 +14,9 @@ jobs:
js_lint:
name: "JS / Lint"
runs-on: ubuntu-latest
env:
GOOS: js
GOARCH: wasm
steps:
- name: Checkout repository
uses: actions/checkout@v4
@@ -24,16 +27,14 @@ jobs:
- name: Install dependencies
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev libpcap-dev
- name: Install golangci-lint
uses: golangci/golangci-lint-action@d6238b002a20823d52840fda27e2d4891c5952dc
uses: golangci/golangci-lint-action@4afd733a84b1f43292c63897423277bb7f4313a9 # v8.0.0
with:
version: latest
install-mode: binary
skip-cache: true
skip-pkg-cache: true
skip-build-cache: true
- name: Run golangci-lint for WASM
run: |
GOOS=js GOARCH=wasm golangci-lint run --timeout=12m --out-format colored-line-number ./client/...
skip-save-cache: true
cache-invalidation-interval: 0
working-directory: ./client
continue-on-error: true
js_build:

View File

@@ -1,139 +1,124 @@
run:
# Timeout for analysis, e.g. 30s, 5m.
# Default: 1m
timeout: 6m
# This file contains only configs which differ from defaults.
# All possible options can be found here https://github.com/golangci/golangci-lint/blob/master/.golangci.reference.yml
linters-settings:
errcheck:
# Report about not checking of errors in type assertions: `a := b.(MyStruct)`.
# Such cases aren't reported by default.
# Default: false
check-type-assertions: false
gosec:
includes:
- G101 # Look for hard coded credentials
#- G102 # Bind to all interfaces
- G103 # Audit the use of unsafe block
- G104 # Audit errors not checked
- G106 # Audit the use of ssh.InsecureIgnoreHostKey
#- G107 # Url provided to HTTP request as taint input
- G108 # Profiling endpoint automatically exposed on /debug/pprof
- G109 # Potential Integer overflow made by strconv.Atoi result conversion to int16/32
- G110 # Potential DoS vulnerability via decompression bomb
- G111 # Potential directory traversal
#- G112 # Potential slowloris attack
- G113 # Usage of Rat.SetString in math/big with an overflow (CVE-2022-23772)
#- G114 # Use of net/http serve function that has no support for setting timeouts
- G201 # SQL query construction using format string
- G202 # SQL query construction using string concatenation
- G203 # Use of unescaped data in HTML templates
#- G204 # Audit use of command execution
- G301 # Poor file permissions used when creating a directory
- G302 # Poor file permissions used with chmod
- G303 # Creating tempfile using a predictable path
- G304 # File path provided as taint input
- G305 # File traversal when extracting zip/tar archive
- G306 # Poor file permissions used when writing to a new file
- G307 # Poor file permissions used when creating a file with os.Create
#- G401 # Detect the usage of DES, RC4, MD5 or SHA1
#- G402 # Look for bad TLS connection settings
- G403 # Ensure minimum RSA key length of 2048 bits
#- G404 # Insecure random number source (rand)
#- G501 # Import blocklist: crypto/md5
- G502 # Import blocklist: crypto/des
- G503 # Import blocklist: crypto/rc4
- G504 # Import blocklist: net/http/cgi
#- G505 # Import blocklist: crypto/sha1
- G601 # Implicit memory aliasing of items from a range statement
- G602 # Slice access out of bounds
gocritic:
disabled-checks:
- commentFormatting
- captLocal
- deprecatedComment
govet:
# Enable all analyzers.
# Default: false
enable-all: false
enable:
- nilness
revive:
rules:
- name: exported
severity: warning
disabled: false
arguments:
- "checkPrivateReceivers"
- "sayRepetitiveInsteadOfStutters"
tenv:
# The option `all` will run against whole test files (`_test.go`) regardless of method/function signatures.
# Otherwise, only methods that take `*testing.T`, `*testing.B`, and `testing.TB` as arguments are checked.
# Default: false
all: true
version: "2"
linters:
disable-all: true
default: none
enable:
## enabled by default
- errcheck # checking for unchecked errors, these unchecked errors can be critical bugs in some cases
- gosimple # specializes in simplifying a code
- govet # reports suspicious constructs, such as Printf calls whose arguments do not align with the format string
- ineffassign # detects when assignments to existing variables are not used
- staticcheck # is a go vet on steroids, applying a ton of static analysis checks
- tenv # Tenv is analyzer that detects using os.Setenv instead of t.Setenv since Go1.17.
- typecheck # like the front-end of a Go compiler, parses and type-checks Go code
- unused # checks for unused constants, variables, functions and types
## disable by default but the have interesting results so lets add them
- bodyclose # checks whether HTTP response body is closed successfully
- dupword # dupword checks for duplicate words in the source code
- durationcheck # durationcheck checks for two durations multiplied together
- forbidigo # forbidigo forbids identifiers
- gocritic # provides diagnostics that check for bugs, performance and style issues
- gosec # inspects source code for security problems
- mirror # mirror reports wrong mirror patterns of bytes/strings usage
- misspell # misspess finds commonly misspelled English words in comments
- nilerr # finds the code that returns nil even if it checks that the error is not nil
- nilnil # checks that there is no simultaneous return of nil error and an invalid value
- predeclared # predeclared finds code that shadows one of Go's predeclared identifiers
- revive # Fast, configurable, extensible, flexible, and beautiful linter for Go. Drop-in replacement of golint.
- sqlclosecheck # checks that sql.Rows and sql.Stmt are closed
# - thelper # thelper detects Go test helpers without t.Helper() call and checks the consistency of test helpers.
- wastedassign # wastedassign finds wasted assignment statements
- bodyclose
- dupword
- durationcheck
- errcheck
- forbidigo
- gocritic
- gosec
- govet
- ineffassign
- mirror
- misspell
- nilerr
- nilnil
- predeclared
- revive
- sqlclosecheck
- staticcheck
- unused
- wastedassign
settings:
errcheck:
check-type-assertions: false
gocritic:
disabled-checks:
- commentFormatting
- captLocal
- deprecatedComment
gosec:
includes:
- G101
- G103
- G104
- G106
- G108
- G109
- G110
- G111
- G201
- G202
- G203
- G301
- G302
- G303
- G304
- G305
- G306
- G307
- G403
- G502
- G503
- G504
- G601
- G602
govet:
enable:
- nilness
enable-all: false
revive:
rules:
- name: exported
arguments:
- checkPrivateReceivers
- sayRepetitiveInsteadOfStutters
severity: warning
disabled: false
exclusions:
generated: lax
presets:
- comments
- common-false-positives
- legacy
- std-error-handling
rules:
- linters:
- forbidigo
path: management/cmd/root\.go
- linters:
- forbidigo
path: signal/cmd/root\.go
- linters:
- unused
path: sharedsock/filter\.go
- linters:
- unused
path: client/firewall/iptables/rule\.go
- linters:
- gosec
- mirror
path: test\.go
- linters:
- nilnil
path: mock\.go
- linters:
- staticcheck
text: grpc.DialContext is deprecated
- linters:
- staticcheck
text: grpc.WithBlock is deprecated
- linters:
- staticcheck
text: "QF1001"
- linters:
- staticcheck
text: "QF1008"
- linters:
- staticcheck
text: "QF1012"
paths:
- third_party$
- builtin$
- examples$
issues:
# Maximum count of issues with the same text.
# Set to 0 to disable.
# Default: 3
max-same-issues: 5
exclude-rules:
# allow fmt
- path: management/cmd/root\.go
linters: forbidigo
- path: signal/cmd/root\.go
linters: forbidigo
- path: sharedsock/filter\.go
linters:
- unused
- path: client/firewall/iptables/rule\.go
linters:
- unused
- path: test\.go
linters:
- mirror
- gosec
- path: mock\.go
linters:
- nilnil
# Exclude specific deprecation warnings for grpc methods
- linters:
- staticcheck
text: "grpc.DialContext is deprecated"
- linters:
- staticcheck
text: "grpc.WithBlock is deprecated"
formatters:
exclusions:
generated: lax
paths:
- third_party$
- builtin$
- examples$

View File

@@ -38,6 +38,11 @@
</strong>
<br>
<strong>
🚀 <a href="https://careers.netbird.io">We are hiring! Join us at careers.netbird.io</a>
</strong>
<br>
<br>
<a href="https://registry.terraform.io/providers/netbirdio/netbird/latest">
New: NetBird terraform provider
</a>

View File

@@ -136,6 +136,7 @@ func setLogLevel(cmd *cobra.Command, args []string) error {
client := proto.NewDaemonServiceClient(conn)
level := server.ParseLogLevel(args[0])
if level == proto.LogLevel_UNKNOWN {
//nolint
return fmt.Errorf("unknown log level: %s. Available levels are: panic, fatal, error, warn, info, debug, trace\n", args[0])
}

View File

@@ -81,6 +81,7 @@ var loginCmd = &cobra.Command{
func doDaemonLogin(ctx context.Context, cmd *cobra.Command, providedSetupKey string, activeProf *profilemanager.Profile, username string, pm *profilemanager.ProfileManager) error {
conn, err := DialClientGRPCServer(ctx, daemonAddr)
if err != nil {
//nolint
return fmt.Errorf("failed to connect to daemon error: %v\n"+
"If the daemon is not running please run: "+
"\nnetbird service install \nnetbird service start\n", err)
@@ -206,6 +207,7 @@ func switchProfileOnDaemon(ctx context.Context, pm *profilemanager.ProfileManage
func switchProfile(ctx context.Context, profileName string, username string) error {
conn, err := DialClientGRPCServer(ctx, daemonAddr)
if err != nil {
//nolint
return fmt.Errorf("failed to connect to daemon error: %v\n"+
"If the daemon is not running please run: "+
"\nnetbird service install \nnetbird service start\n", err)

View File

@@ -1,5 +1,4 @@
//go:build pprof
// +build pprof
package cmd

View File

@@ -390,6 +390,7 @@ func getClient(cmd *cobra.Command) (*grpc.ClientConn, error) {
conn, err := DialClientGRPCServer(cmd.Context(), daemonAddr)
if err != nil {
//nolint
return nil, fmt.Errorf("failed to connect to daemon error: %v\n"+
"If the daemon is not running please run: "+
"\nnetbird service install \nnetbird service start\n", err)

View File

@@ -124,6 +124,7 @@ func statusFunc(cmd *cobra.Command, args []string) error {
func getStatus(ctx context.Context, shouldRunProbes bool) (*proto.StatusResponse, error) {
conn, err := DialClientGRPCServer(ctx, daemonAddr)
if err != nil {
//nolint
return nil, fmt.Errorf("failed to connect to daemon error: %v\n"+
"If the daemon is not running please run: "+
"\nnetbird service install \nnetbird service start\n", err)

View File

@@ -89,9 +89,6 @@ func startManagement(t *testing.T, config *config.Config, testFile string) (*grp
t.Cleanup(cleanUp)
eventStore := &activity.InMemoryEventStore{}
if err != nil {
return nil, nil
}
ctrl := gomock.NewController(t)
t.Cleanup(ctrl.Finish)

View File

@@ -216,6 +216,7 @@ func runInDaemonMode(ctx context.Context, cmd *cobra.Command, pm *profilemanager
conn, err := DialClientGRPCServer(ctx, daemonAddr)
if err != nil {
//nolint
return fmt.Errorf("failed to connect to daemon error: %v\n"+
"If the daemon is not running please run: "+
"\nnetbird service install \nnetbird service start\n", err)

View File

@@ -386,11 +386,8 @@ func (m *aclManager) updateState() {
// filterRuleSpecs returns the specs of a filtering rule
func filterRuleSpecs(ip net.IP, protocol string, sPort, dPort *firewall.Port, action firewall.Action, ipsetName string) (specs []string) {
matchByIP := true
// don't use IP matching if IP is 0.0.0.0
if ip.IsUnspecified() {
matchByIP = false
}
matchByIP := !ip.IsUnspecified()
if matchByIP {
if ipsetName != "" {

View File

@@ -161,7 +161,7 @@ func TestIptablesManagerDenyRules(t *testing.T) {
t.Logf(" [%d] %s", i, rule)
}
var denyRuleIndex, acceptRuleIndex int = -1, -1
var denyRuleIndex, acceptRuleIndex = -1, -1
for i, rule := range rules {
if strings.Contains(rule, "DROP") {
t.Logf("Found DROP rule at index %d: %s", i, rule)

View File

@@ -198,7 +198,7 @@ func TestNftablesManagerRuleOrder(t *testing.T) {
t.Logf("Found %d rules in nftables chain", len(rules))
// Find the accept and deny rules and verify deny comes before accept
var acceptRuleIndex, denyRuleIndex int = -1, -1
var acceptRuleIndex, denyRuleIndex = -1, -1
for i, rule := range rules {
hasAcceptHTTPSet := false
hasDenyHTTPSet := false
@@ -208,11 +208,13 @@ func TestNftablesManagerRuleOrder(t *testing.T) {
for _, e := range rule.Exprs {
// Check for set lookup
if lookup, ok := e.(*expr.Lookup); ok {
if lookup.SetName == "accept-http" {
switch lookup.SetName {
case "accept-http":
hasAcceptHTTPSet = true
} else if lookup.SetName == "deny-http" {
case "deny-http":
hasDenyHTTPSet = true
}
}
// Check for port 80
if cmp, ok := e.(*expr.Cmp); ok {
@@ -222,9 +224,10 @@ func TestNftablesManagerRuleOrder(t *testing.T) {
}
// Check for verdict
if verdict, ok := e.(*expr.Verdict); ok {
if verdict.Kind == expr.VerdictAccept {
switch verdict.Kind {
case expr.VerdictAccept:
action = "ACCEPT"
} else if verdict.Kind == expr.VerdictDrop {
case expr.VerdictDrop:
action = "DROP"
}
}

View File

@@ -29,7 +29,7 @@ import (
)
const (
layerTypeAll = 0
layerTypeAll = 255
// ipTCPHeaderMinSize represents minimum IP (20) + TCP (20) header size for MSS calculation
ipTCPHeaderMinSize = 40
@@ -262,10 +262,7 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe
}
func (m *Manager) blockInvalidRouted(iface common.IFaceMapper) (firewall.Rule, error) {
wgPrefix, err := netip.ParsePrefix(iface.Address().Network.String())
if err != nil {
return nil, fmt.Errorf("parse wireguard network: %w", err)
}
wgPrefix := iface.Address().Network
log.Debugf("blocking invalid routed traffic for %s", wgPrefix)
rule, err := m.addRouteFiltering(
@@ -439,19 +436,7 @@ func (m *Manager) AddPeerFiltering(
r.sPort = sPort
r.dPort = dPort
switch proto {
case firewall.ProtocolTCP:
r.protoLayer = layers.LayerTypeTCP
case firewall.ProtocolUDP:
r.protoLayer = layers.LayerTypeUDP
case firewall.ProtocolICMP:
r.protoLayer = layers.LayerTypeICMPv4
if r.ipLayer == layers.LayerTypeIPv6 {
r.protoLayer = layers.LayerTypeICMPv6
}
case firewall.ProtocolALL:
r.protoLayer = layerTypeAll
}
r.protoLayer = protoToLayer(proto, r.ipLayer)
m.mutex.Lock()
var targetMap map[netip.Addr]RuleSet
@@ -496,16 +481,17 @@ func (m *Manager) addRouteFiltering(
}
ruleID := uuid.New().String()
rule := RouteRule{
// TODO: consolidate these IDs
id: ruleID,
mgmtId: id,
sources: sources,
dstSet: destination.Set,
proto: proto,
srcPort: sPort,
dstPort: dPort,
action: action,
id: ruleID,
mgmtId: id,
sources: sources,
dstSet: destination.Set,
protoLayer: protoToLayer(proto, layers.LayerTypeIPv4),
srcPort: sPort,
dstPort: dPort,
action: action,
}
if destination.IsPrefix() {
rule.destinations = []netip.Prefix{destination.Prefix}
@@ -795,7 +781,7 @@ func (m *Manager) recalculateTCPChecksum(packetData []byte, d *decoder, tcpHeade
pseudoSum += uint32(d.ip4.Protocol)
pseudoSum += uint32(tcpLength)
var sum uint32 = pseudoSum
var sum = pseudoSum
for i := 0; i < tcpLength-1; i += 2 {
sum += uint32(tcpLayer[i])<<8 | uint32(tcpLayer[i+1])
}
@@ -945,7 +931,7 @@ func (m *Manager) filterInbound(packetData []byte, size int) bool {
func (m *Manager) handleLocalTraffic(d *decoder, srcIP, dstIP netip.Addr, packetData []byte, size int) bool {
ruleID, blocked := m.peerACLsBlock(srcIP, d, packetData)
if blocked {
_, pnum := getProtocolFromPacket(d)
pnum := getProtocolFromPacket(d)
srcPort, dstPort := getPortsFromPacket(d)
m.logger.Trace6("Dropping local packet (ACL denied): rule_id=%s proto=%v src=%s:%d dst=%s:%d",
@@ -1010,20 +996,22 @@ func (m *Manager) handleRoutedTraffic(d *decoder, srcIP, dstIP netip.Addr, packe
return false
}
proto, pnum := getProtocolFromPacket(d)
protoLayer := d.decoded[1]
srcPort, dstPort := getPortsFromPacket(d)
ruleID, pass := m.routeACLsPass(srcIP, dstIP, proto, srcPort, dstPort)
ruleID, pass := m.routeACLsPass(srcIP, dstIP, protoLayer, srcPort, dstPort)
if !pass {
proto := getProtocolFromPacket(d)
m.logger.Trace6("Dropping routed packet (ACL denied): rule_id=%s proto=%v src=%s:%d dst=%s:%d",
ruleID, pnum, srcIP, srcPort, dstIP, dstPort)
ruleID, proto, srcIP, srcPort, dstIP, dstPort)
m.flowLogger.StoreEvent(nftypes.EventFields{
FlowID: uuid.New(),
Type: nftypes.TypeDrop,
RuleID: ruleID,
Direction: nftypes.Ingress,
Protocol: pnum,
Protocol: proto,
SourceIP: srcIP,
DestIP: dstIP,
SourcePort: srcPort,
@@ -1052,16 +1040,33 @@ func (m *Manager) handleRoutedTraffic(d *decoder, srcIP, dstIP netip.Addr, packe
return true
}
func getProtocolFromPacket(d *decoder) (firewall.Protocol, nftypes.Protocol) {
func protoToLayer(proto firewall.Protocol, ipLayer gopacket.LayerType) gopacket.LayerType {
switch proto {
case firewall.ProtocolTCP:
return layers.LayerTypeTCP
case firewall.ProtocolUDP:
return layers.LayerTypeUDP
case firewall.ProtocolICMP:
if ipLayer == layers.LayerTypeIPv6 {
return layers.LayerTypeICMPv6
}
return layers.LayerTypeICMPv4
case firewall.ProtocolALL:
return layerTypeAll
}
return 0
}
func getProtocolFromPacket(d *decoder) nftypes.Protocol {
switch d.decoded[1] {
case layers.LayerTypeTCP:
return firewall.ProtocolTCP, nftypes.TCP
return nftypes.TCP
case layers.LayerTypeUDP:
return firewall.ProtocolUDP, nftypes.UDP
return nftypes.UDP
case layers.LayerTypeICMPv4, layers.LayerTypeICMPv6:
return firewall.ProtocolICMP, nftypes.ICMP
return nftypes.ICMP
default:
return firewall.ProtocolALL, nftypes.ProtocolUnknown
return nftypes.ProtocolUnknown
}
}
@@ -1233,19 +1238,30 @@ func validateRule(ip netip.Addr, packetData []byte, rules map[string]PeerRule, d
}
// routeACLsPass returns true if the packet is allowed by the route ACLs
func (m *Manager) routeACLsPass(srcIP, dstIP netip.Addr, proto firewall.Protocol, srcPort, dstPort uint16) ([]byte, bool) {
func (m *Manager) routeACLsPass(srcIP, dstIP netip.Addr, protoLayer gopacket.LayerType, srcPort, dstPort uint16) ([]byte, bool) {
m.mutex.RLock()
defer m.mutex.RUnlock()
for _, rule := range m.routeRules {
if matches := m.ruleMatches(rule, srcIP, dstIP, proto, srcPort, dstPort); matches {
if matches := m.ruleMatches(rule, srcIP, dstIP, protoLayer, srcPort, dstPort); matches {
return rule.mgmtId, rule.action == firewall.ActionAccept
}
}
return nil, false
}
func (m *Manager) ruleMatches(rule *RouteRule, srcAddr, dstAddr netip.Addr, proto firewall.Protocol, srcPort, dstPort uint16) bool {
func (m *Manager) ruleMatches(rule *RouteRule, srcAddr, dstAddr netip.Addr, protoLayer gopacket.LayerType, srcPort, dstPort uint16) bool {
// TODO: handle ipv6 vs ipv4 icmp rules
if rule.protoLayer != layerTypeAll && rule.protoLayer != protoLayer {
return false
}
if protoLayer == layers.LayerTypeTCP || protoLayer == layers.LayerTypeUDP {
if !portsMatch(rule.srcPort, srcPort) || !portsMatch(rule.dstPort, dstPort) {
return false
}
}
destMatched := false
for _, dst := range rule.destinations {
if dst.Contains(dstAddr) {
@@ -1264,21 +1280,8 @@ func (m *Manager) ruleMatches(rule *RouteRule, srcAddr, dstAddr netip.Addr, prot
break
}
}
if !sourceMatched {
return false
}
if rule.proto != firewall.ProtocolALL && rule.proto != proto {
return false
}
if proto == firewall.ProtocolTCP || proto == firewall.ProtocolUDP {
if !portsMatch(rule.srcPort, srcPort) || !portsMatch(rule.dstPort, dstPort) {
return false
}
}
return true
return sourceMatched
}
// AddUDPPacketHook calls hook when UDP packet from given direction matched

View File

@@ -955,7 +955,7 @@ func BenchmarkRouteACLs(b *testing.B) {
for _, tc := range cases {
srcIP := netip.MustParseAddr(tc.srcIP)
dstIP := netip.MustParseAddr(tc.dstIP)
manager.routeACLsPass(srcIP, dstIP, tc.proto, 0, tc.dstPort)
manager.routeACLsPass(srcIP, dstIP, protoToLayer(tc.proto, layers.LayerTypeIPv4), 0, tc.dstPort)
}
}
}

View File

@@ -1259,7 +1259,7 @@ func TestRouteACLFiltering(t *testing.T) {
// testing routeACLsPass only and not FilterInbound, as routed packets are dropped after being passed
// to the forwarder
_, isAllowed := manager.routeACLsPass(srcIP, dstIP, tc.proto, tc.srcPort, tc.dstPort)
_, isAllowed := manager.routeACLsPass(srcIP, dstIP, protoToLayer(tc.proto, layers.LayerTypeIPv4), tc.srcPort, tc.dstPort)
require.Equal(t, tc.shouldPass, isAllowed)
})
}
@@ -1445,7 +1445,7 @@ func TestRouteACLOrder(t *testing.T) {
srcIP := netip.MustParseAddr(p.srcIP)
dstIP := netip.MustParseAddr(p.dstIP)
_, isAllowed := manager.routeACLsPass(srcIP, dstIP, p.proto, p.srcPort, p.dstPort)
_, isAllowed := manager.routeACLsPass(srcIP, dstIP, protoToLayer(p.proto, layers.LayerTypeIPv4), p.srcPort, p.dstPort)
require.Equal(t, p.shouldPass, isAllowed, "packet %d failed", i)
}
})
@@ -1488,13 +1488,13 @@ func TestRouteACLSet(t *testing.T) {
dstIP := netip.MustParseAddr("192.168.1.100")
// Check that traffic is dropped (empty set shouldn't match anything)
_, isAllowed := manager.routeACLsPass(srcIP, dstIP, fw.ProtocolTCP, 12345, 80)
_, isAllowed := manager.routeACLsPass(srcIP, dstIP, protoToLayer(fw.ProtocolTCP, layers.LayerTypeIPv4), 12345, 80)
require.False(t, isAllowed, "Empty set should not allow any traffic")
err = manager.UpdateSet(set, []netip.Prefix{netip.MustParsePrefix("192.168.1.0/24")})
require.NoError(t, err)
// Now the packet should be allowed
_, isAllowed = manager.routeACLsPass(srcIP, dstIP, fw.ProtocolTCP, 12345, 80)
_, isAllowed = manager.routeACLsPass(srcIP, dstIP, protoToLayer(fw.ProtocolTCP, layers.LayerTypeIPv4), 12345, 80)
require.True(t, isAllowed, "After set update, traffic to the added network should be allowed")
}

View File

@@ -767,9 +767,9 @@ func TestUpdateSetMerge(t *testing.T) {
dstIP2 := netip.MustParseAddr("192.168.1.100")
dstIP3 := netip.MustParseAddr("172.16.0.100")
_, isAllowed1 := manager.routeACLsPass(srcIP, dstIP1, fw.ProtocolTCP, 12345, 80)
_, isAllowed2 := manager.routeACLsPass(srcIP, dstIP2, fw.ProtocolTCP, 12345, 80)
_, isAllowed3 := manager.routeACLsPass(srcIP, dstIP3, fw.ProtocolTCP, 12345, 80)
_, isAllowed1 := manager.routeACLsPass(srcIP, dstIP1, protoToLayer(fw.ProtocolTCP, layers.LayerTypeIPv4), 12345, 80)
_, isAllowed2 := manager.routeACLsPass(srcIP, dstIP2, protoToLayer(fw.ProtocolTCP, layers.LayerTypeIPv4), 12345, 80)
_, isAllowed3 := manager.routeACLsPass(srcIP, dstIP3, protoToLayer(fw.ProtocolTCP, layers.LayerTypeIPv4), 12345, 80)
require.True(t, isAllowed1, "Traffic to 10.0.0.100 should be allowed")
require.True(t, isAllowed2, "Traffic to 192.168.1.100 should be allowed")
@@ -784,8 +784,8 @@ func TestUpdateSetMerge(t *testing.T) {
require.NoError(t, err)
// Check that all original prefixes are still included
_, isAllowed1 = manager.routeACLsPass(srcIP, dstIP1, fw.ProtocolTCP, 12345, 80)
_, isAllowed2 = manager.routeACLsPass(srcIP, dstIP2, fw.ProtocolTCP, 12345, 80)
_, isAllowed1 = manager.routeACLsPass(srcIP, dstIP1, protoToLayer(fw.ProtocolTCP, layers.LayerTypeIPv4), 12345, 80)
_, isAllowed2 = manager.routeACLsPass(srcIP, dstIP2, protoToLayer(fw.ProtocolTCP, layers.LayerTypeIPv4), 12345, 80)
require.True(t, isAllowed1, "Traffic to 10.0.0.100 should still be allowed after update")
require.True(t, isAllowed2, "Traffic to 192.168.1.100 should still be allowed after update")
@@ -793,8 +793,8 @@ func TestUpdateSetMerge(t *testing.T) {
dstIP4 := netip.MustParseAddr("172.16.1.100")
dstIP5 := netip.MustParseAddr("10.1.0.50")
_, isAllowed4 := manager.routeACLsPass(srcIP, dstIP4, fw.ProtocolTCP, 12345, 80)
_, isAllowed5 := manager.routeACLsPass(srcIP, dstIP5, fw.ProtocolTCP, 12345, 80)
_, isAllowed4 := manager.routeACLsPass(srcIP, dstIP4, protoToLayer(fw.ProtocolTCP, layers.LayerTypeIPv4), 12345, 80)
_, isAllowed5 := manager.routeACLsPass(srcIP, dstIP5, protoToLayer(fw.ProtocolTCP, layers.LayerTypeIPv4), 12345, 80)
require.True(t, isAllowed4, "Traffic to new prefix 172.16.0.0/16 should be allowed")
require.True(t, isAllowed5, "Traffic to new prefix 10.1.0.0/24 should be allowed")
@@ -922,7 +922,7 @@ func TestUpdateSetDeduplication(t *testing.T) {
srcIP := netip.MustParseAddr("100.10.0.1")
for _, tc := range testCases {
_, isAllowed := manager.routeACLsPass(srcIP, tc.dstIP, fw.ProtocolTCP, 12345, 80)
_, isAllowed := manager.routeACLsPass(srcIP, tc.dstIP, protoToLayer(fw.ProtocolTCP, layers.LayerTypeIPv4), 12345, 80)
require.Equal(t, tc.expected, isAllowed, tc.desc)
}
}

View File

@@ -2,6 +2,7 @@ package forwarder
import (
"fmt"
"sync/atomic"
wgdevice "golang.zx2c4.com/wireguard/device"
"gvisor.dev/gvisor/pkg/tcpip"
@@ -16,7 +17,7 @@ type endpoint struct {
logger *nblog.Logger
dispatcher stack.NetworkDispatcher
device *wgdevice.Device
mtu uint32
mtu atomic.Uint32
}
func (e *endpoint) Attach(dispatcher stack.NetworkDispatcher) {
@@ -28,7 +29,7 @@ func (e *endpoint) IsAttached() bool {
}
func (e *endpoint) MTU() uint32 {
return e.mtu
return e.mtu.Load()
}
func (e *endpoint) Capabilities() stack.LinkEndpointCapabilities {
@@ -82,6 +83,22 @@ func (e *endpoint) ParseHeader(*stack.PacketBuffer) bool {
return true
}
func (e *endpoint) Close() {
// Endpoint cleanup - nothing to do as device is managed externally
}
func (e *endpoint) SetLinkAddress(tcpip.LinkAddress) {
// Link address is not used for this endpoint type
}
func (e *endpoint) SetMTU(mtu uint32) {
e.mtu.Store(mtu)
}
func (e *endpoint) SetOnCloseAction(func()) {
// No action needed on close
}
type epID stack.TransportEndpointID
func (i epID) String() string {

View File

@@ -7,6 +7,7 @@ import (
"net/netip"
"runtime"
"sync"
"time"
log "github.com/sirupsen/logrus"
"gvisor.dev/gvisor/pkg/buffer"
@@ -35,14 +36,16 @@ type Forwarder struct {
logger *nblog.Logger
flowLogger nftypes.FlowLogger
// ruleIdMap is used to store the rule ID for a given connection
ruleIdMap sync.Map
stack *stack.Stack
endpoint *endpoint
udpForwarder *udpForwarder
ctx context.Context
cancel context.CancelFunc
ip tcpip.Address
netstack bool
ruleIdMap sync.Map
stack *stack.Stack
endpoint *endpoint
udpForwarder *udpForwarder
ctx context.Context
cancel context.CancelFunc
ip tcpip.Address
netstack bool
hasRawICMPAccess bool
pingSemaphore chan struct{}
}
func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.FlowLogger, netstack bool, mtu uint16) (*Forwarder, error) {
@@ -60,8 +63,8 @@ func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.Flow
endpoint := &endpoint{
logger: logger,
device: iface.GetWGDevice(),
mtu: uint32(mtu),
}
endpoint.mtu.Store(uint32(mtu))
if err := s.CreateNIC(nicID, endpoint); err != nil {
return nil, fmt.Errorf("create NIC: %v", err)
@@ -103,15 +106,16 @@ func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.Flow
ctx, cancel := context.WithCancel(context.Background())
f := &Forwarder{
logger: logger,
flowLogger: flowLogger,
stack: s,
endpoint: endpoint,
udpForwarder: newUDPForwarder(mtu, logger, flowLogger),
ctx: ctx,
cancel: cancel,
netstack: netstack,
ip: tcpip.AddrFromSlice(iface.Address().IP.AsSlice()),
logger: logger,
flowLogger: flowLogger,
stack: s,
endpoint: endpoint,
udpForwarder: newUDPForwarder(mtu, logger, flowLogger),
ctx: ctx,
cancel: cancel,
netstack: netstack,
ip: tcpip.AddrFromSlice(iface.Address().IP.AsSlice()),
pingSemaphore: make(chan struct{}, 3),
}
receiveWindow := defaultReceiveWindow
@@ -129,6 +133,8 @@ func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.Flow
s.SetTransportProtocolHandler(icmp.ProtocolNumber4, f.handleICMP)
f.checkICMPCapability()
log.Debugf("forwarder: Initialization complete with NIC %d", nicID)
return f, nil
}
@@ -198,3 +204,24 @@ func buildKey(srcIP, dstIP netip.Addr, srcPort, dstPort uint16) conntrack.ConnKe
DstPort: dstPort,
}
}
// checkICMPCapability tests whether we have raw ICMP socket access at startup.
func (f *Forwarder) checkICMPCapability() {
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
lc := net.ListenConfig{}
conn, err := lc.ListenPacket(ctx, "ip4:icmp", "0.0.0.0")
if err != nil {
f.hasRawICMPAccess = false
f.logger.Debug("forwarder: No raw ICMP socket access, will use ping binary fallback")
return
}
if err := conn.Close(); err != nil {
f.logger.Debug1("forwarder: Failed to close ICMP capability test socket: %v", err)
}
f.hasRawICMPAccess = true
f.logger.Debug("forwarder: Raw ICMP socket access available")
}

View File

@@ -2,8 +2,11 @@ package forwarder
import (
"context"
"fmt"
"net"
"net/netip"
"os/exec"
"runtime"
"time"
"github.com/google/uuid"
@@ -14,30 +17,95 @@ import (
)
// handleICMP handles ICMP packets from the network stack
func (f *Forwarder) handleICMP(id stack.TransportEndpointID, pkt stack.PacketBufferPtr) bool {
func (f *Forwarder) handleICMP(id stack.TransportEndpointID, pkt *stack.PacketBuffer) bool {
icmpHdr := header.ICMPv4(pkt.TransportHeader().View().AsSlice())
icmpType := uint8(icmpHdr.Type())
icmpCode := uint8(icmpHdr.Code())
if header.ICMPv4Type(icmpType) == header.ICMPv4EchoReply {
// dont process our own replies
return true
}
flowID := uuid.New()
f.sendICMPEvent(nftypes.TypeStart, flowID, id, icmpType, icmpCode, 0, 0)
f.sendICMPEvent(nftypes.TypeStart, flowID, id, uint8(icmpHdr.Type()), uint8(icmpHdr.Code()), 0, 0)
ctx, cancel := context.WithTimeout(f.ctx, 5*time.Second)
// For Echo Requests, send and wait for response
if icmpHdr.Type() == header.ICMPv4Echo {
return f.handleICMPEcho(flowID, id, pkt, uint8(icmpHdr.Type()), uint8(icmpHdr.Code()))
}
// For other ICMP types (Time Exceeded, Destination Unreachable, etc), forward without waiting
if !f.hasRawICMPAccess {
f.logger.Debug2("forwarder: Cannot handle ICMP type %v without raw socket access for %v", icmpHdr.Type(), epID(id))
return false
}
icmpData := stack.PayloadSince(pkt.TransportHeader()).AsSlice()
conn, err := f.forwardICMPPacket(id, icmpData, uint8(icmpHdr.Type()), uint8(icmpHdr.Code()), 100*time.Millisecond)
if err != nil {
f.logger.Error2("forwarder: Failed to forward ICMP packet for %v: %v", epID(id), err)
return true
}
if err := conn.Close(); err != nil {
f.logger.Debug1("forwarder: Failed to close ICMP socket: %v", err)
}
return true
}
// handleICMPEcho handles ICMP echo requests asynchronously with rate limiting.
func (f *Forwarder) handleICMPEcho(flowID uuid.UUID, id stack.TransportEndpointID, pkt *stack.PacketBuffer, icmpType, icmpCode uint8) bool {
select {
case f.pingSemaphore <- struct{}{}:
icmpData := stack.PayloadSince(pkt.TransportHeader()).ToSlice()
rxBytes := pkt.Size()
go func() {
defer func() { <-f.pingSemaphore }()
if f.hasRawICMPAccess {
f.handleICMPViaSocket(flowID, id, icmpType, icmpCode, icmpData, rxBytes)
} else {
f.handleICMPViaPing(flowID, id, icmpType, icmpCode, icmpData, rxBytes)
}
}()
default:
f.logger.Debug3("forwarder: ICMP rate limit exceeded for %v type %v code %v",
epID(id), icmpType, icmpCode)
}
return true
}
// forwardICMPPacket creates a raw ICMP socket and sends the packet, returning the connection.
// The caller is responsible for closing the returned connection.
func (f *Forwarder) forwardICMPPacket(id stack.TransportEndpointID, payload []byte, icmpType, icmpCode uint8, timeout time.Duration) (net.PacketConn, error) {
ctx, cancel := context.WithTimeout(f.ctx, timeout)
defer cancel()
lc := net.ListenConfig{}
// TODO: support non-root
conn, err := lc.ListenPacket(ctx, "ip4:icmp", "0.0.0.0")
if err != nil {
f.logger.Error2("forwarder: Failed to create ICMP socket for %v: %v", epID(id), err)
return nil, fmt.Errorf("create ICMP socket: %w", err)
}
// This will make netstack reply on behalf of the original destination, that's ok for now
return false
dstIP := f.determineDialAddr(id.LocalAddress)
dst := &net.IPAddr{IP: dstIP}
if _, err = conn.WriteTo(payload, dst); err != nil {
if closeErr := conn.Close(); closeErr != nil {
f.logger.Debug1("forwarder: Failed to close ICMP socket: %v", closeErr)
}
return nil, fmt.Errorf("write ICMP packet: %w", err)
}
f.logger.Trace3("forwarder: Forwarded ICMP packet %v type %v code %v",
epID(id), icmpType, icmpCode)
return conn, nil
}
// handleICMPViaSocket handles ICMP echo requests using raw sockets.
func (f *Forwarder) handleICMPViaSocket(flowID uuid.UUID, id stack.TransportEndpointID, icmpType, icmpCode uint8, icmpData []byte, rxBytes int) {
sendTime := time.Now()
conn, err := f.forwardICMPPacket(id, icmpData, icmpType, icmpCode, 5*time.Second)
if err != nil {
f.logger.Error2("forwarder: Failed to send ICMP packet for %v: %v", epID(id), err)
return
}
defer func() {
if err := conn.Close(); err != nil {
@@ -45,38 +113,22 @@ func (f *Forwarder) handleICMP(id stack.TransportEndpointID, pkt stack.PacketBuf
}
}()
dstIP := f.determineDialAddr(id.LocalAddress)
dst := &net.IPAddr{IP: dstIP}
txBytes := f.handleEchoResponse(conn, id)
rtt := time.Since(sendTime).Round(10 * time.Microsecond)
fullPacket := stack.PayloadSince(pkt.TransportHeader())
payload := fullPacket.AsSlice()
f.logger.Trace4("forwarder: Forwarded ICMP echo reply %v type %v code %v (rtt=%v, raw socket)",
epID(id), icmpType, icmpCode, rtt)
if _, err = conn.WriteTo(payload, dst); err != nil {
f.logger.Error2("forwarder: Failed to write ICMP packet for %v: %v", epID(id), err)
return true
}
f.logger.Trace3("forwarder: Forwarded ICMP packet %v type %v code %v",
epID(id), icmpHdr.Type(), icmpHdr.Code())
// For Echo Requests, send and handle response
if header.ICMPv4Type(icmpType) == header.ICMPv4Echo {
rxBytes := pkt.Size()
txBytes := f.handleEchoResponse(icmpHdr, conn, id)
f.sendICMPEvent(nftypes.TypeEnd, flowID, id, icmpType, icmpCode, uint64(rxBytes), uint64(txBytes))
}
// For other ICMP types (Time Exceeded, Destination Unreachable, etc) do nothing
return true
f.sendICMPEvent(nftypes.TypeEnd, flowID, id, icmpType, icmpCode, uint64(rxBytes), uint64(txBytes))
}
func (f *Forwarder) handleEchoResponse(icmpHdr header.ICMPv4, conn net.PacketConn, id stack.TransportEndpointID) int {
func (f *Forwarder) handleEchoResponse(conn net.PacketConn, id stack.TransportEndpointID) int {
if err := conn.SetReadDeadline(time.Now().Add(5 * time.Second)); err != nil {
f.logger.Error1("forwarder: Failed to set read deadline for ICMP response: %v", err)
return 0
}
response := make([]byte, f.endpoint.mtu)
response := make([]byte, f.endpoint.mtu.Load())
n, _, err := conn.ReadFrom(response)
if err != nil {
if !isTimeout(err) {
@@ -85,31 +137,7 @@ func (f *Forwarder) handleEchoResponse(icmpHdr header.ICMPv4, conn net.PacketCon
return 0
}
ipHdr := make([]byte, header.IPv4MinimumSize)
ip := header.IPv4(ipHdr)
ip.Encode(&header.IPv4Fields{
TotalLength: uint16(header.IPv4MinimumSize + n),
TTL: 64,
Protocol: uint8(header.ICMPv4ProtocolNumber),
SrcAddr: id.LocalAddress,
DstAddr: id.RemoteAddress,
})
ip.SetChecksum(^ip.CalculateChecksum())
fullPacket := make([]byte, 0, len(ipHdr)+n)
fullPacket = append(fullPacket, ipHdr...)
fullPacket = append(fullPacket, response[:n]...)
if err := f.InjectIncomingPacket(fullPacket); err != nil {
f.logger.Error1("forwarder: Failed to inject ICMP response: %v", err)
return 0
}
f.logger.Trace3("forwarder: Forwarded ICMP echo reply for %v type %v code %v",
epID(id), icmpHdr.Type(), icmpHdr.Code())
return len(fullPacket)
return f.injectICMPReply(id, response[:n])
}
// sendICMPEvent stores flow events for ICMP packets
@@ -152,3 +180,95 @@ func (f *Forwarder) sendICMPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.T
f.flowLogger.StoreEvent(fields)
}
// handleICMPViaPing handles ICMP echo requests by executing the system ping binary.
// This is used as a fallback when raw socket access is not available.
func (f *Forwarder) handleICMPViaPing(flowID uuid.UUID, id stack.TransportEndpointID, icmpType, icmpCode uint8, icmpData []byte, rxBytes int) {
ctx, cancel := context.WithTimeout(f.ctx, 5*time.Second)
defer cancel()
dstIP := f.determineDialAddr(id.LocalAddress)
cmd := buildPingCommand(ctx, dstIP, 5*time.Second)
pingStart := time.Now()
if err := cmd.Run(); err != nil {
f.logger.Warn4("forwarder: Ping binary failed for %v type %v code %v: %v", epID(id),
icmpType, icmpCode, err)
return
}
rtt := time.Since(pingStart).Round(10 * time.Microsecond)
f.logger.Trace3("forwarder: Forwarded ICMP echo request %v type %v code %v",
epID(id), icmpType, icmpCode)
txBytes := f.synthesizeEchoReply(id, icmpData)
f.logger.Trace4("forwarder: Forwarded ICMP echo reply %v type %v code %v (rtt=%v, ping binary)",
epID(id), icmpType, icmpCode, rtt)
f.sendICMPEvent(nftypes.TypeEnd, flowID, id, icmpType, icmpCode, uint64(rxBytes), uint64(txBytes))
}
// buildPingCommand creates a platform-specific ping command.
func buildPingCommand(ctx context.Context, target net.IP, timeout time.Duration) *exec.Cmd {
timeoutSec := int(timeout.Seconds())
if timeoutSec < 1 {
timeoutSec = 1
}
switch runtime.GOOS {
case "linux", "android":
return exec.CommandContext(ctx, "ping", "-c", "1", "-W", fmt.Sprintf("%d", timeoutSec), "-q", target.String())
case "darwin", "ios":
return exec.CommandContext(ctx, "ping", "-c", "1", "-t", fmt.Sprintf("%d", timeoutSec), "-q", target.String())
case "freebsd":
return exec.CommandContext(ctx, "ping", "-c", "1", "-t", fmt.Sprintf("%d", timeoutSec), target.String())
case "openbsd", "netbsd":
return exec.CommandContext(ctx, "ping", "-c", "1", "-w", fmt.Sprintf("%d", timeoutSec), target.String())
case "windows":
return exec.CommandContext(ctx, "ping", "-n", "1", "-w", fmt.Sprintf("%d", timeoutSec*1000), target.String())
default:
return exec.CommandContext(ctx, "ping", "-c", "1", target.String())
}
}
// synthesizeEchoReply creates an ICMP echo reply from raw ICMP data and injects it back into the network stack.
// Returns the size of the injected packet.
func (f *Forwarder) synthesizeEchoReply(id stack.TransportEndpointID, icmpData []byte) int {
replyICMP := make([]byte, len(icmpData))
copy(replyICMP, icmpData)
replyICMPHdr := header.ICMPv4(replyICMP)
replyICMPHdr.SetType(header.ICMPv4EchoReply)
replyICMPHdr.SetChecksum(0)
replyICMPHdr.SetChecksum(header.ICMPv4Checksum(replyICMPHdr, 0))
return f.injectICMPReply(id, replyICMP)
}
// injectICMPReply wraps an ICMP payload in an IP header and injects it into the network stack.
// Returns the total size of the injected packet, or 0 if injection failed.
func (f *Forwarder) injectICMPReply(id stack.TransportEndpointID, icmpPayload []byte) int {
ipHdr := make([]byte, header.IPv4MinimumSize)
ip := header.IPv4(ipHdr)
ip.Encode(&header.IPv4Fields{
TotalLength: uint16(header.IPv4MinimumSize + len(icmpPayload)),
TTL: 64,
Protocol: uint8(header.ICMPv4ProtocolNumber),
SrcAddr: id.LocalAddress,
DstAddr: id.RemoteAddress,
})
ip.SetChecksum(^ip.CalculateChecksum())
fullPacket := make([]byte, 0, len(ipHdr)+len(icmpPayload))
fullPacket = append(fullPacket, ipHdr...)
fullPacket = append(fullPacket, icmpPayload...)
// Bypass netstack and send directly to peer to avoid looping through our ICMP handler
if err := f.endpoint.device.CreateOutboundPacket(fullPacket, id.RemoteAddress.AsSlice()); err != nil {
f.logger.Error1("forwarder: Failed to send ICMP reply to peer: %v", err)
return 0
}
return len(fullPacket)
}

View File

@@ -4,6 +4,7 @@ import (
"context"
"errors"
"fmt"
"io"
"net"
"net/netip"
"sync"
@@ -131,10 +132,10 @@ func (f *udpForwarder) cleanup() {
}
// handleUDP is called by the UDP forwarder for new packets
func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) {
func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) bool {
if f.ctx.Err() != nil {
f.logger.Trace("forwarder: context done, dropping UDP packet")
return
return false
}
id := r.ID()
@@ -144,7 +145,7 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) {
f.udpForwarder.RUnlock()
if exists {
f.logger.Trace1("forwarder: existing UDP connection for %v", epID(id))
return
return true
}
flowID := uuid.New()
@@ -162,7 +163,7 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) {
if err != nil {
f.logger.Debug2("forwarder: UDP dial error for %v: %v", epID(id), err)
// TODO: Send ICMP error message
return
return false
}
// Create wait queue for blocking syscalls
@@ -173,10 +174,10 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) {
if err := outConn.Close(); err != nil {
f.logger.Debug2("forwarder: UDP outConn close error for %v: %v", epID(id), err)
}
return
return false
}
inConn := gonet.NewUDPConn(f.stack, &wq, ep)
inConn := gonet.NewUDPConn(&wq, ep)
connCtx, connCancel := context.WithCancel(f.ctx)
pConn := &udpPacketConn{
@@ -199,7 +200,7 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) {
if err := outConn.Close(); err != nil {
f.logger.Debug2("forwarder: UDP outConn close error for %v: %v", epID(id), err)
}
return
return true
}
f.udpForwarder.conns[id] = pConn
f.udpForwarder.Unlock()
@@ -208,6 +209,7 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) {
f.logger.Trace1("forwarder: established UDP connection %v", epID(id))
go f.proxyUDP(connCtx, pConn, id, ep)
return true
}
func (f *Forwarder) proxyUDP(ctx context.Context, pConn *udpPacketConn, id stack.TransportEndpointID, ep tcpip.Endpoint) {
@@ -348,7 +350,7 @@ func (c *udpPacketConn) copy(ctx context.Context, dst net.Conn, src net.Conn, bu
}
func isClosedError(err error) bool {
return errors.Is(err, net.ErrClosed) || errors.Is(err, context.Canceled)
return errors.Is(err, net.ErrClosed) || errors.Is(err, context.Canceled) || errors.Is(err, io.EOF)
}
func isTimeout(err error) bool {

View File

@@ -130,6 +130,7 @@ func (m *localIPManager) UpdateLocalIPs(iface common.IFaceMapper) (err error) {
// 127.0.0.0/8
newIPv4Bitmap[127] = &ipv4LowBitmap{}
for i := 0; i < 8192; i++ {
// #nosec G602 -- bitmap is defined as [8192]uint32, loop range is correct
newIPv4Bitmap[127].bitmap[i] = 0xFFFFFFFF
}

View File

@@ -218,7 +218,7 @@ func BenchmarkIPChecks(b *testing.B) {
b.ResetTimer()
for i := 0; i < b.N; i++ {
// nolint:gosimple
_, _ = mapManager.localIPs[ip.String()]
_ = mapManager.localIPs[ip.String()]
}
})
@@ -227,7 +227,7 @@ func BenchmarkIPChecks(b *testing.B) {
b.ResetTimer()
for i := 0; i < b.N; i++ {
// nolint:gosimple
_, _ = mapManager.localIPs[ip.String()]
_ = mapManager.localIPs[ip.String()]
}
})
}

View File

@@ -168,6 +168,15 @@ func (l *Logger) Warn3(format string, arg1, arg2, arg3 any) {
}
}
func (l *Logger) Warn4(format string, arg1, arg2, arg3, arg4 any) {
if l.level.Load() >= uint32(LevelWarn) {
select {
case l.msgChannel <- logMessage{level: LevelWarn, format: format, arg1: arg1, arg2: arg2, arg3: arg3, arg4: arg4}:
default:
}
}
}
func (l *Logger) Debug1(format string, arg1 any) {
if l.level.Load() >= uint32(LevelDebug) {
select {

View File

@@ -234,9 +234,10 @@ func TestInboundPortDNATNegative(t *testing.T) {
require.False(t, translated, "Packet should NOT be translated for %s", tc.name)
d = parsePacket(t, packet)
if tc.protocol == layers.IPProtocolTCP {
switch tc.protocol {
case layers.IPProtocolTCP:
require.Equal(t, tc.dstPort, uint16(d.tcp.DstPort), "Port should remain unchanged")
} else if tc.protocol == layers.IPProtocolUDP {
case layers.IPProtocolUDP:
require.Equal(t, tc.dstPort, uint16(d.udp.DstPort), "Port should remain unchanged")
}
})

View File

@@ -34,7 +34,7 @@ type RouteRule struct {
sources []netip.Prefix
dstSet firewall.Set
destinations []netip.Prefix
proto firewall.Protocol
protoLayer gopacket.LayerType
srcPort *firewall.Port
dstPort *firewall.Port
action firewall.Action

View File

@@ -379,9 +379,9 @@ func (m *Manager) handleNativeRouter(trace *PacketTrace) *PacketTrace {
}
func (m *Manager) handleRouteACLs(trace *PacketTrace, d *decoder, srcIP, dstIP netip.Addr) *PacketTrace {
proto, _ := getProtocolFromPacket(d)
protoLayer := d.decoded[1]
srcPort, dstPort := getPortsFromPacket(d)
id, allowed := m.routeACLsPass(srcIP, dstIP, proto, srcPort, dstPort)
id, allowed := m.routeACLsPass(srcIP, dstIP, protoLayer, srcPort, dstPort)
strId := string(id)
if id == nil {

View File

@@ -27,8 +27,23 @@ type receiverCreator struct {
iceBind *ICEBind
}
func (rc receiverCreator) CreateIPv4ReceiverFn(pc *ipv4.PacketConn, conn *net.UDPConn, rxOffload bool, msgPool *sync.Pool) wgConn.ReceiveFunc {
return rc.iceBind.createIPv4ReceiverFn(pc, conn, rxOffload, msgPool)
func (rc receiverCreator) CreateReceiverFn(pc wgConn.BatchReader, conn *net.UDPConn, rxOffload bool, msgPool *sync.Pool) wgConn.ReceiveFunc {
if ipv4PC, ok := pc.(*ipv4.PacketConn); ok {
return rc.iceBind.createIPv4ReceiverFn(ipv4PC, conn, rxOffload, msgPool)
}
// IPv6 is currently not supported in the udpmux, this is a stub for compatibility with the
// wireguard-go ReceiverCreator interface which is called for both IPv4 and IPv6.
return func(bufs [][]byte, sizes []int, eps []wgConn.Endpoint) (n int, err error) {
buf := bufs[0]
size, ep, err := conn.ReadFromUDPAddrPort(buf)
if err != nil {
return 0, err
}
sizes[0] = size
stdEp := &wgConn.StdNetEndpoint{AddrPort: ep}
eps[0] = stdEp
return 1, nil
}
}
// ICEBind is a bind implementation with two main features:

View File

@@ -1,6 +1,3 @@
//go:build ios
// +build ios
package device
import (

View File

@@ -507,15 +507,13 @@ func formatPayloadWithCmp(p *expr.Payload, cmp *expr.Cmp) string {
if p.Base == expr.PayloadBaseNetworkHeader {
switch p.Offset {
case 12:
if p.Len == 4 {
return fmt.Sprintf("ip saddr %s %s", formatCmpOp(cmp.Op), formatIPBytes(cmp.Data))
} else if p.Len == 2 {
switch p.Len {
case 4, 2:
return fmt.Sprintf("ip saddr %s %s", formatCmpOp(cmp.Op), formatIPBytes(cmp.Data))
}
case 16:
if p.Len == 4 {
return fmt.Sprintf("ip daddr %s %s", formatCmpOp(cmp.Op), formatIPBytes(cmp.Data))
} else if p.Len == 2 {
switch p.Len {
case 4, 2:
return fmt.Sprintf("ip daddr %s %s", formatCmpOp(cmp.Op), formatIPBytes(cmp.Data))
}
}

View File

@@ -76,7 +76,7 @@ func collectPTRRecords(config *nbdns.Config, prefix netip.Prefix) []nbdns.Simple
var records []nbdns.SimpleRecord
for _, zone := range config.CustomZones {
if zone.SkipPTRProcess {
if zone.NonAuthoritative {
continue
}
for _, record := range zone.Records {

View File

@@ -3,11 +3,15 @@ package dns
import (
"fmt"
"slices"
"strconv"
"strings"
"sync"
"time"
"github.com/miekg/dns"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/dns/resutil"
)
const (
@@ -43,7 +47,23 @@ type HandlerChain struct {
type ResponseWriterChain struct {
dns.ResponseWriter
origPattern string
requestID string
shouldContinue bool
response *dns.Msg
meta map[string]string
}
// RequestID returns the request ID for tracing
func (w *ResponseWriterChain) RequestID() string {
return w.requestID
}
// SetMeta sets a metadata key-value pair for logging
func (w *ResponseWriterChain) SetMeta(key, value string) {
if w.meta == nil {
w.meta = make(map[string]string)
}
w.meta[key] = value
}
func (w *ResponseWriterChain) WriteMsg(m *dns.Msg) error {
@@ -52,6 +72,7 @@ func (w *ResponseWriterChain) WriteMsg(m *dns.Msg) error {
w.shouldContinue = true
return nil
}
w.response = m
return w.ResponseWriter.WriteMsg(m)
}
@@ -101,6 +122,8 @@ func (c *HandlerChain) AddHandler(pattern string, handler dns.Handler, priority
pos := c.findHandlerPosition(entry)
c.handlers = append(c.handlers[:pos], append([]HandlerEntry{entry}, c.handlers[pos:]...)...)
c.logHandlers()
}
// findHandlerPosition determines where to insert a new handler based on priority and specificity
@@ -140,68 +163,109 @@ func (c *HandlerChain) removeEntry(pattern string, priority int) {
for i := len(c.handlers) - 1; i >= 0; i-- {
entry := c.handlers[i]
if strings.EqualFold(entry.OrigPattern, pattern) && entry.Priority == priority {
log.Debugf("removing handler pattern: domain=%s priority=%d", entry.OrigPattern, priority)
c.handlers = append(c.handlers[:i], c.handlers[i+1:]...)
c.logHandlers()
break
}
}
}
// logHandlers logs the current handler chain state. Caller must hold the lock.
func (c *HandlerChain) logHandlers() {
if !log.IsLevelEnabled(log.TraceLevel) {
return
}
var b strings.Builder
b.WriteString("handler chain (" + strconv.Itoa(len(c.handlers)) + "):\n")
for _, h := range c.handlers {
b.WriteString(" - pattern: domain=" + h.Pattern + " original: domain=" + h.OrigPattern +
" wildcard=" + strconv.FormatBool(h.IsWildcard) +
" match_subdomain=" + strconv.FormatBool(h.MatchSubdomains) +
" priority=" + strconv.Itoa(h.Priority) + "\n")
}
log.Trace(strings.TrimSuffix(b.String(), "\n"))
}
func (c *HandlerChain) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
if len(r.Question) == 0 {
return
}
qname := strings.ToLower(r.Question[0].Name)
startTime := time.Now()
requestID := resutil.GenerateRequestID()
logger := log.WithFields(log.Fields{
"request_id": requestID,
"dns_id": fmt.Sprintf("%04x", r.Id),
})
question := r.Question[0]
qname := strings.ToLower(question.Name)
c.mu.RLock()
handlers := slices.Clone(c.handlers)
c.mu.RUnlock()
if log.IsLevelEnabled(log.TraceLevel) {
var b strings.Builder
b.WriteString(fmt.Sprintf("DNS request domain=%s, handlers (%d):\n", qname, len(handlers)))
for _, h := range handlers {
b.WriteString(fmt.Sprintf(" - pattern: domain=%s original: domain=%s wildcard=%v match_subdomain=%v priority=%d\n",
h.Pattern, h.OrigPattern, h.IsWildcard, h.MatchSubdomains, h.Priority))
}
log.Trace(strings.TrimSuffix(b.String(), "\n"))
}
// Try handlers in priority order
for _, entry := range handlers {
matched := c.isHandlerMatch(qname, entry)
if matched {
log.Tracef("handler matched: domain=%s -> pattern=%s wildcard=%v match_subdomain=%v priority=%d",
qname, entry.OrigPattern, entry.IsWildcard, entry.MatchSubdomains, entry.Priority)
chainWriter := &ResponseWriterChain{
ResponseWriter: w,
origPattern: entry.OrigPattern,
}
entry.Handler.ServeDNS(chainWriter, r)
// If handler wants to continue, try next handler
if chainWriter.shouldContinue {
// Only log continue for non-management cache handlers to reduce noise
if entry.Priority != PriorityMgmtCache {
log.Tracef("handler requested continue to next handler for domain=%s", qname)
}
continue
}
return
if !c.isHandlerMatch(qname, entry) {
continue
}
handlerName := entry.OrigPattern
if s, ok := entry.Handler.(interface{ String() string }); ok {
handlerName = s.String()
}
logger.Tracef("question: domain=%s type=%s class=%s -> handler=%s pattern=%s wildcard=%v match_subdomain=%v priority=%d",
qname, dns.TypeToString[question.Qtype], dns.ClassToString[question.Qclass],
handlerName, entry.OrigPattern, entry.IsWildcard, entry.MatchSubdomains, entry.Priority)
chainWriter := &ResponseWriterChain{
ResponseWriter: w,
origPattern: entry.OrigPattern,
requestID: requestID,
}
entry.Handler.ServeDNS(chainWriter, r)
// If handler wants to continue, try next handler
if chainWriter.shouldContinue {
if entry.Priority != PriorityMgmtCache {
logger.Tracef("handler requested continue for domain=%s", qname)
}
continue
}
c.logResponse(logger, chainWriter, qname, startTime)
return
}
// No handler matched or all handlers passed
log.Tracef("no handler found for domain=%s", qname)
logger.Tracef("no handler found for domain=%s type=%s class=%s",
qname, dns.TypeToString[question.Qtype], dns.ClassToString[question.Qclass])
resp := &dns.Msg{}
resp.SetRcode(r, dns.RcodeRefused)
if err := w.WriteMsg(resp); err != nil {
log.Errorf("failed to write DNS response: %v", err)
logger.Errorf("failed to write DNS response: %v", err)
}
}
func (c *HandlerChain) logResponse(logger *log.Entry, cw *ResponseWriterChain, qname string, startTime time.Time) {
if cw.response == nil {
return
}
var meta string
for k, v := range cw.meta {
meta += " " + k + "=" + v
}
logger.Tracef("response: domain=%s rcode=%s answers=%s%s took=%s",
qname, dns.RcodeToString[cw.response.Rcode], resutil.FormatAnswers(cw.response.Answer),
meta, time.Since(startTime))
}
func (c *HandlerChain) isHandlerMatch(qname string, entry HandlerEntry) bool {
switch {
case entry.Pattern == ".":

View File

@@ -1,30 +1,52 @@
package local
import (
"context"
"errors"
"fmt"
"net"
"net/netip"
"slices"
"strings"
"sync"
"time"
"github.com/miekg/dns"
log "github.com/sirupsen/logrus"
"golang.org/x/exp/maps"
"github.com/netbirdio/netbird/client/internal/dns/resutil"
"github.com/netbirdio/netbird/client/internal/dns/types"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/shared/management/domain"
)
const externalResolutionTimeout = 4 * time.Second
type resolver interface {
LookupNetIP(ctx context.Context, network, host string) ([]netip.Addr, error)
}
type Resolver struct {
mu sync.RWMutex
records map[dns.Question][]dns.RR
domains map[domain.Domain]struct{}
// zones maps zone domain -> NonAuthoritative (true = non-authoritative, user-created zone)
zones map[domain.Domain]bool
resolver resolver
ctx context.Context
cancel context.CancelFunc
}
func NewResolver() *Resolver {
ctx, cancel := context.WithCancel(context.Background())
return &Resolver{
records: make(map[dns.Question][]dns.RR),
domains: make(map[domain.Domain]struct{}),
zones: make(map[domain.Domain]bool),
ctx: ctx,
cancel: cancel,
}
}
@@ -37,7 +59,18 @@ func (d *Resolver) String() string {
return fmt.Sprintf("LocalResolver [%d records]", len(d.records))
}
func (d *Resolver) Stop() {}
func (d *Resolver) Stop() {
if d.cancel != nil {
d.cancel()
}
d.mu.Lock()
defer d.mu.Unlock()
maps.Clear(d.records)
maps.Clear(d.domains)
maps.Clear(d.zones)
}
// ID returns the unique handler ID
func (d *Resolver) ID() types.HandlerID {
@@ -48,35 +81,85 @@ func (d *Resolver) ProbeAvailability() {}
// ServeDNS handles a DNS request
func (d *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
logger := log.WithField("request_id", resutil.GetRequestID(w))
if len(r.Question) == 0 {
log.Debugf("received local resolver request with no question")
logger.Debug("received local resolver request with no question")
return
}
question := r.Question[0]
question.Name = strings.ToLower(dns.Fqdn(question.Name))
log.Tracef("received local question: domain=%s type=%v class=%v", r.Question[0].Name, question.Qtype, question.Qclass)
replyMessage := &dns.Msg{}
replyMessage.SetReply(r)
replyMessage.RecursionAvailable = true
// lookup all records matching the question
records := d.lookupRecords(question)
if len(records) > 0 {
replyMessage.Rcode = dns.RcodeSuccess
replyMessage.Answer = append(replyMessage.Answer, records...)
} else {
// Check if we have any records for this domain name with different types
if d.hasRecordsForDomain(domain.Domain(question.Name)) {
replyMessage.Rcode = dns.RcodeSuccess // NOERROR with 0 records
} else {
replyMessage.Rcode = dns.RcodeNameError // NXDOMAIN
}
result := d.lookupRecords(logger, question)
replyMessage.Authoritative = !result.hasExternalData
replyMessage.Answer = result.records
replyMessage.Rcode = d.determineRcode(question, result)
if replyMessage.Rcode == dns.RcodeNameError && d.shouldFallthrough(question.Name) {
d.continueToNext(logger, w, r)
return
}
if err := w.WriteMsg(replyMessage); err != nil {
log.Warnf("failed to write the local resolver response: %v", err)
logger.Warnf("failed to write the local resolver response: %v", err)
}
}
// determineRcode returns the appropriate DNS response code.
// Per RFC 6604, CNAME chains should return the rcode of the final target resolution,
// even if CNAME records are included in the answer.
func (d *Resolver) determineRcode(question dns.Question, result lookupResult) int {
// Use the rcode from lookup - this properly handles CNAME chains where
// the target may be NXDOMAIN or SERVFAIL even though we have CNAME records
if result.rcode != 0 {
return result.rcode
}
// No records found, but domain exists with different record types (NODATA)
if d.hasRecordsForDomain(domain.Domain(question.Name)) {
return dns.RcodeSuccess
}
return dns.RcodeNameError
}
// findZone finds the matching zone for a query name using reverse suffix lookup.
// Returns (nonAuthoritative, found). This is O(k) where k = number of labels in qname.
func (d *Resolver) findZone(qname string) (nonAuthoritative bool, found bool) {
qname = strings.ToLower(dns.Fqdn(qname))
for {
if nonAuth, ok := d.zones[domain.Domain(qname)]; ok {
return nonAuth, true
}
// Move to parent domain
idx := strings.Index(qname, ".")
if idx == -1 || idx == len(qname)-1 {
return false, false
}
qname = qname[idx+1:]
}
}
// shouldFallthrough checks if the query should fallthrough to the next handler.
// Returns true if the queried name belongs to a non-authoritative zone.
func (d *Resolver) shouldFallthrough(qname string) bool {
d.mu.RLock()
defer d.mu.RUnlock()
nonAuth, found := d.findZone(qname)
return found && nonAuth
}
func (d *Resolver) continueToNext(logger *log.Entry, w dns.ResponseWriter, r *dns.Msg) {
resp := &dns.Msg{}
resp.SetRcode(r, dns.RcodeNameError)
resp.MsgHdr.Zero = true
if err := w.WriteMsg(resp); err != nil {
logger.Warnf("failed to write continue signal: %v", err)
}
}
@@ -89,8 +172,27 @@ func (d *Resolver) hasRecordsForDomain(domainName domain.Domain) bool {
return exists
}
// isInManagedZone checks if the given name falls within any of our managed zones.
// This is used to avoid unnecessary external resolution for CNAME targets that
// are within zones we manage - if we don't have a record for it, it doesn't exist.
// Caller must NOT hold the lock.
func (d *Resolver) isInManagedZone(name string) bool {
d.mu.RLock()
defer d.mu.RUnlock()
_, found := d.findZone(name)
return found
}
// lookupResult contains the result of a DNS lookup operation.
type lookupResult struct {
records []dns.RR
rcode int
hasExternalData bool
}
// lookupRecords fetches *all* DNS records matching the first question in r.
func (d *Resolver) lookupRecords(question dns.Question) []dns.RR {
func (d *Resolver) lookupRecords(logger *log.Entry, question dns.Question) lookupResult {
d.mu.RLock()
records, found := d.records[question]
@@ -98,10 +200,14 @@ func (d *Resolver) lookupRecords(question dns.Question) []dns.RR {
d.mu.RUnlock()
// alternatively check if we have a cname
if question.Qtype != dns.TypeCNAME {
question.Qtype = dns.TypeCNAME
return d.lookupRecords(question)
cnameQuestion := dns.Question{
Name: question.Name,
Qtype: dns.TypeCNAME,
Qclass: question.Qclass,
}
return d.lookupCNAMEChain(logger, cnameQuestion, question.Qtype)
}
return nil
return lookupResult{rcode: dns.RcodeNameError}
}
recordsCopy := slices.Clone(records)
@@ -119,20 +225,178 @@ func (d *Resolver) lookupRecords(question dns.Question) []dns.RR {
d.mu.Unlock()
}
return recordsCopy
return lookupResult{records: recordsCopy, rcode: dns.RcodeSuccess}
}
func (d *Resolver) Update(update []nbdns.SimpleRecord) {
// lookupCNAMEChain follows a CNAME chain and returns the CNAME records along with
// the final resolved record of the requested type. This is required for musl libc
// compatibility, which expects the full answer chain rather than just the CNAME.
func (d *Resolver) lookupCNAMEChain(logger *log.Entry, cnameQuestion dns.Question, targetType uint16) lookupResult {
const maxDepth = 8
var chain []dns.RR
for range maxDepth {
cnameRecords := d.getRecords(cnameQuestion)
if len(cnameRecords) == 0 {
break
}
chain = append(chain, cnameRecords...)
cname, ok := cnameRecords[0].(*dns.CNAME)
if !ok {
break
}
targetName := strings.ToLower(cname.Target)
targetResult := d.resolveCNAMETarget(logger, targetName, targetType, cnameQuestion.Qclass)
// keep following chain
if targetResult.rcode == -1 {
cnameQuestion = dns.Question{Name: targetName, Qtype: dns.TypeCNAME, Qclass: cnameQuestion.Qclass}
continue
}
return d.buildChainResult(chain, targetResult)
}
if len(chain) > 0 {
return lookupResult{records: chain, rcode: dns.RcodeSuccess}
}
return lookupResult{rcode: dns.RcodeSuccess}
}
// buildChainResult combines CNAME chain records with the target resolution result.
// Per RFC 6604, the final rcode is propagated through the chain.
func (d *Resolver) buildChainResult(chain []dns.RR, target lookupResult) lookupResult {
records := chain
if len(target.records) > 0 {
records = append(records, target.records...)
}
// preserve hasExternalData for SERVFAIL so caller knows the error came from upstream
if target.hasExternalData && target.rcode == dns.RcodeServerFailure {
return lookupResult{
records: records,
rcode: dns.RcodeServerFailure,
hasExternalData: true,
}
}
return lookupResult{
records: records,
rcode: target.rcode,
hasExternalData: target.hasExternalData,
}
}
// resolveCNAMETarget attempts to resolve a CNAME target name.
// Returns rcode=-1 to signal "keep following the chain".
func (d *Resolver) resolveCNAMETarget(logger *log.Entry, targetName string, targetType uint16, qclass uint16) lookupResult {
if records := d.getRecords(dns.Question{Name: targetName, Qtype: targetType, Qclass: qclass}); len(records) > 0 {
return lookupResult{records: records, rcode: dns.RcodeSuccess}
}
// another CNAME, keep following
if d.hasRecord(dns.Question{Name: targetName, Qtype: dns.TypeCNAME, Qclass: qclass}) {
return lookupResult{rcode: -1}
}
// domain exists locally but not this record type (NODATA)
if d.hasRecordsForDomain(domain.Domain(targetName)) {
return lookupResult{rcode: dns.RcodeSuccess}
}
// in our zone but doesn't exist (NXDOMAIN)
if d.isInManagedZone(targetName) {
return lookupResult{rcode: dns.RcodeNameError}
}
return d.resolveExternal(logger, targetName, targetType)
}
func (d *Resolver) getRecords(q dns.Question) []dns.RR {
d.mu.RLock()
defer d.mu.RUnlock()
return d.records[q]
}
func (d *Resolver) hasRecord(q dns.Question) bool {
d.mu.RLock()
defer d.mu.RUnlock()
_, ok := d.records[q]
return ok
}
// resolveExternal resolves a domain name using the system resolver.
// This is used to resolve CNAME targets that point outside our local zone,
// which is required for musl libc compatibility (musl expects complete answers).
func (d *Resolver) resolveExternal(logger *log.Entry, name string, qtype uint16) lookupResult {
network := resutil.NetworkForQtype(qtype)
if network == "" {
return lookupResult{rcode: dns.RcodeNotImplemented}
}
resolver := d.resolver
if resolver == nil {
resolver = net.DefaultResolver
}
ctx, cancel := context.WithTimeout(d.ctx, externalResolutionTimeout)
defer cancel()
result := resutil.LookupIP(ctx, resolver, network, name, qtype)
if result.Err != nil {
d.logDNSError(logger, name, qtype, result.Err)
return lookupResult{rcode: result.Rcode, hasExternalData: true}
}
return lookupResult{
records: resutil.IPsToRRs(name, result.IPs, 60),
rcode: dns.RcodeSuccess,
hasExternalData: true,
}
}
// logDNSError logs DNS resolution errors for debugging.
func (d *Resolver) logDNSError(logger *log.Entry, hostname string, qtype uint16, err error) {
qtypeName := dns.TypeToString[qtype]
var dnsErr *net.DNSError
if !errors.As(err, &dnsErr) {
logger.Debugf("DNS resolution failed for %s type %s: %v", hostname, qtypeName, err)
return
}
if dnsErr.IsNotFound {
logger.Tracef("DNS target not found: %s type %s", hostname, qtypeName)
return
}
if dnsErr.Server != "" {
logger.Debugf("DNS resolution failed for %s type %s server=%s: %v", hostname, qtypeName, dnsErr.Server, err)
} else {
logger.Debugf("DNS resolution failed for %s type %s: %v", hostname, qtypeName, err)
}
}
// Update replaces all zones and their records
func (d *Resolver) Update(customZones []nbdns.CustomZone) {
d.mu.Lock()
defer d.mu.Unlock()
maps.Clear(d.records)
maps.Clear(d.domains)
maps.Clear(d.zones)
for _, rec := range update {
if err := d.registerRecord(rec); err != nil {
log.Warnf("failed to register the record (%s): %v", rec, err)
continue
for _, zone := range customZones {
zoneDomain := domain.Domain(strings.ToLower(dns.Fqdn(zone.Domain)))
d.zones[zoneDomain] = zone.NonAuthoritative
for _, rec := range zone.Records {
if err := d.registerRecord(rec); err != nil {
log.Warnf("failed to register the record (%s): %v", rec, err)
}
}
}
}

View File

@@ -1,8 +1,14 @@
package local
import (
"context"
"fmt"
"net"
"net/netip"
"strings"
"sync"
"testing"
"time"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
@@ -12,6 +18,18 @@ import (
nbdns "github.com/netbirdio/netbird/dns"
)
// mockResolver implements resolver for testing
type mockResolver struct {
lookupFunc func(ctx context.Context, network, host string) ([]netip.Addr, error)
}
func (m *mockResolver) LookupNetIP(ctx context.Context, network, host string) ([]netip.Addr, error) {
if m.lookupFunc != nil {
return m.lookupFunc(ctx, network, host)
}
return nil, nil
}
func TestLocalResolver_ServeDNS(t *testing.T) {
recordA := nbdns.SimpleRecord{
Name: "peera.netbird.cloud.",
@@ -106,11 +124,11 @@ func TestLocalResolver_Update_StaleRecord(t *testing.T) {
resolver := NewResolver()
update1 := []nbdns.SimpleRecord{record1}
update2 := []nbdns.SimpleRecord{record2}
zone1 := []nbdns.CustomZone{{Domain: "example.com.", Records: []nbdns.SimpleRecord{record1}}}
zone2 := []nbdns.CustomZone{{Domain: "example.com.", Records: []nbdns.SimpleRecord{record2}}}
// Apply first update
resolver.Update(update1)
resolver.Update(zone1)
// Verify first update
resolver.mu.RLock()
@@ -122,7 +140,7 @@ func TestLocalResolver_Update_StaleRecord(t *testing.T) {
assert.Contains(t, rrSlice1[0].String(), record1.RData, "Record after first update should be %s", record1.RData)
// Apply second update
resolver.Update(update2)
resolver.Update(zone2)
// Verify second update
resolver.mu.RLock()
@@ -151,10 +169,10 @@ func TestLocalResolver_MultipleRecords_SameQuestion(t *testing.T) {
Name: recordName, Type: int(recordType), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.2",
}
update := []nbdns.SimpleRecord{record1, record2}
zones := []nbdns.CustomZone{{Domain: "example.com.", Records: []nbdns.SimpleRecord{record1, record2}}}
// Apply update with both records
resolver.Update(update)
resolver.Update(zones)
// Create question that matches both records
question := dns.Question{
@@ -195,10 +213,10 @@ func TestLocalResolver_RecordRotation(t *testing.T) {
Name: recordName, Type: int(recordType), Class: nbdns.DefaultClass, TTL: 300, RData: "192.168.1.3",
}
update := []nbdns.SimpleRecord{record1, record2, record3}
zones := []nbdns.CustomZone{{Domain: "example.com.", Records: []nbdns.SimpleRecord{record1, record2, record3}}}
// Apply update with all three records
resolver.Update(update)
resolver.Update(zones)
msg := new(dns.Msg).SetQuestion(recordName, recordType)
@@ -264,7 +282,7 @@ func TestLocalResolver_CaseInsensitiveMatching(t *testing.T) {
}
// Update resolver with the records
resolver.Update([]nbdns.SimpleRecord{lowerCaseRecord, mixedCaseRecord})
resolver.Update([]nbdns.CustomZone{{Domain: "example.com.", Records: []nbdns.SimpleRecord{lowerCaseRecord, mixedCaseRecord}}})
testCases := []struct {
name string
@@ -379,7 +397,7 @@ func TestLocalResolver_CNAMEFallback(t *testing.T) {
}
// Update resolver with both records
resolver.Update([]nbdns.SimpleRecord{cnameRecord, targetRecord})
resolver.Update([]nbdns.CustomZone{{Domain: "example.com.", Records: []nbdns.SimpleRecord{cnameRecord, targetRecord}}})
testCases := []struct {
name string
@@ -476,6 +494,20 @@ func TestLocalResolver_CNAMEFallback(t *testing.T) {
// with 0 records instead of NXDOMAIN
func TestLocalResolver_NoErrorWithDifferentRecordType(t *testing.T) {
resolver := NewResolver()
// Mock external resolver for CNAME target resolution
resolver.resolver = &mockResolver{
lookupFunc: func(_ context.Context, network, host string) ([]netip.Addr, error) {
if host == "target.example.com." {
if network == "ip4" {
return []netip.Addr{netip.MustParseAddr("93.184.216.34")}, nil
}
if network == "ip6" {
return []netip.Addr{netip.MustParseAddr("2606:2800:220:1:248:1893:25c8:1946")}, nil
}
}
return nil, &net.DNSError{IsNotFound: true, Name: host}
},
}
recordA := nbdns.SimpleRecord{
Name: "example.netbird.cloud.",
@@ -493,7 +525,7 @@ func TestLocalResolver_NoErrorWithDifferentRecordType(t *testing.T) {
RData: "target.example.com.",
}
resolver.Update([]nbdns.SimpleRecord{recordA, recordCNAME})
resolver.Update([]nbdns.CustomZone{{Domain: "netbird.cloud.", Records: []nbdns.SimpleRecord{recordA, recordCNAME}}})
testCases := []struct {
name string
@@ -582,3 +614,808 @@ func TestLocalResolver_NoErrorWithDifferentRecordType(t *testing.T) {
})
}
}
// TestLocalResolver_CNAMEChainResolution tests comprehensive CNAME chain following
func TestLocalResolver_CNAMEChainResolution(t *testing.T) {
t.Run("simple internal CNAME chain", func(t *testing.T) {
resolver := NewResolver()
resolver.Update([]nbdns.CustomZone{{
Domain: "example.com.",
Records: []nbdns.SimpleRecord{
{Name: "alias.example.com.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "target.example.com."},
{Name: "target.example.com.", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "192.168.1.1"},
},
}})
msg := new(dns.Msg).SetQuestion("alias.example.com.", dns.TypeA)
var resp *dns.Msg
resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { resp = m; return nil }}, msg)
require.NotNil(t, resp)
require.Equal(t, dns.RcodeSuccess, resp.Rcode)
require.Len(t, resp.Answer, 2)
cname, ok := resp.Answer[0].(*dns.CNAME)
require.True(t, ok)
assert.Equal(t, "target.example.com.", cname.Target)
a, ok := resp.Answer[1].(*dns.A)
require.True(t, ok)
assert.Equal(t, "192.168.1.1", a.A.String())
})
t.Run("multi-hop CNAME chain", func(t *testing.T) {
resolver := NewResolver()
resolver.Update([]nbdns.CustomZone{{
Domain: "test.",
Records: []nbdns.SimpleRecord{
{Name: "hop1.test.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "hop2.test."},
{Name: "hop2.test.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "hop3.test."},
{Name: "hop3.test.", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"},
},
}})
msg := new(dns.Msg).SetQuestion("hop1.test.", dns.TypeA)
var resp *dns.Msg
resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { resp = m; return nil }}, msg)
require.NotNil(t, resp)
require.Equal(t, dns.RcodeSuccess, resp.Rcode)
require.Len(t, resp.Answer, 3)
})
t.Run("CNAME to non-existent internal target returns only CNAME", func(t *testing.T) {
resolver := NewResolver()
resolver.Update([]nbdns.CustomZone{{
Domain: "test.",
Records: []nbdns.SimpleRecord{
{Name: "alias.test.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "nonexistent.test."},
},
}})
msg := new(dns.Msg).SetQuestion("alias.test.", dns.TypeA)
var resp *dns.Msg
resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { resp = m; return nil }}, msg)
require.NotNil(t, resp)
require.Len(t, resp.Answer, 1)
_, ok := resp.Answer[0].(*dns.CNAME)
assert.True(t, ok)
})
}
// TestLocalResolver_CNAMEMaxDepth tests the maximum depth limit for CNAME chains
func TestLocalResolver_CNAMEMaxDepth(t *testing.T) {
t.Run("chain at max depth resolves", func(t *testing.T) {
resolver := NewResolver()
var records []nbdns.SimpleRecord
// Create chain of 7 CNAMEs (under max of 8)
for i := 1; i <= 7; i++ {
records = append(records, nbdns.SimpleRecord{
Name: fmt.Sprintf("hop%d.test.", i),
Type: int(dns.TypeCNAME),
Class: nbdns.DefaultClass,
TTL: 300,
RData: fmt.Sprintf("hop%d.test.", i+1),
})
}
records = append(records, nbdns.SimpleRecord{
Name: "hop8.test.", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.10.10.10",
})
resolver.Update([]nbdns.CustomZone{{Domain: "test.", Records: records}})
msg := new(dns.Msg).SetQuestion("hop1.test.", dns.TypeA)
var resp *dns.Msg
resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { resp = m; return nil }}, msg)
require.NotNil(t, resp)
require.Equal(t, dns.RcodeSuccess, resp.Rcode)
require.Len(t, resp.Answer, 8)
})
t.Run("chain exceeding max depth stops", func(t *testing.T) {
resolver := NewResolver()
var records []nbdns.SimpleRecord
// Create chain of 10 CNAMEs (exceeds max of 8)
for i := 1; i <= 10; i++ {
records = append(records, nbdns.SimpleRecord{
Name: fmt.Sprintf("deep%d.test.", i),
Type: int(dns.TypeCNAME),
Class: nbdns.DefaultClass,
TTL: 300,
RData: fmt.Sprintf("deep%d.test.", i+1),
})
}
records = append(records, nbdns.SimpleRecord{
Name: "deep11.test.", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.10.10.10",
})
resolver.Update([]nbdns.CustomZone{{Domain: "test.", Records: records}})
msg := new(dns.Msg).SetQuestion("deep1.test.", dns.TypeA)
var resp *dns.Msg
resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { resp = m; return nil }}, msg)
require.NotNil(t, resp)
// Should NOT have the final A record (chain too deep)
assert.LessOrEqual(t, len(resp.Answer), 8)
})
t.Run("circular CNAME is protected by max depth", func(t *testing.T) {
resolver := NewResolver()
resolver.Update([]nbdns.CustomZone{{
Domain: "test.",
Records: []nbdns.SimpleRecord{
{Name: "loop1.test.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "loop2.test."},
{Name: "loop2.test.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "loop1.test."},
},
}})
msg := new(dns.Msg).SetQuestion("loop1.test.", dns.TypeA)
var resp *dns.Msg
resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { resp = m; return nil }}, msg)
require.NotNil(t, resp)
assert.LessOrEqual(t, len(resp.Answer), 8)
})
}
// TestLocalResolver_ExternalCNAMEResolution tests CNAME resolution to external domains
func TestLocalResolver_ExternalCNAMEResolution(t *testing.T) {
t.Run("CNAME to external domain resolves via external resolver", func(t *testing.T) {
resolver := NewResolver()
resolver.resolver = &mockResolver{
lookupFunc: func(_ context.Context, network, host string) ([]netip.Addr, error) {
if host == "external.example.com." && network == "ip4" {
return []netip.Addr{netip.MustParseAddr("93.184.216.34")}, nil
}
return nil, nil
},
}
resolver.Update([]nbdns.CustomZone{{
Domain: "test.",
Records: []nbdns.SimpleRecord{
{Name: "alias.test.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "external.example.com."},
},
}})
msg := new(dns.Msg).SetQuestion("alias.test.", dns.TypeA)
var resp *dns.Msg
resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { resp = m; return nil }}, msg)
require.NotNil(t, resp)
require.Len(t, resp.Answer, 2, "Should have CNAME + A record")
cname, ok := resp.Answer[0].(*dns.CNAME)
require.True(t, ok)
assert.Equal(t, "external.example.com.", cname.Target)
a, ok := resp.Answer[1].(*dns.A)
require.True(t, ok)
assert.Equal(t, "93.184.216.34", a.A.String())
})
t.Run("CNAME to external domain resolves IPv6", func(t *testing.T) {
resolver := NewResolver()
resolver.resolver = &mockResolver{
lookupFunc: func(_ context.Context, network, host string) ([]netip.Addr, error) {
if host == "external.example.com." && network == "ip6" {
return []netip.Addr{netip.MustParseAddr("2606:2800:220:1:248:1893:25c8:1946")}, nil
}
return nil, nil
},
}
resolver.Update([]nbdns.CustomZone{{
Domain: "test.",
Records: []nbdns.SimpleRecord{
{Name: "alias.test.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "external.example.com."},
},
}})
msg := new(dns.Msg).SetQuestion("alias.test.", dns.TypeAAAA)
var resp *dns.Msg
resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { resp = m; return nil }}, msg)
require.NotNil(t, resp)
require.Len(t, resp.Answer, 2, "Should have CNAME + AAAA record")
cname, ok := resp.Answer[0].(*dns.CNAME)
require.True(t, ok)
assert.Equal(t, "external.example.com.", cname.Target)
aaaa, ok := resp.Answer[1].(*dns.AAAA)
require.True(t, ok)
assert.Equal(t, "2606:2800:220:1:248:1893:25c8:1946", aaaa.AAAA.String())
})
t.Run("concurrent external resolution", func(t *testing.T) {
resolver := NewResolver()
resolver.resolver = &mockResolver{
lookupFunc: func(_ context.Context, network, host string) ([]netip.Addr, error) {
if host == "external.example.com." && network == "ip4" {
return []netip.Addr{netip.MustParseAddr("93.184.216.34")}, nil
}
return nil, nil
},
}
resolver.Update([]nbdns.CustomZone{{
Domain: "test.",
Records: []nbdns.SimpleRecord{
{Name: "concurrent.test.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "external.example.com."},
},
}})
var wg sync.WaitGroup
results := make([]*dns.Msg, 10)
for i := 0; i < 10; i++ {
wg.Add(1)
go func(idx int) {
defer wg.Done()
msg := new(dns.Msg).SetQuestion("concurrent.test.", dns.TypeA)
var resp *dns.Msg
resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { resp = m; return nil }}, msg)
results[idx] = resp
}(i)
}
wg.Wait()
for i, resp := range results {
require.NotNil(t, resp, "Response %d should not be nil", i)
require.Len(t, resp.Answer, 2, "Response %d should have CNAME + A", i)
}
})
}
// TestLocalResolver_ZoneManagement tests zone-aware CNAME resolution
func TestLocalResolver_ZoneManagement(t *testing.T) {
t.Run("Update sets zones correctly", func(t *testing.T) {
resolver := NewResolver()
resolver.Update([]nbdns.CustomZone{
{Domain: "example.com.", Records: []nbdns.SimpleRecord{
{Name: "host.example.com.", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"},
}},
{Domain: "test.local."},
})
assert.True(t, resolver.isInManagedZone("host.example.com."))
assert.True(t, resolver.isInManagedZone("other.example.com."))
assert.True(t, resolver.isInManagedZone("sub.test.local."))
assert.False(t, resolver.isInManagedZone("external.com."))
})
t.Run("isInManagedZone case insensitive", func(t *testing.T) {
resolver := NewResolver()
resolver.Update([]nbdns.CustomZone{{Domain: "Example.COM."}})
assert.True(t, resolver.isInManagedZone("host.example.com."))
assert.True(t, resolver.isInManagedZone("HOST.EXAMPLE.COM."))
})
t.Run("Update clears zones", func(t *testing.T) {
resolver := NewResolver()
resolver.Update([]nbdns.CustomZone{{Domain: "example.com."}})
assert.True(t, resolver.isInManagedZone("host.example.com."))
resolver.Update(nil)
assert.False(t, resolver.isInManagedZone("host.example.com."))
})
}
// TestLocalResolver_CNAMEZoneAwareResolution tests CNAME resolution with zone awareness
func TestLocalResolver_CNAMEZoneAwareResolution(t *testing.T) {
t.Run("CNAME target in managed zone returns NXDOMAIN per RFC 6604", func(t *testing.T) {
resolver := NewResolver()
resolver.Update([]nbdns.CustomZone{{
Domain: "myzone.test.",
Records: []nbdns.SimpleRecord{
{Name: "alias.myzone.test.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "nonexistent.myzone.test."},
},
}})
msg := new(dns.Msg).SetQuestion("alias.myzone.test.", dns.TypeA)
var resp *dns.Msg
resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { resp = m; return nil }}, msg)
require.NotNil(t, resp)
assert.Equal(t, dns.RcodeNameError, resp.Rcode, "Should return NXDOMAIN")
require.Len(t, resp.Answer, 1, "Should include CNAME in answer")
})
t.Run("CNAME to external domain skips zone check", func(t *testing.T) {
resolver := NewResolver()
resolver.resolver = &mockResolver{
lookupFunc: func(_ context.Context, network, host string) ([]netip.Addr, error) {
if host == "external.other.com." && network == "ip4" {
return []netip.Addr{netip.MustParseAddr("203.0.113.1")}, nil
}
return nil, nil
},
}
resolver.Update([]nbdns.CustomZone{{
Domain: "myzone.test.",
Records: []nbdns.SimpleRecord{
{Name: "alias.myzone.test.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "external.other.com."},
},
}})
msg := new(dns.Msg).SetQuestion("alias.myzone.test.", dns.TypeA)
var resp *dns.Msg
resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { resp = m; return nil }}, msg)
require.NotNil(t, resp)
assert.Equal(t, dns.RcodeSuccess, resp.Rcode)
require.Len(t, resp.Answer, 2, "Should have CNAME + A from external resolution")
})
t.Run("CNAME target exists with different type returns NODATA not NXDOMAIN", func(t *testing.T) {
resolver := NewResolver()
// CNAME points to target that has A but no AAAA - query for AAAA should be NODATA
resolver.Update([]nbdns.CustomZone{{
Domain: "myzone.test.",
Records: []nbdns.SimpleRecord{
{Name: "alias.myzone.test.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "target.myzone.test."},
{Name: "target.myzone.test.", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "1.1.1.1"},
},
}})
msg := new(dns.Msg).SetQuestion("alias.myzone.test.", dns.TypeAAAA)
var resp *dns.Msg
resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { resp = m; return nil }}, msg)
require.NotNil(t, resp)
assert.Equal(t, dns.RcodeSuccess, resp.Rcode, "Should return NODATA (success), not NXDOMAIN")
require.Len(t, resp.Answer, 1, "Should have only CNAME, no AAAA")
_, ok := resp.Answer[0].(*dns.CNAME)
assert.True(t, ok, "Answer should be CNAME record")
})
t.Run("external CNAME target exists but no AAAA records (NODATA)", func(t *testing.T) {
resolver := NewResolver()
resolver.resolver = &mockResolver{
lookupFunc: func(_ context.Context, network, host string) ([]netip.Addr, error) {
if host == "external.example.com." {
if network == "ip6" {
// No AAAA records
return nil, &net.DNSError{IsNotFound: true, Name: host}
}
if network == "ip4" {
// But A records exist - domain exists
return []netip.Addr{netip.MustParseAddr("93.184.216.34")}, nil
}
}
return nil, &net.DNSError{IsNotFound: true, Name: host}
},
}
resolver.Update([]nbdns.CustomZone{{
Domain: "test.",
Records: []nbdns.SimpleRecord{
{Name: "alias.test.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "external.example.com."},
},
}})
msg := new(dns.Msg).SetQuestion("alias.test.", dns.TypeAAAA)
var resp *dns.Msg
resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { resp = m; return nil }}, msg)
require.NotNil(t, resp)
assert.Equal(t, dns.RcodeSuccess, resp.Rcode, "Should return NODATA (success), not NXDOMAIN")
require.Len(t, resp.Answer, 1, "Should have only CNAME")
_, ok := resp.Answer[0].(*dns.CNAME)
assert.True(t, ok, "Answer should be CNAME record")
})
// Table-driven test for all external resolution outcomes
externalCases := []struct {
name string
lookupFunc func(context.Context, string, string) ([]netip.Addr, error)
expectedRcode int
expectedAnswer int
}{
{
name: "external NXDOMAIN (both A and AAAA not found)",
lookupFunc: func(_ context.Context, network, host string) ([]netip.Addr, error) {
return nil, &net.DNSError{IsNotFound: true, Name: host}
},
expectedRcode: dns.RcodeNameError,
expectedAnswer: 1, // CNAME only
},
{
name: "external SERVFAIL (temporary error)",
lookupFunc: func(_ context.Context, network, host string) ([]netip.Addr, error) {
return nil, &net.DNSError{IsTemporary: true, Name: host}
},
expectedRcode: dns.RcodeServerFailure,
expectedAnswer: 1, // CNAME only
},
{
name: "external SERVFAIL (timeout)",
lookupFunc: func(_ context.Context, network, host string) ([]netip.Addr, error) {
return nil, &net.DNSError{IsTimeout: true, Name: host}
},
expectedRcode: dns.RcodeServerFailure,
expectedAnswer: 1, // CNAME only
},
{
name: "external SERVFAIL (generic error)",
lookupFunc: func(_ context.Context, network, host string) ([]netip.Addr, error) {
return nil, fmt.Errorf("connection refused")
},
expectedRcode: dns.RcodeServerFailure,
expectedAnswer: 1, // CNAME only
},
{
name: "external success with IPs",
lookupFunc: func(_ context.Context, network, host string) ([]netip.Addr, error) {
if network == "ip4" {
return []netip.Addr{netip.MustParseAddr("93.184.216.34")}, nil
}
return nil, &net.DNSError{IsNotFound: true, Name: host}
},
expectedRcode: dns.RcodeSuccess,
expectedAnswer: 2, // CNAME + A
},
}
for _, tc := range externalCases {
t.Run(tc.name, func(t *testing.T) {
resolver := NewResolver()
resolver.resolver = &mockResolver{lookupFunc: tc.lookupFunc}
resolver.Update([]nbdns.CustomZone{{
Domain: "test.",
Records: []nbdns.SimpleRecord{
{Name: "alias.test.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "external.example.com."},
},
}})
msg := new(dns.Msg).SetQuestion("alias.test.", dns.TypeA)
var resp *dns.Msg
resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { resp = m; return nil }}, msg)
require.NotNil(t, resp)
assert.Equal(t, tc.expectedRcode, resp.Rcode, "rcode mismatch")
assert.Len(t, resp.Answer, tc.expectedAnswer, "answer count mismatch")
if tc.expectedAnswer > 0 {
_, ok := resp.Answer[0].(*dns.CNAME)
assert.True(t, ok, "first answer should be CNAME")
}
})
}
}
// TestLocalResolver_Fallthrough verifies that non-authoritative zones
// trigger fallthrough (Zero bit set) when no records match
func TestLocalResolver_Fallthrough(t *testing.T) {
resolver := NewResolver()
record := nbdns.SimpleRecord{
Name: "existing.custom.zone.",
Type: int(dns.TypeA),
Class: nbdns.DefaultClass,
TTL: 300,
RData: "10.0.0.1",
}
testCases := []struct {
name string
zones []nbdns.CustomZone
queryName string
expectFallthrough bool
expectRecord bool
}{
{
name: "Authoritative zone returns NXDOMAIN without fallthrough",
zones: []nbdns.CustomZone{{
Domain: "custom.zone.",
Records: []nbdns.SimpleRecord{record},
}},
queryName: "nonexistent.custom.zone.",
expectFallthrough: false,
expectRecord: false,
},
{
name: "Non-authoritative zone triggers fallthrough",
zones: []nbdns.CustomZone{{
Domain: "custom.zone.",
Records: []nbdns.SimpleRecord{record},
NonAuthoritative: true,
}},
queryName: "nonexistent.custom.zone.",
expectFallthrough: true,
expectRecord: false,
},
{
name: "Record found in non-authoritative zone returns normally",
zones: []nbdns.CustomZone{{
Domain: "custom.zone.",
Records: []nbdns.SimpleRecord{record},
NonAuthoritative: true,
}},
queryName: "existing.custom.zone.",
expectFallthrough: false,
expectRecord: true,
},
{
name: "Record found in authoritative zone returns normally",
zones: []nbdns.CustomZone{{
Domain: "custom.zone.",
Records: []nbdns.SimpleRecord{record},
}},
queryName: "existing.custom.zone.",
expectFallthrough: false,
expectRecord: true,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
resolver.Update(tc.zones)
var responseMSG *dns.Msg
responseWriter := &test.MockResponseWriter{
WriteMsgFunc: func(m *dns.Msg) error {
responseMSG = m
return nil
},
}
msg := new(dns.Msg).SetQuestion(tc.queryName, dns.TypeA)
resolver.ServeDNS(responseWriter, msg)
require.NotNil(t, responseMSG, "Should have received a response")
if tc.expectFallthrough {
assert.True(t, responseMSG.MsgHdr.Zero, "Zero bit should be set for fallthrough")
assert.Equal(t, dns.RcodeNameError, responseMSG.Rcode, "Should return NXDOMAIN")
} else {
assert.False(t, responseMSG.MsgHdr.Zero, "Zero bit should not be set")
}
if tc.expectRecord {
assert.Greater(t, len(responseMSG.Answer), 0, "Should have answer records")
assert.Equal(t, dns.RcodeSuccess, responseMSG.Rcode)
}
})
}
}
// TestLocalResolver_AuthoritativeFlag tests the AA flag behavior
func TestLocalResolver_AuthoritativeFlag(t *testing.T) {
t.Run("direct record lookup is authoritative", func(t *testing.T) {
resolver := NewResolver()
resolver.Update([]nbdns.CustomZone{{
Domain: "example.com.",
Records: []nbdns.SimpleRecord{
{Name: "host.example.com.", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"},
},
}})
msg := new(dns.Msg).SetQuestion("host.example.com.", dns.TypeA)
var resp *dns.Msg
resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { resp = m; return nil }}, msg)
require.NotNil(t, resp)
assert.True(t, resp.Authoritative)
})
t.Run("external resolution is not authoritative", func(t *testing.T) {
resolver := NewResolver()
resolver.resolver = &mockResolver{
lookupFunc: func(_ context.Context, network, host string) ([]netip.Addr, error) {
if host == "external.example.com." && network == "ip4" {
return []netip.Addr{netip.MustParseAddr("93.184.216.34")}, nil
}
return nil, nil
},
}
resolver.Update([]nbdns.CustomZone{{
Domain: "test.",
Records: []nbdns.SimpleRecord{
{Name: "alias.test.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "external.example.com."},
},
}})
msg := new(dns.Msg).SetQuestion("alias.test.", dns.TypeA)
var resp *dns.Msg
resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { resp = m; return nil }}, msg)
require.NotNil(t, resp)
require.Len(t, resp.Answer, 2)
assert.False(t, resp.Authoritative)
})
}
// TestLocalResolver_Stop tests cleanup on Stop
func TestLocalResolver_Stop(t *testing.T) {
t.Run("Stop clears all state", func(t *testing.T) {
resolver := NewResolver()
resolver.Update([]nbdns.CustomZone{{
Domain: "example.com.",
Records: []nbdns.SimpleRecord{
{Name: "host.example.com.", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"},
},
}})
resolver.Stop()
msg := new(dns.Msg).SetQuestion("host.example.com.", dns.TypeA)
var resp *dns.Msg
resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { resp = m; return nil }}, msg)
require.NotNil(t, resp)
assert.Len(t, resp.Answer, 0)
assert.False(t, resolver.isInManagedZone("host.example.com."))
})
t.Run("Stop is safe to call multiple times", func(t *testing.T) {
resolver := NewResolver()
resolver.Update([]nbdns.CustomZone{{
Domain: "example.com.",
Records: []nbdns.SimpleRecord{
{Name: "host.example.com.", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"},
},
}})
resolver.Stop()
resolver.Stop()
resolver.Stop()
})
t.Run("Stop cancels in-flight external resolution", func(t *testing.T) {
resolver := NewResolver()
lookupStarted := make(chan struct{})
lookupCtxCanceled := make(chan struct{})
resolver.resolver = &mockResolver{
lookupFunc: func(ctx context.Context, network, host string) ([]netip.Addr, error) {
close(lookupStarted)
<-ctx.Done()
close(lookupCtxCanceled)
return nil, ctx.Err()
},
}
resolver.Update([]nbdns.CustomZone{{
Domain: "test.",
Records: []nbdns.SimpleRecord{
{Name: "alias.test.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "external.example.com."},
},
}})
done := make(chan struct{})
go func() {
msg := new(dns.Msg).SetQuestion("alias.test.", dns.TypeA)
resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { return nil }}, msg)
close(done)
}()
<-lookupStarted
resolver.Stop()
select {
case <-lookupCtxCanceled:
case <-time.After(time.Second):
t.Fatal("external lookup context was not canceled")
}
select {
case <-done:
case <-time.After(time.Second):
t.Fatal("ServeDNS did not return after Stop")
}
})
}
// TestLocalResolver_FallthroughCaseInsensitive verifies case-insensitive domain matching for fallthrough
func TestLocalResolver_FallthroughCaseInsensitive(t *testing.T) {
resolver := NewResolver()
resolver.Update([]nbdns.CustomZone{{
Domain: "EXAMPLE.COM.",
Records: []nbdns.SimpleRecord{{Name: "host.example.com.", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "1.2.3.4"}},
NonAuthoritative: true,
}})
var responseMSG *dns.Msg
responseWriter := &test.MockResponseWriter{
WriteMsgFunc: func(m *dns.Msg) error {
responseMSG = m
return nil
},
}
msg := new(dns.Msg).SetQuestion("nonexistent.example.com.", dns.TypeA)
resolver.ServeDNS(responseWriter, msg)
require.NotNil(t, responseMSG)
assert.True(t, responseMSG.MsgHdr.Zero, "Should fallthrough for non-authoritative zone with case-insensitive match")
}
// BenchmarkFindZone_BestCase benchmarks zone lookup with immediate match (first label)
func BenchmarkFindZone_BestCase(b *testing.B) {
resolver := NewResolver()
// Single zone that matches immediately
resolver.Update([]nbdns.CustomZone{{
Domain: "example.com.",
NonAuthoritative: true,
}})
b.ResetTimer()
for i := 0; i < b.N; i++ {
resolver.shouldFallthrough("example.com.")
}
}
// BenchmarkFindZone_WorstCase benchmarks zone lookup with many zones, no match, many labels
func BenchmarkFindZone_WorstCase(b *testing.B) {
resolver := NewResolver()
// 100 zones that won't match
var zones []nbdns.CustomZone
for i := 0; i < 100; i++ {
zones = append(zones, nbdns.CustomZone{
Domain: fmt.Sprintf("zone%d.internal.", i),
NonAuthoritative: true,
})
}
resolver.Update(zones)
// Query with many labels that won't match any zone
qname := "a.b.c.d.e.f.g.h.external.com."
b.ResetTimer()
for i := 0; i < b.N; i++ {
resolver.shouldFallthrough(qname)
}
}
// BenchmarkFindZone_TypicalCase benchmarks typical usage: few zones, subdomain match
func BenchmarkFindZone_TypicalCase(b *testing.B) {
resolver := NewResolver()
// Typical setup: peer zone (authoritative) + one user zone (non-authoritative)
resolver.Update([]nbdns.CustomZone{
{Domain: "netbird.cloud.", NonAuthoritative: false},
{Domain: "custom.local.", NonAuthoritative: true},
})
// Query for subdomain of user zone
qname := "myhost.custom.local."
b.ResetTimer()
for i := 0; i < b.N; i++ {
resolver.shouldFallthrough(qname)
}
}
// BenchmarkIsInManagedZone_ManyZones benchmarks isInManagedZone with 100 zones
func BenchmarkIsInManagedZone_ManyZones(b *testing.B) {
resolver := NewResolver()
var zones []nbdns.CustomZone
for i := 0; i < 100; i++ {
zones = append(zones, nbdns.CustomZone{
Domain: fmt.Sprintf("zone%d.internal.", i),
})
}
resolver.Update(zones)
// Query that matches zone50
qname := "host.zone50.internal."
b.ResetTimer()
for i := 0; i < b.N; i++ {
resolver.isInManagedZone(qname)
}
}

View File

@@ -0,0 +1,197 @@
// Package resutil provides shared DNS resolution utilities
package resutil
import (
"context"
"crypto/rand"
"encoding/hex"
"errors"
"net"
"net/netip"
"strings"
"github.com/miekg/dns"
log "github.com/sirupsen/logrus"
)
// GenerateRequestID creates a random 8-character hex string for request tracing.
func GenerateRequestID() string {
bytes := make([]byte, 4)
if _, err := rand.Read(bytes); err != nil {
log.Errorf("generate request ID: %v", err)
return ""
}
return hex.EncodeToString(bytes)
}
// IPsToRRs converts a slice of IP addresses to DNS resource records.
// IPv4 addresses become A records, IPv6 addresses become AAAA records.
func IPsToRRs(name string, ips []netip.Addr, ttl uint32) []dns.RR {
var result []dns.RR
for _, ip := range ips {
if ip.Is6() {
result = append(result, &dns.AAAA{
Hdr: dns.RR_Header{
Name: name,
Rrtype: dns.TypeAAAA,
Class: dns.ClassINET,
Ttl: ttl,
},
AAAA: ip.AsSlice(),
})
} else {
result = append(result, &dns.A{
Hdr: dns.RR_Header{
Name: name,
Rrtype: dns.TypeA,
Class: dns.ClassINET,
Ttl: ttl,
},
A: ip.AsSlice(),
})
}
}
return result
}
// NetworkForQtype returns the network string ("ip4" or "ip6") for a DNS query type.
// Returns empty string for unsupported types.
func NetworkForQtype(qtype uint16) string {
switch qtype {
case dns.TypeA:
return "ip4"
case dns.TypeAAAA:
return "ip6"
default:
return ""
}
}
type resolver interface {
LookupNetIP(ctx context.Context, network, host string) ([]netip.Addr, error)
}
// chainedWriter is implemented by ResponseWriters that carry request metadata
type chainedWriter interface {
RequestID() string
SetMeta(key, value string)
}
// GetRequestID extracts a request ID from the ResponseWriter if available,
// otherwise generates a new one.
func GetRequestID(w dns.ResponseWriter) string {
if cw, ok := w.(chainedWriter); ok {
if id := cw.RequestID(); id != "" {
return id
}
}
return GenerateRequestID()
}
// SetMeta sets metadata on the ResponseWriter if it supports it.
func SetMeta(w dns.ResponseWriter, key, value string) {
if cw, ok := w.(chainedWriter); ok {
cw.SetMeta(key, value)
}
}
// LookupResult contains the result of an external DNS lookup
type LookupResult struct {
IPs []netip.Addr
Rcode int
Err error // Original error for caller's logging needs
}
// LookupIP performs a DNS lookup and determines the appropriate rcode.
func LookupIP(ctx context.Context, r resolver, network, host string, qtype uint16) LookupResult {
ips, err := r.LookupNetIP(ctx, network, host)
if err != nil {
return LookupResult{
Rcode: getRcodeForError(ctx, r, host, qtype, err),
Err: err,
}
}
// Unmap IPv4-mapped IPv6 addresses that some resolvers may return
for i, ip := range ips {
ips[i] = ip.Unmap()
}
return LookupResult{
IPs: ips,
Rcode: dns.RcodeSuccess,
}
}
func getRcodeForError(ctx context.Context, r resolver, host string, qtype uint16, err error) int {
var dnsErr *net.DNSError
if !errors.As(err, &dnsErr) {
return dns.RcodeServerFailure
}
if dnsErr.IsNotFound {
return getRcodeForNotFound(ctx, r, host, qtype)
}
return dns.RcodeServerFailure
}
// getRcodeForNotFound distinguishes between NXDOMAIN (domain doesn't exist) and NODATA
// (domain exists but no records of requested type) by checking the opposite record type.
//
// musl libc (the reason we need this distinction) only queries A/AAAA pairs in getaddrinfo,
// so checking the opposite A/AAAA type is sufficient. Other record types (MX, TXT, etc.)
// are not queried by musl and don't need this handling.
func getRcodeForNotFound(ctx context.Context, r resolver, domain string, originalQtype uint16) int {
// Try querying for a different record type to see if the domain exists
// If the original query was for AAAA, try A. If it was for A, try AAAA.
// This helps distinguish between NXDOMAIN and NODATA.
var alternativeNetwork string
switch originalQtype {
case dns.TypeAAAA:
alternativeNetwork = "ip4"
case dns.TypeA:
alternativeNetwork = "ip6"
default:
return dns.RcodeNameError
}
if _, err := r.LookupNetIP(ctx, alternativeNetwork, domain); err != nil {
var dnsErr *net.DNSError
if errors.As(err, &dnsErr) && dnsErr.IsNotFound {
// Alternative query also returned not found - domain truly doesn't exist
return dns.RcodeNameError
}
// Some other error (timeout, server failure, etc.) - can't determine, assume domain exists
return dns.RcodeSuccess
}
// Alternative query succeeded - domain exists but has no records of this type
return dns.RcodeSuccess
}
// FormatAnswers formats DNS resource records for logging.
func FormatAnswers(answers []dns.RR) string {
if len(answers) == 0 {
return "[]"
}
parts := make([]string, 0, len(answers))
for _, rr := range answers {
switch r := rr.(type) {
case *dns.A:
parts = append(parts, r.A.String())
case *dns.AAAA:
parts = append(parts, r.AAAA.String())
case *dns.CNAME:
parts = append(parts, "CNAME:"+r.Target)
case *dns.PTR:
parts = append(parts, "PTR:"+r.Ptr)
default:
parts = append(parts, dns.TypeToString[rr.Header().Rrtype])
}
}
return "[" + strings.Join(parts, ", ") + "]"
}

View File

@@ -485,7 +485,7 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
}
}
localMuxUpdates, localRecords, err := s.buildLocalHandlerUpdate(update.CustomZones)
localMuxUpdates, localZones, err := s.buildLocalHandlerUpdate(update.CustomZones)
if err != nil {
return fmt.Errorf("local handler updater: %w", err)
}
@@ -498,8 +498,7 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
s.updateMux(muxUpdates)
// register local records
s.localResolver.Update(localRecords)
s.localResolver.Update(localZones)
s.currentConfig = dnsConfigToHostDNSConfig(update, s.service.RuntimeIP(), s.service.RuntimePort())
@@ -659,9 +658,9 @@ func (s *DefaultServer) registerFallback(config HostDNSConfig) {
s.registerHandler([]string{nbdns.RootZone}, handler, PriorityFallback)
}
func (s *DefaultServer) buildLocalHandlerUpdate(customZones []nbdns.CustomZone) ([]handlerWrapper, []nbdns.SimpleRecord, error) {
func (s *DefaultServer) buildLocalHandlerUpdate(customZones []nbdns.CustomZone) ([]handlerWrapper, []nbdns.CustomZone, error) {
var muxUpdates []handlerWrapper
var localRecords []nbdns.SimpleRecord
var zones []nbdns.CustomZone
for _, customZone := range customZones {
if len(customZone.Records) == 0 {
@@ -675,17 +674,20 @@ func (s *DefaultServer) buildLocalHandlerUpdate(customZones []nbdns.CustomZone)
priority: PriorityLocal,
})
// zone records contain the fqdn, so we can just flatten them
var localRecords []nbdns.SimpleRecord
for _, record := range customZone.Records {
if record.Class != nbdns.DefaultClass {
log.Warnf("received an invalid class type: %s", record.Class)
continue
}
// zone records contain the fqdn, so we can just flatten them
localRecords = append(localRecords, record)
}
customZone.Records = localRecords
zones = append(zones, customZone)
}
return muxUpdates, localRecords, nil
return muxUpdates, zones, nil
}
func (s *DefaultServer) buildUpstreamHandlerUpdate(nameServerGroups []*nbdns.NameServerGroup) ([]handlerWrapper, error) {

View File

@@ -128,7 +128,7 @@ func TestUpdateDNSServer(t *testing.T) {
testCases := []struct {
name string
initUpstreamMap registeredHandlerMap
initLocalRecords []nbdns.SimpleRecord
initLocalZones []nbdns.CustomZone
initSerial uint64
inputSerial uint64
inputUpdate nbdns.Config
@@ -181,7 +181,7 @@ func TestUpdateDNSServer(t *testing.T) {
},
{
name: "New Config Should Succeed",
initLocalRecords: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: 1, Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}},
initLocalZones: []nbdns.CustomZone{{Domain: "netbird.cloud", Records: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: 1, Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}}}},
initUpstreamMap: registeredHandlerMap{
generateDummyHandler(zoneRecords[0].Name, nameServers).ID(): handlerWrapper{
domain: "netbird.cloud",
@@ -222,7 +222,7 @@ func TestUpdateDNSServer(t *testing.T) {
},
{
name: "Smaller Config Serial Should Be Skipped",
initLocalRecords: []nbdns.SimpleRecord{},
initLocalZones: []nbdns.CustomZone{},
initUpstreamMap: make(registeredHandlerMap),
initSerial: 2,
inputSerial: 1,
@@ -230,7 +230,7 @@ func TestUpdateDNSServer(t *testing.T) {
},
{
name: "Empty NS Group Domain Or Not Primary Element Should Fail",
initLocalRecords: []nbdns.SimpleRecord{},
initLocalZones: []nbdns.CustomZone{},
initUpstreamMap: make(registeredHandlerMap),
initSerial: 0,
inputSerial: 1,
@@ -252,7 +252,7 @@ func TestUpdateDNSServer(t *testing.T) {
},
{
name: "Invalid NS Group Nameservers list Should Fail",
initLocalRecords: []nbdns.SimpleRecord{},
initLocalZones: []nbdns.CustomZone{},
initUpstreamMap: make(registeredHandlerMap),
initSerial: 0,
inputSerial: 1,
@@ -274,7 +274,7 @@ func TestUpdateDNSServer(t *testing.T) {
},
{
name: "Invalid Custom Zone Records list Should Skip",
initLocalRecords: []nbdns.SimpleRecord{},
initLocalZones: []nbdns.CustomZone{},
initUpstreamMap: make(registeredHandlerMap),
initSerial: 0,
inputSerial: 1,
@@ -300,7 +300,7 @@ func TestUpdateDNSServer(t *testing.T) {
},
{
name: "Empty Config Should Succeed and Clean Maps",
initLocalRecords: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}},
initLocalZones: []nbdns.CustomZone{{Domain: "netbird.cloud", Records: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}}}},
initUpstreamMap: registeredHandlerMap{
generateDummyHandler(zoneRecords[0].Name, nameServers).ID(): handlerWrapper{
domain: zoneRecords[0].Name,
@@ -316,7 +316,7 @@ func TestUpdateDNSServer(t *testing.T) {
},
{
name: "Disabled Service Should clean map",
initLocalRecords: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}},
initLocalZones: []nbdns.CustomZone{{Domain: "netbird.cloud", Records: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}}}},
initUpstreamMap: registeredHandlerMap{
generateDummyHandler(zoneRecords[0].Name, nameServers).ID(): handlerWrapper{
domain: zoneRecords[0].Name,
@@ -385,7 +385,7 @@ func TestUpdateDNSServer(t *testing.T) {
}()
dnsServer.dnsMuxMap = testCase.initUpstreamMap
dnsServer.localResolver.Update(testCase.initLocalRecords)
dnsServer.localResolver.Update(testCase.initLocalZones)
dnsServer.updateSerial = testCase.initSerial
err = dnsServer.UpdateDNSServer(testCase.inputSerial, testCase.inputUpdate)
@@ -510,8 +510,7 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) {
priority: PriorityUpstream,
},
}
//dnsServer.localResolver.RegisteredMap = local.RegistrationMap{local.BuildRecordKey("netbird.cloud", dns.ClassINET, dns.TypeA): struct{}{}}
dnsServer.localResolver.Update([]nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}})
dnsServer.localResolver.Update([]nbdns.CustomZone{{Domain: "netbird.cloud", Records: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}}}})
dnsServer.updateSerial = 0
nameServers := []nbdns.NameServer{

View File

@@ -2,7 +2,6 @@ package dns
import (
"context"
"crypto/rand"
"crypto/sha256"
"encoding/hex"
"errors"
@@ -21,6 +20,7 @@ import (
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/internal/dns/resutil"
"github.com/netbirdio/netbird/client/internal/dns/types"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/proto"
@@ -113,10 +113,7 @@ func (u *upstreamResolverBase) Stop() {
// ServeDNS handles a DNS request
func (u *upstreamResolverBase) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
requestID := GenerateRequestID()
logger := log.WithField("request_id", requestID)
logger.Tracef("received upstream question: domain=%s type=%v class=%v", r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass)
logger := log.WithField("request_id", resutil.GetRequestID(w))
u.prepareRequest(r)
@@ -202,11 +199,18 @@ func (u *upstreamResolverBase) handleUpstreamError(err error, upstream netip.Add
func (u *upstreamResolverBase) writeSuccessResponse(w dns.ResponseWriter, rm *dns.Msg, upstream netip.AddrPort, domain string, t time.Duration, logger *log.Entry) bool {
u.successCount.Add(1)
logger.Tracef("took %s to query the upstream %s for question domain=%s", t, upstream, domain)
resutil.SetMeta(w, "upstream", upstream.String())
// Clear Zero bit from external responses to prevent upstream servers from
// manipulating our internal fallthrough signaling mechanism
rm.MsgHdr.Zero = false
if err := w.WriteMsg(rm); err != nil {
logger.Errorf("failed to write DNS response for question domain=%s: %s", domain, err)
return true
}
return true
}
@@ -414,16 +418,6 @@ func ExchangeWithFallback(ctx context.Context, client *dns.Client, r *dns.Msg, u
return rm, t, nil
}
func GenerateRequestID() string {
bytes := make([]byte, 4)
_, err := rand.Read(bytes)
if err != nil {
log.Errorf("failed to generate request ID: %v", err)
return ""
}
return hex.EncodeToString(bytes)
}
// FormatPeerStatus formats peer connection status information for debugging DNS timeouts
func FormatPeerStatus(peerState *peer.State) string {
isConnected := peerState.ConnStatus == peer.StatusConnected

View File

@@ -18,6 +18,7 @@ import (
nberrors "github.com/netbirdio/netbird/client/errors"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/internal/dns/resutil"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/route"
)
@@ -189,29 +190,22 @@ func (f *DNSForwarder) Close(ctx context.Context) error {
return nberrors.FormatErrorOrNil(result)
}
func (f *DNSForwarder) handleDNSQuery(w dns.ResponseWriter, query *dns.Msg) *dns.Msg {
func (f *DNSForwarder) handleDNSQuery(logger *log.Entry, w dns.ResponseWriter, query *dns.Msg) *dns.Msg {
if len(query.Question) == 0 {
return nil
}
question := query.Question[0]
log.Tracef("received DNS request for DNS forwarder: domain=%v type=%v class=%v",
question.Name, question.Qtype, question.Qclass)
logger.Tracef("received DNS request for DNS forwarder: domain=%s type=%s class=%s",
question.Name, dns.TypeToString[question.Qtype], dns.ClassToString[question.Qclass])
domain := strings.ToLower(question.Name)
resp := query.SetReply(query)
var network string
switch question.Qtype {
case dns.TypeA:
network = "ip4"
case dns.TypeAAAA:
network = "ip6"
default:
// TODO: Handle other types
network := resutil.NetworkForQtype(question.Qtype)
if network == "" {
resp.Rcode = dns.RcodeNotImplemented
if err := w.WriteMsg(resp); err != nil {
log.Errorf("failed to write DNS response: %v", err)
logger.Errorf("failed to write DNS response: %v", err)
}
return nil
}
@@ -221,33 +215,35 @@ func (f *DNSForwarder) handleDNSQuery(w dns.ResponseWriter, query *dns.Msg) *dns
if mostSpecificResId == "" {
resp.Rcode = dns.RcodeRefused
if err := w.WriteMsg(resp); err != nil {
log.Errorf("failed to write DNS response: %v", err)
logger.Errorf("failed to write DNS response: %v", err)
}
return nil
}
ctx, cancel := context.WithTimeout(context.Background(), upstreamTimeout)
defer cancel()
ips, err := f.resolver.LookupNetIP(ctx, network, domain)
if err != nil {
f.handleDNSError(ctx, w, question, resp, domain, err)
result := resutil.LookupIP(ctx, f.resolver, network, domain, question.Qtype)
if result.Err != nil {
f.handleDNSError(ctx, logger, w, question, resp, domain, result)
return nil
}
// Unmap IPv4-mapped IPv6 addresses that some resolvers may return
for i, ip := range ips {
ips[i] = ip.Unmap()
}
f.updateInternalState(ips, mostSpecificResId, matchingEntries)
f.addIPsToResponse(resp, domain, ips)
f.cache.set(domain, question.Qtype, ips)
f.updateInternalState(result.IPs, mostSpecificResId, matchingEntries)
resp.Answer = append(resp.Answer, resutil.IPsToRRs(domain, result.IPs, f.ttl)...)
f.cache.set(domain, question.Qtype, result.IPs)
return resp
}
func (f *DNSForwarder) handleDNSQueryUDP(w dns.ResponseWriter, query *dns.Msg) {
resp := f.handleDNSQuery(w, query)
startTime := time.Now()
logger := log.WithFields(log.Fields{
"request_id": resutil.GenerateRequestID(),
"dns_id": fmt.Sprintf("%04x", query.Id),
})
resp := f.handleDNSQuery(logger, w, query)
if resp == nil {
return
}
@@ -265,19 +261,33 @@ func (f *DNSForwarder) handleDNSQueryUDP(w dns.ResponseWriter, query *dns.Msg) {
}
if err := w.WriteMsg(resp); err != nil {
log.Errorf("failed to write DNS response: %v", err)
logger.Errorf("failed to write DNS response: %v", err)
return
}
logger.Tracef("response: domain=%s rcode=%s answers=%s took=%s",
query.Question[0].Name, dns.RcodeToString[resp.Rcode], resutil.FormatAnswers(resp.Answer), time.Since(startTime))
}
func (f *DNSForwarder) handleDNSQueryTCP(w dns.ResponseWriter, query *dns.Msg) {
resp := f.handleDNSQuery(w, query)
startTime := time.Now()
logger := log.WithFields(log.Fields{
"request_id": resutil.GenerateRequestID(),
"dns_id": fmt.Sprintf("%04x", query.Id),
})
resp := f.handleDNSQuery(logger, w, query)
if resp == nil {
return
}
if err := w.WriteMsg(resp); err != nil {
log.Errorf("failed to write DNS response: %v", err)
logger.Errorf("failed to write DNS response: %v", err)
return
}
logger.Tracef("response: domain=%s rcode=%s answers=%s took=%s",
query.Question[0].Name, dns.RcodeToString[resp.Rcode], resutil.FormatAnswers(resp.Answer), time.Since(startTime))
}
func (f *DNSForwarder) updateInternalState(ips []netip.Addr, mostSpecificResId route.ResID, matchingEntries []*ForwarderEntry) {
@@ -315,140 +325,64 @@ func (f *DNSForwarder) updateFirewall(matchingEntries []*ForwarderEntry, prefixe
}
}
// setResponseCodeForNotFound determines and sets the appropriate response code when IsNotFound is true
// It distinguishes between NXDOMAIN (domain doesn't exist) and NODATA (domain exists but no records of requested type)
//
// LIMITATION: This function only checks A and AAAA record types to determine domain existence.
// If a domain has only other record types (MX, TXT, CNAME, etc.) but no A/AAAA records,
// it may incorrectly return NXDOMAIN instead of NODATA. This is acceptable since the forwarder
// only handles A/AAAA queries and returns NOTIMP for other types.
func (f *DNSForwarder) setResponseCodeForNotFound(ctx context.Context, resp *dns.Msg, domain string, originalQtype uint16) {
// Try querying for a different record type to see if the domain exists
// If the original query was for AAAA, try A. If it was for A, try AAAA.
// This helps distinguish between NXDOMAIN and NODATA.
var alternativeNetwork string
switch originalQtype {
case dns.TypeAAAA:
alternativeNetwork = "ip4"
case dns.TypeA:
alternativeNetwork = "ip6"
default:
resp.Rcode = dns.RcodeNameError
return
}
if _, err := f.resolver.LookupNetIP(ctx, alternativeNetwork, domain); err != nil {
var dnsErr *net.DNSError
if errors.As(err, &dnsErr) && dnsErr.IsNotFound {
// Alternative query also returned not found - domain truly doesn't exist
resp.Rcode = dns.RcodeNameError
return
}
// Some other error (timeout, server failure, etc.) - can't determine, assume domain exists
resp.Rcode = dns.RcodeSuccess
return
}
// Alternative query succeeded - domain exists but has no records of this type
resp.Rcode = dns.RcodeSuccess
}
// handleDNSError processes DNS lookup errors and sends an appropriate error response.
func (f *DNSForwarder) handleDNSError(
ctx context.Context,
logger *log.Entry,
w dns.ResponseWriter,
question dns.Question,
resp *dns.Msg,
domain string,
err error,
result resutil.LookupResult,
) {
// Default to SERVFAIL; override below when appropriate.
resp.Rcode = dns.RcodeServerFailure
qType := question.Qtype
qTypeName := dns.TypeToString[qType]
// Prefer typed DNS errors; fall back to generic logging otherwise.
var dnsErr *net.DNSError
if !errors.As(err, &dnsErr) {
log.Warnf(errResolveFailed, domain, err)
if writeErr := w.WriteMsg(resp); writeErr != nil {
log.Errorf("failed to write failure DNS response: %v", writeErr)
}
return
}
resp.Rcode = result.Rcode
// NotFound: set NXDOMAIN / appropriate code via helper.
if dnsErr.IsNotFound {
f.setResponseCodeForNotFound(ctx, resp, domain, qType)
if writeErr := w.WriteMsg(resp); writeErr != nil {
log.Errorf("failed to write failure DNS response: %v", writeErr)
}
// NotFound: cache negative result and respond
if result.Rcode == dns.RcodeNameError || result.Rcode == dns.RcodeSuccess {
f.cache.set(domain, question.Qtype, nil)
if writeErr := w.WriteMsg(resp); writeErr != nil {
logger.Errorf("failed to write failure DNS response: %v", writeErr)
}
return
}
// Upstream failed but we might have a cached answer—serve it if present.
if ips, ok := f.cache.get(domain, qType); ok {
if len(ips) > 0 {
log.Debugf("serving cached DNS response after upstream failure: domain=%s type=%s", domain, qTypeName)
f.addIPsToResponse(resp, domain, ips)
logger.Debugf("serving cached DNS response after upstream failure: domain=%s type=%s", domain, qTypeName)
resp.Answer = append(resp.Answer, resutil.IPsToRRs(domain, ips, f.ttl)...)
resp.Rcode = dns.RcodeSuccess
if writeErr := w.WriteMsg(resp); writeErr != nil {
log.Errorf("failed to write cached DNS response: %v", writeErr)
}
} else { // send NXDOMAIN / appropriate code if cache is empty
f.setResponseCodeForNotFound(ctx, resp, domain, qType)
if writeErr := w.WriteMsg(resp); writeErr != nil {
log.Errorf("failed to write failure DNS response: %v", writeErr)
logger.Errorf("failed to write cached DNS response: %v", writeErr)
}
return
}
// Cached negative result - re-verify NXDOMAIN vs NODATA
verifyResult := resutil.LookupIP(ctx, f.resolver, resutil.NetworkForQtype(qType), domain, qType)
if verifyResult.Rcode == dns.RcodeNameError || verifyResult.Rcode == dns.RcodeSuccess {
resp.Rcode = verifyResult.Rcode
if writeErr := w.WriteMsg(resp); writeErr != nil {
logger.Errorf("failed to write failure DNS response: %v", writeErr)
}
return
}
return
}
// No cache. Log with or without the server field for more context.
if dnsErr.Server != "" {
log.Warnf("failed to resolve: type=%s domain=%s server=%s: %v", qTypeName, domain, dnsErr.Server, err)
// No cache or verification failed. Log with or without the server field for more context.
var dnsErr *net.DNSError
if errors.As(result.Err, &dnsErr) && dnsErr.Server != "" {
logger.Warnf("failed to resolve: type=%s domain=%s server=%s: %v", qTypeName, domain, dnsErr.Server, result.Err)
} else {
log.Warnf(errResolveFailed, domain, err)
logger.Warnf(errResolveFailed, domain, result.Err)
}
// Write final failure response.
if writeErr := w.WriteMsg(resp); writeErr != nil {
log.Errorf("failed to write failure DNS response: %v", writeErr)
}
}
// addIPsToResponse adds IP addresses to the DNS response as appropriate A or AAAA records
func (f *DNSForwarder) addIPsToResponse(resp *dns.Msg, domain string, ips []netip.Addr) {
for _, ip := range ips {
var respRecord dns.RR
if ip.Is6() {
log.Tracef("resolved domain=%s to IPv6=%s", domain, ip)
rr := dns.AAAA{
AAAA: ip.AsSlice(),
Hdr: dns.RR_Header{
Name: domain,
Rrtype: dns.TypeAAAA,
Class: dns.ClassINET,
Ttl: f.ttl,
},
}
respRecord = &rr
} else {
log.Tracef("resolved domain=%s to IPv4=%s", domain, ip)
rr := dns.A{
A: ip.AsSlice(),
Hdr: dns.RR_Header{
Name: domain,
Rrtype: dns.TypeA,
Class: dns.ClassINET,
Ttl: f.ttl,
},
}
respRecord = &rr
}
resp.Answer = append(resp.Answer, respRecord)
logger.Errorf("failed to write failure DNS response: %v", writeErr)
}
}

View File

@@ -10,6 +10,7 @@ import (
"time"
"github.com/miekg/dns"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
@@ -317,7 +318,7 @@ func TestDNSForwarder_UnauthorizedDomainAccess(t *testing.T) {
query.SetQuestion(dns.Fqdn(tt.queryDomain), dns.TypeA)
mockWriter := &test.MockResponseWriter{}
resp := forwarder.handleDNSQuery(mockWriter, query)
resp := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query)
if tt.shouldResolve {
require.NotNil(t, resp, "Expected response for authorized domain")
@@ -465,7 +466,7 @@ func TestDNSForwarder_FirewallSetUpdates(t *testing.T) {
dnsQuery.SetQuestion(dns.Fqdn(tt.query), dns.TypeA)
mockWriter := &test.MockResponseWriter{}
resp := forwarder.handleDNSQuery(mockWriter, dnsQuery)
resp := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, dnsQuery)
// Verify response
if tt.shouldResolve {
@@ -527,7 +528,7 @@ func TestDNSForwarder_MultipleIPsInSingleUpdate(t *testing.T) {
query.SetQuestion("example.com.", dns.TypeA)
mockWriter := &test.MockResponseWriter{}
resp := forwarder.handleDNSQuery(mockWriter, query)
resp := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query)
// Verify response contains all IPs
require.NotNil(t, resp)
@@ -604,7 +605,7 @@ func TestDNSForwarder_ResponseCodes(t *testing.T) {
},
}
_ = forwarder.handleDNSQuery(mockWriter, query)
_ = forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query)
// Check the response written to the writer
require.NotNil(t, writtenResp, "Expected response to be written")
@@ -674,7 +675,7 @@ func TestDNSForwarder_ServeFromCacheOnUpstreamFailure(t *testing.T) {
q1 := &dns.Msg{}
q1.SetQuestion(dns.Fqdn("example.com"), dns.TypeA)
w1 := &test.MockResponseWriter{}
resp1 := forwarder.handleDNSQuery(w1, q1)
resp1 := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), w1, q1)
require.NotNil(t, resp1)
require.Equal(t, dns.RcodeSuccess, resp1.Rcode)
require.Len(t, resp1.Answer, 1)
@@ -684,7 +685,7 @@ func TestDNSForwarder_ServeFromCacheOnUpstreamFailure(t *testing.T) {
q2.SetQuestion(dns.Fqdn("example.com"), dns.TypeA)
var writtenResp *dns.Msg
w2 := &test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { writtenResp = m; return nil }}
_ = forwarder.handleDNSQuery(w2, q2)
_ = forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), w2, q2)
require.NotNil(t, writtenResp, "expected response to be written")
require.Equal(t, dns.RcodeSuccess, writtenResp.Rcode)
@@ -714,7 +715,7 @@ func TestDNSForwarder_CacheNormalizationCasingAndDot(t *testing.T) {
q1 := &dns.Msg{}
q1.SetQuestion(mixedQuery+".", dns.TypeA)
w1 := &test.MockResponseWriter{}
resp1 := forwarder.handleDNSQuery(w1, q1)
resp1 := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), w1, q1)
require.NotNil(t, resp1)
require.Equal(t, dns.RcodeSuccess, resp1.Rcode)
require.Len(t, resp1.Answer, 1)
@@ -728,7 +729,7 @@ func TestDNSForwarder_CacheNormalizationCasingAndDot(t *testing.T) {
q2.SetQuestion("EXAMPLE.COM", dns.TypeA)
var writtenResp *dns.Msg
w2 := &test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { writtenResp = m; return nil }}
_ = forwarder.handleDNSQuery(w2, q2)
_ = forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), w2, q2)
require.NotNil(t, writtenResp)
require.Equal(t, dns.RcodeSuccess, writtenResp.Rcode)
@@ -783,7 +784,7 @@ func TestDNSForwarder_MultipleOverlappingPatterns(t *testing.T) {
query.SetQuestion("smtp.mail.example.com.", dns.TypeA)
mockWriter := &test.MockResponseWriter{}
resp := forwarder.handleDNSQuery(mockWriter, query)
resp := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query)
require.NotNil(t, resp)
assert.Equal(t, dns.RcodeSuccess, resp.Rcode)
@@ -904,7 +905,7 @@ func TestDNSForwarder_NodataVsNxdomain(t *testing.T) {
},
}
resp := forwarder.handleDNSQuery(mockWriter, query)
resp := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query)
// If a response was returned, it means it should be written (happens in wrapper functions)
if resp != nil && writtenResp == nil {
@@ -937,7 +938,7 @@ func TestDNSForwarder_EmptyQuery(t *testing.T) {
return nil
},
}
resp := forwarder.handleDNSQuery(mockWriter, query)
resp := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query)
assert.Nil(t, resp, "Should return nil for empty query")
assert.False(t, writeCalled, "Should not write response for empty query")

View File

@@ -1251,11 +1251,16 @@ func toDNSConfig(protoDNSConfig *mgmProto.DNSConfig, network netip.Prefix) nbdns
ForwarderPort: forwarderPort,
}
for _, zone := range protoDNSConfig.GetCustomZones() {
protoZones := protoDNSConfig.GetCustomZones()
// Treat single zone as authoritative for backward compatibility with old servers
// that only send the peer FQDN zone without setting field 4.
singleZoneCompat := len(protoZones) == 1
for _, zone := range protoZones {
dnsZone := nbdns.CustomZone{
Domain: zone.GetDomain(),
SearchDomainDisabled: zone.GetSearchDomainDisabled(),
SkipPTRProcess: zone.GetSkipPTRProcess(),
NonAuthoritative: zone.GetNonAuthoritative() && !singleZoneCompat,
}
for _, record := range zone.Records {
dnsRecord := nbdns.SimpleRecord{

View File

@@ -1,5 +1,4 @@
//go:build !windows
// +build !windows
package internal

View File

@@ -669,10 +669,17 @@ func (conn *Conn) isConnectedOnAllWay() (connected bool) {
}
}()
if runtime.GOOS != "js" && conn.statusICE.Get() == worker.StatusDisconnected && !conn.workerICE.InProgress() {
// For JS platform: only relay connection is supported
if runtime.GOOS == "js" {
return conn.statusRelay.Get() == worker.StatusConnected
}
// For non-JS platforms: check ICE connection status
if conn.statusICE.Get() == worker.StatusDisconnected && !conn.workerICE.InProgress() {
return false
}
// If relay is supported with peer, it must also be connected
if conn.workerRelay.IsRelayConnectionSupportedWithPeer() {
if conn.statusRelay.Get() == worker.StatusDisconnected {
return false

View File

@@ -19,6 +19,7 @@ import (
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface/wgaddr"
nbdns "github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/dns/resutil"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/peerstore"
"github.com/netbirdio/netbird/client/internal/routemanager/common"
@@ -219,14 +220,14 @@ func (d *DnsInterceptor) RemoveAllowedIPs() error {
// ServeDNS implements the dns.Handler interface
func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
requestID := nbdns.GenerateRequestID()
logger := log.WithField("request_id", requestID)
logger := log.WithFields(log.Fields{
"request_id": resutil.GetRequestID(w),
"dns_id": fmt.Sprintf("%04x", r.Id),
})
if len(r.Question) == 0 {
return
}
logger.Tracef("received DNS request for domain=%s type=%v class=%v",
r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass)
// pass if non A/AAAA query
if r.Question[0].Qtype != dns.TypeA && r.Question[0].Qtype != dns.TypeAAAA {
@@ -280,15 +281,10 @@ func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
return
}
var answer []dns.RR
if reply != nil {
answer = reply.Answer
}
logger.Tracef("upstream %s (%s) DNS response for domain=%s answers=%v", upstreamIP.String(), peerKey, r.Question[0].Name, answer)
resutil.SetMeta(w, "peer", peerKey)
reply.Id = r.Id
if err := d.writeMsg(w, reply); err != nil {
if err := d.writeMsg(w, reply, logger); err != nil {
logger.Errorf("failed writing DNS response: %v", err)
}
}
@@ -324,11 +320,15 @@ func (d *DnsInterceptor) getUpstreamIP(peerKey string) (netip.Addr, error) {
return peerAllowedIP, nil
}
func (d *DnsInterceptor) writeMsg(w dns.ResponseWriter, r *dns.Msg) error {
func (d *DnsInterceptor) writeMsg(w dns.ResponseWriter, r *dns.Msg, logger *log.Entry) error {
if r == nil {
return fmt.Errorf("received nil DNS message")
}
// Clear Zero bit from peer responses to prevent external sources from
// manipulating our internal fallthrough signaling mechanism
r.MsgHdr.Zero = false
if len(r.Answer) > 0 && len(r.Question) > 0 {
origPattern := ""
if writer, ok := w.(*nbdns.ResponseWriterChain); ok {
@@ -350,14 +350,14 @@ func (d *DnsInterceptor) writeMsg(w dns.ResponseWriter, r *dns.Msg) error {
case *dns.A:
addr, ok := netip.AddrFromSlice(rr.A)
if !ok {
log.Tracef("failed to convert A record for domain=%s ip=%v", resolvedDomain, rr.A)
logger.Tracef("failed to convert A record for domain=%s ip=%v", resolvedDomain, rr.A)
continue
}
ip = addr
case *dns.AAAA:
addr, ok := netip.AddrFromSlice(rr.AAAA)
if !ok {
log.Tracef("failed to convert AAAA record for domain=%s ip=%v", resolvedDomain, rr.AAAA)
logger.Tracef("failed to convert AAAA record for domain=%s ip=%v", resolvedDomain, rr.AAAA)
continue
}
ip = addr
@@ -370,11 +370,11 @@ func (d *DnsInterceptor) writeMsg(w dns.ResponseWriter, r *dns.Msg) error {
}
if len(newPrefixes) > 0 {
if err := d.updateDomainPrefixes(resolvedDomain, originalDomain, newPrefixes); err != nil {
log.Errorf("failed to update domain prefixes: %v", err)
if err := d.updateDomainPrefixes(resolvedDomain, originalDomain, newPrefixes, logger); err != nil {
logger.Errorf("failed to update domain prefixes: %v", err)
}
d.replaceIPsInDNSResponse(r, newPrefixes)
d.replaceIPsInDNSResponse(r, newPrefixes, logger)
}
}
@@ -386,22 +386,22 @@ func (d *DnsInterceptor) writeMsg(w dns.ResponseWriter, r *dns.Msg) error {
}
// logPrefixChanges handles the logging for prefix changes
func (d *DnsInterceptor) logPrefixChanges(resolvedDomain, originalDomain domain.Domain, toAdd, toRemove []netip.Prefix) {
func (d *DnsInterceptor) logPrefixChanges(resolvedDomain, originalDomain domain.Domain, toAdd, toRemove []netip.Prefix, logger *log.Entry) {
if len(toAdd) > 0 {
log.Debugf("added dynamic route(s) for domain=%s (pattern: domain=%s): %s",
logger.Debugf("added dynamic route(s) for domain=%s (pattern: domain=%s): %s",
resolvedDomain.SafeString(),
originalDomain.SafeString(),
toAdd)
}
if len(toRemove) > 0 && !d.route.KeepRoute {
log.Debugf("removed dynamic route(s) for domain=%s (pattern: domain=%s): %s",
logger.Debugf("removed dynamic route(s) for domain=%s (pattern: domain=%s): %s",
resolvedDomain.SafeString(),
originalDomain.SafeString(),
toRemove)
}
}
func (d *DnsInterceptor) updateDomainPrefixes(resolvedDomain, originalDomain domain.Domain, newPrefixes []netip.Prefix) error {
func (d *DnsInterceptor) updateDomainPrefixes(resolvedDomain, originalDomain domain.Domain, newPrefixes []netip.Prefix, logger *log.Entry) error {
d.mu.Lock()
defer d.mu.Unlock()
@@ -418,9 +418,9 @@ func (d *DnsInterceptor) updateDomainPrefixes(resolvedDomain, originalDomain dom
realIP := prefix.Addr()
if fakeIP, err := d.fakeIPManager.AllocateFakeIP(realIP); err == nil {
dnatMappings[fakeIP] = realIP
log.Tracef("allocated fake IP %s for real IP %s", fakeIP, realIP)
logger.Tracef("allocated fake IP %s for real IP %s", fakeIP, realIP)
} else {
log.Errorf("Failed to allocate fake IP for %s: %v", realIP, err)
logger.Errorf("failed to allocate fake IP for %s: %v", realIP, err)
}
}
}
@@ -432,7 +432,7 @@ func (d *DnsInterceptor) updateDomainPrefixes(resolvedDomain, originalDomain dom
}
}
d.addDNATMappings(dnatMappings)
d.addDNATMappings(dnatMappings, logger)
if !d.route.KeepRoute {
// Remove old prefixes
@@ -448,7 +448,7 @@ func (d *DnsInterceptor) updateDomainPrefixes(resolvedDomain, originalDomain dom
}
}
d.removeDNATMappings(toRemove)
d.removeDNATMappings(toRemove, logger)
}
// Update domain prefixes using resolved domain as key - store real IPs
@@ -463,14 +463,14 @@ func (d *DnsInterceptor) updateDomainPrefixes(resolvedDomain, originalDomain dom
// Store real IPs for status (user-facing), not fake IPs
d.statusRecorder.UpdateResolvedDomainsStates(originalDomain, resolvedDomain, newPrefixes, d.route.GetResourceID())
d.logPrefixChanges(resolvedDomain, originalDomain, toAdd, toRemove)
d.logPrefixChanges(resolvedDomain, originalDomain, toAdd, toRemove, logger)
}
return nberrors.FormatErrorOrNil(merr)
}
// removeDNATMappings removes DNAT mappings from the firewall for real IP prefixes
func (d *DnsInterceptor) removeDNATMappings(realPrefixes []netip.Prefix) {
func (d *DnsInterceptor) removeDNATMappings(realPrefixes []netip.Prefix, logger *log.Entry) {
if len(realPrefixes) == 0 {
return
}
@@ -484,9 +484,9 @@ func (d *DnsInterceptor) removeDNATMappings(realPrefixes []netip.Prefix) {
realIP := prefix.Addr()
if fakeIP, exists := d.fakeIPManager.GetFakeIP(realIP); exists {
if err := dnatFirewall.RemoveInternalDNATMapping(fakeIP); err != nil {
log.Errorf("Failed to remove DNAT mapping for %s: %v", fakeIP, err)
logger.Errorf("failed to remove DNAT mapping for %s: %v", fakeIP, err)
} else {
log.Debugf("Removed DNAT mapping for: %s -> %s", fakeIP, realIP)
logger.Debugf("removed DNAT mapping: %s -> %s", fakeIP, realIP)
}
}
}
@@ -502,7 +502,7 @@ func (d *DnsInterceptor) internalDnatFw() (internalDNATer, bool) {
}
// addDNATMappings adds DNAT mappings to the firewall
func (d *DnsInterceptor) addDNATMappings(mappings map[netip.Addr]netip.Addr) {
func (d *DnsInterceptor) addDNATMappings(mappings map[netip.Addr]netip.Addr, logger *log.Entry) {
if len(mappings) == 0 {
return
}
@@ -514,9 +514,9 @@ func (d *DnsInterceptor) addDNATMappings(mappings map[netip.Addr]netip.Addr) {
for fakeIP, realIP := range mappings {
if err := dnatFirewall.AddInternalDNATMapping(fakeIP, realIP); err != nil {
log.Errorf("Failed to add DNAT mapping %s -> %s: %v", fakeIP, realIP, err)
logger.Errorf("failed to add DNAT mapping %s -> %s: %v", fakeIP, realIP, err)
} else {
log.Debugf("Added DNAT mapping: %s -> %s", fakeIP, realIP)
logger.Debugf("added DNAT mapping: %s -> %s", fakeIP, realIP)
}
}
}
@@ -528,12 +528,12 @@ func (d *DnsInterceptor) cleanupDNATMappings() {
}
for _, prefixes := range d.interceptedDomains {
d.removeDNATMappings(prefixes)
d.removeDNATMappings(prefixes, log.NewEntry(log.StandardLogger()))
}
}
// replaceIPsInDNSResponse replaces real IPs with fake IPs in the DNS response
func (d *DnsInterceptor) replaceIPsInDNSResponse(reply *dns.Msg, realPrefixes []netip.Prefix) {
func (d *DnsInterceptor) replaceIPsInDNSResponse(reply *dns.Msg, realPrefixes []netip.Prefix, logger *log.Entry) {
if _, ok := d.internalDnatFw(); !ok {
return
}
@@ -549,7 +549,7 @@ func (d *DnsInterceptor) replaceIPsInDNSResponse(reply *dns.Msg, realPrefixes []
if fakeIP, exists := d.fakeIPManager.GetFakeIP(realIP); exists {
rr.A = fakeIP.AsSlice()
log.Tracef("Replaced real IP %s with fake IP %s in DNS response", realIP, fakeIP)
logger.Tracef("replaced real IP %s with fake IP %s in DNS response", realIP, fakeIP)
}
case *dns.AAAA:
@@ -560,7 +560,7 @@ func (d *DnsInterceptor) replaceIPsInDNSResponse(reply *dns.Msg, realPrefixes []
if fakeIP, exists := d.fakeIPManager.GetFakeIP(realIP); exists {
rr.AAAA = fakeIP.AsSlice()
log.Tracef("Replaced real IP %s with fake IP %s in DNS response", realIP, fakeIP)
logger.Tracef("replaced real IP %s with fake IP %s in DNS response", realIP, fakeIP)
}
}
}

View File

@@ -1,5 +1,4 @@
//go:build !windows
// +build !windows
package iface

View File

@@ -210,7 +210,8 @@ func (r *SysOps) refreshLocalSubnetsCache() {
func (r *SysOps) genericAddVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
nextHop := Nexthop{netip.Addr{}, intf}
if prefix == vars.Defaultv4 {
switch prefix {
case vars.Defaultv4:
if err := r.addToRouteTable(splitDefaultv4_1, nextHop); err != nil {
return err
}
@@ -233,7 +234,7 @@ func (r *SysOps) genericAddVPNRoute(prefix netip.Prefix, intf *net.Interface) er
}
return nil
} else if prefix == vars.Defaultv6 {
case vars.Defaultv6:
if err := r.addToRouteTable(splitDefaultv6_1, nextHop); err != nil {
return fmt.Errorf("add unreachable route split 1: %w", err)
}
@@ -255,7 +256,8 @@ func (r *SysOps) genericAddVPNRoute(prefix netip.Prefix, intf *net.Interface) er
func (r *SysOps) genericRemoveVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
nextHop := Nexthop{netip.Addr{}, intf}
if prefix == vars.Defaultv4 {
switch prefix {
case vars.Defaultv4:
var result *multierror.Error
if err := r.removeFromRouteTable(splitDefaultv4_1, nextHop); err != nil {
result = multierror.Append(result, err)
@@ -273,7 +275,7 @@ func (r *SysOps) genericRemoveVPNRoute(prefix netip.Prefix, intf *net.Interface)
}
return nberrors.FormatErrorOrNil(result)
} else if prefix == vars.Defaultv6 {
case vars.Defaultv6:
var result *multierror.Error
if err := r.removeFromRouteTable(splitDefaultv6_1, nextHop); err != nil {
result = multierror.Append(result, err)
@@ -283,9 +285,9 @@ func (r *SysOps) genericRemoveVPNRoute(prefix netip.Prefix, intf *net.Interface)
}
return nberrors.FormatErrorOrNil(result)
default:
return r.removeFromRouteTable(prefix, nextHop)
}
return r.removeFromRouteTable(prefix, nextHop)
}
func (r *SysOps) setupHooks(initAddresses []net.IP, stateManager *statemanager.Manager) error {

View File

@@ -76,7 +76,7 @@ type Client struct {
loginComplete bool
connectClient *internal.ConnectClient
// preloadedConfig holds config loaded from JSON (used on tvOS where file writes are blocked)
preloadedConfig *profilemanager.Config
preloadedConfig *profilemanager.Config
}
// NewClient instantiate a new Client

View File

@@ -1,5 +1,4 @@
//go:build windows
// +build windows
package server

View File

@@ -602,12 +602,13 @@ func TestJWTAuthentication(t *testing.T) {
require.NoError(t, err)
var authMethods []cryptossh.AuthMethod
if tc.token == "valid" {
switch tc.token {
case "valid":
token := generateValidJWT(t, privateKey, issuer, audience)
authMethods = []cryptossh.AuthMethod{
cryptossh.Password(token),
}
} else if tc.token == "invalid" {
case "invalid":
invalidToken := "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ.invalid"
authMethods = []cryptossh.AuthMethod{
cryptossh.Password(invalidToken),

View File

@@ -1,6 +1,3 @@
//go:build android
// +build android
package system
import (

View File

@@ -1,5 +1,4 @@
//go:build !ios
// +build !ios
package system

View File

@@ -1,6 +1,3 @@
//go:build ios
// +build ios
package system
import (

View File

@@ -510,7 +510,7 @@ func (s *serviceClient) saveSettings() {
// Continue with default behavior if features can't be retrieved
} else if features != nil && features.DisableUpdateSettings {
log.Warn("Configuration updates are disabled by daemon")
dialog.ShowError(fmt.Errorf("Configuration updates are disabled by daemon"), s.wSettings)
dialog.ShowError(fmt.Errorf("configuration updates are disabled by daemon"), s.wSettings)
return
}
@@ -540,7 +540,7 @@ func (s *serviceClient) saveSettings() {
func (s *serviceClient) validateSettings() error {
if s.iPreSharedKey.Text != "" && s.iPreSharedKey.Text != censoredPreSharedKey {
if _, err := wgtypes.ParseKey(s.iPreSharedKey.Text); err != nil {
return fmt.Errorf("Invalid Pre-shared Key Value")
return fmt.Errorf("invalid pre-shared key value")
}
}
return nil
@@ -549,10 +549,10 @@ func (s *serviceClient) validateSettings() error {
func (s *serviceClient) parseNumericSettings() (int64, int64, error) {
port, err := strconv.ParseInt(s.iInterfacePort.Text, 10, 64)
if err != nil {
return 0, 0, errors.New("Invalid interface port")
return 0, 0, errors.New("invalid interface port")
}
if port < 1 || port > 65535 {
return 0, 0, errors.New("Invalid interface port: out of range 1-65535")
return 0, 0, errors.New("invalid interface port: out of range 1-65535")
}
var mtu int64
@@ -560,7 +560,7 @@ func (s *serviceClient) parseNumericSettings() (int64, int64, error) {
if mtuText != "" {
mtu, err = strconv.ParseInt(mtuText, 10, 64)
if err != nil {
return 0, 0, errors.New("Invalid MTU value")
return 0, 0, errors.New("invalid MTU value")
}
if mtu < iface.MinMTU || mtu > iface.MaxMTU {
return 0, 0, fmt.Errorf("MTU must be between %d and %d bytes", iface.MinMTU, iface.MaxMTU)
@@ -645,7 +645,7 @@ func (s *serviceClient) buildSetConfigRequest(iMngURL string, port, mtu int64) (
if sshJWTCacheTTLText != "" {
sshJWTCacheTTL, err := strconv.ParseInt(sshJWTCacheTTLText, 10, 32)
if err != nil {
return nil, errors.New("Invalid SSH JWT Cache TTL value")
return nil, errors.New("invalid SSH JWT Cache TTL value")
}
if sshJWTCacheTTL < 0 || sshJWTCacheTTL > maxSSHJWTCacheTTL {
return nil, fmt.Errorf("SSH JWT Cache TTL must be between 0 and %d seconds", maxSSHJWTCacheTTL)

View File

@@ -164,7 +164,7 @@ func sendShowWindowSignal(pid int32) error {
err = windows.SetEvent(eventHandle)
if err != nil {
return fmt.Errorf("Error setting event: %w", err)
return fmt.Errorf("error setting event: %w", err)
}
return nil

View File

@@ -47,8 +47,8 @@ type CustomZone struct {
Records []SimpleRecord
// SearchDomainDisabled indicates whether to add match domains to a search domains list or not
SearchDomainDisabled bool
// SkipPTRProcess indicates whether a client should process PTR records from custom zones
SkipPTRProcess bool
// NonAuthoritative marks user-created zones
NonAuthoritative bool
}
// SimpleRecord provides a simple DNS record specification for CNAME, A and AAAA records

16
go.mod
View File

@@ -1,6 +1,8 @@
module github.com/netbirdio/netbird
go 1.24.10
go 1.25
toolchain go1.25.5
require (
cunicu.li/go-rosenpass v0.4.0
@@ -40,7 +42,7 @@ require (
github.com/cilium/ebpf v0.15.0
github.com/coder/websocket v1.8.13
github.com/coreos/go-iptables v0.7.0
github.com/creack/pty v1.1.18
github.com/creack/pty v1.1.24
github.com/dexidp/dex v0.0.0-00010101000000-000000000000
github.com/dexidp/dex/api/v2 v2.4.0
github.com/eko/gocache/lib/v4 v4.2.0
@@ -81,7 +83,7 @@ require (
github.com/pion/turn/v3 v3.0.1
github.com/pkg/sftp v1.13.9
github.com/prometheus/client_golang v1.23.2
github.com/quic-go/quic-go v0.49.1
github.com/quic-go/quic-go v0.55.0
github.com/redis/go-redis/v9 v9.7.3
github.com/rs/xid v1.3.0
github.com/shirou/gopsutil/v3 v3.24.4
@@ -103,7 +105,7 @@ require (
go.opentelemetry.io/otel/exporters/prometheus v0.48.0
go.opentelemetry.io/otel/metric v1.38.0
go.opentelemetry.io/otel/sdk/metric v1.38.0
go.uber.org/mock v0.5.0
go.uber.org/mock v0.5.2
go.uber.org/zap v1.27.0
goauthentik.io/api/v3 v3.2023051.3
golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842
@@ -120,7 +122,7 @@ require (
gorm.io/driver/postgres v1.5.7
gorm.io/driver/sqlite v1.5.7
gorm.io/gorm v1.25.12
gvisor.dev/gvisor v0.0.0-20231020174304-b8a429915ff1
gvisor.dev/gvisor v0.0.0-20251031020517-ecfcdd2f171c
)
require (
@@ -186,12 +188,10 @@ require (
github.com/go-logr/stdr v1.2.2 // indirect
github.com/go-ole/go-ole v1.3.0 // indirect
github.com/go-sql-driver/mysql v1.9.3 // indirect
github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 // indirect
github.com/go-text/render v0.2.0 // indirect
github.com/go-text/typesetting v0.2.1 // indirect
github.com/gogo/protobuf v1.3.2 // indirect
github.com/google/btree v1.1.2 // indirect
github.com/google/pprof v0.0.0-20211214055906-6f57359322fd // indirect
github.com/google/s2a-go v0.1.9 // indirect
github.com/googleapis/enterprise-certificate-proxy v0.3.7 // indirect
github.com/googleapis/gax-go/v2 v2.15.0 // indirect
@@ -285,7 +285,7 @@ replace github.com/kardianos/service => github.com/netbirdio/service v0.0.0-2024
replace github.com/getlantern/systray => github.com/netbirdio/systray v0.0.0-20231030152038-ef1ed2a27949
replace golang.zx2c4.com/wireguard => github.com/netbirdio/wireguard-go v0.0.0-20241230120307-6a676aebaaf6
replace golang.zx2c4.com/wireguard => github.com/netbirdio/wireguard-go v0.0.0-20260107100953-33b7c9d03db0
replace github.com/cloudflare/circl => github.com/cunicu/circl v0.0.0-20230801113412-fec58fc7b5f6

25
go.sum
View File

@@ -101,9 +101,6 @@ github.com/cenkalti/backoff/v4 v4.3.0 h1:MyRJ/UdXutAwSAT+s3wNd7MfTIcy71VQueUuFK3
github.com/cenkalti/backoff/v4 v4.3.0/go.mod h1:Y3VNntkOUPxTVeUxJ/G5vcM//AlwfmyYozVcomhLiZE=
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/chzyer/logex v1.1.10/go.mod h1:+Ywpsq7O8HXn0nuIou7OrIPyXbp3wmkHB+jjWRnGsAI=
github.com/chzyer/readline v0.0.0-20180603132655-2972be24d48e/go.mod h1:nSuG5e5PlCu98SY8svDHJxuZscDgtXS6KTTbou5AhLI=
github.com/chzyer/test v0.0.0-20180213035817-a1ea475d72b1/go.mod h1:Q3SI9o4m/ZMnBNeIyt5eFwwo7qiLfzFZmjNmxjkiQlU=
github.com/cilium/ebpf v0.15.0 h1:7NxJhNiBT3NG8pZJ3c+yfrVdHY8ScgKD27sScgjLMMk=
github.com/cilium/ebpf v0.15.0/go.mod h1:DHp1WyrLeiBh19Cf/tfiSMhqheEiK8fXFZ4No0P1Hso=
github.com/coder/websocket v1.8.13 h1:f3QZdXy7uGVz+4uCJy2nTZyM0yTBj8yANEHhqlXZ9FE=
@@ -121,8 +118,8 @@ github.com/coreos/go-oidc/v3 v3.14.1/go.mod h1:HaZ3szPaZ0e4r6ebqvsLWlk2Tn+aejfmr
github.com/cpuguy83/dockercfg v0.3.2 h1:DlJTyZGBDlXqUZ2Dk2Q3xHs/FtnooJJVaad2S9GKorA=
github.com/cpuguy83/dockercfg v0.3.2/go.mod h1:sugsbF4//dDlL/i+S+rtpIWp+5h0BHJHfjj5/jFyUJc=
github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g=
github.com/creack/pty v1.1.18 h1:n56/Zwd5o6whRC5PMGretI4IdRLlmBXYNjScPaBgsbY=
github.com/creack/pty v1.1.18/go.mod h1:MOBLtS5ELjhRRrroQr9kyvTxUAFNvYEK993ew/Vr4O4=
github.com/creack/pty v1.1.24 h1:bJrF4RRfyJnbTJqzRLHzcGaZK1NeM5kTC9jGgovnR1s=
github.com/creack/pty v1.1.24/go.mod h1:08sCNb52WyoAwi2QDyzUCTgcvVFhUzewun7wtTfvcwE=
github.com/cunicu/circl v0.0.0-20230801113412-fec58fc7b5f6 h1:/DS5cDX3FJdl+XaN2D7XAwFpuanTxnp52DBLZAaJKx0=
github.com/cunicu/circl v0.0.0-20230801113412-fec58fc7b5f6/go.mod h1:+CauBF6R70Jqcyl8N2hC8pAXYbWkGIezuSbuGLtRhnw=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
@@ -286,7 +283,6 @@ github.com/hashicorp/go-version v1.6.0/go.mod h1:fltr4n8CU8Ke44wwGCBoEymUuxUHl09
github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU=
github.com/huandu/xstrings v1.5.0 h1:2ag3IFq9ZDANvthTwTiqSSZLjDc+BedvHPAp5tJy2TI=
github.com/huandu/xstrings v1.5.0/go.mod h1:y5/lhBue+AyNmUVz9RLU9xbLR0o4KIIExikq4ovT0aE=
github.com/ianlancetaylor/demangle v0.0.0-20210905161508-09a460cdf81d/go.mod h1:aYm2/VgdVmcIU8iMfdMvDMsRAQjcfZSKFby6HOFvi/w=
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
@@ -411,8 +407,8 @@ github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9ax
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45 h1:ujgviVYmx243Ksy7NdSwrdGPSRNE3pb8kEDSpH0QuAQ=
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45/go.mod h1:5/sjFmLb8O96B5737VCqhHyGRzNFIaN/Bu7ZodXc3qQ=
github.com/netbirdio/wireguard-go v0.0.0-20241230120307-6a676aebaaf6 h1:X5h5QgP7uHAv78FWgHV8+WYLjHxK9v3ilkVXT1cpCrQ=
github.com/netbirdio/wireguard-go v0.0.0-20241230120307-6a676aebaaf6/go.mod h1:tkCQ4FQXmpAgYVh++1cq16/dH4QJtmvpRv19DWGAHSA=
github.com/netbirdio/wireguard-go v0.0.0-20260107100953-33b7c9d03db0 h1:h/QnNzm7xzHPm+gajcblYUOclrW2FeNeDlUNj6tTWKQ=
github.com/netbirdio/wireguard-go v0.0.0-20260107100953-33b7c9d03db0/go.mod h1:rpwXGsirqLqN2L0JDJQlwOboGHmptD5ZD6T2VmcqhTw=
github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646 h1:zYyBkD/k9seD2A7fsi6Oo2LfFZAehjjQMERAvZLEDnQ=
github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646/go.mod h1:jpp1/29i3P1S/RLdc7JQKbRpFeM1dOBd8T9ki5s+AY8=
github.com/nicksnyder/go-i18n/v2 v2.5.1 h1:IxtPxYsR9Gp60cGXjfuR/llTqV8aYMsC472zD0D1vHk=
@@ -491,8 +487,8 @@ github.com/prometheus/common v0.66.1 h1:h5E0h5/Y8niHc5DlaLlWLArTQI7tMrsfQjHV+d9Z
github.com/prometheus/common v0.66.1/go.mod h1:gcaUsgf3KfRSwHY4dIMXLPV0K/Wg1oZ8+SbZk/HH/dA=
github.com/prometheus/procfs v0.16.1 h1:hZ15bTNuirocR6u0JZ6BAHHmwS1p8B4P6MRqxtzMyRg=
github.com/prometheus/procfs v0.16.1/go.mod h1:teAbpZRB1iIAJYREa1LsoWUXykVXA1KlTmWl8x/U+Is=
github.com/quic-go/quic-go v0.49.1 h1:e5JXpUyF0f2uFjckQzD8jTghZrOUK1xxDqqZhlwixo0=
github.com/quic-go/quic-go v0.49.1/go.mod h1:s2wDnmCdooUQBmQfpUSTCYBl1/D4FcqbULMMkASvR6s=
github.com/quic-go/quic-go v0.55.0 h1:zccPQIqYCXDt5NmcEabyYvOnomjs8Tlwl7tISjJh9Mk=
github.com/quic-go/quic-go v0.55.0/go.mod h1:DR51ilwU1uE164KuWXhinFcKWGlEjzys2l8zUl5Ss1U=
github.com/redis/go-redis/v9 v9.7.3 h1:YpPyAayJV+XErNsatSElgRZZVCwXX9QzkKYNvO7x0wM=
github.com/redis/go-redis/v9 v9.7.3/go.mod h1:bGUrSggJ9X9GUmZpZNEOQKaANxSGgOEBRltRTZHSvrA=
github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0tI/otEQ=
@@ -622,8 +618,8 @@ go.opentelemetry.io/proto/otlp v1.0.0 h1:T0TX0tmXU8a3CbNXzEKGeU5mIVOdf0oykP+u2lI
go.opentelemetry.io/proto/otlp v1.0.0/go.mod h1:Sy6pihPLfYHkr3NkUbEhGHFhINUSI/v80hjKIs5JXpM=
go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto=
go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE=
go.uber.org/mock v0.5.0 h1:KAMbZvZPyBPWgD14IrIQ38QCyjwpvVVV6K/bHl1IwQU=
go.uber.org/mock v0.5.0/go.mod h1:ge71pBPLYDk7QIi1LupWxdAykm7KIEFchiOqd6z7qMM=
go.uber.org/mock v0.5.2 h1:LbtPTcP8A5k9WPXj54PPPbjcI4Y6lhyOZXn+VS7wNko=
go.uber.org/mock v0.5.2/go.mod h1:wLlUxC2vVTPTaE3UD51E0BGOAElKrILxhVSDYQLld5o=
go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0=
go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y=
go.uber.org/zap v1.27.0 h1:aJMhYGrd5QSmlpLMr2MftRKl7t8J8PTZPA732ud/XR8=
@@ -717,7 +713,6 @@ golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7w
golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20210616094352-59db8d763f22/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20211007075335-d3039528d8ac/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
@@ -848,5 +843,5 @@ gorm.io/gorm v1.25.12 h1:I0u8i2hWQItBq1WfE0o2+WuL9+8L21K9e2HHSTE/0f8=
gorm.io/gorm v1.25.12/go.mod h1:xh7N7RHfYlNc5EmcI/El95gXusucDrQnHXe0+CgWcLQ=
gotest.tools/v3 v3.5.0 h1:Ljk6PdHdOhAb5aDMWXjDLMMhph+BpztA4v1QdqEW2eY=
gotest.tools/v3 v3.5.0/go.mod h1:isy3WKz7GK6uNw/sbHzfKBLvlvXwUyV06n6brMxxopU=
gvisor.dev/gvisor v0.0.0-20231020174304-b8a429915ff1 h1:qDCwdCWECGnwQSQC01Dpnp09fRHxJs9PbktotUqG+hs=
gvisor.dev/gvisor v0.0.0-20231020174304-b8a429915ff1/go.mod h1:8hmigyCdYtw5xJGfQDJzSH5Ju8XEIDBnpyi8+O6GRt8=
gvisor.dev/gvisor v0.0.0-20251031020517-ecfcdd2f171c h1:pfzmXIkkDgydR4ZRP+e1hXywZfYR21FA0Fbk6ptMkiA=
gvisor.dev/gvisor v0.0.0-20251031020517-ecfcdd2f171c/go.mod h1:/mc6CfwbOm5KKmqoV7Qx20Q+Ja8+vO4g7FuCdlVoAfQ=

113
idp/dex/logrus_handler.go Normal file
View File

@@ -0,0 +1,113 @@
package dex
import (
"context"
"log/slog"
"github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/formatter"
)
// LogrusHandler is an slog.Handler that delegates to logrus.
// This allows Dex to use the same log format as the rest of NetBird.
type LogrusHandler struct {
logger *logrus.Logger
attrs []slog.Attr
groups []string
}
// NewLogrusHandler creates a new slog handler that wraps logrus with NetBird's text formatter.
func NewLogrusHandler(level slog.Level) *LogrusHandler {
logger := logrus.New()
formatter.SetTextFormatter(logger)
// Map slog level to logrus level
switch level {
case slog.LevelDebug:
logger.SetLevel(logrus.DebugLevel)
case slog.LevelInfo:
logger.SetLevel(logrus.InfoLevel)
case slog.LevelWarn:
logger.SetLevel(logrus.WarnLevel)
case slog.LevelError:
logger.SetLevel(logrus.ErrorLevel)
default:
logger.SetLevel(logrus.WarnLevel)
}
return &LogrusHandler{logger: logger}
}
// Enabled reports whether the handler handles records at the given level.
func (h *LogrusHandler) Enabled(_ context.Context, level slog.Level) bool {
switch level {
case slog.LevelDebug:
return h.logger.IsLevelEnabled(logrus.DebugLevel)
case slog.LevelInfo:
return h.logger.IsLevelEnabled(logrus.InfoLevel)
case slog.LevelWarn:
return h.logger.IsLevelEnabled(logrus.WarnLevel)
case slog.LevelError:
return h.logger.IsLevelEnabled(logrus.ErrorLevel)
default:
return true
}
}
// Handle handles the Record.
func (h *LogrusHandler) Handle(_ context.Context, r slog.Record) error {
fields := make(logrus.Fields)
// Add pre-set attributes
for _, attr := range h.attrs {
fields[attr.Key] = attr.Value.Any()
}
// Add record attributes
r.Attrs(func(attr slog.Attr) bool {
fields[attr.Key] = attr.Value.Any()
return true
})
entry := h.logger.WithFields(fields)
switch r.Level {
case slog.LevelDebug:
entry.Debug(r.Message)
case slog.LevelInfo:
entry.Info(r.Message)
case slog.LevelWarn:
entry.Warn(r.Message)
case slog.LevelError:
entry.Error(r.Message)
default:
entry.Info(r.Message)
}
return nil
}
// WithAttrs returns a new Handler with the given attributes added.
func (h *LogrusHandler) WithAttrs(attrs []slog.Attr) slog.Handler {
newAttrs := make([]slog.Attr, len(h.attrs)+len(attrs))
copy(newAttrs, h.attrs)
copy(newAttrs[len(h.attrs):], attrs)
return &LogrusHandler{
logger: h.logger,
attrs: newAttrs,
groups: h.groups,
}
}
// WithGroup returns a new Handler with the given group appended to the receiver's groups.
func (h *LogrusHandler) WithGroup(name string) slog.Handler {
newGroups := make([]string, len(h.groups)+1)
copy(newGroups, h.groups)
newGroups[len(h.groups)] = name
return &LogrusHandler{
logger: h.logger,
attrs: h.attrs,
groups: newGroups,
}
}

View File

@@ -130,7 +130,21 @@ func NewProvider(ctx context.Context, config *Config) (*Provider, error) {
// NewProviderFromYAML creates and initializes the Dex server from a YAMLConfig
func NewProviderFromYAML(ctx context.Context, yamlConfig *YAMLConfig) (*Provider, error) {
logger := slog.New(slog.NewTextHandler(os.Stderr, nil))
// Configure log level from config, default to WARN to avoid logging sensitive data (emails)
logLevel := slog.LevelWarn
if yamlConfig.Logger.Level != "" {
switch strings.ToLower(yamlConfig.Logger.Level) {
case "debug":
logLevel = slog.LevelDebug
case "info":
logLevel = slog.LevelInfo
case "warn", "warning":
logLevel = slog.LevelWarn
case "error":
logLevel = slog.LevelError
}
}
logger := slog.New(NewLogrusHandler(logLevel))
stor, err := yamlConfig.Storage.OpenStorage(logger)
if err != nil {
@@ -778,11 +792,12 @@ func (p *Provider) resolveRedirectURI(redirectURI string) string {
// buildOIDCConnectorConfig creates config for OIDC-based connectors
func buildOIDCConnectorConfig(cfg *ConnectorConfig, redirectURI string) ([]byte, error) {
oidcConfig := map[string]interface{}{
"issuer": cfg.Issuer,
"clientID": cfg.ClientID,
"clientSecret": cfg.ClientSecret,
"redirectURI": redirectURI,
"scopes": []string{"openid", "profile", "email"},
"issuer": cfg.Issuer,
"clientID": cfg.ClientID,
"clientSecret": cfg.ClientSecret,
"redirectURI": redirectURI,
"scopes": []string{"openid", "profile", "email"},
"insecureEnableGroups": true,
}
switch cfg.Type {
case "zitadel":
@@ -792,6 +807,9 @@ func buildOIDCConnectorConfig(cfg *ConnectorConfig, redirectURI string) ([]byte,
oidcConfig["claimMapping"] = map[string]string{"email": "preferred_username"}
case "okta":
oidcConfig["insecureSkipEmailVerified"] = true
oidcConfig["scopes"] = []string{"openid", "profile", "email", "groups"}
case "pocketid":
oidcConfig["scopes"] = []string{"openid", "profile", "email", "groups"}
}
return encodeConnectorConfig(oidcConfig)
}

View File

@@ -270,7 +270,7 @@ AUTH_CLIENT_ID=netbird-dashboard
AUTH_CLIENT_SECRET=
AUTH_AUTHORITY=$NETBIRD_HTTP_PROTOCOL://$NETBIRD_DOMAIN/oauth2
USE_AUTH0=false
AUTH_SUPPORTED_SCOPES=openid profile email offline_access
AUTH_SUPPORTED_SCOPES=openid profile email groups
AUTH_REDIRECT_URI=/nb-auth
AUTH_SILENT_REDIRECT_URI=/nb-silent-auth
# SSL

View File

@@ -64,7 +64,7 @@ var (
config.HttpConfig.IdpSignKeyRefreshEnabled = idpSignKeyRefreshEnabled
}
tlsEnabled := false
var tlsEnabled bool
if mgmtLetsencryptDomain != "" || (config.HttpConfig.CertFile != "" && config.HttpConfig.CertKey != "") {
tlsEnabled = true
}
@@ -143,7 +143,7 @@ func loadMgmtConfig(ctx context.Context, mgmtConfigPath string) (*nbconfig.Confi
applyCommandLineOverrides(loadedConfig)
// Apply EmbeddedIdP config to HttpConfig if embedded IdP is enabled
err := applyEmbeddedIdPConfig(loadedConfig)
err := applyEmbeddedIdPConfig(ctx, loadedConfig)
if err != nil {
return nil, err
}
@@ -177,7 +177,7 @@ func applyCommandLineOverrides(cfg *nbconfig.Config) {
// applyEmbeddedIdPConfig populates HttpConfig and EmbeddedIdP storage from config when embedded IdP is enabled.
// This allows users to only specify EmbeddedIdP config without duplicating values in HttpConfig.
func applyEmbeddedIdPConfig(cfg *nbconfig.Config) error {
func applyEmbeddedIdPConfig(ctx context.Context, cfg *nbconfig.Config) error {
if cfg.EmbeddedIdP == nil || !cfg.EmbeddedIdP.Enabled {
return nil
}
@@ -190,10 +190,8 @@ func applyEmbeddedIdPConfig(cfg *nbconfig.Config) error {
// Enable user deletion from IDP by default if EmbeddedIdP is enabled
userDeleteFromIDPEnabled = true
// Ensure HttpConfig exists
if cfg.HttpConfig == nil {
cfg.HttpConfig = &nbconfig.HttpServerConfig{}
}
// Set LocalAddress for embedded IdP if enabled, used for internal JWT validation
cfg.EmbeddedIdP.LocalAddress = fmt.Sprintf("localhost:%d", mgmtPort)
// Set storage defaults based on Datadir
if cfg.EmbeddedIdP.Storage.Type == "" {
@@ -205,40 +203,22 @@ func applyEmbeddedIdPConfig(cfg *nbconfig.Config) error {
issuer := cfg.EmbeddedIdP.Issuer
// Set AuthIssuer from EmbeddedIdP issuer
if cfg.HttpConfig.AuthIssuer == "" {
cfg.HttpConfig.AuthIssuer = issuer
if cfg.HttpConfig != nil {
log.WithContext(ctx).Warnf("overriding HttpConfig with EmbeddedIdP config. " +
"HttpConfig is ignored when EmbeddedIdP is enabled. Please remove HttpConfig section from the config file")
} else {
// Ensure HttpConfig exists. We need it for backwards compatibility with the old config format.
cfg.HttpConfig = &nbconfig.HttpServerConfig{}
}
// Set AuthAudience to the dashboard client ID
if cfg.HttpConfig.AuthAudience == "" {
cfg.HttpConfig.AuthAudience = "netbird-dashboard"
}
// Set CLIAuthAudience to the client app client ID
if cfg.HttpConfig.CLIAuthAudience == "" {
cfg.HttpConfig.CLIAuthAudience = "netbird-cli"
}
// Set AuthUserIDClaim to "sub" (standard OIDC claim)
if cfg.HttpConfig.AuthUserIDClaim == "" {
cfg.HttpConfig.AuthUserIDClaim = "sub"
}
// Set AuthKeysLocation to the JWKS endpoint
if cfg.HttpConfig.AuthKeysLocation == "" {
cfg.HttpConfig.AuthKeysLocation = issuer + "/keys"
}
// Set OIDCConfigEndpoint to the discovery endpoint
if cfg.HttpConfig.OIDCConfigEndpoint == "" {
cfg.HttpConfig.OIDCConfigEndpoint = issuer + "/.well-known/openid-configuration"
}
// Copy SignKeyRefreshEnabled from EmbeddedIdP config
if cfg.EmbeddedIdP.SignKeyRefreshEnabled {
cfg.HttpConfig.IdpSignKeyRefreshEnabled = true
}
// Set HttpConfig values from EmbeddedIdP
cfg.HttpConfig.AuthIssuer = issuer
cfg.HttpConfig.AuthAudience = "netbird-dashboard"
cfg.HttpConfig.CLIAuthAudience = "netbird-cli"
cfg.HttpConfig.AuthUserIDClaim = "sub"
cfg.HttpConfig.AuthKeysLocation = issuer + "/keys"
cfg.HttpConfig.OIDCConfigEndpoint = issuer + "/.well-known/openid-configuration"
cfg.HttpConfig.IdpSignKeyRefreshEnabled = true
return nil
}
@@ -246,7 +226,12 @@ func applyEmbeddedIdPConfig(cfg *nbconfig.Config) error {
// applyOIDCConfig fetches and applies OIDC configuration if endpoint is specified
func applyOIDCConfig(ctx context.Context, cfg *nbconfig.Config) error {
oidcEndpoint := cfg.HttpConfig.OIDCConfigEndpoint
if oidcEndpoint == "" || cfg.EmbeddedIdP != nil {
if oidcEndpoint == "" {
return nil
}
if cfg.EmbeddedIdP != nil && cfg.EmbeddedIdP.Enabled {
// skip OIDC config fetching if EmbeddedIdP is enabled as it is unnecessary given it is embedded
return nil
}

View File

@@ -68,7 +68,8 @@ func (s *BaseServer) AuthManager() auth.Manager {
if len(audiences) > 0 {
audience = audiences[0] // Use the first client ID as the primary audience
}
keysLocation = oauthProvider.GetKeysLocation()
// Use localhost keys location for internal validation (management has embedded Dex)
keysLocation = oauthProvider.GetLocalKeysLocation()
signingKeyRefreshEnabled = true
issuer = oauthProvider.GetIssuer()
userIDClaim = oauthProvider.GetUserIDClaim()

View File

@@ -374,8 +374,9 @@ func shouldUsePortRange(rule *proto.FirewallRule) bool {
// Helper function to convert nbdns.CustomZone to proto.CustomZone
func convertToProtoCustomZone(zone nbdns.CustomZone) *proto.CustomZone {
protoZone := &proto.CustomZone{
Domain: zone.Domain,
Records: make([]*proto.SimpleRecord, 0, len(zone.Records)),
Domain: zone.Domain,
Records: make([]*proto.SimpleRecord, 0, len(zone.Records)),
NonAuthoritative: zone.NonAuthoritative,
}
for _, record := range zone.Records {
protoZone.Records = append(protoZone.Records, &proto.SimpleRecord{

View File

@@ -85,6 +85,7 @@ func (s *LoginFilterTestSuite) TestBanDurationIncreasesExponentially() {
s.True(s.filter.logged[pubKey].isBanned)
s.Equal(2, s.filter.logged[pubKey].banLevel)
secondBanDuration := s.filter.logged[pubKey].banExpiresAt.Sub(s.filter.logged[pubKey].lastSeen)
// nolint
expectedSecondDuration := time.Duration(float64(baseBan) * math.Pow(2, 1))
s.InDelta(expectedSecondDuration, secondBanDuration, float64(time.Millisecond))
}

View File

@@ -1006,7 +1006,7 @@ func (am *DefaultAccountManager) isCacheFresh(ctx context.Context, accountUsers
for user, loggedInOnce := range accountUsers {
if datum, ok := userDataMap[user]; ok {
// check if the matching user data has a pending invite and if the user has logged in once, forcing the cache to be refreshed
if datum.AppMetadata.WTPendingInvite != nil && *datum.AppMetadata.WTPendingInvite && loggedInOnce == true { //nolint:gosimple
if datum.AppMetadata.WTPendingInvite != nil && *datum.AppMetadata.WTPendingInvite && loggedInOnce == true { //nolint
log.WithContext(ctx).Infof("user %s has a pending invite and has logged in once, cache invalid", user)
return false
}

View File

@@ -753,7 +753,7 @@ func TestAccountManager_SetOrUpdateDomain(t *testing.T) {
t.Fatalf("expected to create an account for a user %s", userId)
}
if account != nil && account.Domain != domain {
if account.Domain != domain {
t.Errorf("setting account domain failed, expected %s, got %s", domain, account.Domain)
}
@@ -768,7 +768,7 @@ func TestAccountManager_SetOrUpdateDomain(t *testing.T) {
t.Fatalf("expected to get an account for a user %s", userId)
}
if account != nil && account.Domain != domain {
if account.Domain != domain {
t.Errorf("updating domain. expected %s got %s", domain, account.Domain)
}
}
@@ -3465,11 +3465,11 @@ func TestPropagateUserGroupMemberships(t *testing.T) {
account, err := manager.GetOrCreateAccountByUser(ctx, auth.UserAuth{UserId: initiatorId, Domain: domain})
require.NoError(t, err)
peer1 := &nbpeer.Peer{ID: "peer1", AccountID: account.Id, UserID: initiatorId, IP: net.IP{1, 1, 1, 1}, DNSLabel: "peer1.domain.test"}
peer1 := &nbpeer.Peer{ID: "peer1", AccountID: account.Id, Key: "key1", UserID: initiatorId, IP: net.IP{1, 1, 1, 1}, DNSLabel: "peer1.domain.test"}
err = manager.Store.AddPeerToAccount(ctx, peer1)
require.NoError(t, err)
peer2 := &nbpeer.Peer{ID: "peer2", AccountID: account.Id, UserID: initiatorId, IP: net.IP{2, 2, 2, 2}, DNSLabel: "peer2.domain.test"}
peer2 := &nbpeer.Peer{ID: "peer2", AccountID: account.Id, Key: "key2", UserID: initiatorId, IP: net.IP{2, 2, 2, 2}, DNSLabel: "peer2.domain.test"}
err = manager.Store.AddPeerToAccount(ctx, peer2)
require.NoError(t, err)

View File

@@ -893,6 +893,7 @@ func Test_AddPeerAndAddToAll(t *testing.T) {
peer := &peer2.Peer{
ID: strconv.Itoa(i),
AccountID: accountID,
Key: "key" + strconv.Itoa(i),
DNSLabel: "peer" + strconv.Itoa(i),
IP: uint32ToIP(uint32(i)),
}

View File

@@ -46,7 +46,7 @@ func initPostureChecksTestData(postureChecks ...*posture.Checks) *postureChecksH
testPostureChecks[postureChecks.ID] = postureChecks
if err := postureChecks.Validate(); err != nil {
return nil, status.Errorf(status.InvalidArgument, "%s", err.Error()) //nolint
return nil, status.Errorf(status.InvalidArgument, "%v", err) //nolint
}
return postureChecks, nil

View File

@@ -1,5 +1,4 @@
//go:build benchmark
// +build benchmark
package benchmarks

View File

@@ -1,5 +1,4 @@
//go:build benchmark
// +build benchmark
package benchmarks

View File

@@ -1,5 +1,4 @@
//go:build benchmark
// +build benchmark
package benchmarks

View File

@@ -1,5 +1,4 @@
//go:build integration
// +build integration
package integration

View File

@@ -2,7 +2,13 @@ package server
import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"strings"
"time"
"github.com/dexidp/dex/storage"
"github.com/rs/xid"
@@ -17,6 +23,69 @@ import (
"github.com/netbirdio/netbird/shared/management/status"
)
// oidcProviderJSON represents the OpenID Connect discovery document
type oidcProviderJSON struct {
Issuer string `json:"issuer"`
}
// validateOIDCIssuer validates the OIDC issuer by fetching the OpenID configuration
// and verifying that the returned issuer matches the configured one.
func validateOIDCIssuer(ctx context.Context, issuer string) error {
wellKnown := strings.TrimSuffix(issuer, "/") + "/.well-known/openid-configuration"
httpClient := &http.Client{
Timeout: 10 * time.Second,
}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, wellKnown, nil)
if err != nil {
return fmt.Errorf("%w: %v", types.ErrIdentityProviderIssuerUnreachable, err)
}
resp, err := httpClient.Do(req)
if err != nil {
return fmt.Errorf("%w: %v", types.ErrIdentityProviderIssuerUnreachable, err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return fmt.Errorf("%w: unable to read response body: %v", types.ErrIdentityProviderIssuerUnreachable, err)
}
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("%w: %s: %s", types.ErrIdentityProviderIssuerUnreachable, resp.Status, body)
}
var p oidcProviderJSON
if err := json.Unmarshal(body, &p); err != nil {
return fmt.Errorf("%w: failed to decode provider discovery object: %v", types.ErrIdentityProviderIssuerUnreachable, err)
}
if p.Issuer != issuer {
return fmt.Errorf("%w: expected %q got %q", types.ErrIdentityProviderIssuerMismatch, issuer, p.Issuer)
}
return nil
}
// validateIdentityProviderConfig validates the identity provider configuration including
// basic validation and OIDC issuer verification.
func validateIdentityProviderConfig(ctx context.Context, idpConfig *types.IdentityProvider) error {
if err := idpConfig.Validate(); err != nil {
return status.Errorf(status.InvalidArgument, "%s", err.Error())
}
// Validate the issuer by calling the OIDC discovery endpoint
if idpConfig.Issuer != "" {
if err := validateOIDCIssuer(ctx, idpConfig.Issuer); err != nil {
return status.Errorf(status.InvalidArgument, "%s", err.Error())
}
}
return nil
}
// GetIdentityProviders returns all identity providers for an account
func (am *DefaultAccountManager) GetIdentityProviders(ctx context.Context, accountID, userID string) ([]*types.IdentityProvider, error) {
ok, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.IdentityProviders, operations.Read)
@@ -82,8 +151,8 @@ func (am *DefaultAccountManager) CreateIdentityProvider(ctx context.Context, acc
return nil, status.NewPermissionDeniedError()
}
if err := idpConfig.Validate(); err != nil {
return nil, status.Errorf(status.InvalidArgument, "%s", err.Error())
if err := validateIdentityProviderConfig(ctx, idpConfig); err != nil {
return nil, err
}
embeddedManager, ok := am.idpManager.(*idp.EmbeddedIdPManager)
@@ -119,8 +188,8 @@ func (am *DefaultAccountManager) UpdateIdentityProvider(ctx context.Context, acc
return nil, status.NewPermissionDeniedError()
}
if err := idpConfig.Validate(); err != nil {
return nil, status.Errorf(status.InvalidArgument, "%s", err.Error())
if err := validateIdentityProviderConfig(ctx, idpConfig); err != nil {
return nil, err
}
embeddedManager, ok := am.idpManager.(*idp.EmbeddedIdPManager)

View File

@@ -2,6 +2,10 @@ package server
import (
"context"
"encoding/json"
"errors"
"net/http"
"net/http/httptest"
"path/filepath"
"testing"
@@ -200,3 +204,109 @@ func TestDefaultAccountManager_UpdateIdentityProvider_Validation(t *testing.T) {
require.Error(t, err)
assert.Contains(t, err.Error(), "name is required")
}
func TestValidateOIDCIssuer(t *testing.T) {
tests := []struct {
name string
setupServer func() *httptest.Server
expectedErr error
expectedErrMsg string
}{
{
name: "issuer mismatch",
setupServer: func() *httptest.Server {
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
resp := oidcProviderJSON{Issuer: "https://different-issuer.com"}
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(resp)
}))
},
expectedErr: types.ErrIdentityProviderIssuerMismatch,
expectedErrMsg: "does not match",
},
{
name: "server returns non-200 status",
setupServer: func() *httptest.Server {
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNotFound)
_, _ = w.Write([]byte("not found"))
}))
},
expectedErr: types.ErrIdentityProviderIssuerUnreachable,
expectedErrMsg: "404",
},
{
name: "server returns invalid JSON",
setupServer: func() *httptest.Server {
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte("invalid json"))
}))
},
expectedErr: types.ErrIdentityProviderIssuerUnreachable,
expectedErrMsg: "failed to decode",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
server := tt.setupServer()
defer server.Close()
err := validateOIDCIssuer(context.Background(), server.URL)
require.Error(t, err)
assert.True(t, errors.Is(err, tt.expectedErr), "expected error %v, got %v", tt.expectedErr, err)
if tt.expectedErrMsg != "" {
assert.Contains(t, err.Error(), tt.expectedErrMsg)
}
})
}
}
func TestValidateOIDCIssuer_Success(t *testing.T) {
// Create a server that returns its own URL as the issuer
var server *httptest.Server
server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/.well-known/openid-configuration" {
http.NotFound(w, r)
return
}
resp := oidcProviderJSON{Issuer: server.URL}
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(resp)
}))
defer server.Close()
err := validateOIDCIssuer(context.Background(), server.URL)
require.NoError(t, err)
}
func TestValidateOIDCIssuer_UnreachableServer(t *testing.T) {
// Use a URL that will definitely fail to connect
err := validateOIDCIssuer(context.Background(), "http://localhost:59999")
require.Error(t, err)
assert.True(t, errors.Is(err, types.ErrIdentityProviderIssuerUnreachable))
}
func TestValidateOIDCIssuer_TrailingSlash(t *testing.T) {
// Test that trailing slashes are handled correctly
var server *httptest.Server
server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/.well-known/openid-configuration" {
http.NotFound(w, r)
return
}
// Return issuer without trailing slash
resp := oidcProviderJSON{Issuer: server.URL}
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(resp)
}))
defer server.Close()
// Pass issuer with trailing slash
err := validateOIDCIssuer(context.Background(), server.URL+"/")
// This should fail because the issuer returned doesn't have trailing slash
require.Error(t, err)
assert.True(t, errors.Is(err, types.ErrIdentityProviderIssuerMismatch))
}

View File

@@ -5,6 +5,7 @@ import (
"errors"
"fmt"
"net/http"
"strings"
"github.com/dexidp/dex/storage"
"github.com/google/uuid"
@@ -19,7 +20,7 @@ const (
staticClientCLI = "netbird-cli"
defaultCLIRedirectURL1 = "http://localhost:53000/"
defaultCLIRedirectURL2 = "http://localhost:54000/"
defaultScopes = "openid profile email offline_access"
defaultScopes = "openid profile email"
defaultUserIDClaim = "sub"
)
@@ -27,8 +28,11 @@ const (
type EmbeddedIdPConfig struct {
// Enabled indicates whether the embedded IDP is enabled
Enabled bool
// Issuer is the OIDC issuer URL (e.g., "http://localhost:3002/oauth2")
// Issuer is the OIDC issuer URL (e.g., "https://management.netbird.io/oauth2")
Issuer string
// LocalAddress is the management server's local listen address (e.g., ":8080" or "localhost:8080")
// Used for internal JWT validation to avoid external network calls
LocalAddress string
// Storage configuration for the IdP database
Storage EmbeddedStorageConfig
// DashboardRedirectURIs are the OAuth2 redirect URIs for the dashboard client
@@ -146,7 +150,12 @@ var _ OAuthConfigProvider = (*EmbeddedIdPManager)(nil)
// OAuthConfigProvider defines the interface for OAuth configuration needed by auth flows.
type OAuthConfigProvider interface {
GetIssuer() string
// GetKeysLocation returns the public JWKS endpoint URL (uses external issuer URL)
GetKeysLocation() string
// GetLocalKeysLocation returns the localhost JWKS endpoint URL for internal use.
// Management server has embedded Dex and can validate tokens via localhost,
// avoiding external network calls and DNS resolution issues during startup.
GetLocalKeysLocation() string
GetClientIDs() []string
GetUserIDClaim() string
GetTokenEndpoint() string
@@ -500,6 +509,22 @@ func (m *EmbeddedIdPManager) GetKeysLocation() string {
return m.provider.GetKeysLocation()
}
// GetLocalKeysLocation returns the localhost JWKS endpoint URL for internal token validation.
// Uses the LocalAddress from config (management server's listen address) since embedded Dex
// is served by the management HTTP server, not a standalone Dex server.
func (m *EmbeddedIdPManager) GetLocalKeysLocation() string {
addr := m.config.LocalAddress
if addr == "" {
return ""
}
// Construct localhost URL from listen address
// addr is in format ":port" or "host:port" or "localhost:port"
if strings.HasPrefix(addr, ":") {
return fmt.Sprintf("http://localhost%s/oauth2/keys", addr)
}
return fmt.Sprintf("http://%s/oauth2/keys", addr)
}
// GetClientIDs returns the OAuth2 client IDs configured for this provider.
func (m *EmbeddedIdPManager) GetClientIDs() []string {
return []string{staticClientDashboard, staticClientCLI}

View File

@@ -247,3 +247,61 @@ func TestEmbeddedIdPManager_UserIDFormat_MatchesJWT(t *testing.T) {
t.Logf(" Raw UUID: %s", rawUserID)
t.Logf(" Connector: %s", connectorID)
}
func TestEmbeddedIdPManager_GetLocalKeysLocation(t *testing.T) {
ctx := context.Background()
tmpDir, err := os.MkdirTemp("", "embedded-idp-test-*")
require.NoError(t, err)
defer os.RemoveAll(tmpDir)
tests := []struct {
name string
localAddress string
expected string
}{
{
name: "localhost with port",
localAddress: "localhost:8080",
expected: "http://localhost:8080/oauth2/keys",
},
{
name: "localhost with https port",
localAddress: "localhost:443",
expected: "http://localhost:443/oauth2/keys",
},
{
name: "port only format",
localAddress: ":8080",
expected: "http://localhost:8080/oauth2/keys",
},
{
name: "empty address",
localAddress: "",
expected: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
config := &EmbeddedIdPConfig{
Enabled: true,
Issuer: "http://localhost:5556/dex",
LocalAddress: tt.localAddress,
Storage: EmbeddedStorageConfig{
Type: "sqlite3",
Config: EmbeddedStorageTypeConfig{
File: filepath.Join(tmpDir, "dex-"+tt.name+".db"),
},
},
}
manager, err := NewEmbeddedIdPManager(ctx, config, nil)
require.NoError(t, err)
defer func() { _ = manager.Stop(ctx) }()
result := manager.GetLocalKeysLocation()
assert.Equal(t, tt.expected, result)
})
}
}

View File

@@ -121,7 +121,7 @@ func NewPocketIdManager(config PocketIdClientConfig, appMetrics telemetry.AppMet
func (p *PocketIdManager) request(ctx context.Context, method, resource string, query *url.Values, body string) ([]byte, error) {
var MethodsWithBody = []string{http.MethodPost, http.MethodPut}
if !slices.Contains(MethodsWithBody, method) && body != "" {
return nil, fmt.Errorf("Body provided to unsupported method: %s", method)
return nil, fmt.Errorf("body provided to unsupported method: %s", method)
}
reqURL := fmt.Sprintf("%s/api/%s", p.managementEndpoint, resource)
@@ -301,7 +301,7 @@ func (p *PocketIdManager) CreateUser(ctx context.Context, email, name, accountID
if p.appMetrics != nil {
p.appMetrics.IDPMetrics().CountCreateUser()
}
var pending bool = true
pending := true
ret := &UserData{
Email: email,
Name: name,

View File

@@ -357,7 +357,7 @@ func (zm *ZitadelManager) CreateUser(ctx context.Context, email, name, accountID
return nil, err
}
var pending bool = true
pending := true
ret := &UserData{
Email: email,
Name: name,

View File

@@ -393,7 +393,7 @@ func CreateIndexIfNotExists[T any](ctx context.Context, db *gorm.DB, indexName s
return fmt.Errorf("failed to parse model schema: %w", err)
}
tableName := stmt.Schema.Table
dialect := db.Dialector.Name()
dialect := db.Name()
if db.Migrator().HasIndex(&model, indexName) {
log.WithContext(ctx).Infof("index %s already exists on table %s", indexName, tableName)
@@ -404,10 +404,11 @@ func CreateIndexIfNotExists[T any](ctx context.Context, db *gorm.DB, indexName s
if dialect == "mysql" {
var withLength []string
for _, col := range columns {
if col == "ip" || col == "dns_label" {
withLength = append(withLength, fmt.Sprintf("%s(64)", col))
quotedCol := fmt.Sprintf("`%s`", col)
if col == "ip" || col == "dns_label" || col == "key" {
withLength = append(withLength, fmt.Sprintf("%s(64)", quotedCol))
} else {
withLength = append(withLength, col)
withLength = append(withLength, quotedCol)
}
}
columnClause = strings.Join(withLength, ", ")
@@ -487,3 +488,54 @@ func MigrateJsonToTable[T any](ctx context.Context, db *gorm.DB, columnName stri
log.WithContext(ctx).Infof("Migration of JSON field %s from table %s into separate table completed", columnName, tableName)
return nil
}
func RemoveDuplicatePeerKeys(ctx context.Context, db *gorm.DB) error {
if !db.Migrator().HasTable("peers") {
log.WithContext(ctx).Debug("peers table does not exist, skipping duplicate key cleanup")
return nil
}
keyColumn := GetColumnName(db, "key")
var duplicates []struct {
Key string
Count int64
}
if err := db.Table("peers").
Select(keyColumn + ", COUNT(*) as count").
Group(keyColumn).
Having("COUNT(*) > 1").
Find(&duplicates).Error; err != nil {
return fmt.Errorf("find duplicate keys: %w", err)
}
if len(duplicates) == 0 {
return nil
}
log.WithContext(ctx).Warnf("Found %d duplicate peer keys, cleaning up", len(duplicates))
for _, dup := range duplicates {
var peerIDs []string
if err := db.Table("peers").
Select("id").
Where(keyColumn+" = ?", dup.Key).
Order("peer_status_last_seen DESC").
Pluck("id", &peerIDs).Error; err != nil {
return fmt.Errorf("get peers for key: %w", err)
}
if len(peerIDs) <= 1 {
continue
}
idsToDelete := peerIDs[1:]
if err := db.Table("peers").Where("id IN ?", idsToDelete).Delete(nil).Error; err != nil {
return fmt.Errorf("delete duplicate peers: %w", err)
}
}
return nil
}

View File

@@ -340,3 +340,104 @@ func TestCreateIndexIfExists(t *testing.T) {
exist = db.Migrator().HasIndex(&nbpeer.Peer{}, indexName)
assert.True(t, exist, "Should have the index")
}
type testPeer struct {
ID string `gorm:"primaryKey"`
Key string `gorm:"index"`
PeerStatusLastSeen time.Time
PeerStatusConnected bool
}
func (testPeer) TableName() string {
return "peers"
}
func setupPeerTestDB(t *testing.T) *gorm.DB {
t.Helper()
db := setupDatabase(t)
_ = db.Migrator().DropTable(&testPeer{})
err := db.AutoMigrate(&testPeer{})
require.NoError(t, err, "Failed to auto-migrate tables")
return db
}
func TestRemoveDuplicatePeerKeys_NoDuplicates(t *testing.T) {
db := setupPeerTestDB(t)
now := time.Now()
peers := []testPeer{
{ID: "peer1", Key: "key1", PeerStatusLastSeen: now},
{ID: "peer2", Key: "key2", PeerStatusLastSeen: now},
{ID: "peer3", Key: "key3", PeerStatusLastSeen: now},
}
for _, p := range peers {
err := db.Create(&p).Error
require.NoError(t, err)
}
err := migration.RemoveDuplicatePeerKeys(context.Background(), db)
require.NoError(t, err)
var count int64
db.Model(&testPeer{}).Count(&count)
assert.Equal(t, int64(len(peers)), count, "All peers should remain when no duplicates")
}
func TestRemoveDuplicatePeerKeys_WithDuplicates(t *testing.T) {
db := setupPeerTestDB(t)
now := time.Now()
peers := []testPeer{
{ID: "peer1", Key: "key1", PeerStatusLastSeen: now.Add(-2 * time.Hour)},
{ID: "peer2", Key: "key1", PeerStatusLastSeen: now.Add(-1 * time.Hour)},
{ID: "peer3", Key: "key1", PeerStatusLastSeen: now},
{ID: "peer4", Key: "key2", PeerStatusLastSeen: now},
{ID: "peer5", Key: "key3", PeerStatusLastSeen: now.Add(-1 * time.Hour)},
{ID: "peer6", Key: "key3", PeerStatusLastSeen: now},
}
for _, p := range peers {
err := db.Create(&p).Error
require.NoError(t, err)
}
err := migration.RemoveDuplicatePeerKeys(context.Background(), db)
require.NoError(t, err)
var count int64
db.Model(&testPeer{}).Count(&count)
assert.Equal(t, int64(3), count, "Should have 3 peers after removing duplicates")
var remainingPeers []testPeer
err = db.Find(&remainingPeers).Error
require.NoError(t, err)
remainingIDs := make(map[string]bool)
for _, p := range remainingPeers {
remainingIDs[p.ID] = true
}
assert.True(t, remainingIDs["peer3"], "peer3 should remain (most recent for key1)")
assert.True(t, remainingIDs["peer4"], "peer4 should remain (only peer for key2)")
assert.True(t, remainingIDs["peer6"], "peer6 should remain (most recent for key3)")
assert.False(t, remainingIDs["peer1"], "peer1 should be deleted (older duplicate)")
assert.False(t, remainingIDs["peer2"], "peer2 should be deleted (older duplicate)")
assert.False(t, remainingIDs["peer5"], "peer5 should be deleted (older duplicate)")
}
func TestRemoveDuplicatePeerKeys_EmptyTable(t *testing.T) {
db := setupPeerTestDB(t)
err := migration.RemoveDuplicatePeerKeys(context.Background(), db)
require.NoError(t, err, "Should not fail on empty table")
}
func TestRemoveDuplicatePeerKeys_NoTable(t *testing.T) {
db := setupDatabase(t)
_ = db.Migrator().DropTable(&testPeer{})
err := migration.RemoveDuplicatePeerKeys(context.Background(), db)
require.NoError(t, err, "Should not fail when table does not exist")
}

View File

@@ -20,7 +20,7 @@ import (
const domainPattern = `^(?i)[a-z0-9]+([\-\.]{1}[a-z0-9]+)*[*.a-z]{1,}$`
var invalidDomainName = errors.New("invalid domain name")
var errInvalidDomainName = errors.New("invalid domain name")
// GetNameServerGroup gets a nameserver group object from account and nameserver group IDs
func (am *DefaultAccountManager) GetNameServerGroup(ctx context.Context, accountID, userID, nsGroupID string) (*nbdns.NameServerGroup, error) {
@@ -314,7 +314,7 @@ func validateDomain(domain string) error {
_, valid := dns.IsDomainName(domain)
if !valid {
return invalidDomainName
return errInvalidDomainName
}
return nil

View File

@@ -19,7 +19,7 @@ type Peer struct {
// AccountID is a reference to Account that this object belongs
AccountID string `json:"-" gorm:"index"`
// WireGuard public key
Key string `gorm:"index"`
Key string // uniqueness index (check migrations)
// IP address of the Peer
IP net.IP `gorm:"serializer:json"` // uniqueness index per accountID (check migrations)
// Meta is a Peer system meta data

View File

@@ -2129,12 +2129,14 @@ func Test_DeletePeer(t *testing.T) {
"peer1": {
ID: "peer1",
AccountID: accountID,
Key: "key1",
IP: net.IP{1, 1, 1, 1},
DNSLabel: "peer1.test",
},
"peer2": {
ID: "peer2",
AccountID: accountID,
Key: "key2",
IP: net.IP{2, 2, 2, 2},
DNSLabel: "peer2.test",
},

View File

@@ -158,7 +158,7 @@ func arePostureCheckChangesAffectPeers(ctx context.Context, transaction store.St
// validatePostureChecks validates the posture checks.
func validatePostureChecks(ctx context.Context, transaction store.Store, accountID string, postureChecks *posture.Checks) error {
if err := postureChecks.Validate(); err != nil {
return status.Errorf(status.InvalidArgument, "%s", err.Error()) //nolint
return status.Errorf(status.InvalidArgument, "%v", err.Error()) //nolint
}
// If the posture check already has an ID, verify its existence in the store.

View File

@@ -997,9 +997,10 @@ func TestGetAccount_ComprehensiveFieldValidation(t *testing.T) {
// Find posture checks by ID
var pc1, pc2 *posture.Checks
for _, pc := range retrievedAccount.PostureChecks {
if pc.ID == postureCheckID1 {
switch pc.ID {
case postureCheckID1:
pc1 = pc
} else if pc.ID == postureCheckID2 {
case postureCheckID2:
pc2 = pc
}
}

View File

@@ -30,7 +30,6 @@ import (
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/management/server/util"
nbroute "github.com/netbirdio/netbird/route"
route2 "github.com/netbirdio/netbird/route"
"github.com/netbirdio/netbird/shared/management/status"
"github.com/netbirdio/netbird/util/crypt"
)
@@ -110,12 +109,12 @@ func runLargeTest(t *testing.T, store Store) {
AccountID: account.Id,
}
account.Users[user.Id] = user
route := &route2.Route{
ID: route2.ID(fmt.Sprintf("network-id-%d", n)),
route := &nbroute.Route{
ID: nbroute.ID(fmt.Sprintf("network-id-%d", n)),
Description: "base route",
NetID: route2.NetID(fmt.Sprintf("network-id-%d", n)),
NetID: nbroute.NetID(fmt.Sprintf("network-id-%d", n)),
Network: netip.MustParsePrefix(netIP.String() + "/24"),
NetworkType: route2.IPv4Network,
NetworkType: nbroute.IPv4Network,
Metric: 9999,
Masquerade: false,
Enabled: true,
@@ -689,7 +688,7 @@ func TestMigrate(t *testing.T) {
require.NoError(t, err, "Failed to insert Gob data")
type route struct {
route2.Route
nbroute.Route
Network netip.Prefix `gorm:"serializer:gob"`
PeerGroups []string `gorm:"serializer:gob"`
}
@@ -698,7 +697,7 @@ func TestMigrate(t *testing.T) {
rt := &route{
Network: prefix,
PeerGroups: []string{"group1", "group2"},
Route: route2.Route{ID: "route1"},
Route: nbroute.Route{ID: "route1"},
}
err = store.(*SqlStore).db.Save(rt).Error
@@ -714,7 +713,7 @@ func TestMigrate(t *testing.T) {
require.NoError(t, err, "Failed to delete Gob data")
prefix = netip.MustParsePrefix("12.0.0.0/24")
nRT := &route2.Route{
nRT := &nbroute.Route{
Network: prefix,
ID: "route2",
Peer: "peer-id",
@@ -969,6 +968,7 @@ func TestSqlite_GetTakenIPs(t *testing.T) {
peer1 := &nbpeer.Peer{
ID: "peer1",
AccountID: existingAccountID,
Key: "key1",
DNSLabel: "peer1",
IP: net.IP{1, 1, 1, 1},
}
@@ -983,6 +983,7 @@ func TestSqlite_GetTakenIPs(t *testing.T) {
peer2 := &nbpeer.Peer{
ID: "peer1second",
AccountID: existingAccountID,
Key: "key2",
DNSLabel: "peer1-1",
IP: net.IP{2, 2, 2, 2},
}
@@ -1010,6 +1011,7 @@ func TestSqlite_GetPeerLabelsInAccount(t *testing.T) {
peer1 := &nbpeer.Peer{
ID: "peer1",
AccountID: existingAccountID,
Key: "key1",
DNSLabel: "peer1",
IP: net.IP{1, 1, 1, 1},
}
@@ -1023,6 +1025,7 @@ func TestSqlite_GetPeerLabelsInAccount(t *testing.T) {
peer2 := &nbpeer.Peer{
ID: "peer1second",
AccountID: existingAccountID,
Key: "key2",
DNSLabel: "peer1-1",
IP: net.IP{2, 2, 2, 2},
}
@@ -1049,6 +1052,7 @@ func Test_AddPeerWithSameDnsLabel(t *testing.T) {
peer1 := &nbpeer.Peer{
ID: "peer1",
AccountID: existingAccountID,
Key: "key1",
DNSLabel: "peer1.domain.test",
}
err = store.AddPeerToAccount(context.Background(), peer1)
@@ -1057,6 +1061,7 @@ func Test_AddPeerWithSameDnsLabel(t *testing.T) {
peer2 := &nbpeer.Peer{
ID: "peer1second",
AccountID: existingAccountID,
Key: "key2",
DNSLabel: "peer1.domain.test",
}
err = store.AddPeerToAccount(context.Background(), peer2)
@@ -1074,6 +1079,7 @@ func Test_AddPeerWithSameIP(t *testing.T) {
peer1 := &nbpeer.Peer{
ID: "peer1",
AccountID: existingAccountID,
Key: "key1",
IP: net.IP{1, 1, 1, 1},
}
err = store.AddPeerToAccount(context.Background(), peer1)
@@ -1082,6 +1088,7 @@ func Test_AddPeerWithSameIP(t *testing.T) {
peer2 := &nbpeer.Peer{
ID: "peer1second",
AccountID: existingAccountID,
Key: "key2",
IP: net.IP{1, 1, 1, 1},
}
err = store.AddPeerToAccount(context.Background(), peer2)
@@ -3544,13 +3551,13 @@ func TestSqlStore_SaveRoute(t *testing.T) {
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
route := &route2.Route{
route := &nbroute.Route{
ID: "route-id",
AccountID: accountID,
Network: netip.MustParsePrefix("10.10.0.0/16"),
NetID: "netID",
PeerGroups: []string{"routeA"},
NetworkType: route2.IPv4Network,
NetworkType: nbroute.IPv4Network,
Masquerade: true,
Metric: 9999,
Enabled: true,
@@ -3697,6 +3704,7 @@ func BenchmarkGetAccountPeers(b *testing.B) {
peer := &nbpeer.Peer{
ID: fmt.Sprintf("peer-%d", i),
AccountID: accountID,
Key: fmt.Sprintf("key-%d", i),
DNSLabel: fmt.Sprintf("peer%d.example.com", i),
IP: intToIPv4(uint32(i)),
}

View File

@@ -350,8 +350,13 @@ func getMigrationsPreAuto(ctx context.Context) []migrationFunc {
func(db *gorm.DB) error {
return migration.MigrateNewField[types.User](ctx, db, "email", "")
},
func(db *gorm.DB) error {
return migration.RemoveDuplicatePeerKeys(ctx, db)
},
}
} // migratePostAuto migrates the SQLite database to the latest schema
}
// migratePostAuto migrates the SQLite database to the latest schema
func migratePostAuto(ctx context.Context, db *gorm.DB) error {
migrations := getMigrationsPostAuto(ctx)
@@ -381,6 +386,12 @@ func getMigrationsPostAuto(ctx context.Context) []migrationFunc {
}
})
},
func(db *gorm.DB) error {
return migration.DropIndex[nbpeer.Peer](ctx, db, "idx_peers_key")
},
func(db *gorm.DB) error {
return migration.CreateIndexIfNotExists[nbpeer.Peer](ctx, db, "idx_peers_key_unique", "key")
},
}
}

View File

@@ -14,7 +14,7 @@ CREATE TABLE `posture_checks` (`id` text,`name` text,`description` text,`account
CREATE TABLE `network_addresses` (`net_ip` text,`mac` text);
CREATE INDEX `idx_accounts_domain` ON `accounts`(`domain`);
CREATE INDEX `idx_setup_keys_account_id` ON `setup_keys`(`account_id`);
CREATE INDEX `idx_peers_key` ON `peers`(`key`);
CREATE UNIQUE INDEX `idx_peers_key_unique` ON `peers`(`key`);
CREATE INDEX `idx_peers_account_id` ON `peers`(`account_id`);
CREATE INDEX `idx_users_account_id` ON `users`(`account_id`);
CREATE INDEX `idx_personal_access_tokens_user_id` ON `personal_access_tokens`(`user_id`);

View File

@@ -18,7 +18,7 @@ CREATE TABLE `network_resources` (`id` text,`network_id` text,`account_id` text,
CREATE TABLE `networks` (`id` text,`account_id` text,`name` text,`description` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_networks` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
CREATE INDEX `idx_accounts_domain` ON `accounts`(`domain`);
CREATE INDEX `idx_setup_keys_account_id` ON `setup_keys`(`account_id`);
CREATE INDEX `idx_peers_key` ON `peers`(`key`);
CREATE UNIQUE INDEX `idx_peers_key_unique` ON `peers`(`key`);
CREATE INDEX `idx_peers_account_id` ON `peers`(`account_id`);
CREATE INDEX `idx_peers_account_id_ip` ON `peers`(`account_id`,`ip`);
CREATE INDEX `idx_users_account_id` ON `users`(`account_id`);
@@ -54,4 +54,4 @@ INSERT INTO policy_rules VALUES('cs387mkv2d4bgq41b6n0','cs1tnh0hhcjnqoiuebf0','D
INSERT INTO network_routers VALUES('ctc20ji7qv9ck2sebc80','ct286bi7qv930dsrrug0','bf1c8084-ba50-4ce7-9439-34653001fc3b','cs1tnh0hhcjnqoiuebeg',NULL,0,0);
INSERT INTO network_resources VALUES ('ctc4nci7qv9061u6ilfg','ct286bi7qv930dsrrug0','bf1c8084-ba50-4ce7-9439-34653001fc3b','Host','192.168.1.1');
INSERT INTO networks VALUES('ct286bi7qv930dsrrug0','bf1c8084-ba50-4ce7-9439-34653001fc3b','Test Network','Test Network');
INSERT INTO peers VALUES('ct286bi7qv930dsrrug0','bf1c8084-ba50-4ce7-9439-34653001fc3b','','','"192.168.0.0"','','','','','','','','','','','','','','','','','test','test','2023-01-01 00:00:00+00:00',0,0,0,'a23efe53-63fb-11ec-90d6-0242ac120003','',0,0,'2023-01-01 00:00:00+00:00','2023-01-01 00:00:00+00:00',0,'','','',0);
INSERT INTO peers VALUES('ct286bi7qv930dsrrug0','bf1c8084-ba50-4ce7-9439-34653001fc3b','6kjbmVq1hmucVzvBXo5OucY5OYv+jSsB1jUTLq291Do=','','"192.168.0.0"','','','','','','','','','','','','','','','','','test','test','2023-01-01 00:00:00+00:00',0,0,0,'a23efe53-63fb-11ec-90d6-0242ac120003','',0,0,'2023-01-01 00:00:00+00:00','2023-01-01 00:00:00+00:00',0,'','','',0);

View File

@@ -14,7 +14,7 @@ CREATE TABLE `posture_checks` (`id` text,`name` text,`description` text,`account
CREATE TABLE `network_addresses` (`net_ip` text,`mac` text);
CREATE INDEX `idx_accounts_domain` ON `accounts`(`domain`);
CREATE INDEX `idx_setup_keys_account_id` ON `setup_keys`(`account_id`);
CREATE INDEX `idx_peers_key` ON `peers`(`key`);
CREATE UNIQUE INDEX `idx_peers_key_unique` ON `peers`(`key`);
CREATE INDEX `idx_peers_account_id` ON `peers`(`account_id`);
CREATE INDEX `idx_users_account_id` ON `users`(`account_id`);
CREATE INDEX `idx_personal_access_tokens_user_id` ON `personal_access_tokens`(`user_id`);

View File

@@ -14,7 +14,7 @@ CREATE TABLE `posture_checks` (`id` text,`name` text,`description` text,`account
CREATE TABLE `network_addresses` (`net_ip` text,`mac` text);
CREATE INDEX `idx_accounts_domain` ON `accounts`(`domain`);
CREATE INDEX `idx_setup_keys_account_id` ON `setup_keys`(`account_id`);
CREATE INDEX `idx_peers_key` ON `peers`(`key`);
CREATE UNIQUE INDEX `idx_peers_key_unique` ON `peers`(`key`);
CREATE INDEX `idx_peers_account_id` ON `peers`(`account_id`);
CREATE INDEX `idx_users_account_id` ON `users`(`account_id`);
CREATE INDEX `idx_personal_access_tokens_user_id` ON `personal_access_tokens`(`user_id`);
@@ -30,7 +30,7 @@ INSERT INTO setup_keys VALUES('','bf1c8084-ba50-4ce7-9439-34653001fc3b','A2C8E62
INSERT INTO peers VALUES('cfvprsrlo1hqoo49ohog','bf1c8084-ba50-4ce7-9439-34653001fc3b','5rvhvriKJZ3S9oxYToVj5TzDM9u9y8cxg7htIMWlYAg=','72546A29-6BC8-4311-BCFC-9CDBF33F1A48','"100.64.114.31"','f2a34f6a4731','linux','Linux','11','unknown','Debian GNU/Linux','','0.12.0','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'f2a34f6a4731','f2a34f6a4731','2023-03-02 09:21:02.189035775+01:00',0,0,0,'','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAILzUUSYG/LGnV8zarb2SGN+tib/PZ+M7cL4WtTzUrTpk',0,1,0,'2023-03-01 19:48:19.817799698+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0);
INSERT INTO peers VALUES('cg05lnblo1hkg2j514p0','bf1c8084-ba50-4ce7-9439-34653001fc3b','RlSy2vzoG2HyMBTUImXOiVhCBiiBa5qD5xzMxkiFDW4=','','"100.64.39.54"','expiredhost','linux','Linux','22.04','x86_64','Ubuntu','','development','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'expiredhost','expiredhost','2023-03-02 09:19:57.276717255+01:00',0,1,0,'edafee4e-63fb-11ec-90d6-0242ac120003','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIMbK5ZXJsGOOWoBT4OmkPtgdPZe2Q7bDuS/zjn2CZxhK',0,1,0,'2023-03-02 09:14:21.791679181+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0);
INSERT INTO peers VALUES('cg3161rlo1hs9cq94gdg','bf1c8084-ba50-4ce7-9439-34653001fc3b','mVABSKj28gv+JRsf7e0NEGKgSOGTfU/nPB2cpuG56HU=','','"100.64.117.96"','testhost','linux','Linux','22.04','x86_64','Ubuntu','','development','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'testhost','testhost','2023-03-06 18:21:27.252010027+01:00',0,0,0,'edafee4e-63fb-11ec-90d6-0242ac120003','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAINWvvUkFFcrj48CWTkNUb/do/n52i1L5dH4DhGu+4ZuM',0,0,0,'2023-03-07 09:02:47.442857106+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0);
INSERT INTO peers VALUES('csrnkiq7qv9d8aitqd50','bf1c8084-ba50-4ce7-9439-34653001fc3b','mVABSKj28gv+JRsf7e0NEGKgSOGTfU/nPB2cpuG56HU=','','"100.64.117.97"','testhost','linux','Linux','22.04','x86_64','Ubuntu','','development','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'testhost','testhost-1','2023-03-06 18:21:27.252010027+01:00',0,0,0,'f4f6d672-63fb-11ec-90d6-0242ac120003','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAINWvvUkFFcrj48CWTkNUb/do/n52i1L5dH4DhGu+4ZuM',0,0,1,'2023-03-07 09:02:47.442857106+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0);
INSERT INTO peers VALUES('csrnkiq7qv9d8aitqd50','bf1c8084-ba50-4ce7-9439-34653001fc3b','nVABSKj28gv+JRsf7e0NEGKgSOGTfU/nPB2cpuG56HX=','','"100.64.117.97"','testhost','linux','Linux','22.04','x86_64','Ubuntu','','development','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'testhost','testhost-1','2023-03-06 18:21:27.252010027+01:00',0,0,0,'f4f6d672-63fb-11ec-90d6-0242ac120003','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAINWvvUkFFcrj48CWTkNUb/do/n52i1L5dH4DhGu+4ZuM',0,0,1,'2023-03-07 09:02:47.442857106+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0);
INSERT INTO users VALUES('f4f6d672-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','user',0,0,'','[]',0,NULL,'2024-10-02 17:00:32.528196+02:00','api',0,'');
INSERT INTO users VALUES('edafee4e-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','admin',0,0,'','[]',0,NULL,'2024-10-02 17:00:32.528196+02:00','api',0,'');
INSERT INTO installations VALUES(1,'');

View File

@@ -14,7 +14,7 @@ CREATE TABLE `posture_checks` (`id` text,`name` text,`description` text,`account
CREATE TABLE `network_addresses` (`net_ip` text,`mac` text);
CREATE INDEX `idx_accounts_domain` ON `accounts`(`domain`);
CREATE INDEX `idx_setup_keys_account_id` ON `setup_keys`(`account_id`);
CREATE INDEX `idx_peers_key` ON `peers`(`key`);
CREATE UNIQUE INDEX `idx_peers_key_unique` ON `peers`(`key`);
CREATE INDEX `idx_peers_account_id` ON `peers`(`account_id`);
CREATE INDEX `idx_users_account_id` ON `users`(`account_id`);
CREATE INDEX `idx_personal_access_tokens_user_id` ON `personal_access_tokens`(`user_id`);

View File

@@ -1,5 +1,4 @@
//go:build !ios
// +build !ios
package testutil

View File

@@ -1,5 +1,4 @@
//go:build ios
// +build ios
package testutil

View File

@@ -7,12 +7,14 @@ import (
// Identity provider validation errors
var (
ErrIdentityProviderNameRequired = errors.New("identity provider name is required")
ErrIdentityProviderTypeRequired = errors.New("identity provider type is required")
ErrIdentityProviderTypeUnsupported = errors.New("unsupported identity provider type")
ErrIdentityProviderIssuerRequired = errors.New("identity provider issuer is required")
ErrIdentityProviderIssuerInvalid = errors.New("identity provider issuer must be a valid URL")
ErrIdentityProviderClientIDRequired = errors.New("identity provider client ID is required")
ErrIdentityProviderNameRequired = errors.New("identity provider name is required")
ErrIdentityProviderTypeRequired = errors.New("identity provider type is required")
ErrIdentityProviderTypeUnsupported = errors.New("unsupported identity provider type")
ErrIdentityProviderIssuerRequired = errors.New("identity provider issuer is required")
ErrIdentityProviderIssuerInvalid = errors.New("identity provider issuer must be a valid URL")
ErrIdentityProviderIssuerUnreachable = errors.New("identity provider issuer is unreachable")
ErrIdentityProviderIssuerMismatch = errors.New("identity provider issuer does not match the issuer returned by the provider")
ErrIdentityProviderClientIDRequired = errors.New("identity provider client ID is required")
)
// IdentityProviderType is the type of identity provider

Some files were not shown because too many files have changed in this diff Show More