mirror of
https://github.com/netbirdio/netbird.git
synced 2026-03-31 14:34:54 -04:00
Compare commits
520 Commits
netmap
...
refactor/o
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
0aeac50803 | ||
|
|
f1ee844446 | ||
|
|
5acfe00de1 | ||
|
|
3e836db1d1 | ||
|
|
232c9788b9 | ||
|
|
c01874e9ce | ||
|
|
1b2517ea20 | ||
|
|
93c8f9476c | ||
|
|
c6d16724fb | ||
|
|
3556d6a1fc | ||
|
|
96e9c3683c | ||
|
|
a4a5a2eb23 | ||
|
|
a78a313fcf | ||
|
|
aa34dda05a | ||
|
|
eeb688cb5e | ||
|
|
f78b629f40 | ||
|
|
5c8be070f0 | ||
|
|
833fec0029 | ||
|
|
66f003d00f | ||
|
|
14a3461e86 | ||
|
|
b853c87e36 | ||
|
|
5d3832c597 | ||
|
|
1fd6583e5f | ||
|
|
3e9f0d57ac | ||
|
|
d0ab68c599 | ||
|
|
d45255b1b7 | ||
|
|
2af95886f8 | ||
|
|
97ff9a3e5d | ||
|
|
b6a558b523 | ||
|
|
f597379be5 | ||
|
|
265ded0508 | ||
|
|
d962f7fd27 | ||
|
|
30ae3aa030 | ||
|
|
481bbe8513 | ||
|
|
bc7b2c6ba3 | ||
|
|
c6f7a299a9 | ||
|
|
992a6c79b4 | ||
|
|
78795a4a73 | ||
|
|
5a82477d48 | ||
|
|
1ffa519387 | ||
|
|
e4a25b6a60 | ||
|
|
6a6b527f24 | ||
|
|
aeb6e53e51 | ||
|
|
b34887a920 | ||
|
|
b9efda3ce8 | ||
|
|
516de93627 | ||
|
|
4afaabb33c | ||
|
|
6367fe4f74 | ||
|
|
22798ada25 | ||
|
|
2fb399aa7d | ||
|
|
6a81ca2e52 | ||
|
|
e27db948ae | ||
|
|
167c80da40 | ||
|
|
7241a16ff7 | ||
|
|
30b023d126 | ||
|
|
7baef8c502 | ||
|
|
b15ee5c07c | ||
|
|
b1e8ed889e | ||
|
|
44f69c70e0 | ||
|
|
a5731fe509 | ||
|
|
61b38e56e4 | ||
|
|
29ea44b874 | ||
|
|
84aea32118 | ||
|
|
acb5340d40 | ||
|
|
8b7766e34d | ||
|
|
2706ede08e | ||
|
|
3483139903 | ||
|
|
ce7385119a | ||
|
|
cd15c85d5b | ||
|
|
9ff56eae64 | ||
|
|
47a18db186 | ||
|
|
aa0480c5e6 | ||
|
|
15f0a665f8 | ||
|
|
9b5b632ff9 | ||
|
|
0c28099712 | ||
|
|
eb062c07ec | ||
|
|
7a9c75db91 | ||
|
|
c603c40a53 | ||
|
|
48af90c770 | ||
|
|
3cc6d3862d | ||
|
|
d1e5d584f7 | ||
|
|
b1325267d8 | ||
|
|
522dd44bfa | ||
|
|
8154069e77 | ||
|
|
e161a92898 | ||
|
|
3fce8485bb | ||
|
|
1cc88a2190 | ||
|
|
168ea9560e | ||
|
|
f48e33b395 | ||
|
|
f1ed8599fc | ||
|
|
2e596fbf1a | ||
|
|
93f3e1b14b | ||
|
|
649bfb236b | ||
|
|
fa1eaa0aec | ||
|
|
baf211203a | ||
|
|
9d86f76a24 | ||
|
|
409003b4f9 | ||
|
|
9e6e34b42d | ||
|
|
d9905d1a57 | ||
|
|
2bd68efc08 | ||
|
|
6848e1e128 | ||
|
|
668aead4c8 | ||
|
|
f08605a7f1 | ||
|
|
02a3feddb8 | ||
|
|
d9487a5749 | ||
|
|
cfa6d09c5e | ||
|
|
a01253c3c8 | ||
|
|
bc013e4888 | ||
|
|
782e3f8853 | ||
|
|
03fd656344 | ||
|
|
18b049cd24 | ||
|
|
2bdb4cb44a | ||
|
|
abbdf20f65 | ||
|
|
43ef64cf67 | ||
|
|
18316be09a | ||
|
|
1a623943c8 | ||
|
|
fbce8bb511 | ||
|
|
445b626dc8 | ||
|
|
b3c87cb5d1 | ||
|
|
0dbaddc7be | ||
|
|
ad9f044aad | ||
|
|
05930ee6b1 | ||
|
|
e670068cab | ||
|
|
b48cf1bf65 | ||
|
|
7ee7ada273 | ||
|
|
82b4e58ad0 | ||
|
|
ddc365f7a0 | ||
|
|
37ad370344 | ||
|
|
703647da1e | ||
|
|
9eff58ae62 | ||
|
|
3844516aa7 | ||
|
|
f591e47404 | ||
|
|
287ae81195 | ||
|
|
a4a30744ad | ||
|
|
dcba6a6b7e | ||
|
|
6142828a9c | ||
|
|
97bb74f824 | ||
|
|
2147bf75eb | ||
|
|
9a96b91d9d | ||
|
|
e40a29ba17 | ||
|
|
a05bd464cd | ||
|
|
ff330e644e | ||
|
|
713e320c4c | ||
|
|
e67fe89adb | ||
|
|
6cfbb1f320 | ||
|
|
c853011a32 | ||
|
|
b50b89ba14 | ||
|
|
d063fbb8b9 | ||
|
|
e5d42bc963 | ||
|
|
8866394eb6 | ||
|
|
17c20b45ce | ||
|
|
7dacd9cb23 | ||
|
|
6285e0d23e | ||
|
|
a4826cfb5f | ||
|
|
a0bf0bdcc0 | ||
|
|
dffce78a8c | ||
|
|
c7e7ad5030 | ||
|
|
5142dc52c1 | ||
|
|
ecb44ff306 | ||
|
|
e4a5fb3e91 | ||
|
|
e52d352a48 | ||
|
|
f9723c9266 | ||
|
|
8efad1d170 | ||
|
|
a3663fb444 | ||
|
|
c6641be94b | ||
|
|
8c4613b456 | ||
|
|
89cf8a55e2 | ||
|
|
d66140fc82 | ||
|
|
dea6886394 | ||
|
|
1ba6eb62a6 | ||
|
|
f87bc601c6 | ||
|
|
00c3b67182 | ||
|
|
cde0e51c72 | ||
|
|
a22d5041e3 | ||
|
|
fde9f2ffda | ||
|
|
21561a2b07 | ||
|
|
b2139875d9 | ||
|
|
9203690033 | ||
|
|
9683da54b0 | ||
|
|
2e6bbaca96 | ||
|
|
0e48a772ff | ||
|
|
f118d81d32 | ||
|
|
72bfc9d07e | ||
|
|
79822cdc15 | ||
|
|
bdb2a76eae | ||
|
|
ca12bc6953 | ||
|
|
9810386937 | ||
|
|
f1625b32bd | ||
|
|
0ecd5f2118 | ||
|
|
940d0c48c6 | ||
|
|
56cecf849e | ||
|
|
05c4aa7c2c | ||
|
|
2a5cb16494 | ||
|
|
9db1932664 | ||
|
|
1bbabf70b0 | ||
|
|
82746d93ee | ||
|
|
aa575d6f44 | ||
|
|
f66bbcc54c | ||
|
|
5dd6a08ea6 | ||
|
|
eb5d0569ae | ||
|
|
0ee56e14d9 | ||
|
|
52ea2e84e9 | ||
|
|
20fc8e879e | ||
|
|
b60e2c3261 | ||
|
|
df98c67ac8 | ||
|
|
78fab877c0 | ||
|
|
ec6438e643 | ||
|
|
6dd56e3328 | ||
|
|
48edfa601f | ||
|
|
a2a49bdd47 | ||
|
|
a2fb274b86 | ||
|
|
a61e9da3e9 | ||
|
|
65a94f695f | ||
|
|
ec543f89fb | ||
|
|
a7d5c52203 | ||
|
|
582bb58714 | ||
|
|
121dfda915 | ||
|
|
d4c712493a | ||
|
|
1ff8f616a8 | ||
|
|
a1c5287b7c | ||
|
|
12f442439a | ||
|
|
51c1ec283c | ||
|
|
d9b691b8a5 | ||
|
|
4aee3c9e33 | ||
|
|
4ef3890bf7 | ||
|
|
92b9e11d3f | ||
|
|
44e799c687 | ||
|
|
be78efbd42 | ||
|
|
f6f7260897 | ||
|
|
c557c98390 | ||
|
|
7d849a92c0 | ||
|
|
f5e7449d01 | ||
|
|
8420a52563 | ||
|
|
6315644065 | ||
|
|
ef55b9eccc | ||
|
|
218345e0ff | ||
|
|
6886691213 | ||
|
|
b48afd92fd | ||
|
|
a4d905ffe7 | ||
|
|
ed047ec9dd | ||
|
|
39329e12a1 | ||
|
|
4b943c34b7 | ||
|
|
560190519d | ||
|
|
9bc8e6e29e | ||
|
|
9872bee41d | ||
|
|
3a915decd7 | ||
|
|
50e6389a1d | ||
|
|
bbaee18cd5 | ||
|
|
32d1b2d602 | ||
|
|
2a59f04540 | ||
|
|
446de5e2bc | ||
|
|
147971fdfe | ||
|
|
ed259a6a03 | ||
|
|
a3abc211b3 | ||
|
|
20a5afc359 | ||
|
|
00023bf110 | ||
|
|
2806d73161 | ||
|
|
2d7f08c609 | ||
|
|
0c0fd380bd | ||
|
|
ffce48ca5f | ||
|
|
d23b5c892b | ||
|
|
113c21b0e1 | ||
|
|
ab00c41dad | ||
|
|
664d1388aa | ||
|
|
010a8bfdc1 | ||
|
|
6cb697eed6 | ||
|
|
e0bed2b0fb | ||
|
|
601d429d82 | ||
|
|
30f025e7dd | ||
|
|
b4d7605147 | ||
|
|
d54b6967ce | ||
|
|
174e07fefd | ||
|
|
871500c5cc | ||
|
|
cc04aef7b4 | ||
|
|
3ed8b9cee9 | ||
|
|
08b6e9d647 | ||
|
|
bdeb95c58c | ||
|
|
6dc185e141 | ||
|
|
7100be83cd | ||
|
|
67ce14eaea | ||
|
|
d58cf50127 | ||
|
|
40af1a50e3 | ||
|
|
ac05f69131 | ||
|
|
8126d95316 | ||
|
|
0a70e4c5d4 | ||
|
|
106fc75936 | ||
|
|
669904cd06 | ||
|
|
f8b5eedd38 | ||
|
|
931521d505 | ||
|
|
1a5f3c653c | ||
|
|
78044c226d | ||
|
|
389c9619af | ||
|
|
4be826450b | ||
|
|
738387f2de | ||
|
|
baf0678ceb | ||
|
|
7fef8f6758 | ||
|
|
6829a64a2d | ||
|
|
cbf500024f | ||
|
|
509e184e10 | ||
|
|
3e88b7c56e | ||
|
|
b952d8693d | ||
|
|
5b46cc8e9c | ||
|
|
a9d06b883f | ||
|
|
5f06b202c3 | ||
|
|
0eb99c266a | ||
|
|
bac95ace18 | ||
|
|
9812de853b | ||
|
|
ad4f0a6fdf | ||
|
|
4c758c6e52 | ||
|
|
ec5095ba6b | ||
|
|
49a54624f8 | ||
|
|
729bcf2b01 | ||
|
|
a0cdb58303 | ||
|
|
39c99781cb | ||
|
|
01f24907c5 | ||
|
|
10480eb52f | ||
|
|
1e44c5b574 | ||
|
|
940f8b4547 | ||
|
|
46e37fa04c | ||
|
|
b9f205b2ce | ||
|
|
0fd874fa45 | ||
|
|
8016710d24 | ||
|
|
4e918e55ba | ||
|
|
869537c951 | ||
|
|
44f2ce666e | ||
|
|
563dca705c | ||
|
|
7bda385e1b | ||
|
|
30ebcf38c7 | ||
|
|
0106a95f7a | ||
|
|
9929b22afc | ||
|
|
88e4fc2245 | ||
|
|
c8d8748dcf | ||
|
|
507a40bd7f | ||
|
|
ccd4ae6315 | ||
|
|
96d2207684 | ||
|
|
f942491b91 | ||
|
|
8c8900be57 | ||
|
|
cee95461d1 | ||
|
|
49e65109d2 | ||
|
|
d93dd4fc7f | ||
|
|
3a88ac78ff | ||
|
|
da3a053e2b | ||
|
|
0e95f16cdd | ||
|
|
b2379175fe | ||
|
|
09bdd271f1 | ||
|
|
208a2b7169 | ||
|
|
8284ae959c | ||
|
|
6ce09bca16 | ||
|
|
b79c1d64cc | ||
|
|
b1eda43f4b | ||
|
|
d4ef84fe6e | ||
|
|
44e8107383 | ||
|
|
2c1f5e46d5 | ||
|
|
dbec24b520 | ||
|
|
f603cd9202 | ||
|
|
5897a48e29 | ||
|
|
8bf729c7b4 | ||
|
|
7f09b39769 | ||
|
|
158936fb15 | ||
|
|
8934453b30 | ||
|
|
fd67892cb4 | ||
|
|
7e5d3bdfe2 | ||
|
|
b7b0828133 | ||
|
|
ff7863785f | ||
|
|
a3a479429e | ||
|
|
5932298ce0 | ||
|
|
ee0ea86a0a | ||
|
|
24c0aaa745 | ||
|
|
16179db599 | ||
|
|
e27f85b317 | ||
|
|
2fd60b2cb4 | ||
|
|
3dca6099d4 | ||
|
|
cfbcf507fb | ||
|
|
52ae693c9e | ||
|
|
58ff7ab797 | ||
|
|
acb73bd64a | ||
|
|
4ebf6e1c4c | ||
|
|
1e4a0f77e2 | ||
|
|
b51d75204b | ||
|
|
e7d52c8c95 | ||
|
|
ab82302c95 | ||
|
|
d47be154ea | ||
|
|
35c892aea3 | ||
|
|
fc4b37f7bc | ||
|
|
6f0fd1d1b3 | ||
|
|
28cbb4b70f | ||
|
|
1104c9c048 | ||
|
|
5bc601111d | ||
|
|
b74951f29e | ||
|
|
97e10e440c | ||
|
|
6c50b0c84b | ||
|
|
730dd1733e | ||
|
|
82739e2832 | ||
|
|
fa7767e612 | ||
|
|
f1171198de | ||
|
|
9e041b7f82 | ||
|
|
b4c8cf0a67 | ||
|
|
1ef51a4ffa | ||
|
|
f6d57e7a96 | ||
|
|
ab892b8cf9 | ||
|
|
33c9b2d989 | ||
|
|
170e842422 | ||
|
|
4c130a0291 | ||
|
|
afb9673bc4 | ||
|
|
cf6210a6f4 | ||
|
|
c59a39d27d | ||
|
|
47adb976f8 | ||
|
|
9cfc8f8aa4 | ||
|
|
2d1bf3982d | ||
|
|
50ebbe482e | ||
|
|
f43a0a0177 | ||
|
|
51e1d3ab8f | ||
|
|
12c36312b5 | ||
|
|
c720d54de6 | ||
|
|
28248ea9f4 | ||
|
|
0c039274a4 | ||
|
|
fcac02a92f | ||
|
|
a7e46bf7b1 | ||
|
|
fcf150f704 | ||
|
|
a33b11946d | ||
|
|
bdbd1db843 | ||
|
|
f2b5b2e9b5 | ||
|
|
c52b406afa | ||
|
|
1ff7a953a0 | ||
|
|
13e923b7c6 | ||
|
|
13e7198046 | ||
|
|
95174d4619 | ||
|
|
92a0092ad5 | ||
|
|
5ac6f56594 | ||
|
|
880b81154f | ||
|
|
7efaf7eadb | ||
|
|
63a75d72fc | ||
|
|
00944bcdbf | ||
|
|
be6bc46bcd | ||
|
|
d97b03656f | ||
|
|
33b264e598 | ||
|
|
d92f2b633f | ||
|
|
ddea001170 | ||
|
|
5d6dfe5938 | ||
|
|
0f0415b92a | ||
|
|
3ed90728e6 | ||
|
|
8c2d37d3fc | ||
|
|
80b0db80bc | ||
|
|
2a30db02bb | ||
|
|
d2b04922e9 | ||
|
|
049b5fb7ed | ||
|
|
a6c59601f9 | ||
|
|
6016d2f7ce | ||
|
|
181dd93695 | ||
|
|
4bbedb5193 | ||
|
|
9716be854d | ||
|
|
539480a713 | ||
|
|
15eb752a7d | ||
|
|
af1b42e538 | ||
|
|
12f9d12a11 | ||
|
|
18cef8280a | ||
|
|
0911163146 | ||
|
|
bcce1bf184 | ||
|
|
ac0d5ff9f3 | ||
|
|
54d896846b | ||
|
|
855fba8fac | ||
|
|
1802e51213 | ||
|
|
d56dfae9b8 | ||
|
|
6b930271fd | ||
|
|
059fc7c3a2 | ||
|
|
0371f529ca | ||
|
|
501fd93e47 | ||
|
|
727a4f0753 | ||
|
|
e6f7222034 | ||
|
|
bfc33a3f6f | ||
|
|
5ad4ae769a | ||
|
|
f84b606506 | ||
|
|
216d9f2ee8 | ||
|
|
57624203c9 | ||
|
|
24e031ab74 | ||
|
|
df8b8db068 | ||
|
|
3506ac4234 | ||
|
|
0c8f8a62c7 | ||
|
|
cbf9f2058e | ||
|
|
02f3105e48 | ||
|
|
5ee9c77e90 | ||
|
|
c832cef44c | ||
|
|
165988429c | ||
|
|
9d2047a08a | ||
|
|
da39c8bbca | ||
|
|
7321046cd6 | ||
|
|
ea3205643a | ||
|
|
1a15b0f900 | ||
|
|
1f48fdf6ca | ||
|
|
45fd1e9c21 | ||
|
|
63aeeb834d | ||
|
|
268e801ec5 | ||
|
|
788f130941 | ||
|
|
926e11b086 | ||
|
|
0a8c78deb1 | ||
|
|
c815ad86fd | ||
|
|
ef1a39cb01 | ||
|
|
c900fa81bb | ||
|
|
9a6de52dd0 | ||
|
|
19147f518e | ||
|
|
e78ec2e985 | ||
|
|
95d725f2c1 | ||
|
|
4fad0e521f | ||
|
|
a711e116a3 | ||
|
|
668d229b67 | ||
|
|
7c595e8493 | ||
|
|
f9c59a7131 | ||
|
|
1d6f5482dd | ||
|
|
12ff93ba72 | ||
|
|
88d1c5a0fd | ||
|
|
1537b0f5e7 | ||
|
|
2577100096 | ||
|
|
bc09348f5a | ||
|
|
d5ba2ef6ec | ||
|
|
47752e1573 | ||
|
|
58fbc1249c | ||
|
|
1cc341a268 | ||
|
|
89df6e7242 | ||
|
|
f74646a3ac | ||
|
|
e8c2fafccd |
@@ -1,4 +1,4 @@
|
||||
FROM golang:1.21-bullseye
|
||||
FROM golang:1.23-bullseye
|
||||
|
||||
RUN apt-get update && export DEBIAN_FRONTEND=noninteractive \
|
||||
&& apt-get -y install --no-install-recommends\
|
||||
|
||||
@@ -7,7 +7,7 @@
|
||||
"features": {
|
||||
"ghcr.io/devcontainers/features/docker-in-docker:2": {},
|
||||
"ghcr.io/devcontainers/features/go:1": {
|
||||
"version": "1.21"
|
||||
"version": "1.23"
|
||||
}
|
||||
},
|
||||
"workspaceFolder": "/workspaces/${localWorkspaceFolderBasename}",
|
||||
|
||||
8
.editorconfig
Normal file
8
.editorconfig
Normal file
@@ -0,0 +1,8 @@
|
||||
root = true
|
||||
|
||||
[*]
|
||||
end_of_line = lf
|
||||
insert_final_newline = true
|
||||
|
||||
[*.go]
|
||||
indent_style = tab
|
||||
3
.github/FUNDING.yml
vendored
Normal file
3
.github/FUNDING.yml
vendored
Normal file
@@ -0,0 +1,3 @@
|
||||
# These are supported funding model platforms
|
||||
|
||||
github: [netbirdio]
|
||||
9
.github/ISSUE_TEMPLATE/bug-issue-report.md
vendored
9
.github/ISSUE_TEMPLATE/bug-issue-report.md
vendored
@@ -31,9 +31,14 @@ Please specify whether you use NetBird Cloud or self-host NetBird's control plan
|
||||
|
||||
`netbird version`
|
||||
|
||||
**NetBird status -d output:**
|
||||
**NetBird status -dA output:**
|
||||
|
||||
If applicable, add the `netbird status -d' command output.
|
||||
If applicable, add the `netbird status -dA' command output.
|
||||
|
||||
**Do you face any (non-mobile) client issues?**
|
||||
|
||||
Please provide the file created by `netbird debug for 1m -AS`.
|
||||
We advise reviewing the anonymized files for any remaining PII.
|
||||
|
||||
**Screenshots**
|
||||
|
||||
|
||||
15
.github/workflows/golang-test-darwin.yml
vendored
15
.github/workflows/golang-test-darwin.yml
vendored
@@ -18,18 +18,20 @@ jobs:
|
||||
runs-on: macos-latest
|
||||
steps:
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v4
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: "1.21.x"
|
||||
go-version: "1.23.x"
|
||||
cache: false
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v3
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache@v3
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/go/pkg/mod
|
||||
key: macos-go-${{ hashFiles('**/go.sum') }}
|
||||
key: macos-gotest-${{ hashFiles('**/go.sum') }}
|
||||
restore-keys: |
|
||||
macos-gotest-
|
||||
macos-go-
|
||||
|
||||
- name: Install libpcap
|
||||
@@ -42,4 +44,5 @@ jobs:
|
||||
run: git --no-pager diff --exit-code
|
||||
|
||||
- name: Test
|
||||
run: NETBIRD_STORE_ENGINE=${{ matrix.store }} go test -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 5m -p 1 ./...
|
||||
run: NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -tags=devcert -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 5m -p 1 $(go list ./... | grep -v /management)
|
||||
|
||||
|
||||
35
.github/workflows/golang-test-freebsd.yml
vendored
35
.github/workflows/golang-test-freebsd.yml
vendored
@@ -13,7 +13,7 @@ concurrency:
|
||||
|
||||
jobs:
|
||||
test:
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Test in FreeBSD
|
||||
@@ -21,19 +21,26 @@ jobs:
|
||||
uses: vmactions/freebsd-vm@v1
|
||||
with:
|
||||
usesh: true
|
||||
copyback: false
|
||||
release: "14.1"
|
||||
prepare: |
|
||||
pkg install -y curl
|
||||
pkg install -y git
|
||||
pkg install -y go
|
||||
|
||||
# -x - to print all executed commands
|
||||
# -e - to faile on first error
|
||||
run: |
|
||||
set -x
|
||||
curl -o go.tar.gz https://go.dev/dl/go1.21.11.freebsd-amd64.tar.gz -L
|
||||
tar zxf go.tar.gz
|
||||
mv go /usr/local/go
|
||||
ln -s /usr/local/go/bin/go /usr/local/bin/go
|
||||
go mod tidy
|
||||
go test -timeout 5m -p 1 ./iface/...
|
||||
go test -timeout 5m -p 1 ./client/...
|
||||
cd client
|
||||
go build .
|
||||
cd ..
|
||||
set -e -x
|
||||
time go build -o netbird client/main.go
|
||||
# check all component except management, since we do not support management server on freebsd
|
||||
time go test -timeout 1m -failfast ./base62/...
|
||||
# NOTE: without -p1 `client/internal/dns` will fail becasue of `listen udp4 :33100: bind: address already in use`
|
||||
time go test -timeout 8m -failfast -p 1 ./client/...
|
||||
time go test -timeout 1m -failfast ./dns/...
|
||||
time go test -timeout 1m -failfast ./encryption/...
|
||||
time go test -timeout 1m -failfast ./formatter/...
|
||||
time go test -timeout 1m -failfast ./client/iface/...
|
||||
time go test -timeout 1m -failfast ./route/...
|
||||
time go test -timeout 1m -failfast ./sharedsock/...
|
||||
time go test -timeout 1m -failfast ./signal/...
|
||||
time go test -timeout 1m -failfast ./util/...
|
||||
time go test -timeout 1m -failfast ./version/...
|
||||
|
||||
388
.github/workflows/golang-test-linux.yml
vendored
388
.github/workflows/golang-test-linux.yml
vendored
@@ -11,29 +11,114 @@ concurrency:
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
test:
|
||||
strategy:
|
||||
matrix:
|
||||
arch: [ '386','amd64' ]
|
||||
store: [ 'sqlite', 'postgres']
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v4
|
||||
with:
|
||||
go-version: "1.21.x"
|
||||
build-cache:
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: "1.23.x"
|
||||
cache: false
|
||||
|
||||
- name: Get Go environment
|
||||
run: |
|
||||
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
|
||||
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
||||
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache@v3
|
||||
uses: actions/cache@v4
|
||||
id: cache
|
||||
with:
|
||||
path: ~/go/pkg/mod
|
||||
key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }}
|
||||
path: |
|
||||
${{ env.cache }}
|
||||
${{ env.modcache }}
|
||||
key: ${{ runner.os }}-gotest-cache-${{ hashFiles('**/go.sum') }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-go-
|
||||
${{ runner.os }}-gotest-cache-${{ hashFiles('**/go.sum') }}
|
||||
|
||||
|
||||
- name: Install dependencies
|
||||
if: steps.cache.outputs.cache-hit != 'true'
|
||||
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
|
||||
|
||||
- name: Install 32-bit libpcap
|
||||
if: steps.cache.outputs.cache-hit != 'true'
|
||||
run: sudo dpkg --add-architecture i386 && sudo apt update && sudo apt-get install -y libpcap0.8-dev:i386
|
||||
|
||||
- name: Build client
|
||||
if: steps.cache.outputs.cache-hit != 'true'
|
||||
working-directory: client
|
||||
run: CGO_ENABLED=1 go build .
|
||||
|
||||
- name: Build client 386
|
||||
if: steps.cache.outputs.cache-hit != 'true'
|
||||
working-directory: client
|
||||
run: CGO_ENABLED=1 GOARCH=386 go build -o client-386 .
|
||||
|
||||
- name: Build management
|
||||
if: steps.cache.outputs.cache-hit != 'true'
|
||||
working-directory: management
|
||||
run: CGO_ENABLED=1 go build .
|
||||
|
||||
- name: Build management 386
|
||||
if: steps.cache.outputs.cache-hit != 'true'
|
||||
working-directory: management
|
||||
run: CGO_ENABLED=1 GOARCH=386 go build -o management-386 .
|
||||
|
||||
- name: Build signal
|
||||
if: steps.cache.outputs.cache-hit != 'true'
|
||||
working-directory: signal
|
||||
run: CGO_ENABLED=1 go build .
|
||||
|
||||
- name: Build signal 386
|
||||
if: steps.cache.outputs.cache-hit != 'true'
|
||||
working-directory: signal
|
||||
run: CGO_ENABLED=1 GOARCH=386 go build -o signal-386 .
|
||||
|
||||
- name: Build relay
|
||||
if: steps.cache.outputs.cache-hit != 'true'
|
||||
working-directory: relay
|
||||
run: CGO_ENABLED=1 go build .
|
||||
|
||||
- name: Build relay 386
|
||||
if: steps.cache.outputs.cache-hit != 'true'
|
||||
working-directory: relay
|
||||
run: CGO_ENABLED=1 GOARCH=386 go build -o relay-386 .
|
||||
|
||||
test:
|
||||
needs: [build-cache]
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
arch: [ '386','amd64' ]
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: "1.23.x"
|
||||
cache: false
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v3
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Get Go environment
|
||||
run: |
|
||||
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
|
||||
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
||||
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache/restore@v4
|
||||
with:
|
||||
path: |
|
||||
${{ env.cache }}
|
||||
${{ env.modcache }}
|
||||
key: ${{ runner.os }}-gotest-cache-${{ hashFiles('**/go.sum') }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-gotest-cache-
|
||||
|
||||
- name: Install dependencies
|
||||
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
|
||||
@@ -49,26 +134,264 @@ jobs:
|
||||
run: git --no-pager diff --exit-code
|
||||
|
||||
- name: Test
|
||||
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} go test -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 5m -p 1 ./...
|
||||
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} CI=true go test -tags devcert -exec 'sudo' -timeout 10m -p 1 $(go list ./... | grep -v /management)
|
||||
|
||||
test_management:
|
||||
needs: [ build-cache ]
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
arch: [ '386','amd64' ]
|
||||
store: [ 'sqlite', 'postgres', 'mysql' ]
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: "1.23.x"
|
||||
cache: false
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Get Go environment
|
||||
run: |
|
||||
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
|
||||
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
||||
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache/restore@v4
|
||||
with:
|
||||
path: |
|
||||
${{ env.cache }}
|
||||
${{ env.modcache }}
|
||||
key: ${{ runner.os }}-gotest-cache-${{ hashFiles('**/go.sum') }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-gotest-cache-
|
||||
|
||||
- name: Install dependencies
|
||||
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
|
||||
|
||||
- name: Install 32-bit libpcap
|
||||
if: matrix.arch == '386'
|
||||
run: sudo dpkg --add-architecture i386 && sudo apt update && sudo apt-get install -y libpcap0.8-dev:i386
|
||||
|
||||
- name: Install modules
|
||||
run: go mod tidy
|
||||
|
||||
- name: check git status
|
||||
run: git --no-pager diff --exit-code
|
||||
|
||||
- name: Login to Docker hub
|
||||
if: matrix.store == 'mysql' && (github.repository == github.head.repo.full_name || !github.head_ref)
|
||||
uses: docker/login-action@v1
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USER }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
- name: download mysql image
|
||||
if: matrix.store == 'mysql'
|
||||
run: docker pull mlsmaycon/warmed-mysql:8
|
||||
|
||||
- name: Test
|
||||
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -tags=devcert -p 1 -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 10m $(go list ./... | grep /management)
|
||||
|
||||
benchmark:
|
||||
needs: [ build-cache ]
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
arch: [ '386','amd64' ]
|
||||
store: [ 'sqlite', 'postgres', 'mysql' ]
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: "1.23.x"
|
||||
cache: false
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Get Go environment
|
||||
run: |
|
||||
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
|
||||
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
||||
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache/restore@v4
|
||||
with:
|
||||
path: |
|
||||
${{ env.cache }}
|
||||
${{ env.modcache }}
|
||||
key: ${{ runner.os }}-gotest-cache-${{ hashFiles('**/go.sum') }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-gotest-cache-
|
||||
|
||||
- name: Install dependencies
|
||||
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
|
||||
|
||||
- name: Install 32-bit libpcap
|
||||
if: matrix.arch == '386'
|
||||
run: sudo dpkg --add-architecture i386 && sudo apt update && sudo apt-get install -y libpcap0.8-dev:i386
|
||||
|
||||
- name: Install modules
|
||||
run: go mod tidy
|
||||
|
||||
- name: check git status
|
||||
run: git --no-pager diff --exit-code
|
||||
|
||||
- name: Login to Docker hub
|
||||
if: matrix.store == 'mysql' && (github.repository == github.head.repo.full_name || !github.head_ref)
|
||||
uses: docker/login-action@v1
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USER }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
- name: download mysql image
|
||||
if: matrix.store == 'mysql'
|
||||
run: docker pull mlsmaycon/warmed-mysql:8
|
||||
|
||||
- name: Test
|
||||
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -tags devcert -run=^$ -bench=. -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 20m ./...
|
||||
|
||||
api_benchmark:
|
||||
needs: [ build-cache ]
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
arch: [ '386','amd64' ]
|
||||
store: [ 'sqlite', 'postgres' ]
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: "1.23.x"
|
||||
cache: false
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Get Go environment
|
||||
run: |
|
||||
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
|
||||
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
||||
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache/restore@v4
|
||||
with:
|
||||
path: |
|
||||
${{ env.cache }}
|
||||
${{ env.modcache }}
|
||||
key: ${{ runner.os }}-gotest-cache-${{ hashFiles('**/go.sum') }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-gotest-cache-
|
||||
|
||||
- name: Install dependencies
|
||||
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
|
||||
|
||||
- name: Install 32-bit libpcap
|
||||
if: matrix.arch == '386'
|
||||
run: sudo dpkg --add-architecture i386 && sudo apt update && sudo apt-get install -y libpcap0.8-dev:i386
|
||||
|
||||
- name: Install modules
|
||||
run: go mod tidy
|
||||
|
||||
- name: check git status
|
||||
run: git --no-pager diff --exit-code
|
||||
|
||||
- name: Login to Docker hub
|
||||
if: matrix.store == 'mysql' && (github.repository == github.head.repo.full_name || !github.head_ref)
|
||||
uses: docker/login-action@v1
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USER }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
- name: download mysql image
|
||||
if: matrix.store == 'mysql'
|
||||
run: docker pull mlsmaycon/warmed-mysql:8
|
||||
|
||||
- name: Test
|
||||
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -run=^$ -tags=benchmark -bench=. -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 30m $(go list -tags=benchmark ./... | grep /management)
|
||||
|
||||
api_integration_test:
|
||||
needs: [ build-cache ]
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
arch: [ '386','amd64' ]
|
||||
store: [ 'sqlite', 'postgres']
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: "1.23.x"
|
||||
cache: false
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Get Go environment
|
||||
run: |
|
||||
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
|
||||
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
||||
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache/restore@v4
|
||||
with:
|
||||
path: |
|
||||
${{ env.cache }}
|
||||
${{ env.modcache }}
|
||||
key: ${{ runner.os }}-gotest-cache-${{ hashFiles('**/go.sum') }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-gotest-cache-
|
||||
|
||||
- name: Install dependencies
|
||||
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
|
||||
|
||||
- name: Install 32-bit libpcap
|
||||
if: matrix.arch == '386'
|
||||
run: sudo dpkg --add-architecture i386 && sudo apt update && sudo apt-get install -y libpcap0.8-dev:i386
|
||||
|
||||
- name: Install modules
|
||||
run: go mod tidy
|
||||
|
||||
- name: check git status
|
||||
run: git --no-pager diff --exit-code
|
||||
|
||||
- name: Test
|
||||
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -tags=integration -p 1 -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 30m $(go list -tags=integration ./... | grep /management)
|
||||
|
||||
test_client_on_docker:
|
||||
needs: [ build-cache ]
|
||||
runs-on: ubuntu-20.04
|
||||
steps:
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v4
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: "1.21.x"
|
||||
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache@v3
|
||||
with:
|
||||
path: ~/go/pkg/mod
|
||||
key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-go-
|
||||
go-version: "1.23.x"
|
||||
cache: false
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v3
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Get Go environment
|
||||
run: |
|
||||
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
|
||||
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
||||
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache/restore@v4
|
||||
with:
|
||||
path: |
|
||||
${{ env.cache }}
|
||||
${{ env.modcache }}
|
||||
key: ${{ runner.os }}-gotest-cache-${{ hashFiles('**/go.sum') }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-gotest-cache-
|
||||
|
||||
- name: Install dependencies
|
||||
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
|
||||
@@ -79,9 +402,6 @@ jobs:
|
||||
- name: check git status
|
||||
run: git --no-pager diff --exit-code
|
||||
|
||||
- name: Generate Iface Test bin
|
||||
run: CGO_ENABLED=0 go test -c -o iface-testing.bin ./iface/
|
||||
|
||||
- name: Generate Shared Sock Test bin
|
||||
run: CGO_ENABLED=0 go test -c -o sharedsock-testing.bin ./sharedsock
|
||||
|
||||
@@ -98,7 +418,7 @@ jobs:
|
||||
run: CGO_ENABLED=1 go test -c -o engine-testing.bin ./client/internal
|
||||
|
||||
- name: Generate Peer Test bin
|
||||
run: CGO_ENABLED=0 go test -c -o peer-testing.bin ./client/internal/peer/...
|
||||
run: CGO_ENABLED=0 go test -c -o peer-testing.bin ./client/internal/peer/
|
||||
|
||||
- run: chmod +x *testing.bin
|
||||
|
||||
@@ -106,7 +426,7 @@ jobs:
|
||||
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/sharedsock --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/sharedsock-testing.bin -test.timeout 5m -test.parallel 1
|
||||
|
||||
- name: Run Iface tests in docker
|
||||
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/iface --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/iface-testing.bin -test.timeout 5m -test.parallel 1
|
||||
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/netbird -v /tmp/cache:/tmp/cache -v /tmp/modcache:/tmp/modcache -w /netbird -e GOCACHE=/tmp/cache -e GOMODCACHE=/tmp/modcache -e CGO_ENABLED=0 golang:1.23-alpine go test -test.timeout 5m -test.parallel 1 ./client/iface/...
|
||||
|
||||
- name: Run RouteManager tests in docker
|
||||
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal/routemanager --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/routemanager-testing.bin -test.timeout 5m -test.parallel 1
|
||||
@@ -124,4 +444,4 @@ jobs:
|
||||
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal -e NETBIRD_STORE_ENGINE="sqlite" --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/engine-testing.bin -test.timeout 5m -test.parallel 1
|
||||
|
||||
- name: Run Peer tests in docker
|
||||
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal/peer --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/peer-testing.bin -test.timeout 5m -test.parallel 1
|
||||
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal/peer --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/peer-testing.bin -test.timeout 5m -test.parallel 1
|
||||
|
||||
31
.github/workflows/golang-test-windows.yml
vendored
31
.github/workflows/golang-test-windows.yml
vendored
@@ -17,13 +17,30 @@ jobs:
|
||||
runs-on: windows-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v3
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v4
|
||||
uses: actions/setup-go@v5
|
||||
id: go
|
||||
with:
|
||||
go-version: "1.21.x"
|
||||
go-version: "1.23.x"
|
||||
cache: false
|
||||
|
||||
- name: Get Go environment
|
||||
run: |
|
||||
echo "cache=$(go env GOCACHE)" >> $env:GITHUB_ENV
|
||||
echo "modcache=$(go env GOMODCACHE)" >> $env:GITHUB_ENV
|
||||
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: |
|
||||
${{ env.cache }}
|
||||
${{ env.modcache }}
|
||||
key: ${{ runner.os }}-gotest-${{ hashFiles('**/go.sum') }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-gotest-
|
||||
${{ runner.os }}-go-
|
||||
|
||||
- name: Download wintun
|
||||
uses: carlosperate/download-file-action@v2
|
||||
@@ -42,11 +59,13 @@ jobs:
|
||||
- run: choco install -y sysinternals --ignore-checksums
|
||||
- run: choco install -y mingw
|
||||
|
||||
- run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOMODCACHE=C:\Users\runneradmin\go\pkg\mod
|
||||
- run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOCACHE=C:\Users\runneradmin\AppData\Local\go-build
|
||||
- run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOMODCACHE=${{ env.cache }}
|
||||
- run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOCACHE=${{ env.modcache }}
|
||||
- run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe mod tidy
|
||||
- run: echo "files=$(go list ./... | ForEach-Object { $_ } | Where-Object { $_ -notmatch '/management' })" >> $env:GITHUB_ENV
|
||||
|
||||
- name: test
|
||||
run: PsExec64 -s -w ${{ github.workspace }} cmd.exe /c "C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe test -timeout 10m -p 1 ./... > test-out.txt 2>&1"
|
||||
run: PsExec64 -s -w ${{ github.workspace }} cmd.exe /c "C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe test -tags=devcert -timeout 10m -p 1 ${{ env.files }} > test-out.txt 2>&1"
|
||||
- name: test output
|
||||
if: ${{ always() }}
|
||||
run: Get-Content test-out.txt
|
||||
|
||||
14
.github/workflows/golangci-lint.yml
vendored
14
.github/workflows/golangci-lint.yml
vendored
@@ -15,11 +15,11 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v3
|
||||
uses: actions/checkout@v4
|
||||
- name: codespell
|
||||
uses: codespell-project/actions-codespell@v2
|
||||
with:
|
||||
ignore_words_list: erro,clienta,hastable,
|
||||
ignore_words_list: erro,clienta,hastable,iif,groupd,testin
|
||||
skip: go.mod,go.sum
|
||||
only_warn: 1
|
||||
golangci:
|
||||
@@ -32,21 +32,21 @@ jobs:
|
||||
timeout-minutes: 15
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v3
|
||||
uses: actions/checkout@v4
|
||||
- name: Check for duplicate constants
|
||||
if: matrix.os == 'ubuntu-latest'
|
||||
run: |
|
||||
! awk '/const \(/,/)/{print $0}' management/server/activity/codes.go | grep -o '= [0-9]*' | sort | uniq -d | grep .
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v4
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: "1.21.x"
|
||||
go-version: "1.23.x"
|
||||
cache: false
|
||||
- name: Install dependencies
|
||||
if: matrix.os == 'ubuntu-latest'
|
||||
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev libpcap-dev
|
||||
- name: golangci-lint
|
||||
uses: golangci/golangci-lint-action@v3
|
||||
uses: golangci/golangci-lint-action@v4
|
||||
with:
|
||||
version: latest
|
||||
args: --timeout=12m
|
||||
args: --timeout=12m --out-format colored-line-number
|
||||
|
||||
3
.github/workflows/install-script-test.yml
vendored
3
.github/workflows/install-script-test.yml
vendored
@@ -13,6 +13,7 @@ concurrency:
|
||||
jobs:
|
||||
test-install-script:
|
||||
strategy:
|
||||
fail-fast: false
|
||||
max-parallel: 2
|
||||
matrix:
|
||||
os: [ubuntu-latest, macos-latest]
|
||||
@@ -21,7 +22,7 @@ jobs:
|
||||
runs-on: ${{ matrix.os }}
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v3
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: run install script
|
||||
env:
|
||||
|
||||
18
.github/workflows/mobile-build-validation.yml
vendored
18
.github/workflows/mobile-build-validation.yml
vendored
@@ -15,23 +15,23 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v3
|
||||
uses: actions/checkout@v4
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v4
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: "1.21.x"
|
||||
go-version: "1.23.x"
|
||||
- name: Setup Android SDK
|
||||
uses: android-actions/setup-android@v3
|
||||
with:
|
||||
cmdline-tools-version: 8512546
|
||||
- name: Setup Java
|
||||
uses: actions/setup-java@v3
|
||||
uses: actions/setup-java@v4
|
||||
with:
|
||||
java-version: "11"
|
||||
distribution: "adopt"
|
||||
- name: NDK Cache
|
||||
id: ndk-cache
|
||||
uses: actions/cache@v3
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: /usr/local/lib/android/sdk/ndk
|
||||
key: ndk-cache-23.1.7779620
|
||||
@@ -50,11 +50,11 @@ jobs:
|
||||
runs-on: macos-latest
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v3
|
||||
uses: actions/checkout@v4
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v4
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: "1.21.x"
|
||||
go-version: "1.23.x"
|
||||
- name: install gomobile
|
||||
run: go install golang.org/x/mobile/cmd/gomobile@v0.0.0-20240404231514-09dbf07665ed
|
||||
- name: gomobile init
|
||||
@@ -62,4 +62,4 @@ jobs:
|
||||
- name: build iOS netbird lib
|
||||
run: PATH=$PATH:$(go env GOPATH) gomobile bind -target=ios -bundleid=io.netbird.framework -ldflags="-X github.com/netbirdio/netbird/version.version=buildtest" -o ./NetBirdSDK.xcframework ./client/ios/NetBirdSDK
|
||||
env:
|
||||
CGO_ENABLED: 0
|
||||
CGO_ENABLED: 0
|
||||
|
||||
178
.github/workflows/release.yml
vendored
178
.github/workflows/release.yml
vendored
@@ -3,15 +3,16 @@ name: Release
|
||||
on:
|
||||
push:
|
||||
tags:
|
||||
- 'v*'
|
||||
- "v*"
|
||||
branches:
|
||||
- main
|
||||
pull_request:
|
||||
|
||||
|
||||
env:
|
||||
SIGN_PIPE_VER: "v0.0.11"
|
||||
GORELEASER_VER: "v1.14.1"
|
||||
SIGN_PIPE_VER: "v0.0.17"
|
||||
GORELEASER_VER: "v2.3.2"
|
||||
PRODUCT_NAME: "NetBird"
|
||||
COPYRIGHT: "Wiretrustee UG (haftungsbeschreankt)"
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
|
||||
@@ -19,26 +20,30 @@ concurrency:
|
||||
|
||||
jobs:
|
||||
release:
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: ubuntu-22.04
|
||||
env:
|
||||
flags: ""
|
||||
steps:
|
||||
- name: Parse semver string
|
||||
id: semver_parser
|
||||
uses: booxmedialtd/ws-action-parse-semver@v1
|
||||
with:
|
||||
input_string: ${{ (startsWith(github.ref, 'refs/tags/v') && github.ref) || 'refs/tags/v0.0.0' }}
|
||||
version_extractor_regex: '\/v(.*)$'
|
||||
|
||||
- if: ${{ !startsWith(github.ref, 'refs/tags/v') }}
|
||||
run: echo "flags=--snapshot" >> $GITHUB_ENV
|
||||
-
|
||||
name: Checkout
|
||||
uses: actions/checkout@v3
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0 # It is required for GoReleaser to work properly
|
||||
-
|
||||
name: Set up Go
|
||||
uses: actions/setup-go@v4
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: "1.21"
|
||||
go-version: "1.23"
|
||||
cache: false
|
||||
-
|
||||
name: Cache Go modules
|
||||
uses: actions/cache@v3
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: |
|
||||
~/go/pkg/mod
|
||||
@@ -46,73 +51,57 @@ jobs:
|
||||
key: ${{ runner.os }}-go-releaser-${{ hashFiles('**/go.sum') }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-go-releaser-
|
||||
-
|
||||
name: Install modules
|
||||
- name: Install modules
|
||||
run: go mod tidy
|
||||
-
|
||||
name: check git status
|
||||
- name: check git status
|
||||
run: git --no-pager diff --exit-code
|
||||
-
|
||||
name: Set up QEMU
|
||||
- name: Set up QEMU
|
||||
uses: docker/setup-qemu-action@v2
|
||||
-
|
||||
name: Set up Docker Buildx
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v2
|
||||
-
|
||||
name: Login to Docker hub
|
||||
- name: Login to Docker hub
|
||||
if: github.event_name != 'pull_request'
|
||||
uses: docker/login-action@v1
|
||||
with:
|
||||
username: netbirdio
|
||||
username: ${{ secrets.DOCKER_USER }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
- name: Install OS build dependencies
|
||||
run: sudo apt update && sudo apt install -y -q gcc-arm-linux-gnueabihf gcc-aarch64-linux-gnu
|
||||
|
||||
- name: Install rsrc
|
||||
run: go install github.com/akavel/rsrc@v0.10.2
|
||||
- name: Generate windows rsrc amd64
|
||||
run: rsrc -arch amd64 -ico client/ui/netbird.ico -manifest client/manifest.xml -o client/resources_windows_amd64.syso
|
||||
- name: Generate windows rsrc arm64
|
||||
run: rsrc -arch arm64 -ico client/ui/netbird.ico -manifest client/manifest.xml -o client/resources_windows_arm64.syso
|
||||
- name: Generate windows rsrc arm
|
||||
run: rsrc -arch arm -ico client/ui/netbird.ico -manifest client/manifest.xml -o client/resources_windows_arm.syso
|
||||
- name: Generate windows rsrc 386
|
||||
run: rsrc -arch 386 -ico client/ui/netbird.ico -manifest client/manifest.xml -o client/resources_windows_386.syso
|
||||
-
|
||||
name: Run GoReleaser
|
||||
- name: Install goversioninfo
|
||||
run: go install github.com/josephspurrier/goversioninfo/cmd/goversioninfo@233067e
|
||||
- name: Generate windows syso amd64
|
||||
run: goversioninfo -icon client/ui/netbird.ico -manifest client/manifest.xml -product-name ${{ env.PRODUCT_NAME }} -copyright "${{ env.COPYRIGHT }}" -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -ver-build 0 -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/resources_windows_amd64.syso
|
||||
- name: Run GoReleaser
|
||||
uses: goreleaser/goreleaser-action@v4
|
||||
with:
|
||||
version: ${{ env.GORELEASER_VER }}
|
||||
args: release --rm-dist ${{ env.flags }}
|
||||
args: release --clean ${{ env.flags }}
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
HOMEBREW_TAP_GITHUB_TOKEN: ${{ secrets.HOMEBREW_TAP_GITHUB_TOKEN }}
|
||||
UPLOAD_DEBIAN_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }}
|
||||
UPLOAD_YUM_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }}
|
||||
-
|
||||
name: upload non tags for debug purposes
|
||||
uses: actions/upload-artifact@v3
|
||||
- name: upload non tags for debug purposes
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: release
|
||||
path: dist/
|
||||
retention-days: 3
|
||||
-
|
||||
name: upload linux packages
|
||||
uses: actions/upload-artifact@v3
|
||||
- name: upload linux packages
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: linux-packages
|
||||
path: dist/netbird_linux**
|
||||
retention-days: 3
|
||||
-
|
||||
name: upload windows packages
|
||||
uses: actions/upload-artifact@v3
|
||||
- name: upload windows packages
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: windows-packages
|
||||
path: dist/netbird_windows**
|
||||
retention-days: 3
|
||||
-
|
||||
name: upload macos packages
|
||||
uses: actions/upload-artifact@v3
|
||||
- name: upload macos packages
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: macos-packages
|
||||
path: dist/netbird_darwin**
|
||||
@@ -121,22 +110,29 @@ jobs:
|
||||
release_ui:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Parse semver string
|
||||
id: semver_parser
|
||||
uses: booxmedialtd/ws-action-parse-semver@v1
|
||||
with:
|
||||
input_string: ${{ (startsWith(github.ref, 'refs/tags/v') && github.ref) || 'refs/tags/v0.0.0' }}
|
||||
version_extractor_regex: '\/v(.*)$'
|
||||
|
||||
- if: ${{ !startsWith(github.ref, 'refs/tags/v') }}
|
||||
run: echo "flags=--snapshot" >> $GITHUB_ENV
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v3
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0 # It is required for GoReleaser to work properly
|
||||
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v4
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: "1.21"
|
||||
go-version: "1.23"
|
||||
cache: false
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache@v3
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: |
|
||||
path: |
|
||||
~/go/pkg/mod
|
||||
~/.cache/go-build
|
||||
key: ${{ runner.os }}-ui-go-releaser-${{ hashFiles('**/go.sum') }}
|
||||
@@ -151,22 +147,23 @@ jobs:
|
||||
|
||||
- name: Install dependencies
|
||||
run: sudo apt update && sudo apt install -y -q libappindicator3-dev gir1.2-appindicator3-0.1 libxxf86vm-dev gcc-mingw-w64-x86-64
|
||||
- name: Install rsrc
|
||||
run: go install github.com/akavel/rsrc@v0.10.2
|
||||
- name: Generate windows rsrc
|
||||
run: rsrc -arch amd64 -ico client/ui/netbird.ico -manifest client/ui/manifest.xml -o client/ui/resources_windows_amd64.syso
|
||||
- name: Install goversioninfo
|
||||
run: go install github.com/josephspurrier/goversioninfo/cmd/goversioninfo@233067e
|
||||
- name: Generate windows syso amd64
|
||||
run: goversioninfo -64 -icon client/ui/netbird.ico -manifest client/ui/manifest.xml -product-name ${{ env.PRODUCT_NAME }}-"UI" -copyright "${{ env.COPYRIGHT }}" -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -ver-build 0 -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/ui/resources_windows_amd64.syso
|
||||
|
||||
- name: Run GoReleaser
|
||||
uses: goreleaser/goreleaser-action@v4
|
||||
with:
|
||||
version: ${{ env.GORELEASER_VER }}
|
||||
args: release --config .goreleaser_ui.yaml --rm-dist ${{ env.flags }}
|
||||
args: release --config .goreleaser_ui.yaml --clean ${{ env.flags }}
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
HOMEBREW_TAP_GITHUB_TOKEN: ${{ secrets.HOMEBREW_TAP_GITHUB_TOKEN }}
|
||||
UPLOAD_DEBIAN_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }}
|
||||
UPLOAD_YUM_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }}
|
||||
- name: upload non tags for debug purposes
|
||||
uses: actions/upload-artifact@v3
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: release-ui
|
||||
path: dist/
|
||||
@@ -177,20 +174,17 @@ jobs:
|
||||
steps:
|
||||
- if: ${{ !startsWith(github.ref, 'refs/tags/v') }}
|
||||
run: echo "flags=--snapshot" >> $GITHUB_ENV
|
||||
-
|
||||
name: Checkout
|
||||
uses: actions/checkout@v3
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0 # It is required for GoReleaser to work properly
|
||||
-
|
||||
name: Set up Go
|
||||
uses: actions/setup-go@v4
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: "1.21"
|
||||
go-version: "1.23"
|
||||
cache: false
|
||||
-
|
||||
name: Cache Go modules
|
||||
uses: actions/cache@v3
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: |
|
||||
~/go/pkg/mod
|
||||
@@ -198,53 +192,35 @@ jobs:
|
||||
key: ${{ runner.os }}-ui-go-releaser-darwin-${{ hashFiles('**/go.sum') }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-ui-go-releaser-darwin-
|
||||
-
|
||||
name: Install modules
|
||||
- name: Install modules
|
||||
run: go mod tidy
|
||||
-
|
||||
name: check git status
|
||||
- name: check git status
|
||||
run: git --no-pager diff --exit-code
|
||||
-
|
||||
name: Run GoReleaser
|
||||
- name: Run GoReleaser
|
||||
id: goreleaser
|
||||
uses: goreleaser/goreleaser-action@v4
|
||||
with:
|
||||
version: ${{ env.GORELEASER_VER }}
|
||||
args: release --config .goreleaser_ui_darwin.yaml --rm-dist ${{ env.flags }}
|
||||
args: release --config .goreleaser_ui_darwin.yaml --clean ${{ env.flags }}
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
-
|
||||
name: upload non tags for debug purposes
|
||||
uses: actions/upload-artifact@v3
|
||||
- name: upload non tags for debug purposes
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: release-ui-darwin
|
||||
path: dist/
|
||||
retention-days: 3
|
||||
|
||||
trigger_windows_signer:
|
||||
trigger_signer:
|
||||
runs-on: ubuntu-latest
|
||||
needs: [release,release_ui]
|
||||
needs: [release, release_ui, release_ui_darwin]
|
||||
if: startsWith(github.ref, 'refs/tags/')
|
||||
steps:
|
||||
- name: Trigger Windows binaries sign pipeline
|
||||
- name: Trigger binaries sign pipelines
|
||||
uses: benc-uk/workflow-dispatch@v1
|
||||
with:
|
||||
workflow: Sign windows bin and installer
|
||||
workflow: Sign bin and installer
|
||||
repo: netbirdio/sign-pipelines
|
||||
ref: ${{ env.SIGN_PIPE_VER }}
|
||||
token: ${{ secrets.SIGN_GITHUB_TOKEN }}
|
||||
inputs: '{ "tag": "${{ github.ref }}" }'
|
||||
|
||||
trigger_darwin_signer:
|
||||
runs-on: ubuntu-latest
|
||||
needs: [release,release_ui_darwin]
|
||||
if: startsWith(github.ref, 'refs/tags/')
|
||||
steps:
|
||||
- name: Trigger Darwin App binaries sign pipeline
|
||||
uses: benc-uk/workflow-dispatch@v1
|
||||
with:
|
||||
workflow: Sign darwin ui app with dispatch
|
||||
repo: netbirdio/sign-pipelines
|
||||
ref: ${{ env.SIGN_PIPE_VER }}
|
||||
token: ${{ secrets.SIGN_GITHUB_TOKEN }}
|
||||
inputs: '{ "tag": "${{ github.ref }}" }'
|
||||
inputs: '{ "tag": "${{ github.ref }}", "skipRelease": false }'
|
||||
|
||||
118
.github/workflows/test-infrastructure-files.yml
vendored
118
.github/workflows/test-infrastructure-files.yml
vendored
@@ -18,7 +18,49 @@ concurrency:
|
||||
jobs:
|
||||
test-docker-compose:
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
store: [ 'sqlite', 'postgres', 'mysql' ]
|
||||
services:
|
||||
postgres:
|
||||
image: ${{ (matrix.store == 'postgres') && 'postgres' || '' }}
|
||||
env:
|
||||
POSTGRES_USER: netbird
|
||||
POSTGRES_PASSWORD: postgres
|
||||
POSTGRES_DB: netbird
|
||||
options: >-
|
||||
--health-cmd pg_isready
|
||||
--health-interval 10s
|
||||
--health-timeout 5s
|
||||
ports:
|
||||
- 5432:5432
|
||||
mysql:
|
||||
image: ${{ (matrix.store == 'mysql') && 'mysql' || '' }}
|
||||
env:
|
||||
MYSQL_USER: netbird
|
||||
MYSQL_PASSWORD: mysql
|
||||
MYSQL_ROOT_PASSWORD: mysqlroot
|
||||
MYSQL_DATABASE: netbird
|
||||
options: >-
|
||||
--health-cmd "mysqladmin ping --silent"
|
||||
--health-interval 10s
|
||||
--health-timeout 5s
|
||||
ports:
|
||||
- 3306:3306
|
||||
steps:
|
||||
- name: Set Database Connection String
|
||||
run: |
|
||||
if [ "${{ matrix.store }}" == "postgres" ]; then
|
||||
echo "NETBIRD_STORE_ENGINE_POSTGRES_DSN=host=$(hostname -I | awk '{print $1}') user=netbird password=postgres dbname=netbird port=5432" >> $GITHUB_ENV
|
||||
else
|
||||
echo "NETBIRD_STORE_ENGINE_POSTGRES_DSN==" >> $GITHUB_ENV
|
||||
fi
|
||||
if [ "${{ matrix.store }}" == "mysql" ]; then
|
||||
echo "NETBIRD_STORE_ENGINE_MYSQL_DSN=netbird:mysql@tcp($(hostname -I | awk '{print $1}'):3306)/netbird" >> $GITHUB_ENV
|
||||
else
|
||||
echo "NETBIRD_STORE_ENGINE_MYSQL_DSN==" >> $GITHUB_ENV
|
||||
fi
|
||||
|
||||
- name: Install jq
|
||||
run: sudo apt-get install -y jq
|
||||
|
||||
@@ -26,12 +68,12 @@ jobs:
|
||||
run: sudo apt-get install -y curl
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v4
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: "1.21.x"
|
||||
go-version: "1.23.x"
|
||||
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache@v3
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/go/pkg/mod
|
||||
key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }}
|
||||
@@ -39,7 +81,7 @@ jobs:
|
||||
${{ runner.os }}-go-
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v3
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: cp setup.env
|
||||
run: cp infrastructure_files/tests/setup.env infrastructure_files/
|
||||
@@ -58,7 +100,9 @@ jobs:
|
||||
CI_NETBIRD_IDP_MGMT_CLIENT_ID: testing.client.id
|
||||
CI_NETBIRD_IDP_MGMT_CLIENT_SECRET: testing.client.secret
|
||||
CI_NETBIRD_AUTH_SUPPORTED_SCOPES: "openid profile email offline_access api email_verified"
|
||||
CI_NETBIRD_STORE_CONFIG_ENGINE: "sqlite"
|
||||
CI_NETBIRD_STORE_CONFIG_ENGINE: ${{ matrix.store }}
|
||||
NETBIRD_STORE_ENGINE_POSTGRES_DSN: ${{ env.NETBIRD_STORE_ENGINE_POSTGRES_DSN }}
|
||||
NETBIRD_STORE_ENGINE_MYSQL_DSN: ${{ env.NETBIRD_STORE_ENGINE_MYSQL_DSN }}
|
||||
CI_NETBIRD_MGMT_IDP_SIGNKEY_REFRESH: false
|
||||
|
||||
- name: check values
|
||||
@@ -85,7 +129,9 @@ jobs:
|
||||
CI_NETBIRD_IDP_MGMT_CLIENT_ID: testing.client.id
|
||||
CI_NETBIRD_IDP_MGMT_CLIENT_SECRET: testing.client.secret
|
||||
CI_NETBIRD_SIGNAL_PORT: 12345
|
||||
CI_NETBIRD_STORE_CONFIG_ENGINE: "sqlite"
|
||||
CI_NETBIRD_STORE_CONFIG_ENGINE: ${{ matrix.store }}
|
||||
NETBIRD_STORE_ENGINE_POSTGRES_DSN: '${{ env.NETBIRD_STORE_ENGINE_POSTGRES_DSN }}$'
|
||||
NETBIRD_STORE_ENGINE_MYSQL_DSN: '${{ env.NETBIRD_STORE_ENGINE_MYSQL_DSN }}$'
|
||||
CI_NETBIRD_MGMT_IDP_SIGNKEY_REFRESH: false
|
||||
CI_NETBIRD_TURN_EXTERNAL_IP: "1.2.3.4"
|
||||
|
||||
@@ -123,6 +169,15 @@ jobs:
|
||||
grep -A 10 PKCEAuthorizationFlow management.json | grep -A 10 ProviderConfig | grep Scope | grep "$CI_NETBIRD_AUTH_SUPPORTED_SCOPES"
|
||||
grep -A 10 PKCEAuthorizationFlow management.json | grep -A 10 ProviderConfig | grep -A 3 RedirectURLs | grep "http://localhost:53000"
|
||||
grep "external-ip" turnserver.conf | grep $CI_NETBIRD_TURN_EXTERNAL_IP
|
||||
grep "NETBIRD_STORE_ENGINE_MYSQL_DSN=$NETBIRD_STORE_ENGINE_MYSQL_DSN" docker-compose.yml
|
||||
grep NETBIRD_STORE_ENGINE_POSTGRES_DSN docker-compose.yml | egrep "$NETBIRD_STORE_ENGINE_POSTGRES_DSN"
|
||||
# check relay values
|
||||
grep "NB_EXPOSED_ADDRESS=$CI_NETBIRD_DOMAIN:33445" docker-compose.yml
|
||||
grep "NB_LISTEN_ADDRESS=:33445" docker-compose.yml
|
||||
grep '33445:33445' docker-compose.yml
|
||||
grep -A 10 'relay:' docker-compose.yml | egrep 'NB_AUTH_SECRET=.+$'
|
||||
grep -A 7 Relay management.json | grep "rel://$CI_NETBIRD_DOMAIN:33445"
|
||||
grep -A 7 Relay management.json | egrep '"Secret": ".+"'
|
||||
|
||||
- name: Install modules
|
||||
run: go mod tidy
|
||||
@@ -148,26 +203,35 @@ jobs:
|
||||
run: |
|
||||
docker build -t netbirdio/signal:latest .
|
||||
|
||||
- name: Build relay binary
|
||||
working-directory: relay
|
||||
run: CGO_ENABLED=0 go build -o netbird-relay main.go
|
||||
|
||||
- name: Build relay docker image
|
||||
working-directory: relay
|
||||
run: |
|
||||
docker build -t netbirdio/relay:latest .
|
||||
|
||||
- name: run docker compose up
|
||||
working-directory: infrastructure_files/artifacts
|
||||
run: |
|
||||
docker-compose up -d
|
||||
docker compose up -d
|
||||
sleep 5
|
||||
docker-compose ps
|
||||
docker-compose logs --tail=20
|
||||
docker compose ps
|
||||
docker compose logs --tail=20
|
||||
|
||||
- name: test running containers
|
||||
run: |
|
||||
count=$(docker compose ps --format json | jq '. | select(.Name | contains("artifacts")) | .State' | grep -c running)
|
||||
test $count -eq 4
|
||||
test $count -eq 5 || docker compose logs
|
||||
working-directory: infrastructure_files/artifacts
|
||||
|
||||
- name: test geolocation databases
|
||||
working-directory: infrastructure_files/artifacts
|
||||
run: |
|
||||
sleep 30
|
||||
docker compose exec management ls -l /var/lib/netbird/ | grep -i GeoLite2-City.mmdb
|
||||
docker compose exec management ls -l /var/lib/netbird/ | grep -i geonames.db
|
||||
docker compose exec management ls -l /var/lib/netbird/ | grep -i GeoLite2-City_[0-9]*.mmdb
|
||||
docker compose exec management ls -l /var/lib/netbird/ | grep -i geonames_[0-9]*.db
|
||||
|
||||
test-getting-started-script:
|
||||
runs-on: ubuntu-latest
|
||||
@@ -176,7 +240,7 @@ jobs:
|
||||
run: sudo apt-get install -y jq
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v3
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: run script with Zitadel PostgreSQL
|
||||
run: NETBIRD_DOMAIN=use-ip bash -x infrastructure_files/getting-started-with-zitadel.sh
|
||||
@@ -191,7 +255,7 @@ jobs:
|
||||
run: test -f management.json
|
||||
|
||||
- name: test turnserver.conf file gen postgres
|
||||
run: |
|
||||
run: |
|
||||
set -x
|
||||
test -f turnserver.conf
|
||||
grep external-ip turnserver.conf
|
||||
@@ -202,12 +266,15 @@ jobs:
|
||||
- name: test dashboard.env file gen postgres
|
||||
run: test -f dashboard.env
|
||||
|
||||
- name: test relay.env file gen postgres
|
||||
run: test -f relay.env
|
||||
|
||||
- name: test zdb.env file gen postgres
|
||||
run: test -f zdb.env
|
||||
|
||||
- name: Postgres run cleanup
|
||||
run: |
|
||||
docker-compose down --volumes --rmi all
|
||||
docker compose down --volumes --rmi all
|
||||
rm -rf docker-compose.yml Caddyfile zitadel.env dashboard.env machinekey/zitadel-admin-sa.token turnserver.conf management.json zdb.env
|
||||
|
||||
- name: run script with Zitadel CockroachDB
|
||||
@@ -226,7 +293,7 @@ jobs:
|
||||
run: test -f management.json
|
||||
|
||||
- name: test turnserver.conf file gen CockroachDB
|
||||
run: |
|
||||
run: |
|
||||
set -x
|
||||
test -f turnserver.conf
|
||||
grep external-ip turnserver.conf
|
||||
@@ -237,20 +304,5 @@ jobs:
|
||||
- name: test dashboard.env file gen CockroachDB
|
||||
run: test -f dashboard.env
|
||||
|
||||
test-download-geolite2-script:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Install jq
|
||||
run: sudo apt-get update && sudo apt-get install -y unzip sqlite3
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v3
|
||||
|
||||
- name: test script
|
||||
run: bash -x infrastructure_files/download-geolite2.sh
|
||||
|
||||
- name: test mmdb file exists
|
||||
run: test -f GeoLite2-City.mmdb
|
||||
|
||||
- name: test geonames file exists
|
||||
run: test -f geonames.db
|
||||
- name: test relay.env file gen CockroachDB
|
||||
run: test -f relay.env
|
||||
|
||||
1
.gitignore
vendored
1
.gitignore
vendored
@@ -29,4 +29,3 @@ infrastructure_files/setup.env
|
||||
infrastructure_files/setup-*.env
|
||||
.vscode
|
||||
.DS_Store
|
||||
GeoLite2-City*
|
||||
164
.goreleaser.yaml
164
.goreleaser.yaml
@@ -1,3 +1,5 @@
|
||||
version: 2
|
||||
|
||||
project_name: netbird
|
||||
builds:
|
||||
- id: netbird
|
||||
@@ -22,7 +24,7 @@ builds:
|
||||
goarch: 386
|
||||
ldflags:
|
||||
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
|
||||
mod_timestamp: '{{ .CommitTimestamp }}'
|
||||
mod_timestamp: "{{ .CommitTimestamp }}"
|
||||
tags:
|
||||
- load_wgnt_from_rsrc
|
||||
|
||||
@@ -42,19 +44,19 @@ builds:
|
||||
- softfloat
|
||||
ldflags:
|
||||
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
|
||||
mod_timestamp: '{{ .CommitTimestamp }}'
|
||||
mod_timestamp: "{{ .CommitTimestamp }}"
|
||||
tags:
|
||||
- load_wgnt_from_rsrc
|
||||
|
||||
- id: netbird-mgmt
|
||||
dir: management
|
||||
env:
|
||||
- CGO_ENABLED=1
|
||||
- >-
|
||||
{{- if eq .Runtime.Goos "linux" }}
|
||||
{{- if eq .Arch "arm64"}}CC=aarch64-linux-gnu-gcc{{- end }}
|
||||
{{- if eq .Arch "arm"}}CC=arm-linux-gnueabihf-gcc{{- end }}
|
||||
{{- end }}
|
||||
- CGO_ENABLED=1
|
||||
- >-
|
||||
{{- if eq .Runtime.Goos "linux" }}
|
||||
{{- if eq .Arch "arm64"}}CC=aarch64-linux-gnu-gcc{{- end }}
|
||||
{{- if eq .Arch "arm"}}CC=arm-linux-gnueabihf-gcc{{- end }}
|
||||
{{- end }}
|
||||
binary: netbird-mgmt
|
||||
goos:
|
||||
- linux
|
||||
@@ -64,7 +66,7 @@ builds:
|
||||
- arm
|
||||
ldflags:
|
||||
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
|
||||
mod_timestamp: '{{ .CommitTimestamp }}'
|
||||
mod_timestamp: "{{ .CommitTimestamp }}"
|
||||
|
||||
- id: netbird-signal
|
||||
dir: signal
|
||||
@@ -78,7 +80,24 @@ builds:
|
||||
- arm
|
||||
ldflags:
|
||||
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
|
||||
mod_timestamp: '{{ .CommitTimestamp }}'
|
||||
mod_timestamp: "{{ .CommitTimestamp }}"
|
||||
|
||||
- id: netbird-relay
|
||||
dir: relay
|
||||
env: [CGO_ENABLED=0]
|
||||
binary: netbird-relay
|
||||
goos:
|
||||
- linux
|
||||
goarch:
|
||||
- amd64
|
||||
- arm64
|
||||
- arm
|
||||
ldflags:
|
||||
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
|
||||
mod_timestamp: "{{ .CommitTimestamp }}"
|
||||
|
||||
universal_binaries:
|
||||
- id: netbird
|
||||
|
||||
archives:
|
||||
- builds:
|
||||
@@ -86,7 +105,6 @@ archives:
|
||||
- netbird-static
|
||||
|
||||
nfpms:
|
||||
|
||||
- maintainer: Netbird <dev@netbird.io>
|
||||
description: Netbird client.
|
||||
homepage: https://netbird.io/
|
||||
@@ -161,6 +179,97 @@ dockers:
|
||||
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||
- "--label=maintainer=dev@netbird.io"
|
||||
|
||||
- image_templates:
|
||||
- netbirdio/netbird:{{ .Version }}-rootless-amd64
|
||||
ids:
|
||||
- netbird
|
||||
goarch: amd64
|
||||
use: buildx
|
||||
dockerfile: client/Dockerfile-rootless
|
||||
build_flag_templates:
|
||||
- "--platform=linux/amd64"
|
||||
- "--label=org.opencontainers.image.created={{.Date}}"
|
||||
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||
- "--label=maintainer=dev@netbird.io"
|
||||
- image_templates:
|
||||
- netbirdio/netbird:{{ .Version }}-rootless-arm64v8
|
||||
ids:
|
||||
- netbird
|
||||
goarch: arm64
|
||||
use: buildx
|
||||
dockerfile: client/Dockerfile-rootless
|
||||
build_flag_templates:
|
||||
- "--platform=linux/arm64"
|
||||
- "--label=org.opencontainers.image.created={{.Date}}"
|
||||
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||
- "--label=maintainer=dev@netbird.io"
|
||||
- image_templates:
|
||||
- netbirdio/netbird:{{ .Version }}-rootless-arm
|
||||
ids:
|
||||
- netbird
|
||||
goarch: arm
|
||||
goarm: 6
|
||||
use: buildx
|
||||
dockerfile: client/Dockerfile-rootless
|
||||
build_flag_templates:
|
||||
- "--platform=linux/arm"
|
||||
- "--label=org.opencontainers.image.created={{.Date}}"
|
||||
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||
- "--label=maintainer=dev@netbird.io"
|
||||
|
||||
- image_templates:
|
||||
- netbirdio/relay:{{ .Version }}-amd64
|
||||
ids:
|
||||
- netbird-relay
|
||||
goarch: amd64
|
||||
use: buildx
|
||||
dockerfile: relay/Dockerfile
|
||||
build_flag_templates:
|
||||
- "--platform=linux/amd64"
|
||||
- "--label=org.opencontainers.image.created={{.Date}}"
|
||||
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||
- "--label=maintainer=dev@netbird.io"
|
||||
- image_templates:
|
||||
- netbirdio/relay:{{ .Version }}-arm64v8
|
||||
ids:
|
||||
- netbird-relay
|
||||
goarch: arm64
|
||||
use: buildx
|
||||
dockerfile: relay/Dockerfile
|
||||
build_flag_templates:
|
||||
- "--platform=linux/arm64"
|
||||
- "--label=org.opencontainers.image.created={{.Date}}"
|
||||
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||
- "--label=maintainer=dev@netbird.io"
|
||||
- image_templates:
|
||||
- netbirdio/relay:{{ .Version }}-arm
|
||||
ids:
|
||||
- netbird-relay
|
||||
goarch: arm
|
||||
goarm: 6
|
||||
use: buildx
|
||||
dockerfile: relay/Dockerfile
|
||||
build_flag_templates:
|
||||
- "--platform=linux/arm"
|
||||
- "--label=org.opencontainers.image.created={{.Date}}"
|
||||
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||
- "--label=maintainer=dev@netbird.io"
|
||||
- image_templates:
|
||||
- netbirdio/signal:{{ .Version }}-amd64
|
||||
ids:
|
||||
@@ -313,6 +422,30 @@ docker_manifests:
|
||||
- netbirdio/netbird:{{ .Version }}-arm
|
||||
- netbirdio/netbird:{{ .Version }}-amd64
|
||||
|
||||
- name_template: netbirdio/netbird:{{ .Version }}-rootless
|
||||
image_templates:
|
||||
- netbirdio/netbird:{{ .Version }}-rootless-arm64v8
|
||||
- netbirdio/netbird:{{ .Version }}-rootless-arm
|
||||
- netbirdio/netbird:{{ .Version }}-rootless-amd64
|
||||
|
||||
- name_template: netbirdio/netbird:rootless-latest
|
||||
image_templates:
|
||||
- netbirdio/netbird:{{ .Version }}-rootless-arm64v8
|
||||
- netbirdio/netbird:{{ .Version }}-rootless-arm
|
||||
- netbirdio/netbird:{{ .Version }}-rootless-amd64
|
||||
|
||||
- name_template: netbirdio/relay:{{ .Version }}
|
||||
image_templates:
|
||||
- netbirdio/relay:{{ .Version }}-arm64v8
|
||||
- netbirdio/relay:{{ .Version }}-arm
|
||||
- netbirdio/relay:{{ .Version }}-amd64
|
||||
|
||||
- name_template: netbirdio/relay:latest
|
||||
image_templates:
|
||||
- netbirdio/relay:{{ .Version }}-arm64v8
|
||||
- netbirdio/relay:{{ .Version }}-arm
|
||||
- netbirdio/relay:{{ .Version }}-amd64
|
||||
|
||||
- name_template: netbirdio/signal:{{ .Version }}
|
||||
image_templates:
|
||||
- netbirdio/signal:{{ .Version }}-arm64v8
|
||||
@@ -344,10 +477,9 @@ docker_manifests:
|
||||
- netbirdio/management:{{ .Version }}-debug-amd64
|
||||
|
||||
brews:
|
||||
-
|
||||
ids:
|
||||
- ids:
|
||||
- default
|
||||
tap:
|
||||
repository:
|
||||
owner: netbirdio
|
||||
name: homebrew-tap
|
||||
token: "{{ .Env.HOMEBREW_TAP_GITHUB_TOKEN }}"
|
||||
@@ -364,7 +496,7 @@ brews:
|
||||
uploads:
|
||||
- name: debian
|
||||
ids:
|
||||
- netbird-deb
|
||||
- netbird-deb
|
||||
mode: archive
|
||||
target: https://pkgs.wiretrustee.com/debian/pool/{{ .ArtifactName }};deb.distribution=stable;deb.component=main;deb.architecture={{ if .Arm }}armhf{{ else }}{{ .Arch }}{{ end }};deb.package=
|
||||
username: dev@wiretrustee.com
|
||||
@@ -386,4 +518,4 @@ checksum:
|
||||
release:
|
||||
extra_files:
|
||||
- glob: ./infrastructure_files/getting-started-with-zitadel.sh
|
||||
- glob: ./release_files/install.sh
|
||||
- glob: ./release_files/install.sh
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
version: 2
|
||||
|
||||
project_name: netbird-ui
|
||||
builds:
|
||||
- id: netbird-ui
|
||||
@@ -11,9 +13,7 @@ builds:
|
||||
- amd64
|
||||
ldflags:
|
||||
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
|
||||
tags:
|
||||
- legacy_appindicator
|
||||
mod_timestamp: '{{ .CommitTimestamp }}'
|
||||
mod_timestamp: "{{ .CommitTimestamp }}"
|
||||
|
||||
- id: netbird-ui-windows
|
||||
dir: client/ui
|
||||
@@ -28,7 +28,7 @@ builds:
|
||||
ldflags:
|
||||
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
|
||||
- -H windowsgui
|
||||
mod_timestamp: '{{ .CommitTimestamp }}'
|
||||
mod_timestamp: "{{ .CommitTimestamp }}"
|
||||
|
||||
archives:
|
||||
- id: linux-arch
|
||||
@@ -41,7 +41,6 @@ archives:
|
||||
- netbird-ui-windows
|
||||
|
||||
nfpms:
|
||||
|
||||
- maintainer: Netbird <dev@netbird.io>
|
||||
description: Netbird client UI.
|
||||
homepage: https://netbird.io/
|
||||
@@ -79,7 +78,7 @@ nfpms:
|
||||
uploads:
|
||||
- name: debian
|
||||
ids:
|
||||
- netbird-ui-deb
|
||||
- netbird-ui-deb
|
||||
mode: archive
|
||||
target: https://pkgs.wiretrustee.com/debian/pool/{{ .ArtifactName }};deb.distribution=stable;deb.component=main;deb.architecture={{ if .Arm }}armhf{{ else }}{{ .Arch }}{{ end }};deb.package=
|
||||
username: dev@wiretrustee.com
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
version: 2
|
||||
|
||||
project_name: netbird-ui
|
||||
builds:
|
||||
- id: netbird-ui-darwin
|
||||
@@ -17,10 +19,13 @@ builds:
|
||||
- softfloat
|
||||
ldflags:
|
||||
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
|
||||
mod_timestamp: '{{ .CommitTimestamp }}'
|
||||
mod_timestamp: "{{ .CommitTimestamp }}"
|
||||
tags:
|
||||
- load_wgnt_from_rsrc
|
||||
|
||||
universal_binaries:
|
||||
- id: netbird-ui-darwin
|
||||
|
||||
archives:
|
||||
- builds:
|
||||
- netbird-ui-darwin
|
||||
@@ -28,4 +33,4 @@ archives:
|
||||
checksum:
|
||||
name_template: "{{ .ProjectName }}_darwin_checksums.txt"
|
||||
changelog:
|
||||
skip: true
|
||||
disable: true
|
||||
|
||||
@@ -96,7 +96,7 @@ They can be executed from the repository root before every push or PR:
|
||||
|
||||
**Goreleaser**
|
||||
```shell
|
||||
goreleaser --snapshot --rm-dist
|
||||
goreleaser build --snapshot --clean
|
||||
```
|
||||
**golangci-lint**
|
||||
```shell
|
||||
|
||||
22
README.md
22
README.md
@@ -1,22 +1,21 @@
|
||||
<p align="center">
|
||||
<strong>:hatching_chick: New Release! Device Posture Checks.</strong>
|
||||
<a href="https://docs.netbird.io/how-to/manage-posture-checks">
|
||||
Learn more
|
||||
</a>
|
||||
</p>
|
||||
<br/>
|
||||
<div align="center">
|
||||
<p align="center">
|
||||
<img width="234" src="docs/media/logo-full.png"/>
|
||||
</p>
|
||||
<p>
|
||||
<a href="https://img.shields.io/badge/license-BSD--3-blue)">
|
||||
<img src="https://sonarcloud.io/api/project_badges/measure?project=netbirdio_netbird&metric=alert_status" />
|
||||
</a>
|
||||
<a href="https://github.com/netbirdio/netbird/blob/main/LICENSE">
|
||||
<img src="https://img.shields.io/badge/license-BSD--3-blue" />
|
||||
</a>
|
||||
<a href="https://www.codacy.com/gh/netbirdio/netbird/dashboard?utm_source=github.com&utm_medium=referral&utm_content=netbirdio/netbird&utm_campaign=Badge_Grade"><img src="https://app.codacy.com/project/badge/Grade/e3013d046aec44cdb7462c8673b00976"/></a>
|
||||
<br>
|
||||
<a href="https://join.slack.com/t/netbirdio/shared_invite/zt-vrahf41g-ik1v7fV8du6t0RwxSrJ96A">
|
||||
<a href="https://join.slack.com/t/netbirdio/shared_invite/zt-2utg2ncdz-W7LEB6toRBLE1Jca37dYpg">
|
||||
<img src="https://img.shields.io/badge/slack-@netbird-red.svg?logo=slack"/>
|
||||
</a>
|
||||
<br>
|
||||
<a href="https://gurubase.io/g/netbird">
|
||||
<img src="https://img.shields.io/badge/Gurubase-Ask%20NetBird%20Guru-006BFF"/>
|
||||
</a>
|
||||
</p>
|
||||
</div>
|
||||
@@ -28,7 +27,7 @@
|
||||
<br/>
|
||||
See <a href="https://netbird.io/docs/">Documentation</a>
|
||||
<br/>
|
||||
Join our <a href="https://join.slack.com/t/netbirdio/shared_invite/zt-vrahf41g-ik1v7fV8du6t0RwxSrJ96A">Slack channel</a>
|
||||
Join our <a href="https://join.slack.com/t/netbirdio/shared_invite/zt-2utg2ncdz-W7LEB6toRBLE1Jca37dYpg">Slack channel</a>
|
||||
<br/>
|
||||
|
||||
</strong>
|
||||
@@ -47,6 +46,8 @@
|
||||
|
||||

|
||||
|
||||
### NetBird on Lawrence Systems (Video)
|
||||
[](https://www.youtube.com/watch?v=Kwrff6h0rEw)
|
||||
|
||||
### Key features
|
||||
|
||||
@@ -60,6 +61,7 @@
|
||||
| | | <ul><li> - \[x] [Quantum-resistance with Rosenpass](https://netbird.io/knowledge-hub/the-first-quantum-resistant-mesh-vpn) </ul></li> | | <ul><li> - \[x] OpenWRT </ul></li> |
|
||||
| | | <ui><li> - \[x] [Periodic re-authentication](https://docs.netbird.io/how-to/enforce-periodic-user-authentication)</ul></li> | | <ul><li> - \[x] [Serverless](https://docs.netbird.io/how-to/netbird-on-faas) </ul></li> |
|
||||
| | | | | <ul><li> - \[x] Docker </ul></li> |
|
||||
|
||||
### Quickstart with NetBird Cloud
|
||||
|
||||
- Download and install NetBird at [https://app.netbird.io/install](https://app.netbird.io/install)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
FROM alpine:3.19
|
||||
FROM alpine:3.21.0
|
||||
RUN apk add --no-cache ca-certificates iptables ip6tables
|
||||
ENV NB_FOREGROUND_MODE=true
|
||||
ENTRYPOINT [ "/usr/local/bin/netbird","up"]
|
||||
|
||||
16
client/Dockerfile-rootless
Normal file
16
client/Dockerfile-rootless
Normal file
@@ -0,0 +1,16 @@
|
||||
FROM alpine:3.21.0
|
||||
|
||||
COPY netbird /usr/local/bin/netbird
|
||||
|
||||
RUN apk add --no-cache ca-certificates \
|
||||
&& adduser -D -h /var/lib/netbird netbird
|
||||
WORKDIR /var/lib/netbird
|
||||
USER netbird:netbird
|
||||
|
||||
ENV NB_FOREGROUND_MODE=true
|
||||
ENV NB_USE_NETSTACK_MODE=true
|
||||
ENV NB_CONFIG=config.json
|
||||
ENV NB_DAEMON_ADDR=unix://netbird.sock
|
||||
ENV NB_DISABLE_DNS=true
|
||||
|
||||
ENTRYPOINT [ "/usr/local/bin/netbird", "up" ]
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
"github.com/netbirdio/netbird/client/internal/dns"
|
||||
"github.com/netbirdio/netbird/client/internal/listener"
|
||||
@@ -15,7 +16,6 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||
"github.com/netbirdio/netbird/client/system"
|
||||
"github.com/netbirdio/netbird/formatter"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
"github.com/netbirdio/netbird/util/net"
|
||||
)
|
||||
|
||||
@@ -26,7 +26,7 @@ type ConnectionListener interface {
|
||||
|
||||
// TunAdapter export internal TunAdapter for mobile
|
||||
type TunAdapter interface {
|
||||
iface.TunAdapter
|
||||
device.TunAdapter
|
||||
}
|
||||
|
||||
// IFaceDiscover export internal IFaceDiscover for mobile
|
||||
@@ -51,7 +51,7 @@ func init() {
|
||||
// Client struct manage the life circle of background service
|
||||
type Client struct {
|
||||
cfgFile string
|
||||
tunAdapter iface.TunAdapter
|
||||
tunAdapter device.TunAdapter
|
||||
iFaceDiscover IFaceDiscover
|
||||
recorder *peer.Status
|
||||
ctxCancel context.CancelFunc
|
||||
|
||||
@@ -84,7 +84,7 @@ func (a *Auth) SaveConfigIfSSOSupported(listener SSOListener) {
|
||||
func (a *Auth) saveConfigIfSSOSupported() (bool, error) {
|
||||
supportsSSO := true
|
||||
err := a.withBackOff(a.ctx, func() (err error) {
|
||||
_, err = internal.GetPKCEAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL)
|
||||
_, err = internal.GetPKCEAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL, nil)
|
||||
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.NotFound || s.Code() == codes.Unimplemented) {
|
||||
_, err = internal.GetDeviceAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL)
|
||||
s, ok := gstatus.FromError(err)
|
||||
@@ -162,7 +162,7 @@ func (a *Auth) login(urlOpener URLOpener) error {
|
||||
|
||||
// check if we need to generate JWT token
|
||||
err := a.withBackOff(a.ctx, func() (err error) {
|
||||
needsLogin, err = internal.IsLoginRequired(a.ctx, a.config.PrivateKey, a.config.ManagementURL, a.config.SSHKey)
|
||||
needsLogin, err = internal.IsLoginRequired(a.ctx, a.config)
|
||||
return
|
||||
})
|
||||
if err != nil {
|
||||
|
||||
@@ -12,6 +12,8 @@ import (
|
||||
"strings"
|
||||
)
|
||||
|
||||
const anonTLD = ".domain"
|
||||
|
||||
type Anonymizer struct {
|
||||
ipAnonymizer map[netip.Addr]netip.Addr
|
||||
domainAnonymizer map[string]string
|
||||
@@ -19,6 +21,8 @@ type Anonymizer struct {
|
||||
currentAnonIPv6 netip.Addr
|
||||
startAnonIPv4 netip.Addr
|
||||
startAnonIPv6 netip.Addr
|
||||
|
||||
domainKeyRegex *regexp.Regexp
|
||||
}
|
||||
|
||||
func DefaultAddresses() (netip.Addr, netip.Addr) {
|
||||
@@ -34,6 +38,8 @@ func NewAnonymizer(startIPv4, startIPv6 netip.Addr) *Anonymizer {
|
||||
currentAnonIPv6: startIPv6,
|
||||
startAnonIPv4: startIPv4,
|
||||
startAnonIPv6: startIPv6,
|
||||
|
||||
domainKeyRegex: regexp.MustCompile(`\bdomain=([^\s,:"]+)`),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -83,29 +89,39 @@ func (a *Anonymizer) AnonymizeIPString(ip string) string {
|
||||
}
|
||||
|
||||
func (a *Anonymizer) AnonymizeDomain(domain string) string {
|
||||
if strings.HasSuffix(domain, "netbird.io") ||
|
||||
strings.HasSuffix(domain, "netbird.selfhosted") ||
|
||||
strings.HasSuffix(domain, "netbird.cloud") ||
|
||||
strings.HasSuffix(domain, "netbird.stage") ||
|
||||
strings.HasSuffix(domain, ".domain") {
|
||||
baseDomain := domain
|
||||
hasDot := strings.HasSuffix(domain, ".")
|
||||
if hasDot {
|
||||
baseDomain = domain[:len(domain)-1]
|
||||
}
|
||||
|
||||
if strings.HasSuffix(baseDomain, "netbird.io") ||
|
||||
strings.HasSuffix(baseDomain, "netbird.selfhosted") ||
|
||||
strings.HasSuffix(baseDomain, "netbird.cloud") ||
|
||||
strings.HasSuffix(baseDomain, "netbird.stage") ||
|
||||
strings.HasSuffix(baseDomain, anonTLD) {
|
||||
return domain
|
||||
}
|
||||
|
||||
parts := strings.Split(domain, ".")
|
||||
parts := strings.Split(baseDomain, ".")
|
||||
if len(parts) < 2 {
|
||||
return domain
|
||||
}
|
||||
|
||||
baseDomain := parts[len(parts)-2] + "." + parts[len(parts)-1]
|
||||
baseForLookup := parts[len(parts)-2] + "." + parts[len(parts)-1]
|
||||
|
||||
anonymized, ok := a.domainAnonymizer[baseDomain]
|
||||
anonymized, ok := a.domainAnonymizer[baseForLookup]
|
||||
if !ok {
|
||||
anonymizedBase := "anon-" + generateRandomString(5) + ".domain"
|
||||
a.domainAnonymizer[baseDomain] = anonymizedBase
|
||||
anonymizedBase := "anon-" + generateRandomString(5) + anonTLD
|
||||
a.domainAnonymizer[baseForLookup] = anonymizedBase
|
||||
anonymized = anonymizedBase
|
||||
}
|
||||
|
||||
return strings.Replace(domain, baseDomain, anonymized, 1)
|
||||
result := strings.Replace(baseDomain, baseForLookup, anonymized, 1)
|
||||
if hasDot {
|
||||
result += "."
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func (a *Anonymizer) AnonymizeURI(uri string) string {
|
||||
@@ -152,32 +168,42 @@ func (a *Anonymizer) AnonymizeString(str string) string {
|
||||
return str
|
||||
}
|
||||
|
||||
// AnonymizeSchemeURI finds and anonymizes URIs with stun, stuns, turn, and turns schemes.
|
||||
// AnonymizeSchemeURI finds and anonymizes URIs with ws, wss, rel, rels, stun, stuns, turn, and turns schemes.
|
||||
func (a *Anonymizer) AnonymizeSchemeURI(text string) string {
|
||||
re := regexp.MustCompile(`(?i)\b(stuns?:|turns?:|https?://)\S+\b`)
|
||||
re := regexp.MustCompile(`(?i)\b(wss?://|rels?://|stuns?:|turns?:|https?://)\S+\b`)
|
||||
|
||||
return re.ReplaceAllStringFunc(text, a.AnonymizeURI)
|
||||
}
|
||||
|
||||
// AnonymizeDNSLogLine anonymizes domain names in DNS log entries by replacing them with a random string.
|
||||
func (a *Anonymizer) AnonymizeDNSLogLine(logEntry string) string {
|
||||
domainPattern := `dns\.Question{Name:"([^"]+)",`
|
||||
domainRegex := regexp.MustCompile(domainPattern)
|
||||
|
||||
return domainRegex.ReplaceAllStringFunc(logEntry, func(match string) string {
|
||||
parts := strings.Split(match, `"`)
|
||||
return a.domainKeyRegex.ReplaceAllStringFunc(logEntry, func(match string) string {
|
||||
parts := strings.SplitN(match, "=", 2)
|
||||
if len(parts) >= 2 {
|
||||
domain := parts[1]
|
||||
if strings.HasSuffix(domain, ".domain") {
|
||||
if strings.HasSuffix(domain, anonTLD) {
|
||||
return match
|
||||
}
|
||||
randomDomain := generateRandomString(10) + ".domain"
|
||||
return strings.Replace(match, domain, randomDomain, 1)
|
||||
return "domain=" + a.AnonymizeDomain(domain)
|
||||
}
|
||||
return match
|
||||
})
|
||||
}
|
||||
|
||||
// AnonymizeRoute anonymizes a route string by replacing IP addresses with anonymized versions and
|
||||
// domain names with random strings.
|
||||
func (a *Anonymizer) AnonymizeRoute(route string) string {
|
||||
prefix, err := netip.ParsePrefix(route)
|
||||
if err == nil {
|
||||
ip := a.AnonymizeIPString(prefix.Addr().String())
|
||||
return fmt.Sprintf("%s/%d", ip, prefix.Bits())
|
||||
}
|
||||
domains := strings.Split(route, ", ")
|
||||
for i, domain := range domains {
|
||||
domains[i] = a.AnonymizeDomain(domain)
|
||||
}
|
||||
return strings.Join(domains, ", ")
|
||||
}
|
||||
|
||||
func isWellKnown(addr netip.Addr) bool {
|
||||
wellKnown := []string{
|
||||
"8.8.8.8", "8.8.4.4", // Google DNS IPv4
|
||||
@@ -186,6 +212,8 @@ func isWellKnown(addr netip.Addr) bool {
|
||||
"2606:4700:4700::1111", "2606:4700:4700::1001", // Cloudflare DNS IPv6
|
||||
"9.9.9.9", "149.112.112.112", // Quad9 DNS IPv4
|
||||
"2620:fe::fe", "2620:fe::9", // Quad9 DNS IPv6
|
||||
|
||||
"128.0.0.0", "8000::", // 2nd split subnet for default routes
|
||||
}
|
||||
|
||||
if slices.Contains(wellKnown, addr.String()) {
|
||||
|
||||
@@ -46,11 +46,59 @@ func TestAnonymizeIP(t *testing.T) {
|
||||
|
||||
func TestAnonymizeDNSLogLine(t *testing.T) {
|
||||
anonymizer := anonymize.NewAnonymizer(netip.Addr{}, netip.Addr{})
|
||||
testLog := `2024-04-23T20:01:11+02:00 TRAC client/internal/dns/local.go:25: received question: dns.Question{Name:"example.com", Qtype:0x1c, Qclass:0x1}`
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
original string
|
||||
expect string
|
||||
}{
|
||||
{
|
||||
name: "Basic domain with trailing content",
|
||||
input: "received DNS request for DNS forwarder: domain=example.com: something happened with code=123",
|
||||
original: "example.com",
|
||||
expect: `received DNS request for DNS forwarder: domain=anon-[a-zA-Z0-9]+\.domain: something happened with code=123`,
|
||||
},
|
||||
{
|
||||
name: "Domain with trailing dot",
|
||||
input: "domain=example.com. processing request with status=pending",
|
||||
original: "example.com",
|
||||
expect: `domain=anon-[a-zA-Z0-9]+\.domain\. processing request with status=pending`,
|
||||
},
|
||||
{
|
||||
name: "Multiple domains in log",
|
||||
input: "forward domain=first.com status=ok, redirect to domain=second.com port=443",
|
||||
original: "first.com", // testing just one is sufficient as AnonymizeDomain is tested separately
|
||||
expect: `forward domain=anon-[a-zA-Z0-9]+\.domain status=ok, redirect to domain=anon-[a-zA-Z0-9]+\.domain port=443`,
|
||||
},
|
||||
{
|
||||
name: "Already anonymized domain",
|
||||
input: "got request domain=anon-xyz123.domain from=client1 to=server2",
|
||||
original: "", // nothing should be anonymized
|
||||
expect: `got request domain=anon-xyz123\.domain from=client1 to=server2`,
|
||||
},
|
||||
{
|
||||
name: "Subdomain with trailing dot",
|
||||
input: "domain=sub.example.com. next_hop=10.0.0.1 proto=udp",
|
||||
original: "example.com",
|
||||
expect: `domain=sub\.anon-[a-zA-Z0-9]+\.domain\. next_hop=10\.0\.0\.1 proto=udp`,
|
||||
},
|
||||
{
|
||||
name: "Handler chain pattern log",
|
||||
input: "pattern: domain=example.com. original: domain=*.example.com. wildcard=true priority=100",
|
||||
original: "example.com",
|
||||
expect: `pattern: domain=anon-[a-zA-Z0-9]+\.domain\. original: domain=\*\.anon-[a-zA-Z0-9]+\.domain\. wildcard=true priority=100`,
|
||||
},
|
||||
}
|
||||
|
||||
result := anonymizer.AnonymizeDNSLogLine(testLog)
|
||||
require.NotEqual(t, testLog, result)
|
||||
assert.NotContains(t, result, "example.com")
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
result := anonymizer.AnonymizeDNSLogLine(tc.input)
|
||||
if tc.original != "" {
|
||||
assert.NotContains(t, result, tc.original)
|
||||
}
|
||||
assert.Regexp(t, tc.expect, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAnonymizeDomain(t *testing.T) {
|
||||
@@ -67,18 +115,36 @@ func TestAnonymizeDomain(t *testing.T) {
|
||||
`^anon-[a-zA-Z0-9]+\.domain$`,
|
||||
true,
|
||||
},
|
||||
{
|
||||
"Domain with Trailing Dot",
|
||||
"example.com.",
|
||||
`^anon-[a-zA-Z0-9]+\.domain.$`,
|
||||
true,
|
||||
},
|
||||
{
|
||||
"Subdomain",
|
||||
"sub.example.com",
|
||||
`^sub\.anon-[a-zA-Z0-9]+\.domain$`,
|
||||
true,
|
||||
},
|
||||
{
|
||||
"Subdomain with Trailing Dot",
|
||||
"sub.example.com.",
|
||||
`^sub\.anon-[a-zA-Z0-9]+\.domain.$`,
|
||||
true,
|
||||
},
|
||||
{
|
||||
"Protected Domain",
|
||||
"netbird.io",
|
||||
`^netbird\.io$`,
|
||||
false,
|
||||
},
|
||||
{
|
||||
"Protected Domain with Trailing Dot",
|
||||
"netbird.io.",
|
||||
`^netbird\.io.$`,
|
||||
false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
@@ -140,8 +206,16 @@ func TestAnonymizeSchemeURI(t *testing.T) {
|
||||
expect string
|
||||
}{
|
||||
{"STUN URI in text", "Connection made via stun:example.com", `Connection made via stun:anon-[a-zA-Z0-9]+\.domain`},
|
||||
{"STUNS URI in message", "Secure connection to stuns:example.com:443", `Secure connection to stuns:anon-[a-zA-Z0-9]+\.domain:443`},
|
||||
{"TURN URI in log", "Failed attempt turn:some.example.com:3478?transport=tcp: retrying", `Failed attempt turn:some.anon-[a-zA-Z0-9]+\.domain:3478\?transport=tcp: retrying`},
|
||||
{"TURNS URI in message", "Secure connection to turns:example.com:5349", `Secure connection to turns:anon-[a-zA-Z0-9]+\.domain:5349`},
|
||||
{"HTTP URI in text", "Visit http://example.com for more", `Visit http://anon-[a-zA-Z0-9]+\.domain for more`},
|
||||
{"HTTPS URI in CAPS", "Visit HTTPS://example.com for more", `Visit https://anon-[a-zA-Z0-9]+\.domain for more`},
|
||||
{"HTTPS URI in message", "Visit https://example.com for more", `Visit https://anon-[a-zA-Z0-9]+\.domain for more`},
|
||||
{"WS URI in log", "Connection established to ws://example.com:8080", `Connection established to ws://anon-[a-zA-Z0-9]+\.domain:8080`},
|
||||
{"WSS URI in message", "Secure connection to wss://example.com", `Secure connection to wss://anon-[a-zA-Z0-9]+\.domain`},
|
||||
{"Rel URI in text", "Relaying to rel://example.com", `Relaying to rel://anon-[a-zA-Z0-9]+\.domain`},
|
||||
{"Rels URI in message", "Relaying to rels://example.com", `Relaying to rels://anon-[a-zA-Z0-9]+\.domain`},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
|
||||
@@ -3,8 +3,10 @@ package cmd
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/spf13/cobra"
|
||||
"google.golang.org/grpc/status"
|
||||
|
||||
@@ -13,6 +15,8 @@ import (
|
||||
"github.com/netbirdio/netbird/client/server"
|
||||
)
|
||||
|
||||
const errCloseConnection = "Failed to close connection: %v"
|
||||
|
||||
var debugCmd = &cobra.Command{
|
||||
Use: "debug",
|
||||
Short: "Debugging commands",
|
||||
@@ -58,17 +62,31 @@ var forCmd = &cobra.Command{
|
||||
RunE: runForDuration,
|
||||
}
|
||||
|
||||
var persistenceCmd = &cobra.Command{
|
||||
Use: "persistence [on|off]",
|
||||
Short: "Set network map memory persistence",
|
||||
Long: `Configure whether the latest network map should persist in memory. When enabled, the last known network map will be kept in memory.`,
|
||||
Example: " netbird debug persistence on",
|
||||
Args: cobra.ExactArgs(1),
|
||||
RunE: setNetworkMapPersistence,
|
||||
}
|
||||
|
||||
func debugBundle(cmd *cobra.Command, _ []string) error {
|
||||
conn, err := getClient(cmd)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer conn.Close()
|
||||
defer func() {
|
||||
if err := conn.Close(); err != nil {
|
||||
log.Errorf(errCloseConnection, err)
|
||||
}
|
||||
}()
|
||||
|
||||
client := proto.NewDaemonServiceClient(conn)
|
||||
resp, err := client.DebugBundle(cmd.Context(), &proto.DebugBundleRequest{
|
||||
Anonymize: anonymizeFlag,
|
||||
Status: getStatusOutput(cmd),
|
||||
Anonymize: anonymizeFlag,
|
||||
Status: getStatusOutput(cmd),
|
||||
SystemInfo: debugSystemInfoFlag,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to bundle debug: %v", status.Convert(err).Message())
|
||||
@@ -84,7 +102,11 @@ func setLogLevel(cmd *cobra.Command, args []string) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer conn.Close()
|
||||
defer func() {
|
||||
if err := conn.Close(); err != nil {
|
||||
log.Errorf(errCloseConnection, err)
|
||||
}
|
||||
}()
|
||||
|
||||
client := proto.NewDaemonServiceClient(conn)
|
||||
level := server.ParseLogLevel(args[0])
|
||||
@@ -113,7 +135,11 @@ func runForDuration(cmd *cobra.Command, args []string) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer conn.Close()
|
||||
defer func() {
|
||||
if err := conn.Close(); err != nil {
|
||||
log.Errorf(errCloseConnection, err)
|
||||
}
|
||||
}()
|
||||
|
||||
client := proto.NewDaemonServiceClient(conn)
|
||||
|
||||
@@ -122,17 +148,20 @@ func runForDuration(cmd *cobra.Command, args []string) error {
|
||||
return fmt.Errorf("failed to get status: %v", status.Convert(err).Message())
|
||||
}
|
||||
|
||||
restoreUp := stat.Status == string(internal.StatusConnected) || stat.Status == string(internal.StatusConnecting)
|
||||
stateWasDown := stat.Status != string(internal.StatusConnected) && stat.Status != string(internal.StatusConnecting)
|
||||
|
||||
initialLogLevel, err := client.GetLogLevel(cmd.Context(), &proto.GetLogLevelRequest{})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get log level: %v", status.Convert(err).Message())
|
||||
}
|
||||
|
||||
if _, err := client.Down(cmd.Context(), &proto.DownRequest{}); err != nil {
|
||||
return fmt.Errorf("failed to down: %v", status.Convert(err).Message())
|
||||
if stateWasDown {
|
||||
if _, err := client.Up(cmd.Context(), &proto.UpRequest{}); err != nil {
|
||||
return fmt.Errorf("failed to up: %v", status.Convert(err).Message())
|
||||
}
|
||||
cmd.Println("Netbird up")
|
||||
time.Sleep(time.Second * 10)
|
||||
}
|
||||
cmd.Println("Netbird down")
|
||||
|
||||
initialLevelTrace := initialLogLevel.GetLevel() >= proto.LogLevel_TRACE
|
||||
if !initialLevelTrace {
|
||||
@@ -145,8 +174,20 @@ func runForDuration(cmd *cobra.Command, args []string) error {
|
||||
cmd.Println("Log level set to trace.")
|
||||
}
|
||||
|
||||
if _, err := client.Down(cmd.Context(), &proto.DownRequest{}); err != nil {
|
||||
return fmt.Errorf("failed to down: %v", status.Convert(err).Message())
|
||||
}
|
||||
cmd.Println("Netbird down")
|
||||
|
||||
time.Sleep(1 * time.Second)
|
||||
|
||||
// Enable network map persistence before bringing the service up
|
||||
if _, err := client.SetNetworkMapPersistence(cmd.Context(), &proto.SetNetworkMapPersistenceRequest{
|
||||
Enabled: true,
|
||||
}); err != nil {
|
||||
return fmt.Errorf("failed to enable network map persistence: %v", status.Convert(err).Message())
|
||||
}
|
||||
|
||||
if _, err := client.Up(cmd.Context(), &proto.UpRequest{}); err != nil {
|
||||
return fmt.Errorf("failed to up: %v", status.Convert(err).Message())
|
||||
}
|
||||
@@ -162,21 +203,32 @@ func runForDuration(cmd *cobra.Command, args []string) error {
|
||||
}
|
||||
cmd.Println("\nDuration completed")
|
||||
|
||||
cmd.Println("Creating debug bundle...")
|
||||
|
||||
headerPreDown := fmt.Sprintf("----- Netbird pre-down - Timestamp: %s - Duration: %s", time.Now().Format(time.RFC3339), duration)
|
||||
statusOutput = fmt.Sprintf("%s\n%s\n%s", statusOutput, headerPreDown, getStatusOutput(cmd))
|
||||
|
||||
if _, err := client.Down(cmd.Context(), &proto.DownRequest{}); err != nil {
|
||||
return fmt.Errorf("failed to down: %v", status.Convert(err).Message())
|
||||
resp, err := client.DebugBundle(cmd.Context(), &proto.DebugBundleRequest{
|
||||
Anonymize: anonymizeFlag,
|
||||
Status: statusOutput,
|
||||
SystemInfo: debugSystemInfoFlag,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to bundle debug: %v", status.Convert(err).Message())
|
||||
}
|
||||
cmd.Println("Netbird down")
|
||||
|
||||
time.Sleep(1 * time.Second)
|
||||
// Disable network map persistence after creating the debug bundle
|
||||
if _, err := client.SetNetworkMapPersistence(cmd.Context(), &proto.SetNetworkMapPersistenceRequest{
|
||||
Enabled: false,
|
||||
}); err != nil {
|
||||
return fmt.Errorf("failed to disable network map persistence: %v", status.Convert(err).Message())
|
||||
}
|
||||
|
||||
if restoreUp {
|
||||
if _, err := client.Up(cmd.Context(), &proto.UpRequest{}); err != nil {
|
||||
return fmt.Errorf("failed to up: %v", status.Convert(err).Message())
|
||||
if stateWasDown {
|
||||
if _, err := client.Down(cmd.Context(), &proto.DownRequest{}); err != nil {
|
||||
return fmt.Errorf("failed to down: %v", status.Convert(err).Message())
|
||||
}
|
||||
cmd.Println("Netbird up")
|
||||
cmd.Println("Netbird down")
|
||||
}
|
||||
|
||||
if !initialLevelTrace {
|
||||
@@ -186,21 +238,39 @@ func runForDuration(cmd *cobra.Command, args []string) error {
|
||||
cmd.Println("Log level restored to", initialLogLevel.GetLevel())
|
||||
}
|
||||
|
||||
cmd.Println("Creating debug bundle...")
|
||||
|
||||
resp, err := client.DebugBundle(cmd.Context(), &proto.DebugBundleRequest{
|
||||
Anonymize: anonymizeFlag,
|
||||
Status: statusOutput,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to bundle debug: %v", status.Convert(err).Message())
|
||||
}
|
||||
|
||||
cmd.Println(resp.GetPath())
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func setNetworkMapPersistence(cmd *cobra.Command, args []string) error {
|
||||
conn, err := getClient(cmd)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() {
|
||||
if err := conn.Close(); err != nil {
|
||||
log.Errorf(errCloseConnection, err)
|
||||
}
|
||||
}()
|
||||
|
||||
persistence := strings.ToLower(args[0])
|
||||
if persistence != "on" && persistence != "off" {
|
||||
return fmt.Errorf("invalid persistence value: %s. Use 'on' or 'off'", args[0])
|
||||
}
|
||||
|
||||
client := proto.NewDaemonServiceClient(conn)
|
||||
_, err = client.SetNetworkMapPersistence(cmd.Context(), &proto.SetNetworkMapPersistenceRequest{
|
||||
Enabled: persistence == "on",
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to set network map persistence: %v", status.Convert(err).Message())
|
||||
}
|
||||
|
||||
cmd.Printf("Network map persistence set to: %s\n", persistence)
|
||||
return nil
|
||||
}
|
||||
|
||||
func getStatusOutput(cmd *cobra.Command) string {
|
||||
var statusOutputString string
|
||||
statusResp, err := getStatus(cmd.Context())
|
||||
|
||||
@@ -26,7 +26,7 @@ var downCmd = &cobra.Command{
|
||||
return err
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second*3)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second*7)
|
||||
defer cancel()
|
||||
|
||||
conn, err := DialClientGRPCServer(ctx, daemonAddr)
|
||||
@@ -42,6 +42,8 @@ var downCmd = &cobra.Command{
|
||||
log.Errorf("call service down method: %v", err)
|
||||
return err
|
||||
}
|
||||
|
||||
cmd.Println("Disconnected")
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
@@ -39,6 +39,11 @@ var loginCmd = &cobra.Command{
|
||||
ctx = context.WithValue(ctx, system.DeviceNameCtxKey, hostName)
|
||||
}
|
||||
|
||||
providedSetupKey, err := getSetupKey()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// workaround to run without service
|
||||
if logFile == "console" {
|
||||
err = handleRebrand(cmd)
|
||||
@@ -62,7 +67,7 @@ var loginCmd = &cobra.Command{
|
||||
|
||||
config, _ = internal.UpdateOldManagementURL(ctx, config, configPath)
|
||||
|
||||
err = foregroundLogin(ctx, cmd, config, setupKey)
|
||||
err = foregroundLogin(ctx, cmd, config, providedSetupKey)
|
||||
if err != nil {
|
||||
return fmt.Errorf("foreground login failed: %v", err)
|
||||
}
|
||||
@@ -81,7 +86,7 @@ var loginCmd = &cobra.Command{
|
||||
client := proto.NewDaemonServiceClient(conn)
|
||||
|
||||
loginRequest := proto.LoginRequest{
|
||||
SetupKey: setupKey,
|
||||
SetupKey: providedSetupKey,
|
||||
ManagementUrl: managementURL,
|
||||
IsLinuxDesktopClient: isLinuxRunningDesktop(),
|
||||
Hostname: hostName,
|
||||
|
||||
@@ -5,8 +5,8 @@ import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
|
||||
|
||||
173
client/cmd/networks.go
Normal file
173
client/cmd/networks.go
Normal file
@@ -0,0 +1,173 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
"google.golang.org/grpc/status"
|
||||
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
)
|
||||
|
||||
var appendFlag bool
|
||||
|
||||
var networksCMD = &cobra.Command{
|
||||
Use: "networks",
|
||||
Aliases: []string{"routes"},
|
||||
Short: "Manage networks",
|
||||
Long: `Commands to list, select, or deselect networks. Replaces the "routes" command.`,
|
||||
}
|
||||
|
||||
var routesListCmd = &cobra.Command{
|
||||
Use: "list",
|
||||
Aliases: []string{"ls"},
|
||||
Short: "List networks",
|
||||
Example: " netbird networks list",
|
||||
Long: "List all available network routes.",
|
||||
RunE: networksList,
|
||||
}
|
||||
|
||||
var routesSelectCmd = &cobra.Command{
|
||||
Use: "select network...|all",
|
||||
Short: "Select network",
|
||||
Long: "Select a list of networks by identifiers or 'all' to clear all selections and to accept all (including new) networks.\nDefault mode is replace, use -a to append to already selected networks.",
|
||||
Example: " netbird networks select all\n netbird networks select route1 route2\n netbird routes select -a route3",
|
||||
Args: cobra.MinimumNArgs(1),
|
||||
RunE: networksSelect,
|
||||
}
|
||||
|
||||
var routesDeselectCmd = &cobra.Command{
|
||||
Use: "deselect network...|all",
|
||||
Short: "Deselect networks",
|
||||
Long: "Deselect previously selected networks by identifiers or 'all' to disable accepting any networks.",
|
||||
Example: " netbird networks deselect all\n netbird networks deselect route1 route2",
|
||||
Args: cobra.MinimumNArgs(1),
|
||||
RunE: networksDeselect,
|
||||
}
|
||||
|
||||
func init() {
|
||||
routesSelectCmd.PersistentFlags().BoolVarP(&appendFlag, "append", "a", false, "Append to current network selection instead of replacing")
|
||||
}
|
||||
|
||||
func networksList(cmd *cobra.Command, _ []string) error {
|
||||
conn, err := getClient(cmd)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
client := proto.NewDaemonServiceClient(conn)
|
||||
resp, err := client.ListNetworks(cmd.Context(), &proto.ListNetworksRequest{})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to list network: %v", status.Convert(err).Message())
|
||||
}
|
||||
|
||||
if len(resp.Routes) == 0 {
|
||||
cmd.Println("No networks available.")
|
||||
return nil
|
||||
}
|
||||
|
||||
printNetworks(cmd, resp)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func printNetworks(cmd *cobra.Command, resp *proto.ListNetworksResponse) {
|
||||
cmd.Println("Available Networks:")
|
||||
for _, route := range resp.Routes {
|
||||
printNetwork(cmd, route)
|
||||
}
|
||||
}
|
||||
|
||||
func printNetwork(cmd *cobra.Command, route *proto.Network) {
|
||||
selectedStatus := getSelectedStatus(route)
|
||||
domains := route.GetDomains()
|
||||
|
||||
if len(domains) > 0 {
|
||||
printDomainRoute(cmd, route, domains, selectedStatus)
|
||||
} else {
|
||||
printNetworkRoute(cmd, route, selectedStatus)
|
||||
}
|
||||
}
|
||||
|
||||
func getSelectedStatus(route *proto.Network) string {
|
||||
if route.GetSelected() {
|
||||
return "Selected"
|
||||
}
|
||||
return "Not Selected"
|
||||
}
|
||||
|
||||
func printDomainRoute(cmd *cobra.Command, route *proto.Network, domains []string, selectedStatus string) {
|
||||
cmd.Printf("\n - ID: %s\n Domains: %s\n Status: %s\n", route.GetID(), strings.Join(domains, ", "), selectedStatus)
|
||||
resolvedIPs := route.GetResolvedIPs()
|
||||
|
||||
if len(resolvedIPs) > 0 {
|
||||
printResolvedIPs(cmd, domains, resolvedIPs)
|
||||
} else {
|
||||
cmd.Printf(" Resolved IPs: -\n")
|
||||
}
|
||||
}
|
||||
|
||||
func printNetworkRoute(cmd *cobra.Command, route *proto.Network, selectedStatus string) {
|
||||
cmd.Printf("\n - ID: %s\n Network: %s\n Status: %s\n", route.GetID(), route.GetRange(), selectedStatus)
|
||||
}
|
||||
|
||||
func printResolvedIPs(cmd *cobra.Command, _ []string, resolvedIPs map[string]*proto.IPList) {
|
||||
cmd.Printf(" Resolved IPs:\n")
|
||||
for resolvedDomain, ipList := range resolvedIPs {
|
||||
cmd.Printf(" [%s]: %s\n", resolvedDomain, strings.Join(ipList.GetIps(), ", "))
|
||||
}
|
||||
}
|
||||
|
||||
func networksSelect(cmd *cobra.Command, args []string) error {
|
||||
conn, err := getClient(cmd)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
client := proto.NewDaemonServiceClient(conn)
|
||||
req := &proto.SelectNetworksRequest{
|
||||
NetworkIDs: args,
|
||||
}
|
||||
|
||||
if len(args) == 1 && args[0] == "all" {
|
||||
req.All = true
|
||||
} else if appendFlag {
|
||||
req.Append = true
|
||||
}
|
||||
|
||||
if _, err := client.SelectNetworks(cmd.Context(), req); err != nil {
|
||||
return fmt.Errorf("failed to select networks: %v", status.Convert(err).Message())
|
||||
}
|
||||
|
||||
cmd.Println("Networks selected successfully.")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func networksDeselect(cmd *cobra.Command, args []string) error {
|
||||
conn, err := getClient(cmd)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
client := proto.NewDaemonServiceClient(conn)
|
||||
req := &proto.SelectNetworksRequest{
|
||||
NetworkIDs: args,
|
||||
}
|
||||
|
||||
if len(args) == 1 && args[0] == "all" {
|
||||
req.All = true
|
||||
}
|
||||
|
||||
if _, err := client.DeselectNetworks(cmd.Context(), req); err != nil {
|
||||
return fmt.Errorf("failed to deselect networks: %v", status.Convert(err).Message())
|
||||
}
|
||||
|
||||
cmd.Println("Networks deselected successfully.")
|
||||
|
||||
return nil
|
||||
}
|
||||
33
client/cmd/pprof.go
Normal file
33
client/cmd/pprof.go
Normal file
@@ -0,0 +1,33 @@
|
||||
//go:build pprof
|
||||
// +build pprof
|
||||
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
_ "net/http/pprof"
|
||||
"os"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
func init() {
|
||||
addr := pprofAddr()
|
||||
go pprof(addr)
|
||||
}
|
||||
|
||||
func pprofAddr() string {
|
||||
listenAddr := os.Getenv("NB_PPROF_ADDR")
|
||||
if listenAddr == "" {
|
||||
return "localhost:6969"
|
||||
}
|
||||
|
||||
return listenAddr
|
||||
}
|
||||
|
||||
func pprof(listenAddr string) {
|
||||
log.Infof("listening pprof on: %s\n", listenAddr)
|
||||
if err := http.ListenAndServe(listenAddr, nil); err != nil {
|
||||
log.Fatalf("Failed to start pprof: %v", err)
|
||||
}
|
||||
}
|
||||
@@ -37,6 +37,8 @@ const (
|
||||
serverSSHAllowedFlag = "allow-server-ssh"
|
||||
extraIFaceBlackListFlag = "extra-iface-blacklist"
|
||||
dnsRouteIntervalFlag = "dns-router-interval"
|
||||
systemInfoFlag = "system-info"
|
||||
blockLANAccessFlag = "block-lan-access"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -55,6 +57,7 @@ var (
|
||||
managementURL string
|
||||
adminURL string
|
||||
setupKey string
|
||||
setupKeyPath string
|
||||
hostName string
|
||||
preSharedKey string
|
||||
natExternalIPs []string
|
||||
@@ -69,7 +72,9 @@ var (
|
||||
autoConnectDisabled bool
|
||||
extraIFaceBlackList []string
|
||||
anonymizeFlag bool
|
||||
debugSystemInfoFlag bool
|
||||
dnsRouteInterval time.Duration
|
||||
blockLANAccess bool
|
||||
|
||||
rootCmd = &cobra.Command{
|
||||
Use: "netbird",
|
||||
@@ -91,12 +96,15 @@ func init() {
|
||||
oldDefaultConfigPathDir = "/etc/wiretrustee/"
|
||||
oldDefaultLogFileDir = "/var/log/wiretrustee/"
|
||||
|
||||
if runtime.GOOS == "windows" {
|
||||
switch runtime.GOOS {
|
||||
case "windows":
|
||||
defaultConfigPathDir = os.Getenv("PROGRAMDATA") + "\\Netbird\\"
|
||||
defaultLogFileDir = os.Getenv("PROGRAMDATA") + "\\Netbird\\"
|
||||
|
||||
oldDefaultConfigPathDir = os.Getenv("PROGRAMDATA") + "\\Wiretrustee\\"
|
||||
oldDefaultLogFileDir = os.Getenv("PROGRAMDATA") + "\\Wiretrustee\\"
|
||||
case "freebsd":
|
||||
defaultConfigPathDir = "/var/db/netbird/"
|
||||
}
|
||||
|
||||
defaultConfigPath = defaultConfigPathDir + "config.json"
|
||||
@@ -121,8 +129,10 @@ func init() {
|
||||
rootCmd.PersistentFlags().StringVarP(&serviceName, "service", "s", defaultServiceName, "Netbird system service name")
|
||||
rootCmd.PersistentFlags().StringVarP(&configPath, "config", "c", defaultConfigPath, "Netbird config file location")
|
||||
rootCmd.PersistentFlags().StringVarP(&logLevel, "log-level", "l", "info", "sets Netbird log level")
|
||||
rootCmd.PersistentFlags().StringVar(&logFile, "log-file", defaultLogFile, "sets Netbird log path. If console is specified the log will be output to stdout")
|
||||
rootCmd.PersistentFlags().StringVar(&logFile, "log-file", defaultLogFile, "sets Netbird log path. If console is specified the log will be output to stdout. If syslog is specified the log will be sent to syslog daemon.")
|
||||
rootCmd.PersistentFlags().StringVarP(&setupKey, "setup-key", "k", "", "Setup key obtained from the Management Service Dashboard (used to register peer)")
|
||||
rootCmd.PersistentFlags().StringVar(&setupKeyPath, "setup-key-file", "", "The path to a setup key obtained from the Management Service Dashboard (used to register peer) This is ignored if the setup-key flag is provided.")
|
||||
rootCmd.MarkFlagsMutuallyExclusive("setup-key", "setup-key-file")
|
||||
rootCmd.PersistentFlags().StringVar(&preSharedKey, preSharedKeyFlag, "", "Sets Wireguard PreSharedKey property. If set, then only peers that have the same key can communicate.")
|
||||
rootCmd.PersistentFlags().StringVarP(&hostName, "hostname", "n", "", "Sets a custom hostname for the device")
|
||||
rootCmd.PersistentFlags().BoolVarP(&anonymizeFlag, "anonymize", "A", false, "anonymize IP addresses and non-netbird.io domains in logs and status output")
|
||||
@@ -134,19 +144,20 @@ func init() {
|
||||
rootCmd.AddCommand(loginCmd)
|
||||
rootCmd.AddCommand(versionCmd)
|
||||
rootCmd.AddCommand(sshCmd)
|
||||
rootCmd.AddCommand(routesCmd)
|
||||
rootCmd.AddCommand(networksCMD)
|
||||
rootCmd.AddCommand(debugCmd)
|
||||
|
||||
serviceCmd.AddCommand(runCmd, startCmd, stopCmd, restartCmd) // service control commands are subcommands of service
|
||||
serviceCmd.AddCommand(installCmd, uninstallCmd) // service installer commands are subcommands of service
|
||||
|
||||
routesCmd.AddCommand(routesListCmd)
|
||||
routesCmd.AddCommand(routesSelectCmd, routesDeselectCmd)
|
||||
networksCMD.AddCommand(routesListCmd)
|
||||
networksCMD.AddCommand(routesSelectCmd, routesDeselectCmd)
|
||||
|
||||
debugCmd.AddCommand(debugBundleCmd)
|
||||
debugCmd.AddCommand(logCmd)
|
||||
logCmd.AddCommand(logLevelCmd)
|
||||
debugCmd.AddCommand(forCmd)
|
||||
debugCmd.AddCommand(persistenceCmd)
|
||||
|
||||
upCmd.PersistentFlags().StringSliceVar(&natExternalIPs, externalIPMapFlag, nil,
|
||||
`Sets external IPs maps between local addresses and interfaces.`+
|
||||
@@ -165,6 +176,8 @@ func init() {
|
||||
upCmd.PersistentFlags().BoolVar(&rosenpassPermissive, rosenpassPermissiveFlag, false, "[Experimental] Enable Rosenpass in permissive mode to allow this peer to accept WireGuard connections without requiring Rosenpass functionality from peers that do not have Rosenpass enabled.")
|
||||
upCmd.PersistentFlags().BoolVar(&serverSSHAllowed, serverSSHAllowedFlag, false, "Allow SSH server on peer. If enabled, the SSH server will be permitted")
|
||||
upCmd.PersistentFlags().BoolVar(&autoConnectDisabled, disableAutoConnectFlag, false, "Disables auto-connect feature. If enabled, then the client won't connect automatically when the service starts.")
|
||||
|
||||
debugCmd.PersistentFlags().BoolVarP(&debugSystemInfoFlag, systemInfoFlag, "S", false, "Adds system information to the debug bundle")
|
||||
}
|
||||
|
||||
// SetupCloseHandler handles SIGTERM signal and exits with success
|
||||
@@ -246,6 +259,21 @@ var CLIBackOffSettings = &backoff.ExponentialBackOff{
|
||||
Clock: backoff.SystemClock,
|
||||
}
|
||||
|
||||
func getSetupKey() (string, error) {
|
||||
if setupKeyPath != "" && setupKey == "" {
|
||||
return getSetupKeyFromFile(setupKeyPath)
|
||||
}
|
||||
return setupKey, nil
|
||||
}
|
||||
|
||||
func getSetupKeyFromFile(setupKeyPath string) (string, error) {
|
||||
data, err := os.ReadFile(setupKeyPath)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to read setup key file: %v", err)
|
||||
}
|
||||
return strings.TrimSpace(string(data)), nil
|
||||
}
|
||||
|
||||
func handleRebrand(cmd *cobra.Command) error {
|
||||
var err error
|
||||
if logFile == defaultLogFile {
|
||||
|
||||
@@ -4,6 +4,10 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"testing"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
)
|
||||
|
||||
func TestInitCommands(t *testing.T) {
|
||||
@@ -34,3 +38,44 @@ func TestInitCommands(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetFlagsFromEnvVars(t *testing.T) {
|
||||
var cmd = &cobra.Command{
|
||||
Use: "netbird",
|
||||
Long: "test",
|
||||
SilenceUsage: true,
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
SetFlagsFromEnvVars(cmd)
|
||||
},
|
||||
}
|
||||
|
||||
cmd.PersistentFlags().StringSliceVar(&natExternalIPs, externalIPMapFlag, nil,
|
||||
`comma separated list of external IPs to map to the Wireguard interface`)
|
||||
cmd.PersistentFlags().StringVar(&interfaceName, interfaceNameFlag, iface.WgInterfaceDefault, "Wireguard interface name")
|
||||
cmd.PersistentFlags().BoolVar(&rosenpassEnabled, enableRosenpassFlag, false, "Enable Rosenpass feature Rosenpass.")
|
||||
cmd.PersistentFlags().Uint16Var(&wireguardPort, wireguardPortFlag, iface.DefaultWgPort, "Wireguard interface listening port")
|
||||
|
||||
t.Setenv("NB_EXTERNAL_IP_MAP", "abc,dec")
|
||||
t.Setenv("NB_INTERFACE_NAME", "test-name")
|
||||
t.Setenv("NB_ENABLE_ROSENPASS", "true")
|
||||
t.Setenv("NB_WIREGUARD_PORT", "10000")
|
||||
err := cmd.Execute()
|
||||
if err != nil {
|
||||
t.Fatalf("expected no error while running netbird command, got %v", err)
|
||||
}
|
||||
if len(natExternalIPs) != 2 {
|
||||
t.Errorf("expected 2 external ips, got %d", len(natExternalIPs))
|
||||
}
|
||||
if natExternalIPs[0] != "abc" || natExternalIPs[1] != "dec" {
|
||||
t.Errorf("expected abc,dec, got %s,%s", natExternalIPs[0], natExternalIPs[1])
|
||||
}
|
||||
if interfaceName != "test-name" {
|
||||
t.Errorf("expected test-name, got %s", interfaceName)
|
||||
}
|
||||
if !rosenpassEnabled {
|
||||
t.Errorf("expected rosenpassEnabled to be true, got false")
|
||||
}
|
||||
if wireguardPort != 10000 {
|
||||
t.Errorf("expected wireguardPort to be 10000, got %d", wireguardPort)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,174 +0,0 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
"google.golang.org/grpc/status"
|
||||
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
)
|
||||
|
||||
var appendFlag bool
|
||||
|
||||
var routesCmd = &cobra.Command{
|
||||
Use: "routes",
|
||||
Short: "Manage network routes",
|
||||
Long: `Commands to list, select, or deselect network routes.`,
|
||||
}
|
||||
|
||||
var routesListCmd = &cobra.Command{
|
||||
Use: "list",
|
||||
Aliases: []string{"ls"},
|
||||
Short: "List routes",
|
||||
Example: " netbird routes list",
|
||||
Long: "List all available network routes.",
|
||||
RunE: routesList,
|
||||
}
|
||||
|
||||
var routesSelectCmd = &cobra.Command{
|
||||
Use: "select route...|all",
|
||||
Short: "Select routes",
|
||||
Long: "Select a list of routes by identifiers or 'all' to clear all selections and to accept all (including new) routes.\nDefault mode is replace, use -a to append to already selected routes.",
|
||||
Example: " netbird routes select all\n netbird routes select route1 route2\n netbird routes select -a route3",
|
||||
Args: cobra.MinimumNArgs(1),
|
||||
RunE: routesSelect,
|
||||
}
|
||||
|
||||
var routesDeselectCmd = &cobra.Command{
|
||||
Use: "deselect route...|all",
|
||||
Short: "Deselect routes",
|
||||
Long: "Deselect previously selected routes by identifiers or 'all' to disable accepting any routes.",
|
||||
Example: " netbird routes deselect all\n netbird routes deselect route1 route2",
|
||||
Args: cobra.MinimumNArgs(1),
|
||||
RunE: routesDeselect,
|
||||
}
|
||||
|
||||
func init() {
|
||||
routesSelectCmd.PersistentFlags().BoolVarP(&appendFlag, "append", "a", false, "Append to current route selection instead of replacing")
|
||||
}
|
||||
|
||||
func routesList(cmd *cobra.Command, _ []string) error {
|
||||
conn, err := getClient(cmd)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
client := proto.NewDaemonServiceClient(conn)
|
||||
resp, err := client.ListRoutes(cmd.Context(), &proto.ListRoutesRequest{})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to list routes: %v", status.Convert(err).Message())
|
||||
}
|
||||
|
||||
if len(resp.Routes) == 0 {
|
||||
cmd.Println("No routes available.")
|
||||
return nil
|
||||
}
|
||||
|
||||
printRoutes(cmd, resp)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func printRoutes(cmd *cobra.Command, resp *proto.ListRoutesResponse) {
|
||||
cmd.Println("Available Routes:")
|
||||
for _, route := range resp.Routes {
|
||||
printRoute(cmd, route)
|
||||
}
|
||||
}
|
||||
|
||||
func printRoute(cmd *cobra.Command, route *proto.Route) {
|
||||
selectedStatus := getSelectedStatus(route)
|
||||
domains := route.GetDomains()
|
||||
|
||||
if len(domains) > 0 {
|
||||
printDomainRoute(cmd, route, domains, selectedStatus)
|
||||
} else {
|
||||
printNetworkRoute(cmd, route, selectedStatus)
|
||||
}
|
||||
}
|
||||
|
||||
func getSelectedStatus(route *proto.Route) string {
|
||||
if route.GetSelected() {
|
||||
return "Selected"
|
||||
}
|
||||
return "Not Selected"
|
||||
}
|
||||
|
||||
func printDomainRoute(cmd *cobra.Command, route *proto.Route, domains []string, selectedStatus string) {
|
||||
cmd.Printf("\n - ID: %s\n Domains: %s\n Status: %s\n", route.GetID(), strings.Join(domains, ", "), selectedStatus)
|
||||
resolvedIPs := route.GetResolvedIPs()
|
||||
|
||||
if len(resolvedIPs) > 0 {
|
||||
printResolvedIPs(cmd, domains, resolvedIPs)
|
||||
} else {
|
||||
cmd.Printf(" Resolved IPs: -\n")
|
||||
}
|
||||
}
|
||||
|
||||
func printNetworkRoute(cmd *cobra.Command, route *proto.Route, selectedStatus string) {
|
||||
cmd.Printf("\n - ID: %s\n Network: %s\n Status: %s\n", route.GetID(), route.GetNetwork(), selectedStatus)
|
||||
}
|
||||
|
||||
func printResolvedIPs(cmd *cobra.Command, domains []string, resolvedIPs map[string]*proto.IPList) {
|
||||
cmd.Printf(" Resolved IPs:\n")
|
||||
for _, domain := range domains {
|
||||
if ipList, exists := resolvedIPs[domain]; exists {
|
||||
cmd.Printf(" [%s]: %s\n", domain, strings.Join(ipList.GetIps(), ", "))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func routesSelect(cmd *cobra.Command, args []string) error {
|
||||
conn, err := getClient(cmd)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
client := proto.NewDaemonServiceClient(conn)
|
||||
req := &proto.SelectRoutesRequest{
|
||||
RouteIDs: args,
|
||||
}
|
||||
|
||||
if len(args) == 1 && args[0] == "all" {
|
||||
req.All = true
|
||||
} else if appendFlag {
|
||||
req.Append = true
|
||||
}
|
||||
|
||||
if _, err := client.SelectRoutes(cmd.Context(), req); err != nil {
|
||||
return fmt.Errorf("failed to select routes: %v", status.Convert(err).Message())
|
||||
}
|
||||
|
||||
cmd.Println("Routes selected successfully.")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func routesDeselect(cmd *cobra.Command, args []string) error {
|
||||
conn, err := getClient(cmd)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
client := proto.NewDaemonServiceClient(conn)
|
||||
req := &proto.SelectRoutesRequest{
|
||||
RouteIDs: args,
|
||||
}
|
||||
|
||||
if len(args) == 1 && args[0] == "all" {
|
||||
req.All = true
|
||||
}
|
||||
|
||||
if _, err := client.DeselectRoutes(cmd.Context(), req); err != nil {
|
||||
return fmt.Errorf("failed to deselect routes: %v", status.Convert(err).Message())
|
||||
}
|
||||
|
||||
cmd.Println("Routes deselected successfully.")
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -2,18 +2,23 @@ package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
|
||||
"github.com/kardianos/service"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/spf13/cobra"
|
||||
"google.golang.org/grpc"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
"github.com/netbirdio/netbird/client/server"
|
||||
)
|
||||
|
||||
type program struct {
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
serv *grpc.Server
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
serv *grpc.Server
|
||||
serverInstance *server.Server
|
||||
serverInstanceMu sync.Mutex
|
||||
}
|
||||
|
||||
func newProgram(ctx context.Context, cancel context.CancelFunc) *program {
|
||||
|
||||
@@ -61,6 +61,10 @@ func (p *program) Start(svc service.Service) error {
|
||||
}
|
||||
proto.RegisterDaemonServiceServer(p.serv, serverInstance)
|
||||
|
||||
p.serverInstanceMu.Lock()
|
||||
p.serverInstance = serverInstance
|
||||
p.serverInstanceMu.Unlock()
|
||||
|
||||
log.Printf("started daemon server: %v", split[1])
|
||||
if err := p.serv.Serve(listen); err != nil {
|
||||
log.Errorf("failed to serve daemon requests: %v", err)
|
||||
@@ -70,6 +74,16 @@ func (p *program) Start(svc service.Service) error {
|
||||
}
|
||||
|
||||
func (p *program) Stop(srv service.Service) error {
|
||||
p.serverInstanceMu.Lock()
|
||||
if p.serverInstance != nil {
|
||||
in := new(proto.DownRequest)
|
||||
_, err := p.serverInstance.Down(p.ctx, in)
|
||||
if err != nil {
|
||||
log.Errorf("failed to stop daemon: %v", err)
|
||||
}
|
||||
}
|
||||
p.serverInstanceMu.Unlock()
|
||||
|
||||
p.cancel()
|
||||
|
||||
if p.serv != nil {
|
||||
|
||||
@@ -31,6 +31,8 @@ var installCmd = &cobra.Command{
|
||||
configPath,
|
||||
"--log-level",
|
||||
logLevel,
|
||||
"--daemon-addr",
|
||||
daemonAddr,
|
||||
}
|
||||
|
||||
if managementURL != "" {
|
||||
|
||||
@@ -9,7 +9,6 @@ import (
|
||||
"strings"
|
||||
"syscall"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
@@ -73,7 +72,7 @@ var sshCmd = &cobra.Command{
|
||||
go func() {
|
||||
// blocking
|
||||
if err := runSSH(sshctx, host, []byte(config.SSHKey), cmd); err != nil {
|
||||
log.Debug(err)
|
||||
cmd.Printf("Error: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
cancel()
|
||||
|
||||
181
client/cmd/state.go
Normal file
181
client/cmd/state.go
Normal file
@@ -0,0 +1,181 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/spf13/cobra"
|
||||
"google.golang.org/grpc/status"
|
||||
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
)
|
||||
|
||||
var (
|
||||
allFlag bool
|
||||
)
|
||||
|
||||
var stateCmd = &cobra.Command{
|
||||
Use: "state",
|
||||
Short: "Manage daemon state",
|
||||
Long: "Provides commands for managing and inspecting the Netbird daemon state.",
|
||||
}
|
||||
|
||||
var stateListCmd = &cobra.Command{
|
||||
Use: "list",
|
||||
Aliases: []string{"ls"},
|
||||
Short: "List all stored states",
|
||||
Long: "Lists all registered states with their status and basic information.",
|
||||
Example: " netbird state list",
|
||||
RunE: stateList,
|
||||
}
|
||||
|
||||
var stateCleanCmd = &cobra.Command{
|
||||
Use: "clean [state-name]",
|
||||
Short: "Clean stored states",
|
||||
Long: `Clean specific state or all states. The daemon must not be running.
|
||||
This will perform cleanup operations and remove the state.`,
|
||||
Example: ` netbird state clean dns_state
|
||||
netbird state clean --all`,
|
||||
RunE: stateClean,
|
||||
PreRunE: func(cmd *cobra.Command, args []string) error {
|
||||
// Check mutual exclusivity between --all flag and state-name argument
|
||||
if allFlag && len(args) > 0 {
|
||||
return fmt.Errorf("cannot specify both --all flag and state name")
|
||||
}
|
||||
if !allFlag && len(args) != 1 {
|
||||
return fmt.Errorf("requires a state name argument or --all flag")
|
||||
}
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
var stateDeleteCmd = &cobra.Command{
|
||||
Use: "delete [state-name]",
|
||||
Short: "Delete stored states",
|
||||
Long: `Delete specific state or all states from storage. The daemon must not be running.
|
||||
This will remove the state without performing any cleanup operations.`,
|
||||
Example: ` netbird state delete dns_state
|
||||
netbird state delete --all`,
|
||||
RunE: stateDelete,
|
||||
PreRunE: func(cmd *cobra.Command, args []string) error {
|
||||
// Check mutual exclusivity between --all flag and state-name argument
|
||||
if allFlag && len(args) > 0 {
|
||||
return fmt.Errorf("cannot specify both --all flag and state name")
|
||||
}
|
||||
if !allFlag && len(args) != 1 {
|
||||
return fmt.Errorf("requires a state name argument or --all flag")
|
||||
}
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
func init() {
|
||||
rootCmd.AddCommand(stateCmd)
|
||||
stateCmd.AddCommand(stateListCmd, stateCleanCmd, stateDeleteCmd)
|
||||
|
||||
stateCleanCmd.Flags().BoolVarP(&allFlag, "all", "a", false, "Clean all states")
|
||||
stateDeleteCmd.Flags().BoolVarP(&allFlag, "all", "a", false, "Delete all states")
|
||||
}
|
||||
|
||||
func stateList(cmd *cobra.Command, _ []string) error {
|
||||
conn, err := getClient(cmd)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() {
|
||||
if err := conn.Close(); err != nil {
|
||||
log.Errorf(errCloseConnection, err)
|
||||
}
|
||||
}()
|
||||
|
||||
client := proto.NewDaemonServiceClient(conn)
|
||||
resp, err := client.ListStates(cmd.Context(), &proto.ListStatesRequest{})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to list states: %v", status.Convert(err).Message())
|
||||
}
|
||||
|
||||
cmd.Printf("\nStored states:\n\n")
|
||||
for _, state := range resp.States {
|
||||
cmd.Printf("- %s\n", state.Name)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func stateClean(cmd *cobra.Command, args []string) error {
|
||||
var stateName string
|
||||
if !allFlag {
|
||||
stateName = args[0]
|
||||
}
|
||||
|
||||
conn, err := getClient(cmd)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() {
|
||||
if err := conn.Close(); err != nil {
|
||||
log.Errorf(errCloseConnection, err)
|
||||
}
|
||||
}()
|
||||
|
||||
client := proto.NewDaemonServiceClient(conn)
|
||||
resp, err := client.CleanState(cmd.Context(), &proto.CleanStateRequest{
|
||||
StateName: stateName,
|
||||
All: allFlag,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to clean state: %v", status.Convert(err).Message())
|
||||
}
|
||||
|
||||
if resp.CleanedStates == 0 {
|
||||
cmd.Println("No states were cleaned")
|
||||
return nil
|
||||
}
|
||||
|
||||
if allFlag {
|
||||
cmd.Printf("Successfully cleaned %d states\n", resp.CleanedStates)
|
||||
} else {
|
||||
cmd.Printf("Successfully cleaned state %q\n", stateName)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func stateDelete(cmd *cobra.Command, args []string) error {
|
||||
var stateName string
|
||||
if !allFlag {
|
||||
stateName = args[0]
|
||||
}
|
||||
|
||||
conn, err := getClient(cmd)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() {
|
||||
if err := conn.Close(); err != nil {
|
||||
log.Errorf(errCloseConnection, err)
|
||||
}
|
||||
}()
|
||||
|
||||
client := proto.NewDaemonServiceClient(conn)
|
||||
resp, err := client.DeleteState(cmd.Context(), &proto.DeleteStateRequest{
|
||||
StateName: stateName,
|
||||
All: allFlag,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to delete state: %v", status.Convert(err).Message())
|
||||
}
|
||||
|
||||
if resp.DeletedStates == 0 {
|
||||
cmd.Println("No states were deleted")
|
||||
return nil
|
||||
}
|
||||
|
||||
if allFlag {
|
||||
cmd.Printf("Successfully deleted %d states\n", resp.DeletedStates)
|
||||
} else {
|
||||
cmd.Printf("Successfully deleted state %q\n", stateName)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -31,15 +31,16 @@ type peerStateDetailOutput struct {
|
||||
Status string `json:"status" yaml:"status"`
|
||||
LastStatusUpdate time.Time `json:"lastStatusUpdate" yaml:"lastStatusUpdate"`
|
||||
ConnType string `json:"connectionType" yaml:"connectionType"`
|
||||
Direct bool `json:"direct" yaml:"direct"`
|
||||
IceCandidateType iceCandidateType `json:"iceCandidateType" yaml:"iceCandidateType"`
|
||||
IceCandidateEndpoint iceCandidateType `json:"iceCandidateEndpoint" yaml:"iceCandidateEndpoint"`
|
||||
RelayAddress string `json:"relayAddress" yaml:"relayAddress"`
|
||||
LastWireguardHandshake time.Time `json:"lastWireguardHandshake" yaml:"lastWireguardHandshake"`
|
||||
TransferReceived int64 `json:"transferReceived" yaml:"transferReceived"`
|
||||
TransferSent int64 `json:"transferSent" yaml:"transferSent"`
|
||||
Latency time.Duration `json:"latency" yaml:"latency"`
|
||||
RosenpassEnabled bool `json:"quantumResistance" yaml:"quantumResistance"`
|
||||
Routes []string `json:"routes" yaml:"routes"`
|
||||
Networks []string `json:"networks" yaml:"networks"`
|
||||
}
|
||||
|
||||
type peersStateOutput struct {
|
||||
@@ -98,6 +99,7 @@ type statusOutputOverview struct {
|
||||
RosenpassEnabled bool `json:"quantumResistance" yaml:"quantumResistance"`
|
||||
RosenpassPermissive bool `json:"quantumResistancePermissive" yaml:"quantumResistancePermissive"`
|
||||
Routes []string `json:"routes" yaml:"routes"`
|
||||
Networks []string `json:"networks" yaml:"networks"`
|
||||
NSServerGroups []nsServerGroupStateOutput `json:"dnsServers" yaml:"dnsServers"`
|
||||
}
|
||||
|
||||
@@ -282,7 +284,8 @@ func convertToStatusOutputOverview(resp *proto.StatusResponse) statusOutputOverv
|
||||
FQDN: pbFullStatus.GetLocalPeerState().GetFqdn(),
|
||||
RosenpassEnabled: pbFullStatus.GetLocalPeerState().GetRosenpassEnabled(),
|
||||
RosenpassPermissive: pbFullStatus.GetLocalPeerState().GetRosenpassPermissive(),
|
||||
Routes: pbFullStatus.GetLocalPeerState().GetRoutes(),
|
||||
Routes: pbFullStatus.GetLocalPeerState().GetNetworks(),
|
||||
Networks: pbFullStatus.GetLocalPeerState().GetNetworks(),
|
||||
NSServerGroups: mapNSGroups(pbFullStatus.GetDnsServers()),
|
||||
}
|
||||
|
||||
@@ -335,16 +338,18 @@ func mapNSGroups(servers []*proto.NSGroupState) []nsServerGroupStateOutput {
|
||||
|
||||
func mapPeers(peers []*proto.PeerState) peersStateOutput {
|
||||
var peersStateDetail []peerStateDetailOutput
|
||||
localICE := ""
|
||||
remoteICE := ""
|
||||
localICEEndpoint := ""
|
||||
remoteICEEndpoint := ""
|
||||
connType := ""
|
||||
peersConnected := 0
|
||||
lastHandshake := time.Time{}
|
||||
transferReceived := int64(0)
|
||||
transferSent := int64(0)
|
||||
for _, pbPeerState := range peers {
|
||||
localICE := ""
|
||||
remoteICE := ""
|
||||
localICEEndpoint := ""
|
||||
remoteICEEndpoint := ""
|
||||
relayServerAddress := ""
|
||||
connType := ""
|
||||
lastHandshake := time.Time{}
|
||||
transferReceived := int64(0)
|
||||
transferSent := int64(0)
|
||||
|
||||
isPeerConnected := pbPeerState.ConnStatus == peer.StatusConnected.String()
|
||||
if skipDetailByFilters(pbPeerState, isPeerConnected) {
|
||||
continue
|
||||
@@ -360,6 +365,7 @@ func mapPeers(peers []*proto.PeerState) peersStateOutput {
|
||||
if pbPeerState.Relayed {
|
||||
connType = "Relayed"
|
||||
}
|
||||
relayServerAddress = pbPeerState.GetRelayAddress()
|
||||
lastHandshake = pbPeerState.GetLastWireguardHandshake().AsTime().Local()
|
||||
transferReceived = pbPeerState.GetBytesRx()
|
||||
transferSent = pbPeerState.GetBytesTx()
|
||||
@@ -372,7 +378,6 @@ func mapPeers(peers []*proto.PeerState) peersStateOutput {
|
||||
Status: pbPeerState.GetConnStatus(),
|
||||
LastStatusUpdate: timeLocal,
|
||||
ConnType: connType,
|
||||
Direct: pbPeerState.GetDirect(),
|
||||
IceCandidateType: iceCandidateType{
|
||||
Local: localICE,
|
||||
Remote: remoteICE,
|
||||
@@ -381,13 +386,15 @@ func mapPeers(peers []*proto.PeerState) peersStateOutput {
|
||||
Local: localICEEndpoint,
|
||||
Remote: remoteICEEndpoint,
|
||||
},
|
||||
RelayAddress: relayServerAddress,
|
||||
FQDN: pbPeerState.GetFqdn(),
|
||||
LastWireguardHandshake: lastHandshake,
|
||||
TransferReceived: transferReceived,
|
||||
TransferSent: transferSent,
|
||||
Latency: pbPeerState.GetLatency().AsDuration(),
|
||||
RosenpassEnabled: pbPeerState.GetRosenpassEnabled(),
|
||||
Routes: pbPeerState.GetRoutes(),
|
||||
Routes: pbPeerState.GetNetworks(),
|
||||
Networks: pbPeerState.GetNetworks(),
|
||||
}
|
||||
|
||||
peersStateDetail = append(peersStateDetail, peerState)
|
||||
@@ -488,10 +495,10 @@ func parseGeneralSummary(overview statusOutputOverview, showURL bool, showRelays
|
||||
relaysString = fmt.Sprintf("%d/%d Available", overview.Relays.Available, overview.Relays.Total)
|
||||
}
|
||||
|
||||
routes := "-"
|
||||
if len(overview.Routes) > 0 {
|
||||
sort.Strings(overview.Routes)
|
||||
routes = strings.Join(overview.Routes, ", ")
|
||||
networks := "-"
|
||||
if len(overview.Networks) > 0 {
|
||||
sort.Strings(overview.Networks)
|
||||
networks = strings.Join(overview.Networks, ", ")
|
||||
}
|
||||
|
||||
var dnsServersString string
|
||||
@@ -553,6 +560,7 @@ func parseGeneralSummary(overview statusOutputOverview, showURL bool, showRelays
|
||||
"Interface type: %s\n"+
|
||||
"Quantum resistance: %s\n"+
|
||||
"Routes: %s\n"+
|
||||
"Networks: %s\n"+
|
||||
"Peers count: %s\n",
|
||||
fmt.Sprintf("%s/%s%s", goos, goarch, goarm),
|
||||
overview.DaemonVersion,
|
||||
@@ -565,7 +573,8 @@ func parseGeneralSummary(overview statusOutputOverview, showURL bool, showRelays
|
||||
interfaceIP,
|
||||
interfaceTypeString,
|
||||
rosenpassEnabledStatus,
|
||||
routes,
|
||||
networks,
|
||||
networks,
|
||||
peersCountString,
|
||||
)
|
||||
return summary
|
||||
@@ -628,10 +637,10 @@ func parsePeers(peers peersStateOutput, rosenpassEnabled, rosenpassPermissive bo
|
||||
}
|
||||
}
|
||||
|
||||
routes := "-"
|
||||
if len(peerState.Routes) > 0 {
|
||||
sort.Strings(peerState.Routes)
|
||||
routes = strings.Join(peerState.Routes, ", ")
|
||||
networks := "-"
|
||||
if len(peerState.Networks) > 0 {
|
||||
sort.Strings(peerState.Networks)
|
||||
networks = strings.Join(peerState.Networks, ", ")
|
||||
}
|
||||
|
||||
peerString := fmt.Sprintf(
|
||||
@@ -641,31 +650,33 @@ func parsePeers(peers peersStateOutput, rosenpassEnabled, rosenpassPermissive bo
|
||||
" Status: %s\n"+
|
||||
" -- detail --\n"+
|
||||
" Connection type: %s\n"+
|
||||
" Direct: %t\n"+
|
||||
" ICE candidate (Local/Remote): %s/%s\n"+
|
||||
" ICE candidate endpoints (Local/Remote): %s/%s\n"+
|
||||
" Relay server address: %s\n"+
|
||||
" Last connection update: %s\n"+
|
||||
" Last WireGuard handshake: %s\n"+
|
||||
" Transfer status (received/sent) %s/%s\n"+
|
||||
" Quantum resistance: %s\n"+
|
||||
" Routes: %s\n"+
|
||||
" Networks: %s\n"+
|
||||
" Latency: %s\n",
|
||||
peerState.FQDN,
|
||||
peerState.IP,
|
||||
peerState.PubKey,
|
||||
peerState.Status,
|
||||
peerState.ConnType,
|
||||
peerState.Direct,
|
||||
localICE,
|
||||
remoteICE,
|
||||
localICEEndpoint,
|
||||
remoteICEEndpoint,
|
||||
peerState.RelayAddress,
|
||||
timeAgo(peerState.LastStatusUpdate),
|
||||
timeAgo(peerState.LastWireguardHandshake),
|
||||
toIEC(peerState.TransferReceived),
|
||||
toIEC(peerState.TransferSent),
|
||||
rosenpassEnabledStatus,
|
||||
routes,
|
||||
networks,
|
||||
networks,
|
||||
peerState.Latency.String(),
|
||||
)
|
||||
|
||||
@@ -677,7 +688,7 @@ func parsePeers(peers peersStateOutput, rosenpassEnabled, rosenpassPermissive bo
|
||||
func skipDetailByFilters(peerState *proto.PeerState, isConnected bool) bool {
|
||||
statusEval := false
|
||||
ipEval := false
|
||||
nameEval := false
|
||||
nameEval := true
|
||||
|
||||
if statusFilter != "" {
|
||||
lowerStatusFilter := strings.ToLower(statusFilter)
|
||||
@@ -697,11 +708,13 @@ func skipDetailByFilters(peerState *proto.PeerState, isConnected bool) bool {
|
||||
|
||||
if len(prefixNamesFilter) > 0 {
|
||||
for prefixNameFilter := range prefixNamesFilterMap {
|
||||
if !strings.HasPrefix(peerState.Fqdn, prefixNameFilter) {
|
||||
nameEval = true
|
||||
if strings.HasPrefix(peerState.Fqdn, prefixNameFilter) {
|
||||
nameEval = false
|
||||
break
|
||||
}
|
||||
}
|
||||
} else {
|
||||
nameEval = false
|
||||
}
|
||||
|
||||
return statusEval || ipEval || nameEval
|
||||
@@ -802,12 +815,23 @@ func anonymizePeerDetail(a *anonymize.Anonymizer, peer *peerStateDetailOutput) {
|
||||
if remoteIP, port, err := net.SplitHostPort(peer.IceCandidateEndpoint.Remote); err == nil {
|
||||
peer.IceCandidateEndpoint.Remote = fmt.Sprintf("%s:%s", a.AnonymizeIPString(remoteIP), port)
|
||||
}
|
||||
|
||||
peer.RelayAddress = a.AnonymizeURI(peer.RelayAddress)
|
||||
|
||||
for i, route := range peer.Networks {
|
||||
peer.Networks[i] = a.AnonymizeIPString(route)
|
||||
}
|
||||
|
||||
for i, route := range peer.Networks {
|
||||
peer.Networks[i] = a.AnonymizeRoute(route)
|
||||
}
|
||||
|
||||
for i, route := range peer.Routes {
|
||||
peer.Routes[i] = a.AnonymizeIPString(route)
|
||||
}
|
||||
|
||||
for i, route := range peer.Routes {
|
||||
peer.Routes[i] = anonymizeRoute(a, route)
|
||||
peer.Routes[i] = a.AnonymizeRoute(route)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -842,22 +866,13 @@ func anonymizeOverview(a *anonymize.Anonymizer, overview *statusOutputOverview)
|
||||
}
|
||||
}
|
||||
|
||||
for i, route := range overview.Networks {
|
||||
overview.Networks[i] = a.AnonymizeRoute(route)
|
||||
}
|
||||
|
||||
for i, route := range overview.Routes {
|
||||
overview.Routes[i] = anonymizeRoute(a, route)
|
||||
overview.Routes[i] = a.AnonymizeRoute(route)
|
||||
}
|
||||
|
||||
overview.FQDN = a.AnonymizeDomain(overview.FQDN)
|
||||
}
|
||||
|
||||
func anonymizeRoute(a *anonymize.Anonymizer, route string) string {
|
||||
prefix, err := netip.ParsePrefix(route)
|
||||
if err == nil {
|
||||
ip := a.AnonymizeIPString(prefix.Addr().String())
|
||||
return fmt.Sprintf("%s/%d", ip, prefix.Bits())
|
||||
}
|
||||
domains := strings.Split(route, ", ")
|
||||
for i, domain := range domains {
|
||||
domains[i] = a.AnonymizeDomain(domain)
|
||||
}
|
||||
return strings.Join(domains, ", ")
|
||||
}
|
||||
|
||||
@@ -37,7 +37,6 @@ var resp = &proto.StatusResponse{
|
||||
ConnStatus: "Connected",
|
||||
ConnStatusUpdate: timestamppb.New(time.Date(2001, time.Month(1), 1, 1, 1, 1, 0, time.UTC)),
|
||||
Relayed: false,
|
||||
Direct: true,
|
||||
LocalIceCandidateType: "",
|
||||
RemoteIceCandidateType: "",
|
||||
LocalIceCandidateEndpoint: "",
|
||||
@@ -45,7 +44,7 @@ var resp = &proto.StatusResponse{
|
||||
LastWireguardHandshake: timestamppb.New(time.Date(2001, time.Month(1), 1, 1, 1, 2, 0, time.UTC)),
|
||||
BytesRx: 200,
|
||||
BytesTx: 100,
|
||||
Routes: []string{
|
||||
Networks: []string{
|
||||
"10.1.0.0/24",
|
||||
},
|
||||
Latency: durationpb.New(time.Duration(10000000)),
|
||||
@@ -57,7 +56,6 @@ var resp = &proto.StatusResponse{
|
||||
ConnStatus: "Connected",
|
||||
ConnStatusUpdate: timestamppb.New(time.Date(2002, time.Month(2), 2, 2, 2, 2, 0, time.UTC)),
|
||||
Relayed: true,
|
||||
Direct: false,
|
||||
LocalIceCandidateType: "relay",
|
||||
RemoteIceCandidateType: "prflx",
|
||||
LocalIceCandidateEndpoint: "10.0.0.1:10001",
|
||||
@@ -95,7 +93,7 @@ var resp = &proto.StatusResponse{
|
||||
PubKey: "Some-Pub-Key",
|
||||
KernelInterface: true,
|
||||
Fqdn: "some-localhost.awesome-domain.com",
|
||||
Routes: []string{
|
||||
Networks: []string{
|
||||
"10.10.0.0/24",
|
||||
},
|
||||
},
|
||||
@@ -137,7 +135,6 @@ var overview = statusOutputOverview{
|
||||
Status: "Connected",
|
||||
LastStatusUpdate: time.Date(2001, 1, 1, 1, 1, 1, 0, time.UTC),
|
||||
ConnType: "P2P",
|
||||
Direct: true,
|
||||
IceCandidateType: iceCandidateType{
|
||||
Local: "",
|
||||
Remote: "",
|
||||
@@ -152,6 +149,9 @@ var overview = statusOutputOverview{
|
||||
Routes: []string{
|
||||
"10.1.0.0/24",
|
||||
},
|
||||
Networks: []string{
|
||||
"10.1.0.0/24",
|
||||
},
|
||||
Latency: time.Duration(10000000),
|
||||
},
|
||||
{
|
||||
@@ -161,7 +161,6 @@ var overview = statusOutputOverview{
|
||||
Status: "Connected",
|
||||
LastStatusUpdate: time.Date(2002, 2, 2, 2, 2, 2, 0, time.UTC),
|
||||
ConnType: "Relayed",
|
||||
Direct: false,
|
||||
IceCandidateType: iceCandidateType{
|
||||
Local: "relay",
|
||||
Remote: "prflx",
|
||||
@@ -234,6 +233,9 @@ var overview = statusOutputOverview{
|
||||
Routes: []string{
|
||||
"10.10.0.0/24",
|
||||
},
|
||||
Networks: []string{
|
||||
"10.10.0.0/24",
|
||||
},
|
||||
}
|
||||
|
||||
func TestConversionFromFullStatusToOutputOverview(t *testing.T) {
|
||||
@@ -283,7 +285,6 @@ func TestParsingToJSON(t *testing.T) {
|
||||
"status": "Connected",
|
||||
"lastStatusUpdate": "2001-01-01T01:01:01Z",
|
||||
"connectionType": "P2P",
|
||||
"direct": true,
|
||||
"iceCandidateType": {
|
||||
"local": "",
|
||||
"remote": ""
|
||||
@@ -292,6 +293,7 @@ func TestParsingToJSON(t *testing.T) {
|
||||
"local": "",
|
||||
"remote": ""
|
||||
},
|
||||
"relayAddress": "",
|
||||
"lastWireguardHandshake": "2001-01-01T01:01:02Z",
|
||||
"transferReceived": 200,
|
||||
"transferSent": 100,
|
||||
@@ -299,6 +301,9 @@ func TestParsingToJSON(t *testing.T) {
|
||||
"quantumResistance": false,
|
||||
"routes": [
|
||||
"10.1.0.0/24"
|
||||
],
|
||||
"networks": [
|
||||
"10.1.0.0/24"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -308,7 +313,6 @@ func TestParsingToJSON(t *testing.T) {
|
||||
"status": "Connected",
|
||||
"lastStatusUpdate": "2002-02-02T02:02:02Z",
|
||||
"connectionType": "Relayed",
|
||||
"direct": false,
|
||||
"iceCandidateType": {
|
||||
"local": "relay",
|
||||
"remote": "prflx"
|
||||
@@ -317,12 +321,14 @@ func TestParsingToJSON(t *testing.T) {
|
||||
"local": "10.0.0.1:10001",
|
||||
"remote": "10.0.10.1:10002"
|
||||
},
|
||||
"relayAddress": "",
|
||||
"lastWireguardHandshake": "2002-02-02T02:02:03Z",
|
||||
"transferReceived": 2000,
|
||||
"transferSent": 1000,
|
||||
"latency": 10000000,
|
||||
"quantumResistance": false,
|
||||
"routes": null
|
||||
"routes": null,
|
||||
"networks": null
|
||||
}
|
||||
]
|
||||
},
|
||||
@@ -363,6 +369,9 @@ func TestParsingToJSON(t *testing.T) {
|
||||
"routes": [
|
||||
"10.10.0.0/24"
|
||||
],
|
||||
"networks": [
|
||||
"10.10.0.0/24"
|
||||
],
|
||||
"dnsServers": [
|
||||
{
|
||||
"servers": [
|
||||
@@ -408,13 +417,13 @@ func TestParsingToYAML(t *testing.T) {
|
||||
status: Connected
|
||||
lastStatusUpdate: 2001-01-01T01:01:01Z
|
||||
connectionType: P2P
|
||||
direct: true
|
||||
iceCandidateType:
|
||||
local: ""
|
||||
remote: ""
|
||||
iceCandidateEndpoint:
|
||||
local: ""
|
||||
remote: ""
|
||||
relayAddress: ""
|
||||
lastWireguardHandshake: 2001-01-01T01:01:02Z
|
||||
transferReceived: 200
|
||||
transferSent: 100
|
||||
@@ -422,25 +431,28 @@ func TestParsingToYAML(t *testing.T) {
|
||||
quantumResistance: false
|
||||
routes:
|
||||
- 10.1.0.0/24
|
||||
networks:
|
||||
- 10.1.0.0/24
|
||||
- fqdn: peer-2.awesome-domain.com
|
||||
netbirdIp: 192.168.178.102
|
||||
publicKey: Pubkey2
|
||||
status: Connected
|
||||
lastStatusUpdate: 2002-02-02T02:02:02Z
|
||||
connectionType: Relayed
|
||||
direct: false
|
||||
iceCandidateType:
|
||||
local: relay
|
||||
remote: prflx
|
||||
iceCandidateEndpoint:
|
||||
local: 10.0.0.1:10001
|
||||
remote: 10.0.10.1:10002
|
||||
relayAddress: ""
|
||||
lastWireguardHandshake: 2002-02-02T02:02:03Z
|
||||
transferReceived: 2000
|
||||
transferSent: 1000
|
||||
latency: 10ms
|
||||
quantumResistance: false
|
||||
routes: []
|
||||
networks: []
|
||||
cliVersion: development
|
||||
daemonVersion: 0.14.1
|
||||
management:
|
||||
@@ -469,6 +481,8 @@ quantumResistance: false
|
||||
quantumResistancePermissive: false
|
||||
routes:
|
||||
- 10.10.0.0/24
|
||||
networks:
|
||||
- 10.10.0.0/24
|
||||
dnsServers:
|
||||
- servers:
|
||||
- 8.8.8.8:53
|
||||
@@ -505,14 +519,15 @@ func TestParsingToDetail(t *testing.T) {
|
||||
Status: Connected
|
||||
-- detail --
|
||||
Connection type: P2P
|
||||
Direct: true
|
||||
ICE candidate (Local/Remote): -/-
|
||||
ICE candidate endpoints (Local/Remote): -/-
|
||||
Relay server address:
|
||||
Last connection update: %s
|
||||
Last WireGuard handshake: %s
|
||||
Transfer status (received/sent) 200 B/100 B
|
||||
Quantum resistance: false
|
||||
Routes: 10.1.0.0/24
|
||||
Networks: 10.1.0.0/24
|
||||
Latency: 10ms
|
||||
|
||||
peer-2.awesome-domain.com:
|
||||
@@ -521,14 +536,15 @@ func TestParsingToDetail(t *testing.T) {
|
||||
Status: Connected
|
||||
-- detail --
|
||||
Connection type: Relayed
|
||||
Direct: false
|
||||
ICE candidate (Local/Remote): relay/prflx
|
||||
ICE candidate endpoints (Local/Remote): 10.0.0.1:10001/10.0.10.1:10002
|
||||
Relay server address:
|
||||
Last connection update: %s
|
||||
Last WireGuard handshake: %s
|
||||
Transfer status (received/sent) 2.0 KiB/1000 B
|
||||
Quantum resistance: false
|
||||
Routes: -
|
||||
Networks: -
|
||||
Latency: 10ms
|
||||
|
||||
OS: %s/%s
|
||||
@@ -547,6 +563,7 @@ NetBird IP: 192.168.178.100/16
|
||||
Interface type: Kernel
|
||||
Quantum resistance: false
|
||||
Routes: 10.10.0.0/24
|
||||
Networks: 10.10.0.0/24
|
||||
Peers count: 2/2 Connected
|
||||
`, lastConnectionUpdate1, lastHandshake1, lastConnectionUpdate2, lastHandshake2, runtime.GOOS, runtime.GOARCH, overview.CliVersion)
|
||||
|
||||
@@ -568,6 +585,7 @@ NetBird IP: 192.168.178.100/16
|
||||
Interface type: Kernel
|
||||
Quantum resistance: false
|
||||
Routes: 10.10.0.0/24
|
||||
Networks: 10.10.0.0/24
|
||||
Peers count: 2/2 Connected
|
||||
`
|
||||
|
||||
|
||||
31
client/cmd/system.go
Normal file
31
client/cmd/system.go
Normal file
@@ -0,0 +1,31 @@
|
||||
package cmd
|
||||
|
||||
// Flag constants for system configuration
|
||||
const (
|
||||
disableClientRoutesFlag = "disable-client-routes"
|
||||
disableServerRoutesFlag = "disable-server-routes"
|
||||
disableDNSFlag = "disable-dns"
|
||||
disableFirewallFlag = "disable-firewall"
|
||||
)
|
||||
|
||||
var (
|
||||
disableClientRoutes bool
|
||||
disableServerRoutes bool
|
||||
disableDNS bool
|
||||
disableFirewall bool
|
||||
)
|
||||
|
||||
func init() {
|
||||
// Add system flags to upCmd
|
||||
upCmd.PersistentFlags().BoolVar(&disableClientRoutes, disableClientRoutesFlag, false,
|
||||
"Disable client routes. If enabled, the client won't process client routes received from the management service.")
|
||||
|
||||
upCmd.PersistentFlags().BoolVar(&disableServerRoutes, disableServerRoutesFlag, false,
|
||||
"Disable server routes. If enabled, the client won't act as a router for server routes received from the management service.")
|
||||
|
||||
upCmd.PersistentFlags().BoolVar(&disableDNS, disableDNSFlag, false,
|
||||
"Disable DNS. If enabled, the client won't configure DNS settings.")
|
||||
|
||||
upCmd.PersistentFlags().BoolVar(&disableFirewall, disableFirewallFlag, false,
|
||||
"Disable firewall configuration. If enabled, the client won't modify firewall rules.")
|
||||
}
|
||||
@@ -3,7 +3,6 @@ package cmd
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -11,6 +10,9 @@ import (
|
||||
"go.opentelemetry.io/otel"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/settings"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
|
||||
"github.com/netbirdio/netbird/util"
|
||||
|
||||
@@ -33,18 +35,12 @@ func startTestingServices(t *testing.T) string {
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
testDir := t.TempDir()
|
||||
config.Datadir = testDir
|
||||
err = util.CopyFileContents("../testdata/store.json", filepath.Join(testDir, "store.json"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
_, signalLis := startSignal(t)
|
||||
signalAddr := signalLis.Addr().String()
|
||||
config.Signal.URI = signalAddr
|
||||
|
||||
_, mgmLis := startManagement(t, config)
|
||||
_, mgmLis := startManagement(t, config, "../testdata/store.sql")
|
||||
mgmAddr := mgmLis.Addr().String()
|
||||
return mgmAddr
|
||||
}
|
||||
@@ -56,7 +52,7 @@ func startSignal(t *testing.T) (*grpc.Server, net.Listener) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
s := grpc.NewServer()
|
||||
srv, err := sig.NewServer(otel.Meter(""))
|
||||
srv, err := sig.NewServer(context.Background(), otel.Meter(""))
|
||||
require.NoError(t, err)
|
||||
|
||||
sigProto.RegisterSignalExchangeServer(s, srv)
|
||||
@@ -69,14 +65,15 @@ func startSignal(t *testing.T) (*grpc.Server, net.Listener) {
|
||||
return s, lis
|
||||
}
|
||||
|
||||
func startManagement(t *testing.T, config *mgmt.Config) (*grpc.Server, net.Listener) {
|
||||
func startManagement(t *testing.T, config *mgmt.Config, testFile string) (*grpc.Server, net.Listener) {
|
||||
t.Helper()
|
||||
|
||||
lis, err := net.Listen("tcp", ":0")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
s := grpc.NewServer()
|
||||
store, cleanUp, err := mgmt.NewTestStoreFromJson(context.Background(), config.Datadir)
|
||||
store, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), testFile, t.TempDir())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -88,12 +85,17 @@ func startManagement(t *testing.T, config *mgmt.Config) (*grpc.Server, net.Liste
|
||||
return nil, nil
|
||||
}
|
||||
iv, _ := integrations.NewIntegratedValidator(context.Background(), eventStore)
|
||||
accountManager, err := mgmt.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, iv)
|
||||
|
||||
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
|
||||
require.NoError(t, err)
|
||||
|
||||
accountManager, err := mgmt.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, iv, metrics)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
turnManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig)
|
||||
mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, peersUpdateManager, turnManager, nil, nil)
|
||||
|
||||
secretsManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay)
|
||||
mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, settings.NewManager(store), peersUpdateManager, secretsManager, nil, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
@@ -15,11 +15,11 @@ import (
|
||||
gstatus "google.golang.org/grpc/status"
|
||||
"google.golang.org/protobuf/types/known/durationpb"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
"github.com/netbirdio/netbird/client/system"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
|
||||
@@ -48,6 +48,7 @@ func init() {
|
||||
)
|
||||
upCmd.PersistentFlags().StringSliceVar(&extraIFaceBlackList, extraIFaceBlackListFlag, nil, "Extra list of default interfaces to ignore for listening")
|
||||
upCmd.PersistentFlags().DurationVar(&dnsRouteInterval, dnsRouteIntervalFlag, time.Minute, "DNS route update interval")
|
||||
upCmd.PersistentFlags().BoolVar(&blockLANAccess, blockLANAccessFlag, false, "Block access to local networks (LAN) when using this peer as a router or exit node")
|
||||
}
|
||||
|
||||
func upFunc(cmd *cobra.Command, args []string) error {
|
||||
@@ -147,6 +148,28 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command) error {
|
||||
ic.DNSRouteInterval = &dnsRouteInterval
|
||||
}
|
||||
|
||||
if cmd.Flag(disableClientRoutesFlag).Changed {
|
||||
ic.DisableClientRoutes = &disableClientRoutes
|
||||
}
|
||||
if cmd.Flag(disableServerRoutesFlag).Changed {
|
||||
ic.DisableServerRoutes = &disableServerRoutes
|
||||
}
|
||||
if cmd.Flag(disableDNSFlag).Changed {
|
||||
ic.DisableDNS = &disableDNS
|
||||
}
|
||||
if cmd.Flag(disableFirewallFlag).Changed {
|
||||
ic.DisableFirewall = &disableFirewall
|
||||
}
|
||||
|
||||
if cmd.Flag(blockLANAccessFlag).Changed {
|
||||
ic.BlockLANAccess = &blockLANAccess
|
||||
}
|
||||
|
||||
providedSetupKey, err := getSetupKey()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
config, err := internal.UpdateOrCreateConfig(ic)
|
||||
if err != nil {
|
||||
return fmt.Errorf("get config file: %v", err)
|
||||
@@ -154,7 +177,7 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command) error {
|
||||
|
||||
config, _ = internal.UpdateOldManagementURL(ctx, config, configPath)
|
||||
|
||||
err = foregroundLogin(ctx, cmd, config, setupKey)
|
||||
err = foregroundLogin(ctx, cmd, config, providedSetupKey)
|
||||
if err != nil {
|
||||
return fmt.Errorf("foreground login failed: %v", err)
|
||||
}
|
||||
@@ -163,7 +186,10 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command) error {
|
||||
ctx, cancel = context.WithCancel(ctx)
|
||||
SetupCloseHandler(ctx, cancel)
|
||||
|
||||
connectClient := internal.NewConnectClient(ctx, config, peer.NewRecorder(config.ManagementURL.String()))
|
||||
r := peer.NewRecorder(config.ManagementURL.String())
|
||||
r.GetFullStatus()
|
||||
|
||||
connectClient := internal.NewConnectClient(ctx, config, r)
|
||||
return connectClient.Run()
|
||||
}
|
||||
|
||||
@@ -199,8 +225,13 @@ func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
providedSetupKey, err := getSetupKey()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
loginRequest := proto.LoginRequest{
|
||||
SetupKey: setupKey,
|
||||
SetupKey: providedSetupKey,
|
||||
ManagementUrl: managementURL,
|
||||
AdminURL: adminURL,
|
||||
NatExternalIPs: natExternalIPs,
|
||||
@@ -251,6 +282,23 @@ func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error {
|
||||
loginRequest.DnsRouteInterval = durationpb.New(dnsRouteInterval)
|
||||
}
|
||||
|
||||
if cmd.Flag(disableClientRoutesFlag).Changed {
|
||||
loginRequest.DisableClientRoutes = &disableClientRoutes
|
||||
}
|
||||
if cmd.Flag(disableServerRoutesFlag).Changed {
|
||||
loginRequest.DisableServerRoutes = &disableServerRoutes
|
||||
}
|
||||
if cmd.Flag(disableDNSFlag).Changed {
|
||||
loginRequest.DisableDns = &disableDNS
|
||||
}
|
||||
if cmd.Flag(disableFirewallFlag).Changed {
|
||||
loginRequest.DisableFirewall = &disableFirewall
|
||||
}
|
||||
|
||||
if cmd.Flag(blockLANAccessFlag).Changed {
|
||||
loginRequest.BlockLanAccess = &blockLANAccess
|
||||
}
|
||||
|
||||
var loginErr error
|
||||
|
||||
var loginResp *proto.LoginResponse
|
||||
|
||||
@@ -2,6 +2,7 @@ package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -40,6 +41,36 @@ func TestUpDaemon(t *testing.T) {
|
||||
return
|
||||
}
|
||||
|
||||
// Test the setup-key-file flag.
|
||||
tempFile, err := os.CreateTemp("", "setup-key")
|
||||
if err != nil {
|
||||
t.Errorf("could not create temp file, got error %v", err)
|
||||
return
|
||||
}
|
||||
defer os.Remove(tempFile.Name())
|
||||
if _, err := tempFile.Write([]byte("A2C8E62B-38F5-4553-B31E-DD66C696CEBB")); err != nil {
|
||||
t.Errorf("could not write to temp file, got error %v", err)
|
||||
return
|
||||
}
|
||||
if err := tempFile.Close(); err != nil {
|
||||
t.Errorf("unable to close file, got error %v", err)
|
||||
}
|
||||
rootCmd.SetArgs([]string{
|
||||
"login",
|
||||
"--daemon-addr", "tcp://" + cliAddr,
|
||||
"--setup-key-file", tempFile.Name(),
|
||||
"--log-file", "",
|
||||
})
|
||||
if err := rootCmd.Execute(); err != nil {
|
||||
t.Errorf("expected no error while running up command, got %v", err)
|
||||
return
|
||||
}
|
||||
time.Sleep(time.Second * 3)
|
||||
if status, err := state.Status(); err != nil && status != internal.StatusIdle {
|
||||
t.Errorf("wrong status after login: %s, %v", internal.StatusIdle, err)
|
||||
return
|
||||
}
|
||||
|
||||
rootCmd.SetArgs([]string{
|
||||
"up",
|
||||
"--daemon-addr", "tcp://" + cliAddr,
|
||||
|
||||
24
client/configs/configs.go
Normal file
24
client/configs/configs.go
Normal file
@@ -0,0 +1,24 @@
|
||||
package configs
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
)
|
||||
|
||||
var StateDir string
|
||||
|
||||
func init() {
|
||||
StateDir = os.Getenv("NB_STATE_DIR")
|
||||
if StateDir != "" {
|
||||
return
|
||||
}
|
||||
switch runtime.GOOS {
|
||||
case "windows":
|
||||
StateDir = filepath.Join(os.Getenv("PROGRAMDATA"), "Netbird")
|
||||
case "darwin", "linux":
|
||||
StateDir = "/var/lib/netbird"
|
||||
case "freebsd", "openbsd", "netbsd", "dragonfly":
|
||||
StateDir = "/var/db/netbird"
|
||||
}
|
||||
}
|
||||
@@ -8,8 +8,8 @@ import (
|
||||
)
|
||||
|
||||
func formatError(es []error) string {
|
||||
if len(es) == 0 {
|
||||
return fmt.Sprintf("0 error occurred:\n\t* %s", es[0])
|
||||
if len(es) == 1 {
|
||||
return fmt.Sprintf("1 error occurred:\n\t* %s", es[0])
|
||||
}
|
||||
|
||||
points := make([]string, len(es))
|
||||
|
||||
@@ -3,7 +3,6 @@
|
||||
package firewall
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"runtime"
|
||||
|
||||
@@ -11,10 +10,11 @@ import (
|
||||
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/firewall/uspfilter"
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
)
|
||||
|
||||
// NewFirewall creates a firewall manager instance
|
||||
func NewFirewall(context context.Context, iface IFaceMapper) (firewall.Manager, error) {
|
||||
func NewFirewall(iface IFaceMapper, _ *statemanager.Manager) (firewall.Manager, error) {
|
||||
if !iface.IsUserspaceBind() {
|
||||
return nil, fmt.Errorf("not implemented for this OS: %s", runtime.GOOS)
|
||||
}
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
package firewall
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
@@ -15,6 +15,7 @@ import (
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
nbnftables "github.com/netbirdio/netbird/client/firewall/nftables"
|
||||
"github.com/netbirdio/netbird/client/firewall/uspfilter"
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -32,54 +33,65 @@ const SKIP_NFTABLES_ENV = "NB_SKIP_NFTABLES_CHECK"
|
||||
// FWType is the type for the firewall type
|
||||
type FWType int
|
||||
|
||||
func NewFirewall(context context.Context, iface IFaceMapper) (firewall.Manager, error) {
|
||||
func NewFirewall(iface IFaceMapper, stateManager *statemanager.Manager) (firewall.Manager, error) {
|
||||
// on the linux system we try to user nftables or iptables
|
||||
// in any case, because we need to allow netbird interface traffic
|
||||
// so we use AllowNetbird traffic from these firewall managers
|
||||
// for the userspace packet filtering firewall
|
||||
var fm firewall.Manager
|
||||
var errFw error
|
||||
fm, err := createNativeFirewall(iface, stateManager)
|
||||
|
||||
if !iface.IsUserspaceBind() {
|
||||
return fm, err
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
log.Warnf("failed to create native firewall: %v. Proceeding with userspace", err)
|
||||
}
|
||||
return createUserspaceFirewall(iface, fm)
|
||||
}
|
||||
|
||||
func createNativeFirewall(iface IFaceMapper, stateManager *statemanager.Manager) (firewall.Manager, error) {
|
||||
fm, err := createFW(iface)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create firewall: %s", err)
|
||||
}
|
||||
|
||||
if err = fm.Init(stateManager); err != nil {
|
||||
return nil, fmt.Errorf("init firewall: %s", err)
|
||||
}
|
||||
|
||||
return fm, nil
|
||||
}
|
||||
|
||||
func createFW(iface IFaceMapper) (firewall.Manager, error) {
|
||||
switch check() {
|
||||
case IPTABLES:
|
||||
log.Info("creating an iptables firewall manager")
|
||||
fm, errFw = nbiptables.Create(context, iface)
|
||||
if errFw != nil {
|
||||
log.Errorf("failed to create iptables manager: %s", errFw)
|
||||
}
|
||||
return nbiptables.Create(iface)
|
||||
case NFTABLES:
|
||||
log.Info("creating an nftables firewall manager")
|
||||
fm, errFw = nbnftables.Create(context, iface)
|
||||
if errFw != nil {
|
||||
log.Errorf("failed to create nftables manager: %s", errFw)
|
||||
}
|
||||
return nbnftables.Create(iface)
|
||||
default:
|
||||
errFw = fmt.Errorf("no firewall manager found")
|
||||
log.Info("no firewall manager found, trying to use userspace packet filtering firewall")
|
||||
return nil, errors.New("no firewall manager found")
|
||||
}
|
||||
}
|
||||
|
||||
func createUserspaceFirewall(iface IFaceMapper, fm firewall.Manager) (firewall.Manager, error) {
|
||||
var errUsp error
|
||||
if fm != nil {
|
||||
fm, errUsp = uspfilter.CreateWithNativeFirewall(iface, fm)
|
||||
} else {
|
||||
fm, errUsp = uspfilter.Create(iface)
|
||||
}
|
||||
|
||||
if iface.IsUserspaceBind() {
|
||||
var errUsp error
|
||||
if errFw == nil {
|
||||
fm, errUsp = uspfilter.CreateWithNativeFirewall(iface, fm)
|
||||
} else {
|
||||
fm, errUsp = uspfilter.Create(iface)
|
||||
}
|
||||
if errUsp != nil {
|
||||
log.Debugf("failed to create userspace filtering firewall: %s", errUsp)
|
||||
return nil, errUsp
|
||||
}
|
||||
|
||||
if err := fm.AllowNetbird(); err != nil {
|
||||
log.Errorf("failed to allow netbird interface traffic: %v", err)
|
||||
}
|
||||
return fm, nil
|
||||
if errUsp != nil {
|
||||
return nil, fmt.Errorf("create userspace firewall: %s", errUsp)
|
||||
}
|
||||
|
||||
if errFw != nil {
|
||||
return nil, errFw
|
||||
if err := fm.AllowNetbird(); err != nil {
|
||||
log.Errorf("failed to allow netbird interface traffic: %v", err)
|
||||
}
|
||||
|
||||
return fm, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -1,11 +1,13 @@
|
||||
package firewall
|
||||
|
||||
import "github.com/netbirdio/netbird/iface"
|
||||
import (
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
)
|
||||
|
||||
// IFaceMapper defines subset methods of interface required for manager
|
||||
type IFaceMapper interface {
|
||||
Name() string
|
||||
Address() iface.WGAddress
|
||||
Address() device.WGAddress
|
||||
IsUserspaceBind() bool
|
||||
SetFilter(iface.PacketFilter) error
|
||||
SetFilter(device.PacketFilter) error
|
||||
}
|
||||
|
||||
@@ -11,62 +11,78 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
)
|
||||
|
||||
const (
|
||||
tableName = "filter"
|
||||
|
||||
// rules chains contains the effective ACL rules
|
||||
chainNameInputRules = "NETBIRD-ACL-INPUT"
|
||||
chainNameOutputRules = "NETBIRD-ACL-OUTPUT"
|
||||
|
||||
postRoutingMark = "0x000007e4"
|
||||
chainNameInputRules = "NETBIRD-ACL-INPUT"
|
||||
)
|
||||
|
||||
type aclManager struct {
|
||||
iptablesClient *iptables.IPTables
|
||||
wgIface iFaceMapper
|
||||
routeingFwChainName string
|
||||
type aclEntries map[string][][]string
|
||||
|
||||
entries map[string][][]string
|
||||
ipsetStore *ipsetStore
|
||||
type entry struct {
|
||||
spec []string
|
||||
position int
|
||||
}
|
||||
|
||||
func newAclManager(iptablesClient *iptables.IPTables, wgIface iFaceMapper, routeingFwChainName string) (*aclManager, error) {
|
||||
type aclManager struct {
|
||||
iptablesClient *iptables.IPTables
|
||||
wgIface iFaceMapper
|
||||
routingFwChainName string
|
||||
|
||||
entries aclEntries
|
||||
optionalEntries map[string][]entry
|
||||
ipsetStore *ipsetStore
|
||||
|
||||
stateManager *statemanager.Manager
|
||||
}
|
||||
|
||||
func newAclManager(iptablesClient *iptables.IPTables, wgIface iFaceMapper, routingFwChainName string) (*aclManager, error) {
|
||||
m := &aclManager{
|
||||
iptablesClient: iptablesClient,
|
||||
wgIface: wgIface,
|
||||
routeingFwChainName: routeingFwChainName,
|
||||
iptablesClient: iptablesClient,
|
||||
wgIface: wgIface,
|
||||
routingFwChainName: routingFwChainName,
|
||||
|
||||
entries: make(map[string][][]string),
|
||||
ipsetStore: newIpsetStore(),
|
||||
entries: make(map[string][][]string),
|
||||
optionalEntries: make(map[string][]entry),
|
||||
ipsetStore: newIpsetStore(),
|
||||
}
|
||||
|
||||
err := ipset.Init()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to init ipset: %w", err)
|
||||
if err := ipset.Init(); err != nil {
|
||||
return nil, fmt.Errorf("init ipset: %w", err)
|
||||
}
|
||||
|
||||
m.seedInitialEntries()
|
||||
|
||||
err = m.cleanChains()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = m.createDefaultChains()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func (m *aclManager) AddFiltering(
|
||||
func (m *aclManager) init(stateManager *statemanager.Manager) error {
|
||||
m.stateManager = stateManager
|
||||
|
||||
m.seedInitialEntries()
|
||||
m.seedInitialOptionalEntries()
|
||||
|
||||
if err := m.cleanChains(); err != nil {
|
||||
return fmt.Errorf("clean chains: %w", err)
|
||||
}
|
||||
|
||||
if err := m.createDefaultChains(); err != nil {
|
||||
return fmt.Errorf("create default chains: %w", err)
|
||||
}
|
||||
|
||||
m.updateState()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *aclManager) AddPeerFiltering(
|
||||
ip net.IP,
|
||||
protocol firewall.Protocol,
|
||||
sPort *firewall.Port,
|
||||
dPort *firewall.Port,
|
||||
direction firewall.RuleDirection,
|
||||
action firewall.Action,
|
||||
ipsetName string,
|
||||
) ([]firewall.Rule, error) {
|
||||
@@ -79,15 +95,10 @@ func (m *aclManager) AddFiltering(
|
||||
sPortVal = strconv.Itoa(sPort.Values[0])
|
||||
}
|
||||
|
||||
var chain string
|
||||
if direction == firewall.RuleDirectionOUT {
|
||||
chain = chainNameOutputRules
|
||||
} else {
|
||||
chain = chainNameInputRules
|
||||
}
|
||||
chain := chainNameInputRules
|
||||
|
||||
ipsetName = transformIPsetName(ipsetName, sPortVal, dPortVal)
|
||||
specs := filterRuleSpecs(ip, string(protocol), sPortVal, dPortVal, direction, action, ipsetName)
|
||||
specs := filterRuleSpecs(ip, string(protocol), sPortVal, dPortVal, action, ipsetName)
|
||||
if ipsetName != "" {
|
||||
if ipList, ipsetExists := m.ipsetStore.ipset(ipsetName); ipsetExists {
|
||||
if err := ipset.Add(ipsetName, ip.String()); err != nil {
|
||||
@@ -127,7 +138,7 @@ func (m *aclManager) AddFiltering(
|
||||
return nil, fmt.Errorf("rule already exists")
|
||||
}
|
||||
|
||||
if err := m.iptablesClient.Insert("filter", chain, 1, specs...); err != nil {
|
||||
if err := m.iptablesClient.Append("filter", chain, specs...); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -139,28 +150,18 @@ func (m *aclManager) AddFiltering(
|
||||
chain: chain,
|
||||
}
|
||||
|
||||
if !shouldAddToPrerouting(protocol, dPort, direction) {
|
||||
return []firewall.Rule{rule}, nil
|
||||
}
|
||||
m.updateState()
|
||||
|
||||
rulePrerouting, err := m.addPreroutingFilter(ipsetName, string(protocol), dPortVal, ip)
|
||||
if err != nil {
|
||||
return []firewall.Rule{rule}, err
|
||||
}
|
||||
return []firewall.Rule{rule, rulePrerouting}, nil
|
||||
return []firewall.Rule{rule}, nil
|
||||
}
|
||||
|
||||
// DeleteRule from the firewall by rule definition
|
||||
func (m *aclManager) DeleteRule(rule firewall.Rule) error {
|
||||
// DeletePeerRule from the firewall by rule definition
|
||||
func (m *aclManager) DeletePeerRule(rule firewall.Rule) error {
|
||||
r, ok := rule.(*Rule)
|
||||
if !ok {
|
||||
return fmt.Errorf("invalid rule type")
|
||||
}
|
||||
|
||||
if r.chain == "PREROUTING" {
|
||||
goto DELETERULE
|
||||
}
|
||||
|
||||
if ipsetList, ok := m.ipsetStore.ipset(r.ipsetName); ok {
|
||||
// delete IP from ruleset IPs list and ipset
|
||||
if _, ok := ipsetList.ips[r.ip]; ok {
|
||||
@@ -185,86 +186,28 @@ func (m *aclManager) DeleteRule(rule firewall.Rule) error {
|
||||
}
|
||||
}
|
||||
|
||||
DELETERULE:
|
||||
var table string
|
||||
if r.chain == "PREROUTING" {
|
||||
table = "mangle"
|
||||
} else {
|
||||
table = "filter"
|
||||
if err := m.iptablesClient.Delete(tableName, r.chain, r.specs...); err != nil {
|
||||
return fmt.Errorf("failed to delete rule: %s, %v: %w", r.chain, r.specs, err)
|
||||
}
|
||||
err := m.iptablesClient.Delete(table, r.chain, r.specs...)
|
||||
if err != nil {
|
||||
log.Debugf("failed to delete rule, %s, %v: %s", r.chain, r.specs, err)
|
||||
}
|
||||
return err
|
||||
|
||||
m.updateState()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *aclManager) Reset() error {
|
||||
return m.cleanChains()
|
||||
}
|
||||
|
||||
func (m *aclManager) addPreroutingFilter(ipsetName string, protocol string, port string, ip net.IP) (*Rule, error) {
|
||||
var src []string
|
||||
if ipsetName != "" {
|
||||
src = []string{"-m", "set", "--set", ipsetName, "src"}
|
||||
} else {
|
||||
src = []string{"-s", ip.String()}
|
||||
}
|
||||
specs := []string{
|
||||
"-d", m.wgIface.Address().IP.String(),
|
||||
"-p", protocol,
|
||||
"--dport", port,
|
||||
"-j", "MARK", "--set-mark", postRoutingMark,
|
||||
if err := m.cleanChains(); err != nil {
|
||||
return fmt.Errorf("clean chains: %w", err)
|
||||
}
|
||||
|
||||
specs = append(src, specs...)
|
||||
m.updateState()
|
||||
|
||||
ok, err := m.iptablesClient.Exists("mangle", "PREROUTING", specs...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to check rule: %w", err)
|
||||
}
|
||||
if ok {
|
||||
return nil, fmt.Errorf("rule already exists")
|
||||
}
|
||||
|
||||
if err := m.iptablesClient.Insert("mangle", "PREROUTING", 1, specs...); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rule := &Rule{
|
||||
ruleID: uuid.New().String(),
|
||||
specs: specs,
|
||||
ipsetName: ipsetName,
|
||||
ip: ip.String(),
|
||||
chain: "PREROUTING",
|
||||
}
|
||||
return rule, nil
|
||||
return nil
|
||||
}
|
||||
|
||||
// todo write less destructive cleanup mechanism
|
||||
func (m *aclManager) cleanChains() error {
|
||||
ok, err := m.iptablesClient.ChainExists(tableName, chainNameOutputRules)
|
||||
if err != nil {
|
||||
log.Debugf("failed to list chains: %s", err)
|
||||
return err
|
||||
}
|
||||
if ok {
|
||||
rules := m.entries["OUTPUT"]
|
||||
for _, rule := range rules {
|
||||
err := m.iptablesClient.DeleteIfExists(tableName, "OUTPUT", rule...)
|
||||
if err != nil {
|
||||
log.Errorf("failed to delete rule: %v, %s", rule, err)
|
||||
}
|
||||
}
|
||||
|
||||
err = m.iptablesClient.ClearAndDeleteChain(tableName, chainNameOutputRules)
|
||||
if err != nil {
|
||||
log.Debugf("failed to clear and delete %s chain: %s", chainNameOutputRules, err)
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
ok, err = m.iptablesClient.ChainExists(tableName, chainNameInputRules)
|
||||
ok, err := m.iptablesClient.ChainExists(tableName, chainNameInputRules)
|
||||
if err != nil {
|
||||
log.Debugf("failed to list chains: %s", err)
|
||||
return err
|
||||
@@ -293,8 +236,7 @@ func (m *aclManager) cleanChains() error {
|
||||
|
||||
ok, err = m.iptablesClient.ChainExists("mangle", "PREROUTING")
|
||||
if err != nil {
|
||||
log.Debugf("failed to list chains: %s", err)
|
||||
return err
|
||||
return fmt.Errorf("list chains: %w", err)
|
||||
}
|
||||
if ok {
|
||||
for _, rule := range m.entries["PREROUTING"] {
|
||||
@@ -303,11 +245,6 @@ func (m *aclManager) cleanChains() error {
|
||||
log.Errorf("failed to delete rule: %v, %s", rule, err)
|
||||
}
|
||||
}
|
||||
err = m.iptablesClient.ClearChain("mangle", "PREROUTING")
|
||||
if err != nil {
|
||||
log.Debugf("failed to clear %s chain: %s", "PREROUTING", err)
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
for _, ipsetName := range m.ipsetStore.ipsetNames() {
|
||||
@@ -330,97 +267,105 @@ func (m *aclManager) createDefaultChains() error {
|
||||
return err
|
||||
}
|
||||
|
||||
// chain netbird-acl-output-rules
|
||||
if err := m.iptablesClient.NewChain(tableName, chainNameOutputRules); err != nil {
|
||||
log.Debugf("failed to create '%s' chain: %s", chainNameOutputRules, err)
|
||||
return err
|
||||
}
|
||||
|
||||
for chainName, rules := range m.entries {
|
||||
for _, rule := range rules {
|
||||
if chainName == "FORWARD" {
|
||||
// position 2 because we add it after router's, jump rule
|
||||
if err := m.iptablesClient.InsertUnique(tableName, "FORWARD", 2, rule...); err != nil {
|
||||
log.Debugf("failed to create input chain jump rule: %s", err)
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
if err := m.iptablesClient.AppendUnique(tableName, chainName, rule...); err != nil {
|
||||
log.Debugf("failed to create input chain jump rule: %s", err)
|
||||
return err
|
||||
}
|
||||
if err := m.iptablesClient.InsertUnique(tableName, chainName, 1, rule...); err != nil {
|
||||
log.Debugf("failed to create input chain jump rule: %s", err)
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for chainName, entries := range m.optionalEntries {
|
||||
for _, entry := range entries {
|
||||
if err := m.iptablesClient.InsertUnique(tableName, chainName, entry.position, entry.spec...); err != nil {
|
||||
log.Errorf("failed to insert optional entry %v: %v", entry.spec, err)
|
||||
continue
|
||||
}
|
||||
m.entries[chainName] = append(m.entries[chainName], entry.spec)
|
||||
}
|
||||
}
|
||||
clear(m.optionalEntries)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// seedInitialEntries adds default rules to the entries map, rules are inserted on pos 1, hence the order is reversed.
|
||||
// We want to make sure our traffic is not dropped by existing rules.
|
||||
|
||||
// The existing FORWARD rules/policies decide outbound traffic towards our interface.
|
||||
// In case the FORWARD policy is set to "drop", we add an established/related rule to allow return traffic for the inbound rule.
|
||||
func (m *aclManager) seedInitialEntries() {
|
||||
m.appendToEntries("INPUT",
|
||||
[]string{"-i", m.wgIface.Name(), "!", "-s", m.wgIface.Address().String(), "-d", m.wgIface.Address().String(), "-j", "ACCEPT"})
|
||||
|
||||
m.appendToEntries("INPUT",
|
||||
[]string{"-i", m.wgIface.Name(), "-s", m.wgIface.Address().String(), "!", "-d", m.wgIface.Address().String(), "-j", "ACCEPT"})
|
||||
|
||||
m.appendToEntries("INPUT",
|
||||
[]string{"-i", m.wgIface.Name(), "-s", m.wgIface.Address().String(), "-d", m.wgIface.Address().String(), "-j", chainNameInputRules})
|
||||
established := getConntrackEstablished()
|
||||
|
||||
m.appendToEntries("INPUT", []string{"-i", m.wgIface.Name(), "-j", "DROP"})
|
||||
|
||||
m.appendToEntries("OUTPUT",
|
||||
[]string{"-o", m.wgIface.Name(), "!", "-s", m.wgIface.Address().String(), "-d", m.wgIface.Address().String(), "-j", "ACCEPT"})
|
||||
|
||||
m.appendToEntries("OUTPUT",
|
||||
[]string{"-o", m.wgIface.Name(), "-s", m.wgIface.Address().String(), "!", "-d", m.wgIface.Address().String(), "-j", "ACCEPT"})
|
||||
|
||||
m.appendToEntries("OUTPUT",
|
||||
[]string{"-o", m.wgIface.Name(), "-s", m.wgIface.Address().String(), "-d", m.wgIface.Address().String(), "-j", chainNameOutputRules})
|
||||
|
||||
m.appendToEntries("OUTPUT", []string{"-o", m.wgIface.Name(), "-j", "DROP"})
|
||||
m.appendToEntries("INPUT", []string{"-i", m.wgIface.Name(), "-j", chainNameInputRules})
|
||||
m.appendToEntries("INPUT", append([]string{"-i", m.wgIface.Name()}, established...))
|
||||
|
||||
m.appendToEntries("FORWARD", []string{"-i", m.wgIface.Name(), "-j", "DROP"})
|
||||
m.appendToEntries("FORWARD", []string{"-i", m.wgIface.Name(), "-j", chainNameInputRules})
|
||||
m.appendToEntries("FORWARD",
|
||||
[]string{"-o", m.wgIface.Name(), "-m", "mark", "--mark", postRoutingMark, "-j", "ACCEPT"})
|
||||
m.appendToEntries("FORWARD",
|
||||
[]string{"-i", m.wgIface.Name(), "-m", "mark", "--mark", postRoutingMark, "-j", "ACCEPT"})
|
||||
m.appendToEntries("FORWARD", []string{"-o", m.wgIface.Name(), "-j", m.routeingFwChainName})
|
||||
m.appendToEntries("FORWARD", []string{"-i", m.wgIface.Name(), "-j", m.routeingFwChainName})
|
||||
m.appendToEntries("FORWARD", []string{"-i", m.wgIface.Name(), "-j", m.routingFwChainName})
|
||||
m.appendToEntries("FORWARD", append([]string{"-o", m.wgIface.Name()}, established...))
|
||||
}
|
||||
|
||||
m.appendToEntries("PREROUTING",
|
||||
[]string{"-t", "mangle", "-i", m.wgIface.Name(), "!", "-s", m.wgIface.Address().String(), "-d", m.wgIface.Address().IP.String(), "-m", "mark", "--mark", postRoutingMark})
|
||||
func (m *aclManager) seedInitialOptionalEntries() {
|
||||
m.optionalEntries["FORWARD"] = []entry{
|
||||
{
|
||||
spec: []string{"-m", "mark", "--mark", fmt.Sprintf("%#x", nbnet.PreroutingFwmarkRedirected), "-j", chainNameInputRules},
|
||||
position: 2,
|
||||
},
|
||||
}
|
||||
|
||||
m.optionalEntries["PREROUTING"] = []entry{
|
||||
{
|
||||
spec: []string{"-t", "mangle", "-i", m.wgIface.Name(), "-m", "addrtype", "--dst-type", "LOCAL", "-j", "MARK", "--set-mark", fmt.Sprintf("%#x", nbnet.PreroutingFwmarkRedirected)},
|
||||
position: 1,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (m *aclManager) appendToEntries(chainName string, spec []string) {
|
||||
m.entries[chainName] = append(m.entries[chainName], spec)
|
||||
}
|
||||
|
||||
func (m *aclManager) updateState() {
|
||||
if m.stateManager == nil {
|
||||
return
|
||||
}
|
||||
|
||||
var currentState *ShutdownState
|
||||
if existing := m.stateManager.GetState(currentState); existing != nil {
|
||||
if existingState, ok := existing.(*ShutdownState); ok {
|
||||
currentState = existingState
|
||||
}
|
||||
}
|
||||
if currentState == nil {
|
||||
currentState = &ShutdownState{}
|
||||
}
|
||||
|
||||
currentState.Lock()
|
||||
defer currentState.Unlock()
|
||||
|
||||
currentState.ACLEntries = m.entries
|
||||
currentState.ACLIPsetStore = m.ipsetStore
|
||||
|
||||
if err := m.stateManager.UpdateState(currentState); err != nil {
|
||||
log.Errorf("failed to update state: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// filterRuleSpecs returns the specs of a filtering rule
|
||||
func filterRuleSpecs(
|
||||
ip net.IP, protocol string, sPort, dPort string, direction firewall.RuleDirection, action firewall.Action, ipsetName string,
|
||||
) (specs []string) {
|
||||
func filterRuleSpecs(ip net.IP, protocol, sPort, dPort string, action firewall.Action, ipsetName string) (specs []string) {
|
||||
matchByIP := true
|
||||
// don't use IP matching if IP is ip 0.0.0.0
|
||||
if ip.String() == "0.0.0.0" {
|
||||
matchByIP = false
|
||||
}
|
||||
switch direction {
|
||||
case firewall.RuleDirectionIN:
|
||||
if matchByIP {
|
||||
if ipsetName != "" {
|
||||
specs = append(specs, "-m", "set", "--set", ipsetName, "src")
|
||||
} else {
|
||||
specs = append(specs, "-s", ip.String())
|
||||
}
|
||||
}
|
||||
case firewall.RuleDirectionOUT:
|
||||
if matchByIP {
|
||||
if ipsetName != "" {
|
||||
specs = append(specs, "-m", "set", "--set", ipsetName, "dst")
|
||||
} else {
|
||||
specs = append(specs, "-d", ip.String())
|
||||
}
|
||||
|
||||
if matchByIP {
|
||||
if ipsetName != "" {
|
||||
specs = append(specs, "-m", "set", "--set", ipsetName, "src")
|
||||
} else {
|
||||
specs = append(specs, "-s", ip.String())
|
||||
}
|
||||
}
|
||||
if protocol != "all" {
|
||||
@@ -456,18 +401,3 @@ func transformIPsetName(ipsetName string, sPort, dPort string) string {
|
||||
return ipsetName
|
||||
}
|
||||
}
|
||||
|
||||
func shouldAddToPrerouting(proto firewall.Protocol, dPort *firewall.Port, direction firewall.RuleDirection) bool {
|
||||
if proto == "all" {
|
||||
return false
|
||||
}
|
||||
|
||||
if direction != firewall.RuleDirectionIN {
|
||||
return false
|
||||
}
|
||||
|
||||
if dPort == nil {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -4,13 +4,17 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"sync"
|
||||
|
||||
"github.com/coreos/go-iptables/iptables"
|
||||
"github.com/hashicorp/go-multierror"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
)
|
||||
|
||||
// Manager of iptables firewall
|
||||
@@ -21,7 +25,7 @@ type Manager struct {
|
||||
|
||||
ipv4Client *iptables.IPTables
|
||||
aclMgr *aclManager
|
||||
router *routerManager
|
||||
router *router
|
||||
}
|
||||
|
||||
// iFaceMapper defines subset methods of interface required for manager
|
||||
@@ -32,10 +36,10 @@ type iFaceMapper interface {
|
||||
}
|
||||
|
||||
// Create iptables firewall manager
|
||||
func Create(context context.Context, wgIface iFaceMapper) (*Manager, error) {
|
||||
func Create(wgIface iFaceMapper) (*Manager, error) {
|
||||
iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("iptables is not installed in the system or not supported")
|
||||
return nil, fmt.Errorf("init iptables: %w", err)
|
||||
}
|
||||
|
||||
m := &Manager{
|
||||
@@ -43,80 +47,146 @@ func Create(context context.Context, wgIface iFaceMapper) (*Manager, error) {
|
||||
ipv4Client: iptablesClient,
|
||||
}
|
||||
|
||||
m.router, err = newRouterManager(context, iptablesClient)
|
||||
m.router, err = newRouter(iptablesClient, wgIface)
|
||||
if err != nil {
|
||||
log.Debugf("failed to initialize route related chains: %s", err)
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("create router: %w", err)
|
||||
}
|
||||
m.aclMgr, err = newAclManager(iptablesClient, wgIface, m.router.RouteingFwChainName())
|
||||
|
||||
m.aclMgr, err = newAclManager(iptablesClient, wgIface, chainRTFWD)
|
||||
if err != nil {
|
||||
log.Debugf("failed to initialize ACL manager: %s", err)
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("create acl manager: %w", err)
|
||||
}
|
||||
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// AddFiltering rule to the firewall
|
||||
func (m *Manager) Init(stateManager *statemanager.Manager) error {
|
||||
state := &ShutdownState{
|
||||
InterfaceState: &InterfaceState{
|
||||
NameStr: m.wgIface.Name(),
|
||||
WGAddress: m.wgIface.Address(),
|
||||
UserspaceBind: m.wgIface.IsUserspaceBind(),
|
||||
},
|
||||
}
|
||||
stateManager.RegisterState(state)
|
||||
if err := stateManager.UpdateState(state); err != nil {
|
||||
log.Errorf("failed to update state: %v", err)
|
||||
}
|
||||
|
||||
if err := m.router.init(stateManager); err != nil {
|
||||
return fmt.Errorf("router init: %w", err)
|
||||
}
|
||||
|
||||
if err := m.aclMgr.init(stateManager); err != nil {
|
||||
// TODO: cleanup router
|
||||
return fmt.Errorf("acl manager init: %w", err)
|
||||
}
|
||||
|
||||
// persist early to ensure cleanup of chains
|
||||
go func() {
|
||||
if err := stateManager.PersistState(context.Background()); err != nil {
|
||||
log.Errorf("failed to persist state: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// AddPeerFiltering adds a rule to the firewall
|
||||
//
|
||||
// Comment will be ignored because some system this feature is not supported
|
||||
func (m *Manager) AddFiltering(
|
||||
func (m *Manager) AddPeerFiltering(
|
||||
ip net.IP,
|
||||
protocol firewall.Protocol,
|
||||
sPort *firewall.Port,
|
||||
dPort *firewall.Port,
|
||||
direction firewall.RuleDirection,
|
||||
action firewall.Action,
|
||||
ipsetName string,
|
||||
comment string,
|
||||
_ string,
|
||||
) ([]firewall.Rule, error) {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
return m.aclMgr.AddFiltering(ip, protocol, sPort, dPort, direction, action, ipsetName)
|
||||
return m.aclMgr.AddPeerFiltering(ip, protocol, sPort, dPort, action, ipsetName)
|
||||
}
|
||||
|
||||
// DeleteRule from the firewall by rule definition
|
||||
func (m *Manager) DeleteRule(rule firewall.Rule) error {
|
||||
func (m *Manager) AddRouteFiltering(
|
||||
sources []netip.Prefix,
|
||||
destination netip.Prefix,
|
||||
proto firewall.Protocol,
|
||||
sPort *firewall.Port,
|
||||
dPort *firewall.Port,
|
||||
action firewall.Action,
|
||||
) (firewall.Rule, error) {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
return m.aclMgr.DeleteRule(rule)
|
||||
if !destination.Addr().Is4() {
|
||||
return nil, fmt.Errorf("unsupported IP version: %s", destination.Addr().String())
|
||||
}
|
||||
|
||||
return m.router.AddRouteFiltering(sources, destination, proto, sPort, dPort, action)
|
||||
}
|
||||
|
||||
// DeletePeerRule from the firewall by rule definition
|
||||
func (m *Manager) DeletePeerRule(rule firewall.Rule) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
return m.aclMgr.DeletePeerRule(rule)
|
||||
}
|
||||
|
||||
func (m *Manager) DeleteRouteRule(rule firewall.Rule) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
return m.router.DeleteRouteRule(rule)
|
||||
}
|
||||
|
||||
func (m *Manager) IsServerRouteSupported() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (m *Manager) InsertRoutingRules(pair firewall.RouterPair) error {
|
||||
func (m *Manager) AddNatRule(pair firewall.RouterPair) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
return m.router.InsertRoutingRules(pair)
|
||||
return m.router.AddNatRule(pair)
|
||||
}
|
||||
|
||||
func (m *Manager) RemoveRoutingRules(pair firewall.RouterPair) error {
|
||||
func (m *Manager) RemoveNatRule(pair firewall.RouterPair) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
return m.router.RemoveRoutingRules(pair)
|
||||
return m.router.RemoveNatRule(pair)
|
||||
}
|
||||
|
||||
func (m *Manager) SetLegacyManagement(isLegacy bool) error {
|
||||
return firewall.SetLegacyManagement(m.router, isLegacy)
|
||||
}
|
||||
|
||||
// Reset firewall to the default state
|
||||
func (m *Manager) Reset() error {
|
||||
func (m *Manager) Reset(stateManager *statemanager.Manager) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
errAcl := m.aclMgr.Reset()
|
||||
if errAcl != nil {
|
||||
log.Errorf("failed to clean up ACL rules from firewall: %s", errAcl)
|
||||
var merr *multierror.Error
|
||||
|
||||
if err := m.aclMgr.Reset(); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("reset acl manager: %w", err))
|
||||
}
|
||||
errMgr := m.router.Reset()
|
||||
if errMgr != nil {
|
||||
log.Errorf("failed to clean up router rules from firewall: %s", errMgr)
|
||||
return errMgr
|
||||
if err := m.router.Reset(); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("reset router: %w", err))
|
||||
}
|
||||
return errAcl
|
||||
|
||||
// attempt to delete state only if all other operations succeeded
|
||||
if merr == nil {
|
||||
if err := stateManager.DeleteState(&ShutdownState{}); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("delete state: %w", err))
|
||||
}
|
||||
}
|
||||
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
// AllowNetbird allows netbird interface traffic
|
||||
@@ -125,31 +195,24 @@ func (m *Manager) AllowNetbird() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
_, err := m.AddFiltering(
|
||||
net.ParseIP("0.0.0.0"),
|
||||
_, err := m.AddPeerFiltering(
|
||||
net.IP{0, 0, 0, 0},
|
||||
"all",
|
||||
nil,
|
||||
nil,
|
||||
firewall.RuleDirectionIN,
|
||||
firewall.ActionAccept,
|
||||
"",
|
||||
"",
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to allow netbird interface traffic: %w", err)
|
||||
return fmt.Errorf("allow netbird interface traffic: %w", err)
|
||||
}
|
||||
_, err = m.AddFiltering(
|
||||
net.ParseIP("0.0.0.0"),
|
||||
"all",
|
||||
nil,
|
||||
nil,
|
||||
firewall.RuleDirectionOUT,
|
||||
firewall.ActionAccept,
|
||||
"",
|
||||
"",
|
||||
)
|
||||
return err
|
||||
return nil
|
||||
}
|
||||
|
||||
// Flush doesn't need to be implemented for this manager
|
||||
func (m *Manager) Flush() error { return nil }
|
||||
|
||||
func getConntrackEstablished() []string {
|
||||
return []string{"-m", "conntrack", "--ctstate", "RELATED,ESTABLISHED", "-j", "ACCEPT"}
|
||||
}
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package iptables
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"testing"
|
||||
@@ -11,9 +10,24 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
fw "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
)
|
||||
|
||||
var ifaceMock = &iFaceMock{
|
||||
NameFunc: func() string {
|
||||
return "lo"
|
||||
},
|
||||
AddressFunc: func() iface.WGAddress {
|
||||
return iface.WGAddress{
|
||||
IP: net.ParseIP("10.20.0.1"),
|
||||
Network: &net.IPNet{
|
||||
IP: net.ParseIP("10.20.0.0"),
|
||||
Mask: net.IPv4Mask(255, 255, 255, 0),
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
// iFaceMapper defines subset methods of interface required for manager
|
||||
type iFaceMock struct {
|
||||
NameFunc func() string
|
||||
@@ -40,55 +54,27 @@ func TestIptablesManager(t *testing.T) {
|
||||
ipv4Client, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
||||
require.NoError(t, err)
|
||||
|
||||
mock := &iFaceMock{
|
||||
NameFunc: func() string {
|
||||
return "lo"
|
||||
},
|
||||
AddressFunc: func() iface.WGAddress {
|
||||
return iface.WGAddress{
|
||||
IP: net.ParseIP("10.20.0.1"),
|
||||
Network: &net.IPNet{
|
||||
IP: net.ParseIP("10.20.0.0"),
|
||||
Mask: net.IPv4Mask(255, 255, 255, 0),
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
// just check on the local interface
|
||||
manager, err := Create(context.Background(), mock)
|
||||
manager, err := Create(ifaceMock)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, manager.Init(nil))
|
||||
|
||||
time.Sleep(time.Second)
|
||||
|
||||
defer func() {
|
||||
err := manager.Reset()
|
||||
err := manager.Reset(nil)
|
||||
require.NoError(t, err, "clear the manager state")
|
||||
|
||||
time.Sleep(time.Second)
|
||||
}()
|
||||
|
||||
var rule1 []fw.Rule
|
||||
t.Run("add first rule", func(t *testing.T) {
|
||||
ip := net.ParseIP("10.20.0.2")
|
||||
port := &fw.Port{Values: []int{8080}}
|
||||
rule1, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic")
|
||||
require.NoError(t, err, "failed to add rule")
|
||||
|
||||
for _, r := range rule1 {
|
||||
checkRuleSpecs(t, ipv4Client, chainNameOutputRules, true, r.(*Rule).specs...)
|
||||
}
|
||||
|
||||
})
|
||||
|
||||
var rule2 []fw.Rule
|
||||
t.Run("add second rule", func(t *testing.T) {
|
||||
ip := net.ParseIP("10.20.0.3")
|
||||
port := &fw.Port{
|
||||
Values: []int{8043: 8046},
|
||||
}
|
||||
rule2, err = manager.AddFiltering(
|
||||
ip, "tcp", port, nil, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTPS traffic from ports range")
|
||||
rule2, err = manager.AddPeerFiltering(ip, "tcp", port, nil, fw.ActionAccept, "", "accept HTTPS traffic from ports range")
|
||||
require.NoError(t, err, "failed to add rule")
|
||||
|
||||
for _, r := range rule2 {
|
||||
@@ -97,18 +83,9 @@ func TestIptablesManager(t *testing.T) {
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("delete first rule", func(t *testing.T) {
|
||||
for _, r := range rule1 {
|
||||
err := manager.DeleteRule(r)
|
||||
require.NoError(t, err, "failed to delete rule")
|
||||
|
||||
checkRuleSpecs(t, ipv4Client, chainNameOutputRules, false, r.(*Rule).specs...)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("delete second rule", func(t *testing.T) {
|
||||
for _, r := range rule2 {
|
||||
err := manager.DeleteRule(r)
|
||||
err := manager.DeletePeerRule(r)
|
||||
require.NoError(t, err, "failed to delete rule")
|
||||
}
|
||||
|
||||
@@ -119,10 +96,10 @@ func TestIptablesManager(t *testing.T) {
|
||||
// add second rule
|
||||
ip := net.ParseIP("10.20.0.3")
|
||||
port := &fw.Port{Values: []int{5353}}
|
||||
_, err = manager.AddFiltering(ip, "udp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept Fake DNS traffic")
|
||||
_, err = manager.AddPeerFiltering(ip, "udp", nil, port, fw.ActionAccept, "", "accept Fake DNS traffic")
|
||||
require.NoError(t, err, "failed to add rule")
|
||||
|
||||
err = manager.Reset()
|
||||
err = manager.Reset(nil)
|
||||
require.NoError(t, err, "failed to reset")
|
||||
|
||||
ok, err := ipv4Client.ChainExists("filter", chainNameInputRules)
|
||||
@@ -135,9 +112,6 @@ func TestIptablesManager(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestIptablesManagerIPSet(t *testing.T) {
|
||||
ipv4Client, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
||||
require.NoError(t, err)
|
||||
|
||||
mock := &iFaceMock{
|
||||
NameFunc: func() string {
|
||||
return "lo"
|
||||
@@ -154,45 +128,26 @@ func TestIptablesManagerIPSet(t *testing.T) {
|
||||
}
|
||||
|
||||
// just check on the local interface
|
||||
manager, err := Create(context.Background(), mock)
|
||||
manager, err := Create(mock)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, manager.Init(nil))
|
||||
|
||||
time.Sleep(time.Second)
|
||||
|
||||
defer func() {
|
||||
err := manager.Reset()
|
||||
err := manager.Reset(nil)
|
||||
require.NoError(t, err, "clear the manager state")
|
||||
|
||||
time.Sleep(time.Second)
|
||||
}()
|
||||
|
||||
var rule1 []fw.Rule
|
||||
t.Run("add first rule with set", func(t *testing.T) {
|
||||
ip := net.ParseIP("10.20.0.2")
|
||||
port := &fw.Port{Values: []int{8080}}
|
||||
rule1, err = manager.AddFiltering(
|
||||
ip, "tcp", nil, port, fw.RuleDirectionOUT,
|
||||
fw.ActionAccept, "default", "accept HTTP traffic",
|
||||
)
|
||||
require.NoError(t, err, "failed to add rule")
|
||||
|
||||
for _, r := range rule1 {
|
||||
checkRuleSpecs(t, ipv4Client, chainNameOutputRules, true, r.(*Rule).specs...)
|
||||
require.Equal(t, r.(*Rule).ipsetName, "default-dport", "ipset name must be set")
|
||||
require.Equal(t, r.(*Rule).ip, "10.20.0.2", "ipset IP must be set")
|
||||
}
|
||||
})
|
||||
|
||||
var rule2 []fw.Rule
|
||||
t.Run("add second rule", func(t *testing.T) {
|
||||
ip := net.ParseIP("10.20.0.3")
|
||||
port := &fw.Port{
|
||||
Values: []int{443},
|
||||
}
|
||||
rule2, err = manager.AddFiltering(
|
||||
ip, "tcp", port, nil, fw.RuleDirectionIN, fw.ActionAccept,
|
||||
"default", "accept HTTPS traffic from ports range",
|
||||
)
|
||||
rule2, err = manager.AddPeerFiltering(ip, "tcp", port, nil, fw.ActionAccept, "default", "accept HTTPS traffic from ports range")
|
||||
for _, r := range rule2 {
|
||||
require.NoError(t, err, "failed to add rule")
|
||||
require.Equal(t, r.(*Rule).ipsetName, "default-sport", "ipset name must be set")
|
||||
@@ -200,18 +155,9 @@ func TestIptablesManagerIPSet(t *testing.T) {
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("delete first rule", func(t *testing.T) {
|
||||
for _, r := range rule1 {
|
||||
err := manager.DeleteRule(r)
|
||||
require.NoError(t, err, "failed to delete rule")
|
||||
|
||||
require.NotContains(t, manager.aclMgr.ipsetStore.ipsets, r.(*Rule).ruleID, "rule must be removed form the ruleset index")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("delete second rule", func(t *testing.T) {
|
||||
for _, r := range rule2 {
|
||||
err := manager.DeleteRule(r)
|
||||
err := manager.DeletePeerRule(r)
|
||||
require.NoError(t, err, "failed to delete rule")
|
||||
|
||||
require.Empty(t, manager.aclMgr.ipsetStore.ipsets, "rulesets index after removed second rule must be empty")
|
||||
@@ -219,7 +165,7 @@ func TestIptablesManagerIPSet(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("reset check", func(t *testing.T) {
|
||||
err = manager.Reset()
|
||||
err = manager.Reset(nil)
|
||||
require.NoError(t, err, "failed to reset")
|
||||
})
|
||||
}
|
||||
@@ -251,12 +197,13 @@ func TestIptablesCreatePerformance(t *testing.T) {
|
||||
for _, testMax := range []int{10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 200, 300, 400, 500, 600, 700, 800, 900, 1000} {
|
||||
t.Run(fmt.Sprintf("Testing %d rules", testMax), func(t *testing.T) {
|
||||
// just check on the local interface
|
||||
manager, err := Create(context.Background(), mock)
|
||||
manager, err := Create(mock)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, manager.Init(nil))
|
||||
time.Sleep(time.Second)
|
||||
|
||||
defer func() {
|
||||
err := manager.Reset()
|
||||
err := manager.Reset(nil)
|
||||
require.NoError(t, err, "clear the manager state")
|
||||
|
||||
time.Sleep(time.Second)
|
||||
@@ -268,11 +215,7 @@ func TestIptablesCreatePerformance(t *testing.T) {
|
||||
start := time.Now()
|
||||
for i := 0; i < testMax; i++ {
|
||||
port := &fw.Port{Values: []int{1000 + i}}
|
||||
if i%2 == 0 {
|
||||
_, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic")
|
||||
} else {
|
||||
_, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTP traffic")
|
||||
}
|
||||
_, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.ActionAccept, "", "accept HTTP traffic")
|
||||
|
||||
require.NoError(t, err, "failed to add rule")
|
||||
}
|
||||
|
||||
@@ -3,370 +3,597 @@
|
||||
package iptables
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/coreos/go-iptables/iptables"
|
||||
"github.com/hashicorp/go-multierror"
|
||||
"github.com/nadoo/ipset"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
)
|
||||
|
||||
const (
|
||||
Ipv4Forwarding = "netbird-rt-forwarding"
|
||||
ipv4Nat = "netbird-rt-nat"
|
||||
"github.com/netbirdio/netbird/client/internal/acl/id"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
)
|
||||
|
||||
// constants needed to manage and create iptable rules
|
||||
const (
|
||||
tableFilter = "filter"
|
||||
tableNat = "nat"
|
||||
chainFORWARD = "FORWARD"
|
||||
tableMangle = "mangle"
|
||||
chainPOSTROUTING = "POSTROUTING"
|
||||
chainPREROUTING = "PREROUTING"
|
||||
chainRTNAT = "NETBIRD-RT-NAT"
|
||||
chainRTFWD = "NETBIRD-RT-FWD"
|
||||
chainRTPRE = "NETBIRD-RT-PRE"
|
||||
routingFinalForwardJump = "ACCEPT"
|
||||
routingFinalNatJump = "MASQUERADE"
|
||||
|
||||
jumpPre = "jump-pre"
|
||||
jumpNat = "jump-nat"
|
||||
matchSet = "--match-set"
|
||||
)
|
||||
|
||||
type routerManager struct {
|
||||
ctx context.Context
|
||||
stop context.CancelFunc
|
||||
iptablesClient *iptables.IPTables
|
||||
rules map[string][]string
|
||||
type routeFilteringRuleParams struct {
|
||||
Sources []netip.Prefix
|
||||
Destination netip.Prefix
|
||||
Proto firewall.Protocol
|
||||
SPort *firewall.Port
|
||||
DPort *firewall.Port
|
||||
Direction firewall.RuleDirection
|
||||
Action firewall.Action
|
||||
SetName string
|
||||
}
|
||||
|
||||
func newRouterManager(parentCtx context.Context, iptablesClient *iptables.IPTables) (*routerManager, error) {
|
||||
ctx, cancel := context.WithCancel(parentCtx)
|
||||
m := &routerManager{
|
||||
ctx: ctx,
|
||||
stop: cancel,
|
||||
type routeRules map[string][]string
|
||||
|
||||
type ipsetCounter = refcounter.Counter[string, []netip.Prefix, struct{}]
|
||||
|
||||
type router struct {
|
||||
iptablesClient *iptables.IPTables
|
||||
rules routeRules
|
||||
ipsetCounter *ipsetCounter
|
||||
wgIface iFaceMapper
|
||||
legacyManagement bool
|
||||
|
||||
stateManager *statemanager.Manager
|
||||
}
|
||||
|
||||
func newRouter(iptablesClient *iptables.IPTables, wgIface iFaceMapper) (*router, error) {
|
||||
r := &router{
|
||||
iptablesClient: iptablesClient,
|
||||
rules: make(map[string][]string),
|
||||
wgIface: wgIface,
|
||||
}
|
||||
|
||||
err := m.cleanUpDefaultForwardRules()
|
||||
if err != nil {
|
||||
log.Errorf("failed to cleanup routing rules: %s", err)
|
||||
return nil, err
|
||||
r.ipsetCounter = refcounter.New(
|
||||
func(name string, sources []netip.Prefix) (struct{}, error) {
|
||||
return struct{}{}, r.createIpSet(name, sources)
|
||||
},
|
||||
func(name string, _ struct{}) error {
|
||||
return r.deleteIpSet(name)
|
||||
},
|
||||
)
|
||||
|
||||
if err := ipset.Init(); err != nil {
|
||||
return nil, fmt.Errorf("init ipset: %w", err)
|
||||
}
|
||||
err = m.createContainers()
|
||||
if err != nil {
|
||||
log.Errorf("failed to create containers for route: %s", err)
|
||||
}
|
||||
return m, err
|
||||
|
||||
return r, nil
|
||||
}
|
||||
|
||||
// InsertRoutingRules inserts an iptables rule pair to the forwarding chain and if enabled, to the nat chain
|
||||
func (i *routerManager) InsertRoutingRules(pair firewall.RouterPair) error {
|
||||
err := i.insertRoutingRule(firewall.ForwardingFormat, tableFilter, chainRTFWD, routingFinalForwardJump, pair)
|
||||
if err != nil {
|
||||
return err
|
||||
func (r *router) init(stateManager *statemanager.Manager) error {
|
||||
r.stateManager = stateManager
|
||||
|
||||
if err := r.cleanUpDefaultForwardRules(); err != nil {
|
||||
log.Errorf("failed to clean up rules from FORWARD chain: %s", err)
|
||||
}
|
||||
|
||||
err = i.insertRoutingRule(firewall.InForwardingFormat, tableFilter, chainRTFWD, routingFinalForwardJump, firewall.GetInPair(pair))
|
||||
if err != nil {
|
||||
return err
|
||||
if err := r.createContainers(); err != nil {
|
||||
return fmt.Errorf("create containers: %w", err)
|
||||
}
|
||||
|
||||
r.updateState()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *router) AddRouteFiltering(
|
||||
sources []netip.Prefix,
|
||||
destination netip.Prefix,
|
||||
proto firewall.Protocol,
|
||||
sPort *firewall.Port,
|
||||
dPort *firewall.Port,
|
||||
action firewall.Action,
|
||||
) (firewall.Rule, error) {
|
||||
ruleKey := id.GenerateRouteRuleKey(sources, destination, proto, sPort, dPort, action)
|
||||
if _, ok := r.rules[string(ruleKey)]; ok {
|
||||
return ruleKey, nil
|
||||
}
|
||||
|
||||
var setName string
|
||||
if len(sources) > 1 {
|
||||
setName = firewall.GenerateSetName(sources)
|
||||
if _, err := r.ipsetCounter.Increment(setName, sources); err != nil {
|
||||
return nil, fmt.Errorf("create or get ipset: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
params := routeFilteringRuleParams{
|
||||
Sources: sources,
|
||||
Destination: destination,
|
||||
Proto: proto,
|
||||
SPort: sPort,
|
||||
DPort: dPort,
|
||||
Action: action,
|
||||
SetName: setName,
|
||||
}
|
||||
|
||||
rule := genRouteFilteringRuleSpec(params)
|
||||
if err := r.iptablesClient.Append(tableFilter, chainRTFWD, rule...); err != nil {
|
||||
return nil, fmt.Errorf("add route rule: %v", err)
|
||||
}
|
||||
|
||||
r.rules[string(ruleKey)] = rule
|
||||
|
||||
r.updateState()
|
||||
|
||||
return ruleKey, nil
|
||||
}
|
||||
|
||||
func (r *router) DeleteRouteRule(rule firewall.Rule) error {
|
||||
ruleKey := rule.GetRuleID()
|
||||
|
||||
if rule, exists := r.rules[ruleKey]; exists {
|
||||
setName := r.findSetNameInRule(rule)
|
||||
|
||||
if err := r.iptablesClient.Delete(tableFilter, chainRTFWD, rule...); err != nil {
|
||||
return fmt.Errorf("delete route rule: %v", err)
|
||||
}
|
||||
delete(r.rules, ruleKey)
|
||||
|
||||
if setName != "" {
|
||||
if _, err := r.ipsetCounter.Decrement(setName); err != nil {
|
||||
return fmt.Errorf("failed to remove ipset: %w", err)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
log.Debugf("route rule %s not found", ruleKey)
|
||||
}
|
||||
|
||||
r.updateState()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *router) findSetNameInRule(rule []string) string {
|
||||
for i, arg := range rule {
|
||||
if arg == "-m" && i+3 < len(rule) && rule[i+1] == "set" && rule[i+2] == matchSet {
|
||||
return rule[i+3]
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (r *router) createIpSet(setName string, sources []netip.Prefix) error {
|
||||
if err := ipset.Create(setName, ipset.OptTimeout(0)); err != nil {
|
||||
return fmt.Errorf("create set %s: %w", setName, err)
|
||||
}
|
||||
|
||||
for _, prefix := range sources {
|
||||
if err := ipset.AddPrefix(setName, prefix); err != nil {
|
||||
return fmt.Errorf("add element to set %s: %w", setName, err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *router) deleteIpSet(setName string) error {
|
||||
if err := ipset.Destroy(setName); err != nil {
|
||||
return fmt.Errorf("destroy set %s: %w", setName, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// AddNatRule inserts an iptables rule pair into the nat chain
|
||||
func (r *router) AddNatRule(pair firewall.RouterPair) error {
|
||||
if r.legacyManagement {
|
||||
log.Warnf("This peer is connected to a NetBird Management service with an older version. Allowing all traffic for %s", pair.Destination)
|
||||
if err := r.addLegacyRouteRule(pair); err != nil {
|
||||
return fmt.Errorf("add legacy routing rule: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if !pair.Masquerade {
|
||||
return nil
|
||||
}
|
||||
|
||||
err = i.addNATRule(firewall.NatFormat, tableNat, chainRTNAT, routingFinalNatJump, pair)
|
||||
if err != nil {
|
||||
return err
|
||||
if err := r.addNatRule(pair); err != nil {
|
||||
return fmt.Errorf("add nat rule: %w", err)
|
||||
}
|
||||
|
||||
err = i.addNATRule(firewall.InNatFormat, tableNat, chainRTNAT, routingFinalNatJump, firewall.GetInPair(pair))
|
||||
if err != nil {
|
||||
return err
|
||||
if err := r.addNatRule(firewall.GetInversePair(pair)); err != nil {
|
||||
return fmt.Errorf("add inverse nat rule: %w", err)
|
||||
}
|
||||
|
||||
r.updateState()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// insertRoutingRule inserts an iptables rule
|
||||
func (i *routerManager) insertRoutingRule(keyFormat, table, chain, jump string, pair firewall.RouterPair) error {
|
||||
var err error
|
||||
// RemoveNatRule removes an iptables rule pair from forwarding and nat chains
|
||||
func (r *router) RemoveNatRule(pair firewall.RouterPair) error {
|
||||
if err := r.removeNatRule(pair); err != nil {
|
||||
return fmt.Errorf("remove nat rule: %w", err)
|
||||
}
|
||||
|
||||
ruleKey := firewall.GenKey(keyFormat, pair.ID)
|
||||
rule := genRuleSpec(jump, pair.Source, pair.Destination)
|
||||
existingRule, found := i.rules[ruleKey]
|
||||
if found {
|
||||
err = i.iptablesClient.DeleteIfExists(table, chain, existingRule...)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error while removing existing %s rule for %s: %v", getIptablesRuleType(table), pair.Destination, err)
|
||||
if err := r.removeNatRule(firewall.GetInversePair(pair)); err != nil {
|
||||
return fmt.Errorf("remove inverse nat rule: %w", err)
|
||||
}
|
||||
|
||||
if err := r.removeLegacyRouteRule(pair); err != nil {
|
||||
return fmt.Errorf("remove legacy routing rule: %w", err)
|
||||
}
|
||||
|
||||
r.updateState()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// addLegacyRouteRule adds a legacy routing rule for mgmt servers pre route acls
|
||||
func (r *router) addLegacyRouteRule(pair firewall.RouterPair) error {
|
||||
ruleKey := firewall.GenKey(firewall.ForwardingFormat, pair)
|
||||
|
||||
if err := r.removeLegacyRouteRule(pair); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
rule := []string{"-s", pair.Source.String(), "-d", pair.Destination.String(), "-j", routingFinalForwardJump}
|
||||
if err := r.iptablesClient.Append(tableFilter, chainRTFWD, rule...); err != nil {
|
||||
return fmt.Errorf("add legacy forwarding rule %s -> %s: %v", pair.Source, pair.Destination, err)
|
||||
}
|
||||
|
||||
r.rules[ruleKey] = rule
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *router) removeLegacyRouteRule(pair firewall.RouterPair) error {
|
||||
ruleKey := firewall.GenKey(firewall.ForwardingFormat, pair)
|
||||
|
||||
if rule, exists := r.rules[ruleKey]; exists {
|
||||
if err := r.iptablesClient.DeleteIfExists(tableFilter, chainRTFWD, rule...); err != nil {
|
||||
return fmt.Errorf("remove legacy forwarding rule %s -> %s: %v", pair.Source, pair.Destination, err)
|
||||
}
|
||||
delete(i.rules, ruleKey)
|
||||
}
|
||||
|
||||
err = i.iptablesClient.Insert(table, chain, 1, rule...)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error while adding new %s rule for %s: %v", getIptablesRuleType(table), pair.Destination, err)
|
||||
}
|
||||
|
||||
i.rules[ruleKey] = rule
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemoveRoutingRules removes an iptables rule pair from forwarding and nat chains
|
||||
func (i *routerManager) RemoveRoutingRules(pair firewall.RouterPair) error {
|
||||
err := i.removeRoutingRule(firewall.ForwardingFormat, tableFilter, chainRTFWD, pair)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = i.removeRoutingRule(firewall.InForwardingFormat, tableFilter, chainRTFWD, firewall.GetInPair(pair))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !pair.Masquerade {
|
||||
return nil
|
||||
}
|
||||
|
||||
err = i.removeRoutingRule(firewall.NatFormat, tableNat, chainRTNAT, pair)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = i.removeRoutingRule(firewall.InNatFormat, tableNat, chainRTNAT, firewall.GetInPair(pair))
|
||||
if err != nil {
|
||||
return err
|
||||
delete(r.rules, ruleKey)
|
||||
} else {
|
||||
log.Debugf("legacy forwarding rule %s not found", ruleKey)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (i *routerManager) removeRoutingRule(keyFormat, table, chain string, pair firewall.RouterPair) error {
|
||||
var err error
|
||||
// GetLegacyManagement returns the current legacy management mode
|
||||
func (r *router) GetLegacyManagement() bool {
|
||||
return r.legacyManagement
|
||||
}
|
||||
|
||||
ruleKey := firewall.GenKey(keyFormat, pair.ID)
|
||||
existingRule, found := i.rules[ruleKey]
|
||||
if found {
|
||||
err = i.iptablesClient.DeleteIfExists(table, chain, existingRule...)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error while removing existing %s rule for %s: %v", getIptablesRuleType(table), pair.Destination, err)
|
||||
// SetLegacyManagement sets the route manager to use legacy management mode
|
||||
func (r *router) SetLegacyManagement(isLegacy bool) {
|
||||
r.legacyManagement = isLegacy
|
||||
}
|
||||
|
||||
// RemoveAllLegacyRouteRules removes all legacy routing rules for mgmt servers pre route acls
|
||||
func (r *router) RemoveAllLegacyRouteRules() error {
|
||||
var merr *multierror.Error
|
||||
for k, rule := range r.rules {
|
||||
if !strings.HasPrefix(k, firewall.ForwardingFormatPrefix) {
|
||||
continue
|
||||
}
|
||||
if err := r.iptablesClient.DeleteIfExists(tableFilter, chainRTFWD, rule...); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove legacy forwarding rule: %v", err))
|
||||
} else {
|
||||
delete(r.rules, k)
|
||||
}
|
||||
}
|
||||
delete(i.rules, ruleKey)
|
||||
|
||||
return nil
|
||||
r.updateState()
|
||||
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
func (i *routerManager) RouteingFwChainName() string {
|
||||
return chainRTFWD
|
||||
}
|
||||
|
||||
func (i *routerManager) Reset() error {
|
||||
err := i.cleanUpDefaultForwardRules()
|
||||
if err != nil {
|
||||
return err
|
||||
func (r *router) Reset() error {
|
||||
var merr *multierror.Error
|
||||
if err := r.cleanUpDefaultForwardRules(); err != nil {
|
||||
merr = multierror.Append(merr, err)
|
||||
}
|
||||
i.rules = make(map[string][]string)
|
||||
return nil
|
||||
r.rules = make(map[string][]string)
|
||||
|
||||
if err := r.ipsetCounter.Flush(); err != nil {
|
||||
merr = multierror.Append(merr, err)
|
||||
}
|
||||
|
||||
r.updateState()
|
||||
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
func (i *routerManager) cleanUpDefaultForwardRules() error {
|
||||
err := i.cleanJumpRules()
|
||||
if err != nil {
|
||||
return err
|
||||
func (r *router) cleanUpDefaultForwardRules() error {
|
||||
if err := r.cleanJumpRules(); err != nil {
|
||||
return fmt.Errorf("clean jump rules: %w", err)
|
||||
}
|
||||
|
||||
log.Debug("flushing routing related tables")
|
||||
ok, err := i.iptablesClient.ChainExists(tableFilter, chainRTFWD)
|
||||
if err != nil {
|
||||
log.Errorf("failed check chain %s,error: %v", chainRTFWD, err)
|
||||
return err
|
||||
} else if ok {
|
||||
err = i.iptablesClient.ClearAndDeleteChain(tableFilter, chainRTFWD)
|
||||
for _, chainInfo := range []struct {
|
||||
chain string
|
||||
table string
|
||||
}{
|
||||
{chainRTFWD, tableFilter},
|
||||
{chainRTNAT, tableNat},
|
||||
{chainRTPRE, tableMangle},
|
||||
} {
|
||||
ok, err := r.iptablesClient.ChainExists(chainInfo.table, chainInfo.chain)
|
||||
if err != nil {
|
||||
log.Errorf("failed cleaning chain %s,error: %v", chainRTFWD, err)
|
||||
return err
|
||||
return fmt.Errorf("check chain %s in table %s: %w", chainInfo.chain, chainInfo.table, err)
|
||||
} else if ok {
|
||||
if err = r.iptablesClient.ClearAndDeleteChain(chainInfo.table, chainInfo.chain); err != nil {
|
||||
return fmt.Errorf("clear and delete chain %s in table %s: %w", chainInfo.chain, chainInfo.table, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
ok, err = i.iptablesClient.ChainExists(tableNat, chainRTNAT)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *router) createContainers() error {
|
||||
for _, chainInfo := range []struct {
|
||||
chain string
|
||||
table string
|
||||
}{
|
||||
{chainRTFWD, tableFilter},
|
||||
{chainRTPRE, tableMangle},
|
||||
{chainRTNAT, tableNat},
|
||||
} {
|
||||
if err := r.createAndSetupChain(chainInfo.chain); err != nil {
|
||||
return fmt.Errorf("create chain %s in table %s: %w", chainInfo.chain, chainInfo.table, err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := r.insertEstablishedRule(chainRTFWD); err != nil {
|
||||
return fmt.Errorf("insert established rule: %w", err)
|
||||
}
|
||||
|
||||
if err := r.addPostroutingRules(); err != nil {
|
||||
return fmt.Errorf("add static nat rules: %w", err)
|
||||
}
|
||||
|
||||
if err := r.addJumpRules(); err != nil {
|
||||
return fmt.Errorf("add jump rules: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *router) addPostroutingRules() error {
|
||||
// First rule for outbound masquerade
|
||||
rule1 := []string{
|
||||
"-m", "mark", "--mark", fmt.Sprintf("%#x", nbnet.PreroutingFwmarkMasquerade),
|
||||
"!", "-o", "lo",
|
||||
"-j", routingFinalNatJump,
|
||||
}
|
||||
if err := r.iptablesClient.Append(tableNat, chainRTNAT, rule1...); err != nil {
|
||||
return fmt.Errorf("add outbound masquerade rule: %v", err)
|
||||
}
|
||||
r.rules["static-nat-outbound"] = rule1
|
||||
|
||||
// Second rule for return traffic masquerade
|
||||
rule2 := []string{
|
||||
"-m", "mark", "--mark", fmt.Sprintf("%#x", nbnet.PreroutingFwmarkMasqueradeReturn),
|
||||
"-o", r.wgIface.Name(),
|
||||
"-j", routingFinalNatJump,
|
||||
}
|
||||
if err := r.iptablesClient.Append(tableNat, chainRTNAT, rule2...); err != nil {
|
||||
return fmt.Errorf("add return masquerade rule: %v", err)
|
||||
}
|
||||
r.rules["static-nat-return"] = rule2
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *router) createAndSetupChain(chain string) error {
|
||||
table := r.getTableForChain(chain)
|
||||
|
||||
if err := r.iptablesClient.NewChain(table, chain); err != nil {
|
||||
return fmt.Errorf("failed creating chain %s, error: %v", chain, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *router) getTableForChain(chain string) string {
|
||||
switch chain {
|
||||
case chainRTNAT:
|
||||
return tableNat
|
||||
case chainRTPRE:
|
||||
return tableMangle
|
||||
default:
|
||||
return tableFilter
|
||||
}
|
||||
}
|
||||
|
||||
func (r *router) insertEstablishedRule(chain string) error {
|
||||
establishedRule := getConntrackEstablished()
|
||||
|
||||
err := r.iptablesClient.Insert(tableFilter, chain, 1, establishedRule...)
|
||||
if err != nil {
|
||||
log.Errorf("failed check chain %s,error: %v", chainRTNAT, err)
|
||||
return err
|
||||
} else if ok {
|
||||
err = i.iptablesClient.ClearAndDeleteChain(tableNat, chainRTNAT)
|
||||
if err != nil {
|
||||
log.Errorf("failed cleaning chain %s,error: %v", chainRTNAT, err)
|
||||
return err
|
||||
return fmt.Errorf("failed to insert established rule: %v", err)
|
||||
}
|
||||
|
||||
ruleKey := "established-" + chain
|
||||
r.rules[ruleKey] = establishedRule
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *router) addJumpRules() error {
|
||||
// Jump to NAT chain
|
||||
natRule := []string{"-j", chainRTNAT}
|
||||
if err := r.iptablesClient.Insert(tableNat, chainPOSTROUTING, 1, natRule...); err != nil {
|
||||
return fmt.Errorf("add nat jump rule: %v", err)
|
||||
}
|
||||
r.rules[jumpNat] = natRule
|
||||
|
||||
// Jump to prerouting chain
|
||||
preRule := []string{"-j", chainRTPRE}
|
||||
if err := r.iptablesClient.Insert(tableMangle, chainPREROUTING, 1, preRule...); err != nil {
|
||||
return fmt.Errorf("add prerouting jump rule: %v", err)
|
||||
}
|
||||
r.rules[jumpPre] = preRule
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *router) cleanJumpRules() error {
|
||||
for _, ruleKey := range []string{jumpNat, jumpPre} {
|
||||
if rule, exists := r.rules[ruleKey]; exists {
|
||||
table := tableNat
|
||||
chain := chainPOSTROUTING
|
||||
if ruleKey == jumpPre {
|
||||
table = tableMangle
|
||||
chain = chainPREROUTING
|
||||
}
|
||||
|
||||
if err := r.iptablesClient.DeleteIfExists(table, chain, rule...); err != nil {
|
||||
return fmt.Errorf("delete rule from chain %s in table %s, err: %v", chain, table, err)
|
||||
}
|
||||
delete(r.rules, ruleKey)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (i *routerManager) createContainers() error {
|
||||
if i.rules[Ipv4Forwarding] != nil {
|
||||
func (r *router) addNatRule(pair firewall.RouterPair) error {
|
||||
ruleKey := firewall.GenKey(firewall.NatFormat, pair)
|
||||
|
||||
if rule, exists := r.rules[ruleKey]; exists {
|
||||
if err := r.iptablesClient.DeleteIfExists(tableMangle, chainRTPRE, rule...); err != nil {
|
||||
return fmt.Errorf("error while removing existing marking rule for %s: %v", pair.Destination, err)
|
||||
}
|
||||
delete(r.rules, ruleKey)
|
||||
}
|
||||
|
||||
markValue := nbnet.PreroutingFwmarkMasquerade
|
||||
if pair.Inverse {
|
||||
markValue = nbnet.PreroutingFwmarkMasqueradeReturn
|
||||
}
|
||||
|
||||
rule := []string{"-i", r.wgIface.Name()}
|
||||
if pair.Inverse {
|
||||
rule = []string{"!", "-i", r.wgIface.Name()}
|
||||
}
|
||||
|
||||
rule = append(rule,
|
||||
"-m", "conntrack",
|
||||
"--ctstate", "NEW",
|
||||
"-s", pair.Source.String(),
|
||||
"-d", pair.Destination.String(),
|
||||
"-j", "MARK", "--set-mark", fmt.Sprintf("%#x", markValue),
|
||||
)
|
||||
|
||||
if err := r.iptablesClient.Append(tableMangle, chainRTPRE, rule...); err != nil {
|
||||
return fmt.Errorf("error while adding marking rule for %s: %v", pair.Destination, err)
|
||||
}
|
||||
|
||||
r.rules[ruleKey] = rule
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *router) removeNatRule(pair firewall.RouterPair) error {
|
||||
ruleKey := firewall.GenKey(firewall.NatFormat, pair)
|
||||
|
||||
if rule, exists := r.rules[ruleKey]; exists {
|
||||
if err := r.iptablesClient.DeleteIfExists(tableMangle, chainRTPRE, rule...); err != nil {
|
||||
return fmt.Errorf("error while removing marking rule for %s: %v", pair.Destination, err)
|
||||
}
|
||||
delete(r.rules, ruleKey)
|
||||
} else {
|
||||
log.Debugf("marking rule %s not found", ruleKey)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *router) updateState() {
|
||||
if r.stateManager == nil {
|
||||
return
|
||||
}
|
||||
|
||||
var currentState *ShutdownState
|
||||
if existing := r.stateManager.GetState(currentState); existing != nil {
|
||||
if existingState, ok := existing.(*ShutdownState); ok {
|
||||
currentState = existingState
|
||||
}
|
||||
}
|
||||
if currentState == nil {
|
||||
currentState = &ShutdownState{}
|
||||
}
|
||||
|
||||
currentState.Lock()
|
||||
defer currentState.Unlock()
|
||||
|
||||
currentState.RouteRules = r.rules
|
||||
currentState.RouteIPsetCounter = r.ipsetCounter
|
||||
|
||||
if err := r.stateManager.UpdateState(currentState); err != nil {
|
||||
log.Errorf("failed to update state: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func genRouteFilteringRuleSpec(params routeFilteringRuleParams) []string {
|
||||
var rule []string
|
||||
|
||||
if params.SetName != "" {
|
||||
rule = append(rule, "-m", "set", matchSet, params.SetName, "src")
|
||||
} else if len(params.Sources) > 0 {
|
||||
source := params.Sources[0]
|
||||
rule = append(rule, "-s", source.String())
|
||||
}
|
||||
|
||||
rule = append(rule, "-d", params.Destination.String())
|
||||
|
||||
if params.Proto != firewall.ProtocolALL {
|
||||
rule = append(rule, "-p", strings.ToLower(string(params.Proto)))
|
||||
rule = append(rule, applyPort("--sport", params.SPort)...)
|
||||
rule = append(rule, applyPort("--dport", params.DPort)...)
|
||||
}
|
||||
|
||||
rule = append(rule, "-j", actionToStr(params.Action))
|
||||
|
||||
return rule
|
||||
}
|
||||
|
||||
func applyPort(flag string, port *firewall.Port) []string {
|
||||
if port == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
errMSGFormat := "failed creating chain %s,error: %v"
|
||||
err := i.createChain(tableFilter, chainRTFWD)
|
||||
if err != nil {
|
||||
return fmt.Errorf(errMSGFormat, chainRTFWD, err)
|
||||
if port.IsRange && len(port.Values) == 2 {
|
||||
return []string{flag, fmt.Sprintf("%d:%d", port.Values[0], port.Values[1])}
|
||||
}
|
||||
|
||||
err = i.createChain(tableNat, chainRTNAT)
|
||||
if err != nil {
|
||||
return fmt.Errorf(errMSGFormat, chainRTNAT, err)
|
||||
if len(port.Values) > 1 {
|
||||
portList := make([]string, len(port.Values))
|
||||
for i, p := range port.Values {
|
||||
portList[i] = strconv.Itoa(p)
|
||||
}
|
||||
return []string{"-m", "multiport", flag, strings.Join(portList, ",")}
|
||||
}
|
||||
|
||||
err = i.addJumpRules()
|
||||
if err != nil {
|
||||
return fmt.Errorf("error while creating jump rules: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// addJumpRules create jump rules to send packets to NetBird chains
|
||||
func (i *routerManager) addJumpRules() error {
|
||||
rule := []string{"-j", chainRTFWD}
|
||||
err := i.iptablesClient.Insert(tableFilter, chainFORWARD, 1, rule...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
i.rules[Ipv4Forwarding] = rule
|
||||
|
||||
rule = []string{"-j", chainRTNAT}
|
||||
err = i.iptablesClient.Insert(tableNat, chainPOSTROUTING, 1, rule...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
i.rules[ipv4Nat] = rule
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// cleanJumpRules cleans jump rules that was sending packets to NetBird chains
|
||||
func (i *routerManager) cleanJumpRules() error {
|
||||
var err error
|
||||
errMSGFormat := "failed cleaning rule from chain %s,err: %v"
|
||||
rule, found := i.rules[Ipv4Forwarding]
|
||||
if found {
|
||||
err = i.iptablesClient.DeleteIfExists(tableFilter, chainFORWARD, rule...)
|
||||
if err != nil {
|
||||
return fmt.Errorf(errMSGFormat, chainFORWARD, err)
|
||||
}
|
||||
}
|
||||
rule, found = i.rules[ipv4Nat]
|
||||
if found {
|
||||
err = i.iptablesClient.DeleteIfExists(tableNat, chainPOSTROUTING, rule...)
|
||||
if err != nil {
|
||||
return fmt.Errorf(errMSGFormat, chainPOSTROUTING, err)
|
||||
}
|
||||
}
|
||||
|
||||
rules, err := i.iptablesClient.List("nat", "POSTROUTING")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to list rules: %s", err)
|
||||
}
|
||||
|
||||
for _, ruleString := range rules {
|
||||
if !strings.Contains(ruleString, "NETBIRD") {
|
||||
continue
|
||||
}
|
||||
rule := strings.Fields(ruleString)
|
||||
err := i.iptablesClient.DeleteIfExists("nat", "POSTROUTING", rule[2:]...)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to delete postrouting jump rule: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
rules, err = i.iptablesClient.List(tableFilter, "FORWARD")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to list rules in FORWARD chain: %s", err)
|
||||
}
|
||||
|
||||
for _, ruleString := range rules {
|
||||
if !strings.Contains(ruleString, "NETBIRD") {
|
||||
continue
|
||||
}
|
||||
rule := strings.Fields(ruleString)
|
||||
err := i.iptablesClient.DeleteIfExists(tableFilter, "FORWARD", rule[2:]...)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to delete FORWARD jump rule: %s", err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (i *routerManager) createChain(table, newChain string) error {
|
||||
chains, err := i.iptablesClient.ListChains(table)
|
||||
if err != nil {
|
||||
return fmt.Errorf("couldn't get %s table chains, error: %v", table, err)
|
||||
}
|
||||
|
||||
shouldCreateChain := true
|
||||
for _, chain := range chains {
|
||||
if chain == newChain {
|
||||
shouldCreateChain = false
|
||||
}
|
||||
}
|
||||
|
||||
if shouldCreateChain {
|
||||
err = i.iptablesClient.NewChain(table, newChain)
|
||||
if err != nil {
|
||||
return fmt.Errorf("couldn't create chain %s in %s table, error: %v", newChain, table, err)
|
||||
}
|
||||
|
||||
// Add the loopback return rule to the NAT chain
|
||||
loopbackRule := []string{"-o", "lo", "-j", "RETURN"}
|
||||
err = i.iptablesClient.Insert(table, newChain, 1, loopbackRule...)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to add loopback return rule to %s: %v", chainRTNAT, err)
|
||||
}
|
||||
|
||||
err = i.iptablesClient.Append(table, newChain, "-j", "RETURN")
|
||||
if err != nil {
|
||||
return fmt.Errorf("couldn't create chain %s default rule, error: %v", newChain, err)
|
||||
}
|
||||
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// addNATRule appends an iptables rule pair to the nat chain
|
||||
func (i *routerManager) addNATRule(keyFormat, table, chain, jump string, pair firewall.RouterPair) error {
|
||||
ruleKey := firewall.GenKey(keyFormat, pair.ID)
|
||||
rule := genRuleSpec(jump, pair.Source, pair.Destination)
|
||||
existingRule, found := i.rules[ruleKey]
|
||||
if found {
|
||||
err := i.iptablesClient.DeleteIfExists(table, chain, existingRule...)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error while removing existing NAT rule for %s: %v", pair.Destination, err)
|
||||
}
|
||||
delete(i.rules, ruleKey)
|
||||
}
|
||||
|
||||
// inserting after loopback ignore rule
|
||||
err := i.iptablesClient.Insert(table, chain, 2, rule...)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error while appending new NAT rule for %s: %v", pair.Destination, err)
|
||||
}
|
||||
|
||||
i.rules[ruleKey] = rule
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// genRuleSpec generates rule specification
|
||||
func genRuleSpec(jump, source, destination string) []string {
|
||||
return []string{"-s", source, "-d", destination, "-j", jump}
|
||||
}
|
||||
|
||||
func getIptablesRuleType(table string) string {
|
||||
ruleType := "forwarding"
|
||||
if table == tableNat {
|
||||
ruleType = "nat"
|
||||
}
|
||||
return ruleType
|
||||
return []string{flag, strconv.Itoa(port.Values[0])}
|
||||
}
|
||||
|
||||
@@ -3,16 +3,18 @@
|
||||
package iptables
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"os/exec"
|
||||
"testing"
|
||||
|
||||
"github.com/coreos/go-iptables/iptables"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/firewall/test"
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
)
|
||||
|
||||
func isIptablesSupported() bool {
|
||||
@@ -28,45 +30,45 @@ func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) {
|
||||
iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
||||
require.NoError(t, err, "failed to init iptables client")
|
||||
|
||||
manager, err := newRouterManager(context.TODO(), iptablesClient)
|
||||
manager, err := newRouter(iptablesClient, ifaceMock)
|
||||
require.NoError(t, err, "should return a valid iptables manager")
|
||||
require.NoError(t, manager.init(nil))
|
||||
|
||||
defer func() {
|
||||
_ = manager.Reset()
|
||||
assert.NoError(t, manager.Reset(), "shouldn't return error")
|
||||
}()
|
||||
|
||||
require.Len(t, manager.rules, 2, "should have created rules map")
|
||||
// Now 5 rules:
|
||||
// 1. established rule in forward chain
|
||||
// 2. jump rule to NAT chain
|
||||
// 3. jump rule to PRE chain
|
||||
// 4. static outbound masquerade rule
|
||||
// 5. static return masquerade rule
|
||||
require.Len(t, manager.rules, 5, "should have created rules map")
|
||||
|
||||
exists, err := manager.iptablesClient.Exists(tableFilter, chainFORWARD, manager.rules[Ipv4Forwarding]...)
|
||||
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableFilter, chainFORWARD)
|
||||
require.True(t, exists, "forwarding rule should exist")
|
||||
|
||||
exists, err = manager.iptablesClient.Exists(tableNat, chainPOSTROUTING, manager.rules[ipv4Nat]...)
|
||||
exists, err := manager.iptablesClient.Exists(tableNat, chainPOSTROUTING, "-j", chainRTNAT)
|
||||
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainPOSTROUTING)
|
||||
require.True(t, exists, "postrouting rule should exist")
|
||||
require.True(t, exists, "postrouting jump rule should exist")
|
||||
|
||||
exists, err = manager.iptablesClient.Exists(tableMangle, chainPREROUTING, "-j", chainRTPRE)
|
||||
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableMangle, chainPREROUTING)
|
||||
require.True(t, exists, "prerouting jump rule should exist")
|
||||
|
||||
pair := firewall.RouterPair{
|
||||
ID: "abc",
|
||||
Source: "100.100.100.1/32",
|
||||
Destination: "100.100.100.0/24",
|
||||
Source: netip.MustParsePrefix("100.100.100.1/32"),
|
||||
Destination: netip.MustParsePrefix("100.100.100.0/24"),
|
||||
Masquerade: true,
|
||||
}
|
||||
forward4Rule := genRuleSpec(routingFinalForwardJump, pair.Source, pair.Destination)
|
||||
|
||||
err = manager.iptablesClient.Insert(tableFilter, chainRTFWD, 1, forward4Rule...)
|
||||
require.NoError(t, err, "inserting rule should not return error")
|
||||
|
||||
nat4Rule := genRuleSpec(routingFinalNatJump, pair.Source, pair.Destination)
|
||||
|
||||
err = manager.iptablesClient.Insert(tableNat, chainRTNAT, 1, nat4Rule...)
|
||||
require.NoError(t, err, "inserting rule should not return error")
|
||||
err = manager.AddNatRule(pair)
|
||||
require.NoError(t, err, "adding NAT rule should not return error")
|
||||
|
||||
err = manager.Reset()
|
||||
require.NoError(t, err, "shouldn't return error")
|
||||
}
|
||||
|
||||
func TestIptablesManager_InsertRoutingRules(t *testing.T) {
|
||||
|
||||
func TestIptablesManager_AddNatRule(t *testing.T) {
|
||||
if !isIptablesSupported() {
|
||||
t.SkipNow()
|
||||
}
|
||||
@@ -76,78 +78,71 @@ func TestIptablesManager_InsertRoutingRules(t *testing.T) {
|
||||
iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
||||
require.NoError(t, err, "failed to init iptables client")
|
||||
|
||||
manager, err := newRouterManager(context.TODO(), iptablesClient)
|
||||
manager, err := newRouter(iptablesClient, ifaceMock)
|
||||
require.NoError(t, err, "shouldn't return error")
|
||||
require.NoError(t, manager.init(nil))
|
||||
|
||||
defer func() {
|
||||
err := manager.Reset()
|
||||
if err != nil {
|
||||
log.Errorf("failed to reset iptables manager: %s", err)
|
||||
}
|
||||
assert.NoError(t, manager.Reset(), "shouldn't return error")
|
||||
}()
|
||||
|
||||
err = manager.InsertRoutingRules(testCase.InputPair)
|
||||
require.NoError(t, err, "forwarding pair should be inserted")
|
||||
err = manager.AddNatRule(testCase.InputPair)
|
||||
require.NoError(t, err, "marking rule should be inserted")
|
||||
|
||||
forwardRuleKey := firewall.GenKey(firewall.ForwardingFormat, testCase.InputPair.ID)
|
||||
forwardRule := genRuleSpec(routingFinalForwardJump, testCase.InputPair.Source, testCase.InputPair.Destination)
|
||||
|
||||
exists, err := iptablesClient.Exists(tableFilter, chainRTFWD, forwardRule...)
|
||||
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableFilter, chainRTFWD)
|
||||
require.True(t, exists, "forwarding rule should exist")
|
||||
|
||||
foundRule, found := manager.rules[forwardRuleKey]
|
||||
require.True(t, found, "forwarding rule should exist in the manager map")
|
||||
require.Equal(t, forwardRule[:4], foundRule[:4], "stored forwarding rule should match")
|
||||
|
||||
inForwardRuleKey := firewall.GenKey(firewall.InForwardingFormat, testCase.InputPair.ID)
|
||||
inForwardRule := genRuleSpec(routingFinalForwardJump, firewall.GetInPair(testCase.InputPair).Source, firewall.GetInPair(testCase.InputPair).Destination)
|
||||
|
||||
exists, err = iptablesClient.Exists(tableFilter, chainRTFWD, inForwardRule...)
|
||||
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableFilter, chainRTFWD)
|
||||
require.True(t, exists, "income forwarding rule should exist")
|
||||
|
||||
foundRule, found = manager.rules[inForwardRuleKey]
|
||||
require.True(t, found, "income forwarding rule should exist in the manager map")
|
||||
require.Equal(t, inForwardRule[:4], foundRule[:4], "stored income forwarding rule should match")
|
||||
|
||||
natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair.ID)
|
||||
natRule := genRuleSpec(routingFinalNatJump, testCase.InputPair.Source, testCase.InputPair.Destination)
|
||||
|
||||
exists, err = iptablesClient.Exists(tableNat, chainRTNAT, natRule...)
|
||||
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainRTNAT)
|
||||
if testCase.InputPair.Masquerade {
|
||||
require.True(t, exists, "nat rule should be created")
|
||||
foundNatRule, foundNat := manager.rules[natRuleKey]
|
||||
require.True(t, foundNat, "nat rule should exist in the map")
|
||||
require.Equal(t, natRule[:4], foundNatRule[:4], "stored nat rule should match")
|
||||
} else {
|
||||
require.False(t, exists, "nat rule should not be created")
|
||||
_, foundNat := manager.rules[natRuleKey]
|
||||
require.False(t, foundNat, "nat rule should not exist in the map")
|
||||
natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair)
|
||||
markingRule := []string{
|
||||
"-i", ifaceMock.Name(),
|
||||
"-m", "conntrack",
|
||||
"--ctstate", "NEW",
|
||||
"-s", testCase.InputPair.Source.String(),
|
||||
"-d", testCase.InputPair.Destination.String(),
|
||||
"-j", "MARK", "--set-mark",
|
||||
fmt.Sprintf("%#x", nbnet.PreroutingFwmarkMasquerade),
|
||||
}
|
||||
|
||||
inNatRuleKey := firewall.GenKey(firewall.InNatFormat, testCase.InputPair.ID)
|
||||
inNatRule := genRuleSpec(routingFinalNatJump, firewall.GetInPair(testCase.InputPair).Source, firewall.GetInPair(testCase.InputPair).Destination)
|
||||
|
||||
exists, err = iptablesClient.Exists(tableNat, chainRTNAT, inNatRule...)
|
||||
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainRTNAT)
|
||||
exists, err := iptablesClient.Exists(tableMangle, chainRTPRE, markingRule...)
|
||||
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableMangle, chainRTPRE)
|
||||
if testCase.InputPair.Masquerade {
|
||||
require.True(t, exists, "income nat rule should be created")
|
||||
foundNatRule, foundNat := manager.rules[inNatRuleKey]
|
||||
require.True(t, foundNat, "income nat rule should exist in the map")
|
||||
require.Equal(t, inNatRule[:4], foundNatRule[:4], "stored income nat rule should match")
|
||||
require.True(t, exists, "marking rule should be created")
|
||||
foundRule, found := manager.rules[natRuleKey]
|
||||
require.True(t, found, "marking rule should exist in the map")
|
||||
require.Equal(t, markingRule, foundRule, "stored marking rule should match")
|
||||
} else {
|
||||
require.False(t, exists, "nat rule should not be created")
|
||||
_, foundNat := manager.rules[inNatRuleKey]
|
||||
require.False(t, foundNat, "income nat rule should not exist in the map")
|
||||
require.False(t, exists, "marking rule should not be created")
|
||||
_, found := manager.rules[natRuleKey]
|
||||
require.False(t, found, "marking rule should not exist in the map")
|
||||
}
|
||||
|
||||
// Check inverse rule
|
||||
inversePair := firewall.GetInversePair(testCase.InputPair)
|
||||
inverseRuleKey := firewall.GenKey(firewall.NatFormat, inversePair)
|
||||
inverseMarkingRule := []string{
|
||||
"!", "-i", ifaceMock.Name(),
|
||||
"-m", "conntrack",
|
||||
"--ctstate", "NEW",
|
||||
"-s", inversePair.Source.String(),
|
||||
"-d", inversePair.Destination.String(),
|
||||
"-j", "MARK", "--set-mark",
|
||||
fmt.Sprintf("%#x", nbnet.PreroutingFwmarkMasqueradeReturn),
|
||||
}
|
||||
|
||||
exists, err = iptablesClient.Exists(tableMangle, chainRTPRE, inverseMarkingRule...)
|
||||
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableMangle, chainRTPRE)
|
||||
if testCase.InputPair.Masquerade {
|
||||
require.True(t, exists, "inverse marking rule should be created")
|
||||
foundRule, found := manager.rules[inverseRuleKey]
|
||||
require.True(t, found, "inverse marking rule should exist in the map")
|
||||
require.Equal(t, inverseMarkingRule, foundRule, "stored inverse marking rule should match")
|
||||
} else {
|
||||
require.False(t, exists, "inverse marking rule should not be created")
|
||||
_, found := manager.rules[inverseRuleKey]
|
||||
require.False(t, found, "inverse marking rule should not exist in the map")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIptablesManager_RemoveRoutingRules(t *testing.T) {
|
||||
|
||||
func TestIptablesManager_RemoveNatRule(t *testing.T) {
|
||||
if !isIptablesSupported() {
|
||||
t.SkipNow()
|
||||
}
|
||||
@@ -156,72 +151,226 @@ func TestIptablesManager_RemoveRoutingRules(t *testing.T) {
|
||||
t.Run(testCase.Name, func(t *testing.T) {
|
||||
iptablesClient, _ := iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
||||
|
||||
manager, err := newRouterManager(context.TODO(), iptablesClient)
|
||||
manager, err := newRouter(iptablesClient, ifaceMock)
|
||||
require.NoError(t, err, "shouldn't return error")
|
||||
require.NoError(t, manager.init(nil))
|
||||
defer func() {
|
||||
_ = manager.Reset()
|
||||
assert.NoError(t, manager.Reset(), "shouldn't return error")
|
||||
}()
|
||||
|
||||
err = manager.AddNatRule(testCase.InputPair)
|
||||
require.NoError(t, err, "should add NAT rule without error")
|
||||
|
||||
err = manager.RemoveNatRule(testCase.InputPair)
|
||||
require.NoError(t, err, "shouldn't return error")
|
||||
|
||||
forwardRuleKey := firewall.GenKey(firewall.ForwardingFormat, testCase.InputPair.ID)
|
||||
forwardRule := genRuleSpec(routingFinalForwardJump, testCase.InputPair.Source, testCase.InputPair.Destination)
|
||||
natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair)
|
||||
markingRule := []string{
|
||||
"-i", ifaceMock.Name(),
|
||||
"-m", "conntrack",
|
||||
"--ctstate", "NEW",
|
||||
"-s", testCase.InputPair.Source.String(),
|
||||
"-d", testCase.InputPair.Destination.String(),
|
||||
"-j", "MARK", "--set-mark",
|
||||
fmt.Sprintf("%#x", nbnet.PreroutingFwmarkMasquerade),
|
||||
}
|
||||
|
||||
err = iptablesClient.Insert(tableFilter, chainRTFWD, 1, forwardRule...)
|
||||
require.NoError(t, err, "inserting rule should not return error")
|
||||
exists, err := iptablesClient.Exists(tableMangle, chainRTPRE, markingRule...)
|
||||
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableMangle, chainRTPRE)
|
||||
require.False(t, exists, "marking rule should not exist")
|
||||
|
||||
inForwardRuleKey := firewall.GenKey(firewall.InForwardingFormat, testCase.InputPair.ID)
|
||||
inForwardRule := genRuleSpec(routingFinalForwardJump, firewall.GetInPair(testCase.InputPair).Source, firewall.GetInPair(testCase.InputPair).Destination)
|
||||
_, found := manager.rules[natRuleKey]
|
||||
require.False(t, found, "marking rule should not exist in the manager map")
|
||||
|
||||
err = iptablesClient.Insert(tableFilter, chainRTFWD, 1, inForwardRule...)
|
||||
require.NoError(t, err, "inserting rule should not return error")
|
||||
// Check inverse rule removal
|
||||
inversePair := firewall.GetInversePair(testCase.InputPair)
|
||||
inverseRuleKey := firewall.GenKey(firewall.NatFormat, inversePair)
|
||||
inverseMarkingRule := []string{
|
||||
"!", "-i", ifaceMock.Name(),
|
||||
"-m", "conntrack",
|
||||
"--ctstate", "NEW",
|
||||
"-s", inversePair.Source.String(),
|
||||
"-d", inversePair.Destination.String(),
|
||||
"-j", "MARK", "--set-mark",
|
||||
fmt.Sprintf("%#x", nbnet.PreroutingFwmarkMasqueradeReturn),
|
||||
}
|
||||
|
||||
natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair.ID)
|
||||
natRule := genRuleSpec(routingFinalNatJump, testCase.InputPair.Source, testCase.InputPair.Destination)
|
||||
|
||||
err = iptablesClient.Insert(tableNat, chainRTNAT, 1, natRule...)
|
||||
require.NoError(t, err, "inserting rule should not return error")
|
||||
|
||||
inNatRuleKey := firewall.GenKey(firewall.InNatFormat, testCase.InputPair.ID)
|
||||
inNatRule := genRuleSpec(routingFinalNatJump, firewall.GetInPair(testCase.InputPair).Source, firewall.GetInPair(testCase.InputPair).Destination)
|
||||
|
||||
err = iptablesClient.Insert(tableNat, chainRTNAT, 1, inNatRule...)
|
||||
require.NoError(t, err, "inserting rule should not return error")
|
||||
|
||||
err = manager.Reset()
|
||||
require.NoError(t, err, "shouldn't return error")
|
||||
|
||||
err = manager.RemoveRoutingRules(testCase.InputPair)
|
||||
require.NoError(t, err, "shouldn't return error")
|
||||
|
||||
exists, err := iptablesClient.Exists(tableFilter, chainRTFWD, forwardRule...)
|
||||
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableFilter, chainRTFWD)
|
||||
require.False(t, exists, "forwarding rule should not exist")
|
||||
|
||||
_, found := manager.rules[forwardRuleKey]
|
||||
require.False(t, found, "forwarding rule should exist in the manager map")
|
||||
|
||||
exists, err = iptablesClient.Exists(tableFilter, chainRTFWD, inForwardRule...)
|
||||
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableFilter, chainRTFWD)
|
||||
require.False(t, exists, "income forwarding rule should not exist")
|
||||
|
||||
_, found = manager.rules[inForwardRuleKey]
|
||||
require.False(t, found, "income forwarding rule should exist in the manager map")
|
||||
|
||||
exists, err = iptablesClient.Exists(tableNat, chainRTNAT, natRule...)
|
||||
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainRTNAT)
|
||||
require.False(t, exists, "nat rule should not exist")
|
||||
|
||||
_, found = manager.rules[natRuleKey]
|
||||
require.False(t, found, "nat rule should exist in the manager map")
|
||||
|
||||
exists, err = iptablesClient.Exists(tableNat, chainRTNAT, inNatRule...)
|
||||
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainRTNAT)
|
||||
require.False(t, exists, "income nat rule should not exist")
|
||||
|
||||
_, found = manager.rules[inNatRuleKey]
|
||||
require.False(t, found, "income nat rule should exist in the manager map")
|
||||
exists, err = iptablesClient.Exists(tableMangle, chainRTPRE, inverseMarkingRule...)
|
||||
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableMangle, chainRTPRE)
|
||||
require.False(t, exists, "inverse marking rule should not exist")
|
||||
|
||||
_, found = manager.rules[inverseRuleKey]
|
||||
require.False(t, found, "inverse marking rule should not exist in the map")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouter_AddRouteFiltering(t *testing.T) {
|
||||
if !isIptablesSupported() {
|
||||
t.Skip("iptables not supported on this system")
|
||||
}
|
||||
|
||||
iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
||||
require.NoError(t, err, "Failed to create iptables client")
|
||||
|
||||
r, err := newRouter(iptablesClient, ifaceMock)
|
||||
require.NoError(t, err, "Failed to create router manager")
|
||||
require.NoError(t, r.init(nil))
|
||||
|
||||
defer func() {
|
||||
err := r.Reset()
|
||||
require.NoError(t, err, "Failed to reset router")
|
||||
}()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
sources []netip.Prefix
|
||||
destination netip.Prefix
|
||||
proto firewall.Protocol
|
||||
sPort *firewall.Port
|
||||
dPort *firewall.Port
|
||||
direction firewall.RuleDirection
|
||||
action firewall.Action
|
||||
expectSet bool
|
||||
}{
|
||||
{
|
||||
name: "Basic TCP rule with single source",
|
||||
sources: []netip.Prefix{netip.MustParsePrefix("192.168.1.0/24")},
|
||||
destination: netip.MustParsePrefix("10.0.0.0/24"),
|
||||
proto: firewall.ProtocolTCP,
|
||||
sPort: nil,
|
||||
dPort: &firewall.Port{Values: []int{80}},
|
||||
direction: firewall.RuleDirectionIN,
|
||||
action: firewall.ActionAccept,
|
||||
expectSet: false,
|
||||
},
|
||||
{
|
||||
name: "UDP rule with multiple sources",
|
||||
sources: []netip.Prefix{
|
||||
netip.MustParsePrefix("172.16.0.0/16"),
|
||||
netip.MustParsePrefix("192.168.0.0/16"),
|
||||
},
|
||||
destination: netip.MustParsePrefix("10.0.0.0/8"),
|
||||
proto: firewall.ProtocolUDP,
|
||||
sPort: &firewall.Port{Values: []int{1024, 2048}, IsRange: true},
|
||||
dPort: nil,
|
||||
direction: firewall.RuleDirectionOUT,
|
||||
action: firewall.ActionDrop,
|
||||
expectSet: true,
|
||||
},
|
||||
{
|
||||
name: "All protocols rule",
|
||||
sources: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/8")},
|
||||
destination: netip.MustParsePrefix("0.0.0.0/0"),
|
||||
proto: firewall.ProtocolALL,
|
||||
sPort: nil,
|
||||
dPort: nil,
|
||||
direction: firewall.RuleDirectionIN,
|
||||
action: firewall.ActionAccept,
|
||||
expectSet: false,
|
||||
},
|
||||
{
|
||||
name: "ICMP rule",
|
||||
sources: []netip.Prefix{netip.MustParsePrefix("192.168.0.0/16")},
|
||||
destination: netip.MustParsePrefix("10.0.0.0/8"),
|
||||
proto: firewall.ProtocolICMP,
|
||||
sPort: nil,
|
||||
dPort: nil,
|
||||
direction: firewall.RuleDirectionIN,
|
||||
action: firewall.ActionAccept,
|
||||
expectSet: false,
|
||||
},
|
||||
{
|
||||
name: "TCP rule with multiple source ports",
|
||||
sources: []netip.Prefix{netip.MustParsePrefix("172.16.0.0/12")},
|
||||
destination: netip.MustParsePrefix("192.168.0.0/16"),
|
||||
proto: firewall.ProtocolTCP,
|
||||
sPort: &firewall.Port{Values: []int{80, 443, 8080}},
|
||||
dPort: nil,
|
||||
direction: firewall.RuleDirectionOUT,
|
||||
action: firewall.ActionAccept,
|
||||
expectSet: false,
|
||||
},
|
||||
{
|
||||
name: "UDP rule with single IP and port range",
|
||||
sources: []netip.Prefix{netip.MustParsePrefix("192.168.1.1/32")},
|
||||
destination: netip.MustParsePrefix("10.0.0.0/24"),
|
||||
proto: firewall.ProtocolUDP,
|
||||
sPort: nil,
|
||||
dPort: &firewall.Port{Values: []int{5000, 5100}, IsRange: true},
|
||||
direction: firewall.RuleDirectionIN,
|
||||
action: firewall.ActionDrop,
|
||||
expectSet: false,
|
||||
},
|
||||
{
|
||||
name: "TCP rule with source and destination ports",
|
||||
sources: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/24")},
|
||||
destination: netip.MustParsePrefix("172.16.0.0/16"),
|
||||
proto: firewall.ProtocolTCP,
|
||||
sPort: &firewall.Port{Values: []int{1024, 65535}, IsRange: true},
|
||||
dPort: &firewall.Port{Values: []int{22}},
|
||||
direction: firewall.RuleDirectionOUT,
|
||||
action: firewall.ActionAccept,
|
||||
expectSet: false,
|
||||
},
|
||||
{
|
||||
name: "Drop all incoming traffic",
|
||||
sources: []netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")},
|
||||
destination: netip.MustParsePrefix("192.168.0.0/24"),
|
||||
proto: firewall.ProtocolALL,
|
||||
sPort: nil,
|
||||
dPort: nil,
|
||||
direction: firewall.RuleDirectionIN,
|
||||
action: firewall.ActionDrop,
|
||||
expectSet: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ruleKey, err := r.AddRouteFiltering(tt.sources, tt.destination, tt.proto, tt.sPort, tt.dPort, tt.action)
|
||||
require.NoError(t, err, "AddRouteFiltering failed")
|
||||
|
||||
// Check if the rule is in the internal map
|
||||
rule, ok := r.rules[ruleKey.GetRuleID()]
|
||||
assert.True(t, ok, "Rule not found in internal map")
|
||||
|
||||
// Log the internal rule
|
||||
t.Logf("Internal rule: %v", rule)
|
||||
|
||||
// Check if the rule exists in iptables
|
||||
exists, err := iptablesClient.Exists(tableFilter, chainRTFWD, rule...)
|
||||
assert.NoError(t, err, "Failed to check rule existence")
|
||||
assert.True(t, exists, "Rule not found in iptables")
|
||||
|
||||
// Verify rule content
|
||||
params := routeFilteringRuleParams{
|
||||
Sources: tt.sources,
|
||||
Destination: tt.destination,
|
||||
Proto: tt.proto,
|
||||
SPort: tt.sPort,
|
||||
DPort: tt.dPort,
|
||||
Action: tt.action,
|
||||
SetName: "",
|
||||
}
|
||||
|
||||
expectedRule := genRouteFilteringRuleSpec(params)
|
||||
|
||||
if tt.expectSet {
|
||||
setName := firewall.GenerateSetName(tt.sources)
|
||||
params.SetName = setName
|
||||
expectedRule = genRouteFilteringRuleSpec(params)
|
||||
|
||||
// Check if the set was created
|
||||
_, exists := r.ipsetCounter.Get(setName)
|
||||
assert.True(t, exists, "IPSet not created")
|
||||
}
|
||||
|
||||
assert.Equal(t, expectedRule, rule, "Rule content mismatch")
|
||||
|
||||
// Clean up
|
||||
err = r.DeleteRouteRule(ruleKey)
|
||||
require.NoError(t, err, "Failed to delete rule")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,14 +1,16 @@
|
||||
package iptables
|
||||
|
||||
import "encoding/json"
|
||||
|
||||
type ipList struct {
|
||||
ips map[string]struct{}
|
||||
}
|
||||
|
||||
func newIpList(ip string) ipList {
|
||||
func newIpList(ip string) *ipList {
|
||||
ips := make(map[string]struct{})
|
||||
ips[ip] = struct{}{}
|
||||
|
||||
return ipList{
|
||||
return &ipList{
|
||||
ips: ips,
|
||||
}
|
||||
}
|
||||
@@ -17,27 +19,52 @@ func (s *ipList) addIP(ip string) {
|
||||
s.ips[ip] = struct{}{}
|
||||
}
|
||||
|
||||
// MarshalJSON implements json.Marshaler
|
||||
func (s *ipList) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(struct {
|
||||
IPs map[string]struct{} `json:"ips"`
|
||||
}{
|
||||
IPs: s.ips,
|
||||
})
|
||||
}
|
||||
|
||||
// UnmarshalJSON implements json.Unmarshaler
|
||||
func (s *ipList) UnmarshalJSON(data []byte) error {
|
||||
temp := struct {
|
||||
IPs map[string]struct{} `json:"ips"`
|
||||
}{}
|
||||
if err := json.Unmarshal(data, &temp); err != nil {
|
||||
return err
|
||||
}
|
||||
s.ips = temp.IPs
|
||||
|
||||
if temp.IPs == nil {
|
||||
temp.IPs = make(map[string]struct{})
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
type ipsetStore struct {
|
||||
ipsets map[string]ipList // ipsetName -> ruleset
|
||||
ipsets map[string]*ipList
|
||||
}
|
||||
|
||||
func newIpsetStore() *ipsetStore {
|
||||
return &ipsetStore{
|
||||
ipsets: make(map[string]ipList),
|
||||
ipsets: make(map[string]*ipList),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ipsetStore) ipset(ipsetName string) (ipList, bool) {
|
||||
func (s *ipsetStore) ipset(ipsetName string) (*ipList, bool) {
|
||||
r, ok := s.ipsets[ipsetName]
|
||||
return r, ok
|
||||
}
|
||||
|
||||
func (s *ipsetStore) addIpList(ipsetName string, list ipList) {
|
||||
func (s *ipsetStore) addIpList(ipsetName string, list *ipList) {
|
||||
s.ipsets[ipsetName] = list
|
||||
}
|
||||
|
||||
func (s *ipsetStore) deleteIpset(ipsetName string) {
|
||||
s.ipsets[ipsetName] = ipList{}
|
||||
delete(s.ipsets, ipsetName)
|
||||
}
|
||||
|
||||
@@ -48,3 +75,29 @@ func (s *ipsetStore) ipsetNames() []string {
|
||||
}
|
||||
return names
|
||||
}
|
||||
|
||||
// MarshalJSON implements json.Marshaler
|
||||
func (s *ipsetStore) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(struct {
|
||||
IPSets map[string]*ipList `json:"ipsets"`
|
||||
}{
|
||||
IPSets: s.ipsets,
|
||||
})
|
||||
}
|
||||
|
||||
// UnmarshalJSON implements json.Unmarshaler
|
||||
func (s *ipsetStore) UnmarshalJSON(data []byte) error {
|
||||
temp := struct {
|
||||
IPSets map[string]*ipList `json:"ipsets"`
|
||||
}{}
|
||||
if err := json.Unmarshal(data, &temp); err != nil {
|
||||
return err
|
||||
}
|
||||
s.ipsets = temp.IPSets
|
||||
|
||||
if temp.IPSets == nil {
|
||||
temp.IPSets = make(map[string]*ipList)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
70
client/firewall/iptables/state_linux.go
Normal file
70
client/firewall/iptables/state_linux.go
Normal file
@@ -0,0 +1,70 @@
|
||||
package iptables
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
)
|
||||
|
||||
type InterfaceState struct {
|
||||
NameStr string `json:"name"`
|
||||
WGAddress iface.WGAddress `json:"wg_address"`
|
||||
UserspaceBind bool `json:"userspace_bind"`
|
||||
}
|
||||
|
||||
func (i *InterfaceState) Name() string {
|
||||
return i.NameStr
|
||||
}
|
||||
|
||||
func (i *InterfaceState) Address() device.WGAddress {
|
||||
return i.WGAddress
|
||||
}
|
||||
|
||||
func (i *InterfaceState) IsUserspaceBind() bool {
|
||||
return i.UserspaceBind
|
||||
}
|
||||
|
||||
type ShutdownState struct {
|
||||
sync.Mutex
|
||||
|
||||
InterfaceState *InterfaceState `json:"interface_state,omitempty"`
|
||||
|
||||
RouteRules routeRules `json:"route_rules,omitempty"`
|
||||
RouteIPsetCounter *ipsetCounter `json:"route_ipset_counter,omitempty"`
|
||||
|
||||
ACLEntries aclEntries `json:"acl_entries,omitempty"`
|
||||
ACLIPsetStore *ipsetStore `json:"acl_ipset_store,omitempty"`
|
||||
}
|
||||
|
||||
func (s *ShutdownState) Name() string {
|
||||
return "iptables_state"
|
||||
}
|
||||
|
||||
func (s *ShutdownState) Cleanup() error {
|
||||
ipt, err := Create(s.InterfaceState)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create iptables manager: %w", err)
|
||||
}
|
||||
|
||||
if s.RouteRules != nil {
|
||||
ipt.router.rules = s.RouteRules
|
||||
}
|
||||
if s.RouteIPsetCounter != nil {
|
||||
ipt.router.ipsetCounter.LoadData(s.RouteIPsetCounter)
|
||||
}
|
||||
|
||||
if s.ACLEntries != nil {
|
||||
ipt.aclMgr.entries = s.ACLEntries
|
||||
}
|
||||
if s.ACLIPsetStore != nil {
|
||||
ipt.aclMgr.ipsetStore = s.ACLIPsetStore
|
||||
}
|
||||
|
||||
if err := ipt.Reset(nil); err != nil {
|
||||
return fmt.Errorf("reset iptables manager: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -1,15 +1,24 @@
|
||||
package manager
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
)
|
||||
|
||||
const (
|
||||
NatFormat = "netbird-nat-%s"
|
||||
ForwardingFormat = "netbird-fwd-%s"
|
||||
InNatFormat = "netbird-nat-in-%s"
|
||||
InForwardingFormat = "netbird-fwd-in-%s"
|
||||
ForwardingFormatPrefix = "netbird-fwd-"
|
||||
ForwardingFormat = "netbird-fwd-%s-%t"
|
||||
PreroutingFormat = "netbird-prerouting-%s-%t"
|
||||
NatFormat = "netbird-nat-%s-%t"
|
||||
)
|
||||
|
||||
// Rule abstraction should be implemented by each firewall manager
|
||||
@@ -46,43 +55,135 @@ const (
|
||||
// It declares methods which handle actions required by the
|
||||
// Netbird client for ACL and routing functionality
|
||||
type Manager interface {
|
||||
Init(stateManager *statemanager.Manager) error
|
||||
|
||||
// AllowNetbird allows netbird interface traffic
|
||||
AllowNetbird() error
|
||||
|
||||
// AddFiltering rule to the firewall
|
||||
// AddPeerFiltering adds a rule to the firewall
|
||||
//
|
||||
// If comment argument is empty firewall manager should set
|
||||
// rule ID as comment for the rule
|
||||
AddFiltering(
|
||||
AddPeerFiltering(
|
||||
ip net.IP,
|
||||
proto Protocol,
|
||||
sPort *Port,
|
||||
dPort *Port,
|
||||
direction RuleDirection,
|
||||
action Action,
|
||||
ipsetName string,
|
||||
comment string,
|
||||
) ([]Rule, error)
|
||||
|
||||
// DeleteRule from the firewall by rule definition
|
||||
DeleteRule(rule Rule) error
|
||||
// DeletePeerRule from the firewall by rule definition
|
||||
DeletePeerRule(rule Rule) error
|
||||
|
||||
// IsServerRouteSupported returns true if the firewall supports server side routing operations
|
||||
IsServerRouteSupported() bool
|
||||
|
||||
// InsertRoutingRules inserts a routing firewall rule
|
||||
InsertRoutingRules(pair RouterPair) error
|
||||
AddRouteFiltering(source []netip.Prefix, destination netip.Prefix, proto Protocol, sPort *Port, dPort *Port, action Action) (Rule, error)
|
||||
|
||||
// RemoveRoutingRules removes a routing firewall rule
|
||||
RemoveRoutingRules(pair RouterPair) error
|
||||
// DeleteRouteRule deletes a routing rule
|
||||
DeleteRouteRule(rule Rule) error
|
||||
|
||||
// AddNatRule inserts a routing NAT rule
|
||||
AddNatRule(pair RouterPair) error
|
||||
|
||||
// RemoveNatRule removes a routing NAT rule
|
||||
RemoveNatRule(pair RouterPair) error
|
||||
|
||||
// SetLegacyManagement sets the legacy management mode
|
||||
SetLegacyManagement(legacy bool) error
|
||||
|
||||
// Reset firewall to the default state
|
||||
Reset() error
|
||||
Reset(stateManager *statemanager.Manager) error
|
||||
|
||||
// Flush the changes to firewall controller
|
||||
Flush() error
|
||||
}
|
||||
|
||||
func GenKey(format string, input string) string {
|
||||
return fmt.Sprintf(format, input)
|
||||
func GenKey(format string, pair RouterPair) string {
|
||||
return fmt.Sprintf(format, pair.ID, pair.Inverse)
|
||||
}
|
||||
|
||||
// LegacyManager defines the interface for legacy management operations
|
||||
type LegacyManager interface {
|
||||
RemoveAllLegacyRouteRules() error
|
||||
GetLegacyManagement() bool
|
||||
SetLegacyManagement(bool)
|
||||
}
|
||||
|
||||
// SetLegacyManagement sets the route manager to use legacy management
|
||||
func SetLegacyManagement(router LegacyManager, isLegacy bool) error {
|
||||
oldLegacy := router.GetLegacyManagement()
|
||||
|
||||
if oldLegacy != isLegacy {
|
||||
router.SetLegacyManagement(isLegacy)
|
||||
log.Debugf("Set legacy management to %v", isLegacy)
|
||||
}
|
||||
|
||||
// client reconnected to a newer mgmt, we need to clean up the legacy rules
|
||||
if !isLegacy && oldLegacy {
|
||||
if err := router.RemoveAllLegacyRouteRules(); err != nil {
|
||||
return fmt.Errorf("remove legacy routing rules: %v", err)
|
||||
}
|
||||
|
||||
log.Debugf("Legacy routing rules removed")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GenerateSetName generates a unique name for an ipset based on the given sources.
|
||||
func GenerateSetName(sources []netip.Prefix) string {
|
||||
// sort for consistent naming
|
||||
SortPrefixes(sources)
|
||||
|
||||
var sourcesStr strings.Builder
|
||||
for _, src := range sources {
|
||||
sourcesStr.WriteString(src.String())
|
||||
}
|
||||
|
||||
hash := sha256.Sum256([]byte(sourcesStr.String()))
|
||||
shortHash := hex.EncodeToString(hash[:])[:8]
|
||||
|
||||
return fmt.Sprintf("nb-%s", shortHash)
|
||||
}
|
||||
|
||||
// MergeIPRanges merges overlapping IP ranges and returns a slice of non-overlapping netip.Prefix
|
||||
func MergeIPRanges(prefixes []netip.Prefix) []netip.Prefix {
|
||||
if len(prefixes) == 0 {
|
||||
return prefixes
|
||||
}
|
||||
|
||||
merged := []netip.Prefix{prefixes[0]}
|
||||
for _, prefix := range prefixes[1:] {
|
||||
last := merged[len(merged)-1]
|
||||
if last.Contains(prefix.Addr()) {
|
||||
// If the current prefix is contained within the last merged prefix, skip it
|
||||
continue
|
||||
}
|
||||
if prefix.Contains(last.Addr()) {
|
||||
// If the current prefix contains the last merged prefix, replace it
|
||||
merged[len(merged)-1] = prefix
|
||||
} else {
|
||||
// Otherwise, add the current prefix to the merged list
|
||||
merged = append(merged, prefix)
|
||||
}
|
||||
}
|
||||
|
||||
return merged
|
||||
}
|
||||
|
||||
// SortPrefixes sorts the given slice of netip.Prefix in place.
|
||||
// It sorts first by IP address, then by prefix length (most specific to least specific).
|
||||
func SortPrefixes(prefixes []netip.Prefix) {
|
||||
sort.Slice(prefixes, func(i, j int) bool {
|
||||
addrCmp := prefixes[i].Addr().Compare(prefixes[j].Addr())
|
||||
if addrCmp != 0 {
|
||||
return addrCmp < 0
|
||||
}
|
||||
|
||||
// If IP addresses are the same, compare prefix lengths (longer prefixes first)
|
||||
return prefixes[i].Bits() > prefixes[j].Bits()
|
||||
})
|
||||
}
|
||||
|
||||
192
client/firewall/manager/firewall_test.go
Normal file
192
client/firewall/manager/firewall_test.go
Normal file
@@ -0,0 +1,192 @@
|
||||
package manager_test
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"reflect"
|
||||
"regexp"
|
||||
"testing"
|
||||
|
||||
"github.com/netbirdio/netbird/client/firewall/manager"
|
||||
)
|
||||
|
||||
func TestGenerateSetName(t *testing.T) {
|
||||
t.Run("Different orders result in same hash", func(t *testing.T) {
|
||||
prefixes1 := []netip.Prefix{
|
||||
netip.MustParsePrefix("192.168.1.0/24"),
|
||||
netip.MustParsePrefix("10.0.0.0/8"),
|
||||
}
|
||||
prefixes2 := []netip.Prefix{
|
||||
netip.MustParsePrefix("10.0.0.0/8"),
|
||||
netip.MustParsePrefix("192.168.1.0/24"),
|
||||
}
|
||||
|
||||
result1 := manager.GenerateSetName(prefixes1)
|
||||
result2 := manager.GenerateSetName(prefixes2)
|
||||
|
||||
if result1 != result2 {
|
||||
t.Errorf("Different orders produced different hashes: %s != %s", result1, result2)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Result format is correct", func(t *testing.T) {
|
||||
prefixes := []netip.Prefix{
|
||||
netip.MustParsePrefix("192.168.1.0/24"),
|
||||
netip.MustParsePrefix("10.0.0.0/8"),
|
||||
}
|
||||
|
||||
result := manager.GenerateSetName(prefixes)
|
||||
|
||||
matched, err := regexp.MatchString(`^nb-[0-9a-f]{8}$`, result)
|
||||
if err != nil {
|
||||
t.Fatalf("Error matching regex: %v", err)
|
||||
}
|
||||
if !matched {
|
||||
t.Errorf("Result format is incorrect: %s", result)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Empty input produces consistent result", func(t *testing.T) {
|
||||
result1 := manager.GenerateSetName([]netip.Prefix{})
|
||||
result2 := manager.GenerateSetName([]netip.Prefix{})
|
||||
|
||||
if result1 != result2 {
|
||||
t.Errorf("Empty input produced inconsistent results: %s != %s", result1, result2)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("IPv4 and IPv6 mixing", func(t *testing.T) {
|
||||
prefixes1 := []netip.Prefix{
|
||||
netip.MustParsePrefix("192.168.1.0/24"),
|
||||
netip.MustParsePrefix("2001:db8::/32"),
|
||||
}
|
||||
prefixes2 := []netip.Prefix{
|
||||
netip.MustParsePrefix("2001:db8::/32"),
|
||||
netip.MustParsePrefix("192.168.1.0/24"),
|
||||
}
|
||||
|
||||
result1 := manager.GenerateSetName(prefixes1)
|
||||
result2 := manager.GenerateSetName(prefixes2)
|
||||
|
||||
if result1 != result2 {
|
||||
t.Errorf("Different orders of IPv4 and IPv6 produced different hashes: %s != %s", result1, result2)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestMergeIPRanges(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input []netip.Prefix
|
||||
expected []netip.Prefix
|
||||
}{
|
||||
{
|
||||
name: "Empty input",
|
||||
input: []netip.Prefix{},
|
||||
expected: []netip.Prefix{},
|
||||
},
|
||||
{
|
||||
name: "Single range",
|
||||
input: []netip.Prefix{
|
||||
netip.MustParsePrefix("192.168.1.0/24"),
|
||||
},
|
||||
expected: []netip.Prefix{
|
||||
netip.MustParsePrefix("192.168.1.0/24"),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Two non-overlapping ranges",
|
||||
input: []netip.Prefix{
|
||||
netip.MustParsePrefix("192.168.1.0/24"),
|
||||
netip.MustParsePrefix("10.0.0.0/8"),
|
||||
},
|
||||
expected: []netip.Prefix{
|
||||
netip.MustParsePrefix("192.168.1.0/24"),
|
||||
netip.MustParsePrefix("10.0.0.0/8"),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "One range containing another",
|
||||
input: []netip.Prefix{
|
||||
netip.MustParsePrefix("192.168.0.0/16"),
|
||||
netip.MustParsePrefix("192.168.1.0/24"),
|
||||
},
|
||||
expected: []netip.Prefix{
|
||||
netip.MustParsePrefix("192.168.0.0/16"),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "One range containing another (different order)",
|
||||
input: []netip.Prefix{
|
||||
netip.MustParsePrefix("192.168.1.0/24"),
|
||||
netip.MustParsePrefix("192.168.0.0/16"),
|
||||
},
|
||||
expected: []netip.Prefix{
|
||||
netip.MustParsePrefix("192.168.0.0/16"),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Overlapping ranges",
|
||||
input: []netip.Prefix{
|
||||
netip.MustParsePrefix("192.168.1.0/24"),
|
||||
netip.MustParsePrefix("192.168.1.128/25"),
|
||||
},
|
||||
expected: []netip.Prefix{
|
||||
netip.MustParsePrefix("192.168.1.0/24"),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Overlapping ranges (different order)",
|
||||
input: []netip.Prefix{
|
||||
netip.MustParsePrefix("192.168.1.128/25"),
|
||||
netip.MustParsePrefix("192.168.1.0/24"),
|
||||
},
|
||||
expected: []netip.Prefix{
|
||||
netip.MustParsePrefix("192.168.1.0/24"),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Multiple overlapping ranges",
|
||||
input: []netip.Prefix{
|
||||
netip.MustParsePrefix("192.168.0.0/16"),
|
||||
netip.MustParsePrefix("192.168.1.0/24"),
|
||||
netip.MustParsePrefix("192.168.2.0/24"),
|
||||
netip.MustParsePrefix("192.168.1.128/25"),
|
||||
},
|
||||
expected: []netip.Prefix{
|
||||
netip.MustParsePrefix("192.168.0.0/16"),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Partially overlapping ranges",
|
||||
input: []netip.Prefix{
|
||||
netip.MustParsePrefix("192.168.0.0/23"),
|
||||
netip.MustParsePrefix("192.168.1.0/24"),
|
||||
netip.MustParsePrefix("192.168.2.0/25"),
|
||||
},
|
||||
expected: []netip.Prefix{
|
||||
netip.MustParsePrefix("192.168.0.0/23"),
|
||||
netip.MustParsePrefix("192.168.2.0/25"),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "IPv6 ranges",
|
||||
input: []netip.Prefix{
|
||||
netip.MustParsePrefix("2001:db8::/32"),
|
||||
netip.MustParsePrefix("2001:db8:1::/48"),
|
||||
netip.MustParsePrefix("2001:db8:2::/48"),
|
||||
},
|
||||
expected: []netip.Prefix{
|
||||
netip.MustParsePrefix("2001:db8::/32"),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := manager.MergeIPRanges(tt.input)
|
||||
if !reflect.DeepEqual(result, tt.expected) {
|
||||
t.Errorf("MergeIPRanges() = %v, want %v", result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,18 +1,26 @@
|
||||
package manager
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
|
||||
"github.com/netbirdio/netbird/route"
|
||||
)
|
||||
|
||||
type RouterPair struct {
|
||||
ID string
|
||||
Source string
|
||||
Destination string
|
||||
ID route.ID
|
||||
Source netip.Prefix
|
||||
Destination netip.Prefix
|
||||
Masquerade bool
|
||||
Inverse bool
|
||||
}
|
||||
|
||||
func GetInPair(pair RouterPair) RouterPair {
|
||||
func GetInversePair(pair RouterPair) RouterPair {
|
||||
return RouterPair{
|
||||
ID: pair.ID,
|
||||
// invert Source/Destination
|
||||
Source: pair.Destination,
|
||||
Destination: pair.Source,
|
||||
Masquerade: pair.Masquerade,
|
||||
Inverse: true,
|
||||
}
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -5,20 +5,34 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"sync"
|
||||
|
||||
"github.com/google/nftables"
|
||||
"github.com/google/nftables/binaryutil"
|
||||
"github.com/google/nftables/expr"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
)
|
||||
|
||||
const (
|
||||
// tableName is the name of the table that is used for filtering by the Netbird client
|
||||
tableName = "netbird"
|
||||
// tableNameNetbird is the name of the table that is used for filtering by the Netbird client
|
||||
tableNameNetbird = "netbird"
|
||||
|
||||
tableNameFilter = "filter"
|
||||
chainNameInput = "INPUT"
|
||||
)
|
||||
|
||||
// iFaceMapper defines subset methods of interface required for manager
|
||||
type iFaceMapper interface {
|
||||
Name() string
|
||||
Address() iface.WGAddress
|
||||
IsUserspaceBind() bool
|
||||
}
|
||||
|
||||
// Manager of iptables firewall
|
||||
type Manager struct {
|
||||
mutex sync.Mutex
|
||||
@@ -30,40 +44,79 @@ type Manager struct {
|
||||
}
|
||||
|
||||
// Create nftables firewall manager
|
||||
func Create(context context.Context, wgIface iFaceMapper) (*Manager, error) {
|
||||
func Create(wgIface iFaceMapper) (*Manager, error) {
|
||||
m := &Manager{
|
||||
rConn: &nftables.Conn{},
|
||||
wgIface: wgIface,
|
||||
}
|
||||
|
||||
workTable, err := m.createWorkTable()
|
||||
workTable := &nftables.Table{Name: tableNameNetbird, Family: nftables.TableFamilyIPv4}
|
||||
|
||||
var err error
|
||||
m.router, err = newRouter(workTable, wgIface)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("create router: %w", err)
|
||||
}
|
||||
|
||||
m.router, err = newRouter(context, workTable)
|
||||
m.aclManager, err = newAclManager(workTable, wgIface, chainNameRoutingFw)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
m.aclManager, err = newAclManager(workTable, wgIface, m.router.RouteingFwChainName())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("create acl manager: %w", err)
|
||||
}
|
||||
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// AddFiltering rule to the firewall
|
||||
// Init nftables firewall manager
|
||||
func (m *Manager) Init(stateManager *statemanager.Manager) error {
|
||||
workTable, err := m.createWorkTable()
|
||||
if err != nil {
|
||||
return fmt.Errorf("create work table: %w", err)
|
||||
}
|
||||
|
||||
if err := m.router.init(workTable); err != nil {
|
||||
return fmt.Errorf("router init: %w", err)
|
||||
}
|
||||
|
||||
if err := m.aclManager.init(workTable); err != nil {
|
||||
// TODO: cleanup router
|
||||
return fmt.Errorf("acl manager init: %w", err)
|
||||
}
|
||||
|
||||
stateManager.RegisterState(&ShutdownState{})
|
||||
|
||||
// We only need to record minimal interface state for potential recreation.
|
||||
// Unlike iptables, which requires tracking individual rules, nftables maintains
|
||||
// a known state (our netbird table plus a few static rules). This allows for easy
|
||||
// cleanup using Reset() without needing to store specific rules.
|
||||
if err := stateManager.UpdateState(&ShutdownState{
|
||||
InterfaceState: &InterfaceState{
|
||||
NameStr: m.wgIface.Name(),
|
||||
WGAddress: m.wgIface.Address(),
|
||||
UserspaceBind: m.wgIface.IsUserspaceBind(),
|
||||
},
|
||||
}); err != nil {
|
||||
log.Errorf("failed to update state: %v", err)
|
||||
}
|
||||
|
||||
// persist early
|
||||
go func() {
|
||||
if err := stateManager.PersistState(context.Background()); err != nil {
|
||||
log.Errorf("failed to persist state: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// AddPeerFiltering rule to the firewall
|
||||
//
|
||||
// If comment argument is empty firewall manager should set
|
||||
// rule ID as comment for the rule
|
||||
func (m *Manager) AddFiltering(
|
||||
func (m *Manager) AddPeerFiltering(
|
||||
ip net.IP,
|
||||
proto firewall.Protocol,
|
||||
sPort *firewall.Port,
|
||||
dPort *firewall.Port,
|
||||
direction firewall.RuleDirection,
|
||||
action firewall.Action,
|
||||
ipsetName string,
|
||||
comment string,
|
||||
@@ -76,33 +129,59 @@ func (m *Manager) AddFiltering(
|
||||
return nil, fmt.Errorf("unsupported IP version: %s", ip.String())
|
||||
}
|
||||
|
||||
return m.aclManager.AddFiltering(ip, proto, sPort, dPort, direction, action, ipsetName, comment)
|
||||
return m.aclManager.AddPeerFiltering(ip, proto, sPort, dPort, action, ipsetName, comment)
|
||||
}
|
||||
|
||||
// DeleteRule from the firewall by rule definition
|
||||
func (m *Manager) DeleteRule(rule firewall.Rule) error {
|
||||
func (m *Manager) AddRouteFiltering(
|
||||
sources []netip.Prefix,
|
||||
destination netip.Prefix,
|
||||
proto firewall.Protocol,
|
||||
sPort *firewall.Port,
|
||||
dPort *firewall.Port,
|
||||
action firewall.Action,
|
||||
) (firewall.Rule, error) {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
return m.aclManager.DeleteRule(rule)
|
||||
if !destination.Addr().Is4() {
|
||||
return nil, fmt.Errorf("unsupported IP version: %s", destination.Addr().String())
|
||||
}
|
||||
|
||||
return m.router.AddRouteFiltering(sources, destination, proto, sPort, dPort, action)
|
||||
}
|
||||
|
||||
// DeletePeerRule from the firewall by rule definition
|
||||
func (m *Manager) DeletePeerRule(rule firewall.Rule) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
return m.aclManager.DeletePeerRule(rule)
|
||||
}
|
||||
|
||||
// DeleteRouteRule deletes a routing rule
|
||||
func (m *Manager) DeleteRouteRule(rule firewall.Rule) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
return m.router.DeleteRouteRule(rule)
|
||||
}
|
||||
|
||||
func (m *Manager) IsServerRouteSupported() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (m *Manager) InsertRoutingRules(pair firewall.RouterPair) error {
|
||||
func (m *Manager) AddNatRule(pair firewall.RouterPair) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
return m.router.AddRoutingRules(pair)
|
||||
return m.router.AddNatRule(pair)
|
||||
}
|
||||
|
||||
func (m *Manager) RemoveRoutingRules(pair firewall.RouterPair) error {
|
||||
func (m *Manager) RemoveNatRule(pair firewall.RouterPair) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
return m.router.RemoveRoutingRules(pair)
|
||||
return m.router.RemoveNatRule(pair)
|
||||
}
|
||||
|
||||
// AllowNetbird allows netbird interface traffic
|
||||
@@ -126,7 +205,7 @@ func (m *Manager) AllowNetbird() error {
|
||||
|
||||
var chain *nftables.Chain
|
||||
for _, c := range chains {
|
||||
if c.Table.Name == "filter" && c.Name == "INPUT" {
|
||||
if c.Table.Name == tableNameFilter && c.Name == chainNameInput {
|
||||
chain = c
|
||||
break
|
||||
}
|
||||
@@ -157,47 +236,86 @@ func (m *Manager) AllowNetbird() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetLegacyManagement sets the route manager to use legacy management
|
||||
func (m *Manager) SetLegacyManagement(isLegacy bool) error {
|
||||
return firewall.SetLegacyManagement(m.router, isLegacy)
|
||||
}
|
||||
|
||||
// Reset firewall to the default state
|
||||
func (m *Manager) Reset() error {
|
||||
func (m *Manager) Reset(stateManager *statemanager.Manager) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
chains, err := m.rConn.ListChains()
|
||||
if err != nil {
|
||||
return fmt.Errorf("list of chains: %w", err)
|
||||
if err := m.resetNetbirdInputRules(); err != nil {
|
||||
return fmt.Errorf("reset netbird input rules: %v", err)
|
||||
}
|
||||
|
||||
if err := m.router.Reset(); err != nil {
|
||||
return fmt.Errorf("reset router: %v", err)
|
||||
}
|
||||
|
||||
if err := m.cleanupNetbirdTables(); err != nil {
|
||||
return fmt.Errorf("cleanup netbird tables: %v", err)
|
||||
}
|
||||
|
||||
if err := m.rConn.Flush(); err != nil {
|
||||
return fmt.Errorf(flushError, err)
|
||||
}
|
||||
|
||||
if err := stateManager.DeleteState(&ShutdownState{}); err != nil {
|
||||
return fmt.Errorf("delete state: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manager) resetNetbirdInputRules() error {
|
||||
chains, err := m.rConn.ListChains()
|
||||
if err != nil {
|
||||
return fmt.Errorf("list chains: %w", err)
|
||||
}
|
||||
|
||||
m.deleteNetbirdInputRules(chains)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manager) deleteNetbirdInputRules(chains []*nftables.Chain) {
|
||||
for _, c := range chains {
|
||||
// delete Netbird allow input traffic rule if it exists
|
||||
if c.Table.Name == "filter" && c.Name == "INPUT" {
|
||||
if c.Table.Name == tableNameFilter && c.Name == chainNameInput {
|
||||
rules, err := m.rConn.GetRules(c.Table, c)
|
||||
if err != nil {
|
||||
log.Errorf("get rules for chain %q: %v", c.Name, err)
|
||||
continue
|
||||
}
|
||||
for _, r := range rules {
|
||||
if bytes.Equal(r.UserData, []byte(allowNetbirdInputRuleID)) {
|
||||
if err := m.rConn.DelRule(r); err != nil {
|
||||
log.Errorf("delete rule: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
m.deleteMatchingRules(rules)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) deleteMatchingRules(rules []*nftables.Rule) {
|
||||
for _, r := range rules {
|
||||
if bytes.Equal(r.UserData, []byte(allowNetbirdInputRuleID)) {
|
||||
if err := m.rConn.DelRule(r); err != nil {
|
||||
log.Errorf("delete rule: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
m.router.ResetForwardRules()
|
||||
|
||||
func (m *Manager) cleanupNetbirdTables() error {
|
||||
tables, err := m.rConn.ListTables()
|
||||
if err != nil {
|
||||
return fmt.Errorf("list of tables: %w", err)
|
||||
return fmt.Errorf("list tables: %w", err)
|
||||
}
|
||||
|
||||
for _, t := range tables {
|
||||
if t.Name == tableName {
|
||||
if t.Name == tableNameNetbird {
|
||||
m.rConn.DelTable(t)
|
||||
}
|
||||
}
|
||||
|
||||
return m.rConn.Flush()
|
||||
return nil
|
||||
}
|
||||
|
||||
// Flush rule/chain/set operations from the buffer
|
||||
@@ -218,12 +336,12 @@ func (m *Manager) createWorkTable() (*nftables.Table, error) {
|
||||
}
|
||||
|
||||
for _, t := range tables {
|
||||
if t.Name == tableName {
|
||||
if t.Name == tableNameNetbird {
|
||||
m.rConn.DelTable(t)
|
||||
}
|
||||
}
|
||||
|
||||
table := m.rConn.AddTable(&nftables.Table{Name: tableName, Family: nftables.TableFamilyIPv4})
|
||||
table := m.rConn.AddTable(&nftables.Table{Name: tableNameNetbird, Family: nftables.TableFamilyIPv4})
|
||||
err = m.rConn.Flush()
|
||||
return table, err
|
||||
}
|
||||
@@ -251,7 +369,7 @@ func (m *Manager) applyAllowNetbirdRules(chain *nftables.Chain) {
|
||||
func (m *Manager) detectAllowNetbirdRule(existedRules []*nftables.Rule) *nftables.Rule {
|
||||
ifName := ifname(m.wgIface.Name())
|
||||
for _, rule := range existedRules {
|
||||
if rule.Table.Name == "filter" && rule.Chain.Name == "INPUT" {
|
||||
if rule.Table.Name == tableNameFilter && rule.Chain.Name == chainNameInput {
|
||||
if len(rule.Exprs) < 4 {
|
||||
if e, ok := rule.Exprs[0].(*expr.Meta); !ok || e.Key != expr.MetaKeyIIFNAME {
|
||||
continue
|
||||
@@ -265,3 +383,38 @@ func (m *Manager) detectAllowNetbirdRule(existedRules []*nftables.Rule) *nftable
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func insertReturnTrafficRule(conn *nftables.Conn, table *nftables.Table, chain *nftables.Chain) {
|
||||
rule := &nftables.Rule{
|
||||
Table: table,
|
||||
Chain: chain,
|
||||
Exprs: getEstablishedExprs(1),
|
||||
}
|
||||
|
||||
conn.InsertRule(rule)
|
||||
}
|
||||
|
||||
func getEstablishedExprs(register uint32) []expr.Any {
|
||||
return []expr.Any{
|
||||
&expr.Ct{
|
||||
Key: expr.CtKeySTATE,
|
||||
Register: register,
|
||||
},
|
||||
&expr.Bitwise{
|
||||
SourceRegister: register,
|
||||
DestRegister: register,
|
||||
Len: 4,
|
||||
Mask: binaryutil.NativeEndian.PutUint32(expr.CtStateBitESTABLISHED | expr.CtStateBitRELATED),
|
||||
Xor: binaryutil.NativeEndian.PutUint32(0),
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpNeq,
|
||||
Register: register,
|
||||
Data: []byte{0, 0, 0, 0},
|
||||
},
|
||||
&expr.Counter{},
|
||||
&expr.Verdict{
|
||||
Kind: expr.VerdictAccept,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,22 +1,39 @@
|
||||
package nftables
|
||||
|
||||
import (
|
||||
"context"
|
||||
"bytes"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os/exec"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/nftables"
|
||||
"github.com/google/nftables/binaryutil"
|
||||
"github.com/google/nftables/expr"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/sys/unix"
|
||||
|
||||
fw "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
)
|
||||
|
||||
var ifaceMock = &iFaceMock{
|
||||
NameFunc: func() string {
|
||||
return "lo"
|
||||
},
|
||||
AddressFunc: func() iface.WGAddress {
|
||||
return iface.WGAddress{
|
||||
IP: net.ParseIP("100.96.0.1"),
|
||||
Network: &net.IPNet{
|
||||
IP: net.ParseIP("100.96.0.0"),
|
||||
Mask: net.IPv4Mask(255, 255, 255, 0),
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
// iFaceMapper defines subset methods of interface required for manager
|
||||
type iFaceMock struct {
|
||||
NameFunc func() string
|
||||
@@ -40,28 +57,15 @@ func (i *iFaceMock) Address() iface.WGAddress {
|
||||
func (i *iFaceMock) IsUserspaceBind() bool { return false }
|
||||
|
||||
func TestNftablesManager(t *testing.T) {
|
||||
mock := &iFaceMock{
|
||||
NameFunc: func() string {
|
||||
return "lo"
|
||||
},
|
||||
AddressFunc: func() iface.WGAddress {
|
||||
return iface.WGAddress{
|
||||
IP: net.ParseIP("100.96.0.1"),
|
||||
Network: &net.IPNet{
|
||||
IP: net.ParseIP("100.96.0.0"),
|
||||
Mask: net.IPv4Mask(255, 255, 255, 0),
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
// just check on the local interface
|
||||
manager, err := Create(context.Background(), mock)
|
||||
manager, err := Create(ifaceMock)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, manager.Init(nil))
|
||||
time.Sleep(time.Second * 3)
|
||||
|
||||
defer func() {
|
||||
err = manager.Reset()
|
||||
err = manager.Reset(nil)
|
||||
require.NoError(t, err, "failed to reset")
|
||||
time.Sleep(time.Second)
|
||||
}()
|
||||
@@ -70,16 +74,7 @@ func TestNftablesManager(t *testing.T) {
|
||||
|
||||
testClient := &nftables.Conn{}
|
||||
|
||||
rule, err := manager.AddFiltering(
|
||||
ip,
|
||||
fw.ProtocolTCP,
|
||||
nil,
|
||||
&fw.Port{Values: []int{53}},
|
||||
fw.RuleDirectionIN,
|
||||
fw.ActionDrop,
|
||||
"",
|
||||
"",
|
||||
)
|
||||
rule, err := manager.AddPeerFiltering(ip, fw.ProtocolTCP, nil, &fw.Port{Values: []int{53}}, fw.ActionDrop, "", "")
|
||||
require.NoError(t, err, "failed to add rule")
|
||||
|
||||
err = manager.Flush()
|
||||
@@ -88,17 +83,35 @@ func TestNftablesManager(t *testing.T) {
|
||||
rules, err := testClient.GetRules(manager.aclManager.workTable, manager.aclManager.chainInputRules)
|
||||
require.NoError(t, err, "failed to get rules")
|
||||
|
||||
require.Len(t, rules, 1, "expected 1 rules")
|
||||
require.Len(t, rules, 2, "expected 2 rules")
|
||||
|
||||
expectedExprs1 := []expr.Any{
|
||||
&expr.Ct{
|
||||
Key: expr.CtKeySTATE,
|
||||
Register: 1,
|
||||
},
|
||||
&expr.Bitwise{
|
||||
SourceRegister: 1,
|
||||
DestRegister: 1,
|
||||
Len: 4,
|
||||
Mask: binaryutil.NativeEndian.PutUint32(expr.CtStateBitESTABLISHED | expr.CtStateBitRELATED),
|
||||
Xor: binaryutil.NativeEndian.PutUint32(0),
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpNeq,
|
||||
Register: 1,
|
||||
Data: []byte{0, 0, 0, 0},
|
||||
},
|
||||
&expr.Counter{},
|
||||
&expr.Verdict{
|
||||
Kind: expr.VerdictAccept,
|
||||
},
|
||||
}
|
||||
require.ElementsMatch(t, rules[0].Exprs, expectedExprs1, "expected the same expressions")
|
||||
|
||||
ipToAdd, _ := netip.AddrFromSlice(ip)
|
||||
add := ipToAdd.Unmap()
|
||||
expectedExprs := []expr.Any{
|
||||
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: ifname("lo"),
|
||||
},
|
||||
expectedExprs2 := []expr.Any{
|
||||
&expr.Payload{
|
||||
DestRegister: 1,
|
||||
Base: expr.PayloadBaseNetworkHeader,
|
||||
@@ -134,10 +147,10 @@ func TestNftablesManager(t *testing.T) {
|
||||
},
|
||||
&expr.Verdict{Kind: expr.VerdictDrop},
|
||||
}
|
||||
require.ElementsMatch(t, rules[0].Exprs, expectedExprs, "expected the same expressions")
|
||||
require.ElementsMatch(t, rules[1].Exprs, expectedExprs2, "expected the same expressions")
|
||||
|
||||
for _, r := range rule {
|
||||
err = manager.DeleteRule(r)
|
||||
err = manager.DeletePeerRule(r)
|
||||
require.NoError(t, err, "failed to delete rule")
|
||||
}
|
||||
|
||||
@@ -146,9 +159,10 @@ func TestNftablesManager(t *testing.T) {
|
||||
|
||||
rules, err = testClient.GetRules(manager.aclManager.workTable, manager.aclManager.chainInputRules)
|
||||
require.NoError(t, err, "failed to get rules")
|
||||
require.Len(t, rules, 0, "expected 0 rules after deletion")
|
||||
// established rule remains
|
||||
require.Len(t, rules, 1, "expected 1 rules after deletion")
|
||||
|
||||
err = manager.Reset()
|
||||
err = manager.Reset(nil)
|
||||
require.NoError(t, err, "failed to reset")
|
||||
}
|
||||
|
||||
@@ -171,12 +185,13 @@ func TestNFtablesCreatePerformance(t *testing.T) {
|
||||
for _, testMax := range []int{10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 200, 300, 400, 500, 600, 700, 800, 900, 1000} {
|
||||
t.Run(fmt.Sprintf("Testing %d rules", testMax), func(t *testing.T) {
|
||||
// just check on the local interface
|
||||
manager, err := Create(context.Background(), mock)
|
||||
manager, err := Create(mock)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, manager.Init(nil))
|
||||
time.Sleep(time.Second * 3)
|
||||
|
||||
defer func() {
|
||||
if err := manager.Reset(); err != nil {
|
||||
if err := manager.Reset(nil); err != nil {
|
||||
t.Errorf("clear the manager state: %v", err)
|
||||
}
|
||||
time.Sleep(time.Second)
|
||||
@@ -186,11 +201,7 @@ func TestNFtablesCreatePerformance(t *testing.T) {
|
||||
start := time.Now()
|
||||
for i := 0; i < testMax; i++ {
|
||||
port := &fw.Port{Values: []int{1000 + i}}
|
||||
if i%2 == 0 {
|
||||
_, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic")
|
||||
} else {
|
||||
_, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTP traffic")
|
||||
}
|
||||
_, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.ActionAccept, "", "accept HTTP traffic")
|
||||
require.NoError(t, err, "failed to add rule")
|
||||
|
||||
if i%100 == 0 {
|
||||
@@ -203,3 +214,96 @@ func TestNFtablesCreatePerformance(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func runIptablesSave(t *testing.T) (string, string) {
|
||||
t.Helper()
|
||||
var stdout, stderr bytes.Buffer
|
||||
cmd := exec.Command("iptables-save")
|
||||
cmd.Stdout = &stdout
|
||||
cmd.Stderr = &stderr
|
||||
|
||||
err := cmd.Run()
|
||||
require.NoError(t, err, "iptables-save failed to run")
|
||||
|
||||
return stdout.String(), stderr.String()
|
||||
}
|
||||
|
||||
func verifyIptablesOutput(t *testing.T, stdout, stderr string) {
|
||||
t.Helper()
|
||||
// Check for any incompatibility warnings
|
||||
require.NotContains(t,
|
||||
stderr,
|
||||
"incompatible",
|
||||
"iptables-save produced compatibility warning. Full stderr: %s",
|
||||
stderr,
|
||||
)
|
||||
|
||||
// Verify standard tables are present
|
||||
expectedTables := []string{
|
||||
"*filter",
|
||||
"*nat",
|
||||
"*mangle",
|
||||
}
|
||||
|
||||
for _, table := range expectedTables {
|
||||
require.Contains(t,
|
||||
stdout,
|
||||
table,
|
||||
"iptables-save output missing expected table: %s\nFull stdout: %s",
|
||||
table,
|
||||
stdout,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNftablesManagerCompatibilityWithIptables(t *testing.T) {
|
||||
if check() != NFTABLES {
|
||||
t.Skip("nftables not supported on this system")
|
||||
}
|
||||
|
||||
if _, err := exec.LookPath("iptables-save"); err != nil {
|
||||
t.Skipf("iptables-save not available on this system: %v", err)
|
||||
}
|
||||
|
||||
// First ensure iptables-nft tables exist by running iptables-save
|
||||
stdout, stderr := runIptablesSave(t)
|
||||
verifyIptablesOutput(t, stdout, stderr)
|
||||
|
||||
manager, err := Create(ifaceMock)
|
||||
require.NoError(t, err, "failed to create manager")
|
||||
require.NoError(t, manager.Init(nil))
|
||||
|
||||
t.Cleanup(func() {
|
||||
err := manager.Reset(nil)
|
||||
require.NoError(t, err, "failed to reset manager state")
|
||||
|
||||
// Verify iptables output after reset
|
||||
stdout, stderr := runIptablesSave(t)
|
||||
verifyIptablesOutput(t, stdout, stderr)
|
||||
})
|
||||
|
||||
ip := net.ParseIP("100.96.0.1")
|
||||
_, err = manager.AddPeerFiltering(ip, fw.ProtocolTCP, nil, &fw.Port{Values: []int{80}}, fw.ActionAccept, "", "test rule")
|
||||
require.NoError(t, err, "failed to add peer filtering rule")
|
||||
|
||||
_, err = manager.AddRouteFiltering(
|
||||
[]netip.Prefix{netip.MustParsePrefix("192.168.2.0/24")},
|
||||
netip.MustParsePrefix("10.1.0.0/24"),
|
||||
fw.ProtocolTCP,
|
||||
nil,
|
||||
&fw.Port{Values: []int{443}},
|
||||
fw.ActionAccept,
|
||||
)
|
||||
require.NoError(t, err, "failed to add route filtering rule")
|
||||
|
||||
pair := fw.RouterPair{
|
||||
Source: netip.MustParsePrefix("192.168.1.0/24"),
|
||||
Destination: netip.MustParsePrefix("10.0.0.0/24"),
|
||||
Masquerade: true,
|
||||
}
|
||||
err = manager.AddNatRule(pair)
|
||||
require.NoError(t, err, "failed to add NAT rule")
|
||||
|
||||
stdout, stderr = runIptablesSave(t)
|
||||
verifyIptablesOutput(t, stdout, stderr)
|
||||
}
|
||||
|
||||
@@ -1,431 +0,0 @@
|
||||
package nftables
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
|
||||
"github.com/google/nftables"
|
||||
"github.com/google/nftables/binaryutil"
|
||||
"github.com/google/nftables/expr"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/firewall/manager"
|
||||
)
|
||||
|
||||
const (
|
||||
chainNameRouteingFw = "netbird-rt-fwd"
|
||||
chainNameRoutingNat = "netbird-rt-nat"
|
||||
|
||||
userDataAcceptForwardRuleSrc = "frwacceptsrc"
|
||||
userDataAcceptForwardRuleDst = "frwacceptdst"
|
||||
|
||||
loopbackInterface = "lo\x00"
|
||||
)
|
||||
|
||||
// some presets for building nftable rules
|
||||
var (
|
||||
zeroXor = binaryutil.NativeEndian.PutUint32(0)
|
||||
|
||||
exprCounterAccept = []expr.Any{
|
||||
&expr.Counter{},
|
||||
&expr.Verdict{
|
||||
Kind: expr.VerdictAccept,
|
||||
},
|
||||
}
|
||||
|
||||
errFilterTableNotFound = fmt.Errorf("nftables: 'filter' table not found")
|
||||
)
|
||||
|
||||
type router struct {
|
||||
ctx context.Context
|
||||
stop context.CancelFunc
|
||||
conn *nftables.Conn
|
||||
workTable *nftables.Table
|
||||
filterTable *nftables.Table
|
||||
chains map[string]*nftables.Chain
|
||||
// rules is useful to avoid duplicates and to get missing attributes that we don't have when adding new rules
|
||||
rules map[string]*nftables.Rule
|
||||
isDefaultFwdRulesEnabled bool
|
||||
}
|
||||
|
||||
func newRouter(parentCtx context.Context, workTable *nftables.Table) (*router, error) {
|
||||
ctx, cancel := context.WithCancel(parentCtx)
|
||||
|
||||
r := &router{
|
||||
ctx: ctx,
|
||||
stop: cancel,
|
||||
conn: &nftables.Conn{},
|
||||
workTable: workTable,
|
||||
chains: make(map[string]*nftables.Chain),
|
||||
rules: make(map[string]*nftables.Rule),
|
||||
}
|
||||
|
||||
var err error
|
||||
r.filterTable, err = r.loadFilterTable()
|
||||
if err != nil {
|
||||
if errors.Is(err, errFilterTableNotFound) {
|
||||
log.Warnf("table 'filter' not found for forward rules")
|
||||
} else {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
err = r.cleanUpDefaultForwardRules()
|
||||
if err != nil {
|
||||
log.Errorf("failed to clean up rules from FORWARD chain: %s", err)
|
||||
}
|
||||
|
||||
err = r.createContainers()
|
||||
if err != nil {
|
||||
log.Errorf("failed to create containers for route: %s", err)
|
||||
}
|
||||
return r, err
|
||||
}
|
||||
|
||||
func (r *router) RouteingFwChainName() string {
|
||||
return chainNameRouteingFw
|
||||
}
|
||||
|
||||
// ResetForwardRules cleans existing nftables default forward rules from the system
|
||||
func (r *router) ResetForwardRules() {
|
||||
err := r.cleanUpDefaultForwardRules()
|
||||
if err != nil {
|
||||
log.Errorf("failed to reset forward rules: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (r *router) loadFilterTable() (*nftables.Table, error) {
|
||||
tables, err := r.conn.ListTablesOfFamily(nftables.TableFamilyIPv4)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("nftables: unable to list tables: %v", err)
|
||||
}
|
||||
|
||||
for _, table := range tables {
|
||||
if table.Name == "filter" {
|
||||
return table, nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil, errFilterTableNotFound
|
||||
}
|
||||
|
||||
func (r *router) createContainers() error {
|
||||
|
||||
r.chains[chainNameRouteingFw] = r.conn.AddChain(&nftables.Chain{
|
||||
Name: chainNameRouteingFw,
|
||||
Table: r.workTable,
|
||||
})
|
||||
|
||||
r.chains[chainNameRoutingNat] = r.conn.AddChain(&nftables.Chain{
|
||||
Name: chainNameRoutingNat,
|
||||
Table: r.workTable,
|
||||
Hooknum: nftables.ChainHookPostrouting,
|
||||
Priority: nftables.ChainPriorityNATSource - 1,
|
||||
Type: nftables.ChainTypeNAT,
|
||||
})
|
||||
|
||||
// Add RETURN rule for loopback interface
|
||||
loRule := &nftables.Rule{
|
||||
Table: r.workTable,
|
||||
Chain: r.chains[chainNameRoutingNat],
|
||||
Exprs: []expr.Any{
|
||||
&expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: []byte(loopbackInterface),
|
||||
},
|
||||
&expr.Verdict{Kind: expr.VerdictReturn},
|
||||
},
|
||||
}
|
||||
r.conn.InsertRule(loRule)
|
||||
|
||||
err := r.refreshRulesMap()
|
||||
if err != nil {
|
||||
log.Errorf("failed to clean up rules from FORWARD chain: %s", err)
|
||||
}
|
||||
|
||||
err = r.conn.Flush()
|
||||
if err != nil {
|
||||
return fmt.Errorf("nftables: unable to initialize table: %v", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// AddRoutingRules appends a nftable rule pair to the forwarding chain and if enabled, to the nat chain
|
||||
func (r *router) AddRoutingRules(pair manager.RouterPair) error {
|
||||
err := r.refreshRulesMap()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = r.addRoutingRule(manager.ForwardingFormat, chainNameRouteingFw, pair, false)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = r.addRoutingRule(manager.InForwardingFormat, chainNameRouteingFw, manager.GetInPair(pair), false)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if pair.Masquerade {
|
||||
err = r.addRoutingRule(manager.NatFormat, chainNameRoutingNat, pair, true)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = r.addRoutingRule(manager.InNatFormat, chainNameRoutingNat, manager.GetInPair(pair), true)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if r.filterTable != nil && !r.isDefaultFwdRulesEnabled {
|
||||
log.Debugf("add default accept forward rule")
|
||||
r.acceptForwardRule(pair.Source)
|
||||
}
|
||||
|
||||
err = r.conn.Flush()
|
||||
if err != nil {
|
||||
return fmt.Errorf("nftables: unable to insert rules for %s: %v", pair.Destination, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// addRoutingRule inserts a nftable rule to the conn client flush queue
|
||||
func (r *router) addRoutingRule(format, chainName string, pair manager.RouterPair, isNat bool) error {
|
||||
sourceExp := generateCIDRMatcherExpressions(true, pair.Source)
|
||||
destExp := generateCIDRMatcherExpressions(false, pair.Destination)
|
||||
|
||||
var expression []expr.Any
|
||||
if isNat {
|
||||
expression = append(sourceExp, append(destExp, &expr.Counter{}, &expr.Masq{})...) // nolint:gocritic
|
||||
} else {
|
||||
expression = append(sourceExp, append(destExp, exprCounterAccept...)...) // nolint:gocritic
|
||||
}
|
||||
|
||||
ruleKey := manager.GenKey(format, pair.ID)
|
||||
|
||||
_, exists := r.rules[ruleKey]
|
||||
if exists {
|
||||
err := r.removeRoutingRule(format, pair)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
r.rules[ruleKey] = r.conn.AddRule(&nftables.Rule{
|
||||
Table: r.workTable,
|
||||
Chain: r.chains[chainName],
|
||||
Exprs: expression,
|
||||
UserData: []byte(ruleKey),
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *router) acceptForwardRule(sourceNetwork string) {
|
||||
src := generateCIDRMatcherExpressions(true, sourceNetwork)
|
||||
dst := generateCIDRMatcherExpressions(false, "0.0.0.0/0")
|
||||
|
||||
var exprs []expr.Any
|
||||
exprs = append(src, append(dst, &expr.Verdict{ // nolint:gocritic
|
||||
Kind: expr.VerdictAccept,
|
||||
})...)
|
||||
|
||||
rule := &nftables.Rule{
|
||||
Table: r.filterTable,
|
||||
Chain: &nftables.Chain{
|
||||
Name: "FORWARD",
|
||||
Table: r.filterTable,
|
||||
Type: nftables.ChainTypeFilter,
|
||||
Hooknum: nftables.ChainHookForward,
|
||||
Priority: nftables.ChainPriorityFilter,
|
||||
},
|
||||
Exprs: exprs,
|
||||
UserData: []byte(userDataAcceptForwardRuleSrc),
|
||||
}
|
||||
|
||||
r.conn.AddRule(rule)
|
||||
|
||||
src = generateCIDRMatcherExpressions(true, "0.0.0.0/0")
|
||||
dst = generateCIDRMatcherExpressions(false, sourceNetwork)
|
||||
|
||||
exprs = append(src, append(dst, &expr.Verdict{ //nolint:gocritic
|
||||
Kind: expr.VerdictAccept,
|
||||
})...)
|
||||
|
||||
rule = &nftables.Rule{
|
||||
Table: r.filterTable,
|
||||
Chain: &nftables.Chain{
|
||||
Name: "FORWARD",
|
||||
Table: r.filterTable,
|
||||
Type: nftables.ChainTypeFilter,
|
||||
Hooknum: nftables.ChainHookForward,
|
||||
Priority: nftables.ChainPriorityFilter,
|
||||
},
|
||||
Exprs: exprs,
|
||||
UserData: []byte(userDataAcceptForwardRuleDst),
|
||||
}
|
||||
r.conn.AddRule(rule)
|
||||
r.isDefaultFwdRulesEnabled = true
|
||||
}
|
||||
|
||||
// RemoveRoutingRules removes a nftable rule pair from forwarding and nat chains
|
||||
func (r *router) RemoveRoutingRules(pair manager.RouterPair) error {
|
||||
err := r.refreshRulesMap()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = r.removeRoutingRule(manager.ForwardingFormat, pair)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = r.removeRoutingRule(manager.InForwardingFormat, manager.GetInPair(pair))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = r.removeRoutingRule(manager.NatFormat, pair)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = r.removeRoutingRule(manager.InNatFormat, manager.GetInPair(pair))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(r.rules) == 0 {
|
||||
err := r.cleanUpDefaultForwardRules()
|
||||
if err != nil {
|
||||
log.Errorf("failed to clean up rules from FORWARD chain: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
err = r.conn.Flush()
|
||||
if err != nil {
|
||||
return fmt.Errorf("nftables: received error while applying rule removal for %s: %v", pair.Destination, err)
|
||||
}
|
||||
log.Debugf("nftables: removed rules for %s", pair.Destination)
|
||||
return nil
|
||||
}
|
||||
|
||||
// removeRoutingRule add a nftable rule to the removal queue and delete from rules map
|
||||
func (r *router) removeRoutingRule(format string, pair manager.RouterPair) error {
|
||||
ruleKey := manager.GenKey(format, pair.ID)
|
||||
|
||||
rule, found := r.rules[ruleKey]
|
||||
if found {
|
||||
ruleType := "forwarding"
|
||||
if rule.Chain.Type == nftables.ChainTypeNAT {
|
||||
ruleType = "nat"
|
||||
}
|
||||
|
||||
err := r.conn.DelRule(rule)
|
||||
if err != nil {
|
||||
return fmt.Errorf("nftables: unable to remove %s rule for %s: %v", ruleType, pair.Destination, err)
|
||||
}
|
||||
|
||||
log.Debugf("nftables: removing %s rule for %s", ruleType, pair.Destination)
|
||||
|
||||
delete(r.rules, ruleKey)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// refreshRulesMap refreshes the rule map with the latest rules. this is useful to avoid
|
||||
// duplicates and to get missing attributes that we don't have when adding new rules
|
||||
func (r *router) refreshRulesMap() error {
|
||||
for _, chain := range r.chains {
|
||||
rules, err := r.conn.GetRules(chain.Table, chain)
|
||||
if err != nil {
|
||||
return fmt.Errorf("nftables: unable to list rules: %v", err)
|
||||
}
|
||||
for _, rule := range rules {
|
||||
if len(rule.UserData) > 0 {
|
||||
r.rules[string(rule.UserData)] = rule
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *router) cleanUpDefaultForwardRules() error {
|
||||
if r.filterTable == nil {
|
||||
r.isDefaultFwdRulesEnabled = false
|
||||
return nil
|
||||
}
|
||||
|
||||
chains, err := r.conn.ListChainsOfTableFamily(nftables.TableFamilyIPv4)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var rules []*nftables.Rule
|
||||
for _, chain := range chains {
|
||||
if chain.Table.Name != r.filterTable.Name {
|
||||
continue
|
||||
}
|
||||
if chain.Name != "FORWARD" {
|
||||
continue
|
||||
}
|
||||
|
||||
rules, err = r.conn.GetRules(r.filterTable, chain)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
for _, rule := range rules {
|
||||
if bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleSrc)) || bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleDst)) {
|
||||
err := r.conn.DelRule(rule)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
r.isDefaultFwdRulesEnabled = false
|
||||
return r.conn.Flush()
|
||||
}
|
||||
|
||||
// generateCIDRMatcherExpressions generates nftables expressions that matches a CIDR
|
||||
func generateCIDRMatcherExpressions(source bool, cidr string) []expr.Any {
|
||||
ip, network, _ := net.ParseCIDR(cidr)
|
||||
ipToAdd, _ := netip.AddrFromSlice(ip)
|
||||
add := ipToAdd.Unmap()
|
||||
|
||||
var offSet uint32
|
||||
if source {
|
||||
offSet = 12 // src offset
|
||||
} else {
|
||||
offSet = 16 // dst offset
|
||||
}
|
||||
|
||||
return []expr.Any{
|
||||
// fetch src add
|
||||
&expr.Payload{
|
||||
DestRegister: 1,
|
||||
Base: expr.PayloadBaseNetworkHeader,
|
||||
Offset: offSet,
|
||||
Len: 4,
|
||||
},
|
||||
// net mask
|
||||
&expr.Bitwise{
|
||||
DestRegister: 1,
|
||||
SourceRegister: 1,
|
||||
Len: 4,
|
||||
Mask: network.Mask,
|
||||
Xor: zeroXor,
|
||||
},
|
||||
// net address
|
||||
&expr.Cmp{
|
||||
Register: 1,
|
||||
Data: add.AsSlice(),
|
||||
},
|
||||
}
|
||||
}
|
||||
989
client/firewall/nftables/router_linux.go
Normal file
989
client/firewall/nftables/router_linux.go
Normal file
@@ -0,0 +1,989 @@
|
||||
package nftables
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"strings"
|
||||
|
||||
"github.com/coreos/go-iptables/iptables"
|
||||
"github.com/davecgh/go-spew/spew"
|
||||
"github.com/google/nftables"
|
||||
"github.com/google/nftables/binaryutil"
|
||||
"github.com/google/nftables/expr"
|
||||
"github.com/hashicorp/go-multierror"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/internal/acl/id"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
)
|
||||
|
||||
const (
|
||||
chainNameRoutingFw = "netbird-rt-fwd"
|
||||
chainNameRoutingNat = "netbird-rt-postrouting"
|
||||
chainNameForward = "FORWARD"
|
||||
|
||||
userDataAcceptForwardRuleIif = "frwacceptiif"
|
||||
userDataAcceptForwardRuleOif = "frwacceptoif"
|
||||
)
|
||||
|
||||
const refreshRulesMapError = "refresh rules map: %w"
|
||||
|
||||
var (
|
||||
errFilterTableNotFound = fmt.Errorf("nftables: 'filter' table not found")
|
||||
)
|
||||
|
||||
type router struct {
|
||||
conn *nftables.Conn
|
||||
workTable *nftables.Table
|
||||
filterTable *nftables.Table
|
||||
chains map[string]*nftables.Chain
|
||||
// rules is useful to avoid duplicates and to get missing attributes that we don't have when adding new rules
|
||||
rules map[string]*nftables.Rule
|
||||
ipsetCounter *refcounter.Counter[string, []netip.Prefix, *nftables.Set]
|
||||
|
||||
wgIface iFaceMapper
|
||||
legacyManagement bool
|
||||
}
|
||||
|
||||
func newRouter(workTable *nftables.Table, wgIface iFaceMapper) (*router, error) {
|
||||
r := &router{
|
||||
conn: &nftables.Conn{},
|
||||
workTable: workTable,
|
||||
chains: make(map[string]*nftables.Chain),
|
||||
rules: make(map[string]*nftables.Rule),
|
||||
wgIface: wgIface,
|
||||
}
|
||||
|
||||
r.ipsetCounter = refcounter.New(
|
||||
r.createIpSet,
|
||||
r.deleteIpSet,
|
||||
)
|
||||
|
||||
var err error
|
||||
r.filterTable, err = r.loadFilterTable()
|
||||
if err != nil {
|
||||
if errors.Is(err, errFilterTableNotFound) {
|
||||
log.Warnf("table 'filter' not found for forward rules")
|
||||
} else {
|
||||
return nil, fmt.Errorf("load filter table: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return r, nil
|
||||
}
|
||||
|
||||
func (r *router) init(workTable *nftables.Table) error {
|
||||
r.workTable = workTable
|
||||
|
||||
if err := r.removeAcceptForwardRules(); err != nil {
|
||||
log.Errorf("failed to clean up rules from FORWARD chain: %s", err)
|
||||
}
|
||||
|
||||
if err := r.createContainers(); err != nil {
|
||||
return fmt.Errorf("create containers: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Reset cleans existing nftables default forward rules from the system
|
||||
func (r *router) Reset() error {
|
||||
// clear without deleting the ipsets, the nf table will be deleted by the caller
|
||||
r.ipsetCounter.Clear()
|
||||
|
||||
return r.removeAcceptForwardRules()
|
||||
}
|
||||
|
||||
func (r *router) loadFilterTable() (*nftables.Table, error) {
|
||||
tables, err := r.conn.ListTablesOfFamily(nftables.TableFamilyIPv4)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("nftables: unable to list tables: %v", err)
|
||||
}
|
||||
|
||||
for _, table := range tables {
|
||||
if table.Name == "filter" {
|
||||
return table, nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil, errFilterTableNotFound
|
||||
}
|
||||
|
||||
func (r *router) createContainers() error {
|
||||
r.chains[chainNameRoutingFw] = r.conn.AddChain(&nftables.Chain{
|
||||
Name: chainNameRoutingFw,
|
||||
Table: r.workTable,
|
||||
})
|
||||
|
||||
insertReturnTrafficRule(r.conn, r.workTable, r.chains[chainNameRoutingFw])
|
||||
|
||||
prio := *nftables.ChainPriorityNATSource - 1
|
||||
r.chains[chainNameRoutingNat] = r.conn.AddChain(&nftables.Chain{
|
||||
Name: chainNameRoutingNat,
|
||||
Table: r.workTable,
|
||||
Hooknum: nftables.ChainHookPostrouting,
|
||||
Priority: &prio,
|
||||
Type: nftables.ChainTypeNAT,
|
||||
})
|
||||
|
||||
// Chain is created by acl manager
|
||||
// TODO: move creation to a common place
|
||||
r.chains[chainNamePrerouting] = &nftables.Chain{
|
||||
Name: chainNamePrerouting,
|
||||
Table: r.workTable,
|
||||
Type: nftables.ChainTypeFilter,
|
||||
Hooknum: nftables.ChainHookPrerouting,
|
||||
Priority: nftables.ChainPriorityMangle,
|
||||
}
|
||||
|
||||
// Add the single NAT rule that matches on mark
|
||||
if err := r.addPostroutingRules(); err != nil {
|
||||
return fmt.Errorf("add single nat rule: %v", err)
|
||||
}
|
||||
|
||||
if err := r.acceptForwardRules(); err != nil {
|
||||
log.Errorf("failed to add accept rules for the forward chain: %s", err)
|
||||
}
|
||||
|
||||
if err := r.refreshRulesMap(); err != nil {
|
||||
log.Errorf("failed to clean up rules from FORWARD chain: %s", err)
|
||||
}
|
||||
|
||||
if err := r.conn.Flush(); err != nil {
|
||||
return fmt.Errorf("nftables: unable to initialize table: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// AddRouteFiltering appends a nftables rule to the routing chain
|
||||
func (r *router) AddRouteFiltering(
|
||||
sources []netip.Prefix,
|
||||
destination netip.Prefix,
|
||||
proto firewall.Protocol,
|
||||
sPort *firewall.Port,
|
||||
dPort *firewall.Port,
|
||||
action firewall.Action,
|
||||
) (firewall.Rule, error) {
|
||||
|
||||
ruleKey := id.GenerateRouteRuleKey(sources, destination, proto, sPort, dPort, action)
|
||||
if _, ok := r.rules[string(ruleKey)]; ok {
|
||||
return ruleKey, nil
|
||||
}
|
||||
|
||||
chain := r.chains[chainNameRoutingFw]
|
||||
var exprs []expr.Any
|
||||
|
||||
switch {
|
||||
case len(sources) == 1 && sources[0].Bits() == 0:
|
||||
// If it's 0.0.0.0/0, we don't need to add any source matching
|
||||
case len(sources) == 1:
|
||||
// If there's only one source, we can use it directly
|
||||
exprs = append(exprs, generateCIDRMatcherExpressions(true, sources[0])...)
|
||||
default:
|
||||
// If there are multiple sources, create or get an ipset
|
||||
var err error
|
||||
exprs, err = r.getIpSetExprs(sources, exprs)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get ipset expressions: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Handle destination
|
||||
exprs = append(exprs, generateCIDRMatcherExpressions(false, destination)...)
|
||||
|
||||
// Handle protocol
|
||||
if proto != firewall.ProtocolALL {
|
||||
protoNum, err := protoToInt(proto)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("convert protocol to number: %w", err)
|
||||
}
|
||||
exprs = append(exprs, &expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1})
|
||||
exprs = append(exprs, &expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: []byte{protoNum},
|
||||
})
|
||||
|
||||
exprs = append(exprs, applyPort(sPort, true)...)
|
||||
exprs = append(exprs, applyPort(dPort, false)...)
|
||||
}
|
||||
|
||||
exprs = append(exprs, &expr.Counter{})
|
||||
|
||||
var verdict expr.VerdictKind
|
||||
if action == firewall.ActionAccept {
|
||||
verdict = expr.VerdictAccept
|
||||
} else {
|
||||
verdict = expr.VerdictDrop
|
||||
}
|
||||
exprs = append(exprs, &expr.Verdict{Kind: verdict})
|
||||
|
||||
rule := &nftables.Rule{
|
||||
Table: r.workTable,
|
||||
Chain: chain,
|
||||
Exprs: exprs,
|
||||
UserData: []byte(ruleKey),
|
||||
}
|
||||
|
||||
rule = r.conn.AddRule(rule)
|
||||
|
||||
log.Tracef("Adding route rule %s", spew.Sdump(rule))
|
||||
if err := r.conn.Flush(); err != nil {
|
||||
return nil, fmt.Errorf(flushError, err)
|
||||
}
|
||||
|
||||
r.rules[string(ruleKey)] = rule
|
||||
|
||||
log.Debugf("nftables: added route rule: sources=%v, destination=%v, proto=%v, sPort=%v, dPort=%v, action=%v", sources, destination, proto, sPort, dPort, action)
|
||||
|
||||
return ruleKey, nil
|
||||
}
|
||||
|
||||
func (r *router) getIpSetExprs(sources []netip.Prefix, exprs []expr.Any) ([]expr.Any, error) {
|
||||
setName := firewall.GenerateSetName(sources)
|
||||
ref, err := r.ipsetCounter.Increment(setName, sources)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create or get ipset for sources: %w", err)
|
||||
}
|
||||
|
||||
exprs = append(exprs,
|
||||
&expr.Payload{
|
||||
DestRegister: 1,
|
||||
Base: expr.PayloadBaseNetworkHeader,
|
||||
Offset: 12,
|
||||
Len: 4,
|
||||
},
|
||||
&expr.Lookup{
|
||||
SourceRegister: 1,
|
||||
SetName: ref.Out.Name,
|
||||
SetID: ref.Out.ID,
|
||||
},
|
||||
)
|
||||
return exprs, nil
|
||||
}
|
||||
|
||||
func (r *router) DeleteRouteRule(rule firewall.Rule) error {
|
||||
if err := r.refreshRulesMap(); err != nil {
|
||||
return fmt.Errorf(refreshRulesMapError, err)
|
||||
}
|
||||
|
||||
ruleKey := rule.GetRuleID()
|
||||
nftRule, exists := r.rules[ruleKey]
|
||||
if !exists {
|
||||
log.Debugf("route rule %s not found", ruleKey)
|
||||
return nil
|
||||
}
|
||||
|
||||
if nftRule.Handle == 0 {
|
||||
return fmt.Errorf("route rule %s has no handle", ruleKey)
|
||||
}
|
||||
|
||||
setName := r.findSetNameInRule(nftRule)
|
||||
|
||||
if err := r.deleteNftRule(nftRule, ruleKey); err != nil {
|
||||
return fmt.Errorf("delete: %w", err)
|
||||
}
|
||||
|
||||
if setName != "" {
|
||||
if _, err := r.ipsetCounter.Decrement(setName); err != nil {
|
||||
return fmt.Errorf("decrement ipset reference: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := r.conn.Flush(); err != nil {
|
||||
return fmt.Errorf(flushError, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *router) createIpSet(setName string, sources []netip.Prefix) (*nftables.Set, error) {
|
||||
// overlapping prefixes will result in an error, so we need to merge them
|
||||
sources = firewall.MergeIPRanges(sources)
|
||||
|
||||
set := &nftables.Set{
|
||||
Name: setName,
|
||||
Table: r.workTable,
|
||||
// required for prefixes
|
||||
Interval: true,
|
||||
KeyType: nftables.TypeIPAddr,
|
||||
}
|
||||
|
||||
var elements []nftables.SetElement
|
||||
for _, prefix := range sources {
|
||||
// TODO: Implement IPv6 support
|
||||
if prefix.Addr().Is6() {
|
||||
log.Printf("Skipping IPv6 prefix %s: IPv6 support not yet implemented", prefix)
|
||||
continue
|
||||
}
|
||||
|
||||
// nftables needs half-open intervals [firstIP, lastIP) for prefixes
|
||||
// e.g. 10.0.0.0/24 becomes [10.0.0.0, 10.0.1.0), 10.1.1.1/32 becomes [10.1.1.1, 10.1.1.2) etc
|
||||
firstIP := prefix.Addr()
|
||||
lastIP := calculateLastIP(prefix).Next()
|
||||
|
||||
elements = append(elements,
|
||||
// the nft tool also adds a line like this, see https://github.com/google/nftables/issues/247
|
||||
// nftables.SetElement{Key: []byte{0, 0, 0, 0}, IntervalEnd: true},
|
||||
nftables.SetElement{Key: firstIP.AsSlice()},
|
||||
nftables.SetElement{Key: lastIP.AsSlice(), IntervalEnd: true},
|
||||
)
|
||||
}
|
||||
|
||||
if err := r.conn.AddSet(set, elements); err != nil {
|
||||
return nil, fmt.Errorf("error adding elements to set %s: %w", setName, err)
|
||||
}
|
||||
|
||||
if err := r.conn.Flush(); err != nil {
|
||||
return nil, fmt.Errorf("flush error: %w", err)
|
||||
}
|
||||
|
||||
log.Printf("Created new ipset: %s with %d elements", setName, len(elements)/2)
|
||||
|
||||
return set, nil
|
||||
}
|
||||
|
||||
// calculateLastIP determines the last IP in a given prefix.
|
||||
func calculateLastIP(prefix netip.Prefix) netip.Addr {
|
||||
hostMask := ^uint32(0) >> prefix.Masked().Bits()
|
||||
lastIP := uint32FromNetipAddr(prefix.Addr()) | hostMask
|
||||
|
||||
return netip.AddrFrom4(uint32ToBytes(lastIP))
|
||||
}
|
||||
|
||||
// Utility function to convert netip.Addr to uint32.
|
||||
func uint32FromNetipAddr(addr netip.Addr) uint32 {
|
||||
b := addr.As4()
|
||||
return binary.BigEndian.Uint32(b[:])
|
||||
}
|
||||
|
||||
// Utility function to convert uint32 to a netip-compatible byte slice.
|
||||
func uint32ToBytes(ip uint32) [4]byte {
|
||||
var b [4]byte
|
||||
binary.BigEndian.PutUint32(b[:], ip)
|
||||
return b
|
||||
}
|
||||
|
||||
func (r *router) deleteIpSet(setName string, set *nftables.Set) error {
|
||||
r.conn.DelSet(set)
|
||||
if err := r.conn.Flush(); err != nil {
|
||||
return fmt.Errorf(flushError, err)
|
||||
}
|
||||
|
||||
log.Debugf("Deleted unused ipset %s", setName)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *router) findSetNameInRule(rule *nftables.Rule) string {
|
||||
for _, e := range rule.Exprs {
|
||||
if lookup, ok := e.(*expr.Lookup); ok {
|
||||
return lookup.SetName
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (r *router) deleteNftRule(rule *nftables.Rule, ruleKey string) error {
|
||||
if err := r.conn.DelRule(rule); err != nil {
|
||||
return fmt.Errorf("delete rule %s: %w", ruleKey, err)
|
||||
}
|
||||
delete(r.rules, ruleKey)
|
||||
|
||||
log.Debugf("removed route rule %s", ruleKey)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// AddNatRule appends a nftables rule pair to the nat chain
|
||||
func (r *router) AddNatRule(pair firewall.RouterPair) error {
|
||||
if err := r.refreshRulesMap(); err != nil {
|
||||
return fmt.Errorf(refreshRulesMapError, err)
|
||||
}
|
||||
|
||||
if r.legacyManagement {
|
||||
log.Warnf("This peer is connected to a NetBird Management service with an older version. Allowing all traffic for %s", pair.Destination)
|
||||
if err := r.addLegacyRouteRule(pair); err != nil {
|
||||
return fmt.Errorf("add legacy routing rule: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if pair.Masquerade {
|
||||
if err := r.addNatRule(pair); err != nil {
|
||||
return fmt.Errorf("add nat rule: %w", err)
|
||||
}
|
||||
|
||||
if err := r.addNatRule(firewall.GetInversePair(pair)); err != nil {
|
||||
return fmt.Errorf("add inverse nat rule: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := r.conn.Flush(); err != nil {
|
||||
return fmt.Errorf("nftables: insert rules for %s: %v", pair.Destination, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// addNatRule inserts a nftables rule to the conn client flush queue
|
||||
func (r *router) addNatRule(pair firewall.RouterPair) error {
|
||||
sourceExp := generateCIDRMatcherExpressions(true, pair.Source)
|
||||
destExp := generateCIDRMatcherExpressions(false, pair.Destination)
|
||||
|
||||
op := expr.CmpOpEq
|
||||
if pair.Inverse {
|
||||
op = expr.CmpOpNeq
|
||||
}
|
||||
|
||||
exprs := []expr.Any{
|
||||
// We only care about NEW connections to mark them and later identify them in the postrouting chain for masquerading.
|
||||
// Masquerading will take care of the conntrack state, which means we won't need to mark established connections.
|
||||
&expr.Ct{
|
||||
Key: expr.CtKeySTATE,
|
||||
Register: 1,
|
||||
},
|
||||
&expr.Bitwise{
|
||||
SourceRegister: 1,
|
||||
DestRegister: 1,
|
||||
Len: 4,
|
||||
Mask: binaryutil.NativeEndian.PutUint32(expr.CtStateBitNEW),
|
||||
Xor: binaryutil.NativeEndian.PutUint32(0),
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpNeq,
|
||||
Register: 1,
|
||||
Data: []byte{0, 0, 0, 0},
|
||||
},
|
||||
|
||||
// interface matching
|
||||
&expr.Meta{
|
||||
Key: expr.MetaKeyIIFNAME,
|
||||
Register: 1,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: op,
|
||||
Register: 1,
|
||||
Data: ifname(r.wgIface.Name()),
|
||||
},
|
||||
}
|
||||
|
||||
exprs = append(exprs, sourceExp...)
|
||||
exprs = append(exprs, destExp...)
|
||||
|
||||
var markValue uint32 = nbnet.PreroutingFwmarkMasquerade
|
||||
if pair.Inverse {
|
||||
markValue = nbnet.PreroutingFwmarkMasqueradeReturn
|
||||
}
|
||||
|
||||
exprs = append(exprs,
|
||||
&expr.Immediate{
|
||||
Register: 1,
|
||||
Data: binaryutil.NativeEndian.PutUint32(markValue),
|
||||
},
|
||||
&expr.Meta{
|
||||
Key: expr.MetaKeyMARK,
|
||||
SourceRegister: true,
|
||||
Register: 1,
|
||||
},
|
||||
)
|
||||
|
||||
ruleKey := firewall.GenKey(firewall.PreroutingFormat, pair)
|
||||
|
||||
if _, exists := r.rules[ruleKey]; exists {
|
||||
if err := r.removeNatRule(pair); err != nil {
|
||||
return fmt.Errorf("remove prerouting rule: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
r.rules[ruleKey] = r.conn.AddRule(&nftables.Rule{
|
||||
Table: r.workTable,
|
||||
Chain: r.chains[chainNamePrerouting],
|
||||
Exprs: exprs,
|
||||
UserData: []byte(ruleKey),
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// addPostroutingRules adds the masquerade rules
|
||||
func (r *router) addPostroutingRules() error {
|
||||
// First masquerade rule for traffic coming in from WireGuard interface
|
||||
exprs := []expr.Any{
|
||||
// Match on the first fwmark
|
||||
&expr.Meta{
|
||||
Key: expr.MetaKeyMARK,
|
||||
Register: 1,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: binaryutil.NativeEndian.PutUint32(nbnet.PreroutingFwmarkMasquerade),
|
||||
},
|
||||
|
||||
// We need to exclude the loopback interface as this changes the ebpf proxy port
|
||||
&expr.Meta{
|
||||
Key: expr.MetaKeyOIFNAME,
|
||||
Register: 1,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpNeq,
|
||||
Register: 1,
|
||||
Data: ifname("lo"),
|
||||
},
|
||||
&expr.Counter{},
|
||||
&expr.Masq{},
|
||||
}
|
||||
|
||||
r.conn.AddRule(&nftables.Rule{
|
||||
Table: r.workTable,
|
||||
Chain: r.chains[chainNameRoutingNat],
|
||||
Exprs: exprs,
|
||||
})
|
||||
|
||||
// Second masquerade rule for traffic going out through WireGuard interface
|
||||
exprs2 := []expr.Any{
|
||||
// Match on the second fwmark
|
||||
&expr.Meta{
|
||||
Key: expr.MetaKeyMARK,
|
||||
Register: 1,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: binaryutil.NativeEndian.PutUint32(nbnet.PreroutingFwmarkMasqueradeReturn),
|
||||
},
|
||||
|
||||
// Match WireGuard interface
|
||||
&expr.Meta{
|
||||
Key: expr.MetaKeyOIFNAME,
|
||||
Register: 1,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: ifname(r.wgIface.Name()),
|
||||
},
|
||||
&expr.Counter{},
|
||||
&expr.Masq{},
|
||||
}
|
||||
|
||||
r.conn.AddRule(&nftables.Rule{
|
||||
Table: r.workTable,
|
||||
Chain: r.chains[chainNameRoutingNat],
|
||||
Exprs: exprs2,
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// addLegacyRouteRule adds a legacy routing rule for mgmt servers pre route acls
|
||||
func (r *router) addLegacyRouteRule(pair firewall.RouterPair) error {
|
||||
sourceExp := generateCIDRMatcherExpressions(true, pair.Source)
|
||||
destExp := generateCIDRMatcherExpressions(false, pair.Destination)
|
||||
|
||||
exprs := []expr.Any{
|
||||
&expr.Counter{},
|
||||
&expr.Verdict{
|
||||
Kind: expr.VerdictAccept,
|
||||
},
|
||||
}
|
||||
|
||||
expression := append(sourceExp, append(destExp, exprs...)...) // nolint:gocritic
|
||||
|
||||
ruleKey := firewall.GenKey(firewall.ForwardingFormat, pair)
|
||||
|
||||
if _, exists := r.rules[ruleKey]; exists {
|
||||
if err := r.removeLegacyRouteRule(pair); err != nil {
|
||||
return fmt.Errorf("remove legacy routing rule: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
r.rules[ruleKey] = r.conn.AddRule(&nftables.Rule{
|
||||
Table: r.workTable,
|
||||
Chain: r.chains[chainNameRoutingFw],
|
||||
Exprs: expression,
|
||||
UserData: []byte(ruleKey),
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
// removeLegacyRouteRule removes a legacy routing rule for mgmt servers pre route acls
|
||||
func (r *router) removeLegacyRouteRule(pair firewall.RouterPair) error {
|
||||
ruleKey := firewall.GenKey(firewall.ForwardingFormat, pair)
|
||||
|
||||
if rule, exists := r.rules[ruleKey]; exists {
|
||||
if err := r.conn.DelRule(rule); err != nil {
|
||||
return fmt.Errorf("remove legacy forwarding rule %s -> %s: %v", pair.Source, pair.Destination, err)
|
||||
}
|
||||
|
||||
log.Debugf("nftables: removed legacy forwarding rule %s -> %s", pair.Source, pair.Destination)
|
||||
|
||||
delete(r.rules, ruleKey)
|
||||
} else {
|
||||
log.Debugf("nftables: legacy forwarding rule %s not found", ruleKey)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetLegacyManagement returns the route manager's legacy management mode
|
||||
func (r *router) GetLegacyManagement() bool {
|
||||
return r.legacyManagement
|
||||
}
|
||||
|
||||
// SetLegacyManagement sets the route manager to use legacy management mode
|
||||
func (r *router) SetLegacyManagement(isLegacy bool) {
|
||||
r.legacyManagement = isLegacy
|
||||
}
|
||||
|
||||
// RemoveAllLegacyRouteRules removes all legacy routing rules for mgmt servers pre route acls
|
||||
func (r *router) RemoveAllLegacyRouteRules() error {
|
||||
if err := r.refreshRulesMap(); err != nil {
|
||||
return fmt.Errorf(refreshRulesMapError, err)
|
||||
}
|
||||
|
||||
var merr *multierror.Error
|
||||
for k, rule := range r.rules {
|
||||
if !strings.HasPrefix(k, firewall.ForwardingFormatPrefix) {
|
||||
continue
|
||||
}
|
||||
if err := r.conn.DelRule(rule); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove legacy forwarding rule: %v", err))
|
||||
} else {
|
||||
delete(r.rules, k)
|
||||
}
|
||||
|
||||
}
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
// acceptForwardRules adds iif/oif rules in the filter table/forward chain to make sure
|
||||
// that our traffic is not dropped by existing rules there.
|
||||
// The existing FORWARD rules/policies decide outbound traffic towards our interface.
|
||||
// In case the FORWARD policy is set to "drop", we add an established/related rule to allow return traffic for the inbound rule.
|
||||
func (r *router) acceptForwardRules() error {
|
||||
if r.filterTable == nil {
|
||||
log.Debugf("table 'filter' not found for forward rules, skipping accept rules")
|
||||
return nil
|
||||
}
|
||||
|
||||
fw := "iptables"
|
||||
|
||||
defer func() {
|
||||
log.Debugf("Used %s to add accept forward rules", fw)
|
||||
}()
|
||||
|
||||
// Try iptables first and fallback to nftables if iptables is not available
|
||||
ipt, err := iptables.New()
|
||||
if err != nil {
|
||||
// filter table exists but iptables is not
|
||||
log.Warnf("Will use nftables to manipulate the filter table because iptables is not available: %v", err)
|
||||
|
||||
fw = "nftables"
|
||||
return r.acceptForwardRulesNftables()
|
||||
}
|
||||
|
||||
return r.acceptForwardRulesIptables(ipt)
|
||||
}
|
||||
|
||||
func (r *router) acceptForwardRulesIptables(ipt *iptables.IPTables) error {
|
||||
var merr *multierror.Error
|
||||
for _, rule := range r.getAcceptForwardRules() {
|
||||
if err := ipt.Insert("filter", chainNameForward, 1, rule...); err != nil {
|
||||
merr = multierror.Append(err, fmt.Errorf("add iptables rule: %v", err))
|
||||
} else {
|
||||
log.Debugf("added iptables rule: %v", rule)
|
||||
}
|
||||
}
|
||||
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
func (r *router) getAcceptForwardRules() [][]string {
|
||||
intf := r.wgIface.Name()
|
||||
return [][]string{
|
||||
{"-i", intf, "-j", "ACCEPT"},
|
||||
{"-o", intf, "-m", "conntrack", "--ctstate", "RELATED,ESTABLISHED", "-j", "ACCEPT"},
|
||||
}
|
||||
}
|
||||
|
||||
func (r *router) acceptForwardRulesNftables() error {
|
||||
intf := ifname(r.wgIface.Name())
|
||||
|
||||
// Rule for incoming interface (iif) with counter
|
||||
iifRule := &nftables.Rule{
|
||||
Table: r.filterTable,
|
||||
Chain: &nftables.Chain{
|
||||
Name: chainNameForward,
|
||||
Table: r.filterTable,
|
||||
Type: nftables.ChainTypeFilter,
|
||||
Hooknum: nftables.ChainHookForward,
|
||||
Priority: nftables.ChainPriorityFilter,
|
||||
},
|
||||
Exprs: []expr.Any{
|
||||
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: intf,
|
||||
},
|
||||
&expr.Counter{},
|
||||
&expr.Verdict{Kind: expr.VerdictAccept},
|
||||
},
|
||||
UserData: []byte(userDataAcceptForwardRuleIif),
|
||||
}
|
||||
r.conn.InsertRule(iifRule)
|
||||
|
||||
oifExprs := []expr.Any{
|
||||
&expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: intf,
|
||||
},
|
||||
}
|
||||
|
||||
// Rule for outgoing interface (oif) with counter
|
||||
oifRule := &nftables.Rule{
|
||||
Table: r.filterTable,
|
||||
Chain: &nftables.Chain{
|
||||
Name: "FORWARD",
|
||||
Table: r.filterTable,
|
||||
Type: nftables.ChainTypeFilter,
|
||||
Hooknum: nftables.ChainHookForward,
|
||||
Priority: nftables.ChainPriorityFilter,
|
||||
},
|
||||
Exprs: append(oifExprs, getEstablishedExprs(2)...),
|
||||
UserData: []byte(userDataAcceptForwardRuleOif),
|
||||
}
|
||||
|
||||
r.conn.InsertRule(oifRule)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *router) removeAcceptForwardRules() error {
|
||||
if r.filterTable == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Try iptables first and fallback to nftables if iptables is not available
|
||||
ipt, err := iptables.New()
|
||||
if err != nil {
|
||||
log.Warnf("Will use nftables to manipulate the filter table because iptables is not available: %v", err)
|
||||
return r.removeAcceptForwardRulesNftables()
|
||||
}
|
||||
|
||||
return r.removeAcceptForwardRulesIptables(ipt)
|
||||
}
|
||||
|
||||
func (r *router) removeAcceptForwardRulesNftables() error {
|
||||
chains, err := r.conn.ListChainsOfTableFamily(nftables.TableFamilyIPv4)
|
||||
if err != nil {
|
||||
return fmt.Errorf("list chains: %v", err)
|
||||
}
|
||||
|
||||
for _, chain := range chains {
|
||||
if chain.Table.Name != r.filterTable.Name || chain.Name != chainNameForward {
|
||||
continue
|
||||
}
|
||||
|
||||
rules, err := r.conn.GetRules(r.filterTable, chain)
|
||||
if err != nil {
|
||||
return fmt.Errorf("get rules: %v", err)
|
||||
}
|
||||
|
||||
for _, rule := range rules {
|
||||
if bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleIif)) ||
|
||||
bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleOif)) {
|
||||
if err := r.conn.DelRule(rule); err != nil {
|
||||
return fmt.Errorf("delete rule: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if err := r.conn.Flush(); err != nil {
|
||||
return fmt.Errorf(flushError, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *router) removeAcceptForwardRulesIptables(ipt *iptables.IPTables) error {
|
||||
var merr *multierror.Error
|
||||
for _, rule := range r.getAcceptForwardRules() {
|
||||
if err := ipt.DeleteIfExists("filter", chainNameForward, rule...); err != nil {
|
||||
merr = multierror.Append(err, fmt.Errorf("remove iptables rule: %v", err))
|
||||
}
|
||||
}
|
||||
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
// RemoveNatRule removes the prerouting mark rule
|
||||
func (r *router) RemoveNatRule(pair firewall.RouterPair) error {
|
||||
if err := r.refreshRulesMap(); err != nil {
|
||||
return fmt.Errorf(refreshRulesMapError, err)
|
||||
}
|
||||
|
||||
if err := r.removeNatRule(pair); err != nil {
|
||||
return fmt.Errorf("remove prerouting rule: %w", err)
|
||||
}
|
||||
|
||||
if err := r.removeNatRule(firewall.GetInversePair(pair)); err != nil {
|
||||
return fmt.Errorf("remove inverse prerouting rule: %w", err)
|
||||
}
|
||||
|
||||
if err := r.removeLegacyRouteRule(pair); err != nil {
|
||||
return fmt.Errorf("remove legacy routing rule: %w", err)
|
||||
}
|
||||
|
||||
if err := r.conn.Flush(); err != nil {
|
||||
return fmt.Errorf("nftables: received error while applying rule removal for %s: %v", pair.Destination, err)
|
||||
}
|
||||
|
||||
log.Debugf("nftables: removed nat rules for %s", pair.Destination)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *router) removeNatRule(pair firewall.RouterPair) error {
|
||||
ruleKey := firewall.GenKey(firewall.PreroutingFormat, pair)
|
||||
|
||||
if rule, exists := r.rules[ruleKey]; exists {
|
||||
err := r.conn.DelRule(rule)
|
||||
if err != nil {
|
||||
return fmt.Errorf("remove prerouting rule %s -> %s: %v", pair.Source, pair.Destination, err)
|
||||
}
|
||||
|
||||
log.Debugf("nftables: removed prerouting rule %s -> %s", pair.Source, pair.Destination)
|
||||
|
||||
delete(r.rules, ruleKey)
|
||||
} else {
|
||||
log.Debugf("nftables: prerouting rule %s not found", ruleKey)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// refreshRulesMap refreshes the rule map with the latest rules. this is useful to avoid
|
||||
// duplicates and to get missing attributes that we don't have when adding new rules
|
||||
func (r *router) refreshRulesMap() error {
|
||||
for _, chain := range r.chains {
|
||||
rules, err := r.conn.GetRules(chain.Table, chain)
|
||||
if err != nil {
|
||||
return fmt.Errorf("nftables: unable to list rules: %v", err)
|
||||
}
|
||||
for _, rule := range rules {
|
||||
if len(rule.UserData) > 0 {
|
||||
r.rules[string(rule.UserData)] = rule
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// generateCIDRMatcherExpressions generates nftables expressions that matches a CIDR
|
||||
func generateCIDRMatcherExpressions(source bool, prefix netip.Prefix) []expr.Any {
|
||||
var offset uint32
|
||||
if source {
|
||||
offset = 12 // src offset
|
||||
} else {
|
||||
offset = 16 // dst offset
|
||||
}
|
||||
|
||||
ones := prefix.Bits()
|
||||
// 0.0.0.0/0 doesn't need extra expressions
|
||||
if ones == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
mask := net.CIDRMask(ones, 32)
|
||||
|
||||
return []expr.Any{
|
||||
&expr.Payload{
|
||||
DestRegister: 1,
|
||||
Base: expr.PayloadBaseNetworkHeader,
|
||||
Offset: offset,
|
||||
Len: 4,
|
||||
},
|
||||
// netmask
|
||||
&expr.Bitwise{
|
||||
DestRegister: 1,
|
||||
SourceRegister: 1,
|
||||
Len: 4,
|
||||
Mask: mask,
|
||||
Xor: []byte{0, 0, 0, 0},
|
||||
},
|
||||
// net address
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: prefix.Masked().Addr().AsSlice(),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func applyPort(port *firewall.Port, isSource bool) []expr.Any {
|
||||
if port == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var exprs []expr.Any
|
||||
|
||||
offset := uint32(2) // Default offset for destination port
|
||||
if isSource {
|
||||
offset = 0 // Offset for source port
|
||||
}
|
||||
|
||||
exprs = append(exprs, &expr.Payload{
|
||||
DestRegister: 1,
|
||||
Base: expr.PayloadBaseTransportHeader,
|
||||
Offset: offset,
|
||||
Len: 2,
|
||||
})
|
||||
|
||||
if port.IsRange && len(port.Values) == 2 {
|
||||
// Handle port range
|
||||
exprs = append(exprs,
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpGte,
|
||||
Register: 1,
|
||||
Data: binaryutil.BigEndian.PutUint16(uint16(port.Values[0])),
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpLte,
|
||||
Register: 1,
|
||||
Data: binaryutil.BigEndian.PutUint16(uint16(port.Values[1])),
|
||||
},
|
||||
)
|
||||
} else {
|
||||
// Handle single port or multiple ports
|
||||
for i, p := range port.Values {
|
||||
if i > 0 {
|
||||
// Add a bitwise OR operation between port checks
|
||||
exprs = append(exprs, &expr.Bitwise{
|
||||
SourceRegister: 1,
|
||||
DestRegister: 1,
|
||||
Len: 4,
|
||||
Mask: []byte{0x00, 0x00, 0xff, 0xff},
|
||||
Xor: []byte{0x00, 0x00, 0x00, 0x00},
|
||||
})
|
||||
}
|
||||
exprs = append(exprs, &expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: binaryutil.BigEndian.PutUint16(uint16(p)),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return exprs
|
||||
}
|
||||
@@ -3,12 +3,16 @@
|
||||
package nftables
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"net/netip"
|
||||
"os/exec"
|
||||
"testing"
|
||||
|
||||
"github.com/coreos/go-iptables/iptables"
|
||||
"github.com/google/nftables"
|
||||
"github.com/google/nftables/binaryutil"
|
||||
"github.com/google/nftables/expr"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
@@ -24,195 +28,627 @@ const (
|
||||
NFTABLES
|
||||
)
|
||||
|
||||
func TestNftablesManager_InsertRoutingRules(t *testing.T) {
|
||||
func TestNftablesManager_AddNatRule(t *testing.T) {
|
||||
if check() != NFTABLES {
|
||||
t.Skip("nftables not supported on this OS")
|
||||
}
|
||||
|
||||
table, err := createWorkTable()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
defer deleteWorkTable()
|
||||
|
||||
for _, testCase := range test.InsertRuleTestCases {
|
||||
t.Run(testCase.Name, func(t *testing.T) {
|
||||
manager, err := newRouter(context.TODO(), table)
|
||||
require.NoError(t, err, "failed to create router")
|
||||
// need fw manager to init both acl mgr and router for all chains to be present
|
||||
manager, err := Create(ifaceMock)
|
||||
t.Cleanup(func() {
|
||||
require.NoError(t, manager.Reset(nil))
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, manager.Init(nil))
|
||||
|
||||
nftablesTestingClient := &nftables.Conn{}
|
||||
|
||||
defer manager.ResetForwardRules()
|
||||
rtr := manager.router
|
||||
err = rtr.AddNatRule(testCase.InputPair)
|
||||
require.NoError(t, err, "pair should be inserted")
|
||||
|
||||
require.NoError(t, err, "shouldn't return error")
|
||||
|
||||
err = manager.AddRoutingRules(testCase.InputPair)
|
||||
defer func() {
|
||||
_ = manager.RemoveRoutingRules(testCase.InputPair)
|
||||
}()
|
||||
require.NoError(t, err, "forwarding pair should be inserted")
|
||||
|
||||
sourceExp := generateCIDRMatcherExpressions(true, testCase.InputPair.Source)
|
||||
destExp := generateCIDRMatcherExpressions(false, testCase.InputPair.Destination)
|
||||
testingExpression := append(sourceExp, destExp...) //nolint:gocritic
|
||||
fwdRuleKey := firewall.GenKey(firewall.ForwardingFormat, testCase.InputPair.ID)
|
||||
|
||||
found := 0
|
||||
for _, chain := range manager.chains {
|
||||
rules, err := nftablesTestingClient.GetRules(chain.Table, chain)
|
||||
require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name)
|
||||
for _, rule := range rules {
|
||||
if len(rule.UserData) > 0 && string(rule.UserData) == fwdRuleKey {
|
||||
require.ElementsMatchf(t, rule.Exprs[:len(testingExpression)], testingExpression, "forwarding rule elements should match")
|
||||
found = 1
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
require.Equal(t, 1, found, "should find at least 1 rule to test")
|
||||
t.Cleanup(func() {
|
||||
require.NoError(t, rtr.RemoveNatRule(testCase.InputPair), "failed to remove rule")
|
||||
})
|
||||
|
||||
if testCase.InputPair.Masquerade {
|
||||
natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair.ID)
|
||||
// Build expected expressions for connection tracking
|
||||
conntrackExprs := []expr.Any{
|
||||
&expr.Ct{
|
||||
Key: expr.CtKeySTATE,
|
||||
Register: 1,
|
||||
},
|
||||
&expr.Bitwise{
|
||||
SourceRegister: 1,
|
||||
DestRegister: 1,
|
||||
Len: 4,
|
||||
Mask: binaryutil.NativeEndian.PutUint32(expr.CtStateBitNEW),
|
||||
Xor: binaryutil.NativeEndian.PutUint32(0),
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpNeq,
|
||||
Register: 1,
|
||||
Data: []byte{0, 0, 0, 0},
|
||||
},
|
||||
}
|
||||
|
||||
// Build interface matching expression
|
||||
ifaceExprs := []expr.Any{
|
||||
&expr.Meta{
|
||||
Key: expr.MetaKeyIIFNAME,
|
||||
Register: 1,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: ifname(ifaceMock.Name()),
|
||||
},
|
||||
}
|
||||
|
||||
// Build CIDR matching expressions
|
||||
sourceExp := generateCIDRMatcherExpressions(true, testCase.InputPair.Source)
|
||||
destExp := generateCIDRMatcherExpressions(false, testCase.InputPair.Destination)
|
||||
|
||||
// Combine all expressions in the correct order
|
||||
// nolint:gocritic
|
||||
testingExpression := append(conntrackExprs, ifaceExprs...)
|
||||
testingExpression = append(testingExpression, sourceExp...)
|
||||
testingExpression = append(testingExpression, destExp...)
|
||||
|
||||
natRuleKey := firewall.GenKey(firewall.PreroutingFormat, testCase.InputPair)
|
||||
found := 0
|
||||
for _, chain := range manager.chains {
|
||||
rules, err := nftablesTestingClient.GetRules(chain.Table, chain)
|
||||
require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name)
|
||||
for _, rule := range rules {
|
||||
if len(rule.UserData) > 0 && string(rule.UserData) == natRuleKey {
|
||||
require.ElementsMatchf(t, rule.Exprs[:len(testingExpression)], testingExpression, "nat rule elements should match")
|
||||
found = 1
|
||||
for _, chain := range rtr.chains {
|
||||
if chain.Name == chainNamePrerouting {
|
||||
rules, err := nftablesTestingClient.GetRules(chain.Table, chain)
|
||||
require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name)
|
||||
for _, rule := range rules {
|
||||
if len(rule.UserData) > 0 && string(rule.UserData) == natRuleKey {
|
||||
// Compare expressions up to the mark setting expressions
|
||||
require.ElementsMatchf(t, rule.Exprs[:len(testingExpression)], testingExpression, "prerouting nat rule elements should match")
|
||||
found = 1
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
require.Equal(t, 1, found, "should find at least 1 rule to test")
|
||||
}
|
||||
|
||||
sourceExp = generateCIDRMatcherExpressions(true, firewall.GetInPair(testCase.InputPair).Source)
|
||||
destExp = generateCIDRMatcherExpressions(false, firewall.GetInPair(testCase.InputPair).Destination)
|
||||
testingExpression = append(sourceExp, destExp...) //nolint:gocritic
|
||||
inFwdRuleKey := firewall.GenKey(firewall.InForwardingFormat, testCase.InputPair.ID)
|
||||
|
||||
found = 0
|
||||
for _, chain := range manager.chains {
|
||||
rules, err := nftablesTestingClient.GetRules(chain.Table, chain)
|
||||
require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name)
|
||||
for _, rule := range rules {
|
||||
if len(rule.UserData) > 0 && string(rule.UserData) == inFwdRuleKey {
|
||||
require.ElementsMatchf(t, rule.Exprs[:len(testingExpression)], testingExpression, "income forwarding rule elements should match")
|
||||
found = 1
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
require.Equal(t, 1, found, "should find at least 1 rule to test")
|
||||
|
||||
if testCase.InputPair.Masquerade {
|
||||
inNatRuleKey := firewall.GenKey(firewall.InNatFormat, testCase.InputPair.ID)
|
||||
found := 0
|
||||
for _, chain := range manager.chains {
|
||||
rules, err := nftablesTestingClient.GetRules(chain.Table, chain)
|
||||
require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name)
|
||||
for _, rule := range rules {
|
||||
if len(rule.UserData) > 0 && string(rule.UserData) == inNatRuleKey {
|
||||
require.ElementsMatchf(t, rule.Exprs[:len(testingExpression)], testingExpression, "income nat rule elements should match")
|
||||
found = 1
|
||||
}
|
||||
}
|
||||
}
|
||||
require.Equal(t, 1, found, "should find at least 1 rule to test")
|
||||
require.Equal(t, 1, found, "should find at least 1 rule in prerouting chain")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNftablesManager_RemoveRoutingRules(t *testing.T) {
|
||||
func TestNftablesManager_RemoveNatRule(t *testing.T) {
|
||||
if check() != NFTABLES {
|
||||
t.Skip("nftables not supported on this OS")
|
||||
}
|
||||
|
||||
table, err := createWorkTable()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
for _, testCase := range test.RemoveRuleTestCases {
|
||||
t.Run(testCase.Name, func(t *testing.T) {
|
||||
manager, err := Create(ifaceMock)
|
||||
t.Cleanup(func() {
|
||||
require.NoError(t, manager.Reset(nil))
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, manager.Init(nil))
|
||||
|
||||
rtr := manager.router
|
||||
|
||||
// First add the NAT rule using the router's method
|
||||
err = rtr.AddNatRule(testCase.InputPair)
|
||||
require.NoError(t, err, "should add NAT rule")
|
||||
|
||||
// Verify the rule was added
|
||||
natRuleKey := firewall.GenKey(firewall.PreroutingFormat, testCase.InputPair)
|
||||
found := false
|
||||
rules, err := rtr.conn.GetRules(rtr.workTable, rtr.chains[chainNamePrerouting])
|
||||
require.NoError(t, err, "should list rules")
|
||||
for _, rule := range rules {
|
||||
if len(rule.UserData) > 0 && string(rule.UserData) == natRuleKey {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
require.True(t, found, "NAT rule should exist before removal")
|
||||
|
||||
// Now remove the rule
|
||||
err = rtr.RemoveNatRule(testCase.InputPair)
|
||||
require.NoError(t, err, "shouldn't return error when removing rule")
|
||||
|
||||
// Verify the rule was removed
|
||||
found = false
|
||||
rules, err = rtr.conn.GetRules(rtr.workTable, rtr.chains[chainNamePrerouting])
|
||||
require.NoError(t, err, "should list rules after removal")
|
||||
for _, rule := range rules {
|
||||
if len(rule.UserData) > 0 && string(rule.UserData) == natRuleKey {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
require.False(t, found, "NAT rule should not exist after removal")
|
||||
|
||||
// Verify the static postrouting rules still exist
|
||||
rules, err = rtr.conn.GetRules(rtr.workTable, rtr.chains[chainNameRoutingNat])
|
||||
require.NoError(t, err, "should list postrouting rules")
|
||||
foundCounter := false
|
||||
for _, rule := range rules {
|
||||
for _, e := range rule.Exprs {
|
||||
if _, ok := e.(*expr.Counter); ok {
|
||||
foundCounter = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if foundCounter {
|
||||
break
|
||||
}
|
||||
}
|
||||
require.True(t, foundCounter, "static postrouting rule should remain")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouter_AddRouteFiltering(t *testing.T) {
|
||||
if check() != NFTABLES {
|
||||
t.Skip("nftables not supported on this system")
|
||||
}
|
||||
|
||||
workTable, err := createWorkTable()
|
||||
require.NoError(t, err, "Failed to create work table")
|
||||
|
||||
defer deleteWorkTable()
|
||||
|
||||
for _, testCase := range test.RemoveRuleTestCases {
|
||||
t.Run(testCase.Name, func(t *testing.T) {
|
||||
manager, err := newRouter(context.TODO(), table)
|
||||
require.NoError(t, err, "failed to create router")
|
||||
r, err := newRouter(workTable, ifaceMock)
|
||||
require.NoError(t, err, "Failed to create router")
|
||||
require.NoError(t, r.init(workTable))
|
||||
|
||||
nftablesTestingClient := &nftables.Conn{}
|
||||
defer func(r *router) {
|
||||
require.NoError(t, r.Reset(), "Failed to reset rules")
|
||||
}(r)
|
||||
|
||||
defer manager.ResetForwardRules()
|
||||
tests := []struct {
|
||||
name string
|
||||
sources []netip.Prefix
|
||||
destination netip.Prefix
|
||||
proto firewall.Protocol
|
||||
sPort *firewall.Port
|
||||
dPort *firewall.Port
|
||||
direction firewall.RuleDirection
|
||||
action firewall.Action
|
||||
expectSet bool
|
||||
}{
|
||||
{
|
||||
name: "Basic TCP rule with single source",
|
||||
sources: []netip.Prefix{netip.MustParsePrefix("192.168.1.0/24")},
|
||||
destination: netip.MustParsePrefix("10.0.0.0/24"),
|
||||
proto: firewall.ProtocolTCP,
|
||||
sPort: nil,
|
||||
dPort: &firewall.Port{Values: []int{80}},
|
||||
direction: firewall.RuleDirectionIN,
|
||||
action: firewall.ActionAccept,
|
||||
expectSet: false,
|
||||
},
|
||||
{
|
||||
name: "UDP rule with multiple sources",
|
||||
sources: []netip.Prefix{
|
||||
netip.MustParsePrefix("172.16.0.0/16"),
|
||||
netip.MustParsePrefix("192.168.0.0/16"),
|
||||
},
|
||||
destination: netip.MustParsePrefix("10.0.0.0/8"),
|
||||
proto: firewall.ProtocolUDP,
|
||||
sPort: &firewall.Port{Values: []int{1024, 2048}, IsRange: true},
|
||||
dPort: nil,
|
||||
direction: firewall.RuleDirectionOUT,
|
||||
action: firewall.ActionDrop,
|
||||
expectSet: true,
|
||||
},
|
||||
{
|
||||
name: "All protocols rule",
|
||||
sources: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/8")},
|
||||
destination: netip.MustParsePrefix("0.0.0.0/0"),
|
||||
proto: firewall.ProtocolALL,
|
||||
sPort: nil,
|
||||
dPort: nil,
|
||||
direction: firewall.RuleDirectionIN,
|
||||
action: firewall.ActionAccept,
|
||||
expectSet: false,
|
||||
},
|
||||
{
|
||||
name: "ICMP rule",
|
||||
sources: []netip.Prefix{netip.MustParsePrefix("192.168.0.0/16")},
|
||||
destination: netip.MustParsePrefix("10.0.0.0/8"),
|
||||
proto: firewall.ProtocolICMP,
|
||||
sPort: nil,
|
||||
dPort: nil,
|
||||
direction: firewall.RuleDirectionIN,
|
||||
action: firewall.ActionAccept,
|
||||
expectSet: false,
|
||||
},
|
||||
{
|
||||
name: "TCP rule with multiple source ports",
|
||||
sources: []netip.Prefix{netip.MustParsePrefix("172.16.0.0/12")},
|
||||
destination: netip.MustParsePrefix("192.168.0.0/16"),
|
||||
proto: firewall.ProtocolTCP,
|
||||
sPort: &firewall.Port{Values: []int{80, 443, 8080}},
|
||||
dPort: nil,
|
||||
direction: firewall.RuleDirectionOUT,
|
||||
action: firewall.ActionAccept,
|
||||
expectSet: false,
|
||||
},
|
||||
{
|
||||
name: "UDP rule with single IP and port range",
|
||||
sources: []netip.Prefix{netip.MustParsePrefix("192.168.1.1/32")},
|
||||
destination: netip.MustParsePrefix("10.0.0.0/24"),
|
||||
proto: firewall.ProtocolUDP,
|
||||
sPort: nil,
|
||||
dPort: &firewall.Port{Values: []int{5000, 5100}, IsRange: true},
|
||||
direction: firewall.RuleDirectionIN,
|
||||
action: firewall.ActionDrop,
|
||||
expectSet: false,
|
||||
},
|
||||
{
|
||||
name: "TCP rule with source and destination ports",
|
||||
sources: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/24")},
|
||||
destination: netip.MustParsePrefix("172.16.0.0/16"),
|
||||
proto: firewall.ProtocolTCP,
|
||||
sPort: &firewall.Port{Values: []int{1024, 65535}, IsRange: true},
|
||||
dPort: &firewall.Port{Values: []int{22}},
|
||||
direction: firewall.RuleDirectionOUT,
|
||||
action: firewall.ActionAccept,
|
||||
expectSet: false,
|
||||
},
|
||||
{
|
||||
name: "Drop all incoming traffic",
|
||||
sources: []netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")},
|
||||
destination: netip.MustParsePrefix("192.168.0.0/24"),
|
||||
proto: firewall.ProtocolALL,
|
||||
sPort: nil,
|
||||
dPort: nil,
|
||||
direction: firewall.RuleDirectionIN,
|
||||
action: firewall.ActionDrop,
|
||||
expectSet: false,
|
||||
},
|
||||
}
|
||||
|
||||
sourceExp := generateCIDRMatcherExpressions(true, testCase.InputPair.Source)
|
||||
destExp := generateCIDRMatcherExpressions(false, testCase.InputPair.Destination)
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ruleKey, err := r.AddRouteFiltering(tt.sources, tt.destination, tt.proto, tt.sPort, tt.dPort, tt.action)
|
||||
require.NoError(t, err, "AddRouteFiltering failed")
|
||||
|
||||
forwardExp := append(sourceExp, append(destExp, exprCounterAccept...)...) //nolint:gocritic
|
||||
forwardRuleKey := firewall.GenKey(firewall.ForwardingFormat, testCase.InputPair.ID)
|
||||
insertedForwarding := nftablesTestingClient.InsertRule(&nftables.Rule{
|
||||
Table: manager.workTable,
|
||||
Chain: manager.chains[chainNameRouteingFw],
|
||||
Exprs: forwardExp,
|
||||
UserData: []byte(forwardRuleKey),
|
||||
t.Cleanup(func() {
|
||||
require.NoError(t, r.DeleteRouteRule(ruleKey), "Failed to delete rule")
|
||||
})
|
||||
|
||||
natExp := append(sourceExp, append(destExp, &expr.Counter{}, &expr.Masq{})...) //nolint:gocritic
|
||||
natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair.ID)
|
||||
// Check if the rule is in the internal map
|
||||
rule, ok := r.rules[ruleKey.GetRuleID()]
|
||||
assert.True(t, ok, "Rule not found in internal map")
|
||||
|
||||
insertedNat := nftablesTestingClient.InsertRule(&nftables.Rule{
|
||||
Table: manager.workTable,
|
||||
Chain: manager.chains[chainNameRoutingNat],
|
||||
Exprs: natExp,
|
||||
UserData: []byte(natRuleKey),
|
||||
})
|
||||
t.Log("Internal rule expressions:")
|
||||
for i, expr := range rule.Exprs {
|
||||
t.Logf(" [%d] %T: %+v", i, expr, expr)
|
||||
}
|
||||
|
||||
sourceExp = generateCIDRMatcherExpressions(true, firewall.GetInPair(testCase.InputPair).Source)
|
||||
destExp = generateCIDRMatcherExpressions(false, firewall.GetInPair(testCase.InputPair).Destination)
|
||||
// Verify internal rule content
|
||||
verifyRule(t, rule, tt.sources, tt.destination, tt.proto, tt.sPort, tt.dPort, tt.direction, tt.action, tt.expectSet)
|
||||
|
||||
forwardExp = append(sourceExp, append(destExp, exprCounterAccept...)...) //nolint:gocritic
|
||||
inForwardRuleKey := firewall.GenKey(firewall.InForwardingFormat, testCase.InputPair.ID)
|
||||
insertedInForwarding := nftablesTestingClient.InsertRule(&nftables.Rule{
|
||||
Table: manager.workTable,
|
||||
Chain: manager.chains[chainNameRouteingFw],
|
||||
Exprs: forwardExp,
|
||||
UserData: []byte(inForwardRuleKey),
|
||||
})
|
||||
// Check if the rule exists in nftables and verify its content
|
||||
rules, err := r.conn.GetRules(r.workTable, r.chains[chainNameRoutingFw])
|
||||
require.NoError(t, err, "Failed to get rules from nftables")
|
||||
|
||||
natExp = append(sourceExp, append(destExp, &expr.Counter{}, &expr.Masq{})...) //nolint:gocritic
|
||||
inNatRuleKey := firewall.GenKey(firewall.InNatFormat, testCase.InputPair.ID)
|
||||
var nftRule *nftables.Rule
|
||||
for _, rule := range rules {
|
||||
if string(rule.UserData) == ruleKey.GetRuleID() {
|
||||
nftRule = rule
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
insertedInNat := nftablesTestingClient.InsertRule(&nftables.Rule{
|
||||
Table: manager.workTable,
|
||||
Chain: manager.chains[chainNameRoutingNat],
|
||||
Exprs: natExp,
|
||||
UserData: []byte(inNatRuleKey),
|
||||
})
|
||||
require.NotNil(t, nftRule, "Rule not found in nftables")
|
||||
t.Log("Actual nftables rule expressions:")
|
||||
for i, expr := range nftRule.Exprs {
|
||||
t.Logf(" [%d] %T: %+v", i, expr, expr)
|
||||
}
|
||||
|
||||
err = nftablesTestingClient.Flush()
|
||||
require.NoError(t, err, "shouldn't return error")
|
||||
// Verify actual nftables rule content
|
||||
verifyRule(t, nftRule, tt.sources, tt.destination, tt.proto, tt.sPort, tt.dPort, tt.direction, tt.action, tt.expectSet)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
manager.ResetForwardRules()
|
||||
func TestNftablesCreateIpSet(t *testing.T) {
|
||||
if check() != NFTABLES {
|
||||
t.Skip("nftables not supported on this system")
|
||||
}
|
||||
|
||||
err = manager.RemoveRoutingRules(testCase.InputPair)
|
||||
require.NoError(t, err, "shouldn't return error")
|
||||
workTable, err := createWorkTable()
|
||||
require.NoError(t, err, "Failed to create work table")
|
||||
|
||||
for _, chain := range manager.chains {
|
||||
rules, err := nftablesTestingClient.GetRules(chain.Table, chain)
|
||||
require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name)
|
||||
for _, rule := range rules {
|
||||
if len(rule.UserData) > 0 {
|
||||
require.NotEqual(t, insertedForwarding.UserData, rule.UserData, "forwarding rule should not exist")
|
||||
require.NotEqual(t, insertedNat.UserData, rule.UserData, "nat rule should not exist")
|
||||
require.NotEqual(t, insertedInForwarding.UserData, rule.UserData, "income forwarding rule should not exist")
|
||||
require.NotEqual(t, insertedInNat.UserData, rule.UserData, "income nat rule should not exist")
|
||||
defer deleteWorkTable()
|
||||
|
||||
r, err := newRouter(workTable, ifaceMock)
|
||||
require.NoError(t, err, "Failed to create router")
|
||||
require.NoError(t, r.init(workTable))
|
||||
|
||||
defer func() {
|
||||
require.NoError(t, r.Reset(), "Failed to reset router")
|
||||
}()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
sources []netip.Prefix
|
||||
expected []netip.Prefix
|
||||
}{
|
||||
{
|
||||
name: "Single IP",
|
||||
sources: []netip.Prefix{netip.MustParsePrefix("192.168.1.1/32")},
|
||||
},
|
||||
{
|
||||
name: "Multiple IPs",
|
||||
sources: []netip.Prefix{
|
||||
netip.MustParsePrefix("192.168.1.1/32"),
|
||||
netip.MustParsePrefix("10.0.0.1/32"),
|
||||
netip.MustParsePrefix("172.16.0.1/32"),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Single Subnet",
|
||||
sources: []netip.Prefix{netip.MustParsePrefix("192.168.0.0/24")},
|
||||
},
|
||||
{
|
||||
name: "Multiple Subnets with Various Prefix Lengths",
|
||||
sources: []netip.Prefix{
|
||||
netip.MustParsePrefix("10.0.0.0/8"),
|
||||
netip.MustParsePrefix("172.16.0.0/16"),
|
||||
netip.MustParsePrefix("192.168.1.0/24"),
|
||||
netip.MustParsePrefix("203.0.113.0/26"),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Mix of Single IPs and Subnets in Different Positions",
|
||||
sources: []netip.Prefix{
|
||||
netip.MustParsePrefix("192.168.1.1/32"),
|
||||
netip.MustParsePrefix("10.0.0.0/16"),
|
||||
netip.MustParsePrefix("172.16.0.1/32"),
|
||||
netip.MustParsePrefix("203.0.113.0/24"),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Overlapping IPs/Subnets",
|
||||
sources: []netip.Prefix{
|
||||
netip.MustParsePrefix("10.0.0.0/8"),
|
||||
netip.MustParsePrefix("10.0.0.0/16"),
|
||||
netip.MustParsePrefix("10.0.0.1/32"),
|
||||
netip.MustParsePrefix("192.168.0.0/16"),
|
||||
netip.MustParsePrefix("192.168.1.0/24"),
|
||||
netip.MustParsePrefix("192.168.1.1/32"),
|
||||
},
|
||||
expected: []netip.Prefix{
|
||||
netip.MustParsePrefix("10.0.0.0/8"),
|
||||
netip.MustParsePrefix("192.168.0.0/16"),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Add this helper function inside TestNftablesCreateIpSet
|
||||
printNftSets := func() {
|
||||
cmd := exec.Command("nft", "list", "sets")
|
||||
output, err := cmd.CombinedOutput()
|
||||
if err != nil {
|
||||
t.Logf("Failed to run 'nft list sets': %v", err)
|
||||
} else {
|
||||
t.Logf("Current nft sets:\n%s", output)
|
||||
}
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
setName := firewall.GenerateSetName(tt.sources)
|
||||
set, err := r.createIpSet(setName, tt.sources)
|
||||
if err != nil {
|
||||
t.Logf("Failed to create IP set: %v", err)
|
||||
printNftSets()
|
||||
require.NoError(t, err, "Failed to create IP set")
|
||||
}
|
||||
require.NotNil(t, set, "Created set is nil")
|
||||
|
||||
// Verify set properties
|
||||
assert.Equal(t, setName, set.Name, "Set name mismatch")
|
||||
assert.Equal(t, r.workTable, set.Table, "Set table mismatch")
|
||||
assert.True(t, set.Interval, "Set interval property should be true")
|
||||
assert.Equal(t, nftables.TypeIPAddr, set.KeyType, "Set key type mismatch")
|
||||
|
||||
// Fetch the created set from nftables
|
||||
fetchedSet, err := r.conn.GetSetByName(r.workTable, setName)
|
||||
require.NoError(t, err, "Failed to fetch created set")
|
||||
require.NotNil(t, fetchedSet, "Fetched set is nil")
|
||||
|
||||
// Verify set elements
|
||||
elements, err := r.conn.GetSetElements(fetchedSet)
|
||||
require.NoError(t, err, "Failed to get set elements")
|
||||
|
||||
// Count the number of unique prefixes (excluding interval end markers)
|
||||
uniquePrefixes := make(map[string]bool)
|
||||
for _, elem := range elements {
|
||||
if !elem.IntervalEnd {
|
||||
ip := netip.AddrFrom4(*(*[4]byte)(elem.Key))
|
||||
uniquePrefixes[ip.String()] = true
|
||||
}
|
||||
}
|
||||
|
||||
// Check against expected merged prefixes
|
||||
expectedCount := len(tt.expected)
|
||||
if expectedCount == 0 {
|
||||
expectedCount = len(tt.sources)
|
||||
}
|
||||
assert.Equal(t, expectedCount, len(uniquePrefixes), "Number of unique prefixes in set doesn't match expected")
|
||||
|
||||
// Verify each expected prefix is in the set
|
||||
for _, expected := range tt.expected {
|
||||
found := false
|
||||
for _, elem := range elements {
|
||||
if !elem.IntervalEnd {
|
||||
ip := netip.AddrFrom4(*(*[4]byte)(elem.Key))
|
||||
if expected.Contains(ip) {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
assert.True(t, found, "Expected prefix %s not found in set", expected)
|
||||
}
|
||||
|
||||
r.conn.DelSet(set)
|
||||
if err := r.conn.Flush(); err != nil {
|
||||
t.Logf("Failed to delete set: %v", err)
|
||||
printNftSets()
|
||||
}
|
||||
require.NoError(t, err, "Failed to delete set")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func verifyRule(t *testing.T, rule *nftables.Rule, sources []netip.Prefix, destination netip.Prefix, proto firewall.Protocol, sPort, dPort *firewall.Port, direction firewall.RuleDirection, action firewall.Action, expectSet bool) {
|
||||
t.Helper()
|
||||
|
||||
assert.NotNil(t, rule, "Rule should not be nil")
|
||||
|
||||
// Verify sources and destination
|
||||
if expectSet {
|
||||
assert.True(t, containsSetLookup(rule.Exprs), "Rule should contain set lookup for multiple sources")
|
||||
} else if len(sources) == 1 && sources[0].Bits() != 0 {
|
||||
if direction == firewall.RuleDirectionIN {
|
||||
assert.True(t, containsCIDRMatcher(rule.Exprs, sources[0], true), "Rule should contain source CIDR matcher for %s", sources[0])
|
||||
} else {
|
||||
assert.True(t, containsCIDRMatcher(rule.Exprs, sources[0], false), "Rule should contain destination CIDR matcher for %s", sources[0])
|
||||
}
|
||||
}
|
||||
|
||||
if direction == firewall.RuleDirectionIN {
|
||||
assert.True(t, containsCIDRMatcher(rule.Exprs, destination, false), "Rule should contain destination CIDR matcher for %s", destination)
|
||||
} else {
|
||||
assert.True(t, containsCIDRMatcher(rule.Exprs, destination, true), "Rule should contain source CIDR matcher for %s", destination)
|
||||
}
|
||||
|
||||
// Verify protocol
|
||||
if proto != firewall.ProtocolALL {
|
||||
assert.True(t, containsProtocol(rule.Exprs, proto), "Rule should contain protocol matcher for %s", proto)
|
||||
}
|
||||
|
||||
// Verify ports
|
||||
if sPort != nil {
|
||||
assert.True(t, containsPort(rule.Exprs, sPort, true), "Rule should contain source port matcher for %v", sPort)
|
||||
}
|
||||
if dPort != nil {
|
||||
assert.True(t, containsPort(rule.Exprs, dPort, false), "Rule should contain destination port matcher for %v", dPort)
|
||||
}
|
||||
|
||||
// Verify action
|
||||
assert.True(t, containsAction(rule.Exprs, action), "Rule should contain correct action: %s", action)
|
||||
}
|
||||
|
||||
func containsSetLookup(exprs []expr.Any) bool {
|
||||
for _, e := range exprs {
|
||||
if _, ok := e.(*expr.Lookup); ok {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func containsCIDRMatcher(exprs []expr.Any, prefix netip.Prefix, isSource bool) bool {
|
||||
var offset uint32
|
||||
if isSource {
|
||||
offset = 12 // src offset
|
||||
} else {
|
||||
offset = 16 // dst offset
|
||||
}
|
||||
|
||||
var payloadFound, bitwiseFound, cmpFound bool
|
||||
for _, e := range exprs {
|
||||
switch ex := e.(type) {
|
||||
case *expr.Payload:
|
||||
if ex.Base == expr.PayloadBaseNetworkHeader && ex.Offset == offset && ex.Len == 4 {
|
||||
payloadFound = true
|
||||
}
|
||||
case *expr.Bitwise:
|
||||
if ex.Len == 4 && len(ex.Mask) == 4 && len(ex.Xor) == 4 {
|
||||
bitwiseFound = true
|
||||
}
|
||||
case *expr.Cmp:
|
||||
if ex.Op == expr.CmpOpEq && len(ex.Data) == 4 {
|
||||
cmpFound = true
|
||||
}
|
||||
}
|
||||
}
|
||||
return (payloadFound && bitwiseFound && cmpFound) || prefix.Bits() == 0
|
||||
}
|
||||
|
||||
func containsPort(exprs []expr.Any, port *firewall.Port, isSource bool) bool {
|
||||
var offset uint32 = 2 // Default offset for destination port
|
||||
if isSource {
|
||||
offset = 0 // Offset for source port
|
||||
}
|
||||
|
||||
var payloadFound, portMatchFound bool
|
||||
for _, e := range exprs {
|
||||
switch ex := e.(type) {
|
||||
case *expr.Payload:
|
||||
if ex.Base == expr.PayloadBaseTransportHeader && ex.Offset == offset && ex.Len == 2 {
|
||||
payloadFound = true
|
||||
}
|
||||
case *expr.Cmp:
|
||||
if port.IsRange {
|
||||
if ex.Op == expr.CmpOpGte || ex.Op == expr.CmpOpLte {
|
||||
portMatchFound = true
|
||||
}
|
||||
} else {
|
||||
if ex.Op == expr.CmpOpEq && len(ex.Data) == 2 {
|
||||
portValue := binary.BigEndian.Uint16(ex.Data)
|
||||
for _, p := range port.Values {
|
||||
if uint16(p) == portValue {
|
||||
portMatchFound = true
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
if payloadFound && portMatchFound {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func containsProtocol(exprs []expr.Any, proto firewall.Protocol) bool {
|
||||
var metaFound, cmpFound bool
|
||||
expectedProto, _ := protoToInt(proto)
|
||||
for _, e := range exprs {
|
||||
switch ex := e.(type) {
|
||||
case *expr.Meta:
|
||||
if ex.Key == expr.MetaKeyL4PROTO {
|
||||
metaFound = true
|
||||
}
|
||||
case *expr.Cmp:
|
||||
if ex.Op == expr.CmpOpEq && len(ex.Data) == 1 && ex.Data[0] == expectedProto {
|
||||
cmpFound = true
|
||||
}
|
||||
}
|
||||
}
|
||||
return metaFound && cmpFound
|
||||
}
|
||||
|
||||
func containsAction(exprs []expr.Any, action firewall.Action) bool {
|
||||
for _, e := range exprs {
|
||||
if verdict, ok := e.(*expr.Verdict); ok {
|
||||
switch action {
|
||||
case firewall.ActionAccept:
|
||||
return verdict.Kind == expr.VerdictAccept
|
||||
case firewall.ActionDrop:
|
||||
return verdict.Kind == expr.VerdictDrop
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// check returns the firewall type based on common lib checks. It returns UNKNOWN if no firewall is found.
|
||||
@@ -250,12 +686,12 @@ func createWorkTable() (*nftables.Table, error) {
|
||||
}
|
||||
|
||||
for _, t := range tables {
|
||||
if t.Name == tableName {
|
||||
if t.Name == tableNameNetbird {
|
||||
sConn.DelTable(t)
|
||||
}
|
||||
}
|
||||
|
||||
table := sConn.AddTable(&nftables.Table{Name: tableName, Family: nftables.TableFamilyIPv4})
|
||||
table := sConn.AddTable(&nftables.Table{Name: tableNameNetbird, Family: nftables.TableFamilyIPv4})
|
||||
err = sConn.Flush()
|
||||
|
||||
return table, err
|
||||
@@ -273,7 +709,7 @@ func deleteWorkTable() {
|
||||
}
|
||||
|
||||
for _, t := range tables {
|
||||
if t.Name == tableName {
|
||||
if t.Name == tableNameNetbird {
|
||||
sConn.DelTable(t)
|
||||
}
|
||||
}
|
||||
|
||||
47
client/firewall/nftables/state_linux.go
Normal file
47
client/firewall/nftables/state_linux.go
Normal file
@@ -0,0 +1,47 @@
|
||||
package nftables
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
)
|
||||
|
||||
type InterfaceState struct {
|
||||
NameStr string `json:"name"`
|
||||
WGAddress iface.WGAddress `json:"wg_address"`
|
||||
UserspaceBind bool `json:"userspace_bind"`
|
||||
}
|
||||
|
||||
func (i *InterfaceState) Name() string {
|
||||
return i.NameStr
|
||||
}
|
||||
|
||||
func (i *InterfaceState) Address() device.WGAddress {
|
||||
return i.WGAddress
|
||||
}
|
||||
|
||||
func (i *InterfaceState) IsUserspaceBind() bool {
|
||||
return i.UserspaceBind
|
||||
}
|
||||
|
||||
type ShutdownState struct {
|
||||
InterfaceState *InterfaceState `json:"interface_state,omitempty"`
|
||||
}
|
||||
|
||||
func (s *ShutdownState) Name() string {
|
||||
return "nftables_state"
|
||||
}
|
||||
|
||||
func (s *ShutdownState) Cleanup() error {
|
||||
nft, err := Create(s.InterfaceState)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create nftables manager: %w", err)
|
||||
}
|
||||
|
||||
if err := nft.Reset(nil); err != nil {
|
||||
return fmt.Errorf("reset nftables manager: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -1,8 +1,10 @@
|
||||
//go:build !android
|
||||
|
||||
package test
|
||||
|
||||
import firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
import (
|
||||
"net/netip"
|
||||
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
)
|
||||
|
||||
var (
|
||||
InsertRuleTestCases = []struct {
|
||||
@@ -13,8 +15,8 @@ var (
|
||||
Name: "Insert Forwarding IPV4 Rule",
|
||||
InputPair: firewall.RouterPair{
|
||||
ID: "zxa",
|
||||
Source: "100.100.100.1/32",
|
||||
Destination: "100.100.200.0/24",
|
||||
Source: netip.MustParsePrefix("100.100.100.1/32"),
|
||||
Destination: netip.MustParsePrefix("100.100.200.0/24"),
|
||||
Masquerade: false,
|
||||
},
|
||||
},
|
||||
@@ -22,8 +24,8 @@ var (
|
||||
Name: "Insert Forwarding And Nat IPV4 Rules",
|
||||
InputPair: firewall.RouterPair{
|
||||
ID: "zxa",
|
||||
Source: "100.100.100.1/32",
|
||||
Destination: "100.100.200.0/24",
|
||||
Source: netip.MustParsePrefix("100.100.100.1/32"),
|
||||
Destination: netip.MustParsePrefix("100.100.200.0/24"),
|
||||
Masquerade: true,
|
||||
},
|
||||
},
|
||||
@@ -38,8 +40,8 @@ var (
|
||||
Name: "Remove Forwarding And Nat IPV4 Rules",
|
||||
InputPair: firewall.RouterPair{
|
||||
ID: "zxa",
|
||||
Source: "100.100.100.1/32",
|
||||
Destination: "100.100.200.0/24",
|
||||
Source: netip.MustParsePrefix("100.100.100.1/32"),
|
||||
Destination: netip.MustParsePrefix("100.100.200.0/24"),
|
||||
Masquerade: true,
|
||||
},
|
||||
},
|
||||
|
||||
@@ -2,16 +2,36 @@
|
||||
|
||||
package uspfilter
|
||||
|
||||
import (
|
||||
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
)
|
||||
|
||||
// Reset firewall to the default state
|
||||
func (m *Manager) Reset() error {
|
||||
func (m *Manager) Reset(stateManager *statemanager.Manager) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
m.outgoingRules = make(map[string]RuleSet)
|
||||
m.incomingRules = make(map[string]RuleSet)
|
||||
|
||||
if m.udpTracker != nil {
|
||||
m.udpTracker.Close()
|
||||
m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout)
|
||||
}
|
||||
|
||||
if m.icmpTracker != nil {
|
||||
m.icmpTracker.Close()
|
||||
m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout)
|
||||
}
|
||||
|
||||
if m.tcpTracker != nil {
|
||||
m.tcpTracker.Close()
|
||||
m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout)
|
||||
}
|
||||
|
||||
if m.nativeFirewall != nil {
|
||||
return m.nativeFirewall.Reset()
|
||||
return m.nativeFirewall.Reset(stateManager)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -6,6 +6,9 @@ import (
|
||||
"syscall"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
)
|
||||
|
||||
type action string
|
||||
@@ -17,13 +20,28 @@ const (
|
||||
)
|
||||
|
||||
// Reset firewall to the default state
|
||||
func (m *Manager) Reset() error {
|
||||
func (m *Manager) Reset(*statemanager.Manager) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
m.outgoingRules = make(map[string]RuleSet)
|
||||
m.incomingRules = make(map[string]RuleSet)
|
||||
|
||||
if m.udpTracker != nil {
|
||||
m.udpTracker.Close()
|
||||
m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout)
|
||||
}
|
||||
|
||||
if m.icmpTracker != nil {
|
||||
m.icmpTracker.Close()
|
||||
m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout)
|
||||
}
|
||||
|
||||
if m.tcpTracker != nil {
|
||||
m.tcpTracker.Close()
|
||||
m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout)
|
||||
}
|
||||
|
||||
if !isWindowsFirewallReachable() {
|
||||
return nil
|
||||
}
|
||||
|
||||
137
client/firewall/uspfilter/conntrack/common.go
Normal file
137
client/firewall/uspfilter/conntrack/common.go
Normal file
@@ -0,0 +1,137 @@
|
||||
// common.go
|
||||
package conntrack
|
||||
|
||||
import (
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
// BaseConnTrack provides common fields and locking for all connection types
|
||||
type BaseConnTrack struct {
|
||||
SourceIP net.IP
|
||||
DestIP net.IP
|
||||
SourcePort uint16
|
||||
DestPort uint16
|
||||
lastSeen atomic.Int64 // Unix nano for atomic access
|
||||
established atomic.Bool
|
||||
}
|
||||
|
||||
// these small methods will be inlined by the compiler
|
||||
|
||||
// UpdateLastSeen safely updates the last seen timestamp
|
||||
func (b *BaseConnTrack) UpdateLastSeen() {
|
||||
b.lastSeen.Store(time.Now().UnixNano())
|
||||
}
|
||||
|
||||
// IsEstablished safely checks if connection is established
|
||||
func (b *BaseConnTrack) IsEstablished() bool {
|
||||
return b.established.Load()
|
||||
}
|
||||
|
||||
// SetEstablished safely sets the established state
|
||||
func (b *BaseConnTrack) SetEstablished(state bool) {
|
||||
b.established.Store(state)
|
||||
}
|
||||
|
||||
// GetLastSeen safely gets the last seen timestamp
|
||||
func (b *BaseConnTrack) GetLastSeen() time.Time {
|
||||
return time.Unix(0, b.lastSeen.Load())
|
||||
}
|
||||
|
||||
// timeoutExceeded checks if the connection has exceeded the given timeout
|
||||
func (b *BaseConnTrack) timeoutExceeded(timeout time.Duration) bool {
|
||||
lastSeen := time.Unix(0, b.lastSeen.Load())
|
||||
return time.Since(lastSeen) > timeout
|
||||
}
|
||||
|
||||
// IPAddr is a fixed-size IP address to avoid allocations
|
||||
type IPAddr [16]byte
|
||||
|
||||
// MakeIPAddr creates an IPAddr from net.IP
|
||||
func MakeIPAddr(ip net.IP) (addr IPAddr) {
|
||||
// Optimization: check for v4 first as it's more common
|
||||
if ip4 := ip.To4(); ip4 != nil {
|
||||
copy(addr[12:], ip4)
|
||||
} else {
|
||||
copy(addr[:], ip.To16())
|
||||
}
|
||||
return addr
|
||||
}
|
||||
|
||||
// ConnKey uniquely identifies a connection
|
||||
type ConnKey struct {
|
||||
SrcIP IPAddr
|
||||
DstIP IPAddr
|
||||
SrcPort uint16
|
||||
DstPort uint16
|
||||
}
|
||||
|
||||
// makeConnKey creates a connection key
|
||||
func makeConnKey(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16) ConnKey {
|
||||
return ConnKey{
|
||||
SrcIP: MakeIPAddr(srcIP),
|
||||
DstIP: MakeIPAddr(dstIP),
|
||||
SrcPort: srcPort,
|
||||
DstPort: dstPort,
|
||||
}
|
||||
}
|
||||
|
||||
// ValidateIPs checks if IPs match without allocation
|
||||
func ValidateIPs(connIP IPAddr, pktIP net.IP) bool {
|
||||
if ip4 := pktIP.To4(); ip4 != nil {
|
||||
// Compare IPv4 addresses (last 4 bytes)
|
||||
for i := 0; i < 4; i++ {
|
||||
if connIP[12+i] != ip4[i] {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
// Compare full IPv6 addresses
|
||||
ip6 := pktIP.To16()
|
||||
for i := 0; i < 16; i++ {
|
||||
if connIP[i] != ip6[i] {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// PreallocatedIPs is a pool of IP byte slices to reduce allocations
|
||||
type PreallocatedIPs struct {
|
||||
sync.Pool
|
||||
}
|
||||
|
||||
// NewPreallocatedIPs creates a new IP pool
|
||||
func NewPreallocatedIPs() *PreallocatedIPs {
|
||||
return &PreallocatedIPs{
|
||||
Pool: sync.Pool{
|
||||
New: func() interface{} {
|
||||
ip := make(net.IP, 16)
|
||||
return &ip
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Get retrieves an IP from the pool
|
||||
func (p *PreallocatedIPs) Get() net.IP {
|
||||
return *p.Pool.Get().(*net.IP)
|
||||
}
|
||||
|
||||
// Put returns an IP to the pool
|
||||
func (p *PreallocatedIPs) Put(ip net.IP) {
|
||||
p.Pool.Put(&ip)
|
||||
}
|
||||
|
||||
// copyIP copies an IP address efficiently
|
||||
func copyIP(dst, src net.IP) {
|
||||
if len(src) == 16 {
|
||||
copy(dst, src)
|
||||
} else {
|
||||
// Handle IPv4
|
||||
copy(dst[12:], src.To4())
|
||||
}
|
||||
}
|
||||
115
client/firewall/uspfilter/conntrack/common_test.go
Normal file
115
client/firewall/uspfilter/conntrack/common_test.go
Normal file
@@ -0,0 +1,115 @@
|
||||
package conntrack
|
||||
|
||||
import (
|
||||
"net"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func BenchmarkIPOperations(b *testing.B) {
|
||||
b.Run("MakeIPAddr", func(b *testing.B) {
|
||||
ip := net.ParseIP("192.168.1.1")
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_ = MakeIPAddr(ip)
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("ValidateIPs", func(b *testing.B) {
|
||||
ip1 := net.ParseIP("192.168.1.1")
|
||||
ip2 := net.ParseIP("192.168.1.1")
|
||||
addr := MakeIPAddr(ip1)
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_ = ValidateIPs(addr, ip2)
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("IPPool", func(b *testing.B) {
|
||||
pool := NewPreallocatedIPs()
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
ip := pool.Get()
|
||||
pool.Put(ip)
|
||||
}
|
||||
})
|
||||
|
||||
}
|
||||
func BenchmarkAtomicOperations(b *testing.B) {
|
||||
conn := &BaseConnTrack{}
|
||||
b.Run("UpdateLastSeen", func(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
conn.UpdateLastSeen()
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("IsEstablished", func(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
_ = conn.IsEstablished()
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("SetEstablished", func(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
conn.SetEstablished(i%2 == 0)
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("GetLastSeen", func(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
_ = conn.GetLastSeen()
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// Memory pressure tests
|
||||
func BenchmarkMemoryPressure(b *testing.B) {
|
||||
b.Run("TCPHighLoad", func(b *testing.B) {
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout)
|
||||
defer tracker.Close()
|
||||
|
||||
// Generate different IPs
|
||||
srcIPs := make([]net.IP, 100)
|
||||
dstIPs := make([]net.IP, 100)
|
||||
for i := 0; i < 100; i++ {
|
||||
srcIPs[i] = net.IPv4(192, 168, byte(i/256), byte(i%256))
|
||||
dstIPs[i] = net.IPv4(10, 0, byte(i/256), byte(i%256))
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
srcIdx := i % len(srcIPs)
|
||||
dstIdx := (i + 1) % len(dstIPs)
|
||||
tracker.TrackOutbound(srcIPs[srcIdx], dstIPs[dstIdx], uint16(i%65535), 80, TCPSyn)
|
||||
|
||||
// Simulate some valid inbound packets
|
||||
if i%3 == 0 {
|
||||
tracker.IsValidInbound(dstIPs[dstIdx], srcIPs[srcIdx], 80, uint16(i%65535), TCPAck)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("UDPHighLoad", func(b *testing.B) {
|
||||
tracker := NewUDPTracker(DefaultUDPTimeout)
|
||||
defer tracker.Close()
|
||||
|
||||
// Generate different IPs
|
||||
srcIPs := make([]net.IP, 100)
|
||||
dstIPs := make([]net.IP, 100)
|
||||
for i := 0; i < 100; i++ {
|
||||
srcIPs[i] = net.IPv4(192, 168, byte(i/256), byte(i%256))
|
||||
dstIPs[i] = net.IPv4(10, 0, byte(i/256), byte(i%256))
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
srcIdx := i % len(srcIPs)
|
||||
dstIdx := (i + 1) % len(dstIPs)
|
||||
tracker.TrackOutbound(srcIPs[srcIdx], dstIPs[dstIdx], uint16(i%65535), 80)
|
||||
|
||||
// Simulate some valid inbound packets
|
||||
if i%3 == 0 {
|
||||
tracker.IsValidInbound(dstIPs[dstIdx], srcIPs[srcIdx], 80, uint16(i%65535))
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
170
client/firewall/uspfilter/conntrack/icmp.go
Normal file
170
client/firewall/uspfilter/conntrack/icmp.go
Normal file
@@ -0,0 +1,170 @@
|
||||
package conntrack
|
||||
|
||||
import (
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/google/gopacket/layers"
|
||||
)
|
||||
|
||||
const (
|
||||
// DefaultICMPTimeout is the default timeout for ICMP connections
|
||||
DefaultICMPTimeout = 30 * time.Second
|
||||
// ICMPCleanupInterval is how often we check for stale ICMP connections
|
||||
ICMPCleanupInterval = 15 * time.Second
|
||||
)
|
||||
|
||||
// ICMPConnKey uniquely identifies an ICMP connection
|
||||
type ICMPConnKey struct {
|
||||
// Supports both IPv4 and IPv6
|
||||
SrcIP [16]byte
|
||||
DstIP [16]byte
|
||||
Sequence uint16 // ICMP sequence number
|
||||
ID uint16 // ICMP identifier
|
||||
}
|
||||
|
||||
// ICMPConnTrack represents an ICMP connection state
|
||||
type ICMPConnTrack struct {
|
||||
BaseConnTrack
|
||||
Sequence uint16
|
||||
ID uint16
|
||||
}
|
||||
|
||||
// ICMPTracker manages ICMP connection states
|
||||
type ICMPTracker struct {
|
||||
connections map[ICMPConnKey]*ICMPConnTrack
|
||||
timeout time.Duration
|
||||
cleanupTicker *time.Ticker
|
||||
mutex sync.RWMutex
|
||||
done chan struct{}
|
||||
ipPool *PreallocatedIPs
|
||||
}
|
||||
|
||||
// NewICMPTracker creates a new ICMP connection tracker
|
||||
func NewICMPTracker(timeout time.Duration) *ICMPTracker {
|
||||
if timeout == 0 {
|
||||
timeout = DefaultICMPTimeout
|
||||
}
|
||||
|
||||
tracker := &ICMPTracker{
|
||||
connections: make(map[ICMPConnKey]*ICMPConnTrack),
|
||||
timeout: timeout,
|
||||
cleanupTicker: time.NewTicker(ICMPCleanupInterval),
|
||||
done: make(chan struct{}),
|
||||
ipPool: NewPreallocatedIPs(),
|
||||
}
|
||||
|
||||
go tracker.cleanupRoutine()
|
||||
return tracker
|
||||
}
|
||||
|
||||
// TrackOutbound records an outbound ICMP Echo Request
|
||||
func (t *ICMPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, id uint16, seq uint16) {
|
||||
key := makeICMPKey(srcIP, dstIP, id, seq)
|
||||
now := time.Now().UnixNano()
|
||||
|
||||
t.mutex.Lock()
|
||||
conn, exists := t.connections[key]
|
||||
if !exists {
|
||||
srcIPCopy := t.ipPool.Get()
|
||||
dstIPCopy := t.ipPool.Get()
|
||||
copyIP(srcIPCopy, srcIP)
|
||||
copyIP(dstIPCopy, dstIP)
|
||||
|
||||
conn = &ICMPConnTrack{
|
||||
BaseConnTrack: BaseConnTrack{
|
||||
SourceIP: srcIPCopy,
|
||||
DestIP: dstIPCopy,
|
||||
},
|
||||
ID: id,
|
||||
Sequence: seq,
|
||||
}
|
||||
conn.lastSeen.Store(now)
|
||||
conn.established.Store(true)
|
||||
t.connections[key] = conn
|
||||
}
|
||||
t.mutex.Unlock()
|
||||
|
||||
conn.lastSeen.Store(now)
|
||||
}
|
||||
|
||||
// IsValidInbound checks if an inbound ICMP Echo Reply matches a tracked request
|
||||
func (t *ICMPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, id uint16, seq uint16, icmpType uint8) bool {
|
||||
switch icmpType {
|
||||
case uint8(layers.ICMPv4TypeDestinationUnreachable),
|
||||
uint8(layers.ICMPv4TypeTimeExceeded):
|
||||
return true
|
||||
case uint8(layers.ICMPv4TypeEchoReply):
|
||||
// continue processing
|
||||
default:
|
||||
return false
|
||||
}
|
||||
|
||||
key := makeICMPKey(dstIP, srcIP, id, seq)
|
||||
|
||||
t.mutex.RLock()
|
||||
conn, exists := t.connections[key]
|
||||
t.mutex.RUnlock()
|
||||
|
||||
if !exists {
|
||||
return false
|
||||
}
|
||||
|
||||
if conn.timeoutExceeded(t.timeout) {
|
||||
return false
|
||||
}
|
||||
|
||||
return conn.IsEstablished() &&
|
||||
ValidateIPs(MakeIPAddr(srcIP), conn.DestIP) &&
|
||||
ValidateIPs(MakeIPAddr(dstIP), conn.SourceIP) &&
|
||||
conn.ID == id &&
|
||||
conn.Sequence == seq
|
||||
}
|
||||
|
||||
func (t *ICMPTracker) cleanupRoutine() {
|
||||
for {
|
||||
select {
|
||||
case <-t.cleanupTicker.C:
|
||||
t.cleanup()
|
||||
case <-t.done:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
func (t *ICMPTracker) cleanup() {
|
||||
t.mutex.Lock()
|
||||
defer t.mutex.Unlock()
|
||||
|
||||
for key, conn := range t.connections {
|
||||
if conn.timeoutExceeded(t.timeout) {
|
||||
t.ipPool.Put(conn.SourceIP)
|
||||
t.ipPool.Put(conn.DestIP)
|
||||
delete(t.connections, key)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Close stops the cleanup routine and releases resources
|
||||
func (t *ICMPTracker) Close() {
|
||||
t.cleanupTicker.Stop()
|
||||
close(t.done)
|
||||
|
||||
t.mutex.Lock()
|
||||
for _, conn := range t.connections {
|
||||
t.ipPool.Put(conn.SourceIP)
|
||||
t.ipPool.Put(conn.DestIP)
|
||||
}
|
||||
t.connections = nil
|
||||
t.mutex.Unlock()
|
||||
}
|
||||
|
||||
// makeICMPKey creates an ICMP connection key
|
||||
func makeICMPKey(srcIP net.IP, dstIP net.IP, id uint16, seq uint16) ICMPConnKey {
|
||||
return ICMPConnKey{
|
||||
SrcIP: MakeIPAddr(srcIP),
|
||||
DstIP: MakeIPAddr(dstIP),
|
||||
ID: id,
|
||||
Sequence: seq,
|
||||
}
|
||||
}
|
||||
39
client/firewall/uspfilter/conntrack/icmp_test.go
Normal file
39
client/firewall/uspfilter/conntrack/icmp_test.go
Normal file
@@ -0,0 +1,39 @@
|
||||
package conntrack
|
||||
|
||||
import (
|
||||
"net"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func BenchmarkICMPTracker(b *testing.B) {
|
||||
b.Run("TrackOutbound", func(b *testing.B) {
|
||||
tracker := NewICMPTracker(DefaultICMPTimeout)
|
||||
defer tracker.Close()
|
||||
|
||||
srcIP := net.ParseIP("192.168.1.1")
|
||||
dstIP := net.ParseIP("192.168.1.2")
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), uint16(i%65535))
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("IsValidInbound", func(b *testing.B) {
|
||||
tracker := NewICMPTracker(DefaultICMPTimeout)
|
||||
defer tracker.Close()
|
||||
|
||||
srcIP := net.ParseIP("192.168.1.1")
|
||||
dstIP := net.ParseIP("192.168.1.2")
|
||||
|
||||
// Pre-populate some connections
|
||||
for i := 0; i < 1000; i++ {
|
||||
tracker.TrackOutbound(srcIP, dstIP, uint16(i), uint16(i))
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
tracker.IsValidInbound(dstIP, srcIP, uint16(i%1000), uint16(i%1000), 0)
|
||||
}
|
||||
})
|
||||
}
|
||||
352
client/firewall/uspfilter/conntrack/tcp.go
Normal file
352
client/firewall/uspfilter/conntrack/tcp.go
Normal file
@@ -0,0 +1,352 @@
|
||||
package conntrack
|
||||
|
||||
// TODO: Send RST packets for invalid/timed-out connections
|
||||
|
||||
import (
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
// MSL (Maximum Segment Lifetime) is typically 2 minutes
|
||||
MSL = 2 * time.Minute
|
||||
// TimeWaitTimeout (TIME-WAIT) should last 2*MSL
|
||||
TimeWaitTimeout = 2 * MSL
|
||||
)
|
||||
|
||||
const (
|
||||
TCPSyn uint8 = 0x02
|
||||
TCPAck uint8 = 0x10
|
||||
TCPFin uint8 = 0x01
|
||||
TCPRst uint8 = 0x04
|
||||
TCPPush uint8 = 0x08
|
||||
TCPUrg uint8 = 0x20
|
||||
)
|
||||
|
||||
const (
|
||||
// DefaultTCPTimeout is the default timeout for established TCP connections
|
||||
DefaultTCPTimeout = 3 * time.Hour
|
||||
// TCPHandshakeTimeout is timeout for TCP handshake completion
|
||||
TCPHandshakeTimeout = 60 * time.Second
|
||||
// TCPCleanupInterval is how often we check for stale connections
|
||||
TCPCleanupInterval = 5 * time.Minute
|
||||
)
|
||||
|
||||
// TCPState represents the state of a TCP connection
|
||||
type TCPState int
|
||||
|
||||
const (
|
||||
TCPStateNew TCPState = iota
|
||||
TCPStateSynSent
|
||||
TCPStateSynReceived
|
||||
TCPStateEstablished
|
||||
TCPStateFinWait1
|
||||
TCPStateFinWait2
|
||||
TCPStateClosing
|
||||
TCPStateTimeWait
|
||||
TCPStateCloseWait
|
||||
TCPStateLastAck
|
||||
TCPStateClosed
|
||||
)
|
||||
|
||||
// TCPConnKey uniquely identifies a TCP connection
|
||||
type TCPConnKey struct {
|
||||
SrcIP [16]byte
|
||||
DstIP [16]byte
|
||||
SrcPort uint16
|
||||
DstPort uint16
|
||||
}
|
||||
|
||||
// TCPConnTrack represents a TCP connection state
|
||||
type TCPConnTrack struct {
|
||||
BaseConnTrack
|
||||
State TCPState
|
||||
sync.RWMutex
|
||||
}
|
||||
|
||||
// TCPTracker manages TCP connection states
|
||||
type TCPTracker struct {
|
||||
connections map[ConnKey]*TCPConnTrack
|
||||
mutex sync.RWMutex
|
||||
cleanupTicker *time.Ticker
|
||||
done chan struct{}
|
||||
timeout time.Duration
|
||||
ipPool *PreallocatedIPs
|
||||
}
|
||||
|
||||
// NewTCPTracker creates a new TCP connection tracker
|
||||
func NewTCPTracker(timeout time.Duration) *TCPTracker {
|
||||
tracker := &TCPTracker{
|
||||
connections: make(map[ConnKey]*TCPConnTrack),
|
||||
cleanupTicker: time.NewTicker(TCPCleanupInterval),
|
||||
done: make(chan struct{}),
|
||||
timeout: timeout,
|
||||
ipPool: NewPreallocatedIPs(),
|
||||
}
|
||||
|
||||
go tracker.cleanupRoutine()
|
||||
return tracker
|
||||
}
|
||||
|
||||
// TrackOutbound processes an outbound TCP packet and updates connection state
|
||||
func (t *TCPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16, flags uint8) {
|
||||
// Create key before lock
|
||||
key := makeConnKey(srcIP, dstIP, srcPort, dstPort)
|
||||
now := time.Now().UnixNano()
|
||||
|
||||
t.mutex.Lock()
|
||||
conn, exists := t.connections[key]
|
||||
if !exists {
|
||||
// Use preallocated IPs
|
||||
srcIPCopy := t.ipPool.Get()
|
||||
dstIPCopy := t.ipPool.Get()
|
||||
copyIP(srcIPCopy, srcIP)
|
||||
copyIP(dstIPCopy, dstIP)
|
||||
|
||||
conn = &TCPConnTrack{
|
||||
BaseConnTrack: BaseConnTrack{
|
||||
SourceIP: srcIPCopy,
|
||||
DestIP: dstIPCopy,
|
||||
SourcePort: srcPort,
|
||||
DestPort: dstPort,
|
||||
},
|
||||
State: TCPStateNew,
|
||||
}
|
||||
conn.lastSeen.Store(now)
|
||||
conn.established.Store(false)
|
||||
t.connections[key] = conn
|
||||
}
|
||||
t.mutex.Unlock()
|
||||
|
||||
// Lock individual connection for state update
|
||||
conn.Lock()
|
||||
t.updateState(conn, flags, true)
|
||||
conn.Unlock()
|
||||
conn.lastSeen.Store(now)
|
||||
}
|
||||
|
||||
// IsValidInbound checks if an inbound TCP packet matches a tracked connection
|
||||
func (t *TCPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16, flags uint8) bool {
|
||||
if !isValidFlagCombination(flags) {
|
||||
return false
|
||||
}
|
||||
|
||||
key := makeConnKey(dstIP, srcIP, dstPort, srcPort)
|
||||
|
||||
t.mutex.RLock()
|
||||
conn, exists := t.connections[key]
|
||||
t.mutex.RUnlock()
|
||||
|
||||
if !exists {
|
||||
return false
|
||||
}
|
||||
|
||||
// Handle RST packets
|
||||
if flags&TCPRst != 0 {
|
||||
conn.Lock()
|
||||
if conn.IsEstablished() || conn.State == TCPStateSynSent || conn.State == TCPStateSynReceived {
|
||||
conn.State = TCPStateClosed
|
||||
conn.SetEstablished(false)
|
||||
conn.Unlock()
|
||||
return true
|
||||
}
|
||||
conn.Unlock()
|
||||
return false
|
||||
}
|
||||
|
||||
conn.Lock()
|
||||
t.updateState(conn, flags, false)
|
||||
conn.UpdateLastSeen()
|
||||
isEstablished := conn.IsEstablished()
|
||||
isValidState := t.isValidStateForFlags(conn.State, flags)
|
||||
conn.Unlock()
|
||||
|
||||
return isEstablished || isValidState
|
||||
}
|
||||
|
||||
// updateState updates the TCP connection state based on flags
|
||||
func (t *TCPTracker) updateState(conn *TCPConnTrack, flags uint8, isOutbound bool) {
|
||||
// Handle RST flag specially - it always causes transition to closed
|
||||
if flags&TCPRst != 0 {
|
||||
conn.State = TCPStateClosed
|
||||
conn.SetEstablished(false)
|
||||
return
|
||||
}
|
||||
|
||||
switch conn.State {
|
||||
case TCPStateNew:
|
||||
if flags&TCPSyn != 0 && flags&TCPAck == 0 {
|
||||
conn.State = TCPStateSynSent
|
||||
}
|
||||
|
||||
case TCPStateSynSent:
|
||||
if flags&TCPSyn != 0 && flags&TCPAck != 0 {
|
||||
if isOutbound {
|
||||
conn.State = TCPStateSynReceived
|
||||
} else {
|
||||
// Simultaneous open
|
||||
conn.State = TCPStateEstablished
|
||||
conn.SetEstablished(true)
|
||||
}
|
||||
}
|
||||
|
||||
case TCPStateSynReceived:
|
||||
if flags&TCPAck != 0 && flags&TCPSyn == 0 {
|
||||
conn.State = TCPStateEstablished
|
||||
conn.SetEstablished(true)
|
||||
}
|
||||
|
||||
case TCPStateEstablished:
|
||||
if flags&TCPFin != 0 {
|
||||
if isOutbound {
|
||||
conn.State = TCPStateFinWait1
|
||||
} else {
|
||||
conn.State = TCPStateCloseWait
|
||||
}
|
||||
conn.SetEstablished(false)
|
||||
}
|
||||
|
||||
case TCPStateFinWait1:
|
||||
switch {
|
||||
case flags&TCPFin != 0 && flags&TCPAck != 0:
|
||||
// Simultaneous close - both sides sent FIN
|
||||
conn.State = TCPStateClosing
|
||||
case flags&TCPFin != 0:
|
||||
conn.State = TCPStateFinWait2
|
||||
case flags&TCPAck != 0:
|
||||
conn.State = TCPStateFinWait2
|
||||
}
|
||||
|
||||
case TCPStateFinWait2:
|
||||
if flags&TCPFin != 0 {
|
||||
conn.State = TCPStateTimeWait
|
||||
}
|
||||
|
||||
case TCPStateClosing:
|
||||
if flags&TCPAck != 0 {
|
||||
conn.State = TCPStateTimeWait
|
||||
// Keep established = false from previous state
|
||||
}
|
||||
|
||||
case TCPStateCloseWait:
|
||||
if flags&TCPFin != 0 {
|
||||
conn.State = TCPStateLastAck
|
||||
}
|
||||
|
||||
case TCPStateLastAck:
|
||||
if flags&TCPAck != 0 {
|
||||
conn.State = TCPStateClosed
|
||||
}
|
||||
|
||||
case TCPStateTimeWait:
|
||||
// Stay in TIME-WAIT for 2MSL before transitioning to closed
|
||||
// This is handled by the cleanup routine
|
||||
}
|
||||
}
|
||||
|
||||
// isValidStateForFlags checks if the TCP flags are valid for the current connection state
|
||||
func (t *TCPTracker) isValidStateForFlags(state TCPState, flags uint8) bool {
|
||||
if !isValidFlagCombination(flags) {
|
||||
return false
|
||||
}
|
||||
|
||||
switch state {
|
||||
case TCPStateNew:
|
||||
return flags&TCPSyn != 0 && flags&TCPAck == 0
|
||||
case TCPStateSynSent:
|
||||
return flags&TCPSyn != 0 && flags&TCPAck != 0
|
||||
case TCPStateSynReceived:
|
||||
return flags&TCPAck != 0
|
||||
case TCPStateEstablished:
|
||||
if flags&TCPRst != 0 {
|
||||
return true
|
||||
}
|
||||
return flags&TCPAck != 0
|
||||
case TCPStateFinWait1:
|
||||
return flags&TCPFin != 0 || flags&TCPAck != 0
|
||||
case TCPStateFinWait2:
|
||||
return flags&TCPFin != 0 || flags&TCPAck != 0
|
||||
case TCPStateClosing:
|
||||
// In CLOSING state, we should accept the final ACK
|
||||
return flags&TCPAck != 0
|
||||
case TCPStateTimeWait:
|
||||
// In TIME_WAIT, we might see retransmissions
|
||||
return flags&TCPAck != 0
|
||||
case TCPStateCloseWait:
|
||||
return flags&TCPFin != 0 || flags&TCPAck != 0
|
||||
case TCPStateLastAck:
|
||||
return flags&TCPAck != 0
|
||||
case TCPStateClosed:
|
||||
// Accept retransmitted ACKs in closed state
|
||||
// This is important because the final ACK might be lost
|
||||
// and the peer will retransmit their FIN-ACK
|
||||
return flags&TCPAck != 0
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (t *TCPTracker) cleanupRoutine() {
|
||||
for {
|
||||
select {
|
||||
case <-t.cleanupTicker.C:
|
||||
t.cleanup()
|
||||
case <-t.done:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (t *TCPTracker) cleanup() {
|
||||
t.mutex.Lock()
|
||||
defer t.mutex.Unlock()
|
||||
|
||||
for key, conn := range t.connections {
|
||||
var timeout time.Duration
|
||||
switch {
|
||||
case conn.State == TCPStateTimeWait:
|
||||
timeout = TimeWaitTimeout
|
||||
case conn.IsEstablished():
|
||||
timeout = t.timeout
|
||||
default:
|
||||
timeout = TCPHandshakeTimeout
|
||||
}
|
||||
|
||||
lastSeen := conn.GetLastSeen()
|
||||
if time.Since(lastSeen) > timeout {
|
||||
// Return IPs to pool
|
||||
t.ipPool.Put(conn.SourceIP)
|
||||
t.ipPool.Put(conn.DestIP)
|
||||
delete(t.connections, key)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Close stops the cleanup routine and releases resources
|
||||
func (t *TCPTracker) Close() {
|
||||
t.cleanupTicker.Stop()
|
||||
close(t.done)
|
||||
|
||||
// Clean up all remaining IPs
|
||||
t.mutex.Lock()
|
||||
for _, conn := range t.connections {
|
||||
t.ipPool.Put(conn.SourceIP)
|
||||
t.ipPool.Put(conn.DestIP)
|
||||
}
|
||||
t.connections = nil
|
||||
t.mutex.Unlock()
|
||||
}
|
||||
|
||||
func isValidFlagCombination(flags uint8) bool {
|
||||
// Invalid: SYN+FIN
|
||||
if flags&TCPSyn != 0 && flags&TCPFin != 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
// Invalid: RST with SYN or FIN
|
||||
if flags&TCPRst != 0 && (flags&TCPSyn != 0 || flags&TCPFin != 0) {
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
308
client/firewall/uspfilter/conntrack/tcp_test.go
Normal file
308
client/firewall/uspfilter/conntrack/tcp_test.go
Normal file
@@ -0,0 +1,308 @@
|
||||
package conntrack
|
||||
|
||||
import (
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestTCPStateMachine(t *testing.T) {
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout)
|
||||
defer tracker.Close()
|
||||
|
||||
srcIP := net.ParseIP("100.64.0.1")
|
||||
dstIP := net.ParseIP("100.64.0.2")
|
||||
srcPort := uint16(12345)
|
||||
dstPort := uint16(80)
|
||||
|
||||
t.Run("Security Tests", func(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
flags uint8
|
||||
wantDrop bool
|
||||
desc string
|
||||
}{
|
||||
{
|
||||
name: "Block unsolicited SYN-ACK",
|
||||
flags: TCPSyn | TCPAck,
|
||||
wantDrop: true,
|
||||
desc: "Should block SYN-ACK without prior SYN",
|
||||
},
|
||||
{
|
||||
name: "Block invalid SYN-FIN",
|
||||
flags: TCPSyn | TCPFin,
|
||||
wantDrop: true,
|
||||
desc: "Should block invalid SYN-FIN combination",
|
||||
},
|
||||
{
|
||||
name: "Block unsolicited RST",
|
||||
flags: TCPRst,
|
||||
wantDrop: true,
|
||||
desc: "Should block RST without connection",
|
||||
},
|
||||
{
|
||||
name: "Block unsolicited ACK",
|
||||
flags: TCPAck,
|
||||
wantDrop: true,
|
||||
desc: "Should block ACK without connection",
|
||||
},
|
||||
{
|
||||
name: "Block data without connection",
|
||||
flags: TCPAck | TCPPush,
|
||||
wantDrop: true,
|
||||
desc: "Should block data without established connection",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
isValid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, tt.flags)
|
||||
require.Equal(t, !tt.wantDrop, isValid, tt.desc)
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Connection Flow Tests", func(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
test func(*testing.T)
|
||||
desc string
|
||||
}{
|
||||
{
|
||||
name: "Normal Handshake",
|
||||
test: func(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
// Send initial SYN
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn)
|
||||
|
||||
// Receive SYN-ACK
|
||||
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck)
|
||||
require.True(t, valid, "SYN-ACK should be allowed")
|
||||
|
||||
// Send ACK
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck)
|
||||
|
||||
// Test data transfer
|
||||
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPPush|TCPAck)
|
||||
require.True(t, valid, "Data should be allowed after handshake")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Normal Close",
|
||||
test: func(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
// First establish connection
|
||||
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
|
||||
|
||||
// Send FIN
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck)
|
||||
|
||||
// Receive ACK for FIN
|
||||
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck)
|
||||
require.True(t, valid, "ACK for FIN should be allowed")
|
||||
|
||||
// Receive FIN from other side
|
||||
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck)
|
||||
require.True(t, valid, "FIN should be allowed")
|
||||
|
||||
// Send final ACK
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "RST During Connection",
|
||||
test: func(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
// First establish connection
|
||||
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
|
||||
|
||||
// Receive RST
|
||||
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst)
|
||||
require.True(t, valid, "RST should be allowed for established connection")
|
||||
|
||||
// Connection is logically dead but we don't enforce blocking subsequent packets
|
||||
// The connection will be cleaned up by timeout
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Simultaneous Close",
|
||||
test: func(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
// First establish connection
|
||||
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
|
||||
|
||||
// Both sides send FIN+ACK
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck)
|
||||
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck)
|
||||
require.True(t, valid, "Simultaneous FIN should be allowed")
|
||||
|
||||
// Both sides send final ACK
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck)
|
||||
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck)
|
||||
require.True(t, valid, "Final ACKs should be allowed")
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
tracker = NewTCPTracker(DefaultTCPTimeout)
|
||||
tt.test(t)
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestRSTHandling(t *testing.T) {
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout)
|
||||
defer tracker.Close()
|
||||
|
||||
srcIP := net.ParseIP("100.64.0.1")
|
||||
dstIP := net.ParseIP("100.64.0.2")
|
||||
srcPort := uint16(12345)
|
||||
dstPort := uint16(80)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
setupState func()
|
||||
sendRST func()
|
||||
wantValid bool
|
||||
desc string
|
||||
}{
|
||||
{
|
||||
name: "RST in established",
|
||||
setupState: func() {
|
||||
// Establish connection first
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn)
|
||||
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck)
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck)
|
||||
},
|
||||
sendRST: func() {
|
||||
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst)
|
||||
},
|
||||
wantValid: true,
|
||||
desc: "Should accept RST for established connection",
|
||||
},
|
||||
{
|
||||
name: "RST without connection",
|
||||
setupState: func() {},
|
||||
sendRST: func() {
|
||||
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst)
|
||||
},
|
||||
wantValid: false,
|
||||
desc: "Should reject RST without connection",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
tt.setupState()
|
||||
tt.sendRST()
|
||||
|
||||
// Verify connection state is as expected
|
||||
key := makeConnKey(srcIP, dstIP, srcPort, dstPort)
|
||||
conn := tracker.connections[key]
|
||||
if tt.wantValid {
|
||||
require.NotNil(t, conn)
|
||||
require.Equal(t, TCPStateClosed, conn.State)
|
||||
require.False(t, conn.IsEstablished())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Helper to establish a TCP connection
|
||||
func establishConnection(t *testing.T, tracker *TCPTracker, srcIP, dstIP net.IP, srcPort, dstPort uint16) {
|
||||
t.Helper()
|
||||
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn)
|
||||
|
||||
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck)
|
||||
require.True(t, valid, "SYN-ACK should be allowed")
|
||||
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck)
|
||||
}
|
||||
|
||||
func BenchmarkTCPTracker(b *testing.B) {
|
||||
b.Run("TrackOutbound", func(b *testing.B) {
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout)
|
||||
defer tracker.Close()
|
||||
|
||||
srcIP := net.ParseIP("192.168.1.1")
|
||||
dstIP := net.ParseIP("192.168.1.2")
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 80, TCPSyn)
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("IsValidInbound", func(b *testing.B) {
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout)
|
||||
defer tracker.Close()
|
||||
|
||||
srcIP := net.ParseIP("192.168.1.1")
|
||||
dstIP := net.ParseIP("192.168.1.2")
|
||||
|
||||
// Pre-populate some connections
|
||||
for i := 0; i < 1000; i++ {
|
||||
tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80, TCPSyn)
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
tracker.IsValidInbound(dstIP, srcIP, 80, uint16(i%1000), TCPAck)
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("ConcurrentAccess", func(b *testing.B) {
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout)
|
||||
defer tracker.Close()
|
||||
|
||||
srcIP := net.ParseIP("192.168.1.1")
|
||||
dstIP := net.ParseIP("192.168.1.2")
|
||||
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
i := 0
|
||||
for pb.Next() {
|
||||
if i%2 == 0 {
|
||||
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 80, TCPSyn)
|
||||
} else {
|
||||
tracker.IsValidInbound(dstIP, srcIP, 80, uint16(i%65535), TCPAck)
|
||||
}
|
||||
i++
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
// Benchmark connection cleanup
|
||||
func BenchmarkCleanup(b *testing.B) {
|
||||
b.Run("TCPCleanup", func(b *testing.B) {
|
||||
tracker := NewTCPTracker(100 * time.Millisecond) // Short timeout for testing
|
||||
defer tracker.Close()
|
||||
|
||||
// Pre-populate with expired connections
|
||||
srcIP := net.ParseIP("192.168.1.1")
|
||||
dstIP := net.ParseIP("192.168.1.2")
|
||||
for i := 0; i < 10000; i++ {
|
||||
tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80, TCPSyn)
|
||||
}
|
||||
|
||||
// Wait for connections to expire
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
tracker.cleanup()
|
||||
}
|
||||
})
|
||||
}
|
||||
158
client/firewall/uspfilter/conntrack/udp.go
Normal file
158
client/firewall/uspfilter/conntrack/udp.go
Normal file
@@ -0,0 +1,158 @@
|
||||
package conntrack
|
||||
|
||||
import (
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
// DefaultUDPTimeout is the default timeout for UDP connections
|
||||
DefaultUDPTimeout = 30 * time.Second
|
||||
// UDPCleanupInterval is how often we check for stale connections
|
||||
UDPCleanupInterval = 15 * time.Second
|
||||
)
|
||||
|
||||
// UDPConnTrack represents a UDP connection state
|
||||
type UDPConnTrack struct {
|
||||
BaseConnTrack
|
||||
}
|
||||
|
||||
// UDPTracker manages UDP connection states
|
||||
type UDPTracker struct {
|
||||
connections map[ConnKey]*UDPConnTrack
|
||||
timeout time.Duration
|
||||
cleanupTicker *time.Ticker
|
||||
mutex sync.RWMutex
|
||||
done chan struct{}
|
||||
ipPool *PreallocatedIPs
|
||||
}
|
||||
|
||||
// NewUDPTracker creates a new UDP connection tracker
|
||||
func NewUDPTracker(timeout time.Duration) *UDPTracker {
|
||||
if timeout == 0 {
|
||||
timeout = DefaultUDPTimeout
|
||||
}
|
||||
|
||||
tracker := &UDPTracker{
|
||||
connections: make(map[ConnKey]*UDPConnTrack),
|
||||
timeout: timeout,
|
||||
cleanupTicker: time.NewTicker(UDPCleanupInterval),
|
||||
done: make(chan struct{}),
|
||||
ipPool: NewPreallocatedIPs(),
|
||||
}
|
||||
|
||||
go tracker.cleanupRoutine()
|
||||
return tracker
|
||||
}
|
||||
|
||||
// TrackOutbound records an outbound UDP connection
|
||||
func (t *UDPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16) {
|
||||
key := makeConnKey(srcIP, dstIP, srcPort, dstPort)
|
||||
now := time.Now().UnixNano()
|
||||
|
||||
t.mutex.Lock()
|
||||
conn, exists := t.connections[key]
|
||||
if !exists {
|
||||
srcIPCopy := t.ipPool.Get()
|
||||
dstIPCopy := t.ipPool.Get()
|
||||
copyIP(srcIPCopy, srcIP)
|
||||
copyIP(dstIPCopy, dstIP)
|
||||
|
||||
conn = &UDPConnTrack{
|
||||
BaseConnTrack: BaseConnTrack{
|
||||
SourceIP: srcIPCopy,
|
||||
DestIP: dstIPCopy,
|
||||
SourcePort: srcPort,
|
||||
DestPort: dstPort,
|
||||
},
|
||||
}
|
||||
conn.lastSeen.Store(now)
|
||||
conn.established.Store(true)
|
||||
t.connections[key] = conn
|
||||
}
|
||||
t.mutex.Unlock()
|
||||
|
||||
conn.lastSeen.Store(now)
|
||||
}
|
||||
|
||||
// IsValidInbound checks if an inbound packet matches a tracked connection
|
||||
func (t *UDPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16) bool {
|
||||
key := makeConnKey(dstIP, srcIP, dstPort, srcPort)
|
||||
|
||||
t.mutex.RLock()
|
||||
conn, exists := t.connections[key]
|
||||
t.mutex.RUnlock()
|
||||
|
||||
if !exists {
|
||||
return false
|
||||
}
|
||||
|
||||
if conn.timeoutExceeded(t.timeout) {
|
||||
return false
|
||||
}
|
||||
|
||||
return conn.IsEstablished() &&
|
||||
ValidateIPs(MakeIPAddr(srcIP), conn.DestIP) &&
|
||||
ValidateIPs(MakeIPAddr(dstIP), conn.SourceIP) &&
|
||||
conn.DestPort == srcPort &&
|
||||
conn.SourcePort == dstPort
|
||||
}
|
||||
|
||||
// cleanupRoutine periodically removes stale connections
|
||||
func (t *UDPTracker) cleanupRoutine() {
|
||||
for {
|
||||
select {
|
||||
case <-t.cleanupTicker.C:
|
||||
t.cleanup()
|
||||
case <-t.done:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (t *UDPTracker) cleanup() {
|
||||
t.mutex.Lock()
|
||||
defer t.mutex.Unlock()
|
||||
|
||||
for key, conn := range t.connections {
|
||||
if conn.timeoutExceeded(t.timeout) {
|
||||
t.ipPool.Put(conn.SourceIP)
|
||||
t.ipPool.Put(conn.DestIP)
|
||||
delete(t.connections, key)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Close stops the cleanup routine and releases resources
|
||||
func (t *UDPTracker) Close() {
|
||||
t.cleanupTicker.Stop()
|
||||
close(t.done)
|
||||
|
||||
t.mutex.Lock()
|
||||
for _, conn := range t.connections {
|
||||
t.ipPool.Put(conn.SourceIP)
|
||||
t.ipPool.Put(conn.DestIP)
|
||||
}
|
||||
t.connections = nil
|
||||
t.mutex.Unlock()
|
||||
}
|
||||
|
||||
// GetConnection safely retrieves a connection state
|
||||
func (t *UDPTracker) GetConnection(srcIP net.IP, srcPort uint16, dstIP net.IP, dstPort uint16) (*UDPConnTrack, bool) {
|
||||
t.mutex.RLock()
|
||||
defer t.mutex.RUnlock()
|
||||
|
||||
key := makeConnKey(srcIP, dstIP, srcPort, dstPort)
|
||||
conn, exists := t.connections[key]
|
||||
if !exists {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
return conn, true
|
||||
}
|
||||
|
||||
// Timeout returns the configured timeout duration for the tracker
|
||||
func (t *UDPTracker) Timeout() time.Duration {
|
||||
return t.timeout
|
||||
}
|
||||
243
client/firewall/uspfilter/conntrack/udp_test.go
Normal file
243
client/firewall/uspfilter/conntrack/udp_test.go
Normal file
@@ -0,0 +1,243 @@
|
||||
package conntrack
|
||||
|
||||
import (
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestNewUDPTracker(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
timeout time.Duration
|
||||
wantTimeout time.Duration
|
||||
}{
|
||||
{
|
||||
name: "with custom timeout",
|
||||
timeout: 1 * time.Minute,
|
||||
wantTimeout: 1 * time.Minute,
|
||||
},
|
||||
{
|
||||
name: "with zero timeout uses default",
|
||||
timeout: 0,
|
||||
wantTimeout: DefaultUDPTimeout,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
tracker := NewUDPTracker(tt.timeout)
|
||||
assert.NotNil(t, tracker)
|
||||
assert.Equal(t, tt.wantTimeout, tracker.timeout)
|
||||
assert.NotNil(t, tracker.connections)
|
||||
assert.NotNil(t, tracker.cleanupTicker)
|
||||
assert.NotNil(t, tracker.done)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestUDPTracker_TrackOutbound(t *testing.T) {
|
||||
tracker := NewUDPTracker(DefaultUDPTimeout)
|
||||
defer tracker.Close()
|
||||
|
||||
srcIP := net.ParseIP("192.168.1.2")
|
||||
dstIP := net.ParseIP("192.168.1.3")
|
||||
srcPort := uint16(12345)
|
||||
dstPort := uint16(53)
|
||||
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort)
|
||||
|
||||
// Verify connection was tracked
|
||||
key := makeConnKey(srcIP, dstIP, srcPort, dstPort)
|
||||
conn, exists := tracker.connections[key]
|
||||
require.True(t, exists)
|
||||
assert.True(t, conn.SourceIP.Equal(srcIP))
|
||||
assert.True(t, conn.DestIP.Equal(dstIP))
|
||||
assert.Equal(t, srcPort, conn.SourcePort)
|
||||
assert.Equal(t, dstPort, conn.DestPort)
|
||||
assert.True(t, conn.IsEstablished())
|
||||
assert.WithinDuration(t, time.Now(), conn.GetLastSeen(), 1*time.Second)
|
||||
}
|
||||
|
||||
func TestUDPTracker_IsValidInbound(t *testing.T) {
|
||||
tracker := NewUDPTracker(1 * time.Second)
|
||||
defer tracker.Close()
|
||||
|
||||
srcIP := net.ParseIP("192.168.1.2")
|
||||
dstIP := net.ParseIP("192.168.1.3")
|
||||
srcPort := uint16(12345)
|
||||
dstPort := uint16(53)
|
||||
|
||||
// Track outbound connection
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
srcIP net.IP
|
||||
dstIP net.IP
|
||||
srcPort uint16
|
||||
dstPort uint16
|
||||
sleep time.Duration
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "valid inbound response",
|
||||
srcIP: dstIP, // Original destination is now source
|
||||
dstIP: srcIP, // Original source is now destination
|
||||
srcPort: dstPort, // Original destination port is now source
|
||||
dstPort: srcPort, // Original source port is now destination
|
||||
sleep: 0,
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "invalid source IP",
|
||||
srcIP: net.ParseIP("192.168.1.4"),
|
||||
dstIP: srcIP,
|
||||
srcPort: dstPort,
|
||||
dstPort: srcPort,
|
||||
sleep: 0,
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "invalid destination IP",
|
||||
srcIP: dstIP,
|
||||
dstIP: net.ParseIP("192.168.1.4"),
|
||||
srcPort: dstPort,
|
||||
dstPort: srcPort,
|
||||
sleep: 0,
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "invalid source port",
|
||||
srcIP: dstIP,
|
||||
dstIP: srcIP,
|
||||
srcPort: 54321,
|
||||
dstPort: srcPort,
|
||||
sleep: 0,
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "invalid destination port",
|
||||
srcIP: dstIP,
|
||||
dstIP: srcIP,
|
||||
srcPort: dstPort,
|
||||
dstPort: 54321,
|
||||
sleep: 0,
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "expired connection",
|
||||
srcIP: dstIP,
|
||||
dstIP: srcIP,
|
||||
srcPort: dstPort,
|
||||
dstPort: srcPort,
|
||||
sleep: 2 * time.Second, // Longer than tracker timeout
|
||||
want: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if tt.sleep > 0 {
|
||||
time.Sleep(tt.sleep)
|
||||
}
|
||||
got := tracker.IsValidInbound(tt.srcIP, tt.dstIP, tt.srcPort, tt.dstPort)
|
||||
assert.Equal(t, tt.want, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestUDPTracker_Cleanup(t *testing.T) {
|
||||
// Use shorter intervals for testing
|
||||
timeout := 50 * time.Millisecond
|
||||
cleanupInterval := 25 * time.Millisecond
|
||||
|
||||
// Create tracker with custom cleanup interval
|
||||
tracker := &UDPTracker{
|
||||
connections: make(map[ConnKey]*UDPConnTrack),
|
||||
timeout: timeout,
|
||||
cleanupTicker: time.NewTicker(cleanupInterval),
|
||||
done: make(chan struct{}),
|
||||
ipPool: NewPreallocatedIPs(),
|
||||
}
|
||||
|
||||
// Start cleanup routine
|
||||
go tracker.cleanupRoutine()
|
||||
|
||||
// Add some connections
|
||||
connections := []struct {
|
||||
srcIP net.IP
|
||||
dstIP net.IP
|
||||
srcPort uint16
|
||||
dstPort uint16
|
||||
}{
|
||||
{
|
||||
srcIP: net.ParseIP("192.168.1.2"),
|
||||
dstIP: net.ParseIP("192.168.1.3"),
|
||||
srcPort: 12345,
|
||||
dstPort: 53,
|
||||
},
|
||||
{
|
||||
srcIP: net.ParseIP("192.168.1.4"),
|
||||
dstIP: net.ParseIP("192.168.1.5"),
|
||||
srcPort: 12346,
|
||||
dstPort: 53,
|
||||
},
|
||||
}
|
||||
|
||||
for _, conn := range connections {
|
||||
tracker.TrackOutbound(conn.srcIP, conn.dstIP, conn.srcPort, conn.dstPort)
|
||||
}
|
||||
|
||||
// Verify initial connections
|
||||
assert.Len(t, tracker.connections, 2)
|
||||
|
||||
// Wait for connection timeout and cleanup interval
|
||||
time.Sleep(timeout + 2*cleanupInterval)
|
||||
|
||||
tracker.mutex.RLock()
|
||||
connCount := len(tracker.connections)
|
||||
tracker.mutex.RUnlock()
|
||||
|
||||
// Verify connections were cleaned up
|
||||
assert.Equal(t, 0, connCount, "Expected all connections to be cleaned up")
|
||||
|
||||
// Properly close the tracker
|
||||
tracker.Close()
|
||||
}
|
||||
|
||||
func BenchmarkUDPTracker(b *testing.B) {
|
||||
b.Run("TrackOutbound", func(b *testing.B) {
|
||||
tracker := NewUDPTracker(DefaultUDPTimeout)
|
||||
defer tracker.Close()
|
||||
|
||||
srcIP := net.ParseIP("192.168.1.1")
|
||||
dstIP := net.ParseIP("192.168.1.2")
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 80)
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("IsValidInbound", func(b *testing.B) {
|
||||
tracker := NewUDPTracker(DefaultUDPTimeout)
|
||||
defer tracker.Close()
|
||||
|
||||
srcIP := net.ParseIP("192.168.1.1")
|
||||
dstIP := net.ParseIP("192.168.1.2")
|
||||
|
||||
// Pre-populate some connections
|
||||
for i := 0; i < 1000; i++ {
|
||||
tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80)
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
tracker.IsValidInbound(dstIP, srcIP, 80, uint16(i%1000))
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -4,8 +4,6 @@ import (
|
||||
"net"
|
||||
|
||||
"github.com/google/gopacket"
|
||||
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
)
|
||||
|
||||
// Rule to handle management of rules
|
||||
@@ -15,7 +13,6 @@ type Rule struct {
|
||||
ipLayer gopacket.LayerType
|
||||
matchByIP bool
|
||||
protoLayer gopacket.LayerType
|
||||
direction firewall.RuleDirection
|
||||
sPort uint16
|
||||
dPort uint16
|
||||
drop bool
|
||||
|
||||
@@ -3,6 +3,9 @@ package uspfilter
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"strconv"
|
||||
"sync"
|
||||
|
||||
"github.com/google/gopacket"
|
||||
@@ -11,18 +14,23 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
)
|
||||
|
||||
const layerTypeAll = 0
|
||||
|
||||
const EnvDisableConntrack = "NB_DISABLE_CONNTRACK"
|
||||
|
||||
var (
|
||||
errRouteNotSupported = fmt.Errorf("route not supported with userspace firewall")
|
||||
)
|
||||
|
||||
// IFaceMapper defines subset methods of interface required for manager
|
||||
type IFaceMapper interface {
|
||||
SetFilter(iface.PacketFilter) error
|
||||
SetFilter(device.PacketFilter) error
|
||||
Address() iface.WGAddress
|
||||
}
|
||||
|
||||
@@ -31,7 +39,9 @@ type RuleSet map[string]Rule
|
||||
|
||||
// Manager userspace firewall manager
|
||||
type Manager struct {
|
||||
outgoingRules map[string]RuleSet
|
||||
// outgoingRules is used for hooks only
|
||||
outgoingRules map[string]RuleSet
|
||||
// incomingRules is used for filtering and hooks
|
||||
incomingRules map[string]RuleSet
|
||||
wgNetwork *net.IPNet
|
||||
decoders sync.Pool
|
||||
@@ -39,6 +49,11 @@ type Manager struct {
|
||||
nativeFirewall firewall.Manager
|
||||
|
||||
mutex sync.RWMutex
|
||||
|
||||
stateful bool
|
||||
udpTracker *conntrack.UDPTracker
|
||||
icmpTracker *conntrack.ICMPTracker
|
||||
tcpTracker *conntrack.TCPTracker
|
||||
}
|
||||
|
||||
// decoder for packages
|
||||
@@ -70,6 +85,8 @@ func CreateWithNativeFirewall(iface IFaceMapper, nativeFirewall firewall.Manager
|
||||
}
|
||||
|
||||
func create(iface IFaceMapper) (*Manager, error) {
|
||||
disableConntrack, _ := strconv.ParseBool(os.Getenv(EnvDisableConntrack))
|
||||
|
||||
m := &Manager{
|
||||
decoders: sync.Pool{
|
||||
New: func() any {
|
||||
@@ -87,6 +104,16 @@ func create(iface IFaceMapper) (*Manager, error) {
|
||||
outgoingRules: make(map[string]RuleSet),
|
||||
incomingRules: make(map[string]RuleSet),
|
||||
wgIface: iface,
|
||||
stateful: !disableConntrack,
|
||||
}
|
||||
|
||||
// Only initialize trackers if stateful mode is enabled
|
||||
if disableConntrack {
|
||||
log.Info("conntrack is disabled")
|
||||
} else {
|
||||
m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout)
|
||||
m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout)
|
||||
m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout)
|
||||
}
|
||||
|
||||
if err := iface.SetFilter(m); err != nil {
|
||||
@@ -95,6 +122,10 @@ func create(iface IFaceMapper) (*Manager, error) {
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func (m *Manager) Init(*statemanager.Manager) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manager) IsServerRouteSupported() bool {
|
||||
if m.nativeFirewall == nil {
|
||||
return false
|
||||
@@ -103,33 +134,32 @@ func (m *Manager) IsServerRouteSupported() bool {
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) InsertRoutingRules(pair firewall.RouterPair) error {
|
||||
func (m *Manager) AddNatRule(pair firewall.RouterPair) error {
|
||||
if m.nativeFirewall == nil {
|
||||
return errRouteNotSupported
|
||||
}
|
||||
return m.nativeFirewall.InsertRoutingRules(pair)
|
||||
return m.nativeFirewall.AddNatRule(pair)
|
||||
}
|
||||
|
||||
// RemoveRoutingRules removes a routing firewall rule
|
||||
func (m *Manager) RemoveRoutingRules(pair firewall.RouterPair) error {
|
||||
// RemoveNatRule removes a routing firewall rule
|
||||
func (m *Manager) RemoveNatRule(pair firewall.RouterPair) error {
|
||||
if m.nativeFirewall == nil {
|
||||
return errRouteNotSupported
|
||||
}
|
||||
return m.nativeFirewall.RemoveRoutingRules(pair)
|
||||
return m.nativeFirewall.RemoveNatRule(pair)
|
||||
}
|
||||
|
||||
// AddFiltering rule to the firewall
|
||||
// AddPeerFiltering rule to the firewall
|
||||
//
|
||||
// If comment argument is empty firewall manager should set
|
||||
// rule ID as comment for the rule
|
||||
func (m *Manager) AddFiltering(
|
||||
func (m *Manager) AddPeerFiltering(
|
||||
ip net.IP,
|
||||
proto firewall.Protocol,
|
||||
sPort *firewall.Port,
|
||||
dPort *firewall.Port,
|
||||
direction firewall.RuleDirection,
|
||||
action firewall.Action,
|
||||
ipsetName string,
|
||||
_ string,
|
||||
comment string,
|
||||
) ([]firewall.Rule, error) {
|
||||
r := Rule{
|
||||
@@ -137,7 +167,6 @@ func (m *Manager) AddFiltering(
|
||||
ip: ip,
|
||||
ipLayer: layers.LayerTypeIPv6,
|
||||
matchByIP: true,
|
||||
direction: direction,
|
||||
drop: action == firewall.ActionDrop,
|
||||
comment: comment,
|
||||
}
|
||||
@@ -173,23 +202,30 @@ func (m *Manager) AddFiltering(
|
||||
}
|
||||
|
||||
m.mutex.Lock()
|
||||
if direction == firewall.RuleDirectionIN {
|
||||
if _, ok := m.incomingRules[r.ip.String()]; !ok {
|
||||
m.incomingRules[r.ip.String()] = make(RuleSet)
|
||||
}
|
||||
m.incomingRules[r.ip.String()][r.id] = r
|
||||
} else {
|
||||
if _, ok := m.outgoingRules[r.ip.String()]; !ok {
|
||||
m.outgoingRules[r.ip.String()] = make(RuleSet)
|
||||
}
|
||||
m.outgoingRules[r.ip.String()][r.id] = r
|
||||
if _, ok := m.incomingRules[r.ip.String()]; !ok {
|
||||
m.incomingRules[r.ip.String()] = make(RuleSet)
|
||||
}
|
||||
m.incomingRules[r.ip.String()][r.id] = r
|
||||
m.mutex.Unlock()
|
||||
return []firewall.Rule{&r}, nil
|
||||
}
|
||||
|
||||
// DeleteRule from the firewall by rule definition
|
||||
func (m *Manager) DeleteRule(rule firewall.Rule) error {
|
||||
func (m *Manager) AddRouteFiltering(sources []netip.Prefix, destination netip.Prefix, proto firewall.Protocol, sPort *firewall.Port, dPort *firewall.Port, action firewall.Action) (firewall.Rule, error) {
|
||||
if m.nativeFirewall == nil {
|
||||
return nil, errRouteNotSupported
|
||||
}
|
||||
return m.nativeFirewall.AddRouteFiltering(sources, destination, proto, sPort, dPort, action)
|
||||
}
|
||||
|
||||
func (m *Manager) DeleteRouteRule(rule firewall.Rule) error {
|
||||
if m.nativeFirewall == nil {
|
||||
return errRouteNotSupported
|
||||
}
|
||||
return m.nativeFirewall.DeleteRouteRule(rule)
|
||||
}
|
||||
|
||||
// DeletePeerRule from the firewall by rule definition
|
||||
func (m *Manager) DeletePeerRule(rule firewall.Rule) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
@@ -198,38 +234,37 @@ func (m *Manager) DeleteRule(rule firewall.Rule) error {
|
||||
return fmt.Errorf("delete rule: invalid rule type: %T", rule)
|
||||
}
|
||||
|
||||
if r.direction == firewall.RuleDirectionIN {
|
||||
_, ok := m.incomingRules[r.ip.String()][r.id]
|
||||
if !ok {
|
||||
return fmt.Errorf("delete rule: no rule with such id: %v", r.id)
|
||||
}
|
||||
delete(m.incomingRules[r.ip.String()], r.id)
|
||||
} else {
|
||||
_, ok := m.outgoingRules[r.ip.String()][r.id]
|
||||
if !ok {
|
||||
return fmt.Errorf("delete rule: no rule with such id: %v", r.id)
|
||||
}
|
||||
delete(m.outgoingRules[r.ip.String()], r.id)
|
||||
if _, ok := m.incomingRules[r.ip.String()][r.id]; !ok {
|
||||
return fmt.Errorf("delete rule: no rule with such id: %v", r.id)
|
||||
}
|
||||
delete(m.incomingRules[r.ip.String()], r.id)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetLegacyManagement doesn't need to be implemented for this manager
|
||||
func (m *Manager) SetLegacyManagement(isLegacy bool) error {
|
||||
if m.nativeFirewall == nil {
|
||||
return nil
|
||||
}
|
||||
return m.nativeFirewall.SetLegacyManagement(isLegacy)
|
||||
}
|
||||
|
||||
// Flush doesn't need to be implemented for this manager
|
||||
func (m *Manager) Flush() error { return nil }
|
||||
|
||||
// DropOutgoing filter outgoing packets
|
||||
func (m *Manager) DropOutgoing(packetData []byte) bool {
|
||||
return m.dropFilter(packetData, m.outgoingRules, false)
|
||||
return m.processOutgoingHooks(packetData)
|
||||
}
|
||||
|
||||
// DropIncoming filter incoming packets
|
||||
func (m *Manager) DropIncoming(packetData []byte) bool {
|
||||
return m.dropFilter(packetData, m.incomingRules, true)
|
||||
return m.dropFilter(packetData, m.incomingRules)
|
||||
}
|
||||
|
||||
// dropFilter implements same logic for booth direction of the traffic
|
||||
func (m *Manager) dropFilter(packetData []byte, rules map[string]RuleSet, isIncomingPacket bool) bool {
|
||||
// processOutgoingHooks processes UDP hooks for outgoing packets and tracks TCP/UDP/ICMP
|
||||
func (m *Manager) processOutgoingHooks(packetData []byte) bool {
|
||||
m.mutex.RLock()
|
||||
defer m.mutex.RUnlock()
|
||||
|
||||
@@ -237,61 +272,215 @@ func (m *Manager) dropFilter(packetData []byte, rules map[string]RuleSet, isInco
|
||||
defer m.decoders.Put(d)
|
||||
|
||||
if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil {
|
||||
log.Tracef("couldn't decode layer, err: %s", err)
|
||||
return true
|
||||
return false
|
||||
}
|
||||
|
||||
if len(d.decoded) < 2 {
|
||||
log.Tracef("not enough levels in network packet")
|
||||
return false
|
||||
}
|
||||
|
||||
srcIP, dstIP := m.extractIPs(d)
|
||||
if srcIP == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// Always process UDP hooks
|
||||
if d.decoded[1] == layers.LayerTypeUDP {
|
||||
// Track UDP state only if enabled
|
||||
if m.stateful {
|
||||
m.trackUDPOutbound(d, srcIP, dstIP)
|
||||
}
|
||||
return m.checkUDPHooks(d, dstIP, packetData)
|
||||
}
|
||||
|
||||
// Track other protocols only if stateful mode is enabled
|
||||
if m.stateful {
|
||||
switch d.decoded[1] {
|
||||
case layers.LayerTypeTCP:
|
||||
m.trackTCPOutbound(d, srcIP, dstIP)
|
||||
case layers.LayerTypeICMPv4:
|
||||
m.trackICMPOutbound(d, srcIP, dstIP)
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func (m *Manager) extractIPs(d *decoder) (srcIP, dstIP net.IP) {
|
||||
switch d.decoded[0] {
|
||||
case layers.LayerTypeIPv4:
|
||||
return d.ip4.SrcIP, d.ip4.DstIP
|
||||
case layers.LayerTypeIPv6:
|
||||
return d.ip6.SrcIP, d.ip6.DstIP
|
||||
default:
|
||||
return nil, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) trackTCPOutbound(d *decoder, srcIP, dstIP net.IP) {
|
||||
flags := getTCPFlags(&d.tcp)
|
||||
m.tcpTracker.TrackOutbound(
|
||||
srcIP,
|
||||
dstIP,
|
||||
uint16(d.tcp.SrcPort),
|
||||
uint16(d.tcp.DstPort),
|
||||
flags,
|
||||
)
|
||||
}
|
||||
|
||||
func getTCPFlags(tcp *layers.TCP) uint8 {
|
||||
var flags uint8
|
||||
if tcp.SYN {
|
||||
flags |= conntrack.TCPSyn
|
||||
}
|
||||
if tcp.ACK {
|
||||
flags |= conntrack.TCPAck
|
||||
}
|
||||
if tcp.FIN {
|
||||
flags |= conntrack.TCPFin
|
||||
}
|
||||
if tcp.RST {
|
||||
flags |= conntrack.TCPRst
|
||||
}
|
||||
if tcp.PSH {
|
||||
flags |= conntrack.TCPPush
|
||||
}
|
||||
if tcp.URG {
|
||||
flags |= conntrack.TCPUrg
|
||||
}
|
||||
return flags
|
||||
}
|
||||
|
||||
func (m *Manager) trackUDPOutbound(d *decoder, srcIP, dstIP net.IP) {
|
||||
m.udpTracker.TrackOutbound(
|
||||
srcIP,
|
||||
dstIP,
|
||||
uint16(d.udp.SrcPort),
|
||||
uint16(d.udp.DstPort),
|
||||
)
|
||||
}
|
||||
|
||||
func (m *Manager) checkUDPHooks(d *decoder, dstIP net.IP, packetData []byte) bool {
|
||||
for _, ipKey := range []string{dstIP.String(), "0.0.0.0", "::"} {
|
||||
if rules, exists := m.outgoingRules[ipKey]; exists {
|
||||
for _, rule := range rules {
|
||||
if rule.udpHook != nil && (rule.dPort == 0 || rule.dPort == uint16(d.udp.DstPort)) {
|
||||
return rule.udpHook(packetData)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (m *Manager) trackICMPOutbound(d *decoder, srcIP, dstIP net.IP) {
|
||||
if d.icmp4.TypeCode.Type() == layers.ICMPv4TypeEchoRequest {
|
||||
m.icmpTracker.TrackOutbound(
|
||||
srcIP,
|
||||
dstIP,
|
||||
d.icmp4.Id,
|
||||
d.icmp4.Seq,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// dropFilter implements filtering logic for incoming packets
|
||||
func (m *Manager) dropFilter(packetData []byte, rules map[string]RuleSet) bool {
|
||||
// TODO: Disable router if --disable-server-router is set
|
||||
|
||||
m.mutex.RLock()
|
||||
defer m.mutex.RUnlock()
|
||||
|
||||
d := m.decoders.Get().(*decoder)
|
||||
defer m.decoders.Put(d)
|
||||
|
||||
if !m.isValidPacket(d, packetData) {
|
||||
return true
|
||||
}
|
||||
|
||||
ipLayer := d.decoded[0]
|
||||
|
||||
switch ipLayer {
|
||||
case layers.LayerTypeIPv4:
|
||||
if !m.wgNetwork.Contains(d.ip4.SrcIP) || !m.wgNetwork.Contains(d.ip4.DstIP) {
|
||||
return false
|
||||
}
|
||||
case layers.LayerTypeIPv6:
|
||||
if !m.wgNetwork.Contains(d.ip6.SrcIP) || !m.wgNetwork.Contains(d.ip6.DstIP) {
|
||||
return false
|
||||
}
|
||||
default:
|
||||
srcIP, dstIP := m.extractIPs(d)
|
||||
if srcIP == nil {
|
||||
log.Errorf("unknown layer: %v", d.decoded[0])
|
||||
return true
|
||||
}
|
||||
|
||||
var ip net.IP
|
||||
switch ipLayer {
|
||||
case layers.LayerTypeIPv4:
|
||||
if isIncomingPacket {
|
||||
ip = d.ip4.SrcIP
|
||||
} else {
|
||||
ip = d.ip4.DstIP
|
||||
}
|
||||
case layers.LayerTypeIPv6:
|
||||
if isIncomingPacket {
|
||||
ip = d.ip6.SrcIP
|
||||
} else {
|
||||
ip = d.ip6.DstIP
|
||||
}
|
||||
if !m.isWireguardTraffic(srcIP, dstIP) {
|
||||
return false
|
||||
}
|
||||
|
||||
filter, ok := validateRule(ip, packetData, rules[ip.String()], d)
|
||||
if ok {
|
||||
return filter
|
||||
// Check connection state only if enabled
|
||||
if m.stateful && m.isValidTrackedConnection(d, srcIP, dstIP) {
|
||||
return false
|
||||
}
|
||||
filter, ok = validateRule(ip, packetData, rules["0.0.0.0"], d)
|
||||
if ok {
|
||||
return filter
|
||||
|
||||
return m.applyRules(srcIP, packetData, rules, d)
|
||||
}
|
||||
|
||||
func (m *Manager) isValidPacket(d *decoder, packetData []byte) bool {
|
||||
if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil {
|
||||
log.Tracef("couldn't decode layer, err: %s", err)
|
||||
return false
|
||||
}
|
||||
filter, ok = validateRule(ip, packetData, rules["::"], d)
|
||||
if ok {
|
||||
|
||||
if len(d.decoded) < 2 {
|
||||
log.Tracef("not enough levels in network packet")
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (m *Manager) isWireguardTraffic(srcIP, dstIP net.IP) bool {
|
||||
return m.wgNetwork.Contains(srcIP) && m.wgNetwork.Contains(dstIP)
|
||||
}
|
||||
|
||||
func (m *Manager) isValidTrackedConnection(d *decoder, srcIP, dstIP net.IP) bool {
|
||||
switch d.decoded[1] {
|
||||
case layers.LayerTypeTCP:
|
||||
return m.tcpTracker.IsValidInbound(
|
||||
srcIP,
|
||||
dstIP,
|
||||
uint16(d.tcp.SrcPort),
|
||||
uint16(d.tcp.DstPort),
|
||||
getTCPFlags(&d.tcp),
|
||||
)
|
||||
|
||||
case layers.LayerTypeUDP:
|
||||
return m.udpTracker.IsValidInbound(
|
||||
srcIP,
|
||||
dstIP,
|
||||
uint16(d.udp.SrcPort),
|
||||
uint16(d.udp.DstPort),
|
||||
)
|
||||
|
||||
case layers.LayerTypeICMPv4:
|
||||
return m.icmpTracker.IsValidInbound(
|
||||
srcIP,
|
||||
dstIP,
|
||||
d.icmp4.Id,
|
||||
d.icmp4.Seq,
|
||||
d.icmp4.TypeCode.Type(),
|
||||
)
|
||||
|
||||
// TODO: ICMPv6
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func (m *Manager) applyRules(srcIP net.IP, packetData []byte, rules map[string]RuleSet, d *decoder) bool {
|
||||
if filter, ok := validateRule(srcIP, packetData, rules[srcIP.String()], d); ok {
|
||||
return filter
|
||||
}
|
||||
|
||||
// default policy is DROP ALL
|
||||
if filter, ok := validateRule(srcIP, packetData, rules["0.0.0.0"], d); ok {
|
||||
return filter
|
||||
}
|
||||
|
||||
if filter, ok := validateRule(srcIP, packetData, rules["::"], d); ok {
|
||||
return filter
|
||||
}
|
||||
|
||||
// Default policy: DROP ALL
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -337,7 +526,6 @@ func validateRule(ip net.IP, packetData []byte, rules map[string]Rule, d *decode
|
||||
if rule.dPort != 0 && rule.dPort == uint16(d.udp.DstPort) {
|
||||
return rule.drop, true
|
||||
}
|
||||
return rule.drop, true
|
||||
case layers.LayerTypeICMPv4, layers.LayerTypeICMPv6:
|
||||
return rule.drop, true
|
||||
}
|
||||
@@ -362,7 +550,6 @@ func (m *Manager) AddUDPPacketHook(
|
||||
protoLayer: layers.LayerTypeUDP,
|
||||
dPort: dPort,
|
||||
ipLayer: layers.LayerTypeIPv6,
|
||||
direction: firewall.RuleDirectionOUT,
|
||||
comment: fmt.Sprintf("UDP Hook direction: %v, ip:%v, dport:%d", in, ip, dPort),
|
||||
udpHook: hook,
|
||||
}
|
||||
@@ -373,7 +560,6 @@ func (m *Manager) AddUDPPacketHook(
|
||||
|
||||
m.mutex.Lock()
|
||||
if in {
|
||||
r.direction = firewall.RuleDirectionIN
|
||||
if _, ok := m.incomingRules[r.ip.String()]; !ok {
|
||||
m.incomingRules[r.ip.String()] = make(map[string]Rule)
|
||||
}
|
||||
@@ -392,19 +578,22 @@ func (m *Manager) AddUDPPacketHook(
|
||||
|
||||
// RemovePacketHook removes packet hook by given ID
|
||||
func (m *Manager) RemovePacketHook(hookID string) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
for _, arr := range m.incomingRules {
|
||||
for _, r := range arr {
|
||||
if r.id == hookID {
|
||||
rule := r
|
||||
return m.DeleteRule(&rule)
|
||||
delete(arr, r.id)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
for _, arr := range m.outgoingRules {
|
||||
for _, r := range arr {
|
||||
if r.id == hookID {
|
||||
rule := r
|
||||
return m.DeleteRule(&rule)
|
||||
delete(arr, r.id)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
998
client/firewall/uspfilter/uspfilter_bench_test.go
Normal file
998
client/firewall/uspfilter/uspfilter_bench_test.go
Normal file
@@ -0,0 +1,998 @@
|
||||
package uspfilter
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"net"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/google/gopacket"
|
||||
"github.com/google/gopacket/layers"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
fw "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
)
|
||||
|
||||
// generateRandomIPs generates n different random IPs in the 100.64.0.0/10 range
|
||||
func generateRandomIPs(n int) []net.IP {
|
||||
ips := make([]net.IP, n)
|
||||
seen := make(map[string]bool)
|
||||
|
||||
for i := 0; i < n; {
|
||||
ip := make(net.IP, 4)
|
||||
ip[0] = 100
|
||||
ip[1] = byte(64 + rand.Intn(63)) // 64-126
|
||||
ip[2] = byte(rand.Intn(256))
|
||||
ip[3] = byte(1 + rand.Intn(254)) // avoid .0 and .255
|
||||
|
||||
key := ip.String()
|
||||
if !seen[key] {
|
||||
ips[i] = ip
|
||||
seen[key] = true
|
||||
i++
|
||||
}
|
||||
}
|
||||
return ips
|
||||
}
|
||||
|
||||
func generatePacket(b *testing.B, srcIP, dstIP net.IP, srcPort, dstPort uint16, protocol layers.IPProtocol) []byte {
|
||||
b.Helper()
|
||||
|
||||
ipv4 := &layers.IPv4{
|
||||
TTL: 64,
|
||||
Version: 4,
|
||||
SrcIP: srcIP,
|
||||
DstIP: dstIP,
|
||||
Protocol: protocol,
|
||||
}
|
||||
|
||||
var transportLayer gopacket.SerializableLayer
|
||||
switch protocol {
|
||||
case layers.IPProtocolTCP:
|
||||
tcp := &layers.TCP{
|
||||
SrcPort: layers.TCPPort(srcPort),
|
||||
DstPort: layers.TCPPort(dstPort),
|
||||
SYN: true,
|
||||
}
|
||||
require.NoError(b, tcp.SetNetworkLayerForChecksum(ipv4))
|
||||
transportLayer = tcp
|
||||
case layers.IPProtocolUDP:
|
||||
udp := &layers.UDP{
|
||||
SrcPort: layers.UDPPort(srcPort),
|
||||
DstPort: layers.UDPPort(dstPort),
|
||||
}
|
||||
require.NoError(b, udp.SetNetworkLayerForChecksum(ipv4))
|
||||
transportLayer = udp
|
||||
}
|
||||
|
||||
buf := gopacket.NewSerializeBuffer()
|
||||
opts := gopacket.SerializeOptions{ComputeChecksums: true, FixLengths: true}
|
||||
err := gopacket.SerializeLayers(buf, opts, ipv4, transportLayer, gopacket.Payload("test"))
|
||||
require.NoError(b, err)
|
||||
return buf.Bytes()
|
||||
}
|
||||
|
||||
// BenchmarkCoreFiltering focuses on the essential performance comparisons between
|
||||
// stateful and stateless filtering approaches
|
||||
func BenchmarkCoreFiltering(b *testing.B) {
|
||||
scenarios := []struct {
|
||||
name string
|
||||
stateful bool
|
||||
setupFunc func(*Manager)
|
||||
desc string
|
||||
}{
|
||||
{
|
||||
name: "stateless_single_allow_all",
|
||||
stateful: false,
|
||||
setupFunc: func(m *Manager) {
|
||||
// Single rule allowing all traffic
|
||||
_, err := m.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolALL, nil, nil,
|
||||
fw.ActionAccept, "", "allow all")
|
||||
require.NoError(b, err)
|
||||
},
|
||||
desc: "Baseline: Single 'allow all' rule without connection tracking",
|
||||
},
|
||||
{
|
||||
name: "stateful_no_rules",
|
||||
stateful: true,
|
||||
setupFunc: func(m *Manager) {
|
||||
// No explicit rules - rely purely on connection tracking
|
||||
},
|
||||
desc: "Pure connection tracking without any rules",
|
||||
},
|
||||
{
|
||||
name: "stateless_explicit_return",
|
||||
stateful: false,
|
||||
setupFunc: func(m *Manager) {
|
||||
// Add explicit rules matching return traffic pattern
|
||||
for i := 0; i < 1000; i++ { // Simulate realistic ruleset size
|
||||
ip := generateRandomIPs(1)[0]
|
||||
_, err := m.AddPeerFiltering(ip, fw.ProtocolTCP,
|
||||
&fw.Port{Values: []int{1024 + i}},
|
||||
&fw.Port{Values: []int{80}},
|
||||
fw.ActionAccept, "", "explicit return")
|
||||
require.NoError(b, err)
|
||||
}
|
||||
},
|
||||
desc: "Explicit rules matching return traffic patterns without state",
|
||||
},
|
||||
{
|
||||
name: "stateful_with_established",
|
||||
stateful: true,
|
||||
setupFunc: func(m *Manager) {
|
||||
// Add some basic rules but rely on state for established connections
|
||||
_, err := m.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolTCP, nil, nil,
|
||||
fw.ActionDrop, "", "default drop")
|
||||
require.NoError(b, err)
|
||||
},
|
||||
desc: "Connection tracking with established connections",
|
||||
},
|
||||
}
|
||||
|
||||
// Test both TCP and UDP
|
||||
protocols := []struct {
|
||||
name string
|
||||
proto layers.IPProtocol
|
||||
}{
|
||||
{"TCP", layers.IPProtocolTCP},
|
||||
{"UDP", layers.IPProtocolUDP},
|
||||
}
|
||||
|
||||
for _, sc := range scenarios {
|
||||
for _, proto := range protocols {
|
||||
b.Run(fmt.Sprintf("%s_%s", sc.name, proto.name), func(b *testing.B) {
|
||||
// Configure stateful/stateless mode
|
||||
if !sc.stateful {
|
||||
require.NoError(b, os.Setenv("NB_DISABLE_CONNTRACK", "1"))
|
||||
} else {
|
||||
require.NoError(b, os.Setenv("NB_CONNTRACK_TIMEOUT", "1m"))
|
||||
}
|
||||
|
||||
// Create manager and basic setup
|
||||
manager, _ := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
})
|
||||
defer b.Cleanup(func() {
|
||||
require.NoError(b, manager.Reset(nil))
|
||||
})
|
||||
|
||||
manager.wgNetwork = &net.IPNet{
|
||||
IP: net.ParseIP("100.64.0.0"),
|
||||
Mask: net.CIDRMask(10, 32),
|
||||
}
|
||||
|
||||
// Apply scenario-specific setup
|
||||
sc.setupFunc(manager)
|
||||
|
||||
// Generate test packets
|
||||
srcIP := generateRandomIPs(1)[0]
|
||||
dstIP := generateRandomIPs(1)[0]
|
||||
srcPort := uint16(1024 + b.N%60000)
|
||||
dstPort := uint16(80)
|
||||
|
||||
outbound := generatePacket(b, srcIP, dstIP, srcPort, dstPort, proto.proto)
|
||||
inbound := generatePacket(b, dstIP, srcIP, dstPort, srcPort, proto.proto)
|
||||
|
||||
// For stateful scenarios, establish the connection
|
||||
if sc.stateful {
|
||||
manager.processOutgoingHooks(outbound)
|
||||
}
|
||||
|
||||
// Measure inbound packet processing
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
manager.dropFilter(inbound, manager.incomingRules)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkStateScaling measures how performance scales with connection table size
|
||||
func BenchmarkStateScaling(b *testing.B) {
|
||||
connCounts := []int{100, 1000, 10000, 100000}
|
||||
|
||||
for _, count := range connCounts {
|
||||
b.Run(fmt.Sprintf("conns_%d", count), func(b *testing.B) {
|
||||
manager, _ := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
})
|
||||
b.Cleanup(func() {
|
||||
require.NoError(b, manager.Reset(nil))
|
||||
})
|
||||
|
||||
manager.wgNetwork = &net.IPNet{
|
||||
IP: net.ParseIP("100.64.0.0"),
|
||||
Mask: net.CIDRMask(10, 32),
|
||||
}
|
||||
|
||||
// Pre-populate connection table
|
||||
srcIPs := generateRandomIPs(count)
|
||||
dstIPs := generateRandomIPs(count)
|
||||
for i := 0; i < count; i++ {
|
||||
outbound := generatePacket(b, srcIPs[i], dstIPs[i],
|
||||
uint16(1024+i), 80, layers.IPProtocolTCP)
|
||||
manager.processOutgoingHooks(outbound)
|
||||
}
|
||||
|
||||
// Test packet
|
||||
testOut := generatePacket(b, srcIPs[0], dstIPs[0], 1024, 80, layers.IPProtocolTCP)
|
||||
testIn := generatePacket(b, dstIPs[0], srcIPs[0], 80, 1024, layers.IPProtocolTCP)
|
||||
|
||||
// First establish our test connection
|
||||
manager.processOutgoingHooks(testOut)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
manager.dropFilter(testIn, manager.incomingRules)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkEstablishmentOverhead measures the overhead of connection establishment
|
||||
func BenchmarkEstablishmentOverhead(b *testing.B) {
|
||||
scenarios := []struct {
|
||||
name string
|
||||
established bool
|
||||
}{
|
||||
{"established", true},
|
||||
{"new", false},
|
||||
}
|
||||
|
||||
for _, sc := range scenarios {
|
||||
b.Run(sc.name, func(b *testing.B) {
|
||||
manager, _ := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
})
|
||||
b.Cleanup(func() {
|
||||
require.NoError(b, manager.Reset(nil))
|
||||
})
|
||||
|
||||
manager.wgNetwork = &net.IPNet{
|
||||
IP: net.ParseIP("100.64.0.0"),
|
||||
Mask: net.CIDRMask(10, 32),
|
||||
}
|
||||
|
||||
srcIP := generateRandomIPs(1)[0]
|
||||
dstIP := generateRandomIPs(1)[0]
|
||||
outbound := generatePacket(b, srcIP, dstIP, 1024, 80, layers.IPProtocolTCP)
|
||||
inbound := generatePacket(b, dstIP, srcIP, 80, 1024, layers.IPProtocolTCP)
|
||||
|
||||
if sc.established {
|
||||
manager.processOutgoingHooks(outbound)
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
manager.dropFilter(inbound, manager.incomingRules)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkRoutedNetworkReturn compares approaches for handling routed network return traffic
|
||||
func BenchmarkRoutedNetworkReturn(b *testing.B) {
|
||||
scenarios := []struct {
|
||||
name string
|
||||
proto layers.IPProtocol
|
||||
state string // "new", "established", "post_handshake" (TCP only)
|
||||
setupFunc func(*Manager)
|
||||
genPackets func(net.IP, net.IP) ([]byte, []byte) // generates appropriate packets for the scenario
|
||||
desc string
|
||||
}{
|
||||
{
|
||||
name: "allow_non_wg_tcp_new",
|
||||
proto: layers.IPProtocolTCP,
|
||||
state: "new",
|
||||
setupFunc: func(m *Manager) {
|
||||
m.wgNetwork = &net.IPNet{
|
||||
IP: net.ParseIP("100.64.0.0"),
|
||||
Mask: net.CIDRMask(10, 32),
|
||||
}
|
||||
b.Setenv("NB_DISABLE_CONNTRACK", "1")
|
||||
},
|
||||
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
|
||||
return generatePacket(b, srcIP, dstIP, 1024, 80, layers.IPProtocolTCP),
|
||||
generatePacket(b, dstIP, srcIP, 80, 1024, layers.IPProtocolTCP)
|
||||
},
|
||||
desc: "Allow non-WG: TCP new connection",
|
||||
},
|
||||
{
|
||||
name: "allow_non_wg_tcp_established",
|
||||
proto: layers.IPProtocolTCP,
|
||||
state: "established",
|
||||
setupFunc: func(m *Manager) {
|
||||
m.wgNetwork = &net.IPNet{
|
||||
IP: net.ParseIP("100.64.0.0"),
|
||||
Mask: net.CIDRMask(10, 32),
|
||||
}
|
||||
b.Setenv("NB_DISABLE_CONNTRACK", "1")
|
||||
},
|
||||
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
|
||||
// Generate packets with ACK flag for established connection
|
||||
return generateTCPPacketWithFlags(b, srcIP, dstIP, 1024, 80, uint16(conntrack.TCPAck)),
|
||||
generateTCPPacketWithFlags(b, dstIP, srcIP, 80, 1024, uint16(conntrack.TCPAck))
|
||||
},
|
||||
desc: "Allow non-WG: TCP established connection",
|
||||
},
|
||||
{
|
||||
name: "allow_non_wg_udp_new",
|
||||
proto: layers.IPProtocolUDP,
|
||||
state: "new",
|
||||
setupFunc: func(m *Manager) {
|
||||
m.wgNetwork = &net.IPNet{
|
||||
IP: net.ParseIP("100.64.0.0"),
|
||||
Mask: net.CIDRMask(10, 32),
|
||||
}
|
||||
b.Setenv("NB_DISABLE_CONNTRACK", "1")
|
||||
},
|
||||
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
|
||||
return generatePacket(b, srcIP, dstIP, 1024, 80, layers.IPProtocolUDP),
|
||||
generatePacket(b, dstIP, srcIP, 80, 1024, layers.IPProtocolUDP)
|
||||
},
|
||||
desc: "Allow non-WG: UDP new connection",
|
||||
},
|
||||
{
|
||||
name: "allow_non_wg_udp_established",
|
||||
proto: layers.IPProtocolUDP,
|
||||
state: "established",
|
||||
setupFunc: func(m *Manager) {
|
||||
m.wgNetwork = &net.IPNet{
|
||||
IP: net.ParseIP("100.64.0.0"),
|
||||
Mask: net.CIDRMask(10, 32),
|
||||
}
|
||||
b.Setenv("NB_DISABLE_CONNTRACK", "1")
|
||||
},
|
||||
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
|
||||
return generatePacket(b, srcIP, dstIP, 1024, 80, layers.IPProtocolUDP),
|
||||
generatePacket(b, dstIP, srcIP, 80, 1024, layers.IPProtocolUDP)
|
||||
},
|
||||
desc: "Allow non-WG: UDP established connection",
|
||||
},
|
||||
{
|
||||
name: "stateful_tcp_new",
|
||||
proto: layers.IPProtocolTCP,
|
||||
state: "new",
|
||||
setupFunc: func(m *Manager) {
|
||||
m.wgNetwork = &net.IPNet{
|
||||
IP: net.ParseIP("0.0.0.0"),
|
||||
Mask: net.CIDRMask(0, 32),
|
||||
}
|
||||
require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK"))
|
||||
},
|
||||
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
|
||||
return generatePacket(b, srcIP, dstIP, 1024, 80, layers.IPProtocolTCP),
|
||||
generatePacket(b, dstIP, srcIP, 80, 1024, layers.IPProtocolTCP)
|
||||
},
|
||||
desc: "Stateful: TCP new connection",
|
||||
},
|
||||
{
|
||||
name: "stateful_tcp_established",
|
||||
proto: layers.IPProtocolTCP,
|
||||
state: "established",
|
||||
setupFunc: func(m *Manager) {
|
||||
m.wgNetwork = &net.IPNet{
|
||||
IP: net.ParseIP("0.0.0.0"),
|
||||
Mask: net.CIDRMask(0, 32),
|
||||
}
|
||||
require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK"))
|
||||
},
|
||||
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
|
||||
// Generate established TCP packets (ACK flag)
|
||||
return generateTCPPacketWithFlags(b, srcIP, dstIP, 1024, 80, uint16(conntrack.TCPAck)),
|
||||
generateTCPPacketWithFlags(b, dstIP, srcIP, 80, 1024, uint16(conntrack.TCPAck))
|
||||
},
|
||||
desc: "Stateful: TCP established connection",
|
||||
},
|
||||
{
|
||||
name: "stateful_tcp_post_handshake",
|
||||
proto: layers.IPProtocolTCP,
|
||||
state: "post_handshake",
|
||||
setupFunc: func(m *Manager) {
|
||||
m.wgNetwork = &net.IPNet{
|
||||
IP: net.ParseIP("0.0.0.0"),
|
||||
Mask: net.CIDRMask(0, 32),
|
||||
}
|
||||
require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK"))
|
||||
},
|
||||
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
|
||||
// Generate packets with PSH+ACK flags for data transfer
|
||||
return generateTCPPacketWithFlags(b, srcIP, dstIP, 1024, 80, uint16(conntrack.TCPPush|conntrack.TCPAck)),
|
||||
generateTCPPacketWithFlags(b, dstIP, srcIP, 80, 1024, uint16(conntrack.TCPPush|conntrack.TCPAck))
|
||||
},
|
||||
desc: "Stateful: TCP post-handshake data transfer",
|
||||
},
|
||||
{
|
||||
name: "stateful_udp_new",
|
||||
proto: layers.IPProtocolUDP,
|
||||
state: "new",
|
||||
setupFunc: func(m *Manager) {
|
||||
m.wgNetwork = &net.IPNet{
|
||||
IP: net.ParseIP("0.0.0.0"),
|
||||
Mask: net.CIDRMask(0, 32),
|
||||
}
|
||||
require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK"))
|
||||
},
|
||||
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
|
||||
return generatePacket(b, srcIP, dstIP, 1024, 80, layers.IPProtocolUDP),
|
||||
generatePacket(b, dstIP, srcIP, 80, 1024, layers.IPProtocolUDP)
|
||||
},
|
||||
desc: "Stateful: UDP new connection",
|
||||
},
|
||||
{
|
||||
name: "stateful_udp_established",
|
||||
proto: layers.IPProtocolUDP,
|
||||
state: "established",
|
||||
setupFunc: func(m *Manager) {
|
||||
m.wgNetwork = &net.IPNet{
|
||||
IP: net.ParseIP("0.0.0.0"),
|
||||
Mask: net.CIDRMask(0, 32),
|
||||
}
|
||||
require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK"))
|
||||
},
|
||||
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
|
||||
return generatePacket(b, srcIP, dstIP, 1024, 80, layers.IPProtocolUDP),
|
||||
generatePacket(b, dstIP, srcIP, 80, 1024, layers.IPProtocolUDP)
|
||||
},
|
||||
desc: "Stateful: UDP established connection",
|
||||
},
|
||||
}
|
||||
|
||||
for _, sc := range scenarios {
|
||||
b.Run(sc.name, func(b *testing.B) {
|
||||
manager, _ := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
})
|
||||
b.Cleanup(func() {
|
||||
require.NoError(b, manager.Reset(nil))
|
||||
})
|
||||
|
||||
// Setup scenario
|
||||
sc.setupFunc(manager)
|
||||
|
||||
// Use IPs outside WG range for routed network simulation
|
||||
srcIP := net.ParseIP("192.168.1.2")
|
||||
dstIP := net.ParseIP("8.8.8.8")
|
||||
outbound, inbound := sc.genPackets(srcIP, dstIP)
|
||||
|
||||
// For stateful cases and established connections
|
||||
if !strings.Contains(sc.name, "allow_non_wg") ||
|
||||
(strings.Contains(sc.state, "established") || sc.state == "post_handshake") {
|
||||
manager.processOutgoingHooks(outbound)
|
||||
|
||||
// For TCP post-handshake, simulate full handshake
|
||||
if sc.state == "post_handshake" {
|
||||
// SYN
|
||||
syn := generateTCPPacketWithFlags(b, srcIP, dstIP, 1024, 80, uint16(conntrack.TCPSyn))
|
||||
manager.processOutgoingHooks(syn)
|
||||
// SYN-ACK
|
||||
synack := generateTCPPacketWithFlags(b, dstIP, srcIP, 80, 1024, uint16(conntrack.TCPSyn|conntrack.TCPAck))
|
||||
manager.dropFilter(synack, manager.incomingRules)
|
||||
// ACK
|
||||
ack := generateTCPPacketWithFlags(b, srcIP, dstIP, 1024, 80, uint16(conntrack.TCPAck))
|
||||
manager.processOutgoingHooks(ack)
|
||||
}
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
manager.dropFilter(inbound, manager.incomingRules)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
var scenarios = []struct {
|
||||
name string
|
||||
stateful bool // Whether conntrack is enabled
|
||||
rules bool // Whether to add return traffic rules
|
||||
routed bool // Whether to test routed network traffic
|
||||
connCount int // Number of concurrent connections
|
||||
desc string
|
||||
}{
|
||||
{
|
||||
name: "stateless_with_rules_100conns",
|
||||
stateful: false,
|
||||
rules: true,
|
||||
routed: false,
|
||||
connCount: 100,
|
||||
desc: "Pure stateless with return traffic rules, 100 conns",
|
||||
},
|
||||
{
|
||||
name: "stateless_with_rules_1000conns",
|
||||
stateful: false,
|
||||
rules: true,
|
||||
routed: false,
|
||||
connCount: 1000,
|
||||
desc: "Pure stateless with return traffic rules, 1000 conns",
|
||||
},
|
||||
{
|
||||
name: "stateful_no_rules_100conns",
|
||||
stateful: true,
|
||||
rules: false,
|
||||
routed: false,
|
||||
connCount: 100,
|
||||
desc: "Pure stateful tracking without rules, 100 conns",
|
||||
},
|
||||
{
|
||||
name: "stateful_no_rules_1000conns",
|
||||
stateful: true,
|
||||
rules: false,
|
||||
routed: false,
|
||||
connCount: 1000,
|
||||
desc: "Pure stateful tracking without rules, 1000 conns",
|
||||
},
|
||||
{
|
||||
name: "stateful_with_rules_100conns",
|
||||
stateful: true,
|
||||
rules: true,
|
||||
routed: false,
|
||||
connCount: 100,
|
||||
desc: "Combined stateful + rules (current implementation), 100 conns",
|
||||
},
|
||||
{
|
||||
name: "stateful_with_rules_1000conns",
|
||||
stateful: true,
|
||||
rules: true,
|
||||
routed: false,
|
||||
connCount: 1000,
|
||||
desc: "Combined stateful + rules (current implementation), 1000 conns",
|
||||
},
|
||||
{
|
||||
name: "routed_network_100conns",
|
||||
stateful: true,
|
||||
rules: false,
|
||||
routed: true,
|
||||
connCount: 100,
|
||||
desc: "Routed network traffic (non-WG), 100 conns",
|
||||
},
|
||||
{
|
||||
name: "routed_network_1000conns",
|
||||
stateful: true,
|
||||
rules: false,
|
||||
routed: true,
|
||||
connCount: 1000,
|
||||
desc: "Routed network traffic (non-WG), 1000 conns",
|
||||
},
|
||||
}
|
||||
|
||||
// BenchmarkLongLivedConnections tests performance with realistic TCP traffic patterns
|
||||
func BenchmarkLongLivedConnections(b *testing.B) {
|
||||
for _, sc := range scenarios {
|
||||
b.Run(sc.name, func(b *testing.B) {
|
||||
// Configure stateful/stateless mode
|
||||
if !sc.stateful {
|
||||
b.Setenv("NB_DISABLE_CONNTRACK", "1")
|
||||
} else {
|
||||
require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK"))
|
||||
}
|
||||
|
||||
manager, _ := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
})
|
||||
defer b.Cleanup(func() {
|
||||
require.NoError(b, manager.Reset(nil))
|
||||
})
|
||||
|
||||
manager.SetNetwork(&net.IPNet{
|
||||
IP: net.ParseIP("100.64.0.0"),
|
||||
Mask: net.CIDRMask(10, 32),
|
||||
})
|
||||
|
||||
// Setup initial state based on scenario
|
||||
if sc.rules {
|
||||
// Single rule to allow all return traffic from port 80
|
||||
_, err := manager.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolTCP,
|
||||
&fw.Port{Values: []int{80}},
|
||||
nil,
|
||||
fw.ActionAccept, "", "return traffic")
|
||||
require.NoError(b, err)
|
||||
}
|
||||
|
||||
// Generate IPs for connections
|
||||
srcIPs := make([]net.IP, sc.connCount)
|
||||
dstIPs := make([]net.IP, sc.connCount)
|
||||
|
||||
for i := 0; i < sc.connCount; i++ {
|
||||
if sc.routed {
|
||||
srcIPs[i] = net.IPv4(192, 168, 1, byte(2+(i%250))).To4()
|
||||
dstIPs[i] = net.IPv4(8, 8, byte((i/250)%255), byte(2+(i%250))).To4()
|
||||
} else {
|
||||
srcIPs[i] = generateRandomIPs(1)[0]
|
||||
dstIPs[i] = generateRandomIPs(1)[0]
|
||||
}
|
||||
}
|
||||
|
||||
// Create established connections
|
||||
for i := 0; i < sc.connCount; i++ {
|
||||
// Initial SYN
|
||||
syn := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
|
||||
uint16(1024+i), 80, uint16(conntrack.TCPSyn))
|
||||
manager.processOutgoingHooks(syn)
|
||||
|
||||
// SYN-ACK
|
||||
synack := generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i],
|
||||
80, uint16(1024+i), uint16(conntrack.TCPSyn|conntrack.TCPAck))
|
||||
manager.dropFilter(synack, manager.incomingRules)
|
||||
|
||||
// ACK
|
||||
ack := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
|
||||
uint16(1024+i), 80, uint16(conntrack.TCPAck))
|
||||
manager.processOutgoingHooks(ack)
|
||||
}
|
||||
|
||||
// Prepare test packets simulating bidirectional traffic
|
||||
inPackets := make([][]byte, sc.connCount)
|
||||
outPackets := make([][]byte, sc.connCount)
|
||||
for i := 0; i < sc.connCount; i++ {
|
||||
// Server -> Client (inbound)
|
||||
inPackets[i] = generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i],
|
||||
80, uint16(1024+i), uint16(conntrack.TCPPush|conntrack.TCPAck))
|
||||
// Client -> Server (outbound)
|
||||
outPackets[i] = generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
|
||||
uint16(1024+i), 80, uint16(conntrack.TCPPush|conntrack.TCPAck))
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
connIdx := i % sc.connCount
|
||||
|
||||
// Simulate bidirectional traffic
|
||||
// First outbound data
|
||||
manager.processOutgoingHooks(outPackets[connIdx])
|
||||
// Then inbound response - this is what we're actually measuring
|
||||
manager.dropFilter(inPackets[connIdx], manager.incomingRules)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkShortLivedConnections tests performance with many short-lived connections
|
||||
func BenchmarkShortLivedConnections(b *testing.B) {
|
||||
for _, sc := range scenarios {
|
||||
b.Run(sc.name, func(b *testing.B) {
|
||||
// Configure stateful/stateless mode
|
||||
if !sc.stateful {
|
||||
b.Setenv("NB_DISABLE_CONNTRACK", "1")
|
||||
} else {
|
||||
require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK"))
|
||||
}
|
||||
|
||||
manager, _ := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
})
|
||||
defer b.Cleanup(func() {
|
||||
require.NoError(b, manager.Reset(nil))
|
||||
})
|
||||
|
||||
manager.SetNetwork(&net.IPNet{
|
||||
IP: net.ParseIP("100.64.0.0"),
|
||||
Mask: net.CIDRMask(10, 32),
|
||||
})
|
||||
|
||||
// Setup initial state based on scenario
|
||||
if sc.rules {
|
||||
// Single rule to allow all return traffic from port 80
|
||||
_, err := manager.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolTCP,
|
||||
&fw.Port{Values: []int{80}},
|
||||
nil,
|
||||
fw.ActionAccept, "", "return traffic")
|
||||
require.NoError(b, err)
|
||||
}
|
||||
|
||||
// Generate IPs for connections
|
||||
srcIPs := make([]net.IP, sc.connCount)
|
||||
dstIPs := make([]net.IP, sc.connCount)
|
||||
|
||||
for i := 0; i < sc.connCount; i++ {
|
||||
if sc.routed {
|
||||
srcIPs[i] = net.IPv4(192, 168, 1, byte(2+(i%250))).To4()
|
||||
dstIPs[i] = net.IPv4(8, 8, byte((i/250)%255), byte(2+(i%250))).To4()
|
||||
} else {
|
||||
srcIPs[i] = generateRandomIPs(1)[0]
|
||||
dstIPs[i] = generateRandomIPs(1)[0]
|
||||
}
|
||||
}
|
||||
|
||||
// Create packet patterns for a complete HTTP-like short connection:
|
||||
// 1. Initial handshake (SYN, SYN-ACK, ACK)
|
||||
// 2. HTTP Request (PSH+ACK from client)
|
||||
// 3. HTTP Response (PSH+ACK from server)
|
||||
// 4. Connection teardown (FIN+ACK, ACK, FIN+ACK, ACK)
|
||||
type connPackets struct {
|
||||
syn []byte
|
||||
synAck []byte
|
||||
ack []byte
|
||||
request []byte
|
||||
response []byte
|
||||
finClient []byte
|
||||
ackServer []byte
|
||||
finServer []byte
|
||||
ackClient []byte
|
||||
}
|
||||
|
||||
// Generate all possible connection patterns
|
||||
patterns := make([]connPackets, sc.connCount)
|
||||
for i := 0; i < sc.connCount; i++ {
|
||||
patterns[i] = connPackets{
|
||||
// Handshake
|
||||
syn: generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
|
||||
uint16(1024+i), 80, uint16(conntrack.TCPSyn)),
|
||||
synAck: generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i],
|
||||
80, uint16(1024+i), uint16(conntrack.TCPSyn|conntrack.TCPAck)),
|
||||
ack: generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
|
||||
uint16(1024+i), 80, uint16(conntrack.TCPAck)),
|
||||
|
||||
// Data transfer
|
||||
request: generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
|
||||
uint16(1024+i), 80, uint16(conntrack.TCPPush|conntrack.TCPAck)),
|
||||
response: generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i],
|
||||
80, uint16(1024+i), uint16(conntrack.TCPPush|conntrack.TCPAck)),
|
||||
|
||||
// Connection teardown
|
||||
finClient: generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
|
||||
uint16(1024+i), 80, uint16(conntrack.TCPFin|conntrack.TCPAck)),
|
||||
ackServer: generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i],
|
||||
80, uint16(1024+i), uint16(conntrack.TCPAck)),
|
||||
finServer: generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i],
|
||||
80, uint16(1024+i), uint16(conntrack.TCPFin|conntrack.TCPAck)),
|
||||
ackClient: generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
|
||||
uint16(1024+i), 80, uint16(conntrack.TCPAck)),
|
||||
}
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
// Each iteration creates a new short-lived connection
|
||||
connIdx := i % sc.connCount
|
||||
p := patterns[connIdx]
|
||||
|
||||
// Connection establishment
|
||||
manager.processOutgoingHooks(p.syn)
|
||||
manager.dropFilter(p.synAck, manager.incomingRules)
|
||||
manager.processOutgoingHooks(p.ack)
|
||||
|
||||
// Data transfer
|
||||
manager.processOutgoingHooks(p.request)
|
||||
manager.dropFilter(p.response, manager.incomingRules)
|
||||
|
||||
// Connection teardown
|
||||
manager.processOutgoingHooks(p.finClient)
|
||||
manager.dropFilter(p.ackServer, manager.incomingRules)
|
||||
manager.dropFilter(p.finServer, manager.incomingRules)
|
||||
manager.processOutgoingHooks(p.ackClient)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkParallelLongLivedConnections tests performance with realistic TCP traffic patterns in parallel
|
||||
func BenchmarkParallelLongLivedConnections(b *testing.B) {
|
||||
for _, sc := range scenarios {
|
||||
b.Run(sc.name, func(b *testing.B) {
|
||||
// Configure stateful/stateless mode
|
||||
if !sc.stateful {
|
||||
b.Setenv("NB_DISABLE_CONNTRACK", "1")
|
||||
} else {
|
||||
require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK"))
|
||||
}
|
||||
|
||||
manager, _ := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
})
|
||||
defer b.Cleanup(func() {
|
||||
require.NoError(b, manager.Reset(nil))
|
||||
})
|
||||
|
||||
manager.SetNetwork(&net.IPNet{
|
||||
IP: net.ParseIP("100.64.0.0"),
|
||||
Mask: net.CIDRMask(10, 32),
|
||||
})
|
||||
|
||||
// Setup initial state based on scenario
|
||||
if sc.rules {
|
||||
_, err := manager.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolTCP,
|
||||
&fw.Port{Values: []int{80}},
|
||||
nil,
|
||||
fw.ActionAccept, "", "return traffic")
|
||||
require.NoError(b, err)
|
||||
}
|
||||
|
||||
// Generate IPs for connections
|
||||
srcIPs := make([]net.IP, sc.connCount)
|
||||
dstIPs := make([]net.IP, sc.connCount)
|
||||
|
||||
for i := 0; i < sc.connCount; i++ {
|
||||
if sc.routed {
|
||||
srcIPs[i] = net.IPv4(192, 168, 1, byte(2+(i%250))).To4()
|
||||
dstIPs[i] = net.IPv4(8, 8, byte((i/250)%255), byte(2+(i%250))).To4()
|
||||
} else {
|
||||
srcIPs[i] = generateRandomIPs(1)[0]
|
||||
dstIPs[i] = generateRandomIPs(1)[0]
|
||||
}
|
||||
}
|
||||
|
||||
// Create established connections
|
||||
for i := 0; i < sc.connCount; i++ {
|
||||
syn := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
|
||||
uint16(1024+i), 80, uint16(conntrack.TCPSyn))
|
||||
manager.processOutgoingHooks(syn)
|
||||
|
||||
synack := generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i],
|
||||
80, uint16(1024+i), uint16(conntrack.TCPSyn|conntrack.TCPAck))
|
||||
manager.dropFilter(synack, manager.incomingRules)
|
||||
|
||||
ack := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
|
||||
uint16(1024+i), 80, uint16(conntrack.TCPAck))
|
||||
manager.processOutgoingHooks(ack)
|
||||
}
|
||||
|
||||
// Pre-generate test packets
|
||||
inPackets := make([][]byte, sc.connCount)
|
||||
outPackets := make([][]byte, sc.connCount)
|
||||
for i := 0; i < sc.connCount; i++ {
|
||||
inPackets[i] = generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i],
|
||||
80, uint16(1024+i), uint16(conntrack.TCPPush|conntrack.TCPAck))
|
||||
outPackets[i] = generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
|
||||
uint16(1024+i), 80, uint16(conntrack.TCPPush|conntrack.TCPAck))
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
// Each goroutine gets its own counter to distribute load
|
||||
counter := 0
|
||||
for pb.Next() {
|
||||
connIdx := counter % sc.connCount
|
||||
counter++
|
||||
|
||||
// Simulate bidirectional traffic
|
||||
manager.processOutgoingHooks(outPackets[connIdx])
|
||||
manager.dropFilter(inPackets[connIdx], manager.incomingRules)
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkParallelShortLivedConnections tests performance with many short-lived connections in parallel
|
||||
func BenchmarkParallelShortLivedConnections(b *testing.B) {
|
||||
for _, sc := range scenarios {
|
||||
b.Run(sc.name, func(b *testing.B) {
|
||||
// Configure stateful/stateless mode
|
||||
if !sc.stateful {
|
||||
b.Setenv("NB_DISABLE_CONNTRACK", "1")
|
||||
} else {
|
||||
require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK"))
|
||||
}
|
||||
|
||||
manager, _ := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
})
|
||||
defer b.Cleanup(func() {
|
||||
require.NoError(b, manager.Reset(nil))
|
||||
})
|
||||
|
||||
manager.SetNetwork(&net.IPNet{
|
||||
IP: net.ParseIP("100.64.0.0"),
|
||||
Mask: net.CIDRMask(10, 32),
|
||||
})
|
||||
|
||||
if sc.rules {
|
||||
_, err := manager.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolTCP,
|
||||
&fw.Port{Values: []int{80}},
|
||||
nil,
|
||||
fw.ActionAccept, "", "return traffic")
|
||||
require.NoError(b, err)
|
||||
}
|
||||
|
||||
// Generate IPs and pre-generate all packet patterns
|
||||
srcIPs := make([]net.IP, sc.connCount)
|
||||
dstIPs := make([]net.IP, sc.connCount)
|
||||
for i := 0; i < sc.connCount; i++ {
|
||||
if sc.routed {
|
||||
srcIPs[i] = net.IPv4(192, 168, 1, byte(2+(i%250))).To4()
|
||||
dstIPs[i] = net.IPv4(8, 8, byte((i/250)%255), byte(2+(i%250))).To4()
|
||||
} else {
|
||||
srcIPs[i] = generateRandomIPs(1)[0]
|
||||
dstIPs[i] = generateRandomIPs(1)[0]
|
||||
}
|
||||
}
|
||||
|
||||
type connPackets struct {
|
||||
syn []byte
|
||||
synAck []byte
|
||||
ack []byte
|
||||
request []byte
|
||||
response []byte
|
||||
finClient []byte
|
||||
ackServer []byte
|
||||
finServer []byte
|
||||
ackClient []byte
|
||||
}
|
||||
|
||||
patterns := make([]connPackets, sc.connCount)
|
||||
for i := 0; i < sc.connCount; i++ {
|
||||
patterns[i] = connPackets{
|
||||
syn: generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
|
||||
uint16(1024+i), 80, uint16(conntrack.TCPSyn)),
|
||||
synAck: generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i],
|
||||
80, uint16(1024+i), uint16(conntrack.TCPSyn|conntrack.TCPAck)),
|
||||
ack: generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
|
||||
uint16(1024+i), 80, uint16(conntrack.TCPAck)),
|
||||
request: generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
|
||||
uint16(1024+i), 80, uint16(conntrack.TCPPush|conntrack.TCPAck)),
|
||||
response: generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i],
|
||||
80, uint16(1024+i), uint16(conntrack.TCPPush|conntrack.TCPAck)),
|
||||
finClient: generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
|
||||
uint16(1024+i), 80, uint16(conntrack.TCPFin|conntrack.TCPAck)),
|
||||
ackServer: generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i],
|
||||
80, uint16(1024+i), uint16(conntrack.TCPAck)),
|
||||
finServer: generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i],
|
||||
80, uint16(1024+i), uint16(conntrack.TCPFin|conntrack.TCPAck)),
|
||||
ackClient: generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
|
||||
uint16(1024+i), 80, uint16(conntrack.TCPAck)),
|
||||
}
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
counter := 0
|
||||
for pb.Next() {
|
||||
connIdx := counter % sc.connCount
|
||||
counter++
|
||||
p := patterns[connIdx]
|
||||
|
||||
// Full connection lifecycle
|
||||
manager.processOutgoingHooks(p.syn)
|
||||
manager.dropFilter(p.synAck, manager.incomingRules)
|
||||
manager.processOutgoingHooks(p.ack)
|
||||
|
||||
manager.processOutgoingHooks(p.request)
|
||||
manager.dropFilter(p.response, manager.incomingRules)
|
||||
|
||||
manager.processOutgoingHooks(p.finClient)
|
||||
manager.dropFilter(p.ackServer, manager.incomingRules)
|
||||
manager.dropFilter(p.finServer, manager.incomingRules)
|
||||
manager.processOutgoingHooks(p.ackClient)
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// generateTCPPacketWithFlags creates a TCP packet with specific flags
|
||||
func generateTCPPacketWithFlags(b *testing.B, srcIP, dstIP net.IP, srcPort, dstPort, flags uint16) []byte {
|
||||
b.Helper()
|
||||
|
||||
ipv4 := &layers.IPv4{
|
||||
TTL: 64,
|
||||
Version: 4,
|
||||
SrcIP: srcIP,
|
||||
DstIP: dstIP,
|
||||
Protocol: layers.IPProtocolTCP,
|
||||
}
|
||||
|
||||
tcp := &layers.TCP{
|
||||
SrcPort: layers.TCPPort(srcPort),
|
||||
DstPort: layers.TCPPort(dstPort),
|
||||
}
|
||||
|
||||
// Set TCP flags
|
||||
tcp.SYN = (flags & uint16(conntrack.TCPSyn)) != 0
|
||||
tcp.ACK = (flags & uint16(conntrack.TCPAck)) != 0
|
||||
tcp.PSH = (flags & uint16(conntrack.TCPPush)) != 0
|
||||
tcp.RST = (flags & uint16(conntrack.TCPRst)) != 0
|
||||
tcp.FIN = (flags & uint16(conntrack.TCPFin)) != 0
|
||||
|
||||
require.NoError(b, tcp.SetNetworkLayerForChecksum(ipv4))
|
||||
|
||||
buf := gopacket.NewSerializeBuffer()
|
||||
opts := gopacket.SerializeOptions{ComputeChecksums: true, FixLengths: true}
|
||||
require.NoError(b, gopacket.SerializeLayers(buf, opts, ipv4, tcp, gopacket.Payload("test")))
|
||||
return buf.Bytes()
|
||||
}
|
||||
@@ -3,6 +3,7 @@ package uspfilter
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -11,15 +12,17 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
fw "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
)
|
||||
|
||||
type IFaceMock struct {
|
||||
SetFilterFunc func(iface.PacketFilter) error
|
||||
SetFilterFunc func(device.PacketFilter) error
|
||||
AddressFunc func() iface.WGAddress
|
||||
}
|
||||
|
||||
func (i *IFaceMock) SetFilter(iface iface.PacketFilter) error {
|
||||
func (i *IFaceMock) SetFilter(iface device.PacketFilter) error {
|
||||
if i.SetFilterFunc == nil {
|
||||
return fmt.Errorf("not implemented")
|
||||
}
|
||||
@@ -35,7 +38,7 @@ func (i *IFaceMock) Address() iface.WGAddress {
|
||||
|
||||
func TestManagerCreate(t *testing.T) {
|
||||
ifaceMock := &IFaceMock{
|
||||
SetFilterFunc: func(iface.PacketFilter) error { return nil },
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}
|
||||
|
||||
m, err := Create(ifaceMock)
|
||||
@@ -49,10 +52,10 @@ func TestManagerCreate(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestManagerAddFiltering(t *testing.T) {
|
||||
func TestManagerAddPeerFiltering(t *testing.T) {
|
||||
isSetFilterCalled := false
|
||||
ifaceMock := &IFaceMock{
|
||||
SetFilterFunc: func(iface.PacketFilter) error {
|
||||
SetFilterFunc: func(device.PacketFilter) error {
|
||||
isSetFilterCalled = true
|
||||
return nil
|
||||
},
|
||||
@@ -67,11 +70,10 @@ func TestManagerAddFiltering(t *testing.T) {
|
||||
ip := net.ParseIP("192.168.1.1")
|
||||
proto := fw.ProtocolTCP
|
||||
port := &fw.Port{Values: []int{80}}
|
||||
direction := fw.RuleDirectionOUT
|
||||
action := fw.ActionDrop
|
||||
comment := "Test rule"
|
||||
|
||||
rule, err := m.AddFiltering(ip, proto, nil, port, direction, action, "", comment)
|
||||
rule, err := m.AddPeerFiltering(ip, proto, nil, port, action, "", comment)
|
||||
if err != nil {
|
||||
t.Errorf("failed to add filtering: %v", err)
|
||||
return
|
||||
@@ -90,7 +92,7 @@ func TestManagerAddFiltering(t *testing.T) {
|
||||
|
||||
func TestManagerDeleteRule(t *testing.T) {
|
||||
ifaceMock := &IFaceMock{
|
||||
SetFilterFunc: func(iface.PacketFilter) error { return nil },
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}
|
||||
|
||||
m, err := Create(ifaceMock)
|
||||
@@ -102,37 +104,15 @@ func TestManagerDeleteRule(t *testing.T) {
|
||||
ip := net.ParseIP("192.168.1.1")
|
||||
proto := fw.ProtocolTCP
|
||||
port := &fw.Port{Values: []int{80}}
|
||||
direction := fw.RuleDirectionOUT
|
||||
action := fw.ActionDrop
|
||||
comment := "Test rule"
|
||||
comment := "Test rule 2"
|
||||
|
||||
rule, err := m.AddFiltering(ip, proto, nil, port, direction, action, "", comment)
|
||||
rule2, err := m.AddPeerFiltering(ip, proto, nil, port, action, "", comment)
|
||||
if err != nil {
|
||||
t.Errorf("failed to add filtering: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
ip = net.ParseIP("192.168.1.1")
|
||||
proto = fw.ProtocolTCP
|
||||
port = &fw.Port{Values: []int{80}}
|
||||
direction = fw.RuleDirectionIN
|
||||
action = fw.ActionDrop
|
||||
comment = "Test rule 2"
|
||||
|
||||
rule2, err := m.AddFiltering(ip, proto, nil, port, direction, action, "", comment)
|
||||
if err != nil {
|
||||
t.Errorf("failed to add filtering: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
for _, r := range rule {
|
||||
err = m.DeleteRule(r)
|
||||
if err != nil {
|
||||
t.Errorf("failed to delete rule: %v", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
for _, r := range rule2 {
|
||||
if _, ok := m.incomingRules[ip.String()][r.GetRuleID()]; !ok {
|
||||
t.Errorf("rule2 is not in the incomingRules")
|
||||
@@ -140,7 +120,7 @@ func TestManagerDeleteRule(t *testing.T) {
|
||||
}
|
||||
|
||||
for _, r := range rule2 {
|
||||
err = m.DeleteRule(r)
|
||||
err = m.DeletePeerRule(r)
|
||||
if err != nil {
|
||||
t.Errorf("failed to delete rule: %v", err)
|
||||
return
|
||||
@@ -184,10 +164,10 @@ func TestAddUDPPacketHook(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
manager := &Manager{
|
||||
incomingRules: map[string]RuleSet{},
|
||||
outgoingRules: map[string]RuleSet{},
|
||||
}
|
||||
manager, err := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
manager.AddUDPPacketHook(tt.in, tt.ip, tt.dPort, tt.hook)
|
||||
|
||||
@@ -222,10 +202,6 @@ func TestAddUDPPacketHook(t *testing.T) {
|
||||
t.Errorf("expected protoLayer %s, got %s", layers.LayerTypeUDP, addedRule.protoLayer)
|
||||
return
|
||||
}
|
||||
if tt.expDir != addedRule.direction {
|
||||
t.Errorf("expected direction %d, got %d", tt.expDir, addedRule.direction)
|
||||
return
|
||||
}
|
||||
if addedRule.udpHook == nil {
|
||||
t.Errorf("expected udpHook to be set")
|
||||
return
|
||||
@@ -236,7 +212,7 @@ func TestAddUDPPacketHook(t *testing.T) {
|
||||
|
||||
func TestManagerReset(t *testing.T) {
|
||||
ifaceMock := &IFaceMock{
|
||||
SetFilterFunc: func(iface.PacketFilter) error { return nil },
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}
|
||||
|
||||
m, err := Create(ifaceMock)
|
||||
@@ -248,17 +224,16 @@ func TestManagerReset(t *testing.T) {
|
||||
ip := net.ParseIP("192.168.1.1")
|
||||
proto := fw.ProtocolTCP
|
||||
port := &fw.Port{Values: []int{80}}
|
||||
direction := fw.RuleDirectionOUT
|
||||
action := fw.ActionDrop
|
||||
comment := "Test rule"
|
||||
|
||||
_, err = m.AddFiltering(ip, proto, nil, port, direction, action, "", comment)
|
||||
_, err = m.AddPeerFiltering(ip, proto, nil, port, action, "", comment)
|
||||
if err != nil {
|
||||
t.Errorf("failed to add filtering: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
err = m.Reset()
|
||||
err = m.Reset(nil)
|
||||
if err != nil {
|
||||
t.Errorf("failed to reset Manager: %v", err)
|
||||
return
|
||||
@@ -271,7 +246,7 @@ func TestManagerReset(t *testing.T) {
|
||||
|
||||
func TestNotMatchByIP(t *testing.T) {
|
||||
ifaceMock := &IFaceMock{
|
||||
SetFilterFunc: func(iface.PacketFilter) error { return nil },
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}
|
||||
|
||||
m, err := Create(ifaceMock)
|
||||
@@ -286,11 +261,10 @@ func TestNotMatchByIP(t *testing.T) {
|
||||
|
||||
ip := net.ParseIP("0.0.0.0")
|
||||
proto := fw.ProtocolUDP
|
||||
direction := fw.RuleDirectionOUT
|
||||
action := fw.ActionAccept
|
||||
comment := "Test rule"
|
||||
|
||||
_, err = m.AddFiltering(ip, proto, nil, nil, direction, action, "", comment)
|
||||
_, err = m.AddPeerFiltering(ip, proto, nil, nil, action, "", comment)
|
||||
if err != nil {
|
||||
t.Errorf("failed to add filtering: %v", err)
|
||||
return
|
||||
@@ -312,7 +286,7 @@ func TestNotMatchByIP(t *testing.T) {
|
||||
t.Errorf("failed to set network layer for checksum: %v", err)
|
||||
return
|
||||
}
|
||||
payload := gopacket.Payload([]byte("test"))
|
||||
payload := gopacket.Payload("test")
|
||||
|
||||
buf := gopacket.NewSerializeBuffer()
|
||||
opts := gopacket.SerializeOptions{
|
||||
@@ -324,12 +298,12 @@ func TestNotMatchByIP(t *testing.T) {
|
||||
return
|
||||
}
|
||||
|
||||
if m.dropFilter(buf.Bytes(), m.outgoingRules, false) {
|
||||
if m.dropFilter(buf.Bytes(), m.incomingRules) {
|
||||
t.Errorf("expected packet to be accepted")
|
||||
return
|
||||
}
|
||||
|
||||
if err = m.Reset(); err != nil {
|
||||
if err = m.Reset(nil); err != nil {
|
||||
t.Errorf("failed to reset Manager: %v", err)
|
||||
return
|
||||
}
|
||||
@@ -339,7 +313,7 @@ func TestNotMatchByIP(t *testing.T) {
|
||||
func TestRemovePacketHook(t *testing.T) {
|
||||
// creating mock iface
|
||||
iface := &IFaceMock{
|
||||
SetFilterFunc: func(iface.PacketFilter) error { return nil },
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}
|
||||
|
||||
// creating manager instance
|
||||
@@ -347,6 +321,9 @@ func TestRemovePacketHook(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create Manager: %s", err)
|
||||
}
|
||||
defer func() {
|
||||
require.NoError(t, manager.Reset(nil))
|
||||
}()
|
||||
|
||||
// Add a UDP packet hook
|
||||
hookFunc := func(data []byte) bool { return true }
|
||||
@@ -383,19 +360,101 @@ func TestRemovePacketHook(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestProcessOutgoingHooks(t *testing.T) {
|
||||
manager, err := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
manager.wgNetwork = &net.IPNet{
|
||||
IP: net.ParseIP("100.10.0.0"),
|
||||
Mask: net.CIDRMask(16, 32),
|
||||
}
|
||||
manager.udpTracker.Close()
|
||||
manager.udpTracker = conntrack.NewUDPTracker(100 * time.Millisecond)
|
||||
defer func() {
|
||||
require.NoError(t, manager.Reset(nil))
|
||||
}()
|
||||
|
||||
manager.decoders = sync.Pool{
|
||||
New: func() any {
|
||||
d := &decoder{
|
||||
decoded: []gopacket.LayerType{},
|
||||
}
|
||||
d.parser = gopacket.NewDecodingLayerParser(
|
||||
layers.LayerTypeIPv4,
|
||||
&d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp,
|
||||
)
|
||||
d.parser.IgnoreUnsupported = true
|
||||
return d
|
||||
},
|
||||
}
|
||||
|
||||
hookCalled := false
|
||||
hookID := manager.AddUDPPacketHook(
|
||||
false,
|
||||
net.ParseIP("100.10.0.100"),
|
||||
53,
|
||||
func([]byte) bool {
|
||||
hookCalled = true
|
||||
return true
|
||||
},
|
||||
)
|
||||
require.NotEmpty(t, hookID)
|
||||
|
||||
// Create test UDP packet
|
||||
ipv4 := &layers.IPv4{
|
||||
TTL: 64,
|
||||
Version: 4,
|
||||
SrcIP: net.ParseIP("100.10.0.1"),
|
||||
DstIP: net.ParseIP("100.10.0.100"),
|
||||
Protocol: layers.IPProtocolUDP,
|
||||
}
|
||||
udp := &layers.UDP{
|
||||
SrcPort: 51334,
|
||||
DstPort: 53,
|
||||
}
|
||||
|
||||
err = udp.SetNetworkLayerForChecksum(ipv4)
|
||||
require.NoError(t, err)
|
||||
payload := gopacket.Payload("test")
|
||||
|
||||
buf := gopacket.NewSerializeBuffer()
|
||||
opts := gopacket.SerializeOptions{
|
||||
ComputeChecksums: true,
|
||||
FixLengths: true,
|
||||
}
|
||||
err = gopacket.SerializeLayers(buf, opts, ipv4, udp, payload)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Test hook gets called
|
||||
result := manager.processOutgoingHooks(buf.Bytes())
|
||||
require.True(t, result)
|
||||
require.True(t, hookCalled)
|
||||
|
||||
// Test non-UDP packet is ignored
|
||||
ipv4.Protocol = layers.IPProtocolTCP
|
||||
buf = gopacket.NewSerializeBuffer()
|
||||
err = gopacket.SerializeLayers(buf, opts, ipv4)
|
||||
require.NoError(t, err)
|
||||
|
||||
result = manager.processOutgoingHooks(buf.Bytes())
|
||||
require.False(t, result)
|
||||
}
|
||||
|
||||
func TestUSPFilterCreatePerformance(t *testing.T) {
|
||||
for _, testMax := range []int{10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 200, 300, 400, 500, 600, 700, 800, 900, 1000} {
|
||||
t.Run(fmt.Sprintf("Testing %d rules", testMax), func(t *testing.T) {
|
||||
// just check on the local interface
|
||||
ifaceMock := &IFaceMock{
|
||||
SetFilterFunc: func(iface.PacketFilter) error { return nil },
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}
|
||||
manager, err := Create(ifaceMock)
|
||||
require.NoError(t, err)
|
||||
time.Sleep(time.Second)
|
||||
|
||||
defer func() {
|
||||
if err := manager.Reset(); err != nil {
|
||||
if err := manager.Reset(nil); err != nil {
|
||||
t.Errorf("clear the manager state: %v", err)
|
||||
}
|
||||
time.Sleep(time.Second)
|
||||
@@ -405,11 +464,7 @@ func TestUSPFilterCreatePerformance(t *testing.T) {
|
||||
start := time.Now()
|
||||
for i := 0; i < testMax; i++ {
|
||||
port := &fw.Port{Values: []int{1000 + i}}
|
||||
if i%2 == 0 {
|
||||
_, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic")
|
||||
} else {
|
||||
_, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTP traffic")
|
||||
}
|
||||
_, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.ActionAccept, "", "accept HTTP traffic")
|
||||
|
||||
require.NoError(t, err, "failed to add rule")
|
||||
}
|
||||
@@ -417,3 +472,213 @@ func TestUSPFilterCreatePerformance(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestStatefulFirewall_UDPTracking(t *testing.T) {
|
||||
manager, err := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
manager.wgNetwork = &net.IPNet{
|
||||
IP: net.ParseIP("100.10.0.0"),
|
||||
Mask: net.CIDRMask(16, 32),
|
||||
}
|
||||
|
||||
manager.udpTracker.Close() // Close the existing tracker
|
||||
manager.udpTracker = conntrack.NewUDPTracker(200 * time.Millisecond)
|
||||
manager.decoders = sync.Pool{
|
||||
New: func() any {
|
||||
d := &decoder{
|
||||
decoded: []gopacket.LayerType{},
|
||||
}
|
||||
d.parser = gopacket.NewDecodingLayerParser(
|
||||
layers.LayerTypeIPv4,
|
||||
&d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp,
|
||||
)
|
||||
d.parser.IgnoreUnsupported = true
|
||||
return d
|
||||
},
|
||||
}
|
||||
defer func() {
|
||||
require.NoError(t, manager.Reset(nil))
|
||||
}()
|
||||
|
||||
// Set up packet parameters
|
||||
srcIP := net.ParseIP("100.10.0.1")
|
||||
dstIP := net.ParseIP("100.10.0.100")
|
||||
srcPort := uint16(51334)
|
||||
dstPort := uint16(53)
|
||||
|
||||
// Create outbound packet
|
||||
outboundIPv4 := &layers.IPv4{
|
||||
TTL: 64,
|
||||
Version: 4,
|
||||
SrcIP: srcIP,
|
||||
DstIP: dstIP,
|
||||
Protocol: layers.IPProtocolUDP,
|
||||
}
|
||||
outboundUDP := &layers.UDP{
|
||||
SrcPort: layers.UDPPort(srcPort),
|
||||
DstPort: layers.UDPPort(dstPort),
|
||||
}
|
||||
|
||||
err = outboundUDP.SetNetworkLayerForChecksum(outboundIPv4)
|
||||
require.NoError(t, err)
|
||||
|
||||
outboundBuf := gopacket.NewSerializeBuffer()
|
||||
opts := gopacket.SerializeOptions{
|
||||
ComputeChecksums: true,
|
||||
FixLengths: true,
|
||||
}
|
||||
|
||||
err = gopacket.SerializeLayers(outboundBuf, opts,
|
||||
outboundIPv4,
|
||||
outboundUDP,
|
||||
gopacket.Payload("test"),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Process outbound packet and verify connection tracking
|
||||
drop := manager.DropOutgoing(outboundBuf.Bytes())
|
||||
require.False(t, drop, "Initial outbound packet should not be dropped")
|
||||
|
||||
// Verify connection was tracked
|
||||
conn, exists := manager.udpTracker.GetConnection(srcIP, srcPort, dstIP, dstPort)
|
||||
|
||||
require.True(t, exists, "Connection should be tracked after outbound packet")
|
||||
require.True(t, conntrack.ValidateIPs(conntrack.MakeIPAddr(srcIP), conn.SourceIP), "Source IP should match")
|
||||
require.True(t, conntrack.ValidateIPs(conntrack.MakeIPAddr(dstIP), conn.DestIP), "Destination IP should match")
|
||||
require.Equal(t, srcPort, conn.SourcePort, "Source port should match")
|
||||
require.Equal(t, dstPort, conn.DestPort, "Destination port should match")
|
||||
|
||||
// Create valid inbound response packet
|
||||
inboundIPv4 := &layers.IPv4{
|
||||
TTL: 64,
|
||||
Version: 4,
|
||||
SrcIP: dstIP, // Original destination is now source
|
||||
DstIP: srcIP, // Original source is now destination
|
||||
Protocol: layers.IPProtocolUDP,
|
||||
}
|
||||
inboundUDP := &layers.UDP{
|
||||
SrcPort: layers.UDPPort(dstPort), // Original destination port is now source
|
||||
DstPort: layers.UDPPort(srcPort), // Original source port is now destination
|
||||
}
|
||||
|
||||
err = inboundUDP.SetNetworkLayerForChecksum(inboundIPv4)
|
||||
require.NoError(t, err)
|
||||
|
||||
inboundBuf := gopacket.NewSerializeBuffer()
|
||||
err = gopacket.SerializeLayers(inboundBuf, opts,
|
||||
inboundIPv4,
|
||||
inboundUDP,
|
||||
gopacket.Payload("response"),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
// Test roundtrip response handling over time
|
||||
checkPoints := []struct {
|
||||
sleep time.Duration
|
||||
shouldAllow bool
|
||||
description string
|
||||
}{
|
||||
{
|
||||
sleep: 0,
|
||||
shouldAllow: true,
|
||||
description: "Immediate response should be allowed",
|
||||
},
|
||||
{
|
||||
sleep: 50 * time.Millisecond,
|
||||
shouldAllow: true,
|
||||
description: "Response within timeout should be allowed",
|
||||
},
|
||||
{
|
||||
sleep: 100 * time.Millisecond,
|
||||
shouldAllow: true,
|
||||
description: "Response at half timeout should be allowed",
|
||||
},
|
||||
{
|
||||
// tracker hasn't updated conn for 250ms -> greater than 200ms timeout
|
||||
sleep: 250 * time.Millisecond,
|
||||
shouldAllow: false,
|
||||
description: "Response after timeout should be dropped",
|
||||
},
|
||||
}
|
||||
|
||||
for _, cp := range checkPoints {
|
||||
time.Sleep(cp.sleep)
|
||||
|
||||
drop = manager.dropFilter(inboundBuf.Bytes(), manager.incomingRules)
|
||||
require.Equal(t, cp.shouldAllow, !drop, cp.description)
|
||||
|
||||
// If the connection should still be valid, verify it exists
|
||||
if cp.shouldAllow {
|
||||
conn, exists := manager.udpTracker.GetConnection(srcIP, srcPort, dstIP, dstPort)
|
||||
require.True(t, exists, "Connection should still exist during valid window")
|
||||
require.True(t, time.Since(conn.GetLastSeen()) < manager.udpTracker.Timeout(),
|
||||
"LastSeen should be updated for valid responses")
|
||||
}
|
||||
}
|
||||
|
||||
// Test invalid response packets (while connection is expired)
|
||||
invalidCases := []struct {
|
||||
name string
|
||||
modifyFunc func(*layers.IPv4, *layers.UDP)
|
||||
description string
|
||||
}{
|
||||
{
|
||||
name: "wrong source IP",
|
||||
modifyFunc: func(ip *layers.IPv4, udp *layers.UDP) {
|
||||
ip.SrcIP = net.ParseIP("100.10.0.101")
|
||||
},
|
||||
description: "Response from wrong IP should be dropped",
|
||||
},
|
||||
{
|
||||
name: "wrong destination IP",
|
||||
modifyFunc: func(ip *layers.IPv4, udp *layers.UDP) {
|
||||
ip.DstIP = net.ParseIP("100.10.0.2")
|
||||
},
|
||||
description: "Response to wrong IP should be dropped",
|
||||
},
|
||||
{
|
||||
name: "wrong source port",
|
||||
modifyFunc: func(ip *layers.IPv4, udp *layers.UDP) {
|
||||
udp.SrcPort = 54
|
||||
},
|
||||
description: "Response from wrong port should be dropped",
|
||||
},
|
||||
{
|
||||
name: "wrong destination port",
|
||||
modifyFunc: func(ip *layers.IPv4, udp *layers.UDP) {
|
||||
udp.DstPort = 51335
|
||||
},
|
||||
description: "Response to wrong port should be dropped",
|
||||
},
|
||||
}
|
||||
|
||||
// Create a new outbound connection for invalid tests
|
||||
drop = manager.processOutgoingHooks(outboundBuf.Bytes())
|
||||
require.False(t, drop, "Second outbound packet should not be dropped")
|
||||
|
||||
for _, tc := range invalidCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
testIPv4 := *inboundIPv4
|
||||
testUDP := *inboundUDP
|
||||
|
||||
tc.modifyFunc(&testIPv4, &testUDP)
|
||||
|
||||
err = testUDP.SetNetworkLayerForChecksum(&testIPv4)
|
||||
require.NoError(t, err)
|
||||
|
||||
testBuf := gopacket.NewSerializeBuffer()
|
||||
err = gopacket.SerializeLayers(testBuf, opts,
|
||||
&testIPv4,
|
||||
&testUDP,
|
||||
gopacket.Payload("response"),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify the invalid packet is dropped
|
||||
drop = manager.dropFilter(testBuf.Bytes(), manager.incomingRules)
|
||||
require.True(t, drop, tc.description)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
12
client/iface/bind/control_android.go
Normal file
12
client/iface/bind/control_android.go
Normal file
@@ -0,0 +1,12 @@
|
||||
package bind
|
||||
|
||||
import (
|
||||
wireguard "golang.zx2c4.com/wireguard/conn"
|
||||
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
)
|
||||
|
||||
func init() {
|
||||
// ControlFns is not thread safe and should only be modified during init.
|
||||
*wireguard.ControlFns = append(*wireguard.ControlFns, nbnet.ControlProtectSocket)
|
||||
}
|
||||
5
client/iface/bind/endpoint.go
Normal file
5
client/iface/bind/endpoint.go
Normal file
@@ -0,0 +1,5 @@
|
||||
package bind
|
||||
|
||||
import wgConn "golang.zx2c4.com/wireguard/conn"
|
||||
|
||||
type Endpoint = wgConn.StdNetEndpoint
|
||||
303
client/iface/bind/ice_bind.go
Normal file
303
client/iface/bind/ice_bind.go
Normal file
@@ -0,0 +1,303 @@
|
||||
package bind
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"runtime"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/pion/stun/v2"
|
||||
"github.com/pion/transport/v3"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/net/ipv4"
|
||||
"golang.org/x/net/ipv6"
|
||||
wgConn "golang.zx2c4.com/wireguard/conn"
|
||||
)
|
||||
|
||||
type RecvMessage struct {
|
||||
Endpoint *Endpoint
|
||||
Buffer []byte
|
||||
}
|
||||
|
||||
type receiverCreator struct {
|
||||
iceBind *ICEBind
|
||||
}
|
||||
|
||||
func (rc receiverCreator) CreateIPv4ReceiverFn(pc *ipv4.PacketConn, conn *net.UDPConn, rxOffload bool, msgPool *sync.Pool) wgConn.ReceiveFunc {
|
||||
return rc.iceBind.createIPv4ReceiverFn(pc, conn, rxOffload, msgPool)
|
||||
}
|
||||
|
||||
// ICEBind is a bind implementation with two main features:
|
||||
// 1. filter out STUN messages and handle them
|
||||
// 2. forward the received packets to the WireGuard interface from the relayed connection
|
||||
//
|
||||
// ICEBind.endpoints var is a map that stores the connection for each relayed peer. Fake address is just an IP address
|
||||
// without port, in the format of 127.1.x.x where x.x is the last two octets of the peer address. We try to avoid to
|
||||
// use the port because in the Send function the wgConn.Endpoint the port info is not exported.
|
||||
type ICEBind struct {
|
||||
*wgConn.StdNetBind
|
||||
RecvChan chan RecvMessage
|
||||
|
||||
transportNet transport.Net
|
||||
filterFn FilterFn
|
||||
endpoints map[netip.Addr]net.Conn
|
||||
endpointsMu sync.Mutex
|
||||
// every time when Close() is called (i.e. BindUpdate()) we need to close exit from the receiveRelayed and create a
|
||||
// new closed channel. With the closedChanMu we can safely close the channel and create a new one
|
||||
closedChan chan struct{}
|
||||
closedChanMu sync.RWMutex // protect the closeChan recreation from reading from it.
|
||||
closed bool
|
||||
|
||||
muUDPMux sync.Mutex
|
||||
udpMux *UniversalUDPMuxDefault
|
||||
}
|
||||
|
||||
func NewICEBind(transportNet transport.Net, filterFn FilterFn) *ICEBind {
|
||||
b, _ := wgConn.NewStdNetBind().(*wgConn.StdNetBind)
|
||||
ib := &ICEBind{
|
||||
StdNetBind: b,
|
||||
RecvChan: make(chan RecvMessage, 1),
|
||||
transportNet: transportNet,
|
||||
filterFn: filterFn,
|
||||
endpoints: make(map[netip.Addr]net.Conn),
|
||||
closedChan: make(chan struct{}),
|
||||
closed: true,
|
||||
}
|
||||
|
||||
rc := receiverCreator{
|
||||
ib,
|
||||
}
|
||||
ib.StdNetBind = wgConn.NewStdNetBindWithReceiverCreator(rc)
|
||||
return ib
|
||||
}
|
||||
|
||||
func (s *ICEBind) Open(uport uint16) ([]wgConn.ReceiveFunc, uint16, error) {
|
||||
s.closed = false
|
||||
s.closedChanMu.Lock()
|
||||
s.closedChan = make(chan struct{})
|
||||
s.closedChanMu.Unlock()
|
||||
fns, port, err := s.StdNetBind.Open(uport)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
fns = append(fns, s.receiveRelayed)
|
||||
return fns, port, nil
|
||||
}
|
||||
|
||||
func (s *ICEBind) Close() error {
|
||||
if s.closed {
|
||||
return nil
|
||||
}
|
||||
s.closed = true
|
||||
|
||||
close(s.closedChan)
|
||||
|
||||
return s.StdNetBind.Close()
|
||||
}
|
||||
|
||||
// GetICEMux returns the ICE UDPMux that was created and used by ICEBind
|
||||
func (s *ICEBind) GetICEMux() (*UniversalUDPMuxDefault, error) {
|
||||
s.muUDPMux.Lock()
|
||||
defer s.muUDPMux.Unlock()
|
||||
if s.udpMux == nil {
|
||||
return nil, fmt.Errorf("ICEBind has not been initialized yet")
|
||||
}
|
||||
|
||||
return s.udpMux, nil
|
||||
}
|
||||
|
||||
func (b *ICEBind) SetEndpoint(peerAddress *net.UDPAddr, conn net.Conn) (*net.UDPAddr, error) {
|
||||
fakeUDPAddr, err := fakeAddress(peerAddress)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// force IPv4
|
||||
fakeAddr, ok := netip.AddrFromSlice(fakeUDPAddr.IP.To4())
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("failed to convert IP to netip.Addr")
|
||||
}
|
||||
|
||||
b.endpointsMu.Lock()
|
||||
b.endpoints[fakeAddr] = conn
|
||||
b.endpointsMu.Unlock()
|
||||
|
||||
return fakeUDPAddr, nil
|
||||
}
|
||||
|
||||
func (b *ICEBind) RemoveEndpoint(fakeUDPAddr *net.UDPAddr) {
|
||||
fakeAddr, ok := netip.AddrFromSlice(fakeUDPAddr.IP.To4())
|
||||
if !ok {
|
||||
log.Warnf("failed to convert IP to netip.Addr")
|
||||
return
|
||||
}
|
||||
|
||||
b.endpointsMu.Lock()
|
||||
defer b.endpointsMu.Unlock()
|
||||
delete(b.endpoints, fakeAddr)
|
||||
}
|
||||
|
||||
func (b *ICEBind) Send(bufs [][]byte, ep wgConn.Endpoint) error {
|
||||
b.endpointsMu.Lock()
|
||||
conn, ok := b.endpoints[ep.DstIP()]
|
||||
b.endpointsMu.Unlock()
|
||||
if !ok {
|
||||
return b.StdNetBind.Send(bufs, ep)
|
||||
}
|
||||
|
||||
for _, buf := range bufs {
|
||||
if _, err := conn.Write(buf); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *ICEBind) createIPv4ReceiverFn(pc *ipv4.PacketConn, conn *net.UDPConn, rxOffload bool, msgsPool *sync.Pool) wgConn.ReceiveFunc {
|
||||
s.muUDPMux.Lock()
|
||||
defer s.muUDPMux.Unlock()
|
||||
|
||||
s.udpMux = NewUniversalUDPMuxDefault(
|
||||
UniversalUDPMuxParams{
|
||||
UDPConn: conn,
|
||||
Net: s.transportNet,
|
||||
FilterFn: s.filterFn,
|
||||
},
|
||||
)
|
||||
return func(bufs [][]byte, sizes []int, eps []wgConn.Endpoint) (n int, err error) {
|
||||
msgs := getMessages(msgsPool)
|
||||
for i := range bufs {
|
||||
(*msgs)[i].Buffers[0] = bufs[i]
|
||||
(*msgs)[i].OOB = (*msgs)[i].OOB[:cap((*msgs)[i].OOB)]
|
||||
}
|
||||
defer putMessages(msgs, msgsPool)
|
||||
var numMsgs int
|
||||
if runtime.GOOS == "linux" || runtime.GOOS == "android" {
|
||||
if rxOffload {
|
||||
readAt := len(*msgs) - (wgConn.IdealBatchSize / wgConn.UdpSegmentMaxDatagrams)
|
||||
//nolint
|
||||
numMsgs, err = pc.ReadBatch((*msgs)[readAt:], 0)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
numMsgs, err = wgConn.SplitCoalescedMessages(*msgs, readAt, wgConn.GetGSOSize)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
} else {
|
||||
numMsgs, err = pc.ReadBatch(*msgs, 0)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
} else {
|
||||
msg := &(*msgs)[0]
|
||||
msg.N, msg.NN, _, msg.Addr, err = conn.ReadMsgUDP(msg.Buffers[0], msg.OOB)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
numMsgs = 1
|
||||
}
|
||||
for i := 0; i < numMsgs; i++ {
|
||||
msg := &(*msgs)[i]
|
||||
|
||||
// todo: handle err
|
||||
ok, _ := s.filterOutStunMessages(msg.Buffers, msg.N, msg.Addr)
|
||||
if ok {
|
||||
continue
|
||||
}
|
||||
sizes[i] = msg.N
|
||||
if sizes[i] == 0 {
|
||||
continue
|
||||
}
|
||||
addrPort := msg.Addr.(*net.UDPAddr).AddrPort()
|
||||
ep := &wgConn.StdNetEndpoint{AddrPort: addrPort} // TODO: remove allocation
|
||||
wgConn.GetSrcFromControl(msg.OOB[:msg.NN], ep)
|
||||
eps[i] = ep
|
||||
}
|
||||
return numMsgs, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ICEBind) filterOutStunMessages(buffers [][]byte, n int, addr net.Addr) (bool, error) {
|
||||
for i := range buffers {
|
||||
if !stun.IsMessage(buffers[i]) {
|
||||
continue
|
||||
}
|
||||
|
||||
msg, err := s.parseSTUNMessage(buffers[i][:n])
|
||||
if err != nil {
|
||||
buffers[i] = []byte{}
|
||||
return true, err
|
||||
}
|
||||
|
||||
muxErr := s.udpMux.HandleSTUNMessage(msg, addr)
|
||||
if muxErr != nil {
|
||||
log.Warnf("failed to handle STUN packet")
|
||||
}
|
||||
|
||||
buffers[i] = []byte{}
|
||||
return true, nil
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (s *ICEBind) parseSTUNMessage(raw []byte) (*stun.Message, error) {
|
||||
msg := &stun.Message{
|
||||
Raw: raw,
|
||||
}
|
||||
if err := msg.Decode(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return msg, nil
|
||||
}
|
||||
|
||||
// receiveRelayed is a receive function that is used to receive packets from the relayed connection and forward to the
|
||||
// WireGuard. Critical part is do not block if the Closed() has been called.
|
||||
func (c *ICEBind) receiveRelayed(buffs [][]byte, sizes []int, eps []wgConn.Endpoint) (int, error) {
|
||||
c.closedChanMu.RLock()
|
||||
defer c.closedChanMu.RUnlock()
|
||||
|
||||
select {
|
||||
case <-c.closedChan:
|
||||
return 0, net.ErrClosed
|
||||
case msg, ok := <-c.RecvChan:
|
||||
if !ok {
|
||||
return 0, net.ErrClosed
|
||||
}
|
||||
copy(buffs[0], msg.Buffer)
|
||||
sizes[0] = len(msg.Buffer)
|
||||
eps[0] = wgConn.Endpoint(msg.Endpoint)
|
||||
return 1, nil
|
||||
}
|
||||
}
|
||||
|
||||
// fakeAddress returns a fake address that is used to as an identifier for the peer.
|
||||
// The fake address is in the format of 127.1.x.x where x.x is the last two octets of the peer address.
|
||||
func fakeAddress(peerAddress *net.UDPAddr) (*net.UDPAddr, error) {
|
||||
octets := strings.Split(peerAddress.IP.String(), ".")
|
||||
if len(octets) != 4 {
|
||||
return nil, fmt.Errorf("invalid IP format")
|
||||
}
|
||||
|
||||
newAddr := &net.UDPAddr{
|
||||
IP: net.ParseIP(fmt.Sprintf("127.1.%s.%s", octets[2], octets[3])),
|
||||
Port: peerAddress.Port,
|
||||
}
|
||||
return newAddr, nil
|
||||
}
|
||||
|
||||
func getMessages(msgsPool *sync.Pool) *[]ipv6.Message {
|
||||
return msgsPool.Get().(*[]ipv6.Message)
|
||||
}
|
||||
|
||||
func putMessages(msgs *[]ipv6.Message, msgsPool *sync.Pool) {
|
||||
for i := range *msgs {
|
||||
(*msgs)[i].OOB = (*msgs)[i].OOB[:0]
|
||||
(*msgs)[i] = ipv6.Message{Buffers: (*msgs)[i].Buffers, OOB: (*msgs)[i].OOB}
|
||||
}
|
||||
msgsPool.Put(msgs)
|
||||
}
|
||||
@@ -162,12 +162,13 @@ func NewUDPMuxDefault(params UDPMuxParams) *UDPMuxDefault {
|
||||
params.Logger.Warn("UDPMuxDefault should not listening on unspecified address, use NewMultiUDPMuxFromPort instead")
|
||||
var networks []ice.NetworkType
|
||||
switch {
|
||||
case addr.IP.To4() != nil:
|
||||
networks = []ice.NetworkType{ice.NetworkTypeUDP4}
|
||||
|
||||
case addr.IP.To16() != nil:
|
||||
networks = []ice.NetworkType{ice.NetworkTypeUDP4, ice.NetworkTypeUDP6}
|
||||
|
||||
case addr.IP.To4() != nil:
|
||||
networks = []ice.NetworkType{ice.NetworkTypeUDP4}
|
||||
|
||||
default:
|
||||
params.Logger.Errorf("LocalAddr expected IPV4 or IPV6, got %T", params.UDPConn.LocalAddr())
|
||||
}
|
||||
5
client/iface/configurer/err.go
Normal file
5
client/iface/configurer/err.go
Normal file
@@ -0,0 +1,5 @@
|
||||
package configurer
|
||||
|
||||
import "errors"
|
||||
|
||||
var ErrPeerNotFound = errors.New("peer not found")
|
||||
@@ -1,6 +1,6 @@
|
||||
//go:build (linux && !android) || freebsd
|
||||
|
||||
package iface
|
||||
package configurer
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
@@ -12,18 +12,17 @@ import (
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
)
|
||||
|
||||
type wgKernelConfigurer struct {
|
||||
type KernelConfigurer struct {
|
||||
deviceName string
|
||||
}
|
||||
|
||||
func newWGConfigurer(deviceName string) wgConfigurer {
|
||||
wgc := &wgKernelConfigurer{
|
||||
func NewKernelConfigurer(deviceName string) *KernelConfigurer {
|
||||
return &KernelConfigurer{
|
||||
deviceName: deviceName,
|
||||
}
|
||||
return wgc
|
||||
}
|
||||
|
||||
func (c *wgKernelConfigurer) configureInterface(privateKey string, port int) error {
|
||||
func (c *KernelConfigurer) ConfigureInterface(privateKey string, port int) error {
|
||||
log.Debugf("adding Wireguard private key")
|
||||
key, err := wgtypes.ParseKey(privateKey)
|
||||
if err != nil {
|
||||
@@ -44,7 +43,7 @@ func (c *wgKernelConfigurer) configureInterface(privateKey string, port int) err
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *wgKernelConfigurer) updatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error {
|
||||
func (c *KernelConfigurer) UpdatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error {
|
||||
// parse allowed ips
|
||||
_, ipNet, err := net.ParseCIDR(allowedIps)
|
||||
if err != nil {
|
||||
@@ -56,8 +55,9 @@ func (c *wgKernelConfigurer) updatePeer(peerKey string, allowedIps string, keepA
|
||||
return err
|
||||
}
|
||||
peer := wgtypes.PeerConfig{
|
||||
PublicKey: peerKeyParsed,
|
||||
ReplaceAllowedIPs: true,
|
||||
PublicKey: peerKeyParsed,
|
||||
ReplaceAllowedIPs: false,
|
||||
// don't replace allowed ips, wg will handle duplicated peer IP
|
||||
AllowedIPs: []net.IPNet{*ipNet},
|
||||
PersistentKeepaliveInterval: &keepAlive,
|
||||
Endpoint: endpoint,
|
||||
@@ -74,7 +74,7 @@ func (c *wgKernelConfigurer) updatePeer(peerKey string, allowedIps string, keepA
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *wgKernelConfigurer) removePeer(peerKey string) error {
|
||||
func (c *KernelConfigurer) RemovePeer(peerKey string) error {
|
||||
peerKeyParsed, err := wgtypes.ParseKey(peerKey)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -95,7 +95,7 @@ func (c *wgKernelConfigurer) removePeer(peerKey string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *wgKernelConfigurer) addAllowedIP(peerKey string, allowedIP string) error {
|
||||
func (c *KernelConfigurer) AddAllowedIP(peerKey string, allowedIP string) error {
|
||||
_, ipNet, err := net.ParseCIDR(allowedIP)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -122,7 +122,7 @@ func (c *wgKernelConfigurer) addAllowedIP(peerKey string, allowedIP string) erro
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *wgKernelConfigurer) removeAllowedIP(peerKey string, allowedIP string) error {
|
||||
func (c *KernelConfigurer) RemoveAllowedIP(peerKey string, allowedIP string) error {
|
||||
_, ipNet, err := net.ParseCIDR(allowedIP)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parse allowed IP: %w", err)
|
||||
@@ -164,7 +164,7 @@ func (c *wgKernelConfigurer) removeAllowedIP(peerKey string, allowedIP string) e
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *wgKernelConfigurer) getPeer(ifaceName, peerPubKey string) (wgtypes.Peer, error) {
|
||||
func (c *KernelConfigurer) getPeer(ifaceName, peerPubKey string) (wgtypes.Peer, error) {
|
||||
wg, err := wgctrl.New()
|
||||
if err != nil {
|
||||
return wgtypes.Peer{}, fmt.Errorf("wgctl: %w", err)
|
||||
@@ -188,7 +188,7 @@ func (c *wgKernelConfigurer) getPeer(ifaceName, peerPubKey string) (wgtypes.Peer
|
||||
return wgtypes.Peer{}, ErrPeerNotFound
|
||||
}
|
||||
|
||||
func (c *wgKernelConfigurer) configure(config wgtypes.Config) error {
|
||||
func (c *KernelConfigurer) configure(config wgtypes.Config) error {
|
||||
wg, err := wgctrl.New()
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -204,10 +204,10 @@ func (c *wgKernelConfigurer) configure(config wgtypes.Config) error {
|
||||
return wg.ConfigureDevice(c.deviceName, config)
|
||||
}
|
||||
|
||||
func (c *wgKernelConfigurer) close() {
|
||||
func (c *KernelConfigurer) Close() {
|
||||
}
|
||||
|
||||
func (c *wgKernelConfigurer) getStats(peerKey string) (WGStats, error) {
|
||||
func (c *KernelConfigurer) GetStats(peerKey string) (WGStats, error) {
|
||||
peer, err := c.getPeer(c.deviceName, peerKey)
|
||||
if err != nil {
|
||||
return WGStats{}, fmt.Errorf("get wireguard stats: %w", err)
|
||||
@@ -1,6 +1,6 @@
|
||||
//go:build linux || windows || freebsd
|
||||
|
||||
package iface
|
||||
package configurer
|
||||
|
||||
// WgInterfaceDefault is a default interface name of Wiretrustee
|
||||
const WgInterfaceDefault = "wt0"
|
||||
@@ -1,6 +1,6 @@
|
||||
//go:build darwin
|
||||
|
||||
package iface
|
||||
package configurer
|
||||
|
||||
// WgInterfaceDefault is a default interface name of Wiretrustee
|
||||
const WgInterfaceDefault = "utun100"
|
||||
@@ -1,6 +1,6 @@
|
||||
//go:build !windows
|
||||
|
||||
package iface
|
||||
package configurer
|
||||
|
||||
import (
|
||||
"net"
|
||||
@@ -1,4 +1,4 @@
|
||||
package iface
|
||||
package configurer
|
||||
|
||||
import (
|
||||
"net"
|
||||
@@ -1,4 +1,4 @@
|
||||
package iface
|
||||
package configurer
|
||||
|
||||
import (
|
||||
"encoding/hex"
|
||||
@@ -19,15 +19,15 @@ import (
|
||||
|
||||
var ErrAllowedIPNotFound = fmt.Errorf("allowed IP not found")
|
||||
|
||||
type wgUSPConfigurer struct {
|
||||
type WGUSPConfigurer struct {
|
||||
device *device.Device
|
||||
deviceName string
|
||||
|
||||
uapiListener net.Listener
|
||||
}
|
||||
|
||||
func newWGUSPConfigurer(device *device.Device, deviceName string) wgConfigurer {
|
||||
wgCfg := &wgUSPConfigurer{
|
||||
func NewUSPConfigurer(device *device.Device, deviceName string) *WGUSPConfigurer {
|
||||
wgCfg := &WGUSPConfigurer{
|
||||
device: device,
|
||||
deviceName: deviceName,
|
||||
}
|
||||
@@ -35,7 +35,7 @@ func newWGUSPConfigurer(device *device.Device, deviceName string) wgConfigurer {
|
||||
return wgCfg
|
||||
}
|
||||
|
||||
func (c *wgUSPConfigurer) configureInterface(privateKey string, port int) error {
|
||||
func (c *WGUSPConfigurer) ConfigureInterface(privateKey string, port int) error {
|
||||
log.Debugf("adding Wireguard private key")
|
||||
key, err := wgtypes.ParseKey(privateKey)
|
||||
if err != nil {
|
||||
@@ -52,7 +52,7 @@ func (c *wgUSPConfigurer) configureInterface(privateKey string, port int) error
|
||||
return c.device.IpcSet(toWgUserspaceString(config))
|
||||
}
|
||||
|
||||
func (c *wgUSPConfigurer) updatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error {
|
||||
func (c *WGUSPConfigurer) UpdatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error {
|
||||
// parse allowed ips
|
||||
_, ipNet, err := net.ParseCIDR(allowedIps)
|
||||
if err != nil {
|
||||
@@ -64,8 +64,9 @@ func (c *wgUSPConfigurer) updatePeer(peerKey string, allowedIps string, keepAliv
|
||||
return err
|
||||
}
|
||||
peer := wgtypes.PeerConfig{
|
||||
PublicKey: peerKeyParsed,
|
||||
ReplaceAllowedIPs: true,
|
||||
PublicKey: peerKeyParsed,
|
||||
ReplaceAllowedIPs: false,
|
||||
// don't replace allowed ips, wg will handle duplicated peer IP
|
||||
AllowedIPs: []net.IPNet{*ipNet},
|
||||
PersistentKeepaliveInterval: &keepAlive,
|
||||
PresharedKey: preSharedKey,
|
||||
@@ -79,7 +80,7 @@ func (c *wgUSPConfigurer) updatePeer(peerKey string, allowedIps string, keepAliv
|
||||
return c.device.IpcSet(toWgUserspaceString(config))
|
||||
}
|
||||
|
||||
func (c *wgUSPConfigurer) removePeer(peerKey string) error {
|
||||
func (c *WGUSPConfigurer) RemovePeer(peerKey string) error {
|
||||
peerKeyParsed, err := wgtypes.ParseKey(peerKey)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -96,7 +97,7 @@ func (c *wgUSPConfigurer) removePeer(peerKey string) error {
|
||||
return c.device.IpcSet(toWgUserspaceString(config))
|
||||
}
|
||||
|
||||
func (c *wgUSPConfigurer) addAllowedIP(peerKey string, allowedIP string) error {
|
||||
func (c *WGUSPConfigurer) AddAllowedIP(peerKey string, allowedIP string) error {
|
||||
_, ipNet, err := net.ParseCIDR(allowedIP)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -120,7 +121,7 @@ func (c *wgUSPConfigurer) addAllowedIP(peerKey string, allowedIP string) error {
|
||||
return c.device.IpcSet(toWgUserspaceString(config))
|
||||
}
|
||||
|
||||
func (c *wgUSPConfigurer) removeAllowedIP(peerKey string, ip string) error {
|
||||
func (c *WGUSPConfigurer) RemoveAllowedIP(peerKey string, ip string) error {
|
||||
ipc, err := c.device.IpcGet()
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -184,7 +185,7 @@ func (c *wgUSPConfigurer) removeAllowedIP(peerKey string, ip string) error {
|
||||
}
|
||||
|
||||
// startUAPI starts the UAPI listener for managing the WireGuard interface via external tool
|
||||
func (t *wgUSPConfigurer) startUAPI() {
|
||||
func (t *WGUSPConfigurer) startUAPI() {
|
||||
var err error
|
||||
t.uapiListener, err = openUAPI(t.deviceName)
|
||||
if err != nil {
|
||||
@@ -206,7 +207,7 @@ func (t *wgUSPConfigurer) startUAPI() {
|
||||
}(t.uapiListener)
|
||||
}
|
||||
|
||||
func (t *wgUSPConfigurer) close() {
|
||||
func (t *WGUSPConfigurer) Close() {
|
||||
if t.uapiListener != nil {
|
||||
err := t.uapiListener.Close()
|
||||
if err != nil {
|
||||
@@ -222,7 +223,7 @@ func (t *wgUSPConfigurer) close() {
|
||||
}
|
||||
}
|
||||
|
||||
func (t *wgUSPConfigurer) getStats(peerKey string) (WGStats, error) {
|
||||
func (t *WGUSPConfigurer) GetStats(peerKey string) (WGStats, error) {
|
||||
ipc, err := t.device.IpcGet()
|
||||
if err != nil {
|
||||
return WGStats{}, fmt.Errorf("ipc get: %w", err)
|
||||
@@ -1,4 +1,4 @@
|
||||
package iface
|
||||
package configurer
|
||||
|
||||
import (
|
||||
"encoding/hex"
|
||||
9
client/iface/configurer/wgstats.go
Normal file
9
client/iface/configurer/wgstats.go
Normal file
@@ -0,0 +1,9 @@
|
||||
package configurer
|
||||
|
||||
import "time"
|
||||
|
||||
type WGStats struct {
|
||||
LastHandshake time.Time
|
||||
TxBytes int64
|
||||
RxBytes int64
|
||||
}
|
||||
18
client/iface/device.go
Normal file
18
client/iface/device.go
Normal file
@@ -0,0 +1,18 @@
|
||||
//go:build !android
|
||||
|
||||
package iface
|
||||
|
||||
import (
|
||||
"github.com/netbirdio/netbird/client/iface/bind"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
)
|
||||
|
||||
type WGTunDevice interface {
|
||||
Create() (device.WGConfigurer, error)
|
||||
Up() (*bind.UniversalUDPMuxDefault, error)
|
||||
UpdateAddr(address WGAddress) error
|
||||
WgAddress() WGAddress
|
||||
DeviceName() string
|
||||
Close() error
|
||||
FilteredDevice() *device.FilteredDevice
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package iface
|
||||
package device
|
||||
|
||||
// TunAdapter is an interface for create tun device from external service
|
||||
type TunAdapter interface {
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user