Compare commits

..

3 Commits

Author SHA1 Message Date
Pedro Costa
43176fb96c Merge branch 'main' into feature/buf-cli
# Conflicts:
#	client/proto/daemon.pb.go
#	management/proto/management.pb.go
2025-05-12 10:34:14 +01:00
Pedro Costa
d1e4bb0fa3 Merge branch 'main' into feature/buf-cli
# Conflicts:
#	management/proto/management.pb.go
2025-04-17 11:53:47 +01:00
Pedro Costa
12b0c13511 [misc] buf cli proto lint, format and gen 2025-04-09 20:39:32 +01:00
465 changed files with 11245 additions and 30874 deletions

View File

@@ -9,7 +9,7 @@ RUN apt-get update && export DEBIAN_FRONTEND=noninteractive \
libayatana-appindicator3-dev=0.5.5-2+deb11u2 \
&& apt-get clean \
&& rm -rf /var/lib/apt/lists/* \
&& go install -v golang.org/x/tools/gopls@v0.18.1
&& go install -v golang.org/x/tools/gopls@latest
WORKDIR /app

View File

@@ -1,3 +0,0 @@
*
!client/netbird-entrypoint.sh
!netbird

View File

@@ -37,21 +37,16 @@ If yes, which one?
**Debug output**
To help us resolve the problem, please attach the following anonymized status output
To help us resolve the problem, please attach the following debug output
netbird status -dA
Create and upload a debug bundle, and share the returned file key:
netbird debug for 1m -AS -U
*Uploaded files are automatically deleted after 30 days.*
Alternatively, create the file only and attach it here manually:
As well as the file created by
netbird debug for 1m -AS
We advise reviewing the anonymized output for any remaining personal information.
**Screenshots**
@@ -62,10 +57,8 @@ If applicable, add screenshots to help explain your problem.
Add any other context about the problem here.
**Have you tried these troubleshooting steps?**
- [ ] Reviewed [client troubleshooting](https://docs.netbird.io/how-to/troubleshooting-client) (if applicable)
- [ ] Checked for newer NetBird versions
- [ ] Searched for similar issues on GitHub (including closed ones)
- [ ] Restarted the NetBird client
- [ ] Disabled other VPN software
- [ ] Checked firewall settings

View File

@@ -13,5 +13,3 @@
- [ ] It is a refactor
- [ ] Created tests that fail without the change (if possible)
- [ ] Extended the README / documentation, if necessary
> By submitting this pull request, you confirm that you have read and agree to the terms of the [Contributor License Agreement](https://github.com/netbirdio/netbird/blob/main/CONTRIBUTOR_LICENSE_AGREEMENT.md).

View File

@@ -16,6 +16,6 @@ jobs:
steps:
- uses: actions/checkout@v4
- uses: git-town/action@v1.2.1
- uses: git-town/action@v1
with:
skip-single-stacks: true
skip-single-stacks: true

View File

@@ -16,7 +16,7 @@ jobs:
runs-on: ubuntu-22.04
outputs:
management: ${{ steps.filter.outputs.management }}
steps:
steps:
- name: Checkout code
uses: actions/checkout@v4
@@ -24,8 +24,8 @@ jobs:
id: filter
with:
filters: |
management:
- 'management/**'
management:
- 'management/**'
- name: Install Go
uses: actions/setup-go@v5
@@ -148,7 +148,7 @@ jobs:
test_client_on_docker:
name: "Client (Docker) / Unit"
needs: [ build-cache ]
needs: [build-cache]
runs-on: ubuntu-22.04
steps:
- name: Install Go
@@ -181,7 +181,6 @@ jobs:
env:
HOST_GOCACHE: ${{ steps.go-env.outputs.cache_dir }}
HOST_GOMODCACHE: ${{ steps.go-env.outputs.modcache_dir }}
CONTAINER: "true"
run: |
CONTAINER_GOCACHE="/root/.cache/go-build"
CONTAINER_GOMODCACHE="/go/pkg/mod"
@@ -199,7 +198,6 @@ jobs:
-e GOARCH=${GOARCH_TARGET} \
-e GOCACHE=${CONTAINER_GOCACHE} \
-e GOMODCACHE=${CONTAINER_GOMODCACHE} \
-e CONTAINER=${CONTAINER} \
golang:1.23-alpine \
sh -c ' \
apk update; apk add --no-cache \
@@ -213,11 +211,7 @@ jobs:
strategy:
fail-fast: false
matrix:
include:
- arch: "386"
raceFlag: ""
- arch: "amd64"
raceFlag: ""
arch: [ '386','amd64' ]
runs-on: ubuntu-22.04
steps:
- name: Install Go
@@ -229,10 +223,6 @@ jobs:
- name: Checkout code
uses: actions/checkout@v4
- name: Install dependencies
if: steps.cache.outputs.cache-hit != 'true'
run: sudo apt update && sudo apt install -y gcc-multilib g++-multilib libc6-dev-i386
- name: Get Go environment
run: |
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
@@ -257,9 +247,9 @@ jobs:
- name: Test
run: |
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
go test ${{ matrix.raceFlag }} \
go test \
-exec 'sudo' \
-timeout 10m ./relay/...
-timeout 10m ./signal/...
test_signal:
name: "Signal / Unit"
@@ -279,10 +269,6 @@ jobs:
- name: Checkout code
uses: actions/checkout@v4
- name: Install dependencies
if: steps.cache.outputs.cache-hit != 'true'
run: sudo apt update && sudo apt install -y gcc-multilib g++-multilib libc6-dev-i386
- name: Get Go environment
run: |
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV

View File

@@ -21,6 +21,7 @@ jobs:
with:
ignore_words_list: erro,clienta,hastable,iif,groupd,testin,groupe
skip: go.mod,go.sum
only_warn: 1
golangci:
strategy:
fail-fast: false

View File

@@ -43,7 +43,7 @@ jobs:
- name: gomobile init
run: gomobile init
- name: build android netbird lib
run: PATH=$PATH:$(go env GOPATH) gomobile bind -o $GITHUB_WORKSPACE/netbird.aar -javapkg=io.netbird.gomobile -ldflags="-checklinkname=0 -X golang.zx2c4.com/wireguard/ipc.socketDirectory=/data/data/io.netbird.client/cache/wireguard -X github.com/netbirdio/netbird/version.version=buildtest" $GITHUB_WORKSPACE/client/android
run: PATH=$PATH:$(go env GOPATH) gomobile bind -o $GITHUB_WORKSPACE/netbird.aar -javapkg=io.netbird.gomobile -ldflags="-X golang.zx2c4.com/wireguard/ipc.socketDirectory=/data/data/io.netbird.client/cache/wireguard -X github.com/netbirdio/netbird/version.version=buildtest" $GITHUB_WORKSPACE/client/android
env:
CGO_ENABLED: 0
ANDROID_NDK_HOME: /usr/local/lib/android/sdk/ndk/23.1.7779620

View File

@@ -9,7 +9,7 @@ on:
pull_request:
env:
SIGN_PIPE_VER: "v0.0.21"
SIGN_PIPE_VER: "v0.0.18"
GORELEASER_VER: "v2.3.2"
PRODUCT_NAME: "NetBird"
COPYRIGHT: "NetBird GmbH"
@@ -65,13 +65,6 @@ jobs:
with:
username: ${{ secrets.DOCKER_USER }}
password: ${{ secrets.DOCKER_TOKEN }}
- name: Log in to the GitHub container registry
if: github.event_name != 'pull_request'
uses: docker/login-action@v3
with:
registry: ghcr.io
username: ${{ github.actor }}
password: ${{ secrets.CI_DOCKER_PUSH_GITHUB_TOKEN }}
- name: Install OS build dependencies
run: sudo apt update && sudo apt install -y -q gcc-arm-linux-gnueabihf gcc-aarch64-linux-gnu
@@ -231,17 +224,3 @@ jobs:
ref: ${{ env.SIGN_PIPE_VER }}
token: ${{ secrets.SIGN_GITHUB_TOKEN }}
inputs: '{ "tag": "${{ github.ref }}", "skipRelease": false }'
post_on_forum:
runs-on: ubuntu-latest
continue-on-error: true
needs: [trigger_signer]
steps:
- uses: Codixer/discourse-topic-github-release-action@v2.0.1
with:
discourse-api-key: ${{ secrets.DISCOURSE_RELEASES_API_KEY }}
discourse-base-url: https://forum.netbird.io
discourse-author-username: NetBird
discourse-category: 17
discourse-tags:
releases

View File

@@ -134,7 +134,6 @@ jobs:
NETBIRD_STORE_ENGINE_MYSQL_DSN: '${{ env.NETBIRD_STORE_ENGINE_MYSQL_DSN }}$'
CI_NETBIRD_MGMT_IDP_SIGNKEY_REFRESH: false
CI_NETBIRD_TURN_EXTERNAL_IP: "1.2.3.4"
CI_NETBIRD_MGMT_DISABLE_DEFAULT_POLICY: false
run: |
set -x
@@ -173,15 +172,13 @@ jobs:
grep "NETBIRD_STORE_ENGINE_MYSQL_DSN=$NETBIRD_STORE_ENGINE_MYSQL_DSN" docker-compose.yml
grep NETBIRD_STORE_ENGINE_POSTGRES_DSN docker-compose.yml | egrep "$NETBIRD_STORE_ENGINE_POSTGRES_DSN"
# check relay values
grep "NB_EXPOSED_ADDRESS=rels://$CI_NETBIRD_DOMAIN:33445" docker-compose.yml
grep "NB_EXPOSED_ADDRESS=$CI_NETBIRD_DOMAIN:33445" docker-compose.yml
grep "NB_LISTEN_ADDRESS=:33445" docker-compose.yml
grep '33445:33445' docker-compose.yml
grep -A 10 'relay:' docker-compose.yml | egrep 'NB_AUTH_SECRET=.+$'
grep -A 7 Relay management.json | grep "rels://$CI_NETBIRD_DOMAIN:33445"
grep -A 7 Relay management.json | grep "rel://$CI_NETBIRD_DOMAIN:33445"
grep -A 7 Relay management.json | egrep '"Secret": ".+"'
grep DisablePromptLogin management.json | grep 'true'
grep LoginFlag management.json | grep 0
grep DisableDefaultPolicy management.json | grep "$CI_NETBIRD_MGMT_DISABLE_DEFAULT_POLICY"
- name: Install modules
run: go mod tidy

1
.gitignore vendored
View File

@@ -30,4 +30,3 @@ infrastructure_files/setup-*.env
.vscode
.DS_Store
vendor/
/netbird

View File

@@ -149,119 +149,97 @@ nfpms:
dockers:
- image_templates:
- netbirdio/netbird:{{ .Version }}-amd64
- ghcr.io/netbirdio/netbird:{{ .Version }}-amd64
ids:
- netbird
goarch: amd64
use: buildx
dockerfile: client/Dockerfile
extra_files:
- client/netbird-entrypoint.sh
build_flag_templates:
- "--platform=linux/amd64"
- "--label=org.opencontainers.image.created={{.Date}}"
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=maintainer=dev@netbird.io"
- image_templates:
- netbirdio/netbird:{{ .Version }}-arm64v8
- ghcr.io/netbirdio/netbird:{{ .Version }}-arm64v8
ids:
- netbird
goarch: arm64
use: buildx
dockerfile: client/Dockerfile
extra_files:
- client/netbird-entrypoint.sh
build_flag_templates:
- "--platform=linux/arm64"
- "--label=org.opencontainers.image.created={{.Date}}"
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=maintainer=dev@netbird.io"
- image_templates:
- netbirdio/netbird:{{ .Version }}-arm
- ghcr.io/netbirdio/netbird:{{ .Version }}-arm
ids:
- netbird
goarch: arm
goarm: 6
use: buildx
dockerfile: client/Dockerfile
extra_files:
- client/netbird-entrypoint.sh
build_flag_templates:
- "--platform=linux/arm"
- "--label=org.opencontainers.image.created={{.Date}}"
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=maintainer=dev@netbird.io"
- image_templates:
- netbirdio/netbird:{{ .Version }}-rootless-amd64
- ghcr.io/netbirdio/netbird:{{ .Version }}-rootless-amd64
ids:
- netbird
goarch: amd64
use: buildx
dockerfile: client/Dockerfile-rootless
extra_files:
- client/netbird-entrypoint.sh
build_flag_templates:
- "--platform=linux/amd64"
- "--label=org.opencontainers.image.created={{.Date}}"
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
- "--label=maintainer=dev@netbird.io"
- image_templates:
- netbirdio/netbird:{{ .Version }}-rootless-arm64v8
- ghcr.io/netbirdio/netbird:{{ .Version }}-rootless-arm64v8
ids:
- netbird
goarch: arm64
use: buildx
dockerfile: client/Dockerfile-rootless
extra_files:
- client/netbird-entrypoint.sh
build_flag_templates:
- "--platform=linux/arm64"
- "--label=org.opencontainers.image.created={{.Date}}"
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
- "--label=maintainer=dev@netbird.io"
- image_templates:
- netbirdio/netbird:{{ .Version }}-rootless-arm
- ghcr.io/netbirdio/netbird:{{ .Version }}-rootless-arm
ids:
- netbird
goarch: arm
goarm: 6
use: buildx
dockerfile: client/Dockerfile-rootless
extra_files:
- client/netbird-entrypoint.sh
build_flag_templates:
- "--platform=linux/arm"
- "--label=org.opencontainers.image.created={{.Date}}"
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
- "--label=maintainer=dev@netbird.io"
- image_templates:
- netbirdio/relay:{{ .Version }}-amd64
- ghcr.io/netbirdio/relay:{{ .Version }}-amd64
ids:
- netbird-relay
goarch: amd64
@@ -273,11 +251,10 @@ dockers:
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=maintainer=dev@netbird.io"
- image_templates:
- netbirdio/relay:{{ .Version }}-arm64v8
- ghcr.io/netbirdio/relay:{{ .Version }}-arm64v8
ids:
- netbird-relay
goarch: arm64
@@ -289,11 +266,10 @@ dockers:
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=maintainer=dev@netbird.io"
- image_templates:
- netbirdio/relay:{{ .Version }}-arm
- ghcr.io/netbirdio/relay:{{ .Version }}-arm
ids:
- netbird-relay
goarch: arm
@@ -306,11 +282,10 @@ dockers:
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=maintainer=dev@netbird.io"
- image_templates:
- netbirdio/signal:{{ .Version }}-amd64
- ghcr.io/netbirdio/signal:{{ .Version }}-amd64
ids:
- netbird-signal
goarch: amd64
@@ -322,11 +297,10 @@ dockers:
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=maintainer=dev@netbird.io"
- image_templates:
- netbirdio/signal:{{ .Version }}-arm64v8
- ghcr.io/netbirdio/signal:{{ .Version }}-arm64v8
ids:
- netbird-signal
goarch: arm64
@@ -338,11 +312,10 @@ dockers:
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=maintainer=dev@netbird.io"
- image_templates:
- netbirdio/signal:{{ .Version }}-arm
- ghcr.io/netbirdio/signal:{{ .Version }}-arm
ids:
- netbird-signal
goarch: arm
@@ -355,11 +328,10 @@ dockers:
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=maintainer=dev@netbird.io"
- image_templates:
- netbirdio/management:{{ .Version }}-amd64
- ghcr.io/netbirdio/management:{{ .Version }}-amd64
ids:
- netbird-mgmt
goarch: amd64
@@ -371,11 +343,10 @@ dockers:
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=maintainer=dev@netbird.io"
- image_templates:
- netbirdio/management:{{ .Version }}-arm64v8
- ghcr.io/netbirdio/management:{{ .Version }}-arm64v8
ids:
- netbird-mgmt
goarch: arm64
@@ -387,11 +358,10 @@ dockers:
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=maintainer=dev@netbird.io"
- image_templates:
- netbirdio/management:{{ .Version }}-arm
- ghcr.io/netbirdio/management:{{ .Version }}-arm
ids:
- netbird-mgmt
goarch: arm
@@ -404,11 +374,10 @@ dockers:
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=maintainer=dev@netbird.io"
- image_templates:
- netbirdio/management:{{ .Version }}-debug-amd64
- ghcr.io/netbirdio/management:{{ .Version }}-debug-amd64
ids:
- netbird-mgmt
goarch: amd64
@@ -420,11 +389,10 @@ dockers:
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=maintainer=dev@netbird.io"
- image_templates:
- netbirdio/management:{{ .Version }}-debug-arm64v8
- ghcr.io/netbirdio/management:{{ .Version }}-debug-arm64v8
ids:
- netbird-mgmt
goarch: arm64
@@ -436,12 +404,11 @@ dockers:
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=maintainer=dev@netbird.io"
- image_templates:
- netbirdio/management:{{ .Version }}-debug-arm
- ghcr.io/netbirdio/management:{{ .Version }}-debug-arm
ids:
- netbird-mgmt
goarch: arm
@@ -454,11 +421,10 @@ dockers:
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=maintainer=dev@netbird.io"
- image_templates:
- netbirdio/upload:{{ .Version }}-amd64
- ghcr.io/netbirdio/upload:{{ .Version }}-amd64
ids:
- netbird-upload
goarch: amd64
@@ -470,11 +436,10 @@ dockers:
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=maintainer=dev@netbird.io"
- image_templates:
- netbirdio/upload:{{ .Version }}-arm64v8
- ghcr.io/netbirdio/upload:{{ .Version }}-arm64v8
ids:
- netbird-upload
goarch: arm64
@@ -486,11 +451,10 @@ dockers:
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=maintainer=dev@netbird.io"
- image_templates:
- netbirdio/upload:{{ .Version }}-arm
- ghcr.io/netbirdio/upload:{{ .Version }}-arm
ids:
- netbird-upload
goarch: arm
@@ -503,7 +467,7 @@ dockers:
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=maintainer=dev@netbird.io"
docker_manifests:
- name_template: netbirdio/netbird:{{ .Version }}
@@ -582,84 +546,6 @@ docker_manifests:
- netbirdio/upload:{{ .Version }}-arm64v8
- netbirdio/upload:{{ .Version }}-arm
- netbirdio/upload:{{ .Version }}-amd64
- name_template: ghcr.io/netbirdio/netbird:{{ .Version }}
image_templates:
- ghcr.io/netbirdio/netbird:{{ .Version }}-arm64v8
- ghcr.io/netbirdio/netbird:{{ .Version }}-arm
- ghcr.io/netbirdio/netbird:{{ .Version }}-amd64
- name_template: ghcr.io/netbirdio/netbird:latest
image_templates:
- ghcr.io/netbirdio/netbird:{{ .Version }}-arm64v8
- ghcr.io/netbirdio/netbird:{{ .Version }}-arm
- ghcr.io/netbirdio/netbird:{{ .Version }}-amd64
- name_template: ghcr.io/netbirdio/netbird:{{ .Version }}-rootless
image_templates:
- ghcr.io/netbirdio/netbird:{{ .Version }}-rootless-arm64v8
- ghcr.io/netbirdio/netbird:{{ .Version }}-rootless-arm
- ghcr.io/netbirdio/netbird:{{ .Version }}-rootless-amd64
- name_template: ghcr.io/netbirdio/netbird:rootless-latest
image_templates:
- ghcr.io/netbirdio/netbird:{{ .Version }}-rootless-arm64v8
- ghcr.io/netbirdio/netbird:{{ .Version }}-rootless-arm
- ghcr.io/netbirdio/netbird:{{ .Version }}-rootless-amd64
- name_template: ghcr.io/netbirdio/relay:{{ .Version }}
image_templates:
- ghcr.io/netbirdio/relay:{{ .Version }}-arm64v8
- ghcr.io/netbirdio/relay:{{ .Version }}-arm
- ghcr.io/netbirdio/relay:{{ .Version }}-amd64
- name_template: ghcr.io/netbirdio/relay:latest
image_templates:
- ghcr.io/netbirdio/relay:{{ .Version }}-arm64v8
- ghcr.io/netbirdio/relay:{{ .Version }}-arm
- ghcr.io/netbirdio/relay:{{ .Version }}-amd64
- name_template: ghcr.io/netbirdio/signal:{{ .Version }}
image_templates:
- ghcr.io/netbirdio/signal:{{ .Version }}-arm64v8
- ghcr.io/netbirdio/signal:{{ .Version }}-arm
- ghcr.io/netbirdio/signal:{{ .Version }}-amd64
- name_template: ghcr.io/netbirdio/signal:latest
image_templates:
- ghcr.io/netbirdio/signal:{{ .Version }}-arm64v8
- ghcr.io/netbirdio/signal:{{ .Version }}-arm
- ghcr.io/netbirdio/signal:{{ .Version }}-amd64
- name_template: ghcr.io/netbirdio/management:{{ .Version }}
image_templates:
- ghcr.io/netbirdio/management:{{ .Version }}-arm64v8
- ghcr.io/netbirdio/management:{{ .Version }}-arm
- ghcr.io/netbirdio/management:{{ .Version }}-amd64
- name_template: ghcr.io/netbirdio/management:latest
image_templates:
- ghcr.io/netbirdio/management:{{ .Version }}-arm64v8
- ghcr.io/netbirdio/management:{{ .Version }}-arm
- ghcr.io/netbirdio/management:{{ .Version }}-amd64
- name_template: ghcr.io/netbirdio/management:debug-latest
image_templates:
- ghcr.io/netbirdio/management:{{ .Version }}-debug-arm64v8
- ghcr.io/netbirdio/management:{{ .Version }}-debug-arm
- ghcr.io/netbirdio/management:{{ .Version }}-debug-amd64
- name_template: ghcr.io/netbirdio/upload:{{ .Version }}
image_templates:
- ghcr.io/netbirdio/upload:{{ .Version }}-arm64v8
- ghcr.io/netbirdio/upload:{{ .Version }}-arm
- ghcr.io/netbirdio/upload:{{ .Version }}-amd64
- name_template: ghcr.io/netbirdio/upload:latest
image_templates:
- ghcr.io/netbirdio/upload:{{ .Version }}-arm64v8
- ghcr.io/netbirdio/upload:{{ .Version }}-arm
- ghcr.io/netbirdio/upload:{{ .Version }}-amd64
brews:
- ids:
- default

View File

@@ -12,11 +12,8 @@
<img src="https://img.shields.io/badge/license-BSD--3-blue" />
</a>
<br>
<a href="https://docs.netbird.io/slack-url">
<a href="https://join.slack.com/t/netbirdio/shared_invite/zt-31rofwmxc-27akKd0Le0vyRpBcwXkP0g">
<img src="https://img.shields.io/badge/slack-@netbird-red.svg?logo=slack"/>
</a>
<a href="https://forum.netbird.io">
<img src="https://img.shields.io/badge/community forum-@netbird-red.svg?logo=discourse"/>
</a>
<br>
<a href="https://gurubase.io/g/netbird">
@@ -32,13 +29,13 @@
<br/>
See <a href="https://netbird.io/docs/">Documentation</a>
<br/>
Join our <a href="https://docs.netbird.io/slack-url">Slack channel</a> or our <a href="https://forum.netbird.io">Community forum</a>
Join our <a href="https://join.slack.com/t/netbirdio/shared_invite/zt-31rofwmxc-27akKd0Le0vyRpBcwXkP0g">Slack channel</a>
<br/>
</strong>
<br>
<a href="https://registry.terraform.io/providers/netbirdio/netbird/latest">
New: NetBird terraform provider
<a href="https://github.com/netbirdio/kubernetes-operator">
New: NetBird Kubernetes Operator
</a>
</p>
@@ -50,9 +47,10 @@
**Secure.** NetBird enables secure remote access by applying granular access policies while allowing you to manage them intuitively from a single place. Works universally on any infrastructure.
### Open Source Network Security in a Single Platform
### Open-Source Network Security in a Single Platform
<img width="1188" alt="centralized-network-management 1" src="https://github.com/user-attachments/assets/c28cc8e4-15d2-4d2f-bb97-a6433db39d56" />
![netbird_2](https://github.com/netbirdio/netbird/assets/700848/46bc3b73-508d-4a0e-bb9a-f465d68646ab)
### NetBird on Lawrence Systems (Video)
[![Watch the video](https://img.youtube.com/vi/Kwrff6h0rEw/0.jpg)](https://www.youtube.com/watch?v=Kwrff6h0rEw)

7
buf.gen.yaml Normal file
View File

@@ -0,0 +1,7 @@
# For details on buf.gen.yaml configuration, visit https://buf.build/docs/configuration/v2/buf-gen-yaml/
version: v2
plugins:
- remote: buf.build/protocolbuffers/go:v1.35.1
out: .
- remote: buf.build/grpc/go:v1.5.1
out: .

10
buf.yaml Normal file
View File

@@ -0,0 +1,10 @@
# For details on buf.yaml configuration, visit https://buf.build/docs/configuration/v2/buf-yaml
version: v2
modules:
- path: proto
lint:
use:
- BASIC
breaking:
use:
- FILE

View File

@@ -1,27 +1,6 @@
# build & run locally with:
# cd "$(git rev-parse --show-toplevel)"
# CGO_ENABLED=0 go build -o netbird ./client
# sudo podman build -t localhost/netbird:latest -f client/Dockerfile --ignorefile .dockerignore-client .
# sudo podman run --rm -it --cap-add={BPF,NET_ADMIN,NET_RAW} localhost/netbird:latest
FROM alpine:3.22.0
FROM alpine:3.21.3
# iproute2: busybox doesn't display ip rules properly
RUN apk add --no-cache \
bash \
ca-certificates \
ip6tables \
iproute2 \
iptables
ENV \
NETBIRD_BIN="/usr/local/bin/netbird" \
NB_LOG_FILE="console,/var/log/netbird/client.log" \
NB_DAEMON_ADDR="unix:///var/run/netbird.sock" \
NB_ENTRYPOINT_SERVICE_TIMEOUT="5" \
NB_ENTRYPOINT_LOGIN_TIMEOUT="1"
ENTRYPOINT [ "/usr/local/bin/netbird-entrypoint.sh" ]
ARG NETBIRD_BINARY=netbird
COPY client/netbird-entrypoint.sh /usr/local/bin/netbird-entrypoint.sh
COPY "${NETBIRD_BINARY}" /usr/local/bin/netbird
RUN apk add --no-cache ca-certificates ip6tables iproute2 iptables
ENV NB_FOREGROUND_MODE=true
ENTRYPOINT [ "/usr/local/bin/netbird","up"]
COPY netbird /usr/local/bin/netbird

View File

@@ -1,33 +1,17 @@
# build & run locally with:
# cd "$(git rev-parse --show-toplevel)"
# CGO_ENABLED=0 go build -o netbird ./client
# podman build -t localhost/netbird:latest -f client/Dockerfile --ignorefile .dockerignore-client .
# podman run --rm -it --cap-add={BPF,NET_ADMIN,NET_RAW} localhost/netbird:latest
FROM alpine:3.21.0
FROM alpine:3.22.0
COPY netbird /usr/local/bin/netbird
RUN apk add --no-cache \
bash \
ca-certificates \
RUN apk add --no-cache ca-certificates \
&& adduser -D -h /var/lib/netbird netbird
WORKDIR /var/lib/netbird
USER netbird:netbird
ENV \
NETBIRD_BIN="/usr/local/bin/netbird" \
NB_USE_NETSTACK_MODE="true" \
NB_ENABLE_NETSTACK_LOCAL_FORWARDING="true" \
NB_CONFIG="/var/lib/netbird/config.json" \
NB_STATE_DIR="/var/lib/netbird" \
NB_DAEMON_ADDR="unix:///var/lib/netbird/netbird.sock" \
NB_LOG_FILE="console,/var/lib/netbird/client.log" \
NB_DISABLE_DNS="true" \
NB_ENTRYPOINT_SERVICE_TIMEOUT="5" \
NB_ENTRYPOINT_LOGIN_TIMEOUT="1"
ENV NB_FOREGROUND_MODE=true
ENV NB_USE_NETSTACK_MODE=true
ENV NB_ENABLE_NETSTACK_LOCAL_FORWARDING=true
ENV NB_CONFIG=config.json
ENV NB_DAEMON_ADDR=unix://netbird.sock
ENV NB_DISABLE_DNS=true
ENTRYPOINT [ "/usr/local/bin/netbird-entrypoint.sh" ]
ARG NETBIRD_BINARY=netbird
COPY client/netbird-entrypoint.sh /usr/local/bin/netbird-entrypoint.sh
COPY "${NETBIRD_BINARY}" /usr/local/bin/netbird
ENTRYPOINT [ "/usr/local/bin/netbird", "up" ]

View File

@@ -13,7 +13,6 @@ import (
"github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/listener"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/profilemanager"
"github.com/netbirdio/netbird/client/internal/stdnet"
"github.com/netbirdio/netbird/client/system"
"github.com/netbirdio/netbird/formatter"
@@ -60,14 +59,10 @@ type Client struct {
deviceName string
uiVersion string
networkChangeListener listener.NetworkChangeListener
connectClient *internal.ConnectClient
}
// NewClient instantiate a new Client
func NewClient(cfgFile string, androidSDKVersion int, deviceName string, uiVersion string, tunAdapter TunAdapter, iFaceDiscover IFaceDiscover, networkChangeListener NetworkChangeListener) *Client {
execWorkaround(androidSDKVersion)
func NewClient(cfgFile, deviceName string, uiVersion string, tunAdapter TunAdapter, iFaceDiscover IFaceDiscover, networkChangeListener NetworkChangeListener) *Client {
net.SetAndroidProtectSocketFn(tunAdapter.ProtectSocket)
return &Client{
cfgFile: cfgFile,
@@ -83,7 +78,7 @@ func NewClient(cfgFile string, androidSDKVersion int, deviceName string, uiVersi
// Run start the internal client. It is a blocker function
func (c *Client) Run(urlOpener URLOpener, dns *DNSList, dnsReadyListener DnsReadyListener) error {
cfg, err := profilemanager.UpdateOrCreateConfig(profilemanager.ConfigInput{
cfg, err := internal.UpdateOrCreateConfig(internal.ConfigInput{
ConfigPath: c.cfgFile,
})
if err != nil {
@@ -111,14 +106,14 @@ func (c *Client) Run(urlOpener URLOpener, dns *DNSList, dnsReadyListener DnsRead
// todo do not throw error in case of cancelled context
ctx = internal.CtxInitState(ctx)
c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder)
return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, dns.items, dnsReadyListener)
connectClient := internal.NewConnectClient(ctx, cfg, c.recorder)
return connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, dns.items, dnsReadyListener)
}
// RunWithoutLogin we apply this type of run function when the backed has been started without UI (i.e. after reboot).
// In this case make no sense handle registration steps.
func (c *Client) RunWithoutLogin(dns *DNSList, dnsReadyListener DnsReadyListener) error {
cfg, err := profilemanager.UpdateOrCreateConfig(profilemanager.ConfigInput{
cfg, err := internal.UpdateOrCreateConfig(internal.ConfigInput{
ConfigPath: c.cfgFile,
})
if err != nil {
@@ -137,8 +132,8 @@ func (c *Client) RunWithoutLogin(dns *DNSList, dnsReadyListener DnsReadyListener
// todo do not throw error in case of cancelled context
ctx = internal.CtxInitState(ctx)
c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder)
return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, dns.items, dnsReadyListener)
connectClient := internal.NewConnectClient(ctx, cfg, c.recorder)
return connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, dns.items, dnsReadyListener)
}
// Stop the internal client and free the resources
@@ -179,55 +174,6 @@ func (c *Client) PeersList() *PeerInfoArray {
return &PeerInfoArray{items: peerInfos}
}
func (c *Client) Networks() *NetworkArray {
if c.connectClient == nil {
log.Error("not connected")
return nil
}
engine := c.connectClient.Engine()
if engine == nil {
log.Error("could not get engine")
return nil
}
routeManager := engine.GetRouteManager()
if routeManager == nil {
log.Error("could not get route manager")
return nil
}
networkArray := &NetworkArray{
items: make([]Network, 0),
}
for id, routes := range routeManager.GetClientRoutesWithNetID() {
if len(routes) == 0 {
continue
}
r := routes[0]
netStr := r.Network.String()
if r.IsDynamic() {
netStr = r.Domains.SafeString()
}
peer, err := c.recorder.GetPeer(routes[0].Peer)
if err != nil {
log.Errorf("could not get peer info for %s: %v", routes[0].Peer, err)
continue
}
network := Network{
Name: string(id),
Network: netStr,
Peer: peer.FQDN,
Status: peer.ConnStatus.String(),
}
networkArray.Add(network)
}
return networkArray
}
// OnUpdatedHostDNS update the DNS servers addresses for root zones
func (c *Client) OnUpdatedHostDNS(list *DNSList) error {
dnsServer, err := dns.GetServerDns()

View File

@@ -1,26 +0,0 @@
//go:build android
package android
import (
"fmt"
_ "unsafe"
)
// https://github.com/golang/go/pull/69543/commits/aad6b3b32c81795f86bc4a9e81aad94899daf520
// In Android version 11 and earlier, pidfd-related system calls
// are not allowed by the seccomp policy, which causes crashes due
// to SIGSYS signals.
//go:linkname checkPidfdOnce os.checkPidfdOnce
var checkPidfdOnce func() error
func execWorkaround(androidSDKVersion int) {
if androidSDKVersion > 30 { // above Android 11
return
}
checkPidfdOnce = func() error {
return fmt.Errorf("unsupported Android version")
}
}

View File

@@ -13,7 +13,6 @@ import (
"github.com/netbirdio/netbird/client/cmd"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/auth"
"github.com/netbirdio/netbird/client/internal/profilemanager"
"github.com/netbirdio/netbird/client/system"
)
@@ -38,17 +37,17 @@ type URLOpener interface {
// Auth can register or login new client
type Auth struct {
ctx context.Context
config *profilemanager.Config
config *internal.Config
cfgPath string
}
// NewAuth instantiate Auth struct and validate the management URL
func NewAuth(cfgPath string, mgmURL string) (*Auth, error) {
inputCfg := profilemanager.ConfigInput{
inputCfg := internal.ConfigInput{
ManagementURL: mgmURL,
}
cfg, err := profilemanager.CreateInMemoryConfig(inputCfg)
cfg, err := internal.CreateInMemoryConfig(inputCfg)
if err != nil {
return nil, err
}
@@ -61,7 +60,7 @@ func NewAuth(cfgPath string, mgmURL string) (*Auth, error) {
}
// NewAuthWithConfig instantiate Auth based on existing config
func NewAuthWithConfig(ctx context.Context, config *profilemanager.Config) *Auth {
func NewAuthWithConfig(ctx context.Context, config *internal.Config) *Auth {
return &Auth{
ctx: ctx,
config: config,
@@ -111,7 +110,7 @@ func (a *Auth) saveConfigIfSSOSupported() (bool, error) {
return false, fmt.Errorf("backoff cycle failed: %v", err)
}
err = profilemanager.WriteOutConfig(a.cfgPath, a.config)
err = internal.WriteOutConfig(a.cfgPath, a.config)
return true, err
}
@@ -143,7 +142,7 @@ func (a *Auth) loginWithSetupKeyAndSaveConfig(setupKey string, deviceName string
return fmt.Errorf("backoff cycle failed: %v", err)
}
return profilemanager.WriteOutConfig(a.cfgPath, a.config)
return internal.WriteOutConfig(a.cfgPath, a.config)
}
// Login try register the client on the server

View File

@@ -1,27 +0,0 @@
//go:build android
package android
type Network struct {
Name string
Network string
Peer string
Status string
}
type NetworkArray struct {
items []Network
}
func (array *NetworkArray) Add(s Network) *NetworkArray {
array.items = append(array.items, s)
return array
}
func (array *NetworkArray) Get(i int) *Network {
return &array.items[i]
}
func (array *NetworkArray) Size() int {
return len(array.items)
}

View File

@@ -7,23 +7,30 @@ type PeerInfo struct {
ConnStatus string // Todo replace to enum
}
// PeerInfoArray is a wrapper of []PeerInfo
// PeerInfoCollection made for Java layer to get non default types as collection
type PeerInfoCollection interface {
Add(s string) PeerInfoCollection
Get(i int) string
Size() int
}
// PeerInfoArray is the implementation of the PeerInfoCollection
type PeerInfoArray struct {
items []PeerInfo
}
// Add new PeerInfo to the collection
func (array *PeerInfoArray) Add(s PeerInfo) *PeerInfoArray {
func (array PeerInfoArray) Add(s PeerInfo) PeerInfoArray {
array.items = append(array.items, s)
return array
}
// Get return an element of the collection
func (array *PeerInfoArray) Get(i int) *PeerInfo {
func (array PeerInfoArray) Get(i int) *PeerInfo {
return &array.items[i]
}
// Size return with the size of the collection
func (array *PeerInfoArray) Size() int {
func (array PeerInfoArray) Size() int {
return len(array.items)
}

View File

@@ -1,226 +1,78 @@
package android
import (
"github.com/netbirdio/netbird/client/internal/profilemanager"
"github.com/netbirdio/netbird/client/internal"
)
// Preferences exports a subset of the internal config for gomobile
// Preferences export a subset of the internal config for gomobile
type Preferences struct {
configInput profilemanager.ConfigInput
configInput internal.ConfigInput
}
// NewPreferences creates a new Preferences instance
// NewPreferences create new Preferences instance
func NewPreferences(configPath string) *Preferences {
ci := profilemanager.ConfigInput{
ci := internal.ConfigInput{
ConfigPath: configPath,
}
return &Preferences{ci}
}
// GetManagementURL reads URL from config file
// GetManagementURL read url from config file
func (p *Preferences) GetManagementURL() (string, error) {
if p.configInput.ManagementURL != "" {
return p.configInput.ManagementURL, nil
}
cfg, err := profilemanager.ReadConfig(p.configInput.ConfigPath)
cfg, err := internal.ReadConfig(p.configInput.ConfigPath)
if err != nil {
return "", err
}
return cfg.ManagementURL.String(), err
}
// SetManagementURL stores the given URL and waits for commit
// SetManagementURL store the given url and wait for commit
func (p *Preferences) SetManagementURL(url string) {
p.configInput.ManagementURL = url
}
// GetAdminURL reads URL from config file
// GetAdminURL read url from config file
func (p *Preferences) GetAdminURL() (string, error) {
if p.configInput.AdminURL != "" {
return p.configInput.AdminURL, nil
}
cfg, err := profilemanager.ReadConfig(p.configInput.ConfigPath)
cfg, err := internal.ReadConfig(p.configInput.ConfigPath)
if err != nil {
return "", err
}
return cfg.AdminURL.String(), err
}
// SetAdminURL stores the given URL and waits for commit
// SetAdminURL store the given url and wait for commit
func (p *Preferences) SetAdminURL(url string) {
p.configInput.AdminURL = url
}
// GetPreSharedKey reads pre-shared key from config file
// GetPreSharedKey read preshared key from config file
func (p *Preferences) GetPreSharedKey() (string, error) {
if p.configInput.PreSharedKey != nil {
return *p.configInput.PreSharedKey, nil
}
cfg, err := profilemanager.ReadConfig(p.configInput.ConfigPath)
cfg, err := internal.ReadConfig(p.configInput.ConfigPath)
if err != nil {
return "", err
}
return cfg.PreSharedKey, err
}
// SetPreSharedKey stores the given key and waits for commit
// SetPreSharedKey store the given key and wait for commit
func (p *Preferences) SetPreSharedKey(key string) {
p.configInput.PreSharedKey = &key
}
// SetRosenpassEnabled stores whether Rosenpass is enabled
func (p *Preferences) SetRosenpassEnabled(enabled bool) {
p.configInput.RosenpassEnabled = &enabled
}
// GetRosenpassEnabled reads Rosenpass enabled status from config file
func (p *Preferences) GetRosenpassEnabled() (bool, error) {
if p.configInput.RosenpassEnabled != nil {
return *p.configInput.RosenpassEnabled, nil
}
cfg, err := profilemanager.ReadConfig(p.configInput.ConfigPath)
if err != nil {
return false, err
}
return cfg.RosenpassEnabled, err
}
// SetRosenpassPermissive stores the given permissive setting and waits for commit
func (p *Preferences) SetRosenpassPermissive(permissive bool) {
p.configInput.RosenpassPermissive = &permissive
}
// GetRosenpassPermissive reads Rosenpass permissive setting from config file
func (p *Preferences) GetRosenpassPermissive() (bool, error) {
if p.configInput.RosenpassPermissive != nil {
return *p.configInput.RosenpassPermissive, nil
}
cfg, err := profilemanager.ReadConfig(p.configInput.ConfigPath)
if err != nil {
return false, err
}
return cfg.RosenpassPermissive, err
}
// GetDisableClientRoutes reads disable client routes setting from config file
func (p *Preferences) GetDisableClientRoutes() (bool, error) {
if p.configInput.DisableClientRoutes != nil {
return *p.configInput.DisableClientRoutes, nil
}
cfg, err := profilemanager.ReadConfig(p.configInput.ConfigPath)
if err != nil {
return false, err
}
return cfg.DisableClientRoutes, err
}
// SetDisableClientRoutes stores the given value and waits for commit
func (p *Preferences) SetDisableClientRoutes(disable bool) {
p.configInput.DisableClientRoutes = &disable
}
// GetDisableServerRoutes reads disable server routes setting from config file
func (p *Preferences) GetDisableServerRoutes() (bool, error) {
if p.configInput.DisableServerRoutes != nil {
return *p.configInput.DisableServerRoutes, nil
}
cfg, err := profilemanager.ReadConfig(p.configInput.ConfigPath)
if err != nil {
return false, err
}
return cfg.DisableServerRoutes, err
}
// SetDisableServerRoutes stores the given value and waits for commit
func (p *Preferences) SetDisableServerRoutes(disable bool) {
p.configInput.DisableServerRoutes = &disable
}
// GetDisableDNS reads disable DNS setting from config file
func (p *Preferences) GetDisableDNS() (bool, error) {
if p.configInput.DisableDNS != nil {
return *p.configInput.DisableDNS, nil
}
cfg, err := profilemanager.ReadConfig(p.configInput.ConfigPath)
if err != nil {
return false, err
}
return cfg.DisableDNS, err
}
// SetDisableDNS stores the given value and waits for commit
func (p *Preferences) SetDisableDNS(disable bool) {
p.configInput.DisableDNS = &disable
}
// GetDisableFirewall reads disable firewall setting from config file
func (p *Preferences) GetDisableFirewall() (bool, error) {
if p.configInput.DisableFirewall != nil {
return *p.configInput.DisableFirewall, nil
}
cfg, err := profilemanager.ReadConfig(p.configInput.ConfigPath)
if err != nil {
return false, err
}
return cfg.DisableFirewall, err
}
// SetDisableFirewall stores the given value and waits for commit
func (p *Preferences) SetDisableFirewall(disable bool) {
p.configInput.DisableFirewall = &disable
}
// GetServerSSHAllowed reads server SSH allowed setting from config file
func (p *Preferences) GetServerSSHAllowed() (bool, error) {
if p.configInput.ServerSSHAllowed != nil {
return *p.configInput.ServerSSHAllowed, nil
}
cfg, err := profilemanager.ReadConfig(p.configInput.ConfigPath)
if err != nil {
return false, err
}
if cfg.ServerSSHAllowed == nil {
// Default to false for security on Android
return false, nil
}
return *cfg.ServerSSHAllowed, err
}
// SetServerSSHAllowed stores the given value and waits for commit
func (p *Preferences) SetServerSSHAllowed(allowed bool) {
p.configInput.ServerSSHAllowed = &allowed
}
// GetBlockInbound reads block inbound setting from config file
func (p *Preferences) GetBlockInbound() (bool, error) {
if p.configInput.BlockInbound != nil {
return *p.configInput.BlockInbound, nil
}
cfg, err := profilemanager.ReadConfig(p.configInput.ConfigPath)
if err != nil {
return false, err
}
return cfg.BlockInbound, err
}
// SetBlockInbound stores the given value and waits for commit
func (p *Preferences) SetBlockInbound(block bool) {
p.configInput.BlockInbound = &block
}
// Commit writes out the changes to the config file
// Commit write out the changes into config file
func (p *Preferences) Commit() error {
_, err := profilemanager.UpdateOrCreateConfig(p.configInput)
_, err := internal.UpdateOrCreateConfig(p.configInput)
return err
}

View File

@@ -4,7 +4,7 @@ import (
"path/filepath"
"testing"
"github.com/netbirdio/netbird/client/internal/profilemanager"
"github.com/netbirdio/netbird/client/internal"
)
func TestPreferences_DefaultValues(t *testing.T) {
@@ -15,7 +15,7 @@ func TestPreferences_DefaultValues(t *testing.T) {
t.Fatalf("failed to read default value: %s", err)
}
if defaultVar != profilemanager.DefaultAdminURL {
if defaultVar != internal.DefaultAdminURL {
t.Errorf("invalid default admin url: %s", defaultVar)
}
@@ -24,7 +24,7 @@ func TestPreferences_DefaultValues(t *testing.T) {
t.Fatalf("failed to read default management URL: %s", err)
}
if defaultVar != profilemanager.DefaultManagementURL {
if defaultVar != internal.DefaultManagementURL {
t.Errorf("invalid default management url: %s", defaultVar)
}

View File

@@ -69,22 +69,6 @@ func (a *Anonymizer) AnonymizeIP(ip netip.Addr) netip.Addr {
return a.ipAnonymizer[ip]
}
func (a *Anonymizer) AnonymizeUDPAddr(addr net.UDPAddr) net.UDPAddr {
// Convert IP to netip.Addr
ip, ok := netip.AddrFromSlice(addr.IP)
if !ok {
return addr
}
anonIP := a.AnonymizeIP(ip)
return net.UDPAddr{
IP: anonIP.AsSlice(),
Port: addr.Port,
Zone: addr.Zone,
}
}
// isInAnonymizedRange checks if an IP is within the range of already assigned anonymized IPs
func (a *Anonymizer) isInAnonymizedRange(ip netip.Addr) bool {
if ip.Is4() && ip.Compare(a.startAnonIPv4) >= 0 && ip.Compare(a.currentAnonIPv4) <= 0 {

View File

@@ -13,23 +13,14 @@ import (
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/debug"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/profilemanager"
"github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/client/server"
nbstatus "github.com/netbirdio/netbird/client/status"
mgmProto "github.com/netbirdio/netbird/management/proto"
"github.com/netbirdio/netbird/upload-server/types"
)
const errCloseConnection = "Failed to close connection: %v"
var (
logFileCount uint32
systemInfoFlag bool
uploadBundleFlag bool
uploadBundleURLFlag string
)
var debugCmd = &cobra.Command{
Use: "debug",
Short: "Debugging commands",
@@ -97,13 +88,12 @@ func debugBundle(cmd *cobra.Command, _ []string) error {
client := proto.NewDaemonServiceClient(conn)
request := &proto.DebugBundleRequest{
Anonymize: anonymizeFlag,
Status: getStatusOutput(cmd, anonymizeFlag),
SystemInfo: systemInfoFlag,
LogFileCount: logFileCount,
Anonymize: anonymizeFlag,
Status: getStatusOutput(cmd, anonymizeFlag),
SystemInfo: debugSystemInfoFlag,
}
if uploadBundleFlag {
request.UploadURL = uploadBundleURLFlag
if debugUploadBundle {
request.UploadURL = debugUploadBundleURL
}
resp, err := client.DebugBundle(cmd.Context(), request)
if err != nil {
@@ -115,7 +105,7 @@ func debugBundle(cmd *cobra.Command, _ []string) error {
return fmt.Errorf("upload failed: %s", resp.GetUploadFailureReason())
}
if uploadBundleFlag {
if debugUploadBundle {
cmd.Printf("Upload file key:\n%s\n", resp.GetUploadedKey())
}
@@ -233,13 +223,12 @@ func runForDuration(cmd *cobra.Command, args []string) error {
headerPreDown := fmt.Sprintf("----- Netbird pre-down - Timestamp: %s - Duration: %s", time.Now().Format(time.RFC3339), duration)
statusOutput = fmt.Sprintf("%s\n%s\n%s", statusOutput, headerPreDown, getStatusOutput(cmd, anonymizeFlag))
request := &proto.DebugBundleRequest{
Anonymize: anonymizeFlag,
Status: statusOutput,
SystemInfo: systemInfoFlag,
LogFileCount: logFileCount,
Anonymize: anonymizeFlag,
Status: statusOutput,
SystemInfo: debugSystemInfoFlag,
}
if uploadBundleFlag {
request.UploadURL = uploadBundleURLFlag
if debugUploadBundle {
request.UploadURL = debugUploadBundleURL
}
resp, err := client.DebugBundle(cmd.Context(), request)
if err != nil {
@@ -266,7 +255,7 @@ func runForDuration(cmd *cobra.Command, args []string) error {
return fmt.Errorf("upload failed: %s", resp.GetUploadFailureReason())
}
if uploadBundleFlag {
if debugUploadBundle {
cmd.Printf("Upload file key:\n%s\n", resp.GetUploadedKey())
}
@@ -308,7 +297,7 @@ func getStatusOutput(cmd *cobra.Command, anon bool) string {
cmd.PrintErrf("Failed to get status: %v\n", err)
} else {
statusOutputString = nbstatus.ParseToFullDetailSummary(
nbstatus.ConvertToStatusOutputOverview(statusResp, anon, "", nil, nil, nil, "", ""),
nbstatus.ConvertToStatusOutputOverview(statusResp, anon, "", nil, nil, nil),
)
}
return statusOutputString
@@ -356,7 +345,7 @@ func formatDuration(d time.Duration) string {
return fmt.Sprintf("%02d:%02d:%02d", h, m, s)
}
func generateDebugBundle(config *profilemanager.Config, recorder *peer.Status, connectClient *internal.ConnectClient, logFilePath string) {
func generateDebugBundle(config *internal.Config, recorder *peer.Status, connectClient *internal.ConnectClient, logFilePath string) {
var networkMap *mgmProto.NetworkMap
var err error
@@ -386,15 +375,3 @@ func generateDebugBundle(config *profilemanager.Config, recorder *peer.Status, c
}
log.Infof("Generated debug bundle from SIGUSR1 at: %s", path)
}
func init() {
debugBundleCmd.Flags().Uint32VarP(&logFileCount, "log-file-count", "C", 1, "Number of rotated log files to include in debug bundle")
debugBundleCmd.Flags().BoolVarP(&systemInfoFlag, "system-info", "S", true, "Adds system information to the debug bundle")
debugBundleCmd.Flags().BoolVarP(&uploadBundleFlag, "upload-bundle", "U", false, "Uploads the debug bundle to a server")
debugBundleCmd.Flags().StringVar(&uploadBundleURLFlag, "upload-bundle-url", types.DefaultBundleURL, "Service URL to get an URL to upload the debug bundle")
forCmd.Flags().Uint32VarP(&logFileCount, "log-file-count", "C", 1, "Number of rotated log files to include in debug bundle")
forCmd.Flags().BoolVarP(&systemInfoFlag, "system-info", "S", true, "Adds system information to the debug bundle")
forCmd.Flags().BoolVarP(&uploadBundleFlag, "upload-bundle", "U", false, "Uploads the debug bundle to a server")
forCmd.Flags().StringVar(&uploadBundleURLFlag, "upload-bundle-url", types.DefaultBundleURL, "Service URL to get an URL to upload the debug bundle")
}

View File

@@ -12,12 +12,11 @@ import (
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/profilemanager"
)
func SetupDebugHandler(
ctx context.Context,
config *profilemanager.Config,
config *internal.Config,
recorder *peer.Status,
connectClient *internal.ConnectClient,
logFilePath string,

View File

@@ -12,7 +12,6 @@ import (
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/profilemanager"
)
const (
@@ -29,7 +28,7 @@ const (
// $evt.Close()
func SetupDebugHandler(
ctx context.Context,
config *profilemanager.Config,
config *internal.Config,
recorder *peer.Status,
connectClient *internal.ConnectClient,
logFilePath string,
@@ -84,7 +83,7 @@ func SetupDebugHandler(
func waitForEvent(
ctx context.Context,
config *profilemanager.Config,
config *internal.Config,
recorder *peer.Status,
connectClient *internal.ConnectClient,
logFilePath string,

View File

@@ -20,7 +20,7 @@ var downCmd = &cobra.Command{
cmd.SetOut(cmd.OutOrStdout())
err := util.InitLog(logLevel, util.LogConsole)
err := util.InitLog(logLevel, "console")
if err != nil {
log.Errorf("failed initializing log %v", err)
return err

View File

@@ -4,12 +4,9 @@ import (
"context"
"fmt"
"os"
"os/user"
"runtime"
"strings"
"time"
log "github.com/sirupsen/logrus"
"github.com/skratchdot/open-golang/open"
"github.com/spf13/cobra"
"google.golang.org/grpc/codes"
@@ -17,7 +14,6 @@ import (
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/auth"
"github.com/netbirdio/netbird/client/internal/profilemanager"
"github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/client/system"
"github.com/netbirdio/netbird/util"
@@ -25,16 +21,19 @@ import (
func init() {
loginCmd.PersistentFlags().BoolVar(&noBrowser, noBrowserFlag, false, noBrowserDesc)
loginCmd.PersistentFlags().StringVar(&profileName, profileNameFlag, "", profileNameDesc)
loginCmd.PersistentFlags().StringVarP(&configPath, "config", "c", "", "(DEPRECATED) Netbird config file location")
}
var loginCmd = &cobra.Command{
Use: "login",
Short: "login to the Netbird Management Service (first run)",
RunE: func(cmd *cobra.Command, args []string) error {
if err := setEnvAndFlags(cmd); err != nil {
return fmt.Errorf("set env and flags: %v", err)
SetFlagsFromEnvVars(rootCmd)
cmd.SetOut(cmd.OutOrStdout())
err := util.InitLog(logLevel, "console")
if err != nil {
return fmt.Errorf("failed initializing log %v", err)
}
ctx := internal.CtxInitState(context.Background())
@@ -43,17 +42,6 @@ var loginCmd = &cobra.Command{
// nolint
ctx = context.WithValue(ctx, system.DeviceNameCtxKey, hostName)
}
username, err := user.Current()
if err != nil {
return fmt.Errorf("get current user: %v", err)
}
pm := profilemanager.NewProfileManager()
activeProf, err := getActiveProfile(cmd.Context(), pm, profileName, username.Username)
if err != nil {
return fmt.Errorf("get active profile: %v", err)
}
providedSetupKey, err := getSetupKey()
if err != nil {
@@ -61,15 +49,97 @@ var loginCmd = &cobra.Command{
}
// workaround to run without service
if util.FindFirstLogPath(logFiles) == "" {
if err := doForegroundLogin(ctx, cmd, providedSetupKey, activeProf); err != nil {
if logFile == "console" {
err = handleRebrand(cmd)
if err != nil {
return err
}
// update host's static platform and system information
system.UpdateStaticInfo()
ic := internal.ConfigInput{
ManagementURL: managementURL,
AdminURL: adminURL,
ConfigPath: configPath,
}
if rootCmd.PersistentFlags().Changed(preSharedKeyFlag) {
ic.PreSharedKey = &preSharedKey
}
config, err := internal.UpdateOrCreateConfig(ic)
if err != nil {
return fmt.Errorf("get config file: %v", err)
}
config, _ = internal.UpdateOldManagementURL(ctx, config, configPath)
err = foregroundLogin(ctx, cmd, config, providedSetupKey)
if err != nil {
return fmt.Errorf("foreground login failed: %v", err)
}
cmd.Println("Logging successfully")
return nil
}
if err := doDaemonLogin(ctx, cmd, providedSetupKey, activeProf, username.Username, pm); err != nil {
return fmt.Errorf("daemon login failed: %v", err)
conn, err := DialClientGRPCServer(ctx, daemonAddr)
if err != nil {
return fmt.Errorf("failed to connect to daemon error: %v\n"+
"If the daemon is not running please run: "+
"\nnetbird service install \nnetbird service start\n", err)
}
defer conn.Close()
client := proto.NewDaemonServiceClient(conn)
var dnsLabelsReq []string
if dnsLabelsValidated != nil {
dnsLabelsReq = dnsLabelsValidated.ToSafeStringList()
}
loginRequest := proto.LoginRequest{
SetupKey: providedSetupKey,
ManagementUrl: managementURL,
IsLinuxDesktopClient: isLinuxRunningDesktop(),
Hostname: hostName,
DnsLabels: dnsLabelsReq,
}
if rootCmd.PersistentFlags().Changed(preSharedKeyFlag) {
loginRequest.OptionalPreSharedKey = &preSharedKey
}
var loginErr error
var loginResp *proto.LoginResponse
err = WithBackOff(func() error {
var backOffErr error
loginResp, backOffErr = client.Login(ctx, &loginRequest)
if s, ok := gstatus.FromError(backOffErr); ok && (s.Code() == codes.InvalidArgument ||
s.Code() == codes.PermissionDenied ||
s.Code() == codes.NotFound ||
s.Code() == codes.Unimplemented) {
loginErr = backOffErr
return nil
}
return backOffErr
})
if err != nil {
return fmt.Errorf("login backoff cycle failed: %v", err)
}
if loginErr != nil {
return fmt.Errorf("login failed: %v", loginErr)
}
if loginResp.NeedsSSOLogin {
openURL(cmd, loginResp.VerificationURIComplete, loginResp.UserCode, noBrowser)
_, err = client.WaitSSOLogin(ctx, &proto.WaitSSOLoginRequest{UserCode: loginResp.UserCode, Hostname: hostName})
if err != nil {
return fmt.Errorf("waiting sso login failed with: %v", err)
}
}
cmd.Println("Logging successfully")
@@ -78,196 +148,7 @@ var loginCmd = &cobra.Command{
},
}
func doDaemonLogin(ctx context.Context, cmd *cobra.Command, providedSetupKey string, activeProf *profilemanager.Profile, username string, pm *profilemanager.ProfileManager) error {
conn, err := DialClientGRPCServer(ctx, daemonAddr)
if err != nil {
return fmt.Errorf("failed to connect to daemon error: %v\n"+
"If the daemon is not running please run: "+
"\nnetbird service install \nnetbird service start\n", err)
}
defer conn.Close()
client := proto.NewDaemonServiceClient(conn)
var dnsLabelsReq []string
if dnsLabelsValidated != nil {
dnsLabelsReq = dnsLabelsValidated.ToSafeStringList()
}
loginRequest := proto.LoginRequest{
SetupKey: providedSetupKey,
ManagementUrl: managementURL,
IsUnixDesktopClient: isUnixRunningDesktop(),
Hostname: hostName,
DnsLabels: dnsLabelsReq,
ProfileName: &activeProf.Name,
Username: &username,
}
if rootCmd.PersistentFlags().Changed(preSharedKeyFlag) {
loginRequest.OptionalPreSharedKey = &preSharedKey
}
var loginErr error
var loginResp *proto.LoginResponse
err = WithBackOff(func() error {
var backOffErr error
loginResp, backOffErr = client.Login(ctx, &loginRequest)
if s, ok := gstatus.FromError(backOffErr); ok && (s.Code() == codes.InvalidArgument ||
s.Code() == codes.PermissionDenied ||
s.Code() == codes.NotFound ||
s.Code() == codes.Unimplemented) {
loginErr = backOffErr
return nil
}
return backOffErr
})
if err != nil {
return fmt.Errorf("login backoff cycle failed: %v", err)
}
if loginErr != nil {
return fmt.Errorf("login failed: %v", loginErr)
}
if loginResp.NeedsSSOLogin {
if err := handleSSOLogin(ctx, cmd, loginResp, client, pm); err != nil {
return fmt.Errorf("sso login failed: %v", err)
}
}
return nil
}
func getActiveProfile(ctx context.Context, pm *profilemanager.ProfileManager, profileName string, username string) (*profilemanager.Profile, error) {
// switch profile if provided
if profileName != "" {
if err := switchProfileOnDaemon(ctx, pm, profileName, username); err != nil {
return nil, fmt.Errorf("switch profile: %v", err)
}
}
activeProf, err := pm.GetActiveProfile()
if err != nil {
return nil, fmt.Errorf("get active profile: %v", err)
}
if activeProf == nil {
return nil, fmt.Errorf("active profile not found, please run 'netbird profile create' first")
}
return activeProf, nil
}
func switchProfileOnDaemon(ctx context.Context, pm *profilemanager.ProfileManager, profileName string, username string) error {
err := switchProfile(context.Background(), profileName, username)
if err != nil {
return fmt.Errorf("switch profile on daemon: %v", err)
}
err = pm.SwitchProfile(profileName)
if err != nil {
return fmt.Errorf("switch profile: %v", err)
}
conn, err := DialClientGRPCServer(ctx, daemonAddr)
if err != nil {
log.Errorf("failed to connect to service CLI interface %v", err)
return err
}
defer conn.Close()
client := proto.NewDaemonServiceClient(conn)
status, err := client.Status(ctx, &proto.StatusRequest{})
if err != nil {
return fmt.Errorf("unable to get daemon status: %v", err)
}
if status.Status == string(internal.StatusConnected) {
if _, err := client.Down(ctx, &proto.DownRequest{}); err != nil {
log.Errorf("call service down method: %v", err)
return err
}
}
return nil
}
func switchProfile(ctx context.Context, profileName string, username string) error {
conn, err := DialClientGRPCServer(ctx, daemonAddr)
if err != nil {
return fmt.Errorf("failed to connect to daemon error: %v\n"+
"If the daemon is not running please run: "+
"\nnetbird service install \nnetbird service start\n", err)
}
defer conn.Close()
client := proto.NewDaemonServiceClient(conn)
_, err = client.SwitchProfile(ctx, &proto.SwitchProfileRequest{
ProfileName: &profileName,
Username: &username,
})
if err != nil {
return fmt.Errorf("switch profile failed: %v", err)
}
return nil
}
func doForegroundLogin(ctx context.Context, cmd *cobra.Command, setupKey string, activeProf *profilemanager.Profile) error {
err := handleRebrand(cmd)
if err != nil {
return err
}
// update host's static platform and system information
system.UpdateStaticInfo()
configFilePath, err := activeProf.FilePath()
if err != nil {
return fmt.Errorf("get active profile file path: %v", err)
}
config, err := profilemanager.ReadConfig(configFilePath)
if err != nil {
return fmt.Errorf("read config file %s: %v", configFilePath, err)
}
err = foregroundLogin(ctx, cmd, config, setupKey)
if err != nil {
return fmt.Errorf("foreground login failed: %v", err)
}
cmd.Println("Logging successfully")
return nil
}
func handleSSOLogin(ctx context.Context, cmd *cobra.Command, loginResp *proto.LoginResponse, client proto.DaemonServiceClient, pm *profilemanager.ProfileManager) error {
openURL(cmd, loginResp.VerificationURIComplete, loginResp.UserCode, noBrowser)
resp, err := client.WaitSSOLogin(ctx, &proto.WaitSSOLoginRequest{UserCode: loginResp.UserCode, Hostname: hostName})
if err != nil {
return fmt.Errorf("waiting sso login failed with: %v", err)
}
if resp.Email != "" {
err = pm.SetActiveProfileState(&profilemanager.ProfileState{
Email: resp.Email,
})
if err != nil {
log.Warnf("failed to set active profile email: %v", err)
}
}
return nil
}
func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *profilemanager.Config, setupKey string) error {
func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *internal.Config, setupKey string) error {
needsLogin := false
err := WithBackOff(func() error {
@@ -313,8 +194,8 @@ func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *profileman
return nil
}
func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *profilemanager.Config) (*auth.TokenInfo, error) {
oAuthFlow, err := auth.NewOAuthFlow(ctx, config, isUnixRunningDesktop())
func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *internal.Config) (*auth.TokenInfo, error) {
oAuthFlow, err := auth.NewOAuthFlow(ctx, config, isLinuxRunningDesktop())
if err != nil {
return nil, err
}
@@ -362,23 +243,7 @@ func openURL(cmd *cobra.Command, verificationURIComplete, userCode string, noBro
}
}
// isUnixRunningDesktop checks if a Linux OS is running desktop environment
func isUnixRunningDesktop() bool {
if runtime.GOOS != "linux" && runtime.GOOS != "freebsd" {
return false
}
// isLinuxRunningDesktop checks if a Linux OS is running desktop environment
func isLinuxRunningDesktop() bool {
return os.Getenv("DESKTOP_SESSION") != "" || os.Getenv("XDG_CURRENT_DESKTOP") != ""
}
func setEnvAndFlags(cmd *cobra.Command) error {
SetFlagsFromEnvVars(rootCmd)
cmd.SetOut(cmd.OutOrStdout())
err := util.InitLog(logLevel, "console")
if err != nil {
return fmt.Errorf("failed initializing log %v", err)
}
return nil
}

View File

@@ -2,11 +2,11 @@ package cmd
import (
"fmt"
"os/user"
"strings"
"testing"
"github.com/netbirdio/netbird/client/internal/profilemanager"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/util"
)
@@ -14,41 +14,40 @@ func TestLogin(t *testing.T) {
mgmAddr := startTestingServices(t)
tempDir := t.TempDir()
currUser, err := user.Current()
if err != nil {
t.Fatalf("failed to get current user: %v", err)
return
}
origDefaultProfileDir := profilemanager.DefaultConfigPathDir
origActiveProfileStatePath := profilemanager.ActiveProfileStatePath
profilemanager.DefaultConfigPathDir = tempDir
profilemanager.ActiveProfileStatePath = tempDir + "/active_profile.json"
sm := profilemanager.ServiceManager{}
err = sm.SetActiveProfileState(&profilemanager.ActiveProfileState{
Name: "default",
Username: currUser.Username,
})
if err != nil {
t.Fatalf("failed to set active profile state: %v", err)
}
t.Cleanup(func() {
profilemanager.DefaultConfigPathDir = origDefaultProfileDir
profilemanager.ActiveProfileStatePath = origActiveProfileStatePath
})
confPath := tempDir + "/config.json"
mgmtURL := fmt.Sprintf("http://%s", mgmAddr)
rootCmd.SetArgs([]string{
"login",
"--config",
confPath,
"--log-file",
util.LogConsole,
"console",
"--setup-key",
strings.ToUpper("a2c8e62b-38f5-4553-b31e-dd66c696cebb"),
"--management-url",
mgmtURL,
})
// TODO(hakan): fix this test
_ = rootCmd.Execute()
err := rootCmd.Execute()
if err != nil {
t.Fatal(err)
}
// validate generated config
actualConf := &internal.Config{}
_, err = util.ReadJson(confPath, actualConf)
if err != nil {
t.Errorf("expected proper config file written, got broken %v", err)
}
if actualConf.ManagementURL.String() != mgmtURL {
t.Errorf("expected management URL %s got %s", mgmtURL, actualConf.ManagementURL.String())
}
if actualConf.WgIface != iface.WgInterfaceDefault {
t.Errorf("expected WgIfaceName %s got %s", iface.WgInterfaceDefault, actualConf.WgIface)
}
if len(actualConf.PrivateKey) == 0 {
t.Errorf("expected non empty Private key, got empty")
}
}

View File

@@ -1,236 +0,0 @@
package cmd
import (
"context"
"fmt"
"time"
"os/user"
"github.com/spf13/cobra"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/profilemanager"
"github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/util"
)
var profileCmd = &cobra.Command{
Use: "profile",
Short: "manage Netbird profiles",
Long: `Manage Netbird profiles, allowing you to list, switch, and remove profiles.`,
}
var profileListCmd = &cobra.Command{
Use: "list",
Short: "list all profiles",
Long: `List all available profiles in the Netbird client.`,
RunE: listProfilesFunc,
}
var profileAddCmd = &cobra.Command{
Use: "add <profile_name>",
Short: "add a new profile",
Long: `Add a new profile to the Netbird client. The profile name must be unique.`,
Args: cobra.ExactArgs(1),
RunE: addProfileFunc,
}
var profileRemoveCmd = &cobra.Command{
Use: "remove <profile_name>",
Short: "remove a profile",
Long: `Remove a profile from the Netbird client. The profile must not be active.`,
Args: cobra.ExactArgs(1),
RunE: removeProfileFunc,
}
var profileSelectCmd = &cobra.Command{
Use: "select <profile_name>",
Short: "select a profile",
Long: `Select a profile to be the active profile in the Netbird client. The profile must exist.`,
Args: cobra.ExactArgs(1),
RunE: selectProfileFunc,
}
func setupCmd(cmd *cobra.Command) error {
SetFlagsFromEnvVars(rootCmd)
SetFlagsFromEnvVars(cmd)
cmd.SetOut(cmd.OutOrStdout())
err := util.InitLog(logLevel, "console")
if err != nil {
return err
}
return nil
}
func listProfilesFunc(cmd *cobra.Command, _ []string) error {
if err := setupCmd(cmd); err != nil {
return err
}
conn, err := DialClientGRPCServer(cmd.Context(), daemonAddr)
if err != nil {
return fmt.Errorf("connect to service CLI interface: %w", err)
}
defer conn.Close()
currUser, err := user.Current()
if err != nil {
return fmt.Errorf("get current user: %w", err)
}
daemonClient := proto.NewDaemonServiceClient(conn)
profiles, err := daemonClient.ListProfiles(cmd.Context(), &proto.ListProfilesRequest{
Username: currUser.Username,
})
if err != nil {
return err
}
// list profiles, add a tick if the profile is active
cmd.Println("Found", len(profiles.Profiles), "profiles:")
for _, profile := range profiles.Profiles {
// use a cross to indicate the passive profiles
activeMarker := "✗"
if profile.IsActive {
activeMarker = "✓"
}
cmd.Println(activeMarker, profile.Name)
}
return nil
}
func addProfileFunc(cmd *cobra.Command, args []string) error {
if err := setupCmd(cmd); err != nil {
return err
}
conn, err := DialClientGRPCServer(cmd.Context(), daemonAddr)
if err != nil {
return fmt.Errorf("connect to service CLI interface: %w", err)
}
defer conn.Close()
currUser, err := user.Current()
if err != nil {
return fmt.Errorf("get current user: %w", err)
}
daemonClient := proto.NewDaemonServiceClient(conn)
profileName := args[0]
_, err = daemonClient.AddProfile(cmd.Context(), &proto.AddProfileRequest{
ProfileName: profileName,
Username: currUser.Username,
})
if err != nil {
return err
}
cmd.Println("Profile added successfully:", profileName)
return nil
}
func removeProfileFunc(cmd *cobra.Command, args []string) error {
if err := setupCmd(cmd); err != nil {
return err
}
conn, err := DialClientGRPCServer(cmd.Context(), daemonAddr)
if err != nil {
return fmt.Errorf("connect to service CLI interface: %w", err)
}
defer conn.Close()
currUser, err := user.Current()
if err != nil {
return fmt.Errorf("get current user: %w", err)
}
daemonClient := proto.NewDaemonServiceClient(conn)
profileName := args[0]
_, err = daemonClient.RemoveProfile(cmd.Context(), &proto.RemoveProfileRequest{
ProfileName: profileName,
Username: currUser.Username,
})
if err != nil {
return err
}
cmd.Println("Profile removed successfully:", profileName)
return nil
}
func selectProfileFunc(cmd *cobra.Command, args []string) error {
if err := setupCmd(cmd); err != nil {
return err
}
profileManager := profilemanager.NewProfileManager()
profileName := args[0]
currUser, err := user.Current()
if err != nil {
return fmt.Errorf("get current user: %w", err)
}
ctx, cancel := context.WithTimeout(context.Background(), time.Second*7)
defer cancel()
conn, err := DialClientGRPCServer(ctx, daemonAddr)
if err != nil {
return fmt.Errorf("connect to service CLI interface: %w", err)
}
defer conn.Close()
daemonClient := proto.NewDaemonServiceClient(conn)
profiles, err := daemonClient.ListProfiles(ctx, &proto.ListProfilesRequest{
Username: currUser.Username,
})
if err != nil {
return fmt.Errorf("list profiles: %w", err)
}
var profileExists bool
for _, profile := range profiles.Profiles {
if profile.Name == profileName {
profileExists = true
break
}
}
if !profileExists {
return fmt.Errorf("profile %s does not exist", profileName)
}
if err := switchProfile(cmd.Context(), profileName, currUser.Username); err != nil {
return err
}
err = profileManager.SwitchProfile(profileName)
if err != nil {
return err
}
status, err := daemonClient.Status(ctx, &proto.StatusRequest{})
if err != nil {
return fmt.Errorf("get service status: %w", err)
}
if status.Status == string(internal.StatusConnected) {
if _, err := daemonClient.Down(ctx, &proto.DownRequest{}); err != nil {
return fmt.Errorf("call service down method: %w", err)
}
}
cmd.Println("Profile switched successfully to:", profileName)
return nil
}

View File

@@ -10,7 +10,6 @@ import (
"os/signal"
"path"
"runtime"
"slices"
"strings"
"syscall"
"time"
@@ -22,26 +21,31 @@ import (
"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"
"github.com/netbirdio/netbird/client/internal/profilemanager"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/upload-server/types"
)
const (
externalIPMapFlag = "external-ip-map"
dnsResolverAddress = "dns-resolver-address"
enableRosenpassFlag = "enable-rosenpass"
rosenpassPermissiveFlag = "rosenpass-permissive"
preSharedKeyFlag = "preshared-key"
interfaceNameFlag = "interface-name"
wireguardPortFlag = "wireguard-port"
networkMonitorFlag = "network-monitor"
disableAutoConnectFlag = "disable-auto-connect"
serverSSHAllowedFlag = "allow-server-ssh"
extraIFaceBlackListFlag = "extra-iface-blacklist"
dnsRouteIntervalFlag = "dns-router-interval"
enableLazyConnectionFlag = "enable-lazy-connection"
externalIPMapFlag = "external-ip-map"
dnsResolverAddress = "dns-resolver-address"
enableRosenpassFlag = "enable-rosenpass"
rosenpassPermissiveFlag = "rosenpass-permissive"
preSharedKeyFlag = "preshared-key"
interfaceNameFlag = "interface-name"
wireguardPortFlag = "wireguard-port"
networkMonitorFlag = "network-monitor"
disableAutoConnectFlag = "disable-auto-connect"
serverSSHAllowedFlag = "allow-server-ssh"
extraIFaceBlackListFlag = "extra-iface-blacklist"
dnsRouteIntervalFlag = "dns-router-interval"
systemInfoFlag = "system-info"
blockLANAccessFlag = "block-lan-access"
uploadBundle = "upload-bundle"
uploadBundleURL = "upload-bundle-url"
)
var (
configPath string
defaultConfigPathDir string
defaultConfigPath string
oldDefaultConfigPathDir string
@@ -51,7 +55,7 @@ var (
defaultLogFile string
oldDefaultLogFileDir string
oldDefaultLogFile string
logFiles []string
logFile string
daemonAddr string
managementURL string
adminURL string
@@ -67,12 +71,15 @@ var (
interfaceName string
wireguardPort uint16
networkMonitor bool
serviceName string
autoConnectDisabled bool
extraIFaceBlackList []string
anonymizeFlag bool
debugSystemInfoFlag bool
dnsRouteInterval time.Duration
lazyConnEnabled bool
profilesDisabled bool
blockLANAccess bool
debugUploadBundle bool
debugUploadBundleURL string
rootCmd = &cobra.Command{
Use: "netbird",
@@ -116,19 +123,26 @@ func init() {
defaultDaemonAddr = "tcp://127.0.0.1:41731"
}
defaultServiceName := "netbird"
if runtime.GOOS == "windows" {
defaultServiceName = "Netbird"
}
rootCmd.PersistentFlags().StringVar(&daemonAddr, "daemon-addr", defaultDaemonAddr, "Daemon service address to serve CLI requests [unix|tcp]://[path|host:port]")
rootCmd.PersistentFlags().StringVarP(&managementURL, "management-url", "m", "", fmt.Sprintf("Management Service URL [http|https]://[host]:[port] (default \"%s\")", profilemanager.DefaultManagementURL))
rootCmd.PersistentFlags().StringVar(&adminURL, "admin-url", "", fmt.Sprintf("Admin Panel URL [http|https]://[host]:[port] (default \"%s\")", profilemanager.DefaultAdminURL))
rootCmd.PersistentFlags().StringVarP(&managementURL, "management-url", "m", "", fmt.Sprintf("Management Service URL [http|https]://[host]:[port] (default \"%s\")", internal.DefaultManagementURL))
rootCmd.PersistentFlags().StringVar(&adminURL, "admin-url", "", fmt.Sprintf("Admin Panel URL [http|https]://[host]:[port] (default \"%s\")", internal.DefaultAdminURL))
rootCmd.PersistentFlags().StringVarP(&serviceName, "service", "s", defaultServiceName, "Netbird system service name")
rootCmd.PersistentFlags().StringVarP(&configPath, "config", "c", defaultConfigPath, "Netbird config file location")
rootCmd.PersistentFlags().StringVarP(&logLevel, "log-level", "l", "info", "sets Netbird log level")
rootCmd.PersistentFlags().StringSliceVar(&logFiles, "log-file", []string{defaultLogFile}, "sets Netbird log paths written to simultaneously. If `console` is specified the log will be output to stdout. If `syslog` is specified the log will be sent to syslog daemon. You can pass the flag multiple times or separate entries by `,` character")
rootCmd.PersistentFlags().StringVar(&logFile, "log-file", defaultLogFile, "sets Netbird log path. If console is specified the log will be output to stdout. If syslog is specified the log will be sent to syslog daemon.")
rootCmd.PersistentFlags().StringVarP(&setupKey, "setup-key", "k", "", "Setup key obtained from the Management Service Dashboard (used to register peer)")
rootCmd.PersistentFlags().StringVar(&setupKeyPath, "setup-key-file", "", "The path to a setup key obtained from the Management Service Dashboard (used to register peer) This is ignored if the setup-key flag is provided.")
rootCmd.MarkFlagsMutuallyExclusive("setup-key", "setup-key-file")
rootCmd.PersistentFlags().StringVar(&preSharedKey, preSharedKeyFlag, "", "Sets Wireguard PreSharedKey property. If set, then only peers that have the same key can communicate.")
rootCmd.PersistentFlags().StringVarP(&hostName, "hostname", "n", "", "Sets a custom hostname for the device")
rootCmd.PersistentFlags().BoolVarP(&anonymizeFlag, "anonymize", "A", false, "anonymize IP addresses and non-netbird.io domains in logs and status output")
rootCmd.PersistentFlags().StringVarP(&configPath, "config", "c", defaultConfigPath, "(DEPRECATED) Netbird config file location")
rootCmd.AddCommand(serviceCmd)
rootCmd.AddCommand(upCmd)
rootCmd.AddCommand(downCmd)
rootCmd.AddCommand(statusCmd)
@@ -138,7 +152,9 @@ func init() {
rootCmd.AddCommand(networksCMD)
rootCmd.AddCommand(forwardingRulesCmd)
rootCmd.AddCommand(debugCmd)
rootCmd.AddCommand(profileCmd)
serviceCmd.AddCommand(runCmd, startCmd, stopCmd, restartCmd) // service control commands are subcommands of service
serviceCmd.AddCommand(installCmd, uninstallCmd) // service installer commands are subcommands of service
networksCMD.AddCommand(routesListCmd)
networksCMD.AddCommand(routesSelectCmd, routesDeselectCmd)
@@ -151,12 +167,6 @@ func init() {
debugCmd.AddCommand(forCmd)
debugCmd.AddCommand(persistenceCmd)
// profile commands
profileCmd.AddCommand(profileListCmd)
profileCmd.AddCommand(profileAddCmd)
profileCmd.AddCommand(profileRemoveCmd)
profileCmd.AddCommand(profileSelectCmd)
upCmd.PersistentFlags().StringSliceVar(&natExternalIPs, externalIPMapFlag, nil,
`Sets external IPs maps between local addresses and interfaces.`+
`You can specify a comma-separated list with a single IP and IP/IP or IP/Interface Name. `+
@@ -174,8 +184,10 @@ func init() {
upCmd.PersistentFlags().BoolVar(&rosenpassPermissive, rosenpassPermissiveFlag, false, "[Experimental] Enable Rosenpass in permissive mode to allow this peer to accept WireGuard connections without requiring Rosenpass functionality from peers that do not have Rosenpass enabled.")
upCmd.PersistentFlags().BoolVar(&serverSSHAllowed, serverSSHAllowedFlag, false, "Allow SSH server on peer. If enabled, the SSH server will be permitted")
upCmd.PersistentFlags().BoolVar(&autoConnectDisabled, disableAutoConnectFlag, false, "Disables auto-connect feature. If enabled, then the client won't connect automatically when the service starts.")
upCmd.PersistentFlags().BoolVar(&lazyConnEnabled, enableLazyConnectionFlag, false, "[Experimental] Enable the lazy connection feature. If enabled, the client will establish connections on-demand. Note: this setting may be overridden by management configuration.")
debugCmd.PersistentFlags().BoolVarP(&debugSystemInfoFlag, systemInfoFlag, "S", true, "Adds system information to the debug bundle")
debugCmd.PersistentFlags().BoolVarP(&debugUploadBundle, uploadBundle, "U", false, fmt.Sprintf("Uploads the debug bundle to a server from URL defined by %s", uploadBundleURL))
debugCmd.PersistentFlags().StringVar(&debugUploadBundleURL, uploadBundleURL, types.DefaultBundleURL, "Service URL to get an URL to upload the debug bundle")
}
// SetupCloseHandler handles SIGTERM signal and exits with success
@@ -183,13 +195,14 @@ func SetupCloseHandler(ctx context.Context, cancel context.CancelFunc) {
termCh := make(chan os.Signal, 1)
signal.Notify(termCh, os.Interrupt, syscall.SIGINT, syscall.SIGTERM)
go func() {
defer cancel()
done := ctx.Done()
select {
case <-ctx.Done():
case <-done:
case <-termCh:
}
log.Info("shutdown signal received")
cancel()
}()
}
@@ -273,7 +286,7 @@ func getSetupKeyFromFile(setupKeyPath string) (string, error) {
func handleRebrand(cmd *cobra.Command) error {
var err error
if slices.Contains(logFiles, defaultLogFile) {
if logFile == defaultLogFile {
if migrateToNetbird(oldDefaultLogFile, defaultLogFile) {
cmd.Printf("will copy Log dir %s and its content to %s\n", oldDefaultLogFileDir, defaultLogFileDir)
err = cpDir(oldDefaultLogFileDir, defaultLogFileDir)
@@ -282,14 +295,15 @@ func handleRebrand(cmd *cobra.Command) error {
}
}
}
if migrateToNetbird(oldDefaultConfigPath, defaultConfigPath) {
cmd.Printf("will copy Config dir %s and its content to %s\n", oldDefaultConfigPathDir, defaultConfigPathDir)
err = cpDir(oldDefaultConfigPathDir, defaultConfigPathDir)
if err != nil {
return err
if configPath == defaultConfigPath {
if migrateToNetbird(oldDefaultConfigPath, defaultConfigPath) {
cmd.Printf("will copy Config dir %s and its content to %s\n", oldDefaultConfigPathDir, defaultConfigPathDir)
err = cpDir(oldDefaultConfigPathDir, defaultConfigPathDir)
if err != nil {
return err
}
}
}
return nil
}

View File

@@ -1,15 +1,11 @@
//go:build !ios && !android
package cmd
import (
"context"
"fmt"
"runtime"
"strings"
"sync"
"github.com/kardianos/service"
log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
"google.golang.org/grpc"
@@ -17,16 +13,6 @@ import (
"github.com/netbirdio/netbird/client/server"
)
var serviceCmd = &cobra.Command{
Use: "service",
Short: "manages Netbird service",
}
var (
serviceName string
serviceEnvVars []string
)
type program struct {
ctx context.Context
cancel context.CancelFunc
@@ -35,81 +21,30 @@ type program struct {
serverInstanceMu sync.Mutex
}
func init() {
defaultServiceName := "netbird"
if runtime.GOOS == "windows" {
defaultServiceName = "Netbird"
}
serviceCmd.AddCommand(runCmd, startCmd, stopCmd, restartCmd, svcStatusCmd, installCmd, uninstallCmd, reconfigureCmd)
serviceCmd.PersistentFlags().BoolVar(&profilesDisabled, "disable-profiles", false, "Disables profiles feature. If enabled, the client will not be able to change or edit any profile.")
rootCmd.PersistentFlags().StringVarP(&serviceName, "service", "s", defaultServiceName, "Netbird system service name")
serviceEnvDesc := `Sets extra environment variables for the service. ` +
`You can specify a comma-separated list of KEY=VALUE pairs. ` +
`E.g. --service-env LOG_LEVEL=debug,CUSTOM_VAR=value`
installCmd.Flags().StringSliceVar(&serviceEnvVars, "service-env", nil, serviceEnvDesc)
reconfigureCmd.Flags().StringSliceVar(&serviceEnvVars, "service-env", nil, serviceEnvDesc)
rootCmd.AddCommand(serviceCmd)
}
func newProgram(ctx context.Context, cancel context.CancelFunc) *program {
ctx = internal.CtxInitState(ctx)
return &program{ctx: ctx, cancel: cancel}
}
func newSVCConfig() (*service.Config, error) {
config := &service.Config{
func newSVCConfig() *service.Config {
return &service.Config{
Name: serviceName,
DisplayName: "Netbird",
Description: "Netbird mesh network client",
Description: "A WireGuard-based mesh network that connects your devices into a single private network.",
Option: make(service.KeyValue),
EnvVars: make(map[string]string),
}
if len(serviceEnvVars) > 0 {
extraEnvs, err := parseServiceEnvVars(serviceEnvVars)
if err != nil {
return nil, fmt.Errorf("parse service environment variables: %w", err)
}
config.EnvVars = extraEnvs
}
if runtime.GOOS == "linux" {
config.EnvVars["SYSTEMD_UNIT"] = serviceName
}
return config, nil
}
func newSVC(prg *program, conf *service.Config) (service.Service, error) {
return service.New(prg, conf)
}
func parseServiceEnvVars(envVars []string) (map[string]string, error) {
envMap := make(map[string]string)
for _, env := range envVars {
if env == "" {
continue
}
parts := strings.SplitN(env, "=", 2)
if len(parts) != 2 {
return nil, fmt.Errorf("invalid environment variable format: %s (expected KEY=VALUE)", env)
}
key := strings.TrimSpace(parts[0])
value := strings.TrimSpace(parts[1])
if key == "" {
return nil, fmt.Errorf("empty environment variable key in: %s", env)
}
envMap[key] = value
s, err := service.New(prg, conf)
if err != nil {
log.Fatal(err)
return nil, err
}
return envMap, nil
return s, nil
}
var serviceCmd = &cobra.Command{
Use: "service",
Short: "manages Netbird service",
}

View File

@@ -1,5 +1,3 @@
//go:build !ios && !android
package cmd
import (
@@ -49,19 +47,20 @@ func (p *program) Start(svc service.Service) error {
listen, err := net.Listen(split[0], split[1])
if err != nil {
return fmt.Errorf("listen daemon interface: %w", err)
return fmt.Errorf("failed to listen daemon interface: %w", err)
}
go func() {
defer listen.Close()
if split[0] == "unix" {
if err := os.Chmod(split[1], 0666); err != nil {
err = os.Chmod(split[1], 0666)
if err != nil {
log.Errorf("failed setting daemon permissions: %v", split[1])
return
}
}
serverInstance := server.New(p.ctx, util.FindFirstLogPath(logFiles), profilesDisabled)
serverInstance := server.New(p.ctx, configPath, logFile)
if err := serverInstance.Start(); err != nil {
log.Fatalf("failed to start daemon: %v", err)
}
@@ -101,49 +100,37 @@ func (p *program) Stop(srv service.Service) error {
return nil
}
// Common setup for service control commands
func setupServiceControlCommand(cmd *cobra.Command, ctx context.Context, cancel context.CancelFunc) (service.Service, error) {
SetFlagsFromEnvVars(rootCmd)
SetFlagsFromEnvVars(serviceCmd)
cmd.SetOut(cmd.OutOrStdout())
if err := handleRebrand(cmd); err != nil {
return nil, err
}
if err := util.InitLog(logLevel, logFiles...); err != nil {
return nil, fmt.Errorf("init log: %w", err)
}
cfg, err := newSVCConfig()
if err != nil {
return nil, fmt.Errorf("create service config: %w", err)
}
s, err := newSVC(newProgram(ctx, cancel), cfg)
if err != nil {
return nil, err
}
return s, nil
}
var runCmd = &cobra.Command{
Use: "run",
Short: "runs Netbird as service",
RunE: func(cmd *cobra.Command, args []string) error {
ctx, cancel := context.WithCancel(cmd.Context())
SetFlagsFromEnvVars(rootCmd)
SetupCloseHandler(ctx, cancel)
SetupDebugHandler(ctx, nil, nil, nil, util.FindFirstLogPath(logFiles))
cmd.SetOut(cmd.OutOrStdout())
s, err := setupServiceControlCommand(cmd, ctx, cancel)
err := handleRebrand(cmd)
if err != nil {
return err
}
return s.Run()
err = util.InitLog(logLevel, logFile)
if err != nil {
return fmt.Errorf("failed initializing log %v", err)
}
ctx, cancel := context.WithCancel(cmd.Context())
SetupCloseHandler(ctx, cancel)
SetupDebugHandler(ctx, nil, nil, nil, logFile)
s, err := newSVC(newProgram(ctx, cancel), newSVCConfig())
if err != nil {
return err
}
err = s.Run()
if err != nil {
return err
}
return nil
},
}
@@ -151,14 +138,31 @@ var startCmd = &cobra.Command{
Use: "start",
Short: "starts Netbird service",
RunE: func(cmd *cobra.Command, args []string) error {
ctx, cancel := context.WithCancel(cmd.Context())
s, err := setupServiceControlCommand(cmd, ctx, cancel)
SetFlagsFromEnvVars(rootCmd)
cmd.SetOut(cmd.OutOrStdout())
err := handleRebrand(cmd)
if err != nil {
return err
}
if err := s.Start(); err != nil {
return fmt.Errorf("start service: %w", err)
err = util.InitLog(logLevel, logFile)
if err != nil {
return err
}
ctx, cancel := context.WithCancel(cmd.Context())
s, err := newSVC(newProgram(ctx, cancel), newSVCConfig())
if err != nil {
cmd.PrintErrln(err)
return err
}
err = s.Start()
if err != nil {
cmd.PrintErrln(err)
return err
}
cmd.Println("Netbird service has been started")
return nil
@@ -169,14 +173,29 @@ var stopCmd = &cobra.Command{
Use: "stop",
Short: "stops Netbird service",
RunE: func(cmd *cobra.Command, args []string) error {
ctx, cancel := context.WithCancel(cmd.Context())
s, err := setupServiceControlCommand(cmd, ctx, cancel)
SetFlagsFromEnvVars(rootCmd)
cmd.SetOut(cmd.OutOrStdout())
err := handleRebrand(cmd)
if err != nil {
return err
}
if err := s.Stop(); err != nil {
return fmt.Errorf("stop service: %w", err)
err = util.InitLog(logLevel, logFile)
if err != nil {
return fmt.Errorf("failed initializing log %v", err)
}
ctx, cancel := context.WithCancel(cmd.Context())
s, err := newSVC(newProgram(ctx, cancel), newSVCConfig())
if err != nil {
return err
}
err = s.Stop()
if err != nil {
return err
}
cmd.Println("Netbird service has been stopped")
return nil
@@ -187,48 +206,31 @@ var restartCmd = &cobra.Command{
Use: "restart",
Short: "restarts Netbird service",
RunE: func(cmd *cobra.Command, args []string) error {
ctx, cancel := context.WithCancel(cmd.Context())
s, err := setupServiceControlCommand(cmd, ctx, cancel)
SetFlagsFromEnvVars(rootCmd)
cmd.SetOut(cmd.OutOrStdout())
err := handleRebrand(cmd)
if err != nil {
return err
}
if err := s.Restart(); err != nil {
return fmt.Errorf("restart service: %w", err)
err = util.InitLog(logLevel, logFile)
if err != nil {
return fmt.Errorf("failed initializing log %v", err)
}
ctx, cancel := context.WithCancel(cmd.Context())
s, err := newSVC(newProgram(ctx, cancel), newSVCConfig())
if err != nil {
return err
}
err = s.Restart()
if err != nil {
return err
}
cmd.Println("Netbird service has been restarted")
return nil
},
}
var svcStatusCmd = &cobra.Command{
Use: "status",
Short: "shows Netbird service status",
RunE: func(cmd *cobra.Command, args []string) error {
ctx, cancel := context.WithCancel(cmd.Context())
s, err := setupServiceControlCommand(cmd, ctx, cancel)
if err != nil {
return err
}
status, err := s.Status()
if err != nil {
return fmt.Errorf("get service status: %w", err)
}
var statusText string
switch status {
case service.StatusRunning:
statusText = "Running"
case service.StatusStopped:
statusText = "Stopped"
case service.StatusUnknown:
statusText = "Unknown"
default:
statusText = fmt.Sprintf("Unknown (%d)", status)
}
cmd.Printf("Netbird service status: %s\n", statusText)
return nil
},
}

View File

@@ -1,121 +1,87 @@
//go:build !ios && !android
package cmd
import (
"context"
"errors"
"fmt"
"os"
"path/filepath"
"runtime"
"github.com/kardianos/service"
"github.com/spf13/cobra"
"github.com/netbirdio/netbird/util"
)
var ErrGetServiceStatus = fmt.Errorf("failed to get service status")
// Common service command setup
func setupServiceCommand(cmd *cobra.Command) error {
SetFlagsFromEnvVars(rootCmd)
SetFlagsFromEnvVars(serviceCmd)
cmd.SetOut(cmd.OutOrStdout())
return handleRebrand(cmd)
}
// Build service arguments for install/reconfigure
func buildServiceArguments() []string {
args := []string{
"service",
"run",
"--log-level",
logLevel,
"--daemon-addr",
daemonAddr,
}
if managementURL != "" {
args = append(args, "--management-url", managementURL)
}
for _, logFile := range logFiles {
args = append(args, "--log-file", logFile)
}
return args
}
// Configure platform-specific service settings
func configurePlatformSpecificSettings(svcConfig *service.Config) error {
if runtime.GOOS == "linux" {
// Respected only by systemd systems
svcConfig.Dependencies = []string{"After=network.target syslog.target"}
if logFile := util.FindFirstLogPath(logFiles); logFile != "" {
setStdLogPath := true
dir := filepath.Dir(logFile)
if _, err := os.Stat(dir); err != nil {
if err = os.MkdirAll(dir, 0750); err != nil {
setStdLogPath = false
}
}
if setStdLogPath {
svcConfig.Option["LogOutput"] = true
svcConfig.Option["LogDirectory"] = dir
}
}
}
if runtime.GOOS == "windows" {
svcConfig.Option["OnFailure"] = "restart"
}
return nil
}
// Create fully configured service config for install/reconfigure
func createServiceConfigForInstall() (*service.Config, error) {
svcConfig, err := newSVCConfig()
if err != nil {
return nil, fmt.Errorf("create service config: %w", err)
}
svcConfig.Arguments = buildServiceArguments()
if err = configurePlatformSpecificSettings(svcConfig); err != nil {
return nil, fmt.Errorf("configure platform-specific settings: %w", err)
}
return svcConfig, nil
}
var installCmd = &cobra.Command{
Use: "install",
Short: "installs Netbird service",
RunE: func(cmd *cobra.Command, args []string) error {
if err := setupServiceCommand(cmd); err != nil {
SetFlagsFromEnvVars(rootCmd)
cmd.SetOut(cmd.OutOrStdout())
err := handleRebrand(cmd)
if err != nil {
return err
}
svcConfig, err := createServiceConfigForInstall()
if err != nil {
return err
svcConfig := newSVCConfig()
svcConfig.Arguments = []string{
"service",
"run",
"--config",
configPath,
"--log-level",
logLevel,
"--daemon-addr",
daemonAddr,
}
if managementURL != "" {
svcConfig.Arguments = append(svcConfig.Arguments, "--management-url", managementURL)
}
if logFile != "console" {
svcConfig.Arguments = append(svcConfig.Arguments, "--log-file", logFile)
}
if runtime.GOOS == "linux" {
// Respected only by systemd systems
svcConfig.Dependencies = []string{"After=network.target syslog.target"}
if logFile != "console" {
setStdLogPath := true
dir := filepath.Dir(logFile)
_, err := os.Stat(dir)
if err != nil {
err = os.MkdirAll(dir, 0750)
if err != nil {
setStdLogPath = false
}
}
if setStdLogPath {
svcConfig.Option["LogOutput"] = true
svcConfig.Option["LogDirectory"] = dir
}
}
}
if runtime.GOOS == "windows" {
svcConfig.Option["OnFailure"] = "restart"
}
ctx, cancel := context.WithCancel(cmd.Context())
defer cancel()
s, err := newSVC(newProgram(ctx, cancel), svcConfig)
if err != nil {
cmd.PrintErrln(err)
return err
}
if err := s.Install(); err != nil {
return fmt.Errorf("install service: %w", err)
err = s.Install()
if err != nil {
cmd.PrintErrln(err)
return err
}
cmd.Println("Netbird service has been installed")
@@ -127,109 +93,27 @@ var uninstallCmd = &cobra.Command{
Use: "uninstall",
Short: "uninstalls Netbird service from system",
RunE: func(cmd *cobra.Command, args []string) error {
if err := setupServiceCommand(cmd); err != nil {
return err
}
SetFlagsFromEnvVars(rootCmd)
cfg, err := newSVCConfig()
cmd.SetOut(cmd.OutOrStdout())
err := handleRebrand(cmd)
if err != nil {
return fmt.Errorf("create service config: %w", err)
return err
}
ctx, cancel := context.WithCancel(cmd.Context())
defer cancel()
s, err := newSVC(newProgram(ctx, cancel), cfg)
s, err := newSVC(newProgram(ctx, cancel), newSVCConfig())
if err != nil {
return err
}
if err := s.Uninstall(); err != nil {
return fmt.Errorf("uninstall service: %w", err)
err = s.Uninstall()
if err != nil {
return err
}
cmd.Println("Netbird service has been uninstalled")
return nil
},
}
var reconfigureCmd = &cobra.Command{
Use: "reconfigure",
Short: "reconfigures Netbird service with new settings",
Long: `Reconfigures the Netbird service with new settings without manual uninstall/install.
This command will temporarily stop the service, update its configuration, and restart it if it was running.`,
RunE: func(cmd *cobra.Command, args []string) error {
if err := setupServiceCommand(cmd); err != nil {
return err
}
wasRunning, err := isServiceRunning()
if err != nil && !errors.Is(err, ErrGetServiceStatus) {
return fmt.Errorf("check service status: %w", err)
}
svcConfig, err := createServiceConfigForInstall()
if err != nil {
return err
}
ctx, cancel := context.WithCancel(cmd.Context())
defer cancel()
s, err := newSVC(newProgram(ctx, cancel), svcConfig)
if err != nil {
return fmt.Errorf("create service: %w", err)
}
if wasRunning {
cmd.Println("Stopping Netbird service...")
if err := s.Stop(); err != nil {
cmd.Printf("Warning: failed to stop service: %v\n", err)
}
}
cmd.Println("Removing existing service configuration...")
if err := s.Uninstall(); err != nil {
return fmt.Errorf("uninstall existing service: %w", err)
}
cmd.Println("Installing service with new configuration...")
if err := s.Install(); err != nil {
return fmt.Errorf("install service with new config: %w", err)
}
if wasRunning {
cmd.Println("Starting Netbird service...")
if err := s.Start(); err != nil {
return fmt.Errorf("start service after reconfigure: %w", err)
}
cmd.Println("Netbird service has been reconfigured and started")
} else {
cmd.Println("Netbird service has been reconfigured")
}
return nil
},
}
func isServiceRunning() (bool, error) {
cfg, err := newSVCConfig()
if err != nil {
return false, err
}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
s, err := newSVC(newProgram(ctx, cancel), cfg)
if err != nil {
return false, err
}
status, err := s.Status()
if err != nil {
return false, fmt.Errorf("%w: %w", ErrGetServiceStatus, err)
}
return status == service.StatusRunning, nil
}

View File

@@ -1,263 +0,0 @@
package cmd
import (
"context"
"fmt"
"os"
"runtime"
"testing"
"time"
"github.com/kardianos/service"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
const (
serviceStartTimeout = 10 * time.Second
serviceStopTimeout = 5 * time.Second
statusPollInterval = 500 * time.Millisecond
)
// waitForServiceStatus waits for service to reach expected status with timeout
func waitForServiceStatus(expectedStatus service.Status, timeout time.Duration) (bool, error) {
cfg, err := newSVCConfig()
if err != nil {
return false, err
}
ctxSvc, cancel := context.WithCancel(context.Background())
defer cancel()
s, err := newSVC(newProgram(ctxSvc, cancel), cfg)
if err != nil {
return false, err
}
ctx, timeoutCancel := context.WithTimeout(context.Background(), timeout)
defer timeoutCancel()
ticker := time.NewTicker(statusPollInterval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return false, fmt.Errorf("timeout waiting for service status %v", expectedStatus)
case <-ticker.C:
status, err := s.Status()
if err != nil {
// Continue polling on transient errors
continue
}
if status == expectedStatus {
return true, nil
}
}
}
}
// TestServiceLifecycle tests the complete service lifecycle
func TestServiceLifecycle(t *testing.T) {
// TODO: Add support for Windows and macOS
if runtime.GOOS != "linux" && runtime.GOOS != "freebsd" {
t.Skipf("Skipping service lifecycle test on unsupported OS: %s", runtime.GOOS)
}
if os.Getenv("CONTAINER") == "true" {
t.Skip("Skipping service lifecycle test in container environment")
}
originalServiceName := serviceName
serviceName = "netbirdtest" + fmt.Sprintf("%d", time.Now().Unix())
defer func() {
serviceName = originalServiceName
}()
tempDir := t.TempDir()
configPath = fmt.Sprintf("%s/netbird-test-config.json", tempDir)
logLevel = "info"
daemonAddr = fmt.Sprintf("unix://%s/netbird-test.sock", tempDir)
ctx := context.Background()
t.Run("Install", func(t *testing.T) {
installCmd.SetContext(ctx)
err := installCmd.RunE(installCmd, []string{})
require.NoError(t, err)
cfg, err := newSVCConfig()
require.NoError(t, err)
ctxSvc, cancel := context.WithCancel(context.Background())
defer cancel()
s, err := newSVC(newProgram(ctxSvc, cancel), cfg)
require.NoError(t, err)
status, err := s.Status()
assert.NoError(t, err)
assert.NotEqual(t, service.StatusUnknown, status)
})
t.Run("Start", func(t *testing.T) {
startCmd.SetContext(ctx)
err := startCmd.RunE(startCmd, []string{})
require.NoError(t, err)
running, err := waitForServiceStatus(service.StatusRunning, serviceStartTimeout)
require.NoError(t, err)
assert.True(t, running)
})
t.Run("Restart", func(t *testing.T) {
restartCmd.SetContext(ctx)
err := restartCmd.RunE(restartCmd, []string{})
require.NoError(t, err)
running, err := waitForServiceStatus(service.StatusRunning, serviceStartTimeout)
require.NoError(t, err)
assert.True(t, running)
})
t.Run("Reconfigure", func(t *testing.T) {
originalLogLevel := logLevel
logLevel = "debug"
defer func() {
logLevel = originalLogLevel
}()
reconfigureCmd.SetContext(ctx)
err := reconfigureCmd.RunE(reconfigureCmd, []string{})
require.NoError(t, err)
running, err := waitForServiceStatus(service.StatusRunning, serviceStartTimeout)
require.NoError(t, err)
assert.True(t, running)
})
t.Run("Stop", func(t *testing.T) {
stopCmd.SetContext(ctx)
err := stopCmd.RunE(stopCmd, []string{})
require.NoError(t, err)
stopped, err := waitForServiceStatus(service.StatusStopped, serviceStopTimeout)
require.NoError(t, err)
assert.True(t, stopped)
})
t.Run("Uninstall", func(t *testing.T) {
uninstallCmd.SetContext(ctx)
err := uninstallCmd.RunE(uninstallCmd, []string{})
require.NoError(t, err)
cfg, err := newSVCConfig()
require.NoError(t, err)
ctxSvc, cancel := context.WithCancel(context.Background())
defer cancel()
s, err := newSVC(newProgram(ctxSvc, cancel), cfg)
require.NoError(t, err)
_, err = s.Status()
assert.Error(t, err)
})
}
// TestServiceEnvVars tests environment variable parsing
func TestServiceEnvVars(t *testing.T) {
tests := []struct {
name string
envVars []string
expected map[string]string
expectErr bool
}{
{
name: "Valid single env var",
envVars: []string{"LOG_LEVEL=debug"},
expected: map[string]string{
"LOG_LEVEL": "debug",
},
},
{
name: "Valid multiple env vars",
envVars: []string{"LOG_LEVEL=debug", "CUSTOM_VAR=value"},
expected: map[string]string{
"LOG_LEVEL": "debug",
"CUSTOM_VAR": "value",
},
},
{
name: "Env var with spaces",
envVars: []string{" KEY = value "},
expected: map[string]string{
"KEY": "value",
},
},
{
name: "Invalid format - no equals",
envVars: []string{"INVALID"},
expectErr: true,
},
{
name: "Invalid format - empty key",
envVars: []string{"=value"},
expectErr: true,
},
{
name: "Empty value is valid",
envVars: []string{"KEY="},
expected: map[string]string{
"KEY": "",
},
},
{
name: "Empty slice",
envVars: []string{},
expected: map[string]string{},
},
{
name: "Empty string in slice",
envVars: []string{"", "KEY=value", ""},
expected: map[string]string{"KEY": "value"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result, err := parseServiceEnvVars(tt.envVars)
if tt.expectErr {
assert.Error(t, err)
} else {
require.NoError(t, err)
assert.Equal(t, tt.expected, result)
}
})
}
}
// TestServiceConfigWithEnvVars tests service config creation with env vars
func TestServiceConfigWithEnvVars(t *testing.T) {
originalServiceName := serviceName
originalServiceEnvVars := serviceEnvVars
defer func() {
serviceName = originalServiceName
serviceEnvVars = originalServiceEnvVars
}()
serviceName = "test-service"
serviceEnvVars = []string{"TEST_VAR=test_value", "ANOTHER_VAR=another_value"}
cfg, err := newSVCConfig()
require.NoError(t, err)
assert.Equal(t, "test-service", cfg.Name)
assert.Equal(t, "test_value", cfg.EnvVars["TEST_VAR"])
assert.Equal(t, "another_value", cfg.EnvVars["ANOTHER_VAR"])
if runtime.GOOS == "linux" {
assert.Equal(t, "test-service", cfg.EnvVars["SYSTEMD_UNIT"])
}
}

View File

@@ -12,15 +12,14 @@ import (
"github.com/spf13/cobra"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/profilemanager"
nbssh "github.com/netbirdio/netbird/client/ssh"
"github.com/netbirdio/netbird/util"
)
var (
port int
userName = "root"
host string
port int
user = "root"
host string
)
var sshCmd = &cobra.Command{
@@ -32,7 +31,7 @@ var sshCmd = &cobra.Command{
split := strings.Split(args[0], "@")
if len(split) == 2 {
userName = split[0]
user = split[0]
host = split[1]
} else {
host = args[0]
@@ -47,7 +46,7 @@ var sshCmd = &cobra.Command{
cmd.SetOut(cmd.OutOrStdout())
err := util.InitLog(logLevel, util.LogConsole)
err := util.InitLog(logLevel, "console")
if err != nil {
return fmt.Errorf("failed initializing log %v", err)
}
@@ -59,19 +58,11 @@ var sshCmd = &cobra.Command{
ctx := internal.CtxInitState(cmd.Context())
pm := profilemanager.NewProfileManager()
activeProf, err := pm.GetActiveProfile()
config, err := internal.UpdateConfig(internal.ConfigInput{
ConfigPath: configPath,
})
if err != nil {
return fmt.Errorf("get active profile: %v", err)
}
profPath, err := activeProf.FilePath()
if err != nil {
return fmt.Errorf("get active profile path: %v", err)
}
config, err := profilemanager.ReadConfig(profPath)
if err != nil {
return fmt.Errorf("read profile config: %v", err)
return err
}
sig := make(chan os.Signal, 1)
@@ -98,7 +89,7 @@ var sshCmd = &cobra.Command{
}
func runSSH(ctx context.Context, addr string, pemKey []byte, cmd *cobra.Command) error {
c, err := nbssh.DialWithKey(fmt.Sprintf("%s:%d", addr, port), userName, pemKey)
c, err := nbssh.DialWithKey(fmt.Sprintf("%s:%d", addr, port), user, pemKey)
if err != nil {
cmd.Printf("Error: %v\n", err)
cmd.Printf("Couldn't connect. Please check the connection status or if the ssh server is enabled on the other peer" +

View File

@@ -11,7 +11,6 @@ import (
"google.golang.org/grpc/status"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/profilemanager"
"github.com/netbirdio/netbird/client/proto"
nbstatus "github.com/netbirdio/netbird/client/status"
"github.com/netbirdio/netbird/util"
@@ -27,7 +26,6 @@ var (
statusFilter string
ipsFilterMap map[string]struct{}
prefixNamesFilterMap map[string]struct{}
connectionTypeFilter string
)
var statusCmd = &cobra.Command{
@@ -46,8 +44,7 @@ func init() {
statusCmd.MarkFlagsMutuallyExclusive("detail", "json", "yaml", "ipv4")
statusCmd.PersistentFlags().StringSliceVar(&ipsFilter, "filter-by-ips", []string{}, "filters the detailed output by a list of one or more IPs, e.g., --filter-by-ips 100.64.0.100,100.64.0.200")
statusCmd.PersistentFlags().StringSliceVar(&prefixNamesFilter, "filter-by-names", []string{}, "filters the detailed output by a list of one or more peer FQDN or hostnames, e.g., --filter-by-names peer-a,peer-b.netbird.cloud")
statusCmd.PersistentFlags().StringVar(&statusFilter, "filter-by-status", "", "filters the detailed output by connection status(idle|connecting|connected), e.g., --filter-by-status connected")
statusCmd.PersistentFlags().StringVar(&connectionTypeFilter, "filter-by-connection-type", "", "filters the detailed output by connection type (P2P|Relayed), e.g., --filter-by-connection-type P2P")
statusCmd.PersistentFlags().StringVar(&statusFilter, "filter-by-status", "", "filters the detailed output by connection status(connected|disconnected), e.g., --filter-by-status connected")
}
func statusFunc(cmd *cobra.Command, args []string) error {
@@ -60,7 +57,7 @@ func statusFunc(cmd *cobra.Command, args []string) error {
return err
}
err = util.InitLog(logLevel, util.LogConsole)
err = util.InitLog(logLevel, "console")
if err != nil {
return fmt.Errorf("failed initializing log %v", err)
}
@@ -72,10 +69,7 @@ func statusFunc(cmd *cobra.Command, args []string) error {
return err
}
status := resp.GetStatus()
if status == string(internal.StatusNeedsLogin) || status == string(internal.StatusLoginFailed) ||
status == string(internal.StatusSessionExpired) {
if resp.GetStatus() == string(internal.StatusNeedsLogin) || resp.GetStatus() == string(internal.StatusLoginFailed) {
cmd.Printf("Daemon status: %s\n\n"+
"Run UP command to log in with SSO (interactive login):\n\n"+
" netbird up \n\n"+
@@ -92,13 +86,7 @@ func statusFunc(cmd *cobra.Command, args []string) error {
return nil
}
pm := profilemanager.NewProfileManager()
var profName string
if activeProf, err := pm.GetActiveProfile(); err == nil {
profName = activeProf.Name
}
var outputInformationHolder = nbstatus.ConvertToStatusOutputOverview(resp, anonymizeFlag, statusFilter, prefixNamesFilter, prefixNamesFilterMap, ipsFilterMap, connectionTypeFilter, profName)
var outputInformationHolder = nbstatus.ConvertToStatusOutputOverview(resp, anonymizeFlag, statusFilter, prefixNamesFilter, prefixNamesFilterMap, ipsFilterMap)
var statusOutputString string
switch {
case detailFlag:
@@ -129,7 +117,7 @@ func getStatus(ctx context.Context) (*proto.StatusResponse, error) {
}
defer conn.Close()
resp, err := proto.NewDaemonServiceClient(conn).Status(ctx, &proto.StatusRequest{GetFullPeerStatus: true, ShouldRunProbes: true})
resp, err := proto.NewDaemonServiceClient(conn).Status(ctx, &proto.StatusRequest{GetFullPeerStatus: true})
if err != nil {
return nil, fmt.Errorf("status failed: %v", status.Convert(err).Message())
}
@@ -139,12 +127,12 @@ func getStatus(ctx context.Context) (*proto.StatusResponse, error) {
func parseFilters() error {
switch strings.ToLower(statusFilter) {
case "", "idle", "connecting", "connected":
case "", "disconnected", "connected":
if strings.ToLower(statusFilter) != "" {
enableDetailFlagWhenFilterFlag()
}
default:
return fmt.Errorf("wrong status filter, should be one of connected|connecting|idle, got: %s", statusFilter)
return fmt.Errorf("wrong status filter, should be one of connected|disconnected, got: %s", statusFilter)
}
if len(ipsFilter) > 0 {
@@ -165,15 +153,6 @@ func parseFilters() error {
enableDetailFlagWhenFilterFlag()
}
switch strings.ToLower(connectionTypeFilter) {
case "", "p2p", "relayed":
if strings.ToLower(connectionTypeFilter) != "" {
enableDetailFlagWhenFilterFlag()
}
default:
return fmt.Errorf("wrong connection-type filter, should be one of P2P|Relayed, got: %s", connectionTypeFilter)
}
return nil
}

View File

@@ -6,8 +6,6 @@ const (
disableServerRoutesFlag = "disable-server-routes"
disableDNSFlag = "disable-dns"
disableFirewallFlag = "disable-firewall"
blockLANAccessFlag = "block-lan-access"
blockInboundFlag = "block-inbound"
)
var (
@@ -15,8 +13,6 @@ var (
disableServerRoutes bool
disableDNS bool
disableFirewall bool
blockLANAccess bool
blockInbound bool
)
func init() {
@@ -32,11 +28,4 @@ func init() {
upCmd.PersistentFlags().BoolVar(&disableFirewall, disableFirewallFlag, false,
"Disable firewall configuration. If enabled, the client won't modify firewall rules.")
upCmd.PersistentFlags().BoolVar(&blockLANAccess, blockLANAccessFlag, false,
"Block access to local networks (LAN) when using this peer as a router or exit node")
upCmd.PersistentFlags().BoolVar(&blockInbound, blockInboundFlag, false,
"Block inbound connections. If enabled, the client will not allow any inbound connections to the local machine nor routed networks.\n"+
"This overrides any policies received from the management service.")
}

View File

@@ -103,13 +103,13 @@ func startManagement(t *testing.T, config *types.Config, testFile string) (*grpc
Return(&types.Settings{}, nil).
AnyTimes()
accountManager, err := mgmt.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, iv, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false)
accountManager, err := mgmt.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, iv, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock)
if err != nil {
t.Fatal(err)
}
secretsManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager)
mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, nil, nil, &mgmt.MockIntegratedValidator{})
mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, nil, nil)
if err != nil {
t.Fatal(err)
}
@@ -124,7 +124,7 @@ func startManagement(t *testing.T, config *types.Config, testFile string) (*grpc
}
func startClientDaemon(
t *testing.T, ctx context.Context, _, _ string,
t *testing.T, ctx context.Context, _, configPath string,
) (*grpc.Server, net.Listener) {
t.Helper()
lis, err := net.Listen("tcp", "127.0.0.1:0")
@@ -134,7 +134,7 @@ func startClientDaemon(
s := grpc.NewServer()
server := client.New(ctx,
"", false)
configPath, "")
if err := server.Start(); err != nil {
t.Fatal(err)
}

View File

@@ -17,7 +17,7 @@ var traceCmd = &cobra.Command{
Example: `
netbird debug trace in 192.168.1.10 10.10.0.2 -p tcp --sport 12345 --dport 443 --syn --ack
netbird debug trace out 10.10.0.1 8.8.8.8 -p udp --dport 53
netbird debug trace in 10.10.0.2 10.10.0.1 -p icmp --icmp-type 8 --icmp-code 0
netbird debug trace in 10.10.0.2 10.10.0.1 -p icmp --type 8 --code 0
netbird debug trace in 100.64.1.1 self -p tcp --dport 80`,
Args: cobra.ExactArgs(3),
RunE: tracePacket,
@@ -118,7 +118,7 @@ func tracePacket(cmd *cobra.Command, args []string) error {
}
func printTrace(cmd *cobra.Command, src, dst, proto string, sport, dport uint16, resp *proto.TracePacketResponse) {
cmd.Printf("Packet trace %s:%d %s:%d (%s)\n\n", src, sport, dst, dport, strings.ToUpper(proto))
cmd.Printf("Packet trace %s:%d -> %s:%d (%s)\n\n", src, sport, dst, dport, strings.ToUpper(proto))
for _, stage := range resp.Stages {
if stage.ForwardingDetails != nil {

View File

@@ -5,7 +5,6 @@ import (
"fmt"
"net"
"net/netip"
"os/user"
"runtime"
"strings"
"time"
@@ -13,14 +12,12 @@ import (
log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
"google.golang.org/grpc/codes"
gstatus "google.golang.org/grpc/status"
"google.golang.org/protobuf/types/known/durationpb"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/profilemanager"
"github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/client/system"
"github.com/netbirdio/netbird/management/domain"
@@ -38,9 +35,6 @@ const (
noBrowserFlag = "no-browser"
noBrowserDesc = "do not open the browser for SSO login"
profileNameFlag = "profile"
profileNameDesc = "profile name to use for the login. If not specified, the last used profile will be used."
)
var (
@@ -48,8 +42,6 @@ var (
dnsLabels []string
dnsLabelsValidated domain.List
noBrowser bool
profileName string
configPath string
upCmd = &cobra.Command{
Use: "up",
@@ -63,11 +55,12 @@ func init() {
upCmd.PersistentFlags().StringVar(&interfaceName, interfaceNameFlag, iface.WgInterfaceDefault, "Wireguard interface name")
upCmd.PersistentFlags().Uint16Var(&wireguardPort, wireguardPortFlag, iface.DefaultWgPort, "Wireguard interface listening port")
upCmd.PersistentFlags().BoolVarP(&networkMonitor, networkMonitorFlag, "N", networkMonitor,
`Manage network monitoring. Defaults to true on Windows and macOS, false on Linux and FreeBSD. `+
`Manage network monitoring. Defaults to true on Windows and macOS, false on Linux. `+
`E.g. --network-monitor=false to disable or --network-monitor=true to enable.`,
)
upCmd.PersistentFlags().StringSliceVar(&extraIFaceBlackList, extraIFaceBlackListFlag, nil, "Extra list of default interfaces to ignore for listening")
upCmd.PersistentFlags().DurationVar(&dnsRouteInterval, dnsRouteIntervalFlag, time.Minute, "DNS route update interval")
upCmd.PersistentFlags().BoolVar(&blockLANAccess, blockLANAccessFlag, false, "Block access to local networks (LAN) when using this peer as a router or exit node")
upCmd.PersistentFlags().StringSliceVar(&dnsLabels, dnsLabelsFlag, nil,
`Sets DNS labels`+
@@ -78,8 +71,6 @@ func init() {
)
upCmd.PersistentFlags().BoolVar(&noBrowser, noBrowserFlag, false, noBrowserDesc)
upCmd.PersistentFlags().StringVar(&profileName, profileNameFlag, "", profileNameDesc)
upCmd.PersistentFlags().StringVarP(&configPath, "config", "c", "", "(DEPRECATED) Netbird config file location")
}
@@ -89,7 +80,7 @@ func upFunc(cmd *cobra.Command, args []string) error {
cmd.SetOut(cmd.OutOrStdout())
err := util.InitLog(logLevel, util.LogConsole)
err := util.InitLog(logLevel, "console")
if err != nil {
return fmt.Errorf("failed initializing log %v", err)
}
@@ -111,41 +102,13 @@ func upFunc(cmd *cobra.Command, args []string) error {
ctx = context.WithValue(ctx, system.DeviceNameCtxKey, hostName)
}
pm := profilemanager.NewProfileManager()
username, err := user.Current()
if err != nil {
return fmt.Errorf("get current user: %v", err)
}
var profileSwitched bool
// switch profile if provided
if profileName != "" {
err = switchProfile(cmd.Context(), profileName, username.Username)
if err != nil {
return fmt.Errorf("switch profile: %v", err)
}
err = pm.SwitchProfile(profileName)
if err != nil {
return fmt.Errorf("switch profile: %v", err)
}
profileSwitched = true
}
activeProf, err := pm.GetActiveProfile()
if err != nil {
return fmt.Errorf("get active profile: %v", err)
}
if foregroundMode {
return runInForegroundMode(ctx, cmd, activeProf)
return runInForegroundMode(ctx, cmd)
}
return runInDaemonMode(ctx, cmd, pm, activeProf, profileSwitched)
return runInDaemonMode(ctx, cmd)
}
func runInForegroundMode(ctx context.Context, cmd *cobra.Command, activeProf *profilemanager.Profile) error {
func runInForegroundMode(ctx context.Context, cmd *cobra.Command) error {
err := handleRebrand(cmd)
if err != nil {
return err
@@ -156,245 +119,10 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command, activeProf *pr
return err
}
configFilePath, err := activeProf.FilePath()
if err != nil {
return fmt.Errorf("get active profile file path: %v", err)
}
ic, err := setupConfig(customDNSAddressConverted, cmd, configFilePath)
if err != nil {
return fmt.Errorf("setup config: %v", err)
}
providedSetupKey, err := getSetupKey()
if err != nil {
return err
}
config, err := profilemanager.UpdateOrCreateConfig(*ic)
if err != nil {
return fmt.Errorf("get config file: %v", err)
}
_, _ = profilemanager.UpdateOldManagementURL(ctx, config, configFilePath)
err = foregroundLogin(ctx, cmd, config, providedSetupKey)
if err != nil {
return fmt.Errorf("foreground login failed: %v", err)
}
var cancel context.CancelFunc
ctx, cancel = context.WithCancel(ctx)
SetupCloseHandler(ctx, cancel)
r := peer.NewRecorder(config.ManagementURL.String())
r.GetFullStatus()
connectClient := internal.NewConnectClient(ctx, config, r)
SetupDebugHandler(ctx, config, r, connectClient, "")
return connectClient.Run(nil)
}
func runInDaemonMode(ctx context.Context, cmd *cobra.Command, pm *profilemanager.ProfileManager, activeProf *profilemanager.Profile, profileSwitched bool) error {
customDNSAddressConverted, err := parseCustomDNSAddress(cmd.Flag(dnsResolverAddress).Changed)
if err != nil {
return fmt.Errorf("parse custom DNS address: %v", err)
}
conn, err := DialClientGRPCServer(ctx, daemonAddr)
if err != nil {
return fmt.Errorf("failed to connect to daemon error: %v\n"+
"If the daemon is not running please run: "+
"\nnetbird service install \nnetbird service start\n", err)
}
defer func() {
err := conn.Close()
if err != nil {
log.Warnf("failed closing daemon gRPC client connection %v", err)
return
}
}()
client := proto.NewDaemonServiceClient(conn)
status, err := client.Status(ctx, &proto.StatusRequest{})
if err != nil {
return fmt.Errorf("unable to get daemon status: %v", err)
}
if status.Status == string(internal.StatusConnected) {
if !profileSwitched {
cmd.Println("Already connected")
return nil
}
if _, err := client.Down(ctx, &proto.DownRequest{}); err != nil {
log.Errorf("call service down method: %v", err)
return err
}
}
username, err := user.Current()
if err != nil {
return fmt.Errorf("get current user: %v", err)
}
// set the new config
req := setupSetConfigReq(customDNSAddressConverted, cmd, activeProf.Name, username.Username)
if _, err := client.SetConfig(ctx, req); err != nil {
if st, ok := gstatus.FromError(err); ok && st.Code() == codes.Unavailable {
log.Warnf("setConfig method is not available in the daemon")
} else {
return fmt.Errorf("call service setConfig method: %v", err)
}
}
if err := doDaemonUp(ctx, cmd, client, pm, activeProf, customDNSAddressConverted, username.Username); err != nil {
return fmt.Errorf("daemon up failed: %v", err)
}
cmd.Println("Connected")
return nil
}
func doDaemonUp(ctx context.Context, cmd *cobra.Command, client proto.DaemonServiceClient, pm *profilemanager.ProfileManager, activeProf *profilemanager.Profile, customDNSAddressConverted []byte, username string) error {
providedSetupKey, err := getSetupKey()
if err != nil {
return fmt.Errorf("get setup key: %v", err)
}
loginRequest, err := setupLoginRequest(providedSetupKey, customDNSAddressConverted, cmd)
if err != nil {
return fmt.Errorf("setup login request: %v", err)
}
loginRequest.ProfileName = &activeProf.Name
loginRequest.Username = &username
var loginErr error
var loginResp *proto.LoginResponse
err = WithBackOff(func() error {
var backOffErr error
loginResp, backOffErr = client.Login(ctx, loginRequest)
if s, ok := gstatus.FromError(backOffErr); ok && (s.Code() == codes.InvalidArgument ||
s.Code() == codes.PermissionDenied ||
s.Code() == codes.NotFound ||
s.Code() == codes.Unimplemented) {
loginErr = backOffErr
return nil
}
return backOffErr
})
if err != nil {
return fmt.Errorf("login backoff cycle failed: %v", err)
}
if loginErr != nil {
return fmt.Errorf("login failed: %v", loginErr)
}
if loginResp.NeedsSSOLogin {
if err := handleSSOLogin(ctx, cmd, loginResp, client, pm); err != nil {
return fmt.Errorf("sso login failed: %v", err)
}
}
if _, err := client.Up(ctx, &proto.UpRequest{
ProfileName: &activeProf.Name,
Username: &username,
}); err != nil {
return fmt.Errorf("call service up method: %v", err)
}
return nil
}
func setupSetConfigReq(customDNSAddressConverted []byte, cmd *cobra.Command, profileName, username string) *proto.SetConfigRequest {
var req proto.SetConfigRequest
req.ProfileName = profileName
req.Username = username
req.ManagementUrl = managementURL
req.AdminURL = adminURL
req.NatExternalIPs = natExternalIPs
req.CustomDNSAddress = customDNSAddressConverted
req.ExtraIFaceBlacklist = extraIFaceBlackList
req.DnsLabels = dnsLabelsValidated.ToPunycodeList()
req.CleanDNSLabels = dnsLabels != nil && len(dnsLabels) == 0
req.CleanNATExternalIPs = natExternalIPs != nil && len(natExternalIPs) == 0
if cmd.Flag(enableRosenpassFlag).Changed {
req.RosenpassEnabled = &rosenpassEnabled
}
if cmd.Flag(rosenpassPermissiveFlag).Changed {
req.RosenpassPermissive = &rosenpassPermissive
}
if cmd.Flag(serverSSHAllowedFlag).Changed {
req.ServerSSHAllowed = &serverSSHAllowed
}
if cmd.Flag(interfaceNameFlag).Changed {
if err := parseInterfaceName(interfaceName); err != nil {
log.Errorf("parse interface name: %v", err)
return nil
}
req.InterfaceName = &interfaceName
}
if cmd.Flag(wireguardPortFlag).Changed {
p := int64(wireguardPort)
req.WireguardPort = &p
}
if cmd.Flag(networkMonitorFlag).Changed {
req.NetworkMonitor = &networkMonitor
}
if rootCmd.PersistentFlags().Changed(preSharedKeyFlag) {
req.OptionalPreSharedKey = &preSharedKey
}
if cmd.Flag(disableAutoConnectFlag).Changed {
req.DisableAutoConnect = &autoConnectDisabled
}
if cmd.Flag(dnsRouteIntervalFlag).Changed {
req.DnsRouteInterval = durationpb.New(dnsRouteInterval)
}
if cmd.Flag(disableClientRoutesFlag).Changed {
req.DisableClientRoutes = &disableClientRoutes
}
if cmd.Flag(disableServerRoutesFlag).Changed {
req.DisableServerRoutes = &disableServerRoutes
}
if cmd.Flag(disableDNSFlag).Changed {
req.DisableDns = &disableDNS
}
if cmd.Flag(disableFirewallFlag).Changed {
req.DisableFirewall = &disableFirewall
}
if cmd.Flag(blockLANAccessFlag).Changed {
req.BlockLanAccess = &blockLANAccess
}
if cmd.Flag(blockInboundFlag).Changed {
req.BlockInbound = &blockInbound
}
if cmd.Flag(enableLazyConnectionFlag).Changed {
req.LazyConnectionEnabled = &lazyConnEnabled
}
return &req
}
func setupConfig(customDNSAddressConverted []byte, cmd *cobra.Command, configFilePath string) (*profilemanager.ConfigInput, error) {
ic := profilemanager.ConfigInput{
ic := internal.ConfigInput{
ManagementURL: managementURL,
ConfigPath: configFilePath,
AdminURL: adminURL,
ConfigPath: configPath,
NATExternalIPs: natExternalIPs,
CustomDNSAddress: customDNSAddressConverted,
ExtraIFaceBlackList: extraIFaceBlackList,
@@ -415,7 +143,7 @@ func setupConfig(customDNSAddressConverted []byte, cmd *cobra.Command, configFil
if cmd.Flag(interfaceNameFlag).Changed {
if err := parseInterfaceName(interfaceName); err != nil {
return nil, err
return err
}
ic.InterfaceName = &interfaceName
}
@@ -466,28 +194,85 @@ func setupConfig(customDNSAddressConverted []byte, cmd *cobra.Command, configFil
ic.BlockLANAccess = &blockLANAccess
}
if cmd.Flag(blockInboundFlag).Changed {
ic.BlockInbound = &blockInbound
providedSetupKey, err := getSetupKey()
if err != nil {
return err
}
if cmd.Flag(enableLazyConnectionFlag).Changed {
ic.LazyConnectionEnabled = &lazyConnEnabled
config, err := internal.UpdateOrCreateConfig(ic)
if err != nil {
return fmt.Errorf("get config file: %v", err)
}
return &ic, nil
config, _ = internal.UpdateOldManagementURL(ctx, config, configPath)
err = foregroundLogin(ctx, cmd, config, providedSetupKey)
if err != nil {
return fmt.Errorf("foreground login failed: %v", err)
}
var cancel context.CancelFunc
ctx, cancel = context.WithCancel(ctx)
SetupCloseHandler(ctx, cancel)
r := peer.NewRecorder(config.ManagementURL.String())
r.GetFullStatus()
connectClient := internal.NewConnectClient(ctx, config, r)
SetupDebugHandler(ctx, config, r, connectClient, "")
return connectClient.Run(nil)
}
func setupLoginRequest(providedSetupKey string, customDNSAddressConverted []byte, cmd *cobra.Command) (*proto.LoginRequest, error) {
func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error {
customDNSAddressConverted, err := parseCustomDNSAddress(cmd.Flag(dnsResolverAddress).Changed)
if err != nil {
return err
}
conn, err := DialClientGRPCServer(ctx, daemonAddr)
if err != nil {
return fmt.Errorf("failed to connect to daemon error: %v\n"+
"If the daemon is not running please run: "+
"\nnetbird service install \nnetbird service start\n", err)
}
defer func() {
err := conn.Close()
if err != nil {
log.Warnf("failed closing daemon gRPC client connection %v", err)
return
}
}()
client := proto.NewDaemonServiceClient(conn)
status, err := client.Status(ctx, &proto.StatusRequest{})
if err != nil {
return fmt.Errorf("unable to get daemon status: %v", err)
}
if status.Status == string(internal.StatusConnected) {
cmd.Println("Already connected")
return nil
}
providedSetupKey, err := getSetupKey()
if err != nil {
return err
}
loginRequest := proto.LoginRequest{
SetupKey: providedSetupKey,
ManagementUrl: managementURL,
NatExternalIPs: natExternalIPs,
CleanNATExternalIPs: natExternalIPs != nil && len(natExternalIPs) == 0,
CustomDNSAddress: customDNSAddressConverted,
IsUnixDesktopClient: isUnixRunningDesktop(),
Hostname: hostName,
ExtraIFaceBlacklist: extraIFaceBlackList,
DnsLabels: dnsLabels,
CleanDNSLabels: dnsLabels != nil && len(dnsLabels) == 0,
SetupKey: providedSetupKey,
ManagementUrl: managementURL,
AdminURL: adminURL,
NatExternalIPs: natExternalIPs,
CleanNATExternalIPs: natExternalIPs != nil && len(natExternalIPs) == 0,
CustomDNSAddress: customDNSAddressConverted,
IsLinuxDesktopClient: isLinuxRunningDesktop(),
Hostname: hostName,
ExtraIFaceBlacklist: extraIFaceBlackList,
DnsLabels: dnsLabels,
CleanDNSLabels: dnsLabels != nil && len(dnsLabels) == 0,
}
if rootCmd.PersistentFlags().Changed(preSharedKeyFlag) {
@@ -512,7 +297,7 @@ func setupLoginRequest(providedSetupKey string, customDNSAddressConverted []byte
if cmd.Flag(interfaceNameFlag).Changed {
if err := parseInterfaceName(interfaceName); err != nil {
return nil, err
return err
}
loginRequest.InterfaceName = &interfaceName
}
@@ -547,14 +332,45 @@ func setupLoginRequest(providedSetupKey string, customDNSAddressConverted []byte
loginRequest.BlockLanAccess = &blockLANAccess
}
if cmd.Flag(blockInboundFlag).Changed {
loginRequest.BlockInbound = &blockInbound
var loginErr error
var loginResp *proto.LoginResponse
err = WithBackOff(func() error {
var backOffErr error
loginResp, backOffErr = client.Login(ctx, &loginRequest)
if s, ok := gstatus.FromError(backOffErr); ok && (s.Code() == codes.InvalidArgument ||
s.Code() == codes.PermissionDenied ||
s.Code() == codes.NotFound ||
s.Code() == codes.Unimplemented) {
loginErr = backOffErr
return nil
}
return backOffErr
})
if err != nil {
return fmt.Errorf("login backoff cycle failed: %v", err)
}
if cmd.Flag(enableLazyConnectionFlag).Changed {
loginRequest.LazyConnectionEnabled = &lazyConnEnabled
if loginErr != nil {
return fmt.Errorf("login failed: %v", loginErr)
}
return &loginRequest, nil
if loginResp.NeedsSSOLogin {
openURL(cmd, loginResp.VerificationURIComplete, loginResp.UserCode, noBrowser)
_, err = client.WaitSSOLogin(ctx, &proto.WaitSSOLoginRequest{UserCode: loginResp.UserCode, Hostname: hostName})
if err != nil {
return fmt.Errorf("waiting sso login failed with: %v", err)
}
}
if _, err := client.Up(ctx, &proto.UpRequest{}); err != nil {
return fmt.Errorf("call service up method: %v", err)
}
cmd.Println("Connected")
return nil
}
func validateNATExternalIPs(list []string) error {
@@ -638,7 +454,7 @@ func parseCustomDNSAddress(modified bool) ([]byte, error) {
if !isValidAddrPort(customDNSAddress) {
return nil, fmt.Errorf("%s is invalid, it should be formatted as IP:Port string or as an empty string like \"\"", customDNSAddress)
}
if customDNSAddress == "" && util.FindFirstLogPath(logFiles) != "" {
if customDNSAddress == "" && logFile != "console" {
parsed = []byte("empty")
} else {
parsed = []byte(customDNSAddress)

View File

@@ -3,55 +3,18 @@ package cmd
import (
"context"
"os"
"os/user"
"testing"
"time"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/profilemanager"
)
var cliAddr string
func TestUpDaemon(t *testing.T) {
tempDir := t.TempDir()
origDefaultProfileDir := profilemanager.DefaultConfigPathDir
origActiveProfileStatePath := profilemanager.ActiveProfileStatePath
profilemanager.DefaultConfigPathDir = tempDir
profilemanager.ActiveProfileStatePath = tempDir + "/active_profile.json"
profilemanager.ConfigDirOverride = tempDir
currUser, err := user.Current()
if err != nil {
t.Fatalf("failed to get current user: %v", err)
return
}
sm := profilemanager.ServiceManager{}
err = sm.AddProfile("test1", currUser.Username)
if err != nil {
t.Fatalf("failed to add profile: %v", err)
return
}
err = sm.SetActiveProfileState(&profilemanager.ActiveProfileState{
Name: "test1",
Username: currUser.Username,
})
if err != nil {
t.Fatalf("failed to set active profile state: %v", err)
return
}
t.Cleanup(func() {
profilemanager.DefaultConfigPathDir = origDefaultProfileDir
profilemanager.ActiveProfileStatePath = origActiveProfileStatePath
profilemanager.ConfigDirOverride = ""
})
mgmAddr := startTestingServices(t)
tempDir := t.TempDir()
confPath := tempDir + "/config.json"
ctx := internal.CtxInitState(context.Background())

View File

@@ -17,7 +17,6 @@ import (
"github.com/netbirdio/netbird/client/iface/netstack"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/profilemanager"
"github.com/netbirdio/netbird/client/system"
)
@@ -27,7 +26,7 @@ var ErrClientNotStarted = errors.New("client not started")
// Client manages a netbird embedded client instance
type Client struct {
deviceName string
config *profilemanager.Config
config *internal.Config
mu sync.Mutex
cancel context.CancelFunc
setupKey string
@@ -89,9 +88,9 @@ func New(opts Options) (*Client, error) {
}
t := true
var config *profilemanager.Config
var config *internal.Config
var err error
input := profilemanager.ConfigInput{
input := internal.ConfigInput{
ConfigPath: opts.ConfigPath,
ManagementURL: opts.ManagementURL,
PreSharedKey: &opts.PreSharedKey,
@@ -99,9 +98,9 @@ func New(opts Options) (*Client, error) {
DisableClientRoutes: &opts.DisableClientRoutes,
}
if opts.ConfigPath != "" {
config, err = profilemanager.UpdateOrCreateConfig(input)
config, err = internal.UpdateOrCreateConfig(input)
} else {
config, err = profilemanager.CreateInMemoryConfig(input)
config, err = internal.CreateInMemoryConfig(input)
}
if err != nil {
return nil, fmt.Errorf("create config: %w", err)

View File

@@ -147,10 +147,6 @@ func (m *Manager) IsServerRouteSupported() bool {
return true
}
func (m *Manager) IsStateful() bool {
return true
}
func (m *Manager) AddNatRule(pair firewall.RouterPair) error {
m.mutex.Lock()
defer m.mutex.Unlock()
@@ -202,7 +198,7 @@ func (m *Manager) AllowNetbird() error {
_, err := m.AddPeerFiltering(
nil,
net.IP{0, 0, 0, 0},
firewall.ProtocolALL,
"all",
nil,
nil,
firewall.ActionAccept,
@@ -223,16 +219,10 @@ func (m *Manager) SetLogLevel(log.Level) {
}
func (m *Manager) EnableRouting() error {
if err := m.router.ipFwdState.RequestForwarding(); err != nil {
return fmt.Errorf("enable IP forwarding: %w", err)
}
return nil
}
func (m *Manager) DisableRouting() error {
if err := m.router.ipFwdState.ReleaseForwarding(); err != nil {
return fmt.Errorf("disable IP forwarding: %w", err)
}
return nil
}

View File

@@ -2,7 +2,7 @@ package iptables
import (
"fmt"
"net/netip"
"net"
"testing"
"time"
@@ -19,8 +19,11 @@ var ifaceMock = &iFaceMock{
},
AddressFunc: func() wgaddr.Address {
return wgaddr.Address{
IP: netip.MustParseAddr("10.20.0.1"),
Network: netip.MustParsePrefix("10.20.0.0/24"),
IP: net.ParseIP("10.20.0.1"),
Network: &net.IPNet{
IP: net.ParseIP("10.20.0.0"),
Mask: net.IPv4Mask(255, 255, 255, 0),
},
}
},
}
@@ -67,12 +70,12 @@ func TestIptablesManager(t *testing.T) {
var rule2 []fw.Rule
t.Run("add second rule", func(t *testing.T) {
ip := netip.MustParseAddr("10.20.0.3")
ip := net.ParseIP("10.20.0.3")
port := &fw.Port{
IsRange: true,
Values: []uint16{8043, 8046},
}
rule2, err = manager.AddPeerFiltering(nil, ip.AsSlice(), "tcp", port, nil, fw.ActionAccept, "")
rule2, err = manager.AddPeerFiltering(nil, ip, "tcp", port, nil, fw.ActionAccept, "")
require.NoError(t, err, "failed to add rule")
for _, r := range rule2 {
@@ -92,9 +95,9 @@ func TestIptablesManager(t *testing.T) {
t.Run("reset check", func(t *testing.T) {
// add second rule
ip := netip.MustParseAddr("10.20.0.3")
ip := net.ParseIP("10.20.0.3")
port := &fw.Port{Values: []uint16{5353}}
_, err = manager.AddPeerFiltering(nil, ip.AsSlice(), "udp", nil, port, fw.ActionAccept, "")
_, err = manager.AddPeerFiltering(nil, ip, "udp", nil, port, fw.ActionAccept, "")
require.NoError(t, err, "failed to add rule")
err = manager.Close(nil)
@@ -116,8 +119,11 @@ func TestIptablesManagerIPSet(t *testing.T) {
},
AddressFunc: func() wgaddr.Address {
return wgaddr.Address{
IP: netip.MustParseAddr("10.20.0.1"),
Network: netip.MustParsePrefix("10.20.0.0/24"),
IP: net.ParseIP("10.20.0.1"),
Network: &net.IPNet{
IP: net.ParseIP("10.20.0.0"),
Mask: net.IPv4Mask(255, 255, 255, 0),
},
}
},
}
@@ -138,11 +144,11 @@ func TestIptablesManagerIPSet(t *testing.T) {
var rule2 []fw.Rule
t.Run("add second rule", func(t *testing.T) {
ip := netip.MustParseAddr("10.20.0.3")
ip := net.ParseIP("10.20.0.3")
port := &fw.Port{
Values: []uint16{443},
}
rule2, err = manager.AddPeerFiltering(nil, ip.AsSlice(), "tcp", port, nil, fw.ActionAccept, "default")
rule2, err = manager.AddPeerFiltering(nil, ip, "tcp", port, nil, fw.ActionAccept, "default")
for _, r := range rule2 {
require.NoError(t, err, "failed to add rule")
require.Equal(t, r.(*Rule).ipsetName, "default-sport", "ipset name must be set")
@@ -180,8 +186,11 @@ func TestIptablesCreatePerformance(t *testing.T) {
},
AddressFunc: func() wgaddr.Address {
return wgaddr.Address{
IP: netip.MustParseAddr("10.20.0.1"),
Network: netip.MustParsePrefix("10.20.0.0/24"),
IP: net.ParseIP("10.20.0.1"),
Network: &net.IPNet{
IP: net.ParseIP("10.20.0.0"),
Mask: net.IPv4Mask(255, 255, 255, 0),
},
}
},
}
@@ -203,11 +212,11 @@ func TestIptablesCreatePerformance(t *testing.T) {
require.NoError(t, err)
ip := netip.MustParseAddr("10.20.0.100")
ip := net.ParseIP("10.20.0.100")
start := time.Now()
for i := 0; i < testMax; i++ {
port := &fw.Port{Values: []uint16{uint16(1000 + i)}}
_, err = manager.AddPeerFiltering(nil, ip.AsSlice(), "tcp", nil, port, fw.ActionAccept, "")
_, err = manager.AddPeerFiltering(nil, ip, "tcp", nil, port, fw.ActionAccept, "")
require.NoError(t, err, "failed to add rule")
}

View File

@@ -248,6 +248,10 @@ func (r *router) deleteIpSet(setName string) error {
// AddNatRule inserts an iptables rule pair into the nat chain
func (r *router) AddNatRule(pair firewall.RouterPair) error {
if err := r.ipFwdState.RequestForwarding(); err != nil {
return err
}
if r.legacyManagement {
log.Warnf("This peer is connected to a NetBird Management service with an older version. Allowing all traffic for %s", pair.Destination)
if err := r.addLegacyRouteRule(pair); err != nil {
@@ -274,6 +278,10 @@ func (r *router) AddNatRule(pair firewall.RouterPair) error {
// RemoveNatRule removes an iptables rule pair from forwarding and nat chains
func (r *router) RemoveNatRule(pair firewall.RouterPair) error {
if err := r.ipFwdState.ReleaseForwarding(); err != nil {
log.Errorf("%v", err)
}
if pair.Masquerade {
if err := r.removeNatRule(pair); err != nil {
return fmt.Errorf("remove nat rule: %w", err)

View File

@@ -116,8 +116,6 @@ type Manager interface {
// IsServerRouteSupported returns true if the firewall supports server side routing operations
IsServerRouteSupported() bool
IsStateful() bool
AddRouteFiltering(
id []byte,
sources []netip.Prefix,

View File

@@ -170,10 +170,6 @@ func (m *Manager) IsServerRouteSupported() bool {
return true
}
func (m *Manager) IsStateful() bool {
return true
}
func (m *Manager) AddNatRule(pair firewall.RouterPair) error {
m.mutex.Lock()
defer m.mutex.Unlock()
@@ -328,16 +324,10 @@ func (m *Manager) SetLogLevel(log.Level) {
}
func (m *Manager) EnableRouting() error {
if err := m.router.ipFwdState.RequestForwarding(); err != nil {
return fmt.Errorf("enable IP forwarding: %w", err)
}
return nil
}
func (m *Manager) DisableRouting() error {
if err := m.router.ipFwdState.ReleaseForwarding(); err != nil {
return fmt.Errorf("disable IP forwarding: %w", err)
}
return nil
}

View File

@@ -3,6 +3,7 @@ package nftables
import (
"bytes"
"fmt"
"net"
"net/netip"
"os/exec"
"testing"
@@ -24,8 +25,11 @@ var ifaceMock = &iFaceMock{
},
AddressFunc: func() wgaddr.Address {
return wgaddr.Address{
IP: netip.MustParseAddr("100.96.0.1"),
Network: netip.MustParsePrefix("100.96.0.0/16"),
IP: net.ParseIP("100.96.0.1"),
Network: &net.IPNet{
IP: net.ParseIP("100.96.0.0"),
Mask: net.IPv4Mask(255, 255, 255, 0),
},
}
},
}
@@ -66,11 +70,11 @@ func TestNftablesManager(t *testing.T) {
time.Sleep(time.Second)
}()
ip := netip.MustParseAddr("100.96.0.1").Unmap()
ip := net.ParseIP("100.96.0.1")
testClient := &nftables.Conn{}
rule, err := manager.AddPeerFiltering(nil, ip.AsSlice(), fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{53}}, fw.ActionDrop, "")
rule, err := manager.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{53}}, fw.ActionDrop, "")
require.NoError(t, err, "failed to add rule")
err = manager.Flush()
@@ -105,6 +109,8 @@ func TestNftablesManager(t *testing.T) {
}
compareExprsIgnoringCounters(t, rules[0].Exprs, expectedExprs1)
ipToAdd, _ := netip.AddrFromSlice(ip)
add := ipToAdd.Unmap()
expectedExprs2 := []expr.Any{
&expr.Payload{
DestRegister: 1,
@@ -126,7 +132,7 @@ func TestNftablesManager(t *testing.T) {
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: ip.AsSlice(),
Data: add.AsSlice(),
},
&expr.Payload{
DestRegister: 1,
@@ -167,8 +173,11 @@ func TestNFtablesCreatePerformance(t *testing.T) {
},
AddressFunc: func() wgaddr.Address {
return wgaddr.Address{
IP: netip.MustParseAddr("100.96.0.1"),
Network: netip.MustParsePrefix("100.96.0.0/16"),
IP: net.ParseIP("100.96.0.1"),
Network: &net.IPNet{
IP: net.ParseIP("100.96.0.0"),
Mask: net.IPv4Mask(255, 255, 255, 0),
},
}
},
}
@@ -188,11 +197,11 @@ func TestNFtablesCreatePerformance(t *testing.T) {
time.Sleep(time.Second)
}()
ip := netip.MustParseAddr("10.20.0.100")
ip := net.ParseIP("10.20.0.100")
start := time.Now()
for i := 0; i < testMax; i++ {
port := &fw.Port{Values: []uint16{uint16(1000 + i)}}
_, err = manager.AddPeerFiltering(nil, ip.AsSlice(), "tcp", nil, port, fw.ActionAccept, "")
_, err = manager.AddPeerFiltering(nil, ip, "tcp", nil, port, fw.ActionAccept, "")
require.NoError(t, err, "failed to add rule")
if i%100 == 0 {
@@ -273,8 +282,8 @@ func TestNftablesManagerCompatibilityWithIptables(t *testing.T) {
verifyIptablesOutput(t, stdout, stderr)
})
ip := netip.MustParseAddr("100.96.0.1")
_, err = manager.AddPeerFiltering(nil, ip.AsSlice(), fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept, "")
ip := net.ParseIP("100.96.0.1")
_, err = manager.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept, "")
require.NoError(t, err, "failed to add peer filtering rule")
_, err = manager.AddRouteFiltering(

View File

@@ -573,6 +573,10 @@ func (r *router) deleteNftRule(rule *nftables.Rule, ruleKey string) error {
// AddNatRule appends a nftables rule pair to the nat chain
func (r *router) AddNatRule(pair firewall.RouterPair) error {
if err := r.ipFwdState.RequestForwarding(); err != nil {
return err
}
if err := r.refreshRulesMap(); err != nil {
return fmt.Errorf(refreshRulesMapError, err)
}
@@ -1002,6 +1006,10 @@ func (r *router) removeAcceptForwardRulesIptables(ipt *iptables.IPTables) error
// RemoveNatRule removes the prerouting mark rule
func (r *router) RemoveNatRule(pair firewall.RouterPair) error {
if err := r.ipFwdState.ReleaseForwarding(); err != nil {
log.Errorf("%v", err)
}
if err := r.refreshRulesMap(); err != nil {
return fmt.Errorf(refreshRulesMapError, err)
}

View File

@@ -62,5 +62,5 @@ type ConnKey struct {
}
func (c ConnKey) String() string {
return fmt.Sprintf("%s:%d %s:%d", c.SrcIP.Unmap(), c.SrcPort, c.DstIP.Unmap(), c.DstPort)
return fmt.Sprintf("%s:%d -> %s:%d", c.SrcIP.Unmap(), c.SrcPort, c.DstIP.Unmap(), c.DstPort)
}

View File

@@ -3,7 +3,6 @@ package conntrack
import (
"context"
"fmt"
"net"
"net/netip"
"sync"
"time"
@@ -20,10 +19,6 @@ const (
DefaultICMPTimeout = 30 * time.Second
// ICMPCleanupInterval is how often we check for stale ICMP connections
ICMPCleanupInterval = 15 * time.Second
// MaxICMPPayloadLength is the maximum length of ICMP payload we consider for original packet info,
// which includes the IP header (20 bytes) and transport header (8 bytes)
MaxICMPPayloadLength = 28
)
// ICMPConnKey uniquely identifies an ICMP connection
@@ -34,7 +29,7 @@ type ICMPConnKey struct {
}
func (i ICMPConnKey) String() string {
return fmt.Sprintf("%s %s (id %d)", i.SrcIP, i.DstIP, i.ID)
return fmt.Sprintf("%s -> %s (id %d)", i.SrcIP, i.DstIP, i.ID)
}
// ICMPConnTrack represents an ICMP connection state
@@ -55,72 +50,6 @@ type ICMPTracker struct {
flowLogger nftypes.FlowLogger
}
// ICMPInfo holds ICMP type, code, and payload for lazy string formatting in logs
type ICMPInfo struct {
TypeCode layers.ICMPv4TypeCode
PayloadData [MaxICMPPayloadLength]byte
// actual length of valid data
PayloadLen int
}
// String implements fmt.Stringer for lazy evaluation in log messages
func (info ICMPInfo) String() string {
if info.isErrorMessage() && info.PayloadLen >= MaxICMPPayloadLength {
if origInfo := info.parseOriginalPacket(); origInfo != "" {
return fmt.Sprintf("%s (original: %s)", info.TypeCode, origInfo)
}
}
return info.TypeCode.String()
}
// isErrorMessage returns true if this ICMP type carries original packet info
func (info ICMPInfo) isErrorMessage() bool {
typ := info.TypeCode.Type()
return typ == 3 || // Destination Unreachable
typ == 5 || // Redirect
typ == 11 || // Time Exceeded
typ == 12 // Parameter Problem
}
// parseOriginalPacket extracts info about the original packet from ICMP payload
func (info ICMPInfo) parseOriginalPacket() string {
if info.PayloadLen < MaxICMPPayloadLength {
return ""
}
// TODO: handle IPv6
if version := (info.PayloadData[0] >> 4) & 0xF; version != 4 {
return ""
}
protocol := info.PayloadData[9]
srcIP := net.IP(info.PayloadData[12:16])
dstIP := net.IP(info.PayloadData[16:20])
transportData := info.PayloadData[20:]
switch nftypes.Protocol(protocol) {
case nftypes.TCP:
srcPort := uint16(transportData[0])<<8 | uint16(transportData[1])
dstPort := uint16(transportData[2])<<8 | uint16(transportData[3])
return fmt.Sprintf("TCP %s:%d → %s:%d", srcIP, srcPort, dstIP, dstPort)
case nftypes.UDP:
srcPort := uint16(transportData[0])<<8 | uint16(transportData[1])
dstPort := uint16(transportData[2])<<8 | uint16(transportData[3])
return fmt.Sprintf("UDP %s:%d → %s:%d", srcIP, srcPort, dstIP, dstPort)
case nftypes.ICMP:
icmpType := transportData[0]
icmpCode := transportData[1]
return fmt.Sprintf("ICMP %s → %s (type %d code %d)", srcIP, dstIP, icmpType, icmpCode)
default:
return fmt.Sprintf("Proto %d %s → %s", protocol, srcIP, dstIP)
}
}
// NewICMPTracker creates a new ICMP connection tracker
func NewICMPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nftypes.FlowLogger) *ICMPTracker {
if timeout == 0 {
@@ -164,64 +93,30 @@ func (t *ICMPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, id uint
}
// TrackOutbound records an outbound ICMP connection
func (t *ICMPTracker) TrackOutbound(
srcIP netip.Addr,
dstIP netip.Addr,
id uint16,
typecode layers.ICMPv4TypeCode,
payload []byte,
size int,
) {
func (t *ICMPTracker) TrackOutbound(srcIP netip.Addr, dstIP netip.Addr, id uint16, typecode layers.ICMPv4TypeCode, size int) {
if _, exists := t.updateIfExists(dstIP, srcIP, id, nftypes.Egress, size); !exists {
// if (inverted direction) conn is not tracked, track this direction
t.track(srcIP, dstIP, id, typecode, nftypes.Egress, nil, payload, size)
t.track(srcIP, dstIP, id, typecode, nftypes.Egress, nil, size)
}
}
// TrackInbound records an inbound ICMP Echo Request
func (t *ICMPTracker) TrackInbound(
srcIP netip.Addr,
dstIP netip.Addr,
id uint16,
typecode layers.ICMPv4TypeCode,
ruleId []byte,
payload []byte,
size int,
) {
t.track(srcIP, dstIP, id, typecode, nftypes.Ingress, ruleId, payload, size)
func (t *ICMPTracker) TrackInbound(srcIP netip.Addr, dstIP netip.Addr, id uint16, typecode layers.ICMPv4TypeCode, ruleId []byte, size int) {
t.track(srcIP, dstIP, id, typecode, nftypes.Ingress, ruleId, size)
}
// track is the common implementation for tracking both inbound and outbound ICMP connections
func (t *ICMPTracker) track(
srcIP netip.Addr,
dstIP netip.Addr,
id uint16,
typecode layers.ICMPv4TypeCode,
direction nftypes.Direction,
ruleId []byte,
payload []byte,
size int,
) {
func (t *ICMPTracker) track(srcIP netip.Addr, dstIP netip.Addr, id uint16, typecode layers.ICMPv4TypeCode, direction nftypes.Direction, ruleId []byte, size int) {
key, exists := t.updateIfExists(srcIP, dstIP, id, direction, size)
if exists {
return
}
typ, code := typecode.Type(), typecode.Code()
icmpInfo := ICMPInfo{
TypeCode: typecode,
}
if len(payload) > 0 {
icmpInfo.PayloadLen = len(payload)
if icmpInfo.PayloadLen > MaxICMPPayloadLength {
icmpInfo.PayloadLen = MaxICMPPayloadLength
}
copy(icmpInfo.PayloadData[:], payload[:icmpInfo.PayloadLen])
}
// non echo requests don't need tracking
if typ != uint8(layers.ICMPv4TypeEchoRequest) {
t.logger.Trace("New %s ICMP connection %s - %s", direction, key, icmpInfo)
t.logger.Trace("New %s ICMP connection %s type %d code %d", direction, key, typ, code)
t.sendStartEvent(direction, srcIP, dstIP, typ, code, ruleId, size)
return
}
@@ -243,7 +138,7 @@ func (t *ICMPTracker) track(
t.connections[key] = conn
t.mutex.Unlock()
t.logger.Trace("New %s ICMP connection %s - %s", direction, key, icmpInfo)
t.logger.Trace("New %s ICMP connection %s type %d code %d", direction, key, typ, code)
t.sendEvent(nftypes.TypeStart, conn, ruleId)
}

View File

@@ -15,7 +15,7 @@ func BenchmarkICMPTracker(b *testing.B) {
b.ResetTimer()
for i := 0; i < b.N; i++ {
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 0, []byte{}, 0)
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 0, 0)
}
})
@@ -28,7 +28,7 @@ func BenchmarkICMPTracker(b *testing.B) {
// Pre-populate some connections
for i := 0; i < 1000; i++ {
tracker.TrackOutbound(srcIP, dstIP, uint16(i), 0, []byte{}, 0)
tracker.TrackOutbound(srcIP, dstIP, uint16(i), 0, 0)
}
b.ResetTimer()

View File

@@ -86,5 +86,5 @@ type epID stack.TransportEndpointID
func (i epID) String() string {
// src and remote is swapped
return fmt.Sprintf("%s:%d %s:%d", i.RemoteAddress, i.RemotePort, i.LocalAddress, i.LocalPort)
return fmt.Sprintf("%s:%d -> %s:%d", i.RemoteAddress, i.RemotePort, i.LocalAddress, i.LocalPort)
}

View File

@@ -41,7 +41,7 @@ type Forwarder struct {
udpForwarder *udpForwarder
ctx context.Context
cancel context.CancelFunc
ip tcpip.Address
ip net.IP
netstack bool
}
@@ -71,11 +71,12 @@ func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.Flow
return nil, fmt.Errorf("failed to create NIC: %v", err)
}
ones, _ := iface.Address().Network.Mask.Size()
protoAddr := tcpip.ProtocolAddress{
Protocol: ipv4.ProtocolNumber,
AddressWithPrefix: tcpip.AddressWithPrefix{
Address: tcpip.AddrFromSlice(iface.Address().IP.AsSlice()),
PrefixLen: iface.Address().Network.Bits(),
Address: tcpip.AddrFromSlice(iface.Address().IP.To4()),
PrefixLen: ones,
},
}
@@ -115,7 +116,7 @@ func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.Flow
ctx: ctx,
cancel: cancel,
netstack: netstack,
ip: tcpip.AddrFromSlice(iface.Address().IP.AsSlice()),
ip: iface.Address().IP,
}
receiveWindow := defaultReceiveWindow
@@ -166,7 +167,7 @@ func (f *Forwarder) Stop() {
}
func (f *Forwarder) determineDialAddr(addr tcpip.Address) net.IP {
if f.netstack && f.ip.Equal(addr) {
if f.netstack && f.ip.Equal(addr.AsSlice()) {
return net.IPv4(127, 0, 0, 1)
}
return addr.AsSlice()
@@ -178,6 +179,7 @@ func (f *Forwarder) RegisterRuleID(srcIP, dstIP netip.Addr, srcPort, dstPort uin
}
func (f *Forwarder) getRuleID(srcIP, dstIP netip.Addr, srcPort, dstPort uint16) ([]byte, bool) {
if value, ok := f.ruleIdMap.Load(buildKey(srcIP, dstIP, srcPort, dstPort)); ok {
return value.([]byte), true
} else if value, ok := f.ruleIdMap.Load(buildKey(dstIP, srcIP, dstPort, srcPort)); ok {

View File

@@ -111,12 +111,12 @@ func (f *Forwarder) proxyTCP(id stack.TransportEndpointID, inConn *gonet.TCPConn
if errInToOut != nil {
if !isClosedError(errInToOut) {
f.logger.Error("proxyTCP: copy error (in out) for %s: %v", epID(id), errInToOut)
f.logger.Error("proxyTCP: copy error (in -> out): %v", errInToOut)
}
}
if errOutToIn != nil {
if !isClosedError(errOutToIn) {
f.logger.Error("proxyTCP: copy error (out in) for %s: %v", epID(id), errOutToIn)
f.logger.Error("proxyTCP: copy error (out -> in): %v", errOutToIn)
}
}

View File

@@ -250,10 +250,10 @@ func (f *Forwarder) proxyUDP(ctx context.Context, pConn *udpPacketConn, id stack
wg.Wait()
if outboundErr != nil && !isClosedError(outboundErr) {
f.logger.Error("proxyUDP: copy error (outboundinbound) for %s: %v", epID(id), outboundErr)
f.logger.Error("proxyUDP: copy error (outbound->inbound): %v", outboundErr)
}
if inboundErr != nil && !isClosedError(inboundErr) {
f.logger.Error("proxyUDP: copy error (inboundoutbound) for %s: %v", epID(id), inboundErr)
f.logger.Error("proxyUDP: copy error (inbound->outbound): %v", inboundErr)
}
var rxPackets, txPackets uint64

View File

@@ -45,26 +45,24 @@ func (m *localIPManager) setBitmapBit(ip net.IP) {
m.ipv4Bitmap[high].bitmap[index] |= 1 << bit
}
func (m *localIPManager) setBitInBitmap(ip netip.Addr, bitmap *[256]*ipv4LowBitmap, ipv4Set map[netip.Addr]struct{}, ipv4Addresses *[]netip.Addr) {
if !ip.Is4() {
return
}
ipv4 := ip.AsSlice()
func (m *localIPManager) setBitInBitmap(ip net.IP, bitmap *[256]*ipv4LowBitmap, ipv4Set map[string]struct{}, ipv4Addresses *[]string) {
if ipv4 := ip.To4(); ipv4 != nil {
high := uint16(ipv4[0])
low := (uint16(ipv4[1]) << 8) | (uint16(ipv4[2]) << 4) | uint16(ipv4[3])
high := uint16(ipv4[0])
low := (uint16(ipv4[1]) << 8) | (uint16(ipv4[2]) << 4) | uint16(ipv4[3])
if bitmap[high] == nil {
bitmap[high] = &ipv4LowBitmap{}
}
if bitmap[high] == nil {
bitmap[high] = &ipv4LowBitmap{}
}
index := low / 32
bit := low % 32
bitmap[high].bitmap[index] |= 1 << bit
index := low / 32
bit := low % 32
bitmap[high].bitmap[index] |= 1 << bit
if _, exists := ipv4Set[ip]; !exists {
ipv4Set[ip] = struct{}{}
*ipv4Addresses = append(*ipv4Addresses, ip)
ipStr := ipv4.String()
if _, exists := ipv4Set[ipStr]; !exists {
ipv4Set[ipStr] = struct{}{}
*ipv4Addresses = append(*ipv4Addresses, ipStr)
}
}
}
@@ -81,12 +79,12 @@ func (m *localIPManager) checkBitmapBit(ip []byte) bool {
return (m.ipv4Bitmap[high].bitmap[index] & (1 << bit)) != 0
}
func (m *localIPManager) processIP(ip netip.Addr, bitmap *[256]*ipv4LowBitmap, ipv4Set map[netip.Addr]struct{}, ipv4Addresses *[]netip.Addr) error {
func (m *localIPManager) processIP(ip net.IP, bitmap *[256]*ipv4LowBitmap, ipv4Set map[string]struct{}, ipv4Addresses *[]string) error {
m.setBitInBitmap(ip, bitmap, ipv4Set, ipv4Addresses)
return nil
}
func (m *localIPManager) processInterface(iface net.Interface, bitmap *[256]*ipv4LowBitmap, ipv4Set map[netip.Addr]struct{}, ipv4Addresses *[]netip.Addr) {
func (m *localIPManager) processInterface(iface net.Interface, bitmap *[256]*ipv4LowBitmap, ipv4Set map[string]struct{}, ipv4Addresses *[]string) {
addrs, err := iface.Addrs()
if err != nil {
log.Debugf("get addresses for interface %s failed: %v", iface.Name, err)
@@ -104,13 +102,7 @@ func (m *localIPManager) processInterface(iface net.Interface, bitmap *[256]*ipv
continue
}
addr, ok := netip.AddrFromSlice(ip)
if !ok {
log.Warnf("invalid IP address %s in interface %s", ip.String(), iface.Name)
continue
}
if err := m.processIP(addr.Unmap(), bitmap, ipv4Set, ipv4Addresses); err != nil {
if err := m.processIP(ip, bitmap, ipv4Set, ipv4Addresses); err != nil {
log.Debugf("process IP failed: %v", err)
}
}
@@ -124,8 +116,8 @@ func (m *localIPManager) UpdateLocalIPs(iface common.IFaceMapper) (err error) {
}()
var newIPv4Bitmap [256]*ipv4LowBitmap
ipv4Set := make(map[netip.Addr]struct{})
var ipv4Addresses []netip.Addr
ipv4Set := make(map[string]struct{})
var ipv4Addresses []string
// 127.0.0.0/8
newIPv4Bitmap[127] = &ipv4LowBitmap{}

View File

@@ -20,8 +20,11 @@ func TestLocalIPManager(t *testing.T) {
{
name: "Localhost range",
setupAddr: wgaddr.Address{
IP: netip.MustParseAddr("192.168.1.1"),
Network: netip.MustParsePrefix("192.168.1.0/24"),
IP: net.ParseIP("192.168.1.1"),
Network: &net.IPNet{
IP: net.ParseIP("192.168.1.0"),
Mask: net.CIDRMask(24, 32),
},
},
testIP: netip.MustParseAddr("127.0.0.2"),
expected: true,
@@ -29,8 +32,11 @@ func TestLocalIPManager(t *testing.T) {
{
name: "Localhost standard address",
setupAddr: wgaddr.Address{
IP: netip.MustParseAddr("192.168.1.1"),
Network: netip.MustParsePrefix("192.168.1.0/24"),
IP: net.ParseIP("192.168.1.1"),
Network: &net.IPNet{
IP: net.ParseIP("192.168.1.0"),
Mask: net.CIDRMask(24, 32),
},
},
testIP: netip.MustParseAddr("127.0.0.1"),
expected: true,
@@ -38,8 +44,11 @@ func TestLocalIPManager(t *testing.T) {
{
name: "Localhost range edge",
setupAddr: wgaddr.Address{
IP: netip.MustParseAddr("192.168.1.1"),
Network: netip.MustParsePrefix("192.168.1.0/24"),
IP: net.ParseIP("192.168.1.1"),
Network: &net.IPNet{
IP: net.ParseIP("192.168.1.0"),
Mask: net.CIDRMask(24, 32),
},
},
testIP: netip.MustParseAddr("127.255.255.255"),
expected: true,
@@ -47,8 +56,11 @@ func TestLocalIPManager(t *testing.T) {
{
name: "Local IP matches",
setupAddr: wgaddr.Address{
IP: netip.MustParseAddr("192.168.1.1"),
Network: netip.MustParsePrefix("192.168.1.0/24"),
IP: net.ParseIP("192.168.1.1"),
Network: &net.IPNet{
IP: net.ParseIP("192.168.1.0"),
Mask: net.CIDRMask(24, 32),
},
},
testIP: netip.MustParseAddr("192.168.1.1"),
expected: true,
@@ -56,8 +68,11 @@ func TestLocalIPManager(t *testing.T) {
{
name: "Local IP doesn't match",
setupAddr: wgaddr.Address{
IP: netip.MustParseAddr("192.168.1.1"),
Network: netip.MustParsePrefix("192.168.1.0/24"),
IP: net.ParseIP("192.168.1.1"),
Network: &net.IPNet{
IP: net.ParseIP("192.168.1.0"),
Mask: net.CIDRMask(24, 32),
},
},
testIP: netip.MustParseAddr("192.168.1.2"),
expected: false,
@@ -65,8 +80,11 @@ func TestLocalIPManager(t *testing.T) {
{
name: "Local IP doesn't match - addresses 32 apart",
setupAddr: wgaddr.Address{
IP: netip.MustParseAddr("192.168.1.1"),
Network: netip.MustParsePrefix("192.168.1.0/24"),
IP: net.ParseIP("192.168.1.1"),
Network: &net.IPNet{
IP: net.ParseIP("192.168.1.0"),
Mask: net.CIDRMask(24, 32),
},
},
testIP: netip.MustParseAddr("192.168.1.33"),
expected: false,
@@ -74,8 +92,11 @@ func TestLocalIPManager(t *testing.T) {
{
name: "IPv6 address",
setupAddr: wgaddr.Address{
IP: netip.MustParseAddr("fe80::1"),
Network: netip.MustParsePrefix("192.168.1.0/24"),
IP: net.ParseIP("fe80::1"),
Network: &net.IPNet{
IP: net.ParseIP("fe80::"),
Mask: net.CIDRMask(64, 128),
},
},
testIP: netip.MustParseAddr("fe80::1"),
expected: false,

View File

@@ -1,408 +0,0 @@
package uspfilter
import (
"encoding/binary"
"errors"
"fmt"
"net/netip"
"github.com/google/gopacket/layers"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
)
var ErrIPv4Only = errors.New("only IPv4 is supported for DNAT")
func ipv4Checksum(header []byte) uint16 {
if len(header) < 20 {
return 0
}
var sum1, sum2 uint32
// Parallel processing - unroll and compute two sums simultaneously
sum1 += uint32(binary.BigEndian.Uint16(header[0:2]))
sum2 += uint32(binary.BigEndian.Uint16(header[2:4]))
sum1 += uint32(binary.BigEndian.Uint16(header[4:6]))
sum2 += uint32(binary.BigEndian.Uint16(header[6:8]))
sum1 += uint32(binary.BigEndian.Uint16(header[8:10]))
// Skip checksum field at [10:12]
sum2 += uint32(binary.BigEndian.Uint16(header[12:14]))
sum1 += uint32(binary.BigEndian.Uint16(header[14:16]))
sum2 += uint32(binary.BigEndian.Uint16(header[16:18]))
sum1 += uint32(binary.BigEndian.Uint16(header[18:20]))
sum := sum1 + sum2
// Handle remaining bytes for headers > 20 bytes
for i := 20; i < len(header)-1; i += 2 {
sum += uint32(binary.BigEndian.Uint16(header[i : i+2]))
}
if len(header)%2 == 1 {
sum += uint32(header[len(header)-1]) << 8
}
// Optimized carry fold - single iteration handles most cases
sum = (sum & 0xFFFF) + (sum >> 16)
if sum > 0xFFFF {
sum++
}
return ^uint16(sum)
}
func icmpChecksum(data []byte) uint16 {
var sum1, sum2, sum3, sum4 uint32
i := 0
// Process 16 bytes at once with 4 parallel accumulators
for i <= len(data)-16 {
sum1 += uint32(binary.BigEndian.Uint16(data[i : i+2]))
sum2 += uint32(binary.BigEndian.Uint16(data[i+2 : i+4]))
sum3 += uint32(binary.BigEndian.Uint16(data[i+4 : i+6]))
sum4 += uint32(binary.BigEndian.Uint16(data[i+6 : i+8]))
sum1 += uint32(binary.BigEndian.Uint16(data[i+8 : i+10]))
sum2 += uint32(binary.BigEndian.Uint16(data[i+10 : i+12]))
sum3 += uint32(binary.BigEndian.Uint16(data[i+12 : i+14]))
sum4 += uint32(binary.BigEndian.Uint16(data[i+14 : i+16]))
i += 16
}
sum := sum1 + sum2 + sum3 + sum4
// Handle remaining bytes
for i < len(data)-1 {
sum += uint32(binary.BigEndian.Uint16(data[i : i+2]))
i += 2
}
if len(data)%2 == 1 {
sum += uint32(data[len(data)-1]) << 8
}
sum = (sum & 0xFFFF) + (sum >> 16)
if sum > 0xFFFF {
sum++
}
return ^uint16(sum)
}
type biDNATMap struct {
forward map[netip.Addr]netip.Addr
reverse map[netip.Addr]netip.Addr
}
func newBiDNATMap() *biDNATMap {
return &biDNATMap{
forward: make(map[netip.Addr]netip.Addr),
reverse: make(map[netip.Addr]netip.Addr),
}
}
func (b *biDNATMap) set(original, translated netip.Addr) {
b.forward[original] = translated
b.reverse[translated] = original
}
func (b *biDNATMap) delete(original netip.Addr) {
if translated, exists := b.forward[original]; exists {
delete(b.forward, original)
delete(b.reverse, translated)
}
}
func (b *biDNATMap) getTranslated(original netip.Addr) (netip.Addr, bool) {
translated, exists := b.forward[original]
return translated, exists
}
func (b *biDNATMap) getOriginal(translated netip.Addr) (netip.Addr, bool) {
original, exists := b.reverse[translated]
return original, exists
}
func (m *Manager) AddInternalDNATMapping(originalAddr, translatedAddr netip.Addr) error {
if !originalAddr.IsValid() || !translatedAddr.IsValid() {
return fmt.Errorf("invalid IP addresses")
}
if m.localipmanager.IsLocalIP(translatedAddr) {
return fmt.Errorf("cannot map to local IP: %s", translatedAddr)
}
m.dnatMutex.Lock()
defer m.dnatMutex.Unlock()
// Initialize both maps together if either is nil
if m.dnatMappings == nil || m.dnatBiMap == nil {
m.dnatMappings = make(map[netip.Addr]netip.Addr)
m.dnatBiMap = newBiDNATMap()
}
m.dnatMappings[originalAddr] = translatedAddr
m.dnatBiMap.set(originalAddr, translatedAddr)
if len(m.dnatMappings) == 1 {
m.dnatEnabled.Store(true)
}
return nil
}
// RemoveInternalDNATMapping removes a 1:1 IP address mapping
func (m *Manager) RemoveInternalDNATMapping(originalAddr netip.Addr) error {
m.dnatMutex.Lock()
defer m.dnatMutex.Unlock()
if _, exists := m.dnatMappings[originalAddr]; !exists {
return fmt.Errorf("mapping not found for: %s", originalAddr)
}
delete(m.dnatMappings, originalAddr)
m.dnatBiMap.delete(originalAddr)
if len(m.dnatMappings) == 0 {
m.dnatEnabled.Store(false)
}
return nil
}
// getDNATTranslation returns the translated address if a mapping exists
func (m *Manager) getDNATTranslation(addr netip.Addr) (netip.Addr, bool) {
if !m.dnatEnabled.Load() {
return addr, false
}
m.dnatMutex.RLock()
translated, exists := m.dnatBiMap.getTranslated(addr)
m.dnatMutex.RUnlock()
return translated, exists
}
// findReverseDNATMapping finds original address for return traffic
func (m *Manager) findReverseDNATMapping(translatedAddr netip.Addr) (netip.Addr, bool) {
if !m.dnatEnabled.Load() {
return translatedAddr, false
}
m.dnatMutex.RLock()
original, exists := m.dnatBiMap.getOriginal(translatedAddr)
m.dnatMutex.RUnlock()
return original, exists
}
// translateOutboundDNAT applies DNAT translation to outbound packets
func (m *Manager) translateOutboundDNAT(packetData []byte, d *decoder) bool {
if !m.dnatEnabled.Load() {
return false
}
if len(packetData) < 20 || d.decoded[0] != layers.LayerTypeIPv4 {
return false
}
dstIP := netip.AddrFrom4([4]byte{packetData[16], packetData[17], packetData[18], packetData[19]})
translatedIP, exists := m.getDNATTranslation(dstIP)
if !exists {
return false
}
if err := m.rewritePacketDestination(packetData, d, translatedIP); err != nil {
m.logger.Error("Failed to rewrite packet destination: %v", err)
return false
}
m.logger.Trace("DNAT: %s -> %s", dstIP, translatedIP)
return true
}
// translateInboundReverse applies reverse DNAT to inbound return traffic
func (m *Manager) translateInboundReverse(packetData []byte, d *decoder) bool {
if !m.dnatEnabled.Load() {
return false
}
if len(packetData) < 20 || d.decoded[0] != layers.LayerTypeIPv4 {
return false
}
srcIP := netip.AddrFrom4([4]byte{packetData[12], packetData[13], packetData[14], packetData[15]})
originalIP, exists := m.findReverseDNATMapping(srcIP)
if !exists {
return false
}
if err := m.rewritePacketSource(packetData, d, originalIP); err != nil {
m.logger.Error("Failed to rewrite packet source: %v", err)
return false
}
m.logger.Trace("Reverse DNAT: %s -> %s", srcIP, originalIP)
return true
}
// rewritePacketDestination replaces destination IP in the packet
func (m *Manager) rewritePacketDestination(packetData []byte, d *decoder, newIP netip.Addr) error {
if len(packetData) < 20 || d.decoded[0] != layers.LayerTypeIPv4 || !newIP.Is4() {
return ErrIPv4Only
}
var oldDst [4]byte
copy(oldDst[:], packetData[16:20])
newDst := newIP.As4()
copy(packetData[16:20], newDst[:])
ipHeaderLen := int(d.ip4.IHL) * 4
if ipHeaderLen < 20 || ipHeaderLen > len(packetData) {
return fmt.Errorf("invalid IP header length")
}
binary.BigEndian.PutUint16(packetData[10:12], 0)
ipChecksum := ipv4Checksum(packetData[:ipHeaderLen])
binary.BigEndian.PutUint16(packetData[10:12], ipChecksum)
if len(d.decoded) > 1 {
switch d.decoded[1] {
case layers.LayerTypeTCP:
m.updateTCPChecksum(packetData, ipHeaderLen, oldDst[:], newDst[:])
case layers.LayerTypeUDP:
m.updateUDPChecksum(packetData, ipHeaderLen, oldDst[:], newDst[:])
case layers.LayerTypeICMPv4:
m.updateICMPChecksum(packetData, ipHeaderLen)
}
}
return nil
}
// rewritePacketSource replaces the source IP address in the packet
func (m *Manager) rewritePacketSource(packetData []byte, d *decoder, newIP netip.Addr) error {
if len(packetData) < 20 || d.decoded[0] != layers.LayerTypeIPv4 || !newIP.Is4() {
return ErrIPv4Only
}
var oldSrc [4]byte
copy(oldSrc[:], packetData[12:16])
newSrc := newIP.As4()
copy(packetData[12:16], newSrc[:])
ipHeaderLen := int(d.ip4.IHL) * 4
if ipHeaderLen < 20 || ipHeaderLen > len(packetData) {
return fmt.Errorf("invalid IP header length")
}
binary.BigEndian.PutUint16(packetData[10:12], 0)
ipChecksum := ipv4Checksum(packetData[:ipHeaderLen])
binary.BigEndian.PutUint16(packetData[10:12], ipChecksum)
if len(d.decoded) > 1 {
switch d.decoded[1] {
case layers.LayerTypeTCP:
m.updateTCPChecksum(packetData, ipHeaderLen, oldSrc[:], newSrc[:])
case layers.LayerTypeUDP:
m.updateUDPChecksum(packetData, ipHeaderLen, oldSrc[:], newSrc[:])
case layers.LayerTypeICMPv4:
m.updateICMPChecksum(packetData, ipHeaderLen)
}
}
return nil
}
func (m *Manager) updateTCPChecksum(packetData []byte, ipHeaderLen int, oldIP, newIP []byte) {
tcpStart := ipHeaderLen
if len(packetData) < tcpStart+18 {
return
}
checksumOffset := tcpStart + 16
oldChecksum := binary.BigEndian.Uint16(packetData[checksumOffset : checksumOffset+2])
newChecksum := incrementalUpdate(oldChecksum, oldIP, newIP)
binary.BigEndian.PutUint16(packetData[checksumOffset:checksumOffset+2], newChecksum)
}
func (m *Manager) updateUDPChecksum(packetData []byte, ipHeaderLen int, oldIP, newIP []byte) {
udpStart := ipHeaderLen
if len(packetData) < udpStart+8 {
return
}
checksumOffset := udpStart + 6
oldChecksum := binary.BigEndian.Uint16(packetData[checksumOffset : checksumOffset+2])
if oldChecksum == 0 {
return
}
newChecksum := incrementalUpdate(oldChecksum, oldIP, newIP)
binary.BigEndian.PutUint16(packetData[checksumOffset:checksumOffset+2], newChecksum)
}
func (m *Manager) updateICMPChecksum(packetData []byte, ipHeaderLen int) {
icmpStart := ipHeaderLen
if len(packetData) < icmpStart+8 {
return
}
icmpData := packetData[icmpStart:]
binary.BigEndian.PutUint16(icmpData[2:4], 0)
checksum := icmpChecksum(icmpData)
binary.BigEndian.PutUint16(icmpData[2:4], checksum)
}
// incrementalUpdate performs incremental checksum update per RFC 1624
func incrementalUpdate(oldChecksum uint16, oldBytes, newBytes []byte) uint16 {
sum := uint32(^oldChecksum)
// Fast path for IPv4 addresses (4 bytes) - most common case
if len(oldBytes) == 4 && len(newBytes) == 4 {
sum += uint32(^binary.BigEndian.Uint16(oldBytes[0:2]))
sum += uint32(^binary.BigEndian.Uint16(oldBytes[2:4]))
sum += uint32(binary.BigEndian.Uint16(newBytes[0:2]))
sum += uint32(binary.BigEndian.Uint16(newBytes[2:4]))
} else {
// Fallback for other lengths
for i := 0; i < len(oldBytes)-1; i += 2 {
sum += uint32(^binary.BigEndian.Uint16(oldBytes[i : i+2]))
}
if len(oldBytes)%2 == 1 {
sum += uint32(^oldBytes[len(oldBytes)-1]) << 8
}
for i := 0; i < len(newBytes)-1; i += 2 {
sum += uint32(binary.BigEndian.Uint16(newBytes[i : i+2]))
}
if len(newBytes)%2 == 1 {
sum += uint32(newBytes[len(newBytes)-1]) << 8
}
}
sum = (sum & 0xFFFF) + (sum >> 16)
if sum > 0xFFFF {
sum++
}
return ^uint16(sum)
}
// AddDNATRule adds a DNAT rule (delegates to native firewall for port forwarding)
func (m *Manager) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) {
if m.nativeFirewall == nil {
return nil, errNatNotSupported
}
return m.nativeFirewall.AddDNATRule(rule)
}
// DeleteDNATRule deletes a DNAT rule (delegates to native firewall)
func (m *Manager) DeleteDNATRule(rule firewall.Rule) error {
if m.nativeFirewall == nil {
return errNatNotSupported
}
return m.nativeFirewall.DeleteDNATRule(rule)
}

View File

@@ -1,416 +0,0 @@
package uspfilter
import (
"fmt"
"net/netip"
"testing"
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/iface/device"
)
// BenchmarkDNATTranslation measures the performance of DNAT operations
func BenchmarkDNATTranslation(b *testing.B) {
scenarios := []struct {
name string
proto layers.IPProtocol
setupDNAT bool
description string
}{
{
name: "tcp_with_dnat",
proto: layers.IPProtocolTCP,
setupDNAT: true,
description: "TCP packet with DNAT translation enabled",
},
{
name: "tcp_without_dnat",
proto: layers.IPProtocolTCP,
setupDNAT: false,
description: "TCP packet without DNAT (baseline)",
},
{
name: "udp_with_dnat",
proto: layers.IPProtocolUDP,
setupDNAT: true,
description: "UDP packet with DNAT translation enabled",
},
{
name: "udp_without_dnat",
proto: layers.IPProtocolUDP,
setupDNAT: false,
description: "UDP packet without DNAT (baseline)",
},
{
name: "icmp_with_dnat",
proto: layers.IPProtocolICMPv4,
setupDNAT: true,
description: "ICMP packet with DNAT translation enabled",
},
{
name: "icmp_without_dnat",
proto: layers.IPProtocolICMPv4,
setupDNAT: false,
description: "ICMP packet without DNAT (baseline)",
},
}
for _, sc := range scenarios {
b.Run(sc.name, func(b *testing.B) {
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger)
require.NoError(b, err)
defer func() {
require.NoError(b, manager.Close(nil))
}()
// Set logger to error level to reduce noise during benchmarking
manager.SetLogLevel(log.ErrorLevel)
defer func() {
// Restore to info level after benchmark
manager.SetLogLevel(log.InfoLevel)
}()
// Setup DNAT mapping if needed
originalIP := netip.MustParseAddr("192.168.1.100")
translatedIP := netip.MustParseAddr("10.0.0.100")
if sc.setupDNAT {
err := manager.AddInternalDNATMapping(originalIP, translatedIP)
require.NoError(b, err)
}
// Create test packets
srcIP := netip.MustParseAddr("172.16.0.1")
outboundPacket := generateDNATTestPacket(b, srcIP, originalIP, sc.proto, 12345, 80)
// Pre-establish connection for reverse DNAT test
if sc.setupDNAT {
manager.filterOutbound(outboundPacket, 0)
}
b.ResetTimer()
// Benchmark outbound DNAT translation
b.Run("outbound", func(b *testing.B) {
for i := 0; i < b.N; i++ {
// Create fresh packet each time since translation modifies it
packet := generateDNATTestPacket(b, srcIP, originalIP, sc.proto, 12345, 80)
manager.filterOutbound(packet, 0)
}
})
// Benchmark inbound reverse DNAT translation
if sc.setupDNAT {
b.Run("inbound_reverse", func(b *testing.B) {
for i := 0; i < b.N; i++ {
// Create fresh packet each time since translation modifies it
packet := generateDNATTestPacket(b, translatedIP, srcIP, sc.proto, 80, 12345)
manager.filterInbound(packet, 0)
}
})
}
})
}
}
// BenchmarkDNATConcurrency tests DNAT performance under concurrent load
func BenchmarkDNATConcurrency(b *testing.B) {
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger)
require.NoError(b, err)
defer func() {
require.NoError(b, manager.Close(nil))
}()
// Set logger to error level to reduce noise during benchmarking
manager.SetLogLevel(log.ErrorLevel)
defer func() {
// Restore to info level after benchmark
manager.SetLogLevel(log.InfoLevel)
}()
// Setup multiple DNAT mappings
numMappings := 100
originalIPs := make([]netip.Addr, numMappings)
translatedIPs := make([]netip.Addr, numMappings)
for i := 0; i < numMappings; i++ {
originalIPs[i] = netip.MustParseAddr(fmt.Sprintf("192.168.%d.%d", (i/254)+1, (i%254)+1))
translatedIPs[i] = netip.MustParseAddr(fmt.Sprintf("10.0.%d.%d", (i/254)+1, (i%254)+1))
err := manager.AddInternalDNATMapping(originalIPs[i], translatedIPs[i])
require.NoError(b, err)
}
srcIP := netip.MustParseAddr("172.16.0.1")
// Pre-generate packets
outboundPackets := make([][]byte, numMappings)
inboundPackets := make([][]byte, numMappings)
for i := 0; i < numMappings; i++ {
outboundPackets[i] = generateDNATTestPacket(b, srcIP, originalIPs[i], layers.IPProtocolTCP, 12345, 80)
inboundPackets[i] = generateDNATTestPacket(b, translatedIPs[i], srcIP, layers.IPProtocolTCP, 80, 12345)
// Establish connections
manager.filterOutbound(outboundPackets[i], 0)
}
b.ResetTimer()
b.Run("concurrent_outbound", func(b *testing.B) {
b.RunParallel(func(pb *testing.PB) {
i := 0
for pb.Next() {
idx := i % numMappings
packet := generateDNATTestPacket(b, srcIP, originalIPs[idx], layers.IPProtocolTCP, 12345, 80)
manager.filterOutbound(packet, 0)
i++
}
})
})
b.Run("concurrent_inbound", func(b *testing.B) {
b.RunParallel(func(pb *testing.PB) {
i := 0
for pb.Next() {
idx := i % numMappings
packet := generateDNATTestPacket(b, translatedIPs[idx], srcIP, layers.IPProtocolTCP, 80, 12345)
manager.filterInbound(packet, 0)
i++
}
})
})
}
// BenchmarkDNATScaling tests how DNAT performance scales with number of mappings
func BenchmarkDNATScaling(b *testing.B) {
mappingCounts := []int{1, 10, 100, 1000}
for _, count := range mappingCounts {
b.Run(fmt.Sprintf("mappings_%d", count), func(b *testing.B) {
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger)
require.NoError(b, err)
defer func() {
require.NoError(b, manager.Close(nil))
}()
// Set logger to error level to reduce noise during benchmarking
manager.SetLogLevel(log.ErrorLevel)
defer func() {
// Restore to info level after benchmark
manager.SetLogLevel(log.InfoLevel)
}()
// Setup DNAT mappings
for i := 0; i < count; i++ {
originalIP := netip.MustParseAddr(fmt.Sprintf("192.168.%d.%d", (i/254)+1, (i%254)+1))
translatedIP := netip.MustParseAddr(fmt.Sprintf("10.0.%d.%d", (i/254)+1, (i%254)+1))
err := manager.AddInternalDNATMapping(originalIP, translatedIP)
require.NoError(b, err)
}
// Test with the last mapping added (worst case for lookup)
srcIP := netip.MustParseAddr("172.16.0.1")
lastOriginal := netip.MustParseAddr(fmt.Sprintf("192.168.%d.%d", ((count-1)/254)+1, ((count-1)%254)+1))
b.ResetTimer()
for i := 0; i < b.N; i++ {
packet := generateDNATTestPacket(b, srcIP, lastOriginal, layers.IPProtocolTCP, 12345, 80)
manager.filterOutbound(packet, 0)
}
})
}
}
// generateDNATTestPacket creates a test packet for DNAT benchmarking
func generateDNATTestPacket(tb testing.TB, srcIP, dstIP netip.Addr, proto layers.IPProtocol, srcPort, dstPort uint16) []byte {
tb.Helper()
ipv4 := &layers.IPv4{
TTL: 64,
Version: 4,
SrcIP: srcIP.AsSlice(),
DstIP: dstIP.AsSlice(),
Protocol: proto,
}
var transportLayer gopacket.SerializableLayer
switch proto {
case layers.IPProtocolTCP:
tcp := &layers.TCP{
SrcPort: layers.TCPPort(srcPort),
DstPort: layers.TCPPort(dstPort),
SYN: true,
}
require.NoError(tb, tcp.SetNetworkLayerForChecksum(ipv4))
transportLayer = tcp
case layers.IPProtocolUDP:
udp := &layers.UDP{
SrcPort: layers.UDPPort(srcPort),
DstPort: layers.UDPPort(dstPort),
}
require.NoError(tb, udp.SetNetworkLayerForChecksum(ipv4))
transportLayer = udp
case layers.IPProtocolICMPv4:
icmp := &layers.ICMPv4{
TypeCode: layers.CreateICMPv4TypeCode(layers.ICMPv4TypeEchoRequest, 0),
}
transportLayer = icmp
}
buf := gopacket.NewSerializeBuffer()
opts := gopacket.SerializeOptions{ComputeChecksums: true, FixLengths: true}
err := gopacket.SerializeLayers(buf, opts, ipv4, transportLayer, gopacket.Payload("test"))
require.NoError(tb, err)
return buf.Bytes()
}
// BenchmarkChecksumUpdate specifically benchmarks checksum calculation performance
func BenchmarkChecksumUpdate(b *testing.B) {
// Create test data for checksum calculations
testData := make([]byte, 64) // Typical packet size for checksum testing
for i := range testData {
testData[i] = byte(i)
}
b.Run("ipv4_checksum", func(b *testing.B) {
for i := 0; i < b.N; i++ {
_ = ipv4Checksum(testData[:20]) // IPv4 header is typically 20 bytes
}
})
b.Run("icmp_checksum", func(b *testing.B) {
for i := 0; i < b.N; i++ {
_ = icmpChecksum(testData)
}
})
b.Run("incremental_update", func(b *testing.B) {
oldBytes := []byte{192, 168, 1, 100}
newBytes := []byte{10, 0, 0, 100}
oldChecksum := uint16(0x1234)
for i := 0; i < b.N; i++ {
_ = incrementalUpdate(oldChecksum, oldBytes, newBytes)
}
})
}
// BenchmarkDNATMemoryAllocations checks for memory allocations in DNAT operations
func BenchmarkDNATMemoryAllocations(b *testing.B) {
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger)
require.NoError(b, err)
defer func() {
require.NoError(b, manager.Close(nil))
}()
// Set logger to error level to reduce noise during benchmarking
manager.SetLogLevel(log.ErrorLevel)
defer func() {
// Restore to info level after benchmark
manager.SetLogLevel(log.InfoLevel)
}()
originalIP := netip.MustParseAddr("192.168.1.100")
translatedIP := netip.MustParseAddr("10.0.0.100")
srcIP := netip.MustParseAddr("172.16.0.1")
err = manager.AddInternalDNATMapping(originalIP, translatedIP)
require.NoError(b, err)
packet := generateDNATTestPacket(b, srcIP, originalIP, layers.IPProtocolTCP, 12345, 80)
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
// Create fresh packet each time to isolate allocation testing
testPacket := make([]byte, len(packet))
copy(testPacket, packet)
// Parse the packet fresh each time to get a clean decoder
d := &decoder{decoded: []gopacket.LayerType{}}
d.parser = gopacket.NewDecodingLayerParser(
layers.LayerTypeIPv4,
&d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp,
)
d.parser.IgnoreUnsupported = true
err = d.parser.DecodeLayers(testPacket, &d.decoded)
assert.NoError(b, err)
manager.translateOutboundDNAT(testPacket, d)
}
}
// BenchmarkDirectIPExtraction tests the performance improvement of direct IP extraction
func BenchmarkDirectIPExtraction(b *testing.B) {
// Create a test packet
srcIP := netip.MustParseAddr("172.16.0.1")
dstIP := netip.MustParseAddr("192.168.1.100")
packet := generateDNATTestPacket(b, srcIP, dstIP, layers.IPProtocolTCP, 12345, 80)
b.Run("direct_byte_access", func(b *testing.B) {
for i := 0; i < b.N; i++ {
// Direct extraction from packet bytes
_ = netip.AddrFrom4([4]byte{packet[16], packet[17], packet[18], packet[19]})
}
})
b.Run("decoder_extraction", func(b *testing.B) {
// Create decoder once for comparison
d := &decoder{decoded: []gopacket.LayerType{}}
d.parser = gopacket.NewDecodingLayerParser(
layers.LayerTypeIPv4,
&d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp,
)
d.parser.IgnoreUnsupported = true
err := d.parser.DecodeLayers(packet, &d.decoded)
assert.NoError(b, err)
for i := 0; i < b.N; i++ {
// Extract using decoder (traditional method)
dst, _ := netip.AddrFromSlice(d.ip4.DstIP)
_ = dst
}
})
}
// BenchmarkChecksumOptimizations compares optimized vs standard checksum implementations
func BenchmarkChecksumOptimizations(b *testing.B) {
// Create test IPv4 header (20 bytes)
header := make([]byte, 20)
for i := range header {
header[i] = byte(i)
}
// Clear checksum field
header[10] = 0
header[11] = 0
b.Run("optimized_ipv4_checksum", func(b *testing.B) {
for i := 0; i < b.N; i++ {
_ = ipv4Checksum(header)
}
})
// Test incremental checksum updates
oldIP := []byte{192, 168, 1, 100}
newIP := []byte{10, 0, 0, 100}
oldChecksum := uint16(0x1234)
b.Run("optimized_incremental_update", func(b *testing.B) {
for i := 0; i < b.N; i++ {
_ = incrementalUpdate(oldChecksum, oldIP, newIP)
}
})
}

View File

@@ -1,145 +0,0 @@
package uspfilter
import (
"net/netip"
"testing"
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/iface/device"
)
// TestDNATTranslationCorrectness verifies DNAT translation works correctly
func TestDNATTranslationCorrectness(t *testing.T) {
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger)
require.NoError(t, err)
defer func() {
require.NoError(t, manager.Close(nil))
}()
originalIP := netip.MustParseAddr("192.168.1.100")
translatedIP := netip.MustParseAddr("10.0.0.100")
srcIP := netip.MustParseAddr("172.16.0.1")
// Add DNAT mapping
err = manager.AddInternalDNATMapping(originalIP, translatedIP)
require.NoError(t, err)
testCases := []struct {
name string
protocol layers.IPProtocol
srcPort uint16
dstPort uint16
}{
{"TCP", layers.IPProtocolTCP, 12345, 80},
{"UDP", layers.IPProtocolUDP, 12345, 53},
{"ICMP", layers.IPProtocolICMPv4, 0, 0},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
// Test outbound DNAT translation
outboundPacket := generateDNATTestPacket(t, srcIP, originalIP, tc.protocol, tc.srcPort, tc.dstPort)
originalOutbound := make([]byte, len(outboundPacket))
copy(originalOutbound, outboundPacket)
// Process outbound packet (should translate destination)
translated := manager.translateOutboundDNAT(outboundPacket, parsePacket(t, outboundPacket))
require.True(t, translated, "Outbound packet should be translated")
// Verify destination IP was changed
dstIPAfter := netip.AddrFrom4([4]byte{outboundPacket[16], outboundPacket[17], outboundPacket[18], outboundPacket[19]})
require.Equal(t, translatedIP, dstIPAfter, "Destination IP should be translated")
// Test inbound reverse DNAT translation
inboundPacket := generateDNATTestPacket(t, translatedIP, srcIP, tc.protocol, tc.dstPort, tc.srcPort)
originalInbound := make([]byte, len(inboundPacket))
copy(originalInbound, inboundPacket)
// Process inbound packet (should reverse translate source)
reversed := manager.translateInboundReverse(inboundPacket, parsePacket(t, inboundPacket))
require.True(t, reversed, "Inbound packet should be reverse translated")
// Verify source IP was changed back to original
srcIPAfter := netip.AddrFrom4([4]byte{inboundPacket[12], inboundPacket[13], inboundPacket[14], inboundPacket[15]})
require.Equal(t, originalIP, srcIPAfter, "Source IP should be reverse translated")
// Test that checksums are recalculated correctly
if tc.protocol != layers.IPProtocolICMPv4 {
// For TCP/UDP, verify the transport checksum was updated
require.NotEqual(t, originalOutbound, outboundPacket, "Outbound packet should be modified")
require.NotEqual(t, originalInbound, inboundPacket, "Inbound packet should be modified")
}
})
}
}
// parsePacket helper to create a decoder for testing
func parsePacket(t testing.TB, packetData []byte) *decoder {
t.Helper()
d := &decoder{
decoded: []gopacket.LayerType{},
}
d.parser = gopacket.NewDecodingLayerParser(
layers.LayerTypeIPv4,
&d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp,
)
d.parser.IgnoreUnsupported = true
err := d.parser.DecodeLayers(packetData, &d.decoded)
require.NoError(t, err)
return d
}
// TestDNATMappingManagement tests adding/removing DNAT mappings
func TestDNATMappingManagement(t *testing.T) {
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger)
require.NoError(t, err)
defer func() {
require.NoError(t, manager.Close(nil))
}()
originalIP := netip.MustParseAddr("192.168.1.100")
translatedIP := netip.MustParseAddr("10.0.0.100")
// Test adding mapping
err = manager.AddInternalDNATMapping(originalIP, translatedIP)
require.NoError(t, err)
// Verify mapping exists
result, exists := manager.getDNATTranslation(originalIP)
require.True(t, exists)
require.Equal(t, translatedIP, result)
// Test reverse lookup
reverseResult, exists := manager.findReverseDNATMapping(translatedIP)
require.True(t, exists)
require.Equal(t, originalIP, reverseResult)
// Test removing mapping
err = manager.RemoveInternalDNATMapping(originalIP)
require.NoError(t, err)
// Verify mapping no longer exists
_, exists = manager.getDNATTranslation(originalIP)
require.False(t, exists)
_, exists = manager.findReverseDNATMapping(translatedIP)
require.False(t, exists)
// Test error cases
err = manager.AddInternalDNATMapping(netip.Addr{}, translatedIP)
require.Error(t, err, "Should reject invalid original IP")
err = manager.AddInternalDNATMapping(originalIP, netip.Addr{})
require.Error(t, err, "Should reject invalid translated IP")
err = manager.RemoveInternalDNATMapping(originalIP)
require.Error(t, err, "Should error when removing non-existent mapping")
}

View File

@@ -401,7 +401,7 @@ func (m *Manager) addForwardingResult(trace *PacketTrace, action, remoteAddr str
func (m *Manager) traceOutbound(packetData []byte, trace *PacketTrace) *PacketTrace {
// will create or update the connection state
dropped := m.filterOutbound(packetData, 0)
dropped := m.processOutgoingHooks(packetData, 0)
if dropped {
trace.AddResult(StageCompleted, "Packet dropped by outgoing hook", false)
} else {

View File

@@ -38,8 +38,11 @@ func TestTracePacket(t *testing.T) {
SetFilterFunc: func(device.PacketFilter) error { return nil },
AddressFunc: func() wgaddr.Address {
return wgaddr.Address{
IP: netip.MustParseAddr("100.10.0.100"),
Network: netip.MustParsePrefix("100.10.0.0/16"),
IP: net.ParseIP("100.10.0.100"),
Network: &net.IPNet{
IP: net.ParseIP("100.10.0.0"),
Mask: net.CIDRMask(16, 32),
},
}
},
}

View File

@@ -39,12 +39,8 @@ const (
// EnvForceUserspaceRouter forces userspace routing even if native routing is available.
EnvForceUserspaceRouter = "NB_FORCE_USERSPACE_ROUTER"
// EnvEnableLocalForwarding enables forwarding of local traffic to the native stack for internal (non-NetBird) interfaces.
// Default off as it might be security risk because sockets listening on localhost only will become accessible.
EnvEnableLocalForwarding = "NB_ENABLE_LOCAL_FORWARDING"
// EnvEnableNetstackLocalForwarding is an alias for EnvEnableLocalForwarding.
// In netstack mode, it enables forwarding of local traffic to the native stack for all interfaces.
// EnvEnableNetstackLocalForwarding enables forwarding of local traffic to the native stack when running netstack
// Leaving this on by default introduces a security risk as sockets on listening on localhost only will be accessible
EnvEnableNetstackLocalForwarding = "NB_ENABLE_NETSTACK_LOCAL_FORWARDING"
)
@@ -75,6 +71,7 @@ type Manager struct {
// incomingRules is used for filtering and hooks
incomingRules map[netip.Addr]RuleSet
routeRules RouteRules
wgNetwork *net.IPNet
decoders sync.Pool
wgIface common.IFaceMapper
nativeFirewall firewall.Manager
@@ -104,12 +101,6 @@ type Manager struct {
flowLogger nftypes.FlowLogger
blockRule firewall.Rule
// Internal 1:1 DNAT
dnatEnabled atomic.Bool
dnatMappings map[netip.Addr]netip.Addr
dnatMutex sync.RWMutex
dnatBiMap *biDNATMap
}
// decoder for packages
@@ -157,11 +148,6 @@ func parseCreateEnv() (bool, bool) {
if err != nil {
log.Warnf("failed to parse %s: %v", EnvEnableNetstackLocalForwarding, err)
}
} else if val := os.Getenv(EnvEnableLocalForwarding); val != "" {
enableLocalForwarding, err = strconv.ParseBool(val)
if err != nil {
log.Warnf("failed to parse %s: %v", EnvEnableLocalForwarding, err)
}
}
return disableConntrack, enableLocalForwarding
@@ -195,7 +181,6 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe
flowLogger: flowLogger,
netstack: netstack.IsEnabled(),
localForwarding: enableLocalForwarding,
dnatMappings: make(map[netip.Addr]netip.Addr),
}
m.routingEnabled.Store(false)
@@ -284,7 +269,7 @@ func (m *Manager) determineRouting() error {
log.Info("userspace routing is forced")
case !m.netstack && m.nativeFirewall != nil:
case !m.netstack && m.nativeFirewall != nil && m.nativeFirewall.IsServerRouteSupported():
// if the OS supports routing natively, then we don't need to filter/route ourselves
// netstack mode won't support native routing as there is no interface
@@ -341,10 +326,6 @@ func (m *Manager) IsServerRouteSupported() bool {
return true
}
func (m *Manager) IsStateful() bool {
return m.stateful
}
func (m *Manager) AddNatRule(pair firewall.RouterPair) error {
if m.nativeRouter.Load() && m.nativeFirewall != nil {
return m.nativeFirewall.AddNatRule(pair)
@@ -526,6 +507,22 @@ func (m *Manager) SetLegacyManagement(isLegacy bool) error {
// Flush doesn't need to be implemented for this manager
func (m *Manager) Flush() error { return nil }
// AddDNATRule adds a DNAT rule
func (m *Manager) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) {
if m.nativeFirewall == nil {
return nil, errNatNotSupported
}
return m.nativeFirewall.AddDNATRule(rule)
}
// DeleteDNATRule deletes a DNAT rule
func (m *Manager) DeleteDNATRule(rule firewall.Rule) error {
if m.nativeFirewall == nil {
return errNatNotSupported
}
return m.nativeFirewall.DeleteDNATRule(rule)
}
// UpdateSet updates the rule destinations associated with the given set
// by merging the existing prefixes with the new ones, then deduplicating.
func (m *Manager) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
@@ -572,14 +569,14 @@ func (m *Manager) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
return nil
}
// FilterOutBound filters outgoing packets
func (m *Manager) FilterOutbound(packetData []byte, size int) bool {
return m.filterOutbound(packetData, size)
// DropOutgoing filter outgoing packets
func (m *Manager) DropOutgoing(packetData []byte, size int) bool {
return m.processOutgoingHooks(packetData, size)
}
// FilterInbound filters incoming packets
func (m *Manager) FilterInbound(packetData []byte, size int) bool {
return m.filterInbound(packetData, size)
// DropIncoming filter incoming packets
func (m *Manager) DropIncoming(packetData []byte, size int) bool {
return m.dropFilter(packetData, size)
}
// UpdateLocalIPs updates the list of local IPs
@@ -587,7 +584,7 @@ func (m *Manager) UpdateLocalIPs() error {
return m.localipmanager.UpdateLocalIPs(m.wgIface)
}
func (m *Manager) filterOutbound(packetData []byte, size int) bool {
func (m *Manager) processOutgoingHooks(packetData []byte, size int) bool {
d := m.decoders.Get().(*decoder)
defer m.decoders.Put(d)
@@ -609,8 +606,9 @@ func (m *Manager) filterOutbound(packetData []byte, size int) bool {
return true
}
m.trackOutbound(d, srcIP, dstIP, size)
m.translateOutboundDNAT(packetData, d)
if m.stateful {
m.trackOutbound(d, srcIP, dstIP, size)
}
return false
}
@@ -662,7 +660,7 @@ func (m *Manager) trackOutbound(d *decoder, srcIP, dstIP netip.Addr, size int) {
flags := getTCPFlags(&d.tcp)
m.tcpTracker.TrackOutbound(srcIP, dstIP, uint16(d.tcp.SrcPort), uint16(d.tcp.DstPort), flags, size)
case layers.LayerTypeICMPv4:
m.icmpTracker.TrackOutbound(srcIP, dstIP, d.icmp4.Id, d.icmp4.TypeCode, d.icmp4.Payload, size)
m.icmpTracker.TrackOutbound(srcIP, dstIP, d.icmp4.Id, d.icmp4.TypeCode, size)
}
}
@@ -675,7 +673,7 @@ func (m *Manager) trackInbound(d *decoder, srcIP, dstIP netip.Addr, ruleID []byt
flags := getTCPFlags(&d.tcp)
m.tcpTracker.TrackInbound(srcIP, dstIP, uint16(d.tcp.SrcPort), uint16(d.tcp.DstPort), flags, ruleID, size)
case layers.LayerTypeICMPv4:
m.icmpTracker.TrackInbound(srcIP, dstIP, d.icmp4.Id, d.icmp4.TypeCode, ruleID, d.icmp4.Payload, size)
m.icmpTracker.TrackInbound(srcIP, dstIP, d.icmp4.Id, d.icmp4.TypeCode, ruleID, size)
}
}
@@ -714,9 +712,9 @@ func (m *Manager) udpHooksDrop(dport uint16, dstIP netip.Addr, packetData []byte
return false
}
// filterInbound implements filtering logic for incoming packets.
// dropFilter implements filtering logic for incoming packets.
// If it returns true, the packet should be dropped.
func (m *Manager) filterInbound(packetData []byte, size int) bool {
func (m *Manager) dropFilter(packetData []byte, size int) bool {
d := m.decoders.Get().(*decoder)
defer m.decoders.Put(d)
@@ -738,15 +736,8 @@ func (m *Manager) filterInbound(packetData []byte, size int) bool {
return false
}
if translated := m.translateInboundReverse(packetData, d); translated {
// Re-decode after translation to get original addresses
if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil {
m.logger.Error("Failed to re-decode packet after reverse DNAT: %v", err)
return true
}
srcIP, dstIP = m.extractIPs(d)
}
// For all inbound traffic, first check if it matches a tracked connection.
// This must happen before any other filtering because the packets are statefully tracked.
if m.stateful && m.isValidTrackedConnection(d, srcIP, dstIP, size) {
return false
}
@@ -786,10 +777,9 @@ func (m *Manager) handleLocalTraffic(d *decoder, srcIP, dstIP netip.Addr, packet
return true
}
// If requested we pass local traffic to internal interfaces to the forwarder.
// netstack doesn't have an interface to forward packets to the native stack so we always need to use the forwarder.
if m.localForwarding && (m.netstack || dstIP != m.wgIface.Address().IP) {
return m.handleForwardedLocalTraffic(packetData)
// if running in netstack mode we need to pass this to the forwarder
if m.netstack && m.localForwarding {
return m.handleNetstackLocalTraffic(packetData)
}
// track inbound packets to get the correct direction and session id for flows
@@ -799,7 +789,8 @@ func (m *Manager) handleLocalTraffic(d *decoder, srcIP, dstIP netip.Addr, packet
return false
}
func (m *Manager) handleForwardedLocalTraffic(packetData []byte) bool {
func (m *Manager) handleNetstackLocalTraffic(packetData []byte) bool {
fwd := m.forwarder.Load()
if fwd == nil {
m.logger.Trace("Dropping local packet (forwarder not initialized)")
@@ -1097,6 +1088,11 @@ func (m *Manager) ruleMatches(rule *RouteRule, srcAddr, dstAddr netip.Addr, prot
return true
}
// SetNetwork of the wireguard interface to which filtering applied
func (m *Manager) SetNetwork(network *net.IPNet) {
m.wgNetwork = network
}
// AddUDPPacketHook calls hook when UDP packet from given direction matched
//
// Hook function returns flag which indicates should be the matched package dropped or not

View File

@@ -174,6 +174,11 @@ func BenchmarkCoreFiltering(b *testing.B) {
require.NoError(b, manager.Close(nil))
})
manager.wgNetwork = &net.IPNet{
IP: net.ParseIP("100.64.0.0"),
Mask: net.CIDRMask(10, 32),
}
// Apply scenario-specific setup
sc.setupFunc(manager)
@@ -188,13 +193,13 @@ func BenchmarkCoreFiltering(b *testing.B) {
// For stateful scenarios, establish the connection
if sc.stateful {
manager.filterOutbound(outbound, 0)
manager.processOutgoingHooks(outbound, 0)
}
// Measure inbound packet processing
b.ResetTimer()
for i := 0; i < b.N; i++ {
manager.filterInbound(inbound, 0)
manager.dropFilter(inbound, 0)
}
})
}
@@ -214,13 +219,18 @@ func BenchmarkStateScaling(b *testing.B) {
require.NoError(b, manager.Close(nil))
})
manager.wgNetwork = &net.IPNet{
IP: net.ParseIP("100.64.0.0"),
Mask: net.CIDRMask(10, 32),
}
// Pre-populate connection table
srcIPs := generateRandomIPs(count)
dstIPs := generateRandomIPs(count)
for i := 0; i < count; i++ {
outbound := generatePacket(b, srcIPs[i], dstIPs[i],
uint16(1024+i), 80, layers.IPProtocolTCP)
manager.filterOutbound(outbound, 0)
manager.processOutgoingHooks(outbound, 0)
}
// Test packet
@@ -228,11 +238,11 @@ func BenchmarkStateScaling(b *testing.B) {
testIn := generatePacket(b, dstIPs[0], srcIPs[0], 80, 1024, layers.IPProtocolTCP)
// First establish our test connection
manager.filterOutbound(testOut, 0)
manager.processOutgoingHooks(testOut, 0)
b.ResetTimer()
for i := 0; i < b.N; i++ {
manager.filterInbound(testIn, 0)
manager.dropFilter(testIn, 0)
}
})
}
@@ -257,18 +267,23 @@ func BenchmarkEstablishmentOverhead(b *testing.B) {
require.NoError(b, manager.Close(nil))
})
manager.wgNetwork = &net.IPNet{
IP: net.ParseIP("100.64.0.0"),
Mask: net.CIDRMask(10, 32),
}
srcIP := generateRandomIPs(1)[0]
dstIP := generateRandomIPs(1)[0]
outbound := generatePacket(b, srcIP, dstIP, 1024, 80, layers.IPProtocolTCP)
inbound := generatePacket(b, dstIP, srcIP, 80, 1024, layers.IPProtocolTCP)
if sc.established {
manager.filterOutbound(outbound, 0)
manager.processOutgoingHooks(outbound, 0)
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
manager.filterInbound(inbound, 0)
manager.dropFilter(inbound, 0)
}
})
}
@@ -289,6 +304,10 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
proto: layers.IPProtocolTCP,
state: "new",
setupFunc: func(m *Manager) {
m.wgNetwork = &net.IPNet{
IP: net.ParseIP("100.64.0.0"),
Mask: net.CIDRMask(10, 32),
}
b.Setenv("NB_DISABLE_CONNTRACK", "1")
},
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
@@ -302,6 +321,10 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
proto: layers.IPProtocolTCP,
state: "established",
setupFunc: func(m *Manager) {
m.wgNetwork = &net.IPNet{
IP: net.ParseIP("100.64.0.0"),
Mask: net.CIDRMask(10, 32),
}
b.Setenv("NB_DISABLE_CONNTRACK", "1")
},
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
@@ -316,6 +339,10 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
proto: layers.IPProtocolUDP,
state: "new",
setupFunc: func(m *Manager) {
m.wgNetwork = &net.IPNet{
IP: net.ParseIP("100.64.0.0"),
Mask: net.CIDRMask(10, 32),
}
b.Setenv("NB_DISABLE_CONNTRACK", "1")
},
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
@@ -329,6 +356,10 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
proto: layers.IPProtocolUDP,
state: "established",
setupFunc: func(m *Manager) {
m.wgNetwork = &net.IPNet{
IP: net.ParseIP("100.64.0.0"),
Mask: net.CIDRMask(10, 32),
}
b.Setenv("NB_DISABLE_CONNTRACK", "1")
},
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
@@ -342,6 +373,10 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
proto: layers.IPProtocolTCP,
state: "new",
setupFunc: func(m *Manager) {
m.wgNetwork = &net.IPNet{
IP: net.ParseIP("0.0.0.0"),
Mask: net.CIDRMask(0, 32),
}
require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK"))
},
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
@@ -355,6 +390,10 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
proto: layers.IPProtocolTCP,
state: "established",
setupFunc: func(m *Manager) {
m.wgNetwork = &net.IPNet{
IP: net.ParseIP("0.0.0.0"),
Mask: net.CIDRMask(0, 32),
}
require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK"))
},
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
@@ -369,6 +408,10 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
proto: layers.IPProtocolTCP,
state: "post_handshake",
setupFunc: func(m *Manager) {
m.wgNetwork = &net.IPNet{
IP: net.ParseIP("0.0.0.0"),
Mask: net.CIDRMask(0, 32),
}
require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK"))
},
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
@@ -383,6 +426,10 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
proto: layers.IPProtocolUDP,
state: "new",
setupFunc: func(m *Manager) {
m.wgNetwork = &net.IPNet{
IP: net.ParseIP("0.0.0.0"),
Mask: net.CIDRMask(0, 32),
}
require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK"))
},
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
@@ -396,6 +443,10 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
proto: layers.IPProtocolUDP,
state: "established",
setupFunc: func(m *Manager) {
m.wgNetwork = &net.IPNet{
IP: net.ParseIP("0.0.0.0"),
Mask: net.CIDRMask(0, 32),
}
require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK"))
},
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
@@ -426,25 +477,25 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
// For stateful cases and established connections
if !strings.Contains(sc.name, "allow_non_wg") ||
(strings.Contains(sc.state, "established") || sc.state == "post_handshake") {
manager.filterOutbound(outbound, 0)
manager.processOutgoingHooks(outbound, 0)
// For TCP post-handshake, simulate full handshake
if sc.state == "post_handshake" {
// SYN
syn := generateTCPPacketWithFlags(b, srcIP, dstIP, 1024, 80, uint16(conntrack.TCPSyn))
manager.filterOutbound(syn, 0)
manager.processOutgoingHooks(syn, 0)
// SYN-ACK
synack := generateTCPPacketWithFlags(b, dstIP, srcIP, 80, 1024, uint16(conntrack.TCPSyn|conntrack.TCPAck))
manager.filterInbound(synack, 0)
manager.dropFilter(synack, 0)
// ACK
ack := generateTCPPacketWithFlags(b, srcIP, dstIP, 1024, 80, uint16(conntrack.TCPAck))
manager.filterOutbound(ack, 0)
manager.processOutgoingHooks(ack, 0)
}
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
manager.filterInbound(inbound, 0)
manager.dropFilter(inbound, 0)
}
})
}
@@ -542,6 +593,11 @@ func BenchmarkLongLivedConnections(b *testing.B) {
require.NoError(b, manager.Close(nil))
})
manager.SetNetwork(&net.IPNet{
IP: net.ParseIP("100.64.0.0"),
Mask: net.CIDRMask(10, 32),
})
// Setup initial state based on scenario
if sc.rules {
// Single rule to allow all return traffic from port 80
@@ -568,17 +624,17 @@ func BenchmarkLongLivedConnections(b *testing.B) {
// Initial SYN
syn := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
uint16(1024+i), 80, uint16(conntrack.TCPSyn))
manager.filterOutbound(syn, 0)
manager.processOutgoingHooks(syn, 0)
// SYN-ACK
synack := generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i],
80, uint16(1024+i), uint16(conntrack.TCPSyn|conntrack.TCPAck))
manager.filterInbound(synack, 0)
manager.dropFilter(synack, 0)
// ACK
ack := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
uint16(1024+i), 80, uint16(conntrack.TCPAck))
manager.filterOutbound(ack, 0)
manager.processOutgoingHooks(ack, 0)
}
// Prepare test packets simulating bidirectional traffic
@@ -599,9 +655,9 @@ func BenchmarkLongLivedConnections(b *testing.B) {
// Simulate bidirectional traffic
// First outbound data
manager.filterOutbound(outPackets[connIdx], 0)
manager.processOutgoingHooks(outPackets[connIdx], 0)
// Then inbound response - this is what we're actually measuring
manager.filterInbound(inPackets[connIdx], 0)
manager.dropFilter(inPackets[connIdx], 0)
}
})
}
@@ -625,6 +681,11 @@ func BenchmarkShortLivedConnections(b *testing.B) {
require.NoError(b, manager.Close(nil))
})
manager.SetNetwork(&net.IPNet{
IP: net.ParseIP("100.64.0.0"),
Mask: net.CIDRMask(10, 32),
})
// Setup initial state based on scenario
if sc.rules {
// Single rule to allow all return traffic from port 80
@@ -700,19 +761,19 @@ func BenchmarkShortLivedConnections(b *testing.B) {
p := patterns[connIdx]
// Connection establishment
manager.filterOutbound(p.syn, 0)
manager.filterInbound(p.synAck, 0)
manager.filterOutbound(p.ack, 0)
manager.processOutgoingHooks(p.syn, 0)
manager.dropFilter(p.synAck, 0)
manager.processOutgoingHooks(p.ack, 0)
// Data transfer
manager.filterOutbound(p.request, 0)
manager.filterInbound(p.response, 0)
manager.processOutgoingHooks(p.request, 0)
manager.dropFilter(p.response, 0)
// Connection teardown
manager.filterOutbound(p.finClient, 0)
manager.filterInbound(p.ackServer, 0)
manager.filterInbound(p.finServer, 0)
manager.filterOutbound(p.ackClient, 0)
manager.processOutgoingHooks(p.finClient, 0)
manager.dropFilter(p.ackServer, 0)
manager.dropFilter(p.finServer, 0)
manager.processOutgoingHooks(p.ackClient, 0)
}
})
}
@@ -736,6 +797,11 @@ func BenchmarkParallelLongLivedConnections(b *testing.B) {
require.NoError(b, manager.Close(nil))
})
manager.SetNetwork(&net.IPNet{
IP: net.ParseIP("100.64.0.0"),
Mask: net.CIDRMask(10, 32),
})
// Setup initial state based on scenario
if sc.rules {
_, err := manager.AddPeerFiltering(nil, net.ParseIP("0.0.0.0"), fw.ProtocolTCP, &fw.Port{Values: []uint16{80}}, nil, fw.ActionAccept, "")
@@ -760,15 +826,15 @@ func BenchmarkParallelLongLivedConnections(b *testing.B) {
for i := 0; i < sc.connCount; i++ {
syn := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
uint16(1024+i), 80, uint16(conntrack.TCPSyn))
manager.filterOutbound(syn, 0)
manager.processOutgoingHooks(syn, 0)
synack := generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i],
80, uint16(1024+i), uint16(conntrack.TCPSyn|conntrack.TCPAck))
manager.filterInbound(synack, 0)
manager.dropFilter(synack, 0)
ack := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
uint16(1024+i), 80, uint16(conntrack.TCPAck))
manager.filterOutbound(ack, 0)
manager.processOutgoingHooks(ack, 0)
}
// Pre-generate test packets
@@ -790,8 +856,8 @@ func BenchmarkParallelLongLivedConnections(b *testing.B) {
counter++
// Simulate bidirectional traffic
manager.filterOutbound(outPackets[connIdx], 0)
manager.filterInbound(inPackets[connIdx], 0)
manager.processOutgoingHooks(outPackets[connIdx], 0)
manager.dropFilter(inPackets[connIdx], 0)
}
})
})
@@ -816,6 +882,11 @@ func BenchmarkParallelShortLivedConnections(b *testing.B) {
require.NoError(b, manager.Close(nil))
})
manager.SetNetwork(&net.IPNet{
IP: net.ParseIP("100.64.0.0"),
Mask: net.CIDRMask(10, 32),
})
if sc.rules {
_, err := manager.AddPeerFiltering(nil, net.ParseIP("0.0.0.0"), fw.ProtocolTCP, &fw.Port{Values: []uint16{80}}, nil, fw.ActionAccept, "")
require.NoError(b, err)
@@ -879,17 +950,17 @@ func BenchmarkParallelShortLivedConnections(b *testing.B) {
p := patterns[connIdx]
// Full connection lifecycle
manager.filterOutbound(p.syn, 0)
manager.filterInbound(p.synAck, 0)
manager.filterOutbound(p.ack, 0)
manager.processOutgoingHooks(p.syn, 0)
manager.dropFilter(p.synAck, 0)
manager.processOutgoingHooks(p.ack, 0)
manager.filterOutbound(p.request, 0)
manager.filterInbound(p.response, 0)
manager.processOutgoingHooks(p.request, 0)
manager.dropFilter(p.response, 0)
manager.filterOutbound(p.finClient, 0)
manager.filterInbound(p.ackServer, 0)
manager.filterInbound(p.finServer, 0)
manager.filterOutbound(p.ackClient, 0)
manager.processOutgoingHooks(p.finClient, 0)
manager.dropFilter(p.ackServer, 0)
manager.dropFilter(p.finServer, 0)
manager.processOutgoingHooks(p.ackClient, 0)
}
})
})
@@ -961,8 +1032,7 @@ func BenchmarkRouteACLs(b *testing.B) {
}
for _, r := range rules {
dst := fw.Network{Prefix: r.dest}
_, err := manager.AddRouteFiltering(nil, r.sources, dst, r.proto, nil, r.port, fw.ActionAccept)
_, err := manager.AddRouteFiltering(nil, r.sources, r.dest, r.proto, nil, r.port, fw.ActionAccept)
if err != nil {
b.Fatal(err)
}

View File

@@ -19,8 +19,12 @@ import (
)
func TestPeerACLFiltering(t *testing.T) {
localIP := netip.MustParseAddr("100.10.0.100")
wgNet := netip.MustParsePrefix("100.10.0.0/16")
localIP := net.ParseIP("100.10.0.100")
wgNet := &net.IPNet{
IP: net.ParseIP("100.10.0.0"),
Mask: net.CIDRMask(16, 32),
}
ifaceMock := &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
AddressFunc: func() wgaddr.Address {
@@ -39,6 +43,8 @@ func TestPeerACLFiltering(t *testing.T) {
require.NoError(t, manager.Close(nil))
})
manager.wgNetwork = wgNet
err = manager.UpdateLocalIPs()
require.NoError(t, err)
@@ -462,7 +468,7 @@ func TestPeerACLFiltering(t *testing.T) {
t.Run("Implicit DROP (no rules)", func(t *testing.T) {
packet := createTestPacket(t, "100.10.0.1", "100.10.0.100", fw.ProtocolTCP, 12345, 443)
isDropped := manager.FilterInbound(packet, 0)
isDropped := manager.DropIncoming(packet, 0)
require.True(t, isDropped, "Packet should be dropped when no rules exist")
})
@@ -509,7 +515,7 @@ func TestPeerACLFiltering(t *testing.T) {
})
packet := createTestPacket(t, tc.srcIP, tc.dstIP, tc.proto, tc.srcPort, tc.dstPort)
isDropped := manager.FilterInbound(packet, 0)
isDropped := manager.DropIncoming(packet, 0)
require.Equal(t, tc.shouldBeBlocked, isDropped)
})
}
@@ -575,13 +581,14 @@ func setupRoutedManager(tb testing.TB, network string) *Manager {
dev := mocks.NewMockDevice(ctrl)
dev.EXPECT().MTU().Return(1500, nil).AnyTimes()
wgNet := netip.MustParsePrefix(network)
localIP, wgNet, err := net.ParseCIDR(network)
require.NoError(tb, err)
ifaceMock := &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
AddressFunc: func() wgaddr.Address {
return wgaddr.Address{
IP: wgNet.Addr(),
IP: localIP,
Network: wgNet,
}
},
@@ -1233,7 +1240,7 @@ func TestRouteACLFiltering(t *testing.T) {
srcIP := netip.MustParseAddr(tc.srcIP)
dstIP := netip.MustParseAddr(tc.dstIP)
// testing routeACLsPass only and not FilterInbound, as routed packets are dropped after being passed
// testing routeACLsPass only and not DropIncoming, as routed packets are dropped after being passed
// to the forwarder
_, isAllowed := manager.routeACLsPass(srcIP, dstIP, tc.proto, tc.srcPort, tc.dstPort)
require.Equal(t, tc.shouldPass, isAllowed)
@@ -1433,8 +1440,11 @@ func TestRouteACLSet(t *testing.T) {
SetFilterFunc: func(device.PacketFilter) error { return nil },
AddressFunc: func() wgaddr.Address {
return wgaddr.Address{
IP: netip.MustParseAddr("100.10.0.100"),
Network: netip.MustParsePrefix("100.10.0.0/16"),
IP: net.ParseIP("100.10.0.100"),
Network: &net.IPNet{
IP: net.ParseIP("100.10.0.0"),
Mask: net.CIDRMask(16, 32),
},
}
},
}

View File

@@ -271,8 +271,11 @@ func TestNotMatchByIP(t *testing.T) {
SetFilterFunc: func(device.PacketFilter) error { return nil },
AddressFunc: func() wgaddr.Address {
return wgaddr.Address{
IP: netip.MustParseAddr("100.10.0.100"),
Network: netip.MustParsePrefix("100.10.0.0/16"),
IP: net.ParseIP("100.10.0.100"),
Network: &net.IPNet{
IP: net.ParseIP("100.10.0.0"),
Mask: net.CIDRMask(16, 32),
},
}
},
}
@@ -282,6 +285,10 @@ func TestNotMatchByIP(t *testing.T) {
t.Errorf("failed to create Manager: %v", err)
return
}
m.wgNetwork = &net.IPNet{
IP: net.ParseIP("100.10.0.0"),
Mask: net.CIDRMask(16, 32),
}
ip := net.ParseIP("0.0.0.0")
proto := fw.ProtocolUDP
@@ -321,7 +328,7 @@ func TestNotMatchByIP(t *testing.T) {
return
}
if m.filterInbound(buf.Bytes(), 0) {
if m.dropFilter(buf.Bytes(), 0) {
t.Errorf("expected packet to be accepted")
return
}
@@ -389,6 +396,10 @@ func TestProcessOutgoingHooks(t *testing.T) {
}, false, flowLogger)
require.NoError(t, err)
manager.wgNetwork = &net.IPNet{
IP: net.ParseIP("100.10.0.0"),
Mask: net.CIDRMask(16, 32),
}
manager.udpTracker.Close()
manager.udpTracker = conntrack.NewUDPTracker(100*time.Millisecond, logger, flowLogger)
defer func() {
@@ -447,7 +458,7 @@ func TestProcessOutgoingHooks(t *testing.T) {
require.NoError(t, err)
// Test hook gets called
result := manager.filterOutbound(buf.Bytes(), 0)
result := manager.processOutgoingHooks(buf.Bytes(), 0)
require.True(t, result)
require.True(t, hookCalled)
@@ -457,7 +468,7 @@ func TestProcessOutgoingHooks(t *testing.T) {
err = gopacket.SerializeLayers(buf, opts, ipv4)
require.NoError(t, err)
result = manager.filterOutbound(buf.Bytes(), 0)
result = manager.processOutgoingHooks(buf.Bytes(), 0)
require.False(t, result)
}
@@ -498,6 +509,11 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
}, false, flowLogger)
require.NoError(t, err)
manager.wgNetwork = &net.IPNet{
IP: net.ParseIP("100.10.0.0"),
Mask: net.CIDRMask(16, 32),
}
manager.udpTracker.Close() // Close the existing tracker
manager.udpTracker = conntrack.NewUDPTracker(200*time.Millisecond, logger, flowLogger)
manager.decoders = sync.Pool{
@@ -553,7 +569,7 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
require.NoError(t, err)
// Process outbound packet and verify connection tracking
drop := manager.FilterOutbound(outboundBuf.Bytes(), 0)
drop := manager.DropOutgoing(outboundBuf.Bytes(), 0)
require.False(t, drop, "Initial outbound packet should not be dropped")
// Verify connection was tracked
@@ -620,7 +636,7 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
for _, cp := range checkPoints {
time.Sleep(cp.sleep)
drop = manager.filterInbound(inboundBuf.Bytes(), 0)
drop = manager.dropFilter(inboundBuf.Bytes(), 0)
require.Equal(t, cp.shouldAllow, !drop, cp.description)
// If the connection should still be valid, verify it exists
@@ -669,7 +685,7 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
}
// Create a new outbound connection for invalid tests
drop = manager.filterOutbound(outboundBuf.Bytes(), 0)
drop = manager.processOutgoingHooks(outboundBuf.Bytes(), 0)
require.False(t, drop, "Second outbound packet should not be dropped")
for _, tc := range invalidCases {
@@ -691,7 +707,7 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
require.NoError(t, err)
// Verify the invalid packet is dropped
drop = manager.filterInbound(testBuf.Bytes(), 0)
drop = manager.dropFilter(testBuf.Bytes(), 0)
require.True(t, drop, tc.description)
})
}

View File

@@ -1,96 +0,0 @@
package bind
import (
"net/netip"
"sync"
"sync/atomic"
"time"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/monotime"
)
const (
saveFrequency = int64(5 * time.Second)
)
type PeerRecord struct {
Address netip.AddrPort
LastActivity atomic.Int64 // UnixNano timestamp
}
type ActivityRecorder struct {
mu sync.RWMutex
peers map[string]*PeerRecord // publicKey to PeerRecord map
addrToPeer map[netip.AddrPort]*PeerRecord // address to PeerRecord map
}
func NewActivityRecorder() *ActivityRecorder {
return &ActivityRecorder{
peers: make(map[string]*PeerRecord),
addrToPeer: make(map[netip.AddrPort]*PeerRecord),
}
}
// GetLastActivities returns a snapshot of peer last activity
func (r *ActivityRecorder) GetLastActivities() map[string]monotime.Time {
r.mu.RLock()
defer r.mu.RUnlock()
activities := make(map[string]monotime.Time, len(r.peers))
for key, record := range r.peers {
monoTime := record.LastActivity.Load()
activities[key] = monotime.Time(monoTime)
}
return activities
}
// UpsertAddress adds or updates the address for a publicKey
func (r *ActivityRecorder) UpsertAddress(publicKey string, address netip.AddrPort) {
r.mu.Lock()
defer r.mu.Unlock()
var record *PeerRecord
record, exists := r.peers[publicKey]
if exists {
delete(r.addrToPeer, record.Address)
record.Address = address
} else {
record = &PeerRecord{
Address: address,
}
record.LastActivity.Store(int64(monotime.Now()))
r.peers[publicKey] = record
}
r.addrToPeer[address] = record
}
func (r *ActivityRecorder) Remove(publicKey string) {
r.mu.Lock()
defer r.mu.Unlock()
if record, exists := r.peers[publicKey]; exists {
delete(r.addrToPeer, record.Address)
delete(r.peers, publicKey)
}
}
// record updates LastActivity for the given address using atomic store
func (r *ActivityRecorder) record(address netip.AddrPort) {
r.mu.RLock()
record, ok := r.addrToPeer[address]
r.mu.RUnlock()
if !ok {
log.Warnf("could not find record for address %s", address)
return
}
now := int64(monotime.Now())
last := record.LastActivity.Load()
if now-last < saveFrequency {
return
}
_ = record.LastActivity.CompareAndSwap(last, now)
}

View File

@@ -1,25 +0,0 @@
package bind
import (
"net/netip"
"testing"
"time"
"github.com/netbirdio/netbird/monotime"
)
func TestActivityRecorder_GetLastActivities(t *testing.T) {
peer := "peer1"
ar := NewActivityRecorder()
ar.UpsertAddress("peer1", netip.MustParseAddrPort("192.168.0.5:51820"))
activities := ar.GetLastActivities()
p, ok := activities[peer]
if !ok {
t.Fatalf("Expected activity for peer %s, but got none", peer)
}
if monotime.Since(p) > 5*time.Second {
t.Fatalf("Expected activity for peer %s to be recent, but got %v", peer, p)
}
}

View File

@@ -1,15 +0,0 @@
package bind
import (
wireguard "golang.zx2c4.com/wireguard/conn"
nbnet "github.com/netbirdio/netbird/util/net"
)
// TODO: This is most likely obsolete since the control fns should be called by the wrapped udpconn (ice_bind.go)
func init() {
listener := nbnet.NewListener()
if listener.ListenConfig.Control != nil {
*wireguard.ControlFns = append(*wireguard.ControlFns, listener.ListenConfig.Control)
}
}

View File

@@ -0,0 +1,12 @@
package bind
import (
wireguard "golang.zx2c4.com/wireguard/conn"
nbnet "github.com/netbirdio/netbird/util/net"
)
func init() {
// ControlFns is not thread safe and should only be modified during init.
*wireguard.ControlFns = append(*wireguard.ControlFns, nbnet.ControlProtectSocket)
}

View File

@@ -1,7 +1,6 @@
package bind
import (
"encoding/binary"
"fmt"
"net"
"net/netip"
@@ -16,7 +15,6 @@ import (
wgConn "golang.zx2c4.com/wireguard/conn"
"github.com/netbirdio/netbird/client/iface/wgaddr"
nbnet "github.com/netbirdio/netbird/util/net"
)
type RecvMessage struct {
@@ -53,24 +51,22 @@ type ICEBind struct {
closedChanMu sync.RWMutex // protect the closeChan recreation from reading from it.
closed bool
muUDPMux sync.Mutex
udpMux *UniversalUDPMuxDefault
address wgaddr.Address
activityRecorder *ActivityRecorder
muUDPMux sync.Mutex
udpMux *UniversalUDPMuxDefault
address wgaddr.Address
}
func NewICEBind(transportNet transport.Net, filterFn FilterFn, address wgaddr.Address) *ICEBind {
b, _ := wgConn.NewStdNetBind().(*wgConn.StdNetBind)
ib := &ICEBind{
StdNetBind: b,
RecvChan: make(chan RecvMessage, 1),
transportNet: transportNet,
filterFn: filterFn,
endpoints: make(map[netip.Addr]net.Conn),
closedChan: make(chan struct{}),
closed: true,
address: address,
activityRecorder: NewActivityRecorder(),
StdNetBind: b,
RecvChan: make(chan RecvMessage, 1),
transportNet: transportNet,
filterFn: filterFn,
endpoints: make(map[netip.Addr]net.Conn),
closedChan: make(chan struct{}),
closed: true,
address: address,
}
rc := receiverCreator{
@@ -104,10 +100,6 @@ func (s *ICEBind) Close() error {
return s.StdNetBind.Close()
}
func (s *ICEBind) ActivityRecorder() *ActivityRecorder {
return s.activityRecorder
}
// GetICEMux returns the ICE UDPMux that was created and used by ICEBind
func (s *ICEBind) GetICEMux() (*UniversalUDPMuxDefault, error) {
s.muUDPMux.Lock()
@@ -154,7 +146,7 @@ func (s *ICEBind) createIPv4ReceiverFn(pc *ipv4.PacketConn, conn *net.UDPConn, r
s.udpMux = NewUniversalUDPMuxDefault(
UniversalUDPMuxParams{
UDPConn: nbnet.WrapPacketConn(conn),
UDPConn: conn,
Net: s.transportNet,
FilterFn: s.filterFn,
WGAddress: s.address,
@@ -207,11 +199,6 @@ func (s *ICEBind) createIPv4ReceiverFn(pc *ipv4.PacketConn, conn *net.UDPConn, r
continue
}
addrPort := msg.Addr.(*net.UDPAddr).AddrPort()
if isTransportPkg(msg.Buffers, msg.N) {
s.activityRecorder.record(addrPort)
}
ep := &wgConn.StdNetEndpoint{AddrPort: addrPort} // TODO: remove allocation
wgConn.GetSrcFromControl(msg.OOB[:msg.NN], ep)
eps[i] = ep
@@ -270,13 +257,6 @@ func (c *ICEBind) receiveRelayed(buffs [][]byte, sizes []int, eps []wgConn.Endpo
copy(buffs[0], msg.Buffer)
sizes[0] = len(msg.Buffer)
eps[0] = wgConn.Endpoint(msg.Endpoint)
if isTransportPkg(buffs, sizes[0]) {
if ep, ok := eps[0].(*Endpoint); ok {
c.activityRecorder.record(ep.AddrPort)
}
}
return 1, nil
}
}
@@ -292,19 +272,3 @@ func putMessages(msgs *[]ipv6.Message, msgsPool *sync.Pool) {
}
msgsPool.Put(msgs)
}
func isTransportPkg(buffers [][]byte, n int) bool {
// The first buffer should contain at least 4 bytes for type
if len(buffers[0]) < 4 {
return true
}
// WireGuard packet type is a little-endian uint32 at start
packetType := binary.LittleEndian.Uint32(buffers[0][:4])
// Check if packetType matches known WireGuard message types
if packetType == 4 && n > 32 {
return true
}
return false
}

View File

@@ -296,20 +296,14 @@ func (m *UDPMuxDefault) RemoveConnByUfrag(ufrag string) {
return
}
var allAddresses []string
m.addressMapMu.Lock()
defer m.addressMapMu.Unlock()
for _, c := range removedConns {
addresses := c.getAddresses()
allAddresses = append(allAddresses, addresses...)
}
m.addressMapMu.Lock()
for _, addr := range allAddresses {
delete(m.addressMap, addr)
}
m.addressMapMu.Unlock()
for _, addr := range allAddresses {
m.notifyAddressRemoval(addr)
for _, addr := range addresses {
delete(m.addressMap, addr)
}
}
}
@@ -357,13 +351,14 @@ func (m *UDPMuxDefault) registerConnForAddress(conn *udpMuxedConn, addr string)
}
m.addressMapMu.Lock()
defer m.addressMapMu.Unlock()
existing, ok := m.addressMap[addr]
if !ok {
existing = []*udpMuxedConn{}
}
existing = append(existing, conn)
m.addressMap[addr] = existing
m.addressMapMu.Unlock()
log.Debugf("ICE: registered %s for %s", addr, conn.params.Key)
}
@@ -391,12 +386,12 @@ func (m *UDPMuxDefault) HandleSTUNMessage(msg *stun.Message, addr net.Addr) erro
// If you are using the same socket for the Host and SRFLX candidates, it might be that there are more than one
// muxed connection - one for the SRFLX candidate and the other one for the HOST one.
// We will then forward STUN packets to each of these connections.
m.addressMapMu.RLock()
m.addressMapMu.Lock()
var destinationConnList []*udpMuxedConn
if storedConns, ok := m.addressMap[addr.String()]; ok {
destinationConnList = append(destinationConnList, storedConns...)
}
m.addressMapMu.RUnlock()
m.addressMapMu.Unlock()
var isIPv6 bool
if udpAddr, _ := addr.(*net.UDPAddr); udpAddr != nil && udpAddr.IP.To4() == nil {

View File

@@ -1,22 +0,0 @@
//go:build !ios
package bind
import (
nbnet "github.com/netbirdio/netbird/util/net"
)
func (m *UDPMuxDefault) notifyAddressRemoval(addr string) {
// Kernel mode: direct nbnet.PacketConn (SharedSocket wrapped with nbnet)
if conn, ok := m.params.UDPConn.(*nbnet.PacketConn); ok {
conn.RemoveAddress(addr)
return
}
// Userspace mode: UDPConn wrapper around nbnet.PacketConn
if wrapped, ok := m.params.UDPConn.(*UDPConn); ok {
if conn, ok := wrapped.GetPacketConn().(*nbnet.PacketConn); ok {
conn.RemoveAddress(addr)
}
}
}

View File

@@ -1,7 +0,0 @@
//go:build ios
package bind
func (m *UDPMuxDefault) notifyAddressRemoval(addr string) {
// iOS doesn't support nbnet hooks, so this is a no-op
}

View File

@@ -62,7 +62,7 @@ func NewUniversalUDPMuxDefault(params UniversalUDPMuxParams) *UniversalUDPMuxDef
// wrap UDP connection, process server reflexive messages
// before they are passed to the UDPMux connection handler (connWorker)
m.params.UDPConn = &UDPConn{
m.params.UDPConn = &udpConn{
PacketConn: params.UDPConn,
mux: m,
logger: params.Logger,
@@ -70,6 +70,7 @@ func NewUniversalUDPMuxDefault(params UniversalUDPMuxParams) *UniversalUDPMuxDef
address: params.WGAddress,
}
// embed UDPMux
udpMuxParams := UDPMuxParams{
Logger: params.Logger,
UDPConn: m.params.UDPConn,
@@ -113,8 +114,8 @@ func (m *UniversalUDPMuxDefault) ReadFromConn(ctx context.Context) {
}
}
// UDPConn is a wrapper around UDPMux conn that overrides ReadFrom and handles STUN/TURN packets
type UDPConn struct {
// udpConn is a wrapper around UDPMux conn that overrides ReadFrom and handles STUN/TURN packets
type udpConn struct {
net.PacketConn
mux *UniversalUDPMuxDefault
logger logging.LeveledLogger
@@ -124,12 +125,7 @@ type UDPConn struct {
address wgaddr.Address
}
// GetPacketConn returns the underlying PacketConn
func (u *UDPConn) GetPacketConn() net.PacketConn {
return u.PacketConn
}
func (u *UDPConn) WriteTo(b []byte, addr net.Addr) (int, error) {
func (u *udpConn) WriteTo(b []byte, addr net.Addr) (int, error) {
if u.filterFn == nil {
return u.PacketConn.WriteTo(b, addr)
}
@@ -141,21 +137,21 @@ func (u *UDPConn) WriteTo(b []byte, addr net.Addr) (int, error) {
return u.handleUncachedAddress(b, addr)
}
func (u *UDPConn) handleCachedAddress(isRouted bool, b []byte, addr net.Addr) (int, error) {
func (u *udpConn) handleCachedAddress(isRouted bool, b []byte, addr net.Addr) (int, error) {
if isRouted {
return 0, fmt.Errorf("address %s is part of a routed network, refusing to write", addr)
}
return u.PacketConn.WriteTo(b, addr)
}
func (u *UDPConn) handleUncachedAddress(b []byte, addr net.Addr) (int, error) {
func (u *udpConn) handleUncachedAddress(b []byte, addr net.Addr) (int, error) {
if err := u.performFilterCheck(addr); err != nil {
return 0, err
}
return u.PacketConn.WriteTo(b, addr)
}
func (u *UDPConn) performFilterCheck(addr net.Addr) error {
func (u *udpConn) performFilterCheck(addr net.Addr) error {
host, err := getHostFromAddr(addr)
if err != nil {
log.Errorf("Failed to get host from address %s: %v", addr, err)
@@ -168,7 +164,7 @@ func (u *UDPConn) performFilterCheck(addr net.Addr) error {
return nil
}
if u.address.Network.Contains(a) {
if u.address.Network.Contains(a.AsSlice()) {
log.Warnf("Address %s is part of the NetBird network %s, refusing to write", addr, u.address)
return fmt.Errorf("address %s is part of the NetBird network %s, refusing to write", addr, u.address)
}

View File

@@ -1,17 +0,0 @@
package configurer
import (
"net"
"net/netip"
)
func prefixesToIPNets(prefixes []netip.Prefix) []net.IPNet {
ipNets := make([]net.IPNet, len(prefixes))
for i, prefix := range prefixes {
ipNets[i] = net.IPNet{
IP: prefix.Addr().AsSlice(), // Convert netip.Addr to net.IP
Mask: net.CIDRMask(prefix.Bits(), prefix.Addr().BitLen()), // Create subnet mask
}
}
return ipNets
}

View File

@@ -5,18 +5,13 @@ package configurer
import (
"fmt"
"net"
"net/netip"
"time"
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/monotime"
)
var zeroKey wgtypes.Key
type KernelConfigurer struct {
deviceName string
}
@@ -48,7 +43,7 @@ func (c *KernelConfigurer) ConfigureInterface(privateKey string, port int) error
return nil
}
func (c *KernelConfigurer) UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error {
func (c *KernelConfigurer) UpdatePeer(peerKey string, allowedIps []net.IPNet, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error {
peerKeyParsed, err := wgtypes.ParseKey(peerKey)
if err != nil {
return err
@@ -57,7 +52,7 @@ func (c *KernelConfigurer) UpdatePeer(peerKey string, allowedIps []netip.Prefix,
PublicKey: peerKeyParsed,
ReplaceAllowedIPs: false,
// don't replace allowed ips, wg will handle duplicated peer IP
AllowedIPs: prefixesToIPNets(allowedIps),
AllowedIPs: allowedIps,
PersistentKeepaliveInterval: &keepAlive,
Endpoint: endpoint,
PresharedKey: preSharedKey,
@@ -94,10 +89,10 @@ func (c *KernelConfigurer) RemovePeer(peerKey string) error {
return nil
}
func (c *KernelConfigurer) AddAllowedIP(peerKey string, allowedIP netip.Prefix) error {
ipNet := net.IPNet{
IP: allowedIP.Addr().AsSlice(),
Mask: net.CIDRMask(allowedIP.Bits(), allowedIP.Addr().BitLen()),
func (c *KernelConfigurer) AddAllowedIP(peerKey string, allowedIP string) error {
_, ipNet, err := net.ParseCIDR(allowedIP)
if err != nil {
return err
}
peerKeyParsed, err := wgtypes.ParseKey(peerKey)
@@ -108,7 +103,7 @@ func (c *KernelConfigurer) AddAllowedIP(peerKey string, allowedIP netip.Prefix)
PublicKey: peerKeyParsed,
UpdateOnly: true,
ReplaceAllowedIPs: false,
AllowedIPs: []net.IPNet{ipNet},
AllowedIPs: []net.IPNet{*ipNet},
}
config := wgtypes.Config{
@@ -121,10 +116,10 @@ func (c *KernelConfigurer) AddAllowedIP(peerKey string, allowedIP netip.Prefix)
return nil
}
func (c *KernelConfigurer) RemoveAllowedIP(peerKey string, allowedIP netip.Prefix) error {
ipNet := net.IPNet{
IP: allowedIP.Addr().AsSlice(),
Mask: net.CIDRMask(allowedIP.Bits(), allowedIP.Addr().BitLen()),
func (c *KernelConfigurer) RemoveAllowedIP(peerKey string, allowedIP string) error {
_, ipNet, err := net.ParseCIDR(allowedIP)
if err != nil {
return fmt.Errorf("parse allowed IP: %w", err)
}
peerKeyParsed, err := wgtypes.ParseKey(peerKey)
@@ -192,11 +187,7 @@ func (c *KernelConfigurer) configure(config wgtypes.Config) error {
if err != nil {
return err
}
defer func() {
if err := wg.Close(); err != nil {
log.Errorf("Failed to close wgctrl client: %v", err)
}
}()
defer wg.Close()
// validate if device with name exists
_, err = wg.Device(c.deviceName)
@@ -210,75 +201,14 @@ func (c *KernelConfigurer) configure(config wgtypes.Config) error {
func (c *KernelConfigurer) Close() {
}
func (c *KernelConfigurer) FullStats() (*Stats, error) {
wg, err := wgctrl.New()
func (c *KernelConfigurer) GetStats(peerKey string) (WGStats, error) {
peer, err := c.getPeer(c.deviceName, peerKey)
if err != nil {
return nil, fmt.Errorf("wgctl: %w", err)
return WGStats{}, fmt.Errorf("get wireguard stats: %w", err)
}
defer func() {
err = wg.Close()
if err != nil {
log.Errorf("Got error while closing wgctl: %v", err)
}
}()
wgDevice, err := wg.Device(c.deviceName)
if err != nil {
return nil, fmt.Errorf("get device %s: %w", c.deviceName, err)
}
fullStats := &Stats{
DeviceName: wgDevice.Name,
PublicKey: wgDevice.PublicKey.String(),
ListenPort: wgDevice.ListenPort,
FWMark: wgDevice.FirewallMark,
Peers: []Peer{},
}
for _, p := range wgDevice.Peers {
peer := Peer{
PublicKey: p.PublicKey.String(),
AllowedIPs: p.AllowedIPs,
TxBytes: p.TransmitBytes,
RxBytes: p.ReceiveBytes,
LastHandshake: p.LastHandshakeTime,
PresharedKey: p.PresharedKey != zeroKey,
}
if p.Endpoint != nil {
peer.Endpoint = *p.Endpoint
}
fullStats.Peers = append(fullStats.Peers, peer)
}
return fullStats, nil
}
func (c *KernelConfigurer) GetStats() (map[string]WGStats, error) {
stats := make(map[string]WGStats)
wg, err := wgctrl.New()
if err != nil {
return nil, fmt.Errorf("wgctl: %w", err)
}
defer func() {
err = wg.Close()
if err != nil {
log.Errorf("Got error while closing wgctl: %v", err)
}
}()
wgDevice, err := wg.Device(c.deviceName)
if err != nil {
return nil, fmt.Errorf("get device %s: %w", c.deviceName, err)
}
for _, peer := range wgDevice.Peers {
stats[peer.PublicKey.String()] = WGStats{
LastHandshake: peer.LastHandshakeTime,
TxBytes: peer.TransmitBytes,
RxBytes: peer.ReceiveBytes,
}
}
return stats, nil
}
func (c *KernelConfigurer) LastActivities() map[string]monotime.Time {
return nil
return WGStats{
LastHandshake: peer.LastHandshakeTime,
TxBytes: peer.TransmitBytes,
RxBytes: peer.ReceiveBytes,
}, nil
}

View File

@@ -1,11 +1,9 @@
package configurer
import (
"encoding/base64"
"encoding/hex"
"fmt"
"net"
"net/netip"
"os"
"runtime"
"strconv"
@@ -16,40 +14,22 @@ import (
"golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/monotime"
nbnet "github.com/netbirdio/netbird/util/net"
)
const (
privateKey = "private_key"
ipcKeyLastHandshakeTimeSec = "last_handshake_time_sec"
ipcKeyLastHandshakeTimeNsec = "last_handshake_time_nsec"
ipcKeyTxBytes = "tx_bytes"
ipcKeyRxBytes = "rx_bytes"
allowedIP = "allowed_ip"
endpoint = "endpoint"
fwmark = "fwmark"
listenPort = "listen_port"
publicKey = "public_key"
presharedKey = "preshared_key"
)
var ErrAllowedIPNotFound = fmt.Errorf("allowed IP not found")
type WGUSPConfigurer struct {
device *device.Device
deviceName string
activityRecorder *bind.ActivityRecorder
device *device.Device
deviceName string
uapiListener net.Listener
}
func NewUSPConfigurer(device *device.Device, deviceName string, activityRecorder *bind.ActivityRecorder) *WGUSPConfigurer {
func NewUSPConfigurer(device *device.Device, deviceName string) *WGUSPConfigurer {
wgCfg := &WGUSPConfigurer{
device: device,
deviceName: deviceName,
activityRecorder: activityRecorder,
device: device,
deviceName: deviceName,
}
wgCfg.startUAPI()
return wgCfg
@@ -72,7 +52,7 @@ func (c *WGUSPConfigurer) ConfigureInterface(privateKey string, port int) error
return c.device.IpcSet(toWgUserspaceString(config))
}
func (c *WGUSPConfigurer) UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error {
func (c *WGUSPConfigurer) UpdatePeer(peerKey string, allowedIps []net.IPNet, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error {
peerKeyParsed, err := wgtypes.ParseKey(peerKey)
if err != nil {
return err
@@ -81,7 +61,7 @@ func (c *WGUSPConfigurer) UpdatePeer(peerKey string, allowedIps []netip.Prefix,
PublicKey: peerKeyParsed,
ReplaceAllowedIPs: false,
// don't replace allowed ips, wg will handle duplicated peer IP
AllowedIPs: prefixesToIPNets(allowedIps),
AllowedIPs: allowedIps,
PersistentKeepaliveInterval: &keepAlive,
PresharedKey: preSharedKey,
Endpoint: endpoint,
@@ -91,19 +71,7 @@ func (c *WGUSPConfigurer) UpdatePeer(peerKey string, allowedIps []netip.Prefix,
Peers: []wgtypes.PeerConfig{peer},
}
if ipcErr := c.device.IpcSet(toWgUserspaceString(config)); ipcErr != nil {
return ipcErr
}
if endpoint != nil {
addr, err := netip.ParseAddr(endpoint.IP.String())
if err != nil {
return fmt.Errorf("failed to parse endpoint address: %w", err)
}
addrPort := netip.AddrPortFrom(addr, uint16(endpoint.Port))
c.activityRecorder.UpsertAddress(peerKey, addrPort)
}
return nil
return c.device.IpcSet(toWgUserspaceString(config))
}
func (c *WGUSPConfigurer) RemovePeer(peerKey string) error {
@@ -120,16 +88,13 @@ func (c *WGUSPConfigurer) RemovePeer(peerKey string) error {
config := wgtypes.Config{
Peers: []wgtypes.PeerConfig{peer},
}
ipcErr := c.device.IpcSet(toWgUserspaceString(config))
c.activityRecorder.Remove(peerKey)
return ipcErr
return c.device.IpcSet(toWgUserspaceString(config))
}
func (c *WGUSPConfigurer) AddAllowedIP(peerKey string, allowedIP netip.Prefix) error {
ipNet := net.IPNet{
IP: allowedIP.Addr().AsSlice(),
Mask: net.CIDRMask(allowedIP.Bits(), allowedIP.Addr().BitLen()),
func (c *WGUSPConfigurer) AddAllowedIP(peerKey string, allowedIP string) error {
_, ipNet, err := net.ParseCIDR(allowedIP)
if err != nil {
return err
}
peerKeyParsed, err := wgtypes.ParseKey(peerKey)
@@ -140,7 +105,7 @@ func (c *WGUSPConfigurer) AddAllowedIP(peerKey string, allowedIP netip.Prefix) e
PublicKey: peerKeyParsed,
UpdateOnly: true,
ReplaceAllowedIPs: false,
AllowedIPs: []net.IPNet{ipNet},
AllowedIPs: []net.IPNet{*ipNet},
}
config := wgtypes.Config{
@@ -150,7 +115,7 @@ func (c *WGUSPConfigurer) AddAllowedIP(peerKey string, allowedIP netip.Prefix) e
return c.device.IpcSet(toWgUserspaceString(config))
}
func (c *WGUSPConfigurer) RemoveAllowedIP(peerKey string, allowedIP netip.Prefix) error {
func (c *WGUSPConfigurer) RemoveAllowedIP(peerKey string, ip string) error {
ipc, err := c.device.IpcGet()
if err != nil {
return err
@@ -173,8 +138,6 @@ func (c *WGUSPConfigurer) RemoveAllowedIP(peerKey string, allowedIP netip.Prefix
foundPeer := false
removedAllowedIP := false
ip := allowedIP.String()
for _, line := range lines {
line = strings.TrimSpace(line)
@@ -197,8 +160,8 @@ func (c *WGUSPConfigurer) RemoveAllowedIP(peerKey string, allowedIP netip.Prefix
// Append the line to the output string
if foundPeer && strings.HasPrefix(line, "allowed_ip=") {
allowedIPStr := strings.TrimPrefix(line, "allowed_ip=")
_, ipNet, err := net.ParseCIDR(allowedIPStr)
allowedIP := strings.TrimPrefix(line, "allowed_ip=")
_, ipNet, err := net.ParseCIDR(allowedIP)
if err != nil {
return err
}
@@ -215,19 +178,6 @@ func (c *WGUSPConfigurer) RemoveAllowedIP(peerKey string, allowedIP netip.Prefix
return c.device.IpcSet(toWgUserspaceString(config))
}
func (c *WGUSPConfigurer) FullStats() (*Stats, error) {
ipcStr, err := c.device.IpcGet()
if err != nil {
return nil, fmt.Errorf("IpcGet failed: %w", err)
}
return parseStatus(c.deviceName, ipcStr)
}
func (c *WGUSPConfigurer) LastActivities() map[string]monotime.Time {
return c.activityRecorder.GetLastActivities()
}
// startUAPI starts the UAPI listener for managing the WireGuard interface via external tool
func (t *WGUSPConfigurer) startUAPI() {
var err error
@@ -267,75 +217,91 @@ func (t *WGUSPConfigurer) Close() {
}
}
func (t *WGUSPConfigurer) GetStats() (map[string]WGStats, error) {
func (t *WGUSPConfigurer) GetStats(peerKey string) (WGStats, error) {
ipc, err := t.device.IpcGet()
if err != nil {
return nil, fmt.Errorf("ipc get: %w", err)
return WGStats{}, fmt.Errorf("ipc get: %w", err)
}
return parseTransfers(ipc)
stats, err := findPeerInfo(ipc, peerKey, []string{
"last_handshake_time_sec",
"last_handshake_time_nsec",
"tx_bytes",
"rx_bytes",
})
if err != nil {
return WGStats{}, fmt.Errorf("find peer info: %w", err)
}
sec, err := strconv.ParseInt(stats["last_handshake_time_sec"], 10, 64)
if err != nil {
return WGStats{}, fmt.Errorf("parse handshake sec: %w", err)
}
nsec, err := strconv.ParseInt(stats["last_handshake_time_nsec"], 10, 64)
if err != nil {
return WGStats{}, fmt.Errorf("parse handshake nsec: %w", err)
}
txBytes, err := strconv.ParseInt(stats["tx_bytes"], 10, 64)
if err != nil {
return WGStats{}, fmt.Errorf("parse tx_bytes: %w", err)
}
rxBytes, err := strconv.ParseInt(stats["rx_bytes"], 10, 64)
if err != nil {
return WGStats{}, fmt.Errorf("parse rx_bytes: %w", err)
}
return WGStats{
LastHandshake: time.Unix(sec, nsec),
TxBytes: txBytes,
RxBytes: rxBytes,
}, nil
}
func parseTransfers(ipc string) (map[string]WGStats, error) {
stats := make(map[string]WGStats)
var (
currentKey string
currentStats WGStats
hasPeer bool
)
lines := strings.Split(ipc, "\n")
func findPeerInfo(ipcInput string, peerKey string, searchConfigKeys []string) (map[string]string, error) {
peerKeyParsed, err := wgtypes.ParseKey(peerKey)
if err != nil {
return nil, fmt.Errorf("parse key: %w", err)
}
hexKey := hex.EncodeToString(peerKeyParsed[:])
lines := strings.Split(ipcInput, "\n")
configFound := map[string]string{}
foundPeer := false
for _, line := range lines {
line = strings.TrimSpace(line)
// If we're within the details of the found peer and encounter another public key,
// this means we're starting another peer's details. So, stop.
if strings.HasPrefix(line, "public_key=") {
peerID := strings.TrimPrefix(line, "public_key=")
h, err := hex.DecodeString(peerID)
if err != nil {
return nil, fmt.Errorf("decode peerID: %w", err)
}
currentKey = base64.StdEncoding.EncodeToString(h)
currentStats = WGStats{} // Reset stats for the new peer
hasPeer = true
stats[currentKey] = currentStats
continue
if strings.HasPrefix(line, "public_key=") && foundPeer {
break
}
if !hasPeer {
continue
// Identify the peer with the specific public key
if line == fmt.Sprintf("public_key=%s", hexKey) {
foundPeer = true
}
key := strings.SplitN(line, "=", 2)
if len(key) != 2 {
continue
}
switch key[0] {
case ipcKeyLastHandshakeTimeSec:
hs, err := toLastHandshake(key[1])
if err != nil {
return nil, err
for _, key := range searchConfigKeys {
if foundPeer && strings.HasPrefix(line, key+"=") {
v := strings.SplitN(line, "=", 2)
configFound[v[0]] = v[1]
}
currentStats.LastHandshake = hs
stats[currentKey] = currentStats
case ipcKeyRxBytes:
rxBytes, err := toBytes(key[1])
if err != nil {
return nil, fmt.Errorf("parse rx_bytes: %w", err)
}
currentStats.RxBytes = rxBytes
stats[currentKey] = currentStats
case ipcKeyTxBytes:
TxBytes, err := toBytes(key[1])
if err != nil {
return nil, fmt.Errorf("parse tx_bytes: %w", err)
}
currentStats.TxBytes = TxBytes
stats[currentKey] = currentStats
}
}
return stats, nil
// todo: use multierr
for _, key := range searchConfigKeys {
if _, ok := configFound[key]; !ok {
return configFound, fmt.Errorf("config key not found: %s", key)
}
}
if !foundPeer {
return nil, fmt.Errorf("%w: %s", ErrPeerNotFound, peerKey)
}
return configFound, nil
}
func toWgUserspaceString(wgCfg wgtypes.Config) string {
@@ -389,154 +355,9 @@ func toWgUserspaceString(wgCfg wgtypes.Config) string {
return sb.String()
}
func toLastHandshake(stringVar string) (time.Time, error) {
sec, err := strconv.ParseInt(stringVar, 10, 64)
if err != nil {
return time.Time{}, fmt.Errorf("parse handshake sec: %w", err)
}
return time.Unix(sec, 0), nil
}
func toBytes(s string) (int64, error) {
return strconv.ParseInt(s, 10, 64)
}
func getFwmark() int {
if nbnet.AdvancedRouting() {
return nbnet.ControlPlaneMark
}
return 0
}
func hexToWireguardKey(hexKey string) (wgtypes.Key, error) {
// Decode hex string to bytes
keyBytes, err := hex.DecodeString(hexKey)
if err != nil {
return wgtypes.Key{}, fmt.Errorf("failed to decode hex key: %w", err)
}
// Check if we have the right number of bytes (WireGuard keys are 32 bytes)
if len(keyBytes) != 32 {
return wgtypes.Key{}, fmt.Errorf("invalid key length: expected 32 bytes, got %d", len(keyBytes))
}
// Convert to wgtypes.Key
var key wgtypes.Key
copy(key[:], keyBytes)
return key, nil
}
func parseStatus(deviceName, ipcStr string) (*Stats, error) {
stats := &Stats{DeviceName: deviceName}
var currentPeer *Peer
for _, line := range strings.Split(strings.TrimSpace(ipcStr), "\n") {
if line == "" {
continue
}
parts := strings.SplitN(line, "=", 2)
if len(parts) != 2 {
continue
}
key := parts[0]
val := parts[1]
switch key {
case privateKey:
key, err := hexToWireguardKey(val)
if err != nil {
log.Errorf("failed to parse private key: %v", err)
continue
}
stats.PublicKey = key.PublicKey().String()
case publicKey:
// Save previous peer
if currentPeer != nil {
stats.Peers = append(stats.Peers, *currentPeer)
}
key, err := hexToWireguardKey(val)
if err != nil {
log.Errorf("failed to parse public key: %v", err)
continue
}
currentPeer = &Peer{
PublicKey: key.String(),
}
case listenPort:
if port, err := strconv.Atoi(val); err == nil {
stats.ListenPort = port
}
case fwmark:
if fwmark, err := strconv.Atoi(val); err == nil {
stats.FWMark = fwmark
}
case endpoint:
if currentPeer == nil {
continue
}
host, portStr, err := net.SplitHostPort(strings.Trim(val, "[]"))
if err != nil {
log.Errorf("failed to parse endpoint: %v", err)
continue
}
port, err := strconv.Atoi(portStr)
if err != nil {
log.Errorf("failed to parse endpoint port: %v", err)
continue
}
currentPeer.Endpoint = net.UDPAddr{
IP: net.ParseIP(host),
Port: port,
}
case allowedIP:
if currentPeer == nil {
continue
}
_, ipnet, err := net.ParseCIDR(val)
if err == nil {
currentPeer.AllowedIPs = append(currentPeer.AllowedIPs, *ipnet)
}
case ipcKeyTxBytes:
if currentPeer == nil {
continue
}
rxBytes, err := toBytes(val)
if err != nil {
continue
}
currentPeer.TxBytes = rxBytes
case ipcKeyRxBytes:
if currentPeer == nil {
continue
}
rxBytes, err := toBytes(val)
if err != nil {
continue
}
currentPeer.RxBytes = rxBytes
case ipcKeyLastHandshakeTimeSec:
if currentPeer == nil {
continue
}
ts, err := toLastHandshake(val)
if err != nil {
continue
}
currentPeer.LastHandshake = ts
case presharedKey:
if currentPeer == nil {
continue
}
if val != "" && val != "0000000000000000000000000000000000000000000000000000000000000000" {
currentPeer.PresharedKey = true
}
}
}
if currentPeer != nil {
stats.Peers = append(stats.Peers, *currentPeer)
}
return stats, nil
}

View File

@@ -2,8 +2,10 @@ package configurer
import (
"encoding/hex"
"fmt"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
@@ -32,35 +34,58 @@ errno=0
`
func Test_parseTransfers(t *testing.T) {
func Test_findPeerInfo(t *testing.T) {
tests := []struct {
name string
peerKey string
want WGStats
name string
peerKey string
searchKeys []string
want map[string]string
wantErr bool
}{
{
name: "single",
peerKey: "b85996fecc9c7f1fc6d2572a76eda11d59bcd20be8e543b15ce4bd85a8e75a33",
want: WGStats{
TxBytes: 0,
RxBytes: 0,
name: "single",
peerKey: "58402e695ba1772b1cc9309755f043251ea77fdcf10fbe63989ceb7e19321376",
searchKeys: []string{"tx_bytes"},
want: map[string]string{
"tx_bytes": "38333",
},
wantErr: false,
},
{
name: "multiple",
peerKey: "58402e695ba1772b1cc9309755f043251ea77fdcf10fbe63989ceb7e19321376",
want: WGStats{
TxBytes: 38333,
RxBytes: 2224,
name: "multiple",
peerKey: "58402e695ba1772b1cc9309755f043251ea77fdcf10fbe63989ceb7e19321376",
searchKeys: []string{"tx_bytes", "rx_bytes"},
want: map[string]string{
"tx_bytes": "38333",
"rx_bytes": "2224",
},
wantErr: false,
},
{
name: "lastpeer",
peerKey: "662e14fd594556f522604703340351258903b64f35553763f19426ab2a515c58",
want: WGStats{
TxBytes: 1212111,
RxBytes: 1929999999,
name: "lastpeer",
peerKey: "662e14fd594556f522604703340351258903b64f35553763f19426ab2a515c58",
searchKeys: []string{"tx_bytes", "rx_bytes"},
want: map[string]string{
"tx_bytes": "1212111",
"rx_bytes": "1929999999",
},
wantErr: false,
},
{
name: "peer not found",
peerKey: "1111111111111111111111111111111111111111111111111111111111111111",
searchKeys: nil,
want: nil,
wantErr: true,
},
{
name: "key not found",
peerKey: "662e14fd594556f522604703340351258903b64f35553763f19426ab2a515c58",
searchKeys: []string{"tx_bytes", "unknown_key"},
want: map[string]string{
"tx_bytes": "1212111",
},
wantErr: true,
},
}
for _, tt := range tests {
@@ -71,19 +96,9 @@ func Test_parseTransfers(t *testing.T) {
key, err := wgtypes.NewKey(res)
require.NoError(t, err)
stats, err := parseTransfers(ipcFixture)
if err != nil {
require.NoError(t, err)
return
}
stat, ok := stats[key.String()]
if !ok {
require.True(t, ok)
return
}
require.Equal(t, tt.want, stat)
got, err := findPeerInfo(ipcFixture, key.String(), tt.searchKeys)
assert.Equalf(t, tt.wantErr, err != nil, fmt.Sprintf("findPeerInfo(%v, %v, %v)", ipcFixture, key.String(), tt.searchKeys))
assert.Equalf(t, tt.want, got, "findPeerInfo(%v, %v, %v)", ipcFixture, key.String(), tt.searchKeys)
})
}
}

View File

@@ -1,24 +0,0 @@
package configurer
import (
"net"
"time"
)
type Peer struct {
PublicKey string
Endpoint net.UDPAddr
AllowedIPs []net.IPNet
TxBytes int64
RxBytes int64
LastHandshake time.Time
PresharedKey bool
}
type Stats struct {
DeviceName string
PublicKey string
ListenPort int
FWMark int
Peers []Peer
}

View File

@@ -24,7 +24,6 @@ type WGTunDevice struct {
mtu int
iceBind *bind.ICEBind
tunAdapter TunAdapter
disableDNS bool
name string
device *device.Device
@@ -33,7 +32,7 @@ type WGTunDevice struct {
configurer WGConfigurer
}
func NewTunDevice(address wgaddr.Address, port int, key string, mtu int, iceBind *bind.ICEBind, tunAdapter TunAdapter, disableDNS bool) *WGTunDevice {
func NewTunDevice(address wgaddr.Address, port int, key string, mtu int, iceBind *bind.ICEBind, tunAdapter TunAdapter) *WGTunDevice {
return &WGTunDevice{
address: address,
port: port,
@@ -41,7 +40,6 @@ func NewTunDevice(address wgaddr.Address, port int, key string, mtu int, iceBind
mtu: mtu,
iceBind: iceBind,
tunAdapter: tunAdapter,
disableDNS: disableDNS,
}
}
@@ -51,13 +49,6 @@ func (t *WGTunDevice) Create(routes []string, dns string, searchDomains []string
routesString := routesToString(routes)
searchDomainsToString := searchDomainsToString(searchDomains)
// Skip DNS configuration when DisableDNS is enabled
if t.disableDNS {
log.Info("DNS is disabled, skipping DNS and search domain configuration")
dns = ""
searchDomainsToString = ""
}
fd, err := t.tunAdapter.ConfigureInterface(t.address.String(), t.mtu, dns, searchDomainsToString, routesString)
if err != nil {
log.Errorf("failed to create Android interface: %s", err)
@@ -79,7 +70,7 @@ func (t *WGTunDevice) Create(routes []string, dns string, searchDomains []string
// this helps with support for the older NetBird clients that had a hardcoded direct mode
// t.device.DisableSomeRoamingForBrokenMobileSemantics()
t.configurer = configurer.NewUSPConfigurer(t.device, t.name, t.iceBind.ActivityRecorder())
t.configurer = configurer.NewUSPConfigurer(t.device, t.name)
err = t.configurer.ConfigureInterface(t.key, t.port)
if err != nil {
t.device.Close()

View File

@@ -61,7 +61,7 @@ func (t *TunDevice) Create() (WGConfigurer, error) {
return nil, fmt.Errorf("error assigning ip: %s", err)
}
t.configurer = configurer.NewUSPConfigurer(t.device, t.name, t.iceBind.ActivityRecorder())
t.configurer = configurer.NewUSPConfigurer(t.device, t.name)
err = t.configurer.ConfigureInterface(t.key, t.port)
if err != nil {
t.device.Close()

View File

@@ -1,6 +1,7 @@
package device
import (
"net"
"net/netip"
"sync"
@@ -9,11 +10,11 @@ import (
// PacketFilter interface for firewall abilities
type PacketFilter interface {
// FilterOutbound filter outgoing packets from host to external destinations
FilterOutbound(packetData []byte, size int) bool
// DropOutgoing filter outgoing packets from host to external destinations
DropOutgoing(packetData []byte, size int) bool
// FilterInbound filter incoming packets from external sources to host
FilterInbound(packetData []byte, size int) bool
// DropIncoming filter incoming packets from external sources to host
DropIncoming(packetData []byte, size int) bool
// AddUDPPacketHook calls hook when UDP packet from given direction matched
//
@@ -23,6 +24,9 @@ type PacketFilter interface {
// RemovePacketHook removes hook by ID
RemovePacketHook(hookID string) error
// SetNetwork of the wireguard interface to which filtering applied
SetNetwork(*net.IPNet)
}
// FilteredDevice to override Read or Write of packets
@@ -54,7 +58,7 @@ func (d *FilteredDevice) Read(bufs [][]byte, sizes []int, offset int) (n int, er
}
for i := 0; i < n; i++ {
if filter.FilterOutbound(bufs[i][offset:offset+sizes[i]], sizes[i]) {
if filter.DropOutgoing(bufs[i][offset:offset+sizes[i]], sizes[i]) {
bufs = append(bufs[:i], bufs[i+1:]...)
sizes = append(sizes[:i], sizes[i+1:]...)
n--
@@ -78,7 +82,7 @@ func (d *FilteredDevice) Write(bufs [][]byte, offset int) (int, error) {
filteredBufs := make([][]byte, 0, len(bufs))
dropped := 0
for _, buf := range bufs {
if !filter.FilterInbound(buf[offset:], len(buf)) {
if !filter.DropIncoming(buf[offset:], len(buf)) {
filteredBufs = append(filteredBufs, buf)
dropped++
}

View File

@@ -146,7 +146,7 @@ func TestDeviceWrapperRead(t *testing.T) {
tun.EXPECT().Write(mockBufs, 0).Return(0, nil)
filter := mocks.NewMockPacketFilter(ctrl)
filter.EXPECT().FilterInbound(gomock.Any(), gomock.Any()).Return(true)
filter.EXPECT().DropIncoming(gomock.Any(), gomock.Any()).Return(true)
wrapped := newDeviceFilter(tun)
wrapped.filter = filter
@@ -201,7 +201,7 @@ func TestDeviceWrapperRead(t *testing.T) {
return 1, nil
})
filter := mocks.NewMockPacketFilter(ctrl)
filter.EXPECT().FilterOutbound(gomock.Any(), gomock.Any()).Return(true)
filter.EXPECT().DropOutgoing(gomock.Any(), gomock.Any()).Return(true)
wrapped := newDeviceFilter(tun)
wrapped.filter = filter

View File

@@ -71,7 +71,7 @@ func (t *TunDevice) Create() (WGConfigurer, error) {
// this helps with support for the older NetBird clients that had a hardcoded direct mode
// t.device.DisableSomeRoamingForBrokenMobileSemantics()
t.configurer = configurer.NewUSPConfigurer(t.device, t.name, t.iceBind.ActivityRecorder())
t.configurer = configurer.NewUSPConfigurer(t.device, t.name)
err = t.configurer.ConfigureInterface(t.key, t.port)
if err != nil {
t.device.Close()

View File

@@ -16,7 +16,6 @@ import (
"github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/sharedsock"
nbnet "github.com/netbirdio/netbird/util/net"
)
type TunKernelDevice struct {
@@ -100,14 +99,8 @@ func (t *TunKernelDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
if err != nil {
return nil, err
}
var udpConn net.PacketConn = rawSock
if !nbnet.AdvancedRouting() {
udpConn = nbnet.WrapPacketConn(rawSock)
}
bindParams := bind.UniversalUDPMuxParams{
UDPConn: udpConn,
UDPConn: rawSock,
Net: t.transportNet,
FilterFn: t.filterFn,
WGAddress: t.address,

View File

@@ -51,11 +51,7 @@ func (t *TunNetstackDevice) Create() (WGConfigurer, error) {
log.Info("create nbnetstack tun interface")
// TODO: get from service listener runtime IP
dnsAddr, err := nbnet.GetLastIPFromNetwork(t.address.Network, 1)
if err != nil {
return nil, fmt.Errorf("last ip: %w", err)
}
dnsAddr := nbnet.GetLastIPFromNetwork(t.address.Network, 1)
log.Debugf("netstack using address: %s", t.address.IP)
t.nsTun = nbnetstack.NewNetStackTun(t.listenAddress, t.address.IP, dnsAddr, t.mtu)
log.Debugf("netstack using dns address: %s", dnsAddr)
@@ -72,7 +68,7 @@ func (t *TunNetstackDevice) Create() (WGConfigurer, error) {
device.NewLogger(wgLogLevel(), "[netbird] "),
)
t.configurer = configurer.NewUSPConfigurer(t.device, t.name, t.iceBind.ActivityRecorder())
t.configurer = configurer.NewUSPConfigurer(t.device, t.name)
err = t.configurer.ConfigureInterface(t.key, t.port)
if err != nil {
_ = tunIface.Close()

View File

@@ -64,7 +64,7 @@ func (t *USPDevice) Create() (WGConfigurer, error) {
return nil, fmt.Errorf("error assigning ip: %s", err)
}
t.configurer = configurer.NewUSPConfigurer(t.device, t.name, t.iceBind.ActivityRecorder())
t.configurer = configurer.NewUSPConfigurer(t.device, t.name)
err = t.configurer.ConfigureInterface(t.key, t.port)
if err != nil {
t.device.Close()

View File

@@ -94,7 +94,7 @@ func (t *TunDevice) Create() (WGConfigurer, error) {
return nil, fmt.Errorf("error assigning ip: %s", err)
}
t.configurer = configurer.NewUSPConfigurer(t.device, t.name, t.iceBind.ActivityRecorder())
t.configurer = configurer.NewUSPConfigurer(t.device, t.name)
err = t.configurer.ConfigureInterface(t.key, t.port)
if err != nil {
t.device.Close()

View File

@@ -2,23 +2,19 @@ package device
import (
"net"
"net/netip"
"time"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/monotime"
)
type WGConfigurer interface {
ConfigureInterface(privateKey string, port int) error
UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error
UpdatePeer(peerKey string, allowedIps []net.IPNet, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error
RemovePeer(peerKey string) error
AddAllowedIP(peerKey string, allowedIP netip.Prefix) error
RemoveAllowedIP(peerKey string, allowedIP netip.Prefix) error
AddAllowedIP(peerKey string, allowedIP string) error
RemoveAllowedIP(peerKey string, allowedIP string) error
Close()
GetStats() (map[string]configurer.WGStats, error)
FullStats() (*configurer.Stats, error)
LastActivities() map[string]monotime.Time
GetStats(peerKey string) (configurer.WGStats, error)
}

View File

@@ -64,15 +64,7 @@ func (l *wgLink) assignAddr(address wgaddr.Address) error {
}
ip := address.IP.String()
// Convert prefix length to hex netmask
prefixLen := address.Network.Bits()
if !address.IP.Is4() {
return fmt.Errorf("IPv6 not supported for interface assignment")
}
maskBits := uint32(0xffffffff) << (32 - prefixLen)
mask := fmt.Sprintf("0x%08x", maskBits)
mask := "0x" + address.Network.Mask.String()
log.Infof("assign addr %s mask %s to %s interface", ip, mask, l.name)

View File

@@ -21,7 +21,6 @@ import (
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/iface/wgproxy"
"github.com/netbirdio/netbird/monotime"
)
const (
@@ -30,11 +29,6 @@ const (
WgInterfaceDefault = configurer.WgInterfaceDefault
)
var (
// ErrIfaceNotFound is returned when the WireGuard interface is not found
ErrIfaceNotFound = fmt.Errorf("wireguard interface not found")
)
type wgProxyFactory interface {
GetProxy() wgproxy.Proxy
Free() error
@@ -49,7 +43,6 @@ type WGIFaceOpts struct {
MobileArgs *device.MobileIFaceArguments
TransportNet transport.Net
FilterFn bind.FilterFn
DisableDNS bool
}
// WGIface represents an interface instance
@@ -118,50 +111,38 @@ func (w *WGIface) UpdateAddr(newAddr string) error {
}
// UpdatePeer updates existing Wireguard Peer or creates a new one if doesn't exist
// Endpoint is optional.
// If allowedIps is given it will be added to the existing ones.
// Endpoint is optional
func (w *WGIface) UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error {
w.mu.Lock()
defer w.mu.Unlock()
if w.configurer == nil {
return ErrIfaceNotFound
}
log.Debugf("updating interface %s peer %s, endpoint %s, allowedIPs %v", w.tun.DeviceName(), peerKey, endpoint, allowedIps)
return w.configurer.UpdatePeer(peerKey, allowedIps, keepAlive, endpoint, preSharedKey)
netIPNets := prefixesToIPNets(allowedIps)
log.Debugf("updating interface %s peer %s, endpoint %s", w.tun.DeviceName(), peerKey, endpoint)
return w.configurer.UpdatePeer(peerKey, netIPNets, keepAlive, endpoint, preSharedKey)
}
// RemovePeer removes a Wireguard Peer from the interface iface
func (w *WGIface) RemovePeer(peerKey string) error {
w.mu.Lock()
defer w.mu.Unlock()
if w.configurer == nil {
return ErrIfaceNotFound
}
log.Debugf("Removing peer %s from interface %s ", peerKey, w.tun.DeviceName())
return w.configurer.RemovePeer(peerKey)
}
// AddAllowedIP adds a prefix to the allowed IPs list of peer
func (w *WGIface) AddAllowedIP(peerKey string, allowedIP netip.Prefix) error {
func (w *WGIface) AddAllowedIP(peerKey string, allowedIP string) error {
w.mu.Lock()
defer w.mu.Unlock()
if w.configurer == nil {
return ErrIfaceNotFound
}
log.Debugf("Adding allowed IP to interface %s and peer %s: allowed IP %s ", w.tun.DeviceName(), peerKey, allowedIP)
return w.configurer.AddAllowedIP(peerKey, allowedIP)
}
// RemoveAllowedIP removes a prefix from the allowed IPs list of peer
func (w *WGIface) RemoveAllowedIP(peerKey string, allowedIP netip.Prefix) error {
func (w *WGIface) RemoveAllowedIP(peerKey string, allowedIP string) error {
w.mu.Lock()
defer w.mu.Unlock()
if w.configurer == nil {
return ErrIfaceNotFound
}
log.Debugf("Removing allowed IP from interface %s and peer %s: allowed IP %s ", w.tun.DeviceName(), peerKey, allowedIP)
return w.configurer.RemoveAllowedIP(peerKey, allowedIP)
@@ -204,6 +185,7 @@ func (w *WGIface) SetFilter(filter device.PacketFilter) error {
}
w.filter = filter
w.filter.SetNetwork(w.tun.WgAddress().Network)
w.tun.FilteredDevice().SetFilter(filter)
return nil
@@ -230,32 +212,9 @@ func (w *WGIface) GetWGDevice() *wgdevice.Device {
return w.tun.Device()
}
// GetStats returns the last handshake time, rx and tx bytes
func (w *WGIface) GetStats() (map[string]configurer.WGStats, error) {
if w.configurer == nil {
return nil, ErrIfaceNotFound
}
return w.configurer.GetStats()
}
func (w *WGIface) LastActivities() map[string]monotime.Time {
w.mu.Lock()
defer w.mu.Unlock()
if w.configurer == nil {
return nil
}
return w.configurer.LastActivities()
}
func (w *WGIface) FullStats() (*configurer.Stats, error) {
if w.configurer == nil {
return nil, ErrIfaceNotFound
}
return w.configurer.FullStats()
// GetStats returns the last handshake time, rx and tx bytes for the given peer
func (w *WGIface) GetStats(peerKey string) (configurer.WGStats, error) {
return w.configurer.GetStats(peerKey)
}
func (w *WGIface) waitUntilRemoved() error {
@@ -292,3 +251,14 @@ func (w *WGIface) GetNet() *netstack.Net {
return w.tun.GetNet()
}
func prefixesToIPNets(prefixes []netip.Prefix) []net.IPNet {
ipNets := make([]net.IPNet, len(prefixes))
for i, prefix := range prefixes {
ipNets[i] = net.IPNet{
IP: net.IP(prefix.Addr().AsSlice()), // Convert netip.Addr to net.IP
Mask: net.CIDRMask(prefix.Bits(), prefix.Addr().BitLen()), // Create subnet mask
}
}
return ipNets
}

View File

@@ -18,7 +18,7 @@ func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
wgIFace := &WGIface{
userspaceBind: true,
tun: device.NewTunDevice(wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, opts.MobileArgs.TunAdapter, opts.DisableDNS),
tun: device.NewTunDevice(wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, opts.MobileArgs.TunAdapter),
wgProxyFactory: wgproxy.NewUSPFactory(iceBind),
}
return wgIFace, nil

View File

@@ -5,6 +5,7 @@
package mocks
import (
net "net"
"net/netip"
reflect "reflect"
@@ -48,32 +49,32 @@ func (mr *MockPacketFilterMockRecorder) AddUDPPacketHook(arg0, arg1, arg2, arg3
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddUDPPacketHook", reflect.TypeOf((*MockPacketFilter)(nil).AddUDPPacketHook), arg0, arg1, arg2, arg3)
}
// FilterInbound mocks base method.
func (m *MockPacketFilter) FilterInbound(arg0 []byte, arg1 int) bool {
// DropIncoming mocks base method.
func (m *MockPacketFilter) DropIncoming(arg0 []byte, arg1 int) bool {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "FilterInbound", arg0, arg1)
ret := m.ctrl.Call(m, "DropIncoming", arg0, arg1)
ret0, _ := ret[0].(bool)
return ret0
}
// FilterInbound indicates an expected call of FilterInbound.
func (mr *MockPacketFilterMockRecorder) FilterInbound(arg0 interface{}, arg1 any) *gomock.Call {
// DropIncoming indicates an expected call of DropIncoming.
func (mr *MockPacketFilterMockRecorder) DropIncoming(arg0 interface{}, arg1 any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FilterInbound", reflect.TypeOf((*MockPacketFilter)(nil).FilterInbound), arg0, arg1)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DropIncoming", reflect.TypeOf((*MockPacketFilter)(nil).DropIncoming), arg0, arg1)
}
// FilterOutbound mocks base method.
func (m *MockPacketFilter) FilterOutbound(arg0 []byte, arg1 int) bool {
// DropOutgoing mocks base method.
func (m *MockPacketFilter) DropOutgoing(arg0 []byte, arg1 int) bool {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "FilterOutbound", arg0, arg1)
ret := m.ctrl.Call(m, "DropOutgoing", arg0, arg1)
ret0, _ := ret[0].(bool)
return ret0
}
// FilterOutbound indicates an expected call of FilterOutbound.
func (mr *MockPacketFilterMockRecorder) FilterOutbound(arg0 interface{}, arg1 any) *gomock.Call {
// DropOutgoing indicates an expected call of DropOutgoing.
func (mr *MockPacketFilterMockRecorder) DropOutgoing(arg0 interface{}, arg1 any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FilterOutbound", reflect.TypeOf((*MockPacketFilter)(nil).FilterOutbound), arg0, arg1)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DropOutgoing", reflect.TypeOf((*MockPacketFilter)(nil).DropOutgoing), arg0, arg1)
}
// RemovePacketHook mocks base method.
@@ -89,3 +90,15 @@ func (mr *MockPacketFilterMockRecorder) RemovePacketHook(arg0 interface{}) *gomo
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemovePacketHook", reflect.TypeOf((*MockPacketFilter)(nil).RemovePacketHook), arg0)
}
// SetNetwork mocks base method.
func (m *MockPacketFilter) SetNetwork(arg0 *net.IPNet) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "SetNetwork", arg0)
}
// SetNetwork indicates an expected call of SetNetwork.
func (mr *MockPacketFilterMockRecorder) SetNetwork(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetNetwork", reflect.TypeOf((*MockPacketFilter)(nil).SetNetwork), arg0)
}

View File

@@ -46,32 +46,32 @@ func (mr *MockPacketFilterMockRecorder) AddUDPPacketHook(arg0, arg1, arg2, arg3
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddUDPPacketHook", reflect.TypeOf((*MockPacketFilter)(nil).AddUDPPacketHook), arg0, arg1, arg2, arg3)
}
// FilterInbound mocks base method.
func (m *MockPacketFilter) FilterInbound(arg0 []byte) bool {
// DropIncoming mocks base method.
func (m *MockPacketFilter) DropIncoming(arg0 []byte) bool {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "FilterInbound", arg0)
ret := m.ctrl.Call(m, "DropIncoming", arg0)
ret0, _ := ret[0].(bool)
return ret0
}
// FilterInbound indicates an expected call of FilterInbound.
func (mr *MockPacketFilterMockRecorder) FilterInbound(arg0 interface{}) *gomock.Call {
// DropIncoming indicates an expected call of DropIncoming.
func (mr *MockPacketFilterMockRecorder) DropIncoming(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FilterInbound", reflect.TypeOf((*MockPacketFilter)(nil).FilterInbound), arg0)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DropIncoming", reflect.TypeOf((*MockPacketFilter)(nil).DropIncoming), arg0)
}
// FilterOutbound mocks base method.
func (m *MockPacketFilter) FilterOutbound(arg0 []byte) bool {
// DropOutgoing mocks base method.
func (m *MockPacketFilter) DropOutgoing(arg0 []byte) bool {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "FilterOutbound", arg0)
ret := m.ctrl.Call(m, "DropOutgoing", arg0)
ret0, _ := ret[0].(bool)
return ret0
}
// FilterOutbound indicates an expected call of FilterOutbound.
func (mr *MockPacketFilterMockRecorder) FilterOutbound(arg0 interface{}) *gomock.Call {
// DropOutgoing indicates an expected call of DropOutgoing.
func (mr *MockPacketFilterMockRecorder) DropOutgoing(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FilterOutbound", reflect.TypeOf((*MockPacketFilter)(nil).FilterOutbound), arg0)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DropOutgoing", reflect.TypeOf((*MockPacketFilter)(nil).DropOutgoing), arg0)
}
// SetNetwork mocks base method.

View File

@@ -1,6 +1,8 @@
package netstack
import (
"fmt"
"net"
"net/netip"
"os"
"strconv"
@@ -13,8 +15,8 @@ import (
const EnvSkipProxy = "NB_NETSTACK_SKIP_PROXY"
type NetStackTun struct { //nolint:revive
address netip.Addr
dnsAddress netip.Addr
address net.IP
dnsAddress net.IP
mtu int
listenAddress string
@@ -22,7 +24,7 @@ type NetStackTun struct { //nolint:revive
tundev tun.Device
}
func NewNetStackTun(listenAddress string, address netip.Addr, dnsAddress netip.Addr, mtu int) *NetStackTun {
func NewNetStackTun(listenAddress string, address net.IP, dnsAddress net.IP, mtu int) *NetStackTun {
return &NetStackTun{
address: address,
dnsAddress: dnsAddress,
@@ -32,21 +34,28 @@ func NewNetStackTun(listenAddress string, address netip.Addr, dnsAddress netip.A
}
func (t *NetStackTun) Create() (tun.Device, *netstack.Net, error) {
addr, ok := netip.AddrFromSlice(t.address)
if !ok {
return nil, nil, fmt.Errorf("convert address to netip.Addr: %v", t.address)
}
dnsAddr, ok := netip.AddrFromSlice(t.dnsAddress)
if !ok {
return nil, nil, fmt.Errorf("convert dns address to netip.Addr: %v", t.dnsAddress)
}
nsTunDev, tunNet, err := netstack.CreateNetTUN(
[]netip.Addr{t.address},
[]netip.Addr{t.dnsAddress},
[]netip.Addr{addr.Unmap()},
[]netip.Addr{dnsAddr.Unmap()},
t.mtu)
if err != nil {
return nil, nil, err
}
t.tundev = nsTunDev
var skipProxy bool
if val := os.Getenv(EnvSkipProxy); val != "" {
skipProxy, err = strconv.ParseBool(val)
if err != nil {
log.Errorf("failed to parse %s: %s", EnvSkipProxy, err)
}
skipProxy, err := strconv.ParseBool(os.Getenv(EnvSkipProxy))
if err != nil {
log.Errorf("failed to parse %s: %s", EnvSkipProxy, err)
}
if skipProxy {
return nsTunDev, tunNet, nil

Some files were not shown because too many files have changed in this diff Show More