Compare commits

..

1 Commits

Author SHA1 Message Date
Viktor Liu
063cbdc6d8 Enable trace logging in WASM client 2026-01-16 15:06:58 +08:00
425 changed files with 4475 additions and 55700 deletions

View File

@@ -1,6 +0,0 @@
.env
.env.*
*.pem
*.key
*.crt
*.p12

View File

@@ -23,7 +23,7 @@ jobs:
- name: Check for problematic license dependencies
run: |
echo "Checking for dependencies on management/, signal/, relay/, and proxy/ packages..."
echo "Checking for dependencies on management/, signal/, and relay/ packages..."
echo ""
# Find all directories except the problematic ones and system dirs
@@ -31,7 +31,7 @@ jobs:
while IFS= read -r dir; do
echo "=== Checking $dir ==="
# Search for problematic imports, excluding test files
RESULTS=$(grep -r "github.com/netbirdio/netbird/\(management\|signal\|relay\|proxy\)" "$dir" --include="*.go" 2>/dev/null | grep -v "_test.go" | grep -v "test_" | grep -v "/test/" || true)
RESULTS=$(grep -r "github.com/netbirdio/netbird/\(management\|signal\|relay\)" "$dir" --include="*.go" 2>/dev/null | grep -v "_test.go" | grep -v "test_" | grep -v "/test/" || true)
if [ -n "$RESULTS" ]; then
echo "❌ Found problematic dependencies:"
echo "$RESULTS"
@@ -39,11 +39,11 @@ jobs:
else
echo "✓ No problematic dependencies found"
fi
done < <(find . -maxdepth 1 -type d -not -name "." -not -name "management" -not -name "signal" -not -name "relay" -not -name "proxy" -not -name ".git*" | sort)
done < <(find . -maxdepth 1 -type d -not -name "." -not -name "management" -not -name "signal" -not -name "relay" -not -name ".git*" | sort)
echo ""
if [ $FOUND_ISSUES -eq 1 ]; then
echo "❌ Found dependencies on management/, signal/, relay/, or proxy/ packages"
echo "❌ Found dependencies on management/, signal/, or relay/ packages"
echo "These packages are licensed under AGPLv3 and must not be imported by BSD-licensed code"
exit 1
else
@@ -88,7 +88,7 @@ jobs:
IMPORTERS=$(go list -json -deps ./... 2>/dev/null | jq -r "select(.Imports[]? == \"$package\") | .ImportPath")
# Check if any importer is NOT in management/signal/relay
BSD_IMPORTER=$(echo "$IMPORTERS" | grep -v "github.com/netbirdio/netbird/\(management\|signal\|relay\|proxy\)" | head -1)
BSD_IMPORTER=$(echo "$IMPORTERS" | grep -v "github.com/netbirdio/netbird/\(management\|signal\|relay\)" | head -1)
if [ -n "$BSD_IMPORTER" ]; then
echo "❌ $package ($license) is imported by BSD-licensed code: $BSD_IMPORTER"

View File

@@ -43,5 +43,5 @@ jobs:
run: git --no-pager diff --exit-code
- name: Test
run: NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -tags=devcert -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 5m -p 1 $(go list ./... | grep -v -e /management -e /signal -e /relay -e /proxy)
run: NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -tags=devcert -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 5m -p 1 $(go list ./... | grep -v /management)

View File

@@ -46,5 +46,6 @@ jobs:
time go test -timeout 1m -failfast ./client/iface/...
time go test -timeout 1m -failfast ./route/...
time go test -timeout 1m -failfast ./sharedsock/...
time go test -timeout 1m -failfast ./signal/...
time go test -timeout 1m -failfast ./util/...
time go test -timeout 1m -failfast ./version/...

View File

@@ -144,7 +144,7 @@ jobs:
run: git --no-pager diff --exit-code
- name: Test
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} CI=true go test -tags devcert -exec 'sudo' -timeout 10m -p 1 $(go list ./... | grep -v -e /management -e /signal -e /relay -e /proxy)
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} CI=true go test -tags devcert -exec 'sudo' -timeout 10m -p 1 $(go list ./... | grep -v -e /management -e /signal -e /relay)
test_client_on_docker:
name: "Client (Docker) / Unit"
@@ -204,7 +204,7 @@ jobs:
sh -c ' \
apk update; apk add --no-cache \
ca-certificates iptables ip6tables dbus dbus-dev libpcap-dev build-base; \
go test -buildvcs=false -tags devcert -v -timeout 10m -p 1 $(go list -buildvcs=false ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /client/ui -e /upload-server)
go test -buildvcs=false -tags devcert -v -timeout 10m -p 1 $(go list -buildvcs=false ./... | grep -v -e /management -e /signal -e /relay -e /client/ui -e /upload-server)
'
test_relay:
@@ -261,53 +261,6 @@ jobs:
-exec 'sudo' \
-timeout 10m -p 1 ./relay/... ./shared/relay/...
test_proxy:
name: "Proxy / Unit"
needs: [build-cache]
strategy:
fail-fast: false
matrix:
arch: [ '386','amd64' ]
runs-on: ubuntu-22.04
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Install Go
uses: actions/setup-go@v5
with:
go-version-file: "go.mod"
cache: false
- name: Install dependencies
run: sudo apt update && sudo apt install -y gcc-multilib g++-multilib libc6-dev-i386
- name: Get Go environment
run: |
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
- name: Cache Go modules
uses: actions/cache/restore@v4
with:
path: |
${{ env.cache }}
${{ env.modcache }}
key: ${{ runner.os }}-gotest-cache-${{ hashFiles('**/go.sum') }}
restore-keys: |
${{ runner.os }}-gotest-cache-
- name: Install modules
run: go mod tidy
- name: check git status
run: git --no-pager diff --exit-code
- name: Test
run: |
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
go test -timeout 10m -p 1 ./proxy/...
test_signal:
name: "Signal / Unit"
needs: [build-cache]

View File

@@ -63,7 +63,7 @@ jobs:
- run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOMODCACHE=${{ env.cache }}
- run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOCACHE=${{ env.modcache }}
- run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe mod tidy
- run: echo "files=$(go list ./... | ForEach-Object { $_ } | Where-Object { $_ -notmatch '/management' } | Where-Object { $_ -notmatch '/relay' } | Where-Object { $_ -notmatch '/signal' } | Where-Object { $_ -notmatch '/proxy' })" >> $env:GITHUB_ENV
- run: echo "files=$(go list ./... | ForEach-Object { $_ } | Where-Object { $_ -notmatch '/management' } | Where-Object { $_ -notmatch '/relay' } | Where-Object { $_ -notmatch '/signal' })" >> $env:GITHUB_ENV
- name: test
run: PsExec64 -s -w ${{ github.workspace }} cmd.exe /c "C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe test -tags=devcert -timeout 10m -p 1 ${{ env.files }} > test-out.txt 2>&1"

View File

@@ -19,8 +19,8 @@ jobs:
- name: codespell
uses: codespell-project/actions-codespell@v2
with:
ignore_words_list: erro,clienta,hastable,iif,groupd,testin,groupe,cros,ans
skip: go.mod,go.sum,**/proxy/web/**
ignore_words_list: erro,clienta,hastable,iif,groupd,testin,groupe,cros
skip: go.mod,go.sum
golangci:
strategy:
fail-fast: false

View File

@@ -9,7 +9,7 @@ on:
pull_request:
env:
SIGN_PIPE_VER: "v0.1.1"
SIGN_PIPE_VER: "v0.1.0"
GORELEASER_VER: "v2.3.2"
PRODUCT_NAME: "NetBird"
COPYRIGHT: "NetBird GmbH"

1
.gitignore vendored
View File

@@ -2,7 +2,6 @@
.run
*.iml
dist/
!proxy/web/dist/
bin/
.env
conf.json

View File

@@ -60,8 +60,8 @@
https://github.com/user-attachments/assets/10cec749-bb56-4ab3-97af-4e38850108d2
### Self-Host NetBird (Video)
[![Watch the video](https://img.youtube.com/vi/bZAgpT6nzaQ/0.jpg)](https://youtu.be/bZAgpT6nzaQ)
### NetBird on Lawrence Systems (Video)
[![Watch the video](https://img.youtube.com/vi/Kwrff6h0rEw/0.jpg)](https://www.youtube.com/watch?v=Kwrff6h0rEw)
### Key features

View File

@@ -4,7 +4,7 @@
# sudo podman build -t localhost/netbird:latest -f client/Dockerfile --ignorefile .dockerignore-client .
# sudo podman run --rm -it --cap-add={BPF,NET_ADMIN,NET_RAW} localhost/netbird:latest
FROM alpine:3.23.2
FROM alpine:3.22.2
# iproute2: busybox doesn't display ip rules properly
RUN apk add --no-cache \
bash \

View File

@@ -3,7 +3,15 @@ package android
import (
"context"
"fmt"
"time"
"github.com/cenkalti/backoff/v4"
log "github.com/sirupsen/logrus"
"google.golang.org/grpc/codes"
gstatus "google.golang.org/grpc/status"
"github.com/netbirdio/netbird/client/cmd"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/auth"
"github.com/netbirdio/netbird/client/internal/profilemanager"
"github.com/netbirdio/netbird/client/system"
@@ -76,21 +84,34 @@ func (a *Auth) SaveConfigIfSSOSupported(listener SSOListener) {
}
func (a *Auth) saveConfigIfSSOSupported() (bool, error) {
authClient, err := auth.NewAuth(a.ctx, a.config.PrivateKey, a.config.ManagementURL, a.config)
if err != nil {
return false, fmt.Errorf("failed to create auth client: %v", err)
}
defer authClient.Close()
supportsSSO := true
err := a.withBackOff(a.ctx, func() (err error) {
_, err = internal.GetPKCEAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL, nil)
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.NotFound || s.Code() == codes.Unimplemented) {
_, err = internal.GetDeviceAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL)
s, ok := gstatus.FromError(err)
if !ok {
return err
}
if s.Code() == codes.NotFound || s.Code() == codes.Unimplemented {
supportsSSO = false
err = nil
}
supportsSSO, err := authClient.IsSSOSupported(a.ctx)
if err != nil {
return false, fmt.Errorf("failed to check SSO support: %v", err)
}
return err
}
return err
})
if !supportsSSO {
return false, nil
}
if err != nil {
return false, fmt.Errorf("backoff cycle failed: %v", err)
}
err = profilemanager.WriteOutConfig(a.cfgPath, a.config)
return true, err
}
@@ -108,17 +129,19 @@ func (a *Auth) LoginWithSetupKeyAndSaveConfig(resultListener ErrListener, setupK
}
func (a *Auth) loginWithSetupKeyAndSaveConfig(setupKey string, deviceName string) error {
authClient, err := auth.NewAuth(a.ctx, a.config.PrivateKey, a.config.ManagementURL, a.config)
if err != nil {
return fmt.Errorf("failed to create auth client: %v", err)
}
defer authClient.Close()
//nolint
ctxWithValues := context.WithValue(a.ctx, system.DeviceNameCtxKey, deviceName)
err, _ = authClient.Login(ctxWithValues, setupKey, "")
err := a.withBackOff(a.ctx, func() error {
backoffErr := internal.Login(ctxWithValues, a.config, setupKey, "")
if s, ok := gstatus.FromError(backoffErr); ok && (s.Code() == codes.PermissionDenied) {
// we got an answer from management, exit backoff earlier
return backoff.Permanent(backoffErr)
}
return backoffErr
})
if err != nil {
return fmt.Errorf("login failed: %v", err)
return fmt.Errorf("backoff cycle failed: %v", err)
}
return profilemanager.WriteOutConfig(a.cfgPath, a.config)
@@ -137,41 +160,49 @@ func (a *Auth) Login(resultListener ErrListener, urlOpener URLOpener, isAndroidT
}
func (a *Auth) login(urlOpener URLOpener, isAndroidTV bool) error {
authClient, err := auth.NewAuth(a.ctx, a.config.PrivateKey, a.config.ManagementURL, a.config)
if err != nil {
return fmt.Errorf("failed to create auth client: %v", err)
}
defer authClient.Close()
var needsLogin bool
// check if we need to generate JWT token
needsLogin, err := authClient.IsLoginRequired(a.ctx)
err := a.withBackOff(a.ctx, func() (err error) {
needsLogin, err = internal.IsLoginRequired(a.ctx, a.config)
return
})
if err != nil {
return fmt.Errorf("failed to check login requirement: %v", err)
return fmt.Errorf("backoff cycle failed: %v", err)
}
jwtToken := ""
if needsLogin {
tokenInfo, err := a.foregroundGetTokenInfo(authClient, urlOpener, isAndroidTV)
tokenInfo, err := a.foregroundGetTokenInfo(urlOpener, isAndroidTV)
if err != nil {
return fmt.Errorf("interactive sso login failed: %v", err)
}
jwtToken = tokenInfo.GetTokenToUse()
}
err, _ = authClient.Login(a.ctx, "", jwtToken)
if err != nil {
return fmt.Errorf("login failed: %v", err)
}
err = a.withBackOff(a.ctx, func() error {
err := internal.Login(a.ctx, a.config, "", jwtToken)
go urlOpener.OnLoginSuccess()
if err == nil {
go urlOpener.OnLoginSuccess()
}
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied) {
return nil
}
return err
})
if err != nil {
return fmt.Errorf("backoff cycle failed: %v", err)
}
return nil
}
func (a *Auth) foregroundGetTokenInfo(authClient *auth.Auth, urlOpener URLOpener, isAndroidTV bool) (*auth.TokenInfo, error) {
oAuthFlow, err := authClient.GetOAuthFlow(a.ctx, isAndroidTV)
func (a *Auth) foregroundGetTokenInfo(urlOpener URLOpener, isAndroidTV bool) (*auth.TokenInfo, error) {
oAuthFlow, err := auth.NewOAuthFlow(a.ctx, a.config, false, isAndroidTV, "")
if err != nil {
return nil, fmt.Errorf("failed to get OAuth flow: %v", err)
return nil, err
}
flowInfo, err := oAuthFlow.RequestAuthInfo(context.TODO())
@@ -181,10 +212,22 @@ func (a *Auth) foregroundGetTokenInfo(authClient *auth.Auth, urlOpener URLOpener
go urlOpener.Open(flowInfo.VerificationURIComplete, flowInfo.UserCode)
tokenInfo, err := oAuthFlow.WaitToken(a.ctx, flowInfo)
waitTimeout := time.Duration(flowInfo.ExpiresIn) * time.Second
waitCTX, cancel := context.WithTimeout(a.ctx, waitTimeout)
defer cancel()
tokenInfo, err := oAuthFlow.WaitToken(waitCTX, flowInfo)
if err != nil {
return nil, fmt.Errorf("waiting for browser login failed: %v", err)
}
return &tokenInfo, nil
}
func (a *Auth) withBackOff(ctx context.Context, bf func() error) error {
return backoff.RetryNotify(
bf,
backoff.WithContext(cmd.CLIBackOffSettings, ctx),
func(err error, duration time.Duration) {
log.Warnf("retrying Login to the Management service in %v due to error %v", duration, err)
})
}

View File

@@ -16,6 +16,7 @@ import (
"github.com/netbirdio/netbird/client/internal/profilemanager"
"github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/client/server"
nbstatus "github.com/netbirdio/netbird/client/status"
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
"github.com/netbirdio/netbird/upload-server/types"
)
@@ -97,6 +98,7 @@ func debugBundle(cmd *cobra.Command, _ []string) error {
client := proto.NewDaemonServiceClient(conn)
request := &proto.DebugBundleRequest{
Anonymize: anonymizeFlag,
Status: getStatusOutput(cmd, anonymizeFlag),
SystemInfo: systemInfoFlag,
LogFileCount: logFileCount,
}
@@ -219,37 +221,21 @@ func runForDuration(cmd *cobra.Command, args []string) error {
time.Sleep(3 * time.Second)
cpuProfilingStarted := false
if _, err := client.StartCPUProfile(cmd.Context(), &proto.StartCPUProfileRequest{}); err != nil {
cmd.PrintErrf("Failed to start CPU profiling: %v\n", err)
} else {
cpuProfilingStarted = true
defer func() {
if cpuProfilingStarted {
if _, err := client.StopCPUProfile(cmd.Context(), &proto.StopCPUProfileRequest{}); err != nil {
cmd.PrintErrf("Failed to stop CPU profiling: %v\n", err)
}
}
}()
}
headerPostUp := fmt.Sprintf("----- NetBird post-up - Timestamp: %s", time.Now().Format(time.RFC3339))
statusOutput := fmt.Sprintf("%s\n%s", headerPostUp, getStatusOutput(cmd, anonymizeFlag))
if waitErr := waitForDurationOrCancel(cmd.Context(), duration, cmd); waitErr != nil {
return waitErr
}
cmd.Println("\nDuration completed")
if cpuProfilingStarted {
if _, err := client.StopCPUProfile(cmd.Context(), &proto.StopCPUProfileRequest{}); err != nil {
cmd.PrintErrf("Failed to stop CPU profiling: %v\n", err)
} else {
cpuProfilingStarted = false
}
}
cmd.Println("Creating debug bundle...")
headerPreDown := fmt.Sprintf("----- NetBird pre-down - Timestamp: %s - Duration: %s", time.Now().Format(time.RFC3339), duration)
statusOutput = fmt.Sprintf("%s\n%s\n%s", statusOutput, headerPreDown, getStatusOutput(cmd, anonymizeFlag))
request := &proto.DebugBundleRequest{
Anonymize: anonymizeFlag,
Status: statusOutput,
SystemInfo: systemInfoFlag,
LogFileCount: logFileCount,
}
@@ -316,6 +302,24 @@ func setSyncResponsePersistence(cmd *cobra.Command, args []string) error {
return nil
}
func getStatusOutput(cmd *cobra.Command, anon bool) string {
var statusOutputString string
statusResp, err := getStatus(cmd.Context(), true)
if err != nil {
cmd.PrintErrf("Failed to get status: %v\n", err)
} else {
pm := profilemanager.NewProfileManager()
var profName string
if activeProf, err := pm.GetActiveProfile(); err == nil {
profName = activeProf.Name
}
overview := nbstatus.ConvertToStatusOutputOverview(statusResp, anon, "", nil, nil, nil, "", profName)
statusOutputString = overview.FullDetailSummary()
}
return statusOutputString
}
func waitForDurationOrCancel(ctx context.Context, duration time.Duration, cmd *cobra.Command) error {
ticker := time.NewTicker(1 * time.Second)
defer ticker.Stop()
@@ -374,8 +378,7 @@ func generateDebugBundle(config *profilemanager.Config, recorder *peer.Status, c
InternalConfig: config,
StatusRecorder: recorder,
SyncResponse: syncResponse,
LogPath: logFilePath,
CPUProfile: nil,
LogFile: logFilePath,
},
debug.BundleConfig{
IncludeSystemInfo: true,

View File

@@ -7,6 +7,7 @@ import (
"os/user"
"runtime"
"strings"
"time"
log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
@@ -276,15 +277,18 @@ func handleSSOLogin(ctx context.Context, cmd *cobra.Command, loginResp *proto.Lo
}
func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *profilemanager.Config, setupKey, profileName string) error {
authClient, err := auth.NewAuth(ctx, config.PrivateKey, config.ManagementURL, config)
if err != nil {
return fmt.Errorf("failed to create auth client: %v", err)
}
defer authClient.Close()
needsLogin := false
needsLogin, err := authClient.IsLoginRequired(ctx)
err := WithBackOff(func() error {
err := internal.Login(ctx, config, "", "")
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied) {
needsLogin = true
return nil
}
return err
})
if err != nil {
return fmt.Errorf("check login required: %v", err)
return fmt.Errorf("backoff cycle failed: %v", err)
}
jwtToken := ""
@@ -296,9 +300,23 @@ func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *profileman
jwtToken = tokenInfo.GetTokenToUse()
}
err, _ = authClient.Login(ctx, setupKey, jwtToken)
var lastError error
err = WithBackOff(func() error {
err := internal.Login(ctx, config, setupKey, jwtToken)
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied) {
lastError = err
return nil
}
return err
})
if lastError != nil {
return fmt.Errorf("login failed: %v", lastError)
}
if err != nil {
return fmt.Errorf("login failed: %v", err)
return fmt.Errorf("backoff cycle failed: %v", err)
}
return nil
@@ -326,7 +344,11 @@ func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *pro
openURL(cmd, flowInfo.VerificationURIComplete, flowInfo.UserCode, noBrowser)
tokenInfo, err := oAuthFlow.WaitToken(context.TODO(), flowInfo)
waitTimeout := time.Duration(flowInfo.ExpiresIn) * time.Second
waitCTX, c := context.WithTimeout(context.TODO(), waitTimeout)
defer c()
tokenInfo, err := oAuthFlow.WaitToken(waitCTX, flowInfo)
if err != nil {
return nil, fmt.Errorf("waiting for browser login failed: %v", err)
}

View File

@@ -99,7 +99,7 @@ func statusFunc(cmd *cobra.Command, args []string) error {
profName = activeProf.Name
}
var outputInformationHolder = nbstatus.ConvertToStatusOutputOverview(resp.GetFullStatus(), anonymizeFlag, resp.GetDaemonVersion(), statusFilter, prefixNamesFilter, prefixNamesFilterMap, ipsFilterMap, connectionTypeFilter, profName)
var outputInformationHolder = nbstatus.ConvertToStatusOutputOverview(resp, anonymizeFlag, statusFilter, prefixNamesFilter, prefixNamesFilterMap, ipsFilterMap, connectionTypeFilter, profName)
var statusOutputString string
switch {
case detailFlag:

View File

@@ -18,7 +18,6 @@ import (
"github.com/netbirdio/netbird/management/internals/modules/peers"
"github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager"
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
"github.com/netbirdio/netbird/management/server/job"
clientProto "github.com/netbirdio/netbird/client/proto"
client "github.com/netbirdio/netbird/client/server"
@@ -98,8 +97,6 @@ func startManagement(t *testing.T, config *config.Config, testFile string) (*grp
peersmanager := peers.NewManager(store, permissionsManagerMock)
settingsManagerMock := settings.NewMockManager(ctrl)
jobManager := job.NewJobManager(nil, store, peersmanager)
iv, _ := integrations.NewIntegratedValidator(context.Background(), peersmanager, settingsManagerMock, eventStore)
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
@@ -118,7 +115,7 @@ func startManagement(t *testing.T, config *config.Config, testFile string) (*grp
requestBuffer := mgmt.NewAccountRequestBuffer(ctx, store)
networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, mgmt.MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock(), manager.NewEphemeralManager(store, peersmanager), config)
accountManager, err := mgmt.BuildManager(context.Background(), config, store, networkMapController, jobManager, nil, "", eventStore, nil, false, iv, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false)
accountManager, err := mgmt.BuildManager(context.Background(), config, store, networkMapController, nil, "", eventStore, nil, false, iv, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false)
if err != nil {
t.Fatal(err)
}
@@ -127,7 +124,7 @@ func startManagement(t *testing.T, config *config.Config, testFile string) (*grp
if err != nil {
t.Fatal(err)
}
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, jobManager, secretsManager, nil, nil, &mgmt.MockIntegratedValidator{}, networkMapController, nil)
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, secretsManager, nil, nil, &mgmt.MockIntegratedValidator{}, networkMapController, nil)
if err != nil {
t.Fatal(err)
}

View File

@@ -200,7 +200,7 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command, activeProf *pr
connectClient := internal.NewConnectClient(ctx, config, r, false)
SetupDebugHandler(ctx, config, r, connectClient, "")
return connectClient.Run(nil, util.FindFirstLogPath(logFiles))
return connectClient.Run(nil)
}
func runInDaemonMode(ctx context.Context, cmd *cobra.Command, pm *profilemanager.ProfileManager, activeProf *profilemanager.Profile, profileSwitched bool) error {

View File

@@ -10,13 +10,13 @@ import (
"net/netip"
"os"
"sync"
"time"
"github.com/sirupsen/logrus"
wgnetstack "golang.zx2c4.com/wireguard/tun/netstack"
"github.com/netbirdio/netbird/client/iface/netstack"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/auth"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/profilemanager"
sshcommon "github.com/netbirdio/netbird/client/ssh"
@@ -31,12 +31,9 @@ var (
ErrConfigNotInitialized = errors.New("config not initialized")
)
// PeerConnStatus is a peer's connection status.
type PeerConnStatus = peer.ConnStatus
const (
// PeerStatusConnected indicates the peer is in connected state.
PeerStatusConnected = peer.StatusConnected
defaultPeerConnectionTimeout = 60 * time.Second
peerConnectionPollInterval = 500 * time.Millisecond
)
// Client manages a netbird embedded client instance.
@@ -77,10 +74,6 @@ type Options struct {
StatePath string
// DisableClientRoutes disables the client routes
DisableClientRoutes bool
// BlockInbound blocks all inbound connections from peers
BlockInbound bool
// WireguardPort is the port for the WireGuard interface. Use 0 for a random port.
WireguardPort *int
}
// validateCredentials checks that exactly one credential type is provided
@@ -149,8 +142,6 @@ func New(opts Options) (*Client, error) {
PreSharedKey: &opts.PreSharedKey,
DisableServerRoutes: &t,
DisableClientRoutes: &opts.DisableClientRoutes,
BlockInbound: &opts.BlockInbound,
WireguardPort: opts.WireguardPort,
}
if opts.ConfigPath != "" {
config, err = profilemanager.UpdateOrCreateConfig(input)
@@ -170,7 +161,6 @@ func New(opts Options) (*Client, error) {
setupKey: opts.SetupKey,
jwtToken: opts.JWTToken,
config: config,
recorder: peer.NewRecorder(config.ManagementURL.String()),
}, nil
}
@@ -192,17 +182,13 @@ func (c *Client) Start(startCtx context.Context) error {
// nolint:staticcheck
ctx = context.WithValue(ctx, system.DeviceNameCtxKey, c.deviceName)
authClient, err := auth.NewAuth(ctx, c.config.PrivateKey, c.config.ManagementURL, c.config)
if err != nil {
return fmt.Errorf("create auth client: %w", err)
}
defer authClient.Close()
if err, _ := authClient.Login(ctx, c.setupKey, c.jwtToken); err != nil {
if err := internal.Login(ctx, c.config, c.setupKey, c.jwtToken); err != nil {
return fmt.Errorf("login: %w", err)
}
client := internal.NewConnectClient(ctx, c.config, c.recorder, false)
recorder := peer.NewRecorder(c.config.ManagementURL.String())
c.recorder = recorder
client := internal.NewConnectClient(ctx, c.config, recorder, false)
client.SetSyncResponsePersistence(true)
// either startup error (permanent backoff err) or nil err (successful engine up)
@@ -210,7 +196,7 @@ func (c *Client) Start(startCtx context.Context) error {
run := make(chan struct{})
clientErr := make(chan error, 1)
go func() {
if err := client.Run(run, ""); err != nil {
if err := client.Run(run); err != nil {
clientErr <- err
}
}()
@@ -278,18 +264,40 @@ func (c *Client) GetConfig() (profilemanager.Config, error) {
// Dial dials a network address in the netbird network.
// Not applicable if the userspace networking mode is disabled.
// With lazy connections, the connection is established on first traffic.
func (c *Client) Dial(ctx context.Context, network, address string) (net.Conn, error) {
logrus.Infof("embed.Dial called: network=%s, address=%s", network, address)
// Check context status upfront
if ctx.Err() != nil {
logrus.Warnf("embed.Dial: context already cancelled/expired: %v", ctx.Err())
return nil, ctx.Err()
}
engine, err := c.getEngine()
if err != nil {
logrus.Errorf("embed.Dial: getEngine failed: %v", err)
return nil, err
}
nsnet, err := engine.GetNet()
if err != nil {
logrus.Errorf("embed.Dial: GetNet failed: %v", err)
return nil, fmt.Errorf("get net: %w", err)
}
return nsnet.DialContext(ctx, network, address)
// Note: Don't wait for peer connection here - lazy connection manager
// will open the connection when DialContext is called. The netstack
// dial triggers WireGuard traffic which activates the lazy connection.
logrus.Debugf("embed.Dial: calling nsnet.DialContext for %s", address)
conn, err := nsnet.DialContext(ctx, network, address)
if err != nil {
logrus.Errorf("embed.Dial: nsnet.DialContext failed: %v", err)
return nil, err
}
logrus.Infof("embed.Dial: successfully connected to %s", address)
return conn, nil
}
// DialContext dials a network address in the netbird network with context
@@ -355,9 +363,14 @@ func (c *Client) NewHTTPClient() *http.Client {
// Status returns the current status of the client.
func (c *Client) Status() (peer.FullStatus, error) {
c.mu.Lock()
recorder := c.recorder
connect := c.connect
c.mu.Unlock()
if recorder == nil {
return peer.FullStatus{}, errors.New("client not started")
}
if connect != nil {
engine := connect.Engine()
if engine != nil {
@@ -365,7 +378,7 @@ func (c *Client) Status() (peer.FullStatus, error) {
}
}
return c.recorder.GetFullStatus(), nil
return recorder.GetFullStatus(), nil
}
// GetLatestSyncResponse returns the latest sync response from the management server.
@@ -383,6 +396,35 @@ func (c *Client) GetLatestSyncResponse() (*mgmProto.SyncResponse, error) {
return syncResp, nil
}
// WaitForPeerConnection waits for a peer with the given IP to be connected.
func (c *Client) WaitForPeerConnection(ctx context.Context, peerIP string) error {
logrus.Infof("Waiting for peer %s to be connected", peerIP)
ticker := time.NewTicker(peerConnectionPollInterval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return fmt.Errorf("timeout waiting for peer %s to connect: %w", peerIP, ctx.Err())
case <-ticker.C:
status, err := c.Status()
if err != nil {
logrus.Debugf("Error getting status while waiting for peer: %v", err)
continue
}
for _, p := range status.Peers {
if p.IP == peerIP && p.ConnStatus == peer.StatusConnected {
logrus.Infof("Peer %s is now connected (relayed: %v)", peerIP, p.Relayed)
return nil
}
}
logrus.Tracef("Peer %s not yet connected, waiting...", peerIP)
}
}
}
// SetLogLevel sets the logging level for the client and its components.
func (c *Client) SetLogLevel(levelStr string) error {
level, err := logrus.ParseLevel(levelStr)
@@ -396,9 +438,8 @@ func (c *Client) SetLogLevel(levelStr string) error {
connect := c.connect
c.mu.Unlock()
if connect != nil {
connect.SetLogLevel(level)
}
// Note: ConnectClient doesn't have SetLogLevel method
_ = connect
return nil
}

View File

@@ -83,10 +83,6 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error {
return fmt.Errorf("acl manager init: %w", err)
}
if err := m.initNoTrackChain(); err != nil {
return fmt.Errorf("init notrack chain: %w", err)
}
// persist early to ensure cleanup of chains
go func() {
if err := stateManager.PersistState(context.Background()); err != nil {
@@ -181,10 +177,6 @@ func (m *Manager) Close(stateManager *statemanager.Manager) error {
var merr *multierror.Error
if err := m.cleanupNoTrackChain(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("cleanup notrack chain: %w", err))
}
if err := m.aclMgr.Reset(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("reset acl manager: %w", err))
}
@@ -285,125 +277,6 @@ func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Prot
return m.router.RemoveInboundDNAT(localAddr, protocol, sourcePort, targetPort)
}
const (
chainNameRaw = "NETBIRD-RAW"
chainOUTPUT = "OUTPUT"
tableRaw = "raw"
)
// SetupEBPFProxyNoTrack creates notrack rules for eBPF proxy loopback traffic.
// This prevents conntrack from tracking WireGuard proxy traffic on loopback, which
// can interfere with MASQUERADE rules (e.g., from container runtimes like Podman/netavark).
//
// Traffic flows that need NOTRACK:
//
// 1. Egress: WireGuard -> fake endpoint (before eBPF rewrite)
// src=127.0.0.1:wgPort -> dst=127.0.0.1:fakePort
// Matched by: sport=wgPort
//
// 2. Egress: Proxy -> WireGuard (via raw socket)
// src=127.0.0.1:fakePort -> dst=127.0.0.1:wgPort
// Matched by: dport=wgPort
//
// 3. Ingress: Packets to WireGuard
// dst=127.0.0.1:wgPort
// Matched by: dport=wgPort
//
// 4. Ingress: Packets to proxy (after eBPF rewrite)
// dst=127.0.0.1:proxyPort
// Matched by: dport=proxyPort
//
// Rules are cleaned up when the firewall manager is closed.
func (m *Manager) SetupEBPFProxyNoTrack(proxyPort, wgPort uint16) error {
m.mutex.Lock()
defer m.mutex.Unlock()
wgPortStr := fmt.Sprintf("%d", wgPort)
proxyPortStr := fmt.Sprintf("%d", proxyPort)
// Egress rules: match outgoing loopback UDP packets
outputRuleSport := []string{"-o", "lo", "-s", "127.0.0.1", "-d", "127.0.0.1", "-p", "udp", "--sport", wgPortStr, "-j", "NOTRACK"}
if err := m.ipv4Client.AppendUnique(tableRaw, chainNameRaw, outputRuleSport...); err != nil {
return fmt.Errorf("add output sport notrack rule: %w", err)
}
outputRuleDport := []string{"-o", "lo", "-s", "127.0.0.1", "-d", "127.0.0.1", "-p", "udp", "--dport", wgPortStr, "-j", "NOTRACK"}
if err := m.ipv4Client.AppendUnique(tableRaw, chainNameRaw, outputRuleDport...); err != nil {
return fmt.Errorf("add output dport notrack rule: %w", err)
}
// Ingress rules: match incoming loopback UDP packets
preroutingRuleWg := []string{"-i", "lo", "-s", "127.0.0.1", "-d", "127.0.0.1", "-p", "udp", "--dport", wgPortStr, "-j", "NOTRACK"}
if err := m.ipv4Client.AppendUnique(tableRaw, chainNameRaw, preroutingRuleWg...); err != nil {
return fmt.Errorf("add prerouting wg notrack rule: %w", err)
}
preroutingRuleProxy := []string{"-i", "lo", "-s", "127.0.0.1", "-d", "127.0.0.1", "-p", "udp", "--dport", proxyPortStr, "-j", "NOTRACK"}
if err := m.ipv4Client.AppendUnique(tableRaw, chainNameRaw, preroutingRuleProxy...); err != nil {
return fmt.Errorf("add prerouting proxy notrack rule: %w", err)
}
log.Debugf("set up ebpf proxy notrack rules for ports %d,%d", proxyPort, wgPort)
return nil
}
func (m *Manager) initNoTrackChain() error {
if err := m.cleanupNoTrackChain(); err != nil {
log.Debugf("cleanup notrack chain: %v", err)
}
if err := m.ipv4Client.NewChain(tableRaw, chainNameRaw); err != nil {
return fmt.Errorf("create chain: %w", err)
}
jumpRule := []string{"-j", chainNameRaw}
if err := m.ipv4Client.InsertUnique(tableRaw, chainOUTPUT, 1, jumpRule...); err != nil {
if delErr := m.ipv4Client.DeleteChain(tableRaw, chainNameRaw); delErr != nil {
log.Debugf("delete orphan chain: %v", delErr)
}
return fmt.Errorf("add output jump rule: %w", err)
}
if err := m.ipv4Client.InsertUnique(tableRaw, chainPREROUTING, 1, jumpRule...); err != nil {
if delErr := m.ipv4Client.DeleteIfExists(tableRaw, chainOUTPUT, jumpRule...); delErr != nil {
log.Debugf("delete output jump rule: %v", delErr)
}
if delErr := m.ipv4Client.DeleteChain(tableRaw, chainNameRaw); delErr != nil {
log.Debugf("delete orphan chain: %v", delErr)
}
return fmt.Errorf("add prerouting jump rule: %w", err)
}
return nil
}
func (m *Manager) cleanupNoTrackChain() error {
exists, err := m.ipv4Client.ChainExists(tableRaw, chainNameRaw)
if err != nil {
return fmt.Errorf("check chain exists: %w", err)
}
if !exists {
return nil
}
jumpRule := []string{"-j", chainNameRaw}
if err := m.ipv4Client.DeleteIfExists(tableRaw, chainOUTPUT, jumpRule...); err != nil {
return fmt.Errorf("remove output jump rule: %w", err)
}
if err := m.ipv4Client.DeleteIfExists(tableRaw, chainPREROUTING, jumpRule...); err != nil {
return fmt.Errorf("remove prerouting jump rule: %w", err)
}
if err := m.ipv4Client.ClearAndDeleteChain(tableRaw, chainNameRaw); err != nil {
return fmt.Errorf("clear and delete chain: %w", err)
}
return nil
}
func getConntrackEstablished() []string {
return []string{"-m", "conntrack", "--ctstate", "RELATED,ESTABLISHED", "-j", "ACCEPT"}
}

View File

@@ -168,10 +168,6 @@ type Manager interface {
// RemoveInboundDNAT removes inbound DNAT rule
RemoveInboundDNAT(localAddr netip.Addr, protocol Protocol, sourcePort, targetPort uint16) error
// SetupEBPFProxyNoTrack creates static notrack rules for eBPF proxy loopback traffic.
// This prevents conntrack from interfering with WireGuard proxy communication.
SetupEBPFProxyNoTrack(proxyPort, wgPort uint16) error
}
func GenKey(format string, pair RouterPair) string {

View File

@@ -12,7 +12,6 @@ import (
"github.com/google/nftables/binaryutil"
"github.com/google/nftables/expr"
log "github.com/sirupsen/logrus"
"golang.org/x/sys/unix"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface/wgaddr"
@@ -49,10 +48,8 @@ type Manager struct {
rConn *nftables.Conn
wgIface iFaceMapper
router *router
aclManager *AclManager
notrackOutputChain *nftables.Chain
notrackPreroutingChain *nftables.Chain
router *router
aclManager *AclManager
}
// Create nftables firewall manager
@@ -94,10 +91,6 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error {
return fmt.Errorf("acl manager init: %w", err)
}
if err := m.initNoTrackChains(workTable); err != nil {
return fmt.Errorf("init notrack chains: %w", err)
}
stateManager.RegisterState(&ShutdownState{})
// We only need to record minimal interface state for potential recreation.
@@ -295,15 +288,7 @@ func (m *Manager) Flush() error {
m.mutex.Lock()
defer m.mutex.Unlock()
if err := m.aclManager.Flush(); err != nil {
return err
}
if err := m.refreshNoTrackChains(); err != nil {
log.Errorf("failed to refresh notrack chains: %v", err)
}
return nil
return m.aclManager.Flush()
}
// AddDNATRule adds a DNAT rule
@@ -346,176 +331,6 @@ func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Prot
return m.router.RemoveInboundDNAT(localAddr, protocol, sourcePort, targetPort)
}
const (
chainNameRawOutput = "netbird-raw-out"
chainNameRawPrerouting = "netbird-raw-pre"
)
// SetupEBPFProxyNoTrack creates notrack rules for eBPF proxy loopback traffic.
// This prevents conntrack from tracking WireGuard proxy traffic on loopback, which
// can interfere with MASQUERADE rules (e.g., from container runtimes like Podman/netavark).
//
// Traffic flows that need NOTRACK:
//
// 1. Egress: WireGuard -> fake endpoint (before eBPF rewrite)
// src=127.0.0.1:wgPort -> dst=127.0.0.1:fakePort
// Matched by: sport=wgPort
//
// 2. Egress: Proxy -> WireGuard (via raw socket)
// src=127.0.0.1:fakePort -> dst=127.0.0.1:wgPort
// Matched by: dport=wgPort
//
// 3. Ingress: Packets to WireGuard
// dst=127.0.0.1:wgPort
// Matched by: dport=wgPort
//
// 4. Ingress: Packets to proxy (after eBPF rewrite)
// dst=127.0.0.1:proxyPort
// Matched by: dport=proxyPort
//
// Rules are cleaned up when the firewall manager is closed.
func (m *Manager) SetupEBPFProxyNoTrack(proxyPort, wgPort uint16) error {
m.mutex.Lock()
defer m.mutex.Unlock()
if m.notrackOutputChain == nil || m.notrackPreroutingChain == nil {
return fmt.Errorf("notrack chains not initialized")
}
proxyPortBytes := binaryutil.BigEndian.PutUint16(proxyPort)
wgPortBytes := binaryutil.BigEndian.PutUint16(wgPort)
loopback := []byte{127, 0, 0, 1}
// Egress rules: match outgoing loopback UDP packets
m.rConn.AddRule(&nftables.Rule{
Table: m.notrackOutputChain.Table,
Chain: m.notrackOutputChain,
Exprs: []expr.Any{
&expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1},
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: ifname("lo")},
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseNetworkHeader, Offset: 12, Len: 4}, // saddr
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: loopback},
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseNetworkHeader, Offset: 16, Len: 4}, // daddr
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: loopback},
&expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1},
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: []byte{unix.IPPROTO_UDP}},
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseTransportHeader, Offset: 0, Len: 2},
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: wgPortBytes}, // sport=wgPort
&expr.Counter{},
&expr.Notrack{},
},
})
m.rConn.AddRule(&nftables.Rule{
Table: m.notrackOutputChain.Table,
Chain: m.notrackOutputChain,
Exprs: []expr.Any{
&expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1},
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: ifname("lo")},
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseNetworkHeader, Offset: 12, Len: 4}, // saddr
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: loopback},
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseNetworkHeader, Offset: 16, Len: 4}, // daddr
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: loopback},
&expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1},
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: []byte{unix.IPPROTO_UDP}},
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseTransportHeader, Offset: 2, Len: 2},
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: wgPortBytes}, // dport=wgPort
&expr.Counter{},
&expr.Notrack{},
},
})
// Ingress rules: match incoming loopback UDP packets
m.rConn.AddRule(&nftables.Rule{
Table: m.notrackPreroutingChain.Table,
Chain: m.notrackPreroutingChain,
Exprs: []expr.Any{
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: ifname("lo")},
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseNetworkHeader, Offset: 12, Len: 4}, // saddr
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: loopback},
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseNetworkHeader, Offset: 16, Len: 4}, // daddr
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: loopback},
&expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1},
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: []byte{unix.IPPROTO_UDP}},
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseTransportHeader, Offset: 2, Len: 2},
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: wgPortBytes}, // dport=wgPort
&expr.Counter{},
&expr.Notrack{},
},
})
m.rConn.AddRule(&nftables.Rule{
Table: m.notrackPreroutingChain.Table,
Chain: m.notrackPreroutingChain,
Exprs: []expr.Any{
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: ifname("lo")},
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseNetworkHeader, Offset: 12, Len: 4}, // saddr
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: loopback},
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseNetworkHeader, Offset: 16, Len: 4}, // daddr
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: loopback},
&expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1},
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: []byte{unix.IPPROTO_UDP}},
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseTransportHeader, Offset: 2, Len: 2},
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: proxyPortBytes}, // dport=proxyPort
&expr.Counter{},
&expr.Notrack{},
},
})
if err := m.rConn.Flush(); err != nil {
return fmt.Errorf("flush notrack rules: %w", err)
}
log.Debugf("set up ebpf proxy notrack rules for ports %d,%d", proxyPort, wgPort)
return nil
}
func (m *Manager) initNoTrackChains(table *nftables.Table) error {
m.notrackOutputChain = m.rConn.AddChain(&nftables.Chain{
Name: chainNameRawOutput,
Table: table,
Type: nftables.ChainTypeFilter,
Hooknum: nftables.ChainHookOutput,
Priority: nftables.ChainPriorityRaw,
})
m.notrackPreroutingChain = m.rConn.AddChain(&nftables.Chain{
Name: chainNameRawPrerouting,
Table: table,
Type: nftables.ChainTypeFilter,
Hooknum: nftables.ChainHookPrerouting,
Priority: nftables.ChainPriorityRaw,
})
if err := m.rConn.Flush(); err != nil {
return fmt.Errorf("flush chain creation: %w", err)
}
return nil
}
func (m *Manager) refreshNoTrackChains() error {
chains, err := m.rConn.ListChainsOfTableFamily(nftables.TableFamilyIPv4)
if err != nil {
return fmt.Errorf("list chains: %w", err)
}
tableName := getTableName()
for _, c := range chains {
if c.Table.Name != tableName {
continue
}
switch c.Name {
case chainNameRawOutput:
m.notrackOutputChain = c
case chainNameRawPrerouting:
m.notrackPreroutingChain = c
}
}
return nil
}
func (m *Manager) createWorkTable() (*nftables.Table, error) {
tables, err := m.rConn.ListTablesOfFamily(nftables.TableFamilyIPv4)
if err != nil {

View File

@@ -483,12 +483,7 @@ func (r *router) DeleteRouteRule(rule firewall.Rule) error {
}
if nftRule.Handle == 0 {
log.Warnf("route rule %s has no handle, removing stale entry", ruleKey)
if err := r.decrementSetCounter(nftRule); err != nil {
log.Warnf("decrement set counter for stale rule %s: %v", ruleKey, err)
}
delete(r.rules, ruleKey)
return nil
return fmt.Errorf("route rule %s has no handle", ruleKey)
}
if err := r.deleteNftRule(nftRule, ruleKey); err != nil {
@@ -665,32 +660,13 @@ func (r *router) AddNatRule(pair firewall.RouterPair) error {
}
if err := r.conn.Flush(); err != nil {
r.rollbackRules(pair)
return fmt.Errorf("insert rules for %s: %w", pair.Destination, err)
// TODO: rollback ipset counter
return fmt.Errorf("insert rules for %s: %v", pair.Destination, err)
}
return nil
}
// rollbackRules cleans up unflushed rules and their set counters after a flush failure.
func (r *router) rollbackRules(pair firewall.RouterPair) {
keys := []string{
firewall.GenKey(firewall.ForwardingFormat, pair),
firewall.GenKey(firewall.PreroutingFormat, pair),
firewall.GenKey(firewall.PreroutingFormat, firewall.GetInversePair(pair)),
}
for _, key := range keys {
rule, ok := r.rules[key]
if !ok {
continue
}
if err := r.decrementSetCounter(rule); err != nil {
log.Warnf("rollback set counter for %s: %v", key, err)
}
delete(r.rules, key)
}
}
// addNatRule inserts a nftables rule to the conn client flush queue
func (r *router) addNatRule(pair firewall.RouterPair) error {
sourceExp, err := r.applyNetwork(pair.Source, nil, true)
@@ -952,30 +928,18 @@ func (r *router) addLegacyRouteRule(pair firewall.RouterPair) error {
func (r *router) removeLegacyRouteRule(pair firewall.RouterPair) error {
ruleKey := firewall.GenKey(firewall.ForwardingFormat, pair)
rule, exists := r.rules[ruleKey]
if !exists {
return nil
}
if rule.Handle == 0 {
log.Warnf("legacy forwarding rule %s has no handle, removing stale entry", ruleKey)
if err := r.decrementSetCounter(rule); err != nil {
log.Warnf("decrement set counter for stale rule %s: %v", ruleKey, err)
if rule, exists := r.rules[ruleKey]; exists {
if err := r.conn.DelRule(rule); err != nil {
return fmt.Errorf("remove legacy forwarding rule %s -> %s: %v", pair.Source, pair.Destination, err)
}
log.Debugf("removed legacy forwarding rule %s -> %s", pair.Source, pair.Destination)
delete(r.rules, ruleKey)
return nil
}
if err := r.conn.DelRule(rule); err != nil {
return fmt.Errorf("remove legacy forwarding rule %s -> %s: %w", pair.Source, pair.Destination, err)
}
log.Debugf("removed legacy forwarding rule %s -> %s", pair.Source, pair.Destination)
delete(r.rules, ruleKey)
if err := r.decrementSetCounter(rule); err != nil {
return fmt.Errorf("decrement set counter: %w", err)
if err := r.decrementSetCounter(rule); err != nil {
return fmt.Errorf("decrement set counter: %w", err)
}
}
return nil
@@ -1365,89 +1329,65 @@ func (r *router) RemoveNatRule(pair firewall.RouterPair) error {
return fmt.Errorf(refreshRulesMapError, err)
}
var merr *multierror.Error
if pair.Masquerade {
if err := r.removeNatRule(pair); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove prerouting rule: %w", err))
return fmt.Errorf("remove prerouting rule: %w", err)
}
if err := r.removeNatRule(firewall.GetInversePair(pair)); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove inverse prerouting rule: %w", err))
return fmt.Errorf("remove inverse prerouting rule: %w", err)
}
}
if err := r.removeLegacyRouteRule(pair); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove legacy routing rule: %w", err))
return fmt.Errorf("remove legacy routing rule: %w", err)
}
// Set counters are decremented in the sub-methods above before flush. If flush fails,
// counters will be off until the next successful removal or refresh cycle.
if err := r.conn.Flush(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("flush remove nat rules %s: %w", pair.Destination, err))
}
return nberrors.FormatErrorOrNil(merr)
}
func (r *router) removeNatRule(pair firewall.RouterPair) error {
ruleKey := firewall.GenKey(firewall.PreroutingFormat, pair)
rule, exists := r.rules[ruleKey]
if !exists {
log.Debugf("prerouting rule %s not found", ruleKey)
return nil
}
if rule.Handle == 0 {
log.Warnf("prerouting rule %s has no handle, removing stale entry", ruleKey)
if err := r.decrementSetCounter(rule); err != nil {
log.Warnf("decrement set counter for stale rule %s: %v", ruleKey, err)
}
delete(r.rules, ruleKey)
return nil
}
if err := r.conn.DelRule(rule); err != nil {
return fmt.Errorf("remove prerouting rule %s -> %s: %w", pair.Source, pair.Destination, err)
}
log.Debugf("removed prerouting rule %s -> %s", pair.Source, pair.Destination)
delete(r.rules, ruleKey)
if err := r.decrementSetCounter(rule); err != nil {
return fmt.Errorf("decrement set counter: %w", err)
// TODO: rollback set counter
return fmt.Errorf("remove nat rules rule %s: %v", pair.Destination, err)
}
return nil
}
// refreshRulesMap rebuilds the rule map from the kernel. This removes stale entries
// (e.g. from failed flushes) and updates handles for all existing rules.
func (r *router) removeNatRule(pair firewall.RouterPair) error {
ruleKey := firewall.GenKey(firewall.PreroutingFormat, pair)
if rule, exists := r.rules[ruleKey]; exists {
if err := r.conn.DelRule(rule); err != nil {
return fmt.Errorf("remove prerouting rule %s -> %s: %v", pair.Source, pair.Destination, err)
}
log.Debugf("removed prerouting rule %s -> %s", pair.Source, pair.Destination)
delete(r.rules, ruleKey)
if err := r.decrementSetCounter(rule); err != nil {
return fmt.Errorf("decrement set counter: %w", err)
}
} else {
log.Debugf("prerouting rule %s not found", ruleKey)
}
return nil
}
// refreshRulesMap refreshes the rule map with the latest rules. this is useful to avoid
// duplicates and to get missing attributes that we don't have when adding new rules
func (r *router) refreshRulesMap() error {
var merr *multierror.Error
newRules := make(map[string]*nftables.Rule)
for _, chain := range r.chains {
rules, err := r.conn.GetRules(chain.Table, chain)
if err != nil {
merr = multierror.Append(merr, fmt.Errorf("list rules for chain %s: %w", chain.Name, err))
// preserve existing entries for this chain since we can't verify their state
for k, v := range r.rules {
if v.Chain != nil && v.Chain.Name == chain.Name {
newRules[k] = v
}
}
continue
return fmt.Errorf("list rules: %w", err)
}
for _, rule := range rules {
if len(rule.UserData) > 0 {
newRules[string(rule.UserData)] = rule
r.rules[string(rule.UserData)] = rule
}
}
}
r.rules = newRules
return nberrors.FormatErrorOrNil(merr)
return nil
}
func (r *router) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) {
@@ -1689,34 +1629,20 @@ func (r *router) DeleteDNATRule(rule firewall.Rule) error {
}
var merr *multierror.Error
var needsFlush bool
if dnatRule, exists := r.rules[ruleKey+dnatSuffix]; exists {
if dnatRule.Handle == 0 {
log.Warnf("dnat rule %s has no handle, removing stale entry", ruleKey+dnatSuffix)
delete(r.rules, ruleKey+dnatSuffix)
} else if err := r.conn.DelRule(dnatRule); err != nil {
if err := r.conn.DelRule(dnatRule); err != nil {
merr = multierror.Append(merr, fmt.Errorf("delete dnat rule: %w", err))
} else {
needsFlush = true
}
}
if masqRule, exists := r.rules[ruleKey+snatSuffix]; exists {
if masqRule.Handle == 0 {
log.Warnf("snat rule %s has no handle, removing stale entry", ruleKey+snatSuffix)
delete(r.rules, ruleKey+snatSuffix)
} else if err := r.conn.DelRule(masqRule); err != nil {
if err := r.conn.DelRule(masqRule); err != nil {
merr = multierror.Append(merr, fmt.Errorf("delete snat rule: %w", err))
} else {
needsFlush = true
}
}
if needsFlush {
if err := r.conn.Flush(); err != nil {
merr = multierror.Append(merr, fmt.Errorf(flushError, err))
}
if err := r.conn.Flush(); err != nil {
merr = multierror.Append(merr, fmt.Errorf(flushError, err))
}
if merr == nil {
@@ -1831,25 +1757,16 @@ func (r *router) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Proto
ruleID := fmt.Sprintf("inbound-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort)
rule, exists := r.rules[ruleID]
if !exists {
return nil
}
if rule.Handle == 0 {
log.Warnf("inbound DNAT rule %s has no handle, removing stale entry", ruleID)
if rule, exists := r.rules[ruleID]; exists {
if err := r.conn.DelRule(rule); err != nil {
return fmt.Errorf("delete inbound DNAT rule %s: %w", ruleID, err)
}
if err := r.conn.Flush(); err != nil {
return fmt.Errorf("flush delete inbound DNAT rule: %w", err)
}
delete(r.rules, ruleID)
return nil
}
if err := r.conn.DelRule(rule); err != nil {
return fmt.Errorf("delete inbound DNAT rule %s: %w", ruleID, err)
}
if err := r.conn.Flush(); err != nil {
return fmt.Errorf("flush delete inbound DNAT rule: %w", err)
}
delete(r.rules, ruleID)
return nil
}

View File

@@ -18,7 +18,6 @@ import (
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/firewall/test"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/internal/acl/id"
)
const (
@@ -720,137 +719,3 @@ func deleteWorkTable() {
}
}
}
func TestRouter_RefreshRulesMap_RemovesStaleEntries(t *testing.T) {
if check() != NFTABLES {
t.Skip("nftables not supported on this system")
}
workTable, err := createWorkTable()
require.NoError(t, err)
defer deleteWorkTable()
r, err := newRouter(workTable, ifaceMock, iface.DefaultMTU)
require.NoError(t, err)
require.NoError(t, r.init(workTable))
defer func() { require.NoError(t, r.Reset()) }()
// Add a real rule to the kernel
ruleKey, err := r.AddRouteFiltering(
nil,
[]netip.Prefix{netip.MustParsePrefix("192.168.1.0/24")},
firewall.Network{Prefix: netip.MustParsePrefix("10.0.0.0/24")},
firewall.ProtocolTCP,
nil,
&firewall.Port{Values: []uint16{80}},
firewall.ActionAccept,
)
require.NoError(t, err)
t.Cleanup(func() {
require.NoError(t, r.DeleteRouteRule(ruleKey))
})
// Inject a stale entry with Handle=0 (simulates store-before-flush failure)
staleKey := "stale-rule-that-does-not-exist"
r.rules[staleKey] = &nftables.Rule{
Table: r.workTable,
Chain: r.chains[chainNameRoutingFw],
Handle: 0,
UserData: []byte(staleKey),
}
require.Contains(t, r.rules, staleKey, "stale entry should be in map before refresh")
err = r.refreshRulesMap()
require.NoError(t, err)
assert.NotContains(t, r.rules, staleKey, "stale entry should be removed after refresh")
realRule, ok := r.rules[ruleKey.ID()]
assert.True(t, ok, "real rule should still exist after refresh")
assert.NotZero(t, realRule.Handle, "real rule should have a valid handle")
}
func TestRouter_DeleteRouteRule_StaleHandle(t *testing.T) {
if check() != NFTABLES {
t.Skip("nftables not supported on this system")
}
workTable, err := createWorkTable()
require.NoError(t, err)
defer deleteWorkTable()
r, err := newRouter(workTable, ifaceMock, iface.DefaultMTU)
require.NoError(t, err)
require.NoError(t, r.init(workTable))
defer func() { require.NoError(t, r.Reset()) }()
// Inject a stale entry with Handle=0
staleKey := "stale-route-rule"
r.rules[staleKey] = &nftables.Rule{
Table: r.workTable,
Chain: r.chains[chainNameRoutingFw],
Handle: 0,
UserData: []byte(staleKey),
}
// DeleteRouteRule should not return an error for stale handles
err = r.DeleteRouteRule(id.RuleID(staleKey))
assert.NoError(t, err, "deleting a stale rule should not error")
assert.NotContains(t, r.rules, staleKey, "stale entry should be cleaned up")
}
func TestRouter_AddNatRule_WithStaleEntry(t *testing.T) {
if check() != NFTABLES {
t.Skip("nftables not supported on this system")
}
manager, err := Create(ifaceMock, iface.DefaultMTU)
require.NoError(t, err)
require.NoError(t, manager.Init(nil))
t.Cleanup(func() {
require.NoError(t, manager.Close(nil))
})
pair := firewall.RouterPair{
ID: "staletest",
Source: firewall.Network{Prefix: netip.MustParsePrefix("100.100.100.1/32")},
Destination: firewall.Network{Prefix: netip.MustParsePrefix("100.100.200.0/24")},
Masquerade: true,
}
rtr := manager.router
// First add succeeds
err = rtr.AddNatRule(pair)
require.NoError(t, err)
t.Cleanup(func() {
require.NoError(t, rtr.RemoveNatRule(pair))
})
// Corrupt the handle to simulate stale state
natRuleKey := firewall.GenKey(firewall.PreroutingFormat, pair)
if rule, exists := rtr.rules[natRuleKey]; exists {
rule.Handle = 0
}
inverseKey := firewall.GenKey(firewall.PreroutingFormat, firewall.GetInversePair(pair))
if rule, exists := rtr.rules[inverseKey]; exists {
rule.Handle = 0
}
// Adding the same rule again should succeed despite stale handles
err = rtr.AddNatRule(pair)
assert.NoError(t, err, "AddNatRule should succeed even with stale entries")
// Verify rules exist in kernel
rules, err := rtr.conn.GetRules(rtr.workTable, rtr.chains[chainNameManglePrerouting])
require.NoError(t, err)
found := 0
for _, rule := range rules {
if len(rule.UserData) > 0 && string(rule.UserData) == natRuleKey {
found++
}
}
assert.Equal(t, 1, found, "NAT rule should exist in kernel")
}

View File

@@ -3,6 +3,12 @@
package uspfilter
import (
"context"
"net/netip"
"time"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/statemanager"
)
@@ -11,7 +17,33 @@ func (m *Manager) Close(stateManager *statemanager.Manager) error {
m.mutex.Lock()
defer m.mutex.Unlock()
m.resetState()
m.outgoingRules = make(map[netip.Addr]RuleSet)
m.incomingDenyRules = make(map[netip.Addr]RuleSet)
m.incomingRules = make(map[netip.Addr]RuleSet)
if m.udpTracker != nil {
m.udpTracker.Close()
}
if m.icmpTracker != nil {
m.icmpTracker.Close()
}
if m.tcpTracker != nil {
m.tcpTracker.Close()
}
if fwder := m.forwarder.Load(); fwder != nil {
fwder.Stop()
}
if m.logger != nil {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
if err := m.logger.Stop(ctx); err != nil {
log.Errorf("failed to shutdown logger: %v", err)
}
}
if m.nativeFirewall != nil {
return m.nativeFirewall.Close(stateManager)

View File

@@ -1,9 +1,12 @@
package uspfilter
import (
"context"
"fmt"
"net/netip"
"os/exec"
"syscall"
"time"
log "github.com/sirupsen/logrus"
@@ -23,7 +26,33 @@ func (m *Manager) Close(*statemanager.Manager) error {
m.mutex.Lock()
defer m.mutex.Unlock()
m.resetState()
m.outgoingRules = make(map[netip.Addr]RuleSet)
m.incomingDenyRules = make(map[netip.Addr]RuleSet)
m.incomingRules = make(map[netip.Addr]RuleSet)
if m.udpTracker != nil {
m.udpTracker.Close()
}
if m.icmpTracker != nil {
m.icmpTracker.Close()
}
if m.tcpTracker != nil {
m.tcpTracker.Close()
}
if fwder := m.forwarder.Load(); fwder != nil {
fwder.Stop()
}
if m.logger != nil {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
if err := m.logger.Stop(ctx); err != nil {
log.Errorf("failed to shutdown logger: %v", err)
}
}
if !isWindowsFirewallReachable() {
return nil

View File

@@ -115,17 +115,6 @@ func (t *TCPConnTrack) IsTombstone() bool {
return t.tombstone.Load()
}
// IsSupersededBy returns true if this connection should be replaced by a new one
// carrying the given flags. Tombstoned connections are always superseded; TIME-WAIT
// connections are superseded by a pure SYN (a new connection attempt for the same
// four-tuple, as contemplated by RFC 1122 §4.2.2.13 and RFC 6191).
func (t *TCPConnTrack) IsSupersededBy(flags uint8) bool {
if t.tombstone.Load() {
return true
}
return flags&TCPSyn != 0 && flags&TCPAck == 0 && TCPState(t.state.Load()) == TCPStateTimeWait
}
// SetTombstone safely marks the connection for deletion
func (t *TCPConnTrack) SetTombstone() {
t.tombstone.Store(true)
@@ -180,7 +169,7 @@ func (t *TCPTracker) updateIfExists(srcIP, dstIP netip.Addr, srcPort, dstPort ui
conn, exists := t.connections[key]
t.mutex.RUnlock()
if exists && !conn.IsSupersededBy(flags) {
if exists {
t.updateState(key, conn, flags, direction, size)
return key, uint16(conn.DNATOrigPort.Load()), true
}
@@ -252,7 +241,7 @@ func (t *TCPTracker) IsValidInbound(srcIP, dstIP netip.Addr, srcPort, dstPort ui
conn, exists := t.connections[key]
t.mutex.RUnlock()
if !exists || conn.IsSupersededBy(flags) {
if !exists || conn.IsTombstone() {
return false
}

View File

@@ -485,261 +485,6 @@ func TestTCPAbnormalSequences(t *testing.T) {
})
}
// TestTCPPortReuseTombstone verifies that a new connection on a port with a
// tombstoned (closed) conntrack entry is properly tracked. Without the fix,
// updateIfExists treats tombstoned entries as live, causing track() to skip
// creating a new connection. The subsequent SYN-ACK then fails IsValidInbound
// because the entry is tombstoned, and the response packet gets dropped by ACL.
func TestTCPPortReuseTombstone(t *testing.T) {
srcIP := netip.MustParseAddr("100.64.0.1")
dstIP := netip.MustParseAddr("100.64.0.2")
srcPort := uint16(12345)
dstPort := uint16(80)
t.Run("Outbound port reuse after graceful close", func(t *testing.T) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close()
key := ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}
// Establish and gracefully close a connection (server-initiated close)
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
// Server sends FIN
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0)
require.True(t, valid)
// Client sends FIN-ACK
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
// Server sends final ACK
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
require.True(t, valid)
// Connection should be tombstoned
conn := tracker.connections[key]
require.NotNil(t, conn, "old connection should still be in map")
require.True(t, conn.IsTombstone(), "old connection should be tombstoned")
// Now reuse the same port for a new connection
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn, 100)
// The old tombstoned entry should be replaced with a new one
newConn := tracker.connections[key]
require.NotNil(t, newConn, "new connection should exist")
require.False(t, newConn.IsTombstone(), "new connection should not be tombstoned")
require.Equal(t, TCPStateSynSent, newConn.GetState())
// SYN-ACK for the new connection should be valid
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck, 100)
require.True(t, valid, "SYN-ACK for new connection on reused port should be accepted")
require.Equal(t, TCPStateEstablished, newConn.GetState())
// Data transfer should work
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 100)
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPPush|TCPAck, 500)
require.True(t, valid, "data should be allowed on new connection")
})
t.Run("Outbound port reuse after RST", func(t *testing.T) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close()
key := ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}
// Establish and RST a connection
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst|TCPAck, 0)
require.True(t, valid)
conn := tracker.connections[key]
require.True(t, conn.IsTombstone(), "RST connection should be tombstoned")
// Reuse the same port
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn, 100)
newConn := tracker.connections[key]
require.NotNil(t, newConn)
require.False(t, newConn.IsTombstone())
require.Equal(t, TCPStateSynSent, newConn.GetState())
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck, 100)
require.True(t, valid, "SYN-ACK should be accepted after RST tombstone")
})
t.Run("Inbound port reuse after close", func(t *testing.T) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close()
clientIP := srcIP
serverIP := dstIP
clientPort := srcPort
serverPort := dstPort
key := ConnKey{SrcIP: clientIP, DstIP: serverIP, SrcPort: clientPort, DstPort: serverPort}
// Inbound connection: client SYN → server SYN-ACK → client ACK
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPSyn, nil, 100, 0)
tracker.TrackOutbound(serverIP, clientIP, serverPort, clientPort, TCPSyn|TCPAck, 100)
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPAck, nil, 100, 0)
conn := tracker.connections[key]
require.Equal(t, TCPStateEstablished, conn.GetState())
// Server-initiated close to reach Closed/tombstoned:
// Server FIN (opposite dir) → CloseWait
tracker.TrackOutbound(serverIP, clientIP, serverPort, clientPort, TCPFin|TCPAck, 100)
require.Equal(t, TCPStateCloseWait, conn.GetState())
// Client FIN-ACK (same dir as conn) → LastAck
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPFin|TCPAck, nil, 100, 0)
require.Equal(t, TCPStateLastAck, conn.GetState())
// Server final ACK (opposite dir) → Closed → tombstoned
tracker.TrackOutbound(serverIP, clientIP, serverPort, clientPort, TCPAck, 100)
require.True(t, conn.IsTombstone())
// New inbound connection on same ports
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPSyn, nil, 100, 0)
newConn := tracker.connections[key]
require.NotNil(t, newConn)
require.False(t, newConn.IsTombstone())
require.Equal(t, TCPStateSynReceived, newConn.GetState())
// Complete handshake: server SYN-ACK, then client ACK
tracker.TrackOutbound(serverIP, clientIP, serverPort, clientPort, TCPSyn|TCPAck, 100)
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPAck, nil, 100, 0)
require.Equal(t, TCPStateEstablished, newConn.GetState())
})
t.Run("Late ACK on tombstoned connection is harmless", func(t *testing.T) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close()
key := ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}
// Establish and close via passive close (server-initiated FIN → Closed → tombstoned)
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0) // CloseWait
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0) // LastAck
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0) // Closed
conn := tracker.connections[key]
require.True(t, conn.IsTombstone())
// Late ACK should be rejected (tombstoned)
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
require.False(t, valid, "late ACK on tombstoned connection should be rejected")
// Late outbound ACK should not create a new connection (not a SYN)
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
require.True(t, tracker.connections[key].IsTombstone(), "late outbound ACK should not replace tombstoned entry")
})
}
func TestTCPPortReuseTimeWait(t *testing.T) {
srcIP := netip.MustParseAddr("100.64.0.1")
dstIP := netip.MustParseAddr("100.64.0.2")
srcPort := uint16(12345)
dstPort := uint16(80)
t.Run("Outbound port reuse during TIME-WAIT (active close)", func(t *testing.T) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close()
key := ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}
// Establish connection
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
// Active close: client (outbound initiator) sends FIN first
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
conn := tracker.connections[key]
require.Equal(t, TCPStateFinWait1, conn.GetState())
// Server ACKs the FIN
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
require.True(t, valid)
require.Equal(t, TCPStateFinWait2, conn.GetState())
// Server sends its own FIN
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0)
require.True(t, valid)
require.Equal(t, TCPStateTimeWait, conn.GetState())
// Client sends final ACK (TIME-WAIT stays, not tombstoned)
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
require.False(t, conn.IsTombstone(), "TIME-WAIT should not be tombstoned")
// New outbound SYN on the same port (port reuse during TIME-WAIT)
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn, 100)
// Per RFC 1122/6191, new SYN during TIME-WAIT should start a new connection
newConn := tracker.connections[key]
require.NotNil(t, newConn, "new connection should exist")
require.False(t, newConn.IsTombstone(), "new connection should not be tombstoned")
require.Equal(t, TCPStateSynSent, newConn.GetState(), "new connection should be in SYN-SENT")
// SYN-ACK for new connection should be valid
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck, 100)
require.True(t, valid, "SYN-ACK for new connection should be accepted")
require.Equal(t, TCPStateEstablished, newConn.GetState())
})
t.Run("Inbound SYN during TIME-WAIT falls through to normal tracking", func(t *testing.T) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close()
key := ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}
// Establish outbound connection and close via active close → TIME-WAIT
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0)
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
conn := tracker.connections[key]
require.Equal(t, TCPStateTimeWait, conn.GetState())
// Inbound SYN on same ports during TIME-WAIT: IsValidInbound returns false
// so the filter falls through to ACL check + TrackInbound (which creates
// a new connection via track() → updateIfExists skips TIME-WAIT for SYN)
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn, 0)
require.False(t, valid, "inbound SYN during TIME-WAIT should fail conntrack validation")
// Simulate what the filter does next: TrackInbound via the normal path
tracker.TrackInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn, nil, 100, 0)
// The new inbound connection uses the inverted key (dst→src becomes src→dst in track)
invertedKey := ConnKey{SrcIP: dstIP, DstIP: srcIP, SrcPort: dstPort, DstPort: srcPort}
newConn := tracker.connections[invertedKey]
require.NotNil(t, newConn, "new inbound connection should be tracked")
require.Equal(t, TCPStateSynReceived, newConn.GetState())
require.False(t, newConn.IsTombstone())
})
t.Run("Late retransmit during TIME-WAIT still allowed", func(t *testing.T) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close()
key := ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}
// Establish and active close → TIME-WAIT
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0)
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
conn := tracker.connections[key]
require.Equal(t, TCPStateTimeWait, conn.GetState())
// Late ACK retransmits during TIME-WAIT should still be accepted
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
require.True(t, valid, "retransmitted ACK during TIME-WAIT should be accepted")
})
}
func TestTCPTimeoutHandling(t *testing.T) {
// Create tracker with a very short timeout for testing
shortTimeout := 100 * time.Millisecond

View File

@@ -1,7 +1,6 @@
package uspfilter
import (
"context"
"encoding/binary"
"errors"
"fmt"
@@ -13,13 +12,11 @@ import (
"strings"
"sync"
"sync/atomic"
"time"
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
"github.com/google/uuid"
log "github.com/sirupsen/logrus"
"golang.org/x/exp/maps"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/firewall/uspfilter/common"
@@ -27,7 +24,6 @@ import (
"github.com/netbirdio/netbird/client/firewall/uspfilter/forwarder"
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
"github.com/netbirdio/netbird/client/iface/netstack"
nbid "github.com/netbirdio/netbird/client/internal/acl/id"
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
"github.com/netbirdio/netbird/client/internal/statemanager"
)
@@ -93,7 +89,6 @@ type Manager struct {
incomingDenyRules map[netip.Addr]RuleSet
incomingRules map[netip.Addr]RuleSet
routeRules RouteRules
routeRulesMap map[nbid.RuleID]*RouteRule
decoders sync.Pool
wgIface common.IFaceMapper
nativeFirewall firewall.Manager
@@ -234,7 +229,6 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe
flowLogger: flowLogger,
netstack: netstack.IsEnabled(),
localForwarding: enableLocalForwarding,
routeRulesMap: make(map[nbid.RuleID]*RouteRule),
dnatMappings: make(map[netip.Addr]netip.Addr),
portDNATRules: []portDNATRule{},
netstackServices: make(map[serviceKey]struct{}),
@@ -486,15 +480,11 @@ func (m *Manager) addRouteFiltering(
return m.nativeFirewall.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action)
}
ruleKey := nbid.GenerateRouteRuleKey(sources, destination, proto, sPort, dPort, action)
if existingRule, ok := m.routeRulesMap[ruleKey]; ok {
return existingRule, nil
}
ruleID := uuid.New().String()
rule := RouteRule{
// TODO: consolidate these IDs
id: string(ruleKey),
id: ruleID,
mgmtId: id,
sources: sources,
dstSet: destination.Set,
@@ -509,7 +499,6 @@ func (m *Manager) addRouteFiltering(
m.routeRules = append(m.routeRules, &rule)
m.routeRules.Sort()
m.routeRulesMap[ruleKey] = &rule
return &rule, nil
}
@@ -526,20 +515,15 @@ func (m *Manager) deleteRouteRule(rule firewall.Rule) error {
return m.nativeFirewall.DeleteRouteRule(rule)
}
ruleKey := nbid.RuleID(rule.ID())
if _, ok := m.routeRulesMap[ruleKey]; !ok {
return fmt.Errorf("route rule not found: %s", ruleKey)
}
ruleID := rule.ID()
idx := slices.IndexFunc(m.routeRules, func(r *RouteRule) bool {
return r.id == string(ruleKey)
return r.id == ruleID
})
if idx < 0 {
return fmt.Errorf("route rule not found in slice: %s", ruleKey)
return fmt.Errorf("route rule not found: %s", ruleID)
}
m.routeRules = slices.Delete(m.routeRules, idx, idx+1)
delete(m.routeRulesMap, ruleKey)
return nil
}
@@ -586,48 +570,6 @@ func (m *Manager) SetLegacyManagement(isLegacy bool) error {
// Flush doesn't need to be implemented for this manager
func (m *Manager) Flush() error { return nil }
// resetState clears all firewall rules and closes connection trackers.
// Must be called with m.mutex held.
func (m *Manager) resetState() {
maps.Clear(m.outgoingRules)
maps.Clear(m.incomingDenyRules)
maps.Clear(m.incomingRules)
maps.Clear(m.routeRulesMap)
m.routeRules = m.routeRules[:0]
if m.udpTracker != nil {
m.udpTracker.Close()
}
if m.icmpTracker != nil {
m.icmpTracker.Close()
}
if m.tcpTracker != nil {
m.tcpTracker.Close()
}
if fwder := m.forwarder.Load(); fwder != nil {
fwder.Stop()
}
if m.logger != nil {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
if err := m.logger.Stop(ctx); err != nil {
log.Errorf("failed to shutdown logger: %v", err)
}
}
}
// SetupEBPFProxyNoTrack creates notrack rules for eBPF proxy loopback traffic.
func (m *Manager) SetupEBPFProxyNoTrack(proxyPort, wgPort uint16) error {
if m.nativeFirewall == nil {
return nil
}
return m.nativeFirewall.SetupEBPFProxyNoTrack(proxyPort, wgPort)
}
// UpdateSet updates the rule destinations associated with the given set
// by merging the existing prefixes with the new ones, then deduplicating.
func (m *Manager) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {

View File

@@ -1,376 +0,0 @@
package uspfilter
import (
"net/netip"
"testing"
"github.com/golang/mock/gomock"
"github.com/google/gopacket/layers"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
wgdevice "golang.zx2c4.com/wireguard/device"
fw "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/mocks"
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
// TestAddRouteFilteringReturnsExistingRule verifies that adding the same route
// filtering rule twice returns the same rule ID (idempotent behavior).
func TestAddRouteFilteringReturnsExistingRule(t *testing.T) {
manager := setupTestManager(t)
sources := []netip.Prefix{
netip.MustParsePrefix("100.64.1.0/24"),
netip.MustParsePrefix("100.64.2.0/24"),
}
destination := fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}
// Add rule first time
rule1, err := manager.AddRouteFiltering(
[]byte("policy-1"),
sources,
destination,
fw.ProtocolTCP,
nil,
&fw.Port{Values: []uint16{443}},
fw.ActionAccept,
)
require.NoError(t, err)
require.NotNil(t, rule1)
// Add the same rule again
rule2, err := manager.AddRouteFiltering(
[]byte("policy-1"),
sources,
destination,
fw.ProtocolTCP,
nil,
&fw.Port{Values: []uint16{443}},
fw.ActionAccept,
)
require.NoError(t, err)
require.NotNil(t, rule2)
// These should be the same (idempotent) like nftables/iptables implementations
assert.Equal(t, rule1.ID(), rule2.ID(),
"Adding the same rule twice should return the same rule ID (idempotent)")
manager.mutex.RLock()
ruleCount := len(manager.routeRules)
manager.mutex.RUnlock()
assert.Equal(t, 2, ruleCount,
"Should have exactly 2 rules (1 user rule + 1 block rule)")
}
// TestAddRouteFilteringDifferentRulesGetDifferentIDs verifies that rules with
// different parameters get distinct IDs.
func TestAddRouteFilteringDifferentRulesGetDifferentIDs(t *testing.T) {
manager := setupTestManager(t)
sources := []netip.Prefix{netip.MustParsePrefix("100.64.1.0/24")}
// Add first rule
rule1, err := manager.AddRouteFiltering(
[]byte("policy-1"),
sources,
fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
fw.ProtocolTCP,
nil,
&fw.Port{Values: []uint16{443}},
fw.ActionAccept,
)
require.NoError(t, err)
// Add different rule (different destination)
rule2, err := manager.AddRouteFiltering(
[]byte("policy-2"),
sources,
fw.Network{Prefix: netip.MustParsePrefix("192.168.2.0/24")}, // Different!
fw.ProtocolTCP,
nil,
&fw.Port{Values: []uint16{443}},
fw.ActionAccept,
)
require.NoError(t, err)
assert.NotEqual(t, rule1.ID(), rule2.ID(),
"Different rules should have different IDs")
manager.mutex.RLock()
ruleCount := len(manager.routeRules)
manager.mutex.RUnlock()
assert.Equal(t, 3, ruleCount, "Should have 3 rules (2 user rules + 1 block rule)")
}
// TestRouteRuleUpdateDoesNotCauseGap verifies that re-adding the same route
// rule during a network map update does not disrupt existing traffic.
func TestRouteRuleUpdateDoesNotCauseGap(t *testing.T) {
manager := setupTestManager(t)
sources := []netip.Prefix{netip.MustParsePrefix("100.64.1.0/24")}
destination := fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}
rule1, err := manager.AddRouteFiltering(
[]byte("policy-1"),
sources,
destination,
fw.ProtocolTCP,
nil,
nil,
fw.ActionAccept,
)
require.NoError(t, err)
srcIP := netip.MustParseAddr("100.64.1.5")
dstIP := netip.MustParseAddr("192.168.1.10")
_, pass := manager.routeACLsPass(srcIP, dstIP, layers.LayerTypeTCP, 12345, 443)
require.True(t, pass, "Traffic should pass with rule in place")
// Re-add same rule (simulates network map update)
rule2, err := manager.AddRouteFiltering(
[]byte("policy-1"),
sources,
destination,
fw.ProtocolTCP,
nil,
nil,
fw.ActionAccept,
)
require.NoError(t, err)
// Idempotent IDs mean rule1.ID() == rule2.ID(), so the ACL manager
// won't delete rule1 during cleanup. If IDs differed, deleting rule1
// would remove the only matching rule and cause a traffic gap.
if rule1.ID() != rule2.ID() {
err = manager.DeleteRouteRule(rule1)
require.NoError(t, err)
}
_, passAfter := manager.routeACLsPass(srcIP, dstIP, layers.LayerTypeTCP, 12345, 443)
assert.True(t, passAfter,
"Traffic should still pass after rule update - no gap should occur")
}
// TestBlockInvalidRoutedIdempotent verifies that blockInvalidRouted creates
// exactly one drop rule for the WireGuard network prefix, and calling it again
// returns the same rule without duplicating.
func TestBlockInvalidRoutedIdempotent(t *testing.T) {
ctrl := gomock.NewController(t)
dev := mocks.NewMockDevice(ctrl)
dev.EXPECT().MTU().Return(1500, nil).AnyTimes()
wgNet := netip.MustParsePrefix("100.64.0.1/16")
ifaceMock := &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
AddressFunc: func() wgaddr.Address {
return wgaddr.Address{
IP: wgNet.Addr(),
Network: wgNet,
}
},
GetDeviceFunc: func() *device.FilteredDevice {
return &device.FilteredDevice{Device: dev}
},
GetWGDeviceFunc: func() *wgdevice.Device {
return &wgdevice.Device{}
},
}
manager, err := Create(ifaceMock, false, flowLogger, iface.DefaultMTU)
require.NoError(t, err)
t.Cleanup(func() {
require.NoError(t, manager.Close(nil))
})
// Call blockInvalidRouted directly multiple times
rule1, err := manager.blockInvalidRouted(ifaceMock)
require.NoError(t, err)
require.NotNil(t, rule1)
rule2, err := manager.blockInvalidRouted(ifaceMock)
require.NoError(t, err)
require.NotNil(t, rule2)
rule3, err := manager.blockInvalidRouted(ifaceMock)
require.NoError(t, err)
require.NotNil(t, rule3)
// All should return the same rule
assert.Equal(t, rule1.ID(), rule2.ID(), "Second call should return same rule")
assert.Equal(t, rule2.ID(), rule3.ID(), "Third call should return same rule")
// Should have exactly 1 route rule
manager.mutex.RLock()
ruleCount := len(manager.routeRules)
manager.mutex.RUnlock()
assert.Equal(t, 1, ruleCount, "Should have exactly 1 block rule after 3 calls")
// Verify the rule blocks traffic to the WG network
srcIP := netip.MustParseAddr("10.0.0.1")
dstIP := netip.MustParseAddr("100.64.0.50")
_, pass := manager.routeACLsPass(srcIP, dstIP, layers.LayerTypeTCP, 12345, 80)
assert.False(t, pass, "Block rule should deny traffic to WG prefix")
}
// TestBlockRuleNotAccumulatedOnRepeatedEnableRouting verifies that calling
// EnableRouting multiple times (as happens on each route update) does not
// accumulate duplicate block rules in the routeRules slice.
func TestBlockRuleNotAccumulatedOnRepeatedEnableRouting(t *testing.T) {
ctrl := gomock.NewController(t)
dev := mocks.NewMockDevice(ctrl)
dev.EXPECT().MTU().Return(1500, nil).AnyTimes()
wgNet := netip.MustParsePrefix("100.64.0.1/16")
ifaceMock := &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
AddressFunc: func() wgaddr.Address {
return wgaddr.Address{
IP: wgNet.Addr(),
Network: wgNet,
}
},
GetDeviceFunc: func() *device.FilteredDevice {
return &device.FilteredDevice{Device: dev}
},
GetWGDeviceFunc: func() *wgdevice.Device {
return &wgdevice.Device{}
},
}
manager, err := Create(ifaceMock, false, flowLogger, iface.DefaultMTU)
require.NoError(t, err)
t.Cleanup(func() {
require.NoError(t, manager.Close(nil))
})
// Call EnableRouting multiple times (simulating repeated route updates)
for i := 0; i < 5; i++ {
require.NoError(t, manager.EnableRouting())
}
manager.mutex.RLock()
ruleCount := len(manager.routeRules)
manager.mutex.RUnlock()
assert.Equal(t, 1, ruleCount,
"Repeated EnableRouting should not accumulate block rules")
}
// TestRouteRuleCountStableAcrossUpdates verifies that adding the same route
// rule multiple times does not create duplicate entries.
func TestRouteRuleCountStableAcrossUpdates(t *testing.T) {
manager := setupTestManager(t)
sources := []netip.Prefix{netip.MustParsePrefix("100.64.1.0/24")}
destination := fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}
// Simulate 5 network map updates with the same route rule
for i := 0; i < 5; i++ {
rule, err := manager.AddRouteFiltering(
[]byte("policy-1"),
sources,
destination,
fw.ProtocolTCP,
nil,
&fw.Port{Values: []uint16{443}},
fw.ActionAccept,
)
require.NoError(t, err)
require.NotNil(t, rule)
}
manager.mutex.RLock()
ruleCount := len(manager.routeRules)
manager.mutex.RUnlock()
assert.Equal(t, 2, ruleCount,
"Should have exactly 2 rules (1 user rule + 1 block rule) after 5 updates")
}
// TestDeleteRouteRuleAfterIdempotentAdd verifies that deleting a route rule
// after adding it multiple times works correctly.
func TestDeleteRouteRuleAfterIdempotentAdd(t *testing.T) {
manager := setupTestManager(t)
sources := []netip.Prefix{netip.MustParsePrefix("100.64.1.0/24")}
destination := fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}
// Add same rule twice
rule1, err := manager.AddRouteFiltering(
[]byte("policy-1"),
sources,
destination,
fw.ProtocolTCP,
nil,
nil,
fw.ActionAccept,
)
require.NoError(t, err)
rule2, err := manager.AddRouteFiltering(
[]byte("policy-1"),
sources,
destination,
fw.ProtocolTCP,
nil,
nil,
fw.ActionAccept,
)
require.NoError(t, err)
require.Equal(t, rule1.ID(), rule2.ID(), "Should return same rule ID")
// Delete using first reference
err = manager.DeleteRouteRule(rule1)
require.NoError(t, err)
// Verify traffic no longer passes
srcIP := netip.MustParseAddr("100.64.1.5")
dstIP := netip.MustParseAddr("192.168.1.10")
_, pass := manager.routeACLsPass(srcIP, dstIP, layers.LayerTypeTCP, 12345, 443)
assert.False(t, pass, "Traffic should not pass after rule deletion")
}
func setupTestManager(t *testing.T) *Manager {
t.Helper()
ctrl := gomock.NewController(t)
dev := mocks.NewMockDevice(ctrl)
dev.EXPECT().MTU().Return(1500, nil).AnyTimes()
wgNet := netip.MustParsePrefix("100.64.0.1/16")
ifaceMock := &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
AddressFunc: func() wgaddr.Address {
return wgaddr.Address{
IP: wgNet.Addr(),
Network: wgNet,
}
},
GetDeviceFunc: func() *device.FilteredDevice {
return &device.FilteredDevice{Device: dev}
},
GetWGDeviceFunc: func() *wgdevice.Device {
return &wgdevice.Device{}
},
}
manager, err := Create(ifaceMock, false, flowLogger, iface.DefaultMTU)
require.NoError(t, err)
require.NoError(t, manager.EnableRouting())
t.Cleanup(func() {
require.NoError(t, manager.Close(nil))
})
return manager
}

View File

@@ -263,158 +263,6 @@ func TestAddUDPPacketHook(t *testing.T) {
}
}
// TestPeerRuleLifecycleDenyRules verifies that deny rules are correctly added
// to the deny map and can be cleanly deleted without leaving orphans.
func TestPeerRuleLifecycleDenyRules(t *testing.T) {
ifaceMock := &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}
m, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
require.NoError(t, err)
defer func() {
require.NoError(t, m.Close(nil))
}()
ip := net.ParseIP("192.168.1.1")
addr := netip.MustParseAddr("192.168.1.1")
// Add multiple deny rules for different ports
rule1, err := m.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil,
&fw.Port{Values: []uint16{22}}, fw.ActionDrop, "")
require.NoError(t, err)
rule2, err := m.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil,
&fw.Port{Values: []uint16{80}}, fw.ActionDrop, "")
require.NoError(t, err)
m.mutex.RLock()
denyCount := len(m.incomingDenyRules[addr])
m.mutex.RUnlock()
require.Equal(t, 2, denyCount, "Should have exactly 2 deny rules")
// Delete the first deny rule
err = m.DeletePeerRule(rule1[0])
require.NoError(t, err)
m.mutex.RLock()
denyCount = len(m.incomingDenyRules[addr])
m.mutex.RUnlock()
require.Equal(t, 1, denyCount, "Should have 1 deny rule after deleting first")
// Delete the second deny rule
err = m.DeletePeerRule(rule2[0])
require.NoError(t, err)
m.mutex.RLock()
_, exists := m.incomingDenyRules[addr]
m.mutex.RUnlock()
require.False(t, exists, "Deny rules IP entry should be cleaned up when empty")
}
// TestPeerRuleAddAndDeleteDontLeak verifies that repeatedly adding and deleting
// peer rules (simulating network map updates) does not leak rules in the maps.
func TestPeerRuleAddAndDeleteDontLeak(t *testing.T) {
ifaceMock := &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}
m, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
require.NoError(t, err)
defer func() {
require.NoError(t, m.Close(nil))
}()
ip := net.ParseIP("192.168.1.1")
addr := netip.MustParseAddr("192.168.1.1")
// Simulate 10 network map updates: add rule, delete old, add new
for i := 0; i < 10; i++ {
// Add a deny rule
rules, err := m.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil,
&fw.Port{Values: []uint16{22}}, fw.ActionDrop, "")
require.NoError(t, err)
// Add an allow rule
allowRules, err := m.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil,
&fw.Port{Values: []uint16{80}}, fw.ActionAccept, "")
require.NoError(t, err)
// Delete them (simulating ACL manager cleanup)
for _, r := range rules {
require.NoError(t, m.DeletePeerRule(r))
}
for _, r := range allowRules {
require.NoError(t, m.DeletePeerRule(r))
}
}
m.mutex.RLock()
denyCount := len(m.incomingDenyRules[addr])
allowCount := len(m.incomingRules[addr])
m.mutex.RUnlock()
require.Equal(t, 0, denyCount, "No deny rules should remain after cleanup")
require.Equal(t, 0, allowCount, "No allow rules should remain after cleanup")
}
// TestMixedAllowDenyRulesSameIP verifies that allow and deny rules for the same
// IP are stored in separate maps and don't interfere with each other.
func TestMixedAllowDenyRulesSameIP(t *testing.T) {
ifaceMock := &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}
m, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
require.NoError(t, err)
defer func() {
require.NoError(t, m.Close(nil))
}()
ip := net.ParseIP("192.168.1.1")
// Add allow rule for port 80
allowRule, err := m.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil,
&fw.Port{Values: []uint16{80}}, fw.ActionAccept, "")
require.NoError(t, err)
// Add deny rule for port 22
denyRule, err := m.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil,
&fw.Port{Values: []uint16{22}}, fw.ActionDrop, "")
require.NoError(t, err)
addr := netip.MustParseAddr("192.168.1.1")
m.mutex.RLock()
allowCount := len(m.incomingRules[addr])
denyCount := len(m.incomingDenyRules[addr])
m.mutex.RUnlock()
require.Equal(t, 1, allowCount, "Should have 1 allow rule")
require.Equal(t, 1, denyCount, "Should have 1 deny rule")
// Delete allow rule should not affect deny rule
err = m.DeletePeerRule(allowRule[0])
require.NoError(t, err)
m.mutex.RLock()
denyCountAfter := len(m.incomingDenyRules[addr])
m.mutex.RUnlock()
require.Equal(t, 1, denyCountAfter, "Deny rule should still exist after deleting allow rule")
// Delete deny rule
err = m.DeletePeerRule(denyRule[0])
require.NoError(t, err)
m.mutex.RLock()
_, denyExists := m.incomingDenyRules[addr]
_, allowExists := m.incomingRules[addr]
m.mutex.RUnlock()
require.False(t, denyExists, "Deny rules should be empty")
require.False(t, allowExists, "Allow rules should be empty")
}
func TestManagerReset(t *testing.T) {
ifaceMock := &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },

View File

@@ -5,8 +5,6 @@ import (
"context"
"fmt"
"io"
"os"
"strconv"
"sync"
"sync/atomic"
"time"
@@ -18,18 +16,9 @@ const (
maxBatchSize = 1024 * 16
maxMessageSize = 1024 * 2
defaultFlushInterval = 2 * time.Second
defaultLogChanSize = 1000
logChannelSize = 1000
)
func getLogChannelSize() int {
if v := os.Getenv("NB_USPFILTER_LOG_BUFFER"); v != "" {
if n, err := strconv.Atoi(v); err == nil && n > 0 {
return n
}
}
return defaultLogChanSize
}
type Level uint32
const (
@@ -80,7 +69,7 @@ type Logger struct {
func NewFromLogrus(logrusLogger *log.Logger) *Logger {
l := &Logger{
output: logrusLogger.Out,
msgChannel: make(chan logMessage, getLogChannelSize()),
msgChannel: make(chan logMessage, logChannelSize),
shutdown: make(chan struct{}),
bufPool: sync.Pool{
New: func() any {

View File

@@ -1,169 +0,0 @@
package bind
import (
"errors"
"net"
"sync"
"time"
"github.com/hashicorp/go-multierror"
log "github.com/sirupsen/logrus"
nberrors "github.com/netbirdio/netbird/client/errors"
)
var (
errNoIPv4Conn = errors.New("no IPv4 connection available")
errNoIPv6Conn = errors.New("no IPv6 connection available")
errInvalidAddr = errors.New("invalid address type")
)
// DualStackPacketConn wraps IPv4 and IPv6 UDP connections and routes writes
// to the appropriate connection based on the destination address.
// ReadFrom is not used in the hot path - ICEBind receives packets via
// BatchReader.ReadBatch() directly. This is only used by udpMux for sending.
type DualStackPacketConn struct {
ipv4Conn net.PacketConn
ipv6Conn net.PacketConn
readFromWarn sync.Once
}
// NewDualStackPacketConn creates a new dual-stack packet connection.
func NewDualStackPacketConn(ipv4Conn, ipv6Conn net.PacketConn) *DualStackPacketConn {
return &DualStackPacketConn{
ipv4Conn: ipv4Conn,
ipv6Conn: ipv6Conn,
}
}
// ReadFrom reads from the available connection (preferring IPv4).
// NOTE: This method is NOT used in the data path. ICEBind receives packets via
// BatchReader.ReadBatch() directly for both IPv4 and IPv6, which is much more efficient.
// This implementation exists only to satisfy the net.PacketConn interface for the udpMux,
// but the udpMux only uses WriteTo() for sending STUN responses - it never calls ReadFrom()
// because STUN packets are filtered and forwarded via HandleSTUNMessage() from the receive path.
func (d *DualStackPacketConn) ReadFrom(b []byte) (n int, addr net.Addr, err error) {
d.readFromWarn.Do(func() {
log.Warn("DualStackPacketConn.ReadFrom called - this is unexpected and may indicate an inefficient code path")
})
if d.ipv4Conn != nil {
return d.ipv4Conn.ReadFrom(b)
}
if d.ipv6Conn != nil {
return d.ipv6Conn.ReadFrom(b)
}
return 0, nil, net.ErrClosed
}
// WriteTo writes to the appropriate connection based on the address type.
func (d *DualStackPacketConn) WriteTo(b []byte, addr net.Addr) (n int, err error) {
udpAddr, ok := addr.(*net.UDPAddr)
if !ok {
return 0, &net.OpError{
Op: "write",
Net: "udp",
Addr: addr,
Err: errInvalidAddr,
}
}
if udpAddr.IP.To4() == nil {
if d.ipv6Conn != nil {
return d.ipv6Conn.WriteTo(b, addr)
}
return 0, &net.OpError{
Op: "write",
Net: "udp6",
Addr: addr,
Err: errNoIPv6Conn,
}
}
if d.ipv4Conn != nil {
return d.ipv4Conn.WriteTo(b, addr)
}
return 0, &net.OpError{
Op: "write",
Net: "udp4",
Addr: addr,
Err: errNoIPv4Conn,
}
}
// Close closes both connections.
func (d *DualStackPacketConn) Close() error {
var result *multierror.Error
if d.ipv4Conn != nil {
if err := d.ipv4Conn.Close(); err != nil {
result = multierror.Append(result, err)
}
}
if d.ipv6Conn != nil {
if err := d.ipv6Conn.Close(); err != nil {
result = multierror.Append(result, err)
}
}
return nberrors.FormatErrorOrNil(result)
}
// LocalAddr returns the local address of the IPv4 connection if available,
// otherwise the IPv6 connection.
func (d *DualStackPacketConn) LocalAddr() net.Addr {
if d.ipv4Conn != nil {
return d.ipv4Conn.LocalAddr()
}
if d.ipv6Conn != nil {
return d.ipv6Conn.LocalAddr()
}
return nil
}
// SetDeadline sets the deadline for both connections.
func (d *DualStackPacketConn) SetDeadline(t time.Time) error {
var result *multierror.Error
if d.ipv4Conn != nil {
if err := d.ipv4Conn.SetDeadline(t); err != nil {
result = multierror.Append(result, err)
}
}
if d.ipv6Conn != nil {
if err := d.ipv6Conn.SetDeadline(t); err != nil {
result = multierror.Append(result, err)
}
}
return nberrors.FormatErrorOrNil(result)
}
// SetReadDeadline sets the read deadline for both connections.
func (d *DualStackPacketConn) SetReadDeadline(t time.Time) error {
var result *multierror.Error
if d.ipv4Conn != nil {
if err := d.ipv4Conn.SetReadDeadline(t); err != nil {
result = multierror.Append(result, err)
}
}
if d.ipv6Conn != nil {
if err := d.ipv6Conn.SetReadDeadline(t); err != nil {
result = multierror.Append(result, err)
}
}
return nberrors.FormatErrorOrNil(result)
}
// SetWriteDeadline sets the write deadline for both connections.
func (d *DualStackPacketConn) SetWriteDeadline(t time.Time) error {
var result *multierror.Error
if d.ipv4Conn != nil {
if err := d.ipv4Conn.SetWriteDeadline(t); err != nil {
result = multierror.Append(result, err)
}
}
if d.ipv6Conn != nil {
if err := d.ipv6Conn.SetWriteDeadline(t); err != nil {
result = multierror.Append(result, err)
}
}
return nberrors.FormatErrorOrNil(result)
}

View File

@@ -1,119 +0,0 @@
package bind
import (
"net"
"testing"
)
var (
ipv4Addr = &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 12345}
ipv6Addr = &net.UDPAddr{IP: net.ParseIP("::1"), Port: 12345}
payload = make([]byte, 1200)
)
func BenchmarkWriteTo_DirectUDPConn(b *testing.B) {
conn, err := net.ListenUDP("udp4", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
if err != nil {
b.Fatal(err)
}
defer conn.Close()
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, _ = conn.WriteTo(payload, ipv4Addr)
}
}
func BenchmarkWriteTo_DualStack_IPv4Only(b *testing.B) {
conn, err := net.ListenUDP("udp4", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
if err != nil {
b.Fatal(err)
}
defer conn.Close()
ds := NewDualStackPacketConn(conn, nil)
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, _ = ds.WriteTo(payload, ipv4Addr)
}
}
func BenchmarkWriteTo_DualStack_IPv6Only(b *testing.B) {
conn, err := net.ListenUDP("udp6", &net.UDPAddr{IP: net.IPv6zero, Port: 0})
if err != nil {
b.Skipf("IPv6 not available: %v", err)
}
defer conn.Close()
ds := NewDualStackPacketConn(nil, conn)
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, _ = ds.WriteTo(payload, ipv6Addr)
}
}
func BenchmarkWriteTo_DualStack_Both_IPv4Traffic(b *testing.B) {
conn4, err := net.ListenUDP("udp4", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
if err != nil {
b.Fatal(err)
}
defer conn4.Close()
conn6, err := net.ListenUDP("udp6", &net.UDPAddr{IP: net.IPv6zero, Port: 0})
if err != nil {
b.Skipf("IPv6 not available: %v", err)
}
defer conn6.Close()
ds := NewDualStackPacketConn(conn4, conn6)
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, _ = ds.WriteTo(payload, ipv4Addr)
}
}
func BenchmarkWriteTo_DualStack_Both_IPv6Traffic(b *testing.B) {
conn4, err := net.ListenUDP("udp4", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
if err != nil {
b.Fatal(err)
}
defer conn4.Close()
conn6, err := net.ListenUDP("udp6", &net.UDPAddr{IP: net.IPv6zero, Port: 0})
if err != nil {
b.Skipf("IPv6 not available: %v", err)
}
defer conn6.Close()
ds := NewDualStackPacketConn(conn4, conn6)
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, _ = ds.WriteTo(payload, ipv6Addr)
}
}
func BenchmarkWriteTo_DualStack_Both_MixedTraffic(b *testing.B) {
conn4, err := net.ListenUDP("udp4", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
if err != nil {
b.Fatal(err)
}
defer conn4.Close()
conn6, err := net.ListenUDP("udp6", &net.UDPAddr{IP: net.IPv6zero, Port: 0})
if err != nil {
b.Skipf("IPv6 not available: %v", err)
}
defer conn6.Close()
ds := NewDualStackPacketConn(conn4, conn6)
addrs := []net.Addr{ipv4Addr, ipv6Addr}
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, _ = ds.WriteTo(payload, addrs[i&1])
}
}

View File

@@ -1,191 +0,0 @@
package bind
import (
"net"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestDualStackPacketConn_RoutesWritesToCorrectSocket(t *testing.T) {
ipv4Conn := &mockPacketConn{network: "udp4"}
ipv6Conn := &mockPacketConn{network: "udp6"}
dualStack := NewDualStackPacketConn(ipv4Conn, ipv6Conn)
tests := []struct {
name string
addr *net.UDPAddr
wantSocket string
}{
{
name: "IPv4 address",
addr: &net.UDPAddr{IP: net.ParseIP("192.168.1.1"), Port: 1234},
wantSocket: "udp4",
},
{
name: "IPv6 address",
addr: &net.UDPAddr{IP: net.ParseIP("2001:db8::1"), Port: 1234},
wantSocket: "udp6",
},
{
name: "IPv4-mapped IPv6 goes to IPv4",
addr: &net.UDPAddr{IP: net.ParseIP("::ffff:192.168.1.1"), Port: 1234},
wantSocket: "udp4",
},
{
name: "IPv4 loopback",
addr: &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 1234},
wantSocket: "udp4",
},
{
name: "IPv6 loopback",
addr: &net.UDPAddr{IP: net.ParseIP("::1"), Port: 1234},
wantSocket: "udp6",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ipv4Conn.writeCount = 0
ipv6Conn.writeCount = 0
n, err := dualStack.WriteTo([]byte("test"), tt.addr)
require.NoError(t, err)
assert.Equal(t, 4, n)
if tt.wantSocket == "udp4" {
assert.Equal(t, 1, ipv4Conn.writeCount, "expected write to IPv4")
assert.Equal(t, 0, ipv6Conn.writeCount, "expected no write to IPv6")
} else {
assert.Equal(t, 0, ipv4Conn.writeCount, "expected no write to IPv4")
assert.Equal(t, 1, ipv6Conn.writeCount, "expected write to IPv6")
}
})
}
}
func TestDualStackPacketConn_IPv4OnlyRejectsIPv6(t *testing.T) {
dualStack := NewDualStackPacketConn(&mockPacketConn{network: "udp4"}, nil)
// IPv4 works
_, err := dualStack.WriteTo([]byte("test"), &net.UDPAddr{IP: net.ParseIP("192.168.1.1"), Port: 1234})
require.NoError(t, err)
// IPv6 fails
_, err = dualStack.WriteTo([]byte("test"), &net.UDPAddr{IP: net.ParseIP("2001:db8::1"), Port: 1234})
require.Error(t, err)
assert.Contains(t, err.Error(), "no IPv6 connection")
}
func TestDualStackPacketConn_IPv6OnlyRejectsIPv4(t *testing.T) {
dualStack := NewDualStackPacketConn(nil, &mockPacketConn{network: "udp6"})
// IPv6 works
_, err := dualStack.WriteTo([]byte("test"), &net.UDPAddr{IP: net.ParseIP("2001:db8::1"), Port: 1234})
require.NoError(t, err)
// IPv4 fails
_, err = dualStack.WriteTo([]byte("test"), &net.UDPAddr{IP: net.ParseIP("192.168.1.1"), Port: 1234})
require.Error(t, err)
assert.Contains(t, err.Error(), "no IPv4 connection")
}
// TestDualStackPacketConn_ReadFromIsNotUsedInHotPath documents that ReadFrom
// only reads from one socket (IPv4 preferred). This is fine because the actual
// receive path uses wireguard-go's BatchReader directly, not ReadFrom.
func TestDualStackPacketConn_ReadFromIsNotUsedInHotPath(t *testing.T) {
ipv4Conn := &mockPacketConn{
network: "udp4",
readData: []byte("from ipv4"),
readAddr: &net.UDPAddr{IP: net.ParseIP("192.168.1.1"), Port: 1234},
}
ipv6Conn := &mockPacketConn{
network: "udp6",
readData: []byte("from ipv6"),
readAddr: &net.UDPAddr{IP: net.ParseIP("2001:db8::1"), Port: 1234},
}
dualStack := NewDualStackPacketConn(ipv4Conn, ipv6Conn)
buf := make([]byte, 100)
n, addr, err := dualStack.ReadFrom(buf)
require.NoError(t, err)
// reads from IPv4 (preferred) - this is expected behavior
assert.Equal(t, "from ipv4", string(buf[:n]))
assert.Equal(t, "192.168.1.1", addr.(*net.UDPAddr).IP.String())
}
func TestDualStackPacketConn_LocalAddrPrefersIPv4(t *testing.T) {
ipv4Addr := &net.UDPAddr{IP: net.ParseIP("0.0.0.0"), Port: 51820}
ipv6Addr := &net.UDPAddr{IP: net.ParseIP("::"), Port: 51820}
tests := []struct {
name string
ipv4 net.PacketConn
ipv6 net.PacketConn
wantAddr net.Addr
}{
{
name: "both available returns IPv4",
ipv4: &mockPacketConn{localAddr: ipv4Addr},
ipv6: &mockPacketConn{localAddr: ipv6Addr},
wantAddr: ipv4Addr,
},
{
name: "IPv4 only",
ipv4: &mockPacketConn{localAddr: ipv4Addr},
ipv6: nil,
wantAddr: ipv4Addr,
},
{
name: "IPv6 only",
ipv4: nil,
ipv6: &mockPacketConn{localAddr: ipv6Addr},
wantAddr: ipv6Addr,
},
{
name: "neither returns nil",
ipv4: nil,
ipv6: nil,
wantAddr: nil,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
dualStack := NewDualStackPacketConn(tt.ipv4, tt.ipv6)
assert.Equal(t, tt.wantAddr, dualStack.LocalAddr())
})
}
}
// mock
type mockPacketConn struct {
network string
writeCount int
readData []byte
readAddr net.Addr
localAddr net.Addr
}
func (m *mockPacketConn) ReadFrom(b []byte) (n int, addr net.Addr, err error) {
if m.readData != nil {
return copy(b, m.readData), m.readAddr, nil
}
return 0, nil, nil
}
func (m *mockPacketConn) WriteTo(b []byte, addr net.Addr) (n int, err error) {
m.writeCount++
return len(b), nil
}
func (m *mockPacketConn) Close() error { return nil }
func (m *mockPacketConn) LocalAddr() net.Addr { return m.localAddr }
func (m *mockPacketConn) SetDeadline(t time.Time) error { return nil }
func (m *mockPacketConn) SetReadDeadline(t time.Time) error { return nil }
func (m *mockPacketConn) SetWriteDeadline(t time.Time) error { return nil }

View File

@@ -14,6 +14,7 @@ import (
"github.com/pion/stun/v3"
"github.com/pion/transport/v3"
log "github.com/sirupsen/logrus"
"golang.org/x/net/ipv4"
"golang.org/x/net/ipv6"
wgConn "golang.zx2c4.com/wireguard/conn"
@@ -27,7 +28,22 @@ type receiverCreator struct {
}
func (rc receiverCreator) CreateReceiverFn(pc wgConn.BatchReader, conn *net.UDPConn, rxOffload bool, msgPool *sync.Pool) wgConn.ReceiveFunc {
return rc.iceBind.createReceiverFn(pc, conn, rxOffload, msgPool)
if ipv4PC, ok := pc.(*ipv4.PacketConn); ok {
return rc.iceBind.createIPv4ReceiverFn(ipv4PC, conn, rxOffload, msgPool)
}
// IPv6 is currently not supported in the udpmux, this is a stub for compatibility with the
// wireguard-go ReceiverCreator interface which is called for both IPv4 and IPv6.
return func(bufs [][]byte, sizes []int, eps []wgConn.Endpoint) (n int, err error) {
buf := bufs[0]
size, ep, err := conn.ReadFromUDPAddrPort(buf)
if err != nil {
return 0, err
}
sizes[0] = size
stdEp := &wgConn.StdNetEndpoint{AddrPort: ep}
eps[0] = stdEp
return 1, nil
}
}
// ICEBind is a bind implementation with two main features:
@@ -57,8 +73,6 @@ type ICEBind struct {
muUDPMux sync.Mutex
udpMux *udpmux.UniversalUDPMuxDefault
ipv4Conn *net.UDPConn
ipv6Conn *net.UDPConn
}
func NewICEBind(transportNet transport.Net, filterFn udpmux.FilterFn, address wgaddr.Address, mtu uint16) *ICEBind {
@@ -104,12 +118,6 @@ func (s *ICEBind) Close() error {
close(s.closedChan)
s.muUDPMux.Lock()
s.ipv4Conn = nil
s.ipv6Conn = nil
s.udpMux = nil
s.muUDPMux.Unlock()
return s.StdNetBind.Close()
}
@@ -167,18 +175,19 @@ func (b *ICEBind) Send(bufs [][]byte, ep wgConn.Endpoint) error {
return nil
}
func (s *ICEBind) createReceiverFn(pc wgConn.BatchReader, conn *net.UDPConn, rxOffload bool, msgsPool *sync.Pool) wgConn.ReceiveFunc {
func (s *ICEBind) createIPv4ReceiverFn(pc *ipv4.PacketConn, conn *net.UDPConn, rxOffload bool, msgsPool *sync.Pool) wgConn.ReceiveFunc {
s.muUDPMux.Lock()
defer s.muUDPMux.Unlock()
// Detect IPv4 vs IPv6 from connection's local address
if localAddr := conn.LocalAddr().(*net.UDPAddr); localAddr.IP.To4() != nil {
s.ipv4Conn = conn
} else {
s.ipv6Conn = conn
}
s.createOrUpdateMux()
s.udpMux = udpmux.NewUniversalUDPMuxDefault(
udpmux.UniversalUDPMuxParams{
UDPConn: nbnet.WrapPacketConn(conn),
Net: s.transportNet,
FilterFn: s.filterFn,
WGAddress: s.address,
MTU: s.mtu,
},
)
return func(bufs [][]byte, sizes []int, eps []wgConn.Endpoint) (n int, err error) {
msgs := getMessages(msgsPool)
for i := range bufs {
@@ -186,13 +195,12 @@ func (s *ICEBind) createReceiverFn(pc wgConn.BatchReader, conn *net.UDPConn, rxO
(*msgs)[i].OOB = (*msgs)[i].OOB[:cap((*msgs)[i].OOB)]
}
defer putMessages(msgs, msgsPool)
var numMsgs int
if runtime.GOOS == "linux" || runtime.GOOS == "android" {
if rxOffload {
readAt := len(*msgs) - (wgConn.IdealBatchSize / wgConn.UdpSegmentMaxDatagrams)
//nolint:staticcheck
_, err = pc.ReadBatch((*msgs)[readAt:], 0)
//nolint
numMsgs, err = pc.ReadBatch((*msgs)[readAt:], 0)
if err != nil {
return 0, err
}
@@ -214,12 +222,12 @@ func (s *ICEBind) createReceiverFn(pc wgConn.BatchReader, conn *net.UDPConn, rxO
}
numMsgs = 1
}
for i := 0; i < numMsgs; i++ {
msg := &(*msgs)[i]
// todo: handle err
if ok, _ := s.filterOutStunMessages(msg.Buffers, msg.N, msg.Addr); ok {
ok, _ := s.filterOutStunMessages(msg.Buffers, msg.N, msg.Addr)
if ok {
continue
}
sizes[i] = msg.N
@@ -240,38 +248,6 @@ func (s *ICEBind) createReceiverFn(pc wgConn.BatchReader, conn *net.UDPConn, rxO
}
}
// createOrUpdateMux creates or updates the UDP mux with the available connections.
// Must be called with muUDPMux held.
func (s *ICEBind) createOrUpdateMux() {
var muxConn net.PacketConn
switch {
case s.ipv4Conn != nil && s.ipv6Conn != nil:
muxConn = NewDualStackPacketConn(
nbnet.WrapPacketConn(s.ipv4Conn),
nbnet.WrapPacketConn(s.ipv6Conn),
)
case s.ipv4Conn != nil:
muxConn = nbnet.WrapPacketConn(s.ipv4Conn)
case s.ipv6Conn != nil:
muxConn = nbnet.WrapPacketConn(s.ipv6Conn)
default:
return
}
// Don't close the old mux - it doesn't own the underlying connections.
// The sockets are managed by WireGuard's StdNetBind, not by us.
s.udpMux = udpmux.NewUniversalUDPMuxDefault(
udpmux.UniversalUDPMuxParams{
UDPConn: muxConn,
Net: s.transportNet,
FilterFn: s.filterFn,
WGAddress: s.address,
MTU: s.mtu,
},
)
}
func (s *ICEBind) filterOutStunMessages(buffers [][]byte, n int, addr net.Addr) (bool, error) {
for i := range buffers {
if !stun.IsMessage(buffers[i]) {
@@ -284,14 +260,9 @@ func (s *ICEBind) filterOutStunMessages(buffers [][]byte, n int, addr net.Addr)
return true, err
}
s.muUDPMux.Lock()
mux := s.udpMux
s.muUDPMux.Unlock()
if mux != nil {
if muxErr := mux.HandleSTUNMessage(msg, addr); muxErr != nil {
log.Warnf("failed to handle STUN packet: %v", muxErr)
}
muxErr := s.udpMux.HandleSTUNMessage(msg, addr)
if muxErr != nil {
log.Warnf("failed to handle STUN packet")
}
buffers[i] = []byte{}

View File

@@ -1,324 +0,0 @@
package bind
import (
"fmt"
"net"
"net/netip"
"sync"
"testing"
"time"
"github.com/pion/transport/v3/stdnet"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/net/ipv4"
"golang.org/x/net/ipv6"
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
func TestICEBind_CreatesReceiverForBothIPv4AndIPv6(t *testing.T) {
iceBind := setupICEBind(t)
ipv4Conn, ipv6Conn := createDualStackConns(t)
defer ipv4Conn.Close()
defer ipv6Conn.Close()
rc := receiverCreator{iceBind}
pool := createMsgPool()
// Simulate wireguard-go calling CreateReceiverFn for IPv4
ipv4RecvFn := rc.CreateReceiverFn(ipv4.NewPacketConn(ipv4Conn), ipv4Conn, false, pool)
require.NotNil(t, ipv4RecvFn)
iceBind.muUDPMux.Lock()
assert.NotNil(t, iceBind.ipv4Conn, "should store IPv4 connection")
assert.Nil(t, iceBind.ipv6Conn, "IPv6 not added yet")
assert.NotNil(t, iceBind.udpMux, "mux should be created after first connection")
iceBind.muUDPMux.Unlock()
// Simulate wireguard-go calling CreateReceiverFn for IPv6
ipv6RecvFn := rc.CreateReceiverFn(ipv6.NewPacketConn(ipv6Conn), ipv6Conn, false, pool)
require.NotNil(t, ipv6RecvFn)
iceBind.muUDPMux.Lock()
assert.NotNil(t, iceBind.ipv4Conn, "should still have IPv4 connection")
assert.NotNil(t, iceBind.ipv6Conn, "should now have IPv6 connection")
assert.NotNil(t, iceBind.udpMux, "mux should still exist")
iceBind.muUDPMux.Unlock()
mux, err := iceBind.GetICEMux()
require.NoError(t, err)
require.NotNil(t, mux)
}
func TestICEBind_WorksWithIPv4Only(t *testing.T) {
iceBind := setupICEBind(t)
ipv4Conn, err := net.ListenUDP("udp4", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
require.NoError(t, err)
defer ipv4Conn.Close()
rc := receiverCreator{iceBind}
recvFn := rc.CreateReceiverFn(ipv4.NewPacketConn(ipv4Conn), ipv4Conn, false, createMsgPool())
require.NotNil(t, recvFn)
iceBind.muUDPMux.Lock()
assert.NotNil(t, iceBind.ipv4Conn)
assert.Nil(t, iceBind.ipv6Conn)
assert.NotNil(t, iceBind.udpMux)
iceBind.muUDPMux.Unlock()
mux, err := iceBind.GetICEMux()
require.NoError(t, err)
require.NotNil(t, mux)
}
func TestICEBind_WorksWithIPv6Only(t *testing.T) {
iceBind := setupICEBind(t)
ipv6Conn, err := net.ListenUDP("udp6", &net.UDPAddr{IP: net.IPv6zero, Port: 0})
if err != nil {
t.Skipf("IPv6 not available: %v", err)
}
defer ipv6Conn.Close()
rc := receiverCreator{iceBind}
recvFn := rc.CreateReceiverFn(ipv6.NewPacketConn(ipv6Conn), ipv6Conn, false, createMsgPool())
require.NotNil(t, recvFn)
iceBind.muUDPMux.Lock()
assert.Nil(t, iceBind.ipv4Conn)
assert.NotNil(t, iceBind.ipv6Conn)
assert.NotNil(t, iceBind.udpMux)
iceBind.muUDPMux.Unlock()
mux, err := iceBind.GetICEMux()
require.NoError(t, err)
require.NotNil(t, mux)
}
// TestICEBind_SendsToIPv4AndIPv6PeersSimultaneously verifies that we can communicate
// with peers on different address families through the same DualStackPacketConn.
func TestICEBind_SendsToIPv4AndIPv6PeersSimultaneously(t *testing.T) {
// two "remote peers" listening on different address families
ipv4Peer := listenUDP(t, "udp4", "127.0.0.1:0")
defer ipv4Peer.Close()
ipv6Peer, err := net.ListenUDP("udp6", &net.UDPAddr{IP: net.IPv6loopback, Port: 0})
if err != nil {
t.Skipf("IPv6 not available: %v", err)
}
defer ipv6Peer.Close()
// our local dual-stack connection
ipv4Local := listenUDP(t, "udp4", "127.0.0.1:0")
defer ipv4Local.Close()
ipv6Local := listenUDP(t, "udp6", "[::1]:0")
defer ipv6Local.Close()
dualStack := NewDualStackPacketConn(ipv4Local, ipv6Local)
// send to both peers
_, err = dualStack.WriteTo([]byte("to-ipv4"), ipv4Peer.LocalAddr())
require.NoError(t, err)
_, err = dualStack.WriteTo([]byte("to-ipv6"), ipv6Peer.LocalAddr())
require.NoError(t, err)
// verify IPv4 peer got its packet from the IPv4 socket
buf := make([]byte, 100)
_ = ipv4Peer.SetReadDeadline(time.Now().Add(time.Second))
n, addr, err := ipv4Peer.ReadFrom(buf)
require.NoError(t, err)
assert.Equal(t, "to-ipv4", string(buf[:n]))
assert.Equal(t, ipv4Local.LocalAddr().(*net.UDPAddr).Port, addr.(*net.UDPAddr).Port)
// verify IPv6 peer got its packet from the IPv6 socket
_ = ipv6Peer.SetReadDeadline(time.Now().Add(time.Second))
n, addr, err = ipv6Peer.ReadFrom(buf)
require.NoError(t, err)
assert.Equal(t, "to-ipv6", string(buf[:n]))
assert.Equal(t, ipv6Local.LocalAddr().(*net.UDPAddr).Port, addr.(*net.UDPAddr).Port)
}
// TestICEBind_HandlesConcurrentMixedTraffic sends packets concurrently to both IPv4
// and IPv6 peers. Verifies no packets get misrouted (IPv4 peer only gets v4- packets,
// IPv6 peer only gets v6- packets). Some packet loss is acceptable for UDP.
func TestICEBind_HandlesConcurrentMixedTraffic(t *testing.T) {
ipv4Peer := listenUDP(t, "udp4", "127.0.0.1:0")
defer ipv4Peer.Close()
ipv6Peer, err := net.ListenUDP("udp6", &net.UDPAddr{IP: net.IPv6loopback, Port: 0})
if err != nil {
t.Skipf("IPv6 not available: %v", err)
}
defer ipv6Peer.Close()
ipv4Local := listenUDP(t, "udp4", "127.0.0.1:0")
defer ipv4Local.Close()
ipv6Local := listenUDP(t, "udp6", "[::1]:0")
defer ipv6Local.Close()
dualStack := NewDualStackPacketConn(ipv4Local, ipv6Local)
const packetsPerFamily = 500
ipv4Received := make(chan string, packetsPerFamily)
ipv6Received := make(chan string, packetsPerFamily)
startGate := make(chan struct{})
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
buf := make([]byte, 100)
for i := 0; i < packetsPerFamily; i++ {
n, _, err := ipv4Peer.ReadFrom(buf)
if err != nil {
return
}
ipv4Received <- string(buf[:n])
}
}()
wg.Add(1)
go func() {
defer wg.Done()
buf := make([]byte, 100)
for i := 0; i < packetsPerFamily; i++ {
n, _, err := ipv6Peer.ReadFrom(buf)
if err != nil {
return
}
ipv6Received <- string(buf[:n])
}
}()
wg.Add(1)
go func() {
defer wg.Done()
<-startGate
for i := 0; i < packetsPerFamily; i++ {
_, _ = dualStack.WriteTo([]byte(fmt.Sprintf("v4-%04d", i)), ipv4Peer.LocalAddr())
}
}()
wg.Add(1)
go func() {
defer wg.Done()
<-startGate
for i := 0; i < packetsPerFamily; i++ {
_, _ = dualStack.WriteTo([]byte(fmt.Sprintf("v6-%04d", i)), ipv6Peer.LocalAddr())
}
}()
close(startGate)
time.AfterFunc(5*time.Second, func() {
_ = ipv4Peer.SetReadDeadline(time.Now())
_ = ipv6Peer.SetReadDeadline(time.Now())
})
wg.Wait()
close(ipv4Received)
close(ipv6Received)
ipv4Count := 0
for pkt := range ipv4Received {
require.True(t, len(pkt) >= 3 && pkt[:3] == "v4-", "IPv4 peer got misrouted packet: %s", pkt)
ipv4Count++
}
ipv6Count := 0
for pkt := range ipv6Received {
require.True(t, len(pkt) >= 3 && pkt[:3] == "v6-", "IPv6 peer got misrouted packet: %s", pkt)
ipv6Count++
}
assert.Equal(t, packetsPerFamily, ipv4Count)
assert.Equal(t, packetsPerFamily, ipv6Count)
}
func TestICEBind_DetectsAddressFamilyFromConnection(t *testing.T) {
tests := []struct {
name string
network string
addr string
wantIPv4 bool
}{
{"IPv4 any", "udp4", "0.0.0.0:0", true},
{"IPv4 loopback", "udp4", "127.0.0.1:0", true},
{"IPv6 any", "udp6", "[::]:0", false},
{"IPv6 loopback", "udp6", "[::1]:0", false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
addr, err := net.ResolveUDPAddr(tt.network, tt.addr)
require.NoError(t, err)
conn, err := net.ListenUDP(tt.network, addr)
if err != nil {
t.Skipf("%s not available: %v", tt.network, err)
}
defer conn.Close()
localAddr := conn.LocalAddr().(*net.UDPAddr)
isIPv4 := localAddr.IP.To4() != nil
assert.Equal(t, tt.wantIPv4, isIPv4)
})
}
}
// helpers
func setupICEBind(t *testing.T) *ICEBind {
t.Helper()
transportNet, err := stdnet.NewNet()
require.NoError(t, err)
address := wgaddr.Address{
IP: netip.MustParseAddr("100.64.0.1"),
Network: netip.MustParsePrefix("100.64.0.0/10"),
}
return NewICEBind(transportNet, nil, address, 1280)
}
func createDualStackConns(t *testing.T) (*net.UDPConn, *net.UDPConn) {
t.Helper()
ipv4Conn, err := net.ListenUDP("udp4", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
require.NoError(t, err)
ipv6Conn, err := net.ListenUDP("udp6", &net.UDPAddr{IP: net.IPv6zero, Port: 0})
if err != nil {
ipv4Conn.Close()
t.Skipf("IPv6 not available: %v", err)
}
return ipv4Conn, ipv6Conn
}
func createMsgPool() *sync.Pool {
return &sync.Pool{
New: func() any {
msgs := make([]ipv6.Message, 1)
for i := range msgs {
msgs[i].Buffers = make(net.Buffers, 1)
msgs[i].OOB = make([]byte, 0, 40)
}
return &msgs
},
}
}
func listenUDP(t *testing.T, network, addr string) *net.UDPConn {
t.Helper()
udpAddr, err := net.ResolveUDPAddr(network, addr)
require.NoError(t, err)
conn, err := net.ListenUDP(network, udpAddr)
require.NoError(t, err)
return conn
}

View File

@@ -3,22 +3,8 @@ package configurer
import (
"net"
"net/netip"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
// buildPresharedKeyConfig creates a wgtypes.Config for setting a preshared key on a peer.
// This is a shared helper used by both kernel and userspace configurers.
func buildPresharedKeyConfig(peerKey wgtypes.Key, psk wgtypes.Key, updateOnly bool) wgtypes.Config {
return wgtypes.Config{
Peers: []wgtypes.PeerConfig{{
PublicKey: peerKey,
PresharedKey: &psk,
UpdateOnly: updateOnly,
}},
}
}
func prefixesToIPNets(prefixes []netip.Prefix) []net.IPNet {
ipNets := make([]net.IPNet, len(prefixes))
for i, prefix := range prefixes {

View File

@@ -15,6 +15,8 @@ import (
"github.com/netbirdio/netbird/monotime"
)
var zeroKey wgtypes.Key
type KernelConfigurer struct {
deviceName string
}
@@ -46,18 +48,6 @@ func (c *KernelConfigurer) ConfigureInterface(privateKey string, port int) error
return nil
}
// SetPresharedKey sets the preshared key for a peer.
// If updateOnly is true, only updates the existing peer; if false, creates or updates.
func (c *KernelConfigurer) SetPresharedKey(peerKey string, psk wgtypes.Key, updateOnly bool) error {
parsedPeerKey, err := wgtypes.ParseKey(peerKey)
if err != nil {
return err
}
cfg := buildPresharedKeyConfig(parsedPeerKey, psk, updateOnly)
return c.configure(cfg)
}
func (c *KernelConfigurer) UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error {
peerKeyParsed, err := wgtypes.ParseKey(peerKey)
if err != nil {
@@ -289,7 +279,7 @@ func (c *KernelConfigurer) FullStats() (*Stats, error) {
TxBytes: p.TransmitBytes,
RxBytes: p.ReceiveBytes,
LastHandshake: p.LastHandshakeTime,
PresharedKey: [32]byte(p.PresharedKey),
PresharedKey: p.PresharedKey != zeroKey,
}
if p.Endpoint != nil {
peer.Endpoint = *p.Endpoint

View File

@@ -22,16 +22,17 @@ import (
)
const (
privateKey = "private_key"
ipcKeyLastHandshakeTimeSec = "last_handshake_time_sec"
ipcKeyTxBytes = "tx_bytes"
ipcKeyRxBytes = "rx_bytes"
allowedIP = "allowed_ip"
endpoint = "endpoint"
fwmark = "fwmark"
listenPort = "listen_port"
publicKey = "public_key"
presharedKey = "preshared_key"
privateKey = "private_key"
ipcKeyLastHandshakeTimeSec = "last_handshake_time_sec"
ipcKeyLastHandshakeTimeNsec = "last_handshake_time_nsec"
ipcKeyTxBytes = "tx_bytes"
ipcKeyRxBytes = "rx_bytes"
allowedIP = "allowed_ip"
endpoint = "endpoint"
fwmark = "fwmark"
listenPort = "listen_port"
publicKey = "public_key"
presharedKey = "preshared_key"
)
var ErrAllowedIPNotFound = fmt.Errorf("allowed IP not found")
@@ -71,18 +72,6 @@ func (c *WGUSPConfigurer) ConfigureInterface(privateKey string, port int) error
return c.device.IpcSet(toWgUserspaceString(config))
}
// SetPresharedKey sets the preshared key for a peer.
// If updateOnly is true, only updates the existing peer; if false, creates or updates.
func (c *WGUSPConfigurer) SetPresharedKey(peerKey string, psk wgtypes.Key, updateOnly bool) error {
parsedPeerKey, err := wgtypes.ParseKey(peerKey)
if err != nil {
return err
}
cfg := buildPresharedKeyConfig(parsedPeerKey, psk, updateOnly)
return c.device.IpcSet(toWgUserspaceString(cfg))
}
func (c *WGUSPConfigurer) UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error {
peerKeyParsed, err := wgtypes.ParseKey(peerKey)
if err != nil {
@@ -433,25 +422,13 @@ func toWgUserspaceString(wgCfg wgtypes.Config) string {
hexKey := hex.EncodeToString(p.PublicKey[:])
sb.WriteString(fmt.Sprintf("public_key=%s\n", hexKey))
if p.Remove {
sb.WriteString("remove=true\n")
}
if p.UpdateOnly {
sb.WriteString("update_only=true\n")
}
if p.PresharedKey != nil {
preSharedHexKey := hex.EncodeToString(p.PresharedKey[:])
sb.WriteString(fmt.Sprintf("preshared_key=%s\n", preSharedHexKey))
}
if p.Endpoint != nil {
sb.WriteString(fmt.Sprintf("endpoint=%s\n", p.Endpoint.String()))
}
if p.PersistentKeepaliveInterval != nil {
sb.WriteString(fmt.Sprintf("persistent_keepalive_interval=%d\n", int(p.PersistentKeepaliveInterval.Seconds())))
if p.Remove {
sb.WriteString("remove=true")
}
if p.ReplaceAllowedIPs {
@@ -461,6 +438,14 @@ func toWgUserspaceString(wgCfg wgtypes.Config) string {
for _, aip := range p.AllowedIPs {
sb.WriteString(fmt.Sprintf("allowed_ip=%s\n", aip.String()))
}
if p.Endpoint != nil {
sb.WriteString(fmt.Sprintf("endpoint=%s\n", p.Endpoint.String()))
}
if p.PersistentKeepaliveInterval != nil {
sb.WriteString(fmt.Sprintf("persistent_keepalive_interval=%d\n", int(p.PersistentKeepaliveInterval.Seconds())))
}
}
return sb.String()
}
@@ -558,7 +543,7 @@ func parseStatus(deviceName, ipcStr string) (*Stats, error) {
continue
}
host, portStr, err := net.SplitHostPort(val)
host, portStr, err := net.SplitHostPort(strings.Trim(val, "[]"))
if err != nil {
log.Errorf("failed to parse endpoint: %v", err)
continue
@@ -614,9 +599,7 @@ func parseStatus(deviceName, ipcStr string) (*Stats, error) {
continue
}
if val != "" && val != "0000000000000000000000000000000000000000000000000000000000000000" {
if pskKey, err := hexToWireguardKey(val); err == nil {
currentPeer.PresharedKey = [32]byte(pskKey)
}
currentPeer.PresharedKey = true
}
}
}

View File

@@ -12,7 +12,7 @@ type Peer struct {
TxBytes int64
RxBytes int64
LastHandshake time.Time
PresharedKey [32]byte
PresharedKey bool
}
type Stats struct {

View File

@@ -29,9 +29,8 @@ type PacketFilter interface {
type FilteredDevice struct {
tun.Device
filter PacketFilter
mutex sync.RWMutex
closeOnce sync.Once
filter PacketFilter
mutex sync.RWMutex
}
// newDeviceFilter constructor function
@@ -41,20 +40,6 @@ func newDeviceFilter(device tun.Device) *FilteredDevice {
}
}
// Close closes the underlying tun device exactly once.
// wireguard-go's netTun.Close() panics on double-close due to a bare close(channel),
// and multiple code paths can trigger Close on the same device.
func (d *FilteredDevice) Close() error {
var err error
d.closeOnce.Do(func() {
err = d.Device.Close()
})
if err != nil {
return err
}
return nil
}
// Read wraps read method with filtering feature
func (d *FilteredDevice) Read(bufs [][]byte, sizes []int, offset int) (n int, err error) {
if n, err = d.Device.Read(bufs, sizes, offset); err != nil {

View File

@@ -82,9 +82,7 @@ func (t *TunNetstackDevice) create() (WGConfigurer, error) {
t.configurer = configurer.NewUSPConfigurer(t.device, t.name, t.bind.ActivityRecorder())
err = t.configurer.ConfigureInterface(t.key, t.port)
if err != nil {
if cErr := tunIface.Close(); cErr != nil {
log.Debugf("failed to close tun device: %v", cErr)
}
_ = tunIface.Close()
return nil, fmt.Errorf("error configuring interface: %s", err)
}

View File

@@ -17,7 +17,6 @@ type WGConfigurer interface {
RemovePeer(peerKey string) error
AddAllowedIP(peerKey string, allowedIP netip.Prefix) error
RemoveAllowedIP(peerKey string, allowedIP netip.Prefix) error
SetPresharedKey(peerKey string, psk wgtypes.Key, updateOnly bool) error
Close()
GetStats() (map[string]configurer.WGStats, error)
FullStats() (*configurer.Stats, error)

View File

@@ -18,7 +18,6 @@ import (
"github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/device"
nbnetstack "github.com/netbirdio/netbird/client/iface/netstack"
"github.com/netbirdio/netbird/client/iface/udpmux"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/iface/wgproxy"
@@ -51,7 +50,6 @@ func ValidateMTU(mtu uint16) error {
type wgProxyFactory interface {
GetProxy() wgproxy.Proxy
GetProxyPort() uint16
Free() error
}
@@ -82,12 +80,6 @@ func (w *WGIface) GetProxy() wgproxy.Proxy {
return w.wgProxyFactory.GetProxy()
}
// GetProxyPort returns the proxy port used by the WireGuard proxy.
// Returns 0 if no proxy port is used (e.g., for userspace WireGuard).
func (w *WGIface) GetProxyPort() uint16 {
return w.wgProxyFactory.GetProxyPort()
}
// GetBind returns the EndpointManager userspace bind mode.
func (w *WGIface) GetBind() device.EndpointManager {
w.mu.Lock()
@@ -229,10 +221,6 @@ func (w *WGIface) Close() error {
result = multierror.Append(result, fmt.Errorf("failed to close wireguard interface %s: %w", w.Name(), err))
}
if nbnetstack.IsEnabled() {
return errors.FormatErrorOrNil(result)
}
if err := w.waitUntilRemoved(); err != nil {
log.Warnf("failed to remove WireGuard interface %s: %v", w.Name(), err)
if err := w.Destroy(); err != nil {
@@ -309,19 +297,6 @@ func (w *WGIface) FullStats() (*configurer.Stats, error) {
return w.configurer.FullStats()
}
// SetPresharedKey sets or updates the preshared key for a peer.
// If updateOnly is true, only updates existing peer; if false, creates or updates.
func (w *WGIface) SetPresharedKey(peerKey string, psk wgtypes.Key, updateOnly bool) error {
w.mu.Lock()
defer w.mu.Unlock()
if w.configurer == nil {
return ErrIfaceNotFound
}
return w.configurer.SetPresharedKey(peerKey, psk, updateOnly)
}
func (w *WGIface) waitUntilRemoved() error {
maxWaitTime := 5 * time.Second
timeout := time.NewTimer(maxWaitTime)

View File

@@ -23,10 +23,10 @@ func NewNSDialer(net *netstack.Net) *NSDialer {
}
func (d *NSDialer) Dial(ctx context.Context, network, addr string) (net.Conn, error) {
log.Debugf("dialing %s %s", network, addr)
log.Infof("NSDialer.Dial: network=%s, addr=%s", network, addr)
conn, err := d.net.Dial(network, addr)
if err != nil {
log.Debugf("failed to deal connection: %s", err)
log.Warnf("NSDialer.Dial failed: %s", err)
}
return conn, err
}

View File

@@ -66,7 +66,7 @@ func (t *NetStackTun) Create() (tun.Device, *netstack.Net, error) {
}
}()
return t.tundev, tunNet, nil
return nsTunDev, tunNet, nil
}
func (t *NetStackTun) Close() error {

View File

@@ -114,21 +114,21 @@ func (p *ProxyBind) Pause() {
}
func (p *ProxyBind) RedirectAs(endpoint *net.UDPAddr) {
ep, err := addrToEndpoint(endpoint)
if err != nil {
log.Errorf("failed to start package redirection: %v", err)
return
}
p.pausedCond.L.Lock()
p.paused = false
p.wgCurrentUsed = ep
p.wgCurrentUsed = addrToEndpoint(endpoint)
p.pausedCond.Signal()
p.pausedCond.L.Unlock()
}
func addrToEndpoint(addr *net.UDPAddr) *bind.Endpoint {
ip, _ := netip.AddrFromSlice(addr.IP.To4())
addrPort := netip.AddrPortFrom(ip, uint16(addr.Port))
return &bind.Endpoint{AddrPort: addrPort}
}
func (p *ProxyBind) CloseConn() error {
if p.cancel == nil {
return fmt.Errorf("proxy not started")
@@ -212,16 +212,3 @@ func fakeAddress(peerAddress *net.UDPAddr) (*netip.AddrPort, error) {
netipAddr := netip.AddrPortFrom(fakeIP, uint16(peerAddress.Port))
return &netipAddr, nil
}
func addrToEndpoint(addr *net.UDPAddr) (*bind.Endpoint, error) {
if addr == nil {
return nil, fmt.Errorf("invalid address")
}
ip, ok := netip.AddrFromSlice(addr.IP)
if !ok {
return nil, fmt.Errorf("convert %s to netip.Addr", addr)
}
addrPort := netip.AddrPortFrom(ip.Unmap(), uint16(addr.Port))
return &bind.Endpoint{AddrPort: addrPort}, nil
}

View File

@@ -8,6 +8,8 @@ import (
"net"
"sync"
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
"github.com/hashicorp/go-multierror"
"github.com/pion/transport/v3"
log "github.com/sirupsen/logrus"
@@ -24,10 +26,13 @@ const (
loopbackAddr = "127.0.0.1"
)
var (
localHostNetIP = net.ParseIP("127.0.0.1")
)
// WGEBPFProxy definition for proxy with EBPF support
type WGEBPFProxy struct {
localWGListenPort int
proxyPort int
mtu uint16
ebpfManager ebpfMgr.Manager
@@ -35,8 +40,7 @@ type WGEBPFProxy struct {
turnConnMutex sync.Mutex
lastUsedPort uint16
rawConnIPv4 net.PacketConn
rawConnIPv6 net.PacketConn
rawConn net.PacketConn
conn transport.UDPConn
ctx context.Context
@@ -58,39 +62,23 @@ func NewWGEBPFProxy(wgPort int, mtu uint16) *WGEBPFProxy {
// Listen load ebpf program and listen the proxy
func (p *WGEBPFProxy) Listen() error {
pl := portLookup{}
proxyPort, err := pl.searchFreePort()
if err != nil {
return err
}
p.proxyPort = proxyPort
// Prepare IPv4 raw socket (required)
p.rawConnIPv4, err = rawsocket.PrepareSenderRawSocketIPv4()
wgPorxyPort, err := pl.searchFreePort()
if err != nil {
return err
}
// Prepare IPv6 raw socket (optional)
p.rawConnIPv6, err = rawsocket.PrepareSenderRawSocketIPv6()
p.rawConn, err = rawsocket.PrepareSenderRawSocket()
if err != nil {
log.Warnf("failed to prepare IPv6 raw socket, continuing with IPv4 only: %v", err)
return err
}
err = p.ebpfManager.LoadWgProxy(proxyPort, p.localWGListenPort)
err = p.ebpfManager.LoadWgProxy(wgPorxyPort, p.localWGListenPort)
if err != nil {
if closeErr := p.rawConnIPv4.Close(); closeErr != nil {
log.Warnf("failed to close IPv4 raw socket: %v", closeErr)
}
if p.rawConnIPv6 != nil {
if closeErr := p.rawConnIPv6.Close(); closeErr != nil {
log.Warnf("failed to close IPv6 raw socket: %v", closeErr)
}
}
return err
}
addr := net.UDPAddr{
Port: proxyPort,
Port: wgPorxyPort,
IP: net.ParseIP(loopbackAddr),
}
@@ -106,7 +94,7 @@ func (p *WGEBPFProxy) Listen() error {
p.conn = conn
go p.proxyToRemote()
log.Infof("local wg proxy listening on: %d", proxyPort)
log.Infof("local wg proxy listening on: %d", wgPorxyPort)
return nil
}
@@ -147,25 +135,12 @@ func (p *WGEBPFProxy) Free() error {
result = multierror.Append(result, err)
}
if p.rawConnIPv4 != nil {
if err := p.rawConnIPv4.Close(); err != nil {
result = multierror.Append(result, err)
}
}
if p.rawConnIPv6 != nil {
if err := p.rawConnIPv6.Close(); err != nil {
result = multierror.Append(result, err)
}
if err := p.rawConn.Close(); err != nil {
result = multierror.Append(result, err)
}
return nberrors.FormatErrorOrNil(result)
}
// GetProxyPort returns the proxy listening port.
func (p *WGEBPFProxy) GetProxyPort() uint16 {
return uint16(p.proxyPort)
}
// proxyToRemote read messages from local WireGuard interface and forward it to remote conn
// From this go routine has only one instance.
func (p *WGEBPFProxy) proxyToRemote() {
@@ -241,3 +216,34 @@ generatePort:
}
return p.lastUsedPort, nil
}
func (p *WGEBPFProxy) sendPkg(data []byte, endpointAddr *net.UDPAddr) error {
payload := gopacket.Payload(data)
ipH := &layers.IPv4{
DstIP: localHostNetIP,
SrcIP: endpointAddr.IP,
Version: 4,
TTL: 64,
Protocol: layers.IPProtocolUDP,
}
udpH := &layers.UDP{
SrcPort: layers.UDPPort(endpointAddr.Port),
DstPort: layers.UDPPort(p.localWGListenPort),
}
err := udpH.SetNetworkLayerForChecksum(ipH)
if err != nil {
return fmt.Errorf("set network layer for checksum: %w", err)
}
layerBuffer := gopacket.NewSerializeBuffer()
err = gopacket.SerializeLayers(layerBuffer, gopacket.SerializeOptions{ComputeChecksums: true, FixLengths: true}, ipH, udpH, payload)
if err != nil {
return fmt.Errorf("serialize layers: %w", err)
}
if _, err = p.rawConn.WriteTo(layerBuffer.Bytes(), &net.IPAddr{IP: localHostNetIP}); err != nil {
return fmt.Errorf("write to raw conn: %w", err)
}
return nil
}

View File

@@ -10,89 +10,12 @@ import (
"net"
"sync"
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/iface/bufsize"
"github.com/netbirdio/netbird/client/iface/wgproxy/listener"
)
var (
errIPv6ConnNotAvailable = errors.New("IPv6 endpoint but rawConnIPv6 is not available")
errIPv4ConnNotAvailable = errors.New("IPv4 endpoint but rawConnIPv4 is not available")
localHostNetIPv4 = net.ParseIP("127.0.0.1")
localHostNetIPv6 = net.ParseIP("::1")
serializeOpts = gopacket.SerializeOptions{
ComputeChecksums: true,
FixLengths: true,
}
)
// PacketHeaders holds pre-created headers and buffers for efficient packet sending
type PacketHeaders struct {
ipH gopacket.SerializableLayer
udpH *layers.UDP
layerBuffer gopacket.SerializeBuffer
localHostAddr net.IP
isIPv4 bool
}
func NewPacketHeaders(localWGListenPort int, endpoint *net.UDPAddr) (*PacketHeaders, error) {
var ipH gopacket.SerializableLayer
var networkLayer gopacket.NetworkLayer
var localHostAddr net.IP
var isIPv4 bool
// Check if source address is IPv4 or IPv6
if endpoint.IP.To4() != nil {
// IPv4 path
ipv4 := &layers.IPv4{
DstIP: localHostNetIPv4,
SrcIP: endpoint.IP,
Version: 4,
TTL: 64,
Protocol: layers.IPProtocolUDP,
}
ipH = ipv4
networkLayer = ipv4
localHostAddr = localHostNetIPv4
isIPv4 = true
} else {
// IPv6 path
ipv6 := &layers.IPv6{
DstIP: localHostNetIPv6,
SrcIP: endpoint.IP,
Version: 6,
HopLimit: 64,
NextHeader: layers.IPProtocolUDP,
}
ipH = ipv6
networkLayer = ipv6
localHostAddr = localHostNetIPv6
isIPv4 = false
}
udpH := &layers.UDP{
SrcPort: layers.UDPPort(endpoint.Port),
DstPort: layers.UDPPort(localWGListenPort),
}
if err := udpH.SetNetworkLayerForChecksum(networkLayer); err != nil {
return nil, fmt.Errorf("set network layer for checksum: %w", err)
}
return &PacketHeaders{
ipH: ipH,
udpH: udpH,
layerBuffer: gopacket.NewSerializeBuffer(),
localHostAddr: localHostAddr,
isIPv4: isIPv4,
}, nil
}
// ProxyWrapper help to keep the remoteConn instance for net.Conn.Close function call
type ProxyWrapper struct {
wgeBPFProxy *WGEBPFProxy
@@ -101,10 +24,8 @@ type ProxyWrapper struct {
ctx context.Context
cancel context.CancelFunc
wgRelayedEndpointAddr *net.UDPAddr
headers *PacketHeaders
headerCurrentUsed *PacketHeaders
rawConn net.PacketConn
wgRelayedEndpointAddr *net.UDPAddr
wgEndpointCurrentUsedAddr *net.UDPAddr
paused bool
pausedCond *sync.Cond
@@ -120,32 +41,15 @@ func NewProxyWrapper(proxy *WGEBPFProxy) *ProxyWrapper {
closeListener: listener.NewCloseListener(),
}
}
func (p *ProxyWrapper) AddTurnConn(ctx context.Context, _ *net.UDPAddr, remoteConn net.Conn) error {
func (p *ProxyWrapper) AddTurnConn(ctx context.Context, endpoint *net.UDPAddr, remoteConn net.Conn) error {
addr, err := p.wgeBPFProxy.AddTurnConn(remoteConn)
if err != nil {
return fmt.Errorf("add turn conn: %w", err)
}
headers, err := NewPacketHeaders(p.wgeBPFProxy.localWGListenPort, addr)
if err != nil {
return fmt.Errorf("create packet sender: %w", err)
}
// Check if required raw connection is available
if !headers.isIPv4 && p.wgeBPFProxy.rawConnIPv6 == nil {
return errIPv6ConnNotAvailable
}
if headers.isIPv4 && p.wgeBPFProxy.rawConnIPv4 == nil {
return errIPv4ConnNotAvailable
}
p.remoteConn = remoteConn
p.ctx, p.cancel = context.WithCancel(ctx)
p.wgRelayedEndpointAddr = addr
p.headers = headers
p.rawConn = p.selectRawConn(headers)
return nil
return err
}
func (p *ProxyWrapper) EndpointAddr() *net.UDPAddr {
@@ -164,8 +68,7 @@ func (p *ProxyWrapper) Work() {
p.pausedCond.L.Lock()
p.paused = false
p.headerCurrentUsed = p.headers
p.rawConn = p.selectRawConn(p.headerCurrentUsed)
p.wgEndpointCurrentUsedAddr = p.wgRelayedEndpointAddr
if !p.isStarted {
p.isStarted = true
@@ -188,32 +91,10 @@ func (p *ProxyWrapper) Pause() {
}
func (p *ProxyWrapper) RedirectAs(endpoint *net.UDPAddr) {
if endpoint == nil || endpoint.IP == nil {
log.Errorf("failed to start package redirection, endpoint is nil")
return
}
header, err := NewPacketHeaders(p.wgeBPFProxy.localWGListenPort, endpoint)
if err != nil {
log.Errorf("failed to create packet headers: %s", err)
return
}
// Check if required raw connection is available
if !header.isIPv4 && p.wgeBPFProxy.rawConnIPv6 == nil {
log.Error(errIPv6ConnNotAvailable)
return
}
if header.isIPv4 && p.wgeBPFProxy.rawConnIPv4 == nil {
log.Error(errIPv4ConnNotAvailable)
return
}
p.pausedCond.L.Lock()
p.paused = false
p.headerCurrentUsed = header
p.rawConn = p.selectRawConn(header)
p.wgEndpointCurrentUsedAddr = endpoint
p.pausedCond.Signal()
p.pausedCond.L.Unlock()
@@ -255,7 +136,7 @@ func (p *ProxyWrapper) proxyToLocal(ctx context.Context) {
p.pausedCond.Wait()
}
err = p.sendPkg(buf[:n], p.headerCurrentUsed)
err = p.wgeBPFProxy.sendPkg(buf[:n], p.wgEndpointCurrentUsedAddr)
p.pausedCond.L.Unlock()
if err != nil {
@@ -281,29 +162,3 @@ func (p *ProxyWrapper) readFromRemote(ctx context.Context, buf []byte) (int, err
}
return n, nil
}
func (p *ProxyWrapper) sendPkg(data []byte, header *PacketHeaders) error {
defer func() {
if err := header.layerBuffer.Clear(); err != nil {
log.Errorf("failed to clear layer buffer: %s", err)
}
}()
payload := gopacket.Payload(data)
if err := gopacket.SerializeLayers(header.layerBuffer, serializeOpts, header.ipH, header.udpH, payload); err != nil {
return fmt.Errorf("serialize layers: %w", err)
}
if _, err := p.rawConn.WriteTo(header.layerBuffer.Bytes(), &net.IPAddr{IP: header.localHostAddr}); err != nil {
return fmt.Errorf("write to raw conn: %w", err)
}
return nil
}
func (p *ProxyWrapper) selectRawConn(header *PacketHeaders) net.PacketConn {
if header.isIPv4 {
return p.wgeBPFProxy.rawConnIPv4
}
return p.wgeBPFProxy.rawConnIPv6
}

View File

@@ -54,14 +54,6 @@ func (w *KernelFactory) GetProxy() Proxy {
return ebpf.NewProxyWrapper(w.ebpfProxy)
}
// GetProxyPort returns the eBPF proxy port, or 0 if eBPF is not active.
func (w *KernelFactory) GetProxyPort() uint16 {
if w.ebpfProxy == nil {
return 0
}
return w.ebpfProxy.GetProxyPort()
}
func (w *KernelFactory) Free() error {
if w.ebpfProxy == nil {
return nil

View File

@@ -24,11 +24,6 @@ func (w *USPFactory) GetProxy() Proxy {
return proxyBind.NewProxyBind(w.bind, w.mtu)
}
// GetProxyPort returns 0 as userspace WireGuard doesn't use a separate proxy port.
func (w *USPFactory) GetProxyPort() uint16 {
return 0
}
func (w *USPFactory) Free() error {
return nil
}

View File

@@ -8,87 +8,43 @@ import (
"os"
"syscall"
log "github.com/sirupsen/logrus"
"golang.org/x/sys/unix"
nbnet "github.com/netbirdio/netbird/client/net"
)
// PrepareSenderRawSocketIPv4 creates and configures a raw socket for sending IPv4 packets
func PrepareSenderRawSocketIPv4() (net.PacketConn, error) {
return prepareSenderRawSocket(syscall.AF_INET, true)
}
// PrepareSenderRawSocketIPv6 creates and configures a raw socket for sending IPv6 packets
func PrepareSenderRawSocketIPv6() (net.PacketConn, error) {
return prepareSenderRawSocket(syscall.AF_INET6, false)
}
func prepareSenderRawSocket(family int, isIPv4 bool) (net.PacketConn, error) {
func PrepareSenderRawSocket() (net.PacketConn, error) {
// Create a raw socket.
fd, err := syscall.Socket(family, syscall.SOCK_RAW, syscall.IPPROTO_RAW)
fd, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_RAW, syscall.IPPROTO_RAW)
if err != nil {
return nil, fmt.Errorf("creating raw socket failed: %w", err)
}
// Set the header include option on the socket to tell the kernel that headers are included in the packet.
// For IPv4, we need to set IP_HDRINCL. For IPv6, we need to set IPV6_HDRINCL to accept application-provided IPv6 headers.
if isIPv4 {
err = syscall.SetsockoptInt(fd, syscall.IPPROTO_IP, unix.IP_HDRINCL, 1)
if err != nil {
if closeErr := syscall.Close(fd); closeErr != nil {
log.Warnf("failed to close raw socket fd: %v", closeErr)
}
return nil, fmt.Errorf("setting IP_HDRINCL failed: %w", err)
}
} else {
err = syscall.SetsockoptInt(fd, syscall.IPPROTO_IPV6, unix.IPV6_HDRINCL, 1)
if err != nil {
if closeErr := syscall.Close(fd); closeErr != nil {
log.Warnf("failed to close raw socket fd: %v", closeErr)
}
return nil, fmt.Errorf("setting IPV6_HDRINCL failed: %w", err)
}
// Set the IP_HDRINCL option on the socket to tell the kernel that headers are included in the packet.
err = syscall.SetsockoptInt(fd, syscall.IPPROTO_IP, syscall.IP_HDRINCL, 1)
if err != nil {
return nil, fmt.Errorf("setting IP_HDRINCL failed: %w", err)
}
// Bind the socket to the "lo" interface.
err = syscall.SetsockoptString(fd, syscall.SOL_SOCKET, syscall.SO_BINDTODEVICE, "lo")
if err != nil {
if closeErr := syscall.Close(fd); closeErr != nil {
log.Warnf("failed to close raw socket fd: %v", closeErr)
}
return nil, fmt.Errorf("binding to lo interface failed: %w", err)
}
// Set the fwmark on the socket.
err = nbnet.SetSocketOpt(fd)
if err != nil {
if closeErr := syscall.Close(fd); closeErr != nil {
log.Warnf("failed to close raw socket fd: %v", closeErr)
}
return nil, fmt.Errorf("setting fwmark failed: %w", err)
}
// Convert the file descriptor to a PacketConn.
file := os.NewFile(uintptr(fd), fmt.Sprintf("fd %d", fd))
if file == nil {
if closeErr := syscall.Close(fd); closeErr != nil {
log.Warnf("failed to close raw socket fd: %v", closeErr)
}
return nil, fmt.Errorf("converting fd to file failed")
}
packetConn, err := net.FilePacketConn(file)
if err != nil {
if closeErr := file.Close(); closeErr != nil {
log.Warnf("failed to close file: %v", closeErr)
}
return nil, fmt.Errorf("converting file to packet conn failed: %w", err)
}
// Close the original file to release the FD (net.FilePacketConn duplicates it)
if closeErr := file.Close(); closeErr != nil {
log.Warnf("failed to close file after creating packet conn: %v", closeErr)
}
return packetConn, nil
}

View File

@@ -1,353 +0,0 @@
//go:build linux && !android
package wgproxy
import (
"context"
"net"
"testing"
"time"
"github.com/netbirdio/netbird/client/iface/wgproxy/ebpf"
"github.com/netbirdio/netbird/client/iface/wgproxy/udp"
)
// compareUDPAddr compares two UDP addresses, ignoring IPv6 zone IDs
// IPv6 link-local addresses include zone IDs (e.g., fe80::1%lo) which we should ignore
func compareUDPAddr(addr1, addr2 net.Addr) bool {
udpAddr1, ok1 := addr1.(*net.UDPAddr)
udpAddr2, ok2 := addr2.(*net.UDPAddr)
if !ok1 || !ok2 {
return addr1.String() == addr2.String()
}
// Compare IP and Port, ignoring zone
return udpAddr1.IP.Equal(udpAddr2.IP) && udpAddr1.Port == udpAddr2.Port
}
// TestRedirectAs_eBPF_IPv4 tests RedirectAs with eBPF proxy using IPv4 addresses
func TestRedirectAs_eBPF_IPv4(t *testing.T) {
wgPort := 51850
ebpfProxy := ebpf.NewWGEBPFProxy(wgPort, 1280)
if err := ebpfProxy.Listen(); err != nil {
t.Fatalf("failed to initialize ebpf proxy: %v", err)
}
defer func() {
if err := ebpfProxy.Free(); err != nil {
t.Errorf("failed to free ebpf proxy: %v", err)
}
}()
proxy := ebpf.NewProxyWrapper(ebpfProxy)
// NetBird UDP address of the remote peer
nbAddr := &net.UDPAddr{
IP: net.ParseIP("100.108.111.177"),
Port: 38746,
}
p2pEndpoint := &net.UDPAddr{
IP: net.ParseIP("192.168.0.56"),
Port: 51820,
}
testRedirectAs(t, proxy, wgPort, nbAddr, p2pEndpoint)
}
// TestRedirectAs_eBPF_IPv6 tests RedirectAs with eBPF proxy using IPv6 addresses
func TestRedirectAs_eBPF_IPv6(t *testing.T) {
wgPort := 51851
ebpfProxy := ebpf.NewWGEBPFProxy(wgPort, 1280)
if err := ebpfProxy.Listen(); err != nil {
t.Fatalf("failed to initialize ebpf proxy: %v", err)
}
defer func() {
if err := ebpfProxy.Free(); err != nil {
t.Errorf("failed to free ebpf proxy: %v", err)
}
}()
proxy := ebpf.NewProxyWrapper(ebpfProxy)
// NetBird UDP address of the remote peer
nbAddr := &net.UDPAddr{
IP: net.ParseIP("100.108.111.177"),
Port: 38746,
}
p2pEndpoint := &net.UDPAddr{
IP: net.ParseIP("fe80::56"),
Port: 51820,
}
testRedirectAs(t, proxy, wgPort, nbAddr, p2pEndpoint)
}
// TestRedirectAs_UDP_IPv4 tests RedirectAs with UDP proxy using IPv4 addresses
func TestRedirectAs_UDP_IPv4(t *testing.T) {
wgPort := 51852
proxy := udp.NewWGUDPProxy(wgPort, 1280)
// NetBird UDP address of the remote peer
nbAddr := &net.UDPAddr{
IP: net.ParseIP("100.108.111.177"),
Port: 38746,
}
p2pEndpoint := &net.UDPAddr{
IP: net.ParseIP("192.168.0.56"),
Port: 51820,
}
testRedirectAs(t, proxy, wgPort, nbAddr, p2pEndpoint)
}
// TestRedirectAs_UDP_IPv6 tests RedirectAs with UDP proxy using IPv6 addresses
func TestRedirectAs_UDP_IPv6(t *testing.T) {
wgPort := 51853
proxy := udp.NewWGUDPProxy(wgPort, 1280)
// NetBird UDP address of the remote peer
nbAddr := &net.UDPAddr{
IP: net.ParseIP("100.108.111.177"),
Port: 38746,
}
p2pEndpoint := &net.UDPAddr{
IP: net.ParseIP("fe80::56"),
Port: 51820,
}
testRedirectAs(t, proxy, wgPort, nbAddr, p2pEndpoint)
}
// testRedirectAs is a helper function that tests the RedirectAs functionality
// It verifies that:
// 1. Initial traffic from relay connection works
// 2. After calling RedirectAs, packets appear to come from the p2p endpoint
// 3. Multiple packets are correctly redirected with the new source address
func testRedirectAs(t *testing.T, proxy Proxy, wgPort int, nbAddr, p2pEndpoint *net.UDPAddr) {
t.Helper()
ctx := context.Background()
// Create WireGuard listeners on both IPv4 and IPv6 to support both P2P connection types
// In reality, WireGuard binds to a port and receives from both IPv4 and IPv6
wgListener4, err := net.ListenUDP("udp4", &net.UDPAddr{
IP: net.ParseIP("127.0.0.1"),
Port: wgPort,
})
if err != nil {
t.Fatalf("failed to create IPv4 WireGuard listener: %v", err)
}
defer wgListener4.Close()
wgListener6, err := net.ListenUDP("udp6", &net.UDPAddr{
IP: net.ParseIP("::1"),
Port: wgPort,
})
if err != nil {
t.Fatalf("failed to create IPv6 WireGuard listener: %v", err)
}
defer wgListener6.Close()
// Determine which listener to use based on the NetBird address IP version
// (this is where initial traffic will come from before RedirectAs is called)
var wgListener *net.UDPConn
if p2pEndpoint.IP.To4() == nil {
wgListener = wgListener6
} else {
wgListener = wgListener4
}
// Create relay server and connection
relayServer, err := net.ListenUDP("udp", &net.UDPAddr{
IP: net.ParseIP("127.0.0.1"),
Port: 0, // Random port
})
if err != nil {
t.Fatalf("failed to create relay server: %v", err)
}
defer relayServer.Close()
relayConn, err := net.Dial("udp", relayServer.LocalAddr().String())
if err != nil {
t.Fatalf("failed to create relay connection: %v", err)
}
defer relayConn.Close()
// Add TURN connection to proxy
if err := proxy.AddTurnConn(ctx, nbAddr, relayConn); err != nil {
t.Fatalf("failed to add TURN connection: %v", err)
}
defer func() {
if err := proxy.CloseConn(); err != nil {
t.Errorf("failed to close proxy connection: %v", err)
}
}()
// Start the proxy
proxy.Work()
// Phase 1: Test initial relay traffic
msgFromRelay := []byte("hello from relay")
if _, err := relayServer.WriteTo(msgFromRelay, relayConn.LocalAddr()); err != nil {
t.Fatalf("failed to write to relay server: %v", err)
}
// Set read deadline to avoid hanging
if err := wgListener4.SetReadDeadline(time.Now().Add(2 * time.Second)); err != nil {
t.Fatalf("failed to set read deadline: %v", err)
}
buf := make([]byte, 1024)
n, _, err := wgListener4.ReadFrom(buf)
if err != nil {
t.Fatalf("failed to read from WireGuard listener: %v", err)
}
if n != len(msgFromRelay) {
t.Errorf("expected %d bytes, got %d", len(msgFromRelay), n)
}
if string(buf[:n]) != string(msgFromRelay) {
t.Errorf("expected message %q, got %q", msgFromRelay, buf[:n])
}
// Phase 2: Redirect to p2p endpoint
proxy.RedirectAs(p2pEndpoint)
// Give the proxy a moment to process the redirect
time.Sleep(100 * time.Millisecond)
// Phase 3: Test redirected traffic
redirectedMessages := [][]byte{
[]byte("redirected message 1"),
[]byte("redirected message 2"),
[]byte("redirected message 3"),
}
for i, msg := range redirectedMessages {
if _, err := relayServer.WriteTo(msg, relayConn.LocalAddr()); err != nil {
t.Fatalf("failed to write redirected message %d: %v", i+1, err)
}
if err := wgListener.SetReadDeadline(time.Now().Add(2 * time.Second)); err != nil {
t.Fatalf("failed to set read deadline: %v", err)
}
n, srcAddr, err := wgListener.ReadFrom(buf)
if err != nil {
t.Fatalf("failed to read redirected message %d: %v", i+1, err)
}
// Verify message content
if string(buf[:n]) != string(msg) {
t.Errorf("message %d: expected %q, got %q", i+1, msg, buf[:n])
}
// Verify source address matches p2p endpoint (this is the key test)
// Use compareUDPAddr to ignore IPv6 zone IDs
if !compareUDPAddr(srcAddr, p2pEndpoint) {
t.Errorf("message %d: expected source address %s, got %s",
i+1, p2pEndpoint.String(), srcAddr.String())
}
}
}
// TestRedirectAs_Multiple_Switches tests switching between multiple endpoints
func TestRedirectAs_Multiple_Switches(t *testing.T) {
wgPort := 51856
ebpfProxy := ebpf.NewWGEBPFProxy(wgPort, 1280)
if err := ebpfProxy.Listen(); err != nil {
t.Fatalf("failed to initialize ebpf proxy: %v", err)
}
defer func() {
if err := ebpfProxy.Free(); err != nil {
t.Errorf("failed to free ebpf proxy: %v", err)
}
}()
proxy := ebpf.NewProxyWrapper(ebpfProxy)
ctx := context.Background()
// Create WireGuard listener
wgListener, err := net.ListenUDP("udp4", &net.UDPAddr{
IP: net.ParseIP("127.0.0.1"),
Port: wgPort,
})
if err != nil {
t.Fatalf("failed to create WireGuard listener: %v", err)
}
defer wgListener.Close()
// Create relay server and connection
relayServer, err := net.ListenUDP("udp", &net.UDPAddr{
IP: net.ParseIP("127.0.0.1"),
Port: 0,
})
if err != nil {
t.Fatalf("failed to create relay server: %v", err)
}
defer relayServer.Close()
relayConn, err := net.Dial("udp", relayServer.LocalAddr().String())
if err != nil {
t.Fatalf("failed to create relay connection: %v", err)
}
defer relayConn.Close()
nbAddr := &net.UDPAddr{
IP: net.ParseIP("100.108.111.177"),
Port: 38746,
}
if err := proxy.AddTurnConn(ctx, nbAddr, relayConn); err != nil {
t.Fatalf("failed to add TURN connection: %v", err)
}
defer func() {
if err := proxy.CloseConn(); err != nil {
t.Errorf("failed to close proxy connection: %v", err)
}
}()
proxy.Work()
// Test switching between multiple endpoints - using addresses in local subnet
endpoints := []*net.UDPAddr{
{IP: net.ParseIP("192.168.0.100"), Port: 51820},
{IP: net.ParseIP("192.168.0.101"), Port: 51821},
{IP: net.ParseIP("192.168.0.102"), Port: 51822},
}
for i, endpoint := range endpoints {
proxy.RedirectAs(endpoint)
time.Sleep(100 * time.Millisecond)
msg := []byte("test message")
if _, err := relayServer.WriteTo(msg, relayConn.LocalAddr()); err != nil {
t.Fatalf("failed to write message for endpoint %d: %v", i, err)
}
buf := make([]byte, 1024)
if err := wgListener.SetReadDeadline(time.Now().Add(2 * time.Second)); err != nil {
t.Fatalf("failed to set read deadline: %v", err)
}
n, srcAddr, err := wgListener.ReadFrom(buf)
if err != nil {
t.Fatalf("failed to read message for endpoint %d: %v", i, err)
}
if string(buf[:n]) != string(msg) {
t.Errorf("endpoint %d: expected message %q, got %q", i, msg, buf[:n])
}
if !compareUDPAddr(srcAddr, endpoint) {
t.Errorf("endpoint %d: expected source %s, got %s",
i, endpoint.String(), srcAddr.String())
}
}
}

View File

@@ -56,7 +56,7 @@ func NewWGUDPProxy(wgPort int, mtu uint16) *WGUDPProxy {
// the connection is complete, an error is returned. Once successfully
// connected, any expiration of the context will not affect the
// connection.
func (p *WGUDPProxy) AddTurnConn(ctx context.Context, _ *net.UDPAddr, remoteConn net.Conn) error {
func (p *WGUDPProxy) AddTurnConn(ctx context.Context, endpoint *net.UDPAddr, remoteConn net.Conn) error {
dialer := net.Dialer{}
localConn, err := dialer.DialContext(ctx, "udp", fmt.Sprintf(":%d", p.localWGListenPort))
if err != nil {

View File

@@ -19,56 +19,37 @@ var (
FixLengths: true,
}
localHostNetIPAddrV4 = &net.IPAddr{
localHostNetIPAddr = &net.IPAddr{
IP: net.ParseIP("127.0.0.1"),
}
localHostNetIPAddrV6 = &net.IPAddr{
IP: net.ParseIP("::1"),
}
)
type SrcFaker struct {
srcAddr *net.UDPAddr
rawSocket net.PacketConn
ipH gopacket.SerializableLayer
udpH gopacket.SerializableLayer
layerBuffer gopacket.SerializeBuffer
localHostAddr *net.IPAddr
rawSocket net.PacketConn
ipH gopacket.SerializableLayer
udpH gopacket.SerializableLayer
layerBuffer gopacket.SerializeBuffer
}
func NewSrcFaker(dstPort int, srcAddr *net.UDPAddr) (*SrcFaker, error) {
// Create only the raw socket for the address family we need
var rawSocket net.PacketConn
var err error
var localHostAddr *net.IPAddr
if srcAddr.IP.To4() != nil {
rawSocket, err = rawsocket.PrepareSenderRawSocketIPv4()
localHostAddr = localHostNetIPAddrV4
} else {
rawSocket, err = rawsocket.PrepareSenderRawSocketIPv6()
localHostAddr = localHostNetIPAddrV6
}
rawSocket, err := rawsocket.PrepareSenderRawSocket()
if err != nil {
return nil, err
}
ipH, udpH, err := prepareHeaders(dstPort, srcAddr)
if err != nil {
if closeErr := rawSocket.Close(); closeErr != nil {
log.Warnf("failed to close raw socket: %v", closeErr)
}
return nil, err
}
f := &SrcFaker{
srcAddr: srcAddr,
rawSocket: rawSocket,
ipH: ipH,
udpH: udpH,
layerBuffer: gopacket.NewSerializeBuffer(),
localHostAddr: localHostAddr,
srcAddr: srcAddr,
rawSocket: rawSocket,
ipH: ipH,
udpH: udpH,
layerBuffer: gopacket.NewSerializeBuffer(),
}
return f, nil
@@ -91,7 +72,7 @@ func (f *SrcFaker) SendPkg(data []byte) (int, error) {
if err != nil {
return 0, fmt.Errorf("serialize layers: %w", err)
}
n, err := f.rawSocket.WriteTo(f.layerBuffer.Bytes(), f.localHostAddr)
n, err := f.rawSocket.WriteTo(f.layerBuffer.Bytes(), localHostNetIPAddr)
if err != nil {
return 0, fmt.Errorf("write to raw conn: %w", err)
}
@@ -99,40 +80,19 @@ func (f *SrcFaker) SendPkg(data []byte) (int, error) {
}
func prepareHeaders(dstPort int, srcAddr *net.UDPAddr) (gopacket.SerializableLayer, gopacket.SerializableLayer, error) {
var ipH gopacket.SerializableLayer
var networkLayer gopacket.NetworkLayer
// Check if source IP is IPv4 or IPv6
if srcAddr.IP.To4() != nil {
// IPv4
ipv4 := &layers.IPv4{
DstIP: localHostNetIPAddrV4.IP,
SrcIP: srcAddr.IP,
Version: 4,
TTL: 64,
Protocol: layers.IPProtocolUDP,
}
ipH = ipv4
networkLayer = ipv4
} else {
// IPv6
ipv6 := &layers.IPv6{
DstIP: localHostNetIPAddrV6.IP,
SrcIP: srcAddr.IP,
Version: 6,
HopLimit: 64,
NextHeader: layers.IPProtocolUDP,
}
ipH = ipv6
networkLayer = ipv6
ipH := &layers.IPv4{
DstIP: net.ParseIP("127.0.0.1"),
SrcIP: srcAddr.IP,
Version: 4,
TTL: 64,
Protocol: layers.IPProtocolUDP,
}
udpH := &layers.UDP{
SrcPort: layers.UDPPort(srcAddr.Port),
DstPort: layers.UDPPort(dstPort), // dst is the localhost WireGuard port
}
err := udpH.SetNetworkLayerForChecksum(networkLayer)
err := udpH.SetNetworkLayerForChecksum(ipH)
if err != nil {
return nil, nil, fmt.Errorf("set network layer for checksum: %w", err)
}

View File

@@ -189,212 +189,6 @@ func TestDefaultManagerStateless(t *testing.T) {
})
}
// TestDenyRulesNotAccumulatedOnRepeatedApply verifies that applying the same
// deny rules repeatedly does not accumulate duplicate rules in the uspfilter.
// This tests the full ACL manager -> uspfilter integration.
func TestDenyRulesNotAccumulatedOnRepeatedApply(t *testing.T) {
t.Setenv("NB_WG_KERNEL_DISABLED", "true")
networkMap := &mgmProto.NetworkMap{
FirewallRules: []*mgmProto.FirewallRule{
{
PeerIP: "10.93.0.1",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_DROP,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "22",
},
{
PeerIP: "10.93.0.2",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_DROP,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "80",
},
{
PeerIP: "10.93.0.3",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "443",
},
},
FirewallRulesIsEmpty: false,
}
ctrl := gomock.NewController(t)
defer ctrl.Finish()
ifaceMock := mocks.NewMockIFaceMapper(ctrl)
ifaceMock.EXPECT().IsUserspaceBind().Return(true).AnyTimes()
ifaceMock.EXPECT().SetFilter(gomock.Any())
network := netip.MustParsePrefix("172.0.0.1/32")
ifaceMock.EXPECT().Name().Return("lo").AnyTimes()
ifaceMock.EXPECT().Address().Return(wgaddr.Address{
IP: network.Addr(),
Network: network,
}).AnyTimes()
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false, iface.DefaultMTU)
require.NoError(t, err)
defer func() {
require.NoError(t, fw.Close(nil))
}()
acl := NewDefaultManager(fw)
// Apply the same rules 5 times (simulating repeated network map updates)
for i := 0; i < 5; i++ {
acl.ApplyFiltering(networkMap, false)
}
// The ACL manager should track exactly 3 rule pairs (2 deny + 1 accept inbound)
assert.Equal(t, 3, len(acl.peerRulesPairs),
"Should have exactly 3 rule pairs after 5 identical updates")
}
// TestDenyRulesCleanedUpOnRemoval verifies that deny rules are properly cleaned
// up when they're removed from the network map in a subsequent update.
func TestDenyRulesCleanedUpOnRemoval(t *testing.T) {
t.Setenv("NB_WG_KERNEL_DISABLED", "true")
ctrl := gomock.NewController(t)
defer ctrl.Finish()
ifaceMock := mocks.NewMockIFaceMapper(ctrl)
ifaceMock.EXPECT().IsUserspaceBind().Return(true).AnyTimes()
ifaceMock.EXPECT().SetFilter(gomock.Any())
network := netip.MustParsePrefix("172.0.0.1/32")
ifaceMock.EXPECT().Name().Return("lo").AnyTimes()
ifaceMock.EXPECT().Address().Return(wgaddr.Address{
IP: network.Addr(),
Network: network,
}).AnyTimes()
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false, iface.DefaultMTU)
require.NoError(t, err)
defer func() {
require.NoError(t, fw.Close(nil))
}()
acl := NewDefaultManager(fw)
// First update: add deny and accept rules
networkMap1 := &mgmProto.NetworkMap{
FirewallRules: []*mgmProto.FirewallRule{
{
PeerIP: "10.93.0.1",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_DROP,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "22",
},
{
PeerIP: "10.93.0.2",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "443",
},
},
FirewallRulesIsEmpty: false,
}
acl.ApplyFiltering(networkMap1, false)
assert.Equal(t, 2, len(acl.peerRulesPairs), "Should have 2 rules after first update")
// Second update: remove the deny rule, keep only accept
networkMap2 := &mgmProto.NetworkMap{
FirewallRules: []*mgmProto.FirewallRule{
{
PeerIP: "10.93.0.2",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "443",
},
},
FirewallRulesIsEmpty: false,
}
acl.ApplyFiltering(networkMap2, false)
assert.Equal(t, 1, len(acl.peerRulesPairs),
"Should have 1 rule after removing deny rule")
// Third update: remove all rules
networkMap3 := &mgmProto.NetworkMap{
FirewallRules: []*mgmProto.FirewallRule{},
FirewallRulesIsEmpty: true,
}
acl.ApplyFiltering(networkMap3, false)
assert.Equal(t, 0, len(acl.peerRulesPairs),
"Should have 0 rules after removing all rules")
}
// TestRuleUpdateChangingAction verifies that when a rule's action changes from
// accept to deny (or vice versa), the old rule is properly removed and the new
// one added without leaking.
func TestRuleUpdateChangingAction(t *testing.T) {
t.Setenv("NB_WG_KERNEL_DISABLED", "true")
ctrl := gomock.NewController(t)
defer ctrl.Finish()
ifaceMock := mocks.NewMockIFaceMapper(ctrl)
ifaceMock.EXPECT().IsUserspaceBind().Return(true).AnyTimes()
ifaceMock.EXPECT().SetFilter(gomock.Any())
network := netip.MustParsePrefix("172.0.0.1/32")
ifaceMock.EXPECT().Name().Return("lo").AnyTimes()
ifaceMock.EXPECT().Address().Return(wgaddr.Address{
IP: network.Addr(),
Network: network,
}).AnyTimes()
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false, iface.DefaultMTU)
require.NoError(t, err)
defer func() {
require.NoError(t, fw.Close(nil))
}()
acl := NewDefaultManager(fw)
// First update: accept rule
networkMap := &mgmProto.NetworkMap{
FirewallRules: []*mgmProto.FirewallRule{
{
PeerIP: "10.93.0.1",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "22",
},
},
FirewallRulesIsEmpty: false,
}
acl.ApplyFiltering(networkMap, false)
assert.Equal(t, 1, len(acl.peerRulesPairs))
// Second update: change to deny (same IP/port/proto, different action)
networkMap.FirewallRules = []*mgmProto.FirewallRule{
{
PeerIP: "10.93.0.1",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_DROP,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "22",
},
}
acl.ApplyFiltering(networkMap, false)
// Should still have exactly 1 rule (the old accept removed, new deny added)
assert.Equal(t, 1, len(acl.peerRulesPairs),
"Changing action should result in exactly 1 rule, not 2")
}
func TestPortInfoEmpty(t *testing.T) {
tests := []struct {
name string

View File

@@ -1,499 +0,0 @@
package auth
import (
"context"
"net/url"
"sync"
"time"
"github.com/cenkalti/backoff/v4"
"github.com/google/uuid"
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"github.com/netbirdio/netbird/client/internal/profilemanager"
"github.com/netbirdio/netbird/client/ssh"
"github.com/netbirdio/netbird/client/system"
mgm "github.com/netbirdio/netbird/shared/management/client"
"github.com/netbirdio/netbird/shared/management/client/common"
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
)
// Auth manages authentication operations with the management server
// It maintains a long-lived connection and automatically handles reconnection with backoff
type Auth struct {
mutex sync.RWMutex
client *mgm.GrpcClient
config *profilemanager.Config
privateKey wgtypes.Key
mgmURL *url.URL
mgmTLSEnabled bool
}
// NewAuth creates a new Auth instance that manages authentication flows
// It establishes a connection to the management server that will be reused for all operations
// The connection is automatically recreated with backoff if it becomes disconnected
func NewAuth(ctx context.Context, privateKey string, mgmURL *url.URL, config *profilemanager.Config) (*Auth, error) {
// Validate WireGuard private key
myPrivateKey, err := wgtypes.ParseKey(privateKey)
if err != nil {
return nil, err
}
// Determine TLS setting based on URL scheme
mgmTLSEnabled := mgmURL.Scheme == "https"
log.Debugf("connecting to Management Service %s", mgmURL.String())
mgmClient, err := mgm.NewClient(ctx, mgmURL.Host, myPrivateKey, mgmTLSEnabled)
if err != nil {
log.Errorf("failed connecting to Management Service %s: %v", mgmURL.String(), err)
return nil, err
}
log.Debugf("connected to the Management service %s", mgmURL.String())
return &Auth{
client: mgmClient,
config: config,
privateKey: myPrivateKey,
mgmURL: mgmURL,
mgmTLSEnabled: mgmTLSEnabled,
}, nil
}
// Close closes the management client connection
func (a *Auth) Close() error {
a.mutex.Lock()
defer a.mutex.Unlock()
if a.client == nil {
return nil
}
return a.client.Close()
}
// IsSSOSupported checks if the management server supports SSO by attempting to retrieve auth flow configurations.
// Returns true if either PKCE or Device authorization flow is supported, false otherwise.
// This function encapsulates the SSO detection logic to avoid exposing gRPC error codes to upper layers.
// Automatically retries with backoff and reconnection on connection errors.
func (a *Auth) IsSSOSupported(ctx context.Context) (bool, error) {
var supportsSSO bool
err := a.withRetry(ctx, func(client *mgm.GrpcClient) error {
// Try PKCE flow first
_, err := a.getPKCEFlow(client)
if err == nil {
supportsSSO = true
return nil
}
// Check if PKCE is not supported
if s, ok := status.FromError(err); ok && (s.Code() == codes.NotFound || s.Code() == codes.Unimplemented) {
// PKCE not supported, try Device flow
_, err = a.getDeviceFlow(client)
if err == nil {
supportsSSO = true
return nil
}
// Check if Device flow is also not supported
if s, ok := status.FromError(err); ok && (s.Code() == codes.NotFound || s.Code() == codes.Unimplemented) {
// Neither PKCE nor Device flow is supported
supportsSSO = false
return nil
}
// Device flow check returned an error other than NotFound/Unimplemented
return err
}
// PKCE flow check returned an error other than NotFound/Unimplemented
return err
})
return supportsSSO, err
}
// GetOAuthFlow returns an OAuth flow (PKCE or Device) using the existing management connection
// This avoids creating a new connection to the management server
func (a *Auth) GetOAuthFlow(ctx context.Context, forceDeviceAuth bool) (OAuthFlow, error) {
var flow OAuthFlow
var err error
err = a.withRetry(ctx, func(client *mgm.GrpcClient) error {
if forceDeviceAuth {
flow, err = a.getDeviceFlow(client)
return err
}
// Try PKCE flow first
flow, err = a.getPKCEFlow(client)
if err != nil {
// If PKCE not supported, try Device flow
if s, ok := status.FromError(err); ok && (s.Code() == codes.NotFound || s.Code() == codes.Unimplemented) {
flow, err = a.getDeviceFlow(client)
return err
}
return err
}
return nil
})
return flow, err
}
// IsLoginRequired checks if login is required by attempting to authenticate with the server
// Automatically retries with backoff and reconnection on connection errors.
func (a *Auth) IsLoginRequired(ctx context.Context) (bool, error) {
pubSSHKey, err := ssh.GeneratePublicKey([]byte(a.config.SSHKey))
if err != nil {
return false, err
}
var needsLogin bool
err = a.withRetry(ctx, func(client *mgm.GrpcClient) error {
_, _, err := a.doMgmLogin(client, ctx, pubSSHKey)
if isLoginNeeded(err) {
needsLogin = true
return nil
}
needsLogin = false
return err
})
return needsLogin, err
}
// Login attempts to log in or register the client with the management server
// Returns error and a boolean indicating if it's an authentication error (permission denied) that should stop retries.
// Automatically retries with backoff and reconnection on connection errors.
func (a *Auth) Login(ctx context.Context, setupKey string, jwtToken string) (error, bool) {
pubSSHKey, err := ssh.GeneratePublicKey([]byte(a.config.SSHKey))
if err != nil {
return err, false
}
var isAuthError bool
err = a.withRetry(ctx, func(client *mgm.GrpcClient) error {
serverKey, _, err := a.doMgmLogin(client, ctx, pubSSHKey)
if serverKey != nil && isRegistrationNeeded(err) {
log.Debugf("peer registration required")
_, err = a.registerPeer(client, ctx, setupKey, jwtToken, pubSSHKey)
if err != nil {
isAuthError = isPermissionDenied(err)
return err
}
} else if err != nil {
isAuthError = isPermissionDenied(err)
return err
}
isAuthError = false
return nil
})
return err, isAuthError
}
// getPKCEFlow retrieves PKCE authorization flow configuration and creates a flow instance
func (a *Auth) getPKCEFlow(client *mgm.GrpcClient) (*PKCEAuthorizationFlow, error) {
serverKey, err := client.GetServerPublicKey()
if err != nil {
log.Errorf("failed while getting Management Service public key: %v", err)
return nil, err
}
protoFlow, err := client.GetPKCEAuthorizationFlow(*serverKey)
if err != nil {
if s, ok := status.FromError(err); ok && s.Code() == codes.NotFound {
log.Warnf("server couldn't find pkce flow, contact admin: %v", err)
return nil, err
}
log.Errorf("failed to retrieve pkce flow: %v", err)
return nil, err
}
protoConfig := protoFlow.GetProviderConfig()
config := &PKCEAuthProviderConfig{
Audience: protoConfig.GetAudience(),
ClientID: protoConfig.GetClientID(),
ClientSecret: protoConfig.GetClientSecret(),
TokenEndpoint: protoConfig.GetTokenEndpoint(),
AuthorizationEndpoint: protoConfig.GetAuthorizationEndpoint(),
Scope: protoConfig.GetScope(),
RedirectURLs: protoConfig.GetRedirectURLs(),
UseIDToken: protoConfig.GetUseIDToken(),
ClientCertPair: a.config.ClientCertKeyPair,
DisablePromptLogin: protoConfig.GetDisablePromptLogin(),
LoginFlag: common.LoginFlag(protoConfig.GetLoginFlag()),
}
if err := validatePKCEConfig(config); err != nil {
return nil, err
}
flow, err := NewPKCEAuthorizationFlow(*config)
if err != nil {
return nil, err
}
return flow, nil
}
// getDeviceFlow retrieves device authorization flow configuration and creates a flow instance
func (a *Auth) getDeviceFlow(client *mgm.GrpcClient) (*DeviceAuthorizationFlow, error) {
serverKey, err := client.GetServerPublicKey()
if err != nil {
log.Errorf("failed while getting Management Service public key: %v", err)
return nil, err
}
protoFlow, err := client.GetDeviceAuthorizationFlow(*serverKey)
if err != nil {
if s, ok := status.FromError(err); ok && s.Code() == codes.NotFound {
log.Warnf("server couldn't find device flow, contact admin: %v", err)
return nil, err
}
log.Errorf("failed to retrieve device flow: %v", err)
return nil, err
}
protoConfig := protoFlow.GetProviderConfig()
config := &DeviceAuthProviderConfig{
Audience: protoConfig.GetAudience(),
ClientID: protoConfig.GetClientID(),
ClientSecret: protoConfig.GetClientSecret(),
Domain: protoConfig.Domain,
TokenEndpoint: protoConfig.GetTokenEndpoint(),
DeviceAuthEndpoint: protoConfig.GetDeviceAuthEndpoint(),
Scope: protoConfig.GetScope(),
UseIDToken: protoConfig.GetUseIDToken(),
}
// Keep compatibility with older management versions
if config.Scope == "" {
config.Scope = "openid"
}
if err := validateDeviceAuthConfig(config); err != nil {
return nil, err
}
flow, err := NewDeviceAuthorizationFlow(*config)
if err != nil {
return nil, err
}
return flow, nil
}
// doMgmLogin performs the actual login operation with the management service
func (a *Auth) doMgmLogin(client *mgm.GrpcClient, ctx context.Context, pubSSHKey []byte) (*wgtypes.Key, *mgmProto.LoginResponse, error) {
serverKey, err := client.GetServerPublicKey()
if err != nil {
log.Errorf("failed while getting Management Service public key: %v", err)
return nil, nil, err
}
sysInfo := system.GetInfo(ctx)
a.setSystemInfoFlags(sysInfo)
loginResp, err := client.Login(*serverKey, sysInfo, pubSSHKey, a.config.DNSLabels)
return serverKey, loginResp, err
}
// registerPeer checks whether setupKey was provided via cmd line and if not then it prompts user to enter a key.
// Otherwise tries to register with the provided setupKey via command line.
func (a *Auth) registerPeer(client *mgm.GrpcClient, ctx context.Context, setupKey string, jwtToken string, pubSSHKey []byte) (*mgmProto.LoginResponse, error) {
serverPublicKey, err := client.GetServerPublicKey()
if err != nil {
log.Errorf("failed while getting Management Service public key: %v", err)
return nil, err
}
validSetupKey, err := uuid.Parse(setupKey)
if err != nil && jwtToken == "" {
return nil, status.Errorf(codes.InvalidArgument, "invalid setup-key or no sso information provided, err: %v", err)
}
log.Debugf("sending peer registration request to Management Service")
info := system.GetInfo(ctx)
a.setSystemInfoFlags(info)
loginResp, err := client.Register(*serverPublicKey, validSetupKey.String(), jwtToken, info, pubSSHKey, a.config.DNSLabels)
if err != nil {
log.Errorf("failed registering peer %v", err)
return nil, err
}
log.Infof("peer has been successfully registered on Management Service")
return loginResp, nil
}
// setSystemInfoFlags sets all configuration flags on the provided system info
func (a *Auth) setSystemInfoFlags(info *system.Info) {
info.SetFlags(
a.config.RosenpassEnabled,
a.config.RosenpassPermissive,
a.config.ServerSSHAllowed,
a.config.DisableClientRoutes,
a.config.DisableServerRoutes,
a.config.DisableDNS,
a.config.DisableFirewall,
a.config.BlockLANAccess,
a.config.BlockInbound,
a.config.LazyConnectionEnabled,
a.config.EnableSSHRoot,
a.config.EnableSSHSFTP,
a.config.EnableSSHLocalPortForwarding,
a.config.EnableSSHRemotePortForwarding,
a.config.DisableSSHAuth,
)
}
// reconnect closes the current connection and creates a new one
// It checks if the brokenClient is still the current client before reconnecting
// to avoid multiple threads reconnecting unnecessarily
func (a *Auth) reconnect(ctx context.Context, brokenClient *mgm.GrpcClient) error {
a.mutex.Lock()
defer a.mutex.Unlock()
// Double-check: if client has already been replaced by another thread, skip reconnection
if a.client != brokenClient {
log.Debugf("client already reconnected by another thread, skipping")
return nil
}
// Create new connection FIRST, before closing the old one
// This ensures a.client is never nil, preventing panics in other threads
log.Debugf("reconnecting to Management Service %s", a.mgmURL.String())
mgmClient, err := mgm.NewClient(ctx, a.mgmURL.Host, a.privateKey, a.mgmTLSEnabled)
if err != nil {
log.Errorf("failed reconnecting to Management Service %s: %v", a.mgmURL.String(), err)
// Keep the old client if reconnection fails
return err
}
// Close old connection AFTER new one is successfully created
oldClient := a.client
a.client = mgmClient
if oldClient != nil {
if err := oldClient.Close(); err != nil {
log.Debugf("error closing old connection: %v", err)
}
}
log.Debugf("successfully reconnected to Management service %s", a.mgmURL.String())
return nil
}
// isConnectionError checks if the error is a connection-related error that should trigger reconnection
func isConnectionError(err error) bool {
if err == nil {
return false
}
s, ok := status.FromError(err)
if !ok {
return false
}
// These error codes indicate connection issues
return s.Code() == codes.Unavailable ||
s.Code() == codes.DeadlineExceeded ||
s.Code() == codes.Canceled ||
s.Code() == codes.Internal
}
// withRetry wraps an operation with exponential backoff retry logic
// It automatically reconnects on connection errors
func (a *Auth) withRetry(ctx context.Context, operation func(client *mgm.GrpcClient) error) error {
backoffSettings := &backoff.ExponentialBackOff{
InitialInterval: 500 * time.Millisecond,
RandomizationFactor: 0.5,
Multiplier: 1.5,
MaxInterval: 10 * time.Second,
MaxElapsedTime: 2 * time.Minute,
Stop: backoff.Stop,
Clock: backoff.SystemClock,
}
backoffSettings.Reset()
return backoff.RetryNotify(
func() error {
// Capture the client BEFORE the operation to ensure we track the correct client
a.mutex.RLock()
currentClient := a.client
a.mutex.RUnlock()
if currentClient == nil {
return status.Errorf(codes.Unavailable, "client is not initialized")
}
// Execute operation with the captured client
err := operation(currentClient)
if err == nil {
return nil
}
// If it's a connection error, attempt reconnection using the client that was actually used
if isConnectionError(err) {
log.Warnf("connection error detected, attempting reconnection: %v", err)
if reconnectErr := a.reconnect(ctx, currentClient); reconnectErr != nil {
log.Errorf("reconnection failed: %v", reconnectErr)
return reconnectErr
}
// Return the original error to trigger retry with the new connection
return err
}
// For authentication errors, don't retry
if isAuthenticationError(err) {
return backoff.Permanent(err)
}
return err
},
backoff.WithContext(backoffSettings, ctx),
func(err error, duration time.Duration) {
log.Warnf("operation failed, retrying in %v: %v", duration, err)
},
)
}
// isAuthenticationError checks if the error is an authentication-related error that should not be retried.
// Returns true if the error is InvalidArgument or PermissionDenied, indicating that retrying won't help.
func isAuthenticationError(err error) bool {
if err == nil {
return false
}
s, ok := status.FromError(err)
if !ok {
return false
}
return s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied
}
// isPermissionDenied checks if the error is a PermissionDenied error.
// This is used to determine if early exit from backoff is needed (e.g., when the server responded but denied access).
func isPermissionDenied(err error) bool {
if err == nil {
return false
}
s, ok := status.FromError(err)
if !ok {
return false
}
return s.Code() == codes.PermissionDenied
}
func isLoginNeeded(err error) bool {
return isAuthenticationError(err)
}
func isRegistrationNeeded(err error) bool {
return isPermissionDenied(err)
}

View File

@@ -15,6 +15,7 @@ import (
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/util/embeddedroots"
)
@@ -25,56 +26,12 @@ const (
var _ OAuthFlow = &DeviceAuthorizationFlow{}
// DeviceAuthProviderConfig has all attributes needed to initiate a device authorization flow
type DeviceAuthProviderConfig struct {
// ClientID An IDP application client id
ClientID string
// ClientSecret An IDP application client secret
ClientSecret string
// Domain An IDP API domain
// Deprecated. Use OIDCConfigEndpoint instead
Domain string
// Audience An Audience for to authorization validation
Audience string
// TokenEndpoint is the endpoint of an IDP manager where clients can obtain access token
TokenEndpoint string
// DeviceAuthEndpoint is the endpoint of an IDP manager where clients can obtain device authorization code
DeviceAuthEndpoint string
// Scopes provides the scopes to be included in the token request
Scope string
// UseIDToken indicates if the id token should be used for authentication
UseIDToken bool
// LoginHint is used to pre-fill the email/username field during authentication
LoginHint string
}
// validateDeviceAuthConfig validates device authorization provider configuration
func validateDeviceAuthConfig(config *DeviceAuthProviderConfig) error {
errorMsgFormat := "invalid provider configuration received from management: %s value is empty. Contact your NetBird administrator"
if config.Audience == "" {
return fmt.Errorf(errorMsgFormat, "Audience")
}
if config.ClientID == "" {
return fmt.Errorf(errorMsgFormat, "Client ID")
}
if config.TokenEndpoint == "" {
return fmt.Errorf(errorMsgFormat, "Token Endpoint")
}
if config.DeviceAuthEndpoint == "" {
return fmt.Errorf(errorMsgFormat, "Device Auth Endpoint")
}
if config.Scope == "" {
return fmt.Errorf(errorMsgFormat, "Device Auth Scopes")
}
return nil
}
// DeviceAuthorizationFlow implements the OAuthFlow interface,
// for the Device Authorization Flow.
type DeviceAuthorizationFlow struct {
providerConfig DeviceAuthProviderConfig
HTTPClient HTTPClient
providerConfig internal.DeviceAuthProviderConfig
HTTPClient HTTPClient
}
// RequestDeviceCodePayload used for request device code payload for auth0
@@ -100,7 +57,7 @@ type TokenRequestResponse struct {
}
// NewDeviceAuthorizationFlow returns device authorization flow client
func NewDeviceAuthorizationFlow(config DeviceAuthProviderConfig) (*DeviceAuthorizationFlow, error) {
func NewDeviceAuthorizationFlow(config internal.DeviceAuthProviderConfig) (*DeviceAuthorizationFlow, error) {
httpTransport := http.DefaultTransport.(*http.Transport).Clone()
httpTransport.MaxIdleConns = 5
@@ -132,11 +89,6 @@ func (d *DeviceAuthorizationFlow) GetClientID(ctx context.Context) string {
return d.providerConfig.ClientID
}
// SetLoginHint sets the login hint for the device authorization flow
func (d *DeviceAuthorizationFlow) SetLoginHint(hint string) {
d.providerConfig.LoginHint = hint
}
// RequestAuthInfo requests a device code login flow information from Hosted
func (d *DeviceAuthorizationFlow) RequestAuthInfo(ctx context.Context) (AuthFlowInfo, error) {
form := url.Values{}
@@ -247,22 +199,14 @@ func (d *DeviceAuthorizationFlow) requestToken(info AuthFlowInfo) (TokenRequestR
}
// WaitToken waits user's login and authorize the app. Once the user's authorize
// it retrieves the access token from Hosted's endpoint and validates it before returning.
// The method creates a timeout context internally based on info.ExpiresIn.
// it retrieves the access token from Hosted's endpoint and validates it before returning
func (d *DeviceAuthorizationFlow) WaitToken(ctx context.Context, info AuthFlowInfo) (TokenInfo, error) {
// Create timeout context based on flow expiration
timeout := time.Duration(info.ExpiresIn) * time.Second
waitCtx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()
interval := time.Duration(info.Interval) * time.Second
ticker := time.NewTicker(interval)
defer ticker.Stop()
for {
select {
case <-waitCtx.Done():
return TokenInfo{}, waitCtx.Err()
case <-ctx.Done():
return TokenInfo{}, ctx.Err()
case <-ticker.C:
tokenResponse, err := d.requestToken(info)

View File

@@ -12,6 +12,8 @@ import (
"github.com/golang-jwt/jwt/v5"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/internal"
)
type mockHTTPClient struct {
@@ -113,19 +115,18 @@ func TestHosted_RequestDeviceCode(t *testing.T) {
err: testCase.inputReqError,
}
config := DeviceAuthProviderConfig{
Audience: expectedAudience,
ClientID: expectedClientID,
Scope: expectedScope,
TokenEndpoint: "test.hosted.com/token",
DeviceAuthEndpoint: "test.hosted.com/device/auth",
UseIDToken: false,
deviceFlow := &DeviceAuthorizationFlow{
providerConfig: internal.DeviceAuthProviderConfig{
Audience: expectedAudience,
ClientID: expectedClientID,
Scope: expectedScope,
TokenEndpoint: "test.hosted.com/token",
DeviceAuthEndpoint: "test.hosted.com/device/auth",
UseIDToken: false,
},
HTTPClient: &httpClient,
}
deviceFlow, err := NewDeviceAuthorizationFlow(config)
require.NoError(t, err, "creating device flow should not fail")
deviceFlow.HTTPClient = &httpClient
authInfo, err := deviceFlow.RequestAuthInfo(context.TODO())
testCase.testingErrFunc(t, err, testCase.expectedErrorMSG)
@@ -279,19 +280,18 @@ func TestHosted_WaitToken(t *testing.T) {
countResBody: testCase.inputCountResBody,
}
config := DeviceAuthProviderConfig{
Audience: testCase.inputAudience,
ClientID: clientID,
TokenEndpoint: "test.hosted.com/token",
DeviceAuthEndpoint: "test.hosted.com/device/auth",
Scope: "openid",
UseIDToken: false,
deviceFlow := DeviceAuthorizationFlow{
providerConfig: internal.DeviceAuthProviderConfig{
Audience: testCase.inputAudience,
ClientID: clientID,
TokenEndpoint: "test.hosted.com/token",
DeviceAuthEndpoint: "test.hosted.com/device/auth",
Scope: "openid",
UseIDToken: false,
},
HTTPClient: &httpClient,
}
deviceFlow, err := NewDeviceAuthorizationFlow(config)
require.NoError(t, err, "creating device flow should not fail")
deviceFlow.HTTPClient = &httpClient
ctx, cancel := context.WithTimeout(context.TODO(), testCase.inputTimeout)
defer cancel()
tokenInfo, err := deviceFlow.WaitToken(ctx, testCase.inputInfo)

View File

@@ -10,6 +10,7 @@ import (
"google.golang.org/grpc/codes"
gstatus "google.golang.org/grpc/status"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/profilemanager"
)
@@ -86,33 +87,19 @@ func NewOAuthFlow(ctx context.Context, config *profilemanager.Config, isUnixDesk
// authenticateWithPKCEFlow initializes the Proof Key for Code Exchange flow auth flow
func authenticateWithPKCEFlow(ctx context.Context, config *profilemanager.Config, hint string) (OAuthFlow, error) {
authClient, err := NewAuth(ctx, config.PrivateKey, config.ManagementURL, config)
if err != nil {
return nil, fmt.Errorf("failed to create auth client: %v", err)
}
defer authClient.Close()
pkceFlowInfo, err := authClient.getPKCEFlow(authClient.client)
pkceFlowInfo, err := internal.GetPKCEAuthorizationFlowInfo(ctx, config.PrivateKey, config.ManagementURL, config.ClientCertKeyPair)
if err != nil {
return nil, fmt.Errorf("getting pkce authorization flow info failed with error: %v", err)
}
if hint != "" {
pkceFlowInfo.SetLoginHint(hint)
}
pkceFlowInfo.ProviderConfig.LoginHint = hint
return pkceFlowInfo, nil
return NewPKCEAuthorizationFlow(pkceFlowInfo.ProviderConfig)
}
// authenticateWithDeviceCodeFlow initializes the Device Code auth Flow
func authenticateWithDeviceCodeFlow(ctx context.Context, config *profilemanager.Config, hint string) (OAuthFlow, error) {
authClient, err := NewAuth(ctx, config.PrivateKey, config.ManagementURL, config)
if err != nil {
return nil, fmt.Errorf("failed to create auth client: %v", err)
}
defer authClient.Close()
deviceFlowInfo, err := authClient.getDeviceFlow(authClient.client)
deviceFlowInfo, err := internal.GetDeviceAuthorizationFlowInfo(ctx, config.PrivateKey, config.ManagementURL)
if err != nil {
switch s, ok := gstatus.FromError(err); {
case ok && s.Code() == codes.NotFound:
@@ -127,9 +114,7 @@ func authenticateWithDeviceCodeFlow(ctx context.Context, config *profilemanager.
}
}
if hint != "" {
deviceFlowInfo.SetLoginHint(hint)
}
deviceFlowInfo.ProviderConfig.LoginHint = hint
return deviceFlowInfo, nil
return NewDeviceAuthorizationFlow(deviceFlowInfo.ProviderConfig)
}

View File

@@ -20,6 +20,7 @@ import (
log "github.com/sirupsen/logrus"
"golang.org/x/oauth2"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/templates"
"github.com/netbirdio/netbird/shared/management/client/common"
)
@@ -34,67 +35,17 @@ const (
defaultPKCETimeoutSeconds = 300
)
// PKCEAuthProviderConfig has all attributes needed to initiate PKCE authorization flow
type PKCEAuthProviderConfig struct {
// ClientID An IDP application client id
ClientID string
// ClientSecret An IDP application client secret
ClientSecret string
// Audience An Audience for to authorization validation
Audience string
// TokenEndpoint is the endpoint of an IDP manager where clients can obtain access token
TokenEndpoint string
// AuthorizationEndpoint is the endpoint of an IDP manager where clients can obtain authorization code
AuthorizationEndpoint string
// Scopes provides the scopes to be included in the token request
Scope string
// RedirectURL handles authorization code from IDP manager
RedirectURLs []string
// UseIDToken indicates if the id token should be used for authentication
UseIDToken bool
// ClientCertPair is used for mTLS authentication to the IDP
ClientCertPair *tls.Certificate
// DisablePromptLogin makes the PKCE flow to not prompt the user for login
DisablePromptLogin bool
// LoginFlag is used to configure the PKCE flow login behavior
LoginFlag common.LoginFlag
// LoginHint is used to pre-fill the email/username field during authentication
LoginHint string
}
// validatePKCEConfig validates PKCE provider configuration
func validatePKCEConfig(config *PKCEAuthProviderConfig) error {
errorMsgFormat := "invalid provider configuration received from management: %s value is empty. Contact your NetBird administrator"
if config.ClientID == "" {
return fmt.Errorf(errorMsgFormat, "Client ID")
}
if config.TokenEndpoint == "" {
return fmt.Errorf(errorMsgFormat, "Token Endpoint")
}
if config.AuthorizationEndpoint == "" {
return fmt.Errorf(errorMsgFormat, "Authorization Auth Endpoint")
}
if config.Scope == "" {
return fmt.Errorf(errorMsgFormat, "PKCE Auth Scopes")
}
if config.RedirectURLs == nil {
return fmt.Errorf(errorMsgFormat, "PKCE Redirect URLs")
}
return nil
}
// PKCEAuthorizationFlow implements the OAuthFlow interface for
// the Authorization Code Flow with PKCE.
type PKCEAuthorizationFlow struct {
providerConfig PKCEAuthProviderConfig
providerConfig internal.PKCEAuthProviderConfig
state string
codeVerifier string
oAuthConfig *oauth2.Config
}
// NewPKCEAuthorizationFlow returns new PKCE authorization code flow.
func NewPKCEAuthorizationFlow(config PKCEAuthProviderConfig) (*PKCEAuthorizationFlow, error) {
func NewPKCEAuthorizationFlow(config internal.PKCEAuthProviderConfig) (*PKCEAuthorizationFlow, error) {
var availableRedirectURL string
excludedRanges := getSystemExcludedPortRanges()
@@ -173,21 +124,10 @@ func (p *PKCEAuthorizationFlow) RequestAuthInfo(ctx context.Context) (AuthFlowIn
}, nil
}
// SetLoginHint sets the login hint for the PKCE authorization flow
func (p *PKCEAuthorizationFlow) SetLoginHint(hint string) {
p.providerConfig.LoginHint = hint
}
// WaitToken waits for the OAuth token in the PKCE Authorization Flow.
// It starts an HTTP server to receive the OAuth token callback and waits for the token or an error.
// Once the token is received, it is converted to TokenInfo and validated before returning.
// The method creates a timeout context internally based on info.ExpiresIn.
func (p *PKCEAuthorizationFlow) WaitToken(ctx context.Context, info AuthFlowInfo) (TokenInfo, error) {
// Create timeout context based on flow expiration
timeout := time.Duration(info.ExpiresIn) * time.Second
waitCtx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()
func (p *PKCEAuthorizationFlow) WaitToken(ctx context.Context, _ AuthFlowInfo) (TokenInfo, error) {
tokenChan := make(chan *oauth2.Token, 1)
errChan := make(chan error, 1)
@@ -198,7 +138,7 @@ func (p *PKCEAuthorizationFlow) WaitToken(ctx context.Context, info AuthFlowInfo
server := &http.Server{Addr: fmt.Sprintf(":%s", parsedURL.Port())}
defer func() {
shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
shutdownCtx, cancel := context.WithTimeout(ctx, 5*time.Second)
defer cancel()
if err := server.Shutdown(shutdownCtx); err != nil {
@@ -209,8 +149,8 @@ func (p *PKCEAuthorizationFlow) WaitToken(ctx context.Context, info AuthFlowInfo
go p.startServer(server, tokenChan, errChan)
select {
case <-waitCtx.Done():
return TokenInfo{}, waitCtx.Err()
case <-ctx.Done():
return TokenInfo{}, ctx.Err()
case token := <-tokenChan:
return p.parseOAuthToken(token)
case err := <-errChan:

View File

@@ -9,6 +9,7 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/internal"
mgm "github.com/netbirdio/netbird/shared/management/client/common"
)
@@ -49,7 +50,7 @@ func TestPromptLogin(t *testing.T) {
for _, tc := range tt {
t.Run(tc.name, func(t *testing.T) {
config := PKCEAuthProviderConfig{
config := internal.PKCEAuthProviderConfig{
ClientID: "test-client-id",
Audience: "test-audience",
TokenEndpoint: "https://test-token-endpoint.com/token",

View File

@@ -9,6 +9,8 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/internal"
)
func TestParseExcludedPortRanges(t *testing.T) {
@@ -93,7 +95,7 @@ func TestNewPKCEAuthorizationFlow_WithActualExcludedPorts(t *testing.T) {
availablePort := 65432
config := PKCEAuthProviderConfig{
config := internal.PKCEAuthProviderConfig{
ClientID: "test-client-id",
Audience: "test-audience",
TokenEndpoint: "https://test-token-endpoint.com/token",

View File

@@ -20,7 +20,6 @@ import (
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/netstack"
"github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/listener"
"github.com/netbirdio/netbird/client/internal/peer"
@@ -60,6 +59,7 @@ func NewConnectClient(
config *profilemanager.Config,
statusRecorder *peer.Status,
doInitalAutoUpdate bool,
) *ConnectClient {
return &ConnectClient{
ctx: ctx,
@@ -71,8 +71,8 @@ func NewConnectClient(
}
// Run with main logic.
func (c *ConnectClient) Run(runningChan chan struct{}, logPath string) error {
return c.run(MobileDependency{}, runningChan, logPath)
func (c *ConnectClient) Run(runningChan chan struct{}) error {
return c.run(MobileDependency{}, runningChan)
}
// RunOnAndroid with main logic on mobile system
@@ -93,7 +93,7 @@ func (c *ConnectClient) RunOnAndroid(
DnsReadyListener: dnsReadyListener,
StateFilePath: stateFilePath,
}
return c.run(mobileDependency, nil, "")
return c.run(mobileDependency, nil)
}
func (c *ConnectClient) RunOniOS(
@@ -111,10 +111,10 @@ func (c *ConnectClient) RunOniOS(
DnsManager: dnsManager,
StateFilePath: stateFilePath,
}
return c.run(mobileDependency, nil, "")
return c.run(mobileDependency, nil)
}
func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan struct{}, logPath string) error {
func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan struct{}) error {
defer func() {
if r := recover(); r != nil {
rec := c.statusRecorder
@@ -245,7 +245,7 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
localPeerState := peer.LocalPeerState{
IP: loginResp.GetPeerConfig().GetAddress(),
PubKey: myPrivateKey.PublicKey().String(),
KernelInterface: device.WireGuardModuleIsLoaded() && !netstack.IsEnabled(),
KernelInterface: device.WireGuardModuleIsLoaded(),
FQDN: loginResp.GetPeerConfig().GetFqdn(),
}
c.statusRecorder.UpdateLocalPeerState(localPeerState)
@@ -284,7 +284,7 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
relayURLs, token := parseRelayInfo(loginResp)
peerConfig := loginResp.GetPeerConfig()
engineConfig, err := createEngineConfig(myPrivateKey, c.config, peerConfig, logPath)
engineConfig, err := createEngineConfig(myPrivateKey, c.config, peerConfig)
if err != nil {
log.Error(err)
return wrapErr(err)
@@ -472,7 +472,7 @@ func (c *ConnectClient) SetSyncResponsePersistence(enabled bool) {
}
// createEngineConfig converts configuration received from Management Service to EngineConfig
func createEngineConfig(key wgtypes.Key, config *profilemanager.Config, peerConfig *mgmProto.PeerConfig, logPath string) (*EngineConfig, error) {
func createEngineConfig(key wgtypes.Key, config *profilemanager.Config, peerConfig *mgmProto.PeerConfig) (*EngineConfig, error) {
nm := false
if config.NetworkMonitor != nil {
nm = *config.NetworkMonitor
@@ -507,10 +507,7 @@ func createEngineConfig(key wgtypes.Key, config *profilemanager.Config, peerConf
LazyConnectionEnabled: config.LazyConnectionEnabled,
MTU: selectMTU(config.MTU, peerConfig.Mtu),
LogPath: logPath,
ProfileConfig: config,
MTU: selectMTU(config.MTU, peerConfig.Mtu),
}
if config.PreSharedKey != "" {

View File

@@ -28,10 +28,8 @@ import (
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/profilemanager"
"github.com/netbirdio/netbird/client/internal/updatemanager/installer"
nbstatus "github.com/netbirdio/netbird/client/status"
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
"github.com/netbirdio/netbird/util"
"github.com/netbirdio/netbird/version"
)
const readmeContent = `Netbird debug bundle
@@ -59,7 +57,6 @@ block.prof: Block profiling information.
heap.prof: Heap profiling information (snapshot of memory allocations).
allocs.prof: Allocations profiling information.
threadcreate.prof: Thread creation profiling information.
cpu.prof: CPU profiling information.
stack_trace.txt: Complete stack traces of all goroutines at the time of bundle creation.
@@ -226,11 +223,10 @@ type BundleGenerator struct {
internalConfig *profilemanager.Config
statusRecorder *peer.Status
syncResponse *mgmProto.SyncResponse
logPath string
cpuProfile []byte
refreshStatus func() // Optional callback to refresh status before bundle generation
logFile string
anonymize bool
clientStatus string
includeSystemInfo bool
logFileCount uint32
@@ -239,6 +235,7 @@ type BundleGenerator struct {
type BundleConfig struct {
Anonymize bool
ClientStatus string
IncludeSystemInfo bool
LogFileCount uint32
}
@@ -247,9 +244,7 @@ type GeneratorDependencies struct {
InternalConfig *profilemanager.Config
StatusRecorder *peer.Status
SyncResponse *mgmProto.SyncResponse
LogPath string
CPUProfile []byte
RefreshStatus func() // Optional callback to refresh status before bundle generation
LogFile string
}
func NewBundleGenerator(deps GeneratorDependencies, cfg BundleConfig) *BundleGenerator {
@@ -265,11 +260,10 @@ func NewBundleGenerator(deps GeneratorDependencies, cfg BundleConfig) *BundleGen
internalConfig: deps.InternalConfig,
statusRecorder: deps.StatusRecorder,
syncResponse: deps.SyncResponse,
logPath: deps.LogPath,
cpuProfile: deps.CPUProfile,
refreshStatus: deps.RefreshStatus,
logFile: deps.LogFile,
anonymize: cfg.Anonymize,
clientStatus: cfg.ClientStatus,
includeSystemInfo: cfg.IncludeSystemInfo,
logFileCount: logFileCount,
}
@@ -315,6 +309,13 @@ func (g *BundleGenerator) createArchive() error {
return fmt.Errorf("add status: %w", err)
}
if g.statusRecorder != nil {
status := g.statusRecorder.GetFullStatus()
seedFromStatus(g.anonymizer, &status)
} else {
log.Debugf("no status recorder available for seeding")
}
if err := g.addConfig(); err != nil {
log.Errorf("failed to add config to debug bundle: %v", err)
}
@@ -331,10 +332,6 @@ func (g *BundleGenerator) createArchive() error {
log.Errorf("failed to add profiles to debug bundle: %v", err)
}
if err := g.addCPUProfile(); err != nil {
log.Errorf("failed to add CPU profile to debug bundle: %v", err)
}
if err := g.addStackTrace(); err != nil {
log.Errorf("failed to add stack trace to debug bundle: %v", err)
}
@@ -355,7 +352,7 @@ func (g *BundleGenerator) createArchive() error {
log.Errorf("failed to add wg show output: %v", err)
}
if g.logPath != "" && !slices.Contains(util.SpecialLogs, g.logPath) {
if g.logFile != "" && !slices.Contains(util.SpecialLogs, g.logFile) {
if err := g.addLogfile(); err != nil {
log.Errorf("failed to add log file to debug bundle: %v", err)
if err := g.trySystemdLogFallback(); err != nil {
@@ -404,30 +401,11 @@ func (g *BundleGenerator) addReadme() error {
}
func (g *BundleGenerator) addStatus() error {
if g.statusRecorder != nil {
pm := profilemanager.NewProfileManager()
var profName string
if activeProf, err := pm.GetActiveProfile(); err == nil {
profName = activeProf.Name
}
if g.refreshStatus != nil {
g.refreshStatus()
}
fullStatus := g.statusRecorder.GetFullStatus()
protoFullStatus := nbstatus.ToProtoFullStatus(fullStatus)
protoFullStatus.Events = g.statusRecorder.GetEventHistory()
overview := nbstatus.ConvertToStatusOutputOverview(protoFullStatus, g.anonymize, version.NetbirdVersion(), "", nil, nil, nil, "", profName)
statusOutput := overview.FullDetailSummary()
statusReader := strings.NewReader(statusOutput)
if status := g.clientStatus; status != "" {
statusReader := strings.NewReader(status)
if err := g.addFileToZip(statusReader, "status.txt"); err != nil {
return fmt.Errorf("add status file to zip: %w", err)
}
seedFromStatus(g.anonymizer, &fullStatus)
} else {
log.Debugf("no status recorder available for seeding")
}
return nil
}
@@ -557,19 +535,6 @@ func (g *BundleGenerator) addProf() (err error) {
return nil
}
func (g *BundleGenerator) addCPUProfile() error {
if len(g.cpuProfile) == 0 {
return nil
}
reader := bytes.NewReader(g.cpuProfile)
if err := g.addFileToZip(reader, "cpu.prof"); err != nil {
return fmt.Errorf("add CPU profile to zip: %w", err)
}
return nil
}
func (g *BundleGenerator) addStackTrace() error {
buf := make([]byte, 5242880) // 5 MB buffer
n := runtime.Stack(buf, true)
@@ -745,14 +710,14 @@ func (g *BundleGenerator) addCorruptedStateFiles() error {
}
func (g *BundleGenerator) addLogfile() error {
if g.logPath == "" {
if g.logFile == "" {
log.Debugf("skipping empty log file in debug bundle")
return nil
}
logDir := filepath.Dir(g.logPath)
logDir := filepath.Dir(g.logFile)
if err := g.addSingleLogfile(g.logPath, clientLogFile); err != nil {
if err := g.addSingleLogfile(g.logFile, clientLogFile); err != nil {
return fmt.Errorf("add client log file to zip: %w", err)
}

View File

@@ -1,101 +0,0 @@
package debug
import (
"context"
"crypto/sha256"
"encoding/json"
"fmt"
"io"
"net/http"
"os"
"github.com/netbirdio/netbird/upload-server/types"
)
const maxBundleUploadSize = 50 * 1024 * 1024
func UploadDebugBundle(ctx context.Context, url, managementURL, filePath string) (key string, err error) {
response, err := getUploadURL(ctx, url, managementURL)
if err != nil {
return "", err
}
err = upload(ctx, filePath, response)
if err != nil {
return "", err
}
return response.Key, nil
}
func upload(ctx context.Context, filePath string, response *types.GetURLResponse) error {
fileData, err := os.Open(filePath)
if err != nil {
return fmt.Errorf("open file: %w", err)
}
defer fileData.Close()
stat, err := fileData.Stat()
if err != nil {
return fmt.Errorf("stat file: %w", err)
}
if stat.Size() > maxBundleUploadSize {
return fmt.Errorf("file size exceeds maximum limit of %d bytes", maxBundleUploadSize)
}
req, err := http.NewRequestWithContext(ctx, "PUT", response.URL, fileData)
if err != nil {
return fmt.Errorf("create PUT request: %w", err)
}
req.ContentLength = stat.Size()
req.Header.Set("Content-Type", "application/octet-stream")
putResp, err := http.DefaultClient.Do(req)
if err != nil {
return fmt.Errorf("upload failed: %v", err)
}
defer putResp.Body.Close()
if putResp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(putResp.Body)
return fmt.Errorf("upload status %d: %s", putResp.StatusCode, string(body))
}
return nil
}
func getUploadURL(ctx context.Context, url string, managementURL string) (*types.GetURLResponse, error) {
id := getURLHash(managementURL)
getReq, err := http.NewRequestWithContext(ctx, "GET", url+"?id="+id, nil)
if err != nil {
return nil, fmt.Errorf("create GET request: %w", err)
}
getReq.Header.Set(types.ClientHeader, types.ClientHeaderValue)
resp, err := http.DefaultClient.Do(getReq)
if err != nil {
return nil, fmt.Errorf("get presigned URL: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
return nil, fmt.Errorf("get presigned URL status %d: %s", resp.StatusCode, string(body))
}
urlBytes, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("read response body: %w", err)
}
var response types.GetURLResponse
if err := json.Unmarshal(urlBytes, &response); err != nil {
return nil, fmt.Errorf("unmarshal response: %w", err)
}
return &response, nil
}
func getURLHash(url string) string {
return fmt.Sprintf("%x", sha256.Sum256([]byte(url)))
}

View File

@@ -60,7 +60,7 @@ func (g *BundleGenerator) toWGShowFormat(s *configurer.Stats) string {
}
sb.WriteString(fmt.Sprintf(" latest handshake: %s\n", peer.LastHandshake.Format(time.RFC1123)))
sb.WriteString(fmt.Sprintf(" transfer: %d B received, %d B sent\n", peer.RxBytes, peer.TxBytes))
if peer.PresharedKey != [32]byte{} {
if peer.PresharedKey {
sb.WriteString(" preshared key: (hidden)\n")
}
}

View File

@@ -0,0 +1,136 @@
package internal
import (
"context"
"fmt"
"net/url"
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
mgm "github.com/netbirdio/netbird/shared/management/client"
)
// DeviceAuthorizationFlow represents Device Authorization Flow information
type DeviceAuthorizationFlow struct {
Provider string
ProviderConfig DeviceAuthProviderConfig
}
// DeviceAuthProviderConfig has all attributes needed to initiate a device authorization flow
type DeviceAuthProviderConfig struct {
// ClientID An IDP application client id
ClientID string
// ClientSecret An IDP application client secret
ClientSecret string
// Domain An IDP API domain
// Deprecated. Use OIDCConfigEndpoint instead
Domain string
// Audience An Audience for to authorization validation
Audience string
// TokenEndpoint is the endpoint of an IDP manager where clients can obtain access token
TokenEndpoint string
// DeviceAuthEndpoint is the endpoint of an IDP manager where clients can obtain device authorization code
DeviceAuthEndpoint string
// Scopes provides the scopes to be included in the token request
Scope string
// UseIDToken indicates if the id token should be used for authentication
UseIDToken bool
// LoginHint is used to pre-fill the email/username field during authentication
LoginHint string
}
// GetDeviceAuthorizationFlowInfo initialize a DeviceAuthorizationFlow instance and return with it
func GetDeviceAuthorizationFlowInfo(ctx context.Context, privateKey string, mgmURL *url.URL) (DeviceAuthorizationFlow, error) {
// validate our peer's Wireguard PRIVATE key
myPrivateKey, err := wgtypes.ParseKey(privateKey)
if err != nil {
log.Errorf("failed parsing Wireguard key %s: [%s]", privateKey, err.Error())
return DeviceAuthorizationFlow{}, err
}
var mgmTLSEnabled bool
if mgmURL.Scheme == "https" {
mgmTLSEnabled = true
}
log.Debugf("connecting to Management Service %s", mgmURL.String())
mgmClient, err := mgm.NewClient(ctx, mgmURL.Host, myPrivateKey, mgmTLSEnabled)
if err != nil {
log.Errorf("failed connecting to Management Service %s %v", mgmURL.String(), err)
return DeviceAuthorizationFlow{}, err
}
log.Debugf("connected to the Management service %s", mgmURL.String())
defer func() {
err = mgmClient.Close()
if err != nil {
log.Warnf("failed to close the Management service client %v", err)
}
}()
serverKey, err := mgmClient.GetServerPublicKey()
if err != nil {
log.Errorf("failed while getting Management Service public key: %v", err)
return DeviceAuthorizationFlow{}, err
}
protoDeviceAuthorizationFlow, err := mgmClient.GetDeviceAuthorizationFlow(*serverKey)
if err != nil {
if s, ok := status.FromError(err); ok && s.Code() == codes.NotFound {
log.Warnf("server couldn't find device flow, contact admin: %v", err)
return DeviceAuthorizationFlow{}, err
}
log.Errorf("failed to retrieve device flow: %v", err)
return DeviceAuthorizationFlow{}, err
}
deviceAuthorizationFlow := DeviceAuthorizationFlow{
Provider: protoDeviceAuthorizationFlow.Provider.String(),
ProviderConfig: DeviceAuthProviderConfig{
Audience: protoDeviceAuthorizationFlow.GetProviderConfig().GetAudience(),
ClientID: protoDeviceAuthorizationFlow.GetProviderConfig().GetClientID(),
ClientSecret: protoDeviceAuthorizationFlow.GetProviderConfig().GetClientSecret(),
Domain: protoDeviceAuthorizationFlow.GetProviderConfig().Domain,
TokenEndpoint: protoDeviceAuthorizationFlow.GetProviderConfig().GetTokenEndpoint(),
DeviceAuthEndpoint: protoDeviceAuthorizationFlow.GetProviderConfig().GetDeviceAuthEndpoint(),
Scope: protoDeviceAuthorizationFlow.GetProviderConfig().GetScope(),
UseIDToken: protoDeviceAuthorizationFlow.GetProviderConfig().GetUseIDToken(),
},
}
// keep compatibility with older management versions
if deviceAuthorizationFlow.ProviderConfig.Scope == "" {
deviceAuthorizationFlow.ProviderConfig.Scope = "openid"
}
err = isDeviceAuthProviderConfigValid(deviceAuthorizationFlow.ProviderConfig)
if err != nil {
return DeviceAuthorizationFlow{}, err
}
return deviceAuthorizationFlow, nil
}
func isDeviceAuthProviderConfigValid(config DeviceAuthProviderConfig) error {
errorMSGFormat := "invalid provider configuration received from management: %s value is empty. Contact your NetBird administrator"
if config.Audience == "" {
return fmt.Errorf(errorMSGFormat, "Audience")
}
if config.ClientID == "" {
return fmt.Errorf(errorMSGFormat, "Client ID")
}
if config.TokenEndpoint == "" {
return fmt.Errorf(errorMSGFormat, "Token Endpoint")
}
if config.DeviceAuthEndpoint == "" {
return fmt.Errorf(errorMSGFormat, "Device Auth Endpoint")
}
if config.Scope == "" {
return fmt.Errorf(errorMSGFormat, "Device Auth Scopes")
}
return nil
}

View File

@@ -112,54 +112,6 @@ func TestHandlerChain_ServeDNS_DomainMatching(t *testing.T) {
matchSubdomains: false,
shouldMatch: false,
},
{
name: "single letter TLD exact match",
handlerDomain: "example.x.",
queryDomain: "example.x.",
isWildcard: false,
matchSubdomains: false,
shouldMatch: true,
},
{
name: "single letter TLD subdomain match",
handlerDomain: "example.x.",
queryDomain: "sub.example.x.",
isWildcard: false,
matchSubdomains: true,
shouldMatch: true,
},
{
name: "single letter TLD wildcard match",
handlerDomain: "*.example.x.",
queryDomain: "sub.example.x.",
isWildcard: true,
matchSubdomains: false,
shouldMatch: true,
},
{
name: "two letter domain labels",
handlerDomain: "a.b.",
queryDomain: "a.b.",
isWildcard: false,
matchSubdomains: false,
shouldMatch: true,
},
{
name: "single character domain",
handlerDomain: "x.",
queryDomain: "x.",
isWildcard: false,
matchSubdomains: false,
shouldMatch: true,
},
{
name: "single character domain with subdomain match",
handlerDomain: "x.",
queryDomain: "sub.x.",
isWildcard: false,
matchSubdomains: true,
shouldMatch: true,
},
}
for _, tt := range tests {

View File

@@ -9,10 +9,8 @@ import (
"io"
"net/netip"
"os/exec"
"slices"
"strconv"
"strings"
"sync"
log "github.com/sirupsen/logrus"
"golang.org/x/exp/maps"
@@ -40,9 +38,6 @@ const (
type systemConfigurator struct {
createdKeys map[string]struct{}
systemDNSSettings SystemDNSSettings
mu sync.RWMutex
origNameservers []netip.Addr
}
func newHostManager() (*systemConfigurator, error) {
@@ -223,7 +218,6 @@ func (s *systemConfigurator) getSystemDNSSettings() (SystemDNSSettings, error) {
}
var dnsSettings SystemDNSSettings
var serverAddresses []netip.Addr
inSearchDomainsArray := false
inServerAddressesArray := false
@@ -250,12 +244,9 @@ func (s *systemConfigurator) getSystemDNSSettings() (SystemDNSSettings, error) {
dnsSettings.Domains = append(dnsSettings.Domains, searchDomain)
} else if inServerAddressesArray {
address := strings.Split(line, " : ")[1]
if ip, err := netip.ParseAddr(address); err == nil && !ip.IsUnspecified() {
ip = ip.Unmap()
serverAddresses = append(serverAddresses, ip)
if !dnsSettings.ServerIP.IsValid() && ip.Is4() {
dnsSettings.ServerIP = ip
}
if ip, err := netip.ParseAddr(address); err == nil && ip.Is4() {
dnsSettings.ServerIP = ip.Unmap()
inServerAddressesArray = false // Stop reading after finding the first IPv4 address
}
}
}
@@ -267,19 +258,9 @@ func (s *systemConfigurator) getSystemDNSSettings() (SystemDNSSettings, error) {
// default to 53 port
dnsSettings.ServerPort = DefaultPort
s.mu.Lock()
s.origNameservers = serverAddresses
s.mu.Unlock()
return dnsSettings, nil
}
func (s *systemConfigurator) getOriginalNameservers() []netip.Addr {
s.mu.RLock()
defer s.mu.RUnlock()
return slices.Clone(s.origNameservers)
}
func (s *systemConfigurator) addSearchDomains(key, domains string, ip netip.Addr, port int) error {
err := s.addDNSState(key, domains, ip, port, true)
if err != nil {

View File

@@ -109,169 +109,3 @@ func removeTestDNSKey(key string) error {
_, err := cmd.CombinedOutput()
return err
}
func TestGetOriginalNameservers(t *testing.T) {
configurator := &systemConfigurator{
createdKeys: make(map[string]struct{}),
origNameservers: []netip.Addr{
netip.MustParseAddr("8.8.8.8"),
netip.MustParseAddr("1.1.1.1"),
},
}
servers := configurator.getOriginalNameservers()
assert.Len(t, servers, 2)
assert.Equal(t, netip.MustParseAddr("8.8.8.8"), servers[0])
assert.Equal(t, netip.MustParseAddr("1.1.1.1"), servers[1])
}
func TestGetOriginalNameserversFromSystem(t *testing.T) {
configurator := &systemConfigurator{
createdKeys: make(map[string]struct{}),
}
_, err := configurator.getSystemDNSSettings()
require.NoError(t, err)
servers := configurator.getOriginalNameservers()
require.NotEmpty(t, servers, "expected at least one DNS server from system configuration")
for _, server := range servers {
assert.True(t, server.IsValid(), "server address should be valid")
assert.False(t, server.IsUnspecified(), "server address should not be unspecified")
}
t.Logf("found %d original nameservers: %v", len(servers), servers)
}
func setupTestConfigurator(t *testing.T) (*systemConfigurator, *statemanager.Manager, func()) {
t.Helper()
tmpDir := t.TempDir()
stateFile := filepath.Join(tmpDir, "state.json")
sm := statemanager.New(stateFile)
sm.RegisterState(&ShutdownState{})
sm.Start()
configurator := &systemConfigurator{
createdKeys: make(map[string]struct{}),
}
searchKey := getKeyWithInput(netbirdDNSStateKeyFormat, searchSuffix)
matchKey := getKeyWithInput(netbirdDNSStateKeyFormat, matchSuffix)
localKey := getKeyWithInput(netbirdDNSStateKeyFormat, localSuffix)
cleanup := func() {
_ = sm.Stop(context.Background())
for _, key := range []string{searchKey, matchKey, localKey} {
_ = removeTestDNSKey(key)
}
}
return configurator, sm, cleanup
}
func TestOriginalNameserversNoTransition(t *testing.T) {
netbirdIP := netip.MustParseAddr("100.64.0.1")
testCases := []struct {
name string
routeAll bool
}{
{"routeall_false", false},
{"routeall_true", true},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
configurator, sm, cleanup := setupTestConfigurator(t)
defer cleanup()
_, err := configurator.getSystemDNSSettings()
require.NoError(t, err)
initialServers := configurator.getOriginalNameservers()
t.Logf("Initial servers: %v", initialServers)
require.NotEmpty(t, initialServers)
for _, srv := range initialServers {
require.NotEqual(t, netbirdIP, srv, "initial servers should not contain NetBird IP")
}
config := HostDNSConfig{
ServerIP: netbirdIP,
ServerPort: 53,
RouteAll: tc.routeAll,
Domains: []DomainConfig{{Domain: "example.com", MatchOnly: true}},
}
for i := 1; i <= 2; i++ {
err = configurator.applyDNSConfig(config, sm)
require.NoError(t, err)
servers := configurator.getOriginalNameservers()
t.Logf("After apply %d (RouteAll=%v): %v", i, tc.routeAll, servers)
assert.Equal(t, initialServers, servers)
}
})
}
}
func TestOriginalNameserversRouteAllTransition(t *testing.T) {
netbirdIP := netip.MustParseAddr("100.64.0.1")
testCases := []struct {
name string
initialRoute bool
}{
{"start_with_routeall_false", false},
{"start_with_routeall_true", true},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
configurator, sm, cleanup := setupTestConfigurator(t)
defer cleanup()
_, err := configurator.getSystemDNSSettings()
require.NoError(t, err)
initialServers := configurator.getOriginalNameservers()
t.Logf("Initial servers: %v", initialServers)
require.NotEmpty(t, initialServers)
config := HostDNSConfig{
ServerIP: netbirdIP,
ServerPort: 53,
RouteAll: tc.initialRoute,
Domains: []DomainConfig{{Domain: "example.com", MatchOnly: true}},
}
// First apply
err = configurator.applyDNSConfig(config, sm)
require.NoError(t, err)
servers := configurator.getOriginalNameservers()
t.Logf("After first apply (RouteAll=%v): %v", tc.initialRoute, servers)
assert.Equal(t, initialServers, servers)
// Toggle RouteAll
config.RouteAll = !tc.initialRoute
err = configurator.applyDNSConfig(config, sm)
require.NoError(t, err)
servers = configurator.getOriginalNameservers()
t.Logf("After toggle (RouteAll=%v): %v", config.RouteAll, servers)
assert.Equal(t, initialServers, servers)
// Toggle back
config.RouteAll = tc.initialRoute
err = configurator.applyDNSConfig(config, sm)
require.NoError(t, err)
servers = configurator.getOriginalNameservers()
t.Logf("After toggle back (RouteAll=%v): %v", config.RouteAll, servers)
assert.Equal(t, initialServers, servers)
for _, srv := range servers {
assert.NotEqual(t, netbirdIP, srv, "servers should not contain NetBird IP")
}
})
}
}

View File

@@ -81,10 +81,7 @@ func (d *Resolver) ProbeAvailability() {}
// ServeDNS handles a DNS request
func (d *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
logger := log.WithFields(log.Fields{
"request_id": resutil.GetRequestID(w),
"dns_id": fmt.Sprintf("%04x", r.Id),
})
logger := log.WithField("request_id", resutil.GetRequestID(w))
if len(r.Question) == 0 {
logger.Debug("received local resolver request with no question")
@@ -123,7 +120,7 @@ func (d *Resolver) determineRcode(question dns.Question, result lookupResult) in
}
// No records found, but domain exists with different record types (NODATA)
if d.hasRecordsForDomain(domain.Domain(question.Name), question.Qtype) {
if d.hasRecordsForDomain(domain.Domain(question.Name)) {
return dns.RcodeSuccess
}
@@ -167,15 +164,11 @@ func (d *Resolver) continueToNext(logger *log.Entry, w dns.ResponseWriter, r *dn
}
// hasRecordsForDomain checks if any records exist for the given domain name regardless of type
func (d *Resolver) hasRecordsForDomain(domainName domain.Domain, qType uint16) bool {
func (d *Resolver) hasRecordsForDomain(domainName domain.Domain) bool {
d.mu.RLock()
defer d.mu.RUnlock()
_, exists := d.domains[domainName]
if !exists && supportsWildcard(qType) {
testWild := transformDomainToWildcard(string(domainName))
_, exists = d.domains[domain.Domain(testWild)]
}
return exists
}
@@ -202,16 +195,6 @@ type lookupResult struct {
func (d *Resolver) lookupRecords(logger *log.Entry, question dns.Question) lookupResult {
d.mu.RLock()
records, found := d.records[question]
usingWildcard := false
wildQuestion := transformToWildcard(question)
// RFC 4592 section 2.2.1: wildcard only matches if the name does NOT exist in the zone.
// If the domain exists with any record type, return NODATA instead of wildcard match.
if !found && supportsWildcard(question.Qtype) {
if _, domainExists := d.domains[domain.Domain(question.Name)]; !domainExists {
records, found = d.records[wildQuestion]
usingWildcard = found
}
}
if !found {
d.mu.RUnlock()
@@ -233,53 +216,18 @@ func (d *Resolver) lookupRecords(logger *log.Entry, question dns.Question) looku
// if there's more than one record, rotate them (round-robin)
if len(recordsCopy) > 1 {
d.mu.Lock()
q := question
if usingWildcard {
q = wildQuestion
}
records = d.records[q]
records = d.records[question]
if len(records) > 1 {
first := records[0]
records = append(records[1:], first)
d.records[q] = records
d.records[question] = records
}
d.mu.Unlock()
}
if usingWildcard {
return responseFromWildRecords(question.Name, wildQuestion.Name, recordsCopy)
}
return lookupResult{records: recordsCopy, rcode: dns.RcodeSuccess}
}
func transformToWildcard(question dns.Question) dns.Question {
wildQuestion := question
wildQuestion.Name = transformDomainToWildcard(wildQuestion.Name)
return wildQuestion
}
func transformDomainToWildcard(domain string) string {
s := strings.Split(domain, ".")
s[0] = "*"
return strings.Join(s, ".")
}
func supportsWildcard(queryType uint16) bool {
return queryType != dns.TypeNS && queryType != dns.TypeSOA
}
func responseFromWildRecords(originalName, wildName string, wildRecords []dns.RR) lookupResult {
records := make([]dns.RR, len(wildRecords))
for i, record := range wildRecords {
copiedRecord := dns.Copy(record)
copiedRecord.Header().Name = originalName
records[i] = copiedRecord
}
return lookupResult{records: records, rcode: dns.RcodeSuccess}
}
// lookupCNAMEChain follows a CNAME chain and returns the CNAME records along with
// the final resolved record of the requested type. This is required for musl libc
// compatibility, which expects the full answer chain rather than just the CNAME.
@@ -289,13 +237,6 @@ func (d *Resolver) lookupCNAMEChain(logger *log.Entry, cnameQuestion dns.Questio
for range maxDepth {
cnameRecords := d.getRecords(cnameQuestion)
if len(cnameRecords) == 0 && supportsWildcard(targetType) {
wildQuestion := transformToWildcard(cnameQuestion)
if wildRecords := d.getRecords(wildQuestion); len(wildRecords) > 0 {
cnameRecords = responseFromWildRecords(cnameQuestion.Name, wildQuestion.Name, wildRecords).records
}
}
if len(cnameRecords) == 0 {
break
}
@@ -362,7 +303,7 @@ func (d *Resolver) resolveCNAMETarget(logger *log.Entry, targetName string, targ
}
// domain exists locally but not this record type (NODATA)
if d.hasRecordsForDomain(domain.Domain(targetName), targetType) {
if d.hasRecordsForDomain(domain.Domain(targetName)) {
return lookupResult{rcode: dns.RcodeSuccess}
}

File diff suppressed because it is too large Load Diff

View File

@@ -6,9 +6,7 @@ import (
"fmt"
"net/netip"
"net/url"
"os"
"runtime"
"strconv"
"strings"
"sync"
@@ -29,8 +27,6 @@ import (
"github.com/netbirdio/netbird/shared/management/domain"
)
const envSkipDNSProbe = "NB_SKIP_DNS_PROBE"
// ReadyListener is a notification mechanism what indicate the server is ready to handle host dns address changes
type ReadyListener interface {
OnReady()
@@ -443,17 +439,6 @@ func (s *DefaultServer) SearchDomains() []string {
// ProbeAvailability tests each upstream group's servers for availability
// and deactivates the group if no server responds
func (s *DefaultServer) ProbeAvailability() {
if val := os.Getenv(envSkipDNSProbe); val != "" {
skipProbe, err := strconv.ParseBool(val)
if err != nil {
log.Warnf("failed to parse %s: %v", envSkipDNSProbe, err)
}
if skipProbe {
log.Infof("skipping DNS probe due to %s", envSkipDNSProbe)
return
}
}
var wg sync.WaitGroup
for _, mux := range s.dnsMuxMap {
wg.Add(1)
@@ -630,7 +615,7 @@ func (s *DefaultServer) applyHostConfig() {
s.registerFallback(config)
}
// registerFallback registers original nameservers as low-priority fallback handlers.
// registerFallback registers original nameservers as low-priority fallback handlers
func (s *DefaultServer) registerFallback(config HostDNSConfig) {
hostMgrWithNS, ok := s.hostManager.(hostManagerWithOriginalNS)
if !ok {
@@ -639,7 +624,6 @@ func (s *DefaultServer) registerFallback(config HostDNSConfig) {
originalNameservers := hostMgrWithNS.getOriginalNameservers()
if len(originalNameservers) == 0 {
s.deregisterHandler([]string{nbdns.RootZone}, PriorityFallback)
return
}

View File

@@ -8,21 +8,15 @@ import (
type MockResponseWriter struct {
WriteMsgFunc func(m *dns.Msg) error
lastResponse *dns.Msg
}
func (rw *MockResponseWriter) WriteMsg(m *dns.Msg) error {
rw.lastResponse = m
if rw.WriteMsgFunc != nil {
return rw.WriteMsgFunc(m)
}
return nil
}
func (rw *MockResponseWriter) GetLastResponse() *dns.Msg {
return rw.lastResponse
}
func (rw *MockResponseWriter) LocalAddr() net.Addr { return nil }
func (rw *MockResponseWriter) RemoteAddr() net.Addr { return nil }
func (rw *MockResponseWriter) Write([]byte) (int, error) { return 0, nil }

View File

@@ -71,11 +71,6 @@ type upstreamResolverBase struct {
statusRecorder *peer.Status
}
type upstreamFailure struct {
upstream netip.AddrPort
reason string
}
func newUpstreamResolverBase(ctx context.Context, statusRecorder *peer.Status, domain string) *upstreamResolverBase {
ctx, cancel := context.WithCancel(ctx)
@@ -119,10 +114,7 @@ func (u *upstreamResolverBase) Stop() {
// ServeDNS handles a DNS request
func (u *upstreamResolverBase) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
logger := log.WithFields(log.Fields{
"request_id": resutil.GetRequestID(w),
"dns_id": fmt.Sprintf("%04x", r.Id),
})
logger := log.WithField("request_id", resutil.GetRequestID(w))
u.prepareRequest(r)
@@ -131,13 +123,11 @@ func (u *upstreamResolverBase) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
return
}
ok, failures := u.tryUpstreamServers(w, r, logger)
if len(failures) > 0 {
u.logUpstreamFailures(r.Question[0].Name, failures, ok, logger)
}
if !ok {
u.writeErrorResponse(w, r, logger)
if u.tryUpstreamServers(w, r, logger) {
return
}
u.writeErrorResponse(w, r, logger)
}
func (u *upstreamResolverBase) prepareRequest(r *dns.Msg) {
@@ -146,7 +136,7 @@ func (u *upstreamResolverBase) prepareRequest(r *dns.Msg) {
}
}
func (u *upstreamResolverBase) tryUpstreamServers(w dns.ResponseWriter, r *dns.Msg, logger *log.Entry) (bool, []upstreamFailure) {
func (u *upstreamResolverBase) tryUpstreamServers(w dns.ResponseWriter, r *dns.Msg, logger *log.Entry) bool {
timeout := u.upstreamTimeout
if len(u.upstreamServers) > 1 {
maxTotal := 5 * time.Second
@@ -159,19 +149,15 @@ func (u *upstreamResolverBase) tryUpstreamServers(w dns.ResponseWriter, r *dns.M
}
}
var failures []upstreamFailure
for _, upstream := range u.upstreamServers {
if failure := u.queryUpstream(w, r, upstream, timeout, logger); failure != nil {
failures = append(failures, *failure)
} else {
return true, failures
if u.queryUpstream(w, r, upstream, timeout, logger) {
return true
}
}
return false, failures
return false
}
// queryUpstream queries a single upstream server. Returns nil on success, or failure info to try next upstream.
func (u *upstreamResolverBase) queryUpstream(w dns.ResponseWriter, r *dns.Msg, upstream netip.AddrPort, timeout time.Duration, logger *log.Entry) *upstreamFailure {
func (u *upstreamResolverBase) queryUpstream(w dns.ResponseWriter, r *dns.Msg, upstream netip.AddrPort, timeout time.Duration, logger *log.Entry) bool {
var rm *dns.Msg
var t time.Duration
var err error
@@ -185,32 +171,31 @@ func (u *upstreamResolverBase) queryUpstream(w dns.ResponseWriter, r *dns.Msg, u
}()
if err != nil {
return u.handleUpstreamError(err, upstream, startTime)
u.handleUpstreamError(err, upstream, r.Question[0].Name, startTime, timeout, logger)
return false
}
if rm == nil || !rm.Response {
return &upstreamFailure{upstream: upstream, reason: "no response"}
logger.Warnf("no response from upstream %s for question domain=%s", upstream, r.Question[0].Name)
return false
}
if rm.Rcode == dns.RcodeServerFailure || rm.Rcode == dns.RcodeRefused {
return &upstreamFailure{upstream: upstream, reason: dns.RcodeToString[rm.Rcode]}
}
u.writeSuccessResponse(w, rm, upstream, r.Question[0].Name, t, logger)
return nil
return u.writeSuccessResponse(w, rm, upstream, r.Question[0].Name, t, logger)
}
func (u *upstreamResolverBase) handleUpstreamError(err error, upstream netip.AddrPort, startTime time.Time) *upstreamFailure {
func (u *upstreamResolverBase) handleUpstreamError(err error, upstream netip.AddrPort, domain string, startTime time.Time, timeout time.Duration, logger *log.Entry) {
if !errors.Is(err, context.DeadlineExceeded) && !isTimeout(err) {
return &upstreamFailure{upstream: upstream, reason: err.Error()}
logger.Warnf("failed to query upstream %s for question domain=%s: %s", upstream, domain, err)
return
}
elapsed := time.Since(startTime)
reason := fmt.Sprintf("timeout after %v", elapsed.Truncate(time.Millisecond))
timeoutMsg := fmt.Sprintf("upstream %s timed out for question domain=%s after %v (timeout=%v)", upstream, domain, elapsed.Truncate(time.Millisecond), timeout)
if peerInfo := u.debugUpstreamTimeout(upstream); peerInfo != "" {
reason += " " + peerInfo
timeoutMsg += " " + peerInfo
}
return &upstreamFailure{upstream: upstream, reason: reason}
timeoutMsg += fmt.Sprintf(" - error: %v", err)
logger.Warn(timeoutMsg)
}
func (u *upstreamResolverBase) writeSuccessResponse(w dns.ResponseWriter, rm *dns.Msg, upstream netip.AddrPort, domain string, t time.Duration, logger *log.Entry) bool {
@@ -230,34 +215,16 @@ func (u *upstreamResolverBase) writeSuccessResponse(w dns.ResponseWriter, rm *dn
return true
}
func (u *upstreamResolverBase) logUpstreamFailures(domain string, failures []upstreamFailure, succeeded bool, logger *log.Entry) {
totalUpstreams := len(u.upstreamServers)
failedCount := len(failures)
failureSummary := formatFailures(failures)
if succeeded {
logger.Warnf("%d/%d upstreams failed for domain=%s: %s", failedCount, totalUpstreams, domain, failureSummary)
} else {
logger.Errorf("%d/%d upstreams failed for domain=%s: %s", failedCount, totalUpstreams, domain, failureSummary)
}
}
func (u *upstreamResolverBase) writeErrorResponse(w dns.ResponseWriter, r *dns.Msg, logger *log.Entry) {
logger.Errorf("all queries to the %s failed for question domain=%s", u, r.Question[0].Name)
m := new(dns.Msg)
m.SetRcode(r, dns.RcodeServerFailure)
if err := w.WriteMsg(m); err != nil {
logger.Errorf("write error response for domain=%s: %s", r.Question[0].Name, err)
logger.Errorf("failed to write error response for %s for question domain=%s: %s", u, r.Question[0].Name, err)
}
}
func formatFailures(failures []upstreamFailure) string {
parts := make([]string, 0, len(failures))
for _, f := range failures {
parts = append(parts, fmt.Sprintf("%s=%s", f.upstream, f.reason))
}
return strings.Join(parts, ", ")
}
// ProbeAvailability tests all upstream servers simultaneously and
// disables the resolver if none work
func (u *upstreamResolverBase) ProbeAvailability() {
@@ -501,6 +468,7 @@ func netstackExchange(ctx context.Context, nsNet *netstack.Net, r *dns.Msg, upst
return reply, nil
}
// FormatPeerStatus formats peer connection status information for debugging DNS timeouts
func FormatPeerStatus(peerState *peer.State) string {
isConnected := peerState.ConnStatus == peer.StatusConnected

View File

@@ -2,7 +2,6 @@ package dns
import (
"context"
"fmt"
"net"
"net/netip"
"strings"
@@ -10,8 +9,6 @@ import (
"time"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.zx2c4.com/wireguard/tun/netstack"
"github.com/netbirdio/netbird/client/iface/device"
@@ -143,23 +140,6 @@ func (c mockUpstreamResolver) exchange(_ context.Context, _ string, _ *dns.Msg)
return c.r, c.rtt, c.err
}
type mockUpstreamResponse struct {
msg *dns.Msg
err error
}
type mockUpstreamResolverPerServer struct {
responses map[string]mockUpstreamResponse
rtt time.Duration
}
func (c mockUpstreamResolverPerServer) exchange(_ context.Context, upstream string, _ *dns.Msg) (*dns.Msg, time.Duration, error) {
if r, ok := c.responses[upstream]; ok {
return r.msg, c.rtt, r.err
}
return nil, c.rtt, fmt.Errorf("no mock response for %s", upstream)
}
func TestUpstreamResolver_DeactivationReactivation(t *testing.T) {
mockClient := &mockUpstreamResolver{
err: dns.ErrTime,
@@ -211,267 +191,3 @@ func TestUpstreamResolver_DeactivationReactivation(t *testing.T) {
t.Errorf("should be enabled")
}
}
func TestUpstreamResolver_Failover(t *testing.T) {
upstream1 := netip.MustParseAddrPort("192.0.2.1:53")
upstream2 := netip.MustParseAddrPort("192.0.2.2:53")
successAnswer := "192.0.2.100"
timeoutErr := &net.OpError{Op: "read", Err: fmt.Errorf("i/o timeout")}
testCases := []struct {
name string
upstream1 mockUpstreamResponse
upstream2 mockUpstreamResponse
expectedRcode int
expectAnswer bool
expectTrySecond bool
}{
{
name: "success on first upstream",
upstream1: mockUpstreamResponse{msg: buildMockResponse(dns.RcodeSuccess, successAnswer)},
upstream2: mockUpstreamResponse{msg: buildMockResponse(dns.RcodeSuccess, successAnswer)},
expectedRcode: dns.RcodeSuccess,
expectAnswer: true,
expectTrySecond: false,
},
{
name: "SERVFAIL from first should try second",
upstream1: mockUpstreamResponse{msg: buildMockResponse(dns.RcodeServerFailure, "")},
upstream2: mockUpstreamResponse{msg: buildMockResponse(dns.RcodeSuccess, successAnswer)},
expectedRcode: dns.RcodeSuccess,
expectAnswer: true,
expectTrySecond: true,
},
{
name: "REFUSED from first should try second",
upstream1: mockUpstreamResponse{msg: buildMockResponse(dns.RcodeRefused, "")},
upstream2: mockUpstreamResponse{msg: buildMockResponse(dns.RcodeSuccess, successAnswer)},
expectedRcode: dns.RcodeSuccess,
expectAnswer: true,
expectTrySecond: true,
},
{
name: "NXDOMAIN from first should NOT try second",
upstream1: mockUpstreamResponse{msg: buildMockResponse(dns.RcodeNameError, "")},
upstream2: mockUpstreamResponse{msg: buildMockResponse(dns.RcodeSuccess, successAnswer)},
expectedRcode: dns.RcodeNameError,
expectAnswer: false,
expectTrySecond: false,
},
{
name: "timeout from first should try second",
upstream1: mockUpstreamResponse{err: timeoutErr},
upstream2: mockUpstreamResponse{msg: buildMockResponse(dns.RcodeSuccess, successAnswer)},
expectedRcode: dns.RcodeSuccess,
expectAnswer: true,
expectTrySecond: true,
},
{
name: "no response from first should try second",
upstream1: mockUpstreamResponse{msg: nil},
upstream2: mockUpstreamResponse{msg: buildMockResponse(dns.RcodeSuccess, successAnswer)},
expectedRcode: dns.RcodeSuccess,
expectAnswer: true,
expectTrySecond: true,
},
{
name: "both upstreams return SERVFAIL",
upstream1: mockUpstreamResponse{msg: buildMockResponse(dns.RcodeServerFailure, "")},
upstream2: mockUpstreamResponse{msg: buildMockResponse(dns.RcodeServerFailure, "")},
expectedRcode: dns.RcodeServerFailure,
expectAnswer: false,
expectTrySecond: true,
},
{
name: "both upstreams timeout",
upstream1: mockUpstreamResponse{err: timeoutErr},
upstream2: mockUpstreamResponse{err: timeoutErr},
expectedRcode: dns.RcodeServerFailure,
expectAnswer: false,
expectTrySecond: true,
},
{
name: "first SERVFAIL then timeout",
upstream1: mockUpstreamResponse{msg: buildMockResponse(dns.RcodeServerFailure, "")},
upstream2: mockUpstreamResponse{err: timeoutErr},
expectedRcode: dns.RcodeServerFailure,
expectAnswer: false,
expectTrySecond: true,
},
{
name: "first timeout then SERVFAIL",
upstream1: mockUpstreamResponse{err: timeoutErr},
upstream2: mockUpstreamResponse{msg: buildMockResponse(dns.RcodeServerFailure, "")},
expectedRcode: dns.RcodeServerFailure,
expectAnswer: false,
expectTrySecond: true,
},
{
name: "first REFUSED then SERVFAIL",
upstream1: mockUpstreamResponse{msg: buildMockResponse(dns.RcodeRefused, "")},
upstream2: mockUpstreamResponse{msg: buildMockResponse(dns.RcodeServerFailure, "")},
expectedRcode: dns.RcodeServerFailure,
expectAnswer: false,
expectTrySecond: true,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
var queriedUpstreams []string
mockClient := &mockUpstreamResolverPerServer{
responses: map[string]mockUpstreamResponse{
upstream1.String(): tc.upstream1,
upstream2.String(): tc.upstream2,
},
rtt: time.Millisecond,
}
trackingClient := &trackingMockClient{
inner: mockClient,
queriedUpstreams: &queriedUpstreams,
}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
resolver := &upstreamResolverBase{
ctx: ctx,
upstreamClient: trackingClient,
upstreamServers: []netip.AddrPort{upstream1, upstream2},
upstreamTimeout: UpstreamTimeout,
}
var responseMSG *dns.Msg
responseWriter := &test.MockResponseWriter{
WriteMsgFunc: func(m *dns.Msg) error {
responseMSG = m
return nil
},
}
inputMSG := new(dns.Msg).SetQuestion("example.com.", dns.TypeA)
resolver.ServeDNS(responseWriter, inputMSG)
require.NotNil(t, responseMSG, "should write a response")
assert.Equal(t, tc.expectedRcode, responseMSG.Rcode, "unexpected rcode")
if tc.expectAnswer {
require.NotEmpty(t, responseMSG.Answer, "expected answer records")
assert.Contains(t, responseMSG.Answer[0].String(), successAnswer)
}
if tc.expectTrySecond {
assert.Len(t, queriedUpstreams, 2, "should have tried both upstreams")
assert.Equal(t, upstream1.String(), queriedUpstreams[0])
assert.Equal(t, upstream2.String(), queriedUpstreams[1])
} else {
assert.Len(t, queriedUpstreams, 1, "should have only tried first upstream")
assert.Equal(t, upstream1.String(), queriedUpstreams[0])
}
})
}
}
type trackingMockClient struct {
inner *mockUpstreamResolverPerServer
queriedUpstreams *[]string
}
func (t *trackingMockClient) exchange(ctx context.Context, upstream string, r *dns.Msg) (*dns.Msg, time.Duration, error) {
*t.queriedUpstreams = append(*t.queriedUpstreams, upstream)
return t.inner.exchange(ctx, upstream, r)
}
func buildMockResponse(rcode int, answer string) *dns.Msg {
m := new(dns.Msg)
m.Response = true
m.Rcode = rcode
if rcode == dns.RcodeSuccess && answer != "" {
m.Answer = []dns.RR{
&dns.A{
Hdr: dns.RR_Header{
Name: "example.com.",
Rrtype: dns.TypeA,
Class: dns.ClassINET,
Ttl: 300,
},
A: net.ParseIP(answer),
},
}
}
return m
}
func TestUpstreamResolver_SingleUpstreamFailure(t *testing.T) {
upstream := netip.MustParseAddrPort("192.0.2.1:53")
mockClient := &mockUpstreamResolverPerServer{
responses: map[string]mockUpstreamResponse{
upstream.String(): {msg: buildMockResponse(dns.RcodeServerFailure, "")},
},
rtt: time.Millisecond,
}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
resolver := &upstreamResolverBase{
ctx: ctx,
upstreamClient: mockClient,
upstreamServers: []netip.AddrPort{upstream},
upstreamTimeout: UpstreamTimeout,
}
var responseMSG *dns.Msg
responseWriter := &test.MockResponseWriter{
WriteMsgFunc: func(m *dns.Msg) error {
responseMSG = m
return nil
},
}
inputMSG := new(dns.Msg).SetQuestion("example.com.", dns.TypeA)
resolver.ServeDNS(responseWriter, inputMSG)
require.NotNil(t, responseMSG, "should write a response")
assert.Equal(t, dns.RcodeServerFailure, responseMSG.Rcode, "single upstream SERVFAIL should return SERVFAIL")
}
func TestFormatFailures(t *testing.T) {
testCases := []struct {
name string
failures []upstreamFailure
expected string
}{
{
name: "empty slice",
failures: []upstreamFailure{},
expected: "",
},
{
name: "single failure",
failures: []upstreamFailure{
{upstream: netip.MustParseAddrPort("8.8.8.8:53"), reason: "SERVFAIL"},
},
expected: "8.8.8.8:53=SERVFAIL",
},
{
name: "multiple failures",
failures: []upstreamFailure{
{upstream: netip.MustParseAddrPort("8.8.8.8:53"), reason: "SERVFAIL"},
{upstream: netip.MustParseAddrPort("8.8.4.4:53"), reason: "timeout after 2s"},
},
expected: "8.8.8.8:53=SERVFAIL, 8.8.4.4:53=timeout after 2s",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
result := formatFailures(tc.failures)
assert.Equal(t, tc.expected, result)
})
}
}

View File

@@ -190,75 +190,50 @@ func (f *DNSForwarder) Close(ctx context.Context) error {
return nberrors.FormatErrorOrNil(result)
}
func (f *DNSForwarder) handleDNSQuery(logger *log.Entry, w dns.ResponseWriter, query *dns.Msg, startTime time.Time) {
func (f *DNSForwarder) handleDNSQuery(logger *log.Entry, w dns.ResponseWriter, query *dns.Msg) *dns.Msg {
if len(query.Question) == 0 {
return
return nil
}
question := query.Question[0]
qname := strings.ToLower(question.Name)
logger.Tracef("received DNS request for DNS forwarder: domain=%s type=%s class=%s",
question.Name, dns.TypeToString[question.Qtype], dns.ClassToString[question.Qclass])
logger.Tracef("question: domain=%s type=%s class=%s",
qname, dns.TypeToString[question.Qtype], dns.ClassToString[question.Qclass])
domain := strings.ToLower(question.Name)
resp := query.SetReply(query)
network := resutil.NetworkForQtype(question.Qtype)
if network == "" {
resp.Rcode = dns.RcodeNotImplemented
f.writeResponse(logger, w, resp, qname, startTime)
return
if err := w.WriteMsg(resp); err != nil {
logger.Errorf("failed to write DNS response: %v", err)
}
return nil
}
mostSpecificResId, matchingEntries := f.getMatchingEntries(strings.TrimSuffix(qname, "."))
mostSpecificResId, matchingEntries := f.getMatchingEntries(strings.TrimSuffix(domain, "."))
// query doesn't match any configured domain
if mostSpecificResId == "" {
resp.Rcode = dns.RcodeRefused
f.writeResponse(logger, w, resp, qname, startTime)
return
if err := w.WriteMsg(resp); err != nil {
logger.Errorf("failed to write DNS response: %v", err)
}
return nil
}
ctx, cancel := context.WithTimeout(context.Background(), upstreamTimeout)
defer cancel()
result := resutil.LookupIP(ctx, f.resolver, network, qname, question.Qtype)
result := resutil.LookupIP(ctx, f.resolver, network, domain, question.Qtype)
if result.Err != nil {
f.handleDNSError(ctx, logger, w, question, resp, qname, result, startTime)
return
f.handleDNSError(ctx, logger, w, question, resp, domain, result)
return nil
}
f.updateInternalState(result.IPs, mostSpecificResId, matchingEntries)
resp.Answer = append(resp.Answer, resutil.IPsToRRs(qname, result.IPs, f.ttl)...)
f.cache.set(qname, question.Qtype, result.IPs)
resp.Answer = append(resp.Answer, resutil.IPsToRRs(domain, result.IPs, f.ttl)...)
f.cache.set(domain, question.Qtype, result.IPs)
f.writeResponse(logger, w, resp, qname, startTime)
}
func (f *DNSForwarder) writeResponse(logger *log.Entry, w dns.ResponseWriter, resp *dns.Msg, qname string, startTime time.Time) {
if err := w.WriteMsg(resp); err != nil {
logger.Errorf("failed to write DNS response: %v", err)
return
}
logger.Tracef("response: domain=%s rcode=%s answers=%s took=%s",
qname, dns.RcodeToString[resp.Rcode], resutil.FormatAnswers(resp.Answer), time.Since(startTime))
}
// udpResponseWriter wraps a dns.ResponseWriter to handle UDP-specific truncation.
type udpResponseWriter struct {
dns.ResponseWriter
query *dns.Msg
}
func (u *udpResponseWriter) WriteMsg(resp *dns.Msg) error {
opt := u.query.IsEdns0()
maxSize := dns.MinMsgSize
if opt != nil {
maxSize = int(opt.UDPSize())
}
if resp.Len() > maxSize {
resp.Truncate(maxSize)
}
return u.ResponseWriter.WriteMsg(resp)
return resp
}
func (f *DNSForwarder) handleDNSQueryUDP(w dns.ResponseWriter, query *dns.Msg) {
@@ -268,7 +243,30 @@ func (f *DNSForwarder) handleDNSQueryUDP(w dns.ResponseWriter, query *dns.Msg) {
"dns_id": fmt.Sprintf("%04x", query.Id),
})
f.handleDNSQuery(logger, &udpResponseWriter{ResponseWriter: w, query: query}, query, startTime)
resp := f.handleDNSQuery(logger, w, query)
if resp == nil {
return
}
opt := query.IsEdns0()
maxSize := dns.MinMsgSize
if opt != nil {
// client advertised a larger EDNS0 buffer
maxSize = int(opt.UDPSize())
}
// if our response is too big, truncate and set the TC bit
if resp.Len() > maxSize {
resp.Truncate(maxSize)
}
if err := w.WriteMsg(resp); err != nil {
logger.Errorf("failed to write DNS response: %v", err)
return
}
logger.Tracef("response: domain=%s rcode=%s answers=%s took=%s",
query.Question[0].Name, dns.RcodeToString[resp.Rcode], resutil.FormatAnswers(resp.Answer), time.Since(startTime))
}
func (f *DNSForwarder) handleDNSQueryTCP(w dns.ResponseWriter, query *dns.Msg) {
@@ -278,7 +276,18 @@ func (f *DNSForwarder) handleDNSQueryTCP(w dns.ResponseWriter, query *dns.Msg) {
"dns_id": fmt.Sprintf("%04x", query.Id),
})
f.handleDNSQuery(logger, w, query, startTime)
resp := f.handleDNSQuery(logger, w, query)
if resp == nil {
return
}
if err := w.WriteMsg(resp); err != nil {
logger.Errorf("failed to write DNS response: %v", err)
return
}
logger.Tracef("response: domain=%s rcode=%s answers=%s took=%s",
query.Question[0].Name, dns.RcodeToString[resp.Rcode], resutil.FormatAnswers(resp.Answer), time.Since(startTime))
}
func (f *DNSForwarder) updateInternalState(ips []netip.Addr, mostSpecificResId route.ResID, matchingEntries []*ForwarderEntry) {
@@ -325,7 +334,6 @@ func (f *DNSForwarder) handleDNSError(
resp *dns.Msg,
domain string,
result resutil.LookupResult,
startTime time.Time,
) {
qType := question.Qtype
qTypeName := dns.TypeToString[qType]
@@ -335,7 +343,9 @@ func (f *DNSForwarder) handleDNSError(
// NotFound: cache negative result and respond
if result.Rcode == dns.RcodeNameError || result.Rcode == dns.RcodeSuccess {
f.cache.set(domain, question.Qtype, nil)
f.writeResponse(logger, w, resp, domain, startTime)
if writeErr := w.WriteMsg(resp); writeErr != nil {
logger.Errorf("failed to write failure DNS response: %v", writeErr)
}
return
}
@@ -345,7 +355,9 @@ func (f *DNSForwarder) handleDNSError(
logger.Debugf("serving cached DNS response after upstream failure: domain=%s type=%s", domain, qTypeName)
resp.Answer = append(resp.Answer, resutil.IPsToRRs(domain, ips, f.ttl)...)
resp.Rcode = dns.RcodeSuccess
f.writeResponse(logger, w, resp, domain, startTime)
if writeErr := w.WriteMsg(resp); writeErr != nil {
logger.Errorf("failed to write cached DNS response: %v", writeErr)
}
return
}
@@ -353,7 +365,9 @@ func (f *DNSForwarder) handleDNSError(
verifyResult := resutil.LookupIP(ctx, f.resolver, resutil.NetworkForQtype(qType), domain, qType)
if verifyResult.Rcode == dns.RcodeNameError || verifyResult.Rcode == dns.RcodeSuccess {
resp.Rcode = verifyResult.Rcode
f.writeResponse(logger, w, resp, domain, startTime)
if writeErr := w.WriteMsg(resp); writeErr != nil {
logger.Errorf("failed to write failure DNS response: %v", writeErr)
}
return
}
}
@@ -361,12 +375,15 @@ func (f *DNSForwarder) handleDNSError(
// No cache or verification failed. Log with or without the server field for more context.
var dnsErr *net.DNSError
if errors.As(result.Err, &dnsErr) && dnsErr.Server != "" {
logger.Warnf("upstream failure: type=%s domain=%s server=%s: %v", qTypeName, domain, dnsErr.Server, result.Err)
logger.Warnf("failed to resolve: type=%s domain=%s server=%s: %v", qTypeName, domain, dnsErr.Server, result.Err)
} else {
logger.Warnf(errResolveFailed, domain, result.Err)
}
f.writeResponse(logger, w, resp, domain, startTime)
// Write final failure response.
if writeErr := w.WriteMsg(resp); writeErr != nil {
logger.Errorf("failed to write failure DNS response: %v", writeErr)
}
}
// getMatchingEntries retrieves the resource IDs for a given domain.

View File

@@ -318,9 +318,8 @@ func TestDNSForwarder_UnauthorizedDomainAccess(t *testing.T) {
query.SetQuestion(dns.Fqdn(tt.queryDomain), dns.TypeA)
mockWriter := &test.MockResponseWriter{}
forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query, time.Now())
resp := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query)
resp := mockWriter.GetLastResponse()
if tt.shouldResolve {
require.NotNil(t, resp, "Expected response for authorized domain")
require.Equal(t, dns.RcodeSuccess, resp.Rcode, "Expected successful response")
@@ -330,9 +329,10 @@ func TestDNSForwarder_UnauthorizedDomainAccess(t *testing.T) {
mockFirewall.AssertExpectations(t)
mockResolver.AssertExpectations(t)
} else {
require.NotNil(t, resp, "Expected response")
assert.True(t, len(resp.Answer) == 0 || resp.Rcode != dns.RcodeSuccess,
"Unauthorized domain should not return successful answers")
if resp != nil {
assert.True(t, len(resp.Answer) == 0 || resp.Rcode != dns.RcodeSuccess,
"Unauthorized domain should not return successful answers")
}
mockFirewall.AssertNotCalled(t, "UpdateSet")
mockResolver.AssertNotCalled(t, "LookupNetIP")
}
@@ -466,16 +466,14 @@ func TestDNSForwarder_FirewallSetUpdates(t *testing.T) {
dnsQuery.SetQuestion(dns.Fqdn(tt.query), dns.TypeA)
mockWriter := &test.MockResponseWriter{}
forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, dnsQuery, time.Now())
resp := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, dnsQuery)
// Verify response
resp := mockWriter.GetLastResponse()
if tt.shouldResolve {
require.NotNil(t, resp, "Expected response for authorized domain")
require.Equal(t, dns.RcodeSuccess, resp.Rcode)
require.NotEmpty(t, resp.Answer)
} else {
require.NotNil(t, resp, "Expected response")
} else if resp != nil {
assert.True(t, resp.Rcode == dns.RcodeRefused || len(resp.Answer) == 0,
"Unauthorized domain should be refused or have no answers")
}
@@ -530,10 +528,9 @@ func TestDNSForwarder_MultipleIPsInSingleUpdate(t *testing.T) {
query.SetQuestion("example.com.", dns.TypeA)
mockWriter := &test.MockResponseWriter{}
forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query, time.Now())
resp := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query)
// Verify response contains all IPs
resp := mockWriter.GetLastResponse()
require.NotNil(t, resp)
require.Equal(t, dns.RcodeSuccess, resp.Rcode)
require.Len(t, resp.Answer, 3, "Should have 3 answer records")
@@ -608,7 +605,7 @@ func TestDNSForwarder_ResponseCodes(t *testing.T) {
},
}
forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query, time.Now())
_ = forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query)
// Check the response written to the writer
require.NotNil(t, writtenResp, "Expected response to be written")
@@ -678,8 +675,7 @@ func TestDNSForwarder_ServeFromCacheOnUpstreamFailure(t *testing.T) {
q1 := &dns.Msg{}
q1.SetQuestion(dns.Fqdn("example.com"), dns.TypeA)
w1 := &test.MockResponseWriter{}
forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), w1, q1, time.Now())
resp1 := w1.GetLastResponse()
resp1 := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), w1, q1)
require.NotNil(t, resp1)
require.Equal(t, dns.RcodeSuccess, resp1.Rcode)
require.Len(t, resp1.Answer, 1)
@@ -687,13 +683,13 @@ func TestDNSForwarder_ServeFromCacheOnUpstreamFailure(t *testing.T) {
// Second query: serve from cache after upstream failure
q2 := &dns.Msg{}
q2.SetQuestion(dns.Fqdn("example.com"), dns.TypeA)
w2 := &test.MockResponseWriter{}
forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), w2, q2, time.Now())
var writtenResp *dns.Msg
w2 := &test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { writtenResp = m; return nil }}
_ = forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), w2, q2)
resp2 := w2.GetLastResponse()
require.NotNil(t, resp2, "expected response to be written")
require.Equal(t, dns.RcodeSuccess, resp2.Rcode)
require.Len(t, resp2.Answer, 1)
require.NotNil(t, writtenResp, "expected response to be written")
require.Equal(t, dns.RcodeSuccess, writtenResp.Rcode)
require.Len(t, writtenResp.Answer, 1)
mockResolver.AssertExpectations(t)
}
@@ -719,8 +715,7 @@ func TestDNSForwarder_CacheNormalizationCasingAndDot(t *testing.T) {
q1 := &dns.Msg{}
q1.SetQuestion(mixedQuery+".", dns.TypeA)
w1 := &test.MockResponseWriter{}
forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), w1, q1, time.Now())
resp1 := w1.GetLastResponse()
resp1 := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), w1, q1)
require.NotNil(t, resp1)
require.Equal(t, dns.RcodeSuccess, resp1.Rcode)
require.Len(t, resp1.Answer, 1)
@@ -732,13 +727,13 @@ func TestDNSForwarder_CacheNormalizationCasingAndDot(t *testing.T) {
q2 := &dns.Msg{}
q2.SetQuestion("EXAMPLE.COM", dns.TypeA)
w2 := &test.MockResponseWriter{}
forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), w2, q2, time.Now())
var writtenResp *dns.Msg
w2 := &test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { writtenResp = m; return nil }}
_ = forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), w2, q2)
resp2 := w2.GetLastResponse()
require.NotNil(t, resp2)
require.Equal(t, dns.RcodeSuccess, resp2.Rcode)
require.Len(t, resp2.Answer, 1)
require.NotNil(t, writtenResp)
require.Equal(t, dns.RcodeSuccess, writtenResp.Rcode)
require.Len(t, writtenResp.Answer, 1)
mockResolver.AssertExpectations(t)
}
@@ -789,9 +784,8 @@ func TestDNSForwarder_MultipleOverlappingPatterns(t *testing.T) {
query.SetQuestion("smtp.mail.example.com.", dns.TypeA)
mockWriter := &test.MockResponseWriter{}
forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query, time.Now())
resp := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query)
resp := mockWriter.GetLastResponse()
require.NotNil(t, resp)
assert.Equal(t, dns.RcodeSuccess, resp.Rcode)
@@ -903,15 +897,26 @@ func TestDNSForwarder_NodataVsNxdomain(t *testing.T) {
query := &dns.Msg{}
query.SetQuestion(dns.Fqdn("example.com"), tt.queryType)
mockWriter := &test.MockResponseWriter{}
forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query, time.Now())
var writtenResp *dns.Msg
mockWriter := &test.MockResponseWriter{
WriteMsgFunc: func(m *dns.Msg) error {
writtenResp = m
return nil
},
}
resp := mockWriter.GetLastResponse()
require.NotNil(t, resp, "Expected response to be written")
assert.Equal(t, tt.expectedCode, resp.Rcode, tt.description)
resp := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query)
// If a response was returned, it means it should be written (happens in wrapper functions)
if resp != nil && writtenResp == nil {
writtenResp = resp
}
require.NotNil(t, writtenResp, "Expected response to be written")
assert.Equal(t, tt.expectedCode, writtenResp.Rcode, tt.description)
if tt.expectNoAnswer {
assert.Empty(t, resp.Answer, "Response should have no answer records")
assert.Empty(t, writtenResp.Answer, "Response should have no answer records")
}
mockResolver.AssertExpectations(t)
@@ -926,8 +931,15 @@ func TestDNSForwarder_EmptyQuery(t *testing.T) {
query := &dns.Msg{}
// Don't set any question
mockWriter := &test.MockResponseWriter{}
forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query, time.Now())
writeCalled := false
mockWriter := &test.MockResponseWriter{
WriteMsgFunc: func(m *dns.Msg) error {
writeCalled = true
return nil
},
}
resp := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query)
assert.Nil(t, mockWriter.GetLastResponse(), "Should not write response for empty query")
assert.Nil(t, resp, "Should return nil for empty query")
assert.False(t, writeCalled, "Should not write response for empty query")
}

View File

@@ -28,11 +28,9 @@ import (
"github.com/netbirdio/netbird/client/firewall"
firewallManager "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface"
nbnetstack "github.com/netbirdio/netbird/client/iface/netstack"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/udpmux"
"github.com/netbirdio/netbird/client/internal/acl"
"github.com/netbirdio/netbird/client/internal/debug"
"github.com/netbirdio/netbird/client/internal/dns"
dnsconfig "github.com/netbirdio/netbird/client/internal/dns/config"
"github.com/netbirdio/netbird/client/internal/dnsfwd"
@@ -44,14 +42,12 @@ import (
"github.com/netbirdio/netbird/client/internal/peer/guard"
icemaker "github.com/netbirdio/netbird/client/internal/peer/ice"
"github.com/netbirdio/netbird/client/internal/peerstore"
"github.com/netbirdio/netbird/client/internal/profilemanager"
"github.com/netbirdio/netbird/client/internal/relay"
"github.com/netbirdio/netbird/client/internal/rosenpass"
"github.com/netbirdio/netbird/client/internal/routemanager"
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
"github.com/netbirdio/netbird/client/internal/statemanager"
"github.com/netbirdio/netbird/client/internal/updatemanager"
"github.com/netbirdio/netbird/client/jobexec"
cProto "github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/shared/management/domain"
semaphoregroup "github.com/netbirdio/netbird/util/semaphore-group"
@@ -136,11 +132,6 @@ type EngineConfig struct {
LazyConnectionEnabled bool
MTU uint16
// for debug bundle generation
ProfileConfig *profilemanager.Config
LogPath string
}
// Engine is a mechanism responsible for reacting on Signal and Management stream events and managing connections to the remote peers.
@@ -204,8 +195,7 @@ type Engine struct {
stateManager *statemanager.Manager
srWatcher *guard.SRWatcher
// Sync response persistence (protected by syncRespMux)
syncRespMux sync.RWMutex
// Sync response persistence
persistSyncResponse bool
latestSyncResponse *mgmProto.SyncResponse
connSemaphore *semaphoregroup.SemaphoreGroup
@@ -221,9 +211,6 @@ type Engine struct {
shutdownWg sync.WaitGroup
probeStunTurn *relay.StunTurnProbe
jobExecutor *jobexec.Executor
jobExecutorWG sync.WaitGroup
}
// Peer is an instance of the Connection Peer
@@ -237,18 +224,7 @@ type localIpUpdater interface {
}
// NewEngine creates a new Connection Engine with probes attached
func NewEngine(
clientCtx context.Context,
clientCancel context.CancelFunc,
signalClient signal.Client,
mgmClient mgm.Client,
relayManager *relayClient.Manager,
config *EngineConfig,
mobileDep MobileDependency,
statusRecorder *peer.Status,
checks []*mgmProto.Checks,
stateManager *statemanager.Manager,
) *Engine {
func NewEngine(clientCtx context.Context, clientCancel context.CancelFunc, signalClient signal.Client, mgmClient mgm.Client, relayManager *relayClient.Manager, config *EngineConfig, mobileDep MobileDependency, statusRecorder *peer.Status, checks []*mgmProto.Checks, stateManager *statemanager.Manager) *Engine {
engine := &Engine{
clientCtx: clientCtx,
clientCancel: clientCancel,
@@ -268,7 +244,6 @@ func NewEngine(
checks: checks,
connSemaphore: semaphoregroup.NewSemaphoreGroup(connInitLimit),
probeStunTurn: relay.NewStunTurnProbe(relay.DefaultCacheTTL),
jobExecutor: jobexec.NewExecutor(),
}
log.Infof("I am: %s", config.WgPrivateKey.PublicKey().String())
@@ -337,8 +312,6 @@ func (e *Engine) Stop() error {
e.cancel()
}
e.jobExecutorWG.Wait() // block until job goroutines finish
e.close()
// stop flow manager after wg interface is gone
@@ -506,15 +479,6 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
return fmt.Errorf("up wg interface: %w", err)
}
// Set up notrack rules immediately after proxy is listening to prevent
// conntrack entries from being created before the rules are in place
e.setupWGProxyNoTrack()
// Set the WireGuard interface for rosenpass after interface is up
if e.rpManager != nil {
e.rpManager.SetInterface(e.wgInterface)
}
// if inbound conns are blocked there is no need to create the ACL manager
if e.firewall != nil && !e.config.BlockInbound {
e.acl = acl.NewDefaultManager(e.firewall)
@@ -536,7 +500,6 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
e.receiveSignalEvents()
e.receiveManagementEvents()
e.receiveJobEvents()
// starting network monitor at the very last to avoid disruptions
e.startNetworkMonitor()
@@ -544,12 +507,11 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
// monitor WireGuard interface lifecycle and restart engine on changes
e.wgIfaceMonitor = NewWGIfaceMonitor()
e.shutdownWg.Add(1)
wgIfaceName := e.wgInterface.Name()
go func() {
defer e.shutdownWg.Done()
if shouldRestart, err := e.wgIfaceMonitor.Start(e.ctx, wgIfaceName); shouldRestart {
if shouldRestart, err := e.wgIfaceMonitor.Start(e.ctx, e.wgInterface.Name()); shouldRestart {
log.Infof("WireGuard interface monitor: %s, restarting engine", err)
e.triggerClientRestart()
} else if err != nil {
@@ -575,11 +537,9 @@ func (e *Engine) createFirewall() error {
var err error
e.firewall, err = firewall.NewFirewall(e.wgInterface, e.stateManager, e.flowManager.GetLogger(), e.config.DisableServerRoutes, e.config.MTU)
if err != nil {
return fmt.Errorf("create firewall manager: %w", err)
}
if e.firewall == nil {
return fmt.Errorf("create firewall manager: received nil manager")
if err != nil || e.firewall == nil {
log.Errorf("failed creating firewall manager: %s", err)
return nil
}
if err := e.initFirewall(); err != nil {
@@ -625,23 +585,6 @@ func (e *Engine) initFirewall() error {
return nil
}
// setupWGProxyNoTrack configures connection tracking exclusion for WireGuard proxy traffic.
// This prevents conntrack/MASQUERADE from affecting loopback traffic between WireGuard and the eBPF proxy.
func (e *Engine) setupWGProxyNoTrack() {
if e.firewall == nil {
return
}
proxyPort := e.wgInterface.GetProxyPort()
if proxyPort == 0 {
return
}
if err := e.firewall.SetupEBPFProxyNoTrack(proxyPort, uint16(e.config.WgPort)); err != nil {
log.Warnf("failed to setup ebpf proxy notrack: %v", err)
}
}
func (e *Engine) blockLanAccess() {
if e.config.BlockInbound {
// no need to set up extra deny rules if inbound is already blocked in general
@@ -830,10 +773,6 @@ func (e *Engine) handleAutoUpdateVersion(autoUpdateSettings *mgmProto.AutoUpdate
}
func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
started := time.Now()
defer func() {
log.Infof("sync finished in %s", time.Since(started))
}()
e.syncMsgMux.Lock()
defer e.syncMsgMux.Unlock()
@@ -889,18 +828,9 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
return nil
}
// Persist sync response under the dedicated lock (syncRespMux), not under syncMsgMux.
// Read the storage-enabled flag under the syncRespMux too.
e.syncRespMux.RLock()
enabled := e.persistSyncResponse
e.syncRespMux.RUnlock()
// Store sync response if persistence is enabled
if enabled {
e.syncRespMux.Lock()
if e.persistSyncResponse {
e.latestSyncResponse = update
e.syncRespMux.Unlock()
log.Debugf("sync response persisted with serial %d", nm.GetSerial())
}
@@ -1023,87 +953,13 @@ func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error {
state := e.statusRecorder.GetLocalPeerState()
state.IP = e.wgInterface.Address().String()
state.PubKey = e.config.WgPrivateKey.PublicKey().String()
state.KernelInterface = !e.wgInterface.IsUserspaceBind()
state.KernelInterface = device.WireGuardModuleIsLoaded()
state.FQDN = conf.GetFqdn()
e.statusRecorder.UpdateLocalPeerState(state)
return nil
}
func (e *Engine) receiveJobEvents() {
e.jobExecutorWG.Add(1)
go func() {
defer e.jobExecutorWG.Done()
err := e.mgmClient.Job(e.ctx, func(msg *mgmProto.JobRequest) *mgmProto.JobResponse {
resp := mgmProto.JobResponse{
ID: msg.ID,
Status: mgmProto.JobStatus_failed,
}
switch params := msg.WorkloadParameters.(type) {
case *mgmProto.JobRequest_Bundle:
bundleResult, err := e.handleBundle(params.Bundle)
if err != nil {
log.Errorf("handling bundle: %v", err)
resp.Reason = []byte(err.Error())
return &resp
}
resp.Status = mgmProto.JobStatus_succeeded
resp.WorkloadResults = bundleResult
return &resp
default:
resp.Reason = []byte(jobexec.ErrJobNotImplemented.Error())
return &resp
}
})
if err != nil {
// happens if management is unavailable for a long time.
// We want to cancel the operation of the whole client
_ = CtxGetState(e.ctx).Wrap(ErrResetConnection)
e.clientCancel()
return
}
log.Info("stopped receiving jobs from Management Service")
}()
log.Info("connecting to Management Service jobs stream")
}
func (e *Engine) handleBundle(params *mgmProto.BundleParameters) (*mgmProto.JobResponse_Bundle, error) {
log.Infof("handle remote debug bundle request: %s", params.String())
syncResponse, err := e.GetLatestSyncResponse()
if err != nil {
log.Warnf("get latest sync response: %v", err)
}
bundleDeps := debug.GeneratorDependencies{
InternalConfig: e.config.ProfileConfig,
StatusRecorder: e.statusRecorder,
SyncResponse: syncResponse,
LogPath: e.config.LogPath,
RefreshStatus: func() {
e.RunHealthProbes(true)
},
}
bundleJobParams := debug.BundleConfig{
Anonymize: params.Anonymize,
IncludeSystemInfo: true,
LogFileCount: uint32(params.LogFileCount),
}
waitFor := time.Duration(params.BundleForTime) * time.Minute
uploadKey, err := e.jobExecutor.BundleJob(e.ctx, bundleDeps, bundleJobParams, waitFor, e.config.ProfileConfig.ManagementURL.String())
if err != nil {
return nil, err
}
response := &mgmProto.JobResponse_Bundle{
Bundle: &mgmProto.BundleResult{
UploadKey: uploadKey,
},
}
return response, nil
}
// receiveManagementEvents connects to the Management Service event stream to receive updates from the management service
// E.g. when a new peer has been registered and we are allowed to connect to it.
@@ -1549,7 +1405,6 @@ func (e *Engine) createPeerConn(pubKey string, allowedIPs []netip.Prefix, agentV
if e.rpManager != nil {
peerConn.SetOnConnected(e.rpManager.OnConnected)
peerConn.SetOnDisconnected(e.rpManager.OnDisconnected)
peerConn.SetRosenpassInitializedPresharedKeyValidator(e.rpManager.IsPresharedKeyInitialized)
}
return peerConn, nil
@@ -1673,7 +1528,6 @@ func (e *Engine) parseNATExternalIPMappings() []string {
func (e *Engine) close() {
log.Debugf("removing Netbird interface %s", e.config.WgIfaceName)
if e.wgInterface != nil {
if err := e.wgInterface.Close(); err != nil {
log.Errorf("failed closing Netbird interface %s %v", e.config.WgIfaceName, err)
@@ -1860,7 +1714,7 @@ func (e *Engine) getRosenpassAddr() string {
return ""
}
// RunHealthProbes executes health checks for Signal, Management, Relay, and WireGuard services
// RunHealthProbes executes health checks for Signal, Management, Relay and WireGuard services
// and updates the status recorder with the latest states.
func (e *Engine) RunHealthProbes(waitForResult bool) bool {
e.syncMsgMux.Lock()
@@ -1874,8 +1728,23 @@ func (e *Engine) RunHealthProbes(waitForResult bool) bool {
stuns := slices.Clone(e.STUNs)
turns := slices.Clone(e.TURNs)
if err := e.statusRecorder.RefreshWireGuardStats(); err != nil {
log.Debugf("failed to refresh WireGuard stats: %v", err)
if e.wgInterface != nil {
stats, err := e.wgInterface.GetStats()
if err != nil {
log.Warnf("failed to get wireguard stats: %v", err)
e.syncMsgMux.Unlock()
return false
}
for _, key := range e.peerStore.PeersPubKey() {
// wgStats could be zero value, in which case we just reset the stats
wgStats, ok := stats[key]
if !ok {
continue
}
if err := e.statusRecorder.UpdateWireGuardPeerState(key, wgStats); err != nil {
log.Debugf("failed to update wg stats for peer %s: %s", key, err)
}
}
}
e.syncMsgMux.Unlock()
@@ -1924,7 +1793,7 @@ func (e *Engine) triggerClientRestart() {
}
func (e *Engine) startNetworkMonitor() {
if !e.config.NetworkMonitor || nbnetstack.IsEnabled() {
if !e.config.NetworkMonitor {
log.Infof("Network monitor is disabled, not starting")
return
}
@@ -1979,8 +1848,8 @@ func (e *Engine) stopDNSServer() {
// SetSyncResponsePersistence enables or disables sync response persistence
func (e *Engine) SetSyncResponsePersistence(enabled bool) {
e.syncRespMux.Lock()
defer e.syncRespMux.Unlock()
e.syncMsgMux.Lock()
defer e.syncMsgMux.Unlock()
if enabled == e.persistSyncResponse {
return
@@ -1995,22 +1864,20 @@ func (e *Engine) SetSyncResponsePersistence(enabled bool) {
// GetLatestSyncResponse returns the stored sync response if persistence is enabled
func (e *Engine) GetLatestSyncResponse() (*mgmProto.SyncResponse, error) {
e.syncRespMux.RLock()
enabled := e.persistSyncResponse
latest := e.latestSyncResponse
e.syncRespMux.RUnlock()
e.syncMsgMux.Lock()
defer e.syncMsgMux.Unlock()
if !enabled {
if !e.persistSyncResponse {
return nil, errors.New("sync response persistence is disabled")
}
if latest == nil {
if e.latestSyncResponse == nil {
//nolint:nilnil
return nil, nil
}
log.Debugf("Retrieving latest sync response with size %d bytes", proto.Size(latest))
sr, ok := proto.Clone(latest).(*mgmProto.SyncResponse)
log.Debugf("Retrieving latest sync response with size %d bytes", proto.Size(e.latestSyncResponse))
sr, ok := proto.Clone(e.latestSyncResponse).(*mgmProto.SyncResponse)
if !ok {
return nil, fmt.Errorf("failed to clone sync response")
}

View File

@@ -10,7 +10,6 @@ import (
log "github.com/sirupsen/logrus"
firewallManager "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface/netstack"
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
sshauth "github.com/netbirdio/netbird/client/ssh/auth"
sshconfig "github.com/netbirdio/netbird/client/ssh/config"
@@ -73,16 +72,9 @@ func (e *Engine) updateSSH(sshConf *mgmProto.SSHConfig) error {
}
if protoJWT := sshConf.GetJwtConfig(); protoJWT != nil {
audiences := protoJWT.GetAudiences()
if len(audiences) == 0 && protoJWT.GetAudience() != "" {
audiences = []string{protoJWT.GetAudience()}
}
log.Debugf("starting SSH server with JWT authentication: audiences=%v", audiences)
jwtConfig := &sshserver.JWTConfig{
Issuer: protoJWT.GetIssuer(),
Audiences: audiences,
Audience: protoJWT.GetAudience(),
KeysLocation: protoJWT.GetKeysLocation(),
MaxTokenAge: protoJWT.GetMaxTokenAge(),
}
@@ -95,10 +87,6 @@ func (e *Engine) updateSSH(sshConf *mgmProto.SSHConfig) error {
// updateSSHClientConfig updates the SSH client configuration with peer information
func (e *Engine) updateSSHClientConfig(remotePeers []*mgmProto.RemotePeerConfig) error {
if netstack.IsEnabled() {
return nil
}
peerInfo := e.extractPeerSSHInfo(remotePeers)
if len(peerInfo) == 0 {
log.Debug("no SSH-enabled peers found, skipping SSH config update")
@@ -221,10 +209,6 @@ func (e *Engine) GetPeerSSHKey(peerAddress string) ([]byte, bool) {
// cleanupSSHConfig removes NetBird SSH client configuration on shutdown
func (e *Engine) cleanupSSHConfig() {
if netstack.IsEnabled() {
return
}
configMgr := sshconfig.New()
if err := configMgr.RemoveSSHClientConfig(); err != nil {

View File

@@ -25,7 +25,6 @@ import (
"google.golang.org/grpc/keepalive"
"github.com/netbirdio/netbird/client/internal/stdnet"
"github.com/netbirdio/netbird/management/server/job"
"github.com/netbirdio/management-integrations/integrations"
@@ -107,7 +106,6 @@ type MockWGIface struct {
GetStatsFunc func() (map[string]configurer.WGStats, error)
GetInterfaceGUIDStringFunc func() (string, error)
GetProxyFunc func() wgproxy.Proxy
GetProxyPortFunc func() uint16
GetNetFunc func() *netstack.Net
LastActivitiesFunc func() map[string]monotime.Time
}
@@ -204,13 +202,6 @@ func (m *MockWGIface) GetProxy() wgproxy.Proxy {
return m.GetProxyFunc()
}
func (m *MockWGIface) GetProxyPort() uint16 {
if m.GetProxyPortFunc != nil {
return m.GetProxyPortFunc()
}
return 0
}
func (m *MockWGIface) GetNet() *netstack.Net {
return m.GetNetFunc()
}
@@ -222,10 +213,6 @@ func (m *MockWGIface) LastActivities() map[string]monotime.Time {
return nil
}
func (m *MockWGIface) SetPresharedKey(peerKey string, psk wgtypes.Key, updateOnly bool) error {
return nil
}
func TestMain(m *testing.M) {
_ = util.InitLog("debug", util.LogConsole)
code := m.Run()
@@ -1612,7 +1599,6 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri
permissionsManager := permissions.NewManager(store)
peersManager := peers.NewManager(store, permissionsManager)
jobManager := job.NewJobManager(nil, store, peersManager)
ia, _ := integrations.NewIntegratedValidator(context.Background(), peersManager, nil, eventStore)
@@ -1636,7 +1622,7 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri
updateManager := update_channel.NewPeersUpdateManager(metrics)
requestBuffer := server.NewAccountRequestBuffer(context.Background(), store)
networkMapController := controller.NewController(context.Background(), store, metrics, updateManager, requestBuffer, server.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), manager.NewEphemeralManager(store, peersManager), config)
accountManager, err := server.BuildManager(context.Background(), config, store, networkMapController, jobManager, nil, "", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
accountManager, err := server.BuildManager(context.Background(), config, store, networkMapController, nil, "", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
if err != nil {
return nil, "", err
}
@@ -1645,7 +1631,7 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri
if err != nil {
return nil, "", err
}
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, jobManager, secretsManager, nil, nil, &server.MockIntegratedValidator{}, networkMapController, nil)
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, secretsManager, nil, nil, &server.MockIntegratedValidator{}, networkMapController, nil)
if err != nil {
return nil, "", err
}

View File

@@ -28,7 +28,6 @@ type wgIfaceBase interface {
Up() (*udpmux.UniversalUDPMuxDefault, error)
UpdateAddr(newAddr string) error
GetProxy() wgproxy.Proxy
GetProxyPort() uint16
UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error
RemoveEndpointAddress(key string) error
RemovePeer(peerKey string) error
@@ -43,5 +42,4 @@ type wgIfaceBase interface {
GetNet() *netstack.Net
FullStats() (*configurer.Stats, error)
LastActivities() map[string]monotime.Time
SetPresharedKey(peerKey string, psk wgtypes.Key, updateOnly bool) error
}

View File

@@ -11,7 +11,6 @@ import (
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/client/iface/netstack"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/internal/lazyconn"
peerid "github.com/netbirdio/netbird/client/internal/peer/id"
@@ -75,13 +74,12 @@ func (m *Manager) createListener(peerCfg lazyconn.PeerConfig) (listener, error)
return NewUDPListener(m.wgIface, peerCfg)
}
// BindListener is used on Windows, JS, and netstack platforms:
// BindListener is only used on Windows and JS platforms:
// - JS: Cannot listen to UDP sockets
// - Windows: IP_UNICAST_IF socket option forces packets out the interface the default
// gateway points to, preventing them from reaching the loopback interface.
// - Netstack: Allows multiple instances on the same host without port conflicts.
// BindListener bypasses these issues by passing data directly through the bind.
if runtime.GOOS != "windows" && runtime.GOOS != "js" && !netstack.IsEnabled() {
// BindListener bypasses this by passing data directly through the bind.
if runtime.GOOS != "windows" && runtime.GOOS != "js" {
return NewUDPListener(m.wgIface, peerCfg)
}

201
client/internal/login.go Normal file
View File

@@ -0,0 +1,201 @@
package internal
import (
"context"
"net/url"
"github.com/google/uuid"
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"github.com/netbirdio/netbird/client/internal/profilemanager"
"github.com/netbirdio/netbird/client/ssh"
"github.com/netbirdio/netbird/client/system"
mgm "github.com/netbirdio/netbird/shared/management/client"
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
)
// IsLoginRequired check that the server is support SSO or not
func IsLoginRequired(ctx context.Context, config *profilemanager.Config) (bool, error) {
mgmURL := config.ManagementURL
mgmClient, err := getMgmClient(ctx, config.PrivateKey, mgmURL)
if err != nil {
return false, err
}
defer func() {
err = mgmClient.Close()
if err != nil {
cStatus, ok := status.FromError(err)
if !ok || ok && cStatus.Code() != codes.Canceled {
log.Warnf("failed to close the Management service client, err: %v", err)
}
}
}()
log.Debugf("connected to the Management service %s", mgmURL.String())
pubSSHKey, err := ssh.GeneratePublicKey([]byte(config.SSHKey))
if err != nil {
return false, err
}
_, _, err = doMgmLogin(ctx, mgmClient, pubSSHKey, config)
if isLoginNeeded(err) {
return true, nil
}
return false, err
}
// Login or register the client
func Login(ctx context.Context, config *profilemanager.Config, setupKey string, jwtToken string) error {
mgmClient, err := getMgmClient(ctx, config.PrivateKey, config.ManagementURL)
if err != nil {
return err
}
defer func() {
err = mgmClient.Close()
if err != nil {
cStatus, ok := status.FromError(err)
if !ok || ok && cStatus.Code() != codes.Canceled {
log.Warnf("failed to close the Management service client, err: %v", err)
}
}
}()
log.Debugf("connected to the Management service %s", config.ManagementURL.String())
pubSSHKey, err := ssh.GeneratePublicKey([]byte(config.SSHKey))
if err != nil {
return err
}
serverKey, _, err := doMgmLogin(ctx, mgmClient, pubSSHKey, config)
if serverKey != nil && isRegistrationNeeded(err) {
log.Debugf("peer registration required")
_, err = registerPeer(ctx, *serverKey, mgmClient, setupKey, jwtToken, pubSSHKey, config)
if err != nil {
return err
}
} else if err != nil {
return err
}
return nil
}
func getMgmClient(ctx context.Context, privateKey string, mgmURL *url.URL) (*mgm.GrpcClient, error) {
// validate our peer's Wireguard PRIVATE key
myPrivateKey, err := wgtypes.ParseKey(privateKey)
if err != nil {
log.Errorf("failed parsing Wireguard key %s: [%s]", privateKey, err.Error())
return nil, err
}
var mgmTlsEnabled bool
if mgmURL.Scheme == "https" {
mgmTlsEnabled = true
}
log.Debugf("connecting to the Management service %s", mgmURL.String())
mgmClient, err := mgm.NewClient(ctx, mgmURL.Host, myPrivateKey, mgmTlsEnabled)
if err != nil {
log.Errorf("failed connecting to the Management service %s %v", mgmURL.String(), err)
return nil, err
}
return mgmClient, err
}
func doMgmLogin(ctx context.Context, mgmClient *mgm.GrpcClient, pubSSHKey []byte, config *profilemanager.Config) (*wgtypes.Key, *mgmProto.LoginResponse, error) {
serverKey, err := mgmClient.GetServerPublicKey()
if err != nil {
log.Errorf("failed while getting Management Service public key: %v", err)
return nil, nil, err
}
sysInfo := system.GetInfo(ctx)
sysInfo.SetFlags(
config.RosenpassEnabled,
config.RosenpassPermissive,
config.ServerSSHAllowed,
config.DisableClientRoutes,
config.DisableServerRoutes,
config.DisableDNS,
config.DisableFirewall,
config.BlockLANAccess,
config.BlockInbound,
config.LazyConnectionEnabled,
config.EnableSSHRoot,
config.EnableSSHSFTP,
config.EnableSSHLocalPortForwarding,
config.EnableSSHRemotePortForwarding,
config.DisableSSHAuth,
)
loginResp, err := mgmClient.Login(*serverKey, sysInfo, pubSSHKey, config.DNSLabels)
return serverKey, loginResp, err
}
// registerPeer checks whether setupKey was provided via cmd line and if not then it prompts user to enter a key.
// Otherwise tries to register with the provided setupKey via command line.
func registerPeer(ctx context.Context, serverPublicKey wgtypes.Key, client *mgm.GrpcClient, setupKey string, jwtToken string, pubSSHKey []byte, config *profilemanager.Config) (*mgmProto.LoginResponse, error) {
validSetupKey, err := uuid.Parse(setupKey)
if err != nil && jwtToken == "" {
return nil, status.Errorf(codes.InvalidArgument, "invalid setup-key or no sso information provided, err: %v", err)
}
log.Debugf("sending peer registration request to Management Service")
info := system.GetInfo(ctx)
info.SetFlags(
config.RosenpassEnabled,
config.RosenpassPermissive,
config.ServerSSHAllowed,
config.DisableClientRoutes,
config.DisableServerRoutes,
config.DisableDNS,
config.DisableFirewall,
config.BlockLANAccess,
config.BlockInbound,
config.LazyConnectionEnabled,
config.EnableSSHRoot,
config.EnableSSHSFTP,
config.EnableSSHLocalPortForwarding,
config.EnableSSHRemotePortForwarding,
config.DisableSSHAuth,
)
loginResp, err := client.Register(serverPublicKey, validSetupKey.String(), jwtToken, info, pubSSHKey, config.DNSLabels)
if err != nil {
log.Errorf("failed registering peer %v", err)
return nil, err
}
log.Infof("peer has been successfully registered on Management Service")
return loginResp, nil
}
func isLoginNeeded(err error) bool {
if err == nil {
return false
}
s, ok := status.FromError(err)
if !ok {
return false
}
if s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied {
return true
}
return false
}
func isRegistrationNeeded(err error) bool {
if err == nil {
return false
}
s, ok := status.FromError(err)
if !ok {
return false
}
if s.Code() == codes.PermissionDenied {
return true
}
return false
}

View File

@@ -88,9 +88,8 @@ type Conn struct {
relayManager *relayClient.Manager
srWatcher *guard.SRWatcher
onConnected func(remoteWireGuardKey string, remoteRosenpassPubKey []byte, wireGuardIP string, remoteRosenpassAddr string)
onDisconnected func(remotePeer string)
rosenpassInitializedPresharedKeyValidator func(peerKey string) bool
onConnected func(remoteWireGuardKey string, remoteRosenpassPubKey []byte, wireGuardIP string, remoteRosenpassAddr string)
onDisconnected func(remotePeer string)
statusRelay *worker.AtomicWorkerStatus
statusICE *worker.AtomicWorkerStatus
@@ -99,10 +98,7 @@ type Conn struct {
workerICE *WorkerICE
workerRelay *WorkerRelay
wgWatcher *WGWatcher
wgWatcherWg sync.WaitGroup
wgWatcherCancel context.CancelFunc
wgWatcherWg sync.WaitGroup
// used to store the remote Rosenpass key for Relayed connection in case of connection update from ice
rosenpassRemoteKey []byte
@@ -130,7 +126,6 @@ func NewConn(config ConnConfig, services ServiceDependencies) (*Conn, error) {
connLog := log.WithField("peer", config.Key)
dumpState := newStateDump(config.Key, connLog, services.StatusRecorder)
var conn = &Conn{
Log: connLog,
config: config,
@@ -142,9 +137,8 @@ func NewConn(config ConnConfig, services ServiceDependencies) (*Conn, error) {
semaphore: services.Semaphore,
statusRelay: worker.NewAtomicStatus(),
statusICE: worker.NewAtomicStatus(),
dumpState: dumpState,
dumpState: newStateDump(config.Key, connLog, services.StatusRecorder),
endpointUpdater: NewEndpointUpdater(connLog, config.WgConfig, isController(config)),
wgWatcher: NewWGWatcher(connLog, config.WgConfig.WgInterface, config.Key, dumpState),
}
return conn, nil
@@ -168,7 +162,7 @@ func (conn *Conn) Open(engineCtx context.Context) error {
conn.ctx, conn.ctxCancel = context.WithCancel(engineCtx)
conn.workerRelay = NewWorkerRelay(conn.ctx, conn.Log, isController(conn.config), conn.config, conn, conn.relayManager)
conn.workerRelay = NewWorkerRelay(conn.ctx, conn.Log, isController(conn.config), conn.config, conn, conn.relayManager, conn.dumpState)
relayIsSupportedLocally := conn.workerRelay.RelayIsSupportedLocally()
workerICE, err := NewWorkerICE(conn.ctx, conn.Log, conn.config, conn, conn.signaler, conn.iFaceDiscover, conn.statusRecorder, relayIsSupportedLocally)
@@ -237,9 +231,7 @@ func (conn *Conn) Close(signalToRemote bool) {
conn.Log.Infof("close peer connection")
conn.ctxCancel()
if conn.wgWatcherCancel != nil {
conn.wgWatcherCancel()
}
conn.workerRelay.DisableWgWatcher()
conn.workerRelay.CloseConn()
conn.workerICE.Close()
@@ -297,13 +289,6 @@ func (conn *Conn) SetOnDisconnected(handler func(remotePeer string)) {
conn.onDisconnected = handler
}
// SetRosenpassInitializedPresharedKeyValidator sets a function to check if Rosenpass has taken over
// PSK management for a peer. When this returns true, presharedKey() returns nil
// to prevent UpdatePeer from overwriting the Rosenpass-managed PSK.
func (conn *Conn) SetRosenpassInitializedPresharedKeyValidator(handler func(peerKey string) bool) {
conn.rosenpassInitializedPresharedKeyValidator = handler
}
func (conn *Conn) OnRemoteOffer(offer OfferAnswer) {
conn.dumpState.RemoteOffer()
conn.Log.Infof("OnRemoteOffer, on status ICE: %s, status Relay: %s", conn.statusICE, conn.statusRelay)
@@ -381,6 +366,9 @@ func (conn *Conn) onICEConnectionIsReady(priority conntype.ConnPriority, iceConn
ep = directEp
}
conn.workerRelay.DisableWgWatcher()
// todo consider to run conn.wgWatcherWg.Wait() here
if conn.wgProxyRelay != nil {
conn.wgProxyRelay.Pause()
}
@@ -390,8 +378,6 @@ func (conn *Conn) onICEConnectionIsReady(priority conntype.ConnPriority, iceConn
}
conn.Log.Infof("configure WireGuard endpoint to: %s", ep.String())
conn.enableWgWatcherIfNeeded()
presharedKey := conn.presharedKey(iceConnInfo.RosenpassPubKey)
if err = conn.endpointUpdater.ConfigureWGEndpoint(ep, presharedKey); err != nil {
conn.handleConfigurationFailure(err, wgProxy)
@@ -437,6 +423,11 @@ func (conn *Conn) onICEStateDisconnected() {
conn.Log.Errorf("failed to switch to relay conn: %v", err)
}
conn.wgWatcherWg.Add(1)
go func() {
defer conn.wgWatcherWg.Done()
conn.workerRelay.EnableWgWatcher(conn.ctx)
}()
conn.wgProxyRelay.Work()
conn.currentConnPriority = conntype.Relay
} else {
@@ -453,15 +444,15 @@ func (conn *Conn) onICEStateDisconnected() {
}
conn.statusICE.SetDisconnected()
conn.disableWgWatcherIfNeeded()
peerState := State{
PubKey: conn.config.Key,
ConnStatus: conn.evalStatus(),
Relayed: conn.isRelayed(),
ConnStatusUpdate: time.Now(),
}
if err := conn.statusRecorder.UpdatePeerICEStateToDisconnected(peerState); err != nil {
err := conn.statusRecorder.UpdatePeerICEStateToDisconnected(peerState)
if err != nil {
conn.Log.Warnf("unable to set peer's state to disconnected ice, got error: %v", err)
}
}
@@ -501,9 +492,6 @@ func (conn *Conn) onRelayConnectionIsReady(rci RelayConnInfo) {
wgProxy.Work()
presharedKey := conn.presharedKey(rci.rosenpassPubKey)
conn.enableWgWatcherIfNeeded()
if err := conn.endpointUpdater.ConfigureWGEndpoint(wgProxy.EndpointAddr(), presharedKey); err != nil {
if err := wgProxy.CloseConn(); err != nil {
conn.Log.Warnf("Failed to close relay connection: %v", err)
@@ -512,6 +500,12 @@ func (conn *Conn) onRelayConnectionIsReady(rci RelayConnInfo) {
return
}
conn.wgWatcherWg.Add(1)
go func() {
defer conn.wgWatcherWg.Done()
conn.workerRelay.EnableWgWatcher(conn.ctx)
}()
wgConfigWorkaround()
conn.rosenpassRemoteKey = rci.rosenpassPubKey
conn.currentConnPriority = conntype.Relay
@@ -525,11 +519,7 @@ func (conn *Conn) onRelayConnectionIsReady(rci RelayConnInfo) {
func (conn *Conn) onRelayDisconnected() {
conn.mu.Lock()
defer conn.mu.Unlock()
conn.handleRelayDisconnectedLocked()
}
// handleRelayDisconnectedLocked handles relay disconnection. Caller must hold conn.mu.
func (conn *Conn) handleRelayDisconnectedLocked() {
if conn.ctx.Err() != nil {
return
}
@@ -555,8 +545,6 @@ func (conn *Conn) handleRelayDisconnectedLocked() {
}
conn.statusRelay.SetDisconnected()
conn.disableWgWatcherIfNeeded()
peerState := State{
PubKey: conn.config.Key,
ConnStatus: conn.evalStatus(),
@@ -575,28 +563,6 @@ func (conn *Conn) onGuardEvent() {
}
}
func (conn *Conn) onWGDisconnected() {
conn.mu.Lock()
defer conn.mu.Unlock()
if conn.ctx.Err() != nil {
return
}
conn.Log.Warnf("WireGuard handshake timeout detected, closing current connection")
// Close the active connection based on current priority
switch conn.currentConnPriority {
case conntype.Relay:
conn.workerRelay.CloseConn()
conn.handleRelayDisconnectedLocked()
case conntype.ICEP2P, conntype.ICETurn:
conn.workerICE.Close()
default:
conn.Log.Debugf("No active connection to close on WG timeout")
}
}
func (conn *Conn) updateRelayStatus(relayServerAddr string, rosenpassPubKey []byte) {
peerState := State{
PubKey: conn.config.Key,
@@ -723,25 +689,6 @@ func (conn *Conn) isConnectedOnAllWay() (connected bool) {
return true
}
func (conn *Conn) enableWgWatcherIfNeeded() {
if !conn.wgWatcher.IsEnabled() {
wgWatcherCtx, wgWatcherCancel := context.WithCancel(conn.ctx)
conn.wgWatcherCancel = wgWatcherCancel
conn.wgWatcherWg.Add(1)
go func() {
defer conn.wgWatcherWg.Done()
conn.wgWatcher.EnableWgWatcher(wgWatcherCtx, conn.onWGDisconnected)
}()
}
}
func (conn *Conn) disableWgWatcherIfNeeded() {
if conn.currentConnPriority == conntype.None && conn.wgWatcherCancel != nil {
conn.wgWatcherCancel()
conn.wgWatcherCancel = nil
}
}
func (conn *Conn) newProxy(remoteConn net.Conn) (wgproxy.Proxy, error) {
conn.Log.Debugf("setup proxied WireGuard connection")
udpAddr := &net.UDPAddr{
@@ -812,24 +759,10 @@ func (conn *Conn) presharedKey(remoteRosenpassKey []byte) *wgtypes.Key {
return conn.config.WgConfig.PreSharedKey
}
// If Rosenpass has already set a PSK for this peer, return nil to prevent
// UpdatePeer from overwriting the Rosenpass-managed key.
if conn.rosenpassInitializedPresharedKeyValidator != nil && conn.rosenpassInitializedPresharedKeyValidator(conn.config.Key) {
return nil
}
// Use NetBird PSK as the seed for Rosenpass. This same PSK is passed to
// Rosenpass as PeerConfig.PresharedKey, ensuring the derived post-quantum
// key is cryptographically bound to the original secret.
if conn.config.WgConfig.PreSharedKey != nil {
return conn.config.WgConfig.PreSharedKey
}
// Fallback to deterministic key if no NetBird PSK is configured
determKey, err := conn.rosenpassDetermKey()
if err != nil {
conn.Log.Errorf("failed to generate Rosenpass initial key: %v", err)
return nil
return conn.config.WgConfig.PreSharedKey
}
return determKey

View File

@@ -284,27 +284,3 @@ func TestConn_presharedKey(t *testing.T) {
})
}
}
func TestConn_presharedKey_RosenpassManaged(t *testing.T) {
conn := Conn{
config: ConnConfig{
Key: "LLHf3Ma6z6mdLbriAJbqhX7+nM/B71lgw2+91q3LfhU=",
LocalKey: "RRHf3Ma6z6mdLbriAJbqhX7+nM/B71lgw2+91q3LfhU=",
RosenpassConfig: RosenpassConfig{PubKey: []byte("dummykey")},
},
}
// When Rosenpass has already initialized the PSK for this peer,
// presharedKey must return nil to avoid UpdatePeer overwriting it.
conn.rosenpassInitializedPresharedKeyValidator = func(peerKey string) bool { return true }
if k := conn.presharedKey([]byte("remote")); k != nil {
t.Fatalf("expected nil presharedKey when Rosenpass manages PSK, got %v", k)
}
// When Rosenpass hasn't taken over yet, presharedKey should provide
// a non-nil initial key (deterministic or from NetBird PSK).
conn.rosenpassInitializedPresharedKeyValidator = func(peerKey string) bool { return false }
if k := conn.presharedKey([]byte("remote")); k == nil {
t.Fatalf("expected non-nil presharedKey before Rosenpass manages PSK")
}
}

View File

@@ -2,7 +2,6 @@ package ice
import (
"context"
"fmt"
"sync"
"time"
@@ -33,6 +32,24 @@ type ThreadSafeAgent struct {
once sync.Once
}
func (a *ThreadSafeAgent) Close() error {
var err error
a.once.Do(func() {
done := make(chan error, 1)
go func() {
done <- a.Agent.Close()
}()
select {
case err = <-done:
case <-time.After(iceAgentCloseTimeout):
log.Warnf("ICE agent close timed out after %v, proceeding with cleanup", iceAgentCloseTimeout)
err = nil
}
})
return err
}
func NewAgent(ctx context.Context, iFaceDiscover stdnet.ExternalIFaceDiscover, config Config, candidateTypes []ice.CandidateType, ufrag string, pwd string) (*ThreadSafeAgent, error) {
iceKeepAlive := iceKeepAlive()
iceDisconnectedTimeout := iceDisconnectedTimeout()
@@ -76,41 +93,9 @@ func NewAgent(ctx context.Context, iFaceDiscover stdnet.ExternalIFaceDiscover, c
return nil, err
}
if agent == nil {
return nil, fmt.Errorf("ice.NewAgent returned nil agent without error")
}
return &ThreadSafeAgent{Agent: agent}, nil
}
func (a *ThreadSafeAgent) Close() error {
var err error
a.once.Do(func() {
// Defensive check to prevent nil pointer dereference
// This can happen during sleep/wake transitions or memory corruption scenarios
// github.com/netbirdio/netbird/client/internal/peer/ice.(*ThreadSafeAgent).Close(0x40006883f0?)
// [signal 0xc0000005 code=0x0 addr=0x0 pc=0x7ff7e73af83c]
agent := a.Agent
if agent == nil {
log.Warnf("ICE agent is nil during close, skipping")
return
}
done := make(chan error, 1)
go func() {
done <- agent.Close()
}()
select {
case err = <-done:
case <-time.After(iceAgentCloseTimeout):
log.Warnf("ICE agent close timed out after %v, proceeding with cleanup", iceAgentCloseTimeout)
err = nil
}
})
return err
}
func GenerateICECredentials() (string, string, error) {
ufrag, err := randutil.GenerateCryptoRandomString(lenUFrag, runesAlpha)
if err != nil {

View File

@@ -1145,38 +1145,6 @@ func (d *Status) PeersStatus() (*configurer.Stats, error) {
return d.wgIface.FullStats()
}
// RefreshWireGuardStats fetches fresh WireGuard statistics from the interface
// and updates the cached peer states. This ensures accurate handshake times and
// transfer statistics in status reports without running full health probes.
func (d *Status) RefreshWireGuardStats() error {
d.mux.Lock()
defer d.mux.Unlock()
if d.wgIface == nil {
return nil // silently skip if interface not set
}
stats, err := d.wgIface.FullStats()
if err != nil {
return fmt.Errorf("get wireguard stats: %w", err)
}
// Update each peer's WireGuard statistics
for _, peerStats := range stats.Peers {
peerState, ok := d.peers[peerStats.PublicKey]
if !ok {
continue
}
peerState.LastWireguardHandshake = peerStats.LastHandshake
peerState.BytesRx = peerStats.RxBytes
peerState.BytesTx = peerStats.TxBytes
d.peers[peerStats.PublicKey] = peerState
}
return nil
}
type EventQueue struct {
maxSize int
events []*proto.SystemEvent

View File

@@ -30,8 +30,10 @@ type WGWatcher struct {
peerKey string
stateDump *stateDump
enabled bool
muEnabled sync.RWMutex
ctx context.Context
ctxCancel context.CancelFunc
ctxLock sync.Mutex
enabledTime time.Time
}
func NewWGWatcher(log *log.Entry, wgIfaceStater WGInterfaceStater, peerKey string, stateDump *stateDump) *WGWatcher {
@@ -44,44 +46,52 @@ func NewWGWatcher(log *log.Entry, wgIfaceStater WGInterfaceStater, peerKey strin
}
// EnableWgWatcher starts the WireGuard watcher. If it is already enabled, it will return immediately and do nothing.
// The watcher runs until ctx is cancelled. Caller is responsible for context lifecycle management.
func (w *WGWatcher) EnableWgWatcher(ctx context.Context, onDisconnectedFn func()) {
w.muEnabled.Lock()
if w.enabled {
w.muEnabled.Unlock()
func (w *WGWatcher) EnableWgWatcher(parentCtx context.Context, onDisconnectedFn func()) {
w.log.Debugf("enable WireGuard watcher")
w.ctxLock.Lock()
w.enabledTime = time.Now()
if w.ctx != nil && w.ctx.Err() == nil {
w.log.Errorf("WireGuard watcher already enabled")
w.ctxLock.Unlock()
return
}
w.log.Debugf("enable WireGuard watcher")
enabledTime := time.Now()
w.enabled = true
w.muEnabled.Unlock()
ctx, ctxCancel := context.WithCancel(parentCtx)
w.ctx = ctx
w.ctxCancel = ctxCancel
w.ctxLock.Unlock()
initialHandshake, err := w.wgState()
if err != nil {
w.log.Warnf("failed to read initial wg stats: %v", err)
}
w.periodicHandshakeCheck(ctx, onDisconnectedFn, enabledTime, initialHandshake)
w.muEnabled.Lock()
w.enabled = false
w.muEnabled.Unlock()
w.periodicHandshakeCheck(ctx, ctxCancel, onDisconnectedFn, initialHandshake)
}
// IsEnabled returns true if the WireGuard watcher is currently enabled
func (w *WGWatcher) IsEnabled() bool {
w.muEnabled.RLock()
defer w.muEnabled.RUnlock()
return w.enabled
// DisableWgWatcher stops the WireGuard watcher and wait for the watcher to exit
func (w *WGWatcher) DisableWgWatcher() {
w.ctxLock.Lock()
defer w.ctxLock.Unlock()
if w.ctxCancel == nil {
return
}
w.log.Debugf("disable WireGuard watcher")
w.ctxCancel()
w.ctxCancel = nil
}
// wgStateCheck help to check the state of the WireGuard handshake and relay connection
func (w *WGWatcher) periodicHandshakeCheck(ctx context.Context, onDisconnectedFn func(), enabledTime time.Time, initialHandshake time.Time) {
func (w *WGWatcher) periodicHandshakeCheck(ctx context.Context, ctxCancel context.CancelFunc, onDisconnectedFn func(), initialHandshake time.Time) {
w.log.Infof("WireGuard watcher started")
timer := time.NewTimer(wgHandshakeOvertime)
defer timer.Stop()
defer ctxCancel()
lastHandshake := initialHandshake
@@ -94,7 +104,7 @@ func (w *WGWatcher) periodicHandshakeCheck(ctx context.Context, onDisconnectedFn
return
}
if lastHandshake.IsZero() {
elapsed := calcElapsed(enabledTime, *handshake)
elapsed := handshake.Sub(w.enabledTime).Seconds()
w.log.Infof("first wg handshake detected within: %.2fsec, (%s)", elapsed, handshake)
}
@@ -124,19 +134,19 @@ func (w *WGWatcher) handshakeCheck(lastHandshake time.Time) (*time.Time, bool) {
// the current know handshake did not change
if handshake.Equal(lastHandshake) {
w.log.Warnf("WireGuard handshake timed out: %v", handshake)
w.log.Warnf("WireGuard handshake timed out, closing relay connection: %v", handshake)
return nil, false
}
// in case if the machine is suspended, the handshake time will be in the past
if handshake.Add(checkPeriod).Before(time.Now()) {
w.log.Warnf("WireGuard handshake timed out: %v", handshake)
w.log.Warnf("WireGuard handshake timed out, closing relay connection: %v", handshake)
return nil, false
}
// error handling for handshake time in the future
if handshake.After(time.Now()) {
w.log.Warnf("WireGuard handshake is in the future: %v", handshake)
w.log.Warnf("WireGuard handshake is in the future, closing relay connection: %v", handshake)
return nil, false
}
@@ -154,13 +164,3 @@ func (w *WGWatcher) wgState() (time.Time, error) {
}
return wgState.LastHandshake, nil
}
// calcElapsed calculates elapsed time since watcher was enabled.
// The watcher started after the wg configuration happens, because of this need to normalise the negative value
func calcElapsed(enabledTime, handshake time.Time) float64 {
elapsed := handshake.Sub(enabledTime).Seconds()
if elapsed < 0 {
elapsed = 0
}
return elapsed
}

View File

@@ -2,7 +2,6 @@ package peer
import (
"context"
"sync"
"testing"
"time"
@@ -49,6 +48,7 @@ func TestWGWatcher_EnableWgWatcher(t *testing.T) {
case <-time.After(10 * time.Second):
t.Errorf("timeout")
}
watcher.DisableWgWatcher()
}
func TestWGWatcher_ReEnable(t *testing.T) {
@@ -60,21 +60,14 @@ func TestWGWatcher_ReEnable(t *testing.T) {
watcher := NewWGWatcher(mlog, mocWgIface, "", newStateDump("peer", mlog, &Status{}))
ctx, cancel := context.WithCancel(context.Background())
wg := &sync.WaitGroup{}
wg.Add(1)
go func() {
defer wg.Done()
watcher.EnableWgWatcher(ctx, func() {})
}()
cancel()
wg.Wait()
// Re-enable with a new context
ctx, cancel = context.WithCancel(context.Background())
defer cancel()
onDisconnected := make(chan struct{}, 1)
go watcher.EnableWgWatcher(ctx, func() {})
time.Sleep(1 * time.Second)
watcher.DisableWgWatcher()
go watcher.EnableWgWatcher(ctx, func() {
onDisconnected <- struct{}{}
})
@@ -87,4 +80,5 @@ func TestWGWatcher_ReEnable(t *testing.T) {
case <-time.After(10 * time.Second):
t.Errorf("timeout")
}
watcher.DisableWgWatcher()
}

View File

@@ -5,7 +5,6 @@ import (
"fmt"
"net"
"net/netip"
"strconv"
"sync"
"time"
@@ -107,10 +106,8 @@ func (w *WorkerICE) OnNewOffer(remoteOfferAnswer *OfferAnswer) {
}
w.log.Debugf("agent already exists, recreate the connection")
w.agentDialerCancel()
if w.agent != nil {
if err := w.agent.Close(); err != nil {
w.log.Warnf("failed to close ICE agent: %s", err)
}
if err := w.agent.Close(); err != nil {
w.log.Warnf("failed to close ICE agent: %s", err)
}
sessionID, err := NewICESessionID()
@@ -289,8 +286,8 @@ func (w *WorkerICE) connect(ctx context.Context, agent *icemaker.ThreadSafeAgent
RosenpassAddr: remoteOfferAnswer.RosenpassAddr,
LocalIceCandidateType: pair.Local.Type().String(),
RemoteIceCandidateType: pair.Remote.Type().String(),
LocalIceCandidateEndpoint: net.JoinHostPort(pair.Local.Address(), strconv.Itoa(pair.Local.Port())),
RemoteIceCandidateEndpoint: net.JoinHostPort(pair.Remote.Address(), strconv.Itoa(pair.Remote.Port())),
LocalIceCandidateEndpoint: fmt.Sprintf("%s:%d", pair.Local.Address(), pair.Local.Port()),
RemoteIceCandidateEndpoint: fmt.Sprintf("%s:%d", pair.Remote.Address(), pair.Remote.Port()),
Relayed: isRelayed(pair),
RelayedOnLocal: isRelayCandidate(pair.Local),
}
@@ -331,7 +328,13 @@ func (w *WorkerICE) closeAgent(agent *icemaker.ThreadSafeAgent, cancel context.C
func (w *WorkerICE) punchRemoteWGPort(pair *ice.CandidatePair, remoteWgPort int) {
// wait local endpoint configuration
time.Sleep(time.Second)
addr, err := net.ResolveUDPAddr("udp", net.JoinHostPort(pair.Remote.Address(), strconv.Itoa(remoteWgPort)))
addrString := pair.Remote.Address()
parsed, err := netip.ParseAddr(addrString)
if (err == nil) && (parsed.Is6()) {
addrString = fmt.Sprintf("[%s]", addrString)
//IPv6 Literals need to be wrapped in brackets for Resolve*Addr()
}
addr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%d", addrString, remoteWgPort))
if err != nil {
w.log.Warnf("got an error while resolving the udp address, err: %s", err)
return
@@ -383,44 +386,12 @@ func (w *WorkerICE) onICESelectedCandidatePair(agent *icemaker.ThreadSafeAgent,
}
}
func (w *WorkerICE) logSuccessfulPaths(agent *icemaker.ThreadSafeAgent) {
sessionID := w.SessionID()
stats := agent.GetCandidatePairsStats()
localCandidates, _ := agent.GetLocalCandidates()
remoteCandidates, _ := agent.GetRemoteCandidates()
localMap := make(map[string]ice.Candidate)
for _, c := range localCandidates {
localMap[c.ID()] = c
}
remoteMap := make(map[string]ice.Candidate)
for _, c := range remoteCandidates {
remoteMap[c.ID()] = c
}
for _, stat := range stats {
if stat.State == ice.CandidatePairStateSucceeded {
local, lok := localMap[stat.LocalCandidateID]
remote, rok := remoteMap[stat.RemoteCandidateID]
if !lok || !rok {
continue
}
w.log.Debugf("successful ICE path %s: [%s %s %s] <-> [%s %s %s] rtt=%.3fms",
sessionID,
local.NetworkType(), local.Type(), local.Address(),
remote.NetworkType(), remote.Type(), remote.Address(),
stat.CurrentRoundTripTime*1000)
}
}
}
func (w *WorkerICE) onConnectionStateChange(agent *icemaker.ThreadSafeAgent, dialerCancel context.CancelFunc) func(ice.ConnectionState) {
return func(state ice.ConnectionState) {
w.log.Debugf("ICE ConnectionState has changed to %s", state.String())
switch state {
case ice.ConnectionStateConnected:
w.lastKnownState = ice.ConnectionStateConnected
w.logSuccessfulPaths(agent)
return
case ice.ConnectionStateFailed, ice.ConnectionStateDisconnected, ice.ConnectionStateClosed:
// ice.ConnectionStateClosed happens when we recreate the agent. For the P2P to TURN switch important to

View File

@@ -30,9 +30,11 @@ type WorkerRelay struct {
relayLock sync.Mutex
relaySupportedOnRemotePeer atomic.Bool
wgWatcher *WGWatcher
}
func NewWorkerRelay(ctx context.Context, log *log.Entry, ctrl bool, config ConnConfig, conn *Conn, relayManager *relayClient.Manager) *WorkerRelay {
func NewWorkerRelay(ctx context.Context, log *log.Entry, ctrl bool, config ConnConfig, conn *Conn, relayManager *relayClient.Manager, stateDump *stateDump) *WorkerRelay {
r := &WorkerRelay{
peerCtx: ctx,
log: log,
@@ -40,6 +42,7 @@ func NewWorkerRelay(ctx context.Context, log *log.Entry, ctrl bool, config ConnC
config: config,
conn: conn,
relayManager: relayManager,
wgWatcher: NewWGWatcher(log, config.WgConfig.WgInterface, config.Key, stateDump),
}
return r
}
@@ -90,6 +93,14 @@ func (w *WorkerRelay) OnNewOffer(remoteOfferAnswer *OfferAnswer) {
})
}
func (w *WorkerRelay) EnableWgWatcher(ctx context.Context) {
w.wgWatcher.EnableWgWatcher(ctx, w.onWGDisconnected)
}
func (w *WorkerRelay) DisableWgWatcher() {
w.wgWatcher.DisableWgWatcher()
}
func (w *WorkerRelay) RelayInstanceAddress() (string, error) {
return w.relayManager.RelayInstanceAddress()
}
@@ -114,6 +125,14 @@ func (w *WorkerRelay) CloseConn() {
}
}
func (w *WorkerRelay) onWGDisconnected() {
w.relayLock.Lock()
_ = w.relayedConn.Close()
w.relayLock.Unlock()
w.conn.onRelayDisconnected()
}
func (w *WorkerRelay) isRelaySupported(answer *OfferAnswer) bool {
if !w.relayManager.HasRelayAddress() {
return false
@@ -129,5 +148,6 @@ func (w *WorkerRelay) preferredRelayServer(myRelayAddress, remoteRelayAddress st
}
func (w *WorkerRelay) onRelayClientDisconnected() {
w.wgWatcher.DisableWgWatcher()
go w.conn.onRelayDisconnected()
}

View File

@@ -0,0 +1,138 @@
package internal
import (
"context"
"crypto/tls"
"fmt"
"net/url"
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
mgm "github.com/netbirdio/netbird/shared/management/client"
"github.com/netbirdio/netbird/shared/management/client/common"
)
// PKCEAuthorizationFlow represents PKCE Authorization Flow information
type PKCEAuthorizationFlow struct {
ProviderConfig PKCEAuthProviderConfig
}
// PKCEAuthProviderConfig has all attributes needed to initiate pkce authorization flow
type PKCEAuthProviderConfig struct {
// ClientID An IDP application client id
ClientID string
// ClientSecret An IDP application client secret
ClientSecret string
// Audience An Audience for to authorization validation
Audience string
// TokenEndpoint is the endpoint of an IDP manager where clients can obtain access token
TokenEndpoint string
// AuthorizationEndpoint is the endpoint of an IDP manager where clients can obtain authorization code
AuthorizationEndpoint string
// Scopes provides the scopes to be included in the token request
Scope string
// RedirectURL handles authorization code from IDP manager
RedirectURLs []string
// UseIDToken indicates if the id token should be used for authentication
UseIDToken bool
// ClientCertPair is used for mTLS authentication to the IDP
ClientCertPair *tls.Certificate
// DisablePromptLogin makes the PKCE flow to not prompt the user for login
DisablePromptLogin bool
// LoginFlag is used to configure the PKCE flow login behavior
LoginFlag common.LoginFlag
// LoginHint is used to pre-fill the email/username field during authentication
LoginHint string
}
// GetPKCEAuthorizationFlowInfo initialize a PKCEAuthorizationFlow instance and return with it
func GetPKCEAuthorizationFlowInfo(ctx context.Context, privateKey string, mgmURL *url.URL, clientCert *tls.Certificate) (PKCEAuthorizationFlow, error) {
// validate our peer's Wireguard PRIVATE key
myPrivateKey, err := wgtypes.ParseKey(privateKey)
if err != nil {
log.Errorf("failed parsing Wireguard key %s: [%s]", privateKey, err.Error())
return PKCEAuthorizationFlow{}, err
}
var mgmTLSEnabled bool
if mgmURL.Scheme == "https" {
mgmTLSEnabled = true
}
log.Debugf("connecting to Management Service %s", mgmURL.String())
mgmClient, err := mgm.NewClient(ctx, mgmURL.Host, myPrivateKey, mgmTLSEnabled)
if err != nil {
log.Errorf("failed connecting to Management Service %s %v", mgmURL.String(), err)
return PKCEAuthorizationFlow{}, err
}
log.Debugf("connected to the Management service %s", mgmURL.String())
defer func() {
err = mgmClient.Close()
if err != nil {
log.Warnf("failed to close the Management service client %v", err)
}
}()
serverKey, err := mgmClient.GetServerPublicKey()
if err != nil {
log.Errorf("failed while getting Management Service public key: %v", err)
return PKCEAuthorizationFlow{}, err
}
protoPKCEAuthorizationFlow, err := mgmClient.GetPKCEAuthorizationFlow(*serverKey)
if err != nil {
if s, ok := status.FromError(err); ok && s.Code() == codes.NotFound {
log.Warnf("server couldn't find pkce flow, contact admin: %v", err)
return PKCEAuthorizationFlow{}, err
}
log.Errorf("failed to retrieve pkce flow: %v", err)
return PKCEAuthorizationFlow{}, err
}
authFlow := PKCEAuthorizationFlow{
ProviderConfig: PKCEAuthProviderConfig{
Audience: protoPKCEAuthorizationFlow.GetProviderConfig().GetAudience(),
ClientID: protoPKCEAuthorizationFlow.GetProviderConfig().GetClientID(),
ClientSecret: protoPKCEAuthorizationFlow.GetProviderConfig().GetClientSecret(),
TokenEndpoint: protoPKCEAuthorizationFlow.GetProviderConfig().GetTokenEndpoint(),
AuthorizationEndpoint: protoPKCEAuthorizationFlow.GetProviderConfig().GetAuthorizationEndpoint(),
Scope: protoPKCEAuthorizationFlow.GetProviderConfig().GetScope(),
RedirectURLs: protoPKCEAuthorizationFlow.GetProviderConfig().GetRedirectURLs(),
UseIDToken: protoPKCEAuthorizationFlow.GetProviderConfig().GetUseIDToken(),
ClientCertPair: clientCert,
DisablePromptLogin: protoPKCEAuthorizationFlow.GetProviderConfig().GetDisablePromptLogin(),
LoginFlag: common.LoginFlag(protoPKCEAuthorizationFlow.GetProviderConfig().GetLoginFlag()),
},
}
err = isPKCEProviderConfigValid(authFlow.ProviderConfig)
if err != nil {
return PKCEAuthorizationFlow{}, err
}
return authFlow, nil
}
func isPKCEProviderConfigValid(config PKCEAuthProviderConfig) error {
errorMSGFormat := "invalid provider configuration received from management: %s value is empty. Contact your NetBird administrator"
if config.ClientID == "" {
return fmt.Errorf(errorMSGFormat, "Client ID")
}
if config.TokenEndpoint == "" {
return fmt.Errorf(errorMSGFormat, "Token Endpoint")
}
if config.AuthorizationEndpoint == "" {
return fmt.Errorf(errorMSGFormat, "Authorization Auth Endpoint")
}
if config.Scope == "" {
return fmt.Errorf(errorMSGFormat, "PKCE Auth Scopes")
}
if config.RedirectURLs == nil {
return fmt.Errorf(errorMSGFormat, "PKCE Redirect URLs")
}
return nil
}

View File

@@ -252,7 +252,7 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) {
}
if config.AdminURL == nil {
log.Infof("using default Admin URL %s", DefaultAdminURL)
log.Infof("using default Admin URL %s", DefaultManagementURL)
config.AdminURL, err = parseURL("Admin URL", DefaultAdminURL)
if err != nil {
return false, err

View File

@@ -17,11 +17,6 @@ import (
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
const (
defaultLog = slog.LevelInfo
defaultLogLevelVar = "NB_ROSENPASS_LOG_LEVEL"
)
func hashRosenpassKey(key []byte) string {
hasher := sha256.New()
hasher.Write(key)
@@ -39,7 +34,6 @@ type Manager struct {
server *rp.Server
lock sync.Mutex
port int
wgIface PresharedKeySetter
}
// NewManager creates a new Rosenpass manager
@@ -50,7 +44,7 @@ func NewManager(preSharedKey *wgtypes.Key, wgIfaceName string) (*Manager, error)
}
rpKeyHash := hashRosenpassKey(public)
log.Tracef("generated new rosenpass key pair with public key %s", rpKeyHash)
log.Debugf("generated new rosenpass key pair with public key %s", rpKeyHash)
return &Manager{ifaceName: wgIfaceName, rpKeyHash: rpKeyHash, spk: public, ssk: secret, preSharedKey: (*[32]byte)(preSharedKey), rpPeerIDs: make(map[string]*rp.PeerID), lock: sync.Mutex{}}, nil
}
@@ -106,7 +100,7 @@ func (m *Manager) removePeer(wireGuardPubKey string) error {
func (m *Manager) generateConfig() (rp.Config, error) {
opts := &slog.HandlerOptions{
Level: getLogLevel(),
Level: slog.LevelDebug,
}
logger := slog.New(slog.NewTextHandler(os.Stdout, opts))
cfg := rp.Config{Logger: logger}
@@ -115,13 +109,7 @@ func (m *Manager) generateConfig() (rp.Config, error) {
cfg.SecretKey = m.ssk
cfg.Peers = []rp.PeerConfig{}
m.lock.Lock()
m.rpWgHandler = NewNetbirdHandler()
if m.wgIface != nil {
m.rpWgHandler.SetInterface(m.wgIface)
}
m.lock.Unlock()
m.rpWgHandler, _ = NewNetbirdHandler(m.preSharedKey, m.ifaceName)
cfg.Handlers = []rp.Handler{m.rpWgHandler}
@@ -138,26 +126,6 @@ func (m *Manager) generateConfig() (rp.Config, error) {
return cfg, nil
}
func getLogLevel() slog.Level {
level, ok := os.LookupEnv(defaultLogLevelVar)
if !ok {
return defaultLog
}
switch strings.ToLower(level) {
case "debug":
return slog.LevelDebug
case "info":
return slog.LevelInfo
case "warn":
return slog.LevelWarn
case "error":
return slog.LevelError
default:
log.Warnf("unknown log level: %s. Using default %s", level, defaultLog.String())
return defaultLog
}
}
func (m *Manager) OnDisconnected(peerKey string) {
m.lock.Lock()
defer m.lock.Unlock()
@@ -204,20 +172,6 @@ func (m *Manager) Close() error {
return nil
}
// SetInterface sets the WireGuard interface for the rosenpass handler.
// This can be called before or after Run() - the interface will be stored
// and passed to the handler when it's created or updated immediately if
// already running.
func (m *Manager) SetInterface(iface PresharedKeySetter) {
m.lock.Lock()
defer m.lock.Unlock()
m.wgIface = iface
if m.rpWgHandler != nil {
m.rpWgHandler.SetInterface(iface)
}
}
// OnConnected is a handler function that is triggered when a connection to a remote peer establishes
func (m *Manager) OnConnected(remoteWireGuardKey string, remoteRosenpassPubKey []byte, wireGuardIP string, remoteRosenpassAddr string) {
m.lock.Lock()
@@ -238,20 +192,6 @@ func (m *Manager) OnConnected(remoteWireGuardKey string, remoteRosenpassPubKey [
}
}
// IsPresharedKeyInitialized returns true if Rosenpass has completed a handshake
// and set a PSK for the given WireGuard peer.
func (m *Manager) IsPresharedKeyInitialized(wireGuardPubKey string) bool {
m.lock.Lock()
defer m.lock.Unlock()
peerID, ok := m.rpPeerIDs[wireGuardPubKey]
if !ok || peerID == nil {
return false
}
return m.rpWgHandler.IsPeerInitialized(*peerID)
}
func findRandomAvailableUDPPort() (int, error) {
conn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
if err != nil {

View File

@@ -1,50 +1,46 @@
package rosenpass
import (
"sync"
"fmt"
"log/slog"
rp "cunicu.li/go-rosenpass"
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
// PresharedKeySetter is the interface for setting preshared keys on WireGuard peers.
// This minimal interface allows rosenpass to update PSKs without depending on the full WGIface.
type PresharedKeySetter interface {
SetPresharedKey(peerKey string, psk wgtypes.Key, updateOnly bool) error
}
type wireGuardPeer struct {
Interface string
PublicKey rp.Key
}
type NetbirdHandler struct {
mu sync.Mutex
iface PresharedKeySetter
peers map[rp.PeerID]wireGuardPeer
initializedPeers map[rp.PeerID]bool
ifaceName string
client *wgctrl.Client
peers map[rp.PeerID]wireGuardPeer
presharedKey [32]byte
}
func NewNetbirdHandler() *NetbirdHandler {
return &NetbirdHandler{
peers: map[rp.PeerID]wireGuardPeer{},
initializedPeers: map[rp.PeerID]bool{},
func NewNetbirdHandler(preSharedKey *[32]byte, wgIfaceName string) (hdlr *NetbirdHandler, err error) {
hdlr = &NetbirdHandler{
ifaceName: wgIfaceName,
peers: map[rp.PeerID]wireGuardPeer{},
}
}
// SetInterface sets the WireGuard interface for the handler.
// This must be called after the WireGuard interface is created.
func (h *NetbirdHandler) SetInterface(iface PresharedKeySetter) {
h.mu.Lock()
defer h.mu.Unlock()
h.iface = iface
if preSharedKey != nil {
hdlr.presharedKey = *preSharedKey
}
if hdlr.client, err = wgctrl.New(); err != nil {
return nil, fmt.Errorf("failed to creat WireGuard client: %w", err)
}
return hdlr, nil
}
func (h *NetbirdHandler) AddPeer(pid rp.PeerID, intf string, pk rp.Key) {
h.mu.Lock()
defer h.mu.Unlock()
h.peers[pid] = wireGuardPeer{
Interface: intf,
PublicKey: pk,
@@ -52,61 +48,79 @@ func (h *NetbirdHandler) AddPeer(pid rp.PeerID, intf string, pk rp.Key) {
}
func (h *NetbirdHandler) RemovePeer(pid rp.PeerID) {
h.mu.Lock()
defer h.mu.Unlock()
delete(h.peers, pid)
delete(h.initializedPeers, pid)
}
// IsPeerInitialized returns true if Rosenpass has completed a handshake
// and set a PSK for this peer.
func (h *NetbirdHandler) IsPeerInitialized(pid rp.PeerID) bool {
h.mu.Lock()
defer h.mu.Unlock()
return h.initializedPeers[pid]
}
func (h *NetbirdHandler) HandshakeCompleted(pid rp.PeerID, key rp.Key) {
log.Debug("Handshake complete")
h.outputKey(rp.KeyOutputReasonStale, pid, key)
}
func (h *NetbirdHandler) HandshakeExpired(pid rp.PeerID) {
key, _ := rp.GeneratePresharedKey()
log.Debug("Handshake expired")
h.outputKey(rp.KeyOutputReasonStale, pid, key)
}
func (h *NetbirdHandler) outputKey(_ rp.KeyOutputReason, pid rp.PeerID, psk rp.Key) {
h.mu.Lock()
iface := h.iface
wg, ok := h.peers[pid]
isInitialized := h.initializedPeers[pid]
h.mu.Unlock()
if iface == nil {
log.Warn("rosenpass: interface not set, cannot update preshared key")
return
}
if !ok {
return
}
peerKey := wgtypes.Key(wg.PublicKey).String()
pskKey := wgtypes.Key(psk)
// Use updateOnly=true for later rotations (peer already has Rosenpass PSK)
// Use updateOnly=false for first rotation (peer has original/empty PSK)
if err := iface.SetPresharedKey(peerKey, pskKey, isInitialized); err != nil {
log.Errorf("Failed to apply rosenpass key: %v", err)
device, err := h.client.Device(h.ifaceName)
if err != nil {
log.Errorf("Failed to get WireGuard device: %v", err)
return
}
config := []wgtypes.PeerConfig{
{
UpdateOnly: true,
PublicKey: wgtypes.Key(wg.PublicKey),
PresharedKey: (*wgtypes.Key)(&psk),
},
}
for _, peer := range device.Peers {
if peer.PublicKey == wgtypes.Key(wg.PublicKey) {
if publicKeyEmpty(peer.PresharedKey) || peer.PresharedKey == h.presharedKey {
log.Debugf("Restart wireguard connection to peer %s", peer.PublicKey)
config = []wgtypes.PeerConfig{
{
PublicKey: wgtypes.Key(wg.PublicKey),
PresharedKey: (*wgtypes.Key)(&psk),
Endpoint: peer.Endpoint,
AllowedIPs: peer.AllowedIPs,
},
}
err = h.client.ConfigureDevice(wg.Interface, wgtypes.Config{
Peers: []wgtypes.PeerConfig{
{
Remove: true,
PublicKey: wgtypes.Key(wg.PublicKey),
},
},
})
if err != nil {
slog.Debug("Failed to remove peer")
return
}
}
// Mark peer as isInitialized after the successful first rotation
if !isInitialized {
h.mu.Lock()
if _, exists := h.peers[pid]; exists {
h.initializedPeers[pid] = true
}
h.mu.Unlock()
}
if err = h.client.ConfigureDevice(wg.Interface, wgtypes.Config{
Peers: config,
}); err != nil {
log.Errorf("Failed to apply rosenpass key: %v", err)
}
}
func publicKeyEmpty(key wgtypes.Key) bool {
for _, b := range key {
if b != 0 {
return false
}
}
return true
}

View File

@@ -173,21 +173,12 @@ func (m *DefaultManager) setupAndroidRoutes(config ManagerConfig) {
}
func (m *DefaultManager) setupRefCounters(useNoop bool) {
var once sync.Once
var wgIface *net.Interface
toInterface := func() *net.Interface {
once.Do(func() {
wgIface = m.wgInterface.ToInterface()
})
return wgIface
}
m.routeRefCounter = refcounter.New(
func(prefix netip.Prefix, _ struct{}) (struct{}, error) {
return struct{}{}, m.sysOps.AddVPNRoute(prefix, toInterface())
return struct{}{}, m.sysOps.AddVPNRoute(prefix, m.wgInterface.ToInterface())
},
func(prefix netip.Prefix, _ struct{}) error {
return m.sysOps.RemoveVPNRoute(prefix, toInterface())
return m.sysOps.RemoveVPNRoute(prefix, m.wgInterface.ToInterface())
},
)

View File

@@ -4,17 +4,16 @@ package systemops
import (
"strings"
"golang.org/x/sys/unix"
"syscall"
)
// filterRoutesByFlags returns true if the route message should be ignored based on its flags.
func filterRoutesByFlags(routeMessageFlags int) bool {
if routeMessageFlags&unix.RTF_UP == 0 {
if routeMessageFlags&syscall.RTF_UP == 0 {
return true
}
if routeMessageFlags&(unix.RTF_REJECT|unix.RTF_BLACKHOLE|unix.RTF_WASCLONED) != 0 {
if routeMessageFlags&(syscall.RTF_REJECT|syscall.RTF_BLACKHOLE|syscall.RTF_WASCLONED) != 0 {
return true
}
@@ -25,51 +24,42 @@ func filterRoutesByFlags(routeMessageFlags int) bool {
func formatBSDFlags(flags int) string {
var flagStrs []string
if flags&unix.RTF_UP != 0 {
if flags&syscall.RTF_UP != 0 {
flagStrs = append(flagStrs, "U")
}
if flags&unix.RTF_GATEWAY != 0 {
if flags&syscall.RTF_GATEWAY != 0 {
flagStrs = append(flagStrs, "G")
}
if flags&unix.RTF_HOST != 0 {
if flags&syscall.RTF_HOST != 0 {
flagStrs = append(flagStrs, "H")
}
if flags&unix.RTF_REJECT != 0 {
if flags&syscall.RTF_REJECT != 0 {
flagStrs = append(flagStrs, "R")
}
if flags&unix.RTF_DYNAMIC != 0 {
if flags&syscall.RTF_DYNAMIC != 0 {
flagStrs = append(flagStrs, "D")
}
if flags&unix.RTF_MODIFIED != 0 {
if flags&syscall.RTF_MODIFIED != 0 {
flagStrs = append(flagStrs, "M")
}
if flags&unix.RTF_STATIC != 0 {
if flags&syscall.RTF_STATIC != 0 {
flagStrs = append(flagStrs, "S")
}
if flags&unix.RTF_LLINFO != 0 {
if flags&syscall.RTF_LLINFO != 0 {
flagStrs = append(flagStrs, "L")
}
if flags&unix.RTF_LOCAL != 0 {
if flags&syscall.RTF_LOCAL != 0 {
flagStrs = append(flagStrs, "l")
}
if flags&unix.RTF_BLACKHOLE != 0 {
if flags&syscall.RTF_BLACKHOLE != 0 {
flagStrs = append(flagStrs, "B")
}
if flags&unix.RTF_CLONING != 0 {
if flags&syscall.RTF_CLONING != 0 {
flagStrs = append(flagStrs, "C")
}
if flags&unix.RTF_WASCLONED != 0 {
if flags&syscall.RTF_WASCLONED != 0 {
flagStrs = append(flagStrs, "W")
}
if flags&unix.RTF_PROTO1 != 0 {
flagStrs = append(flagStrs, "1")
}
if flags&unix.RTF_PROTO2 != 0 {
flagStrs = append(flagStrs, "2")
}
if flags&unix.RTF_PROTO3 != 0 {
flagStrs = append(flagStrs, "3")
}
if len(flagStrs) == 0 {
return "-"

View File

@@ -4,18 +4,17 @@ package systemops
import (
"strings"
"golang.org/x/sys/unix"
"syscall"
)
// filterRoutesByFlags returns true if the route message should be ignored based on its flags.
func filterRoutesByFlags(routeMessageFlags int) bool {
if routeMessageFlags&unix.RTF_UP == 0 {
if routeMessageFlags&syscall.RTF_UP == 0 {
return true
}
// NOTE: RTF_WASCLONED deprecated in FreeBSD 8.0
if routeMessageFlags&(unix.RTF_REJECT|unix.RTF_BLACKHOLE) != 0 {
// NOTE: syscall.RTF_WASCLONED deprecated in FreeBSD 8.0
if routeMessageFlags&(syscall.RTF_REJECT|syscall.RTF_BLACKHOLE) != 0 {
return true
}
@@ -26,46 +25,37 @@ func filterRoutesByFlags(routeMessageFlags int) bool {
func formatBSDFlags(flags int) string {
var flagStrs []string
if flags&unix.RTF_UP != 0 {
if flags&syscall.RTF_UP != 0 {
flagStrs = append(flagStrs, "U")
}
if flags&unix.RTF_GATEWAY != 0 {
if flags&syscall.RTF_GATEWAY != 0 {
flagStrs = append(flagStrs, "G")
}
if flags&unix.RTF_HOST != 0 {
if flags&syscall.RTF_HOST != 0 {
flagStrs = append(flagStrs, "H")
}
if flags&unix.RTF_REJECT != 0 {
if flags&syscall.RTF_REJECT != 0 {
flagStrs = append(flagStrs, "R")
}
if flags&unix.RTF_DYNAMIC != 0 {
if flags&syscall.RTF_DYNAMIC != 0 {
flagStrs = append(flagStrs, "D")
}
if flags&unix.RTF_MODIFIED != 0 {
if flags&syscall.RTF_MODIFIED != 0 {
flagStrs = append(flagStrs, "M")
}
if flags&unix.RTF_STATIC != 0 {
if flags&syscall.RTF_STATIC != 0 {
flagStrs = append(flagStrs, "S")
}
if flags&unix.RTF_LLINFO != 0 {
if flags&syscall.RTF_LLINFO != 0 {
flagStrs = append(flagStrs, "L")
}
if flags&unix.RTF_LOCAL != 0 {
if flags&syscall.RTF_LOCAL != 0 {
flagStrs = append(flagStrs, "l")
}
if flags&unix.RTF_BLACKHOLE != 0 {
if flags&syscall.RTF_BLACKHOLE != 0 {
flagStrs = append(flagStrs, "B")
}
// Note: RTF_CLONING and RTF_WASCLONED deprecated in FreeBSD 8.0
if flags&unix.RTF_PROTO1 != 0 {
flagStrs = append(flagStrs, "1")
}
if flags&unix.RTF_PROTO2 != 0 {
flagStrs = append(flagStrs, "2")
}
if flags&unix.RTF_PROTO3 != 0 {
flagStrs = append(flagStrs, "3")
}
if len(flagStrs) == 0 {
return "-"

Some files were not shown because too many files have changed in this diff Show More