mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-18 15:16:48 -04:00
Compare commits
14 Commits
wasm-debug
...
nmap/compa
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ca432ff681 | ||
|
|
7b5d7aeb2e | ||
|
|
3bdce8d0b6 | ||
|
|
d534ce9dfc | ||
|
|
bbc2b42807 | ||
|
|
db9cc52c96 | ||
|
|
3209b241d9 | ||
|
|
7566afd7d0 | ||
|
|
e93d4132d3 | ||
|
|
21e5e6ddff | ||
|
|
10fb18736b | ||
|
|
942abeca0c | ||
|
|
e184a43e8a | ||
|
|
f33f84299f |
@@ -1,15 +1,15 @@
|
||||
FROM golang:1.25-bookworm
|
||||
FROM golang:1.23-bullseye
|
||||
|
||||
RUN apt-get update && export DEBIAN_FRONTEND=noninteractive \
|
||||
&& apt-get -y install --no-install-recommends\
|
||||
gettext-base=0.21-12 \
|
||||
iptables=1.8.9-2 \
|
||||
libgl1-mesa-dev=22.3.6-1+deb12u1 \
|
||||
xorg-dev=1:7.7+23 \
|
||||
libayatana-appindicator3-dev=0.5.92-1 \
|
||||
gettext-base=0.21-4 \
|
||||
iptables=1.8.7-1 \
|
||||
libgl1-mesa-dev=20.3.5-1 \
|
||||
xorg-dev=1:7.7+22 \
|
||||
libayatana-appindicator3-dev=0.5.5-2+deb11u2 \
|
||||
&& apt-get clean \
|
||||
&& rm -rf /var/lib/apt/lists/* \
|
||||
&& go install -v golang.org/x/tools/gopls@latest
|
||||
&& go install -v golang.org/x/tools/gopls@v0.18.1
|
||||
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
2
.github/workflows/golang-test-freebsd.yml
vendored
2
.github/workflows/golang-test-freebsd.yml
vendored
@@ -25,7 +25,7 @@ jobs:
|
||||
release: "14.2"
|
||||
prepare: |
|
||||
pkg install -y curl pkgconf xorg
|
||||
GO_TARBALL="go1.25.3.freebsd-amd64.tar.gz"
|
||||
GO_TARBALL="go1.24.10.freebsd-amd64.tar.gz"
|
||||
GO_URL="https://go.dev/dl/$GO_TARBALL"
|
||||
curl -vLO "$GO_URL"
|
||||
tar -C /usr/local -vxzf "$GO_TARBALL"
|
||||
|
||||
4
.github/workflows/golang-test-linux.yml
vendored
4
.github/workflows/golang-test-linux.yml
vendored
@@ -200,7 +200,7 @@ jobs:
|
||||
-e GOCACHE=${CONTAINER_GOCACHE} \
|
||||
-e GOMODCACHE=${CONTAINER_GOMODCACHE} \
|
||||
-e CONTAINER=${CONTAINER} \
|
||||
golang:1.25-alpine \
|
||||
golang:1.24-alpine \
|
||||
sh -c ' \
|
||||
apk update; apk add --no-cache \
|
||||
ca-certificates iptables ip6tables dbus dbus-dev libpcap-dev build-base; \
|
||||
@@ -259,7 +259,7 @@ jobs:
|
||||
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
|
||||
go test ${{ matrix.raceFlag }} \
|
||||
-exec 'sudo' \
|
||||
-timeout 10m -p 1 ./relay/... ./shared/relay/...
|
||||
-timeout 10m ./relay/... ./shared/relay/...
|
||||
|
||||
test_signal:
|
||||
name: "Signal / Unit"
|
||||
|
||||
7
.github/workflows/golangci-lint.yml
vendored
7
.github/workflows/golangci-lint.yml
vendored
@@ -52,10 +52,7 @@ jobs:
|
||||
if: matrix.os == 'ubuntu-latest'
|
||||
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev libpcap-dev
|
||||
- name: golangci-lint
|
||||
uses: golangci/golangci-lint-action@4afd733a84b1f43292c63897423277bb7f4313a9 # v8.0.0
|
||||
uses: golangci/golangci-lint-action@v4
|
||||
with:
|
||||
version: latest
|
||||
skip-cache: true
|
||||
skip-save-cache: true
|
||||
cache-invalidation-interval: 0
|
||||
args: --timeout=12m
|
||||
args: --timeout=12m --out-format colored-line-number
|
||||
|
||||
2
.github/workflows/release.yml
vendored
2
.github/workflows/release.yml
vendored
@@ -63,7 +63,7 @@ jobs:
|
||||
pkg install -y git curl portlint go
|
||||
|
||||
# Install Go for building
|
||||
GO_TARBALL="go1.25.5.freebsd-amd64.tar.gz"
|
||||
GO_TARBALL="go1.24.10.freebsd-amd64.tar.gz"
|
||||
GO_URL="https://go.dev/dl/$GO_TARBALL"
|
||||
curl -LO "$GO_URL"
|
||||
tar -C /usr/local -xzf "$GO_TARBALL"
|
||||
|
||||
13
.github/workflows/wasm-build-validation.yml
vendored
13
.github/workflows/wasm-build-validation.yml
vendored
@@ -14,9 +14,6 @@ jobs:
|
||||
js_lint:
|
||||
name: "JS / Lint"
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
GOOS: js
|
||||
GOARCH: wasm
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
@@ -27,14 +24,16 @@ jobs:
|
||||
- name: Install dependencies
|
||||
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev libpcap-dev
|
||||
- name: Install golangci-lint
|
||||
uses: golangci/golangci-lint-action@4afd733a84b1f43292c63897423277bb7f4313a9 # v8.0.0
|
||||
uses: golangci/golangci-lint-action@d6238b002a20823d52840fda27e2d4891c5952dc
|
||||
with:
|
||||
version: latest
|
||||
install-mode: binary
|
||||
skip-cache: true
|
||||
skip-save-cache: true
|
||||
cache-invalidation-interval: 0
|
||||
working-directory: ./client
|
||||
skip-pkg-cache: true
|
||||
skip-build-cache: true
|
||||
- name: Run golangci-lint for WASM
|
||||
run: |
|
||||
GOOS=js GOARCH=wasm golangci-lint run --timeout=12m --out-format colored-line-number ./client/...
|
||||
continue-on-error: true
|
||||
|
||||
js_build:
|
||||
|
||||
257
.golangci.yaml
257
.golangci.yaml
@@ -1,124 +1,139 @@
|
||||
version: "2"
|
||||
linters:
|
||||
default: none
|
||||
enable:
|
||||
- bodyclose
|
||||
- dupword
|
||||
- durationcheck
|
||||
- errcheck
|
||||
- forbidigo
|
||||
- gocritic
|
||||
- gosec
|
||||
- govet
|
||||
- ineffassign
|
||||
- mirror
|
||||
- misspell
|
||||
- nilerr
|
||||
- nilnil
|
||||
- predeclared
|
||||
- revive
|
||||
- sqlclosecheck
|
||||
- staticcheck
|
||||
- unused
|
||||
- wastedassign
|
||||
settings:
|
||||
errcheck:
|
||||
check-type-assertions: false
|
||||
gocritic:
|
||||
disabled-checks:
|
||||
- commentFormatting
|
||||
- captLocal
|
||||
- deprecatedComment
|
||||
gosec:
|
||||
includes:
|
||||
- G101
|
||||
- G103
|
||||
- G104
|
||||
- G106
|
||||
- G108
|
||||
- G109
|
||||
- G110
|
||||
- G111
|
||||
- G201
|
||||
- G202
|
||||
- G203
|
||||
- G301
|
||||
- G302
|
||||
- G303
|
||||
- G304
|
||||
- G305
|
||||
- G306
|
||||
- G307
|
||||
- G403
|
||||
- G502
|
||||
- G503
|
||||
- G504
|
||||
- G601
|
||||
- G602
|
||||
govet:
|
||||
enable:
|
||||
- nilness
|
||||
enable-all: false
|
||||
revive:
|
||||
rules:
|
||||
- name: exported
|
||||
arguments:
|
||||
- checkPrivateReceivers
|
||||
- sayRepetitiveInsteadOfStutters
|
||||
severity: warning
|
||||
disabled: false
|
||||
exclusions:
|
||||
generated: lax
|
||||
presets:
|
||||
- comments
|
||||
- common-false-positives
|
||||
- legacy
|
||||
- std-error-handling
|
||||
run:
|
||||
# Timeout for analysis, e.g. 30s, 5m.
|
||||
# Default: 1m
|
||||
timeout: 6m
|
||||
|
||||
# This file contains only configs which differ from defaults.
|
||||
# All possible options can be found here https://github.com/golangci/golangci-lint/blob/master/.golangci.reference.yml
|
||||
linters-settings:
|
||||
errcheck:
|
||||
# Report about not checking of errors in type assertions: `a := b.(MyStruct)`.
|
||||
# Such cases aren't reported by default.
|
||||
# Default: false
|
||||
check-type-assertions: false
|
||||
|
||||
gosec:
|
||||
includes:
|
||||
- G101 # Look for hard coded credentials
|
||||
#- G102 # Bind to all interfaces
|
||||
- G103 # Audit the use of unsafe block
|
||||
- G104 # Audit errors not checked
|
||||
- G106 # Audit the use of ssh.InsecureIgnoreHostKey
|
||||
#- G107 # Url provided to HTTP request as taint input
|
||||
- G108 # Profiling endpoint automatically exposed on /debug/pprof
|
||||
- G109 # Potential Integer overflow made by strconv.Atoi result conversion to int16/32
|
||||
- G110 # Potential DoS vulnerability via decompression bomb
|
||||
- G111 # Potential directory traversal
|
||||
#- G112 # Potential slowloris attack
|
||||
- G113 # Usage of Rat.SetString in math/big with an overflow (CVE-2022-23772)
|
||||
#- G114 # Use of net/http serve function that has no support for setting timeouts
|
||||
- G201 # SQL query construction using format string
|
||||
- G202 # SQL query construction using string concatenation
|
||||
- G203 # Use of unescaped data in HTML templates
|
||||
#- G204 # Audit use of command execution
|
||||
- G301 # Poor file permissions used when creating a directory
|
||||
- G302 # Poor file permissions used with chmod
|
||||
- G303 # Creating tempfile using a predictable path
|
||||
- G304 # File path provided as taint input
|
||||
- G305 # File traversal when extracting zip/tar archive
|
||||
- G306 # Poor file permissions used when writing to a new file
|
||||
- G307 # Poor file permissions used when creating a file with os.Create
|
||||
#- G401 # Detect the usage of DES, RC4, MD5 or SHA1
|
||||
#- G402 # Look for bad TLS connection settings
|
||||
- G403 # Ensure minimum RSA key length of 2048 bits
|
||||
#- G404 # Insecure random number source (rand)
|
||||
#- G501 # Import blocklist: crypto/md5
|
||||
- G502 # Import blocklist: crypto/des
|
||||
- G503 # Import blocklist: crypto/rc4
|
||||
- G504 # Import blocklist: net/http/cgi
|
||||
#- G505 # Import blocklist: crypto/sha1
|
||||
- G601 # Implicit memory aliasing of items from a range statement
|
||||
- G602 # Slice access out of bounds
|
||||
|
||||
gocritic:
|
||||
disabled-checks:
|
||||
- commentFormatting
|
||||
- captLocal
|
||||
- deprecatedComment
|
||||
|
||||
govet:
|
||||
# Enable all analyzers.
|
||||
# Default: false
|
||||
enable-all: false
|
||||
enable:
|
||||
- nilness
|
||||
|
||||
revive:
|
||||
rules:
|
||||
- linters:
|
||||
- forbidigo
|
||||
path: management/cmd/root\.go
|
||||
- linters:
|
||||
- forbidigo
|
||||
path: signal/cmd/root\.go
|
||||
- linters:
|
||||
- unused
|
||||
path: sharedsock/filter\.go
|
||||
- linters:
|
||||
- unused
|
||||
path: client/firewall/iptables/rule\.go
|
||||
- linters:
|
||||
- gosec
|
||||
- mirror
|
||||
path: test\.go
|
||||
- linters:
|
||||
- nilnil
|
||||
path: mock\.go
|
||||
- linters:
|
||||
- staticcheck
|
||||
text: grpc.DialContext is deprecated
|
||||
- linters:
|
||||
- staticcheck
|
||||
text: grpc.WithBlock is deprecated
|
||||
- linters:
|
||||
- staticcheck
|
||||
text: "QF1001"
|
||||
- linters:
|
||||
- staticcheck
|
||||
text: "QF1008"
|
||||
- linters:
|
||||
- staticcheck
|
||||
text: "QF1012"
|
||||
paths:
|
||||
- third_party$
|
||||
- builtin$
|
||||
- examples$
|
||||
- name: exported
|
||||
severity: warning
|
||||
disabled: false
|
||||
arguments:
|
||||
- "checkPrivateReceivers"
|
||||
- "sayRepetitiveInsteadOfStutters"
|
||||
tenv:
|
||||
# The option `all` will run against whole test files (`_test.go`) regardless of method/function signatures.
|
||||
# Otherwise, only methods that take `*testing.T`, `*testing.B`, and `testing.TB` as arguments are checked.
|
||||
# Default: false
|
||||
all: true
|
||||
|
||||
linters:
|
||||
disable-all: true
|
||||
enable:
|
||||
## enabled by default
|
||||
- errcheck # checking for unchecked errors, these unchecked errors can be critical bugs in some cases
|
||||
- gosimple # specializes in simplifying a code
|
||||
- govet # reports suspicious constructs, such as Printf calls whose arguments do not align with the format string
|
||||
- ineffassign # detects when assignments to existing variables are not used
|
||||
- staticcheck # is a go vet on steroids, applying a ton of static analysis checks
|
||||
- tenv # Tenv is analyzer that detects using os.Setenv instead of t.Setenv since Go1.17.
|
||||
- typecheck # like the front-end of a Go compiler, parses and type-checks Go code
|
||||
- unused # checks for unused constants, variables, functions and types
|
||||
## disable by default but the have interesting results so lets add them
|
||||
- bodyclose # checks whether HTTP response body is closed successfully
|
||||
- dupword # dupword checks for duplicate words in the source code
|
||||
- durationcheck # durationcheck checks for two durations multiplied together
|
||||
- forbidigo # forbidigo forbids identifiers
|
||||
- gocritic # provides diagnostics that check for bugs, performance and style issues
|
||||
- gosec # inspects source code for security problems
|
||||
- mirror # mirror reports wrong mirror patterns of bytes/strings usage
|
||||
- misspell # misspess finds commonly misspelled English words in comments
|
||||
- nilerr # finds the code that returns nil even if it checks that the error is not nil
|
||||
- nilnil # checks that there is no simultaneous return of nil error and an invalid value
|
||||
- predeclared # predeclared finds code that shadows one of Go's predeclared identifiers
|
||||
- revive # Fast, configurable, extensible, flexible, and beautiful linter for Go. Drop-in replacement of golint.
|
||||
- sqlclosecheck # checks that sql.Rows and sql.Stmt are closed
|
||||
# - thelper # thelper detects Go test helpers without t.Helper() call and checks the consistency of test helpers.
|
||||
- wastedassign # wastedassign finds wasted assignment statements
|
||||
issues:
|
||||
# Maximum count of issues with the same text.
|
||||
# Set to 0 to disable.
|
||||
# Default: 3
|
||||
max-same-issues: 5
|
||||
formatters:
|
||||
exclusions:
|
||||
generated: lax
|
||||
paths:
|
||||
- third_party$
|
||||
- builtin$
|
||||
- examples$
|
||||
|
||||
exclude-rules:
|
||||
# allow fmt
|
||||
- path: management/cmd/root\.go
|
||||
linters: forbidigo
|
||||
- path: signal/cmd/root\.go
|
||||
linters: forbidigo
|
||||
- path: sharedsock/filter\.go
|
||||
linters:
|
||||
- unused
|
||||
- path: client/firewall/iptables/rule\.go
|
||||
linters:
|
||||
- unused
|
||||
- path: test\.go
|
||||
linters:
|
||||
- mirror
|
||||
- gosec
|
||||
- path: mock\.go
|
||||
linters:
|
||||
- nilnil
|
||||
# Exclude specific deprecation warnings for grpc methods
|
||||
- linters:
|
||||
- staticcheck
|
||||
text: "grpc.DialContext is deprecated"
|
||||
- linters:
|
||||
- staticcheck
|
||||
text: "grpc.WithBlock is deprecated"
|
||||
|
||||
@@ -38,11 +38,6 @@
|
||||
|
||||
</strong>
|
||||
<br>
|
||||
<strong>
|
||||
🚀 <a href="https://careers.netbird.io">We are hiring! Join us at careers.netbird.io</a>
|
||||
</strong>
|
||||
<br>
|
||||
<br>
|
||||
<a href="https://registry.terraform.io/providers/netbirdio/netbird/latest">
|
||||
New: NetBird terraform provider
|
||||
</a>
|
||||
|
||||
@@ -136,7 +136,6 @@ func setLogLevel(cmd *cobra.Command, args []string) error {
|
||||
client := proto.NewDaemonServiceClient(conn)
|
||||
level := server.ParseLogLevel(args[0])
|
||||
if level == proto.LogLevel_UNKNOWN {
|
||||
//nolint
|
||||
return fmt.Errorf("unknown log level: %s. Available levels are: panic, fatal, error, warn, info, debug, trace\n", args[0])
|
||||
}
|
||||
|
||||
@@ -314,8 +313,9 @@ func getStatusOutput(cmd *cobra.Command, anon bool) string {
|
||||
profName = activeProf.Name
|
||||
}
|
||||
|
||||
overview := nbstatus.ConvertToStatusOutputOverview(statusResp, anon, "", nil, nil, nil, "", profName)
|
||||
statusOutputString = overview.FullDetailSummary()
|
||||
statusOutputString = nbstatus.ParseToFullDetailSummary(
|
||||
nbstatus.ConvertToStatusOutputOverview(statusResp, anon, "", nil, nil, nil, "", profName),
|
||||
)
|
||||
}
|
||||
return statusOutputString
|
||||
}
|
||||
|
||||
@@ -81,7 +81,6 @@ var loginCmd = &cobra.Command{
|
||||
func doDaemonLogin(ctx context.Context, cmd *cobra.Command, providedSetupKey string, activeProf *profilemanager.Profile, username string, pm *profilemanager.ProfileManager) error {
|
||||
conn, err := DialClientGRPCServer(ctx, daemonAddr)
|
||||
if err != nil {
|
||||
//nolint
|
||||
return fmt.Errorf("failed to connect to daemon error: %v\n"+
|
||||
"If the daemon is not running please run: "+
|
||||
"\nnetbird service install \nnetbird service start\n", err)
|
||||
@@ -207,7 +206,6 @@ func switchProfileOnDaemon(ctx context.Context, pm *profilemanager.ProfileManage
|
||||
func switchProfile(ctx context.Context, profileName string, username string) error {
|
||||
conn, err := DialClientGRPCServer(ctx, daemonAddr)
|
||||
if err != nil {
|
||||
//nolint
|
||||
return fmt.Errorf("failed to connect to daemon error: %v\n"+
|
||||
"If the daemon is not running please run: "+
|
||||
"\nnetbird service install \nnetbird service start\n", err)
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
//go:build pprof
|
||||
// +build pprof
|
||||
|
||||
package cmd
|
||||
|
||||
|
||||
@@ -390,7 +390,6 @@ func getClient(cmd *cobra.Command) (*grpc.ClientConn, error) {
|
||||
|
||||
conn, err := DialClientGRPCServer(cmd.Context(), daemonAddr)
|
||||
if err != nil {
|
||||
//nolint
|
||||
return nil, fmt.Errorf("failed to connect to daemon error: %v\n"+
|
||||
"If the daemon is not running please run: "+
|
||||
"\nnetbird service install \nnetbird service start\n", err)
|
||||
|
||||
@@ -103,13 +103,13 @@ func statusFunc(cmd *cobra.Command, args []string) error {
|
||||
var statusOutputString string
|
||||
switch {
|
||||
case detailFlag:
|
||||
statusOutputString = outputInformationHolder.FullDetailSummary()
|
||||
statusOutputString = nbstatus.ParseToFullDetailSummary(outputInformationHolder)
|
||||
case jsonFlag:
|
||||
statusOutputString, err = outputInformationHolder.JSON()
|
||||
statusOutputString, err = nbstatus.ParseToJSON(outputInformationHolder)
|
||||
case yamlFlag:
|
||||
statusOutputString, err = outputInformationHolder.YAML()
|
||||
statusOutputString, err = nbstatus.ParseToYAML(outputInformationHolder)
|
||||
default:
|
||||
statusOutputString = outputInformationHolder.GeneralSummary(false, false, false, false)
|
||||
statusOutputString = nbstatus.ParseGeneralSummary(outputInformationHolder, false, false, false, false)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
@@ -124,7 +124,6 @@ func statusFunc(cmd *cobra.Command, args []string) error {
|
||||
func getStatus(ctx context.Context, shouldRunProbes bool) (*proto.StatusResponse, error) {
|
||||
conn, err := DialClientGRPCServer(ctx, daemonAddr)
|
||||
if err != nil {
|
||||
//nolint
|
||||
return nil, fmt.Errorf("failed to connect to daemon error: %v\n"+
|
||||
"If the daemon is not running please run: "+
|
||||
"\nnetbird service install \nnetbird service start\n", err)
|
||||
|
||||
@@ -89,6 +89,9 @@ func startManagement(t *testing.T, config *config.Config, testFile string) (*grp
|
||||
t.Cleanup(cleanUp)
|
||||
|
||||
eventStore := &activity.InMemoryEventStore{}
|
||||
if err != nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
t.Cleanup(ctrl.Finish)
|
||||
|
||||
@@ -216,7 +216,6 @@ func runInDaemonMode(ctx context.Context, cmd *cobra.Command, pm *profilemanager
|
||||
|
||||
conn, err := DialClientGRPCServer(ctx, daemonAddr)
|
||||
if err != nil {
|
||||
//nolint
|
||||
return fmt.Errorf("failed to connect to daemon error: %v\n"+
|
||||
"If the daemon is not running please run: "+
|
||||
"\nnetbird service install \nnetbird service start\n", err)
|
||||
|
||||
@@ -10,7 +10,6 @@ import (
|
||||
"net/netip"
|
||||
"os"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
wgnetstack "golang.zx2c4.com/wireguard/tun/netstack"
|
||||
@@ -21,7 +20,6 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||
sshcommon "github.com/netbirdio/netbird/client/ssh"
|
||||
"github.com/netbirdio/netbird/client/system"
|
||||
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -31,11 +29,6 @@ var (
|
||||
ErrConfigNotInitialized = errors.New("config not initialized")
|
||||
)
|
||||
|
||||
const (
|
||||
defaultPeerConnectionTimeout = 60 * time.Second
|
||||
peerConnectionPollInterval = 500 * time.Millisecond
|
||||
)
|
||||
|
||||
// Client manages a netbird embedded client instance.
|
||||
type Client struct {
|
||||
deviceName string
|
||||
@@ -45,7 +38,6 @@ type Client struct {
|
||||
setupKey string
|
||||
jwtToken string
|
||||
connect *internal.ConnectClient
|
||||
recorder *peer.Status
|
||||
}
|
||||
|
||||
// Options configures a new Client.
|
||||
@@ -169,17 +161,11 @@ func New(opts Options) (*Client, error) {
|
||||
func (c *Client) Start(startCtx context.Context) error {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
if c.connect != nil {
|
||||
if c.cancel != nil {
|
||||
return ErrClientAlreadyStarted
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(internal.CtxInitState(context.Background()))
|
||||
defer func() {
|
||||
if c.connect == nil {
|
||||
cancel()
|
||||
}
|
||||
}()
|
||||
|
||||
ctx := internal.CtxInitState(context.Background())
|
||||
// nolint:staticcheck
|
||||
ctx = context.WithValue(ctx, system.DeviceNameCtxKey, c.deviceName)
|
||||
if err := internal.Login(ctx, c.config, c.setupKey, c.jwtToken); err != nil {
|
||||
@@ -187,9 +173,7 @@ func (c *Client) Start(startCtx context.Context) error {
|
||||
}
|
||||
|
||||
recorder := peer.NewRecorder(c.config.ManagementURL.String())
|
||||
c.recorder = recorder
|
||||
client := internal.NewConnectClient(ctx, c.config, recorder, false)
|
||||
client.SetSyncResponsePersistence(true)
|
||||
|
||||
// either startup error (permanent backoff err) or nil err (successful engine up)
|
||||
// TODO: make after-startup backoff err available
|
||||
@@ -213,7 +197,6 @@ func (c *Client) Start(startCtx context.Context) error {
|
||||
}
|
||||
|
||||
c.connect = client
|
||||
c.cancel = cancel
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -228,23 +211,17 @@ func (c *Client) Stop(ctx context.Context) error {
|
||||
return ErrClientNotStarted
|
||||
}
|
||||
|
||||
if c.cancel != nil {
|
||||
c.cancel()
|
||||
c.cancel = nil
|
||||
}
|
||||
|
||||
done := make(chan error, 1)
|
||||
connect := c.connect
|
||||
go func() {
|
||||
done <- connect.Stop()
|
||||
done <- c.connect.Stop()
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
c.connect = nil
|
||||
c.cancel = nil
|
||||
return ctx.Err()
|
||||
case err := <-done:
|
||||
c.connect = nil
|
||||
c.cancel = nil
|
||||
if err != nil {
|
||||
return fmt.Errorf("stop: %w", err)
|
||||
}
|
||||
@@ -264,40 +241,18 @@ func (c *Client) GetConfig() (profilemanager.Config, error) {
|
||||
|
||||
// Dial dials a network address in the netbird network.
|
||||
// Not applicable if the userspace networking mode is disabled.
|
||||
// With lazy connections, the connection is established on first traffic.
|
||||
func (c *Client) Dial(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
logrus.Infof("embed.Dial called: network=%s, address=%s", network, address)
|
||||
|
||||
// Check context status upfront
|
||||
if ctx.Err() != nil {
|
||||
logrus.Warnf("embed.Dial: context already cancelled/expired: %v", ctx.Err())
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
|
||||
engine, err := c.getEngine()
|
||||
if err != nil {
|
||||
logrus.Errorf("embed.Dial: getEngine failed: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
nsnet, err := engine.GetNet()
|
||||
if err != nil {
|
||||
logrus.Errorf("embed.Dial: GetNet failed: %v", err)
|
||||
return nil, fmt.Errorf("get net: %w", err)
|
||||
}
|
||||
|
||||
// Note: Don't wait for peer connection here - lazy connection manager
|
||||
// will open the connection when DialContext is called. The netstack
|
||||
// dial triggers WireGuard traffic which activates the lazy connection.
|
||||
|
||||
logrus.Debugf("embed.Dial: calling nsnet.DialContext for %s", address)
|
||||
conn, err := nsnet.DialContext(ctx, network, address)
|
||||
if err != nil {
|
||||
logrus.Errorf("embed.Dial: nsnet.DialContext failed: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
logrus.Infof("embed.Dial: successfully connected to %s", address)
|
||||
return conn, nil
|
||||
return nsnet.DialContext(ctx, network, address)
|
||||
}
|
||||
|
||||
// DialContext dials a network address in the netbird network with context
|
||||
@@ -360,90 +315,6 @@ func (c *Client) NewHTTPClient() *http.Client {
|
||||
}
|
||||
}
|
||||
|
||||
// Status returns the current status of the client.
|
||||
func (c *Client) Status() (peer.FullStatus, error) {
|
||||
c.mu.Lock()
|
||||
recorder := c.recorder
|
||||
connect := c.connect
|
||||
c.mu.Unlock()
|
||||
|
||||
if recorder == nil {
|
||||
return peer.FullStatus{}, errors.New("client not started")
|
||||
}
|
||||
|
||||
if connect != nil {
|
||||
engine := connect.Engine()
|
||||
if engine != nil {
|
||||
_ = engine.RunHealthProbes(false)
|
||||
}
|
||||
}
|
||||
|
||||
return recorder.GetFullStatus(), nil
|
||||
}
|
||||
|
||||
// GetLatestSyncResponse returns the latest sync response from the management server.
|
||||
func (c *Client) GetLatestSyncResponse() (*mgmProto.SyncResponse, error) {
|
||||
engine, err := c.getEngine()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
syncResp, err := engine.GetLatestSyncResponse()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get sync response: %w", err)
|
||||
}
|
||||
|
||||
return syncResp, nil
|
||||
}
|
||||
|
||||
// WaitForPeerConnection waits for a peer with the given IP to be connected.
|
||||
func (c *Client) WaitForPeerConnection(ctx context.Context, peerIP string) error {
|
||||
logrus.Infof("Waiting for peer %s to be connected", peerIP)
|
||||
|
||||
ticker := time.NewTicker(peerConnectionPollInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return fmt.Errorf("timeout waiting for peer %s to connect: %w", peerIP, ctx.Err())
|
||||
case <-ticker.C:
|
||||
status, err := c.Status()
|
||||
if err != nil {
|
||||
logrus.Debugf("Error getting status while waiting for peer: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
for _, p := range status.Peers {
|
||||
if p.IP == peerIP && p.ConnStatus == peer.StatusConnected {
|
||||
logrus.Infof("Peer %s is now connected (relayed: %v)", peerIP, p.Relayed)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
logrus.Tracef("Peer %s not yet connected, waiting...", peerIP)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// SetLogLevel sets the logging level for the client and its components.
|
||||
func (c *Client) SetLogLevel(levelStr string) error {
|
||||
level, err := logrus.ParseLevel(levelStr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parse log level: %w", err)
|
||||
}
|
||||
|
||||
logrus.SetLevel(level)
|
||||
|
||||
c.mu.Lock()
|
||||
connect := c.connect
|
||||
c.mu.Unlock()
|
||||
|
||||
// Note: ConnectClient doesn't have SetLogLevel method
|
||||
_ = connect
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// VerifySSHHostKey verifies an SSH host key against stored peer keys.
|
||||
// Returns nil if the key matches, ErrPeerNotFound if peer is not in network,
|
||||
// ErrNoStoredKey if peer has no stored key, or an error for verification failures.
|
||||
|
||||
@@ -386,8 +386,11 @@ func (m *aclManager) updateState() {
|
||||
|
||||
// filterRuleSpecs returns the specs of a filtering rule
|
||||
func filterRuleSpecs(ip net.IP, protocol string, sPort, dPort *firewall.Port, action firewall.Action, ipsetName string) (specs []string) {
|
||||
matchByIP := true
|
||||
// don't use IP matching if IP is 0.0.0.0
|
||||
matchByIP := !ip.IsUnspecified()
|
||||
if ip.IsUnspecified() {
|
||||
matchByIP = false
|
||||
}
|
||||
|
||||
if matchByIP {
|
||||
if ipsetName != "" {
|
||||
|
||||
@@ -161,7 +161,7 @@ func TestIptablesManagerDenyRules(t *testing.T) {
|
||||
t.Logf(" [%d] %s", i, rule)
|
||||
}
|
||||
|
||||
var denyRuleIndex, acceptRuleIndex = -1, -1
|
||||
var denyRuleIndex, acceptRuleIndex int = -1, -1
|
||||
for i, rule := range rules {
|
||||
if strings.Contains(rule, "DROP") {
|
||||
t.Logf("Found DROP rule at index %d: %s", i, rule)
|
||||
|
||||
@@ -198,7 +198,7 @@ func TestNftablesManagerRuleOrder(t *testing.T) {
|
||||
t.Logf("Found %d rules in nftables chain", len(rules))
|
||||
|
||||
// Find the accept and deny rules and verify deny comes before accept
|
||||
var acceptRuleIndex, denyRuleIndex = -1, -1
|
||||
var acceptRuleIndex, denyRuleIndex int = -1, -1
|
||||
for i, rule := range rules {
|
||||
hasAcceptHTTPSet := false
|
||||
hasDenyHTTPSet := false
|
||||
@@ -208,13 +208,11 @@ func TestNftablesManagerRuleOrder(t *testing.T) {
|
||||
for _, e := range rule.Exprs {
|
||||
// Check for set lookup
|
||||
if lookup, ok := e.(*expr.Lookup); ok {
|
||||
switch lookup.SetName {
|
||||
case "accept-http":
|
||||
if lookup.SetName == "accept-http" {
|
||||
hasAcceptHTTPSet = true
|
||||
case "deny-http":
|
||||
} else if lookup.SetName == "deny-http" {
|
||||
hasDenyHTTPSet = true
|
||||
}
|
||||
|
||||
}
|
||||
// Check for port 80
|
||||
if cmp, ok := e.(*expr.Cmp); ok {
|
||||
@@ -224,10 +222,9 @@ func TestNftablesManagerRuleOrder(t *testing.T) {
|
||||
}
|
||||
// Check for verdict
|
||||
if verdict, ok := e.(*expr.Verdict); ok {
|
||||
switch verdict.Kind {
|
||||
case expr.VerdictAccept:
|
||||
if verdict.Kind == expr.VerdictAccept {
|
||||
action = "ACCEPT"
|
||||
case expr.VerdictDrop:
|
||||
} else if verdict.Kind == expr.VerdictDrop {
|
||||
action = "DROP"
|
||||
}
|
||||
}
|
||||
|
||||
@@ -29,7 +29,7 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
layerTypeAll = 255
|
||||
layerTypeAll = 0
|
||||
|
||||
// ipTCPHeaderMinSize represents minimum IP (20) + TCP (20) header size for MSS calculation
|
||||
ipTCPHeaderMinSize = 40
|
||||
@@ -262,7 +262,10 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe
|
||||
}
|
||||
|
||||
func (m *Manager) blockInvalidRouted(iface common.IFaceMapper) (firewall.Rule, error) {
|
||||
wgPrefix := iface.Address().Network
|
||||
wgPrefix, err := netip.ParsePrefix(iface.Address().Network.String())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse wireguard network: %w", err)
|
||||
}
|
||||
log.Debugf("blocking invalid routed traffic for %s", wgPrefix)
|
||||
|
||||
rule, err := m.addRouteFiltering(
|
||||
@@ -436,7 +439,19 @@ func (m *Manager) AddPeerFiltering(
|
||||
r.sPort = sPort
|
||||
r.dPort = dPort
|
||||
|
||||
r.protoLayer = protoToLayer(proto, r.ipLayer)
|
||||
switch proto {
|
||||
case firewall.ProtocolTCP:
|
||||
r.protoLayer = layers.LayerTypeTCP
|
||||
case firewall.ProtocolUDP:
|
||||
r.protoLayer = layers.LayerTypeUDP
|
||||
case firewall.ProtocolICMP:
|
||||
r.protoLayer = layers.LayerTypeICMPv4
|
||||
if r.ipLayer == layers.LayerTypeIPv6 {
|
||||
r.protoLayer = layers.LayerTypeICMPv6
|
||||
}
|
||||
case firewall.ProtocolALL:
|
||||
r.protoLayer = layerTypeAll
|
||||
}
|
||||
|
||||
m.mutex.Lock()
|
||||
var targetMap map[netip.Addr]RuleSet
|
||||
@@ -481,17 +496,16 @@ func (m *Manager) addRouteFiltering(
|
||||
}
|
||||
|
||||
ruleID := uuid.New().String()
|
||||
|
||||
rule := RouteRule{
|
||||
// TODO: consolidate these IDs
|
||||
id: ruleID,
|
||||
mgmtId: id,
|
||||
sources: sources,
|
||||
dstSet: destination.Set,
|
||||
protoLayer: protoToLayer(proto, layers.LayerTypeIPv4),
|
||||
srcPort: sPort,
|
||||
dstPort: dPort,
|
||||
action: action,
|
||||
id: ruleID,
|
||||
mgmtId: id,
|
||||
sources: sources,
|
||||
dstSet: destination.Set,
|
||||
proto: proto,
|
||||
srcPort: sPort,
|
||||
dstPort: dPort,
|
||||
action: action,
|
||||
}
|
||||
if destination.IsPrefix() {
|
||||
rule.destinations = []netip.Prefix{destination.Prefix}
|
||||
@@ -781,7 +795,7 @@ func (m *Manager) recalculateTCPChecksum(packetData []byte, d *decoder, tcpHeade
|
||||
pseudoSum += uint32(d.ip4.Protocol)
|
||||
pseudoSum += uint32(tcpLength)
|
||||
|
||||
var sum = pseudoSum
|
||||
var sum uint32 = pseudoSum
|
||||
for i := 0; i < tcpLength-1; i += 2 {
|
||||
sum += uint32(tcpLayer[i])<<8 | uint32(tcpLayer[i+1])
|
||||
}
|
||||
@@ -931,7 +945,7 @@ func (m *Manager) filterInbound(packetData []byte, size int) bool {
|
||||
func (m *Manager) handleLocalTraffic(d *decoder, srcIP, dstIP netip.Addr, packetData []byte, size int) bool {
|
||||
ruleID, blocked := m.peerACLsBlock(srcIP, d, packetData)
|
||||
if blocked {
|
||||
pnum := getProtocolFromPacket(d)
|
||||
_, pnum := getProtocolFromPacket(d)
|
||||
srcPort, dstPort := getPortsFromPacket(d)
|
||||
|
||||
m.logger.Trace6("Dropping local packet (ACL denied): rule_id=%s proto=%v src=%s:%d dst=%s:%d",
|
||||
@@ -996,22 +1010,20 @@ func (m *Manager) handleRoutedTraffic(d *decoder, srcIP, dstIP netip.Addr, packe
|
||||
return false
|
||||
}
|
||||
|
||||
protoLayer := d.decoded[1]
|
||||
proto, pnum := getProtocolFromPacket(d)
|
||||
srcPort, dstPort := getPortsFromPacket(d)
|
||||
|
||||
ruleID, pass := m.routeACLsPass(srcIP, dstIP, protoLayer, srcPort, dstPort)
|
||||
ruleID, pass := m.routeACLsPass(srcIP, dstIP, proto, srcPort, dstPort)
|
||||
if !pass {
|
||||
proto := getProtocolFromPacket(d)
|
||||
|
||||
m.logger.Trace6("Dropping routed packet (ACL denied): rule_id=%s proto=%v src=%s:%d dst=%s:%d",
|
||||
ruleID, proto, srcIP, srcPort, dstIP, dstPort)
|
||||
ruleID, pnum, srcIP, srcPort, dstIP, dstPort)
|
||||
|
||||
m.flowLogger.StoreEvent(nftypes.EventFields{
|
||||
FlowID: uuid.New(),
|
||||
Type: nftypes.TypeDrop,
|
||||
RuleID: ruleID,
|
||||
Direction: nftypes.Ingress,
|
||||
Protocol: proto,
|
||||
Protocol: pnum,
|
||||
SourceIP: srcIP,
|
||||
DestIP: dstIP,
|
||||
SourcePort: srcPort,
|
||||
@@ -1040,33 +1052,16 @@ func (m *Manager) handleRoutedTraffic(d *decoder, srcIP, dstIP netip.Addr, packe
|
||||
return true
|
||||
}
|
||||
|
||||
func protoToLayer(proto firewall.Protocol, ipLayer gopacket.LayerType) gopacket.LayerType {
|
||||
switch proto {
|
||||
case firewall.ProtocolTCP:
|
||||
return layers.LayerTypeTCP
|
||||
case firewall.ProtocolUDP:
|
||||
return layers.LayerTypeUDP
|
||||
case firewall.ProtocolICMP:
|
||||
if ipLayer == layers.LayerTypeIPv6 {
|
||||
return layers.LayerTypeICMPv6
|
||||
}
|
||||
return layers.LayerTypeICMPv4
|
||||
case firewall.ProtocolALL:
|
||||
return layerTypeAll
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func getProtocolFromPacket(d *decoder) nftypes.Protocol {
|
||||
func getProtocolFromPacket(d *decoder) (firewall.Protocol, nftypes.Protocol) {
|
||||
switch d.decoded[1] {
|
||||
case layers.LayerTypeTCP:
|
||||
return nftypes.TCP
|
||||
return firewall.ProtocolTCP, nftypes.TCP
|
||||
case layers.LayerTypeUDP:
|
||||
return nftypes.UDP
|
||||
return firewall.ProtocolUDP, nftypes.UDP
|
||||
case layers.LayerTypeICMPv4, layers.LayerTypeICMPv6:
|
||||
return nftypes.ICMP
|
||||
return firewall.ProtocolICMP, nftypes.ICMP
|
||||
default:
|
||||
return nftypes.ProtocolUnknown
|
||||
return firewall.ProtocolALL, nftypes.ProtocolUnknown
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1238,30 +1233,19 @@ func validateRule(ip netip.Addr, packetData []byte, rules map[string]PeerRule, d
|
||||
}
|
||||
|
||||
// routeACLsPass returns true if the packet is allowed by the route ACLs
|
||||
func (m *Manager) routeACLsPass(srcIP, dstIP netip.Addr, protoLayer gopacket.LayerType, srcPort, dstPort uint16) ([]byte, bool) {
|
||||
func (m *Manager) routeACLsPass(srcIP, dstIP netip.Addr, proto firewall.Protocol, srcPort, dstPort uint16) ([]byte, bool) {
|
||||
m.mutex.RLock()
|
||||
defer m.mutex.RUnlock()
|
||||
|
||||
for _, rule := range m.routeRules {
|
||||
if matches := m.ruleMatches(rule, srcIP, dstIP, protoLayer, srcPort, dstPort); matches {
|
||||
if matches := m.ruleMatches(rule, srcIP, dstIP, proto, srcPort, dstPort); matches {
|
||||
return rule.mgmtId, rule.action == firewall.ActionAccept
|
||||
}
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
func (m *Manager) ruleMatches(rule *RouteRule, srcAddr, dstAddr netip.Addr, protoLayer gopacket.LayerType, srcPort, dstPort uint16) bool {
|
||||
// TODO: handle ipv6 vs ipv4 icmp rules
|
||||
if rule.protoLayer != layerTypeAll && rule.protoLayer != protoLayer {
|
||||
return false
|
||||
}
|
||||
|
||||
if protoLayer == layers.LayerTypeTCP || protoLayer == layers.LayerTypeUDP {
|
||||
if !portsMatch(rule.srcPort, srcPort) || !portsMatch(rule.dstPort, dstPort) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) ruleMatches(rule *RouteRule, srcAddr, dstAddr netip.Addr, proto firewall.Protocol, srcPort, dstPort uint16) bool {
|
||||
destMatched := false
|
||||
for _, dst := range rule.destinations {
|
||||
if dst.Contains(dstAddr) {
|
||||
@@ -1280,8 +1264,21 @@ func (m *Manager) ruleMatches(rule *RouteRule, srcAddr, dstAddr netip.Addr, prot
|
||||
break
|
||||
}
|
||||
}
|
||||
if !sourceMatched {
|
||||
return false
|
||||
}
|
||||
|
||||
return sourceMatched
|
||||
if rule.proto != firewall.ProtocolALL && rule.proto != proto {
|
||||
return false
|
||||
}
|
||||
|
||||
if proto == firewall.ProtocolTCP || proto == firewall.ProtocolUDP {
|
||||
if !portsMatch(rule.srcPort, srcPort) || !portsMatch(rule.dstPort, dstPort) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// AddUDPPacketHook calls hook when UDP packet from given direction matched
|
||||
|
||||
@@ -955,7 +955,7 @@ func BenchmarkRouteACLs(b *testing.B) {
|
||||
for _, tc := range cases {
|
||||
srcIP := netip.MustParseAddr(tc.srcIP)
|
||||
dstIP := netip.MustParseAddr(tc.dstIP)
|
||||
manager.routeACLsPass(srcIP, dstIP, protoToLayer(tc.proto, layers.LayerTypeIPv4), 0, tc.dstPort)
|
||||
manager.routeACLsPass(srcIP, dstIP, tc.proto, 0, tc.dstPort)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1259,7 +1259,7 @@ func TestRouteACLFiltering(t *testing.T) {
|
||||
|
||||
// testing routeACLsPass only and not FilterInbound, as routed packets are dropped after being passed
|
||||
// to the forwarder
|
||||
_, isAllowed := manager.routeACLsPass(srcIP, dstIP, protoToLayer(tc.proto, layers.LayerTypeIPv4), tc.srcPort, tc.dstPort)
|
||||
_, isAllowed := manager.routeACLsPass(srcIP, dstIP, tc.proto, tc.srcPort, tc.dstPort)
|
||||
require.Equal(t, tc.shouldPass, isAllowed)
|
||||
})
|
||||
}
|
||||
@@ -1445,7 +1445,7 @@ func TestRouteACLOrder(t *testing.T) {
|
||||
srcIP := netip.MustParseAddr(p.srcIP)
|
||||
dstIP := netip.MustParseAddr(p.dstIP)
|
||||
|
||||
_, isAllowed := manager.routeACLsPass(srcIP, dstIP, protoToLayer(p.proto, layers.LayerTypeIPv4), p.srcPort, p.dstPort)
|
||||
_, isAllowed := manager.routeACLsPass(srcIP, dstIP, p.proto, p.srcPort, p.dstPort)
|
||||
require.Equal(t, p.shouldPass, isAllowed, "packet %d failed", i)
|
||||
}
|
||||
})
|
||||
@@ -1488,13 +1488,13 @@ func TestRouteACLSet(t *testing.T) {
|
||||
dstIP := netip.MustParseAddr("192.168.1.100")
|
||||
|
||||
// Check that traffic is dropped (empty set shouldn't match anything)
|
||||
_, isAllowed := manager.routeACLsPass(srcIP, dstIP, protoToLayer(fw.ProtocolTCP, layers.LayerTypeIPv4), 12345, 80)
|
||||
_, isAllowed := manager.routeACLsPass(srcIP, dstIP, fw.ProtocolTCP, 12345, 80)
|
||||
require.False(t, isAllowed, "Empty set should not allow any traffic")
|
||||
|
||||
err = manager.UpdateSet(set, []netip.Prefix{netip.MustParsePrefix("192.168.1.0/24")})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Now the packet should be allowed
|
||||
_, isAllowed = manager.routeACLsPass(srcIP, dstIP, protoToLayer(fw.ProtocolTCP, layers.LayerTypeIPv4), 12345, 80)
|
||||
_, isAllowed = manager.routeACLsPass(srcIP, dstIP, fw.ProtocolTCP, 12345, 80)
|
||||
require.True(t, isAllowed, "After set update, traffic to the added network should be allowed")
|
||||
}
|
||||
|
||||
@@ -767,9 +767,9 @@ func TestUpdateSetMerge(t *testing.T) {
|
||||
dstIP2 := netip.MustParseAddr("192.168.1.100")
|
||||
dstIP3 := netip.MustParseAddr("172.16.0.100")
|
||||
|
||||
_, isAllowed1 := manager.routeACLsPass(srcIP, dstIP1, protoToLayer(fw.ProtocolTCP, layers.LayerTypeIPv4), 12345, 80)
|
||||
_, isAllowed2 := manager.routeACLsPass(srcIP, dstIP2, protoToLayer(fw.ProtocolTCP, layers.LayerTypeIPv4), 12345, 80)
|
||||
_, isAllowed3 := manager.routeACLsPass(srcIP, dstIP3, protoToLayer(fw.ProtocolTCP, layers.LayerTypeIPv4), 12345, 80)
|
||||
_, isAllowed1 := manager.routeACLsPass(srcIP, dstIP1, fw.ProtocolTCP, 12345, 80)
|
||||
_, isAllowed2 := manager.routeACLsPass(srcIP, dstIP2, fw.ProtocolTCP, 12345, 80)
|
||||
_, isAllowed3 := manager.routeACLsPass(srcIP, dstIP3, fw.ProtocolTCP, 12345, 80)
|
||||
|
||||
require.True(t, isAllowed1, "Traffic to 10.0.0.100 should be allowed")
|
||||
require.True(t, isAllowed2, "Traffic to 192.168.1.100 should be allowed")
|
||||
@@ -784,8 +784,8 @@ func TestUpdateSetMerge(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
// Check that all original prefixes are still included
|
||||
_, isAllowed1 = manager.routeACLsPass(srcIP, dstIP1, protoToLayer(fw.ProtocolTCP, layers.LayerTypeIPv4), 12345, 80)
|
||||
_, isAllowed2 = manager.routeACLsPass(srcIP, dstIP2, protoToLayer(fw.ProtocolTCP, layers.LayerTypeIPv4), 12345, 80)
|
||||
_, isAllowed1 = manager.routeACLsPass(srcIP, dstIP1, fw.ProtocolTCP, 12345, 80)
|
||||
_, isAllowed2 = manager.routeACLsPass(srcIP, dstIP2, fw.ProtocolTCP, 12345, 80)
|
||||
require.True(t, isAllowed1, "Traffic to 10.0.0.100 should still be allowed after update")
|
||||
require.True(t, isAllowed2, "Traffic to 192.168.1.100 should still be allowed after update")
|
||||
|
||||
@@ -793,8 +793,8 @@ func TestUpdateSetMerge(t *testing.T) {
|
||||
dstIP4 := netip.MustParseAddr("172.16.1.100")
|
||||
dstIP5 := netip.MustParseAddr("10.1.0.50")
|
||||
|
||||
_, isAllowed4 := manager.routeACLsPass(srcIP, dstIP4, protoToLayer(fw.ProtocolTCP, layers.LayerTypeIPv4), 12345, 80)
|
||||
_, isAllowed5 := manager.routeACLsPass(srcIP, dstIP5, protoToLayer(fw.ProtocolTCP, layers.LayerTypeIPv4), 12345, 80)
|
||||
_, isAllowed4 := manager.routeACLsPass(srcIP, dstIP4, fw.ProtocolTCP, 12345, 80)
|
||||
_, isAllowed5 := manager.routeACLsPass(srcIP, dstIP5, fw.ProtocolTCP, 12345, 80)
|
||||
|
||||
require.True(t, isAllowed4, "Traffic to new prefix 172.16.0.0/16 should be allowed")
|
||||
require.True(t, isAllowed5, "Traffic to new prefix 10.1.0.0/24 should be allowed")
|
||||
@@ -922,7 +922,7 @@ func TestUpdateSetDeduplication(t *testing.T) {
|
||||
|
||||
srcIP := netip.MustParseAddr("100.10.0.1")
|
||||
for _, tc := range testCases {
|
||||
_, isAllowed := manager.routeACLsPass(srcIP, tc.dstIP, protoToLayer(fw.ProtocolTCP, layers.LayerTypeIPv4), 12345, 80)
|
||||
_, isAllowed := manager.routeACLsPass(srcIP, tc.dstIP, fw.ProtocolTCP, 12345, 80)
|
||||
require.Equal(t, tc.expected, isAllowed, tc.desc)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,7 +2,6 @@ package forwarder
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync/atomic"
|
||||
|
||||
wgdevice "golang.zx2c4.com/wireguard/device"
|
||||
"gvisor.dev/gvisor/pkg/tcpip"
|
||||
@@ -17,7 +16,7 @@ type endpoint struct {
|
||||
logger *nblog.Logger
|
||||
dispatcher stack.NetworkDispatcher
|
||||
device *wgdevice.Device
|
||||
mtu atomic.Uint32
|
||||
mtu uint32
|
||||
}
|
||||
|
||||
func (e *endpoint) Attach(dispatcher stack.NetworkDispatcher) {
|
||||
@@ -29,7 +28,7 @@ func (e *endpoint) IsAttached() bool {
|
||||
}
|
||||
|
||||
func (e *endpoint) MTU() uint32 {
|
||||
return e.mtu.Load()
|
||||
return e.mtu
|
||||
}
|
||||
|
||||
func (e *endpoint) Capabilities() stack.LinkEndpointCapabilities {
|
||||
@@ -83,22 +82,6 @@ func (e *endpoint) ParseHeader(*stack.PacketBuffer) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (e *endpoint) Close() {
|
||||
// Endpoint cleanup - nothing to do as device is managed externally
|
||||
}
|
||||
|
||||
func (e *endpoint) SetLinkAddress(tcpip.LinkAddress) {
|
||||
// Link address is not used for this endpoint type
|
||||
}
|
||||
|
||||
func (e *endpoint) SetMTU(mtu uint32) {
|
||||
e.mtu.Store(mtu)
|
||||
}
|
||||
|
||||
func (e *endpoint) SetOnCloseAction(func()) {
|
||||
// No action needed on close
|
||||
}
|
||||
|
||||
type epID stack.TransportEndpointID
|
||||
|
||||
func (i epID) String() string {
|
||||
|
||||
@@ -7,7 +7,6 @@ import (
|
||||
"net/netip"
|
||||
"runtime"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"gvisor.dev/gvisor/pkg/buffer"
|
||||
@@ -36,16 +35,14 @@ type Forwarder struct {
|
||||
logger *nblog.Logger
|
||||
flowLogger nftypes.FlowLogger
|
||||
// ruleIdMap is used to store the rule ID for a given connection
|
||||
ruleIdMap sync.Map
|
||||
stack *stack.Stack
|
||||
endpoint *endpoint
|
||||
udpForwarder *udpForwarder
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
ip tcpip.Address
|
||||
netstack bool
|
||||
hasRawICMPAccess bool
|
||||
pingSemaphore chan struct{}
|
||||
ruleIdMap sync.Map
|
||||
stack *stack.Stack
|
||||
endpoint *endpoint
|
||||
udpForwarder *udpForwarder
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
ip tcpip.Address
|
||||
netstack bool
|
||||
}
|
||||
|
||||
func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.FlowLogger, netstack bool, mtu uint16) (*Forwarder, error) {
|
||||
@@ -63,8 +60,8 @@ func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.Flow
|
||||
endpoint := &endpoint{
|
||||
logger: logger,
|
||||
device: iface.GetWGDevice(),
|
||||
mtu: uint32(mtu),
|
||||
}
|
||||
endpoint.mtu.Store(uint32(mtu))
|
||||
|
||||
if err := s.CreateNIC(nicID, endpoint); err != nil {
|
||||
return nil, fmt.Errorf("create NIC: %v", err)
|
||||
@@ -106,16 +103,15 @@ func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.Flow
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
f := &Forwarder{
|
||||
logger: logger,
|
||||
flowLogger: flowLogger,
|
||||
stack: s,
|
||||
endpoint: endpoint,
|
||||
udpForwarder: newUDPForwarder(mtu, logger, flowLogger),
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
netstack: netstack,
|
||||
ip: tcpip.AddrFromSlice(iface.Address().IP.AsSlice()),
|
||||
pingSemaphore: make(chan struct{}, 3),
|
||||
logger: logger,
|
||||
flowLogger: flowLogger,
|
||||
stack: s,
|
||||
endpoint: endpoint,
|
||||
udpForwarder: newUDPForwarder(mtu, logger, flowLogger),
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
netstack: netstack,
|
||||
ip: tcpip.AddrFromSlice(iface.Address().IP.AsSlice()),
|
||||
}
|
||||
|
||||
receiveWindow := defaultReceiveWindow
|
||||
@@ -133,8 +129,6 @@ func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.Flow
|
||||
|
||||
s.SetTransportProtocolHandler(icmp.ProtocolNumber4, f.handleICMP)
|
||||
|
||||
f.checkICMPCapability()
|
||||
|
||||
log.Debugf("forwarder: Initialization complete with NIC %d", nicID)
|
||||
return f, nil
|
||||
}
|
||||
@@ -204,24 +198,3 @@ func buildKey(srcIP, dstIP netip.Addr, srcPort, dstPort uint16) conntrack.ConnKe
|
||||
DstPort: dstPort,
|
||||
}
|
||||
}
|
||||
|
||||
// checkICMPCapability tests whether we have raw ICMP socket access at startup.
|
||||
func (f *Forwarder) checkICMPCapability() {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
lc := net.ListenConfig{}
|
||||
conn, err := lc.ListenPacket(ctx, "ip4:icmp", "0.0.0.0")
|
||||
if err != nil {
|
||||
f.hasRawICMPAccess = false
|
||||
f.logger.Debug("forwarder: No raw ICMP socket access, will use ping binary fallback")
|
||||
return
|
||||
}
|
||||
|
||||
if err := conn.Close(); err != nil {
|
||||
f.logger.Debug1("forwarder: Failed to close ICMP capability test socket: %v", err)
|
||||
}
|
||||
|
||||
f.hasRawICMPAccess = true
|
||||
f.logger.Debug("forwarder: Raw ICMP socket access available")
|
||||
}
|
||||
|
||||
@@ -2,11 +2,8 @@ package forwarder
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os/exec"
|
||||
"runtime"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
@@ -17,95 +14,30 @@ import (
|
||||
)
|
||||
|
||||
// handleICMP handles ICMP packets from the network stack
|
||||
func (f *Forwarder) handleICMP(id stack.TransportEndpointID, pkt *stack.PacketBuffer) bool {
|
||||
func (f *Forwarder) handleICMP(id stack.TransportEndpointID, pkt stack.PacketBufferPtr) bool {
|
||||
icmpHdr := header.ICMPv4(pkt.TransportHeader().View().AsSlice())
|
||||
icmpType := uint8(icmpHdr.Type())
|
||||
icmpCode := uint8(icmpHdr.Code())
|
||||
|
||||
flowID := uuid.New()
|
||||
f.sendICMPEvent(nftypes.TypeStart, flowID, id, uint8(icmpHdr.Type()), uint8(icmpHdr.Code()), 0, 0)
|
||||
|
||||
// For Echo Requests, send and wait for response
|
||||
if icmpHdr.Type() == header.ICMPv4Echo {
|
||||
return f.handleICMPEcho(flowID, id, pkt, uint8(icmpHdr.Type()), uint8(icmpHdr.Code()))
|
||||
}
|
||||
|
||||
// For other ICMP types (Time Exceeded, Destination Unreachable, etc), forward without waiting
|
||||
if !f.hasRawICMPAccess {
|
||||
f.logger.Debug2("forwarder: Cannot handle ICMP type %v without raw socket access for %v", icmpHdr.Type(), epID(id))
|
||||
return false
|
||||
}
|
||||
|
||||
icmpData := stack.PayloadSince(pkt.TransportHeader()).AsSlice()
|
||||
conn, err := f.forwardICMPPacket(id, icmpData, uint8(icmpHdr.Type()), uint8(icmpHdr.Code()), 100*time.Millisecond)
|
||||
if err != nil {
|
||||
f.logger.Error2("forwarder: Failed to forward ICMP packet for %v: %v", epID(id), err)
|
||||
if header.ICMPv4Type(icmpType) == header.ICMPv4EchoReply {
|
||||
// dont process our own replies
|
||||
return true
|
||||
}
|
||||
if err := conn.Close(); err != nil {
|
||||
f.logger.Debug1("forwarder: Failed to close ICMP socket: %v", err)
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
flowID := uuid.New()
|
||||
f.sendICMPEvent(nftypes.TypeStart, flowID, id, icmpType, icmpCode, 0, 0)
|
||||
|
||||
// handleICMPEcho handles ICMP echo requests asynchronously with rate limiting.
|
||||
func (f *Forwarder) handleICMPEcho(flowID uuid.UUID, id stack.TransportEndpointID, pkt *stack.PacketBuffer, icmpType, icmpCode uint8) bool {
|
||||
select {
|
||||
case f.pingSemaphore <- struct{}{}:
|
||||
icmpData := stack.PayloadSince(pkt.TransportHeader()).ToSlice()
|
||||
rxBytes := pkt.Size()
|
||||
|
||||
go func() {
|
||||
defer func() { <-f.pingSemaphore }()
|
||||
|
||||
if f.hasRawICMPAccess {
|
||||
f.handleICMPViaSocket(flowID, id, icmpType, icmpCode, icmpData, rxBytes)
|
||||
} else {
|
||||
f.handleICMPViaPing(flowID, id, icmpType, icmpCode, icmpData, rxBytes)
|
||||
}
|
||||
}()
|
||||
default:
|
||||
f.logger.Debug3("forwarder: ICMP rate limit exceeded for %v type %v code %v",
|
||||
epID(id), icmpType, icmpCode)
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// forwardICMPPacket creates a raw ICMP socket and sends the packet, returning the connection.
|
||||
// The caller is responsible for closing the returned connection.
|
||||
func (f *Forwarder) forwardICMPPacket(id stack.TransportEndpointID, payload []byte, icmpType, icmpCode uint8, timeout time.Duration) (net.PacketConn, error) {
|
||||
ctx, cancel := context.WithTimeout(f.ctx, timeout)
|
||||
ctx, cancel := context.WithTimeout(f.ctx, 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
lc := net.ListenConfig{}
|
||||
// TODO: support non-root
|
||||
conn, err := lc.ListenPacket(ctx, "ip4:icmp", "0.0.0.0")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create ICMP socket: %w", err)
|
||||
}
|
||||
f.logger.Error2("forwarder: Failed to create ICMP socket for %v: %v", epID(id), err)
|
||||
|
||||
dstIP := f.determineDialAddr(id.LocalAddress)
|
||||
dst := &net.IPAddr{IP: dstIP}
|
||||
|
||||
if _, err = conn.WriteTo(payload, dst); err != nil {
|
||||
if closeErr := conn.Close(); closeErr != nil {
|
||||
f.logger.Debug1("forwarder: Failed to close ICMP socket: %v", closeErr)
|
||||
}
|
||||
return nil, fmt.Errorf("write ICMP packet: %w", err)
|
||||
}
|
||||
|
||||
f.logger.Trace3("forwarder: Forwarded ICMP packet %v type %v code %v",
|
||||
epID(id), icmpType, icmpCode)
|
||||
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
// handleICMPViaSocket handles ICMP echo requests using raw sockets.
|
||||
func (f *Forwarder) handleICMPViaSocket(flowID uuid.UUID, id stack.TransportEndpointID, icmpType, icmpCode uint8, icmpData []byte, rxBytes int) {
|
||||
sendTime := time.Now()
|
||||
|
||||
conn, err := f.forwardICMPPacket(id, icmpData, icmpType, icmpCode, 5*time.Second)
|
||||
if err != nil {
|
||||
f.logger.Error2("forwarder: Failed to send ICMP packet for %v: %v", epID(id), err)
|
||||
return
|
||||
// This will make netstack reply on behalf of the original destination, that's ok for now
|
||||
return false
|
||||
}
|
||||
defer func() {
|
||||
if err := conn.Close(); err != nil {
|
||||
@@ -113,22 +45,38 @@ func (f *Forwarder) handleICMPViaSocket(flowID uuid.UUID, id stack.TransportEndp
|
||||
}
|
||||
}()
|
||||
|
||||
txBytes := f.handleEchoResponse(conn, id)
|
||||
rtt := time.Since(sendTime).Round(10 * time.Microsecond)
|
||||
dstIP := f.determineDialAddr(id.LocalAddress)
|
||||
dst := &net.IPAddr{IP: dstIP}
|
||||
|
||||
f.logger.Trace4("forwarder: Forwarded ICMP echo reply %v type %v code %v (rtt=%v, raw socket)",
|
||||
epID(id), icmpType, icmpCode, rtt)
|
||||
fullPacket := stack.PayloadSince(pkt.TransportHeader())
|
||||
payload := fullPacket.AsSlice()
|
||||
|
||||
f.sendICMPEvent(nftypes.TypeEnd, flowID, id, icmpType, icmpCode, uint64(rxBytes), uint64(txBytes))
|
||||
if _, err = conn.WriteTo(payload, dst); err != nil {
|
||||
f.logger.Error2("forwarder: Failed to write ICMP packet for %v: %v", epID(id), err)
|
||||
return true
|
||||
}
|
||||
|
||||
f.logger.Trace3("forwarder: Forwarded ICMP packet %v type %v code %v",
|
||||
epID(id), icmpHdr.Type(), icmpHdr.Code())
|
||||
|
||||
// For Echo Requests, send and handle response
|
||||
if header.ICMPv4Type(icmpType) == header.ICMPv4Echo {
|
||||
rxBytes := pkt.Size()
|
||||
txBytes := f.handleEchoResponse(icmpHdr, conn, id)
|
||||
f.sendICMPEvent(nftypes.TypeEnd, flowID, id, icmpType, icmpCode, uint64(rxBytes), uint64(txBytes))
|
||||
}
|
||||
|
||||
// For other ICMP types (Time Exceeded, Destination Unreachable, etc) do nothing
|
||||
return true
|
||||
}
|
||||
|
||||
func (f *Forwarder) handleEchoResponse(conn net.PacketConn, id stack.TransportEndpointID) int {
|
||||
func (f *Forwarder) handleEchoResponse(icmpHdr header.ICMPv4, conn net.PacketConn, id stack.TransportEndpointID) int {
|
||||
if err := conn.SetReadDeadline(time.Now().Add(5 * time.Second)); err != nil {
|
||||
f.logger.Error1("forwarder: Failed to set read deadline for ICMP response: %v", err)
|
||||
return 0
|
||||
}
|
||||
|
||||
response := make([]byte, f.endpoint.mtu.Load())
|
||||
response := make([]byte, f.endpoint.mtu)
|
||||
n, _, err := conn.ReadFrom(response)
|
||||
if err != nil {
|
||||
if !isTimeout(err) {
|
||||
@@ -137,7 +85,31 @@ func (f *Forwarder) handleEchoResponse(conn net.PacketConn, id stack.TransportEn
|
||||
return 0
|
||||
}
|
||||
|
||||
return f.injectICMPReply(id, response[:n])
|
||||
ipHdr := make([]byte, header.IPv4MinimumSize)
|
||||
ip := header.IPv4(ipHdr)
|
||||
ip.Encode(&header.IPv4Fields{
|
||||
TotalLength: uint16(header.IPv4MinimumSize + n),
|
||||
TTL: 64,
|
||||
Protocol: uint8(header.ICMPv4ProtocolNumber),
|
||||
SrcAddr: id.LocalAddress,
|
||||
DstAddr: id.RemoteAddress,
|
||||
})
|
||||
ip.SetChecksum(^ip.CalculateChecksum())
|
||||
|
||||
fullPacket := make([]byte, 0, len(ipHdr)+n)
|
||||
fullPacket = append(fullPacket, ipHdr...)
|
||||
fullPacket = append(fullPacket, response[:n]...)
|
||||
|
||||
if err := f.InjectIncomingPacket(fullPacket); err != nil {
|
||||
f.logger.Error1("forwarder: Failed to inject ICMP response: %v", err)
|
||||
|
||||
return 0
|
||||
}
|
||||
|
||||
f.logger.Trace3("forwarder: Forwarded ICMP echo reply for %v type %v code %v",
|
||||
epID(id), icmpHdr.Type(), icmpHdr.Code())
|
||||
|
||||
return len(fullPacket)
|
||||
}
|
||||
|
||||
// sendICMPEvent stores flow events for ICMP packets
|
||||
@@ -180,95 +152,3 @@ func (f *Forwarder) sendICMPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.T
|
||||
|
||||
f.flowLogger.StoreEvent(fields)
|
||||
}
|
||||
|
||||
// handleICMPViaPing handles ICMP echo requests by executing the system ping binary.
|
||||
// This is used as a fallback when raw socket access is not available.
|
||||
func (f *Forwarder) handleICMPViaPing(flowID uuid.UUID, id stack.TransportEndpointID, icmpType, icmpCode uint8, icmpData []byte, rxBytes int) {
|
||||
ctx, cancel := context.WithTimeout(f.ctx, 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
dstIP := f.determineDialAddr(id.LocalAddress)
|
||||
cmd := buildPingCommand(ctx, dstIP, 5*time.Second)
|
||||
|
||||
pingStart := time.Now()
|
||||
if err := cmd.Run(); err != nil {
|
||||
f.logger.Warn4("forwarder: Ping binary failed for %v type %v code %v: %v", epID(id),
|
||||
icmpType, icmpCode, err)
|
||||
return
|
||||
}
|
||||
rtt := time.Since(pingStart).Round(10 * time.Microsecond)
|
||||
|
||||
f.logger.Trace3("forwarder: Forwarded ICMP echo request %v type %v code %v",
|
||||
epID(id), icmpType, icmpCode)
|
||||
|
||||
txBytes := f.synthesizeEchoReply(id, icmpData)
|
||||
|
||||
f.logger.Trace4("forwarder: Forwarded ICMP echo reply %v type %v code %v (rtt=%v, ping binary)",
|
||||
epID(id), icmpType, icmpCode, rtt)
|
||||
|
||||
f.sendICMPEvent(nftypes.TypeEnd, flowID, id, icmpType, icmpCode, uint64(rxBytes), uint64(txBytes))
|
||||
}
|
||||
|
||||
// buildPingCommand creates a platform-specific ping command.
|
||||
func buildPingCommand(ctx context.Context, target net.IP, timeout time.Duration) *exec.Cmd {
|
||||
timeoutSec := int(timeout.Seconds())
|
||||
if timeoutSec < 1 {
|
||||
timeoutSec = 1
|
||||
}
|
||||
|
||||
switch runtime.GOOS {
|
||||
case "linux", "android":
|
||||
return exec.CommandContext(ctx, "ping", "-c", "1", "-W", fmt.Sprintf("%d", timeoutSec), "-q", target.String())
|
||||
case "darwin", "ios":
|
||||
return exec.CommandContext(ctx, "ping", "-c", "1", "-t", fmt.Sprintf("%d", timeoutSec), "-q", target.String())
|
||||
case "freebsd":
|
||||
return exec.CommandContext(ctx, "ping", "-c", "1", "-t", fmt.Sprintf("%d", timeoutSec), target.String())
|
||||
case "openbsd", "netbsd":
|
||||
return exec.CommandContext(ctx, "ping", "-c", "1", "-w", fmt.Sprintf("%d", timeoutSec), target.String())
|
||||
case "windows":
|
||||
return exec.CommandContext(ctx, "ping", "-n", "1", "-w", fmt.Sprintf("%d", timeoutSec*1000), target.String())
|
||||
default:
|
||||
return exec.CommandContext(ctx, "ping", "-c", "1", target.String())
|
||||
}
|
||||
}
|
||||
|
||||
// synthesizeEchoReply creates an ICMP echo reply from raw ICMP data and injects it back into the network stack.
|
||||
// Returns the size of the injected packet.
|
||||
func (f *Forwarder) synthesizeEchoReply(id stack.TransportEndpointID, icmpData []byte) int {
|
||||
replyICMP := make([]byte, len(icmpData))
|
||||
copy(replyICMP, icmpData)
|
||||
|
||||
replyICMPHdr := header.ICMPv4(replyICMP)
|
||||
replyICMPHdr.SetType(header.ICMPv4EchoReply)
|
||||
replyICMPHdr.SetChecksum(0)
|
||||
replyICMPHdr.SetChecksum(header.ICMPv4Checksum(replyICMPHdr, 0))
|
||||
|
||||
return f.injectICMPReply(id, replyICMP)
|
||||
}
|
||||
|
||||
// injectICMPReply wraps an ICMP payload in an IP header and injects it into the network stack.
|
||||
// Returns the total size of the injected packet, or 0 if injection failed.
|
||||
func (f *Forwarder) injectICMPReply(id stack.TransportEndpointID, icmpPayload []byte) int {
|
||||
ipHdr := make([]byte, header.IPv4MinimumSize)
|
||||
ip := header.IPv4(ipHdr)
|
||||
ip.Encode(&header.IPv4Fields{
|
||||
TotalLength: uint16(header.IPv4MinimumSize + len(icmpPayload)),
|
||||
TTL: 64,
|
||||
Protocol: uint8(header.ICMPv4ProtocolNumber),
|
||||
SrcAddr: id.LocalAddress,
|
||||
DstAddr: id.RemoteAddress,
|
||||
})
|
||||
ip.SetChecksum(^ip.CalculateChecksum())
|
||||
|
||||
fullPacket := make([]byte, 0, len(ipHdr)+len(icmpPayload))
|
||||
fullPacket = append(fullPacket, ipHdr...)
|
||||
fullPacket = append(fullPacket, icmpPayload...)
|
||||
|
||||
// Bypass netstack and send directly to peer to avoid looping through our ICMP handler
|
||||
if err := f.endpoint.device.CreateOutboundPacket(fullPacket, id.RemoteAddress.AsSlice()); err != nil {
|
||||
f.logger.Error1("forwarder: Failed to send ICMP reply to peer: %v", err)
|
||||
return 0
|
||||
}
|
||||
|
||||
return len(fullPacket)
|
||||
}
|
||||
|
||||
@@ -4,7 +4,6 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/netip"
|
||||
"sync"
|
||||
@@ -132,10 +131,10 @@ func (f *udpForwarder) cleanup() {
|
||||
}
|
||||
|
||||
// handleUDP is called by the UDP forwarder for new packets
|
||||
func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) bool {
|
||||
func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) {
|
||||
if f.ctx.Err() != nil {
|
||||
f.logger.Trace("forwarder: context done, dropping UDP packet")
|
||||
return false
|
||||
return
|
||||
}
|
||||
|
||||
id := r.ID()
|
||||
@@ -145,7 +144,7 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) bool {
|
||||
f.udpForwarder.RUnlock()
|
||||
if exists {
|
||||
f.logger.Trace1("forwarder: existing UDP connection for %v", epID(id))
|
||||
return true
|
||||
return
|
||||
}
|
||||
|
||||
flowID := uuid.New()
|
||||
@@ -163,7 +162,7 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) bool {
|
||||
if err != nil {
|
||||
f.logger.Debug2("forwarder: UDP dial error for %v: %v", epID(id), err)
|
||||
// TODO: Send ICMP error message
|
||||
return false
|
||||
return
|
||||
}
|
||||
|
||||
// Create wait queue for blocking syscalls
|
||||
@@ -174,10 +173,10 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) bool {
|
||||
if err := outConn.Close(); err != nil {
|
||||
f.logger.Debug2("forwarder: UDP outConn close error for %v: %v", epID(id), err)
|
||||
}
|
||||
return false
|
||||
return
|
||||
}
|
||||
|
||||
inConn := gonet.NewUDPConn(&wq, ep)
|
||||
inConn := gonet.NewUDPConn(f.stack, &wq, ep)
|
||||
connCtx, connCancel := context.WithCancel(f.ctx)
|
||||
|
||||
pConn := &udpPacketConn{
|
||||
@@ -200,7 +199,7 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) bool {
|
||||
if err := outConn.Close(); err != nil {
|
||||
f.logger.Debug2("forwarder: UDP outConn close error for %v: %v", epID(id), err)
|
||||
}
|
||||
return true
|
||||
return
|
||||
}
|
||||
f.udpForwarder.conns[id] = pConn
|
||||
f.udpForwarder.Unlock()
|
||||
@@ -209,7 +208,6 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) bool {
|
||||
f.logger.Trace1("forwarder: established UDP connection %v", epID(id))
|
||||
|
||||
go f.proxyUDP(connCtx, pConn, id, ep)
|
||||
return true
|
||||
}
|
||||
|
||||
func (f *Forwarder) proxyUDP(ctx context.Context, pConn *udpPacketConn, id stack.TransportEndpointID, ep tcpip.Endpoint) {
|
||||
@@ -350,7 +348,7 @@ func (c *udpPacketConn) copy(ctx context.Context, dst net.Conn, src net.Conn, bu
|
||||
}
|
||||
|
||||
func isClosedError(err error) bool {
|
||||
return errors.Is(err, net.ErrClosed) || errors.Is(err, context.Canceled) || errors.Is(err, io.EOF)
|
||||
return errors.Is(err, net.ErrClosed) || errors.Is(err, context.Canceled)
|
||||
}
|
||||
|
||||
func isTimeout(err error) bool {
|
||||
|
||||
@@ -130,7 +130,6 @@ func (m *localIPManager) UpdateLocalIPs(iface common.IFaceMapper) (err error) {
|
||||
// 127.0.0.0/8
|
||||
newIPv4Bitmap[127] = &ipv4LowBitmap{}
|
||||
for i := 0; i < 8192; i++ {
|
||||
// #nosec G602 -- bitmap is defined as [8192]uint32, loop range is correct
|
||||
newIPv4Bitmap[127].bitmap[i] = 0xFFFFFFFF
|
||||
}
|
||||
|
||||
|
||||
@@ -218,7 +218,7 @@ func BenchmarkIPChecks(b *testing.B) {
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
// nolint:gosimple
|
||||
_ = mapManager.localIPs[ip.String()]
|
||||
_, _ = mapManager.localIPs[ip.String()]
|
||||
}
|
||||
})
|
||||
|
||||
@@ -227,7 +227,7 @@ func BenchmarkIPChecks(b *testing.B) {
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
// nolint:gosimple
|
||||
_ = mapManager.localIPs[ip.String()]
|
||||
_, _ = mapManager.localIPs[ip.String()]
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -168,15 +168,6 @@ func (l *Logger) Warn3(format string, arg1, arg2, arg3 any) {
|
||||
}
|
||||
}
|
||||
|
||||
func (l *Logger) Warn4(format string, arg1, arg2, arg3, arg4 any) {
|
||||
if l.level.Load() >= uint32(LevelWarn) {
|
||||
select {
|
||||
case l.msgChannel <- logMessage{level: LevelWarn, format: format, arg1: arg1, arg2: arg2, arg3: arg3, arg4: arg4}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (l *Logger) Debug1(format string, arg1 any) {
|
||||
if l.level.Load() >= uint32(LevelDebug) {
|
||||
select {
|
||||
|
||||
@@ -234,10 +234,9 @@ func TestInboundPortDNATNegative(t *testing.T) {
|
||||
require.False(t, translated, "Packet should NOT be translated for %s", tc.name)
|
||||
|
||||
d = parsePacket(t, packet)
|
||||
switch tc.protocol {
|
||||
case layers.IPProtocolTCP:
|
||||
if tc.protocol == layers.IPProtocolTCP {
|
||||
require.Equal(t, tc.dstPort, uint16(d.tcp.DstPort), "Port should remain unchanged")
|
||||
case layers.IPProtocolUDP:
|
||||
} else if tc.protocol == layers.IPProtocolUDP {
|
||||
require.Equal(t, tc.dstPort, uint16(d.udp.DstPort), "Port should remain unchanged")
|
||||
}
|
||||
})
|
||||
|
||||
@@ -34,7 +34,7 @@ type RouteRule struct {
|
||||
sources []netip.Prefix
|
||||
dstSet firewall.Set
|
||||
destinations []netip.Prefix
|
||||
protoLayer gopacket.LayerType
|
||||
proto firewall.Protocol
|
||||
srcPort *firewall.Port
|
||||
dstPort *firewall.Port
|
||||
action firewall.Action
|
||||
|
||||
@@ -379,9 +379,9 @@ func (m *Manager) handleNativeRouter(trace *PacketTrace) *PacketTrace {
|
||||
}
|
||||
|
||||
func (m *Manager) handleRouteACLs(trace *PacketTrace, d *decoder, srcIP, dstIP netip.Addr) *PacketTrace {
|
||||
protoLayer := d.decoded[1]
|
||||
proto, _ := getProtocolFromPacket(d)
|
||||
srcPort, dstPort := getPortsFromPacket(d)
|
||||
id, allowed := m.routeACLsPass(srcIP, dstIP, protoLayer, srcPort, dstPort)
|
||||
id, allowed := m.routeACLsPass(srcIP, dstIP, proto, srcPort, dstPort)
|
||||
|
||||
strId := string(id)
|
||||
if id == nil {
|
||||
|
||||
@@ -27,23 +27,8 @@ type receiverCreator struct {
|
||||
iceBind *ICEBind
|
||||
}
|
||||
|
||||
func (rc receiverCreator) CreateReceiverFn(pc wgConn.BatchReader, conn *net.UDPConn, rxOffload bool, msgPool *sync.Pool) wgConn.ReceiveFunc {
|
||||
if ipv4PC, ok := pc.(*ipv4.PacketConn); ok {
|
||||
return rc.iceBind.createIPv4ReceiverFn(ipv4PC, conn, rxOffload, msgPool)
|
||||
}
|
||||
// IPv6 is currently not supported in the udpmux, this is a stub for compatibility with the
|
||||
// wireguard-go ReceiverCreator interface which is called for both IPv4 and IPv6.
|
||||
return func(bufs [][]byte, sizes []int, eps []wgConn.Endpoint) (n int, err error) {
|
||||
buf := bufs[0]
|
||||
size, ep, err := conn.ReadFromUDPAddrPort(buf)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
sizes[0] = size
|
||||
stdEp := &wgConn.StdNetEndpoint{AddrPort: ep}
|
||||
eps[0] = stdEp
|
||||
return 1, nil
|
||||
}
|
||||
func (rc receiverCreator) CreateIPv4ReceiverFn(pc *ipv4.PacketConn, conn *net.UDPConn, rxOffload bool, msgPool *sync.Pool) wgConn.ReceiveFunc {
|
||||
return rc.iceBind.createIPv4ReceiverFn(pc, conn, rxOffload, msgPool)
|
||||
}
|
||||
|
||||
// ICEBind is a bind implementation with two main features:
|
||||
|
||||
@@ -1,3 +1,6 @@
|
||||
//go:build ios
|
||||
// +build ios
|
||||
|
||||
package device
|
||||
|
||||
import (
|
||||
|
||||
@@ -23,10 +23,10 @@ func NewNSDialer(net *netstack.Net) *NSDialer {
|
||||
}
|
||||
|
||||
func (d *NSDialer) Dial(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
log.Infof("NSDialer.Dial: network=%s, addr=%s", network, addr)
|
||||
log.Debugf("dialing %s %s", network, addr)
|
||||
conn, err := d.net.Dial(network, addr)
|
||||
if err != nil {
|
||||
log.Warnf("NSDialer.Dial failed: %s", err)
|
||||
log.Debugf("failed to deal connection: %s", err)
|
||||
}
|
||||
return conn, err
|
||||
}
|
||||
|
||||
@@ -420,19 +420,6 @@ func (c *ConnectClient) GetLatestSyncResponse() (*mgmProto.SyncResponse, error)
|
||||
return syncResponse, nil
|
||||
}
|
||||
|
||||
// SetLogLevel sets the log level for the firewall manager if the engine is running.
|
||||
func (c *ConnectClient) SetLogLevel(level log.Level) {
|
||||
engine := c.Engine()
|
||||
if engine == nil {
|
||||
return
|
||||
}
|
||||
|
||||
fwManager := engine.GetFirewallManager()
|
||||
if fwManager != nil {
|
||||
fwManager.SetLogLevel(level)
|
||||
}
|
||||
}
|
||||
|
||||
// Status returns the current client status
|
||||
func (c *ConnectClient) Status() StatusType {
|
||||
if c == nil {
|
||||
|
||||
@@ -507,13 +507,15 @@ func formatPayloadWithCmp(p *expr.Payload, cmp *expr.Cmp) string {
|
||||
if p.Base == expr.PayloadBaseNetworkHeader {
|
||||
switch p.Offset {
|
||||
case 12:
|
||||
switch p.Len {
|
||||
case 4, 2:
|
||||
if p.Len == 4 {
|
||||
return fmt.Sprintf("ip saddr %s %s", formatCmpOp(cmp.Op), formatIPBytes(cmp.Data))
|
||||
} else if p.Len == 2 {
|
||||
return fmt.Sprintf("ip saddr %s %s", formatCmpOp(cmp.Op), formatIPBytes(cmp.Data))
|
||||
}
|
||||
case 16:
|
||||
switch p.Len {
|
||||
case 4, 2:
|
||||
if p.Len == 4 {
|
||||
return fmt.Sprintf("ip daddr %s %s", formatCmpOp(cmp.Op), formatIPBytes(cmp.Data))
|
||||
} else if p.Len == 2 {
|
||||
return fmt.Sprintf("ip daddr %s %s", formatCmpOp(cmp.Op), formatIPBytes(cmp.Data))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -76,7 +76,7 @@ func collectPTRRecords(config *nbdns.Config, prefix netip.Prefix) []nbdns.Simple
|
||||
var records []nbdns.SimpleRecord
|
||||
|
||||
for _, zone := range config.CustomZones {
|
||||
if zone.NonAuthoritative {
|
||||
if zone.SkipPTRProcess {
|
||||
continue
|
||||
}
|
||||
for _, record := range zone.Records {
|
||||
|
||||
@@ -3,21 +3,17 @@ package dns
|
||||
import (
|
||||
"fmt"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/dns/resutil"
|
||||
)
|
||||
|
||||
const (
|
||||
PriorityMgmtCache = 150
|
||||
PriorityDNSRoute = 100
|
||||
PriorityLocal = 75
|
||||
PriorityLocal = 100
|
||||
PriorityDNSRoute = 75
|
||||
PriorityUpstream = 50
|
||||
PriorityDefault = 1
|
||||
PriorityFallback = -100
|
||||
@@ -47,23 +43,7 @@ type HandlerChain struct {
|
||||
type ResponseWriterChain struct {
|
||||
dns.ResponseWriter
|
||||
origPattern string
|
||||
requestID string
|
||||
shouldContinue bool
|
||||
response *dns.Msg
|
||||
meta map[string]string
|
||||
}
|
||||
|
||||
// RequestID returns the request ID for tracing
|
||||
func (w *ResponseWriterChain) RequestID() string {
|
||||
return w.requestID
|
||||
}
|
||||
|
||||
// SetMeta sets a metadata key-value pair for logging
|
||||
func (w *ResponseWriterChain) SetMeta(key, value string) {
|
||||
if w.meta == nil {
|
||||
w.meta = make(map[string]string)
|
||||
}
|
||||
w.meta[key] = value
|
||||
}
|
||||
|
||||
func (w *ResponseWriterChain) WriteMsg(m *dns.Msg) error {
|
||||
@@ -72,7 +52,6 @@ func (w *ResponseWriterChain) WriteMsg(m *dns.Msg) error {
|
||||
w.shouldContinue = true
|
||||
return nil
|
||||
}
|
||||
w.response = m
|
||||
return w.ResponseWriter.WriteMsg(m)
|
||||
}
|
||||
|
||||
@@ -122,8 +101,6 @@ func (c *HandlerChain) AddHandler(pattern string, handler dns.Handler, priority
|
||||
|
||||
pos := c.findHandlerPosition(entry)
|
||||
c.handlers = append(c.handlers[:pos], append([]HandlerEntry{entry}, c.handlers[pos:]...)...)
|
||||
|
||||
c.logHandlers()
|
||||
}
|
||||
|
||||
// findHandlerPosition determines where to insert a new handler based on priority and specificity
|
||||
@@ -163,109 +140,68 @@ func (c *HandlerChain) removeEntry(pattern string, priority int) {
|
||||
for i := len(c.handlers) - 1; i >= 0; i-- {
|
||||
entry := c.handlers[i]
|
||||
if strings.EqualFold(entry.OrigPattern, pattern) && entry.Priority == priority {
|
||||
log.Debugf("removing handler pattern: domain=%s priority=%d", entry.OrigPattern, priority)
|
||||
c.handlers = append(c.handlers[:i], c.handlers[i+1:]...)
|
||||
c.logHandlers()
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// logHandlers logs the current handler chain state. Caller must hold the lock.
|
||||
func (c *HandlerChain) logHandlers() {
|
||||
if !log.IsLevelEnabled(log.TraceLevel) {
|
||||
return
|
||||
}
|
||||
|
||||
var b strings.Builder
|
||||
b.WriteString("handler chain (" + strconv.Itoa(len(c.handlers)) + "):\n")
|
||||
for _, h := range c.handlers {
|
||||
b.WriteString(" - pattern: domain=" + h.Pattern + " original: domain=" + h.OrigPattern +
|
||||
" wildcard=" + strconv.FormatBool(h.IsWildcard) +
|
||||
" match_subdomain=" + strconv.FormatBool(h.MatchSubdomains) +
|
||||
" priority=" + strconv.Itoa(h.Priority) + "\n")
|
||||
}
|
||||
log.Trace(strings.TrimSuffix(b.String(), "\n"))
|
||||
}
|
||||
|
||||
func (c *HandlerChain) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||
if len(r.Question) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
startTime := time.Now()
|
||||
requestID := resutil.GenerateRequestID()
|
||||
logger := log.WithFields(log.Fields{
|
||||
"request_id": requestID,
|
||||
"dns_id": fmt.Sprintf("%04x", r.Id),
|
||||
})
|
||||
|
||||
question := r.Question[0]
|
||||
qname := strings.ToLower(question.Name)
|
||||
qname := strings.ToLower(r.Question[0].Name)
|
||||
|
||||
c.mu.RLock()
|
||||
handlers := slices.Clone(c.handlers)
|
||||
c.mu.RUnlock()
|
||||
|
||||
if log.IsLevelEnabled(log.TraceLevel) {
|
||||
var b strings.Builder
|
||||
b.WriteString(fmt.Sprintf("DNS request domain=%s, handlers (%d):\n", qname, len(handlers)))
|
||||
for _, h := range handlers {
|
||||
b.WriteString(fmt.Sprintf(" - pattern: domain=%s original: domain=%s wildcard=%v match_subdomain=%v priority=%d\n",
|
||||
h.Pattern, h.OrigPattern, h.IsWildcard, h.MatchSubdomains, h.Priority))
|
||||
}
|
||||
log.Trace(strings.TrimSuffix(b.String(), "\n"))
|
||||
}
|
||||
|
||||
// Try handlers in priority order
|
||||
for _, entry := range handlers {
|
||||
if !c.isHandlerMatch(qname, entry) {
|
||||
continue
|
||||
}
|
||||
matched := c.isHandlerMatch(qname, entry)
|
||||
|
||||
handlerName := entry.OrigPattern
|
||||
if s, ok := entry.Handler.(interface{ String() string }); ok {
|
||||
handlerName = s.String()
|
||||
}
|
||||
if matched {
|
||||
log.Tracef("handler matched: domain=%s -> pattern=%s wildcard=%v match_subdomain=%v priority=%d",
|
||||
qname, entry.OrigPattern, entry.IsWildcard, entry.MatchSubdomains, entry.Priority)
|
||||
|
||||
logger.Tracef("question: domain=%s type=%s class=%s -> handler=%s pattern=%s wildcard=%v match_subdomain=%v priority=%d",
|
||||
qname, dns.TypeToString[question.Qtype], dns.ClassToString[question.Qclass],
|
||||
handlerName, entry.OrigPattern, entry.IsWildcard, entry.MatchSubdomains, entry.Priority)
|
||||
|
||||
chainWriter := &ResponseWriterChain{
|
||||
ResponseWriter: w,
|
||||
origPattern: entry.OrigPattern,
|
||||
requestID: requestID,
|
||||
}
|
||||
entry.Handler.ServeDNS(chainWriter, r)
|
||||
|
||||
// If handler wants to continue, try next handler
|
||||
if chainWriter.shouldContinue {
|
||||
if entry.Priority != PriorityMgmtCache {
|
||||
logger.Tracef("handler requested continue for domain=%s", qname)
|
||||
chainWriter := &ResponseWriterChain{
|
||||
ResponseWriter: w,
|
||||
origPattern: entry.OrigPattern,
|
||||
}
|
||||
continue
|
||||
}
|
||||
entry.Handler.ServeDNS(chainWriter, r)
|
||||
|
||||
c.logResponse(logger, chainWriter, qname, startTime)
|
||||
return
|
||||
// If handler wants to continue, try next handler
|
||||
if chainWriter.shouldContinue {
|
||||
// Only log continue for non-management cache handlers to reduce noise
|
||||
if entry.Priority != PriorityMgmtCache {
|
||||
log.Tracef("handler requested continue to next handler for domain=%s", qname)
|
||||
}
|
||||
continue
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// No handler matched or all handlers passed
|
||||
logger.Tracef("no handler found for domain=%s type=%s class=%s",
|
||||
qname, dns.TypeToString[question.Qtype], dns.ClassToString[question.Qclass])
|
||||
log.Tracef("no handler found for domain=%s", qname)
|
||||
resp := &dns.Msg{}
|
||||
resp.SetRcode(r, dns.RcodeRefused)
|
||||
if err := w.WriteMsg(resp); err != nil {
|
||||
logger.Errorf("failed to write DNS response: %v", err)
|
||||
log.Errorf("failed to write DNS response: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *HandlerChain) logResponse(logger *log.Entry, cw *ResponseWriterChain, qname string, startTime time.Time) {
|
||||
if cw.response == nil {
|
||||
return
|
||||
}
|
||||
|
||||
var meta string
|
||||
for k, v := range cw.meta {
|
||||
meta += " " + k + "=" + v
|
||||
}
|
||||
|
||||
logger.Tracef("response: domain=%s rcode=%s answers=%s%s took=%s",
|
||||
qname, dns.RcodeToString[cw.response.Rcode], resutil.FormatAnswers(cw.response.Answer),
|
||||
meta, time.Since(startTime))
|
||||
}
|
||||
|
||||
func (c *HandlerChain) isHandlerMatch(qname string, entry HandlerEntry) bool {
|
||||
switch {
|
||||
case entry.Pattern == ".":
|
||||
|
||||
@@ -1,52 +1,30 @@
|
||||
package local
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/exp/maps"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/dns/resutil"
|
||||
"github.com/netbirdio/netbird/client/internal/dns/types"
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
)
|
||||
|
||||
const externalResolutionTimeout = 4 * time.Second
|
||||
|
||||
type resolver interface {
|
||||
LookupNetIP(ctx context.Context, network, host string) ([]netip.Addr, error)
|
||||
}
|
||||
|
||||
type Resolver struct {
|
||||
mu sync.RWMutex
|
||||
records map[dns.Question][]dns.RR
|
||||
domains map[domain.Domain]struct{}
|
||||
// zones maps zone domain -> NonAuthoritative (true = non-authoritative, user-created zone)
|
||||
zones map[domain.Domain]bool
|
||||
resolver resolver
|
||||
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
|
||||
func NewResolver() *Resolver {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
return &Resolver{
|
||||
records: make(map[dns.Question][]dns.RR),
|
||||
domains: make(map[domain.Domain]struct{}),
|
||||
zones: make(map[domain.Domain]bool),
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -59,18 +37,7 @@ func (d *Resolver) String() string {
|
||||
return fmt.Sprintf("LocalResolver [%d records]", len(d.records))
|
||||
}
|
||||
|
||||
func (d *Resolver) Stop() {
|
||||
if d.cancel != nil {
|
||||
d.cancel()
|
||||
}
|
||||
|
||||
d.mu.Lock()
|
||||
defer d.mu.Unlock()
|
||||
|
||||
maps.Clear(d.records)
|
||||
maps.Clear(d.domains)
|
||||
maps.Clear(d.zones)
|
||||
}
|
||||
func (d *Resolver) Stop() {}
|
||||
|
||||
// ID returns the unique handler ID
|
||||
func (d *Resolver) ID() types.HandlerID {
|
||||
@@ -81,85 +48,35 @@ func (d *Resolver) ProbeAvailability() {}
|
||||
|
||||
// ServeDNS handles a DNS request
|
||||
func (d *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||
logger := log.WithField("request_id", resutil.GetRequestID(w))
|
||||
|
||||
if len(r.Question) == 0 {
|
||||
logger.Debug("received local resolver request with no question")
|
||||
log.Debugf("received local resolver request with no question")
|
||||
return
|
||||
}
|
||||
question := r.Question[0]
|
||||
question.Name = strings.ToLower(dns.Fqdn(question.Name))
|
||||
|
||||
log.Tracef("received local question: domain=%s type=%v class=%v", r.Question[0].Name, question.Qtype, question.Qclass)
|
||||
|
||||
replyMessage := &dns.Msg{}
|
||||
replyMessage.SetReply(r)
|
||||
replyMessage.RecursionAvailable = true
|
||||
|
||||
result := d.lookupRecords(logger, question)
|
||||
replyMessage.Authoritative = !result.hasExternalData
|
||||
replyMessage.Answer = result.records
|
||||
replyMessage.Rcode = d.determineRcode(question, result)
|
||||
|
||||
if replyMessage.Rcode == dns.RcodeNameError && d.shouldFallthrough(question.Name) {
|
||||
d.continueToNext(logger, w, r)
|
||||
return
|
||||
// lookup all records matching the question
|
||||
records := d.lookupRecords(question)
|
||||
if len(records) > 0 {
|
||||
replyMessage.Rcode = dns.RcodeSuccess
|
||||
replyMessage.Answer = append(replyMessage.Answer, records...)
|
||||
} else {
|
||||
// Check if we have any records for this domain name with different types
|
||||
if d.hasRecordsForDomain(domain.Domain(question.Name)) {
|
||||
replyMessage.Rcode = dns.RcodeSuccess // NOERROR with 0 records
|
||||
} else {
|
||||
replyMessage.Rcode = dns.RcodeNameError // NXDOMAIN
|
||||
}
|
||||
}
|
||||
|
||||
if err := w.WriteMsg(replyMessage); err != nil {
|
||||
logger.Warnf("failed to write the local resolver response: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// determineRcode returns the appropriate DNS response code.
|
||||
// Per RFC 6604, CNAME chains should return the rcode of the final target resolution,
|
||||
// even if CNAME records are included in the answer.
|
||||
func (d *Resolver) determineRcode(question dns.Question, result lookupResult) int {
|
||||
// Use the rcode from lookup - this properly handles CNAME chains where
|
||||
// the target may be NXDOMAIN or SERVFAIL even though we have CNAME records
|
||||
if result.rcode != 0 {
|
||||
return result.rcode
|
||||
}
|
||||
|
||||
// No records found, but domain exists with different record types (NODATA)
|
||||
if d.hasRecordsForDomain(domain.Domain(question.Name)) {
|
||||
return dns.RcodeSuccess
|
||||
}
|
||||
|
||||
return dns.RcodeNameError
|
||||
}
|
||||
|
||||
// findZone finds the matching zone for a query name using reverse suffix lookup.
|
||||
// Returns (nonAuthoritative, found). This is O(k) where k = number of labels in qname.
|
||||
func (d *Resolver) findZone(qname string) (nonAuthoritative bool, found bool) {
|
||||
qname = strings.ToLower(dns.Fqdn(qname))
|
||||
for {
|
||||
if nonAuth, ok := d.zones[domain.Domain(qname)]; ok {
|
||||
return nonAuth, true
|
||||
}
|
||||
// Move to parent domain
|
||||
idx := strings.Index(qname, ".")
|
||||
if idx == -1 || idx == len(qname)-1 {
|
||||
return false, false
|
||||
}
|
||||
qname = qname[idx+1:]
|
||||
}
|
||||
}
|
||||
|
||||
// shouldFallthrough checks if the query should fallthrough to the next handler.
|
||||
// Returns true if the queried name belongs to a non-authoritative zone.
|
||||
func (d *Resolver) shouldFallthrough(qname string) bool {
|
||||
d.mu.RLock()
|
||||
defer d.mu.RUnlock()
|
||||
|
||||
nonAuth, found := d.findZone(qname)
|
||||
return found && nonAuth
|
||||
}
|
||||
|
||||
func (d *Resolver) continueToNext(logger *log.Entry, w dns.ResponseWriter, r *dns.Msg) {
|
||||
resp := &dns.Msg{}
|
||||
resp.SetRcode(r, dns.RcodeNameError)
|
||||
resp.MsgHdr.Zero = true
|
||||
if err := w.WriteMsg(resp); err != nil {
|
||||
logger.Warnf("failed to write continue signal: %v", err)
|
||||
log.Warnf("failed to write the local resolver response: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -172,27 +89,8 @@ func (d *Resolver) hasRecordsForDomain(domainName domain.Domain) bool {
|
||||
return exists
|
||||
}
|
||||
|
||||
// isInManagedZone checks if the given name falls within any of our managed zones.
|
||||
// This is used to avoid unnecessary external resolution for CNAME targets that
|
||||
// are within zones we manage - if we don't have a record for it, it doesn't exist.
|
||||
// Caller must NOT hold the lock.
|
||||
func (d *Resolver) isInManagedZone(name string) bool {
|
||||
d.mu.RLock()
|
||||
defer d.mu.RUnlock()
|
||||
|
||||
_, found := d.findZone(name)
|
||||
return found
|
||||
}
|
||||
|
||||
// lookupResult contains the result of a DNS lookup operation.
|
||||
type lookupResult struct {
|
||||
records []dns.RR
|
||||
rcode int
|
||||
hasExternalData bool
|
||||
}
|
||||
|
||||
// lookupRecords fetches *all* DNS records matching the first question in r.
|
||||
func (d *Resolver) lookupRecords(logger *log.Entry, question dns.Question) lookupResult {
|
||||
func (d *Resolver) lookupRecords(question dns.Question) []dns.RR {
|
||||
d.mu.RLock()
|
||||
records, found := d.records[question]
|
||||
|
||||
@@ -200,14 +98,10 @@ func (d *Resolver) lookupRecords(logger *log.Entry, question dns.Question) looku
|
||||
d.mu.RUnlock()
|
||||
// alternatively check if we have a cname
|
||||
if question.Qtype != dns.TypeCNAME {
|
||||
cnameQuestion := dns.Question{
|
||||
Name: question.Name,
|
||||
Qtype: dns.TypeCNAME,
|
||||
Qclass: question.Qclass,
|
||||
}
|
||||
return d.lookupCNAMEChain(logger, cnameQuestion, question.Qtype)
|
||||
question.Qtype = dns.TypeCNAME
|
||||
return d.lookupRecords(question)
|
||||
}
|
||||
return lookupResult{rcode: dns.RcodeNameError}
|
||||
return nil
|
||||
}
|
||||
|
||||
recordsCopy := slices.Clone(records)
|
||||
@@ -225,178 +119,20 @@ func (d *Resolver) lookupRecords(logger *log.Entry, question dns.Question) looku
|
||||
d.mu.Unlock()
|
||||
}
|
||||
|
||||
return lookupResult{records: recordsCopy, rcode: dns.RcodeSuccess}
|
||||
return recordsCopy
|
||||
}
|
||||
|
||||
// lookupCNAMEChain follows a CNAME chain and returns the CNAME records along with
|
||||
// the final resolved record of the requested type. This is required for musl libc
|
||||
// compatibility, which expects the full answer chain rather than just the CNAME.
|
||||
func (d *Resolver) lookupCNAMEChain(logger *log.Entry, cnameQuestion dns.Question, targetType uint16) lookupResult {
|
||||
const maxDepth = 8
|
||||
var chain []dns.RR
|
||||
|
||||
for range maxDepth {
|
||||
cnameRecords := d.getRecords(cnameQuestion)
|
||||
if len(cnameRecords) == 0 {
|
||||
break
|
||||
}
|
||||
|
||||
chain = append(chain, cnameRecords...)
|
||||
|
||||
cname, ok := cnameRecords[0].(*dns.CNAME)
|
||||
if !ok {
|
||||
break
|
||||
}
|
||||
|
||||
targetName := strings.ToLower(cname.Target)
|
||||
targetResult := d.resolveCNAMETarget(logger, targetName, targetType, cnameQuestion.Qclass)
|
||||
|
||||
// keep following chain
|
||||
if targetResult.rcode == -1 {
|
||||
cnameQuestion = dns.Question{Name: targetName, Qtype: dns.TypeCNAME, Qclass: cnameQuestion.Qclass}
|
||||
continue
|
||||
}
|
||||
|
||||
return d.buildChainResult(chain, targetResult)
|
||||
}
|
||||
|
||||
if len(chain) > 0 {
|
||||
return lookupResult{records: chain, rcode: dns.RcodeSuccess}
|
||||
}
|
||||
return lookupResult{rcode: dns.RcodeSuccess}
|
||||
}
|
||||
|
||||
// buildChainResult combines CNAME chain records with the target resolution result.
|
||||
// Per RFC 6604, the final rcode is propagated through the chain.
|
||||
func (d *Resolver) buildChainResult(chain []dns.RR, target lookupResult) lookupResult {
|
||||
records := chain
|
||||
if len(target.records) > 0 {
|
||||
records = append(records, target.records...)
|
||||
}
|
||||
|
||||
// preserve hasExternalData for SERVFAIL so caller knows the error came from upstream
|
||||
if target.hasExternalData && target.rcode == dns.RcodeServerFailure {
|
||||
return lookupResult{
|
||||
records: records,
|
||||
rcode: dns.RcodeServerFailure,
|
||||
hasExternalData: true,
|
||||
}
|
||||
}
|
||||
|
||||
return lookupResult{
|
||||
records: records,
|
||||
rcode: target.rcode,
|
||||
hasExternalData: target.hasExternalData,
|
||||
}
|
||||
}
|
||||
|
||||
// resolveCNAMETarget attempts to resolve a CNAME target name.
|
||||
// Returns rcode=-1 to signal "keep following the chain".
|
||||
func (d *Resolver) resolveCNAMETarget(logger *log.Entry, targetName string, targetType uint16, qclass uint16) lookupResult {
|
||||
if records := d.getRecords(dns.Question{Name: targetName, Qtype: targetType, Qclass: qclass}); len(records) > 0 {
|
||||
return lookupResult{records: records, rcode: dns.RcodeSuccess}
|
||||
}
|
||||
|
||||
// another CNAME, keep following
|
||||
if d.hasRecord(dns.Question{Name: targetName, Qtype: dns.TypeCNAME, Qclass: qclass}) {
|
||||
return lookupResult{rcode: -1}
|
||||
}
|
||||
|
||||
// domain exists locally but not this record type (NODATA)
|
||||
if d.hasRecordsForDomain(domain.Domain(targetName)) {
|
||||
return lookupResult{rcode: dns.RcodeSuccess}
|
||||
}
|
||||
|
||||
// in our zone but doesn't exist (NXDOMAIN)
|
||||
if d.isInManagedZone(targetName) {
|
||||
return lookupResult{rcode: dns.RcodeNameError}
|
||||
}
|
||||
|
||||
return d.resolveExternal(logger, targetName, targetType)
|
||||
}
|
||||
|
||||
func (d *Resolver) getRecords(q dns.Question) []dns.RR {
|
||||
d.mu.RLock()
|
||||
defer d.mu.RUnlock()
|
||||
return d.records[q]
|
||||
}
|
||||
|
||||
func (d *Resolver) hasRecord(q dns.Question) bool {
|
||||
d.mu.RLock()
|
||||
defer d.mu.RUnlock()
|
||||
_, ok := d.records[q]
|
||||
return ok
|
||||
}
|
||||
|
||||
// resolveExternal resolves a domain name using the system resolver.
|
||||
// This is used to resolve CNAME targets that point outside our local zone,
|
||||
// which is required for musl libc compatibility (musl expects complete answers).
|
||||
func (d *Resolver) resolveExternal(logger *log.Entry, name string, qtype uint16) lookupResult {
|
||||
network := resutil.NetworkForQtype(qtype)
|
||||
if network == "" {
|
||||
return lookupResult{rcode: dns.RcodeNotImplemented}
|
||||
}
|
||||
|
||||
resolver := d.resolver
|
||||
if resolver == nil {
|
||||
resolver = net.DefaultResolver
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(d.ctx, externalResolutionTimeout)
|
||||
defer cancel()
|
||||
|
||||
result := resutil.LookupIP(ctx, resolver, network, name, qtype)
|
||||
if result.Err != nil {
|
||||
d.logDNSError(logger, name, qtype, result.Err)
|
||||
return lookupResult{rcode: result.Rcode, hasExternalData: true}
|
||||
}
|
||||
|
||||
return lookupResult{
|
||||
records: resutil.IPsToRRs(name, result.IPs, 60),
|
||||
rcode: dns.RcodeSuccess,
|
||||
hasExternalData: true,
|
||||
}
|
||||
}
|
||||
|
||||
// logDNSError logs DNS resolution errors for debugging.
|
||||
func (d *Resolver) logDNSError(logger *log.Entry, hostname string, qtype uint16, err error) {
|
||||
qtypeName := dns.TypeToString[qtype]
|
||||
|
||||
var dnsErr *net.DNSError
|
||||
if !errors.As(err, &dnsErr) {
|
||||
logger.Debugf("DNS resolution failed for %s type %s: %v", hostname, qtypeName, err)
|
||||
return
|
||||
}
|
||||
|
||||
if dnsErr.IsNotFound {
|
||||
logger.Tracef("DNS target not found: %s type %s", hostname, qtypeName)
|
||||
return
|
||||
}
|
||||
|
||||
if dnsErr.Server != "" {
|
||||
logger.Debugf("DNS resolution failed for %s type %s server=%s: %v", hostname, qtypeName, dnsErr.Server, err)
|
||||
} else {
|
||||
logger.Debugf("DNS resolution failed for %s type %s: %v", hostname, qtypeName, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Update replaces all zones and their records
|
||||
func (d *Resolver) Update(customZones []nbdns.CustomZone) {
|
||||
func (d *Resolver) Update(update []nbdns.SimpleRecord) {
|
||||
d.mu.Lock()
|
||||
defer d.mu.Unlock()
|
||||
|
||||
maps.Clear(d.records)
|
||||
maps.Clear(d.domains)
|
||||
maps.Clear(d.zones)
|
||||
|
||||
for _, zone := range customZones {
|
||||
zoneDomain := domain.Domain(strings.ToLower(dns.Fqdn(zone.Domain)))
|
||||
d.zones[zoneDomain] = zone.NonAuthoritative
|
||||
|
||||
for _, rec := range zone.Records {
|
||||
if err := d.registerRecord(rec); err != nil {
|
||||
log.Warnf("failed to register the record (%s): %v", rec, err)
|
||||
}
|
||||
for _, rec := range update {
|
||||
if err := d.registerRecord(rec); err != nil {
|
||||
log.Warnf("failed to register the record (%s): %v", rec, err)
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,14 +1,8 @@
|
||||
package local
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
"github.com/stretchr/testify/assert"
|
||||
@@ -18,18 +12,6 @@ import (
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
)
|
||||
|
||||
// mockResolver implements resolver for testing
|
||||
type mockResolver struct {
|
||||
lookupFunc func(ctx context.Context, network, host string) ([]netip.Addr, error)
|
||||
}
|
||||
|
||||
func (m *mockResolver) LookupNetIP(ctx context.Context, network, host string) ([]netip.Addr, error) {
|
||||
if m.lookupFunc != nil {
|
||||
return m.lookupFunc(ctx, network, host)
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func TestLocalResolver_ServeDNS(t *testing.T) {
|
||||
recordA := nbdns.SimpleRecord{
|
||||
Name: "peera.netbird.cloud.",
|
||||
@@ -124,11 +106,11 @@ func TestLocalResolver_Update_StaleRecord(t *testing.T) {
|
||||
|
||||
resolver := NewResolver()
|
||||
|
||||
zone1 := []nbdns.CustomZone{{Domain: "example.com.", Records: []nbdns.SimpleRecord{record1}}}
|
||||
zone2 := []nbdns.CustomZone{{Domain: "example.com.", Records: []nbdns.SimpleRecord{record2}}}
|
||||
update1 := []nbdns.SimpleRecord{record1}
|
||||
update2 := []nbdns.SimpleRecord{record2}
|
||||
|
||||
// Apply first update
|
||||
resolver.Update(zone1)
|
||||
resolver.Update(update1)
|
||||
|
||||
// Verify first update
|
||||
resolver.mu.RLock()
|
||||
@@ -140,7 +122,7 @@ func TestLocalResolver_Update_StaleRecord(t *testing.T) {
|
||||
assert.Contains(t, rrSlice1[0].String(), record1.RData, "Record after first update should be %s", record1.RData)
|
||||
|
||||
// Apply second update
|
||||
resolver.Update(zone2)
|
||||
resolver.Update(update2)
|
||||
|
||||
// Verify second update
|
||||
resolver.mu.RLock()
|
||||
@@ -169,10 +151,10 @@ func TestLocalResolver_MultipleRecords_SameQuestion(t *testing.T) {
|
||||
Name: recordName, Type: int(recordType), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.2",
|
||||
}
|
||||
|
||||
zones := []nbdns.CustomZone{{Domain: "example.com.", Records: []nbdns.SimpleRecord{record1, record2}}}
|
||||
update := []nbdns.SimpleRecord{record1, record2}
|
||||
|
||||
// Apply update with both records
|
||||
resolver.Update(zones)
|
||||
resolver.Update(update)
|
||||
|
||||
// Create question that matches both records
|
||||
question := dns.Question{
|
||||
@@ -213,10 +195,10 @@ func TestLocalResolver_RecordRotation(t *testing.T) {
|
||||
Name: recordName, Type: int(recordType), Class: nbdns.DefaultClass, TTL: 300, RData: "192.168.1.3",
|
||||
}
|
||||
|
||||
zones := []nbdns.CustomZone{{Domain: "example.com.", Records: []nbdns.SimpleRecord{record1, record2, record3}}}
|
||||
update := []nbdns.SimpleRecord{record1, record2, record3}
|
||||
|
||||
// Apply update with all three records
|
||||
resolver.Update(zones)
|
||||
resolver.Update(update)
|
||||
|
||||
msg := new(dns.Msg).SetQuestion(recordName, recordType)
|
||||
|
||||
@@ -282,7 +264,7 @@ func TestLocalResolver_CaseInsensitiveMatching(t *testing.T) {
|
||||
}
|
||||
|
||||
// Update resolver with the records
|
||||
resolver.Update([]nbdns.CustomZone{{Domain: "example.com.", Records: []nbdns.SimpleRecord{lowerCaseRecord, mixedCaseRecord}}})
|
||||
resolver.Update([]nbdns.SimpleRecord{lowerCaseRecord, mixedCaseRecord})
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
@@ -397,7 +379,7 @@ func TestLocalResolver_CNAMEFallback(t *testing.T) {
|
||||
}
|
||||
|
||||
// Update resolver with both records
|
||||
resolver.Update([]nbdns.CustomZone{{Domain: "example.com.", Records: []nbdns.SimpleRecord{cnameRecord, targetRecord}}})
|
||||
resolver.Update([]nbdns.SimpleRecord{cnameRecord, targetRecord})
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
@@ -494,20 +476,6 @@ func TestLocalResolver_CNAMEFallback(t *testing.T) {
|
||||
// with 0 records instead of NXDOMAIN
|
||||
func TestLocalResolver_NoErrorWithDifferentRecordType(t *testing.T) {
|
||||
resolver := NewResolver()
|
||||
// Mock external resolver for CNAME target resolution
|
||||
resolver.resolver = &mockResolver{
|
||||
lookupFunc: func(_ context.Context, network, host string) ([]netip.Addr, error) {
|
||||
if host == "target.example.com." {
|
||||
if network == "ip4" {
|
||||
return []netip.Addr{netip.MustParseAddr("93.184.216.34")}, nil
|
||||
}
|
||||
if network == "ip6" {
|
||||
return []netip.Addr{netip.MustParseAddr("2606:2800:220:1:248:1893:25c8:1946")}, nil
|
||||
}
|
||||
}
|
||||
return nil, &net.DNSError{IsNotFound: true, Name: host}
|
||||
},
|
||||
}
|
||||
|
||||
recordA := nbdns.SimpleRecord{
|
||||
Name: "example.netbird.cloud.",
|
||||
@@ -525,7 +493,7 @@ func TestLocalResolver_NoErrorWithDifferentRecordType(t *testing.T) {
|
||||
RData: "target.example.com.",
|
||||
}
|
||||
|
||||
resolver.Update([]nbdns.CustomZone{{Domain: "netbird.cloud.", Records: []nbdns.SimpleRecord{recordA, recordCNAME}}})
|
||||
resolver.Update([]nbdns.SimpleRecord{recordA, recordCNAME})
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
@@ -614,808 +582,3 @@ func TestLocalResolver_NoErrorWithDifferentRecordType(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestLocalResolver_CNAMEChainResolution tests comprehensive CNAME chain following
|
||||
func TestLocalResolver_CNAMEChainResolution(t *testing.T) {
|
||||
t.Run("simple internal CNAME chain", func(t *testing.T) {
|
||||
resolver := NewResolver()
|
||||
resolver.Update([]nbdns.CustomZone{{
|
||||
Domain: "example.com.",
|
||||
Records: []nbdns.SimpleRecord{
|
||||
{Name: "alias.example.com.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "target.example.com."},
|
||||
{Name: "target.example.com.", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "192.168.1.1"},
|
||||
},
|
||||
}})
|
||||
|
||||
msg := new(dns.Msg).SetQuestion("alias.example.com.", dns.TypeA)
|
||||
var resp *dns.Msg
|
||||
resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { resp = m; return nil }}, msg)
|
||||
|
||||
require.NotNil(t, resp)
|
||||
require.Equal(t, dns.RcodeSuccess, resp.Rcode)
|
||||
require.Len(t, resp.Answer, 2)
|
||||
|
||||
cname, ok := resp.Answer[0].(*dns.CNAME)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "target.example.com.", cname.Target)
|
||||
|
||||
a, ok := resp.Answer[1].(*dns.A)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "192.168.1.1", a.A.String())
|
||||
})
|
||||
|
||||
t.Run("multi-hop CNAME chain", func(t *testing.T) {
|
||||
resolver := NewResolver()
|
||||
resolver.Update([]nbdns.CustomZone{{
|
||||
Domain: "test.",
|
||||
Records: []nbdns.SimpleRecord{
|
||||
{Name: "hop1.test.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "hop2.test."},
|
||||
{Name: "hop2.test.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "hop3.test."},
|
||||
{Name: "hop3.test.", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"},
|
||||
},
|
||||
}})
|
||||
|
||||
msg := new(dns.Msg).SetQuestion("hop1.test.", dns.TypeA)
|
||||
var resp *dns.Msg
|
||||
resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { resp = m; return nil }}, msg)
|
||||
|
||||
require.NotNil(t, resp)
|
||||
require.Equal(t, dns.RcodeSuccess, resp.Rcode)
|
||||
require.Len(t, resp.Answer, 3)
|
||||
})
|
||||
|
||||
t.Run("CNAME to non-existent internal target returns only CNAME", func(t *testing.T) {
|
||||
resolver := NewResolver()
|
||||
resolver.Update([]nbdns.CustomZone{{
|
||||
Domain: "test.",
|
||||
Records: []nbdns.SimpleRecord{
|
||||
{Name: "alias.test.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "nonexistent.test."},
|
||||
},
|
||||
}})
|
||||
|
||||
msg := new(dns.Msg).SetQuestion("alias.test.", dns.TypeA)
|
||||
var resp *dns.Msg
|
||||
resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { resp = m; return nil }}, msg)
|
||||
|
||||
require.NotNil(t, resp)
|
||||
require.Len(t, resp.Answer, 1)
|
||||
_, ok := resp.Answer[0].(*dns.CNAME)
|
||||
assert.True(t, ok)
|
||||
})
|
||||
}
|
||||
|
||||
// TestLocalResolver_CNAMEMaxDepth tests the maximum depth limit for CNAME chains
|
||||
func TestLocalResolver_CNAMEMaxDepth(t *testing.T) {
|
||||
t.Run("chain at max depth resolves", func(t *testing.T) {
|
||||
resolver := NewResolver()
|
||||
var records []nbdns.SimpleRecord
|
||||
// Create chain of 7 CNAMEs (under max of 8)
|
||||
for i := 1; i <= 7; i++ {
|
||||
records = append(records, nbdns.SimpleRecord{
|
||||
Name: fmt.Sprintf("hop%d.test.", i),
|
||||
Type: int(dns.TypeCNAME),
|
||||
Class: nbdns.DefaultClass,
|
||||
TTL: 300,
|
||||
RData: fmt.Sprintf("hop%d.test.", i+1),
|
||||
})
|
||||
}
|
||||
records = append(records, nbdns.SimpleRecord{
|
||||
Name: "hop8.test.", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.10.10.10",
|
||||
})
|
||||
|
||||
resolver.Update([]nbdns.CustomZone{{Domain: "test.", Records: records}})
|
||||
|
||||
msg := new(dns.Msg).SetQuestion("hop1.test.", dns.TypeA)
|
||||
var resp *dns.Msg
|
||||
resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { resp = m; return nil }}, msg)
|
||||
|
||||
require.NotNil(t, resp)
|
||||
require.Equal(t, dns.RcodeSuccess, resp.Rcode)
|
||||
require.Len(t, resp.Answer, 8)
|
||||
})
|
||||
|
||||
t.Run("chain exceeding max depth stops", func(t *testing.T) {
|
||||
resolver := NewResolver()
|
||||
var records []nbdns.SimpleRecord
|
||||
// Create chain of 10 CNAMEs (exceeds max of 8)
|
||||
for i := 1; i <= 10; i++ {
|
||||
records = append(records, nbdns.SimpleRecord{
|
||||
Name: fmt.Sprintf("deep%d.test.", i),
|
||||
Type: int(dns.TypeCNAME),
|
||||
Class: nbdns.DefaultClass,
|
||||
TTL: 300,
|
||||
RData: fmt.Sprintf("deep%d.test.", i+1),
|
||||
})
|
||||
}
|
||||
records = append(records, nbdns.SimpleRecord{
|
||||
Name: "deep11.test.", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.10.10.10",
|
||||
})
|
||||
|
||||
resolver.Update([]nbdns.CustomZone{{Domain: "test.", Records: records}})
|
||||
|
||||
msg := new(dns.Msg).SetQuestion("deep1.test.", dns.TypeA)
|
||||
var resp *dns.Msg
|
||||
resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { resp = m; return nil }}, msg)
|
||||
|
||||
require.NotNil(t, resp)
|
||||
// Should NOT have the final A record (chain too deep)
|
||||
assert.LessOrEqual(t, len(resp.Answer), 8)
|
||||
})
|
||||
|
||||
t.Run("circular CNAME is protected by max depth", func(t *testing.T) {
|
||||
resolver := NewResolver()
|
||||
resolver.Update([]nbdns.CustomZone{{
|
||||
Domain: "test.",
|
||||
Records: []nbdns.SimpleRecord{
|
||||
{Name: "loop1.test.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "loop2.test."},
|
||||
{Name: "loop2.test.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "loop1.test."},
|
||||
},
|
||||
}})
|
||||
|
||||
msg := new(dns.Msg).SetQuestion("loop1.test.", dns.TypeA)
|
||||
var resp *dns.Msg
|
||||
resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { resp = m; return nil }}, msg)
|
||||
|
||||
require.NotNil(t, resp)
|
||||
assert.LessOrEqual(t, len(resp.Answer), 8)
|
||||
})
|
||||
}
|
||||
|
||||
// TestLocalResolver_ExternalCNAMEResolution tests CNAME resolution to external domains
|
||||
func TestLocalResolver_ExternalCNAMEResolution(t *testing.T) {
|
||||
t.Run("CNAME to external domain resolves via external resolver", func(t *testing.T) {
|
||||
resolver := NewResolver()
|
||||
resolver.resolver = &mockResolver{
|
||||
lookupFunc: func(_ context.Context, network, host string) ([]netip.Addr, error) {
|
||||
if host == "external.example.com." && network == "ip4" {
|
||||
return []netip.Addr{netip.MustParseAddr("93.184.216.34")}, nil
|
||||
}
|
||||
return nil, nil
|
||||
},
|
||||
}
|
||||
|
||||
resolver.Update([]nbdns.CustomZone{{
|
||||
Domain: "test.",
|
||||
Records: []nbdns.SimpleRecord{
|
||||
{Name: "alias.test.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "external.example.com."},
|
||||
},
|
||||
}})
|
||||
|
||||
msg := new(dns.Msg).SetQuestion("alias.test.", dns.TypeA)
|
||||
var resp *dns.Msg
|
||||
resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { resp = m; return nil }}, msg)
|
||||
|
||||
require.NotNil(t, resp)
|
||||
require.Len(t, resp.Answer, 2, "Should have CNAME + A record")
|
||||
|
||||
cname, ok := resp.Answer[0].(*dns.CNAME)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "external.example.com.", cname.Target)
|
||||
|
||||
a, ok := resp.Answer[1].(*dns.A)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "93.184.216.34", a.A.String())
|
||||
})
|
||||
|
||||
t.Run("CNAME to external domain resolves IPv6", func(t *testing.T) {
|
||||
resolver := NewResolver()
|
||||
resolver.resolver = &mockResolver{
|
||||
lookupFunc: func(_ context.Context, network, host string) ([]netip.Addr, error) {
|
||||
if host == "external.example.com." && network == "ip6" {
|
||||
return []netip.Addr{netip.MustParseAddr("2606:2800:220:1:248:1893:25c8:1946")}, nil
|
||||
}
|
||||
return nil, nil
|
||||
},
|
||||
}
|
||||
|
||||
resolver.Update([]nbdns.CustomZone{{
|
||||
Domain: "test.",
|
||||
Records: []nbdns.SimpleRecord{
|
||||
{Name: "alias.test.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "external.example.com."},
|
||||
},
|
||||
}})
|
||||
|
||||
msg := new(dns.Msg).SetQuestion("alias.test.", dns.TypeAAAA)
|
||||
var resp *dns.Msg
|
||||
resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { resp = m; return nil }}, msg)
|
||||
|
||||
require.NotNil(t, resp)
|
||||
require.Len(t, resp.Answer, 2, "Should have CNAME + AAAA record")
|
||||
|
||||
cname, ok := resp.Answer[0].(*dns.CNAME)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "external.example.com.", cname.Target)
|
||||
|
||||
aaaa, ok := resp.Answer[1].(*dns.AAAA)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "2606:2800:220:1:248:1893:25c8:1946", aaaa.AAAA.String())
|
||||
})
|
||||
|
||||
t.Run("concurrent external resolution", func(t *testing.T) {
|
||||
resolver := NewResolver()
|
||||
resolver.resolver = &mockResolver{
|
||||
lookupFunc: func(_ context.Context, network, host string) ([]netip.Addr, error) {
|
||||
if host == "external.example.com." && network == "ip4" {
|
||||
return []netip.Addr{netip.MustParseAddr("93.184.216.34")}, nil
|
||||
}
|
||||
return nil, nil
|
||||
},
|
||||
}
|
||||
|
||||
resolver.Update([]nbdns.CustomZone{{
|
||||
Domain: "test.",
|
||||
Records: []nbdns.SimpleRecord{
|
||||
{Name: "concurrent.test.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "external.example.com."},
|
||||
},
|
||||
}})
|
||||
|
||||
var wg sync.WaitGroup
|
||||
results := make([]*dns.Msg, 10)
|
||||
|
||||
for i := 0; i < 10; i++ {
|
||||
wg.Add(1)
|
||||
go func(idx int) {
|
||||
defer wg.Done()
|
||||
msg := new(dns.Msg).SetQuestion("concurrent.test.", dns.TypeA)
|
||||
var resp *dns.Msg
|
||||
resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { resp = m; return nil }}, msg)
|
||||
results[idx] = resp
|
||||
}(i)
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
for i, resp := range results {
|
||||
require.NotNil(t, resp, "Response %d should not be nil", i)
|
||||
require.Len(t, resp.Answer, 2, "Response %d should have CNAME + A", i)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestLocalResolver_ZoneManagement tests zone-aware CNAME resolution
|
||||
func TestLocalResolver_ZoneManagement(t *testing.T) {
|
||||
t.Run("Update sets zones correctly", func(t *testing.T) {
|
||||
resolver := NewResolver()
|
||||
|
||||
resolver.Update([]nbdns.CustomZone{
|
||||
{Domain: "example.com.", Records: []nbdns.SimpleRecord{
|
||||
{Name: "host.example.com.", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"},
|
||||
}},
|
||||
{Domain: "test.local."},
|
||||
})
|
||||
|
||||
assert.True(t, resolver.isInManagedZone("host.example.com."))
|
||||
assert.True(t, resolver.isInManagedZone("other.example.com."))
|
||||
assert.True(t, resolver.isInManagedZone("sub.test.local."))
|
||||
assert.False(t, resolver.isInManagedZone("external.com."))
|
||||
})
|
||||
|
||||
t.Run("isInManagedZone case insensitive", func(t *testing.T) {
|
||||
resolver := NewResolver()
|
||||
resolver.Update([]nbdns.CustomZone{{Domain: "Example.COM."}})
|
||||
|
||||
assert.True(t, resolver.isInManagedZone("host.example.com."))
|
||||
assert.True(t, resolver.isInManagedZone("HOST.EXAMPLE.COM."))
|
||||
})
|
||||
|
||||
t.Run("Update clears zones", func(t *testing.T) {
|
||||
resolver := NewResolver()
|
||||
resolver.Update([]nbdns.CustomZone{{Domain: "example.com."}})
|
||||
assert.True(t, resolver.isInManagedZone("host.example.com."))
|
||||
|
||||
resolver.Update(nil)
|
||||
assert.False(t, resolver.isInManagedZone("host.example.com."))
|
||||
})
|
||||
}
|
||||
|
||||
// TestLocalResolver_CNAMEZoneAwareResolution tests CNAME resolution with zone awareness
|
||||
func TestLocalResolver_CNAMEZoneAwareResolution(t *testing.T) {
|
||||
t.Run("CNAME target in managed zone returns NXDOMAIN per RFC 6604", func(t *testing.T) {
|
||||
resolver := NewResolver()
|
||||
resolver.Update([]nbdns.CustomZone{{
|
||||
Domain: "myzone.test.",
|
||||
Records: []nbdns.SimpleRecord{
|
||||
{Name: "alias.myzone.test.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "nonexistent.myzone.test."},
|
||||
},
|
||||
}})
|
||||
|
||||
msg := new(dns.Msg).SetQuestion("alias.myzone.test.", dns.TypeA)
|
||||
var resp *dns.Msg
|
||||
resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { resp = m; return nil }}, msg)
|
||||
|
||||
require.NotNil(t, resp)
|
||||
assert.Equal(t, dns.RcodeNameError, resp.Rcode, "Should return NXDOMAIN")
|
||||
require.Len(t, resp.Answer, 1, "Should include CNAME in answer")
|
||||
})
|
||||
|
||||
t.Run("CNAME to external domain skips zone check", func(t *testing.T) {
|
||||
resolver := NewResolver()
|
||||
resolver.resolver = &mockResolver{
|
||||
lookupFunc: func(_ context.Context, network, host string) ([]netip.Addr, error) {
|
||||
if host == "external.other.com." && network == "ip4" {
|
||||
return []netip.Addr{netip.MustParseAddr("203.0.113.1")}, nil
|
||||
}
|
||||
return nil, nil
|
||||
},
|
||||
}
|
||||
|
||||
resolver.Update([]nbdns.CustomZone{{
|
||||
Domain: "myzone.test.",
|
||||
Records: []nbdns.SimpleRecord{
|
||||
{Name: "alias.myzone.test.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "external.other.com."},
|
||||
},
|
||||
}})
|
||||
|
||||
msg := new(dns.Msg).SetQuestion("alias.myzone.test.", dns.TypeA)
|
||||
var resp *dns.Msg
|
||||
resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { resp = m; return nil }}, msg)
|
||||
|
||||
require.NotNil(t, resp)
|
||||
assert.Equal(t, dns.RcodeSuccess, resp.Rcode)
|
||||
require.Len(t, resp.Answer, 2, "Should have CNAME + A from external resolution")
|
||||
})
|
||||
|
||||
t.Run("CNAME target exists with different type returns NODATA not NXDOMAIN", func(t *testing.T) {
|
||||
resolver := NewResolver()
|
||||
// CNAME points to target that has A but no AAAA - query for AAAA should be NODATA
|
||||
resolver.Update([]nbdns.CustomZone{{
|
||||
Domain: "myzone.test.",
|
||||
Records: []nbdns.SimpleRecord{
|
||||
{Name: "alias.myzone.test.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "target.myzone.test."},
|
||||
{Name: "target.myzone.test.", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "1.1.1.1"},
|
||||
},
|
||||
}})
|
||||
|
||||
msg := new(dns.Msg).SetQuestion("alias.myzone.test.", dns.TypeAAAA)
|
||||
var resp *dns.Msg
|
||||
resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { resp = m; return nil }}, msg)
|
||||
|
||||
require.NotNil(t, resp)
|
||||
assert.Equal(t, dns.RcodeSuccess, resp.Rcode, "Should return NODATA (success), not NXDOMAIN")
|
||||
require.Len(t, resp.Answer, 1, "Should have only CNAME, no AAAA")
|
||||
_, ok := resp.Answer[0].(*dns.CNAME)
|
||||
assert.True(t, ok, "Answer should be CNAME record")
|
||||
})
|
||||
|
||||
t.Run("external CNAME target exists but no AAAA records (NODATA)", func(t *testing.T) {
|
||||
resolver := NewResolver()
|
||||
resolver.resolver = &mockResolver{
|
||||
lookupFunc: func(_ context.Context, network, host string) ([]netip.Addr, error) {
|
||||
if host == "external.example.com." {
|
||||
if network == "ip6" {
|
||||
// No AAAA records
|
||||
return nil, &net.DNSError{IsNotFound: true, Name: host}
|
||||
}
|
||||
if network == "ip4" {
|
||||
// But A records exist - domain exists
|
||||
return []netip.Addr{netip.MustParseAddr("93.184.216.34")}, nil
|
||||
}
|
||||
}
|
||||
return nil, &net.DNSError{IsNotFound: true, Name: host}
|
||||
},
|
||||
}
|
||||
|
||||
resolver.Update([]nbdns.CustomZone{{
|
||||
Domain: "test.",
|
||||
Records: []nbdns.SimpleRecord{
|
||||
{Name: "alias.test.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "external.example.com."},
|
||||
},
|
||||
}})
|
||||
|
||||
msg := new(dns.Msg).SetQuestion("alias.test.", dns.TypeAAAA)
|
||||
var resp *dns.Msg
|
||||
resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { resp = m; return nil }}, msg)
|
||||
|
||||
require.NotNil(t, resp)
|
||||
assert.Equal(t, dns.RcodeSuccess, resp.Rcode, "Should return NODATA (success), not NXDOMAIN")
|
||||
require.Len(t, resp.Answer, 1, "Should have only CNAME")
|
||||
_, ok := resp.Answer[0].(*dns.CNAME)
|
||||
assert.True(t, ok, "Answer should be CNAME record")
|
||||
})
|
||||
|
||||
// Table-driven test for all external resolution outcomes
|
||||
externalCases := []struct {
|
||||
name string
|
||||
lookupFunc func(context.Context, string, string) ([]netip.Addr, error)
|
||||
expectedRcode int
|
||||
expectedAnswer int
|
||||
}{
|
||||
{
|
||||
name: "external NXDOMAIN (both A and AAAA not found)",
|
||||
lookupFunc: func(_ context.Context, network, host string) ([]netip.Addr, error) {
|
||||
return nil, &net.DNSError{IsNotFound: true, Name: host}
|
||||
},
|
||||
expectedRcode: dns.RcodeNameError,
|
||||
expectedAnswer: 1, // CNAME only
|
||||
},
|
||||
{
|
||||
name: "external SERVFAIL (temporary error)",
|
||||
lookupFunc: func(_ context.Context, network, host string) ([]netip.Addr, error) {
|
||||
return nil, &net.DNSError{IsTemporary: true, Name: host}
|
||||
},
|
||||
expectedRcode: dns.RcodeServerFailure,
|
||||
expectedAnswer: 1, // CNAME only
|
||||
},
|
||||
{
|
||||
name: "external SERVFAIL (timeout)",
|
||||
lookupFunc: func(_ context.Context, network, host string) ([]netip.Addr, error) {
|
||||
return nil, &net.DNSError{IsTimeout: true, Name: host}
|
||||
},
|
||||
expectedRcode: dns.RcodeServerFailure,
|
||||
expectedAnswer: 1, // CNAME only
|
||||
},
|
||||
{
|
||||
name: "external SERVFAIL (generic error)",
|
||||
lookupFunc: func(_ context.Context, network, host string) ([]netip.Addr, error) {
|
||||
return nil, fmt.Errorf("connection refused")
|
||||
},
|
||||
expectedRcode: dns.RcodeServerFailure,
|
||||
expectedAnswer: 1, // CNAME only
|
||||
},
|
||||
{
|
||||
name: "external success with IPs",
|
||||
lookupFunc: func(_ context.Context, network, host string) ([]netip.Addr, error) {
|
||||
if network == "ip4" {
|
||||
return []netip.Addr{netip.MustParseAddr("93.184.216.34")}, nil
|
||||
}
|
||||
return nil, &net.DNSError{IsNotFound: true, Name: host}
|
||||
},
|
||||
expectedRcode: dns.RcodeSuccess,
|
||||
expectedAnswer: 2, // CNAME + A
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range externalCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
resolver := NewResolver()
|
||||
resolver.resolver = &mockResolver{lookupFunc: tc.lookupFunc}
|
||||
|
||||
resolver.Update([]nbdns.CustomZone{{
|
||||
Domain: "test.",
|
||||
Records: []nbdns.SimpleRecord{
|
||||
{Name: "alias.test.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "external.example.com."},
|
||||
},
|
||||
}})
|
||||
|
||||
msg := new(dns.Msg).SetQuestion("alias.test.", dns.TypeA)
|
||||
var resp *dns.Msg
|
||||
resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { resp = m; return nil }}, msg)
|
||||
|
||||
require.NotNil(t, resp)
|
||||
assert.Equal(t, tc.expectedRcode, resp.Rcode, "rcode mismatch")
|
||||
assert.Len(t, resp.Answer, tc.expectedAnswer, "answer count mismatch")
|
||||
if tc.expectedAnswer > 0 {
|
||||
_, ok := resp.Answer[0].(*dns.CNAME)
|
||||
assert.True(t, ok, "first answer should be CNAME")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestLocalResolver_Fallthrough verifies that non-authoritative zones
|
||||
// trigger fallthrough (Zero bit set) when no records match
|
||||
func TestLocalResolver_Fallthrough(t *testing.T) {
|
||||
resolver := NewResolver()
|
||||
|
||||
record := nbdns.SimpleRecord{
|
||||
Name: "existing.custom.zone.",
|
||||
Type: int(dns.TypeA),
|
||||
Class: nbdns.DefaultClass,
|
||||
TTL: 300,
|
||||
RData: "10.0.0.1",
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
zones []nbdns.CustomZone
|
||||
queryName string
|
||||
expectFallthrough bool
|
||||
expectRecord bool
|
||||
}{
|
||||
{
|
||||
name: "Authoritative zone returns NXDOMAIN without fallthrough",
|
||||
zones: []nbdns.CustomZone{{
|
||||
Domain: "custom.zone.",
|
||||
Records: []nbdns.SimpleRecord{record},
|
||||
}},
|
||||
queryName: "nonexistent.custom.zone.",
|
||||
expectFallthrough: false,
|
||||
expectRecord: false,
|
||||
},
|
||||
{
|
||||
name: "Non-authoritative zone triggers fallthrough",
|
||||
zones: []nbdns.CustomZone{{
|
||||
Domain: "custom.zone.",
|
||||
Records: []nbdns.SimpleRecord{record},
|
||||
NonAuthoritative: true,
|
||||
}},
|
||||
queryName: "nonexistent.custom.zone.",
|
||||
expectFallthrough: true,
|
||||
expectRecord: false,
|
||||
},
|
||||
{
|
||||
name: "Record found in non-authoritative zone returns normally",
|
||||
zones: []nbdns.CustomZone{{
|
||||
Domain: "custom.zone.",
|
||||
Records: []nbdns.SimpleRecord{record},
|
||||
NonAuthoritative: true,
|
||||
}},
|
||||
queryName: "existing.custom.zone.",
|
||||
expectFallthrough: false,
|
||||
expectRecord: true,
|
||||
},
|
||||
{
|
||||
name: "Record found in authoritative zone returns normally",
|
||||
zones: []nbdns.CustomZone{{
|
||||
Domain: "custom.zone.",
|
||||
Records: []nbdns.SimpleRecord{record},
|
||||
}},
|
||||
queryName: "existing.custom.zone.",
|
||||
expectFallthrough: false,
|
||||
expectRecord: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
resolver.Update(tc.zones)
|
||||
|
||||
var responseMSG *dns.Msg
|
||||
responseWriter := &test.MockResponseWriter{
|
||||
WriteMsgFunc: func(m *dns.Msg) error {
|
||||
responseMSG = m
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
msg := new(dns.Msg).SetQuestion(tc.queryName, dns.TypeA)
|
||||
resolver.ServeDNS(responseWriter, msg)
|
||||
|
||||
require.NotNil(t, responseMSG, "Should have received a response")
|
||||
|
||||
if tc.expectFallthrough {
|
||||
assert.True(t, responseMSG.MsgHdr.Zero, "Zero bit should be set for fallthrough")
|
||||
assert.Equal(t, dns.RcodeNameError, responseMSG.Rcode, "Should return NXDOMAIN")
|
||||
} else {
|
||||
assert.False(t, responseMSG.MsgHdr.Zero, "Zero bit should not be set")
|
||||
}
|
||||
|
||||
if tc.expectRecord {
|
||||
assert.Greater(t, len(responseMSG.Answer), 0, "Should have answer records")
|
||||
assert.Equal(t, dns.RcodeSuccess, responseMSG.Rcode)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestLocalResolver_AuthoritativeFlag tests the AA flag behavior
|
||||
func TestLocalResolver_AuthoritativeFlag(t *testing.T) {
|
||||
t.Run("direct record lookup is authoritative", func(t *testing.T) {
|
||||
resolver := NewResolver()
|
||||
resolver.Update([]nbdns.CustomZone{{
|
||||
Domain: "example.com.",
|
||||
Records: []nbdns.SimpleRecord{
|
||||
{Name: "host.example.com.", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"},
|
||||
},
|
||||
}})
|
||||
|
||||
msg := new(dns.Msg).SetQuestion("host.example.com.", dns.TypeA)
|
||||
var resp *dns.Msg
|
||||
resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { resp = m; return nil }}, msg)
|
||||
|
||||
require.NotNil(t, resp)
|
||||
assert.True(t, resp.Authoritative)
|
||||
})
|
||||
|
||||
t.Run("external resolution is not authoritative", func(t *testing.T) {
|
||||
resolver := NewResolver()
|
||||
resolver.resolver = &mockResolver{
|
||||
lookupFunc: func(_ context.Context, network, host string) ([]netip.Addr, error) {
|
||||
if host == "external.example.com." && network == "ip4" {
|
||||
return []netip.Addr{netip.MustParseAddr("93.184.216.34")}, nil
|
||||
}
|
||||
return nil, nil
|
||||
},
|
||||
}
|
||||
|
||||
resolver.Update([]nbdns.CustomZone{{
|
||||
Domain: "test.",
|
||||
Records: []nbdns.SimpleRecord{
|
||||
{Name: "alias.test.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "external.example.com."},
|
||||
},
|
||||
}})
|
||||
|
||||
msg := new(dns.Msg).SetQuestion("alias.test.", dns.TypeA)
|
||||
var resp *dns.Msg
|
||||
resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { resp = m; return nil }}, msg)
|
||||
|
||||
require.NotNil(t, resp)
|
||||
require.Len(t, resp.Answer, 2)
|
||||
assert.False(t, resp.Authoritative)
|
||||
})
|
||||
}
|
||||
|
||||
// TestLocalResolver_Stop tests cleanup on Stop
|
||||
func TestLocalResolver_Stop(t *testing.T) {
|
||||
t.Run("Stop clears all state", func(t *testing.T) {
|
||||
resolver := NewResolver()
|
||||
resolver.Update([]nbdns.CustomZone{{
|
||||
Domain: "example.com.",
|
||||
Records: []nbdns.SimpleRecord{
|
||||
{Name: "host.example.com.", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"},
|
||||
},
|
||||
}})
|
||||
|
||||
resolver.Stop()
|
||||
|
||||
msg := new(dns.Msg).SetQuestion("host.example.com.", dns.TypeA)
|
||||
var resp *dns.Msg
|
||||
resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { resp = m; return nil }}, msg)
|
||||
|
||||
require.NotNil(t, resp)
|
||||
assert.Len(t, resp.Answer, 0)
|
||||
assert.False(t, resolver.isInManagedZone("host.example.com."))
|
||||
})
|
||||
|
||||
t.Run("Stop is safe to call multiple times", func(t *testing.T) {
|
||||
resolver := NewResolver()
|
||||
resolver.Update([]nbdns.CustomZone{{
|
||||
Domain: "example.com.",
|
||||
Records: []nbdns.SimpleRecord{
|
||||
{Name: "host.example.com.", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"},
|
||||
},
|
||||
}})
|
||||
|
||||
resolver.Stop()
|
||||
resolver.Stop()
|
||||
resolver.Stop()
|
||||
})
|
||||
|
||||
t.Run("Stop cancels in-flight external resolution", func(t *testing.T) {
|
||||
resolver := NewResolver()
|
||||
|
||||
lookupStarted := make(chan struct{})
|
||||
lookupCtxCanceled := make(chan struct{})
|
||||
|
||||
resolver.resolver = &mockResolver{
|
||||
lookupFunc: func(ctx context.Context, network, host string) ([]netip.Addr, error) {
|
||||
close(lookupStarted)
|
||||
<-ctx.Done()
|
||||
close(lookupCtxCanceled)
|
||||
return nil, ctx.Err()
|
||||
},
|
||||
}
|
||||
|
||||
resolver.Update([]nbdns.CustomZone{{
|
||||
Domain: "test.",
|
||||
Records: []nbdns.SimpleRecord{
|
||||
{Name: "alias.test.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "external.example.com."},
|
||||
},
|
||||
}})
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
msg := new(dns.Msg).SetQuestion("alias.test.", dns.TypeA)
|
||||
resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { return nil }}, msg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
<-lookupStarted
|
||||
resolver.Stop()
|
||||
|
||||
select {
|
||||
case <-lookupCtxCanceled:
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("external lookup context was not canceled")
|
||||
}
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("ServeDNS did not return after Stop")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestLocalResolver_FallthroughCaseInsensitive verifies case-insensitive domain matching for fallthrough
|
||||
func TestLocalResolver_FallthroughCaseInsensitive(t *testing.T) {
|
||||
resolver := NewResolver()
|
||||
|
||||
resolver.Update([]nbdns.CustomZone{{
|
||||
Domain: "EXAMPLE.COM.",
|
||||
Records: []nbdns.SimpleRecord{{Name: "host.example.com.", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "1.2.3.4"}},
|
||||
NonAuthoritative: true,
|
||||
}})
|
||||
|
||||
var responseMSG *dns.Msg
|
||||
responseWriter := &test.MockResponseWriter{
|
||||
WriteMsgFunc: func(m *dns.Msg) error {
|
||||
responseMSG = m
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
msg := new(dns.Msg).SetQuestion("nonexistent.example.com.", dns.TypeA)
|
||||
resolver.ServeDNS(responseWriter, msg)
|
||||
|
||||
require.NotNil(t, responseMSG)
|
||||
assert.True(t, responseMSG.MsgHdr.Zero, "Should fallthrough for non-authoritative zone with case-insensitive match")
|
||||
}
|
||||
|
||||
// BenchmarkFindZone_BestCase benchmarks zone lookup with immediate match (first label)
|
||||
func BenchmarkFindZone_BestCase(b *testing.B) {
|
||||
resolver := NewResolver()
|
||||
|
||||
// Single zone that matches immediately
|
||||
resolver.Update([]nbdns.CustomZone{{
|
||||
Domain: "example.com.",
|
||||
NonAuthoritative: true,
|
||||
}})
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
resolver.shouldFallthrough("example.com.")
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkFindZone_WorstCase benchmarks zone lookup with many zones, no match, many labels
|
||||
func BenchmarkFindZone_WorstCase(b *testing.B) {
|
||||
resolver := NewResolver()
|
||||
|
||||
// 100 zones that won't match
|
||||
var zones []nbdns.CustomZone
|
||||
for i := 0; i < 100; i++ {
|
||||
zones = append(zones, nbdns.CustomZone{
|
||||
Domain: fmt.Sprintf("zone%d.internal.", i),
|
||||
NonAuthoritative: true,
|
||||
})
|
||||
}
|
||||
resolver.Update(zones)
|
||||
|
||||
// Query with many labels that won't match any zone
|
||||
qname := "a.b.c.d.e.f.g.h.external.com."
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
resolver.shouldFallthrough(qname)
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkFindZone_TypicalCase benchmarks typical usage: few zones, subdomain match
|
||||
func BenchmarkFindZone_TypicalCase(b *testing.B) {
|
||||
resolver := NewResolver()
|
||||
|
||||
// Typical setup: peer zone (authoritative) + one user zone (non-authoritative)
|
||||
resolver.Update([]nbdns.CustomZone{
|
||||
{Domain: "netbird.cloud.", NonAuthoritative: false},
|
||||
{Domain: "custom.local.", NonAuthoritative: true},
|
||||
})
|
||||
|
||||
// Query for subdomain of user zone
|
||||
qname := "myhost.custom.local."
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
resolver.shouldFallthrough(qname)
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkIsInManagedZone_ManyZones benchmarks isInManagedZone with 100 zones
|
||||
func BenchmarkIsInManagedZone_ManyZones(b *testing.B) {
|
||||
resolver := NewResolver()
|
||||
|
||||
var zones []nbdns.CustomZone
|
||||
for i := 0; i < 100; i++ {
|
||||
zones = append(zones, nbdns.CustomZone{
|
||||
Domain: fmt.Sprintf("zone%d.internal.", i),
|
||||
})
|
||||
}
|
||||
resolver.Update(zones)
|
||||
|
||||
// Query that matches zone50
|
||||
qname := "host.zone50.internal."
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
resolver.isInManagedZone(qname)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,197 +0,0 @@
|
||||
// Package resutil provides shared DNS resolution utilities
|
||||
package resutil
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"net"
|
||||
"net/netip"
|
||||
"strings"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// GenerateRequestID creates a random 8-character hex string for request tracing.
|
||||
func GenerateRequestID() string {
|
||||
bytes := make([]byte, 4)
|
||||
if _, err := rand.Read(bytes); err != nil {
|
||||
log.Errorf("generate request ID: %v", err)
|
||||
return ""
|
||||
}
|
||||
return hex.EncodeToString(bytes)
|
||||
}
|
||||
|
||||
// IPsToRRs converts a slice of IP addresses to DNS resource records.
|
||||
// IPv4 addresses become A records, IPv6 addresses become AAAA records.
|
||||
func IPsToRRs(name string, ips []netip.Addr, ttl uint32) []dns.RR {
|
||||
var result []dns.RR
|
||||
|
||||
for _, ip := range ips {
|
||||
if ip.Is6() {
|
||||
result = append(result, &dns.AAAA{
|
||||
Hdr: dns.RR_Header{
|
||||
Name: name,
|
||||
Rrtype: dns.TypeAAAA,
|
||||
Class: dns.ClassINET,
|
||||
Ttl: ttl,
|
||||
},
|
||||
AAAA: ip.AsSlice(),
|
||||
})
|
||||
} else {
|
||||
result = append(result, &dns.A{
|
||||
Hdr: dns.RR_Header{
|
||||
Name: name,
|
||||
Rrtype: dns.TypeA,
|
||||
Class: dns.ClassINET,
|
||||
Ttl: ttl,
|
||||
},
|
||||
A: ip.AsSlice(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// NetworkForQtype returns the network string ("ip4" or "ip6") for a DNS query type.
|
||||
// Returns empty string for unsupported types.
|
||||
func NetworkForQtype(qtype uint16) string {
|
||||
switch qtype {
|
||||
case dns.TypeA:
|
||||
return "ip4"
|
||||
case dns.TypeAAAA:
|
||||
return "ip6"
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
type resolver interface {
|
||||
LookupNetIP(ctx context.Context, network, host string) ([]netip.Addr, error)
|
||||
}
|
||||
|
||||
// chainedWriter is implemented by ResponseWriters that carry request metadata
|
||||
type chainedWriter interface {
|
||||
RequestID() string
|
||||
SetMeta(key, value string)
|
||||
}
|
||||
|
||||
// GetRequestID extracts a request ID from the ResponseWriter if available,
|
||||
// otherwise generates a new one.
|
||||
func GetRequestID(w dns.ResponseWriter) string {
|
||||
if cw, ok := w.(chainedWriter); ok {
|
||||
if id := cw.RequestID(); id != "" {
|
||||
return id
|
||||
}
|
||||
}
|
||||
return GenerateRequestID()
|
||||
}
|
||||
|
||||
// SetMeta sets metadata on the ResponseWriter if it supports it.
|
||||
func SetMeta(w dns.ResponseWriter, key, value string) {
|
||||
if cw, ok := w.(chainedWriter); ok {
|
||||
cw.SetMeta(key, value)
|
||||
}
|
||||
}
|
||||
|
||||
// LookupResult contains the result of an external DNS lookup
|
||||
type LookupResult struct {
|
||||
IPs []netip.Addr
|
||||
Rcode int
|
||||
Err error // Original error for caller's logging needs
|
||||
}
|
||||
|
||||
// LookupIP performs a DNS lookup and determines the appropriate rcode.
|
||||
func LookupIP(ctx context.Context, r resolver, network, host string, qtype uint16) LookupResult {
|
||||
ips, err := r.LookupNetIP(ctx, network, host)
|
||||
if err != nil {
|
||||
return LookupResult{
|
||||
Rcode: getRcodeForError(ctx, r, host, qtype, err),
|
||||
Err: err,
|
||||
}
|
||||
}
|
||||
|
||||
// Unmap IPv4-mapped IPv6 addresses that some resolvers may return
|
||||
for i, ip := range ips {
|
||||
ips[i] = ip.Unmap()
|
||||
}
|
||||
|
||||
return LookupResult{
|
||||
IPs: ips,
|
||||
Rcode: dns.RcodeSuccess,
|
||||
}
|
||||
}
|
||||
|
||||
func getRcodeForError(ctx context.Context, r resolver, host string, qtype uint16, err error) int {
|
||||
var dnsErr *net.DNSError
|
||||
if !errors.As(err, &dnsErr) {
|
||||
return dns.RcodeServerFailure
|
||||
}
|
||||
|
||||
if dnsErr.IsNotFound {
|
||||
return getRcodeForNotFound(ctx, r, host, qtype)
|
||||
}
|
||||
|
||||
return dns.RcodeServerFailure
|
||||
}
|
||||
|
||||
// getRcodeForNotFound distinguishes between NXDOMAIN (domain doesn't exist) and NODATA
|
||||
// (domain exists but no records of requested type) by checking the opposite record type.
|
||||
//
|
||||
// musl libc (the reason we need this distinction) only queries A/AAAA pairs in getaddrinfo,
|
||||
// so checking the opposite A/AAAA type is sufficient. Other record types (MX, TXT, etc.)
|
||||
// are not queried by musl and don't need this handling.
|
||||
func getRcodeForNotFound(ctx context.Context, r resolver, domain string, originalQtype uint16) int {
|
||||
// Try querying for a different record type to see if the domain exists
|
||||
// If the original query was for AAAA, try A. If it was for A, try AAAA.
|
||||
// This helps distinguish between NXDOMAIN and NODATA.
|
||||
var alternativeNetwork string
|
||||
switch originalQtype {
|
||||
case dns.TypeAAAA:
|
||||
alternativeNetwork = "ip4"
|
||||
case dns.TypeA:
|
||||
alternativeNetwork = "ip6"
|
||||
default:
|
||||
return dns.RcodeNameError
|
||||
}
|
||||
|
||||
if _, err := r.LookupNetIP(ctx, alternativeNetwork, domain); err != nil {
|
||||
var dnsErr *net.DNSError
|
||||
if errors.As(err, &dnsErr) && dnsErr.IsNotFound {
|
||||
// Alternative query also returned not found - domain truly doesn't exist
|
||||
return dns.RcodeNameError
|
||||
}
|
||||
// Some other error (timeout, server failure, etc.) - can't determine, assume domain exists
|
||||
return dns.RcodeSuccess
|
||||
}
|
||||
|
||||
// Alternative query succeeded - domain exists but has no records of this type
|
||||
return dns.RcodeSuccess
|
||||
}
|
||||
|
||||
// FormatAnswers formats DNS resource records for logging.
|
||||
func FormatAnswers(answers []dns.RR) string {
|
||||
if len(answers) == 0 {
|
||||
return "[]"
|
||||
}
|
||||
|
||||
parts := make([]string, 0, len(answers))
|
||||
for _, rr := range answers {
|
||||
switch r := rr.(type) {
|
||||
case *dns.A:
|
||||
parts = append(parts, r.A.String())
|
||||
case *dns.AAAA:
|
||||
parts = append(parts, r.AAAA.String())
|
||||
case *dns.CNAME:
|
||||
parts = append(parts, "CNAME:"+r.Target)
|
||||
case *dns.PTR:
|
||||
parts = append(parts, "PTR:"+r.Ptr)
|
||||
default:
|
||||
parts = append(parts, dns.TypeToString[rr.Header().Rrtype])
|
||||
}
|
||||
}
|
||||
return "[" + strings.Join(parts, ", ") + "]"
|
||||
}
|
||||
@@ -485,7 +485,7 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
|
||||
}
|
||||
}
|
||||
|
||||
localMuxUpdates, localZones, err := s.buildLocalHandlerUpdate(update.CustomZones)
|
||||
localMuxUpdates, localRecords, err := s.buildLocalHandlerUpdate(update.CustomZones)
|
||||
if err != nil {
|
||||
return fmt.Errorf("local handler updater: %w", err)
|
||||
}
|
||||
@@ -498,7 +498,8 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
|
||||
|
||||
s.updateMux(muxUpdates)
|
||||
|
||||
s.localResolver.Update(localZones)
|
||||
// register local records
|
||||
s.localResolver.Update(localRecords)
|
||||
|
||||
s.currentConfig = dnsConfigToHostDNSConfig(update, s.service.RuntimeIP(), s.service.RuntimePort())
|
||||
|
||||
@@ -631,7 +632,9 @@ func (s *DefaultServer) registerFallback(config HostDNSConfig) {
|
||||
|
||||
handler, err := newUpstreamResolver(
|
||||
s.ctx,
|
||||
s.wgInterface,
|
||||
s.wgInterface.Name(),
|
||||
s.wgInterface.Address().IP,
|
||||
s.wgInterface.Address().Network,
|
||||
s.statusRecorder,
|
||||
s.hostsDNSHolder,
|
||||
nbdns.RootZone,
|
||||
@@ -656,9 +659,9 @@ func (s *DefaultServer) registerFallback(config HostDNSConfig) {
|
||||
s.registerHandler([]string{nbdns.RootZone}, handler, PriorityFallback)
|
||||
}
|
||||
|
||||
func (s *DefaultServer) buildLocalHandlerUpdate(customZones []nbdns.CustomZone) ([]handlerWrapper, []nbdns.CustomZone, error) {
|
||||
func (s *DefaultServer) buildLocalHandlerUpdate(customZones []nbdns.CustomZone) ([]handlerWrapper, []nbdns.SimpleRecord, error) {
|
||||
var muxUpdates []handlerWrapper
|
||||
var zones []nbdns.CustomZone
|
||||
var localRecords []nbdns.SimpleRecord
|
||||
|
||||
for _, customZone := range customZones {
|
||||
if len(customZone.Records) == 0 {
|
||||
@@ -672,20 +675,17 @@ func (s *DefaultServer) buildLocalHandlerUpdate(customZones []nbdns.CustomZone)
|
||||
priority: PriorityLocal,
|
||||
})
|
||||
|
||||
// zone records contain the fqdn, so we can just flatten them
|
||||
var localRecords []nbdns.SimpleRecord
|
||||
for _, record := range customZone.Records {
|
||||
if record.Class != nbdns.DefaultClass {
|
||||
log.Warnf("received an invalid class type: %s", record.Class)
|
||||
continue
|
||||
}
|
||||
// zone records contain the fqdn, so we can just flatten them
|
||||
localRecords = append(localRecords, record)
|
||||
}
|
||||
customZone.Records = localRecords
|
||||
zones = append(zones, customZone)
|
||||
}
|
||||
|
||||
return muxUpdates, zones, nil
|
||||
return muxUpdates, localRecords, nil
|
||||
}
|
||||
|
||||
func (s *DefaultServer) buildUpstreamHandlerUpdate(nameServerGroups []*nbdns.NameServerGroup) ([]handlerWrapper, error) {
|
||||
@@ -741,7 +741,9 @@ func (s *DefaultServer) createHandlersForDomainGroup(domainGroup nsGroupsByDomai
|
||||
log.Debugf("creating handler for domain=%s with priority=%d", domainGroup.domain, priority)
|
||||
handler, err := newUpstreamResolver(
|
||||
s.ctx,
|
||||
s.wgInterface,
|
||||
s.wgInterface.Name(),
|
||||
s.wgInterface.Address().IP,
|
||||
s.wgInterface.Address().Network,
|
||||
s.statusRecorder,
|
||||
s.hostsDNSHolder,
|
||||
domainGroup.domain,
|
||||
@@ -922,7 +924,9 @@ func (s *DefaultServer) addHostRootZone() {
|
||||
|
||||
handler, err := newUpstreamResolver(
|
||||
s.ctx,
|
||||
s.wgInterface,
|
||||
s.wgInterface.Name(),
|
||||
s.wgInterface.Address().IP,
|
||||
s.wgInterface.Address().Network,
|
||||
s.statusRecorder,
|
||||
s.hostsDNSHolder,
|
||||
nbdns.RootZone,
|
||||
|
||||
@@ -15,7 +15,6 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
"golang.zx2c4.com/wireguard/tun/netstack"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
|
||||
"github.com/netbirdio/netbird/client/firewall/uspfilter"
|
||||
@@ -82,10 +81,6 @@ func (w *mocWGIface) GetStats(_ string) (configurer.WGStats, error) {
|
||||
return configurer.WGStats{}, nil
|
||||
}
|
||||
|
||||
func (w *mocWGIface) GetNet() *netstack.Net {
|
||||
return nil
|
||||
}
|
||||
|
||||
var zoneRecords = []nbdns.SimpleRecord{
|
||||
{
|
||||
Name: "peera.netbird.cloud",
|
||||
@@ -133,7 +128,7 @@ func TestUpdateDNSServer(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
initUpstreamMap registeredHandlerMap
|
||||
initLocalZones []nbdns.CustomZone
|
||||
initLocalRecords []nbdns.SimpleRecord
|
||||
initSerial uint64
|
||||
inputSerial uint64
|
||||
inputUpdate nbdns.Config
|
||||
@@ -185,8 +180,8 @@ func TestUpdateDNSServer(t *testing.T) {
|
||||
expectedLocalQs: []dns.Question{{Name: "peera.netbird.cloud.", Qtype: dns.TypeA, Qclass: dns.ClassINET}},
|
||||
},
|
||||
{
|
||||
name: "New Config Should Succeed",
|
||||
initLocalZones: []nbdns.CustomZone{{Domain: "netbird.cloud", Records: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: 1, Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}}}},
|
||||
name: "New Config Should Succeed",
|
||||
initLocalRecords: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: 1, Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}},
|
||||
initUpstreamMap: registeredHandlerMap{
|
||||
generateDummyHandler(zoneRecords[0].Name, nameServers).ID(): handlerWrapper{
|
||||
domain: "netbird.cloud",
|
||||
@@ -226,19 +221,19 @@ func TestUpdateDNSServer(t *testing.T) {
|
||||
expectedLocalQs: []dns.Question{{Name: zoneRecords[0].Name, Qtype: 1, Qclass: 1}},
|
||||
},
|
||||
{
|
||||
name: "Smaller Config Serial Should Be Skipped",
|
||||
initLocalZones: []nbdns.CustomZone{},
|
||||
initUpstreamMap: make(registeredHandlerMap),
|
||||
initSerial: 2,
|
||||
inputSerial: 1,
|
||||
shouldFail: true,
|
||||
name: "Smaller Config Serial Should Be Skipped",
|
||||
initLocalRecords: []nbdns.SimpleRecord{},
|
||||
initUpstreamMap: make(registeredHandlerMap),
|
||||
initSerial: 2,
|
||||
inputSerial: 1,
|
||||
shouldFail: true,
|
||||
},
|
||||
{
|
||||
name: "Empty NS Group Domain Or Not Primary Element Should Fail",
|
||||
initLocalZones: []nbdns.CustomZone{},
|
||||
initUpstreamMap: make(registeredHandlerMap),
|
||||
initSerial: 0,
|
||||
inputSerial: 1,
|
||||
name: "Empty NS Group Domain Or Not Primary Element Should Fail",
|
||||
initLocalRecords: []nbdns.SimpleRecord{},
|
||||
initUpstreamMap: make(registeredHandlerMap),
|
||||
initSerial: 0,
|
||||
inputSerial: 1,
|
||||
inputUpdate: nbdns.Config{
|
||||
ServiceEnable: true,
|
||||
CustomZones: []nbdns.CustomZone{
|
||||
@@ -256,11 +251,11 @@ func TestUpdateDNSServer(t *testing.T) {
|
||||
shouldFail: true,
|
||||
},
|
||||
{
|
||||
name: "Invalid NS Group Nameservers list Should Fail",
|
||||
initLocalZones: []nbdns.CustomZone{},
|
||||
initUpstreamMap: make(registeredHandlerMap),
|
||||
initSerial: 0,
|
||||
inputSerial: 1,
|
||||
name: "Invalid NS Group Nameservers list Should Fail",
|
||||
initLocalRecords: []nbdns.SimpleRecord{},
|
||||
initUpstreamMap: make(registeredHandlerMap),
|
||||
initSerial: 0,
|
||||
inputSerial: 1,
|
||||
inputUpdate: nbdns.Config{
|
||||
ServiceEnable: true,
|
||||
CustomZones: []nbdns.CustomZone{
|
||||
@@ -278,11 +273,11 @@ func TestUpdateDNSServer(t *testing.T) {
|
||||
shouldFail: true,
|
||||
},
|
||||
{
|
||||
name: "Invalid Custom Zone Records list Should Skip",
|
||||
initLocalZones: []nbdns.CustomZone{},
|
||||
initUpstreamMap: make(registeredHandlerMap),
|
||||
initSerial: 0,
|
||||
inputSerial: 1,
|
||||
name: "Invalid Custom Zone Records list Should Skip",
|
||||
initLocalRecords: []nbdns.SimpleRecord{},
|
||||
initUpstreamMap: make(registeredHandlerMap),
|
||||
initSerial: 0,
|
||||
inputSerial: 1,
|
||||
inputUpdate: nbdns.Config{
|
||||
ServiceEnable: true,
|
||||
CustomZones: []nbdns.CustomZone{
|
||||
@@ -304,8 +299,8 @@ func TestUpdateDNSServer(t *testing.T) {
|
||||
}},
|
||||
},
|
||||
{
|
||||
name: "Empty Config Should Succeed and Clean Maps",
|
||||
initLocalZones: []nbdns.CustomZone{{Domain: "netbird.cloud", Records: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}}}},
|
||||
name: "Empty Config Should Succeed and Clean Maps",
|
||||
initLocalRecords: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}},
|
||||
initUpstreamMap: registeredHandlerMap{
|
||||
generateDummyHandler(zoneRecords[0].Name, nameServers).ID(): handlerWrapper{
|
||||
domain: zoneRecords[0].Name,
|
||||
@@ -320,8 +315,8 @@ func TestUpdateDNSServer(t *testing.T) {
|
||||
expectedLocalQs: []dns.Question{},
|
||||
},
|
||||
{
|
||||
name: "Disabled Service Should clean map",
|
||||
initLocalZones: []nbdns.CustomZone{{Domain: "netbird.cloud", Records: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}}}},
|
||||
name: "Disabled Service Should clean map",
|
||||
initLocalRecords: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}},
|
||||
initUpstreamMap: registeredHandlerMap{
|
||||
generateDummyHandler(zoneRecords[0].Name, nameServers).ID(): handlerWrapper{
|
||||
domain: zoneRecords[0].Name,
|
||||
@@ -390,7 +385,7 @@ func TestUpdateDNSServer(t *testing.T) {
|
||||
}()
|
||||
|
||||
dnsServer.dnsMuxMap = testCase.initUpstreamMap
|
||||
dnsServer.localResolver.Update(testCase.initLocalZones)
|
||||
dnsServer.localResolver.Update(testCase.initLocalRecords)
|
||||
dnsServer.updateSerial = testCase.initSerial
|
||||
|
||||
err = dnsServer.UpdateDNSServer(testCase.inputSerial, testCase.inputUpdate)
|
||||
@@ -515,7 +510,8 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) {
|
||||
priority: PriorityUpstream,
|
||||
},
|
||||
}
|
||||
dnsServer.localResolver.Update([]nbdns.CustomZone{{Domain: "netbird.cloud", Records: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}}}})
|
||||
//dnsServer.localResolver.RegisteredMap = local.RegistrationMap{local.BuildRecordKey("netbird.cloud", dns.ClassINET, dns.TypeA): struct{}{}}
|
||||
dnsServer.localResolver.Update([]nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}})
|
||||
dnsServer.updateSerial = 0
|
||||
|
||||
nameServers := []nbdns.NameServer{
|
||||
@@ -2052,7 +2048,7 @@ func TestLocalResolverPriorityInServer(t *testing.T) {
|
||||
|
||||
func TestLocalResolverPriorityConstants(t *testing.T) {
|
||||
// Test that priority constants are ordered correctly
|
||||
assert.Greater(t, PriorityDNSRoute, PriorityLocal, "DNS Route should be higher than Local priority")
|
||||
assert.Greater(t, PriorityLocal, PriorityDNSRoute, "Local priority should be higher than DNS route")
|
||||
assert.Greater(t, PriorityLocal, PriorityUpstream, "Local priority should be higher than upstream")
|
||||
assert.Greater(t, PriorityUpstream, PriorityDefault, "Upstream priority should be higher than default")
|
||||
|
||||
|
||||
@@ -2,6 +2,7 @@ package dns
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
@@ -18,10 +19,8 @@ import (
|
||||
"github.com/hashicorp/go-multierror"
|
||||
"github.com/miekg/dns"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.zx2c4.com/wireguard/tun/netstack"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/internal/dns/resutil"
|
||||
"github.com/netbirdio/netbird/client/internal/dns/types"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
@@ -114,7 +113,10 @@ func (u *upstreamResolverBase) Stop() {
|
||||
|
||||
// ServeDNS handles a DNS request
|
||||
func (u *upstreamResolverBase) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||
logger := log.WithField("request_id", resutil.GetRequestID(w))
|
||||
requestID := GenerateRequestID()
|
||||
logger := log.WithField("request_id", requestID)
|
||||
|
||||
logger.Tracef("received upstream question: domain=%s type=%v class=%v", r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass)
|
||||
|
||||
u.prepareRequest(r)
|
||||
|
||||
@@ -200,18 +202,11 @@ func (u *upstreamResolverBase) handleUpstreamError(err error, upstream netip.Add
|
||||
|
||||
func (u *upstreamResolverBase) writeSuccessResponse(w dns.ResponseWriter, rm *dns.Msg, upstream netip.AddrPort, domain string, t time.Duration, logger *log.Entry) bool {
|
||||
u.successCount.Add(1)
|
||||
|
||||
resutil.SetMeta(w, "upstream", upstream.String())
|
||||
|
||||
// Clear Zero bit from external responses to prevent upstream servers from
|
||||
// manipulating our internal fallthrough signaling mechanism
|
||||
rm.MsgHdr.Zero = false
|
||||
logger.Tracef("took %s to query the upstream %s for question domain=%s", t, upstream, domain)
|
||||
|
||||
if err := w.WriteMsg(rm); err != nil {
|
||||
logger.Errorf("failed to write DNS response for question domain=%s: %s", domain, err)
|
||||
return true
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -419,56 +414,16 @@ func ExchangeWithFallback(ctx context.Context, client *dns.Client, r *dns.Msg, u
|
||||
return rm, t, nil
|
||||
}
|
||||
|
||||
// ExchangeWithNetstack performs a DNS exchange using netstack for dialing.
|
||||
// This is needed when netstack is enabled to reach peer IPs through the tunnel.
|
||||
func ExchangeWithNetstack(ctx context.Context, nsNet *netstack.Net, r *dns.Msg, upstream string) (*dns.Msg, error) {
|
||||
reply, err := netstackExchange(ctx, nsNet, r, upstream, "udp")
|
||||
func GenerateRequestID() string {
|
||||
bytes := make([]byte, 4)
|
||||
_, err := rand.Read(bytes)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
log.Errorf("failed to generate request ID: %v", err)
|
||||
return ""
|
||||
}
|
||||
|
||||
// If response is truncated, retry with TCP
|
||||
if reply != nil && reply.MsgHdr.Truncated {
|
||||
log.Tracef("udp response for domain=%s type=%v class=%v is truncated, trying TCP",
|
||||
r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass)
|
||||
return netstackExchange(ctx, nsNet, r, upstream, "tcp")
|
||||
}
|
||||
|
||||
return reply, nil
|
||||
return hex.EncodeToString(bytes)
|
||||
}
|
||||
|
||||
func netstackExchange(ctx context.Context, nsNet *netstack.Net, r *dns.Msg, upstream, network string) (*dns.Msg, error) {
|
||||
conn, err := nsNet.DialContext(ctx, network, upstream)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("with %s: %w", network, err)
|
||||
}
|
||||
defer func() {
|
||||
if err := conn.Close(); err != nil {
|
||||
log.Debugf("failed to close DNS connection: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
if deadline, ok := ctx.Deadline(); ok {
|
||||
if err := conn.SetDeadline(deadline); err != nil {
|
||||
return nil, fmt.Errorf("set deadline: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
dnsConn := &dns.Conn{Conn: conn}
|
||||
|
||||
if err := dnsConn.WriteMsg(r); err != nil {
|
||||
return nil, fmt.Errorf("write %s message: %w", network, err)
|
||||
}
|
||||
|
||||
reply, err := dnsConn.ReadMsg()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read %s message: %w", network, err)
|
||||
}
|
||||
|
||||
return reply, nil
|
||||
}
|
||||
|
||||
|
||||
// FormatPeerStatus formats peer connection status information for debugging DNS timeouts
|
||||
func FormatPeerStatus(peerState *peer.State) string {
|
||||
isConnected := peerState.ConnStatus == peer.StatusConnected
|
||||
|
||||
@@ -23,7 +23,9 @@ type upstreamResolver struct {
|
||||
// first time, and we need to wait for a while to start to use again the proper DNS resolver.
|
||||
func newUpstreamResolver(
|
||||
ctx context.Context,
|
||||
_ WGIface,
|
||||
_ string,
|
||||
_ netip.Addr,
|
||||
_ netip.Prefix,
|
||||
statusRecorder *peer.Status,
|
||||
hostsDNSHolder *hostsDNSHolder,
|
||||
domain string,
|
||||
|
||||
@@ -5,23 +5,22 @@ package dns
|
||||
import (
|
||||
"context"
|
||||
"net/netip"
|
||||
"runtime"
|
||||
"time"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
"golang.zx2c4.com/wireguard/tun/netstack"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
)
|
||||
|
||||
type upstreamResolver struct {
|
||||
*upstreamResolverBase
|
||||
nsNet *netstack.Net
|
||||
}
|
||||
|
||||
func newUpstreamResolver(
|
||||
ctx context.Context,
|
||||
wgIface WGIface,
|
||||
_ string,
|
||||
_ netip.Addr,
|
||||
_ netip.Prefix,
|
||||
statusRecorder *peer.Status,
|
||||
_ *hostsDNSHolder,
|
||||
domain string,
|
||||
@@ -29,23 +28,12 @@ func newUpstreamResolver(
|
||||
upstreamResolverBase := newUpstreamResolverBase(ctx, statusRecorder, domain)
|
||||
nonIOS := &upstreamResolver{
|
||||
upstreamResolverBase: upstreamResolverBase,
|
||||
nsNet: wgIface.GetNet(),
|
||||
}
|
||||
upstreamResolverBase.upstreamClient = nonIOS
|
||||
return nonIOS, nil
|
||||
}
|
||||
|
||||
func (u *upstreamResolver) exchange(ctx context.Context, upstream string, r *dns.Msg) (rm *dns.Msg, t time.Duration, err error) {
|
||||
// TODO: Check if upstream DNS server is routed through a peer before using netstack.
|
||||
// Similar to iOS logic, we should determine if the DNS server is reachable directly
|
||||
// or needs to go through the tunnel, and only use netstack when necessary.
|
||||
// For now, only use netstack on JS platform where direct access is not possible.
|
||||
if u.nsNet != nil && runtime.GOOS == "js" {
|
||||
start := time.Now()
|
||||
reply, err := ExchangeWithNetstack(ctx, u.nsNet, r, upstream)
|
||||
return reply, time.Since(start), err
|
||||
}
|
||||
|
||||
client := &dns.Client{
|
||||
Timeout: ClientTimeout,
|
||||
}
|
||||
|
||||
@@ -26,7 +26,9 @@ type upstreamResolverIOS struct {
|
||||
|
||||
func newUpstreamResolver(
|
||||
ctx context.Context,
|
||||
wgIface WGIface,
|
||||
interfaceName string,
|
||||
ip netip.Addr,
|
||||
net netip.Prefix,
|
||||
statusRecorder *peer.Status,
|
||||
_ *hostsDNSHolder,
|
||||
domain string,
|
||||
@@ -35,9 +37,9 @@ func newUpstreamResolver(
|
||||
|
||||
ios := &upstreamResolverIOS{
|
||||
upstreamResolverBase: upstreamResolverBase,
|
||||
lIP: wgIface.Address().IP,
|
||||
lNet: wgIface.Address().Network,
|
||||
interfaceName: wgIface.Name(),
|
||||
lIP: ip,
|
||||
lNet: net,
|
||||
interfaceName: interfaceName,
|
||||
}
|
||||
ios.upstreamClient = ios
|
||||
|
||||
|
||||
@@ -2,17 +2,13 @@ package dns
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"net/netip"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
"golang.zx2c4.com/wireguard/tun/netstack"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
"github.com/netbirdio/netbird/client/internal/dns/test"
|
||||
)
|
||||
|
||||
@@ -62,7 +58,7 @@ func TestUpstreamResolver_ServeDNS(t *testing.T) {
|
||||
for _, testCase := range testCases {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.TODO())
|
||||
resolver, _ := newUpstreamResolver(ctx, &mockNetstackProvider{}, nil, nil, ".")
|
||||
resolver, _ := newUpstreamResolver(ctx, "", netip.Addr{}, netip.Prefix{}, nil, nil, ".")
|
||||
// Convert test servers to netip.AddrPort
|
||||
var servers []netip.AddrPort
|
||||
for _, server := range testCase.InputServers {
|
||||
@@ -116,19 +112,6 @@ func TestUpstreamResolver_ServeDNS(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
type mockNetstackProvider struct{}
|
||||
|
||||
func (m *mockNetstackProvider) Name() string { return "mock" }
|
||||
func (m *mockNetstackProvider) Address() wgaddr.Address { return wgaddr.Address{} }
|
||||
func (m *mockNetstackProvider) ToInterface() *net.Interface { return nil }
|
||||
func (m *mockNetstackProvider) IsUserspaceBind() bool { return false }
|
||||
func (m *mockNetstackProvider) GetFilter() device.PacketFilter { return nil }
|
||||
func (m *mockNetstackProvider) GetDevice() *device.FilteredDevice { return nil }
|
||||
func (m *mockNetstackProvider) GetNet() *netstack.Net { return nil }
|
||||
func (m *mockNetstackProvider) GetInterfaceGUIDString() (string, error) {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
type mockUpstreamResolver struct {
|
||||
r *dns.Msg
|
||||
rtt time.Duration
|
||||
|
||||
@@ -5,8 +5,6 @@ package dns
|
||||
import (
|
||||
"net"
|
||||
|
||||
"golang.zx2c4.com/wireguard/tun/netstack"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
)
|
||||
@@ -19,5 +17,4 @@ type WGIface interface {
|
||||
IsUserspaceBind() bool
|
||||
GetFilter() device.PacketFilter
|
||||
GetDevice() *device.FilteredDevice
|
||||
GetNet() *netstack.Net
|
||||
}
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
package dns
|
||||
|
||||
import (
|
||||
"golang.zx2c4.com/wireguard/tun/netstack"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
)
|
||||
@@ -14,6 +12,5 @@ type WGIface interface {
|
||||
IsUserspaceBind() bool
|
||||
GetFilter() device.PacketFilter
|
||||
GetDevice() *device.FilteredDevice
|
||||
GetNet() *netstack.Net
|
||||
GetInterfaceGUIDString() (string, error)
|
||||
}
|
||||
|
||||
@@ -18,7 +18,6 @@ import (
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/internal/dns/resutil"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
)
|
||||
@@ -190,22 +189,29 @@ func (f *DNSForwarder) Close(ctx context.Context) error {
|
||||
return nberrors.FormatErrorOrNil(result)
|
||||
}
|
||||
|
||||
func (f *DNSForwarder) handleDNSQuery(logger *log.Entry, w dns.ResponseWriter, query *dns.Msg) *dns.Msg {
|
||||
func (f *DNSForwarder) handleDNSQuery(w dns.ResponseWriter, query *dns.Msg) *dns.Msg {
|
||||
if len(query.Question) == 0 {
|
||||
return nil
|
||||
}
|
||||
question := query.Question[0]
|
||||
logger.Tracef("received DNS request for DNS forwarder: domain=%s type=%s class=%s",
|
||||
question.Name, dns.TypeToString[question.Qtype], dns.ClassToString[question.Qclass])
|
||||
log.Tracef("received DNS request for DNS forwarder: domain=%v type=%v class=%v",
|
||||
question.Name, question.Qtype, question.Qclass)
|
||||
|
||||
domain := strings.ToLower(question.Name)
|
||||
|
||||
resp := query.SetReply(query)
|
||||
network := resutil.NetworkForQtype(question.Qtype)
|
||||
if network == "" {
|
||||
var network string
|
||||
switch question.Qtype {
|
||||
case dns.TypeA:
|
||||
network = "ip4"
|
||||
case dns.TypeAAAA:
|
||||
network = "ip6"
|
||||
default:
|
||||
// TODO: Handle other types
|
||||
|
||||
resp.Rcode = dns.RcodeNotImplemented
|
||||
if err := w.WriteMsg(resp); err != nil {
|
||||
logger.Errorf("failed to write DNS response: %v", err)
|
||||
log.Errorf("failed to write DNS response: %v", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -215,35 +221,33 @@ func (f *DNSForwarder) handleDNSQuery(logger *log.Entry, w dns.ResponseWriter, q
|
||||
if mostSpecificResId == "" {
|
||||
resp.Rcode = dns.RcodeRefused
|
||||
if err := w.WriteMsg(resp); err != nil {
|
||||
logger.Errorf("failed to write DNS response: %v", err)
|
||||
log.Errorf("failed to write DNS response: %v", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), upstreamTimeout)
|
||||
defer cancel()
|
||||
|
||||
result := resutil.LookupIP(ctx, f.resolver, network, domain, question.Qtype)
|
||||
if result.Err != nil {
|
||||
f.handleDNSError(ctx, logger, w, question, resp, domain, result)
|
||||
ips, err := f.resolver.LookupNetIP(ctx, network, domain)
|
||||
if err != nil {
|
||||
f.handleDNSError(ctx, w, question, resp, domain, err)
|
||||
return nil
|
||||
}
|
||||
|
||||
f.updateInternalState(result.IPs, mostSpecificResId, matchingEntries)
|
||||
resp.Answer = append(resp.Answer, resutil.IPsToRRs(domain, result.IPs, f.ttl)...)
|
||||
f.cache.set(domain, question.Qtype, result.IPs)
|
||||
// Unmap IPv4-mapped IPv6 addresses that some resolvers may return
|
||||
for i, ip := range ips {
|
||||
ips[i] = ip.Unmap()
|
||||
}
|
||||
|
||||
f.updateInternalState(ips, mostSpecificResId, matchingEntries)
|
||||
f.addIPsToResponse(resp, domain, ips)
|
||||
f.cache.set(domain, question.Qtype, ips)
|
||||
|
||||
return resp
|
||||
}
|
||||
|
||||
func (f *DNSForwarder) handleDNSQueryUDP(w dns.ResponseWriter, query *dns.Msg) {
|
||||
startTime := time.Now()
|
||||
logger := log.WithFields(log.Fields{
|
||||
"request_id": resutil.GenerateRequestID(),
|
||||
"dns_id": fmt.Sprintf("%04x", query.Id),
|
||||
})
|
||||
|
||||
resp := f.handleDNSQuery(logger, w, query)
|
||||
resp := f.handleDNSQuery(w, query)
|
||||
if resp == nil {
|
||||
return
|
||||
}
|
||||
@@ -261,33 +265,19 @@ func (f *DNSForwarder) handleDNSQueryUDP(w dns.ResponseWriter, query *dns.Msg) {
|
||||
}
|
||||
|
||||
if err := w.WriteMsg(resp); err != nil {
|
||||
logger.Errorf("failed to write DNS response: %v", err)
|
||||
return
|
||||
log.Errorf("failed to write DNS response: %v", err)
|
||||
}
|
||||
|
||||
logger.Tracef("response: domain=%s rcode=%s answers=%s took=%s",
|
||||
query.Question[0].Name, dns.RcodeToString[resp.Rcode], resutil.FormatAnswers(resp.Answer), time.Since(startTime))
|
||||
}
|
||||
|
||||
func (f *DNSForwarder) handleDNSQueryTCP(w dns.ResponseWriter, query *dns.Msg) {
|
||||
startTime := time.Now()
|
||||
logger := log.WithFields(log.Fields{
|
||||
"request_id": resutil.GenerateRequestID(),
|
||||
"dns_id": fmt.Sprintf("%04x", query.Id),
|
||||
})
|
||||
|
||||
resp := f.handleDNSQuery(logger, w, query)
|
||||
resp := f.handleDNSQuery(w, query)
|
||||
if resp == nil {
|
||||
return
|
||||
}
|
||||
|
||||
if err := w.WriteMsg(resp); err != nil {
|
||||
logger.Errorf("failed to write DNS response: %v", err)
|
||||
return
|
||||
log.Errorf("failed to write DNS response: %v", err)
|
||||
}
|
||||
|
||||
logger.Tracef("response: domain=%s rcode=%s answers=%s took=%s",
|
||||
query.Question[0].Name, dns.RcodeToString[resp.Rcode], resutil.FormatAnswers(resp.Answer), time.Since(startTime))
|
||||
}
|
||||
|
||||
func (f *DNSForwarder) updateInternalState(ips []netip.Addr, mostSpecificResId route.ResID, matchingEntries []*ForwarderEntry) {
|
||||
@@ -325,64 +315,140 @@ func (f *DNSForwarder) updateFirewall(matchingEntries []*ForwarderEntry, prefixe
|
||||
}
|
||||
}
|
||||
|
||||
// setResponseCodeForNotFound determines and sets the appropriate response code when IsNotFound is true
|
||||
// It distinguishes between NXDOMAIN (domain doesn't exist) and NODATA (domain exists but no records of requested type)
|
||||
//
|
||||
// LIMITATION: This function only checks A and AAAA record types to determine domain existence.
|
||||
// If a domain has only other record types (MX, TXT, CNAME, etc.) but no A/AAAA records,
|
||||
// it may incorrectly return NXDOMAIN instead of NODATA. This is acceptable since the forwarder
|
||||
// only handles A/AAAA queries and returns NOTIMP for other types.
|
||||
func (f *DNSForwarder) setResponseCodeForNotFound(ctx context.Context, resp *dns.Msg, domain string, originalQtype uint16) {
|
||||
// Try querying for a different record type to see if the domain exists
|
||||
// If the original query was for AAAA, try A. If it was for A, try AAAA.
|
||||
// This helps distinguish between NXDOMAIN and NODATA.
|
||||
var alternativeNetwork string
|
||||
switch originalQtype {
|
||||
case dns.TypeAAAA:
|
||||
alternativeNetwork = "ip4"
|
||||
case dns.TypeA:
|
||||
alternativeNetwork = "ip6"
|
||||
default:
|
||||
resp.Rcode = dns.RcodeNameError
|
||||
return
|
||||
}
|
||||
|
||||
if _, err := f.resolver.LookupNetIP(ctx, alternativeNetwork, domain); err != nil {
|
||||
var dnsErr *net.DNSError
|
||||
if errors.As(err, &dnsErr) && dnsErr.IsNotFound {
|
||||
// Alternative query also returned not found - domain truly doesn't exist
|
||||
resp.Rcode = dns.RcodeNameError
|
||||
return
|
||||
}
|
||||
// Some other error (timeout, server failure, etc.) - can't determine, assume domain exists
|
||||
resp.Rcode = dns.RcodeSuccess
|
||||
return
|
||||
}
|
||||
|
||||
// Alternative query succeeded - domain exists but has no records of this type
|
||||
resp.Rcode = dns.RcodeSuccess
|
||||
}
|
||||
|
||||
// handleDNSError processes DNS lookup errors and sends an appropriate error response.
|
||||
func (f *DNSForwarder) handleDNSError(
|
||||
ctx context.Context,
|
||||
logger *log.Entry,
|
||||
w dns.ResponseWriter,
|
||||
question dns.Question,
|
||||
resp *dns.Msg,
|
||||
domain string,
|
||||
result resutil.LookupResult,
|
||||
err error,
|
||||
) {
|
||||
// Default to SERVFAIL; override below when appropriate.
|
||||
resp.Rcode = dns.RcodeServerFailure
|
||||
|
||||
qType := question.Qtype
|
||||
qTypeName := dns.TypeToString[qType]
|
||||
|
||||
resp.Rcode = result.Rcode
|
||||
|
||||
// NotFound: cache negative result and respond
|
||||
if result.Rcode == dns.RcodeNameError || result.Rcode == dns.RcodeSuccess {
|
||||
f.cache.set(domain, question.Qtype, nil)
|
||||
// Prefer typed DNS errors; fall back to generic logging otherwise.
|
||||
var dnsErr *net.DNSError
|
||||
if !errors.As(err, &dnsErr) {
|
||||
log.Warnf(errResolveFailed, domain, err)
|
||||
if writeErr := w.WriteMsg(resp); writeErr != nil {
|
||||
logger.Errorf("failed to write failure DNS response: %v", writeErr)
|
||||
log.Errorf("failed to write failure DNS response: %v", writeErr)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// NotFound: set NXDOMAIN / appropriate code via helper.
|
||||
if dnsErr.IsNotFound {
|
||||
f.setResponseCodeForNotFound(ctx, resp, domain, qType)
|
||||
if writeErr := w.WriteMsg(resp); writeErr != nil {
|
||||
log.Errorf("failed to write failure DNS response: %v", writeErr)
|
||||
}
|
||||
f.cache.set(domain, question.Qtype, nil)
|
||||
return
|
||||
}
|
||||
|
||||
// Upstream failed but we might have a cached answer—serve it if present.
|
||||
if ips, ok := f.cache.get(domain, qType); ok {
|
||||
if len(ips) > 0 {
|
||||
logger.Debugf("serving cached DNS response after upstream failure: domain=%s type=%s", domain, qTypeName)
|
||||
resp.Answer = append(resp.Answer, resutil.IPsToRRs(domain, ips, f.ttl)...)
|
||||
log.Debugf("serving cached DNS response after upstream failure: domain=%s type=%s", domain, qTypeName)
|
||||
f.addIPsToResponse(resp, domain, ips)
|
||||
resp.Rcode = dns.RcodeSuccess
|
||||
if writeErr := w.WriteMsg(resp); writeErr != nil {
|
||||
logger.Errorf("failed to write cached DNS response: %v", writeErr)
|
||||
log.Errorf("failed to write cached DNS response: %v", writeErr)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Cached negative result - re-verify NXDOMAIN vs NODATA
|
||||
verifyResult := resutil.LookupIP(ctx, f.resolver, resutil.NetworkForQtype(qType), domain, qType)
|
||||
if verifyResult.Rcode == dns.RcodeNameError || verifyResult.Rcode == dns.RcodeSuccess {
|
||||
resp.Rcode = verifyResult.Rcode
|
||||
} else { // send NXDOMAIN / appropriate code if cache is empty
|
||||
f.setResponseCodeForNotFound(ctx, resp, domain, qType)
|
||||
if writeErr := w.WriteMsg(resp); writeErr != nil {
|
||||
logger.Errorf("failed to write failure DNS response: %v", writeErr)
|
||||
log.Errorf("failed to write failure DNS response: %v", writeErr)
|
||||
}
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// No cache or verification failed. Log with or without the server field for more context.
|
||||
var dnsErr *net.DNSError
|
||||
if errors.As(result.Err, &dnsErr) && dnsErr.Server != "" {
|
||||
logger.Warnf("failed to resolve: type=%s domain=%s server=%s: %v", qTypeName, domain, dnsErr.Server, result.Err)
|
||||
// No cache. Log with or without the server field for more context.
|
||||
if dnsErr.Server != "" {
|
||||
log.Warnf("failed to resolve: type=%s domain=%s server=%s: %v", qTypeName, domain, dnsErr.Server, err)
|
||||
} else {
|
||||
logger.Warnf(errResolveFailed, domain, result.Err)
|
||||
log.Warnf(errResolveFailed, domain, err)
|
||||
}
|
||||
|
||||
// Write final failure response.
|
||||
if writeErr := w.WriteMsg(resp); writeErr != nil {
|
||||
logger.Errorf("failed to write failure DNS response: %v", writeErr)
|
||||
log.Errorf("failed to write failure DNS response: %v", writeErr)
|
||||
}
|
||||
}
|
||||
|
||||
// addIPsToResponse adds IP addresses to the DNS response as appropriate A or AAAA records
|
||||
func (f *DNSForwarder) addIPsToResponse(resp *dns.Msg, domain string, ips []netip.Addr) {
|
||||
for _, ip := range ips {
|
||||
var respRecord dns.RR
|
||||
if ip.Is6() {
|
||||
log.Tracef("resolved domain=%s to IPv6=%s", domain, ip)
|
||||
rr := dns.AAAA{
|
||||
AAAA: ip.AsSlice(),
|
||||
Hdr: dns.RR_Header{
|
||||
Name: domain,
|
||||
Rrtype: dns.TypeAAAA,
|
||||
Class: dns.ClassINET,
|
||||
Ttl: f.ttl,
|
||||
},
|
||||
}
|
||||
respRecord = &rr
|
||||
} else {
|
||||
log.Tracef("resolved domain=%s to IPv4=%s", domain, ip)
|
||||
rr := dns.A{
|
||||
A: ip.AsSlice(),
|
||||
Hdr: dns.RR_Header{
|
||||
Name: domain,
|
||||
Rrtype: dns.TypeA,
|
||||
Class: dns.ClassINET,
|
||||
Ttl: f.ttl,
|
||||
},
|
||||
}
|
||||
respRecord = &rr
|
||||
}
|
||||
resp.Answer = append(resp.Answer, respRecord)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -10,7 +10,6 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
"github.com/stretchr/testify/require"
|
||||
@@ -318,7 +317,7 @@ func TestDNSForwarder_UnauthorizedDomainAccess(t *testing.T) {
|
||||
query.SetQuestion(dns.Fqdn(tt.queryDomain), dns.TypeA)
|
||||
|
||||
mockWriter := &test.MockResponseWriter{}
|
||||
resp := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query)
|
||||
resp := forwarder.handleDNSQuery(mockWriter, query)
|
||||
|
||||
if tt.shouldResolve {
|
||||
require.NotNil(t, resp, "Expected response for authorized domain")
|
||||
@@ -466,7 +465,7 @@ func TestDNSForwarder_FirewallSetUpdates(t *testing.T) {
|
||||
dnsQuery.SetQuestion(dns.Fqdn(tt.query), dns.TypeA)
|
||||
|
||||
mockWriter := &test.MockResponseWriter{}
|
||||
resp := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, dnsQuery)
|
||||
resp := forwarder.handleDNSQuery(mockWriter, dnsQuery)
|
||||
|
||||
// Verify response
|
||||
if tt.shouldResolve {
|
||||
@@ -528,7 +527,7 @@ func TestDNSForwarder_MultipleIPsInSingleUpdate(t *testing.T) {
|
||||
query.SetQuestion("example.com.", dns.TypeA)
|
||||
|
||||
mockWriter := &test.MockResponseWriter{}
|
||||
resp := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query)
|
||||
resp := forwarder.handleDNSQuery(mockWriter, query)
|
||||
|
||||
// Verify response contains all IPs
|
||||
require.NotNil(t, resp)
|
||||
@@ -605,7 +604,7 @@ func TestDNSForwarder_ResponseCodes(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
_ = forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query)
|
||||
_ = forwarder.handleDNSQuery(mockWriter, query)
|
||||
|
||||
// Check the response written to the writer
|
||||
require.NotNil(t, writtenResp, "Expected response to be written")
|
||||
@@ -675,7 +674,7 @@ func TestDNSForwarder_ServeFromCacheOnUpstreamFailure(t *testing.T) {
|
||||
q1 := &dns.Msg{}
|
||||
q1.SetQuestion(dns.Fqdn("example.com"), dns.TypeA)
|
||||
w1 := &test.MockResponseWriter{}
|
||||
resp1 := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), w1, q1)
|
||||
resp1 := forwarder.handleDNSQuery(w1, q1)
|
||||
require.NotNil(t, resp1)
|
||||
require.Equal(t, dns.RcodeSuccess, resp1.Rcode)
|
||||
require.Len(t, resp1.Answer, 1)
|
||||
@@ -685,7 +684,7 @@ func TestDNSForwarder_ServeFromCacheOnUpstreamFailure(t *testing.T) {
|
||||
q2.SetQuestion(dns.Fqdn("example.com"), dns.TypeA)
|
||||
var writtenResp *dns.Msg
|
||||
w2 := &test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { writtenResp = m; return nil }}
|
||||
_ = forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), w2, q2)
|
||||
_ = forwarder.handleDNSQuery(w2, q2)
|
||||
|
||||
require.NotNil(t, writtenResp, "expected response to be written")
|
||||
require.Equal(t, dns.RcodeSuccess, writtenResp.Rcode)
|
||||
@@ -715,7 +714,7 @@ func TestDNSForwarder_CacheNormalizationCasingAndDot(t *testing.T) {
|
||||
q1 := &dns.Msg{}
|
||||
q1.SetQuestion(mixedQuery+".", dns.TypeA)
|
||||
w1 := &test.MockResponseWriter{}
|
||||
resp1 := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), w1, q1)
|
||||
resp1 := forwarder.handleDNSQuery(w1, q1)
|
||||
require.NotNil(t, resp1)
|
||||
require.Equal(t, dns.RcodeSuccess, resp1.Rcode)
|
||||
require.Len(t, resp1.Answer, 1)
|
||||
@@ -729,7 +728,7 @@ func TestDNSForwarder_CacheNormalizationCasingAndDot(t *testing.T) {
|
||||
q2.SetQuestion("EXAMPLE.COM", dns.TypeA)
|
||||
var writtenResp *dns.Msg
|
||||
w2 := &test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { writtenResp = m; return nil }}
|
||||
_ = forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), w2, q2)
|
||||
_ = forwarder.handleDNSQuery(w2, q2)
|
||||
|
||||
require.NotNil(t, writtenResp)
|
||||
require.Equal(t, dns.RcodeSuccess, writtenResp.Rcode)
|
||||
@@ -784,7 +783,7 @@ func TestDNSForwarder_MultipleOverlappingPatterns(t *testing.T) {
|
||||
query.SetQuestion("smtp.mail.example.com.", dns.TypeA)
|
||||
|
||||
mockWriter := &test.MockResponseWriter{}
|
||||
resp := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query)
|
||||
resp := forwarder.handleDNSQuery(mockWriter, query)
|
||||
|
||||
require.NotNil(t, resp)
|
||||
assert.Equal(t, dns.RcodeSuccess, resp.Rcode)
|
||||
@@ -905,7 +904,7 @@ func TestDNSForwarder_NodataVsNxdomain(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
resp := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query)
|
||||
resp := forwarder.handleDNSQuery(mockWriter, query)
|
||||
|
||||
// If a response was returned, it means it should be written (happens in wrapper functions)
|
||||
if resp != nil && writtenResp == nil {
|
||||
@@ -938,7 +937,7 @@ func TestDNSForwarder_EmptyQuery(t *testing.T) {
|
||||
return nil
|
||||
},
|
||||
}
|
||||
resp := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query)
|
||||
resp := forwarder.handleDNSQuery(mockWriter, query)
|
||||
|
||||
assert.Nil(t, resp, "Should return nil for empty query")
|
||||
assert.False(t, writeCalled, "Should not write response for empty query")
|
||||
|
||||
@@ -1251,16 +1251,11 @@ func toDNSConfig(protoDNSConfig *mgmProto.DNSConfig, network netip.Prefix) nbdns
|
||||
ForwarderPort: forwarderPort,
|
||||
}
|
||||
|
||||
protoZones := protoDNSConfig.GetCustomZones()
|
||||
// Treat single zone as authoritative for backward compatibility with old servers
|
||||
// that only send the peer FQDN zone without setting field 4.
|
||||
singleZoneCompat := len(protoZones) == 1
|
||||
|
||||
for _, zone := range protoZones {
|
||||
for _, zone := range protoDNSConfig.GetCustomZones() {
|
||||
dnsZone := nbdns.CustomZone{
|
||||
Domain: zone.GetDomain(),
|
||||
SearchDomainDisabled: zone.GetSearchDomainDisabled(),
|
||||
NonAuthoritative: zone.GetNonAuthoritative() && !singleZoneCompat,
|
||||
SkipPTRProcess: zone.GetSkipPTRProcess(),
|
||||
}
|
||||
for _, record := range zone.Records {
|
||||
dnsRecord := nbdns.SimpleRecord{
|
||||
@@ -1748,26 +1743,22 @@ func (e *Engine) RunHealthProbes(waitForResult bool) bool {
|
||||
}
|
||||
|
||||
e.syncMsgMux.Unlock()
|
||||
|
||||
// Skip STUN/TURN probing for JS/WASM as it's not available
|
||||
relayHealthy := true
|
||||
if runtime.GOOS != "js" {
|
||||
var results []relay.ProbeResult
|
||||
if waitForResult {
|
||||
results = e.probeStunTurn.ProbeAllWaitResult(e.ctx, stuns, turns)
|
||||
} else {
|
||||
results = e.probeStunTurn.ProbeAll(e.ctx, stuns, turns)
|
||||
}
|
||||
e.statusRecorder.UpdateRelayStates(results)
|
||||
|
||||
for _, res := range results {
|
||||
if res.Err != nil {
|
||||
relayHealthy = false
|
||||
break
|
||||
}
|
||||
}
|
||||
log.Debugf("relay health check: healthy=%t", relayHealthy)
|
||||
var results []relay.ProbeResult
|
||||
if waitForResult {
|
||||
results = e.probeStunTurn.ProbeAllWaitResult(e.ctx, stuns, turns)
|
||||
} else {
|
||||
results = e.probeStunTurn.ProbeAll(e.ctx, stuns, turns)
|
||||
}
|
||||
e.statusRecorder.UpdateRelayStates(results)
|
||||
|
||||
relayHealthy := true
|
||||
for _, res := range results {
|
||||
if res.Err != nil {
|
||||
relayHealthy = false
|
||||
break
|
||||
}
|
||||
}
|
||||
log.Debugf("relay health check: healthy=%t", relayHealthy)
|
||||
|
||||
allHealthy := signalHealthy && managementHealthy && relayHealthy
|
||||
log.Debugf("all health checks completed: healthy=%t", allHealthy)
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
//go:build !windows
|
||||
// +build !windows
|
||||
|
||||
package internal
|
||||
|
||||
|
||||
@@ -669,17 +669,10 @@ func (conn *Conn) isConnectedOnAllWay() (connected bool) {
|
||||
}
|
||||
}()
|
||||
|
||||
// For JS platform: only relay connection is supported
|
||||
if runtime.GOOS == "js" {
|
||||
return conn.statusRelay.Get() == worker.StatusConnected
|
||||
}
|
||||
|
||||
// For non-JS platforms: check ICE connection status
|
||||
if conn.statusICE.Get() == worker.StatusDisconnected && !conn.workerICE.InProgress() {
|
||||
if runtime.GOOS != "js" && conn.statusICE.Get() == worker.StatusDisconnected && !conn.workerICE.InProgress() {
|
||||
return false
|
||||
}
|
||||
|
||||
// If relay is supported with peer, it must also be connected
|
||||
if conn.workerRelay.IsRelayConnectionSupportedWithPeer() {
|
||||
if conn.statusRelay.Get() == worker.StatusDisconnected {
|
||||
return false
|
||||
|
||||
@@ -14,7 +14,6 @@ import (
|
||||
"golang.org/x/exp/maps"
|
||||
"google.golang.org/grpc/codes"
|
||||
gstatus "google.golang.org/grpc/status"
|
||||
"google.golang.org/protobuf/types/known/durationpb"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
@@ -159,7 +158,6 @@ type FullStatus struct {
|
||||
NSGroupStates []NSGroupState
|
||||
NumOfForwardingRules int
|
||||
LazyConnectionEnabled bool
|
||||
Events []*proto.SystemEvent
|
||||
}
|
||||
|
||||
type StatusChangeSubscription struct {
|
||||
@@ -983,7 +981,6 @@ func (d *Status) GetFullStatus() FullStatus {
|
||||
}
|
||||
|
||||
fullStatus.Peers = append(fullStatus.Peers, d.offlinePeers...)
|
||||
fullStatus.Events = d.GetEventHistory()
|
||||
return fullStatus
|
||||
}
|
||||
|
||||
@@ -1184,97 +1181,3 @@ type EventSubscription struct {
|
||||
func (s *EventSubscription) Events() <-chan *proto.SystemEvent {
|
||||
return s.events
|
||||
}
|
||||
|
||||
// ToProto converts FullStatus to proto.FullStatus.
|
||||
func (fs FullStatus) ToProto() *proto.FullStatus {
|
||||
pbFullStatus := proto.FullStatus{
|
||||
ManagementState: &proto.ManagementState{},
|
||||
SignalState: &proto.SignalState{},
|
||||
LocalPeerState: &proto.LocalPeerState{},
|
||||
Peers: []*proto.PeerState{},
|
||||
}
|
||||
|
||||
pbFullStatus.ManagementState.URL = fs.ManagementState.URL
|
||||
pbFullStatus.ManagementState.Connected = fs.ManagementState.Connected
|
||||
if err := fs.ManagementState.Error; err != nil {
|
||||
pbFullStatus.ManagementState.Error = err.Error()
|
||||
}
|
||||
|
||||
pbFullStatus.SignalState.URL = fs.SignalState.URL
|
||||
pbFullStatus.SignalState.Connected = fs.SignalState.Connected
|
||||
if err := fs.SignalState.Error; err != nil {
|
||||
pbFullStatus.SignalState.Error = err.Error()
|
||||
}
|
||||
|
||||
pbFullStatus.LocalPeerState.IP = fs.LocalPeerState.IP
|
||||
pbFullStatus.LocalPeerState.PubKey = fs.LocalPeerState.PubKey
|
||||
pbFullStatus.LocalPeerState.KernelInterface = fs.LocalPeerState.KernelInterface
|
||||
pbFullStatus.LocalPeerState.Fqdn = fs.LocalPeerState.FQDN
|
||||
pbFullStatus.LocalPeerState.RosenpassPermissive = fs.RosenpassState.Permissive
|
||||
pbFullStatus.LocalPeerState.RosenpassEnabled = fs.RosenpassState.Enabled
|
||||
pbFullStatus.NumberOfForwardingRules = int32(fs.NumOfForwardingRules)
|
||||
pbFullStatus.LazyConnectionEnabled = fs.LazyConnectionEnabled
|
||||
|
||||
pbFullStatus.LocalPeerState.Networks = maps.Keys(fs.LocalPeerState.Routes)
|
||||
|
||||
for _, peerState := range fs.Peers {
|
||||
networks := maps.Keys(peerState.GetRoutes())
|
||||
|
||||
pbPeerState := &proto.PeerState{
|
||||
IP: peerState.IP,
|
||||
PubKey: peerState.PubKey,
|
||||
ConnStatus: peerState.ConnStatus.String(),
|
||||
ConnStatusUpdate: timestamppb.New(peerState.ConnStatusUpdate),
|
||||
Relayed: peerState.Relayed,
|
||||
LocalIceCandidateType: peerState.LocalIceCandidateType,
|
||||
RemoteIceCandidateType: peerState.RemoteIceCandidateType,
|
||||
LocalIceCandidateEndpoint: peerState.LocalIceCandidateEndpoint,
|
||||
RemoteIceCandidateEndpoint: peerState.RemoteIceCandidateEndpoint,
|
||||
RelayAddress: peerState.RelayServerAddress,
|
||||
Fqdn: peerState.FQDN,
|
||||
LastWireguardHandshake: timestamppb.New(peerState.LastWireguardHandshake),
|
||||
BytesRx: peerState.BytesRx,
|
||||
BytesTx: peerState.BytesTx,
|
||||
RosenpassEnabled: peerState.RosenpassEnabled,
|
||||
Networks: networks,
|
||||
Latency: durationpb.New(peerState.Latency),
|
||||
SshHostKey: peerState.SSHHostKey,
|
||||
}
|
||||
pbFullStatus.Peers = append(pbFullStatus.Peers, pbPeerState)
|
||||
}
|
||||
|
||||
for _, relayState := range fs.Relays {
|
||||
pbRelayState := &proto.RelayState{
|
||||
URI: relayState.URI,
|
||||
Available: relayState.Err == nil,
|
||||
}
|
||||
if err := relayState.Err; err != nil {
|
||||
pbRelayState.Error = err.Error()
|
||||
}
|
||||
pbFullStatus.Relays = append(pbFullStatus.Relays, pbRelayState)
|
||||
}
|
||||
|
||||
for _, dnsState := range fs.NSGroupStates {
|
||||
var err string
|
||||
if dnsState.Error != nil {
|
||||
err = dnsState.Error.Error()
|
||||
}
|
||||
|
||||
var servers []string
|
||||
for _, server := range dnsState.Servers {
|
||||
servers = append(servers, server.String())
|
||||
}
|
||||
|
||||
pbDnsState := &proto.NSGroupState{
|
||||
Servers: servers,
|
||||
Domains: dnsState.Domains,
|
||||
Enabled: dnsState.Enabled,
|
||||
Error: err,
|
||||
}
|
||||
pbFullStatus.DnsServers = append(pbFullStatus.DnsServers, pbDnsState)
|
||||
}
|
||||
|
||||
pbFullStatus.Events = fs.Events
|
||||
|
||||
return &pbFullStatus
|
||||
}
|
||||
|
||||
@@ -17,13 +17,12 @@ import (
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
nbdns "github.com/netbirdio/netbird/client/internal/dns"
|
||||
"github.com/netbirdio/netbird/client/internal/dns/resutil"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/peerstore"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/common"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/fakeip"
|
||||
iface "github.com/netbirdio/netbird/client/internal/routemanager/iface"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
@@ -38,6 +37,11 @@ type internalDNATer interface {
|
||||
AddInternalDNATMapping(netip.Addr, netip.Addr) error
|
||||
}
|
||||
|
||||
type wgInterface interface {
|
||||
Name() string
|
||||
Address() wgaddr.Address
|
||||
}
|
||||
|
||||
type DnsInterceptor struct {
|
||||
mu sync.RWMutex
|
||||
route *route.Route
|
||||
@@ -47,7 +51,7 @@ type DnsInterceptor struct {
|
||||
dnsServer nbdns.Server
|
||||
currentPeerKey string
|
||||
interceptedDomains domainMap
|
||||
wgInterface iface.WGIface
|
||||
wgInterface wgInterface
|
||||
peerStore *peerstore.Store
|
||||
firewall firewall.Manager
|
||||
fakeIPManager *fakeip.Manager
|
||||
@@ -215,14 +219,14 @@ func (d *DnsInterceptor) RemoveAllowedIPs() error {
|
||||
|
||||
// ServeDNS implements the dns.Handler interface
|
||||
func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||
logger := log.WithFields(log.Fields{
|
||||
"request_id": resutil.GetRequestID(w),
|
||||
"dns_id": fmt.Sprintf("%04x", r.Id),
|
||||
})
|
||||
requestID := nbdns.GenerateRequestID()
|
||||
logger := log.WithField("request_id", requestID)
|
||||
|
||||
if len(r.Question) == 0 {
|
||||
return
|
||||
}
|
||||
logger.Tracef("received DNS request for domain=%s type=%v class=%v",
|
||||
r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass)
|
||||
|
||||
// pass if non A/AAAA query
|
||||
if r.Question[0].Qtype != dns.TypeA && r.Question[0].Qtype != dns.TypeAAAA {
|
||||
@@ -245,6 +249,12 @@ func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||
return
|
||||
}
|
||||
|
||||
client, err := nbdns.GetClientPrivate(d.wgInterface.Address().IP, d.wgInterface.Name(), dnsTimeout)
|
||||
if err != nil {
|
||||
d.writeDNSError(w, r, logger, fmt.Sprintf("create DNS client: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
if r.Extra == nil {
|
||||
r.MsgHdr.AuthenticatedData = true
|
||||
}
|
||||
@@ -253,15 +263,32 @@ func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), dnsTimeout)
|
||||
defer cancel()
|
||||
|
||||
reply := d.queryUpstreamDNS(ctx, w, r, upstream, upstreamIP, peerKey, logger)
|
||||
if reply == nil {
|
||||
startTime := time.Now()
|
||||
reply, _, err := nbdns.ExchangeWithFallback(ctx, client, r, upstream)
|
||||
if err != nil {
|
||||
if errors.Is(err, context.DeadlineExceeded) {
|
||||
elapsed := time.Since(startTime)
|
||||
peerInfo := d.debugPeerTimeout(upstreamIP, peerKey)
|
||||
logger.Errorf("peer DNS timeout after %v (timeout=%v) for domain=%s to peer %s (%s)%s - error: %v",
|
||||
elapsed.Truncate(time.Millisecond), dnsTimeout, r.Question[0].Name, upstreamIP.String(), peerKey, peerInfo, err)
|
||||
} else {
|
||||
logger.Errorf("failed to exchange DNS request with %s (%s) for domain=%s: %v", upstreamIP.String(), peerKey, r.Question[0].Name, err)
|
||||
}
|
||||
if err := w.WriteMsg(&dns.Msg{MsgHdr: dns.MsgHdr{Rcode: dns.RcodeServerFailure, Id: r.Id}}); err != nil {
|
||||
logger.Errorf("failed writing DNS response: %v", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
resutil.SetMeta(w, "peer", peerKey)
|
||||
var answer []dns.RR
|
||||
if reply != nil {
|
||||
answer = reply.Answer
|
||||
}
|
||||
|
||||
logger.Tracef("upstream %s (%s) DNS response for domain=%s answers=%v", upstreamIP.String(), peerKey, r.Question[0].Name, answer)
|
||||
|
||||
reply.Id = r.Id
|
||||
if err := d.writeMsg(w, reply, logger); err != nil {
|
||||
if err := d.writeMsg(w, reply); err != nil {
|
||||
logger.Errorf("failed writing DNS response: %v", err)
|
||||
}
|
||||
}
|
||||
@@ -297,15 +324,11 @@ func (d *DnsInterceptor) getUpstreamIP(peerKey string) (netip.Addr, error) {
|
||||
return peerAllowedIP, nil
|
||||
}
|
||||
|
||||
func (d *DnsInterceptor) writeMsg(w dns.ResponseWriter, r *dns.Msg, logger *log.Entry) error {
|
||||
func (d *DnsInterceptor) writeMsg(w dns.ResponseWriter, r *dns.Msg) error {
|
||||
if r == nil {
|
||||
return fmt.Errorf("received nil DNS message")
|
||||
}
|
||||
|
||||
// Clear Zero bit from peer responses to prevent external sources from
|
||||
// manipulating our internal fallthrough signaling mechanism
|
||||
r.MsgHdr.Zero = false
|
||||
|
||||
if len(r.Answer) > 0 && len(r.Question) > 0 {
|
||||
origPattern := ""
|
||||
if writer, ok := w.(*nbdns.ResponseWriterChain); ok {
|
||||
@@ -327,14 +350,14 @@ func (d *DnsInterceptor) writeMsg(w dns.ResponseWriter, r *dns.Msg, logger *log.
|
||||
case *dns.A:
|
||||
addr, ok := netip.AddrFromSlice(rr.A)
|
||||
if !ok {
|
||||
logger.Tracef("failed to convert A record for domain=%s ip=%v", resolvedDomain, rr.A)
|
||||
log.Tracef("failed to convert A record for domain=%s ip=%v", resolvedDomain, rr.A)
|
||||
continue
|
||||
}
|
||||
ip = addr
|
||||
case *dns.AAAA:
|
||||
addr, ok := netip.AddrFromSlice(rr.AAAA)
|
||||
if !ok {
|
||||
logger.Tracef("failed to convert AAAA record for domain=%s ip=%v", resolvedDomain, rr.AAAA)
|
||||
log.Tracef("failed to convert AAAA record for domain=%s ip=%v", resolvedDomain, rr.AAAA)
|
||||
continue
|
||||
}
|
||||
ip = addr
|
||||
@@ -347,11 +370,11 @@ func (d *DnsInterceptor) writeMsg(w dns.ResponseWriter, r *dns.Msg, logger *log.
|
||||
}
|
||||
|
||||
if len(newPrefixes) > 0 {
|
||||
if err := d.updateDomainPrefixes(resolvedDomain, originalDomain, newPrefixes, logger); err != nil {
|
||||
logger.Errorf("failed to update domain prefixes: %v", err)
|
||||
if err := d.updateDomainPrefixes(resolvedDomain, originalDomain, newPrefixes); err != nil {
|
||||
log.Errorf("failed to update domain prefixes: %v", err)
|
||||
}
|
||||
|
||||
d.replaceIPsInDNSResponse(r, newPrefixes, logger)
|
||||
d.replaceIPsInDNSResponse(r, newPrefixes)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -363,22 +386,22 @@ func (d *DnsInterceptor) writeMsg(w dns.ResponseWriter, r *dns.Msg, logger *log.
|
||||
}
|
||||
|
||||
// logPrefixChanges handles the logging for prefix changes
|
||||
func (d *DnsInterceptor) logPrefixChanges(resolvedDomain, originalDomain domain.Domain, toAdd, toRemove []netip.Prefix, logger *log.Entry) {
|
||||
func (d *DnsInterceptor) logPrefixChanges(resolvedDomain, originalDomain domain.Domain, toAdd, toRemove []netip.Prefix) {
|
||||
if len(toAdd) > 0 {
|
||||
logger.Debugf("added dynamic route(s) for domain=%s (pattern: domain=%s): %s",
|
||||
log.Debugf("added dynamic route(s) for domain=%s (pattern: domain=%s): %s",
|
||||
resolvedDomain.SafeString(),
|
||||
originalDomain.SafeString(),
|
||||
toAdd)
|
||||
}
|
||||
if len(toRemove) > 0 && !d.route.KeepRoute {
|
||||
logger.Debugf("removed dynamic route(s) for domain=%s (pattern: domain=%s): %s",
|
||||
log.Debugf("removed dynamic route(s) for domain=%s (pattern: domain=%s): %s",
|
||||
resolvedDomain.SafeString(),
|
||||
originalDomain.SafeString(),
|
||||
toRemove)
|
||||
}
|
||||
}
|
||||
|
||||
func (d *DnsInterceptor) updateDomainPrefixes(resolvedDomain, originalDomain domain.Domain, newPrefixes []netip.Prefix, logger *log.Entry) error {
|
||||
func (d *DnsInterceptor) updateDomainPrefixes(resolvedDomain, originalDomain domain.Domain, newPrefixes []netip.Prefix) error {
|
||||
d.mu.Lock()
|
||||
defer d.mu.Unlock()
|
||||
|
||||
@@ -395,9 +418,9 @@ func (d *DnsInterceptor) updateDomainPrefixes(resolvedDomain, originalDomain dom
|
||||
realIP := prefix.Addr()
|
||||
if fakeIP, err := d.fakeIPManager.AllocateFakeIP(realIP); err == nil {
|
||||
dnatMappings[fakeIP] = realIP
|
||||
logger.Tracef("allocated fake IP %s for real IP %s", fakeIP, realIP)
|
||||
log.Tracef("allocated fake IP %s for real IP %s", fakeIP, realIP)
|
||||
} else {
|
||||
logger.Errorf("failed to allocate fake IP for %s: %v", realIP, err)
|
||||
log.Errorf("Failed to allocate fake IP for %s: %v", realIP, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -409,7 +432,7 @@ func (d *DnsInterceptor) updateDomainPrefixes(resolvedDomain, originalDomain dom
|
||||
}
|
||||
}
|
||||
|
||||
d.addDNATMappings(dnatMappings, logger)
|
||||
d.addDNATMappings(dnatMappings)
|
||||
|
||||
if !d.route.KeepRoute {
|
||||
// Remove old prefixes
|
||||
@@ -425,7 +448,7 @@ func (d *DnsInterceptor) updateDomainPrefixes(resolvedDomain, originalDomain dom
|
||||
}
|
||||
}
|
||||
|
||||
d.removeDNATMappings(toRemove, logger)
|
||||
d.removeDNATMappings(toRemove)
|
||||
}
|
||||
|
||||
// Update domain prefixes using resolved domain as key - store real IPs
|
||||
@@ -440,14 +463,14 @@ func (d *DnsInterceptor) updateDomainPrefixes(resolvedDomain, originalDomain dom
|
||||
// Store real IPs for status (user-facing), not fake IPs
|
||||
d.statusRecorder.UpdateResolvedDomainsStates(originalDomain, resolvedDomain, newPrefixes, d.route.GetResourceID())
|
||||
|
||||
d.logPrefixChanges(resolvedDomain, originalDomain, toAdd, toRemove, logger)
|
||||
d.logPrefixChanges(resolvedDomain, originalDomain, toAdd, toRemove)
|
||||
}
|
||||
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
// removeDNATMappings removes DNAT mappings from the firewall for real IP prefixes
|
||||
func (d *DnsInterceptor) removeDNATMappings(realPrefixes []netip.Prefix, logger *log.Entry) {
|
||||
func (d *DnsInterceptor) removeDNATMappings(realPrefixes []netip.Prefix) {
|
||||
if len(realPrefixes) == 0 {
|
||||
return
|
||||
}
|
||||
@@ -461,9 +484,9 @@ func (d *DnsInterceptor) removeDNATMappings(realPrefixes []netip.Prefix, logger
|
||||
realIP := prefix.Addr()
|
||||
if fakeIP, exists := d.fakeIPManager.GetFakeIP(realIP); exists {
|
||||
if err := dnatFirewall.RemoveInternalDNATMapping(fakeIP); err != nil {
|
||||
logger.Errorf("failed to remove DNAT mapping for %s: %v", fakeIP, err)
|
||||
log.Errorf("Failed to remove DNAT mapping for %s: %v", fakeIP, err)
|
||||
} else {
|
||||
logger.Debugf("removed DNAT mapping: %s -> %s", fakeIP, realIP)
|
||||
log.Debugf("Removed DNAT mapping for: %s -> %s", fakeIP, realIP)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -479,7 +502,7 @@ func (d *DnsInterceptor) internalDnatFw() (internalDNATer, bool) {
|
||||
}
|
||||
|
||||
// addDNATMappings adds DNAT mappings to the firewall
|
||||
func (d *DnsInterceptor) addDNATMappings(mappings map[netip.Addr]netip.Addr, logger *log.Entry) {
|
||||
func (d *DnsInterceptor) addDNATMappings(mappings map[netip.Addr]netip.Addr) {
|
||||
if len(mappings) == 0 {
|
||||
return
|
||||
}
|
||||
@@ -491,9 +514,9 @@ func (d *DnsInterceptor) addDNATMappings(mappings map[netip.Addr]netip.Addr, log
|
||||
|
||||
for fakeIP, realIP := range mappings {
|
||||
if err := dnatFirewall.AddInternalDNATMapping(fakeIP, realIP); err != nil {
|
||||
logger.Errorf("failed to add DNAT mapping %s -> %s: %v", fakeIP, realIP, err)
|
||||
log.Errorf("Failed to add DNAT mapping %s -> %s: %v", fakeIP, realIP, err)
|
||||
} else {
|
||||
logger.Debugf("added DNAT mapping: %s -> %s", fakeIP, realIP)
|
||||
log.Debugf("Added DNAT mapping: %s -> %s", fakeIP, realIP)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -505,12 +528,12 @@ func (d *DnsInterceptor) cleanupDNATMappings() {
|
||||
}
|
||||
|
||||
for _, prefixes := range d.interceptedDomains {
|
||||
d.removeDNATMappings(prefixes, log.NewEntry(log.StandardLogger()))
|
||||
d.removeDNATMappings(prefixes)
|
||||
}
|
||||
}
|
||||
|
||||
// replaceIPsInDNSResponse replaces real IPs with fake IPs in the DNS response
|
||||
func (d *DnsInterceptor) replaceIPsInDNSResponse(reply *dns.Msg, realPrefixes []netip.Prefix, logger *log.Entry) {
|
||||
func (d *DnsInterceptor) replaceIPsInDNSResponse(reply *dns.Msg, realPrefixes []netip.Prefix) {
|
||||
if _, ok := d.internalDnatFw(); !ok {
|
||||
return
|
||||
}
|
||||
@@ -526,7 +549,7 @@ func (d *DnsInterceptor) replaceIPsInDNSResponse(reply *dns.Msg, realPrefixes []
|
||||
|
||||
if fakeIP, exists := d.fakeIPManager.GetFakeIP(realIP); exists {
|
||||
rr.A = fakeIP.AsSlice()
|
||||
logger.Tracef("replaced real IP %s with fake IP %s in DNS response", realIP, fakeIP)
|
||||
log.Tracef("Replaced real IP %s with fake IP %s in DNS response", realIP, fakeIP)
|
||||
}
|
||||
|
||||
case *dns.AAAA:
|
||||
@@ -537,7 +560,7 @@ func (d *DnsInterceptor) replaceIPsInDNSResponse(reply *dns.Msg, realPrefixes []
|
||||
|
||||
if fakeIP, exists := d.fakeIPManager.GetFakeIP(realIP); exists {
|
||||
rr.AAAA = fakeIP.AsSlice()
|
||||
logger.Tracef("replaced real IP %s with fake IP %s in DNS response", realIP, fakeIP)
|
||||
log.Tracef("Replaced real IP %s with fake IP %s in DNS response", realIP, fakeIP)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -563,44 +586,6 @@ func determinePrefixChanges(oldPrefixes, newPrefixes []netip.Prefix) (toAdd, toR
|
||||
return
|
||||
}
|
||||
|
||||
// queryUpstreamDNS queries the upstream DNS server using netstack if available, otherwise uses regular client.
|
||||
// Returns the DNS reply on success, or nil on error (error responses are written internally).
|
||||
func (d *DnsInterceptor) queryUpstreamDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg, upstream string, upstreamIP netip.Addr, peerKey string, logger *log.Entry) *dns.Msg {
|
||||
startTime := time.Now()
|
||||
|
||||
nsNet := d.wgInterface.GetNet()
|
||||
var reply *dns.Msg
|
||||
var err error
|
||||
|
||||
if nsNet != nil {
|
||||
reply, err = nbdns.ExchangeWithNetstack(ctx, nsNet, r, upstream)
|
||||
} else {
|
||||
client, clientErr := nbdns.GetClientPrivate(d.wgInterface.Address().IP, d.wgInterface.Name(), dnsTimeout)
|
||||
if clientErr != nil {
|
||||
d.writeDNSError(w, r, logger, fmt.Sprintf("create DNS client: %v", clientErr))
|
||||
return nil
|
||||
}
|
||||
reply, _, err = nbdns.ExchangeWithFallback(ctx, client, r, upstream)
|
||||
}
|
||||
|
||||
if err == nil {
|
||||
return reply
|
||||
}
|
||||
|
||||
if errors.Is(err, context.DeadlineExceeded) {
|
||||
elapsed := time.Since(startTime)
|
||||
peerInfo := d.debugPeerTimeout(upstreamIP, peerKey)
|
||||
logger.Errorf("peer DNS timeout after %v (timeout=%v) for domain=%s to peer %s (%s)%s - error: %v",
|
||||
elapsed.Truncate(time.Millisecond), dnsTimeout, r.Question[0].Name, upstreamIP.String(), peerKey, peerInfo, err)
|
||||
} else {
|
||||
logger.Errorf("failed to exchange DNS request with %s (%s) for domain=%s: %v", upstreamIP.String(), peerKey, r.Question[0].Name, err)
|
||||
}
|
||||
if err := w.WriteMsg(&dns.Msg{MsgHdr: dns.MsgHdr{Rcode: dns.RcodeServerFailure, Id: r.Id}}); err != nil {
|
||||
logger.Errorf("failed writing DNS response: %v", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *DnsInterceptor) debugPeerTimeout(peerIP netip.Addr, peerKey string) string {
|
||||
if d.statusRecorder == nil {
|
||||
return ""
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
//go:build !windows
|
||||
// +build !windows
|
||||
|
||||
package iface
|
||||
|
||||
|
||||
@@ -4,8 +4,6 @@ import (
|
||||
"net"
|
||||
"net/netip"
|
||||
|
||||
"golang.zx2c4.com/wireguard/tun/netstack"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
)
|
||||
@@ -20,5 +18,4 @@ type wgIfaceBase interface {
|
||||
IsUserspaceBind() bool
|
||||
GetFilter() device.PacketFilter
|
||||
GetDevice() *device.FilteredDevice
|
||||
GetNet() *netstack.Net
|
||||
}
|
||||
|
||||
@@ -210,8 +210,7 @@ func (r *SysOps) refreshLocalSubnetsCache() {
|
||||
func (r *SysOps) genericAddVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
|
||||
nextHop := Nexthop{netip.Addr{}, intf}
|
||||
|
||||
switch prefix {
|
||||
case vars.Defaultv4:
|
||||
if prefix == vars.Defaultv4 {
|
||||
if err := r.addToRouteTable(splitDefaultv4_1, nextHop); err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -234,7 +233,7 @@ func (r *SysOps) genericAddVPNRoute(prefix netip.Prefix, intf *net.Interface) er
|
||||
}
|
||||
|
||||
return nil
|
||||
case vars.Defaultv6:
|
||||
} else if prefix == vars.Defaultv6 {
|
||||
if err := r.addToRouteTable(splitDefaultv6_1, nextHop); err != nil {
|
||||
return fmt.Errorf("add unreachable route split 1: %w", err)
|
||||
}
|
||||
@@ -256,8 +255,7 @@ func (r *SysOps) genericAddVPNRoute(prefix netip.Prefix, intf *net.Interface) er
|
||||
func (r *SysOps) genericRemoveVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
|
||||
nextHop := Nexthop{netip.Addr{}, intf}
|
||||
|
||||
switch prefix {
|
||||
case vars.Defaultv4:
|
||||
if prefix == vars.Defaultv4 {
|
||||
var result *multierror.Error
|
||||
if err := r.removeFromRouteTable(splitDefaultv4_1, nextHop); err != nil {
|
||||
result = multierror.Append(result, err)
|
||||
@@ -275,7 +273,7 @@ func (r *SysOps) genericRemoveVPNRoute(prefix netip.Prefix, intf *net.Interface)
|
||||
}
|
||||
|
||||
return nberrors.FormatErrorOrNil(result)
|
||||
case vars.Defaultv6:
|
||||
} else if prefix == vars.Defaultv6 {
|
||||
var result *multierror.Error
|
||||
if err := r.removeFromRouteTable(splitDefaultv6_1, nextHop); err != nil {
|
||||
result = multierror.Append(result, err)
|
||||
@@ -285,9 +283,9 @@ func (r *SysOps) genericRemoveVPNRoute(prefix netip.Prefix, intf *net.Interface)
|
||||
}
|
||||
|
||||
return nberrors.FormatErrorOrNil(result)
|
||||
default:
|
||||
return r.removeFromRouteTable(prefix, nextHop)
|
||||
}
|
||||
|
||||
return r.removeFromRouteTable(prefix, nextHop)
|
||||
}
|
||||
|
||||
func (r *SysOps) setupHooks(initAddresses []net.IP, stateManager *statemanager.Manager) error {
|
||||
|
||||
@@ -76,7 +76,7 @@ type Client struct {
|
||||
loginComplete bool
|
||||
connectClient *internal.ConnectClient
|
||||
// preloadedConfig holds config loaded from JSON (used on tvOS where file writes are blocked)
|
||||
preloadedConfig *profilemanager.Config
|
||||
preloadedConfig *profilemanager.Config
|
||||
}
|
||||
|
||||
// NewClient instantiate a new Client
|
||||
|
||||
@@ -173,9 +173,20 @@ func (s *Server) SetLogLevel(_ context.Context, req *proto.SetLogLevelRequest) (
|
||||
|
||||
log.SetLevel(level)
|
||||
|
||||
if s.connectClient != nil {
|
||||
s.connectClient.SetLogLevel(level)
|
||||
if s.connectClient == nil {
|
||||
return nil, fmt.Errorf("connect client not initialized")
|
||||
}
|
||||
engine := s.connectClient.Engine()
|
||||
if engine == nil {
|
||||
return nil, fmt.Errorf("engine not initialized")
|
||||
}
|
||||
|
||||
fwManager := engine.GetFirewallManager()
|
||||
if fwManager == nil {
|
||||
return nil, fmt.Errorf("firewall manager not initialized")
|
||||
}
|
||||
|
||||
fwManager.SetLogLevel(level)
|
||||
|
||||
log.Infof("Log level set to %s", level.String())
|
||||
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
@@ -27,3 +29,8 @@ func (s *Server) SubscribeEvents(req *proto.SubscribeRequest, stream proto.Daemo
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) GetEvents(context.Context, *proto.GetEventsRequest) (*proto.GetEventsResponse, error) {
|
||||
events := s.statusRecorder.GetEventHistory()
|
||||
return &proto.GetEventsResponse{Events: events}, nil
|
||||
}
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
//go:build windows
|
||||
// +build windows
|
||||
|
||||
package server
|
||||
|
||||
|
||||
@@ -13,12 +13,15 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/cenkalti/backoff/v4"
|
||||
"golang.org/x/exp/maps"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
"google.golang.org/protobuf/types/known/durationpb"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/metadata"
|
||||
gstatus "google.golang.org/grpc/status"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/auth"
|
||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||
@@ -1064,9 +1067,11 @@ func (s *Server) Status(
|
||||
if msg.GetFullPeerStatus {
|
||||
s.runProbes(msg.ShouldRunProbes)
|
||||
fullStatus := s.statusRecorder.GetFullStatus()
|
||||
pbFullStatus := fullStatus.ToProto()
|
||||
pbFullStatus := toProtoFullStatus(fullStatus)
|
||||
pbFullStatus.Events = s.statusRecorder.GetEventHistory()
|
||||
|
||||
pbFullStatus.SshServerState = s.getSSHServerState()
|
||||
|
||||
statusResponse.FullStatus = pbFullStatus
|
||||
}
|
||||
|
||||
@@ -1595,6 +1600,94 @@ func parseEnvDuration(envVar string, defaultDuration time.Duration) time.Duratio
|
||||
return defaultDuration
|
||||
}
|
||||
|
||||
func toProtoFullStatus(fullStatus peer.FullStatus) *proto.FullStatus {
|
||||
pbFullStatus := proto.FullStatus{
|
||||
ManagementState: &proto.ManagementState{},
|
||||
SignalState: &proto.SignalState{},
|
||||
LocalPeerState: &proto.LocalPeerState{},
|
||||
Peers: []*proto.PeerState{},
|
||||
}
|
||||
|
||||
pbFullStatus.ManagementState.URL = fullStatus.ManagementState.URL
|
||||
pbFullStatus.ManagementState.Connected = fullStatus.ManagementState.Connected
|
||||
if err := fullStatus.ManagementState.Error; err != nil {
|
||||
pbFullStatus.ManagementState.Error = err.Error()
|
||||
}
|
||||
|
||||
pbFullStatus.SignalState.URL = fullStatus.SignalState.URL
|
||||
pbFullStatus.SignalState.Connected = fullStatus.SignalState.Connected
|
||||
if err := fullStatus.SignalState.Error; err != nil {
|
||||
pbFullStatus.SignalState.Error = err.Error()
|
||||
}
|
||||
|
||||
pbFullStatus.LocalPeerState.IP = fullStatus.LocalPeerState.IP
|
||||
pbFullStatus.LocalPeerState.PubKey = fullStatus.LocalPeerState.PubKey
|
||||
pbFullStatus.LocalPeerState.KernelInterface = fullStatus.LocalPeerState.KernelInterface
|
||||
pbFullStatus.LocalPeerState.Fqdn = fullStatus.LocalPeerState.FQDN
|
||||
pbFullStatus.LocalPeerState.RosenpassPermissive = fullStatus.RosenpassState.Permissive
|
||||
pbFullStatus.LocalPeerState.RosenpassEnabled = fullStatus.RosenpassState.Enabled
|
||||
pbFullStatus.LocalPeerState.Networks = maps.Keys(fullStatus.LocalPeerState.Routes)
|
||||
pbFullStatus.NumberOfForwardingRules = int32(fullStatus.NumOfForwardingRules)
|
||||
pbFullStatus.LazyConnectionEnabled = fullStatus.LazyConnectionEnabled
|
||||
|
||||
for _, peerState := range fullStatus.Peers {
|
||||
pbPeerState := &proto.PeerState{
|
||||
IP: peerState.IP,
|
||||
PubKey: peerState.PubKey,
|
||||
ConnStatus: peerState.ConnStatus.String(),
|
||||
ConnStatusUpdate: timestamppb.New(peerState.ConnStatusUpdate),
|
||||
Relayed: peerState.Relayed,
|
||||
LocalIceCandidateType: peerState.LocalIceCandidateType,
|
||||
RemoteIceCandidateType: peerState.RemoteIceCandidateType,
|
||||
LocalIceCandidateEndpoint: peerState.LocalIceCandidateEndpoint,
|
||||
RemoteIceCandidateEndpoint: peerState.RemoteIceCandidateEndpoint,
|
||||
RelayAddress: peerState.RelayServerAddress,
|
||||
Fqdn: peerState.FQDN,
|
||||
LastWireguardHandshake: timestamppb.New(peerState.LastWireguardHandshake),
|
||||
BytesRx: peerState.BytesRx,
|
||||
BytesTx: peerState.BytesTx,
|
||||
RosenpassEnabled: peerState.RosenpassEnabled,
|
||||
Networks: maps.Keys(peerState.GetRoutes()),
|
||||
Latency: durationpb.New(peerState.Latency),
|
||||
SshHostKey: peerState.SSHHostKey,
|
||||
}
|
||||
pbFullStatus.Peers = append(pbFullStatus.Peers, pbPeerState)
|
||||
}
|
||||
|
||||
for _, relayState := range fullStatus.Relays {
|
||||
pbRelayState := &proto.RelayState{
|
||||
URI: relayState.URI,
|
||||
Available: relayState.Err == nil,
|
||||
}
|
||||
if err := relayState.Err; err != nil {
|
||||
pbRelayState.Error = err.Error()
|
||||
}
|
||||
pbFullStatus.Relays = append(pbFullStatus.Relays, pbRelayState)
|
||||
}
|
||||
|
||||
for _, dnsState := range fullStatus.NSGroupStates {
|
||||
var err string
|
||||
if dnsState.Error != nil {
|
||||
err = dnsState.Error.Error()
|
||||
}
|
||||
|
||||
var servers []string
|
||||
for _, server := range dnsState.Servers {
|
||||
servers = append(servers, server.String())
|
||||
}
|
||||
|
||||
pbDnsState := &proto.NSGroupState{
|
||||
Servers: servers,
|
||||
Domains: dnsState.Domains,
|
||||
Enabled: dnsState.Enabled,
|
||||
Error: err,
|
||||
}
|
||||
pbFullStatus.DnsServers = append(pbFullStatus.DnsServers, pbDnsState)
|
||||
}
|
||||
|
||||
return &pbFullStatus
|
||||
}
|
||||
|
||||
// sendTerminalNotification sends a terminal notification message
|
||||
// to inform the user that the NetBird connection session has expired.
|
||||
func sendTerminalNotification() error {
|
||||
|
||||
@@ -602,13 +602,12 @@ func TestJWTAuthentication(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
var authMethods []cryptossh.AuthMethod
|
||||
switch tc.token {
|
||||
case "valid":
|
||||
if tc.token == "valid" {
|
||||
token := generateValidJWT(t, privateKey, issuer, audience)
|
||||
authMethods = []cryptossh.AuthMethod{
|
||||
cryptossh.Password(token),
|
||||
}
|
||||
case "invalid":
|
||||
} else if tc.token == "invalid" {
|
||||
invalidToken := "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ.invalid"
|
||||
authMethods = []cryptossh.AuthMethod{
|
||||
cryptossh.Password(invalidToken),
|
||||
|
||||
@@ -325,64 +325,61 @@ func sortPeersByIP(peersStateDetail []PeerStateDetailOutput) {
|
||||
}
|
||||
}
|
||||
|
||||
// JSON returns the status overview as a JSON string.
|
||||
func (o *OutputOverview) JSON() (string, error) {
|
||||
jsonBytes, err := json.Marshal(o)
|
||||
func ParseToJSON(overview OutputOverview) (string, error) {
|
||||
jsonBytes, err := json.Marshal(overview)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("json marshal failed")
|
||||
}
|
||||
return string(jsonBytes), err
|
||||
}
|
||||
|
||||
// YAML returns the status overview as a YAML string.
|
||||
func (o *OutputOverview) YAML() (string, error) {
|
||||
yamlBytes, err := yaml.Marshal(o)
|
||||
func ParseToYAML(overview OutputOverview) (string, error) {
|
||||
yamlBytes, err := yaml.Marshal(overview)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("yaml marshal failed")
|
||||
}
|
||||
return string(yamlBytes), nil
|
||||
}
|
||||
|
||||
// GeneralSummary returns a general summary of the status overview.
|
||||
func (o *OutputOverview) GeneralSummary(showURL bool, showRelays bool, showNameServers bool, showSSHSessions bool) string {
|
||||
func ParseGeneralSummary(overview OutputOverview, showURL bool, showRelays bool, showNameServers bool, showSSHSessions bool) string {
|
||||
var managementConnString string
|
||||
if o.ManagementState.Connected {
|
||||
if overview.ManagementState.Connected {
|
||||
managementConnString = "Connected"
|
||||
if showURL {
|
||||
managementConnString = fmt.Sprintf("%s to %s", managementConnString, o.ManagementState.URL)
|
||||
managementConnString = fmt.Sprintf("%s to %s", managementConnString, overview.ManagementState.URL)
|
||||
}
|
||||
} else {
|
||||
managementConnString = "Disconnected"
|
||||
if o.ManagementState.Error != "" {
|
||||
managementConnString = fmt.Sprintf("%s, reason: %s", managementConnString, o.ManagementState.Error)
|
||||
if overview.ManagementState.Error != "" {
|
||||
managementConnString = fmt.Sprintf("%s, reason: %s", managementConnString, overview.ManagementState.Error)
|
||||
}
|
||||
}
|
||||
|
||||
var signalConnString string
|
||||
if o.SignalState.Connected {
|
||||
if overview.SignalState.Connected {
|
||||
signalConnString = "Connected"
|
||||
if showURL {
|
||||
signalConnString = fmt.Sprintf("%s to %s", signalConnString, o.SignalState.URL)
|
||||
signalConnString = fmt.Sprintf("%s to %s", signalConnString, overview.SignalState.URL)
|
||||
}
|
||||
} else {
|
||||
signalConnString = "Disconnected"
|
||||
if o.SignalState.Error != "" {
|
||||
signalConnString = fmt.Sprintf("%s, reason: %s", signalConnString, o.SignalState.Error)
|
||||
if overview.SignalState.Error != "" {
|
||||
signalConnString = fmt.Sprintf("%s, reason: %s", signalConnString, overview.SignalState.Error)
|
||||
}
|
||||
}
|
||||
|
||||
interfaceTypeString := "Userspace"
|
||||
interfaceIP := o.IP
|
||||
if o.KernelInterface {
|
||||
interfaceIP := overview.IP
|
||||
if overview.KernelInterface {
|
||||
interfaceTypeString = "Kernel"
|
||||
} else if o.IP == "" {
|
||||
} else if overview.IP == "" {
|
||||
interfaceTypeString = "N/A"
|
||||
interfaceIP = "N/A"
|
||||
}
|
||||
|
||||
var relaysString string
|
||||
if showRelays {
|
||||
for _, relay := range o.Relays.Details {
|
||||
for _, relay := range overview.Relays.Details {
|
||||
available := "Available"
|
||||
reason := ""
|
||||
|
||||
@@ -398,18 +395,18 @@ func (o *OutputOverview) GeneralSummary(showURL bool, showRelays bool, showNameS
|
||||
relaysString += fmt.Sprintf("\n [%s] is %s%s", relay.URI, available, reason)
|
||||
}
|
||||
} else {
|
||||
relaysString = fmt.Sprintf("%d/%d Available", o.Relays.Available, o.Relays.Total)
|
||||
relaysString = fmt.Sprintf("%d/%d Available", overview.Relays.Available, overview.Relays.Total)
|
||||
}
|
||||
|
||||
networks := "-"
|
||||
if len(o.Networks) > 0 {
|
||||
sort.Strings(o.Networks)
|
||||
networks = strings.Join(o.Networks, ", ")
|
||||
if len(overview.Networks) > 0 {
|
||||
sort.Strings(overview.Networks)
|
||||
networks = strings.Join(overview.Networks, ", ")
|
||||
}
|
||||
|
||||
var dnsServersString string
|
||||
if showNameServers {
|
||||
for _, nsServerGroup := range o.NSServerGroups {
|
||||
for _, nsServerGroup := range overview.NSServerGroups {
|
||||
enabled := "Available"
|
||||
if !nsServerGroup.Enabled {
|
||||
enabled = "Unavailable"
|
||||
@@ -433,25 +430,25 @@ func (o *OutputOverview) GeneralSummary(showURL bool, showRelays bool, showNameS
|
||||
)
|
||||
}
|
||||
} else {
|
||||
dnsServersString = fmt.Sprintf("%d/%d Available", countEnabled(o.NSServerGroups), len(o.NSServerGroups))
|
||||
dnsServersString = fmt.Sprintf("%d/%d Available", countEnabled(overview.NSServerGroups), len(overview.NSServerGroups))
|
||||
}
|
||||
|
||||
rosenpassEnabledStatus := "false"
|
||||
if o.RosenpassEnabled {
|
||||
if overview.RosenpassEnabled {
|
||||
rosenpassEnabledStatus = "true"
|
||||
if o.RosenpassPermissive {
|
||||
if overview.RosenpassPermissive {
|
||||
rosenpassEnabledStatus = "true (permissive)" //nolint:gosec
|
||||
}
|
||||
}
|
||||
|
||||
lazyConnectionEnabledStatus := "false"
|
||||
if o.LazyConnectionEnabled {
|
||||
if overview.LazyConnectionEnabled {
|
||||
lazyConnectionEnabledStatus = "true"
|
||||
}
|
||||
|
||||
sshServerStatus := "Disabled"
|
||||
if o.SSHServerState.Enabled {
|
||||
sessionCount := len(o.SSHServerState.Sessions)
|
||||
if overview.SSHServerState.Enabled {
|
||||
sessionCount := len(overview.SSHServerState.Sessions)
|
||||
if sessionCount > 0 {
|
||||
sessionWord := "session"
|
||||
if sessionCount > 1 {
|
||||
@@ -463,7 +460,7 @@ func (o *OutputOverview) GeneralSummary(showURL bool, showRelays bool, showNameS
|
||||
}
|
||||
|
||||
if showSSHSessions && sessionCount > 0 {
|
||||
for _, session := range o.SSHServerState.Sessions {
|
||||
for _, session := range overview.SSHServerState.Sessions {
|
||||
var sessionDisplay string
|
||||
if session.JWTUsername != "" {
|
||||
sessionDisplay = fmt.Sprintf("[%s@%s -> %s] %s",
|
||||
@@ -487,7 +484,7 @@ func (o *OutputOverview) GeneralSummary(showURL bool, showRelays bool, showNameS
|
||||
}
|
||||
}
|
||||
|
||||
peersCountString := fmt.Sprintf("%d/%d Connected", o.Peers.Connected, o.Peers.Total)
|
||||
peersCountString := fmt.Sprintf("%d/%d Connected", overview.Peers.Connected, overview.Peers.Total)
|
||||
|
||||
goos := runtime.GOOS
|
||||
goarch := runtime.GOARCH
|
||||
@@ -515,31 +512,30 @@ func (o *OutputOverview) GeneralSummary(showURL bool, showRelays bool, showNameS
|
||||
"Forwarding rules: %d\n"+
|
||||
"Peers count: %s\n",
|
||||
fmt.Sprintf("%s/%s%s", goos, goarch, goarm),
|
||||
o.DaemonVersion,
|
||||
overview.DaemonVersion,
|
||||
version.NetbirdVersion(),
|
||||
o.ProfileName,
|
||||
overview.ProfileName,
|
||||
managementConnString,
|
||||
signalConnString,
|
||||
relaysString,
|
||||
dnsServersString,
|
||||
domain.Domain(o.FQDN).SafeString(),
|
||||
domain.Domain(overview.FQDN).SafeString(),
|
||||
interfaceIP,
|
||||
interfaceTypeString,
|
||||
rosenpassEnabledStatus,
|
||||
lazyConnectionEnabledStatus,
|
||||
sshServerStatus,
|
||||
networks,
|
||||
o.NumberOfForwardingRules,
|
||||
overview.NumberOfForwardingRules,
|
||||
peersCountString,
|
||||
)
|
||||
return summary
|
||||
}
|
||||
|
||||
// FullDetailSummary returns a full detailed summary with peer details and events.
|
||||
func (o *OutputOverview) FullDetailSummary() string {
|
||||
parsedPeersString := parsePeers(o.Peers, o.RosenpassEnabled, o.RosenpassPermissive)
|
||||
parsedEventsString := parseEvents(o.Events)
|
||||
summary := o.GeneralSummary(true, true, true, true)
|
||||
func ParseToFullDetailSummary(overview OutputOverview) string {
|
||||
parsedPeersString := parsePeers(overview.Peers, overview.RosenpassEnabled, overview.RosenpassPermissive)
|
||||
parsedEventsString := parseEvents(overview.Events)
|
||||
summary := ParseGeneralSummary(overview, true, true, true, true)
|
||||
|
||||
return fmt.Sprintf(
|
||||
"Peers detail:"+
|
||||
|
||||
@@ -268,7 +268,7 @@ func TestSortingOfPeers(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestParsingToJSON(t *testing.T) {
|
||||
jsonString, _ := overview.JSON()
|
||||
jsonString, _ := ParseToJSON(overview)
|
||||
|
||||
//@formatter:off
|
||||
expectedJSONString := `
|
||||
@@ -404,7 +404,7 @@ func TestParsingToJSON(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestParsingToYAML(t *testing.T) {
|
||||
yaml, _ := overview.YAML()
|
||||
yaml, _ := ParseToYAML(overview)
|
||||
|
||||
expectedYAML :=
|
||||
`peers:
|
||||
@@ -511,7 +511,7 @@ func TestParsingToDetail(t *testing.T) {
|
||||
lastConnectionUpdate2 := timeAgo(overview.Peers.Details[1].LastStatusUpdate)
|
||||
lastHandshake2 := timeAgo(overview.Peers.Details[1].LastWireguardHandshake)
|
||||
|
||||
detail := overview.FullDetailSummary()
|
||||
detail := ParseToFullDetailSummary(overview)
|
||||
|
||||
expectedDetail := fmt.Sprintf(
|
||||
`Peers detail:
|
||||
@@ -575,7 +575,7 @@ Peers count: 2/2 Connected
|
||||
}
|
||||
|
||||
func TestParsingToShortVersion(t *testing.T) {
|
||||
shortVersion := overview.GeneralSummary(false, false, false, false)
|
||||
shortVersion := ParseGeneralSummary(overview, false, false, false, false)
|
||||
|
||||
expectedString := fmt.Sprintf("OS: %s/%s", runtime.GOOS, runtime.GOARCH) + `
|
||||
Daemon version: 0.14.1
|
||||
|
||||
@@ -1,3 +1,6 @@
|
||||
//go:build android
|
||||
// +build android
|
||||
|
||||
package system
|
||||
|
||||
import (
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
//go:build !ios
|
||||
// +build !ios
|
||||
|
||||
package system
|
||||
|
||||
|
||||
@@ -1,3 +1,6 @@
|
||||
//go:build ios
|
||||
// +build ios
|
||||
|
||||
package system
|
||||
|
||||
import (
|
||||
|
||||
@@ -510,7 +510,7 @@ func (s *serviceClient) saveSettings() {
|
||||
// Continue with default behavior if features can't be retrieved
|
||||
} else if features != nil && features.DisableUpdateSettings {
|
||||
log.Warn("Configuration updates are disabled by daemon")
|
||||
dialog.ShowError(fmt.Errorf("configuration updates are disabled by daemon"), s.wSettings)
|
||||
dialog.ShowError(fmt.Errorf("Configuration updates are disabled by daemon"), s.wSettings)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -540,7 +540,7 @@ func (s *serviceClient) saveSettings() {
|
||||
func (s *serviceClient) validateSettings() error {
|
||||
if s.iPreSharedKey.Text != "" && s.iPreSharedKey.Text != censoredPreSharedKey {
|
||||
if _, err := wgtypes.ParseKey(s.iPreSharedKey.Text); err != nil {
|
||||
return fmt.Errorf("invalid pre-shared key value")
|
||||
return fmt.Errorf("Invalid Pre-shared Key Value")
|
||||
}
|
||||
}
|
||||
return nil
|
||||
@@ -549,10 +549,10 @@ func (s *serviceClient) validateSettings() error {
|
||||
func (s *serviceClient) parseNumericSettings() (int64, int64, error) {
|
||||
port, err := strconv.ParseInt(s.iInterfacePort.Text, 10, 64)
|
||||
if err != nil {
|
||||
return 0, 0, errors.New("invalid interface port")
|
||||
return 0, 0, errors.New("Invalid interface port")
|
||||
}
|
||||
if port < 1 || port > 65535 {
|
||||
return 0, 0, errors.New("invalid interface port: out of range 1-65535")
|
||||
return 0, 0, errors.New("Invalid interface port: out of range 1-65535")
|
||||
}
|
||||
|
||||
var mtu int64
|
||||
@@ -560,7 +560,7 @@ func (s *serviceClient) parseNumericSettings() (int64, int64, error) {
|
||||
if mtuText != "" {
|
||||
mtu, err = strconv.ParseInt(mtuText, 10, 64)
|
||||
if err != nil {
|
||||
return 0, 0, errors.New("invalid MTU value")
|
||||
return 0, 0, errors.New("Invalid MTU value")
|
||||
}
|
||||
if mtu < iface.MinMTU || mtu > iface.MaxMTU {
|
||||
return 0, 0, fmt.Errorf("MTU must be between %d and %d bytes", iface.MinMTU, iface.MaxMTU)
|
||||
@@ -645,7 +645,7 @@ func (s *serviceClient) buildSetConfigRequest(iMngURL string, port, mtu int64) (
|
||||
if sshJWTCacheTTLText != "" {
|
||||
sshJWTCacheTTL, err := strconv.ParseInt(sshJWTCacheTTLText, 10, 32)
|
||||
if err != nil {
|
||||
return nil, errors.New("invalid SSH JWT Cache TTL value")
|
||||
return nil, errors.New("Invalid SSH JWT Cache TTL value")
|
||||
}
|
||||
if sshJWTCacheTTL < 0 || sshJWTCacheTTL > maxSSHJWTCacheTTL {
|
||||
return nil, fmt.Errorf("SSH JWT Cache TTL must be between 0 and %d seconds", maxSSHJWTCacheTTL)
|
||||
|
||||
@@ -441,7 +441,7 @@ func (s *serviceClient) collectDebugData(
|
||||
var postUpStatusOutput string
|
||||
if postUpStatus != nil {
|
||||
overview := nbstatus.ConvertToStatusOutputOverview(postUpStatus, params.anonymize, "", nil, nil, nil, "", profName)
|
||||
postUpStatusOutput = overview.FullDetailSummary()
|
||||
postUpStatusOutput = nbstatus.ParseToFullDetailSummary(overview)
|
||||
}
|
||||
headerPostUp := fmt.Sprintf("----- NetBird post-up - Timestamp: %s", time.Now().Format(time.RFC3339))
|
||||
statusOutput := fmt.Sprintf("%s\n%s", headerPostUp, postUpStatusOutput)
|
||||
@@ -458,7 +458,7 @@ func (s *serviceClient) collectDebugData(
|
||||
var preDownStatusOutput string
|
||||
if preDownStatus != nil {
|
||||
overview := nbstatus.ConvertToStatusOutputOverview(preDownStatus, params.anonymize, "", nil, nil, nil, "", profName)
|
||||
preDownStatusOutput = overview.FullDetailSummary()
|
||||
preDownStatusOutput = nbstatus.ParseToFullDetailSummary(overview)
|
||||
}
|
||||
headerPreDown := fmt.Sprintf("----- NetBird pre-down - Timestamp: %s - Duration: %s",
|
||||
time.Now().Format(time.RFC3339), params.duration)
|
||||
@@ -595,7 +595,7 @@ func (s *serviceClient) createDebugBundle(anonymize bool, systemInfo bool, uploa
|
||||
var statusOutput string
|
||||
if statusResp != nil {
|
||||
overview := nbstatus.ConvertToStatusOutputOverview(statusResp, anonymize, "", nil, nil, nil, "", profName)
|
||||
statusOutput = overview.FullDetailSummary()
|
||||
statusOutput = nbstatus.ParseToFullDetailSummary(overview)
|
||||
}
|
||||
|
||||
request := &proto.DebugBundleRequest{
|
||||
|
||||
@@ -164,7 +164,7 @@ func sendShowWindowSignal(pid int32) error {
|
||||
|
||||
err = windows.SetEvent(eventHandle)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error setting event: %w", err)
|
||||
return fmt.Errorf("Error setting event: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
@@ -9,31 +9,20 @@ import (
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"google.golang.org/protobuf/encoding/protojson"
|
||||
|
||||
netbird "github.com/netbirdio/netbird/client/embed"
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
sshdetection "github.com/netbirdio/netbird/client/ssh/detection"
|
||||
nbstatus "github.com/netbirdio/netbird/client/status"
|
||||
"github.com/netbirdio/netbird/client/wasm/internal/http"
|
||||
"github.com/netbirdio/netbird/client/wasm/internal/rdp"
|
||||
"github.com/netbirdio/netbird/client/wasm/internal/ssh"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
"github.com/netbirdio/netbird/version"
|
||||
)
|
||||
|
||||
const (
|
||||
clientStartTimeout = 30 * time.Second
|
||||
clientStopTimeout = 10 * time.Second
|
||||
pingTimeout = 10 * time.Second
|
||||
defaultLogLevel = "warn"
|
||||
defaultSSHDetectionTimeout = 20 * time.Second
|
||||
defaultPeerConnectionTimeout = 60 * time.Second
|
||||
peerConnectionPollInterval = 500 * time.Millisecond
|
||||
|
||||
icmpEchoRequest = 8
|
||||
icmpCodeEcho = 0
|
||||
pingBufferSize = 1500
|
||||
clientStartTimeout = 30 * time.Second
|
||||
clientStopTimeout = 10 * time.Second
|
||||
defaultLogLevel = "warn"
|
||||
defaultSSHDetectionTimeout = 20 * time.Second
|
||||
)
|
||||
|
||||
func main() {
|
||||
@@ -124,45 +113,18 @@ func createStopMethod(client *netbird.Client) js.Func {
|
||||
})
|
||||
}
|
||||
|
||||
// validateSSHArgs validates SSH connection arguments
|
||||
func validateSSHArgs(args []js.Value) (host string, port int, username string, err js.Value) {
|
||||
if len(args) < 2 {
|
||||
return "", 0, "", js.ValueOf("error: requires host and port")
|
||||
}
|
||||
|
||||
if args[0].Type() != js.TypeString {
|
||||
return "", 0, "", js.ValueOf("host parameter must be a string")
|
||||
}
|
||||
if args[1].Type() != js.TypeNumber {
|
||||
return "", 0, "", js.ValueOf("port parameter must be a number")
|
||||
}
|
||||
|
||||
host = args[0].String()
|
||||
port = args[1].Int()
|
||||
username = "root"
|
||||
|
||||
if len(args) > 2 {
|
||||
if args[2].Type() == js.TypeString && args[2].String() != "" {
|
||||
username = args[2].String()
|
||||
} else if args[2].Type() != js.TypeString {
|
||||
return "", 0, "", js.ValueOf("username parameter must be a string")
|
||||
}
|
||||
}
|
||||
|
||||
return host, port, username, js.Undefined()
|
||||
}
|
||||
|
||||
// createSSHMethod creates the SSH connection method
|
||||
func createSSHMethod(client *netbird.Client) js.Func {
|
||||
return js.FuncOf(func(this js.Value, args []js.Value) any {
|
||||
host, port, username, validationErr := validateSSHArgs(args)
|
||||
if !validationErr.IsUndefined() {
|
||||
if validationErr.Type() == js.TypeString && validationErr.String() == "error: requires host and port" {
|
||||
return validationErr
|
||||
}
|
||||
return createPromise(func(resolve, reject js.Value) {
|
||||
reject.Invoke(validationErr)
|
||||
})
|
||||
if len(args) < 2 {
|
||||
return js.ValueOf("error: requires host and port")
|
||||
}
|
||||
|
||||
host := args[0].String()
|
||||
port := args[1].Int()
|
||||
username := "root"
|
||||
if len(args) > 2 && args[2].String() != "" {
|
||||
username = args[2].String()
|
||||
}
|
||||
|
||||
var jwtToken string
|
||||
@@ -171,9 +133,6 @@ func createSSHMethod(client *netbird.Client) js.Func {
|
||||
}
|
||||
|
||||
return createPromise(func(resolve, reject js.Value) {
|
||||
// Note: Don't wait for peer connection here - lazy connection manager
|
||||
// will open the connection when Dial is called in ssh.Connect().
|
||||
|
||||
sshClient := ssh.NewClient(client)
|
||||
|
||||
if err := sshClient.Connect(host, port, username, jwtToken); err != nil {
|
||||
@@ -195,110 +154,6 @@ func createSSHMethod(client *netbird.Client) js.Func {
|
||||
})
|
||||
}
|
||||
|
||||
func performPing(client *netbird.Client, hostname string) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), pingTimeout)
|
||||
defer cancel()
|
||||
|
||||
start := time.Now()
|
||||
conn, err := client.Dial(ctx, "ping", hostname)
|
||||
if err != nil {
|
||||
js.Global().Get("console").Call("log", fmt.Sprintf("Ping to %s failed: %v", hostname, err))
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
if err := conn.Close(); err != nil {
|
||||
log.Debugf("failed to close ping connection: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
icmpData := make([]byte, 8)
|
||||
icmpData[0] = icmpEchoRequest
|
||||
icmpData[1] = icmpCodeEcho
|
||||
|
||||
if _, err := conn.Write(icmpData); err != nil {
|
||||
js.Global().Get("console").Call("log", fmt.Sprintf("Ping to %s write failed: %v", hostname, err))
|
||||
return
|
||||
}
|
||||
|
||||
buf := make([]byte, pingBufferSize)
|
||||
if _, err := conn.Read(buf); err != nil {
|
||||
js.Global().Get("console").Call("log", fmt.Sprintf("Ping to %s read failed: %v", hostname, err))
|
||||
return
|
||||
}
|
||||
|
||||
latency := time.Since(start)
|
||||
js.Global().Get("console").Call("log", fmt.Sprintf("Ping to %s: %dms", hostname, latency.Milliseconds()))
|
||||
}
|
||||
|
||||
func performPingTCP(client *netbird.Client, hostname string, port int) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), pingTimeout)
|
||||
defer cancel()
|
||||
|
||||
address := fmt.Sprintf("%s:%d", hostname, port)
|
||||
start := time.Now()
|
||||
conn, err := client.Dial(ctx, "tcp", address)
|
||||
if err != nil {
|
||||
js.Global().Get("console").Call("log", fmt.Sprintf("TCP ping to %s failed: %v", address, err))
|
||||
return
|
||||
}
|
||||
latency := time.Since(start)
|
||||
|
||||
if err := conn.Close(); err != nil {
|
||||
log.Debugf("failed to close TCP connection: %v", err)
|
||||
}
|
||||
|
||||
js.Global().Get("console").Call("log", fmt.Sprintf("TCP ping to %s succeeded: %dms", address, latency.Milliseconds()))
|
||||
}
|
||||
|
||||
// createPingMethod creates the ping method
|
||||
func createPingMethod(client *netbird.Client) js.Func {
|
||||
return js.FuncOf(func(this js.Value, args []js.Value) any {
|
||||
if len(args) < 1 {
|
||||
return js.ValueOf("error: hostname required")
|
||||
}
|
||||
|
||||
if args[0].Type() != js.TypeString {
|
||||
return createPromise(func(resolve, reject js.Value) {
|
||||
reject.Invoke(js.ValueOf("hostname parameter must be a string"))
|
||||
})
|
||||
}
|
||||
|
||||
hostname := args[0].String()
|
||||
return createPromise(func(resolve, reject js.Value) {
|
||||
performPing(client, hostname)
|
||||
resolve.Invoke(js.Undefined())
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
// createPingTCPMethod creates the pingtcp method
|
||||
func createPingTCPMethod(client *netbird.Client) js.Func {
|
||||
return js.FuncOf(func(this js.Value, args []js.Value) any {
|
||||
if len(args) < 2 {
|
||||
return js.ValueOf("error: hostname and port required")
|
||||
}
|
||||
|
||||
if args[0].Type() != js.TypeString {
|
||||
return createPromise(func(resolve, reject js.Value) {
|
||||
reject.Invoke(js.ValueOf("hostname parameter must be a string"))
|
||||
})
|
||||
}
|
||||
|
||||
if args[1].Type() != js.TypeNumber {
|
||||
return createPromise(func(resolve, reject js.Value) {
|
||||
reject.Invoke(js.ValueOf("port parameter must be a number"))
|
||||
})
|
||||
}
|
||||
|
||||
hostname := args[0].String()
|
||||
port := args[1].Int()
|
||||
return createPromise(func(resolve, reject js.Value) {
|
||||
performPingTCP(client, hostname, port)
|
||||
resolve.Invoke(js.Undefined())
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
// createProxyRequestMethod creates the proxyRequest method
|
||||
func createProxyRequestMethod(client *netbird.Client) js.Func {
|
||||
return js.FuncOf(func(this js.Value, args []js.Value) any {
|
||||
@@ -307,11 +162,6 @@ func createProxyRequestMethod(client *netbird.Client) js.Func {
|
||||
}
|
||||
|
||||
request := args[0]
|
||||
if request.Type() != js.TypeObject {
|
||||
return createPromise(func(resolve, reject js.Value) {
|
||||
reject.Invoke(js.ValueOf("request parameter must be an object"))
|
||||
})
|
||||
}
|
||||
|
||||
return createPromise(func(resolve, reject js.Value) {
|
||||
response, err := http.ProxyRequest(client, request)
|
||||
@@ -331,255 +181,23 @@ func createRDPProxyMethod(client *netbird.Client) js.Func {
|
||||
return js.ValueOf("error: hostname and port required")
|
||||
}
|
||||
|
||||
if args[0].Type() != js.TypeString {
|
||||
return createPromise(func(resolve, reject js.Value) {
|
||||
reject.Invoke(js.ValueOf("hostname parameter must be a string"))
|
||||
})
|
||||
}
|
||||
if args[1].Type() != js.TypeString {
|
||||
return createPromise(func(resolve, reject js.Value) {
|
||||
reject.Invoke(js.ValueOf("port parameter must be a string"))
|
||||
})
|
||||
}
|
||||
|
||||
proxy := rdp.NewRDCleanPathProxy(client)
|
||||
return proxy.CreateProxy(args[0].String(), args[1].String())
|
||||
})
|
||||
}
|
||||
|
||||
// getStatusOverview is a helper to get the status overview
|
||||
func getStatusOverview(client *netbird.Client) (nbstatus.OutputOverview, error) {
|
||||
fullStatus, err := client.Status()
|
||||
if err != nil {
|
||||
return nbstatus.OutputOverview{}, err
|
||||
}
|
||||
|
||||
pbFullStatus := fullStatus.ToProto()
|
||||
statusResp := &proto.StatusResponse{
|
||||
DaemonVersion: version.NetbirdVersion(),
|
||||
FullStatus: pbFullStatus,
|
||||
}
|
||||
|
||||
return nbstatus.ConvertToStatusOutputOverview(statusResp, false, "", nil, nil, nil, "", ""), nil
|
||||
}
|
||||
|
||||
// createStatusMethod creates the status method that returns JSON
|
||||
func createStatusMethod(client *netbird.Client) js.Func {
|
||||
return js.FuncOf(func(_ js.Value, args []js.Value) any {
|
||||
return createPromise(func(resolve, reject js.Value) {
|
||||
overview, err := getStatusOverview(client)
|
||||
if err != nil {
|
||||
reject.Invoke(js.ValueOf(err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
jsonStr, err := overview.JSON()
|
||||
if err != nil {
|
||||
reject.Invoke(js.ValueOf(err.Error()))
|
||||
return
|
||||
}
|
||||
jsonObj := js.Global().Get("JSON").Call("parse", jsonStr)
|
||||
resolve.Invoke(jsonObj)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
// createStatusSummaryMethod creates the statusSummary method
|
||||
func createStatusSummaryMethod(client *netbird.Client) js.Func {
|
||||
return js.FuncOf(func(_ js.Value, args []js.Value) any {
|
||||
return createPromise(func(resolve, reject js.Value) {
|
||||
overview, err := getStatusOverview(client)
|
||||
if err != nil {
|
||||
reject.Invoke(js.ValueOf(err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
summary := overview.GeneralSummary(false, false, false, false)
|
||||
js.Global().Get("console").Call("log", summary)
|
||||
resolve.Invoke(js.Undefined())
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
// createStatusDetailMethod creates the statusDetail method
|
||||
func createStatusDetailMethod(client *netbird.Client) js.Func {
|
||||
return js.FuncOf(func(_ js.Value, args []js.Value) any {
|
||||
log.Info("statusDetail called")
|
||||
return createPromise(func(resolve, reject js.Value) {
|
||||
log.Info("statusDetail: getting overview")
|
||||
overview, err := getStatusOverview(client)
|
||||
if err != nil {
|
||||
log.Errorf("statusDetail: getStatusOverview failed: %v", err)
|
||||
reject.Invoke(js.ValueOf(err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
log.Info("statusDetail: generating full detail summary")
|
||||
detail := overview.FullDetailSummary()
|
||||
log.Infof("statusDetail: detail length=%d", len(detail))
|
||||
js.Global().Get("console").Call("log", detail)
|
||||
resolve.Invoke(js.Undefined())
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
// createWaitForPeerConnectionMethod creates a method that waits for a peer to be connected
|
||||
func createWaitForPeerConnectionMethod(client *netbird.Client) js.Func {
|
||||
return js.FuncOf(func(_ js.Value, args []js.Value) any {
|
||||
return createPromise(func(resolve, reject js.Value) {
|
||||
if len(args) < 1 {
|
||||
reject.Invoke(js.ValueOf("peer IP address required"))
|
||||
return
|
||||
}
|
||||
|
||||
peerIP := args[0].String()
|
||||
timeoutMs := int(defaultPeerConnectionTimeout.Milliseconds())
|
||||
if len(args) > 1 && !args[1].IsUndefined() && !args[1].IsNull() {
|
||||
timeoutMs = args[1].Int()
|
||||
}
|
||||
|
||||
timeout := time.Duration(timeoutMs) * time.Millisecond
|
||||
ctx, cancel := context.WithTimeout(context.Background(), timeout)
|
||||
defer cancel()
|
||||
|
||||
log.Infof("Waiting for peer %s to be connected (timeout: %v)", peerIP, timeout)
|
||||
|
||||
ticker := time.NewTicker(peerConnectionPollInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
reject.Invoke(js.ValueOf(fmt.Sprintf("timeout waiting for peer %s to connect", peerIP)))
|
||||
return
|
||||
case <-ticker.C:
|
||||
overview, err := getStatusOverview(client)
|
||||
if err != nil {
|
||||
log.Debugf("Error getting status: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
for _, peer := range overview.Peers.Details {
|
||||
if peer.IP == peerIP && peer.Status == "Connected" {
|
||||
log.Infof("Peer %s is now connected (type: %s)", peerIP, peer.ConnType)
|
||||
resolve.Invoke(js.ValueOf(map[string]interface{}{
|
||||
"ip": peer.IP,
|
||||
"status": peer.Status,
|
||||
"connType": peer.ConnType,
|
||||
}))
|
||||
return
|
||||
}
|
||||
}
|
||||
log.Tracef("Peer %s not yet connected, waiting...", peerIP)
|
||||
}
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
// createGetSyncResponseMethod creates the getSyncResponse method that returns the latest sync response as JSON
|
||||
func createGetSyncResponseMethod(client *netbird.Client) js.Func {
|
||||
return js.FuncOf(func(_ js.Value, args []js.Value) any {
|
||||
return createPromise(func(resolve, reject js.Value) {
|
||||
syncResp, err := client.GetLatestSyncResponse()
|
||||
if err != nil {
|
||||
reject.Invoke(js.ValueOf(err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
options := protojson.MarshalOptions{
|
||||
EmitUnpopulated: true,
|
||||
UseProtoNames: true,
|
||||
AllowPartial: true,
|
||||
}
|
||||
jsonBytes, err := options.Marshal(syncResp)
|
||||
if err != nil {
|
||||
reject.Invoke(js.ValueOf(fmt.Sprintf("marshal sync response: %v", err)))
|
||||
return
|
||||
}
|
||||
|
||||
jsonObj := js.Global().Get("JSON").Call("parse", string(jsonBytes))
|
||||
resolve.Invoke(jsonObj)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
// createSetLogLevelMethod creates the setLogLevel method to dynamically change logging level
|
||||
func createSetLogLevelMethod(client *netbird.Client) js.Func {
|
||||
return js.FuncOf(func(_ js.Value, args []js.Value) any {
|
||||
if len(args) < 1 {
|
||||
return js.ValueOf("error: log level required")
|
||||
}
|
||||
|
||||
if args[0].Type() != js.TypeString {
|
||||
return createPromise(func(resolve, reject js.Value) {
|
||||
reject.Invoke(js.ValueOf("log level parameter must be a string"))
|
||||
})
|
||||
}
|
||||
|
||||
logLevel := args[0].String()
|
||||
return createPromise(func(resolve, reject js.Value) {
|
||||
if err := client.SetLogLevel(logLevel); err != nil {
|
||||
reject.Invoke(js.ValueOf(fmt.Sprintf("set log level: %v", err)))
|
||||
return
|
||||
}
|
||||
log.Infof("Log level set to: %s", logLevel)
|
||||
resolve.Invoke(js.ValueOf(true))
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
// createPromise is a helper to create JavaScript promises
|
||||
func createPromise(handler func(resolve, reject js.Value)) js.Value {
|
||||
return js.Global().Get("Promise").New(js.FuncOf(func(_ js.Value, promiseArgs []js.Value) any {
|
||||
resolve := promiseArgs[0]
|
||||
reject := promiseArgs[1]
|
||||
|
||||
// Wrap reject to always log the error
|
||||
loggingReject := js.FuncOf(func(_ js.Value, args []js.Value) any {
|
||||
if len(args) > 0 {
|
||||
log.Errorf("Promise rejected: %v", args[0])
|
||||
js.Global().Get("console").Call("error", "WASM Promise rejected:", args[0])
|
||||
}
|
||||
reject.Invoke(args[0])
|
||||
return nil
|
||||
})
|
||||
|
||||
go handler(resolve, loggingReject.Value)
|
||||
go handler(resolve, reject)
|
||||
|
||||
return nil
|
||||
}))
|
||||
}
|
||||
|
||||
// waitForPeerConnected waits for a peer with the given IP to be connected
|
||||
func waitForPeerConnected(ctx context.Context, client *netbird.Client, peerIP string) error {
|
||||
log.Infof("Waiting for peer %s to be connected before operation", peerIP)
|
||||
|
||||
ticker := time.NewTicker(peerConnectionPollInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return fmt.Errorf("timeout waiting for peer %s to connect", peerIP)
|
||||
case <-ticker.C:
|
||||
overview, err := getStatusOverview(client)
|
||||
if err != nil {
|
||||
log.Debugf("Error getting status while waiting for peer: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
for _, peer := range overview.Peers.Details {
|
||||
if peer.IP == peerIP && peer.Status == "Connected" {
|
||||
log.Infof("Peer %s is now connected (type: %s), proceeding with operation", peerIP, peer.ConnType)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
log.Tracef("Peer %s not yet connected, waiting...", peerIP)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// createDetectSSHServerMethod creates the SSH server detection method
|
||||
func createDetectSSHServerMethod(client *netbird.Client) js.Func {
|
||||
return js.FuncOf(func(this js.Value, args []js.Value) any {
|
||||
@@ -602,10 +220,6 @@ func createDetectSSHServerMethod(client *netbird.Client) js.Func {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(timeoutMs)*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
// Note: Don't wait for peer connection here - lazy connection manager
|
||||
// will open the connection when Dial is called. Waiting would cause
|
||||
// a deadlock since lazy connections only open on traffic.
|
||||
|
||||
serverType, err := sshdetection.DetectSSHServerType(ctx, client, host, port)
|
||||
if err != nil {
|
||||
reject.Invoke(err.Error())
|
||||
@@ -623,25 +237,17 @@ func createClientObject(client *netbird.Client) js.Value {
|
||||
|
||||
obj["start"] = createStartMethod(client)
|
||||
obj["stop"] = createStopMethod(client)
|
||||
obj["ping"] = createPingMethod(client)
|
||||
obj["pingtcp"] = createPingTCPMethod(client)
|
||||
obj["detectSSHServerType"] = createDetectSSHServerMethod(client)
|
||||
obj["createSSHConnection"] = createSSHMethod(client)
|
||||
obj["proxyRequest"] = createProxyRequestMethod(client)
|
||||
obj["createRDPProxy"] = createRDPProxyMethod(client)
|
||||
obj["status"] = createStatusMethod(client)
|
||||
obj["statusSummary"] = createStatusSummaryMethod(client)
|
||||
obj["statusDetail"] = createStatusDetailMethod(client)
|
||||
obj["waitForPeerConnection"] = createWaitForPeerConnectionMethod(client)
|
||||
obj["getSyncResponse"] = createGetSyncResponseMethod(client)
|
||||
obj["setLogLevel"] = createSetLogLevelMethod(client)
|
||||
|
||||
return js.ValueOf(obj)
|
||||
}
|
||||
|
||||
// netBirdClientConstructor acts as a JavaScript constructor function
|
||||
func netBirdClientConstructor(_ js.Value, args []js.Value) any {
|
||||
return js.Global().Get("Promise").New(js.FuncOf(func(_ js.Value, promiseArgs []js.Value) any {
|
||||
func netBirdClientConstructor(this js.Value, args []js.Value) any {
|
||||
return js.Global().Get("Promise").New(js.FuncOf(func(this js.Value, promiseArgs []js.Value) any {
|
||||
resolve := promiseArgs[0]
|
||||
reject := promiseArgs[1]
|
||||
|
||||
@@ -657,10 +263,7 @@ func netBirdClientConstructor(_ js.Value, args []js.Value) any {
|
||||
return
|
||||
}
|
||||
|
||||
// Force trace logging for debugging - must be set BEFORE netbird.New
|
||||
// as New() will call logrus.SetLevel with options.LogLevel
|
||||
options.LogLevel = "trace"
|
||||
if err := util.InitLog("trace", util.LogConsole); err != nil {
|
||||
if err := util.InitLog(options.LogLevel, util.LogConsole); err != nil {
|
||||
log.Warnf("Failed to initialize logging: %v", err)
|
||||
}
|
||||
|
||||
|
||||
@@ -47,8 +47,8 @@ type CustomZone struct {
|
||||
Records []SimpleRecord
|
||||
// SearchDomainDisabled indicates whether to add match domains to a search domains list or not
|
||||
SearchDomainDisabled bool
|
||||
// NonAuthoritative marks user-created zones
|
||||
NonAuthoritative bool
|
||||
// SkipPTRProcess indicates whether a client should process PTR records from custom zones
|
||||
SkipPTRProcess bool
|
||||
}
|
||||
|
||||
// SimpleRecord provides a simple DNS record specification for CNAME, A and AAAA records
|
||||
|
||||
24
go.mod
24
go.mod
@@ -1,8 +1,6 @@
|
||||
module github.com/netbirdio/netbird
|
||||
|
||||
go 1.25
|
||||
|
||||
toolchain go1.25.5
|
||||
go 1.24.10
|
||||
|
||||
require (
|
||||
cunicu.li/go-rosenpass v0.4.0
|
||||
@@ -42,7 +40,7 @@ require (
|
||||
github.com/cilium/ebpf v0.15.0
|
||||
github.com/coder/websocket v1.8.13
|
||||
github.com/coreos/go-iptables v0.7.0
|
||||
github.com/creack/pty v1.1.24
|
||||
github.com/creack/pty v1.1.18
|
||||
github.com/dexidp/dex v0.0.0-00010101000000-000000000000
|
||||
github.com/dexidp/dex/api/v2 v2.4.0
|
||||
github.com/eko/gocache/lib/v4 v4.2.0
|
||||
@@ -78,12 +76,12 @@ require (
|
||||
github.com/pion/logging v0.2.4
|
||||
github.com/pion/randutil v0.1.0
|
||||
github.com/pion/stun/v2 v2.0.0
|
||||
github.com/pion/stun/v3 v3.1.0
|
||||
github.com/pion/transport/v3 v3.1.1
|
||||
github.com/pion/stun/v3 v3.0.0
|
||||
github.com/pion/transport/v3 v3.0.7
|
||||
github.com/pion/turn/v3 v3.0.1
|
||||
github.com/pkg/sftp v1.13.9
|
||||
github.com/prometheus/client_golang v1.23.2
|
||||
github.com/quic-go/quic-go v0.55.0
|
||||
github.com/quic-go/quic-go v0.49.1
|
||||
github.com/redis/go-redis/v9 v9.7.3
|
||||
github.com/rs/xid v1.3.0
|
||||
github.com/shirou/gopsutil/v3 v3.24.4
|
||||
@@ -105,7 +103,7 @@ require (
|
||||
go.opentelemetry.io/otel/exporters/prometheus v0.48.0
|
||||
go.opentelemetry.io/otel/metric v1.38.0
|
||||
go.opentelemetry.io/otel/sdk/metric v1.38.0
|
||||
go.uber.org/mock v0.5.2
|
||||
go.uber.org/mock v0.5.0
|
||||
go.uber.org/zap v1.27.0
|
||||
goauthentik.io/api/v3 v3.2023051.3
|
||||
golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842
|
||||
@@ -122,7 +120,7 @@ require (
|
||||
gorm.io/driver/postgres v1.5.7
|
||||
gorm.io/driver/sqlite v1.5.7
|
||||
gorm.io/gorm v1.25.12
|
||||
gvisor.dev/gvisor v0.0.0-20251031020517-ecfcdd2f171c
|
||||
gvisor.dev/gvisor v0.0.0-20231020174304-b8a429915ff1
|
||||
)
|
||||
|
||||
require (
|
||||
@@ -188,10 +186,12 @@ require (
|
||||
github.com/go-logr/stdr v1.2.2 // indirect
|
||||
github.com/go-ole/go-ole v1.3.0 // indirect
|
||||
github.com/go-sql-driver/mysql v1.9.3 // indirect
|
||||
github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 // indirect
|
||||
github.com/go-text/render v0.2.0 // indirect
|
||||
github.com/go-text/typesetting v0.2.1 // indirect
|
||||
github.com/gogo/protobuf v1.3.2 // indirect
|
||||
github.com/google/btree v1.1.2 // indirect
|
||||
github.com/google/pprof v0.0.0-20211214055906-6f57359322fd // indirect
|
||||
github.com/google/s2a-go v0.1.9 // indirect
|
||||
github.com/googleapis/enterprise-certificate-proxy v0.3.7 // indirect
|
||||
github.com/googleapis/gax-go/v2 v2.15.0 // indirect
|
||||
@@ -241,7 +241,7 @@ require (
|
||||
github.com/opencontainers/go-digest v1.0.0 // indirect
|
||||
github.com/opencontainers/image-spec v1.1.0 // indirect
|
||||
github.com/pion/dtls/v2 v2.2.10 // indirect
|
||||
github.com/pion/dtls/v3 v3.0.9 // indirect
|
||||
github.com/pion/dtls/v3 v3.0.7 // indirect
|
||||
github.com/pion/mdns/v2 v2.0.7 // indirect
|
||||
github.com/pion/transport/v2 v2.2.4 // indirect
|
||||
github.com/pion/turn/v4 v4.1.1 // indirect
|
||||
@@ -263,7 +263,7 @@ require (
|
||||
github.com/tklauser/numcpus v0.8.0 // indirect
|
||||
github.com/vishvananda/netns v0.0.5 // indirect
|
||||
github.com/vmihailenco/tagparser/v2 v2.0.0 // indirect
|
||||
github.com/wlynxg/anet v0.0.5 // indirect
|
||||
github.com/wlynxg/anet v0.0.3 // indirect
|
||||
github.com/yuin/goldmark v1.7.8 // indirect
|
||||
github.com/zeebo/blake3 v0.2.3 // indirect
|
||||
go.opentelemetry.io/auto/sdk v1.2.1 // indirect
|
||||
@@ -285,7 +285,7 @@ replace github.com/kardianos/service => github.com/netbirdio/service v0.0.0-2024
|
||||
|
||||
replace github.com/getlantern/systray => github.com/netbirdio/systray v0.0.0-20231030152038-ef1ed2a27949
|
||||
|
||||
replace golang.zx2c4.com/wireguard => github.com/netbirdio/wireguard-go v0.0.0-20260107100953-33b7c9d03db0
|
||||
replace golang.zx2c4.com/wireguard => github.com/netbirdio/wireguard-go v0.0.0-20241230120307-6a676aebaaf6
|
||||
|
||||
replace github.com/cloudflare/circl => github.com/cunicu/circl v0.0.0-20230801113412-fec58fc7b5f6
|
||||
|
||||
|
||||
41
go.sum
41
go.sum
@@ -101,6 +101,9 @@ github.com/cenkalti/backoff/v4 v4.3.0 h1:MyRJ/UdXutAwSAT+s3wNd7MfTIcy71VQueUuFK3
|
||||
github.com/cenkalti/backoff/v4 v4.3.0/go.mod h1:Y3VNntkOUPxTVeUxJ/G5vcM//AlwfmyYozVcomhLiZE=
|
||||
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
|
||||
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
|
||||
github.com/chzyer/logex v1.1.10/go.mod h1:+Ywpsq7O8HXn0nuIou7OrIPyXbp3wmkHB+jjWRnGsAI=
|
||||
github.com/chzyer/readline v0.0.0-20180603132655-2972be24d48e/go.mod h1:nSuG5e5PlCu98SY8svDHJxuZscDgtXS6KTTbou5AhLI=
|
||||
github.com/chzyer/test v0.0.0-20180213035817-a1ea475d72b1/go.mod h1:Q3SI9o4m/ZMnBNeIyt5eFwwo7qiLfzFZmjNmxjkiQlU=
|
||||
github.com/cilium/ebpf v0.15.0 h1:7NxJhNiBT3NG8pZJ3c+yfrVdHY8ScgKD27sScgjLMMk=
|
||||
github.com/cilium/ebpf v0.15.0/go.mod h1:DHp1WyrLeiBh19Cf/tfiSMhqheEiK8fXFZ4No0P1Hso=
|
||||
github.com/coder/websocket v1.8.13 h1:f3QZdXy7uGVz+4uCJy2nTZyM0yTBj8yANEHhqlXZ9FE=
|
||||
@@ -118,8 +121,8 @@ github.com/coreos/go-oidc/v3 v3.14.1/go.mod h1:HaZ3szPaZ0e4r6ebqvsLWlk2Tn+aejfmr
|
||||
github.com/cpuguy83/dockercfg v0.3.2 h1:DlJTyZGBDlXqUZ2Dk2Q3xHs/FtnooJJVaad2S9GKorA=
|
||||
github.com/cpuguy83/dockercfg v0.3.2/go.mod h1:sugsbF4//dDlL/i+S+rtpIWp+5h0BHJHfjj5/jFyUJc=
|
||||
github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g=
|
||||
github.com/creack/pty v1.1.24 h1:bJrF4RRfyJnbTJqzRLHzcGaZK1NeM5kTC9jGgovnR1s=
|
||||
github.com/creack/pty v1.1.24/go.mod h1:08sCNb52WyoAwi2QDyzUCTgcvVFhUzewun7wtTfvcwE=
|
||||
github.com/creack/pty v1.1.18 h1:n56/Zwd5o6whRC5PMGretI4IdRLlmBXYNjScPaBgsbY=
|
||||
github.com/creack/pty v1.1.18/go.mod h1:MOBLtS5ELjhRRrroQr9kyvTxUAFNvYEK993ew/Vr4O4=
|
||||
github.com/cunicu/circl v0.0.0-20230801113412-fec58fc7b5f6 h1:/DS5cDX3FJdl+XaN2D7XAwFpuanTxnp52DBLZAaJKx0=
|
||||
github.com/cunicu/circl v0.0.0-20230801113412-fec58fc7b5f6/go.mod h1:+CauBF6R70Jqcyl8N2hC8pAXYbWkGIezuSbuGLtRhnw=
|
||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
@@ -283,6 +286,7 @@ github.com/hashicorp/go-version v1.6.0/go.mod h1:fltr4n8CU8Ke44wwGCBoEymUuxUHl09
|
||||
github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU=
|
||||
github.com/huandu/xstrings v1.5.0 h1:2ag3IFq9ZDANvthTwTiqSSZLjDc+BedvHPAp5tJy2TI=
|
||||
github.com/huandu/xstrings v1.5.0/go.mod h1:y5/lhBue+AyNmUVz9RLU9xbLR0o4KIIExikq4ovT0aE=
|
||||
github.com/ianlancetaylor/demangle v0.0.0-20210905161508-09a460cdf81d/go.mod h1:aYm2/VgdVmcIU8iMfdMvDMsRAQjcfZSKFby6HOFvi/w=
|
||||
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
|
||||
github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
|
||||
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
|
||||
@@ -407,8 +411,8 @@ github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9ax
|
||||
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
|
||||
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45 h1:ujgviVYmx243Ksy7NdSwrdGPSRNE3pb8kEDSpH0QuAQ=
|
||||
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45/go.mod h1:5/sjFmLb8O96B5737VCqhHyGRzNFIaN/Bu7ZodXc3qQ=
|
||||
github.com/netbirdio/wireguard-go v0.0.0-20260107100953-33b7c9d03db0 h1:h/QnNzm7xzHPm+gajcblYUOclrW2FeNeDlUNj6tTWKQ=
|
||||
github.com/netbirdio/wireguard-go v0.0.0-20260107100953-33b7c9d03db0/go.mod h1:rpwXGsirqLqN2L0JDJQlwOboGHmptD5ZD6T2VmcqhTw=
|
||||
github.com/netbirdio/wireguard-go v0.0.0-20241230120307-6a676aebaaf6 h1:X5h5QgP7uHAv78FWgHV8+WYLjHxK9v3ilkVXT1cpCrQ=
|
||||
github.com/netbirdio/wireguard-go v0.0.0-20241230120307-6a676aebaaf6/go.mod h1:tkCQ4FQXmpAgYVh++1cq16/dH4QJtmvpRv19DWGAHSA=
|
||||
github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646 h1:zYyBkD/k9seD2A7fsi6Oo2LfFZAehjjQMERAvZLEDnQ=
|
||||
github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646/go.mod h1:jpp1/29i3P1S/RLdc7JQKbRpFeM1dOBd8T9ki5s+AY8=
|
||||
github.com/nicksnyder/go-i18n/v2 v2.5.1 h1:IxtPxYsR9Gp60cGXjfuR/llTqV8aYMsC472zD0D1vHk=
|
||||
@@ -444,8 +448,8 @@ github.com/petermattis/goid v0.0.0-20250303134427-723919f7f203/go.mod h1:pxMtw7c
|
||||
github.com/pion/dtls/v2 v2.2.7/go.mod h1:8WiMkebSHFD0T+dIU+UeBaoV7kDhOW5oDCzZ7WZ/F9s=
|
||||
github.com/pion/dtls/v2 v2.2.10 h1:u2Axk+FyIR1VFTPurktB+1zoEPGIW3bmyj3LEFrXjAA=
|
||||
github.com/pion/dtls/v2 v2.2.10/go.mod h1:d9SYc9fch0CqK90mRk1dC7AkzzpwJj6u2GU3u+9pqFE=
|
||||
github.com/pion/dtls/v3 v3.0.9 h1:4AijfFRm8mAjd1gfdlB1wzJF3fjjR/VPIpJgkEtvYmM=
|
||||
github.com/pion/dtls/v3 v3.0.9/go.mod h1:abApPjgadS/ra1wvUzHLc3o2HvoxppAh+NZkyApL4Os=
|
||||
github.com/pion/dtls/v3 v3.0.7 h1:bItXtTYYhZwkPFk4t1n3Kkf5TDrfj6+4wG+CZR8uI9Q=
|
||||
github.com/pion/dtls/v3 v3.0.7/go.mod h1:uDlH5VPrgOQIw59irKYkMudSFprY9IEFCqz/eTz16f8=
|
||||
github.com/pion/logging v0.2.2/go.mod h1:k0/tDVsRCX2Mb2ZEmTqNa7CWsQPc+YYCB7Q+5pahoms=
|
||||
github.com/pion/logging v0.2.4 h1:tTew+7cmQ+Mc1pTBLKH2puKsOvhm32dROumOZ655zB8=
|
||||
github.com/pion/logging v0.2.4/go.mod h1:DffhXTKYdNZU+KtJ5pyQDjvOAh/GsNSyv1lbkFbe3so=
|
||||
@@ -455,14 +459,14 @@ github.com/pion/randutil v0.1.0 h1:CFG1UdESneORglEsnimhUjf33Rwjubwj6xfiOXBa3mA=
|
||||
github.com/pion/randutil v0.1.0/go.mod h1:XcJrSMMbbMRhASFVOlj/5hQial/Y8oH/HVo7TBZq+j8=
|
||||
github.com/pion/stun/v2 v2.0.0 h1:A5+wXKLAypxQri59+tmQKVs7+l6mMM+3d+eER9ifRU0=
|
||||
github.com/pion/stun/v2 v2.0.0/go.mod h1:22qRSh08fSEttYUmJZGlriq9+03jtVmXNODgLccj8GQ=
|
||||
github.com/pion/stun/v3 v3.1.0 h1:bS1jjT3tGWZ4UPmIUeyalOylamTMTFg1OvXtY/r6seM=
|
||||
github.com/pion/stun/v3 v3.1.0/go.mod h1:egmx1CUcfSSGJxQCOjtVlomfPqmQ58BibPyuOWNGQEU=
|
||||
github.com/pion/stun/v3 v3.0.0 h1:4h1gwhWLWuZWOJIJR9s2ferRO+W3zA/b6ijOI6mKzUw=
|
||||
github.com/pion/stun/v3 v3.0.0/go.mod h1:HvCN8txt8mwi4FBvS3EmDghW6aQJ24T+y+1TKjB5jyU=
|
||||
github.com/pion/transport/v2 v2.2.1/go.mod h1:cXXWavvCnFF6McHTft3DWS9iic2Mftcz1Aq29pGcU5g=
|
||||
github.com/pion/transport/v2 v2.2.4 h1:41JJK6DZQYSeVLxILA2+F4ZkKb4Xd/tFJZRFZQ9QAlo=
|
||||
github.com/pion/transport/v2 v2.2.4/go.mod h1:q2U/tf9FEfnSBGSW6w5Qp5PFWRLRj3NjLhCCgpRK4p0=
|
||||
github.com/pion/transport/v3 v3.0.1/go.mod h1:UY7kiITrlMv7/IKgd5eTUcaahZx5oUN3l9SzK5f5xE0=
|
||||
github.com/pion/transport/v3 v3.1.1 h1:Tr684+fnnKlhPceU+ICdrw6KKkTms+5qHMgw6bIkYOM=
|
||||
github.com/pion/transport/v3 v3.1.1/go.mod h1:+c2eewC5WJQHiAA46fkMMzoYZSuGzA/7E2FPrOYHctQ=
|
||||
github.com/pion/transport/v3 v3.0.7 h1:iRbMH05BzSNwhILHoBoAPxoB9xQgOaJk+591KC9P1o0=
|
||||
github.com/pion/transport/v3 v3.0.7/go.mod h1:YleKiTZ4vqNxVwh77Z0zytYi7rXHl7j6uPLGhhz9rwo=
|
||||
github.com/pion/turn/v3 v3.0.1 h1:wLi7BTQr6/Q20R0vt/lHbjv6y4GChFtC33nkYbasoT8=
|
||||
github.com/pion/turn/v3 v3.0.1/go.mod h1:MrJDKgqryDyWy1/4NT9TWfXWGMC7UHT6pJIv1+gMeNE=
|
||||
github.com/pion/turn/v4 v4.1.1 h1:9UnY2HB99tpDyz3cVVZguSxcqkJ1DsTSZ+8TGruh4fc=
|
||||
@@ -487,8 +491,8 @@ github.com/prometheus/common v0.66.1 h1:h5E0h5/Y8niHc5DlaLlWLArTQI7tMrsfQjHV+d9Z
|
||||
github.com/prometheus/common v0.66.1/go.mod h1:gcaUsgf3KfRSwHY4dIMXLPV0K/Wg1oZ8+SbZk/HH/dA=
|
||||
github.com/prometheus/procfs v0.16.1 h1:hZ15bTNuirocR6u0JZ6BAHHmwS1p8B4P6MRqxtzMyRg=
|
||||
github.com/prometheus/procfs v0.16.1/go.mod h1:teAbpZRB1iIAJYREa1LsoWUXykVXA1KlTmWl8x/U+Is=
|
||||
github.com/quic-go/quic-go v0.55.0 h1:zccPQIqYCXDt5NmcEabyYvOnomjs8Tlwl7tISjJh9Mk=
|
||||
github.com/quic-go/quic-go v0.55.0/go.mod h1:DR51ilwU1uE164KuWXhinFcKWGlEjzys2l8zUl5Ss1U=
|
||||
github.com/quic-go/quic-go v0.49.1 h1:e5JXpUyF0f2uFjckQzD8jTghZrOUK1xxDqqZhlwixo0=
|
||||
github.com/quic-go/quic-go v0.49.1/go.mod h1:s2wDnmCdooUQBmQfpUSTCYBl1/D4FcqbULMMkASvR6s=
|
||||
github.com/redis/go-redis/v9 v9.7.3 h1:YpPyAayJV+XErNsatSElgRZZVCwXX9QzkKYNvO7x0wM=
|
||||
github.com/redis/go-redis/v9 v9.7.3/go.mod h1:bGUrSggJ9X9GUmZpZNEOQKaANxSGgOEBRltRTZHSvrA=
|
||||
github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0tI/otEQ=
|
||||
@@ -574,8 +578,8 @@ github.com/vmihailenco/msgpack/v5 v5.4.1 h1:cQriyiUvjTwOHg8QZaPihLWeRAAVoCpE00IU
|
||||
github.com/vmihailenco/msgpack/v5 v5.4.1/go.mod h1:GaZTsDaehaPpQVyxrf5mtQlH+pc21PIudVV/E3rRQok=
|
||||
github.com/vmihailenco/tagparser/v2 v2.0.0 h1:y09buUbR+b5aycVFQs/g70pqKVZNBmxwAhO7/IwNM9g=
|
||||
github.com/vmihailenco/tagparser/v2 v2.0.0/go.mod h1:Wri+At7QHww0WTrCBeu4J6bNtoV6mEfg5OIWRZA9qds=
|
||||
github.com/wlynxg/anet v0.0.5 h1:J3VJGi1gvo0JwZ/P1/Yc/8p63SoW98B5dHkYDmpgvvU=
|
||||
github.com/wlynxg/anet v0.0.5/go.mod h1:eay5PRQr7fIVAMbTbchTnO9gG65Hg/uYGdc7mguHxoA=
|
||||
github.com/wlynxg/anet v0.0.3 h1:PvR53psxFXstc12jelG6f1Lv4MWqE0tI76/hHGjh9rg=
|
||||
github.com/wlynxg/anet v0.0.3/go.mod h1:eay5PRQr7fIVAMbTbchTnO9gG65Hg/uYGdc7mguHxoA=
|
||||
github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
|
||||
github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
|
||||
github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k=
|
||||
@@ -618,8 +622,8 @@ go.opentelemetry.io/proto/otlp v1.0.0 h1:T0TX0tmXU8a3CbNXzEKGeU5mIVOdf0oykP+u2lI
|
||||
go.opentelemetry.io/proto/otlp v1.0.0/go.mod h1:Sy6pihPLfYHkr3NkUbEhGHFhINUSI/v80hjKIs5JXpM=
|
||||
go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto=
|
||||
go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE=
|
||||
go.uber.org/mock v0.5.2 h1:LbtPTcP8A5k9WPXj54PPPbjcI4Y6lhyOZXn+VS7wNko=
|
||||
go.uber.org/mock v0.5.2/go.mod h1:wLlUxC2vVTPTaE3UD51E0BGOAElKrILxhVSDYQLld5o=
|
||||
go.uber.org/mock v0.5.0 h1:KAMbZvZPyBPWgD14IrIQ38QCyjwpvVVV6K/bHl1IwQU=
|
||||
go.uber.org/mock v0.5.0/go.mod h1:ge71pBPLYDk7QIi1LupWxdAykm7KIEFchiOqd6z7qMM=
|
||||
go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0=
|
||||
go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y=
|
||||
go.uber.org/zap v1.27.0 h1:aJMhYGrd5QSmlpLMr2MftRKl7t8J8PTZPA732ud/XR8=
|
||||
@@ -713,6 +717,7 @@ golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7w
|
||||
golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20210616094352-59db8d763f22/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20211007075335-d3039528d8ac/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
@@ -843,5 +848,5 @@ gorm.io/gorm v1.25.12 h1:I0u8i2hWQItBq1WfE0o2+WuL9+8L21K9e2HHSTE/0f8=
|
||||
gorm.io/gorm v1.25.12/go.mod h1:xh7N7RHfYlNc5EmcI/El95gXusucDrQnHXe0+CgWcLQ=
|
||||
gotest.tools/v3 v3.5.0 h1:Ljk6PdHdOhAb5aDMWXjDLMMhph+BpztA4v1QdqEW2eY=
|
||||
gotest.tools/v3 v3.5.0/go.mod h1:isy3WKz7GK6uNw/sbHzfKBLvlvXwUyV06n6brMxxopU=
|
||||
gvisor.dev/gvisor v0.0.0-20251031020517-ecfcdd2f171c h1:pfzmXIkkDgydR4ZRP+e1hXywZfYR21FA0Fbk6ptMkiA=
|
||||
gvisor.dev/gvisor v0.0.0-20251031020517-ecfcdd2f171c/go.mod h1:/mc6CfwbOm5KKmqoV7Qx20Q+Ja8+vO4g7FuCdlVoAfQ=
|
||||
gvisor.dev/gvisor v0.0.0-20231020174304-b8a429915ff1 h1:qDCwdCWECGnwQSQC01Dpnp09fRHxJs9PbktotUqG+hs=
|
||||
gvisor.dev/gvisor v0.0.0-20231020174304-b8a429915ff1/go.mod h1:8hmigyCdYtw5xJGfQDJzSH5Ju8XEIDBnpyi8+O6GRt8=
|
||||
|
||||
@@ -1,113 +0,0 @@
|
||||
package dex
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log/slog"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/formatter"
|
||||
)
|
||||
|
||||
// LogrusHandler is an slog.Handler that delegates to logrus.
|
||||
// This allows Dex to use the same log format as the rest of NetBird.
|
||||
type LogrusHandler struct {
|
||||
logger *logrus.Logger
|
||||
attrs []slog.Attr
|
||||
groups []string
|
||||
}
|
||||
|
||||
// NewLogrusHandler creates a new slog handler that wraps logrus with NetBird's text formatter.
|
||||
func NewLogrusHandler(level slog.Level) *LogrusHandler {
|
||||
logger := logrus.New()
|
||||
formatter.SetTextFormatter(logger)
|
||||
|
||||
// Map slog level to logrus level
|
||||
switch level {
|
||||
case slog.LevelDebug:
|
||||
logger.SetLevel(logrus.DebugLevel)
|
||||
case slog.LevelInfo:
|
||||
logger.SetLevel(logrus.InfoLevel)
|
||||
case slog.LevelWarn:
|
||||
logger.SetLevel(logrus.WarnLevel)
|
||||
case slog.LevelError:
|
||||
logger.SetLevel(logrus.ErrorLevel)
|
||||
default:
|
||||
logger.SetLevel(logrus.WarnLevel)
|
||||
}
|
||||
|
||||
return &LogrusHandler{logger: logger}
|
||||
}
|
||||
|
||||
// Enabled reports whether the handler handles records at the given level.
|
||||
func (h *LogrusHandler) Enabled(_ context.Context, level slog.Level) bool {
|
||||
switch level {
|
||||
case slog.LevelDebug:
|
||||
return h.logger.IsLevelEnabled(logrus.DebugLevel)
|
||||
case slog.LevelInfo:
|
||||
return h.logger.IsLevelEnabled(logrus.InfoLevel)
|
||||
case slog.LevelWarn:
|
||||
return h.logger.IsLevelEnabled(logrus.WarnLevel)
|
||||
case slog.LevelError:
|
||||
return h.logger.IsLevelEnabled(logrus.ErrorLevel)
|
||||
default:
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// Handle handles the Record.
|
||||
func (h *LogrusHandler) Handle(_ context.Context, r slog.Record) error {
|
||||
fields := make(logrus.Fields)
|
||||
|
||||
// Add pre-set attributes
|
||||
for _, attr := range h.attrs {
|
||||
fields[attr.Key] = attr.Value.Any()
|
||||
}
|
||||
|
||||
// Add record attributes
|
||||
r.Attrs(func(attr slog.Attr) bool {
|
||||
fields[attr.Key] = attr.Value.Any()
|
||||
return true
|
||||
})
|
||||
|
||||
entry := h.logger.WithFields(fields)
|
||||
|
||||
switch r.Level {
|
||||
case slog.LevelDebug:
|
||||
entry.Debug(r.Message)
|
||||
case slog.LevelInfo:
|
||||
entry.Info(r.Message)
|
||||
case slog.LevelWarn:
|
||||
entry.Warn(r.Message)
|
||||
case slog.LevelError:
|
||||
entry.Error(r.Message)
|
||||
default:
|
||||
entry.Info(r.Message)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// WithAttrs returns a new Handler with the given attributes added.
|
||||
func (h *LogrusHandler) WithAttrs(attrs []slog.Attr) slog.Handler {
|
||||
newAttrs := make([]slog.Attr, len(h.attrs)+len(attrs))
|
||||
copy(newAttrs, h.attrs)
|
||||
copy(newAttrs[len(h.attrs):], attrs)
|
||||
return &LogrusHandler{
|
||||
logger: h.logger,
|
||||
attrs: newAttrs,
|
||||
groups: h.groups,
|
||||
}
|
||||
}
|
||||
|
||||
// WithGroup returns a new Handler with the given group appended to the receiver's groups.
|
||||
func (h *LogrusHandler) WithGroup(name string) slog.Handler {
|
||||
newGroups := make([]string, len(h.groups)+1)
|
||||
copy(newGroups, h.groups)
|
||||
newGroups[len(h.groups)] = name
|
||||
return &LogrusHandler{
|
||||
logger: h.logger,
|
||||
attrs: h.attrs,
|
||||
groups: newGroups,
|
||||
}
|
||||
}
|
||||
@@ -130,21 +130,7 @@ func NewProvider(ctx context.Context, config *Config) (*Provider, error) {
|
||||
|
||||
// NewProviderFromYAML creates and initializes the Dex server from a YAMLConfig
|
||||
func NewProviderFromYAML(ctx context.Context, yamlConfig *YAMLConfig) (*Provider, error) {
|
||||
// Configure log level from config, default to WARN to avoid logging sensitive data (emails)
|
||||
logLevel := slog.LevelWarn
|
||||
if yamlConfig.Logger.Level != "" {
|
||||
switch strings.ToLower(yamlConfig.Logger.Level) {
|
||||
case "debug":
|
||||
logLevel = slog.LevelDebug
|
||||
case "info":
|
||||
logLevel = slog.LevelInfo
|
||||
case "warn", "warning":
|
||||
logLevel = slog.LevelWarn
|
||||
case "error":
|
||||
logLevel = slog.LevelError
|
||||
}
|
||||
}
|
||||
logger := slog.New(NewLogrusHandler(logLevel))
|
||||
logger := slog.New(slog.NewTextHandler(os.Stderr, nil))
|
||||
|
||||
stor, err := yamlConfig.Storage.OpenStorage(logger)
|
||||
if err != nil {
|
||||
@@ -792,12 +778,11 @@ func (p *Provider) resolveRedirectURI(redirectURI string) string {
|
||||
// buildOIDCConnectorConfig creates config for OIDC-based connectors
|
||||
func buildOIDCConnectorConfig(cfg *ConnectorConfig, redirectURI string) ([]byte, error) {
|
||||
oidcConfig := map[string]interface{}{
|
||||
"issuer": cfg.Issuer,
|
||||
"clientID": cfg.ClientID,
|
||||
"clientSecret": cfg.ClientSecret,
|
||||
"redirectURI": redirectURI,
|
||||
"scopes": []string{"openid", "profile", "email"},
|
||||
"insecureEnableGroups": true,
|
||||
"issuer": cfg.Issuer,
|
||||
"clientID": cfg.ClientID,
|
||||
"clientSecret": cfg.ClientSecret,
|
||||
"redirectURI": redirectURI,
|
||||
"scopes": []string{"openid", "profile", "email"},
|
||||
}
|
||||
switch cfg.Type {
|
||||
case "zitadel":
|
||||
@@ -807,9 +792,6 @@ func buildOIDCConnectorConfig(cfg *ConnectorConfig, redirectURI string) ([]byte,
|
||||
oidcConfig["claimMapping"] = map[string]string{"email": "preferred_username"}
|
||||
case "okta":
|
||||
oidcConfig["insecureSkipEmailVerified"] = true
|
||||
oidcConfig["scopes"] = []string{"openid", "profile", "email", "groups"}
|
||||
case "pocketid":
|
||||
oidcConfig["scopes"] = []string{"openid", "profile", "email", "groups"}
|
||||
}
|
||||
return encodeConnectorConfig(oidcConfig)
|
||||
}
|
||||
|
||||
@@ -270,7 +270,7 @@ AUTH_CLIENT_ID=netbird-dashboard
|
||||
AUTH_CLIENT_SECRET=
|
||||
AUTH_AUTHORITY=$NETBIRD_HTTP_PROTOCOL://$NETBIRD_DOMAIN/oauth2
|
||||
USE_AUTH0=false
|
||||
AUTH_SUPPORTED_SCOPES=openid profile email groups
|
||||
AUTH_SUPPORTED_SCOPES=openid profile email offline_access
|
||||
AUTH_REDIRECT_URI=/nb-auth
|
||||
AUTH_SILENT_REDIRECT_URI=/nb-silent-auth
|
||||
# SSL
|
||||
|
||||
@@ -64,7 +64,7 @@ var (
|
||||
config.HttpConfig.IdpSignKeyRefreshEnabled = idpSignKeyRefreshEnabled
|
||||
}
|
||||
|
||||
var tlsEnabled bool
|
||||
tlsEnabled := false
|
||||
if mgmtLetsencryptDomain != "" || (config.HttpConfig.CertFile != "" && config.HttpConfig.CertKey != "") {
|
||||
tlsEnabled = true
|
||||
}
|
||||
@@ -143,7 +143,7 @@ func loadMgmtConfig(ctx context.Context, mgmtConfigPath string) (*nbconfig.Confi
|
||||
applyCommandLineOverrides(loadedConfig)
|
||||
|
||||
// Apply EmbeddedIdP config to HttpConfig if embedded IdP is enabled
|
||||
err := applyEmbeddedIdPConfig(ctx, loadedConfig)
|
||||
err := applyEmbeddedIdPConfig(loadedConfig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -177,7 +177,7 @@ func applyCommandLineOverrides(cfg *nbconfig.Config) {
|
||||
|
||||
// applyEmbeddedIdPConfig populates HttpConfig and EmbeddedIdP storage from config when embedded IdP is enabled.
|
||||
// This allows users to only specify EmbeddedIdP config without duplicating values in HttpConfig.
|
||||
func applyEmbeddedIdPConfig(ctx context.Context, cfg *nbconfig.Config) error {
|
||||
func applyEmbeddedIdPConfig(cfg *nbconfig.Config) error {
|
||||
if cfg.EmbeddedIdP == nil || !cfg.EmbeddedIdP.Enabled {
|
||||
return nil
|
||||
}
|
||||
@@ -190,8 +190,10 @@ func applyEmbeddedIdPConfig(ctx context.Context, cfg *nbconfig.Config) error {
|
||||
// Enable user deletion from IDP by default if EmbeddedIdP is enabled
|
||||
userDeleteFromIDPEnabled = true
|
||||
|
||||
// Set LocalAddress for embedded IdP if enabled, used for internal JWT validation
|
||||
cfg.EmbeddedIdP.LocalAddress = fmt.Sprintf("localhost:%d", mgmtPort)
|
||||
// Ensure HttpConfig exists
|
||||
if cfg.HttpConfig == nil {
|
||||
cfg.HttpConfig = &nbconfig.HttpServerConfig{}
|
||||
}
|
||||
|
||||
// Set storage defaults based on Datadir
|
||||
if cfg.EmbeddedIdP.Storage.Type == "" {
|
||||
@@ -203,22 +205,40 @@ func applyEmbeddedIdPConfig(ctx context.Context, cfg *nbconfig.Config) error {
|
||||
|
||||
issuer := cfg.EmbeddedIdP.Issuer
|
||||
|
||||
if cfg.HttpConfig != nil {
|
||||
log.WithContext(ctx).Warnf("overriding HttpConfig with EmbeddedIdP config. " +
|
||||
"HttpConfig is ignored when EmbeddedIdP is enabled. Please remove HttpConfig section from the config file")
|
||||
} else {
|
||||
// Ensure HttpConfig exists. We need it for backwards compatibility with the old config format.
|
||||
cfg.HttpConfig = &nbconfig.HttpServerConfig{}
|
||||
// Set AuthIssuer from EmbeddedIdP issuer
|
||||
if cfg.HttpConfig.AuthIssuer == "" {
|
||||
cfg.HttpConfig.AuthIssuer = issuer
|
||||
}
|
||||
|
||||
// Set HttpConfig values from EmbeddedIdP
|
||||
cfg.HttpConfig.AuthIssuer = issuer
|
||||
cfg.HttpConfig.AuthAudience = "netbird-dashboard"
|
||||
cfg.HttpConfig.CLIAuthAudience = "netbird-cli"
|
||||
cfg.HttpConfig.AuthUserIDClaim = "sub"
|
||||
cfg.HttpConfig.AuthKeysLocation = issuer + "/keys"
|
||||
cfg.HttpConfig.OIDCConfigEndpoint = issuer + "/.well-known/openid-configuration"
|
||||
cfg.HttpConfig.IdpSignKeyRefreshEnabled = true
|
||||
// Set AuthAudience to the dashboard client ID
|
||||
if cfg.HttpConfig.AuthAudience == "" {
|
||||
cfg.HttpConfig.AuthAudience = "netbird-dashboard"
|
||||
}
|
||||
|
||||
// Set CLIAuthAudience to the client app client ID
|
||||
if cfg.HttpConfig.CLIAuthAudience == "" {
|
||||
cfg.HttpConfig.CLIAuthAudience = "netbird-cli"
|
||||
}
|
||||
|
||||
// Set AuthUserIDClaim to "sub" (standard OIDC claim)
|
||||
if cfg.HttpConfig.AuthUserIDClaim == "" {
|
||||
cfg.HttpConfig.AuthUserIDClaim = "sub"
|
||||
}
|
||||
|
||||
// Set AuthKeysLocation to the JWKS endpoint
|
||||
if cfg.HttpConfig.AuthKeysLocation == "" {
|
||||
cfg.HttpConfig.AuthKeysLocation = issuer + "/keys"
|
||||
}
|
||||
|
||||
// Set OIDCConfigEndpoint to the discovery endpoint
|
||||
if cfg.HttpConfig.OIDCConfigEndpoint == "" {
|
||||
cfg.HttpConfig.OIDCConfigEndpoint = issuer + "/.well-known/openid-configuration"
|
||||
}
|
||||
|
||||
// Copy SignKeyRefreshEnabled from EmbeddedIdP config
|
||||
if cfg.EmbeddedIdP.SignKeyRefreshEnabled {
|
||||
cfg.HttpConfig.IdpSignKeyRefreshEnabled = true
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -226,12 +246,7 @@ func applyEmbeddedIdPConfig(ctx context.Context, cfg *nbconfig.Config) error {
|
||||
// applyOIDCConfig fetches and applies OIDC configuration if endpoint is specified
|
||||
func applyOIDCConfig(ctx context.Context, cfg *nbconfig.Config) error {
|
||||
oidcEndpoint := cfg.HttpConfig.OIDCConfigEndpoint
|
||||
if oidcEndpoint == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
if cfg.EmbeddedIdP != nil && cfg.EmbeddedIdP.Enabled {
|
||||
// skip OIDC config fetching if EmbeddedIdP is enabled as it is unnecessary given it is embedded
|
||||
if oidcEndpoint == "" || cfg.EmbeddedIdP != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -36,6 +36,10 @@ import (
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
|
||||
const (
|
||||
compactNetworkMapMinVersion = "v0.61.0" // TODO change to real version
|
||||
)
|
||||
|
||||
type Controller struct {
|
||||
repo Repository
|
||||
metrics *metrics
|
||||
@@ -483,6 +487,11 @@ func (c *Controller) getPeerNetworkMapExp(
|
||||
}
|
||||
}
|
||||
|
||||
peer := account.GetPeer(peerId)
|
||||
if peer != nil && supportsCompactNetworkMap(peer) {
|
||||
return account.GetPeerNetworkMapCompactExp(ctx, peerId, customZone, validatedPeers, metrics)
|
||||
}
|
||||
|
||||
return account.GetPeerNetworkMapExp(ctx, peerId, customZone, validatedPeers, metrics)
|
||||
}
|
||||
|
||||
@@ -622,6 +631,19 @@ func (c *Controller) StartWarmup(ctx context.Context) {
|
||||
|
||||
}
|
||||
|
||||
func supportsCompactNetworkMap(peer *nbpeer.Peer) bool {
|
||||
if peer.Meta.WtVersion == "development" || peer.Meta.WtVersion == "dev" {
|
||||
return true
|
||||
}
|
||||
|
||||
peerVersion := semver.Canonical("v" + peer.Meta.WtVersion)
|
||||
if peerVersion == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
return semver.Compare(peerVersion, compactNetworkMapMinVersion) >= 0
|
||||
}
|
||||
|
||||
// computeForwarderPort checks if all peers in the account have updated to a specific version or newer.
|
||||
// If all peers have the required version, it returns the new well-known port (22054), otherwise returns 0.
|
||||
func computeForwarderPort(peers []*nbpeer.Peer, requiredVersion string) int64 {
|
||||
|
||||
@@ -68,8 +68,7 @@ func (s *BaseServer) AuthManager() auth.Manager {
|
||||
if len(audiences) > 0 {
|
||||
audience = audiences[0] // Use the first client ID as the primary audience
|
||||
}
|
||||
// Use localhost keys location for internal validation (management has embedded Dex)
|
||||
keysLocation = oauthProvider.GetLocalKeysLocation()
|
||||
keysLocation = oauthProvider.GetKeysLocation()
|
||||
signingKeyRefreshEnabled = true
|
||||
issuer = oauthProvider.GetIssuer()
|
||||
userIDClaim = oauthProvider.GetUserIDClaim()
|
||||
|
||||
@@ -374,9 +374,8 @@ func shouldUsePortRange(rule *proto.FirewallRule) bool {
|
||||
// Helper function to convert nbdns.CustomZone to proto.CustomZone
|
||||
func convertToProtoCustomZone(zone nbdns.CustomZone) *proto.CustomZone {
|
||||
protoZone := &proto.CustomZone{
|
||||
Domain: zone.Domain,
|
||||
Records: make([]*proto.SimpleRecord, 0, len(zone.Records)),
|
||||
NonAuthoritative: zone.NonAuthoritative,
|
||||
Domain: zone.Domain,
|
||||
Records: make([]*proto.SimpleRecord, 0, len(zone.Records)),
|
||||
}
|
||||
for _, record := range zone.Records {
|
||||
protoZone.Records = append(protoZone.Records, &proto.SimpleRecord{
|
||||
|
||||
@@ -85,7 +85,6 @@ func (s *LoginFilterTestSuite) TestBanDurationIncreasesExponentially() {
|
||||
s.True(s.filter.logged[pubKey].isBanned)
|
||||
s.Equal(2, s.filter.logged[pubKey].banLevel)
|
||||
secondBanDuration := s.filter.logged[pubKey].banExpiresAt.Sub(s.filter.logged[pubKey].lastSeen)
|
||||
// nolint
|
||||
expectedSecondDuration := time.Duration(float64(baseBan) * math.Pow(2, 1))
|
||||
s.InDelta(expectedSecondDuration, secondBanDuration, float64(time.Millisecond))
|
||||
}
|
||||
|
||||
@@ -1006,7 +1006,7 @@ func (am *DefaultAccountManager) isCacheFresh(ctx context.Context, accountUsers
|
||||
for user, loggedInOnce := range accountUsers {
|
||||
if datum, ok := userDataMap[user]; ok {
|
||||
// check if the matching user data has a pending invite and if the user has logged in once, forcing the cache to be refreshed
|
||||
if datum.AppMetadata.WTPendingInvite != nil && *datum.AppMetadata.WTPendingInvite && loggedInOnce == true { //nolint
|
||||
if datum.AppMetadata.WTPendingInvite != nil && *datum.AppMetadata.WTPendingInvite && loggedInOnce == true { //nolint:gosimple
|
||||
log.WithContext(ctx).Infof("user %s has a pending invite and has logged in once, cache invalid", user)
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -753,7 +753,7 @@ func TestAccountManager_SetOrUpdateDomain(t *testing.T) {
|
||||
t.Fatalf("expected to create an account for a user %s", userId)
|
||||
}
|
||||
|
||||
if account.Domain != domain {
|
||||
if account != nil && account.Domain != domain {
|
||||
t.Errorf("setting account domain failed, expected %s, got %s", domain, account.Domain)
|
||||
}
|
||||
|
||||
@@ -768,7 +768,7 @@ func TestAccountManager_SetOrUpdateDomain(t *testing.T) {
|
||||
t.Fatalf("expected to get an account for a user %s", userId)
|
||||
}
|
||||
|
||||
if account.Domain != domain {
|
||||
if account != nil && account.Domain != domain {
|
||||
t.Errorf("updating domain. expected %s got %s", domain, account.Domain)
|
||||
}
|
||||
}
|
||||
@@ -3465,11 +3465,11 @@ func TestPropagateUserGroupMemberships(t *testing.T) {
|
||||
account, err := manager.GetOrCreateAccountByUser(ctx, auth.UserAuth{UserId: initiatorId, Domain: domain})
|
||||
require.NoError(t, err)
|
||||
|
||||
peer1 := &nbpeer.Peer{ID: "peer1", AccountID: account.Id, Key: "key1", UserID: initiatorId, IP: net.IP{1, 1, 1, 1}, DNSLabel: "peer1.domain.test"}
|
||||
peer1 := &nbpeer.Peer{ID: "peer1", AccountID: account.Id, UserID: initiatorId, IP: net.IP{1, 1, 1, 1}, DNSLabel: "peer1.domain.test"}
|
||||
err = manager.Store.AddPeerToAccount(ctx, peer1)
|
||||
require.NoError(t, err)
|
||||
|
||||
peer2 := &nbpeer.Peer{ID: "peer2", AccountID: account.Id, Key: "key2", UserID: initiatorId, IP: net.IP{2, 2, 2, 2}, DNSLabel: "peer2.domain.test"}
|
||||
peer2 := &nbpeer.Peer{ID: "peer2", AccountID: account.Id, UserID: initiatorId, IP: net.IP{2, 2, 2, 2}, DNSLabel: "peer2.domain.test"}
|
||||
err = manager.Store.AddPeerToAccount(ctx, peer2)
|
||||
require.NoError(t, err)
|
||||
|
||||
|
||||
25
management/server/cache/idp.go
vendored
25
management/server/cache/idp.go
vendored
@@ -26,8 +26,6 @@ type UserDataCache interface {
|
||||
Get(ctx context.Context, key string) (*idp.UserData, error)
|
||||
Set(ctx context.Context, key string, value *idp.UserData, expiration time.Duration) error
|
||||
Delete(ctx context.Context, key string) error
|
||||
GetUsers(ctx context.Context, key string) ([]*idp.UserData, error)
|
||||
SetUsers(ctx context.Context, key string, users []*idp.UserData, expiration time.Duration) error
|
||||
}
|
||||
|
||||
// UserDataCacheImpl is a struct that implements the UserDataCache interface.
|
||||
@@ -53,29 +51,6 @@ func (u *UserDataCacheImpl) Delete(ctx context.Context, key string) error {
|
||||
return u.cache.Delete(ctx, key)
|
||||
}
|
||||
|
||||
func (u *UserDataCacheImpl) GetUsers(ctx context.Context, key string) ([]*idp.UserData, error) {
|
||||
var users []*idp.UserData
|
||||
v, err := u.cache.Get(ctx, key, &users)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
switch v := v.(type) {
|
||||
case []*idp.UserData:
|
||||
return v, nil
|
||||
case *[]*idp.UserData:
|
||||
return *v, nil
|
||||
case []byte:
|
||||
return unmarshalUserData(v)
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("unexpected type: %T", v)
|
||||
}
|
||||
|
||||
func (u *UserDataCacheImpl) SetUsers(ctx context.Context, key string, users []*idp.UserData, expiration time.Duration) error {
|
||||
return u.cache.Set(ctx, key, users, store.WithExpiration(expiration))
|
||||
}
|
||||
|
||||
// NewUserDataCache creates a new UserDataCacheImpl object.
|
||||
func NewUserDataCache(store store.StoreInterface) *UserDataCacheImpl {
|
||||
simpleCache := cache.New[any](store)
|
||||
|
||||
@@ -893,7 +893,6 @@ func Test_AddPeerAndAddToAll(t *testing.T) {
|
||||
peer := &peer2.Peer{
|
||||
ID: strconv.Itoa(i),
|
||||
AccountID: accountID,
|
||||
Key: "key" + strconv.Itoa(i),
|
||||
DNSLabel: "peer" + strconv.Itoa(i),
|
||||
IP: uint32ToIP(uint32(i)),
|
||||
}
|
||||
|
||||
@@ -46,7 +46,7 @@ func initPostureChecksTestData(postureChecks ...*posture.Checks) *postureChecksH
|
||||
testPostureChecks[postureChecks.ID] = postureChecks
|
||||
|
||||
if err := postureChecks.Validate(); err != nil {
|
||||
return nil, status.Errorf(status.InvalidArgument, "%v", err) //nolint
|
||||
return nil, status.Errorf(status.InvalidArgument, "%s", err.Error()) //nolint
|
||||
}
|
||||
|
||||
return postureChecks, nil
|
||||
|
||||
@@ -178,7 +178,7 @@ func (m *AuthMiddleware) checkPATFromRequest(r *http.Request, authHeaderParts []
|
||||
m.patUsageTracker.IncrementUsage(token)
|
||||
}
|
||||
|
||||
if m.rateLimiter != nil && !isTerraformRequest(r) {
|
||||
if m.rateLimiter != nil {
|
||||
if !m.rateLimiter.Allow(token) {
|
||||
return r, status.Errorf(status.TooManyRequests, "too many requests")
|
||||
}
|
||||
@@ -214,11 +214,6 @@ func (m *AuthMiddleware) checkPATFromRequest(r *http.Request, authHeaderParts []
|
||||
return nbcontext.SetUserAuthInRequest(r, userAuth), nil
|
||||
}
|
||||
|
||||
func isTerraformRequest(r *http.Request) bool {
|
||||
ua := strings.ToLower(r.Header.Get("User-Agent"))
|
||||
return strings.Contains(ua, "terraform")
|
||||
}
|
||||
|
||||
// getTokenFromJWTRequest is a "TokenExtractor" that takes auth header parts and extracts
|
||||
// the JWT token from the Authorization header.
|
||||
func getTokenFromJWTRequest(authHeaderParts []string) (string, error) {
|
||||
|
||||
@@ -508,103 +508,6 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
|
||||
handler.ServeHTTP(rec, req)
|
||||
assert.Equal(t, http.StatusTooManyRequests, rec.Code, "Second request after cleanup should be rate limited again")
|
||||
})
|
||||
|
||||
t.Run("Terraform User Agent Not Rate Limited", func(t *testing.T) {
|
||||
rateLimitConfig := &RateLimiterConfig{
|
||||
RequestsPerMinute: 1,
|
||||
Burst: 1,
|
||||
CleanupInterval: 5 * time.Minute,
|
||||
LimiterTTL: 10 * time.Minute,
|
||||
}
|
||||
|
||||
authMiddleware := NewAuthMiddleware(
|
||||
mockAuth,
|
||||
func(ctx context.Context, userAuth nbauth.UserAuth) (string, string, error) {
|
||||
return userAuth.AccountId, userAuth.UserId, nil
|
||||
},
|
||||
func(ctx context.Context, userAuth nbauth.UserAuth) error {
|
||||
return nil
|
||||
},
|
||||
func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) {
|
||||
return &types.User{}, nil
|
||||
},
|
||||
rateLimitConfig,
|
||||
nil,
|
||||
)
|
||||
|
||||
handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
// Test various Terraform user agent formats
|
||||
terraformUserAgents := []string{
|
||||
"Terraform/1.5.0",
|
||||
"terraform/1.0.0",
|
||||
"Terraform-Provider/2.0.0",
|
||||
"Mozilla/5.0 (compatible; Terraform/1.3.0)",
|
||||
}
|
||||
|
||||
for _, userAgent := range terraformUserAgents {
|
||||
t.Run("UserAgent: "+userAgent, func(t *testing.T) {
|
||||
successCount := 0
|
||||
for i := 0; i < 10; i++ {
|
||||
req := httptest.NewRequest("GET", "http://testing/test", nil)
|
||||
req.Header.Set("Authorization", "Token "+PAT)
|
||||
req.Header.Set("User-Agent", userAgent)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
handler.ServeHTTP(rec, req)
|
||||
if rec.Code == http.StatusOK {
|
||||
successCount++
|
||||
}
|
||||
}
|
||||
|
||||
assert.Equal(t, 10, successCount, "All Terraform user agent requests should succeed (not rate limited)")
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Non-Terraform User Agent With PAT Is Rate Limited", func(t *testing.T) {
|
||||
rateLimitConfig := &RateLimiterConfig{
|
||||
RequestsPerMinute: 1,
|
||||
Burst: 1,
|
||||
CleanupInterval: 5 * time.Minute,
|
||||
LimiterTTL: 10 * time.Minute,
|
||||
}
|
||||
|
||||
authMiddleware := NewAuthMiddleware(
|
||||
mockAuth,
|
||||
func(ctx context.Context, userAuth nbauth.UserAuth) (string, string, error) {
|
||||
return userAuth.AccountId, userAuth.UserId, nil
|
||||
},
|
||||
func(ctx context.Context, userAuth nbauth.UserAuth) error {
|
||||
return nil
|
||||
},
|
||||
func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) {
|
||||
return &types.User{}, nil
|
||||
},
|
||||
rateLimitConfig,
|
||||
nil,
|
||||
)
|
||||
|
||||
handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
req := httptest.NewRequest("GET", "http://testing/test", nil)
|
||||
req.Header.Set("Authorization", "Token "+PAT)
|
||||
req.Header.Set("User-Agent", "curl/7.68.0")
|
||||
rec := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rec, req)
|
||||
assert.Equal(t, http.StatusOK, rec.Code, "First request should succeed")
|
||||
|
||||
req = httptest.NewRequest("GET", "http://testing/test", nil)
|
||||
req.Header.Set("Authorization", "Token "+PAT)
|
||||
req.Header.Set("User-Agent", "curl/7.68.0")
|
||||
rec = httptest.NewRecorder()
|
||||
handler.ServeHTTP(rec, req)
|
||||
assert.Equal(t, http.StatusTooManyRequests, rec.Code, "Second request should be rate limited")
|
||||
})
|
||||
}
|
||||
|
||||
func TestAuthMiddleware_Handler_Child(t *testing.T) {
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
//go:build benchmark
|
||||
// +build benchmark
|
||||
|
||||
package benchmarks
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
//go:build benchmark
|
||||
// +build benchmark
|
||||
|
||||
package benchmarks
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
//go:build benchmark
|
||||
// +build benchmark
|
||||
|
||||
package benchmarks
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
//go:build integration
|
||||
// +build integration
|
||||
|
||||
package integration
|
||||
|
||||
|
||||
@@ -2,13 +2,7 @@ package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/dexidp/dex/storage"
|
||||
"github.com/rs/xid"
|
||||
@@ -23,69 +17,6 @@ import (
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
)
|
||||
|
||||
// oidcProviderJSON represents the OpenID Connect discovery document
|
||||
type oidcProviderJSON struct {
|
||||
Issuer string `json:"issuer"`
|
||||
}
|
||||
|
||||
// validateOIDCIssuer validates the OIDC issuer by fetching the OpenID configuration
|
||||
// and verifying that the returned issuer matches the configured one.
|
||||
func validateOIDCIssuer(ctx context.Context, issuer string) error {
|
||||
wellKnown := strings.TrimSuffix(issuer, "/") + "/.well-known/openid-configuration"
|
||||
|
||||
httpClient := &http.Client{
|
||||
Timeout: 10 * time.Second,
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, wellKnown, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%w: %v", types.ErrIdentityProviderIssuerUnreachable, err)
|
||||
}
|
||||
|
||||
resp, err := httpClient.Do(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%w: %v", types.ErrIdentityProviderIssuerUnreachable, err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%w: unable to read response body: %v", types.ErrIdentityProviderIssuerUnreachable, err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return fmt.Errorf("%w: %s: %s", types.ErrIdentityProviderIssuerUnreachable, resp.Status, body)
|
||||
}
|
||||
|
||||
var p oidcProviderJSON
|
||||
if err := json.Unmarshal(body, &p); err != nil {
|
||||
return fmt.Errorf("%w: failed to decode provider discovery object: %v", types.ErrIdentityProviderIssuerUnreachable, err)
|
||||
}
|
||||
|
||||
if p.Issuer != issuer {
|
||||
return fmt.Errorf("%w: expected %q got %q", types.ErrIdentityProviderIssuerMismatch, issuer, p.Issuer)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// validateIdentityProviderConfig validates the identity provider configuration including
|
||||
// basic validation and OIDC issuer verification.
|
||||
func validateIdentityProviderConfig(ctx context.Context, idpConfig *types.IdentityProvider) error {
|
||||
if err := idpConfig.Validate(); err != nil {
|
||||
return status.Errorf(status.InvalidArgument, "%s", err.Error())
|
||||
}
|
||||
|
||||
// Validate the issuer by calling the OIDC discovery endpoint
|
||||
if idpConfig.Issuer != "" {
|
||||
if err := validateOIDCIssuer(ctx, idpConfig.Issuer); err != nil {
|
||||
return status.Errorf(status.InvalidArgument, "%s", err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetIdentityProviders returns all identity providers for an account
|
||||
func (am *DefaultAccountManager) GetIdentityProviders(ctx context.Context, accountID, userID string) ([]*types.IdentityProvider, error) {
|
||||
ok, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.IdentityProviders, operations.Read)
|
||||
@@ -151,8 +82,8 @@ func (am *DefaultAccountManager) CreateIdentityProvider(ctx context.Context, acc
|
||||
return nil, status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
if err := validateIdentityProviderConfig(ctx, idpConfig); err != nil {
|
||||
return nil, err
|
||||
if err := idpConfig.Validate(); err != nil {
|
||||
return nil, status.Errorf(status.InvalidArgument, "%s", err.Error())
|
||||
}
|
||||
|
||||
embeddedManager, ok := am.idpManager.(*idp.EmbeddedIdPManager)
|
||||
@@ -188,8 +119,8 @@ func (am *DefaultAccountManager) UpdateIdentityProvider(ctx context.Context, acc
|
||||
return nil, status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
if err := validateIdentityProviderConfig(ctx, idpConfig); err != nil {
|
||||
return nil, err
|
||||
if err := idpConfig.Validate(); err != nil {
|
||||
return nil, status.Errorf(status.InvalidArgument, "%s", err.Error())
|
||||
}
|
||||
|
||||
embeddedManager, ok := am.idpManager.(*idp.EmbeddedIdPManager)
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user