mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-05 17:04:21 -04:00
Compare commits
233 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
8c8473aed3 | ||
|
|
e1c66a8124 | ||
|
|
d89e6151a4 | ||
|
|
3d9be5098b | ||
|
|
cb8b6ca59b | ||
|
|
e0d9306b05 | ||
|
|
2c4ac33b38 | ||
|
|
31872a7fb6 | ||
|
|
cb85d3f2fc | ||
|
|
af8687579b | ||
|
|
3f82698089 | ||
|
|
cb1e437785 | ||
|
|
c435c2727f | ||
|
|
643730f770 | ||
|
|
04fae00a6c | ||
|
|
1a9ea32c21 | ||
|
|
0ea5d020a3 | ||
|
|
459c9ef317 | ||
|
|
e5e275c87a | ||
|
|
d311f57559 | ||
|
|
1a28d18cde | ||
|
|
91e7423989 | ||
|
|
86c16cf651 | ||
|
|
a7af15c4fc | ||
|
|
d6ed9c037e | ||
|
|
40fdeda838 | ||
|
|
f6e9d755e4 | ||
|
|
08fd460867 | ||
|
|
4f74509d55 | ||
|
|
58185ced16 | ||
|
|
e67f44f47c | ||
|
|
b524f486e2 | ||
|
|
0dab03252c | ||
|
|
e49bcc343d | ||
|
|
3e6eede152 | ||
|
|
a76c8eafb4 | ||
|
|
2b9f331980 | ||
|
|
a7ea881900 | ||
|
|
8632dd15f1 | ||
|
|
e3b40ba694 | ||
|
|
e59d75d56a | ||
|
|
408f423adc | ||
|
|
f17dd3619c | ||
|
|
969f1ed59a | ||
|
|
768ba24fda | ||
|
|
8942c40fde | ||
|
|
fbb1b55beb | ||
|
|
77ec32dd6f | ||
|
|
8c09a55057 | ||
|
|
f603ddf35e | ||
|
|
996b8c600c | ||
|
|
c4ed11d447 | ||
|
|
9afbecb7ac | ||
|
|
2c81cf2c1e | ||
|
|
551cb4e467 | ||
|
|
57961afe95 | ||
|
|
22678bce7f | ||
|
|
6c633497bc | ||
|
|
6922826919 | ||
|
|
56a1a75e3f | ||
|
|
d9402168ad | ||
|
|
dbdef04b9e | ||
|
|
29cbfe8467 | ||
|
|
6ce8643368 | ||
|
|
07d1ad35fc | ||
|
|
ef6cd36f1a | ||
|
|
c1c71b6d39 | ||
|
|
0480507a10 | ||
|
|
34ac4e4b5a | ||
|
|
52ff9d9602 | ||
|
|
1b73fae46e | ||
|
|
d897365abc | ||
|
|
f37aa2cc9d | ||
|
|
5343bee7b2 | ||
|
|
870e29db63 | ||
|
|
08e9b05d51 | ||
|
|
3581648071 | ||
|
|
2a51609436 | ||
|
|
83457f8b99 | ||
|
|
b45284f086 | ||
|
|
e9016aecea | ||
|
|
23b5d45b68 | ||
|
|
0e5dc9d412 | ||
|
|
91f7ee6a3c | ||
|
|
7c6b85b4cb | ||
|
|
08c9107c61 | ||
|
|
81d83245e1 | ||
|
|
af2b427751 | ||
|
|
f61ebdb3bc | ||
|
|
de7384e8ea | ||
|
|
75c1be69cf | ||
|
|
424ae28de9 | ||
|
|
d4a800edd5 | ||
|
|
dd9917f1a8 | ||
|
|
8df8c1012f | ||
|
|
bfa5c21d2d | ||
|
|
b1247a14ba | ||
|
|
f595057a0b | ||
|
|
089d442fb2 | ||
|
|
04a3765391 | ||
|
|
d24d8328f9 | ||
|
|
4f63996ae8 | ||
|
|
bdf2994e97 | ||
|
|
6d654acbad | ||
|
|
3e43298471 | ||
|
|
0ad2590974 | ||
|
|
9d11257b1a | ||
|
|
4ee1635baa | ||
|
|
75feb0da8b | ||
|
|
87376afd13 | ||
|
|
b76d9e8e9e | ||
|
|
e71383dcb9 | ||
|
|
e002a2e6e8 | ||
|
|
6127a01196 | ||
|
|
de27d6df36 | ||
|
|
3c535cdd2b | ||
|
|
0f050e5fe1 | ||
|
|
0f7c7f1da2 | ||
|
|
b56f61bf1b | ||
|
|
64f111923e | ||
|
|
122a89c02b | ||
|
|
c6cceba381 | ||
|
|
6c0cdb6ed1 | ||
|
|
84354951d3 | ||
|
|
55957a1960 | ||
|
|
df82a45d99 | ||
|
|
9424b88db2 | ||
|
|
609654eee7 | ||
|
|
b604c66140 | ||
|
|
ea4d13e96d | ||
|
|
87148c503f | ||
|
|
0cd36baf67 | ||
|
|
06980e7fa0 | ||
|
|
1ce4ee0cef | ||
|
|
f367925496 | ||
|
|
616b19c064 | ||
|
|
af27aaf9af | ||
|
|
35287f8241 | ||
|
|
07b220d91b | ||
|
|
41cd4952f1 | ||
|
|
f16f0c7831 | ||
|
|
aa07b3b87b | ||
|
|
2bef214cc0 | ||
|
|
cfb2d82352 | ||
|
|
684501fd35 | ||
|
|
0492c1724a | ||
|
|
6f436e57b5 | ||
|
|
a0d28f9851 | ||
|
|
cdd27a9fe5 | ||
|
|
5523040acd | ||
|
|
670446d42e | ||
|
|
5bed6777d5 | ||
|
|
a0482ebc7b | ||
|
|
2a89d6e47a | ||
|
|
24f932b2ce | ||
|
|
c03435061c | ||
|
|
8e948739f1 | ||
|
|
9b53cad752 | ||
|
|
802a18167c | ||
|
|
e9108ffe6c | ||
|
|
e806d9de38 | ||
|
|
daa8380df9 | ||
|
|
4785f23fc4 | ||
|
|
1d4cfb83e7 | ||
|
|
207fa059d2 | ||
|
|
cbcdad7814 | ||
|
|
701c13807a | ||
|
|
99f8dc7748 | ||
|
|
f1de8e6eb0 | ||
|
|
b2a10780af | ||
|
|
43ae79d848 | ||
|
|
e520b64c6d | ||
|
|
92c91bbdd8 | ||
|
|
adf494e1ac | ||
|
|
2158461121 | ||
|
|
0cd4b601c3 | ||
|
|
ee1cec47b3 | ||
|
|
efb0edfc4c | ||
|
|
20f59ddecb | ||
|
|
2f34e984b0 | ||
|
|
d5b52e86b6 | ||
|
|
cad2fe1f39 | ||
|
|
fcd2c15a37 | ||
|
|
ebda0fc538 | ||
|
|
ac135ab11d | ||
|
|
25faf9283d | ||
|
|
59faaa99f6 | ||
|
|
9762b39f29 | ||
|
|
ffdd115ded | ||
|
|
055df9854c | ||
|
|
12f883badf | ||
|
|
2abb92b0d4 | ||
|
|
01c3719c5d | ||
|
|
7b64953eed | ||
|
|
9bc7d788f0 | ||
|
|
b5419ef11a | ||
|
|
d5081cef90 | ||
|
|
488e619ec7 | ||
|
|
d2b42c8f68 | ||
|
|
2f44fe2e23 | ||
|
|
d8dc107bee | ||
|
|
3fa915e271 | ||
|
|
47c3afe561 | ||
|
|
84bfecdd37 | ||
|
|
3cf87b6846 | ||
|
|
4fe4c2054d | ||
|
|
38ada44a0e | ||
|
|
dbf81a145e | ||
|
|
39483f8ca8 | ||
|
|
c0eaea938e | ||
|
|
ef8b8a2891 | ||
|
|
2817f62c13 | ||
|
|
4a9049566a | ||
|
|
85f92f8321 | ||
|
|
714beb6e3b | ||
|
|
400b9fca32 | ||
|
|
4013298e22 | ||
|
|
312bfd9bd7 | ||
|
|
8db05838ca | ||
|
|
c69df13515 | ||
|
|
986eb8c1e0 | ||
|
|
197761ba4d | ||
|
|
f74ea64c7b | ||
|
|
3b7b9d25bc | ||
|
|
1a6d6b3109 | ||
|
|
f686615876 | ||
|
|
a4311f574d | ||
|
|
0bb8eae903 | ||
|
|
e0b33d325d | ||
|
|
c38e07d89a | ||
|
|
a37368fff4 | ||
|
|
0c93bd3d06 | ||
|
|
a675531b5c |
3
.dockerignore-client
Normal file
3
.dockerignore-client
Normal file
@@ -0,0 +1,3 @@
|
||||
*
|
||||
!client/netbird-entrypoint.sh
|
||||
!netbird
|
||||
15
.github/ISSUE_TEMPLATE/bug-issue-report.md
vendored
15
.github/ISSUE_TEMPLATE/bug-issue-report.md
vendored
@@ -37,16 +37,21 @@ If yes, which one?
|
||||
|
||||
**Debug output**
|
||||
|
||||
To help us resolve the problem, please attach the following debug output
|
||||
To help us resolve the problem, please attach the following anonymized status output
|
||||
|
||||
netbird status -dA
|
||||
|
||||
As well as the file created by
|
||||
Create and upload a debug bundle, and share the returned file key:
|
||||
|
||||
netbird debug for 1m -AS -U
|
||||
|
||||
*Uploaded files are automatically deleted after 30 days.*
|
||||
|
||||
|
||||
Alternatively, create the file only and attach it here manually:
|
||||
|
||||
netbird debug for 1m -AS
|
||||
|
||||
|
||||
We advise reviewing the anonymized output for any remaining personal information.
|
||||
|
||||
**Screenshots**
|
||||
|
||||
@@ -57,8 +62,10 @@ If applicable, add screenshots to help explain your problem.
|
||||
Add any other context about the problem here.
|
||||
|
||||
**Have you tried these troubleshooting steps?**
|
||||
- [ ] Reviewed [client troubleshooting](https://docs.netbird.io/how-to/troubleshooting-client) (if applicable)
|
||||
- [ ] Checked for newer NetBird versions
|
||||
- [ ] Searched for similar issues on GitHub (including closed ones)
|
||||
- [ ] Restarted the NetBird client
|
||||
- [ ] Disabled other VPN software
|
||||
- [ ] Checked firewall settings
|
||||
|
||||
|
||||
2
.github/pull_request_template.md
vendored
2
.github/pull_request_template.md
vendored
@@ -13,3 +13,5 @@
|
||||
- [ ] It is a refactor
|
||||
- [ ] Created tests that fail without the change (if possible)
|
||||
- [ ] Extended the README / documentation, if necessary
|
||||
|
||||
> By submitting this pull request, you confirm that you have read and agree to the terms of the [Contributor License Agreement](https://github.com/netbirdio/netbird/blob/main/CONTRIBUTOR_LICENSE_AGREEMENT.md).
|
||||
|
||||
4
.github/workflows/git-town.yml
vendored
4
.github/workflows/git-town.yml
vendored
@@ -16,6 +16,6 @@ jobs:
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: git-town/action@v1
|
||||
- uses: git-town/action@v1.2.1
|
||||
with:
|
||||
skip-single-stacks: true
|
||||
skip-single-stacks: true
|
||||
|
||||
263
.github/workflows/golang-test-linux.yml
vendored
263
.github/workflows/golang-test-linux.yml
vendored
@@ -16,7 +16,7 @@ jobs:
|
||||
runs-on: ubuntu-22.04
|
||||
outputs:
|
||||
management: ${{ steps.filter.outputs.management }}
|
||||
steps:
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
@@ -24,8 +24,8 @@ jobs:
|
||||
id: filter
|
||||
with:
|
||||
filters: |
|
||||
management:
|
||||
- 'management/**'
|
||||
management:
|
||||
- 'management/**'
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v5
|
||||
@@ -146,13 +146,9 @@ jobs:
|
||||
- name: Test
|
||||
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} CI=true go test -tags devcert -exec 'sudo' -timeout 10m -p 1 $(go list ./... | grep -v -e /management -e /signal -e /relay)
|
||||
|
||||
test_relay:
|
||||
name: "Relay / Unit"
|
||||
needs: [build-cache]
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
arch: [ '386','amd64' ]
|
||||
test_client_on_docker:
|
||||
name: "Client (Docker) / Unit"
|
||||
needs: [ build-cache ]
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Install Go
|
||||
@@ -164,6 +160,79 @@ jobs:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Get Go environment
|
||||
id: go-env
|
||||
run: |
|
||||
echo "cache_dir=$(go env GOCACHE)" >> $GITHUB_OUTPUT
|
||||
echo "modcache_dir=$(go env GOMODCACHE)" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache/restore@v4
|
||||
id: cache-restore
|
||||
with:
|
||||
path: |
|
||||
${{ steps.go-env.outputs.cache_dir }}
|
||||
${{ steps.go-env.outputs.modcache_dir }}
|
||||
key: ${{ runner.os }}-gotest-cache-${{ hashFiles('**/go.sum') }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-gotest-cache-
|
||||
|
||||
- name: Run tests in container
|
||||
env:
|
||||
HOST_GOCACHE: ${{ steps.go-env.outputs.cache_dir }}
|
||||
HOST_GOMODCACHE: ${{ steps.go-env.outputs.modcache_dir }}
|
||||
CONTAINER: "true"
|
||||
run: |
|
||||
CONTAINER_GOCACHE="/root/.cache/go-build"
|
||||
CONTAINER_GOMODCACHE="/go/pkg/mod"
|
||||
|
||||
docker run --rm \
|
||||
--cap-add=NET_ADMIN \
|
||||
--privileged \
|
||||
-v $PWD:/app \
|
||||
-w /app \
|
||||
-v "${HOST_GOCACHE}:${CONTAINER_GOCACHE}" \
|
||||
-v "${HOST_GOMODCACHE}:${CONTAINER_GOMODCACHE}" \
|
||||
-e CGO_ENABLED=1 \
|
||||
-e CI=true \
|
||||
-e DOCKER_CI=true \
|
||||
-e GOARCH=${GOARCH_TARGET} \
|
||||
-e GOCACHE=${CONTAINER_GOCACHE} \
|
||||
-e GOMODCACHE=${CONTAINER_GOMODCACHE} \
|
||||
-e CONTAINER=${CONTAINER} \
|
||||
golang:1.23-alpine \
|
||||
sh -c ' \
|
||||
apk update; apk add --no-cache \
|
||||
ca-certificates iptables ip6tables dbus dbus-dev libpcap-dev build-base; \
|
||||
go test -buildvcs=false -tags devcert -v -timeout 10m -p 1 $(go list -buildvcs=false ./... | grep -v -e /management -e /signal -e /relay -e /client/ui -e /upload-server)
|
||||
'
|
||||
|
||||
test_relay:
|
||||
name: "Relay / Unit"
|
||||
needs: [build-cache]
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
include:
|
||||
- arch: "386"
|
||||
raceFlag: ""
|
||||
- arch: "amd64"
|
||||
raceFlag: ""
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: "1.23.x"
|
||||
cache: false
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Install dependencies
|
||||
if: steps.cache.outputs.cache-hit != 'true'
|
||||
run: sudo apt update && sudo apt install -y gcc-multilib g++-multilib libc6-dev-i386
|
||||
|
||||
- name: Get Go environment
|
||||
run: |
|
||||
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
|
||||
@@ -179,13 +248,6 @@ jobs:
|
||||
restore-keys: |
|
||||
${{ runner.os }}-gotest-cache-
|
||||
|
||||
- name: Install dependencies
|
||||
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
|
||||
|
||||
- name: Install 32-bit libpcap
|
||||
if: matrix.arch == '386'
|
||||
run: sudo dpkg --add-architecture i386 && sudo apt update && sudo apt-get install -y libpcap0.8-dev:i386
|
||||
|
||||
- name: Install modules
|
||||
run: go mod tidy
|
||||
|
||||
@@ -195,9 +257,9 @@ jobs:
|
||||
- name: Test
|
||||
run: |
|
||||
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
|
||||
go test \
|
||||
go test ${{ matrix.raceFlag }} \
|
||||
-exec 'sudo' \
|
||||
-timeout 10m ./signal/...
|
||||
-timeout 10m ./relay/...
|
||||
|
||||
test_signal:
|
||||
name: "Signal / Unit"
|
||||
@@ -217,6 +279,10 @@ jobs:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Install dependencies
|
||||
if: steps.cache.outputs.cache-hit != 'true'
|
||||
run: sudo apt update && sudo apt install -y gcc-multilib g++-multilib libc6-dev-i386
|
||||
|
||||
- name: Get Go environment
|
||||
run: |
|
||||
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
|
||||
@@ -232,13 +298,6 @@ jobs:
|
||||
restore-keys: |
|
||||
${{ runner.os }}-gotest-cache-
|
||||
|
||||
- name: Install dependencies
|
||||
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
|
||||
|
||||
- name: Install 32-bit libpcap
|
||||
if: matrix.arch == '386'
|
||||
run: sudo dpkg --add-architecture i386 && sudo apt update && sudo apt-get install -y libpcap0.8-dev:i386
|
||||
|
||||
- name: Install modules
|
||||
run: go mod tidy
|
||||
|
||||
@@ -286,13 +345,6 @@ jobs:
|
||||
restore-keys: |
|
||||
${{ runner.os }}-gotest-cache-
|
||||
|
||||
- name: Install dependencies
|
||||
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
|
||||
|
||||
- name: Install 32-bit libpcap
|
||||
if: matrix.arch == '386'
|
||||
run: sudo dpkg --add-architecture i386 && sudo apt update && sudo apt-get install -y libpcap0.8-dev:i386
|
||||
|
||||
- name: Install modules
|
||||
run: go mod tidy
|
||||
|
||||
@@ -314,6 +366,7 @@ jobs:
|
||||
run: |
|
||||
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
|
||||
NETBIRD_STORE_ENGINE=${{ matrix.store }} \
|
||||
CI=true \
|
||||
go test -tags=devcert \
|
||||
-exec "sudo --preserve-env=CI,NETBIRD_STORE_ENGINE" \
|
||||
-timeout 20m ./management/...
|
||||
@@ -353,13 +406,6 @@ jobs:
|
||||
restore-keys: |
|
||||
${{ runner.os }}-gotest-cache-
|
||||
|
||||
- name: Install dependencies
|
||||
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
|
||||
|
||||
- name: Install 32-bit libpcap
|
||||
if: matrix.arch == '386'
|
||||
run: sudo dpkg --add-architecture i386 && sudo apt update && sudo apt-get install -y libpcap0.8-dev:i386
|
||||
|
||||
- name: Install modules
|
||||
run: go mod tidy
|
||||
|
||||
@@ -380,10 +426,11 @@ jobs:
|
||||
- name: Test
|
||||
run: |
|
||||
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
|
||||
NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true \
|
||||
NETBIRD_STORE_ENGINE=${{ matrix.store }} \
|
||||
CI=true \
|
||||
go test -tags devcert -run=^$ -bench=. \
|
||||
-exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' \
|
||||
-timeout 20m ./...
|
||||
-timeout 20m ./management/...
|
||||
|
||||
api_benchmark:
|
||||
name: "Management / Benchmark (API)"
|
||||
@@ -396,6 +443,33 @@ jobs:
|
||||
store: [ 'sqlite', 'postgres' ]
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Create Docker network
|
||||
run: docker network create promnet
|
||||
|
||||
- name: Start Prometheus Pushgateway
|
||||
run: docker run -d --name pushgateway --network promnet -p 9091:9091 prom/pushgateway
|
||||
|
||||
- name: Start Prometheus (for Pushgateway forwarding)
|
||||
run: |
|
||||
echo '
|
||||
global:
|
||||
scrape_interval: 15s
|
||||
scrape_configs:
|
||||
- job_name: "pushgateway"
|
||||
static_configs:
|
||||
- targets: ["pushgateway:9091"]
|
||||
remote_write:
|
||||
- url: ${{ secrets.GRAFANA_URL }}
|
||||
basic_auth:
|
||||
username: ${{ secrets.GRAFANA_USER }}
|
||||
password: ${{ secrets.GRAFANA_API_KEY }}
|
||||
' > prometheus.yml
|
||||
|
||||
docker run -d --name prometheus --network promnet \
|
||||
-v $PWD/prometheus.yml:/etc/prometheus/prometheus.yml \
|
||||
-p 9090:9090 \
|
||||
prom/prometheus
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
@@ -420,13 +494,6 @@ jobs:
|
||||
restore-keys: |
|
||||
${{ runner.os }}-gotest-cache-
|
||||
|
||||
- name: Install dependencies
|
||||
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
|
||||
|
||||
- name: Install 32-bit libpcap
|
||||
if: matrix.arch == '386'
|
||||
run: sudo dpkg --add-architecture i386 && sudo apt update && sudo apt-get install -y libpcap0.8-dev:i386
|
||||
|
||||
- name: Install modules
|
||||
run: go mod tidy
|
||||
|
||||
@@ -447,11 +514,13 @@ jobs:
|
||||
- name: Test
|
||||
run: |
|
||||
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
|
||||
NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true \
|
||||
NETBIRD_STORE_ENGINE=${{ matrix.store }} \
|
||||
CI=true \
|
||||
GIT_BRANCH=${{ github.ref_name }} \
|
||||
go test -tags=benchmark \
|
||||
-run=^$ \
|
||||
-bench=. \
|
||||
-exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' \
|
||||
-exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE,GIT_BRANCH,GITHUB_RUN_ID' \
|
||||
-timeout 20m ./management/...
|
||||
|
||||
api_integration_test:
|
||||
@@ -489,13 +558,6 @@ jobs:
|
||||
restore-keys: |
|
||||
${{ runner.os }}-gotest-cache-
|
||||
|
||||
- name: Install dependencies
|
||||
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
|
||||
|
||||
- name: Install 32-bit libpcap
|
||||
if: matrix.arch == '386'
|
||||
run: sudo dpkg --add-architecture i386 && sudo apt update && sudo apt-get install -y libpcap0.8-dev:i386
|
||||
|
||||
- name: Install modules
|
||||
run: go mod tidy
|
||||
|
||||
@@ -505,89 +567,8 @@ jobs:
|
||||
- name: Test
|
||||
run: |
|
||||
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
|
||||
NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true \
|
||||
NETBIRD_STORE_ENGINE=${{ matrix.store }} \
|
||||
CI=true \
|
||||
go test -tags=integration \
|
||||
-exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' \
|
||||
-timeout 20m ./management/...
|
||||
|
||||
test_client_on_docker:
|
||||
name: "Client (Docker) / Unit"
|
||||
needs: [ build-cache ]
|
||||
runs-on: ubuntu-20.04
|
||||
steps:
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: "1.23.x"
|
||||
cache: false
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Get Go environment
|
||||
run: |
|
||||
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
|
||||
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
||||
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache/restore@v4
|
||||
with:
|
||||
path: |
|
||||
${{ env.cache }}
|
||||
${{ env.modcache }}
|
||||
key: ${{ runner.os }}-gotest-cache-${{ hashFiles('**/go.sum') }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-gotest-cache-
|
||||
|
||||
- name: Install dependencies
|
||||
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
|
||||
|
||||
- name: Install modules
|
||||
run: go mod tidy
|
||||
|
||||
- name: check git status
|
||||
run: git --no-pager diff --exit-code
|
||||
|
||||
- name: Generate Shared Sock Test bin
|
||||
run: CGO_ENABLED=0 go test -c -o sharedsock-testing.bin ./sharedsock
|
||||
|
||||
- name: Generate RouteManager Test bin
|
||||
run: CGO_ENABLED=0 go test -c -o routemanager-testing.bin ./client/internal/routemanager
|
||||
|
||||
- name: Generate SystemOps Test bin
|
||||
run: CGO_ENABLED=1 go test -c -o systemops-testing.bin -tags netgo -ldflags '-w -extldflags "-static -ldbus-1 -lpcap"' ./client/internal/routemanager/systemops
|
||||
|
||||
- name: Generate nftables Manager Test bin
|
||||
run: CGO_ENABLED=0 go test -c -o nftablesmanager-testing.bin ./client/firewall/nftables/...
|
||||
|
||||
- name: Generate Engine Test bin
|
||||
run: CGO_ENABLED=1 go test -c -o engine-testing.bin ./client/internal
|
||||
|
||||
- name: Generate Peer Test bin
|
||||
run: CGO_ENABLED=0 go test -c -o peer-testing.bin ./client/internal/peer/
|
||||
|
||||
- run: chmod +x *testing.bin
|
||||
|
||||
- name: Run Shared Sock tests in docker
|
||||
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/sharedsock --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/sharedsock-testing.bin -test.timeout 5m -test.parallel 1
|
||||
|
||||
- name: Run Iface tests in docker
|
||||
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/netbird -v /tmp/cache:/tmp/cache -v /tmp/modcache:/tmp/modcache -w /netbird -e GOCACHE=/tmp/cache -e GOMODCACHE=/tmp/modcache -e CGO_ENABLED=0 golang:1.23-alpine go test -test.timeout 5m -test.parallel 1 ./client/iface/...
|
||||
|
||||
- name: Run RouteManager tests in docker
|
||||
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal/routemanager --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/routemanager-testing.bin -test.timeout 5m -test.parallel 1
|
||||
|
||||
- name: Run SystemOps tests in docker
|
||||
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal/routemanager/systemops --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/systemops-testing.bin -test.timeout 5m -test.parallel 1
|
||||
|
||||
- name: Run nftables Manager tests in docker
|
||||
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/firewall --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/nftablesmanager-testing.bin -test.timeout 5m -test.parallel 1
|
||||
|
||||
- name: Run Engine tests in docker with file store
|
||||
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal -e NETBIRD_STORE_ENGINE="jsonfile" --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/engine-testing.bin -test.timeout 5m -test.parallel 1
|
||||
|
||||
- name: Run Engine tests in docker with sqlite store
|
||||
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal -e NETBIRD_STORE_ENGINE="sqlite" --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/engine-testing.bin -test.timeout 5m -test.parallel 1
|
||||
|
||||
- name: Run Peer tests in docker
|
||||
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal/peer --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/peer-testing.bin -test.timeout 5m -test.parallel 1
|
||||
|
||||
1
.github/workflows/golangci-lint.yml
vendored
1
.github/workflows/golangci-lint.yml
vendored
@@ -21,7 +21,6 @@ jobs:
|
||||
with:
|
||||
ignore_words_list: erro,clienta,hastable,iif,groupd,testin,groupe
|
||||
skip: go.mod,go.sum
|
||||
only_warn: 1
|
||||
golangci:
|
||||
strategy:
|
||||
fail-fast: false
|
||||
|
||||
@@ -43,7 +43,7 @@ jobs:
|
||||
- name: gomobile init
|
||||
run: gomobile init
|
||||
- name: build android netbird lib
|
||||
run: PATH=$PATH:$(go env GOPATH) gomobile bind -o $GITHUB_WORKSPACE/netbird.aar -javapkg=io.netbird.gomobile -ldflags="-X golang.zx2c4.com/wireguard/ipc.socketDirectory=/data/data/io.netbird.client/cache/wireguard -X github.com/netbirdio/netbird/version.version=buildtest" $GITHUB_WORKSPACE/client/android
|
||||
run: PATH=$PATH:$(go env GOPATH) gomobile bind -o $GITHUB_WORKSPACE/netbird.aar -javapkg=io.netbird.gomobile -ldflags="-checklinkname=0 -X golang.zx2c4.com/wireguard/ipc.socketDirectory=/data/data/io.netbird.client/cache/wireguard -X github.com/netbirdio/netbird/version.version=buildtest" $GITHUB_WORKSPACE/client/android
|
||||
env:
|
||||
CGO_ENABLED: 0
|
||||
ANDROID_NDK_HOME: /usr/local/lib/android/sdk/ndk/23.1.7779620
|
||||
|
||||
23
.github/workflows/release.yml
vendored
23
.github/workflows/release.yml
vendored
@@ -9,7 +9,7 @@ on:
|
||||
pull_request:
|
||||
|
||||
env:
|
||||
SIGN_PIPE_VER: "v0.0.18"
|
||||
SIGN_PIPE_VER: "v0.0.21"
|
||||
GORELEASER_VER: "v2.3.2"
|
||||
PRODUCT_NAME: "NetBird"
|
||||
COPYRIGHT: "NetBird GmbH"
|
||||
@@ -65,6 +65,13 @@ jobs:
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USER }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
- name: Log in to the GitHub container registry
|
||||
if: github.event_name != 'pull_request'
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
registry: ghcr.io
|
||||
username: ${{ github.actor }}
|
||||
password: ${{ secrets.CI_DOCKER_PUSH_GITHUB_TOKEN }}
|
||||
- name: Install OS build dependencies
|
||||
run: sudo apt update && sudo apt install -y -q gcc-arm-linux-gnueabihf gcc-aarch64-linux-gnu
|
||||
|
||||
@@ -224,3 +231,17 @@ jobs:
|
||||
ref: ${{ env.SIGN_PIPE_VER }}
|
||||
token: ${{ secrets.SIGN_GITHUB_TOKEN }}
|
||||
inputs: '{ "tag": "${{ github.ref }}", "skipRelease": false }'
|
||||
|
||||
post_on_forum:
|
||||
runs-on: ubuntu-latest
|
||||
continue-on-error: true
|
||||
needs: [trigger_signer]
|
||||
steps:
|
||||
- uses: Codixer/discourse-topic-github-release-action@v2.0.1
|
||||
with:
|
||||
discourse-api-key: ${{ secrets.DISCOURSE_RELEASES_API_KEY }}
|
||||
discourse-base-url: https://forum.netbird.io
|
||||
discourse-author-username: NetBird
|
||||
discourse-category: 17
|
||||
discourse-tags:
|
||||
releases
|
||||
|
||||
@@ -134,6 +134,7 @@ jobs:
|
||||
NETBIRD_STORE_ENGINE_MYSQL_DSN: '${{ env.NETBIRD_STORE_ENGINE_MYSQL_DSN }}$'
|
||||
CI_NETBIRD_MGMT_IDP_SIGNKEY_REFRESH: false
|
||||
CI_NETBIRD_TURN_EXTERNAL_IP: "1.2.3.4"
|
||||
CI_NETBIRD_MGMT_DISABLE_DEFAULT_POLICY: false
|
||||
|
||||
run: |
|
||||
set -x
|
||||
@@ -172,13 +173,15 @@ jobs:
|
||||
grep "NETBIRD_STORE_ENGINE_MYSQL_DSN=$NETBIRD_STORE_ENGINE_MYSQL_DSN" docker-compose.yml
|
||||
grep NETBIRD_STORE_ENGINE_POSTGRES_DSN docker-compose.yml | egrep "$NETBIRD_STORE_ENGINE_POSTGRES_DSN"
|
||||
# check relay values
|
||||
grep "NB_EXPOSED_ADDRESS=$CI_NETBIRD_DOMAIN:33445" docker-compose.yml
|
||||
grep "NB_EXPOSED_ADDRESS=rels://$CI_NETBIRD_DOMAIN:33445" docker-compose.yml
|
||||
grep "NB_LISTEN_ADDRESS=:33445" docker-compose.yml
|
||||
grep '33445:33445' docker-compose.yml
|
||||
grep -A 10 'relay:' docker-compose.yml | egrep 'NB_AUTH_SECRET=.+$'
|
||||
grep -A 7 Relay management.json | grep "rel://$CI_NETBIRD_DOMAIN:33445"
|
||||
grep -A 7 Relay management.json | grep "rels://$CI_NETBIRD_DOMAIN:33445"
|
||||
grep -A 7 Relay management.json | egrep '"Secret": ".+"'
|
||||
grep DisablePromptLogin management.json | grep 'true'
|
||||
grep LoginFlag management.json | grep 0
|
||||
grep DisableDefaultPolicy management.json | grep "$CI_NETBIRD_MGMT_DISABLE_DEFAULT_POLICY"
|
||||
|
||||
- name: Install modules
|
||||
run: go mod tidy
|
||||
|
||||
1
.gitignore
vendored
1
.gitignore
vendored
@@ -30,3 +30,4 @@ infrastructure_files/setup-*.env
|
||||
.vscode
|
||||
.DS_Store
|
||||
vendor/
|
||||
/netbird
|
||||
|
||||
212
.goreleaser.yaml
212
.goreleaser.yaml
@@ -96,6 +96,20 @@ builds:
|
||||
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
|
||||
mod_timestamp: "{{ .CommitTimestamp }}"
|
||||
|
||||
- id: netbird-upload
|
||||
dir: upload-server
|
||||
env: [CGO_ENABLED=0]
|
||||
binary: netbird-upload
|
||||
goos:
|
||||
- linux
|
||||
goarch:
|
||||
- amd64
|
||||
- arm64
|
||||
- arm
|
||||
ldflags:
|
||||
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
|
||||
mod_timestamp: "{{ .CommitTimestamp }}"
|
||||
|
||||
universal_binaries:
|
||||
- id: netbird
|
||||
|
||||
@@ -135,97 +149,119 @@ nfpms:
|
||||
dockers:
|
||||
- image_templates:
|
||||
- netbirdio/netbird:{{ .Version }}-amd64
|
||||
- ghcr.io/netbirdio/netbird:{{ .Version }}-amd64
|
||||
ids:
|
||||
- netbird
|
||||
goarch: amd64
|
||||
use: buildx
|
||||
dockerfile: client/Dockerfile
|
||||
extra_files:
|
||||
- client/netbird-entrypoint.sh
|
||||
build_flag_templates:
|
||||
- "--platform=linux/amd64"
|
||||
- "--label=org.opencontainers.image.created={{.Date}}"
|
||||
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
|
||||
- "--label=maintainer=dev@netbird.io"
|
||||
- image_templates:
|
||||
- netbirdio/netbird:{{ .Version }}-arm64v8
|
||||
- ghcr.io/netbirdio/netbird:{{ .Version }}-arm64v8
|
||||
ids:
|
||||
- netbird
|
||||
goarch: arm64
|
||||
use: buildx
|
||||
dockerfile: client/Dockerfile
|
||||
extra_files:
|
||||
- client/netbird-entrypoint.sh
|
||||
build_flag_templates:
|
||||
- "--platform=linux/arm64"
|
||||
- "--label=org.opencontainers.image.created={{.Date}}"
|
||||
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
|
||||
- "--label=maintainer=dev@netbird.io"
|
||||
- image_templates:
|
||||
- netbirdio/netbird:{{ .Version }}-arm
|
||||
- ghcr.io/netbirdio/netbird:{{ .Version }}-arm
|
||||
ids:
|
||||
- netbird
|
||||
goarch: arm
|
||||
goarm: 6
|
||||
use: buildx
|
||||
dockerfile: client/Dockerfile
|
||||
extra_files:
|
||||
- client/netbird-entrypoint.sh
|
||||
build_flag_templates:
|
||||
- "--platform=linux/arm"
|
||||
- "--label=org.opencontainers.image.created={{.Date}}"
|
||||
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
|
||||
- "--label=maintainer=dev@netbird.io"
|
||||
|
||||
- image_templates:
|
||||
- netbirdio/netbird:{{ .Version }}-rootless-amd64
|
||||
- ghcr.io/netbirdio/netbird:{{ .Version }}-rootless-amd64
|
||||
ids:
|
||||
- netbird
|
||||
goarch: amd64
|
||||
use: buildx
|
||||
dockerfile: client/Dockerfile-rootless
|
||||
extra_files:
|
||||
- client/netbird-entrypoint.sh
|
||||
build_flag_templates:
|
||||
- "--platform=linux/amd64"
|
||||
- "--label=org.opencontainers.image.created={{.Date}}"
|
||||
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
|
||||
- "--label=maintainer=dev@netbird.io"
|
||||
- image_templates:
|
||||
- netbirdio/netbird:{{ .Version }}-rootless-arm64v8
|
||||
- ghcr.io/netbirdio/netbird:{{ .Version }}-rootless-arm64v8
|
||||
ids:
|
||||
- netbird
|
||||
goarch: arm64
|
||||
use: buildx
|
||||
dockerfile: client/Dockerfile-rootless
|
||||
extra_files:
|
||||
- client/netbird-entrypoint.sh
|
||||
build_flag_templates:
|
||||
- "--platform=linux/arm64"
|
||||
- "--label=org.opencontainers.image.created={{.Date}}"
|
||||
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
|
||||
- "--label=maintainer=dev@netbird.io"
|
||||
- image_templates:
|
||||
- netbirdio/netbird:{{ .Version }}-rootless-arm
|
||||
- ghcr.io/netbirdio/netbird:{{ .Version }}-rootless-arm
|
||||
ids:
|
||||
- netbird
|
||||
goarch: arm
|
||||
goarm: 6
|
||||
use: buildx
|
||||
dockerfile: client/Dockerfile-rootless
|
||||
extra_files:
|
||||
- client/netbird-entrypoint.sh
|
||||
build_flag_templates:
|
||||
- "--platform=linux/arm"
|
||||
- "--label=org.opencontainers.image.created={{.Date}}"
|
||||
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
|
||||
- "--label=maintainer=dev@netbird.io"
|
||||
|
||||
- image_templates:
|
||||
- netbirdio/relay:{{ .Version }}-amd64
|
||||
- ghcr.io/netbirdio/relay:{{ .Version }}-amd64
|
||||
ids:
|
||||
- netbird-relay
|
||||
goarch: amd64
|
||||
@@ -237,10 +273,11 @@ dockers:
|
||||
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
|
||||
- "--label=maintainer=dev@netbird.io"
|
||||
- image_templates:
|
||||
- netbirdio/relay:{{ .Version }}-arm64v8
|
||||
- ghcr.io/netbirdio/relay:{{ .Version }}-arm64v8
|
||||
ids:
|
||||
- netbird-relay
|
||||
goarch: arm64
|
||||
@@ -252,10 +289,11 @@ dockers:
|
||||
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
|
||||
- "--label=maintainer=dev@netbird.io"
|
||||
- image_templates:
|
||||
- netbirdio/relay:{{ .Version }}-arm
|
||||
- ghcr.io/netbirdio/relay:{{ .Version }}-arm
|
||||
ids:
|
||||
- netbird-relay
|
||||
goarch: arm
|
||||
@@ -268,10 +306,11 @@ dockers:
|
||||
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
|
||||
- "--label=maintainer=dev@netbird.io"
|
||||
- image_templates:
|
||||
- netbirdio/signal:{{ .Version }}-amd64
|
||||
- ghcr.io/netbirdio/signal:{{ .Version }}-amd64
|
||||
ids:
|
||||
- netbird-signal
|
||||
goarch: amd64
|
||||
@@ -283,10 +322,11 @@ dockers:
|
||||
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
|
||||
- "--label=maintainer=dev@netbird.io"
|
||||
- image_templates:
|
||||
- netbirdio/signal:{{ .Version }}-arm64v8
|
||||
- ghcr.io/netbirdio/signal:{{ .Version }}-arm64v8
|
||||
ids:
|
||||
- netbird-signal
|
||||
goarch: arm64
|
||||
@@ -298,10 +338,11 @@ dockers:
|
||||
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
|
||||
- "--label=maintainer=dev@netbird.io"
|
||||
- image_templates:
|
||||
- netbirdio/signal:{{ .Version }}-arm
|
||||
- ghcr.io/netbirdio/signal:{{ .Version }}-arm
|
||||
ids:
|
||||
- netbird-signal
|
||||
goarch: arm
|
||||
@@ -314,10 +355,11 @@ dockers:
|
||||
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
|
||||
- "--label=maintainer=dev@netbird.io"
|
||||
- image_templates:
|
||||
- netbirdio/management:{{ .Version }}-amd64
|
||||
- ghcr.io/netbirdio/management:{{ .Version }}-amd64
|
||||
ids:
|
||||
- netbird-mgmt
|
||||
goarch: amd64
|
||||
@@ -329,10 +371,11 @@ dockers:
|
||||
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
|
||||
- "--label=maintainer=dev@netbird.io"
|
||||
- image_templates:
|
||||
- netbirdio/management:{{ .Version }}-arm64v8
|
||||
- ghcr.io/netbirdio/management:{{ .Version }}-arm64v8
|
||||
ids:
|
||||
- netbird-mgmt
|
||||
goarch: arm64
|
||||
@@ -344,10 +387,11 @@ dockers:
|
||||
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
|
||||
- "--label=maintainer=dev@netbird.io"
|
||||
- image_templates:
|
||||
- netbirdio/management:{{ .Version }}-arm
|
||||
- ghcr.io/netbirdio/management:{{ .Version }}-arm
|
||||
ids:
|
||||
- netbird-mgmt
|
||||
goarch: arm
|
||||
@@ -360,10 +404,11 @@ dockers:
|
||||
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
|
||||
- "--label=maintainer=dev@netbird.io"
|
||||
- image_templates:
|
||||
- netbirdio/management:{{ .Version }}-debug-amd64
|
||||
- ghcr.io/netbirdio/management:{{ .Version }}-debug-amd64
|
||||
ids:
|
||||
- netbird-mgmt
|
||||
goarch: amd64
|
||||
@@ -375,10 +420,11 @@ dockers:
|
||||
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
|
||||
- "--label=maintainer=dev@netbird.io"
|
||||
- image_templates:
|
||||
- netbirdio/management:{{ .Version }}-debug-arm64v8
|
||||
- ghcr.io/netbirdio/management:{{ .Version }}-debug-arm64v8
|
||||
ids:
|
||||
- netbird-mgmt
|
||||
goarch: arm64
|
||||
@@ -390,11 +436,12 @@ dockers:
|
||||
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
|
||||
- "--label=maintainer=dev@netbird.io"
|
||||
|
||||
- image_templates:
|
||||
- netbirdio/management:{{ .Version }}-debug-arm
|
||||
- ghcr.io/netbirdio/management:{{ .Version }}-debug-arm
|
||||
ids:
|
||||
- netbird-mgmt
|
||||
goarch: arm
|
||||
@@ -407,7 +454,56 @@ dockers:
|
||||
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
|
||||
- "--label=maintainer=dev@netbird.io"
|
||||
- image_templates:
|
||||
- netbirdio/upload:{{ .Version }}-amd64
|
||||
- ghcr.io/netbirdio/upload:{{ .Version }}-amd64
|
||||
ids:
|
||||
- netbird-upload
|
||||
goarch: amd64
|
||||
use: buildx
|
||||
dockerfile: upload-server/Dockerfile
|
||||
build_flag_templates:
|
||||
- "--platform=linux/amd64"
|
||||
- "--label=org.opencontainers.image.created={{.Date}}"
|
||||
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
|
||||
- "--label=maintainer=dev@netbird.io"
|
||||
- image_templates:
|
||||
- netbirdio/upload:{{ .Version }}-arm64v8
|
||||
- ghcr.io/netbirdio/upload:{{ .Version }}-arm64v8
|
||||
ids:
|
||||
- netbird-upload
|
||||
goarch: arm64
|
||||
use: buildx
|
||||
dockerfile: upload-server/Dockerfile
|
||||
build_flag_templates:
|
||||
- "--platform=linux/arm64"
|
||||
- "--label=org.opencontainers.image.created={{.Date}}"
|
||||
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
|
||||
- "--label=maintainer=dev@netbird.io"
|
||||
- image_templates:
|
||||
- netbirdio/upload:{{ .Version }}-arm
|
||||
- ghcr.io/netbirdio/upload:{{ .Version }}-arm
|
||||
ids:
|
||||
- netbird-upload
|
||||
goarch: arm
|
||||
goarm: 6
|
||||
use: buildx
|
||||
dockerfile: upload-server/Dockerfile
|
||||
build_flag_templates:
|
||||
- "--platform=linux/arm"
|
||||
- "--label=org.opencontainers.image.created={{.Date}}"
|
||||
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
|
||||
- "--label=maintainer=dev@netbird.io"
|
||||
docker_manifests:
|
||||
- name_template: netbirdio/netbird:{{ .Version }}
|
||||
@@ -475,7 +571,95 @@ docker_manifests:
|
||||
- netbirdio/management:{{ .Version }}-debug-arm64v8
|
||||
- netbirdio/management:{{ .Version }}-debug-arm
|
||||
- netbirdio/management:{{ .Version }}-debug-amd64
|
||||
- name_template: netbirdio/upload:{{ .Version }}
|
||||
image_templates:
|
||||
- netbirdio/upload:{{ .Version }}-arm64v8
|
||||
- netbirdio/upload:{{ .Version }}-arm
|
||||
- netbirdio/upload:{{ .Version }}-amd64
|
||||
|
||||
- name_template: netbirdio/upload:latest
|
||||
image_templates:
|
||||
- netbirdio/upload:{{ .Version }}-arm64v8
|
||||
- netbirdio/upload:{{ .Version }}-arm
|
||||
- netbirdio/upload:{{ .Version }}-amd64
|
||||
|
||||
- name_template: ghcr.io/netbirdio/netbird:{{ .Version }}
|
||||
image_templates:
|
||||
- ghcr.io/netbirdio/netbird:{{ .Version }}-arm64v8
|
||||
- ghcr.io/netbirdio/netbird:{{ .Version }}-arm
|
||||
- ghcr.io/netbirdio/netbird:{{ .Version }}-amd64
|
||||
|
||||
- name_template: ghcr.io/netbirdio/netbird:latest
|
||||
image_templates:
|
||||
- ghcr.io/netbirdio/netbird:{{ .Version }}-arm64v8
|
||||
- ghcr.io/netbirdio/netbird:{{ .Version }}-arm
|
||||
- ghcr.io/netbirdio/netbird:{{ .Version }}-amd64
|
||||
|
||||
- name_template: ghcr.io/netbirdio/netbird:{{ .Version }}-rootless
|
||||
image_templates:
|
||||
- ghcr.io/netbirdio/netbird:{{ .Version }}-rootless-arm64v8
|
||||
- ghcr.io/netbirdio/netbird:{{ .Version }}-rootless-arm
|
||||
- ghcr.io/netbirdio/netbird:{{ .Version }}-rootless-amd64
|
||||
|
||||
- name_template: ghcr.io/netbirdio/netbird:rootless-latest
|
||||
image_templates:
|
||||
- ghcr.io/netbirdio/netbird:{{ .Version }}-rootless-arm64v8
|
||||
- ghcr.io/netbirdio/netbird:{{ .Version }}-rootless-arm
|
||||
- ghcr.io/netbirdio/netbird:{{ .Version }}-rootless-amd64
|
||||
|
||||
- name_template: ghcr.io/netbirdio/relay:{{ .Version }}
|
||||
image_templates:
|
||||
- ghcr.io/netbirdio/relay:{{ .Version }}-arm64v8
|
||||
- ghcr.io/netbirdio/relay:{{ .Version }}-arm
|
||||
- ghcr.io/netbirdio/relay:{{ .Version }}-amd64
|
||||
|
||||
- name_template: ghcr.io/netbirdio/relay:latest
|
||||
image_templates:
|
||||
- ghcr.io/netbirdio/relay:{{ .Version }}-arm64v8
|
||||
- ghcr.io/netbirdio/relay:{{ .Version }}-arm
|
||||
- ghcr.io/netbirdio/relay:{{ .Version }}-amd64
|
||||
|
||||
- name_template: ghcr.io/netbirdio/signal:{{ .Version }}
|
||||
image_templates:
|
||||
- ghcr.io/netbirdio/signal:{{ .Version }}-arm64v8
|
||||
- ghcr.io/netbirdio/signal:{{ .Version }}-arm
|
||||
- ghcr.io/netbirdio/signal:{{ .Version }}-amd64
|
||||
|
||||
- name_template: ghcr.io/netbirdio/signal:latest
|
||||
image_templates:
|
||||
- ghcr.io/netbirdio/signal:{{ .Version }}-arm64v8
|
||||
- ghcr.io/netbirdio/signal:{{ .Version }}-arm
|
||||
- ghcr.io/netbirdio/signal:{{ .Version }}-amd64
|
||||
|
||||
- name_template: ghcr.io/netbirdio/management:{{ .Version }}
|
||||
image_templates:
|
||||
- ghcr.io/netbirdio/management:{{ .Version }}-arm64v8
|
||||
- ghcr.io/netbirdio/management:{{ .Version }}-arm
|
||||
- ghcr.io/netbirdio/management:{{ .Version }}-amd64
|
||||
|
||||
- name_template: ghcr.io/netbirdio/management:latest
|
||||
image_templates:
|
||||
- ghcr.io/netbirdio/management:{{ .Version }}-arm64v8
|
||||
- ghcr.io/netbirdio/management:{{ .Version }}-arm
|
||||
- ghcr.io/netbirdio/management:{{ .Version }}-amd64
|
||||
|
||||
- name_template: ghcr.io/netbirdio/management:debug-latest
|
||||
image_templates:
|
||||
- ghcr.io/netbirdio/management:{{ .Version }}-debug-arm64v8
|
||||
- ghcr.io/netbirdio/management:{{ .Version }}-debug-arm
|
||||
- ghcr.io/netbirdio/management:{{ .Version }}-debug-amd64
|
||||
|
||||
- name_template: ghcr.io/netbirdio/upload:{{ .Version }}
|
||||
image_templates:
|
||||
- ghcr.io/netbirdio/upload:{{ .Version }}-arm64v8
|
||||
- ghcr.io/netbirdio/upload:{{ .Version }}-arm
|
||||
- ghcr.io/netbirdio/upload:{{ .Version }}-amd64
|
||||
|
||||
- name_template: ghcr.io/netbirdio/upload:latest
|
||||
image_templates:
|
||||
- ghcr.io/netbirdio/upload:{{ .Version }}-arm64v8
|
||||
- ghcr.io/netbirdio/upload:{{ .Version }}-arm
|
||||
- ghcr.io/netbirdio/upload:{{ .Version }}-amd64
|
||||
brews:
|
||||
- ids:
|
||||
- default
|
||||
|
||||
18
README.md
18
README.md
@@ -12,8 +12,11 @@
|
||||
<img src="https://img.shields.io/badge/license-BSD--3-blue" />
|
||||
</a>
|
||||
<br>
|
||||
<a href="https://join.slack.com/t/netbirdio/shared_invite/zt-31rofwmxc-27akKd0Le0vyRpBcwXkP0g">
|
||||
<a href="https://docs.netbird.io/slack-url">
|
||||
<img src="https://img.shields.io/badge/slack-@netbird-red.svg?logo=slack"/>
|
||||
</a>
|
||||
<a href="https://forum.netbird.io">
|
||||
<img src="https://img.shields.io/badge/community forum-@netbird-red.svg?logo=discourse"/>
|
||||
</a>
|
||||
<br>
|
||||
<a href="https://gurubase.io/g/netbird">
|
||||
@@ -29,13 +32,13 @@
|
||||
<br/>
|
||||
See <a href="https://netbird.io/docs/">Documentation</a>
|
||||
<br/>
|
||||
Join our <a href="https://join.slack.com/t/netbirdio/shared_invite/zt-31rofwmxc-27akKd0Le0vyRpBcwXkP0g">Slack channel</a>
|
||||
Join our <a href="https://docs.netbird.io/slack-url">Slack channel</a> or our <a href="https://forum.netbird.io">Community forum</a>
|
||||
<br/>
|
||||
|
||||
</strong>
|
||||
<br>
|
||||
<a href="https://github.com/netbirdio/kubernetes-operator">
|
||||
New: NetBird Kubernetes Operator
|
||||
<a href="https://registry.terraform.io/providers/netbirdio/netbird/latest">
|
||||
New: NetBird terraform provider
|
||||
</a>
|
||||
</p>
|
||||
|
||||
@@ -47,10 +50,9 @@
|
||||
|
||||
**Secure.** NetBird enables secure remote access by applying granular access policies while allowing you to manage them intuitively from a single place. Works universally on any infrastructure.
|
||||
|
||||
### Open-Source Network Security in a Single Platform
|
||||
### Open Source Network Security in a Single Platform
|
||||
|
||||
|
||||

|
||||
<img width="1188" alt="centralized-network-management 1" src="https://github.com/user-attachments/assets/c28cc8e4-15d2-4d2f-bb97-a6433db39d56" />
|
||||
|
||||
### NetBird on Lawrence Systems (Video)
|
||||
[](https://www.youtube.com/watch?v=Kwrff6h0rEw)
|
||||
@@ -61,7 +63,7 @@
|
||||
|----|----|----|----|----|
|
||||
| <ul><li>- \[x] Kernel WireGuard</ul></li> | <ul><li>- \[x] [Admin Web UI](https://github.com/netbirdio/dashboard)</ul></li> | <ul><li>- \[x] [SSO & MFA support](https://docs.netbird.io/how-to/installation#running-net-bird-with-sso-login)</ul></li> | <ul><li>- \[x] [Public API](https://docs.netbird.io/api)</ul></li> | <ul><li>- \[x] Linux</ul></li> |
|
||||
| <ul><li>- \[x] Peer-to-peer connections</ul></li> | <ul><li>- \[x] Auto peer discovery and configuration</ui></li> | <ul><li>- \[x] [Access control - groups & rules](https://docs.netbird.io/how-to/manage-network-access)</ui></li> | <ul><li>- \[x] [Setup keys for bulk network provisioning](https://docs.netbird.io/how-to/register-machines-using-setup-keys)</ui></li> | <ul><li>- \[x] Mac</ui></li> |
|
||||
| <ul><li>- \[x] Connection relay fallback</ui></li> | <ul><li>- \[x] [IdP integrations](https://docs.netbird.io/selfhosted/identity-providers)</ui></li> | <ul><li>- \[x] [Activity logging](https://docs.netbird.io/how-to/monitor-system-and-network-activity)</ui></li> | <ul><li>- \[x] [Self-hosting quickstart script](https://docs.netbird.io/selfhosted/selfhosted-quickstart)</ui></li> | <ul><li>- \[x] Windows</ui></li> |
|
||||
| <ul><li>- \[x] Connection relay fallback</ui></li> | <ul><li>- \[x] [IdP integrations](https://docs.netbird.io/selfhosted/identity-providers)</ui></li> | <ul><li>- \[x] [Activity logging](https://docs.netbird.io/how-to/audit-events-logging)</ui></li> | <ul><li>- \[x] [Self-hosting quickstart script](https://docs.netbird.io/selfhosted/selfhosted-quickstart)</ui></li> | <ul><li>- \[x] Windows</ui></li> |
|
||||
| <ul><li>- \[x] [Routes to external networks](https://docs.netbird.io/how-to/routing-traffic-to-private-networks)</ui></li> | <ul><li>- \[x] [Private DNS](https://docs.netbird.io/how-to/manage-dns-in-your-network)</ui></li> | <ul><li>- \[x] [Device posture checks](https://docs.netbird.io/how-to/manage-posture-checks)</ui></li> | <ul><li>- \[x] IdP groups sync with JWT</ui></li> | <ul><li>- \[x] Android</ui></li> |
|
||||
| <ul><li>- \[x] NAT traversal with BPF</ui></li> | <ul><li>- \[x] [Multiuser support](https://docs.netbird.io/how-to/add-users-to-your-network)</ui></li> | <ul><li>- \[x] Peer-to-peer encryption</ui></li> || <ul><li>- \[x] iOS</ui></li> |
|
||||
||| <ul><li>- \[x] [Quantum-resistance with Rosenpass](https://netbird.io/knowledge-hub/the-first-quantum-resistant-mesh-vpn)</ui></li> || <ul><li>- \[x] OpenWRT</ui></li> |
|
||||
|
||||
@@ -1,5 +1,27 @@
|
||||
FROM alpine:3.21.3
|
||||
RUN apk add --no-cache ca-certificates iptables ip6tables
|
||||
ENV NB_FOREGROUND_MODE=true
|
||||
ENTRYPOINT [ "/usr/local/bin/netbird","up"]
|
||||
COPY netbird /usr/local/bin/netbird
|
||||
# build & run locally with:
|
||||
# cd "$(git rev-parse --show-toplevel)"
|
||||
# CGO_ENABLED=0 go build -o netbird ./client
|
||||
# sudo podman build -t localhost/netbird:latest -f client/Dockerfile --ignorefile .dockerignore-client .
|
||||
# sudo podman run --rm -it --cap-add={BPF,NET_ADMIN,NET_RAW} localhost/netbird:latest
|
||||
|
||||
FROM alpine:3.22.0
|
||||
# iproute2: busybox doesn't display ip rules properly
|
||||
RUN apk add --no-cache \
|
||||
bash \
|
||||
ca-certificates \
|
||||
ip6tables \
|
||||
iproute2 \
|
||||
iptables
|
||||
|
||||
ENV \
|
||||
NETBIRD_BIN="/usr/local/bin/netbird" \
|
||||
NB_LOG_FILE="console,/var/log/netbird/client.log" \
|
||||
NB_DAEMON_ADDR="unix:///var/run/netbird.sock" \
|
||||
NB_ENTRYPOINT_SERVICE_TIMEOUT="5" \
|
||||
NB_ENTRYPOINT_LOGIN_TIMEOUT="1"
|
||||
|
||||
ENTRYPOINT [ "/usr/local/bin/netbird-entrypoint.sh" ]
|
||||
|
||||
ARG NETBIRD_BINARY=netbird
|
||||
COPY client/netbird-entrypoint.sh /usr/local/bin/netbird-entrypoint.sh
|
||||
COPY "${NETBIRD_BINARY}" /usr/local/bin/netbird
|
||||
|
||||
@@ -1,17 +1,33 @@
|
||||
FROM alpine:3.21.0
|
||||
# build & run locally with:
|
||||
# cd "$(git rev-parse --show-toplevel)"
|
||||
# CGO_ENABLED=0 go build -o netbird ./client
|
||||
# podman build -t localhost/netbird:latest -f client/Dockerfile --ignorefile .dockerignore-client .
|
||||
# podman run --rm -it --cap-add={BPF,NET_ADMIN,NET_RAW} localhost/netbird:latest
|
||||
|
||||
COPY netbird /usr/local/bin/netbird
|
||||
FROM alpine:3.22.0
|
||||
|
||||
RUN apk add --no-cache ca-certificates \
|
||||
RUN apk add --no-cache \
|
||||
bash \
|
||||
ca-certificates \
|
||||
&& adduser -D -h /var/lib/netbird netbird
|
||||
|
||||
WORKDIR /var/lib/netbird
|
||||
USER netbird:netbird
|
||||
|
||||
ENV NB_FOREGROUND_MODE=true
|
||||
ENV NB_USE_NETSTACK_MODE=true
|
||||
ENV NB_ENABLE_NETSTACK_LOCAL_FORWARDING=true
|
||||
ENV NB_CONFIG=config.json
|
||||
ENV NB_DAEMON_ADDR=unix://netbird.sock
|
||||
ENV NB_DISABLE_DNS=true
|
||||
ENV \
|
||||
NETBIRD_BIN="/usr/local/bin/netbird" \
|
||||
NB_USE_NETSTACK_MODE="true" \
|
||||
NB_ENABLE_NETSTACK_LOCAL_FORWARDING="true" \
|
||||
NB_CONFIG="/var/lib/netbird/config.json" \
|
||||
NB_STATE_DIR="/var/lib/netbird" \
|
||||
NB_DAEMON_ADDR="unix:///var/lib/netbird/netbird.sock" \
|
||||
NB_LOG_FILE="console,/var/lib/netbird/client.log" \
|
||||
NB_DISABLE_DNS="true" \
|
||||
NB_ENTRYPOINT_SERVICE_TIMEOUT="5" \
|
||||
NB_ENTRYPOINT_LOGIN_TIMEOUT="1"
|
||||
|
||||
ENTRYPOINT [ "/usr/local/bin/netbird", "up" ]
|
||||
ENTRYPOINT [ "/usr/local/bin/netbird-entrypoint.sh" ]
|
||||
|
||||
ARG NETBIRD_BINARY=netbird
|
||||
COPY client/netbird-entrypoint.sh /usr/local/bin/netbird-entrypoint.sh
|
||||
COPY "${NETBIRD_BINARY}" /usr/local/bin/netbird
|
||||
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/dns"
|
||||
"github.com/netbirdio/netbird/client/internal/listener"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||
"github.com/netbirdio/netbird/client/system"
|
||||
"github.com/netbirdio/netbird/formatter"
|
||||
@@ -59,10 +60,14 @@ type Client struct {
|
||||
deviceName string
|
||||
uiVersion string
|
||||
networkChangeListener listener.NetworkChangeListener
|
||||
|
||||
connectClient *internal.ConnectClient
|
||||
}
|
||||
|
||||
// NewClient instantiate a new Client
|
||||
func NewClient(cfgFile, deviceName string, uiVersion string, tunAdapter TunAdapter, iFaceDiscover IFaceDiscover, networkChangeListener NetworkChangeListener) *Client {
|
||||
func NewClient(cfgFile string, androidSDKVersion int, deviceName string, uiVersion string, tunAdapter TunAdapter, iFaceDiscover IFaceDiscover, networkChangeListener NetworkChangeListener) *Client {
|
||||
execWorkaround(androidSDKVersion)
|
||||
|
||||
net.SetAndroidProtectSocketFn(tunAdapter.ProtectSocket)
|
||||
return &Client{
|
||||
cfgFile: cfgFile,
|
||||
@@ -78,7 +83,7 @@ func NewClient(cfgFile, deviceName string, uiVersion string, tunAdapter TunAdapt
|
||||
|
||||
// Run start the internal client. It is a blocker function
|
||||
func (c *Client) Run(urlOpener URLOpener, dns *DNSList, dnsReadyListener DnsReadyListener) error {
|
||||
cfg, err := internal.UpdateOrCreateConfig(internal.ConfigInput{
|
||||
cfg, err := profilemanager.UpdateOrCreateConfig(profilemanager.ConfigInput{
|
||||
ConfigPath: c.cfgFile,
|
||||
})
|
||||
if err != nil {
|
||||
@@ -106,14 +111,14 @@ func (c *Client) Run(urlOpener URLOpener, dns *DNSList, dnsReadyListener DnsRead
|
||||
|
||||
// todo do not throw error in case of cancelled context
|
||||
ctx = internal.CtxInitState(ctx)
|
||||
connectClient := internal.NewConnectClient(ctx, cfg, c.recorder)
|
||||
return connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, dns.items, dnsReadyListener)
|
||||
c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder)
|
||||
return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, dns.items, dnsReadyListener)
|
||||
}
|
||||
|
||||
// RunWithoutLogin we apply this type of run function when the backed has been started without UI (i.e. after reboot).
|
||||
// In this case make no sense handle registration steps.
|
||||
func (c *Client) RunWithoutLogin(dns *DNSList, dnsReadyListener DnsReadyListener) error {
|
||||
cfg, err := internal.UpdateOrCreateConfig(internal.ConfigInput{
|
||||
cfg, err := profilemanager.UpdateOrCreateConfig(profilemanager.ConfigInput{
|
||||
ConfigPath: c.cfgFile,
|
||||
})
|
||||
if err != nil {
|
||||
@@ -132,8 +137,8 @@ func (c *Client) RunWithoutLogin(dns *DNSList, dnsReadyListener DnsReadyListener
|
||||
|
||||
// todo do not throw error in case of cancelled context
|
||||
ctx = internal.CtxInitState(ctx)
|
||||
connectClient := internal.NewConnectClient(ctx, cfg, c.recorder)
|
||||
return connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, dns.items, dnsReadyListener)
|
||||
c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder)
|
||||
return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, dns.items, dnsReadyListener)
|
||||
}
|
||||
|
||||
// Stop the internal client and free the resources
|
||||
@@ -174,6 +179,55 @@ func (c *Client) PeersList() *PeerInfoArray {
|
||||
return &PeerInfoArray{items: peerInfos}
|
||||
}
|
||||
|
||||
func (c *Client) Networks() *NetworkArray {
|
||||
if c.connectClient == nil {
|
||||
log.Error("not connected")
|
||||
return nil
|
||||
}
|
||||
|
||||
engine := c.connectClient.Engine()
|
||||
if engine == nil {
|
||||
log.Error("could not get engine")
|
||||
return nil
|
||||
}
|
||||
|
||||
routeManager := engine.GetRouteManager()
|
||||
if routeManager == nil {
|
||||
log.Error("could not get route manager")
|
||||
return nil
|
||||
}
|
||||
|
||||
networkArray := &NetworkArray{
|
||||
items: make([]Network, 0),
|
||||
}
|
||||
|
||||
for id, routes := range routeManager.GetClientRoutesWithNetID() {
|
||||
if len(routes) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
r := routes[0]
|
||||
netStr := r.Network.String()
|
||||
if r.IsDynamic() {
|
||||
netStr = r.Domains.SafeString()
|
||||
}
|
||||
|
||||
peer, err := c.recorder.GetPeer(routes[0].Peer)
|
||||
if err != nil {
|
||||
log.Errorf("could not get peer info for %s: %v", routes[0].Peer, err)
|
||||
continue
|
||||
}
|
||||
network := Network{
|
||||
Name: string(id),
|
||||
Network: netStr,
|
||||
Peer: peer.FQDN,
|
||||
Status: peer.ConnStatus.String(),
|
||||
}
|
||||
networkArray.Add(network)
|
||||
}
|
||||
return networkArray
|
||||
}
|
||||
|
||||
// OnUpdatedHostDNS update the DNS servers addresses for root zones
|
||||
func (c *Client) OnUpdatedHostDNS(list *DNSList) error {
|
||||
dnsServer, err := dns.GetServerDns()
|
||||
|
||||
26
client/android/exec.go
Normal file
26
client/android/exec.go
Normal file
@@ -0,0 +1,26 @@
|
||||
//go:build android
|
||||
|
||||
package android
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
_ "unsafe"
|
||||
)
|
||||
|
||||
// https://github.com/golang/go/pull/69543/commits/aad6b3b32c81795f86bc4a9e81aad94899daf520
|
||||
// In Android version 11 and earlier, pidfd-related system calls
|
||||
// are not allowed by the seccomp policy, which causes crashes due
|
||||
// to SIGSYS signals.
|
||||
|
||||
//go:linkname checkPidfdOnce os.checkPidfdOnce
|
||||
var checkPidfdOnce func() error
|
||||
|
||||
func execWorkaround(androidSDKVersion int) {
|
||||
if androidSDKVersion > 30 { // above Android 11
|
||||
return
|
||||
}
|
||||
|
||||
checkPidfdOnce = func() error {
|
||||
return fmt.Errorf("unsupported Android version")
|
||||
}
|
||||
}
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
"github.com/netbirdio/netbird/client/cmd"
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
"github.com/netbirdio/netbird/client/internal/auth"
|
||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||
"github.com/netbirdio/netbird/client/system"
|
||||
)
|
||||
|
||||
@@ -37,17 +38,17 @@ type URLOpener interface {
|
||||
// Auth can register or login new client
|
||||
type Auth struct {
|
||||
ctx context.Context
|
||||
config *internal.Config
|
||||
config *profilemanager.Config
|
||||
cfgPath string
|
||||
}
|
||||
|
||||
// NewAuth instantiate Auth struct and validate the management URL
|
||||
func NewAuth(cfgPath string, mgmURL string) (*Auth, error) {
|
||||
inputCfg := internal.ConfigInput{
|
||||
inputCfg := profilemanager.ConfigInput{
|
||||
ManagementURL: mgmURL,
|
||||
}
|
||||
|
||||
cfg, err := internal.CreateInMemoryConfig(inputCfg)
|
||||
cfg, err := profilemanager.CreateInMemoryConfig(inputCfg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -60,7 +61,7 @@ func NewAuth(cfgPath string, mgmURL string) (*Auth, error) {
|
||||
}
|
||||
|
||||
// NewAuthWithConfig instantiate Auth based on existing config
|
||||
func NewAuthWithConfig(ctx context.Context, config *internal.Config) *Auth {
|
||||
func NewAuthWithConfig(ctx context.Context, config *profilemanager.Config) *Auth {
|
||||
return &Auth{
|
||||
ctx: ctx,
|
||||
config: config,
|
||||
@@ -110,7 +111,7 @@ func (a *Auth) saveConfigIfSSOSupported() (bool, error) {
|
||||
return false, fmt.Errorf("backoff cycle failed: %v", err)
|
||||
}
|
||||
|
||||
err = internal.WriteOutConfig(a.cfgPath, a.config)
|
||||
err = profilemanager.WriteOutConfig(a.cfgPath, a.config)
|
||||
return true, err
|
||||
}
|
||||
|
||||
@@ -142,7 +143,7 @@ func (a *Auth) loginWithSetupKeyAndSaveConfig(setupKey string, deviceName string
|
||||
return fmt.Errorf("backoff cycle failed: %v", err)
|
||||
}
|
||||
|
||||
return internal.WriteOutConfig(a.cfgPath, a.config)
|
||||
return profilemanager.WriteOutConfig(a.cfgPath, a.config)
|
||||
}
|
||||
|
||||
// Login try register the client on the server
|
||||
|
||||
27
client/android/networks.go
Normal file
27
client/android/networks.go
Normal file
@@ -0,0 +1,27 @@
|
||||
//go:build android
|
||||
|
||||
package android
|
||||
|
||||
type Network struct {
|
||||
Name string
|
||||
Network string
|
||||
Peer string
|
||||
Status string
|
||||
}
|
||||
|
||||
type NetworkArray struct {
|
||||
items []Network
|
||||
}
|
||||
|
||||
func (array *NetworkArray) Add(s Network) *NetworkArray {
|
||||
array.items = append(array.items, s)
|
||||
return array
|
||||
}
|
||||
|
||||
func (array *NetworkArray) Get(i int) *Network {
|
||||
return &array.items[i]
|
||||
}
|
||||
|
||||
func (array *NetworkArray) Size() int {
|
||||
return len(array.items)
|
||||
}
|
||||
@@ -7,30 +7,23 @@ type PeerInfo struct {
|
||||
ConnStatus string // Todo replace to enum
|
||||
}
|
||||
|
||||
// PeerInfoCollection made for Java layer to get non default types as collection
|
||||
type PeerInfoCollection interface {
|
||||
Add(s string) PeerInfoCollection
|
||||
Get(i int) string
|
||||
Size() int
|
||||
}
|
||||
|
||||
// PeerInfoArray is the implementation of the PeerInfoCollection
|
||||
// PeerInfoArray is a wrapper of []PeerInfo
|
||||
type PeerInfoArray struct {
|
||||
items []PeerInfo
|
||||
}
|
||||
|
||||
// Add new PeerInfo to the collection
|
||||
func (array PeerInfoArray) Add(s PeerInfo) PeerInfoArray {
|
||||
func (array *PeerInfoArray) Add(s PeerInfo) *PeerInfoArray {
|
||||
array.items = append(array.items, s)
|
||||
return array
|
||||
}
|
||||
|
||||
// Get return an element of the collection
|
||||
func (array PeerInfoArray) Get(i int) *PeerInfo {
|
||||
func (array *PeerInfoArray) Get(i int) *PeerInfo {
|
||||
return &array.items[i]
|
||||
}
|
||||
|
||||
// Size return with the size of the collection
|
||||
func (array PeerInfoArray) Size() int {
|
||||
func (array *PeerInfoArray) Size() int {
|
||||
return len(array.items)
|
||||
}
|
||||
|
||||
@@ -1,78 +1,226 @@
|
||||
package android
|
||||
|
||||
import (
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||
)
|
||||
|
||||
// Preferences export a subset of the internal config for gomobile
|
||||
// Preferences exports a subset of the internal config for gomobile
|
||||
type Preferences struct {
|
||||
configInput internal.ConfigInput
|
||||
configInput profilemanager.ConfigInput
|
||||
}
|
||||
|
||||
// NewPreferences create new Preferences instance
|
||||
// NewPreferences creates a new Preferences instance
|
||||
func NewPreferences(configPath string) *Preferences {
|
||||
ci := internal.ConfigInput{
|
||||
ci := profilemanager.ConfigInput{
|
||||
ConfigPath: configPath,
|
||||
}
|
||||
return &Preferences{ci}
|
||||
}
|
||||
|
||||
// GetManagementURL read url from config file
|
||||
// GetManagementURL reads URL from config file
|
||||
func (p *Preferences) GetManagementURL() (string, error) {
|
||||
if p.configInput.ManagementURL != "" {
|
||||
return p.configInput.ManagementURL, nil
|
||||
}
|
||||
|
||||
cfg, err := internal.ReadConfig(p.configInput.ConfigPath)
|
||||
cfg, err := profilemanager.ReadConfig(p.configInput.ConfigPath)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return cfg.ManagementURL.String(), err
|
||||
}
|
||||
|
||||
// SetManagementURL store the given url and wait for commit
|
||||
// SetManagementURL stores the given URL and waits for commit
|
||||
func (p *Preferences) SetManagementURL(url string) {
|
||||
p.configInput.ManagementURL = url
|
||||
}
|
||||
|
||||
// GetAdminURL read url from config file
|
||||
// GetAdminURL reads URL from config file
|
||||
func (p *Preferences) GetAdminURL() (string, error) {
|
||||
if p.configInput.AdminURL != "" {
|
||||
return p.configInput.AdminURL, nil
|
||||
}
|
||||
|
||||
cfg, err := internal.ReadConfig(p.configInput.ConfigPath)
|
||||
cfg, err := profilemanager.ReadConfig(p.configInput.ConfigPath)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return cfg.AdminURL.String(), err
|
||||
}
|
||||
|
||||
// SetAdminURL store the given url and wait for commit
|
||||
// SetAdminURL stores the given URL and waits for commit
|
||||
func (p *Preferences) SetAdminURL(url string) {
|
||||
p.configInput.AdminURL = url
|
||||
}
|
||||
|
||||
// GetPreSharedKey read preshared key from config file
|
||||
// GetPreSharedKey reads pre-shared key from config file
|
||||
func (p *Preferences) GetPreSharedKey() (string, error) {
|
||||
if p.configInput.PreSharedKey != nil {
|
||||
return *p.configInput.PreSharedKey, nil
|
||||
}
|
||||
|
||||
cfg, err := internal.ReadConfig(p.configInput.ConfigPath)
|
||||
cfg, err := profilemanager.ReadConfig(p.configInput.ConfigPath)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return cfg.PreSharedKey, err
|
||||
}
|
||||
|
||||
// SetPreSharedKey store the given key and wait for commit
|
||||
// SetPreSharedKey stores the given key and waits for commit
|
||||
func (p *Preferences) SetPreSharedKey(key string) {
|
||||
p.configInput.PreSharedKey = &key
|
||||
}
|
||||
|
||||
// Commit write out the changes into config file
|
||||
// SetRosenpassEnabled stores whether Rosenpass is enabled
|
||||
func (p *Preferences) SetRosenpassEnabled(enabled bool) {
|
||||
p.configInput.RosenpassEnabled = &enabled
|
||||
}
|
||||
|
||||
// GetRosenpassEnabled reads Rosenpass enabled status from config file
|
||||
func (p *Preferences) GetRosenpassEnabled() (bool, error) {
|
||||
if p.configInput.RosenpassEnabled != nil {
|
||||
return *p.configInput.RosenpassEnabled, nil
|
||||
}
|
||||
|
||||
cfg, err := profilemanager.ReadConfig(p.configInput.ConfigPath)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return cfg.RosenpassEnabled, err
|
||||
}
|
||||
|
||||
// SetRosenpassPermissive stores the given permissive setting and waits for commit
|
||||
func (p *Preferences) SetRosenpassPermissive(permissive bool) {
|
||||
p.configInput.RosenpassPermissive = &permissive
|
||||
}
|
||||
|
||||
// GetRosenpassPermissive reads Rosenpass permissive setting from config file
|
||||
func (p *Preferences) GetRosenpassPermissive() (bool, error) {
|
||||
if p.configInput.RosenpassPermissive != nil {
|
||||
return *p.configInput.RosenpassPermissive, nil
|
||||
}
|
||||
|
||||
cfg, err := profilemanager.ReadConfig(p.configInput.ConfigPath)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return cfg.RosenpassPermissive, err
|
||||
}
|
||||
|
||||
// GetDisableClientRoutes reads disable client routes setting from config file
|
||||
func (p *Preferences) GetDisableClientRoutes() (bool, error) {
|
||||
if p.configInput.DisableClientRoutes != nil {
|
||||
return *p.configInput.DisableClientRoutes, nil
|
||||
}
|
||||
|
||||
cfg, err := profilemanager.ReadConfig(p.configInput.ConfigPath)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return cfg.DisableClientRoutes, err
|
||||
}
|
||||
|
||||
// SetDisableClientRoutes stores the given value and waits for commit
|
||||
func (p *Preferences) SetDisableClientRoutes(disable bool) {
|
||||
p.configInput.DisableClientRoutes = &disable
|
||||
}
|
||||
|
||||
// GetDisableServerRoutes reads disable server routes setting from config file
|
||||
func (p *Preferences) GetDisableServerRoutes() (bool, error) {
|
||||
if p.configInput.DisableServerRoutes != nil {
|
||||
return *p.configInput.DisableServerRoutes, nil
|
||||
}
|
||||
|
||||
cfg, err := profilemanager.ReadConfig(p.configInput.ConfigPath)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return cfg.DisableServerRoutes, err
|
||||
}
|
||||
|
||||
// SetDisableServerRoutes stores the given value and waits for commit
|
||||
func (p *Preferences) SetDisableServerRoutes(disable bool) {
|
||||
p.configInput.DisableServerRoutes = &disable
|
||||
}
|
||||
|
||||
// GetDisableDNS reads disable DNS setting from config file
|
||||
func (p *Preferences) GetDisableDNS() (bool, error) {
|
||||
if p.configInput.DisableDNS != nil {
|
||||
return *p.configInput.DisableDNS, nil
|
||||
}
|
||||
|
||||
cfg, err := profilemanager.ReadConfig(p.configInput.ConfigPath)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return cfg.DisableDNS, err
|
||||
}
|
||||
|
||||
// SetDisableDNS stores the given value and waits for commit
|
||||
func (p *Preferences) SetDisableDNS(disable bool) {
|
||||
p.configInput.DisableDNS = &disable
|
||||
}
|
||||
|
||||
// GetDisableFirewall reads disable firewall setting from config file
|
||||
func (p *Preferences) GetDisableFirewall() (bool, error) {
|
||||
if p.configInput.DisableFirewall != nil {
|
||||
return *p.configInput.DisableFirewall, nil
|
||||
}
|
||||
|
||||
cfg, err := profilemanager.ReadConfig(p.configInput.ConfigPath)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return cfg.DisableFirewall, err
|
||||
}
|
||||
|
||||
// SetDisableFirewall stores the given value and waits for commit
|
||||
func (p *Preferences) SetDisableFirewall(disable bool) {
|
||||
p.configInput.DisableFirewall = &disable
|
||||
}
|
||||
|
||||
// GetServerSSHAllowed reads server SSH allowed setting from config file
|
||||
func (p *Preferences) GetServerSSHAllowed() (bool, error) {
|
||||
if p.configInput.ServerSSHAllowed != nil {
|
||||
return *p.configInput.ServerSSHAllowed, nil
|
||||
}
|
||||
|
||||
cfg, err := profilemanager.ReadConfig(p.configInput.ConfigPath)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
if cfg.ServerSSHAllowed == nil {
|
||||
// Default to false for security on Android
|
||||
return false, nil
|
||||
}
|
||||
return *cfg.ServerSSHAllowed, err
|
||||
}
|
||||
|
||||
// SetServerSSHAllowed stores the given value and waits for commit
|
||||
func (p *Preferences) SetServerSSHAllowed(allowed bool) {
|
||||
p.configInput.ServerSSHAllowed = &allowed
|
||||
}
|
||||
|
||||
// GetBlockInbound reads block inbound setting from config file
|
||||
func (p *Preferences) GetBlockInbound() (bool, error) {
|
||||
if p.configInput.BlockInbound != nil {
|
||||
return *p.configInput.BlockInbound, nil
|
||||
}
|
||||
|
||||
cfg, err := profilemanager.ReadConfig(p.configInput.ConfigPath)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return cfg.BlockInbound, err
|
||||
}
|
||||
|
||||
// SetBlockInbound stores the given value and waits for commit
|
||||
func (p *Preferences) SetBlockInbound(block bool) {
|
||||
p.configInput.BlockInbound = &block
|
||||
}
|
||||
|
||||
// Commit writes out the changes to the config file
|
||||
func (p *Preferences) Commit() error {
|
||||
_, err := internal.UpdateOrCreateConfig(p.configInput)
|
||||
_, err := profilemanager.UpdateOrCreateConfig(p.configInput)
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -4,7 +4,7 @@ import (
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||
)
|
||||
|
||||
func TestPreferences_DefaultValues(t *testing.T) {
|
||||
@@ -15,7 +15,7 @@ func TestPreferences_DefaultValues(t *testing.T) {
|
||||
t.Fatalf("failed to read default value: %s", err)
|
||||
}
|
||||
|
||||
if defaultVar != internal.DefaultAdminURL {
|
||||
if defaultVar != profilemanager.DefaultAdminURL {
|
||||
t.Errorf("invalid default admin url: %s", defaultVar)
|
||||
}
|
||||
|
||||
@@ -24,7 +24,7 @@ func TestPreferences_DefaultValues(t *testing.T) {
|
||||
t.Fatalf("failed to read default management URL: %s", err)
|
||||
}
|
||||
|
||||
if defaultVar != internal.DefaultManagementURL {
|
||||
if defaultVar != profilemanager.DefaultManagementURL {
|
||||
t.Errorf("invalid default management url: %s", defaultVar)
|
||||
}
|
||||
|
||||
|
||||
@@ -69,6 +69,22 @@ func (a *Anonymizer) AnonymizeIP(ip netip.Addr) netip.Addr {
|
||||
return a.ipAnonymizer[ip]
|
||||
}
|
||||
|
||||
func (a *Anonymizer) AnonymizeUDPAddr(addr net.UDPAddr) net.UDPAddr {
|
||||
// Convert IP to netip.Addr
|
||||
ip, ok := netip.AddrFromSlice(addr.IP)
|
||||
if !ok {
|
||||
return addr
|
||||
}
|
||||
|
||||
anonIP := a.AnonymizeIP(ip)
|
||||
|
||||
return net.UDPAddr{
|
||||
IP: anonIP.AsSlice(),
|
||||
Port: addr.Port,
|
||||
Zone: addr.Zone,
|
||||
}
|
||||
}
|
||||
|
||||
// isInAnonymizedRange checks if an IP is within the range of already assigned anonymized IPs
|
||||
func (a *Anonymizer) isInAnonymizedRange(ip netip.Addr) bool {
|
||||
if ip.Is4() && ip.Compare(a.startAnonIPv4) >= 0 && ip.Compare(a.currentAnonIPv4) <= 0 {
|
||||
|
||||
@@ -11,13 +11,25 @@ import (
|
||||
"google.golang.org/grpc/status"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
"github.com/netbirdio/netbird/client/internal/debug"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
"github.com/netbirdio/netbird/client/server"
|
||||
nbstatus "github.com/netbirdio/netbird/client/status"
|
||||
mgmProto "github.com/netbirdio/netbird/management/proto"
|
||||
"github.com/netbirdio/netbird/upload-server/types"
|
||||
)
|
||||
|
||||
const errCloseConnection = "Failed to close connection: %v"
|
||||
|
||||
var (
|
||||
logFileCount uint32
|
||||
systemInfoFlag bool
|
||||
uploadBundleFlag bool
|
||||
uploadBundleURLFlag string
|
||||
)
|
||||
|
||||
var debugCmd = &cobra.Command{
|
||||
Use: "debug",
|
||||
Short: "Debugging commands",
|
||||
@@ -84,16 +96,28 @@ func debugBundle(cmd *cobra.Command, _ []string) error {
|
||||
}()
|
||||
|
||||
client := proto.NewDaemonServiceClient(conn)
|
||||
resp, err := client.DebugBundle(cmd.Context(), &proto.DebugBundleRequest{
|
||||
Anonymize: anonymizeFlag,
|
||||
Status: getStatusOutput(cmd, anonymizeFlag),
|
||||
SystemInfo: debugSystemInfoFlag,
|
||||
})
|
||||
request := &proto.DebugBundleRequest{
|
||||
Anonymize: anonymizeFlag,
|
||||
Status: getStatusOutput(cmd, anonymizeFlag),
|
||||
SystemInfo: systemInfoFlag,
|
||||
LogFileCount: logFileCount,
|
||||
}
|
||||
if uploadBundleFlag {
|
||||
request.UploadURL = uploadBundleURLFlag
|
||||
}
|
||||
resp, err := client.DebugBundle(cmd.Context(), request)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to bundle debug: %v", status.Convert(err).Message())
|
||||
}
|
||||
cmd.Printf("Local file:\n%s\n", resp.GetPath())
|
||||
|
||||
cmd.Println(resp.GetPath())
|
||||
if resp.GetUploadFailureReason() != "" {
|
||||
return fmt.Errorf("upload failed: %s", resp.GetUploadFailureReason())
|
||||
}
|
||||
|
||||
if uploadBundleFlag {
|
||||
cmd.Printf("Upload file key:\n%s\n", resp.GetUploadedKey())
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -208,23 +232,20 @@ func runForDuration(cmd *cobra.Command, args []string) error {
|
||||
|
||||
headerPreDown := fmt.Sprintf("----- Netbird pre-down - Timestamp: %s - Duration: %s", time.Now().Format(time.RFC3339), duration)
|
||||
statusOutput = fmt.Sprintf("%s\n%s\n%s", statusOutput, headerPreDown, getStatusOutput(cmd, anonymizeFlag))
|
||||
|
||||
resp, err := client.DebugBundle(cmd.Context(), &proto.DebugBundleRequest{
|
||||
Anonymize: anonymizeFlag,
|
||||
Status: statusOutput,
|
||||
SystemInfo: debugSystemInfoFlag,
|
||||
})
|
||||
request := &proto.DebugBundleRequest{
|
||||
Anonymize: anonymizeFlag,
|
||||
Status: statusOutput,
|
||||
SystemInfo: systemInfoFlag,
|
||||
LogFileCount: logFileCount,
|
||||
}
|
||||
if uploadBundleFlag {
|
||||
request.UploadURL = uploadBundleURLFlag
|
||||
}
|
||||
resp, err := client.DebugBundle(cmd.Context(), request)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to bundle debug: %v", status.Convert(err).Message())
|
||||
}
|
||||
|
||||
// Disable network map persistence after creating the debug bundle
|
||||
if _, err := client.SetNetworkMapPersistence(cmd.Context(), &proto.SetNetworkMapPersistenceRequest{
|
||||
Enabled: false,
|
||||
}); err != nil {
|
||||
return fmt.Errorf("failed to disable network map persistence: %v", status.Convert(err).Message())
|
||||
}
|
||||
|
||||
if stateWasDown {
|
||||
if _, err := client.Down(cmd.Context(), &proto.DownRequest{}); err != nil {
|
||||
return fmt.Errorf("failed to down: %v", status.Convert(err).Message())
|
||||
@@ -239,7 +260,15 @@ func runForDuration(cmd *cobra.Command, args []string) error {
|
||||
cmd.Println("Log level restored to", initialLogLevel.GetLevel())
|
||||
}
|
||||
|
||||
cmd.Println(resp.GetPath())
|
||||
cmd.Printf("Local file:\n%s\n", resp.GetPath())
|
||||
|
||||
if resp.GetUploadFailureReason() != "" {
|
||||
return fmt.Errorf("upload failed: %s", resp.GetUploadFailureReason())
|
||||
}
|
||||
|
||||
if uploadBundleFlag {
|
||||
cmd.Printf("Upload file key:\n%s\n", resp.GetUploadedKey())
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -279,7 +308,7 @@ func getStatusOutput(cmd *cobra.Command, anon bool) string {
|
||||
cmd.PrintErrf("Failed to get status: %v\n", err)
|
||||
} else {
|
||||
statusOutputString = nbstatus.ParseToFullDetailSummary(
|
||||
nbstatus.ConvertToStatusOutputOverview(statusResp, anon, "", nil, nil, nil),
|
||||
nbstatus.ConvertToStatusOutputOverview(statusResp, anon, "", nil, nil, nil, "", ""),
|
||||
)
|
||||
}
|
||||
return statusOutputString
|
||||
@@ -326,3 +355,46 @@ func formatDuration(d time.Duration) string {
|
||||
s := d / time.Second
|
||||
return fmt.Sprintf("%02d:%02d:%02d", h, m, s)
|
||||
}
|
||||
|
||||
func generateDebugBundle(config *profilemanager.Config, recorder *peer.Status, connectClient *internal.ConnectClient, logFilePath string) {
|
||||
var networkMap *mgmProto.NetworkMap
|
||||
var err error
|
||||
|
||||
if connectClient != nil {
|
||||
networkMap, err = connectClient.GetLatestNetworkMap()
|
||||
if err != nil {
|
||||
log.Warnf("Failed to get latest network map: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
bundleGenerator := debug.NewBundleGenerator(
|
||||
debug.GeneratorDependencies{
|
||||
InternalConfig: config,
|
||||
StatusRecorder: recorder,
|
||||
NetworkMap: networkMap,
|
||||
LogFile: logFilePath,
|
||||
},
|
||||
debug.BundleConfig{
|
||||
IncludeSystemInfo: true,
|
||||
},
|
||||
)
|
||||
|
||||
path, err := bundleGenerator.Generate()
|
||||
if err != nil {
|
||||
log.Errorf("Failed to generate debug bundle: %v", err)
|
||||
return
|
||||
}
|
||||
log.Infof("Generated debug bundle from SIGUSR1 at: %s", path)
|
||||
}
|
||||
|
||||
func init() {
|
||||
debugBundleCmd.Flags().Uint32VarP(&logFileCount, "log-file-count", "C", 1, "Number of rotated log files to include in debug bundle")
|
||||
debugBundleCmd.Flags().BoolVarP(&systemInfoFlag, "system-info", "S", true, "Adds system information to the debug bundle")
|
||||
debugBundleCmd.Flags().BoolVarP(&uploadBundleFlag, "upload-bundle", "U", false, "Uploads the debug bundle to a server")
|
||||
debugBundleCmd.Flags().StringVar(&uploadBundleURLFlag, "upload-bundle-url", types.DefaultBundleURL, "Service URL to get an URL to upload the debug bundle")
|
||||
|
||||
forCmd.Flags().Uint32VarP(&logFileCount, "log-file-count", "C", 1, "Number of rotated log files to include in debug bundle")
|
||||
forCmd.Flags().BoolVarP(&systemInfoFlag, "system-info", "S", true, "Adds system information to the debug bundle")
|
||||
forCmd.Flags().BoolVarP(&uploadBundleFlag, "upload-bundle", "U", false, "Uploads the debug bundle to a server")
|
||||
forCmd.Flags().StringVar(&uploadBundleURLFlag, "upload-bundle-url", types.DefaultBundleURL, "Service URL to get an URL to upload the debug bundle")
|
||||
}
|
||||
|
||||
40
client/cmd/debug_unix.go
Normal file
40
client/cmd/debug_unix.go
Normal file
@@ -0,0 +1,40 @@
|
||||
//go:build unix
|
||||
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||
)
|
||||
|
||||
func SetupDebugHandler(
|
||||
ctx context.Context,
|
||||
config *profilemanager.Config,
|
||||
recorder *peer.Status,
|
||||
connectClient *internal.ConnectClient,
|
||||
logFilePath string,
|
||||
) {
|
||||
usr1Ch := make(chan os.Signal, 1)
|
||||
|
||||
signal.Notify(usr1Ch, syscall.SIGUSR1)
|
||||
|
||||
go func() {
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-usr1Ch:
|
||||
log.Info("Received SIGUSR1. Triggering debug bundle generation.")
|
||||
go generateDebugBundle(config, recorder, connectClient, logFilePath)
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
127
client/cmd/debug_windows.go
Normal file
127
client/cmd/debug_windows.go
Normal file
@@ -0,0 +1,127 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"os"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/sys/windows"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||
)
|
||||
|
||||
const (
|
||||
envListenEvent = "NB_LISTEN_DEBUG_EVENT"
|
||||
debugTriggerEventName = `Global\NetbirdDebugTriggerEvent`
|
||||
|
||||
waitTimeout = 5 * time.Second
|
||||
)
|
||||
|
||||
// SetupDebugHandler sets up a Windows event to listen for a signal to generate a debug bundle.
|
||||
// Example usage with PowerShell:
|
||||
// $evt = [System.Threading.EventWaitHandle]::OpenExisting("Global\NetbirdDebugTriggerEvent")
|
||||
// $evt.Set()
|
||||
// $evt.Close()
|
||||
func SetupDebugHandler(
|
||||
ctx context.Context,
|
||||
config *profilemanager.Config,
|
||||
recorder *peer.Status,
|
||||
connectClient *internal.ConnectClient,
|
||||
logFilePath string,
|
||||
) {
|
||||
env := os.Getenv(envListenEvent)
|
||||
if env == "" {
|
||||
return
|
||||
}
|
||||
|
||||
listenEvent, err := strconv.ParseBool(env)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to parse %s: %v", envListenEvent, err)
|
||||
return
|
||||
}
|
||||
if !listenEvent {
|
||||
return
|
||||
}
|
||||
|
||||
eventNamePtr, err := windows.UTF16PtrFromString(debugTriggerEventName)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to convert event name '%s' to UTF16: %v", debugTriggerEventName, err)
|
||||
return
|
||||
}
|
||||
|
||||
// TODO: restrict access by ACL
|
||||
eventHandle, err := windows.CreateEvent(nil, 1, 0, eventNamePtr)
|
||||
if err != nil {
|
||||
if errors.Is(err, windows.ERROR_ALREADY_EXISTS) {
|
||||
log.Warnf("Debug trigger event '%s' already exists. Attempting to open.", debugTriggerEventName)
|
||||
// SYNCHRONIZE is needed for WaitForSingleObject, EVENT_MODIFY_STATE for ResetEvent.
|
||||
eventHandle, err = windows.OpenEvent(windows.SYNCHRONIZE|windows.EVENT_MODIFY_STATE, false, eventNamePtr)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to open existing debug trigger event '%s': %v", debugTriggerEventName, err)
|
||||
return
|
||||
}
|
||||
log.Infof("Successfully opened existing debug trigger event '%s'.", debugTriggerEventName)
|
||||
} else {
|
||||
log.Errorf("Failed to create debug trigger event '%s': %v", debugTriggerEventName, err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if eventHandle == windows.InvalidHandle {
|
||||
log.Errorf("Obtained an invalid handle for debug trigger event '%s'", debugTriggerEventName)
|
||||
return
|
||||
}
|
||||
|
||||
log.Infof("Debug handler waiting for signal on event: %s", debugTriggerEventName)
|
||||
|
||||
go waitForEvent(ctx, config, recorder, connectClient, logFilePath, eventHandle)
|
||||
}
|
||||
|
||||
func waitForEvent(
|
||||
ctx context.Context,
|
||||
config *profilemanager.Config,
|
||||
recorder *peer.Status,
|
||||
connectClient *internal.ConnectClient,
|
||||
logFilePath string,
|
||||
eventHandle windows.Handle,
|
||||
) {
|
||||
defer func() {
|
||||
if err := windows.CloseHandle(eventHandle); err != nil {
|
||||
log.Errorf("Failed to close debug event handle '%s': %v", debugTriggerEventName, err)
|
||||
}
|
||||
}()
|
||||
|
||||
for {
|
||||
if ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
|
||||
status, err := windows.WaitForSingleObject(eventHandle, uint32(waitTimeout.Milliseconds()))
|
||||
|
||||
switch status {
|
||||
case windows.WAIT_OBJECT_0:
|
||||
log.Info("Received signal on debug event. Triggering debug bundle generation.")
|
||||
|
||||
// reset the event so it can be triggered again later (manual reset == 1)
|
||||
if err := windows.ResetEvent(eventHandle); err != nil {
|
||||
log.Errorf("Failed to reset debug event '%s': %v", debugTriggerEventName, err)
|
||||
}
|
||||
|
||||
go generateDebugBundle(config, recorder, connectClient, logFilePath)
|
||||
case uint32(windows.WAIT_TIMEOUT):
|
||||
|
||||
default:
|
||||
log.Errorf("Unexpected status %d from WaitForSingleObject for debug event '%s': %v", status, debugTriggerEventName, err)
|
||||
select {
|
||||
case <-time.After(5 * time.Second):
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -20,7 +20,7 @@ var downCmd = &cobra.Command{
|
||||
|
||||
cmd.SetOut(cmd.OutOrStdout())
|
||||
|
||||
err := util.InitLog(logLevel, "console")
|
||||
err := util.InitLog(logLevel, util.LogConsole)
|
||||
if err != nil {
|
||||
log.Errorf("failed initializing log %v", err)
|
||||
return err
|
||||
|
||||
@@ -4,9 +4,12 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/user"
|
||||
"runtime"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/skratchdot/open-golang/open"
|
||||
"github.com/spf13/cobra"
|
||||
"google.golang.org/grpc/codes"
|
||||
@@ -14,6 +17,7 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
"github.com/netbirdio/netbird/client/internal/auth"
|
||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
"github.com/netbirdio/netbird/client/system"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
@@ -21,19 +25,16 @@ import (
|
||||
|
||||
func init() {
|
||||
loginCmd.PersistentFlags().BoolVar(&noBrowser, noBrowserFlag, false, noBrowserDesc)
|
||||
loginCmd.PersistentFlags().StringVar(&profileName, profileNameFlag, "", profileNameDesc)
|
||||
loginCmd.PersistentFlags().StringVarP(&configPath, "config", "c", "", "(DEPRECATED) Netbird config file location")
|
||||
}
|
||||
|
||||
var loginCmd = &cobra.Command{
|
||||
Use: "login",
|
||||
Short: "login to the Netbird Management Service (first run)",
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
SetFlagsFromEnvVars(rootCmd)
|
||||
|
||||
cmd.SetOut(cmd.OutOrStdout())
|
||||
|
||||
err := util.InitLog(logLevel, "console")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed initializing log %v", err)
|
||||
if err := setEnvAndFlags(cmd); err != nil {
|
||||
return fmt.Errorf("set env and flags: %v", err)
|
||||
}
|
||||
|
||||
ctx := internal.CtxInitState(context.Background())
|
||||
@@ -42,6 +43,17 @@ var loginCmd = &cobra.Command{
|
||||
// nolint
|
||||
ctx = context.WithValue(ctx, system.DeviceNameCtxKey, hostName)
|
||||
}
|
||||
username, err := user.Current()
|
||||
if err != nil {
|
||||
return fmt.Errorf("get current user: %v", err)
|
||||
}
|
||||
|
||||
pm := profilemanager.NewProfileManager()
|
||||
|
||||
activeProf, err := getActiveProfile(cmd.Context(), pm, profileName, username.Username)
|
||||
if err != nil {
|
||||
return fmt.Errorf("get active profile: %v", err)
|
||||
}
|
||||
|
||||
providedSetupKey, err := getSetupKey()
|
||||
if err != nil {
|
||||
@@ -49,94 +61,15 @@ var loginCmd = &cobra.Command{
|
||||
}
|
||||
|
||||
// workaround to run without service
|
||||
if logFile == "console" {
|
||||
err = handleRebrand(cmd)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ic := internal.ConfigInput{
|
||||
ManagementURL: managementURL,
|
||||
AdminURL: adminURL,
|
||||
ConfigPath: configPath,
|
||||
}
|
||||
if rootCmd.PersistentFlags().Changed(preSharedKeyFlag) {
|
||||
ic.PreSharedKey = &preSharedKey
|
||||
}
|
||||
|
||||
config, err := internal.UpdateOrCreateConfig(ic)
|
||||
if err != nil {
|
||||
return fmt.Errorf("get config file: %v", err)
|
||||
}
|
||||
|
||||
config, _ = internal.UpdateOldManagementURL(ctx, config, configPath)
|
||||
|
||||
err = foregroundLogin(ctx, cmd, config, providedSetupKey)
|
||||
if err != nil {
|
||||
if util.FindFirstLogPath(logFiles) == "" {
|
||||
if err := doForegroundLogin(ctx, cmd, providedSetupKey, activeProf); err != nil {
|
||||
return fmt.Errorf("foreground login failed: %v", err)
|
||||
}
|
||||
cmd.Println("Logging successfully")
|
||||
return nil
|
||||
}
|
||||
|
||||
conn, err := DialClientGRPCServer(ctx, daemonAddr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to connect to daemon error: %v\n"+
|
||||
"If the daemon is not running please run: "+
|
||||
"\nnetbird service install \nnetbird service start\n", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
client := proto.NewDaemonServiceClient(conn)
|
||||
|
||||
var dnsLabelsReq []string
|
||||
if dnsLabelsValidated != nil {
|
||||
dnsLabelsReq = dnsLabelsValidated.ToSafeStringList()
|
||||
}
|
||||
|
||||
loginRequest := proto.LoginRequest{
|
||||
SetupKey: providedSetupKey,
|
||||
ManagementUrl: managementURL,
|
||||
IsLinuxDesktopClient: isLinuxRunningDesktop(),
|
||||
Hostname: hostName,
|
||||
DnsLabels: dnsLabelsReq,
|
||||
}
|
||||
|
||||
if rootCmd.PersistentFlags().Changed(preSharedKeyFlag) {
|
||||
loginRequest.OptionalPreSharedKey = &preSharedKey
|
||||
}
|
||||
|
||||
var loginErr error
|
||||
|
||||
var loginResp *proto.LoginResponse
|
||||
|
||||
err = WithBackOff(func() error {
|
||||
var backOffErr error
|
||||
loginResp, backOffErr = client.Login(ctx, &loginRequest)
|
||||
if s, ok := gstatus.FromError(backOffErr); ok && (s.Code() == codes.InvalidArgument ||
|
||||
s.Code() == codes.PermissionDenied ||
|
||||
s.Code() == codes.NotFound ||
|
||||
s.Code() == codes.Unimplemented) {
|
||||
loginErr = backOffErr
|
||||
return nil
|
||||
}
|
||||
return backOffErr
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("login backoff cycle failed: %v", err)
|
||||
}
|
||||
|
||||
if loginErr != nil {
|
||||
return fmt.Errorf("login failed: %v", loginErr)
|
||||
}
|
||||
|
||||
if loginResp.NeedsSSOLogin {
|
||||
openURL(cmd, loginResp.VerificationURIComplete, loginResp.UserCode, noBrowser)
|
||||
|
||||
_, err = client.WaitSSOLogin(ctx, &proto.WaitSSOLoginRequest{UserCode: loginResp.UserCode, Hostname: hostName})
|
||||
if err != nil {
|
||||
return fmt.Errorf("waiting sso login failed with: %v", err)
|
||||
}
|
||||
if err := doDaemonLogin(ctx, cmd, providedSetupKey, activeProf, username.Username, pm); err != nil {
|
||||
return fmt.Errorf("daemon login failed: %v", err)
|
||||
}
|
||||
|
||||
cmd.Println("Logging successfully")
|
||||
@@ -145,7 +78,196 @@ var loginCmd = &cobra.Command{
|
||||
},
|
||||
}
|
||||
|
||||
func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *internal.Config, setupKey string) error {
|
||||
func doDaemonLogin(ctx context.Context, cmd *cobra.Command, providedSetupKey string, activeProf *profilemanager.Profile, username string, pm *profilemanager.ProfileManager) error {
|
||||
conn, err := DialClientGRPCServer(ctx, daemonAddr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to connect to daemon error: %v\n"+
|
||||
"If the daemon is not running please run: "+
|
||||
"\nnetbird service install \nnetbird service start\n", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
client := proto.NewDaemonServiceClient(conn)
|
||||
|
||||
var dnsLabelsReq []string
|
||||
if dnsLabelsValidated != nil {
|
||||
dnsLabelsReq = dnsLabelsValidated.ToSafeStringList()
|
||||
}
|
||||
|
||||
loginRequest := proto.LoginRequest{
|
||||
SetupKey: providedSetupKey,
|
||||
ManagementUrl: managementURL,
|
||||
IsUnixDesktopClient: isUnixRunningDesktop(),
|
||||
Hostname: hostName,
|
||||
DnsLabels: dnsLabelsReq,
|
||||
ProfileName: &activeProf.Name,
|
||||
Username: &username,
|
||||
}
|
||||
|
||||
if rootCmd.PersistentFlags().Changed(preSharedKeyFlag) {
|
||||
loginRequest.OptionalPreSharedKey = &preSharedKey
|
||||
}
|
||||
|
||||
var loginErr error
|
||||
|
||||
var loginResp *proto.LoginResponse
|
||||
|
||||
err = WithBackOff(func() error {
|
||||
var backOffErr error
|
||||
loginResp, backOffErr = client.Login(ctx, &loginRequest)
|
||||
if s, ok := gstatus.FromError(backOffErr); ok && (s.Code() == codes.InvalidArgument ||
|
||||
s.Code() == codes.PermissionDenied ||
|
||||
s.Code() == codes.NotFound ||
|
||||
s.Code() == codes.Unimplemented) {
|
||||
loginErr = backOffErr
|
||||
return nil
|
||||
}
|
||||
return backOffErr
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("login backoff cycle failed: %v", err)
|
||||
}
|
||||
|
||||
if loginErr != nil {
|
||||
return fmt.Errorf("login failed: %v", loginErr)
|
||||
}
|
||||
|
||||
if loginResp.NeedsSSOLogin {
|
||||
if err := handleSSOLogin(ctx, cmd, loginResp, client, pm); err != nil {
|
||||
return fmt.Errorf("sso login failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func getActiveProfile(ctx context.Context, pm *profilemanager.ProfileManager, profileName string, username string) (*profilemanager.Profile, error) {
|
||||
// switch profile if provided
|
||||
|
||||
if profileName != "" {
|
||||
if err := switchProfileOnDaemon(ctx, pm, profileName, username); err != nil {
|
||||
return nil, fmt.Errorf("switch profile: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
activeProf, err := pm.GetActiveProfile()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get active profile: %v", err)
|
||||
}
|
||||
|
||||
if activeProf == nil {
|
||||
return nil, fmt.Errorf("active profile not found, please run 'netbird profile create' first")
|
||||
}
|
||||
return activeProf, nil
|
||||
}
|
||||
|
||||
func switchProfileOnDaemon(ctx context.Context, pm *profilemanager.ProfileManager, profileName string, username string) error {
|
||||
err := switchProfile(context.Background(), profileName, username)
|
||||
if err != nil {
|
||||
return fmt.Errorf("switch profile on daemon: %v", err)
|
||||
}
|
||||
|
||||
err = pm.SwitchProfile(profileName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("switch profile: %v", err)
|
||||
}
|
||||
|
||||
conn, err := DialClientGRPCServer(ctx, daemonAddr)
|
||||
if err != nil {
|
||||
log.Errorf("failed to connect to service CLI interface %v", err)
|
||||
return err
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
client := proto.NewDaemonServiceClient(conn)
|
||||
|
||||
status, err := client.Status(ctx, &proto.StatusRequest{})
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to get daemon status: %v", err)
|
||||
}
|
||||
|
||||
if status.Status == string(internal.StatusConnected) {
|
||||
if _, err := client.Down(ctx, &proto.DownRequest{}); err != nil {
|
||||
log.Errorf("call service down method: %v", err)
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func switchProfile(ctx context.Context, profileName string, username string) error {
|
||||
conn, err := DialClientGRPCServer(ctx, daemonAddr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to connect to daemon error: %v\n"+
|
||||
"If the daemon is not running please run: "+
|
||||
"\nnetbird service install \nnetbird service start\n", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
client := proto.NewDaemonServiceClient(conn)
|
||||
|
||||
_, err = client.SwitchProfile(ctx, &proto.SwitchProfileRequest{
|
||||
ProfileName: &profileName,
|
||||
Username: &username,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("switch profile failed: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func doForegroundLogin(ctx context.Context, cmd *cobra.Command, setupKey string, activeProf *profilemanager.Profile) error {
|
||||
|
||||
err := handleRebrand(cmd)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// update host's static platform and system information
|
||||
system.UpdateStaticInfo()
|
||||
|
||||
configFilePath, err := activeProf.FilePath()
|
||||
if err != nil {
|
||||
return fmt.Errorf("get active profile file path: %v", err)
|
||||
|
||||
}
|
||||
|
||||
config, err := profilemanager.ReadConfig(configFilePath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("read config file %s: %v", configFilePath, err)
|
||||
}
|
||||
|
||||
err = foregroundLogin(ctx, cmd, config, setupKey)
|
||||
if err != nil {
|
||||
return fmt.Errorf("foreground login failed: %v", err)
|
||||
}
|
||||
cmd.Println("Logging successfully")
|
||||
return nil
|
||||
}
|
||||
|
||||
func handleSSOLogin(ctx context.Context, cmd *cobra.Command, loginResp *proto.LoginResponse, client proto.DaemonServiceClient, pm *profilemanager.ProfileManager) error {
|
||||
openURL(cmd, loginResp.VerificationURIComplete, loginResp.UserCode, noBrowser)
|
||||
|
||||
resp, err := client.WaitSSOLogin(ctx, &proto.WaitSSOLoginRequest{UserCode: loginResp.UserCode, Hostname: hostName})
|
||||
if err != nil {
|
||||
return fmt.Errorf("waiting sso login failed with: %v", err)
|
||||
}
|
||||
|
||||
if resp.Email != "" {
|
||||
err = pm.SetActiveProfileState(&profilemanager.ProfileState{
|
||||
Email: resp.Email,
|
||||
})
|
||||
if err != nil {
|
||||
log.Warnf("failed to set active profile email: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *profilemanager.Config, setupKey string) error {
|
||||
needsLogin := false
|
||||
|
||||
err := WithBackOff(func() error {
|
||||
@@ -191,8 +313,8 @@ func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *internal.C
|
||||
return nil
|
||||
}
|
||||
|
||||
func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *internal.Config) (*auth.TokenInfo, error) {
|
||||
oAuthFlow, err := auth.NewOAuthFlow(ctx, config, isLinuxRunningDesktop())
|
||||
func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *profilemanager.Config) (*auth.TokenInfo, error) {
|
||||
oAuthFlow, err := auth.NewOAuthFlow(ctx, config, isUnixRunningDesktop())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -240,7 +362,23 @@ func openURL(cmd *cobra.Command, verificationURIComplete, userCode string, noBro
|
||||
}
|
||||
}
|
||||
|
||||
// isLinuxRunningDesktop checks if a Linux OS is running desktop environment
|
||||
func isLinuxRunningDesktop() bool {
|
||||
// isUnixRunningDesktop checks if a Linux OS is running desktop environment
|
||||
func isUnixRunningDesktop() bool {
|
||||
if runtime.GOOS != "linux" && runtime.GOOS != "freebsd" {
|
||||
return false
|
||||
}
|
||||
return os.Getenv("DESKTOP_SESSION") != "" || os.Getenv("XDG_CURRENT_DESKTOP") != ""
|
||||
}
|
||||
|
||||
func setEnvAndFlags(cmd *cobra.Command) error {
|
||||
SetFlagsFromEnvVars(rootCmd)
|
||||
|
||||
cmd.SetOut(cmd.OutOrStdout())
|
||||
|
||||
err := util.InitLog(logLevel, "console")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed initializing log %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -2,11 +2,11 @@ package cmd
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os/user"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
|
||||
@@ -14,40 +14,41 @@ func TestLogin(t *testing.T) {
|
||||
mgmAddr := startTestingServices(t)
|
||||
|
||||
tempDir := t.TempDir()
|
||||
confPath := tempDir + "/config.json"
|
||||
|
||||
currUser, err := user.Current()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to get current user: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
origDefaultProfileDir := profilemanager.DefaultConfigPathDir
|
||||
origActiveProfileStatePath := profilemanager.ActiveProfileStatePath
|
||||
profilemanager.DefaultConfigPathDir = tempDir
|
||||
profilemanager.ActiveProfileStatePath = tempDir + "/active_profile.json"
|
||||
sm := profilemanager.ServiceManager{}
|
||||
err = sm.SetActiveProfileState(&profilemanager.ActiveProfileState{
|
||||
Name: "default",
|
||||
Username: currUser.Username,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to set active profile state: %v", err)
|
||||
}
|
||||
|
||||
t.Cleanup(func() {
|
||||
profilemanager.DefaultConfigPathDir = origDefaultProfileDir
|
||||
profilemanager.ActiveProfileStatePath = origActiveProfileStatePath
|
||||
})
|
||||
|
||||
mgmtURL := fmt.Sprintf("http://%s", mgmAddr)
|
||||
rootCmd.SetArgs([]string{
|
||||
"login",
|
||||
"--config",
|
||||
confPath,
|
||||
"--log-file",
|
||||
"console",
|
||||
util.LogConsole,
|
||||
"--setup-key",
|
||||
strings.ToUpper("a2c8e62b-38f5-4553-b31e-dd66c696cebb"),
|
||||
"--management-url",
|
||||
mgmtURL,
|
||||
})
|
||||
err := rootCmd.Execute()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// validate generated config
|
||||
actualConf := &internal.Config{}
|
||||
_, err = util.ReadJson(confPath, actualConf)
|
||||
if err != nil {
|
||||
t.Errorf("expected proper config file written, got broken %v", err)
|
||||
}
|
||||
|
||||
if actualConf.ManagementURL.String() != mgmtURL {
|
||||
t.Errorf("expected management URL %s got %s", mgmtURL, actualConf.ManagementURL.String())
|
||||
}
|
||||
|
||||
if actualConf.WgIface != iface.WgInterfaceDefault {
|
||||
t.Errorf("expected WgIfaceName %s got %s", iface.WgInterfaceDefault, actualConf.WgIface)
|
||||
}
|
||||
|
||||
if len(actualConf.PrivateKey) == 0 {
|
||||
t.Errorf("expected non empty Private key, got empty")
|
||||
}
|
||||
// TODO(hakan): fix this test
|
||||
_ = rootCmd.Execute()
|
||||
}
|
||||
|
||||
236
client/cmd/profile.go
Normal file
236
client/cmd/profile.go
Normal file
@@ -0,0 +1,236 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"os/user"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
|
||||
var profileCmd = &cobra.Command{
|
||||
Use: "profile",
|
||||
Short: "manage Netbird profiles",
|
||||
Long: `Manage Netbird profiles, allowing you to list, switch, and remove profiles.`,
|
||||
}
|
||||
|
||||
var profileListCmd = &cobra.Command{
|
||||
Use: "list",
|
||||
Short: "list all profiles",
|
||||
Long: `List all available profiles in the Netbird client.`,
|
||||
RunE: listProfilesFunc,
|
||||
}
|
||||
|
||||
var profileAddCmd = &cobra.Command{
|
||||
Use: "add <profile_name>",
|
||||
Short: "add a new profile",
|
||||
Long: `Add a new profile to the Netbird client. The profile name must be unique.`,
|
||||
Args: cobra.ExactArgs(1),
|
||||
RunE: addProfileFunc,
|
||||
}
|
||||
|
||||
var profileRemoveCmd = &cobra.Command{
|
||||
Use: "remove <profile_name>",
|
||||
Short: "remove a profile",
|
||||
Long: `Remove a profile from the Netbird client. The profile must not be active.`,
|
||||
Args: cobra.ExactArgs(1),
|
||||
RunE: removeProfileFunc,
|
||||
}
|
||||
|
||||
var profileSelectCmd = &cobra.Command{
|
||||
Use: "select <profile_name>",
|
||||
Short: "select a profile",
|
||||
Long: `Select a profile to be the active profile in the Netbird client. The profile must exist.`,
|
||||
Args: cobra.ExactArgs(1),
|
||||
RunE: selectProfileFunc,
|
||||
}
|
||||
|
||||
func setupCmd(cmd *cobra.Command) error {
|
||||
SetFlagsFromEnvVars(rootCmd)
|
||||
SetFlagsFromEnvVars(cmd)
|
||||
|
||||
cmd.SetOut(cmd.OutOrStdout())
|
||||
|
||||
err := util.InitLog(logLevel, "console")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
func listProfilesFunc(cmd *cobra.Command, _ []string) error {
|
||||
if err := setupCmd(cmd); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
conn, err := DialClientGRPCServer(cmd.Context(), daemonAddr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("connect to service CLI interface: %w", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
currUser, err := user.Current()
|
||||
if err != nil {
|
||||
return fmt.Errorf("get current user: %w", err)
|
||||
}
|
||||
|
||||
daemonClient := proto.NewDaemonServiceClient(conn)
|
||||
|
||||
profiles, err := daemonClient.ListProfiles(cmd.Context(), &proto.ListProfilesRequest{
|
||||
Username: currUser.Username,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// list profiles, add a tick if the profile is active
|
||||
cmd.Println("Found", len(profiles.Profiles), "profiles:")
|
||||
for _, profile := range profiles.Profiles {
|
||||
// use a cross to indicate the passive profiles
|
||||
activeMarker := "✗"
|
||||
if profile.IsActive {
|
||||
activeMarker = "✓"
|
||||
}
|
||||
cmd.Println(activeMarker, profile.Name)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func addProfileFunc(cmd *cobra.Command, args []string) error {
|
||||
if err := setupCmd(cmd); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
conn, err := DialClientGRPCServer(cmd.Context(), daemonAddr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("connect to service CLI interface: %w", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
currUser, err := user.Current()
|
||||
if err != nil {
|
||||
return fmt.Errorf("get current user: %w", err)
|
||||
}
|
||||
|
||||
daemonClient := proto.NewDaemonServiceClient(conn)
|
||||
|
||||
profileName := args[0]
|
||||
|
||||
_, err = daemonClient.AddProfile(cmd.Context(), &proto.AddProfileRequest{
|
||||
ProfileName: profileName,
|
||||
Username: currUser.Username,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cmd.Println("Profile added successfully:", profileName)
|
||||
return nil
|
||||
}
|
||||
|
||||
func removeProfileFunc(cmd *cobra.Command, args []string) error {
|
||||
if err := setupCmd(cmd); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
conn, err := DialClientGRPCServer(cmd.Context(), daemonAddr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("connect to service CLI interface: %w", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
currUser, err := user.Current()
|
||||
if err != nil {
|
||||
return fmt.Errorf("get current user: %w", err)
|
||||
}
|
||||
|
||||
daemonClient := proto.NewDaemonServiceClient(conn)
|
||||
|
||||
profileName := args[0]
|
||||
|
||||
_, err = daemonClient.RemoveProfile(cmd.Context(), &proto.RemoveProfileRequest{
|
||||
ProfileName: profileName,
|
||||
Username: currUser.Username,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cmd.Println("Profile removed successfully:", profileName)
|
||||
return nil
|
||||
}
|
||||
|
||||
func selectProfileFunc(cmd *cobra.Command, args []string) error {
|
||||
if err := setupCmd(cmd); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
profileManager := profilemanager.NewProfileManager()
|
||||
profileName := args[0]
|
||||
|
||||
currUser, err := user.Current()
|
||||
if err != nil {
|
||||
return fmt.Errorf("get current user: %w", err)
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second*7)
|
||||
defer cancel()
|
||||
conn, err := DialClientGRPCServer(ctx, daemonAddr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("connect to service CLI interface: %w", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
daemonClient := proto.NewDaemonServiceClient(conn)
|
||||
|
||||
profiles, err := daemonClient.ListProfiles(ctx, &proto.ListProfilesRequest{
|
||||
Username: currUser.Username,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("list profiles: %w", err)
|
||||
}
|
||||
|
||||
var profileExists bool
|
||||
|
||||
for _, profile := range profiles.Profiles {
|
||||
if profile.Name == profileName {
|
||||
profileExists = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !profileExists {
|
||||
return fmt.Errorf("profile %s does not exist", profileName)
|
||||
}
|
||||
|
||||
if err := switchProfile(cmd.Context(), profileName, currUser.Username); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = profileManager.SwitchProfile(profileName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
status, err := daemonClient.Status(ctx, &proto.StatusRequest{})
|
||||
if err != nil {
|
||||
return fmt.Errorf("get service status: %w", err)
|
||||
}
|
||||
|
||||
if status.Status == string(internal.StatusConnected) {
|
||||
if _, err := daemonClient.Down(ctx, &proto.DownRequest{}); err != nil {
|
||||
return fmt.Errorf("call service down method: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
cmd.Println("Profile switched successfully to:", profileName)
|
||||
return nil
|
||||
}
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"os/signal"
|
||||
"path"
|
||||
"runtime"
|
||||
"slices"
|
||||
"strings"
|
||||
"syscall"
|
||||
"time"
|
||||
@@ -21,28 +22,26 @@ import (
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/credentials/insecure"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||
)
|
||||
|
||||
const (
|
||||
externalIPMapFlag = "external-ip-map"
|
||||
dnsResolverAddress = "dns-resolver-address"
|
||||
enableRosenpassFlag = "enable-rosenpass"
|
||||
rosenpassPermissiveFlag = "rosenpass-permissive"
|
||||
preSharedKeyFlag = "preshared-key"
|
||||
interfaceNameFlag = "interface-name"
|
||||
wireguardPortFlag = "wireguard-port"
|
||||
networkMonitorFlag = "network-monitor"
|
||||
disableAutoConnectFlag = "disable-auto-connect"
|
||||
serverSSHAllowedFlag = "allow-server-ssh"
|
||||
extraIFaceBlackListFlag = "extra-iface-blacklist"
|
||||
dnsRouteIntervalFlag = "dns-router-interval"
|
||||
systemInfoFlag = "system-info"
|
||||
blockLANAccessFlag = "block-lan-access"
|
||||
externalIPMapFlag = "external-ip-map"
|
||||
dnsResolverAddress = "dns-resolver-address"
|
||||
enableRosenpassFlag = "enable-rosenpass"
|
||||
rosenpassPermissiveFlag = "rosenpass-permissive"
|
||||
preSharedKeyFlag = "preshared-key"
|
||||
interfaceNameFlag = "interface-name"
|
||||
wireguardPortFlag = "wireguard-port"
|
||||
networkMonitorFlag = "network-monitor"
|
||||
disableAutoConnectFlag = "disable-auto-connect"
|
||||
serverSSHAllowedFlag = "allow-server-ssh"
|
||||
extraIFaceBlackListFlag = "extra-iface-blacklist"
|
||||
dnsRouteIntervalFlag = "dns-router-interval"
|
||||
enableLazyConnectionFlag = "enable-lazy-connection"
|
||||
)
|
||||
|
||||
var (
|
||||
configPath string
|
||||
defaultConfigPathDir string
|
||||
defaultConfigPath string
|
||||
oldDefaultConfigPathDir string
|
||||
@@ -52,7 +51,7 @@ var (
|
||||
defaultLogFile string
|
||||
oldDefaultLogFileDir string
|
||||
oldDefaultLogFile string
|
||||
logFile string
|
||||
logFiles []string
|
||||
daemonAddr string
|
||||
managementURL string
|
||||
adminURL string
|
||||
@@ -68,13 +67,12 @@ var (
|
||||
interfaceName string
|
||||
wireguardPort uint16
|
||||
networkMonitor bool
|
||||
serviceName string
|
||||
autoConnectDisabled bool
|
||||
extraIFaceBlackList []string
|
||||
anonymizeFlag bool
|
||||
debugSystemInfoFlag bool
|
||||
dnsRouteInterval time.Duration
|
||||
blockLANAccess bool
|
||||
lazyConnEnabled bool
|
||||
profilesDisabled bool
|
||||
|
||||
rootCmd = &cobra.Command{
|
||||
Use: "netbird",
|
||||
@@ -118,26 +116,19 @@ func init() {
|
||||
defaultDaemonAddr = "tcp://127.0.0.1:41731"
|
||||
}
|
||||
|
||||
defaultServiceName := "netbird"
|
||||
if runtime.GOOS == "windows" {
|
||||
defaultServiceName = "Netbird"
|
||||
}
|
||||
|
||||
rootCmd.PersistentFlags().StringVar(&daemonAddr, "daemon-addr", defaultDaemonAddr, "Daemon service address to serve CLI requests [unix|tcp]://[path|host:port]")
|
||||
rootCmd.PersistentFlags().StringVarP(&managementURL, "management-url", "m", "", fmt.Sprintf("Management Service URL [http|https]://[host]:[port] (default \"%s\")", internal.DefaultManagementURL))
|
||||
rootCmd.PersistentFlags().StringVar(&adminURL, "admin-url", "", fmt.Sprintf("Admin Panel URL [http|https]://[host]:[port] (default \"%s\")", internal.DefaultAdminURL))
|
||||
rootCmd.PersistentFlags().StringVarP(&serviceName, "service", "s", defaultServiceName, "Netbird system service name")
|
||||
rootCmd.PersistentFlags().StringVarP(&configPath, "config", "c", defaultConfigPath, "Netbird config file location")
|
||||
rootCmd.PersistentFlags().StringVarP(&managementURL, "management-url", "m", "", fmt.Sprintf("Management Service URL [http|https]://[host]:[port] (default \"%s\")", profilemanager.DefaultManagementURL))
|
||||
rootCmd.PersistentFlags().StringVar(&adminURL, "admin-url", "", fmt.Sprintf("Admin Panel URL [http|https]://[host]:[port] (default \"%s\")", profilemanager.DefaultAdminURL))
|
||||
rootCmd.PersistentFlags().StringVarP(&logLevel, "log-level", "l", "info", "sets Netbird log level")
|
||||
rootCmd.PersistentFlags().StringVar(&logFile, "log-file", defaultLogFile, "sets Netbird log path. If console is specified the log will be output to stdout. If syslog is specified the log will be sent to syslog daemon.")
|
||||
rootCmd.PersistentFlags().StringSliceVar(&logFiles, "log-file", []string{defaultLogFile}, "sets Netbird log paths written to simultaneously. If `console` is specified the log will be output to stdout. If `syslog` is specified the log will be sent to syslog daemon. You can pass the flag multiple times or separate entries by `,` character")
|
||||
rootCmd.PersistentFlags().StringVarP(&setupKey, "setup-key", "k", "", "Setup key obtained from the Management Service Dashboard (used to register peer)")
|
||||
rootCmd.PersistentFlags().StringVar(&setupKeyPath, "setup-key-file", "", "The path to a setup key obtained from the Management Service Dashboard (used to register peer) This is ignored if the setup-key flag is provided.")
|
||||
rootCmd.MarkFlagsMutuallyExclusive("setup-key", "setup-key-file")
|
||||
rootCmd.PersistentFlags().StringVar(&preSharedKey, preSharedKeyFlag, "", "Sets Wireguard PreSharedKey property. If set, then only peers that have the same key can communicate.")
|
||||
rootCmd.PersistentFlags().StringVarP(&hostName, "hostname", "n", "", "Sets a custom hostname for the device")
|
||||
rootCmd.PersistentFlags().BoolVarP(&anonymizeFlag, "anonymize", "A", false, "anonymize IP addresses and non-netbird.io domains in logs and status output")
|
||||
rootCmd.PersistentFlags().StringVarP(&configPath, "config", "c", defaultConfigPath, "(DEPRECATED) Netbird config file location")
|
||||
|
||||
rootCmd.AddCommand(serviceCmd)
|
||||
rootCmd.AddCommand(upCmd)
|
||||
rootCmd.AddCommand(downCmd)
|
||||
rootCmd.AddCommand(statusCmd)
|
||||
@@ -147,9 +138,7 @@ func init() {
|
||||
rootCmd.AddCommand(networksCMD)
|
||||
rootCmd.AddCommand(forwardingRulesCmd)
|
||||
rootCmd.AddCommand(debugCmd)
|
||||
|
||||
serviceCmd.AddCommand(runCmd, startCmd, stopCmd, restartCmd) // service control commands are subcommands of service
|
||||
serviceCmd.AddCommand(installCmd, uninstallCmd) // service installer commands are subcommands of service
|
||||
rootCmd.AddCommand(profileCmd)
|
||||
|
||||
networksCMD.AddCommand(routesListCmd)
|
||||
networksCMD.AddCommand(routesSelectCmd, routesDeselectCmd)
|
||||
@@ -162,6 +151,12 @@ func init() {
|
||||
debugCmd.AddCommand(forCmd)
|
||||
debugCmd.AddCommand(persistenceCmd)
|
||||
|
||||
// profile commands
|
||||
profileCmd.AddCommand(profileListCmd)
|
||||
profileCmd.AddCommand(profileAddCmd)
|
||||
profileCmd.AddCommand(profileRemoveCmd)
|
||||
profileCmd.AddCommand(profileSelectCmd)
|
||||
|
||||
upCmd.PersistentFlags().StringSliceVar(&natExternalIPs, externalIPMapFlag, nil,
|
||||
`Sets external IPs maps between local addresses and interfaces.`+
|
||||
`You can specify a comma-separated list with a single IP and IP/IP or IP/Interface Name. `+
|
||||
@@ -179,8 +174,8 @@ func init() {
|
||||
upCmd.PersistentFlags().BoolVar(&rosenpassPermissive, rosenpassPermissiveFlag, false, "[Experimental] Enable Rosenpass in permissive mode to allow this peer to accept WireGuard connections without requiring Rosenpass functionality from peers that do not have Rosenpass enabled.")
|
||||
upCmd.PersistentFlags().BoolVar(&serverSSHAllowed, serverSSHAllowedFlag, false, "Allow SSH server on peer. If enabled, the SSH server will be permitted")
|
||||
upCmd.PersistentFlags().BoolVar(&autoConnectDisabled, disableAutoConnectFlag, false, "Disables auto-connect feature. If enabled, then the client won't connect automatically when the service starts.")
|
||||
upCmd.PersistentFlags().BoolVar(&lazyConnEnabled, enableLazyConnectionFlag, false, "[Experimental] Enable the lazy connection feature. If enabled, the client will establish connections on-demand. Note: this setting may be overridden by management configuration.")
|
||||
|
||||
debugCmd.PersistentFlags().BoolVarP(&debugSystemInfoFlag, systemInfoFlag, "S", true, "Adds system information to the debug bundle")
|
||||
}
|
||||
|
||||
// SetupCloseHandler handles SIGTERM signal and exits with success
|
||||
@@ -188,14 +183,13 @@ func SetupCloseHandler(ctx context.Context, cancel context.CancelFunc) {
|
||||
termCh := make(chan os.Signal, 1)
|
||||
signal.Notify(termCh, os.Interrupt, syscall.SIGINT, syscall.SIGTERM)
|
||||
go func() {
|
||||
done := ctx.Done()
|
||||
defer cancel()
|
||||
select {
|
||||
case <-done:
|
||||
case <-ctx.Done():
|
||||
case <-termCh:
|
||||
}
|
||||
|
||||
log.Info("shutdown signal received")
|
||||
cancel()
|
||||
}()
|
||||
}
|
||||
|
||||
@@ -279,7 +273,7 @@ func getSetupKeyFromFile(setupKeyPath string) (string, error) {
|
||||
|
||||
func handleRebrand(cmd *cobra.Command) error {
|
||||
var err error
|
||||
if logFile == defaultLogFile {
|
||||
if slices.Contains(logFiles, defaultLogFile) {
|
||||
if migrateToNetbird(oldDefaultLogFile, defaultLogFile) {
|
||||
cmd.Printf("will copy Log dir %s and its content to %s\n", oldDefaultLogFileDir, defaultLogFileDir)
|
||||
err = cpDir(oldDefaultLogFileDir, defaultLogFileDir)
|
||||
@@ -288,15 +282,14 @@ func handleRebrand(cmd *cobra.Command) error {
|
||||
}
|
||||
}
|
||||
}
|
||||
if configPath == defaultConfigPath {
|
||||
if migrateToNetbird(oldDefaultConfigPath, defaultConfigPath) {
|
||||
cmd.Printf("will copy Config dir %s and its content to %s\n", oldDefaultConfigPathDir, defaultConfigPathDir)
|
||||
err = cpDir(oldDefaultConfigPathDir, defaultConfigPathDir)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if migrateToNetbird(oldDefaultConfigPath, defaultConfigPath) {
|
||||
cmd.Printf("will copy Config dir %s and its content to %s\n", oldDefaultConfigPathDir, defaultConfigPathDir)
|
||||
err = cpDir(oldDefaultConfigPathDir, defaultConfigPathDir)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -1,11 +1,15 @@
|
||||
//go:build !ios && !android
|
||||
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"runtime"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/kardianos/service"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/spf13/cobra"
|
||||
"google.golang.org/grpc"
|
||||
|
||||
@@ -13,6 +17,16 @@ import (
|
||||
"github.com/netbirdio/netbird/client/server"
|
||||
)
|
||||
|
||||
var serviceCmd = &cobra.Command{
|
||||
Use: "service",
|
||||
Short: "manages Netbird service",
|
||||
}
|
||||
|
||||
var (
|
||||
serviceName string
|
||||
serviceEnvVars []string
|
||||
)
|
||||
|
||||
type program struct {
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
@@ -21,30 +35,81 @@ type program struct {
|
||||
serverInstanceMu sync.Mutex
|
||||
}
|
||||
|
||||
func init() {
|
||||
defaultServiceName := "netbird"
|
||||
if runtime.GOOS == "windows" {
|
||||
defaultServiceName = "Netbird"
|
||||
}
|
||||
|
||||
serviceCmd.AddCommand(runCmd, startCmd, stopCmd, restartCmd, svcStatusCmd, installCmd, uninstallCmd, reconfigureCmd)
|
||||
serviceCmd.PersistentFlags().BoolVar(&profilesDisabled, "disable-profiles", false, "Disables profiles feature. If enabled, the client will not be able to change or edit any profile.")
|
||||
|
||||
rootCmd.PersistentFlags().StringVarP(&serviceName, "service", "s", defaultServiceName, "Netbird system service name")
|
||||
serviceEnvDesc := `Sets extra environment variables for the service. ` +
|
||||
`You can specify a comma-separated list of KEY=VALUE pairs. ` +
|
||||
`E.g. --service-env LOG_LEVEL=debug,CUSTOM_VAR=value`
|
||||
|
||||
installCmd.Flags().StringSliceVar(&serviceEnvVars, "service-env", nil, serviceEnvDesc)
|
||||
reconfigureCmd.Flags().StringSliceVar(&serviceEnvVars, "service-env", nil, serviceEnvDesc)
|
||||
|
||||
rootCmd.AddCommand(serviceCmd)
|
||||
}
|
||||
|
||||
func newProgram(ctx context.Context, cancel context.CancelFunc) *program {
|
||||
ctx = internal.CtxInitState(ctx)
|
||||
return &program{ctx: ctx, cancel: cancel}
|
||||
}
|
||||
|
||||
func newSVCConfig() *service.Config {
|
||||
return &service.Config{
|
||||
func newSVCConfig() (*service.Config, error) {
|
||||
config := &service.Config{
|
||||
Name: serviceName,
|
||||
DisplayName: "Netbird",
|
||||
Description: "A WireGuard-based mesh network that connects your devices into a single private network.",
|
||||
Description: "Netbird mesh network client",
|
||||
Option: make(service.KeyValue),
|
||||
EnvVars: make(map[string]string),
|
||||
}
|
||||
|
||||
if len(serviceEnvVars) > 0 {
|
||||
extraEnvs, err := parseServiceEnvVars(serviceEnvVars)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse service environment variables: %w", err)
|
||||
}
|
||||
config.EnvVars = extraEnvs
|
||||
}
|
||||
|
||||
if runtime.GOOS == "linux" {
|
||||
config.EnvVars["SYSTEMD_UNIT"] = serviceName
|
||||
}
|
||||
|
||||
return config, nil
|
||||
}
|
||||
|
||||
func newSVC(prg *program, conf *service.Config) (service.Service, error) {
|
||||
s, err := service.New(prg, conf)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
return nil, err
|
||||
}
|
||||
return s, nil
|
||||
return service.New(prg, conf)
|
||||
}
|
||||
|
||||
var serviceCmd = &cobra.Command{
|
||||
Use: "service",
|
||||
Short: "manages Netbird service",
|
||||
func parseServiceEnvVars(envVars []string) (map[string]string, error) {
|
||||
envMap := make(map[string]string)
|
||||
|
||||
for _, env := range envVars {
|
||||
if env == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
parts := strings.SplitN(env, "=", 2)
|
||||
if len(parts) != 2 {
|
||||
return nil, fmt.Errorf("invalid environment variable format: %s (expected KEY=VALUE)", env)
|
||||
}
|
||||
|
||||
key := strings.TrimSpace(parts[0])
|
||||
value := strings.TrimSpace(parts[1])
|
||||
|
||||
if key == "" {
|
||||
return nil, fmt.Errorf("empty environment variable key in: %s", env)
|
||||
}
|
||||
|
||||
envMap[key] = value
|
||||
}
|
||||
|
||||
return envMap, nil
|
||||
}
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
//go:build !ios && !android
|
||||
|
||||
package cmd
|
||||
|
||||
import (
|
||||
@@ -16,12 +18,17 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
"github.com/netbirdio/netbird/client/server"
|
||||
"github.com/netbirdio/netbird/client/system"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
|
||||
func (p *program) Start(svc service.Service) error {
|
||||
// Start should not block. Do the actual work async.
|
||||
log.Info("starting Netbird service") //nolint
|
||||
|
||||
// Collect static system and platform information
|
||||
system.UpdateStaticInfo()
|
||||
|
||||
// in any case, even if configuration does not exists we run daemon to serve CLI gRPC API.
|
||||
p.serv = grpc.NewServer()
|
||||
|
||||
@@ -42,20 +49,19 @@ func (p *program) Start(svc service.Service) error {
|
||||
|
||||
listen, err := net.Listen(split[0], split[1])
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to listen daemon interface: %w", err)
|
||||
return fmt.Errorf("listen daemon interface: %w", err)
|
||||
}
|
||||
go func() {
|
||||
defer listen.Close()
|
||||
|
||||
if split[0] == "unix" {
|
||||
err = os.Chmod(split[1], 0666)
|
||||
if err != nil {
|
||||
if err := os.Chmod(split[1], 0666); err != nil {
|
||||
log.Errorf("failed setting daemon permissions: %v", split[1])
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
serverInstance := server.New(p.ctx, configPath, logFile)
|
||||
serverInstance := server.New(p.ctx, util.FindFirstLogPath(logFiles), profilesDisabled)
|
||||
if err := serverInstance.Start(); err != nil {
|
||||
log.Fatalf("failed to start daemon: %v", err)
|
||||
}
|
||||
@@ -95,36 +101,49 @@ func (p *program) Stop(srv service.Service) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Common setup for service control commands
|
||||
func setupServiceControlCommand(cmd *cobra.Command, ctx context.Context, cancel context.CancelFunc) (service.Service, error) {
|
||||
SetFlagsFromEnvVars(rootCmd)
|
||||
SetFlagsFromEnvVars(serviceCmd)
|
||||
|
||||
cmd.SetOut(cmd.OutOrStdout())
|
||||
|
||||
if err := handleRebrand(cmd); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := util.InitLog(logLevel, logFiles...); err != nil {
|
||||
return nil, fmt.Errorf("init log: %w", err)
|
||||
}
|
||||
|
||||
cfg, err := newSVCConfig()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create service config: %w", err)
|
||||
}
|
||||
|
||||
s, err := newSVC(newProgram(ctx, cancel), cfg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return s, nil
|
||||
}
|
||||
|
||||
var runCmd = &cobra.Command{
|
||||
Use: "run",
|
||||
Short: "runs Netbird as service",
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
SetFlagsFromEnvVars(rootCmd)
|
||||
|
||||
cmd.SetOut(cmd.OutOrStdout())
|
||||
|
||||
err := handleRebrand(cmd)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = util.InitLog(logLevel, logFile)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed initializing log %v", err)
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(cmd.Context())
|
||||
SetupCloseHandler(ctx, cancel)
|
||||
|
||||
s, err := newSVC(newProgram(ctx, cancel), newSVCConfig())
|
||||
SetupCloseHandler(ctx, cancel)
|
||||
SetupDebugHandler(ctx, nil, nil, nil, util.FindFirstLogPath(logFiles))
|
||||
|
||||
s, err := setupServiceControlCommand(cmd, ctx, cancel)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = s.Run()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
|
||||
return s.Run()
|
||||
},
|
||||
}
|
||||
|
||||
@@ -132,31 +151,14 @@ var startCmd = &cobra.Command{
|
||||
Use: "start",
|
||||
Short: "starts Netbird service",
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
SetFlagsFromEnvVars(rootCmd)
|
||||
|
||||
cmd.SetOut(cmd.OutOrStdout())
|
||||
|
||||
err := handleRebrand(cmd)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = util.InitLog(logLevel, logFile)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(cmd.Context())
|
||||
|
||||
s, err := newSVC(newProgram(ctx, cancel), newSVCConfig())
|
||||
s, err := setupServiceControlCommand(cmd, ctx, cancel)
|
||||
if err != nil {
|
||||
cmd.PrintErrln(err)
|
||||
return err
|
||||
}
|
||||
err = s.Start()
|
||||
if err != nil {
|
||||
cmd.PrintErrln(err)
|
||||
return err
|
||||
|
||||
if err := s.Start(); err != nil {
|
||||
return fmt.Errorf("start service: %w", err)
|
||||
}
|
||||
cmd.Println("Netbird service has been started")
|
||||
return nil
|
||||
@@ -167,29 +169,14 @@ var stopCmd = &cobra.Command{
|
||||
Use: "stop",
|
||||
Short: "stops Netbird service",
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
SetFlagsFromEnvVars(rootCmd)
|
||||
|
||||
cmd.SetOut(cmd.OutOrStdout())
|
||||
|
||||
err := handleRebrand(cmd)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = util.InitLog(logLevel, logFile)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed initializing log %v", err)
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(cmd.Context())
|
||||
|
||||
s, err := newSVC(newProgram(ctx, cancel), newSVCConfig())
|
||||
s, err := setupServiceControlCommand(cmd, ctx, cancel)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = s.Stop()
|
||||
if err != nil {
|
||||
return err
|
||||
|
||||
if err := s.Stop(); err != nil {
|
||||
return fmt.Errorf("stop service: %w", err)
|
||||
}
|
||||
cmd.Println("Netbird service has been stopped")
|
||||
return nil
|
||||
@@ -200,31 +187,48 @@ var restartCmd = &cobra.Command{
|
||||
Use: "restart",
|
||||
Short: "restarts Netbird service",
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
SetFlagsFromEnvVars(rootCmd)
|
||||
|
||||
cmd.SetOut(cmd.OutOrStdout())
|
||||
|
||||
err := handleRebrand(cmd)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = util.InitLog(logLevel, logFile)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed initializing log %v", err)
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(cmd.Context())
|
||||
|
||||
s, err := newSVC(newProgram(ctx, cancel), newSVCConfig())
|
||||
s, err := setupServiceControlCommand(cmd, ctx, cancel)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = s.Restart()
|
||||
if err != nil {
|
||||
return err
|
||||
|
||||
if err := s.Restart(); err != nil {
|
||||
return fmt.Errorf("restart service: %w", err)
|
||||
}
|
||||
cmd.Println("Netbird service has been restarted")
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
var svcStatusCmd = &cobra.Command{
|
||||
Use: "status",
|
||||
Short: "shows Netbird service status",
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
ctx, cancel := context.WithCancel(cmd.Context())
|
||||
s, err := setupServiceControlCommand(cmd, ctx, cancel)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
status, err := s.Status()
|
||||
if err != nil {
|
||||
return fmt.Errorf("get service status: %w", err)
|
||||
}
|
||||
|
||||
var statusText string
|
||||
switch status {
|
||||
case service.StatusRunning:
|
||||
statusText = "Running"
|
||||
case service.StatusStopped:
|
||||
statusText = "Stopped"
|
||||
case service.StatusUnknown:
|
||||
statusText = "Unknown"
|
||||
default:
|
||||
statusText = fmt.Sprintf("Unknown (%d)", status)
|
||||
}
|
||||
|
||||
cmd.Printf("Netbird service status: %s\n", statusText)
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
@@ -1,87 +1,121 @@
|
||||
//go:build !ios && !android
|
||||
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
|
||||
"github.com/kardianos/service"
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
|
||||
var ErrGetServiceStatus = fmt.Errorf("failed to get service status")
|
||||
|
||||
// Common service command setup
|
||||
func setupServiceCommand(cmd *cobra.Command) error {
|
||||
SetFlagsFromEnvVars(rootCmd)
|
||||
SetFlagsFromEnvVars(serviceCmd)
|
||||
cmd.SetOut(cmd.OutOrStdout())
|
||||
return handleRebrand(cmd)
|
||||
}
|
||||
|
||||
// Build service arguments for install/reconfigure
|
||||
func buildServiceArguments() []string {
|
||||
args := []string{
|
||||
"service",
|
||||
"run",
|
||||
"--log-level",
|
||||
logLevel,
|
||||
"--daemon-addr",
|
||||
daemonAddr,
|
||||
}
|
||||
|
||||
if managementURL != "" {
|
||||
args = append(args, "--management-url", managementURL)
|
||||
}
|
||||
|
||||
for _, logFile := range logFiles {
|
||||
args = append(args, "--log-file", logFile)
|
||||
}
|
||||
|
||||
return args
|
||||
}
|
||||
|
||||
// Configure platform-specific service settings
|
||||
func configurePlatformSpecificSettings(svcConfig *service.Config) error {
|
||||
if runtime.GOOS == "linux" {
|
||||
// Respected only by systemd systems
|
||||
svcConfig.Dependencies = []string{"After=network.target syslog.target"}
|
||||
|
||||
if logFile := util.FindFirstLogPath(logFiles); logFile != "" {
|
||||
setStdLogPath := true
|
||||
dir := filepath.Dir(logFile)
|
||||
|
||||
if _, err := os.Stat(dir); err != nil {
|
||||
if err = os.MkdirAll(dir, 0750); err != nil {
|
||||
setStdLogPath = false
|
||||
}
|
||||
}
|
||||
|
||||
if setStdLogPath {
|
||||
svcConfig.Option["LogOutput"] = true
|
||||
svcConfig.Option["LogDirectory"] = dir
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if runtime.GOOS == "windows" {
|
||||
svcConfig.Option["OnFailure"] = "restart"
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Create fully configured service config for install/reconfigure
|
||||
func createServiceConfigForInstall() (*service.Config, error) {
|
||||
svcConfig, err := newSVCConfig()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create service config: %w", err)
|
||||
}
|
||||
|
||||
svcConfig.Arguments = buildServiceArguments()
|
||||
if err = configurePlatformSpecificSettings(svcConfig); err != nil {
|
||||
return nil, fmt.Errorf("configure platform-specific settings: %w", err)
|
||||
}
|
||||
|
||||
return svcConfig, nil
|
||||
}
|
||||
|
||||
var installCmd = &cobra.Command{
|
||||
Use: "install",
|
||||
Short: "installs Netbird service",
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
SetFlagsFromEnvVars(rootCmd)
|
||||
|
||||
cmd.SetOut(cmd.OutOrStdout())
|
||||
|
||||
err := handleRebrand(cmd)
|
||||
if err != nil {
|
||||
if err := setupServiceCommand(cmd); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
svcConfig := newSVCConfig()
|
||||
|
||||
svcConfig.Arguments = []string{
|
||||
"service",
|
||||
"run",
|
||||
"--config",
|
||||
configPath,
|
||||
"--log-level",
|
||||
logLevel,
|
||||
"--daemon-addr",
|
||||
daemonAddr,
|
||||
}
|
||||
|
||||
if managementURL != "" {
|
||||
svcConfig.Arguments = append(svcConfig.Arguments, "--management-url", managementURL)
|
||||
}
|
||||
|
||||
if logFile != "console" {
|
||||
svcConfig.Arguments = append(svcConfig.Arguments, "--log-file", logFile)
|
||||
}
|
||||
|
||||
if runtime.GOOS == "linux" {
|
||||
// Respected only by systemd systems
|
||||
svcConfig.Dependencies = []string{"After=network.target syslog.target"}
|
||||
|
||||
if logFile != "console" {
|
||||
setStdLogPath := true
|
||||
dir := filepath.Dir(logFile)
|
||||
|
||||
_, err := os.Stat(dir)
|
||||
if err != nil {
|
||||
err = os.MkdirAll(dir, 0750)
|
||||
if err != nil {
|
||||
setStdLogPath = false
|
||||
}
|
||||
}
|
||||
|
||||
if setStdLogPath {
|
||||
svcConfig.Option["LogOutput"] = true
|
||||
svcConfig.Option["LogDirectory"] = dir
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if runtime.GOOS == "windows" {
|
||||
svcConfig.Option["OnFailure"] = "restart"
|
||||
svcConfig, err := createServiceConfigForInstall()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(cmd.Context())
|
||||
defer cancel()
|
||||
|
||||
s, err := newSVC(newProgram(ctx, cancel), svcConfig)
|
||||
if err != nil {
|
||||
cmd.PrintErrln(err)
|
||||
return err
|
||||
}
|
||||
|
||||
err = s.Install()
|
||||
if err != nil {
|
||||
cmd.PrintErrln(err)
|
||||
return err
|
||||
if err := s.Install(); err != nil {
|
||||
return fmt.Errorf("install service: %w", err)
|
||||
}
|
||||
|
||||
cmd.Println("Netbird service has been installed")
|
||||
@@ -93,27 +127,109 @@ var uninstallCmd = &cobra.Command{
|
||||
Use: "uninstall",
|
||||
Short: "uninstalls Netbird service from system",
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
SetFlagsFromEnvVars(rootCmd)
|
||||
if err := setupServiceCommand(cmd); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cmd.SetOut(cmd.OutOrStdout())
|
||||
cfg, err := newSVCConfig()
|
||||
if err != nil {
|
||||
return fmt.Errorf("create service config: %w", err)
|
||||
}
|
||||
|
||||
err := handleRebrand(cmd)
|
||||
ctx, cancel := context.WithCancel(cmd.Context())
|
||||
defer cancel()
|
||||
|
||||
s, err := newSVC(newProgram(ctx, cancel), cfg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := s.Uninstall(); err != nil {
|
||||
return fmt.Errorf("uninstall service: %w", err)
|
||||
}
|
||||
|
||||
cmd.Println("Netbird service has been uninstalled")
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
var reconfigureCmd = &cobra.Command{
|
||||
Use: "reconfigure",
|
||||
Short: "reconfigures Netbird service with new settings",
|
||||
Long: `Reconfigures the Netbird service with new settings without manual uninstall/install.
|
||||
This command will temporarily stop the service, update its configuration, and restart it if it was running.`,
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
if err := setupServiceCommand(cmd); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
wasRunning, err := isServiceRunning()
|
||||
if err != nil && !errors.Is(err, ErrGetServiceStatus) {
|
||||
return fmt.Errorf("check service status: %w", err)
|
||||
}
|
||||
|
||||
svcConfig, err := createServiceConfigForInstall()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(cmd.Context())
|
||||
defer cancel()
|
||||
|
||||
s, err := newSVC(newProgram(ctx, cancel), newSVCConfig())
|
||||
s, err := newSVC(newProgram(ctx, cancel), svcConfig)
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("create service: %w", err)
|
||||
}
|
||||
|
||||
err = s.Uninstall()
|
||||
if err != nil {
|
||||
return err
|
||||
if wasRunning {
|
||||
cmd.Println("Stopping Netbird service...")
|
||||
if err := s.Stop(); err != nil {
|
||||
cmd.Printf("Warning: failed to stop service: %v\n", err)
|
||||
}
|
||||
}
|
||||
cmd.Println("Netbird service has been uninstalled")
|
||||
|
||||
cmd.Println("Removing existing service configuration...")
|
||||
if err := s.Uninstall(); err != nil {
|
||||
return fmt.Errorf("uninstall existing service: %w", err)
|
||||
}
|
||||
|
||||
cmd.Println("Installing service with new configuration...")
|
||||
if err := s.Install(); err != nil {
|
||||
return fmt.Errorf("install service with new config: %w", err)
|
||||
}
|
||||
|
||||
if wasRunning {
|
||||
cmd.Println("Starting Netbird service...")
|
||||
if err := s.Start(); err != nil {
|
||||
return fmt.Errorf("start service after reconfigure: %w", err)
|
||||
}
|
||||
cmd.Println("Netbird service has been reconfigured and started")
|
||||
} else {
|
||||
cmd.Println("Netbird service has been reconfigured")
|
||||
}
|
||||
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
func isServiceRunning() (bool, error) {
|
||||
cfg, err := newSVCConfig()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
s, err := newSVC(newProgram(ctx, cancel), cfg)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
status, err := s.Status()
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("%w: %w", ErrGetServiceStatus, err)
|
||||
}
|
||||
|
||||
return status == service.StatusRunning, nil
|
||||
}
|
||||
|
||||
263
client/cmd/service_test.go
Normal file
263
client/cmd/service_test.go
Normal file
@@ -0,0 +1,263 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"runtime"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/kardianos/service"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
const (
|
||||
serviceStartTimeout = 10 * time.Second
|
||||
serviceStopTimeout = 5 * time.Second
|
||||
statusPollInterval = 500 * time.Millisecond
|
||||
)
|
||||
|
||||
// waitForServiceStatus waits for service to reach expected status with timeout
|
||||
func waitForServiceStatus(expectedStatus service.Status, timeout time.Duration) (bool, error) {
|
||||
cfg, err := newSVCConfig()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
ctxSvc, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
s, err := newSVC(newProgram(ctxSvc, cancel), cfg)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
ctx, timeoutCancel := context.WithTimeout(context.Background(), timeout)
|
||||
defer timeoutCancel()
|
||||
|
||||
ticker := time.NewTicker(statusPollInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return false, fmt.Errorf("timeout waiting for service status %v", expectedStatus)
|
||||
case <-ticker.C:
|
||||
status, err := s.Status()
|
||||
if err != nil {
|
||||
// Continue polling on transient errors
|
||||
continue
|
||||
}
|
||||
if status == expectedStatus {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestServiceLifecycle tests the complete service lifecycle
|
||||
func TestServiceLifecycle(t *testing.T) {
|
||||
// TODO: Add support for Windows and macOS
|
||||
if runtime.GOOS != "linux" && runtime.GOOS != "freebsd" {
|
||||
t.Skipf("Skipping service lifecycle test on unsupported OS: %s", runtime.GOOS)
|
||||
}
|
||||
|
||||
if os.Getenv("CONTAINER") == "true" {
|
||||
t.Skip("Skipping service lifecycle test in container environment")
|
||||
}
|
||||
|
||||
originalServiceName := serviceName
|
||||
serviceName = "netbirdtest" + fmt.Sprintf("%d", time.Now().Unix())
|
||||
defer func() {
|
||||
serviceName = originalServiceName
|
||||
}()
|
||||
|
||||
tempDir := t.TempDir()
|
||||
configPath = fmt.Sprintf("%s/netbird-test-config.json", tempDir)
|
||||
logLevel = "info"
|
||||
daemonAddr = fmt.Sprintf("unix://%s/netbird-test.sock", tempDir)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("Install", func(t *testing.T) {
|
||||
installCmd.SetContext(ctx)
|
||||
err := installCmd.RunE(installCmd, []string{})
|
||||
require.NoError(t, err)
|
||||
|
||||
cfg, err := newSVCConfig()
|
||||
require.NoError(t, err)
|
||||
|
||||
ctxSvc, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
s, err := newSVC(newProgram(ctxSvc, cancel), cfg)
|
||||
require.NoError(t, err)
|
||||
|
||||
status, err := s.Status()
|
||||
assert.NoError(t, err)
|
||||
assert.NotEqual(t, service.StatusUnknown, status)
|
||||
})
|
||||
|
||||
t.Run("Start", func(t *testing.T) {
|
||||
startCmd.SetContext(ctx)
|
||||
err := startCmd.RunE(startCmd, []string{})
|
||||
require.NoError(t, err)
|
||||
|
||||
running, err := waitForServiceStatus(service.StatusRunning, serviceStartTimeout)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, running)
|
||||
})
|
||||
|
||||
t.Run("Restart", func(t *testing.T) {
|
||||
restartCmd.SetContext(ctx)
|
||||
err := restartCmd.RunE(restartCmd, []string{})
|
||||
require.NoError(t, err)
|
||||
|
||||
running, err := waitForServiceStatus(service.StatusRunning, serviceStartTimeout)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, running)
|
||||
})
|
||||
|
||||
t.Run("Reconfigure", func(t *testing.T) {
|
||||
originalLogLevel := logLevel
|
||||
logLevel = "debug"
|
||||
defer func() {
|
||||
logLevel = originalLogLevel
|
||||
}()
|
||||
|
||||
reconfigureCmd.SetContext(ctx)
|
||||
err := reconfigureCmd.RunE(reconfigureCmd, []string{})
|
||||
require.NoError(t, err)
|
||||
|
||||
running, err := waitForServiceStatus(service.StatusRunning, serviceStartTimeout)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, running)
|
||||
})
|
||||
|
||||
t.Run("Stop", func(t *testing.T) {
|
||||
stopCmd.SetContext(ctx)
|
||||
err := stopCmd.RunE(stopCmd, []string{})
|
||||
require.NoError(t, err)
|
||||
|
||||
stopped, err := waitForServiceStatus(service.StatusStopped, serviceStopTimeout)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, stopped)
|
||||
})
|
||||
|
||||
t.Run("Uninstall", func(t *testing.T) {
|
||||
uninstallCmd.SetContext(ctx)
|
||||
err := uninstallCmd.RunE(uninstallCmd, []string{})
|
||||
require.NoError(t, err)
|
||||
|
||||
cfg, err := newSVCConfig()
|
||||
require.NoError(t, err)
|
||||
|
||||
ctxSvc, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
s, err := newSVC(newProgram(ctxSvc, cancel), cfg)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = s.Status()
|
||||
assert.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
// TestServiceEnvVars tests environment variable parsing
|
||||
func TestServiceEnvVars(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
envVars []string
|
||||
expected map[string]string
|
||||
expectErr bool
|
||||
}{
|
||||
{
|
||||
name: "Valid single env var",
|
||||
envVars: []string{"LOG_LEVEL=debug"},
|
||||
expected: map[string]string{
|
||||
"LOG_LEVEL": "debug",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Valid multiple env vars",
|
||||
envVars: []string{"LOG_LEVEL=debug", "CUSTOM_VAR=value"},
|
||||
expected: map[string]string{
|
||||
"LOG_LEVEL": "debug",
|
||||
"CUSTOM_VAR": "value",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Env var with spaces",
|
||||
envVars: []string{" KEY = value "},
|
||||
expected: map[string]string{
|
||||
"KEY": "value",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Invalid format - no equals",
|
||||
envVars: []string{"INVALID"},
|
||||
expectErr: true,
|
||||
},
|
||||
{
|
||||
name: "Invalid format - empty key",
|
||||
envVars: []string{"=value"},
|
||||
expectErr: true,
|
||||
},
|
||||
{
|
||||
name: "Empty value is valid",
|
||||
envVars: []string{"KEY="},
|
||||
expected: map[string]string{
|
||||
"KEY": "",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Empty slice",
|
||||
envVars: []string{},
|
||||
expected: map[string]string{},
|
||||
},
|
||||
{
|
||||
name: "Empty string in slice",
|
||||
envVars: []string{"", "KEY=value", ""},
|
||||
expected: map[string]string{"KEY": "value"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result, err := parseServiceEnvVars(tt.envVars)
|
||||
|
||||
if tt.expectErr {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestServiceConfigWithEnvVars tests service config creation with env vars
|
||||
func TestServiceConfigWithEnvVars(t *testing.T) {
|
||||
originalServiceName := serviceName
|
||||
originalServiceEnvVars := serviceEnvVars
|
||||
defer func() {
|
||||
serviceName = originalServiceName
|
||||
serviceEnvVars = originalServiceEnvVars
|
||||
}()
|
||||
|
||||
serviceName = "test-service"
|
||||
serviceEnvVars = []string{"TEST_VAR=test_value", "ANOTHER_VAR=another_value"}
|
||||
|
||||
cfg, err := newSVCConfig()
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, "test-service", cfg.Name)
|
||||
assert.Equal(t, "test_value", cfg.EnvVars["TEST_VAR"])
|
||||
assert.Equal(t, "another_value", cfg.EnvVars["ANOTHER_VAR"])
|
||||
|
||||
if runtime.GOOS == "linux" {
|
||||
assert.Equal(t, "test-service", cfg.EnvVars["SYSTEMD_UNIT"])
|
||||
}
|
||||
}
|
||||
@@ -12,14 +12,15 @@ import (
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||
nbssh "github.com/netbirdio/netbird/client/ssh"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
|
||||
var (
|
||||
port int
|
||||
user = "root"
|
||||
host string
|
||||
port int
|
||||
userName = "root"
|
||||
host string
|
||||
)
|
||||
|
||||
var sshCmd = &cobra.Command{
|
||||
@@ -31,7 +32,7 @@ var sshCmd = &cobra.Command{
|
||||
|
||||
split := strings.Split(args[0], "@")
|
||||
if len(split) == 2 {
|
||||
user = split[0]
|
||||
userName = split[0]
|
||||
host = split[1]
|
||||
} else {
|
||||
host = args[0]
|
||||
@@ -46,7 +47,7 @@ var sshCmd = &cobra.Command{
|
||||
|
||||
cmd.SetOut(cmd.OutOrStdout())
|
||||
|
||||
err := util.InitLog(logLevel, "console")
|
||||
err := util.InitLog(logLevel, util.LogConsole)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed initializing log %v", err)
|
||||
}
|
||||
@@ -58,11 +59,19 @@ var sshCmd = &cobra.Command{
|
||||
|
||||
ctx := internal.CtxInitState(cmd.Context())
|
||||
|
||||
config, err := internal.UpdateConfig(internal.ConfigInput{
|
||||
ConfigPath: configPath,
|
||||
})
|
||||
pm := profilemanager.NewProfileManager()
|
||||
activeProf, err := pm.GetActiveProfile()
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("get active profile: %v", err)
|
||||
}
|
||||
profPath, err := activeProf.FilePath()
|
||||
if err != nil {
|
||||
return fmt.Errorf("get active profile path: %v", err)
|
||||
}
|
||||
|
||||
config, err := profilemanager.ReadConfig(profPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("read profile config: %v", err)
|
||||
}
|
||||
|
||||
sig := make(chan os.Signal, 1)
|
||||
@@ -89,7 +98,7 @@ var sshCmd = &cobra.Command{
|
||||
}
|
||||
|
||||
func runSSH(ctx context.Context, addr string, pemKey []byte, cmd *cobra.Command) error {
|
||||
c, err := nbssh.DialWithKey(fmt.Sprintf("%s:%d", addr, port), user, pemKey)
|
||||
c, err := nbssh.DialWithKey(fmt.Sprintf("%s:%d", addr, port), userName, pemKey)
|
||||
if err != nil {
|
||||
cmd.Printf("Error: %v\n", err)
|
||||
cmd.Printf("Couldn't connect. Please check the connection status or if the ssh server is enabled on the other peer" +
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
"google.golang.org/grpc/status"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
nbstatus "github.com/netbirdio/netbird/client/status"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
@@ -26,6 +27,7 @@ var (
|
||||
statusFilter string
|
||||
ipsFilterMap map[string]struct{}
|
||||
prefixNamesFilterMap map[string]struct{}
|
||||
connectionTypeFilter string
|
||||
)
|
||||
|
||||
var statusCmd = &cobra.Command{
|
||||
@@ -44,7 +46,8 @@ func init() {
|
||||
statusCmd.MarkFlagsMutuallyExclusive("detail", "json", "yaml", "ipv4")
|
||||
statusCmd.PersistentFlags().StringSliceVar(&ipsFilter, "filter-by-ips", []string{}, "filters the detailed output by a list of one or more IPs, e.g., --filter-by-ips 100.64.0.100,100.64.0.200")
|
||||
statusCmd.PersistentFlags().StringSliceVar(&prefixNamesFilter, "filter-by-names", []string{}, "filters the detailed output by a list of one or more peer FQDN or hostnames, e.g., --filter-by-names peer-a,peer-b.netbird.cloud")
|
||||
statusCmd.PersistentFlags().StringVar(&statusFilter, "filter-by-status", "", "filters the detailed output by connection status(connected|disconnected), e.g., --filter-by-status connected")
|
||||
statusCmd.PersistentFlags().StringVar(&statusFilter, "filter-by-status", "", "filters the detailed output by connection status(idle|connecting|connected), e.g., --filter-by-status connected")
|
||||
statusCmd.PersistentFlags().StringVar(&connectionTypeFilter, "filter-by-connection-type", "", "filters the detailed output by connection type (P2P|Relayed), e.g., --filter-by-connection-type P2P")
|
||||
}
|
||||
|
||||
func statusFunc(cmd *cobra.Command, args []string) error {
|
||||
@@ -57,7 +60,7 @@ func statusFunc(cmd *cobra.Command, args []string) error {
|
||||
return err
|
||||
}
|
||||
|
||||
err = util.InitLog(logLevel, "console")
|
||||
err = util.InitLog(logLevel, util.LogConsole)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed initializing log %v", err)
|
||||
}
|
||||
@@ -69,7 +72,10 @@ func statusFunc(cmd *cobra.Command, args []string) error {
|
||||
return err
|
||||
}
|
||||
|
||||
if resp.GetStatus() == string(internal.StatusNeedsLogin) || resp.GetStatus() == string(internal.StatusLoginFailed) {
|
||||
status := resp.GetStatus()
|
||||
|
||||
if status == string(internal.StatusNeedsLogin) || status == string(internal.StatusLoginFailed) ||
|
||||
status == string(internal.StatusSessionExpired) {
|
||||
cmd.Printf("Daemon status: %s\n\n"+
|
||||
"Run UP command to log in with SSO (interactive login):\n\n"+
|
||||
" netbird up \n\n"+
|
||||
@@ -86,7 +92,13 @@ func statusFunc(cmd *cobra.Command, args []string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
var outputInformationHolder = nbstatus.ConvertToStatusOutputOverview(resp, anonymizeFlag, statusFilter, prefixNamesFilter, prefixNamesFilterMap, ipsFilterMap)
|
||||
pm := profilemanager.NewProfileManager()
|
||||
var profName string
|
||||
if activeProf, err := pm.GetActiveProfile(); err == nil {
|
||||
profName = activeProf.Name
|
||||
}
|
||||
|
||||
var outputInformationHolder = nbstatus.ConvertToStatusOutputOverview(resp, anonymizeFlag, statusFilter, prefixNamesFilter, prefixNamesFilterMap, ipsFilterMap, connectionTypeFilter, profName)
|
||||
var statusOutputString string
|
||||
switch {
|
||||
case detailFlag:
|
||||
@@ -117,7 +129,7 @@ func getStatus(ctx context.Context) (*proto.StatusResponse, error) {
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
resp, err := proto.NewDaemonServiceClient(conn).Status(ctx, &proto.StatusRequest{GetFullPeerStatus: true})
|
||||
resp, err := proto.NewDaemonServiceClient(conn).Status(ctx, &proto.StatusRequest{GetFullPeerStatus: true, ShouldRunProbes: true})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("status failed: %v", status.Convert(err).Message())
|
||||
}
|
||||
@@ -127,12 +139,12 @@ func getStatus(ctx context.Context) (*proto.StatusResponse, error) {
|
||||
|
||||
func parseFilters() error {
|
||||
switch strings.ToLower(statusFilter) {
|
||||
case "", "disconnected", "connected":
|
||||
case "", "idle", "connecting", "connected":
|
||||
if strings.ToLower(statusFilter) != "" {
|
||||
enableDetailFlagWhenFilterFlag()
|
||||
}
|
||||
default:
|
||||
return fmt.Errorf("wrong status filter, should be one of connected|disconnected, got: %s", statusFilter)
|
||||
return fmt.Errorf("wrong status filter, should be one of connected|connecting|idle, got: %s", statusFilter)
|
||||
}
|
||||
|
||||
if len(ipsFilter) > 0 {
|
||||
@@ -153,6 +165,15 @@ func parseFilters() error {
|
||||
enableDetailFlagWhenFilterFlag()
|
||||
}
|
||||
|
||||
switch strings.ToLower(connectionTypeFilter) {
|
||||
case "", "p2p", "relayed":
|
||||
if strings.ToLower(connectionTypeFilter) != "" {
|
||||
enableDetailFlagWhenFilterFlag()
|
||||
}
|
||||
default:
|
||||
return fmt.Errorf("wrong connection-type filter, should be one of P2P|Relayed, got: %s", connectionTypeFilter)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -6,6 +6,8 @@ const (
|
||||
disableServerRoutesFlag = "disable-server-routes"
|
||||
disableDNSFlag = "disable-dns"
|
||||
disableFirewallFlag = "disable-firewall"
|
||||
blockLANAccessFlag = "block-lan-access"
|
||||
blockInboundFlag = "block-inbound"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -13,6 +15,8 @@ var (
|
||||
disableServerRoutes bool
|
||||
disableDNS bool
|
||||
disableFirewall bool
|
||||
blockLANAccess bool
|
||||
blockInbound bool
|
||||
)
|
||||
|
||||
func init() {
|
||||
@@ -28,4 +32,11 @@ func init() {
|
||||
|
||||
upCmd.PersistentFlags().BoolVar(&disableFirewall, disableFirewallFlag, false,
|
||||
"Disable firewall configuration. If enabled, the client won't modify firewall rules.")
|
||||
|
||||
upCmd.PersistentFlags().BoolVar(&blockLANAccess, blockLANAccessFlag, false,
|
||||
"Block access to local networks (LAN) when using this peer as a router or exit node")
|
||||
|
||||
upCmd.PersistentFlags().BoolVar(&blockInbound, blockInboundFlag, false,
|
||||
"Block inbound connections. If enabled, the client will not allow any inbound connections to the local machine nor routed networks.\n"+
|
||||
"This overrides any policies received from the management service.")
|
||||
}
|
||||
|
||||
@@ -98,13 +98,18 @@ func startManagement(t *testing.T, config *types.Config, testFile string) (*grpc
|
||||
settingsMockManager := settings.NewMockManager(ctrl)
|
||||
permissionsManagerMock := permissions.NewMockManager(ctrl)
|
||||
|
||||
accountManager, err := mgmt.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, iv, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock)
|
||||
settingsMockManager.EXPECT().
|
||||
GetSettings(gomock.Any(), gomock.Any(), gomock.Any()).
|
||||
Return(&types.Settings{}, nil).
|
||||
AnyTimes()
|
||||
|
||||
accountManager, err := mgmt.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, iv, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
secretsManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager)
|
||||
mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, nil, nil)
|
||||
mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, nil, nil, &mgmt.MockIntegratedValidator{})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -119,7 +124,7 @@ func startManagement(t *testing.T, config *types.Config, testFile string) (*grpc
|
||||
}
|
||||
|
||||
func startClientDaemon(
|
||||
t *testing.T, ctx context.Context, _, configPath string,
|
||||
t *testing.T, ctx context.Context, _, _ string,
|
||||
) (*grpc.Server, net.Listener) {
|
||||
t.Helper()
|
||||
lis, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
@@ -129,7 +134,7 @@ func startClientDaemon(
|
||||
s := grpc.NewServer()
|
||||
|
||||
server := client.New(ctx,
|
||||
configPath, "")
|
||||
"", false)
|
||||
if err := server.Start(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
@@ -17,7 +17,7 @@ var traceCmd = &cobra.Command{
|
||||
Example: `
|
||||
netbird debug trace in 192.168.1.10 10.10.0.2 -p tcp --sport 12345 --dport 443 --syn --ack
|
||||
netbird debug trace out 10.10.0.1 8.8.8.8 -p udp --dport 53
|
||||
netbird debug trace in 10.10.0.2 10.10.0.1 -p icmp --type 8 --code 0
|
||||
netbird debug trace in 10.10.0.2 10.10.0.1 -p icmp --icmp-type 8 --icmp-code 0
|
||||
netbird debug trace in 100.64.1.1 self -p tcp --dport 80`,
|
||||
Args: cobra.ExactArgs(3),
|
||||
RunE: tracePacket,
|
||||
@@ -118,7 +118,7 @@ func tracePacket(cmd *cobra.Command, args []string) error {
|
||||
}
|
||||
|
||||
func printTrace(cmd *cobra.Command, src, dst, proto string, sport, dport uint16, resp *proto.TracePacketResponse) {
|
||||
cmd.Printf("Packet trace %s:%d -> %s:%d (%s)\n\n", src, sport, dst, dport, strings.ToUpper(proto))
|
||||
cmd.Printf("Packet trace %s:%d → %s:%d (%s)\n\n", src, sport, dst, dport, strings.ToUpper(proto))
|
||||
|
||||
for _, stage := range resp.Stages {
|
||||
if stage.ForwardingDetails != nil {
|
||||
|
||||
421
client/cmd/up.go
421
client/cmd/up.go
@@ -5,6 +5,7 @@ import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os/user"
|
||||
"runtime"
|
||||
"strings"
|
||||
"time"
|
||||
@@ -18,6 +19,7 @@ import (
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
"github.com/netbirdio/netbird/client/system"
|
||||
"github.com/netbirdio/netbird/management/domain"
|
||||
@@ -35,6 +37,9 @@ const (
|
||||
|
||||
noBrowserFlag = "no-browser"
|
||||
noBrowserDesc = "do not open the browser for SSO login"
|
||||
|
||||
profileNameFlag = "profile"
|
||||
profileNameDesc = "profile name to use for the login. If not specified, the last used profile will be used."
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -42,6 +47,8 @@ var (
|
||||
dnsLabels []string
|
||||
dnsLabelsValidated domain.List
|
||||
noBrowser bool
|
||||
profileName string
|
||||
configPath string
|
||||
|
||||
upCmd = &cobra.Command{
|
||||
Use: "up",
|
||||
@@ -55,12 +62,11 @@ func init() {
|
||||
upCmd.PersistentFlags().StringVar(&interfaceName, interfaceNameFlag, iface.WgInterfaceDefault, "Wireguard interface name")
|
||||
upCmd.PersistentFlags().Uint16Var(&wireguardPort, wireguardPortFlag, iface.DefaultWgPort, "Wireguard interface listening port")
|
||||
upCmd.PersistentFlags().BoolVarP(&networkMonitor, networkMonitorFlag, "N", networkMonitor,
|
||||
`Manage network monitoring. Defaults to true on Windows and macOS, false on Linux. `+
|
||||
`Manage network monitoring. Defaults to true on Windows and macOS, false on Linux and FreeBSD. `+
|
||||
`E.g. --network-monitor=false to disable or --network-monitor=true to enable.`,
|
||||
)
|
||||
upCmd.PersistentFlags().StringSliceVar(&extraIFaceBlackList, extraIFaceBlackListFlag, nil, "Extra list of default interfaces to ignore for listening")
|
||||
upCmd.PersistentFlags().DurationVar(&dnsRouteInterval, dnsRouteIntervalFlag, time.Minute, "DNS route update interval")
|
||||
upCmd.PersistentFlags().BoolVar(&blockLANAccess, blockLANAccessFlag, false, "Block access to local networks (LAN) when using this peer as a router or exit node")
|
||||
|
||||
upCmd.PersistentFlags().StringSliceVar(&dnsLabels, dnsLabelsFlag, nil,
|
||||
`Sets DNS labels`+
|
||||
@@ -71,6 +77,8 @@ func init() {
|
||||
)
|
||||
|
||||
upCmd.PersistentFlags().BoolVar(&noBrowser, noBrowserFlag, false, noBrowserDesc)
|
||||
upCmd.PersistentFlags().StringVar(&profileName, profileNameFlag, "", profileNameDesc)
|
||||
upCmd.PersistentFlags().StringVarP(&configPath, "config", "c", "", "(DEPRECATED) Netbird config file location")
|
||||
|
||||
}
|
||||
|
||||
@@ -80,7 +88,7 @@ func upFunc(cmd *cobra.Command, args []string) error {
|
||||
|
||||
cmd.SetOut(cmd.OutOrStdout())
|
||||
|
||||
err := util.InitLog(logLevel, "console")
|
||||
err := util.InitLog(logLevel, util.LogConsole)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed initializing log %v", err)
|
||||
}
|
||||
@@ -102,13 +110,41 @@ func upFunc(cmd *cobra.Command, args []string) error {
|
||||
ctx = context.WithValue(ctx, system.DeviceNameCtxKey, hostName)
|
||||
}
|
||||
|
||||
if foregroundMode {
|
||||
return runInForegroundMode(ctx, cmd)
|
||||
pm := profilemanager.NewProfileManager()
|
||||
|
||||
username, err := user.Current()
|
||||
if err != nil {
|
||||
return fmt.Errorf("get current user: %v", err)
|
||||
}
|
||||
return runInDaemonMode(ctx, cmd)
|
||||
|
||||
var profileSwitched bool
|
||||
// switch profile if provided
|
||||
if profileName != "" {
|
||||
err = switchProfile(cmd.Context(), profileName, username.Username)
|
||||
if err != nil {
|
||||
return fmt.Errorf("switch profile: %v", err)
|
||||
}
|
||||
|
||||
err = pm.SwitchProfile(profileName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("switch profile: %v", err)
|
||||
}
|
||||
|
||||
profileSwitched = true
|
||||
}
|
||||
|
||||
activeProf, err := pm.GetActiveProfile()
|
||||
if err != nil {
|
||||
return fmt.Errorf("get active profile: %v", err)
|
||||
}
|
||||
|
||||
if foregroundMode {
|
||||
return runInForegroundMode(ctx, cmd, activeProf)
|
||||
}
|
||||
return runInDaemonMode(ctx, cmd, pm, activeProf, profileSwitched)
|
||||
}
|
||||
|
||||
func runInForegroundMode(ctx context.Context, cmd *cobra.Command) error {
|
||||
func runInForegroundMode(ctx context.Context, cmd *cobra.Command, activeProf *profilemanager.Profile) error {
|
||||
err := handleRebrand(cmd)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -119,10 +155,241 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command) error {
|
||||
return err
|
||||
}
|
||||
|
||||
ic := internal.ConfigInput{
|
||||
configFilePath, err := activeProf.FilePath()
|
||||
if err != nil {
|
||||
return fmt.Errorf("get active profile file path: %v", err)
|
||||
}
|
||||
|
||||
ic, err := setupConfig(customDNSAddressConverted, cmd, configFilePath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("setup config: %v", err)
|
||||
}
|
||||
|
||||
providedSetupKey, err := getSetupKey()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
config, err := profilemanager.UpdateOrCreateConfig(*ic)
|
||||
if err != nil {
|
||||
return fmt.Errorf("get config file: %v", err)
|
||||
}
|
||||
|
||||
_, _ = profilemanager.UpdateOldManagementURL(ctx, config, configFilePath)
|
||||
|
||||
err = foregroundLogin(ctx, cmd, config, providedSetupKey)
|
||||
if err != nil {
|
||||
return fmt.Errorf("foreground login failed: %v", err)
|
||||
}
|
||||
|
||||
var cancel context.CancelFunc
|
||||
ctx, cancel = context.WithCancel(ctx)
|
||||
SetupCloseHandler(ctx, cancel)
|
||||
|
||||
r := peer.NewRecorder(config.ManagementURL.String())
|
||||
r.GetFullStatus()
|
||||
|
||||
connectClient := internal.NewConnectClient(ctx, config, r)
|
||||
SetupDebugHandler(ctx, config, r, connectClient, "")
|
||||
|
||||
return connectClient.Run(nil)
|
||||
}
|
||||
|
||||
func runInDaemonMode(ctx context.Context, cmd *cobra.Command, pm *profilemanager.ProfileManager, activeProf *profilemanager.Profile, profileSwitched bool) error {
|
||||
customDNSAddressConverted, err := parseCustomDNSAddress(cmd.Flag(dnsResolverAddress).Changed)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parse custom DNS address: %v", err)
|
||||
}
|
||||
|
||||
conn, err := DialClientGRPCServer(ctx, daemonAddr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to connect to daemon error: %v\n"+
|
||||
"If the daemon is not running please run: "+
|
||||
"\nnetbird service install \nnetbird service start\n", err)
|
||||
}
|
||||
defer func() {
|
||||
err := conn.Close()
|
||||
if err != nil {
|
||||
log.Warnf("failed closing daemon gRPC client connection %v", err)
|
||||
return
|
||||
}
|
||||
}()
|
||||
|
||||
client := proto.NewDaemonServiceClient(conn)
|
||||
|
||||
status, err := client.Status(ctx, &proto.StatusRequest{})
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to get daemon status: %v", err)
|
||||
}
|
||||
|
||||
if status.Status == string(internal.StatusConnected) {
|
||||
if !profileSwitched {
|
||||
cmd.Println("Already connected")
|
||||
return nil
|
||||
}
|
||||
|
||||
if _, err := client.Down(ctx, &proto.DownRequest{}); err != nil {
|
||||
log.Errorf("call service down method: %v", err)
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
username, err := user.Current()
|
||||
if err != nil {
|
||||
return fmt.Errorf("get current user: %v", err)
|
||||
}
|
||||
|
||||
// set the new config
|
||||
req := setupSetConfigReq(customDNSAddressConverted, cmd, activeProf.Name, username.Username)
|
||||
if _, err := client.SetConfig(ctx, req); err != nil {
|
||||
return fmt.Errorf("call service set config method: %v", err)
|
||||
}
|
||||
|
||||
if err := doDaemonUp(ctx, cmd, client, pm, activeProf, customDNSAddressConverted, username.Username); err != nil {
|
||||
return fmt.Errorf("daemon up failed: %v", err)
|
||||
}
|
||||
cmd.Println("Connected")
|
||||
return nil
|
||||
}
|
||||
|
||||
func doDaemonUp(ctx context.Context, cmd *cobra.Command, client proto.DaemonServiceClient, pm *profilemanager.ProfileManager, activeProf *profilemanager.Profile, customDNSAddressConverted []byte, username string) error {
|
||||
|
||||
providedSetupKey, err := getSetupKey()
|
||||
if err != nil {
|
||||
return fmt.Errorf("get setup key: %v", err)
|
||||
}
|
||||
|
||||
loginRequest, err := setupLoginRequest(providedSetupKey, customDNSAddressConverted, cmd)
|
||||
if err != nil {
|
||||
return fmt.Errorf("setup login request: %v", err)
|
||||
}
|
||||
|
||||
loginRequest.ProfileName = &activeProf.Name
|
||||
loginRequest.Username = &username
|
||||
|
||||
var loginErr error
|
||||
var loginResp *proto.LoginResponse
|
||||
|
||||
err = WithBackOff(func() error {
|
||||
var backOffErr error
|
||||
loginResp, backOffErr = client.Login(ctx, loginRequest)
|
||||
if s, ok := gstatus.FromError(backOffErr); ok && (s.Code() == codes.InvalidArgument ||
|
||||
s.Code() == codes.PermissionDenied ||
|
||||
s.Code() == codes.NotFound ||
|
||||
s.Code() == codes.Unimplemented) {
|
||||
loginErr = backOffErr
|
||||
return nil
|
||||
}
|
||||
return backOffErr
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("login backoff cycle failed: %v", err)
|
||||
}
|
||||
|
||||
if loginErr != nil {
|
||||
return fmt.Errorf("login failed: %v", loginErr)
|
||||
}
|
||||
|
||||
if loginResp.NeedsSSOLogin {
|
||||
if err := handleSSOLogin(ctx, cmd, loginResp, client, pm); err != nil {
|
||||
return fmt.Errorf("sso login failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
if _, err := client.Up(ctx, &proto.UpRequest{
|
||||
ProfileName: &activeProf.Name,
|
||||
Username: &username,
|
||||
}); err != nil {
|
||||
return fmt.Errorf("call service up method: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func setupSetConfigReq(customDNSAddressConverted []byte, cmd *cobra.Command, profileName, username string) *proto.SetConfigRequest {
|
||||
var req proto.SetConfigRequest
|
||||
req.ProfileName = profileName
|
||||
req.Username = username
|
||||
|
||||
req.ManagementUrl = managementURL
|
||||
req.AdminURL = adminURL
|
||||
req.NatExternalIPs = natExternalIPs
|
||||
req.CustomDNSAddress = customDNSAddressConverted
|
||||
req.ExtraIFaceBlacklist = extraIFaceBlackList
|
||||
req.DnsLabels = dnsLabelsValidated.ToPunycodeList()
|
||||
req.CleanDNSLabels = dnsLabels != nil && len(dnsLabels) == 0
|
||||
req.CleanNATExternalIPs = natExternalIPs != nil && len(natExternalIPs) == 0
|
||||
|
||||
if cmd.Flag(enableRosenpassFlag).Changed {
|
||||
req.RosenpassEnabled = &rosenpassEnabled
|
||||
}
|
||||
if cmd.Flag(rosenpassPermissiveFlag).Changed {
|
||||
req.RosenpassPermissive = &rosenpassPermissive
|
||||
}
|
||||
if cmd.Flag(serverSSHAllowedFlag).Changed {
|
||||
req.ServerSSHAllowed = &serverSSHAllowed
|
||||
}
|
||||
if cmd.Flag(interfaceNameFlag).Changed {
|
||||
if err := parseInterfaceName(interfaceName); err != nil {
|
||||
log.Errorf("parse interface name: %v", err)
|
||||
return nil
|
||||
}
|
||||
req.InterfaceName = &interfaceName
|
||||
}
|
||||
if cmd.Flag(wireguardPortFlag).Changed {
|
||||
p := int64(wireguardPort)
|
||||
req.WireguardPort = &p
|
||||
}
|
||||
|
||||
if cmd.Flag(networkMonitorFlag).Changed {
|
||||
req.NetworkMonitor = &networkMonitor
|
||||
}
|
||||
if rootCmd.PersistentFlags().Changed(preSharedKeyFlag) {
|
||||
req.OptionalPreSharedKey = &preSharedKey
|
||||
}
|
||||
if cmd.Flag(disableAutoConnectFlag).Changed {
|
||||
req.DisableAutoConnect = &autoConnectDisabled
|
||||
}
|
||||
|
||||
if cmd.Flag(dnsRouteIntervalFlag).Changed {
|
||||
req.DnsRouteInterval = durationpb.New(dnsRouteInterval)
|
||||
}
|
||||
|
||||
if cmd.Flag(disableClientRoutesFlag).Changed {
|
||||
req.DisableClientRoutes = &disableClientRoutes
|
||||
}
|
||||
|
||||
if cmd.Flag(disableServerRoutesFlag).Changed {
|
||||
req.DisableServerRoutes = &disableServerRoutes
|
||||
}
|
||||
|
||||
if cmd.Flag(disableDNSFlag).Changed {
|
||||
req.DisableDns = &disableDNS
|
||||
}
|
||||
|
||||
if cmd.Flag(disableFirewallFlag).Changed {
|
||||
req.DisableFirewall = &disableFirewall
|
||||
}
|
||||
|
||||
if cmd.Flag(blockLANAccessFlag).Changed {
|
||||
req.BlockLanAccess = &blockLANAccess
|
||||
}
|
||||
|
||||
if cmd.Flag(blockInboundFlag).Changed {
|
||||
req.BlockInbound = &blockInbound
|
||||
}
|
||||
|
||||
if cmd.Flag(enableLazyConnectionFlag).Changed {
|
||||
req.LazyConnectionEnabled = &lazyConnEnabled
|
||||
}
|
||||
|
||||
return &req
|
||||
}
|
||||
|
||||
func setupConfig(customDNSAddressConverted []byte, cmd *cobra.Command, configFilePath string) (*profilemanager.ConfigInput, error) {
|
||||
ic := profilemanager.ConfigInput{
|
||||
ManagementURL: managementURL,
|
||||
AdminURL: adminURL,
|
||||
ConfigPath: configPath,
|
||||
ConfigPath: configFilePath,
|
||||
NATExternalIPs: natExternalIPs,
|
||||
CustomDNSAddress: customDNSAddressConverted,
|
||||
ExtraIFaceBlackList: extraIFaceBlackList,
|
||||
@@ -143,7 +410,7 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command) error {
|
||||
|
||||
if cmd.Flag(interfaceNameFlag).Changed {
|
||||
if err := parseInterfaceName(interfaceName); err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
ic.InterfaceName = &interfaceName
|
||||
}
|
||||
@@ -194,83 +461,28 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command) error {
|
||||
ic.BlockLANAccess = &blockLANAccess
|
||||
}
|
||||
|
||||
providedSetupKey, err := getSetupKey()
|
||||
if err != nil {
|
||||
return err
|
||||
if cmd.Flag(blockInboundFlag).Changed {
|
||||
ic.BlockInbound = &blockInbound
|
||||
}
|
||||
|
||||
config, err := internal.UpdateOrCreateConfig(ic)
|
||||
if err != nil {
|
||||
return fmt.Errorf("get config file: %v", err)
|
||||
if cmd.Flag(enableLazyConnectionFlag).Changed {
|
||||
ic.LazyConnectionEnabled = &lazyConnEnabled
|
||||
}
|
||||
|
||||
config, _ = internal.UpdateOldManagementURL(ctx, config, configPath)
|
||||
|
||||
err = foregroundLogin(ctx, cmd, config, providedSetupKey)
|
||||
if err != nil {
|
||||
return fmt.Errorf("foreground login failed: %v", err)
|
||||
}
|
||||
|
||||
var cancel context.CancelFunc
|
||||
ctx, cancel = context.WithCancel(ctx)
|
||||
SetupCloseHandler(ctx, cancel)
|
||||
|
||||
r := peer.NewRecorder(config.ManagementURL.String())
|
||||
r.GetFullStatus()
|
||||
|
||||
connectClient := internal.NewConnectClient(ctx, config, r)
|
||||
return connectClient.Run(nil)
|
||||
return &ic, nil
|
||||
}
|
||||
|
||||
func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error {
|
||||
customDNSAddressConverted, err := parseCustomDNSAddress(cmd.Flag(dnsResolverAddress).Changed)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
conn, err := DialClientGRPCServer(ctx, daemonAddr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to connect to daemon error: %v\n"+
|
||||
"If the daemon is not running please run: "+
|
||||
"\nnetbird service install \nnetbird service start\n", err)
|
||||
}
|
||||
defer func() {
|
||||
err := conn.Close()
|
||||
if err != nil {
|
||||
log.Warnf("failed closing daemon gRPC client connection %v", err)
|
||||
return
|
||||
}
|
||||
}()
|
||||
|
||||
client := proto.NewDaemonServiceClient(conn)
|
||||
|
||||
status, err := client.Status(ctx, &proto.StatusRequest{})
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to get daemon status: %v", err)
|
||||
}
|
||||
|
||||
if status.Status == string(internal.StatusConnected) {
|
||||
cmd.Println("Already connected")
|
||||
return nil
|
||||
}
|
||||
|
||||
providedSetupKey, err := getSetupKey()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
func setupLoginRequest(providedSetupKey string, customDNSAddressConverted []byte, cmd *cobra.Command) (*proto.LoginRequest, error) {
|
||||
loginRequest := proto.LoginRequest{
|
||||
SetupKey: providedSetupKey,
|
||||
ManagementUrl: managementURL,
|
||||
AdminURL: adminURL,
|
||||
NatExternalIPs: natExternalIPs,
|
||||
CleanNATExternalIPs: natExternalIPs != nil && len(natExternalIPs) == 0,
|
||||
CustomDNSAddress: customDNSAddressConverted,
|
||||
IsLinuxDesktopClient: isLinuxRunningDesktop(),
|
||||
Hostname: hostName,
|
||||
ExtraIFaceBlacklist: extraIFaceBlackList,
|
||||
DnsLabels: dnsLabels,
|
||||
CleanDNSLabels: dnsLabels != nil && len(dnsLabels) == 0,
|
||||
SetupKey: providedSetupKey,
|
||||
ManagementUrl: managementURL,
|
||||
NatExternalIPs: natExternalIPs,
|
||||
CleanNATExternalIPs: natExternalIPs != nil && len(natExternalIPs) == 0,
|
||||
CustomDNSAddress: customDNSAddressConverted,
|
||||
IsUnixDesktopClient: isUnixRunningDesktop(),
|
||||
Hostname: hostName,
|
||||
ExtraIFaceBlacklist: extraIFaceBlackList,
|
||||
DnsLabels: dnsLabels,
|
||||
CleanDNSLabels: dnsLabels != nil && len(dnsLabels) == 0,
|
||||
}
|
||||
|
||||
if rootCmd.PersistentFlags().Changed(preSharedKeyFlag) {
|
||||
@@ -295,7 +507,7 @@ func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error {
|
||||
|
||||
if cmd.Flag(interfaceNameFlag).Changed {
|
||||
if err := parseInterfaceName(interfaceName); err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
loginRequest.InterfaceName = &interfaceName
|
||||
}
|
||||
@@ -330,45 +542,14 @@ func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error {
|
||||
loginRequest.BlockLanAccess = &blockLANAccess
|
||||
}
|
||||
|
||||
var loginErr error
|
||||
|
||||
var loginResp *proto.LoginResponse
|
||||
|
||||
err = WithBackOff(func() error {
|
||||
var backOffErr error
|
||||
loginResp, backOffErr = client.Login(ctx, &loginRequest)
|
||||
if s, ok := gstatus.FromError(backOffErr); ok && (s.Code() == codes.InvalidArgument ||
|
||||
s.Code() == codes.PermissionDenied ||
|
||||
s.Code() == codes.NotFound ||
|
||||
s.Code() == codes.Unimplemented) {
|
||||
loginErr = backOffErr
|
||||
return nil
|
||||
}
|
||||
return backOffErr
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("login backoff cycle failed: %v", err)
|
||||
if cmd.Flag(blockInboundFlag).Changed {
|
||||
loginRequest.BlockInbound = &blockInbound
|
||||
}
|
||||
|
||||
if loginErr != nil {
|
||||
return fmt.Errorf("login failed: %v", loginErr)
|
||||
if cmd.Flag(enableLazyConnectionFlag).Changed {
|
||||
loginRequest.LazyConnectionEnabled = &lazyConnEnabled
|
||||
}
|
||||
|
||||
if loginResp.NeedsSSOLogin {
|
||||
|
||||
openURL(cmd, loginResp.VerificationURIComplete, loginResp.UserCode, noBrowser)
|
||||
|
||||
_, err = client.WaitSSOLogin(ctx, &proto.WaitSSOLoginRequest{UserCode: loginResp.UserCode, Hostname: hostName})
|
||||
if err != nil {
|
||||
return fmt.Errorf("waiting sso login failed with: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
if _, err := client.Up(ctx, &proto.UpRequest{}); err != nil {
|
||||
return fmt.Errorf("call service up method: %v", err)
|
||||
}
|
||||
cmd.Println("Connected")
|
||||
return nil
|
||||
return &loginRequest, nil
|
||||
}
|
||||
|
||||
func validateNATExternalIPs(list []string) error {
|
||||
@@ -452,7 +633,7 @@ func parseCustomDNSAddress(modified bool) ([]byte, error) {
|
||||
if !isValidAddrPort(customDNSAddress) {
|
||||
return nil, fmt.Errorf("%s is invalid, it should be formatted as IP:Port string or as an empty string like \"\"", customDNSAddress)
|
||||
}
|
||||
if customDNSAddress == "" && logFile != "console" {
|
||||
if customDNSAddress == "" && util.FindFirstLogPath(logFiles) != "" {
|
||||
parsed = []byte("empty")
|
||||
} else {
|
||||
parsed = []byte(customDNSAddress)
|
||||
|
||||
@@ -3,18 +3,55 @@ package cmd
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"os/user"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||
)
|
||||
|
||||
var cliAddr string
|
||||
|
||||
func TestUpDaemon(t *testing.T) {
|
||||
mgmAddr := startTestingServices(t)
|
||||
|
||||
tempDir := t.TempDir()
|
||||
origDefaultProfileDir := profilemanager.DefaultConfigPathDir
|
||||
origActiveProfileStatePath := profilemanager.ActiveProfileStatePath
|
||||
profilemanager.DefaultConfigPathDir = tempDir
|
||||
profilemanager.ActiveProfileStatePath = tempDir + "/active_profile.json"
|
||||
profilemanager.ConfigDirOverride = tempDir
|
||||
|
||||
currUser, err := user.Current()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to get current user: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
sm := profilemanager.ServiceManager{}
|
||||
err = sm.AddProfile("test1", currUser.Username)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to add profile: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
err = sm.SetActiveProfileState(&profilemanager.ActiveProfileState{
|
||||
Name: "test1",
|
||||
Username: currUser.Username,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to set active profile state: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
t.Cleanup(func() {
|
||||
profilemanager.DefaultConfigPathDir = origDefaultProfileDir
|
||||
profilemanager.ActiveProfileStatePath = origActiveProfileStatePath
|
||||
profilemanager.ConfigDirOverride = ""
|
||||
})
|
||||
|
||||
mgmAddr := startTestingServices(t)
|
||||
|
||||
confPath := tempDir + "/config.json"
|
||||
|
||||
ctx := internal.CtxInitState(context.Background())
|
||||
|
||||
@@ -17,6 +17,7 @@ import (
|
||||
"github.com/netbirdio/netbird/client/iface/netstack"
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||
"github.com/netbirdio/netbird/client/system"
|
||||
)
|
||||
|
||||
@@ -26,7 +27,7 @@ var ErrClientNotStarted = errors.New("client not started")
|
||||
// Client manages a netbird embedded client instance
|
||||
type Client struct {
|
||||
deviceName string
|
||||
config *internal.Config
|
||||
config *profilemanager.Config
|
||||
mu sync.Mutex
|
||||
cancel context.CancelFunc
|
||||
setupKey string
|
||||
@@ -88,9 +89,9 @@ func New(opts Options) (*Client, error) {
|
||||
}
|
||||
|
||||
t := true
|
||||
var config *internal.Config
|
||||
var config *profilemanager.Config
|
||||
var err error
|
||||
input := internal.ConfigInput{
|
||||
input := profilemanager.ConfigInput{
|
||||
ConfigPath: opts.ConfigPath,
|
||||
ManagementURL: opts.ManagementURL,
|
||||
PreSharedKey: &opts.PreSharedKey,
|
||||
@@ -98,9 +99,9 @@ func New(opts Options) (*Client, error) {
|
||||
DisableClientRoutes: &opts.DisableClientRoutes,
|
||||
}
|
||||
if opts.ConfigPath != "" {
|
||||
config, err = internal.UpdateOrCreateConfig(input)
|
||||
config, err = profilemanager.UpdateOrCreateConfig(input)
|
||||
} else {
|
||||
config, err = internal.CreateInMemoryConfig(input)
|
||||
config, err = profilemanager.CreateInMemoryConfig(input)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create config: %w", err)
|
||||
|
||||
@@ -113,17 +113,16 @@ func (m *Manager) AddPeerFiltering(
|
||||
func (m *Manager) AddRouteFiltering(
|
||||
id []byte,
|
||||
sources []netip.Prefix,
|
||||
destination netip.Prefix,
|
||||
destination firewall.Network,
|
||||
proto firewall.Protocol,
|
||||
sPort *firewall.Port,
|
||||
dPort *firewall.Port,
|
||||
sPort, dPort *firewall.Port,
|
||||
action firewall.Action,
|
||||
) (firewall.Rule, error) {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
if !destination.Addr().Is4() {
|
||||
return nil, fmt.Errorf("unsupported IP version: %s", destination.Addr().String())
|
||||
if destination.IsPrefix() && !destination.Prefix.Addr().Is4() {
|
||||
return nil, fmt.Errorf("unsupported IP version: %s", destination.Prefix.Addr().String())
|
||||
}
|
||||
|
||||
return m.router.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action)
|
||||
@@ -148,6 +147,10 @@ func (m *Manager) IsServerRouteSupported() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (m *Manager) IsStateful() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (m *Manager) AddNatRule(pair firewall.RouterPair) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
@@ -199,7 +202,7 @@ func (m *Manager) AllowNetbird() error {
|
||||
_, err := m.AddPeerFiltering(
|
||||
nil,
|
||||
net.IP{0, 0, 0, 0},
|
||||
"all",
|
||||
firewall.ProtocolALL,
|
||||
nil,
|
||||
nil,
|
||||
firewall.ActionAccept,
|
||||
@@ -220,10 +223,16 @@ func (m *Manager) SetLogLevel(log.Level) {
|
||||
}
|
||||
|
||||
func (m *Manager) EnableRouting() error {
|
||||
if err := m.router.ipFwdState.RequestForwarding(); err != nil {
|
||||
return fmt.Errorf("enable IP forwarding: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manager) DisableRouting() error {
|
||||
if err := m.router.ipFwdState.ReleaseForwarding(); err != nil {
|
||||
return fmt.Errorf("disable IP forwarding: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -243,6 +252,14 @@ func (m *Manager) DeleteDNATRule(rule firewall.Rule) error {
|
||||
return m.router.DeleteDNATRule(rule)
|
||||
}
|
||||
|
||||
// UpdateSet updates the set with the given prefixes
|
||||
func (m *Manager) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
return m.router.UpdateSet(set, prefixes)
|
||||
}
|
||||
|
||||
func getConntrackEstablished() []string {
|
||||
return []string{"-m", "conntrack", "--ctstate", "RELATED,ESTABLISHED", "-j", "ACCEPT"}
|
||||
}
|
||||
|
||||
@@ -2,7 +2,7 @@ package iptables
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -19,11 +19,8 @@ var ifaceMock = &iFaceMock{
|
||||
},
|
||||
AddressFunc: func() wgaddr.Address {
|
||||
return wgaddr.Address{
|
||||
IP: net.ParseIP("10.20.0.1"),
|
||||
Network: &net.IPNet{
|
||||
IP: net.ParseIP("10.20.0.0"),
|
||||
Mask: net.IPv4Mask(255, 255, 255, 0),
|
||||
},
|
||||
IP: netip.MustParseAddr("10.20.0.1"),
|
||||
Network: netip.MustParsePrefix("10.20.0.0/24"),
|
||||
}
|
||||
},
|
||||
}
|
||||
@@ -70,12 +67,12 @@ func TestIptablesManager(t *testing.T) {
|
||||
|
||||
var rule2 []fw.Rule
|
||||
t.Run("add second rule", func(t *testing.T) {
|
||||
ip := net.ParseIP("10.20.0.3")
|
||||
ip := netip.MustParseAddr("10.20.0.3")
|
||||
port := &fw.Port{
|
||||
IsRange: true,
|
||||
Values: []uint16{8043, 8046},
|
||||
}
|
||||
rule2, err = manager.AddPeerFiltering(nil, ip, "tcp", port, nil, fw.ActionAccept, "")
|
||||
rule2, err = manager.AddPeerFiltering(nil, ip.AsSlice(), "tcp", port, nil, fw.ActionAccept, "")
|
||||
require.NoError(t, err, "failed to add rule")
|
||||
|
||||
for _, r := range rule2 {
|
||||
@@ -95,9 +92,9 @@ func TestIptablesManager(t *testing.T) {
|
||||
|
||||
t.Run("reset check", func(t *testing.T) {
|
||||
// add second rule
|
||||
ip := net.ParseIP("10.20.0.3")
|
||||
ip := netip.MustParseAddr("10.20.0.3")
|
||||
port := &fw.Port{Values: []uint16{5353}}
|
||||
_, err = manager.AddPeerFiltering(nil, ip, "udp", nil, port, fw.ActionAccept, "")
|
||||
_, err = manager.AddPeerFiltering(nil, ip.AsSlice(), "udp", nil, port, fw.ActionAccept, "")
|
||||
require.NoError(t, err, "failed to add rule")
|
||||
|
||||
err = manager.Close(nil)
|
||||
@@ -119,11 +116,8 @@ func TestIptablesManagerIPSet(t *testing.T) {
|
||||
},
|
||||
AddressFunc: func() wgaddr.Address {
|
||||
return wgaddr.Address{
|
||||
IP: net.ParseIP("10.20.0.1"),
|
||||
Network: &net.IPNet{
|
||||
IP: net.ParseIP("10.20.0.0"),
|
||||
Mask: net.IPv4Mask(255, 255, 255, 0),
|
||||
},
|
||||
IP: netip.MustParseAddr("10.20.0.1"),
|
||||
Network: netip.MustParsePrefix("10.20.0.0/24"),
|
||||
}
|
||||
},
|
||||
}
|
||||
@@ -144,11 +138,11 @@ func TestIptablesManagerIPSet(t *testing.T) {
|
||||
|
||||
var rule2 []fw.Rule
|
||||
t.Run("add second rule", func(t *testing.T) {
|
||||
ip := net.ParseIP("10.20.0.3")
|
||||
ip := netip.MustParseAddr("10.20.0.3")
|
||||
port := &fw.Port{
|
||||
Values: []uint16{443},
|
||||
}
|
||||
rule2, err = manager.AddPeerFiltering(nil, ip, "tcp", port, nil, fw.ActionAccept, "default")
|
||||
rule2, err = manager.AddPeerFiltering(nil, ip.AsSlice(), "tcp", port, nil, fw.ActionAccept, "default")
|
||||
for _, r := range rule2 {
|
||||
require.NoError(t, err, "failed to add rule")
|
||||
require.Equal(t, r.(*Rule).ipsetName, "default-sport", "ipset name must be set")
|
||||
@@ -186,11 +180,8 @@ func TestIptablesCreatePerformance(t *testing.T) {
|
||||
},
|
||||
AddressFunc: func() wgaddr.Address {
|
||||
return wgaddr.Address{
|
||||
IP: net.ParseIP("10.20.0.1"),
|
||||
Network: &net.IPNet{
|
||||
IP: net.ParseIP("10.20.0.0"),
|
||||
Mask: net.IPv4Mask(255, 255, 255, 0),
|
||||
},
|
||||
IP: netip.MustParseAddr("10.20.0.1"),
|
||||
Network: netip.MustParsePrefix("10.20.0.0/24"),
|
||||
}
|
||||
},
|
||||
}
|
||||
@@ -212,11 +203,11 @@ func TestIptablesCreatePerformance(t *testing.T) {
|
||||
|
||||
require.NoError(t, err)
|
||||
|
||||
ip := net.ParseIP("10.20.0.100")
|
||||
ip := netip.MustParseAddr("10.20.0.100")
|
||||
start := time.Now()
|
||||
for i := 0; i < testMax; i++ {
|
||||
port := &fw.Port{Values: []uint16{uint16(1000 + i)}}
|
||||
_, err = manager.AddPeerFiltering(nil, ip, "tcp", nil, port, fw.ActionAccept, "")
|
||||
_, err = manager.AddPeerFiltering(nil, ip.AsSlice(), "tcp", nil, port, fw.ActionAccept, "")
|
||||
|
||||
require.NoError(t, err, "failed to add rule")
|
||||
}
|
||||
|
||||
@@ -57,18 +57,18 @@ type ruleInfo struct {
|
||||
}
|
||||
|
||||
type routeFilteringRuleParams struct {
|
||||
Sources []netip.Prefix
|
||||
Destination netip.Prefix
|
||||
Source firewall.Network
|
||||
Destination firewall.Network
|
||||
Proto firewall.Protocol
|
||||
SPort *firewall.Port
|
||||
DPort *firewall.Port
|
||||
Direction firewall.RuleDirection
|
||||
Action firewall.Action
|
||||
SetName string
|
||||
}
|
||||
|
||||
type routeRules map[string][]string
|
||||
|
||||
// the ipset library currently does not support comments, so we use the name only (string)
|
||||
type ipsetCounter = refcounter.Counter[string, []netip.Prefix, struct{}]
|
||||
|
||||
type router struct {
|
||||
@@ -129,7 +129,7 @@ func (r *router) init(stateManager *statemanager.Manager) error {
|
||||
func (r *router) AddRouteFiltering(
|
||||
id []byte,
|
||||
sources []netip.Prefix,
|
||||
destination netip.Prefix,
|
||||
destination firewall.Network,
|
||||
proto firewall.Protocol,
|
||||
sPort *firewall.Port,
|
||||
dPort *firewall.Port,
|
||||
@@ -140,27 +140,28 @@ func (r *router) AddRouteFiltering(
|
||||
return ruleKey, nil
|
||||
}
|
||||
|
||||
var setName string
|
||||
var source firewall.Network
|
||||
if len(sources) > 1 {
|
||||
setName = firewall.GenerateSetName(sources)
|
||||
if _, err := r.ipsetCounter.Increment(setName, sources); err != nil {
|
||||
return nil, fmt.Errorf("create or get ipset: %w", err)
|
||||
}
|
||||
source.Set = firewall.NewPrefixSet(sources)
|
||||
} else if len(sources) > 0 {
|
||||
source.Prefix = sources[0]
|
||||
}
|
||||
|
||||
params := routeFilteringRuleParams{
|
||||
Sources: sources,
|
||||
Source: source,
|
||||
Destination: destination,
|
||||
Proto: proto,
|
||||
SPort: sPort,
|
||||
DPort: dPort,
|
||||
Action: action,
|
||||
SetName: setName,
|
||||
}
|
||||
|
||||
rule := genRouteFilteringRuleSpec(params)
|
||||
rule, err := r.genRouteRuleSpec(params, sources)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("generate route rule spec: %w", err)
|
||||
}
|
||||
|
||||
// Insert DROP rules at the beginning, append ACCEPT rules at the end
|
||||
var err error
|
||||
if action == firewall.ActionDrop {
|
||||
// after the established rule
|
||||
err = r.iptablesClient.Insert(tableFilter, chainRTFWDIN, 2, rule...)
|
||||
@@ -183,17 +184,13 @@ func (r *router) DeleteRouteRule(rule firewall.Rule) error {
|
||||
ruleKey := rule.ID()
|
||||
|
||||
if rule, exists := r.rules[ruleKey]; exists {
|
||||
setName := r.findSetNameInRule(rule)
|
||||
|
||||
if err := r.iptablesClient.Delete(tableFilter, chainRTFWDIN, rule...); err != nil {
|
||||
return fmt.Errorf("delete route rule: %v", err)
|
||||
}
|
||||
delete(r.rules, ruleKey)
|
||||
|
||||
if setName != "" {
|
||||
if _, err := r.ipsetCounter.Decrement(setName); err != nil {
|
||||
return fmt.Errorf("failed to remove ipset: %w", err)
|
||||
}
|
||||
if err := r.decrementSetCounter(rule); err != nil {
|
||||
return fmt.Errorf("decrement ipset counter: %w", err)
|
||||
}
|
||||
} else {
|
||||
log.Debugf("route rule %s not found", ruleKey)
|
||||
@@ -204,13 +201,26 @@ func (r *router) DeleteRouteRule(rule firewall.Rule) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *router) findSetNameInRule(rule []string) string {
|
||||
for i, arg := range rule {
|
||||
if arg == "-m" && i+3 < len(rule) && rule[i+1] == "set" && rule[i+2] == matchSet {
|
||||
return rule[i+3]
|
||||
func (r *router) decrementSetCounter(rule []string) error {
|
||||
sets := r.findSets(rule)
|
||||
var merr *multierror.Error
|
||||
for _, setName := range sets {
|
||||
if _, err := r.ipsetCounter.Decrement(setName); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("decrement counter: %w", err))
|
||||
}
|
||||
}
|
||||
return ""
|
||||
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
func (r *router) findSets(rule []string) []string {
|
||||
var sets []string
|
||||
for i, arg := range rule {
|
||||
if arg == "-m" && i+3 < len(rule) && rule[i+1] == "set" && rule[i+2] == matchSet {
|
||||
sets = append(sets, rule[i+3])
|
||||
}
|
||||
}
|
||||
return sets
|
||||
}
|
||||
|
||||
func (r *router) createIpSet(setName string, sources []netip.Prefix) error {
|
||||
@@ -231,15 +241,13 @@ func (r *router) deleteIpSet(setName string) error {
|
||||
if err := ipset.Destroy(setName); err != nil {
|
||||
return fmt.Errorf("destroy set %s: %w", setName, err)
|
||||
}
|
||||
|
||||
log.Debugf("Deleted unused ipset %s", setName)
|
||||
return nil
|
||||
}
|
||||
|
||||
// AddNatRule inserts an iptables rule pair into the nat chain
|
||||
func (r *router) AddNatRule(pair firewall.RouterPair) error {
|
||||
if err := r.ipFwdState.RequestForwarding(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if r.legacyManagement {
|
||||
log.Warnf("This peer is connected to a NetBird Management service with an older version. Allowing all traffic for %s", pair.Destination)
|
||||
if err := r.addLegacyRouteRule(pair); err != nil {
|
||||
@@ -266,16 +274,14 @@ func (r *router) AddNatRule(pair firewall.RouterPair) error {
|
||||
|
||||
// RemoveNatRule removes an iptables rule pair from forwarding and nat chains
|
||||
func (r *router) RemoveNatRule(pair firewall.RouterPair) error {
|
||||
if err := r.ipFwdState.ReleaseForwarding(); err != nil {
|
||||
log.Errorf("%v", err)
|
||||
}
|
||||
if pair.Masquerade {
|
||||
if err := r.removeNatRule(pair); err != nil {
|
||||
return fmt.Errorf("remove nat rule: %w", err)
|
||||
}
|
||||
|
||||
if err := r.removeNatRule(pair); err != nil {
|
||||
return fmt.Errorf("remove nat rule: %w", err)
|
||||
}
|
||||
|
||||
if err := r.removeNatRule(firewall.GetInversePair(pair)); err != nil {
|
||||
return fmt.Errorf("remove inverse nat rule: %w", err)
|
||||
if err := r.removeNatRule(firewall.GetInversePair(pair)); err != nil {
|
||||
return fmt.Errorf("remove inverse nat rule: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := r.removeLegacyRouteRule(pair); err != nil {
|
||||
@@ -313,8 +319,10 @@ func (r *router) removeLegacyRouteRule(pair firewall.RouterPair) error {
|
||||
return fmt.Errorf("remove legacy forwarding rule %s -> %s: %v", pair.Source, pair.Destination, err)
|
||||
}
|
||||
delete(r.rules, ruleKey)
|
||||
} else {
|
||||
log.Debugf("legacy forwarding rule %s not found", ruleKey)
|
||||
|
||||
if err := r.decrementSetCounter(rule); err != nil {
|
||||
return fmt.Errorf("decrement ipset counter: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -599,12 +607,26 @@ func (r *router) addNatRule(pair firewall.RouterPair) error {
|
||||
rule = append(rule,
|
||||
"-m", "conntrack",
|
||||
"--ctstate", "NEW",
|
||||
"-s", pair.Source.String(),
|
||||
"-d", pair.Destination.String(),
|
||||
)
|
||||
sourceExp, err := r.applyNetwork("-s", pair.Source, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("apply network -s: %w", err)
|
||||
}
|
||||
destExp, err := r.applyNetwork("-d", pair.Destination, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("apply network -d: %w", err)
|
||||
}
|
||||
|
||||
rule = append(rule, sourceExp...)
|
||||
rule = append(rule, destExp...)
|
||||
rule = append(rule,
|
||||
"-j", "MARK", "--set-mark", fmt.Sprintf("%#x", markValue),
|
||||
)
|
||||
|
||||
if err := r.iptablesClient.Append(tableMangle, chainRTPRE, rule...); err != nil {
|
||||
// Ensure nat rules come first, so the mark can be overwritten.
|
||||
// Currently overwritten by the dst-type LOCAL rules for redirected traffic.
|
||||
if err := r.iptablesClient.Insert(tableMangle, chainRTPRE, 1, rule...); err != nil {
|
||||
// TODO: rollback ipset counter
|
||||
return fmt.Errorf("error while adding marking rule for %s: %v", pair.Destination, err)
|
||||
}
|
||||
|
||||
@@ -622,6 +644,10 @@ func (r *router) removeNatRule(pair firewall.RouterPair) error {
|
||||
return fmt.Errorf("error while removing marking rule for %s: %v", pair.Destination, err)
|
||||
}
|
||||
delete(r.rules, ruleKey)
|
||||
|
||||
if err := r.decrementSetCounter(rule); err != nil {
|
||||
return fmt.Errorf("decrement ipset counter: %w", err)
|
||||
}
|
||||
} else {
|
||||
log.Debugf("marking rule %s not found", ruleKey)
|
||||
}
|
||||
@@ -787,17 +813,21 @@ func (r *router) DeleteDNATRule(rule firewall.Rule) error {
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
func genRouteFilteringRuleSpec(params routeFilteringRuleParams) []string {
|
||||
func (r *router) genRouteRuleSpec(params routeFilteringRuleParams, sources []netip.Prefix) ([]string, error) {
|
||||
var rule []string
|
||||
|
||||
if params.SetName != "" {
|
||||
rule = append(rule, "-m", "set", matchSet, params.SetName, "src")
|
||||
} else if len(params.Sources) > 0 {
|
||||
source := params.Sources[0]
|
||||
rule = append(rule, "-s", source.String())
|
||||
sourceExp, err := r.applyNetwork("-s", params.Source, sources)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("apply network -s: %w", err)
|
||||
|
||||
}
|
||||
destExp, err := r.applyNetwork("-d", params.Destination, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("apply network -d: %w", err)
|
||||
}
|
||||
|
||||
rule = append(rule, "-d", params.Destination.String())
|
||||
rule = append(rule, sourceExp...)
|
||||
rule = append(rule, destExp...)
|
||||
|
||||
if params.Proto != firewall.ProtocolALL {
|
||||
rule = append(rule, "-p", strings.ToLower(string(params.Proto)))
|
||||
@@ -807,7 +837,47 @@ func genRouteFilteringRuleSpec(params routeFilteringRuleParams) []string {
|
||||
|
||||
rule = append(rule, "-j", actionToStr(params.Action))
|
||||
|
||||
return rule
|
||||
return rule, nil
|
||||
}
|
||||
|
||||
func (r *router) applyNetwork(flag string, network firewall.Network, prefixes []netip.Prefix) ([]string, error) {
|
||||
direction := "src"
|
||||
if flag == "-d" {
|
||||
direction = "dst"
|
||||
}
|
||||
|
||||
if network.IsSet() {
|
||||
if _, err := r.ipsetCounter.Increment(network.Set.HashedName(), prefixes); err != nil {
|
||||
return nil, fmt.Errorf("create or get ipset: %w", err)
|
||||
}
|
||||
|
||||
return []string{"-m", "set", matchSet, network.Set.HashedName(), direction}, nil
|
||||
}
|
||||
if network.IsPrefix() {
|
||||
return []string{flag, network.Prefix.String()}, nil
|
||||
}
|
||||
|
||||
// nolint:nilnil
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (r *router) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
|
||||
var merr *multierror.Error
|
||||
for _, prefix := range prefixes {
|
||||
// TODO: Implement IPv6 support
|
||||
if prefix.Addr().Is6() {
|
||||
log.Tracef("skipping IPv6 prefix %s: IPv6 support not yet implemented", prefix)
|
||||
continue
|
||||
}
|
||||
if err := ipset.AddPrefix(set.HashedName(), prefix); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("increment ipset counter: %w", err))
|
||||
}
|
||||
}
|
||||
if merr == nil {
|
||||
log.Debugf("updated set %s with prefixes %v", set.HashedName(), prefixes)
|
||||
}
|
||||
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
func applyPort(flag string, port *firewall.Port) []string {
|
||||
|
||||
@@ -60,8 +60,8 @@ func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) {
|
||||
|
||||
pair := firewall.RouterPair{
|
||||
ID: "abc",
|
||||
Source: netip.MustParsePrefix("100.100.100.1/32"),
|
||||
Destination: netip.MustParsePrefix("100.100.100.0/24"),
|
||||
Source: firewall.Network{Prefix: netip.MustParsePrefix("100.100.100.1/32")},
|
||||
Destination: firewall.Network{Prefix: netip.MustParsePrefix("100.100.100.0/24")},
|
||||
Masquerade: true,
|
||||
}
|
||||
|
||||
@@ -332,7 +332,7 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ruleKey, err := r.AddRouteFiltering(nil, tt.sources, tt.destination, tt.proto, tt.sPort, tt.dPort, tt.action)
|
||||
ruleKey, err := r.AddRouteFiltering(nil, tt.sources, firewall.Network{Prefix: tt.destination}, tt.proto, tt.sPort, tt.dPort, tt.action)
|
||||
require.NoError(t, err, "AddRouteFiltering failed")
|
||||
|
||||
// Check if the rule is in the internal map
|
||||
@@ -347,23 +347,29 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
|
||||
assert.NoError(t, err, "Failed to check rule existence")
|
||||
assert.True(t, exists, "Rule not found in iptables")
|
||||
|
||||
var source firewall.Network
|
||||
if len(tt.sources) > 1 {
|
||||
source.Set = firewall.NewPrefixSet(tt.sources)
|
||||
} else if len(tt.sources) > 0 {
|
||||
source.Prefix = tt.sources[0]
|
||||
}
|
||||
// Verify rule content
|
||||
params := routeFilteringRuleParams{
|
||||
Sources: tt.sources,
|
||||
Destination: tt.destination,
|
||||
Source: source,
|
||||
Destination: firewall.Network{Prefix: tt.destination},
|
||||
Proto: tt.proto,
|
||||
SPort: tt.sPort,
|
||||
DPort: tt.dPort,
|
||||
Action: tt.action,
|
||||
SetName: "",
|
||||
}
|
||||
|
||||
expectedRule := genRouteFilteringRuleSpec(params)
|
||||
expectedRule, err := r.genRouteRuleSpec(params, nil)
|
||||
require.NoError(t, err, "Failed to generate expected rule spec")
|
||||
|
||||
if tt.expectSet {
|
||||
setName := firewall.GenerateSetName(tt.sources)
|
||||
params.SetName = setName
|
||||
expectedRule = genRouteFilteringRuleSpec(params)
|
||||
setName := firewall.NewPrefixSet(tt.sources).HashedName()
|
||||
expectedRule, err = r.genRouteRuleSpec(params, nil)
|
||||
require.NoError(t, err, "Failed to generate expected rule spec with set")
|
||||
|
||||
// Check if the set was created
|
||||
_, exists := r.ipsetCounter.Get(setName)
|
||||
@@ -378,3 +384,62 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFindSetNameInRule(t *testing.T) {
|
||||
r := &router{}
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
rule []string
|
||||
expected []string
|
||||
}{
|
||||
{
|
||||
name: "Basic rule with two sets",
|
||||
rule: []string{
|
||||
"-A", "NETBIRD-RT-FWD-IN", "-p", "tcp", "-m", "set", "--match-set", "nb-2e5a2a05", "src",
|
||||
"-m", "set", "--match-set", "nb-349ae051", "dst", "-m", "tcp", "--dport", "8080", "-j", "ACCEPT",
|
||||
},
|
||||
expected: []string{"nb-2e5a2a05", "nb-349ae051"},
|
||||
},
|
||||
{
|
||||
name: "No sets",
|
||||
rule: []string{"-A", "NETBIRD-RT-FWD-IN", "-p", "tcp", "-j", "ACCEPT"},
|
||||
expected: []string{},
|
||||
},
|
||||
{
|
||||
name: "Multiple sets with different positions",
|
||||
rule: []string{
|
||||
"-m", "set", "--match-set", "set1", "src", "-p", "tcp",
|
||||
"-m", "set", "--match-set", "set-abc123", "dst", "-j", "ACCEPT",
|
||||
},
|
||||
expected: []string{"set1", "set-abc123"},
|
||||
},
|
||||
{
|
||||
name: "Boundary case - sequence appears at end",
|
||||
rule: []string{"-p", "tcp", "-m", "set", "--match-set", "final-set"},
|
||||
expected: []string{"final-set"},
|
||||
},
|
||||
{
|
||||
name: "Incomplete pattern - missing set name",
|
||||
rule: []string{"-p", "tcp", "-m", "set", "--match-set"},
|
||||
expected: []string{},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
result := r.findSets(tc.rule)
|
||||
|
||||
if len(result) != len(tc.expected) {
|
||||
t.Errorf("Expected %d sets, got %d. Sets found: %v", len(tc.expected), len(result), result)
|
||||
return
|
||||
}
|
||||
|
||||
for i, set := range result {
|
||||
if set != tc.expected[i] {
|
||||
t.Errorf("Expected set %q at position %d, got %q", tc.expected[i], i, set)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,13 +1,10 @@
|
||||
package manager
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
@@ -43,6 +40,18 @@ const (
|
||||
// Action is the action to be taken on a rule
|
||||
type Action int
|
||||
|
||||
// String returns the string representation of the action
|
||||
func (a Action) String() string {
|
||||
switch a {
|
||||
case ActionAccept:
|
||||
return "accept"
|
||||
case ActionDrop:
|
||||
return "drop"
|
||||
default:
|
||||
return "unknown"
|
||||
}
|
||||
}
|
||||
|
||||
const (
|
||||
// ActionAccept is the action to accept a packet
|
||||
ActionAccept Action = iota
|
||||
@@ -50,6 +59,33 @@ const (
|
||||
ActionDrop
|
||||
)
|
||||
|
||||
// Network is a rule destination, either a set or a prefix
|
||||
type Network struct {
|
||||
Set Set
|
||||
Prefix netip.Prefix
|
||||
}
|
||||
|
||||
// String returns the string representation of the destination
|
||||
func (d Network) String() string {
|
||||
if d.Prefix.IsValid() {
|
||||
return d.Prefix.String()
|
||||
}
|
||||
if d.IsSet() {
|
||||
return d.Set.HashedName()
|
||||
}
|
||||
return "<invalid network>"
|
||||
}
|
||||
|
||||
// IsSet returns true if the destination is a set
|
||||
func (d Network) IsSet() bool {
|
||||
return d.Set != Set{}
|
||||
}
|
||||
|
||||
// IsPrefix returns true if the destination is a valid prefix
|
||||
func (d Network) IsPrefix() bool {
|
||||
return d.Prefix.IsValid()
|
||||
}
|
||||
|
||||
// Manager is the high level abstraction of a firewall manager
|
||||
//
|
||||
// It declares methods which handle actions required by the
|
||||
@@ -80,13 +116,14 @@ type Manager interface {
|
||||
// IsServerRouteSupported returns true if the firewall supports server side routing operations
|
||||
IsServerRouteSupported() bool
|
||||
|
||||
IsStateful() bool
|
||||
|
||||
AddRouteFiltering(
|
||||
id []byte,
|
||||
sources []netip.Prefix,
|
||||
destination netip.Prefix,
|
||||
destination Network,
|
||||
proto Protocol,
|
||||
sPort *Port,
|
||||
dPort *Port,
|
||||
sPort, dPort *Port,
|
||||
action Action,
|
||||
) (Rule, error)
|
||||
|
||||
@@ -119,6 +156,9 @@ type Manager interface {
|
||||
|
||||
// DeleteDNATRule deletes a DNAT rule
|
||||
DeleteDNATRule(Rule) error
|
||||
|
||||
// UpdateSet updates the set with the given prefixes
|
||||
UpdateSet(hash Set, prefixes []netip.Prefix) error
|
||||
}
|
||||
|
||||
func GenKey(format string, pair RouterPair) string {
|
||||
@@ -153,22 +193,6 @@ func SetLegacyManagement(router LegacyManager, isLegacy bool) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// GenerateSetName generates a unique name for an ipset based on the given sources.
|
||||
func GenerateSetName(sources []netip.Prefix) string {
|
||||
// sort for consistent naming
|
||||
SortPrefixes(sources)
|
||||
|
||||
var sourcesStr strings.Builder
|
||||
for _, src := range sources {
|
||||
sourcesStr.WriteString(src.String())
|
||||
}
|
||||
|
||||
hash := sha256.Sum256([]byte(sourcesStr.String()))
|
||||
shortHash := hex.EncodeToString(hash[:])[:8]
|
||||
|
||||
return fmt.Sprintf("nb-%s", shortHash)
|
||||
}
|
||||
|
||||
// MergeIPRanges merges overlapping IP ranges and returns a slice of non-overlapping netip.Prefix
|
||||
func MergeIPRanges(prefixes []netip.Prefix) []netip.Prefix {
|
||||
if len(prefixes) == 0 {
|
||||
|
||||
@@ -20,8 +20,8 @@ func TestGenerateSetName(t *testing.T) {
|
||||
netip.MustParsePrefix("192.168.1.0/24"),
|
||||
}
|
||||
|
||||
result1 := manager.GenerateSetName(prefixes1)
|
||||
result2 := manager.GenerateSetName(prefixes2)
|
||||
result1 := manager.NewPrefixSet(prefixes1)
|
||||
result2 := manager.NewPrefixSet(prefixes2)
|
||||
|
||||
if result1 != result2 {
|
||||
t.Errorf("Different orders produced different hashes: %s != %s", result1, result2)
|
||||
@@ -34,9 +34,9 @@ func TestGenerateSetName(t *testing.T) {
|
||||
netip.MustParsePrefix("10.0.0.0/8"),
|
||||
}
|
||||
|
||||
result := manager.GenerateSetName(prefixes)
|
||||
result := manager.NewPrefixSet(prefixes)
|
||||
|
||||
matched, err := regexp.MatchString(`^nb-[0-9a-f]{8}$`, result)
|
||||
matched, err := regexp.MatchString(`^nb-[0-9a-f]{8}$`, result.HashedName())
|
||||
if err != nil {
|
||||
t.Fatalf("Error matching regex: %v", err)
|
||||
}
|
||||
@@ -46,8 +46,8 @@ func TestGenerateSetName(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("Empty input produces consistent result", func(t *testing.T) {
|
||||
result1 := manager.GenerateSetName([]netip.Prefix{})
|
||||
result2 := manager.GenerateSetName([]netip.Prefix{})
|
||||
result1 := manager.NewPrefixSet([]netip.Prefix{})
|
||||
result2 := manager.NewPrefixSet([]netip.Prefix{})
|
||||
|
||||
if result1 != result2 {
|
||||
t.Errorf("Empty input produced inconsistent results: %s != %s", result1, result2)
|
||||
@@ -64,8 +64,8 @@ func TestGenerateSetName(t *testing.T) {
|
||||
netip.MustParsePrefix("192.168.1.0/24"),
|
||||
}
|
||||
|
||||
result1 := manager.GenerateSetName(prefixes1)
|
||||
result2 := manager.GenerateSetName(prefixes2)
|
||||
result1 := manager.NewPrefixSet(prefixes1)
|
||||
result2 := manager.NewPrefixSet(prefixes2)
|
||||
|
||||
if result1 != result2 {
|
||||
t.Errorf("Different orders of IPv4 and IPv6 produced different hashes: %s != %s", result1, result2)
|
||||
|
||||
@@ -1,15 +1,13 @@
|
||||
package manager
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
|
||||
"github.com/netbirdio/netbird/route"
|
||||
)
|
||||
|
||||
type RouterPair struct {
|
||||
ID route.ID
|
||||
Source netip.Prefix
|
||||
Destination netip.Prefix
|
||||
Source Network
|
||||
Destination Network
|
||||
Masquerade bool
|
||||
Inverse bool
|
||||
}
|
||||
|
||||
74
client/firewall/manager/set.go
Normal file
74
client/firewall/manager/set.go
Normal file
@@ -0,0 +1,74 @@
|
||||
package manager
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"slices"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/management/domain"
|
||||
)
|
||||
|
||||
type Set struct {
|
||||
hash [4]byte
|
||||
comment string
|
||||
}
|
||||
|
||||
// String returns the string representation of the set: hashed name and comment
|
||||
func (h Set) String() string {
|
||||
if h.comment == "" {
|
||||
return h.HashedName()
|
||||
}
|
||||
return h.HashedName() + ": " + h.comment
|
||||
}
|
||||
|
||||
// HashedName returns the string representation of the hash
|
||||
func (h Set) HashedName() string {
|
||||
return fmt.Sprintf(
|
||||
"nb-%s",
|
||||
hex.EncodeToString(h.hash[:]),
|
||||
)
|
||||
}
|
||||
|
||||
// Comment returns the comment of the set
|
||||
func (h Set) Comment() string {
|
||||
return h.comment
|
||||
}
|
||||
|
||||
// NewPrefixSet generates a unique name for an ipset based on the given prefixes.
|
||||
func NewPrefixSet(prefixes []netip.Prefix) Set {
|
||||
// sort for consistent naming
|
||||
SortPrefixes(prefixes)
|
||||
|
||||
hash := sha256.New()
|
||||
for _, src := range prefixes {
|
||||
bytes, err := src.MarshalBinary()
|
||||
if err != nil {
|
||||
log.Warnf("failed to marshal prefix %s: %v", src, err)
|
||||
}
|
||||
hash.Write(bytes)
|
||||
}
|
||||
var set Set
|
||||
copy(set.hash[:], hash.Sum(nil)[:4])
|
||||
|
||||
return set
|
||||
}
|
||||
|
||||
// NewDomainSet generates a unique name for an ipset based on the given domains.
|
||||
func NewDomainSet(domains domain.List) Set {
|
||||
slices.Sort(domains)
|
||||
|
||||
hash := sha256.New()
|
||||
for _, d := range domains {
|
||||
hash.Write([]byte(d.PunycodeString()))
|
||||
}
|
||||
set := Set{
|
||||
comment: domains.SafeString(),
|
||||
}
|
||||
copy(set.hash[:], hash.Sum(nil)[:4])
|
||||
|
||||
return set
|
||||
}
|
||||
@@ -135,17 +135,16 @@ func (m *Manager) AddPeerFiltering(
|
||||
func (m *Manager) AddRouteFiltering(
|
||||
id []byte,
|
||||
sources []netip.Prefix,
|
||||
destination netip.Prefix,
|
||||
destination firewall.Network,
|
||||
proto firewall.Protocol,
|
||||
sPort *firewall.Port,
|
||||
dPort *firewall.Port,
|
||||
sPort, dPort *firewall.Port,
|
||||
action firewall.Action,
|
||||
) (firewall.Rule, error) {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
if !destination.Addr().Is4() {
|
||||
return nil, fmt.Errorf("unsupported IP version: %s", destination.Addr().String())
|
||||
if destination.IsPrefix() && !destination.Prefix.Addr().Is4() {
|
||||
return nil, fmt.Errorf("unsupported IP version: %s", destination.Prefix.Addr().String())
|
||||
}
|
||||
|
||||
return m.router.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action)
|
||||
@@ -171,6 +170,10 @@ func (m *Manager) IsServerRouteSupported() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (m *Manager) IsStateful() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (m *Manager) AddNatRule(pair firewall.RouterPair) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
@@ -242,7 +245,7 @@ func (m *Manager) SetLegacyManagement(isLegacy bool) error {
|
||||
return firewall.SetLegacyManagement(m.router, isLegacy)
|
||||
}
|
||||
|
||||
// Reset firewall to the default state
|
||||
// Close closes the firewall manager
|
||||
func (m *Manager) Close(stateManager *statemanager.Manager) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
@@ -325,10 +328,16 @@ func (m *Manager) SetLogLevel(log.Level) {
|
||||
}
|
||||
|
||||
func (m *Manager) EnableRouting() error {
|
||||
if err := m.router.ipFwdState.RequestForwarding(); err != nil {
|
||||
return fmt.Errorf("enable IP forwarding: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manager) DisableRouting() error {
|
||||
if err := m.router.ipFwdState.ReleaseForwarding(); err != nil {
|
||||
return fmt.Errorf("disable IP forwarding: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -359,6 +368,14 @@ func (m *Manager) DeleteDNATRule(rule firewall.Rule) error {
|
||||
return m.router.DeleteDNATRule(rule)
|
||||
}
|
||||
|
||||
// UpdateSet updates the set with the given prefixes
|
||||
func (m *Manager) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
return m.router.UpdateSet(set, prefixes)
|
||||
}
|
||||
|
||||
func (m *Manager) createWorkTable() (*nftables.Table, error) {
|
||||
tables, err := m.rConn.ListTablesOfFamily(nftables.TableFamilyIPv4)
|
||||
if err != nil {
|
||||
|
||||
@@ -3,7 +3,6 @@ package nftables
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os/exec"
|
||||
"testing"
|
||||
@@ -25,11 +24,8 @@ var ifaceMock = &iFaceMock{
|
||||
},
|
||||
AddressFunc: func() wgaddr.Address {
|
||||
return wgaddr.Address{
|
||||
IP: net.ParseIP("100.96.0.1"),
|
||||
Network: &net.IPNet{
|
||||
IP: net.ParseIP("100.96.0.0"),
|
||||
Mask: net.IPv4Mask(255, 255, 255, 0),
|
||||
},
|
||||
IP: netip.MustParseAddr("100.96.0.1"),
|
||||
Network: netip.MustParsePrefix("100.96.0.0/16"),
|
||||
}
|
||||
},
|
||||
}
|
||||
@@ -70,11 +66,11 @@ func TestNftablesManager(t *testing.T) {
|
||||
time.Sleep(time.Second)
|
||||
}()
|
||||
|
||||
ip := net.ParseIP("100.96.0.1")
|
||||
ip := netip.MustParseAddr("100.96.0.1").Unmap()
|
||||
|
||||
testClient := &nftables.Conn{}
|
||||
|
||||
rule, err := manager.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{53}}, fw.ActionDrop, "")
|
||||
rule, err := manager.AddPeerFiltering(nil, ip.AsSlice(), fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{53}}, fw.ActionDrop, "")
|
||||
require.NoError(t, err, "failed to add rule")
|
||||
|
||||
err = manager.Flush()
|
||||
@@ -109,8 +105,6 @@ func TestNftablesManager(t *testing.T) {
|
||||
}
|
||||
compareExprsIgnoringCounters(t, rules[0].Exprs, expectedExprs1)
|
||||
|
||||
ipToAdd, _ := netip.AddrFromSlice(ip)
|
||||
add := ipToAdd.Unmap()
|
||||
expectedExprs2 := []expr.Any{
|
||||
&expr.Payload{
|
||||
DestRegister: 1,
|
||||
@@ -132,7 +126,7 @@ func TestNftablesManager(t *testing.T) {
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: add.AsSlice(),
|
||||
Data: ip.AsSlice(),
|
||||
},
|
||||
&expr.Payload{
|
||||
DestRegister: 1,
|
||||
@@ -173,11 +167,8 @@ func TestNFtablesCreatePerformance(t *testing.T) {
|
||||
},
|
||||
AddressFunc: func() wgaddr.Address {
|
||||
return wgaddr.Address{
|
||||
IP: net.ParseIP("100.96.0.1"),
|
||||
Network: &net.IPNet{
|
||||
IP: net.ParseIP("100.96.0.0"),
|
||||
Mask: net.IPv4Mask(255, 255, 255, 0),
|
||||
},
|
||||
IP: netip.MustParseAddr("100.96.0.1"),
|
||||
Network: netip.MustParsePrefix("100.96.0.0/16"),
|
||||
}
|
||||
},
|
||||
}
|
||||
@@ -197,11 +188,11 @@ func TestNFtablesCreatePerformance(t *testing.T) {
|
||||
time.Sleep(time.Second)
|
||||
}()
|
||||
|
||||
ip := net.ParseIP("10.20.0.100")
|
||||
ip := netip.MustParseAddr("10.20.0.100")
|
||||
start := time.Now()
|
||||
for i := 0; i < testMax; i++ {
|
||||
port := &fw.Port{Values: []uint16{uint16(1000 + i)}}
|
||||
_, err = manager.AddPeerFiltering(nil, ip, "tcp", nil, port, fw.ActionAccept, "")
|
||||
_, err = manager.AddPeerFiltering(nil, ip.AsSlice(), "tcp", nil, port, fw.ActionAccept, "")
|
||||
require.NoError(t, err, "failed to add rule")
|
||||
|
||||
if i%100 == 0 {
|
||||
@@ -282,14 +273,14 @@ func TestNftablesManagerCompatibilityWithIptables(t *testing.T) {
|
||||
verifyIptablesOutput(t, stdout, stderr)
|
||||
})
|
||||
|
||||
ip := net.ParseIP("100.96.0.1")
|
||||
_, err = manager.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept, "")
|
||||
ip := netip.MustParseAddr("100.96.0.1")
|
||||
_, err = manager.AddPeerFiltering(nil, ip.AsSlice(), fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept, "")
|
||||
require.NoError(t, err, "failed to add peer filtering rule")
|
||||
|
||||
_, err = manager.AddRouteFiltering(
|
||||
nil,
|
||||
[]netip.Prefix{netip.MustParsePrefix("192.168.2.0/24")},
|
||||
netip.MustParsePrefix("10.1.0.0/24"),
|
||||
fw.Network{Prefix: netip.MustParsePrefix("10.1.0.0/24")},
|
||||
fw.ProtocolTCP,
|
||||
nil,
|
||||
&fw.Port{Values: []uint16{443}},
|
||||
@@ -298,8 +289,8 @@ func TestNftablesManagerCompatibilityWithIptables(t *testing.T) {
|
||||
require.NoError(t, err, "failed to add route filtering rule")
|
||||
|
||||
pair := fw.RouterPair{
|
||||
Source: netip.MustParsePrefix("192.168.1.0/24"),
|
||||
Destination: netip.MustParsePrefix("10.0.0.0/24"),
|
||||
Source: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||
Destination: fw.Network{Prefix: netip.MustParsePrefix("10.0.0.0/24")},
|
||||
Masquerade: true,
|
||||
}
|
||||
err = manager.AddNatRule(pair)
|
||||
|
||||
@@ -10,7 +10,6 @@ import (
|
||||
"strings"
|
||||
|
||||
"github.com/coreos/go-iptables/iptables"
|
||||
"github.com/davecgh/go-spew/spew"
|
||||
"github.com/google/nftables"
|
||||
"github.com/google/nftables/binaryutil"
|
||||
"github.com/google/nftables/expr"
|
||||
@@ -44,9 +43,14 @@ const (
|
||||
const refreshRulesMapError = "refresh rules map: %w"
|
||||
|
||||
var (
|
||||
errFilterTableNotFound = fmt.Errorf("nftables: 'filter' table not found")
|
||||
errFilterTableNotFound = fmt.Errorf("'filter' table not found")
|
||||
)
|
||||
|
||||
type setInput struct {
|
||||
set firewall.Set
|
||||
prefixes []netip.Prefix
|
||||
}
|
||||
|
||||
type router struct {
|
||||
conn *nftables.Conn
|
||||
workTable *nftables.Table
|
||||
@@ -54,7 +58,7 @@ type router struct {
|
||||
chains map[string]*nftables.Chain
|
||||
// rules is useful to avoid duplicates and to get missing attributes that we don't have when adding new rules
|
||||
rules map[string]*nftables.Rule
|
||||
ipsetCounter *refcounter.Counter[string, []netip.Prefix, *nftables.Set]
|
||||
ipsetCounter *refcounter.Counter[string, setInput, *nftables.Set]
|
||||
|
||||
wgIface iFaceMapper
|
||||
ipFwdState *ipfwdstate.IPForwardingState
|
||||
@@ -163,7 +167,7 @@ func (r *router) removeNatPreroutingRules() error {
|
||||
func (r *router) loadFilterTable() (*nftables.Table, error) {
|
||||
tables, err := r.conn.ListTablesOfFamily(nftables.TableFamilyIPv4)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("nftables: unable to list tables: %v", err)
|
||||
return nil, fmt.Errorf("unable to list tables: %v", err)
|
||||
}
|
||||
|
||||
for _, table := range tables {
|
||||
@@ -316,7 +320,7 @@ func (r *router) setupDataPlaneMark() error {
|
||||
func (r *router) AddRouteFiltering(
|
||||
id []byte,
|
||||
sources []netip.Prefix,
|
||||
destination netip.Prefix,
|
||||
destination firewall.Network,
|
||||
proto firewall.Protocol,
|
||||
sPort *firewall.Port,
|
||||
dPort *firewall.Port,
|
||||
@@ -331,23 +335,29 @@ func (r *router) AddRouteFiltering(
|
||||
chain := r.chains[chainNameRoutingFw]
|
||||
var exprs []expr.Any
|
||||
|
||||
var source firewall.Network
|
||||
switch {
|
||||
case len(sources) == 1 && sources[0].Bits() == 0:
|
||||
// If it's 0.0.0.0/0, we don't need to add any source matching
|
||||
case len(sources) == 1:
|
||||
// If there's only one source, we can use it directly
|
||||
exprs = append(exprs, generateCIDRMatcherExpressions(true, sources[0])...)
|
||||
source.Prefix = sources[0]
|
||||
default:
|
||||
// If there are multiple sources, create or get an ipset
|
||||
var err error
|
||||
exprs, err = r.getIpSetExprs(sources, exprs)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get ipset expressions: %w", err)
|
||||
}
|
||||
// If there are multiple sources, use a set
|
||||
source.Set = firewall.NewPrefixSet(sources)
|
||||
}
|
||||
|
||||
// Handle destination
|
||||
exprs = append(exprs, generateCIDRMatcherExpressions(false, destination)...)
|
||||
sourceExp, err := r.applyNetwork(source, sources, true)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("apply source: %w", err)
|
||||
}
|
||||
exprs = append(exprs, sourceExp...)
|
||||
|
||||
destExp, err := r.applyNetwork(destination, nil, false)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("apply destination: %w", err)
|
||||
}
|
||||
exprs = append(exprs, destExp...)
|
||||
|
||||
// Handle protocol
|
||||
if proto != firewall.ProtocolALL {
|
||||
@@ -391,39 +401,27 @@ func (r *router) AddRouteFiltering(
|
||||
rule = r.conn.AddRule(rule)
|
||||
}
|
||||
|
||||
log.Tracef("Adding route rule %s", spew.Sdump(rule))
|
||||
if err := r.conn.Flush(); err != nil {
|
||||
return nil, fmt.Errorf(flushError, err)
|
||||
}
|
||||
|
||||
r.rules[string(ruleKey)] = rule
|
||||
|
||||
log.Debugf("nftables: added route rule: sources=%v, destination=%v, proto=%v, sPort=%v, dPort=%v, action=%v", sources, destination, proto, sPort, dPort, action)
|
||||
log.Debugf("added route rule: sources=%v, destination=%v, proto=%v, sPort=%v, dPort=%v, action=%v", sources, destination, proto, sPort, dPort, action)
|
||||
|
||||
return ruleKey, nil
|
||||
}
|
||||
|
||||
func (r *router) getIpSetExprs(sources []netip.Prefix, exprs []expr.Any) ([]expr.Any, error) {
|
||||
setName := firewall.GenerateSetName(sources)
|
||||
ref, err := r.ipsetCounter.Increment(setName, sources)
|
||||
func (r *router) getIpSet(set firewall.Set, prefixes []netip.Prefix, isSource bool) ([]expr.Any, error) {
|
||||
ref, err := r.ipsetCounter.Increment(set.HashedName(), setInput{
|
||||
set: set,
|
||||
prefixes: prefixes,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create or get ipset for sources: %w", err)
|
||||
return nil, fmt.Errorf("create or get ipset: %w", err)
|
||||
}
|
||||
|
||||
exprs = append(exprs,
|
||||
&expr.Payload{
|
||||
DestRegister: 1,
|
||||
Base: expr.PayloadBaseNetworkHeader,
|
||||
Offset: 12,
|
||||
Len: 4,
|
||||
},
|
||||
&expr.Lookup{
|
||||
SourceRegister: 1,
|
||||
SetName: ref.Out.Name,
|
||||
SetID: ref.Out.ID,
|
||||
},
|
||||
)
|
||||
return exprs, nil
|
||||
return getIpSetExprs(ref, isSource)
|
||||
}
|
||||
|
||||
func (r *router) DeleteRouteRule(rule firewall.Rule) error {
|
||||
@@ -442,42 +440,54 @@ func (r *router) DeleteRouteRule(rule firewall.Rule) error {
|
||||
return fmt.Errorf("route rule %s has no handle", ruleKey)
|
||||
}
|
||||
|
||||
setName := r.findSetNameInRule(nftRule)
|
||||
|
||||
if err := r.deleteNftRule(nftRule, ruleKey); err != nil {
|
||||
return fmt.Errorf("delete: %w", err)
|
||||
}
|
||||
|
||||
if setName != "" {
|
||||
if _, err := r.ipsetCounter.Decrement(setName); err != nil {
|
||||
return fmt.Errorf("decrement ipset reference: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := r.conn.Flush(); err != nil {
|
||||
return fmt.Errorf(flushError, err)
|
||||
}
|
||||
|
||||
if err := r.decrementSetCounter(nftRule); err != nil {
|
||||
return fmt.Errorf("decrement set counter: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *router) createIpSet(setName string, sources []netip.Prefix) (*nftables.Set, error) {
|
||||
func (r *router) createIpSet(setName string, input setInput) (*nftables.Set, error) {
|
||||
// overlapping prefixes will result in an error, so we need to merge them
|
||||
sources = firewall.MergeIPRanges(sources)
|
||||
prefixes := firewall.MergeIPRanges(input.prefixes)
|
||||
|
||||
set := &nftables.Set{
|
||||
Name: setName,
|
||||
Table: r.workTable,
|
||||
nfset := &nftables.Set{
|
||||
Name: setName,
|
||||
Comment: input.set.Comment(),
|
||||
Table: r.workTable,
|
||||
// required for prefixes
|
||||
Interval: true,
|
||||
KeyType: nftables.TypeIPAddr,
|
||||
}
|
||||
|
||||
elements := convertPrefixesToSet(prefixes)
|
||||
if err := r.conn.AddSet(nfset, elements); err != nil {
|
||||
return nil, fmt.Errorf("error adding elements to set %s: %w", setName, err)
|
||||
}
|
||||
|
||||
if err := r.conn.Flush(); err != nil {
|
||||
return nil, fmt.Errorf("flush error: %w", err)
|
||||
}
|
||||
|
||||
log.Printf("Created new ipset: %s with %d elements", setName, len(elements)/2)
|
||||
|
||||
return nfset, nil
|
||||
}
|
||||
|
||||
func convertPrefixesToSet(prefixes []netip.Prefix) []nftables.SetElement {
|
||||
var elements []nftables.SetElement
|
||||
for _, prefix := range sources {
|
||||
for _, prefix := range prefixes {
|
||||
// TODO: Implement IPv6 support
|
||||
if prefix.Addr().Is6() {
|
||||
log.Printf("Skipping IPv6 prefix %s: IPv6 support not yet implemented", prefix)
|
||||
log.Tracef("skipping IPv6 prefix %s: IPv6 support not yet implemented", prefix)
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -493,18 +503,7 @@ func (r *router) createIpSet(setName string, sources []netip.Prefix) (*nftables.
|
||||
nftables.SetElement{Key: lastIP.AsSlice(), IntervalEnd: true},
|
||||
)
|
||||
}
|
||||
|
||||
if err := r.conn.AddSet(set, elements); err != nil {
|
||||
return nil, fmt.Errorf("error adding elements to set %s: %w", setName, err)
|
||||
}
|
||||
|
||||
if err := r.conn.Flush(); err != nil {
|
||||
return nil, fmt.Errorf("flush error: %w", err)
|
||||
}
|
||||
|
||||
log.Printf("Created new ipset: %s with %d elements", setName, len(elements)/2)
|
||||
|
||||
return set, nil
|
||||
return elements
|
||||
}
|
||||
|
||||
// calculateLastIP determines the last IP in a given prefix.
|
||||
@@ -528,8 +527,8 @@ func uint32ToBytes(ip uint32) [4]byte {
|
||||
return b
|
||||
}
|
||||
|
||||
func (r *router) deleteIpSet(setName string, set *nftables.Set) error {
|
||||
r.conn.DelSet(set)
|
||||
func (r *router) deleteIpSet(setName string, nfset *nftables.Set) error {
|
||||
r.conn.DelSet(nfset)
|
||||
if err := r.conn.Flush(); err != nil {
|
||||
return fmt.Errorf(flushError, err)
|
||||
}
|
||||
@@ -538,13 +537,27 @@ func (r *router) deleteIpSet(setName string, set *nftables.Set) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *router) findSetNameInRule(rule *nftables.Rule) string {
|
||||
for _, e := range rule.Exprs {
|
||||
if lookup, ok := e.(*expr.Lookup); ok {
|
||||
return lookup.SetName
|
||||
func (r *router) decrementSetCounter(rule *nftables.Rule) error {
|
||||
sets := r.findSets(rule)
|
||||
|
||||
var merr *multierror.Error
|
||||
for _, setName := range sets {
|
||||
if _, err := r.ipsetCounter.Decrement(setName); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("decrement set counter: %w", err))
|
||||
}
|
||||
}
|
||||
return ""
|
||||
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
func (r *router) findSets(rule *nftables.Rule) []string {
|
||||
var sets []string
|
||||
for _, e := range rule.Exprs {
|
||||
if lookup, ok := e.(*expr.Lookup); ok {
|
||||
sets = append(sets, lookup.SetName)
|
||||
}
|
||||
}
|
||||
return sets
|
||||
}
|
||||
|
||||
func (r *router) deleteNftRule(rule *nftables.Rule, ruleKey string) error {
|
||||
@@ -560,10 +573,6 @@ func (r *router) deleteNftRule(rule *nftables.Rule, ruleKey string) error {
|
||||
|
||||
// AddNatRule appends a nftables rule pair to the nat chain
|
||||
func (r *router) AddNatRule(pair firewall.RouterPair) error {
|
||||
if err := r.ipFwdState.RequestForwarding(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := r.refreshRulesMap(); err != nil {
|
||||
return fmt.Errorf(refreshRulesMapError, err)
|
||||
}
|
||||
@@ -586,7 +595,8 @@ func (r *router) AddNatRule(pair firewall.RouterPair) error {
|
||||
}
|
||||
|
||||
if err := r.conn.Flush(); err != nil {
|
||||
return fmt.Errorf("nftables: insert rules for %s: %v", pair.Destination, err)
|
||||
// TODO: rollback ipset counter
|
||||
return fmt.Errorf("insert rules for %s: %v", pair.Destination, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -594,19 +604,22 @@ func (r *router) AddNatRule(pair firewall.RouterPair) error {
|
||||
|
||||
// addNatRule inserts a nftables rule to the conn client flush queue
|
||||
func (r *router) addNatRule(pair firewall.RouterPair) error {
|
||||
sourceExp := generateCIDRMatcherExpressions(true, pair.Source)
|
||||
destExp := generateCIDRMatcherExpressions(false, pair.Destination)
|
||||
sourceExp, err := r.applyNetwork(pair.Source, nil, true)
|
||||
if err != nil {
|
||||
return fmt.Errorf("apply source: %w", err)
|
||||
}
|
||||
|
||||
destExp, err := r.applyNetwork(pair.Destination, nil, false)
|
||||
if err != nil {
|
||||
return fmt.Errorf("apply destination: %w", err)
|
||||
}
|
||||
|
||||
op := expr.CmpOpEq
|
||||
if pair.Inverse {
|
||||
op = expr.CmpOpNeq
|
||||
}
|
||||
|
||||
// We only care about NEW connections to mark them and later identify them in the postrouting chain for masquerading.
|
||||
// Masquerading will take care of the conntrack state, which means we won't need to mark established connections.
|
||||
exprs := getCtNewExprs()
|
||||
exprs = append(exprs,
|
||||
// interface matching
|
||||
exprs := []expr.Any{
|
||||
&expr.Meta{
|
||||
Key: expr.MetaKeyIIFNAME,
|
||||
Register: 1,
|
||||
@@ -616,7 +629,10 @@ func (r *router) addNatRule(pair firewall.RouterPair) error {
|
||||
Register: 1,
|
||||
Data: ifname(r.wgIface.Name()),
|
||||
},
|
||||
)
|
||||
}
|
||||
// We only care about NEW connections to mark them and later identify them in the postrouting chain for masquerading.
|
||||
// Masquerading will take care of the conntrack state, which means we won't need to mark established connections.
|
||||
exprs = append(exprs, getCtNewExprs()...)
|
||||
|
||||
exprs = append(exprs, sourceExp...)
|
||||
exprs = append(exprs, destExp...)
|
||||
@@ -646,7 +662,9 @@ func (r *router) addNatRule(pair firewall.RouterPair) error {
|
||||
}
|
||||
}
|
||||
|
||||
r.rules[ruleKey] = r.conn.AddRule(&nftables.Rule{
|
||||
// Ensure nat rules come first, so the mark can be overwritten.
|
||||
// Currently overwritten by the dst-type LOCAL rules for redirected traffic.
|
||||
r.rules[ruleKey] = r.conn.InsertRule(&nftables.Rule{
|
||||
Table: r.workTable,
|
||||
Chain: r.chains[chainNameManglePrerouting],
|
||||
Exprs: exprs,
|
||||
@@ -729,8 +747,15 @@ func (r *router) addPostroutingRules() error {
|
||||
|
||||
// addLegacyRouteRule adds a legacy routing rule for mgmt servers pre route acls
|
||||
func (r *router) addLegacyRouteRule(pair firewall.RouterPair) error {
|
||||
sourceExp := generateCIDRMatcherExpressions(true, pair.Source)
|
||||
destExp := generateCIDRMatcherExpressions(false, pair.Destination)
|
||||
sourceExp, err := r.applyNetwork(pair.Source, nil, true)
|
||||
if err != nil {
|
||||
return fmt.Errorf("apply source: %w", err)
|
||||
}
|
||||
|
||||
destExp, err := r.applyNetwork(pair.Destination, nil, false)
|
||||
if err != nil {
|
||||
return fmt.Errorf("apply destination: %w", err)
|
||||
}
|
||||
|
||||
exprs := []expr.Any{
|
||||
&expr.Counter{},
|
||||
@@ -739,7 +764,8 @@ func (r *router) addLegacyRouteRule(pair firewall.RouterPair) error {
|
||||
},
|
||||
}
|
||||
|
||||
expression := append(sourceExp, append(destExp, exprs...)...) // nolint:gocritic
|
||||
exprs = append(exprs, sourceExp...)
|
||||
exprs = append(exprs, destExp...)
|
||||
|
||||
ruleKey := firewall.GenKey(firewall.ForwardingFormat, pair)
|
||||
|
||||
@@ -752,7 +778,7 @@ func (r *router) addLegacyRouteRule(pair firewall.RouterPair) error {
|
||||
r.rules[ruleKey] = r.conn.AddRule(&nftables.Rule{
|
||||
Table: r.workTable,
|
||||
Chain: r.chains[chainNameRoutingFw],
|
||||
Exprs: expression,
|
||||
Exprs: exprs,
|
||||
UserData: []byte(ruleKey),
|
||||
})
|
||||
return nil
|
||||
@@ -767,11 +793,13 @@ func (r *router) removeLegacyRouteRule(pair firewall.RouterPair) error {
|
||||
return fmt.Errorf("remove legacy forwarding rule %s -> %s: %v", pair.Source, pair.Destination, err)
|
||||
}
|
||||
|
||||
log.Debugf("nftables: removed legacy forwarding rule %s -> %s", pair.Source, pair.Destination)
|
||||
log.Debugf("removed legacy forwarding rule %s -> %s", pair.Source, pair.Destination)
|
||||
|
||||
delete(r.rules, ruleKey)
|
||||
} else {
|
||||
log.Debugf("nftables: legacy forwarding rule %s not found", ruleKey)
|
||||
|
||||
if err := r.decrementSetCounter(rule); err != nil {
|
||||
return fmt.Errorf("decrement set counter: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -974,20 +1002,18 @@ func (r *router) removeAcceptForwardRulesIptables(ipt *iptables.IPTables) error
|
||||
|
||||
// RemoveNatRule removes the prerouting mark rule
|
||||
func (r *router) RemoveNatRule(pair firewall.RouterPair) error {
|
||||
if err := r.ipFwdState.ReleaseForwarding(); err != nil {
|
||||
log.Errorf("%v", err)
|
||||
}
|
||||
|
||||
if err := r.refreshRulesMap(); err != nil {
|
||||
return fmt.Errorf(refreshRulesMapError, err)
|
||||
}
|
||||
|
||||
if err := r.removeNatRule(pair); err != nil {
|
||||
return fmt.Errorf("remove prerouting rule: %w", err)
|
||||
}
|
||||
if pair.Masquerade {
|
||||
if err := r.removeNatRule(pair); err != nil {
|
||||
return fmt.Errorf("remove prerouting rule: %w", err)
|
||||
}
|
||||
|
||||
if err := r.removeNatRule(firewall.GetInversePair(pair)); err != nil {
|
||||
return fmt.Errorf("remove inverse prerouting rule: %w", err)
|
||||
if err := r.removeNatRule(firewall.GetInversePair(pair)); err != nil {
|
||||
return fmt.Errorf("remove inverse prerouting rule: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := r.removeLegacyRouteRule(pair); err != nil {
|
||||
@@ -995,10 +1021,10 @@ func (r *router) RemoveNatRule(pair firewall.RouterPair) error {
|
||||
}
|
||||
|
||||
if err := r.conn.Flush(); err != nil {
|
||||
return fmt.Errorf("nftables: received error while applying rule removal for %s: %v", pair.Destination, err)
|
||||
// TODO: rollback set counter
|
||||
return fmt.Errorf("remove nat rules rule %s: %v", pair.Destination, err)
|
||||
}
|
||||
|
||||
log.Debugf("nftables: removed nat rules for %s", pair.Destination)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1006,16 +1032,19 @@ func (r *router) removeNatRule(pair firewall.RouterPair) error {
|
||||
ruleKey := firewall.GenKey(firewall.PreroutingFormat, pair)
|
||||
|
||||
if rule, exists := r.rules[ruleKey]; exists {
|
||||
err := r.conn.DelRule(rule)
|
||||
if err != nil {
|
||||
if err := r.conn.DelRule(rule); err != nil {
|
||||
return fmt.Errorf("remove prerouting rule %s -> %s: %v", pair.Source, pair.Destination, err)
|
||||
}
|
||||
|
||||
log.Debugf("nftables: removed prerouting rule %s -> %s", pair.Source, pair.Destination)
|
||||
log.Debugf("removed prerouting rule %s -> %s", pair.Source, pair.Destination)
|
||||
|
||||
delete(r.rules, ruleKey)
|
||||
|
||||
if err := r.decrementSetCounter(rule); err != nil {
|
||||
return fmt.Errorf("decrement set counter: %w", err)
|
||||
}
|
||||
} else {
|
||||
log.Debugf("nftables: prerouting rule %s not found", ruleKey)
|
||||
log.Debugf("prerouting rule %s not found", ruleKey)
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -1027,7 +1056,7 @@ func (r *router) refreshRulesMap() error {
|
||||
for _, chain := range r.chains {
|
||||
rules, err := r.conn.GetRules(chain.Table, chain)
|
||||
if err != nil {
|
||||
return fmt.Errorf("nftables: unable to list rules: %v", err)
|
||||
return fmt.Errorf(" unable to list rules: %v", err)
|
||||
}
|
||||
for _, rule := range rules {
|
||||
if len(rule.UserData) > 0 {
|
||||
@@ -1301,13 +1330,54 @@ func (r *router) DeleteDNATRule(rule firewall.Rule) error {
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
// generateCIDRMatcherExpressions generates nftables expressions that matches a CIDR
|
||||
func generateCIDRMatcherExpressions(source bool, prefix netip.Prefix) []expr.Any {
|
||||
var offset uint32
|
||||
if source {
|
||||
offset = 12 // src offset
|
||||
} else {
|
||||
offset = 16 // dst offset
|
||||
func (r *router) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
|
||||
nfset, err := r.conn.GetSetByName(r.workTable, set.HashedName())
|
||||
if err != nil {
|
||||
return fmt.Errorf("get set %s: %w", set.HashedName(), err)
|
||||
}
|
||||
|
||||
elements := convertPrefixesToSet(prefixes)
|
||||
if err := r.conn.SetAddElements(nfset, elements); err != nil {
|
||||
return fmt.Errorf("add elements to set %s: %w", set.HashedName(), err)
|
||||
}
|
||||
|
||||
if err := r.conn.Flush(); err != nil {
|
||||
return fmt.Errorf(flushError, err)
|
||||
}
|
||||
|
||||
log.Debugf("updated set %s with prefixes %v", set.HashedName(), prefixes)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// applyNetwork generates nftables expressions for networks (CIDR) or sets
|
||||
func (r *router) applyNetwork(
|
||||
network firewall.Network,
|
||||
setPrefixes []netip.Prefix,
|
||||
isSource bool,
|
||||
) ([]expr.Any, error) {
|
||||
if network.IsSet() {
|
||||
exprs, err := r.getIpSet(network.Set, setPrefixes, isSource)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("source: %w", err)
|
||||
}
|
||||
return exprs, nil
|
||||
}
|
||||
|
||||
if network.IsPrefix() {
|
||||
return applyPrefix(network.Prefix, isSource), nil
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// applyPrefix generates nftables expressions for a CIDR prefix
|
||||
func applyPrefix(prefix netip.Prefix, isSource bool) []expr.Any {
|
||||
// dst offset
|
||||
offset := uint32(16)
|
||||
if isSource {
|
||||
// src offset
|
||||
offset = 12
|
||||
}
|
||||
|
||||
ones := prefix.Bits()
|
||||
@@ -1415,3 +1485,27 @@ func getCtNewExprs() []expr.Any {
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func getIpSetExprs(ref refcounter.Ref[*nftables.Set], isSource bool) ([]expr.Any, error) {
|
||||
|
||||
// dst offset
|
||||
offset := uint32(16)
|
||||
if isSource {
|
||||
// src offset
|
||||
offset = 12
|
||||
}
|
||||
|
||||
return []expr.Any{
|
||||
&expr.Payload{
|
||||
DestRegister: 1,
|
||||
Base: expr.PayloadBaseNetworkHeader,
|
||||
Offset: offset,
|
||||
Len: 4,
|
||||
},
|
||||
&expr.Lookup{
|
||||
SourceRegister: 1,
|
||||
SetName: ref.Out.Name,
|
||||
SetID: ref.Out.ID,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -88,8 +88,8 @@ func TestNftablesManager_AddNatRule(t *testing.T) {
|
||||
}
|
||||
|
||||
// Build CIDR matching expressions
|
||||
sourceExp := generateCIDRMatcherExpressions(true, testCase.InputPair.Source)
|
||||
destExp := generateCIDRMatcherExpressions(false, testCase.InputPair.Destination)
|
||||
sourceExp := applyPrefix(testCase.InputPair.Source.Prefix, true)
|
||||
destExp := applyPrefix(testCase.InputPair.Destination.Prefix, false)
|
||||
|
||||
// Combine all expressions in the correct order
|
||||
// nolint:gocritic
|
||||
@@ -311,7 +311,7 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ruleKey, err := r.AddRouteFiltering(nil, tt.sources, tt.destination, tt.proto, tt.sPort, tt.dPort, tt.action)
|
||||
ruleKey, err := r.AddRouteFiltering(nil, tt.sources, firewall.Network{Prefix: tt.destination}, tt.proto, tt.sPort, tt.dPort, tt.action)
|
||||
require.NoError(t, err, "AddRouteFiltering failed")
|
||||
|
||||
t.Cleanup(func() {
|
||||
@@ -441,8 +441,8 @@ func TestNftablesCreateIpSet(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
setName := firewall.GenerateSetName(tt.sources)
|
||||
set, err := r.createIpSet(setName, tt.sources)
|
||||
setName := firewall.NewPrefixSet(tt.sources).HashedName()
|
||||
set, err := r.createIpSet(setName, setInput{prefixes: tt.sources})
|
||||
if err != nil {
|
||||
t.Logf("Failed to create IP set: %v", err)
|
||||
printNftSets()
|
||||
|
||||
@@ -15,8 +15,8 @@ var (
|
||||
Name: "Insert Forwarding IPV4 Rule",
|
||||
InputPair: firewall.RouterPair{
|
||||
ID: "zxa",
|
||||
Source: netip.MustParsePrefix("100.100.100.1/32"),
|
||||
Destination: netip.MustParsePrefix("100.100.200.0/24"),
|
||||
Source: firewall.Network{Prefix: netip.MustParsePrefix("100.100.100.1/32")},
|
||||
Destination: firewall.Network{Prefix: netip.MustParsePrefix("100.100.200.0/24")},
|
||||
Masquerade: false,
|
||||
},
|
||||
},
|
||||
@@ -24,8 +24,8 @@ var (
|
||||
Name: "Insert Forwarding And Nat IPV4 Rules",
|
||||
InputPair: firewall.RouterPair{
|
||||
ID: "zxa",
|
||||
Source: netip.MustParsePrefix("100.100.100.1/32"),
|
||||
Destination: netip.MustParsePrefix("100.100.200.0/24"),
|
||||
Source: firewall.Network{Prefix: netip.MustParsePrefix("100.100.100.1/32")},
|
||||
Destination: firewall.Network{Prefix: netip.MustParsePrefix("100.100.200.0/24")},
|
||||
Masquerade: true,
|
||||
},
|
||||
},
|
||||
@@ -40,8 +40,8 @@ var (
|
||||
Name: "Remove Forwarding And Nat IPV4 Rules",
|
||||
InputPair: firewall.RouterPair{
|
||||
ID: "zxa",
|
||||
Source: netip.MustParsePrefix("100.100.100.1/32"),
|
||||
Destination: netip.MustParsePrefix("100.100.200.0/24"),
|
||||
Source: firewall.Network{Prefix: netip.MustParsePrefix("100.100.100.1/32")},
|
||||
Destination: firewall.Network{Prefix: netip.MustParsePrefix("100.100.200.0/24")},
|
||||
Masquerade: true,
|
||||
},
|
||||
},
|
||||
|
||||
@@ -12,7 +12,7 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
)
|
||||
|
||||
// Reset firewall to the default state
|
||||
// Close cleans up the firewall manager by removing all rules and closing trackers
|
||||
func (m *Manager) Close(stateManager *statemanager.Manager) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
@@ -10,7 +10,6 @@ import (
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
)
|
||||
|
||||
@@ -22,7 +21,7 @@ const (
|
||||
firewallRuleName = "Netbird"
|
||||
)
|
||||
|
||||
// Reset firewall to the default state
|
||||
// Close cleans up the firewall manager by removing all rules and closing trackers
|
||||
func (m *Manager) Close(*statemanager.Manager) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
@@ -32,17 +31,14 @@ func (m *Manager) Close(*statemanager.Manager) error {
|
||||
|
||||
if m.udpTracker != nil {
|
||||
m.udpTracker.Close()
|
||||
m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout, m.logger, m.flowLogger)
|
||||
}
|
||||
|
||||
if m.icmpTracker != nil {
|
||||
m.icmpTracker.Close()
|
||||
m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout, m.logger, m.flowLogger)
|
||||
}
|
||||
|
||||
if m.tcpTracker != nil {
|
||||
m.tcpTracker.Close()
|
||||
m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout, m.logger, m.flowLogger)
|
||||
}
|
||||
|
||||
if fwder := m.forwarder.Load(); fwder != nil {
|
||||
|
||||
@@ -62,5 +62,5 @@ type ConnKey struct {
|
||||
}
|
||||
|
||||
func (c ConnKey) String() string {
|
||||
return fmt.Sprintf("%s:%d -> %s:%d", c.SrcIP.Unmap(), c.SrcPort, c.DstIP.Unmap(), c.DstPort)
|
||||
return fmt.Sprintf("%s:%d → %s:%d", c.SrcIP.Unmap(), c.SrcPort, c.DstIP.Unmap(), c.DstPort)
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@ package conntrack
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"sync"
|
||||
"time"
|
||||
@@ -19,6 +20,10 @@ const (
|
||||
DefaultICMPTimeout = 30 * time.Second
|
||||
// ICMPCleanupInterval is how often we check for stale ICMP connections
|
||||
ICMPCleanupInterval = 15 * time.Second
|
||||
|
||||
// MaxICMPPayloadLength is the maximum length of ICMP payload we consider for original packet info,
|
||||
// which includes the IP header (20 bytes) and transport header (8 bytes)
|
||||
MaxICMPPayloadLength = 28
|
||||
)
|
||||
|
||||
// ICMPConnKey uniquely identifies an ICMP connection
|
||||
@@ -29,7 +34,7 @@ type ICMPConnKey struct {
|
||||
}
|
||||
|
||||
func (i ICMPConnKey) String() string {
|
||||
return fmt.Sprintf("%s -> %s (id %d)", i.SrcIP, i.DstIP, i.ID)
|
||||
return fmt.Sprintf("%s → %s (id %d)", i.SrcIP, i.DstIP, i.ID)
|
||||
}
|
||||
|
||||
// ICMPConnTrack represents an ICMP connection state
|
||||
@@ -50,6 +55,72 @@ type ICMPTracker struct {
|
||||
flowLogger nftypes.FlowLogger
|
||||
}
|
||||
|
||||
// ICMPInfo holds ICMP type, code, and payload for lazy string formatting in logs
|
||||
type ICMPInfo struct {
|
||||
TypeCode layers.ICMPv4TypeCode
|
||||
PayloadData [MaxICMPPayloadLength]byte
|
||||
// actual length of valid data
|
||||
PayloadLen int
|
||||
}
|
||||
|
||||
// String implements fmt.Stringer for lazy evaluation in log messages
|
||||
func (info ICMPInfo) String() string {
|
||||
if info.isErrorMessage() && info.PayloadLen >= MaxICMPPayloadLength {
|
||||
if origInfo := info.parseOriginalPacket(); origInfo != "" {
|
||||
return fmt.Sprintf("%s (original: %s)", info.TypeCode, origInfo)
|
||||
}
|
||||
}
|
||||
|
||||
return info.TypeCode.String()
|
||||
}
|
||||
|
||||
// isErrorMessage returns true if this ICMP type carries original packet info
|
||||
func (info ICMPInfo) isErrorMessage() bool {
|
||||
typ := info.TypeCode.Type()
|
||||
return typ == 3 || // Destination Unreachable
|
||||
typ == 5 || // Redirect
|
||||
typ == 11 || // Time Exceeded
|
||||
typ == 12 // Parameter Problem
|
||||
}
|
||||
|
||||
// parseOriginalPacket extracts info about the original packet from ICMP payload
|
||||
func (info ICMPInfo) parseOriginalPacket() string {
|
||||
if info.PayloadLen < MaxICMPPayloadLength {
|
||||
return ""
|
||||
}
|
||||
|
||||
// TODO: handle IPv6
|
||||
if version := (info.PayloadData[0] >> 4) & 0xF; version != 4 {
|
||||
return ""
|
||||
}
|
||||
|
||||
protocol := info.PayloadData[9]
|
||||
srcIP := net.IP(info.PayloadData[12:16])
|
||||
dstIP := net.IP(info.PayloadData[16:20])
|
||||
|
||||
transportData := info.PayloadData[20:]
|
||||
|
||||
switch nftypes.Protocol(protocol) {
|
||||
case nftypes.TCP:
|
||||
srcPort := uint16(transportData[0])<<8 | uint16(transportData[1])
|
||||
dstPort := uint16(transportData[2])<<8 | uint16(transportData[3])
|
||||
return fmt.Sprintf("TCP %s:%d → %s:%d", srcIP, srcPort, dstIP, dstPort)
|
||||
|
||||
case nftypes.UDP:
|
||||
srcPort := uint16(transportData[0])<<8 | uint16(transportData[1])
|
||||
dstPort := uint16(transportData[2])<<8 | uint16(transportData[3])
|
||||
return fmt.Sprintf("UDP %s:%d → %s:%d", srcIP, srcPort, dstIP, dstPort)
|
||||
|
||||
case nftypes.ICMP:
|
||||
icmpType := transportData[0]
|
||||
icmpCode := transportData[1]
|
||||
return fmt.Sprintf("ICMP %s → %s (type %d code %d)", srcIP, dstIP, icmpType, icmpCode)
|
||||
|
||||
default:
|
||||
return fmt.Sprintf("Proto %d %s → %s", protocol, srcIP, dstIP)
|
||||
}
|
||||
}
|
||||
|
||||
// NewICMPTracker creates a new ICMP connection tracker
|
||||
func NewICMPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nftypes.FlowLogger) *ICMPTracker {
|
||||
if timeout == 0 {
|
||||
@@ -93,30 +164,64 @@ func (t *ICMPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, id uint
|
||||
}
|
||||
|
||||
// TrackOutbound records an outbound ICMP connection
|
||||
func (t *ICMPTracker) TrackOutbound(srcIP netip.Addr, dstIP netip.Addr, id uint16, typecode layers.ICMPv4TypeCode, size int) {
|
||||
func (t *ICMPTracker) TrackOutbound(
|
||||
srcIP netip.Addr,
|
||||
dstIP netip.Addr,
|
||||
id uint16,
|
||||
typecode layers.ICMPv4TypeCode,
|
||||
payload []byte,
|
||||
size int,
|
||||
) {
|
||||
if _, exists := t.updateIfExists(dstIP, srcIP, id, nftypes.Egress, size); !exists {
|
||||
// if (inverted direction) conn is not tracked, track this direction
|
||||
t.track(srcIP, dstIP, id, typecode, nftypes.Egress, nil, size)
|
||||
t.track(srcIP, dstIP, id, typecode, nftypes.Egress, nil, payload, size)
|
||||
}
|
||||
}
|
||||
|
||||
// TrackInbound records an inbound ICMP Echo Request
|
||||
func (t *ICMPTracker) TrackInbound(srcIP netip.Addr, dstIP netip.Addr, id uint16, typecode layers.ICMPv4TypeCode, ruleId []byte, size int) {
|
||||
t.track(srcIP, dstIP, id, typecode, nftypes.Ingress, ruleId, size)
|
||||
func (t *ICMPTracker) TrackInbound(
|
||||
srcIP netip.Addr,
|
||||
dstIP netip.Addr,
|
||||
id uint16,
|
||||
typecode layers.ICMPv4TypeCode,
|
||||
ruleId []byte,
|
||||
payload []byte,
|
||||
size int,
|
||||
) {
|
||||
t.track(srcIP, dstIP, id, typecode, nftypes.Ingress, ruleId, payload, size)
|
||||
}
|
||||
|
||||
// track is the common implementation for tracking both inbound and outbound ICMP connections
|
||||
func (t *ICMPTracker) track(srcIP netip.Addr, dstIP netip.Addr, id uint16, typecode layers.ICMPv4TypeCode, direction nftypes.Direction, ruleId []byte, size int) {
|
||||
func (t *ICMPTracker) track(
|
||||
srcIP netip.Addr,
|
||||
dstIP netip.Addr,
|
||||
id uint16,
|
||||
typecode layers.ICMPv4TypeCode,
|
||||
direction nftypes.Direction,
|
||||
ruleId []byte,
|
||||
payload []byte,
|
||||
size int,
|
||||
) {
|
||||
key, exists := t.updateIfExists(srcIP, dstIP, id, direction, size)
|
||||
if exists {
|
||||
return
|
||||
}
|
||||
|
||||
typ, code := typecode.Type(), typecode.Code()
|
||||
icmpInfo := ICMPInfo{
|
||||
TypeCode: typecode,
|
||||
}
|
||||
if len(payload) > 0 {
|
||||
icmpInfo.PayloadLen = len(payload)
|
||||
if icmpInfo.PayloadLen > MaxICMPPayloadLength {
|
||||
icmpInfo.PayloadLen = MaxICMPPayloadLength
|
||||
}
|
||||
copy(icmpInfo.PayloadData[:], payload[:icmpInfo.PayloadLen])
|
||||
}
|
||||
|
||||
// non echo requests don't need tracking
|
||||
if typ != uint8(layers.ICMPv4TypeEchoRequest) {
|
||||
t.logger.Trace("New %s ICMP connection %s type %d code %d", direction, key, typ, code)
|
||||
t.logger.Trace("New %s ICMP connection %s - %s", direction, key, icmpInfo)
|
||||
t.sendStartEvent(direction, srcIP, dstIP, typ, code, ruleId, size)
|
||||
return
|
||||
}
|
||||
@@ -138,7 +243,7 @@ func (t *ICMPTracker) track(srcIP netip.Addr, dstIP netip.Addr, id uint16, typec
|
||||
t.connections[key] = conn
|
||||
t.mutex.Unlock()
|
||||
|
||||
t.logger.Trace("New %s ICMP connection %s type %d code %d", direction, key, typ, code)
|
||||
t.logger.Trace("New %s ICMP connection %s - %s", direction, key, icmpInfo)
|
||||
t.sendEvent(nftypes.TypeStart, conn, ruleId)
|
||||
}
|
||||
|
||||
|
||||
@@ -15,7 +15,7 @@ func BenchmarkICMPTracker(b *testing.B) {
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 0, 0)
|
||||
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 0, []byte{}, 0)
|
||||
}
|
||||
})
|
||||
|
||||
@@ -28,7 +28,7 @@ func BenchmarkICMPTracker(b *testing.B) {
|
||||
|
||||
// Pre-populate some connections
|
||||
for i := 0; i < 1000; i++ {
|
||||
tracker.TrackOutbound(srcIP, dstIP, uint16(i), 0, 0)
|
||||
tracker.TrackOutbound(srcIP, dstIP, uint16(i), 0, []byte{}, 0)
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
|
||||
@@ -39,8 +39,12 @@ const (
|
||||
// EnvForceUserspaceRouter forces userspace routing even if native routing is available.
|
||||
EnvForceUserspaceRouter = "NB_FORCE_USERSPACE_ROUTER"
|
||||
|
||||
// EnvEnableNetstackLocalForwarding enables forwarding of local traffic to the native stack when running netstack
|
||||
// Leaving this on by default introduces a security risk as sockets on listening on localhost only will be accessible
|
||||
// EnvEnableLocalForwarding enables forwarding of local traffic to the native stack for internal (non-NetBird) interfaces.
|
||||
// Default off as it might be security risk because sockets listening on localhost only will become accessible.
|
||||
EnvEnableLocalForwarding = "NB_ENABLE_LOCAL_FORWARDING"
|
||||
|
||||
// EnvEnableNetstackLocalForwarding is an alias for EnvEnableLocalForwarding.
|
||||
// In netstack mode, it enables forwarding of local traffic to the native stack for all interfaces.
|
||||
EnvEnableNetstackLocalForwarding = "NB_ENABLE_NETSTACK_LOCAL_FORWARDING"
|
||||
)
|
||||
|
||||
@@ -49,10 +53,10 @@ var errNatNotSupported = errors.New("nat not supported with userspace firewall")
|
||||
// RuleSet is a set of rules grouped by a string key
|
||||
type RuleSet map[string]PeerRule
|
||||
|
||||
type RouteRules []RouteRule
|
||||
type RouteRules []*RouteRule
|
||||
|
||||
func (r RouteRules) Sort() {
|
||||
slices.SortStableFunc(r, func(a, b RouteRule) int {
|
||||
slices.SortStableFunc(r, func(a, b *RouteRule) int {
|
||||
// Deny rules come first
|
||||
if a.action == firewall.ActionDrop && b.action != firewall.ActionDrop {
|
||||
return -1
|
||||
@@ -71,7 +75,6 @@ type Manager struct {
|
||||
// incomingRules is used for filtering and hooks
|
||||
incomingRules map[netip.Addr]RuleSet
|
||||
routeRules RouteRules
|
||||
wgNetwork *net.IPNet
|
||||
decoders sync.Pool
|
||||
wgIface common.IFaceMapper
|
||||
nativeFirewall firewall.Manager
|
||||
@@ -99,6 +102,14 @@ type Manager struct {
|
||||
forwarder atomic.Pointer[forwarder.Forwarder]
|
||||
logger *nblog.Logger
|
||||
flowLogger nftypes.FlowLogger
|
||||
|
||||
blockRule firewall.Rule
|
||||
|
||||
// Internal 1:1 DNAT
|
||||
dnatEnabled atomic.Bool
|
||||
dnatMappings map[netip.Addr]netip.Addr
|
||||
dnatMutex sync.RWMutex
|
||||
dnatBiMap *biDNATMap
|
||||
}
|
||||
|
||||
// decoder for packages
|
||||
@@ -146,6 +157,11 @@ func parseCreateEnv() (bool, bool) {
|
||||
if err != nil {
|
||||
log.Warnf("failed to parse %s: %v", EnvEnableNetstackLocalForwarding, err)
|
||||
}
|
||||
} else if val := os.Getenv(EnvEnableLocalForwarding); val != "" {
|
||||
enableLocalForwarding, err = strconv.ParseBool(val)
|
||||
if err != nil {
|
||||
log.Warnf("failed to parse %s: %v", EnvEnableLocalForwarding, err)
|
||||
}
|
||||
}
|
||||
|
||||
return disableConntrack, enableLocalForwarding
|
||||
@@ -179,6 +195,7 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe
|
||||
flowLogger: flowLogger,
|
||||
netstack: netstack.IsEnabled(),
|
||||
localForwarding: enableLocalForwarding,
|
||||
dnatMappings: make(map[netip.Addr]netip.Addr),
|
||||
}
|
||||
m.routingEnabled.Store(false)
|
||||
|
||||
@@ -201,41 +218,35 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe
|
||||
}
|
||||
}
|
||||
|
||||
if err := m.blockInvalidRouted(iface); err != nil {
|
||||
log.Errorf("failed to block invalid routed traffic: %v", err)
|
||||
}
|
||||
|
||||
if err := iface.SetFilter(m); err != nil {
|
||||
return nil, fmt.Errorf("set filter: %w", err)
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func (m *Manager) blockInvalidRouted(iface common.IFaceMapper) error {
|
||||
if m.forwarder.Load() == nil {
|
||||
return nil
|
||||
}
|
||||
func (m *Manager) blockInvalidRouted(iface common.IFaceMapper) (firewall.Rule, error) {
|
||||
wgPrefix, err := netip.ParsePrefix(iface.Address().Network.String())
|
||||
if err != nil {
|
||||
return fmt.Errorf("parse wireguard network: %w", err)
|
||||
return nil, fmt.Errorf("parse wireguard network: %w", err)
|
||||
}
|
||||
log.Debugf("blocking invalid routed traffic for %s", wgPrefix)
|
||||
|
||||
if _, err := m.AddRouteFiltering(
|
||||
rule, err := m.addRouteFiltering(
|
||||
nil,
|
||||
[]netip.Prefix{netip.PrefixFrom(netip.IPv4Unspecified(), 0)},
|
||||
wgPrefix,
|
||||
firewall.Network{Prefix: wgPrefix},
|
||||
firewall.ProtocolALL,
|
||||
nil,
|
||||
nil,
|
||||
firewall.ActionDrop,
|
||||
); err != nil {
|
||||
return fmt.Errorf("block wg nte : %w", err)
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("block wg nte : %w", err)
|
||||
}
|
||||
|
||||
// TODO: Block networks that we're a client of
|
||||
|
||||
return nil
|
||||
return rule, nil
|
||||
}
|
||||
|
||||
func (m *Manager) determineRouting() error {
|
||||
@@ -273,7 +284,7 @@ func (m *Manager) determineRouting() error {
|
||||
|
||||
log.Info("userspace routing is forced")
|
||||
|
||||
case !m.netstack && m.nativeFirewall != nil && m.nativeFirewall.IsServerRouteSupported():
|
||||
case !m.netstack && m.nativeFirewall != nil:
|
||||
// if the OS supports routing natively, then we don't need to filter/route ourselves
|
||||
// netstack mode won't support native routing as there is no interface
|
||||
|
||||
@@ -330,6 +341,10 @@ func (m *Manager) IsServerRouteSupported() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (m *Manager) IsStateful() bool {
|
||||
return m.stateful
|
||||
}
|
||||
|
||||
func (m *Manager) AddNatRule(pair firewall.RouterPair) error {
|
||||
if m.nativeRouter.Load() && m.nativeFirewall != nil {
|
||||
return m.nativeFirewall.AddNatRule(pair)
|
||||
@@ -413,10 +428,23 @@ func (m *Manager) AddPeerFiltering(
|
||||
func (m *Manager) AddRouteFiltering(
|
||||
id []byte,
|
||||
sources []netip.Prefix,
|
||||
destination netip.Prefix,
|
||||
destination firewall.Network,
|
||||
proto firewall.Protocol,
|
||||
sPort *firewall.Port,
|
||||
dPort *firewall.Port,
|
||||
sPort, dPort *firewall.Port,
|
||||
action firewall.Action,
|
||||
) (firewall.Rule, error) {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
return m.addRouteFiltering(id, sources, destination, proto, sPort, dPort, action)
|
||||
}
|
||||
|
||||
func (m *Manager) addRouteFiltering(
|
||||
id []byte,
|
||||
sources []netip.Prefix,
|
||||
destination firewall.Network,
|
||||
proto firewall.Protocol,
|
||||
sPort, dPort *firewall.Port,
|
||||
action firewall.Action,
|
||||
) (firewall.Rule, error) {
|
||||
if m.nativeRouter.Load() && m.nativeFirewall != nil {
|
||||
@@ -426,34 +454,39 @@ func (m *Manager) AddRouteFiltering(
|
||||
ruleID := uuid.New().String()
|
||||
rule := RouteRule{
|
||||
// TODO: consolidate these IDs
|
||||
id: ruleID,
|
||||
mgmtId: id,
|
||||
sources: sources,
|
||||
destination: destination,
|
||||
proto: proto,
|
||||
srcPort: sPort,
|
||||
dstPort: dPort,
|
||||
action: action,
|
||||
id: ruleID,
|
||||
mgmtId: id,
|
||||
sources: sources,
|
||||
dstSet: destination.Set,
|
||||
proto: proto,
|
||||
srcPort: sPort,
|
||||
dstPort: dPort,
|
||||
action: action,
|
||||
}
|
||||
if destination.IsPrefix() {
|
||||
rule.destinations = []netip.Prefix{destination.Prefix}
|
||||
}
|
||||
|
||||
m.mutex.Lock()
|
||||
m.routeRules = append(m.routeRules, rule)
|
||||
m.routeRules = append(m.routeRules, &rule)
|
||||
m.routeRules.Sort()
|
||||
m.mutex.Unlock()
|
||||
|
||||
return &rule, nil
|
||||
}
|
||||
|
||||
func (m *Manager) DeleteRouteRule(rule firewall.Rule) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
return m.deleteRouteRule(rule)
|
||||
}
|
||||
|
||||
func (m *Manager) deleteRouteRule(rule firewall.Rule) error {
|
||||
if m.nativeRouter.Load() && m.nativeFirewall != nil {
|
||||
return m.nativeFirewall.DeleteRouteRule(rule)
|
||||
}
|
||||
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
ruleID := rule.ID()
|
||||
idx := slices.IndexFunc(m.routeRules, func(r RouteRule) bool {
|
||||
idx := slices.IndexFunc(m.routeRules, func(r *RouteRule) bool {
|
||||
return r.id == ruleID
|
||||
})
|
||||
if idx < 0 {
|
||||
@@ -493,30 +526,60 @@ func (m *Manager) SetLegacyManagement(isLegacy bool) error {
|
||||
// Flush doesn't need to be implemented for this manager
|
||||
func (m *Manager) Flush() error { return nil }
|
||||
|
||||
// AddDNATRule adds a DNAT rule
|
||||
func (m *Manager) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) {
|
||||
if m.nativeFirewall == nil {
|
||||
return nil, errNatNotSupported
|
||||
// UpdateSet updates the rule destinations associated with the given set
|
||||
// by merging the existing prefixes with the new ones, then deduplicating.
|
||||
func (m *Manager) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
|
||||
if m.nativeRouter.Load() && m.nativeFirewall != nil {
|
||||
return m.nativeFirewall.UpdateSet(set, prefixes)
|
||||
}
|
||||
return m.nativeFirewall.AddDNATRule(rule)
|
||||
}
|
||||
|
||||
// DeleteDNATRule deletes a DNAT rule
|
||||
func (m *Manager) DeleteDNATRule(rule firewall.Rule) error {
|
||||
if m.nativeFirewall == nil {
|
||||
return errNatNotSupported
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
var matches []*RouteRule
|
||||
for _, rule := range m.routeRules {
|
||||
if rule.dstSet == set {
|
||||
matches = append(matches, rule)
|
||||
}
|
||||
}
|
||||
return m.nativeFirewall.DeleteDNATRule(rule)
|
||||
|
||||
if len(matches) == 0 {
|
||||
return fmt.Errorf("no route rule found for set: %s", set)
|
||||
}
|
||||
|
||||
destinations := matches[0].destinations
|
||||
for _, prefix := range prefixes {
|
||||
if prefix.Addr().Is4() {
|
||||
destinations = append(destinations, prefix)
|
||||
}
|
||||
}
|
||||
|
||||
slices.SortFunc(destinations, func(a, b netip.Prefix) int {
|
||||
cmp := a.Addr().Compare(b.Addr())
|
||||
if cmp != 0 {
|
||||
return cmp
|
||||
}
|
||||
return a.Bits() - b.Bits()
|
||||
})
|
||||
|
||||
destinations = slices.Compact(destinations)
|
||||
|
||||
for _, rule := range matches {
|
||||
rule.destinations = destinations
|
||||
}
|
||||
log.Debugf("updated set %s to prefixes %v", set.HashedName(), destinations)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// DropOutgoing filter outgoing packets
|
||||
func (m *Manager) DropOutgoing(packetData []byte, size int) bool {
|
||||
return m.processOutgoingHooks(packetData, size)
|
||||
// FilterOutBound filters outgoing packets
|
||||
func (m *Manager) FilterOutbound(packetData []byte, size int) bool {
|
||||
return m.filterOutbound(packetData, size)
|
||||
}
|
||||
|
||||
// DropIncoming filter incoming packets
|
||||
func (m *Manager) DropIncoming(packetData []byte, size int) bool {
|
||||
return m.dropFilter(packetData, size)
|
||||
// FilterInbound filters incoming packets
|
||||
func (m *Manager) FilterInbound(packetData []byte, size int) bool {
|
||||
return m.filterInbound(packetData, size)
|
||||
}
|
||||
|
||||
// UpdateLocalIPs updates the list of local IPs
|
||||
@@ -524,7 +587,7 @@ func (m *Manager) UpdateLocalIPs() error {
|
||||
return m.localipmanager.UpdateLocalIPs(m.wgIface)
|
||||
}
|
||||
|
||||
func (m *Manager) processOutgoingHooks(packetData []byte, size int) bool {
|
||||
func (m *Manager) filterOutbound(packetData []byte, size int) bool {
|
||||
d := m.decoders.Get().(*decoder)
|
||||
defer m.decoders.Put(d)
|
||||
|
||||
@@ -546,9 +609,8 @@ func (m *Manager) processOutgoingHooks(packetData []byte, size int) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
if m.stateful {
|
||||
m.trackOutbound(d, srcIP, dstIP, size)
|
||||
}
|
||||
m.trackOutbound(d, srcIP, dstIP, size)
|
||||
m.translateOutboundDNAT(packetData, d)
|
||||
|
||||
return false
|
||||
}
|
||||
@@ -600,7 +662,7 @@ func (m *Manager) trackOutbound(d *decoder, srcIP, dstIP netip.Addr, size int) {
|
||||
flags := getTCPFlags(&d.tcp)
|
||||
m.tcpTracker.TrackOutbound(srcIP, dstIP, uint16(d.tcp.SrcPort), uint16(d.tcp.DstPort), flags, size)
|
||||
case layers.LayerTypeICMPv4:
|
||||
m.icmpTracker.TrackOutbound(srcIP, dstIP, d.icmp4.Id, d.icmp4.TypeCode, size)
|
||||
m.icmpTracker.TrackOutbound(srcIP, dstIP, d.icmp4.Id, d.icmp4.TypeCode, d.icmp4.Payload, size)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -613,7 +675,7 @@ func (m *Manager) trackInbound(d *decoder, srcIP, dstIP netip.Addr, ruleID []byt
|
||||
flags := getTCPFlags(&d.tcp)
|
||||
m.tcpTracker.TrackInbound(srcIP, dstIP, uint16(d.tcp.SrcPort), uint16(d.tcp.DstPort), flags, ruleID, size)
|
||||
case layers.LayerTypeICMPv4:
|
||||
m.icmpTracker.TrackInbound(srcIP, dstIP, d.icmp4.Id, d.icmp4.TypeCode, ruleID, size)
|
||||
m.icmpTracker.TrackInbound(srcIP, dstIP, d.icmp4.Id, d.icmp4.TypeCode, ruleID, d.icmp4.Payload, size)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -652,9 +714,9 @@ func (m *Manager) udpHooksDrop(dport uint16, dstIP netip.Addr, packetData []byte
|
||||
return false
|
||||
}
|
||||
|
||||
// dropFilter implements filtering logic for incoming packets.
|
||||
// filterInbound implements filtering logic for incoming packets.
|
||||
// If it returns true, the packet should be dropped.
|
||||
func (m *Manager) dropFilter(packetData []byte, size int) bool {
|
||||
func (m *Manager) filterInbound(packetData []byte, size int) bool {
|
||||
d := m.decoders.Get().(*decoder)
|
||||
defer m.decoders.Put(d)
|
||||
|
||||
@@ -676,8 +738,15 @@ func (m *Manager) dropFilter(packetData []byte, size int) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// For all inbound traffic, first check if it matches a tracked connection.
|
||||
// This must happen before any other filtering because the packets are statefully tracked.
|
||||
if translated := m.translateInboundReverse(packetData, d); translated {
|
||||
// Re-decode after translation to get original addresses
|
||||
if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil {
|
||||
m.logger.Error("Failed to re-decode packet after reverse DNAT: %v", err)
|
||||
return true
|
||||
}
|
||||
srcIP, dstIP = m.extractIPs(d)
|
||||
}
|
||||
|
||||
if m.stateful && m.isValidTrackedConnection(d, srcIP, dstIP, size) {
|
||||
return false
|
||||
}
|
||||
@@ -717,9 +786,10 @@ func (m *Manager) handleLocalTraffic(d *decoder, srcIP, dstIP netip.Addr, packet
|
||||
return true
|
||||
}
|
||||
|
||||
// if running in netstack mode we need to pass this to the forwarder
|
||||
if m.netstack && m.localForwarding {
|
||||
return m.handleNetstackLocalTraffic(packetData)
|
||||
// If requested we pass local traffic to internal interfaces to the forwarder.
|
||||
// netstack doesn't have an interface to forward packets to the native stack so we always need to use the forwarder.
|
||||
if m.localForwarding && (m.netstack || dstIP != m.wgIface.Address().IP) {
|
||||
return m.handleForwardedLocalTraffic(packetData)
|
||||
}
|
||||
|
||||
// track inbound packets to get the correct direction and session id for flows
|
||||
@@ -729,8 +799,7 @@ func (m *Manager) handleLocalTraffic(d *decoder, srcIP, dstIP netip.Addr, packet
|
||||
return false
|
||||
}
|
||||
|
||||
func (m *Manager) handleNetstackLocalTraffic(packetData []byte) bool {
|
||||
|
||||
func (m *Manager) handleForwardedLocalTraffic(packetData []byte) bool {
|
||||
fwd := m.forwarder.Load()
|
||||
if fwd == nil {
|
||||
m.logger.Trace("Dropping local packet (forwarder not initialized)")
|
||||
@@ -764,7 +833,8 @@ func (m *Manager) handleRoutedTraffic(d *decoder, srcIP, dstIP netip.Addr, packe
|
||||
proto, pnum := getProtocolFromPacket(d)
|
||||
srcPort, dstPort := getPortsFromPacket(d)
|
||||
|
||||
if ruleID, pass := m.routeACLsPass(srcIP, dstIP, proto, srcPort, dstPort); !pass {
|
||||
ruleID, pass := m.routeACLsPass(srcIP, dstIP, proto, srcPort, dstPort)
|
||||
if !pass {
|
||||
m.logger.Trace("Dropping routed packet (ACL denied): rule_id=%s proto=%v src=%s:%d dst=%s:%d",
|
||||
ruleID, pnum, srcIP, srcPort, dstIP, dstPort)
|
||||
|
||||
@@ -790,8 +860,11 @@ func (m *Manager) handleRoutedTraffic(d *decoder, srcIP, dstIP netip.Addr, packe
|
||||
if fwd == nil {
|
||||
m.logger.Trace("failed to forward routed packet (forwarder not initialized)")
|
||||
} else {
|
||||
fwd.RegisterRuleID(srcIP, dstIP, srcPort, dstPort, ruleID)
|
||||
|
||||
if err := fwd.InjectIncomingPacket(packetData); err != nil {
|
||||
m.logger.Error("Failed to inject routed packet: %v", err)
|
||||
fwd.DeleteRuleID(srcIP, dstIP, srcPort, dstPort)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -988,8 +1061,15 @@ func (m *Manager) routeACLsPass(srcIP, dstIP netip.Addr, proto firewall.Protocol
|
||||
return nil, false
|
||||
}
|
||||
|
||||
func (m *Manager) ruleMatches(rule RouteRule, srcAddr, dstAddr netip.Addr, proto firewall.Protocol, srcPort, dstPort uint16) bool {
|
||||
if !rule.destination.Contains(dstAddr) {
|
||||
func (m *Manager) ruleMatches(rule *RouteRule, srcAddr, dstAddr netip.Addr, proto firewall.Protocol, srcPort, dstPort uint16) bool {
|
||||
destMatched := false
|
||||
for _, dst := range rule.destinations {
|
||||
if dst.Contains(dstAddr) {
|
||||
destMatched = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !destMatched {
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -1017,11 +1097,6 @@ func (m *Manager) ruleMatches(rule RouteRule, srcAddr, dstAddr netip.Addr, proto
|
||||
return true
|
||||
}
|
||||
|
||||
// SetNetwork of the wireguard interface to which filtering applied
|
||||
func (m *Manager) SetNetwork(network *net.IPNet) {
|
||||
m.wgNetwork = network
|
||||
}
|
||||
|
||||
// AddUDPPacketHook calls hook when UDP packet from given direction matched
|
||||
//
|
||||
// Hook function returns flag which indicates should be the matched package dropped or not
|
||||
@@ -1091,7 +1166,22 @@ func (m *Manager) EnableRouting() error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
return m.determineRouting()
|
||||
if err := m.determineRouting(); err != nil {
|
||||
return fmt.Errorf("determine routing: %w", err)
|
||||
}
|
||||
|
||||
if m.forwarder.Load() == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
rule, err := m.blockInvalidRouted(m.wgIface)
|
||||
if err != nil {
|
||||
return fmt.Errorf("block invalid routed: %w", err)
|
||||
}
|
||||
|
||||
m.blockRule = rule
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manager) DisableRouting() error {
|
||||
@@ -1116,5 +1206,12 @@ func (m *Manager) DisableRouting() error {
|
||||
|
||||
log.Debug("forwarder stopped")
|
||||
|
||||
if m.blockRule != nil {
|
||||
if err := m.deleteRouteRule(m.blockRule); err != nil {
|
||||
return fmt.Errorf("delete block rule: %w", err)
|
||||
}
|
||||
m.blockRule = nil
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -174,11 +174,6 @@ func BenchmarkCoreFiltering(b *testing.B) {
|
||||
require.NoError(b, manager.Close(nil))
|
||||
})
|
||||
|
||||
manager.wgNetwork = &net.IPNet{
|
||||
IP: net.ParseIP("100.64.0.0"),
|
||||
Mask: net.CIDRMask(10, 32),
|
||||
}
|
||||
|
||||
// Apply scenario-specific setup
|
||||
sc.setupFunc(manager)
|
||||
|
||||
@@ -193,13 +188,13 @@ func BenchmarkCoreFiltering(b *testing.B) {
|
||||
|
||||
// For stateful scenarios, establish the connection
|
||||
if sc.stateful {
|
||||
manager.processOutgoingHooks(outbound, 0)
|
||||
manager.filterOutbound(outbound, 0)
|
||||
}
|
||||
|
||||
// Measure inbound packet processing
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
manager.dropFilter(inbound, 0)
|
||||
manager.filterInbound(inbound, 0)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -219,18 +214,13 @@ func BenchmarkStateScaling(b *testing.B) {
|
||||
require.NoError(b, manager.Close(nil))
|
||||
})
|
||||
|
||||
manager.wgNetwork = &net.IPNet{
|
||||
IP: net.ParseIP("100.64.0.0"),
|
||||
Mask: net.CIDRMask(10, 32),
|
||||
}
|
||||
|
||||
// Pre-populate connection table
|
||||
srcIPs := generateRandomIPs(count)
|
||||
dstIPs := generateRandomIPs(count)
|
||||
for i := 0; i < count; i++ {
|
||||
outbound := generatePacket(b, srcIPs[i], dstIPs[i],
|
||||
uint16(1024+i), 80, layers.IPProtocolTCP)
|
||||
manager.processOutgoingHooks(outbound, 0)
|
||||
manager.filterOutbound(outbound, 0)
|
||||
}
|
||||
|
||||
// Test packet
|
||||
@@ -238,11 +228,11 @@ func BenchmarkStateScaling(b *testing.B) {
|
||||
testIn := generatePacket(b, dstIPs[0], srcIPs[0], 80, 1024, layers.IPProtocolTCP)
|
||||
|
||||
// First establish our test connection
|
||||
manager.processOutgoingHooks(testOut, 0)
|
||||
manager.filterOutbound(testOut, 0)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
manager.dropFilter(testIn, 0)
|
||||
manager.filterInbound(testIn, 0)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -267,23 +257,18 @@ func BenchmarkEstablishmentOverhead(b *testing.B) {
|
||||
require.NoError(b, manager.Close(nil))
|
||||
})
|
||||
|
||||
manager.wgNetwork = &net.IPNet{
|
||||
IP: net.ParseIP("100.64.0.0"),
|
||||
Mask: net.CIDRMask(10, 32),
|
||||
}
|
||||
|
||||
srcIP := generateRandomIPs(1)[0]
|
||||
dstIP := generateRandomIPs(1)[0]
|
||||
outbound := generatePacket(b, srcIP, dstIP, 1024, 80, layers.IPProtocolTCP)
|
||||
inbound := generatePacket(b, dstIP, srcIP, 80, 1024, layers.IPProtocolTCP)
|
||||
|
||||
if sc.established {
|
||||
manager.processOutgoingHooks(outbound, 0)
|
||||
manager.filterOutbound(outbound, 0)
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
manager.dropFilter(inbound, 0)
|
||||
manager.filterInbound(inbound, 0)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -304,10 +289,6 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
|
||||
proto: layers.IPProtocolTCP,
|
||||
state: "new",
|
||||
setupFunc: func(m *Manager) {
|
||||
m.wgNetwork = &net.IPNet{
|
||||
IP: net.ParseIP("100.64.0.0"),
|
||||
Mask: net.CIDRMask(10, 32),
|
||||
}
|
||||
b.Setenv("NB_DISABLE_CONNTRACK", "1")
|
||||
},
|
||||
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
|
||||
@@ -321,10 +302,6 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
|
||||
proto: layers.IPProtocolTCP,
|
||||
state: "established",
|
||||
setupFunc: func(m *Manager) {
|
||||
m.wgNetwork = &net.IPNet{
|
||||
IP: net.ParseIP("100.64.0.0"),
|
||||
Mask: net.CIDRMask(10, 32),
|
||||
}
|
||||
b.Setenv("NB_DISABLE_CONNTRACK", "1")
|
||||
},
|
||||
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
|
||||
@@ -339,10 +316,6 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
|
||||
proto: layers.IPProtocolUDP,
|
||||
state: "new",
|
||||
setupFunc: func(m *Manager) {
|
||||
m.wgNetwork = &net.IPNet{
|
||||
IP: net.ParseIP("100.64.0.0"),
|
||||
Mask: net.CIDRMask(10, 32),
|
||||
}
|
||||
b.Setenv("NB_DISABLE_CONNTRACK", "1")
|
||||
},
|
||||
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
|
||||
@@ -356,10 +329,6 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
|
||||
proto: layers.IPProtocolUDP,
|
||||
state: "established",
|
||||
setupFunc: func(m *Manager) {
|
||||
m.wgNetwork = &net.IPNet{
|
||||
IP: net.ParseIP("100.64.0.0"),
|
||||
Mask: net.CIDRMask(10, 32),
|
||||
}
|
||||
b.Setenv("NB_DISABLE_CONNTRACK", "1")
|
||||
},
|
||||
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
|
||||
@@ -373,10 +342,6 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
|
||||
proto: layers.IPProtocolTCP,
|
||||
state: "new",
|
||||
setupFunc: func(m *Manager) {
|
||||
m.wgNetwork = &net.IPNet{
|
||||
IP: net.ParseIP("0.0.0.0"),
|
||||
Mask: net.CIDRMask(0, 32),
|
||||
}
|
||||
require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK"))
|
||||
},
|
||||
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
|
||||
@@ -390,10 +355,6 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
|
||||
proto: layers.IPProtocolTCP,
|
||||
state: "established",
|
||||
setupFunc: func(m *Manager) {
|
||||
m.wgNetwork = &net.IPNet{
|
||||
IP: net.ParseIP("0.0.0.0"),
|
||||
Mask: net.CIDRMask(0, 32),
|
||||
}
|
||||
require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK"))
|
||||
},
|
||||
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
|
||||
@@ -408,10 +369,6 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
|
||||
proto: layers.IPProtocolTCP,
|
||||
state: "post_handshake",
|
||||
setupFunc: func(m *Manager) {
|
||||
m.wgNetwork = &net.IPNet{
|
||||
IP: net.ParseIP("0.0.0.0"),
|
||||
Mask: net.CIDRMask(0, 32),
|
||||
}
|
||||
require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK"))
|
||||
},
|
||||
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
|
||||
@@ -426,10 +383,6 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
|
||||
proto: layers.IPProtocolUDP,
|
||||
state: "new",
|
||||
setupFunc: func(m *Manager) {
|
||||
m.wgNetwork = &net.IPNet{
|
||||
IP: net.ParseIP("0.0.0.0"),
|
||||
Mask: net.CIDRMask(0, 32),
|
||||
}
|
||||
require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK"))
|
||||
},
|
||||
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
|
||||
@@ -443,10 +396,6 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
|
||||
proto: layers.IPProtocolUDP,
|
||||
state: "established",
|
||||
setupFunc: func(m *Manager) {
|
||||
m.wgNetwork = &net.IPNet{
|
||||
IP: net.ParseIP("0.0.0.0"),
|
||||
Mask: net.CIDRMask(0, 32),
|
||||
}
|
||||
require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK"))
|
||||
},
|
||||
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
|
||||
@@ -477,25 +426,25 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
|
||||
// For stateful cases and established connections
|
||||
if !strings.Contains(sc.name, "allow_non_wg") ||
|
||||
(strings.Contains(sc.state, "established") || sc.state == "post_handshake") {
|
||||
manager.processOutgoingHooks(outbound, 0)
|
||||
manager.filterOutbound(outbound, 0)
|
||||
|
||||
// For TCP post-handshake, simulate full handshake
|
||||
if sc.state == "post_handshake" {
|
||||
// SYN
|
||||
syn := generateTCPPacketWithFlags(b, srcIP, dstIP, 1024, 80, uint16(conntrack.TCPSyn))
|
||||
manager.processOutgoingHooks(syn, 0)
|
||||
manager.filterOutbound(syn, 0)
|
||||
// SYN-ACK
|
||||
synack := generateTCPPacketWithFlags(b, dstIP, srcIP, 80, 1024, uint16(conntrack.TCPSyn|conntrack.TCPAck))
|
||||
manager.dropFilter(synack, 0)
|
||||
manager.filterInbound(synack, 0)
|
||||
// ACK
|
||||
ack := generateTCPPacketWithFlags(b, srcIP, dstIP, 1024, 80, uint16(conntrack.TCPAck))
|
||||
manager.processOutgoingHooks(ack, 0)
|
||||
manager.filterOutbound(ack, 0)
|
||||
}
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
manager.dropFilter(inbound, 0)
|
||||
manager.filterInbound(inbound, 0)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -593,11 +542,6 @@ func BenchmarkLongLivedConnections(b *testing.B) {
|
||||
require.NoError(b, manager.Close(nil))
|
||||
})
|
||||
|
||||
manager.SetNetwork(&net.IPNet{
|
||||
IP: net.ParseIP("100.64.0.0"),
|
||||
Mask: net.CIDRMask(10, 32),
|
||||
})
|
||||
|
||||
// Setup initial state based on scenario
|
||||
if sc.rules {
|
||||
// Single rule to allow all return traffic from port 80
|
||||
@@ -624,17 +568,17 @@ func BenchmarkLongLivedConnections(b *testing.B) {
|
||||
// Initial SYN
|
||||
syn := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
|
||||
uint16(1024+i), 80, uint16(conntrack.TCPSyn))
|
||||
manager.processOutgoingHooks(syn, 0)
|
||||
manager.filterOutbound(syn, 0)
|
||||
|
||||
// SYN-ACK
|
||||
synack := generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i],
|
||||
80, uint16(1024+i), uint16(conntrack.TCPSyn|conntrack.TCPAck))
|
||||
manager.dropFilter(synack, 0)
|
||||
manager.filterInbound(synack, 0)
|
||||
|
||||
// ACK
|
||||
ack := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
|
||||
uint16(1024+i), 80, uint16(conntrack.TCPAck))
|
||||
manager.processOutgoingHooks(ack, 0)
|
||||
manager.filterOutbound(ack, 0)
|
||||
}
|
||||
|
||||
// Prepare test packets simulating bidirectional traffic
|
||||
@@ -655,9 +599,9 @@ func BenchmarkLongLivedConnections(b *testing.B) {
|
||||
|
||||
// Simulate bidirectional traffic
|
||||
// First outbound data
|
||||
manager.processOutgoingHooks(outPackets[connIdx], 0)
|
||||
manager.filterOutbound(outPackets[connIdx], 0)
|
||||
// Then inbound response - this is what we're actually measuring
|
||||
manager.dropFilter(inPackets[connIdx], 0)
|
||||
manager.filterInbound(inPackets[connIdx], 0)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -681,11 +625,6 @@ func BenchmarkShortLivedConnections(b *testing.B) {
|
||||
require.NoError(b, manager.Close(nil))
|
||||
})
|
||||
|
||||
manager.SetNetwork(&net.IPNet{
|
||||
IP: net.ParseIP("100.64.0.0"),
|
||||
Mask: net.CIDRMask(10, 32),
|
||||
})
|
||||
|
||||
// Setup initial state based on scenario
|
||||
if sc.rules {
|
||||
// Single rule to allow all return traffic from port 80
|
||||
@@ -761,19 +700,19 @@ func BenchmarkShortLivedConnections(b *testing.B) {
|
||||
p := patterns[connIdx]
|
||||
|
||||
// Connection establishment
|
||||
manager.processOutgoingHooks(p.syn, 0)
|
||||
manager.dropFilter(p.synAck, 0)
|
||||
manager.processOutgoingHooks(p.ack, 0)
|
||||
manager.filterOutbound(p.syn, 0)
|
||||
manager.filterInbound(p.synAck, 0)
|
||||
manager.filterOutbound(p.ack, 0)
|
||||
|
||||
// Data transfer
|
||||
manager.processOutgoingHooks(p.request, 0)
|
||||
manager.dropFilter(p.response, 0)
|
||||
manager.filterOutbound(p.request, 0)
|
||||
manager.filterInbound(p.response, 0)
|
||||
|
||||
// Connection teardown
|
||||
manager.processOutgoingHooks(p.finClient, 0)
|
||||
manager.dropFilter(p.ackServer, 0)
|
||||
manager.dropFilter(p.finServer, 0)
|
||||
manager.processOutgoingHooks(p.ackClient, 0)
|
||||
manager.filterOutbound(p.finClient, 0)
|
||||
manager.filterInbound(p.ackServer, 0)
|
||||
manager.filterInbound(p.finServer, 0)
|
||||
manager.filterOutbound(p.ackClient, 0)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -797,11 +736,6 @@ func BenchmarkParallelLongLivedConnections(b *testing.B) {
|
||||
require.NoError(b, manager.Close(nil))
|
||||
})
|
||||
|
||||
manager.SetNetwork(&net.IPNet{
|
||||
IP: net.ParseIP("100.64.0.0"),
|
||||
Mask: net.CIDRMask(10, 32),
|
||||
})
|
||||
|
||||
// Setup initial state based on scenario
|
||||
if sc.rules {
|
||||
_, err := manager.AddPeerFiltering(nil, net.ParseIP("0.0.0.0"), fw.ProtocolTCP, &fw.Port{Values: []uint16{80}}, nil, fw.ActionAccept, "")
|
||||
@@ -826,15 +760,15 @@ func BenchmarkParallelLongLivedConnections(b *testing.B) {
|
||||
for i := 0; i < sc.connCount; i++ {
|
||||
syn := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
|
||||
uint16(1024+i), 80, uint16(conntrack.TCPSyn))
|
||||
manager.processOutgoingHooks(syn, 0)
|
||||
manager.filterOutbound(syn, 0)
|
||||
|
||||
synack := generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i],
|
||||
80, uint16(1024+i), uint16(conntrack.TCPSyn|conntrack.TCPAck))
|
||||
manager.dropFilter(synack, 0)
|
||||
manager.filterInbound(synack, 0)
|
||||
|
||||
ack := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
|
||||
uint16(1024+i), 80, uint16(conntrack.TCPAck))
|
||||
manager.processOutgoingHooks(ack, 0)
|
||||
manager.filterOutbound(ack, 0)
|
||||
}
|
||||
|
||||
// Pre-generate test packets
|
||||
@@ -856,8 +790,8 @@ func BenchmarkParallelLongLivedConnections(b *testing.B) {
|
||||
counter++
|
||||
|
||||
// Simulate bidirectional traffic
|
||||
manager.processOutgoingHooks(outPackets[connIdx], 0)
|
||||
manager.dropFilter(inPackets[connIdx], 0)
|
||||
manager.filterOutbound(outPackets[connIdx], 0)
|
||||
manager.filterInbound(inPackets[connIdx], 0)
|
||||
}
|
||||
})
|
||||
})
|
||||
@@ -882,11 +816,6 @@ func BenchmarkParallelShortLivedConnections(b *testing.B) {
|
||||
require.NoError(b, manager.Close(nil))
|
||||
})
|
||||
|
||||
manager.SetNetwork(&net.IPNet{
|
||||
IP: net.ParseIP("100.64.0.0"),
|
||||
Mask: net.CIDRMask(10, 32),
|
||||
})
|
||||
|
||||
if sc.rules {
|
||||
_, err := manager.AddPeerFiltering(nil, net.ParseIP("0.0.0.0"), fw.ProtocolTCP, &fw.Port{Values: []uint16{80}}, nil, fw.ActionAccept, "")
|
||||
require.NoError(b, err)
|
||||
@@ -950,17 +879,17 @@ func BenchmarkParallelShortLivedConnections(b *testing.B) {
|
||||
p := patterns[connIdx]
|
||||
|
||||
// Full connection lifecycle
|
||||
manager.processOutgoingHooks(p.syn, 0)
|
||||
manager.dropFilter(p.synAck, 0)
|
||||
manager.processOutgoingHooks(p.ack, 0)
|
||||
manager.filterOutbound(p.syn, 0)
|
||||
manager.filterInbound(p.synAck, 0)
|
||||
manager.filterOutbound(p.ack, 0)
|
||||
|
||||
manager.processOutgoingHooks(p.request, 0)
|
||||
manager.dropFilter(p.response, 0)
|
||||
manager.filterOutbound(p.request, 0)
|
||||
manager.filterInbound(p.response, 0)
|
||||
|
||||
manager.processOutgoingHooks(p.finClient, 0)
|
||||
manager.dropFilter(p.ackServer, 0)
|
||||
manager.dropFilter(p.finServer, 0)
|
||||
manager.processOutgoingHooks(p.ackClient, 0)
|
||||
manager.filterOutbound(p.finClient, 0)
|
||||
manager.filterInbound(p.ackServer, 0)
|
||||
manager.filterInbound(p.finServer, 0)
|
||||
manager.filterOutbound(p.ackClient, 0)
|
||||
}
|
||||
})
|
||||
})
|
||||
@@ -1032,7 +961,8 @@ func BenchmarkRouteACLs(b *testing.B) {
|
||||
}
|
||||
|
||||
for _, r := range rules {
|
||||
_, err := manager.AddRouteFiltering(nil, r.sources, r.dest, r.proto, nil, r.port, fw.ActionAccept)
|
||||
dst := fw.Network{Prefix: r.dest}
|
||||
_, err := manager.AddRouteFiltering(nil, r.sources, dst, r.proto, nil, r.port, fw.ActionAccept)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
@@ -15,15 +15,12 @@ import (
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/client/iface/mocks"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
"github.com/netbirdio/netbird/management/domain"
|
||||
)
|
||||
|
||||
func TestPeerACLFiltering(t *testing.T) {
|
||||
localIP := net.ParseIP("100.10.0.100")
|
||||
wgNet := &net.IPNet{
|
||||
IP: net.ParseIP("100.10.0.0"),
|
||||
Mask: net.CIDRMask(16, 32),
|
||||
}
|
||||
|
||||
localIP := netip.MustParseAddr("100.10.0.100")
|
||||
wgNet := netip.MustParsePrefix("100.10.0.0/16")
|
||||
ifaceMock := &IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
AddressFunc: func() wgaddr.Address {
|
||||
@@ -42,8 +39,6 @@ func TestPeerACLFiltering(t *testing.T) {
|
||||
require.NoError(t, manager.Close(nil))
|
||||
})
|
||||
|
||||
manager.wgNetwork = wgNet
|
||||
|
||||
err = manager.UpdateLocalIPs()
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -188,16 +183,313 @@ func TestPeerACLFiltering(t *testing.T) {
|
||||
ruleAction: fw.ActionAccept,
|
||||
shouldBeBlocked: true,
|
||||
},
|
||||
{
|
||||
name: "Allow TCP traffic without port specification",
|
||||
srcIP: "100.10.0.1",
|
||||
dstIP: "100.10.0.100",
|
||||
proto: fw.ProtocolTCP,
|
||||
srcPort: 12345,
|
||||
dstPort: 443,
|
||||
ruleIP: "100.10.0.1",
|
||||
ruleProto: fw.ProtocolTCP,
|
||||
ruleAction: fw.ActionAccept,
|
||||
shouldBeBlocked: false,
|
||||
},
|
||||
{
|
||||
name: "Allow UDP traffic without port specification",
|
||||
srcIP: "100.10.0.1",
|
||||
dstIP: "100.10.0.100",
|
||||
proto: fw.ProtocolUDP,
|
||||
srcPort: 12345,
|
||||
dstPort: 53,
|
||||
ruleIP: "100.10.0.1",
|
||||
ruleProto: fw.ProtocolUDP,
|
||||
ruleAction: fw.ActionAccept,
|
||||
shouldBeBlocked: false,
|
||||
},
|
||||
{
|
||||
name: "TCP packet doesn't match UDP filter with same port",
|
||||
srcIP: "100.10.0.1",
|
||||
dstIP: "100.10.0.100",
|
||||
proto: fw.ProtocolTCP,
|
||||
srcPort: 12345,
|
||||
dstPort: 443,
|
||||
ruleIP: "100.10.0.1",
|
||||
ruleProto: fw.ProtocolUDP,
|
||||
ruleDstPort: &fw.Port{Values: []uint16{443}},
|
||||
ruleAction: fw.ActionAccept,
|
||||
shouldBeBlocked: true,
|
||||
},
|
||||
{
|
||||
name: "UDP packet doesn't match TCP filter with same port",
|
||||
srcIP: "100.10.0.1",
|
||||
dstIP: "100.10.0.100",
|
||||
proto: fw.ProtocolUDP,
|
||||
srcPort: 12345,
|
||||
dstPort: 443,
|
||||
ruleIP: "100.10.0.1",
|
||||
ruleProto: fw.ProtocolTCP,
|
||||
ruleDstPort: &fw.Port{Values: []uint16{443}},
|
||||
ruleAction: fw.ActionAccept,
|
||||
shouldBeBlocked: true,
|
||||
},
|
||||
{
|
||||
name: "ICMP packet doesn't match TCP filter",
|
||||
srcIP: "100.10.0.1",
|
||||
dstIP: "100.10.0.100",
|
||||
proto: fw.ProtocolICMP,
|
||||
ruleIP: "100.10.0.1",
|
||||
ruleProto: fw.ProtocolTCP,
|
||||
ruleAction: fw.ActionAccept,
|
||||
shouldBeBlocked: true,
|
||||
},
|
||||
{
|
||||
name: "ICMP packet doesn't match UDP filter",
|
||||
srcIP: "100.10.0.1",
|
||||
dstIP: "100.10.0.100",
|
||||
proto: fw.ProtocolICMP,
|
||||
ruleIP: "100.10.0.1",
|
||||
ruleProto: fw.ProtocolUDP,
|
||||
ruleAction: fw.ActionAccept,
|
||||
shouldBeBlocked: true,
|
||||
},
|
||||
{
|
||||
name: "Allow TCP traffic within port range",
|
||||
srcIP: "100.10.0.1",
|
||||
dstIP: "100.10.0.100",
|
||||
proto: fw.ProtocolTCP,
|
||||
srcPort: 12345,
|
||||
dstPort: 8080,
|
||||
ruleIP: "100.10.0.1",
|
||||
ruleProto: fw.ProtocolTCP,
|
||||
ruleDstPort: &fw.Port{IsRange: true, Values: []uint16{8000, 8100}},
|
||||
ruleAction: fw.ActionAccept,
|
||||
shouldBeBlocked: false,
|
||||
},
|
||||
{
|
||||
name: "Block TCP traffic outside port range",
|
||||
srcIP: "100.10.0.1",
|
||||
dstIP: "100.10.0.100",
|
||||
proto: fw.ProtocolTCP,
|
||||
srcPort: 12345,
|
||||
dstPort: 7999,
|
||||
ruleIP: "100.10.0.1",
|
||||
ruleProto: fw.ProtocolTCP,
|
||||
ruleDstPort: &fw.Port{IsRange: true, Values: []uint16{8000, 8100}},
|
||||
ruleAction: fw.ActionAccept,
|
||||
shouldBeBlocked: true,
|
||||
},
|
||||
{
|
||||
name: "Edge Case - Port at Range Boundary",
|
||||
srcIP: "100.10.0.1",
|
||||
dstIP: "100.10.0.100",
|
||||
proto: fw.ProtocolTCP,
|
||||
srcPort: 12345,
|
||||
dstPort: 8100,
|
||||
ruleIP: "100.10.0.1",
|
||||
ruleProto: fw.ProtocolTCP,
|
||||
ruleDstPort: &fw.Port{IsRange: true, Values: []uint16{8000, 8100}},
|
||||
ruleAction: fw.ActionAccept,
|
||||
shouldBeBlocked: false,
|
||||
},
|
||||
{
|
||||
name: "UDP Port Range",
|
||||
srcIP: "100.10.0.1",
|
||||
dstIP: "100.10.0.100",
|
||||
proto: fw.ProtocolUDP,
|
||||
srcPort: 12345,
|
||||
dstPort: 5060,
|
||||
ruleIP: "100.10.0.1",
|
||||
ruleProto: fw.ProtocolUDP,
|
||||
ruleDstPort: &fw.Port{IsRange: true, Values: []uint16{5060, 5070}},
|
||||
ruleAction: fw.ActionAccept,
|
||||
shouldBeBlocked: false,
|
||||
},
|
||||
{
|
||||
name: "Allow multiple destination ports",
|
||||
srcIP: "100.10.0.1",
|
||||
dstIP: "100.10.0.100",
|
||||
proto: fw.ProtocolTCP,
|
||||
srcPort: 12345,
|
||||
dstPort: 8080,
|
||||
ruleIP: "100.10.0.1",
|
||||
ruleProto: fw.ProtocolTCP,
|
||||
ruleDstPort: &fw.Port{Values: []uint16{80, 8080, 443}},
|
||||
ruleAction: fw.ActionAccept,
|
||||
shouldBeBlocked: false,
|
||||
},
|
||||
{
|
||||
name: "Allow multiple source ports",
|
||||
srcIP: "100.10.0.1",
|
||||
dstIP: "100.10.0.100",
|
||||
proto: fw.ProtocolTCP,
|
||||
srcPort: 12345,
|
||||
dstPort: 80,
|
||||
ruleIP: "100.10.0.1",
|
||||
ruleProto: fw.ProtocolTCP,
|
||||
ruleSrcPort: &fw.Port{Values: []uint16{12345, 12346, 12347}},
|
||||
ruleAction: fw.ActionAccept,
|
||||
shouldBeBlocked: false,
|
||||
},
|
||||
// New drop test cases
|
||||
{
|
||||
name: "Drop TCP traffic from WG peer",
|
||||
srcIP: "100.10.0.1",
|
||||
dstIP: "100.10.0.100",
|
||||
proto: fw.ProtocolTCP,
|
||||
srcPort: 12345,
|
||||
dstPort: 443,
|
||||
ruleIP: "100.10.0.1",
|
||||
ruleProto: fw.ProtocolTCP,
|
||||
ruleDstPort: &fw.Port{Values: []uint16{443}},
|
||||
ruleAction: fw.ActionDrop,
|
||||
shouldBeBlocked: true,
|
||||
},
|
||||
{
|
||||
name: "Drop UDP traffic from WG peer",
|
||||
srcIP: "100.10.0.1",
|
||||
dstIP: "100.10.0.100",
|
||||
proto: fw.ProtocolUDP,
|
||||
srcPort: 12345,
|
||||
dstPort: 53,
|
||||
ruleIP: "100.10.0.1",
|
||||
ruleProto: fw.ProtocolUDP,
|
||||
ruleDstPort: &fw.Port{Values: []uint16{53}},
|
||||
ruleAction: fw.ActionDrop,
|
||||
shouldBeBlocked: true,
|
||||
},
|
||||
{
|
||||
name: "Drop ICMP traffic from WG peer",
|
||||
srcIP: "100.10.0.1",
|
||||
dstIP: "100.10.0.100",
|
||||
proto: fw.ProtocolICMP,
|
||||
ruleIP: "100.10.0.1",
|
||||
ruleProto: fw.ProtocolICMP,
|
||||
ruleAction: fw.ActionDrop,
|
||||
shouldBeBlocked: true,
|
||||
},
|
||||
{
|
||||
name: "Drop all traffic from WG peer",
|
||||
srcIP: "100.10.0.1",
|
||||
dstIP: "100.10.0.100",
|
||||
proto: fw.ProtocolTCP,
|
||||
srcPort: 12345,
|
||||
dstPort: 443,
|
||||
ruleIP: "100.10.0.1",
|
||||
ruleProto: fw.ProtocolALL,
|
||||
ruleAction: fw.ActionDrop,
|
||||
shouldBeBlocked: true,
|
||||
},
|
||||
{
|
||||
name: "Drop traffic from multiple source ports",
|
||||
srcIP: "100.10.0.1",
|
||||
dstIP: "100.10.0.100",
|
||||
proto: fw.ProtocolTCP,
|
||||
srcPort: 12345,
|
||||
dstPort: 80,
|
||||
ruleIP: "100.10.0.1",
|
||||
ruleProto: fw.ProtocolTCP,
|
||||
ruleSrcPort: &fw.Port{Values: []uint16{12345, 12346, 12347}},
|
||||
ruleAction: fw.ActionDrop,
|
||||
shouldBeBlocked: true,
|
||||
},
|
||||
{
|
||||
name: "Drop multiple destination ports",
|
||||
srcIP: "100.10.0.1",
|
||||
dstIP: "100.10.0.100",
|
||||
proto: fw.ProtocolTCP,
|
||||
srcPort: 12345,
|
||||
dstPort: 8080,
|
||||
ruleIP: "100.10.0.1",
|
||||
ruleProto: fw.ProtocolTCP,
|
||||
ruleDstPort: &fw.Port{Values: []uint16{80, 8080, 443}},
|
||||
ruleAction: fw.ActionDrop,
|
||||
shouldBeBlocked: true,
|
||||
},
|
||||
{
|
||||
name: "Drop TCP traffic within port range",
|
||||
srcIP: "100.10.0.1",
|
||||
dstIP: "100.10.0.100",
|
||||
proto: fw.ProtocolTCP,
|
||||
srcPort: 12345,
|
||||
dstPort: 8080,
|
||||
ruleIP: "100.10.0.1",
|
||||
ruleProto: fw.ProtocolTCP,
|
||||
ruleDstPort: &fw.Port{IsRange: true, Values: []uint16{8000, 8100}},
|
||||
ruleAction: fw.ActionDrop,
|
||||
shouldBeBlocked: true,
|
||||
},
|
||||
{
|
||||
name: "Accept TCP traffic outside drop port range",
|
||||
srcIP: "100.10.0.1",
|
||||
dstIP: "100.10.0.100",
|
||||
proto: fw.ProtocolTCP,
|
||||
srcPort: 12345,
|
||||
dstPort: 7999,
|
||||
ruleIP: "100.10.0.1",
|
||||
ruleProto: fw.ProtocolTCP,
|
||||
ruleDstPort: &fw.Port{IsRange: true, Values: []uint16{8000, 8100}},
|
||||
ruleAction: fw.ActionDrop,
|
||||
shouldBeBlocked: false,
|
||||
},
|
||||
{
|
||||
name: "Drop TCP traffic with source port range",
|
||||
srcIP: "100.10.0.1",
|
||||
dstIP: "100.10.0.100",
|
||||
proto: fw.ProtocolTCP,
|
||||
srcPort: 32100,
|
||||
dstPort: 80,
|
||||
ruleIP: "100.10.0.1",
|
||||
ruleProto: fw.ProtocolTCP,
|
||||
ruleSrcPort: &fw.Port{IsRange: true, Values: []uint16{32000, 33000}},
|
||||
ruleAction: fw.ActionDrop,
|
||||
shouldBeBlocked: true,
|
||||
},
|
||||
{
|
||||
name: "Mixed rule - drop specific port but allow other ports",
|
||||
srcIP: "100.10.0.1",
|
||||
dstIP: "100.10.0.100",
|
||||
proto: fw.ProtocolTCP,
|
||||
srcPort: 12345,
|
||||
dstPort: 443,
|
||||
ruleIP: "100.10.0.1",
|
||||
ruleProto: fw.ProtocolTCP,
|
||||
ruleDstPort: &fw.Port{Values: []uint16{443}},
|
||||
ruleAction: fw.ActionDrop,
|
||||
shouldBeBlocked: true,
|
||||
},
|
||||
}
|
||||
|
||||
t.Run("Implicit DROP (no rules)", func(t *testing.T) {
|
||||
packet := createTestPacket(t, "100.10.0.1", "100.10.0.100", fw.ProtocolTCP, 12345, 443)
|
||||
isDropped := manager.DropIncoming(packet, 0)
|
||||
isDropped := manager.FilterInbound(packet, 0)
|
||||
require.True(t, isDropped, "Packet should be dropped when no rules exist")
|
||||
})
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
|
||||
if tc.ruleAction == fw.ActionDrop {
|
||||
// add general accept rule to test drop rule
|
||||
// TODO: this only works because 0.0.0.0 is tested last, we need to implement order
|
||||
rules, err := manager.AddPeerFiltering(
|
||||
nil,
|
||||
net.ParseIP("0.0.0.0"),
|
||||
fw.ProtocolALL,
|
||||
nil,
|
||||
nil,
|
||||
fw.ActionAccept,
|
||||
"",
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, rules)
|
||||
t.Cleanup(func() {
|
||||
for _, rule := range rules {
|
||||
require.NoError(t, manager.DeletePeerRule(rule))
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
rules, err := manager.AddPeerFiltering(
|
||||
nil,
|
||||
net.ParseIP(tc.ruleIP),
|
||||
@@ -217,7 +509,7 @@ func TestPeerACLFiltering(t *testing.T) {
|
||||
})
|
||||
|
||||
packet := createTestPacket(t, tc.srcIP, tc.dstIP, tc.proto, tc.srcPort, tc.dstPort)
|
||||
isDropped := manager.DropIncoming(packet, 0)
|
||||
isDropped := manager.FilterInbound(packet, 0)
|
||||
require.Equal(t, tc.shouldBeBlocked, isDropped)
|
||||
})
|
||||
}
|
||||
@@ -283,14 +575,13 @@ func setupRoutedManager(tb testing.TB, network string) *Manager {
|
||||
dev := mocks.NewMockDevice(ctrl)
|
||||
dev.EXPECT().MTU().Return(1500, nil).AnyTimes()
|
||||
|
||||
localIP, wgNet, err := net.ParseCIDR(network)
|
||||
require.NoError(tb, err)
|
||||
wgNet := netip.MustParsePrefix(network)
|
||||
|
||||
ifaceMock := &IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
AddressFunc: func() wgaddr.Address {
|
||||
return wgaddr.Address{
|
||||
IP: localIP,
|
||||
IP: wgNet.Addr(),
|
||||
Network: wgNet,
|
||||
}
|
||||
},
|
||||
@@ -303,8 +594,8 @@ func setupRoutedManager(tb testing.TB, network string) *Manager {
|
||||
}
|
||||
|
||||
manager, err := Create(ifaceMock, false, flowLogger)
|
||||
require.NoError(tb, manager.EnableRouting())
|
||||
require.NoError(tb, err)
|
||||
require.NoError(tb, manager.EnableRouting())
|
||||
require.NotNil(tb, manager)
|
||||
require.True(tb, manager.routingEnabled.Load())
|
||||
require.False(tb, manager.nativeRouter.Load())
|
||||
@@ -321,7 +612,7 @@ func TestRouteACLFiltering(t *testing.T) {
|
||||
|
||||
type rule struct {
|
||||
sources []netip.Prefix
|
||||
dest netip.Prefix
|
||||
dest fw.Network
|
||||
proto fw.Protocol
|
||||
srcPort *fw.Port
|
||||
dstPort *fw.Port
|
||||
@@ -347,7 +638,7 @@ func TestRouteACLFiltering(t *testing.T) {
|
||||
dstPort: 443,
|
||||
rule: rule{
|
||||
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
||||
dest: netip.MustParsePrefix("192.168.1.0/24"),
|
||||
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||
proto: fw.ProtocolTCP,
|
||||
dstPort: &fw.Port{Values: []uint16{443}},
|
||||
action: fw.ActionAccept,
|
||||
@@ -363,7 +654,7 @@ func TestRouteACLFiltering(t *testing.T) {
|
||||
dstPort: 443,
|
||||
rule: rule{
|
||||
sources: []netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")},
|
||||
dest: netip.MustParsePrefix("192.168.1.0/24"),
|
||||
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||
proto: fw.ProtocolTCP,
|
||||
dstPort: &fw.Port{Values: []uint16{443}},
|
||||
action: fw.ActionAccept,
|
||||
@@ -379,7 +670,7 @@ func TestRouteACLFiltering(t *testing.T) {
|
||||
dstPort: 443,
|
||||
rule: rule{
|
||||
sources: []netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")},
|
||||
dest: netip.MustParsePrefix("0.0.0.0/0"),
|
||||
dest: fw.Network{Prefix: netip.MustParsePrefix("0.0.0.0/0")},
|
||||
proto: fw.ProtocolTCP,
|
||||
dstPort: &fw.Port{Values: []uint16{443}},
|
||||
action: fw.ActionAccept,
|
||||
@@ -395,7 +686,7 @@ func TestRouteACLFiltering(t *testing.T) {
|
||||
dstPort: 53,
|
||||
rule: rule{
|
||||
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
||||
dest: netip.MustParsePrefix("192.168.1.0/24"),
|
||||
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||
proto: fw.ProtocolUDP,
|
||||
dstPort: &fw.Port{Values: []uint16{53}},
|
||||
action: fw.ActionAccept,
|
||||
@@ -409,7 +700,7 @@ func TestRouteACLFiltering(t *testing.T) {
|
||||
proto: fw.ProtocolICMP,
|
||||
rule: rule{
|
||||
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
||||
dest: netip.MustParsePrefix("0.0.0.0/0"),
|
||||
dest: fw.Network{Prefix: netip.MustParsePrefix("0.0.0.0/0")},
|
||||
proto: fw.ProtocolICMP,
|
||||
action: fw.ActionAccept,
|
||||
},
|
||||
@@ -424,7 +715,7 @@ func TestRouteACLFiltering(t *testing.T) {
|
||||
dstPort: 80,
|
||||
rule: rule{
|
||||
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
||||
dest: netip.MustParsePrefix("192.168.1.0/24"),
|
||||
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||
proto: fw.ProtocolALL,
|
||||
dstPort: &fw.Port{Values: []uint16{80}},
|
||||
action: fw.ActionAccept,
|
||||
@@ -440,7 +731,7 @@ func TestRouteACLFiltering(t *testing.T) {
|
||||
dstPort: 8080,
|
||||
rule: rule{
|
||||
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
||||
dest: netip.MustParsePrefix("192.168.1.0/24"),
|
||||
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||
proto: fw.ProtocolTCP,
|
||||
dstPort: &fw.Port{Values: []uint16{80}},
|
||||
action: fw.ActionAccept,
|
||||
@@ -456,7 +747,7 @@ func TestRouteACLFiltering(t *testing.T) {
|
||||
dstPort: 80,
|
||||
rule: rule{
|
||||
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
||||
dest: netip.MustParsePrefix("192.168.1.0/24"),
|
||||
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||
proto: fw.ProtocolTCP,
|
||||
dstPort: &fw.Port{Values: []uint16{80}},
|
||||
action: fw.ActionAccept,
|
||||
@@ -472,7 +763,7 @@ func TestRouteACLFiltering(t *testing.T) {
|
||||
dstPort: 80,
|
||||
rule: rule{
|
||||
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
||||
dest: netip.MustParsePrefix("192.168.1.0/24"),
|
||||
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||
proto: fw.ProtocolTCP,
|
||||
dstPort: &fw.Port{Values: []uint16{80}},
|
||||
action: fw.ActionAccept,
|
||||
@@ -488,7 +779,7 @@ func TestRouteACLFiltering(t *testing.T) {
|
||||
dstPort: 80,
|
||||
rule: rule{
|
||||
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
||||
dest: netip.MustParsePrefix("192.168.1.0/24"),
|
||||
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||
proto: fw.ProtocolTCP,
|
||||
srcPort: &fw.Port{Values: []uint16{12345}},
|
||||
action: fw.ActionAccept,
|
||||
@@ -507,7 +798,7 @@ func TestRouteACLFiltering(t *testing.T) {
|
||||
netip.MustParsePrefix("100.10.0.0/16"),
|
||||
netip.MustParsePrefix("172.16.0.0/16"),
|
||||
},
|
||||
dest: netip.MustParsePrefix("192.168.1.0/24"),
|
||||
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||
proto: fw.ProtocolTCP,
|
||||
dstPort: &fw.Port{Values: []uint16{80}},
|
||||
action: fw.ActionAccept,
|
||||
@@ -521,7 +812,7 @@ func TestRouteACLFiltering(t *testing.T) {
|
||||
proto: fw.ProtocolICMP,
|
||||
rule: rule{
|
||||
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
||||
dest: netip.MustParsePrefix("192.168.1.0/24"),
|
||||
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||
proto: fw.ProtocolALL,
|
||||
action: fw.ActionAccept,
|
||||
},
|
||||
@@ -536,33 +827,13 @@ func TestRouteACLFiltering(t *testing.T) {
|
||||
dstPort: 80,
|
||||
rule: rule{
|
||||
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
||||
dest: netip.MustParsePrefix("192.168.1.0/24"),
|
||||
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||
proto: fw.ProtocolALL,
|
||||
dstPort: &fw.Port{Values: []uint16{80}},
|
||||
action: fw.ActionAccept,
|
||||
},
|
||||
shouldPass: true,
|
||||
},
|
||||
{
|
||||
name: "Multiple source networks with mismatched protocol",
|
||||
srcIP: "172.16.0.1",
|
||||
dstIP: "192.168.1.100",
|
||||
// Should not match TCP rule
|
||||
proto: fw.ProtocolUDP,
|
||||
srcPort: 12345,
|
||||
dstPort: 80,
|
||||
rule: rule{
|
||||
sources: []netip.Prefix{
|
||||
netip.MustParsePrefix("100.10.0.0/16"),
|
||||
netip.MustParsePrefix("172.16.0.0/16"),
|
||||
},
|
||||
dest: netip.MustParsePrefix("192.168.1.0/24"),
|
||||
proto: fw.ProtocolTCP,
|
||||
dstPort: &fw.Port{Values: []uint16{80}},
|
||||
action: fw.ActionAccept,
|
||||
},
|
||||
shouldPass: false,
|
||||
},
|
||||
{
|
||||
name: "Allow multiple destination ports",
|
||||
srcIP: "100.10.0.1",
|
||||
@@ -572,7 +843,7 @@ func TestRouteACLFiltering(t *testing.T) {
|
||||
dstPort: 8080,
|
||||
rule: rule{
|
||||
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
||||
dest: netip.MustParsePrefix("192.168.1.0/24"),
|
||||
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||
proto: fw.ProtocolTCP,
|
||||
dstPort: &fw.Port{Values: []uint16{80, 8080, 443}},
|
||||
action: fw.ActionAccept,
|
||||
@@ -588,7 +859,7 @@ func TestRouteACLFiltering(t *testing.T) {
|
||||
dstPort: 80,
|
||||
rule: rule{
|
||||
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
||||
dest: netip.MustParsePrefix("192.168.1.0/24"),
|
||||
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||
proto: fw.ProtocolTCP,
|
||||
srcPort: &fw.Port{Values: []uint16{12345, 12346, 12347}},
|
||||
action: fw.ActionAccept,
|
||||
@@ -604,7 +875,7 @@ func TestRouteACLFiltering(t *testing.T) {
|
||||
dstPort: 80,
|
||||
rule: rule{
|
||||
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
||||
dest: netip.MustParsePrefix("192.168.1.0/24"),
|
||||
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||
proto: fw.ProtocolALL,
|
||||
srcPort: &fw.Port{Values: []uint16{12345}},
|
||||
dstPort: &fw.Port{Values: []uint16{80}},
|
||||
@@ -621,7 +892,7 @@ func TestRouteACLFiltering(t *testing.T) {
|
||||
dstPort: 8080,
|
||||
rule: rule{
|
||||
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
||||
dest: netip.MustParsePrefix("192.168.1.0/24"),
|
||||
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||
proto: fw.ProtocolTCP,
|
||||
dstPort: &fw.Port{
|
||||
IsRange: true,
|
||||
@@ -640,7 +911,7 @@ func TestRouteACLFiltering(t *testing.T) {
|
||||
dstPort: 7999,
|
||||
rule: rule{
|
||||
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
||||
dest: netip.MustParsePrefix("192.168.1.0/24"),
|
||||
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||
proto: fw.ProtocolTCP,
|
||||
dstPort: &fw.Port{
|
||||
IsRange: true,
|
||||
@@ -659,7 +930,7 @@ func TestRouteACLFiltering(t *testing.T) {
|
||||
dstPort: 80,
|
||||
rule: rule{
|
||||
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
||||
dest: netip.MustParsePrefix("192.168.1.0/24"),
|
||||
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||
proto: fw.ProtocolTCP,
|
||||
srcPort: &fw.Port{
|
||||
IsRange: true,
|
||||
@@ -678,7 +949,7 @@ func TestRouteACLFiltering(t *testing.T) {
|
||||
dstPort: 443,
|
||||
rule: rule{
|
||||
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
||||
dest: netip.MustParsePrefix("192.168.1.0/24"),
|
||||
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||
proto: fw.ProtocolTCP,
|
||||
srcPort: &fw.Port{
|
||||
IsRange: true,
|
||||
@@ -700,7 +971,7 @@ func TestRouteACLFiltering(t *testing.T) {
|
||||
dstPort: 8100,
|
||||
rule: rule{
|
||||
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
||||
dest: netip.MustParsePrefix("192.168.1.0/24"),
|
||||
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||
proto: fw.ProtocolTCP,
|
||||
dstPort: &fw.Port{
|
||||
IsRange: true,
|
||||
@@ -719,7 +990,7 @@ func TestRouteACLFiltering(t *testing.T) {
|
||||
dstPort: 5060,
|
||||
rule: rule{
|
||||
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
||||
dest: netip.MustParsePrefix("192.168.1.0/24"),
|
||||
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||
proto: fw.ProtocolUDP,
|
||||
dstPort: &fw.Port{
|
||||
IsRange: true,
|
||||
@@ -738,7 +1009,7 @@ func TestRouteACLFiltering(t *testing.T) {
|
||||
dstPort: 8080,
|
||||
rule: rule{
|
||||
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
||||
dest: netip.MustParsePrefix("192.168.1.0/24"),
|
||||
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||
proto: fw.ProtocolALL,
|
||||
dstPort: &fw.Port{
|
||||
IsRange: true,
|
||||
@@ -757,7 +1028,7 @@ func TestRouteACLFiltering(t *testing.T) {
|
||||
dstPort: 443,
|
||||
rule: rule{
|
||||
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
||||
dest: netip.MustParsePrefix("192.168.1.0/24"),
|
||||
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||
proto: fw.ProtocolTCP,
|
||||
dstPort: &fw.Port{Values: []uint16{443}},
|
||||
action: fw.ActionDrop,
|
||||
@@ -773,7 +1044,7 @@ func TestRouteACLFiltering(t *testing.T) {
|
||||
dstPort: 80,
|
||||
rule: rule{
|
||||
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
||||
dest: netip.MustParsePrefix("192.168.1.0/24"),
|
||||
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||
proto: fw.ProtocolALL,
|
||||
action: fw.ActionDrop,
|
||||
},
|
||||
@@ -791,17 +1062,158 @@ func TestRouteACLFiltering(t *testing.T) {
|
||||
netip.MustParsePrefix("100.10.0.0/16"),
|
||||
netip.MustParsePrefix("172.16.0.0/16"),
|
||||
},
|
||||
dest: netip.MustParsePrefix("192.168.1.0/24"),
|
||||
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||
proto: fw.ProtocolTCP,
|
||||
dstPort: &fw.Port{Values: []uint16{80}},
|
||||
action: fw.ActionDrop,
|
||||
},
|
||||
shouldPass: false,
|
||||
},
|
||||
|
||||
{
|
||||
name: "Drop empty destination set",
|
||||
srcIP: "172.16.0.1",
|
||||
dstIP: "192.168.1.100",
|
||||
proto: fw.ProtocolTCP,
|
||||
srcPort: 12345,
|
||||
dstPort: 80,
|
||||
rule: rule{
|
||||
sources: []netip.Prefix{
|
||||
netip.MustParsePrefix("172.16.0.0/16"),
|
||||
},
|
||||
dest: fw.Network{Set: fw.Set{}},
|
||||
proto: fw.ProtocolTCP,
|
||||
dstPort: &fw.Port{Values: []uint16{80}},
|
||||
action: fw.ActionAccept,
|
||||
},
|
||||
shouldPass: false,
|
||||
},
|
||||
{
|
||||
name: "Accept TCP traffic outside drop port range",
|
||||
srcIP: "100.10.0.1",
|
||||
dstIP: "192.168.1.100",
|
||||
proto: fw.ProtocolTCP,
|
||||
srcPort: 12345,
|
||||
dstPort: 7999,
|
||||
rule: rule{
|
||||
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
||||
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||
proto: fw.ProtocolTCP,
|
||||
dstPort: &fw.Port{IsRange: true, Values: []uint16{8000, 8100}},
|
||||
action: fw.ActionDrop,
|
||||
},
|
||||
shouldPass: true,
|
||||
},
|
||||
{
|
||||
name: "Allow TCP traffic without port specification",
|
||||
srcIP: "100.10.0.1",
|
||||
dstIP: "192.168.1.100",
|
||||
proto: fw.ProtocolTCP,
|
||||
srcPort: 12345,
|
||||
dstPort: 443,
|
||||
rule: rule{
|
||||
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
||||
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||
proto: fw.ProtocolTCP,
|
||||
action: fw.ActionAccept,
|
||||
},
|
||||
shouldPass: true,
|
||||
},
|
||||
{
|
||||
name: "Allow UDP traffic without port specification",
|
||||
srcIP: "100.10.0.1",
|
||||
dstIP: "192.168.1.100",
|
||||
proto: fw.ProtocolUDP,
|
||||
srcPort: 12345,
|
||||
dstPort: 53,
|
||||
rule: rule{
|
||||
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
||||
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||
proto: fw.ProtocolUDP,
|
||||
action: fw.ActionAccept,
|
||||
},
|
||||
shouldPass: true,
|
||||
},
|
||||
{
|
||||
name: "TCP packet doesn't match UDP filter with same port",
|
||||
srcIP: "100.10.0.1",
|
||||
dstIP: "192.168.1.100",
|
||||
proto: fw.ProtocolTCP,
|
||||
srcPort: 12345,
|
||||
dstPort: 80,
|
||||
rule: rule{
|
||||
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
||||
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||
proto: fw.ProtocolUDP,
|
||||
dstPort: &fw.Port{Values: []uint16{80}},
|
||||
action: fw.ActionAccept,
|
||||
},
|
||||
shouldPass: false,
|
||||
},
|
||||
{
|
||||
name: "UDP packet doesn't match TCP filter with same port",
|
||||
srcIP: "100.10.0.1",
|
||||
dstIP: "192.168.1.100",
|
||||
proto: fw.ProtocolUDP,
|
||||
srcPort: 12345,
|
||||
dstPort: 80,
|
||||
rule: rule{
|
||||
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
||||
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||
proto: fw.ProtocolTCP,
|
||||
dstPort: &fw.Port{Values: []uint16{80}},
|
||||
action: fw.ActionAccept,
|
||||
},
|
||||
shouldPass: false,
|
||||
},
|
||||
{
|
||||
name: "ICMP packet doesn't match TCP filter",
|
||||
srcIP: "100.10.0.1",
|
||||
dstIP: "192.168.1.100",
|
||||
proto: fw.ProtocolICMP,
|
||||
rule: rule{
|
||||
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
||||
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||
proto: fw.ProtocolTCP,
|
||||
action: fw.ActionAccept,
|
||||
},
|
||||
shouldPass: false,
|
||||
},
|
||||
{
|
||||
name: "ICMP packet doesn't match UDP filter",
|
||||
srcIP: "100.10.0.1",
|
||||
dstIP: "192.168.1.100",
|
||||
proto: fw.ProtocolICMP,
|
||||
rule: rule{
|
||||
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
||||
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||
proto: fw.ProtocolUDP,
|
||||
action: fw.ActionAccept,
|
||||
},
|
||||
shouldPass: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
if tc.rule.action == fw.ActionDrop {
|
||||
// add general accept rule to test drop rule
|
||||
rule, err := manager.AddRouteFiltering(
|
||||
nil,
|
||||
[]netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")},
|
||||
fw.Network{Prefix: netip.MustParsePrefix("0.0.0.0/0")},
|
||||
fw.ProtocolALL,
|
||||
nil,
|
||||
nil,
|
||||
fw.ActionAccept,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, rule)
|
||||
t.Cleanup(func() {
|
||||
require.NoError(t, manager.DeleteRouteRule(rule))
|
||||
})
|
||||
}
|
||||
|
||||
rule, err := manager.AddRouteFiltering(
|
||||
nil,
|
||||
tc.rule.sources,
|
||||
@@ -821,7 +1233,7 @@ func TestRouteACLFiltering(t *testing.T) {
|
||||
srcIP := netip.MustParseAddr(tc.srcIP)
|
||||
dstIP := netip.MustParseAddr(tc.dstIP)
|
||||
|
||||
// testing routeACLsPass only and not DropIncoming, as routed packets are dropped after being passed
|
||||
// testing routeACLsPass only and not FilterInbound, as routed packets are dropped after being passed
|
||||
// to the forwarder
|
||||
_, isAllowed := manager.routeACLsPass(srcIP, dstIP, tc.proto, tc.srcPort, tc.dstPort)
|
||||
require.Equal(t, tc.shouldPass, isAllowed)
|
||||
@@ -836,7 +1248,7 @@ func TestRouteACLOrder(t *testing.T) {
|
||||
name string
|
||||
rules []struct {
|
||||
sources []netip.Prefix
|
||||
dest netip.Prefix
|
||||
dest fw.Network
|
||||
proto fw.Protocol
|
||||
srcPort *fw.Port
|
||||
dstPort *fw.Port
|
||||
@@ -857,7 +1269,7 @@ func TestRouteACLOrder(t *testing.T) {
|
||||
name: "Drop rules take precedence over accept",
|
||||
rules: []struct {
|
||||
sources []netip.Prefix
|
||||
dest netip.Prefix
|
||||
dest fw.Network
|
||||
proto fw.Protocol
|
||||
srcPort *fw.Port
|
||||
dstPort *fw.Port
|
||||
@@ -866,7 +1278,7 @@ func TestRouteACLOrder(t *testing.T) {
|
||||
{
|
||||
// Accept rule added first
|
||||
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
||||
dest: netip.MustParsePrefix("192.168.1.0/24"),
|
||||
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||
proto: fw.ProtocolTCP,
|
||||
dstPort: &fw.Port{Values: []uint16{80, 443}},
|
||||
action: fw.ActionAccept,
|
||||
@@ -874,7 +1286,7 @@ func TestRouteACLOrder(t *testing.T) {
|
||||
{
|
||||
// Drop rule added second but should be evaluated first
|
||||
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
||||
dest: netip.MustParsePrefix("192.168.1.0/24"),
|
||||
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||
proto: fw.ProtocolTCP,
|
||||
dstPort: &fw.Port{Values: []uint16{443}},
|
||||
action: fw.ActionDrop,
|
||||
@@ -912,7 +1324,7 @@ func TestRouteACLOrder(t *testing.T) {
|
||||
name: "Multiple drop rules take precedence",
|
||||
rules: []struct {
|
||||
sources []netip.Prefix
|
||||
dest netip.Prefix
|
||||
dest fw.Network
|
||||
proto fw.Protocol
|
||||
srcPort *fw.Port
|
||||
dstPort *fw.Port
|
||||
@@ -921,14 +1333,14 @@ func TestRouteACLOrder(t *testing.T) {
|
||||
{
|
||||
// Accept all
|
||||
sources: []netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")},
|
||||
dest: netip.MustParsePrefix("0.0.0.0/0"),
|
||||
dest: fw.Network{Prefix: netip.MustParsePrefix("0.0.0.0/0")},
|
||||
proto: fw.ProtocolALL,
|
||||
action: fw.ActionAccept,
|
||||
},
|
||||
{
|
||||
// Drop specific port
|
||||
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
||||
dest: netip.MustParsePrefix("192.168.1.0/24"),
|
||||
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||
proto: fw.ProtocolTCP,
|
||||
dstPort: &fw.Port{Values: []uint16{443}},
|
||||
action: fw.ActionDrop,
|
||||
@@ -936,7 +1348,7 @@ func TestRouteACLOrder(t *testing.T) {
|
||||
{
|
||||
// Drop different port
|
||||
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
||||
dest: netip.MustParsePrefix("192.168.1.0/24"),
|
||||
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||
proto: fw.ProtocolTCP,
|
||||
dstPort: &fw.Port{Values: []uint16{80}},
|
||||
action: fw.ActionDrop,
|
||||
@@ -1015,3 +1427,50 @@ func TestRouteACLOrder(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouteACLSet(t *testing.T) {
|
||||
ifaceMock := &IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
AddressFunc: func() wgaddr.Address {
|
||||
return wgaddr.Address{
|
||||
IP: netip.MustParseAddr("100.10.0.100"),
|
||||
Network: netip.MustParsePrefix("100.10.0.0/16"),
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
manager, err := Create(ifaceMock, false, flowLogger)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
require.NoError(t, manager.Close(nil))
|
||||
})
|
||||
|
||||
set := fw.NewDomainSet(domain.List{"example.org"})
|
||||
|
||||
// Add rule that uses the set (initially empty)
|
||||
rule, err := manager.AddRouteFiltering(
|
||||
nil,
|
||||
[]netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")},
|
||||
fw.Network{Set: set},
|
||||
fw.ProtocolTCP,
|
||||
nil,
|
||||
nil,
|
||||
fw.ActionAccept,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, rule)
|
||||
|
||||
srcIP := netip.MustParseAddr("100.10.0.1")
|
||||
dstIP := netip.MustParseAddr("192.168.1.100")
|
||||
|
||||
// Check that traffic is dropped (empty set shouldn't match anything)
|
||||
_, isAllowed := manager.routeACLsPass(srcIP, dstIP, fw.ProtocolTCP, 12345, 80)
|
||||
require.False(t, isAllowed, "Empty set should not allow any traffic")
|
||||
|
||||
err = manager.UpdateSet(set, []netip.Prefix{netip.MustParsePrefix("192.168.1.0/24")})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Now the packet should be allowed
|
||||
_, isAllowed = manager.routeACLsPass(srcIP, dstIP, fw.ProtocolTCP, 12345, 80)
|
||||
require.True(t, isAllowed, "After set update, traffic to the added network should be allowed")
|
||||
}
|
||||
@@ -20,6 +20,7 @@ import (
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
"github.com/netbirdio/netbird/client/internal/netflow"
|
||||
"github.com/netbirdio/netbird/management/domain"
|
||||
)
|
||||
|
||||
var logger = log.NewFromLogrus(logrus.StandardLogger())
|
||||
@@ -270,11 +271,8 @@ func TestNotMatchByIP(t *testing.T) {
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
AddressFunc: func() wgaddr.Address {
|
||||
return wgaddr.Address{
|
||||
IP: net.ParseIP("100.10.0.100"),
|
||||
Network: &net.IPNet{
|
||||
IP: net.ParseIP("100.10.0.0"),
|
||||
Mask: net.CIDRMask(16, 32),
|
||||
},
|
||||
IP: netip.MustParseAddr("100.10.0.100"),
|
||||
Network: netip.MustParsePrefix("100.10.0.0/16"),
|
||||
}
|
||||
},
|
||||
}
|
||||
@@ -284,10 +282,6 @@ func TestNotMatchByIP(t *testing.T) {
|
||||
t.Errorf("failed to create Manager: %v", err)
|
||||
return
|
||||
}
|
||||
m.wgNetwork = &net.IPNet{
|
||||
IP: net.ParseIP("100.10.0.0"),
|
||||
Mask: net.CIDRMask(16, 32),
|
||||
}
|
||||
|
||||
ip := net.ParseIP("0.0.0.0")
|
||||
proto := fw.ProtocolUDP
|
||||
@@ -327,7 +321,7 @@ func TestNotMatchByIP(t *testing.T) {
|
||||
return
|
||||
}
|
||||
|
||||
if m.dropFilter(buf.Bytes(), 0) {
|
||||
if m.filterInbound(buf.Bytes(), 0) {
|
||||
t.Errorf("expected packet to be accepted")
|
||||
return
|
||||
}
|
||||
@@ -395,10 +389,6 @@ func TestProcessOutgoingHooks(t *testing.T) {
|
||||
}, false, flowLogger)
|
||||
require.NoError(t, err)
|
||||
|
||||
manager.wgNetwork = &net.IPNet{
|
||||
IP: net.ParseIP("100.10.0.0"),
|
||||
Mask: net.CIDRMask(16, 32),
|
||||
}
|
||||
manager.udpTracker.Close()
|
||||
manager.udpTracker = conntrack.NewUDPTracker(100*time.Millisecond, logger, flowLogger)
|
||||
defer func() {
|
||||
@@ -457,7 +447,7 @@ func TestProcessOutgoingHooks(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
// Test hook gets called
|
||||
result := manager.processOutgoingHooks(buf.Bytes(), 0)
|
||||
result := manager.filterOutbound(buf.Bytes(), 0)
|
||||
require.True(t, result)
|
||||
require.True(t, hookCalled)
|
||||
|
||||
@@ -467,7 +457,7 @@ func TestProcessOutgoingHooks(t *testing.T) {
|
||||
err = gopacket.SerializeLayers(buf, opts, ipv4)
|
||||
require.NoError(t, err)
|
||||
|
||||
result = manager.processOutgoingHooks(buf.Bytes(), 0)
|
||||
result = manager.filterOutbound(buf.Bytes(), 0)
|
||||
require.False(t, result)
|
||||
}
|
||||
|
||||
@@ -508,11 +498,6 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
|
||||
}, false, flowLogger)
|
||||
require.NoError(t, err)
|
||||
|
||||
manager.wgNetwork = &net.IPNet{
|
||||
IP: net.ParseIP("100.10.0.0"),
|
||||
Mask: net.CIDRMask(16, 32),
|
||||
}
|
||||
|
||||
manager.udpTracker.Close() // Close the existing tracker
|
||||
manager.udpTracker = conntrack.NewUDPTracker(200*time.Millisecond, logger, flowLogger)
|
||||
manager.decoders = sync.Pool{
|
||||
@@ -568,7 +553,7 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
// Process outbound packet and verify connection tracking
|
||||
drop := manager.DropOutgoing(outboundBuf.Bytes(), 0)
|
||||
drop := manager.FilterOutbound(outboundBuf.Bytes(), 0)
|
||||
require.False(t, drop, "Initial outbound packet should not be dropped")
|
||||
|
||||
// Verify connection was tracked
|
||||
@@ -635,7 +620,7 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
|
||||
for _, cp := range checkPoints {
|
||||
time.Sleep(cp.sleep)
|
||||
|
||||
drop = manager.dropFilter(inboundBuf.Bytes(), 0)
|
||||
drop = manager.filterInbound(inboundBuf.Bytes(), 0)
|
||||
require.Equal(t, cp.shouldAllow, !drop, cp.description)
|
||||
|
||||
// If the connection should still be valid, verify it exists
|
||||
@@ -684,7 +669,7 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
|
||||
}
|
||||
|
||||
// Create a new outbound connection for invalid tests
|
||||
drop = manager.processOutgoingHooks(outboundBuf.Bytes(), 0)
|
||||
drop = manager.filterOutbound(outboundBuf.Bytes(), 0)
|
||||
require.False(t, drop, "Second outbound packet should not be dropped")
|
||||
|
||||
for _, tc := range invalidCases {
|
||||
@@ -706,8 +691,208 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify the invalid packet is dropped
|
||||
drop = manager.dropFilter(testBuf.Bytes(), 0)
|
||||
drop = manager.filterInbound(testBuf.Bytes(), 0)
|
||||
require.True(t, drop, tc.description)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateSetMerge(t *testing.T) {
|
||||
ifaceMock := &IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}
|
||||
|
||||
manager, err := Create(ifaceMock, false, flowLogger)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
require.NoError(t, manager.Close(nil))
|
||||
})
|
||||
|
||||
set := fw.NewDomainSet(domain.List{"example.org"})
|
||||
|
||||
initialPrefixes := []netip.Prefix{
|
||||
netip.MustParsePrefix("10.0.0.0/24"),
|
||||
netip.MustParsePrefix("192.168.1.0/24"),
|
||||
}
|
||||
|
||||
rule, err := manager.AddRouteFiltering(
|
||||
nil,
|
||||
[]netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")},
|
||||
fw.Network{Set: set},
|
||||
fw.ProtocolTCP,
|
||||
nil,
|
||||
nil,
|
||||
fw.ActionAccept,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, rule)
|
||||
|
||||
// Update the set with initial prefixes
|
||||
err = manager.UpdateSet(set, initialPrefixes)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Test initial prefixes work
|
||||
srcIP := netip.MustParseAddr("100.10.0.1")
|
||||
dstIP1 := netip.MustParseAddr("10.0.0.100")
|
||||
dstIP2 := netip.MustParseAddr("192.168.1.100")
|
||||
dstIP3 := netip.MustParseAddr("172.16.0.100")
|
||||
|
||||
_, isAllowed1 := manager.routeACLsPass(srcIP, dstIP1, fw.ProtocolTCP, 12345, 80)
|
||||
_, isAllowed2 := manager.routeACLsPass(srcIP, dstIP2, fw.ProtocolTCP, 12345, 80)
|
||||
_, isAllowed3 := manager.routeACLsPass(srcIP, dstIP3, fw.ProtocolTCP, 12345, 80)
|
||||
|
||||
require.True(t, isAllowed1, "Traffic to 10.0.0.100 should be allowed")
|
||||
require.True(t, isAllowed2, "Traffic to 192.168.1.100 should be allowed")
|
||||
require.False(t, isAllowed3, "Traffic to 172.16.0.100 should be denied")
|
||||
|
||||
newPrefixes := []netip.Prefix{
|
||||
netip.MustParsePrefix("172.16.0.0/16"),
|
||||
netip.MustParsePrefix("10.1.0.0/24"),
|
||||
}
|
||||
|
||||
err = manager.UpdateSet(set, newPrefixes)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Check that all original prefixes are still included
|
||||
_, isAllowed1 = manager.routeACLsPass(srcIP, dstIP1, fw.ProtocolTCP, 12345, 80)
|
||||
_, isAllowed2 = manager.routeACLsPass(srcIP, dstIP2, fw.ProtocolTCP, 12345, 80)
|
||||
require.True(t, isAllowed1, "Traffic to 10.0.0.100 should still be allowed after update")
|
||||
require.True(t, isAllowed2, "Traffic to 192.168.1.100 should still be allowed after update")
|
||||
|
||||
// Check that new prefixes are included
|
||||
dstIP4 := netip.MustParseAddr("172.16.1.100")
|
||||
dstIP5 := netip.MustParseAddr("10.1.0.50")
|
||||
|
||||
_, isAllowed4 := manager.routeACLsPass(srcIP, dstIP4, fw.ProtocolTCP, 12345, 80)
|
||||
_, isAllowed5 := manager.routeACLsPass(srcIP, dstIP5, fw.ProtocolTCP, 12345, 80)
|
||||
|
||||
require.True(t, isAllowed4, "Traffic to new prefix 172.16.0.0/16 should be allowed")
|
||||
require.True(t, isAllowed5, "Traffic to new prefix 10.1.0.0/24 should be allowed")
|
||||
|
||||
// Verify the rule has all prefixes
|
||||
manager.mutex.RLock()
|
||||
foundRule := false
|
||||
for _, r := range manager.routeRules {
|
||||
if r.id == rule.ID() {
|
||||
foundRule = true
|
||||
require.Len(t, r.destinations, len(initialPrefixes)+len(newPrefixes),
|
||||
"Rule should have all prefixes merged")
|
||||
}
|
||||
}
|
||||
manager.mutex.RUnlock()
|
||||
require.True(t, foundRule, "Rule should be found")
|
||||
}
|
||||
|
||||
func TestUpdateSetDeduplication(t *testing.T) {
|
||||
ifaceMock := &IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}
|
||||
|
||||
manager, err := Create(ifaceMock, false, flowLogger)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
require.NoError(t, manager.Close(nil))
|
||||
})
|
||||
|
||||
set := fw.NewDomainSet(domain.List{"example.org"})
|
||||
|
||||
rule, err := manager.AddRouteFiltering(
|
||||
nil,
|
||||
[]netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")},
|
||||
fw.Network{Set: set},
|
||||
fw.ProtocolTCP,
|
||||
nil,
|
||||
nil,
|
||||
fw.ActionAccept,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, rule)
|
||||
|
||||
initialPrefixes := []netip.Prefix{
|
||||
netip.MustParsePrefix("10.0.0.0/24"),
|
||||
netip.MustParsePrefix("10.0.0.0/24"), // Duplicate
|
||||
netip.MustParsePrefix("192.168.1.0/24"),
|
||||
netip.MustParsePrefix("192.168.1.0/24"), // Duplicate
|
||||
}
|
||||
|
||||
err = manager.UpdateSet(set, initialPrefixes)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Check the internal state for deduplication
|
||||
manager.mutex.RLock()
|
||||
foundRule := false
|
||||
for _, r := range manager.routeRules {
|
||||
if r.id == rule.ID() {
|
||||
foundRule = true
|
||||
// Should have deduplicated to 2 prefixes
|
||||
require.Len(t, r.destinations, 2, "Duplicate prefixes should be removed")
|
||||
|
||||
// Check the prefixes are correct
|
||||
expectedPrefixes := []netip.Prefix{
|
||||
netip.MustParsePrefix("10.0.0.0/24"),
|
||||
netip.MustParsePrefix("192.168.1.0/24"),
|
||||
}
|
||||
for i, prefix := range expectedPrefixes {
|
||||
require.True(t, r.destinations[i] == prefix,
|
||||
"Prefix should match expected value")
|
||||
}
|
||||
}
|
||||
}
|
||||
manager.mutex.RUnlock()
|
||||
require.True(t, foundRule, "Rule should be found")
|
||||
|
||||
// Test with overlapping prefixes of different sizes
|
||||
overlappingPrefixes := []netip.Prefix{
|
||||
netip.MustParsePrefix("10.0.0.0/16"), // More general
|
||||
netip.MustParsePrefix("10.0.0.0/24"), // More specific (already exists)
|
||||
netip.MustParsePrefix("192.168.0.0/16"), // More general
|
||||
netip.MustParsePrefix("192.168.1.0/24"), // More specific (already exists)
|
||||
}
|
||||
|
||||
err = manager.UpdateSet(set, overlappingPrefixes)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Check that all prefixes are included (no deduplication of overlapping prefixes)
|
||||
manager.mutex.RLock()
|
||||
for _, r := range manager.routeRules {
|
||||
if r.id == rule.ID() {
|
||||
// Should have all 4 prefixes (2 original + 2 new more general ones)
|
||||
require.Len(t, r.destinations, 4,
|
||||
"Overlapping prefixes should not be deduplicated")
|
||||
|
||||
// Verify they're sorted correctly (more specific prefixes should come first)
|
||||
prefixes := make([]string, 0, len(r.destinations))
|
||||
for _, p := range r.destinations {
|
||||
prefixes = append(prefixes, p.String())
|
||||
}
|
||||
|
||||
// Check sorted order
|
||||
require.Equal(t, []string{
|
||||
"10.0.0.0/16",
|
||||
"10.0.0.0/24",
|
||||
"192.168.0.0/16",
|
||||
"192.168.1.0/24",
|
||||
}, prefixes, "Prefixes should be sorted")
|
||||
}
|
||||
}
|
||||
manager.mutex.RUnlock()
|
||||
|
||||
// Test functionality with all prefixes
|
||||
testCases := []struct {
|
||||
dstIP netip.Addr
|
||||
expected bool
|
||||
desc string
|
||||
}{
|
||||
{netip.MustParseAddr("10.0.0.100"), true, "IP in both /16 and /24"},
|
||||
{netip.MustParseAddr("10.0.1.100"), true, "IP only in /16"},
|
||||
{netip.MustParseAddr("192.168.1.100"), true, "IP in both /16 and /24"},
|
||||
{netip.MustParseAddr("192.168.2.100"), true, "IP only in /16"},
|
||||
{netip.MustParseAddr("172.16.0.100"), false, "IP not in any prefix"},
|
||||
}
|
||||
|
||||
srcIP := netip.MustParseAddr("100.10.0.1")
|
||||
for _, tc := range testCases {
|
||||
_, isAllowed := manager.routeACLsPass(srcIP, tc.dstIP, fw.ProtocolTCP, 12345, 80)
|
||||
require.Equal(t, tc.expected, isAllowed, tc.desc)
|
||||
}
|
||||
}
|
||||
@@ -86,5 +86,5 @@ type epID stack.TransportEndpointID
|
||||
|
||||
func (i epID) String() string {
|
||||
// src and remote is swapped
|
||||
return fmt.Sprintf("%s:%d -> %s:%d", i.RemoteAddress, i.RemotePort, i.LocalAddress, i.LocalPort)
|
||||
return fmt.Sprintf("%s:%d → %s:%d", i.RemoteAddress, i.RemotePort, i.LocalAddress, i.LocalPort)
|
||||
}
|
||||
|
||||
@@ -4,7 +4,9 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"runtime"
|
||||
"sync"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"gvisor.dev/gvisor/pkg/buffer"
|
||||
@@ -17,6 +19,7 @@ import (
|
||||
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
|
||||
|
||||
"github.com/netbirdio/netbird/client/firewall/uspfilter/common"
|
||||
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
|
||||
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
||||
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||
)
|
||||
@@ -29,14 +32,16 @@ const (
|
||||
)
|
||||
|
||||
type Forwarder struct {
|
||||
logger *nblog.Logger
|
||||
flowLogger nftypes.FlowLogger
|
||||
logger *nblog.Logger
|
||||
flowLogger nftypes.FlowLogger
|
||||
// ruleIdMap is used to store the rule ID for a given connection
|
||||
ruleIdMap sync.Map
|
||||
stack *stack.Stack
|
||||
endpoint *endpoint
|
||||
udpForwarder *udpForwarder
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
ip net.IP
|
||||
ip tcpip.Address
|
||||
netstack bool
|
||||
}
|
||||
|
||||
@@ -66,12 +71,11 @@ func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.Flow
|
||||
return nil, fmt.Errorf("failed to create NIC: %v", err)
|
||||
}
|
||||
|
||||
ones, _ := iface.Address().Network.Mask.Size()
|
||||
protoAddr := tcpip.ProtocolAddress{
|
||||
Protocol: ipv4.ProtocolNumber,
|
||||
AddressWithPrefix: tcpip.AddressWithPrefix{
|
||||
Address: tcpip.AddrFromSlice(iface.Address().IP.To4()),
|
||||
PrefixLen: ones,
|
||||
Address: tcpip.AddrFromSlice(iface.Address().IP.AsSlice()),
|
||||
PrefixLen: iface.Address().Network.Bits(),
|
||||
},
|
||||
}
|
||||
|
||||
@@ -111,7 +115,7 @@ func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.Flow
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
netstack: netstack,
|
||||
ip: iface.Address().IP,
|
||||
ip: tcpip.AddrFromSlice(iface.Address().IP.AsSlice()),
|
||||
}
|
||||
|
||||
receiveWindow := defaultReceiveWindow
|
||||
@@ -162,8 +166,39 @@ func (f *Forwarder) Stop() {
|
||||
}
|
||||
|
||||
func (f *Forwarder) determineDialAddr(addr tcpip.Address) net.IP {
|
||||
if f.netstack && f.ip.Equal(addr.AsSlice()) {
|
||||
if f.netstack && f.ip.Equal(addr) {
|
||||
return net.IPv4(127, 0, 0, 1)
|
||||
}
|
||||
return addr.AsSlice()
|
||||
}
|
||||
|
||||
func (f *Forwarder) RegisterRuleID(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, ruleID []byte) {
|
||||
key := buildKey(srcIP, dstIP, srcPort, dstPort)
|
||||
f.ruleIdMap.LoadOrStore(key, ruleID)
|
||||
}
|
||||
|
||||
func (f *Forwarder) getRuleID(srcIP, dstIP netip.Addr, srcPort, dstPort uint16) ([]byte, bool) {
|
||||
if value, ok := f.ruleIdMap.Load(buildKey(srcIP, dstIP, srcPort, dstPort)); ok {
|
||||
return value.([]byte), true
|
||||
} else if value, ok := f.ruleIdMap.Load(buildKey(dstIP, srcIP, dstPort, srcPort)); ok {
|
||||
return value.([]byte), true
|
||||
}
|
||||
|
||||
return nil, false
|
||||
}
|
||||
|
||||
func (f *Forwarder) DeleteRuleID(srcIP, dstIP netip.Addr, srcPort, dstPort uint16) {
|
||||
if _, ok := f.ruleIdMap.LoadAndDelete(buildKey(srcIP, dstIP, srcPort, dstPort)); ok {
|
||||
return
|
||||
}
|
||||
f.ruleIdMap.LoadAndDelete(buildKey(dstIP, srcIP, dstPort, srcPort))
|
||||
}
|
||||
|
||||
func buildKey(srcIP, dstIP netip.Addr, srcPort, dstPort uint16) conntrack.ConnKey {
|
||||
return conntrack.ConnKey{
|
||||
SrcIP: srcIP,
|
||||
DstIP: dstIP,
|
||||
SrcPort: srcPort,
|
||||
DstPort: dstPort,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -25,7 +25,7 @@ func (f *Forwarder) handleICMP(id stack.TransportEndpointID, pkt stack.PacketBuf
|
||||
}
|
||||
|
||||
flowID := uuid.New()
|
||||
f.sendICMPEvent(nftypes.TypeStart, flowID, id, icmpType, icmpCode)
|
||||
f.sendICMPEvent(nftypes.TypeStart, flowID, id, icmpType, icmpCode, 0, 0)
|
||||
|
||||
ctx, cancel := context.WithTimeout(f.ctx, 5*time.Second)
|
||||
defer cancel()
|
||||
@@ -34,14 +34,14 @@ func (f *Forwarder) handleICMP(id stack.TransportEndpointID, pkt stack.PacketBuf
|
||||
// TODO: support non-root
|
||||
conn, err := lc.ListenPacket(ctx, "ip4:icmp", "0.0.0.0")
|
||||
if err != nil {
|
||||
f.logger.Error("Failed to create ICMP socket for %v: %v", epID(id), err)
|
||||
f.logger.Error("forwarder: Failed to create ICMP socket for %v: %v", epID(id), err)
|
||||
|
||||
// This will make netstack reply on behalf of the original destination, that's ok for now
|
||||
return false
|
||||
}
|
||||
defer func() {
|
||||
if err := conn.Close(); err != nil {
|
||||
f.logger.Debug("Failed to close ICMP socket: %v", err)
|
||||
f.logger.Debug("forwarder: Failed to close ICMP socket: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
@@ -52,36 +52,37 @@ func (f *Forwarder) handleICMP(id stack.TransportEndpointID, pkt stack.PacketBuf
|
||||
payload := fullPacket.AsSlice()
|
||||
|
||||
if _, err = conn.WriteTo(payload, dst); err != nil {
|
||||
f.logger.Error("Failed to write ICMP packet for %v: %v", epID(id), err)
|
||||
f.logger.Error("forwarder: Failed to write ICMP packet for %v: %v", epID(id), err)
|
||||
return true
|
||||
}
|
||||
|
||||
f.logger.Trace("Forwarded ICMP packet %v type %v code %v",
|
||||
f.logger.Trace("forwarder: Forwarded ICMP packet %v type %v code %v",
|
||||
epID(id), icmpHdr.Type(), icmpHdr.Code())
|
||||
|
||||
// For Echo Requests, send and handle response
|
||||
if header.ICMPv4Type(icmpType) == header.ICMPv4Echo {
|
||||
f.handleEchoResponse(icmpHdr, conn, id)
|
||||
f.sendICMPEvent(nftypes.TypeEnd, flowID, id, icmpType, icmpCode)
|
||||
rxBytes := pkt.Size()
|
||||
txBytes := f.handleEchoResponse(icmpHdr, conn, id)
|
||||
f.sendICMPEvent(nftypes.TypeEnd, flowID, id, icmpType, icmpCode, uint64(rxBytes), uint64(txBytes))
|
||||
}
|
||||
|
||||
// For other ICMP types (Time Exceeded, Destination Unreachable, etc) do nothing
|
||||
return true
|
||||
}
|
||||
|
||||
func (f *Forwarder) handleEchoResponse(icmpHdr header.ICMPv4, conn net.PacketConn, id stack.TransportEndpointID) {
|
||||
func (f *Forwarder) handleEchoResponse(icmpHdr header.ICMPv4, conn net.PacketConn, id stack.TransportEndpointID) int {
|
||||
if err := conn.SetReadDeadline(time.Now().Add(5 * time.Second)); err != nil {
|
||||
f.logger.Error("Failed to set read deadline for ICMP response: %v", err)
|
||||
return
|
||||
f.logger.Error("forwarder: Failed to set read deadline for ICMP response: %v", err)
|
||||
return 0
|
||||
}
|
||||
|
||||
response := make([]byte, f.endpoint.mtu)
|
||||
n, _, err := conn.ReadFrom(response)
|
||||
if err != nil {
|
||||
if !isTimeout(err) {
|
||||
f.logger.Error("Failed to read ICMP response: %v", err)
|
||||
f.logger.Error("forwarder: Failed to read ICMP response: %v", err)
|
||||
}
|
||||
return
|
||||
return 0
|
||||
}
|
||||
|
||||
ipHdr := make([]byte, header.IPv4MinimumSize)
|
||||
@@ -100,28 +101,54 @@ func (f *Forwarder) handleEchoResponse(icmpHdr header.ICMPv4, conn net.PacketCon
|
||||
fullPacket = append(fullPacket, response[:n]...)
|
||||
|
||||
if err := f.InjectIncomingPacket(fullPacket); err != nil {
|
||||
f.logger.Error("Failed to inject ICMP response: %v", err)
|
||||
f.logger.Error("forwarder: Failed to inject ICMP response: %v", err)
|
||||
|
||||
return
|
||||
return 0
|
||||
}
|
||||
|
||||
f.logger.Trace("Forwarded ICMP echo reply for %v type %v code %v",
|
||||
f.logger.Trace("forwarder: Forwarded ICMP echo reply for %v type %v code %v",
|
||||
epID(id), icmpHdr.Type(), icmpHdr.Code())
|
||||
|
||||
return len(fullPacket)
|
||||
}
|
||||
|
||||
// sendICMPEvent stores flow events for ICMP packets
|
||||
func (f *Forwarder) sendICMPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.TransportEndpointID, icmpType, icmpCode uint8) {
|
||||
f.flowLogger.StoreEvent(nftypes.EventFields{
|
||||
func (f *Forwarder) sendICMPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.TransportEndpointID, icmpType, icmpCode uint8, rxBytes, txBytes uint64) {
|
||||
var rxPackets, txPackets uint64
|
||||
if rxBytes > 0 {
|
||||
rxPackets = 1
|
||||
}
|
||||
if txBytes > 0 {
|
||||
txPackets = 1
|
||||
}
|
||||
|
||||
srcIp := netip.AddrFrom4(id.RemoteAddress.As4())
|
||||
dstIp := netip.AddrFrom4(id.LocalAddress.As4())
|
||||
|
||||
fields := nftypes.EventFields{
|
||||
FlowID: flowID,
|
||||
Type: typ,
|
||||
Direction: nftypes.Ingress,
|
||||
Protocol: nftypes.ICMP,
|
||||
// TODO: handle ipv6
|
||||
SourceIP: netip.AddrFrom4(id.RemoteAddress.As4()),
|
||||
DestIP: netip.AddrFrom4(id.LocalAddress.As4()),
|
||||
SourceIP: srcIp,
|
||||
DestIP: dstIp,
|
||||
ICMPType: icmpType,
|
||||
ICMPCode: icmpCode,
|
||||
|
||||
// TODO: get packets/bytes
|
||||
})
|
||||
RxBytes: rxBytes,
|
||||
TxBytes: txBytes,
|
||||
RxPackets: rxPackets,
|
||||
TxPackets: txPackets,
|
||||
}
|
||||
|
||||
if typ == nftypes.TypeStart {
|
||||
if ruleId, ok := f.getRuleID(srcIp, dstIp, id.RemotePort, id.LocalPort); ok {
|
||||
fields.RuleID = ruleId
|
||||
}
|
||||
} else {
|
||||
f.DeleteRuleID(srcIp, dstIp, id.RemotePort, id.LocalPort)
|
||||
}
|
||||
|
||||
f.flowLogger.StoreEvent(fields)
|
||||
}
|
||||
|
||||
@@ -6,8 +6,10 @@ import (
|
||||
"io"
|
||||
"net"
|
||||
"net/netip"
|
||||
"sync"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
||||
"gvisor.dev/gvisor/pkg/tcpip"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/stack"
|
||||
@@ -23,11 +25,11 @@ func (f *Forwarder) handleTCP(r *tcp.ForwarderRequest) {
|
||||
|
||||
flowID := uuid.New()
|
||||
|
||||
f.sendTCPEvent(nftypes.TypeStart, flowID, id, nil)
|
||||
f.sendTCPEvent(nftypes.TypeStart, flowID, id, 0, 0, 0, 0)
|
||||
var success bool
|
||||
defer func() {
|
||||
if !success {
|
||||
f.sendTCPEvent(nftypes.TypeEnd, flowID, id, nil)
|
||||
f.sendTCPEvent(nftypes.TypeEnd, flowID, id, 0, 0, 0, 0)
|
||||
}
|
||||
}()
|
||||
|
||||
@@ -65,67 +67,97 @@ func (f *Forwarder) handleTCP(r *tcp.ForwarderRequest) {
|
||||
}
|
||||
|
||||
func (f *Forwarder) proxyTCP(id stack.TransportEndpointID, inConn *gonet.TCPConn, outConn net.Conn, ep tcpip.Endpoint, flowID uuid.UUID) {
|
||||
defer func() {
|
||||
if err := inConn.Close(); err != nil {
|
||||
f.logger.Debug("forwarder: inConn close error: %v", err)
|
||||
}
|
||||
if err := outConn.Close(); err != nil {
|
||||
f.logger.Debug("forwarder: outConn close error: %v", err)
|
||||
}
|
||||
ep.Close()
|
||||
|
||||
f.sendTCPEvent(nftypes.TypeEnd, flowID, id, ep)
|
||||
}()
|
||||
|
||||
// Create context for managing the proxy goroutines
|
||||
ctx, cancel := context.WithCancel(f.ctx)
|
||||
defer cancel()
|
||||
|
||||
errChan := make(chan error, 2)
|
||||
|
||||
go func() {
|
||||
_, err := io.Copy(outConn, inConn)
|
||||
errChan <- err
|
||||
}()
|
||||
|
||||
go func() {
|
||||
_, err := io.Copy(inConn, outConn)
|
||||
errChan <- err
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
f.logger.Trace("forwarder: tearing down TCP connection %v due to context done", epID(id))
|
||||
return
|
||||
case err := <-errChan:
|
||||
if err != nil && !isClosedError(err) {
|
||||
f.logger.Error("proxyTCP: copy error: %v", err)
|
||||
<-ctx.Done()
|
||||
// Close connections and endpoint.
|
||||
if err := inConn.Close(); err != nil && !isClosedError(err) {
|
||||
f.logger.Debug("forwarder: inConn close error: %v", err)
|
||||
}
|
||||
if err := outConn.Close(); err != nil && !isClosedError(err) {
|
||||
f.logger.Debug("forwarder: outConn close error: %v", err)
|
||||
}
|
||||
|
||||
ep.Close()
|
||||
}()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(2)
|
||||
|
||||
var (
|
||||
bytesFromInToOut int64 // bytes from client to server (tx for client)
|
||||
bytesFromOutToIn int64 // bytes from server to client (rx for client)
|
||||
errInToOut error
|
||||
errOutToIn error
|
||||
)
|
||||
|
||||
go func() {
|
||||
bytesFromInToOut, errInToOut = io.Copy(outConn, inConn)
|
||||
cancel()
|
||||
wg.Done()
|
||||
}()
|
||||
|
||||
go func() {
|
||||
|
||||
bytesFromOutToIn, errOutToIn = io.Copy(inConn, outConn)
|
||||
cancel()
|
||||
wg.Done()
|
||||
}()
|
||||
|
||||
wg.Wait()
|
||||
|
||||
if errInToOut != nil {
|
||||
if !isClosedError(errInToOut) {
|
||||
f.logger.Error("proxyTCP: copy error (in → out) for %s: %v", epID(id), errInToOut)
|
||||
}
|
||||
f.logger.Trace("forwarder: tearing down TCP connection %v", epID(id))
|
||||
return
|
||||
}
|
||||
if errOutToIn != nil {
|
||||
if !isClosedError(errOutToIn) {
|
||||
f.logger.Error("proxyTCP: copy error (out → in) for %s: %v", epID(id), errOutToIn)
|
||||
}
|
||||
}
|
||||
|
||||
var rxPackets, txPackets uint64
|
||||
if tcpStats, ok := ep.Stats().(*tcp.Stats); ok {
|
||||
// fields are flipped since this is the in conn
|
||||
rxPackets = tcpStats.SegmentsSent.Value()
|
||||
txPackets = tcpStats.SegmentsReceived.Value()
|
||||
}
|
||||
|
||||
f.logger.Trace("forwarder: Removed TCP connection %s [in: %d Pkts/%d B, out: %d Pkts/%d B]", epID(id), rxPackets, bytesFromOutToIn, txPackets, bytesFromInToOut)
|
||||
|
||||
f.sendTCPEvent(nftypes.TypeEnd, flowID, id, uint64(bytesFromOutToIn), uint64(bytesFromInToOut), rxPackets, txPackets)
|
||||
}
|
||||
|
||||
func (f *Forwarder) sendTCPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.TransportEndpointID, ep tcpip.Endpoint) {
|
||||
func (f *Forwarder) sendTCPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.TransportEndpointID, rxBytes, txBytes, rxPackets, txPackets uint64) {
|
||||
srcIp := netip.AddrFrom4(id.RemoteAddress.As4())
|
||||
dstIp := netip.AddrFrom4(id.LocalAddress.As4())
|
||||
|
||||
fields := nftypes.EventFields{
|
||||
FlowID: flowID,
|
||||
Type: typ,
|
||||
Direction: nftypes.Ingress,
|
||||
Protocol: nftypes.TCP,
|
||||
// TODO: handle ipv6
|
||||
SourceIP: netip.AddrFrom4(id.RemoteAddress.As4()),
|
||||
DestIP: netip.AddrFrom4(id.LocalAddress.As4()),
|
||||
SourceIP: srcIp,
|
||||
DestIP: dstIp,
|
||||
SourcePort: id.RemotePort,
|
||||
DestPort: id.LocalPort,
|
||||
RxBytes: rxBytes,
|
||||
TxBytes: txBytes,
|
||||
RxPackets: rxPackets,
|
||||
TxPackets: txPackets,
|
||||
}
|
||||
|
||||
if ep != nil {
|
||||
if tcpStats, ok := ep.Stats().(*tcp.Stats); ok {
|
||||
// fields are flipped since this is the in conn
|
||||
// TODO: get bytes
|
||||
fields.RxPackets = tcpStats.SegmentsSent.Value()
|
||||
fields.TxPackets = tcpStats.SegmentsReceived.Value()
|
||||
if typ == nftypes.TypeStart {
|
||||
if ruleId, ok := f.getRuleID(srcIp, dstIp, id.RemotePort, id.LocalPort); ok {
|
||||
fields.RuleID = ruleId
|
||||
}
|
||||
} else {
|
||||
f.DeleteRuleID(srcIp, dstIp, id.RemotePort, id.LocalPort)
|
||||
}
|
||||
|
||||
f.flowLogger.StoreEvent(fields)
|
||||
|
||||
@@ -149,11 +149,11 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) {
|
||||
|
||||
flowID := uuid.New()
|
||||
|
||||
f.sendUDPEvent(nftypes.TypeStart, flowID, id, nil)
|
||||
f.sendUDPEvent(nftypes.TypeStart, flowID, id, 0, 0, 0, 0)
|
||||
var success bool
|
||||
defer func() {
|
||||
if !success {
|
||||
f.sendUDPEvent(nftypes.TypeEnd, flowID, id, nil)
|
||||
f.sendUDPEvent(nftypes.TypeEnd, flowID, id, 0, 0, 0, 0)
|
||||
}
|
||||
}()
|
||||
|
||||
@@ -199,7 +199,6 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) {
|
||||
if err := outConn.Close(); err != nil {
|
||||
f.logger.Debug("forwarder: UDP outConn close error for %v: %v", epID(id), err)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
f.udpForwarder.conns[id] = pConn
|
||||
@@ -212,68 +211,94 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) {
|
||||
}
|
||||
|
||||
func (f *Forwarder) proxyUDP(ctx context.Context, pConn *udpPacketConn, id stack.TransportEndpointID, ep tcpip.Endpoint) {
|
||||
defer func() {
|
||||
|
||||
ctx, cancel := context.WithCancel(f.ctx)
|
||||
defer cancel()
|
||||
|
||||
go func() {
|
||||
<-ctx.Done()
|
||||
|
||||
pConn.cancel()
|
||||
if err := pConn.conn.Close(); err != nil {
|
||||
if err := pConn.conn.Close(); err != nil && !isClosedError(err) {
|
||||
f.logger.Debug("forwarder: UDP inConn close error for %v: %v", epID(id), err)
|
||||
}
|
||||
if err := pConn.outConn.Close(); err != nil {
|
||||
if err := pConn.outConn.Close(); err != nil && !isClosedError(err) {
|
||||
f.logger.Debug("forwarder: UDP outConn close error for %v: %v", epID(id), err)
|
||||
}
|
||||
|
||||
ep.Close()
|
||||
|
||||
f.udpForwarder.Lock()
|
||||
delete(f.udpForwarder.conns, id)
|
||||
f.udpForwarder.Unlock()
|
||||
|
||||
f.sendUDPEvent(nftypes.TypeEnd, pConn.flowID, id, ep)
|
||||
}()
|
||||
|
||||
errChan := make(chan error, 2)
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(2)
|
||||
|
||||
var txBytes, rxBytes int64
|
||||
var outboundErr, inboundErr error
|
||||
|
||||
// outbound->inbound: copy from pConn.conn to pConn.outConn
|
||||
go func() {
|
||||
errChan <- pConn.copy(ctx, pConn.conn, pConn.outConn, &f.udpForwarder.bufPool, "outbound->inbound")
|
||||
defer wg.Done()
|
||||
txBytes, outboundErr = pConn.copy(ctx, pConn.conn, pConn.outConn, &f.udpForwarder.bufPool, "outbound->inbound")
|
||||
}()
|
||||
|
||||
// inbound->outbound: copy from pConn.outConn to pConn.conn
|
||||
go func() {
|
||||
errChan <- pConn.copy(ctx, pConn.outConn, pConn.conn, &f.udpForwarder.bufPool, "inbound->outbound")
|
||||
defer wg.Done()
|
||||
rxBytes, inboundErr = pConn.copy(ctx, pConn.outConn, pConn.conn, &f.udpForwarder.bufPool, "inbound->outbound")
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
f.logger.Trace("forwarder: tearing down UDP connection %v due to context done", epID(id))
|
||||
return
|
||||
case err := <-errChan:
|
||||
if err != nil && !isClosedError(err) {
|
||||
f.logger.Error("proxyUDP: copy error: %v", err)
|
||||
}
|
||||
f.logger.Trace("forwarder: tearing down UDP connection %v", epID(id))
|
||||
return
|
||||
wg.Wait()
|
||||
|
||||
if outboundErr != nil && !isClosedError(outboundErr) {
|
||||
f.logger.Error("proxyUDP: copy error (outbound→inbound) for %s: %v", epID(id), outboundErr)
|
||||
}
|
||||
if inboundErr != nil && !isClosedError(inboundErr) {
|
||||
f.logger.Error("proxyUDP: copy error (inbound→outbound) for %s: %v", epID(id), inboundErr)
|
||||
}
|
||||
|
||||
var rxPackets, txPackets uint64
|
||||
if udpStats, ok := ep.Stats().(*tcpip.TransportEndpointStats); ok {
|
||||
// fields are flipped since this is the in conn
|
||||
rxPackets = udpStats.PacketsSent.Value()
|
||||
txPackets = udpStats.PacketsReceived.Value()
|
||||
}
|
||||
|
||||
f.logger.Trace("forwarder: Removed UDP connection %s [in: %d Pkts/%d B, out: %d Pkts/%d B]", epID(id), rxPackets, rxBytes, txPackets, txBytes)
|
||||
|
||||
f.udpForwarder.Lock()
|
||||
delete(f.udpForwarder.conns, id)
|
||||
f.udpForwarder.Unlock()
|
||||
|
||||
f.sendUDPEvent(nftypes.TypeEnd, pConn.flowID, id, uint64(rxBytes), uint64(txBytes), rxPackets, txPackets)
|
||||
}
|
||||
|
||||
// sendUDPEvent stores flow events for UDP connections
|
||||
func (f *Forwarder) sendUDPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.TransportEndpointID, ep tcpip.Endpoint) {
|
||||
func (f *Forwarder) sendUDPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.TransportEndpointID, rxBytes, txBytes, rxPackets, txPackets uint64) {
|
||||
srcIp := netip.AddrFrom4(id.RemoteAddress.As4())
|
||||
dstIp := netip.AddrFrom4(id.LocalAddress.As4())
|
||||
|
||||
fields := nftypes.EventFields{
|
||||
FlowID: flowID,
|
||||
Type: typ,
|
||||
Direction: nftypes.Ingress,
|
||||
Protocol: nftypes.UDP,
|
||||
// TODO: handle ipv6
|
||||
SourceIP: netip.AddrFrom4(id.RemoteAddress.As4()),
|
||||
DestIP: netip.AddrFrom4(id.LocalAddress.As4()),
|
||||
SourceIP: srcIp,
|
||||
DestIP: dstIp,
|
||||
SourcePort: id.RemotePort,
|
||||
DestPort: id.LocalPort,
|
||||
RxBytes: rxBytes,
|
||||
TxBytes: txBytes,
|
||||
RxPackets: rxPackets,
|
||||
TxPackets: txPackets,
|
||||
}
|
||||
|
||||
if ep != nil {
|
||||
if tcpStats, ok := ep.Stats().(*tcpip.TransportEndpointStats); ok {
|
||||
// fields are flipped since this is the in conn
|
||||
// TODO: get bytes
|
||||
fields.RxPackets = tcpStats.PacketsSent.Value()
|
||||
fields.TxPackets = tcpStats.PacketsReceived.Value()
|
||||
if typ == nftypes.TypeStart {
|
||||
if ruleId, ok := f.getRuleID(srcIp, dstIp, id.RemotePort, id.LocalPort); ok {
|
||||
fields.RuleID = ruleId
|
||||
}
|
||||
} else {
|
||||
f.DeleteRuleID(srcIp, dstIp, id.RemotePort, id.LocalPort)
|
||||
}
|
||||
|
||||
f.flowLogger.StoreEvent(fields)
|
||||
@@ -288,18 +313,20 @@ func (c *udpPacketConn) getIdleDuration() time.Duration {
|
||||
return time.Since(lastSeen)
|
||||
}
|
||||
|
||||
func (c *udpPacketConn) copy(ctx context.Context, dst net.Conn, src net.Conn, bufPool *sync.Pool, direction string) error {
|
||||
// copy reads from src and writes to dst.
|
||||
func (c *udpPacketConn) copy(ctx context.Context, dst net.Conn, src net.Conn, bufPool *sync.Pool, direction string) (int64, error) {
|
||||
bufp := bufPool.Get().(*[]byte)
|
||||
defer bufPool.Put(bufp)
|
||||
buffer := *bufp
|
||||
var totalBytes int64 = 0
|
||||
|
||||
for {
|
||||
if ctx.Err() != nil {
|
||||
return ctx.Err()
|
||||
return totalBytes, ctx.Err()
|
||||
}
|
||||
|
||||
if err := src.SetDeadline(time.Now().Add(udpTimeout)); err != nil {
|
||||
return fmt.Errorf("set read deadline: %w", err)
|
||||
return totalBytes, fmt.Errorf("set read deadline: %w", err)
|
||||
}
|
||||
|
||||
n, err := src.Read(buffer)
|
||||
@@ -307,14 +334,15 @@ func (c *udpPacketConn) copy(ctx context.Context, dst net.Conn, src net.Conn, bu
|
||||
if isTimeout(err) {
|
||||
continue
|
||||
}
|
||||
return fmt.Errorf("read from %s: %w", direction, err)
|
||||
return totalBytes, fmt.Errorf("read from %s: %w", direction, err)
|
||||
}
|
||||
|
||||
_, err = dst.Write(buffer[:n])
|
||||
nWritten, err := dst.Write(buffer[:n])
|
||||
if err != nil {
|
||||
return fmt.Errorf("write to %s: %w", direction, err)
|
||||
return totalBytes, fmt.Errorf("write to %s: %w", direction, err)
|
||||
}
|
||||
|
||||
totalBytes += int64(nWritten)
|
||||
c.updateLastSeen()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -45,24 +45,26 @@ func (m *localIPManager) setBitmapBit(ip net.IP) {
|
||||
m.ipv4Bitmap[high].bitmap[index] |= 1 << bit
|
||||
}
|
||||
|
||||
func (m *localIPManager) setBitInBitmap(ip net.IP, bitmap *[256]*ipv4LowBitmap, ipv4Set map[string]struct{}, ipv4Addresses *[]string) {
|
||||
if ipv4 := ip.To4(); ipv4 != nil {
|
||||
high := uint16(ipv4[0])
|
||||
low := (uint16(ipv4[1]) << 8) | (uint16(ipv4[2]) << 4) | uint16(ipv4[3])
|
||||
func (m *localIPManager) setBitInBitmap(ip netip.Addr, bitmap *[256]*ipv4LowBitmap, ipv4Set map[netip.Addr]struct{}, ipv4Addresses *[]netip.Addr) {
|
||||
if !ip.Is4() {
|
||||
return
|
||||
}
|
||||
ipv4 := ip.AsSlice()
|
||||
|
||||
if bitmap[high] == nil {
|
||||
bitmap[high] = &ipv4LowBitmap{}
|
||||
}
|
||||
high := uint16(ipv4[0])
|
||||
low := (uint16(ipv4[1]) << 8) | (uint16(ipv4[2]) << 4) | uint16(ipv4[3])
|
||||
|
||||
index := low / 32
|
||||
bit := low % 32
|
||||
bitmap[high].bitmap[index] |= 1 << bit
|
||||
if bitmap[high] == nil {
|
||||
bitmap[high] = &ipv4LowBitmap{}
|
||||
}
|
||||
|
||||
ipStr := ipv4.String()
|
||||
if _, exists := ipv4Set[ipStr]; !exists {
|
||||
ipv4Set[ipStr] = struct{}{}
|
||||
*ipv4Addresses = append(*ipv4Addresses, ipStr)
|
||||
}
|
||||
index := low / 32
|
||||
bit := low % 32
|
||||
bitmap[high].bitmap[index] |= 1 << bit
|
||||
|
||||
if _, exists := ipv4Set[ip]; !exists {
|
||||
ipv4Set[ip] = struct{}{}
|
||||
*ipv4Addresses = append(*ipv4Addresses, ip)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -79,12 +81,12 @@ func (m *localIPManager) checkBitmapBit(ip []byte) bool {
|
||||
return (m.ipv4Bitmap[high].bitmap[index] & (1 << bit)) != 0
|
||||
}
|
||||
|
||||
func (m *localIPManager) processIP(ip net.IP, bitmap *[256]*ipv4LowBitmap, ipv4Set map[string]struct{}, ipv4Addresses *[]string) error {
|
||||
func (m *localIPManager) processIP(ip netip.Addr, bitmap *[256]*ipv4LowBitmap, ipv4Set map[netip.Addr]struct{}, ipv4Addresses *[]netip.Addr) error {
|
||||
m.setBitInBitmap(ip, bitmap, ipv4Set, ipv4Addresses)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *localIPManager) processInterface(iface net.Interface, bitmap *[256]*ipv4LowBitmap, ipv4Set map[string]struct{}, ipv4Addresses *[]string) {
|
||||
func (m *localIPManager) processInterface(iface net.Interface, bitmap *[256]*ipv4LowBitmap, ipv4Set map[netip.Addr]struct{}, ipv4Addresses *[]netip.Addr) {
|
||||
addrs, err := iface.Addrs()
|
||||
if err != nil {
|
||||
log.Debugf("get addresses for interface %s failed: %v", iface.Name, err)
|
||||
@@ -102,7 +104,13 @@ func (m *localIPManager) processInterface(iface net.Interface, bitmap *[256]*ipv
|
||||
continue
|
||||
}
|
||||
|
||||
if err := m.processIP(ip, bitmap, ipv4Set, ipv4Addresses); err != nil {
|
||||
addr, ok := netip.AddrFromSlice(ip)
|
||||
if !ok {
|
||||
log.Warnf("invalid IP address %s in interface %s", ip.String(), iface.Name)
|
||||
continue
|
||||
}
|
||||
|
||||
if err := m.processIP(addr.Unmap(), bitmap, ipv4Set, ipv4Addresses); err != nil {
|
||||
log.Debugf("process IP failed: %v", err)
|
||||
}
|
||||
}
|
||||
@@ -116,8 +124,8 @@ func (m *localIPManager) UpdateLocalIPs(iface common.IFaceMapper) (err error) {
|
||||
}()
|
||||
|
||||
var newIPv4Bitmap [256]*ipv4LowBitmap
|
||||
ipv4Set := make(map[string]struct{})
|
||||
var ipv4Addresses []string
|
||||
ipv4Set := make(map[netip.Addr]struct{})
|
||||
var ipv4Addresses []netip.Addr
|
||||
|
||||
// 127.0.0.0/8
|
||||
newIPv4Bitmap[127] = &ipv4LowBitmap{}
|
||||
|
||||
@@ -20,11 +20,8 @@ func TestLocalIPManager(t *testing.T) {
|
||||
{
|
||||
name: "Localhost range",
|
||||
setupAddr: wgaddr.Address{
|
||||
IP: net.ParseIP("192.168.1.1"),
|
||||
Network: &net.IPNet{
|
||||
IP: net.ParseIP("192.168.1.0"),
|
||||
Mask: net.CIDRMask(24, 32),
|
||||
},
|
||||
IP: netip.MustParseAddr("192.168.1.1"),
|
||||
Network: netip.MustParsePrefix("192.168.1.0/24"),
|
||||
},
|
||||
testIP: netip.MustParseAddr("127.0.0.2"),
|
||||
expected: true,
|
||||
@@ -32,11 +29,8 @@ func TestLocalIPManager(t *testing.T) {
|
||||
{
|
||||
name: "Localhost standard address",
|
||||
setupAddr: wgaddr.Address{
|
||||
IP: net.ParseIP("192.168.1.1"),
|
||||
Network: &net.IPNet{
|
||||
IP: net.ParseIP("192.168.1.0"),
|
||||
Mask: net.CIDRMask(24, 32),
|
||||
},
|
||||
IP: netip.MustParseAddr("192.168.1.1"),
|
||||
Network: netip.MustParsePrefix("192.168.1.0/24"),
|
||||
},
|
||||
testIP: netip.MustParseAddr("127.0.0.1"),
|
||||
expected: true,
|
||||
@@ -44,11 +38,8 @@ func TestLocalIPManager(t *testing.T) {
|
||||
{
|
||||
name: "Localhost range edge",
|
||||
setupAddr: wgaddr.Address{
|
||||
IP: net.ParseIP("192.168.1.1"),
|
||||
Network: &net.IPNet{
|
||||
IP: net.ParseIP("192.168.1.0"),
|
||||
Mask: net.CIDRMask(24, 32),
|
||||
},
|
||||
IP: netip.MustParseAddr("192.168.1.1"),
|
||||
Network: netip.MustParsePrefix("192.168.1.0/24"),
|
||||
},
|
||||
testIP: netip.MustParseAddr("127.255.255.255"),
|
||||
expected: true,
|
||||
@@ -56,11 +47,8 @@ func TestLocalIPManager(t *testing.T) {
|
||||
{
|
||||
name: "Local IP matches",
|
||||
setupAddr: wgaddr.Address{
|
||||
IP: net.ParseIP("192.168.1.1"),
|
||||
Network: &net.IPNet{
|
||||
IP: net.ParseIP("192.168.1.0"),
|
||||
Mask: net.CIDRMask(24, 32),
|
||||
},
|
||||
IP: netip.MustParseAddr("192.168.1.1"),
|
||||
Network: netip.MustParsePrefix("192.168.1.0/24"),
|
||||
},
|
||||
testIP: netip.MustParseAddr("192.168.1.1"),
|
||||
expected: true,
|
||||
@@ -68,11 +56,8 @@ func TestLocalIPManager(t *testing.T) {
|
||||
{
|
||||
name: "Local IP doesn't match",
|
||||
setupAddr: wgaddr.Address{
|
||||
IP: net.ParseIP("192.168.1.1"),
|
||||
Network: &net.IPNet{
|
||||
IP: net.ParseIP("192.168.1.0"),
|
||||
Mask: net.CIDRMask(24, 32),
|
||||
},
|
||||
IP: netip.MustParseAddr("192.168.1.1"),
|
||||
Network: netip.MustParsePrefix("192.168.1.0/24"),
|
||||
},
|
||||
testIP: netip.MustParseAddr("192.168.1.2"),
|
||||
expected: false,
|
||||
@@ -80,11 +65,8 @@ func TestLocalIPManager(t *testing.T) {
|
||||
{
|
||||
name: "Local IP doesn't match - addresses 32 apart",
|
||||
setupAddr: wgaddr.Address{
|
||||
IP: net.ParseIP("192.168.1.1"),
|
||||
Network: &net.IPNet{
|
||||
IP: net.ParseIP("192.168.1.0"),
|
||||
Mask: net.CIDRMask(24, 32),
|
||||
},
|
||||
IP: netip.MustParseAddr("192.168.1.1"),
|
||||
Network: netip.MustParsePrefix("192.168.1.0/24"),
|
||||
},
|
||||
testIP: netip.MustParseAddr("192.168.1.33"),
|
||||
expected: false,
|
||||
@@ -92,11 +74,8 @@ func TestLocalIPManager(t *testing.T) {
|
||||
{
|
||||
name: "IPv6 address",
|
||||
setupAddr: wgaddr.Address{
|
||||
IP: net.ParseIP("fe80::1"),
|
||||
Network: &net.IPNet{
|
||||
IP: net.ParseIP("fe80::"),
|
||||
Mask: net.CIDRMask(64, 128),
|
||||
},
|
||||
IP: netip.MustParseAddr("fe80::1"),
|
||||
Network: netip.MustParsePrefix("192.168.1.0/24"),
|
||||
},
|
||||
testIP: netip.MustParseAddr("fe80::1"),
|
||||
expected: false,
|
||||
|
||||
408
client/firewall/uspfilter/nat.go
Normal file
408
client/firewall/uspfilter/nat.go
Normal file
@@ -0,0 +1,408 @@
|
||||
package uspfilter
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
|
||||
"github.com/google/gopacket/layers"
|
||||
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
)
|
||||
|
||||
var ErrIPv4Only = errors.New("only IPv4 is supported for DNAT")
|
||||
|
||||
func ipv4Checksum(header []byte) uint16 {
|
||||
if len(header) < 20 {
|
||||
return 0
|
||||
}
|
||||
|
||||
var sum1, sum2 uint32
|
||||
|
||||
// Parallel processing - unroll and compute two sums simultaneously
|
||||
sum1 += uint32(binary.BigEndian.Uint16(header[0:2]))
|
||||
sum2 += uint32(binary.BigEndian.Uint16(header[2:4]))
|
||||
sum1 += uint32(binary.BigEndian.Uint16(header[4:6]))
|
||||
sum2 += uint32(binary.BigEndian.Uint16(header[6:8]))
|
||||
sum1 += uint32(binary.BigEndian.Uint16(header[8:10]))
|
||||
// Skip checksum field at [10:12]
|
||||
sum2 += uint32(binary.BigEndian.Uint16(header[12:14]))
|
||||
sum1 += uint32(binary.BigEndian.Uint16(header[14:16]))
|
||||
sum2 += uint32(binary.BigEndian.Uint16(header[16:18]))
|
||||
sum1 += uint32(binary.BigEndian.Uint16(header[18:20]))
|
||||
|
||||
sum := sum1 + sum2
|
||||
|
||||
// Handle remaining bytes for headers > 20 bytes
|
||||
for i := 20; i < len(header)-1; i += 2 {
|
||||
sum += uint32(binary.BigEndian.Uint16(header[i : i+2]))
|
||||
}
|
||||
|
||||
if len(header)%2 == 1 {
|
||||
sum += uint32(header[len(header)-1]) << 8
|
||||
}
|
||||
|
||||
// Optimized carry fold - single iteration handles most cases
|
||||
sum = (sum & 0xFFFF) + (sum >> 16)
|
||||
if sum > 0xFFFF {
|
||||
sum++
|
||||
}
|
||||
|
||||
return ^uint16(sum)
|
||||
}
|
||||
|
||||
func icmpChecksum(data []byte) uint16 {
|
||||
var sum1, sum2, sum3, sum4 uint32
|
||||
i := 0
|
||||
|
||||
// Process 16 bytes at once with 4 parallel accumulators
|
||||
for i <= len(data)-16 {
|
||||
sum1 += uint32(binary.BigEndian.Uint16(data[i : i+2]))
|
||||
sum2 += uint32(binary.BigEndian.Uint16(data[i+2 : i+4]))
|
||||
sum3 += uint32(binary.BigEndian.Uint16(data[i+4 : i+6]))
|
||||
sum4 += uint32(binary.BigEndian.Uint16(data[i+6 : i+8]))
|
||||
sum1 += uint32(binary.BigEndian.Uint16(data[i+8 : i+10]))
|
||||
sum2 += uint32(binary.BigEndian.Uint16(data[i+10 : i+12]))
|
||||
sum3 += uint32(binary.BigEndian.Uint16(data[i+12 : i+14]))
|
||||
sum4 += uint32(binary.BigEndian.Uint16(data[i+14 : i+16]))
|
||||
i += 16
|
||||
}
|
||||
|
||||
sum := sum1 + sum2 + sum3 + sum4
|
||||
|
||||
// Handle remaining bytes
|
||||
for i < len(data)-1 {
|
||||
sum += uint32(binary.BigEndian.Uint16(data[i : i+2]))
|
||||
i += 2
|
||||
}
|
||||
|
||||
if len(data)%2 == 1 {
|
||||
sum += uint32(data[len(data)-1]) << 8
|
||||
}
|
||||
|
||||
sum = (sum & 0xFFFF) + (sum >> 16)
|
||||
if sum > 0xFFFF {
|
||||
sum++
|
||||
}
|
||||
|
||||
return ^uint16(sum)
|
||||
}
|
||||
|
||||
type biDNATMap struct {
|
||||
forward map[netip.Addr]netip.Addr
|
||||
reverse map[netip.Addr]netip.Addr
|
||||
}
|
||||
|
||||
func newBiDNATMap() *biDNATMap {
|
||||
return &biDNATMap{
|
||||
forward: make(map[netip.Addr]netip.Addr),
|
||||
reverse: make(map[netip.Addr]netip.Addr),
|
||||
}
|
||||
}
|
||||
|
||||
func (b *biDNATMap) set(original, translated netip.Addr) {
|
||||
b.forward[original] = translated
|
||||
b.reverse[translated] = original
|
||||
}
|
||||
|
||||
func (b *biDNATMap) delete(original netip.Addr) {
|
||||
if translated, exists := b.forward[original]; exists {
|
||||
delete(b.forward, original)
|
||||
delete(b.reverse, translated)
|
||||
}
|
||||
}
|
||||
|
||||
func (b *biDNATMap) getTranslated(original netip.Addr) (netip.Addr, bool) {
|
||||
translated, exists := b.forward[original]
|
||||
return translated, exists
|
||||
}
|
||||
|
||||
func (b *biDNATMap) getOriginal(translated netip.Addr) (netip.Addr, bool) {
|
||||
original, exists := b.reverse[translated]
|
||||
return original, exists
|
||||
}
|
||||
|
||||
func (m *Manager) AddInternalDNATMapping(originalAddr, translatedAddr netip.Addr) error {
|
||||
if !originalAddr.IsValid() || !translatedAddr.IsValid() {
|
||||
return fmt.Errorf("invalid IP addresses")
|
||||
}
|
||||
|
||||
if m.localipmanager.IsLocalIP(translatedAddr) {
|
||||
return fmt.Errorf("cannot map to local IP: %s", translatedAddr)
|
||||
}
|
||||
|
||||
m.dnatMutex.Lock()
|
||||
defer m.dnatMutex.Unlock()
|
||||
|
||||
// Initialize both maps together if either is nil
|
||||
if m.dnatMappings == nil || m.dnatBiMap == nil {
|
||||
m.dnatMappings = make(map[netip.Addr]netip.Addr)
|
||||
m.dnatBiMap = newBiDNATMap()
|
||||
}
|
||||
|
||||
m.dnatMappings[originalAddr] = translatedAddr
|
||||
m.dnatBiMap.set(originalAddr, translatedAddr)
|
||||
|
||||
if len(m.dnatMappings) == 1 {
|
||||
m.dnatEnabled.Store(true)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemoveInternalDNATMapping removes a 1:1 IP address mapping
|
||||
func (m *Manager) RemoveInternalDNATMapping(originalAddr netip.Addr) error {
|
||||
m.dnatMutex.Lock()
|
||||
defer m.dnatMutex.Unlock()
|
||||
|
||||
if _, exists := m.dnatMappings[originalAddr]; !exists {
|
||||
return fmt.Errorf("mapping not found for: %s", originalAddr)
|
||||
}
|
||||
|
||||
delete(m.dnatMappings, originalAddr)
|
||||
m.dnatBiMap.delete(originalAddr)
|
||||
if len(m.dnatMappings) == 0 {
|
||||
m.dnatEnabled.Store(false)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// getDNATTranslation returns the translated address if a mapping exists
|
||||
func (m *Manager) getDNATTranslation(addr netip.Addr) (netip.Addr, bool) {
|
||||
if !m.dnatEnabled.Load() {
|
||||
return addr, false
|
||||
}
|
||||
|
||||
m.dnatMutex.RLock()
|
||||
translated, exists := m.dnatBiMap.getTranslated(addr)
|
||||
m.dnatMutex.RUnlock()
|
||||
return translated, exists
|
||||
}
|
||||
|
||||
// findReverseDNATMapping finds original address for return traffic
|
||||
func (m *Manager) findReverseDNATMapping(translatedAddr netip.Addr) (netip.Addr, bool) {
|
||||
if !m.dnatEnabled.Load() {
|
||||
return translatedAddr, false
|
||||
}
|
||||
|
||||
m.dnatMutex.RLock()
|
||||
original, exists := m.dnatBiMap.getOriginal(translatedAddr)
|
||||
m.dnatMutex.RUnlock()
|
||||
return original, exists
|
||||
}
|
||||
|
||||
// translateOutboundDNAT applies DNAT translation to outbound packets
|
||||
func (m *Manager) translateOutboundDNAT(packetData []byte, d *decoder) bool {
|
||||
if !m.dnatEnabled.Load() {
|
||||
return false
|
||||
}
|
||||
|
||||
if len(packetData) < 20 || d.decoded[0] != layers.LayerTypeIPv4 {
|
||||
return false
|
||||
}
|
||||
|
||||
dstIP := netip.AddrFrom4([4]byte{packetData[16], packetData[17], packetData[18], packetData[19]})
|
||||
|
||||
translatedIP, exists := m.getDNATTranslation(dstIP)
|
||||
if !exists {
|
||||
return false
|
||||
}
|
||||
|
||||
if err := m.rewritePacketDestination(packetData, d, translatedIP); err != nil {
|
||||
m.logger.Error("Failed to rewrite packet destination: %v", err)
|
||||
return false
|
||||
}
|
||||
|
||||
m.logger.Trace("DNAT: %s -> %s", dstIP, translatedIP)
|
||||
return true
|
||||
}
|
||||
|
||||
// translateInboundReverse applies reverse DNAT to inbound return traffic
|
||||
func (m *Manager) translateInboundReverse(packetData []byte, d *decoder) bool {
|
||||
if !m.dnatEnabled.Load() {
|
||||
return false
|
||||
}
|
||||
|
||||
if len(packetData) < 20 || d.decoded[0] != layers.LayerTypeIPv4 {
|
||||
return false
|
||||
}
|
||||
|
||||
srcIP := netip.AddrFrom4([4]byte{packetData[12], packetData[13], packetData[14], packetData[15]})
|
||||
|
||||
originalIP, exists := m.findReverseDNATMapping(srcIP)
|
||||
if !exists {
|
||||
return false
|
||||
}
|
||||
|
||||
if err := m.rewritePacketSource(packetData, d, originalIP); err != nil {
|
||||
m.logger.Error("Failed to rewrite packet source: %v", err)
|
||||
return false
|
||||
}
|
||||
|
||||
m.logger.Trace("Reverse DNAT: %s -> %s", srcIP, originalIP)
|
||||
return true
|
||||
}
|
||||
|
||||
// rewritePacketDestination replaces destination IP in the packet
|
||||
func (m *Manager) rewritePacketDestination(packetData []byte, d *decoder, newIP netip.Addr) error {
|
||||
if len(packetData) < 20 || d.decoded[0] != layers.LayerTypeIPv4 || !newIP.Is4() {
|
||||
return ErrIPv4Only
|
||||
}
|
||||
|
||||
var oldDst [4]byte
|
||||
copy(oldDst[:], packetData[16:20])
|
||||
newDst := newIP.As4()
|
||||
|
||||
copy(packetData[16:20], newDst[:])
|
||||
|
||||
ipHeaderLen := int(d.ip4.IHL) * 4
|
||||
if ipHeaderLen < 20 || ipHeaderLen > len(packetData) {
|
||||
return fmt.Errorf("invalid IP header length")
|
||||
}
|
||||
|
||||
binary.BigEndian.PutUint16(packetData[10:12], 0)
|
||||
ipChecksum := ipv4Checksum(packetData[:ipHeaderLen])
|
||||
binary.BigEndian.PutUint16(packetData[10:12], ipChecksum)
|
||||
|
||||
if len(d.decoded) > 1 {
|
||||
switch d.decoded[1] {
|
||||
case layers.LayerTypeTCP:
|
||||
m.updateTCPChecksum(packetData, ipHeaderLen, oldDst[:], newDst[:])
|
||||
case layers.LayerTypeUDP:
|
||||
m.updateUDPChecksum(packetData, ipHeaderLen, oldDst[:], newDst[:])
|
||||
case layers.LayerTypeICMPv4:
|
||||
m.updateICMPChecksum(packetData, ipHeaderLen)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// rewritePacketSource replaces the source IP address in the packet
|
||||
func (m *Manager) rewritePacketSource(packetData []byte, d *decoder, newIP netip.Addr) error {
|
||||
if len(packetData) < 20 || d.decoded[0] != layers.LayerTypeIPv4 || !newIP.Is4() {
|
||||
return ErrIPv4Only
|
||||
}
|
||||
|
||||
var oldSrc [4]byte
|
||||
copy(oldSrc[:], packetData[12:16])
|
||||
newSrc := newIP.As4()
|
||||
|
||||
copy(packetData[12:16], newSrc[:])
|
||||
|
||||
ipHeaderLen := int(d.ip4.IHL) * 4
|
||||
if ipHeaderLen < 20 || ipHeaderLen > len(packetData) {
|
||||
return fmt.Errorf("invalid IP header length")
|
||||
}
|
||||
|
||||
binary.BigEndian.PutUint16(packetData[10:12], 0)
|
||||
ipChecksum := ipv4Checksum(packetData[:ipHeaderLen])
|
||||
binary.BigEndian.PutUint16(packetData[10:12], ipChecksum)
|
||||
|
||||
if len(d.decoded) > 1 {
|
||||
switch d.decoded[1] {
|
||||
case layers.LayerTypeTCP:
|
||||
m.updateTCPChecksum(packetData, ipHeaderLen, oldSrc[:], newSrc[:])
|
||||
case layers.LayerTypeUDP:
|
||||
m.updateUDPChecksum(packetData, ipHeaderLen, oldSrc[:], newSrc[:])
|
||||
case layers.LayerTypeICMPv4:
|
||||
m.updateICMPChecksum(packetData, ipHeaderLen)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manager) updateTCPChecksum(packetData []byte, ipHeaderLen int, oldIP, newIP []byte) {
|
||||
tcpStart := ipHeaderLen
|
||||
if len(packetData) < tcpStart+18 {
|
||||
return
|
||||
}
|
||||
|
||||
checksumOffset := tcpStart + 16
|
||||
oldChecksum := binary.BigEndian.Uint16(packetData[checksumOffset : checksumOffset+2])
|
||||
newChecksum := incrementalUpdate(oldChecksum, oldIP, newIP)
|
||||
binary.BigEndian.PutUint16(packetData[checksumOffset:checksumOffset+2], newChecksum)
|
||||
}
|
||||
|
||||
func (m *Manager) updateUDPChecksum(packetData []byte, ipHeaderLen int, oldIP, newIP []byte) {
|
||||
udpStart := ipHeaderLen
|
||||
if len(packetData) < udpStart+8 {
|
||||
return
|
||||
}
|
||||
|
||||
checksumOffset := udpStart + 6
|
||||
oldChecksum := binary.BigEndian.Uint16(packetData[checksumOffset : checksumOffset+2])
|
||||
|
||||
if oldChecksum == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
newChecksum := incrementalUpdate(oldChecksum, oldIP, newIP)
|
||||
binary.BigEndian.PutUint16(packetData[checksumOffset:checksumOffset+2], newChecksum)
|
||||
}
|
||||
|
||||
func (m *Manager) updateICMPChecksum(packetData []byte, ipHeaderLen int) {
|
||||
icmpStart := ipHeaderLen
|
||||
if len(packetData) < icmpStart+8 {
|
||||
return
|
||||
}
|
||||
|
||||
icmpData := packetData[icmpStart:]
|
||||
binary.BigEndian.PutUint16(icmpData[2:4], 0)
|
||||
checksum := icmpChecksum(icmpData)
|
||||
binary.BigEndian.PutUint16(icmpData[2:4], checksum)
|
||||
}
|
||||
|
||||
// incrementalUpdate performs incremental checksum update per RFC 1624
|
||||
func incrementalUpdate(oldChecksum uint16, oldBytes, newBytes []byte) uint16 {
|
||||
sum := uint32(^oldChecksum)
|
||||
|
||||
// Fast path for IPv4 addresses (4 bytes) - most common case
|
||||
if len(oldBytes) == 4 && len(newBytes) == 4 {
|
||||
sum += uint32(^binary.BigEndian.Uint16(oldBytes[0:2]))
|
||||
sum += uint32(^binary.BigEndian.Uint16(oldBytes[2:4]))
|
||||
sum += uint32(binary.BigEndian.Uint16(newBytes[0:2]))
|
||||
sum += uint32(binary.BigEndian.Uint16(newBytes[2:4]))
|
||||
} else {
|
||||
// Fallback for other lengths
|
||||
for i := 0; i < len(oldBytes)-1; i += 2 {
|
||||
sum += uint32(^binary.BigEndian.Uint16(oldBytes[i : i+2]))
|
||||
}
|
||||
if len(oldBytes)%2 == 1 {
|
||||
sum += uint32(^oldBytes[len(oldBytes)-1]) << 8
|
||||
}
|
||||
|
||||
for i := 0; i < len(newBytes)-1; i += 2 {
|
||||
sum += uint32(binary.BigEndian.Uint16(newBytes[i : i+2]))
|
||||
}
|
||||
if len(newBytes)%2 == 1 {
|
||||
sum += uint32(newBytes[len(newBytes)-1]) << 8
|
||||
}
|
||||
}
|
||||
|
||||
sum = (sum & 0xFFFF) + (sum >> 16)
|
||||
if sum > 0xFFFF {
|
||||
sum++
|
||||
}
|
||||
|
||||
return ^uint16(sum)
|
||||
}
|
||||
|
||||
// AddDNATRule adds a DNAT rule (delegates to native firewall for port forwarding)
|
||||
func (m *Manager) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) {
|
||||
if m.nativeFirewall == nil {
|
||||
return nil, errNatNotSupported
|
||||
}
|
||||
return m.nativeFirewall.AddDNATRule(rule)
|
||||
}
|
||||
|
||||
// DeleteDNATRule deletes a DNAT rule (delegates to native firewall)
|
||||
func (m *Manager) DeleteDNATRule(rule firewall.Rule) error {
|
||||
if m.nativeFirewall == nil {
|
||||
return errNatNotSupported
|
||||
}
|
||||
return m.nativeFirewall.DeleteDNATRule(rule)
|
||||
}
|
||||
416
client/firewall/uspfilter/nat_bench_test.go
Normal file
416
client/firewall/uspfilter/nat_bench_test.go
Normal file
@@ -0,0 +1,416 @@
|
||||
package uspfilter
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
"github.com/google/gopacket"
|
||||
"github.com/google/gopacket/layers"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
)
|
||||
|
||||
// BenchmarkDNATTranslation measures the performance of DNAT operations
|
||||
func BenchmarkDNATTranslation(b *testing.B) {
|
||||
scenarios := []struct {
|
||||
name string
|
||||
proto layers.IPProtocol
|
||||
setupDNAT bool
|
||||
description string
|
||||
}{
|
||||
{
|
||||
name: "tcp_with_dnat",
|
||||
proto: layers.IPProtocolTCP,
|
||||
setupDNAT: true,
|
||||
description: "TCP packet with DNAT translation enabled",
|
||||
},
|
||||
{
|
||||
name: "tcp_without_dnat",
|
||||
proto: layers.IPProtocolTCP,
|
||||
setupDNAT: false,
|
||||
description: "TCP packet without DNAT (baseline)",
|
||||
},
|
||||
{
|
||||
name: "udp_with_dnat",
|
||||
proto: layers.IPProtocolUDP,
|
||||
setupDNAT: true,
|
||||
description: "UDP packet with DNAT translation enabled",
|
||||
},
|
||||
{
|
||||
name: "udp_without_dnat",
|
||||
proto: layers.IPProtocolUDP,
|
||||
setupDNAT: false,
|
||||
description: "UDP packet without DNAT (baseline)",
|
||||
},
|
||||
{
|
||||
name: "icmp_with_dnat",
|
||||
proto: layers.IPProtocolICMPv4,
|
||||
setupDNAT: true,
|
||||
description: "ICMP packet with DNAT translation enabled",
|
||||
},
|
||||
{
|
||||
name: "icmp_without_dnat",
|
||||
proto: layers.IPProtocolICMPv4,
|
||||
setupDNAT: false,
|
||||
description: "ICMP packet without DNAT (baseline)",
|
||||
},
|
||||
}
|
||||
|
||||
for _, sc := range scenarios {
|
||||
b.Run(sc.name, func(b *testing.B) {
|
||||
manager, err := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, false, flowLogger)
|
||||
require.NoError(b, err)
|
||||
defer func() {
|
||||
require.NoError(b, manager.Close(nil))
|
||||
}()
|
||||
|
||||
// Set logger to error level to reduce noise during benchmarking
|
||||
manager.SetLogLevel(log.ErrorLevel)
|
||||
defer func() {
|
||||
// Restore to info level after benchmark
|
||||
manager.SetLogLevel(log.InfoLevel)
|
||||
}()
|
||||
|
||||
// Setup DNAT mapping if needed
|
||||
originalIP := netip.MustParseAddr("192.168.1.100")
|
||||
translatedIP := netip.MustParseAddr("10.0.0.100")
|
||||
|
||||
if sc.setupDNAT {
|
||||
err := manager.AddInternalDNATMapping(originalIP, translatedIP)
|
||||
require.NoError(b, err)
|
||||
}
|
||||
|
||||
// Create test packets
|
||||
srcIP := netip.MustParseAddr("172.16.0.1")
|
||||
outboundPacket := generateDNATTestPacket(b, srcIP, originalIP, sc.proto, 12345, 80)
|
||||
|
||||
// Pre-establish connection for reverse DNAT test
|
||||
if sc.setupDNAT {
|
||||
manager.filterOutbound(outboundPacket, 0)
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
|
||||
// Benchmark outbound DNAT translation
|
||||
b.Run("outbound", func(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
// Create fresh packet each time since translation modifies it
|
||||
packet := generateDNATTestPacket(b, srcIP, originalIP, sc.proto, 12345, 80)
|
||||
manager.filterOutbound(packet, 0)
|
||||
}
|
||||
})
|
||||
|
||||
// Benchmark inbound reverse DNAT translation
|
||||
if sc.setupDNAT {
|
||||
b.Run("inbound_reverse", func(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
// Create fresh packet each time since translation modifies it
|
||||
packet := generateDNATTestPacket(b, translatedIP, srcIP, sc.proto, 80, 12345)
|
||||
manager.filterInbound(packet, 0)
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkDNATConcurrency tests DNAT performance under concurrent load
|
||||
func BenchmarkDNATConcurrency(b *testing.B) {
|
||||
manager, err := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, false, flowLogger)
|
||||
require.NoError(b, err)
|
||||
defer func() {
|
||||
require.NoError(b, manager.Close(nil))
|
||||
}()
|
||||
|
||||
// Set logger to error level to reduce noise during benchmarking
|
||||
manager.SetLogLevel(log.ErrorLevel)
|
||||
defer func() {
|
||||
// Restore to info level after benchmark
|
||||
manager.SetLogLevel(log.InfoLevel)
|
||||
}()
|
||||
|
||||
// Setup multiple DNAT mappings
|
||||
numMappings := 100
|
||||
originalIPs := make([]netip.Addr, numMappings)
|
||||
translatedIPs := make([]netip.Addr, numMappings)
|
||||
|
||||
for i := 0; i < numMappings; i++ {
|
||||
originalIPs[i] = netip.MustParseAddr(fmt.Sprintf("192.168.%d.%d", (i/254)+1, (i%254)+1))
|
||||
translatedIPs[i] = netip.MustParseAddr(fmt.Sprintf("10.0.%d.%d", (i/254)+1, (i%254)+1))
|
||||
err := manager.AddInternalDNATMapping(originalIPs[i], translatedIPs[i])
|
||||
require.NoError(b, err)
|
||||
}
|
||||
|
||||
srcIP := netip.MustParseAddr("172.16.0.1")
|
||||
|
||||
// Pre-generate packets
|
||||
outboundPackets := make([][]byte, numMappings)
|
||||
inboundPackets := make([][]byte, numMappings)
|
||||
for i := 0; i < numMappings; i++ {
|
||||
outboundPackets[i] = generateDNATTestPacket(b, srcIP, originalIPs[i], layers.IPProtocolTCP, 12345, 80)
|
||||
inboundPackets[i] = generateDNATTestPacket(b, translatedIPs[i], srcIP, layers.IPProtocolTCP, 80, 12345)
|
||||
// Establish connections
|
||||
manager.filterOutbound(outboundPackets[i], 0)
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
|
||||
b.Run("concurrent_outbound", func(b *testing.B) {
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
i := 0
|
||||
for pb.Next() {
|
||||
idx := i % numMappings
|
||||
packet := generateDNATTestPacket(b, srcIP, originalIPs[idx], layers.IPProtocolTCP, 12345, 80)
|
||||
manager.filterOutbound(packet, 0)
|
||||
i++
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
b.Run("concurrent_inbound", func(b *testing.B) {
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
i := 0
|
||||
for pb.Next() {
|
||||
idx := i % numMappings
|
||||
packet := generateDNATTestPacket(b, translatedIPs[idx], srcIP, layers.IPProtocolTCP, 80, 12345)
|
||||
manager.filterInbound(packet, 0)
|
||||
i++
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
// BenchmarkDNATScaling tests how DNAT performance scales with number of mappings
|
||||
func BenchmarkDNATScaling(b *testing.B) {
|
||||
mappingCounts := []int{1, 10, 100, 1000}
|
||||
|
||||
for _, count := range mappingCounts {
|
||||
b.Run(fmt.Sprintf("mappings_%d", count), func(b *testing.B) {
|
||||
manager, err := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, false, flowLogger)
|
||||
require.NoError(b, err)
|
||||
defer func() {
|
||||
require.NoError(b, manager.Close(nil))
|
||||
}()
|
||||
|
||||
// Set logger to error level to reduce noise during benchmarking
|
||||
manager.SetLogLevel(log.ErrorLevel)
|
||||
defer func() {
|
||||
// Restore to info level after benchmark
|
||||
manager.SetLogLevel(log.InfoLevel)
|
||||
}()
|
||||
|
||||
// Setup DNAT mappings
|
||||
for i := 0; i < count; i++ {
|
||||
originalIP := netip.MustParseAddr(fmt.Sprintf("192.168.%d.%d", (i/254)+1, (i%254)+1))
|
||||
translatedIP := netip.MustParseAddr(fmt.Sprintf("10.0.%d.%d", (i/254)+1, (i%254)+1))
|
||||
err := manager.AddInternalDNATMapping(originalIP, translatedIP)
|
||||
require.NoError(b, err)
|
||||
}
|
||||
|
||||
// Test with the last mapping added (worst case for lookup)
|
||||
srcIP := netip.MustParseAddr("172.16.0.1")
|
||||
lastOriginal := netip.MustParseAddr(fmt.Sprintf("192.168.%d.%d", ((count-1)/254)+1, ((count-1)%254)+1))
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
packet := generateDNATTestPacket(b, srcIP, lastOriginal, layers.IPProtocolTCP, 12345, 80)
|
||||
manager.filterOutbound(packet, 0)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// generateDNATTestPacket creates a test packet for DNAT benchmarking
|
||||
func generateDNATTestPacket(tb testing.TB, srcIP, dstIP netip.Addr, proto layers.IPProtocol, srcPort, dstPort uint16) []byte {
|
||||
tb.Helper()
|
||||
|
||||
ipv4 := &layers.IPv4{
|
||||
TTL: 64,
|
||||
Version: 4,
|
||||
SrcIP: srcIP.AsSlice(),
|
||||
DstIP: dstIP.AsSlice(),
|
||||
Protocol: proto,
|
||||
}
|
||||
|
||||
var transportLayer gopacket.SerializableLayer
|
||||
switch proto {
|
||||
case layers.IPProtocolTCP:
|
||||
tcp := &layers.TCP{
|
||||
SrcPort: layers.TCPPort(srcPort),
|
||||
DstPort: layers.TCPPort(dstPort),
|
||||
SYN: true,
|
||||
}
|
||||
require.NoError(tb, tcp.SetNetworkLayerForChecksum(ipv4))
|
||||
transportLayer = tcp
|
||||
case layers.IPProtocolUDP:
|
||||
udp := &layers.UDP{
|
||||
SrcPort: layers.UDPPort(srcPort),
|
||||
DstPort: layers.UDPPort(dstPort),
|
||||
}
|
||||
require.NoError(tb, udp.SetNetworkLayerForChecksum(ipv4))
|
||||
transportLayer = udp
|
||||
case layers.IPProtocolICMPv4:
|
||||
icmp := &layers.ICMPv4{
|
||||
TypeCode: layers.CreateICMPv4TypeCode(layers.ICMPv4TypeEchoRequest, 0),
|
||||
}
|
||||
transportLayer = icmp
|
||||
}
|
||||
|
||||
buf := gopacket.NewSerializeBuffer()
|
||||
opts := gopacket.SerializeOptions{ComputeChecksums: true, FixLengths: true}
|
||||
err := gopacket.SerializeLayers(buf, opts, ipv4, transportLayer, gopacket.Payload("test"))
|
||||
require.NoError(tb, err)
|
||||
return buf.Bytes()
|
||||
}
|
||||
|
||||
// BenchmarkChecksumUpdate specifically benchmarks checksum calculation performance
|
||||
func BenchmarkChecksumUpdate(b *testing.B) {
|
||||
// Create test data for checksum calculations
|
||||
testData := make([]byte, 64) // Typical packet size for checksum testing
|
||||
for i := range testData {
|
||||
testData[i] = byte(i)
|
||||
}
|
||||
|
||||
b.Run("ipv4_checksum", func(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
_ = ipv4Checksum(testData[:20]) // IPv4 header is typically 20 bytes
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("icmp_checksum", func(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
_ = icmpChecksum(testData)
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("incremental_update", func(b *testing.B) {
|
||||
oldBytes := []byte{192, 168, 1, 100}
|
||||
newBytes := []byte{10, 0, 0, 100}
|
||||
oldChecksum := uint16(0x1234)
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
_ = incrementalUpdate(oldChecksum, oldBytes, newBytes)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// BenchmarkDNATMemoryAllocations checks for memory allocations in DNAT operations
|
||||
func BenchmarkDNATMemoryAllocations(b *testing.B) {
|
||||
manager, err := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, false, flowLogger)
|
||||
require.NoError(b, err)
|
||||
defer func() {
|
||||
require.NoError(b, manager.Close(nil))
|
||||
}()
|
||||
|
||||
// Set logger to error level to reduce noise during benchmarking
|
||||
manager.SetLogLevel(log.ErrorLevel)
|
||||
defer func() {
|
||||
// Restore to info level after benchmark
|
||||
manager.SetLogLevel(log.InfoLevel)
|
||||
}()
|
||||
|
||||
originalIP := netip.MustParseAddr("192.168.1.100")
|
||||
translatedIP := netip.MustParseAddr("10.0.0.100")
|
||||
srcIP := netip.MustParseAddr("172.16.0.1")
|
||||
|
||||
err = manager.AddInternalDNATMapping(originalIP, translatedIP)
|
||||
require.NoError(b, err)
|
||||
|
||||
packet := generateDNATTestPacket(b, srcIP, originalIP, layers.IPProtocolTCP, 12345, 80)
|
||||
|
||||
b.ResetTimer()
|
||||
b.ReportAllocs()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
// Create fresh packet each time to isolate allocation testing
|
||||
testPacket := make([]byte, len(packet))
|
||||
copy(testPacket, packet)
|
||||
|
||||
// Parse the packet fresh each time to get a clean decoder
|
||||
d := &decoder{decoded: []gopacket.LayerType{}}
|
||||
d.parser = gopacket.NewDecodingLayerParser(
|
||||
layers.LayerTypeIPv4,
|
||||
&d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp,
|
||||
)
|
||||
d.parser.IgnoreUnsupported = true
|
||||
err = d.parser.DecodeLayers(testPacket, &d.decoded)
|
||||
assert.NoError(b, err)
|
||||
|
||||
manager.translateOutboundDNAT(testPacket, d)
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkDirectIPExtraction tests the performance improvement of direct IP extraction
|
||||
func BenchmarkDirectIPExtraction(b *testing.B) {
|
||||
// Create a test packet
|
||||
srcIP := netip.MustParseAddr("172.16.0.1")
|
||||
dstIP := netip.MustParseAddr("192.168.1.100")
|
||||
packet := generateDNATTestPacket(b, srcIP, dstIP, layers.IPProtocolTCP, 12345, 80)
|
||||
|
||||
b.Run("direct_byte_access", func(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
// Direct extraction from packet bytes
|
||||
_ = netip.AddrFrom4([4]byte{packet[16], packet[17], packet[18], packet[19]})
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("decoder_extraction", func(b *testing.B) {
|
||||
// Create decoder once for comparison
|
||||
d := &decoder{decoded: []gopacket.LayerType{}}
|
||||
d.parser = gopacket.NewDecodingLayerParser(
|
||||
layers.LayerTypeIPv4,
|
||||
&d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp,
|
||||
)
|
||||
d.parser.IgnoreUnsupported = true
|
||||
err := d.parser.DecodeLayers(packet, &d.decoded)
|
||||
assert.NoError(b, err)
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
// Extract using decoder (traditional method)
|
||||
dst, _ := netip.AddrFromSlice(d.ip4.DstIP)
|
||||
_ = dst
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// BenchmarkChecksumOptimizations compares optimized vs standard checksum implementations
|
||||
func BenchmarkChecksumOptimizations(b *testing.B) {
|
||||
// Create test IPv4 header (20 bytes)
|
||||
header := make([]byte, 20)
|
||||
for i := range header {
|
||||
header[i] = byte(i)
|
||||
}
|
||||
// Clear checksum field
|
||||
header[10] = 0
|
||||
header[11] = 0
|
||||
|
||||
b.Run("optimized_ipv4_checksum", func(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
_ = ipv4Checksum(header)
|
||||
}
|
||||
})
|
||||
|
||||
// Test incremental checksum updates
|
||||
oldIP := []byte{192, 168, 1, 100}
|
||||
newIP := []byte{10, 0, 0, 100}
|
||||
oldChecksum := uint16(0x1234)
|
||||
|
||||
b.Run("optimized_incremental_update", func(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
_ = incrementalUpdate(oldChecksum, oldIP, newIP)
|
||||
}
|
||||
})
|
||||
}
|
||||
145
client/firewall/uspfilter/nat_test.go
Normal file
145
client/firewall/uspfilter/nat_test.go
Normal file
@@ -0,0 +1,145 @@
|
||||
package uspfilter
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
"github.com/google/gopacket"
|
||||
"github.com/google/gopacket/layers"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
)
|
||||
|
||||
// TestDNATTranslationCorrectness verifies DNAT translation works correctly
|
||||
func TestDNATTranslationCorrectness(t *testing.T) {
|
||||
manager, err := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, false, flowLogger)
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
require.NoError(t, manager.Close(nil))
|
||||
}()
|
||||
|
||||
originalIP := netip.MustParseAddr("192.168.1.100")
|
||||
translatedIP := netip.MustParseAddr("10.0.0.100")
|
||||
srcIP := netip.MustParseAddr("172.16.0.1")
|
||||
|
||||
// Add DNAT mapping
|
||||
err = manager.AddInternalDNATMapping(originalIP, translatedIP)
|
||||
require.NoError(t, err)
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
protocol layers.IPProtocol
|
||||
srcPort uint16
|
||||
dstPort uint16
|
||||
}{
|
||||
{"TCP", layers.IPProtocolTCP, 12345, 80},
|
||||
{"UDP", layers.IPProtocolUDP, 12345, 53},
|
||||
{"ICMP", layers.IPProtocolICMPv4, 0, 0},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
// Test outbound DNAT translation
|
||||
outboundPacket := generateDNATTestPacket(t, srcIP, originalIP, tc.protocol, tc.srcPort, tc.dstPort)
|
||||
originalOutbound := make([]byte, len(outboundPacket))
|
||||
copy(originalOutbound, outboundPacket)
|
||||
|
||||
// Process outbound packet (should translate destination)
|
||||
translated := manager.translateOutboundDNAT(outboundPacket, parsePacket(t, outboundPacket))
|
||||
require.True(t, translated, "Outbound packet should be translated")
|
||||
|
||||
// Verify destination IP was changed
|
||||
dstIPAfter := netip.AddrFrom4([4]byte{outboundPacket[16], outboundPacket[17], outboundPacket[18], outboundPacket[19]})
|
||||
require.Equal(t, translatedIP, dstIPAfter, "Destination IP should be translated")
|
||||
|
||||
// Test inbound reverse DNAT translation
|
||||
inboundPacket := generateDNATTestPacket(t, translatedIP, srcIP, tc.protocol, tc.dstPort, tc.srcPort)
|
||||
originalInbound := make([]byte, len(inboundPacket))
|
||||
copy(originalInbound, inboundPacket)
|
||||
|
||||
// Process inbound packet (should reverse translate source)
|
||||
reversed := manager.translateInboundReverse(inboundPacket, parsePacket(t, inboundPacket))
|
||||
require.True(t, reversed, "Inbound packet should be reverse translated")
|
||||
|
||||
// Verify source IP was changed back to original
|
||||
srcIPAfter := netip.AddrFrom4([4]byte{inboundPacket[12], inboundPacket[13], inboundPacket[14], inboundPacket[15]})
|
||||
require.Equal(t, originalIP, srcIPAfter, "Source IP should be reverse translated")
|
||||
|
||||
// Test that checksums are recalculated correctly
|
||||
if tc.protocol != layers.IPProtocolICMPv4 {
|
||||
// For TCP/UDP, verify the transport checksum was updated
|
||||
require.NotEqual(t, originalOutbound, outboundPacket, "Outbound packet should be modified")
|
||||
require.NotEqual(t, originalInbound, inboundPacket, "Inbound packet should be modified")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// parsePacket helper to create a decoder for testing
|
||||
func parsePacket(t testing.TB, packetData []byte) *decoder {
|
||||
t.Helper()
|
||||
d := &decoder{
|
||||
decoded: []gopacket.LayerType{},
|
||||
}
|
||||
d.parser = gopacket.NewDecodingLayerParser(
|
||||
layers.LayerTypeIPv4,
|
||||
&d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp,
|
||||
)
|
||||
d.parser.IgnoreUnsupported = true
|
||||
|
||||
err := d.parser.DecodeLayers(packetData, &d.decoded)
|
||||
require.NoError(t, err)
|
||||
return d
|
||||
}
|
||||
|
||||
// TestDNATMappingManagement tests adding/removing DNAT mappings
|
||||
func TestDNATMappingManagement(t *testing.T) {
|
||||
manager, err := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, false, flowLogger)
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
require.NoError(t, manager.Close(nil))
|
||||
}()
|
||||
|
||||
originalIP := netip.MustParseAddr("192.168.1.100")
|
||||
translatedIP := netip.MustParseAddr("10.0.0.100")
|
||||
|
||||
// Test adding mapping
|
||||
err = manager.AddInternalDNATMapping(originalIP, translatedIP)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify mapping exists
|
||||
result, exists := manager.getDNATTranslation(originalIP)
|
||||
require.True(t, exists)
|
||||
require.Equal(t, translatedIP, result)
|
||||
|
||||
// Test reverse lookup
|
||||
reverseResult, exists := manager.findReverseDNATMapping(translatedIP)
|
||||
require.True(t, exists)
|
||||
require.Equal(t, originalIP, reverseResult)
|
||||
|
||||
// Test removing mapping
|
||||
err = manager.RemoveInternalDNATMapping(originalIP)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify mapping no longer exists
|
||||
_, exists = manager.getDNATTranslation(originalIP)
|
||||
require.False(t, exists)
|
||||
|
||||
_, exists = manager.findReverseDNATMapping(translatedIP)
|
||||
require.False(t, exists)
|
||||
|
||||
// Test error cases
|
||||
err = manager.AddInternalDNATMapping(netip.Addr{}, translatedIP)
|
||||
require.Error(t, err, "Should reject invalid original IP")
|
||||
|
||||
err = manager.AddInternalDNATMapping(originalIP, netip.Addr{})
|
||||
require.Error(t, err, "Should reject invalid translated IP")
|
||||
|
||||
err = manager.RemoveInternalDNATMapping(originalIP)
|
||||
require.Error(t, err, "Should error when removing non-existent mapping")
|
||||
}
|
||||
@@ -29,14 +29,15 @@ func (r *PeerRule) ID() string {
|
||||
}
|
||||
|
||||
type RouteRule struct {
|
||||
id string
|
||||
mgmtId []byte
|
||||
sources []netip.Prefix
|
||||
destination netip.Prefix
|
||||
proto firewall.Protocol
|
||||
srcPort *firewall.Port
|
||||
dstPort *firewall.Port
|
||||
action firewall.Action
|
||||
id string
|
||||
mgmtId []byte
|
||||
sources []netip.Prefix
|
||||
dstSet firewall.Set
|
||||
destinations []netip.Prefix
|
||||
proto firewall.Protocol
|
||||
srcPort *firewall.Port
|
||||
dstPort *firewall.Port
|
||||
action firewall.Action
|
||||
}
|
||||
|
||||
// ID returns the rule id
|
||||
|
||||
@@ -401,7 +401,7 @@ func (m *Manager) addForwardingResult(trace *PacketTrace, action, remoteAddr str
|
||||
|
||||
func (m *Manager) traceOutbound(packetData []byte, trace *PacketTrace) *PacketTrace {
|
||||
// will create or update the connection state
|
||||
dropped := m.processOutgoingHooks(packetData, 0)
|
||||
dropped := m.filterOutbound(packetData, 0)
|
||||
if dropped {
|
||||
trace.AddResult(StageCompleted, "Packet dropped by outgoing hook", false)
|
||||
} else {
|
||||
|
||||
@@ -38,11 +38,8 @@ func TestTracePacket(t *testing.T) {
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
AddressFunc: func() wgaddr.Address {
|
||||
return wgaddr.Address{
|
||||
IP: net.ParseIP("100.10.0.100"),
|
||||
Network: &net.IPNet{
|
||||
IP: net.ParseIP("100.10.0.0"),
|
||||
Mask: net.CIDRMask(16, 32),
|
||||
},
|
||||
IP: netip.MustParseAddr("100.10.0.100"),
|
||||
Network: netip.MustParsePrefix("100.10.0.0/16"),
|
||||
}
|
||||
},
|
||||
}
|
||||
@@ -198,12 +195,12 @@ func TestTracePacket(t *testing.T) {
|
||||
m.forwarder.Store(&forwarder.Forwarder{})
|
||||
|
||||
src := netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 32)
|
||||
dst := netip.PrefixFrom(netip.AddrFrom4([4]byte{172, 17, 0, 2}), 32)
|
||||
_, err := m.AddRouteFiltering(nil, []netip.Prefix{src}, dst, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept)
|
||||
dst := netip.PrefixFrom(netip.AddrFrom4([4]byte{192, 168, 17, 2}), 32)
|
||||
_, err := m.AddRouteFiltering(nil, []netip.Prefix{src}, fw.Network{Prefix: dst}, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept)
|
||||
require.NoError(t, err)
|
||||
},
|
||||
packetBuilder: func() *PacketBuilder {
|
||||
return createPacketBuilder("1.1.1.1", "172.17.0.2", "tcp", 12345, 80, fw.RuleDirectionIN)
|
||||
return createPacketBuilder("1.1.1.1", "192.168.17.2", "tcp", 12345, 80, fw.RuleDirectionIN)
|
||||
},
|
||||
expectedStages: []PacketStage{
|
||||
StageReceived,
|
||||
@@ -222,12 +219,12 @@ func TestTracePacket(t *testing.T) {
|
||||
m.nativeRouter.Store(false)
|
||||
|
||||
src := netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 32)
|
||||
dst := netip.PrefixFrom(netip.AddrFrom4([4]byte{172, 17, 0, 2}), 32)
|
||||
_, err := m.AddRouteFiltering(nil, []netip.Prefix{src}, dst, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionDrop)
|
||||
dst := netip.PrefixFrom(netip.AddrFrom4([4]byte{192, 168, 17, 2}), 32)
|
||||
_, err := m.AddRouteFiltering(nil, []netip.Prefix{src}, fw.Network{Prefix: dst}, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionDrop)
|
||||
require.NoError(t, err)
|
||||
},
|
||||
packetBuilder: func() *PacketBuilder {
|
||||
return createPacketBuilder("1.1.1.1", "172.17.0.2", "tcp", 12345, 80, fw.RuleDirectionIN)
|
||||
return createPacketBuilder("1.1.1.1", "192.168.17.2", "tcp", 12345, 80, fw.RuleDirectionIN)
|
||||
},
|
||||
expectedStages: []PacketStage{
|
||||
StageReceived,
|
||||
@@ -245,7 +242,7 @@ func TestTracePacket(t *testing.T) {
|
||||
m.nativeRouter.Store(true)
|
||||
},
|
||||
packetBuilder: func() *PacketBuilder {
|
||||
return createPacketBuilder("1.1.1.1", "172.17.0.2", "tcp", 12345, 80, fw.RuleDirectionIN)
|
||||
return createPacketBuilder("1.1.1.1", "192.168.17.2", "tcp", 12345, 80, fw.RuleDirectionIN)
|
||||
},
|
||||
expectedStages: []PacketStage{
|
||||
StageReceived,
|
||||
@@ -263,7 +260,7 @@ func TestTracePacket(t *testing.T) {
|
||||
m.routingEnabled.Store(false)
|
||||
},
|
||||
packetBuilder: func() *PacketBuilder {
|
||||
return createPacketBuilder("1.1.1.1", "172.17.0.2", "tcp", 12345, 80, fw.RuleDirectionIN)
|
||||
return createPacketBuilder("1.1.1.1", "192.168.17.2", "tcp", 12345, 80, fw.RuleDirectionIN)
|
||||
},
|
||||
expectedStages: []PacketStage{
|
||||
StageReceived,
|
||||
@@ -425,8 +422,8 @@ func TestTracePacket(t *testing.T) {
|
||||
|
||||
require.True(t, m.localipmanager.IsLocalIP(netip.MustParseAddr("100.10.0.100")),
|
||||
"100.10.0.100 should be recognized as a local IP")
|
||||
require.False(t, m.localipmanager.IsLocalIP(netip.MustParseAddr("172.17.0.2")),
|
||||
"172.17.0.2 should not be recognized as a local IP")
|
||||
require.False(t, m.localipmanager.IsLocalIP(netip.MustParseAddr("192.168.17.2")),
|
||||
"192.168.17.2 should not be recognized as a local IP")
|
||||
|
||||
pb := tc.packetBuilder()
|
||||
|
||||
|
||||
96
client/iface/bind/activity.go
Normal file
96
client/iface/bind/activity.go
Normal file
@@ -0,0 +1,96 @@
|
||||
package bind
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/monotime"
|
||||
)
|
||||
|
||||
const (
|
||||
saveFrequency = int64(5 * time.Second)
|
||||
)
|
||||
|
||||
type PeerRecord struct {
|
||||
Address netip.AddrPort
|
||||
LastActivity atomic.Int64 // UnixNano timestamp
|
||||
}
|
||||
|
||||
type ActivityRecorder struct {
|
||||
mu sync.RWMutex
|
||||
peers map[string]*PeerRecord // publicKey to PeerRecord map
|
||||
addrToPeer map[netip.AddrPort]*PeerRecord // address to PeerRecord map
|
||||
}
|
||||
|
||||
func NewActivityRecorder() *ActivityRecorder {
|
||||
return &ActivityRecorder{
|
||||
peers: make(map[string]*PeerRecord),
|
||||
addrToPeer: make(map[netip.AddrPort]*PeerRecord),
|
||||
}
|
||||
}
|
||||
|
||||
// GetLastActivities returns a snapshot of peer last activity
|
||||
func (r *ActivityRecorder) GetLastActivities() map[string]monotime.Time {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
|
||||
activities := make(map[string]monotime.Time, len(r.peers))
|
||||
for key, record := range r.peers {
|
||||
monoTime := record.LastActivity.Load()
|
||||
activities[key] = monotime.Time(monoTime)
|
||||
}
|
||||
return activities
|
||||
}
|
||||
|
||||
// UpsertAddress adds or updates the address for a publicKey
|
||||
func (r *ActivityRecorder) UpsertAddress(publicKey string, address netip.AddrPort) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
var record *PeerRecord
|
||||
record, exists := r.peers[publicKey]
|
||||
if exists {
|
||||
delete(r.addrToPeer, record.Address)
|
||||
record.Address = address
|
||||
} else {
|
||||
record = &PeerRecord{
|
||||
Address: address,
|
||||
}
|
||||
record.LastActivity.Store(int64(monotime.Now()))
|
||||
r.peers[publicKey] = record
|
||||
}
|
||||
|
||||
r.addrToPeer[address] = record
|
||||
}
|
||||
|
||||
func (r *ActivityRecorder) Remove(publicKey string) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
if record, exists := r.peers[publicKey]; exists {
|
||||
delete(r.addrToPeer, record.Address)
|
||||
delete(r.peers, publicKey)
|
||||
}
|
||||
}
|
||||
|
||||
// record updates LastActivity for the given address using atomic store
|
||||
func (r *ActivityRecorder) record(address netip.AddrPort) {
|
||||
r.mu.RLock()
|
||||
record, ok := r.addrToPeer[address]
|
||||
r.mu.RUnlock()
|
||||
if !ok {
|
||||
log.Warnf("could not find record for address %s", address)
|
||||
return
|
||||
}
|
||||
|
||||
now := int64(monotime.Now())
|
||||
last := record.LastActivity.Load()
|
||||
if now-last < saveFrequency {
|
||||
return
|
||||
}
|
||||
|
||||
_ = record.LastActivity.CompareAndSwap(last, now)
|
||||
}
|
||||
25
client/iface/bind/activity_test.go
Normal file
25
client/iface/bind/activity_test.go
Normal file
@@ -0,0 +1,25 @@
|
||||
package bind
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/netbirdio/netbird/monotime"
|
||||
)
|
||||
|
||||
func TestActivityRecorder_GetLastActivities(t *testing.T) {
|
||||
peer := "peer1"
|
||||
ar := NewActivityRecorder()
|
||||
ar.UpsertAddress("peer1", netip.MustParseAddrPort("192.168.0.5:51820"))
|
||||
activities := ar.GetLastActivities()
|
||||
|
||||
p, ok := activities[peer]
|
||||
if !ok {
|
||||
t.Fatalf("Expected activity for peer %s, but got none", peer)
|
||||
}
|
||||
|
||||
if monotime.Since(p) > 5*time.Second {
|
||||
t.Fatalf("Expected activity for peer %s to be recent, but got %v", peer, p)
|
||||
}
|
||||
}
|
||||
15
client/iface/bind/control.go
Normal file
15
client/iface/bind/control.go
Normal file
@@ -0,0 +1,15 @@
|
||||
package bind
|
||||
|
||||
import (
|
||||
wireguard "golang.zx2c4.com/wireguard/conn"
|
||||
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
)
|
||||
|
||||
// TODO: This is most likely obsolete since the control fns should be called by the wrapped udpconn (ice_bind.go)
|
||||
func init() {
|
||||
listener := nbnet.NewListener()
|
||||
if listener.ListenConfig.Control != nil {
|
||||
*wireguard.ControlFns = append(*wireguard.ControlFns, listener.ListenConfig.Control)
|
||||
}
|
||||
}
|
||||
@@ -1,12 +0,0 @@
|
||||
package bind
|
||||
|
||||
import (
|
||||
wireguard "golang.zx2c4.com/wireguard/conn"
|
||||
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
)
|
||||
|
||||
func init() {
|
||||
// ControlFns is not thread safe and should only be modified during init.
|
||||
*wireguard.ControlFns = append(*wireguard.ControlFns, nbnet.ControlProtectSocket)
|
||||
}
|
||||
@@ -1,6 +1,7 @@
|
||||
package bind
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
@@ -15,6 +16,7 @@ import (
|
||||
wgConn "golang.zx2c4.com/wireguard/conn"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
)
|
||||
|
||||
type RecvMessage struct {
|
||||
@@ -51,22 +53,24 @@ type ICEBind struct {
|
||||
closedChanMu sync.RWMutex // protect the closeChan recreation from reading from it.
|
||||
closed bool
|
||||
|
||||
muUDPMux sync.Mutex
|
||||
udpMux *UniversalUDPMuxDefault
|
||||
address wgaddr.Address
|
||||
muUDPMux sync.Mutex
|
||||
udpMux *UniversalUDPMuxDefault
|
||||
address wgaddr.Address
|
||||
activityRecorder *ActivityRecorder
|
||||
}
|
||||
|
||||
func NewICEBind(transportNet transport.Net, filterFn FilterFn, address wgaddr.Address) *ICEBind {
|
||||
b, _ := wgConn.NewStdNetBind().(*wgConn.StdNetBind)
|
||||
ib := &ICEBind{
|
||||
StdNetBind: b,
|
||||
RecvChan: make(chan RecvMessage, 1),
|
||||
transportNet: transportNet,
|
||||
filterFn: filterFn,
|
||||
endpoints: make(map[netip.Addr]net.Conn),
|
||||
closedChan: make(chan struct{}),
|
||||
closed: true,
|
||||
address: address,
|
||||
StdNetBind: b,
|
||||
RecvChan: make(chan RecvMessage, 1),
|
||||
transportNet: transportNet,
|
||||
filterFn: filterFn,
|
||||
endpoints: make(map[netip.Addr]net.Conn),
|
||||
closedChan: make(chan struct{}),
|
||||
closed: true,
|
||||
address: address,
|
||||
activityRecorder: NewActivityRecorder(),
|
||||
}
|
||||
|
||||
rc := receiverCreator{
|
||||
@@ -100,6 +104,10 @@ func (s *ICEBind) Close() error {
|
||||
return s.StdNetBind.Close()
|
||||
}
|
||||
|
||||
func (s *ICEBind) ActivityRecorder() *ActivityRecorder {
|
||||
return s.activityRecorder
|
||||
}
|
||||
|
||||
// GetICEMux returns the ICE UDPMux that was created and used by ICEBind
|
||||
func (s *ICEBind) GetICEMux() (*UniversalUDPMuxDefault, error) {
|
||||
s.muUDPMux.Lock()
|
||||
@@ -146,7 +154,7 @@ func (s *ICEBind) createIPv4ReceiverFn(pc *ipv4.PacketConn, conn *net.UDPConn, r
|
||||
|
||||
s.udpMux = NewUniversalUDPMuxDefault(
|
||||
UniversalUDPMuxParams{
|
||||
UDPConn: conn,
|
||||
UDPConn: nbnet.WrapPacketConn(conn),
|
||||
Net: s.transportNet,
|
||||
FilterFn: s.filterFn,
|
||||
WGAddress: s.address,
|
||||
@@ -199,6 +207,11 @@ func (s *ICEBind) createIPv4ReceiverFn(pc *ipv4.PacketConn, conn *net.UDPConn, r
|
||||
continue
|
||||
}
|
||||
addrPort := msg.Addr.(*net.UDPAddr).AddrPort()
|
||||
|
||||
if isTransportPkg(msg.Buffers, msg.N) {
|
||||
s.activityRecorder.record(addrPort)
|
||||
}
|
||||
|
||||
ep := &wgConn.StdNetEndpoint{AddrPort: addrPort} // TODO: remove allocation
|
||||
wgConn.GetSrcFromControl(msg.OOB[:msg.NN], ep)
|
||||
eps[i] = ep
|
||||
@@ -257,6 +270,13 @@ func (c *ICEBind) receiveRelayed(buffs [][]byte, sizes []int, eps []wgConn.Endpo
|
||||
copy(buffs[0], msg.Buffer)
|
||||
sizes[0] = len(msg.Buffer)
|
||||
eps[0] = wgConn.Endpoint(msg.Endpoint)
|
||||
|
||||
if isTransportPkg(buffs, sizes[0]) {
|
||||
if ep, ok := eps[0].(*Endpoint); ok {
|
||||
c.activityRecorder.record(ep.AddrPort)
|
||||
}
|
||||
}
|
||||
|
||||
return 1, nil
|
||||
}
|
||||
}
|
||||
@@ -272,3 +292,19 @@ func putMessages(msgs *[]ipv6.Message, msgsPool *sync.Pool) {
|
||||
}
|
||||
msgsPool.Put(msgs)
|
||||
}
|
||||
|
||||
func isTransportPkg(buffers [][]byte, n int) bool {
|
||||
// The first buffer should contain at least 4 bytes for type
|
||||
if len(buffers[0]) < 4 {
|
||||
return true
|
||||
}
|
||||
|
||||
// WireGuard packet type is a little-endian uint32 at start
|
||||
packetType := binary.LittleEndian.Uint32(buffers[0][:4])
|
||||
|
||||
// Check if packetType matches known WireGuard message types
|
||||
if packetType == 4 && n > 32 {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -296,14 +296,20 @@ func (m *UDPMuxDefault) RemoveConnByUfrag(ufrag string) {
|
||||
return
|
||||
}
|
||||
|
||||
m.addressMapMu.Lock()
|
||||
defer m.addressMapMu.Unlock()
|
||||
|
||||
var allAddresses []string
|
||||
for _, c := range removedConns {
|
||||
addresses := c.getAddresses()
|
||||
for _, addr := range addresses {
|
||||
delete(m.addressMap, addr)
|
||||
}
|
||||
allAddresses = append(allAddresses, addresses...)
|
||||
}
|
||||
|
||||
m.addressMapMu.Lock()
|
||||
for _, addr := range allAddresses {
|
||||
delete(m.addressMap, addr)
|
||||
}
|
||||
m.addressMapMu.Unlock()
|
||||
|
||||
for _, addr := range allAddresses {
|
||||
m.notifyAddressRemoval(addr)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -351,14 +357,13 @@ func (m *UDPMuxDefault) registerConnForAddress(conn *udpMuxedConn, addr string)
|
||||
}
|
||||
|
||||
m.addressMapMu.Lock()
|
||||
defer m.addressMapMu.Unlock()
|
||||
|
||||
existing, ok := m.addressMap[addr]
|
||||
if !ok {
|
||||
existing = []*udpMuxedConn{}
|
||||
}
|
||||
existing = append(existing, conn)
|
||||
m.addressMap[addr] = existing
|
||||
m.addressMapMu.Unlock()
|
||||
|
||||
log.Debugf("ICE: registered %s for %s", addr, conn.params.Key)
|
||||
}
|
||||
@@ -386,12 +391,12 @@ func (m *UDPMuxDefault) HandleSTUNMessage(msg *stun.Message, addr net.Addr) erro
|
||||
// If you are using the same socket for the Host and SRFLX candidates, it might be that there are more than one
|
||||
// muxed connection - one for the SRFLX candidate and the other one for the HOST one.
|
||||
// We will then forward STUN packets to each of these connections.
|
||||
m.addressMapMu.Lock()
|
||||
m.addressMapMu.RLock()
|
||||
var destinationConnList []*udpMuxedConn
|
||||
if storedConns, ok := m.addressMap[addr.String()]; ok {
|
||||
destinationConnList = append(destinationConnList, storedConns...)
|
||||
}
|
||||
m.addressMapMu.Unlock()
|
||||
m.addressMapMu.RUnlock()
|
||||
|
||||
var isIPv6 bool
|
||||
if udpAddr, _ := addr.(*net.UDPAddr); udpAddr != nil && udpAddr.IP.To4() == nil {
|
||||
|
||||
22
client/iface/bind/udp_mux_generic.go
Normal file
22
client/iface/bind/udp_mux_generic.go
Normal file
@@ -0,0 +1,22 @@
|
||||
//go:build !ios
|
||||
|
||||
package bind
|
||||
|
||||
import (
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
)
|
||||
|
||||
func (m *UDPMuxDefault) notifyAddressRemoval(addr string) {
|
||||
// Kernel mode: direct nbnet.PacketConn (SharedSocket wrapped with nbnet)
|
||||
if conn, ok := m.params.UDPConn.(*nbnet.PacketConn); ok {
|
||||
conn.RemoveAddress(addr)
|
||||
return
|
||||
}
|
||||
|
||||
// Userspace mode: UDPConn wrapper around nbnet.PacketConn
|
||||
if wrapped, ok := m.params.UDPConn.(*UDPConn); ok {
|
||||
if conn, ok := wrapped.GetPacketConn().(*nbnet.PacketConn); ok {
|
||||
conn.RemoveAddress(addr)
|
||||
}
|
||||
}
|
||||
}
|
||||
7
client/iface/bind/udp_mux_ios.go
Normal file
7
client/iface/bind/udp_mux_ios.go
Normal file
@@ -0,0 +1,7 @@
|
||||
//go:build ios
|
||||
|
||||
package bind
|
||||
|
||||
func (m *UDPMuxDefault) notifyAddressRemoval(addr string) {
|
||||
// iOS doesn't support nbnet hooks, so this is a no-op
|
||||
}
|
||||
@@ -62,7 +62,7 @@ func NewUniversalUDPMuxDefault(params UniversalUDPMuxParams) *UniversalUDPMuxDef
|
||||
|
||||
// wrap UDP connection, process server reflexive messages
|
||||
// before they are passed to the UDPMux connection handler (connWorker)
|
||||
m.params.UDPConn = &udpConn{
|
||||
m.params.UDPConn = &UDPConn{
|
||||
PacketConn: params.UDPConn,
|
||||
mux: m,
|
||||
logger: params.Logger,
|
||||
@@ -70,7 +70,6 @@ func NewUniversalUDPMuxDefault(params UniversalUDPMuxParams) *UniversalUDPMuxDef
|
||||
address: params.WGAddress,
|
||||
}
|
||||
|
||||
// embed UDPMux
|
||||
udpMuxParams := UDPMuxParams{
|
||||
Logger: params.Logger,
|
||||
UDPConn: m.params.UDPConn,
|
||||
@@ -114,8 +113,8 @@ func (m *UniversalUDPMuxDefault) ReadFromConn(ctx context.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
// udpConn is a wrapper around UDPMux conn that overrides ReadFrom and handles STUN/TURN packets
|
||||
type udpConn struct {
|
||||
// UDPConn is a wrapper around UDPMux conn that overrides ReadFrom and handles STUN/TURN packets
|
||||
type UDPConn struct {
|
||||
net.PacketConn
|
||||
mux *UniversalUDPMuxDefault
|
||||
logger logging.LeveledLogger
|
||||
@@ -125,7 +124,12 @@ type udpConn struct {
|
||||
address wgaddr.Address
|
||||
}
|
||||
|
||||
func (u *udpConn) WriteTo(b []byte, addr net.Addr) (int, error) {
|
||||
// GetPacketConn returns the underlying PacketConn
|
||||
func (u *UDPConn) GetPacketConn() net.PacketConn {
|
||||
return u.PacketConn
|
||||
}
|
||||
|
||||
func (u *UDPConn) WriteTo(b []byte, addr net.Addr) (int, error) {
|
||||
if u.filterFn == nil {
|
||||
return u.PacketConn.WriteTo(b, addr)
|
||||
}
|
||||
@@ -137,21 +141,21 @@ func (u *udpConn) WriteTo(b []byte, addr net.Addr) (int, error) {
|
||||
return u.handleUncachedAddress(b, addr)
|
||||
}
|
||||
|
||||
func (u *udpConn) handleCachedAddress(isRouted bool, b []byte, addr net.Addr) (int, error) {
|
||||
func (u *UDPConn) handleCachedAddress(isRouted bool, b []byte, addr net.Addr) (int, error) {
|
||||
if isRouted {
|
||||
return 0, fmt.Errorf("address %s is part of a routed network, refusing to write", addr)
|
||||
}
|
||||
return u.PacketConn.WriteTo(b, addr)
|
||||
}
|
||||
|
||||
func (u *udpConn) handleUncachedAddress(b []byte, addr net.Addr) (int, error) {
|
||||
func (u *UDPConn) handleUncachedAddress(b []byte, addr net.Addr) (int, error) {
|
||||
if err := u.performFilterCheck(addr); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return u.PacketConn.WriteTo(b, addr)
|
||||
}
|
||||
|
||||
func (u *udpConn) performFilterCheck(addr net.Addr) error {
|
||||
func (u *UDPConn) performFilterCheck(addr net.Addr) error {
|
||||
host, err := getHostFromAddr(addr)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to get host from address %s: %v", addr, err)
|
||||
@@ -164,7 +168,7 @@ func (u *udpConn) performFilterCheck(addr net.Addr) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
if u.address.Network.Contains(a.AsSlice()) {
|
||||
if u.address.Network.Contains(a) {
|
||||
log.Warnf("Address %s is part of the NetBird network %s, refusing to write", addr, u.address)
|
||||
return fmt.Errorf("address %s is part of the NetBird network %s, refusing to write", addr, u.address)
|
||||
}
|
||||
|
||||
17
client/iface/configurer/common.go
Normal file
17
client/iface/configurer/common.go
Normal file
@@ -0,0 +1,17 @@
|
||||
package configurer
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
)
|
||||
|
||||
func prefixesToIPNets(prefixes []netip.Prefix) []net.IPNet {
|
||||
ipNets := make([]net.IPNet, len(prefixes))
|
||||
for i, prefix := range prefixes {
|
||||
ipNets[i] = net.IPNet{
|
||||
IP: prefix.Addr().AsSlice(), // Convert netip.Addr to net.IP
|
||||
Mask: net.CIDRMask(prefix.Bits(), prefix.Addr().BitLen()), // Create subnet mask
|
||||
}
|
||||
}
|
||||
return ipNets
|
||||
}
|
||||
@@ -5,13 +5,18 @@ package configurer
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.zx2c4.com/wireguard/wgctrl"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
|
||||
"github.com/netbirdio/netbird/monotime"
|
||||
)
|
||||
|
||||
var zeroKey wgtypes.Key
|
||||
|
||||
type KernelConfigurer struct {
|
||||
deviceName string
|
||||
}
|
||||
@@ -43,7 +48,7 @@ func (c *KernelConfigurer) ConfigureInterface(privateKey string, port int) error
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *KernelConfigurer) UpdatePeer(peerKey string, allowedIps []net.IPNet, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error {
|
||||
func (c *KernelConfigurer) UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error {
|
||||
peerKeyParsed, err := wgtypes.ParseKey(peerKey)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -52,7 +57,7 @@ func (c *KernelConfigurer) UpdatePeer(peerKey string, allowedIps []net.IPNet, ke
|
||||
PublicKey: peerKeyParsed,
|
||||
ReplaceAllowedIPs: false,
|
||||
// don't replace allowed ips, wg will handle duplicated peer IP
|
||||
AllowedIPs: allowedIps,
|
||||
AllowedIPs: prefixesToIPNets(allowedIps),
|
||||
PersistentKeepaliveInterval: &keepAlive,
|
||||
Endpoint: endpoint,
|
||||
PresharedKey: preSharedKey,
|
||||
@@ -89,10 +94,10 @@ func (c *KernelConfigurer) RemovePeer(peerKey string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *KernelConfigurer) AddAllowedIP(peerKey string, allowedIP string) error {
|
||||
_, ipNet, err := net.ParseCIDR(allowedIP)
|
||||
if err != nil {
|
||||
return err
|
||||
func (c *KernelConfigurer) AddAllowedIP(peerKey string, allowedIP netip.Prefix) error {
|
||||
ipNet := net.IPNet{
|
||||
IP: allowedIP.Addr().AsSlice(),
|
||||
Mask: net.CIDRMask(allowedIP.Bits(), allowedIP.Addr().BitLen()),
|
||||
}
|
||||
|
||||
peerKeyParsed, err := wgtypes.ParseKey(peerKey)
|
||||
@@ -103,7 +108,7 @@ func (c *KernelConfigurer) AddAllowedIP(peerKey string, allowedIP string) error
|
||||
PublicKey: peerKeyParsed,
|
||||
UpdateOnly: true,
|
||||
ReplaceAllowedIPs: false,
|
||||
AllowedIPs: []net.IPNet{*ipNet},
|
||||
AllowedIPs: []net.IPNet{ipNet},
|
||||
}
|
||||
|
||||
config := wgtypes.Config{
|
||||
@@ -116,10 +121,10 @@ func (c *KernelConfigurer) AddAllowedIP(peerKey string, allowedIP string) error
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *KernelConfigurer) RemoveAllowedIP(peerKey string, allowedIP string) error {
|
||||
_, ipNet, err := net.ParseCIDR(allowedIP)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parse allowed IP: %w", err)
|
||||
func (c *KernelConfigurer) RemoveAllowedIP(peerKey string, allowedIP netip.Prefix) error {
|
||||
ipNet := net.IPNet{
|
||||
IP: allowedIP.Addr().AsSlice(),
|
||||
Mask: net.CIDRMask(allowedIP.Bits(), allowedIP.Addr().BitLen()),
|
||||
}
|
||||
|
||||
peerKeyParsed, err := wgtypes.ParseKey(peerKey)
|
||||
@@ -187,7 +192,11 @@ func (c *KernelConfigurer) configure(config wgtypes.Config) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer wg.Close()
|
||||
defer func() {
|
||||
if err := wg.Close(); err != nil {
|
||||
log.Errorf("Failed to close wgctrl client: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// validate if device with name exists
|
||||
_, err = wg.Device(c.deviceName)
|
||||
@@ -201,14 +210,75 @@ func (c *KernelConfigurer) configure(config wgtypes.Config) error {
|
||||
func (c *KernelConfigurer) Close() {
|
||||
}
|
||||
|
||||
func (c *KernelConfigurer) GetStats(peerKey string) (WGStats, error) {
|
||||
peer, err := c.getPeer(c.deviceName, peerKey)
|
||||
func (c *KernelConfigurer) FullStats() (*Stats, error) {
|
||||
wg, err := wgctrl.New()
|
||||
if err != nil {
|
||||
return WGStats{}, fmt.Errorf("get wireguard stats: %w", err)
|
||||
return nil, fmt.Errorf("wgctl: %w", err)
|
||||
}
|
||||
return WGStats{
|
||||
LastHandshake: peer.LastHandshakeTime,
|
||||
TxBytes: peer.TransmitBytes,
|
||||
RxBytes: peer.ReceiveBytes,
|
||||
}, nil
|
||||
defer func() {
|
||||
err = wg.Close()
|
||||
if err != nil {
|
||||
log.Errorf("Got error while closing wgctl: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
wgDevice, err := wg.Device(c.deviceName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get device %s: %w", c.deviceName, err)
|
||||
}
|
||||
fullStats := &Stats{
|
||||
DeviceName: wgDevice.Name,
|
||||
PublicKey: wgDevice.PublicKey.String(),
|
||||
ListenPort: wgDevice.ListenPort,
|
||||
FWMark: wgDevice.FirewallMark,
|
||||
Peers: []Peer{},
|
||||
}
|
||||
|
||||
for _, p := range wgDevice.Peers {
|
||||
peer := Peer{
|
||||
PublicKey: p.PublicKey.String(),
|
||||
AllowedIPs: p.AllowedIPs,
|
||||
TxBytes: p.TransmitBytes,
|
||||
RxBytes: p.ReceiveBytes,
|
||||
LastHandshake: p.LastHandshakeTime,
|
||||
PresharedKey: p.PresharedKey != zeroKey,
|
||||
}
|
||||
if p.Endpoint != nil {
|
||||
peer.Endpoint = *p.Endpoint
|
||||
}
|
||||
fullStats.Peers = append(fullStats.Peers, peer)
|
||||
}
|
||||
return fullStats, nil
|
||||
}
|
||||
|
||||
func (c *KernelConfigurer) GetStats() (map[string]WGStats, error) {
|
||||
stats := make(map[string]WGStats)
|
||||
wg, err := wgctrl.New()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("wgctl: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
err = wg.Close()
|
||||
if err != nil {
|
||||
log.Errorf("Got error while closing wgctl: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
wgDevice, err := wg.Device(c.deviceName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get device %s: %w", c.deviceName, err)
|
||||
}
|
||||
|
||||
for _, peer := range wgDevice.Peers {
|
||||
stats[peer.PublicKey.String()] = WGStats{
|
||||
LastHandshake: peer.LastHandshakeTime,
|
||||
TxBytes: peer.TransmitBytes,
|
||||
RxBytes: peer.ReceiveBytes,
|
||||
}
|
||||
}
|
||||
return stats, nil
|
||||
}
|
||||
|
||||
func (c *KernelConfigurer) LastActivities() map[string]monotime.Time {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
package configurer
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"runtime"
|
||||
"strconv"
|
||||
@@ -14,22 +16,40 @@ import (
|
||||
"golang.zx2c4.com/wireguard/device"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/bind"
|
||||
"github.com/netbirdio/netbird/monotime"
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
)
|
||||
|
||||
const (
|
||||
privateKey = "private_key"
|
||||
ipcKeyLastHandshakeTimeSec = "last_handshake_time_sec"
|
||||
ipcKeyLastHandshakeTimeNsec = "last_handshake_time_nsec"
|
||||
ipcKeyTxBytes = "tx_bytes"
|
||||
ipcKeyRxBytes = "rx_bytes"
|
||||
allowedIP = "allowed_ip"
|
||||
endpoint = "endpoint"
|
||||
fwmark = "fwmark"
|
||||
listenPort = "listen_port"
|
||||
publicKey = "public_key"
|
||||
presharedKey = "preshared_key"
|
||||
)
|
||||
|
||||
var ErrAllowedIPNotFound = fmt.Errorf("allowed IP not found")
|
||||
|
||||
type WGUSPConfigurer struct {
|
||||
device *device.Device
|
||||
deviceName string
|
||||
device *device.Device
|
||||
deviceName string
|
||||
activityRecorder *bind.ActivityRecorder
|
||||
|
||||
uapiListener net.Listener
|
||||
}
|
||||
|
||||
func NewUSPConfigurer(device *device.Device, deviceName string) *WGUSPConfigurer {
|
||||
func NewUSPConfigurer(device *device.Device, deviceName string, activityRecorder *bind.ActivityRecorder) *WGUSPConfigurer {
|
||||
wgCfg := &WGUSPConfigurer{
|
||||
device: device,
|
||||
deviceName: deviceName,
|
||||
device: device,
|
||||
deviceName: deviceName,
|
||||
activityRecorder: activityRecorder,
|
||||
}
|
||||
wgCfg.startUAPI()
|
||||
return wgCfg
|
||||
@@ -52,7 +72,7 @@ func (c *WGUSPConfigurer) ConfigureInterface(privateKey string, port int) error
|
||||
return c.device.IpcSet(toWgUserspaceString(config))
|
||||
}
|
||||
|
||||
func (c *WGUSPConfigurer) UpdatePeer(peerKey string, allowedIps []net.IPNet, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error {
|
||||
func (c *WGUSPConfigurer) UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error {
|
||||
peerKeyParsed, err := wgtypes.ParseKey(peerKey)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -61,7 +81,7 @@ func (c *WGUSPConfigurer) UpdatePeer(peerKey string, allowedIps []net.IPNet, kee
|
||||
PublicKey: peerKeyParsed,
|
||||
ReplaceAllowedIPs: false,
|
||||
// don't replace allowed ips, wg will handle duplicated peer IP
|
||||
AllowedIPs: allowedIps,
|
||||
AllowedIPs: prefixesToIPNets(allowedIps),
|
||||
PersistentKeepaliveInterval: &keepAlive,
|
||||
PresharedKey: preSharedKey,
|
||||
Endpoint: endpoint,
|
||||
@@ -71,7 +91,19 @@ func (c *WGUSPConfigurer) UpdatePeer(peerKey string, allowedIps []net.IPNet, kee
|
||||
Peers: []wgtypes.PeerConfig{peer},
|
||||
}
|
||||
|
||||
return c.device.IpcSet(toWgUserspaceString(config))
|
||||
if ipcErr := c.device.IpcSet(toWgUserspaceString(config)); ipcErr != nil {
|
||||
return ipcErr
|
||||
}
|
||||
|
||||
if endpoint != nil {
|
||||
addr, err := netip.ParseAddr(endpoint.IP.String())
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse endpoint address: %w", err)
|
||||
}
|
||||
addrPort := netip.AddrPortFrom(addr, uint16(endpoint.Port))
|
||||
c.activityRecorder.UpsertAddress(peerKey, addrPort)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *WGUSPConfigurer) RemovePeer(peerKey string) error {
|
||||
@@ -88,13 +120,16 @@ func (c *WGUSPConfigurer) RemovePeer(peerKey string) error {
|
||||
config := wgtypes.Config{
|
||||
Peers: []wgtypes.PeerConfig{peer},
|
||||
}
|
||||
return c.device.IpcSet(toWgUserspaceString(config))
|
||||
ipcErr := c.device.IpcSet(toWgUserspaceString(config))
|
||||
|
||||
c.activityRecorder.Remove(peerKey)
|
||||
return ipcErr
|
||||
}
|
||||
|
||||
func (c *WGUSPConfigurer) AddAllowedIP(peerKey string, allowedIP string) error {
|
||||
_, ipNet, err := net.ParseCIDR(allowedIP)
|
||||
if err != nil {
|
||||
return err
|
||||
func (c *WGUSPConfigurer) AddAllowedIP(peerKey string, allowedIP netip.Prefix) error {
|
||||
ipNet := net.IPNet{
|
||||
IP: allowedIP.Addr().AsSlice(),
|
||||
Mask: net.CIDRMask(allowedIP.Bits(), allowedIP.Addr().BitLen()),
|
||||
}
|
||||
|
||||
peerKeyParsed, err := wgtypes.ParseKey(peerKey)
|
||||
@@ -105,7 +140,7 @@ func (c *WGUSPConfigurer) AddAllowedIP(peerKey string, allowedIP string) error {
|
||||
PublicKey: peerKeyParsed,
|
||||
UpdateOnly: true,
|
||||
ReplaceAllowedIPs: false,
|
||||
AllowedIPs: []net.IPNet{*ipNet},
|
||||
AllowedIPs: []net.IPNet{ipNet},
|
||||
}
|
||||
|
||||
config := wgtypes.Config{
|
||||
@@ -115,7 +150,7 @@ func (c *WGUSPConfigurer) AddAllowedIP(peerKey string, allowedIP string) error {
|
||||
return c.device.IpcSet(toWgUserspaceString(config))
|
||||
}
|
||||
|
||||
func (c *WGUSPConfigurer) RemoveAllowedIP(peerKey string, ip string) error {
|
||||
func (c *WGUSPConfigurer) RemoveAllowedIP(peerKey string, allowedIP netip.Prefix) error {
|
||||
ipc, err := c.device.IpcGet()
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -138,6 +173,8 @@ func (c *WGUSPConfigurer) RemoveAllowedIP(peerKey string, ip string) error {
|
||||
|
||||
foundPeer := false
|
||||
removedAllowedIP := false
|
||||
ip := allowedIP.String()
|
||||
|
||||
for _, line := range lines {
|
||||
line = strings.TrimSpace(line)
|
||||
|
||||
@@ -160,8 +197,8 @@ func (c *WGUSPConfigurer) RemoveAllowedIP(peerKey string, ip string) error {
|
||||
|
||||
// Append the line to the output string
|
||||
if foundPeer && strings.HasPrefix(line, "allowed_ip=") {
|
||||
allowedIP := strings.TrimPrefix(line, "allowed_ip=")
|
||||
_, ipNet, err := net.ParseCIDR(allowedIP)
|
||||
allowedIPStr := strings.TrimPrefix(line, "allowed_ip=")
|
||||
_, ipNet, err := net.ParseCIDR(allowedIPStr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -178,6 +215,19 @@ func (c *WGUSPConfigurer) RemoveAllowedIP(peerKey string, ip string) error {
|
||||
return c.device.IpcSet(toWgUserspaceString(config))
|
||||
}
|
||||
|
||||
func (c *WGUSPConfigurer) FullStats() (*Stats, error) {
|
||||
ipcStr, err := c.device.IpcGet()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("IpcGet failed: %w", err)
|
||||
}
|
||||
|
||||
return parseStatus(c.deviceName, ipcStr)
|
||||
}
|
||||
|
||||
func (c *WGUSPConfigurer) LastActivities() map[string]monotime.Time {
|
||||
return c.activityRecorder.GetLastActivities()
|
||||
}
|
||||
|
||||
// startUAPI starts the UAPI listener for managing the WireGuard interface via external tool
|
||||
func (t *WGUSPConfigurer) startUAPI() {
|
||||
var err error
|
||||
@@ -217,91 +267,75 @@ func (t *WGUSPConfigurer) Close() {
|
||||
}
|
||||
}
|
||||
|
||||
func (t *WGUSPConfigurer) GetStats(peerKey string) (WGStats, error) {
|
||||
func (t *WGUSPConfigurer) GetStats() (map[string]WGStats, error) {
|
||||
ipc, err := t.device.IpcGet()
|
||||
if err != nil {
|
||||
return WGStats{}, fmt.Errorf("ipc get: %w", err)
|
||||
return nil, fmt.Errorf("ipc get: %w", err)
|
||||
}
|
||||
|
||||
stats, err := findPeerInfo(ipc, peerKey, []string{
|
||||
"last_handshake_time_sec",
|
||||
"last_handshake_time_nsec",
|
||||
"tx_bytes",
|
||||
"rx_bytes",
|
||||
})
|
||||
if err != nil {
|
||||
return WGStats{}, fmt.Errorf("find peer info: %w", err)
|
||||
}
|
||||
|
||||
sec, err := strconv.ParseInt(stats["last_handshake_time_sec"], 10, 64)
|
||||
if err != nil {
|
||||
return WGStats{}, fmt.Errorf("parse handshake sec: %w", err)
|
||||
}
|
||||
nsec, err := strconv.ParseInt(stats["last_handshake_time_nsec"], 10, 64)
|
||||
if err != nil {
|
||||
return WGStats{}, fmt.Errorf("parse handshake nsec: %w", err)
|
||||
}
|
||||
txBytes, err := strconv.ParseInt(stats["tx_bytes"], 10, 64)
|
||||
if err != nil {
|
||||
return WGStats{}, fmt.Errorf("parse tx_bytes: %w", err)
|
||||
}
|
||||
rxBytes, err := strconv.ParseInt(stats["rx_bytes"], 10, 64)
|
||||
if err != nil {
|
||||
return WGStats{}, fmt.Errorf("parse rx_bytes: %w", err)
|
||||
}
|
||||
|
||||
return WGStats{
|
||||
LastHandshake: time.Unix(sec, nsec),
|
||||
TxBytes: txBytes,
|
||||
RxBytes: rxBytes,
|
||||
}, nil
|
||||
return parseTransfers(ipc)
|
||||
}
|
||||
|
||||
func findPeerInfo(ipcInput string, peerKey string, searchConfigKeys []string) (map[string]string, error) {
|
||||
peerKeyParsed, err := wgtypes.ParseKey(peerKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse key: %w", err)
|
||||
}
|
||||
|
||||
hexKey := hex.EncodeToString(peerKeyParsed[:])
|
||||
|
||||
lines := strings.Split(ipcInput, "\n")
|
||||
|
||||
configFound := map[string]string{}
|
||||
foundPeer := false
|
||||
func parseTransfers(ipc string) (map[string]WGStats, error) {
|
||||
stats := make(map[string]WGStats)
|
||||
var (
|
||||
currentKey string
|
||||
currentStats WGStats
|
||||
hasPeer bool
|
||||
)
|
||||
lines := strings.Split(ipc, "\n")
|
||||
for _, line := range lines {
|
||||
line = strings.TrimSpace(line)
|
||||
|
||||
// If we're within the details of the found peer and encounter another public key,
|
||||
// this means we're starting another peer's details. So, stop.
|
||||
if strings.HasPrefix(line, "public_key=") && foundPeer {
|
||||
break
|
||||
}
|
||||
|
||||
// Identify the peer with the specific public key
|
||||
if line == fmt.Sprintf("public_key=%s", hexKey) {
|
||||
foundPeer = true
|
||||
}
|
||||
|
||||
for _, key := range searchConfigKeys {
|
||||
if foundPeer && strings.HasPrefix(line, key+"=") {
|
||||
v := strings.SplitN(line, "=", 2)
|
||||
configFound[v[0]] = v[1]
|
||||
if strings.HasPrefix(line, "public_key=") {
|
||||
peerID := strings.TrimPrefix(line, "public_key=")
|
||||
h, err := hex.DecodeString(peerID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("decode peerID: %w", err)
|
||||
}
|
||||
currentKey = base64.StdEncoding.EncodeToString(h)
|
||||
currentStats = WGStats{} // Reset stats for the new peer
|
||||
hasPeer = true
|
||||
stats[currentKey] = currentStats
|
||||
continue
|
||||
}
|
||||
|
||||
if !hasPeer {
|
||||
continue
|
||||
}
|
||||
|
||||
key := strings.SplitN(line, "=", 2)
|
||||
if len(key) != 2 {
|
||||
continue
|
||||
}
|
||||
switch key[0] {
|
||||
case ipcKeyLastHandshakeTimeSec:
|
||||
hs, err := toLastHandshake(key[1])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
currentStats.LastHandshake = hs
|
||||
stats[currentKey] = currentStats
|
||||
case ipcKeyRxBytes:
|
||||
rxBytes, err := toBytes(key[1])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse rx_bytes: %w", err)
|
||||
}
|
||||
currentStats.RxBytes = rxBytes
|
||||
stats[currentKey] = currentStats
|
||||
case ipcKeyTxBytes:
|
||||
TxBytes, err := toBytes(key[1])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse tx_bytes: %w", err)
|
||||
}
|
||||
currentStats.TxBytes = TxBytes
|
||||
stats[currentKey] = currentStats
|
||||
}
|
||||
}
|
||||
|
||||
// todo: use multierr
|
||||
for _, key := range searchConfigKeys {
|
||||
if _, ok := configFound[key]; !ok {
|
||||
return configFound, fmt.Errorf("config key not found: %s", key)
|
||||
}
|
||||
}
|
||||
if !foundPeer {
|
||||
return nil, fmt.Errorf("%w: %s", ErrPeerNotFound, peerKey)
|
||||
}
|
||||
|
||||
return configFound, nil
|
||||
return stats, nil
|
||||
}
|
||||
|
||||
func toWgUserspaceString(wgCfg wgtypes.Config) string {
|
||||
@@ -355,9 +389,154 @@ func toWgUserspaceString(wgCfg wgtypes.Config) string {
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
func toLastHandshake(stringVar string) (time.Time, error) {
|
||||
sec, err := strconv.ParseInt(stringVar, 10, 64)
|
||||
if err != nil {
|
||||
return time.Time{}, fmt.Errorf("parse handshake sec: %w", err)
|
||||
}
|
||||
return time.Unix(sec, 0), nil
|
||||
}
|
||||
|
||||
func toBytes(s string) (int64, error) {
|
||||
return strconv.ParseInt(s, 10, 64)
|
||||
}
|
||||
|
||||
func getFwmark() int {
|
||||
if nbnet.AdvancedRouting() {
|
||||
return nbnet.ControlPlaneMark
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func hexToWireguardKey(hexKey string) (wgtypes.Key, error) {
|
||||
// Decode hex string to bytes
|
||||
keyBytes, err := hex.DecodeString(hexKey)
|
||||
if err != nil {
|
||||
return wgtypes.Key{}, fmt.Errorf("failed to decode hex key: %w", err)
|
||||
}
|
||||
|
||||
// Check if we have the right number of bytes (WireGuard keys are 32 bytes)
|
||||
if len(keyBytes) != 32 {
|
||||
return wgtypes.Key{}, fmt.Errorf("invalid key length: expected 32 bytes, got %d", len(keyBytes))
|
||||
}
|
||||
|
||||
// Convert to wgtypes.Key
|
||||
var key wgtypes.Key
|
||||
copy(key[:], keyBytes)
|
||||
|
||||
return key, nil
|
||||
}
|
||||
|
||||
func parseStatus(deviceName, ipcStr string) (*Stats, error) {
|
||||
stats := &Stats{DeviceName: deviceName}
|
||||
var currentPeer *Peer
|
||||
for _, line := range strings.Split(strings.TrimSpace(ipcStr), "\n") {
|
||||
if line == "" {
|
||||
continue
|
||||
}
|
||||
parts := strings.SplitN(line, "=", 2)
|
||||
if len(parts) != 2 {
|
||||
continue
|
||||
}
|
||||
key := parts[0]
|
||||
val := parts[1]
|
||||
|
||||
switch key {
|
||||
case privateKey:
|
||||
key, err := hexToWireguardKey(val)
|
||||
if err != nil {
|
||||
log.Errorf("failed to parse private key: %v", err)
|
||||
continue
|
||||
}
|
||||
stats.PublicKey = key.PublicKey().String()
|
||||
case publicKey:
|
||||
// Save previous peer
|
||||
if currentPeer != nil {
|
||||
stats.Peers = append(stats.Peers, *currentPeer)
|
||||
}
|
||||
key, err := hexToWireguardKey(val)
|
||||
if err != nil {
|
||||
log.Errorf("failed to parse public key: %v", err)
|
||||
continue
|
||||
}
|
||||
currentPeer = &Peer{
|
||||
PublicKey: key.String(),
|
||||
}
|
||||
case listenPort:
|
||||
if port, err := strconv.Atoi(val); err == nil {
|
||||
stats.ListenPort = port
|
||||
}
|
||||
case fwmark:
|
||||
if fwmark, err := strconv.Atoi(val); err == nil {
|
||||
stats.FWMark = fwmark
|
||||
}
|
||||
case endpoint:
|
||||
if currentPeer == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
host, portStr, err := net.SplitHostPort(strings.Trim(val, "[]"))
|
||||
if err != nil {
|
||||
log.Errorf("failed to parse endpoint: %v", err)
|
||||
continue
|
||||
}
|
||||
port, err := strconv.Atoi(portStr)
|
||||
if err != nil {
|
||||
log.Errorf("failed to parse endpoint port: %v", err)
|
||||
continue
|
||||
}
|
||||
currentPeer.Endpoint = net.UDPAddr{
|
||||
IP: net.ParseIP(host),
|
||||
Port: port,
|
||||
}
|
||||
case allowedIP:
|
||||
if currentPeer == nil {
|
||||
continue
|
||||
}
|
||||
_, ipnet, err := net.ParseCIDR(val)
|
||||
if err == nil {
|
||||
currentPeer.AllowedIPs = append(currentPeer.AllowedIPs, *ipnet)
|
||||
}
|
||||
case ipcKeyTxBytes:
|
||||
if currentPeer == nil {
|
||||
continue
|
||||
}
|
||||
rxBytes, err := toBytes(val)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
currentPeer.TxBytes = rxBytes
|
||||
case ipcKeyRxBytes:
|
||||
if currentPeer == nil {
|
||||
continue
|
||||
}
|
||||
rxBytes, err := toBytes(val)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
currentPeer.RxBytes = rxBytes
|
||||
|
||||
case ipcKeyLastHandshakeTimeSec:
|
||||
if currentPeer == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
ts, err := toLastHandshake(val)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
currentPeer.LastHandshake = ts
|
||||
case presharedKey:
|
||||
if currentPeer == nil {
|
||||
continue
|
||||
}
|
||||
if val != "" && val != "0000000000000000000000000000000000000000000000000000000000000000" {
|
||||
currentPeer.PresharedKey = true
|
||||
}
|
||||
}
|
||||
}
|
||||
if currentPeer != nil {
|
||||
stats.Peers = append(stats.Peers, *currentPeer)
|
||||
}
|
||||
return stats, nil
|
||||
}
|
||||
|
||||
@@ -2,10 +2,8 @@ package configurer
|
||||
|
||||
import (
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
)
|
||||
@@ -34,58 +32,35 @@ errno=0
|
||||
|
||||
`
|
||||
|
||||
func Test_findPeerInfo(t *testing.T) {
|
||||
func Test_parseTransfers(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
peerKey string
|
||||
searchKeys []string
|
||||
want map[string]string
|
||||
wantErr bool
|
||||
name string
|
||||
peerKey string
|
||||
want WGStats
|
||||
}{
|
||||
{
|
||||
name: "single",
|
||||
peerKey: "58402e695ba1772b1cc9309755f043251ea77fdcf10fbe63989ceb7e19321376",
|
||||
searchKeys: []string{"tx_bytes"},
|
||||
want: map[string]string{
|
||||
"tx_bytes": "38333",
|
||||
name: "single",
|
||||
peerKey: "b85996fecc9c7f1fc6d2572a76eda11d59bcd20be8e543b15ce4bd85a8e75a33",
|
||||
want: WGStats{
|
||||
TxBytes: 0,
|
||||
RxBytes: 0,
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "multiple",
|
||||
peerKey: "58402e695ba1772b1cc9309755f043251ea77fdcf10fbe63989ceb7e19321376",
|
||||
searchKeys: []string{"tx_bytes", "rx_bytes"},
|
||||
want: map[string]string{
|
||||
"tx_bytes": "38333",
|
||||
"rx_bytes": "2224",
|
||||
name: "multiple",
|
||||
peerKey: "58402e695ba1772b1cc9309755f043251ea77fdcf10fbe63989ceb7e19321376",
|
||||
want: WGStats{
|
||||
TxBytes: 38333,
|
||||
RxBytes: 2224,
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "lastpeer",
|
||||
peerKey: "662e14fd594556f522604703340351258903b64f35553763f19426ab2a515c58",
|
||||
searchKeys: []string{"tx_bytes", "rx_bytes"},
|
||||
want: map[string]string{
|
||||
"tx_bytes": "1212111",
|
||||
"rx_bytes": "1929999999",
|
||||
name: "lastpeer",
|
||||
peerKey: "662e14fd594556f522604703340351258903b64f35553763f19426ab2a515c58",
|
||||
want: WGStats{
|
||||
TxBytes: 1212111,
|
||||
RxBytes: 1929999999,
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "peer not found",
|
||||
peerKey: "1111111111111111111111111111111111111111111111111111111111111111",
|
||||
searchKeys: nil,
|
||||
want: nil,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "key not found",
|
||||
peerKey: "662e14fd594556f522604703340351258903b64f35553763f19426ab2a515c58",
|
||||
searchKeys: []string{"tx_bytes", "unknown_key"},
|
||||
want: map[string]string{
|
||||
"tx_bytes": "1212111",
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
@@ -96,9 +71,19 @@ func Test_findPeerInfo(t *testing.T) {
|
||||
key, err := wgtypes.NewKey(res)
|
||||
require.NoError(t, err)
|
||||
|
||||
got, err := findPeerInfo(ipcFixture, key.String(), tt.searchKeys)
|
||||
assert.Equalf(t, tt.wantErr, err != nil, fmt.Sprintf("findPeerInfo(%v, %v, %v)", ipcFixture, key.String(), tt.searchKeys))
|
||||
assert.Equalf(t, tt.want, got, "findPeerInfo(%v, %v, %v)", ipcFixture, key.String(), tt.searchKeys)
|
||||
stats, err := parseTransfers(ipcFixture)
|
||||
if err != nil {
|
||||
require.NoError(t, err)
|
||||
return
|
||||
}
|
||||
|
||||
stat, ok := stats[key.String()]
|
||||
if !ok {
|
||||
require.True(t, ok)
|
||||
return
|
||||
}
|
||||
|
||||
require.Equal(t, tt.want, stat)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
24
client/iface/configurer/wgshow.go
Normal file
24
client/iface/configurer/wgshow.go
Normal file
@@ -0,0 +1,24 @@
|
||||
package configurer
|
||||
|
||||
import (
|
||||
"net"
|
||||
"time"
|
||||
)
|
||||
|
||||
type Peer struct {
|
||||
PublicKey string
|
||||
Endpoint net.UDPAddr
|
||||
AllowedIPs []net.IPNet
|
||||
TxBytes int64
|
||||
RxBytes int64
|
||||
LastHandshake time.Time
|
||||
PresharedKey bool
|
||||
}
|
||||
|
||||
type Stats struct {
|
||||
DeviceName string
|
||||
PublicKey string
|
||||
ListenPort int
|
||||
FWMark int
|
||||
Peers []Peer
|
||||
}
|
||||
@@ -24,6 +24,7 @@ type WGTunDevice struct {
|
||||
mtu int
|
||||
iceBind *bind.ICEBind
|
||||
tunAdapter TunAdapter
|
||||
disableDNS bool
|
||||
|
||||
name string
|
||||
device *device.Device
|
||||
@@ -32,7 +33,7 @@ type WGTunDevice struct {
|
||||
configurer WGConfigurer
|
||||
}
|
||||
|
||||
func NewTunDevice(address wgaddr.Address, port int, key string, mtu int, iceBind *bind.ICEBind, tunAdapter TunAdapter) *WGTunDevice {
|
||||
func NewTunDevice(address wgaddr.Address, port int, key string, mtu int, iceBind *bind.ICEBind, tunAdapter TunAdapter, disableDNS bool) *WGTunDevice {
|
||||
return &WGTunDevice{
|
||||
address: address,
|
||||
port: port,
|
||||
@@ -40,6 +41,7 @@ func NewTunDevice(address wgaddr.Address, port int, key string, mtu int, iceBind
|
||||
mtu: mtu,
|
||||
iceBind: iceBind,
|
||||
tunAdapter: tunAdapter,
|
||||
disableDNS: disableDNS,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -49,6 +51,13 @@ func (t *WGTunDevice) Create(routes []string, dns string, searchDomains []string
|
||||
routesString := routesToString(routes)
|
||||
searchDomainsToString := searchDomainsToString(searchDomains)
|
||||
|
||||
// Skip DNS configuration when DisableDNS is enabled
|
||||
if t.disableDNS {
|
||||
log.Info("DNS is disabled, skipping DNS and search domain configuration")
|
||||
dns = ""
|
||||
searchDomainsToString = ""
|
||||
}
|
||||
|
||||
fd, err := t.tunAdapter.ConfigureInterface(t.address.String(), t.mtu, dns, searchDomainsToString, routesString)
|
||||
if err != nil {
|
||||
log.Errorf("failed to create Android interface: %s", err)
|
||||
@@ -70,7 +79,7 @@ func (t *WGTunDevice) Create(routes []string, dns string, searchDomains []string
|
||||
// this helps with support for the older NetBird clients that had a hardcoded direct mode
|
||||
// t.device.DisableSomeRoamingForBrokenMobileSemantics()
|
||||
|
||||
t.configurer = configurer.NewUSPConfigurer(t.device, t.name)
|
||||
t.configurer = configurer.NewUSPConfigurer(t.device, t.name, t.iceBind.ActivityRecorder())
|
||||
err = t.configurer.ConfigureInterface(t.key, t.port)
|
||||
if err != nil {
|
||||
t.device.Close()
|
||||
|
||||
@@ -61,7 +61,7 @@ func (t *TunDevice) Create() (WGConfigurer, error) {
|
||||
return nil, fmt.Errorf("error assigning ip: %s", err)
|
||||
}
|
||||
|
||||
t.configurer = configurer.NewUSPConfigurer(t.device, t.name)
|
||||
t.configurer = configurer.NewUSPConfigurer(t.device, t.name, t.iceBind.ActivityRecorder())
|
||||
err = t.configurer.ConfigureInterface(t.key, t.port)
|
||||
if err != nil {
|
||||
t.device.Close()
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package device
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"sync"
|
||||
|
||||
@@ -10,11 +9,11 @@ import (
|
||||
|
||||
// PacketFilter interface for firewall abilities
|
||||
type PacketFilter interface {
|
||||
// DropOutgoing filter outgoing packets from host to external destinations
|
||||
DropOutgoing(packetData []byte, size int) bool
|
||||
// FilterOutbound filter outgoing packets from host to external destinations
|
||||
FilterOutbound(packetData []byte, size int) bool
|
||||
|
||||
// DropIncoming filter incoming packets from external sources to host
|
||||
DropIncoming(packetData []byte, size int) bool
|
||||
// FilterInbound filter incoming packets from external sources to host
|
||||
FilterInbound(packetData []byte, size int) bool
|
||||
|
||||
// AddUDPPacketHook calls hook when UDP packet from given direction matched
|
||||
//
|
||||
@@ -24,9 +23,6 @@ type PacketFilter interface {
|
||||
|
||||
// RemovePacketHook removes hook by ID
|
||||
RemovePacketHook(hookID string) error
|
||||
|
||||
// SetNetwork of the wireguard interface to which filtering applied
|
||||
SetNetwork(*net.IPNet)
|
||||
}
|
||||
|
||||
// FilteredDevice to override Read or Write of packets
|
||||
@@ -58,7 +54,7 @@ func (d *FilteredDevice) Read(bufs [][]byte, sizes []int, offset int) (n int, er
|
||||
}
|
||||
|
||||
for i := 0; i < n; i++ {
|
||||
if filter.DropOutgoing(bufs[i][offset:offset+sizes[i]], sizes[i]) {
|
||||
if filter.FilterOutbound(bufs[i][offset:offset+sizes[i]], sizes[i]) {
|
||||
bufs = append(bufs[:i], bufs[i+1:]...)
|
||||
sizes = append(sizes[:i], sizes[i+1:]...)
|
||||
n--
|
||||
@@ -82,7 +78,7 @@ func (d *FilteredDevice) Write(bufs [][]byte, offset int) (int, error) {
|
||||
filteredBufs := make([][]byte, 0, len(bufs))
|
||||
dropped := 0
|
||||
for _, buf := range bufs {
|
||||
if !filter.DropIncoming(buf[offset:], len(buf)) {
|
||||
if !filter.FilterInbound(buf[offset:], len(buf)) {
|
||||
filteredBufs = append(filteredBufs, buf)
|
||||
dropped++
|
||||
}
|
||||
|
||||
@@ -146,7 +146,7 @@ func TestDeviceWrapperRead(t *testing.T) {
|
||||
tun.EXPECT().Write(mockBufs, 0).Return(0, nil)
|
||||
|
||||
filter := mocks.NewMockPacketFilter(ctrl)
|
||||
filter.EXPECT().DropIncoming(gomock.Any(), gomock.Any()).Return(true)
|
||||
filter.EXPECT().FilterInbound(gomock.Any(), gomock.Any()).Return(true)
|
||||
|
||||
wrapped := newDeviceFilter(tun)
|
||||
wrapped.filter = filter
|
||||
@@ -201,7 +201,7 @@ func TestDeviceWrapperRead(t *testing.T) {
|
||||
return 1, nil
|
||||
})
|
||||
filter := mocks.NewMockPacketFilter(ctrl)
|
||||
filter.EXPECT().DropOutgoing(gomock.Any(), gomock.Any()).Return(true)
|
||||
filter.EXPECT().FilterOutbound(gomock.Any(), gomock.Any()).Return(true)
|
||||
|
||||
wrapped := newDeviceFilter(tun)
|
||||
wrapped.filter = filter
|
||||
|
||||
@@ -71,7 +71,7 @@ func (t *TunDevice) Create() (WGConfigurer, error) {
|
||||
// this helps with support for the older NetBird clients that had a hardcoded direct mode
|
||||
// t.device.DisableSomeRoamingForBrokenMobileSemantics()
|
||||
|
||||
t.configurer = configurer.NewUSPConfigurer(t.device, t.name)
|
||||
t.configurer = configurer.NewUSPConfigurer(t.device, t.name, t.iceBind.ActivityRecorder())
|
||||
err = t.configurer.ConfigureInterface(t.key, t.port)
|
||||
if err != nil {
|
||||
t.device.Close()
|
||||
|
||||
@@ -16,6 +16,7 @@ import (
|
||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
"github.com/netbirdio/netbird/sharedsock"
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
)
|
||||
|
||||
type TunKernelDevice struct {
|
||||
@@ -99,8 +100,14 @@ func (t *TunKernelDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var udpConn net.PacketConn = rawSock
|
||||
if !nbnet.AdvancedRouting() {
|
||||
udpConn = nbnet.WrapPacketConn(rawSock)
|
||||
}
|
||||
|
||||
bindParams := bind.UniversalUDPMuxParams{
|
||||
UDPConn: rawSock,
|
||||
UDPConn: udpConn,
|
||||
Net: t.transportNet,
|
||||
FilterFn: t.filterFn,
|
||||
WGAddress: t.address,
|
||||
|
||||
@@ -51,7 +51,11 @@ func (t *TunNetstackDevice) Create() (WGConfigurer, error) {
|
||||
log.Info("create nbnetstack tun interface")
|
||||
|
||||
// TODO: get from service listener runtime IP
|
||||
dnsAddr := nbnet.GetLastIPFromNetwork(t.address.Network, 1)
|
||||
dnsAddr, err := nbnet.GetLastIPFromNetwork(t.address.Network, 1)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("last ip: %w", err)
|
||||
}
|
||||
|
||||
log.Debugf("netstack using address: %s", t.address.IP)
|
||||
t.nsTun = nbnetstack.NewNetStackTun(t.listenAddress, t.address.IP, dnsAddr, t.mtu)
|
||||
log.Debugf("netstack using dns address: %s", dnsAddr)
|
||||
@@ -68,7 +72,7 @@ func (t *TunNetstackDevice) Create() (WGConfigurer, error) {
|
||||
device.NewLogger(wgLogLevel(), "[netbird] "),
|
||||
)
|
||||
|
||||
t.configurer = configurer.NewUSPConfigurer(t.device, t.name)
|
||||
t.configurer = configurer.NewUSPConfigurer(t.device, t.name, t.iceBind.ActivityRecorder())
|
||||
err = t.configurer.ConfigureInterface(t.key, t.port)
|
||||
if err != nil {
|
||||
_ = tunIface.Close()
|
||||
|
||||
@@ -64,7 +64,7 @@ func (t *USPDevice) Create() (WGConfigurer, error) {
|
||||
return nil, fmt.Errorf("error assigning ip: %s", err)
|
||||
}
|
||||
|
||||
t.configurer = configurer.NewUSPConfigurer(t.device, t.name)
|
||||
t.configurer = configurer.NewUSPConfigurer(t.device, t.name, t.iceBind.ActivityRecorder())
|
||||
err = t.configurer.ConfigureInterface(t.key, t.port)
|
||||
if err != nil {
|
||||
t.device.Close()
|
||||
|
||||
@@ -94,7 +94,7 @@ func (t *TunDevice) Create() (WGConfigurer, error) {
|
||||
return nil, fmt.Errorf("error assigning ip: %s", err)
|
||||
}
|
||||
|
||||
t.configurer = configurer.NewUSPConfigurer(t.device, t.name)
|
||||
t.configurer = configurer.NewUSPConfigurer(t.device, t.name, t.iceBind.ActivityRecorder())
|
||||
err = t.configurer.ConfigureInterface(t.key, t.port)
|
||||
if err != nil {
|
||||
t.device.Close()
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user