Compare commits

..

22 Commits

Author SHA1 Message Date
aliamerj
e953089f0b fix snyk hopefully 2025-08-29 16:13:00 +03:00
aliamerj
a9c179ee61 fix snyk 2025-08-29 16:01:31 +03:00
aliamerj
55873e38f7 fix code-gen 2025-08-29 15:28:04 +03:00
aliamerj
073990c10f update codegen 2025-08-29 15:00:15 +03:00
aliamerj
6c80e6adee fix 1 2025-08-29 14:23:04 +03:00
aliamerj
95dff6bbb2 update package 2025-08-29 14:19:22 +03:00
aliamerj
0b9c1fd0a3 running go mod tidy 2025-08-29 14:05:22 +03:00
aliamerj
e19cd11461 use RawJson for both parameters and results 2025-08-29 13:56:28 +03:00
aliamerj
2db19c27f2 fix sonar issue 2025-08-29 10:50:55 +03:00
aliamerj
373f978baa get rid of any/interface type in job database 2025-08-29 10:17:44 +03:00
aliamerj
d693c4f5a6 fix error handle in create job 2025-08-29 08:50:25 +03:00
aliamerj
6c0a46dfd5 apply feedback 2 2025-08-28 00:19:53 +03:00
aliamerj
9043938233 clean switch case 2025-08-27 13:00:43 +03:00
aliamerj
556c4c7777 fix api object 2025-08-27 12:46:20 +03:00
aliamerj
d1aec58108 fix lint 2025-08-26 13:09:54 +03:00
aliamerj
667dfdfcc3 change api and apply new schema 2025-08-26 13:00:28 +03:00
aliamerj
1fe6295d44 fix typo 2025-08-25 15:44:06 +03:00
aliamerj
cc1338f92d apply feedbacks 1 2025-08-22 20:57:50 +03:00
aliamerj
b7f0088fe3 fix MarkPendingJobsAsFailed 2025-08-21 15:04:11 +03:00
aliamerj
78bc18cb9e clean up 2025-08-20 22:19:37 +03:00
aliamerj
0befefe085 fix lint 2025-08-20 20:42:04 +03:00
aliamerj
f4757165c1 implement remote debug api 2025-08-20 20:15:46 +03:00
420 changed files with 5861 additions and 17306 deletions

View File

@@ -16,29 +16,19 @@ jobs:
steps:
- name: Read PR body
id: body
shell: bash
run: |
set -euo pipefail
BODY_B64=$(jq -r '.pull_request.body // "" | @base64' "$GITHUB_EVENT_PATH")
{
echo "body_b64=$BODY_B64"
} >> "$GITHUB_OUTPUT"
BODY=$(jq -r '.pull_request.body // ""' "$GITHUB_EVENT_PATH")
echo "body<<EOF" >> $GITHUB_OUTPUT
echo "$BODY" >> $GITHUB_OUTPUT
echo "EOF" >> $GITHUB_OUTPUT
- name: Validate checkbox selection
id: validate
shell: bash
env:
BODY_B64: ${{ steps.body.outputs.body_b64 }}
run: |
set -euo pipefail
if ! body="$(printf '%s' "$BODY_B64" | base64 -d)"; then
echo "::error::Failed to decode PR body from base64. Data may be corrupted or missing."
exit 1
fi
added_checked=$(printf '%s' "$body" | grep -Ei '^[[:space:]]*-\s*\[x\]\s*I added/updated documentation' | wc -l | tr -d '[:space:]' || true)
noneed_checked=$(printf '%s' "$body" | grep -Ei '^[[:space:]]*-\s*\[x\]\s*Documentation is \*\*not needed\*\*' | wc -l | tr -d '[:space:]' || true)
body='${{ steps.body.outputs.body }}'
added_checked=$(printf "%s" "$body" | grep -E '^- \[x\] I added/updated documentation' -i | wc -l | tr -d ' ')
noneed_checked=$(printf "%s" "$body" | grep -E '^- \[x\] Documentation is \*\*not needed\*\*' -i | wc -l | tr -d ' ')
if [ "$added_checked" -eq 1 ] && [ "$noneed_checked" -eq 1 ]; then
echo "::error::Choose exactly one: either 'docs added' OR 'not needed'."
@@ -51,35 +41,30 @@ jobs:
fi
if [ "$added_checked" -eq 1 ]; then
echo "mode=added" >> "$GITHUB_OUTPUT"
echo "mode=added" >> $GITHUB_OUTPUT
else
echo "mode=noneed" >> "$GITHUB_OUTPUT"
echo "mode=noneed" >> $GITHUB_OUTPUT
fi
- name: Extract docs PR URL (when 'docs added')
if: steps.validate.outputs.mode == 'added'
id: extract
shell: bash
env:
BODY_B64: ${{ steps.body.outputs.body_b64 }}
run: |
set -euo pipefail
body="$(printf '%s' "$BODY_B64" | base64 -d)"
body='${{ steps.body.outputs.body }}'
# Strictly require HTTPS and that it's a PR in netbirdio/docs
# e.g., https://github.com/netbirdio/docs/pull/1234
url="$(printf '%s' "$body" | grep -Eo 'https://github\.com/netbirdio/docs/pull/[0-9]+' | head -n1 || true)"
# Examples accepted:
# https://github.com/netbirdio/docs/pull/1234
url=$(printf "%s" "$body" | grep -Eo 'https://github\.com/netbirdio/docs/pull/[0-9]+' | head -n1 || true)
if [ -z "${url:-}" ]; then
if [ -z "$url" ]; then
echo "::error::You checked 'docs added' but didn't include a valid HTTPS PR link to netbirdio/docs (e.g., https://github.com/netbirdio/docs/pull/1234)."
exit 1
fi
pr_number="$(printf '%s' "$url" | sed -E 's#.*/pull/([0-9]+)$#\1#')"
{
echo "url=$url"
echo "pr_number=$pr_number"
} >> "$GITHUB_OUTPUT"
pr_number=$(echo "$url" | sed -E 's#.*/pull/([0-9]+)$#\1#')
echo "url=$url" >> $GITHUB_OUTPUT
echo "pr_number=$pr_number" >> $GITHUB_OUTPUT
- name: Verify docs PR exists (and is open or merged)
if: steps.validate.outputs.mode == 'added'

View File

@@ -217,7 +217,7 @@ jobs:
- arch: "386"
raceFlag: ""
- arch: "amd64"
raceFlag: "-race"
raceFlag: ""
runs-on: ubuntu-22.04
steps:
- name: Install Go
@@ -382,32 +382,6 @@ jobs:
store: [ 'sqlite', 'postgres' ]
runs-on: ubuntu-22.04
steps:
- name: Create Docker network
run: docker network create promnet
- name: Start Prometheus Pushgateway
run: docker run -d --name pushgateway --network promnet -p 9091:9091 prom/pushgateway
- name: Start Prometheus (for Pushgateway forwarding)
run: |
echo '
global:
scrape_interval: 15s
scrape_configs:
- job_name: "pushgateway"
static_configs:
- targets: ["pushgateway:9091"]
remote_write:
- url: ${{ secrets.GRAFANA_URL }}
basic_auth:
username: ${{ secrets.GRAFANA_USER }}
password: ${{ secrets.GRAFANA_API_KEY }}
' > prometheus.yml
docker run -d --name prometheus --network promnet \
-v $PWD/prometheus.yml:/etc/prometheus/prometheus.yml \
-p 9090:9090 \
prom/prometheus
- name: Install Go
uses: actions/setup-go@v5
with:
@@ -454,10 +428,9 @@ jobs:
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
NETBIRD_STORE_ENGINE=${{ matrix.store }} \
CI=true \
GIT_BRANCH=${{ github.ref_name }} \
go test -tags devcert -run=^$ -bench=. \
-exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE,GIT_BRANCH,GITHUB_RUN_ID' \
-timeout 20m ./management/... ./shared/management/... $(go list ./management/... ./shared/management/... | grep -v -e /management/server/http)
-exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' \
-timeout 20m ./management/... ./shared/management/...
api_benchmark:
name: "Management / Benchmark (API)"
@@ -548,7 +521,7 @@ jobs:
-run=^$ \
-bench=. \
-exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE,GIT_BRANCH,GITHUB_RUN_ID' \
-timeout 20m ./management/server/http/...
-timeout 20m ./management/... ./shared/management/...
api_integration_test:
name: "Management / Integration"
@@ -598,4 +571,4 @@ jobs:
CI=true \
go test -tags=integration \
-exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' \
-timeout 20m ./management/server/http/...
-timeout 20m ./management/... ./shared/management/...

View File

@@ -63,7 +63,7 @@ jobs:
- run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOMODCACHE=${{ env.cache }}
- run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOCACHE=${{ env.modcache }}
- run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe mod tidy
- run: echo "files=$(go list ./... | ForEach-Object { $_ } | Where-Object { $_ -notmatch '/management' } | Where-Object { $_ -notmatch '/relay' } | Where-Object { $_ -notmatch '/signal' })" >> $env:GITHUB_ENV
- run: echo "files=$(go list ./... | ForEach-Object { $_ } | Where-Object { $_ -notmatch '/management' })" >> $env:GITHUB_ENV
- name: test
run: PsExec64 -s -w ${{ github.workspace }} cmd.exe /c "C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe test -tags=devcert -timeout 10m -p 1 ${{ env.files }} > test-out.txt 2>&1"

View File

@@ -19,7 +19,7 @@ jobs:
- name: codespell
uses: codespell-project/actions-codespell@v2
with:
ignore_words_list: erro,clienta,hastable,iif,groupd,testin,groupe,cros
ignore_words_list: erro,clienta,hastable,iif,groupd,testin,groupe
skip: go.mod,go.sum
golangci:
strategy:

View File

@@ -9,7 +9,7 @@ on:
pull_request:
env:
SIGN_PIPE_VER: "v0.0.23"
SIGN_PIPE_VER: "v0.0.22"
GORELEASER_VER: "v2.3.2"
PRODUCT_NAME: "NetBird"
COPYRIGHT: "NetBird GmbH"

View File

@@ -1,67 +0,0 @@
name: Wasm
on:
push:
branches:
- main
pull_request:
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
cancel-in-progress: true
jobs:
js_lint:
name: "JS / Lint"
runs-on: ubuntu-latest
steps:
- name: Checkout repository
uses: actions/checkout@v4
- name: Install Go
uses: actions/setup-go@v5
with:
go-version: "1.23.x"
- name: Install dependencies
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev libpcap-dev
- name: Install golangci-lint
uses: golangci/golangci-lint-action@d6238b002a20823d52840fda27e2d4891c5952dc
with:
version: latest
install-mode: binary
skip-cache: true
skip-pkg-cache: true
skip-build-cache: true
- name: Run golangci-lint for WASM
run: |
GOOS=js GOARCH=wasm golangci-lint run --timeout=12m --out-format colored-line-number ./client/...
continue-on-error: true
js_build:
name: "JS / Build"
runs-on: ubuntu-latest
steps:
- name: Checkout repository
uses: actions/checkout@v4
- name: Install Go
uses: actions/setup-go@v5
with:
go-version: "1.23.x"
- name: Build Wasm client
run: GOOS=js GOARCH=wasm go build -o netbird.wasm ./client/wasm/cmd
env:
CGO_ENABLED: 0
- name: Check Wasm build size
run: |
echo "Wasm build size:"
ls -lh netbird.wasm
SIZE=$(stat -c%s netbird.wasm)
SIZE_MB=$((SIZE / 1024 / 1024))
echo "Size: ${SIZE} bytes (${SIZE_MB} MB)"
if [ ${SIZE} -gt 52428800 ]; then
echo "Wasm binary size (${SIZE_MB}MB) exceeds 50MB limit!"
exit 1
fi

0
.gitmodules vendored
View File

View File

@@ -2,18 +2,6 @@ version: 2
project_name: netbird
builds:
- id: netbird-wasm
dir: client/wasm/cmd
binary: netbird
env: [GOOS=js, GOARCH=wasm, CGO_ENABLED=0]
goos:
- js
goarch:
- wasm
ldflags:
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
mod_timestamp: "{{ .CommitTimestamp }}"
- id: netbird
dir: client
binary: netbird
@@ -127,11 +115,6 @@ archives:
- builds:
- netbird
- netbird-static
- id: netbird-wasm
builds:
- netbird-wasm
name_template: "{{ .ProjectName }}_{{ .Version }}"
format: binary
nfpms:
- maintainer: Netbird <dev@netbird.io>

View File

@@ -1,4 +1,3 @@
<div align="center">
<br/>
<br/>
@@ -53,7 +52,7 @@
### Open Source Network Security in a Single Platform
https://github.com/user-attachments/assets/10cec749-bb56-4ab3-97af-4e38850108d2
<img width="1188" alt="centralized-network-management 1" src="https://github.com/user-attachments/assets/c28cc8e4-15d2-4d2f-bb97-a6433db39d56" />
### NetBird on Lawrence Systems (Video)
[![Watch the video](https://img.youtube.com/vi/Kwrff6h0rEw/0.jpg)](https://www.youtube.com/watch?v=Kwrff6h0rEw)

View File

@@ -4,7 +4,7 @@
# sudo podman build -t localhost/netbird:latest -f client/Dockerfile --ignorefile .dockerignore-client .
# sudo podman run --rm -it --cap-add={BPF,NET_ADMIN,NET_RAW} localhost/netbird:latest
FROM alpine:3.22.2
FROM alpine:3.22.0
# iproute2: busybox doesn't display ip rules properly
RUN apk add --no-cache \
bash \
@@ -18,7 +18,7 @@ ENV \
NB_LOG_FILE="console,/var/log/netbird/client.log" \
NB_DAEMON_ADDR="unix:///var/run/netbird.sock" \
NB_ENTRYPOINT_SERVICE_TIMEOUT="5" \
NB_ENTRYPOINT_LOGIN_TIMEOUT="5"
NB_ENTRYPOINT_LOGIN_TIMEOUT="1"
ENTRYPOINT [ "/usr/local/bin/netbird-entrypoint.sh" ]

View File

@@ -4,7 +4,6 @@ package android
import (
"context"
"os"
"slices"
"sync"
@@ -19,7 +18,7 @@ import (
"github.com/netbirdio/netbird/client/internal/stdnet"
"github.com/netbirdio/netbird/client/system"
"github.com/netbirdio/netbird/formatter"
"github.com/netbirdio/netbird/client/net"
"github.com/netbirdio/netbird/util/net"
)
// ConnectionListener export internal Listener for mobile
@@ -84,8 +83,7 @@ func NewClient(cfgFile string, androidSDKVersion int, deviceName string, uiVersi
}
// Run start the internal client. It is a blocker function
func (c *Client) Run(urlOpener URLOpener, dns *DNSList, dnsReadyListener DnsReadyListener, envList *EnvList) error {
exportEnvList(envList)
func (c *Client) Run(urlOpener URLOpener, dns *DNSList, dnsReadyListener DnsReadyListener) error {
cfg, err := profilemanager.UpdateOrCreateConfig(profilemanager.ConfigInput{
ConfigPath: c.cfgFile,
})
@@ -120,8 +118,7 @@ func (c *Client) Run(urlOpener URLOpener, dns *DNSList, dnsReadyListener DnsRead
// RunWithoutLogin we apply this type of run function when the backed has been started without UI (i.e. after reboot).
// In this case make no sense handle registration steps.
func (c *Client) RunWithoutLogin(dns *DNSList, dnsReadyListener DnsReadyListener, envList *EnvList) error {
exportEnvList(envList)
func (c *Client) RunWithoutLogin(dns *DNSList, dnsReadyListener DnsReadyListener) error {
cfg, err := profilemanager.UpdateOrCreateConfig(profilemanager.ConfigInput{
ConfigPath: c.cfgFile,
})
@@ -252,14 +249,3 @@ func (c *Client) SetConnectionListener(listener ConnectionListener) {
func (c *Client) RemoveConnectionListener() {
c.recorder.RemoveConnectionListener()
}
func exportEnvList(list *EnvList) {
if list == nil {
return
}
for k, v := range list.AllItems() {
if err := os.Setenv(k, v); err != nil {
log.Errorf("could not set env variable %s: %v", k, err)
}
}
}

View File

@@ -1,32 +0,0 @@
package android
import "github.com/netbirdio/netbird/client/internal/peer"
var (
// EnvKeyNBForceRelay Exported for Android java client
EnvKeyNBForceRelay = peer.EnvKeyNBForceRelay
)
// EnvList wraps a Go map for export to Java
type EnvList struct {
data map[string]string
}
// NewEnvList creates a new EnvList
func NewEnvList() *EnvList {
return &EnvList{data: make(map[string]string)}
}
// Put adds a key-value pair
func (el *EnvList) Put(key, value string) {
el.data[key] = value
}
// Get retrieves a value by key
func (el *EnvList) Get(key string) string {
return el.data[key]
}
func (el *EnvList) AllItems() map[string]string {
return el.data
}

View File

@@ -33,7 +33,6 @@ type ErrListener interface {
// the backend want to show an url for the user
type URLOpener interface {
Open(string)
OnLoginSuccess()
}
// Auth can register or login new client
@@ -182,11 +181,6 @@ func (a *Auth) login(urlOpener URLOpener) error {
err = a.withBackOff(a.ctx, func() error {
err := internal.Login(a.ctx, a.config, "", jwtToken)
if err == nil {
go urlOpener.OnLoginSuccess()
}
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied) {
return nil
}

View File

@@ -168,7 +168,7 @@ func runForDuration(cmd *cobra.Command, args []string) error {
client := proto.NewDaemonServiceClient(conn)
stat, err := client.Status(cmd.Context(), &proto.StatusRequest{ShouldRunProbes: true})
stat, err := client.Status(cmd.Context(), &proto.StatusRequest{})
if err != nil {
return fmt.Errorf("failed to get status: %v", status.Convert(err).Message())
}
@@ -303,18 +303,12 @@ func setSyncResponsePersistence(cmd *cobra.Command, args []string) error {
func getStatusOutput(cmd *cobra.Command, anon bool) string {
var statusOutputString string
statusResp, err := getStatus(cmd.Context(), true)
statusResp, err := getStatus(cmd.Context())
if err != nil {
cmd.PrintErrf("Failed to get status: %v\n", err)
} else {
pm := profilemanager.NewProfileManager()
var profName string
if activeProf, err := pm.GetActiveProfile(); err == nil {
profName = activeProf.Name
}
statusOutputString = nbstatus.ParseToFullDetailSummary(
nbstatus.ConvertToStatusOutputOverview(statusResp, anon, "", nil, nil, nil, "", profName),
nbstatus.ConvertToStatusOutputOverview(statusResp, anon, "", nil, nil, nil, "", ""),
)
}
return statusOutputString

View File

@@ -1,8 +0,0 @@
package cmd
import "context"
// SetupDebugHandler is a no-op for WASM
func SetupDebugHandler(context.Context, interface{}, interface{}, interface{}, string) {
// Debug handler not needed for WASM
}

View File

@@ -27,7 +27,7 @@ var downCmd = &cobra.Command{
return err
}
ctx, cancel := context.WithTimeout(context.Background(), time.Second*20)
ctx, cancel := context.WithTimeout(context.Background(), time.Second*7)
defer cancel()
conn, err := DialClientGRPCServer(ctx, daemonAddr)

View File

@@ -4,7 +4,6 @@ import (
"context"
"fmt"
"os"
"os/exec"
"os/user"
"runtime"
"strings"
@@ -228,7 +227,7 @@ func doForegroundLogin(ctx context.Context, cmd *cobra.Command, setupKey string,
}
// update host's static platform and system information
system.UpdateStaticInfoAsync()
system.UpdateStaticInfo()
configFilePath, err := activeProf.FilePath()
if err != nil {
@@ -357,21 +356,13 @@ func openURL(cmd *cobra.Command, verificationURIComplete, userCode string, noBro
cmd.Println("")
if !noBrowser {
if err := openBrowser(verificationURIComplete); err != nil {
if err := open.Run(verificationURIComplete); err != nil {
cmd.Println("\nAlternatively, you may want to use a setup key, see:\n\n" +
"https://docs.netbird.io/how-to/register-machines-using-setup-keys")
}
}
}
// openBrowser opens the URL in a browser, respecting the BROWSER environment variable.
func openBrowser(url string) error {
if browser := os.Getenv("BROWSER"); browser != "" {
return exec.Command(browser, url).Start()
}
return open.Run(url)
}
// isUnixRunningDesktop checks if a Linux OS is running desktop environment
func isUnixRunningDesktop() bool {
if runtime.GOOS != "linux" && runtime.GOOS != "freebsd" {

View File

@@ -39,7 +39,6 @@ const (
extraIFaceBlackListFlag = "extra-iface-blacklist"
dnsRouteIntervalFlag = "dns-router-interval"
enableLazyConnectionFlag = "enable-lazy-connection"
mtuFlag = "mtu"
)
var (
@@ -73,7 +72,6 @@ var (
anonymizeFlag bool
dnsRouteInterval time.Duration
lazyConnEnabled bool
mtu uint16
profilesDisabled bool
updateSettingsDisabled bool
@@ -231,7 +229,7 @@ func FlagNameToEnvVar(cmdFlag string, prefix string) string {
// DialClientGRPCServer returns client connection to the daemon server.
func DialClientGRPCServer(ctx context.Context, addr string) (*grpc.ClientConn, error) {
ctx, cancel := context.WithTimeout(ctx, time.Second*10)
ctx, cancel := context.WithTimeout(ctx, time.Second*3)
defer cancel()
return grpc.DialContext(

View File

@@ -54,7 +54,6 @@ func TestSetFlagsFromEnvVars(t *testing.T) {
cmd.PersistentFlags().StringVar(&interfaceName, interfaceNameFlag, iface.WgInterfaceDefault, "WireGuard interface name")
cmd.PersistentFlags().BoolVar(&rosenpassEnabled, enableRosenpassFlag, false, "Enable Rosenpass feature Rosenpass.")
cmd.PersistentFlags().Uint16Var(&wireguardPort, wireguardPortFlag, iface.DefaultWgPort, "WireGuard interface listening port")
cmd.PersistentFlags().Uint16Var(&mtu, mtuFlag, iface.DefaultMTU, "Set MTU (Maximum Transmission Unit) for the WireGuard interface")
t.Setenv("NB_EXTERNAL_IP_MAP", "abc,dec")
t.Setenv("NB_INTERFACE_NAME", "test-name")

View File

@@ -27,7 +27,7 @@ func (p *program) Start(svc service.Service) error {
log.Info("starting NetBird service") //nolint
// Collect static system and platform information
system.UpdateStaticInfoAsync()
system.UpdateStaticInfo()
// in any case, even if configuration does not exists we run daemon to serve CLI gRPC API.
p.serv = grpc.NewServer()

View File

@@ -68,7 +68,7 @@ func statusFunc(cmd *cobra.Command, args []string) error {
ctx := internal.CtxInitState(cmd.Context())
resp, err := getStatus(ctx, false)
resp, err := getStatus(ctx)
if err != nil {
return err
}
@@ -121,7 +121,7 @@ func statusFunc(cmd *cobra.Command, args []string) error {
return nil
}
func getStatus(ctx context.Context, shouldRunProbes bool) (*proto.StatusResponse, error) {
func getStatus(ctx context.Context) (*proto.StatusResponse, error) {
conn, err := DialClientGRPCServer(ctx, daemonAddr)
if err != nil {
return nil, fmt.Errorf("failed to connect to daemon error: %v\n"+
@@ -130,7 +130,7 @@ func getStatus(ctx context.Context, shouldRunProbes bool) (*proto.StatusResponse
}
defer conn.Close()
resp, err := proto.NewDaemonServiceClient(conn).Status(ctx, &proto.StatusRequest{GetFullPeerStatus: true, ShouldRunProbes: shouldRunProbes})
resp, err := proto.NewDaemonServiceClient(conn).Status(ctx, &proto.StatusRequest{GetFullPeerStatus: true, ShouldRunProbes: true})
if err != nil {
return nil, fmt.Errorf("status failed: %v", status.Convert(err).Message())
}

View File

@@ -9,28 +9,29 @@ import (
"github.com/golang/mock/gomock"
"github.com/stretchr/testify/require"
"go.opentelemetry.io/otel"
"github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/groups"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/util"
"google.golang.org/grpc"
"github.com/netbirdio/management-integrations/integrations"
clientProto "github.com/netbirdio/netbird/client/proto"
client "github.com/netbirdio/netbird/client/server"
"github.com/netbirdio/netbird/management/internals/server/config"
mgmt "github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/groups"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
"github.com/netbirdio/netbird/management/server/peers"
"github.com/netbirdio/netbird/management/server/peers/ephemeral/manager"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/management/server/types"
mgmtProto "github.com/netbirdio/netbird/shared/management/proto"
sigProto "github.com/netbirdio/netbird/shared/signal/proto"
sig "github.com/netbirdio/netbird/signal/server"
"github.com/netbirdio/netbird/util"
)
func startTestingServices(t *testing.T) string {
@@ -89,20 +90,15 @@ func startManagement(t *testing.T, config *config.Config, testFile string) (*grp
if err != nil {
return nil, nil
}
ctrl := gomock.NewController(t)
t.Cleanup(ctrl.Finish)
permissionsManagerMock := permissions.NewMockManager(ctrl)
peersmanager := peers.NewManager(store, permissionsManagerMock)
settingsManagerMock := settings.NewMockManager(ctrl)
iv, _ := integrations.NewIntegratedValidator(context.Background(), peersmanager, settingsManagerMock, eventStore)
iv, _ := integrations.NewIntegratedValidator(context.Background(), eventStore)
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
require.NoError(t, err)
ctrl := gomock.NewController(t)
t.Cleanup(ctrl.Finish)
settingsMockManager := settings.NewMockManager(ctrl)
permissionsManagerMock := permissions.NewMockManager(ctrl)
groupsManager := groups.NewManagerMock()
settingsMockManager.EXPECT().
@@ -116,7 +112,7 @@ func startManagement(t *testing.T, config *config.Config, testFile string) (*grp
}
secretsManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, &manager.EphemeralManager{}, nil, &mgmt.MockIntegratedValidator{})
mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, nil, nil, &mgmt.MockIntegratedValidator{})
if err != nil {
t.Fatal(err)
}

View File

@@ -63,7 +63,6 @@ func init() {
upCmd.PersistentFlags().BoolVarP(&foregroundMode, "foreground-mode", "F", false, "start service in foreground")
upCmd.PersistentFlags().StringVar(&interfaceName, interfaceNameFlag, iface.WgInterfaceDefault, "WireGuard interface name")
upCmd.PersistentFlags().Uint16Var(&wireguardPort, wireguardPortFlag, iface.DefaultWgPort, "WireGuard interface listening port")
upCmd.PersistentFlags().Uint16Var(&mtu, mtuFlag, iface.DefaultMTU, "Set MTU (Maximum Transmission Unit) for the WireGuard interface")
upCmd.PersistentFlags().BoolVarP(&networkMonitor, networkMonitorFlag, "N", networkMonitor,
`Manage network monitoring. Defaults to true on Windows and macOS, false on Linux and FreeBSD. `+
`E.g. --network-monitor=false to disable or --network-monitor=true to enable.`,
@@ -230,9 +229,7 @@ func runInDaemonMode(ctx context.Context, cmd *cobra.Command, pm *profilemanager
client := proto.NewDaemonServiceClient(conn)
status, err := client.Status(ctx, &proto.StatusRequest{
WaitForReady: func() *bool { b := true; return &b }(),
})
status, err := client.Status(ctx, &proto.StatusRequest{})
if err != nil {
return fmt.Errorf("unable to get daemon status: %v", err)
}
@@ -360,11 +357,6 @@ func setupSetConfigReq(customDNSAddressConverted []byte, cmd *cobra.Command, pro
req.WireguardPort = &p
}
if cmd.Flag(mtuFlag).Changed {
m := int64(mtu)
req.Mtu = &m
}
if cmd.Flag(networkMonitorFlag).Changed {
req.NetworkMonitor = &networkMonitor
}
@@ -444,13 +436,6 @@ func setupConfig(customDNSAddressConverted []byte, cmd *cobra.Command, configFil
ic.WireguardPort = &p
}
if cmd.Flag(mtuFlag).Changed {
if err := iface.ValidateMTU(mtu); err != nil {
return nil, err
}
ic.MTU = &mtu
}
if cmd.Flag(networkMonitorFlag).Changed {
ic.NetworkMonitor = &networkMonitor
}
@@ -548,14 +533,6 @@ func setupLoginRequest(providedSetupKey string, customDNSAddressConverted []byte
loginRequest.WireguardPort = &wp
}
if cmd.Flag(mtuFlag).Changed {
if err := iface.ValidateMTU(mtu); err != nil {
return nil, err
}
m := int64(mtu)
loginRequest.Mtu = &m
}
if cmd.Flag(networkMonitorFlag).Changed {
loginRequest.NetworkMonitor = &networkMonitor
}

View File

@@ -23,29 +23,23 @@ import (
var ErrClientAlreadyStarted = errors.New("client already started")
var ErrClientNotStarted = errors.New("client not started")
var ErrConfigNotInitialized = errors.New("config not initialized")
// Client manages a netbird embedded client instance.
// Client manages a netbird embedded client instance
type Client struct {
deviceName string
config *profilemanager.Config
mu sync.Mutex
cancel context.CancelFunc
setupKey string
jwtToken string
connect *internal.ConnectClient
}
// Options configures a new Client.
// Options configures a new Client
type Options struct {
// DeviceName is this peer's name in the network
DeviceName string
// SetupKey is used for authentication
SetupKey string
// JWTToken is used for JWT-based authentication
JWTToken string
// PrivateKey is used for direct private key authentication
PrivateKey string
// ManagementURL overrides the default management server URL
ManagementURL string
// PreSharedKey is the pre-shared key for the WireGuard interface
@@ -64,35 +58,8 @@ type Options struct {
DisableClientRoutes bool
}
// validateCredentials checks that exactly one credential type is provided
func (opts *Options) validateCredentials() error {
credentialsProvided := 0
if opts.SetupKey != "" {
credentialsProvided++
}
if opts.JWTToken != "" {
credentialsProvided++
}
if opts.PrivateKey != "" {
credentialsProvided++
}
if credentialsProvided == 0 {
return fmt.Errorf("one of SetupKey, JWTToken, or PrivateKey must be provided")
}
if credentialsProvided > 1 {
return fmt.Errorf("only one of SetupKey, JWTToken, or PrivateKey can be specified")
}
return nil
}
// New creates a new netbird embedded client.
// New creates a new netbird embedded client
func New(opts Options) (*Client, error) {
if err := opts.validateCredentials(); err != nil {
return nil, err
}
if opts.LogOutput != nil {
logrus.SetOutput(opts.LogOutput)
}
@@ -140,14 +107,9 @@ func New(opts Options) (*Client, error) {
return nil, fmt.Errorf("create config: %w", err)
}
if opts.PrivateKey != "" {
config.PrivateKey = opts.PrivateKey
}
return &Client{
deviceName: opts.DeviceName,
setupKey: opts.SetupKey,
jwtToken: opts.JWTToken,
config: config,
}, nil
}
@@ -164,7 +126,7 @@ func (c *Client) Start(startCtx context.Context) error {
ctx := internal.CtxInitState(context.Background())
// nolint:staticcheck
ctx = context.WithValue(ctx, system.DeviceNameCtxKey, c.deviceName)
if err := internal.Login(ctx, c.config, c.setupKey, c.jwtToken); err != nil {
if err := internal.Login(ctx, c.config, c.setupKey, ""); err != nil {
return fmt.Errorf("login: %w", err)
}
@@ -173,7 +135,7 @@ func (c *Client) Start(startCtx context.Context) error {
// either startup error (permanent backoff err) or nil err (successful engine up)
// TODO: make after-startup backoff err available
run := make(chan struct{})
run := make(chan struct{}, 1)
clientErr := make(chan error, 1)
go func() {
if err := client.Run(run); err != nil {
@@ -225,16 +187,6 @@ func (c *Client) Stop(ctx context.Context) error {
}
}
// GetConfig returns a copy of the internal client config.
func (c *Client) GetConfig() (profilemanager.Config, error) {
c.mu.Lock()
defer c.mu.Unlock()
if c.config == nil {
return profilemanager.Config{}, ErrConfigNotInitialized
}
return *c.config, nil
}
// Dial dials a network address in the netbird network.
// Not applicable if the userspace networking mode is disabled.
func (c *Client) Dial(ctx context.Context, network, address string) (net.Conn, error) {
@@ -259,7 +211,7 @@ func (c *Client) Dial(ctx context.Context, network, address string) (net.Conn, e
return nsnet.DialContext(ctx, network, address)
}
// ListenTCP listens on the given address in the netbird network.
// ListenTCP listens on the given address in the netbird network
// Not applicable if the userspace networking mode is disabled.
func (c *Client) ListenTCP(address string) (net.Listener, error) {
nsnet, addr, err := c.getNet()
@@ -280,7 +232,7 @@ func (c *Client) ListenTCP(address string) (net.Listener, error) {
return nsnet.ListenTCP(tcpAddr)
}
// ListenUDP listens on the given address in the netbird network.
// ListenUDP listens on the given address in the netbird network
// Not applicable if the userspace networking mode is disabled.
func (c *Client) ListenUDP(address string) (net.PacketConn, error) {
nsnet, addr, err := c.getNet()

View File

@@ -12,7 +12,7 @@ import (
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/internal/statemanager"
nbnet "github.com/netbirdio/netbird/client/net"
nbnet "github.com/netbirdio/netbird/util/net"
)
const (
@@ -400,6 +400,7 @@ func transformIPsetName(ipsetName string, sPort, dPort *firewall.Port, action fi
return ""
}
// Include action in the ipset name to prevent squashing rules with different actions
actionSuffix := ""
if action == firewall.ActionDrop {
actionSuffix = "-drop"

View File

@@ -260,22 +260,6 @@ func (m *Manager) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
return m.router.UpdateSet(set, prefixes)
}
// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services.
func (m *Manager) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
m.mutex.Lock()
defer m.mutex.Unlock()
return m.router.AddInboundDNAT(localAddr, protocol, sourcePort, targetPort)
}
// RemoveInboundDNAT removes an inbound DNAT rule.
func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
m.mutex.Lock()
defer m.mutex.Unlock()
return m.router.RemoveInboundDNAT(localAddr, protocol, sourcePort, targetPort)
}
func getConntrackEstablished() []string {
return []string{"-m", "conntrack", "--ctstate", "RELATED,ESTABLISHED", "-j", "ACCEPT"}
}

View File

@@ -19,7 +19,7 @@ import (
"github.com/netbirdio/netbird/client/internal/routemanager/ipfwdstate"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
"github.com/netbirdio/netbird/client/internal/statemanager"
nbnet "github.com/netbirdio/netbird/client/net"
nbnet "github.com/netbirdio/netbird/util/net"
)
// constants needed to manage and create iptable rules
@@ -880,54 +880,6 @@ func (r *router) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
return nberrors.FormatErrorOrNil(merr)
}
// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services.
func (r *router) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
ruleID := fmt.Sprintf("inbound-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort)
if _, exists := r.rules[ruleID]; exists {
return nil
}
dnatRule := []string{
"-i", r.wgIface.Name(),
"-p", strings.ToLower(string(protocol)),
"--dport", strconv.Itoa(int(sourcePort)),
"-d", localAddr.String(),
"-m", "addrtype", "--dst-type", "LOCAL",
"-j", "DNAT",
"--to-destination", ":" + strconv.Itoa(int(targetPort)),
}
ruleInfo := ruleInfo{
table: tableNat,
chain: chainRTRDR,
rule: dnatRule,
}
if err := r.iptablesClient.Append(ruleInfo.table, ruleInfo.chain, ruleInfo.rule...); err != nil {
return fmt.Errorf("add inbound DNAT rule: %w", err)
}
r.rules[ruleID] = ruleInfo.rule
r.updateState()
return nil
}
// RemoveInboundDNAT removes an inbound DNAT rule.
func (r *router) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
ruleID := fmt.Sprintf("inbound-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort)
if dnatRule, exists := r.rules[ruleID]; exists {
if err := r.iptablesClient.Delete(tableNat, chainRTRDR, dnatRule...); err != nil {
return fmt.Errorf("delete inbound DNAT rule: %w", err)
}
delete(r.rules, ruleID)
}
r.updateState()
return nil
}
func applyPort(flag string, port *firewall.Port) []string {
if port == nil {
return nil

View File

@@ -14,7 +14,7 @@ import (
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/firewall/test"
nbnet "github.com/netbirdio/netbird/client/net"
nbnet "github.com/netbirdio/netbird/util/net"
)
func isIptablesSupported() bool {

View File

@@ -151,20 +151,14 @@ type Manager interface {
DisableRouting() error
// AddDNATRule adds outbound DNAT rule for forwarding external traffic to the NetBird network.
// AddDNATRule adds a DNAT rule
AddDNATRule(ForwardRule) (Rule, error)
// DeleteDNATRule deletes the outbound DNAT rule.
// DeleteDNATRule deletes a DNAT rule
DeleteDNATRule(Rule) error
// UpdateSet updates the set with the given prefixes
UpdateSet(hash Set, prefixes []netip.Prefix) error
// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services
AddInboundDNAT(localAddr netip.Addr, protocol Protocol, sourcePort, targetPort uint16) error
// RemoveInboundDNAT removes inbound DNAT rule
RemoveInboundDNAT(localAddr netip.Addr, protocol Protocol, sourcePort, targetPort uint16) error
}
func GenKey(format string, pair RouterPair) string {

View File

@@ -16,7 +16,7 @@ import (
"golang.org/x/sys/unix"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
nbnet "github.com/netbirdio/netbird/client/net"
nbnet "github.com/netbirdio/netbird/util/net"
)
const (

View File

@@ -376,22 +376,6 @@ func (m *Manager) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
return m.router.UpdateSet(set, prefixes)
}
// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services.
func (m *Manager) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
m.mutex.Lock()
defer m.mutex.Unlock()
return m.router.AddInboundDNAT(localAddr, protocol, sourcePort, targetPort)
}
// RemoveInboundDNAT removes an inbound DNAT rule.
func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
m.mutex.Lock()
defer m.mutex.Unlock()
return m.router.RemoveInboundDNAT(localAddr, protocol, sourcePort, targetPort)
}
func (m *Manager) createWorkTable() (*nftables.Table, error) {
tables, err := m.rConn.ListTablesOfFamily(nftables.TableFamilyIPv4)
if err != nil {

View File

@@ -22,7 +22,7 @@ import (
nbid "github.com/netbirdio/netbird/client/internal/acl/id"
"github.com/netbirdio/netbird/client/internal/routemanager/ipfwdstate"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
nbnet "github.com/netbirdio/netbird/client/net"
nbnet "github.com/netbirdio/netbird/util/net"
)
const (
@@ -1350,103 +1350,6 @@ func (r *router) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
return nil
}
// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services.
func (r *router) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
ruleID := fmt.Sprintf("inbound-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort)
if _, exists := r.rules[ruleID]; exists {
return nil
}
protoNum, err := protoToInt(protocol)
if err != nil {
return fmt.Errorf("convert protocol to number: %w", err)
}
exprs := []expr.Any{
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: ifname(r.wgIface.Name()),
},
&expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 2},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 2,
Data: []byte{protoNum},
},
&expr.Payload{
DestRegister: 3,
Base: expr.PayloadBaseTransportHeader,
Offset: 2,
Len: 2,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 3,
Data: binaryutil.BigEndian.PutUint16(sourcePort),
},
}
exprs = append(exprs, applyPrefix(netip.PrefixFrom(localAddr, 32), false)...)
exprs = append(exprs,
&expr.Immediate{
Register: 1,
Data: localAddr.AsSlice(),
},
&expr.Immediate{
Register: 2,
Data: binaryutil.BigEndian.PutUint16(targetPort),
},
&expr.NAT{
Type: expr.NATTypeDestNAT,
Family: uint32(nftables.TableFamilyIPv4),
RegAddrMin: 1,
RegProtoMin: 2,
RegProtoMax: 0,
},
)
dnatRule := &nftables.Rule{
Table: r.workTable,
Chain: r.chains[chainNameRoutingRdr],
Exprs: exprs,
UserData: []byte(ruleID),
}
r.conn.AddRule(dnatRule)
if err := r.conn.Flush(); err != nil {
return fmt.Errorf("add inbound DNAT rule: %w", err)
}
r.rules[ruleID] = dnatRule
return nil
}
// RemoveInboundDNAT removes an inbound DNAT rule.
func (r *router) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
if err := r.refreshRulesMap(); err != nil {
return fmt.Errorf(refreshRulesMapError, err)
}
ruleID := fmt.Sprintf("inbound-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort)
if rule, exists := r.rules[ruleID]; exists {
if err := r.conn.DelRule(rule); err != nil {
return fmt.Errorf("delete inbound DNAT rule %s: %w", ruleID, err)
}
if err := r.conn.Flush(); err != nil {
return fmt.Errorf("flush delete inbound DNAT rule: %w", err)
}
delete(r.rules, ruleID)
}
return nil
}
// applyNetwork generates nftables expressions for networks (CIDR) or sets
func (r *router) applyNetwork(
network firewall.Network,

View File

@@ -22,8 +22,6 @@ type BaseConnTrack struct {
PacketsRx atomic.Uint64
BytesTx atomic.Uint64
BytesRx atomic.Uint64
DNATOrigPort atomic.Uint32
}
// these small methods will be inlined by the compiler

View File

@@ -157,7 +157,7 @@ func NewTCPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nftyp
return tracker
}
func (t *TCPTracker) updateIfExists(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, direction nftypes.Direction, size int) (ConnKey, uint16, bool) {
func (t *TCPTracker) updateIfExists(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, direction nftypes.Direction, size int) (ConnKey, bool) {
key := ConnKey{
SrcIP: srcIP,
DstIP: dstIP,
@@ -171,30 +171,28 @@ func (t *TCPTracker) updateIfExists(srcIP, dstIP netip.Addr, srcPort, dstPort ui
if exists {
t.updateState(key, conn, flags, direction, size)
return key, uint16(conn.DNATOrigPort.Load()), true
return key, true
}
return key, 0, false
return key, false
}
// TrackOutbound records an outbound TCP connection and returns the original port if DNAT reversal is needed
func (t *TCPTracker) TrackOutbound(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, size int) uint16 {
if _, origPort, exists := t.updateIfExists(dstIP, srcIP, dstPort, srcPort, flags, nftypes.Egress, size); exists {
return origPort
// TrackOutbound records an outbound TCP connection
func (t *TCPTracker) TrackOutbound(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, size int) {
if _, exists := t.updateIfExists(dstIP, srcIP, dstPort, srcPort, flags, nftypes.Egress, size); !exists {
// if (inverted direction) conn is not tracked, track this direction
t.track(srcIP, dstIP, srcPort, dstPort, flags, nftypes.Egress, nil, size)
}
// if (inverted direction) conn is not tracked, track this direction
t.track(srcIP, dstIP, srcPort, dstPort, flags, nftypes.Egress, nil, size, 0)
return 0
}
// TrackInbound processes an inbound TCP packet and updates connection state
func (t *TCPTracker) TrackInbound(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, ruleID []byte, size int, dnatOrigPort uint16) {
t.track(srcIP, dstIP, srcPort, dstPort, flags, nftypes.Ingress, ruleID, size, dnatOrigPort)
func (t *TCPTracker) TrackInbound(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, ruleID []byte, size int) {
t.track(srcIP, dstIP, srcPort, dstPort, flags, nftypes.Ingress, ruleID, size)
}
// track is the common implementation for tracking both inbound and outbound connections
func (t *TCPTracker) track(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, direction nftypes.Direction, ruleID []byte, size int, origPort uint16) {
key, _, exists := t.updateIfExists(srcIP, dstIP, srcPort, dstPort, flags, direction, size)
func (t *TCPTracker) track(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, direction nftypes.Direction, ruleID []byte, size int) {
key, exists := t.updateIfExists(srcIP, dstIP, srcPort, dstPort, flags, direction, size)
if exists || flags&TCPSyn == 0 {
return
}
@@ -212,13 +210,8 @@ func (t *TCPTracker) track(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, fla
conn.tombstone.Store(false)
conn.state.Store(int32(TCPStateNew))
conn.DNATOrigPort.Store(uint32(origPort))
if origPort != 0 {
t.logger.Trace4("New %s TCP connection: %s (port DNAT %d -> %d)", direction, key, origPort, dstPort)
} else {
t.logger.Trace2("New %s TCP connection: %s", direction, key)
}
t.logger.Trace2("New %s TCP connection: %s", direction, key)
t.updateState(key, conn, flags, direction, size)
t.mutex.Lock()
@@ -456,21 +449,6 @@ func (t *TCPTracker) cleanup() {
}
}
// GetConnection safely retrieves a connection state
func (t *TCPTracker) GetConnection(srcIP netip.Addr, srcPort uint16, dstIP netip.Addr, dstPort uint16) (*TCPConnTrack, bool) {
t.mutex.RLock()
defer t.mutex.RUnlock()
key := ConnKey{
SrcIP: srcIP,
DstIP: dstIP,
SrcPort: srcPort,
DstPort: dstPort,
}
conn, exists := t.connections[key]
return conn, exists
}
// Close stops the cleanup routine and releases resources
func (t *TCPTracker) Close() {
t.tickerCancel()

View File

@@ -603,7 +603,7 @@ func TestTCPInboundInitiatedConnection(t *testing.T) {
serverPort := uint16(80)
// 1. Client sends SYN (we receive it as inbound)
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPSyn, nil, 100, 0)
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPSyn, nil, 100)
key := ConnKey{
SrcIP: clientIP,
@@ -623,12 +623,12 @@ func TestTCPInboundInitiatedConnection(t *testing.T) {
tracker.TrackOutbound(serverIP, clientIP, serverPort, clientPort, TCPSyn|TCPAck, 100)
// 3. Client sends ACK to complete handshake
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPAck, nil, 100, 0)
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPAck, nil, 100)
require.Equal(t, TCPStateEstablished, conn.GetState(), "Connection should be ESTABLISHED after handshake completion")
// 4. Test data transfer
// Client sends data
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPPush|TCPAck, nil, 1000, 0)
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPPush|TCPAck, nil, 1000)
// Server sends ACK for data
tracker.TrackOutbound(serverIP, clientIP, serverPort, clientPort, TCPAck, 100)
@@ -637,7 +637,7 @@ func TestTCPInboundInitiatedConnection(t *testing.T) {
tracker.TrackOutbound(serverIP, clientIP, serverPort, clientPort, TCPPush|TCPAck, 1500)
// Client sends ACK for data
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPAck, nil, 100, 0)
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPAck, nil, 100)
// Verify state and counters
require.Equal(t, TCPStateEstablished, conn.GetState())

View File

@@ -58,23 +58,20 @@ func NewUDPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nftyp
return tracker
}
// TrackOutbound records an outbound UDP connection and returns the original port if DNAT reversal is needed
func (t *UDPTracker) TrackOutbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, size int) uint16 {
_, origPort, exists := t.updateIfExists(dstIP, srcIP, dstPort, srcPort, nftypes.Egress, size)
if exists {
return origPort
// TrackOutbound records an outbound UDP connection
func (t *UDPTracker) TrackOutbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, size int) {
if _, exists := t.updateIfExists(dstIP, srcIP, dstPort, srcPort, nftypes.Egress, size); !exists {
// if (inverted direction) conn is not tracked, track this direction
t.track(srcIP, dstIP, srcPort, dstPort, nftypes.Egress, nil, size)
}
// if (inverted direction) conn is not tracked, track this direction
t.track(srcIP, dstIP, srcPort, dstPort, nftypes.Egress, nil, size, 0)
return 0
}
// TrackInbound records an inbound UDP connection
func (t *UDPTracker) TrackInbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, ruleID []byte, size int, dnatOrigPort uint16) {
t.track(srcIP, dstIP, srcPort, dstPort, nftypes.Ingress, ruleID, size, dnatOrigPort)
func (t *UDPTracker) TrackInbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, ruleID []byte, size int) {
t.track(srcIP, dstIP, srcPort, dstPort, nftypes.Ingress, ruleID, size)
}
func (t *UDPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, direction nftypes.Direction, size int) (ConnKey, uint16, bool) {
func (t *UDPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, direction nftypes.Direction, size int) (ConnKey, bool) {
key := ConnKey{
SrcIP: srcIP,
DstIP: dstIP,
@@ -89,15 +86,15 @@ func (t *UDPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, srcPort
if exists {
conn.UpdateLastSeen()
conn.UpdateCounters(direction, size)
return key, uint16(conn.DNATOrigPort.Load()), true
return key, true
}
return key, 0, false
return key, false
}
// track is the common implementation for tracking both inbound and outbound connections
func (t *UDPTracker) track(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, direction nftypes.Direction, ruleID []byte, size int, origPort uint16) {
key, _, exists := t.updateIfExists(srcIP, dstIP, srcPort, dstPort, direction, size)
func (t *UDPTracker) track(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, direction nftypes.Direction, ruleID []byte, size int) {
key, exists := t.updateIfExists(srcIP, dstIP, srcPort, dstPort, direction, size)
if exists {
return
}
@@ -112,7 +109,6 @@ func (t *UDPTracker) track(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, d
SourcePort: srcPort,
DestPort: dstPort,
}
conn.DNATOrigPort.Store(uint32(origPort))
conn.UpdateLastSeen()
conn.UpdateCounters(direction, size)
@@ -120,11 +116,7 @@ func (t *UDPTracker) track(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, d
t.connections[key] = conn
t.mutex.Unlock()
if origPort != 0 {
t.logger.Trace4("New %s UDP connection: %s (port DNAT %d -> %d)", direction, key, origPort, dstPort)
} else {
t.logger.Trace2("New %s UDP connection: %s", direction, key)
}
t.logger.Trace2("New %s UDP connection: %s", direction, key)
t.sendEvent(nftypes.TypeStart, conn, ruleID)
}

View File

@@ -109,10 +109,6 @@ type Manager struct {
dnatMappings map[netip.Addr]netip.Addr
dnatMutex sync.RWMutex
dnatBiMap *biDNATMap
portDNATEnabled atomic.Bool
portDNATRules []portDNATRule
portDNATMutex sync.RWMutex
}
// decoder for packages
@@ -126,8 +122,6 @@ type decoder struct {
icmp6 layers.ICMPv6
decoded []gopacket.LayerType
parser *gopacket.DecodingLayerParser
dnatOrigPort uint16
}
// Create userspace firewall manager constructor
@@ -202,7 +196,6 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe
netstack: netstack.IsEnabled(),
localForwarding: enableLocalForwarding,
dnatMappings: make(map[netip.Addr]netip.Addr),
portDNATRules: []portDNATRule{},
}
m.routingEnabled.Store(false)
@@ -637,7 +630,7 @@ func (m *Manager) filterOutbound(packetData []byte, size int) bool {
return true
}
m.trackOutbound(d, srcIP, dstIP, packetData, size)
m.trackOutbound(d, srcIP, dstIP, size)
m.translateOutboundDNAT(packetData, d)
return false
@@ -681,26 +674,14 @@ func getTCPFlags(tcp *layers.TCP) uint8 {
return flags
}
func (m *Manager) trackOutbound(d *decoder, srcIP, dstIP netip.Addr, packetData []byte, size int) {
func (m *Manager) trackOutbound(d *decoder, srcIP, dstIP netip.Addr, size int) {
transport := d.decoded[1]
switch transport {
case layers.LayerTypeUDP:
origPort := m.udpTracker.TrackOutbound(srcIP, dstIP, uint16(d.udp.SrcPort), uint16(d.udp.DstPort), size)
if origPort == 0 {
break
}
if err := m.rewriteUDPPort(packetData, d, origPort, sourcePortOffset); err != nil {
m.logger.Error1("failed to rewrite UDP port: %v", err)
}
m.udpTracker.TrackOutbound(srcIP, dstIP, uint16(d.udp.SrcPort), uint16(d.udp.DstPort), size)
case layers.LayerTypeTCP:
flags := getTCPFlags(&d.tcp)
origPort := m.tcpTracker.TrackOutbound(srcIP, dstIP, uint16(d.tcp.SrcPort), uint16(d.tcp.DstPort), flags, size)
if origPort == 0 {
break
}
if err := m.rewriteTCPPort(packetData, d, origPort, sourcePortOffset); err != nil {
m.logger.Error1("failed to rewrite TCP port: %v", err)
}
m.tcpTracker.TrackOutbound(srcIP, dstIP, uint16(d.tcp.SrcPort), uint16(d.tcp.DstPort), flags, size)
case layers.LayerTypeICMPv4:
m.icmpTracker.TrackOutbound(srcIP, dstIP, d.icmp4.Id, d.icmp4.TypeCode, d.icmp4.Payload, size)
}
@@ -710,15 +691,13 @@ func (m *Manager) trackInbound(d *decoder, srcIP, dstIP netip.Addr, ruleID []byt
transport := d.decoded[1]
switch transport {
case layers.LayerTypeUDP:
m.udpTracker.TrackInbound(srcIP, dstIP, uint16(d.udp.SrcPort), uint16(d.udp.DstPort), ruleID, size, d.dnatOrigPort)
m.udpTracker.TrackInbound(srcIP, dstIP, uint16(d.udp.SrcPort), uint16(d.udp.DstPort), ruleID, size)
case layers.LayerTypeTCP:
flags := getTCPFlags(&d.tcp)
m.tcpTracker.TrackInbound(srcIP, dstIP, uint16(d.tcp.SrcPort), uint16(d.tcp.DstPort), flags, ruleID, size, d.dnatOrigPort)
m.tcpTracker.TrackInbound(srcIP, dstIP, uint16(d.tcp.SrcPort), uint16(d.tcp.DstPort), flags, ruleID, size)
case layers.LayerTypeICMPv4:
m.icmpTracker.TrackInbound(srcIP, dstIP, d.icmp4.Id, d.icmp4.TypeCode, ruleID, d.icmp4.Payload, size)
}
d.dnatOrigPort = 0
}
// udpHooksDrop checks if any UDP hooks should drop the packet
@@ -780,20 +759,10 @@ func (m *Manager) filterInbound(packetData []byte, size int) bool {
return false
}
// TODO: optimize port DNAT by caching matched rules in conntrack
if translated := m.translateInboundPortDNAT(packetData, d, srcIP, dstIP); translated {
// Re-decode after port DNAT translation to update port information
if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil {
m.logger.Error1("failed to re-decode packet after port DNAT: %v", err)
return true
}
srcIP, dstIP = m.extractIPs(d)
}
if translated := m.translateInboundReverse(packetData, d); translated {
// Re-decode after translation to get original addresses
if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil {
m.logger.Error1("failed to re-decode packet after reverse DNAT: %v", err)
m.logger.Error1("Failed to re-decode packet after reverse DNAT: %v", err)
return true
}
srcIP, dstIP = m.extractIPs(d)

View File

@@ -50,8 +50,6 @@ type logMessage struct {
arg4 any
arg5 any
arg6 any
arg7 any
arg8 any
}
// Logger is a high-performance, non-blocking logger
@@ -96,6 +94,7 @@ func (l *Logger) SetLevel(level Level) {
log.Debugf("Set uspfilter logger loglevel to %v", levelStrings[level])
}
func (l *Logger) Error(format string) {
if l.level.Load() >= uint32(LevelError) {
select {
@@ -186,15 +185,6 @@ func (l *Logger) Debug2(format string, arg1, arg2 any) {
}
}
func (l *Logger) Debug3(format string, arg1, arg2, arg3 any) {
if l.level.Load() >= uint32(LevelDebug) {
select {
case l.msgChannel <- logMessage{level: LevelDebug, format: format, arg1: arg1, arg2: arg2, arg3: arg3}:
default:
}
}
}
func (l *Logger) Trace1(format string, arg1 any) {
if l.level.Load() >= uint32(LevelTrace) {
select {
@@ -249,16 +239,6 @@ func (l *Logger) Trace6(format string, arg1, arg2, arg3, arg4, arg5, arg6 any) {
}
}
// Trace8 logs a trace message with 8 arguments (8 placeholder in format string)
func (l *Logger) Trace8(format string, arg1, arg2, arg3, arg4, arg5, arg6, arg7, arg8 any) {
if l.level.Load() >= uint32(LevelTrace) {
select {
case l.msgChannel <- logMessage{level: LevelTrace, format: format, arg1: arg1, arg2: arg2, arg3: arg3, arg4: arg4, arg5: arg5, arg6: arg6, arg7: arg7, arg8: arg8}:
default:
}
}
}
func (l *Logger) formatMessage(buf *[]byte, msg logMessage) {
*buf = (*buf)[:0]
*buf = time.Now().AppendFormat(*buf, "2006-01-02T15:04:05-07:00")
@@ -280,12 +260,6 @@ func (l *Logger) formatMessage(buf *[]byte, msg logMessage) {
argCount++
if msg.arg6 != nil {
argCount++
if msg.arg7 != nil {
argCount++
if msg.arg8 != nil {
argCount++
}
}
}
}
}
@@ -309,10 +283,6 @@ func (l *Logger) formatMessage(buf *[]byte, msg logMessage) {
formatted = fmt.Sprintf(msg.format, msg.arg1, msg.arg2, msg.arg3, msg.arg4, msg.arg5)
case 6:
formatted = fmt.Sprintf(msg.format, msg.arg1, msg.arg2, msg.arg3, msg.arg4, msg.arg5, msg.arg6)
case 7:
formatted = fmt.Sprintf(msg.format, msg.arg1, msg.arg2, msg.arg3, msg.arg4, msg.arg5, msg.arg6, msg.arg7)
case 8:
formatted = fmt.Sprintf(msg.format, msg.arg1, msg.arg2, msg.arg3, msg.arg4, msg.arg5, msg.arg6, msg.arg7, msg.arg8)
}
*buf = append(*buf, formatted...)
@@ -420,4 +390,4 @@ func (l *Logger) Stop(ctx context.Context) error {
case <-done:
return nil
}
}
}

View File

@@ -5,9 +5,7 @@ import (
"errors"
"fmt"
"net/netip"
"slices"
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
@@ -15,21 +13,6 @@ import (
var ErrIPv4Only = errors.New("only IPv4 is supported for DNAT")
var (
errInvalidIPHeaderLength = errors.New("invalid IP header length")
)
const (
// Port offsets in TCP/UDP headers
sourcePortOffset = 0
destinationPortOffset = 2
// IP address offsets in IPv4 header
sourceIPOffset = 12
destinationIPOffset = 16
)
// ipv4Checksum calculates IPv4 header checksum.
func ipv4Checksum(header []byte) uint16 {
if len(header) < 20 {
return 0
@@ -69,7 +52,6 @@ func ipv4Checksum(header []byte) uint16 {
return ^uint16(sum)
}
// icmpChecksum calculates ICMP checksum.
func icmpChecksum(data []byte) uint16 {
var sum1, sum2, sum3, sum4 uint32
i := 0
@@ -107,21 +89,11 @@ func icmpChecksum(data []byte) uint16 {
return ^uint16(sum)
}
// biDNATMap maintains bidirectional DNAT mappings.
type biDNATMap struct {
forward map[netip.Addr]netip.Addr
reverse map[netip.Addr]netip.Addr
}
// portDNATRule represents a port-specific DNAT rule.
type portDNATRule struct {
protocol gopacket.LayerType
origPort uint16
targetPort uint16
targetIP netip.Addr
}
// newBiDNATMap creates a new bidirectional DNAT mapping structure.
func newBiDNATMap() *biDNATMap {
return &biDNATMap{
forward: make(map[netip.Addr]netip.Addr),
@@ -129,13 +101,11 @@ func newBiDNATMap() *biDNATMap {
}
}
// set adds a bidirectional DNAT mapping between original and translated addresses.
func (b *biDNATMap) set(original, translated netip.Addr) {
b.forward[original] = translated
b.reverse[translated] = original
}
// delete removes a bidirectional DNAT mapping for the given original address.
func (b *biDNATMap) delete(original netip.Addr) {
if translated, exists := b.forward[original]; exists {
delete(b.forward, original)
@@ -143,25 +113,19 @@ func (b *biDNATMap) delete(original netip.Addr) {
}
}
// getTranslated returns the translated address for a given original address.
func (b *biDNATMap) getTranslated(original netip.Addr) (netip.Addr, bool) {
translated, exists := b.forward[original]
return translated, exists
}
// getOriginal returns the original address for a given translated address.
func (b *biDNATMap) getOriginal(translated netip.Addr) (netip.Addr, bool) {
original, exists := b.reverse[translated]
return original, exists
}
// AddInternalDNATMapping adds a 1:1 IP address mapping for internal DNAT translation.
func (m *Manager) AddInternalDNATMapping(originalAddr, translatedAddr netip.Addr) error {
if !originalAddr.IsValid() {
return fmt.Errorf("invalid original IP address")
}
if !translatedAddr.IsValid() {
return fmt.Errorf("invalid translated IP address")
if !originalAddr.IsValid() || !translatedAddr.IsValid() {
return fmt.Errorf("invalid IP addresses")
}
if m.localipmanager.IsLocalIP(translatedAddr) {
@@ -171,6 +135,7 @@ func (m *Manager) AddInternalDNATMapping(originalAddr, translatedAddr netip.Addr
m.dnatMutex.Lock()
defer m.dnatMutex.Unlock()
// Initialize both maps together if either is nil
if m.dnatMappings == nil || m.dnatBiMap == nil {
m.dnatMappings = make(map[netip.Addr]netip.Addr)
m.dnatBiMap = newBiDNATMap()
@@ -186,7 +151,7 @@ func (m *Manager) AddInternalDNATMapping(originalAddr, translatedAddr netip.Addr
return nil
}
// RemoveInternalDNATMapping removes a 1:1 IP address mapping.
// RemoveInternalDNATMapping removes a 1:1 IP address mapping
func (m *Manager) RemoveInternalDNATMapping(originalAddr netip.Addr) error {
m.dnatMutex.Lock()
defer m.dnatMutex.Unlock()
@@ -204,7 +169,7 @@ func (m *Manager) RemoveInternalDNATMapping(originalAddr netip.Addr) error {
return nil
}
// getDNATTranslation returns the translated address if a mapping exists.
// getDNATTranslation returns the translated address if a mapping exists
func (m *Manager) getDNATTranslation(addr netip.Addr) (netip.Addr, bool) {
if !m.dnatEnabled.Load() {
return addr, false
@@ -216,7 +181,7 @@ func (m *Manager) getDNATTranslation(addr netip.Addr) (netip.Addr, bool) {
return translated, exists
}
// findReverseDNATMapping finds original address for return traffic.
// findReverseDNATMapping finds original address for return traffic
func (m *Manager) findReverseDNATMapping(translatedAddr netip.Addr) (netip.Addr, bool) {
if !m.dnatEnabled.Load() {
return translatedAddr, false
@@ -228,12 +193,16 @@ func (m *Manager) findReverseDNATMapping(translatedAddr netip.Addr) (netip.Addr,
return original, exists
}
// translateOutboundDNAT applies DNAT translation to outbound packets.
// translateOutboundDNAT applies DNAT translation to outbound packets
func (m *Manager) translateOutboundDNAT(packetData []byte, d *decoder) bool {
if !m.dnatEnabled.Load() {
return false
}
if len(packetData) < 20 || d.decoded[0] != layers.LayerTypeIPv4 {
return false
}
dstIP := netip.AddrFrom4([4]byte{packetData[16], packetData[17], packetData[18], packetData[19]})
translatedIP, exists := m.getDNATTranslation(dstIP)
@@ -241,8 +210,8 @@ func (m *Manager) translateOutboundDNAT(packetData []byte, d *decoder) bool {
return false
}
if err := m.rewritePacketIP(packetData, d, translatedIP, destinationIPOffset); err != nil {
m.logger.Error1("failed to rewrite packet destination: %v", err)
if err := m.rewritePacketDestination(packetData, d, translatedIP); err != nil {
m.logger.Error1("Failed to rewrite packet destination: %v", err)
return false
}
@@ -250,12 +219,16 @@ func (m *Manager) translateOutboundDNAT(packetData []byte, d *decoder) bool {
return true
}
// translateInboundReverse applies reverse DNAT to inbound return traffic.
// translateInboundReverse applies reverse DNAT to inbound return traffic
func (m *Manager) translateInboundReverse(packetData []byte, d *decoder) bool {
if !m.dnatEnabled.Load() {
return false
}
if len(packetData) < 20 || d.decoded[0] != layers.LayerTypeIPv4 {
return false
}
srcIP := netip.AddrFrom4([4]byte{packetData[12], packetData[13], packetData[14], packetData[15]})
originalIP, exists := m.findReverseDNATMapping(srcIP)
@@ -263,8 +236,8 @@ func (m *Manager) translateInboundReverse(packetData []byte, d *decoder) bool {
return false
}
if err := m.rewritePacketIP(packetData, d, originalIP, sourceIPOffset); err != nil {
m.logger.Error1("failed to rewrite packet source: %v", err)
if err := m.rewritePacketSource(packetData, d, originalIP); err != nil {
m.logger.Error1("Failed to rewrite packet source: %v", err)
return false
}
@@ -272,21 +245,21 @@ func (m *Manager) translateInboundReverse(packetData []byte, d *decoder) bool {
return true
}
// rewritePacketIP replaces an IP address (source or destination) in the packet and updates checksums.
func (m *Manager) rewritePacketIP(packetData []byte, d *decoder, newIP netip.Addr, ipOffset int) error {
if !newIP.Is4() {
// rewritePacketDestination replaces destination IP in the packet
func (m *Manager) rewritePacketDestination(packetData []byte, d *decoder, newIP netip.Addr) error {
if len(packetData) < 20 || d.decoded[0] != layers.LayerTypeIPv4 || !newIP.Is4() {
return ErrIPv4Only
}
var oldIP [4]byte
copy(oldIP[:], packetData[ipOffset:ipOffset+4])
newIPBytes := newIP.As4()
var oldDst [4]byte
copy(oldDst[:], packetData[16:20])
newDst := newIP.As4()
copy(packetData[ipOffset:ipOffset+4], newIPBytes[:])
copy(packetData[16:20], newDst[:])
ipHeaderLen := int(d.ip4.IHL) * 4
if ipHeaderLen < 20 || ipHeaderLen > len(packetData) {
return errInvalidIPHeaderLength
return fmt.Errorf("invalid IP header length")
}
binary.BigEndian.PutUint16(packetData[10:12], 0)
@@ -296,9 +269,44 @@ func (m *Manager) rewritePacketIP(packetData []byte, d *decoder, newIP netip.Add
if len(d.decoded) > 1 {
switch d.decoded[1] {
case layers.LayerTypeTCP:
m.updateTCPChecksum(packetData, ipHeaderLen, oldIP[:], newIPBytes[:])
m.updateTCPChecksum(packetData, ipHeaderLen, oldDst[:], newDst[:])
case layers.LayerTypeUDP:
m.updateUDPChecksum(packetData, ipHeaderLen, oldIP[:], newIPBytes[:])
m.updateUDPChecksum(packetData, ipHeaderLen, oldDst[:], newDst[:])
case layers.LayerTypeICMPv4:
m.updateICMPChecksum(packetData, ipHeaderLen)
}
}
return nil
}
// rewritePacketSource replaces the source IP address in the packet
func (m *Manager) rewritePacketSource(packetData []byte, d *decoder, newIP netip.Addr) error {
if len(packetData) < 20 || d.decoded[0] != layers.LayerTypeIPv4 || !newIP.Is4() {
return ErrIPv4Only
}
var oldSrc [4]byte
copy(oldSrc[:], packetData[12:16])
newSrc := newIP.As4()
copy(packetData[12:16], newSrc[:])
ipHeaderLen := int(d.ip4.IHL) * 4
if ipHeaderLen < 20 || ipHeaderLen > len(packetData) {
return fmt.Errorf("invalid IP header length")
}
binary.BigEndian.PutUint16(packetData[10:12], 0)
ipChecksum := ipv4Checksum(packetData[:ipHeaderLen])
binary.BigEndian.PutUint16(packetData[10:12], ipChecksum)
if len(d.decoded) > 1 {
switch d.decoded[1] {
case layers.LayerTypeTCP:
m.updateTCPChecksum(packetData, ipHeaderLen, oldSrc[:], newSrc[:])
case layers.LayerTypeUDP:
m.updateUDPChecksum(packetData, ipHeaderLen, oldSrc[:], newSrc[:])
case layers.LayerTypeICMPv4:
m.updateICMPChecksum(packetData, ipHeaderLen)
}
@@ -307,7 +315,6 @@ func (m *Manager) rewritePacketIP(packetData []byte, d *decoder, newIP netip.Add
return nil
}
// updateTCPChecksum updates TCP checksum after IP address change per RFC 1624.
func (m *Manager) updateTCPChecksum(packetData []byte, ipHeaderLen int, oldIP, newIP []byte) {
tcpStart := ipHeaderLen
if len(packetData) < tcpStart+18 {
@@ -320,7 +327,6 @@ func (m *Manager) updateTCPChecksum(packetData []byte, ipHeaderLen int, oldIP, n
binary.BigEndian.PutUint16(packetData[checksumOffset:checksumOffset+2], newChecksum)
}
// updateUDPChecksum updates UDP checksum after IP address change per RFC 1624.
func (m *Manager) updateUDPChecksum(packetData []byte, ipHeaderLen int, oldIP, newIP []byte) {
udpStart := ipHeaderLen
if len(packetData) < udpStart+8 {
@@ -338,7 +344,6 @@ func (m *Manager) updateUDPChecksum(packetData []byte, ipHeaderLen int, oldIP, n
binary.BigEndian.PutUint16(packetData[checksumOffset:checksumOffset+2], newChecksum)
}
// updateICMPChecksum recalculates ICMP checksum after packet modification.
func (m *Manager) updateICMPChecksum(packetData []byte, ipHeaderLen int) {
icmpStart := ipHeaderLen
if len(packetData) < icmpStart+8 {
@@ -351,7 +356,7 @@ func (m *Manager) updateICMPChecksum(packetData []byte, ipHeaderLen int) {
binary.BigEndian.PutUint16(icmpData[2:4], checksum)
}
// incrementalUpdate performs incremental checksum update per RFC 1624.
// incrementalUpdate performs incremental checksum update per RFC 1624
func incrementalUpdate(oldChecksum uint16, oldBytes, newBytes []byte) uint16 {
sum := uint32(^oldChecksum)
@@ -386,7 +391,7 @@ func incrementalUpdate(oldChecksum uint16, oldBytes, newBytes []byte) uint16 {
return ^uint16(sum)
}
// AddDNATRule adds outbound DNAT rule for forwarding external traffic to NetBird network.
// AddDNATRule adds a DNAT rule (delegates to native firewall for port forwarding)
func (m *Manager) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) {
if m.nativeFirewall == nil {
return nil, errNatNotSupported
@@ -394,184 +399,10 @@ func (m *Manager) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error)
return m.nativeFirewall.AddDNATRule(rule)
}
// DeleteDNATRule deletes outbound DNAT rule.
// DeleteDNATRule deletes a DNAT rule (delegates to native firewall)
func (m *Manager) DeleteDNATRule(rule firewall.Rule) error {
if m.nativeFirewall == nil {
return errNatNotSupported
}
return m.nativeFirewall.DeleteDNATRule(rule)
}
// addPortRedirection adds a port redirection rule.
func (m *Manager) addPortRedirection(targetIP netip.Addr, protocol gopacket.LayerType, sourcePort, targetPort uint16) error {
m.portDNATMutex.Lock()
defer m.portDNATMutex.Unlock()
rule := portDNATRule{
protocol: protocol,
origPort: sourcePort,
targetPort: targetPort,
targetIP: targetIP,
}
m.portDNATRules = append(m.portDNATRules, rule)
m.portDNATEnabled.Store(true)
return nil
}
// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services.
func (m *Manager) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
var layerType gopacket.LayerType
switch protocol {
case firewall.ProtocolTCP:
layerType = layers.LayerTypeTCP
case firewall.ProtocolUDP:
layerType = layers.LayerTypeUDP
default:
return fmt.Errorf("unsupported protocol: %s", protocol)
}
return m.addPortRedirection(localAddr, layerType, sourcePort, targetPort)
}
// removePortRedirection removes a port redirection rule.
func (m *Manager) removePortRedirection(targetIP netip.Addr, protocol gopacket.LayerType, sourcePort, targetPort uint16) error {
m.portDNATMutex.Lock()
defer m.portDNATMutex.Unlock()
m.portDNATRules = slices.DeleteFunc(m.portDNATRules, func(rule portDNATRule) bool {
return rule.protocol == protocol && rule.origPort == sourcePort && rule.targetPort == targetPort && rule.targetIP.Compare(targetIP) == 0
})
if len(m.portDNATRules) == 0 {
m.portDNATEnabled.Store(false)
}
return nil
}
// RemoveInboundDNAT removes an inbound DNAT rule.
func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
var layerType gopacket.LayerType
switch protocol {
case firewall.ProtocolTCP:
layerType = layers.LayerTypeTCP
case firewall.ProtocolUDP:
layerType = layers.LayerTypeUDP
default:
return fmt.Errorf("unsupported protocol: %s", protocol)
}
return m.removePortRedirection(localAddr, layerType, sourcePort, targetPort)
}
// translateInboundPortDNAT applies port-specific DNAT translation to inbound packets.
func (m *Manager) translateInboundPortDNAT(packetData []byte, d *decoder, srcIP, dstIP netip.Addr) bool {
if !m.portDNATEnabled.Load() {
return false
}
switch d.decoded[1] {
case layers.LayerTypeTCP:
dstPort := uint16(d.tcp.DstPort)
return m.applyPortRule(packetData, d, srcIP, dstIP, dstPort, layers.LayerTypeTCP, m.rewriteTCPPort)
case layers.LayerTypeUDP:
dstPort := uint16(d.udp.DstPort)
return m.applyPortRule(packetData, d, netip.Addr{}, dstIP, dstPort, layers.LayerTypeUDP, m.rewriteUDPPort)
default:
return false
}
}
type portRewriteFunc func(packetData []byte, d *decoder, newPort uint16, portOffset int) error
func (m *Manager) applyPortRule(packetData []byte, d *decoder, srcIP, dstIP netip.Addr, port uint16, protocol gopacket.LayerType, rewriteFn portRewriteFunc) bool {
m.portDNATMutex.RLock()
defer m.portDNATMutex.RUnlock()
for _, rule := range m.portDNATRules {
if rule.protocol != protocol || rule.targetIP.Compare(dstIP) != 0 {
continue
}
if rule.targetPort == port && rule.targetIP.Compare(srcIP) == 0 {
return false
}
if rule.origPort != port {
continue
}
if err := rewriteFn(packetData, d, rule.targetPort, destinationPortOffset); err != nil {
m.logger.Error1("failed to rewrite port: %v", err)
return false
}
d.dnatOrigPort = rule.origPort
return true
}
return false
}
// rewriteTCPPort rewrites a TCP port (source or destination) and updates checksum.
func (m *Manager) rewriteTCPPort(packetData []byte, d *decoder, newPort uint16, portOffset int) error {
ipHeaderLen := int(d.ip4.IHL) * 4
if ipHeaderLen < 20 || ipHeaderLen > len(packetData) {
return errInvalidIPHeaderLength
}
tcpStart := ipHeaderLen
if len(packetData) < tcpStart+4 {
return fmt.Errorf("packet too short for TCP header")
}
portStart := tcpStart + portOffset
oldPort := binary.BigEndian.Uint16(packetData[portStart : portStart+2])
binary.BigEndian.PutUint16(packetData[portStart:portStart+2], newPort)
if len(packetData) >= tcpStart+18 {
checksumOffset := tcpStart + 16
oldChecksum := binary.BigEndian.Uint16(packetData[checksumOffset : checksumOffset+2])
var oldPortBytes, newPortBytes [2]byte
binary.BigEndian.PutUint16(oldPortBytes[:], oldPort)
binary.BigEndian.PutUint16(newPortBytes[:], newPort)
newChecksum := incrementalUpdate(oldChecksum, oldPortBytes[:], newPortBytes[:])
binary.BigEndian.PutUint16(packetData[checksumOffset:checksumOffset+2], newChecksum)
}
return nil
}
// rewriteUDPPort rewrites a UDP port (source or destination) and updates checksum.
func (m *Manager) rewriteUDPPort(packetData []byte, d *decoder, newPort uint16, portOffset int) error {
ipHeaderLen := int(d.ip4.IHL) * 4
if ipHeaderLen < 20 || ipHeaderLen > len(packetData) {
return errInvalidIPHeaderLength
}
udpStart := ipHeaderLen
if len(packetData) < udpStart+8 {
return fmt.Errorf("packet too short for UDP header")
}
portStart := udpStart + portOffset
oldPort := binary.BigEndian.Uint16(packetData[portStart : portStart+2])
binary.BigEndian.PutUint16(packetData[portStart:portStart+2], newPort)
checksumOffset := udpStart + 6
if len(packetData) >= udpStart+8 {
oldChecksum := binary.BigEndian.Uint16(packetData[checksumOffset : checksumOffset+2])
if oldChecksum != 0 {
var oldPortBytes, newPortBytes [2]byte
binary.BigEndian.PutUint16(oldPortBytes[:], oldPort)
binary.BigEndian.PutUint16(newPortBytes[:], newPort)
newChecksum := incrementalUpdate(oldChecksum, oldPortBytes[:], newPortBytes[:])
binary.BigEndian.PutUint16(packetData[checksumOffset:checksumOffset+2], newChecksum)
}
}
return nil
}

View File

@@ -414,127 +414,3 @@ func BenchmarkChecksumOptimizations(b *testing.B) {
}
})
}
// BenchmarkPortDNAT measures the performance of port DNAT operations
func BenchmarkPortDNAT(b *testing.B) {
scenarios := []struct {
name string
proto layers.IPProtocol
setupDNAT bool
useMatchPort bool
description string
}{
{
name: "tcp_inbound_dnat_match",
proto: layers.IPProtocolTCP,
setupDNAT: true,
useMatchPort: true,
description: "TCP inbound port DNAT translation (22 → 22022)",
},
{
name: "tcp_inbound_dnat_nomatch",
proto: layers.IPProtocolTCP,
setupDNAT: true,
useMatchPort: false,
description: "TCP inbound with DNAT configured but no port match",
},
{
name: "tcp_inbound_no_dnat",
proto: layers.IPProtocolTCP,
setupDNAT: false,
useMatchPort: false,
description: "TCP inbound without DNAT (baseline)",
},
{
name: "udp_inbound_dnat_match",
proto: layers.IPProtocolUDP,
setupDNAT: true,
useMatchPort: true,
description: "UDP inbound port DNAT translation (5353 → 22054)",
},
{
name: "udp_inbound_dnat_nomatch",
proto: layers.IPProtocolUDP,
setupDNAT: true,
useMatchPort: false,
description: "UDP inbound with DNAT configured but no port match",
},
{
name: "udp_inbound_no_dnat",
proto: layers.IPProtocolUDP,
setupDNAT: false,
useMatchPort: false,
description: "UDP inbound without DNAT (baseline)",
},
}
for _, sc := range scenarios {
b.Run(sc.name, func(b *testing.B) {
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger)
require.NoError(b, err)
defer func() {
require.NoError(b, manager.Close(nil))
}()
// Set logger to error level to reduce noise during benchmarking
manager.SetLogLevel(log.ErrorLevel)
defer func() {
// Restore to info level after benchmark
manager.SetLogLevel(log.InfoLevel)
}()
localAddr := netip.MustParseAddr("100.0.2.175")
clientIP := netip.MustParseAddr("100.0.169.249")
var origPort, targetPort, testPort uint16
if sc.proto == layers.IPProtocolTCP {
origPort, targetPort = 22, 22022
} else {
origPort, targetPort = 5353, 22054
}
if sc.useMatchPort {
testPort = origPort
} else {
testPort = 443 // Different port
}
// Setup port DNAT mapping if needed
if sc.setupDNAT {
err := manager.AddInboundDNAT(localAddr, protocolToFirewall(sc.proto), origPort, targetPort)
require.NoError(b, err)
}
// Pre-establish inbound connection for outbound reverse test
if sc.setupDNAT && sc.useMatchPort {
inboundPacket := generateDNATTestPacket(b, clientIP, localAddr, sc.proto, 54321, origPort)
manager.filterInbound(inboundPacket, 0)
}
b.ResetTimer()
b.ReportAllocs()
// Benchmark inbound DNAT translation
b.Run("inbound", func(b *testing.B) {
for i := 0; i < b.N; i++ {
// Create fresh packet each time
packet := generateDNATTestPacket(b, clientIP, localAddr, sc.proto, 54321, testPort)
manager.filterInbound(packet, 0)
}
})
// Benchmark outbound reverse DNAT translation (only if DNAT is set up and port matches)
if sc.setupDNAT && sc.useMatchPort {
b.Run("outbound_reverse", func(b *testing.B) {
for i := 0; i < b.N; i++ {
// Create fresh return packet (from target port)
packet := generateDNATTestPacket(b, localAddr, clientIP, sc.proto, targetPort, 54321)
manager.filterOutbound(packet, 0)
}
})
}
})
}
}

View File

@@ -8,7 +8,6 @@ import (
"github.com/google/gopacket/layers"
"github.com/stretchr/testify/require"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface/device"
)
@@ -144,111 +143,3 @@ func TestDNATMappingManagement(t *testing.T) {
err = manager.RemoveInternalDNATMapping(originalIP)
require.Error(t, err, "Should error when removing non-existent mapping")
}
func TestInboundPortDNAT(t *testing.T) {
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger)
require.NoError(t, err)
defer func() {
require.NoError(t, manager.Close(nil))
}()
localAddr := netip.MustParseAddr("100.0.2.175")
clientIP := netip.MustParseAddr("100.0.169.249")
testCases := []struct {
name string
protocol layers.IPProtocol
sourcePort uint16
targetPort uint16
}{
{"TCP SSH", layers.IPProtocolTCP, 22, 22022},
{"UDP DNS", layers.IPProtocolUDP, 5353, 22054},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
err := manager.AddInboundDNAT(localAddr, protocolToFirewall(tc.protocol), tc.sourcePort, tc.targetPort)
require.NoError(t, err)
inboundPacket := generateDNATTestPacket(t, clientIP, localAddr, tc.protocol, 54321, tc.sourcePort)
d := parsePacket(t, inboundPacket)
translated := manager.translateInboundPortDNAT(inboundPacket, d, clientIP, localAddr)
require.True(t, translated, "Inbound packet should be translated")
d = parsePacket(t, inboundPacket)
var dstPort uint16
switch tc.protocol {
case layers.IPProtocolTCP:
dstPort = uint16(d.tcp.DstPort)
case layers.IPProtocolUDP:
dstPort = uint16(d.udp.DstPort)
}
require.Equal(t, tc.targetPort, dstPort, "Destination port should be rewritten to target port")
err = manager.RemoveInboundDNAT(localAddr, protocolToFirewall(tc.protocol), tc.sourcePort, tc.targetPort)
require.NoError(t, err)
})
}
}
func TestInboundPortDNATNegative(t *testing.T) {
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger)
require.NoError(t, err)
defer func() {
require.NoError(t, manager.Close(nil))
}()
localAddr := netip.MustParseAddr("100.0.2.175")
clientIP := netip.MustParseAddr("100.0.169.249")
err = manager.AddInboundDNAT(localAddr, firewall.ProtocolTCP, 22, 22022)
require.NoError(t, err)
testCases := []struct {
name string
protocol layers.IPProtocol
srcIP netip.Addr
dstIP netip.Addr
srcPort uint16
dstPort uint16
}{
{"Wrong port", layers.IPProtocolTCP, clientIP, localAddr, 54321, 80},
{"Wrong IP", layers.IPProtocolTCP, clientIP, netip.MustParseAddr("100.64.0.99"), 54321, 22},
{"Wrong protocol", layers.IPProtocolUDP, clientIP, localAddr, 54321, 22},
{"ICMP", layers.IPProtocolICMPv4, clientIP, localAddr, 0, 0},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
packet := generateDNATTestPacket(t, tc.srcIP, tc.dstIP, tc.protocol, tc.srcPort, tc.dstPort)
d := parsePacket(t, packet)
translated := manager.translateInboundPortDNAT(packet, d, tc.srcIP, tc.dstIP)
require.False(t, translated, "Packet should NOT be translated for %s", tc.name)
d = parsePacket(t, packet)
if tc.protocol == layers.IPProtocolTCP {
require.Equal(t, tc.dstPort, uint16(d.tcp.DstPort), "Port should remain unchanged")
} else if tc.protocol == layers.IPProtocolUDP {
require.Equal(t, tc.dstPort, uint16(d.udp.DstPort), "Port should remain unchanged")
}
})
}
}
func protocolToFirewall(proto layers.IPProtocol) firewall.Protocol {
switch proto {
case layers.IPProtocolTCP:
return firewall.ProtocolTCP
case layers.IPProtocolUDP:
return firewall.ProtocolUDP
default:
return firewall.ProtocolALL
}
}

View File

@@ -16,33 +16,25 @@ type PacketStage int
const (
StageReceived PacketStage = iota
StageInboundPortDNAT
StageInbound1to1NAT
StageConntrack
StagePeerACL
StageRouting
StageRouteACL
StageForwarding
StageCompleted
StageOutbound1to1NAT
StageOutboundPortReverse
)
const msgProcessingCompleted = "Processing completed"
func (s PacketStage) String() string {
return map[PacketStage]string{
StageReceived: "Received",
StageInboundPortDNAT: "Inbound Port DNAT",
StageInbound1to1NAT: "Inbound 1:1 NAT",
StageConntrack: "Connection Tracking",
StagePeerACL: "Peer ACL",
StageRouting: "Routing",
StageRouteACL: "Route ACL",
StageForwarding: "Forwarding",
StageCompleted: "Completed",
StageOutbound1to1NAT: "Outbound 1:1 NAT",
StageOutboundPortReverse: "Outbound DNAT Reverse",
StageReceived: "Received",
StageConntrack: "Connection Tracking",
StagePeerACL: "Peer ACL",
StageRouting: "Routing",
StageRouteACL: "Route ACL",
StageForwarding: "Forwarding",
StageCompleted: "Completed",
}[s]
}
@@ -269,10 +261,6 @@ func (m *Manager) TracePacket(packetData []byte, direction fw.RuleDirection) *Pa
}
func (m *Manager) traceInbound(packetData []byte, trace *PacketTrace, d *decoder, srcIP netip.Addr, dstIP netip.Addr) *PacketTrace {
if m.handleInboundDNAT(trace, packetData, d, &srcIP, &dstIP) {
return trace
}
if m.stateful && m.handleConntrackState(trace, d, srcIP, dstIP) {
return trace
}
@@ -412,16 +400,7 @@ func (m *Manager) addForwardingResult(trace *PacketTrace, action, remoteAddr str
}
func (m *Manager) traceOutbound(packetData []byte, trace *PacketTrace) *PacketTrace {
d := m.decoders.Get().(*decoder)
defer m.decoders.Put(d)
if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil {
trace.AddResult(StageCompleted, "Packet dropped - decode error", false)
return trace
}
m.handleOutboundDNAT(trace, packetData, d)
// will create or update the connection state
dropped := m.filterOutbound(packetData, 0)
if dropped {
trace.AddResult(StageCompleted, "Packet dropped by outgoing hook", false)
@@ -430,199 +409,3 @@ func (m *Manager) traceOutbound(packetData []byte, trace *PacketTrace) *PacketTr
}
return trace
}
func (m *Manager) handleInboundDNAT(trace *PacketTrace, packetData []byte, d *decoder, srcIP, dstIP *netip.Addr) bool {
portDNATApplied := m.traceInboundPortDNAT(trace, packetData, d)
if portDNATApplied {
if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil {
trace.AddResult(StageInboundPortDNAT, "Failed to re-decode after port DNAT", false)
return true
}
*srcIP, *dstIP = m.extractIPs(d)
trace.DestinationPort = m.getDestPort(d)
}
nat1to1Applied := m.traceInbound1to1NAT(trace, packetData, d)
if nat1to1Applied {
if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil {
trace.AddResult(StageInbound1to1NAT, "Failed to re-decode after 1:1 NAT", false)
return true
}
*srcIP, *dstIP = m.extractIPs(d)
}
return false
}
func (m *Manager) traceInboundPortDNAT(trace *PacketTrace, packetData []byte, d *decoder) bool {
if !m.portDNATEnabled.Load() {
trace.AddResult(StageInboundPortDNAT, "Port DNAT not enabled", true)
return false
}
if len(packetData) < 20 || d.decoded[0] != layers.LayerTypeIPv4 {
trace.AddResult(StageInboundPortDNAT, "Not IPv4, skipping port DNAT", true)
return false
}
if len(d.decoded) < 2 {
trace.AddResult(StageInboundPortDNAT, "No transport layer, skipping port DNAT", true)
return false
}
protocol := d.decoded[1]
if protocol != layers.LayerTypeTCP && protocol != layers.LayerTypeUDP {
trace.AddResult(StageInboundPortDNAT, "Not TCP/UDP, skipping port DNAT", true)
return false
}
srcIP := netip.AddrFrom4([4]byte{packetData[12], packetData[13], packetData[14], packetData[15]})
dstIP := netip.AddrFrom4([4]byte{packetData[16], packetData[17], packetData[18], packetData[19]})
var originalPort uint16
if protocol == layers.LayerTypeTCP {
originalPort = uint16(d.tcp.DstPort)
} else {
originalPort = uint16(d.udp.DstPort)
}
translated := m.translateInboundPortDNAT(packetData, d, srcIP, dstIP)
if translated {
ipHeaderLen := int((packetData[0] & 0x0F) * 4)
translatedPort := uint16(packetData[ipHeaderLen+2])<<8 | uint16(packetData[ipHeaderLen+3])
protoStr := "TCP"
if protocol == layers.LayerTypeUDP {
protoStr = "UDP"
}
msg := fmt.Sprintf("%s port DNAT applied: %s:%d -> %s:%d", protoStr, dstIP, originalPort, dstIP, translatedPort)
trace.AddResult(StageInboundPortDNAT, msg, true)
return true
}
trace.AddResult(StageInboundPortDNAT, "No matching port DNAT rule", true)
return false
}
func (m *Manager) traceInbound1to1NAT(trace *PacketTrace, packetData []byte, d *decoder) bool {
if !m.dnatEnabled.Load() {
trace.AddResult(StageInbound1to1NAT, "1:1 NAT not enabled", true)
return false
}
srcIP := netip.AddrFrom4([4]byte{packetData[12], packetData[13], packetData[14], packetData[15]})
translated := m.translateInboundReverse(packetData, d)
if translated {
m.dnatMutex.RLock()
translatedIP, exists := m.dnatBiMap.getOriginal(srcIP)
m.dnatMutex.RUnlock()
if exists {
msg := fmt.Sprintf("1:1 NAT reverse applied: %s -> %s", srcIP, translatedIP)
trace.AddResult(StageInbound1to1NAT, msg, true)
return true
}
}
trace.AddResult(StageInbound1to1NAT, "No matching 1:1 NAT rule", true)
return false
}
func (m *Manager) handleOutboundDNAT(trace *PacketTrace, packetData []byte, d *decoder) {
m.traceOutbound1to1NAT(trace, packetData, d)
m.traceOutboundPortReverse(trace, packetData, d)
}
func (m *Manager) traceOutbound1to1NAT(trace *PacketTrace, packetData []byte, d *decoder) bool {
if !m.dnatEnabled.Load() {
trace.AddResult(StageOutbound1to1NAT, "1:1 NAT not enabled", true)
return false
}
dstIP := netip.AddrFrom4([4]byte{packetData[16], packetData[17], packetData[18], packetData[19]})
translated := m.translateOutboundDNAT(packetData, d)
if translated {
m.dnatMutex.RLock()
translatedIP, exists := m.dnatMappings[dstIP]
m.dnatMutex.RUnlock()
if exists {
msg := fmt.Sprintf("1:1 NAT applied: %s -> %s", dstIP, translatedIP)
trace.AddResult(StageOutbound1to1NAT, msg, true)
return true
}
}
trace.AddResult(StageOutbound1to1NAT, "No matching 1:1 NAT rule", true)
return false
}
func (m *Manager) traceOutboundPortReverse(trace *PacketTrace, packetData []byte, d *decoder) bool {
if !m.portDNATEnabled.Load() {
trace.AddResult(StageOutboundPortReverse, "Port DNAT not enabled", true)
return false
}
if len(packetData) < 20 || d.decoded[0] != layers.LayerTypeIPv4 {
trace.AddResult(StageOutboundPortReverse, "Not IPv4, skipping port reverse", true)
return false
}
if len(d.decoded) < 2 {
trace.AddResult(StageOutboundPortReverse, "No transport layer, skipping port reverse", true)
return false
}
srcIP := netip.AddrFrom4([4]byte{packetData[12], packetData[13], packetData[14], packetData[15]})
dstIP := netip.AddrFrom4([4]byte{packetData[16], packetData[17], packetData[18], packetData[19]})
var origPort uint16
transport := d.decoded[1]
switch transport {
case layers.LayerTypeTCP:
srcPort := uint16(d.tcp.SrcPort)
dstPort := uint16(d.tcp.DstPort)
conn, exists := m.tcpTracker.GetConnection(dstIP, dstPort, srcIP, srcPort)
if exists {
origPort = uint16(conn.DNATOrigPort.Load())
}
if origPort != 0 {
msg := fmt.Sprintf("TCP DNAT reverse (tracked connection): %s:%d -> %s:%d", srcIP, srcPort, srcIP, origPort)
trace.AddResult(StageOutboundPortReverse, msg, true)
return true
}
case layers.LayerTypeUDP:
srcPort := uint16(d.udp.SrcPort)
dstPort := uint16(d.udp.DstPort)
conn, exists := m.udpTracker.GetConnection(dstIP, dstPort, srcIP, srcPort)
if exists {
origPort = uint16(conn.DNATOrigPort.Load())
}
if origPort != 0 {
msg := fmt.Sprintf("UDP DNAT reverse (tracked connection): %s:%d -> %s:%d", srcIP, srcPort, srcIP, origPort)
trace.AddResult(StageOutboundPortReverse, msg, true)
return true
}
default:
trace.AddResult(StageOutboundPortReverse, "Not TCP/UDP, skipping port reverse", true)
return false
}
trace.AddResult(StageOutboundPortReverse, "No tracked connection for DNAT reverse", true)
return false
}
func (m *Manager) getDestPort(d *decoder) uint16 {
if len(d.decoded) < 2 {
return 0
}
switch d.decoded[1] {
case layers.LayerTypeTCP:
return uint16(d.tcp.DstPort)
case layers.LayerTypeUDP:
return uint16(d.udp.DstPort)
default:
return 0
}
}

View File

@@ -104,8 +104,6 @@ func TestTracePacket(t *testing.T) {
},
expectedStages: []PacketStage{
StageReceived,
StageInboundPortDNAT,
StageInbound1to1NAT,
StageConntrack,
StageRouting,
StagePeerACL,
@@ -128,8 +126,6 @@ func TestTracePacket(t *testing.T) {
},
expectedStages: []PacketStage{
StageReceived,
StageInboundPortDNAT,
StageInbound1to1NAT,
StageConntrack,
StageRouting,
StagePeerACL,
@@ -157,8 +153,6 @@ func TestTracePacket(t *testing.T) {
},
expectedStages: []PacketStage{
StageReceived,
StageInboundPortDNAT,
StageInbound1to1NAT,
StageConntrack,
StageRouting,
StagePeerACL,
@@ -185,8 +179,6 @@ func TestTracePacket(t *testing.T) {
},
expectedStages: []PacketStage{
StageReceived,
StageInboundPortDNAT,
StageInbound1to1NAT,
StageConntrack,
StageRouting,
StagePeerACL,
@@ -212,8 +204,6 @@ func TestTracePacket(t *testing.T) {
},
expectedStages: []PacketStage{
StageReceived,
StageInboundPortDNAT,
StageInbound1to1NAT,
StageConntrack,
StageRouting,
StageRouteACL,
@@ -238,8 +228,6 @@ func TestTracePacket(t *testing.T) {
},
expectedStages: []PacketStage{
StageReceived,
StageInboundPortDNAT,
StageInbound1to1NAT,
StageConntrack,
StageRouting,
StageRouteACL,
@@ -258,8 +246,6 @@ func TestTracePacket(t *testing.T) {
},
expectedStages: []PacketStage{
StageReceived,
StageInboundPortDNAT,
StageInbound1to1NAT,
StageConntrack,
StageRouting,
StageRouteACL,
@@ -278,8 +264,6 @@ func TestTracePacket(t *testing.T) {
},
expectedStages: []PacketStage{
StageReceived,
StageInboundPortDNAT,
StageInbound1to1NAT,
StageConntrack,
StageRouting,
StageCompleted,
@@ -303,8 +287,6 @@ func TestTracePacket(t *testing.T) {
},
expectedStages: []PacketStage{
StageReceived,
StageInboundPortDNAT,
StageInbound1to1NAT,
StageConntrack,
StageCompleted,
},
@@ -319,8 +301,6 @@ func TestTracePacket(t *testing.T) {
},
expectedStages: []PacketStage{
StageReceived,
StageOutbound1to1NAT,
StageOutboundPortReverse,
StageCompleted,
},
expectedAllow: true,
@@ -339,8 +319,6 @@ func TestTracePacket(t *testing.T) {
},
expectedStages: []PacketStage{
StageReceived,
StageInboundPortDNAT,
StageInbound1to1NAT,
StageConntrack,
StageRouting,
StagePeerACL,
@@ -362,8 +340,6 @@ func TestTracePacket(t *testing.T) {
},
expectedStages: []PacketStage{
StageReceived,
StageInboundPortDNAT,
StageInbound1to1NAT,
StageConntrack,
StageRouting,
StagePeerACL,
@@ -386,8 +362,6 @@ func TestTracePacket(t *testing.T) {
},
expectedStages: []PacketStage{
StageReceived,
StageInboundPortDNAT,
StageInbound1to1NAT,
StageConntrack,
StageRouting,
StagePeerACL,
@@ -408,8 +382,6 @@ func TestTracePacket(t *testing.T) {
},
expectedStages: []PacketStage{
StageReceived,
StageInboundPortDNAT,
StageInbound1to1NAT,
StageConntrack,
StageRouting,
StagePeerACL,
@@ -434,8 +406,6 @@ func TestTracePacket(t *testing.T) {
},
expectedStages: []PacketStage{
StageReceived,
StageInboundPortDNAT,
StageInbound1to1NAT,
StageRouting,
StagePeerACL,
StageCompleted,

View File

@@ -1,44 +0,0 @@
//go:build !js
package grpc
import (
"context"
"fmt"
"net"
"os/user"
"runtime"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
log "github.com/sirupsen/logrus"
"google.golang.org/grpc"
nbnet "github.com/netbirdio/netbird/client/net"
)
func WithCustomDialer(tlsEnabled bool, component string) grpc.DialOption {
return grpc.WithContextDialer(func(ctx context.Context, addr string) (net.Conn, error) {
if runtime.GOOS == "linux" {
currentUser, err := user.Current()
if err != nil {
return nil, status.Errorf(codes.FailedPrecondition, "failed to get current user: %v", err)
}
// the custom dialer requires root permissions which are not required for use cases run as non-root
if currentUser.Uid != "0" {
log.Debug("Not running as root, using standard dialer")
dialer := &net.Dialer{}
return dialer.DialContext(ctx, "tcp", addr)
}
}
conn, err := nbnet.NewDialer().DialContext(ctx, "tcp", addr)
if err != nil {
log.Errorf("Failed to dial: %s", err)
return nil, fmt.Errorf("nbnet.NewDialer().DialContext: %w", err)
}
return conn, nil
})
}

View File

@@ -1,13 +0,0 @@
package grpc
import (
"google.golang.org/grpc"
"github.com/netbirdio/netbird/util/wsproxy/client"
)
// WithCustomDialer returns a gRPC dial option that uses WebSocket transport for WASM/JS environments.
// The component parameter specifies the WebSocket proxy component path (e.g., "/management", "/signal").
func WithCustomDialer(tlsEnabled bool, component string) grpc.DialOption {
return client.WithWebSocketDialer(tlsEnabled, component)
}

View File

@@ -3,7 +3,7 @@ package bind
import (
wireguard "golang.zx2c4.com/wireguard/conn"
nbnet "github.com/netbirdio/netbird/client/net"
nbnet "github.com/netbirdio/netbird/util/net"
)
// TODO: This is most likely obsolete since the control fns should be called by the wrapped udpconn (ice_bind.go)

View File

@@ -1,17 +1,5 @@
package bind
import (
"net"
wgConn "golang.zx2c4.com/wireguard/conn"
)
import wgConn "golang.zx2c4.com/wireguard/conn"
type Endpoint = wgConn.StdNetEndpoint
func EndpointToUDPAddr(e Endpoint) *net.UDPAddr {
return &net.UDPAddr{
IP: e.Addr().AsSlice(),
Port: int(e.Port()),
Zone: e.Addr().Zone(),
}
}

View File

@@ -1,7 +0,0 @@
package bind
import "fmt"
var (
ErrUDPMUXNotSupported = fmt.Errorf("UDPMUX is not supported in WASM")
)

View File

@@ -1,9 +1,6 @@
//go:build !js
package bind
import (
"context"
"encoding/binary"
"fmt"
"net"
@@ -11,18 +8,22 @@ import (
"runtime"
"sync"
"github.com/pion/stun/v3"
"github.com/pion/stun/v2"
"github.com/pion/transport/v3"
log "github.com/sirupsen/logrus"
"golang.org/x/net/ipv4"
"golang.org/x/net/ipv6"
wgConn "golang.zx2c4.com/wireguard/conn"
"github.com/netbirdio/netbird/client/iface/udpmux"
"github.com/netbirdio/netbird/client/iface/wgaddr"
nbnet "github.com/netbirdio/netbird/client/net"
nbnet "github.com/netbirdio/netbird/util/net"
)
type RecvMessage struct {
Endpoint *Endpoint
Buffer []byte
}
type receiverCreator struct {
iceBind *ICEBind
}
@@ -40,38 +41,35 @@ func (rc receiverCreator) CreateIPv4ReceiverFn(pc *ipv4.PacketConn, conn *net.UD
// use the port because in the Send function the wgConn.Endpoint the port info is not exported.
type ICEBind struct {
*wgConn.StdNetBind
RecvChan chan RecvMessage
transportNet transport.Net
filterFn udpmux.FilterFn
address wgaddr.Address
mtu uint16
endpoints map[netip.Addr]net.Conn
endpointsMu sync.Mutex
recvChan chan recvMessage
filterFn FilterFn
endpoints map[netip.Addr]net.Conn
endpointsMu sync.Mutex
// every time when Close() is called (i.e. BindUpdate()) we need to close exit from the receiveRelayed and create a
// new closed channel. With the closedChanMu we can safely close the channel and create a new one
closedChan chan struct{}
closedChanMu sync.RWMutex // protect the closeChan recreation from reading from it.
closed bool
activityRecorder *ActivityRecorder
closedChan chan struct{}
closedChanMu sync.RWMutex // protect the closeChan recreation from reading from it.
closed bool
muUDPMux sync.Mutex
udpMux *udpmux.UniversalUDPMuxDefault
muUDPMux sync.Mutex
udpMux *UniversalUDPMuxDefault
address wgaddr.Address
activityRecorder *ActivityRecorder
}
func NewICEBind(transportNet transport.Net, filterFn udpmux.FilterFn, address wgaddr.Address, mtu uint16) *ICEBind {
func NewICEBind(transportNet transport.Net, filterFn FilterFn, address wgaddr.Address) *ICEBind {
b, _ := wgConn.NewStdNetBind().(*wgConn.StdNetBind)
ib := &ICEBind{
StdNetBind: b,
RecvChan: make(chan RecvMessage, 1),
transportNet: transportNet,
filterFn: filterFn,
address: address,
mtu: mtu,
endpoints: make(map[netip.Addr]net.Conn),
recvChan: make(chan recvMessage, 1),
closedChan: make(chan struct{}),
closed: true,
address: address,
activityRecorder: NewActivityRecorder(),
}
@@ -111,7 +109,7 @@ func (s *ICEBind) ActivityRecorder() *ActivityRecorder {
}
// GetICEMux returns the ICE UDPMux that was created and used by ICEBind
func (s *ICEBind) GetICEMux() (*udpmux.UniversalUDPMuxDefault, error) {
func (s *ICEBind) GetICEMux() (*UniversalUDPMuxDefault, error) {
s.muUDPMux.Lock()
defer s.muUDPMux.Unlock()
if s.udpMux == nil {
@@ -134,16 +132,6 @@ func (b *ICEBind) RemoveEndpoint(fakeIP netip.Addr) {
delete(b.endpoints, fakeIP)
}
func (b *ICEBind) ReceiveFromEndpoint(ctx context.Context, ep *Endpoint, buf []byte) {
select {
case <-b.closedChan:
return
case <-ctx.Done():
return
case b.recvChan <- recvMessage{ep, buf}:
}
}
func (b *ICEBind) Send(bufs [][]byte, ep wgConn.Endpoint) error {
b.endpointsMu.Lock()
conn, ok := b.endpoints[ep.DstIP()]
@@ -164,13 +152,12 @@ func (s *ICEBind) createIPv4ReceiverFn(pc *ipv4.PacketConn, conn *net.UDPConn, r
s.muUDPMux.Lock()
defer s.muUDPMux.Unlock()
s.udpMux = udpmux.NewUniversalUDPMuxDefault(
udpmux.UniversalUDPMuxParams{
s.udpMux = NewUniversalUDPMuxDefault(
UniversalUDPMuxParams{
UDPConn: nbnet.WrapPacketConn(conn),
Net: s.transportNet,
FilterFn: s.filterFn,
WGAddress: s.address,
MTU: s.mtu,
},
)
return func(bufs [][]byte, sizes []int, eps []wgConn.Endpoint) (n int, err error) {
@@ -276,7 +263,7 @@ func (c *ICEBind) receiveRelayed(buffs [][]byte, sizes []int, eps []wgConn.Endpo
select {
case <-c.closedChan:
return 0, net.ErrClosed
case msg, ok := <-c.recvChan:
case msg, ok := <-c.RecvChan:
if !ok {
return 0, net.ErrClosed
}

View File

@@ -1,6 +0,0 @@
package bind
type recvMessage struct {
Endpoint *Endpoint
Buffer []byte
}

View File

@@ -1,125 +0,0 @@
package bind
import (
"context"
"net"
"net/netip"
"sync"
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/conn"
"github.com/netbirdio/netbird/client/iface/udpmux"
)
// RelayBindJS is a conn.Bind implementation for WebAssembly environments.
// Do not limit to build only js, because we want to be able to run tests
type RelayBindJS struct {
*conn.StdNetBind
recvChan chan recvMessage
endpoints map[netip.Addr]net.Conn
endpointsMu sync.Mutex
activityRecorder *ActivityRecorder
ctx context.Context
cancel context.CancelFunc
}
func NewRelayBindJS() *RelayBindJS {
return &RelayBindJS{
recvChan: make(chan recvMessage, 100),
endpoints: make(map[netip.Addr]net.Conn),
activityRecorder: NewActivityRecorder(),
}
}
// Open creates a receive function for handling relay packets in WASM.
func (s *RelayBindJS) Open(uport uint16) ([]conn.ReceiveFunc, uint16, error) {
log.Debugf("Open: creating receive function for port %d", uport)
s.ctx, s.cancel = context.WithCancel(context.Background())
receiveFn := func(bufs [][]byte, sizes []int, eps []conn.Endpoint) (int, error) {
select {
case <-s.ctx.Done():
return 0, net.ErrClosed
case msg, ok := <-s.recvChan:
if !ok {
return 0, net.ErrClosed
}
copy(bufs[0], msg.Buffer)
sizes[0] = len(msg.Buffer)
eps[0] = conn.Endpoint(msg.Endpoint)
return 1, nil
}
}
log.Debugf("Open: receive function created, returning port %d", uport)
return []conn.ReceiveFunc{receiveFn}, uport, nil
}
func (s *RelayBindJS) Close() error {
if s.cancel == nil {
return nil
}
log.Debugf("close RelayBindJS")
s.cancel()
return nil
}
func (s *RelayBindJS) ReceiveFromEndpoint(ctx context.Context, ep *Endpoint, buf []byte) {
select {
case <-s.ctx.Done():
return
case <-ctx.Done():
return
case s.recvChan <- recvMessage{ep, buf}:
}
}
// Send forwards packets through the relay connection for WASM.
func (s *RelayBindJS) Send(bufs [][]byte, ep conn.Endpoint) error {
if ep == nil {
return nil
}
fakeIP := ep.DstIP()
s.endpointsMu.Lock()
relayConn, ok := s.endpoints[fakeIP]
s.endpointsMu.Unlock()
if !ok {
return nil
}
for _, buf := range bufs {
if _, err := relayConn.Write(buf); err != nil {
return err
}
}
return nil
}
func (b *RelayBindJS) SetEndpoint(fakeIP netip.Addr, conn net.Conn) {
b.endpointsMu.Lock()
b.endpoints[fakeIP] = conn
b.endpointsMu.Unlock()
}
func (s *RelayBindJS) RemoveEndpoint(fakeIP netip.Addr) {
s.endpointsMu.Lock()
defer s.endpointsMu.Unlock()
delete(s.endpoints, fakeIP)
}
// GetICEMux returns the ICE UDPMux that was created and used by ICEBind
func (s *RelayBindJS) GetICEMux() (*udpmux.UniversalUDPMuxDefault, error) {
return nil, ErrUDPMUXNotSupported
}
func (s *RelayBindJS) ActivityRecorder() *ActivityRecorder {
return s.activityRecorder
}

View File

@@ -1,4 +1,4 @@
package udpmux
package bind
import (
"fmt"
@@ -8,9 +8,9 @@ import (
"strings"
"sync"
"github.com/pion/ice/v4"
"github.com/pion/ice/v3"
"github.com/pion/logging"
"github.com/pion/stun/v3"
"github.com/pion/stun/v2"
"github.com/pion/transport/v3"
"github.com/pion/transport/v3/stdnet"
log "github.com/sirupsen/logrus"
@@ -22,9 +22,9 @@ import (
const receiveMTU = 8192
// SingleSocketUDPMux is an implementation of the interface
type SingleSocketUDPMux struct {
params Params
// UDPMuxDefault is an implementation of the interface
type UDPMuxDefault struct {
params UDPMuxParams
closedChan chan struct{}
closeOnce sync.Once
@@ -32,9 +32,6 @@ type SingleSocketUDPMux struct {
// connsIPv4 and connsIPv6 are maps of all udpMuxedConn indexed by ufrag|network|candidateType
connsIPv4, connsIPv6 map[string]*udpMuxedConn
// candidateConnMap maps local candidate IDs to their corresponding connection.
candidateConnMap map[string]*udpMuxedConn
addressMapMu sync.RWMutex
addressMap map[string][]*udpMuxedConn
@@ -49,8 +46,8 @@ type SingleSocketUDPMux struct {
const maxAddrSize = 512
// Params are parameters for UDPMux.
type Params struct {
// UDPMuxParams are parameters for UDPMux.
type UDPMuxParams struct {
Logger logging.LeveledLogger
UDPConn net.PacketConn
@@ -150,19 +147,18 @@ func isZeros(ip net.IP) bool {
return true
}
// NewSingleSocketUDPMux creates an implementation of UDPMux
func NewSingleSocketUDPMux(params Params) *SingleSocketUDPMux {
// NewUDPMuxDefault creates an implementation of UDPMux
func NewUDPMuxDefault(params UDPMuxParams) *UDPMuxDefault {
if params.Logger == nil {
params.Logger = getLogger()
}
mux := &SingleSocketUDPMux{
addressMap: map[string][]*udpMuxedConn{},
params: params,
connsIPv4: make(map[string]*udpMuxedConn),
connsIPv6: make(map[string]*udpMuxedConn),
candidateConnMap: make(map[string]*udpMuxedConn),
closedChan: make(chan struct{}, 1),
mux := &UDPMuxDefault{
addressMap: map[string][]*udpMuxedConn{},
params: params,
connsIPv4: make(map[string]*udpMuxedConn),
connsIPv6: make(map[string]*udpMuxedConn),
closedChan: make(chan struct{}, 1),
pool: &sync.Pool{
New: func() interface{} {
// big enough buffer to fit both packet and address
@@ -175,15 +171,15 @@ func NewSingleSocketUDPMux(params Params) *SingleSocketUDPMux {
return mux
}
func (m *SingleSocketUDPMux) updateLocalAddresses() {
func (m *UDPMuxDefault) updateLocalAddresses() {
var localAddrsForUnspecified []net.Addr
if addr, ok := m.params.UDPConn.LocalAddr().(*net.UDPAddr); !ok {
m.params.Logger.Errorf("LocalAddr is not a net.UDPAddr, got %T", m.params.UDPConn.LocalAddr())
} else if ok && addr.IP.IsUnspecified() {
// For unspecified addresses, the correct behavior is to return errListenUnspecified, but
// it will break the applications that are already using unspecified UDP connection
// with SingleSocketUDPMux, so print a warn log and create a local address list for mux.
m.params.Logger.Warn("SingleSocketUDPMux should not listening on unspecified address, use NewMultiUDPMuxFromPort instead")
// with UDPMuxDefault, so print a warn log and create a local address list for mux.
m.params.Logger.Warn("UDPMuxDefault should not listening on unspecified address, use NewMultiUDPMuxFromPort instead")
var networks []ice.NetworkType
switch {
@@ -220,13 +216,13 @@ func (m *SingleSocketUDPMux) updateLocalAddresses() {
m.mu.Unlock()
}
// LocalAddr returns the listening address of this SingleSocketUDPMux
func (m *SingleSocketUDPMux) LocalAddr() net.Addr {
// LocalAddr returns the listening address of this UDPMuxDefault
func (m *UDPMuxDefault) LocalAddr() net.Addr {
return m.params.UDPConn.LocalAddr()
}
// GetListenAddresses returns the list of addresses that this mux is listening on
func (m *SingleSocketUDPMux) GetListenAddresses() []net.Addr {
func (m *UDPMuxDefault) GetListenAddresses() []net.Addr {
m.updateLocalAddresses()
m.mu.Lock()
@@ -240,7 +236,7 @@ func (m *SingleSocketUDPMux) GetListenAddresses() []net.Addr {
// GetConn returns a PacketConn given the connection's ufrag and network address
// creates the connection if an existing one can't be found
func (m *SingleSocketUDPMux) GetConn(ufrag string, addr net.Addr, candidateID string) (net.PacketConn, error) {
func (m *UDPMuxDefault) GetConn(ufrag string, addr net.Addr) (net.PacketConn, error) {
// don't check addr for mux using unspecified address
m.mu.Lock()
lenLocalAddrs := len(m.localAddrsForUnspecified)
@@ -264,14 +260,12 @@ func (m *SingleSocketUDPMux) GetConn(ufrag string, addr net.Addr, candidateID st
return conn, nil
}
c := m.createMuxedConn(ufrag, candidateID)
c := m.createMuxedConn(ufrag)
go func() {
<-c.CloseChannel()
m.RemoveConnByUfrag(ufrag)
}()
m.candidateConnMap[candidateID] = c
if isIPv6 {
m.connsIPv6[ufrag] = c
} else {
@@ -282,7 +276,7 @@ func (m *SingleSocketUDPMux) GetConn(ufrag string, addr net.Addr, candidateID st
}
// RemoveConnByUfrag stops and removes the muxed packet connection
func (m *SingleSocketUDPMux) RemoveConnByUfrag(ufrag string) {
func (m *UDPMuxDefault) RemoveConnByUfrag(ufrag string) {
removedConns := make([]*udpMuxedConn, 0, 2)
// Keep lock section small to avoid deadlock with conn lock
@@ -290,12 +284,10 @@ func (m *SingleSocketUDPMux) RemoveConnByUfrag(ufrag string) {
if c, ok := m.connsIPv4[ufrag]; ok {
delete(m.connsIPv4, ufrag)
removedConns = append(removedConns, c)
delete(m.candidateConnMap, c.GetCandidateID())
}
if c, ok := m.connsIPv6[ufrag]; ok {
delete(m.connsIPv6, ufrag)
removedConns = append(removedConns, c)
delete(m.candidateConnMap, c.GetCandidateID())
}
m.mu.Unlock()
@@ -322,7 +314,7 @@ func (m *SingleSocketUDPMux) RemoveConnByUfrag(ufrag string) {
}
// IsClosed returns true if the mux had been closed
func (m *SingleSocketUDPMux) IsClosed() bool {
func (m *UDPMuxDefault) IsClosed() bool {
select {
case <-m.closedChan:
return true
@@ -332,7 +324,7 @@ func (m *SingleSocketUDPMux) IsClosed() bool {
}
// Close the mux, no further connections could be created
func (m *SingleSocketUDPMux) Close() error {
func (m *UDPMuxDefault) Close() error {
var err error
m.closeOnce.Do(func() {
m.mu.Lock()
@@ -355,11 +347,11 @@ func (m *SingleSocketUDPMux) Close() error {
return err
}
func (m *SingleSocketUDPMux) writeTo(buf []byte, rAddr net.Addr) (n int, err error) {
func (m *UDPMuxDefault) writeTo(buf []byte, rAddr net.Addr) (n int, err error) {
return m.params.UDPConn.WriteTo(buf, rAddr)
}
func (m *SingleSocketUDPMux) registerConnForAddress(conn *udpMuxedConn, addr string) {
func (m *UDPMuxDefault) registerConnForAddress(conn *udpMuxedConn, addr string) {
if m.IsClosed() {
return
}
@@ -376,109 +368,81 @@ func (m *SingleSocketUDPMux) registerConnForAddress(conn *udpMuxedConn, addr str
log.Debugf("ICE: registered %s for %s", addr, conn.params.Key)
}
func (m *SingleSocketUDPMux) createMuxedConn(key string, candidateID string) *udpMuxedConn {
func (m *UDPMuxDefault) createMuxedConn(key string) *udpMuxedConn {
c := newUDPMuxedConn(&udpMuxedConnParams{
Mux: m,
Key: key,
AddrPool: m.pool,
LocalAddr: m.LocalAddr(),
Logger: m.params.Logger,
CandidateID: candidateID,
Mux: m,
Key: key,
AddrPool: m.pool,
LocalAddr: m.LocalAddr(),
Logger: m.params.Logger,
})
return c
}
// HandleSTUNMessage handles STUN packets and forwards them to underlying pion/ice library
func (m *SingleSocketUDPMux) HandleSTUNMessage(msg *stun.Message, addr net.Addr) error {
func (m *UDPMuxDefault) HandleSTUNMessage(msg *stun.Message, addr net.Addr) error {
remoteAddr, ok := addr.(*net.UDPAddr)
if !ok {
return fmt.Errorf("underlying PacketConn did not return a UDPAddr")
}
// Try to route to specific candidate connection first
if conn := m.findCandidateConnection(msg); conn != nil {
return conn.writePacket(msg.Raw, remoteAddr)
}
// Fallback: route to all possible connections
return m.forwardToAllConnections(msg, addr, remoteAddr)
}
// findCandidateConnection attempts to find the specific connection for a STUN message
func (m *SingleSocketUDPMux) findCandidateConnection(msg *stun.Message) *udpMuxedConn {
candidatePairID, ok, err := ice.CandidatePairIDFromSTUN(msg)
if err != nil {
return nil
} else if !ok {
return nil
}
m.mu.Lock()
defer m.mu.Unlock()
conn, exists := m.candidateConnMap[candidatePairID.TargetCandidateID()]
if !exists {
return nil
}
return conn
}
// forwardToAllConnections forwards STUN message to all relevant connections
func (m *SingleSocketUDPMux) forwardToAllConnections(msg *stun.Message, addr net.Addr, remoteAddr *net.UDPAddr) error {
var destinationConnList []*udpMuxedConn
// Add connections from address map
// If we have already seen this address dispatch to the appropriate destination
// If you are using the same socket for the Host and SRFLX candidates, it might be that there are more than one
// muxed connection - one for the SRFLX candidate and the other one for the HOST one.
// We will then forward STUN packets to each of these connections.
m.addressMapMu.RLock()
var destinationConnList []*udpMuxedConn
if storedConns, ok := m.addressMap[addr.String()]; ok {
destinationConnList = append(destinationConnList, storedConns...)
}
m.addressMapMu.RUnlock()
if conn, ok := m.findConnectionByUsername(msg, addr); ok {
// If we have already seen this address dispatch to the appropriate destination
// If you are using the same socket for the Host and SRFLX candidates, it might be that there are more than one
// muxed connection - one for the SRFLX candidate and the other one for the HOST one.
// We will then forward STUN packets to each of these connections.
if !m.connectionExists(conn, destinationConnList) {
destinationConnList = append(destinationConnList, conn)
}
var isIPv6 bool
if udpAddr, _ := addr.(*net.UDPAddr); udpAddr != nil && udpAddr.IP.To4() == nil {
isIPv6 = true
}
// Forward to all found connections
// This block is needed to discover Peer Reflexive Candidates for which we don't know the Endpoint upfront.
// However, we can take a username attribute from the STUN message which contains ufrag.
// We can use ufrag to identify the destination conn to route packet to.
attr, stunAttrErr := msg.Get(stun.AttrUsername)
if stunAttrErr == nil {
ufrag := strings.Split(string(attr), ":")[0]
m.mu.Lock()
destinationConn := m.connsIPv4[ufrag]
if isIPv6 {
destinationConn = m.connsIPv6[ufrag]
}
if destinationConn != nil {
exists := false
for _, conn := range destinationConnList {
if conn.params.Key == destinationConn.params.Key {
exists = true
break
}
}
if !exists {
destinationConnList = append(destinationConnList, destinationConn)
}
}
m.mu.Unlock()
}
// Forward STUN packets to each destination connections even thought the STUN packet might not belong there.
// It will be discarded by the further ICE candidate logic if so.
for _, conn := range destinationConnList {
if err := conn.writePacket(msg.Raw, remoteAddr); err != nil {
log.Errorf("could not write packet: %v", err)
}
}
return nil
}
// findConnectionByUsername finds connection using username attribute from STUN message
func (m *SingleSocketUDPMux) findConnectionByUsername(msg *stun.Message, addr net.Addr) (*udpMuxedConn, bool) {
attr, err := msg.Get(stun.AttrUsername)
if err != nil {
return nil, false
}
ufrag := strings.Split(string(attr), ":")[0]
isIPv6 := isIPv6Address(addr)
m.mu.Lock()
defer m.mu.Unlock()
return m.getConn(ufrag, isIPv6)
}
// connectionExists checks if a connection already exists in the list
func (m *SingleSocketUDPMux) connectionExists(target *udpMuxedConn, conns []*udpMuxedConn) bool {
for _, conn := range conns {
if conn.params.Key == target.params.Key {
return true
}
}
return false
}
func (m *SingleSocketUDPMux) getConn(ufrag string, isIPv6 bool) (val *udpMuxedConn, ok bool) {
func (m *UDPMuxDefault) getConn(ufrag string, isIPv6 bool) (val *udpMuxedConn, ok bool) {
if isIPv6 {
val, ok = m.connsIPv6[ufrag]
} else {
@@ -487,13 +451,6 @@ func (m *SingleSocketUDPMux) getConn(ufrag string, isIPv6 bool) (val *udpMuxedCo
return
}
func isIPv6Address(addr net.Addr) bool {
if udpAddr, ok := addr.(*net.UDPAddr); ok {
return udpAddr.IP.To4() == nil
}
return false
}
type bufferHolder struct {
buf []byte
}

View File

@@ -1,12 +1,12 @@
//go:build !ios
package udpmux
package bind
import (
nbnet "github.com/netbirdio/netbird/client/net"
nbnet "github.com/netbirdio/netbird/util/net"
)
func (m *SingleSocketUDPMux) notifyAddressRemoval(addr string) {
func (m *UDPMuxDefault) notifyAddressRemoval(addr string) {
// Kernel mode: direct nbnet.PacketConn (SharedSocket wrapped with nbnet)
if conn, ok := m.params.UDPConn.(*nbnet.PacketConn); ok {
conn.RemoveAddress(addr)

View File

@@ -0,0 +1,7 @@
//go:build ios
package bind
func (m *UDPMuxDefault) notifyAddressRemoval(addr string) {
// iOS doesn't support nbnet hooks, so this is a no-op
}

View File

@@ -1,4 +1,4 @@
package udpmux
package bind
/*
Most of this code was copied from https://github.com/pion/ice and modified to fulfill NetBird's requirements.
@@ -15,10 +15,9 @@ import (
log "github.com/sirupsen/logrus"
"github.com/pion/logging"
"github.com/pion/stun/v3"
"github.com/pion/stun/v2"
"github.com/pion/transport/v3"
"github.com/netbirdio/netbird/client/iface/bufsize"
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
@@ -29,7 +28,7 @@ type FilterFn func(address netip.Addr) (bool, netip.Prefix, error)
// UniversalUDPMuxDefault handles STUN and TURN servers packets by wrapping the original UDPConn
// It then passes packets to the UDPMux that does the actual connection muxing.
type UniversalUDPMuxDefault struct {
*SingleSocketUDPMux
*UDPMuxDefault
params UniversalUDPMuxParams
// since we have a shared socket, for srflx candidates it makes sense to have a shared mapped address across all the agents
@@ -45,7 +44,6 @@ type UniversalUDPMuxParams struct {
Net transport.Net
FilterFn FilterFn
WGAddress wgaddr.Address
MTU uint16
}
// NewUniversalUDPMuxDefault creates an implementation of UniversalUDPMux embedding UDPMux
@@ -72,12 +70,12 @@ func NewUniversalUDPMuxDefault(params UniversalUDPMuxParams) *UniversalUDPMuxDef
address: params.WGAddress,
}
udpMuxParams := Params{
udpMuxParams := UDPMuxParams{
Logger: params.Logger,
UDPConn: m.params.UDPConn,
Net: m.params.Net,
}
m.SingleSocketUDPMux = NewSingleSocketUDPMux(udpMuxParams)
m.UDPMuxDefault = NewUDPMuxDefault(udpMuxParams)
return m
}
@@ -86,7 +84,7 @@ func NewUniversalUDPMuxDefault(params UniversalUDPMuxParams) *UniversalUDPMuxDef
// just ignore other packets printing an warning message.
// It is a blocking method, consider running in a go routine.
func (m *UniversalUDPMuxDefault) ReadFromConn(ctx context.Context) {
buf := make([]byte, m.params.MTU+bufsize.WGBufferOverhead)
buf := make([]byte, 1500)
for {
select {
case <-ctx.Done():
@@ -211,8 +209,8 @@ func (m *UniversalUDPMuxDefault) GetRelayedAddr(turnAddr net.Addr, deadline time
// GetConnForURL add uniques to the muxed connection by concatenating ufrag and URL (e.g. STUN URL) to be able to support multiple STUN/TURN servers
// and return a unique connection per server.
func (m *UniversalUDPMuxDefault) GetConnForURL(ufrag string, url string, addr net.Addr, candidateID string) (net.PacketConn, error) {
return m.SingleSocketUDPMux.GetConn(fmt.Sprintf("%s%s", ufrag, url), addr, candidateID)
func (m *UniversalUDPMuxDefault) GetConnForURL(ufrag string, url string, addr net.Addr) (net.PacketConn, error) {
return m.UDPMuxDefault.GetConn(fmt.Sprintf("%s%s", ufrag, url), addr)
}
// HandleSTUNMessage discovers STUN packets that carry a XOR mapped address from a STUN server.
@@ -233,7 +231,7 @@ func (m *UniversalUDPMuxDefault) HandleSTUNMessage(msg *stun.Message, addr net.A
}
return nil
}
return m.SingleSocketUDPMux.HandleSTUNMessage(msg, addr)
return m.UDPMuxDefault.HandleSTUNMessage(msg, addr)
}
// isXORMappedResponse indicates whether the message is a XORMappedAddress and is coming from the known STUN server.

View File

@@ -1,4 +1,4 @@
package udpmux
package bind
/*
Most of this code was copied from https://github.com/pion/ice and modified to fulfill NetBird's requirements
@@ -16,12 +16,11 @@ import (
)
type udpMuxedConnParams struct {
Mux *SingleSocketUDPMux
AddrPool *sync.Pool
Key string
LocalAddr net.Addr
Logger logging.LeveledLogger
CandidateID string
Mux *UDPMuxDefault
AddrPool *sync.Pool
Key string
LocalAddr net.Addr
Logger logging.LeveledLogger
}
// udpMuxedConn represents a logical packet conn for a single remote as identified by ufrag
@@ -120,10 +119,6 @@ func (c *udpMuxedConn) Close() error {
return err
}
func (c *udpMuxedConn) GetCandidateID() string {
return c.params.CandidateID
}
func (c *udpMuxedConn) isClosed() bool {
select {
case <-c.closedChan:

View File

@@ -1,9 +0,0 @@
package bufsize
const (
// WGBufferOverhead represents the additional buffer space needed beyond MTU
// for WireGuard packet encapsulation (WG header + UDP + IP + safety margin)
// Original hardcoded buffers were 1500, default MTU is 1280, so overhead = 220
// TODO: Calculate this properly based on actual protocol overhead instead of using hardcoded difference
WGBufferOverhead = 220
)

View File

@@ -73,44 +73,6 @@ func (c *KernelConfigurer) UpdatePeer(peerKey string, allowedIps []netip.Prefix,
return nil
}
func (c *KernelConfigurer) RemoveEndpointAddress(peerKey string) error {
peerKeyParsed, err := wgtypes.ParseKey(peerKey)
if err != nil {
return err
}
// Get the existing peer to preserve its allowed IPs
existingPeer, err := c.getPeer(c.deviceName, peerKey)
if err != nil {
return fmt.Errorf("get peer: %w", err)
}
removePeerCfg := wgtypes.PeerConfig{
PublicKey: peerKeyParsed,
Remove: true,
}
if err := c.configure(wgtypes.Config{Peers: []wgtypes.PeerConfig{removePeerCfg}}); err != nil {
return fmt.Errorf(`error removing peer %s from interface %s: %w`, peerKey, c.deviceName, err)
}
//Re-add the peer without the endpoint but same AllowedIPs
reAddPeerCfg := wgtypes.PeerConfig{
PublicKey: peerKeyParsed,
AllowedIPs: existingPeer.AllowedIPs,
ReplaceAllowedIPs: true,
}
if err := c.configure(wgtypes.Config{Peers: []wgtypes.PeerConfig{reAddPeerCfg}}); err != nil {
return fmt.Errorf(
`error re-adding peer %s to interface %s with allowed IPs %v: %w`,
peerKey, c.deviceName, existingPeer.AllowedIPs, err,
)
}
return nil
}
func (c *KernelConfigurer) RemovePeer(peerKey string) error {
peerKeyParsed, err := wgtypes.ParseKey(peerKey)
if err != nil {

View File

@@ -1,4 +1,4 @@
//go:build linux || windows || freebsd || js || wasip1
//go:build linux || windows || freebsd
package configurer

View File

@@ -1,4 +1,4 @@
//go:build !windows && !js
//go:build !windows
package configurer

View File

@@ -1,23 +0,0 @@
package configurer
import (
"net"
)
type noopListener struct{}
func (n *noopListener) Accept() (net.Conn, error) {
return nil, net.ErrClosed
}
func (n *noopListener) Close() error {
return nil
}
func (n *noopListener) Addr() net.Addr {
return nil
}
func openUAPI(deviceName string) (net.Listener, error) {
return &noopListener{}, nil
}

View File

@@ -17,8 +17,8 @@ import (
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/client/iface/bind"
nbnet "github.com/netbirdio/netbird/client/net"
"github.com/netbirdio/netbird/monotime"
nbnet "github.com/netbirdio/netbird/util/net"
)
const (
@@ -106,67 +106,6 @@ func (c *WGUSPConfigurer) UpdatePeer(peerKey string, allowedIps []netip.Prefix,
return nil
}
func (c *WGUSPConfigurer) RemoveEndpointAddress(peerKey string) error {
peerKeyParsed, err := wgtypes.ParseKey(peerKey)
if err != nil {
return fmt.Errorf("parse peer key: %w", err)
}
ipcStr, err := c.device.IpcGet()
if err != nil {
return fmt.Errorf("get IPC config: %w", err)
}
// Parse current status to get allowed IPs for the peer
stats, err := parseStatus(c.deviceName, ipcStr)
if err != nil {
return fmt.Errorf("parse IPC config: %w", err)
}
var allowedIPs []net.IPNet
found := false
for _, peer := range stats.Peers {
if peer.PublicKey == peerKey {
allowedIPs = peer.AllowedIPs
found = true
break
}
}
if !found {
return fmt.Errorf("peer %s not found", peerKey)
}
// remove the peer from the WireGuard configuration
peer := wgtypes.PeerConfig{
PublicKey: peerKeyParsed,
Remove: true,
}
config := wgtypes.Config{
Peers: []wgtypes.PeerConfig{peer},
}
if ipcErr := c.device.IpcSet(toWgUserspaceString(config)); ipcErr != nil {
return fmt.Errorf("failed to remove peer: %s", ipcErr)
}
// Build the peer config
peer = wgtypes.PeerConfig{
PublicKey: peerKeyParsed,
ReplaceAllowedIPs: true,
AllowedIPs: allowedIPs,
}
config = wgtypes.Config{
Peers: []wgtypes.PeerConfig{peer},
}
if err := c.device.IpcSet(toWgUserspaceString(config)); err != nil {
return fmt.Errorf("remove endpoint address: %w", err)
}
return nil
}
func (c *WGUSPConfigurer) RemovePeer(peerKey string) error {
peerKeyParsed, err := wgtypes.ParseKey(peerKey)
if err != nil {
@@ -455,13 +394,6 @@ func toLastHandshake(stringVar string) (time.Time, error) {
if err != nil {
return time.Time{}, fmt.Errorf("parse handshake sec: %w", err)
}
// If sec is 0 (Unix epoch), return zero time instead
// This indicates no handshake has occurred
if sec == 0 {
return time.Time{}, nil
}
return time.Unix(sec, 0), nil
}
@@ -470,7 +402,7 @@ func toBytes(s string) (int64, error) {
}
func getFwmark() int {
if nbnet.AdvancedRouting() && runtime.GOOS == "linux" {
if nbnet.AdvancedRouting() {
return nbnet.ControlPlaneMark
}
return 0

View File

@@ -7,21 +7,19 @@ import (
wgdevice "golang.zx2c4.com/wireguard/device"
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/udpmux"
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
type WGTunDevice interface {
Create() (device.WGConfigurer, error)
Up() (*udpmux.UniversalUDPMuxDefault, error)
Up() (*bind.UniversalUDPMuxDefault, error)
UpdateAddr(address wgaddr.Address) error
WgAddress() wgaddr.Address
MTU() uint16
DeviceName() string
Close() error
FilteredDevice() *device.FilteredDevice
Device() *wgdevice.Device
GetNet() *netstack.Net
GetICEBind() device.EndpointManager
}

View File

@@ -13,7 +13,6 @@ import (
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/udpmux"
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
@@ -22,7 +21,7 @@ type WGTunDevice struct {
address wgaddr.Address
port int
key string
mtu uint16
mtu int
iceBind *bind.ICEBind
tunAdapter TunAdapter
disableDNS bool
@@ -30,11 +29,11 @@ type WGTunDevice struct {
name string
device *device.Device
filteredDevice *FilteredDevice
udpMux *udpmux.UniversalUDPMuxDefault
udpMux *bind.UniversalUDPMuxDefault
configurer WGConfigurer
}
func NewTunDevice(address wgaddr.Address, port int, key string, mtu uint16, iceBind *bind.ICEBind, tunAdapter TunAdapter, disableDNS bool) *WGTunDevice {
func NewTunDevice(address wgaddr.Address, port int, key string, mtu int, iceBind *bind.ICEBind, tunAdapter TunAdapter, disableDNS bool) *WGTunDevice {
return &WGTunDevice{
address: address,
port: port,
@@ -59,7 +58,7 @@ func (t *WGTunDevice) Create(routes []string, dns string, searchDomains []string
searchDomainsToString = ""
}
fd, err := t.tunAdapter.ConfigureInterface(t.address.String(), int(t.mtu), dns, searchDomainsToString, routesString)
fd, err := t.tunAdapter.ConfigureInterface(t.address.String(), t.mtu, dns, searchDomainsToString, routesString)
if err != nil {
log.Errorf("failed to create Android interface: %s", err)
return nil, err
@@ -89,7 +88,7 @@ func (t *WGTunDevice) Create(routes []string, dns string, searchDomains []string
}
return t.configurer, nil
}
func (t *WGTunDevice) Up() (*udpmux.UniversalUDPMuxDefault, error) {
func (t *WGTunDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
err := t.device.Up()
if err != nil {
return nil, err
@@ -138,10 +137,6 @@ func (t *WGTunDevice) WgAddress() wgaddr.Address {
return t.address
}
func (t *WGTunDevice) MTU() uint16 {
return t.mtu
}
func (t *WGTunDevice) FilteredDevice() *FilteredDevice {
return t.filteredDevice
}
@@ -150,11 +145,6 @@ func (t *WGTunDevice) GetNet() *netstack.Net {
return nil
}
// GetICEBind returns the ICEBind instance
func (t *WGTunDevice) GetICEBind() EndpointManager {
return t.iceBind
}
func routesToString(routes []string) string {
return strings.Join(routes, ";")
}

View File

@@ -13,7 +13,6 @@ import (
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/udpmux"
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
@@ -22,16 +21,16 @@ type TunDevice struct {
address wgaddr.Address
port int
key string
mtu uint16
mtu int
iceBind *bind.ICEBind
device *device.Device
filteredDevice *FilteredDevice
udpMux *udpmux.UniversalUDPMuxDefault
udpMux *bind.UniversalUDPMuxDefault
configurer WGConfigurer
}
func NewTunDevice(name string, address wgaddr.Address, port int, key string, mtu uint16, iceBind *bind.ICEBind) *TunDevice {
func NewTunDevice(name string, address wgaddr.Address, port int, key string, mtu int, iceBind *bind.ICEBind) *TunDevice {
return &TunDevice{
name: name,
address: address,
@@ -43,7 +42,7 @@ func NewTunDevice(name string, address wgaddr.Address, port int, key string, mtu
}
func (t *TunDevice) Create() (WGConfigurer, error) {
tunDevice, err := tun.CreateTUN(t.name, int(t.mtu))
tunDevice, err := tun.CreateTUN(t.name, t.mtu)
if err != nil {
return nil, fmt.Errorf("error creating tun device: %s", err)
}
@@ -72,7 +71,7 @@ func (t *TunDevice) Create() (WGConfigurer, error) {
return t.configurer, nil
}
func (t *TunDevice) Up() (*udpmux.UniversalUDPMuxDefault, error) {
func (t *TunDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
err := t.device.Up()
if err != nil {
return nil, err
@@ -112,10 +111,6 @@ func (t *TunDevice) WgAddress() wgaddr.Address {
return t.address
}
func (t *TunDevice) MTU() uint16 {
return t.mtu
}
func (t *TunDevice) DeviceName() string {
return t.name
}
@@ -154,8 +149,3 @@ func (t *TunDevice) assignAddr() error {
func (t *TunDevice) GetNet() *netstack.Net {
return nil
}
// GetICEBind returns the ICEBind instance
func (t *TunDevice) GetICEBind() EndpointManager {
return t.iceBind
}

View File

@@ -14,7 +14,6 @@ import (
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/udpmux"
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
@@ -23,23 +22,21 @@ type TunDevice struct {
address wgaddr.Address
port int
key string
mtu uint16
iceBind *bind.ICEBind
tunFd int
device *device.Device
filteredDevice *FilteredDevice
udpMux *udpmux.UniversalUDPMuxDefault
udpMux *bind.UniversalUDPMuxDefault
configurer WGConfigurer
}
func NewTunDevice(name string, address wgaddr.Address, port int, key string, mtu uint16, iceBind *bind.ICEBind, tunFd int) *TunDevice {
func NewTunDevice(name string, address wgaddr.Address, port int, key string, iceBind *bind.ICEBind, tunFd int) *TunDevice {
return &TunDevice{
name: name,
address: address,
port: port,
key: key,
mtu: mtu,
iceBind: iceBind,
tunFd: tunFd,
}
@@ -84,7 +81,7 @@ func (t *TunDevice) Create() (WGConfigurer, error) {
return t.configurer, nil
}
func (t *TunDevice) Up() (*udpmux.UniversalUDPMuxDefault, error) {
func (t *TunDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
err := t.device.Up()
if err != nil {
return nil, err
@@ -128,10 +125,6 @@ func (t *TunDevice) WgAddress() wgaddr.Address {
return t.address
}
func (t *TunDevice) MTU() uint16 {
return t.mtu
}
func (t *TunDevice) UpdateAddr(_ wgaddr.Address) error {
// todo implement
return nil
@@ -144,8 +137,3 @@ func (t *TunDevice) FilteredDevice() *FilteredDevice {
func (t *TunDevice) GetNet() *netstack.Net {
return nil
}
// GetICEBind returns the ICEBind instance
func (t *TunDevice) GetICEBind() EndpointManager {
return t.iceBind
}

View File

@@ -12,11 +12,11 @@ import (
"golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/tun/netstack"
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/udpmux"
"github.com/netbirdio/netbird/client/iface/wgaddr"
nbnet "github.com/netbirdio/netbird/client/net"
"github.com/netbirdio/netbird/sharedsock"
nbnet "github.com/netbirdio/netbird/util/net"
)
type TunKernelDevice struct {
@@ -24,19 +24,19 @@ type TunKernelDevice struct {
address wgaddr.Address
wgPort int
key string
mtu uint16
mtu int
ctx context.Context
ctxCancel context.CancelFunc
transportNet transport.Net
link *wgLink
udpMuxConn net.PacketConn
udpMux *udpmux.UniversalUDPMuxDefault
udpMux *bind.UniversalUDPMuxDefault
filterFn udpmux.FilterFn
filterFn bind.FilterFn
}
func NewKernelDevice(name string, address wgaddr.Address, wgPort int, key string, mtu uint16, transportNet transport.Net) *TunKernelDevice {
func NewKernelDevice(name string, address wgaddr.Address, wgPort int, key string, mtu int, transportNet transport.Net) *TunKernelDevice {
ctx, cancel := context.WithCancel(context.Background())
return &TunKernelDevice{
ctx: ctx,
@@ -66,7 +66,7 @@ func (t *TunKernelDevice) Create() (WGConfigurer, error) {
// TODO: do a MTU discovery
log.Debugf("setting MTU: %d interface: %s", t.mtu, t.name)
if err := link.setMTU(int(t.mtu)); err != nil {
if err := link.setMTU(t.mtu); err != nil {
return nil, fmt.Errorf("set mtu: %w", err)
}
@@ -79,7 +79,7 @@ func (t *TunKernelDevice) Create() (WGConfigurer, error) {
return configurer, nil
}
func (t *TunKernelDevice) Up() (*udpmux.UniversalUDPMuxDefault, error) {
func (t *TunKernelDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
if t.udpMux != nil {
return t.udpMux, nil
}
@@ -96,19 +96,23 @@ func (t *TunKernelDevice) Up() (*udpmux.UniversalUDPMuxDefault, error) {
return nil, err
}
rawSock, err := sharedsock.Listen(t.wgPort, sharedsock.NewIncomingSTUNFilter(), t.mtu)
rawSock, err := sharedsock.Listen(t.wgPort, sharedsock.NewIncomingSTUNFilter())
if err != nil {
return nil, err
}
bindParams := udpmux.UniversalUDPMuxParams{
UDPConn: nbnet.WrapPacketConn(rawSock),
var udpConn net.PacketConn = rawSock
if !nbnet.AdvancedRouting() {
udpConn = nbnet.WrapPacketConn(rawSock)
}
bindParams := bind.UniversalUDPMuxParams{
UDPConn: udpConn,
Net: t.transportNet,
FilterFn: t.filterFn,
WGAddress: t.address,
MTU: t.mtu,
}
mux := udpmux.NewUniversalUDPMuxDefault(bindParams)
mux := bind.NewUniversalUDPMuxDefault(bindParams)
go mux.ReadFromConn(t.ctx)
t.udpMuxConn = rawSock
t.udpMux = mux
@@ -154,10 +158,6 @@ func (t *TunKernelDevice) WgAddress() wgaddr.Address {
return t.address
}
func (t *TunKernelDevice) MTU() uint16 {
return t.mtu
}
func (t *TunKernelDevice) DeviceName() string {
return t.name
}
@@ -179,8 +179,3 @@ func (t *TunKernelDevice) assignAddr() error {
func (t *TunKernelDevice) GetNet() *netstack.Net {
return nil
}
// GetICEBind returns nil for kernel mode devices
func (t *TunKernelDevice) GetICEBind() EndpointManager {
return nil
}

View File

@@ -1,48 +1,41 @@
//go:build !android
// +build !android
package device
import (
"errors"
"fmt"
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/conn"
"golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/tun/netstack"
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer"
nbnetstack "github.com/netbirdio/netbird/client/iface/netstack"
"github.com/netbirdio/netbird/client/iface/udpmux"
"github.com/netbirdio/netbird/client/iface/wgaddr"
nbnet "github.com/netbirdio/netbird/client/net"
nbnet "github.com/netbirdio/netbird/util/net"
)
type Bind interface {
conn.Bind
GetICEMux() (*udpmux.UniversalUDPMuxDefault, error)
ActivityRecorder() *bind.ActivityRecorder
EndpointManager
}
type TunNetstackDevice struct {
name string
address wgaddr.Address
port int
key string
mtu uint16
mtu int
listenAddress string
bind Bind
iceBind *bind.ICEBind
device *device.Device
filteredDevice *FilteredDevice
nsTun *nbnetstack.NetStackTun
udpMux *udpmux.UniversalUDPMuxDefault
udpMux *bind.UniversalUDPMuxDefault
configurer WGConfigurer
net *netstack.Net
}
func NewNetstackDevice(name string, address wgaddr.Address, wgPort int, key string, mtu uint16, bind Bind, listenAddress string) *TunNetstackDevice {
func NewNetstackDevice(name string, address wgaddr.Address, wgPort int, key string, mtu int, iceBind *bind.ICEBind, listenAddress string) *TunNetstackDevice {
return &TunNetstackDevice{
name: name,
address: address,
@@ -50,11 +43,11 @@ func NewNetstackDevice(name string, address wgaddr.Address, wgPort int, key stri
key: key,
mtu: mtu,
listenAddress: listenAddress,
bind: bind,
iceBind: iceBind,
}
}
func (t *TunNetstackDevice) create() (WGConfigurer, error) {
func (t *TunNetstackDevice) Create() (WGConfigurer, error) {
log.Info("create nbnetstack tun interface")
// TODO: get from service listener runtime IP
@@ -64,7 +57,7 @@ func (t *TunNetstackDevice) create() (WGConfigurer, error) {
}
log.Debugf("netstack using address: %s", t.address.IP)
t.nsTun = nbnetstack.NewNetStackTun(t.listenAddress, t.address.IP, dnsAddr, int(t.mtu))
t.nsTun = nbnetstack.NewNetStackTun(t.listenAddress, t.address.IP, dnsAddr, t.mtu)
log.Debugf("netstack using dns address: %s", dnsAddr)
tunIface, net, err := t.nsTun.Create()
if err != nil {
@@ -75,11 +68,11 @@ func (t *TunNetstackDevice) create() (WGConfigurer, error) {
t.device = device.NewDevice(
t.filteredDevice,
t.bind,
t.iceBind,
device.NewLogger(wgLogLevel(), "[netbird] "),
)
t.configurer = configurer.NewUSPConfigurer(t.device, t.name, t.bind.ActivityRecorder())
t.configurer = configurer.NewUSPConfigurer(t.device, t.name, t.iceBind.ActivityRecorder())
err = t.configurer.ConfigureInterface(t.key, t.port)
if err != nil {
_ = tunIface.Close()
@@ -90,7 +83,7 @@ func (t *TunNetstackDevice) create() (WGConfigurer, error) {
return t.configurer, nil
}
func (t *TunNetstackDevice) Up() (*udpmux.UniversalUDPMuxDefault, error) {
func (t *TunNetstackDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
if t.device == nil {
return nil, fmt.Errorf("device is not ready yet")
}
@@ -100,15 +93,11 @@ func (t *TunNetstackDevice) Up() (*udpmux.UniversalUDPMuxDefault, error) {
return nil, err
}
udpMux, err := t.bind.GetICEMux()
if err != nil && !errors.Is(err, bind.ErrUDPMUXNotSupported) {
udpMux, err := t.iceBind.GetICEMux()
if err != nil {
return nil, err
}
if udpMux != nil {
t.udpMux = udpMux
}
t.udpMux = udpMux
log.Debugf("netstack device is ready to use")
return udpMux, nil
}
@@ -136,10 +125,6 @@ func (t *TunNetstackDevice) WgAddress() wgaddr.Address {
return t.address
}
func (t *TunNetstackDevice) MTU() uint16 {
return t.mtu
}
func (t *TunNetstackDevice) DeviceName() string {
return t.name
}
@@ -156,8 +141,3 @@ func (t *TunNetstackDevice) Device() *device.Device {
func (t *TunNetstackDevice) GetNet() *netstack.Net {
return t.net
}
// GetICEBind returns the bind instance
func (t *TunNetstackDevice) GetICEBind() EndpointManager {
return t.bind
}

View File

@@ -1,7 +0,0 @@
//go:build android
package device
func (t *TunNetstackDevice) Create(routes []string, dns string, searchDomains []string) (WGConfigurer, error) {
return t.create()
}

View File

@@ -1,7 +0,0 @@
//go:build !android
package device
func (t *TunNetstackDevice) Create() (WGConfigurer, error) {
return t.create()
}

View File

@@ -1,27 +0,0 @@
package device
import (
"testing"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/netstack"
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
func TestNewNetstackDevice(t *testing.T) {
privateKey, _ := wgtypes.GeneratePrivateKey()
wgAddress, _ := wgaddr.ParseWGAddress("1.2.3.4/24")
relayBind := bind.NewRelayBindJS()
nsTun := NewNetstackDevice("wtx", wgAddress, 1234, privateKey.String(), 1500, relayBind, netstack.ListenAddr())
cfgr, err := nsTun.Create()
if err != nil {
t.Fatalf("failed to create netstack device: %v", err)
}
if cfgr == nil {
t.Fatal("expected non-nil configurer")
}
}

View File

@@ -12,7 +12,6 @@ import (
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/udpmux"
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
@@ -21,16 +20,16 @@ type USPDevice struct {
address wgaddr.Address
port int
key string
mtu uint16
mtu int
iceBind *bind.ICEBind
device *device.Device
filteredDevice *FilteredDevice
udpMux *udpmux.UniversalUDPMuxDefault
udpMux *bind.UniversalUDPMuxDefault
configurer WGConfigurer
}
func NewUSPDevice(name string, address wgaddr.Address, port int, key string, mtu uint16, iceBind *bind.ICEBind) *USPDevice {
func NewUSPDevice(name string, address wgaddr.Address, port int, key string, mtu int, iceBind *bind.ICEBind) *USPDevice {
log.Infof("using userspace bind mode")
return &USPDevice{
@@ -45,9 +44,9 @@ func NewUSPDevice(name string, address wgaddr.Address, port int, key string, mtu
func (t *USPDevice) Create() (WGConfigurer, error) {
log.Info("create tun interface")
tunIface, err := tun.CreateTUN(t.name, int(t.mtu))
tunIface, err := tun.CreateTUN(t.name, t.mtu)
if err != nil {
log.Debugf("failed to create tun interface (%s, %d): %s", t.name, int(t.mtu), err)
log.Debugf("failed to create tun interface (%s, %d): %s", t.name, t.mtu, err)
return nil, fmt.Errorf("error creating tun device: %s", err)
}
t.filteredDevice = newDeviceFilter(tunIface)
@@ -75,7 +74,7 @@ func (t *USPDevice) Create() (WGConfigurer, error) {
return t.configurer, nil
}
func (t *USPDevice) Up() (*udpmux.UniversalUDPMuxDefault, error) {
func (t *USPDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
if t.device == nil {
return nil, fmt.Errorf("device is not ready yet")
}
@@ -119,10 +118,6 @@ func (t *USPDevice) WgAddress() wgaddr.Address {
return t.address
}
func (t *USPDevice) MTU() uint16 {
return t.mtu
}
func (t *USPDevice) DeviceName() string {
return t.name
}
@@ -146,8 +141,3 @@ func (t *USPDevice) assignAddr() error {
func (t *USPDevice) GetNet() *netstack.Net {
return nil
}
// GetICEBind returns the ICEBind instance
func (t *USPDevice) GetICEBind() EndpointManager {
return t.iceBind
}

View File

@@ -13,7 +13,6 @@ import (
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/udpmux"
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
@@ -24,17 +23,17 @@ type TunDevice struct {
address wgaddr.Address
port int
key string
mtu uint16
mtu int
iceBind *bind.ICEBind
device *device.Device
nativeTunDevice *tun.NativeTun
filteredDevice *FilteredDevice
udpMux *udpmux.UniversalUDPMuxDefault
udpMux *bind.UniversalUDPMuxDefault
configurer WGConfigurer
}
func NewTunDevice(name string, address wgaddr.Address, port int, key string, mtu uint16, iceBind *bind.ICEBind) *TunDevice {
func NewTunDevice(name string, address wgaddr.Address, port int, key string, mtu int, iceBind *bind.ICEBind) *TunDevice {
return &TunDevice{
name: name,
address: address,
@@ -60,7 +59,7 @@ func (t *TunDevice) Create() (WGConfigurer, error) {
return nil, err
}
log.Info("create tun interface")
tunDevice, err := tun.CreateTUNWithRequestedGUID(t.name, &guid, int(t.mtu))
tunDevice, err := tun.CreateTUNWithRequestedGUID(t.name, &guid, t.mtu)
if err != nil {
return nil, fmt.Errorf("error creating tun device: %s", err)
}
@@ -105,7 +104,7 @@ func (t *TunDevice) Create() (WGConfigurer, error) {
return t.configurer, nil
}
func (t *TunDevice) Up() (*udpmux.UniversalUDPMuxDefault, error) {
func (t *TunDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
err := t.device.Up()
if err != nil {
return nil, err
@@ -145,10 +144,6 @@ func (t *TunDevice) WgAddress() wgaddr.Address {
return t.address
}
func (t *TunDevice) MTU() uint16 {
return t.mtu
}
func (t *TunDevice) DeviceName() string {
return t.name
}
@@ -185,8 +180,3 @@ func (t *TunDevice) assignAddr() error {
func (t *TunDevice) GetNet() *netstack.Net {
return nil
}
// GetICEBind returns the ICEBind instance
func (t *TunDevice) GetICEBind() EndpointManager {
return t.iceBind
}

View File

@@ -1,13 +0,0 @@
package device
import (
"net"
"net/netip"
)
// EndpointManager manages fake IP to connection mappings for userspace bind implementations.
// Implemented by bind.ICEBind and bind.RelayBindJS.
type EndpointManager interface {
SetEndpoint(fakeIP netip.Addr, conn net.Conn)
RemoveEndpoint(fakeIP netip.Addr)
}

View File

@@ -21,5 +21,4 @@ type WGConfigurer interface {
GetStats() (map[string]configurer.WGStats, error)
FullStats() (*configurer.Stats, error)
LastActivities() map[string]monotime.Time
RemoveEndpointAddress(peerKey string) error
}

View File

@@ -5,21 +5,19 @@ import (
"golang.zx2c4.com/wireguard/tun/netstack"
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/udpmux"
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
type WGTunDevice interface {
Create(routes []string, dns string, searchDomains []string) (device.WGConfigurer, error)
Up() (*udpmux.UniversalUDPMuxDefault, error)
Up() (*bind.UniversalUDPMuxDefault, error)
UpdateAddr(address wgaddr.Address) error
WgAddress() wgaddr.Address
MTU() uint16
DeviceName() string
Close() error
FilteredDevice() *device.FilteredDevice
Device() *wgdevice.Device
GetNet() *netstack.Net
GetICEBind() device.EndpointManager
}

View File

@@ -16,9 +16,9 @@ import (
wgdevice "golang.zx2c4.com/wireguard/device"
"github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/udpmux"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/iface/wgproxy"
"github.com/netbirdio/netbird/monotime"
@@ -26,8 +26,6 @@ import (
const (
DefaultMTU = 1280
MinMTU = 576
MaxMTU = 8192
DefaultWgPort = 51820
WgInterfaceDefault = configurer.WgInterfaceDefault
)
@@ -37,17 +35,6 @@ var (
ErrIfaceNotFound = fmt.Errorf("wireguard interface not found")
)
// ValidateMTU validates that MTU is within acceptable range
func ValidateMTU(mtu uint16) error {
if mtu < MinMTU {
return fmt.Errorf("MTU %d below minimum (%d bytes)", mtu, MinMTU)
}
if mtu > MaxMTU {
return fmt.Errorf("MTU %d exceeds maximum supported size (%d bytes)", mtu, MaxMTU)
}
return nil
}
type wgProxyFactory interface {
GetProxy() wgproxy.Proxy
Free() error
@@ -58,10 +45,10 @@ type WGIFaceOpts struct {
Address string
WGPort int
WGPrivKey string
MTU uint16
MTU int
MobileArgs *device.MobileIFaceArguments
TransportNet transport.Net
FilterFn udpmux.FilterFn
FilterFn bind.FilterFn
DisableDNS bool
}
@@ -80,17 +67,6 @@ func (w *WGIface) GetProxy() wgproxy.Proxy {
return w.wgProxyFactory.GetProxy()
}
// GetBind returns the EndpointManager userspace bind mode.
func (w *WGIface) GetBind() device.EndpointManager {
w.mu.Lock()
defer w.mu.Unlock()
if w.tun == nil {
return nil
}
return w.tun.GetICEBind()
}
// IsUserspaceBind indicates whether this interfaces is userspace with bind.ICEBind
func (w *WGIface) IsUserspaceBind() bool {
return w.userspaceBind
@@ -106,10 +82,6 @@ func (w *WGIface) Address() wgaddr.Address {
return w.tun.WgAddress()
}
func (w *WGIface) MTU() uint16 {
return w.tun.MTU()
}
// ToInterface returns the net.Interface for the Wireguard interface
func (r *WGIface) ToInterface() *net.Interface {
name := r.tun.DeviceName()
@@ -125,7 +97,7 @@ func (r *WGIface) ToInterface() *net.Interface {
// Up configures a Wireguard interface
// The interface must exist before calling this method (e.g. call interface.Create() before)
func (w *WGIface) Up() (*udpmux.UniversalUDPMuxDefault, error) {
func (w *WGIface) Up() (*bind.UniversalUDPMuxDefault, error) {
w.mu.Lock()
defer w.mu.Unlock()
@@ -159,17 +131,6 @@ func (w *WGIface) UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAliv
return w.configurer.UpdatePeer(peerKey, allowedIps, keepAlive, endpoint, preSharedKey)
}
func (w *WGIface) RemoveEndpointAddress(peerKey string) error {
w.mu.Lock()
defer w.mu.Unlock()
if w.configurer == nil {
return ErrIfaceNotFound
}
log.Debugf("Removing endpoint address: %s", peerKey)
return w.configurer.RemoveEndpointAddress(peerKey)
}
// RemovePeer removes a Wireguard Peer from the interface iface
func (w *WGIface) RemovePeer(peerKey string) error {
w.mu.Lock()

View File

@@ -1,6 +0,0 @@
package iface
// Destroy is a no-op on WASM
func (w *WGIface) Destroy() error {
return nil
}

View File

@@ -3,7 +3,6 @@ package iface
import (
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/netstack"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/iface/wgproxy"
)
@@ -15,21 +14,12 @@ func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
return nil, err
}
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress, opts.MTU)
if netstack.IsEnabled() {
wgIFace := &WGIface{
userspaceBind: true,
tun: device.NewNetstackDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, netstack.ListenAddr()),
wgProxyFactory: wgproxy.NewUSPFactory(iceBind, opts.MTU),
}
return wgIFace, nil
}
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress)
wgIFace := &WGIface{
userspaceBind: true,
tun: device.NewTunDevice(wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, opts.MobileArgs.TunAdapter, opts.DisableDNS),
wgProxyFactory: wgproxy.NewUSPFactory(iceBind, opts.MTU),
wgProxyFactory: wgproxy.NewUSPFactory(iceBind),
}
return wgIFace, nil
}

View File

@@ -17,7 +17,7 @@ func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
return nil, err
}
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress, opts.MTU)
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress)
var tun WGTunDevice
if netstack.IsEnabled() {
@@ -29,7 +29,7 @@ func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
wgIFace := &WGIface{
userspaceBind: true,
tun: tun,
wgProxyFactory: wgproxy.NewUSPFactory(iceBind, opts.MTU),
wgProxyFactory: wgproxy.NewUSPFactory(iceBind),
}
return wgIFace, nil
}

View File

@@ -1,41 +0,0 @@
//go:build freebsd
package iface
import (
"fmt"
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/netstack"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/iface/wgproxy"
)
// NewWGIFace Creates a new WireGuard interface instance
func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
wgAddress, err := wgaddr.ParseWGAddress(opts.Address)
if err != nil {
return nil, err
}
wgIFace := &WGIface{}
if netstack.IsEnabled() {
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress, opts.MTU)
wgIFace.tun = device.NewNetstackDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, netstack.ListenAddr())
wgIFace.userspaceBind = true
wgIFace.wgProxyFactory = wgproxy.NewUSPFactory(iceBind, opts.MTU)
return wgIFace, nil
}
if device.ModuleTunIsLoaded() {
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress, opts.MTU)
wgIFace.tun = device.NewUSPDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind)
wgIFace.userspaceBind = true
wgIFace.wgProxyFactory = wgproxy.NewUSPFactory(iceBind, opts.MTU)
return wgIFace, nil
}
return nil, fmt.Errorf("couldn't check or load tun module")
}

View File

@@ -16,12 +16,12 @@ func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
return nil, err
}
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress, opts.MTU)
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress)
wgIFace := &WGIface{
tun: device.NewTunDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, opts.MobileArgs.TunFd),
tun: device.NewTunDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, iceBind, opts.MobileArgs.TunFd),
userspaceBind: true,
wgProxyFactory: wgproxy.NewUSPFactory(iceBind, opts.MTU),
wgProxyFactory: wgproxy.NewUSPFactory(iceBind),
}
return wgIFace, nil
}

View File

@@ -1,27 +0,0 @@
package iface
import (
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/netstack"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/iface/wgproxy"
)
// NewWGIFace creates a new WireGuard interface for WASM (always uses netstack mode)
func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
wgAddress, err := wgaddr.ParseWGAddress(opts.Address)
if err != nil {
return nil, err
}
relayBind := bind.NewRelayBindJS()
wgIface := &WGIface{
tun: device.NewNetstackDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, relayBind, netstack.ListenAddr()),
userspaceBind: true,
wgProxyFactory: wgproxy.NewUSPFactory(relayBind, opts.MTU),
}
return wgIface, nil
}

View File

@@ -1,4 +1,4 @@
//go:build linux && !android
//go:build (linux && !android) || freebsd
package iface
@@ -22,23 +22,23 @@ func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
wgIFace := &WGIface{}
if netstack.IsEnabled() {
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress, opts.MTU)
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress)
wgIFace.tun = device.NewNetstackDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, netstack.ListenAddr())
wgIFace.userspaceBind = true
wgIFace.wgProxyFactory = wgproxy.NewUSPFactory(iceBind, opts.MTU)
wgIFace.wgProxyFactory = wgproxy.NewUSPFactory(iceBind)
return wgIFace, nil
}
if device.WireGuardModuleIsLoaded() {
wgIFace.tun = device.NewKernelDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, opts.TransportNet)
wgIFace.wgProxyFactory = wgproxy.NewKernelFactory(opts.WGPort, opts.MTU)
wgIFace.wgProxyFactory = wgproxy.NewKernelFactory(opts.WGPort)
return wgIFace, nil
}
if device.ModuleTunIsLoaded() {
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress, opts.MTU)
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress)
wgIFace.tun = device.NewUSPDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind)
wgIFace.userspaceBind = true
wgIFace.wgProxyFactory = wgproxy.NewUSPFactory(iceBind, opts.MTU)
wgIFace.wgProxyFactory = wgproxy.NewUSPFactory(iceBind)
return wgIFace, nil
}

View File

@@ -14,7 +14,7 @@ func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
if err != nil {
return nil, err
}
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress, opts.MTU)
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress)
var tun WGTunDevice
if netstack.IsEnabled() {
@@ -26,7 +26,7 @@ func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
wgIFace := &WGIface{
userspaceBind: true,
tun: tun,
wgProxyFactory: wgproxy.NewUSPFactory(iceBind, opts.MTU),
wgProxyFactory: wgproxy.NewUSPFactory(iceBind),
}
return wgIFace, nil

View File

@@ -1,5 +1,3 @@
//go:build !js
package netstack
import (

View File

@@ -1,12 +0,0 @@
package netstack
const EnvUseNetstackMode = "NB_USE_NETSTACK_MODE"
// IsEnabled always returns true for js since it's the only mode available
func IsEnabled() bool {
return true
}
func ListenAddr() string {
return ""
}

View File

@@ -1,64 +0,0 @@
// Package udpmux provides a custom implementation of a UDP multiplexer
// that allows multiple logical ICE connections to share a single underlying
// UDP socket. This is based on Pion's ICE library, with modifications for
// NetBird's requirements.
//
// # Background
//
// In WebRTC and NAT traversal scenarios, ICE (Interactive Connectivity
// Establishment) is responsible for discovering candidate network paths
// and maintaining connectivity between peers. Each ICE connection
// normally requires a dedicated UDP socket. However, using one socket
// per candidate can be inefficient and difficult to manage.
//
// This package introduces SingleSocketUDPMux, which allows multiple ICE
// candidate connections (muxed connections) to share a single UDP socket.
// It handles demultiplexing of packets based on ICE ufrag values, STUN
// attributes, and candidate IDs.
//
// # Usage
//
// The typical flow is:
//
// 1. Create a UDP socket (net.PacketConn).
// 2. Construct Params with the socket and optional logger/net stack.
// 3. Call NewSingleSocketUDPMux(params).
// 4. For each ICE candidate ufrag, call GetConn(ufrag, addr, candidateID)
// to obtain a logical PacketConn.
// 5. Use the returned PacketConn just like a normal UDP connection.
//
// # STUN Message Routing Logic
//
// When a STUN packet arrives, the mux decides which connection should
// receive it using this routing logic:
//
// Primary Routing: Candidate Pair ID
// - Extract the candidate pair ID from the STUN message using
// ice.CandidatePairIDFromSTUN(msg)
// - The target candidate is the locally generated candidate that
// corresponds to the connection that should handle this STUN message
// - If found, use the target candidate ID to lookup the specific
// connection in candidateConnMap
// - Route the message directly to that connection
//
// Fallback Routing: Broadcasting
// When candidate pair ID is not available or lookup fails:
// - Collect connections from addressMap based on source address
// - Find connection using username attribute (ufrag) from STUN message
// - Remove duplicate connections from the list
// - Send the STUN message to all collected connections
//
// # Peer Reflexive Candidate Discovery
//
// When a remote peer sends a STUN message from an unknown source address
// (from a candidate that has not been exchanged via signal), the ICE
// library will:
// - Generate a new peer reflexive candidate for this source address
// - Extract or assign a candidate ID based on the STUN message attributes
// - Create a mapping between the new peer reflexive candidate ID and
// the appropriate local connection
//
// This discovery mechanism ensures that STUN messages from newly discovered
// peer reflexive candidates can be properly routed to the correct local
// connection without requiring fallback broadcasting.
package udpmux

View File

@@ -1,7 +0,0 @@
//go:build ios
package udpmux
func (m *SingleSocketUDPMux) notifyAddressRemoval(addr string) {
// iOS doesn't support nbnet hooks, so this is a no-op
}

View File

@@ -12,42 +12,31 @@ import (
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/bufsize"
"github.com/netbirdio/netbird/client/iface/wgproxy/listener"
)
type Bind interface {
SetEndpoint(addr netip.Addr, conn net.Conn)
RemoveEndpoint(addr netip.Addr)
ReceiveFromEndpoint(ctx context.Context, ep *bind.Endpoint, buf []byte)
}
type ProxyBind struct {
bind Bind
Bind *bind.ICEBind
// wgRelayedEndpoint is a fake address that generated by the Bind.SetEndpoint based on the remote NetBird peer address
wgRelayedEndpoint *bind.Endpoint
wgCurrentUsed *bind.Endpoint
remoteConn net.Conn
ctx context.Context
cancel context.CancelFunc
closeMu sync.Mutex
closed bool
fakeNetIP *netip.AddrPort
wgBindEndpoint *bind.Endpoint
remoteConn net.Conn
ctx context.Context
cancel context.CancelFunc
closeMu sync.Mutex
closed bool
paused bool
pausedCond *sync.Cond
isStarted bool
pausedMu sync.Mutex
paused bool
isStarted bool
closeListener *listener.CloseListener
mtu uint16
}
func NewProxyBind(bind Bind, mtu uint16) *ProxyBind {
func NewProxyBind(bind *bind.ICEBind) *ProxyBind {
p := &ProxyBind{
bind: bind,
Bind: bind,
closeListener: listener.NewCloseListener(),
pausedCond: sync.NewCond(&sync.Mutex{}),
mtu: mtu + bufsize.WGBufferOverhead,
}
return p
@@ -56,25 +45,25 @@ func NewProxyBind(bind Bind, mtu uint16) *ProxyBind {
// AddTurnConn adds a new connection to the bind.
// endpoint is the NetBird address of the remote peer. The SetEndpoint return with the address what will be used in the
// WireGuard configuration.
//
// Parameters:
// - ctx: Context is used for proxyToLocal to avoid unnecessary error messages
// - nbAddr: The NetBird UDP address of the remote peer, it required to generate fake address
// - remoteConn: The established TURN connection to the remote peer
func (p *ProxyBind) AddTurnConn(ctx context.Context, nbAddr *net.UDPAddr, remoteConn net.Conn) error {
fakeNetIP, err := fakeAddress(nbAddr)
if err != nil {
return err
}
p.wgRelayedEndpoint = &bind.Endpoint{AddrPort: *fakeNetIP}
p.fakeNetIP = fakeNetIP
p.wgBindEndpoint = &bind.Endpoint{AddrPort: *fakeNetIP}
p.remoteConn = remoteConn
p.ctx, p.cancel = context.WithCancel(ctx)
return nil
}
func (p *ProxyBind) EndpointAddr() *net.UDPAddr {
return bind.EndpointToUDPAddr(*p.wgRelayedEndpoint)
return &net.UDPAddr{
IP: p.fakeNetIP.Addr().AsSlice(),
Port: int(p.fakeNetIP.Port()),
Zone: p.fakeNetIP.Addr().Zone(),
}
}
func (p *ProxyBind) SetDisconnectListener(disconnected func()) {
@@ -86,21 +75,17 @@ func (p *ProxyBind) Work() {
return
}
p.bind.SetEndpoint(p.wgRelayedEndpoint.Addr(), p.remoteConn)
p.Bind.SetEndpoint(p.fakeNetIP.Addr(), p.remoteConn)
p.pausedCond.L.Lock()
p.pausedMu.Lock()
p.paused = false
p.wgCurrentUsed = p.wgRelayedEndpoint
p.pausedMu.Unlock()
// Start the proxy only once
if !p.isStarted {
p.isStarted = true
go p.proxyToLocal(p.ctx)
}
p.pausedCond.Signal()
p.pausedCond.L.Unlock()
}
func (p *ProxyBind) Pause() {
@@ -108,25 +93,9 @@ func (p *ProxyBind) Pause() {
return
}
p.pausedCond.L.Lock()
p.pausedMu.Lock()
p.paused = true
p.pausedCond.L.Unlock()
}
func (p *ProxyBind) RedirectAs(endpoint *net.UDPAddr) {
p.pausedCond.L.Lock()
p.paused = false
p.wgCurrentUsed = addrToEndpoint(endpoint)
p.pausedCond.Signal()
p.pausedCond.L.Unlock()
}
func addrToEndpoint(addr *net.UDPAddr) *bind.Endpoint {
ip, _ := netip.AddrFromSlice(addr.IP.To4())
addrPort := netip.AddrPortFrom(ip, uint16(addr.Port))
return &bind.Endpoint{AddrPort: addrPort}
p.pausedMu.Unlock()
}
func (p *ProxyBind) CloseConn() error {
@@ -137,10 +106,6 @@ func (p *ProxyBind) CloseConn() error {
}
func (p *ProxyBind) close() error {
if p.remoteConn == nil {
return nil
}
p.closeMu.Lock()
defer p.closeMu.Unlock()
@@ -154,12 +119,7 @@ func (p *ProxyBind) close() error {
p.cancel()
p.pausedCond.L.Lock()
p.paused = false
p.pausedCond.Signal()
p.pausedCond.L.Unlock()
p.bind.RemoveEndpoint(p.wgRelayedEndpoint.Addr())
p.Bind.RemoveEndpoint(p.fakeNetIP.Addr())
if rErr := p.remoteConn.Close(); rErr != nil && !errors.Is(rErr, net.ErrClosed) {
return rErr
@@ -175,7 +135,7 @@ func (p *ProxyBind) proxyToLocal(ctx context.Context) {
}()
for {
buf := make([]byte, p.mtu)
buf := make([]byte, 1500)
n, err := p.remoteConn.Read(buf)
if err != nil {
if ctx.Err() != nil {
@@ -186,13 +146,18 @@ func (p *ProxyBind) proxyToLocal(ctx context.Context) {
return
}
p.pausedCond.L.Lock()
for p.paused {
p.pausedCond.Wait()
p.pausedMu.Lock()
if p.paused {
p.pausedMu.Unlock()
continue
}
p.bind.ReceiveFromEndpoint(ctx, p.wgCurrentUsed, buf[:n])
p.pausedCond.L.Unlock()
msg := bind.RecvMessage{
Endpoint: p.wgBindEndpoint,
Buffer: buf[:n],
}
p.Bind.RecvChan <- msg
p.pausedMu.Unlock()
}
}

View File

@@ -6,7 +6,9 @@ import (
"context"
"fmt"
"net"
"os"
"sync"
"syscall"
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
@@ -15,25 +17,18 @@ import (
log "github.com/sirupsen/logrus"
nberrors "github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/client/iface/bufsize"
"github.com/netbirdio/netbird/client/iface/wgproxy/rawsocket"
"github.com/netbirdio/netbird/client/internal/ebpf"
ebpfMgr "github.com/netbirdio/netbird/client/internal/ebpf/manager"
nbnet "github.com/netbirdio/netbird/client/net"
nbnet "github.com/netbirdio/netbird/util/net"
)
const (
loopbackAddr = "127.0.0.1"
)
var (
localHostNetIP = net.ParseIP("127.0.0.1")
)
// WGEBPFProxy definition for proxy with EBPF support
type WGEBPFProxy struct {
localWGListenPort int
mtu uint16
ebpfManager ebpfMgr.Manager
turnConnStore map[uint16]net.Conn
@@ -48,11 +43,10 @@ type WGEBPFProxy struct {
}
// NewWGEBPFProxy create new WGEBPFProxy instance
func NewWGEBPFProxy(wgPort int, mtu uint16) *WGEBPFProxy {
func NewWGEBPFProxy(wgPort int) *WGEBPFProxy {
log.Debugf("instantiate ebpf proxy")
wgProxy := &WGEBPFProxy{
localWGListenPort: wgPort,
mtu: mtu,
ebpfManager: ebpf.GetEbpfManagerInstance(),
turnConnStore: make(map[uint16]net.Conn),
}
@@ -67,7 +61,7 @@ func (p *WGEBPFProxy) Listen() error {
return err
}
p.rawConn, err = rawsocket.PrepareSenderRawSocket()
p.rawConn, err = p.prepareSenderRawSocket()
if err != nil {
return err
}
@@ -144,7 +138,7 @@ func (p *WGEBPFProxy) Free() error {
// proxyToRemote read messages from local WireGuard interface and forward it to remote conn
// From this go routine has only one instance.
func (p *WGEBPFProxy) proxyToRemote() {
buf := make([]byte, p.mtu+bufsize.WGBufferOverhead)
buf := make([]byte, 1500)
for p.ctx.Err() == nil {
if err := p.readAndForwardPacket(buf); err != nil {
if p.ctx.Err() != nil {
@@ -217,17 +211,57 @@ generatePort:
return p.lastUsedPort, nil
}
func (p *WGEBPFProxy) sendPkg(data []byte, endpointAddr *net.UDPAddr) error {
func (p *WGEBPFProxy) prepareSenderRawSocket() (net.PacketConn, error) {
// Create a raw socket.
fd, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_RAW, syscall.IPPROTO_RAW)
if err != nil {
return nil, fmt.Errorf("creating raw socket failed: %w", err)
}
// Set the IP_HDRINCL option on the socket to tell the kernel that headers are included in the packet.
err = syscall.SetsockoptInt(fd, syscall.IPPROTO_IP, syscall.IP_HDRINCL, 1)
if err != nil {
return nil, fmt.Errorf("setting IP_HDRINCL failed: %w", err)
}
// Bind the socket to the "lo" interface.
err = syscall.SetsockoptString(fd, syscall.SOL_SOCKET, syscall.SO_BINDTODEVICE, "lo")
if err != nil {
return nil, fmt.Errorf("binding to lo interface failed: %w", err)
}
// Set the fwmark on the socket.
err = nbnet.SetSocketOpt(fd)
if err != nil {
return nil, fmt.Errorf("setting fwmark failed: %w", err)
}
// Convert the file descriptor to a PacketConn.
file := os.NewFile(uintptr(fd), fmt.Sprintf("fd %d", fd))
if file == nil {
return nil, fmt.Errorf("converting fd to file failed")
}
packetConn, err := net.FilePacketConn(file)
if err != nil {
return nil, fmt.Errorf("converting file to packet conn failed: %w", err)
}
return packetConn, nil
}
func (p *WGEBPFProxy) sendPkg(data []byte, port int) error {
localhost := net.ParseIP("127.0.0.1")
payload := gopacket.Payload(data)
ipH := &layers.IPv4{
DstIP: localHostNetIP,
SrcIP: endpointAddr.IP,
DstIP: localhost,
SrcIP: localhost,
Version: 4,
TTL: 64,
Protocol: layers.IPProtocolUDP,
}
udpH := &layers.UDP{
SrcPort: layers.UDPPort(endpointAddr.Port),
SrcPort: layers.UDPPort(port),
DstPort: layers.UDPPort(p.localWGListenPort),
}
@@ -242,7 +276,7 @@ func (p *WGEBPFProxy) sendPkg(data []byte, endpointAddr *net.UDPAddr) error {
if err != nil {
return fmt.Errorf("serialize layers: %w", err)
}
if _, err = p.rawConn.WriteTo(layerBuffer.Bytes(), &net.IPAddr{IP: localHostNetIP}); err != nil {
if _, err = p.rawConn.WriteTo(layerBuffer.Bytes(), &net.IPAddr{IP: localhost}); err != nil {
return fmt.Errorf("write to raw conn: %w", err)
}
return nil

View File

@@ -7,7 +7,7 @@ import (
)
func TestWGEBPFProxy_connStore(t *testing.T) {
wgProxy := NewWGEBPFProxy(1, 1280)
wgProxy := NewWGEBPFProxy(1)
p, _ := wgProxy.storeTurnConn(nil)
if p != 1 {
@@ -27,7 +27,7 @@ func TestWGEBPFProxy_connStore(t *testing.T) {
}
func TestWGEBPFProxy_portCalculation_overflow(t *testing.T) {
wgProxy := NewWGEBPFProxy(1, 1280)
wgProxy := NewWGEBPFProxy(1)
_, _ = wgProxy.storeTurnConn(nil)
wgProxy.lastUsedPort = 65535
@@ -43,7 +43,7 @@ func TestWGEBPFProxy_portCalculation_overflow(t *testing.T) {
}
func TestWGEBPFProxy_portCalculation_maxConn(t *testing.T) {
wgProxy := NewWGEBPFProxy(1, 1280)
wgProxy := NewWGEBPFProxy(1)
for i := 0; i < 65535; i++ {
_, _ = wgProxy.storeTurnConn(nil)

View File

@@ -12,48 +12,46 @@ import (
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/iface/bufsize"
"github.com/netbirdio/netbird/client/iface/wgproxy/listener"
)
// ProxyWrapper help to keep the remoteConn instance for net.Conn.Close function call
type ProxyWrapper struct {
wgeBPFProxy *WGEBPFProxy
WgeBPFProxy *WGEBPFProxy
remoteConn net.Conn
ctx context.Context
cancel context.CancelFunc
wgRelayedEndpointAddr *net.UDPAddr
wgEndpointCurrentUsedAddr *net.UDPAddr
wgEndpointAddr *net.UDPAddr
paused bool
pausedCond *sync.Cond
isStarted bool
pausedMu sync.Mutex
paused bool
isStarted bool
closeListener *listener.CloseListener
}
func NewProxyWrapper(proxy *WGEBPFProxy) *ProxyWrapper {
func NewProxyWrapper(WgeBPFProxy *WGEBPFProxy) *ProxyWrapper {
return &ProxyWrapper{
wgeBPFProxy: proxy,
pausedCond: sync.NewCond(&sync.Mutex{}),
WgeBPFProxy: WgeBPFProxy,
closeListener: listener.NewCloseListener(),
}
}
func (p *ProxyWrapper) AddTurnConn(ctx context.Context, endpoint *net.UDPAddr, remoteConn net.Conn) error {
addr, err := p.wgeBPFProxy.AddTurnConn(remoteConn)
addr, err := p.WgeBPFProxy.AddTurnConn(remoteConn)
if err != nil {
return fmt.Errorf("add turn conn: %w", err)
}
p.remoteConn = remoteConn
p.ctx, p.cancel = context.WithCancel(ctx)
p.wgRelayedEndpointAddr = addr
p.wgEndpointAddr = addr
return err
}
func (p *ProxyWrapper) EndpointAddr() *net.UDPAddr {
return p.wgRelayedEndpointAddr
return p.wgEndpointAddr
}
func (p *ProxyWrapper) SetDisconnectListener(disconnected func()) {
@@ -65,18 +63,14 @@ func (p *ProxyWrapper) Work() {
return
}
p.pausedCond.L.Lock()
p.pausedMu.Lock()
p.paused = false
p.wgEndpointCurrentUsedAddr = p.wgRelayedEndpointAddr
p.pausedMu.Unlock()
if !p.isStarted {
p.isStarted = true
go p.proxyToLocal(p.ctx)
}
p.pausedCond.Signal()
p.pausedCond.L.Unlock()
}
func (p *ProxyWrapper) Pause() {
@@ -85,59 +79,45 @@ func (p *ProxyWrapper) Pause() {
}
log.Tracef("pause proxy reading from: %s", p.remoteConn.RemoteAddr())
p.pausedCond.L.Lock()
p.pausedMu.Lock()
p.paused = true
p.pausedCond.L.Unlock()
}
func (p *ProxyWrapper) RedirectAs(endpoint *net.UDPAddr) {
p.pausedCond.L.Lock()
p.paused = false
p.wgEndpointCurrentUsedAddr = endpoint
p.pausedCond.Signal()
p.pausedCond.L.Unlock()
p.pausedMu.Unlock()
}
// CloseConn close the remoteConn and automatically remove the conn instance from the map
func (p *ProxyWrapper) CloseConn() error {
if p.cancel == nil {
func (e *ProxyWrapper) CloseConn() error {
if e.cancel == nil {
return fmt.Errorf("proxy not started")
}
p.cancel()
e.cancel()
p.closeListener.SetCloseListener(nil)
e.closeListener.SetCloseListener(nil)
p.pausedCond.L.Lock()
p.paused = false
p.pausedCond.Signal()
p.pausedCond.L.Unlock()
if err := p.remoteConn.Close(); err != nil && !errors.Is(err, net.ErrClosed) {
return fmt.Errorf("failed to close remote conn: %w", err)
if err := e.remoteConn.Close(); err != nil && !errors.Is(err, net.ErrClosed) {
return fmt.Errorf("close remote conn: %w", err)
}
return nil
}
func (p *ProxyWrapper) proxyToLocal(ctx context.Context) {
defer p.wgeBPFProxy.removeTurnConn(uint16(p.wgRelayedEndpointAddr.Port))
defer p.WgeBPFProxy.removeTurnConn(uint16(p.wgEndpointAddr.Port))
buf := make([]byte, p.wgeBPFProxy.mtu+bufsize.WGBufferOverhead)
buf := make([]byte, 1500)
for {
n, err := p.readFromRemote(ctx, buf)
if err != nil {
return
}
p.pausedCond.L.Lock()
for p.paused {
p.pausedCond.Wait()
p.pausedMu.Lock()
if p.paused {
p.pausedMu.Unlock()
continue
}
err = p.wgeBPFProxy.sendPkg(buf[:n], p.wgEndpointCurrentUsedAddr)
p.pausedCond.L.Unlock()
err = p.WgeBPFProxy.sendPkg(buf[:n], p.wgEndpointAddr.Port)
p.pausedMu.Unlock()
if err != nil {
if ctx.Err() != nil {
@@ -156,7 +136,7 @@ func (p *ProxyWrapper) readFromRemote(ctx context.Context, buf []byte) (int, err
}
p.closeListener.Notify()
if !errors.Is(err, io.EOF) {
log.Errorf("failed to read from turn conn (endpoint: :%d): %s", p.wgRelayedEndpointAddr.Port, err)
log.Errorf("failed to read from turn conn (endpoint: :%d): %s", p.wgEndpointAddr.Port, err)
}
return 0, err
}

View File

@@ -11,18 +11,16 @@ import (
type KernelFactory struct {
wgPort int
mtu uint16
ebpfProxy *ebpf.WGEBPFProxy
}
func NewKernelFactory(wgPort int, mtu uint16) *KernelFactory {
func NewKernelFactory(wgPort int) *KernelFactory {
f := &KernelFactory{
wgPort: wgPort,
mtu: mtu,
}
ebpfProxy := ebpf.NewWGEBPFProxy(wgPort, mtu)
ebpfProxy := ebpf.NewWGEBPFProxy(wgPort)
if err := ebpfProxy.Listen(); err != nil {
log.Infof("WireGuard Proxy Factory will produce UDP proxy")
log.Warnf("failed to initialize ebpf proxy, fallback to user space proxy: %s", err)
@@ -35,10 +33,11 @@ func NewKernelFactory(wgPort int, mtu uint16) *KernelFactory {
func (w *KernelFactory) GetProxy() Proxy {
if w.ebpfProxy == nil {
return udpProxy.NewWGUDPProxy(w.wgPort, w.mtu)
return udpProxy.NewWGUDPProxy(w.wgPort)
}
return ebpf.NewProxyWrapper(w.ebpfProxy)
}
func (w *KernelFactory) Free() error {

View File

@@ -0,0 +1,29 @@
package wgproxy
import (
log "github.com/sirupsen/logrus"
udpProxy "github.com/netbirdio/netbird/client/iface/wgproxy/udp"
)
// KernelFactory todo: check eBPF support on FreeBSD
type KernelFactory struct {
wgPort int
}
func NewKernelFactory(wgPort int) *KernelFactory {
log.Infof("WireGuard Proxy Factory will produce UDP proxy")
f := &KernelFactory{
wgPort: wgPort,
}
return f
}
func (w *KernelFactory) GetProxy() Proxy {
return udpProxy.NewWGUDPProxy(w.wgPort)
}
func (w *KernelFactory) Free() error {
return nil
}

View File

@@ -3,25 +3,24 @@ package wgproxy
import (
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/iface/bind"
proxyBind "github.com/netbirdio/netbird/client/iface/wgproxy/bind"
)
type USPFactory struct {
bind proxyBind.Bind
mtu uint16
bind *bind.ICEBind
}
func NewUSPFactory(bind proxyBind.Bind, mtu uint16) *USPFactory {
func NewUSPFactory(iceBind *bind.ICEBind) *USPFactory {
log.Infof("WireGuard Proxy Factory will produce bind proxy")
f := &USPFactory{
bind: bind,
mtu: mtu,
bind: iceBind,
}
return f
}
func (w *USPFactory) GetProxy() Proxy {
return proxyBind.NewProxyBind(w.bind, w.mtu)
return proxyBind.NewProxyBind(w.bind)
}
func (w *USPFactory) Free() error {

View File

@@ -11,11 +11,6 @@ type Proxy interface {
EndpointAddr() *net.UDPAddr // EndpointAddr returns the address of the WireGuard peer endpoint
Work() // Work start or resume the proxy
Pause() // Pause to forward the packages from remote connection to WireGuard. The opposite way still works.
//RedirectAs resume the forwarding the packages from relayed connection to WireGuard interface if it was paused
//and rewrite the src address to the endpoint address.
//With this logic can avoid the package loss from relayed connections.
RedirectAs(endpoint *net.UDPAddr)
CloseConn() error
SetDisconnectListener(disconnected func())
}

View File

@@ -3,82 +3,54 @@
package wgproxy
import (
"fmt"
"net"
"context"
"os"
"testing"
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/wgaddr"
bindproxy "github.com/netbirdio/netbird/client/iface/wgproxy/bind"
"github.com/netbirdio/netbird/client/iface/wgproxy/ebpf"
"github.com/netbirdio/netbird/client/iface/wgproxy/udp"
)
func seedProxies() ([]proxyInstance, error) {
pl := make([]proxyInstance, 0)
func TestProxyCloseByRemoteConnEBPF(t *testing.T) {
if os.Getenv("GITHUB_ACTIONS") != "true" {
t.Skip("Skipping test as it requires root privileges")
}
ctx := context.Background()
ebpfProxy := ebpf.NewWGEBPFProxy(51831, 1280)
ebpfProxy := ebpf.NewWGEBPFProxy(51831)
if err := ebpfProxy.Listen(); err != nil {
return nil, fmt.Errorf("failed to initialize ebpf proxy: %s", err)
t.Fatalf("failed to initialize ebpf proxy: %s", err)
}
pEbpf := proxyInstance{
name: "ebpf kernel proxy",
proxy: ebpf.NewProxyWrapper(ebpfProxy),
wgPort: 51831,
closeFn: ebpfProxy.Free,
}
pl = append(pl, pEbpf)
defer func() {
if err := ebpfProxy.Free(); err != nil {
t.Errorf("failed to free ebpf proxy: %s", err)
}
}()
pUDP := proxyInstance{
name: "udp kernel proxy",
proxy: udp.NewWGUDPProxy(51832, 1280),
wgPort: 51832,
closeFn: func() error { return nil },
tests := []struct {
name string
proxy Proxy
}{
{
name: "ebpf proxy",
proxy: &ebpf.ProxyWrapper{
WgeBPFProxy: ebpfProxy,
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
relayedConn := newMockConn()
err := tt.proxy.AddTurnConn(ctx, nil, relayedConn)
if err != nil {
t.Errorf("error: %v", err)
}
_ = relayedConn.Close()
if err := tt.proxy.CloseConn(); err != nil {
t.Errorf("error: %v", err)
}
})
}
pl = append(pl, pUDP)
return pl, nil
}
func seedProxyForProxyCloseByRemoteConn() ([]proxyInstance, error) {
pl := make([]proxyInstance, 0)
ebpfProxy := ebpf.NewWGEBPFProxy(51831, 1280)
if err := ebpfProxy.Listen(); err != nil {
return nil, fmt.Errorf("failed to initialize ebpf proxy: %s", err)
}
pEbpf := proxyInstance{
name: "ebpf kernel proxy",
proxy: ebpf.NewProxyWrapper(ebpfProxy),
wgPort: 51831,
closeFn: ebpfProxy.Free,
}
pl = append(pl, pEbpf)
pUDP := proxyInstance{
name: "udp kernel proxy",
proxy: udp.NewWGUDPProxy(51832, 1280),
wgPort: 51832,
closeFn: func() error { return nil },
}
pl = append(pl, pUDP)
wgAddress, err := wgaddr.ParseWGAddress("10.0.0.1/32")
if err != nil {
return nil, err
}
iceBind := bind.NewICEBind(nil, nil, wgAddress, 1280)
endpointAddress := &net.UDPAddr{
IP: net.IPv4(10, 0, 0, 1),
Port: 1234,
}
pBind := proxyInstance{
name: "bind proxy",
proxy: bindproxy.NewProxyBind(iceBind, 0),
endpointAddr: endpointAddress,
closeFn: func() error { return nil },
}
pl = append(pl, pBind)
return pl, nil
}

View File

@@ -1,39 +0,0 @@
//go:build !linux
package wgproxy
import (
"net"
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/wgaddr"
bindproxy "github.com/netbirdio/netbird/client/iface/wgproxy/bind"
)
func seedProxies() ([]proxyInstance, error) {
// todo extend with Bind proxy
pl := make([]proxyInstance, 0)
return pl, nil
}
func seedProxyForProxyCloseByRemoteConn() ([]proxyInstance, error) {
pl := make([]proxyInstance, 0)
wgAddress, err := wgaddr.ParseWGAddress("10.0.0.1/32")
if err != nil {
return nil, err
}
iceBind := bind.NewICEBind(nil, nil, wgAddress, 1280)
endpointAddress := &net.UDPAddr{
IP: net.IPv4(10, 0, 0, 1),
Port: 1234,
}
pBind := proxyInstance{
name: "bind proxy",
proxy: bindproxy.NewProxyBind(iceBind, 0),
endpointAddr: endpointAddress,
closeFn: func() error { return nil },
}
pl = append(pl, pBind)
return pl, nil
}

View File

@@ -1,3 +1,5 @@
//go:build linux
package wgproxy
import (
@@ -5,9 +7,12 @@ import (
"io"
"net"
"os"
"runtime"
"testing"
"time"
"github.com/netbirdio/netbird/client/iface/wgproxy/ebpf"
udpProxy "github.com/netbirdio/netbird/client/iface/wgproxy/udp"
"github.com/netbirdio/netbird/util"
)
@@ -17,14 +22,6 @@ func TestMain(m *testing.M) {
os.Exit(code)
}
type proxyInstance struct {
name string
proxy Proxy
wgPort int
endpointAddr *net.UDPAddr
closeFn func() error
}
type mocConn struct {
closeChan chan struct{}
closed bool
@@ -81,21 +78,41 @@ func (m *mocConn) SetWriteDeadline(t time.Time) error {
func TestProxyCloseByRemoteConn(t *testing.T) {
ctx := context.Background()
tests, err := seedProxyForProxyCloseByRemoteConn()
if err != nil {
t.Fatalf("error: %v", err)
tests := []struct {
name string
proxy Proxy
}{
{
name: "userspace proxy",
proxy: udpProxy.NewWGUDPProxy(51830),
},
}
relayedConn, _ := net.Dial("udp", "127.0.0.1:1234")
defer func() {
_ = relayedConn.Close()
}()
if runtime.GOOS == "linux" && os.Getenv("GITHUB_ACTIONS") != "true" {
ebpfProxy := ebpf.NewWGEBPFProxy(51831)
if err := ebpfProxy.Listen(); err != nil {
t.Fatalf("failed to initialize ebpf proxy: %s", err)
}
defer func() {
if err := ebpfProxy.Free(); err != nil {
t.Errorf("failed to free ebpf proxy: %s", err)
}
}()
proxyWrapper := ebpf.NewProxyWrapper(ebpfProxy)
tests = append(tests, struct {
name string
proxy Proxy
}{
name: "ebpf proxy",
proxy: proxyWrapper,
})
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
addr, _ := net.ResolveUDPAddr("udp", "100.108.135.221:51892")
relayedConn := newMockConn()
err := tt.proxy.AddTurnConn(ctx, addr, relayedConn)
err := tt.proxy.AddTurnConn(ctx, nil, relayedConn)
if err != nil {
t.Errorf("error: %v", err)
}
@@ -107,104 +124,3 @@ func TestProxyCloseByRemoteConn(t *testing.T) {
})
}
}
// TestProxyRedirect todo extend the proxies with Bind proxy
func TestProxyRedirect(t *testing.T) {
tests, err := seedProxies()
if err != nil {
t.Fatalf("error: %v", err)
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
redirectTraffic(t, tt.proxy, tt.wgPort, tt.endpointAddr)
if err := tt.closeFn(); err != nil {
t.Errorf("error: %v", err)
}
})
}
}
func redirectTraffic(t *testing.T, proxy Proxy, wgPort int, endPointAddr *net.UDPAddr) {
t.Helper()
msgHelloFromRelay := []byte("hello from relay")
msgRedirected := [][]byte{
[]byte("hello 1. to p2p"),
[]byte("hello 2. to p2p"),
[]byte("hello 3. to p2p"),
}
dummyWgListener, err := net.ListenUDP("udp", &net.UDPAddr{
IP: net.IPv4(127, 0, 0, 1),
Port: wgPort})
if err != nil {
t.Fatalf("failed to listen on udp port: %s", err)
}
relayedServer, _ := net.ListenUDP("udp",
&net.UDPAddr{
IP: net.IPv4(127, 0, 0, 1),
Port: 1234,
},
)
relayedConn, _ := net.Dial("udp", "127.0.0.1:1234")
defer func() {
_ = dummyWgListener.Close()
_ = relayedConn.Close()
_ = relayedServer.Close()
}()
if err := proxy.AddTurnConn(context.Background(), endPointAddr, relayedConn); err != nil {
t.Errorf("error: %v", err)
}
defer func() {
if err := proxy.CloseConn(); err != nil {
t.Errorf("error: %v", err)
}
}()
proxy.Work()
if _, err := relayedServer.WriteTo(msgHelloFromRelay, relayedConn.LocalAddr()); err != nil {
t.Errorf("error relayedServer.Write(msgHelloFromRelay): %v", err)
}
n, err := dummyWgListener.Read(make([]byte, 1024))
if err != nil {
t.Errorf("error: %v", err)
}
if n != len(msgHelloFromRelay) {
t.Errorf("expected %d bytes, got %d", len(msgHelloFromRelay), n)
}
p2pEndpointAddr := &net.UDPAddr{
IP: net.IPv4(192, 168, 0, 56),
Port: 1234,
}
proxy.RedirectAs(p2pEndpointAddr)
for _, msg := range msgRedirected {
if _, err := relayedServer.WriteTo(msg, relayedConn.LocalAddr()); err != nil {
t.Errorf("error: %v", err)
}
}
for i := 0; i < len(msgRedirected); i++ {
buf := make([]byte, 1024)
n, rAddr, err := dummyWgListener.ReadFrom(buf)
if err != nil {
t.Errorf("error: %v", err)
}
if rAddr.String() != p2pEndpointAddr.String() {
t.Errorf("expected %s, got %s", p2pEndpointAddr.String(), rAddr.String())
}
if string(buf[:n]) != string(msgRedirected[i]) {
t.Errorf("expected %s, got %s", string(msgRedirected[i]), string(buf[:n]))
}
}
}

Some files were not shown because too many files have changed in this diff Show More