mirror of
https://github.com/netbirdio/netbird.git
synced 2026-03-31 06:34:19 -04:00
Compare commits
314 Commits
fix/iptabl
...
test/abc2
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
480e00ce43 | ||
|
|
01e88ed6ab | ||
|
|
be209e7841 | ||
|
|
7c22a4ba9b | ||
|
|
f86aa47933 | ||
|
|
aafa342786 | ||
|
|
042141db06 | ||
|
|
4a1aee1ae0 | ||
|
|
ba33572ec9 | ||
|
|
9d213e0b54 | ||
|
|
5dde044fa5 | ||
|
|
5a3d9e401f | ||
|
|
fde1a2196c | ||
|
|
0aeb87742a | ||
|
|
6d747b2f83 | ||
|
|
199bf73103 | ||
|
|
17f5abc653 | ||
|
|
aa935bdae3 | ||
|
|
452419c4c3 | ||
|
|
17b1099032 | ||
|
|
a4b9e93217 | ||
|
|
63d7957140 | ||
|
|
9a6814deff | ||
|
|
190698bcf2 | ||
|
|
468fa2940b | ||
|
|
79a0647a26 | ||
|
|
17ceb3bde8 | ||
|
|
5a8f1763a6 | ||
|
|
f64e73ca70 | ||
|
|
b085419ab8 | ||
|
|
d78b652ff7 | ||
|
|
7251150c1c | ||
|
|
b65c2f69b0 | ||
|
|
d8ce08d898 | ||
|
|
e1c50248d9 | ||
|
|
ce2d14c08e | ||
|
|
52fd9a575a | ||
|
|
9028c3c1f7 | ||
|
|
9357a587e9 | ||
|
|
a47c69c472 | ||
|
|
bbea4c3cc3 | ||
|
|
b7a6cbfaa5 | ||
|
|
e18bf565a2 | ||
|
|
51fa3c92c5 | ||
|
|
d65602f904 | ||
|
|
8d9e1fed5f | ||
|
|
e1eddd1cab | ||
|
|
0fbf72434e | ||
|
|
51f133fdc6 | ||
|
|
d5338c09dc | ||
|
|
8fd4166c53 | ||
|
|
9bc7b9e897 | ||
|
|
db3cba5e0f | ||
|
|
cb3408a10b | ||
|
|
0afd738509 | ||
|
|
cf87f1e702 | ||
|
|
e890fdae54 | ||
|
|
dd14db6478 | ||
|
|
88747e3e01 | ||
|
|
fb30931365 | ||
|
|
a7547b9990 | ||
|
|
62bacee8dc | ||
|
|
71cd2e3e03 | ||
|
|
bdf71ab7ff | ||
|
|
a2f2a6e21a | ||
|
|
f89332fcd2 | ||
|
|
8604add997 | ||
|
|
93cab49696 | ||
|
|
b6835d9467 | ||
|
|
846d486366 | ||
|
|
9c56f74235 | ||
|
|
25b3641be8 | ||
|
|
c41504b571 | ||
|
|
399493a954 | ||
|
|
4771fed64f | ||
|
|
88117f7d16 | ||
|
|
5ac9f9fe2f | ||
|
|
a7d6632298 | ||
|
|
d4194cba6a | ||
|
|
131d9f1bc7 | ||
|
|
f099e02b34 | ||
|
|
93646e6a13 | ||
|
|
67a2127fd7 | ||
|
|
dd7fcbd083 | ||
|
|
d5f330b9c0 | ||
|
|
9fa0fbda0d | ||
|
|
5a7aa461de | ||
|
|
e9c967b27c | ||
|
|
ace588758c | ||
|
|
8bb16e016c | ||
|
|
6a2a97f088 | ||
|
|
3591795a58 | ||
|
|
5311ce4e4a | ||
|
|
c61cb00f40 | ||
|
|
72a1e97304 | ||
|
|
5242851ecc | ||
|
|
cb69348a30 | ||
|
|
69dbcbd362 | ||
|
|
5de4acf2fe | ||
|
|
aa3b79d311 | ||
|
|
8b4ec96516 | ||
|
|
1f3a12d941 | ||
|
|
1de3bb5420 | ||
|
|
163933d429 | ||
|
|
875a2e2b63 | ||
|
|
fd8bba6aa3 | ||
|
|
86908eee58 | ||
|
|
c1caec3fcb | ||
|
|
b28b8fce50 | ||
|
|
f780f17f85 | ||
|
|
5903715a61 | ||
|
|
5469de53c5 | ||
|
|
bc3d647d6b | ||
|
|
7060b63838 | ||
|
|
3168b80ad0 | ||
|
|
818c6b885f | ||
|
|
01f28baec7 | ||
|
|
56896794b3 | ||
|
|
f73a2e2848 | ||
|
|
19fa071a93 | ||
|
|
cba3c549e9 | ||
|
|
65247de48d | ||
|
|
2d1dfa3ae7 | ||
|
|
5961c8330e | ||
|
|
d275d411aa | ||
|
|
5ecafef5d2 | ||
|
|
d073a250cc | ||
|
|
a1c48468ab | ||
|
|
dd1e730454 | ||
|
|
050f140245 | ||
|
|
006ba32086 | ||
|
|
b03343bc4d | ||
|
|
36d62f1844 | ||
|
|
08733ed8d5 | ||
|
|
27ed88f918 | ||
|
|
45fc89b2c9 | ||
|
|
f822a58326 | ||
|
|
d1f13025d1 | ||
|
|
3f8b500f0b | ||
|
|
0d2db4b172 | ||
|
|
7a18dea766 | ||
|
|
ae5f69562d | ||
|
|
755ffcfc73 | ||
|
|
dc8f55f23e | ||
|
|
89249b414f | ||
|
|
92adf57fea | ||
|
|
e37a337164 | ||
|
|
1cd5a66575 | ||
|
|
b9fc008542 | ||
|
|
d5bf79bc51 | ||
|
|
d7efea74b6 | ||
|
|
b8c46e2654 | ||
|
|
4bf574037f | ||
|
|
47c44d4b87 | ||
|
|
96f866fb68 | ||
|
|
141065f14e | ||
|
|
8e74fb1fa8 | ||
|
|
ba96e102b4 | ||
|
|
7a46a63a14 | ||
|
|
2129b23fe7 | ||
|
|
b6211ad020 | ||
|
|
efd05ca023 | ||
|
|
c829ad930c | ||
|
|
ad1f18a52a | ||
|
|
bab420ca77 | ||
|
|
c2eaf8a1c0 | ||
|
|
a729c83b06 | ||
|
|
dc05102b8f | ||
|
|
a7e55cc5e3 | ||
|
|
b7c0eba1e5 | ||
|
|
d1a323fa9d | ||
|
|
63d211c698 | ||
|
|
0ca06b566a | ||
|
|
cf9e447bf0 | ||
|
|
fdd23d4644 | ||
|
|
5a3ee4f9c4 | ||
|
|
5ffed796c0 | ||
|
|
ab895be4a3 | ||
|
|
96cdcf8e49 | ||
|
|
63f6514be5 | ||
|
|
afece95ae5 | ||
|
|
d78b7e5d93 | ||
|
|
67906f6da5 | ||
|
|
52b5a31058 | ||
|
|
b58094de0f | ||
|
|
456aaf2868 | ||
|
|
d379c25ff5 | ||
|
|
f86ed12cf5 | ||
|
|
5a45f79fec | ||
|
|
e7d063126d | ||
|
|
fb42fedb58 | ||
|
|
9eb1e90bbe | ||
|
|
53fb0a9754 | ||
|
|
70c7543e36 | ||
|
|
d1d01a0611 | ||
|
|
9e8725618e | ||
|
|
a40261ff7e | ||
|
|
89e8540531 | ||
|
|
9f7e13fc87 | ||
|
|
8be6e92563 | ||
|
|
b726b3262d | ||
|
|
125a7a9daf | ||
|
|
9b1a0c2df7 | ||
|
|
1568c8aa91 | ||
|
|
2f5ba96596 | ||
|
|
63568e5e0e | ||
|
|
9c4bf1e899 | ||
|
|
2c01514259 | ||
|
|
e2f27502e4 | ||
|
|
8cf2866a6a | ||
|
|
c99ae6f009 | ||
|
|
8843784312 | ||
|
|
c38d65ef4c | ||
|
|
6d4240a5ae | ||
|
|
52f5101715 | ||
|
|
e2eef4e3fd | ||
|
|
76318f3f06 | ||
|
|
db25ca21a8 | ||
|
|
a8d03d8c91 | ||
|
|
74ff2619d0 | ||
|
|
40bea645e9 | ||
|
|
e7d52beeab | ||
|
|
7a5c6b24ae | ||
|
|
90c2093018 | ||
|
|
06318a15e1 | ||
|
|
eeb38b7ecf | ||
|
|
e59d2317fe | ||
|
|
ee6be58a67 | ||
|
|
a9f5fad625 | ||
|
|
c979a4e9fb | ||
|
|
f2fc0df104 | ||
|
|
87cc53b743 | ||
|
|
7d8a69cc0c | ||
|
|
e4de1d75de | ||
|
|
73e57f17ea | ||
|
|
46f5f148da | ||
|
|
32880c56a4 | ||
|
|
2b90ff8c24 | ||
|
|
b8599f634c | ||
|
|
659110f0d5 | ||
|
|
4ad14cb46b | ||
|
|
3c485dc7a1 | ||
|
|
f7e6cdcbf0 | ||
|
|
af6fdd3af2 | ||
|
|
5781ec7a8e | ||
|
|
1219006a6e | ||
|
|
4791e41004 | ||
|
|
9131069d12 | ||
|
|
26bbc33e7a | ||
|
|
35bc493cc3 | ||
|
|
e26ec0b937 | ||
|
|
a952e7c72f | ||
|
|
22f69d7852 | ||
|
|
b23011fbe8 | ||
|
|
6ad3894a51 | ||
|
|
c81b83b346 | ||
|
|
8c5c6815e0 | ||
|
|
0c470e7838 | ||
|
|
8118d60ffb | ||
|
|
1956ca169e | ||
|
|
830dee1771 | ||
|
|
c08a96770e | ||
|
|
c6bf1c7f26 | ||
|
|
5f499d66b2 | ||
|
|
7c065bd9fc | ||
|
|
ab849f0942 | ||
|
|
aa1d31bde6 | ||
|
|
5b4dc4dd47 | ||
|
|
1324169ebb | ||
|
|
732afd8393 | ||
|
|
da7b6b11ad | ||
|
|
e260270825 | ||
|
|
d4b6d7646c | ||
|
|
8febab4076 | ||
|
|
34e2c6b943 | ||
|
|
0be8c72601 | ||
|
|
c34e53477f | ||
|
|
8d18190c94 | ||
|
|
06bec61be9 | ||
|
|
2135533f1d | ||
|
|
bb791d59f3 | ||
|
|
30f1c54ed1 | ||
|
|
5c8541ef42 | ||
|
|
fa4b8c1d42 | ||
|
|
7682fe2e45 | ||
|
|
c9b2ce08eb | ||
|
|
246abda46d | ||
|
|
e4bc76c4de | ||
|
|
bdb8383485 | ||
|
|
bb40325977 | ||
|
|
8524cc75d6 | ||
|
|
c1f164c9cb | ||
|
|
4e2d075413 | ||
|
|
f89c200ce9 | ||
|
|
d51dc4fd33 | ||
|
|
00dddb9458 | ||
|
|
1a9301b684 | ||
|
|
80d9b5fca5 | ||
|
|
ac0b7dc8cb | ||
|
|
e586eca16c | ||
|
|
892db25021 | ||
|
|
da75a76d41 | ||
|
|
3ac32fd78a | ||
|
|
3aa657599b | ||
|
|
d4e9087f94 | ||
|
|
da8447a67d | ||
|
|
8e3bcd57a2 | ||
|
|
4572c6c1f8 | ||
|
|
01f2b0ecb7 | ||
|
|
442ba7cbc8 | ||
|
|
6c2b364966 | ||
|
|
0f0c7ec2ed | ||
|
|
2dec016201 | ||
|
|
06125acb8d |
15
.devcontainer/Dockerfile
Normal file
15
.devcontainer/Dockerfile
Normal file
@@ -0,0 +1,15 @@
|
||||
FROM golang:1.21-bullseye
|
||||
|
||||
RUN apt-get update && export DEBIAN_FRONTEND=noninteractive \
|
||||
&& apt-get -y install --no-install-recommends\
|
||||
gettext-base=0.21-4 \
|
||||
iptables=1.8.7-1 \
|
||||
libgl1-mesa-dev=20.3.5-1 \
|
||||
xorg-dev=1:7.7+22 \
|
||||
libayatana-appindicator3-dev=0.5.5-2+deb11u2 \
|
||||
&& apt-get clean \
|
||||
&& rm -rf /var/lib/apt/lists/* \
|
||||
&& go install -v golang.org/x/tools/gopls@latest
|
||||
|
||||
|
||||
WORKDIR /app
|
||||
20
.devcontainer/devcontainer.json
Normal file
20
.devcontainer/devcontainer.json
Normal file
@@ -0,0 +1,20 @@
|
||||
{
|
||||
"name": "NetBird",
|
||||
"build": {
|
||||
"context": "..",
|
||||
"dockerfile": "Dockerfile"
|
||||
},
|
||||
"features": {
|
||||
"ghcr.io/devcontainers/features/docker-in-docker:2": {},
|
||||
"ghcr.io/devcontainers/features/go:1": {
|
||||
"version": "1.21"
|
||||
}
|
||||
},
|
||||
"workspaceFolder": "/workspaces/${localWorkspaceFolderBasename}",
|
||||
"capAdd": [
|
||||
"NET_ADMIN",
|
||||
"SYS_ADMIN",
|
||||
"SYS_RESOURCE"
|
||||
],
|
||||
"privileged": true
|
||||
}
|
||||
1
.gitattributes
vendored
Normal file
1
.gitattributes
vendored
Normal file
@@ -0,0 +1 @@
|
||||
*.go text eol=lf
|
||||
18
.github/ISSUE_TEMPLATE/bug-issue-report.md
vendored
18
.github/ISSUE_TEMPLATE/bug-issue-report.md
vendored
@@ -2,15 +2,17 @@
|
||||
name: Bug/Issue report
|
||||
about: Create a report to help us improve
|
||||
title: ''
|
||||
labels: ''
|
||||
labels: ['triage-needed']
|
||||
assignees: ''
|
||||
|
||||
---
|
||||
|
||||
**Describe the problem**
|
||||
|
||||
A clear and concise description of what the problem is.
|
||||
|
||||
**To Reproduce**
|
||||
|
||||
Steps to reproduce the behavior:
|
||||
1. Go to '...'
|
||||
2. Click on '....'
|
||||
@@ -18,13 +20,25 @@ Steps to reproduce the behavior:
|
||||
4. See error
|
||||
|
||||
**Expected behavior**
|
||||
|
||||
A clear and concise description of what you expected to happen.
|
||||
|
||||
**Are you using NetBird Cloud?**
|
||||
|
||||
Please specify whether you use NetBird Cloud or self-host NetBird's control plane.
|
||||
|
||||
**NetBird version**
|
||||
|
||||
`netbird version`
|
||||
|
||||
**NetBird status -d output:**
|
||||
If applicable, add the output of the `netbird status -d` command
|
||||
|
||||
If applicable, add the `netbird status -d' command output.
|
||||
|
||||
**Screenshots**
|
||||
|
||||
If applicable, add screenshots to help explain your problem.
|
||||
|
||||
**Additional context**
|
||||
|
||||
Add any other context about the problem here.
|
||||
|
||||
2
.github/ISSUE_TEMPLATE/feature_request.md
vendored
2
.github/ISSUE_TEMPLATE/feature_request.md
vendored
@@ -2,7 +2,7 @@
|
||||
name: Feature request
|
||||
about: Suggest an idea for this project
|
||||
title: ''
|
||||
labels: ''
|
||||
labels: ['feature-request']
|
||||
assignees: ''
|
||||
|
||||
---
|
||||
|
||||
16
.github/workflows/golang-test-darwin.yml
vendored
16
.github/workflows/golang-test-darwin.yml
vendored
@@ -12,17 +12,20 @@ concurrency:
|
||||
|
||||
jobs:
|
||||
test:
|
||||
strategy:
|
||||
matrix:
|
||||
store: ['jsonfile', 'sqlite']
|
||||
runs-on: macos-latest
|
||||
steps:
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v2
|
||||
uses: actions/setup-go@v4
|
||||
with:
|
||||
go-version: "1.20.x"
|
||||
go-version: "1.21.x"
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v2
|
||||
uses: actions/checkout@v3
|
||||
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache@v2
|
||||
uses: actions/cache@v3
|
||||
with:
|
||||
path: ~/go/pkg/mod
|
||||
key: macos-go-${{ hashFiles('**/go.sum') }}
|
||||
@@ -32,5 +35,8 @@ jobs:
|
||||
- name: Install modules
|
||||
run: go mod tidy
|
||||
|
||||
- name: check git status
|
||||
run: git --no-pager diff --exit-code
|
||||
|
||||
- name: Test
|
||||
run: go test -exec 'sudo --preserve-env=CI' -timeout 5m -p 1 ./...
|
||||
run: NETBIRD_STORE_ENGINE=${{ matrix.store }} go test -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 5m -p 1 ./...
|
||||
|
||||
40
.github/workflows/golang-test-linux.yml
vendored
40
.github/workflows/golang-test-linux.yml
vendored
@@ -15,16 +15,17 @@ jobs:
|
||||
strategy:
|
||||
matrix:
|
||||
arch: ['386','amd64']
|
||||
store: ['jsonfile', 'sqlite']
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v2
|
||||
uses: actions/setup-go@v4
|
||||
with:
|
||||
go-version: "1.20.x"
|
||||
go-version: "1.21.x"
|
||||
|
||||
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache@v2
|
||||
uses: actions/cache@v3
|
||||
with:
|
||||
path: ~/go/pkg/mod
|
||||
key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }}
|
||||
@@ -32,7 +33,7 @@ jobs:
|
||||
${{ runner.os }}-go-
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v2
|
||||
uses: actions/checkout@v3
|
||||
|
||||
- name: Install dependencies
|
||||
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib
|
||||
@@ -40,20 +41,22 @@ jobs:
|
||||
- name: Install modules
|
||||
run: go mod tidy
|
||||
|
||||
- name: check git status
|
||||
run: git --no-pager diff --exit-code
|
||||
|
||||
- name: Test
|
||||
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} go test -exec 'sudo --preserve-env=CI' -timeout 5m -p 1 ./...
|
||||
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} go test -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 5m -p 1 ./...
|
||||
|
||||
test_client_on_docker:
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: ubuntu-20.04
|
||||
steps:
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v2
|
||||
uses: actions/setup-go@v4
|
||||
with:
|
||||
go-version: "1.20.x"
|
||||
|
||||
go-version: "1.21.x"
|
||||
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache@v2
|
||||
uses: actions/cache@v3
|
||||
with:
|
||||
path: ~/go/pkg/mod
|
||||
key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }}
|
||||
@@ -61,14 +64,17 @@ jobs:
|
||||
${{ runner.os }}-go-
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v2
|
||||
uses: actions/checkout@v3
|
||||
|
||||
- name: Install dependencies
|
||||
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev
|
||||
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib
|
||||
|
||||
- name: Install modules
|
||||
run: go mod tidy
|
||||
|
||||
- name: check git status
|
||||
run: git --no-pager diff --exit-code
|
||||
|
||||
- name: Generate Iface Test bin
|
||||
run: CGO_ENABLED=0 go test -c -o iface-testing.bin ./iface/
|
||||
|
||||
@@ -82,7 +88,7 @@ jobs:
|
||||
run: CGO_ENABLED=0 go test -c -o nftablesmanager-testing.bin ./client/firewall/nftables/...
|
||||
|
||||
- name: Generate Engine Test bin
|
||||
run: CGO_ENABLED=0 go test -c -o engine-testing.bin ./client/internal
|
||||
run: CGO_ENABLED=1 go test -c -o engine-testing.bin ./client/internal
|
||||
|
||||
- name: Generate Peer Test bin
|
||||
run: CGO_ENABLED=0 go test -c -o peer-testing.bin ./client/internal/peer/...
|
||||
@@ -95,15 +101,17 @@ jobs:
|
||||
- name: Run Iface tests in docker
|
||||
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/iface --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/iface-testing.bin -test.timeout 5m -test.parallel 1
|
||||
|
||||
|
||||
- name: Run RouteManager tests in docker
|
||||
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal/routemanager --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/routemanager-testing.bin -test.timeout 5m -test.parallel 1
|
||||
|
||||
- name: Run nftables Manager tests in docker
|
||||
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/firewall --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/nftablesmanager-testing.bin -test.timeout 5m -test.parallel 1
|
||||
|
||||
- name: Run Engine tests in docker
|
||||
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/engine-testing.bin -test.timeout 5m -test.parallel 1
|
||||
- name: Run Engine tests in docker with file store
|
||||
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal -e NETBIRD_STORE_ENGINE="jsonfile" --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/engine-testing.bin -test.timeout 5m -test.parallel 1
|
||||
|
||||
- name: Run Engine tests in docker with sqlite store
|
||||
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal -e NETBIRD_STORE_ENGINE="sqlite" --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/engine-testing.bin -test.timeout 5m -test.parallel 1
|
||||
|
||||
- name: Run Peer tests in docker
|
||||
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal/peer --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/peer-testing.bin -test.timeout 5m -test.parallel 1
|
||||
11
.github/workflows/golang-test-windows.yml
vendored
11
.github/workflows/golang-test-windows.yml
vendored
@@ -23,13 +23,13 @@ jobs:
|
||||
uses: actions/setup-go@v4
|
||||
id: go
|
||||
with:
|
||||
go-version: "1.20.x"
|
||||
go-version: "1.21.x"
|
||||
|
||||
- name: Download wintun
|
||||
uses: carlosperate/download-file-action@v2
|
||||
id: download-wintun
|
||||
with:
|
||||
file-url: https://www.wintun.net/builds/wintun-0.14.1.zip
|
||||
file-url: https://pkgs.netbird.io/wintun/wintun-0.14.1.zip
|
||||
file-name: wintun.zip
|
||||
location: ${{ env.downloadPath }}
|
||||
sha256: '07c256185d6ee3652e09fa55c0b673e2624b565e02c4b9091c79ca7d2f24ef51'
|
||||
@@ -39,12 +39,15 @@ jobs:
|
||||
|
||||
- run: mv ${{ env.downloadPath }}/wintun/bin/amd64/wintun.dll 'C:\Windows\System32\'
|
||||
|
||||
- run: choco install -y sysinternals
|
||||
- run: choco install -y sysinternals --ignore-checksums
|
||||
- run: choco install -y mingw
|
||||
|
||||
- run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOMODCACHE=C:\Users\runneradmin\go\pkg\mod
|
||||
- run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOCACHE=C:\Users\runneradmin\AppData\Local\go-build
|
||||
- run: "[Environment]::SetEnvironmentVariable('NETBIRD_STORE_ENGINE', 'jsonfile', 'Machine')"
|
||||
|
||||
- name: test
|
||||
run: PsExec64 -s -w ${{ github.workspace }} cmd.exe /c "C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe test -timeout 5m -p 1 ./... > test-out.txt 2>&1"
|
||||
- name: test output
|
||||
if: ${{ always() }}
|
||||
run: Get-Content test-out.txt
|
||||
run: Get-Content test-out.txt
|
||||
|
||||
43
.github/workflows/golangci-lint.yml
vendored
43
.github/workflows/golangci-lint.yml
vendored
@@ -1,21 +1,48 @@
|
||||
name: golangci-lint
|
||||
on: [pull_request]
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
pull-requests: read
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
golangci:
|
||||
name: lint
|
||||
codespell:
|
||||
name: codespell
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v2
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v3
|
||||
- name: codespell
|
||||
uses: codespell-project/actions-codespell@v2
|
||||
with:
|
||||
go-version: "1.20.x"
|
||||
ignore_words_list: erro,clienta
|
||||
skip: go.mod,go.sum
|
||||
only_warn: 1
|
||||
golangci:
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
os: [macos-latest, windows-latest, ubuntu-latest]
|
||||
name: lint
|
||||
runs-on: ${{ matrix.os }}
|
||||
timeout-minutes: 15
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v3
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v4
|
||||
with:
|
||||
go-version: "1.21.x"
|
||||
cache: false
|
||||
- name: Install dependencies
|
||||
if: matrix.os == 'ubuntu-latest'
|
||||
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev
|
||||
- name: golangci-lint
|
||||
uses: golangci/golangci-lint-action@v2
|
||||
uses: golangci/golangci-lint-action@v3
|
||||
with:
|
||||
args: --timeout=6m
|
||||
version: latest
|
||||
args: --timeout=12m
|
||||
36
.github/workflows/install-script-test.yml
vendored
Normal file
36
.github/workflows/install-script-test.yml
vendored
Normal file
@@ -0,0 +1,36 @@
|
||||
name: Test installation
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
pull_request:
|
||||
paths:
|
||||
- "release_files/install.sh"
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
|
||||
cancel-in-progress: true
|
||||
jobs:
|
||||
test-install-script:
|
||||
strategy:
|
||||
max-parallel: 2
|
||||
matrix:
|
||||
os: [ubuntu-latest, macos-latest]
|
||||
skip_ui_mode: [true, false]
|
||||
install_binary: [true, false]
|
||||
runs-on: ${{ matrix.os }}
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v3
|
||||
|
||||
- name: run install script
|
||||
env:
|
||||
SKIP_UI_APP: ${{ matrix.skip_ui_mode }}
|
||||
USE_BIN_INSTALL: ${{ matrix.install_binary }}
|
||||
GITHUB_TOKEN: ${{ secrets.RO_API_CALLER_TOKEN }}
|
||||
run: |
|
||||
[ "$SKIP_UI_APP" == "false" ] && export XDG_CURRENT_DESKTOP="none"
|
||||
cat release_files/install.sh | sh -x
|
||||
|
||||
- name: check cli binary
|
||||
run: command -v netbird
|
||||
60
.github/workflows/install-test-darwin.yml
vendored
60
.github/workflows/install-test-darwin.yml
vendored
@@ -1,60 +0,0 @@
|
||||
name: Test installation Darwin
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
pull_request:
|
||||
paths:
|
||||
- "release_files/install.sh"
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
|
||||
cancel-in-progress: true
|
||||
jobs:
|
||||
install-cli-only:
|
||||
runs-on: macos-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v2
|
||||
|
||||
- name: Rename brew package
|
||||
if: ${{ matrix.check_bin_install }}
|
||||
run: mv /opt/homebrew/bin/brew /opt/homebrew/bin/brew.bak
|
||||
|
||||
- name: Run install script
|
||||
run: |
|
||||
sh ./release_files/install.sh
|
||||
env:
|
||||
SKIP_UI_APP: true
|
||||
|
||||
- name: Run tests
|
||||
run: |
|
||||
if ! command -v netbird &> /dev/null; then
|
||||
echo "Error: netbird is not installed"
|
||||
exit 1
|
||||
fi
|
||||
install-all:
|
||||
runs-on: macos-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v2
|
||||
|
||||
- name: Rename brew package
|
||||
if: ${{ matrix.check_bin_install }}
|
||||
run: mv /opt/homebrew/bin/brew /opt/homebrew/bin/brew.bak
|
||||
|
||||
- name: Run install script
|
||||
run: |
|
||||
sh ./release_files/install.sh
|
||||
|
||||
- name: Run tests
|
||||
run: |
|
||||
if ! command -v netbird &> /dev/null; then
|
||||
echo "Error: netbird is not installed"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [[ $(mdfind "kMDItemContentType == 'com.apple.application-bundle' && kMDItemFSName == '*NetBird UI.app'") ]]; then
|
||||
echo "Error: NetBird UI is not installed"
|
||||
exit 1
|
||||
fi
|
||||
38
.github/workflows/install-test-linux.yml
vendored
38
.github/workflows/install-test-linux.yml
vendored
@@ -1,38 +0,0 @@
|
||||
name: Test installation Linux
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
pull_request:
|
||||
paths:
|
||||
- "release_files/install.sh"
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
|
||||
cancel-in-progress: true
|
||||
jobs:
|
||||
install-cli-only:
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
check_bin_install: [true, false]
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v2
|
||||
|
||||
- name: Rename apt package
|
||||
if: ${{ matrix.check_bin_install }}
|
||||
run: |
|
||||
sudo mv /usr/bin/apt /usr/bin/apt.bak
|
||||
sudo mv /usr/bin/apt-get /usr/bin/apt-get.bak
|
||||
|
||||
- name: Run install script
|
||||
run: |
|
||||
sh ./release_files/install.sh
|
||||
|
||||
- name: Run tests
|
||||
run: |
|
||||
if ! command -v netbird &> /dev/null; then
|
||||
echo "Error: netbird is not installed"
|
||||
exit 1
|
||||
fi
|
||||
65
.github/workflows/mobile-build-validation.yml
vendored
Normal file
65
.github/workflows/mobile-build-validation.yml
vendored
Normal file
@@ -0,0 +1,65 @@
|
||||
name: Mobile build validation
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
pull_request:
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
android_build:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v3
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v4
|
||||
with:
|
||||
go-version: "1.21.x"
|
||||
- name: Setup Android SDK
|
||||
uses: android-actions/setup-android@v3
|
||||
with:
|
||||
cmdline-tools-version: 8512546
|
||||
- name: Setup Java
|
||||
uses: actions/setup-java@v3
|
||||
with:
|
||||
java-version: "11"
|
||||
distribution: "adopt"
|
||||
- name: NDK Cache
|
||||
id: ndk-cache
|
||||
uses: actions/cache@v3
|
||||
with:
|
||||
path: /usr/local/lib/android/sdk/ndk
|
||||
key: ndk-cache-23.1.7779620
|
||||
- name: Setup NDK
|
||||
run: /usr/local/lib/android/sdk/cmdline-tools/7.0/bin/sdkmanager --install "ndk;23.1.7779620"
|
||||
- name: install gomobile
|
||||
run: go install golang.org/x/mobile/cmd/gomobile@v0.0.0-20230531173138-3c911d8e3eda
|
||||
- name: gomobile init
|
||||
run: gomobile init
|
||||
- name: build android netbird lib
|
||||
run: PATH=$PATH:$(go env GOPATH) gomobile bind -o $GITHUB_WORKSPACE/netbird.aar -javapkg=io.netbird.gomobile -ldflags="-X golang.zx2c4.com/wireguard/ipc.socketDirectory=/data/data/io.netbird.client/cache/wireguard -X github.com/netbirdio/netbird/version.version=buildtest" $GITHUB_WORKSPACE/client/android
|
||||
env:
|
||||
CGO_ENABLED: 0
|
||||
ANDROID_NDK_HOME: /usr/local/lib/android/sdk/ndk/23.1.7779620
|
||||
ios_build:
|
||||
runs-on: macos-latest
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v3
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v4
|
||||
with:
|
||||
go-version: "1.21.x"
|
||||
- name: install gomobile
|
||||
run: go install golang.org/x/mobile/cmd/gomobile@v0.0.0-20230531173138-3c911d8e3eda
|
||||
- name: gomobile init
|
||||
run: gomobile init
|
||||
- name: build iOS netbird lib
|
||||
run: PATH=$PATH:$(go env GOPATH) gomobile bind -target=ios -bundleid=io.netbird.framework -ldflags="-X github.com/netbirdio/netbird/version.version=buildtest" -o $GITHUB_WORKSPACE/NetBirdSDK.xcframework $GITHUB_WORKSPACE/client/ios/NetBirdSDK
|
||||
env:
|
||||
CGO_ENABLED: 0
|
||||
87
.github/workflows/release.yml
vendored
87
.github/workflows/release.yml
vendored
@@ -17,9 +17,10 @@ on:
|
||||
- 'release_files/**'
|
||||
- '**/Dockerfile'
|
||||
- '**/Dockerfile.*'
|
||||
- 'client/ui/**'
|
||||
|
||||
env:
|
||||
SIGN_PIPE_VER: "v0.0.8"
|
||||
SIGN_PIPE_VER: "v0.0.11"
|
||||
GORELEASER_VER: "v1.14.1"
|
||||
|
||||
concurrency:
|
||||
@@ -29,25 +30,32 @@ concurrency:
|
||||
jobs:
|
||||
release:
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
flags: ""
|
||||
steps:
|
||||
- if: ${{ !startsWith(github.ref, 'refs/tags/v') }}
|
||||
run: echo "flags=--snapshot" >> $GITHUB_ENV
|
||||
-
|
||||
name: Checkout
|
||||
uses: actions/checkout@v2
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 0 # It is required for GoReleaser to work properly
|
||||
-
|
||||
name: Set up Go
|
||||
uses: actions/setup-go@v2
|
||||
uses: actions/setup-go@v4
|
||||
with:
|
||||
go-version: "1.20"
|
||||
go-version: "1.21"
|
||||
cache: false
|
||||
-
|
||||
name: Cache Go modules
|
||||
uses: actions/cache@v1
|
||||
uses: actions/cache@v3
|
||||
with:
|
||||
path: ~/go/pkg/mod
|
||||
key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }}
|
||||
path: |
|
||||
~/go/pkg/mod
|
||||
~/.cache/go-build
|
||||
key: ${{ runner.os }}-go-releaser-${{ hashFiles('**/go.sum') }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-go-
|
||||
${{ runner.os }}-go-releaser-
|
||||
-
|
||||
name: Install modules
|
||||
run: go mod tidy
|
||||
@@ -56,10 +64,10 @@ jobs:
|
||||
run: git --no-pager diff --exit-code
|
||||
-
|
||||
name: Set up QEMU
|
||||
uses: docker/setup-qemu-action@v1
|
||||
uses: docker/setup-qemu-action@v2
|
||||
-
|
||||
name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v1
|
||||
uses: docker/setup-buildx-action@v2
|
||||
-
|
||||
name: Login to Docker hub
|
||||
if: github.event_name != 'pull_request'
|
||||
@@ -82,10 +90,10 @@ jobs:
|
||||
run: rsrc -arch 386 -ico client/ui/netbird.ico -manifest client/manifest.xml -o client/resources_windows_386.syso
|
||||
-
|
||||
name: Run GoReleaser
|
||||
uses: goreleaser/goreleaser-action@v2
|
||||
uses: goreleaser/goreleaser-action@v4
|
||||
with:
|
||||
version: ${{ env.GORELEASER_VER }}
|
||||
args: release --rm-dist
|
||||
args: release --rm-dist ${{ env.flags }}
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
HOMEBREW_TAP_GITHUB_TOKEN: ${{ secrets.HOMEBREW_TAP_GITHUB_TOKEN }}
|
||||
@@ -93,7 +101,7 @@ jobs:
|
||||
UPLOAD_YUM_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }}
|
||||
-
|
||||
name: upload non tags for debug purposes
|
||||
uses: actions/upload-artifact@v2
|
||||
uses: actions/upload-artifact@v3
|
||||
with:
|
||||
name: release
|
||||
path: dist/
|
||||
@@ -102,22 +110,27 @@ jobs:
|
||||
release_ui:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- if: ${{ !startsWith(github.ref, 'refs/tags/v') }}
|
||||
run: echo "flags=--snapshot" >> $GITHUB_ENV
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v2
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 0 # It is required for GoReleaser to work properly
|
||||
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v2
|
||||
uses: actions/setup-go@v4
|
||||
with:
|
||||
go-version: "1.20"
|
||||
go-version: "1.21"
|
||||
cache: false
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache@v1
|
||||
uses: actions/cache@v3
|
||||
with:
|
||||
path: ~/go/pkg/mod
|
||||
key: ${{ runner.os }}-ui-go-${{ hashFiles('**/go.sum') }}
|
||||
path: |
|
||||
~/go/pkg/mod
|
||||
~/.cache/go-build
|
||||
key: ${{ runner.os }}-ui-go-releaser-${{ hashFiles('**/go.sum') }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-ui-go-
|
||||
${{ runner.os }}-ui-go-releaser-
|
||||
|
||||
- name: Install modules
|
||||
run: go mod tidy
|
||||
@@ -132,17 +145,17 @@ jobs:
|
||||
- name: Generate windows rsrc
|
||||
run: rsrc -arch amd64 -ico client/ui/netbird.ico -manifest client/ui/manifest.xml -o client/ui/resources_windows_amd64.syso
|
||||
- name: Run GoReleaser
|
||||
uses: goreleaser/goreleaser-action@v2
|
||||
uses: goreleaser/goreleaser-action@v4
|
||||
with:
|
||||
version: ${{ env.GORELEASER_VER }}
|
||||
args: release --config .goreleaser_ui.yaml --rm-dist
|
||||
args: release --config .goreleaser_ui.yaml --rm-dist ${{ env.flags }}
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
HOMEBREW_TAP_GITHUB_TOKEN: ${{ secrets.HOMEBREW_TAP_GITHUB_TOKEN }}
|
||||
UPLOAD_DEBIAN_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }}
|
||||
UPLOAD_YUM_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }}
|
||||
- name: upload non tags for debug purposes
|
||||
uses: actions/upload-artifact@v2
|
||||
uses: actions/upload-artifact@v3
|
||||
with:
|
||||
name: release-ui
|
||||
path: dist/
|
||||
@@ -151,39 +164,47 @@ jobs:
|
||||
release_ui_darwin:
|
||||
runs-on: macos-11
|
||||
steps:
|
||||
- if: ${{ !startsWith(github.ref, 'refs/tags/v') }}
|
||||
run: echo "flags=--snapshot" >> $GITHUB_ENV
|
||||
-
|
||||
name: Checkout
|
||||
uses: actions/checkout@v2
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 0 # It is required for GoReleaser to work properly
|
||||
-
|
||||
name: Set up Go
|
||||
uses: actions/setup-go@v2
|
||||
uses: actions/setup-go@v4
|
||||
with:
|
||||
go-version: "1.20"
|
||||
go-version: "1.21"
|
||||
cache: false
|
||||
-
|
||||
name: Cache Go modules
|
||||
uses: actions/cache@v1
|
||||
uses: actions/cache@v3
|
||||
with:
|
||||
path: ~/go/pkg/mod
|
||||
key: ${{ runner.os }}-ui-go-${{ hashFiles('**/go.sum') }}
|
||||
path: |
|
||||
~/go/pkg/mod
|
||||
~/.cache/go-build
|
||||
key: ${{ runner.os }}-ui-go-releaser-darwin-${{ hashFiles('**/go.sum') }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-ui-go-
|
||||
${{ runner.os }}-ui-go-releaser-darwin-
|
||||
-
|
||||
name: Install modules
|
||||
run: go mod tidy
|
||||
-
|
||||
name: check git status
|
||||
run: git --no-pager diff --exit-code
|
||||
-
|
||||
name: Run GoReleaser
|
||||
id: goreleaser
|
||||
uses: goreleaser/goreleaser-action@v2
|
||||
uses: goreleaser/goreleaser-action@v4
|
||||
with:
|
||||
version: ${{ env.GORELEASER_VER }}
|
||||
args: release --config .goreleaser_ui_darwin.yaml --rm-dist
|
||||
args: release --config .goreleaser_ui_darwin.yaml --rm-dist ${{ env.flags }}
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
-
|
||||
name: upload non tags for debug purposes
|
||||
uses: actions/upload-artifact@v2
|
||||
uses: actions/upload-artifact@v3
|
||||
with:
|
||||
name: release-ui-darwin
|
||||
path: dist/
|
||||
|
||||
22
.github/workflows/sync-main.yml
vendored
Normal file
22
.github/workflows/sync-main.yml
vendored
Normal file
@@ -0,0 +1,22 @@
|
||||
name: sync main
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
trigger_sync_main:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Trigger main branch sync
|
||||
uses: benc-uk/workflow-dispatch@v1
|
||||
with:
|
||||
workflow: sync-main.yml
|
||||
repo: ${{ secrets.UPSTREAM_REPO }}
|
||||
token: ${{ secrets.NC_GITHUB_TOKEN }}
|
||||
inputs: '{ "sha": "${{ github.sha }}" }'
|
||||
23
.github/workflows/sync-tag.yml
vendored
Normal file
23
.github/workflows/sync-tag.yml
vendored
Normal file
@@ -0,0 +1,23 @@
|
||||
name: sync tag
|
||||
|
||||
on:
|
||||
push:
|
||||
tags:
|
||||
- 'v*'
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
trigger_sync_tag:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Trigger release tag sync
|
||||
uses: benc-uk/workflow-dispatch@v1
|
||||
with:
|
||||
workflow: sync-tag.yml
|
||||
ref: main
|
||||
repo: ${{ secrets.UPSTREAM_REPO }}
|
||||
token: ${{ secrets.NC_GITHUB_TOKEN }}
|
||||
inputs: '{ "tag": "${{ github.ref_name }}" }'
|
||||
96
.github/workflows/test-infrastructure-files.yml
vendored
96
.github/workflows/test-infrastructure-files.yml
vendored
@@ -8,7 +8,8 @@ on:
|
||||
paths:
|
||||
- 'infrastructure_files/**'
|
||||
- '.github/workflows/test-infrastructure-files.yml'
|
||||
|
||||
- 'management/cmd/**'
|
||||
- 'signal/cmd/**'
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
|
||||
@@ -25,12 +26,12 @@ jobs:
|
||||
run: sudo apt-get install -y curl
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v2
|
||||
uses: actions/setup-go@v4
|
||||
with:
|
||||
go-version: "1.20.x"
|
||||
go-version: "1.21.x"
|
||||
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache@v2
|
||||
uses: actions/cache@v3
|
||||
with:
|
||||
path: ~/go/pkg/mod
|
||||
key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }}
|
||||
@@ -57,9 +58,11 @@ jobs:
|
||||
CI_NETBIRD_IDP_MGMT_CLIENT_ID: testing.client.id
|
||||
CI_NETBIRD_IDP_MGMT_CLIENT_SECRET: testing.client.secret
|
||||
CI_NETBIRD_AUTH_SUPPORTED_SCOPES: "openid profile email offline_access api email_verified"
|
||||
CI_NETBIRD_STORE_CONFIG_ENGINE: "sqlite"
|
||||
CI_NETBIRD_MGMT_IDP_SIGNKEY_REFRESH: false
|
||||
|
||||
- name: check values
|
||||
working-directory: infrastructure_files
|
||||
working-directory: infrastructure_files/artifacts
|
||||
env:
|
||||
CI_NETBIRD_DOMAIN: localhost
|
||||
CI_NETBIRD_AUTH_CLIENT_ID: testing.client.id
|
||||
@@ -81,8 +84,13 @@ jobs:
|
||||
CI_NETBIRD_MGMT_IDP: "none"
|
||||
CI_NETBIRD_IDP_MGMT_CLIENT_ID: testing.client.id
|
||||
CI_NETBIRD_IDP_MGMT_CLIENT_SECRET: testing.client.secret
|
||||
CI_NETBIRD_SIGNAL_PORT: 12345
|
||||
CI_NETBIRD_STORE_CONFIG_ENGINE: "sqlite"
|
||||
CI_NETBIRD_MGMT_IDP_SIGNKEY_REFRESH: false
|
||||
CI_NETBIRD_TURN_EXTERNAL_IP: "1.2.3.4"
|
||||
|
||||
run: |
|
||||
set -x
|
||||
grep AUTH_CLIENT_ID docker-compose.yml | grep $CI_NETBIRD_AUTH_CLIENT_ID
|
||||
grep AUTH_CLIENT_SECRET docker-compose.yml | grep $CI_NETBIRD_AUTH_CLIENT_SECRET
|
||||
grep AUTH_AUTHORITY docker-compose.yml | grep $CI_NETBIRD_AUTH_AUTHORITY
|
||||
@@ -92,27 +100,56 @@ jobs:
|
||||
grep NETBIRD_MGMT_API_ENDPOINT docker-compose.yml | grep "$CI_NETBIRD_DOMAIN:33073"
|
||||
grep AUTH_REDIRECT_URI docker-compose.yml | grep $CI_NETBIRD_AUTH_REDIRECT_URI
|
||||
grep AUTH_SILENT_REDIRECT_URI docker-compose.yml | egrep 'AUTH_SILENT_REDIRECT_URI=$'
|
||||
grep $CI_NETBIRD_SIGNAL_PORT docker-compose.yml | grep ':80'
|
||||
grep LETSENCRYPT_DOMAIN docker-compose.yml | egrep 'LETSENCRYPT_DOMAIN=$'
|
||||
grep NETBIRD_TOKEN_SOURCE docker-compose.yml | grep $CI_NETBIRD_TOKEN_SOURCE
|
||||
grep AuthUserIDClaim management.json | grep $CI_NETBIRD_AUTH_USER_ID_CLAIM
|
||||
grep -A 3 DeviceAuthorizationFlow management.json | grep -A 1 ProviderConfig | grep Audience | grep $CI_NETBIRD_AUTH_DEVICE_AUTH_AUDIENCE
|
||||
grep -A 8 DeviceAuthorizationFlow management.json | grep -A 6 ProviderConfig | grep Scope | grep "$CI_NETBIRD_AUTH_DEVICE_AUTH_SCOPE"
|
||||
grep -A 3 DeviceAuthorizationFlow management.json | grep -A 1 ProviderConfig | grep Audience | grep $CI_NETBIRD_AUTH_DEVICE_AUTH_AUDIENCE
|
||||
grep Engine management.json | grep "$CI_NETBIRD_STORE_CONFIG_ENGINE"
|
||||
grep IdpSignKeyRefreshEnabled management.json | grep "$CI_NETBIRD_MGMT_IDP_SIGNKEY_REFRESH"
|
||||
grep UseIDToken management.json | grep false
|
||||
grep -A 1 IdpManagerConfig management.json | grep ManagerType | grep $CI_NETBIRD_MGMT_IDP
|
||||
grep -A 1 IdpManagerConfig management.json | grep ManagerType | grep $CI_NETBIRD_MGMT_IDP
|
||||
grep -A 3 IdpManagerConfig management.json | grep -A 1 ClientConfig | grep Issuer | grep $CI_NETBIRD_AUTH_AUTHORITY
|
||||
grep -A 4 IdpManagerConfig management.json | grep -A 2 ClientConfig | grep TokenEndpoint | grep $CI_NETBIRD_AUTH_TOKEN_ENDPOINT
|
||||
grep -A 5 IdpManagerConfig management.json | grep -A 3 ClientConfig | grep ClientID | grep $CI_NETBIRD_IDP_MGMT_CLIENT_ID
|
||||
grep -A 6 IdpManagerConfig management.json | grep -A 4 ClientConfig | grep ClientSecret | grep $CI_NETBIRD_IDP_MGMT_CLIENT_SECRET
|
||||
grep -A 7 IdpManagerConfig management.json | grep -A 5 ClientConfig | grep GrantType | grep client_credentials
|
||||
grep -A 2 PKCEAuthorizationFlow management.json | grep -A 1 ProviderConfig | grep Audience | grep $CI_NETBIRD_AUTH_AUDIENCE
|
||||
grep -A 3 PKCEAuthorizationFlow management.json | grep -A 2 ProviderConfig | grep ClientID | grep $CI_NETBIRD_AUTH_CLIENT_ID
|
||||
grep -A 4 PKCEAuthorizationFlow management.json | grep -A 3 ProviderConfig | grep ClientSecret | grep $CI_NETBIRD_AUTH_CLIENT_SECRET
|
||||
grep -A 5 PKCEAuthorizationFlow management.json | grep -A 4 ProviderConfig | grep AuthorizationEndpoint | grep $CI_NETBIRD_AUTH_PKCE_AUTHORIZATION_ENDPOINT
|
||||
grep -A 6 PKCEAuthorizationFlow management.json | grep -A 5 ProviderConfig | grep TokenEndpoint | grep $CI_NETBIRD_AUTH_TOKEN_ENDPOINT
|
||||
grep -A 7 PKCEAuthorizationFlow management.json | grep -A 6 ProviderConfig | grep Scope | grep "$CI_NETBIRD_AUTH_SUPPORTED_SCOPES"
|
||||
grep -A 10 PKCEAuthorizationFlow management.json | grep -A 10 ProviderConfig | grep Audience | grep $CI_NETBIRD_AUTH_AUDIENCE
|
||||
grep -A 10 PKCEAuthorizationFlow management.json | grep -A 10 ProviderConfig | grep ClientID | grep $CI_NETBIRD_AUTH_CLIENT_ID
|
||||
grep -A 10 PKCEAuthorizationFlow management.json | grep -A 10 ProviderConfig | grep ClientSecret | grep $CI_NETBIRD_AUTH_CLIENT_SECRET
|
||||
grep -A 10 PKCEAuthorizationFlow management.json | grep -A 10 ProviderConfig | grep AuthorizationEndpoint | grep $CI_NETBIRD_AUTH_PKCE_AUTHORIZATION_ENDPOINT
|
||||
grep -A 10 PKCEAuthorizationFlow management.json | grep -A 10 ProviderConfig | grep TokenEndpoint | grep $CI_NETBIRD_AUTH_TOKEN_ENDPOINT
|
||||
grep -A 10 PKCEAuthorizationFlow management.json | grep -A 10 ProviderConfig | grep Scope | grep "$CI_NETBIRD_AUTH_SUPPORTED_SCOPES"
|
||||
grep -A 10 PKCEAuthorizationFlow management.json | grep -A 10 ProviderConfig | grep -A 3 RedirectURLs | grep "http://localhost:53000"
|
||||
grep "external-ip" turnserver.conf | grep $CI_NETBIRD_TURN_EXTERNAL_IP
|
||||
|
||||
- name: Install modules
|
||||
run: go mod tidy
|
||||
|
||||
- name: check git status
|
||||
run: git --no-pager diff --exit-code
|
||||
|
||||
- name: Build management binary
|
||||
working-directory: management
|
||||
run: CGO_ENABLED=1 go build -o netbird-mgmt main.go
|
||||
|
||||
- name: Build management docker image
|
||||
working-directory: management
|
||||
run: |
|
||||
docker build -t netbirdio/management:latest .
|
||||
|
||||
- name: Build signal binary
|
||||
working-directory: signal
|
||||
run: CGO_ENABLED=0 go build -o netbird-signal main.go
|
||||
|
||||
- name: Build signal docker image
|
||||
working-directory: signal
|
||||
run: |
|
||||
docker build -t netbirdio/signal:latest .
|
||||
|
||||
- name: run docker compose up
|
||||
working-directory: infrastructure_files
|
||||
working-directory: infrastructure_files/artifacts
|
||||
run: |
|
||||
docker-compose up -d
|
||||
sleep 5
|
||||
@@ -121,9 +158,16 @@ jobs:
|
||||
|
||||
- name: test running containers
|
||||
run: |
|
||||
count=$(docker compose ps --format json | jq '.[] | select(.Project | contains("infrastructure_files")) | .State' | grep -c running)
|
||||
count=$(docker compose ps --format json | jq '. | select(.Name | contains("artifacts")) | .State' | grep -c running)
|
||||
test $count -eq 4
|
||||
working-directory: infrastructure_files
|
||||
working-directory: infrastructure_files/artifacts
|
||||
|
||||
- name: test geolocation databases
|
||||
working-directory: infrastructure_files/artifacts
|
||||
run: |
|
||||
sleep 30
|
||||
docker compose exec management ls -l /var/lib/netbird/ | grep -i GeoLite2-City.mmdb
|
||||
docker compose exec management ls -l /var/lib/netbird/ | grep -i geonames.db
|
||||
|
||||
test-getting-started-script:
|
||||
runs-on: ubuntu-latest
|
||||
@@ -144,8 +188,24 @@ jobs:
|
||||
- name: test management.json file gen
|
||||
run: test -f management.json
|
||||
- name: test turnserver.conf file gen
|
||||
run: test -f turnserver.conf
|
||||
run: |
|
||||
set -x
|
||||
test -f turnserver.conf
|
||||
grep external-ip turnserver.conf
|
||||
- name: test zitadel.env file gen
|
||||
run: test -f zitadel.env
|
||||
- name: test dashboard.env file gen
|
||||
run: test -f dashboard.env
|
||||
run: test -f dashboard.env
|
||||
test-download-geolite2-script:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Install jq
|
||||
run: sudo apt-get update && sudo apt-get install -y unzip sqlite3
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v3
|
||||
- name: test script
|
||||
run: bash -x infrastructure_files/download-geolite2.sh
|
||||
- name: test mmdb file exists
|
||||
run: test -f GeoLite2-City.mmdb
|
||||
- name: test geonames file exists
|
||||
run: test -f geonames.db
|
||||
|
||||
21
.gitignore
vendored
21
.gitignore
vendored
@@ -6,11 +6,20 @@ bin/
|
||||
.env
|
||||
conf.json
|
||||
http-cmds.sh
|
||||
infrastructure_files/management.json
|
||||
infrastructure_files/management-*.json
|
||||
infrastructure_files/docker-compose.yml
|
||||
infrastructure_files/openid-configuration.json
|
||||
infrastructure_files/turnserver.conf
|
||||
setup.env
|
||||
infrastructure_files/**/Caddyfile
|
||||
infrastructure_files/**/dashboard.env
|
||||
infrastructure_files/**/zitadel.env
|
||||
infrastructure_files/**/management.json
|
||||
infrastructure_files/**/management-*.json
|
||||
infrastructure_files/**/docker-compose.yml
|
||||
infrastructure_files/**/openid-configuration.json
|
||||
infrastructure_files/**/turnserver.conf
|
||||
infrastructure_files/**/management.json.bkp.**
|
||||
infrastructure_files/**/management-*.json.bkp.**
|
||||
infrastructure_files/**/docker-compose.yml.bkp.**
|
||||
infrastructure_files/**/openid-configuration.json.bkp.**
|
||||
infrastructure_files/**/turnserver.conf.bkp.**
|
||||
management/management
|
||||
client/client
|
||||
client/client.exe
|
||||
@@ -19,3 +28,5 @@ client/.distfiles/
|
||||
infrastructure_files/setup.env
|
||||
infrastructure_files/setup-*.env
|
||||
.vscode
|
||||
.DS_Store
|
||||
GeoLite2-City*
|
||||
132
.golangci.yaml
Normal file
132
.golangci.yaml
Normal file
@@ -0,0 +1,132 @@
|
||||
run:
|
||||
# Timeout for analysis, e.g. 30s, 5m.
|
||||
# Default: 1m
|
||||
timeout: 6m
|
||||
|
||||
# This file contains only configs which differ from defaults.
|
||||
# All possible options can be found here https://github.com/golangci/golangci-lint/blob/master/.golangci.reference.yml
|
||||
linters-settings:
|
||||
errcheck:
|
||||
# Report about not checking of errors in type assertions: `a := b.(MyStruct)`.
|
||||
# Such cases aren't reported by default.
|
||||
# Default: false
|
||||
check-type-assertions: false
|
||||
|
||||
gosec:
|
||||
includes:
|
||||
- G101 # Look for hard coded credentials
|
||||
#- G102 # Bind to all interfaces
|
||||
- G103 # Audit the use of unsafe block
|
||||
- G104 # Audit errors not checked
|
||||
- G106 # Audit the use of ssh.InsecureIgnoreHostKey
|
||||
#- G107 # Url provided to HTTP request as taint input
|
||||
- G108 # Profiling endpoint automatically exposed on /debug/pprof
|
||||
- G109 # Potential Integer overflow made by strconv.Atoi result conversion to int16/32
|
||||
- G110 # Potential DoS vulnerability via decompression bomb
|
||||
- G111 # Potential directory traversal
|
||||
#- G112 # Potential slowloris attack
|
||||
- G113 # Usage of Rat.SetString in math/big with an overflow (CVE-2022-23772)
|
||||
#- G114 # Use of net/http serve function that has no support for setting timeouts
|
||||
- G201 # SQL query construction using format string
|
||||
- G202 # SQL query construction using string concatenation
|
||||
- G203 # Use of unescaped data in HTML templates
|
||||
#- G204 # Audit use of command execution
|
||||
- G301 # Poor file permissions used when creating a directory
|
||||
- G302 # Poor file permissions used with chmod
|
||||
- G303 # Creating tempfile using a predictable path
|
||||
- G304 # File path provided as taint input
|
||||
- G305 # File traversal when extracting zip/tar archive
|
||||
- G306 # Poor file permissions used when writing to a new file
|
||||
- G307 # Poor file permissions used when creating a file with os.Create
|
||||
#- G401 # Detect the usage of DES, RC4, MD5 or SHA1
|
||||
#- G402 # Look for bad TLS connection settings
|
||||
- G403 # Ensure minimum RSA key length of 2048 bits
|
||||
#- G404 # Insecure random number source (rand)
|
||||
#- G501 # Import blocklist: crypto/md5
|
||||
- G502 # Import blocklist: crypto/des
|
||||
- G503 # Import blocklist: crypto/rc4
|
||||
- G504 # Import blocklist: net/http/cgi
|
||||
#- G505 # Import blocklist: crypto/sha1
|
||||
- G601 # Implicit memory aliasing of items from a range statement
|
||||
- G602 # Slice access out of bounds
|
||||
|
||||
gocritic:
|
||||
disabled-checks:
|
||||
- commentFormatting
|
||||
- captLocal
|
||||
- deprecatedComment
|
||||
|
||||
govet:
|
||||
# Enable all analyzers.
|
||||
# Default: false
|
||||
enable-all: false
|
||||
enable:
|
||||
- nilness
|
||||
|
||||
revive:
|
||||
rules:
|
||||
- name: exported
|
||||
severity: warning
|
||||
disabled: false
|
||||
arguments:
|
||||
- "checkPrivateReceivers"
|
||||
- "sayRepetitiveInsteadOfStutters"
|
||||
tenv:
|
||||
# The option `all` will run against whole test files (`_test.go`) regardless of method/function signatures.
|
||||
# Otherwise, only methods that take `*testing.T`, `*testing.B`, and `testing.TB` as arguments are checked.
|
||||
# Default: false
|
||||
all: true
|
||||
|
||||
linters:
|
||||
disable-all: true
|
||||
enable:
|
||||
## enabled by default
|
||||
- errcheck # checking for unchecked errors, these unchecked errors can be critical bugs in some cases
|
||||
- gosimple # specializes in simplifying a code
|
||||
- govet # reports suspicious constructs, such as Printf calls whose arguments do not align with the format string
|
||||
- ineffassign # detects when assignments to existing variables are not used
|
||||
- staticcheck # is a go vet on steroids, applying a ton of static analysis checks
|
||||
- tenv # Tenv is analyzer that detects using os.Setenv instead of t.Setenv since Go1.17.
|
||||
- typecheck # like the front-end of a Go compiler, parses and type-checks Go code
|
||||
- unused # checks for unused constants, variables, functions and types
|
||||
## disable by default but the have interesting results so lets add them
|
||||
- bodyclose # checks whether HTTP response body is closed successfully
|
||||
- dupword # dupword checks for duplicate words in the source code
|
||||
- durationcheck # durationcheck checks for two durations multiplied together
|
||||
- forbidigo # forbidigo forbids identifiers
|
||||
- gocritic # provides diagnostics that check for bugs, performance and style issues
|
||||
- gosec # inspects source code for security problems
|
||||
- mirror # mirror reports wrong mirror patterns of bytes/strings usage
|
||||
- misspell # misspess finds commonly misspelled English words in comments
|
||||
- nilerr # finds the code that returns nil even if it checks that the error is not nil
|
||||
- nilnil # checks that there is no simultaneous return of nil error and an invalid value
|
||||
- predeclared # predeclared finds code that shadows one of Go's predeclared identifiers
|
||||
- revive # Fast, configurable, extensible, flexible, and beautiful linter for Go. Drop-in replacement of golint.
|
||||
- sqlclosecheck # checks that sql.Rows and sql.Stmt are closed
|
||||
- thelper # thelper detects Go test helpers without t.Helper() call and checks the consistency of test helpers.
|
||||
- wastedassign # wastedassign finds wasted assignment statements
|
||||
issues:
|
||||
# Maximum count of issues with the same text.
|
||||
# Set to 0 to disable.
|
||||
# Default: 3
|
||||
max-same-issues: 5
|
||||
|
||||
exclude-rules:
|
||||
# allow fmt
|
||||
- path: management/cmd/root\.go
|
||||
linters: forbidigo
|
||||
- path: signal/cmd/root\.go
|
||||
linters: forbidigo
|
||||
- path: sharedsock/filter\.go
|
||||
linters:
|
||||
- unused
|
||||
- path: client/firewall/iptables/rule\.go
|
||||
linters:
|
||||
- unused
|
||||
- path: test\.go
|
||||
linters:
|
||||
- mirror
|
||||
- gosec
|
||||
- path: mock\.go
|
||||
linters:
|
||||
- nilnil
|
||||
@@ -54,7 +54,7 @@ nfpms:
|
||||
contents:
|
||||
- src: client/ui/netbird.desktop
|
||||
dst: /usr/share/applications/netbird.desktop
|
||||
- src: client/ui/disconnected.png
|
||||
- src: client/ui/netbird-systemtray-connected.png
|
||||
dst: /usr/share/pixmaps/netbird.png
|
||||
dependencies:
|
||||
- netbird
|
||||
@@ -71,7 +71,7 @@ nfpms:
|
||||
contents:
|
||||
- src: client/ui/netbird.desktop
|
||||
dst: /usr/share/applications/netbird.desktop
|
||||
- src: client/ui/disconnected.png
|
||||
- src: client/ui/netbird-systemtray-connected.png
|
||||
dst: /usr/share/pixmaps/netbird.png
|
||||
dependencies:
|
||||
- netbird
|
||||
@@ -91,4 +91,4 @@ uploads:
|
||||
mode: archive
|
||||
target: https://pkgs.wiretrustee.com/yum/{{ .Arch }}{{ if .Arm }}{{ .Arm }}{{ end }}
|
||||
username: dev@wiretrustee.com
|
||||
method: PUT
|
||||
method: PUT
|
||||
|
||||
@@ -19,11 +19,11 @@ If you haven't already, join our slack workspace [here](https://join.slack.com/t
|
||||
- [Development setup](#development-setup)
|
||||
- [Requirements](#requirements)
|
||||
- [Local NetBird setup](#local-netbird-setup)
|
||||
- [Dev Container Support](#dev-container-support)
|
||||
- [Build and start](#build-and-start)
|
||||
- [Test suite](#test-suite)
|
||||
- [Checklist before submitting a PR](#checklist-before-submitting-a-pr)
|
||||
- [Other project repositories](#other-project-repositories)
|
||||
- [Checklist before submitting a new node](#checklist-before-submitting-a-new-node)
|
||||
- [Contributor License Agreement](#contributor-license-agreement)
|
||||
|
||||
## Code of conduct
|
||||
@@ -70,7 +70,7 @@ dependencies are installed. Here is a short guide on how that can be done.
|
||||
|
||||
### Requirements
|
||||
|
||||
#### Go 1.19
|
||||
#### Go 1.21
|
||||
|
||||
Follow the installation guide from https://go.dev/
|
||||
|
||||
@@ -136,18 +136,61 @@ checked out and set up:
|
||||
go mod tidy
|
||||
```
|
||||
|
||||
### Dev Container Support
|
||||
|
||||
If you prefer using a dev container for development, NetBird now includes support for dev containers.
|
||||
Dev containers provide a consistent and isolated development environment, making it easier for contributors to get started quickly. Follow the steps below to set up NetBird in a dev container.
|
||||
|
||||
#### 1. Prerequisites:
|
||||
|
||||
* Install Docker on your machine: [Docker Installation Guide](https://docs.docker.com/get-docker/)
|
||||
* Install Visual Studio Code: [VS Code Installation Guide](https://code.visualstudio.com/download)
|
||||
* If you prefer JetBrains Goland please follow this [manual](https://www.jetbrains.com/help/go/connect-to-devcontainer.html)
|
||||
|
||||
#### 2. Clone the Repository:
|
||||
|
||||
Clone the repository following previous [Local NetBird setup](#local-netbird-setup).
|
||||
|
||||
#### 3. Open in project in IDE of your choice:
|
||||
|
||||
**VScode**:
|
||||
|
||||
Open the project folder in Visual Studio Code:
|
||||
|
||||
```bash
|
||||
code .
|
||||
```
|
||||
|
||||
When you open the project in VS Code, it will detect the presence of a dev container configuration.
|
||||
Click on the green "Reopen in Container" button in the bottom-right corner of VS Code.
|
||||
|
||||
**Goland**:
|
||||
|
||||
Open GoLand and select `"File" > "Open"` to open the NetBird project folder.
|
||||
GoLand will detect the dev container configuration and prompt you to open the project in the container. Accept the prompt.
|
||||
|
||||
#### 4. Wait for the Container to Build:
|
||||
|
||||
VsCode or GoLand will use the specified Docker image to build the dev container. This might take some time, depending on your internet connection.
|
||||
|
||||
#### 6. Development:
|
||||
|
||||
Once the container is built, you can start developing within the dev container. All the necessary dependencies and configurations are set up within the container.
|
||||
|
||||
|
||||
### Build and start
|
||||
#### Client
|
||||
|
||||
> Windows clients have a Wireguard driver requirement. We provide a bash script that can be executed in WLS 2 with docker support [wireguard_nt.sh](/client/wireguard_nt.sh).
|
||||
|
||||
To start NetBird, execute:
|
||||
```
|
||||
cd client
|
||||
# bash wireguard_nt.sh # if windows
|
||||
go build .
|
||||
CGO_ENABLED=0 go build .
|
||||
```
|
||||
|
||||
> Windows clients have a Wireguard driver requirement. You can download the wintun driver from https://www.wintun.net/builds/wintun-0.14.1.zip, after decompressing, you can copy the file `windtun\bin\ARCH\wintun.dll` to the same path as your binary file or to `C:\Windows\System32\wintun.dll`.
|
||||
|
||||
> To test the client GUI application on Windows machines with RDP or vituralized environments (e.g. virtualbox or cloud), you need to download and extract the opengl32.dll from https://fdossena.com/?p=mesa/index.frag next to the built application.
|
||||
|
||||
To start NetBird the client in the foreground:
|
||||
|
||||
```
|
||||
@@ -185,6 +228,42 @@ To start NetBird the management service:
|
||||
./management management --log-level debug --log-file console --config ./management.json
|
||||
```
|
||||
|
||||
#### Windows Netbird Installer
|
||||
Create dist directory
|
||||
```shell
|
||||
mkdir -p dist/netbird_windows_amd64
|
||||
```
|
||||
|
||||
UI client
|
||||
```shell
|
||||
CC=x86_64-w64-mingw32-gcc CGO_ENABLED=1 GOOS=windows GOARCH=amd64 go build -o netbird-ui.exe -ldflags "-s -w -H windowsgui" ./client/ui
|
||||
mv netbird-ui.exe ./dist/netbird_windows_amd64/
|
||||
```
|
||||
|
||||
Client
|
||||
```shell
|
||||
CGO_ENABLED=0 GOOS=windows GOARCH=amd64 go build -o netbird.exe ./client/
|
||||
mv netbird.exe ./dist/netbird_windows_amd64/
|
||||
```
|
||||
> Windows clients have a Wireguard driver requirement. You can download the wintun driver from https://www.wintun.net/builds/wintun-0.14.1.zip, after decompressing, you can copy the file `windtun\bin\ARCH\wintun.dll` to `./dist/netbird_windows_amd64/`.
|
||||
|
||||
NSIS compiler
|
||||
- [Windows-nsis]( https://nsis.sourceforge.io/Download)
|
||||
- [MacOS-makensis](https://formulae.brew.sh/formula/makensis#default)
|
||||
- [Linux-makensis](https://manpages.ubuntu.com/manpages/trusty/man1/makensis.1.html)
|
||||
|
||||
NSIS Plugins. Download and move them to the NSIS plugins folder.
|
||||
- [EnVar](https://nsis.sourceforge.io/mediawiki/images/7/7f/EnVar_plugin.zip)
|
||||
- [ShellExecAsUser](https://nsis.sourceforge.io/mediawiki/images/6/68/ShellExecAsUser_amd64-Unicode.7z)
|
||||
|
||||
Windows Installer
|
||||
```shell
|
||||
export APPVER=0.0.0.1
|
||||
makensis -V4 client/installer.nsis
|
||||
```
|
||||
|
||||
The installer `netbird-installer.exe` will be created in root directory.
|
||||
|
||||
### Test suite
|
||||
|
||||
The tests can be started via:
|
||||
@@ -195,6 +274,8 @@ go test -exec sudo ./...
|
||||
```
|
||||
> On Windows use a powershell with administrator privileges
|
||||
|
||||
> Non-GTK environments will need the `libayatana-appindicator3-dev` (debian/ubuntu) package installed
|
||||
|
||||
## Checklist before submitting a PR
|
||||
As a critical network service and open-source project, we must enforce a few things before submitting the pull-requests:
|
||||
- Keep functions as simple as possible, with a single purpose
|
||||
@@ -215,4 +296,4 @@ NetBird project is composed of 3 main repositories:
|
||||
|
||||
That we do not have any potential problems later it is sadly necessary to sign a [Contributor License Agreement](CONTRIBUTOR_LICENSE_AGREEMENT.md). That can be done literally with the push of a button.
|
||||
|
||||
A bot will automatically comment on the pull request once it got opened asking for the agreement to be signed. Before it did not get signed it is sadly not possible to merge it in.
|
||||
A bot will automatically comment on the pull request once it got opened asking for the agreement to be signed. Before it did not get signed it is sadly not possible to merge it in.
|
||||
|
||||
45
README.md
45
README.md
@@ -1,6 +1,6 @@
|
||||
<p align="center">
|
||||
<strong>:hatching_chick: New Release! Peer expiration.</strong>
|
||||
<a href="https://github.com/netbirdio/netbird/releases">
|
||||
<strong>:hatching_chick: New Release! Device Posture Checks.</strong>
|
||||
<a href="https://docs.netbird.io/how-to/manage-posture-checks">
|
||||
Learn more
|
||||
</a>
|
||||
</p>
|
||||
@@ -24,7 +24,7 @@
|
||||
|
||||
<p align="center">
|
||||
<strong>
|
||||
Start using NetBird at <a href="https://app.netbird.io/">app.netbird.io</a>
|
||||
Start using NetBird at <a href="https://netbird.io/pricing">netbird.io</a>
|
||||
<br/>
|
||||
See <a href="https://netbird.io/docs/">Documentation</a>
|
||||
<br/>
|
||||
@@ -40,27 +40,24 @@
|
||||
|
||||
**Connect.** NetBird creates a WireGuard-based overlay network that automatically connects your machines over an encrypted tunnel, leaving behind the hassle of opening ports, complex firewall rules, VPN gateways, and so forth.
|
||||
|
||||
**Secure.** NetBird isolates every machine and device by applying granular access policies, while allowing you to manage them intuitively from a single place.
|
||||
**Secure.** NetBird enables secure remote access by applying granular access policies, while allowing you to manage them intuitively from a single place. Works universally on any infrastructure.
|
||||
|
||||
**Key features:**
|
||||
### Open-Source Network Security in a Single Platform
|
||||
|
||||
| Connectivity | Management | Automation | Platforms |
|
||||
|-------------------------------------------------------------------|--------------------------------------------------------------------------|----------------------------------------------------------------------------|---------------------------------------|
|
||||
| <ul><li> - \[x] Kernel WireGuard </ul></li> | <ul><li> - \[x] [Admin Web UI](https://github.com/netbirdio/dashboard) </ul></li> | <ul><li> - \[x] [Public API](https://docs.netbird.io/api) </ul></li> | <ul><li> - \[x] Linux </ul></li> |
|
||||
| <ul><li> - \[x] Peer-to-peer connections </ul></li> | <ul><li> - \[x] Auto peer discovery and configuration </ul></li> | <ul><li> - \[x] [Setup keys for bulk network provisioning](https://docs.netbird.io/how-to/register-machines-using-setup-keys) </ul></li> | <ul><li> - \[x] Mac </ul></li> |
|
||||
| <ul><li> - \[x] Peer-to-peer encryption </ul></li> | <ul><li> - \[x] [IdP integrations](https://docs.netbird.io/selfhosted/identity-providers) </ul></li> | <ul><li> - \[x] [Self-hosting quickstart script](https://docs.netbird.io/selfhosted/selfhosted-quickstart) </ul></li> | <ul><li> - \[x] Windows </ul></li> |
|
||||
| <ul><li> - \[x] Connection relay fallback </ul></li> | <ul><li> - \[x] [SSO & MFA support](https://docs.netbird.io/how-to/installation#running-net-bird-with-sso-login) </ul></li> | <ul><li> - \[x] IdP groups sync with JWT </ul></li> | <ul><li> - \[x] Android </ul></li> |
|
||||
| <ul><li> - \[x] [Routes to external networks](https://docs.netbird.io/how-to/routing-traffic-to-private-networks) </ul></li> | <ul><li> - \[x] [Access control - groups & rules](https://docs.netbird.io/how-to/manage-network-access) </ul></li> | | <ul><li> - \[ ] iOS </ul></li> |
|
||||
| <ul><li> - \[x] NAT traversal with BPF </ul></li> | <ul><li> - \[x] [Private DNS](https://docs.netbird.io/how-to/manage-dns-in-your-network) </ul></li> | | <ul><li> - \[x] Docker </ul></li> |
|
||||
| | <ul><li> - \[x] [Multiuser support](https://docs.netbird.io/how-to/add-users-to-your-network) </ul></li> | | <ul><li> - \[x] OpenWRT </ul></li> |
|
||||
| | <ul><li> - \[x] [Activity logging](https://docs.netbird.io/how-to/monitor-system-and-network-activity) </ul></li> | | |
|
||||
| | <ul><li> - \[x] SSH access management </ul></li> | | |
|
||||

|
||||
|
||||
### Key features
|
||||
|
||||
### Secure peer-to-peer VPN with SSO and MFA in minutes
|
||||
|
||||
https://user-images.githubusercontent.com/700848/197345890-2e2cded5-7b7a-436f-a444-94e80dd24f46.mov
|
||||
|
||||
| Connectivity | Management | Security | Automation | Platforms |
|
||||
|------------------------------------------------------------------------------------------------------------------------------|----------------------------------------------------------------------------------------------------------|---------------------------------------------------------------------------------------------------------------------------------------|------------------------------------------------------------------------------------------------------------------------------------------|-----------------------------------------------------------------------------------------|
|
||||
| <ul><li> - \[x] Kernel WireGuard </ul></li> | <ul><li> - \[x] [Admin Web UI](https://github.com/netbirdio/dashboard) </ul></li> | <ul><li> - \[x] [SSO & MFA support](https://docs.netbird.io/how-to/installation#running-net-bird-with-sso-login) </ul></li> | <ul><li> - \[x] [Public API](https://docs.netbird.io/api) </ul></li> | <ul><li> - \[x] Linux </ul></li> |
|
||||
| <ul><li> - \[x] Peer-to-peer connections </ul></li> | <ul><li> - \[x] Auto peer discovery and configuration </ul></li> | <ul><li> - \[x] [Access control - groups & rules](https://docs.netbird.io/how-to/manage-network-access) </ul></li> | <ul><li> - \[x] [Setup keys for bulk network provisioning](https://docs.netbird.io/how-to/register-machines-using-setup-keys) </ul></li> | <ul><li> - \[x] Mac </ul></li> |
|
||||
| <ul><li> - \[x] Connection relay fallback </ul></li> | <ul><li> - \[x] [IdP integrations](https://docs.netbird.io/selfhosted/identity-providers) </ul></li> | <ul><li> - \[x] [Activity logging](https://docs.netbird.io/how-to/monitor-system-and-network-activity) </ul></li> | <ul><li> - \[x] [Self-hosting quickstart script](https://docs.netbird.io/selfhosted/selfhosted-quickstart) </ul></li> | <ul><li> - \[x] Windows </ul></li> |
|
||||
| <ul><li> - \[x] [Routes to external networks](https://docs.netbird.io/how-to/routing-traffic-to-private-networks) </ul></li> | <ul><li> - \[x] [Private DNS](https://docs.netbird.io/how-to/manage-dns-in-your-network) </ul></li> | <ul><li> - \[x] [Device posture checks](https://docs.netbird.io/how-to/manage-posture-checks) </ul></li> | <ul><li> - \[x] IdP groups sync with JWT </ul></li> | <ul><li> - \[x] Android </ul></li> |
|
||||
| <ul><li> - \[x] NAT traversal with BPF </ul></li> | <ul><li> - \[x] [Multiuser support](https://docs.netbird.io/how-to/add-users-to-your-network) </ul></li> | <ul><li> - \[x] Peer-to-peer encryption </ul></li> | | <ul><li> - \[x] iOS </ul></li> |
|
||||
| | | <ul><li> - \[x] [Quantum-resistance with Rosenpass](https://netbird.io/knowledge-hub/the-first-quantum-resistant-mesh-vpn) </ul></li> | | <ul><li> - \[x] OpenWRT </ul></li> |
|
||||
| | | <ui><li> - \[x] [Periodic re-authentication](https://docs.netbird.io/how-to/enforce-periodic-user-authentication)</ul></li> | | <ul><li> - \[x] [Serverless](https://docs.netbird.io/how-to/netbird-on-faas) </ul></li> |
|
||||
| | | | | <ul><li> - \[x] Docker </ul></li> |
|
||||
### Quickstart with NetBird Cloud
|
||||
|
||||
- Download and install NetBird at [https://app.netbird.io/install](https://app.netbird.io/install)
|
||||
@@ -109,8 +106,9 @@ export NETBIRD_DOMAIN=netbird.example.com; curl -fsSL https://github.com/netbird
|
||||
See a complete [architecture overview](https://docs.netbird.io/about-netbird/how-netbird-works#architecture) for details.
|
||||
|
||||
### Community projects
|
||||
- [NetBird on OpenWRT](https://github.com/messense/openwrt-netbird)
|
||||
- [NetBird installer script](https://github.com/physk/netbird-installer)
|
||||
- [NetBird ansible collection by Dominion Solutions](https://galaxy.ansible.com/ui/repo/published/dominion_solutions/netbird/)
|
||||
|
||||
|
||||
**Note**: The `main` branch may be in an *unstable or even broken state* during development.
|
||||
For stable versions, see [releases](https://github.com/netbirdio/netbird/releases).
|
||||
@@ -126,4 +124,7 @@ We use open-source technologies like [WireGuard®](https://www.wireguard.com/),
|
||||
|
||||
### Legal
|
||||
_WireGuard_ and the _WireGuard_ logo are [registered trademarks](https://www.wireguard.com/trademark-policy/) of Jason A. Donenfeld.
|
||||
|
||||
|
||||
dddd
|
||||
|
||||
ccc
|
||||
@@ -18,10 +18,9 @@ func Encode(num uint32) string {
|
||||
}
|
||||
|
||||
var encoded strings.Builder
|
||||
remainder := uint32(0)
|
||||
|
||||
for num > 0 {
|
||||
remainder = num % base
|
||||
remainder := num % base
|
||||
encoded.WriteByte(alphabet[remainder])
|
||||
num /= base
|
||||
}
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
FROM gcr.io/distroless/base:debug
|
||||
FROM alpine:3.18.5
|
||||
RUN apk add --no-cache ca-certificates iptables ip6tables
|
||||
ENV NB_FOREGROUND_MODE=true
|
||||
ENV PATH=/sbin:/usr/sbin:/bin:/usr/bin:/busybox
|
||||
SHELL ["/busybox/sh","-c"]
|
||||
RUN sed -i -E 's/(^root:.+)\/sbin\/nologin/\1\/busybox\/sh/g' /etc/passwd
|
||||
ENTRYPOINT [ "/go/bin/netbird","up"]
|
||||
COPY netbird /go/bin/netbird
|
||||
@@ -8,8 +8,8 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
"github.com/netbirdio/netbird/client/internal/dns"
|
||||
"github.com/netbirdio/netbird/client/internal/listener"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager"
|
||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||
"github.com/netbirdio/netbird/client/system"
|
||||
"github.com/netbirdio/netbird/formatter"
|
||||
@@ -31,9 +31,9 @@ type IFaceDiscover interface {
|
||||
stdnet.ExternalIFaceDiscover
|
||||
}
|
||||
|
||||
// RouteListener export internal RouteListener for mobile
|
||||
type RouteListener interface {
|
||||
routemanager.RouteListener
|
||||
// NetworkChangeListener export internal NetworkChangeListener for mobile
|
||||
type NetworkChangeListener interface {
|
||||
listener.NetworkChangeListener
|
||||
}
|
||||
|
||||
// DnsReadyListener export internal dns ReadyListener for mobile
|
||||
@@ -47,27 +47,26 @@ func init() {
|
||||
|
||||
// Client struct manage the life circle of background service
|
||||
type Client struct {
|
||||
cfgFile string
|
||||
tunAdapter iface.TunAdapter
|
||||
iFaceDiscover IFaceDiscover
|
||||
recorder *peer.Status
|
||||
ctxCancel context.CancelFunc
|
||||
ctxCancelLock *sync.Mutex
|
||||
deviceName string
|
||||
routeListener routemanager.RouteListener
|
||||
onHostDnsFn func([]string)
|
||||
cfgFile string
|
||||
tunAdapter iface.TunAdapter
|
||||
iFaceDiscover IFaceDiscover
|
||||
recorder *peer.Status
|
||||
ctxCancel context.CancelFunc
|
||||
ctxCancelLock *sync.Mutex
|
||||
deviceName string
|
||||
networkChangeListener listener.NetworkChangeListener
|
||||
}
|
||||
|
||||
// NewClient instantiate a new Client
|
||||
func NewClient(cfgFile, deviceName string, tunAdapter TunAdapter, iFaceDiscover IFaceDiscover, routeListener RouteListener) *Client {
|
||||
func NewClient(cfgFile, deviceName string, tunAdapter TunAdapter, iFaceDiscover IFaceDiscover, networkChangeListener NetworkChangeListener) *Client {
|
||||
return &Client{
|
||||
cfgFile: cfgFile,
|
||||
deviceName: deviceName,
|
||||
tunAdapter: tunAdapter,
|
||||
iFaceDiscover: iFaceDiscover,
|
||||
recorder: peer.NewRecorder(""),
|
||||
ctxCancelLock: &sync.Mutex{},
|
||||
routeListener: routeListener,
|
||||
cfgFile: cfgFile,
|
||||
deviceName: deviceName,
|
||||
tunAdapter: tunAdapter,
|
||||
iFaceDiscover: iFaceDiscover,
|
||||
recorder: peer.NewRecorder(""),
|
||||
ctxCancelLock: &sync.Mutex{},
|
||||
networkChangeListener: networkChangeListener,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -80,6 +79,7 @@ func (c *Client) Run(urlOpener URLOpener, dns *DNSList, dnsReadyListener DnsRead
|
||||
return err
|
||||
}
|
||||
c.recorder.UpdateManagementAddress(cfg.ManagementURL.String())
|
||||
c.recorder.UpdateRosenpass(cfg.RosenpassEnabled, cfg.RosenpassPermissive)
|
||||
|
||||
var ctx context.Context
|
||||
//nolint
|
||||
@@ -97,8 +97,32 @@ func (c *Client) Run(urlOpener URLOpener, dns *DNSList, dnsReadyListener DnsRead
|
||||
|
||||
// todo do not throw error in case of cancelled context
|
||||
ctx = internal.CtxInitState(ctx)
|
||||
c.onHostDnsFn = func([]string) {}
|
||||
return internal.RunClientMobile(ctx, cfg, c.recorder, c.tunAdapter, c.iFaceDiscover, c.routeListener, dns.items, dnsReadyListener)
|
||||
return internal.RunClientMobile(ctx, cfg, c.recorder, c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, dns.items, dnsReadyListener)
|
||||
}
|
||||
|
||||
// RunWithoutLogin we apply this type of run function when the backed has been started without UI (i.e. after reboot).
|
||||
// In this case make no sense handle registration steps.
|
||||
func (c *Client) RunWithoutLogin(dns *DNSList, dnsReadyListener DnsReadyListener) error {
|
||||
cfg, err := internal.UpdateOrCreateConfig(internal.ConfigInput{
|
||||
ConfigPath: c.cfgFile,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.recorder.UpdateManagementAddress(cfg.ManagementURL.String())
|
||||
c.recorder.UpdateRosenpass(cfg.RosenpassEnabled, cfg.RosenpassPermissive)
|
||||
|
||||
var ctx context.Context
|
||||
//nolint
|
||||
ctxWithValues := context.WithValue(context.Background(), system.DeviceNameCtxKey, c.deviceName)
|
||||
c.ctxCancelLock.Lock()
|
||||
ctx, c.ctxCancel = context.WithCancel(ctxWithValues)
|
||||
defer c.ctxCancel()
|
||||
c.ctxCancelLock.Unlock()
|
||||
|
||||
// todo do not throw error in case of cancelled context
|
||||
ctx = internal.CtxInitState(ctx)
|
||||
return internal.RunClientMobile(ctx, cfg, c.recorder, c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, dns.items, dnsReadyListener)
|
||||
}
|
||||
|
||||
// Stop the internal client and free the resources
|
||||
@@ -117,6 +141,11 @@ func (c *Client) SetTraceLogLevel() {
|
||||
log.SetLevel(log.TraceLevel)
|
||||
}
|
||||
|
||||
// SetInfoLogLevel configure the logger to info level
|
||||
func (c *Client) SetInfoLogLevel() {
|
||||
log.SetLevel(log.InfoLevel)
|
||||
}
|
||||
|
||||
// PeersList return with the list of the PeerInfos
|
||||
func (c *Client) PeersList() *PeerInfoArray {
|
||||
|
||||
|
||||
@@ -84,10 +84,14 @@ func (a *Auth) SaveConfigIfSSOSupported(listener SSOListener) {
|
||||
func (a *Auth) saveConfigIfSSOSupported() (bool, error) {
|
||||
supportsSSO := true
|
||||
err := a.withBackOff(a.ctx, func() (err error) {
|
||||
_, err = internal.GetDeviceAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL)
|
||||
if s, ok := gstatus.FromError(err); ok && s.Code() == codes.NotFound {
|
||||
_, err = internal.GetPKCEAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL)
|
||||
if s, ok := gstatus.FromError(err); ok && s.Code() == codes.NotFound {
|
||||
_, err = internal.GetPKCEAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL)
|
||||
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.NotFound || s.Code() == codes.Unimplemented) {
|
||||
_, err = internal.GetDeviceAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL)
|
||||
s, ok := gstatus.FromError(err)
|
||||
if !ok {
|
||||
return err
|
||||
}
|
||||
if s.Code() == codes.NotFound || s.Code() == codes.Unimplemented {
|
||||
supportsSSO = false
|
||||
err = nil
|
||||
}
|
||||
@@ -189,7 +193,7 @@ func (a *Auth) login(urlOpener URLOpener) error {
|
||||
}
|
||||
|
||||
func (a *Auth) foregroundGetTokenInfo(urlOpener URLOpener) (*auth.TokenInfo, error) {
|
||||
oAuthFlow, err := auth.NewOAuthFlow(a.ctx, a.config)
|
||||
oAuthFlow, err := auth.NewOAuthFlow(a.ctx, a.config, false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -201,8 +205,8 @@ func (a *Auth) foregroundGetTokenInfo(urlOpener URLOpener) (*auth.TokenInfo, err
|
||||
|
||||
go urlOpener.Open(flowInfo.VerificationURIComplete)
|
||||
|
||||
waitTimeout := time.Duration(flowInfo.ExpiresIn)
|
||||
waitCTX, cancel := context.WithTimeout(a.ctx, waitTimeout*time.Second)
|
||||
waitTimeout := time.Duration(flowInfo.ExpiresIn) * time.Second
|
||||
waitCTX, cancel := context.WithTimeout(a.ctx, waitTimeout)
|
||||
defer cancel()
|
||||
tokenInfo, err := oAuthFlow.WaitToken(waitCTX, flowInfo)
|
||||
if err != nil {
|
||||
|
||||
@@ -57,11 +57,11 @@ func TestPreferences_ReadUncommitedValues(t *testing.T) {
|
||||
p.SetManagementURL(exampleString)
|
||||
resp, err = p.GetManagementURL()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to read managmenet url: %s", err)
|
||||
t.Fatalf("failed to read management url: %s", err)
|
||||
}
|
||||
|
||||
if resp != exampleString {
|
||||
t.Errorf("unexpected managemenet url: %s", resp)
|
||||
t.Errorf("unexpected management url: %s", resp)
|
||||
}
|
||||
|
||||
p.SetPreSharedKey(exampleString)
|
||||
@@ -102,11 +102,11 @@ func TestPreferences_Commit(t *testing.T) {
|
||||
|
||||
resp, err = p.GetManagementURL()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to read managmenet url: %s", err)
|
||||
t.Fatalf("failed to read management url: %s", err)
|
||||
}
|
||||
|
||||
if resp != exampleURL {
|
||||
t.Errorf("unexpected managemenet url: %s", resp)
|
||||
t.Errorf("unexpected management url: %s", resp)
|
||||
}
|
||||
|
||||
resp, err = p.GetPreSharedKey()
|
||||
|
||||
@@ -3,21 +3,20 @@ package cmd
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"github.com/netbirdio/netbird/client/internal/auth"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/skratchdot/open-golang/open"
|
||||
"github.com/spf13/cobra"
|
||||
"google.golang.org/grpc/codes"
|
||||
gstatus "google.golang.org/grpc/status"
|
||||
|
||||
"github.com/netbirdio/netbird/util"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
"github.com/netbirdio/netbird/client/internal/auth"
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
"github.com/netbirdio/netbird/client/system"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
|
||||
var loginCmd = &cobra.Command{
|
||||
@@ -52,7 +51,7 @@ var loginCmd = &cobra.Command{
|
||||
AdminURL: adminURL,
|
||||
ConfigPath: configPath,
|
||||
}
|
||||
if preSharedKey != "" {
|
||||
if rootCmd.PersistentFlags().Changed(preSharedKeyFlag) {
|
||||
ic.PreSharedKey = &preSharedKey
|
||||
}
|
||||
|
||||
@@ -61,7 +60,7 @@ var loginCmd = &cobra.Command{
|
||||
return fmt.Errorf("get config file: %v", err)
|
||||
}
|
||||
|
||||
config, _ = internal.UpdateOldManagementPort(ctx, config, configPath)
|
||||
config, _ = internal.UpdateOldManagementURL(ctx, config, configPath)
|
||||
|
||||
err = foregroundLogin(ctx, cmd, config, setupKey)
|
||||
if err != nil {
|
||||
@@ -82,9 +81,14 @@ var loginCmd = &cobra.Command{
|
||||
client := proto.NewDaemonServiceClient(conn)
|
||||
|
||||
loginRequest := proto.LoginRequest{
|
||||
SetupKey: setupKey,
|
||||
PreSharedKey: preSharedKey,
|
||||
ManagementUrl: managementURL,
|
||||
SetupKey: setupKey,
|
||||
ManagementUrl: managementURL,
|
||||
IsLinuxDesktopClient: isLinuxRunningDesktop(),
|
||||
Hostname: hostName,
|
||||
}
|
||||
|
||||
if rootCmd.PersistentFlags().Changed(preSharedKeyFlag) {
|
||||
loginRequest.OptionalPreSharedKey = &preSharedKey
|
||||
}
|
||||
|
||||
var loginErr error
|
||||
@@ -114,7 +118,7 @@ var loginCmd = &cobra.Command{
|
||||
if loginResp.NeedsSSOLogin {
|
||||
openURL(cmd, loginResp.VerificationURIComplete, loginResp.UserCode)
|
||||
|
||||
_, err = client.WaitSSOLogin(ctx, &proto.WaitSSOLoginRequest{UserCode: loginResp.UserCode})
|
||||
_, err = client.WaitSSOLogin(ctx, &proto.WaitSSOLoginRequest{UserCode: loginResp.UserCode, Hostname: hostName})
|
||||
if err != nil {
|
||||
return fmt.Errorf("waiting sso login failed with: %v", err)
|
||||
}
|
||||
@@ -150,13 +154,21 @@ func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *internal.C
|
||||
jwtToken = tokenInfo.GetTokenToUse()
|
||||
}
|
||||
|
||||
var lastError error
|
||||
|
||||
err = WithBackOff(func() error {
|
||||
err := internal.Login(ctx, config, setupKey, jwtToken)
|
||||
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied) {
|
||||
lastError = err
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
})
|
||||
|
||||
if lastError != nil {
|
||||
return fmt.Errorf("login failed: %v", lastError)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("backoff cycle failed: %v", err)
|
||||
}
|
||||
@@ -165,7 +177,7 @@ func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *internal.C
|
||||
}
|
||||
|
||||
func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *internal.Config) (*auth.TokenInfo, error) {
|
||||
oAuthFlow, err := auth.NewOAuthFlow(ctx, config)
|
||||
oAuthFlow, err := auth.NewOAuthFlow(ctx, config, isLinuxRunningDesktop())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -177,8 +189,8 @@ func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *int
|
||||
|
||||
openURL(cmd, flowInfo.VerificationURIComplete, flowInfo.UserCode)
|
||||
|
||||
waitTimeout := time.Duration(flowInfo.ExpiresIn)
|
||||
waitCTX, c := context.WithTimeout(context.TODO(), waitTimeout*time.Second)
|
||||
waitTimeout := time.Duration(flowInfo.ExpiresIn) * time.Second
|
||||
waitCTX, c := context.WithTimeout(context.TODO(), waitTimeout)
|
||||
defer c()
|
||||
|
||||
tokenInfo, err := oAuthFlow.WaitToken(waitCTX, flowInfo)
|
||||
@@ -191,17 +203,21 @@ func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *int
|
||||
|
||||
func openURL(cmd *cobra.Command, verificationURIComplete, userCode string) {
|
||||
var codeMsg string
|
||||
if userCode != "" {
|
||||
if !strings.Contains(verificationURIComplete, userCode) {
|
||||
codeMsg = fmt.Sprintf("and enter the code %s to authenticate.", userCode)
|
||||
}
|
||||
if userCode != "" && !strings.Contains(verificationURIComplete, userCode) {
|
||||
codeMsg = fmt.Sprintf("and enter the code %s to authenticate.", userCode)
|
||||
}
|
||||
|
||||
err := open.Run(verificationURIComplete)
|
||||
cmd.Printf("Please do the SSO login in your browser. \n" +
|
||||
cmd.Println("Please do the SSO login in your browser. \n" +
|
||||
"If your browser didn't open automatically, use this URL to log in:\n\n" +
|
||||
" " + verificationURIComplete + " " + codeMsg + " \n\n")
|
||||
if err != nil {
|
||||
cmd.Printf("Alternatively, you may want to use a setup key, see:\n\n https://www.netbird.io/docs/overview/setup-keys\n")
|
||||
verificationURIComplete + " " + codeMsg)
|
||||
cmd.Println("")
|
||||
if err := open.Run(verificationURIComplete); err != nil {
|
||||
cmd.Println("\nAlternatively, you may want to use a setup key, see:\n\n" +
|
||||
"https://docs.netbird.io/how-to/register-machines-using-setup-keys")
|
||||
}
|
||||
}
|
||||
|
||||
// isLinuxRunningDesktop checks if a Linux OS is running desktop environment
|
||||
func isLinuxRunningDesktop() bool {
|
||||
return os.Getenv("DESKTOP_SESSION") != "" || os.Getenv("XDG_CURRENT_DESKTOP") != ""
|
||||
}
|
||||
|
||||
@@ -25,8 +25,15 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
externalIPMapFlag = "external-ip-map"
|
||||
dnsResolverAddress = "dns-resolver-address"
|
||||
externalIPMapFlag = "external-ip-map"
|
||||
dnsResolverAddress = "dns-resolver-address"
|
||||
enableRosenpassFlag = "enable-rosenpass"
|
||||
rosenpassPermissiveFlag = "rosenpass-permissive"
|
||||
preSharedKeyFlag = "preshared-key"
|
||||
interfaceNameFlag = "interface-name"
|
||||
wireguardPortFlag = "wireguard-port"
|
||||
disableAutoConnectFlag = "disable-auto-connect"
|
||||
serverSSHAllowedFlag = "allow-server-ssh"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -49,6 +56,13 @@ var (
|
||||
preSharedKey string
|
||||
natExternalIPs []string
|
||||
customDNSAddress string
|
||||
rosenpassEnabled bool
|
||||
rosenpassPermissive bool
|
||||
serverSSHAllowed bool
|
||||
interfaceName string
|
||||
wireguardPort uint16
|
||||
serviceName string
|
||||
autoConnectDisabled bool
|
||||
rootCmd = &cobra.Command{
|
||||
Use: "netbird",
|
||||
Short: "",
|
||||
@@ -87,14 +101,21 @@ func init() {
|
||||
if runtime.GOOS == "windows" {
|
||||
defaultDaemonAddr = "tcp://127.0.0.1:41731"
|
||||
}
|
||||
|
||||
defaultServiceName := "netbird"
|
||||
if runtime.GOOS == "windows" {
|
||||
defaultServiceName = "Netbird"
|
||||
}
|
||||
|
||||
rootCmd.PersistentFlags().StringVar(&daemonAddr, "daemon-addr", defaultDaemonAddr, "Daemon service address to serve CLI requests [unix|tcp]://[path|host:port]")
|
||||
rootCmd.PersistentFlags().StringVarP(&managementURL, "management-url", "m", "", fmt.Sprintf("Management Service URL [http|https]://[host]:[port] (default \"%s\")", internal.DefaultManagementURL))
|
||||
rootCmd.PersistentFlags().StringVar(&adminURL, "admin-url", "", fmt.Sprintf("Admin Panel URL [http|https]://[host]:[port] (default \"%s\")", internal.DefaultAdminURL))
|
||||
rootCmd.PersistentFlags().StringVarP(&serviceName, "service", "s", defaultServiceName, "Netbird system service name")
|
||||
rootCmd.PersistentFlags().StringVarP(&configPath, "config", "c", defaultConfigPath, "Netbird config file location")
|
||||
rootCmd.PersistentFlags().StringVarP(&logLevel, "log-level", "l", "info", "sets Netbird log level")
|
||||
rootCmd.PersistentFlags().StringVar(&logFile, "log-file", defaultLogFile, "sets Netbird log path. If console is specified the the log will be output to stdout")
|
||||
rootCmd.PersistentFlags().StringVar(&logFile, "log-file", defaultLogFile, "sets Netbird log path. If console is specified the log will be output to stdout")
|
||||
rootCmd.PersistentFlags().StringVarP(&setupKey, "setup-key", "k", "", "Setup key obtained from the Management Service Dashboard (used to register peer)")
|
||||
rootCmd.PersistentFlags().StringVar(&preSharedKey, "preshared-key", "", "Sets Wireguard PreSharedKey property. If set, then only peers that have the same key can communicate.")
|
||||
rootCmd.PersistentFlags().StringVar(&preSharedKey, preSharedKeyFlag, "", "Sets Wireguard PreSharedKey property. If set, then only peers that have the same key can communicate.")
|
||||
rootCmd.PersistentFlags().StringVarP(&hostName, "hostname", "n", "", "Sets a custom hostname for the device")
|
||||
rootCmd.AddCommand(serviceCmd)
|
||||
rootCmd.AddCommand(upCmd)
|
||||
@@ -118,6 +139,10 @@ func init() {
|
||||
`An empty string "" clears the previous configuration. `+
|
||||
`E.g. --dns-resolver-address 127.0.0.1:5053 or --dns-resolver-address ""`,
|
||||
)
|
||||
upCmd.PersistentFlags().BoolVar(&rosenpassEnabled, enableRosenpassFlag, false, "[Experimental] Enable Rosenpass feature. If enabled, the connection will be post-quantum secured via Rosenpass.")
|
||||
upCmd.PersistentFlags().BoolVar(&rosenpassPermissive, rosenpassPermissiveFlag, false, "[Experimental] Enable Rosenpass in permissive mode to allow this peer to accept WireGuard connections without requiring Rosenpass functionality from peers that do not have Rosenpass enabled.")
|
||||
upCmd.PersistentFlags().BoolVar(&serverSSHAllowed, serverSSHAllowedFlag, false, "Allow SSH server on peer. If enabled, the SSH server will be permitted")
|
||||
upCmd.PersistentFlags().BoolVar(&autoConnectDisabled, disableAutoConnectFlag, false, "Disables auto-connect feature. If enabled, then the client won't connect automatically when the service starts.")
|
||||
}
|
||||
|
||||
// SetupCloseHandler handles SIGTERM signal and exits with success
|
||||
@@ -168,7 +193,7 @@ func FlagNameToEnvVar(cmdFlag string, prefix string) string {
|
||||
return prefix + upper
|
||||
}
|
||||
|
||||
// DialClientGRPCServer returns client connection to the dameno server.
|
||||
// DialClientGRPCServer returns client connection to the daemon server.
|
||||
func DialClientGRPCServer(ctx context.Context, addr string) (*grpc.ClientConn, error) {
|
||||
ctx, cancel := context.WithTimeout(ctx, time.Second*3)
|
||||
defer cancel()
|
||||
|
||||
@@ -2,8 +2,6 @@ package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"runtime"
|
||||
|
||||
"github.com/kardianos/service"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/spf13/cobra"
|
||||
@@ -24,12 +22,8 @@ func newProgram(ctx context.Context, cancel context.CancelFunc) *program {
|
||||
}
|
||||
|
||||
func newSVCConfig() *service.Config {
|
||||
name := "netbird"
|
||||
if runtime.GOOS == "windows" {
|
||||
name = "Netbird"
|
||||
}
|
||||
return &service.Config{
|
||||
Name: name,
|
||||
Name: serviceName,
|
||||
DisplayName: "Netbird",
|
||||
Description: "A WireGuard-based mesh network that connects your devices into a single private network.",
|
||||
Option: make(service.KeyValue),
|
||||
|
||||
@@ -11,11 +11,12 @@ import (
|
||||
"github.com/kardianos/service"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
"google.golang.org/grpc"
|
||||
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
"github.com/netbirdio/netbird/client/server"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
"github.com/spf13/cobra"
|
||||
"google.golang.org/grpc"
|
||||
)
|
||||
|
||||
func (p *program) Start(svc service.Service) error {
|
||||
@@ -109,7 +110,6 @@ var runCmd = &cobra.Command{
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
cmd.Printf("Netbird service is running")
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
@@ -77,6 +77,7 @@ var installCmd = &cobra.Command{
|
||||
cmd.PrintErrln(err)
|
||||
return err
|
||||
}
|
||||
|
||||
cmd.Println("Netbird service has been installed")
|
||||
return nil
|
||||
},
|
||||
@@ -106,7 +107,7 @@ var uninstallCmd = &cobra.Command{
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
cmd.Println("Netbird has been uninstalled")
|
||||
cmd.Println("Netbird service has been uninstalled")
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
@@ -22,14 +22,20 @@ import (
|
||||
)
|
||||
|
||||
type peerStateDetailOutput struct {
|
||||
FQDN string `json:"fqdn" yaml:"fqdn"`
|
||||
IP string `json:"netbirdIp" yaml:"netbirdIp"`
|
||||
PubKey string `json:"publicKey" yaml:"publicKey"`
|
||||
Status string `json:"status" yaml:"status"`
|
||||
LastStatusUpdate time.Time `json:"lastStatusUpdate" yaml:"lastStatusUpdate"`
|
||||
ConnType string `json:"connectionType" yaml:"connectionType"`
|
||||
Direct bool `json:"direct" yaml:"direct"`
|
||||
IceCandidateType iceCandidateType `json:"iceCandidateType" yaml:"iceCandidateType"`
|
||||
FQDN string `json:"fqdn" yaml:"fqdn"`
|
||||
IP string `json:"netbirdIp" yaml:"netbirdIp"`
|
||||
PubKey string `json:"publicKey" yaml:"publicKey"`
|
||||
Status string `json:"status" yaml:"status"`
|
||||
LastStatusUpdate time.Time `json:"lastStatusUpdate" yaml:"lastStatusUpdate"`
|
||||
ConnType string `json:"connectionType" yaml:"connectionType"`
|
||||
Direct bool `json:"direct" yaml:"direct"`
|
||||
IceCandidateType iceCandidateType `json:"iceCandidateType" yaml:"iceCandidateType"`
|
||||
IceCandidateEndpoint iceCandidateType `json:"iceCandidateEndpoint" yaml:"iceCandidateEndpoint"`
|
||||
LastWireguardHandshake time.Time `json:"lastWireguardHandshake" yaml:"lastWireguardHandshake"`
|
||||
TransferReceived int64 `json:"transferReceived" yaml:"transferReceived"`
|
||||
TransferSent int64 `json:"transferSent" yaml:"transferSent"`
|
||||
RosenpassEnabled bool `json:"quantumResistance" yaml:"quantumResistance"`
|
||||
Routes []string `json:"routes" yaml:"routes"`
|
||||
}
|
||||
|
||||
type peersStateOutput struct {
|
||||
@@ -41,11 +47,25 @@ type peersStateOutput struct {
|
||||
type signalStateOutput struct {
|
||||
URL string `json:"url" yaml:"url"`
|
||||
Connected bool `json:"connected" yaml:"connected"`
|
||||
Error string `json:"error" yaml:"error"`
|
||||
}
|
||||
|
||||
type managementStateOutput struct {
|
||||
URL string `json:"url" yaml:"url"`
|
||||
Connected bool `json:"connected" yaml:"connected"`
|
||||
Error string `json:"error" yaml:"error"`
|
||||
}
|
||||
|
||||
type relayStateOutputDetail struct {
|
||||
URI string `json:"uri" yaml:"uri"`
|
||||
Available bool `json:"available" yaml:"available"`
|
||||
Error string `json:"error" yaml:"error"`
|
||||
}
|
||||
|
||||
type relayStateOutput struct {
|
||||
Total int `json:"total" yaml:"total"`
|
||||
Available int `json:"available" yaml:"available"`
|
||||
Details []relayStateOutputDetail `json:"details" yaml:"details"`
|
||||
}
|
||||
|
||||
type iceCandidateType struct {
|
||||
@@ -53,26 +73,40 @@ type iceCandidateType struct {
|
||||
Remote string `json:"remote" yaml:"remote"`
|
||||
}
|
||||
|
||||
type nsServerGroupStateOutput struct {
|
||||
Servers []string `json:"servers" yaml:"servers"`
|
||||
Domains []string `json:"domains" yaml:"domains"`
|
||||
Enabled bool `json:"enabled" yaml:"enabled"`
|
||||
Error string `json:"error" yaml:"error"`
|
||||
}
|
||||
|
||||
type statusOutputOverview struct {
|
||||
Peers peersStateOutput `json:"peers" yaml:"peers"`
|
||||
CliVersion string `json:"cliVersion" yaml:"cliVersion"`
|
||||
DaemonVersion string `json:"daemonVersion" yaml:"daemonVersion"`
|
||||
ManagementState managementStateOutput `json:"management" yaml:"management"`
|
||||
SignalState signalStateOutput `json:"signal" yaml:"signal"`
|
||||
IP string `json:"netbirdIp" yaml:"netbirdIp"`
|
||||
PubKey string `json:"publicKey" yaml:"publicKey"`
|
||||
KernelInterface bool `json:"usesKernelInterface" yaml:"usesKernelInterface"`
|
||||
FQDN string `json:"fqdn" yaml:"fqdn"`
|
||||
Peers peersStateOutput `json:"peers" yaml:"peers"`
|
||||
CliVersion string `json:"cliVersion" yaml:"cliVersion"`
|
||||
DaemonVersion string `json:"daemonVersion" yaml:"daemonVersion"`
|
||||
ManagementState managementStateOutput `json:"management" yaml:"management"`
|
||||
SignalState signalStateOutput `json:"signal" yaml:"signal"`
|
||||
Relays relayStateOutput `json:"relays" yaml:"relays"`
|
||||
IP string `json:"netbirdIp" yaml:"netbirdIp"`
|
||||
PubKey string `json:"publicKey" yaml:"publicKey"`
|
||||
KernelInterface bool `json:"usesKernelInterface" yaml:"usesKernelInterface"`
|
||||
FQDN string `json:"fqdn" yaml:"fqdn"`
|
||||
RosenpassEnabled bool `json:"quantumResistance" yaml:"quantumResistance"`
|
||||
RosenpassPermissive bool `json:"quantumResistancePermissive" yaml:"quantumResistancePermissive"`
|
||||
Routes []string `json:"routes" yaml:"routes"`
|
||||
NSServerGroups []nsServerGroupStateOutput `json:"dnsServers" yaml:"dnsServers"`
|
||||
}
|
||||
|
||||
var (
|
||||
detailFlag bool
|
||||
ipv4Flag bool
|
||||
jsonFlag bool
|
||||
yamlFlag bool
|
||||
ipsFilter []string
|
||||
statusFilter string
|
||||
ipsFilterMap map[string]struct{}
|
||||
detailFlag bool
|
||||
ipv4Flag bool
|
||||
jsonFlag bool
|
||||
yamlFlag bool
|
||||
ipsFilter []string
|
||||
prefixNamesFilter []string
|
||||
statusFilter string
|
||||
ipsFilterMap map[string]struct{}
|
||||
prefixNamesFilterMap map[string]struct{}
|
||||
)
|
||||
|
||||
var statusCmd = &cobra.Command{
|
||||
@@ -83,12 +117,14 @@ var statusCmd = &cobra.Command{
|
||||
|
||||
func init() {
|
||||
ipsFilterMap = make(map[string]struct{})
|
||||
prefixNamesFilterMap = make(map[string]struct{})
|
||||
statusCmd.PersistentFlags().BoolVarP(&detailFlag, "detail", "d", false, "display detailed status information in human-readable format")
|
||||
statusCmd.PersistentFlags().BoolVar(&jsonFlag, "json", false, "display detailed status information in json format")
|
||||
statusCmd.PersistentFlags().BoolVar(&yamlFlag, "yaml", false, "display detailed status information in yaml format")
|
||||
statusCmd.PersistentFlags().BoolVar(&ipv4Flag, "ipv4", false, "display only NetBird IPv4 of this peer, e.g., --ipv4 will output 100.64.0.33")
|
||||
statusCmd.MarkFlagsMutuallyExclusive("detail", "json", "yaml", "ipv4")
|
||||
statusCmd.PersistentFlags().StringSliceVar(&ipsFilter, "filter-by-ips", []string{}, "filters the detailed output by a list of one or more IPs, e.g., --filter-by-ips 100.64.0.100,100.64.0.200")
|
||||
statusCmd.PersistentFlags().StringSliceVar(&prefixNamesFilter, "filter-by-names", []string{}, "filters the detailed output by a list of one or more peer FQDN or hostnames, e.g., --filter-by-names peer-a,peer-b.netbird.cloud")
|
||||
statusCmd.PersistentFlags().StringVar(&statusFilter, "filter-by-status", "", "filters the detailed output by connection status(connected|disconnected), e.g., --filter-by-status connected")
|
||||
}
|
||||
|
||||
@@ -109,9 +145,9 @@ func statusFunc(cmd *cobra.Command, args []string) error {
|
||||
|
||||
ctx := internal.CtxInitState(context.Background())
|
||||
|
||||
resp, _ := getStatus(ctx, cmd)
|
||||
resp, err := getStatus(ctx, cmd)
|
||||
if err != nil {
|
||||
return nil
|
||||
return err
|
||||
}
|
||||
|
||||
if resp.GetStatus() == string(internal.StatusNeedsLogin) || resp.GetStatus() == string(internal.StatusLoginFailed) {
|
||||
@@ -120,7 +156,7 @@ func statusFunc(cmd *cobra.Command, args []string) error {
|
||||
" netbird up \n\n"+
|
||||
"If you are running a self-hosted version and no SSO provider has been configured in your Management Server,\n"+
|
||||
"you can use a setup-key:\n\n netbird up --management-url <YOUR_MANAGEMENT_URL> --setup-key <YOUR_SETUP_KEY>\n\n"+
|
||||
"More info: https://www.netbird.io/docs/overview/setup-keys\n\n",
|
||||
"More info: https://docs.netbird.io/how-to/register-machines-using-setup-keys\n\n",
|
||||
resp.GetStatus(),
|
||||
)
|
||||
return nil
|
||||
@@ -133,7 +169,7 @@ func statusFunc(cmd *cobra.Command, args []string) error {
|
||||
|
||||
outputInformationHolder := convertToStatusOutputOverview(resp)
|
||||
|
||||
statusOutputString := ""
|
||||
var statusOutputString string
|
||||
switch {
|
||||
case detailFlag:
|
||||
statusOutputString = parseToFullDetailSummary(outputInformationHolder)
|
||||
@@ -142,7 +178,7 @@ func statusFunc(cmd *cobra.Command, args []string) error {
|
||||
case yamlFlag:
|
||||
statusOutputString, err = parseToYAML(outputInformationHolder)
|
||||
default:
|
||||
statusOutputString = parseGeneralSummary(outputInformationHolder, false)
|
||||
statusOutputString = parseGeneralSummary(outputInformationHolder, false, false, false)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
@@ -172,8 +208,12 @@ func getStatus(ctx context.Context, cmd *cobra.Command) (*proto.StatusResponse,
|
||||
}
|
||||
|
||||
func parseFilters() error {
|
||||
|
||||
switch strings.ToLower(statusFilter) {
|
||||
case "", "disconnected", "connected":
|
||||
if strings.ToLower(statusFilter) != "" {
|
||||
enableDetailFlagWhenFilterFlag()
|
||||
}
|
||||
default:
|
||||
return fmt.Errorf("wrong status filter, should be one of connected|disconnected, got: %s", statusFilter)
|
||||
}
|
||||
@@ -185,11 +225,26 @@ func parseFilters() error {
|
||||
return fmt.Errorf("got an invalid IP address in the filter: address %s, error %s", addr, err)
|
||||
}
|
||||
ipsFilterMap[addr] = struct{}{}
|
||||
enableDetailFlagWhenFilterFlag()
|
||||
}
|
||||
}
|
||||
|
||||
if len(prefixNamesFilter) > 0 {
|
||||
for _, name := range prefixNamesFilter {
|
||||
prefixNamesFilterMap[strings.ToLower(name)] = struct{}{}
|
||||
}
|
||||
enableDetailFlagWhenFilterFlag()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func enableDetailFlagWhenFilterFlag() {
|
||||
if !detailFlag && !jsonFlag && !yamlFlag {
|
||||
detailFlag = true
|
||||
}
|
||||
}
|
||||
|
||||
func convertToStatusOutputOverview(resp *proto.StatusResponse) statusOutputOverview {
|
||||
pbFullStatus := resp.GetFullStatus()
|
||||
|
||||
@@ -197,51 +252,108 @@ func convertToStatusOutputOverview(resp *proto.StatusResponse) statusOutputOverv
|
||||
managementOverview := managementStateOutput{
|
||||
URL: managementState.GetURL(),
|
||||
Connected: managementState.GetConnected(),
|
||||
Error: managementState.Error,
|
||||
}
|
||||
|
||||
signalState := pbFullStatus.GetSignalState()
|
||||
signalOverview := signalStateOutput{
|
||||
URL: signalState.GetURL(),
|
||||
Connected: signalState.GetConnected(),
|
||||
Error: signalState.Error,
|
||||
}
|
||||
|
||||
relayOverview := mapRelays(pbFullStatus.GetRelays())
|
||||
peersOverview := mapPeers(resp.GetFullStatus().GetPeers())
|
||||
|
||||
overview := statusOutputOverview{
|
||||
Peers: peersOverview,
|
||||
CliVersion: version.NetbirdVersion(),
|
||||
DaemonVersion: resp.GetDaemonVersion(),
|
||||
ManagementState: managementOverview,
|
||||
SignalState: signalOverview,
|
||||
IP: pbFullStatus.GetLocalPeerState().GetIP(),
|
||||
PubKey: pbFullStatus.GetLocalPeerState().GetPubKey(),
|
||||
KernelInterface: pbFullStatus.GetLocalPeerState().GetKernelInterface(),
|
||||
FQDN: pbFullStatus.GetLocalPeerState().GetFqdn(),
|
||||
Peers: peersOverview,
|
||||
CliVersion: version.NetbirdVersion(),
|
||||
DaemonVersion: resp.GetDaemonVersion(),
|
||||
ManagementState: managementOverview,
|
||||
SignalState: signalOverview,
|
||||
Relays: relayOverview,
|
||||
IP: pbFullStatus.GetLocalPeerState().GetIP(),
|
||||
PubKey: pbFullStatus.GetLocalPeerState().GetPubKey(),
|
||||
KernelInterface: pbFullStatus.GetLocalPeerState().GetKernelInterface(),
|
||||
FQDN: pbFullStatus.GetLocalPeerState().GetFqdn(),
|
||||
RosenpassEnabled: pbFullStatus.GetLocalPeerState().GetRosenpassEnabled(),
|
||||
RosenpassPermissive: pbFullStatus.GetLocalPeerState().GetRosenpassPermissive(),
|
||||
Routes: pbFullStatus.GetLocalPeerState().GetRoutes(),
|
||||
NSServerGroups: mapNSGroups(pbFullStatus.GetDnsServers()),
|
||||
}
|
||||
|
||||
return overview
|
||||
}
|
||||
|
||||
func mapRelays(relays []*proto.RelayState) relayStateOutput {
|
||||
var relayStateDetail []relayStateOutputDetail
|
||||
|
||||
var relaysAvailable int
|
||||
for _, relay := range relays {
|
||||
available := relay.GetAvailable()
|
||||
relayStateDetail = append(relayStateDetail,
|
||||
relayStateOutputDetail{
|
||||
URI: relay.URI,
|
||||
Available: available,
|
||||
Error: relay.GetError(),
|
||||
},
|
||||
)
|
||||
|
||||
if available {
|
||||
relaysAvailable++
|
||||
}
|
||||
}
|
||||
|
||||
return relayStateOutput{
|
||||
Total: len(relays),
|
||||
Available: relaysAvailable,
|
||||
Details: relayStateDetail,
|
||||
}
|
||||
}
|
||||
|
||||
func mapNSGroups(servers []*proto.NSGroupState) []nsServerGroupStateOutput {
|
||||
mappedNSGroups := make([]nsServerGroupStateOutput, 0, len(servers))
|
||||
for _, pbNsGroupServer := range servers {
|
||||
mappedNSGroups = append(mappedNSGroups, nsServerGroupStateOutput{
|
||||
Servers: pbNsGroupServer.GetServers(),
|
||||
Domains: pbNsGroupServer.GetDomains(),
|
||||
Enabled: pbNsGroupServer.GetEnabled(),
|
||||
Error: pbNsGroupServer.GetError(),
|
||||
})
|
||||
}
|
||||
return mappedNSGroups
|
||||
}
|
||||
|
||||
func mapPeers(peers []*proto.PeerState) peersStateOutput {
|
||||
var peersStateDetail []peerStateDetailOutput
|
||||
localICE := ""
|
||||
remoteICE := ""
|
||||
localICEEndpoint := ""
|
||||
remoteICEEndpoint := ""
|
||||
connType := ""
|
||||
peersConnected := 0
|
||||
lastHandshake := time.Time{}
|
||||
transferReceived := int64(0)
|
||||
transferSent := int64(0)
|
||||
for _, pbPeerState := range peers {
|
||||
isPeerConnected := pbPeerState.ConnStatus == peer.StatusConnected.String()
|
||||
if skipDetailByFilters(pbPeerState, isPeerConnected) {
|
||||
continue
|
||||
}
|
||||
if isPeerConnected {
|
||||
peersConnected = peersConnected + 1
|
||||
peersConnected++
|
||||
|
||||
localICE = pbPeerState.GetLocalIceCandidateType()
|
||||
remoteICE = pbPeerState.GetRemoteIceCandidateType()
|
||||
localICEEndpoint = pbPeerState.GetLocalIceCandidateEndpoint()
|
||||
remoteICEEndpoint = pbPeerState.GetRemoteIceCandidateEndpoint()
|
||||
connType = "P2P"
|
||||
if pbPeerState.Relayed {
|
||||
connType = "Relayed"
|
||||
}
|
||||
lastHandshake = pbPeerState.GetLastWireguardHandshake().AsTime().Local()
|
||||
transferReceived = pbPeerState.GetBytesRx()
|
||||
transferSent = pbPeerState.GetBytesTx()
|
||||
}
|
||||
|
||||
timeLocal := pbPeerState.GetConnStatusUpdate().AsTime().Local()
|
||||
@@ -256,7 +368,16 @@ func mapPeers(peers []*proto.PeerState) peersStateOutput {
|
||||
Local: localICE,
|
||||
Remote: remoteICE,
|
||||
},
|
||||
FQDN: pbPeerState.GetFqdn(),
|
||||
IceCandidateEndpoint: iceCandidateType{
|
||||
Local: localICEEndpoint,
|
||||
Remote: remoteICEEndpoint,
|
||||
},
|
||||
FQDN: pbPeerState.GetFqdn(),
|
||||
LastWireguardHandshake: lastHandshake,
|
||||
TransferReceived: transferReceived,
|
||||
TransferSent: transferSent,
|
||||
RosenpassEnabled: pbPeerState.GetRosenpassEnabled(),
|
||||
Routes: pbPeerState.GetRoutes(),
|
||||
}
|
||||
|
||||
peersStateDetail = append(peersStateDetail, peerState)
|
||||
@@ -306,22 +427,31 @@ func parseToYAML(overview statusOutputOverview) (string, error) {
|
||||
return string(yamlBytes), nil
|
||||
}
|
||||
|
||||
func parseGeneralSummary(overview statusOutputOverview, showURL bool) string {
|
||||
|
||||
managementConnString := "Disconnected"
|
||||
func parseGeneralSummary(overview statusOutputOverview, showURL bool, showRelays bool, showNameServers bool) string {
|
||||
var managementConnString string
|
||||
if overview.ManagementState.Connected {
|
||||
managementConnString = "Connected"
|
||||
if showURL {
|
||||
managementConnString = fmt.Sprintf("%s to %s", managementConnString, overview.ManagementState.URL)
|
||||
}
|
||||
} else {
|
||||
managementConnString = "Disconnected"
|
||||
if overview.ManagementState.Error != "" {
|
||||
managementConnString = fmt.Sprintf("%s, reason: %s", managementConnString, overview.ManagementState.Error)
|
||||
}
|
||||
}
|
||||
|
||||
signalConnString := "Disconnected"
|
||||
var signalConnString string
|
||||
if overview.SignalState.Connected {
|
||||
signalConnString = "Connected"
|
||||
if showURL {
|
||||
signalConnString = fmt.Sprintf("%s to %s", signalConnString, overview.SignalState.URL)
|
||||
}
|
||||
} else {
|
||||
signalConnString = "Disconnected"
|
||||
if overview.SignalState.Error != "" {
|
||||
signalConnString = fmt.Sprintf("%s, reason: %s", signalConnString, overview.SignalState.Error)
|
||||
}
|
||||
}
|
||||
|
||||
interfaceTypeString := "Userspace"
|
||||
@@ -333,6 +463,64 @@ func parseGeneralSummary(overview statusOutputOverview, showURL bool) string {
|
||||
interfaceIP = "N/A"
|
||||
}
|
||||
|
||||
var relaysString string
|
||||
if showRelays {
|
||||
for _, relay := range overview.Relays.Details {
|
||||
available := "Available"
|
||||
reason := ""
|
||||
if !relay.Available {
|
||||
available = "Unavailable"
|
||||
reason = fmt.Sprintf(", reason: %s", relay.Error)
|
||||
}
|
||||
relaysString += fmt.Sprintf("\n [%s] is %s%s", relay.URI, available, reason)
|
||||
}
|
||||
} else {
|
||||
relaysString = fmt.Sprintf("%d/%d Available", overview.Relays.Available, overview.Relays.Total)
|
||||
}
|
||||
|
||||
routes := "-"
|
||||
if len(overview.Routes) > 0 {
|
||||
sort.Strings(overview.Routes)
|
||||
routes = strings.Join(overview.Routes, ", ")
|
||||
}
|
||||
|
||||
var dnsServersString string
|
||||
if showNameServers {
|
||||
for _, nsServerGroup := range overview.NSServerGroups {
|
||||
enabled := "Available"
|
||||
if !nsServerGroup.Enabled {
|
||||
enabled = "Unavailable"
|
||||
}
|
||||
errorString := ""
|
||||
if nsServerGroup.Error != "" {
|
||||
errorString = fmt.Sprintf(", reason: %s", nsServerGroup.Error)
|
||||
errorString = strings.TrimSpace(errorString)
|
||||
}
|
||||
|
||||
domainsString := strings.Join(nsServerGroup.Domains, ", ")
|
||||
if domainsString == "" {
|
||||
domainsString = "." // Show "." for the default zone
|
||||
}
|
||||
dnsServersString += fmt.Sprintf(
|
||||
"\n [%s] for [%s] is %s%s",
|
||||
strings.Join(nsServerGroup.Servers, ", "),
|
||||
domainsString,
|
||||
enabled,
|
||||
errorString,
|
||||
)
|
||||
}
|
||||
} else {
|
||||
dnsServersString = fmt.Sprintf("%d/%d Available", countEnabled(overview.NSServerGroups), len(overview.NSServerGroups))
|
||||
}
|
||||
|
||||
rosenpassEnabledStatus := "false"
|
||||
if overview.RosenpassEnabled {
|
||||
rosenpassEnabledStatus = "true"
|
||||
if overview.RosenpassPermissive {
|
||||
rosenpassEnabledStatus = "true (permissive)" //nolint:gosec
|
||||
}
|
||||
}
|
||||
|
||||
peersCountString := fmt.Sprintf("%d/%d Connected", overview.Peers.Connected, overview.Peers.Total)
|
||||
|
||||
summary := fmt.Sprintf(
|
||||
@@ -340,25 +528,33 @@ func parseGeneralSummary(overview statusOutputOverview, showURL bool) string {
|
||||
"CLI version: %s\n"+
|
||||
"Management: %s\n"+
|
||||
"Signal: %s\n"+
|
||||
"Relays: %s\n"+
|
||||
"Nameservers: %s\n"+
|
||||
"FQDN: %s\n"+
|
||||
"NetBird IP: %s\n"+
|
||||
"Interface type: %s\n"+
|
||||
"Quantum resistance: %s\n"+
|
||||
"Routes: %s\n"+
|
||||
"Peers count: %s\n",
|
||||
overview.DaemonVersion,
|
||||
version.NetbirdVersion(),
|
||||
managementConnString,
|
||||
signalConnString,
|
||||
relaysString,
|
||||
dnsServersString,
|
||||
overview.FQDN,
|
||||
interfaceIP,
|
||||
interfaceTypeString,
|
||||
rosenpassEnabledStatus,
|
||||
routes,
|
||||
peersCountString,
|
||||
)
|
||||
return summary
|
||||
}
|
||||
|
||||
func parseToFullDetailSummary(overview statusOutputOverview) string {
|
||||
parsedPeersString := parsePeers(overview.Peers)
|
||||
summary := parseGeneralSummary(overview, true)
|
||||
parsedPeersString := parsePeers(overview.Peers, overview.RosenpassEnabled, overview.RosenpassPermissive)
|
||||
summary := parseGeneralSummary(overview, true, true, true)
|
||||
|
||||
return fmt.Sprintf(
|
||||
"Peers detail:"+
|
||||
@@ -369,7 +565,7 @@ func parseToFullDetailSummary(overview statusOutputOverview) string {
|
||||
)
|
||||
}
|
||||
|
||||
func parsePeers(peers peersStateOutput) string {
|
||||
func parsePeers(peers peersStateOutput, rosenpassEnabled, rosenpassPermissive bool) string {
|
||||
var (
|
||||
peersString = ""
|
||||
)
|
||||
@@ -386,6 +582,48 @@ func parsePeers(peers peersStateOutput) string {
|
||||
remoteICE = peerState.IceCandidateType.Remote
|
||||
}
|
||||
|
||||
localICEEndpoint := "-"
|
||||
if peerState.IceCandidateEndpoint.Local != "" {
|
||||
localICEEndpoint = peerState.IceCandidateEndpoint.Local
|
||||
}
|
||||
|
||||
remoteICEEndpoint := "-"
|
||||
if peerState.IceCandidateEndpoint.Remote != "" {
|
||||
remoteICEEndpoint = peerState.IceCandidateEndpoint.Remote
|
||||
}
|
||||
lastStatusUpdate := "-"
|
||||
if !peerState.LastStatusUpdate.IsZero() {
|
||||
lastStatusUpdate = peerState.LastStatusUpdate.Format("2006-01-02 15:04:05")
|
||||
}
|
||||
|
||||
lastWireGuardHandshake := "-"
|
||||
if !peerState.LastWireguardHandshake.IsZero() && peerState.LastWireguardHandshake != time.Unix(0, 0) {
|
||||
lastWireGuardHandshake = peerState.LastWireguardHandshake.Format("2006-01-02 15:04:05")
|
||||
}
|
||||
|
||||
rosenpassEnabledStatus := "false"
|
||||
if rosenpassEnabled {
|
||||
if peerState.RosenpassEnabled {
|
||||
rosenpassEnabledStatus = "true"
|
||||
} else {
|
||||
if rosenpassPermissive {
|
||||
rosenpassEnabledStatus = "false (remote didn't enable quantum resistance)"
|
||||
} else {
|
||||
rosenpassEnabledStatus = "false (connection won't work without a permissive mode)"
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if peerState.RosenpassEnabled {
|
||||
rosenpassEnabledStatus = "false (connection might not work without a remote permissive mode)"
|
||||
}
|
||||
}
|
||||
|
||||
routes := "-"
|
||||
if len(peerState.Routes) > 0 {
|
||||
sort.Strings(peerState.Routes)
|
||||
routes = strings.Join(peerState.Routes, ", ")
|
||||
}
|
||||
|
||||
peerString := fmt.Sprintf(
|
||||
"\n %s:\n"+
|
||||
" NetBird IP: %s\n"+
|
||||
@@ -395,7 +633,12 @@ func parsePeers(peers peersStateOutput) string {
|
||||
" Connection type: %s\n"+
|
||||
" Direct: %t\n"+
|
||||
" ICE candidate (Local/Remote): %s/%s\n"+
|
||||
" Last connection update: %s\n",
|
||||
" ICE candidate endpoints (Local/Remote): %s/%s\n"+
|
||||
" Last connection update: %s\n"+
|
||||
" Last WireGuard handshake: %s\n"+
|
||||
" Transfer status (received/sent) %s/%s\n"+
|
||||
" Quantum resistance: %s\n"+
|
||||
" Routes: %s\n",
|
||||
peerState.FQDN,
|
||||
peerState.IP,
|
||||
peerState.PubKey,
|
||||
@@ -404,10 +647,17 @@ func parsePeers(peers peersStateOutput) string {
|
||||
peerState.Direct,
|
||||
localICE,
|
||||
remoteICE,
|
||||
peerState.LastStatusUpdate.Format("2006-01-02 15:04:05"),
|
||||
localICEEndpoint,
|
||||
remoteICEEndpoint,
|
||||
lastStatusUpdate,
|
||||
lastWireGuardHandshake,
|
||||
toIEC(peerState.TransferReceived),
|
||||
toIEC(peerState.TransferSent),
|
||||
rosenpassEnabledStatus,
|
||||
routes,
|
||||
)
|
||||
|
||||
peersString = peersString + peerString
|
||||
peersString += peerString
|
||||
}
|
||||
return peersString
|
||||
}
|
||||
@@ -415,6 +665,7 @@ func parsePeers(peers peersStateOutput) string {
|
||||
func skipDetailByFilters(peerState *proto.PeerState, isConnected bool) bool {
|
||||
statusEval := false
|
||||
ipEval := false
|
||||
nameEval := false
|
||||
|
||||
if statusFilter != "" {
|
||||
lowerStatusFilter := strings.ToLower(statusFilter)
|
||||
@@ -431,5 +682,39 @@ func skipDetailByFilters(peerState *proto.PeerState, isConnected bool) bool {
|
||||
ipEval = true
|
||||
}
|
||||
}
|
||||
return statusEval || ipEval
|
||||
|
||||
if len(prefixNamesFilter) > 0 {
|
||||
for prefixNameFilter := range prefixNamesFilterMap {
|
||||
if !strings.HasPrefix(peerState.Fqdn, prefixNameFilter) {
|
||||
nameEval = true
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return statusEval || ipEval || nameEval
|
||||
}
|
||||
|
||||
func toIEC(b int64) string {
|
||||
const unit = 1024
|
||||
if b < unit {
|
||||
return fmt.Sprintf("%d B", b)
|
||||
}
|
||||
div, exp := int64(unit), 0
|
||||
for n := b / unit; n >= unit; n /= unit {
|
||||
div *= unit
|
||||
exp++
|
||||
}
|
||||
return fmt.Sprintf("%.1f %ciB",
|
||||
float64(b)/float64(div), "KMGTPE"[exp])
|
||||
}
|
||||
|
||||
func countEnabled(dnsServers []nsServerGroupStateOutput) int {
|
||||
count := 0
|
||||
for _, server := range dnsServers {
|
||||
if server.Enabled {
|
||||
count++
|
||||
}
|
||||
}
|
||||
return count
|
||||
}
|
||||
|
||||
@@ -1,10 +1,13 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
@@ -25,41 +28,93 @@ var resp = &proto.StatusResponse{
|
||||
FullStatus: &proto.FullStatus{
|
||||
Peers: []*proto.PeerState{
|
||||
{
|
||||
IP: "192.168.178.101",
|
||||
PubKey: "Pubkey1",
|
||||
Fqdn: "peer-1.awesome-domain.com",
|
||||
ConnStatus: "Connected",
|
||||
ConnStatusUpdate: timestamppb.New(time.Date(2001, time.Month(1), 1, 1, 1, 1, 0, time.UTC)),
|
||||
Relayed: false,
|
||||
Direct: true,
|
||||
LocalIceCandidateType: "",
|
||||
RemoteIceCandidateType: "",
|
||||
IP: "192.168.178.101",
|
||||
PubKey: "Pubkey1",
|
||||
Fqdn: "peer-1.awesome-domain.com",
|
||||
ConnStatus: "Connected",
|
||||
ConnStatusUpdate: timestamppb.New(time.Date(2001, time.Month(1), 1, 1, 1, 1, 0, time.UTC)),
|
||||
Relayed: false,
|
||||
Direct: true,
|
||||
LocalIceCandidateType: "",
|
||||
RemoteIceCandidateType: "",
|
||||
LocalIceCandidateEndpoint: "",
|
||||
RemoteIceCandidateEndpoint: "",
|
||||
LastWireguardHandshake: timestamppb.New(time.Date(2001, time.Month(1), 1, 1, 1, 2, 0, time.UTC)),
|
||||
BytesRx: 200,
|
||||
BytesTx: 100,
|
||||
Routes: []string{
|
||||
"10.1.0.0/24",
|
||||
},
|
||||
},
|
||||
{
|
||||
IP: "192.168.178.102",
|
||||
PubKey: "Pubkey2",
|
||||
Fqdn: "peer-2.awesome-domain.com",
|
||||
ConnStatus: "Connected",
|
||||
ConnStatusUpdate: timestamppb.New(time.Date(2002, time.Month(2), 2, 2, 2, 2, 0, time.UTC)),
|
||||
Relayed: true,
|
||||
Direct: false,
|
||||
LocalIceCandidateType: "relay",
|
||||
RemoteIceCandidateType: "prflx",
|
||||
IP: "192.168.178.102",
|
||||
PubKey: "Pubkey2",
|
||||
Fqdn: "peer-2.awesome-domain.com",
|
||||
ConnStatus: "Connected",
|
||||
ConnStatusUpdate: timestamppb.New(time.Date(2002, time.Month(2), 2, 2, 2, 2, 0, time.UTC)),
|
||||
Relayed: true,
|
||||
Direct: false,
|
||||
LocalIceCandidateType: "relay",
|
||||
RemoteIceCandidateType: "prflx",
|
||||
LocalIceCandidateEndpoint: "10.0.0.1:10001",
|
||||
RemoteIceCandidateEndpoint: "10.0.10.1:10002",
|
||||
LastWireguardHandshake: timestamppb.New(time.Date(2002, time.Month(2), 2, 2, 2, 3, 0, time.UTC)),
|
||||
BytesRx: 2000,
|
||||
BytesTx: 1000,
|
||||
},
|
||||
},
|
||||
ManagementState: &proto.ManagementState{
|
||||
URL: "my-awesome-management.com:443",
|
||||
Connected: true,
|
||||
Error: "",
|
||||
},
|
||||
SignalState: &proto.SignalState{
|
||||
URL: "my-awesome-signal.com:443",
|
||||
Connected: true,
|
||||
Error: "",
|
||||
},
|
||||
Relays: []*proto.RelayState{
|
||||
{
|
||||
URI: "stun:my-awesome-stun.com:3478",
|
||||
Available: true,
|
||||
Error: "",
|
||||
},
|
||||
{
|
||||
URI: "turns:my-awesome-turn.com:443?transport=tcp",
|
||||
Available: false,
|
||||
Error: "context: deadline exceeded",
|
||||
},
|
||||
},
|
||||
LocalPeerState: &proto.LocalPeerState{
|
||||
IP: "192.168.178.100/16",
|
||||
PubKey: "Some-Pub-Key",
|
||||
KernelInterface: true,
|
||||
Fqdn: "some-localhost.awesome-domain.com",
|
||||
Routes: []string{
|
||||
"10.10.0.0/24",
|
||||
},
|
||||
},
|
||||
DnsServers: []*proto.NSGroupState{
|
||||
{
|
||||
Servers: []string{
|
||||
"8.8.8.8:53",
|
||||
},
|
||||
Domains: nil,
|
||||
Enabled: true,
|
||||
Error: "",
|
||||
},
|
||||
{
|
||||
Servers: []string{
|
||||
"1.1.1.1:53",
|
||||
"2.2.2.2:53",
|
||||
},
|
||||
Domains: []string{
|
||||
"example.com",
|
||||
"example.net",
|
||||
},
|
||||
Enabled: false,
|
||||
Error: "timeout",
|
||||
},
|
||||
},
|
||||
},
|
||||
DaemonVersion: "0.14.1",
|
||||
@@ -82,6 +137,16 @@ var overview = statusOutputOverview{
|
||||
Local: "",
|
||||
Remote: "",
|
||||
},
|
||||
IceCandidateEndpoint: iceCandidateType{
|
||||
Local: "",
|
||||
Remote: "",
|
||||
},
|
||||
LastWireguardHandshake: time.Date(2001, 1, 1, 1, 1, 2, 0, time.UTC),
|
||||
TransferReceived: 200,
|
||||
TransferSent: 100,
|
||||
Routes: []string{
|
||||
"10.1.0.0/24",
|
||||
},
|
||||
},
|
||||
{
|
||||
IP: "192.168.178.102",
|
||||
@@ -95,6 +160,13 @@ var overview = statusOutputOverview{
|
||||
Local: "relay",
|
||||
Remote: "prflx",
|
||||
},
|
||||
IceCandidateEndpoint: iceCandidateType{
|
||||
Local: "10.0.0.1:10001",
|
||||
Remote: "10.0.10.1:10002",
|
||||
},
|
||||
LastWireguardHandshake: time.Date(2002, 2, 2, 2, 2, 3, 0, time.UTC),
|
||||
TransferReceived: 2000,
|
||||
TransferSent: 1000,
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -103,15 +175,58 @@ var overview = statusOutputOverview{
|
||||
ManagementState: managementStateOutput{
|
||||
URL: "my-awesome-management.com:443",
|
||||
Connected: true,
|
||||
Error: "",
|
||||
},
|
||||
SignalState: signalStateOutput{
|
||||
URL: "my-awesome-signal.com:443",
|
||||
Connected: true,
|
||||
Error: "",
|
||||
},
|
||||
Relays: relayStateOutput{
|
||||
Total: 2,
|
||||
Available: 1,
|
||||
Details: []relayStateOutputDetail{
|
||||
{
|
||||
URI: "stun:my-awesome-stun.com:3478",
|
||||
Available: true,
|
||||
Error: "",
|
||||
},
|
||||
{
|
||||
URI: "turns:my-awesome-turn.com:443?transport=tcp",
|
||||
Available: false,
|
||||
Error: "context: deadline exceeded",
|
||||
},
|
||||
},
|
||||
},
|
||||
IP: "192.168.178.100/16",
|
||||
PubKey: "Some-Pub-Key",
|
||||
KernelInterface: true,
|
||||
FQDN: "some-localhost.awesome-domain.com",
|
||||
NSServerGroups: []nsServerGroupStateOutput{
|
||||
{
|
||||
Servers: []string{
|
||||
"8.8.8.8:53",
|
||||
},
|
||||
Domains: nil,
|
||||
Enabled: true,
|
||||
Error: "",
|
||||
},
|
||||
{
|
||||
Servers: []string{
|
||||
"1.1.1.1:53",
|
||||
"2.2.2.2:53",
|
||||
},
|
||||
Domains: []string{
|
||||
"example.com",
|
||||
"example.net",
|
||||
},
|
||||
Enabled: false,
|
||||
Error: "timeout",
|
||||
},
|
||||
},
|
||||
Routes: []string{
|
||||
"10.10.0.0/24",
|
||||
},
|
||||
}
|
||||
|
||||
func TestConversionFromFullStatusToOutputOverview(t *testing.T) {
|
||||
@@ -145,107 +260,219 @@ func TestSortingOfPeers(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestParsingToJSON(t *testing.T) {
|
||||
json, _ := parseToJSON(overview)
|
||||
jsonString, _ := parseToJSON(overview)
|
||||
|
||||
//@formatter:off
|
||||
expectedJSON := "{\"" +
|
||||
"peers\":" +
|
||||
"{" +
|
||||
"\"total\":2," +
|
||||
"\"connected\":2," +
|
||||
"\"details\":" +
|
||||
"[" +
|
||||
"{" +
|
||||
"\"fqdn\":\"peer-1.awesome-domain.com\"," +
|
||||
"\"netbirdIp\":\"192.168.178.101\"," +
|
||||
"\"publicKey\":\"Pubkey1\"," +
|
||||
"\"status\":\"Connected\"," +
|
||||
"\"lastStatusUpdate\":\"2001-01-01T01:01:01Z\"," +
|
||||
"\"connectionType\":\"P2P\"," +
|
||||
"\"direct\":true," +
|
||||
"\"iceCandidateType\":" +
|
||||
"{" +
|
||||
"\"local\":\"\"," +
|
||||
"\"remote\":\"\"" +
|
||||
"}" +
|
||||
"}," +
|
||||
"{" +
|
||||
"\"fqdn\":\"peer-2.awesome-domain.com\"," +
|
||||
"\"netbirdIp\":\"192.168.178.102\"," +
|
||||
"\"publicKey\":\"Pubkey2\"," +
|
||||
"\"status\":\"Connected\"," +
|
||||
"\"lastStatusUpdate\":\"2002-02-02T02:02:02Z\"," +
|
||||
"\"connectionType\":\"Relayed\"," +
|
||||
"\"direct\":false," +
|
||||
"\"iceCandidateType\":" +
|
||||
"{" +
|
||||
"\"local\":\"relay\"," +
|
||||
"\"remote\":\"prflx\"" +
|
||||
"}" +
|
||||
"}" +
|
||||
"]" +
|
||||
"}," +
|
||||
"\"cliVersion\":\"development\"," +
|
||||
"\"daemonVersion\":\"0.14.1\"," +
|
||||
"\"management\":" +
|
||||
"{" +
|
||||
"\"url\":\"my-awesome-management.com:443\"," +
|
||||
"\"connected\":true" +
|
||||
"}," +
|
||||
"\"signal\":" +
|
||||
"{\"" +
|
||||
"url\":\"my-awesome-signal.com:443\"," +
|
||||
"\"connected\":true" +
|
||||
"}," +
|
||||
"\"netbirdIp\":\"192.168.178.100/16\"," +
|
||||
"\"publicKey\":\"Some-Pub-Key\"," +
|
||||
"\"usesKernelInterface\":true," +
|
||||
"\"fqdn\":\"some-localhost.awesome-domain.com\"" +
|
||||
"}"
|
||||
expectedJSONString := `
|
||||
{
|
||||
"peers": {
|
||||
"total": 2,
|
||||
"connected": 2,
|
||||
"details": [
|
||||
{
|
||||
"fqdn": "peer-1.awesome-domain.com",
|
||||
"netbirdIp": "192.168.178.101",
|
||||
"publicKey": "Pubkey1",
|
||||
"status": "Connected",
|
||||
"lastStatusUpdate": "2001-01-01T01:01:01Z",
|
||||
"connectionType": "P2P",
|
||||
"direct": true,
|
||||
"iceCandidateType": {
|
||||
"local": "",
|
||||
"remote": ""
|
||||
},
|
||||
"iceCandidateEndpoint": {
|
||||
"local": "",
|
||||
"remote": ""
|
||||
},
|
||||
"lastWireguardHandshake": "2001-01-01T01:01:02Z",
|
||||
"transferReceived": 200,
|
||||
"transferSent": 100,
|
||||
"quantumResistance": false,
|
||||
"routes": [
|
||||
"10.1.0.0/24"
|
||||
]
|
||||
},
|
||||
{
|
||||
"fqdn": "peer-2.awesome-domain.com",
|
||||
"netbirdIp": "192.168.178.102",
|
||||
"publicKey": "Pubkey2",
|
||||
"status": "Connected",
|
||||
"lastStatusUpdate": "2002-02-02T02:02:02Z",
|
||||
"connectionType": "Relayed",
|
||||
"direct": false,
|
||||
"iceCandidateType": {
|
||||
"local": "relay",
|
||||
"remote": "prflx"
|
||||
},
|
||||
"iceCandidateEndpoint": {
|
||||
"local": "10.0.0.1:10001",
|
||||
"remote": "10.0.10.1:10002"
|
||||
},
|
||||
"lastWireguardHandshake": "2002-02-02T02:02:03Z",
|
||||
"transferReceived": 2000,
|
||||
"transferSent": 1000,
|
||||
"quantumResistance": false,
|
||||
"routes": null
|
||||
}
|
||||
]
|
||||
},
|
||||
"cliVersion": "development",
|
||||
"daemonVersion": "0.14.1",
|
||||
"management": {
|
||||
"url": "my-awesome-management.com:443",
|
||||
"connected": true,
|
||||
"error": ""
|
||||
},
|
||||
"signal": {
|
||||
"url": "my-awesome-signal.com:443",
|
||||
"connected": true,
|
||||
"error": ""
|
||||
},
|
||||
"relays": {
|
||||
"total": 2,
|
||||
"available": 1,
|
||||
"details": [
|
||||
{
|
||||
"uri": "stun:my-awesome-stun.com:3478",
|
||||
"available": true,
|
||||
"error": ""
|
||||
},
|
||||
{
|
||||
"uri": "turns:my-awesome-turn.com:443?transport=tcp",
|
||||
"available": false,
|
||||
"error": "context: deadline exceeded"
|
||||
}
|
||||
]
|
||||
},
|
||||
"netbirdIp": "192.168.178.100/16",
|
||||
"publicKey": "Some-Pub-Key",
|
||||
"usesKernelInterface": true,
|
||||
"fqdn": "some-localhost.awesome-domain.com",
|
||||
"quantumResistance": false,
|
||||
"quantumResistancePermissive": false,
|
||||
"routes": [
|
||||
"10.10.0.0/24"
|
||||
],
|
||||
"dnsServers": [
|
||||
{
|
||||
"servers": [
|
||||
"8.8.8.8:53"
|
||||
],
|
||||
"domains": null,
|
||||
"enabled": true,
|
||||
"error": ""
|
||||
},
|
||||
{
|
||||
"servers": [
|
||||
"1.1.1.1:53",
|
||||
"2.2.2.2:53"
|
||||
],
|
||||
"domains": [
|
||||
"example.com",
|
||||
"example.net"
|
||||
],
|
||||
"enabled": false,
|
||||
"error": "timeout"
|
||||
}
|
||||
]
|
||||
}`
|
||||
// @formatter:on
|
||||
|
||||
assert.Equal(t, expectedJSON, json)
|
||||
var expectedJSON bytes.Buffer
|
||||
require.NoError(t, json.Compact(&expectedJSON, []byte(expectedJSONString)))
|
||||
|
||||
assert.Equal(t, expectedJSON.String(), jsonString)
|
||||
}
|
||||
|
||||
func TestParsingToYAML(t *testing.T) {
|
||||
yaml, _ := parseToYAML(overview)
|
||||
|
||||
expectedYAML := "peers:\n" +
|
||||
" total: 2\n" +
|
||||
" connected: 2\n" +
|
||||
" details:\n" +
|
||||
" - fqdn: peer-1.awesome-domain.com\n" +
|
||||
" netbirdIp: 192.168.178.101\n" +
|
||||
" publicKey: Pubkey1\n" +
|
||||
" status: Connected\n" +
|
||||
" lastStatusUpdate: 2001-01-01T01:01:01Z\n" +
|
||||
" connectionType: P2P\n" +
|
||||
" direct: true\n" +
|
||||
" iceCandidateType:\n" +
|
||||
" local: \"\"\n" +
|
||||
" remote: \"\"\n" +
|
||||
" - fqdn: peer-2.awesome-domain.com\n" +
|
||||
" netbirdIp: 192.168.178.102\n" +
|
||||
" publicKey: Pubkey2\n" +
|
||||
" status: Connected\n" +
|
||||
" lastStatusUpdate: 2002-02-02T02:02:02Z\n" +
|
||||
" connectionType: Relayed\n" +
|
||||
" direct: false\n" +
|
||||
" iceCandidateType:\n" +
|
||||
" local: relay\n" +
|
||||
" remote: prflx\n" +
|
||||
"cliVersion: development\n" +
|
||||
"daemonVersion: 0.14.1\n" +
|
||||
"management:\n" +
|
||||
" url: my-awesome-management.com:443\n" +
|
||||
" connected: true\n" +
|
||||
"signal:\n" +
|
||||
" url: my-awesome-signal.com:443\n" +
|
||||
" connected: true\n" +
|
||||
"netbirdIp: 192.168.178.100/16\n" +
|
||||
"publicKey: Some-Pub-Key\n" +
|
||||
"usesKernelInterface: true\n" +
|
||||
"fqdn: some-localhost.awesome-domain.com\n"
|
||||
expectedYAML :=
|
||||
`peers:
|
||||
total: 2
|
||||
connected: 2
|
||||
details:
|
||||
- fqdn: peer-1.awesome-domain.com
|
||||
netbirdIp: 192.168.178.101
|
||||
publicKey: Pubkey1
|
||||
status: Connected
|
||||
lastStatusUpdate: 2001-01-01T01:01:01Z
|
||||
connectionType: P2P
|
||||
direct: true
|
||||
iceCandidateType:
|
||||
local: ""
|
||||
remote: ""
|
||||
iceCandidateEndpoint:
|
||||
local: ""
|
||||
remote: ""
|
||||
lastWireguardHandshake: 2001-01-01T01:01:02Z
|
||||
transferReceived: 200
|
||||
transferSent: 100
|
||||
quantumResistance: false
|
||||
routes:
|
||||
- 10.1.0.0/24
|
||||
- fqdn: peer-2.awesome-domain.com
|
||||
netbirdIp: 192.168.178.102
|
||||
publicKey: Pubkey2
|
||||
status: Connected
|
||||
lastStatusUpdate: 2002-02-02T02:02:02Z
|
||||
connectionType: Relayed
|
||||
direct: false
|
||||
iceCandidateType:
|
||||
local: relay
|
||||
remote: prflx
|
||||
iceCandidateEndpoint:
|
||||
local: 10.0.0.1:10001
|
||||
remote: 10.0.10.1:10002
|
||||
lastWireguardHandshake: 2002-02-02T02:02:03Z
|
||||
transferReceived: 2000
|
||||
transferSent: 1000
|
||||
quantumResistance: false
|
||||
routes: []
|
||||
cliVersion: development
|
||||
daemonVersion: 0.14.1
|
||||
management:
|
||||
url: my-awesome-management.com:443
|
||||
connected: true
|
||||
error: ""
|
||||
signal:
|
||||
url: my-awesome-signal.com:443
|
||||
connected: true
|
||||
error: ""
|
||||
relays:
|
||||
total: 2
|
||||
available: 1
|
||||
details:
|
||||
- uri: stun:my-awesome-stun.com:3478
|
||||
available: true
|
||||
error: ""
|
||||
- uri: turns:my-awesome-turn.com:443?transport=tcp
|
||||
available: false
|
||||
error: 'context: deadline exceeded'
|
||||
netbirdIp: 192.168.178.100/16
|
||||
publicKey: Some-Pub-Key
|
||||
usesKernelInterface: true
|
||||
fqdn: some-localhost.awesome-domain.com
|
||||
quantumResistance: false
|
||||
quantumResistancePermissive: false
|
||||
routes:
|
||||
- 10.10.0.0/24
|
||||
dnsServers:
|
||||
- servers:
|
||||
- 8.8.8.8:53
|
||||
domains: []
|
||||
enabled: true
|
||||
error: ""
|
||||
- servers:
|
||||
- 1.1.1.1:53
|
||||
- 2.2.2.2:53
|
||||
domains:
|
||||
- example.com
|
||||
- example.net
|
||||
enabled: false
|
||||
error: timeout
|
||||
`
|
||||
|
||||
assert.Equal(t, expectedYAML, yaml)
|
||||
}
|
||||
@@ -253,50 +480,76 @@ func TestParsingToYAML(t *testing.T) {
|
||||
func TestParsingToDetail(t *testing.T) {
|
||||
detail := parseToFullDetailSummary(overview)
|
||||
|
||||
expectedDetail := "Peers detail:\n" +
|
||||
" peer-1.awesome-domain.com:\n" +
|
||||
" NetBird IP: 192.168.178.101\n" +
|
||||
" Public key: Pubkey1\n" +
|
||||
" Status: Connected\n" +
|
||||
" -- detail --\n" +
|
||||
" Connection type: P2P\n" +
|
||||
" Direct: true\n" +
|
||||
" ICE candidate (Local/Remote): -/-\n" +
|
||||
" Last connection update: 2001-01-01 01:01:01\n" +
|
||||
"\n" +
|
||||
" peer-2.awesome-domain.com:\n" +
|
||||
" NetBird IP: 192.168.178.102\n" +
|
||||
" Public key: Pubkey2\n" +
|
||||
" Status: Connected\n" +
|
||||
" -- detail --\n" +
|
||||
" Connection type: Relayed\n" +
|
||||
" Direct: false\n" +
|
||||
" ICE candidate (Local/Remote): relay/prflx\n" +
|
||||
" Last connection update: 2002-02-02 02:02:02\n" +
|
||||
"\n" +
|
||||
"Daemon version: 0.14.1\n" +
|
||||
"CLI version: development\n" +
|
||||
"Management: Connected to my-awesome-management.com:443\n" +
|
||||
"Signal: Connected to my-awesome-signal.com:443\n" +
|
||||
"FQDN: some-localhost.awesome-domain.com\n" +
|
||||
"NetBird IP: 192.168.178.100/16\n" +
|
||||
"Interface type: Kernel\n" +
|
||||
"Peers count: 2/2 Connected\n"
|
||||
expectedDetail :=
|
||||
`Peers detail:
|
||||
peer-1.awesome-domain.com:
|
||||
NetBird IP: 192.168.178.101
|
||||
Public key: Pubkey1
|
||||
Status: Connected
|
||||
-- detail --
|
||||
Connection type: P2P
|
||||
Direct: true
|
||||
ICE candidate (Local/Remote): -/-
|
||||
ICE candidate endpoints (Local/Remote): -/-
|
||||
Last connection update: 2001-01-01 01:01:01
|
||||
Last WireGuard handshake: 2001-01-01 01:01:02
|
||||
Transfer status (received/sent) 200 B/100 B
|
||||
Quantum resistance: false
|
||||
Routes: 10.1.0.0/24
|
||||
|
||||
peer-2.awesome-domain.com:
|
||||
NetBird IP: 192.168.178.102
|
||||
Public key: Pubkey2
|
||||
Status: Connected
|
||||
-- detail --
|
||||
Connection type: Relayed
|
||||
Direct: false
|
||||
ICE candidate (Local/Remote): relay/prflx
|
||||
ICE candidate endpoints (Local/Remote): 10.0.0.1:10001/10.0.10.1:10002
|
||||
Last connection update: 2002-02-02 02:02:02
|
||||
Last WireGuard handshake: 2002-02-02 02:02:03
|
||||
Transfer status (received/sent) 2.0 KiB/1000 B
|
||||
Quantum resistance: false
|
||||
Routes: -
|
||||
|
||||
Daemon version: 0.14.1
|
||||
CLI version: development
|
||||
Management: Connected to my-awesome-management.com:443
|
||||
Signal: Connected to my-awesome-signal.com:443
|
||||
Relays:
|
||||
[stun:my-awesome-stun.com:3478] is Available
|
||||
[turns:my-awesome-turn.com:443?transport=tcp] is Unavailable, reason: context: deadline exceeded
|
||||
Nameservers:
|
||||
[8.8.8.8:53] for [.] is Available
|
||||
[1.1.1.1:53, 2.2.2.2:53] for [example.com, example.net] is Unavailable, reason: timeout
|
||||
FQDN: some-localhost.awesome-domain.com
|
||||
NetBird IP: 192.168.178.100/16
|
||||
Interface type: Kernel
|
||||
Quantum resistance: false
|
||||
Routes: 10.10.0.0/24
|
||||
Peers count: 2/2 Connected
|
||||
`
|
||||
|
||||
assert.Equal(t, expectedDetail, detail)
|
||||
}
|
||||
|
||||
func TestParsingToShortVersion(t *testing.T) {
|
||||
shortVersion := parseGeneralSummary(overview, false)
|
||||
shortVersion := parseGeneralSummary(overview, false, false, false)
|
||||
|
||||
expectedString := "Daemon version: 0.14.1\n" +
|
||||
"CLI version: development\n" +
|
||||
"Management: Connected\n" +
|
||||
"Signal: Connected\n" +
|
||||
"FQDN: some-localhost.awesome-domain.com\n" +
|
||||
"NetBird IP: 192.168.178.100/16\n" +
|
||||
"Interface type: Kernel\n" +
|
||||
"Peers count: 2/2 Connected\n"
|
||||
expectedString :=
|
||||
`Daemon version: 0.14.1
|
||||
CLI version: development
|
||||
Management: Connected
|
||||
Signal: Connected
|
||||
Relays: 1/2 Available
|
||||
Nameservers: 1/2 Available
|
||||
FQDN: some-localhost.awesome-domain.com
|
||||
NetBird IP: 192.168.178.100/16
|
||||
Interface type: Kernel
|
||||
Quantum resistance: false
|
||||
Routes: 10.10.0.0/24
|
||||
Peers count: 2/2 Connected
|
||||
`
|
||||
|
||||
assert.Equal(t, expectedString, shortVersion)
|
||||
}
|
||||
|
||||
@@ -22,6 +22,7 @@ import (
|
||||
)
|
||||
|
||||
func startTestingServices(t *testing.T) string {
|
||||
t.Helper()
|
||||
config := &mgmt.Config{}
|
||||
_, err := util.ReadJson("../testdata/management.json", config)
|
||||
if err != nil {
|
||||
@@ -44,6 +45,7 @@ func startTestingServices(t *testing.T) string {
|
||||
}
|
||||
|
||||
func startSignal(t *testing.T) (*grpc.Server, net.Listener) {
|
||||
t.Helper()
|
||||
lis, err := net.Listen("tcp", ":0")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
@@ -60,28 +62,28 @@ func startSignal(t *testing.T) (*grpc.Server, net.Listener) {
|
||||
}
|
||||
|
||||
func startManagement(t *testing.T, config *mgmt.Config) (*grpc.Server, net.Listener) {
|
||||
t.Helper()
|
||||
lis, err := net.Listen("tcp", ":0")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
s := grpc.NewServer()
|
||||
store, err := mgmt.NewFileStore(config.Datadir, nil)
|
||||
store, err := mgmt.NewStoreFromJson(config.Datadir, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
peersUpdateManager := mgmt.NewPeersUpdateManager()
|
||||
peersUpdateManager := mgmt.NewPeersUpdateManager(nil)
|
||||
eventStore := &activity.InMemoryEventStore{}
|
||||
if err != nil {
|
||||
return nil, nil
|
||||
}
|
||||
accountManager, err := mgmt.BuildManager(store, peersUpdateManager, nil, "", "",
|
||||
eventStore)
|
||||
accountManager, err := mgmt.BuildManager(store, peersUpdateManager, nil, "", "", eventStore, nil, false)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
turnManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig)
|
||||
mgmtServer, err := mgmt.NewServer(config, accountManager, peersUpdateManager, turnManager, nil)
|
||||
mgmtServer, err := mgmt.NewServer(config, accountManager, peersUpdateManager, turnManager, nil, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -98,6 +100,7 @@ func startManagement(t *testing.T, config *mgmt.Config) (*grpc.Server, net.Liste
|
||||
func startClientDaemon(
|
||||
t *testing.T, ctx context.Context, managementURL, configPath string,
|
||||
) (*grpc.Server, net.Listener) {
|
||||
t.Helper()
|
||||
lis, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
|
||||
114
client/cmd/up.go
114
client/cmd/up.go
@@ -5,6 +5,7 @@ import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"runtime"
|
||||
"strings"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
@@ -16,6 +17,7 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
"github.com/netbirdio/netbird/client/system"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
|
||||
@@ -36,6 +38,8 @@ var (
|
||||
|
||||
func init() {
|
||||
upCmd.PersistentFlags().BoolVarP(&foregroundMode, "foreground-mode", "F", false, "start service in foreground")
|
||||
upCmd.PersistentFlags().StringVar(&interfaceName, interfaceNameFlag, iface.WgInterfaceDefault, "Wireguard interface name")
|
||||
upCmd.PersistentFlags().Uint16Var(&wireguardPort, wireguardPortFlag, iface.DefaultWgPort, "Wireguard interface listening port")
|
||||
}
|
||||
|
||||
func upFunc(cmd *cobra.Command, args []string) error {
|
||||
@@ -85,16 +89,53 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command) error {
|
||||
NATExternalIPs: natExternalIPs,
|
||||
CustomDNSAddress: customDNSAddressConverted,
|
||||
}
|
||||
if preSharedKey != "" {
|
||||
|
||||
if cmd.Flag(enableRosenpassFlag).Changed {
|
||||
ic.RosenpassEnabled = &rosenpassEnabled
|
||||
}
|
||||
|
||||
if cmd.Flag(rosenpassPermissiveFlag).Changed {
|
||||
ic.RosenpassPermissive = &rosenpassPermissive
|
||||
}
|
||||
|
||||
if cmd.Flag(serverSSHAllowedFlag).Changed {
|
||||
ic.ServerSSHAllowed = &serverSSHAllowed
|
||||
}
|
||||
|
||||
if cmd.Flag(interfaceNameFlag).Changed {
|
||||
if err := parseInterfaceName(interfaceName); err != nil {
|
||||
return err
|
||||
}
|
||||
ic.InterfaceName = &interfaceName
|
||||
}
|
||||
|
||||
if cmd.Flag(wireguardPortFlag).Changed {
|
||||
p := int(wireguardPort)
|
||||
ic.WireguardPort = &p
|
||||
}
|
||||
|
||||
if rootCmd.PersistentFlags().Changed(preSharedKeyFlag) {
|
||||
ic.PreSharedKey = &preSharedKey
|
||||
}
|
||||
|
||||
if cmd.Flag(disableAutoConnectFlag).Changed {
|
||||
ic.DisableAutoConnect = &autoConnectDisabled
|
||||
|
||||
if autoConnectDisabled {
|
||||
cmd.Println("Autoconnect has been disabled. The client won't connect automatically when the service starts.")
|
||||
}
|
||||
|
||||
if !autoConnectDisabled {
|
||||
cmd.Println("Autoconnect has been enabled. The client will connect automatically when the service starts.")
|
||||
}
|
||||
}
|
||||
|
||||
config, err := internal.UpdateOrCreateConfig(ic)
|
||||
if err != nil {
|
||||
return fmt.Errorf("get config file: %v", err)
|
||||
}
|
||||
|
||||
config, _ = internal.UpdateOldManagementPort(ctx, config, configPath)
|
||||
config, _ = internal.UpdateOldManagementURL(ctx, config, configPath)
|
||||
|
||||
err = foregroundLogin(ctx, cmd, config, setupKey)
|
||||
if err != nil {
|
||||
@@ -123,7 +164,7 @@ func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error {
|
||||
defer func() {
|
||||
err := conn.Close()
|
||||
if err != nil {
|
||||
log.Warnf("failed closing dameon gRPC client connection %v", err)
|
||||
log.Warnf("failed closing daemon gRPC client connection %v", err)
|
||||
return
|
||||
}
|
||||
}()
|
||||
@@ -141,13 +182,46 @@ func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error {
|
||||
}
|
||||
|
||||
loginRequest := proto.LoginRequest{
|
||||
SetupKey: setupKey,
|
||||
PreSharedKey: preSharedKey,
|
||||
ManagementUrl: managementURL,
|
||||
AdminURL: adminURL,
|
||||
NatExternalIPs: natExternalIPs,
|
||||
CleanNATExternalIPs: natExternalIPs != nil && len(natExternalIPs) == 0,
|
||||
CustomDNSAddress: customDNSAddressConverted,
|
||||
SetupKey: setupKey,
|
||||
ManagementUrl: managementURL,
|
||||
AdminURL: adminURL,
|
||||
NatExternalIPs: natExternalIPs,
|
||||
CleanNATExternalIPs: natExternalIPs != nil && len(natExternalIPs) == 0,
|
||||
CustomDNSAddress: customDNSAddressConverted,
|
||||
IsLinuxDesktopClient: isLinuxRunningDesktop(),
|
||||
Hostname: hostName,
|
||||
}
|
||||
|
||||
if rootCmd.PersistentFlags().Changed(preSharedKeyFlag) {
|
||||
loginRequest.OptionalPreSharedKey = &preSharedKey
|
||||
}
|
||||
|
||||
if cmd.Flag(enableRosenpassFlag).Changed {
|
||||
loginRequest.RosenpassEnabled = &rosenpassEnabled
|
||||
}
|
||||
|
||||
if cmd.Flag(rosenpassPermissiveFlag).Changed {
|
||||
loginRequest.RosenpassPermissive = &rosenpassPermissive
|
||||
}
|
||||
|
||||
if cmd.Flag(serverSSHAllowedFlag).Changed {
|
||||
loginRequest.ServerSSHAllowed = &serverSSHAllowed
|
||||
}
|
||||
|
||||
if cmd.Flag(disableAutoConnectFlag).Changed {
|
||||
loginRequest.DisableAutoConnect = &autoConnectDisabled
|
||||
}
|
||||
|
||||
if cmd.Flag(interfaceNameFlag).Changed {
|
||||
if err := parseInterfaceName(interfaceName); err != nil {
|
||||
return err
|
||||
}
|
||||
loginRequest.InterfaceName = &interfaceName
|
||||
}
|
||||
|
||||
if cmd.Flag(wireguardPortFlag).Changed {
|
||||
wp := int64(wireguardPort)
|
||||
loginRequest.WireguardPort = &wp
|
||||
}
|
||||
|
||||
var loginErr error
|
||||
@@ -178,7 +252,7 @@ func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error {
|
||||
|
||||
openURL(cmd, loginResp.VerificationURIComplete, loginResp.UserCode)
|
||||
|
||||
_, err = client.WaitSSOLogin(ctx, &proto.WaitSSOLoginRequest{UserCode: loginResp.UserCode})
|
||||
_, err = client.WaitSSOLogin(ctx, &proto.WaitSSOLoginRequest{UserCode: loginResp.UserCode, Hostname: hostName})
|
||||
if err != nil {
|
||||
return fmt.Errorf("waiting sso login failed with: %v", err)
|
||||
}
|
||||
@@ -199,11 +273,11 @@ func validateNATExternalIPs(list []string) error {
|
||||
|
||||
subElements := strings.Split(element, "/")
|
||||
if len(subElements) > 2 {
|
||||
return fmt.Errorf("%s is not a valid input for %s. it should be formated as \"String\" or \"String/String\"", element, externalIPMapFlag)
|
||||
return fmt.Errorf("%s is not a valid input for %s. it should be formatted as \"String\" or \"String/String\"", element, externalIPMapFlag)
|
||||
}
|
||||
|
||||
if len(subElements) == 1 && !isValidIP(subElements[0]) {
|
||||
return fmt.Errorf("%s is not a valid input for %s. it should be formated as \"IP\" or \"IP/IP\", or \"IP/Interface Name\"", element, externalIPMapFlag)
|
||||
return fmt.Errorf("%s is not a valid input for %s. it should be formatted as \"IP\" or \"IP/IP\", or \"IP/Interface Name\"", element, externalIPMapFlag)
|
||||
}
|
||||
|
||||
last := 0
|
||||
@@ -221,6 +295,18 @@ func validateNATExternalIPs(list []string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func parseInterfaceName(name string) error {
|
||||
if runtime.GOOS != "darwin" {
|
||||
return nil
|
||||
}
|
||||
|
||||
if strings.HasPrefix(name, "utun") {
|
||||
return nil
|
||||
}
|
||||
|
||||
return fmt.Errorf("invalid interface name %s. Please use the prefix utun followed by a number on MacOS. e.g., utun1 or utun199", name)
|
||||
}
|
||||
|
||||
func validateElement(element string) (int, error) {
|
||||
if isValidIP(element) {
|
||||
return ipInputType, nil
|
||||
@@ -258,7 +344,7 @@ func parseCustomDNSAddress(modified bool) ([]byte, error) {
|
||||
var parsed []byte
|
||||
if modified {
|
||||
if !isValidAddrPort(customDNSAddress) {
|
||||
return nil, fmt.Errorf("%s is invalid, it should be formated as IP:Port string or as an empty string like \"\"", customDNSAddress)
|
||||
return nil, fmt.Errorf("%s is invalid, it should be formatted as IP:Port string or as an empty string like \"\"", customDNSAddress)
|
||||
}
|
||||
if customDNSAddress == "" && logFile != "console" {
|
||||
parsed = []byte("empty")
|
||||
|
||||
32
client/firewall/create.go
Normal file
32
client/firewall/create.go
Normal file
@@ -0,0 +1,32 @@
|
||||
//go:build !linux || android
|
||||
|
||||
package firewall
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"runtime"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/firewall/uspfilter"
|
||||
)
|
||||
|
||||
// NewFirewall creates a firewall manager instance
|
||||
func NewFirewall(context context.Context, iface IFaceMapper) (firewall.Manager, error) {
|
||||
if !iface.IsUserspaceBind() {
|
||||
return nil, fmt.Errorf("not implemented for this OS: %s", runtime.GOOS)
|
||||
}
|
||||
|
||||
// use userspace packet filtering firewall
|
||||
fm, err := uspfilter.Create(iface)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
err = fm.AllowNetbird()
|
||||
if err != nil {
|
||||
log.Warnf("failed to allow netbird interface traffic: %v", err)
|
||||
}
|
||||
return fm, nil
|
||||
}
|
||||
107
client/firewall/create_linux.go
Normal file
107
client/firewall/create_linux.go
Normal file
@@ -0,0 +1,107 @@
|
||||
//go:build !android
|
||||
|
||||
package firewall
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
"github.com/coreos/go-iptables/iptables"
|
||||
"github.com/google/nftables"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
nbiptables "github.com/netbirdio/netbird/client/firewall/iptables"
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
nbnftables "github.com/netbirdio/netbird/client/firewall/nftables"
|
||||
"github.com/netbirdio/netbird/client/firewall/uspfilter"
|
||||
)
|
||||
|
||||
const (
|
||||
// UNKNOWN is the default value for the firewall type for unknown firewall type
|
||||
UNKNOWN FWType = iota
|
||||
// IPTABLES is the value for the iptables firewall type
|
||||
IPTABLES
|
||||
// NFTABLES is the value for the nftables firewall type
|
||||
NFTABLES
|
||||
)
|
||||
|
||||
// SKIP_NFTABLES_ENV is the environment variable to skip nftables check
|
||||
const SKIP_NFTABLES_ENV = "NB_SKIP_NFTABLES_CHECK"
|
||||
|
||||
// FWType is the type for the firewall type
|
||||
type FWType int
|
||||
|
||||
func NewFirewall(context context.Context, iface IFaceMapper) (firewall.Manager, error) {
|
||||
// on the linux system we try to user nftables or iptables
|
||||
// in any case, because we need to allow netbird interface traffic
|
||||
// so we use AllowNetbird traffic from these firewall managers
|
||||
// for the userspace packet filtering firewall
|
||||
var fm firewall.Manager
|
||||
var errFw error
|
||||
|
||||
switch check() {
|
||||
case IPTABLES:
|
||||
log.Debug("creating an iptables firewall manager")
|
||||
fm, errFw = nbiptables.Create(context, iface)
|
||||
if errFw != nil {
|
||||
log.Errorf("failed to create iptables manager: %s", errFw)
|
||||
}
|
||||
case NFTABLES:
|
||||
log.Debug("creating an nftables firewall manager")
|
||||
fm, errFw = nbnftables.Create(context, iface)
|
||||
if errFw != nil {
|
||||
log.Errorf("failed to create nftables manager: %s", errFw)
|
||||
}
|
||||
default:
|
||||
errFw = fmt.Errorf("no firewall manager found")
|
||||
log.Debug("no firewall manager found, try to use userspace packet filtering firewall")
|
||||
}
|
||||
|
||||
if iface.IsUserspaceBind() {
|
||||
var errUsp error
|
||||
if errFw == nil {
|
||||
fm, errUsp = uspfilter.CreateWithNativeFirewall(iface, fm)
|
||||
} else {
|
||||
fm, errUsp = uspfilter.Create(iface)
|
||||
}
|
||||
if errUsp != nil {
|
||||
log.Debugf("failed to create userspace filtering firewall: %s", errUsp)
|
||||
return nil, errUsp
|
||||
}
|
||||
|
||||
if err := fm.AllowNetbird(); err != nil {
|
||||
log.Errorf("failed to allow netbird interface traffic: %v", err)
|
||||
}
|
||||
return fm, nil
|
||||
}
|
||||
|
||||
if errFw != nil {
|
||||
return nil, errFw
|
||||
}
|
||||
|
||||
return fm, nil
|
||||
}
|
||||
|
||||
// check returns the firewall type based on common lib checks. It returns UNKNOWN if no firewall is found.
|
||||
func check() FWType {
|
||||
nf := nftables.Conn{}
|
||||
if _, err := nf.ListChains(); err == nil && os.Getenv(SKIP_NFTABLES_ENV) != "true" {
|
||||
return NFTABLES
|
||||
}
|
||||
|
||||
ip, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
||||
if err != nil {
|
||||
return UNKNOWN
|
||||
}
|
||||
if isIptablesClientAvailable(ip) {
|
||||
return IPTABLES
|
||||
}
|
||||
|
||||
return UNKNOWN
|
||||
}
|
||||
|
||||
func isIptablesClientAvailable(client *iptables.IPTables) bool {
|
||||
_, err := client.ListChains("filter")
|
||||
return err == nil
|
||||
}
|
||||
11
client/firewall/iface.go
Normal file
11
client/firewall/iface.go
Normal file
@@ -0,0 +1,11 @@
|
||||
package firewall
|
||||
|
||||
import "github.com/netbirdio/netbird/iface"
|
||||
|
||||
// IFaceMapper defines subset methods of interface required for manager
|
||||
type IFaceMapper interface {
|
||||
Name() string
|
||||
Address() iface.WGAddress
|
||||
IsUserspaceBind() bool
|
||||
SetFilter(iface.PacketFilter) error
|
||||
}
|
||||
473
client/firewall/iptables/acl_linux.go
Normal file
473
client/firewall/iptables/acl_linux.go
Normal file
@@ -0,0 +1,473 @@
|
||||
package iptables
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"strconv"
|
||||
|
||||
"github.com/coreos/go-iptables/iptables"
|
||||
"github.com/google/uuid"
|
||||
"github.com/nadoo/ipset"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
)
|
||||
|
||||
const (
|
||||
tableName = "filter"
|
||||
|
||||
// rules chains contains the effective ACL rules
|
||||
chainNameInputRules = "NETBIRD-ACL-INPUT"
|
||||
chainNameOutputRules = "NETBIRD-ACL-OUTPUT"
|
||||
|
||||
postRoutingMark = "0x000007e4"
|
||||
)
|
||||
|
||||
type aclManager struct {
|
||||
iptablesClient *iptables.IPTables
|
||||
wgIface iFaceMapper
|
||||
routeingFwChainName string
|
||||
|
||||
entries map[string][][]string
|
||||
ipsetStore *ipsetStore
|
||||
}
|
||||
|
||||
func newAclManager(iptablesClient *iptables.IPTables, wgIface iFaceMapper, routeingFwChainName string) (*aclManager, error) {
|
||||
m := &aclManager{
|
||||
iptablesClient: iptablesClient,
|
||||
wgIface: wgIface,
|
||||
routeingFwChainName: routeingFwChainName,
|
||||
|
||||
entries: make(map[string][][]string),
|
||||
ipsetStore: newIpsetStore(),
|
||||
}
|
||||
|
||||
err := ipset.Init()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to init ipset: %w", err)
|
||||
}
|
||||
|
||||
m.seedInitialEntries()
|
||||
|
||||
err = m.cleanChains()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = m.createDefaultChains()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func (m *aclManager) AddFiltering(
|
||||
ip net.IP,
|
||||
protocol firewall.Protocol,
|
||||
sPort *firewall.Port,
|
||||
dPort *firewall.Port,
|
||||
direction firewall.RuleDirection,
|
||||
action firewall.Action,
|
||||
ipsetName string,
|
||||
) ([]firewall.Rule, error) {
|
||||
var dPortVal, sPortVal string
|
||||
if dPort != nil && dPort.Values != nil {
|
||||
// TODO: we support only one port per rule in current implementation of ACLs
|
||||
dPortVal = strconv.Itoa(dPort.Values[0])
|
||||
}
|
||||
if sPort != nil && sPort.Values != nil {
|
||||
sPortVal = strconv.Itoa(sPort.Values[0])
|
||||
}
|
||||
|
||||
var chain string
|
||||
if direction == firewall.RuleDirectionOUT {
|
||||
chain = chainNameOutputRules
|
||||
} else {
|
||||
chain = chainNameInputRules
|
||||
}
|
||||
|
||||
ipsetName = transformIPsetName(ipsetName, sPortVal, dPortVal)
|
||||
specs := filterRuleSpecs(ip, string(protocol), sPortVal, dPortVal, direction, action, ipsetName)
|
||||
if ipsetName != "" {
|
||||
if ipList, ipsetExists := m.ipsetStore.ipset(ipsetName); ipsetExists {
|
||||
if err := ipset.Add(ipsetName, ip.String()); err != nil {
|
||||
return nil, fmt.Errorf("failed to add IP to ipset: %w", err)
|
||||
}
|
||||
// if ruleset already exists it means we already have the firewall rule
|
||||
// so we need to update IPs in the ruleset and return new fw.Rule object for ACL manager.
|
||||
ipList.addIP(ip.String())
|
||||
return []firewall.Rule{&Rule{
|
||||
ruleID: uuid.New().String(),
|
||||
ipsetName: ipsetName,
|
||||
ip: ip.String(),
|
||||
chain: chain,
|
||||
specs: specs,
|
||||
}}, nil
|
||||
}
|
||||
|
||||
if err := ipset.Flush(ipsetName); err != nil {
|
||||
log.Errorf("flush ipset %s before use it: %s", ipsetName, err)
|
||||
}
|
||||
if err := ipset.Create(ipsetName); err != nil {
|
||||
return nil, fmt.Errorf("failed to create ipset: %w", err)
|
||||
}
|
||||
if err := ipset.Add(ipsetName, ip.String()); err != nil {
|
||||
return nil, fmt.Errorf("failed to add IP to ipset: %w", err)
|
||||
}
|
||||
|
||||
ipList := newIpList(ip.String())
|
||||
m.ipsetStore.addIpList(ipsetName, ipList)
|
||||
}
|
||||
|
||||
ok, err := m.iptablesClient.Exists("filter", chain, specs...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to check rule: %w", err)
|
||||
}
|
||||
if ok {
|
||||
return nil, fmt.Errorf("rule already exists")
|
||||
}
|
||||
|
||||
if err := m.iptablesClient.Insert("filter", chain, 1, specs...); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rule := &Rule{
|
||||
ruleID: uuid.New().String(),
|
||||
specs: specs,
|
||||
ipsetName: ipsetName,
|
||||
ip: ip.String(),
|
||||
chain: chain,
|
||||
}
|
||||
|
||||
if !shouldAddToPrerouting(protocol, dPort, direction) {
|
||||
return []firewall.Rule{rule}, nil
|
||||
}
|
||||
|
||||
rulePrerouting, err := m.addPreroutingFilter(ipsetName, string(protocol), dPortVal, ip)
|
||||
if err != nil {
|
||||
return []firewall.Rule{rule}, err
|
||||
}
|
||||
return []firewall.Rule{rule, rulePrerouting}, nil
|
||||
}
|
||||
|
||||
// DeleteRule from the firewall by rule definition
|
||||
func (m *aclManager) DeleteRule(rule firewall.Rule) error {
|
||||
r, ok := rule.(*Rule)
|
||||
if !ok {
|
||||
return fmt.Errorf("invalid rule type")
|
||||
}
|
||||
|
||||
if r.chain == "PREROUTING" {
|
||||
goto DELETERULE
|
||||
}
|
||||
|
||||
if ipsetList, ok := m.ipsetStore.ipset(r.ipsetName); ok {
|
||||
// delete IP from ruleset IPs list and ipset
|
||||
if _, ok := ipsetList.ips[r.ip]; ok {
|
||||
if err := ipset.Del(r.ipsetName, r.ip); err != nil {
|
||||
return fmt.Errorf("failed to delete ip from ipset: %w", err)
|
||||
}
|
||||
delete(ipsetList.ips, r.ip)
|
||||
}
|
||||
|
||||
// if after delete, set still contains other IPs,
|
||||
// no need to delete firewall rule and we should exit here
|
||||
if len(ipsetList.ips) != 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// we delete last IP from the set, that means we need to delete
|
||||
// set itself and associated firewall rule too
|
||||
m.ipsetStore.deleteIpset(r.ipsetName)
|
||||
|
||||
if err := ipset.Destroy(r.ipsetName); err != nil {
|
||||
log.Errorf("delete empty ipset: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
DELETERULE:
|
||||
var table string
|
||||
if r.chain == "PREROUTING" {
|
||||
table = "mangle"
|
||||
} else {
|
||||
table = "filter"
|
||||
}
|
||||
err := m.iptablesClient.Delete(table, r.chain, r.specs...)
|
||||
if err != nil {
|
||||
log.Debugf("failed to delete rule, %s, %v: %s", r.chain, r.specs, err)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (m *aclManager) Reset() error {
|
||||
return m.cleanChains()
|
||||
}
|
||||
|
||||
func (m *aclManager) addPreroutingFilter(ipsetName string, protocol string, port string, ip net.IP) (*Rule, error) {
|
||||
var src []string
|
||||
if ipsetName != "" {
|
||||
src = []string{"-m", "set", "--set", ipsetName, "src"}
|
||||
} else {
|
||||
src = []string{"-s", ip.String()}
|
||||
}
|
||||
specs := []string{
|
||||
"-d", m.wgIface.Address().IP.String(),
|
||||
"-p", protocol,
|
||||
"--dport", port,
|
||||
"-j", "MARK", "--set-mark", postRoutingMark,
|
||||
}
|
||||
|
||||
specs = append(src, specs...)
|
||||
|
||||
ok, err := m.iptablesClient.Exists("mangle", "PREROUTING", specs...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to check rule: %w", err)
|
||||
}
|
||||
if ok {
|
||||
return nil, fmt.Errorf("rule already exists")
|
||||
}
|
||||
|
||||
if err := m.iptablesClient.Insert("mangle", "PREROUTING", 1, specs...); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rule := &Rule{
|
||||
ruleID: uuid.New().String(),
|
||||
specs: specs,
|
||||
ipsetName: ipsetName,
|
||||
ip: ip.String(),
|
||||
chain: "PREROUTING",
|
||||
}
|
||||
return rule, nil
|
||||
}
|
||||
|
||||
// todo write less destructive cleanup mechanism
|
||||
func (m *aclManager) cleanChains() error {
|
||||
ok, err := m.iptablesClient.ChainExists(tableName, chainNameOutputRules)
|
||||
if err != nil {
|
||||
log.Debugf("failed to list chains: %s", err)
|
||||
return err
|
||||
}
|
||||
if ok {
|
||||
rules := m.entries["OUTPUT"]
|
||||
for _, rule := range rules {
|
||||
err := m.iptablesClient.DeleteIfExists(tableName, "OUTPUT", rule...)
|
||||
if err != nil {
|
||||
log.Errorf("failed to delete rule: %v, %s", rule, err)
|
||||
}
|
||||
}
|
||||
|
||||
err = m.iptablesClient.ClearAndDeleteChain(tableName, chainNameOutputRules)
|
||||
if err != nil {
|
||||
log.Debugf("failed to clear and delete %s chain: %s", chainNameOutputRules, err)
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
ok, err = m.iptablesClient.ChainExists(tableName, chainNameInputRules)
|
||||
if err != nil {
|
||||
log.Debugf("failed to list chains: %s", err)
|
||||
return err
|
||||
}
|
||||
if ok {
|
||||
for _, rule := range m.entries["INPUT"] {
|
||||
err := m.iptablesClient.DeleteIfExists(tableName, "INPUT", rule...)
|
||||
if err != nil {
|
||||
log.Errorf("failed to delete rule: %v, %s", rule, err)
|
||||
}
|
||||
}
|
||||
|
||||
for _, rule := range m.entries["FORWARD"] {
|
||||
err := m.iptablesClient.DeleteIfExists(tableName, "FORWARD", rule...)
|
||||
if err != nil {
|
||||
log.Errorf("failed to delete rule: %v, %s", rule, err)
|
||||
}
|
||||
}
|
||||
|
||||
err = m.iptablesClient.ClearAndDeleteChain(tableName, chainNameInputRules)
|
||||
if err != nil {
|
||||
log.Debugf("failed to clear and delete %s chain: %s", chainNameInputRules, err)
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
ok, err = m.iptablesClient.ChainExists("mangle", "PREROUTING")
|
||||
if err != nil {
|
||||
log.Debugf("failed to list chains: %s", err)
|
||||
return err
|
||||
}
|
||||
if ok {
|
||||
for _, rule := range m.entries["PREROUTING"] {
|
||||
err := m.iptablesClient.DeleteIfExists("mangle", "PREROUTING", rule...)
|
||||
if err != nil {
|
||||
log.Errorf("failed to delete rule: %v, %s", rule, err)
|
||||
}
|
||||
}
|
||||
err = m.iptablesClient.ClearChain("mangle", "PREROUTING")
|
||||
if err != nil {
|
||||
log.Debugf("failed to clear %s chain: %s", "PREROUTING", err)
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
for _, ipsetName := range m.ipsetStore.ipsetNames() {
|
||||
if err := ipset.Flush(ipsetName); err != nil {
|
||||
log.Errorf("flush ipset %q during reset: %v", ipsetName, err)
|
||||
}
|
||||
if err := ipset.Destroy(ipsetName); err != nil {
|
||||
log.Errorf("delete ipset %q during reset: %v", ipsetName, err)
|
||||
}
|
||||
m.ipsetStore.deleteIpset(ipsetName)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *aclManager) createDefaultChains() error {
|
||||
// chain netbird-acl-input-rules
|
||||
if err := m.iptablesClient.NewChain(tableName, chainNameInputRules); err != nil {
|
||||
log.Debugf("failed to create '%s' chain: %s", chainNameInputRules, err)
|
||||
return err
|
||||
}
|
||||
|
||||
// chain netbird-acl-output-rules
|
||||
if err := m.iptablesClient.NewChain(tableName, chainNameOutputRules); err != nil {
|
||||
log.Debugf("failed to create '%s' chain: %s", chainNameOutputRules, err)
|
||||
return err
|
||||
}
|
||||
|
||||
for chainName, rules := range m.entries {
|
||||
for _, rule := range rules {
|
||||
if chainName == "FORWARD" {
|
||||
// position 2 because we add it after router's, jump rule
|
||||
if err := m.iptablesClient.InsertUnique(tableName, "FORWARD", 2, rule...); err != nil {
|
||||
log.Debugf("failed to create input chain jump rule: %s", err)
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
if err := m.iptablesClient.AppendUnique(tableName, chainName, rule...); err != nil {
|
||||
log.Debugf("failed to create input chain jump rule: %s", err)
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *aclManager) seedInitialEntries() {
|
||||
m.appendToEntries("INPUT",
|
||||
[]string{"-i", m.wgIface.Name(), "!", "-s", m.wgIface.Address().String(), "-d", m.wgIface.Address().String(), "-j", "ACCEPT"})
|
||||
|
||||
m.appendToEntries("INPUT",
|
||||
[]string{"-i", m.wgIface.Name(), "-s", m.wgIface.Address().String(), "!", "-d", m.wgIface.Address().String(), "-j", "ACCEPT"})
|
||||
|
||||
m.appendToEntries("INPUT",
|
||||
[]string{"-i", m.wgIface.Name(), "-s", m.wgIface.Address().String(), "-d", m.wgIface.Address().String(), "-j", chainNameInputRules})
|
||||
|
||||
m.appendToEntries("INPUT", []string{"-i", m.wgIface.Name(), "-j", "DROP"})
|
||||
|
||||
m.appendToEntries("OUTPUT",
|
||||
[]string{"-o", m.wgIface.Name(), "!", "-s", m.wgIface.Address().String(), "-d", m.wgIface.Address().String(), "-j", "ACCEPT"})
|
||||
|
||||
m.appendToEntries("OUTPUT",
|
||||
[]string{"-o", m.wgIface.Name(), "-s", m.wgIface.Address().String(), "!", "-d", m.wgIface.Address().String(), "-j", "ACCEPT"})
|
||||
|
||||
m.appendToEntries("OUTPUT",
|
||||
[]string{"-o", m.wgIface.Name(), "-s", m.wgIface.Address().String(), "-d", m.wgIface.Address().String(), "-j", chainNameOutputRules})
|
||||
|
||||
m.appendToEntries("OUTPUT", []string{"-o", m.wgIface.Name(), "-j", "DROP"})
|
||||
|
||||
m.appendToEntries("FORWARD", []string{"-i", m.wgIface.Name(), "-j", "DROP"})
|
||||
m.appendToEntries("FORWARD", []string{"-i", m.wgIface.Name(), "-j", chainNameInputRules})
|
||||
m.appendToEntries("FORWARD",
|
||||
[]string{"-o", m.wgIface.Name(), "-m", "mark", "--mark", postRoutingMark, "-j", "ACCEPT"})
|
||||
m.appendToEntries("FORWARD",
|
||||
[]string{"-i", m.wgIface.Name(), "-m", "mark", "--mark", postRoutingMark, "-j", "ACCEPT"})
|
||||
m.appendToEntries("FORWARD", []string{"-o", m.wgIface.Name(), "-j", m.routeingFwChainName})
|
||||
m.appendToEntries("FORWARD", []string{"-i", m.wgIface.Name(), "-j", m.routeingFwChainName})
|
||||
|
||||
m.appendToEntries("PREROUTING",
|
||||
[]string{"-t", "mangle", "-i", m.wgIface.Name(), "!", "-s", m.wgIface.Address().String(), "-d", m.wgIface.Address().IP.String(), "-m", "mark", "--mark", postRoutingMark})
|
||||
}
|
||||
|
||||
func (m *aclManager) appendToEntries(chainName string, spec []string) {
|
||||
m.entries[chainName] = append(m.entries[chainName], spec)
|
||||
}
|
||||
|
||||
// filterRuleSpecs returns the specs of a filtering rule
|
||||
func filterRuleSpecs(
|
||||
ip net.IP, protocol string, sPort, dPort string, direction firewall.RuleDirection, action firewall.Action, ipsetName string,
|
||||
) (specs []string) {
|
||||
matchByIP := true
|
||||
// don't use IP matching if IP is ip 0.0.0.0
|
||||
if ip.String() == "0.0.0.0" {
|
||||
matchByIP = false
|
||||
}
|
||||
switch direction {
|
||||
case firewall.RuleDirectionIN:
|
||||
if matchByIP {
|
||||
if ipsetName != "" {
|
||||
specs = append(specs, "-m", "set", "--set", ipsetName, "src")
|
||||
} else {
|
||||
specs = append(specs, "-s", ip.String())
|
||||
}
|
||||
}
|
||||
case firewall.RuleDirectionOUT:
|
||||
if matchByIP {
|
||||
if ipsetName != "" {
|
||||
specs = append(specs, "-m", "set", "--set", ipsetName, "dst")
|
||||
} else {
|
||||
specs = append(specs, "-d", ip.String())
|
||||
}
|
||||
}
|
||||
}
|
||||
if protocol != "all" {
|
||||
specs = append(specs, "-p", protocol)
|
||||
}
|
||||
if sPort != "" {
|
||||
specs = append(specs, "--sport", sPort)
|
||||
}
|
||||
if dPort != "" {
|
||||
specs = append(specs, "--dport", dPort)
|
||||
}
|
||||
return append(specs, "-j", actionToStr(action))
|
||||
}
|
||||
|
||||
func actionToStr(action firewall.Action) string {
|
||||
if action == firewall.ActionAccept {
|
||||
return "ACCEPT"
|
||||
}
|
||||
return "DROP"
|
||||
}
|
||||
|
||||
func transformIPsetName(ipsetName string, sPort, dPort string) string {
|
||||
switch {
|
||||
case ipsetName == "":
|
||||
return ""
|
||||
case sPort != "" && dPort != "":
|
||||
return ipsetName + "-sport-dport"
|
||||
case sPort != "":
|
||||
return ipsetName + "-sport"
|
||||
case dPort != "":
|
||||
return ipsetName + "-dport"
|
||||
default:
|
||||
return ipsetName
|
||||
}
|
||||
}
|
||||
|
||||
func shouldAddToPrerouting(proto firewall.Protocol, dPort *firewall.Port, direction firewall.RuleDirection) bool {
|
||||
if proto == "all" {
|
||||
return false
|
||||
}
|
||||
|
||||
if direction != firewall.RuleDirectionIN {
|
||||
return false
|
||||
}
|
||||
|
||||
if dPort == nil {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
@@ -1,262 +1,105 @@
|
||||
package iptables
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"strconv"
|
||||
"sync"
|
||||
|
||||
"github.com/coreos/go-iptables/iptables"
|
||||
"github.com/google/uuid"
|
||||
"github.com/nadoo/ipset"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
fw "github.com/netbirdio/netbird/client/firewall"
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
)
|
||||
|
||||
const (
|
||||
// ChainInputFilterName is the name of the chain that is used for filtering incoming packets
|
||||
ChainInputFilterName = "NETBIRD-ACL-INPUT"
|
||||
|
||||
// ChainOutputFilterName is the name of the chain that is used for filtering outgoing packets
|
||||
ChainOutputFilterName = "NETBIRD-ACL-OUTPUT"
|
||||
)
|
||||
|
||||
// dropAllDefaultRule in the Netbird chain
|
||||
var dropAllDefaultRule = []string{"-j", "DROP"}
|
||||
|
||||
// Manager of iptables firewall
|
||||
type Manager struct {
|
||||
mutex sync.Mutex
|
||||
|
||||
wgIface iFaceMapper
|
||||
|
||||
ipv4Client *iptables.IPTables
|
||||
ipv6Client *iptables.IPTables
|
||||
|
||||
inputDefaultRuleSpecs []string
|
||||
outputDefaultRuleSpecs []string
|
||||
wgIface iFaceMapper
|
||||
|
||||
rulesets map[string]ruleset
|
||||
aclMgr *aclManager
|
||||
router *routerManager
|
||||
}
|
||||
|
||||
// iFaceMapper defines subset methods of interface required for manager
|
||||
type iFaceMapper interface {
|
||||
Name() string
|
||||
Address() iface.WGAddress
|
||||
}
|
||||
|
||||
type ruleset struct {
|
||||
rule *Rule
|
||||
ips map[string]string
|
||||
IsUserspaceBind() bool
|
||||
}
|
||||
|
||||
// Create iptables firewall manager
|
||||
func Create(wgIface iFaceMapper) (*Manager, error) {
|
||||
m := &Manager{
|
||||
wgIface: wgIface,
|
||||
inputDefaultRuleSpecs: []string{
|
||||
"-i", wgIface.Name(), "-j", ChainInputFilterName, "-s", wgIface.Address().String()},
|
||||
outputDefaultRuleSpecs: []string{
|
||||
"-o", wgIface.Name(), "-j", ChainOutputFilterName, "-d", wgIface.Address().String()},
|
||||
rulesets: make(map[string]ruleset),
|
||||
}
|
||||
|
||||
if err := ipset.Init(); err != nil {
|
||||
return nil, fmt.Errorf("init ipset: %w", err)
|
||||
}
|
||||
|
||||
// init clients for booth ipv4 and ipv6
|
||||
ipv4Client, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
||||
func Create(context context.Context, wgIface iFaceMapper) (*Manager, error) {
|
||||
iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("iptables is not installed in the system or not supported")
|
||||
}
|
||||
if isIptablesClientAvailable(ipv4Client) {
|
||||
m.ipv4Client = ipv4Client
|
||||
|
||||
m := &Manager{
|
||||
wgIface: wgIface,
|
||||
ipv4Client: iptablesClient,
|
||||
}
|
||||
|
||||
ipv6Client, err := iptables.NewWithProtocol(iptables.ProtocolIPv6)
|
||||
m.router, err = newRouterManager(context, iptablesClient)
|
||||
if err != nil {
|
||||
log.Errorf("ip6tables is not installed in the system or not supported: %v", err)
|
||||
} else {
|
||||
if isIptablesClientAvailable(ipv6Client) {
|
||||
m.ipv6Client = ipv6Client
|
||||
}
|
||||
log.Debugf("failed to initialize route related chains: %s", err)
|
||||
return nil, err
|
||||
}
|
||||
m.aclMgr, err = newAclManager(iptablesClient, wgIface, m.router.RouteingFwChainName())
|
||||
if err != nil {
|
||||
log.Debugf("failed to initialize ACL manager: %s", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := m.Reset(); err != nil {
|
||||
return nil, fmt.Errorf("failed to reset firewall: %v", err)
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func isIptablesClientAvailable(client *iptables.IPTables) bool {
|
||||
_, err := client.ListChains("filter")
|
||||
return err == nil
|
||||
}
|
||||
|
||||
// AddFiltering rule to the firewall
|
||||
//
|
||||
// If comment is empty rule ID is used as comment
|
||||
// Comment will be ignored because some system this feature is not supported
|
||||
func (m *Manager) AddFiltering(
|
||||
ip net.IP,
|
||||
protocol fw.Protocol,
|
||||
sPort *fw.Port,
|
||||
dPort *fw.Port,
|
||||
direction fw.RuleDirection,
|
||||
action fw.Action,
|
||||
protocol firewall.Protocol,
|
||||
sPort *firewall.Port,
|
||||
dPort *firewall.Port,
|
||||
direction firewall.RuleDirection,
|
||||
action firewall.Action,
|
||||
ipsetName string,
|
||||
comment string,
|
||||
) (fw.Rule, error) {
|
||||
) ([]firewall.Rule, error) {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
client, err := m.client(ip)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var dPortVal, sPortVal string
|
||||
if dPort != nil && dPort.Values != nil {
|
||||
// TODO: we support only one port per rule in current implementation of ACLs
|
||||
dPortVal = strconv.Itoa(dPort.Values[0])
|
||||
}
|
||||
if sPort != nil && sPort.Values != nil {
|
||||
sPortVal = strconv.Itoa(sPort.Values[0])
|
||||
}
|
||||
ipsetName = m.transformIPsetName(ipsetName, sPortVal, dPortVal)
|
||||
|
||||
ruleID := uuid.New().String()
|
||||
if comment == "" {
|
||||
comment = ruleID
|
||||
}
|
||||
|
||||
if ipsetName != "" {
|
||||
rs, rsExists := m.rulesets[ipsetName]
|
||||
if !rsExists {
|
||||
if err := ipset.Flush(ipsetName); err != nil {
|
||||
log.Errorf("flush ipset %q before use it: %v", ipsetName, err)
|
||||
}
|
||||
if err := ipset.Create(ipsetName); err != nil {
|
||||
return nil, fmt.Errorf("failed to create ipset: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := ipset.Add(ipsetName, ip.String()); err != nil {
|
||||
return nil, fmt.Errorf("failed to add IP to ipset: %w", err)
|
||||
}
|
||||
|
||||
if rsExists {
|
||||
// if ruleset already exists it means we already have the firewall rule
|
||||
// so we need to update IPs in the ruleset and return new fw.Rule object for ACL manager.
|
||||
rs.ips[ip.String()] = ruleID
|
||||
return &Rule{
|
||||
ruleID: ruleID,
|
||||
ipsetName: ipsetName,
|
||||
ip: ip.String(),
|
||||
dst: direction == fw.RuleDirectionOUT,
|
||||
v6: ip.To4() == nil,
|
||||
}, nil
|
||||
}
|
||||
// this is new ipset so we need to create firewall rule for it
|
||||
}
|
||||
|
||||
specs := m.filterRuleSpecs("filter", ip, string(protocol), sPortVal, dPortVal,
|
||||
direction, action, comment, ipsetName)
|
||||
|
||||
if direction == fw.RuleDirectionOUT {
|
||||
ok, err := client.Exists("filter", ChainOutputFilterName, specs...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("check is output rule already exists: %w", err)
|
||||
}
|
||||
if ok {
|
||||
return nil, fmt.Errorf("input rule already exists")
|
||||
}
|
||||
|
||||
if err := client.Insert("filter", ChainOutputFilterName, 1, specs...); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
ok, err := client.Exists("filter", ChainInputFilterName, specs...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("check is input rule already exists: %w", err)
|
||||
}
|
||||
if ok {
|
||||
return nil, fmt.Errorf("input rule already exists")
|
||||
}
|
||||
|
||||
if err := client.Insert("filter", ChainInputFilterName, 1, specs...); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
rule := &Rule{
|
||||
ruleID: ruleID,
|
||||
specs: specs,
|
||||
ipsetName: ipsetName,
|
||||
ip: ip.String(),
|
||||
dst: direction == fw.RuleDirectionOUT,
|
||||
v6: ip.To4() == nil,
|
||||
}
|
||||
if ipsetName != "" {
|
||||
// ipset name is defined and it means that this rule was created
|
||||
// for it, need to assosiate it with ruleset
|
||||
m.rulesets[ipsetName] = ruleset{
|
||||
rule: rule,
|
||||
ips: map[string]string{rule.ip: ruleID},
|
||||
}
|
||||
}
|
||||
|
||||
return rule, nil
|
||||
return m.aclMgr.AddFiltering(ip, protocol, sPort, dPort, direction, action, ipsetName)
|
||||
}
|
||||
|
||||
// DeleteRule from the firewall by rule definition
|
||||
func (m *Manager) DeleteRule(rule fw.Rule) error {
|
||||
func (m *Manager) DeleteRule(rule firewall.Rule) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
r, ok := rule.(*Rule)
|
||||
if !ok {
|
||||
return fmt.Errorf("invalid rule type")
|
||||
}
|
||||
return m.aclMgr.DeleteRule(rule)
|
||||
}
|
||||
|
||||
client := m.ipv4Client
|
||||
if r.v6 {
|
||||
if m.ipv6Client == nil {
|
||||
return fmt.Errorf("ipv6 is not supported")
|
||||
}
|
||||
client = m.ipv6Client
|
||||
}
|
||||
func (m *Manager) IsServerRouteSupported() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
if rs, ok := m.rulesets[r.ipsetName]; ok {
|
||||
// delete IP from ruleset IPs list and ipset
|
||||
if _, ok := rs.ips[r.ip]; ok {
|
||||
if err := ipset.Del(r.ipsetName, r.ip); err != nil {
|
||||
return fmt.Errorf("failed to delete ip from ipset: %w", err)
|
||||
}
|
||||
delete(rs.ips, r.ip)
|
||||
}
|
||||
func (m *Manager) InsertRoutingRules(pair firewall.RouterPair) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
// if after delete, set still contains other IPs,
|
||||
// no need to delete firewall rule and we should exit here
|
||||
if len(rs.ips) != 0 {
|
||||
return nil
|
||||
}
|
||||
return m.router.InsertRoutingRules(pair)
|
||||
}
|
||||
|
||||
// we delete last IP from the set, that means we need to delete
|
||||
// set itself and assosiated firewall rule too
|
||||
delete(m.rulesets, r.ipsetName)
|
||||
func (m *Manager) RemoveRoutingRules(pair firewall.RouterPair) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
if err := ipset.Destroy(r.ipsetName); err != nil {
|
||||
log.Errorf("delete empty ipset: %v", err)
|
||||
}
|
||||
r = rs.rule
|
||||
}
|
||||
|
||||
if r.dst {
|
||||
return client.Delete("filter", ChainOutputFilterName, r.specs...)
|
||||
}
|
||||
return client.Delete("filter", ChainInputFilterName, r.specs...)
|
||||
return m.router.RemoveRoutingRules(pair)
|
||||
}
|
||||
|
||||
// Reset firewall to the default state
|
||||
@@ -264,192 +107,49 @@ func (m *Manager) Reset() error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
if err := m.reset(m.ipv4Client, "filter"); err != nil {
|
||||
return fmt.Errorf("clean ipv4 firewall ACL input chain: %w", err)
|
||||
errAcl := m.aclMgr.Reset()
|
||||
if errAcl != nil {
|
||||
log.Errorf("failed to clean up ACL rules from firewall: %s", errAcl)
|
||||
}
|
||||
if m.ipv6Client != nil {
|
||||
if err := m.reset(m.ipv6Client, "filter"); err != nil {
|
||||
return fmt.Errorf("clean ipv6 firewall ACL input chain: %w", err)
|
||||
}
|
||||
errMgr := m.router.Reset()
|
||||
if errMgr != nil {
|
||||
log.Errorf("failed to clean up router rules from firewall: %s", errMgr)
|
||||
return errMgr
|
||||
}
|
||||
return errAcl
|
||||
}
|
||||
|
||||
// AllowNetbird allows netbird interface traffic
|
||||
func (m *Manager) AllowNetbird() error {
|
||||
if !m.wgIface.IsUserspaceBind() {
|
||||
return nil
|
||||
}
|
||||
|
||||
return nil
|
||||
_, err := m.AddFiltering(
|
||||
net.ParseIP("0.0.0.0"),
|
||||
"all",
|
||||
nil,
|
||||
nil,
|
||||
firewall.RuleDirectionIN,
|
||||
firewall.ActionAccept,
|
||||
"",
|
||||
"",
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to allow netbird interface traffic: %w", err)
|
||||
}
|
||||
_, err = m.AddFiltering(
|
||||
net.ParseIP("0.0.0.0"),
|
||||
"all",
|
||||
nil,
|
||||
nil,
|
||||
firewall.RuleDirectionOUT,
|
||||
firewall.ActionAccept,
|
||||
"",
|
||||
"",
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
||||
// Flush doesn't need to be implemented for this manager
|
||||
func (m *Manager) Flush() error { return nil }
|
||||
|
||||
// reset firewall chain, clear it and drop it
|
||||
func (m *Manager) reset(client *iptables.IPTables, table string) error {
|
||||
ok, err := client.ChainExists(table, ChainInputFilterName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to check if input chain exists: %w", err)
|
||||
}
|
||||
if ok {
|
||||
if ok, err := client.Exists("filter", "INPUT", m.inputDefaultRuleSpecs...); err != nil {
|
||||
return err
|
||||
} else if ok {
|
||||
if err := client.Delete("filter", "INPUT", m.inputDefaultRuleSpecs...); err != nil {
|
||||
log.WithError(err).Errorf("failed to delete default input rule: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
ok, err = client.ChainExists(table, ChainOutputFilterName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to check if output chain exists: %w", err)
|
||||
}
|
||||
if ok {
|
||||
if ok, err := client.Exists("filter", "OUTPUT", m.outputDefaultRuleSpecs...); err != nil {
|
||||
return err
|
||||
} else if ok {
|
||||
if err := client.Delete("filter", "OUTPUT", m.outputDefaultRuleSpecs...); err != nil {
|
||||
log.WithError(err).Errorf("failed to delete default output rule: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if err := client.ClearAndDeleteChain(table, ChainInputFilterName); err != nil {
|
||||
log.Errorf("failed to clear and delete input chain: %v", err)
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := client.ClearAndDeleteChain(table, ChainOutputFilterName); err != nil {
|
||||
log.Errorf("failed to clear and delete input chain: %v", err)
|
||||
return nil
|
||||
}
|
||||
|
||||
for ipsetName := range m.rulesets {
|
||||
if err := ipset.Flush(ipsetName); err != nil {
|
||||
log.Errorf("flush ipset %q during reset: %v", ipsetName, err)
|
||||
}
|
||||
if err := ipset.Destroy(ipsetName); err != nil {
|
||||
log.Errorf("delete ipset %q during reset: %v", ipsetName, err)
|
||||
}
|
||||
delete(m.rulesets, ipsetName)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// filterRuleSpecs returns the specs of a filtering rule
|
||||
func (m *Manager) filterRuleSpecs(
|
||||
table string, ip net.IP, protocol string, sPort, dPort string,
|
||||
direction fw.RuleDirection, action fw.Action, comment string,
|
||||
ipsetName string,
|
||||
) (specs []string) {
|
||||
matchByIP := true
|
||||
// don't use IP matching if IP is ip 0.0.0.0
|
||||
if s := ip.String(); s == "0.0.0.0" || s == "::" {
|
||||
matchByIP = false
|
||||
}
|
||||
switch direction {
|
||||
case fw.RuleDirectionIN:
|
||||
if matchByIP {
|
||||
if ipsetName != "" {
|
||||
specs = append(specs, "-m", "set", "--set", ipsetName, "src")
|
||||
} else {
|
||||
specs = append(specs, "-s", ip.String())
|
||||
}
|
||||
}
|
||||
case fw.RuleDirectionOUT:
|
||||
if matchByIP {
|
||||
if ipsetName != "" {
|
||||
specs = append(specs, "-m", "set", "--set", ipsetName, "dst")
|
||||
} else {
|
||||
specs = append(specs, "-d", ip.String())
|
||||
}
|
||||
}
|
||||
}
|
||||
if protocol != "all" {
|
||||
specs = append(specs, "-p", protocol)
|
||||
}
|
||||
if sPort != "" {
|
||||
specs = append(specs, "--sport", sPort)
|
||||
}
|
||||
if dPort != "" {
|
||||
specs = append(specs, "--dport", dPort)
|
||||
}
|
||||
specs = append(specs, "-j", m.actionToStr(action))
|
||||
return append(specs, "-m", "comment", "--comment", comment)
|
||||
}
|
||||
|
||||
// rawClient returns corresponding iptables client for the given ip
|
||||
func (m *Manager) rawClient(ip net.IP) (*iptables.IPTables, error) {
|
||||
if ip.To4() != nil {
|
||||
return m.ipv4Client, nil
|
||||
}
|
||||
if m.ipv6Client == nil {
|
||||
return nil, fmt.Errorf("ipv6 is not supported")
|
||||
}
|
||||
return m.ipv6Client, nil
|
||||
}
|
||||
|
||||
// client returns client with initialized chain and default rules
|
||||
func (m *Manager) client(ip net.IP) (*iptables.IPTables, error) {
|
||||
client, err := m.rawClient(ip)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ok, err := client.ChainExists("filter", ChainInputFilterName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to check if chain exists: %w", err)
|
||||
}
|
||||
|
||||
if !ok {
|
||||
if err := client.NewChain("filter", ChainInputFilterName); err != nil {
|
||||
return nil, fmt.Errorf("failed to create input chain: %w", err)
|
||||
}
|
||||
|
||||
if err := client.AppendUnique("filter", ChainInputFilterName, dropAllDefaultRule...); err != nil {
|
||||
return nil, fmt.Errorf("failed to create default drop all in netbird input chain: %w", err)
|
||||
}
|
||||
|
||||
if err := client.AppendUnique("filter", "INPUT", m.inputDefaultRuleSpecs...); err != nil {
|
||||
return nil, fmt.Errorf("failed to create input chain jump rule: %w", err)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
ok, err = client.ChainExists("filter", ChainOutputFilterName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to check if chain exists: %w", err)
|
||||
}
|
||||
|
||||
if !ok {
|
||||
if err := client.NewChain("filter", ChainOutputFilterName); err != nil {
|
||||
return nil, fmt.Errorf("failed to create output chain: %w", err)
|
||||
}
|
||||
|
||||
if err := client.AppendUnique("filter", ChainOutputFilterName, dropAllDefaultRule...); err != nil {
|
||||
return nil, fmt.Errorf("failed to create default drop all in netbird output chain: %w", err)
|
||||
}
|
||||
|
||||
if err := client.AppendUnique("filter", "OUTPUT", m.outputDefaultRuleSpecs...); err != nil {
|
||||
return nil, fmt.Errorf("failed to create output chain jump rule: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return client, nil
|
||||
}
|
||||
|
||||
func (m *Manager) actionToStr(action fw.Action) string {
|
||||
if action == fw.ActionAccept {
|
||||
return "ACCEPT"
|
||||
}
|
||||
return "DROP"
|
||||
}
|
||||
|
||||
func (m *Manager) transformIPsetName(ipsetName string, sPort, dPort string) string {
|
||||
if ipsetName == "" {
|
||||
return ""
|
||||
} else if sPort != "" && dPort != "" {
|
||||
return ipsetName + "-sport-dport"
|
||||
} else if sPort != "" {
|
||||
return ipsetName + "-sport"
|
||||
} else if dPort != "" {
|
||||
return ipsetName + "-dport"
|
||||
}
|
||||
return ipsetName
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package iptables
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"testing"
|
||||
@@ -9,7 +10,7 @@ import (
|
||||
"github.com/coreos/go-iptables/iptables"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
fw "github.com/netbirdio/netbird/client/firewall"
|
||||
fw "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
)
|
||||
|
||||
@@ -33,6 +34,8 @@ func (i *iFaceMock) Address() iface.WGAddress {
|
||||
panic("AddressFunc is not set")
|
||||
}
|
||||
|
||||
func (i *iFaceMock) IsUserspaceBind() bool { return false }
|
||||
|
||||
func TestIptablesManager(t *testing.T) {
|
||||
ipv4Client, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
||||
require.NoError(t, err)
|
||||
@@ -53,7 +56,7 @@ func TestIptablesManager(t *testing.T) {
|
||||
}
|
||||
|
||||
// just check on the local interface
|
||||
manager, err := Create(mock)
|
||||
manager, err := Create(context.Background(), mock)
|
||||
require.NoError(t, err)
|
||||
|
||||
time.Sleep(time.Second)
|
||||
@@ -65,17 +68,20 @@ func TestIptablesManager(t *testing.T) {
|
||||
time.Sleep(time.Second)
|
||||
}()
|
||||
|
||||
var rule1 fw.Rule
|
||||
var rule1 []fw.Rule
|
||||
t.Run("add first rule", func(t *testing.T) {
|
||||
ip := net.ParseIP("10.20.0.2")
|
||||
port := &fw.Port{Values: []int{8080}}
|
||||
rule1, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic")
|
||||
require.NoError(t, err, "failed to add rule")
|
||||
|
||||
checkRuleSpecs(t, ipv4Client, ChainOutputFilterName, true, rule1.(*Rule).specs...)
|
||||
for _, r := range rule1 {
|
||||
checkRuleSpecs(t, ipv4Client, chainNameOutputRules, true, r.(*Rule).specs...)
|
||||
}
|
||||
|
||||
})
|
||||
|
||||
var rule2 fw.Rule
|
||||
var rule2 []fw.Rule
|
||||
t.Run("add second rule", func(t *testing.T) {
|
||||
ip := net.ParseIP("10.20.0.3")
|
||||
port := &fw.Port{
|
||||
@@ -85,21 +91,28 @@ func TestIptablesManager(t *testing.T) {
|
||||
ip, "tcp", port, nil, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTPS traffic from ports range")
|
||||
require.NoError(t, err, "failed to add rule")
|
||||
|
||||
checkRuleSpecs(t, ipv4Client, ChainInputFilterName, true, rule2.(*Rule).specs...)
|
||||
for _, r := range rule2 {
|
||||
rr := r.(*Rule)
|
||||
checkRuleSpecs(t, ipv4Client, rr.chain, true, rr.specs...)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("delete first rule", func(t *testing.T) {
|
||||
err := manager.DeleteRule(rule1)
|
||||
require.NoError(t, err, "failed to delete rule")
|
||||
for _, r := range rule1 {
|
||||
err := manager.DeleteRule(r)
|
||||
require.NoError(t, err, "failed to delete rule")
|
||||
|
||||
checkRuleSpecs(t, ipv4Client, ChainOutputFilterName, false, rule1.(*Rule).specs...)
|
||||
checkRuleSpecs(t, ipv4Client, chainNameOutputRules, false, r.(*Rule).specs...)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("delete second rule", func(t *testing.T) {
|
||||
err := manager.DeleteRule(rule2)
|
||||
require.NoError(t, err, "failed to delete rule")
|
||||
for _, r := range rule2 {
|
||||
err := manager.DeleteRule(r)
|
||||
require.NoError(t, err, "failed to delete rule")
|
||||
}
|
||||
|
||||
require.Empty(t, manager.rulesets, "rulesets index after removed second rule must be empty")
|
||||
require.Empty(t, manager.aclMgr.ipsetStore.ipsets, "rulesets index after removed second rule must be empty")
|
||||
})
|
||||
|
||||
t.Run("reset check", func(t *testing.T) {
|
||||
@@ -112,11 +125,11 @@ func TestIptablesManager(t *testing.T) {
|
||||
err = manager.Reset()
|
||||
require.NoError(t, err, "failed to reset")
|
||||
|
||||
ok, err := ipv4Client.ChainExists("filter", ChainInputFilterName)
|
||||
ok, err := ipv4Client.ChainExists("filter", chainNameInputRules)
|
||||
require.NoError(t, err, "failed check chain exists")
|
||||
|
||||
if ok {
|
||||
require.NoErrorf(t, err, "chain '%v' still exists after Reset", ChainInputFilterName)
|
||||
require.NoErrorf(t, err, "chain '%v' still exists after Reset", chainNameInputRules)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -141,7 +154,7 @@ func TestIptablesManagerIPSet(t *testing.T) {
|
||||
}
|
||||
|
||||
// just check on the local interface
|
||||
manager, err := Create(mock)
|
||||
manager, err := Create(context.Background(), mock)
|
||||
require.NoError(t, err)
|
||||
|
||||
time.Sleep(time.Second)
|
||||
@@ -153,7 +166,7 @@ func TestIptablesManagerIPSet(t *testing.T) {
|
||||
time.Sleep(time.Second)
|
||||
}()
|
||||
|
||||
var rule1 fw.Rule
|
||||
var rule1 []fw.Rule
|
||||
t.Run("add first rule with set", func(t *testing.T) {
|
||||
ip := net.ParseIP("10.20.0.2")
|
||||
port := &fw.Port{Values: []int{8080}}
|
||||
@@ -163,12 +176,14 @@ func TestIptablesManagerIPSet(t *testing.T) {
|
||||
)
|
||||
require.NoError(t, err, "failed to add rule")
|
||||
|
||||
checkRuleSpecs(t, ipv4Client, ChainOutputFilterName, true, rule1.(*Rule).specs...)
|
||||
require.Equal(t, rule1.(*Rule).ipsetName, "default-dport", "ipset name must be set")
|
||||
require.Equal(t, rule1.(*Rule).ip, "10.20.0.2", "ipset IP must be set")
|
||||
for _, r := range rule1 {
|
||||
checkRuleSpecs(t, ipv4Client, chainNameOutputRules, true, r.(*Rule).specs...)
|
||||
require.Equal(t, r.(*Rule).ipsetName, "default-dport", "ipset name must be set")
|
||||
require.Equal(t, r.(*Rule).ip, "10.20.0.2", "ipset IP must be set")
|
||||
}
|
||||
})
|
||||
|
||||
var rule2 fw.Rule
|
||||
var rule2 []fw.Rule
|
||||
t.Run("add second rule", func(t *testing.T) {
|
||||
ip := net.ParseIP("10.20.0.3")
|
||||
port := &fw.Port{
|
||||
@@ -178,23 +193,29 @@ func TestIptablesManagerIPSet(t *testing.T) {
|
||||
ip, "tcp", port, nil, fw.RuleDirectionIN, fw.ActionAccept,
|
||||
"default", "accept HTTPS traffic from ports range",
|
||||
)
|
||||
require.NoError(t, err, "failed to add rule")
|
||||
require.Equal(t, rule2.(*Rule).ipsetName, "default-sport", "ipset name must be set")
|
||||
require.Equal(t, rule2.(*Rule).ip, "10.20.0.3", "ipset IP must be set")
|
||||
for _, r := range rule2 {
|
||||
require.NoError(t, err, "failed to add rule")
|
||||
require.Equal(t, r.(*Rule).ipsetName, "default-sport", "ipset name must be set")
|
||||
require.Equal(t, r.(*Rule).ip, "10.20.0.3", "ipset IP must be set")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("delete first rule", func(t *testing.T) {
|
||||
err := manager.DeleteRule(rule1)
|
||||
require.NoError(t, err, "failed to delete rule")
|
||||
for _, r := range rule1 {
|
||||
err := manager.DeleteRule(r)
|
||||
require.NoError(t, err, "failed to delete rule")
|
||||
|
||||
require.NotContains(t, manager.rulesets, rule1.(*Rule).ruleID, "rule must be removed form the ruleset index")
|
||||
require.NotContains(t, manager.aclMgr.ipsetStore.ipsets, r.(*Rule).ruleID, "rule must be removed form the ruleset index")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("delete second rule", func(t *testing.T) {
|
||||
err := manager.DeleteRule(rule2)
|
||||
require.NoError(t, err, "failed to delete rule")
|
||||
for _, r := range rule2 {
|
||||
err := manager.DeleteRule(r)
|
||||
require.NoError(t, err, "failed to delete rule")
|
||||
|
||||
require.Empty(t, manager.rulesets, "rulesets index after removed second rule must be empty")
|
||||
require.Empty(t, manager.aclMgr.ipsetStore.ipsets, "rulesets index after removed second rule must be empty")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("reset check", func(t *testing.T) {
|
||||
@@ -204,6 +225,7 @@ func TestIptablesManagerIPSet(t *testing.T) {
|
||||
}
|
||||
|
||||
func checkRuleSpecs(t *testing.T, ipv4Client *iptables.IPTables, chainName string, mustExists bool, rulespec ...string) {
|
||||
t.Helper()
|
||||
exists, err := ipv4Client.Exists("filter", chainName, rulespec...)
|
||||
require.NoError(t, err, "failed to check rule")
|
||||
require.Falsef(t, !exists && mustExists, "rule '%v' does not exist", rulespec)
|
||||
@@ -229,7 +251,7 @@ func TestIptablesCreatePerformance(t *testing.T) {
|
||||
for _, testMax := range []int{10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 200, 300, 400, 500, 600, 700, 800, 900, 1000} {
|
||||
t.Run(fmt.Sprintf("Testing %d rules", testMax), func(t *testing.T) {
|
||||
// just check on the local interface
|
||||
manager, err := Create(mock)
|
||||
manager, err := Create(context.Background(), mock)
|
||||
require.NoError(t, err)
|
||||
time.Sleep(time.Second)
|
||||
|
||||
@@ -240,7 +262,6 @@ func TestIptablesCreatePerformance(t *testing.T) {
|
||||
time.Sleep(time.Second)
|
||||
}()
|
||||
|
||||
_, err = manager.client(net.ParseIP("10.20.0.100"))
|
||||
require.NoError(t, err)
|
||||
|
||||
ip := net.ParseIP("10.20.0.100")
|
||||
|
||||
340
client/firewall/iptables/router_linux.go
Normal file
340
client/firewall/iptables/router_linux.go
Normal file
@@ -0,0 +1,340 @@
|
||||
//go:build !android
|
||||
|
||||
package iptables
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/coreos/go-iptables/iptables"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
)
|
||||
|
||||
const (
|
||||
Ipv4Forwarding = "netbird-rt-forwarding"
|
||||
ipv4Nat = "netbird-rt-nat"
|
||||
)
|
||||
|
||||
// constants needed to manage and create iptable rules
|
||||
const (
|
||||
tableFilter = "filter"
|
||||
tableNat = "nat"
|
||||
chainFORWARD = "FORWARD"
|
||||
chainPOSTROUTING = "POSTROUTING"
|
||||
chainRTNAT = "NETBIRD-RT-NAT"
|
||||
chainRTFWD = "NETBIRD-RT-FWD"
|
||||
routingFinalForwardJump = "ACCEPT"
|
||||
routingFinalNatJump = "MASQUERADE"
|
||||
)
|
||||
|
||||
type routerManager struct {
|
||||
ctx context.Context
|
||||
stop context.CancelFunc
|
||||
iptablesClient *iptables.IPTables
|
||||
rules map[string][]string
|
||||
}
|
||||
|
||||
func newRouterManager(parentCtx context.Context, iptablesClient *iptables.IPTables) (*routerManager, error) {
|
||||
ctx, cancel := context.WithCancel(parentCtx)
|
||||
m := &routerManager{
|
||||
ctx: ctx,
|
||||
stop: cancel,
|
||||
iptablesClient: iptablesClient,
|
||||
rules: make(map[string][]string),
|
||||
}
|
||||
|
||||
err := m.cleanUpDefaultForwardRules()
|
||||
if err != nil {
|
||||
log.Errorf("failed to cleanup routing rules: %s", err)
|
||||
return nil, err
|
||||
}
|
||||
err = m.createContainers()
|
||||
if err != nil {
|
||||
log.Errorf("failed to create containers for route: %s", err)
|
||||
}
|
||||
return m, err
|
||||
}
|
||||
|
||||
// InsertRoutingRules inserts an iptables rule pair to the forwarding chain and if enabled, to the nat chain
|
||||
func (i *routerManager) InsertRoutingRules(pair firewall.RouterPair) error {
|
||||
err := i.insertRoutingRule(firewall.ForwardingFormat, tableFilter, chainRTFWD, routingFinalForwardJump, pair)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = i.insertRoutingRule(firewall.InForwardingFormat, tableFilter, chainRTFWD, routingFinalForwardJump, firewall.GetInPair(pair))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !pair.Masquerade {
|
||||
return nil
|
||||
}
|
||||
|
||||
err = i.insertRoutingRule(firewall.NatFormat, tableNat, chainRTNAT, routingFinalNatJump, pair)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = i.insertRoutingRule(firewall.InNatFormat, tableNat, chainRTNAT, routingFinalNatJump, firewall.GetInPair(pair))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// insertRoutingRule inserts an iptable rule
|
||||
func (i *routerManager) insertRoutingRule(keyFormat, table, chain, jump string, pair firewall.RouterPair) error {
|
||||
var err error
|
||||
|
||||
ruleKey := firewall.GenKey(keyFormat, pair.ID)
|
||||
rule := genRuleSpec(jump, ruleKey, pair.Source, pair.Destination)
|
||||
existingRule, found := i.rules[ruleKey]
|
||||
if found {
|
||||
err = i.iptablesClient.DeleteIfExists(table, chain, existingRule...)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error while removing existing %s rule for %s: %v", getIptablesRuleType(table), pair.Destination, err)
|
||||
}
|
||||
delete(i.rules, ruleKey)
|
||||
}
|
||||
err = i.iptablesClient.Insert(table, chain, 1, rule...)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error while adding new %s rule for %s: %v", getIptablesRuleType(table), pair.Destination, err)
|
||||
}
|
||||
|
||||
i.rules[ruleKey] = rule
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemoveRoutingRules removes an iptables rule pair from forwarding and nat chains
|
||||
func (i *routerManager) RemoveRoutingRules(pair firewall.RouterPair) error {
|
||||
err := i.removeRoutingRule(firewall.ForwardingFormat, tableFilter, chainRTFWD, pair)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = i.removeRoutingRule(firewall.InForwardingFormat, tableFilter, chainRTFWD, firewall.GetInPair(pair))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !pair.Masquerade {
|
||||
return nil
|
||||
}
|
||||
|
||||
err = i.removeRoutingRule(firewall.NatFormat, tableNat, chainRTNAT, pair)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = i.removeRoutingRule(firewall.InNatFormat, tableNat, chainRTNAT, firewall.GetInPair(pair))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (i *routerManager) removeRoutingRule(keyFormat, table, chain string, pair firewall.RouterPair) error {
|
||||
var err error
|
||||
|
||||
ruleKey := firewall.GenKey(keyFormat, pair.ID)
|
||||
existingRule, found := i.rules[ruleKey]
|
||||
if found {
|
||||
err = i.iptablesClient.DeleteIfExists(table, chain, existingRule...)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error while removing existing %s rule for %s: %v", getIptablesRuleType(table), pair.Destination, err)
|
||||
}
|
||||
}
|
||||
delete(i.rules, ruleKey)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (i *routerManager) RouteingFwChainName() string {
|
||||
return chainRTFWD
|
||||
}
|
||||
|
||||
func (i *routerManager) Reset() error {
|
||||
err := i.cleanUpDefaultForwardRules()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
i.rules = make(map[string][]string)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (i *routerManager) cleanUpDefaultForwardRules() error {
|
||||
err := i.cleanJumpRules()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
log.Debug("flushing routing related tables")
|
||||
ok, err := i.iptablesClient.ChainExists(tableFilter, chainRTFWD)
|
||||
if err != nil {
|
||||
log.Errorf("failed check chain %s,error: %v", chainRTFWD, err)
|
||||
return err
|
||||
} else if ok {
|
||||
err = i.iptablesClient.ClearAndDeleteChain(tableFilter, chainRTFWD)
|
||||
if err != nil {
|
||||
log.Errorf("failed cleaning chain %s,error: %v", chainRTFWD, err)
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
ok, err = i.iptablesClient.ChainExists(tableNat, chainRTNAT)
|
||||
if err != nil {
|
||||
log.Errorf("failed check chain %s,error: %v", chainRTNAT, err)
|
||||
return err
|
||||
} else if ok {
|
||||
err = i.iptablesClient.ClearAndDeleteChain(tableNat, chainRTNAT)
|
||||
if err != nil {
|
||||
log.Errorf("failed cleaning chain %s,error: %v", chainRTNAT, err)
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (i *routerManager) createContainers() error {
|
||||
if i.rules[Ipv4Forwarding] != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
errMSGFormat := "failed creating chain %s,error: %v"
|
||||
err := i.createChain(tableFilter, chainRTFWD)
|
||||
if err != nil {
|
||||
return fmt.Errorf(errMSGFormat, chainRTFWD, err)
|
||||
}
|
||||
|
||||
err = i.createChain(tableNat, chainRTNAT)
|
||||
if err != nil {
|
||||
return fmt.Errorf(errMSGFormat, chainRTNAT, err)
|
||||
}
|
||||
|
||||
err = i.addJumpRules()
|
||||
if err != nil {
|
||||
return fmt.Errorf("error while creating jump rules: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// addJumpRules create jump rules to send packets to NetBird chains
|
||||
func (i *routerManager) addJumpRules() error {
|
||||
rule := []string{"-j", chainRTFWD}
|
||||
err := i.iptablesClient.Insert(tableFilter, chainFORWARD, 1, rule...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
i.rules[Ipv4Forwarding] = rule
|
||||
|
||||
rule = []string{"-j", chainRTNAT}
|
||||
err = i.iptablesClient.Insert(tableNat, chainPOSTROUTING, 1, rule...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
i.rules[ipv4Nat] = rule
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// cleanJumpRules cleans jump rules that was sending packets to NetBird chains
|
||||
func (i *routerManager) cleanJumpRules() error {
|
||||
var err error
|
||||
errMSGFormat := "failed cleaning rule from chain %s,err: %v"
|
||||
rule, found := i.rules[Ipv4Forwarding]
|
||||
if found {
|
||||
err = i.iptablesClient.DeleteIfExists(tableFilter, chainFORWARD, rule...)
|
||||
if err != nil {
|
||||
return fmt.Errorf(errMSGFormat, chainFORWARD, err)
|
||||
}
|
||||
}
|
||||
rule, found = i.rules[ipv4Nat]
|
||||
if found {
|
||||
err = i.iptablesClient.DeleteIfExists(tableNat, chainPOSTROUTING, rule...)
|
||||
if err != nil {
|
||||
return fmt.Errorf(errMSGFormat, chainPOSTROUTING, err)
|
||||
}
|
||||
}
|
||||
|
||||
rules, err := i.iptablesClient.List("nat", "POSTROUTING")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to list rules: %s", err)
|
||||
}
|
||||
|
||||
for _, ruleString := range rules {
|
||||
if !strings.Contains(ruleString, "NETBIRD") {
|
||||
continue
|
||||
}
|
||||
rule := strings.Fields(ruleString)
|
||||
err := i.iptablesClient.DeleteIfExists("nat", "POSTROUTING", rule[2:]...)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to delete postrouting jump rule: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
rules, err = i.iptablesClient.List(tableFilter, "FORWARD")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to list rules in FORWARD chain: %s", err)
|
||||
}
|
||||
|
||||
for _, ruleString := range rules {
|
||||
if !strings.Contains(ruleString, "NETBIRD") {
|
||||
continue
|
||||
}
|
||||
rule := strings.Fields(ruleString)
|
||||
err := i.iptablesClient.DeleteIfExists(tableFilter, "FORWARD", rule[2:]...)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to delete FORWARD jump rule: %s", err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (i *routerManager) createChain(table, newChain string) error {
|
||||
chains, err := i.iptablesClient.ListChains(table)
|
||||
if err != nil {
|
||||
return fmt.Errorf("couldn't get %s table chains, error: %v", table, err)
|
||||
}
|
||||
|
||||
shouldCreateChain := true
|
||||
for _, chain := range chains {
|
||||
if chain == newChain {
|
||||
shouldCreateChain = false
|
||||
}
|
||||
}
|
||||
|
||||
if shouldCreateChain {
|
||||
err = i.iptablesClient.NewChain(table, newChain)
|
||||
if err != nil {
|
||||
return fmt.Errorf("couldn't create chain %s in %s table, error: %v", newChain, table, err)
|
||||
}
|
||||
|
||||
err = i.iptablesClient.Append(table, newChain, "-j", "RETURN")
|
||||
if err != nil {
|
||||
return fmt.Errorf("couldn't create chain %s default rule, error: %v", newChain, err)
|
||||
}
|
||||
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// genRuleSpec generates rule specification with comment identifier
|
||||
func genRuleSpec(jump, id, source, destination string) []string {
|
||||
return []string{"-s", source, "-d", destination, "-j", jump, "-m", "comment", "--comment", id}
|
||||
}
|
||||
|
||||
func getIptablesRuleType(table string) string {
|
||||
ruleType := "forwarding"
|
||||
if table == tableNat {
|
||||
ruleType = "nat"
|
||||
}
|
||||
return ruleType
|
||||
}
|
||||
229
client/firewall/iptables/router_linux_test.go
Normal file
229
client/firewall/iptables/router_linux_test.go
Normal file
@@ -0,0 +1,229 @@
|
||||
//go:build !android
|
||||
|
||||
package iptables
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os/exec"
|
||||
"testing"
|
||||
|
||||
"github.com/coreos/go-iptables/iptables"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/firewall/test"
|
||||
)
|
||||
|
||||
func isIptablesSupported() bool {
|
||||
_, err4 := exec.LookPath("iptables")
|
||||
return err4 == nil
|
||||
}
|
||||
|
||||
func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) {
|
||||
if !isIptablesSupported() {
|
||||
t.SkipNow()
|
||||
}
|
||||
|
||||
iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
||||
require.NoError(t, err, "failed to init iptables client")
|
||||
|
||||
manager, err := newRouterManager(context.TODO(), iptablesClient)
|
||||
require.NoError(t, err, "should return a valid iptables manager")
|
||||
|
||||
defer func() {
|
||||
_ = manager.Reset()
|
||||
}()
|
||||
|
||||
require.Len(t, manager.rules, 2, "should have created rules map")
|
||||
|
||||
exists, err := manager.iptablesClient.Exists(tableFilter, chainFORWARD, manager.rules[Ipv4Forwarding]...)
|
||||
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableFilter, chainFORWARD)
|
||||
require.True(t, exists, "forwarding rule should exist")
|
||||
|
||||
exists, err = manager.iptablesClient.Exists(tableNat, chainPOSTROUTING, manager.rules[ipv4Nat]...)
|
||||
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainPOSTROUTING)
|
||||
require.True(t, exists, "postrouting rule should exist")
|
||||
|
||||
pair := firewall.RouterPair{
|
||||
ID: "abc",
|
||||
Source: "100.100.100.1/32",
|
||||
Destination: "100.100.100.0/24",
|
||||
Masquerade: true,
|
||||
}
|
||||
forward4RuleKey := firewall.GenKey(firewall.ForwardingFormat, pair.ID)
|
||||
forward4Rule := genRuleSpec(routingFinalForwardJump, forward4RuleKey, pair.Source, pair.Destination)
|
||||
|
||||
err = manager.iptablesClient.Insert(tableFilter, chainRTFWD, 1, forward4Rule...)
|
||||
require.NoError(t, err, "inserting rule should not return error")
|
||||
|
||||
nat4RuleKey := firewall.GenKey(firewall.NatFormat, pair.ID)
|
||||
nat4Rule := genRuleSpec(routingFinalNatJump, nat4RuleKey, pair.Source, pair.Destination)
|
||||
|
||||
err = manager.iptablesClient.Insert(tableNat, chainRTNAT, 1, nat4Rule...)
|
||||
require.NoError(t, err, "inserting rule should not return error")
|
||||
|
||||
err = manager.Reset()
|
||||
require.NoError(t, err, "shouldn't return error")
|
||||
}
|
||||
|
||||
func TestIptablesManager_InsertRoutingRules(t *testing.T) {
|
||||
|
||||
if !isIptablesSupported() {
|
||||
t.SkipNow()
|
||||
}
|
||||
|
||||
for _, testCase := range test.InsertRuleTestCases {
|
||||
t.Run(testCase.Name, func(t *testing.T) {
|
||||
iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
||||
require.NoError(t, err, "failed to init iptables client")
|
||||
|
||||
manager, err := newRouterManager(context.TODO(), iptablesClient)
|
||||
require.NoError(t, err, "shouldn't return error")
|
||||
|
||||
defer func() {
|
||||
err := manager.Reset()
|
||||
if err != nil {
|
||||
log.Errorf("failed to reset iptables manager: %s", err)
|
||||
}
|
||||
}()
|
||||
|
||||
err = manager.InsertRoutingRules(testCase.InputPair)
|
||||
require.NoError(t, err, "forwarding pair should be inserted")
|
||||
|
||||
forwardRuleKey := firewall.GenKey(firewall.ForwardingFormat, testCase.InputPair.ID)
|
||||
forwardRule := genRuleSpec(routingFinalForwardJump, forwardRuleKey, testCase.InputPair.Source, testCase.InputPair.Destination)
|
||||
|
||||
exists, err := iptablesClient.Exists(tableFilter, chainRTFWD, forwardRule...)
|
||||
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableFilter, chainRTFWD)
|
||||
require.True(t, exists, "forwarding rule should exist")
|
||||
|
||||
foundRule, found := manager.rules[forwardRuleKey]
|
||||
require.True(t, found, "forwarding rule should exist in the manager map")
|
||||
require.Equal(t, forwardRule[:4], foundRule[:4], "stored forwarding rule should match")
|
||||
|
||||
inForwardRuleKey := firewall.GenKey(firewall.InForwardingFormat, testCase.InputPair.ID)
|
||||
inForwardRule := genRuleSpec(routingFinalForwardJump, inForwardRuleKey, firewall.GetInPair(testCase.InputPair).Source, firewall.GetInPair(testCase.InputPair).Destination)
|
||||
|
||||
exists, err = iptablesClient.Exists(tableFilter, chainRTFWD, inForwardRule...)
|
||||
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableFilter, chainRTFWD)
|
||||
require.True(t, exists, "income forwarding rule should exist")
|
||||
|
||||
foundRule, found = manager.rules[inForwardRuleKey]
|
||||
require.True(t, found, "income forwarding rule should exist in the manager map")
|
||||
require.Equal(t, inForwardRule[:4], foundRule[:4], "stored income forwarding rule should match")
|
||||
|
||||
natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair.ID)
|
||||
natRule := genRuleSpec(routingFinalNatJump, natRuleKey, testCase.InputPair.Source, testCase.InputPair.Destination)
|
||||
|
||||
exists, err = iptablesClient.Exists(tableNat, chainRTNAT, natRule...)
|
||||
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainRTNAT)
|
||||
if testCase.InputPair.Masquerade {
|
||||
require.True(t, exists, "nat rule should be created")
|
||||
foundNatRule, foundNat := manager.rules[natRuleKey]
|
||||
require.True(t, foundNat, "nat rule should exist in the map")
|
||||
require.Equal(t, natRule[:4], foundNatRule[:4], "stored nat rule should match")
|
||||
} else {
|
||||
require.False(t, exists, "nat rule should not be created")
|
||||
_, foundNat := manager.rules[natRuleKey]
|
||||
require.False(t, foundNat, "nat rule should not exist in the map")
|
||||
}
|
||||
|
||||
inNatRuleKey := firewall.GenKey(firewall.InNatFormat, testCase.InputPair.ID)
|
||||
inNatRule := genRuleSpec(routingFinalNatJump, inNatRuleKey, firewall.GetInPair(testCase.InputPair).Source, firewall.GetInPair(testCase.InputPair).Destination)
|
||||
|
||||
exists, err = iptablesClient.Exists(tableNat, chainRTNAT, inNatRule...)
|
||||
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainRTNAT)
|
||||
if testCase.InputPair.Masquerade {
|
||||
require.True(t, exists, "income nat rule should be created")
|
||||
foundNatRule, foundNat := manager.rules[inNatRuleKey]
|
||||
require.True(t, foundNat, "income nat rule should exist in the map")
|
||||
require.Equal(t, inNatRule[:4], foundNatRule[:4], "stored income nat rule should match")
|
||||
} else {
|
||||
require.False(t, exists, "nat rule should not be created")
|
||||
_, foundNat := manager.rules[inNatRuleKey]
|
||||
require.False(t, foundNat, "income nat rule should not exist in the map")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIptablesManager_RemoveRoutingRules(t *testing.T) {
|
||||
|
||||
if !isIptablesSupported() {
|
||||
t.SkipNow()
|
||||
}
|
||||
|
||||
for _, testCase := range test.RemoveRuleTestCases {
|
||||
t.Run(testCase.Name, func(t *testing.T) {
|
||||
iptablesClient, _ := iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
||||
|
||||
manager, err := newRouterManager(context.TODO(), iptablesClient)
|
||||
require.NoError(t, err, "shouldn't return error")
|
||||
defer func() {
|
||||
_ = manager.Reset()
|
||||
}()
|
||||
|
||||
require.NoError(t, err, "shouldn't return error")
|
||||
|
||||
forwardRuleKey := firewall.GenKey(firewall.ForwardingFormat, testCase.InputPair.ID)
|
||||
forwardRule := genRuleSpec(routingFinalForwardJump, forwardRuleKey, testCase.InputPair.Source, testCase.InputPair.Destination)
|
||||
|
||||
err = iptablesClient.Insert(tableFilter, chainRTFWD, 1, forwardRule...)
|
||||
require.NoError(t, err, "inserting rule should not return error")
|
||||
|
||||
inForwardRuleKey := firewall.GenKey(firewall.InForwardingFormat, testCase.InputPair.ID)
|
||||
inForwardRule := genRuleSpec(routingFinalForwardJump, inForwardRuleKey, firewall.GetInPair(testCase.InputPair).Source, firewall.GetInPair(testCase.InputPair).Destination)
|
||||
|
||||
err = iptablesClient.Insert(tableFilter, chainRTFWD, 1, inForwardRule...)
|
||||
require.NoError(t, err, "inserting rule should not return error")
|
||||
|
||||
natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair.ID)
|
||||
natRule := genRuleSpec(routingFinalNatJump, natRuleKey, testCase.InputPair.Source, testCase.InputPair.Destination)
|
||||
|
||||
err = iptablesClient.Insert(tableNat, chainRTNAT, 1, natRule...)
|
||||
require.NoError(t, err, "inserting rule should not return error")
|
||||
|
||||
inNatRuleKey := firewall.GenKey(firewall.InNatFormat, testCase.InputPair.ID)
|
||||
inNatRule := genRuleSpec(routingFinalNatJump, inNatRuleKey, firewall.GetInPair(testCase.InputPair).Source, firewall.GetInPair(testCase.InputPair).Destination)
|
||||
|
||||
err = iptablesClient.Insert(tableNat, chainRTNAT, 1, inNatRule...)
|
||||
require.NoError(t, err, "inserting rule should not return error")
|
||||
|
||||
err = manager.Reset()
|
||||
require.NoError(t, err, "shouldn't return error")
|
||||
|
||||
err = manager.RemoveRoutingRules(testCase.InputPair)
|
||||
require.NoError(t, err, "shouldn't return error")
|
||||
|
||||
exists, err := iptablesClient.Exists(tableFilter, chainRTFWD, forwardRule...)
|
||||
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableFilter, chainRTFWD)
|
||||
require.False(t, exists, "forwarding rule should not exist")
|
||||
|
||||
_, found := manager.rules[forwardRuleKey]
|
||||
require.False(t, found, "forwarding rule should exist in the manager map")
|
||||
|
||||
exists, err = iptablesClient.Exists(tableFilter, chainRTFWD, inForwardRule...)
|
||||
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableFilter, chainRTFWD)
|
||||
require.False(t, exists, "income forwarding rule should not exist")
|
||||
|
||||
_, found = manager.rules[inForwardRuleKey]
|
||||
require.False(t, found, "income forwarding rule should exist in the manager map")
|
||||
|
||||
exists, err = iptablesClient.Exists(tableNat, chainRTNAT, natRule...)
|
||||
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainRTNAT)
|
||||
require.False(t, exists, "nat rule should not exist")
|
||||
|
||||
_, found = manager.rules[natRuleKey]
|
||||
require.False(t, found, "nat rule should exist in the manager map")
|
||||
|
||||
exists, err = iptablesClient.Exists(tableNat, chainRTNAT, inNatRule...)
|
||||
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainRTNAT)
|
||||
require.False(t, exists, "income nat rule should not exist")
|
||||
|
||||
_, found = manager.rules[inNatRuleKey]
|
||||
require.False(t, found, "income nat rule should exist in the manager map")
|
||||
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -7,8 +7,7 @@ type Rule struct {
|
||||
|
||||
specs []string
|
||||
ip string
|
||||
dst bool
|
||||
v6 bool
|
||||
chain string
|
||||
}
|
||||
|
||||
// GetRuleID returns the rule id
|
||||
|
||||
50
client/firewall/iptables/rulestore_linux.go
Normal file
50
client/firewall/iptables/rulestore_linux.go
Normal file
@@ -0,0 +1,50 @@
|
||||
package iptables
|
||||
|
||||
type ipList struct {
|
||||
ips map[string]struct{}
|
||||
}
|
||||
|
||||
func newIpList(ip string) ipList {
|
||||
ips := make(map[string]struct{})
|
||||
ips[ip] = struct{}{}
|
||||
|
||||
return ipList{
|
||||
ips: ips,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ipList) addIP(ip string) {
|
||||
s.ips[ip] = struct{}{}
|
||||
}
|
||||
|
||||
type ipsetStore struct {
|
||||
ipsets map[string]ipList // ipsetName -> ruleset
|
||||
}
|
||||
|
||||
func newIpsetStore() *ipsetStore {
|
||||
return &ipsetStore{
|
||||
ipsets: make(map[string]ipList),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ipsetStore) ipset(ipsetName string) (ipList, bool) {
|
||||
r, ok := s.ipsets[ipsetName]
|
||||
return r, ok
|
||||
}
|
||||
|
||||
func (s *ipsetStore) addIpList(ipsetName string, list ipList) {
|
||||
s.ipsets[ipsetName] = list
|
||||
}
|
||||
|
||||
func (s *ipsetStore) deleteIpset(ipsetName string) {
|
||||
s.ipsets[ipsetName] = ipList{}
|
||||
delete(s.ipsets, ipsetName)
|
||||
}
|
||||
|
||||
func (s *ipsetStore) ipsetNames() []string {
|
||||
names := make([]string, 0, len(s.ipsets))
|
||||
for name := range s.ipsets {
|
||||
names = append(names, name)
|
||||
}
|
||||
return names
|
||||
}
|
||||
@@ -1,9 +1,17 @@
|
||||
package firewall
|
||||
package manager
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
)
|
||||
|
||||
const (
|
||||
NatFormat = "netbird-nat-%s"
|
||||
ForwardingFormat = "netbird-fwd-%s"
|
||||
InNatFormat = "netbird-nat-in-%s"
|
||||
InForwardingFormat = "netbird-fwd-in-%s"
|
||||
)
|
||||
|
||||
// Rule abstraction should be implemented by each firewall manager
|
||||
//
|
||||
// Each firewall type for different OS can use different type
|
||||
@@ -27,10 +35,8 @@ const (
|
||||
type Action int
|
||||
|
||||
const (
|
||||
// ActionUnknown is a unknown action
|
||||
ActionUnknown Action = iota
|
||||
// ActionAccept is the action to accept a packet
|
||||
ActionAccept
|
||||
ActionAccept Action = iota
|
||||
// ActionDrop is the action to drop a packet
|
||||
ActionDrop
|
||||
)
|
||||
@@ -40,6 +46,9 @@ const (
|
||||
// It declares methods which handle actions required by the
|
||||
// Netbird client for ACL and routing functionality
|
||||
type Manager interface {
|
||||
// AllowNetbird allows netbird interface traffic
|
||||
AllowNetbird() error
|
||||
|
||||
// AddFiltering rule to the firewall
|
||||
//
|
||||
// If comment argument is empty firewall manager should set
|
||||
@@ -53,16 +62,27 @@ type Manager interface {
|
||||
action Action,
|
||||
ipsetName string,
|
||||
comment string,
|
||||
) (Rule, error)
|
||||
) ([]Rule, error)
|
||||
|
||||
// DeleteRule from the firewall by rule definition
|
||||
DeleteRule(rule Rule) error
|
||||
|
||||
// IsServerRouteSupported returns true if the firewall supports server side routing operations
|
||||
IsServerRouteSupported() bool
|
||||
|
||||
// InsertRoutingRules inserts a routing firewall rule
|
||||
InsertRoutingRules(pair RouterPair) error
|
||||
|
||||
// RemoveRoutingRules removes a routing firewall rule
|
||||
RemoveRoutingRules(pair RouterPair) error
|
||||
|
||||
// Reset firewall to the default state
|
||||
Reset() error
|
||||
|
||||
// Flush the changes to firewall controller
|
||||
Flush() error
|
||||
|
||||
// TODO: migrate routemanager firewal actions to this interface
|
||||
}
|
||||
|
||||
func GenKey(format string, input string) string {
|
||||
return fmt.Sprintf(format, input)
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package firewall
|
||||
package manager
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
18
client/firewall/manager/routerpair.go
Normal file
18
client/firewall/manager/routerpair.go
Normal file
@@ -0,0 +1,18 @@
|
||||
package manager
|
||||
|
||||
type RouterPair struct {
|
||||
ID string
|
||||
Source string
|
||||
Destination string
|
||||
Masquerade bool
|
||||
}
|
||||
|
||||
func GetInPair(pair RouterPair) RouterPair {
|
||||
return RouterPair{
|
||||
ID: pair.ID,
|
||||
// invert Source/Destination
|
||||
Source: pair.Destination,
|
||||
Destination: pair.Source,
|
||||
Masquerade: pair.Masquerade,
|
||||
}
|
||||
}
|
||||
1196
client/firewall/nftables/acl_linux.go
Normal file
1196
client/firewall/nftables/acl_linux.go
Normal file
File diff suppressed because it is too large
Load Diff
85
client/firewall/nftables/ipsetstore_linux.go
Normal file
85
client/firewall/nftables/ipsetstore_linux.go
Normal file
@@ -0,0 +1,85 @@
|
||||
package nftables
|
||||
|
||||
import (
|
||||
"net"
|
||||
)
|
||||
|
||||
type ipsetStore struct {
|
||||
ipsetReference map[string]int
|
||||
ipsets map[string]map[string]struct{} // ipsetName -> list of ips
|
||||
}
|
||||
|
||||
func newIpsetStore() *ipsetStore {
|
||||
return &ipsetStore{
|
||||
ipsetReference: make(map[string]int),
|
||||
ipsets: make(map[string]map[string]struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ipsetStore) ips(ipsetName string) (map[string]struct{}, bool) {
|
||||
r, ok := s.ipsets[ipsetName]
|
||||
return r, ok
|
||||
}
|
||||
|
||||
func (s *ipsetStore) newIpset(ipsetName string) map[string]struct{} {
|
||||
s.ipsetReference[ipsetName] = 0
|
||||
ipList := make(map[string]struct{})
|
||||
s.ipsets[ipsetName] = ipList
|
||||
return ipList
|
||||
}
|
||||
|
||||
func (s *ipsetStore) deleteIpset(ipsetName string) {
|
||||
delete(s.ipsetReference, ipsetName)
|
||||
delete(s.ipsets, ipsetName)
|
||||
}
|
||||
|
||||
func (s *ipsetStore) DeleteIpFromSet(ipsetName string, ip net.IP) {
|
||||
ipList, ok := s.ipsets[ipsetName]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
delete(ipList, ip.String())
|
||||
}
|
||||
|
||||
func (s *ipsetStore) AddIpToSet(ipsetName string, ip net.IP) {
|
||||
ipList, ok := s.ipsets[ipsetName]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
ipList[ip.String()] = struct{}{}
|
||||
}
|
||||
|
||||
func (s *ipsetStore) IsIpInSet(ipsetName string, ip net.IP) bool {
|
||||
ipList, ok := s.ipsets[ipsetName]
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
_, ok = ipList[ip.String()]
|
||||
return ok
|
||||
}
|
||||
|
||||
func (s *ipsetStore) AddReferenceToIpset(ipsetName string) {
|
||||
s.ipsetReference[ipsetName]++
|
||||
}
|
||||
|
||||
func (s *ipsetStore) DeleteReferenceFromIpSet(ipsetName string) {
|
||||
r, ok := s.ipsetReference[ipsetName]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
if r == 0 {
|
||||
return
|
||||
}
|
||||
s.ipsetReference[ipsetName]--
|
||||
}
|
||||
|
||||
func (s *ipsetStore) HasReferenceToSet(ipsetName string) bool {
|
||||
if _, ok := s.ipsetReference[ipsetName]; !ok {
|
||||
return false
|
||||
}
|
||||
if s.ipsetReference[ipsetName] == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
@@ -2,88 +2,52 @@ package nftables
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/google/nftables"
|
||||
"github.com/google/nftables/expr"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/sys/unix"
|
||||
|
||||
fw "github.com/netbirdio/netbird/client/firewall"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
)
|
||||
|
||||
const (
|
||||
// FilterTableName is the name of the table that is used for filtering by the Netbird client
|
||||
FilterTableName = "netbird-acl"
|
||||
|
||||
// FilterInputChainName is the name of the chain that is used for filtering incoming packets
|
||||
FilterInputChainName = "netbird-acl-input-filter"
|
||||
|
||||
// FilterOutputChainName is the name of the chain that is used for filtering outgoing packets
|
||||
FilterOutputChainName = "netbird-acl-output-filter"
|
||||
// tableName is the name of the table that is used for filtering by the Netbird client
|
||||
tableName = "netbird"
|
||||
)
|
||||
|
||||
var anyIP = []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}
|
||||
|
||||
// Manager of iptables firewall
|
||||
type Manager struct {
|
||||
mutex sync.Mutex
|
||||
|
||||
rConn *nftables.Conn
|
||||
sConn *nftables.Conn
|
||||
tableIPv4 *nftables.Table
|
||||
tableIPv6 *nftables.Table
|
||||
|
||||
filterInputChainIPv4 *nftables.Chain
|
||||
filterOutputChainIPv4 *nftables.Chain
|
||||
|
||||
filterInputChainIPv6 *nftables.Chain
|
||||
filterOutputChainIPv6 *nftables.Chain
|
||||
|
||||
rulesetManager *rulesetManager
|
||||
setRemovedIPs map[string]struct{}
|
||||
setRemoved map[string]*nftables.Set
|
||||
|
||||
mutex sync.Mutex
|
||||
rConn *nftables.Conn
|
||||
wgIface iFaceMapper
|
||||
}
|
||||
|
||||
// iFaceMapper defines subset methods of interface required for manager
|
||||
type iFaceMapper interface {
|
||||
Name() string
|
||||
Address() iface.WGAddress
|
||||
router *router
|
||||
aclManager *AclManager
|
||||
}
|
||||
|
||||
// Create nftables firewall manager
|
||||
func Create(wgIface iFaceMapper) (*Manager, error) {
|
||||
// sConn is used for creating sets and adding/removing elements from them
|
||||
// it's differ then rConn (which does create new conn for each flush operation)
|
||||
// and is permanent. Using same connection for booth type of operations
|
||||
// overloads netlink with high amount of rules ( > 10000)
|
||||
sConn, err := nftables.New(nftables.AsLasting())
|
||||
func Create(context context.Context, wgIface iFaceMapper) (*Manager, error) {
|
||||
m := &Manager{
|
||||
rConn: &nftables.Conn{},
|
||||
wgIface: wgIface,
|
||||
}
|
||||
|
||||
workTable, err := m.createWorkTable()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
m := &Manager{
|
||||
rConn: &nftables.Conn{},
|
||||
sConn: sConn,
|
||||
|
||||
rulesetManager: newRuleManager(),
|
||||
setRemovedIPs: map[string]struct{}{},
|
||||
setRemoved: map[string]*nftables.Set{},
|
||||
|
||||
wgIface: wgIface,
|
||||
m.router, err = newRouter(context, workTable)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := m.Reset(); err != nil {
|
||||
m.aclManager, err = newAclManager(workTable, wgIface, m.router.RouteingFwChainName())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -96,533 +60,98 @@ func Create(wgIface iFaceMapper) (*Manager, error) {
|
||||
// rule ID as comment for the rule
|
||||
func (m *Manager) AddFiltering(
|
||||
ip net.IP,
|
||||
proto fw.Protocol,
|
||||
sPort *fw.Port,
|
||||
dPort *fw.Port,
|
||||
direction fw.RuleDirection,
|
||||
action fw.Action,
|
||||
proto firewall.Protocol,
|
||||
sPort *firewall.Port,
|
||||
dPort *firewall.Port,
|
||||
direction firewall.RuleDirection,
|
||||
action firewall.Action,
|
||||
ipsetName string,
|
||||
comment string,
|
||||
) (fw.Rule, error) {
|
||||
) ([]firewall.Rule, error) {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
var (
|
||||
err error
|
||||
ipset *nftables.Set
|
||||
table *nftables.Table
|
||||
chain *nftables.Chain
|
||||
)
|
||||
|
||||
if direction == fw.RuleDirectionOUT {
|
||||
table, chain, err = m.chain(
|
||||
ip,
|
||||
FilterOutputChainName,
|
||||
nftables.ChainHookOutput,
|
||||
nftables.ChainPriorityFilter,
|
||||
nftables.ChainTypeFilter)
|
||||
} else {
|
||||
table, chain, err = m.chain(
|
||||
ip,
|
||||
FilterInputChainName,
|
||||
nftables.ChainHookInput,
|
||||
nftables.ChainPriorityFilter,
|
||||
nftables.ChainTypeFilter)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rawIP := ip.To4()
|
||||
if rawIP == nil {
|
||||
rawIP = ip.To16()
|
||||
return nil, fmt.Errorf("unsupported IP version: %s", ip.String())
|
||||
}
|
||||
|
||||
rulesetID := m.getRulesetID(ip, proto, sPort, dPort, direction, action, ipsetName)
|
||||
|
||||
if ipsetName != "" {
|
||||
// if we already have set with given name, just add ip to the set
|
||||
// and return rule with new ID in other case let's create rule
|
||||
// with fresh created set and set element
|
||||
|
||||
var isSetNew bool
|
||||
ipset, err = m.rConn.GetSetByName(table, ipsetName)
|
||||
if err != nil {
|
||||
if ipset, err = m.createSet(table, rawIP, ipsetName); err != nil {
|
||||
return nil, fmt.Errorf("get set name: %v", err)
|
||||
}
|
||||
isSetNew = true
|
||||
}
|
||||
|
||||
if err := m.sConn.SetAddElements(ipset, []nftables.SetElement{{Key: rawIP}}); err != nil {
|
||||
return nil, fmt.Errorf("add set element for the first time: %v", err)
|
||||
}
|
||||
if err := m.sConn.Flush(); err != nil {
|
||||
return nil, fmt.Errorf("flush add elements: %v", err)
|
||||
}
|
||||
|
||||
if !isSetNew {
|
||||
// if we already have nftables rules with set for given direction
|
||||
// just add new rule to the ruleset and return new fw.Rule object
|
||||
|
||||
if ruleset, ok := m.rulesetManager.getRuleset(rulesetID); ok {
|
||||
return m.rulesetManager.addRule(ruleset, rawIP)
|
||||
}
|
||||
// if ipset exists but it is not linked to rule for given direction
|
||||
// create new rule for direction and bind ipset to it later
|
||||
}
|
||||
}
|
||||
|
||||
ifaceKey := expr.MetaKeyIIFNAME
|
||||
if direction == fw.RuleDirectionOUT {
|
||||
ifaceKey = expr.MetaKeyOIFNAME
|
||||
}
|
||||
expressions := []expr.Any{
|
||||
&expr.Meta{Key: ifaceKey, Register: 1},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: ifname(m.wgIface.Name()),
|
||||
},
|
||||
}
|
||||
|
||||
if proto != "all" {
|
||||
expressions = append(expressions, &expr.Payload{
|
||||
DestRegister: 1,
|
||||
Base: expr.PayloadBaseNetworkHeader,
|
||||
Offset: uint32(9),
|
||||
Len: uint32(1),
|
||||
})
|
||||
|
||||
var protoData []byte
|
||||
switch proto {
|
||||
case fw.ProtocolTCP:
|
||||
protoData = []byte{unix.IPPROTO_TCP}
|
||||
case fw.ProtocolUDP:
|
||||
protoData = []byte{unix.IPPROTO_UDP}
|
||||
case fw.ProtocolICMP:
|
||||
protoData = []byte{unix.IPPROTO_ICMP}
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported protocol: %s", proto)
|
||||
}
|
||||
expressions = append(expressions, &expr.Cmp{
|
||||
Register: 1,
|
||||
Op: expr.CmpOpEq,
|
||||
Data: protoData,
|
||||
})
|
||||
}
|
||||
|
||||
// check if rawIP contains zeroed IPv4 0.0.0.0 or same IPv6 value
|
||||
// in that case not add IP match expression into the rule definition
|
||||
if !bytes.HasPrefix(anyIP, rawIP) {
|
||||
// source address position
|
||||
addrLen := uint32(len(rawIP))
|
||||
addrOffset := uint32(12)
|
||||
if addrLen == 16 {
|
||||
addrOffset = 8
|
||||
}
|
||||
|
||||
// change to destination address position if need
|
||||
if direction == fw.RuleDirectionOUT {
|
||||
addrOffset += addrLen
|
||||
}
|
||||
|
||||
expressions = append(expressions,
|
||||
&expr.Payload{
|
||||
DestRegister: 1,
|
||||
Base: expr.PayloadBaseNetworkHeader,
|
||||
Offset: addrOffset,
|
||||
Len: addrLen,
|
||||
},
|
||||
)
|
||||
// add individual IP for match if no ipset defined
|
||||
if ipset == nil {
|
||||
expressions = append(expressions,
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: rawIP,
|
||||
},
|
||||
)
|
||||
} else {
|
||||
expressions = append(expressions,
|
||||
&expr.Lookup{
|
||||
SourceRegister: 1,
|
||||
SetName: ipsetName,
|
||||
SetID: ipset.ID,
|
||||
},
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
if sPort != nil && len(sPort.Values) != 0 {
|
||||
expressions = append(expressions,
|
||||
&expr.Payload{
|
||||
DestRegister: 1,
|
||||
Base: expr.PayloadBaseTransportHeader,
|
||||
Offset: 0,
|
||||
Len: 2,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: encodePort(*sPort),
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
if dPort != nil && len(dPort.Values) != 0 {
|
||||
expressions = append(expressions,
|
||||
&expr.Payload{
|
||||
DestRegister: 1,
|
||||
Base: expr.PayloadBaseTransportHeader,
|
||||
Offset: 2,
|
||||
Len: 2,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: encodePort(*dPort),
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
if action == fw.ActionAccept {
|
||||
expressions = append(expressions, &expr.Verdict{Kind: expr.VerdictAccept})
|
||||
} else {
|
||||
expressions = append(expressions, &expr.Verdict{Kind: expr.VerdictDrop})
|
||||
}
|
||||
|
||||
userData := []byte(strings.Join([]string{rulesetID, comment}, " "))
|
||||
|
||||
rule := m.rConn.InsertRule(&nftables.Rule{
|
||||
Table: table,
|
||||
Chain: chain,
|
||||
Position: 0,
|
||||
Exprs: expressions,
|
||||
UserData: userData,
|
||||
})
|
||||
if err := m.rConn.Flush(); err != nil {
|
||||
return nil, fmt.Errorf("flush insert rule: %v", err)
|
||||
}
|
||||
|
||||
ruleset := m.rulesetManager.createRuleset(rulesetID, rule, ipset)
|
||||
return m.rulesetManager.addRule(ruleset, rawIP)
|
||||
}
|
||||
|
||||
// getRulesetID returns ruleset ID based on given parameters
|
||||
func (m *Manager) getRulesetID(
|
||||
ip net.IP,
|
||||
proto fw.Protocol,
|
||||
sPort *fw.Port,
|
||||
dPort *fw.Port,
|
||||
direction fw.RuleDirection,
|
||||
action fw.Action,
|
||||
ipsetName string,
|
||||
) string {
|
||||
rulesetID := ":" + strconv.Itoa(int(direction)) + ":"
|
||||
if sPort != nil {
|
||||
rulesetID += sPort.String()
|
||||
}
|
||||
rulesetID += ":"
|
||||
if dPort != nil {
|
||||
rulesetID += dPort.String()
|
||||
}
|
||||
rulesetID += ":"
|
||||
rulesetID += strconv.Itoa(int(action))
|
||||
if ipsetName == "" {
|
||||
return "ip:" + ip.String() + rulesetID
|
||||
}
|
||||
return "set:" + ipsetName + rulesetID
|
||||
}
|
||||
|
||||
// createSet in given table by name
|
||||
func (m *Manager) createSet(
|
||||
table *nftables.Table,
|
||||
rawIP []byte,
|
||||
name string,
|
||||
) (*nftables.Set, error) {
|
||||
keyType := nftables.TypeIPAddr
|
||||
if len(rawIP) == 16 {
|
||||
keyType = nftables.TypeIP6Addr
|
||||
}
|
||||
// else we create new ipset and continue creating rule
|
||||
ipset := &nftables.Set{
|
||||
Name: name,
|
||||
Table: table,
|
||||
Dynamic: true,
|
||||
KeyType: keyType,
|
||||
}
|
||||
|
||||
if err := m.rConn.AddSet(ipset, nil); err != nil {
|
||||
return nil, fmt.Errorf("create set: %v", err)
|
||||
}
|
||||
|
||||
if err := m.rConn.Flush(); err != nil {
|
||||
return nil, fmt.Errorf("flush created set: %v", err)
|
||||
}
|
||||
|
||||
return ipset, nil
|
||||
}
|
||||
|
||||
// chain returns the chain for the given IP address with specific settings
|
||||
func (m *Manager) chain(
|
||||
ip net.IP,
|
||||
name string,
|
||||
hook nftables.ChainHook,
|
||||
priority nftables.ChainPriority,
|
||||
cType nftables.ChainType,
|
||||
) (*nftables.Table, *nftables.Chain, error) {
|
||||
var err error
|
||||
|
||||
getChain := func(c *nftables.Chain, tf nftables.TableFamily) (*nftables.Chain, error) {
|
||||
if c != nil {
|
||||
return c, nil
|
||||
}
|
||||
return m.createChainIfNotExists(tf, name, hook, priority, cType)
|
||||
}
|
||||
|
||||
if ip.To4() != nil {
|
||||
if name == FilterInputChainName {
|
||||
m.filterInputChainIPv4, err = getChain(m.filterInputChainIPv4, nftables.TableFamilyIPv4)
|
||||
return m.tableIPv4, m.filterInputChainIPv4, err
|
||||
}
|
||||
m.filterOutputChainIPv4, err = getChain(m.filterOutputChainIPv4, nftables.TableFamilyIPv4)
|
||||
return m.tableIPv4, m.filterOutputChainIPv4, err
|
||||
}
|
||||
if name == FilterInputChainName {
|
||||
m.filterInputChainIPv6, err = getChain(m.filterInputChainIPv6, nftables.TableFamilyIPv6)
|
||||
return m.tableIPv4, m.filterInputChainIPv6, err
|
||||
}
|
||||
m.filterOutputChainIPv6, err = getChain(m.filterOutputChainIPv6, nftables.TableFamilyIPv6)
|
||||
return m.tableIPv4, m.filterOutputChainIPv6, err
|
||||
}
|
||||
|
||||
// table returns the table for the given family of the IP address
|
||||
func (m *Manager) table(family nftables.TableFamily) (*nftables.Table, error) {
|
||||
if family == nftables.TableFamilyIPv4 {
|
||||
if m.tableIPv4 != nil {
|
||||
return m.tableIPv4, nil
|
||||
}
|
||||
|
||||
table, err := m.createTableIfNotExists(nftables.TableFamilyIPv4)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
m.tableIPv4 = table
|
||||
return m.tableIPv4, nil
|
||||
}
|
||||
|
||||
if m.tableIPv6 != nil {
|
||||
return m.tableIPv6, nil
|
||||
}
|
||||
|
||||
table, err := m.createTableIfNotExists(nftables.TableFamilyIPv6)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
m.tableIPv6 = table
|
||||
return m.tableIPv6, nil
|
||||
}
|
||||
|
||||
func (m *Manager) createTableIfNotExists(family nftables.TableFamily) (*nftables.Table, error) {
|
||||
tables, err := m.rConn.ListTablesOfFamily(family)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("list of tables: %w", err)
|
||||
}
|
||||
|
||||
for _, t := range tables {
|
||||
if t.Name == FilterTableName {
|
||||
return t, nil
|
||||
}
|
||||
}
|
||||
|
||||
table := m.rConn.AddTable(&nftables.Table{Name: FilterTableName, Family: nftables.TableFamilyIPv4})
|
||||
if err := m.rConn.Flush(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return table, nil
|
||||
}
|
||||
|
||||
func (m *Manager) createChainIfNotExists(
|
||||
family nftables.TableFamily,
|
||||
name string,
|
||||
hooknum nftables.ChainHook,
|
||||
priority nftables.ChainPriority,
|
||||
chainType nftables.ChainType,
|
||||
) (*nftables.Chain, error) {
|
||||
table, err := m.table(family)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
chains, err := m.rConn.ListChainsOfTableFamily(family)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("list of chains: %w", err)
|
||||
}
|
||||
|
||||
for _, c := range chains {
|
||||
if c.Name == name && c.Table.Name == table.Name {
|
||||
return c, nil
|
||||
}
|
||||
}
|
||||
|
||||
polAccept := nftables.ChainPolicyAccept
|
||||
chain := &nftables.Chain{
|
||||
Name: name,
|
||||
Table: table,
|
||||
Hooknum: hooknum,
|
||||
Priority: priority,
|
||||
Type: chainType,
|
||||
Policy: &polAccept,
|
||||
}
|
||||
|
||||
chain = m.rConn.AddChain(chain)
|
||||
|
||||
ifaceKey := expr.MetaKeyIIFNAME
|
||||
shiftDSTAddr := 0
|
||||
if name == FilterOutputChainName {
|
||||
ifaceKey = expr.MetaKeyOIFNAME
|
||||
shiftDSTAddr = 1
|
||||
}
|
||||
|
||||
expressions := []expr.Any{
|
||||
&expr.Meta{Key: ifaceKey, Register: 1},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: ifname(m.wgIface.Name()),
|
||||
},
|
||||
}
|
||||
|
||||
mask, _ := netip.AddrFromSlice(m.wgIface.Address().Network.Mask)
|
||||
if m.wgIface.Address().IP.To4() == nil {
|
||||
ip, _ := netip.AddrFromSlice(m.wgIface.Address().Network.IP.To16())
|
||||
expressions = append(expressions,
|
||||
&expr.Payload{
|
||||
DestRegister: 2,
|
||||
Base: expr.PayloadBaseNetworkHeader,
|
||||
Offset: uint32(8 + (16 * shiftDSTAddr)),
|
||||
Len: 16,
|
||||
},
|
||||
&expr.Bitwise{
|
||||
SourceRegister: 2,
|
||||
DestRegister: 2,
|
||||
Len: 16,
|
||||
Xor: []byte{0x0, 0x0, 0x0, 0x0},
|
||||
Mask: mask.Unmap().AsSlice(),
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpNeq,
|
||||
Register: 2,
|
||||
Data: ip.Unmap().AsSlice(),
|
||||
},
|
||||
&expr.Verdict{Kind: expr.VerdictAccept},
|
||||
)
|
||||
} else {
|
||||
ip, _ := netip.AddrFromSlice(m.wgIface.Address().Network.IP.To4())
|
||||
expressions = append(expressions,
|
||||
&expr.Payload{
|
||||
DestRegister: 2,
|
||||
Base: expr.PayloadBaseNetworkHeader,
|
||||
Offset: uint32(12 + (4 * shiftDSTAddr)),
|
||||
Len: 4,
|
||||
},
|
||||
&expr.Bitwise{
|
||||
SourceRegister: 2,
|
||||
DestRegister: 2,
|
||||
Len: 4,
|
||||
Xor: []byte{0x0, 0x0, 0x0, 0x0},
|
||||
Mask: m.wgIface.Address().Network.Mask,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpNeq,
|
||||
Register: 2,
|
||||
Data: ip.Unmap().AsSlice(),
|
||||
},
|
||||
&expr.Verdict{Kind: expr.VerdictAccept},
|
||||
)
|
||||
}
|
||||
|
||||
_ = m.rConn.AddRule(&nftables.Rule{
|
||||
Table: table,
|
||||
Chain: chain,
|
||||
Exprs: expressions,
|
||||
})
|
||||
|
||||
expressions = []expr.Any{
|
||||
&expr.Meta{Key: ifaceKey, Register: 1},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: ifname(m.wgIface.Name()),
|
||||
},
|
||||
&expr.Verdict{Kind: expr.VerdictDrop},
|
||||
}
|
||||
_ = m.rConn.AddRule(&nftables.Rule{
|
||||
Table: table,
|
||||
Chain: chain,
|
||||
Exprs: expressions,
|
||||
})
|
||||
|
||||
if err := m.rConn.Flush(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return chain, nil
|
||||
return m.aclManager.AddFiltering(ip, proto, sPort, dPort, direction, action, ipsetName, comment)
|
||||
}
|
||||
|
||||
// DeleteRule from the firewall by rule definition
|
||||
func (m *Manager) DeleteRule(rule fw.Rule) error {
|
||||
func (m *Manager) DeleteRule(rule firewall.Rule) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
nativeRule, ok := rule.(*Rule)
|
||||
if !ok {
|
||||
return fmt.Errorf("invalid rule type")
|
||||
}
|
||||
return m.aclManager.DeleteRule(rule)
|
||||
}
|
||||
|
||||
if nativeRule.nftRule == nil {
|
||||
func (m *Manager) IsServerRouteSupported() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (m *Manager) InsertRoutingRules(pair firewall.RouterPair) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
return m.router.InsertRoutingRules(pair)
|
||||
}
|
||||
|
||||
func (m *Manager) RemoveRoutingRules(pair firewall.RouterPair) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
return m.router.RemoveRoutingRules(pair)
|
||||
}
|
||||
|
||||
// AllowNetbird allows netbird interface traffic
|
||||
func (m *Manager) AllowNetbird() error {
|
||||
if !m.wgIface.IsUserspaceBind() {
|
||||
return nil
|
||||
}
|
||||
|
||||
if nativeRule.nftSet != nil {
|
||||
// call twice of delete set element raises error
|
||||
// so we need to check if element is already removed
|
||||
key := fmt.Sprintf("%s:%v", nativeRule.nftSet.Name, nativeRule.ip)
|
||||
if _, ok := m.setRemovedIPs[key]; !ok {
|
||||
err := m.sConn.SetDeleteElements(nativeRule.nftSet, []nftables.SetElement{{Key: nativeRule.ip}})
|
||||
if err != nil {
|
||||
log.Errorf("delete elements for set %q: %v", nativeRule.nftSet.Name, err)
|
||||
}
|
||||
if err := m.sConn.Flush(); err != nil {
|
||||
return err
|
||||
}
|
||||
m.setRemovedIPs[key] = struct{}{}
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
err := m.aclManager.createDefaultAllowRules()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create default allow rules: %v", err)
|
||||
}
|
||||
|
||||
chains, err := m.rConn.ListChainsOfTableFamily(nftables.TableFamilyIPv4)
|
||||
if err != nil {
|
||||
return fmt.Errorf("list of chains: %w", err)
|
||||
}
|
||||
|
||||
var chain *nftables.Chain
|
||||
for _, c := range chains {
|
||||
if c.Table.Name == "filter" && c.Name == "INPUT" {
|
||||
chain = c
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if m.rulesetManager.deleteRule(nativeRule) {
|
||||
// deleteRule indicates that we still have IP in the ruleset
|
||||
// it means we should not remove the nftables rule but need to update set
|
||||
// so we prepare IP to be removed from set on the next flush call
|
||||
if chain == nil {
|
||||
log.Debugf("chain INPUT not found. Skipping add allow netbird rule")
|
||||
return nil
|
||||
}
|
||||
|
||||
// ruleset doesn't contain IP anymore (or contains only one), remove nft rule
|
||||
if err := m.rConn.DelRule(nativeRule.nftRule); err != nil {
|
||||
log.Errorf("failed to delete rule: %v", err)
|
||||
rules, err := m.rConn.GetRules(chain.Table, chain)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get rules for the INPUT chain: %v", err)
|
||||
}
|
||||
if err := m.rConn.Flush(); err != nil {
|
||||
return err
|
||||
}
|
||||
nativeRule.nftRule = nil
|
||||
|
||||
if nativeRule.nftSet != nil {
|
||||
if _, ok := m.setRemoved[nativeRule.nftSet.Name]; !ok {
|
||||
m.setRemoved[nativeRule.nftSet.Name] = nativeRule.nftSet
|
||||
}
|
||||
nativeRule.nftSet = nil
|
||||
if rule := m.detectAllowNetbirdRule(rules); rule != nil {
|
||||
log.Debugf("allow netbird rule already exists: %v", rule)
|
||||
return nil
|
||||
}
|
||||
|
||||
m.applyAllowNetbirdRules(chain)
|
||||
|
||||
err = m.rConn.Flush()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to flush allow input netbird rules: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -637,18 +166,33 @@ func (m *Manager) Reset() error {
|
||||
if err != nil {
|
||||
return fmt.Errorf("list of chains: %w", err)
|
||||
}
|
||||
|
||||
for _, c := range chains {
|
||||
if c.Name == FilterInputChainName || c.Name == FilterOutputChainName {
|
||||
m.rConn.DelChain(c)
|
||||
// delete Netbird allow input traffic rule if it exists
|
||||
if c.Table.Name == "filter" && c.Name == "INPUT" {
|
||||
rules, err := m.rConn.GetRules(c.Table, c)
|
||||
if err != nil {
|
||||
log.Errorf("get rules for chain %q: %v", c.Name, err)
|
||||
continue
|
||||
}
|
||||
for _, r := range rules {
|
||||
if bytes.Equal(r.UserData, []byte(allowNetbirdInputRuleID)) {
|
||||
if err := m.rConn.DelRule(r); err != nil {
|
||||
log.Errorf("delete rule: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
m.router.ResetForwardRules()
|
||||
|
||||
tables, err := m.rConn.ListTables()
|
||||
if err != nil {
|
||||
return fmt.Errorf("list of tables: %w", err)
|
||||
}
|
||||
for _, t := range tables {
|
||||
if t.Name == FilterTableName {
|
||||
if t.Name == tableName {
|
||||
m.rConn.DelTable(t)
|
||||
}
|
||||
}
|
||||
@@ -659,100 +203,65 @@ func (m *Manager) Reset() error {
|
||||
// Flush rule/chain/set operations from the buffer
|
||||
//
|
||||
// Method also get all rules after flush and refreshes handle values in the rulesets
|
||||
// todo review this method usage
|
||||
func (m *Manager) Flush() error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
if err := m.flushWithBackoff(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// set must be removed after flush rule changes
|
||||
// otherwise we will get error
|
||||
for _, s := range m.setRemoved {
|
||||
m.rConn.FlushSet(s)
|
||||
m.rConn.DelSet(s)
|
||||
}
|
||||
|
||||
if len(m.setRemoved) > 0 {
|
||||
if err := m.flushWithBackoff(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
m.setRemovedIPs = map[string]struct{}{}
|
||||
m.setRemoved = map[string]*nftables.Set{}
|
||||
|
||||
if err := m.refreshRuleHandles(m.tableIPv4, m.filterInputChainIPv4); err != nil {
|
||||
log.Errorf("failed to refresh rule handles ipv4 input chain: %v", err)
|
||||
}
|
||||
|
||||
if err := m.refreshRuleHandles(m.tableIPv4, m.filterOutputChainIPv4); err != nil {
|
||||
log.Errorf("failed to refresh rule handles IPv4 output chain: %v", err)
|
||||
}
|
||||
|
||||
if err := m.refreshRuleHandles(m.tableIPv6, m.filterInputChainIPv6); err != nil {
|
||||
log.Errorf("failed to refresh rule handles IPv6 input chain: %v", err)
|
||||
}
|
||||
|
||||
if err := m.refreshRuleHandles(m.tableIPv6, m.filterOutputChainIPv6); err != nil {
|
||||
log.Errorf("failed to refresh rule handles IPv6 output chain: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
return m.aclManager.Flush()
|
||||
}
|
||||
|
||||
func (m *Manager) flushWithBackoff() (err error) {
|
||||
backoff := 4
|
||||
backoffTime := 1000 * time.Millisecond
|
||||
for i := 0; ; i++ {
|
||||
err = m.rConn.Flush()
|
||||
if err != nil {
|
||||
if !strings.Contains(err.Error(), "busy") {
|
||||
return
|
||||
}
|
||||
log.Error("failed to flush nftables, retrying...")
|
||||
if i == backoff-1 {
|
||||
return err
|
||||
}
|
||||
time.Sleep(backoffTime)
|
||||
backoffTime = backoffTime * 2
|
||||
continue
|
||||
}
|
||||
break
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (m *Manager) refreshRuleHandles(table *nftables.Table, chain *nftables.Chain) error {
|
||||
if table == nil || chain == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
list, err := m.rConn.GetRules(table, chain)
|
||||
func (m *Manager) createWorkTable() (*nftables.Table, error) {
|
||||
tables, err := m.rConn.ListTablesOfFamily(nftables.TableFamilyIPv4)
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, fmt.Errorf("list of tables: %w", err)
|
||||
}
|
||||
|
||||
for _, rule := range list {
|
||||
if len(rule.UserData) != 0 {
|
||||
if err := m.rulesetManager.setNftRuleHandle(rule); err != nil {
|
||||
log.Errorf("failed to set rule handle: %v", err)
|
||||
}
|
||||
for _, t := range tables {
|
||||
if t.Name == tableName {
|
||||
m.rConn.DelTable(t)
|
||||
}
|
||||
}
|
||||
|
||||
table := m.rConn.AddTable(&nftables.Table{Name: tableName, Family: nftables.TableFamilyIPv4})
|
||||
err = m.rConn.Flush()
|
||||
return table, err
|
||||
}
|
||||
|
||||
func (m *Manager) applyAllowNetbirdRules(chain *nftables.Chain) {
|
||||
rule := &nftables.Rule{
|
||||
Table: chain.Table,
|
||||
Chain: chain,
|
||||
Exprs: []expr.Any{
|
||||
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: ifname(m.wgIface.Name()),
|
||||
},
|
||||
&expr.Verdict{
|
||||
Kind: expr.VerdictAccept,
|
||||
},
|
||||
},
|
||||
UserData: []byte(allowNetbirdInputRuleID),
|
||||
}
|
||||
_ = m.rConn.InsertRule(rule)
|
||||
}
|
||||
|
||||
func (m *Manager) detectAllowNetbirdRule(existedRules []*nftables.Rule) *nftables.Rule {
|
||||
ifName := ifname(m.wgIface.Name())
|
||||
for _, rule := range existedRules {
|
||||
if rule.Table.Name == "filter" && rule.Chain.Name == "INPUT" {
|
||||
if len(rule.Exprs) < 4 {
|
||||
if e, ok := rule.Exprs[0].(*expr.Meta); !ok || e.Key != expr.MetaKeyIIFNAME {
|
||||
continue
|
||||
}
|
||||
if e, ok := rule.Exprs[1].(*expr.Cmp); !ok || e.Op != expr.CmpOpEq || !bytes.Equal(e.Data, ifName) {
|
||||
continue
|
||||
}
|
||||
return rule
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func encodePort(port fw.Port) []byte {
|
||||
bs := make([]byte, 2)
|
||||
binary.BigEndian.PutUint16(bs, uint16(port.Values[0]))
|
||||
return bs
|
||||
}
|
||||
|
||||
func ifname(n string) []byte {
|
||||
b := make([]byte, 16)
|
||||
copy(b, []byte(n+"\x00"))
|
||||
return b
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package nftables
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
@@ -12,7 +13,7 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/sys/unix"
|
||||
|
||||
fw "github.com/netbirdio/netbird/client/firewall"
|
||||
fw "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
)
|
||||
|
||||
@@ -36,6 +37,8 @@ func (i *iFaceMock) Address() iface.WGAddress {
|
||||
panic("AddressFunc is not set")
|
||||
}
|
||||
|
||||
func (i *iFaceMock) IsUserspaceBind() bool { return false }
|
||||
|
||||
func TestNftablesManager(t *testing.T) {
|
||||
mock := &iFaceMock{
|
||||
NameFunc: func() string {
|
||||
@@ -53,7 +56,7 @@ func TestNftablesManager(t *testing.T) {
|
||||
}
|
||||
|
||||
// just check on the local interface
|
||||
manager, err := Create(mock)
|
||||
manager, err := Create(context.Background(), mock)
|
||||
require.NoError(t, err)
|
||||
time.Sleep(time.Second * 3)
|
||||
|
||||
@@ -82,14 +85,10 @@ func TestNftablesManager(t *testing.T) {
|
||||
err = manager.Flush()
|
||||
require.NoError(t, err, "failed to flush")
|
||||
|
||||
rules, err := testClient.GetRules(manager.tableIPv4, manager.filterInputChainIPv4)
|
||||
rules, err := testClient.GetRules(manager.aclManager.workTable, manager.aclManager.chainInputRules)
|
||||
require.NoError(t, err, "failed to get rules")
|
||||
|
||||
// test expectations:
|
||||
// 1) regular rule
|
||||
// 2) "accept extra routed traffic rule" for the interface
|
||||
// 3) "drop all rule" for the interface
|
||||
require.Len(t, rules, 3, "expected 3 rules")
|
||||
require.Len(t, rules, 1, "expected 1 rules")
|
||||
|
||||
ipToAdd, _ := netip.AddrFromSlice(ip)
|
||||
add := ipToAdd.Unmap()
|
||||
@@ -137,18 +136,17 @@ func TestNftablesManager(t *testing.T) {
|
||||
}
|
||||
require.ElementsMatch(t, rules[0].Exprs, expectedExprs, "expected the same expressions")
|
||||
|
||||
err = manager.DeleteRule(rule)
|
||||
require.NoError(t, err, "failed to delete rule")
|
||||
for _, r := range rule {
|
||||
err = manager.DeleteRule(r)
|
||||
require.NoError(t, err, "failed to delete rule")
|
||||
}
|
||||
|
||||
err = manager.Flush()
|
||||
require.NoError(t, err, "failed to flush")
|
||||
|
||||
rules, err = testClient.GetRules(manager.tableIPv4, manager.filterInputChainIPv4)
|
||||
rules, err = testClient.GetRules(manager.aclManager.workTable, manager.aclManager.chainInputRules)
|
||||
require.NoError(t, err, "failed to get rules")
|
||||
// test expectations:
|
||||
// 1) "accept extra routed traffic rule" for the interface
|
||||
// 2) "drop all rule" for the interface
|
||||
require.Len(t, rules, 2, "expected 2 rules after deleteion")
|
||||
require.Len(t, rules, 0, "expected 0 rules after deletion")
|
||||
|
||||
err = manager.Reset()
|
||||
require.NoError(t, err, "failed to reset")
|
||||
@@ -173,7 +171,7 @@ func TestNFtablesCreatePerformance(t *testing.T) {
|
||||
for _, testMax := range []int{10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 200, 300, 400, 500, 600, 700, 800, 900, 1000} {
|
||||
t.Run(fmt.Sprintf("Testing %d rules", testMax), func(t *testing.T) {
|
||||
// just check on the local interface
|
||||
manager, err := Create(mock)
|
||||
manager, err := Create(context.Background(), mock)
|
||||
require.NoError(t, err)
|
||||
time.Sleep(time.Second * 3)
|
||||
|
||||
|
||||
413
client/firewall/nftables/route_linux.go
Normal file
413
client/firewall/nftables/route_linux.go
Normal file
@@ -0,0 +1,413 @@
|
||||
package nftables
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
|
||||
"github.com/google/nftables"
|
||||
"github.com/google/nftables/binaryutil"
|
||||
"github.com/google/nftables/expr"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/firewall/manager"
|
||||
)
|
||||
|
||||
const (
|
||||
chainNameRouteingFw = "netbird-rt-fwd"
|
||||
chainNameRoutingNat = "netbird-rt-nat"
|
||||
|
||||
userDataAcceptForwardRuleSrc = "frwacceptsrc"
|
||||
userDataAcceptForwardRuleDst = "frwacceptdst"
|
||||
)
|
||||
|
||||
// some presets for building nftable rules
|
||||
var (
|
||||
zeroXor = binaryutil.NativeEndian.PutUint32(0)
|
||||
|
||||
exprCounterAccept = []expr.Any{
|
||||
&expr.Counter{},
|
||||
&expr.Verdict{
|
||||
Kind: expr.VerdictAccept,
|
||||
},
|
||||
}
|
||||
|
||||
errFilterTableNotFound = fmt.Errorf("nftables: 'filter' table not found")
|
||||
)
|
||||
|
||||
type router struct {
|
||||
ctx context.Context
|
||||
stop context.CancelFunc
|
||||
conn *nftables.Conn
|
||||
workTable *nftables.Table
|
||||
filterTable *nftables.Table
|
||||
chains map[string]*nftables.Chain
|
||||
// rules is useful to avoid duplicates and to get missing attributes that we don't have when adding new rules
|
||||
rules map[string]*nftables.Rule
|
||||
isDefaultFwdRulesEnabled bool
|
||||
}
|
||||
|
||||
func newRouter(parentCtx context.Context, workTable *nftables.Table) (*router, error) {
|
||||
ctx, cancel := context.WithCancel(parentCtx)
|
||||
|
||||
r := &router{
|
||||
ctx: ctx,
|
||||
stop: cancel,
|
||||
conn: &nftables.Conn{},
|
||||
workTable: workTable,
|
||||
chains: make(map[string]*nftables.Chain),
|
||||
rules: make(map[string]*nftables.Rule),
|
||||
}
|
||||
|
||||
var err error
|
||||
r.filterTable, err = r.loadFilterTable()
|
||||
if err != nil {
|
||||
if errors.Is(err, errFilterTableNotFound) {
|
||||
log.Warnf("table 'filter' not found for forward rules")
|
||||
} else {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
err = r.cleanUpDefaultForwardRules()
|
||||
if err != nil {
|
||||
log.Errorf("failed to clean up rules from FORWARD chain: %s", err)
|
||||
}
|
||||
|
||||
err = r.createContainers()
|
||||
if err != nil {
|
||||
log.Errorf("failed to create containers for route: %s", err)
|
||||
}
|
||||
return r, err
|
||||
}
|
||||
|
||||
func (r *router) RouteingFwChainName() string {
|
||||
return chainNameRouteingFw
|
||||
}
|
||||
|
||||
// ResetForwardRules cleans existing nftables default forward rules from the system
|
||||
func (r *router) ResetForwardRules() {
|
||||
err := r.cleanUpDefaultForwardRules()
|
||||
if err != nil {
|
||||
log.Errorf("failed to reset forward rules: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (r *router) loadFilterTable() (*nftables.Table, error) {
|
||||
tables, err := r.conn.ListTablesOfFamily(nftables.TableFamilyIPv4)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("nftables: unable to list tables: %v", err)
|
||||
}
|
||||
|
||||
for _, table := range tables {
|
||||
if table.Name == "filter" {
|
||||
return table, nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil, errFilterTableNotFound
|
||||
}
|
||||
|
||||
func (r *router) createContainers() error {
|
||||
|
||||
r.chains[chainNameRouteingFw] = r.conn.AddChain(&nftables.Chain{
|
||||
Name: chainNameRouteingFw,
|
||||
Table: r.workTable,
|
||||
})
|
||||
|
||||
r.chains[chainNameRoutingNat] = r.conn.AddChain(&nftables.Chain{
|
||||
Name: chainNameRoutingNat,
|
||||
Table: r.workTable,
|
||||
Hooknum: nftables.ChainHookPostrouting,
|
||||
Priority: nftables.ChainPriorityNATSource - 1,
|
||||
Type: nftables.ChainTypeNAT,
|
||||
})
|
||||
|
||||
err := r.refreshRulesMap()
|
||||
if err != nil {
|
||||
log.Errorf("failed to clean up rules from FORWARD chain: %s", err)
|
||||
}
|
||||
|
||||
err = r.conn.Flush()
|
||||
if err != nil {
|
||||
return fmt.Errorf("nftables: unable to initialize table: %v", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// InsertRoutingRules inserts a nftable rule pair to the forwarding chain and if enabled, to the nat chain
|
||||
func (r *router) InsertRoutingRules(pair manager.RouterPair) error {
|
||||
err := r.refreshRulesMap()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = r.insertRoutingRule(manager.ForwardingFormat, chainNameRouteingFw, pair, false)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = r.insertRoutingRule(manager.InForwardingFormat, chainNameRouteingFw, manager.GetInPair(pair), false)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if pair.Masquerade {
|
||||
err = r.insertRoutingRule(manager.NatFormat, chainNameRoutingNat, pair, true)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = r.insertRoutingRule(manager.InNatFormat, chainNameRoutingNat, manager.GetInPair(pair), true)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if r.filterTable != nil && !r.isDefaultFwdRulesEnabled {
|
||||
log.Debugf("add default accept forward rule")
|
||||
r.acceptForwardRule(pair.Source)
|
||||
}
|
||||
|
||||
err = r.conn.Flush()
|
||||
if err != nil {
|
||||
return fmt.Errorf("nftables: unable to insert rules for %s: %v", pair.Destination, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// insertRoutingRule inserts a nftable rule to the conn client flush queue
|
||||
func (r *router) insertRoutingRule(format, chainName string, pair manager.RouterPair, isNat bool) error {
|
||||
sourceExp := generateCIDRMatcherExpressions(true, pair.Source)
|
||||
destExp := generateCIDRMatcherExpressions(false, pair.Destination)
|
||||
|
||||
var expression []expr.Any
|
||||
if isNat {
|
||||
expression = append(sourceExp, append(destExp, &expr.Counter{}, &expr.Masq{})...) // nolint:gocritic
|
||||
} else {
|
||||
expression = append(sourceExp, append(destExp, exprCounterAccept...)...) // nolint:gocritic
|
||||
}
|
||||
|
||||
ruleKey := manager.GenKey(format, pair.ID)
|
||||
|
||||
_, exists := r.rules[ruleKey]
|
||||
if exists {
|
||||
err := r.removeRoutingRule(format, pair)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
r.rules[ruleKey] = r.conn.InsertRule(&nftables.Rule{
|
||||
Table: r.workTable,
|
||||
Chain: r.chains[chainName],
|
||||
Exprs: expression,
|
||||
UserData: []byte(ruleKey),
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *router) acceptForwardRule(sourceNetwork string) {
|
||||
src := generateCIDRMatcherExpressions(true, sourceNetwork)
|
||||
dst := generateCIDRMatcherExpressions(false, "0.0.0.0/0")
|
||||
|
||||
var exprs []expr.Any
|
||||
exprs = append(src, append(dst, &expr.Verdict{ // nolint:gocritic
|
||||
Kind: expr.VerdictAccept,
|
||||
})...)
|
||||
|
||||
rule := &nftables.Rule{
|
||||
Table: r.filterTable,
|
||||
Chain: &nftables.Chain{
|
||||
Name: "FORWARD",
|
||||
Table: r.filterTable,
|
||||
Type: nftables.ChainTypeFilter,
|
||||
Hooknum: nftables.ChainHookForward,
|
||||
Priority: nftables.ChainPriorityFilter,
|
||||
},
|
||||
Exprs: exprs,
|
||||
UserData: []byte(userDataAcceptForwardRuleSrc),
|
||||
}
|
||||
|
||||
r.conn.AddRule(rule)
|
||||
|
||||
src = generateCIDRMatcherExpressions(true, "0.0.0.0/0")
|
||||
dst = generateCIDRMatcherExpressions(false, sourceNetwork)
|
||||
|
||||
exprs = append(src, append(dst, &expr.Verdict{ //nolint:gocritic
|
||||
Kind: expr.VerdictAccept,
|
||||
})...)
|
||||
|
||||
rule = &nftables.Rule{
|
||||
Table: r.filterTable,
|
||||
Chain: &nftables.Chain{
|
||||
Name: "FORWARD",
|
||||
Table: r.filterTable,
|
||||
Type: nftables.ChainTypeFilter,
|
||||
Hooknum: nftables.ChainHookForward,
|
||||
Priority: nftables.ChainPriorityFilter,
|
||||
},
|
||||
Exprs: exprs,
|
||||
UserData: []byte(userDataAcceptForwardRuleDst),
|
||||
}
|
||||
r.conn.AddRule(rule)
|
||||
r.isDefaultFwdRulesEnabled = true
|
||||
}
|
||||
|
||||
// RemoveRoutingRules removes a nftable rule pair from forwarding and nat chains
|
||||
func (r *router) RemoveRoutingRules(pair manager.RouterPair) error {
|
||||
err := r.refreshRulesMap()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = r.removeRoutingRule(manager.ForwardingFormat, pair)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = r.removeRoutingRule(manager.InForwardingFormat, manager.GetInPair(pair))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = r.removeRoutingRule(manager.NatFormat, pair)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = r.removeRoutingRule(manager.InNatFormat, manager.GetInPair(pair))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(r.rules) == 0 {
|
||||
err := r.cleanUpDefaultForwardRules()
|
||||
if err != nil {
|
||||
log.Errorf("failed to clean up rules from FORWARD chain: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
err = r.conn.Flush()
|
||||
if err != nil {
|
||||
return fmt.Errorf("nftables: received error while applying rule removal for %s: %v", pair.Destination, err)
|
||||
}
|
||||
log.Debugf("nftables: removed rules for %s", pair.Destination)
|
||||
return nil
|
||||
}
|
||||
|
||||
// removeRoutingRule add a nftable rule to the removal queue and delete from rules map
|
||||
func (r *router) removeRoutingRule(format string, pair manager.RouterPair) error {
|
||||
ruleKey := manager.GenKey(format, pair.ID)
|
||||
|
||||
rule, found := r.rules[ruleKey]
|
||||
if found {
|
||||
ruleType := "forwarding"
|
||||
if rule.Chain.Type == nftables.ChainTypeNAT {
|
||||
ruleType = "nat"
|
||||
}
|
||||
|
||||
err := r.conn.DelRule(rule)
|
||||
if err != nil {
|
||||
return fmt.Errorf("nftables: unable to remove %s rule for %s: %v", ruleType, pair.Destination, err)
|
||||
}
|
||||
|
||||
log.Debugf("nftables: removing %s rule for %s", ruleType, pair.Destination)
|
||||
|
||||
delete(r.rules, ruleKey)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// refreshRulesMap refreshes the rule map with the latest rules. this is useful to avoid
|
||||
// duplicates and to get missing attributes that we don't have when adding new rules
|
||||
func (r *router) refreshRulesMap() error {
|
||||
for _, chain := range r.chains {
|
||||
rules, err := r.conn.GetRules(chain.Table, chain)
|
||||
if err != nil {
|
||||
return fmt.Errorf("nftables: unable to list rules: %v", err)
|
||||
}
|
||||
for _, rule := range rules {
|
||||
if len(rule.UserData) > 0 {
|
||||
r.rules[string(rule.UserData)] = rule
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *router) cleanUpDefaultForwardRules() error {
|
||||
if r.filterTable == nil {
|
||||
r.isDefaultFwdRulesEnabled = false
|
||||
return nil
|
||||
}
|
||||
|
||||
chains, err := r.conn.ListChainsOfTableFamily(nftables.TableFamilyIPv4)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var rules []*nftables.Rule
|
||||
for _, chain := range chains {
|
||||
if chain.Table.Name != r.filterTable.Name {
|
||||
continue
|
||||
}
|
||||
if chain.Name != "FORWARD" {
|
||||
continue
|
||||
}
|
||||
|
||||
rules, err = r.conn.GetRules(r.filterTable, chain)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
for _, rule := range rules {
|
||||
if bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleSrc)) || bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleDst)) {
|
||||
err := r.conn.DelRule(rule)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
r.isDefaultFwdRulesEnabled = false
|
||||
return r.conn.Flush()
|
||||
}
|
||||
|
||||
// generateCIDRMatcherExpressions generates nftables expressions that matches a CIDR
|
||||
func generateCIDRMatcherExpressions(source bool, cidr string) []expr.Any {
|
||||
ip, network, _ := net.ParseCIDR(cidr)
|
||||
ipToAdd, _ := netip.AddrFromSlice(ip)
|
||||
add := ipToAdd.Unmap()
|
||||
|
||||
var offSet uint32
|
||||
if source {
|
||||
offSet = 12 // src offset
|
||||
} else {
|
||||
offSet = 16 // dst offset
|
||||
}
|
||||
|
||||
return []expr.Any{
|
||||
// fetch src add
|
||||
&expr.Payload{
|
||||
DestRegister: 1,
|
||||
Base: expr.PayloadBaseNetworkHeader,
|
||||
Offset: offSet,
|
||||
Len: 4,
|
||||
},
|
||||
// net mask
|
||||
&expr.Bitwise{
|
||||
DestRegister: 1,
|
||||
SourceRegister: 1,
|
||||
Len: 4,
|
||||
Mask: network.Mask,
|
||||
Xor: zeroXor,
|
||||
},
|
||||
// net address
|
||||
&expr.Cmp{
|
||||
Register: 1,
|
||||
Data: add.AsSlice(),
|
||||
},
|
||||
}
|
||||
}
|
||||
280
client/firewall/nftables/router_linux_test.go
Normal file
280
client/firewall/nftables/router_linux_test.go
Normal file
@@ -0,0 +1,280 @@
|
||||
//go:build !android
|
||||
|
||||
package nftables
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/coreos/go-iptables/iptables"
|
||||
"github.com/google/nftables"
|
||||
"github.com/google/nftables/expr"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/firewall/test"
|
||||
)
|
||||
|
||||
const (
|
||||
// UNKNOWN is the default value for the firewall type for unknown firewall type
|
||||
UNKNOWN = iota
|
||||
// IPTABLES is the value for the iptables firewall type
|
||||
IPTABLES
|
||||
// NFTABLES is the value for the nftables firewall type
|
||||
NFTABLES
|
||||
)
|
||||
|
||||
func TestNftablesManager_InsertRoutingRules(t *testing.T) {
|
||||
if check() != NFTABLES {
|
||||
t.Skip("nftables not supported on this OS")
|
||||
}
|
||||
|
||||
table, err := createWorkTable()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
defer deleteWorkTable()
|
||||
|
||||
for _, testCase := range test.InsertRuleTestCases {
|
||||
t.Run(testCase.Name, func(t *testing.T) {
|
||||
manager, err := newRouter(context.TODO(), table)
|
||||
require.NoError(t, err, "failed to create router")
|
||||
|
||||
nftablesTestingClient := &nftables.Conn{}
|
||||
|
||||
defer manager.ResetForwardRules()
|
||||
|
||||
require.NoError(t, err, "shouldn't return error")
|
||||
|
||||
err = manager.InsertRoutingRules(testCase.InputPair)
|
||||
defer func() {
|
||||
_ = manager.RemoveRoutingRules(testCase.InputPair)
|
||||
}()
|
||||
require.NoError(t, err, "forwarding pair should be inserted")
|
||||
|
||||
sourceExp := generateCIDRMatcherExpressions(true, testCase.InputPair.Source)
|
||||
destExp := generateCIDRMatcherExpressions(false, testCase.InputPair.Destination)
|
||||
testingExpression := append(sourceExp, destExp...) //nolint:gocritic
|
||||
fwdRuleKey := firewall.GenKey(firewall.ForwardingFormat, testCase.InputPair.ID)
|
||||
|
||||
found := 0
|
||||
for _, chain := range manager.chains {
|
||||
rules, err := nftablesTestingClient.GetRules(chain.Table, chain)
|
||||
require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name)
|
||||
for _, rule := range rules {
|
||||
if len(rule.UserData) > 0 && string(rule.UserData) == fwdRuleKey {
|
||||
require.ElementsMatchf(t, rule.Exprs[:len(testingExpression)], testingExpression, "forwarding rule elements should match")
|
||||
found = 1
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
require.Equal(t, 1, found, "should find at least 1 rule to test")
|
||||
|
||||
if testCase.InputPair.Masquerade {
|
||||
natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair.ID)
|
||||
found := 0
|
||||
for _, chain := range manager.chains {
|
||||
rules, err := nftablesTestingClient.GetRules(chain.Table, chain)
|
||||
require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name)
|
||||
for _, rule := range rules {
|
||||
if len(rule.UserData) > 0 && string(rule.UserData) == natRuleKey {
|
||||
require.ElementsMatchf(t, rule.Exprs[:len(testingExpression)], testingExpression, "nat rule elements should match")
|
||||
found = 1
|
||||
}
|
||||
}
|
||||
}
|
||||
require.Equal(t, 1, found, "should find at least 1 rule to test")
|
||||
}
|
||||
|
||||
sourceExp = generateCIDRMatcherExpressions(true, firewall.GetInPair(testCase.InputPair).Source)
|
||||
destExp = generateCIDRMatcherExpressions(false, firewall.GetInPair(testCase.InputPair).Destination)
|
||||
testingExpression = append(sourceExp, destExp...) //nolint:gocritic
|
||||
inFwdRuleKey := firewall.GenKey(firewall.InForwardingFormat, testCase.InputPair.ID)
|
||||
|
||||
found = 0
|
||||
for _, chain := range manager.chains {
|
||||
rules, err := nftablesTestingClient.GetRules(chain.Table, chain)
|
||||
require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name)
|
||||
for _, rule := range rules {
|
||||
if len(rule.UserData) > 0 && string(rule.UserData) == inFwdRuleKey {
|
||||
require.ElementsMatchf(t, rule.Exprs[:len(testingExpression)], testingExpression, "income forwarding rule elements should match")
|
||||
found = 1
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
require.Equal(t, 1, found, "should find at least 1 rule to test")
|
||||
|
||||
if testCase.InputPair.Masquerade {
|
||||
inNatRuleKey := firewall.GenKey(firewall.InNatFormat, testCase.InputPair.ID)
|
||||
found := 0
|
||||
for _, chain := range manager.chains {
|
||||
rules, err := nftablesTestingClient.GetRules(chain.Table, chain)
|
||||
require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name)
|
||||
for _, rule := range rules {
|
||||
if len(rule.UserData) > 0 && string(rule.UserData) == inNatRuleKey {
|
||||
require.ElementsMatchf(t, rule.Exprs[:len(testingExpression)], testingExpression, "income nat rule elements should match")
|
||||
found = 1
|
||||
}
|
||||
}
|
||||
}
|
||||
require.Equal(t, 1, found, "should find at least 1 rule to test")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNftablesManager_RemoveRoutingRules(t *testing.T) {
|
||||
if check() != NFTABLES {
|
||||
t.Skip("nftables not supported on this OS")
|
||||
}
|
||||
|
||||
table, err := createWorkTable()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
defer deleteWorkTable()
|
||||
|
||||
for _, testCase := range test.RemoveRuleTestCases {
|
||||
t.Run(testCase.Name, func(t *testing.T) {
|
||||
manager, err := newRouter(context.TODO(), table)
|
||||
require.NoError(t, err, "failed to create router")
|
||||
|
||||
nftablesTestingClient := &nftables.Conn{}
|
||||
|
||||
defer manager.ResetForwardRules()
|
||||
|
||||
sourceExp := generateCIDRMatcherExpressions(true, testCase.InputPair.Source)
|
||||
destExp := generateCIDRMatcherExpressions(false, testCase.InputPair.Destination)
|
||||
|
||||
forwardExp := append(sourceExp, append(destExp, exprCounterAccept...)...) //nolint:gocritic
|
||||
forwardRuleKey := firewall.GenKey(firewall.ForwardingFormat, testCase.InputPair.ID)
|
||||
insertedForwarding := nftablesTestingClient.InsertRule(&nftables.Rule{
|
||||
Table: manager.workTable,
|
||||
Chain: manager.chains[chainNameRouteingFw],
|
||||
Exprs: forwardExp,
|
||||
UserData: []byte(forwardRuleKey),
|
||||
})
|
||||
|
||||
natExp := append(sourceExp, append(destExp, &expr.Counter{}, &expr.Masq{})...) //nolint:gocritic
|
||||
natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair.ID)
|
||||
|
||||
insertedNat := nftablesTestingClient.InsertRule(&nftables.Rule{
|
||||
Table: manager.workTable,
|
||||
Chain: manager.chains[chainNameRoutingNat],
|
||||
Exprs: natExp,
|
||||
UserData: []byte(natRuleKey),
|
||||
})
|
||||
|
||||
sourceExp = generateCIDRMatcherExpressions(true, firewall.GetInPair(testCase.InputPair).Source)
|
||||
destExp = generateCIDRMatcherExpressions(false, firewall.GetInPair(testCase.InputPair).Destination)
|
||||
|
||||
forwardExp = append(sourceExp, append(destExp, exprCounterAccept...)...) //nolint:gocritic
|
||||
inForwardRuleKey := firewall.GenKey(firewall.InForwardingFormat, testCase.InputPair.ID)
|
||||
insertedInForwarding := nftablesTestingClient.InsertRule(&nftables.Rule{
|
||||
Table: manager.workTable,
|
||||
Chain: manager.chains[chainNameRouteingFw],
|
||||
Exprs: forwardExp,
|
||||
UserData: []byte(inForwardRuleKey),
|
||||
})
|
||||
|
||||
natExp = append(sourceExp, append(destExp, &expr.Counter{}, &expr.Masq{})...) //nolint:gocritic
|
||||
inNatRuleKey := firewall.GenKey(firewall.InNatFormat, testCase.InputPair.ID)
|
||||
|
||||
insertedInNat := nftablesTestingClient.InsertRule(&nftables.Rule{
|
||||
Table: manager.workTable,
|
||||
Chain: manager.chains[chainNameRoutingNat],
|
||||
Exprs: natExp,
|
||||
UserData: []byte(inNatRuleKey),
|
||||
})
|
||||
|
||||
err = nftablesTestingClient.Flush()
|
||||
require.NoError(t, err, "shouldn't return error")
|
||||
|
||||
manager.ResetForwardRules()
|
||||
|
||||
err = manager.RemoveRoutingRules(testCase.InputPair)
|
||||
require.NoError(t, err, "shouldn't return error")
|
||||
|
||||
for _, chain := range manager.chains {
|
||||
rules, err := nftablesTestingClient.GetRules(chain.Table, chain)
|
||||
require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name)
|
||||
for _, rule := range rules {
|
||||
if len(rule.UserData) > 0 {
|
||||
require.NotEqual(t, insertedForwarding.UserData, rule.UserData, "forwarding rule should not exist")
|
||||
require.NotEqual(t, insertedNat.UserData, rule.UserData, "nat rule should not exist")
|
||||
require.NotEqual(t, insertedInForwarding.UserData, rule.UserData, "income forwarding rule should not exist")
|
||||
require.NotEqual(t, insertedInNat.UserData, rule.UserData, "income nat rule should not exist")
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// check returns the firewall type based on common lib checks. It returns UNKNOWN if no firewall is found.
|
||||
func check() int {
|
||||
nf := nftables.Conn{}
|
||||
if _, err := nf.ListChains(); err == nil {
|
||||
return NFTABLES
|
||||
}
|
||||
|
||||
ip, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
||||
if err != nil {
|
||||
return UNKNOWN
|
||||
}
|
||||
if isIptablesClientAvailable(ip) {
|
||||
return IPTABLES
|
||||
}
|
||||
|
||||
return UNKNOWN
|
||||
}
|
||||
|
||||
func isIptablesClientAvailable(client *iptables.IPTables) bool {
|
||||
_, err := client.ListChains("filter")
|
||||
return err == nil
|
||||
}
|
||||
|
||||
func createWorkTable() (*nftables.Table, error) {
|
||||
sConn, err := nftables.New(nftables.AsLasting())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
tables, err := sConn.ListTablesOfFamily(nftables.TableFamilyIPv4)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, t := range tables {
|
||||
if t.Name == tableName {
|
||||
sConn.DelTable(t)
|
||||
}
|
||||
}
|
||||
|
||||
table := sConn.AddTable(&nftables.Table{Name: tableName, Family: nftables.TableFamilyIPv4})
|
||||
err = sConn.Flush()
|
||||
|
||||
return table, err
|
||||
}
|
||||
|
||||
func deleteWorkTable() {
|
||||
sConn, err := nftables.New(nftables.AsLasting())
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
tables, err := sConn.ListTablesOfFamily(nftables.TableFamilyIPv4)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
for _, t := range tables {
|
||||
if t.Name == tableName {
|
||||
sConn.DelTable(t)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,6 +1,8 @@
|
||||
package nftables
|
||||
|
||||
import (
|
||||
"net"
|
||||
|
||||
"github.com/google/nftables"
|
||||
)
|
||||
|
||||
@@ -8,9 +10,8 @@ import (
|
||||
type Rule struct {
|
||||
nftRule *nftables.Rule
|
||||
nftSet *nftables.Set
|
||||
|
||||
ruleID string
|
||||
ip []byte
|
||||
ruleID string
|
||||
ip net.IP
|
||||
}
|
||||
|
||||
// GetRuleID returns the rule id
|
||||
|
||||
@@ -1,115 +0,0 @@
|
||||
package nftables
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
|
||||
"github.com/google/nftables"
|
||||
"github.com/rs/xid"
|
||||
)
|
||||
|
||||
// nftRuleset links native firewall rule and ipset to ACL generated rules
|
||||
type nftRuleset struct {
|
||||
nftRule *nftables.Rule
|
||||
nftSet *nftables.Set
|
||||
issuedRules map[string]*Rule
|
||||
rulesetID string
|
||||
}
|
||||
|
||||
type rulesetManager struct {
|
||||
rulesets map[string]*nftRuleset
|
||||
|
||||
nftSetName2rulesetID map[string]string
|
||||
issuedRuleID2rulesetID map[string]string
|
||||
}
|
||||
|
||||
func newRuleManager() *rulesetManager {
|
||||
return &rulesetManager{
|
||||
rulesets: map[string]*nftRuleset{},
|
||||
|
||||
nftSetName2rulesetID: map[string]string{},
|
||||
issuedRuleID2rulesetID: map[string]string{},
|
||||
}
|
||||
}
|
||||
|
||||
func (r *rulesetManager) getRuleset(rulesetID string) (*nftRuleset, bool) {
|
||||
ruleset, ok := r.rulesets[rulesetID]
|
||||
return ruleset, ok
|
||||
}
|
||||
|
||||
func (r *rulesetManager) createRuleset(
|
||||
rulesetID string,
|
||||
nftRule *nftables.Rule,
|
||||
nftSet *nftables.Set,
|
||||
) *nftRuleset {
|
||||
ruleset := nftRuleset{
|
||||
rulesetID: rulesetID,
|
||||
nftRule: nftRule,
|
||||
nftSet: nftSet,
|
||||
issuedRules: map[string]*Rule{},
|
||||
}
|
||||
r.rulesets[ruleset.rulesetID] = &ruleset
|
||||
if nftSet != nil {
|
||||
r.nftSetName2rulesetID[nftSet.Name] = ruleset.rulesetID
|
||||
}
|
||||
return &ruleset
|
||||
}
|
||||
|
||||
func (r *rulesetManager) addRule(
|
||||
ruleset *nftRuleset,
|
||||
ip []byte,
|
||||
) (*Rule, error) {
|
||||
if _, ok := r.rulesets[ruleset.rulesetID]; !ok {
|
||||
return nil, fmt.Errorf("ruleset not found")
|
||||
}
|
||||
|
||||
rule := Rule{
|
||||
nftRule: ruleset.nftRule,
|
||||
nftSet: ruleset.nftSet,
|
||||
ruleID: xid.New().String(),
|
||||
ip: ip,
|
||||
}
|
||||
|
||||
ruleset.issuedRules[rule.ruleID] = &rule
|
||||
r.issuedRuleID2rulesetID[rule.ruleID] = ruleset.rulesetID
|
||||
|
||||
return &rule, nil
|
||||
}
|
||||
|
||||
// deleteRule from ruleset and returns true if contains other rules
|
||||
func (r *rulesetManager) deleteRule(rule *Rule) bool {
|
||||
rulesetID, ok := r.issuedRuleID2rulesetID[rule.ruleID]
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
ruleset := r.rulesets[rulesetID]
|
||||
if ruleset.nftRule == nil {
|
||||
return false
|
||||
}
|
||||
delete(r.issuedRuleID2rulesetID, rule.ruleID)
|
||||
delete(ruleset.issuedRules, rule.ruleID)
|
||||
|
||||
if len(ruleset.issuedRules) == 0 {
|
||||
delete(r.rulesets, ruleset.rulesetID)
|
||||
if rule.nftSet != nil {
|
||||
delete(r.nftSetName2rulesetID, rule.nftSet.Name)
|
||||
}
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// setNftRuleHandle finds rule by userdata which contains rulesetID and updates it's handle number
|
||||
//
|
||||
// This is important to do, because after we add rule to the nftables we can't update it until
|
||||
// we set correct handle value to it.
|
||||
func (r *rulesetManager) setNftRuleHandle(nftRule *nftables.Rule) error {
|
||||
split := bytes.Split(nftRule.UserData, []byte(" "))
|
||||
ruleset, ok := r.rulesets[string(split[0])]
|
||||
if !ok {
|
||||
return fmt.Errorf("ruleset not found")
|
||||
}
|
||||
*ruleset.nftRule = *nftRule
|
||||
return nil
|
||||
}
|
||||
@@ -1,122 +0,0 @@
|
||||
package nftables
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/google/nftables"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestRulesetManager_createRuleset(t *testing.T) {
|
||||
// Create a ruleset manager.
|
||||
rulesetManager := newRuleManager()
|
||||
|
||||
// Create a ruleset.
|
||||
rulesetID := "ruleset-1"
|
||||
nftRule := nftables.Rule{
|
||||
UserData: []byte(rulesetID),
|
||||
}
|
||||
ruleset := rulesetManager.createRuleset(rulesetID, &nftRule, nil)
|
||||
require.NotNil(t, ruleset, "createRuleset() failed")
|
||||
require.Equal(t, ruleset.rulesetID, rulesetID, "rulesetID is incorrect")
|
||||
require.Equal(t, ruleset.nftRule, &nftRule, "nftRule is incorrect")
|
||||
}
|
||||
|
||||
func TestRulesetManager_addRule(t *testing.T) {
|
||||
// Create a ruleset manager.
|
||||
rulesetManager := newRuleManager()
|
||||
|
||||
// Create a ruleset.
|
||||
rulesetID := "ruleset-1"
|
||||
nftRule := nftables.Rule{}
|
||||
ruleset := rulesetManager.createRuleset(rulesetID, &nftRule, nil)
|
||||
|
||||
// Add a rule to the ruleset.
|
||||
ip := []byte("192.168.1.1")
|
||||
rule, err := rulesetManager.addRule(ruleset, ip)
|
||||
require.NoError(t, err, "addRule() failed")
|
||||
require.NotNil(t, rule, "rule should not be nil")
|
||||
require.NotEqual(t, rule.ruleID, "ruleID is empty")
|
||||
require.EqualValues(t, rule.ip, ip, "ip is incorrect")
|
||||
require.Contains(t, ruleset.issuedRules, rule.ruleID, "ruleID already exists in ruleset")
|
||||
require.Contains(t, rulesetManager.issuedRuleID2rulesetID, rule.ruleID, "ruleID already exists in ruleset manager")
|
||||
|
||||
ruleset2 := &nftRuleset{
|
||||
rulesetID: "ruleset-2",
|
||||
}
|
||||
_, err = rulesetManager.addRule(ruleset2, ip)
|
||||
require.Error(t, err, "addRule() should have failed")
|
||||
}
|
||||
|
||||
func TestRulesetManager_deleteRule(t *testing.T) {
|
||||
// Create a ruleset manager.
|
||||
rulesetManager := newRuleManager()
|
||||
|
||||
// Create a ruleset.
|
||||
rulesetID := "ruleset-1"
|
||||
nftRule := nftables.Rule{}
|
||||
ruleset := rulesetManager.createRuleset(rulesetID, &nftRule, nil)
|
||||
|
||||
// Add a rule to the ruleset.
|
||||
ip := []byte("192.168.1.1")
|
||||
rule, err := rulesetManager.addRule(ruleset, ip)
|
||||
require.NoError(t, err, "addRule() failed")
|
||||
require.NotNil(t, rule, "rule should not be nil")
|
||||
|
||||
ip2 := []byte("192.168.1.1")
|
||||
rule2, err := rulesetManager.addRule(ruleset, ip2)
|
||||
require.NoError(t, err, "addRule() failed")
|
||||
require.NotNil(t, rule2, "rule should not be nil")
|
||||
|
||||
hasNext := rulesetManager.deleteRule(rule)
|
||||
require.True(t, hasNext, "deleteRule() should have returned true")
|
||||
|
||||
// Check that the rule is no longer in the manager.
|
||||
require.NotContains(t, rulesetManager.issuedRuleID2rulesetID, rule.ruleID, "rule should have been deleted")
|
||||
|
||||
hasNext = rulesetManager.deleteRule(rule2)
|
||||
require.False(t, hasNext, "deleteRule() should have returned false")
|
||||
}
|
||||
|
||||
func TestRulesetManager_setNftRuleHandle(t *testing.T) {
|
||||
// Create a ruleset manager.
|
||||
rulesetManager := newRuleManager()
|
||||
// Create a ruleset.
|
||||
rulesetID := "ruleset-1"
|
||||
nftRule := nftables.Rule{}
|
||||
ruleset := rulesetManager.createRuleset(rulesetID, &nftRule, nil)
|
||||
// Add a rule to the ruleset.
|
||||
ip := []byte("192.168.0.1")
|
||||
|
||||
rule, err := rulesetManager.addRule(ruleset, ip)
|
||||
require.NoError(t, err, "addRule() failed")
|
||||
require.NotNil(t, rule, "rule should not be nil")
|
||||
|
||||
nftRuleCopy := nftRule
|
||||
nftRuleCopy.Handle = 2
|
||||
nftRuleCopy.UserData = []byte(rulesetID)
|
||||
err = rulesetManager.setNftRuleHandle(&nftRuleCopy)
|
||||
require.NoError(t, err, "setNftRuleHandle() failed")
|
||||
// check correct work with references
|
||||
require.Equal(t, nftRule.Handle, uint64(2), "nftRule.Handle is incorrect")
|
||||
}
|
||||
|
||||
func TestRulesetManager_getRuleset(t *testing.T) {
|
||||
// Create a ruleset manager.
|
||||
rulesetManager := newRuleManager()
|
||||
// Create a ruleset.
|
||||
rulesetID := "ruleset-1"
|
||||
nftRule := nftables.Rule{}
|
||||
nftSet := nftables.Set{
|
||||
ID: 2,
|
||||
}
|
||||
ruleset := rulesetManager.createRuleset(rulesetID, &nftRule, &nftSet)
|
||||
require.NotNil(t, ruleset, "createRuleset() failed")
|
||||
|
||||
find, ok := rulesetManager.getRuleset(rulesetID)
|
||||
require.True(t, ok, "getRuleset() failed")
|
||||
require.Equal(t, ruleset, find, "getRulesetBySetID() failed")
|
||||
|
||||
_, ok = rulesetManager.getRuleset("does-not-exist")
|
||||
require.False(t, ok, "getRuleset() failed")
|
||||
}
|
||||
47
client/firewall/test/cases_linux.go
Normal file
47
client/firewall/test/cases_linux.go
Normal file
@@ -0,0 +1,47 @@
|
||||
//go:build !android
|
||||
|
||||
package test
|
||||
|
||||
import firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
|
||||
var (
|
||||
InsertRuleTestCases = []struct {
|
||||
Name string
|
||||
InputPair firewall.RouterPair
|
||||
}{
|
||||
{
|
||||
Name: "Insert Forwarding IPV4 Rule",
|
||||
InputPair: firewall.RouterPair{
|
||||
ID: "zxa",
|
||||
Source: "100.100.100.1/32",
|
||||
Destination: "100.100.200.0/24",
|
||||
Masquerade: false,
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "Insert Forwarding And Nat IPV4 Rules",
|
||||
InputPair: firewall.RouterPair{
|
||||
ID: "zxa",
|
||||
Source: "100.100.100.1/32",
|
||||
Destination: "100.100.200.0/24",
|
||||
Masquerade: true,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
RemoveRuleTestCases = []struct {
|
||||
Name string
|
||||
InputPair firewall.RouterPair
|
||||
IpVersion string
|
||||
}{
|
||||
{
|
||||
Name: "Remove Forwarding And Nat IPV4 Rules",
|
||||
InputPair: firewall.RouterPair{
|
||||
ID: "zxa",
|
||||
Source: "100.100.100.1/32",
|
||||
Destination: "100.100.200.0/24",
|
||||
Masquerade: true,
|
||||
},
|
||||
},
|
||||
}
|
||||
)
|
||||
25
client/firewall/uspfilter/allow_netbird.go
Normal file
25
client/firewall/uspfilter/allow_netbird.go
Normal file
@@ -0,0 +1,25 @@
|
||||
//go:build !windows
|
||||
|
||||
package uspfilter
|
||||
|
||||
// Reset firewall to the default state
|
||||
func (m *Manager) Reset() error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
m.outgoingRules = make(map[string]RuleSet)
|
||||
m.incomingRules = make(map[string]RuleSet)
|
||||
|
||||
if m.nativeFirewall != nil {
|
||||
return m.nativeFirewall.Reset()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// AllowNetbird allows netbird interface traffic
|
||||
func (m *Manager) AllowNetbird() error {
|
||||
if m.nativeFirewall != nil {
|
||||
return m.nativeFirewall.AllowNetbird()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
94
client/firewall/uspfilter/allow_netbird_windows.go
Normal file
94
client/firewall/uspfilter/allow_netbird_windows.go
Normal file
@@ -0,0 +1,94 @@
|
||||
package uspfilter
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os/exec"
|
||||
"syscall"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
type action string
|
||||
|
||||
const (
|
||||
addRule action = "add"
|
||||
deleteRule action = "delete"
|
||||
firewallRuleName = "Netbird"
|
||||
)
|
||||
|
||||
// Reset firewall to the default state
|
||||
func (m *Manager) Reset() error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
m.outgoingRules = make(map[string]RuleSet)
|
||||
m.incomingRules = make(map[string]RuleSet)
|
||||
|
||||
if !isWindowsFirewallReachable() {
|
||||
return nil
|
||||
}
|
||||
|
||||
if !isFirewallRuleActive(firewallRuleName) {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := manageFirewallRule(firewallRuleName, deleteRule); err != nil {
|
||||
return fmt.Errorf("couldn't remove windows firewall: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// AllowNetbird allows netbird interface traffic
|
||||
func (m *Manager) AllowNetbird() error {
|
||||
if !isWindowsFirewallReachable() {
|
||||
return nil
|
||||
}
|
||||
|
||||
if isFirewallRuleActive(firewallRuleName) {
|
||||
return nil
|
||||
}
|
||||
return manageFirewallRule(firewallRuleName,
|
||||
addRule,
|
||||
"dir=in",
|
||||
"enable=yes",
|
||||
"action=allow",
|
||||
"profile=any",
|
||||
"localip="+m.wgIface.Address().IP.String(),
|
||||
)
|
||||
}
|
||||
|
||||
func manageFirewallRule(ruleName string, action action, extraArgs ...string) error {
|
||||
|
||||
args := []string{"advfirewall", "firewall", string(action), "rule", "name=" + ruleName}
|
||||
if action == addRule {
|
||||
args = append(args, extraArgs...)
|
||||
}
|
||||
|
||||
cmd := exec.Command("netsh", args...)
|
||||
cmd.SysProcAttr = &syscall.SysProcAttr{HideWindow: true}
|
||||
return cmd.Run()
|
||||
}
|
||||
|
||||
func isWindowsFirewallReachable() bool {
|
||||
args := []string{"advfirewall", "show", "allprofiles", "state"}
|
||||
cmd := exec.Command("netsh", args...)
|
||||
cmd.SysProcAttr = &syscall.SysProcAttr{HideWindow: true}
|
||||
|
||||
_, err := cmd.Output()
|
||||
if err != nil {
|
||||
log.Infof("Windows firewall is not reachable, skipping default rule management. Using only user space rules. Error: %s", err)
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func isFirewallRuleActive(ruleName string) bool {
|
||||
args := []string{"advfirewall", "firewall", "show", "rule", "name=" + ruleName}
|
||||
|
||||
cmd := exec.Command("netsh", args...)
|
||||
cmd.SysProcAttr = &syscall.SysProcAttr{HideWindow: true}
|
||||
_, err := cmd.Output()
|
||||
return err == nil
|
||||
}
|
||||
@@ -5,7 +5,7 @@ import (
|
||||
|
||||
"github.com/google/gopacket"
|
||||
|
||||
fw "github.com/netbirdio/netbird/client/firewall"
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
)
|
||||
|
||||
// Rule to handle management of rules
|
||||
@@ -15,7 +15,7 @@ type Rule struct {
|
||||
ipLayer gopacket.LayerType
|
||||
matchByIP bool
|
||||
protoLayer gopacket.LayerType
|
||||
direction fw.RuleDirection
|
||||
direction firewall.RuleDirection
|
||||
sPort uint16
|
||||
dPort uint16
|
||||
drop bool
|
||||
|
||||
@@ -10,15 +10,20 @@ import (
|
||||
"github.com/google/uuid"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
fw "github.com/netbirdio/netbird/client/firewall"
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
)
|
||||
|
||||
const layerTypeAll = 0
|
||||
|
||||
var (
|
||||
errRouteNotSupported = fmt.Errorf("route not supported with userspace firewall")
|
||||
)
|
||||
|
||||
// IFaceMapper defines subset methods of interface required for manager
|
||||
type IFaceMapper interface {
|
||||
SetFilter(iface.PacketFilter) error
|
||||
Address() iface.WGAddress
|
||||
}
|
||||
|
||||
// RuleSet is a set of rules grouped by a string key
|
||||
@@ -26,10 +31,12 @@ type RuleSet map[string]Rule
|
||||
|
||||
// Manager userspace firewall manager
|
||||
type Manager struct {
|
||||
outgoingRules map[string]RuleSet
|
||||
incomingRules map[string]RuleSet
|
||||
wgNetwork *net.IPNet
|
||||
decoders sync.Pool
|
||||
outgoingRules map[string]RuleSet
|
||||
incomingRules map[string]RuleSet
|
||||
wgNetwork *net.IPNet
|
||||
decoders sync.Pool
|
||||
wgIface IFaceMapper
|
||||
nativeFirewall firewall.Manager
|
||||
|
||||
mutex sync.RWMutex
|
||||
}
|
||||
@@ -49,6 +56,20 @@ type decoder struct {
|
||||
|
||||
// Create userspace firewall manager constructor
|
||||
func Create(iface IFaceMapper) (*Manager, error) {
|
||||
return create(iface)
|
||||
}
|
||||
|
||||
func CreateWithNativeFirewall(iface IFaceMapper, nativeFirewall firewall.Manager) (*Manager, error) {
|
||||
mgr, err := create(iface)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
mgr.nativeFirewall = nativeFirewall
|
||||
return mgr, nil
|
||||
}
|
||||
|
||||
func create(iface IFaceMapper) (*Manager, error) {
|
||||
m := &Manager{
|
||||
decoders: sync.Pool{
|
||||
New: func() any {
|
||||
@@ -65,6 +86,7 @@ func Create(iface IFaceMapper) (*Manager, error) {
|
||||
},
|
||||
outgoingRules: make(map[string]RuleSet),
|
||||
incomingRules: make(map[string]RuleSet),
|
||||
wgIface: iface,
|
||||
}
|
||||
|
||||
if err := iface.SetFilter(m); err != nil {
|
||||
@@ -73,27 +95,50 @@ func Create(iface IFaceMapper) (*Manager, error) {
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func (m *Manager) IsServerRouteSupported() bool {
|
||||
if m.nativeFirewall == nil {
|
||||
return false
|
||||
} else {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) InsertRoutingRules(pair firewall.RouterPair) error {
|
||||
if m.nativeFirewall == nil {
|
||||
return errRouteNotSupported
|
||||
}
|
||||
return m.nativeFirewall.InsertRoutingRules(pair)
|
||||
}
|
||||
|
||||
// RemoveRoutingRules removes a routing firewall rule
|
||||
func (m *Manager) RemoveRoutingRules(pair firewall.RouterPair) error {
|
||||
if m.nativeFirewall == nil {
|
||||
return errRouteNotSupported
|
||||
}
|
||||
return m.nativeFirewall.RemoveRoutingRules(pair)
|
||||
}
|
||||
|
||||
// AddFiltering rule to the firewall
|
||||
//
|
||||
// If comment argument is empty firewall manager should set
|
||||
// rule ID as comment for the rule
|
||||
func (m *Manager) AddFiltering(
|
||||
ip net.IP,
|
||||
proto fw.Protocol,
|
||||
sPort *fw.Port,
|
||||
dPort *fw.Port,
|
||||
direction fw.RuleDirection,
|
||||
action fw.Action,
|
||||
proto firewall.Protocol,
|
||||
sPort *firewall.Port,
|
||||
dPort *firewall.Port,
|
||||
direction firewall.RuleDirection,
|
||||
action firewall.Action,
|
||||
ipsetName string,
|
||||
comment string,
|
||||
) (fw.Rule, error) {
|
||||
) ([]firewall.Rule, error) {
|
||||
r := Rule{
|
||||
id: uuid.New().String(),
|
||||
ip: ip,
|
||||
ipLayer: layers.LayerTypeIPv6,
|
||||
matchByIP: true,
|
||||
direction: direction,
|
||||
drop: action == fw.ActionDrop,
|
||||
drop: action == firewall.ActionDrop,
|
||||
comment: comment,
|
||||
}
|
||||
if ipNormalized := ip.To4(); ipNormalized != nil {
|
||||
@@ -114,21 +159,21 @@ func (m *Manager) AddFiltering(
|
||||
}
|
||||
|
||||
switch proto {
|
||||
case fw.ProtocolTCP:
|
||||
case firewall.ProtocolTCP:
|
||||
r.protoLayer = layers.LayerTypeTCP
|
||||
case fw.ProtocolUDP:
|
||||
case firewall.ProtocolUDP:
|
||||
r.protoLayer = layers.LayerTypeUDP
|
||||
case fw.ProtocolICMP:
|
||||
case firewall.ProtocolICMP:
|
||||
r.protoLayer = layers.LayerTypeICMPv4
|
||||
if r.ipLayer == layers.LayerTypeIPv6 {
|
||||
r.protoLayer = layers.LayerTypeICMPv6
|
||||
}
|
||||
case fw.ProtocolALL:
|
||||
case firewall.ProtocolALL:
|
||||
r.protoLayer = layerTypeAll
|
||||
}
|
||||
|
||||
m.mutex.Lock()
|
||||
if direction == fw.RuleDirectionIN {
|
||||
if direction == firewall.RuleDirectionIN {
|
||||
if _, ok := m.incomingRules[r.ip.String()]; !ok {
|
||||
m.incomingRules[r.ip.String()] = make(RuleSet)
|
||||
}
|
||||
@@ -140,12 +185,11 @@ func (m *Manager) AddFiltering(
|
||||
m.outgoingRules[r.ip.String()][r.id] = r
|
||||
}
|
||||
m.mutex.Unlock()
|
||||
|
||||
return &r, nil
|
||||
return []firewall.Rule{&r}, nil
|
||||
}
|
||||
|
||||
// DeleteRule from the firewall by rule definition
|
||||
func (m *Manager) DeleteRule(rule fw.Rule) error {
|
||||
func (m *Manager) DeleteRule(rule firewall.Rule) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
@@ -154,7 +198,7 @@ func (m *Manager) DeleteRule(rule fw.Rule) error {
|
||||
return fmt.Errorf("delete rule: invalid rule type: %T", rule)
|
||||
}
|
||||
|
||||
if r.direction == fw.RuleDirectionIN {
|
||||
if r.direction == firewall.RuleDirectionIN {
|
||||
_, ok := m.incomingRules[r.ip.String()][r.id]
|
||||
if !ok {
|
||||
return fmt.Errorf("delete rule: no rule with such id: %v", r.id)
|
||||
@@ -171,17 +215,6 @@ func (m *Manager) DeleteRule(rule fw.Rule) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Reset firewall to the default state
|
||||
func (m *Manager) Reset() error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
m.outgoingRules = make(map[string]RuleSet)
|
||||
m.incomingRules = make(map[string]RuleSet)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Flush doesn't need to be implemented for this manager
|
||||
func (m *Manager) Flush() error { return nil }
|
||||
|
||||
@@ -195,7 +228,7 @@ func (m *Manager) DropIncoming(packetData []byte) bool {
|
||||
return m.dropFilter(packetData, m.incomingRules, true)
|
||||
}
|
||||
|
||||
// dropFilter imlements same logic for booth direction of the traffic
|
||||
// dropFilter implements same logic for booth direction of the traffic
|
||||
func (m *Manager) dropFilter(packetData []byte, rules map[string]RuleSet, isIncomingPacket bool) bool {
|
||||
m.mutex.RLock()
|
||||
defer m.mutex.RUnlock()
|
||||
@@ -329,7 +362,7 @@ func (m *Manager) AddUDPPacketHook(
|
||||
protoLayer: layers.LayerTypeUDP,
|
||||
dPort: dPort,
|
||||
ipLayer: layers.LayerTypeIPv6,
|
||||
direction: fw.RuleDirectionOUT,
|
||||
direction: firewall.RuleDirectionOUT,
|
||||
comment: fmt.Sprintf("UDP Hook direction: %v, ip:%v, dport:%d", in, ip, dPort),
|
||||
udpHook: hook,
|
||||
}
|
||||
@@ -340,7 +373,7 @@ func (m *Manager) AddUDPPacketHook(
|
||||
|
||||
m.mutex.Lock()
|
||||
if in {
|
||||
r.direction = fw.RuleDirectionIN
|
||||
r.direction = firewall.RuleDirectionIN
|
||||
if _, ok := m.incomingRules[r.ip.String()]; !ok {
|
||||
m.incomingRules[r.ip.String()] = make(map[string]Rule)
|
||||
}
|
||||
@@ -362,14 +395,16 @@ func (m *Manager) RemovePacketHook(hookID string) error {
|
||||
for _, arr := range m.incomingRules {
|
||||
for _, r := range arr {
|
||||
if r.id == hookID {
|
||||
return m.DeleteRule(&r)
|
||||
rule := r
|
||||
return m.DeleteRule(&rule)
|
||||
}
|
||||
}
|
||||
}
|
||||
for _, arr := range m.outgoingRules {
|
||||
for _, r := range arr {
|
||||
if r.id == hookID {
|
||||
return m.DeleteRule(&r)
|
||||
rule := r
|
||||
return m.DeleteRule(&rule)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -10,12 +10,13 @@ import (
|
||||
"github.com/google/gopacket/layers"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
fw "github.com/netbirdio/netbird/client/firewall"
|
||||
fw "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
)
|
||||
|
||||
type IFaceMock struct {
|
||||
SetFilterFunc func(iface.PacketFilter) error
|
||||
AddressFunc func() iface.WGAddress
|
||||
}
|
||||
|
||||
func (i *IFaceMock) SetFilter(iface iface.PacketFilter) error {
|
||||
@@ -25,6 +26,13 @@ func (i *IFaceMock) SetFilter(iface iface.PacketFilter) error {
|
||||
return i.SetFilterFunc(iface)
|
||||
}
|
||||
|
||||
func (i *IFaceMock) Address() iface.WGAddress {
|
||||
if i.AddressFunc == nil {
|
||||
return iface.WGAddress{}
|
||||
}
|
||||
return i.AddressFunc()
|
||||
}
|
||||
|
||||
func TestManagerCreate(t *testing.T) {
|
||||
ifaceMock := &IFaceMock{
|
||||
SetFilterFunc: func(iface.PacketFilter) error { return nil },
|
||||
@@ -117,24 +125,32 @@ func TestManagerDeleteRule(t *testing.T) {
|
||||
return
|
||||
}
|
||||
|
||||
err = m.DeleteRule(rule)
|
||||
if err != nil {
|
||||
t.Errorf("failed to delete rule: %v", err)
|
||||
return
|
||||
for _, r := range rule {
|
||||
err = m.DeleteRule(r)
|
||||
if err != nil {
|
||||
t.Errorf("failed to delete rule: %v", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if _, ok := m.incomingRules[ip.String()][rule2.GetRuleID()]; !ok {
|
||||
t.Errorf("rule2 is not in the incomingRules")
|
||||
for _, r := range rule2 {
|
||||
if _, ok := m.incomingRules[ip.String()][r.GetRuleID()]; !ok {
|
||||
t.Errorf("rule2 is not in the incomingRules")
|
||||
}
|
||||
}
|
||||
|
||||
err = m.DeleteRule(rule2)
|
||||
if err != nil {
|
||||
t.Errorf("failed to delete rule: %v", err)
|
||||
return
|
||||
for _, r := range rule2 {
|
||||
err = m.DeleteRule(r)
|
||||
if err != nil {
|
||||
t.Errorf("failed to delete rule: %v", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if _, ok := m.incomingRules[ip.String()][rule2.GetRuleID()]; ok {
|
||||
t.Errorf("rule2 is not in the incomingRules")
|
||||
for _, r := range rule2 {
|
||||
if _, ok := m.incomingRules[ip.String()][r.GetRuleID()]; ok {
|
||||
t.Errorf("rule2 is not in the incomingRules")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -166,10 +166,9 @@ WriteRegStr ${REG_ROOT} "${UI_REG_APP_PATH}" "" "$INSTDIR\${UI_APP_EXE}"
|
||||
EnVar::SetHKLM
|
||||
EnVar::AddValueEx "path" "$INSTDIR"
|
||||
|
||||
SetShellVarContext current
|
||||
SetShellVarContext all
|
||||
CreateShortCut "$SMPROGRAMS\${APP_NAME}.lnk" "$INSTDIR\${UI_APP_EXE}"
|
||||
CreateShortCut "$DESKTOP\${APP_NAME}.lnk" "$INSTDIR\${UI_APP_EXE}"
|
||||
SetShellVarContext all
|
||||
SectionEnd
|
||||
|
||||
Section -Post
|
||||
@@ -194,12 +193,12 @@ Sleep 3000
|
||||
Delete "$INSTDIR\${UI_APP_EXE}"
|
||||
Delete "$INSTDIR\${MAIN_APP_EXE}"
|
||||
Delete "$INSTDIR\wintun.dll"
|
||||
Delete "$INSTDIR\opengl32.dll"
|
||||
RmDir /r "$INSTDIR"
|
||||
|
||||
SetShellVarContext current
|
||||
SetShellVarContext all
|
||||
Delete "$DESKTOP\${APP_NAME}.lnk"
|
||||
Delete "$SMPROGRAMS\${APP_NAME}.lnk"
|
||||
SetShellVarContext all
|
||||
|
||||
DeleteRegKey ${REG_ROOT} "${REG_APP_PATH}"
|
||||
DeleteRegKey ${REG_ROOT} "${UNINSTALL_PATH}"
|
||||
@@ -209,8 +208,7 @@ SectionEnd
|
||||
|
||||
|
||||
Function LaunchLink
|
||||
SetShellVarContext current
|
||||
SetShellVarContext all
|
||||
SetOutPath $INSTDIR
|
||||
ShellExecAsUser::ShellExecAsUser "" "$DESKTOP\${APP_NAME}.lnk"
|
||||
SetShellVarContext all
|
||||
FunctionEnd
|
||||
|
||||
@@ -11,49 +11,34 @@ import (
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/firewall"
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/ssh"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
mgmProto "github.com/netbirdio/netbird/management/proto"
|
||||
)
|
||||
|
||||
// IFaceMapper defines subset methods of interface required for manager
|
||||
type IFaceMapper interface {
|
||||
Name() string
|
||||
Address() iface.WGAddress
|
||||
IsUserspaceBind() bool
|
||||
SetFilter(iface.PacketFilter) error
|
||||
}
|
||||
|
||||
// Manager is a ACL rules manager
|
||||
type Manager interface {
|
||||
ApplyFiltering(networkMap *mgmProto.NetworkMap)
|
||||
Stop()
|
||||
}
|
||||
|
||||
// DefaultManager uses firewall manager to handle
|
||||
type DefaultManager struct {
|
||||
manager firewall.Manager
|
||||
firewall firewall.Manager
|
||||
ipsetCounter int
|
||||
rulesPairs map[string][]firewall.Rule
|
||||
mutex sync.Mutex
|
||||
}
|
||||
|
||||
type ipsetInfo struct {
|
||||
name string
|
||||
ipCount int
|
||||
}
|
||||
|
||||
func newDefaultManager(fm firewall.Manager) *DefaultManager {
|
||||
func NewDefaultManager(fm firewall.Manager) *DefaultManager {
|
||||
return &DefaultManager{
|
||||
manager: fm,
|
||||
firewall: fm,
|
||||
rulesPairs: make(map[string][]firewall.Rule),
|
||||
}
|
||||
}
|
||||
|
||||
// ApplyFiltering firewall rules to the local firewall manager processed by ACL policy.
|
||||
//
|
||||
// If allowByDefault is ture it appends allow ALL traffic rules to input and output chains.
|
||||
// If allowByDefault is true it appends allow ALL traffic rules to input and output chains.
|
||||
func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap) {
|
||||
d.mutex.Lock()
|
||||
defer d.mutex.Unlock()
|
||||
@@ -69,13 +54,13 @@ func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap) {
|
||||
time.Since(start), total)
|
||||
}()
|
||||
|
||||
if d.manager == nil {
|
||||
if d.firewall == nil {
|
||||
log.Debug("firewall manager is not supported, skipping firewall rules")
|
||||
return
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if err := d.manager.Flush(); err != nil {
|
||||
if err := d.firewall.Flush(); err != nil {
|
||||
log.Error("failed to flush firewall rules: ", err)
|
||||
}
|
||||
}()
|
||||
@@ -125,58 +110,35 @@ func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap) {
|
||||
)
|
||||
}
|
||||
|
||||
applyFailed := false
|
||||
newRulePairs := make(map[string][]firewall.Rule)
|
||||
ipsetByRuleSelectors := make(map[string]*ipsetInfo)
|
||||
|
||||
// calculate which IP's can be grouped in by which ipset
|
||||
// to do that we use rule selector (which is just rule properties without IP's)
|
||||
for _, r := range rules {
|
||||
selector := d.getRuleGroupingSelector(r)
|
||||
ipset, ok := ipsetByRuleSelectors[selector]
|
||||
if !ok {
|
||||
ipset = &ipsetInfo{}
|
||||
}
|
||||
|
||||
ipset.ipCount++
|
||||
ipsetByRuleSelectors[selector] = ipset
|
||||
}
|
||||
ipsetByRuleSelectors := make(map[string]string)
|
||||
|
||||
for _, r := range rules {
|
||||
// if this rule is member of rule selection with more than DefaultIPsCountForSet
|
||||
// it's IP address can be used in the ipset for firewall manager which supports it
|
||||
ipset := ipsetByRuleSelectors[d.getRuleGroupingSelector(r)]
|
||||
ipsetName := ""
|
||||
if ipset.name == "" {
|
||||
selector := d.getRuleGroupingSelector(r)
|
||||
ipsetName, ok := ipsetByRuleSelectors[selector]
|
||||
if !ok {
|
||||
d.ipsetCounter++
|
||||
ipset.name = fmt.Sprintf("nb%07d", d.ipsetCounter)
|
||||
ipsetName = fmt.Sprintf("nb%07d", d.ipsetCounter)
|
||||
ipsetByRuleSelectors[selector] = ipsetName
|
||||
}
|
||||
ipsetName = ipset.name
|
||||
pairID, rulePair, err := d.protoRuleToFirewallRule(r, ipsetName)
|
||||
if err != nil {
|
||||
log.Errorf("failed to apply firewall rule: %+v, %v", r, err)
|
||||
applyFailed = true
|
||||
d.rollBack(newRulePairs)
|
||||
break
|
||||
}
|
||||
newRulePairs[pairID] = rulePair
|
||||
}
|
||||
if applyFailed {
|
||||
log.Error("failed to apply firewall rules, rollback ACL to previous state")
|
||||
for _, rules := range newRulePairs {
|
||||
for _, rule := range rules {
|
||||
if err := d.manager.DeleteRule(rule); err != nil {
|
||||
log.Errorf("failed to delete new firewall rule (id: %v) during rollback: %v", rule.GetRuleID(), err)
|
||||
continue
|
||||
}
|
||||
}
|
||||
if len(rules) > 0 {
|
||||
d.rulesPairs[pairID] = rulePair
|
||||
newRulePairs[pairID] = rulePair
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
for pairID, rules := range d.rulesPairs {
|
||||
if _, ok := newRulePairs[pairID]; !ok {
|
||||
for _, rule := range rules {
|
||||
if err := d.manager.DeleteRule(rule); err != nil {
|
||||
if err := d.firewall.DeleteRule(rule); err != nil {
|
||||
log.Errorf("failed to delete firewall rule: %v", err)
|
||||
continue
|
||||
}
|
||||
@@ -187,16 +149,6 @@ func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap) {
|
||||
d.rulesPairs = newRulePairs
|
||||
}
|
||||
|
||||
// Stop ACL controller and clear firewall state
|
||||
func (d *DefaultManager) Stop() {
|
||||
d.mutex.Lock()
|
||||
defer d.mutex.Unlock()
|
||||
|
||||
if err := d.manager.Reset(); err != nil {
|
||||
log.WithError(err).Error("reset firewall state")
|
||||
}
|
||||
}
|
||||
|
||||
func (d *DefaultManager) protoRuleToFirewallRule(
|
||||
r *mgmProto.FirewallRule,
|
||||
ipsetName string,
|
||||
@@ -206,14 +158,14 @@ func (d *DefaultManager) protoRuleToFirewallRule(
|
||||
return "", nil, fmt.Errorf("invalid IP address, skipping firewall rule")
|
||||
}
|
||||
|
||||
protocol := convertToFirewallProtocol(r.Protocol)
|
||||
if protocol == firewall.ProtocolUnknown {
|
||||
return "", nil, fmt.Errorf("invalid protocol type: %d, skipping firewall rule", r.Protocol)
|
||||
protocol, err := convertToFirewallProtocol(r.Protocol)
|
||||
if err != nil {
|
||||
return "", nil, fmt.Errorf("skipping firewall rule: %s", err)
|
||||
}
|
||||
|
||||
action := convertFirewallAction(r.Action)
|
||||
if action == firewall.ActionUnknown {
|
||||
return "", nil, fmt.Errorf("invalid action type: %d, skipping firewall rule", r.Action)
|
||||
action, err := convertFirewallAction(r.Action)
|
||||
if err != nil {
|
||||
return "", nil, fmt.Errorf("skipping firewall rule: %s", err)
|
||||
}
|
||||
|
||||
var port *firewall.Port
|
||||
@@ -233,7 +185,6 @@ func (d *DefaultManager) protoRuleToFirewallRule(
|
||||
}
|
||||
|
||||
var rules []firewall.Rule
|
||||
var err error
|
||||
switch r.Direction {
|
||||
case mgmProto.FirewallRule_IN:
|
||||
rules, err = d.addInRules(ip, protocol, port, action, ipsetName, "")
|
||||
@@ -247,7 +198,6 @@ func (d *DefaultManager) protoRuleToFirewallRule(
|
||||
return "", nil, err
|
||||
}
|
||||
|
||||
d.rulesPairs[ruleID] = rules
|
||||
return ruleID, rules, nil
|
||||
}
|
||||
|
||||
@@ -260,24 +210,24 @@ func (d *DefaultManager) addInRules(
|
||||
comment string,
|
||||
) ([]firewall.Rule, error) {
|
||||
var rules []firewall.Rule
|
||||
rule, err := d.manager.AddFiltering(
|
||||
rule, err := d.firewall.AddFiltering(
|
||||
ip, protocol, nil, port, firewall.RuleDirectionIN, action, ipsetName, comment)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to add firewall rule: %v", err)
|
||||
}
|
||||
rules = append(rules, rule)
|
||||
rules = append(rules, rule...)
|
||||
|
||||
if shouldSkipInvertedRule(protocol, port) {
|
||||
return rules, nil
|
||||
}
|
||||
|
||||
rule, err = d.manager.AddFiltering(
|
||||
rule, err = d.firewall.AddFiltering(
|
||||
ip, protocol, port, nil, firewall.RuleDirectionOUT, action, ipsetName, comment)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to add firewall rule: %v", err)
|
||||
}
|
||||
|
||||
return append(rules, rule), nil
|
||||
return append(rules, rule...), nil
|
||||
}
|
||||
|
||||
func (d *DefaultManager) addOutRules(
|
||||
@@ -289,24 +239,24 @@ func (d *DefaultManager) addOutRules(
|
||||
comment string,
|
||||
) ([]firewall.Rule, error) {
|
||||
var rules []firewall.Rule
|
||||
rule, err := d.manager.AddFiltering(
|
||||
rule, err := d.firewall.AddFiltering(
|
||||
ip, protocol, nil, port, firewall.RuleDirectionOUT, action, ipsetName, comment)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to add firewall rule: %v", err)
|
||||
}
|
||||
rules = append(rules, rule)
|
||||
rules = append(rules, rule...)
|
||||
|
||||
if shouldSkipInvertedRule(protocol, port) {
|
||||
return rules, nil
|
||||
}
|
||||
|
||||
rule, err = d.manager.AddFiltering(
|
||||
rule, err = d.firewall.AddFiltering(
|
||||
ip, protocol, port, nil, firewall.RuleDirectionIN, action, ipsetName, comment)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to add firewall rule: %v", err)
|
||||
}
|
||||
|
||||
return append(rules, rule), nil
|
||||
return append(rules, rule...), nil
|
||||
}
|
||||
|
||||
// getRuleID() returns unique ID for the rule based on its parameters.
|
||||
@@ -367,7 +317,7 @@ func (d *DefaultManager) squashAcceptRules(
|
||||
protocols[r.Protocol] = map[string]int{}
|
||||
}
|
||||
|
||||
// special case, when we recieve this all network IP address
|
||||
// special case, when we receive this all network IP address
|
||||
// it means that rules for that protocol was already optimized on the
|
||||
// management side
|
||||
if r.PeerIP == "0.0.0.0" {
|
||||
@@ -394,7 +344,7 @@ func (d *DefaultManager) squashAcceptRules(
|
||||
}
|
||||
|
||||
// order of squashing by protocol is important
|
||||
// only for ther first element ALL, it must be done first
|
||||
// only for their first element ALL, it must be done first
|
||||
protocolOrders := []mgmProto.FirewallRuleProtocol{
|
||||
mgmProto.FirewallRule_ALL,
|
||||
mgmProto.FirewallRule_ICMP,
|
||||
@@ -462,18 +412,29 @@ func (d *DefaultManager) getRuleGroupingSelector(rule *mgmProto.FirewallRule) st
|
||||
return fmt.Sprintf("%v:%v:%v:%s", strconv.Itoa(int(rule.Direction)), rule.Action, rule.Protocol, rule.Port)
|
||||
}
|
||||
|
||||
func convertToFirewallProtocol(protocol mgmProto.FirewallRuleProtocol) firewall.Protocol {
|
||||
func (d *DefaultManager) rollBack(newRulePairs map[string][]firewall.Rule) {
|
||||
log.Debugf("rollback ACL to previous state")
|
||||
for _, rules := range newRulePairs {
|
||||
for _, rule := range rules {
|
||||
if err := d.firewall.DeleteRule(rule); err != nil {
|
||||
log.Errorf("failed to delete new firewall rule (id: %v) during rollback: %v", rule.GetRuleID(), err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func convertToFirewallProtocol(protocol mgmProto.FirewallRuleProtocol) (firewall.Protocol, error) {
|
||||
switch protocol {
|
||||
case mgmProto.FirewallRule_TCP:
|
||||
return firewall.ProtocolTCP
|
||||
return firewall.ProtocolTCP, nil
|
||||
case mgmProto.FirewallRule_UDP:
|
||||
return firewall.ProtocolUDP
|
||||
return firewall.ProtocolUDP, nil
|
||||
case mgmProto.FirewallRule_ICMP:
|
||||
return firewall.ProtocolICMP
|
||||
return firewall.ProtocolICMP, nil
|
||||
case mgmProto.FirewallRule_ALL:
|
||||
return firewall.ProtocolALL
|
||||
return firewall.ProtocolALL, nil
|
||||
default:
|
||||
return firewall.ProtocolUnknown
|
||||
return firewall.ProtocolALL, fmt.Errorf("invalid protocol type: %s", protocol.String())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -481,13 +442,13 @@ func shouldSkipInvertedRule(protocol firewall.Protocol, port *firewall.Port) boo
|
||||
return protocol == firewall.ProtocolALL || protocol == firewall.ProtocolICMP || port == nil
|
||||
}
|
||||
|
||||
func convertFirewallAction(action mgmProto.FirewallRuleAction) firewall.Action {
|
||||
func convertFirewallAction(action mgmProto.FirewallRuleAction) (firewall.Action, error) {
|
||||
switch action {
|
||||
case mgmProto.FirewallRule_ACCEPT:
|
||||
return firewall.ActionAccept
|
||||
return firewall.ActionAccept, nil
|
||||
case mgmProto.FirewallRule_DROP:
|
||||
return firewall.ActionDrop
|
||||
return firewall.ActionDrop, nil
|
||||
default:
|
||||
return firewall.ActionUnknown
|
||||
return firewall.ActionDrop, fmt.Errorf("invalid action type: %d", action)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,23 +0,0 @@
|
||||
//go:build !linux
|
||||
|
||||
package acl
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"runtime"
|
||||
|
||||
"github.com/netbirdio/netbird/client/firewall/uspfilter"
|
||||
)
|
||||
|
||||
// Create creates a firewall manager instance
|
||||
func Create(iface IFaceMapper) (manager *DefaultManager, err error) {
|
||||
if iface.IsUserspaceBind() {
|
||||
// use userspace packet filtering firewall
|
||||
fm, err := uspfilter.Create(iface)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return newDefaultManager(fm), nil
|
||||
}
|
||||
return nil, fmt.Errorf("not implemented for this OS: %s", runtime.GOOS)
|
||||
}
|
||||
@@ -1,33 +0,0 @@
|
||||
package acl
|
||||
|
||||
import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/firewall"
|
||||
"github.com/netbirdio/netbird/client/firewall/iptables"
|
||||
"github.com/netbirdio/netbird/client/firewall/nftables"
|
||||
"github.com/netbirdio/netbird/client/firewall/uspfilter"
|
||||
)
|
||||
|
||||
// Create creates a firewall manager instance for the Linux
|
||||
func Create(iface IFaceMapper) (manager *DefaultManager, err error) {
|
||||
var fm firewall.Manager
|
||||
if iface.IsUserspaceBind() {
|
||||
// use userspace packet filtering firewall
|
||||
if fm, err = uspfilter.Create(iface); err != nil {
|
||||
log.Debugf("failed to create userspace filtering firewall: %s", err)
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
if fm, err = nftables.Create(iface); err != nil {
|
||||
log.Debugf("failed to create nftables manager: %s", err)
|
||||
// fallback to iptables
|
||||
if fm, err = iptables.Create(iface); err != nil {
|
||||
log.Errorf("failed to create iptables manager: %s", err)
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return newDefaultManager(fm), nil
|
||||
}
|
||||
@@ -1,11 +1,16 @@
|
||||
package acl
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
|
||||
"github.com/netbirdio/netbird/client/firewall"
|
||||
"github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/internal/acl/mocks"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
mgmProto "github.com/netbirdio/netbird/management/proto"
|
||||
)
|
||||
|
||||
@@ -32,18 +37,30 @@ func TestDefaultManager(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
iface := mocks.NewMockIFaceMapper(ctrl)
|
||||
iface.EXPECT().IsUserspaceBind().Return(true)
|
||||
// iface.EXPECT().Name().Return("lo")
|
||||
iface.EXPECT().SetFilter(gomock.Any())
|
||||
ifaceMock := mocks.NewMockIFaceMapper(ctrl)
|
||||
ifaceMock.EXPECT().IsUserspaceBind().Return(true).AnyTimes()
|
||||
ifaceMock.EXPECT().SetFilter(gomock.Any())
|
||||
ip, network, err := net.ParseCIDR("172.0.0.1/32")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to parse IP address: %v", err)
|
||||
}
|
||||
|
||||
ifaceMock.EXPECT().Name().Return("lo").AnyTimes()
|
||||
ifaceMock.EXPECT().Address().Return(iface.WGAddress{
|
||||
IP: ip,
|
||||
Network: network,
|
||||
}).AnyTimes()
|
||||
|
||||
// we receive one rule from the management so for testing purposes ignore it
|
||||
acl, err := Create(iface)
|
||||
fw, err := firewall.NewFirewall(context.Background(), ifaceMock)
|
||||
if err != nil {
|
||||
t.Errorf("create ACL manager: %v", err)
|
||||
t.Errorf("create firewall: %v", err)
|
||||
return
|
||||
}
|
||||
defer acl.Stop()
|
||||
defer func(fw manager.Manager) {
|
||||
_ = fw.Reset()
|
||||
}(fw)
|
||||
acl := NewDefaultManager(fw)
|
||||
|
||||
t.Run("apply firewall rules", func(t *testing.T) {
|
||||
acl.ApplyFiltering(networkMap)
|
||||
@@ -178,31 +195,33 @@ func TestDefaultManagerSquashRules(t *testing.T) {
|
||||
}
|
||||
|
||||
r := rules[0]
|
||||
if r.PeerIP != "0.0.0.0" {
|
||||
switch {
|
||||
case r.PeerIP != "0.0.0.0":
|
||||
t.Errorf("IP should be 0.0.0.0, got: %v", r.PeerIP)
|
||||
return
|
||||
} else if r.Direction != mgmProto.FirewallRule_IN {
|
||||
case r.Direction != mgmProto.FirewallRule_IN:
|
||||
t.Errorf("direction should be IN, got: %v", r.Direction)
|
||||
return
|
||||
} else if r.Protocol != mgmProto.FirewallRule_ALL {
|
||||
case r.Protocol != mgmProto.FirewallRule_ALL:
|
||||
t.Errorf("protocol should be ALL, got: %v", r.Protocol)
|
||||
return
|
||||
} else if r.Action != mgmProto.FirewallRule_ACCEPT {
|
||||
case r.Action != mgmProto.FirewallRule_ACCEPT:
|
||||
t.Errorf("action should be ACCEPT, got: %v", r.Action)
|
||||
return
|
||||
}
|
||||
|
||||
r = rules[1]
|
||||
if r.PeerIP != "0.0.0.0" {
|
||||
switch {
|
||||
case r.PeerIP != "0.0.0.0":
|
||||
t.Errorf("IP should be 0.0.0.0, got: %v", r.PeerIP)
|
||||
return
|
||||
} else if r.Direction != mgmProto.FirewallRule_OUT {
|
||||
case r.Direction != mgmProto.FirewallRule_OUT:
|
||||
t.Errorf("direction should be OUT, got: %v", r.Direction)
|
||||
return
|
||||
} else if r.Protocol != mgmProto.FirewallRule_ALL {
|
||||
case r.Protocol != mgmProto.FirewallRule_ALL:
|
||||
t.Errorf("protocol should be ALL, got: %v", r.Protocol)
|
||||
return
|
||||
} else if r.Action != mgmProto.FirewallRule_ACCEPT {
|
||||
case r.Action != mgmProto.FirewallRule_ACCEPT:
|
||||
t.Errorf("action should be ACCEPT, got: %v", r.Action)
|
||||
return
|
||||
}
|
||||
@@ -270,7 +289,7 @@ func TestDefaultManagerSquashRulesNoAffect(t *testing.T) {
|
||||
|
||||
manager := &DefaultManager{}
|
||||
if rules, _ := manager.squashAcceptRules(networkMap); len(rules) != len(networkMap.FirewallRules) {
|
||||
t.Errorf("we should got same amount of rules as intput, got %v", len(rules))
|
||||
t.Errorf("we should get the same amount of rules as output, got %v", len(rules))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -311,18 +330,30 @@ func TestDefaultManagerEnableSSHRules(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
iface := mocks.NewMockIFaceMapper(ctrl)
|
||||
iface.EXPECT().IsUserspaceBind().Return(true)
|
||||
// iface.EXPECT().Name().Return("lo")
|
||||
iface.EXPECT().SetFilter(gomock.Any())
|
||||
ifaceMock := mocks.NewMockIFaceMapper(ctrl)
|
||||
ifaceMock.EXPECT().IsUserspaceBind().Return(true).AnyTimes()
|
||||
ifaceMock.EXPECT().SetFilter(gomock.Any())
|
||||
ip, network, err := net.ParseCIDR("172.0.0.1/32")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to parse IP address: %v", err)
|
||||
}
|
||||
|
||||
ifaceMock.EXPECT().Name().Return("lo").AnyTimes()
|
||||
ifaceMock.EXPECT().Address().Return(iface.WGAddress{
|
||||
IP: ip,
|
||||
Network: network,
|
||||
}).AnyTimes()
|
||||
|
||||
// we receive one rule from the management so for testing purposes ignore it
|
||||
acl, err := Create(iface)
|
||||
fw, err := firewall.NewFirewall(context.Background(), ifaceMock)
|
||||
if err != nil {
|
||||
t.Errorf("create ACL manager: %v", err)
|
||||
t.Errorf("create firewall: %v", err)
|
||||
return
|
||||
}
|
||||
defer acl.Stop()
|
||||
defer func(fw manager.Manager) {
|
||||
_ = fw.Reset()
|
||||
}(fw)
|
||||
acl := NewDefaultManager(fw)
|
||||
|
||||
acl.ApplyFiltering(networkMap)
|
||||
|
||||
|
||||
@@ -4,12 +4,13 @@ import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
)
|
||||
|
||||
// HostedGrantType grant type for device flow on Hosted
|
||||
@@ -174,7 +175,7 @@ func (d *DeviceAuthorizationFlow) WaitToken(ctx context.Context, info AuthFlowIn
|
||||
if tokenResponse.Error == "authorization_pending" {
|
||||
continue
|
||||
} else if tokenResponse.Error == "slow_down" {
|
||||
interval = interval + (3 * time.Second)
|
||||
interval += (3 * time.Second)
|
||||
ticker.Reset(interval)
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"runtime"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"google.golang.org/grpc/codes"
|
||||
@@ -25,7 +26,7 @@ type HTTPClient interface {
|
||||
}
|
||||
|
||||
// AuthFlowInfo holds information for the OAuth 2.0 authorization flow
|
||||
type AuthFlowInfo struct {
|
||||
type AuthFlowInfo struct { //nolint:revive
|
||||
DeviceCode string `json:"device_code"`
|
||||
UserCode string `json:"user_code"`
|
||||
VerificationURI string `json:"verification_uri"`
|
||||
@@ -57,34 +58,52 @@ func (t TokenInfo) GetTokenToUse() string {
|
||||
return t.AccessToken
|
||||
}
|
||||
|
||||
// NewOAuthFlow initializes and returns the appropriate OAuth flow based on the management configuration.
|
||||
func NewOAuthFlow(ctx context.Context, config *internal.Config) (OAuthFlow, error) {
|
||||
log.Debug("getting device authorization flow info")
|
||||
|
||||
// Try to initialize the Device Authorization Flow
|
||||
deviceFlowInfo, err := internal.GetDeviceAuthorizationFlowInfo(ctx, config.PrivateKey, config.ManagementURL)
|
||||
if err == nil {
|
||||
return NewDeviceAuthorizationFlow(deviceFlowInfo.ProviderConfig)
|
||||
// NewOAuthFlow initializes and returns the appropriate OAuth flow based on the management configuration
|
||||
//
|
||||
// It starts by initializing the PKCE.If this process fails, it resorts to the Device Code Flow,
|
||||
// and if that also fails, the authentication process is deemed unsuccessful
|
||||
//
|
||||
// On Linux distros without desktop environment support, it only tries to initialize the Device Code Flow
|
||||
func NewOAuthFlow(ctx context.Context, config *internal.Config, isLinuxDesktopClient bool) (OAuthFlow, error) {
|
||||
if runtime.GOOS == "linux" && !isLinuxDesktopClient {
|
||||
return authenticateWithDeviceCodeFlow(ctx, config)
|
||||
}
|
||||
|
||||
log.Debugf("getting device authorization flow info failed with error: %v", err)
|
||||
log.Debugf("falling back to pkce authorization flow info")
|
||||
pkceFlow, err := authenticateWithPKCEFlow(ctx, config)
|
||||
if err != nil {
|
||||
// fallback to device code flow
|
||||
log.Debugf("failed to initialize pkce authentication with error: %v\n", err)
|
||||
log.Debug("falling back to device code flow")
|
||||
return authenticateWithDeviceCodeFlow(ctx, config)
|
||||
}
|
||||
return pkceFlow, nil
|
||||
}
|
||||
|
||||
// If Device Authorization Flow failed, try the PKCE Authorization Flow
|
||||
// authenticateWithPKCEFlow initializes the Proof Key for Code Exchange flow auth flow
|
||||
func authenticateWithPKCEFlow(ctx context.Context, config *internal.Config) (OAuthFlow, error) {
|
||||
pkceFlowInfo, err := internal.GetPKCEAuthorizationFlowInfo(ctx, config.PrivateKey, config.ManagementURL)
|
||||
if err != nil {
|
||||
s, ok := gstatus.FromError(err)
|
||||
if ok && s.Code() == codes.NotFound {
|
||||
return nil, fmt.Errorf("getting pkce authorization flow info failed with error: %v", err)
|
||||
}
|
||||
return NewPKCEAuthorizationFlow(pkceFlowInfo.ProviderConfig)
|
||||
}
|
||||
|
||||
// authenticateWithDeviceCodeFlow initializes the Device Code auth Flow
|
||||
func authenticateWithDeviceCodeFlow(ctx context.Context, config *internal.Config) (OAuthFlow, error) {
|
||||
deviceFlowInfo, err := internal.GetDeviceAuthorizationFlowInfo(ctx, config.PrivateKey, config.ManagementURL)
|
||||
if err != nil {
|
||||
switch s, ok := gstatus.FromError(err); {
|
||||
case ok && s.Code() == codes.NotFound:
|
||||
return nil, fmt.Errorf("no SSO provider returned from management. " +
|
||||
"If you are using hosting Netbird see documentation at " +
|
||||
"https://github.com/netbirdio/netbird/tree/main/management for details")
|
||||
} else if ok && s.Code() == codes.Unimplemented {
|
||||
"Please proceed with setting up this device using setup keys " +
|
||||
"https://docs.netbird.io/how-to/register-machines-using-setup-keys")
|
||||
case ok && s.Code() == codes.Unimplemented:
|
||||
return nil, fmt.Errorf("the management server, %s, does not support SSO providers, "+
|
||||
"please update your server or use Setup Keys to login", config.ManagementURL)
|
||||
} else {
|
||||
return nil, fmt.Errorf("getting pkce authorization flow info failed with error: %v", err)
|
||||
default:
|
||||
return nil, fmt.Errorf("getting device authorization flow info failed with error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
return NewPKCEAuthorizationFlow(pkceFlowInfo.ProviderConfig)
|
||||
return NewDeviceAuthorizationFlow(deviceFlowInfo.ProviderConfig)
|
||||
}
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"crypto/sha256"
|
||||
"crypto/subtle"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"html/template"
|
||||
"net"
|
||||
@@ -78,7 +79,7 @@ func (p *PKCEAuthorizationFlow) GetClientID(_ context.Context) string {
|
||||
}
|
||||
|
||||
// RequestAuthInfo requests a authorization code login flow information.
|
||||
func (p *PKCEAuthorizationFlow) RequestAuthInfo(_ context.Context) (AuthFlowInfo, error) {
|
||||
func (p *PKCEAuthorizationFlow) RequestAuthInfo(ctx context.Context) (AuthFlowInfo, error) {
|
||||
state, err := randomBytesInHex(24)
|
||||
if err != nil {
|
||||
return AuthFlowInfo{}, fmt.Errorf("could not generate random state: %v", err)
|
||||
@@ -112,60 +113,37 @@ func (p *PKCEAuthorizationFlow) WaitToken(ctx context.Context, _ AuthFlowInfo) (
|
||||
tokenChan := make(chan *oauth2.Token, 1)
|
||||
errChan := make(chan error, 1)
|
||||
|
||||
go p.startServer(tokenChan, errChan)
|
||||
parsedURL, err := url.Parse(p.oAuthConfig.RedirectURL)
|
||||
if err != nil {
|
||||
return TokenInfo{}, fmt.Errorf("failed to parse redirect URL: %v", err)
|
||||
}
|
||||
|
||||
server := &http.Server{Addr: fmt.Sprintf(":%s", parsedURL.Port())}
|
||||
defer func() {
|
||||
shutdownCtx, cancel := context.WithTimeout(ctx, 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if err := server.Shutdown(shutdownCtx); err != nil {
|
||||
log.Errorf("failed to close the server: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
go p.startServer(server, tokenChan, errChan)
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return TokenInfo{}, ctx.Err()
|
||||
case token := <-tokenChan:
|
||||
return p.handleOAuthToken(token)
|
||||
return p.parseOAuthToken(token)
|
||||
case err := <-errChan:
|
||||
return TokenInfo{}, err
|
||||
}
|
||||
}
|
||||
|
||||
func (p *PKCEAuthorizationFlow) startServer(tokenChan chan<- *oauth2.Token, errChan chan<- error) {
|
||||
parsedURL, err := url.Parse(p.oAuthConfig.RedirectURL)
|
||||
if err != nil {
|
||||
errChan <- fmt.Errorf("failed to parse redirect URL: %v", err)
|
||||
return
|
||||
}
|
||||
port := parsedURL.Port()
|
||||
|
||||
server := http.Server{Addr: fmt.Sprintf(":%s", port)}
|
||||
defer func() {
|
||||
if err := server.Shutdown(context.Background()); err != nil {
|
||||
log.Errorf("error while shutting down pkce flow server: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
http.HandleFunc("/", func(w http.ResponseWriter, req *http.Request) {
|
||||
tokenValidatorFunc := func() (*oauth2.Token, error) {
|
||||
query := req.URL.Query()
|
||||
|
||||
if authError := query.Get(queryError); authError != "" {
|
||||
authErrorDesc := query.Get(queryErrorDesc)
|
||||
return nil, fmt.Errorf("%s.%s", authError, authErrorDesc)
|
||||
}
|
||||
|
||||
// Prevent timing attacks on state
|
||||
if state := query.Get(queryState); subtle.ConstantTimeCompare([]byte(p.state), []byte(state)) == 0 {
|
||||
return nil, fmt.Errorf("invalid state")
|
||||
}
|
||||
|
||||
code := query.Get(queryCode)
|
||||
if code == "" {
|
||||
return nil, fmt.Errorf("missing code")
|
||||
}
|
||||
|
||||
return p.oAuthConfig.Exchange(
|
||||
req.Context(),
|
||||
code,
|
||||
oauth2.SetAuthURLParam("code_verifier", p.codeVerifier),
|
||||
)
|
||||
}
|
||||
|
||||
token, err := tokenValidatorFunc()
|
||||
func (p *PKCEAuthorizationFlow) startServer(server *http.Server, tokenChan chan<- *oauth2.Token, errChan chan<- error) {
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("/", func(w http.ResponseWriter, req *http.Request) {
|
||||
token, err := p.handleRequest(req)
|
||||
if err != nil {
|
||||
renderPKCEFlowTmpl(w, err)
|
||||
errChan <- fmt.Errorf("PKCE authorization flow failed: %v", err)
|
||||
@@ -176,12 +154,38 @@ func (p *PKCEAuthorizationFlow) startServer(tokenChan chan<- *oauth2.Token, errC
|
||||
tokenChan <- token
|
||||
})
|
||||
|
||||
if err := server.ListenAndServe(); err != nil {
|
||||
server.Handler = mux
|
||||
if err := server.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
|
||||
errChan <- err
|
||||
}
|
||||
}
|
||||
|
||||
func (p *PKCEAuthorizationFlow) handleOAuthToken(token *oauth2.Token) (TokenInfo, error) {
|
||||
func (p *PKCEAuthorizationFlow) handleRequest(req *http.Request) (*oauth2.Token, error) {
|
||||
query := req.URL.Query()
|
||||
|
||||
if authError := query.Get(queryError); authError != "" {
|
||||
authErrorDesc := query.Get(queryErrorDesc)
|
||||
return nil, fmt.Errorf("%s.%s", authError, authErrorDesc)
|
||||
}
|
||||
|
||||
// Prevent timing attacks on the state
|
||||
if state := query.Get(queryState); subtle.ConstantTimeCompare([]byte(p.state), []byte(state)) == 0 {
|
||||
return nil, fmt.Errorf("invalid state")
|
||||
}
|
||||
|
||||
code := query.Get(queryCode)
|
||||
if code == "" {
|
||||
return nil, fmt.Errorf("missing code")
|
||||
}
|
||||
|
||||
return p.oAuthConfig.Exchange(
|
||||
req.Context(),
|
||||
code,
|
||||
oauth2.SetAuthURLParam("code_verifier", p.codeVerifier),
|
||||
)
|
||||
}
|
||||
|
||||
func (p *PKCEAuthorizationFlow) parseOAuthToken(token *oauth2.Token) (TokenInfo, error) {
|
||||
tokenInfo := TokenInfo{
|
||||
AccessToken: token.AccessToken,
|
||||
RefreshToken: token.RefreshToken,
|
||||
@@ -193,7 +197,13 @@ func (p *PKCEAuthorizationFlow) handleOAuthToken(token *oauth2.Token) (TokenInfo
|
||||
tokenInfo.IDToken = idToken
|
||||
}
|
||||
|
||||
if err := isValidAccessToken(tokenInfo.GetTokenToUse(), p.providerConfig.Audience); err != nil {
|
||||
// if a provider doesn't support an audience, use the Client ID for token verification
|
||||
audience := p.providerConfig.Audience
|
||||
if audience == "" {
|
||||
audience = p.providerConfig.ClientID
|
||||
}
|
||||
|
||||
if err := isValidAccessToken(tokenInfo.GetTokenToUse(), audience); err != nil {
|
||||
return TokenInfo{}, fmt.Errorf("validate access token failed with error: %v", err)
|
||||
}
|
||||
|
||||
|
||||
@@ -7,7 +7,6 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"reflect"
|
||||
"strings"
|
||||
)
|
||||
|
||||
@@ -44,15 +43,14 @@ func isValidAccessToken(token string, audience string) error {
|
||||
}
|
||||
|
||||
// Audience claim of JWT can be a string or an array of strings
|
||||
typ := reflect.TypeOf(claims.Audience)
|
||||
switch typ.Kind() {
|
||||
case reflect.String:
|
||||
if claims.Audience == audience {
|
||||
switch aud := claims.Audience.(type) {
|
||||
case string:
|
||||
if aud == audience {
|
||||
return nil
|
||||
}
|
||||
case reflect.Slice:
|
||||
for _, aud := range claims.Audience.([]interface{}) {
|
||||
if audience == aud {
|
||||
case []interface{}:
|
||||
for _, audItem := range aud {
|
||||
if audStr, ok := audItem.(string); ok && audStr == audience {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package internal
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"os"
|
||||
@@ -12,16 +13,19 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/client/ssh"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
mgm "github.com/netbirdio/netbird/management/client"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
|
||||
const (
|
||||
// ManagementLegacyPort is the port that was used before by the Management gRPC server.
|
||||
// managementLegacyPortString is the port that was used before by the Management gRPC server.
|
||||
// It is used for backward compatibility now.
|
||||
// NB: hardcoded from github.com/netbirdio/netbird/management/cmd to avoid import
|
||||
ManagementLegacyPort = 33073
|
||||
managementLegacyPortString = "33073"
|
||||
// DefaultManagementURL points to the NetBird's cloud management endpoint
|
||||
DefaultManagementURL = "https://api.wiretrustee.com:443"
|
||||
DefaultManagementURL = "https://api.netbird.io:443"
|
||||
// oldDefaultManagementURL points to the NetBird's old cloud management endpoint
|
||||
oldDefaultManagementURL = "https://api.wiretrustee.com:443"
|
||||
// DefaultAdminURL points to NetBird's cloud management console
|
||||
DefaultAdminURL = "https://app.netbird.io:443"
|
||||
)
|
||||
@@ -31,12 +35,18 @@ var defaultInterfaceBlacklist = []string{iface.WgInterfaceDefault, "wt", "utun",
|
||||
|
||||
// ConfigInput carries configuration changes to the client
|
||||
type ConfigInput struct {
|
||||
ManagementURL string
|
||||
AdminURL string
|
||||
ConfigPath string
|
||||
PreSharedKey *string
|
||||
NATExternalIPs []string
|
||||
CustomDNSAddress []byte
|
||||
ManagementURL string
|
||||
AdminURL string
|
||||
ConfigPath string
|
||||
PreSharedKey *string
|
||||
ServerSSHAllowed *bool
|
||||
NATExternalIPs []string
|
||||
CustomDNSAddress []byte
|
||||
RosenpassEnabled *bool
|
||||
RosenpassPermissive *bool
|
||||
InterfaceName *string
|
||||
WireguardPort *int
|
||||
DisableAutoConnect *bool
|
||||
}
|
||||
|
||||
// Config Configuration type
|
||||
@@ -50,10 +60,13 @@ type Config struct {
|
||||
WgPort int
|
||||
IFaceBlackList []string
|
||||
DisableIPv6Discovery bool
|
||||
RosenpassEnabled bool
|
||||
RosenpassPermissive bool
|
||||
ServerSSHAllowed *bool
|
||||
// SSHKey is a private SSH key in a PEM format
|
||||
SSHKey string
|
||||
|
||||
// ExternalIP mappings, if different than the host interface IP
|
||||
// ExternalIP mappings, if different from the host interface IP
|
||||
//
|
||||
// External IP must not be behind a CGNAT and port-forwarding for incoming UDP packets from WgPort on ExternalIP
|
||||
// to WgPort on host interface IP must be present. This can take form of single port-forwarding rule, 1:1 DNAT
|
||||
@@ -71,6 +84,10 @@ type Config struct {
|
||||
NATExternalIPs []string
|
||||
// CustomDNSAddress sets the DNS resolver listening address in format ip:port
|
||||
CustomDNSAddress string
|
||||
|
||||
// DisableAutoConnect determines whether the client should not start with the service
|
||||
// it's set to false by default due to backwards compatibility
|
||||
DisableAutoConnect bool
|
||||
}
|
||||
|
||||
// ReadConfig read config file and return with Config. If it is not exists create a new with default values
|
||||
@@ -80,6 +97,7 @@ func ReadConfig(configPath string) (*Config, error) {
|
||||
if _, err := util.ReadJson(configPath, config); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return config, nil
|
||||
}
|
||||
|
||||
@@ -136,15 +154,16 @@ func createNewConfig(input ConfigInput) (*Config, error) {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
config := &Config{
|
||||
SSHKey: string(pem),
|
||||
PrivateKey: wgKey,
|
||||
WgIface: iface.WgInterfaceDefault,
|
||||
WgPort: iface.DefaultWgPort,
|
||||
IFaceBlackList: []string{},
|
||||
DisableIPv6Discovery: false,
|
||||
NATExternalIPs: input.NATExternalIPs,
|
||||
CustomDNSAddress: string(input.CustomDNSAddress),
|
||||
ServerSSHAllowed: util.False(),
|
||||
DisableAutoConnect: false,
|
||||
}
|
||||
|
||||
defaultManagementURL, err := parseURL("Management URL", DefaultManagementURL)
|
||||
@@ -161,10 +180,32 @@ func createNewConfig(input ConfigInput) (*Config, error) {
|
||||
config.ManagementURL = URL
|
||||
}
|
||||
|
||||
config.WgPort = iface.DefaultWgPort
|
||||
if input.WireguardPort != nil {
|
||||
config.WgPort = *input.WireguardPort
|
||||
}
|
||||
|
||||
config.WgIface = iface.WgInterfaceDefault
|
||||
if input.InterfaceName != nil {
|
||||
config.WgIface = *input.InterfaceName
|
||||
}
|
||||
|
||||
if input.PreSharedKey != nil {
|
||||
config.PreSharedKey = *input.PreSharedKey
|
||||
}
|
||||
|
||||
if input.RosenpassEnabled != nil {
|
||||
config.RosenpassEnabled = *input.RosenpassEnabled
|
||||
}
|
||||
|
||||
if input.RosenpassPermissive != nil {
|
||||
config.RosenpassPermissive = *input.RosenpassPermissive
|
||||
}
|
||||
|
||||
if input.ServerSSHAllowed != nil {
|
||||
config.ServerSSHAllowed = input.ServerSSHAllowed
|
||||
}
|
||||
|
||||
defaultAdminURL, err := parseURL("Admin URL", DefaultAdminURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -215,12 +256,9 @@ func update(input ConfigInput) (*Config, error) {
|
||||
}
|
||||
|
||||
if input.PreSharedKey != nil && config.PreSharedKey != *input.PreSharedKey {
|
||||
if *input.PreSharedKey != "" {
|
||||
log.Infof("new pre-shared key provides, updated to %s (old value %s)",
|
||||
*input.PreSharedKey, config.PreSharedKey)
|
||||
config.PreSharedKey = *input.PreSharedKey
|
||||
refresh = true
|
||||
}
|
||||
log.Infof("new pre-shared key provided, replacing old key")
|
||||
config.PreSharedKey = *input.PreSharedKey
|
||||
refresh = true
|
||||
}
|
||||
|
||||
if config.SSHKey == "" {
|
||||
@@ -236,6 +274,17 @@ func update(input ConfigInput) (*Config, error) {
|
||||
config.WgPort = iface.DefaultWgPort
|
||||
refresh = true
|
||||
}
|
||||
|
||||
if input.WireguardPort != nil {
|
||||
config.WgPort = *input.WireguardPort
|
||||
refresh = true
|
||||
}
|
||||
|
||||
if input.InterfaceName != nil {
|
||||
config.WgIface = *input.InterfaceName
|
||||
refresh = true
|
||||
}
|
||||
|
||||
if input.NATExternalIPs != nil && len(config.NATExternalIPs) != len(input.NATExternalIPs) {
|
||||
config.NATExternalIPs = input.NATExternalIPs
|
||||
refresh = true
|
||||
@@ -246,6 +295,31 @@ func update(input ConfigInput) (*Config, error) {
|
||||
refresh = true
|
||||
}
|
||||
|
||||
if input.RosenpassEnabled != nil {
|
||||
config.RosenpassEnabled = *input.RosenpassEnabled
|
||||
refresh = true
|
||||
}
|
||||
|
||||
if input.RosenpassPermissive != nil {
|
||||
config.RosenpassPermissive = *input.RosenpassPermissive
|
||||
refresh = true
|
||||
}
|
||||
|
||||
if input.DisableAutoConnect != nil {
|
||||
config.DisableAutoConnect = *input.DisableAutoConnect
|
||||
refresh = true
|
||||
}
|
||||
|
||||
if input.ServerSSHAllowed != nil {
|
||||
config.ServerSSHAllowed = input.ServerSSHAllowed
|
||||
refresh = true
|
||||
}
|
||||
|
||||
if config.ServerSSHAllowed == nil {
|
||||
config.ServerSSHAllowed = util.True()
|
||||
refresh = true
|
||||
}
|
||||
|
||||
if refresh {
|
||||
// since we have new management URL, we need to update config file
|
||||
if err := util.WriteJson(input.ConfigPath, config); err != nil {
|
||||
@@ -273,9 +347,9 @@ func parseURL(serviceName, serviceURL string) (*url.URL, error) {
|
||||
if parsedMgmtURL.Port() == "" {
|
||||
switch parsedMgmtURL.Scheme {
|
||||
case "https":
|
||||
parsedMgmtURL.Host = parsedMgmtURL.Host + ":443"
|
||||
parsedMgmtURL.Host += ":443"
|
||||
case "http":
|
||||
parsedMgmtURL.Host = parsedMgmtURL.Host + ":80"
|
||||
parsedMgmtURL.Host += ":80"
|
||||
default:
|
||||
log.Infof("unable to determine a default port for schema %s in URL %s", parsedMgmtURL.Scheme, serviceURL)
|
||||
}
|
||||
@@ -305,3 +379,86 @@ func configFileIsExists(path string) bool {
|
||||
_, err := os.Stat(path)
|
||||
return !os.IsNotExist(err)
|
||||
}
|
||||
|
||||
// UpdateOldManagementURL checks whether client can switch to the new Management URL with port 443 and the management domain.
|
||||
// If it can switch, then it updates the config and returns a new one. Otherwise, it returns the provided config.
|
||||
// The check is performed only for the NetBird's managed version.
|
||||
func UpdateOldManagementURL(ctx context.Context, config *Config, configPath string) (*Config, error) {
|
||||
|
||||
defaultManagementURL, err := parseURL("Management URL", DefaultManagementURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
parsedOldDefaultManagementURL, err := parseURL("Management URL", oldDefaultManagementURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if config.ManagementURL.Hostname() != defaultManagementURL.Hostname() &&
|
||||
config.ManagementURL.Hostname() != parsedOldDefaultManagementURL.Hostname() {
|
||||
// only do the check for the NetBird's managed version
|
||||
return config, nil
|
||||
}
|
||||
|
||||
var mgmTlsEnabled bool
|
||||
if config.ManagementURL.Scheme == "https" {
|
||||
mgmTlsEnabled = true
|
||||
}
|
||||
|
||||
if !mgmTlsEnabled {
|
||||
// only do the check for HTTPs scheme (the hosted version of the Management service is always HTTPs)
|
||||
return config, nil
|
||||
}
|
||||
|
||||
if config.ManagementURL.Port() != managementLegacyPortString &&
|
||||
config.ManagementURL.Hostname() == defaultManagementURL.Hostname() {
|
||||
return config, nil
|
||||
}
|
||||
|
||||
newURL, err := parseURL("Management URL", fmt.Sprintf("%s://%s:%d",
|
||||
config.ManagementURL.Scheme, defaultManagementURL.Hostname(), 443))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// here we check whether we could switch from the legacy 33073 port to the new 443
|
||||
log.Infof("attempting to switch from the legacy Management URL %s to the new one %s",
|
||||
config.ManagementURL.String(), newURL.String())
|
||||
key, err := wgtypes.ParseKey(config.PrivateKey)
|
||||
if err != nil {
|
||||
log.Infof("couldn't switch to the new Management %s", newURL.String())
|
||||
return config, err
|
||||
}
|
||||
|
||||
client, err := mgm.NewClient(ctx, newURL.Host, key, mgmTlsEnabled)
|
||||
if err != nil {
|
||||
log.Infof("couldn't switch to the new Management %s", newURL.String())
|
||||
return config, err
|
||||
}
|
||||
defer func() {
|
||||
err = client.Close()
|
||||
if err != nil {
|
||||
log.Warnf("failed to close the Management service client %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// gRPC check
|
||||
_, err = client.GetServerPublicKey()
|
||||
if err != nil {
|
||||
log.Infof("couldn't switch to the new Management %s", newURL.String())
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// everything is alright => update the config
|
||||
newConfig, err := UpdateConfig(ConfigInput{
|
||||
ManagementURL: newURL.String(),
|
||||
ConfigPath: configPath,
|
||||
})
|
||||
if err != nil {
|
||||
log.Infof("couldn't switch to the new Management %s", newURL.String())
|
||||
return config, fmt.Errorf("failed updating config file: %v", err)
|
||||
}
|
||||
log.Infof("successfully switched to the new Management URL: %s", newURL.String())
|
||||
|
||||
return newConfig, nil
|
||||
}
|
||||
|
||||
@@ -1,13 +1,16 @@
|
||||
package internal
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/netbirdio/netbird/util"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
|
||||
func TestGetConfig(t *testing.T) {
|
||||
@@ -23,9 +26,6 @@ func TestGetConfig(t *testing.T) {
|
||||
assert.Equal(t, config.ManagementURL.String(), DefaultManagementURL)
|
||||
assert.Equal(t, config.AdminURL.String(), DefaultAdminURL)
|
||||
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
managementURL := "https://test.management.url:33071"
|
||||
adminURL := "https://app.admin.url:443"
|
||||
path := filepath.Join(t.TempDir(), "config.json")
|
||||
@@ -63,22 +63,7 @@ func TestGetConfig(t *testing.T) {
|
||||
assert.Equal(t, config.ManagementURL.String(), managementURL)
|
||||
assert.Equal(t, config.PreSharedKey, preSharedKey)
|
||||
|
||||
// case 4: new empty pre-shared key config -> fetch it
|
||||
newPreSharedKey := ""
|
||||
config, err = UpdateOrCreateConfig(ConfigInput{
|
||||
ManagementURL: managementURL,
|
||||
AdminURL: adminURL,
|
||||
ConfigPath: path,
|
||||
PreSharedKey: &newPreSharedKey,
|
||||
})
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
assert.Equal(t, config.ManagementURL.String(), managementURL)
|
||||
assert.Equal(t, config.PreSharedKey, preSharedKey)
|
||||
|
||||
// case 5: existing config, but new managementURL has been provided -> update config
|
||||
// case 4: existing config, but new managementURL has been provided -> update config
|
||||
newManagementURL := "https://test.newManagement.url:33071"
|
||||
config, err = UpdateOrCreateConfig(ConfigInput{
|
||||
ManagementURL: newManagementURL,
|
||||
@@ -137,3 +122,60 @@ func TestHiddenPreSharedKey(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateOldManagementURL(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
previousManagementURL string
|
||||
expectedManagementURL string
|
||||
fileShouldNotChange bool
|
||||
}{
|
||||
{
|
||||
name: "Update old management URL with legacy port",
|
||||
previousManagementURL: "https://api.wiretrustee.com:33073",
|
||||
expectedManagementURL: DefaultManagementURL,
|
||||
},
|
||||
{
|
||||
name: "Update old management URL",
|
||||
previousManagementURL: oldDefaultManagementURL,
|
||||
expectedManagementURL: DefaultManagementURL,
|
||||
},
|
||||
{
|
||||
name: "No update needed when management URL is up to date",
|
||||
previousManagementURL: DefaultManagementURL,
|
||||
expectedManagementURL: DefaultManagementURL,
|
||||
fileShouldNotChange: true,
|
||||
},
|
||||
{
|
||||
name: "No update needed when not using cloud management",
|
||||
previousManagementURL: "https://netbird.example.com:33073",
|
||||
expectedManagementURL: "https://netbird.example.com:33073",
|
||||
fileShouldNotChange: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
configPath := filepath.Join(tempDir, "config.json")
|
||||
config, err := UpdateOrCreateConfig(ConfigInput{
|
||||
ManagementURL: tt.previousManagementURL,
|
||||
ConfigPath: configPath,
|
||||
})
|
||||
require.NoError(t, err, "failed to create testing config")
|
||||
previousStats, err := os.Stat(configPath)
|
||||
require.NoError(t, err, "failed to create testing config stats")
|
||||
resultConfig, err := UpdateOldManagementURL(context.TODO(), config, configPath)
|
||||
require.NoError(t, err, "got error when updating old management url")
|
||||
require.Equal(t, tt.expectedManagementURL, resultConfig.ManagementURL.String())
|
||||
newStats, err := os.Stat(configPath)
|
||||
require.NoError(t, err, "failed to create testing config stats")
|
||||
switch tt.fileShouldNotChange {
|
||||
case true:
|
||||
require.Equal(t, previousStats.ModTime(), newStats.ModTime(), "file should not change")
|
||||
case false:
|
||||
require.NotEqual(t, previousStats.ModTime(), newStats.ModTime(), "file should have changed")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package internal
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
@@ -13,8 +14,8 @@ import (
|
||||
gstatus "google.golang.org/grpc/status"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/dns"
|
||||
"github.com/netbirdio/netbird/client/internal/listener"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager"
|
||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||
"github.com/netbirdio/netbird/client/ssh"
|
||||
"github.com/netbirdio/netbird/client/system"
|
||||
@@ -22,27 +23,84 @@ import (
|
||||
mgm "github.com/netbirdio/netbird/management/client"
|
||||
mgmProto "github.com/netbirdio/netbird/management/proto"
|
||||
signal "github.com/netbirdio/netbird/signal/client"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
"github.com/netbirdio/netbird/version"
|
||||
)
|
||||
|
||||
// RunClient with main logic.
|
||||
func RunClient(ctx context.Context, config *Config, statusRecorder *peer.Status) error {
|
||||
return runClient(ctx, config, statusRecorder, MobileDependency{})
|
||||
return runClient(ctx, config, statusRecorder, MobileDependency{}, nil, nil, nil, nil)
|
||||
}
|
||||
|
||||
// RunClientWithProbes runs the client's main logic with probes attached
|
||||
func RunClientWithProbes(
|
||||
ctx context.Context,
|
||||
config *Config,
|
||||
statusRecorder *peer.Status,
|
||||
mgmProbe *Probe,
|
||||
signalProbe *Probe,
|
||||
relayProbe *Probe,
|
||||
wgProbe *Probe,
|
||||
) error {
|
||||
return runClient(ctx, config, statusRecorder, MobileDependency{}, mgmProbe, signalProbe, relayProbe, wgProbe)
|
||||
}
|
||||
|
||||
// RunClientMobile with main logic on mobile system
|
||||
func RunClientMobile(ctx context.Context, config *Config, statusRecorder *peer.Status, tunAdapter iface.TunAdapter, iFaceDiscover stdnet.ExternalIFaceDiscover, routeListener routemanager.RouteListener, dnsAddresses []string, dnsReadyListener dns.ReadyListener) error {
|
||||
func RunClientMobile(
|
||||
ctx context.Context,
|
||||
config *Config,
|
||||
statusRecorder *peer.Status,
|
||||
tunAdapter iface.TunAdapter,
|
||||
iFaceDiscover stdnet.ExternalIFaceDiscover,
|
||||
networkChangeListener listener.NetworkChangeListener,
|
||||
dnsAddresses []string,
|
||||
dnsReadyListener dns.ReadyListener,
|
||||
) error {
|
||||
// in case of non Android os these variables will be nil
|
||||
mobileDependency := MobileDependency{
|
||||
TunAdapter: tunAdapter,
|
||||
IFaceDiscover: iFaceDiscover,
|
||||
RouteListener: routeListener,
|
||||
HostDNSAddresses: dnsAddresses,
|
||||
DnsReadyListener: dnsReadyListener,
|
||||
TunAdapter: tunAdapter,
|
||||
IFaceDiscover: iFaceDiscover,
|
||||
NetworkChangeListener: networkChangeListener,
|
||||
HostDNSAddresses: dnsAddresses,
|
||||
DnsReadyListener: dnsReadyListener,
|
||||
}
|
||||
return runClient(ctx, config, statusRecorder, mobileDependency)
|
||||
return runClient(ctx, config, statusRecorder, mobileDependency, nil, nil, nil, nil)
|
||||
}
|
||||
|
||||
func runClient(ctx context.Context, config *Config, statusRecorder *peer.Status, mobileDependency MobileDependency) error {
|
||||
func RunClientiOS(
|
||||
ctx context.Context,
|
||||
config *Config,
|
||||
statusRecorder *peer.Status,
|
||||
fileDescriptor int32,
|
||||
networkChangeListener listener.NetworkChangeListener,
|
||||
dnsManager dns.IosDnsManager,
|
||||
) error {
|
||||
mobileDependency := MobileDependency{
|
||||
FileDescriptor: fileDescriptor,
|
||||
NetworkChangeListener: networkChangeListener,
|
||||
DnsManager: dnsManager,
|
||||
}
|
||||
return runClient(ctx, config, statusRecorder, mobileDependency, nil, nil, nil, nil)
|
||||
}
|
||||
|
||||
func runClient(
|
||||
ctx context.Context,
|
||||
config *Config,
|
||||
statusRecorder *peer.Status,
|
||||
mobileDependency MobileDependency,
|
||||
mgmProbe *Probe,
|
||||
signalProbe *Probe,
|
||||
relayProbe *Probe,
|
||||
wgProbe *Probe,
|
||||
) error {
|
||||
log.Infof("starting NetBird client version %s", version.NetbirdVersion())
|
||||
|
||||
// Check if client was not shut down in a clean way and restore DNS config if required.
|
||||
// Otherwise, we might not be able to connect to the management server to retrieve new config.
|
||||
if err := dns.CheckUncleanShutdown(config.WgIface); err != nil {
|
||||
log.Errorf("checking unclean shutdown error: %s", err)
|
||||
}
|
||||
|
||||
backOff := &backoff.ExponentialBackOff{
|
||||
InitialInterval: time.Second,
|
||||
RandomizationFactor: 1,
|
||||
@@ -91,12 +149,12 @@ func runClient(ctx context.Context, config *Config, statusRecorder *peer.Status,
|
||||
|
||||
engineCtx, cancel := context.WithCancel(ctx)
|
||||
defer func() {
|
||||
statusRecorder.MarkManagementDisconnected()
|
||||
statusRecorder.MarkManagementDisconnected(state.err)
|
||||
statusRecorder.CleanLocalPeerState()
|
||||
cancel()
|
||||
}()
|
||||
|
||||
log.Debugf("conecting to the Management service %s", config.ManagementURL.Host)
|
||||
log.Debugf("connecting to the Management service %s", config.ManagementURL.Host)
|
||||
mgmClient, err := mgm.NewClient(engineCtx, config.ManagementURL.Host, myPrivateKey, mgmTlsEnabled)
|
||||
if err != nil {
|
||||
return wrapErr(gstatus.Errorf(codes.FailedPrecondition, "failed connecting to Management Service : %s", err))
|
||||
@@ -140,8 +198,10 @@ func runClient(ctx context.Context, config *Config, statusRecorder *peer.Status,
|
||||
|
||||
statusRecorder.UpdateSignalAddress(signalURL)
|
||||
|
||||
statusRecorder.MarkSignalDisconnected()
|
||||
defer statusRecorder.MarkSignalDisconnected()
|
||||
statusRecorder.MarkSignalDisconnected(nil)
|
||||
defer func() {
|
||||
statusRecorder.MarkSignalDisconnected(state.err)
|
||||
}()
|
||||
|
||||
// with the global Wiretrustee config in hand connect (just a connection, no stream yet) Signal
|
||||
signalClient, err := connectToSignal(engineCtx, loginResp.GetWiretrusteeConfig(), myPrivateKey)
|
||||
@@ -169,7 +229,7 @@ func runClient(ctx context.Context, config *Config, statusRecorder *peer.Status,
|
||||
return wrapErr(err)
|
||||
}
|
||||
|
||||
engine := NewEngine(engineCtx, cancel, signalClient, mgmClient, engineConfig, mobileDependency, statusRecorder)
|
||||
engine := NewEngineWithProbes(engineCtx, cancel, signalClient, mgmClient, engineConfig, mobileDependency, statusRecorder, mgmProbe, signalProbe, relayProbe, wgProbe)
|
||||
err = engine.Start()
|
||||
if err != nil {
|
||||
log.Errorf("error while starting Netbird Connection Engine: %s", err)
|
||||
@@ -179,8 +239,6 @@ func runClient(ctx context.Context, config *Config, statusRecorder *peer.Status,
|
||||
log.Print("Netbird engine started, my IP is: ", peerConfig.Address)
|
||||
state.Set(StatusConnected)
|
||||
|
||||
statusRecorder.ClientStart()
|
||||
|
||||
<-engineCtx.Done()
|
||||
statusRecorder.ClientTeardown()
|
||||
|
||||
@@ -194,13 +252,14 @@ func runClient(ctx context.Context, config *Config, statusRecorder *peer.Status,
|
||||
|
||||
log.Info("stopped NetBird client")
|
||||
|
||||
if _, err := state.Status(); err == ErrResetConnection {
|
||||
if _, err := state.Status(); errors.Is(err, ErrResetConnection) {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
statusRecorder.ClientStart()
|
||||
err = backoff.Retry(operation, backOff)
|
||||
if err != nil {
|
||||
log.Debugf("exiting client retry loop due to unrecoverable error: %s", err)
|
||||
@@ -224,6 +283,9 @@ func createEngineConfig(key wgtypes.Key, config *Config, peerConfig *mgmProto.Pe
|
||||
SSHKey: []byte(config.SSHKey),
|
||||
NATExternalIPs: config.NATExternalIPs,
|
||||
CustomDNSAddress: config.CustomDNSAddress,
|
||||
RosenpassEnabled: config.RosenpassEnabled,
|
||||
RosenpassPermissive: config.RosenpassPermissive,
|
||||
ServerSSHAllowed: util.ReturnBoolWithDefaultTrue(config.ServerSSHAllowed),
|
||||
}
|
||||
|
||||
if config.PreSharedKey != "" {
|
||||
@@ -272,83 +334,6 @@ func loginToManagement(ctx context.Context, client mgm.Client, pubSSHKey []byte)
|
||||
return loginResp, nil
|
||||
}
|
||||
|
||||
// UpdateOldManagementPort checks whether client can switch to the new Management port 443.
|
||||
// If it can switch, then it updates the config and returns a new one. Otherwise, it returns the provided config.
|
||||
// The check is performed only for the NetBird's managed version.
|
||||
func UpdateOldManagementPort(ctx context.Context, config *Config, configPath string) (*Config, error) {
|
||||
|
||||
defaultManagementURL, err := parseURL("Management URL", DefaultManagementURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if config.ManagementURL.Hostname() != defaultManagementURL.Hostname() {
|
||||
// only do the check for the NetBird's managed version
|
||||
return config, nil
|
||||
}
|
||||
|
||||
var mgmTlsEnabled bool
|
||||
if config.ManagementURL.Scheme == "https" {
|
||||
mgmTlsEnabled = true
|
||||
}
|
||||
|
||||
if !mgmTlsEnabled {
|
||||
// only do the check for HTTPs scheme (the hosted version of the Management service is always HTTPs)
|
||||
return config, nil
|
||||
}
|
||||
|
||||
if mgmTlsEnabled && config.ManagementURL.Port() == fmt.Sprintf("%d", ManagementLegacyPort) {
|
||||
|
||||
newURL, err := parseURL("Management URL", fmt.Sprintf("%s://%s:%d",
|
||||
config.ManagementURL.Scheme, config.ManagementURL.Hostname(), 443))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// here we check whether we could switch from the legacy 33073 port to the new 443
|
||||
log.Infof("attempting to switch from the legacy Management URL %s to the new one %s",
|
||||
config.ManagementURL.String(), newURL.String())
|
||||
key, err := wgtypes.ParseKey(config.PrivateKey)
|
||||
if err != nil {
|
||||
log.Infof("couldn't switch to the new Management %s", newURL.String())
|
||||
return config, err
|
||||
}
|
||||
|
||||
client, err := mgm.NewClient(ctx, newURL.Host, key, mgmTlsEnabled)
|
||||
if err != nil {
|
||||
log.Infof("couldn't switch to the new Management %s", newURL.String())
|
||||
return config, err
|
||||
}
|
||||
defer func() {
|
||||
err = client.Close()
|
||||
if err != nil {
|
||||
log.Warnf("failed to close the Management service client %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// gRPC check
|
||||
_, err = client.GetServerPublicKey()
|
||||
if err != nil {
|
||||
log.Infof("couldn't switch to the new Management %s", newURL.String())
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// everything is alright => update the config
|
||||
newConfig, err := UpdateConfig(ConfigInput{
|
||||
ManagementURL: newURL.String(),
|
||||
ConfigPath: configPath,
|
||||
})
|
||||
if err != nil {
|
||||
log.Infof("couldn't switch to the new Management %s", newURL.String())
|
||||
return config, fmt.Errorf("failed updating config file: %v", err)
|
||||
}
|
||||
log.Infof("successfully switched to the new Management URL: %s", newURL.String())
|
||||
|
||||
return newConfig, nil
|
||||
}
|
||||
|
||||
return config, nil
|
||||
}
|
||||
|
||||
func statusRecorderToMgmConnStateNotifier(statusRecorder *peer.Status) mgm.ConnStateNotifier {
|
||||
var sri interface{} = statusRecorder
|
||||
mgmNotifier, _ := sri.(mgm.ConnStateNotifier)
|
||||
|
||||
@@ -4,9 +4,11 @@ package dns
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/godbus/dbus/v5"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"time"
|
||||
)
|
||||
|
||||
const dbusDefaultFlag = 0
|
||||
@@ -14,6 +16,7 @@ const dbusDefaultFlag = 0
|
||||
func isDbusListenerRunning(dest string, path dbus.ObjectPath) bool {
|
||||
obj, closeConn, err := getDbusObject(dest, path)
|
||||
if err != nil {
|
||||
log.Tracef("error getting dbus object: %s", err)
|
||||
return false
|
||||
}
|
||||
defer closeConn()
|
||||
@@ -21,14 +24,18 @@ func isDbusListenerRunning(dest string, path dbus.ObjectPath) bool {
|
||||
ctx, cancel := context.WithTimeout(context.TODO(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
err = obj.CallWithContext(ctx, "org.freedesktop.DBus.Peer.Ping", 0).Store()
|
||||
return err == nil
|
||||
if err = obj.CallWithContext(ctx, "org.freedesktop.DBus.Peer.Ping", 0).Store(); err != nil {
|
||||
log.Tracef("error calling dbus: %s", err)
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func getDbusObject(dest string, path dbus.ObjectPath) (dbus.BusObject, func(), error) {
|
||||
conn, err := dbus.SystemBus()
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
return nil, nil, fmt.Errorf("get dbus: %w", err)
|
||||
}
|
||||
obj := conn.Object(dest, path)
|
||||
|
||||
|
||||
@@ -5,164 +5,324 @@ package dns
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
fileGeneratedResolvConfContentHeader = "# Generated by NetBird"
|
||||
fileGeneratedResolvConfSearchBeginContent = "search "
|
||||
fileGeneratedResolvConfContentFormat = fileGeneratedResolvConfContentHeader +
|
||||
"\n# If needed you can restore the original file by copying back %s\n\nnameserver %s\n" +
|
||||
fileGeneratedResolvConfSearchBeginContent + "%s\n\n" +
|
||||
"%s\n"
|
||||
fileGeneratedResolvConfContentHeader = "# Generated by NetBird"
|
||||
fileGeneratedResolvConfContentHeaderNextLine = fileGeneratedResolvConfContentHeader + `
|
||||
# If needed you can restore the original file by copying back ` + fileDefaultResolvConfBackupLocation + "\n\n"
|
||||
|
||||
fileDefaultResolvConfBackupLocation = defaultResolvConfPath + ".original.netbird"
|
||||
|
||||
fileMaxLineCharsLimit = 256
|
||||
fileMaxNumberOfSearchDomains = 6
|
||||
)
|
||||
|
||||
const (
|
||||
fileDefaultResolvConfBackupLocation = defaultResolvConfPath + ".original.netbird"
|
||||
fileMaxLineCharsLimit = 256
|
||||
fileMaxNumberOfSearchDomains = 6
|
||||
dnsFailoverTimeout = 4 * time.Second
|
||||
dnsFailoverAttempts = 1
|
||||
)
|
||||
|
||||
var fileSearchLineBeginCharCount = len(fileGeneratedResolvConfSearchBeginContent)
|
||||
|
||||
type fileConfigurator struct {
|
||||
originalPerms os.FileMode
|
||||
repair *repair
|
||||
|
||||
originalPerms os.FileMode
|
||||
nbNameserverIP string
|
||||
}
|
||||
|
||||
func newFileConfigurator() (hostManager, error) {
|
||||
return &fileConfigurator{}, nil
|
||||
fc := &fileConfigurator{}
|
||||
fc.repair = newRepair(defaultResolvConfPath, fc.updateConfig)
|
||||
return fc, nil
|
||||
}
|
||||
|
||||
func (f *fileConfigurator) supportCustomPort() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (f *fileConfigurator) applyDNSConfig(config hostDNSConfig) error {
|
||||
func (f *fileConfigurator) applyDNSConfig(config HostDNSConfig) error {
|
||||
backupFileExist := false
|
||||
_, err := os.Stat(fileDefaultResolvConfBackupLocation)
|
||||
if err == nil {
|
||||
backupFileExist = true
|
||||
}
|
||||
|
||||
if !config.routeAll {
|
||||
if !config.RouteAll {
|
||||
if backupFileExist {
|
||||
err = f.restore()
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to configure DNS for this peer using file manager without a Primary nameserver group. Restoring the original file return err: %s", err)
|
||||
return fmt.Errorf("unable to configure DNS for this peer using file manager without a Primary nameserver group. Restoring the original file return err: %w", err)
|
||||
}
|
||||
}
|
||||
return fmt.Errorf("unable to configure DNS for this peer using file manager without a nameserver group with all domains configured")
|
||||
}
|
||||
managerType, err := getOSDNSManagerType()
|
||||
|
||||
if !backupFileExist {
|
||||
err = f.backup()
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to backup the resolv.conf file: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
nbSearchDomains := searchDomains(config)
|
||||
f.nbNameserverIP = config.ServerIP
|
||||
|
||||
resolvConf, err := parseBackupResolvConf()
|
||||
if err != nil {
|
||||
log.Errorf("could not read original search domains from %s: %s", fileDefaultResolvConfBackupLocation, err)
|
||||
}
|
||||
|
||||
f.repair.stopWatchFileChanges()
|
||||
|
||||
err = f.updateConfig(nbSearchDomains, f.nbNameserverIP, resolvConf)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
switch managerType {
|
||||
case fileManager, netbirdManager:
|
||||
if !backupFileExist {
|
||||
err = f.backup()
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to backup the resolv.conf file")
|
||||
}
|
||||
}
|
||||
default:
|
||||
// todo improve this and maybe restart DNS manager from scratch
|
||||
return fmt.Errorf("something happened and file manager is not your prefered host dns configurator, restart the agent")
|
||||
}
|
||||
f.repair.watchFileChanges(nbSearchDomains, f.nbNameserverIP)
|
||||
return nil
|
||||
}
|
||||
|
||||
var searchDomains string
|
||||
appendedDomains := 0
|
||||
for _, dConf := range config.domains {
|
||||
if dConf.matchOnly || dConf.disabled {
|
||||
continue
|
||||
}
|
||||
if appendedDomains >= fileMaxNumberOfSearchDomains {
|
||||
// lets log all skipped domains
|
||||
log.Infof("already appended %d domains to search list. Skipping append of %s domain", fileMaxNumberOfSearchDomains, dConf.domain)
|
||||
continue
|
||||
}
|
||||
if fileSearchLineBeginCharCount+len(searchDomains) > fileMaxLineCharsLimit {
|
||||
// lets log all skipped domains
|
||||
log.Infof("search list line is larger than %d characters. Skipping append of %s domain", fileMaxLineCharsLimit, dConf.domain)
|
||||
continue
|
||||
}
|
||||
func (f *fileConfigurator) updateConfig(nbSearchDomains []string, nbNameserverIP string, cfg *resolvConf) error {
|
||||
searchDomainList := mergeSearchDomains(nbSearchDomains, cfg.searchDomains)
|
||||
nameServers := generateNsList(nbNameserverIP, cfg)
|
||||
|
||||
searchDomains += " " + dConf.domain
|
||||
appendedDomains++
|
||||
}
|
||||
options := prepareOptionsWithTimeout(cfg.others, int(dnsFailoverTimeout.Seconds()), dnsFailoverAttempts)
|
||||
buf := prepareResolvConfContent(
|
||||
searchDomainList,
|
||||
nameServers,
|
||||
options)
|
||||
|
||||
originalContent, err := os.ReadFile(fileDefaultResolvConfBackupLocation)
|
||||
log.Debugf("creating managed file %s", defaultResolvConfPath)
|
||||
err := os.WriteFile(defaultResolvConfPath, buf.Bytes(), f.originalPerms)
|
||||
if err != nil {
|
||||
log.Errorf("Could not read existing resolv.conf")
|
||||
}
|
||||
content := fmt.Sprintf(fileGeneratedResolvConfContentFormat, fileDefaultResolvConfBackupLocation, config.serverIP, searchDomains, string(originalContent))
|
||||
err = writeDNSConfig(content, defaultResolvConfPath, f.originalPerms)
|
||||
if err != nil {
|
||||
err = f.restore()
|
||||
if err != nil {
|
||||
restoreErr := f.restore()
|
||||
if restoreErr != nil {
|
||||
log.Errorf("attempt to restore default file failed with error: %s", err)
|
||||
}
|
||||
return err
|
||||
return fmt.Errorf("creating resolver file %s. Error: %w", defaultResolvConfPath, err)
|
||||
}
|
||||
log.Infof("created a NetBird managed %s file with your DNS settings. Added %d search domains. Search list: %s", defaultResolvConfPath, appendedDomains, searchDomains)
|
||||
|
||||
log.Infof("created a NetBird managed %s file with the DNS settings. Added %d search domains. Search list: %s", defaultResolvConfPath, len(searchDomainList), searchDomainList)
|
||||
|
||||
// create another backup for unclean shutdown detection right after overwriting the original resolv.conf
|
||||
if err := createUncleanShutdownIndicator(fileDefaultResolvConfBackupLocation, fileManager, nbNameserverIP); err != nil {
|
||||
log.Errorf("failed to create unclean shutdown resolv.conf backup: %s", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *fileConfigurator) restoreHostDNS() error {
|
||||
f.repair.stopWatchFileChanges()
|
||||
return f.restore()
|
||||
}
|
||||
|
||||
func (f *fileConfigurator) backup() error {
|
||||
stats, err := os.Stat(defaultResolvConfPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("got an error while checking stats for %s file. Error: %s", defaultResolvConfPath, err)
|
||||
return fmt.Errorf("checking stats for %s file. Error: %w", defaultResolvConfPath, err)
|
||||
}
|
||||
|
||||
f.originalPerms = stats.Mode()
|
||||
|
||||
err = copyFile(defaultResolvConfPath, fileDefaultResolvConfBackupLocation)
|
||||
if err != nil {
|
||||
return fmt.Errorf("got error while backing up the %s file. Error: %s", defaultResolvConfPath, err)
|
||||
return fmt.Errorf("backing up %s: %w", defaultResolvConfPath, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *fileConfigurator) restore() error {
|
||||
err := copyFile(fileDefaultResolvConfBackupLocation, defaultResolvConfPath)
|
||||
err := removeFirstNbNameserver(fileDefaultResolvConfBackupLocation, f.nbNameserverIP)
|
||||
if err != nil {
|
||||
return fmt.Errorf("got error while restoring the %s file from %s. Error: %s", defaultResolvConfPath, fileDefaultResolvConfBackupLocation, err)
|
||||
log.Errorf("Failed to remove netbird nameserver from %s on backup restore: %s", fileDefaultResolvConfBackupLocation, err)
|
||||
}
|
||||
|
||||
err = copyFile(fileDefaultResolvConfBackupLocation, defaultResolvConfPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("restoring %s from %s: %w", defaultResolvConfPath, fileDefaultResolvConfBackupLocation, err)
|
||||
}
|
||||
|
||||
if err := removeUncleanShutdownIndicator(); err != nil {
|
||||
log.Errorf("failed to remove unclean shutdown resolv.conf backup: %s", err)
|
||||
}
|
||||
|
||||
return os.RemoveAll(fileDefaultResolvConfBackupLocation)
|
||||
}
|
||||
|
||||
func writeDNSConfig(content, fileName string, permissions os.FileMode) error {
|
||||
log.Debugf("creating managed file %s", fileName)
|
||||
var buf bytes.Buffer
|
||||
buf.WriteString(content)
|
||||
err := os.WriteFile(fileName, buf.Bytes(), permissions)
|
||||
func (f *fileConfigurator) restoreUncleanShutdownDNS(storedDNSAddress *netip.Addr) error {
|
||||
resolvConf, err := parseDefaultResolvConf()
|
||||
if err != nil {
|
||||
return fmt.Errorf("got an creating resolver file %s. Error: %s", fileName, err)
|
||||
return fmt.Errorf("parse current resolv.conf: %w", err)
|
||||
}
|
||||
|
||||
// no current nameservers set -> restore
|
||||
if len(resolvConf.nameServers) == 0 {
|
||||
return restoreResolvConfFile()
|
||||
}
|
||||
|
||||
currentDNSAddress, err := netip.ParseAddr(resolvConf.nameServers[0])
|
||||
// not a valid first nameserver -> restore
|
||||
if err != nil {
|
||||
log.Errorf("restoring unclean shutdown: parse dns address %s failed: %s", resolvConf.nameServers[0], err)
|
||||
return restoreResolvConfFile()
|
||||
}
|
||||
|
||||
// current address is still netbird's non-available dns address -> restore
|
||||
// comparing parsed addresses only, to remove ambiguity
|
||||
if currentDNSAddress.String() == storedDNSAddress.String() {
|
||||
return restoreResolvConfFile()
|
||||
}
|
||||
|
||||
log.Info("restoring unclean shutdown: first current nameserver differs from saved nameserver pre-netbird: not restoring")
|
||||
return nil
|
||||
}
|
||||
|
||||
func restoreResolvConfFile() error {
|
||||
log.Debugf("restoring unclean shutdown: restoring %s from %s", defaultResolvConfPath, fileUncleanShutdownResolvConfLocation)
|
||||
|
||||
if err := copyFile(fileUncleanShutdownResolvConfLocation, defaultResolvConfPath); err != nil {
|
||||
return fmt.Errorf("restoring %s from %s: %w", defaultResolvConfPath, fileUncleanShutdownResolvConfLocation, err)
|
||||
}
|
||||
|
||||
if err := removeUncleanShutdownIndicator(); err != nil {
|
||||
log.Errorf("failed to remove unclean shutdown resolv.conf file: %s", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// generateNsList generates a list of nameservers from the config and adds the primary nameserver to the beginning of the list
|
||||
func generateNsList(nbNameserverIP string, cfg *resolvConf) []string {
|
||||
ns := make([]string, 1, len(cfg.nameServers)+1)
|
||||
ns[0] = nbNameserverIP
|
||||
for _, cfgNs := range cfg.nameServers {
|
||||
if nbNameserverIP != cfgNs {
|
||||
ns = append(ns, cfgNs)
|
||||
}
|
||||
}
|
||||
return ns
|
||||
}
|
||||
|
||||
func prepareResolvConfContent(searchDomains, nameServers, others []string) bytes.Buffer {
|
||||
var buf bytes.Buffer
|
||||
buf.WriteString(fileGeneratedResolvConfContentHeaderNextLine)
|
||||
|
||||
for _, cfgLine := range others {
|
||||
buf.WriteString(cfgLine)
|
||||
buf.WriteString("\n")
|
||||
}
|
||||
|
||||
if len(searchDomains) > 0 {
|
||||
buf.WriteString("search ")
|
||||
buf.WriteString(strings.Join(searchDomains, " "))
|
||||
buf.WriteString("\n")
|
||||
}
|
||||
|
||||
for _, ns := range nameServers {
|
||||
buf.WriteString("nameserver ")
|
||||
buf.WriteString(ns)
|
||||
buf.WriteString("\n")
|
||||
}
|
||||
return buf
|
||||
}
|
||||
|
||||
func searchDomains(config HostDNSConfig) []string {
|
||||
listOfDomains := make([]string, 0)
|
||||
for _, dConf := range config.Domains {
|
||||
if dConf.MatchOnly || dConf.Disabled {
|
||||
continue
|
||||
}
|
||||
|
||||
listOfDomains = append(listOfDomains, dConf.Domain)
|
||||
}
|
||||
return listOfDomains
|
||||
}
|
||||
|
||||
// merge search Domains lists and cut off the list if it is too long
|
||||
func mergeSearchDomains(searchDomains []string, originalSearchDomains []string) []string {
|
||||
lineSize := len("search")
|
||||
searchDomainsList := make([]string, 0, len(searchDomains)+len(originalSearchDomains))
|
||||
|
||||
lineSize = validateAndFillSearchDomains(lineSize, &searchDomainsList, searchDomains)
|
||||
_ = validateAndFillSearchDomains(lineSize, &searchDomainsList, originalSearchDomains)
|
||||
|
||||
return searchDomainsList
|
||||
}
|
||||
|
||||
// validateAndFillSearchDomains checks if the search Domains list is not too long and if the line is not too long
|
||||
// extend s slice with vs elements
|
||||
// return with the number of characters in the searchDomains line
|
||||
func validateAndFillSearchDomains(initialLineChars int, s *[]string, vs []string) int {
|
||||
for _, sd := range vs {
|
||||
duplicated := false
|
||||
for _, fs := range *s {
|
||||
if fs == sd {
|
||||
duplicated = true
|
||||
break
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
if duplicated {
|
||||
continue
|
||||
}
|
||||
|
||||
tmpCharsNumber := initialLineChars + 1 + len(sd)
|
||||
if tmpCharsNumber > fileMaxLineCharsLimit {
|
||||
// lets log all skipped Domains
|
||||
log.Infof("search list line is larger than %d characters. Skipping append of %s domain", fileMaxLineCharsLimit, sd)
|
||||
continue
|
||||
}
|
||||
|
||||
initialLineChars = tmpCharsNumber
|
||||
|
||||
if len(*s) >= fileMaxNumberOfSearchDomains {
|
||||
// lets log all skipped Domains
|
||||
log.Infof("already appended %d domains to search list. Skipping append of %s domain", fileMaxNumberOfSearchDomains, sd)
|
||||
continue
|
||||
}
|
||||
*s = append(*s, sd)
|
||||
}
|
||||
|
||||
return initialLineChars
|
||||
}
|
||||
|
||||
func copyFile(src, dest string) error {
|
||||
stats, err := os.Stat(src)
|
||||
if err != nil {
|
||||
return fmt.Errorf("got an error while checking stats for %s file when copying it. Error: %s", src, err)
|
||||
return fmt.Errorf("checking stats for %s file when copying it. Error: %s", src, err)
|
||||
}
|
||||
|
||||
bytesRead, err := os.ReadFile(src)
|
||||
if err != nil {
|
||||
return fmt.Errorf("got an error while reading the file %s file for copy. Error: %s", src, err)
|
||||
return fmt.Errorf("reading the file %s file for copy. Error: %s", src, err)
|
||||
}
|
||||
|
||||
err = os.WriteFile(dest, bytesRead, stats.Mode())
|
||||
if err != nil {
|
||||
return fmt.Errorf("got an writing the destination file %s for copy. Error: %s", dest, err)
|
||||
return fmt.Errorf("writing the destination file %s for copy. Error: %s", dest, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func isContains(subList []string, list []string) bool {
|
||||
for _, sl := range subList {
|
||||
var found bool
|
||||
for _, l := range list {
|
||||
if sl == l {
|
||||
found = true
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
125
client/internal/dns/file_linux_test.go
Normal file
125
client/internal/dns/file_linux_test.go
Normal file
@@ -0,0 +1,125 @@
|
||||
//go:build !android
|
||||
|
||||
package dns
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func Test_mergeSearchDomains(t *testing.T) {
|
||||
searchDomains := []string{"a", "b"}
|
||||
originDomains := []string{"c", "d"}
|
||||
mergedDomains := mergeSearchDomains(searchDomains, originDomains)
|
||||
if len(mergedDomains) != 4 {
|
||||
t.Errorf("invalid len of result domains: %d, want: %d", len(mergedDomains), 4)
|
||||
}
|
||||
}
|
||||
|
||||
func Test_mergeSearchTooMuchDomains(t *testing.T) {
|
||||
searchDomains := []string{"a", "b", "c", "d", "e", "f", "g"}
|
||||
originDomains := []string{"h", "i"}
|
||||
mergedDomains := mergeSearchDomains(searchDomains, originDomains)
|
||||
if len(mergedDomains) != 6 {
|
||||
t.Errorf("invalid len of result domains: %d, want: %d", len(mergedDomains), 6)
|
||||
}
|
||||
}
|
||||
|
||||
func Test_mergeSearchTooMuchDomainsInOrigin(t *testing.T) {
|
||||
searchDomains := []string{"a", "b"}
|
||||
originDomains := []string{"c", "d", "e", "f", "g"}
|
||||
mergedDomains := mergeSearchDomains(searchDomains, originDomains)
|
||||
if len(mergedDomains) != 6 {
|
||||
t.Errorf("invalid len of result domains: %d, want: %d", len(mergedDomains), 6)
|
||||
}
|
||||
}
|
||||
|
||||
func Test_mergeSearchTooLongDomain(t *testing.T) {
|
||||
searchDomains := []string{getLongLine()}
|
||||
originDomains := []string{"b"}
|
||||
mergedDomains := mergeSearchDomains(searchDomains, originDomains)
|
||||
if len(mergedDomains) != 1 {
|
||||
t.Errorf("invalid len of result domains: %d, want: %d", len(mergedDomains), 1)
|
||||
}
|
||||
|
||||
searchDomains = []string{"b"}
|
||||
originDomains = []string{getLongLine()}
|
||||
|
||||
mergedDomains = mergeSearchDomains(searchDomains, originDomains)
|
||||
if len(mergedDomains) != 1 {
|
||||
t.Errorf("invalid len of result domains: %d, want: %d", len(mergedDomains), 1)
|
||||
}
|
||||
}
|
||||
|
||||
func Test_isContains(t *testing.T) {
|
||||
type args struct {
|
||||
subList []string
|
||||
list []string
|
||||
}
|
||||
tests := []struct {
|
||||
args args
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
args: args{
|
||||
subList: []string{"a", "b", "c"},
|
||||
list: []string{"a", "b", "c"},
|
||||
},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
args: args{
|
||||
subList: []string{"a"},
|
||||
list: []string{"a", "b", "c"},
|
||||
},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
args: args{
|
||||
subList: []string{"d"},
|
||||
list: []string{"a", "b", "c"},
|
||||
},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
args: args{
|
||||
subList: []string{"a"},
|
||||
list: []string{},
|
||||
},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
args: args{
|
||||
subList: []string{},
|
||||
list: []string{"b"},
|
||||
},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
args: args{
|
||||
subList: []string{},
|
||||
list: []string{},
|
||||
},
|
||||
want: true,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run("list check test", func(t *testing.T) {
|
||||
if got := isContains(tt.args.subList, tt.args.list); got != tt.want {
|
||||
t.Errorf("isContains() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func getLongLine() string {
|
||||
x := "search "
|
||||
for {
|
||||
for i := 0; i <= 9; i++ {
|
||||
if len(x) > fileMaxLineCharsLimit {
|
||||
return x
|
||||
}
|
||||
x = fmt.Sprintf("%s%d", x, i)
|
||||
}
|
||||
}
|
||||
}
|
||||
168
client/internal/dns/file_parser_linux.go
Normal file
168
client/internal/dns/file_parser_linux.go
Normal file
@@ -0,0 +1,168 @@
|
||||
//go:build !android
|
||||
|
||||
package dns
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultResolvConfPath = "/etc/resolv.conf"
|
||||
)
|
||||
|
||||
var timeoutRegex = regexp.MustCompile(`timeout:\d+`)
|
||||
var attemptsRegex = regexp.MustCompile(`attempts:\d+`)
|
||||
|
||||
type resolvConf struct {
|
||||
nameServers []string
|
||||
searchDomains []string
|
||||
others []string
|
||||
}
|
||||
|
||||
func (r *resolvConf) String() string {
|
||||
return fmt.Sprintf("search domains: %v, name servers: %v, others: %s", r.searchDomains, r.nameServers, r.others)
|
||||
}
|
||||
|
||||
func parseDefaultResolvConf() (*resolvConf, error) {
|
||||
return parseResolvConfFile(defaultResolvConfPath)
|
||||
}
|
||||
|
||||
func parseBackupResolvConf() (*resolvConf, error) {
|
||||
return parseResolvConfFile(fileDefaultResolvConfBackupLocation)
|
||||
}
|
||||
|
||||
func parseResolvConfFile(resolvConfFile string) (*resolvConf, error) {
|
||||
rconf := &resolvConf{
|
||||
searchDomains: make([]string, 0),
|
||||
nameServers: make([]string, 0),
|
||||
others: make([]string, 0),
|
||||
}
|
||||
|
||||
file, err := os.Open(resolvConfFile)
|
||||
if err != nil {
|
||||
return rconf, fmt.Errorf("failed to open %s file: %w", resolvConfFile, err)
|
||||
}
|
||||
defer func() {
|
||||
if err := file.Close(); err != nil {
|
||||
log.Errorf("failed closing %s: %s", resolvConfFile, err)
|
||||
}
|
||||
}()
|
||||
|
||||
cur, err := os.ReadFile(resolvConfFile)
|
||||
if err != nil {
|
||||
return rconf, fmt.Errorf("failed to read %s file: %w", resolvConfFile, err)
|
||||
}
|
||||
|
||||
if len(cur) == 0 {
|
||||
return rconf, fmt.Errorf("file is empty")
|
||||
}
|
||||
|
||||
for _, line := range strings.Split(string(cur), "\n") {
|
||||
line = strings.TrimSpace(line)
|
||||
|
||||
if strings.HasPrefix(line, "#") {
|
||||
continue
|
||||
}
|
||||
|
||||
if strings.HasPrefix(line, "domain") {
|
||||
continue
|
||||
}
|
||||
|
||||
if strings.HasPrefix(line, "options") && strings.Contains(line, "rotate") {
|
||||
line = strings.ReplaceAll(line, "rotate", "")
|
||||
splitLines := strings.Fields(line)
|
||||
if len(splitLines) == 1 {
|
||||
continue
|
||||
}
|
||||
line = strings.Join(splitLines, " ")
|
||||
}
|
||||
|
||||
if strings.HasPrefix(line, "search") {
|
||||
splitLines := strings.Fields(line)
|
||||
if len(splitLines) < 2 {
|
||||
continue
|
||||
}
|
||||
|
||||
rconf.searchDomains = splitLines[1:]
|
||||
continue
|
||||
}
|
||||
|
||||
if strings.HasPrefix(line, "nameserver") {
|
||||
splitLines := strings.Fields(line)
|
||||
if len(splitLines) != 2 {
|
||||
continue
|
||||
}
|
||||
rconf.nameServers = append(rconf.nameServers, splitLines[1])
|
||||
continue
|
||||
}
|
||||
|
||||
if line != "" {
|
||||
rconf.others = append(rconf.others, line)
|
||||
}
|
||||
}
|
||||
return rconf, nil
|
||||
}
|
||||
|
||||
// prepareOptionsWithTimeout appends timeout to existing options if it doesn't exist,
|
||||
// otherwise it adds a new option with timeout and attempts.
|
||||
func prepareOptionsWithTimeout(input []string, timeout int, attempts int) []string {
|
||||
configs := make([]string, len(input))
|
||||
copy(configs, input)
|
||||
|
||||
for i, config := range configs {
|
||||
if strings.HasPrefix(config, "options") {
|
||||
config = strings.ReplaceAll(config, "rotate", "")
|
||||
config = strings.Join(strings.Fields(config), " ")
|
||||
|
||||
if strings.Contains(config, "timeout:") {
|
||||
config = timeoutRegex.ReplaceAllString(config, fmt.Sprintf("timeout:%d", timeout))
|
||||
} else {
|
||||
config = strings.Replace(config, "options ", fmt.Sprintf("options timeout:%d ", timeout), 1)
|
||||
}
|
||||
|
||||
if strings.Contains(config, "attempts:") {
|
||||
config = attemptsRegex.ReplaceAllString(config, fmt.Sprintf("attempts:%d", attempts))
|
||||
} else {
|
||||
config = strings.Replace(config, "options ", fmt.Sprintf("options attempts:%d ", attempts), 1)
|
||||
}
|
||||
|
||||
configs[i] = config
|
||||
return configs
|
||||
}
|
||||
}
|
||||
|
||||
return append(configs, fmt.Sprintf("options timeout:%d attempts:%d", timeout, attempts))
|
||||
}
|
||||
|
||||
// removeFirstNbNameserver removes the given nameserver from the given file if it is in the first position
|
||||
// and writes the file back to the original location
|
||||
func removeFirstNbNameserver(filename, nameserverIP string) error {
|
||||
resolvConf, err := parseResolvConfFile(filename)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parse backup resolv.conf: %w", err)
|
||||
}
|
||||
content, err := os.ReadFile(filename)
|
||||
if err != nil {
|
||||
return fmt.Errorf("read %s: %w", filename, err)
|
||||
}
|
||||
|
||||
if len(resolvConf.nameServers) > 1 && resolvConf.nameServers[0] == nameserverIP {
|
||||
newContent := strings.Replace(string(content), fmt.Sprintf("nameserver %s\n", nameserverIP), "", 1)
|
||||
|
||||
stat, err := os.Stat(filename)
|
||||
if err != nil {
|
||||
return fmt.Errorf("stat %s: %w", filename, err)
|
||||
}
|
||||
if err := os.WriteFile(filename, []byte(newContent), stat.Mode()); err != nil {
|
||||
return fmt.Errorf("write %s: %w", filename, err)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
304
client/internal/dns/file_parser_linux_test.go
Normal file
304
client/internal/dns/file_parser_linux_test.go
Normal file
@@ -0,0 +1,304 @@
|
||||
//go:build !android
|
||||
|
||||
package dns
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func Test_parseResolvConf(t *testing.T) {
|
||||
testCases := []struct {
|
||||
input string
|
||||
expectedSearch []string
|
||||
expectedNS []string
|
||||
expectedOther []string
|
||||
}{
|
||||
{
|
||||
input: `domain example.org
|
||||
search example.org
|
||||
nameserver 192.168.0.1
|
||||
`,
|
||||
expectedSearch: []string{"example.org"},
|
||||
expectedNS: []string{"192.168.0.1"},
|
||||
expectedOther: []string{},
|
||||
},
|
||||
{
|
||||
input: `# This is /run/systemd/resolve/resolv.conf managed by man:systemd-resolved(8).
|
||||
# Do not edit.
|
||||
#
|
||||
# This file might be symlinked as /etc/resolv.conf. If you're looking at
|
||||
# /etc/resolv.conf and seeing this text, you have followed the symlink.
|
||||
#
|
||||
# This is a dynamic resolv.conf file for connecting local clients directly to
|
||||
# all known uplink DNS servers. This file lists all configured search domains.
|
||||
#
|
||||
# Third party programs should typically not access this file directly, but only
|
||||
# through the symlink at /etc/resolv.conf. To manage man:resolv.conf(5) in a
|
||||
# different way, replace this symlink by a static file or a different symlink.
|
||||
#
|
||||
# See man:systemd-resolved.service(8) for details about the supported modes of
|
||||
# operation for /etc/resolv.conf.
|
||||
|
||||
nameserver 192.168.2.1
|
||||
nameserver 100.81.99.197
|
||||
search netbird.cloud
|
||||
`,
|
||||
expectedSearch: []string{"netbird.cloud"},
|
||||
expectedNS: []string{"192.168.2.1", "100.81.99.197"},
|
||||
expectedOther: []string{},
|
||||
},
|
||||
{
|
||||
input: `# This is /run/systemd/resolve/resolv.conf managed by man:systemd-resolved(8).
|
||||
# Do not edit.
|
||||
#
|
||||
# This file might be symlinked as /etc/resolv.conf. If you're looking at
|
||||
# /etc/resolv.conf and seeing this text, you have followed the symlink.
|
||||
#
|
||||
# This is a dynamic resolv.conf file for connecting local clients directly to
|
||||
# all known uplink DNS servers. This file lists all configured search domains.
|
||||
#
|
||||
# Third party programs should typically not access this file directly, but only
|
||||
# through the symlink at /etc/resolv.conf. To manage man:resolv.conf(5) in a
|
||||
# different way, replace this symlink by a static file or a different symlink.
|
||||
#
|
||||
# See man:systemd-resolved.service(8) for details about the supported modes of
|
||||
# operation for /etc/resolv.conf.
|
||||
|
||||
nameserver 192.168.2.1
|
||||
nameserver 100.81.99.197
|
||||
search netbird.cloud
|
||||
options debug
|
||||
`,
|
||||
expectedSearch: []string{"netbird.cloud"},
|
||||
expectedNS: []string{"192.168.2.1", "100.81.99.197"},
|
||||
expectedOther: []string{"options debug"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, testCase := range testCases {
|
||||
testCase := testCase
|
||||
t.Run("test", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
tmpResolvConf := filepath.Join(t.TempDir(), "resolv.conf")
|
||||
err := os.WriteFile(tmpResolvConf, []byte(testCase.input), 0644)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
cfg, err := parseResolvConfFile(tmpResolvConf)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
ok := compareLists(cfg.searchDomains, testCase.expectedSearch)
|
||||
if !ok {
|
||||
t.Errorf("invalid parse result for search domains, expected: %v, got: %v", testCase.expectedSearch, cfg.searchDomains)
|
||||
}
|
||||
|
||||
ok = compareLists(cfg.nameServers, testCase.expectedNS)
|
||||
if !ok {
|
||||
t.Errorf("invalid parse result for ns domains, expected: %v, got: %v", testCase.expectedNS, cfg.nameServers)
|
||||
}
|
||||
|
||||
ok = compareLists(cfg.others, testCase.expectedOther)
|
||||
if !ok {
|
||||
t.Errorf("invalid parse result for others, expected: %v, got: %v", testCase.expectedOther, cfg.others)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func compareLists(search []string, search2 []string) bool {
|
||||
if len(search) != len(search2) {
|
||||
return false
|
||||
}
|
||||
for i, v := range search {
|
||||
if v != search2[i] {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func Test_emptyFile(t *testing.T) {
|
||||
cfg, err := parseResolvConfFile("/tmp/nothing")
|
||||
if err == nil {
|
||||
t.Errorf("expected error, got nil")
|
||||
}
|
||||
if len(cfg.others) != 0 || len(cfg.searchDomains) != 0 || len(cfg.nameServers) != 0 {
|
||||
t.Errorf("expected empty config, got %v", cfg)
|
||||
}
|
||||
}
|
||||
|
||||
func Test_symlink(t *testing.T) {
|
||||
input := `# This is /run/systemd/resolve/resolv.conf managed by man:systemd-resolved(8).
|
||||
# Do not edit.
|
||||
#
|
||||
# This file might be symlinked as /etc/resolv.conf. If you're looking at
|
||||
# /etc/resolv.conf and seeing this text, you have followed the symlink.
|
||||
#
|
||||
# This is a dynamic resolv.conf file for connecting local clients directly to
|
||||
# all known uplink DNS servers. This file lists all configured search domains.
|
||||
#
|
||||
# Third party programs should typically not access this file directly, but only
|
||||
# through the symlink at /etc/resolv.conf. To manage man:resolv.conf(5) in a
|
||||
# different way, replace this symlink by a static file or a different symlink.
|
||||
#
|
||||
# See man:systemd-resolved.service(8) for details about the supported modes of
|
||||
# operation for /etc/resolv.conf.
|
||||
|
||||
nameserver 192.168.0.1
|
||||
`
|
||||
|
||||
tmpResolvConf := filepath.Join(t.TempDir(), "resolv.conf")
|
||||
err := os.WriteFile(tmpResolvConf, []byte(input), 0644)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
tmpLink := filepath.Join(t.TempDir(), "symlink")
|
||||
err = os.Symlink(tmpResolvConf, tmpLink)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
cfg, err := parseResolvConfFile(tmpLink)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if len(cfg.nameServers) != 1 {
|
||||
t.Errorf("unexpected resolv.conf content: %v", cfg)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPrepareOptionsWithTimeout(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
others []string
|
||||
timeout int
|
||||
attempts int
|
||||
expected []string
|
||||
}{
|
||||
{
|
||||
name: "Append new options with timeout and attempts",
|
||||
others: []string{"some config"},
|
||||
timeout: 2,
|
||||
attempts: 2,
|
||||
expected: []string{"some config", "options timeout:2 attempts:2"},
|
||||
},
|
||||
{
|
||||
name: "Modify existing options to exclude rotate and include timeout and attempts",
|
||||
others: []string{"some config", "options rotate someother"},
|
||||
timeout: 3,
|
||||
attempts: 2,
|
||||
expected: []string{"some config", "options attempts:2 timeout:3 someother"},
|
||||
},
|
||||
{
|
||||
name: "Existing options with timeout and attempts are updated",
|
||||
others: []string{"some config", "options timeout:4 attempts:3"},
|
||||
timeout: 5,
|
||||
attempts: 4,
|
||||
expected: []string{"some config", "options timeout:5 attempts:4"},
|
||||
},
|
||||
{
|
||||
name: "Modify existing options, add missing attempts before timeout",
|
||||
others: []string{"some config", "options timeout:4"},
|
||||
timeout: 4,
|
||||
attempts: 3,
|
||||
expected: []string{"some config", "options attempts:3 timeout:4"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
result := prepareOptionsWithTimeout(tc.others, tc.timeout, tc.attempts)
|
||||
assert.Equal(t, tc.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRemoveFirstNbNameserver(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
content string
|
||||
ipToRemove string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "Unrelated nameservers with comments and options",
|
||||
content: `# This is a comment
|
||||
options rotate
|
||||
nameserver 1.1.1.1
|
||||
# Another comment
|
||||
nameserver 8.8.4.4
|
||||
search example.com`,
|
||||
ipToRemove: "9.9.9.9",
|
||||
expected: `# This is a comment
|
||||
options rotate
|
||||
nameserver 1.1.1.1
|
||||
# Another comment
|
||||
nameserver 8.8.4.4
|
||||
search example.com`,
|
||||
},
|
||||
{
|
||||
name: "First nameserver matches",
|
||||
content: `search example.com
|
||||
nameserver 9.9.9.9
|
||||
# oof, a comment
|
||||
nameserver 8.8.4.4
|
||||
options attempts:5`,
|
||||
ipToRemove: "9.9.9.9",
|
||||
expected: `search example.com
|
||||
# oof, a comment
|
||||
nameserver 8.8.4.4
|
||||
options attempts:5`,
|
||||
},
|
||||
{
|
||||
name: "Target IP not the first nameserver",
|
||||
// nolint:dupword
|
||||
content: `# Comment about the first nameserver
|
||||
nameserver 8.8.4.4
|
||||
# Comment before our target
|
||||
nameserver 9.9.9.9
|
||||
options timeout:2`,
|
||||
ipToRemove: "9.9.9.9",
|
||||
// nolint:dupword
|
||||
expected: `# Comment about the first nameserver
|
||||
nameserver 8.8.4.4
|
||||
# Comment before our target
|
||||
nameserver 9.9.9.9
|
||||
options timeout:2`,
|
||||
},
|
||||
{
|
||||
name: "Only nameserver matches",
|
||||
content: `options debug
|
||||
nameserver 9.9.9.9
|
||||
search localdomain`,
|
||||
ipToRemove: "9.9.9.9",
|
||||
expected: `options debug
|
||||
nameserver 9.9.9.9
|
||||
search localdomain`,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
tempFile := filepath.Join(tempDir, "resolv.conf")
|
||||
err := os.WriteFile(tempFile, []byte(tc.content), 0644)
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = removeFirstNbNameserver(tempFile, tc.ipToRemove)
|
||||
assert.NoError(t, err)
|
||||
|
||||
content, err := os.ReadFile(tempFile)
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.Equal(t, tc.expected, string(content), "The resulting content should match the expected output.")
|
||||
})
|
||||
}
|
||||
}
|
||||
159
client/internal/dns/file_repair_linux.go
Normal file
159
client/internal/dns/file_repair_linux.go
Normal file
@@ -0,0 +1,159 @@
|
||||
//go:build !android
|
||||
|
||||
package dns
|
||||
|
||||
import (
|
||||
"path"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
|
||||
"github.com/fsnotify/fsnotify"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
var (
|
||||
eventTypes = []fsnotify.Op{
|
||||
fsnotify.Create,
|
||||
fsnotify.Write,
|
||||
fsnotify.Remove,
|
||||
fsnotify.Rename,
|
||||
}
|
||||
)
|
||||
|
||||
type repairConfFn func([]string, string, *resolvConf) error
|
||||
|
||||
type repair struct {
|
||||
operationFile string
|
||||
updateFn repairConfFn
|
||||
watchDir string
|
||||
|
||||
inotify *fsnotify.Watcher
|
||||
inotifyWg sync.WaitGroup
|
||||
}
|
||||
|
||||
func newRepair(operationFile string, updateFn repairConfFn) *repair {
|
||||
targetFile := targetFile(operationFile)
|
||||
return &repair{
|
||||
operationFile: targetFile,
|
||||
watchDir: path.Dir(targetFile),
|
||||
updateFn: updateFn,
|
||||
}
|
||||
}
|
||||
|
||||
func (f *repair) watchFileChanges(nbSearchDomains []string, nbNameserverIP string) {
|
||||
if f.inotify != nil {
|
||||
return
|
||||
}
|
||||
|
||||
log.Infof("start to watch resolv.conf: %s", f.operationFile)
|
||||
inotify, err := fsnotify.NewWatcher()
|
||||
if err != nil {
|
||||
log.Errorf("failed to start inotify watcher for resolv.conf: %s", err)
|
||||
return
|
||||
}
|
||||
f.inotify = inotify
|
||||
|
||||
f.inotifyWg.Add(1)
|
||||
go func() {
|
||||
defer f.inotifyWg.Done()
|
||||
for event := range f.inotify.Events {
|
||||
if !f.isEventRelevant(event) {
|
||||
continue
|
||||
}
|
||||
|
||||
log.Tracef("%s changed, check if it is broken", f.operationFile)
|
||||
|
||||
rConf, err := parseResolvConfFile(f.operationFile)
|
||||
if err != nil {
|
||||
log.Warnf("failed to parse resolv conf: %s", err)
|
||||
continue
|
||||
}
|
||||
|
||||
log.Debugf("check resolv.conf parameters: %s", rConf)
|
||||
if !isNbParamsMissing(nbSearchDomains, nbNameserverIP, rConf) {
|
||||
log.Tracef("resolv.conf still correct, skip the update")
|
||||
continue
|
||||
}
|
||||
log.Info("broken params in resolv.conf, repairing it...")
|
||||
|
||||
err = f.inotify.Remove(f.watchDir)
|
||||
if err != nil {
|
||||
log.Errorf("failed to rm inotify watch for resolv.conf: %s", err)
|
||||
}
|
||||
|
||||
err = f.updateFn(nbSearchDomains, nbNameserverIP, rConf)
|
||||
if err != nil {
|
||||
log.Errorf("failed to repair resolv.conf: %v", err)
|
||||
}
|
||||
|
||||
err = f.inotify.Add(f.watchDir)
|
||||
if err != nil {
|
||||
log.Errorf("failed to re-add inotify watch for resolv.conf: %s", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
err = f.inotify.Add(f.watchDir)
|
||||
if err != nil {
|
||||
log.Errorf("failed to add inotify watch for resolv.conf: %s", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func (f *repair) stopWatchFileChanges() {
|
||||
if f.inotify == nil {
|
||||
return
|
||||
}
|
||||
err := f.inotify.Close()
|
||||
if err != nil {
|
||||
log.Warnf("failed to close resolv.conf inotify: %v", err)
|
||||
}
|
||||
f.inotifyWg.Wait()
|
||||
f.inotify = nil
|
||||
}
|
||||
|
||||
func (f *repair) isEventRelevant(event fsnotify.Event) bool {
|
||||
var ok bool
|
||||
for _, et := range eventTypes {
|
||||
if event.Has(et) {
|
||||
ok = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
if event.Name == f.operationFile {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// nbParamsAreMissing checks if the resolv.conf file contains all the parameters that NetBird needs
|
||||
// check the NetBird related nameserver IP at the first place
|
||||
// check the NetBird related search domains in the search domains list
|
||||
func isNbParamsMissing(nbSearchDomains []string, nbNameserverIP string, rConf *resolvConf) bool {
|
||||
if !isContains(nbSearchDomains, rConf.searchDomains) {
|
||||
return true
|
||||
}
|
||||
|
||||
if len(rConf.nameServers) == 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
if rConf.nameServers[0] != nbNameserverIP {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func targetFile(filename string) string {
|
||||
target, err := filepath.EvalSymlinks(filename)
|
||||
if err != nil {
|
||||
log.Errorf("evarl err: %s", err)
|
||||
}
|
||||
return target
|
||||
}
|
||||
175
client/internal/dns/file_repair_linux_test.go
Normal file
175
client/internal/dns/file_repair_linux_test.go
Normal file
@@ -0,0 +1,175 @@
|
||||
//go:build !android
|
||||
|
||||
package dns
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
_ = util.InitLog("debug", "console")
|
||||
code := m.Run()
|
||||
os.Exit(code)
|
||||
}
|
||||
|
||||
func Test_newRepairtmp(t *testing.T) {
|
||||
type args struct {
|
||||
resolvConfContent string
|
||||
touchedConfContent string
|
||||
wantChange bool
|
||||
}
|
||||
tests := []args{
|
||||
{
|
||||
resolvConfContent: `
|
||||
nameserver 10.0.0.1
|
||||
nameserver 8.8.8.8
|
||||
searchdomain netbird.cloud something`,
|
||||
|
||||
touchedConfContent: `
|
||||
nameserver 8.8.8.8
|
||||
searchdomain netbird.cloud something`,
|
||||
wantChange: true,
|
||||
},
|
||||
{
|
||||
resolvConfContent: `
|
||||
nameserver 10.0.0.1
|
||||
nameserver 8.8.8.8
|
||||
searchdomain netbird.cloud something`,
|
||||
|
||||
touchedConfContent: `
|
||||
nameserver 10.0.0.1
|
||||
nameserver 8.8.8.8
|
||||
searchdomain netbird.cloud something somethingelse`,
|
||||
wantChange: false,
|
||||
},
|
||||
{
|
||||
resolvConfContent: `
|
||||
nameserver 10.0.0.1
|
||||
nameserver 8.8.8.8
|
||||
searchdomain netbird.cloud something`,
|
||||
|
||||
touchedConfContent: `
|
||||
nameserver 10.0.0.1
|
||||
searchdomain netbird.cloud something`,
|
||||
wantChange: false,
|
||||
},
|
||||
{
|
||||
resolvConfContent: `
|
||||
nameserver 10.0.0.1
|
||||
nameserver 8.8.8.8
|
||||
searchdomain netbird.cloud something`,
|
||||
|
||||
touchedConfContent: `
|
||||
searchdomain something`,
|
||||
wantChange: true,
|
||||
},
|
||||
{
|
||||
resolvConfContent: `
|
||||
nameserver 10.0.0.1
|
||||
nameserver 8.8.8.8
|
||||
searchdomain netbird.cloud something`,
|
||||
|
||||
touchedConfContent: `
|
||||
nameserver 10.0.0.1`,
|
||||
wantChange: true,
|
||||
},
|
||||
{
|
||||
resolvConfContent: `
|
||||
nameserver 10.0.0.1
|
||||
nameserver 8.8.8.8
|
||||
searchdomain netbird.cloud something`,
|
||||
|
||||
touchedConfContent: `
|
||||
nameserver 8.8.8.8`,
|
||||
wantChange: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
t.Run("test", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
workDir := t.TempDir()
|
||||
operationFile := workDir + "/resolv.conf"
|
||||
err := os.WriteFile(operationFile, []byte(tt.resolvConfContent), 0755)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to write out resolv.conf: %s", err)
|
||||
}
|
||||
|
||||
var changed bool
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||
updateFn := func([]string, string, *resolvConf) error {
|
||||
changed = true
|
||||
cancel()
|
||||
return nil
|
||||
}
|
||||
|
||||
r := newRepair(operationFile, updateFn)
|
||||
r.watchFileChanges([]string{"netbird.cloud"}, "10.0.0.1")
|
||||
|
||||
err = os.WriteFile(operationFile, []byte(tt.touchedConfContent), 0755)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to write out resolv.conf: %s", err)
|
||||
}
|
||||
|
||||
<-ctx.Done()
|
||||
|
||||
r.stopWatchFileChanges()
|
||||
|
||||
if changed != tt.wantChange {
|
||||
t.Errorf("unexpected result: want: %v, got: %v", tt.wantChange, changed)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_newRepairSymlink(t *testing.T) {
|
||||
resolvConfContent := `
|
||||
nameserver 10.0.0.1
|
||||
nameserver 8.8.8.8
|
||||
searchdomain netbird.cloud something`
|
||||
|
||||
modifyContent := `nameserver 8.8.8.8`
|
||||
|
||||
tmpResolvConf := filepath.Join(t.TempDir(), "resolv.conf")
|
||||
err := os.WriteFile(tmpResolvConf, []byte(resolvConfContent), 0644)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
tmpLink := filepath.Join(t.TempDir(), "symlink")
|
||||
err = os.Symlink(tmpResolvConf, tmpLink)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
var changed bool
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||
updateFn := func([]string, string, *resolvConf) error {
|
||||
changed = true
|
||||
cancel()
|
||||
return nil
|
||||
}
|
||||
|
||||
r := newRepair(tmpLink, updateFn)
|
||||
r.watchFileChanges([]string{"netbird.cloud"}, "10.0.0.1")
|
||||
|
||||
err = os.WriteFile(tmpLink, []byte(modifyContent), 0755)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to write out resolv.conf: %s", err)
|
||||
}
|
||||
|
||||
<-ctx.Done()
|
||||
|
||||
r.stopWatchFileChanges()
|
||||
|
||||
if changed != true {
|
||||
t.Errorf("unexpected result: want: %v, got: %v", true, false)
|
||||
}
|
||||
}
|
||||
@@ -2,37 +2,40 @@ package dns
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"strings"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
)
|
||||
|
||||
type hostManager interface {
|
||||
applyDNSConfig(config hostDNSConfig) error
|
||||
applyDNSConfig(config HostDNSConfig) error
|
||||
restoreHostDNS() error
|
||||
supportCustomPort() bool
|
||||
restoreUncleanShutdownDNS(storedDNSAddress *netip.Addr) error
|
||||
}
|
||||
|
||||
type hostDNSConfig struct {
|
||||
domains []domainConfig
|
||||
routeAll bool
|
||||
serverIP string
|
||||
serverPort int
|
||||
type HostDNSConfig struct {
|
||||
Domains []DomainConfig `json:"domains"`
|
||||
RouteAll bool `json:"routeAll"`
|
||||
ServerIP string `json:"serverIP"`
|
||||
ServerPort int `json:"serverPort"`
|
||||
}
|
||||
|
||||
type domainConfig struct {
|
||||
disabled bool
|
||||
domain string
|
||||
matchOnly bool
|
||||
type DomainConfig struct {
|
||||
Disabled bool `json:"disabled"`
|
||||
Domain string `json:"domain"`
|
||||
MatchOnly bool `json:"matchOnly"`
|
||||
}
|
||||
|
||||
type mockHostConfigurator struct {
|
||||
applyDNSConfigFunc func(config hostDNSConfig) error
|
||||
restoreHostDNSFunc func() error
|
||||
supportCustomPortFunc func() bool
|
||||
applyDNSConfigFunc func(config HostDNSConfig) error
|
||||
restoreHostDNSFunc func() error
|
||||
supportCustomPortFunc func() bool
|
||||
restoreUncleanShutdownDNSFunc func(*netip.Addr) error
|
||||
}
|
||||
|
||||
func (m *mockHostConfigurator) applyDNSConfig(config hostDNSConfig) error {
|
||||
func (m *mockHostConfigurator) applyDNSConfig(config HostDNSConfig) error {
|
||||
if m.applyDNSConfigFunc != nil {
|
||||
return m.applyDNSConfigFunc(config)
|
||||
}
|
||||
@@ -53,40 +56,48 @@ func (m *mockHostConfigurator) supportCustomPort() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (m *mockHostConfigurator) restoreUncleanShutdownDNS(storedDNSAddress *netip.Addr) error {
|
||||
if m.restoreUncleanShutdownDNSFunc != nil {
|
||||
return m.restoreUncleanShutdownDNSFunc(storedDNSAddress)
|
||||
}
|
||||
return fmt.Errorf("method restoreUncleanShutdownDNS is not implemented")
|
||||
}
|
||||
|
||||
func newNoopHostMocker() hostManager {
|
||||
return &mockHostConfigurator{
|
||||
applyDNSConfigFunc: func(config hostDNSConfig) error { return nil },
|
||||
restoreHostDNSFunc: func() error { return nil },
|
||||
supportCustomPortFunc: func() bool { return true },
|
||||
applyDNSConfigFunc: func(config HostDNSConfig) error { return nil },
|
||||
restoreHostDNSFunc: func() error { return nil },
|
||||
supportCustomPortFunc: func() bool { return true },
|
||||
restoreUncleanShutdownDNSFunc: func(*netip.Addr) error { return nil },
|
||||
}
|
||||
}
|
||||
|
||||
func dnsConfigToHostDNSConfig(dnsConfig nbdns.Config, ip string, port int) hostDNSConfig {
|
||||
config := hostDNSConfig{
|
||||
routeAll: false,
|
||||
serverIP: ip,
|
||||
serverPort: port,
|
||||
func dnsConfigToHostDNSConfig(dnsConfig nbdns.Config, ip string, port int) HostDNSConfig {
|
||||
config := HostDNSConfig{
|
||||
RouteAll: false,
|
||||
ServerIP: ip,
|
||||
ServerPort: port,
|
||||
}
|
||||
for _, nsConfig := range dnsConfig.NameServerGroups {
|
||||
if len(nsConfig.NameServers) == 0 {
|
||||
continue
|
||||
}
|
||||
if nsConfig.Primary {
|
||||
config.routeAll = true
|
||||
config.RouteAll = true
|
||||
}
|
||||
|
||||
for _, domain := range nsConfig.Domains {
|
||||
config.domains = append(config.domains, domainConfig{
|
||||
domain: strings.TrimSuffix(domain, "."),
|
||||
matchOnly: true,
|
||||
config.Domains = append(config.Domains, DomainConfig{
|
||||
Domain: strings.TrimSuffix(domain, "."),
|
||||
MatchOnly: !nsConfig.SearchDomainsEnabled,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
for _, customZone := range dnsConfig.CustomZones {
|
||||
config.domains = append(config.domains, domainConfig{
|
||||
domain: strings.TrimSuffix(customZone.Domain, "."),
|
||||
matchOnly: false,
|
||||
config.Domains = append(config.Domains, DomainConfig{
|
||||
Domain: strings.TrimSuffix(customZone.Domain, "."),
|
||||
MatchOnly: false,
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
@@ -1,13 +1,15 @@
|
||||
package dns
|
||||
|
||||
import "net/netip"
|
||||
|
||||
type androidHostManager struct {
|
||||
}
|
||||
|
||||
func newHostManager(wgInterface WGIface) (hostManager, error) {
|
||||
func newHostManager() (hostManager, error) {
|
||||
return &androidHostManager{}, nil
|
||||
}
|
||||
|
||||
func (a androidHostManager) applyDNSConfig(config hostDNSConfig) error {
|
||||
func (a androidHostManager) applyDNSConfig(config HostDNSConfig) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -18,3 +20,7 @@ func (a androidHostManager) restoreHostDNS() error {
|
||||
func (a androidHostManager) supportCustomPort() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (a androidHostManager) restoreUncleanShutdownDNS(*netip.Addr) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1,9 +1,13 @@
|
||||
//go:build !ios
|
||||
|
||||
package dns
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/netip"
|
||||
"os/exec"
|
||||
"strconv"
|
||||
"strings"
|
||||
@@ -32,7 +36,7 @@ type systemConfigurator struct {
|
||||
createdKeys map[string]struct{}
|
||||
}
|
||||
|
||||
func newHostManager(_ WGIface) (hostManager, error) {
|
||||
func newHostManager() (hostManager, error) {
|
||||
return &systemConfigurator{
|
||||
createdKeys: make(map[string]struct{}),
|
||||
}, nil
|
||||
@@ -42,21 +46,26 @@ func (s *systemConfigurator) supportCustomPort() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (s *systemConfigurator) applyDNSConfig(config hostDNSConfig) error {
|
||||
func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig) error {
|
||||
var err error
|
||||
|
||||
if config.routeAll {
|
||||
err = s.addDNSSetupForAll(config.serverIP, config.serverPort)
|
||||
if config.RouteAll {
|
||||
err = s.addDNSSetupForAll(config.ServerIP, config.ServerPort)
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("add dns setup for all: %w", err)
|
||||
}
|
||||
} else if s.primaryServiceID != "" {
|
||||
err = s.removeKeyFromSystemConfig(getKeyWithInput(primaryServiceSetupKeyFormat, s.primaryServiceID))
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("remote key from system config: %w", err)
|
||||
}
|
||||
s.primaryServiceID = ""
|
||||
log.Infof("removed %s:%d as main DNS resolver for this peer", config.serverIP, config.serverPort)
|
||||
log.Infof("removed %s:%d as main DNS resolver for this peer", config.ServerIP, config.ServerPort)
|
||||
}
|
||||
|
||||
// create a file for unclean shutdown detection
|
||||
if err := createUncleanShutdownIndicator(); err != nil {
|
||||
log.Errorf("failed to create unclean shutdown file: %s", err)
|
||||
}
|
||||
|
||||
var (
|
||||
@@ -64,37 +73,37 @@ func (s *systemConfigurator) applyDNSConfig(config hostDNSConfig) error {
|
||||
matchDomains []string
|
||||
)
|
||||
|
||||
for _, dConf := range config.domains {
|
||||
if dConf.disabled {
|
||||
for _, dConf := range config.Domains {
|
||||
if dConf.Disabled {
|
||||
continue
|
||||
}
|
||||
if dConf.matchOnly {
|
||||
matchDomains = append(matchDomains, dConf.domain)
|
||||
if dConf.MatchOnly {
|
||||
matchDomains = append(matchDomains, dConf.Domain)
|
||||
continue
|
||||
}
|
||||
searchDomains = append(searchDomains, dConf.domain)
|
||||
searchDomains = append(searchDomains, dConf.Domain)
|
||||
}
|
||||
|
||||
matchKey := getKeyWithInput(netbirdDNSStateKeyFormat, matchSuffix)
|
||||
if len(matchDomains) != 0 {
|
||||
err = s.addMatchDomains(matchKey, strings.Join(matchDomains, " "), config.serverIP, config.serverPort)
|
||||
err = s.addMatchDomains(matchKey, strings.Join(matchDomains, " "), config.ServerIP, config.ServerPort)
|
||||
} else {
|
||||
log.Infof("removing match domains from the system")
|
||||
err = s.removeKeyFromSystemConfig(matchKey)
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("add match domains: %w", err)
|
||||
}
|
||||
|
||||
searchKey := getKeyWithInput(netbirdDNSStateKeyFormat, searchSuffix)
|
||||
if len(searchDomains) != 0 {
|
||||
err = s.addSearchDomains(searchKey, strings.Join(searchDomains, " "), config.serverIP, config.serverPort)
|
||||
err = s.addSearchDomains(searchKey, strings.Join(searchDomains, " "), config.ServerIP, config.ServerPort)
|
||||
} else {
|
||||
log.Infof("removing search domains from the system")
|
||||
err = s.removeKeyFromSystemConfig(searchKey)
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("add search domains: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -117,7 +126,11 @@ func (s *systemConfigurator) restoreHostDNS() error {
|
||||
_, err := runSystemConfigCommand(wrapCommand(lines))
|
||||
if err != nil {
|
||||
log.Errorf("got an error while cleaning the system configuration: %s", err)
|
||||
return err
|
||||
return fmt.Errorf("clean system: %w", err)
|
||||
}
|
||||
|
||||
if err := removeUncleanShutdownIndicator(); err != nil {
|
||||
log.Errorf("failed to remove unclean shutdown file: %s", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -127,7 +140,7 @@ func (s *systemConfigurator) removeKeyFromSystemConfig(key string) error {
|
||||
line := buildRemoveKeyOperation(key)
|
||||
_, err := runSystemConfigCommand(wrapCommand(line))
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("remove key: %w", err)
|
||||
}
|
||||
|
||||
delete(s.createdKeys, key)
|
||||
@@ -138,7 +151,7 @@ func (s *systemConfigurator) removeKeyFromSystemConfig(key string) error {
|
||||
func (s *systemConfigurator) addSearchDomains(key, domains string, ip string, port int) error {
|
||||
err := s.addDNSState(key, domains, ip, port, true)
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("add dns state: %w", err)
|
||||
}
|
||||
|
||||
log.Infof("added %d search domains to the state. Domain list: %s", len(strings.Split(domains, " ")), domains)
|
||||
@@ -151,7 +164,7 @@ func (s *systemConfigurator) addSearchDomains(key, domains string, ip string, po
|
||||
func (s *systemConfigurator) addMatchDomains(key, domains, dnsServer string, port int) error {
|
||||
err := s.addDNSState(key, domains, dnsServer, port, false)
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("add dns state: %w", err)
|
||||
}
|
||||
|
||||
log.Infof("added %d match domains to the state. Domain list: %s", len(strings.Split(domains, " ")), domains)
|
||||
@@ -176,33 +189,37 @@ func (s *systemConfigurator) addDNSState(state, domains, dnsServer string, port
|
||||
|
||||
_, err := runSystemConfigCommand(stdinCommands)
|
||||
if err != nil {
|
||||
return fmt.Errorf("got error while applying state for domains %s, error: %s", domains, err)
|
||||
return fmt.Errorf("applying state for domains %s, error: %w", domains, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *systemConfigurator) addDNSSetupForAll(dnsServer string, port int) error {
|
||||
primaryServiceKey, existingNameserver := s.getPrimaryService()
|
||||
if primaryServiceKey == "" {
|
||||
return fmt.Errorf("couldn't find the primary service key")
|
||||
primaryServiceKey, existingNameserver, err := s.getPrimaryService()
|
||||
if err != nil || primaryServiceKey == "" {
|
||||
return fmt.Errorf("couldn't find the primary service key: %w", err)
|
||||
}
|
||||
err := s.addDNSSetup(getKeyWithInput(primaryServiceSetupKeyFormat, primaryServiceKey), dnsServer, port, existingNameserver)
|
||||
|
||||
err = s.addDNSSetup(getKeyWithInput(primaryServiceSetupKeyFormat, primaryServiceKey), dnsServer, port, existingNameserver)
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("add dns setup: %w", err)
|
||||
}
|
||||
|
||||
log.Infof("configured %s:%d as main DNS resolver for this peer", dnsServer, port)
|
||||
s.primaryServiceID = primaryServiceKey
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *systemConfigurator) getPrimaryService() (string, string) {
|
||||
func (s *systemConfigurator) getPrimaryService() (string, string, error) {
|
||||
line := buildCommandLine("show", globalIPv4State, "")
|
||||
stdinCommands := wrapCommand(line)
|
||||
|
||||
b, err := runSystemConfigCommand(stdinCommands)
|
||||
if err != nil {
|
||||
log.Error("got error while sending the command: ", err)
|
||||
return "", ""
|
||||
return "", "", fmt.Errorf("sending the command: %w", err)
|
||||
}
|
||||
|
||||
scanner := bufio.NewScanner(bytes.NewReader(b))
|
||||
primaryService := ""
|
||||
router := ""
|
||||
@@ -215,7 +232,11 @@ func (s *systemConfigurator) getPrimaryService() (string, string) {
|
||||
router = strings.TrimSpace(strings.Split(text, ":")[1])
|
||||
}
|
||||
}
|
||||
return primaryService, router
|
||||
if err := scanner.Err(); err != nil && err != io.EOF {
|
||||
return primaryService, router, fmt.Errorf("scan: %w", err)
|
||||
}
|
||||
|
||||
return primaryService, router, nil
|
||||
}
|
||||
|
||||
func (s *systemConfigurator) addDNSSetup(setupKey, dnsServer string, port int, existingDNSServer string) error {
|
||||
@@ -226,7 +247,14 @@ func (s *systemConfigurator) addDNSSetup(setupKey, dnsServer string, port int, e
|
||||
stdinCommands := wrapCommand(addDomainCommand)
|
||||
_, err := runSystemConfigCommand(stdinCommands)
|
||||
if err != nil {
|
||||
return fmt.Errorf("got error while applying dns setup, error: %s", err)
|
||||
return fmt.Errorf("applying dns setup, error: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *systemConfigurator) restoreUncleanShutdownDNS(*netip.Addr) error {
|
||||
if err := s.restoreHostDNS(); err != nil {
|
||||
return fmt.Errorf("restoring dns via scutil: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -264,7 +292,7 @@ func runSystemConfigCommand(command string) ([]byte, error) {
|
||||
cmd.Stdin = strings.NewReader(command)
|
||||
out, err := cmd.Output()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("got error while running system configuration command: \"%s\", error: %s", command, err)
|
||||
return nil, fmt.Errorf("running system configuration command: \"%s\", error: %w", command, err)
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
43
client/internal/dns/host_ios.go
Normal file
43
client/internal/dns/host_ios.go
Normal file
@@ -0,0 +1,43 @@
|
||||
package dns
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
type iosHostManager struct {
|
||||
dnsManager IosDnsManager
|
||||
config HostDNSConfig
|
||||
}
|
||||
|
||||
func newHostManager(dnsManager IosDnsManager) (hostManager, error) {
|
||||
return &iosHostManager{
|
||||
dnsManager: dnsManager,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (a iosHostManager) applyDNSConfig(config HostDNSConfig) error {
|
||||
jsonData, err := json.Marshal(config)
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal: %w", err)
|
||||
}
|
||||
jsonString := string(jsonData)
|
||||
log.Debugf("Applying DNS settings: %s", jsonString)
|
||||
a.dnsManager.ApplyDns(jsonString)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a iosHostManager) restoreHostDNS() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a iosHostManager) supportCustomPort() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (a iosHostManager) restoreUncleanShutdownDNS(*netip.Addr) error {
|
||||
return nil
|
||||
}
|
||||
@@ -4,17 +4,15 @@ package dns
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultResolvConfPath = "/etc/resolv.conf"
|
||||
)
|
||||
|
||||
const (
|
||||
netbirdManager osManagerType = iota
|
||||
fileManager
|
||||
@@ -23,15 +21,55 @@ const (
|
||||
resolvConfManager
|
||||
)
|
||||
|
||||
var ErrUnknownOsManagerType = errors.New("unknown os manager type")
|
||||
|
||||
type osManagerType int
|
||||
|
||||
func newHostManager(wgInterface WGIface) (hostManager, error) {
|
||||
func newOsManagerType(osManager string) (osManagerType, error) {
|
||||
switch osManager {
|
||||
case "netbird":
|
||||
return fileManager, nil
|
||||
case "file":
|
||||
return netbirdManager, nil
|
||||
case "networkManager":
|
||||
return networkManager, nil
|
||||
case "systemd":
|
||||
return systemdManager, nil
|
||||
case "resolvconf":
|
||||
return resolvConfManager, nil
|
||||
default:
|
||||
return 0, ErrUnknownOsManagerType
|
||||
}
|
||||
}
|
||||
|
||||
func (t osManagerType) String() string {
|
||||
switch t {
|
||||
case netbirdManager:
|
||||
return "netbird"
|
||||
case fileManager:
|
||||
return "file"
|
||||
case networkManager:
|
||||
return "networkManager"
|
||||
case systemdManager:
|
||||
return "systemd"
|
||||
case resolvConfManager:
|
||||
return "resolvconf"
|
||||
default:
|
||||
return "unknown"
|
||||
}
|
||||
}
|
||||
|
||||
func newHostManager(wgInterface string) (hostManager, error) {
|
||||
osManager, err := getOSDNSManagerType()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
log.Debugf("discovered mode is: %d", osManager)
|
||||
log.Infof("System DNS manager discovered: %s", osManager)
|
||||
return newHostManagerFromType(wgInterface, osManager)
|
||||
}
|
||||
|
||||
func newHostManagerFromType(wgInterface string, osManager osManagerType) (hostManager, error) {
|
||||
switch osManager {
|
||||
case networkManager:
|
||||
return newNetworkManagerDbusConfigurator(wgInterface)
|
||||
@@ -45,12 +83,15 @@ func newHostManager(wgInterface WGIface) (hostManager, error) {
|
||||
}
|
||||
|
||||
func getOSDNSManagerType() (osManagerType, error) {
|
||||
|
||||
file, err := os.Open(defaultResolvConfPath)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("unable to open %s for checking owner, got error: %s", defaultResolvConfPath, err)
|
||||
return 0, fmt.Errorf("unable to open %s for checking owner, got error: %w", defaultResolvConfPath, err)
|
||||
}
|
||||
defer file.Close()
|
||||
defer func() {
|
||||
if err := file.Close(); err != nil {
|
||||
log.Errorf("close file %s: %s", defaultResolvConfPath, err)
|
||||
}
|
||||
}()
|
||||
|
||||
scanner := bufio.NewScanner(file)
|
||||
for scanner.Scan() {
|
||||
@@ -65,11 +106,14 @@ func getOSDNSManagerType() (osManagerType, error) {
|
||||
return netbirdManager, nil
|
||||
}
|
||||
if strings.Contains(text, "NetworkManager") && isDbusListenerRunning(networkManagerDest, networkManagerDbusObjectNode) && isNetworkManagerSupported() {
|
||||
log.Debugf("is nm running on supported v? %t", isNetworkManagerSupportedVersion())
|
||||
return networkManager, nil
|
||||
}
|
||||
if strings.Contains(text, "systemd-resolved") && isDbusListenerRunning(systemdResolvedDest, systemdDbusObjectNode) {
|
||||
return systemdManager, nil
|
||||
if checkStub() {
|
||||
return systemdManager, nil
|
||||
} else {
|
||||
return fileManager, nil
|
||||
}
|
||||
}
|
||||
if strings.Contains(text, "resolvconf") {
|
||||
if isDbusListenerRunning(systemdResolvedDest, systemdDbusObjectNode) {
|
||||
@@ -85,5 +129,26 @@ func getOSDNSManagerType() (osManagerType, error) {
|
||||
return resolvConfManager, nil
|
||||
}
|
||||
}
|
||||
if err := scanner.Err(); err != nil && err != io.EOF {
|
||||
return 0, fmt.Errorf("scan: %w", err)
|
||||
}
|
||||
|
||||
return fileManager, nil
|
||||
}
|
||||
|
||||
// checkStub checks if the stub resolver is disabled in systemd-resolved. If it is disabled, we fall back to file manager.
|
||||
func checkStub() bool {
|
||||
rConf, err := parseDefaultResolvConf()
|
||||
if err != nil {
|
||||
log.Warnf("failed to parse resolv conf: %s", err)
|
||||
return true
|
||||
}
|
||||
|
||||
for _, ns := range rConf.nameServers {
|
||||
if ns == "127.0.0.53" {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -2,6 +2,8 @@ package dns
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/netip"
|
||||
"strings"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
@@ -9,7 +11,7 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
dnsPolicyConfigMatchPath = "SYSTEM\\CurrentControlSet\\Services\\Dnscache\\Parameters\\DnsPolicyConfig\\NetBird-Match"
|
||||
dnsPolicyConfigMatchPath = `SYSTEM\CurrentControlSet\Services\Dnscache\Parameters\DnsPolicyConfig\NetBird-Match`
|
||||
dnsPolicyConfigVersionKey = "Version"
|
||||
dnsPolicyConfigVersionValue = 2
|
||||
dnsPolicyConfigNameKey = "Name"
|
||||
@@ -19,16 +21,14 @@ const (
|
||||
)
|
||||
|
||||
const (
|
||||
interfaceConfigPath = "SYSTEM\\CurrentControlSet\\Services\\Tcpip\\Parameters\\Interfaces"
|
||||
interfaceConfigPath = `SYSTEM\CurrentControlSet\Services\Tcpip\Parameters\Interfaces`
|
||||
interfaceConfigNameServerKey = "NameServer"
|
||||
interfaceConfigSearchListKey = "SearchList"
|
||||
tcpipParametersPath = "SYSTEM\\CurrentControlSet\\Services\\Tcpip\\Parameters"
|
||||
)
|
||||
|
||||
type registryConfigurator struct {
|
||||
guid string
|
||||
routingAll bool
|
||||
existingSearchDomains []string
|
||||
guid string
|
||||
routingAll bool
|
||||
}
|
||||
|
||||
func newHostManager(wgInterface WGIface) (hostManager, error) {
|
||||
@@ -36,29 +36,38 @@ func newHostManager(wgInterface WGIface) (hostManager, error) {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return newHostManagerWithGuid(guid)
|
||||
}
|
||||
|
||||
func newHostManagerWithGuid(guid string) (hostManager, error) {
|
||||
return ®istryConfigurator{
|
||||
guid: guid,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *registryConfigurator) supportCustomPort() bool {
|
||||
func (r *registryConfigurator) supportCustomPort() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (r *registryConfigurator) applyDNSConfig(config hostDNSConfig) error {
|
||||
func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig) error {
|
||||
var err error
|
||||
if config.routeAll {
|
||||
err = r.addDNSSetupForAll(config.serverIP)
|
||||
if config.RouteAll {
|
||||
err = r.addDNSSetupForAll(config.ServerIP)
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("add dns setup: %w", err)
|
||||
}
|
||||
} else if r.routingAll {
|
||||
err = r.deleteInterfaceRegistryKeyProperty(interfaceConfigNameServerKey)
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("delete interface registry key property: %w", err)
|
||||
}
|
||||
r.routingAll = false
|
||||
log.Infof("removed %s as main DNS forwarder for this peer", config.serverIP)
|
||||
log.Infof("removed %s as main DNS forwarder for this peer", config.ServerIP)
|
||||
}
|
||||
|
||||
// create a file for unclean shutdown detection
|
||||
if err := createUncleanShutdownIndicator(r.guid); err != nil {
|
||||
log.Errorf("failed to create unclean shutdown file: %s", err)
|
||||
}
|
||||
|
||||
var (
|
||||
@@ -66,28 +75,28 @@ func (r *registryConfigurator) applyDNSConfig(config hostDNSConfig) error {
|
||||
matchDomains []string
|
||||
)
|
||||
|
||||
for _, dConf := range config.domains {
|
||||
if dConf.disabled {
|
||||
for _, dConf := range config.Domains {
|
||||
if dConf.Disabled {
|
||||
continue
|
||||
}
|
||||
if !dConf.matchOnly {
|
||||
searchDomains = append(searchDomains, dConf.domain)
|
||||
if !dConf.MatchOnly {
|
||||
searchDomains = append(searchDomains, dConf.Domain)
|
||||
}
|
||||
matchDomains = append(matchDomains, "."+dConf.domain)
|
||||
matchDomains = append(matchDomains, "."+dConf.Domain)
|
||||
}
|
||||
|
||||
if len(matchDomains) != 0 {
|
||||
err = r.addDNSMatchPolicy(matchDomains, config.serverIP)
|
||||
err = r.addDNSMatchPolicy(matchDomains, config.ServerIP)
|
||||
} else {
|
||||
err = removeRegistryKeyFromDNSPolicyConfig(dnsPolicyConfigMatchPath)
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("add dns match policy: %w", err)
|
||||
}
|
||||
|
||||
err = r.updateSearchDomains(searchDomains)
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("update search domains: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -96,7 +105,7 @@ func (r *registryConfigurator) applyDNSConfig(config hostDNSConfig) error {
|
||||
func (r *registryConfigurator) addDNSSetupForAll(ip string) error {
|
||||
err := r.setInterfaceRegistryKeyStringValue(interfaceConfigNameServerKey, ip)
|
||||
if err != nil {
|
||||
return fmt.Errorf("adding dns setup for all failed with error: %s", err)
|
||||
return fmt.Errorf("adding dns setup for all failed with error: %w", err)
|
||||
}
|
||||
r.routingAll = true
|
||||
log.Infof("configured %s:53 as main DNS forwarder for this peer", ip)
|
||||
@@ -108,33 +117,33 @@ func (r *registryConfigurator) addDNSMatchPolicy(domains []string, ip string) er
|
||||
if err == nil {
|
||||
err = registry.DeleteKey(registry.LOCAL_MACHINE, dnsPolicyConfigMatchPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to remove existing key from registry, key: HKEY_LOCAL_MACHINE\\%s, error: %s", dnsPolicyConfigMatchPath, err)
|
||||
return fmt.Errorf("unable to remove existing key from registry, key: HKEY_LOCAL_MACHINE\\%s, error: %w", dnsPolicyConfigMatchPath, err)
|
||||
}
|
||||
}
|
||||
|
||||
regKey, _, err := registry.CreateKey(registry.LOCAL_MACHINE, dnsPolicyConfigMatchPath, registry.SET_VALUE)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to create registry key, key: HKEY_LOCAL_MACHINE\\%s, error: %s", dnsPolicyConfigMatchPath, err)
|
||||
return fmt.Errorf("unable to create registry key, key: HKEY_LOCAL_MACHINE\\%s, error: %w", dnsPolicyConfigMatchPath, err)
|
||||
}
|
||||
|
||||
err = regKey.SetDWordValue(dnsPolicyConfigVersionKey, dnsPolicyConfigVersionValue)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to set registry value for %s, error: %s", dnsPolicyConfigVersionKey, err)
|
||||
return fmt.Errorf("unable to set registry value for %s, error: %w", dnsPolicyConfigVersionKey, err)
|
||||
}
|
||||
|
||||
err = regKey.SetStringsValue(dnsPolicyConfigNameKey, domains)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to set registry value for %s, error: %s", dnsPolicyConfigNameKey, err)
|
||||
return fmt.Errorf("unable to set registry value for %s, error: %w", dnsPolicyConfigNameKey, err)
|
||||
}
|
||||
|
||||
err = regKey.SetStringValue(dnsPolicyConfigGenericDNSServersKey, ip)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to set registry value for %s, error: %s", dnsPolicyConfigGenericDNSServersKey, err)
|
||||
return fmt.Errorf("unable to set registry value for %s, error: %w", dnsPolicyConfigGenericDNSServersKey, err)
|
||||
}
|
||||
|
||||
err = regKey.SetDWordValue(dnsPolicyConfigConfigOptionsKey, dnsPolicyConfigConfigOptionsValue)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to set registry value for %s, error: %s", dnsPolicyConfigConfigOptionsKey, err)
|
||||
return fmt.Errorf("unable to set registry value for %s, error: %w", dnsPolicyConfigConfigOptionsKey, err)
|
||||
}
|
||||
|
||||
log.Infof("added %d match domains to the state. Domain list: %s", len(domains), domains)
|
||||
@@ -143,37 +152,25 @@ func (r *registryConfigurator) addDNSMatchPolicy(domains []string, ip string) er
|
||||
}
|
||||
|
||||
func (r *registryConfigurator) restoreHostDNS() error {
|
||||
err := removeRegistryKeyFromDNSPolicyConfig(dnsPolicyConfigMatchPath)
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
if err := removeRegistryKeyFromDNSPolicyConfig(dnsPolicyConfigMatchPath); err != nil {
|
||||
log.Errorf("remove registry key from dns policy config: %s", err)
|
||||
}
|
||||
|
||||
return r.updateSearchDomains([]string{})
|
||||
if err := r.deleteInterfaceRegistryKeyProperty(interfaceConfigSearchListKey); err != nil {
|
||||
return fmt.Errorf("remove interface registry key: %w", err)
|
||||
}
|
||||
|
||||
if err := removeUncleanShutdownIndicator(); err != nil {
|
||||
log.Errorf("failed to remove unclean shutdown file: %s", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *registryConfigurator) updateSearchDomains(domains []string) error {
|
||||
value, err := getLocalMachineRegistryKeyStringValue(tcpipParametersPath, interfaceConfigSearchListKey)
|
||||
err := r.setInterfaceRegistryKeyStringValue(interfaceConfigSearchListKey, strings.Join(domains, ","))
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to get current search domains failed with error: %s", err)
|
||||
}
|
||||
|
||||
valueList := strings.Split(value, ",")
|
||||
setExisting := false
|
||||
if len(r.existingSearchDomains) == 0 {
|
||||
r.existingSearchDomains = valueList
|
||||
setExisting = true
|
||||
}
|
||||
|
||||
if len(domains) == 0 && setExisting {
|
||||
log.Infof("added %d search domains to the registry. Domain list: %s", len(domains), domains)
|
||||
return nil
|
||||
}
|
||||
|
||||
newList := append(r.existingSearchDomains, domains...)
|
||||
|
||||
err = setLocalMachineRegistryKeyStringValue(tcpipParametersPath, interfaceConfigSearchListKey, strings.Join(newList, ","))
|
||||
if err != nil {
|
||||
return fmt.Errorf("adding search domain failed with error: %s", err)
|
||||
return fmt.Errorf("adding search domain failed with error: %w", err)
|
||||
}
|
||||
|
||||
log.Infof("updated the search domains in the registry with %d domains. Domain list: %s", len(domains), domains)
|
||||
@@ -184,13 +181,13 @@ func (r *registryConfigurator) updateSearchDomains(domains []string) error {
|
||||
func (r *registryConfigurator) setInterfaceRegistryKeyStringValue(key, value string) error {
|
||||
regKey, err := r.getInterfaceRegistryKey()
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("get interface registry key: %w", err)
|
||||
}
|
||||
defer regKey.Close()
|
||||
defer closer(regKey)
|
||||
|
||||
err = regKey.SetStringValue(key, value)
|
||||
if err != nil {
|
||||
return fmt.Errorf("applying key %s with value \"%s\" for interface failed with error: %s", key, value, err)
|
||||
return fmt.Errorf("applying key %s with value \"%s\" for interface failed with error: %w", key, value, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -199,13 +196,13 @@ func (r *registryConfigurator) setInterfaceRegistryKeyStringValue(key, value str
|
||||
func (r *registryConfigurator) deleteInterfaceRegistryKeyProperty(propertyKey string) error {
|
||||
regKey, err := r.getInterfaceRegistryKey()
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("get interface registry key: %w", err)
|
||||
}
|
||||
defer regKey.Close()
|
||||
defer closer(regKey)
|
||||
|
||||
err = regKey.DeleteValue(propertyKey)
|
||||
if err != nil {
|
||||
return fmt.Errorf("deleting registry key %s for interface failed with error: %s", propertyKey, err)
|
||||
return fmt.Errorf("deleting registry key %s for interface failed with error: %w", propertyKey, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -218,50 +215,33 @@ func (r *registryConfigurator) getInterfaceRegistryKey() (registry.Key, error) {
|
||||
|
||||
regKey, err := registry.OpenKey(registry.LOCAL_MACHINE, regKeyPath, registry.SET_VALUE)
|
||||
if err != nil {
|
||||
return regKey, fmt.Errorf("unable to open the interface registry key, key: HKEY_LOCAL_MACHINE\\%s, error: %s", regKeyPath, err)
|
||||
return regKey, fmt.Errorf("unable to open the interface registry key, key: HKEY_LOCAL_MACHINE\\%s, error: %w", regKeyPath, err)
|
||||
}
|
||||
|
||||
return regKey, nil
|
||||
}
|
||||
|
||||
func (r *registryConfigurator) restoreUncleanShutdownDNS(*netip.Addr) error {
|
||||
if err := r.restoreHostDNS(); err != nil {
|
||||
return fmt.Errorf("restoring dns via registry: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func removeRegistryKeyFromDNSPolicyConfig(regKeyPath string) error {
|
||||
k, err := registry.OpenKey(registry.LOCAL_MACHINE, regKeyPath, registry.QUERY_VALUE)
|
||||
if err == nil {
|
||||
k.Close()
|
||||
defer closer(k)
|
||||
err = registry.DeleteKey(registry.LOCAL_MACHINE, regKeyPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to remove existing key from registry, key: HKEY_LOCAL_MACHINE\\%s, error: %s", regKeyPath, err)
|
||||
return fmt.Errorf("unable to remove existing key from registry, key: HKEY_LOCAL_MACHINE\\%s, error: %w", regKeyPath, err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func getLocalMachineRegistryKeyStringValue(keyPath, key string) (string, error) {
|
||||
regKey, err := registry.OpenKey(registry.LOCAL_MACHINE, keyPath, registry.QUERY_VALUE)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("unable to open existing key from registry, key path: HKEY_LOCAL_MACHINE\\%s, error: %s", keyPath, err)
|
||||
func closer(closer io.Closer) {
|
||||
if err := closer.Close(); err != nil {
|
||||
log.Errorf("failed to close: %s", err)
|
||||
}
|
||||
defer regKey.Close()
|
||||
|
||||
val, _, err := regKey.GetStringValue(key)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("getting %s value for key path HKEY_LOCAL_MACHINE\\%s failed with error: %s", key, keyPath, err)
|
||||
}
|
||||
|
||||
return val, nil
|
||||
}
|
||||
|
||||
func setLocalMachineRegistryKeyStringValue(keyPath, key, value string) error {
|
||||
regKey, err := registry.OpenKey(registry.LOCAL_MACHINE, keyPath, registry.SET_VALUE)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to open existing key from registry, key path: HKEY_LOCAL_MACHINE\\%s, error: %s", keyPath, err)
|
||||
}
|
||||
defer regKey.Close()
|
||||
|
||||
err = regKey.SetStringValue(key, value)
|
||||
if err != nil {
|
||||
return fmt.Errorf("setting %s value %s for key path HKEY_LOCAL_MACHINE\\%s failed with error: %s", key, value, keyPath, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -52,7 +52,7 @@ func (d *localResolver) lookupRecord(r *dns.Msg) dns.RR {
|
||||
func (d *localResolver) registerRecord(record nbdns.SimpleRecord) error {
|
||||
fullRecord, err := dns.NewRR(record.String())
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("register record: %w", err)
|
||||
}
|
||||
|
||||
fullRecord.Header().Rdlength = record.Len()
|
||||
@@ -71,3 +71,5 @@ func buildRecordKey(name string, class, qType uint16) string {
|
||||
key := fmt.Sprintf("%s_%d_%d", name, class, qType)
|
||||
return key
|
||||
}
|
||||
|
||||
func (d *localResolver) probeAvailability() {}
|
||||
|
||||
@@ -1,10 +1,12 @@
|
||||
package dns
|
||||
|
||||
import (
|
||||
"github.com/miekg/dns"
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
)
|
||||
|
||||
func TestLocalResolver_ServeDNS(t *testing.T) {
|
||||
|
||||
@@ -2,6 +2,7 @@ package dns
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
)
|
||||
|
||||
@@ -32,7 +33,7 @@ func (m *MockServer) DnsIP() string {
|
||||
}
|
||||
|
||||
func (m *MockServer) OnUpdatedHostDNSServer(strings []string) {
|
||||
//TODO implement me
|
||||
// TODO implement me
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
@@ -43,3 +44,11 @@ func (m *MockServer) UpdateDNSServer(serial uint64, update nbdns.Config) error {
|
||||
}
|
||||
return fmt.Errorf("method UpdateDNSServer is not implemented")
|
||||
}
|
||||
|
||||
func (m *MockServer) SearchDomains() []string {
|
||||
return make([]string, 0)
|
||||
}
|
||||
|
||||
// ProbeAvailability mocks implementation of ProbeAvailability from the Server interface
|
||||
func (m *MockServer) ProbeAvailability() {
|
||||
}
|
||||
@@ -5,15 +5,18 @@ package dns
|
||||
import (
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"regexp"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/godbus/dbus/v5"
|
||||
"github.com/hashicorp/go-version"
|
||||
"github.com/miekg/dns"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
nbversion "github.com/netbirdio/netbird/version"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -40,9 +43,13 @@ const (
|
||||
networkManagerDbusPrimaryDNSPriority int32 = -500
|
||||
networkManagerDbusWithMatchDomainPriority int32 = 0
|
||||
networkManagerDbusSearchDomainOnlyPriority int32 = 50
|
||||
supportedNetworkManagerVersionConstraint = ">= 1.16, < 1.28"
|
||||
)
|
||||
|
||||
var supportedNetworkManagerVersionConstraints = []string{
|
||||
">= 1.16, < 1.27",
|
||||
">= 1.44, < 1.45",
|
||||
}
|
||||
|
||||
type networkManagerDbusConfigurator struct {
|
||||
dbusLinkObject dbus.ObjectPath
|
||||
routingAll bool
|
||||
@@ -70,19 +77,19 @@ func (s networkManagerConnSettings) cleanDeprecatedSettings() {
|
||||
}
|
||||
}
|
||||
|
||||
func newNetworkManagerDbusConfigurator(wgInterface WGIface) (hostManager, error) {
|
||||
func newNetworkManagerDbusConfigurator(wgInterface string) (hostManager, error) {
|
||||
obj, closeConn, err := getDbusObject(networkManagerDest, networkManagerDbusObjectNode)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("get nm dbus: %w", err)
|
||||
}
|
||||
defer closeConn()
|
||||
var s string
|
||||
err = obj.Call(networkManagerDbusGetDeviceByIPIfaceMethod, dbusDefaultFlag, wgInterface.Name()).Store(&s)
|
||||
err = obj.Call(networkManagerDbusGetDeviceByIPIfaceMethod, dbusDefaultFlag, wgInterface).Store(&s)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("call: %w", err)
|
||||
}
|
||||
|
||||
log.Debugf("got network manager dbus Link Object: %s from net interface %s", s, wgInterface.Name())
|
||||
log.Debugf("got network manager dbus Link Object: %s from net interface %s", s, wgInterface)
|
||||
|
||||
return &networkManagerDbusConfigurator{
|
||||
dbusLinkObject: dbus.ObjectPath(s),
|
||||
@@ -93,17 +100,17 @@ func (n *networkManagerDbusConfigurator) supportCustomPort() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (n *networkManagerDbusConfigurator) applyDNSConfig(config hostDNSConfig) error {
|
||||
func (n *networkManagerDbusConfigurator) applyDNSConfig(config HostDNSConfig) error {
|
||||
connSettings, configVersion, err := n.getAppliedConnectionSettings()
|
||||
if err != nil {
|
||||
return fmt.Errorf("got an error while retrieving the applied connection settings, error: %s", err)
|
||||
return fmt.Errorf("retrieving the applied connection settings, error: %w", err)
|
||||
}
|
||||
|
||||
connSettings.cleanDeprecatedSettings()
|
||||
|
||||
dnsIP, err := netip.ParseAddr(config.serverIP)
|
||||
dnsIP, err := netip.ParseAddr(config.ServerIP)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to parse ip address, error: %s", err)
|
||||
return fmt.Errorf("unable to parse ip address, error: %w", err)
|
||||
}
|
||||
convDNSIP := binary.LittleEndian.Uint32(dnsIP.AsSlice())
|
||||
connSettings[networkManagerDbusIPv4Key][networkManagerDbusDNSKey] = dbus.MakeVariant([]uint32{convDNSIP})
|
||||
@@ -111,56 +118,70 @@ func (n *networkManagerDbusConfigurator) applyDNSConfig(config hostDNSConfig) er
|
||||
searchDomains []string
|
||||
matchDomains []string
|
||||
)
|
||||
for _, dConf := range config.domains {
|
||||
if dConf.disabled {
|
||||
for _, dConf := range config.Domains {
|
||||
if dConf.Disabled {
|
||||
continue
|
||||
}
|
||||
if dConf.matchOnly {
|
||||
matchDomains = append(matchDomains, "~."+dns.Fqdn(dConf.domain))
|
||||
if dConf.MatchOnly {
|
||||
matchDomains = append(matchDomains, "~."+dns.Fqdn(dConf.Domain))
|
||||
continue
|
||||
}
|
||||
searchDomains = append(searchDomains, dns.Fqdn(dConf.domain))
|
||||
searchDomains = append(searchDomains, dns.Fqdn(dConf.Domain))
|
||||
}
|
||||
|
||||
newDomainList := append(searchDomains, matchDomains...)
|
||||
newDomainList := append(searchDomains, matchDomains...) //nolint:gocritic
|
||||
|
||||
priority := networkManagerDbusSearchDomainOnlyPriority
|
||||
switch {
|
||||
case config.routeAll:
|
||||
case config.RouteAll:
|
||||
priority = networkManagerDbusPrimaryDNSPriority
|
||||
newDomainList = append(newDomainList, "~.")
|
||||
if !n.routingAll {
|
||||
log.Infof("configured %s:%d as main DNS forwarder for this peer", config.serverIP, config.serverPort)
|
||||
log.Infof("configured %s:%d as main DNS forwarder for this peer", config.ServerIP, config.ServerPort)
|
||||
}
|
||||
case len(matchDomains) > 0:
|
||||
priority = networkManagerDbusWithMatchDomainPriority
|
||||
}
|
||||
|
||||
if priority != networkManagerDbusPrimaryDNSPriority && n.routingAll {
|
||||
log.Infof("removing %s:%d as main DNS forwarder for this peer", config.serverIP, config.serverPort)
|
||||
log.Infof("removing %s:%d as main DNS forwarder for this peer", config.ServerIP, config.ServerPort)
|
||||
n.routingAll = false
|
||||
}
|
||||
|
||||
connSettings[networkManagerDbusIPv4Key][networkManagerDbusDNSPriorityKey] = dbus.MakeVariant(priority)
|
||||
connSettings[networkManagerDbusIPv4Key][networkManagerDbusDNSSearchKey] = dbus.MakeVariant(newDomainList)
|
||||
|
||||
// create a backup for unclean shutdown detection before adding domains, as these might end up in the resolv.conf file.
|
||||
// The file content itself is not important for network-manager restoration
|
||||
if err := createUncleanShutdownIndicator(defaultResolvConfPath, networkManager, dnsIP.String()); err != nil {
|
||||
log.Errorf("failed to create unclean shutdown resolv.conf backup: %s", err)
|
||||
}
|
||||
|
||||
log.Infof("adding %d search domains and %d match domains. Search list: %s , Match list: %s", len(searchDomains), len(matchDomains), searchDomains, matchDomains)
|
||||
err = n.reApplyConnectionSettings(connSettings, configVersion)
|
||||
if err != nil {
|
||||
return fmt.Errorf("got an error while reapplying the connection with new settings, error: %s", err)
|
||||
return fmt.Errorf("reapplying the connection with new settings, error: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (n *networkManagerDbusConfigurator) restoreHostDNS() error {
|
||||
// once the interface is gone network manager cleans all config associated with it
|
||||
return n.deleteConnectionSettings()
|
||||
if err := n.deleteConnectionSettings(); err != nil {
|
||||
return fmt.Errorf("delete connection settings: %w", err)
|
||||
}
|
||||
|
||||
if err := removeUncleanShutdownIndicator(); err != nil {
|
||||
log.Errorf("failed to remove unclean shutdown resolv.conf backup: %s", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (n *networkManagerDbusConfigurator) getAppliedConnectionSettings() (networkManagerConnSettings, networkManagerConfigVersion, error) {
|
||||
obj, closeConn, err := getDbusObject(networkManagerDest, n.dbusLinkObject)
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("got error while attempting to retrieve the applied connection settings, err: %s", err)
|
||||
return nil, 0, fmt.Errorf("attempting to retrieve the applied connection settings, err: %w", err)
|
||||
}
|
||||
defer closeConn()
|
||||
|
||||
@@ -175,7 +196,7 @@ func (n *networkManagerDbusConfigurator) getAppliedConnectionSettings() (network
|
||||
err = obj.CallWithContext(ctx, networkManagerDbusDeviceGetAppliedConnectionMethod, dbusDefaultFlag,
|
||||
networkManagerDbusDefaultBehaviorFlag).Store(&connSettings, &configVersion)
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("got error while calling GetAppliedConnection method with context, err: %s", err)
|
||||
return nil, 0, fmt.Errorf("calling GetAppliedConnection method with context, err: %w", err)
|
||||
}
|
||||
|
||||
return connSettings, configVersion, nil
|
||||
@@ -184,7 +205,7 @@ func (n *networkManagerDbusConfigurator) getAppliedConnectionSettings() (network
|
||||
func (n *networkManagerDbusConfigurator) reApplyConnectionSettings(connSettings networkManagerConnSettings, configVersion networkManagerConfigVersion) error {
|
||||
obj, closeConn, err := getDbusObject(networkManagerDest, n.dbusLinkObject)
|
||||
if err != nil {
|
||||
return fmt.Errorf("got error while attempting to retrieve the applied connection settings, err: %s", err)
|
||||
return fmt.Errorf("attempting to retrieve the applied connection settings, err: %w", err)
|
||||
}
|
||||
defer closeConn()
|
||||
|
||||
@@ -194,7 +215,7 @@ func (n *networkManagerDbusConfigurator) reApplyConnectionSettings(connSettings
|
||||
err = obj.CallWithContext(ctx, networkManagerDbusDeviceReapplyMethod, dbusDefaultFlag,
|
||||
connSettings, configVersion, networkManagerDbusDefaultBehaviorFlag).Store()
|
||||
if err != nil {
|
||||
return fmt.Errorf("got error while calling ReApply method with context, err: %s", err)
|
||||
return fmt.Errorf("calling ReApply method with context, err: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -203,21 +224,34 @@ func (n *networkManagerDbusConfigurator) reApplyConnectionSettings(connSettings
|
||||
func (n *networkManagerDbusConfigurator) deleteConnectionSettings() error {
|
||||
obj, closeConn, err := getDbusObject(networkManagerDest, n.dbusLinkObject)
|
||||
if err != nil {
|
||||
return fmt.Errorf("got error while attempting to retrieve the applied connection settings, err: %s", err)
|
||||
return fmt.Errorf("attempting to retrieve the applied connection settings, err: %w", err)
|
||||
}
|
||||
defer closeConn()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.TODO(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
// this call is required to remove the device for DNS cleanup, even if it fails
|
||||
err = obj.CallWithContext(ctx, networkManagerDbusDeviceDeleteMethod, dbusDefaultFlag).Store()
|
||||
if err != nil {
|
||||
return fmt.Errorf("got error while calling delete method with context, err: %s", err)
|
||||
var dbusErr dbus.Error
|
||||
if errors.As(err, &dbusErr) && dbusErr.Name == dbus.ErrMsgUnknownMethod.Name {
|
||||
// interface is gone already
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("calling delete method with context, err: %s", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (n *networkManagerDbusConfigurator) restoreUncleanShutdownDNS(*netip.Addr) error {
|
||||
if err := n.restoreHostDNS(); err != nil {
|
||||
return fmt.Errorf("restoring dns via network-manager: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func isNetworkManagerSupported() bool {
|
||||
return isNetworkManagerSupportedVersion() && isNetworkManagerSupportedMode()
|
||||
}
|
||||
@@ -249,13 +283,13 @@ func isNetworkManagerSupportedMode() bool {
|
||||
func getNetworkManagerDNSProperty(property string, store any) error {
|
||||
obj, closeConn, err := getDbusObject(networkManagerDest, networkManagerDbusDNSManagerObjectNode)
|
||||
if err != nil {
|
||||
return fmt.Errorf("got error while attempting to retrieve the network manager dns manager object, error: %s", err)
|
||||
return fmt.Errorf("attempting to retrieve the network manager dns manager object, error: %w", err)
|
||||
}
|
||||
defer closeConn()
|
||||
|
||||
v, e := obj.GetProperty(property)
|
||||
if e != nil {
|
||||
return fmt.Errorf("got an error getting property %s: %v", property, e)
|
||||
return fmt.Errorf("getting property %s: %w", property, e)
|
||||
}
|
||||
|
||||
return v.Store(store)
|
||||
@@ -277,24 +311,30 @@ func isNetworkManagerSupportedVersion() bool {
|
||||
}
|
||||
versionValue, err := parseVersion(value.Value().(string))
|
||||
if err != nil {
|
||||
log.Errorf("nm: parse version: %s", err)
|
||||
return false
|
||||
}
|
||||
|
||||
constraints, err := version.NewConstraint(supportedNetworkManagerVersionConstraint)
|
||||
if err != nil {
|
||||
return false
|
||||
var supported bool
|
||||
for _, constraint := range supportedNetworkManagerVersionConstraints {
|
||||
constr, err := version.NewConstraint(constraint)
|
||||
if err != nil {
|
||||
log.Errorf("nm: create constraint: %s", err)
|
||||
return false
|
||||
}
|
||||
|
||||
if met := constr.Check(versionValue); met {
|
||||
supported = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return constraints.Check(versionValue)
|
||||
log.Debugf("network manager constraints [%s] met: %t", strings.Join(supportedNetworkManagerVersionConstraints, " | "), supported)
|
||||
return supported
|
||||
}
|
||||
|
||||
func parseVersion(inputVersion string) (*version.Version, error) {
|
||||
reg, err := regexp.Compile(version.SemverRegexpRaw)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if inputVersion == "" || !reg.MatchString(inputVersion) {
|
||||
if inputVersion == "" || !nbversion.SemverRegexp.MatchString(inputVersion) {
|
||||
return nil, fmt.Errorf("couldn't parse the provided version: Not SemVer")
|
||||
}
|
||||
|
||||
|
||||
57
client/internal/dns/notifier.go
Normal file
57
client/internal/dns/notifier.go
Normal file
@@ -0,0 +1,57 @@
|
||||
package dns
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"sort"
|
||||
"sync"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/listener"
|
||||
)
|
||||
|
||||
type notifier struct {
|
||||
listener listener.NetworkChangeListener
|
||||
listenerMux sync.Mutex
|
||||
searchDomains []string
|
||||
}
|
||||
|
||||
func newNotifier(initialSearchDomains []string) *notifier {
|
||||
sort.Strings(initialSearchDomains)
|
||||
return ¬ifier{
|
||||
searchDomains: initialSearchDomains,
|
||||
}
|
||||
}
|
||||
|
||||
func (n *notifier) setListener(listener listener.NetworkChangeListener) {
|
||||
n.listenerMux.Lock()
|
||||
defer n.listenerMux.Unlock()
|
||||
n.listener = listener
|
||||
}
|
||||
|
||||
func (n *notifier) onNewSearchDomains(searchDomains []string) {
|
||||
sort.Strings(searchDomains)
|
||||
|
||||
if len(n.searchDomains) != len(searchDomains) {
|
||||
n.searchDomains = searchDomains
|
||||
n.notify()
|
||||
return
|
||||
}
|
||||
|
||||
if reflect.DeepEqual(n.searchDomains, searchDomains) {
|
||||
return
|
||||
}
|
||||
|
||||
n.searchDomains = searchDomains
|
||||
n.notify()
|
||||
}
|
||||
|
||||
func (n *notifier) notify() {
|
||||
n.listenerMux.Lock()
|
||||
defer n.listenerMux.Unlock()
|
||||
if n.listener == nil {
|
||||
return
|
||||
}
|
||||
|
||||
go func(l listener.NetworkChangeListener) {
|
||||
l.OnNetworkChanged("")
|
||||
}(n.listener)
|
||||
}
|
||||
@@ -3,10 +3,10 @@
|
||||
package dns
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"os"
|
||||
"net/netip"
|
||||
"os/exec"
|
||||
"strings"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
@@ -15,11 +15,24 @@ const resolvconfCommand = "resolvconf"
|
||||
|
||||
type resolvconf struct {
|
||||
ifaceName string
|
||||
|
||||
originalSearchDomains []string
|
||||
originalNameServers []string
|
||||
othersConfigs []string
|
||||
}
|
||||
|
||||
func newResolvConfConfigurator(wgInterface WGIface) (hostManager, error) {
|
||||
// supported "openresolv" only
|
||||
func newResolvConfConfigurator(wgInterface string) (hostManager, error) {
|
||||
resolvConfEntries, err := parseDefaultResolvConf()
|
||||
if err != nil {
|
||||
log.Errorf("could not read original search domains from %s: %s", defaultResolvConfPath, err)
|
||||
}
|
||||
|
||||
return &resolvconf{
|
||||
ifaceName: wgInterface.Name(),
|
||||
ifaceName: wgInterface,
|
||||
originalSearchDomains: resolvConfEntries.searchDomains,
|
||||
originalNameServers: resolvConfEntries.nameServers,
|
||||
othersConfigs: resolvConfEntries.others,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -27,69 +40,69 @@ func (r *resolvconf) supportCustomPort() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (r *resolvconf) applyDNSConfig(config hostDNSConfig) error {
|
||||
func (r *resolvconf) applyDNSConfig(config HostDNSConfig) error {
|
||||
var err error
|
||||
if !config.routeAll {
|
||||
if !config.RouteAll {
|
||||
err = r.restoreHostDNS()
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
log.Errorf("restore host dns: %s", err)
|
||||
}
|
||||
return fmt.Errorf("unable to configure DNS for this peer using resolvconf manager without a nameserver group with all domains configured")
|
||||
}
|
||||
|
||||
var searchDomains string
|
||||
appendedDomains := 0
|
||||
for _, dConf := range config.domains {
|
||||
if dConf.matchOnly || dConf.disabled {
|
||||
continue
|
||||
}
|
||||
searchDomainList := searchDomains(config)
|
||||
searchDomainList = mergeSearchDomains(searchDomainList, r.originalSearchDomains)
|
||||
|
||||
if appendedDomains >= fileMaxNumberOfSearchDomains {
|
||||
// lets log all skipped domains
|
||||
log.Infof("already appended %d domains to search list. Skipping append of %s domain", fileMaxNumberOfSearchDomains, dConf.domain)
|
||||
continue
|
||||
}
|
||||
options := prepareOptionsWithTimeout(r.othersConfigs, int(dnsFailoverTimeout.Seconds()), dnsFailoverAttempts)
|
||||
|
||||
if fileSearchLineBeginCharCount+len(searchDomains) > fileMaxLineCharsLimit {
|
||||
// lets log all skipped domains
|
||||
log.Infof("search list line is larger than %d characters. Skipping append of %s domain", fileMaxLineCharsLimit, dConf.domain)
|
||||
continue
|
||||
}
|
||||
buf := prepareResolvConfContent(
|
||||
searchDomainList,
|
||||
append([]string{config.ServerIP}, r.originalNameServers...),
|
||||
options)
|
||||
|
||||
searchDomains += " " + dConf.domain
|
||||
appendedDomains++
|
||||
// create a backup for unclean shutdown detection before the resolv.conf is changed
|
||||
if err := createUncleanShutdownIndicator(defaultResolvConfPath, resolvConfManager, config.ServerIP); err != nil {
|
||||
log.Errorf("failed to create unclean shutdown resolv.conf backup: %s", err)
|
||||
}
|
||||
|
||||
originalContent, err := os.ReadFile(fileDefaultResolvConfBackupLocation)
|
||||
err = r.applyConfig(buf)
|
||||
if err != nil {
|
||||
log.Errorf("Could not read existing resolv.conf")
|
||||
}
|
||||
content := fmt.Sprintf(fileGeneratedResolvConfContentFormat, fileDefaultResolvConfBackupLocation, config.serverIP, searchDomains, string(originalContent))
|
||||
|
||||
err = r.applyConfig(content)
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("apply config: %w", err)
|
||||
}
|
||||
|
||||
log.Infof("added %d search domains. Search list: %s", appendedDomains, searchDomains)
|
||||
log.Infof("added %d search domains. Search list: %s", len(searchDomainList), searchDomainList)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *resolvconf) restoreHostDNS() error {
|
||||
// openresolv only, debian resolvconf doesn't support "-f"
|
||||
cmd := exec.Command(resolvconfCommand, "-f", "-d", r.ifaceName)
|
||||
_, err := cmd.Output()
|
||||
if err != nil {
|
||||
return fmt.Errorf("got an error while removing resolvconf configuration for %s interface, error: %s", r.ifaceName, err)
|
||||
return fmt.Errorf("removing resolvconf configuration for %s interface, error: %w", r.ifaceName, err)
|
||||
}
|
||||
|
||||
if err := removeUncleanShutdownIndicator(); err != nil {
|
||||
log.Errorf("failed to remove unclean shutdown resolv.conf backup: %s", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *resolvconf) applyConfig(content bytes.Buffer) error {
|
||||
// openresolv only, debian resolvconf doesn't support "-x"
|
||||
cmd := exec.Command(resolvconfCommand, "-x", "-a", r.ifaceName)
|
||||
cmd.Stdin = &content
|
||||
_, err := cmd.Output()
|
||||
if err != nil {
|
||||
return fmt.Errorf("applying resolvconf configuration for %s interface, error: %w", r.ifaceName, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *resolvconf) applyConfig(content string) error {
|
||||
cmd := exec.Command(resolvconfCommand, "-x", "-a", r.ifaceName)
|
||||
cmd.Stdin = strings.NewReader(content)
|
||||
_, err := cmd.Output()
|
||||
if err != nil {
|
||||
return fmt.Errorf("got an error while appying resolvconf configuration for %s interface, error: %s", r.ifaceName, err)
|
||||
func (r *resolvconf) restoreUncleanShutdownDNS(*netip.Addr) error {
|
||||
if err := r.restoreHostDNS(); err != nil {
|
||||
return fmt.Errorf("restoring dns for interface %s: %w", r.ifaceName, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -31,10 +31,13 @@ func (r *responseWriter) RemoteAddr() net.Addr {
|
||||
func (r *responseWriter) WriteMsg(msg *dns.Msg) error {
|
||||
buff, err := msg.Pack()
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("pack: %w", err)
|
||||
}
|
||||
_, err = r.Write(buff)
|
||||
return err
|
||||
|
||||
if _, err := r.Write(buff); err != nil {
|
||||
return fmt.Errorf("write: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Write writes a raw buffer back to the client.
|
||||
|
||||
@@ -4,12 +4,15 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
"github.com/mitchellh/hashstructure/v2"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/listener"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
)
|
||||
|
||||
@@ -18,6 +21,11 @@ type ReadyListener interface {
|
||||
OnReady()
|
||||
}
|
||||
|
||||
// IosDnsManager is a dns manager interface for iOS
|
||||
type IosDnsManager interface {
|
||||
ApplyDns(string)
|
||||
}
|
||||
|
||||
// Server is a dns server interface
|
||||
type Server interface {
|
||||
Initialize() error
|
||||
@@ -25,6 +33,8 @@ type Server interface {
|
||||
DnsIP() string
|
||||
UpdateDNSServer(serial uint64, update nbdns.Config) error
|
||||
OnUpdatedHostDNSServer(strings []string)
|
||||
SearchDomains() []string
|
||||
ProbeAvailability()
|
||||
}
|
||||
|
||||
type registeredHandlerMap map[string]handlerWithStop
|
||||
@@ -41,17 +51,24 @@ type DefaultServer struct {
|
||||
hostManager hostManager
|
||||
updateSerial uint64
|
||||
previousConfigHash uint64
|
||||
currentConfig hostDNSConfig
|
||||
currentConfig HostDNSConfig
|
||||
|
||||
// permanent related properties
|
||||
permanent bool
|
||||
hostsDnsList []string
|
||||
hostsDnsListLock sync.Mutex
|
||||
|
||||
// make sense on mobile only
|
||||
searchDomainNotifier *notifier
|
||||
iosDnsManager IosDnsManager
|
||||
|
||||
statusRecorder *peer.Status
|
||||
}
|
||||
|
||||
type handlerWithStop interface {
|
||||
dns.Handler
|
||||
stop()
|
||||
probeAvailability()
|
||||
}
|
||||
|
||||
type muxUpdate struct {
|
||||
@@ -60,7 +77,12 @@ type muxUpdate struct {
|
||||
}
|
||||
|
||||
// NewDefaultServer returns a new dns server
|
||||
func NewDefaultServer(ctx context.Context, wgInterface WGIface, customAddress string) (*DefaultServer, error) {
|
||||
func NewDefaultServer(
|
||||
ctx context.Context,
|
||||
wgInterface WGIface,
|
||||
customAddress string,
|
||||
statusRecorder *peer.Status,
|
||||
) (*DefaultServer, error) {
|
||||
var addrPort *netip.AddrPort
|
||||
if customAddress != "" {
|
||||
parsedAddrPort, err := netip.ParseAddrPort(customAddress)
|
||||
@@ -77,21 +99,43 @@ func NewDefaultServer(ctx context.Context, wgInterface WGIface, customAddress st
|
||||
dnsService = newServiceViaListener(wgInterface, addrPort)
|
||||
}
|
||||
|
||||
return newDefaultServer(ctx, wgInterface, dnsService), nil
|
||||
return newDefaultServer(ctx, wgInterface, dnsService, statusRecorder), nil
|
||||
}
|
||||
|
||||
// NewDefaultServerPermanentUpstream returns a new dns server. It optimized for mobile systems
|
||||
func NewDefaultServerPermanentUpstream(ctx context.Context, wgInterface WGIface, hostsDnsList []string) *DefaultServer {
|
||||
func NewDefaultServerPermanentUpstream(
|
||||
ctx context.Context,
|
||||
wgInterface WGIface,
|
||||
hostsDnsList []string,
|
||||
config nbdns.Config,
|
||||
listener listener.NetworkChangeListener,
|
||||
statusRecorder *peer.Status,
|
||||
) *DefaultServer {
|
||||
log.Debugf("host dns address list is: %v", hostsDnsList)
|
||||
ds := newDefaultServer(ctx, wgInterface, newServiceViaMemory(wgInterface))
|
||||
ds := newDefaultServer(ctx, wgInterface, newServiceViaMemory(wgInterface), statusRecorder)
|
||||
ds.permanent = true
|
||||
ds.hostsDnsList = hostsDnsList
|
||||
ds.addHostRootZone()
|
||||
ds.currentConfig = dnsConfigToHostDNSConfig(config, ds.service.RuntimeIP(), ds.service.RuntimePort())
|
||||
ds.searchDomainNotifier = newNotifier(ds.SearchDomains())
|
||||
ds.searchDomainNotifier.setListener(listener)
|
||||
setServerDns(ds)
|
||||
return ds
|
||||
}
|
||||
|
||||
func newDefaultServer(ctx context.Context, wgInterface WGIface, dnsService service) *DefaultServer {
|
||||
// NewDefaultServerIos returns a new dns server. It optimized for ios
|
||||
func NewDefaultServerIos(
|
||||
ctx context.Context,
|
||||
wgInterface WGIface,
|
||||
iosDnsManager IosDnsManager,
|
||||
statusRecorder *peer.Status,
|
||||
) *DefaultServer {
|
||||
ds := newDefaultServer(ctx, wgInterface, newServiceViaMemory(wgInterface), statusRecorder)
|
||||
ds.iosDnsManager = iosDnsManager
|
||||
return ds
|
||||
}
|
||||
|
||||
func newDefaultServer(ctx context.Context, wgInterface WGIface, dnsService service, statusRecorder *peer.Status) *DefaultServer {
|
||||
ctx, stop := context.WithCancel(ctx)
|
||||
defaultServer := &DefaultServer{
|
||||
ctx: ctx,
|
||||
@@ -101,7 +145,8 @@ func newDefaultServer(ctx context.Context, wgInterface WGIface, dnsService servi
|
||||
localResolver: &localResolver{
|
||||
registeredMap: make(registrationMap),
|
||||
},
|
||||
wgInterface: wgInterface,
|
||||
wgInterface: wgInterface,
|
||||
statusRecorder: statusRecorder,
|
||||
}
|
||||
|
||||
return defaultServer
|
||||
@@ -119,12 +164,15 @@ func (s *DefaultServer) Initialize() (err error) {
|
||||
if s.permanent {
|
||||
err = s.service.Listen()
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("service listen: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
s.hostManager, err = newHostManager(s.wgInterface)
|
||||
return
|
||||
s.hostManager, err = s.initialize()
|
||||
if err != nil {
|
||||
return fmt.Errorf("initialize: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// DnsIP returns the DNS resolver server IP address
|
||||
@@ -202,7 +250,7 @@ func (s *DefaultServer) UpdateDNSServer(serial uint64, update nbdns.Config) erro
|
||||
}
|
||||
|
||||
if err := s.applyConfiguration(update); err != nil {
|
||||
return err
|
||||
return fmt.Errorf("apply configuration: %w", err)
|
||||
}
|
||||
|
||||
s.updateSerial = serial
|
||||
@@ -212,8 +260,31 @@ func (s *DefaultServer) UpdateDNSServer(serial uint64, update nbdns.Config) erro
|
||||
}
|
||||
}
|
||||
|
||||
func (s *DefaultServer) SearchDomains() []string {
|
||||
var searchDomains []string
|
||||
|
||||
for _, dConf := range s.currentConfig.Domains {
|
||||
if dConf.Disabled {
|
||||
continue
|
||||
}
|
||||
if dConf.MatchOnly {
|
||||
continue
|
||||
}
|
||||
searchDomains = append(searchDomains, dConf.Domain)
|
||||
}
|
||||
return searchDomains
|
||||
}
|
||||
|
||||
// ProbeAvailability tests each upstream group's servers for availability
|
||||
// and deactivates the group if no server responds
|
||||
func (s *DefaultServer) ProbeAvailability() {
|
||||
for _, mux := range s.dnsMuxMap {
|
||||
mux.probeAvailability()
|
||||
}
|
||||
}
|
||||
|
||||
func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
|
||||
// is the service should be disabled, we stop the listener or fake resolver
|
||||
// is the service should be Disabled, we stop the listener or fake resolver
|
||||
// and proceed with a regular update to clean up the handlers and records
|
||||
if update.ServiceEnable {
|
||||
_ = s.service.Listen()
|
||||
@@ -229,7 +300,7 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
|
||||
if err != nil {
|
||||
return fmt.Errorf("not applying dns update, error: %v", err)
|
||||
}
|
||||
muxUpdates := append(localMuxUpdates, upstreamMuxUpdates...)
|
||||
muxUpdates := append(localMuxUpdates, upstreamMuxUpdates...) //nolint:gocritic
|
||||
|
||||
s.updateMux(muxUpdates)
|
||||
s.updateLocalResolver(localRecords)
|
||||
@@ -238,14 +309,20 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
|
||||
hostUpdate := s.currentConfig
|
||||
if s.service.RuntimePort() != defaultPort && !s.hostManager.supportCustomPort() {
|
||||
log.Warnf("the DNS manager of this peer doesn't support custom port. Disabling primary DNS setup. " +
|
||||
"Learn more at: https://netbird.io/docs/how-to-guides/nameservers#local-resolver")
|
||||
hostUpdate.routeAll = false
|
||||
"Learn more at: https://docs.netbird.io/how-to/manage-dns-in-your-network#local-resolver")
|
||||
hostUpdate.RouteAll = false
|
||||
}
|
||||
|
||||
if err = s.hostManager.applyDNSConfig(hostUpdate); err != nil {
|
||||
log.Error(err)
|
||||
}
|
||||
|
||||
if s.searchDomainNotifier != nil {
|
||||
s.searchDomainNotifier.onNewSearchDomains(s.SearchDomains())
|
||||
}
|
||||
|
||||
s.updateNSGroupStates(update.NameServerGroups)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -285,10 +362,19 @@ func (s *DefaultServer) buildUpstreamHandlerUpdate(nameServerGroups []*nbdns.Nam
|
||||
continue
|
||||
}
|
||||
|
||||
handler := newUpstreamResolver(s.ctx)
|
||||
handler, err := newUpstreamResolver(
|
||||
s.ctx,
|
||||
s.wgInterface.Name(),
|
||||
s.wgInterface.Address().IP,
|
||||
s.wgInterface.Address().Network,
|
||||
s.statusRecorder,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to create a new upstream resolver, error: %v", err)
|
||||
}
|
||||
for _, ns := range nsGroup.NameServers {
|
||||
if ns.NSType != nbdns.UDPNameServerType {
|
||||
log.Warnf("skiping nameserver %s with type %s, this peer supports only %s",
|
||||
log.Warnf("skipping nameserver %s with type %s, this peer supports only %s",
|
||||
ns.IP.String(), ns.NSType.String(), nbdns.UDPNameServerType.String())
|
||||
continue
|
||||
}
|
||||
@@ -306,7 +392,7 @@ func (s *DefaultServer) buildUpstreamHandlerUpdate(nameServerGroups []*nbdns.Nam
|
||||
// reapply DNS settings, but it not touch the original configuration and serial number
|
||||
// because it is temporal deactivation until next try
|
||||
//
|
||||
// after some period defined by upstream it trys to reactivate self by calling this hook
|
||||
// after some period defined by upstream it tries to reactivate self by calling this hook
|
||||
// everything we need here is just to re-apply current configuration because it already
|
||||
// contains this upstream settings (temporal deactivation not removed it)
|
||||
handler.deactivate, handler.reactivate = s.upstreamCallbacks(nsGroup, handler)
|
||||
@@ -335,6 +421,7 @@ func (s *DefaultServer) buildUpstreamHandlerUpdate(nameServerGroups []*nbdns.Nam
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return muxUpdates, nil
|
||||
}
|
||||
|
||||
@@ -403,14 +490,14 @@ func getNSHostPort(ns nbdns.NameServer) string {
|
||||
func (s *DefaultServer) upstreamCallbacks(
|
||||
nsGroup *nbdns.NameServerGroup,
|
||||
handler dns.Handler,
|
||||
) (deactivate func(), reactivate func()) {
|
||||
) (deactivate func(error), reactivate func()) {
|
||||
var removeIndex map[string]int
|
||||
deactivate = func() {
|
||||
deactivate = func(err error) {
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
|
||||
l := log.WithField("nameservers", nsGroup.NameServers)
|
||||
l.Info("temporary deactivate nameservers group due timeout")
|
||||
l.Info("Temporarily deactivating nameservers group due to timeout")
|
||||
|
||||
removeIndex = make(map[string]int)
|
||||
for _, domain := range nsGroup.Domains {
|
||||
@@ -418,29 +505,32 @@ func (s *DefaultServer) upstreamCallbacks(
|
||||
}
|
||||
if nsGroup.Primary {
|
||||
removeIndex[nbdns.RootZone] = -1
|
||||
s.currentConfig.routeAll = false
|
||||
s.currentConfig.RouteAll = false
|
||||
}
|
||||
|
||||
for i, item := range s.currentConfig.domains {
|
||||
if _, found := removeIndex[item.domain]; found {
|
||||
s.currentConfig.domains[i].disabled = true
|
||||
s.service.DeregisterMux(item.domain)
|
||||
removeIndex[item.domain] = i
|
||||
for i, item := range s.currentConfig.Domains {
|
||||
if _, found := removeIndex[item.Domain]; found {
|
||||
s.currentConfig.Domains[i].Disabled = true
|
||||
s.service.DeregisterMux(item.Domain)
|
||||
removeIndex[item.Domain] = i
|
||||
}
|
||||
}
|
||||
if err := s.hostManager.applyDNSConfig(s.currentConfig); err != nil {
|
||||
l.WithError(err).Error("fail to apply nameserver deactivation on the host")
|
||||
l.Errorf("Failed to apply nameserver deactivation on the host: %v", err)
|
||||
}
|
||||
|
||||
s.updateNSState(nsGroup, err, false)
|
||||
|
||||
}
|
||||
reactivate = func() {
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
|
||||
for domain, i := range removeIndex {
|
||||
if i == -1 || i >= len(s.currentConfig.domains) || s.currentConfig.domains[i].domain != domain {
|
||||
if i == -1 || i >= len(s.currentConfig.Domains) || s.currentConfig.Domains[i].Domain != domain {
|
||||
continue
|
||||
}
|
||||
s.currentConfig.domains[i].disabled = false
|
||||
s.currentConfig.Domains[i].Disabled = false
|
||||
s.service.RegisterMux(domain, handler)
|
||||
}
|
||||
|
||||
@@ -448,22 +538,88 @@ func (s *DefaultServer) upstreamCallbacks(
|
||||
l.Debug("reactivate temporary disabled nameserver group")
|
||||
|
||||
if nsGroup.Primary {
|
||||
s.currentConfig.routeAll = true
|
||||
s.currentConfig.RouteAll = true
|
||||
}
|
||||
if err := s.hostManager.applyDNSConfig(s.currentConfig); err != nil {
|
||||
l.WithError(err).Error("reactivate temporary disabled nameserver group, DNS update apply")
|
||||
}
|
||||
|
||||
s.updateNSState(nsGroup, nil, true)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (s *DefaultServer) addHostRootZone() {
|
||||
handler := newUpstreamResolver(s.ctx)
|
||||
handler, err := newUpstreamResolver(
|
||||
s.ctx,
|
||||
s.wgInterface.Name(),
|
||||
s.wgInterface.Address().IP,
|
||||
s.wgInterface.Address().Network,
|
||||
s.statusRecorder,
|
||||
)
|
||||
if err != nil {
|
||||
log.Errorf("unable to create a new upstream resolver, error: %v", err)
|
||||
return
|
||||
}
|
||||
handler.upstreamServers = make([]string, len(s.hostsDnsList))
|
||||
for n, ua := range s.hostsDnsList {
|
||||
handler.upstreamServers[n] = fmt.Sprintf("%s:53", ua)
|
||||
a, err := netip.ParseAddr(ua)
|
||||
if err != nil {
|
||||
log.Errorf("invalid upstream IP address: %s, error: %s", ua, err)
|
||||
continue
|
||||
}
|
||||
|
||||
ipString := ua
|
||||
if !a.Is4() {
|
||||
ipString = fmt.Sprintf("[%s]", ua)
|
||||
}
|
||||
|
||||
handler.upstreamServers[n] = fmt.Sprintf("%s:53", ipString)
|
||||
}
|
||||
handler.deactivate = func() {}
|
||||
handler.deactivate = func(error) {}
|
||||
handler.reactivate = func() {}
|
||||
s.service.RegisterMux(nbdns.RootZone, handler)
|
||||
}
|
||||
|
||||
func (s *DefaultServer) updateNSGroupStates(groups []*nbdns.NameServerGroup) {
|
||||
var states []peer.NSGroupState
|
||||
|
||||
for _, group := range groups {
|
||||
var servers []string
|
||||
for _, ns := range group.NameServers {
|
||||
servers = append(servers, fmt.Sprintf("%s:%d", ns.IP, ns.Port))
|
||||
}
|
||||
|
||||
state := peer.NSGroupState{
|
||||
ID: generateGroupKey(group),
|
||||
Servers: servers,
|
||||
Domains: group.Domains,
|
||||
// The probe will determine the state, default enabled
|
||||
Enabled: true,
|
||||
Error: nil,
|
||||
}
|
||||
states = append(states, state)
|
||||
}
|
||||
s.statusRecorder.UpdateDNSStates(states)
|
||||
}
|
||||
|
||||
func (s *DefaultServer) updateNSState(nsGroup *nbdns.NameServerGroup, err error, enabled bool) {
|
||||
states := s.statusRecorder.GetDNSStates()
|
||||
id := generateGroupKey(nsGroup)
|
||||
for i, state := range states {
|
||||
if state.ID == id {
|
||||
states[i].Enabled = enabled
|
||||
states[i].Error = err
|
||||
break
|
||||
}
|
||||
}
|
||||
s.statusRecorder.UpdateDNSStates(states)
|
||||
}
|
||||
|
||||
func generateGroupKey(nsGroup *nbdns.NameServerGroup) string {
|
||||
var servers []string
|
||||
for _, ns := range nsGroup.NameServers {
|
||||
servers = append(servers, fmt.Sprintf("%s:%d", ns.IP, ns.Port))
|
||||
}
|
||||
return fmt.Sprintf("%s_%s_%s", nsGroup.ID, nsGroup.Name, strings.Join(servers, ","))
|
||||
}
|
||||
|
||||
5
client/internal/dns/server_android.go
Normal file
5
client/internal/dns/server_android.go
Normal file
@@ -0,0 +1,5 @@
|
||||
package dns
|
||||
|
||||
func (s *DefaultServer) initialize() (manager hostManager, err error) {
|
||||
return newHostManager()
|
||||
}
|
||||
7
client/internal/dns/server_darwin.go
Normal file
7
client/internal/dns/server_darwin.go
Normal file
@@ -0,0 +1,7 @@
|
||||
//go:build !ios
|
||||
|
||||
package dns
|
||||
|
||||
func (s *DefaultServer) initialize() (manager hostManager, err error) {
|
||||
return newHostManager()
|
||||
}
|
||||
@@ -19,6 +19,6 @@ func TestGetServerDns(t *testing.T) {
|
||||
}
|
||||
|
||||
if srvB != srv {
|
||||
t.Errorf("missmatch dns instances")
|
||||
t.Errorf("mismatch dns instances")
|
||||
}
|
||||
}
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user