Compare commits

...

197 Commits

Author SHA1 Message Date
Pascal Fischer
4d2c774378 refactor networm map generation 2025-03-13 14:29:59 +01:00
Pascal Fischer
ab2e3fec72 expose resource type consts 2025-03-12 13:49:24 +01:00
Hakan Sariman
47f88f7057 Refactor routeIDLookup methods to use Addr() for resolved IP operations 2025-03-11 19:43:58 +08:00
Hakan Sariman
ee33a6ed7c Refactor RemoveLocalPeerStateRoute to eliminate resourceId parameter 2025-03-11 13:19:30 +08:00
Hakan Sariman
da662cfd08 Add source and destination resource IDs to FlowFields 2025-03-11 13:12:54 +08:00
Hakan Sariman
ed2ee1ee9d Merge branch 'feature/flow' into feat/flow-resid 2025-03-11 13:08:11 +08:00
Viktor Liu
76d73548d6 Fix more conflicts 2025-03-10 18:46:01 +01:00
Viktor Liu
11828a064a Fix conflict 2025-03-10 18:35:32 +01:00
Viktor Liu
0c2a3dd937 Merge branch 'main' into feature/flow 2025-03-10 18:30:45 +01:00
Zoltan Papp
cd9eff5331 Increase the timeout to 50 sec (#3481) 2025-03-10 18:23:47 +01:00
Viktor Liu
47dcf8d68c Fix forwarder IP source/destination (#3463) 2025-03-10 14:55:07 +01:00
Viktor Liu
80ceb80197 [client] Ignore candidates that are part of the the wireguard subnet (#3472) 2025-03-10 13:59:21 +01:00
Bethuel Mmbaga
cc8f6bcaf3 [management] Fix tests circular dependency (#3460)
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2025-03-10 15:54:36 +03:00
Zoltan Papp
636a0e2475 [client] Fix engine restart (#3435)
- Refactor the network monitoring to handle one event and it after return
- In the engine restart cancel the upper layer context and the responsibility of the engine stop will be the upper layer
- Before triggering a restart, the engine checks whether the state is already down. This helps avoid unnecessary delayed network restart events.
2025-03-10 13:32:12 +01:00
Viktor Liu
e66e329bf6 [client] Add option to autostart netbird ui in the Windows installer (#3469) 2025-03-10 13:19:17 +01:00
Zoltan Papp
aaa23beeec [client] Prevent to block channel writing (#3474)
The "runningChan" provides feedback to the UI or any client about whether the service is up and running. If the client exits earlier than when the service successfully starts, then this channel causes a block.

- Added timeout for reading the channel to ensure we don't cause blocks for too long for the caller
- Modified channel writing operations to be non-blocking
2025-03-10 13:17:09 +01:00
Zoltan Papp
6bef474e9e [client] Prevent panic in case of double close call (#3475)
Prevent panic in case of double close call
2025-03-10 13:16:28 +01:00
Maycon Santos
81040ff80a [docs] Update typo (#3477) 2025-03-10 11:52:36 +01:00
Viktor Liu
c73481aee4 [client] Enable windows stderr logs by default (#3476) 2025-03-10 11:30:49 +01:00
Hakan Sariman
92286b2541 Implement routeIDLookup for managing local and remote route IDs 2025-03-10 15:58:45 +08:00
Maycon Santos
d8bcf745b0 update integrations 2025-03-09 19:32:38 +01:00
Maycon Santos
8430139d80 fix missing method 2025-03-09 19:03:57 +01:00
Maycon Santos
a2962b4ce0 sync go.sum 2025-03-09 18:50:20 +01:00
Maycon Santos
16fffdb75b sync changes from #3426 2025-03-09 18:48:48 +01:00
Maycon Santos
036cecbf46 update integrations and go mod 2025-03-09 18:47:05 +01:00
Maycon Santos
3482852bb6 sync proto and sum 2025-03-09 18:02:33 +01:00
Maycon Santos
fd62665b1f Merge branch 'main' into feature/flow
# Conflicts:
#	client/cmd/testutil_test.go
#	client/firewall/iptables/router_linux.go
#	client/firewall/nftables/router_linux.go
#	client/firewall/uspfilter/allow_netbird.go
#	client/firewall/uspfilter/allow_netbird_windows.go
#	client/firewall/uspfilter/uspfilter_test.go
#	client/internal/engine.go
#	client/internal/engine_test.go
#	client/server/server_test.go
#	go.mod
#	go.sum
#	management/client/client_test.go
#	management/cmd/management.go
#	management/proto/management.pb.go
#	management/proto/management.proto
#	management/server/account.go
#	management/server/account_test.go
#	management/server/dns_test.go
#	management/server/http/handler.go
#	management/server/http/testing/testing_tools/tools.go
#	management/server/integrations/port_forwarding/controller.go
#	management/server/management_proto_test.go
#	management/server/management_test.go
#	management/server/nameserver_test.go
#	management/server/peer.go
#	management/server/peer_test.go
#	management/server/route_test.go
2025-03-09 17:42:16 +01:00
Viktor Liu
fc1da94520 [client, management] Add port forwarding (#3275)
Add initial support to ingress ports on the client code.

- new types where added
- new protocol messages and controller
2025-03-09 16:06:43 +01:00
Hakan Sariman
1ffe48f0d4 Add nil check in CheckRoutes to prevent potential panic 2025-03-08 12:54:33 +03:00
Hakan Sariman
a3b8a21385 Refactor CheckRoutes to return resource IDs for matching source and destination addresses 2025-03-08 12:26:53 +03:00
Hakan Sariman
86492b88c4 Refactor route handling to simplify route information and improve state management 2025-03-08 12:25:35 +03:00
Hakan Sariman
d08a629f9e Merge branch 'feature/flow' into feat/flow-resid 2025-03-08 12:18:02 +03:00
Viktor Liu
36da464413 Fix tracer test 2025-03-07 17:19:10 +01:00
Hakan Sariman
268e3404d3 Merge branch 'feature/flow' into feat/flow-resid 2025-03-07 18:52:11 +03:00
Hakan Sariman
54d0591833 Refactor route handling to use RouteWithResourceId for improved state management 2025-03-07 18:43:49 +03:00
Muzammil
ae6b61301c Muz/netbird dashboards (#3458)
* added all 3 dashboards

* update readme
2025-03-07 16:13:11 +01:00
Viktor Liu
86370a0e7b Use bytes for flows event id (#3439) 2025-03-07 16:12:47 +01:00
Philippe Vaucher
a444e551b3 [misc] Traefik config improvements (#3346)
* Remove deprecated docker-compose version

* Prettify docker-compose files

* Backports missing logging entries

* Fix signal port

* Add missing relay configuration

* Serve management over 33073 to avoid confusion
2025-03-07 16:10:11 +01:00
Zoltan Papp
53b9a2002f Print out the goroutine id (#3433)
The TXT logger prints out the actual go routine ID

This feature depends on 'loggoroutine' build tag

```go build -tags loggoroutine```
2025-03-07 14:06:47 +01:00
Viktor Liu
cb16d0f45f Align packet tracer behavior with actual code paths (#3424) 2025-03-07 14:03:45 +01:00
Viktor Liu
e8d8bd8f18 Add peer traffic rule IDs to allowed connections in flows (#3442) 2025-03-07 13:56:26 +01:00
Viktor Liu
8b07f21c28 Don't track intercepted packets (#3448) 2025-03-07 13:56:16 +01:00
Viktor Liu
54be772ffd Handle flow updates (#3455) 2025-03-07 13:56:00 +01:00
Zoltan Papp
4b76d93cec [client] Fix TURN-Relay switch (#3456)
- When a peer is connected with TURN and a Relay connection is established, do not force switching to Relay. Keep using TURN until disconnection.

-In the proxy preparation phase, the Bind Proxy does not set the remote conn as a fake address for Bind. When running the Work() function, the proper proxy instance updates the conn inside the Bind.
2025-03-07 12:00:25 +01:00
Viktor Liu
3c3a454e61 Fix merge regression 2025-03-06 16:54:15 +01:00
Viktor Liu
5ff77b3595 Add flow userspace counters (#3438) 2025-03-06 16:52:56 +01:00
Viktor Liu
b180edbe5c Track icmp with id only (#3447) 2025-03-06 14:51:23 +01:00
Hakan Sariman
de3b5c78d7 Fix nil pointer dereference in CheckRoutes method 2025-03-06 14:10:31 +03:00
Hakan Sariman
0b42f40cf6 Refactor route management to include resource IDs in state handling 2025-03-06 13:51:46 +03:00
Viktor Liu
062d1ec76f [misc] Update bug-issue-report.md template (#3449) 2025-03-06 01:10:37 +01:00
Viktor Liu
0a042ac36d Fix merge conflict 2025-03-05 19:11:20 +01:00
Viktor Liu
c111675dd8 [client] Handle large DNS packets in dns route resolution (#3441) 2025-03-05 18:57:17 +01:00
Hakan Sariman
e7f921d787 [client] add resource id fields to netflow events 2025-03-05 20:35:52 +03:00
Viktor Liu
e9f11fb11b Replace net.IP with netip.Addr (#3425) 2025-03-05 18:28:05 +01:00
hakansa
419ed275fa Handle TCP RST flag to transition connection state to closed (#3432) 2025-03-05 18:25:42 +01:00
Viktor Liu
2d4fcaf186 Fix proto numbering (#3436) 2025-03-04 16:57:25 +01:00
Viktor Liu
acf172b52c Add kernel conntrack counters (#3434) 2025-03-04 16:46:03 +01:00
Viktor Liu
8c81a823fa Add flow ACL IDs (#3421) 2025-03-04 16:43:07 +01:00
Maycon Santos
619c549547 sync port forwarding 2025-03-04 16:29:59 +01:00
hakansa
60ffe0dc87 [client] UI Refactor Icon Paths (#3420)
[client] UI Refactor Icon Paths (#3420)
2025-03-04 18:29:29 +03:00
Maycon Santos
9a713a0987 Merge branch 'feature/port-forwarding' into feature/flow
# Conflicts:
#	go.mod
#	go.sum
2025-03-04 16:28:57 +01:00
Pascal Fischer
c4945cd565 add cleanup scheduler + metrics 2025-03-04 16:21:52 +01:00
Viktor Liu
1e10c17ecb Fix tcp state (#3431) 2025-03-04 11:19:54 +01:00
Viktor Liu
bcc5824980 [client] Close userspace firewall properly (#3426) 2025-03-04 11:19:42 +01:00
robertgro
af5796de1c [client] Add Netbird GitHub link to the client ui about sub menu (#3372) 2025-03-03 17:32:50 +01:00
Philippe Vaucher
9d604b7e66 [client Fix env var typo (#3415) 2025-03-03 17:22:51 +01:00
Viktor Liu
96d5190436 Add icmp type and code to forwarder flow event (#3413) 2025-02-28 21:04:07 +01:00
Viktor Liu
d19c26df06 Fix log direction (#3412) 2025-02-28 21:03:40 +01:00
Viktor Liu
36e36414d9 Fix forwarder log displaying (#3411) 2025-02-28 20:53:01 +01:00
bcmmbaga
7e69589e05 Update management-integrations
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2025-02-28 19:49:56 +00:00
bcmmbaga
aa613ab79a Update golang.org/x/crypto/ssh
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2025-02-28 19:27:46 +00:00
Viktor Liu
6ead0ff95e Fix log format 2025-02-28 20:24:23 +01:00
Viktor Liu
0db65a8984 Add routed packet drop flow (#3410) 2025-02-28 20:04:59 +01:00
Pascal Fischer
c138807e95 remove log message 2025-02-28 19:54:50 +01:00
Viktor Liu
637c0c8949 Add icmp type and code (#3409) 2025-02-28 19:16:42 +01:00
Viktor Liu
c72e13d8e6 Add conntrack flows (#3406) 2025-02-28 19:16:29 +01:00
Maycon Santos
f6d7bccfa0 Add flow client with sender/receiver (#3405)
add an initial version of receiver client and flow manager receiver and sender
2025-02-28 17:16:18 +00:00
bcmmbaga
e3ed01cafb go mod tidy
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2025-02-28 17:10:44 +00:00
Viktor Liu
fa748a7ec2 Add userspace flow implementation (#3393) 2025-02-28 11:08:35 +01:00
Maycon Santos
cccc615783 update flow proto package generated code 2025-02-28 03:09:09 +00:00
Maycon Santos
2021463ca0 update flow proto package name 2025-02-28 02:51:57 +00:00
Maycon Santos
f48cfd52e9 fix logger stop (#3403)
* fix logger stop

* use context to stop receiver

* update test
2025-02-28 00:28:17 +00:00
Pascal Fischer
6838f53f40 add getPeerByIp store method 2025-02-27 19:01:05 +01:00
Maycon Santos
8276236dfa Add netflow manager (#3398)
* Add netflow manager

* fix linter issues
2025-02-27 12:05:20 +00:00
Viktor Liu
994b923d56 Move proto and rename port and icmp info (#3399) 2025-02-27 12:52:33 +01:00
Viktor Liu
59e2432231 Add event proto fields (#3397) 2025-02-27 12:29:50 +01:00
Pascal Fischer
eee0d123e4 [management] add flow settings and credentials (#3389) 2025-02-27 12:17:07 +01:00
Viktor Liu
e943203ae2 Add event fields (#3390)
Co-authored-by: Maycon Santos <mlsmaycon@gmail.com>
2025-02-26 12:06:06 +01:00
Bethuel Mmbaga
82c12cc8ae [management] Handle transaction error on peer deletion (#3387)
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2025-02-25 19:57:04 +00:00
Pedro Costa
6a775217cf rename flow proto messages 2025-02-25 16:29:54 +00:00
Maycon Santos
175674749f Add memory flow store (#3386) 2025-02-25 15:23:43 +00:00
Pascal Fischer
1e534cecf6 [management] Add flow proto (#3384) 2025-02-25 13:03:27 +01:00
Pedro Costa
aa3aa8c6a8 [management] flow proto 2025-02-25 11:22:54 +00:00
Pascal Fischer
fbdfe45c25 fix merge conflicts on management 2025-02-25 11:57:25 +01:00
Viktor Liu
81ee172db8 Fix route conflict 2025-02-25 11:44:21 +01:00
Viktor Liu
f8fd65a65f Merge branch 'main' into feature/port-forwarding 2025-02-25 11:37:52 +01:00
Bethuel Mmbaga
62b978c050 [management] Add support for tcp/udp allocations (#3381)
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2025-02-25 10:11:50 +00:00
Misha Bragin
266fdcd2ed Replace webinar link (#3380) 2025-02-24 19:12:10 +01:00
Zoltan Papp
0819df916e [client] Replace string to netip.Prefix (#3362)
Replace string to netip.Prefix

---------

Co-authored-by: Hakan Sariman <hknsrmn46@gmail.com>
2025-02-24 15:51:43 +01:00
Pascal Fischer
c8a558f797 [tests] Retry mysql store creation on reused containers (#3370) 2025-02-24 13:40:11 +01:00
hakansa
dabdef4d67 [client] fix extra DNS labels parameter to Register method in client (#3371)
[client] fix extra DNS labels parameter to Register method in client (#3371)
2025-02-24 14:53:59 +03:00
Viktor Liu
cc48594b0b [client][ui] Disable notifications by default (#3375) 2025-02-24 01:14:31 +01:00
Carlos Hernandez
559e673107 [client] fix privacy warning on macOS (#3350)
* fix: macos privacy warning

Move GetDesktopUIUserAgent to its own package so UI does not have to
import client/system package that reaches out to broadcasts address.
Thus, fixing the network privacy warnings.
2025-02-22 12:41:24 +01:00
Pedro Maia Costa
b64bee35fa [management] faster server bootstrap (#3365)
Faster server bootstrap by counting accounts rather than fetching all from storage in the account manager instantiation.

This change moved the deprecated need to ensure accounts have an All group to tests instead.
2025-02-22 11:31:39 +01:00
Viktor Liu
9a0354b681 [client] Update local interface addresses when gathering candidates (#3324) 2025-02-21 19:44:50 +01:00
M. Essam
73101c8977 [client] Restart netbird-ui post-install in linux deb&rpm (#2992) 2025-02-21 19:39:12 +01:00
Viktor Liu
73ce746ba7 [misc] Rename CI client tests (#3366) 2025-02-21 19:07:43 +01:00
Viktor Liu
a74208abac [client] Fix udp forwarder deadline (#3364) 2025-02-21 18:51:52 +01:00
Viktor Liu
b307298b2f [client] Add netbird ui improvements (#3222) 2025-02-21 16:29:21 +01:00
Pedro Maia Costa
f00a997167 [management] fix grpc new account (#3361) 2025-02-21 15:17:42 +01:00
Viktor Liu
5134e3a06a [client] Add reverse dns zone (#3217) 2025-02-21 12:52:04 +01:00
Maycon Santos
6554026a82 [client] fix client/Dockerfile to reduce vulnerabilities (#3359)
The following vulnerabilities are fixed with an upgrade:
- https://snyk.io/vuln/SNYK-ALPINE321-MUSL-8720634
- https://snyk.io/vuln/SNYK-ALPINE321-MUSL-8720634
- https://snyk.io/vuln/SNYK-ALPINE321-OPENSSL-8690014
- https://snyk.io/vuln/SNYK-ALPINE321-OPENSSL-8690014
- https://snyk.io/vuln/SNYK-ALPINE321-OPENSSL-8710358

Co-authored-by: snyk-bot <snyk-bot@snyk.io>
2025-02-21 12:04:26 +01:00
Christian Stewart
a854660402 [client, signal, management] Update google.golang.org/api to latest (#3288)
* [misc] Add vendor/ to .gitignore

Ignore the vendor/ tree created if someone runs "go mod vendor"

Signed-off-by: Christian Stewart <christian@aperture.us>

* [client, signal, management] Update google.golang.org/protobuf to latest

Updating protobuf runtime library as a dependency of eventually updating
google.golang.org/api in a future commit.

Signed-off-by: Christian Stewart <christian@aperture.us>

* [client, signal, management] Update google.golang.org/grpc to latest

Updating grpc library as a dependency of eventually updating
google.golang.org/api in a future commit.

Signed-off-by: Christian Stewart <christian@aperture.us>

* [client, signal, management] Update golang.org/x/net to latest

Updating x/net library as a dependency of eventually updating
google.golang.org/api in a future commit.

Signed-off-by: Christian Stewart <christian@aperture.us>

* [client, signal, management] Update golang.org/x/oauth2 to latest

Updating x/oauth2 library as a dependency of eventually updating
google.golang.org/api in a future commit.

Signed-off-by: Christian Stewart <christian@aperture.us>

* [client, signal, management] Update github.com/stretchr/testify to latest

Updating testify library as a dependency of eventually updating
google.golang.org/api in a future commit.

Signed-off-by: Christian Stewart <christian@aperture.us>

* [client, signal, management] Update opentelemetry to latest

Updating otel library as a dependency of eventually updating
google.golang.org/api in a future commit.

Signed-off-by: Christian Stewart <christian@aperture.us>

* [client, signal, management] Update golang.org/x/time to latest

Updating x/time library as a dependency of eventually updating
google.golang.org/api in a future commit.

Signed-off-by: Christian Stewart <christian@aperture.us>

* [management] Update google.golang.org/api to latest

Updating google.golang.org/api library to fix indirect dependency issues with
older versions of OpenTelemetry.

See: #3240

Signed-off-by: Christian Stewart <christian@aperture.us>

---------

Signed-off-by: Christian Stewart <christian@aperture.us>
2025-02-21 12:02:50 +01:00
Misha Bragin
a0b48f971c Add K8s webinar to Readme 2025-02-21 11:13:02 +01:00
Zoltan Papp
96de928cb3 Interface code cleaning (#3358)
Code cleaning in interfaces files
2025-02-21 10:19:38 +01:00
Bethuel Mmbaga
4ebf1410c6 [management] Add support to allocate same port for public and internal (#3347)
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2025-02-21 11:16:24 +03:00
Pedro Maia Costa
77e40f41f2 [management] refactor auth (#3296) 2025-02-20 20:24:40 +00:00
Viktor Liu
d7d5b1b1d6 Skip CLI session expired notifcation if notifications are disabled (#3266) 2025-02-20 15:01:53 +01:00
Viktor Liu
630edf2480 Remove unused var 2025-02-20 13:24:37 +01:00
Viktor Liu
ea469d28d7 Merge branch 'main' into feature/port-forwarding 2025-02-20 13:24:05 +01:00
Viktor Liu
631ef4ed28 [client] Add embeddable library (#3239) 2025-02-20 13:22:03 +01:00
Pascal Fischer
597f1d47b8 fix management test suite 2025-02-20 13:08:18 +01:00
Viktor Liu
fcc96417f9 Merge branch 'main' into feature/port-forwarding 2025-02-20 11:45:30 +01:00
hakansa
39986b0e97 [client, management] Support DNS Labels for Peer Addressing (#3252)
* [client] Support Extra DNS Labels for Peer Addressing

* [management] Support Extra DNS Labels for Peer Addressing

---------

Co-authored-by: Viktor Liu <17948409+lixmal@users.noreply.github.com>
2025-02-20 13:43:20 +03:00
Viktor Liu
8755211a60 Merge branch 'main' into feature/port-forwarding 2025-02-20 11:39:06 +01:00
Viktor Liu
62a0c358f9 [client] Add UI client event notifications (#3207) 2025-02-20 11:00:02 +01:00
César Gonçalves
87311074f1 [misc] improvement(template): add traefik labels to relay (#3333) 2025-02-20 10:56:22 +01:00
Carlos Hernandez
33cf9535b3 [client] Use go build to embed less icons (#3351) 2025-02-20 10:55:44 +01:00
Pascal Fischer
7e6beee7f6 [management] optimize test execution (#3204) 2025-02-19 19:13:45 +01:00
Viktor Liu
27b3891b14 [client] Set up local dns policy additionally if a gpo policy is detected (#3336) 2025-02-19 12:35:30 +01:00
Pascal Fischer
2a864832c6 [management] remove gorm preparestmt from all DB connections (#3292) 2025-02-18 15:24:17 +01:00
Pascal Fischer
c974c12d65 [signal] Fix registry not found (#3342) 2025-02-18 14:23:34 +01:00
hakansa
50926bdbb4 [client] [ui] issue when changing setting in GUI while peer session is expired (#3334)
* [client] [ui] fix issue when changing settings in GUI while peer session is expired
2025-02-18 13:17:34 +03:00
Maycon Santos
bd381d59cd [misc] Run management benchmark jobs on file changes (#3343)
They will always run on Main
2025-02-18 10:45:41 +01:00
Karsa
f67e56d3b9 [client][ui] added accessible tray icons (#3335)
Added accessible tray icons with:
- dark mode support on Windows and Linux, kudos to @burgosz for the PoC
- template icon support on MacOS
Also added appropriate connecting status icons
2025-02-18 02:21:44 +01:00
Bethuel Mmbaga
8fb5a9ce11 [management] add batching support for SaveUsers and SaveGroups (#3341)
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2025-02-18 00:08:03 +01:00
Bethuel Mmbaga
4cdb2e533a [management] Refactor users to use store methods (#2917)
* Refactor setup key handling to use store methods

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* add lock to get account groups

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* add check for regular user

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* get only required groups for auto-group validation

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* add account lock and return auto groups map on validation

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* refactor account peers update

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Refactor groups to use store methods

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* refactor GetGroupByID and add NewGroupNotFoundError

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* fix tests

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Add AddPeer and RemovePeer methods to Group struct

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Preserve store engine in SqlStore transactions

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Run groups ops in transaction

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* fix missing group removed from setup key activity

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* fix merge

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Refactor posture checks to remove get and save account

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* fix refactor

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* fix tests

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* fix merge

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* fix sonar

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Change setup key log level to debug for missing group

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Retrieve modified peers once for group events

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Refactor policy get and save account to use store methods

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Fix tests

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Add tests

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Add tests

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Retrieve policy groups and posture checks once for validation

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Fix typo

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Add policy tests

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Refactor anyGroupHasPeers to retrieve all groups once

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Refactor dns settings to use store methods

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Add tests

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Add account locking and merge group deletion methods

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Fix tests

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Refactor name server groups to use store methods

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Add tests

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Add peer store methods

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Refactor ephemeral peers

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Add lock for peer store methods

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Refactor peer handlers

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Refactor peer to use store methods

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Fix tests

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Fix typo

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Add locks and remove log

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* run peer ops in transaction

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* remove duplicate store method

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* fix peer fields updated after save

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* add tests

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Use update strength and simplify check

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* prevent changing ruleID when not empty

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* prevent duplicate rules during updates

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* fix tests

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* fix lint

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Refactor auth middleware

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Refactor account methods and mock

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Refactor user and PAT handling

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Remove db query context and fix get user by id

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Fix database transaction locking issue

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Fix tests

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Use UTC time in test

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Add account locks

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Fix prevent users from creating PATs for other users

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Add tests

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Add store locks and prevent fetching setup keys peers when retrieving user peers with empty userID

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Add missing tests

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Refactor test names and remove duplicate TestPostgresql_SavePeerStatus

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Add account locks and remove redundant ephemeral check

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Retrieve all groups for peers and restrict groups for regular users

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Fix merge

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Fix merge

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* fix merge

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* fix store tests

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* use account object to get validated peers

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Fix merge

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Improve peer performance

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Get account direct from store without buffer

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Add get peer groups tests

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Adjust benchmarks

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Adjust benchmarks

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* [management] Update benchmark workflow (#3181)

* update local benchmark expectations

* update cloud expectations

* Add status error for generic result error

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Use integrated validator direct

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* update expectations

* update expectations

* update expectations

* Refactor peer scheduler to retry every 3 seconds on errors

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* update expectations

* fix validator

* fix validator

* fix validator

* update timeouts

* Refactor ToGroupsInfo to process slices of groups

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* update expectations

* update expectations

* update expectations

* Bump integrations version

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Refactor GetValidatedPeers

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Fix tests

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* go mod tidy

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Use peers and groups map for peers validation

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* remove mysql from api benchmark tests

* Fix merge

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Fix blocked db calls on user auto groups update

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Fix tests

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* update expectations

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* update expectations

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Skip user check for system initiated peer deletion

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Remove context in db calls

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* update expectations

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* [management] Improve group peer/resource counting (#3192)

* Fix sonar

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Adjust bench expectations

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Rename GetAccountInfoFromPAT to GetTokenInfo

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Fix tests

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Remove global account lock for ListUsers

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* build userinfo after updating users in db

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* [management] Optimize user bulk deletion  (#3315)

* refactor building user infos

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* fix tests

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* remove unused code

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Refactor GetUsersFromAccount to return a map of UserInfo instead of a slice

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Export BuildUserInfosForAccount to account manager

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Fetch account user info once for bulk users save

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Update user deletion expectations

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Set max open conns for activity store

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Update bench expectations

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

---------

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

---------

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
Co-authored-by: Pascal Fischer <32096965+pascal-fischer@users.noreply.github.com>
Co-authored-by: Pascal Fischer <pascal@netbird.io>
Co-authored-by: Pedro Costa <550684+pnmcosta@users.noreply.github.com>
2025-02-17 21:43:12 +03:00
Pascal Fischer
abe8da697c [signal] add pprof and message size metrics (#3337) 2025-02-17 17:07:30 +01:00
hakansa
039a985f41 [client] Normalize DNS record names to lowercase in local handler update (#3323)
* [client] Normalize DNS record names to lowercase in lookup
2025-02-14 13:13:40 +03:00
Viktor Liu
c4a6dafd27 [client] Use GPO DNS Policy Config to configure DNS if present (#3319) 2025-02-13 18:17:18 +01:00
Zoltan Papp
a930c2aecf Fix priority handling (#3313) 2025-02-13 15:48:10 +01:00
Pascal Fischer
e6d4653b08 [management] add cloud tag to get ingress ports api spec (#3300)
* fix tag for get endpoint

* update labels
2025-02-12 16:11:54 +01:00
Pedro Maia Costa
d48edb9837 fix integration tests (#3311) 2025-02-12 11:16:51 +00:00
Viktor Liu
b41de7fcd1 [client] Enable userspace forwarder conditionally (#3309)
* Enable userspace forwarder conditionally

* Move disable/enable logic
2025-02-12 11:10:49 +01:00
Viktor Liu
18f84f0df5 [client] Check for fwmark support and use fallback routing if not supported (#3220) 2025-02-11 13:09:17 +01:00
Viktor Liu
44407a158a [client] Fix dns handler chain test (#3307) 2025-02-11 12:42:04 +01:00
Viktor Liu
488b697479 [client] Support dns upstream failover for nameserver groups with same match domain (#3178) 2025-02-10 18:13:34 +01:00
Zoltan Papp
5953b43ead [client, relay] Fix/wg watch (#3261)
Fix WireGuard watcher related issues

- Fix race handling between TURN and Relayed reconnection
- Move the WgWatcher logic to separate struct
- Handle timeouts in a more defensive way
- Fix initial Relay client reconnection to the home server
2025-02-10 10:32:50 +01:00
ransomware
58b2eb4b92 [signal] Fix context propagation in signal server (#3251) 2025-02-07 15:05:41 +01:00
Viktor Liu
05415f72ec [client] Add experimental support for userspace routing (#3134) 2025-02-07 14:11:53 +01:00
Pascal Fischer
b7af53ea40 [management] add logs for grpc API (#3298) 2025-02-07 13:51:17 +01:00
Pascal Fischer
cee4aeea9e [management] Check groups when counting peers on networks list (#3284) 2025-02-06 13:36:57 +01:00
Zoltan Papp
eb69f2de78 Fix nil pointer exception when load empty list and try to cast it (#3282) 2025-02-06 10:28:42 +01:00
Viktor Liu
206420c085 [client] Fix grouping of peer ACLs with different port ranges (#3289) 2025-02-06 10:28:42 +01:00
Christian Stewart
88a864c195 [relay] Use new upstream for nhooyr.io/websocket package (#3287)
The nhooyr.io/websocket package was renamed to github.com/coder/websocket when
the project was transferred to "coder" as the new maintainer.

Use the new import path and update go.mod and go.sum accordingly.

Signed-off-by: Christian Stewart <christian@aperture.us>
2025-02-06 10:28:42 +01:00
Zoltan Papp
ca9aca9b19 Fix nil pointer exception when load empty list and try to cast it (#3282) 2025-02-06 10:20:31 +01:00
Viktor Liu
e00a280329 [client] Fix grouping of peer ACLs with different port ranges (#3289) 2025-02-05 23:04:52 +01:00
Christian Stewart
fe370e7d8f [relay] Use new upstream for nhooyr.io/websocket package (#3287)
The nhooyr.io/websocket package was renamed to github.com/coder/websocket when
the project was transferred to "coder" as the new maintainer.

Use the new import path and update go.mod and go.sum accordingly.

Signed-off-by: Christian Stewart <christian@aperture.us>
2025-02-05 23:03:53 +01:00
Pascal Fischer
a789e9e6d8 [management] fix duplication detection (#3286) 2025-02-05 21:42:09 +01:00
Viktor Liu
9930913e4e Merge branch 'main' into feature/port-forwarding 2025-02-05 18:55:59 +01:00
Viktor Liu
125b5e2b16 [client] Fix acl empty port range detection (#3285) 2025-02-05 18:55:42 +01:00
Viktor Liu
48675f579f Merge branch 'main' into feature/port-forwarding 2025-02-05 17:44:01 +01:00
Pascal Fischer
afec455f86 [management] copy port info (#3283) 2025-02-05 17:30:42 +01:00
Pascal Fischer
035c5d9f23 [management merge only unique entries on network map merge (#3277) 2025-02-05 16:50:45 +01:00
Viktor Liu
97d498c59c [misc, client, management] Replace Wiretrustee with Netbird (#3267) 2025-02-05 16:49:41 +01:00
Viktor Liu
b2a5b29fb2 Merge branch 'main' into feature/port-forwarding 2025-02-05 10:15:37 +01:00
Bethuel Mmbaga
9ec61206c2 [management] Add support for filtering peers by name and IP (#3279)
* add peers ip and name filters

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* add get peers filter

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* fix get account peers

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Extend GetAccountPeers store to support filtering by name and IP

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Fix get peers references

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

---------

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2025-02-05 00:33:15 +03:00
hakansa
0125cd97d8 [client] use embedded root CA if system certpool is empty (#3272)
* Implement custom TLS certificate handling with fallback to embedded roots
2025-02-04 18:17:59 +03:00
M. Essam
7d385b8dc3 [management] REST client package (#3278) 2025-02-04 10:10:10 +00:00
Zoltan Papp
f930ef2ee6 Cleanup magiconair usage from repo (#3276) 2025-02-03 17:54:35 +01:00
Zoltan Papp
1b011a2d85 [client] Manage the IP forwarding sysctl setting in global way (#3270)
Add new package ipfwdstate that implements reference counting for IP forwarding
state management. This allows multiple usage to safely request IP forwarding
without interfering with each other.
2025-02-03 12:27:18 +01:00
Pascal Fischer
a85ea1ddb0 [manager] ingress ports manager support (#3268)
* add peers manager

* Extend peers manager to support retrieving all peers

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* add network map calc

* move integrations interface

* update management-integrations

* merge main and fix

* go mod tidy

* [management] port forwarding add peer manager fix network map (#3264)

* [management] fix testing tools (#3265)

* Fix net.IPv4 conversion to []byte

* update test to check ipv4

---------

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
Co-authored-by: bcmmbaga <bethuelmbaga12@gmail.com>
Co-authored-by: Zoltán Papp <zoltan.pmail@gmail.com>
2025-02-03 09:37:37 +01:00
Zoltán Papp
829e40d2aa Fix ingress manager unnecessary creation 2025-02-01 10:58:47 +01:00
Pascal Fischer
6344e34880 [management] renamed ingress port endpoints (#3263) 2025-02-01 00:40:33 +01:00
Pascal Fischer
a76ca8c565 Merge branch 'main' into feature/port-forwarding 2025-01-29 22:28:10 +01:00
dependabot[bot]
771c99a523 [clien]t Bump golang.org/x/net from 0.30.0 to 0.33.0 (#3218)
Bumps [golang.org/x/net](https://github.com/golang/net) from 0.30.0 to 0.33.0.
- [Commits](https://github.com/golang/net/compare/v0.30.0...v0.33.0)

---
updated-dependencies:
- dependency-name: golang.org/x/net
  dependency-type: direct:production
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-01-29 18:25:06 +01:00
Zoltan Papp
26693e4ea8 Feature/port forwarding client ingress (#3242)
Client-side forward handling

Co-authored-by: Viktor Liu <17948409+lixmal@users.noreply.github.com>

---------

Co-authored-by: Viktor Liu <17948409+lixmal@users.noreply.github.com>
2025-01-29 16:04:33 +01:00
Viktor Liu
e20be2397c [client] Add missing peer ACL flush (#3247) 2025-01-28 23:25:22 +01:00
Maycon Santos
46766e7e24 [misc] Update sign pipeline version (#3246) 2025-01-28 22:48:19 +01:00
Viktor Liu
a7ddb8f1f8 [client] Replace engine probes with direct calls (#3195) 2025-01-28 12:25:45 +01:00
Pascal Fischer
7335c82553 [management] copy destination and source resource on policyRUle copy (#3235) 2025-01-28 07:05:21 +01:00
Viktor Liu
a32ec97911 [client] Use dynamic dns route resolution on iOS (#3243) 2025-01-27 18:13:10 +01:00
Pascal Fischer
f6a71f4193 [management] add openapi specs and generate types for port forwarding proxy (#3236) 2025-01-27 17:47:40 +01:00
Viktor Liu
5c05131a94 [client] Support port ranges in peer ACLs (#3232) 2025-01-27 13:51:57 +01:00
Pascal Fischer
b6abd4b4da [management/signal/relay] add metrics descriptions (#3233) 2025-01-24 14:17:30 +01:00
Pascal Fischer
2605948e01 [management] use account request buffer on sync (#3229) 2025-01-24 12:04:50 +01:00
Viktor Liu
eb2ac039c7 [client] Mark redirected traffic early to match input filters on pre-DNAT ports (#3205) 2025-01-23 18:00:51 +01:00
Viktor Liu
790a9ed7df [client] Match more specific dns handler first (#3226) 2025-01-23 18:00:05 +01:00
Viktor Liu
2e61ce006d [client] Back up corrupted state files and present them in the debug bundle (#3227) 2025-01-23 17:59:44 +01:00
Viktor Liu
3cc485759e [client] Use correct stdout/stderr log paths for debug bundle on macOS (#3231) 2025-01-23 17:59:22 +01:00
Viktor Liu
aafa9c67fc [client] Fix freebsd default routes (#3230) 2025-01-23 16:57:11 +01:00
Pascal Fischer
69f48db0a3 [management] disable prepareStmt for sqlite (#3228) 2025-01-22 19:53:20 +01:00
Pascal Fischer
8c965434ae [management] remove peer from group on delete (#3223) 2025-01-22 19:33:20 +01:00
Eddie Garcia
78da6b42ad [misc] Fix typo in test output (#3216)
Fix a typo in test output
2025-01-22 18:57:54 +01:00
Bethuel Mmbaga
1ad2cb5582 [management] Refactor peers to use store methods (#2893) 2025-01-20 18:41:46 +01:00
Viktor Liu
c619bf5b0c [client] Allow freebsd to build netbird-ui (#3212) 2025-01-20 11:02:09 +01:00
Maycon Santos
9f4db0a953 [client] Close ice agent only if not nil (#3210) 2025-01-18 00:18:59 +01:00
513 changed files with 39859 additions and 10905 deletions

View File

@@ -31,14 +31,22 @@ Please specify whether you use NetBird Cloud or self-host NetBird's control plan
`netbird version`
**NetBird status -dA output:**
**Is any other VPN software installed?**
If applicable, add the `netbird status -dA' command output.
If yes, which one?
**Do you face any (non-mobile) client issues?**
**Debug output**
Please provide the file created by `netbird debug for 1m -AS`.
We advise reviewing the anonymized files for any remaining PII.
To help us resolve the problem, please attach the following debug output
netbird status -dA
As well as the file created by
netbird debug for 1m -AS
We advise reviewing the anonymized output for any remaining personal information.
**Screenshots**
@@ -47,3 +55,10 @@ If applicable, add screenshots to help explain your problem.
**Additional context**
Add any other context about the problem here.
**Have you tried these troubleshooting steps?**
- [ ] Checked for newer NetBird versions
- [ ] Searched for similar issues on GitHub (including closed ones)
- [ ] Restarted the NetBird client
- [ ] Disabled other VPN software
- [ ] Checked firewall settings

View File

@@ -1,4 +1,4 @@
name: Test Code Darwin
name: "Darwin"
on:
push:
@@ -12,9 +12,7 @@ concurrency:
jobs:
test:
strategy:
matrix:
store: ['sqlite']
name: "Client / Unit"
runs-on: macos-latest
steps:
- name: Install Go

View File

@@ -1,5 +1,4 @@
name: Test Code FreeBSD
name: "FreeBSD"
on:
push:
@@ -13,6 +12,7 @@ concurrency:
jobs:
test:
name: "Client / Unit"
runs-on: ubuntu-22.04
steps:
- uses: actions/checkout@v4
@@ -24,7 +24,7 @@ jobs:
copyback: false
release: "14.1"
prepare: |
pkg install -y go
pkg install -y go pkgconf xorg
# -x - to print all executed commands
# -e - to faile on first error
@@ -33,7 +33,7 @@ jobs:
time go build -o netbird client/main.go
# check all component except management, since we do not support management server on freebsd
time go test -timeout 1m -failfast ./base62/...
# NOTE: without -p1 `client/internal/dns` will fail becasue of `listen udp4 :33100: bind: address already in use`
# NOTE: without -p1 `client/internal/dns` will fail because of `listen udp4 :33100: bind: address already in use`
time go test -timeout 8m -failfast -p 1 ./client/...
time go test -timeout 1m -failfast ./dns/...
time go test -timeout 1m -failfast ./encryption/...

View File

@@ -1,4 +1,4 @@
name: Test Code Linux
name: Linux
on:
push:
@@ -12,11 +12,21 @@ concurrency:
jobs:
build-cache:
name: "Build Cache"
runs-on: ubuntu-22.04
outputs:
management: ${{ steps.filter.outputs.management }}
steps:
- name: Checkout code
uses: actions/checkout@v4
- uses: dorny/paths-filter@v3
id: filter
with:
filters: |
management:
- 'management/**'
- name: Install Go
uses: actions/setup-go@v5
with:
@@ -38,7 +48,6 @@ jobs:
key: ${{ runner.os }}-gotest-cache-${{ hashFiles('**/go.sum') }}
restore-keys: |
${{ runner.os }}-gotest-cache-${{ hashFiles('**/go.sum') }}
- name: Install dependencies
if: steps.cache.outputs.cache-hit != 'true'
@@ -89,6 +98,7 @@ jobs:
run: CGO_ENABLED=1 GOARCH=386 go build -o relay-386 .
test:
name: "Client / Unit"
needs: [build-cache]
strategy:
fail-fast: false
@@ -134,14 +144,121 @@ jobs:
run: git --no-pager diff --exit-code
- name: Test
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} CI=true go test -tags devcert -exec 'sudo' -timeout 10m -p 1 $(go list ./... | grep -v /management)
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} CI=true go test -tags devcert -exec 'sudo' -timeout 10m -p 1 $(go list ./... | grep -v -e /management -e /signal -e /relay)
test_relay:
name: "Relay / Unit"
needs: [build-cache]
strategy:
fail-fast: false
matrix:
arch: [ '386','amd64' ]
runs-on: ubuntu-22.04
steps:
- name: Install Go
uses: actions/setup-go@v5
with:
go-version: "1.23.x"
cache: false
- name: Checkout code
uses: actions/checkout@v4
- name: Get Go environment
run: |
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
- name: Cache Go modules
uses: actions/cache/restore@v4
with:
path: |
${{ env.cache }}
${{ env.modcache }}
key: ${{ runner.os }}-gotest-cache-${{ hashFiles('**/go.sum') }}
restore-keys: |
${{ runner.os }}-gotest-cache-
- name: Install dependencies
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
- name: Install 32-bit libpcap
if: matrix.arch == '386'
run: sudo dpkg --add-architecture i386 && sudo apt update && sudo apt-get install -y libpcap0.8-dev:i386
- name: Install modules
run: go mod tidy
- name: check git status
run: git --no-pager diff --exit-code
- name: Test
run: |
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
go test \
-exec 'sudo' \
-timeout 10m ./signal/...
test_signal:
name: "Signal / Unit"
needs: [build-cache]
strategy:
fail-fast: false
matrix:
arch: [ '386','amd64' ]
runs-on: ubuntu-22.04
steps:
- name: Install Go
uses: actions/setup-go@v5
with:
go-version: "1.23.x"
cache: false
- name: Checkout code
uses: actions/checkout@v4
- name: Get Go environment
run: |
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
- name: Cache Go modules
uses: actions/cache/restore@v4
with:
path: |
${{ env.cache }}
${{ env.modcache }}
key: ${{ runner.os }}-gotest-cache-${{ hashFiles('**/go.sum') }}
restore-keys: |
${{ runner.os }}-gotest-cache-
- name: Install dependencies
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
- name: Install 32-bit libpcap
if: matrix.arch == '386'
run: sudo dpkg --add-architecture i386 && sudo apt update && sudo apt-get install -y libpcap0.8-dev:i386
- name: Install modules
run: go mod tidy
- name: check git status
run: git --no-pager diff --exit-code
- name: Test
run: |
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
go test \
-exec 'sudo' \
-timeout 10m ./signal/...
test_management:
name: "Management / Unit"
needs: [ build-cache ]
strategy:
fail-fast: false
matrix:
arch: [ '386','amd64' ]
arch: [ 'amd64' ]
store: [ 'sqlite', 'postgres', 'mysql' ]
runs-on: ubuntu-22.04
steps:
@@ -194,15 +311,22 @@ jobs:
run: docker pull mlsmaycon/warmed-mysql:8
- name: Test
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -tags=devcert -p 1 -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 10m $(go list ./... | grep /management)
run: |
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
NETBIRD_STORE_ENGINE=${{ matrix.store }} \
go test -tags=devcert \
-exec "sudo --preserve-env=CI,NETBIRD_STORE_ENGINE" \
-timeout 20m ./management/...
benchmark:
name: "Management / Benchmark"
needs: [ build-cache ]
if: ${{ needs.build-cache.outputs.management == 'true' || github.event_name != 'pull_request' }}
strategy:
fail-fast: false
matrix:
arch: [ '386','amd64' ]
store: [ 'sqlite', 'postgres', 'mysql' ]
arch: [ 'amd64' ]
store: [ 'sqlite', 'postgres' ]
runs-on: ubuntu-22.04
steps:
- name: Install Go
@@ -254,15 +378,22 @@ jobs:
run: docker pull mlsmaycon/warmed-mysql:8
- name: Test
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -tags devcert -run=^$ -bench=. -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 20m ./...
run: |
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true \
go test -tags devcert -run=^$ -bench=. \
-exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' \
-timeout 20m ./...
api_benchmark:
name: "Management / Benchmark (API)"
needs: [ build-cache ]
if: ${{ needs.build-cache.outputs.management == 'true' || github.event_name != 'pull_request' }}
strategy:
fail-fast: false
matrix:
arch: [ '386','amd64' ]
store: [ 'sqlite', 'postgres', 'mysql' ]
arch: [ 'amd64' ]
store: [ 'sqlite', 'postgres' ]
runs-on: ubuntu-22.04
steps:
- name: Install Go
@@ -312,16 +443,25 @@ jobs:
- name: download mysql image
if: matrix.store == 'mysql'
run: docker pull mlsmaycon/warmed-mysql:8
- name: Test
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -tags=benchmark -run=^$ -bench=. -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 10m $(go list ./... | grep /management)
run: |
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true \
go test -tags=benchmark \
-run=^$ \
-bench=. \
-exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' \
-timeout 20m ./management/...
api_integration_test:
name: "Management / Integration"
needs: [ build-cache ]
if: ${{ needs.build-cache.outputs.management == 'true' || github.event_name != 'pull_request' }}
strategy:
fail-fast: false
matrix:
arch: [ '386','amd64' ]
arch: [ 'amd64' ]
store: [ 'sqlite', 'postgres']
runs-on: ubuntu-22.04
steps:
@@ -363,9 +503,15 @@ jobs:
run: git --no-pager diff --exit-code
- name: Test
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -p 1 -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 10m -tags=integration $(go list ./... | grep /management)
run: |
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true \
go test -tags=integration \
-exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' \
-timeout 20m ./management/...
test_client_on_docker:
name: "Client (Docker) / Unit"
needs: [ build-cache ]
runs-on: ubuntu-20.04
steps:

View File

@@ -1,4 +1,4 @@
name: Test Code Windows
name: "Windows"
on:
push:
@@ -14,6 +14,7 @@ concurrency:
jobs:
test:
name: "Client / Unit"
runs-on: windows-latest
steps:
- name: Checkout code

View File

@@ -1,4 +1,4 @@
name: golangci-lint
name: Lint
on: [pull_request]
permissions:
@@ -27,7 +27,14 @@ jobs:
fail-fast: false
matrix:
os: [macos-latest, windows-latest, ubuntu-latest]
name: lint
include:
- os: macos-latest
display_name: Darwin
- os: windows-latest
display_name: Windows
- os: ubuntu-latest
display_name: Linux
name: ${{ matrix.display_name }}
runs-on: ${{ matrix.os }}
timeout-minutes: 15
steps:

View File

@@ -1,4 +1,4 @@
name: Mobile build validation
name: Mobile
on:
push:
@@ -12,6 +12,7 @@ concurrency:
jobs:
android_build:
name: "Android / Build"
runs-on: ubuntu-latest
steps:
- name: Checkout repository
@@ -47,6 +48,7 @@ jobs:
CGO_ENABLED: 0
ANDROID_NDK_HOME: /usr/local/lib/android/sdk/ndk/23.1.7779620
ios_build:
name: "iOS / Build"
runs-on: macos-latest
steps:
- name: Checkout repository

View File

@@ -9,10 +9,10 @@ on:
pull_request:
env:
SIGN_PIPE_VER: "v0.0.17"
SIGN_PIPE_VER: "v0.0.18"
GORELEASER_VER: "v2.3.2"
PRODUCT_NAME: "NetBird"
COPYRIGHT: "Wiretrustee UG (haftungsbeschreankt)"
COPYRIGHT: "NetBird GmbH"
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
@@ -71,7 +71,7 @@ jobs:
- name: Install goversioninfo
run: go install github.com/josephspurrier/goversioninfo/cmd/goversioninfo@233067e
- name: Generate windows syso amd64
run: goversioninfo -icon client/ui/netbird.ico -manifest client/manifest.xml -product-name ${{ env.PRODUCT_NAME }} -copyright "${{ env.COPYRIGHT }}" -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -ver-build 0 -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/resources_windows_amd64.syso
run: goversioninfo -icon client/ui/assets/netbird.ico -manifest client/manifest.xml -product-name ${{ env.PRODUCT_NAME }} -copyright "${{ env.COPYRIGHT }}" -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -ver-build 0 -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/resources_windows_amd64.syso
- name: Run GoReleaser
uses: goreleaser/goreleaser-action@v4
with:
@@ -150,7 +150,7 @@ jobs:
- name: Install goversioninfo
run: go install github.com/josephspurrier/goversioninfo/cmd/goversioninfo@233067e
- name: Generate windows syso amd64
run: goversioninfo -64 -icon client/ui/netbird.ico -manifest client/ui/manifest.xml -product-name ${{ env.PRODUCT_NAME }}-"UI" -copyright "${{ env.COPYRIGHT }}" -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -ver-build 0 -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/ui/resources_windows_amd64.syso
run: goversioninfo -64 -icon client/ui/assets/netbird.ico -manifest client/ui/manifest.xml -product-name ${{ env.PRODUCT_NAME }}-"UI" -copyright "${{ env.COPYRIGHT }}" -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -ver-build 0 -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/ui/resources_windows_amd64.syso
- name: Run GoReleaser
uses: goreleaser/goreleaser-action@v4

1
.gitignore vendored
View File

@@ -29,3 +29,4 @@ infrastructure_files/setup.env
infrastructure_files/setup-*.env
.vscode
.DS_Store
vendor/

View File

@@ -103,7 +103,7 @@ linters:
- predeclared # predeclared finds code that shadows one of Go's predeclared identifiers
- revive # Fast, configurable, extensible, flexible, and beautiful linter for Go. Drop-in replacement of golint.
- sqlclosecheck # checks that sql.Rows and sql.Stmt are closed
- thelper # thelper detects Go test helpers without t.Helper() call and checks the consistency of test helpers.
# - thelper # thelper detects Go test helpers without t.Helper() call and checks the consistency of test helpers.
- wastedassign # wastedassign finds wasted assignment statements
issues:
# Maximum count of issues with the same text.

View File

@@ -50,10 +50,12 @@ nfpms:
- netbird-ui
formats:
- deb
scripts:
postinstall: "release_files/ui-post-install.sh"
contents:
- src: client/ui/netbird.desktop
- src: client/ui/build/netbird.desktop
dst: /usr/share/applications/netbird.desktop
- src: client/ui/netbird-systemtray-connected.png
- src: client/ui/assets/netbird.png
dst: /usr/share/pixmaps/netbird.png
dependencies:
- netbird
@@ -67,10 +69,12 @@ nfpms:
- netbird-ui
formats:
- rpm
scripts:
postinstall: "release_files/ui-post-install.sh"
contents:
- src: client/ui/netbird.desktop
- src: client/ui/build/netbird.desktop
dst: /usr/share/applications/netbird.desktop
- src: client/ui/netbird-systemtray-connected.png
- src: client/ui/assets/netbird.png
dst: /usr/share/pixmaps/netbird.png
dependencies:
- netbird

View File

@@ -1,3 +1,3 @@
Mikhail Bragin (https://github.com/braginini)
Maycon Santos (https://github.com/mlsmaycon)
Wiretrustee UG (haftungsbeschränkt)
NetBird GmbH

View File

@@ -3,10 +3,10 @@
We are incredibly thankful for the contributions we receive from the community.
We require our external contributors to sign a Contributor License Agreement ("CLA") in
order to ensure that our projects remain licensed under Free and Open Source licenses such
as BSD-3 while allowing Wiretrustee to build a sustainable business.
as BSD-3 while allowing NetBird to build a sustainable business.
Wiretrustee is committed to having a true Open Source Software ("OSS") license for
our software. A CLA enables Wiretrustee to safely commercialize our products
NetBird is committed to having a true Open Source Software ("OSS") license for
our software. A CLA enables NetBird to safely commercialize our products
while keeping a standard OSS license with all the rights that license grants to users: the
ability to use the project in their own projects or businesses, to republish modified
source, or to completely fork the project.
@@ -20,11 +20,11 @@ This is a human-readable summary of (and not a substitute for) the full agreemen
This highlights only some of key terms of the CLA. It has no legal value and you should
carefully review all the terms of the actual CLA before agreeing.
<li>Grant of copyright license. You give Wiretrustee permission to use your copyrighted work
<li>Grant of copyright license. You give NetBird permission to use your copyrighted work
in commercial products.
</li>
<li>Grant of patent license. If your contributed work uses a patent, you give Wiretrustee a
<li>Grant of patent license. If your contributed work uses a patent, you give NetBird a
license to use that patent including within commercial products. You also agree that you
have permission to grant this license.
</li>
@@ -45,7 +45,7 @@ more.
# Why require a CLA?
Agreeing to a CLA explicitly states that you are entitled to provide a contribution, that you cannot withdraw permission
to use your contribution at a later date, and that Wiretrustee has permission to use your contribution in our commercial
to use your contribution at a later date, and that NetBird has permission to use your contribution in our commercial
products.
This removes any ambiguities or uncertainties caused by not having a CLA and allows users and customers to confidently
@@ -65,25 +65,25 @@ Follow the steps given by the bot to sign the CLA. This will require you to log
information from your account) and to fill in a few additional details such as your name and email address. We will only
use this information for CLA tracking; none of your submitted information will be used for marketing purposes.
You only have to sign the CLA once. Once you've signed the CLA, future contributions to any Wiretrustee project will not
You only have to sign the CLA once. Once you've signed the CLA, future contributions to any NetBird project will not
require you to sign again.
# Legal Terms and Agreement
In order to clarify the intellectual property license granted with Contributions from any person or entity, Wiretrustee
UG (haftungsbeschränkt) ("Wiretrustee") must have a Contributor License Agreement ("CLA") on file that has been signed
In order to clarify the intellectual property license granted with Contributions from any person or entity, NetBird
GmbH ("NetBird") must have a Contributor License Agreement ("CLA") on file that has been signed
by each Contributor, indicating agreement to the license terms below. This license does not change your rights to use
your own Contributions for any other purpose.
You accept and agree to the following terms and conditions for Your present and future Contributions submitted to
Wiretrustee. Except for the license granted herein to Wiretrustee and recipients of software distributed by Wiretrustee,
NetBird. Except for the license granted herein to NetBird and recipients of software distributed by NetBird,
You reserve all right, title, and interest in and to Your Contributions.
1. Definitions.
```
"You" (or "Your") shall mean the copyright owner or legal entity authorized by the copyright owner
that is making this Agreement with Wiretrustee. For legal entities, the entity making a Contribution and all other
that is making this Agreement with NetBird. For legal entities, the entity making a Contribution and all other
entities that control, are controlled by, or are under common control with that entity are considered
to be a single Contributor. For the purposes of this definition, "control" means (i) the power, direct or indirect,
to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty
@@ -91,23 +91,23 @@ You reserve all right, title, and interest in and to Your Contributions.
```
```
"Contribution" shall mean any original work of authorship, including any modifications or additions to
an existing work, that is or previously has been intentionally submitted by You to Wiretrustee for inclusion in,
or documentation of, any of the products owned or managed by Wiretrustee (the "Work").
an existing work, that is or previously has been intentionally submitted by You to NetBird for inclusion in,
or documentation of, any of the products owned or managed by NetBird (the "Work").
For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication
sent to Wiretrustee or its representatives, including but not limited to communication on electronic mailing lists,
sent to NetBird or its representatives, including but not limited to communication on electronic mailing lists,
source code control systems, and issue tracking systems that are managed by, or on behalf of,
Wiretrustee for the purpose of discussing and improving the Work, but excluding communication that is conspicuously
NetBird for the purpose of discussing and improving the Work, but excluding communication that is conspicuously
marked or otherwise designated in writing by You as "Not a Contribution."
```
2. Grant of Copyright License. Subject to the terms and conditions of this Agreement, You hereby grant to Wiretrustee
and to recipients of software distributed by Wiretrustee a perpetual, worldwide, non-exclusive, no-charge,
2. Grant of Copyright License. Subject to the terms and conditions of this Agreement, You hereby grant to NetBird
and to recipients of software distributed by NetBird a perpetual, worldwide, non-exclusive, no-charge,
royalty-free, irrevocable copyright license to reproduce, prepare derivative works of, publicly display, publicly
perform, sublicense, and distribute Your Contributions and such derivative works.
3. Grant of Patent License. Subject to the terms and conditions of this Agreement, You hereby grant to Wiretrustee and
to recipients of software distributed by Wiretrustee a perpetual, worldwide, non-exclusive, no-charge, royalty-free,
3. Grant of Patent License. Subject to the terms and conditions of this Agreement, You hereby grant to NetBird and
to recipients of software distributed by NetBird a perpetual, worldwide, non-exclusive, no-charge, royalty-free,
irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import,
and otherwise transfer the Work, where such license applies only to those patent claims licensable by You that are
necessarily infringed by Your Contribution(s) alone or by combination of Your Contribution(s) with the Work to which
@@ -121,8 +121,8 @@ You reserve all right, title, and interest in and to Your Contributions.
intellectual property that you create that includes your Contributions, you represent that you have received
permission to make Contributions on behalf of that employer, that you will have received permission from your current
and future employers for all future Contributions, that your applicable employer has waived such rights for all of
your current and future Contributions to Wiretrustee, or that your employer has executed a separate Corporate CLA
with Wiretrustee.
your current and future Contributions to NetBird, or that your employer has executed a separate Corporate CLA
with NetBird.
5. You represent that each of Your Contributions is Your original creation (see section 7 for submissions on behalf of
@@ -138,11 +138,11 @@ You reserve all right, title, and interest in and to Your Contributions.
MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE.
7. Should You wish to submit work that is not Your original creation, You may submit it to Wiretrustee separately from
7. Should You wish to submit work that is not Your original creation, You may submit it to NetBird separately from
any Contribution, identifying the complete details of its source and of any license or other restriction (including,
but not limited to, related patents, trademarks, and license agreements) of which you are personally aware, and
conspicuously marking the work as "Submitted on behalf of a third-party: [named here]".
8. You agree to notify Wiretrustee of any facts or circumstances of which you become aware that would make these
representations inaccurate in any respect.
8. You agree to notify NetBird of any facts or circumstances of which you become aware that would make these
representations inaccurate in any respect.

View File

@@ -1,6 +1,6 @@
BSD 3-Clause License
Copyright (c) 2022 Wiretrustee UG (haftungsbeschränkt) & AUTHORS
Copyright (c) 2022 NetBird GmbH & AUTHORS
Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
@@ -10,4 +10,4 @@ Redistribution and use in source and binary forms, with or without modification,
3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

View File

@@ -1,4 +1,6 @@
<div align="center">
<br/>
<br/>
<p align="center">
<img width="234" src="docs/media/logo-full.png"/>
</p>
@@ -31,6 +33,10 @@
<br/>
</strong>
<br>
<a href="https://netbird.io/webinars/achieve-zero-trust-access-to-k8s?utm_source=github&utm_campaign=2502%20-%20webinar%20-%20How%20to%20Achieve%20Zero%20Trust%20Access%20to%20Kubernetes%20-%20Effortlessly&utm_medium=github">
Webinar: Securely Access Kubernetes without Port Forwarding and Jump Hosts
</a>
</p>
<br>

View File

@@ -1,4 +1,4 @@
FROM alpine:3.21.0
FROM alpine:3.21.3
RUN apk add --no-cache ca-certificates iptables ip6tables
ENV NB_FOREGROUND_MODE=true
ENTRYPOINT [ "/usr/local/bin/netbird","up"]

View File

@@ -9,6 +9,7 @@ USER netbird:netbird
ENV NB_FOREGROUND_MODE=true
ENV NB_USE_NETSTACK_MODE=true
ENV NB_ENABLE_NETSTACK_LOCAL_FORWARDING=true
ENV NB_CONFIG=config.json
ENV NB_DAEMON_ADDR=unix://netbird.sock
ENV NB_DISABLE_DNS=true

View File

@@ -13,6 +13,7 @@ import (
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/client/server"
nbstatus "github.com/netbirdio/netbird/client/status"
)
const errCloseConnection = "Failed to close connection: %v"
@@ -85,7 +86,7 @@ func debugBundle(cmd *cobra.Command, _ []string) error {
client := proto.NewDaemonServiceClient(conn)
resp, err := client.DebugBundle(cmd.Context(), &proto.DebugBundleRequest{
Anonymize: anonymizeFlag,
Status: getStatusOutput(cmd),
Status: getStatusOutput(cmd, anonymizeFlag),
SystemInfo: debugSystemInfoFlag,
})
if err != nil {
@@ -196,7 +197,7 @@ func runForDuration(cmd *cobra.Command, args []string) error {
time.Sleep(3 * time.Second)
headerPostUp := fmt.Sprintf("----- Netbird post-up - Timestamp: %s", time.Now().Format(time.RFC3339))
statusOutput := fmt.Sprintf("%s\n%s", headerPostUp, getStatusOutput(cmd))
statusOutput := fmt.Sprintf("%s\n%s", headerPostUp, getStatusOutput(cmd, anonymizeFlag))
if waitErr := waitForDurationOrCancel(cmd.Context(), duration, cmd); waitErr != nil {
return waitErr
@@ -206,7 +207,7 @@ func runForDuration(cmd *cobra.Command, args []string) error {
cmd.Println("Creating debug bundle...")
headerPreDown := fmt.Sprintf("----- Netbird pre-down - Timestamp: %s - Duration: %s", time.Now().Format(time.RFC3339), duration)
statusOutput = fmt.Sprintf("%s\n%s\n%s", statusOutput, headerPreDown, getStatusOutput(cmd))
statusOutput = fmt.Sprintf("%s\n%s\n%s", statusOutput, headerPreDown, getStatusOutput(cmd, anonymizeFlag))
resp, err := client.DebugBundle(cmd.Context(), &proto.DebugBundleRequest{
Anonymize: anonymizeFlag,
@@ -271,13 +272,15 @@ func setNetworkMapPersistence(cmd *cobra.Command, args []string) error {
return nil
}
func getStatusOutput(cmd *cobra.Command) string {
func getStatusOutput(cmd *cobra.Command, anon bool) string {
var statusOutputString string
statusResp, err := getStatus(cmd.Context())
if err != nil {
cmd.PrintErrf("Failed to get status: %v\n", err)
} else {
statusOutputString = parseToFullDetailSummary(convertToStatusOutputOverview(statusResp))
statusOutputString = nbstatus.ParseToFullDetailSummary(
nbstatus.ConvertToStatusOutputOverview(statusResp, anon, "", nil, nil, nil),
)
}
return statusOutputString
}

View File

@@ -0,0 +1,98 @@
package cmd
import (
"fmt"
"sort"
"github.com/spf13/cobra"
"google.golang.org/grpc/status"
"github.com/netbirdio/netbird/client/proto"
)
var forwardingRulesCmd = &cobra.Command{
Use: "forwarding",
Short: "List forwarding rules",
Long: `Commands to list forwarding rules.`,
}
var forwardingRulesListCmd = &cobra.Command{
Use: "list",
Aliases: []string{"ls"},
Short: "List forwarding rules",
Example: " netbird forwarding list",
Long: "Commands to list forwarding rules.",
RunE: listForwardingRules,
}
func listForwardingRules(cmd *cobra.Command, _ []string) error {
conn, err := getClient(cmd)
if err != nil {
return err
}
defer conn.Close()
client := proto.NewDaemonServiceClient(conn)
resp, err := client.ForwardingRules(cmd.Context(), &proto.EmptyRequest{})
if err != nil {
return fmt.Errorf("failed to list network: %v", status.Convert(err).Message())
}
if len(resp.GetRules()) == 0 {
cmd.Println("No forwarding rules available.")
return nil
}
printForwardingRules(cmd, resp.GetRules())
return nil
}
func printForwardingRules(cmd *cobra.Command, rules []*proto.ForwardingRule) {
cmd.Println("Available forwarding rules:")
// Sort rules by translated address
sort.Slice(rules, func(i, j int) bool {
if rules[i].GetTranslatedAddress() != rules[j].GetTranslatedAddress() {
return rules[i].GetTranslatedAddress() < rules[j].GetTranslatedAddress()
}
if rules[i].GetProtocol() != rules[j].GetProtocol() {
return rules[i].GetProtocol() < rules[j].GetProtocol()
}
return getFirstPort(rules[i].GetDestinationPort()) < getFirstPort(rules[j].GetDestinationPort())
})
var lastIP string
for _, rule := range rules {
dPort := portToString(rule.GetDestinationPort())
tPort := portToString(rule.GetTranslatedPort())
if lastIP != rule.GetTranslatedAddress() {
lastIP = rule.GetTranslatedAddress()
cmd.Printf("\nTranslated peer: %s\n", rule.GetTranslatedHostname())
}
cmd.Printf(" Local %s/%s to %s:%s\n", rule.GetProtocol(), dPort, rule.GetTranslatedAddress(), tPort)
}
}
func getFirstPort(portInfo *proto.PortInfo) int {
switch v := portInfo.PortSelection.(type) {
case *proto.PortInfo_Port:
return int(v.Port)
case *proto.PortInfo_Range_:
return int(v.Range.GetStart())
default:
return 0
}
}
func portToString(translatedPort *proto.PortInfo) string {
switch v := translatedPort.PortSelection.(type) {
case *proto.PortInfo_Port:
return fmt.Sprintf("%d", v.Port)
case *proto.PortInfo_Range_:
return fmt.Sprintf("%d-%d", v.Range.GetStart(), v.Range.GetEnd())
default:
return "No port specified"
}
}

View File

@@ -85,11 +85,17 @@ var loginCmd = &cobra.Command{
client := proto.NewDaemonServiceClient(conn)
var dnsLabelsReq []string
if dnsLabelsValidated != nil {
dnsLabelsReq = dnsLabelsValidated.ToSafeStringList()
}
loginRequest := proto.LoginRequest{
SetupKey: providedSetupKey,
ManagementUrl: managementURL,
IsLinuxDesktopClient: isLinuxRunningDesktop(),
Hostname: hostName,
DnsLabels: dnsLabelsReq,
}
if rootCmd.PersistentFlags().Changed(preSharedKeyFlag) {

View File

@@ -145,6 +145,7 @@ func init() {
rootCmd.AddCommand(versionCmd)
rootCmd.AddCommand(sshCmd)
rootCmd.AddCommand(networksCMD)
rootCmd.AddCommand(forwardingRulesCmd)
rootCmd.AddCommand(debugCmd)
serviceCmd.AddCommand(runCmd, startCmd, stopCmd, restartCmd) // service control commands are subcommands of service
@@ -153,6 +154,8 @@ func init() {
networksCMD.AddCommand(routesListCmd)
networksCMD.AddCommand(routesSelectCmd, routesDeselectCmd)
forwardingRulesCmd.AddCommand(forwardingRulesListCmd)
debugCmd.AddCommand(debugBundleCmd)
debugCmd.AddCommand(logCmd)
logCmd.AddCommand(logLevelCmd)

View File

@@ -2,107 +2,20 @@ package cmd
import (
"context"
"encoding/json"
"fmt"
"net"
"net/netip"
"os"
"runtime"
"sort"
"strings"
"time"
"github.com/spf13/cobra"
"google.golang.org/grpc/status"
"gopkg.in/yaml.v3"
"github.com/netbirdio/netbird/client/anonymize"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/proto"
nbstatus "github.com/netbirdio/netbird/client/status"
"github.com/netbirdio/netbird/util"
"github.com/netbirdio/netbird/version"
)
type peerStateDetailOutput struct {
FQDN string `json:"fqdn" yaml:"fqdn"`
IP string `json:"netbirdIp" yaml:"netbirdIp"`
PubKey string `json:"publicKey" yaml:"publicKey"`
Status string `json:"status" yaml:"status"`
LastStatusUpdate time.Time `json:"lastStatusUpdate" yaml:"lastStatusUpdate"`
ConnType string `json:"connectionType" yaml:"connectionType"`
IceCandidateType iceCandidateType `json:"iceCandidateType" yaml:"iceCandidateType"`
IceCandidateEndpoint iceCandidateType `json:"iceCandidateEndpoint" yaml:"iceCandidateEndpoint"`
RelayAddress string `json:"relayAddress" yaml:"relayAddress"`
LastWireguardHandshake time.Time `json:"lastWireguardHandshake" yaml:"lastWireguardHandshake"`
TransferReceived int64 `json:"transferReceived" yaml:"transferReceived"`
TransferSent int64 `json:"transferSent" yaml:"transferSent"`
Latency time.Duration `json:"latency" yaml:"latency"`
RosenpassEnabled bool `json:"quantumResistance" yaml:"quantumResistance"`
Routes []string `json:"routes" yaml:"routes"`
Networks []string `json:"networks" yaml:"networks"`
}
type peersStateOutput struct {
Total int `json:"total" yaml:"total"`
Connected int `json:"connected" yaml:"connected"`
Details []peerStateDetailOutput `json:"details" yaml:"details"`
}
type signalStateOutput struct {
URL string `json:"url" yaml:"url"`
Connected bool `json:"connected" yaml:"connected"`
Error string `json:"error" yaml:"error"`
}
type managementStateOutput struct {
URL string `json:"url" yaml:"url"`
Connected bool `json:"connected" yaml:"connected"`
Error string `json:"error" yaml:"error"`
}
type relayStateOutputDetail struct {
URI string `json:"uri" yaml:"uri"`
Available bool `json:"available" yaml:"available"`
Error string `json:"error" yaml:"error"`
}
type relayStateOutput struct {
Total int `json:"total" yaml:"total"`
Available int `json:"available" yaml:"available"`
Details []relayStateOutputDetail `json:"details" yaml:"details"`
}
type iceCandidateType struct {
Local string `json:"local" yaml:"local"`
Remote string `json:"remote" yaml:"remote"`
}
type nsServerGroupStateOutput struct {
Servers []string `json:"servers" yaml:"servers"`
Domains []string `json:"domains" yaml:"domains"`
Enabled bool `json:"enabled" yaml:"enabled"`
Error string `json:"error" yaml:"error"`
}
type statusOutputOverview struct {
Peers peersStateOutput `json:"peers" yaml:"peers"`
CliVersion string `json:"cliVersion" yaml:"cliVersion"`
DaemonVersion string `json:"daemonVersion" yaml:"daemonVersion"`
ManagementState managementStateOutput `json:"management" yaml:"management"`
SignalState signalStateOutput `json:"signal" yaml:"signal"`
Relays relayStateOutput `json:"relays" yaml:"relays"`
IP string `json:"netbirdIp" yaml:"netbirdIp"`
PubKey string `json:"publicKey" yaml:"publicKey"`
KernelInterface bool `json:"usesKernelInterface" yaml:"usesKernelInterface"`
FQDN string `json:"fqdn" yaml:"fqdn"`
RosenpassEnabled bool `json:"quantumResistance" yaml:"quantumResistance"`
RosenpassPermissive bool `json:"quantumResistancePermissive" yaml:"quantumResistancePermissive"`
Routes []string `json:"routes" yaml:"routes"`
Networks []string `json:"networks" yaml:"networks"`
NSServerGroups []nsServerGroupStateOutput `json:"dnsServers" yaml:"dnsServers"`
}
var (
detailFlag bool
ipv4Flag bool
@@ -173,18 +86,17 @@ func statusFunc(cmd *cobra.Command, args []string) error {
return nil
}
outputInformationHolder := convertToStatusOutputOverview(resp)
var outputInformationHolder = nbstatus.ConvertToStatusOutputOverview(resp, anonymizeFlag, statusFilter, prefixNamesFilter, prefixNamesFilterMap, ipsFilterMap)
var statusOutputString string
switch {
case detailFlag:
statusOutputString = parseToFullDetailSummary(outputInformationHolder)
statusOutputString = nbstatus.ParseToFullDetailSummary(outputInformationHolder)
case jsonFlag:
statusOutputString, err = parseToJSON(outputInformationHolder)
statusOutputString, err = nbstatus.ParseToJSON(outputInformationHolder)
case yamlFlag:
statusOutputString, err = parseToYAML(outputInformationHolder)
statusOutputString, err = nbstatus.ParseToYAML(outputInformationHolder)
default:
statusOutputString = parseGeneralSummary(outputInformationHolder, false, false, false)
statusOutputString = nbstatus.ParseGeneralSummary(outputInformationHolder, false, false, false)
}
if err != nil {
@@ -214,7 +126,6 @@ func getStatus(ctx context.Context) (*proto.StatusResponse, error) {
}
func parseFilters() error {
switch strings.ToLower(statusFilter) {
case "", "disconnected", "connected":
if strings.ToLower(statusFilter) != "" {
@@ -251,175 +162,6 @@ func enableDetailFlagWhenFilterFlag() {
}
}
func convertToStatusOutputOverview(resp *proto.StatusResponse) statusOutputOverview {
pbFullStatus := resp.GetFullStatus()
managementState := pbFullStatus.GetManagementState()
managementOverview := managementStateOutput{
URL: managementState.GetURL(),
Connected: managementState.GetConnected(),
Error: managementState.Error,
}
signalState := pbFullStatus.GetSignalState()
signalOverview := signalStateOutput{
URL: signalState.GetURL(),
Connected: signalState.GetConnected(),
Error: signalState.Error,
}
relayOverview := mapRelays(pbFullStatus.GetRelays())
peersOverview := mapPeers(resp.GetFullStatus().GetPeers())
overview := statusOutputOverview{
Peers: peersOverview,
CliVersion: version.NetbirdVersion(),
DaemonVersion: resp.GetDaemonVersion(),
ManagementState: managementOverview,
SignalState: signalOverview,
Relays: relayOverview,
IP: pbFullStatus.GetLocalPeerState().GetIP(),
PubKey: pbFullStatus.GetLocalPeerState().GetPubKey(),
KernelInterface: pbFullStatus.GetLocalPeerState().GetKernelInterface(),
FQDN: pbFullStatus.GetLocalPeerState().GetFqdn(),
RosenpassEnabled: pbFullStatus.GetLocalPeerState().GetRosenpassEnabled(),
RosenpassPermissive: pbFullStatus.GetLocalPeerState().GetRosenpassPermissive(),
Routes: pbFullStatus.GetLocalPeerState().GetNetworks(),
Networks: pbFullStatus.GetLocalPeerState().GetNetworks(),
NSServerGroups: mapNSGroups(pbFullStatus.GetDnsServers()),
}
if anonymizeFlag {
anonymizer := anonymize.NewAnonymizer(anonymize.DefaultAddresses())
anonymizeOverview(anonymizer, &overview)
}
return overview
}
func mapRelays(relays []*proto.RelayState) relayStateOutput {
var relayStateDetail []relayStateOutputDetail
var relaysAvailable int
for _, relay := range relays {
available := relay.GetAvailable()
relayStateDetail = append(relayStateDetail,
relayStateOutputDetail{
URI: relay.URI,
Available: available,
Error: relay.GetError(),
},
)
if available {
relaysAvailable++
}
}
return relayStateOutput{
Total: len(relays),
Available: relaysAvailable,
Details: relayStateDetail,
}
}
func mapNSGroups(servers []*proto.NSGroupState) []nsServerGroupStateOutput {
mappedNSGroups := make([]nsServerGroupStateOutput, 0, len(servers))
for _, pbNsGroupServer := range servers {
mappedNSGroups = append(mappedNSGroups, nsServerGroupStateOutput{
Servers: pbNsGroupServer.GetServers(),
Domains: pbNsGroupServer.GetDomains(),
Enabled: pbNsGroupServer.GetEnabled(),
Error: pbNsGroupServer.GetError(),
})
}
return mappedNSGroups
}
func mapPeers(peers []*proto.PeerState) peersStateOutput {
var peersStateDetail []peerStateDetailOutput
peersConnected := 0
for _, pbPeerState := range peers {
localICE := ""
remoteICE := ""
localICEEndpoint := ""
remoteICEEndpoint := ""
relayServerAddress := ""
connType := ""
lastHandshake := time.Time{}
transferReceived := int64(0)
transferSent := int64(0)
isPeerConnected := pbPeerState.ConnStatus == peer.StatusConnected.String()
if skipDetailByFilters(pbPeerState, isPeerConnected) {
continue
}
if isPeerConnected {
peersConnected++
localICE = pbPeerState.GetLocalIceCandidateType()
remoteICE = pbPeerState.GetRemoteIceCandidateType()
localICEEndpoint = pbPeerState.GetLocalIceCandidateEndpoint()
remoteICEEndpoint = pbPeerState.GetRemoteIceCandidateEndpoint()
connType = "P2P"
if pbPeerState.Relayed {
connType = "Relayed"
}
relayServerAddress = pbPeerState.GetRelayAddress()
lastHandshake = pbPeerState.GetLastWireguardHandshake().AsTime().Local()
transferReceived = pbPeerState.GetBytesRx()
transferSent = pbPeerState.GetBytesTx()
}
timeLocal := pbPeerState.GetConnStatusUpdate().AsTime().Local()
peerState := peerStateDetailOutput{
IP: pbPeerState.GetIP(),
PubKey: pbPeerState.GetPubKey(),
Status: pbPeerState.GetConnStatus(),
LastStatusUpdate: timeLocal,
ConnType: connType,
IceCandidateType: iceCandidateType{
Local: localICE,
Remote: remoteICE,
},
IceCandidateEndpoint: iceCandidateType{
Local: localICEEndpoint,
Remote: remoteICEEndpoint,
},
RelayAddress: relayServerAddress,
FQDN: pbPeerState.GetFqdn(),
LastWireguardHandshake: lastHandshake,
TransferReceived: transferReceived,
TransferSent: transferSent,
Latency: pbPeerState.GetLatency().AsDuration(),
RosenpassEnabled: pbPeerState.GetRosenpassEnabled(),
Routes: pbPeerState.GetNetworks(),
Networks: pbPeerState.GetNetworks(),
}
peersStateDetail = append(peersStateDetail, peerState)
}
sortPeersByIP(peersStateDetail)
peersOverview := peersStateOutput{
Total: len(peersStateDetail),
Connected: peersConnected,
Details: peersStateDetail,
}
return peersOverview
}
func sortPeersByIP(peersStateDetail []peerStateDetailOutput) {
if len(peersStateDetail) > 0 {
sort.SliceStable(peersStateDetail, func(i, j int) bool {
iAddr, _ := netip.ParseAddr(peersStateDetail[i].IP)
jAddr, _ := netip.ParseAddr(peersStateDetail[j].IP)
return iAddr.Compare(jAddr) == -1
})
}
}
func parseInterfaceIP(interfaceIP string) string {
ip, _, err := net.ParseCIDR(interfaceIP)
if err != nil {
@@ -427,452 +169,3 @@ func parseInterfaceIP(interfaceIP string) string {
}
return fmt.Sprintf("%s\n", ip)
}
func parseToJSON(overview statusOutputOverview) (string, error) {
jsonBytes, err := json.Marshal(overview)
if err != nil {
return "", fmt.Errorf("json marshal failed")
}
return string(jsonBytes), err
}
func parseToYAML(overview statusOutputOverview) (string, error) {
yamlBytes, err := yaml.Marshal(overview)
if err != nil {
return "", fmt.Errorf("yaml marshal failed")
}
return string(yamlBytes), nil
}
func parseGeneralSummary(overview statusOutputOverview, showURL bool, showRelays bool, showNameServers bool) string {
var managementConnString string
if overview.ManagementState.Connected {
managementConnString = "Connected"
if showURL {
managementConnString = fmt.Sprintf("%s to %s", managementConnString, overview.ManagementState.URL)
}
} else {
managementConnString = "Disconnected"
if overview.ManagementState.Error != "" {
managementConnString = fmt.Sprintf("%s, reason: %s", managementConnString, overview.ManagementState.Error)
}
}
var signalConnString string
if overview.SignalState.Connected {
signalConnString = "Connected"
if showURL {
signalConnString = fmt.Sprintf("%s to %s", signalConnString, overview.SignalState.URL)
}
} else {
signalConnString = "Disconnected"
if overview.SignalState.Error != "" {
signalConnString = fmt.Sprintf("%s, reason: %s", signalConnString, overview.SignalState.Error)
}
}
interfaceTypeString := "Userspace"
interfaceIP := overview.IP
if overview.KernelInterface {
interfaceTypeString = "Kernel"
} else if overview.IP == "" {
interfaceTypeString = "N/A"
interfaceIP = "N/A"
}
var relaysString string
if showRelays {
for _, relay := range overview.Relays.Details {
available := "Available"
reason := ""
if !relay.Available {
available = "Unavailable"
reason = fmt.Sprintf(", reason: %s", relay.Error)
}
relaysString += fmt.Sprintf("\n [%s] is %s%s", relay.URI, available, reason)
}
} else {
relaysString = fmt.Sprintf("%d/%d Available", overview.Relays.Available, overview.Relays.Total)
}
networks := "-"
if len(overview.Networks) > 0 {
sort.Strings(overview.Networks)
networks = strings.Join(overview.Networks, ", ")
}
var dnsServersString string
if showNameServers {
for _, nsServerGroup := range overview.NSServerGroups {
enabled := "Available"
if !nsServerGroup.Enabled {
enabled = "Unavailable"
}
errorString := ""
if nsServerGroup.Error != "" {
errorString = fmt.Sprintf(", reason: %s", nsServerGroup.Error)
errorString = strings.TrimSpace(errorString)
}
domainsString := strings.Join(nsServerGroup.Domains, ", ")
if domainsString == "" {
domainsString = "." // Show "." for the default zone
}
dnsServersString += fmt.Sprintf(
"\n [%s] for [%s] is %s%s",
strings.Join(nsServerGroup.Servers, ", "),
domainsString,
enabled,
errorString,
)
}
} else {
dnsServersString = fmt.Sprintf("%d/%d Available", countEnabled(overview.NSServerGroups), len(overview.NSServerGroups))
}
rosenpassEnabledStatus := "false"
if overview.RosenpassEnabled {
rosenpassEnabledStatus = "true"
if overview.RosenpassPermissive {
rosenpassEnabledStatus = "true (permissive)" //nolint:gosec
}
}
peersCountString := fmt.Sprintf("%d/%d Connected", overview.Peers.Connected, overview.Peers.Total)
goos := runtime.GOOS
goarch := runtime.GOARCH
goarm := ""
if goarch == "arm" {
goarm = fmt.Sprintf(" (ARMv%s)", os.Getenv("GOARM"))
}
summary := fmt.Sprintf(
"OS: %s\n"+
"Daemon version: %s\n"+
"CLI version: %s\n"+
"Management: %s\n"+
"Signal: %s\n"+
"Relays: %s\n"+
"Nameservers: %s\n"+
"FQDN: %s\n"+
"NetBird IP: %s\n"+
"Interface type: %s\n"+
"Quantum resistance: %s\n"+
"Routes: %s\n"+
"Networks: %s\n"+
"Peers count: %s\n",
fmt.Sprintf("%s/%s%s", goos, goarch, goarm),
overview.DaemonVersion,
version.NetbirdVersion(),
managementConnString,
signalConnString,
relaysString,
dnsServersString,
overview.FQDN,
interfaceIP,
interfaceTypeString,
rosenpassEnabledStatus,
networks,
networks,
peersCountString,
)
return summary
}
func parseToFullDetailSummary(overview statusOutputOverview) string {
parsedPeersString := parsePeers(overview.Peers, overview.RosenpassEnabled, overview.RosenpassPermissive)
summary := parseGeneralSummary(overview, true, true, true)
return fmt.Sprintf(
"Peers detail:"+
"%s\n"+
"%s",
parsedPeersString,
summary,
)
}
func parsePeers(peers peersStateOutput, rosenpassEnabled, rosenpassPermissive bool) string {
var (
peersString = ""
)
for _, peerState := range peers.Details {
localICE := "-"
if peerState.IceCandidateType.Local != "" {
localICE = peerState.IceCandidateType.Local
}
remoteICE := "-"
if peerState.IceCandidateType.Remote != "" {
remoteICE = peerState.IceCandidateType.Remote
}
localICEEndpoint := "-"
if peerState.IceCandidateEndpoint.Local != "" {
localICEEndpoint = peerState.IceCandidateEndpoint.Local
}
remoteICEEndpoint := "-"
if peerState.IceCandidateEndpoint.Remote != "" {
remoteICEEndpoint = peerState.IceCandidateEndpoint.Remote
}
rosenpassEnabledStatus := "false"
if rosenpassEnabled {
if peerState.RosenpassEnabled {
rosenpassEnabledStatus = "true"
} else {
if rosenpassPermissive {
rosenpassEnabledStatus = "false (remote didn't enable quantum resistance)"
} else {
rosenpassEnabledStatus = "false (connection won't work without a permissive mode)"
}
}
} else {
if peerState.RosenpassEnabled {
rosenpassEnabledStatus = "false (connection might not work without a remote permissive mode)"
}
}
networks := "-"
if len(peerState.Networks) > 0 {
sort.Strings(peerState.Networks)
networks = strings.Join(peerState.Networks, ", ")
}
peerString := fmt.Sprintf(
"\n %s:\n"+
" NetBird IP: %s\n"+
" Public key: %s\n"+
" Status: %s\n"+
" -- detail --\n"+
" Connection type: %s\n"+
" ICE candidate (Local/Remote): %s/%s\n"+
" ICE candidate endpoints (Local/Remote): %s/%s\n"+
" Relay server address: %s\n"+
" Last connection update: %s\n"+
" Last WireGuard handshake: %s\n"+
" Transfer status (received/sent) %s/%s\n"+
" Quantum resistance: %s\n"+
" Routes: %s\n"+
" Networks: %s\n"+
" Latency: %s\n",
peerState.FQDN,
peerState.IP,
peerState.PubKey,
peerState.Status,
peerState.ConnType,
localICE,
remoteICE,
localICEEndpoint,
remoteICEEndpoint,
peerState.RelayAddress,
timeAgo(peerState.LastStatusUpdate),
timeAgo(peerState.LastWireguardHandshake),
toIEC(peerState.TransferReceived),
toIEC(peerState.TransferSent),
rosenpassEnabledStatus,
networks,
networks,
peerState.Latency.String(),
)
peersString += peerString
}
return peersString
}
func skipDetailByFilters(peerState *proto.PeerState, isConnected bool) bool {
statusEval := false
ipEval := false
nameEval := true
if statusFilter != "" {
lowerStatusFilter := strings.ToLower(statusFilter)
if lowerStatusFilter == "disconnected" && isConnected {
statusEval = true
} else if lowerStatusFilter == "connected" && !isConnected {
statusEval = true
}
}
if len(ipsFilter) > 0 {
_, ok := ipsFilterMap[peerState.IP]
if !ok {
ipEval = true
}
}
if len(prefixNamesFilter) > 0 {
for prefixNameFilter := range prefixNamesFilterMap {
if strings.HasPrefix(peerState.Fqdn, prefixNameFilter) {
nameEval = false
break
}
}
} else {
nameEval = false
}
return statusEval || ipEval || nameEval
}
func toIEC(b int64) string {
const unit = 1024
if b < unit {
return fmt.Sprintf("%d B", b)
}
div, exp := int64(unit), 0
for n := b / unit; n >= unit; n /= unit {
div *= unit
exp++
}
return fmt.Sprintf("%.1f %ciB",
float64(b)/float64(div), "KMGTPE"[exp])
}
func countEnabled(dnsServers []nsServerGroupStateOutput) int {
count := 0
for _, server := range dnsServers {
if server.Enabled {
count++
}
}
return count
}
// timeAgo returns a string representing the duration since the provided time in a human-readable format.
func timeAgo(t time.Time) string {
if t.IsZero() || t.Equal(time.Unix(0, 0)) {
return "-"
}
duration := time.Since(t)
switch {
case duration < time.Second:
return "Now"
case duration < time.Minute:
seconds := int(duration.Seconds())
if seconds == 1 {
return "1 second ago"
}
return fmt.Sprintf("%d seconds ago", seconds)
case duration < time.Hour:
minutes := int(duration.Minutes())
seconds := int(duration.Seconds()) % 60
if minutes == 1 {
if seconds == 1 {
return "1 minute, 1 second ago"
} else if seconds > 0 {
return fmt.Sprintf("1 minute, %d seconds ago", seconds)
}
return "1 minute ago"
}
if seconds > 0 {
return fmt.Sprintf("%d minutes, %d seconds ago", minutes, seconds)
}
return fmt.Sprintf("%d minutes ago", minutes)
case duration < 24*time.Hour:
hours := int(duration.Hours())
minutes := int(duration.Minutes()) % 60
if hours == 1 {
if minutes == 1 {
return "1 hour, 1 minute ago"
} else if minutes > 0 {
return fmt.Sprintf("1 hour, %d minutes ago", minutes)
}
return "1 hour ago"
}
if minutes > 0 {
return fmt.Sprintf("%d hours, %d minutes ago", hours, minutes)
}
return fmt.Sprintf("%d hours ago", hours)
}
days := int(duration.Hours()) / 24
hours := int(duration.Hours()) % 24
if days == 1 {
if hours == 1 {
return "1 day, 1 hour ago"
} else if hours > 0 {
return fmt.Sprintf("1 day, %d hours ago", hours)
}
return "1 day ago"
}
if hours > 0 {
return fmt.Sprintf("%d days, %d hours ago", days, hours)
}
return fmt.Sprintf("%d days ago", days)
}
func anonymizePeerDetail(a *anonymize.Anonymizer, peer *peerStateDetailOutput) {
peer.FQDN = a.AnonymizeDomain(peer.FQDN)
if localIP, port, err := net.SplitHostPort(peer.IceCandidateEndpoint.Local); err == nil {
peer.IceCandidateEndpoint.Local = fmt.Sprintf("%s:%s", a.AnonymizeIPString(localIP), port)
}
if remoteIP, port, err := net.SplitHostPort(peer.IceCandidateEndpoint.Remote); err == nil {
peer.IceCandidateEndpoint.Remote = fmt.Sprintf("%s:%s", a.AnonymizeIPString(remoteIP), port)
}
peer.RelayAddress = a.AnonymizeURI(peer.RelayAddress)
for i, route := range peer.Networks {
peer.Networks[i] = a.AnonymizeIPString(route)
}
for i, route := range peer.Networks {
peer.Networks[i] = a.AnonymizeRoute(route)
}
for i, route := range peer.Routes {
peer.Routes[i] = a.AnonymizeIPString(route)
}
for i, route := range peer.Routes {
peer.Routes[i] = a.AnonymizeRoute(route)
}
}
func anonymizeOverview(a *anonymize.Anonymizer, overview *statusOutputOverview) {
for i, peer := range overview.Peers.Details {
peer := peer
anonymizePeerDetail(a, &peer)
overview.Peers.Details[i] = peer
}
overview.ManagementState.URL = a.AnonymizeURI(overview.ManagementState.URL)
overview.ManagementState.Error = a.AnonymizeString(overview.ManagementState.Error)
overview.SignalState.URL = a.AnonymizeURI(overview.SignalState.URL)
overview.SignalState.Error = a.AnonymizeString(overview.SignalState.Error)
overview.IP = a.AnonymizeIPString(overview.IP)
for i, detail := range overview.Relays.Details {
detail.URI = a.AnonymizeURI(detail.URI)
detail.Error = a.AnonymizeString(detail.Error)
overview.Relays.Details[i] = detail
}
for i, nsGroup := range overview.NSServerGroups {
for j, domain := range nsGroup.Domains {
overview.NSServerGroups[i].Domains[j] = a.AnonymizeDomain(domain)
}
for j, ns := range nsGroup.Servers {
host, port, err := net.SplitHostPort(ns)
if err == nil {
overview.NSServerGroups[i].Servers[j] = fmt.Sprintf("%s:%s", a.AnonymizeIPString(host), port)
}
}
}
for i, route := range overview.Networks {
overview.Networks[i] = a.AnonymizeRoute(route)
}
for i, route := range overview.Routes {
overview.Routes[i] = a.AnonymizeRoute(route)
}
overview.FQDN = a.AnonymizeDomain(overview.FQDN)
}

View File

@@ -1,597 +1,11 @@
package cmd
import (
"bytes"
"encoding/json"
"fmt"
"runtime"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/protobuf/types/known/durationpb"
"google.golang.org/protobuf/types/known/timestamppb"
"github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/version"
)
func init() {
loc, err := time.LoadLocation("UTC")
if err != nil {
panic(err)
}
time.Local = loc
}
var resp = &proto.StatusResponse{
Status: "Connected",
FullStatus: &proto.FullStatus{
Peers: []*proto.PeerState{
{
IP: "192.168.178.101",
PubKey: "Pubkey1",
Fqdn: "peer-1.awesome-domain.com",
ConnStatus: "Connected",
ConnStatusUpdate: timestamppb.New(time.Date(2001, time.Month(1), 1, 1, 1, 1, 0, time.UTC)),
Relayed: false,
LocalIceCandidateType: "",
RemoteIceCandidateType: "",
LocalIceCandidateEndpoint: "",
RemoteIceCandidateEndpoint: "",
LastWireguardHandshake: timestamppb.New(time.Date(2001, time.Month(1), 1, 1, 1, 2, 0, time.UTC)),
BytesRx: 200,
BytesTx: 100,
Networks: []string{
"10.1.0.0/24",
},
Latency: durationpb.New(time.Duration(10000000)),
},
{
IP: "192.168.178.102",
PubKey: "Pubkey2",
Fqdn: "peer-2.awesome-domain.com",
ConnStatus: "Connected",
ConnStatusUpdate: timestamppb.New(time.Date(2002, time.Month(2), 2, 2, 2, 2, 0, time.UTC)),
Relayed: true,
LocalIceCandidateType: "relay",
RemoteIceCandidateType: "prflx",
LocalIceCandidateEndpoint: "10.0.0.1:10001",
RemoteIceCandidateEndpoint: "10.0.10.1:10002",
LastWireguardHandshake: timestamppb.New(time.Date(2002, time.Month(2), 2, 2, 2, 3, 0, time.UTC)),
BytesRx: 2000,
BytesTx: 1000,
Latency: durationpb.New(time.Duration(10000000)),
},
},
ManagementState: &proto.ManagementState{
URL: "my-awesome-management.com:443",
Connected: true,
Error: "",
},
SignalState: &proto.SignalState{
URL: "my-awesome-signal.com:443",
Connected: true,
Error: "",
},
Relays: []*proto.RelayState{
{
URI: "stun:my-awesome-stun.com:3478",
Available: true,
Error: "",
},
{
URI: "turns:my-awesome-turn.com:443?transport=tcp",
Available: false,
Error: "context: deadline exceeded",
},
},
LocalPeerState: &proto.LocalPeerState{
IP: "192.168.178.100/16",
PubKey: "Some-Pub-Key",
KernelInterface: true,
Fqdn: "some-localhost.awesome-domain.com",
Networks: []string{
"10.10.0.0/24",
},
},
DnsServers: []*proto.NSGroupState{
{
Servers: []string{
"8.8.8.8:53",
},
Domains: nil,
Enabled: true,
Error: "",
},
{
Servers: []string{
"1.1.1.1:53",
"2.2.2.2:53",
},
Domains: []string{
"example.com",
"example.net",
},
Enabled: false,
Error: "timeout",
},
},
},
DaemonVersion: "0.14.1",
}
var overview = statusOutputOverview{
Peers: peersStateOutput{
Total: 2,
Connected: 2,
Details: []peerStateDetailOutput{
{
IP: "192.168.178.101",
PubKey: "Pubkey1",
FQDN: "peer-1.awesome-domain.com",
Status: "Connected",
LastStatusUpdate: time.Date(2001, 1, 1, 1, 1, 1, 0, time.UTC),
ConnType: "P2P",
IceCandidateType: iceCandidateType{
Local: "",
Remote: "",
},
IceCandidateEndpoint: iceCandidateType{
Local: "",
Remote: "",
},
LastWireguardHandshake: time.Date(2001, 1, 1, 1, 1, 2, 0, time.UTC),
TransferReceived: 200,
TransferSent: 100,
Routes: []string{
"10.1.0.0/24",
},
Networks: []string{
"10.1.0.0/24",
},
Latency: time.Duration(10000000),
},
{
IP: "192.168.178.102",
PubKey: "Pubkey2",
FQDN: "peer-2.awesome-domain.com",
Status: "Connected",
LastStatusUpdate: time.Date(2002, 2, 2, 2, 2, 2, 0, time.UTC),
ConnType: "Relayed",
IceCandidateType: iceCandidateType{
Local: "relay",
Remote: "prflx",
},
IceCandidateEndpoint: iceCandidateType{
Local: "10.0.0.1:10001",
Remote: "10.0.10.1:10002",
},
LastWireguardHandshake: time.Date(2002, 2, 2, 2, 2, 3, 0, time.UTC),
TransferReceived: 2000,
TransferSent: 1000,
Latency: time.Duration(10000000),
},
},
},
CliVersion: version.NetbirdVersion(),
DaemonVersion: "0.14.1",
ManagementState: managementStateOutput{
URL: "my-awesome-management.com:443",
Connected: true,
Error: "",
},
SignalState: signalStateOutput{
URL: "my-awesome-signal.com:443",
Connected: true,
Error: "",
},
Relays: relayStateOutput{
Total: 2,
Available: 1,
Details: []relayStateOutputDetail{
{
URI: "stun:my-awesome-stun.com:3478",
Available: true,
Error: "",
},
{
URI: "turns:my-awesome-turn.com:443?transport=tcp",
Available: false,
Error: "context: deadline exceeded",
},
},
},
IP: "192.168.178.100/16",
PubKey: "Some-Pub-Key",
KernelInterface: true,
FQDN: "some-localhost.awesome-domain.com",
NSServerGroups: []nsServerGroupStateOutput{
{
Servers: []string{
"8.8.8.8:53",
},
Domains: nil,
Enabled: true,
Error: "",
},
{
Servers: []string{
"1.1.1.1:53",
"2.2.2.2:53",
},
Domains: []string{
"example.com",
"example.net",
},
Enabled: false,
Error: "timeout",
},
},
Routes: []string{
"10.10.0.0/24",
},
Networks: []string{
"10.10.0.0/24",
},
}
func TestConversionFromFullStatusToOutputOverview(t *testing.T) {
convertedResult := convertToStatusOutputOverview(resp)
assert.Equal(t, overview, convertedResult)
}
func TestSortingOfPeers(t *testing.T) {
peers := []peerStateDetailOutput{
{
IP: "192.168.178.104",
},
{
IP: "192.168.178.102",
},
{
IP: "192.168.178.101",
},
{
IP: "192.168.178.105",
},
{
IP: "192.168.178.103",
},
}
sortPeersByIP(peers)
assert.Equal(t, peers[3].IP, "192.168.178.104")
}
func TestParsingToJSON(t *testing.T) {
jsonString, _ := parseToJSON(overview)
//@formatter:off
expectedJSONString := `
{
"peers": {
"total": 2,
"connected": 2,
"details": [
{
"fqdn": "peer-1.awesome-domain.com",
"netbirdIp": "192.168.178.101",
"publicKey": "Pubkey1",
"status": "Connected",
"lastStatusUpdate": "2001-01-01T01:01:01Z",
"connectionType": "P2P",
"iceCandidateType": {
"local": "",
"remote": ""
},
"iceCandidateEndpoint": {
"local": "",
"remote": ""
},
"relayAddress": "",
"lastWireguardHandshake": "2001-01-01T01:01:02Z",
"transferReceived": 200,
"transferSent": 100,
"latency": 10000000,
"quantumResistance": false,
"routes": [
"10.1.0.0/24"
],
"networks": [
"10.1.0.0/24"
]
},
{
"fqdn": "peer-2.awesome-domain.com",
"netbirdIp": "192.168.178.102",
"publicKey": "Pubkey2",
"status": "Connected",
"lastStatusUpdate": "2002-02-02T02:02:02Z",
"connectionType": "Relayed",
"iceCandidateType": {
"local": "relay",
"remote": "prflx"
},
"iceCandidateEndpoint": {
"local": "10.0.0.1:10001",
"remote": "10.0.10.1:10002"
},
"relayAddress": "",
"lastWireguardHandshake": "2002-02-02T02:02:03Z",
"transferReceived": 2000,
"transferSent": 1000,
"latency": 10000000,
"quantumResistance": false,
"routes": null,
"networks": null
}
]
},
"cliVersion": "development",
"daemonVersion": "0.14.1",
"management": {
"url": "my-awesome-management.com:443",
"connected": true,
"error": ""
},
"signal": {
"url": "my-awesome-signal.com:443",
"connected": true,
"error": ""
},
"relays": {
"total": 2,
"available": 1,
"details": [
{
"uri": "stun:my-awesome-stun.com:3478",
"available": true,
"error": ""
},
{
"uri": "turns:my-awesome-turn.com:443?transport=tcp",
"available": false,
"error": "context: deadline exceeded"
}
]
},
"netbirdIp": "192.168.178.100/16",
"publicKey": "Some-Pub-Key",
"usesKernelInterface": true,
"fqdn": "some-localhost.awesome-domain.com",
"quantumResistance": false,
"quantumResistancePermissive": false,
"routes": [
"10.10.0.0/24"
],
"networks": [
"10.10.0.0/24"
],
"dnsServers": [
{
"servers": [
"8.8.8.8:53"
],
"domains": null,
"enabled": true,
"error": ""
},
{
"servers": [
"1.1.1.1:53",
"2.2.2.2:53"
],
"domains": [
"example.com",
"example.net"
],
"enabled": false,
"error": "timeout"
}
]
}`
// @formatter:on
var expectedJSON bytes.Buffer
require.NoError(t, json.Compact(&expectedJSON, []byte(expectedJSONString)))
assert.Equal(t, expectedJSON.String(), jsonString)
}
func TestParsingToYAML(t *testing.T) {
yaml, _ := parseToYAML(overview)
expectedYAML :=
`peers:
total: 2
connected: 2
details:
- fqdn: peer-1.awesome-domain.com
netbirdIp: 192.168.178.101
publicKey: Pubkey1
status: Connected
lastStatusUpdate: 2001-01-01T01:01:01Z
connectionType: P2P
iceCandidateType:
local: ""
remote: ""
iceCandidateEndpoint:
local: ""
remote: ""
relayAddress: ""
lastWireguardHandshake: 2001-01-01T01:01:02Z
transferReceived: 200
transferSent: 100
latency: 10ms
quantumResistance: false
routes:
- 10.1.0.0/24
networks:
- 10.1.0.0/24
- fqdn: peer-2.awesome-domain.com
netbirdIp: 192.168.178.102
publicKey: Pubkey2
status: Connected
lastStatusUpdate: 2002-02-02T02:02:02Z
connectionType: Relayed
iceCandidateType:
local: relay
remote: prflx
iceCandidateEndpoint:
local: 10.0.0.1:10001
remote: 10.0.10.1:10002
relayAddress: ""
lastWireguardHandshake: 2002-02-02T02:02:03Z
transferReceived: 2000
transferSent: 1000
latency: 10ms
quantumResistance: false
routes: []
networks: []
cliVersion: development
daemonVersion: 0.14.1
management:
url: my-awesome-management.com:443
connected: true
error: ""
signal:
url: my-awesome-signal.com:443
connected: true
error: ""
relays:
total: 2
available: 1
details:
- uri: stun:my-awesome-stun.com:3478
available: true
error: ""
- uri: turns:my-awesome-turn.com:443?transport=tcp
available: false
error: 'context: deadline exceeded'
netbirdIp: 192.168.178.100/16
publicKey: Some-Pub-Key
usesKernelInterface: true
fqdn: some-localhost.awesome-domain.com
quantumResistance: false
quantumResistancePermissive: false
routes:
- 10.10.0.0/24
networks:
- 10.10.0.0/24
dnsServers:
- servers:
- 8.8.8.8:53
domains: []
enabled: true
error: ""
- servers:
- 1.1.1.1:53
- 2.2.2.2:53
domains:
- example.com
- example.net
enabled: false
error: timeout
`
assert.Equal(t, expectedYAML, yaml)
}
func TestParsingToDetail(t *testing.T) {
// Calculate time ago based on the fixture dates
lastConnectionUpdate1 := timeAgo(overview.Peers.Details[0].LastStatusUpdate)
lastHandshake1 := timeAgo(overview.Peers.Details[0].LastWireguardHandshake)
lastConnectionUpdate2 := timeAgo(overview.Peers.Details[1].LastStatusUpdate)
lastHandshake2 := timeAgo(overview.Peers.Details[1].LastWireguardHandshake)
detail := parseToFullDetailSummary(overview)
expectedDetail := fmt.Sprintf(
`Peers detail:
peer-1.awesome-domain.com:
NetBird IP: 192.168.178.101
Public key: Pubkey1
Status: Connected
-- detail --
Connection type: P2P
ICE candidate (Local/Remote): -/-
ICE candidate endpoints (Local/Remote): -/-
Relay server address:
Last connection update: %s
Last WireGuard handshake: %s
Transfer status (received/sent) 200 B/100 B
Quantum resistance: false
Routes: 10.1.0.0/24
Networks: 10.1.0.0/24
Latency: 10ms
peer-2.awesome-domain.com:
NetBird IP: 192.168.178.102
Public key: Pubkey2
Status: Connected
-- detail --
Connection type: Relayed
ICE candidate (Local/Remote): relay/prflx
ICE candidate endpoints (Local/Remote): 10.0.0.1:10001/10.0.10.1:10002
Relay server address:
Last connection update: %s
Last WireGuard handshake: %s
Transfer status (received/sent) 2.0 KiB/1000 B
Quantum resistance: false
Routes: -
Networks: -
Latency: 10ms
OS: %s/%s
Daemon version: 0.14.1
CLI version: %s
Management: Connected to my-awesome-management.com:443
Signal: Connected to my-awesome-signal.com:443
Relays:
[stun:my-awesome-stun.com:3478] is Available
[turns:my-awesome-turn.com:443?transport=tcp] is Unavailable, reason: context: deadline exceeded
Nameservers:
[8.8.8.8:53] for [.] is Available
[1.1.1.1:53, 2.2.2.2:53] for [example.com, example.net] is Unavailable, reason: timeout
FQDN: some-localhost.awesome-domain.com
NetBird IP: 192.168.178.100/16
Interface type: Kernel
Quantum resistance: false
Routes: 10.10.0.0/24
Networks: 10.10.0.0/24
Peers count: 2/2 Connected
`, lastConnectionUpdate1, lastHandshake1, lastConnectionUpdate2, lastHandshake2, runtime.GOOS, runtime.GOARCH, overview.CliVersion)
assert.Equal(t, expectedDetail, detail)
}
func TestParsingToShortVersion(t *testing.T) {
shortVersion := parseGeneralSummary(overview, false, false, false)
expectedString := fmt.Sprintf("OS: %s/%s", runtime.GOOS, runtime.GOARCH) + `
Daemon version: 0.14.1
CLI version: development
Management: Connected
Signal: Connected
Relays: 1/2 Available
Nameservers: 1/2 Available
FQDN: some-localhost.awesome-domain.com
NetBird IP: 192.168.178.100/16
Interface type: Kernel
Quantum resistance: false
Routes: 10.10.0.0/24
Networks: 10.10.0.0/24
Peers count: 2/2 Connected
`
assert.Equal(t, expectedString, shortVersion)
}
func TestParsingOfIP(t *testing.T) {
InterfaceIP := "192.168.178.123/16"
@@ -599,31 +13,3 @@ func TestParsingOfIP(t *testing.T) {
assert.Equal(t, "192.168.178.123\n", parsedIP)
}
func TestTimeAgo(t *testing.T) {
now := time.Now()
cases := []struct {
name string
input time.Time
expected string
}{
{"Now", now, "Now"},
{"Seconds ago", now.Add(-10 * time.Second), "10 seconds ago"},
{"One minute ago", now.Add(-1 * time.Minute), "1 minute ago"},
{"Minutes and seconds ago", now.Add(-(1*time.Minute + 30*time.Second)), "1 minute, 30 seconds ago"},
{"One hour ago", now.Add(-1 * time.Hour), "1 hour ago"},
{"Hours and minutes ago", now.Add(-(2*time.Hour + 15*time.Minute)), "2 hours, 15 minutes ago"},
{"One day ago", now.Add(-24 * time.Hour), "1 day ago"},
{"Multiple days ago", now.Add(-(72*time.Hour + 20*time.Minute)), "3 days ago"},
{"Zero time", time.Time{}, "-"},
{"Unix zero time", time.Unix(0, 0), "-"},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
result := timeAgo(tc.input)
assert.Equal(t, tc.expected, result, "Failed %s", tc.name)
})
}
}

View File

@@ -10,6 +10,7 @@ import (
"go.opentelemetry.io/otel"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
"github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/telemetry"
@@ -89,13 +90,13 @@ func startManagement(t *testing.T, config *mgmt.Config, testFile string) (*grpc.
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
require.NoError(t, err)
accountManager, err := mgmt.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, iv, metrics)
accountManager, err := mgmt.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, iv, metrics, port_forwarding.NewControllerMock(), settings.NewManagerMock())
if err != nil {
t.Fatal(err)
}
secretsManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay)
mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, settings.NewManager(store), peersUpdateManager, secretsManager, nil, nil)
mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, settings.NewManagerMock(), peersUpdateManager, secretsManager, nil, nil, nil)
if err != nil {
t.Fatal(err)
}

137
client/cmd/trace.go Normal file
View File

@@ -0,0 +1,137 @@
package cmd
import (
"fmt"
"math/rand"
"strings"
"github.com/spf13/cobra"
"google.golang.org/grpc/status"
"github.com/netbirdio/netbird/client/proto"
)
var traceCmd = &cobra.Command{
Use: "trace <direction> <source-ip> <dest-ip>",
Short: "Trace a packet through the firewall",
Example: `
netbird debug trace in 192.168.1.10 10.10.0.2 -p tcp --sport 12345 --dport 443 --syn --ack
netbird debug trace out 10.10.0.1 8.8.8.8 -p udp --dport 53
netbird debug trace in 10.10.0.2 10.10.0.1 -p icmp --type 8 --code 0
netbird debug trace in 100.64.1.1 self -p tcp --dport 80`,
Args: cobra.ExactArgs(3),
RunE: tracePacket,
}
func init() {
debugCmd.AddCommand(traceCmd)
traceCmd.Flags().StringP("protocol", "p", "tcp", "Protocol (tcp/udp/icmp)")
traceCmd.Flags().Uint16("sport", 0, "Source port")
traceCmd.Flags().Uint16("dport", 0, "Destination port")
traceCmd.Flags().Uint8("icmp-type", 0, "ICMP type")
traceCmd.Flags().Uint8("icmp-code", 0, "ICMP code")
traceCmd.Flags().Bool("syn", false, "TCP SYN flag")
traceCmd.Flags().Bool("ack", false, "TCP ACK flag")
traceCmd.Flags().Bool("fin", false, "TCP FIN flag")
traceCmd.Flags().Bool("rst", false, "TCP RST flag")
traceCmd.Flags().Bool("psh", false, "TCP PSH flag")
traceCmd.Flags().Bool("urg", false, "TCP URG flag")
}
func tracePacket(cmd *cobra.Command, args []string) error {
direction := strings.ToLower(args[0])
if direction != "in" && direction != "out" {
return fmt.Errorf("invalid direction: use 'in' or 'out'")
}
protocol := cmd.Flag("protocol").Value.String()
if protocol != "tcp" && protocol != "udp" && protocol != "icmp" {
return fmt.Errorf("invalid protocol: use tcp/udp/icmp")
}
sport, err := cmd.Flags().GetUint16("sport")
if err != nil {
return fmt.Errorf("invalid source port: %v", err)
}
dport, err := cmd.Flags().GetUint16("dport")
if err != nil {
return fmt.Errorf("invalid destination port: %v", err)
}
// For TCP/UDP, generate random ephemeral port (49152-65535) if not specified
if protocol != "icmp" {
if sport == 0 {
sport = uint16(rand.Intn(16383) + 49152)
}
if dport == 0 {
dport = uint16(rand.Intn(16383) + 49152)
}
}
var tcpFlags *proto.TCPFlags
if protocol == "tcp" {
syn, _ := cmd.Flags().GetBool("syn")
ack, _ := cmd.Flags().GetBool("ack")
fin, _ := cmd.Flags().GetBool("fin")
rst, _ := cmd.Flags().GetBool("rst")
psh, _ := cmd.Flags().GetBool("psh")
urg, _ := cmd.Flags().GetBool("urg")
tcpFlags = &proto.TCPFlags{
Syn: syn,
Ack: ack,
Fin: fin,
Rst: rst,
Psh: psh,
Urg: urg,
}
}
icmpType, _ := cmd.Flags().GetUint32("icmp-type")
icmpCode, _ := cmd.Flags().GetUint32("icmp-code")
conn, err := getClient(cmd)
if err != nil {
return err
}
defer conn.Close()
client := proto.NewDaemonServiceClient(conn)
resp, err := client.TracePacket(cmd.Context(), &proto.TracePacketRequest{
SourceIp: args[1],
DestinationIp: args[2],
Protocol: protocol,
SourcePort: uint32(sport),
DestinationPort: uint32(dport),
Direction: direction,
TcpFlags: tcpFlags,
IcmpType: &icmpType,
IcmpCode: &icmpCode,
})
if err != nil {
return fmt.Errorf("trace failed: %v", status.Convert(err).Message())
}
printTrace(cmd, args[1], args[2], protocol, sport, dport, resp)
return nil
}
func printTrace(cmd *cobra.Command, src, dst, proto string, sport, dport uint16, resp *proto.TracePacketResponse) {
cmd.Printf("Packet trace %s:%d -> %s:%d (%s)\n\n", src, sport, dst, dport, strings.ToUpper(proto))
for _, stage := range resp.Stages {
if stage.ForwardingDetails != nil {
cmd.Printf("%s: %s [%s]\n", stage.Name, stage.Message, *stage.ForwardingDetails)
} else {
cmd.Printf("%s: %s\n", stage.Name, stage.Message)
}
}
disposition := map[bool]string{
true: "\033[32mALLOWED\033[0m", // Green
false: "\033[31mDENIED\033[0m", // Red
}[resp.FinalDisposition]
cmd.Printf("\nFinal disposition: %s\n", disposition)
}

View File

@@ -20,6 +20,7 @@ import (
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/client/system"
"github.com/netbirdio/netbird/management/domain"
"github.com/netbirdio/netbird/util"
)
@@ -29,9 +30,16 @@ const (
interfaceInputType
)
const (
dnsLabelsFlag = "extra-dns-labels"
)
var (
foregroundMode bool
upCmd = &cobra.Command{
foregroundMode bool
dnsLabels []string
dnsLabelsValidated domain.List
upCmd = &cobra.Command{
Use: "up",
Short: "install, login and start Netbird client",
RunE: upFunc,
@@ -49,6 +57,14 @@ func init() {
upCmd.PersistentFlags().StringSliceVar(&extraIFaceBlackList, extraIFaceBlackListFlag, nil, "Extra list of default interfaces to ignore for listening")
upCmd.PersistentFlags().DurationVar(&dnsRouteInterval, dnsRouteIntervalFlag, time.Minute, "DNS route update interval")
upCmd.PersistentFlags().BoolVar(&blockLANAccess, blockLANAccessFlag, false, "Block access to local networks (LAN) when using this peer as a router or exit node")
upCmd.PersistentFlags().StringSliceVar(&dnsLabels, dnsLabelsFlag, nil,
`Sets DNS labels`+
`You can specify a comma-separated list of up to 32 labels. `+
`An empty string "" clears the previous configuration. `+
`E.g. --extra-dns-labels vpc1 or --extra-dns-labels vpc1,mgmt1 `+
`or --extra-dns-labels ""`,
)
}
func upFunc(cmd *cobra.Command, args []string) error {
@@ -67,6 +83,11 @@ func upFunc(cmd *cobra.Command, args []string) error {
return err
}
dnsLabelsValidated, err = validateDnsLabels(dnsLabels)
if err != nil {
return err
}
ctx := internal.CtxInitState(cmd.Context())
if hostName != "" {
@@ -98,6 +119,7 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command) error {
NATExternalIPs: natExternalIPs,
CustomDNSAddress: customDNSAddressConverted,
ExtraIFaceBlackList: extraIFaceBlackList,
DNSLabels: dnsLabelsValidated,
}
if cmd.Flag(enableRosenpassFlag).Changed {
@@ -190,7 +212,7 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command) error {
r.GetFullStatus()
connectClient := internal.NewConnectClient(ctx, config, r)
return connectClient.Run()
return connectClient.Run(nil)
}
func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error {
@@ -240,6 +262,8 @@ func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error {
IsLinuxDesktopClient: isLinuxRunningDesktop(),
Hostname: hostName,
ExtraIFaceBlacklist: extraIFaceBlackList,
DnsLabels: dnsLabels,
CleanDNSLabels: dnsLabels != nil && len(dnsLabels) == 0,
}
if rootCmd.PersistentFlags().Changed(preSharedKeyFlag) {
@@ -430,6 +454,24 @@ func parseCustomDNSAddress(modified bool) ([]byte, error) {
return parsed, nil
}
func validateDnsLabels(labels []string) (domain.List, error) {
var (
domains domain.List
err error
)
if len(labels) == 0 {
return domains, nil
}
domains, err = domain.ValidateDomains(labels)
if err != nil {
return nil, fmt.Errorf("failed to validate dns labels: %v", err)
}
return domains, nil
}
func isValidAddrPort(input string) bool {
if input == "" {
return true

167
client/embed/doc.go Normal file
View File

@@ -0,0 +1,167 @@
// Package embed provides a way to embed the NetBird client directly
// into Go programs without requiring a separate NetBird client installation.
package embed
// Basic Usage:
//
// client, err := embed.New(embed.Options{
// DeviceName: "my-service",
// SetupKey: os.Getenv("NB_SETUP_KEY"),
// ManagementURL: os.Getenv("NB_MANAGEMENT_URL"),
// })
// if err != nil {
// log.Fatal(err)
// }
//
// ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
// defer cancel()
// if err := client.Start(ctx); err != nil {
// log.Fatal(err)
// }
//
// Complete HTTP Server Example:
//
// package main
//
// import (
// "context"
// "fmt"
// "log"
// "net/http"
// "os"
// "os/signal"
// "syscall"
// "time"
//
// netbird "github.com/netbirdio/netbird/client/embed"
// )
//
// func main() {
// // Create client with setup key and device name
// client, err := netbird.New(netbird.Options{
// DeviceName: "http-server",
// SetupKey: os.Getenv("NB_SETUP_KEY"),
// ManagementURL: os.Getenv("NB_MANAGEMENT_URL"),
// LogOutput: io.Discard,
// })
// if err != nil {
// log.Fatal(err)
// }
//
// // Start with timeout
// ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
// defer cancel()
// if err := client.Start(ctx); err != nil {
// log.Fatal(err)
// }
//
// // Create HTTP server
// mux := http.NewServeMux()
// mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
// fmt.Printf("Request from %s: %s %s\n", r.RemoteAddr, r.Method, r.URL.Path)
// fmt.Fprintf(w, "Hello from netbird!")
// })
//
// // Listen on netbird network
// l, err := client.ListenTCP(":8080")
// if err != nil {
// log.Fatal(err)
// }
//
// server := &http.Server{Handler: mux}
// go func() {
// if err := server.Serve(l); !errors.Is(err, http.ErrServerClosed) {
// log.Printf("HTTP server error: %v", err)
// }
// }()
//
// log.Printf("HTTP server listening on netbird network port 8080")
//
// // Handle shutdown
// stop := make(chan os.Signal, 1)
// signal.Notify(stop, syscall.SIGINT, syscall.SIGTERM)
// <-stop
//
// shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
// defer cancel()
//
// if err := server.Shutdown(shutdownCtx); err != nil {
// log.Printf("HTTP shutdown error: %v", err)
// }
// if err := client.Stop(shutdownCtx); err != nil {
// log.Printf("Netbird shutdown error: %v", err)
// }
// }
//
// Complete HTTP Client Example:
//
// package main
//
// import (
// "context"
// "fmt"
// "io"
// "log"
// "os"
// "time"
//
// netbird "github.com/netbirdio/netbird/client/embed"
// )
//
// func main() {
// // Create client with setup key and device name
// client, err := netbird.New(netbird.Options{
// DeviceName: "http-client",
// SetupKey: os.Getenv("NB_SETUP_KEY"),
// ManagementURL: os.Getenv("NB_MANAGEMENT_URL"),
// LogOutput: io.Discard,
// })
// if err != nil {
// log.Fatal(err)
// }
//
// // Start with timeout
// ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
// defer cancel()
//
// if err := client.Start(ctx); err != nil {
// log.Fatal(err)
// }
//
// // Create HTTP client that uses netbird network
// httpClient := client.NewHTTPClient()
// httpClient.Timeout = 10 * time.Second
//
// // Make request to server in netbird network
// target := os.Getenv("NB_TARGET")
// resp, err := httpClient.Get(target)
// if err != nil {
// log.Fatal(err)
// }
// defer resp.Body.Close()
//
// // Read and print response
// body, err := io.ReadAll(resp.Body)
// if err != nil {
// log.Fatal(err)
// }
//
// fmt.Printf("Response from server: %s\n", string(body))
//
// // Clean shutdown
// shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
// defer cancel()
//
// if err := client.Stop(shutdownCtx); err != nil {
// log.Printf("Netbird shutdown error: %v", err)
// }
// }
//
// The package provides several methods for network operations:
// - Dial: Creates outbound connections
// - ListenTCP: Creates TCP listeners
// - ListenUDP: Creates UDP listeners
//
// By default, the embed package uses userspace networking mode, which doesn't
// require root/admin privileges. For production deployments, consider setting
// appropriate config and state paths for persistence.

293
client/embed/embed.go Normal file
View File

@@ -0,0 +1,293 @@
package embed
import (
"context"
"errors"
"fmt"
"io"
"net"
"net/http"
"net/netip"
"os"
"sync"
"github.com/sirupsen/logrus"
wgnetstack "golang.zx2c4.com/wireguard/tun/netstack"
"github.com/netbirdio/netbird/client/iface/netstack"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/system"
)
var ErrClientAlreadyStarted = errors.New("client already started")
var ErrClientNotStarted = errors.New("client not started")
// Client manages a netbird embedded client instance
type Client struct {
deviceName string
config *internal.Config
mu sync.Mutex
cancel context.CancelFunc
setupKey string
connect *internal.ConnectClient
}
// Options configures a new Client
type Options struct {
// DeviceName is this peer's name in the network
DeviceName string
// SetupKey is used for authentication
SetupKey string
// ManagementURL overrides the default management server URL
ManagementURL string
// PreSharedKey is the pre-shared key for the WireGuard interface
PreSharedKey string
// LogOutput is the output destination for logs (defaults to os.Stderr if nil)
LogOutput io.Writer
// LogLevel sets the logging level (defaults to info if empty)
LogLevel string
// NoUserspace disables the userspace networking mode. Needs admin/root privileges
NoUserspace bool
// ConfigPath is the path to the netbird config file. If empty, the config will be stored in memory and not persisted.
ConfigPath string
// StatePath is the path to the netbird state file
StatePath string
// DisableClientRoutes disables the client routes
DisableClientRoutes bool
}
// New creates a new netbird embedded client
func New(opts Options) (*Client, error) {
if opts.LogOutput != nil {
logrus.SetOutput(opts.LogOutput)
}
if opts.LogLevel != "" {
level, err := logrus.ParseLevel(opts.LogLevel)
if err != nil {
return nil, fmt.Errorf("parse log level: %w", err)
}
logrus.SetLevel(level)
}
if !opts.NoUserspace {
if err := os.Setenv(netstack.EnvUseNetstackMode, "true"); err != nil {
return nil, fmt.Errorf("setenv: %w", err)
}
if err := os.Setenv(netstack.EnvSkipProxy, "true"); err != nil {
return nil, fmt.Errorf("setenv: %w", err)
}
}
if opts.StatePath != "" {
// TODO: Disable state if path not provided
if err := os.Setenv("NB_DNS_STATE_FILE", opts.StatePath); err != nil {
return nil, fmt.Errorf("setenv: %w", err)
}
}
t := true
var config *internal.Config
var err error
input := internal.ConfigInput{
ConfigPath: opts.ConfigPath,
ManagementURL: opts.ManagementURL,
PreSharedKey: &opts.PreSharedKey,
DisableServerRoutes: &t,
DisableClientRoutes: &opts.DisableClientRoutes,
}
if opts.ConfigPath != "" {
config, err = internal.UpdateOrCreateConfig(input)
} else {
config, err = internal.CreateInMemoryConfig(input)
}
if err != nil {
return nil, fmt.Errorf("create config: %w", err)
}
return &Client{
deviceName: opts.DeviceName,
setupKey: opts.SetupKey,
config: config,
}, nil
}
// Start begins client operation and blocks until the engine has been started successfully or a startup error occurs.
// Pass a context with a deadline to limit the time spent waiting for the engine to start.
func (c *Client) Start(startCtx context.Context) error {
c.mu.Lock()
defer c.mu.Unlock()
if c.cancel != nil {
return ErrClientAlreadyStarted
}
ctx := internal.CtxInitState(context.Background())
// nolint:staticcheck
ctx = context.WithValue(ctx, system.DeviceNameCtxKey, c.deviceName)
if err := internal.Login(ctx, c.config, c.setupKey, ""); err != nil {
return fmt.Errorf("login: %w", err)
}
recorder := peer.NewRecorder(c.config.ManagementURL.String())
client := internal.NewConnectClient(ctx, c.config, recorder)
// either startup error (permanent backoff err) or nil err (successful engine up)
// TODO: make after-startup backoff err available
run := make(chan struct{}, 1)
clientErr := make(chan error, 1)
go func() {
if err := client.Run(run); err != nil {
clientErr <- err
}
}()
select {
case <-startCtx.Done():
if stopErr := client.Stop(); stopErr != nil {
return fmt.Errorf("stop error after context done. Stop error: %w. Context done: %w", stopErr, startCtx.Err())
}
return startCtx.Err()
case err := <-clientErr:
return fmt.Errorf("startup: %w", err)
case <-run:
}
c.connect = client
return nil
}
// Stop gracefully stops the client.
// Pass a context with a deadline to limit the time spent waiting for the engine to stop.
func (c *Client) Stop(ctx context.Context) error {
c.mu.Lock()
defer c.mu.Unlock()
if c.connect == nil {
return ErrClientNotStarted
}
done := make(chan error, 1)
go func() {
done <- c.connect.Stop()
}()
select {
case <-ctx.Done():
c.cancel = nil
return ctx.Err()
case err := <-done:
c.cancel = nil
if err != nil {
return fmt.Errorf("stop: %w", err)
}
return nil
}
}
// Dial dials a network address in the netbird network.
// Not applicable if the userspace networking mode is disabled.
func (c *Client) Dial(ctx context.Context, network, address string) (net.Conn, error) {
c.mu.Lock()
connect := c.connect
if connect == nil {
c.mu.Unlock()
return nil, ErrClientNotStarted
}
c.mu.Unlock()
engine := connect.Engine()
if engine == nil {
return nil, errors.New("engine not started")
}
nsnet, err := engine.GetNet()
if err != nil {
return nil, fmt.Errorf("get net: %w", err)
}
return nsnet.DialContext(ctx, network, address)
}
// ListenTCP listens on the given address in the netbird network
// Not applicable if the userspace networking mode is disabled.
func (c *Client) ListenTCP(address string) (net.Listener, error) {
nsnet, addr, err := c.getNet()
if err != nil {
return nil, err
}
_, port, err := net.SplitHostPort(address)
if err != nil {
return nil, fmt.Errorf("split host port: %w", err)
}
listenAddr := fmt.Sprintf("%s:%s", addr, port)
tcpAddr, err := net.ResolveTCPAddr("tcp", listenAddr)
if err != nil {
return nil, fmt.Errorf("resolve: %w", err)
}
return nsnet.ListenTCP(tcpAddr)
}
// ListenUDP listens on the given address in the netbird network
// Not applicable if the userspace networking mode is disabled.
func (c *Client) ListenUDP(address string) (net.PacketConn, error) {
nsnet, addr, err := c.getNet()
if err != nil {
return nil, err
}
_, port, err := net.SplitHostPort(address)
if err != nil {
return nil, fmt.Errorf("split host port: %w", err)
}
listenAddr := fmt.Sprintf("%s:%s", addr, port)
udpAddr, err := net.ResolveUDPAddr("udp", listenAddr)
if err != nil {
return nil, fmt.Errorf("resolve: %w", err)
}
return nsnet.ListenUDP(udpAddr)
}
// NewHTTPClient returns a configured http.Client that uses the netbird network for requests.
// Not applicable if the userspace networking mode is disabled.
func (c *Client) NewHTTPClient() *http.Client {
transport := &http.Transport{
DialContext: c.Dial,
}
return &http.Client{
Transport: transport,
}
}
func (c *Client) getNet() (*wgnetstack.Net, netip.Addr, error) {
c.mu.Lock()
connect := c.connect
if connect == nil {
c.mu.Unlock()
return nil, netip.Addr{}, errors.New("client not started")
}
c.mu.Unlock()
engine := connect.Engine()
if engine == nil {
return nil, netip.Addr{}, errors.New("engine not started")
}
addr, err := engine.Address()
if err != nil {
return nil, netip.Addr{}, fmt.Errorf("engine address: %w", err)
}
nsnet, err := engine.GetNet()
if err != nil {
return nil, netip.Addr{}, fmt.Errorf("get net: %w", err)
}
return nsnet, addr, nil
}

View File

@@ -10,17 +10,18 @@ import (
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/firewall/uspfilter"
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
"github.com/netbirdio/netbird/client/internal/statemanager"
)
// NewFirewall creates a firewall manager instance
func NewFirewall(iface IFaceMapper, _ *statemanager.Manager) (firewall.Manager, error) {
func NewFirewall(iface IFaceMapper, _ *statemanager.Manager, flowLogger nftypes.FlowLogger, disableServerRoutes bool) (firewall.Manager, error) {
if !iface.IsUserspaceBind() {
return nil, fmt.Errorf("not implemented for this OS: %s", runtime.GOOS)
}
// use userspace packet filtering firewall
fm, err := uspfilter.Create(iface)
fm, err := uspfilter.Create(iface, disableServerRoutes, flowLogger)
if err != nil {
return nil, err
}

View File

@@ -15,6 +15,7 @@ import (
firewall "github.com/netbirdio/netbird/client/firewall/manager"
nbnftables "github.com/netbirdio/netbird/client/firewall/nftables"
"github.com/netbirdio/netbird/client/firewall/uspfilter"
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
"github.com/netbirdio/netbird/client/internal/statemanager"
)
@@ -33,12 +34,12 @@ const SKIP_NFTABLES_ENV = "NB_SKIP_NFTABLES_CHECK"
// FWType is the type for the firewall type
type FWType int
func NewFirewall(iface IFaceMapper, stateManager *statemanager.Manager) (firewall.Manager, error) {
func NewFirewall(iface IFaceMapper, stateManager *statemanager.Manager, flowLogger nftypes.FlowLogger, disableServerRoutes bool) (firewall.Manager, error) {
// on the linux system we try to user nftables or iptables
// in any case, because we need to allow netbird interface traffic
// so we use AllowNetbird traffic from these firewall managers
// for the userspace packet filtering firewall
fm, err := createNativeFirewall(iface, stateManager)
fm, err := createNativeFirewall(iface, stateManager, disableServerRoutes)
if !iface.IsUserspaceBind() {
return fm, err
@@ -47,10 +48,10 @@ func NewFirewall(iface IFaceMapper, stateManager *statemanager.Manager) (firewal
if err != nil {
log.Warnf("failed to create native firewall: %v. Proceeding with userspace", err)
}
return createUserspaceFirewall(iface, fm)
return createUserspaceFirewall(iface, fm, disableServerRoutes, flowLogger)
}
func createNativeFirewall(iface IFaceMapper, stateManager *statemanager.Manager) (firewall.Manager, error) {
func createNativeFirewall(iface IFaceMapper, stateManager *statemanager.Manager, routes bool) (firewall.Manager, error) {
fm, err := createFW(iface)
if err != nil {
return nil, fmt.Errorf("create firewall: %s", err)
@@ -77,12 +78,12 @@ func createFW(iface IFaceMapper) (firewall.Manager, error) {
}
}
func createUserspaceFirewall(iface IFaceMapper, fm firewall.Manager) (firewall.Manager, error) {
func createUserspaceFirewall(iface IFaceMapper, fm firewall.Manager, disableServerRoutes bool, flowLogger nftypes.FlowLogger) (firewall.Manager, error) {
var errUsp error
if fm != nil {
fm, errUsp = uspfilter.CreateWithNativeFirewall(iface, fm)
fm, errUsp = uspfilter.CreateWithNativeFirewall(iface, fm, disableServerRoutes, flowLogger)
} else {
fm, errUsp = uspfilter.Create(iface)
fm, errUsp = uspfilter.Create(iface, disableServerRoutes, flowLogger)
}
if errUsp != nil {

View File

@@ -1,13 +1,18 @@
package firewall
import (
wgdevice "golang.zx2c4.com/wireguard/device"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
// IFaceMapper defines subset methods of interface required for manager
type IFaceMapper interface {
Name() string
Address() device.WGAddress
Address() wgaddr.Address
IsUserspaceBind() bool
SetFilter(device.PacketFilter) error
GetDevice() *device.FilteredDevice
GetWGDevice() *wgdevice.Device
}

View File

@@ -3,7 +3,7 @@ package iptables
import (
"fmt"
"net"
"strconv"
"slices"
"github.com/coreos/go-iptables/iptables"
"github.com/google/uuid"
@@ -30,10 +30,8 @@ type entry struct {
}
type aclManager struct {
iptablesClient *iptables.IPTables
wgIface iFaceMapper
routingFwChainName string
iptablesClient *iptables.IPTables
wgIface iFaceMapper
entries aclEntries
optionalEntries map[string][]entry
ipsetStore *ipsetStore
@@ -41,12 +39,10 @@ type aclManager struct {
stateManager *statemanager.Manager
}
func newAclManager(iptablesClient *iptables.IPTables, wgIface iFaceMapper, routingFwChainName string) (*aclManager, error) {
func newAclManager(iptablesClient *iptables.IPTables, wgIface iFaceMapper) (*aclManager, error) {
m := &aclManager{
iptablesClient: iptablesClient,
wgIface: wgIface,
routingFwChainName: routingFwChainName,
iptablesClient: iptablesClient,
wgIface: wgIface,
entries: make(map[string][][]string),
optionalEntries: make(map[string][]entry),
ipsetStore: newIpsetStore(),
@@ -79,6 +75,7 @@ func (m *aclManager) init(stateManager *statemanager.Manager) error {
}
func (m *aclManager) AddPeerFiltering(
id []byte,
ip net.IP,
protocol firewall.Protocol,
sPort *firewall.Port,
@@ -86,19 +83,19 @@ func (m *aclManager) AddPeerFiltering(
action firewall.Action,
ipsetName string,
) ([]firewall.Rule, error) {
var dPortVal, sPortVal string
if dPort != nil && dPort.Values != nil {
// TODO: we support only one port per rule in current implementation of ACLs
dPortVal = strconv.Itoa(dPort.Values[0])
}
if sPort != nil && sPort.Values != nil {
sPortVal = strconv.Itoa(sPort.Values[0])
}
chain := chainNameInputRules
ipsetName = transformIPsetName(ipsetName, sPortVal, dPortVal)
specs := filterRuleSpecs(ip, string(protocol), sPortVal, dPortVal, action, ipsetName)
ipsetName = transformIPsetName(ipsetName, sPort, dPort)
specs := filterRuleSpecs(ip, string(protocol), sPort, dPort, action, ipsetName)
mangleSpecs := slices.Clone(specs)
mangleSpecs = append(mangleSpecs,
"-i", m.wgIface.Name(),
"-m", "addrtype", "--dst-type", "LOCAL",
"-j", "MARK", "--set-xmark", fmt.Sprintf("%#x", nbnet.PreroutingFwmarkRedirected),
)
specs = append(specs, "-j", actionToStr(action))
if ipsetName != "" {
if ipList, ipsetExists := m.ipsetStore.ipset(ipsetName); ipsetExists {
if err := ipset.Add(ipsetName, ip.String()); err != nil {
@@ -130,7 +127,7 @@ func (m *aclManager) AddPeerFiltering(
m.ipsetStore.addIpList(ipsetName, ipList)
}
ok, err := m.iptablesClient.Exists("filter", chain, specs...)
ok, err := m.iptablesClient.Exists(tableFilter, chain, specs...)
if err != nil {
return nil, fmt.Errorf("failed to check rule: %w", err)
}
@@ -138,16 +135,22 @@ func (m *aclManager) AddPeerFiltering(
return nil, fmt.Errorf("rule already exists")
}
if err := m.iptablesClient.Append("filter", chain, specs...); err != nil {
if err := m.iptablesClient.Append(tableFilter, chain, specs...); err != nil {
return nil, err
}
if err := m.iptablesClient.Append(tableMangle, chainRTPRE, mangleSpecs...); err != nil {
log.Errorf("failed to add mangle rule: %v", err)
mangleSpecs = nil
}
rule := &Rule{
ruleID: uuid.New().String(),
specs: specs,
ipsetName: ipsetName,
ip: ip.String(),
chain: chain,
ruleID: uuid.New().String(),
specs: specs,
mangleSpecs: mangleSpecs,
ipsetName: ipsetName,
ip: ip.String(),
chain: chain,
}
m.updateState()
@@ -190,6 +193,12 @@ func (m *aclManager) DeletePeerRule(rule firewall.Rule) error {
return fmt.Errorf("failed to delete rule: %s, %v: %w", r.chain, r.specs, err)
}
if r.mangleSpecs != nil {
if err := m.iptablesClient.Delete(tableMangle, chainRTPRE, r.mangleSpecs...); err != nil {
log.Errorf("failed to delete mangle rule: %v", err)
}
}
m.updateState()
return nil
@@ -302,25 +311,21 @@ func (m *aclManager) seedInitialEntries() {
m.appendToEntries("INPUT", []string{"-i", m.wgIface.Name(), "-j", chainNameInputRules})
m.appendToEntries("INPUT", append([]string{"-i", m.wgIface.Name()}, established...))
// Inbound is handled by our ACLs, the rest is dropped.
// For outbound we respect the FORWARD policy. However, we need to allow established/related traffic for inbound rules.
m.appendToEntries("FORWARD", []string{"-i", m.wgIface.Name(), "-j", "DROP"})
m.appendToEntries("FORWARD", []string{"-i", m.wgIface.Name(), "-j", m.routingFwChainName})
m.appendToEntries("FORWARD", append([]string{"-o", m.wgIface.Name()}, established...))
m.appendToEntries("FORWARD", []string{"-o", m.wgIface.Name(), "-j", chainRTFWDOUT})
m.appendToEntries("FORWARD", []string{"-i", m.wgIface.Name(), "-j", chainRTFWDIN})
}
func (m *aclManager) seedInitialOptionalEntries() {
m.optionalEntries["FORWARD"] = []entry{
{
spec: []string{"-m", "mark", "--mark", fmt.Sprintf("%#x", nbnet.PreroutingFwmarkRedirected), "-j", chainNameInputRules},
spec: []string{"-m", "mark", "--mark", fmt.Sprintf("%#x", nbnet.PreroutingFwmarkRedirected), "-j", "ACCEPT"},
position: 2,
},
}
m.optionalEntries["PREROUTING"] = []entry{
{
spec: []string{"-t", "mangle", "-i", m.wgIface.Name(), "-m", "addrtype", "--dst-type", "LOCAL", "-j", "MARK", "--set-mark", fmt.Sprintf("%#x", nbnet.PreroutingFwmarkRedirected)},
position: 1,
},
}
}
func (m *aclManager) appendToEntries(chainName string, spec []string) {
@@ -354,7 +359,7 @@ func (m *aclManager) updateState() {
}
// filterRuleSpecs returns the specs of a filtering rule
func filterRuleSpecs(ip net.IP, protocol, sPort, dPort string, action firewall.Action, ipsetName string) (specs []string) {
func filterRuleSpecs(ip net.IP, protocol string, sPort, dPort *firewall.Port, action firewall.Action, ipsetName string) (specs []string) {
matchByIP := true
// don't use IP matching if IP is ip 0.0.0.0
if ip.String() == "0.0.0.0" {
@@ -371,13 +376,9 @@ func filterRuleSpecs(ip net.IP, protocol, sPort, dPort string, action firewall.A
if protocol != "all" {
specs = append(specs, "-p", protocol)
}
if sPort != "" {
specs = append(specs, "--sport", sPort)
}
if dPort != "" {
specs = append(specs, "--dport", dPort)
}
return append(specs, "-j", actionToStr(action))
specs = append(specs, applyPort("--sport", sPort)...)
specs = append(specs, applyPort("--dport", dPort)...)
return specs
}
func actionToStr(action firewall.Action) string {
@@ -387,15 +388,15 @@ func actionToStr(action firewall.Action) string {
return "DROP"
}
func transformIPsetName(ipsetName string, sPort, dPort string) string {
func transformIPsetName(ipsetName string, sPort, dPort *firewall.Port) string {
switch {
case ipsetName == "":
return ""
case sPort != "" && dPort != "":
case sPort != nil && dPort != nil:
return ipsetName + "-sport-dport"
case sPort != "":
case sPort != nil:
return ipsetName + "-sport"
case dPort != "":
case dPort != nil:
return ipsetName + "-dport"
default:
return ipsetName

View File

@@ -13,7 +13,7 @@ import (
nberrors "github.com/netbirdio/netbird/client/errors"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/internal/statemanager"
)
@@ -31,7 +31,7 @@ type Manager struct {
// iFaceMapper defines subset methods of interface required for manager
type iFaceMapper interface {
Name() string
Address() iface.WGAddress
Address() wgaddr.Address
IsUserspaceBind() bool
}
@@ -52,7 +52,7 @@ func Create(wgIface iFaceMapper) (*Manager, error) {
return nil, fmt.Errorf("create router: %w", err)
}
m.aclMgr, err = newAclManager(iptablesClient, wgIface, chainRTFWD)
m.aclMgr, err = newAclManager(iptablesClient, wgIface)
if err != nil {
return nil, fmt.Errorf("create acl manager: %w", err)
}
@@ -96,21 +96,22 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error {
//
// Comment will be ignored because some system this feature is not supported
func (m *Manager) AddPeerFiltering(
id []byte,
ip net.IP,
protocol firewall.Protocol,
proto firewall.Protocol,
sPort *firewall.Port,
dPort *firewall.Port,
action firewall.Action,
ipsetName string,
_ string,
) ([]firewall.Rule, error) {
m.mutex.Lock()
defer m.mutex.Unlock()
return m.aclMgr.AddPeerFiltering(ip, protocol, sPort, dPort, action, ipsetName)
return m.aclMgr.AddPeerFiltering(id, ip, proto, sPort, dPort, action, ipsetName)
}
func (m *Manager) AddRouteFiltering(
id []byte,
sources []netip.Prefix,
destination netip.Prefix,
proto firewall.Protocol,
@@ -125,7 +126,7 @@ func (m *Manager) AddRouteFiltering(
return nil, fmt.Errorf("unsupported IP version: %s", destination.Addr().String())
}
return m.router.AddRouteFiltering(sources, destination, proto, sPort, dPort, action)
return m.router.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action)
}
// DeletePeerRule from the firewall by rule definition
@@ -166,7 +167,7 @@ func (m *Manager) SetLegacyManagement(isLegacy bool) error {
}
// Reset firewall to the default state
func (m *Manager) Reset(stateManager *statemanager.Manager) error {
func (m *Manager) Close(stateManager *statemanager.Manager) error {
m.mutex.Lock()
defer m.mutex.Unlock()
@@ -196,13 +197,13 @@ func (m *Manager) AllowNetbird() error {
}
_, err := m.AddPeerFiltering(
nil,
net.IP{0, 0, 0, 0},
"all",
nil,
nil,
firewall.ActionAccept,
"",
"",
)
if err != nil {
return fmt.Errorf("allow netbird interface traffic: %w", err)
@@ -213,6 +214,35 @@ func (m *Manager) AllowNetbird() error {
// Flush doesn't need to be implemented for this manager
func (m *Manager) Flush() error { return nil }
// SetLogLevel sets the log level for the firewall manager
func (m *Manager) SetLogLevel(log.Level) {
// not supported
}
func (m *Manager) EnableRouting() error {
return nil
}
func (m *Manager) DisableRouting() error {
return nil
}
// AddDNATRule adds a DNAT rule
func (m *Manager) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) {
m.mutex.Lock()
defer m.mutex.Unlock()
return m.router.AddDNATRule(rule)
}
// DeleteDNATRule deletes a DNAT rule
func (m *Manager) DeleteDNATRule(rule firewall.Rule) error {
m.mutex.Lock()
defer m.mutex.Unlock()
return m.router.DeleteDNATRule(rule)
}
func getConntrackEstablished() []string {
return []string{"-m", "conntrack", "--ctstate", "RELATED,ESTABLISHED", "-j", "ACCEPT"}
}

View File

@@ -10,15 +10,15 @@ import (
"github.com/stretchr/testify/require"
fw "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
var ifaceMock = &iFaceMock{
NameFunc: func() string {
return "lo"
},
AddressFunc: func() iface.WGAddress {
return iface.WGAddress{
AddressFunc: func() wgaddr.Address {
return wgaddr.Address{
IP: net.ParseIP("10.20.0.1"),
Network: &net.IPNet{
IP: net.ParseIP("10.20.0.0"),
@@ -31,7 +31,7 @@ var ifaceMock = &iFaceMock{
// iFaceMapper defines subset methods of interface required for manager
type iFaceMock struct {
NameFunc func() string
AddressFunc func() iface.WGAddress
AddressFunc func() wgaddr.Address
}
func (i *iFaceMock) Name() string {
@@ -41,7 +41,7 @@ func (i *iFaceMock) Name() string {
panic("NameFunc is not set")
}
func (i *iFaceMock) Address() iface.WGAddress {
func (i *iFaceMock) Address() wgaddr.Address {
if i.AddressFunc != nil {
return i.AddressFunc()
}
@@ -62,7 +62,7 @@ func TestIptablesManager(t *testing.T) {
time.Sleep(time.Second)
defer func() {
err := manager.Reset(nil)
err := manager.Close(nil)
require.NoError(t, err, "clear the manager state")
time.Sleep(time.Second)
@@ -72,9 +72,10 @@ func TestIptablesManager(t *testing.T) {
t.Run("add second rule", func(t *testing.T) {
ip := net.ParseIP("10.20.0.3")
port := &fw.Port{
Values: []int{8043: 8046},
IsRange: true,
Values: []uint16{8043, 8046},
}
rule2, err = manager.AddPeerFiltering(ip, "tcp", port, nil, fw.ActionAccept, "", "accept HTTPS traffic from ports range")
rule2, err = manager.AddPeerFiltering(nil, ip, "tcp", port, nil, fw.ActionAccept, "")
require.NoError(t, err, "failed to add rule")
for _, r := range rule2 {
@@ -95,18 +96,18 @@ func TestIptablesManager(t *testing.T) {
t.Run("reset check", func(t *testing.T) {
// add second rule
ip := net.ParseIP("10.20.0.3")
port := &fw.Port{Values: []int{5353}}
_, err = manager.AddPeerFiltering(ip, "udp", nil, port, fw.ActionAccept, "", "accept Fake DNS traffic")
port := &fw.Port{Values: []uint16{5353}}
_, err = manager.AddPeerFiltering(nil, ip, "udp", nil, port, fw.ActionAccept, "")
require.NoError(t, err, "failed to add rule")
err = manager.Reset(nil)
err = manager.Close(nil)
require.NoError(t, err, "failed to reset")
ok, err := ipv4Client.ChainExists("filter", chainNameInputRules)
require.NoError(t, err, "failed check chain exists")
if ok {
require.NoErrorf(t, err, "chain '%v' still exists after Reset", chainNameInputRules)
require.NoErrorf(t, err, "chain '%v' still exists after Close", chainNameInputRules)
}
})
}
@@ -116,8 +117,8 @@ func TestIptablesManagerIPSet(t *testing.T) {
NameFunc: func() string {
return "lo"
},
AddressFunc: func() iface.WGAddress {
return iface.WGAddress{
AddressFunc: func() wgaddr.Address {
return wgaddr.Address{
IP: net.ParseIP("10.20.0.1"),
Network: &net.IPNet{
IP: net.ParseIP("10.20.0.0"),
@@ -135,7 +136,7 @@ func TestIptablesManagerIPSet(t *testing.T) {
time.Sleep(time.Second)
defer func() {
err := manager.Reset(nil)
err := manager.Close(nil)
require.NoError(t, err, "clear the manager state")
time.Sleep(time.Second)
@@ -145,9 +146,9 @@ func TestIptablesManagerIPSet(t *testing.T) {
t.Run("add second rule", func(t *testing.T) {
ip := net.ParseIP("10.20.0.3")
port := &fw.Port{
Values: []int{443},
Values: []uint16{443},
}
rule2, err = manager.AddPeerFiltering(ip, "tcp", port, nil, fw.ActionAccept, "default", "accept HTTPS traffic from ports range")
rule2, err = manager.AddPeerFiltering(nil, ip, "tcp", port, nil, fw.ActionAccept, "default")
for _, r := range rule2 {
require.NoError(t, err, "failed to add rule")
require.Equal(t, r.(*Rule).ipsetName, "default-sport", "ipset name must be set")
@@ -165,7 +166,7 @@ func TestIptablesManagerIPSet(t *testing.T) {
})
t.Run("reset check", func(t *testing.T) {
err = manager.Reset(nil)
err = manager.Close(nil)
require.NoError(t, err, "failed to reset")
})
}
@@ -183,8 +184,8 @@ func TestIptablesCreatePerformance(t *testing.T) {
NameFunc: func() string {
return "lo"
},
AddressFunc: func() iface.WGAddress {
return iface.WGAddress{
AddressFunc: func() wgaddr.Address {
return wgaddr.Address{
IP: net.ParseIP("10.20.0.1"),
Network: &net.IPNet{
IP: net.ParseIP("10.20.0.0"),
@@ -203,7 +204,7 @@ func TestIptablesCreatePerformance(t *testing.T) {
time.Sleep(time.Second)
defer func() {
err := manager.Reset(nil)
err := manager.Close(nil)
require.NoError(t, err, "clear the manager state")
time.Sleep(time.Second)
@@ -214,8 +215,8 @@ func TestIptablesCreatePerformance(t *testing.T) {
ip := net.ParseIP("10.20.0.100")
start := time.Now()
for i := 0; i < testMax; i++ {
port := &fw.Port{Values: []int{1000 + i}}
_, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.ActionAccept, "", "accept HTTP traffic")
port := &fw.Port{Values: []uint16{uint16(1000 + i)}}
_, err = manager.AddPeerFiltering(nil, ip, "tcp", nil, port, fw.ActionAccept, "")
require.NoError(t, err, "failed to add rule")
}

View File

@@ -15,7 +15,8 @@ import (
nberrors "github.com/netbirdio/netbird/client/errors"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/internal/acl/id"
nbid "github.com/netbirdio/netbird/client/internal/acl/id"
"github.com/netbirdio/netbird/client/internal/routemanager/ipfwdstate"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
"github.com/netbirdio/netbird/client/internal/statemanager"
nbnet "github.com/netbirdio/netbird/util/net"
@@ -23,22 +24,36 @@ import (
// constants needed to manage and create iptable rules
const (
tableFilter = "filter"
tableNat = "nat"
tableMangle = "mangle"
tableFilter = "filter"
tableNat = "nat"
tableMangle = "mangle"
chainPOSTROUTING = "POSTROUTING"
chainPREROUTING = "PREROUTING"
chainRTNAT = "NETBIRD-RT-NAT"
chainRTFWD = "NETBIRD-RT-FWD"
chainRTFWDIN = "NETBIRD-RT-FWD-IN"
chainRTFWDOUT = "NETBIRD-RT-FWD-OUT"
chainRTPRE = "NETBIRD-RT-PRE"
chainRTRDR = "NETBIRD-RT-RDR"
routingFinalForwardJump = "ACCEPT"
routingFinalNatJump = "MASQUERADE"
jumpPre = "jump-pre"
jumpNat = "jump-nat"
matchSet = "--match-set"
jumpManglePre = "jump-mangle-pre"
jumpNatPre = "jump-nat-pre"
jumpNatPost = "jump-nat-post"
matchSet = "--match-set"
dnatSuffix = "_dnat"
snatSuffix = "_snat"
fwdSuffix = "_fwd"
)
type ruleInfo struct {
chain string
table string
rule []string
}
type routeFilteringRuleParams struct {
Sources []netip.Prefix
Destination netip.Prefix
@@ -62,6 +77,7 @@ type router struct {
legacyManagement bool
stateManager *statemanager.Manager
ipFwdState *ipfwdstate.IPForwardingState
}
func newRouter(iptablesClient *iptables.IPTables, wgIface iFaceMapper) (*router, error) {
@@ -69,6 +85,7 @@ func newRouter(iptablesClient *iptables.IPTables, wgIface iFaceMapper) (*router,
iptablesClient: iptablesClient,
rules: make(map[string][]string),
wgIface: wgIface,
ipFwdState: ipfwdstate.NewIPForwardingState(),
}
r.ipsetCounter = refcounter.New(
@@ -104,6 +121,7 @@ func (r *router) init(stateManager *statemanager.Manager) error {
}
func (r *router) AddRouteFiltering(
id []byte,
sources []netip.Prefix,
destination netip.Prefix,
proto firewall.Protocol,
@@ -111,7 +129,7 @@ func (r *router) AddRouteFiltering(
dPort *firewall.Port,
action firewall.Action,
) (firewall.Rule, error) {
ruleKey := id.GenerateRouteRuleKey(sources, destination, proto, sPort, dPort, action)
ruleKey := nbid.GenerateRouteRuleKey(sources, destination, proto, sPort, dPort, action)
if _, ok := r.rules[string(ruleKey)]; ok {
return ruleKey, nil
}
@@ -135,7 +153,16 @@ func (r *router) AddRouteFiltering(
}
rule := genRouteFilteringRuleSpec(params)
if err := r.iptablesClient.Append(tableFilter, chainRTFWD, rule...); err != nil {
// Insert DROP rules at the beginning, append ACCEPT rules at the end
var err error
if action == firewall.ActionDrop {
// after the established rule
err = r.iptablesClient.Insert(tableFilter, chainRTFWDIN, 2, rule...)
} else {
err = r.iptablesClient.Append(tableFilter, chainRTFWDIN, rule...)
}
if err != nil {
return nil, fmt.Errorf("add route rule: %v", err)
}
@@ -147,12 +174,12 @@ func (r *router) AddRouteFiltering(
}
func (r *router) DeleteRouteRule(rule firewall.Rule) error {
ruleKey := rule.GetRuleID()
ruleKey := rule.ID()
if rule, exists := r.rules[ruleKey]; exists {
setName := r.findSetNameInRule(rule)
if err := r.iptablesClient.Delete(tableFilter, chainRTFWD, rule...); err != nil {
if err := r.iptablesClient.Delete(tableFilter, chainRTFWDIN, rule...); err != nil {
return fmt.Errorf("delete route rule: %v", err)
}
delete(r.rules, ruleKey)
@@ -203,6 +230,10 @@ func (r *router) deleteIpSet(setName string) error {
// AddNatRule inserts an iptables rule pair into the nat chain
func (r *router) AddNatRule(pair firewall.RouterPair) error {
if err := r.ipFwdState.RequestForwarding(); err != nil {
return err
}
if r.legacyManagement {
log.Warnf("This peer is connected to a NetBird Management service with an older version. Allowing all traffic for %s", pair.Destination)
if err := r.addLegacyRouteRule(pair); err != nil {
@@ -229,6 +260,10 @@ func (r *router) AddNatRule(pair firewall.RouterPair) error {
// RemoveNatRule removes an iptables rule pair from forwarding and nat chains
func (r *router) RemoveNatRule(pair firewall.RouterPair) error {
if err := r.ipFwdState.ReleaseForwarding(); err != nil {
log.Errorf("%v", err)
}
if err := r.removeNatRule(pair); err != nil {
return fmt.Errorf("remove nat rule: %w", err)
}
@@ -255,7 +290,7 @@ func (r *router) addLegacyRouteRule(pair firewall.RouterPair) error {
}
rule := []string{"-s", pair.Source.String(), "-d", pair.Destination.String(), "-j", routingFinalForwardJump}
if err := r.iptablesClient.Append(tableFilter, chainRTFWD, rule...); err != nil {
if err := r.iptablesClient.Append(tableFilter, chainRTFWDIN, rule...); err != nil {
return fmt.Errorf("add legacy forwarding rule %s -> %s: %v", pair.Source, pair.Destination, err)
}
@@ -268,7 +303,7 @@ func (r *router) removeLegacyRouteRule(pair firewall.RouterPair) error {
ruleKey := firewall.GenKey(firewall.ForwardingFormat, pair)
if rule, exists := r.rules[ruleKey]; exists {
if err := r.iptablesClient.DeleteIfExists(tableFilter, chainRTFWD, rule...); err != nil {
if err := r.iptablesClient.DeleteIfExists(tableFilter, chainRTFWDIN, rule...); err != nil {
return fmt.Errorf("remove legacy forwarding rule %s -> %s: %v", pair.Source, pair.Destination, err)
}
delete(r.rules, ruleKey)
@@ -296,7 +331,7 @@ func (r *router) RemoveAllLegacyRouteRules() error {
if !strings.HasPrefix(k, firewall.ForwardingFormatPrefix) {
continue
}
if err := r.iptablesClient.DeleteIfExists(tableFilter, chainRTFWD, rule...); err != nil {
if err := r.iptablesClient.DeleteIfExists(tableFilter, chainRTFWDIN, rule...); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove legacy forwarding rule: %v", err))
} else {
delete(r.rules, k)
@@ -334,9 +369,11 @@ func (r *router) cleanUpDefaultForwardRules() error {
chain string
table string
}{
{chainRTFWD, tableFilter},
{chainRTNAT, tableNat},
{chainRTFWDIN, tableFilter},
{chainRTFWDOUT, tableFilter},
{chainRTPRE, tableMangle},
{chainRTNAT, tableNat},
{chainRTRDR, tableNat},
} {
ok, err := r.iptablesClient.ChainExists(chainInfo.table, chainInfo.chain)
if err != nil {
@@ -356,16 +393,22 @@ func (r *router) createContainers() error {
chain string
table string
}{
{chainRTFWD, tableFilter},
{chainRTFWDIN, tableFilter},
{chainRTFWDOUT, tableFilter},
{chainRTPRE, tableMangle},
{chainRTNAT, tableNat},
{chainRTRDR, tableNat},
} {
if err := r.createAndSetupChain(chainInfo.chain); err != nil {
if err := r.iptablesClient.NewChain(chainInfo.table, chainInfo.chain); err != nil {
return fmt.Errorf("create chain %s in table %s: %w", chainInfo.chain, chainInfo.table, err)
}
}
if err := r.insertEstablishedRule(chainRTFWD); err != nil {
if err := r.insertEstablishedRule(chainRTFWDIN); err != nil {
return fmt.Errorf("insert established rule: %w", err)
}
if err := r.insertEstablishedRule(chainRTFWDOUT); err != nil {
return fmt.Errorf("insert established rule: %w", err)
}
@@ -406,27 +449,6 @@ func (r *router) addPostroutingRules() error {
return nil
}
func (r *router) createAndSetupChain(chain string) error {
table := r.getTableForChain(chain)
if err := r.iptablesClient.NewChain(table, chain); err != nil {
return fmt.Errorf("failed creating chain %s, error: %v", chain, err)
}
return nil
}
func (r *router) getTableForChain(chain string) string {
switch chain {
case chainRTNAT:
return tableNat
case chainRTPRE:
return tableMangle
default:
return tableFilter
}
}
func (r *router) insertEstablishedRule(chain string) error {
establishedRule := getConntrackEstablished()
@@ -445,28 +467,43 @@ func (r *router) addJumpRules() error {
// Jump to NAT chain
natRule := []string{"-j", chainRTNAT}
if err := r.iptablesClient.Insert(tableNat, chainPOSTROUTING, 1, natRule...); err != nil {
return fmt.Errorf("add nat jump rule: %v", err)
return fmt.Errorf("add nat postrouting jump rule: %v", err)
}
r.rules[jumpNat] = natRule
r.rules[jumpNatPost] = natRule
// Jump to prerouting chain
// Jump to mangle prerouting chain
preRule := []string{"-j", chainRTPRE}
if err := r.iptablesClient.Insert(tableMangle, chainPREROUTING, 1, preRule...); err != nil {
return fmt.Errorf("add prerouting jump rule: %v", err)
return fmt.Errorf("add mangle prerouting jump rule: %v", err)
}
r.rules[jumpPre] = preRule
r.rules[jumpManglePre] = preRule
// Jump to nat prerouting chain
rdrRule := []string{"-j", chainRTRDR}
if err := r.iptablesClient.Insert(tableNat, chainPREROUTING, 1, rdrRule...); err != nil {
return fmt.Errorf("add nat prerouting jump rule: %v", err)
}
r.rules[jumpNatPre] = rdrRule
return nil
}
func (r *router) cleanJumpRules() error {
for _, ruleKey := range []string{jumpNat, jumpPre} {
for _, ruleKey := range []string{jumpNatPost, jumpManglePre, jumpNatPre} {
if rule, exists := r.rules[ruleKey]; exists {
table := tableNat
chain := chainPOSTROUTING
if ruleKey == jumpPre {
var table, chain string
switch ruleKey {
case jumpNatPost:
table = tableNat
chain = chainPOSTROUTING
case jumpManglePre:
table = tableMangle
chain = chainPREROUTING
case jumpNatPre:
table = tableNat
chain = chainPREROUTING
default:
return fmt.Errorf("unknown jump rule: %s", ruleKey)
}
if err := r.iptablesClient.DeleteIfExists(table, chain, rule...); err != nil {
@@ -511,6 +548,8 @@ func (r *router) addNatRule(pair firewall.RouterPair) error {
}
r.rules[ruleKey] = rule
r.updateState()
return nil
}
@@ -526,6 +565,7 @@ func (r *router) removeNatRule(pair firewall.RouterPair) error {
log.Debugf("marking rule %s not found", ruleKey)
}
r.updateState()
return nil
}
@@ -555,6 +595,137 @@ func (r *router) updateState() {
}
}
func (r *router) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) {
if err := r.ipFwdState.RequestForwarding(); err != nil {
return nil, err
}
ruleKey := rule.ID()
if _, exists := r.rules[ruleKey+dnatSuffix]; exists {
return rule, nil
}
toDestination := rule.TranslatedAddress.String()
switch {
case len(rule.TranslatedPort.Values) == 0:
// no translated port, use original port
case len(rule.TranslatedPort.Values) == 1:
toDestination += fmt.Sprintf(":%d", rule.TranslatedPort.Values[0])
case rule.TranslatedPort.IsRange && len(rule.TranslatedPort.Values) == 2:
// need the "/originalport" suffix to avoid dnat port randomization
toDestination += fmt.Sprintf(":%d-%d/%d", rule.TranslatedPort.Values[0], rule.TranslatedPort.Values[1], rule.DestinationPort.Values[0])
default:
return nil, fmt.Errorf("invalid translated port: %v", rule.TranslatedPort)
}
proto := strings.ToLower(string(rule.Protocol))
rules := make(map[string]ruleInfo, 3)
// DNAT rule
dnatRule := []string{
"!", "-i", r.wgIface.Name(),
"-p", proto,
"-j", "DNAT",
"--to-destination", toDestination,
}
dnatRule = append(dnatRule, applyPort("--dport", &rule.DestinationPort)...)
rules[ruleKey+dnatSuffix] = ruleInfo{
table: tableNat,
chain: chainRTRDR,
rule: dnatRule,
}
// SNAT rule
snatRule := []string{
"-o", r.wgIface.Name(),
"-p", proto,
"-d", rule.TranslatedAddress.String(),
"-j", "MASQUERADE",
}
snatRule = append(snatRule, applyPort("--dport", &rule.TranslatedPort)...)
rules[ruleKey+snatSuffix] = ruleInfo{
table: tableNat,
chain: chainRTNAT,
rule: snatRule,
}
// Forward filtering rule, if fwd policy is DROP
forwardRule := []string{
"-o", r.wgIface.Name(),
"-p", proto,
"-d", rule.TranslatedAddress.String(),
"-j", "ACCEPT",
}
forwardRule = append(forwardRule, applyPort("--dport", &rule.TranslatedPort)...)
rules[ruleKey+fwdSuffix] = ruleInfo{
table: tableFilter,
chain: chainRTFWDOUT,
rule: forwardRule,
}
for key, ruleInfo := range rules {
if err := r.iptablesClient.Append(ruleInfo.table, ruleInfo.chain, ruleInfo.rule...); err != nil {
if rollbackErr := r.rollbackRules(rules); rollbackErr != nil {
log.Errorf("rollback failed: %v", rollbackErr)
}
return nil, fmt.Errorf("add rule %s: %w", key, err)
}
r.rules[key] = ruleInfo.rule
}
r.updateState()
return rule, nil
}
func (r *router) rollbackRules(rules map[string]ruleInfo) error {
var merr *multierror.Error
for key, ruleInfo := range rules {
if err := r.iptablesClient.DeleteIfExists(ruleInfo.table, ruleInfo.chain, ruleInfo.rule...); err != nil {
merr = multierror.Append(merr, fmt.Errorf("rollback rule %s: %w", key, err))
// On rollback error, add to rules map for next cleanup
r.rules[key] = ruleInfo.rule
}
}
if merr != nil {
r.updateState()
}
return nberrors.FormatErrorOrNil(merr)
}
func (r *router) DeleteDNATRule(rule firewall.Rule) error {
if err := r.ipFwdState.ReleaseForwarding(); err != nil {
log.Errorf("%v", err)
}
ruleKey := rule.ID()
var merr *multierror.Error
if dnatRule, exists := r.rules[ruleKey+dnatSuffix]; exists {
if err := r.iptablesClient.Delete(tableNat, chainRTRDR, dnatRule...); err != nil {
merr = multierror.Append(merr, fmt.Errorf("delete DNAT rule: %w", err))
}
delete(r.rules, ruleKey+dnatSuffix)
}
if snatRule, exists := r.rules[ruleKey+snatSuffix]; exists {
if err := r.iptablesClient.Delete(tableNat, chainRTNAT, snatRule...); err != nil {
merr = multierror.Append(merr, fmt.Errorf("delete SNAT rule: %w", err))
}
delete(r.rules, ruleKey+snatSuffix)
}
if fwdRule, exists := r.rules[ruleKey+fwdSuffix]; exists {
if err := r.iptablesClient.Delete(tableFilter, chainRTFWDIN, fwdRule...); err != nil {
merr = multierror.Append(merr, fmt.Errorf("delete forward rule: %w", err))
}
delete(r.rules, ruleKey+fwdSuffix)
}
r.updateState()
return nberrors.FormatErrorOrNil(merr)
}
func genRouteFilteringRuleSpec(params routeFilteringRuleParams) []string {
var rule []string
@@ -590,10 +761,10 @@ func applyPort(flag string, port *firewall.Port) []string {
if len(port.Values) > 1 {
portList := make([]string, len(port.Values))
for i, p := range port.Values {
portList[i] = strconv.Itoa(p)
portList[i] = strconv.Itoa(int(p))
}
return []string{"-m", "multiport", flag, strings.Join(portList, ",")}
}
return []string{flag, strconv.Itoa(port.Values[0])}
return []string{flag, strconv.Itoa(int(port.Values[0]))}
}

View File

@@ -39,12 +39,14 @@ func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) {
}()
// Now 5 rules:
// 1. established rule in forward chain
// 2. jump rule to NAT chain
// 3. jump rule to PRE chain
// 4. static outbound masquerade rule
// 5. static return masquerade rule
require.Len(t, manager.rules, 5, "should have created rules map")
// 1. established rule forward in
// 2. estbalished rule forward out
// 3. jump rule to POST nat chain
// 4. jump rule to PRE mangle chain
// 5. jump rule to PRE nat chain
// 6. static outbound masquerade rule
// 7. static return masquerade rule
require.Len(t, manager.rules, 7, "should have created rules map")
exists, err := manager.iptablesClient.Exists(tableNat, chainPOSTROUTING, "-j", chainRTNAT)
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainPOSTROUTING)
@@ -239,7 +241,7 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
destination: netip.MustParsePrefix("10.0.0.0/24"),
proto: firewall.ProtocolTCP,
sPort: nil,
dPort: &firewall.Port{Values: []int{80}},
dPort: &firewall.Port{Values: []uint16{80}},
direction: firewall.RuleDirectionIN,
action: firewall.ActionAccept,
expectSet: false,
@@ -252,7 +254,7 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
},
destination: netip.MustParsePrefix("10.0.0.0/8"),
proto: firewall.ProtocolUDP,
sPort: &firewall.Port{Values: []int{1024, 2048}, IsRange: true},
sPort: &firewall.Port{Values: []uint16{1024, 2048}, IsRange: true},
dPort: nil,
direction: firewall.RuleDirectionOUT,
action: firewall.ActionDrop,
@@ -285,7 +287,7 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
sources: []netip.Prefix{netip.MustParsePrefix("172.16.0.0/12")},
destination: netip.MustParsePrefix("192.168.0.0/16"),
proto: firewall.ProtocolTCP,
sPort: &firewall.Port{Values: []int{80, 443, 8080}},
sPort: &firewall.Port{Values: []uint16{80, 443, 8080}},
dPort: nil,
direction: firewall.RuleDirectionOUT,
action: firewall.ActionAccept,
@@ -297,7 +299,7 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
destination: netip.MustParsePrefix("10.0.0.0/24"),
proto: firewall.ProtocolUDP,
sPort: nil,
dPort: &firewall.Port{Values: []int{5000, 5100}, IsRange: true},
dPort: &firewall.Port{Values: []uint16{5000, 5100}, IsRange: true},
direction: firewall.RuleDirectionIN,
action: firewall.ActionDrop,
expectSet: false,
@@ -307,8 +309,8 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
sources: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/24")},
destination: netip.MustParsePrefix("172.16.0.0/16"),
proto: firewall.ProtocolTCP,
sPort: &firewall.Port{Values: []int{1024, 65535}, IsRange: true},
dPort: &firewall.Port{Values: []int{22}},
sPort: &firewall.Port{Values: []uint16{1024, 65535}, IsRange: true},
dPort: &firewall.Port{Values: []uint16{22}},
direction: firewall.RuleDirectionOUT,
action: firewall.ActionAccept,
expectSet: false,
@@ -328,18 +330,18 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ruleKey, err := r.AddRouteFiltering(tt.sources, tt.destination, tt.proto, tt.sPort, tt.dPort, tt.action)
ruleKey, err := r.AddRouteFiltering(nil, tt.sources, tt.destination, tt.proto, tt.sPort, tt.dPort, tt.action)
require.NoError(t, err, "AddRouteFiltering failed")
// Check if the rule is in the internal map
rule, ok := r.rules[ruleKey.GetRuleID()]
rule, ok := r.rules[ruleKey.ID()]
assert.True(t, ok, "Rule not found in internal map")
// Log the internal rule
t.Logf("Internal rule: %v", rule)
// Check if the rule exists in iptables
exists, err := iptablesClient.Exists(tableFilter, chainRTFWD, rule...)
exists, err := iptablesClient.Exists(tableFilter, chainRTFWDIN, rule...)
assert.NoError(t, err, "Failed to check rule existence")
assert.True(t, exists, "Rule not found in iptables")

View File

@@ -5,12 +5,13 @@ type Rule struct {
ruleID string
ipsetName string
specs []string
ip string
chain string
specs []string
mangleSpecs []string
ip string
chain string
}
// GetRuleID returns the rule id
func (r *Rule) GetRuleID() string {
func (r *Rule) ID() string {
return r.ruleID
}

View File

@@ -4,21 +4,20 @@ import (
"fmt"
"sync"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
type InterfaceState struct {
NameStr string `json:"name"`
WGAddress iface.WGAddress `json:"wg_address"`
UserspaceBind bool `json:"userspace_bind"`
NameStr string `json:"name"`
WGAddress wgaddr.Address `json:"wg_address"`
UserspaceBind bool `json:"userspace_bind"`
}
func (i *InterfaceState) Name() string {
return i.NameStr
}
func (i *InterfaceState) Address() device.WGAddress {
func (i *InterfaceState) Address() wgaddr.Address {
return i.WGAddress
}
@@ -62,7 +61,7 @@ func (s *ShutdownState) Cleanup() error {
ipt.aclMgr.ipsetStore = s.ACLIPsetStore
}
if err := ipt.Reset(nil); err != nil {
if err := ipt.Close(nil); err != nil {
return fmt.Errorf("reset iptables manager: %w", err)
}

View File

@@ -26,8 +26,8 @@ const (
// Each firewall type for different OS can use different type
// of the properties to hold data of the created rule
type Rule interface {
// GetRuleID returns the rule id
GetRuleID() string
// ID returns the rule id
ID() string
}
// RuleDirection is the traffic direction which a rule is applied
@@ -65,13 +65,13 @@ type Manager interface {
// If comment argument is empty firewall manager should set
// rule ID as comment for the rule
AddPeerFiltering(
id []byte,
ip net.IP,
proto Protocol,
sPort *Port,
dPort *Port,
action Action,
ipsetName string,
comment string,
) ([]Rule, error)
// DeletePeerRule from the firewall by rule definition
@@ -80,7 +80,15 @@ type Manager interface {
// IsServerRouteSupported returns true if the firewall supports server side routing operations
IsServerRouteSupported() bool
AddRouteFiltering(source []netip.Prefix, destination netip.Prefix, proto Protocol, sPort *Port, dPort *Port, action Action) (Rule, error)
AddRouteFiltering(
id []byte,
sources []netip.Prefix,
destination netip.Prefix,
proto Protocol,
sPort *Port,
dPort *Port,
action Action,
) (Rule, error)
// DeleteRouteRule deletes a routing rule
DeleteRouteRule(rule Rule) error
@@ -94,11 +102,23 @@ type Manager interface {
// SetLegacyManagement sets the legacy management mode
SetLegacyManagement(legacy bool) error
// Reset firewall to the default state
Reset(stateManager *statemanager.Manager) error
// Close closes the firewall manager
Close(stateManager *statemanager.Manager) error
// Flush the changes to firewall controller
Flush() error
SetLogLevel(log.Level)
EnableRouting() error
DisableRouting() error
// AddDNATRule adds a DNAT rule
AddDNATRule(ForwardRule) (Rule, error)
// DeleteDNATRule deletes a DNAT rule
DeleteDNATRule(Rule) error
}
func GenKey(format string, pair RouterPair) string {

View File

@@ -0,0 +1,27 @@
package manager
import (
"fmt"
"net/netip"
)
// ForwardRule todo figure out better place to this to avoid circular imports
type ForwardRule struct {
Protocol Protocol
DestinationPort Port
TranslatedAddress netip.Addr
TranslatedPort Port
}
func (r ForwardRule) ID() string {
id := fmt.Sprintf("%s;%s;%s;%s",
r.Protocol,
r.DestinationPort.String(),
r.TranslatedAddress.String(),
r.TranslatedPort.String())
return id
}
func (r ForwardRule) String() string {
return fmt.Sprintf("protocol: %s, destinationPort: %s, translatedAddress: %s, translatedPort: %s", r.Protocol, r.DestinationPort.String(), r.TranslatedAddress.String(), r.TranslatedPort.String())
}

View File

@@ -1,36 +1,37 @@
package manager
import (
"fmt"
"strconv"
)
// Protocol is the protocol of the port
type Protocol string
const (
// ProtocolTCP is the TCP protocol
ProtocolTCP Protocol = "tcp"
// ProtocolUDP is the UDP protocol
ProtocolUDP Protocol = "udp"
// ProtocolICMP is the ICMP protocol
ProtocolICMP Protocol = "icmp"
// ProtocolALL cover all supported protocols
ProtocolALL Protocol = "all"
// ProtocolUnknown unknown protocol
ProtocolUnknown Protocol = "unknown"
)
// Port of the address for firewall rule
// todo Move Protocol and Port and RouterPair to the Firwall package or a separate package
type Port struct {
// IsRange is true Values contains two values, the first is the start port, the second is the end port
IsRange bool
// Values contains one value for single port, multiple values for the list of ports, or two values for the range of ports
Values []int
Values []uint16
}
func NewPort(ports ...int) (*Port, error) {
if len(ports) == 0 {
return nil, fmt.Errorf("no port provided")
}
ports16 := make([]uint16, len(ports))
for i, port := range ports {
if port < 1 || port > 65535 {
return nil, fmt.Errorf("invalid port number: %d (must be between 1-65535)", port)
}
ports16[i] = uint16(port)
}
return &Port{
IsRange: len(ports) > 1,
Values: ports16,
}, nil
}
// String interface implementation
@@ -40,7 +41,11 @@ func (p *Port) String() string {
if ports != "" {
ports += ","
}
ports += strconv.Itoa(port)
ports += strconv.Itoa(int(port))
}
if p.IsRange {
ports = "range:" + ports
}
return ports
}

View File

@@ -0,0 +1,19 @@
package manager
// Protocol is the protocol of the port
// todo Move Protocol and Port and RouterPair to the Firwall package or a separate package
type Protocol string
const (
// ProtocolTCP is the TCP protocol
ProtocolTCP Protocol = "tcp"
// ProtocolUDP is the UDP protocol
ProtocolUDP Protocol = "udp"
// ProtocolICMP is the ICMP protocol
ProtocolICMP Protocol = "icmp"
// ProtocolALL cover all supported protocols
ProtocolALL Protocol = "all"
)

View File

@@ -2,9 +2,9 @@ package nftables
import (
"bytes"
"encoding/binary"
"fmt"
"net"
"slices"
"strconv"
"strings"
"time"
@@ -46,6 +46,7 @@ type AclManager struct {
workTable *nftables.Table
chainInputRules *nftables.Chain
chainPrerouting *nftables.Chain
ipsetStore *ipsetStore
rules map[string]*Rule
@@ -83,13 +84,13 @@ func (m *AclManager) init(workTable *nftables.Table) error {
// If comment argument is empty firewall manager should set
// rule ID as comment for the rule
func (m *AclManager) AddPeerFiltering(
id []byte,
ip net.IP,
proto firewall.Protocol,
sPort *firewall.Port,
dPort *firewall.Port,
action firewall.Action,
ipsetName string,
comment string,
) ([]firewall.Rule, error) {
var ipset *nftables.Set
if ipsetName != "" {
@@ -101,7 +102,7 @@ func (m *AclManager) AddPeerFiltering(
}
newRules := make([]firewall.Rule, 0, 2)
ioRule, err := m.addIOFiltering(ip, proto, sPort, dPort, action, ipset, comment)
ioRule, err := m.addIOFiltering(ip, proto, sPort, dPort, action, ipset)
if err != nil {
return nil, err
}
@@ -118,23 +119,32 @@ func (m *AclManager) DeletePeerRule(rule firewall.Rule) error {
}
if r.nftSet == nil {
err := m.rConn.DelRule(r.nftRule)
if err != nil {
if err := m.rConn.DelRule(r.nftRule); err != nil {
log.Errorf("failed to delete rule: %v", err)
}
delete(m.rules, r.GetRuleID())
if r.mangleRule != nil {
if err := m.rConn.DelRule(r.mangleRule); err != nil {
log.Errorf("failed to delete mangle rule: %v", err)
}
}
delete(m.rules, r.ID())
return m.rConn.Flush()
}
ips, ok := m.ipsetStore.ips(r.nftSet.Name)
if !ok {
err := m.rConn.DelRule(r.nftRule)
if err != nil {
if err := m.rConn.DelRule(r.nftRule); err != nil {
log.Errorf("failed to delete rule: %v", err)
}
delete(m.rules, r.GetRuleID())
if r.mangleRule != nil {
if err := m.rConn.DelRule(r.mangleRule); err != nil {
log.Errorf("failed to delete mangle rule: %v", err)
}
}
delete(m.rules, r.ID())
return m.rConn.Flush()
}
if _, ok := ips[r.ip.String()]; ok {
err := m.sConn.SetDeleteElements(r.nftSet, []nftables.SetElement{{Key: r.ip.To4()}})
if err != nil {
@@ -153,16 +163,20 @@ func (m *AclManager) DeletePeerRule(rule firewall.Rule) error {
return nil
}
err := m.rConn.DelRule(r.nftRule)
if err != nil {
if err := m.rConn.DelRule(r.nftRule); err != nil {
log.Errorf("failed to delete rule: %v", err)
}
err = m.rConn.Flush()
if err != nil {
if r.mangleRule != nil {
if err := m.rConn.DelRule(r.mangleRule); err != nil {
log.Errorf("failed to delete mangle rule: %v", err)
}
}
if err := m.rConn.Flush(); err != nil {
return err
}
delete(m.rules, r.GetRuleID())
delete(m.rules, r.ID())
m.ipsetStore.DeleteReferenceFromIpSet(r.nftSet.Name)
if m.ipsetStore.HasReferenceToSet(r.nftSet.Name) {
@@ -225,9 +239,12 @@ func (m *AclManager) Flush() error {
return err
}
if err := m.refreshRuleHandles(m.chainInputRules); err != nil {
if err := m.refreshRuleHandles(m.chainInputRules, false); err != nil {
log.Errorf("failed to refresh rule handles ipv4 input chain: %v", err)
}
if err := m.refreshRuleHandles(m.chainPrerouting, true); err != nil {
log.Errorf("failed to refresh rule handles prerouting chain: %v", err)
}
return nil
}
@@ -239,15 +256,15 @@ func (m *AclManager) addIOFiltering(
dPort *firewall.Port,
action firewall.Action,
ipset *nftables.Set,
comment string,
) (*Rule, error) {
ruleId := generatePeerRuleId(ip, sPort, dPort, action, ipset)
if r, ok := m.rules[ruleId]; ok {
return &Rule{
r.nftRule,
r.nftSet,
r.ruleID,
ip,
nftRule: r.nftRule,
mangleRule: r.mangleRule,
nftSet: r.nftSet,
ruleID: r.ruleID,
ip: ip,
}, nil
}
@@ -308,68 +325,100 @@ func (m *AclManager) addIOFiltering(
}
}
if sPort != nil && len(sPort.Values) != 0 {
expressions = append(expressions,
&expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseTransportHeader,
Offset: 0,
Len: 2,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: encodePort(*sPort),
},
)
}
expressions = append(expressions, applyPort(sPort, true)...)
expressions = append(expressions, applyPort(dPort, false)...)
if dPort != nil && len(dPort.Values) != 0 {
expressions = append(expressions,
&expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseTransportHeader,
Offset: 2,
Len: 2,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: encodePort(*dPort),
},
)
}
mainExpressions := slices.Clone(expressions)
switch action {
case firewall.ActionAccept:
expressions = append(expressions, &expr.Verdict{Kind: expr.VerdictAccept})
mainExpressions = append(mainExpressions, &expr.Verdict{Kind: expr.VerdictAccept})
case firewall.ActionDrop:
expressions = append(expressions, &expr.Verdict{Kind: expr.VerdictDrop})
mainExpressions = append(mainExpressions, &expr.Verdict{Kind: expr.VerdictDrop})
}
userData := []byte(strings.Join([]string{ruleId, comment}, " "))
userData := []byte(ruleId)
chain := m.chainInputRules
nftRule := m.rConn.AddRule(&nftables.Rule{
Table: m.workTable,
Chain: chain,
Exprs: expressions,
Exprs: mainExpressions,
UserData: userData,
})
if err := m.rConn.Flush(); err != nil {
return nil, fmt.Errorf(flushError, err)
}
rule := &Rule{
nftRule: nftRule,
nftSet: ipset,
ruleID: ruleId,
ip: ip,
nftRule: nftRule,
mangleRule: m.createPreroutingRule(expressions, userData),
nftSet: ipset,
ruleID: ruleId,
ip: ip,
}
m.rules[ruleId] = rule
if ipset != nil {
m.ipsetStore.AddReferenceToIpset(ipset.Name)
}
return rule, nil
}
func (m *AclManager) createPreroutingRule(expressions []expr.Any, userData []byte) *nftables.Rule {
if m.chainPrerouting == nil {
log.Warn("prerouting chain is not created")
return nil
}
preroutingExprs := slices.Clone(expressions)
// interface
preroutingExprs = append([]expr.Any{
&expr.Meta{
Key: expr.MetaKeyIIFNAME,
Register: 1,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: ifname(m.wgIface.Name()),
},
}, preroutingExprs...)
// local destination and mark
preroutingExprs = append(preroutingExprs,
&expr.Fib{
Register: 1,
ResultADDRTYPE: true,
FlagDADDR: true,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: binaryutil.NativeEndian.PutUint32(unix.RTN_LOCAL),
},
&expr.Immediate{
Register: 1,
Data: binaryutil.NativeEndian.PutUint32(nbnet.PreroutingFwmarkRedirected),
},
&expr.Meta{
Key: expr.MetaKeyMARK,
Register: 1,
SourceRegister: true,
},
)
return m.rConn.AddRule(&nftables.Rule{
Table: m.workTable,
Chain: m.chainPrerouting,
Exprs: preroutingExprs,
UserData: userData,
})
}
func (m *AclManager) createDefaultChains() (err error) {
// chainNameInputRules
chain := m.createChain(chainNameInputRules)
@@ -413,7 +462,7 @@ func (m *AclManager) createDefaultChains() (err error) {
// go through the input filter as well. This will enable e.g. Docker services to keep working by accessing the
// netbird peer IP.
func (m *AclManager) allowRedirectedTraffic(chainFwFilter *nftables.Chain) error {
preroutingChain := m.rConn.AddChain(&nftables.Chain{
m.chainPrerouting = m.rConn.AddChain(&nftables.Chain{
Name: chainNamePrerouting,
Table: m.workTable,
Type: nftables.ChainTypeFilter,
@@ -421,8 +470,6 @@ func (m *AclManager) allowRedirectedTraffic(chainFwFilter *nftables.Chain) error
Priority: nftables.ChainPriorityMangle,
})
m.addPreroutingRule(preroutingChain)
m.addFwmarkToForward(chainFwFilter)
if err := m.rConn.Flush(); err != nil {
@@ -432,43 +479,6 @@ func (m *AclManager) allowRedirectedTraffic(chainFwFilter *nftables.Chain) error
return nil
}
func (m *AclManager) addPreroutingRule(preroutingChain *nftables.Chain) {
m.rConn.AddRule(&nftables.Rule{
Table: m.workTable,
Chain: preroutingChain,
Exprs: []expr.Any{
&expr.Meta{
Key: expr.MetaKeyIIFNAME,
Register: 1,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: ifname(m.wgIface.Name()),
},
&expr.Fib{
Register: 1,
ResultADDRTYPE: true,
FlagDADDR: true,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: binaryutil.NativeEndian.PutUint32(unix.RTN_LOCAL),
},
&expr.Immediate{
Register: 1,
Data: binaryutil.NativeEndian.PutUint32(nbnet.PreroutingFwmarkRedirected),
},
&expr.Meta{
Key: expr.MetaKeyMARK,
Register: 1,
SourceRegister: true,
},
},
})
}
func (m *AclManager) addFwmarkToForward(chainFwFilter *nftables.Chain) {
m.rConn.InsertRule(&nftables.Rule{
Table: m.workTable,
@@ -484,8 +494,7 @@ func (m *AclManager) addFwmarkToForward(chainFwFilter *nftables.Chain) {
Data: binaryutil.NativeEndian.PutUint32(nbnet.PreroutingFwmarkRedirected),
},
&expr.Verdict{
Kind: expr.VerdictJump,
Chain: m.chainInputRules.Name,
Kind: expr.VerdictAccept,
},
},
})
@@ -632,6 +641,7 @@ func (m *AclManager) flushWithBackoff() (err error) {
for i := 0; ; i++ {
err = m.rConn.Flush()
if err != nil {
log.Debugf("failed to flush nftables: %v", err)
if !strings.Contains(err.Error(), "busy") {
return
}
@@ -648,7 +658,7 @@ func (m *AclManager) flushWithBackoff() (err error) {
return
}
func (m *AclManager) refreshRuleHandles(chain *nftables.Chain) error {
func (m *AclManager) refreshRuleHandles(chain *nftables.Chain, mangle bool) error {
if m.workTable == nil || chain == nil {
return nil
}
@@ -665,7 +675,11 @@ func (m *AclManager) refreshRuleHandles(chain *nftables.Chain) error {
split := bytes.Split(rule.UserData, []byte(" "))
r, ok := m.rules[string(split[0])]
if ok {
*r.nftRule = *rule
if mangle {
*r.mangleRule = *rule
} else {
*r.nftRule = *rule
}
}
}
@@ -689,12 +703,6 @@ func generatePeerRuleId(ip net.IP, sPort *firewall.Port, dPort *firewall.Port, a
return "set:" + ipset.Name + rulesetID
}
func encodePort(port firewall.Port) []byte {
bs := make([]byte, 2)
binary.BigEndian.PutUint16(bs, uint16(port.Values[0]))
return bs
}
func ifname(n string) []byte {
b := make([]byte, 16)
copy(b, n+"\x00")

View File

@@ -14,7 +14,7 @@ import (
log "github.com/sirupsen/logrus"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/internal/statemanager"
)
@@ -29,7 +29,7 @@ const (
// iFaceMapper defines subset methods of interface required for manager
type iFaceMapper interface {
Name() string
Address() iface.WGAddress
Address() wgaddr.Address
IsUserspaceBind() bool
}
@@ -87,7 +87,7 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error {
// We only need to record minimal interface state for potential recreation.
// Unlike iptables, which requires tracking individual rules, nftables maintains
// a known state (our netbird table plus a few static rules). This allows for easy
// cleanup using Reset() without needing to store specific rules.
// cleanup using Close() without needing to store specific rules.
if err := stateManager.UpdateState(&ShutdownState{
InterfaceState: &InterfaceState{
NameStr: m.wgIface.Name(),
@@ -113,13 +113,13 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error {
// If comment argument is empty firewall manager should set
// rule ID as comment for the rule
func (m *Manager) AddPeerFiltering(
id []byte,
ip net.IP,
proto firewall.Protocol,
sPort *firewall.Port,
dPort *firewall.Port,
action firewall.Action,
ipsetName string,
comment string,
) ([]firewall.Rule, error) {
m.mutex.Lock()
defer m.mutex.Unlock()
@@ -129,10 +129,11 @@ func (m *Manager) AddPeerFiltering(
return nil, fmt.Errorf("unsupported IP version: %s", ip.String())
}
return m.aclManager.AddPeerFiltering(ip, proto, sPort, dPort, action, ipsetName, comment)
return m.aclManager.AddPeerFiltering(id, ip, proto, sPort, dPort, action, ipsetName)
}
func (m *Manager) AddRouteFiltering(
id []byte,
sources []netip.Prefix,
destination netip.Prefix,
proto firewall.Protocol,
@@ -147,7 +148,7 @@ func (m *Manager) AddRouteFiltering(
return nil, fmt.Errorf("unsupported IP version: %s", destination.Addr().String())
}
return m.router.AddRouteFiltering(sources, destination, proto, sPort, dPort, action)
return m.router.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action)
}
// DeletePeerRule from the firewall by rule definition
@@ -242,7 +243,7 @@ func (m *Manager) SetLegacyManagement(isLegacy bool) error {
}
// Reset firewall to the default state
func (m *Manager) Reset(stateManager *statemanager.Manager) error {
func (m *Manager) Close(stateManager *statemanager.Manager) error {
m.mutex.Lock()
defer m.mutex.Unlock()
@@ -318,6 +319,19 @@ func (m *Manager) cleanupNetbirdTables() error {
return nil
}
// SetLogLevel sets the log level for the firewall manager
func (m *Manager) SetLogLevel(log.Level) {
// not supported
}
func (m *Manager) EnableRouting() error {
return nil
}
func (m *Manager) DisableRouting() error {
return nil
}
// Flush rule/chain/set operations from the buffer
//
// Method also get all rules after flush and refreshes handle values in the rulesets
@@ -329,6 +343,22 @@ func (m *Manager) Flush() error {
return m.aclManager.Flush()
}
// AddDNATRule adds a DNAT rule
func (m *Manager) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) {
m.mutex.Lock()
defer m.mutex.Unlock()
return m.router.AddDNATRule(rule)
}
// DeleteDNATRule deletes a DNAT rule
func (m *Manager) DeleteDNATRule(rule firewall.Rule) error {
m.mutex.Lock()
defer m.mutex.Unlock()
return m.router.DeleteDNATRule(rule)
}
func (m *Manager) createWorkTable() (*nftables.Table, error) {
tables, err := m.rConn.ListTablesOfFamily(nftables.TableFamilyIPv4)
if err != nil {

View File

@@ -16,15 +16,15 @@ import (
"golang.org/x/sys/unix"
fw "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
var ifaceMock = &iFaceMock{
NameFunc: func() string {
return "lo"
},
AddressFunc: func() iface.WGAddress {
return iface.WGAddress{
AddressFunc: func() wgaddr.Address {
return wgaddr.Address{
IP: net.ParseIP("100.96.0.1"),
Network: &net.IPNet{
IP: net.ParseIP("100.96.0.0"),
@@ -37,7 +37,7 @@ var ifaceMock = &iFaceMock{
// iFaceMapper defines subset methods of interface required for manager
type iFaceMock struct {
NameFunc func() string
AddressFunc func() iface.WGAddress
AddressFunc func() wgaddr.Address
}
func (i *iFaceMock) Name() string {
@@ -47,7 +47,7 @@ func (i *iFaceMock) Name() string {
panic("NameFunc is not set")
}
func (i *iFaceMock) Address() iface.WGAddress {
func (i *iFaceMock) Address() wgaddr.Address {
if i.AddressFunc != nil {
return i.AddressFunc()
}
@@ -65,7 +65,7 @@ func TestNftablesManager(t *testing.T) {
time.Sleep(time.Second * 3)
defer func() {
err = manager.Reset(nil)
err = manager.Close(nil)
require.NoError(t, err, "failed to reset")
time.Sleep(time.Second)
}()
@@ -74,7 +74,7 @@ func TestNftablesManager(t *testing.T) {
testClient := &nftables.Conn{}
rule, err := manager.AddPeerFiltering(ip, fw.ProtocolTCP, nil, &fw.Port{Values: []int{53}}, fw.ActionDrop, "", "")
rule, err := manager.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{53}}, fw.ActionDrop, "")
require.NoError(t, err, "failed to add rule")
err = manager.Flush()
@@ -107,7 +107,7 @@ func TestNftablesManager(t *testing.T) {
Kind: expr.VerdictAccept,
},
}
require.ElementsMatch(t, rules[0].Exprs, expectedExprs1, "expected the same expressions")
compareExprsIgnoringCounters(t, rules[0].Exprs, expectedExprs1)
ipToAdd, _ := netip.AddrFromSlice(ip)
add := ipToAdd.Unmap()
@@ -162,7 +162,7 @@ func TestNftablesManager(t *testing.T) {
// established rule remains
require.Len(t, rules, 1, "expected 1 rules after deletion")
err = manager.Reset(nil)
err = manager.Close(nil)
require.NoError(t, err, "failed to reset")
}
@@ -171,8 +171,8 @@ func TestNFtablesCreatePerformance(t *testing.T) {
NameFunc: func() string {
return "lo"
},
AddressFunc: func() iface.WGAddress {
return iface.WGAddress{
AddressFunc: func() wgaddr.Address {
return wgaddr.Address{
IP: net.ParseIP("100.96.0.1"),
Network: &net.IPNet{
IP: net.ParseIP("100.96.0.0"),
@@ -191,7 +191,7 @@ func TestNFtablesCreatePerformance(t *testing.T) {
time.Sleep(time.Second * 3)
defer func() {
if err := manager.Reset(nil); err != nil {
if err := manager.Close(nil); err != nil {
t.Errorf("clear the manager state: %v", err)
}
time.Sleep(time.Second)
@@ -200,8 +200,8 @@ func TestNFtablesCreatePerformance(t *testing.T) {
ip := net.ParseIP("10.20.0.100")
start := time.Now()
for i := 0; i < testMax; i++ {
port := &fw.Port{Values: []int{1000 + i}}
_, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.ActionAccept, "", "accept HTTP traffic")
port := &fw.Port{Values: []uint16{uint16(1000 + i)}}
_, err = manager.AddPeerFiltering(nil, ip, "tcp", nil, port, fw.ActionAccept, "")
require.NoError(t, err, "failed to add rule")
if i%100 == 0 {
@@ -274,7 +274,7 @@ func TestNftablesManagerCompatibilityWithIptables(t *testing.T) {
require.NoError(t, manager.Init(nil))
t.Cleanup(func() {
err := manager.Reset(nil)
err := manager.Close(nil)
require.NoError(t, err, "failed to reset manager state")
// Verify iptables output after reset
@@ -283,15 +283,16 @@ func TestNftablesManagerCompatibilityWithIptables(t *testing.T) {
})
ip := net.ParseIP("100.96.0.1")
_, err = manager.AddPeerFiltering(ip, fw.ProtocolTCP, nil, &fw.Port{Values: []int{80}}, fw.ActionAccept, "", "test rule")
_, err = manager.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept, "")
require.NoError(t, err, "failed to add peer filtering rule")
_, err = manager.AddRouteFiltering(
nil,
[]netip.Prefix{netip.MustParsePrefix("192.168.2.0/24")},
netip.MustParsePrefix("10.1.0.0/24"),
fw.ProtocolTCP,
nil,
&fw.Port{Values: []int{443}},
&fw.Port{Values: []uint16{443}},
fw.ActionAccept,
)
require.NoError(t, err, "failed to add route filtering rule")
@@ -307,3 +308,18 @@ func TestNftablesManagerCompatibilityWithIptables(t *testing.T) {
stdout, stderr = runIptablesSave(t)
verifyIptablesOutput(t, stdout, stderr)
}
func compareExprsIgnoringCounters(t *testing.T, got, want []expr.Any) {
t.Helper()
require.Equal(t, len(got), len(want), "expression count mismatch")
for i := range got {
if _, isCounter := got[i].(*expr.Counter); isCounter {
_, wantIsCounter := want[i].(*expr.Counter)
require.True(t, wantIsCounter, "expected Counter at index %d", i)
continue
}
require.Equal(t, got[i], want[i], "expression mismatch at index %d", i)
}
}

View File

@@ -14,23 +14,31 @@ import (
"github.com/google/nftables"
"github.com/google/nftables/binaryutil"
"github.com/google/nftables/expr"
"github.com/google/nftables/xt"
"github.com/hashicorp/go-multierror"
log "github.com/sirupsen/logrus"
nberrors "github.com/netbirdio/netbird/client/errors"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/internal/acl/id"
nbid "github.com/netbirdio/netbird/client/internal/acl/id"
"github.com/netbirdio/netbird/client/internal/routemanager/ipfwdstate"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
nbnet "github.com/netbirdio/netbird/util/net"
)
const (
chainNameRoutingFw = "netbird-rt-fwd"
chainNameRoutingNat = "netbird-rt-postrouting"
chainNameForward = "FORWARD"
tableNat = "nat"
chainNameNatPrerouting = "PREROUTING"
chainNameRoutingFw = "netbird-rt-fwd"
chainNameRoutingNat = "netbird-rt-postrouting"
chainNameRoutingRdr = "netbird-rt-redirect"
chainNameForward = "FORWARD"
userDataAcceptForwardRuleIif = "frwacceptiif"
userDataAcceptForwardRuleOif = "frwacceptoif"
dnatSuffix = "_dnat"
snatSuffix = "_snat"
)
const refreshRulesMapError = "refresh rules map: %w"
@@ -49,16 +57,18 @@ type router struct {
ipsetCounter *refcounter.Counter[string, []netip.Prefix, *nftables.Set]
wgIface iFaceMapper
ipFwdState *ipfwdstate.IPForwardingState
legacyManagement bool
}
func newRouter(workTable *nftables.Table, wgIface iFaceMapper) (*router, error) {
r := &router{
conn: &nftables.Conn{},
workTable: workTable,
chains: make(map[string]*nftables.Chain),
rules: make(map[string]*nftables.Rule),
wgIface: wgIface,
conn: &nftables.Conn{},
workTable: workTable,
chains: make(map[string]*nftables.Chain),
rules: make(map[string]*nftables.Rule),
wgIface: wgIface,
ipFwdState: ipfwdstate.NewIPForwardingState(),
}
r.ipsetCounter = refcounter.New(
@@ -98,7 +108,52 @@ func (r *router) Reset() error {
// clear without deleting the ipsets, the nf table will be deleted by the caller
r.ipsetCounter.Clear()
return r.removeAcceptForwardRules()
var merr *multierror.Error
if err := r.removeAcceptForwardRules(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove accept forward rules: %w", err))
}
if err := r.removeNatPreroutingRules(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove filter prerouting rules: %w", err))
}
return nberrors.FormatErrorOrNil(merr)
}
func (r *router) removeNatPreroutingRules() error {
table := &nftables.Table{
Name: tableNat,
Family: nftables.TableFamilyIPv4,
}
chain := &nftables.Chain{
Name: chainNameNatPrerouting,
Table: table,
Hooknum: nftables.ChainHookPrerouting,
Priority: nftables.ChainPriorityNATDest,
Type: nftables.ChainTypeNAT,
}
rules, err := r.conn.GetRules(table, chain)
if err != nil {
return fmt.Errorf("get rules from nat table: %w", err)
}
var merr *multierror.Error
// Delete rules that have our UserData suffix
for _, rule := range rules {
if len(rule.UserData) == 0 || !strings.HasSuffix(string(rule.UserData), dnatSuffix) {
continue
}
if err := r.conn.DelRule(rule); err != nil {
merr = multierror.Append(merr, fmt.Errorf("delete rule %s: %w", rule.UserData, err))
}
}
if err := r.conn.Flush(); err != nil {
merr = multierror.Append(merr, fmt.Errorf(flushError, err))
}
return nberrors.FormatErrorOrNil(merr)
}
func (r *router) loadFilterTable() (*nftables.Table, error) {
@@ -133,14 +188,22 @@ func (r *router) createContainers() error {
Type: nftables.ChainTypeNAT,
})
r.chains[chainNameRoutingRdr] = r.conn.AddChain(&nftables.Chain{
Name: chainNameRoutingRdr,
Table: r.workTable,
Hooknum: nftables.ChainHookPrerouting,
Priority: nftables.ChainPriorityNATDest,
Type: nftables.ChainTypeNAT,
})
// Chain is created by acl manager
// TODO: move creation to a common place
r.chains[chainNamePrerouting] = &nftables.Chain{
Name: chainNamePrerouting,
Table: r.workTable,
Type: nftables.ChainTypeFilter,
Hooknum: nftables.ChainHookPrerouting,
Priority: nftables.ChainPriorityMangle,
Type: nftables.ChainTypeFilter,
}
// Add the single NAT rule that matches on mark
@@ -165,6 +228,7 @@ func (r *router) createContainers() error {
// AddRouteFiltering appends a nftables rule to the routing chain
func (r *router) AddRouteFiltering(
id []byte,
sources []netip.Prefix,
destination netip.Prefix,
proto firewall.Protocol,
@@ -173,7 +237,7 @@ func (r *router) AddRouteFiltering(
action firewall.Action,
) (firewall.Rule, error) {
ruleKey := id.GenerateRouteRuleKey(sources, destination, proto, sPort, dPort, action)
ruleKey := nbid.GenerateRouteRuleKey(sources, destination, proto, sPort, dPort, action)
if _, ok := r.rules[string(ruleKey)]; ok {
return ruleKey, nil
}
@@ -233,7 +297,13 @@ func (r *router) AddRouteFiltering(
UserData: []byte(ruleKey),
}
rule = r.conn.AddRule(rule)
// Insert DROP rules at the beginning, append ACCEPT rules at the end
if action == firewall.ActionDrop {
// TODO: Insert after the established rule
rule = r.conn.InsertRule(rule)
} else {
rule = r.conn.AddRule(rule)
}
log.Tracef("Adding route rule %s", spew.Sdump(rule))
if err := r.conn.Flush(); err != nil {
@@ -275,7 +345,7 @@ func (r *router) DeleteRouteRule(rule firewall.Rule) error {
return fmt.Errorf(refreshRulesMapError, err)
}
ruleKey := rule.GetRuleID()
ruleKey := rule.ID()
nftRule, exists := r.rules[ruleKey]
if !exists {
log.Debugf("route rule %s not found", ruleKey)
@@ -404,6 +474,10 @@ func (r *router) deleteNftRule(rule *nftables.Rule, ruleKey string) error {
// AddNatRule appends a nftables rule pair to the nat chain
func (r *router) AddNatRule(pair firewall.RouterPair) error {
if err := r.ipFwdState.RequestForwarding(); err != nil {
return err
}
if err := r.refreshRulesMap(); err != nil {
return fmt.Errorf(refreshRulesMapError, err)
}
@@ -830,6 +904,10 @@ func (r *router) removeAcceptForwardRulesIptables(ipt *iptables.IPTables) error
// RemoveNatRule removes the prerouting mark rule
func (r *router) RemoveNatRule(pair firewall.RouterPair) error {
if err := r.ipFwdState.ReleaseForwarding(); err != nil {
log.Errorf("%v", err)
}
if err := r.refreshRulesMap(); err != nil {
return fmt.Errorf(refreshRulesMapError, err)
}
@@ -890,6 +968,269 @@ func (r *router) refreshRulesMap() error {
return nil
}
func (r *router) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) {
if err := r.ipFwdState.RequestForwarding(); err != nil {
return nil, err
}
ruleKey := rule.ID()
if _, exists := r.rules[ruleKey+dnatSuffix]; exists {
return rule, nil
}
protoNum, err := protoToInt(rule.Protocol)
if err != nil {
return nil, fmt.Errorf("convert protocol to number: %w", err)
}
if err := r.addDnatRedirect(rule, protoNum, ruleKey); err != nil {
return nil, err
}
r.addDnatMasq(rule, protoNum, ruleKey)
// Unlike iptables, there's no point in adding "out" rules in the forward chain here as our policy is ACCEPT.
// To overcome DROP policies in other chains, we'd have to add rules to the chains there.
// We also cannot just add "oif <iface> accept" there and filter in our own table as we don't know what is supposed to be allowed.
// TODO: find chains with drop policies and add rules there
if err := r.conn.Flush(); err != nil {
return nil, fmt.Errorf("flush rules: %w", err)
}
return &rule, nil
}
func (r *router) addDnatRedirect(rule firewall.ForwardRule, protoNum uint8, ruleKey string) error {
dnatExprs := []expr.Any{
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
&expr.Cmp{
Op: expr.CmpOpNeq,
Register: 1,
Data: ifname(r.wgIface.Name()),
},
&expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: []byte{protoNum},
},
&expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseTransportHeader,
Offset: 2,
Len: 2,
},
}
dnatExprs = append(dnatExprs, applyPort(&rule.DestinationPort, false)...)
// shifted translated port is not supported in nftables, so we hand this over to xtables
if rule.TranslatedPort.IsRange && len(rule.TranslatedPort.Values) == 2 {
if rule.TranslatedPort.Values[0] != rule.DestinationPort.Values[0] ||
rule.TranslatedPort.Values[1] != rule.DestinationPort.Values[1] {
return r.addXTablesRedirect(dnatExprs, ruleKey, rule)
}
}
additionalExprs, regProtoMin, regProtoMax, err := r.handleTranslatedPort(rule)
if err != nil {
return err
}
dnatExprs = append(dnatExprs, additionalExprs...)
dnatExprs = append(dnatExprs,
&expr.NAT{
Type: expr.NATTypeDestNAT,
Family: uint32(nftables.TableFamilyIPv4),
RegAddrMin: 1,
RegProtoMin: regProtoMin,
RegProtoMax: regProtoMax,
},
)
dnatRule := &nftables.Rule{
Table: r.workTable,
Chain: r.chains[chainNameRoutingRdr],
Exprs: dnatExprs,
UserData: []byte(ruleKey + dnatSuffix),
}
r.conn.AddRule(dnatRule)
r.rules[ruleKey+dnatSuffix] = dnatRule
return nil
}
func (r *router) handleTranslatedPort(rule firewall.ForwardRule) ([]expr.Any, uint32, uint32, error) {
switch {
case rule.TranslatedPort.IsRange && len(rule.TranslatedPort.Values) == 2:
return r.handlePortRange(rule)
case len(rule.TranslatedPort.Values) == 0:
return r.handleAddressOnly(rule)
case len(rule.TranslatedPort.Values) == 1:
return r.handleSinglePort(rule)
default:
return nil, 0, 0, fmt.Errorf("invalid translated port: %v", rule.TranslatedPort)
}
}
func (r *router) handlePortRange(rule firewall.ForwardRule) ([]expr.Any, uint32, uint32, error) {
exprs := []expr.Any{
&expr.Immediate{
Register: 1,
Data: rule.TranslatedAddress.AsSlice(),
},
&expr.Immediate{
Register: 2,
Data: binaryutil.BigEndian.PutUint16(rule.TranslatedPort.Values[0]),
},
&expr.Immediate{
Register: 3,
Data: binaryutil.BigEndian.PutUint16(rule.TranslatedPort.Values[1]),
},
}
return exprs, 2, 3, nil
}
func (r *router) handleAddressOnly(rule firewall.ForwardRule) ([]expr.Any, uint32, uint32, error) {
exprs := []expr.Any{
&expr.Immediate{
Register: 1,
Data: rule.TranslatedAddress.AsSlice(),
},
}
return exprs, 0, 0, nil
}
func (r *router) handleSinglePort(rule firewall.ForwardRule) ([]expr.Any, uint32, uint32, error) {
exprs := []expr.Any{
&expr.Immediate{
Register: 1,
Data: rule.TranslatedAddress.AsSlice(),
},
&expr.Immediate{
Register: 2,
Data: binaryutil.BigEndian.PutUint16(rule.TranslatedPort.Values[0]),
},
}
return exprs, 2, 0, nil
}
func (r *router) addXTablesRedirect(dnatExprs []expr.Any, ruleKey string, rule firewall.ForwardRule) error {
dnatExprs = append(dnatExprs,
&expr.Counter{},
&expr.Target{
Name: "DNAT",
Rev: 2,
Info: &xt.NatRange2{
NatRange: xt.NatRange{
Flags: uint(xt.NatRangeMapIPs | xt.NatRangeProtoSpecified | xt.NatRangeProtoOffset),
MinIP: rule.TranslatedAddress.AsSlice(),
MaxIP: rule.TranslatedAddress.AsSlice(),
MinPort: rule.TranslatedPort.Values[0],
MaxPort: rule.TranslatedPort.Values[1],
},
BasePort: rule.DestinationPort.Values[0],
},
},
)
dnatRule := &nftables.Rule{
Table: &nftables.Table{
Name: tableNat,
Family: nftables.TableFamilyIPv4,
},
Chain: &nftables.Chain{
Name: chainNameNatPrerouting,
Table: r.filterTable,
Type: nftables.ChainTypeNAT,
Hooknum: nftables.ChainHookPrerouting,
Priority: nftables.ChainPriorityNATDest,
},
Exprs: dnatExprs,
UserData: []byte(ruleKey + dnatSuffix),
}
r.conn.AddRule(dnatRule)
r.rules[ruleKey+dnatSuffix] = dnatRule
return nil
}
func (r *router) addDnatMasq(rule firewall.ForwardRule, protoNum uint8, ruleKey string) {
masqExprs := []expr.Any{
&expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: ifname(r.wgIface.Name()),
},
&expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: []byte{protoNum},
},
&expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseNetworkHeader,
Offset: 16,
Len: 4,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: rule.TranslatedAddress.AsSlice(),
},
}
masqExprs = append(masqExprs, applyPort(&rule.TranslatedPort, false)...)
masqExprs = append(masqExprs, &expr.Masq{})
masqRule := &nftables.Rule{
Table: r.workTable,
Chain: r.chains[chainNameRoutingNat],
Exprs: masqExprs,
UserData: []byte(ruleKey + snatSuffix),
}
r.conn.AddRule(masqRule)
r.rules[ruleKey+snatSuffix] = masqRule
}
func (r *router) DeleteDNATRule(rule firewall.Rule) error {
if err := r.ipFwdState.ReleaseForwarding(); err != nil {
log.Errorf("%v", err)
}
ruleKey := rule.ID()
if err := r.refreshRulesMap(); err != nil {
return fmt.Errorf(refreshRulesMapError, err)
}
var merr *multierror.Error
if dnatRule, exists := r.rules[ruleKey+dnatSuffix]; exists {
if err := r.conn.DelRule(dnatRule); err != nil {
merr = multierror.Append(merr, fmt.Errorf("delete dnat rule: %w", err))
}
}
if masqRule, exists := r.rules[ruleKey+snatSuffix]; exists {
if err := r.conn.DelRule(masqRule); err != nil {
merr = multierror.Append(merr, fmt.Errorf("delete snat rule: %w", err))
}
}
if err := r.conn.Flush(); err != nil {
merr = multierror.Append(merr, fmt.Errorf(flushError, err))
}
if merr == nil {
delete(r.rules, ruleKey+dnatSuffix)
delete(r.rules, ruleKey+snatSuffix)
}
return nberrors.FormatErrorOrNil(merr)
}
// generateCIDRMatcherExpressions generates nftables expressions that matches a CIDR
func generateCIDRMatcherExpressions(source bool, prefix netip.Prefix) []expr.Any {
var offset uint32
@@ -953,15 +1294,11 @@ func applyPort(port *firewall.Port, isSource bool) []expr.Any {
if port.IsRange && len(port.Values) == 2 {
// Handle port range
exprs = append(exprs,
&expr.Cmp{
Op: expr.CmpOpGte,
&expr.Range{
Op: expr.CmpOpEq,
Register: 1,
Data: binaryutil.BigEndian.PutUint16(uint16(port.Values[0])),
},
&expr.Cmp{
Op: expr.CmpOpLte,
Register: 1,
Data: binaryutil.BigEndian.PutUint16(uint16(port.Values[1])),
FromData: binaryutil.BigEndian.PutUint16(port.Values[0]),
ToData: binaryutil.BigEndian.PutUint16(port.Values[1]),
},
)
} else {
@@ -980,7 +1317,7 @@ func applyPort(port *firewall.Port, isSource bool) []expr.Any {
exprs = append(exprs, &expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: binaryutil.BigEndian.PutUint16(uint16(p)),
Data: binaryutil.BigEndian.PutUint16(p),
})
}
}

View File

@@ -38,7 +38,7 @@ func TestNftablesManager_AddNatRule(t *testing.T) {
// need fw manager to init both acl mgr and router for all chains to be present
manager, err := Create(ifaceMock)
t.Cleanup(func() {
require.NoError(t, manager.Reset(nil))
require.NoError(t, manager.Close(nil))
})
require.NoError(t, err)
require.NoError(t, manager.Init(nil))
@@ -127,7 +127,7 @@ func TestNftablesManager_RemoveNatRule(t *testing.T) {
t.Run(testCase.Name, func(t *testing.T) {
manager, err := Create(ifaceMock)
t.Cleanup(func() {
require.NoError(t, manager.Reset(nil))
require.NoError(t, manager.Close(nil))
})
require.NoError(t, err)
require.NoError(t, manager.Init(nil))
@@ -222,7 +222,7 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
destination: netip.MustParsePrefix("10.0.0.0/24"),
proto: firewall.ProtocolTCP,
sPort: nil,
dPort: &firewall.Port{Values: []int{80}},
dPort: &firewall.Port{Values: []uint16{80}},
direction: firewall.RuleDirectionIN,
action: firewall.ActionAccept,
expectSet: false,
@@ -235,7 +235,7 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
},
destination: netip.MustParsePrefix("10.0.0.0/8"),
proto: firewall.ProtocolUDP,
sPort: &firewall.Port{Values: []int{1024, 2048}, IsRange: true},
sPort: &firewall.Port{Values: []uint16{1024, 2048}, IsRange: true},
dPort: nil,
direction: firewall.RuleDirectionOUT,
action: firewall.ActionDrop,
@@ -268,7 +268,7 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
sources: []netip.Prefix{netip.MustParsePrefix("172.16.0.0/12")},
destination: netip.MustParsePrefix("192.168.0.0/16"),
proto: firewall.ProtocolTCP,
sPort: &firewall.Port{Values: []int{80, 443, 8080}},
sPort: &firewall.Port{Values: []uint16{80, 443, 8080}},
dPort: nil,
direction: firewall.RuleDirectionOUT,
action: firewall.ActionAccept,
@@ -280,7 +280,7 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
destination: netip.MustParsePrefix("10.0.0.0/24"),
proto: firewall.ProtocolUDP,
sPort: nil,
dPort: &firewall.Port{Values: []int{5000, 5100}, IsRange: true},
dPort: &firewall.Port{Values: []uint16{5000, 5100}, IsRange: true},
direction: firewall.RuleDirectionIN,
action: firewall.ActionDrop,
expectSet: false,
@@ -290,8 +290,8 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
sources: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/24")},
destination: netip.MustParsePrefix("172.16.0.0/16"),
proto: firewall.ProtocolTCP,
sPort: &firewall.Port{Values: []int{1024, 65535}, IsRange: true},
dPort: &firewall.Port{Values: []int{22}},
sPort: &firewall.Port{Values: []uint16{1024, 65535}, IsRange: true},
dPort: &firewall.Port{Values: []uint16{22}},
direction: firewall.RuleDirectionOUT,
action: firewall.ActionAccept,
expectSet: false,
@@ -311,7 +311,7 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ruleKey, err := r.AddRouteFiltering(tt.sources, tt.destination, tt.proto, tt.sPort, tt.dPort, tt.action)
ruleKey, err := r.AddRouteFiltering(nil, tt.sources, tt.destination, tt.proto, tt.sPort, tt.dPort, tt.action)
require.NoError(t, err, "AddRouteFiltering failed")
t.Cleanup(func() {
@@ -319,7 +319,7 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
})
// Check if the rule is in the internal map
rule, ok := r.rules[ruleKey.GetRuleID()]
rule, ok := r.rules[ruleKey.ID()]
assert.True(t, ok, "Rule not found in internal map")
t.Log("Internal rule expressions:")
@@ -336,7 +336,7 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
var nftRule *nftables.Rule
for _, rule := range rules {
if string(rule.UserData) == ruleKey.GetRuleID() {
if string(rule.UserData) == ruleKey.ID() {
nftRule = rule
break
}
@@ -595,16 +595,20 @@ func containsPort(exprs []expr.Any, port *firewall.Port, isSource bool) bool {
if ex.Base == expr.PayloadBaseTransportHeader && ex.Offset == offset && ex.Len == 2 {
payloadFound = true
}
case *expr.Cmp:
if port.IsRange {
if ex.Op == expr.CmpOpGte || ex.Op == expr.CmpOpLte {
case *expr.Range:
if port.IsRange && len(port.Values) == 2 {
fromPort := binary.BigEndian.Uint16(ex.FromData)
toPort := binary.BigEndian.Uint16(ex.ToData)
if fromPort == port.Values[0] && toPort == port.Values[1] {
portMatchFound = true
}
} else {
}
case *expr.Cmp:
if !port.IsRange {
if ex.Op == expr.CmpOpEq && len(ex.Data) == 2 {
portValue := binary.BigEndian.Uint16(ex.Data)
for _, p := range port.Values {
if uint16(p) == portValue {
if p == portValue {
portMatchFound = true
break
}

View File

@@ -8,13 +8,14 @@ import (
// Rule to handle management of rules
type Rule struct {
nftRule *nftables.Rule
nftSet *nftables.Set
ruleID string
ip net.IP
nftRule *nftables.Rule
mangleRule *nftables.Rule
nftSet *nftables.Set
ruleID string
ip net.IP
}
// GetRuleID returns the rule id
func (r *Rule) GetRuleID() string {
func (r *Rule) ID() string {
return r.ruleID
}

View File

@@ -3,21 +3,20 @@ package nftables
import (
"fmt"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
type InterfaceState struct {
NameStr string `json:"name"`
WGAddress iface.WGAddress `json:"wg_address"`
UserspaceBind bool `json:"userspace_bind"`
NameStr string `json:"name"`
WGAddress wgaddr.Address `json:"wg_address"`
UserspaceBind bool `json:"userspace_bind"`
}
func (i *InterfaceState) Name() string {
return i.NameStr
}
func (i *InterfaceState) Address() device.WGAddress {
func (i *InterfaceState) Address() wgaddr.Address {
return i.WGAddress
}
@@ -39,7 +38,7 @@ func (s *ShutdownState) Cleanup() error {
return fmt.Errorf("create nftables manager: %w", err)
}
if err := nft.Reset(nil); err != nil {
if err := nft.Close(nil); err != nil {
return fmt.Errorf("reset nftables manager: %w", err)
}

View File

@@ -3,35 +3,49 @@
package uspfilter
import (
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
"context"
"net/netip"
"time"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/statemanager"
)
// Reset firewall to the default state
func (m *Manager) Reset(stateManager *statemanager.Manager) error {
func (m *Manager) Close(stateManager *statemanager.Manager) error {
m.mutex.Lock()
defer m.mutex.Unlock()
m.outgoingRules = make(map[string]RuleSet)
m.incomingRules = make(map[string]RuleSet)
m.outgoingRules = make(map[netip.Addr]RuleSet)
m.incomingRules = make(map[netip.Addr]RuleSet)
if m.udpTracker != nil {
m.udpTracker.Close()
m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout)
}
if m.icmpTracker != nil {
m.icmpTracker.Close()
m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout)
}
if m.tcpTracker != nil {
m.tcpTracker.Close()
m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout)
}
if fwder := m.forwarder.Load(); fwder != nil {
fwder.Stop()
}
if m.logger != nil {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
if err := m.logger.Stop(ctx); err != nil {
log.Errorf("failed to shutdown logger: %v", err)
}
}
if m.nativeFirewall != nil {
return m.nativeFirewall.Reset(stateManager)
return m.nativeFirewall.Close(stateManager)
}
return nil
}

View File

@@ -1,9 +1,12 @@
package uspfilter
import (
"context"
"fmt"
"net/netip"
"os/exec"
"syscall"
"time"
log "github.com/sirupsen/logrus"
@@ -20,26 +23,38 @@ const (
)
// Reset firewall to the default state
func (m *Manager) Reset(*statemanager.Manager) error {
func (m *Manager) Close(*statemanager.Manager) error {
m.mutex.Lock()
defer m.mutex.Unlock()
m.outgoingRules = make(map[string]RuleSet)
m.incomingRules = make(map[string]RuleSet)
m.outgoingRules = make(map[netip.Addr]RuleSet)
m.incomingRules = make(map[netip.Addr]RuleSet)
if m.udpTracker != nil {
m.udpTracker.Close()
m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout)
m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout, m.logger, m.flowLogger)
}
if m.icmpTracker != nil {
m.icmpTracker.Close()
m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout)
m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout, m.logger, m.flowLogger)
}
if m.tcpTracker != nil {
m.tcpTracker.Close()
m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout)
m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout, m.logger, m.flowLogger)
}
if fwder := m.forwarder.Load(); fwder != nil {
fwder.Stop()
}
if m.logger != nil {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
if err := m.logger.Stop(ctx); err != nil {
log.Errorf("failed to shutdown logger: %v", err)
}
}
if !isWindowsFirewallReachable() {

View File

@@ -0,0 +1,16 @@
package common
import (
wgdevice "golang.zx2c4.com/wireguard/device"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
// IFaceMapper defines subset methods of interface required for manager
type IFaceMapper interface {
SetFilter(device.PacketFilter) error
Address() wgaddr.Address
GetWGDevice() *wgdevice.Device
GetDevice() *device.FilteredDevice
}

View File

@@ -1,21 +1,27 @@
// common.go
package conntrack
import (
"net"
"sync"
"fmt"
"net/netip"
"sync/atomic"
"time"
"github.com/google/uuid"
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
)
// BaseConnTrack provides common fields and locking for all connection types
type BaseConnTrack struct {
SourceIP net.IP
DestIP net.IP
SourcePort uint16
DestPort uint16
lastSeen atomic.Int64 // Unix nano for atomic access
established atomic.Bool
FlowId uuid.UUID
Direction nftypes.Direction
SourceIP netip.Addr
DestIP netip.Addr
lastSeen atomic.Int64
PacketsTx atomic.Uint64
PacketsRx atomic.Uint64
BytesTx atomic.Uint64
BytesRx atomic.Uint64
}
// these small methods will be inlined by the compiler
@@ -25,14 +31,15 @@ func (b *BaseConnTrack) UpdateLastSeen() {
b.lastSeen.Store(time.Now().UnixNano())
}
// IsEstablished safely checks if connection is established
func (b *BaseConnTrack) IsEstablished() bool {
return b.established.Load()
}
// SetEstablished safely sets the established state
func (b *BaseConnTrack) SetEstablished(state bool) {
b.established.Store(state)
// UpdateCounters safely updates the packet and byte counters
func (b *BaseConnTrack) UpdateCounters(direction nftypes.Direction, bytes int) {
if direction == nftypes.Egress {
b.PacketsTx.Add(1)
b.BytesTx.Add(uint64(bytes))
} else {
b.PacketsRx.Add(1)
b.BytesRx.Add(uint64(bytes))
}
}
// GetLastSeen safely gets the last seen timestamp
@@ -46,92 +53,14 @@ func (b *BaseConnTrack) timeoutExceeded(timeout time.Duration) bool {
return time.Since(lastSeen) > timeout
}
// IPAddr is a fixed-size IP address to avoid allocations
type IPAddr [16]byte
// MakeIPAddr creates an IPAddr from net.IP
func MakeIPAddr(ip net.IP) (addr IPAddr) {
// Optimization: check for v4 first as it's more common
if ip4 := ip.To4(); ip4 != nil {
copy(addr[12:], ip4)
} else {
copy(addr[:], ip.To16())
}
return addr
}
// ConnKey uniquely identifies a connection
type ConnKey struct {
SrcIP IPAddr
DstIP IPAddr
SrcIP netip.Addr
DstIP netip.Addr
SrcPort uint16
DstPort uint16
}
// makeConnKey creates a connection key
func makeConnKey(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16) ConnKey {
return ConnKey{
SrcIP: MakeIPAddr(srcIP),
DstIP: MakeIPAddr(dstIP),
SrcPort: srcPort,
DstPort: dstPort,
}
}
// ValidateIPs checks if IPs match without allocation
func ValidateIPs(connIP IPAddr, pktIP net.IP) bool {
if ip4 := pktIP.To4(); ip4 != nil {
// Compare IPv4 addresses (last 4 bytes)
for i := 0; i < 4; i++ {
if connIP[12+i] != ip4[i] {
return false
}
}
return true
}
// Compare full IPv6 addresses
ip6 := pktIP.To16()
for i := 0; i < 16; i++ {
if connIP[i] != ip6[i] {
return false
}
}
return true
}
// PreallocatedIPs is a pool of IP byte slices to reduce allocations
type PreallocatedIPs struct {
sync.Pool
}
// NewPreallocatedIPs creates a new IP pool
func NewPreallocatedIPs() *PreallocatedIPs {
return &PreallocatedIPs{
Pool: sync.Pool{
New: func() interface{} {
ip := make(net.IP, 16)
return &ip
},
},
}
}
// Get retrieves an IP from the pool
func (p *PreallocatedIPs) Get() net.IP {
return *p.Pool.Get().(*net.IP)
}
// Put returns an IP to the pool
func (p *PreallocatedIPs) Put(ip net.IP) {
p.Pool.Put(&ip)
}
// copyIP copies an IP address efficiently
func copyIP(dst, src net.IP) {
if len(src) == 16 {
copy(dst, src)
} else {
// Handle IPv4
copy(dst[12:], src.To4())
}
func (c ConnKey) String() string {
return fmt.Sprintf("%s:%d -> %s:%d", c.SrcIP.Unmap(), c.SrcPort, c.DstIP.Unmap(), c.DstPort)
}

View File

@@ -1,114 +1,67 @@
package conntrack
import (
"net"
"context"
"net/netip"
"testing"
"github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/firewall/uspfilter/log"
"github.com/netbirdio/netbird/client/internal/netflow"
)
func BenchmarkIPOperations(b *testing.B) {
b.Run("MakeIPAddr", func(b *testing.B) {
ip := net.ParseIP("192.168.1.1")
b.ResetTimer()
for i := 0; i < b.N; i++ {
_ = MakeIPAddr(ip)
}
})
b.Run("ValidateIPs", func(b *testing.B) {
ip1 := net.ParseIP("192.168.1.1")
ip2 := net.ParseIP("192.168.1.1")
addr := MakeIPAddr(ip1)
b.ResetTimer()
for i := 0; i < b.N; i++ {
_ = ValidateIPs(addr, ip2)
}
})
b.Run("IPPool", func(b *testing.B) {
pool := NewPreallocatedIPs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
ip := pool.Get()
pool.Put(ip)
}
})
}
func BenchmarkAtomicOperations(b *testing.B) {
conn := &BaseConnTrack{}
b.Run("UpdateLastSeen", func(b *testing.B) {
for i := 0; i < b.N; i++ {
conn.UpdateLastSeen()
}
})
b.Run("IsEstablished", func(b *testing.B) {
for i := 0; i < b.N; i++ {
_ = conn.IsEstablished()
}
})
b.Run("SetEstablished", func(b *testing.B) {
for i := 0; i < b.N; i++ {
conn.SetEstablished(i%2 == 0)
}
})
b.Run("GetLastSeen", func(b *testing.B) {
for i := 0; i < b.N; i++ {
_ = conn.GetLastSeen()
}
})
}
var logger = log.NewFromLogrus(logrus.StandardLogger())
var flowLogger = netflow.NewManager(context.Background(), nil, []byte{}, nil).GetLogger()
// Memory pressure tests
func BenchmarkMemoryPressure(b *testing.B) {
b.Run("TCPHighLoad", func(b *testing.B) {
tracker := NewTCPTracker(DefaultTCPTimeout)
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close()
// Generate different IPs
srcIPs := make([]net.IP, 100)
dstIPs := make([]net.IP, 100)
srcIPs := make([]netip.Addr, 100)
dstIPs := make([]netip.Addr, 100)
for i := 0; i < 100; i++ {
srcIPs[i] = net.IPv4(192, 168, byte(i/256), byte(i%256))
dstIPs[i] = net.IPv4(10, 0, byte(i/256), byte(i%256))
srcIPs[i] = netip.AddrFrom4([4]byte{192, 168, byte(i / 256), byte(i % 256)})
dstIPs[i] = netip.AddrFrom4([4]byte{10, 0, byte(i / 256), byte(i % 256)})
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
srcIdx := i % len(srcIPs)
dstIdx := (i + 1) % len(dstIPs)
tracker.TrackOutbound(srcIPs[srcIdx], dstIPs[dstIdx], uint16(i%65535), 80, TCPSyn)
tracker.TrackOutbound(srcIPs[srcIdx], dstIPs[dstIdx], uint16(i%65535), 80, TCPSyn, 0)
// Simulate some valid inbound packets
if i%3 == 0 {
tracker.IsValidInbound(dstIPs[dstIdx], srcIPs[srcIdx], 80, uint16(i%65535), TCPAck)
tracker.IsValidInbound(dstIPs[dstIdx], srcIPs[srcIdx], 80, uint16(i%65535), TCPAck, 0)
}
}
})
b.Run("UDPHighLoad", func(b *testing.B) {
tracker := NewUDPTracker(DefaultUDPTimeout)
tracker := NewUDPTracker(DefaultUDPTimeout, logger, flowLogger)
defer tracker.Close()
// Generate different IPs
srcIPs := make([]net.IP, 100)
dstIPs := make([]net.IP, 100)
srcIPs := make([]netip.Addr, 100)
dstIPs := make([]netip.Addr, 100)
for i := 0; i < 100; i++ {
srcIPs[i] = net.IPv4(192, 168, byte(i/256), byte(i%256))
dstIPs[i] = net.IPv4(10, 0, byte(i/256), byte(i%256))
srcIPs[i] = netip.AddrFrom4([4]byte{192, 168, byte(i / 256), byte(i % 256)})
dstIPs[i] = netip.AddrFrom4([4]byte{10, 0, byte(i / 256), byte(i % 256)})
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
srcIdx := i % len(srcIPs)
dstIdx := (i + 1) % len(dstIPs)
tracker.TrackOutbound(srcIPs[srcIdx], dstIPs[dstIdx], uint16(i%65535), 80)
tracker.TrackOutbound(srcIPs[srcIdx], dstIPs[dstIdx], uint16(i%65535), 80, 0)
// Simulate some valid inbound packets
if i%3 == 0 {
tracker.IsValidInbound(dstIPs[dstIdx], srcIPs[srcIdx], 80, uint16(i%65535))
tracker.IsValidInbound(dstIPs[dstIdx], srcIPs[srcIdx], 80, uint16(i%65535), 0)
}
}
})

View File

@@ -1,11 +1,17 @@
package conntrack
import (
"net"
"context"
"fmt"
"net/netip"
"sync"
"time"
"github.com/google/gopacket/layers"
"github.com/google/uuid"
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
)
const (
@@ -17,154 +23,223 @@ const (
// ICMPConnKey uniquely identifies an ICMP connection
type ICMPConnKey struct {
// Supports both IPv4 and IPv6
SrcIP [16]byte
DstIP [16]byte
Sequence uint16 // ICMP sequence number
ID uint16 // ICMP identifier
SrcIP netip.Addr
DstIP netip.Addr
ID uint16
}
func (i ICMPConnKey) String() string {
return fmt.Sprintf("%s -> %s (id %d)", i.SrcIP, i.DstIP, i.ID)
}
// ICMPConnTrack represents an ICMP connection state
type ICMPConnTrack struct {
BaseConnTrack
Sequence uint16
ID uint16
ICMPType uint8
ICMPCode uint8
}
// ICMPTracker manages ICMP connection states
type ICMPTracker struct {
logger *nblog.Logger
connections map[ICMPConnKey]*ICMPConnTrack
timeout time.Duration
cleanupTicker *time.Ticker
tickerCancel context.CancelFunc
mutex sync.RWMutex
done chan struct{}
ipPool *PreallocatedIPs
flowLogger nftypes.FlowLogger
}
// NewICMPTracker creates a new ICMP connection tracker
func NewICMPTracker(timeout time.Duration) *ICMPTracker {
func NewICMPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nftypes.FlowLogger) *ICMPTracker {
if timeout == 0 {
timeout = DefaultICMPTimeout
}
ctx, cancel := context.WithCancel(context.Background())
tracker := &ICMPTracker{
logger: logger,
connections: make(map[ICMPConnKey]*ICMPConnTrack),
timeout: timeout,
cleanupTicker: time.NewTicker(ICMPCleanupInterval),
done: make(chan struct{}),
ipPool: NewPreallocatedIPs(),
tickerCancel: cancel,
flowLogger: flowLogger,
}
go tracker.cleanupRoutine()
go tracker.cleanupRoutine(ctx)
return tracker
}
// TrackOutbound records an outbound ICMP Echo Request
func (t *ICMPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, id uint16, seq uint16) {
key := makeICMPKey(srcIP, dstIP, id, seq)
now := time.Now().UnixNano()
t.mutex.Lock()
conn, exists := t.connections[key]
if !exists {
srcIPCopy := t.ipPool.Get()
dstIPCopy := t.ipPool.Get()
copyIP(srcIPCopy, srcIP)
copyIP(dstIPCopy, dstIP)
conn = &ICMPConnTrack{
BaseConnTrack: BaseConnTrack{
SourceIP: srcIPCopy,
DestIP: dstIPCopy,
},
ID: id,
Sequence: seq,
}
conn.lastSeen.Store(now)
conn.established.Store(true)
t.connections[key] = conn
func (t *ICMPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, id uint16, direction nftypes.Direction, size int) (ICMPConnKey, bool) {
key := ICMPConnKey{
SrcIP: srcIP,
DstIP: dstIP,
ID: id,
}
t.mutex.Unlock()
conn.lastSeen.Store(now)
}
// IsValidInbound checks if an inbound ICMP Echo Reply matches a tracked request
func (t *ICMPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, id uint16, seq uint16, icmpType uint8) bool {
switch icmpType {
case uint8(layers.ICMPv4TypeDestinationUnreachable),
uint8(layers.ICMPv4TypeTimeExceeded):
return true
case uint8(layers.ICMPv4TypeEchoReply):
// continue processing
default:
return false
}
key := makeICMPKey(dstIP, srcIP, id, seq)
t.mutex.RLock()
conn, exists := t.connections[key]
t.mutex.RUnlock()
if !exists {
return false
if exists {
conn.UpdateLastSeen()
conn.UpdateCounters(direction, size)
return key, true
}
if conn.timeoutExceeded(t.timeout) {
return false
}
return conn.IsEstablished() &&
ValidateIPs(MakeIPAddr(srcIP), conn.DestIP) &&
ValidateIPs(MakeIPAddr(dstIP), conn.SourceIP) &&
conn.ID == id &&
conn.Sequence == seq
return key, false
}
func (t *ICMPTracker) cleanupRoutine() {
// TrackOutbound records an outbound ICMP connection
func (t *ICMPTracker) TrackOutbound(srcIP netip.Addr, dstIP netip.Addr, id uint16, typecode layers.ICMPv4TypeCode, size int) {
if _, exists := t.updateIfExists(dstIP, srcIP, id, nftypes.Egress, size); !exists {
// if (inverted direction) conn is not tracked, track this direction
t.track(srcIP, dstIP, id, typecode, nftypes.Egress, nil, size)
}
}
// TrackInbound records an inbound ICMP Echo Request
func (t *ICMPTracker) TrackInbound(srcIP netip.Addr, dstIP netip.Addr, id uint16, typecode layers.ICMPv4TypeCode, ruleId []byte, size int) {
t.track(srcIP, dstIP, id, typecode, nftypes.Ingress, ruleId, size)
}
// track is the common implementation for tracking both inbound and outbound ICMP connections
func (t *ICMPTracker) track(srcIP netip.Addr, dstIP netip.Addr, id uint16, typecode layers.ICMPv4TypeCode, direction nftypes.Direction, ruleId []byte, size int) {
key, exists := t.updateIfExists(srcIP, dstIP, id, direction, size)
if exists {
return
}
typ, code := typecode.Type(), typecode.Code()
// non echo requests don't need tracking
if typ != uint8(layers.ICMPv4TypeEchoRequest) {
t.logger.Trace("New %s ICMP connection %s type %d code %d", direction, key, typ, code)
t.sendStartEvent(direction, srcIP, dstIP, typ, code, ruleId, size)
return
}
conn := &ICMPConnTrack{
BaseConnTrack: BaseConnTrack{
FlowId: uuid.New(),
Direction: direction,
SourceIP: srcIP,
DestIP: dstIP,
},
ICMPType: typ,
ICMPCode: code,
}
conn.UpdateLastSeen()
t.mutex.Lock()
t.connections[key] = conn
t.mutex.Unlock()
t.logger.Trace("New %s ICMP connection %s type %d code %d", direction, key, typ, code)
t.sendEvent(nftypes.TypeStart, conn, ruleId)
}
// IsValidInbound checks if an inbound ICMP Echo Reply matches a tracked request
func (t *ICMPTracker) IsValidInbound(srcIP netip.Addr, dstIP netip.Addr, id uint16, icmpType uint8, size int) bool {
if icmpType != uint8(layers.ICMPv4TypeEchoReply) {
return false
}
key := ICMPConnKey{
SrcIP: dstIP,
DstIP: srcIP,
ID: id,
}
t.mutex.RLock()
conn, exists := t.connections[key]
t.mutex.RUnlock()
if !exists || conn.timeoutExceeded(t.timeout) {
return false
}
conn.UpdateLastSeen()
conn.UpdateCounters(nftypes.Ingress, size)
return true
}
func (t *ICMPTracker) cleanupRoutine(ctx context.Context) {
defer t.tickerCancel()
for {
select {
case <-t.cleanupTicker.C:
t.cleanup()
case <-t.done:
case <-ctx.Done():
return
}
}
}
func (t *ICMPTracker) cleanup() {
t.mutex.Lock()
defer t.mutex.Unlock()
for key, conn := range t.connections {
if conn.timeoutExceeded(t.timeout) {
t.ipPool.Put(conn.SourceIP)
t.ipPool.Put(conn.DestIP)
delete(t.connections, key)
t.logger.Debug("Removed ICMP connection %s (timeout) [in: %d Pkts/%d B out: %d Pkts/%d B]",
key, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load())
t.sendEvent(nftypes.TypeEnd, conn, nil)
}
}
}
// Close stops the cleanup routine and releases resources
func (t *ICMPTracker) Close() {
t.cleanupTicker.Stop()
close(t.done)
t.tickerCancel()
t.mutex.Lock()
for _, conn := range t.connections {
t.ipPool.Put(conn.SourceIP)
t.ipPool.Put(conn.DestIP)
}
t.connections = nil
t.mutex.Unlock()
}
// makeICMPKey creates an ICMP connection key
func makeICMPKey(srcIP net.IP, dstIP net.IP, id uint16, seq uint16) ICMPConnKey {
return ICMPConnKey{
SrcIP: MakeIPAddr(srcIP),
DstIP: MakeIPAddr(dstIP),
ID: id,
Sequence: seq,
}
func (t *ICMPTracker) sendEvent(typ nftypes.Type, conn *ICMPConnTrack, ruleID []byte) {
t.flowLogger.StoreEvent(nftypes.EventFields{
FlowID: conn.FlowId,
Type: typ,
RuleID: ruleID,
Direction: conn.Direction,
Protocol: nftypes.ICMP, // TODO: adjust for IPv6/icmpv6
SourceIP: conn.SourceIP,
DestIP: conn.DestIP,
ICMPType: conn.ICMPType,
ICMPCode: conn.ICMPCode,
RxPackets: conn.PacketsRx.Load(),
TxPackets: conn.PacketsTx.Load(),
RxBytes: conn.BytesRx.Load(),
TxBytes: conn.BytesTx.Load(),
})
}
func (t *ICMPTracker) sendStartEvent(direction nftypes.Direction, srcIP netip.Addr, dstIP netip.Addr, typ uint8, code uint8, ruleID []byte, size int) {
fields := nftypes.EventFields{
FlowID: uuid.New(),
Type: nftypes.TypeStart,
RuleID: ruleID,
Direction: direction,
Protocol: nftypes.ICMP,
SourceIP: srcIP,
DestIP: dstIP,
ICMPType: typ,
ICMPCode: code,
}
if direction == nftypes.Ingress {
fields.RxPackets = 1
fields.RxBytes = uint64(size)
} else {
fields.TxPackets = 1
fields.TxBytes = uint64(size)
}
t.flowLogger.StoreEvent(fields)
}

View File

@@ -1,39 +1,39 @@
package conntrack
import (
"net"
"net/netip"
"testing"
)
func BenchmarkICMPTracker(b *testing.B) {
b.Run("TrackOutbound", func(b *testing.B) {
tracker := NewICMPTracker(DefaultICMPTimeout)
tracker := NewICMPTracker(DefaultICMPTimeout, logger, flowLogger)
defer tracker.Close()
srcIP := net.ParseIP("192.168.1.1")
dstIP := net.ParseIP("192.168.1.2")
srcIP := netip.MustParseAddr("192.168.1.1")
dstIP := netip.MustParseAddr("192.168.1.2")
b.ResetTimer()
for i := 0; i < b.N; i++ {
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), uint16(i%65535))
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 0, 0)
}
})
b.Run("IsValidInbound", func(b *testing.B) {
tracker := NewICMPTracker(DefaultICMPTimeout)
tracker := NewICMPTracker(DefaultICMPTimeout, logger, flowLogger)
defer tracker.Close()
srcIP := net.ParseIP("192.168.1.1")
dstIP := net.ParseIP("192.168.1.2")
srcIP := netip.MustParseAddr("192.168.1.1")
dstIP := netip.MustParseAddr("192.168.1.2")
// Pre-populate some connections
for i := 0; i < 1000; i++ {
tracker.TrackOutbound(srcIP, dstIP, uint16(i), uint16(i))
tracker.TrackOutbound(srcIP, dstIP, uint16(i), 0, 0)
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
tracker.IsValidInbound(dstIP, srcIP, uint16(i%1000), uint16(i%1000), 0)
tracker.IsValidInbound(dstIP, srcIP, uint16(i%1000), 0, 0)
}
})
}

View File

@@ -3,9 +3,16 @@ package conntrack
// TODO: Send RST packets for invalid/timed-out connections
import (
"net"
"context"
"net/netip"
"sync"
"sync/atomic"
"time"
"github.com/google/uuid"
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
)
const (
@@ -36,6 +43,35 @@ const (
// TCPState represents the state of a TCP connection
type TCPState int
func (s TCPState) String() string {
switch s {
case TCPStateNew:
return "New"
case TCPStateSynSent:
return "SYN Sent"
case TCPStateSynReceived:
return "SYN Received"
case TCPStateEstablished:
return "Established"
case TCPStateFinWait1:
return "FIN Wait 1"
case TCPStateFinWait2:
return "FIN Wait 2"
case TCPStateClosing:
return "Closing"
case TCPStateTimeWait:
return "Time Wait"
case TCPStateCloseWait:
return "Close Wait"
case TCPStateLastAck:
return "Last ACK"
case TCPStateClosed:
return "Closed"
default:
return "Unknown"
}
}
const (
TCPStateNew TCPState = iota
TCPStateSynSent
@@ -50,90 +86,147 @@ const (
TCPStateClosed
)
// TCPConnKey uniquely identifies a TCP connection
type TCPConnKey struct {
SrcIP [16]byte
DstIP [16]byte
SrcPort uint16
DstPort uint16
}
// TCPConnTrack represents a TCP connection state
type TCPConnTrack struct {
BaseConnTrack
State TCPState
SourcePort uint16
DestPort uint16
State TCPState
established atomic.Bool
tombstone atomic.Bool
sync.RWMutex
}
// IsEstablished safely checks if connection is established
func (t *TCPConnTrack) IsEstablished() bool {
return t.established.Load()
}
// SetEstablished safely sets the established state
func (t *TCPConnTrack) SetEstablished(state bool) {
t.established.Store(state)
}
// IsTombstone safely checks if the connection is marked for deletion
func (t *TCPConnTrack) IsTombstone() bool {
return t.tombstone.Load()
}
// SetTombstone safely marks the connection for deletion
func (t *TCPConnTrack) SetTombstone() {
t.tombstone.Store(true)
}
// TCPTracker manages TCP connection states
type TCPTracker struct {
logger *nblog.Logger
connections map[ConnKey]*TCPConnTrack
mutex sync.RWMutex
cleanupTicker *time.Ticker
done chan struct{}
tickerCancel context.CancelFunc
timeout time.Duration
ipPool *PreallocatedIPs
flowLogger nftypes.FlowLogger
}
// NewTCPTracker creates a new TCP connection tracker
func NewTCPTracker(timeout time.Duration) *TCPTracker {
tracker := &TCPTracker{
connections: make(map[ConnKey]*TCPConnTrack),
cleanupTicker: time.NewTicker(TCPCleanupInterval),
done: make(chan struct{}),
timeout: timeout,
ipPool: NewPreallocatedIPs(),
func NewTCPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nftypes.FlowLogger) *TCPTracker {
if timeout == 0 {
timeout = DefaultTCPTimeout
}
go tracker.cleanupRoutine()
ctx, cancel := context.WithCancel(context.Background())
tracker := &TCPTracker{
logger: logger,
connections: make(map[ConnKey]*TCPConnTrack),
cleanupTicker: time.NewTicker(TCPCleanupInterval),
tickerCancel: cancel,
timeout: timeout,
flowLogger: flowLogger,
}
go tracker.cleanupRoutine(ctx)
return tracker
}
// TrackOutbound processes an outbound TCP packet and updates connection state
func (t *TCPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16, flags uint8) {
// Create key before lock
key := makeConnKey(srcIP, dstIP, srcPort, dstPort)
now := time.Now().UnixNano()
func (t *TCPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, flags uint8, direction nftypes.Direction, size int) (ConnKey, bool) {
key := ConnKey{
SrcIP: srcIP,
DstIP: dstIP,
SrcPort: srcPort,
DstPort: dstPort,
}
t.mutex.RLock()
conn, exists := t.connections[key]
t.mutex.RUnlock()
if exists {
conn.Lock()
t.updateState(key, conn, flags, conn.Direction == nftypes.Egress)
conn.Unlock()
conn.UpdateCounters(direction, size)
return key, true
}
return key, false
}
// TrackOutbound records an outbound TCP connection
func (t *TCPTracker) TrackOutbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, flags uint8, size int) {
if _, exists := t.updateIfExists(dstIP, srcIP, dstPort, srcPort, flags, 0, 0); !exists {
// if (inverted direction) conn is not tracked, track this direction
t.track(srcIP, dstIP, srcPort, dstPort, flags, nftypes.Egress, nil, size)
}
}
// TrackInbound processes an inbound TCP packet and updates connection state
func (t *TCPTracker) TrackInbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, flags uint8, ruleID []byte, size int) {
t.track(srcIP, dstIP, srcPort, dstPort, flags, nftypes.Ingress, ruleID, size)
}
// track is the common implementation for tracking both inbound and outbound connections
func (t *TCPTracker) track(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, flags uint8, direction nftypes.Direction, ruleID []byte, size int) {
key, exists := t.updateIfExists(srcIP, dstIP, srcPort, dstPort, flags, direction, size)
if exists {
return
}
conn := &TCPConnTrack{
BaseConnTrack: BaseConnTrack{
FlowId: uuid.New(),
Direction: direction,
SourceIP: srcIP,
DestIP: dstIP,
},
SourcePort: srcPort,
DestPort: dstPort,
}
conn.established.Store(false)
conn.tombstone.Store(false)
t.logger.Trace("New %s TCP connection: %s", direction, key)
t.updateState(key, conn, flags, direction == nftypes.Egress)
t.mutex.Lock()
conn, exists := t.connections[key]
if !exists {
// Use preallocated IPs
srcIPCopy := t.ipPool.Get()
dstIPCopy := t.ipPool.Get()
copyIP(srcIPCopy, srcIP)
copyIP(dstIPCopy, dstIP)
conn = &TCPConnTrack{
BaseConnTrack: BaseConnTrack{
SourceIP: srcIPCopy,
DestIP: dstIPCopy,
SourcePort: srcPort,
DestPort: dstPort,
},
State: TCPStateNew,
}
conn.lastSeen.Store(now)
conn.established.Store(false)
t.connections[key] = conn
}
t.connections[key] = conn
t.mutex.Unlock()
// Lock individual connection for state update
conn.Lock()
t.updateState(conn, flags, true)
conn.Unlock()
conn.lastSeen.Store(now)
t.sendEvent(nftypes.TypeStart, conn, ruleID)
}
// IsValidInbound checks if an inbound TCP packet matches a tracked connection
func (t *TCPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16, flags uint8) bool {
if !isValidFlagCombination(flags) {
return false
func (t *TCPTracker) IsValidInbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, flags uint8, size int) bool {
key := ConnKey{
SrcIP: dstIP,
DstIP: srcIP,
SrcPort: dstPort,
DstPort: srcPort,
}
key := makeConnKey(dstIP, srcIP, dstPort, srcPort)
t.mutex.RLock()
conn, exists := t.connections[key]
t.mutex.RUnlock()
@@ -142,22 +235,26 @@ func (t *TCPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, srcPort uint16,
return false
}
// Handle RST packets
// Handle RST flag specially - it always causes transition to closed
if flags&TCPRst != 0 {
conn.Lock()
if conn.IsEstablished() || conn.State == TCPStateSynSent || conn.State == TCPStateSynReceived {
conn.State = TCPStateClosed
conn.SetEstablished(false)
conn.Unlock()
if conn.IsTombstone() {
return true
}
conn.Lock()
conn.SetTombstone()
conn.State = TCPStateClosed
conn.SetEstablished(false)
conn.Unlock()
return false
conn.UpdateCounters(nftypes.Ingress, size)
t.logger.Trace("TCP connection reset: %s", key)
t.sendEvent(nftypes.TypeEnd, conn, nil)
return true
}
conn.Lock()
t.updateState(conn, flags, false)
conn.UpdateLastSeen()
t.updateState(key, conn, flags, false)
isEstablished := conn.IsEstablished()
isValidState := t.isValidStateForFlags(conn.State, flags)
conn.Unlock()
@@ -166,15 +263,17 @@ func (t *TCPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, srcPort uint16,
}
// updateState updates the TCP connection state based on flags
func (t *TCPTracker) updateState(conn *TCPConnTrack, flags uint8, isOutbound bool) {
// Handle RST flag specially - it always causes transition to closed
if flags&TCPRst != 0 {
conn.State = TCPStateClosed
conn.SetEstablished(false)
return
}
func (t *TCPTracker) updateState(key ConnKey, conn *TCPConnTrack, flags uint8, isOutbound bool) {
conn.UpdateLastSeen()
switch conn.State {
state := conn.State
defer func() {
if state != conn.State {
t.logger.Trace("TCP connection %s transitioned from %s to %s", key, state, conn.State)
}
}()
switch state {
case TCPStateNew:
if flags&TCPSyn != 0 && flags&TCPAck == 0 {
conn.State = TCPStateSynSent
@@ -183,11 +282,11 @@ func (t *TCPTracker) updateState(conn *TCPConnTrack, flags uint8, isOutbound boo
case TCPStateSynSent:
if flags&TCPSyn != 0 && flags&TCPAck != 0 {
if isOutbound {
conn.State = TCPStateSynReceived
} else {
// Simultaneous open
conn.State = TCPStateEstablished
conn.SetEstablished(true)
} else {
// Simultaneous open
conn.State = TCPStateSynReceived
}
}
@@ -205,28 +304,41 @@ func (t *TCPTracker) updateState(conn *TCPConnTrack, flags uint8, isOutbound boo
conn.State = TCPStateCloseWait
}
conn.SetEstablished(false)
} else if flags&TCPRst != 0 {
conn.State = TCPStateClosed
conn.SetTombstone()
t.sendEvent(nftypes.TypeEnd, conn, nil)
}
case TCPStateFinWait1:
switch {
case flags&TCPFin != 0 && flags&TCPAck != 0:
// Simultaneous close - both sides sent FIN
conn.State = TCPStateClosing
case flags&TCPFin != 0:
conn.State = TCPStateFinWait2
case flags&TCPAck != 0:
conn.State = TCPStateFinWait2
case flags&TCPRst != 0:
conn.State = TCPStateClosed
conn.SetTombstone()
t.sendEvent(nftypes.TypeEnd, conn, nil)
}
case TCPStateFinWait2:
if flags&TCPFin != 0 {
conn.State = TCPStateTimeWait
t.logger.Trace("TCP connection %s completed", key)
t.sendEvent(nftypes.TypeEnd, conn, nil)
}
case TCPStateClosing:
if flags&TCPAck != 0 {
conn.State = TCPStateTimeWait
// Keep established = false from previous state
t.logger.Trace("TCP connection %s closed (simultaneous)", key)
t.sendEvent(nftypes.TypeEnd, conn, nil)
}
case TCPStateCloseWait:
@@ -237,11 +349,12 @@ func (t *TCPTracker) updateState(conn *TCPConnTrack, flags uint8, isOutbound boo
case TCPStateLastAck:
if flags&TCPAck != 0 {
conn.State = TCPStateClosed
}
conn.SetTombstone()
case TCPStateTimeWait:
// Stay in TIME-WAIT for 2MSL before transitioning to closed
// This is handled by the cleanup routine
// Send close event for gracefully closed connections
t.sendEvent(nftypes.TypeEnd, conn, nil)
t.logger.Trace("TCP connection %s closed gracefully", key)
}
}
}
@@ -286,12 +399,14 @@ func (t *TCPTracker) isValidStateForFlags(state TCPState, flags uint8) bool {
return false
}
func (t *TCPTracker) cleanupRoutine() {
func (t *TCPTracker) cleanupRoutine(ctx context.Context) {
defer t.cleanupTicker.Stop()
for {
select {
case <-t.cleanupTicker.C:
t.cleanup()
case <-t.done:
case <-ctx.Done():
return
}
}
@@ -302,6 +417,12 @@ func (t *TCPTracker) cleanup() {
defer t.mutex.Unlock()
for key, conn := range t.connections {
if conn.IsTombstone() {
// Clean up tombstoned connections without sending an event
delete(t.connections, key)
continue
}
var timeout time.Duration
switch {
case conn.State == TCPStateTimeWait:
@@ -312,27 +433,26 @@ func (t *TCPTracker) cleanup() {
timeout = TCPHandshakeTimeout
}
lastSeen := conn.GetLastSeen()
if time.Since(lastSeen) > timeout {
if conn.timeoutExceeded(timeout) {
// Return IPs to pool
t.ipPool.Put(conn.SourceIP)
t.ipPool.Put(conn.DestIP)
delete(t.connections, key)
t.logger.Trace("Cleaned up timed-out TCP connection %s", key)
// event already handled by state change
if conn.State != TCPStateTimeWait {
t.sendEvent(nftypes.TypeEnd, conn, nil)
}
}
}
}
// Close stops the cleanup routine and releases resources
func (t *TCPTracker) Close() {
t.cleanupTicker.Stop()
close(t.done)
t.tickerCancel()
// Clean up all remaining IPs
t.mutex.Lock()
for _, conn := range t.connections {
t.ipPool.Put(conn.SourceIP)
t.ipPool.Put(conn.DestIP)
}
t.connections = nil
t.mutex.Unlock()
}
@@ -350,3 +470,21 @@ func isValidFlagCombination(flags uint8) bool {
return true
}
func (t *TCPTracker) sendEvent(typ nftypes.Type, conn *TCPConnTrack, ruleID []byte) {
t.flowLogger.StoreEvent(nftypes.EventFields{
FlowID: conn.FlowId,
Type: typ,
RuleID: ruleID,
Direction: conn.Direction,
Protocol: nftypes.TCP,
SourceIP: conn.SourceIP,
DestIP: conn.DestIP,
SourcePort: conn.SourcePort,
DestPort: conn.DestPort,
RxPackets: conn.PacketsRx.Load(),
TxPackets: conn.PacketsTx.Load(),
RxBytes: conn.BytesRx.Load(),
TxBytes: conn.BytesTx.Load(),
})
}

View File

@@ -1,7 +1,7 @@
package conntrack
import (
"net"
"net/netip"
"testing"
"time"
@@ -9,11 +9,11 @@ import (
)
func TestTCPStateMachine(t *testing.T) {
tracker := NewTCPTracker(DefaultTCPTimeout)
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close()
srcIP := net.ParseIP("100.64.0.1")
dstIP := net.ParseIP("100.64.0.2")
srcIP := netip.MustParseAddr("100.64.0.1")
dstIP := netip.MustParseAddr("100.64.0.2")
srcPort := uint16(12345)
dstPort := uint16(80)
@@ -58,7 +58,7 @@ func TestTCPStateMachine(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
isValid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, tt.flags)
isValid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, tt.flags, 0)
require.Equal(t, !tt.wantDrop, isValid, tt.desc)
})
}
@@ -76,17 +76,17 @@ func TestTCPStateMachine(t *testing.T) {
t.Helper()
// Send initial SYN
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn)
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn, 0)
// Receive SYN-ACK
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck)
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck, 0)
require.True(t, valid, "SYN-ACK should be allowed")
// Send ACK
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck)
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
// Test data transfer
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPPush|TCPAck)
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPPush|TCPAck, 0)
require.True(t, valid, "Data should be allowed after handshake")
},
},
@@ -99,18 +99,18 @@ func TestTCPStateMachine(t *testing.T) {
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
// Send FIN
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck)
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
// Receive ACK for FIN
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck)
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
require.True(t, valid, "ACK for FIN should be allowed")
// Receive FIN from other side
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck)
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0)
require.True(t, valid, "FIN should be allowed")
// Send final ACK
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck)
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
},
},
{
@@ -122,7 +122,7 @@ func TestTCPStateMachine(t *testing.T) {
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
// Receive RST
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst)
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst, 0)
require.True(t, valid, "RST should be allowed for established connection")
// Connection is logically dead but we don't enforce blocking subsequent packets
@@ -138,13 +138,13 @@ func TestTCPStateMachine(t *testing.T) {
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
// Both sides send FIN+ACK
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck)
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck)
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0)
require.True(t, valid, "Simultaneous FIN should be allowed")
// Both sides send final ACK
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck)
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck)
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
require.True(t, valid, "Final ACKs should be allowed")
},
},
@@ -154,7 +154,7 @@ func TestTCPStateMachine(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
t.Helper()
tracker = NewTCPTracker(DefaultTCPTimeout)
tracker = NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
tt.test(t)
})
}
@@ -162,11 +162,11 @@ func TestTCPStateMachine(t *testing.T) {
}
func TestRSTHandling(t *testing.T) {
tracker := NewTCPTracker(DefaultTCPTimeout)
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close()
srcIP := net.ParseIP("100.64.0.1")
dstIP := net.ParseIP("100.64.0.2")
srcIP := netip.MustParseAddr("100.64.0.1")
dstIP := netip.MustParseAddr("100.64.0.2")
srcPort := uint16(12345)
dstPort := uint16(80)
@@ -181,12 +181,12 @@ func TestRSTHandling(t *testing.T) {
name: "RST in established",
setupState: func() {
// Establish connection first
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn)
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck)
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck)
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn, 0)
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck, 0)
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
},
sendRST: func() {
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst)
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst, 0)
},
wantValid: true,
desc: "Should accept RST for established connection",
@@ -195,7 +195,7 @@ func TestRSTHandling(t *testing.T) {
name: "RST without connection",
setupState: func() {},
sendRST: func() {
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst)
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst, 0)
},
wantValid: false,
desc: "Should reject RST without connection",
@@ -208,7 +208,12 @@ func TestRSTHandling(t *testing.T) {
tt.sendRST()
// Verify connection state is as expected
key := makeConnKey(srcIP, dstIP, srcPort, dstPort)
key := ConnKey{
SrcIP: srcIP,
DstIP: dstIP,
SrcPort: srcPort,
DstPort: dstPort,
}
conn := tracker.connections[key]
if tt.wantValid {
require.NotNil(t, conn)
@@ -220,63 +225,63 @@ func TestRSTHandling(t *testing.T) {
}
// Helper to establish a TCP connection
func establishConnection(t *testing.T, tracker *TCPTracker, srcIP, dstIP net.IP, srcPort, dstPort uint16) {
func establishConnection(t *testing.T, tracker *TCPTracker, srcIP, dstIP netip.Addr, srcPort, dstPort uint16) {
t.Helper()
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn)
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn, 0)
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck)
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck, 0)
require.True(t, valid, "SYN-ACK should be allowed")
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck)
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
}
func BenchmarkTCPTracker(b *testing.B) {
b.Run("TrackOutbound", func(b *testing.B) {
tracker := NewTCPTracker(DefaultTCPTimeout)
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close()
srcIP := net.ParseIP("192.168.1.1")
dstIP := net.ParseIP("192.168.1.2")
srcIP := netip.MustParseAddr("192.168.1.1")
dstIP := netip.MustParseAddr("192.168.1.2")
b.ResetTimer()
for i := 0; i < b.N; i++ {
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 80, TCPSyn)
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 80, TCPSyn, 0)
}
})
b.Run("IsValidInbound", func(b *testing.B) {
tracker := NewTCPTracker(DefaultTCPTimeout)
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close()
srcIP := net.ParseIP("192.168.1.1")
dstIP := net.ParseIP("192.168.1.2")
srcIP := netip.MustParseAddr("192.168.1.1")
dstIP := netip.MustParseAddr("192.168.1.2")
// Pre-populate some connections
for i := 0; i < 1000; i++ {
tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80, TCPSyn)
tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80, TCPSyn, 0)
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
tracker.IsValidInbound(dstIP, srcIP, 80, uint16(i%1000), TCPAck)
tracker.IsValidInbound(dstIP, srcIP, 80, uint16(i%1000), TCPAck, 0)
}
})
b.Run("ConcurrentAccess", func(b *testing.B) {
tracker := NewTCPTracker(DefaultTCPTimeout)
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close()
srcIP := net.ParseIP("192.168.1.1")
dstIP := net.ParseIP("192.168.1.2")
srcIP := netip.MustParseAddr("192.168.1.1")
dstIP := netip.MustParseAddr("192.168.1.2")
b.RunParallel(func(pb *testing.PB) {
i := 0
for pb.Next() {
if i%2 == 0 {
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 80, TCPSyn)
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 80, TCPSyn, 0)
} else {
tracker.IsValidInbound(dstIP, srcIP, 80, uint16(i%65535), TCPAck)
tracker.IsValidInbound(dstIP, srcIP, 80, uint16(i%65535), TCPAck, 0)
}
i++
}
@@ -287,14 +292,14 @@ func BenchmarkTCPTracker(b *testing.B) {
// Benchmark connection cleanup
func BenchmarkCleanup(b *testing.B) {
b.Run("TCPCleanup", func(b *testing.B) {
tracker := NewTCPTracker(100 * time.Millisecond) // Short timeout for testing
tracker := NewTCPTracker(100*time.Millisecond, logger, flowLogger) // Short timeout for testing
defer tracker.Close()
// Pre-populate with expired connections
srcIP := net.ParseIP("192.168.1.1")
dstIP := net.ParseIP("192.168.1.2")
srcIP := netip.MustParseAddr("192.168.1.1")
dstIP := netip.MustParseAddr("192.168.1.2")
for i := 0; i < 10000; i++ {
tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80, TCPSyn)
tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80, TCPSyn, 0)
}
// Wait for connections to expire

View File

@@ -1,9 +1,15 @@
package conntrack
import (
"net"
"context"
"net/netip"
"sync"
"time"
"github.com/google/uuid"
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
)
const (
@@ -16,96 +22,135 @@ const (
// UDPConnTrack represents a UDP connection state
type UDPConnTrack struct {
BaseConnTrack
SourcePort uint16
DestPort uint16
}
// UDPTracker manages UDP connection states
type UDPTracker struct {
logger *nblog.Logger
connections map[ConnKey]*UDPConnTrack
timeout time.Duration
cleanupTicker *time.Ticker
tickerCancel context.CancelFunc
mutex sync.RWMutex
done chan struct{}
ipPool *PreallocatedIPs
flowLogger nftypes.FlowLogger
}
// NewUDPTracker creates a new UDP connection tracker
func NewUDPTracker(timeout time.Duration) *UDPTracker {
func NewUDPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nftypes.FlowLogger) *UDPTracker {
if timeout == 0 {
timeout = DefaultUDPTimeout
}
ctx, cancel := context.WithCancel(context.Background())
tracker := &UDPTracker{
logger: logger,
connections: make(map[ConnKey]*UDPConnTrack),
timeout: timeout,
cleanupTicker: time.NewTicker(UDPCleanupInterval),
done: make(chan struct{}),
ipPool: NewPreallocatedIPs(),
tickerCancel: cancel,
flowLogger: flowLogger,
}
go tracker.cleanupRoutine()
go tracker.cleanupRoutine(ctx)
return tracker
}
// TrackOutbound records an outbound UDP connection
func (t *UDPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16) {
key := makeConnKey(srcIP, dstIP, srcPort, dstPort)
now := time.Now().UnixNano()
t.mutex.Lock()
conn, exists := t.connections[key]
if !exists {
srcIPCopy := t.ipPool.Get()
dstIPCopy := t.ipPool.Get()
copyIP(srcIPCopy, srcIP)
copyIP(dstIPCopy, dstIP)
conn = &UDPConnTrack{
BaseConnTrack: BaseConnTrack{
SourceIP: srcIPCopy,
DestIP: dstIPCopy,
SourcePort: srcPort,
DestPort: dstPort,
},
}
conn.lastSeen.Store(now)
conn.established.Store(true)
t.connections[key] = conn
func (t *UDPTracker) TrackOutbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, size int) {
if _, exists := t.updateIfExists(dstIP, srcIP, dstPort, srcPort, nftypes.Egress, size); !exists {
// if (inverted direction) conn is not tracked, track this direction
t.track(srcIP, dstIP, srcPort, dstPort, nftypes.Egress, nil, size)
}
t.mutex.Unlock()
conn.lastSeen.Store(now)
}
// IsValidInbound checks if an inbound packet matches a tracked connection
func (t *UDPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16) bool {
key := makeConnKey(dstIP, srcIP, dstPort, srcPort)
// TrackInbound records an inbound UDP connection
func (t *UDPTracker) TrackInbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, ruleID []byte, size int) {
t.track(srcIP, dstIP, srcPort, dstPort, nftypes.Ingress, ruleID, size)
}
func (t *UDPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, direction nftypes.Direction, size int) (ConnKey, bool) {
key := ConnKey{
SrcIP: srcIP,
DstIP: dstIP,
SrcPort: srcPort,
DstPort: dstPort,
}
t.mutex.RLock()
conn, exists := t.connections[key]
t.mutex.RUnlock()
if !exists {
if exists {
conn.UpdateLastSeen()
conn.UpdateCounters(direction, size)
return key, true
}
return key, false
}
// track is the common implementation for tracking both inbound and outbound connections
func (t *UDPTracker) track(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, direction nftypes.Direction, ruleID []byte, size int) {
key, exists := t.updateIfExists(srcIP, dstIP, srcPort, dstPort, direction, size)
if exists {
return
}
conn := &UDPConnTrack{
BaseConnTrack: BaseConnTrack{
FlowId: uuid.New(),
Direction: direction,
SourceIP: srcIP,
DestIP: dstIP,
},
SourcePort: srcPort,
DestPort: dstPort,
}
conn.UpdateLastSeen()
t.mutex.Lock()
t.connections[key] = conn
t.mutex.Unlock()
t.logger.Trace("New %s UDP connection: %s", direction, key)
t.sendEvent(nftypes.TypeStart, conn, ruleID)
}
// IsValidInbound checks if an inbound packet matches a tracked connection
func (t *UDPTracker) IsValidInbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, size int) bool {
key := ConnKey{
SrcIP: dstIP,
DstIP: srcIP,
SrcPort: dstPort,
DstPort: srcPort,
}
t.mutex.RLock()
conn, exists := t.connections[key]
t.mutex.RUnlock()
if !exists || conn.timeoutExceeded(t.timeout) {
return false
}
if conn.timeoutExceeded(t.timeout) {
return false
}
conn.UpdateLastSeen()
conn.UpdateCounters(nftypes.Ingress, size)
return conn.IsEstablished() &&
ValidateIPs(MakeIPAddr(srcIP), conn.DestIP) &&
ValidateIPs(MakeIPAddr(dstIP), conn.SourceIP) &&
conn.DestPort == srcPort &&
conn.SourcePort == dstPort
return true
}
// cleanupRoutine periodically removes stale connections
func (t *UDPTracker) cleanupRoutine() {
func (t *UDPTracker) cleanupRoutine(ctx context.Context) {
defer t.cleanupTicker.Stop()
for {
select {
case <-t.cleanupTicker.C:
t.cleanup()
case <-t.done:
case <-ctx.Done():
return
}
}
@@ -117,42 +162,58 @@ func (t *UDPTracker) cleanup() {
for key, conn := range t.connections {
if conn.timeoutExceeded(t.timeout) {
t.ipPool.Put(conn.SourceIP)
t.ipPool.Put(conn.DestIP)
delete(t.connections, key)
t.logger.Trace("Removed UDP connection %s (timeout) [in: %d Pkts/%d B out: %d Pkts/%d B]",
key, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load())
t.sendEvent(nftypes.TypeEnd, conn, nil)
}
}
}
// Close stops the cleanup routine and releases resources
func (t *UDPTracker) Close() {
t.cleanupTicker.Stop()
close(t.done)
t.tickerCancel()
t.mutex.Lock()
for _, conn := range t.connections {
t.ipPool.Put(conn.SourceIP)
t.ipPool.Put(conn.DestIP)
}
t.connections = nil
t.mutex.Unlock()
}
// GetConnection safely retrieves a connection state
func (t *UDPTracker) GetConnection(srcIP net.IP, srcPort uint16, dstIP net.IP, dstPort uint16) (*UDPConnTrack, bool) {
func (t *UDPTracker) GetConnection(srcIP netip.Addr, srcPort uint16, dstIP netip.Addr, dstPort uint16) (*UDPConnTrack, bool) {
t.mutex.RLock()
defer t.mutex.RUnlock()
key := makeConnKey(srcIP, dstIP, srcPort, dstPort)
conn, exists := t.connections[key]
if !exists {
return nil, false
key := ConnKey{
SrcIP: srcIP,
DstIP: dstIP,
SrcPort: srcPort,
DstPort: dstPort,
}
return conn, true
conn, exists := t.connections[key]
return conn, exists
}
// Timeout returns the configured timeout duration for the tracker
func (t *UDPTracker) Timeout() time.Duration {
return t.timeout
}
func (t *UDPTracker) sendEvent(typ nftypes.Type, conn *UDPConnTrack, ruleID []byte) {
t.flowLogger.StoreEvent(nftypes.EventFields{
FlowID: conn.FlowId,
Type: typ,
RuleID: ruleID,
Direction: conn.Direction,
Protocol: nftypes.UDP,
SourceIP: conn.SourceIP,
DestIP: conn.DestIP,
SourcePort: conn.SourcePort,
DestPort: conn.DestPort,
RxPackets: conn.PacketsRx.Load(),
TxPackets: conn.PacketsTx.Load(),
RxBytes: conn.BytesRx.Load(),
TxBytes: conn.BytesTx.Load(),
})
}

View File

@@ -1,7 +1,8 @@
package conntrack
import (
"net"
"context"
"net/netip"
"testing"
"time"
@@ -29,55 +30,59 @@ func TestNewUDPTracker(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tracker := NewUDPTracker(tt.timeout)
tracker := NewUDPTracker(tt.timeout, logger, flowLogger)
assert.NotNil(t, tracker)
assert.Equal(t, tt.wantTimeout, tracker.timeout)
assert.NotNil(t, tracker.connections)
assert.NotNil(t, tracker.cleanupTicker)
assert.NotNil(t, tracker.done)
assert.NotNil(t, tracker.tickerCancel)
})
}
}
func TestUDPTracker_TrackOutbound(t *testing.T) {
tracker := NewUDPTracker(DefaultUDPTimeout)
tracker := NewUDPTracker(DefaultUDPTimeout, logger, flowLogger)
defer tracker.Close()
srcIP := net.ParseIP("192.168.1.2")
dstIP := net.ParseIP("192.168.1.3")
srcIP := netip.MustParseAddr("192.168.1.2")
dstIP := netip.MustParseAddr("192.168.1.3")
srcPort := uint16(12345)
dstPort := uint16(53)
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort)
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, 0)
// Verify connection was tracked
key := makeConnKey(srcIP, dstIP, srcPort, dstPort)
key := ConnKey{
SrcIP: srcIP,
DstIP: dstIP,
SrcPort: srcPort,
DstPort: dstPort,
}
conn, exists := tracker.connections[key]
require.True(t, exists)
assert.True(t, conn.SourceIP.Equal(srcIP))
assert.True(t, conn.DestIP.Equal(dstIP))
assert.True(t, conn.SourceIP.Compare(srcIP) == 0)
assert.True(t, conn.DestIP.Compare(dstIP) == 0)
assert.Equal(t, srcPort, conn.SourcePort)
assert.Equal(t, dstPort, conn.DestPort)
assert.True(t, conn.IsEstablished())
assert.WithinDuration(t, time.Now(), conn.GetLastSeen(), 1*time.Second)
}
func TestUDPTracker_IsValidInbound(t *testing.T) {
tracker := NewUDPTracker(1 * time.Second)
tracker := NewUDPTracker(1*time.Second, logger, flowLogger)
defer tracker.Close()
srcIP := net.ParseIP("192.168.1.2")
dstIP := net.ParseIP("192.168.1.3")
srcIP := netip.MustParseAddr("192.168.1.2")
dstIP := netip.MustParseAddr("192.168.1.3")
srcPort := uint16(12345)
dstPort := uint16(53)
// Track outbound connection
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort)
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, 0)
tests := []struct {
name string
srcIP net.IP
dstIP net.IP
srcIP netip.Addr
dstIP netip.Addr
srcPort uint16
dstPort uint16
sleep time.Duration
@@ -94,7 +99,7 @@ func TestUDPTracker_IsValidInbound(t *testing.T) {
},
{
name: "invalid source IP",
srcIP: net.ParseIP("192.168.1.4"),
srcIP: netip.MustParseAddr("192.168.1.4"),
dstIP: srcIP,
srcPort: dstPort,
dstPort: srcPort,
@@ -104,7 +109,7 @@ func TestUDPTracker_IsValidInbound(t *testing.T) {
{
name: "invalid destination IP",
srcIP: dstIP,
dstIP: net.ParseIP("192.168.1.4"),
dstIP: netip.MustParseAddr("192.168.1.4"),
srcPort: dstPort,
dstPort: srcPort,
sleep: 0,
@@ -144,7 +149,7 @@ func TestUDPTracker_IsValidInbound(t *testing.T) {
if tt.sleep > 0 {
time.Sleep(tt.sleep)
}
got := tracker.IsValidInbound(tt.srcIP, tt.dstIP, tt.srcPort, tt.dstPort)
got := tracker.IsValidInbound(tt.srcIP, tt.dstIP, tt.srcPort, tt.dstPort, 0)
assert.Equal(t, tt.want, got)
})
}
@@ -155,41 +160,45 @@ func TestUDPTracker_Cleanup(t *testing.T) {
timeout := 50 * time.Millisecond
cleanupInterval := 25 * time.Millisecond
ctx, tickerCancel := context.WithCancel(context.Background())
defer tickerCancel()
// Create tracker with custom cleanup interval
tracker := &UDPTracker{
connections: make(map[ConnKey]*UDPConnTrack),
timeout: timeout,
cleanupTicker: time.NewTicker(cleanupInterval),
done: make(chan struct{}),
ipPool: NewPreallocatedIPs(),
tickerCancel: tickerCancel,
logger: logger,
flowLogger: flowLogger,
}
// Start cleanup routine
go tracker.cleanupRoutine()
go tracker.cleanupRoutine(ctx)
// Add some connections
connections := []struct {
srcIP net.IP
dstIP net.IP
srcIP netip.Addr
dstIP netip.Addr
srcPort uint16
dstPort uint16
}{
{
srcIP: net.ParseIP("192.168.1.2"),
dstIP: net.ParseIP("192.168.1.3"),
srcIP: netip.MustParseAddr("192.168.1.2"),
dstIP: netip.MustParseAddr("192.168.1.3"),
srcPort: 12345,
dstPort: 53,
},
{
srcIP: net.ParseIP("192.168.1.4"),
dstIP: net.ParseIP("192.168.1.5"),
srcIP: netip.MustParseAddr("192.168.1.4"),
dstIP: netip.MustParseAddr("192.168.1.5"),
srcPort: 12346,
dstPort: 53,
},
}
for _, conn := range connections {
tracker.TrackOutbound(conn.srcIP, conn.dstIP, conn.srcPort, conn.dstPort)
tracker.TrackOutbound(conn.srcIP, conn.dstIP, conn.srcPort, conn.dstPort, 0)
}
// Verify initial connections
@@ -211,33 +220,33 @@ func TestUDPTracker_Cleanup(t *testing.T) {
func BenchmarkUDPTracker(b *testing.B) {
b.Run("TrackOutbound", func(b *testing.B) {
tracker := NewUDPTracker(DefaultUDPTimeout)
tracker := NewUDPTracker(DefaultUDPTimeout, logger, flowLogger)
defer tracker.Close()
srcIP := net.ParseIP("192.168.1.1")
dstIP := net.ParseIP("192.168.1.2")
srcIP := netip.MustParseAddr("192.168.1.1")
dstIP := netip.MustParseAddr("192.168.1.2")
b.ResetTimer()
for i := 0; i < b.N; i++ {
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 80)
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 80, 0)
}
})
b.Run("IsValidInbound", func(b *testing.B) {
tracker := NewUDPTracker(DefaultUDPTimeout)
tracker := NewUDPTracker(DefaultUDPTimeout, logger, flowLogger)
defer tracker.Close()
srcIP := net.ParseIP("192.168.1.1")
dstIP := net.ParseIP("192.168.1.2")
srcIP := netip.MustParseAddr("192.168.1.1")
dstIP := netip.MustParseAddr("192.168.1.2")
// Pre-populate some connections
for i := 0; i < 1000; i++ {
tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80)
tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80, 0)
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
tracker.IsValidInbound(dstIP, srcIP, 80, uint16(i%1000))
tracker.IsValidInbound(dstIP, srcIP, 80, uint16(i%1000), 0)
}
})
}

View File

@@ -0,0 +1,90 @@
package forwarder
import (
"fmt"
wgdevice "golang.zx2c4.com/wireguard/device"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/stack"
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
)
// endpoint implements stack.LinkEndpoint and handles integration with the wireguard device
type endpoint struct {
logger *nblog.Logger
dispatcher stack.NetworkDispatcher
device *wgdevice.Device
mtu uint32
}
func (e *endpoint) Attach(dispatcher stack.NetworkDispatcher) {
e.dispatcher = dispatcher
}
func (e *endpoint) IsAttached() bool {
return e.dispatcher != nil
}
func (e *endpoint) MTU() uint32 {
return e.mtu
}
func (e *endpoint) Capabilities() stack.LinkEndpointCapabilities {
return stack.CapabilityNone
}
func (e *endpoint) MaxHeaderLength() uint16 {
return 0
}
func (e *endpoint) LinkAddress() tcpip.LinkAddress {
return ""
}
func (e *endpoint) WritePackets(pkts stack.PacketBufferList) (int, tcpip.Error) {
var written int
for _, pkt := range pkts.AsSlice() {
netHeader := header.IPv4(pkt.NetworkHeader().View().AsSlice())
data := stack.PayloadSince(pkt.NetworkHeader())
if data == nil {
continue
}
// Send the packet through WireGuard
address := netHeader.DestinationAddress()
err := e.device.CreateOutboundPacket(data.AsSlice(), address.AsSlice())
if err != nil {
e.logger.Error("CreateOutboundPacket: %v", err)
continue
}
written++
}
return written, nil
}
func (e *endpoint) Wait() {
// not required
}
func (e *endpoint) ARPHardwareType() header.ARPHardwareType {
return header.ARPHardwareNone
}
func (e *endpoint) AddHeader(*stack.PacketBuffer) {
// not required
}
func (e *endpoint) ParseHeader(*stack.PacketBuffer) bool {
return true
}
type epID stack.TransportEndpointID
func (i epID) String() string {
// src and remote is swapped
return fmt.Sprintf("%s:%d -> %s:%d", i.RemoteAddress, i.RemotePort, i.LocalAddress, i.LocalPort)
}

View File

@@ -0,0 +1,169 @@
package forwarder
import (
"context"
"fmt"
"net"
"runtime"
log "github.com/sirupsen/logrus"
"gvisor.dev/gvisor/pkg/buffer"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
"github.com/netbirdio/netbird/client/firewall/uspfilter/common"
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
)
const (
defaultReceiveWindow = 32768
defaultMaxInFlight = 1024
iosReceiveWindow = 16384
iosMaxInFlight = 256
)
type Forwarder struct {
logger *nblog.Logger
flowLogger nftypes.FlowLogger
stack *stack.Stack
endpoint *endpoint
udpForwarder *udpForwarder
ctx context.Context
cancel context.CancelFunc
ip net.IP
netstack bool
}
func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.FlowLogger, netstack bool) (*Forwarder, error) {
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
TransportProtocols: []stack.TransportProtocolFactory{
tcp.NewProtocol,
udp.NewProtocol,
icmp.NewProtocol4,
},
HandleLocal: false,
})
mtu, err := iface.GetDevice().MTU()
if err != nil {
return nil, fmt.Errorf("get MTU: %w", err)
}
nicID := tcpip.NICID(1)
endpoint := &endpoint{
logger: logger,
device: iface.GetWGDevice(),
mtu: uint32(mtu),
}
if err := s.CreateNIC(nicID, endpoint); err != nil {
return nil, fmt.Errorf("failed to create NIC: %v", err)
}
ones, _ := iface.Address().Network.Mask.Size()
protoAddr := tcpip.ProtocolAddress{
Protocol: ipv4.ProtocolNumber,
AddressWithPrefix: tcpip.AddressWithPrefix{
Address: tcpip.AddrFromSlice(iface.Address().IP.To4()),
PrefixLen: ones,
},
}
if err := s.AddProtocolAddress(nicID, protoAddr, stack.AddressProperties{}); err != nil {
return nil, fmt.Errorf("failed to add protocol address: %s", err)
}
defaultSubnet, err := tcpip.NewSubnet(
tcpip.AddrFrom4([4]byte{0, 0, 0, 0}),
tcpip.MaskFromBytes([]byte{0, 0, 0, 0}),
)
if err != nil {
return nil, fmt.Errorf("creating default subnet: %w", err)
}
if err := s.SetPromiscuousMode(nicID, true); err != nil {
return nil, fmt.Errorf("set promiscuous mode: %s", err)
}
if err := s.SetSpoofing(nicID, true); err != nil {
return nil, fmt.Errorf("set spoofing: %s", err)
}
s.SetRouteTable([]tcpip.Route{
{
Destination: defaultSubnet,
NIC: nicID,
},
})
ctx, cancel := context.WithCancel(context.Background())
f := &Forwarder{
logger: logger,
flowLogger: flowLogger,
stack: s,
endpoint: endpoint,
udpForwarder: newUDPForwarder(mtu, logger, flowLogger),
ctx: ctx,
cancel: cancel,
netstack: netstack,
ip: iface.Address().IP,
}
receiveWindow := defaultReceiveWindow
maxInFlight := defaultMaxInFlight
if runtime.GOOS == "ios" {
receiveWindow = iosReceiveWindow
maxInFlight = iosMaxInFlight
}
tcpForwarder := tcp.NewForwarder(s, receiveWindow, maxInFlight, f.handleTCP)
s.SetTransportProtocolHandler(tcp.ProtocolNumber, tcpForwarder.HandlePacket)
udpForwarder := udp.NewForwarder(s, f.handleUDP)
s.SetTransportProtocolHandler(udp.ProtocolNumber, udpForwarder.HandlePacket)
s.SetTransportProtocolHandler(icmp.ProtocolNumber4, f.handleICMP)
log.Debugf("forwarder: Initialization complete with NIC %d", nicID)
return f, nil
}
func (f *Forwarder) InjectIncomingPacket(payload []byte) error {
if len(payload) < header.IPv4MinimumSize {
return fmt.Errorf("packet too small: %d bytes", len(payload))
}
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
Payload: buffer.MakeWithData(payload),
})
defer pkt.DecRef()
if f.endpoint.dispatcher != nil {
f.endpoint.dispatcher.DeliverNetworkPacket(ipv4.ProtocolNumber, pkt)
}
return nil
}
// Stop gracefully shuts down the forwarder
func (f *Forwarder) Stop() {
f.cancel()
if f.udpForwarder != nil {
f.udpForwarder.Stop()
}
f.stack.Close()
f.stack.Wait()
}
func (f *Forwarder) determineDialAddr(addr tcpip.Address) net.IP {
if f.netstack && f.ip.Equal(addr.AsSlice()) {
return net.IPv4(127, 0, 0, 1)
}
return addr.AsSlice()
}

View File

@@ -0,0 +1,127 @@
package forwarder
import (
"context"
"net"
"net/netip"
"time"
"github.com/google/uuid"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/stack"
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
)
// handleICMP handles ICMP packets from the network stack
func (f *Forwarder) handleICMP(id stack.TransportEndpointID, pkt stack.PacketBufferPtr) bool {
icmpHdr := header.ICMPv4(pkt.TransportHeader().View().AsSlice())
icmpType := uint8(icmpHdr.Type())
icmpCode := uint8(icmpHdr.Code())
if header.ICMPv4Type(icmpType) == header.ICMPv4EchoReply {
// dont process our own replies
return true
}
flowID := uuid.New()
f.sendICMPEvent(nftypes.TypeStart, flowID, id, icmpType, icmpCode)
ctx, cancel := context.WithTimeout(f.ctx, 5*time.Second)
defer cancel()
lc := net.ListenConfig{}
// TODO: support non-root
conn, err := lc.ListenPacket(ctx, "ip4:icmp", "0.0.0.0")
if err != nil {
f.logger.Error("Failed to create ICMP socket for %v: %v", epID(id), err)
// This will make netstack reply on behalf of the original destination, that's ok for now
return false
}
defer func() {
if err := conn.Close(); err != nil {
f.logger.Debug("Failed to close ICMP socket: %v", err)
}
}()
dstIP := f.determineDialAddr(id.LocalAddress)
dst := &net.IPAddr{IP: dstIP}
fullPacket := stack.PayloadSince(pkt.TransportHeader())
payload := fullPacket.AsSlice()
if _, err = conn.WriteTo(payload, dst); err != nil {
f.logger.Error("Failed to write ICMP packet for %v: %v", epID(id), err)
return true
}
f.logger.Trace("Forwarded ICMP packet %v type %v code %v",
epID(id), icmpHdr.Type(), icmpHdr.Code())
// For Echo Requests, send and handle response
if header.ICMPv4Type(icmpType) == header.ICMPv4Echo {
f.handleEchoResponse(icmpHdr, conn, id)
f.sendICMPEvent(nftypes.TypeEnd, flowID, id, icmpType, icmpCode)
}
// For other ICMP types (Time Exceeded, Destination Unreachable, etc) do nothing
return true
}
func (f *Forwarder) handleEchoResponse(icmpHdr header.ICMPv4, conn net.PacketConn, id stack.TransportEndpointID) {
if err := conn.SetReadDeadline(time.Now().Add(5 * time.Second)); err != nil {
f.logger.Error("Failed to set read deadline for ICMP response: %v", err)
return
}
response := make([]byte, f.endpoint.mtu)
n, _, err := conn.ReadFrom(response)
if err != nil {
if !isTimeout(err) {
f.logger.Error("Failed to read ICMP response: %v", err)
}
return
}
ipHdr := make([]byte, header.IPv4MinimumSize)
ip := header.IPv4(ipHdr)
ip.Encode(&header.IPv4Fields{
TotalLength: uint16(header.IPv4MinimumSize + n),
TTL: 64,
Protocol: uint8(header.ICMPv4ProtocolNumber),
SrcAddr: id.LocalAddress,
DstAddr: id.RemoteAddress,
})
ip.SetChecksum(^ip.CalculateChecksum())
fullPacket := make([]byte, 0, len(ipHdr)+n)
fullPacket = append(fullPacket, ipHdr...)
fullPacket = append(fullPacket, response[:n]...)
if err := f.InjectIncomingPacket(fullPacket); err != nil {
f.logger.Error("Failed to inject ICMP response: %v", err)
return
}
f.logger.Trace("Forwarded ICMP echo reply for %v type %v code %v",
epID(id), icmpHdr.Type(), icmpHdr.Code())
}
// sendICMPEvent stores flow events for ICMP packets
func (f *Forwarder) sendICMPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.TransportEndpointID, icmpType, icmpCode uint8) {
f.flowLogger.StoreEvent(nftypes.EventFields{
FlowID: flowID,
Type: typ,
Direction: nftypes.Ingress,
Protocol: nftypes.ICMP,
// TODO: handle ipv6
SourceIP: netip.AddrFrom4(id.RemoteAddress.As4()),
DestIP: netip.AddrFrom4(id.LocalAddress.As4()),
ICMPType: icmpType,
ICMPCode: icmpCode,
// TODO: get packets/bytes
})
}

View File

@@ -0,0 +1,132 @@
package forwarder
import (
"context"
"fmt"
"io"
"net"
"net/netip"
"github.com/google/uuid"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
"gvisor.dev/gvisor/pkg/waiter"
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
)
// handleTCP is called by the TCP forwarder for new connections.
func (f *Forwarder) handleTCP(r *tcp.ForwarderRequest) {
id := r.ID()
flowID := uuid.New()
f.sendTCPEvent(nftypes.TypeStart, flowID, id, nil)
var success bool
defer func() {
if !success {
f.sendTCPEvent(nftypes.TypeEnd, flowID, id, nil)
}
}()
dialAddr := fmt.Sprintf("%s:%d", f.determineDialAddr(id.LocalAddress), id.LocalPort)
outConn, err := (&net.Dialer{}).DialContext(f.ctx, "tcp", dialAddr)
if err != nil {
r.Complete(true)
f.logger.Trace("forwarder: dial error for %v: %v", epID(id), err)
return
}
// Create wait queue for blocking syscalls
wq := waiter.Queue{}
ep, epErr := r.CreateEndpoint(&wq)
if epErr != nil {
f.logger.Error("forwarder: failed to create TCP endpoint: %v", epErr)
if err := outConn.Close(); err != nil {
f.logger.Debug("forwarder: outConn close error: %v", err)
}
r.Complete(true)
return
}
// Complete the handshake
r.Complete(false)
inConn := gonet.NewTCPConn(&wq, ep)
success = true
f.logger.Trace("forwarder: established TCP connection %v", epID(id))
go f.proxyTCP(id, inConn, outConn, ep, flowID)
}
func (f *Forwarder) proxyTCP(id stack.TransportEndpointID, inConn *gonet.TCPConn, outConn net.Conn, ep tcpip.Endpoint, flowID uuid.UUID) {
defer func() {
if err := inConn.Close(); err != nil {
f.logger.Debug("forwarder: inConn close error: %v", err)
}
if err := outConn.Close(); err != nil {
f.logger.Debug("forwarder: outConn close error: %v", err)
}
ep.Close()
f.sendTCPEvent(nftypes.TypeEnd, flowID, id, ep)
}()
// Create context for managing the proxy goroutines
ctx, cancel := context.WithCancel(f.ctx)
defer cancel()
errChan := make(chan error, 2)
go func() {
_, err := io.Copy(outConn, inConn)
errChan <- err
}()
go func() {
_, err := io.Copy(inConn, outConn)
errChan <- err
}()
select {
case <-ctx.Done():
f.logger.Trace("forwarder: tearing down TCP connection %v due to context done", epID(id))
return
case err := <-errChan:
if err != nil && !isClosedError(err) {
f.logger.Error("proxyTCP: copy error: %v", err)
}
f.logger.Trace("forwarder: tearing down TCP connection %v", epID(id))
return
}
}
func (f *Forwarder) sendTCPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.TransportEndpointID, ep tcpip.Endpoint) {
fields := nftypes.EventFields{
FlowID: flowID,
Type: typ,
Direction: nftypes.Ingress,
Protocol: nftypes.TCP,
// TODO: handle ipv6
SourceIP: netip.AddrFrom4(id.RemoteAddress.As4()),
DestIP: netip.AddrFrom4(id.LocalAddress.As4()),
SourcePort: id.RemotePort,
DestPort: id.LocalPort,
}
if ep != nil {
if tcpStats, ok := ep.Stats().(*tcp.Stats); ok {
// fields are flipped since this is the in conn
// TODO: get bytes
fields.RxPackets = tcpStats.SegmentsSent.Value()
fields.TxPackets = tcpStats.SegmentsReceived.Value()
}
}
f.flowLogger.StoreEvent(fields)
}

View File

@@ -0,0 +1,332 @@
package forwarder
import (
"context"
"errors"
"fmt"
"net"
"net/netip"
"sync"
"sync/atomic"
"time"
"github.com/google/uuid"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
"gvisor.dev/gvisor/pkg/waiter"
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
)
const (
udpTimeout = 30 * time.Second
)
type udpPacketConn struct {
conn *gonet.UDPConn
outConn net.Conn
lastSeen atomic.Int64
cancel context.CancelFunc
ep tcpip.Endpoint
flowID uuid.UUID
}
type udpForwarder struct {
sync.RWMutex
logger *nblog.Logger
flowLogger nftypes.FlowLogger
conns map[stack.TransportEndpointID]*udpPacketConn
bufPool sync.Pool
ctx context.Context
cancel context.CancelFunc
}
type idleConn struct {
id stack.TransportEndpointID
conn *udpPacketConn
}
func newUDPForwarder(mtu int, logger *nblog.Logger, flowLogger nftypes.FlowLogger) *udpForwarder {
ctx, cancel := context.WithCancel(context.Background())
f := &udpForwarder{
logger: logger,
flowLogger: flowLogger,
conns: make(map[stack.TransportEndpointID]*udpPacketConn),
ctx: ctx,
cancel: cancel,
bufPool: sync.Pool{
New: func() any {
b := make([]byte, mtu)
return &b
},
},
}
go f.cleanup()
return f
}
// Stop stops the UDP forwarder and all active connections
func (f *udpForwarder) Stop() {
f.cancel()
f.Lock()
defer f.Unlock()
for id, conn := range f.conns {
conn.cancel()
if err := conn.conn.Close(); err != nil {
f.logger.Debug("forwarder: UDP conn close error for %v: %v", epID(id), err)
}
if err := conn.outConn.Close(); err != nil {
f.logger.Debug("forwarder: UDP outConn close error for %v: %v", epID(id), err)
}
conn.ep.Close()
delete(f.conns, id)
}
}
// cleanup periodically removes idle UDP connections
func (f *udpForwarder) cleanup() {
ticker := time.NewTicker(time.Minute)
defer ticker.Stop()
for {
select {
case <-f.ctx.Done():
return
case <-ticker.C:
var idleConns []idleConn
f.RLock()
for id, conn := range f.conns {
if conn.getIdleDuration() > udpTimeout {
idleConns = append(idleConns, idleConn{id, conn})
}
}
f.RUnlock()
for _, idle := range idleConns {
idle.conn.cancel()
if err := idle.conn.conn.Close(); err != nil {
f.logger.Debug("forwarder: UDP conn close error for %v: %v", epID(idle.id), err)
}
if err := idle.conn.outConn.Close(); err != nil {
f.logger.Debug("forwarder: UDP outConn close error for %v: %v", epID(idle.id), err)
}
idle.conn.ep.Close()
f.Lock()
delete(f.conns, idle.id)
f.Unlock()
f.logger.Trace("forwarder: cleaned up idle UDP connection %v", epID(idle.id))
}
}
}
}
// handleUDP is called by the UDP forwarder for new packets
func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) {
if f.ctx.Err() != nil {
f.logger.Trace("forwarder: context done, dropping UDP packet")
return
}
id := r.ID()
f.udpForwarder.RLock()
_, exists := f.udpForwarder.conns[id]
f.udpForwarder.RUnlock()
if exists {
f.logger.Trace("forwarder: existing UDP connection for %v", epID(id))
return
}
flowID := uuid.New()
f.sendUDPEvent(nftypes.TypeStart, flowID, id, nil)
var success bool
defer func() {
if !success {
f.sendUDPEvent(nftypes.TypeEnd, flowID, id, nil)
}
}()
dstAddr := fmt.Sprintf("%s:%d", f.determineDialAddr(id.LocalAddress), id.LocalPort)
outConn, err := (&net.Dialer{}).DialContext(f.ctx, "udp", dstAddr)
if err != nil {
f.logger.Debug("forwarder: UDP dial error for %v: %v", epID(id), err)
// TODO: Send ICMP error message
return
}
// Create wait queue for blocking syscalls
wq := waiter.Queue{}
ep, epErr := r.CreateEndpoint(&wq)
if epErr != nil {
f.logger.Debug("forwarder: failed to create UDP endpoint: %v", epErr)
if err := outConn.Close(); err != nil {
f.logger.Debug("forwarder: UDP outConn close error for %v: %v", epID(id), err)
}
return
}
inConn := gonet.NewUDPConn(f.stack, &wq, ep)
connCtx, connCancel := context.WithCancel(f.ctx)
pConn := &udpPacketConn{
conn: inConn,
outConn: outConn,
cancel: connCancel,
ep: ep,
flowID: flowID,
}
pConn.updateLastSeen()
f.udpForwarder.Lock()
// Double-check no connection was created while we were setting up
if _, exists := f.udpForwarder.conns[id]; exists {
f.udpForwarder.Unlock()
pConn.cancel()
if err := inConn.Close(); err != nil {
f.logger.Debug("forwarder: UDP inConn close error for %v: %v", epID(id), err)
}
if err := outConn.Close(); err != nil {
f.logger.Debug("forwarder: UDP outConn close error for %v: %v", epID(id), err)
}
return
}
f.udpForwarder.conns[id] = pConn
f.udpForwarder.Unlock()
success = true
f.logger.Trace("forwarder: established UDP connection %v", epID(id))
go f.proxyUDP(connCtx, pConn, id, ep)
}
func (f *Forwarder) proxyUDP(ctx context.Context, pConn *udpPacketConn, id stack.TransportEndpointID, ep tcpip.Endpoint) {
defer func() {
pConn.cancel()
if err := pConn.conn.Close(); err != nil {
f.logger.Debug("forwarder: UDP inConn close error for %v: %v", epID(id), err)
}
if err := pConn.outConn.Close(); err != nil {
f.logger.Debug("forwarder: UDP outConn close error for %v: %v", epID(id), err)
}
ep.Close()
f.udpForwarder.Lock()
delete(f.udpForwarder.conns, id)
f.udpForwarder.Unlock()
f.sendUDPEvent(nftypes.TypeEnd, pConn.flowID, id, ep)
}()
errChan := make(chan error, 2)
go func() {
errChan <- pConn.copy(ctx, pConn.conn, pConn.outConn, &f.udpForwarder.bufPool, "outbound->inbound")
}()
go func() {
errChan <- pConn.copy(ctx, pConn.outConn, pConn.conn, &f.udpForwarder.bufPool, "inbound->outbound")
}()
select {
case <-ctx.Done():
f.logger.Trace("forwarder: tearing down UDP connection %v due to context done", epID(id))
return
case err := <-errChan:
if err != nil && !isClosedError(err) {
f.logger.Error("proxyUDP: copy error: %v", err)
}
f.logger.Trace("forwarder: tearing down UDP connection %v", epID(id))
return
}
}
// sendUDPEvent stores flow events for UDP connections
func (f *Forwarder) sendUDPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.TransportEndpointID, ep tcpip.Endpoint) {
fields := nftypes.EventFields{
FlowID: flowID,
Type: typ,
Direction: nftypes.Ingress,
Protocol: nftypes.UDP,
// TODO: handle ipv6
SourceIP: netip.AddrFrom4(id.RemoteAddress.As4()),
DestIP: netip.AddrFrom4(id.LocalAddress.As4()),
SourcePort: id.RemotePort,
DestPort: id.LocalPort,
}
if ep != nil {
if tcpStats, ok := ep.Stats().(*tcpip.TransportEndpointStats); ok {
// fields are flipped since this is the in conn
// TODO: get bytes
fields.RxPackets = tcpStats.PacketsSent.Value()
fields.TxPackets = tcpStats.PacketsReceived.Value()
}
}
f.flowLogger.StoreEvent(fields)
}
func (c *udpPacketConn) updateLastSeen() {
c.lastSeen.Store(time.Now().UnixNano())
}
func (c *udpPacketConn) getIdleDuration() time.Duration {
lastSeen := time.Unix(0, c.lastSeen.Load())
return time.Since(lastSeen)
}
func (c *udpPacketConn) copy(ctx context.Context, dst net.Conn, src net.Conn, bufPool *sync.Pool, direction string) error {
bufp := bufPool.Get().(*[]byte)
defer bufPool.Put(bufp)
buffer := *bufp
for {
if ctx.Err() != nil {
return ctx.Err()
}
if err := src.SetDeadline(time.Now().Add(udpTimeout)); err != nil {
return fmt.Errorf("set read deadline: %w", err)
}
n, err := src.Read(buffer)
if err != nil {
if isTimeout(err) {
continue
}
return fmt.Errorf("read from %s: %w", direction, err)
}
_, err = dst.Write(buffer[:n])
if err != nil {
return fmt.Errorf("write to %s: %w", direction, err)
}
c.updateLastSeen()
}
}
func isClosedError(err error) bool {
return errors.Is(err, net.ErrClosed) || errors.Is(err, context.Canceled)
}
func isTimeout(err error) bool {
var netErr net.Error
if errors.As(err, &netErr) {
return netErr.Timeout()
}
return false
}

View File

@@ -0,0 +1,131 @@
package uspfilter
import (
"fmt"
"net"
"net/netip"
"sync"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/firewall/uspfilter/common"
)
type localIPManager struct {
mu sync.RWMutex
// Use bitmap for IPv4 (32 bits * 2^16 = 256KB memory)
ipv4Bitmap [1 << 16]uint32
}
func newLocalIPManager() *localIPManager {
return &localIPManager{}
}
func (m *localIPManager) setBitmapBit(ip net.IP) {
ipv4 := ip.To4()
if ipv4 == nil {
return
}
high := (uint16(ipv4[0]) << 8) | uint16(ipv4[1])
low := (uint16(ipv4[2]) << 8) | uint16(ipv4[3])
m.ipv4Bitmap[high] |= 1 << (low % 32)
}
func (m *localIPManager) checkBitmapBit(ip []byte) bool {
high := (uint16(ip[0]) << 8) | uint16(ip[1])
low := (uint16(ip[2]) << 8) | uint16(ip[3])
return (m.ipv4Bitmap[high] & (1 << (low % 32))) != 0
}
func (m *localIPManager) processIP(ip net.IP, newIPv4Bitmap *[1 << 16]uint32, ipv4Set map[string]struct{}, ipv4Addresses *[]string) error {
if ipv4 := ip.To4(); ipv4 != nil {
high := (uint16(ipv4[0]) << 8) | uint16(ipv4[1])
low := (uint16(ipv4[2]) << 8) | uint16(ipv4[3])
if int(high) >= len(*newIPv4Bitmap) {
return fmt.Errorf("invalid IPv4 address: %s", ip)
}
ipStr := ip.String()
if _, exists := ipv4Set[ipStr]; !exists {
ipv4Set[ipStr] = struct{}{}
*ipv4Addresses = append(*ipv4Addresses, ipStr)
newIPv4Bitmap[high] |= 1 << (low % 32)
}
}
return nil
}
func (m *localIPManager) processInterface(iface net.Interface, newIPv4Bitmap *[1 << 16]uint32, ipv4Set map[string]struct{}, ipv4Addresses *[]string) {
addrs, err := iface.Addrs()
if err != nil {
log.Debugf("get addresses for interface %s failed: %v", iface.Name, err)
return
}
for _, addr := range addrs {
var ip net.IP
switch v := addr.(type) {
case *net.IPNet:
ip = v.IP
case *net.IPAddr:
ip = v.IP
default:
continue
}
if err := m.processIP(ip, newIPv4Bitmap, ipv4Set, ipv4Addresses); err != nil {
log.Debugf("process IP failed: %v", err)
}
}
}
func (m *localIPManager) UpdateLocalIPs(iface common.IFaceMapper) (err error) {
defer func() {
if r := recover(); r != nil {
err = fmt.Errorf("panic: %v", r)
}
}()
var newIPv4Bitmap [1 << 16]uint32
ipv4Set := make(map[string]struct{})
var ipv4Addresses []string
// 127.0.0.0/8
high := uint16(127) << 8
for i := uint16(0); i < 256; i++ {
newIPv4Bitmap[high|i] = 0xffffffff
}
if iface != nil {
if err := m.processIP(iface.Address().IP, &newIPv4Bitmap, ipv4Set, &ipv4Addresses); err != nil {
return err
}
}
interfaces, err := net.Interfaces()
if err != nil {
log.Warnf("failed to get interfaces: %v", err)
} else {
for _, intf := range interfaces {
m.processInterface(intf, &newIPv4Bitmap, ipv4Set, &ipv4Addresses)
}
}
m.mu.Lock()
m.ipv4Bitmap = newIPv4Bitmap
m.mu.Unlock()
log.Debugf("Local IPv4 addresses: %v", ipv4Addresses)
return nil
}
func (m *localIPManager) IsLocalIP(ip netip.Addr) bool {
m.mu.RLock()
defer m.mu.RUnlock()
if ip.Is4() {
return m.checkBitmapBit(ip.AsSlice())
}
return false
}

View File

@@ -0,0 +1,271 @@
package uspfilter
import (
"net"
"net/netip"
"testing"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
func TestLocalIPManager(t *testing.T) {
tests := []struct {
name string
setupAddr wgaddr.Address
testIP netip.Addr
expected bool
}{
{
name: "Localhost range",
setupAddr: wgaddr.Address{
IP: net.ParseIP("192.168.1.1"),
Network: &net.IPNet{
IP: net.ParseIP("192.168.1.0"),
Mask: net.CIDRMask(24, 32),
},
},
testIP: netip.MustParseAddr("127.0.0.2"),
expected: true,
},
{
name: "Localhost standard address",
setupAddr: wgaddr.Address{
IP: net.ParseIP("192.168.1.1"),
Network: &net.IPNet{
IP: net.ParseIP("192.168.1.0"),
Mask: net.CIDRMask(24, 32),
},
},
testIP: netip.MustParseAddr("127.0.0.1"),
expected: true,
},
{
name: "Localhost range edge",
setupAddr: wgaddr.Address{
IP: net.ParseIP("192.168.1.1"),
Network: &net.IPNet{
IP: net.ParseIP("192.168.1.0"),
Mask: net.CIDRMask(24, 32),
},
},
testIP: netip.MustParseAddr("127.255.255.255"),
expected: true,
},
{
name: "Local IP matches",
setupAddr: wgaddr.Address{
IP: net.ParseIP("192.168.1.1"),
Network: &net.IPNet{
IP: net.ParseIP("192.168.1.0"),
Mask: net.CIDRMask(24, 32),
},
},
testIP: netip.MustParseAddr("192.168.1.1"),
expected: true,
},
{
name: "Local IP doesn't match",
setupAddr: wgaddr.Address{
IP: net.ParseIP("192.168.1.1"),
Network: &net.IPNet{
IP: net.ParseIP("192.168.1.0"),
Mask: net.CIDRMask(24, 32),
},
},
testIP: netip.MustParseAddr("192.168.1.2"),
expected: false,
},
{
name: "IPv6 address",
setupAddr: wgaddr.Address{
IP: net.ParseIP("fe80::1"),
Network: &net.IPNet{
IP: net.ParseIP("fe80::"),
Mask: net.CIDRMask(64, 128),
},
},
testIP: netip.MustParseAddr("fe80::1"),
expected: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
manager := newLocalIPManager()
mock := &IFaceMock{
AddressFunc: func() wgaddr.Address {
return tt.setupAddr
},
}
err := manager.UpdateLocalIPs(mock)
require.NoError(t, err)
result := manager.IsLocalIP(tt.testIP)
require.Equal(t, tt.expected, result)
})
}
}
func TestLocalIPManager_AllInterfaces(t *testing.T) {
manager := newLocalIPManager()
mock := &IFaceMock{}
// Get actual local interfaces
interfaces, err := net.Interfaces()
require.NoError(t, err)
var tests []struct {
ip string
expected bool
}
// Add all local interface IPs to test cases
for _, iface := range interfaces {
addrs, err := iface.Addrs()
require.NoError(t, err)
for _, addr := range addrs {
var ip net.IP
switch v := addr.(type) {
case *net.IPNet:
ip = v.IP
case *net.IPAddr:
ip = v.IP
default:
continue
}
if ip4 := ip.To4(); ip4 != nil {
tests = append(tests, struct {
ip string
expected bool
}{
ip: ip4.String(),
expected: true,
})
}
}
}
// Add some external IPs as negative test cases
externalIPs := []string{
"8.8.8.8",
"1.1.1.1",
"208.67.222.222",
}
for _, ip := range externalIPs {
tests = append(tests, struct {
ip string
expected bool
}{
ip: ip,
expected: false,
})
}
require.NotEmpty(t, tests, "No test cases generated")
err = manager.UpdateLocalIPs(mock)
require.NoError(t, err)
t.Logf("Testing %d IPs", len(tests))
for _, tt := range tests {
t.Run(tt.ip, func(t *testing.T) {
result := manager.IsLocalIP(netip.MustParseAddr(tt.ip))
require.Equal(t, tt.expected, result, "IP: %s", tt.ip)
})
}
}
// MapImplementation is a version using map[string]struct{}
type MapImplementation struct {
localIPs map[string]struct{}
}
func BenchmarkIPChecks(b *testing.B) {
interfaces := make([]net.IP, 16)
for i := range interfaces {
interfaces[i] = net.IPv4(10, 0, byte(i>>8), byte(i))
}
// Setup bitmap version
bitmapManager := &localIPManager{
ipv4Bitmap: [1 << 16]uint32{},
}
for _, ip := range interfaces[:8] { // Add half of IPs
bitmapManager.setBitmapBit(ip)
}
// Setup map version
mapManager := &MapImplementation{
localIPs: make(map[string]struct{}),
}
for _, ip := range interfaces[:8] {
mapManager.localIPs[ip.String()] = struct{}{}
}
b.Run("Bitmap_Hit", func(b *testing.B) {
ip := interfaces[4]
b.ResetTimer()
for i := 0; i < b.N; i++ {
bitmapManager.checkBitmapBit(ip)
}
})
b.Run("Bitmap_Miss", func(b *testing.B) {
ip := interfaces[12]
b.ResetTimer()
for i := 0; i < b.N; i++ {
bitmapManager.checkBitmapBit(ip)
}
})
b.Run("Map_Hit", func(b *testing.B) {
ip := interfaces[4]
b.ResetTimer()
for i := 0; i < b.N; i++ {
// nolint:gosimple
_, _ = mapManager.localIPs[ip.String()]
}
})
b.Run("Map_Miss", func(b *testing.B) {
ip := interfaces[12]
b.ResetTimer()
for i := 0; i < b.N; i++ {
// nolint:gosimple
_, _ = mapManager.localIPs[ip.String()]
}
})
}
func BenchmarkWGPosition(b *testing.B) {
wgIP := net.ParseIP("10.10.0.1")
// Create two managers - one checks WG IP first, other checks it last
b.Run("WG_First", func(b *testing.B) {
bm := &localIPManager{ipv4Bitmap: [1 << 16]uint32{}}
bm.setBitmapBit(wgIP)
b.ResetTimer()
for i := 0; i < b.N; i++ {
bm.checkBitmapBit(wgIP)
}
})
b.Run("WG_Last", func(b *testing.B) {
bm := &localIPManager{ipv4Bitmap: [1 << 16]uint32{}}
// Fill with other IPs first
for i := 0; i < 15; i++ {
bm.setBitmapBit(net.IPv4(10, 0, byte(i>>8), byte(i)))
}
bm.setBitmapBit(wgIP) // Add WG IP last
b.ResetTimer()
for i := 0; i < b.N; i++ {
bm.checkBitmapBit(wgIP)
}
})
}

View File

@@ -0,0 +1,252 @@
// Package log provides a high-performance, non-blocking logger for userspace networking
package log
import (
"context"
"fmt"
"io"
"sync"
"sync/atomic"
"time"
log "github.com/sirupsen/logrus"
)
const (
maxBatchSize = 1024 * 16
maxMessageSize = 1024 * 2
defaultFlushInterval = 2 * time.Second
logChannelSize = 1000
)
type Level uint32
const (
LevelPanic Level = iota
LevelFatal
LevelError
LevelWarn
LevelInfo
LevelDebug
LevelTrace
)
var levelStrings = map[Level]string{
LevelPanic: "PANC",
LevelFatal: "FATL",
LevelError: "ERRO",
LevelWarn: "WARN",
LevelInfo: "INFO",
LevelDebug: "DEBG",
LevelTrace: "TRAC",
}
type logMessage struct {
level Level
format string
args []any
}
// Logger is a high-performance, non-blocking logger
type Logger struct {
output io.Writer
level atomic.Uint32
msgChannel chan logMessage
shutdown chan struct{}
closeOnce sync.Once
wg sync.WaitGroup
bufPool sync.Pool
}
// NewFromLogrus creates a new Logger that writes to the same output as the given logrus logger
func NewFromLogrus(logrusLogger *log.Logger) *Logger {
l := &Logger{
output: logrusLogger.Out,
msgChannel: make(chan logMessage, logChannelSize),
shutdown: make(chan struct{}),
bufPool: sync.Pool{
New: func() any {
b := make([]byte, 0, maxMessageSize)
return &b
},
},
}
logrusLevel := logrusLogger.GetLevel()
l.level.Store(uint32(logrusLevel))
level := levelStrings[Level(logrusLevel)]
log.Debugf("New uspfilter logger created with loglevel %v", level)
l.wg.Add(1)
go l.worker()
return l
}
// SetLevel sets the logging level
func (l *Logger) SetLevel(level Level) {
l.level.Store(uint32(level))
log.Debugf("Set uspfilter logger loglevel to %v", levelStrings[level])
}
func (l *Logger) log(level Level, format string, args ...any) {
select {
case l.msgChannel <- logMessage{level: level, format: format, args: args}:
default:
}
}
// Error logs a message at error level
func (l *Logger) Error(format string, args ...any) {
if l.level.Load() >= uint32(LevelError) {
l.log(LevelError, format, args...)
}
}
// Warn logs a message at warning level
func (l *Logger) Warn(format string, args ...any) {
if l.level.Load() >= uint32(LevelWarn) {
l.log(LevelWarn, format, args...)
}
}
// Info logs a message at info level
func (l *Logger) Info(format string, args ...any) {
if l.level.Load() >= uint32(LevelInfo) {
l.log(LevelInfo, format, args...)
}
}
// Debug logs a message at debug level
func (l *Logger) Debug(format string, args ...any) {
if l.level.Load() >= uint32(LevelDebug) {
l.log(LevelDebug, format, args...)
}
}
// Trace logs a message at trace level
func (l *Logger) Trace(format string, args ...any) {
if l.level.Load() >= uint32(LevelTrace) {
l.log(LevelTrace, format, args...)
}
}
func (l *Logger) formatMessage(buf *[]byte, level Level, format string, args ...any) {
*buf = (*buf)[:0]
*buf = time.Now().AppendFormat(*buf, "2006-01-02T15:04:05-07:00")
*buf = append(*buf, ' ')
*buf = append(*buf, levelStrings[level]...)
*buf = append(*buf, ' ')
var msg string
if len(args) > 0 {
msg = fmt.Sprintf(format, args...)
} else {
msg = format
}
*buf = append(*buf, msg...)
*buf = append(*buf, '\n')
if len(*buf) > maxMessageSize {
*buf = (*buf)[:maxMessageSize]
}
}
// processMessage handles a single log message and adds it to the buffer
func (l *Logger) processMessage(msg logMessage, buffer *[]byte) {
bufp := l.bufPool.Get().(*[]byte)
defer l.bufPool.Put(bufp)
l.formatMessage(bufp, msg.level, msg.format, msg.args...)
if len(*buffer)+len(*bufp) > maxBatchSize {
_, _ = l.output.Write(*buffer)
*buffer = (*buffer)[:0]
}
*buffer = append(*buffer, *bufp...)
}
// flushBuffer writes the accumulated buffer to output
func (l *Logger) flushBuffer(buffer *[]byte) {
if len(*buffer) > 0 {
_, _ = l.output.Write(*buffer)
*buffer = (*buffer)[:0]
}
}
// processBatch processes as many messages as possible without blocking
func (l *Logger) processBatch(buffer *[]byte) {
for len(*buffer) < maxBatchSize {
select {
case msg := <-l.msgChannel:
l.processMessage(msg, buffer)
default:
return
}
}
}
// handleShutdown manages the graceful shutdown sequence with timeout
func (l *Logger) handleShutdown(buffer *[]byte) {
ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond)
defer cancel()
for {
select {
case msg := <-l.msgChannel:
l.processMessage(msg, buffer)
case <-ctx.Done():
l.flushBuffer(buffer)
return
}
if len(l.msgChannel) == 0 {
l.flushBuffer(buffer)
return
}
}
}
// worker is the main goroutine that processes log messages
func (l *Logger) worker() {
defer l.wg.Done()
ticker := time.NewTicker(defaultFlushInterval)
defer ticker.Stop()
buffer := make([]byte, 0, maxBatchSize)
for {
select {
case <-l.shutdown:
l.handleShutdown(&buffer)
return
case <-ticker.C:
l.flushBuffer(&buffer)
case msg := <-l.msgChannel:
l.processMessage(msg, &buffer)
l.processBatch(&buffer)
}
}
}
// Stop gracefully shuts down the logger
func (l *Logger) Stop(ctx context.Context) error {
done := make(chan struct{})
l.closeOnce.Do(func() {
close(l.shutdown)
})
go func() {
l.wg.Wait()
close(done)
}()
select {
case <-ctx.Done():
return ctx.Err()
case <-done:
return nil
}
}

View File

@@ -0,0 +1,121 @@
package log_test
import (
"context"
"testing"
"time"
"github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/firewall/uspfilter/log"
)
type discard struct{}
func (d *discard) Write(p []byte) (n int, err error) {
return len(p), nil
}
func BenchmarkLogger(b *testing.B) {
simpleMessage := "Connection established"
conntrackMessage := "TCP connection %s:%d -> %s:%d state changed to %d"
srcIP := "192.168.1.1"
srcPort := uint16(12345)
dstIP := "10.0.0.1"
dstPort := uint16(443)
state := 4 // TCPStateEstablished
complexMessage := "Packet inspection result: protocol=%s, direction=%s, flags=0x%x, sequence=%d, acknowledged=%d, payload_size=%d, fragmented=%v, connection_id=%s"
protocol := "TCP"
direction := "outbound"
flags := uint16(0x18) // ACK + PSH
sequence := uint32(123456789)
acknowledged := uint32(987654321)
payloadSize := 1460
fragmented := false
connID := "f7a12b3e-c456-7890-d123-456789abcdef"
b.Run("SimpleMessage", func(b *testing.B) {
logger := createTestLogger()
defer cleanupLogger(logger)
b.ResetTimer()
for i := 0; i < b.N; i++ {
logger.Trace(simpleMessage)
}
})
b.Run("ConntrackMessage", func(b *testing.B) {
logger := createTestLogger()
defer cleanupLogger(logger)
b.ResetTimer()
for i := 0; i < b.N; i++ {
logger.Trace(conntrackMessage, srcIP, srcPort, dstIP, dstPort, state)
}
})
b.Run("ComplexMessage", func(b *testing.B) {
logger := createTestLogger()
defer cleanupLogger(logger)
b.ResetTimer()
for i := 0; i < b.N; i++ {
logger.Trace(complexMessage, protocol, direction, flags, sequence, acknowledged, payloadSize, fragmented, connID)
}
})
}
// BenchmarkLoggerParallel tests the logger under concurrent load
func BenchmarkLoggerParallel(b *testing.B) {
logger := createTestLogger()
defer cleanupLogger(logger)
conntrackMessage := "TCP connection %s:%d -> %s:%d state changed to %d"
srcIP := "192.168.1.1"
srcPort := uint16(12345)
dstIP := "10.0.0.1"
dstPort := uint16(443)
state := 4
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
logger.Trace(conntrackMessage, srcIP, srcPort, dstIP, dstPort, state)
}
})
}
// BenchmarkLoggerBurst tests how the logger handles bursts of messages
func BenchmarkLoggerBurst(b *testing.B) {
logger := createTestLogger()
defer cleanupLogger(logger)
conntrackMessage := "TCP connection %s:%d -> %s:%d state changed to %d"
srcIP := "192.168.1.1"
srcPort := uint16(12345)
dstIP := "10.0.0.1"
dstPort := uint16(443)
state := 4
b.ResetTimer()
for i := 0; i < b.N; i++ {
for j := 0; j < 100; j++ {
logger.Trace(conntrackMessage, srcIP, srcPort, dstIP, dstPort, state)
}
}
}
func createTestLogger() *log.Logger {
logrusLogger := logrus.New()
logrusLogger.SetOutput(&discard{})
logrusLogger.SetLevel(logrus.TraceLevel)
return log.NewFromLogrus(logrusLogger)
}
func cleanupLogger(logger *log.Logger) {
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
_ = logger.Stop(ctx)
}

View File

@@ -1,27 +1,45 @@
package uspfilter
import (
"net"
"net/netip"
"github.com/google/gopacket"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
)
// Rule to handle management of rules
type Rule struct {
// PeerRule to handle management of rules
type PeerRule struct {
id string
ip net.IP
mgmtId []byte
ip netip.Addr
ipLayer gopacket.LayerType
matchByIP bool
protoLayer gopacket.LayerType
sPort uint16
dPort uint16
sPort *firewall.Port
dPort *firewall.Port
drop bool
comment string
udpHook func([]byte) bool
}
// GetRuleID returns the rule id
func (r *Rule) GetRuleID() string {
// ID returns the rule id
func (r *PeerRule) ID() string {
return r.id
}
type RouteRule struct {
id string
mgmtId []byte
sources []netip.Prefix
destination netip.Prefix
proto firewall.Protocol
srcPort *firewall.Port
dstPort *firewall.Port
action firewall.Action
}
// ID returns the rule id
func (r *RouteRule) ID() string {
return r.id
}

View File

@@ -0,0 +1,411 @@
package uspfilter
import (
"fmt"
"net/netip"
"time"
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
fw "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
)
type PacketStage int
const (
StageReceived PacketStage = iota
StageConntrack
StagePeerACL
StageRouting
StageRouteACL
StageForwarding
StageCompleted
)
const msgProcessingCompleted = "Processing completed"
func (s PacketStage) String() string {
return map[PacketStage]string{
StageReceived: "Received",
StageConntrack: "Connection Tracking",
StagePeerACL: "Peer ACL",
StageRouting: "Routing",
StageRouteACL: "Route ACL",
StageForwarding: "Forwarding",
StageCompleted: "Completed",
}[s]
}
type ForwarderAction struct {
Action string
RemoteAddr string
Error error
}
type TraceResult struct {
Timestamp time.Time
Stage PacketStage
Message string
Allowed bool
ForwarderAction *ForwarderAction
}
type PacketTrace struct {
SourceIP netip.Addr
DestinationIP netip.Addr
Protocol string
SourcePort uint16
DestinationPort uint16
Direction fw.RuleDirection
Results []TraceResult
}
type TCPState struct {
SYN bool
ACK bool
FIN bool
RST bool
PSH bool
URG bool
}
type PacketBuilder struct {
SrcIP netip.Addr
DstIP netip.Addr
Protocol fw.Protocol
SrcPort uint16
DstPort uint16
ICMPType uint8
ICMPCode uint8
Direction fw.RuleDirection
PayloadSize int
TCPState *TCPState
}
func (t *PacketTrace) AddResult(stage PacketStage, message string, allowed bool) {
t.Results = append(t.Results, TraceResult{
Timestamp: time.Now(),
Stage: stage,
Message: message,
Allowed: allowed,
})
}
func (t *PacketTrace) AddResultWithForwarder(stage PacketStage, message string, allowed bool, action *ForwarderAction) {
t.Results = append(t.Results, TraceResult{
Timestamp: time.Now(),
Stage: stage,
Message: message,
Allowed: allowed,
ForwarderAction: action,
})
}
func (p *PacketBuilder) Build() ([]byte, error) {
ip := p.buildIPLayer()
pktLayers := []gopacket.SerializableLayer{ip}
transportLayer, err := p.buildTransportLayer(ip)
if err != nil {
return nil, err
}
pktLayers = append(pktLayers, transportLayer...)
if p.PayloadSize > 0 {
payload := make([]byte, p.PayloadSize)
pktLayers = append(pktLayers, gopacket.Payload(payload))
}
return serializePacket(pktLayers)
}
func (p *PacketBuilder) buildIPLayer() *layers.IPv4 {
return &layers.IPv4{
Version: 4,
TTL: 64,
Protocol: layers.IPProtocol(getIPProtocolNumber(p.Protocol)),
SrcIP: p.SrcIP.AsSlice(),
DstIP: p.DstIP.AsSlice(),
}
}
func (p *PacketBuilder) buildTransportLayer(ip *layers.IPv4) ([]gopacket.SerializableLayer, error) {
switch p.Protocol {
case "tcp":
return p.buildTCPLayer(ip)
case "udp":
return p.buildUDPLayer(ip)
case "icmp":
return p.buildICMPLayer()
default:
return nil, fmt.Errorf("unsupported protocol: %s", p.Protocol)
}
}
func (p *PacketBuilder) buildTCPLayer(ip *layers.IPv4) ([]gopacket.SerializableLayer, error) {
tcp := &layers.TCP{
SrcPort: layers.TCPPort(p.SrcPort),
DstPort: layers.TCPPort(p.DstPort),
Window: 65535,
SYN: p.TCPState != nil && p.TCPState.SYN,
ACK: p.TCPState != nil && p.TCPState.ACK,
FIN: p.TCPState != nil && p.TCPState.FIN,
RST: p.TCPState != nil && p.TCPState.RST,
PSH: p.TCPState != nil && p.TCPState.PSH,
URG: p.TCPState != nil && p.TCPState.URG,
}
if err := tcp.SetNetworkLayerForChecksum(ip); err != nil {
return nil, fmt.Errorf("set network layer for TCP checksum: %w", err)
}
return []gopacket.SerializableLayer{tcp}, nil
}
func (p *PacketBuilder) buildUDPLayer(ip *layers.IPv4) ([]gopacket.SerializableLayer, error) {
udp := &layers.UDP{
SrcPort: layers.UDPPort(p.SrcPort),
DstPort: layers.UDPPort(p.DstPort),
}
if err := udp.SetNetworkLayerForChecksum(ip); err != nil {
return nil, fmt.Errorf("set network layer for UDP checksum: %w", err)
}
return []gopacket.SerializableLayer{udp}, nil
}
func (p *PacketBuilder) buildICMPLayer() ([]gopacket.SerializableLayer, error) {
icmp := &layers.ICMPv4{
TypeCode: layers.CreateICMPv4TypeCode(p.ICMPType, p.ICMPCode),
}
if p.ICMPType == layers.ICMPv4TypeEchoRequest || p.ICMPType == layers.ICMPv4TypeEchoReply {
icmp.Id = uint16(1)
icmp.Seq = uint16(1)
}
return []gopacket.SerializableLayer{icmp}, nil
}
func serializePacket(layers []gopacket.SerializableLayer) ([]byte, error) {
buf := gopacket.NewSerializeBuffer()
opts := gopacket.SerializeOptions{
ComputeChecksums: true,
FixLengths: true,
}
if err := gopacket.SerializeLayers(buf, opts, layers...); err != nil {
return nil, fmt.Errorf("serialize packet: %w", err)
}
return buf.Bytes(), nil
}
func getIPProtocolNumber(protocol fw.Protocol) int {
switch protocol {
case fw.ProtocolTCP:
return int(layers.IPProtocolTCP)
case fw.ProtocolUDP:
return int(layers.IPProtocolUDP)
case fw.ProtocolICMP:
return int(layers.IPProtocolICMPv4)
default:
return 0
}
}
func (m *Manager) TracePacketFromBuilder(builder *PacketBuilder) (*PacketTrace, error) {
packetData, err := builder.Build()
if err != nil {
return nil, fmt.Errorf("build packet: %w", err)
}
return m.TracePacket(packetData, builder.Direction), nil
}
func (m *Manager) TracePacket(packetData []byte, direction fw.RuleDirection) *PacketTrace {
d := m.decoders.Get().(*decoder)
defer m.decoders.Put(d)
trace := &PacketTrace{Direction: direction}
// Initial packet decoding
if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil {
trace.AddResult(StageReceived, fmt.Sprintf("Failed to decode packet: %v", err), false)
return trace
}
// Extract base packet info
srcIP, dstIP := m.extractIPs(d)
trace.SourceIP = srcIP
trace.DestinationIP = dstIP
// Determine protocol and ports
switch d.decoded[1] {
case layers.LayerTypeTCP:
trace.Protocol = "TCP"
trace.SourcePort = uint16(d.tcp.SrcPort)
trace.DestinationPort = uint16(d.tcp.DstPort)
case layers.LayerTypeUDP:
trace.Protocol = "UDP"
trace.SourcePort = uint16(d.udp.SrcPort)
trace.DestinationPort = uint16(d.udp.DstPort)
case layers.LayerTypeICMPv4:
trace.Protocol = "ICMP"
}
trace.AddResult(StageReceived, fmt.Sprintf("Received %s packet: %s:%d -> %s:%d",
trace.Protocol, srcIP, trace.SourcePort, dstIP, trace.DestinationPort), true)
if direction == fw.RuleDirectionOUT {
return m.traceOutbound(packetData, trace)
}
return m.traceInbound(packetData, trace, d, srcIP, dstIP)
}
func (m *Manager) traceInbound(packetData []byte, trace *PacketTrace, d *decoder, srcIP netip.Addr, dstIP netip.Addr) *PacketTrace {
if m.stateful && m.handleConntrackState(trace, d, srcIP, dstIP) {
return trace
}
if m.localipmanager.IsLocalIP(dstIP) {
if m.handleLocalDelivery(trace, packetData, d, srcIP, dstIP) {
return trace
}
}
if !m.handleRouting(trace) {
return trace
}
if m.nativeRouter.Load() {
return m.handleNativeRouter(trace)
}
return m.handleRouteACLs(trace, d, srcIP, dstIP)
}
func (m *Manager) handleConntrackState(trace *PacketTrace, d *decoder, srcIP, dstIP netip.Addr) bool {
allowed := m.isValidTrackedConnection(d, srcIP, dstIP, 0)
msg := "No existing connection found"
if allowed {
msg = m.buildConntrackStateMessage(d)
trace.AddResult(StageConntrack, msg, true)
trace.AddResult(StageCompleted, "Packet allowed by connection tracking", true)
return true
}
trace.AddResult(StageConntrack, msg, false)
return false
}
func (m *Manager) buildConntrackStateMessage(d *decoder) string {
msg := "Matched existing connection state"
switch d.decoded[1] {
case layers.LayerTypeTCP:
flags := getTCPFlags(&d.tcp)
msg += fmt.Sprintf(" (TCP Flags: SYN=%v ACK=%v RST=%v FIN=%v)",
flags&conntrack.TCPSyn != 0,
flags&conntrack.TCPAck != 0,
flags&conntrack.TCPRst != 0,
flags&conntrack.TCPFin != 0)
case layers.LayerTypeICMPv4:
msg += fmt.Sprintf(" (ICMP ID=%d, Seq=%d)", d.icmp4.Id, d.icmp4.Seq)
}
return msg
}
func (m *Manager) handleLocalDelivery(trace *PacketTrace, packetData []byte, d *decoder, srcIP, dstIP netip.Addr) bool {
trace.AddResult(StageRouting, "Packet destined for local delivery", true)
ruleId, blocked := m.peerACLsBlock(srcIP, packetData, m.incomingRules, d)
strRuleId := "<no id>"
if ruleId != nil {
strRuleId = string(ruleId)
}
msg := fmt.Sprintf("Allowed by peer ACL rules (%s)", strRuleId)
if blocked {
msg = fmt.Sprintf("Blocked by peer ACL rules (%s)", strRuleId)
trace.AddResult(StagePeerACL, msg, false)
trace.AddResult(StageCompleted, "Packet dropped - ACL denied", false)
return true
}
trace.AddResult(StagePeerACL, msg, true)
// Handle netstack mode
if m.netstack {
switch {
case !m.localForwarding:
trace.AddResult(StageCompleted, "Packet sent to virtual stack", true)
case m.forwarder.Load() != nil:
m.addForwardingResult(trace, "proxy-local", "127.0.0.1", true)
trace.AddResult(StageCompleted, msgProcessingCompleted, true)
default:
trace.AddResult(StageCompleted, "Packet dropped - forwarder not initialized", false)
}
return true
}
// In normal mode, packets are allowed through for local delivery
trace.AddResult(StageCompleted, msgProcessingCompleted, true)
return true
}
func (m *Manager) handleRouting(trace *PacketTrace) bool {
if !m.routingEnabled.Load() {
trace.AddResult(StageRouting, "Routing disabled", false)
trace.AddResult(StageCompleted, "Packet dropped - routing disabled", false)
return false
}
trace.AddResult(StageRouting, "Routing enabled, checking ACLs", true)
return true
}
func (m *Manager) handleNativeRouter(trace *PacketTrace) *PacketTrace {
trace.AddResult(StageRouteACL, "Using native router, skipping ACL checks", true)
trace.AddResult(StageForwarding, "Forwarding via native router", true)
trace.AddResult(StageCompleted, msgProcessingCompleted, true)
return trace
}
func (m *Manager) handleRouteACLs(trace *PacketTrace, d *decoder, srcIP, dstIP netip.Addr) *PacketTrace {
proto, _ := getProtocolFromPacket(d)
srcPort, dstPort := getPortsFromPacket(d)
id, allowed := m.routeACLsPass(srcIP, dstIP, proto, srcPort, dstPort)
strId := string(id)
if id == nil {
strId = "<no id>"
}
msg := fmt.Sprintf("Allowed by route ACLs (%s)", strId)
if !allowed {
msg = fmt.Sprintf("Blocked by route ACLs (%s)", strId)
}
trace.AddResult(StageRouteACL, msg, allowed)
if allowed && m.forwarder.Load() != nil {
m.addForwardingResult(trace, "proxy-remote", fmt.Sprintf("%s:%d", dstIP, dstPort), true)
}
trace.AddResult(StageCompleted, msgProcessingCompleted, allowed)
return trace
}
func (m *Manager) addForwardingResult(trace *PacketTrace, action, remoteAddr string, allowed bool) {
fwdAction := &ForwarderAction{
Action: action,
RemoteAddr: remoteAddr,
}
trace.AddResultWithForwarder(StageForwarding,
fmt.Sprintf("Forwarding to %s", fwdAction.Action), allowed, fwdAction)
}
func (m *Manager) traceOutbound(packetData []byte, trace *PacketTrace) *PacketTrace {
// will create or update the connection state
dropped := m.processOutgoingHooks(packetData, 0)
if dropped {
trace.AddResult(StageCompleted, "Packet dropped by outgoing hook", false)
} else {
trace.AddResult(StageCompleted, "Packet allowed (outgoing)", true)
}
return trace
}

View File

@@ -0,0 +1,440 @@
package uspfilter
import (
"net"
"net/netip"
"testing"
"github.com/stretchr/testify/require"
fw "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
"github.com/netbirdio/netbird/client/firewall/uspfilter/forwarder"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
func verifyTraceStages(t *testing.T, trace *PacketTrace, expectedStages []PacketStage) {
t.Logf("Trace results: %v", trace.Results)
actualStages := make([]PacketStage, 0, len(trace.Results))
for _, result := range trace.Results {
actualStages = append(actualStages, result.Stage)
t.Logf("Stage: %s, Message: %s, Allowed: %v", result.Stage, result.Message, result.Allowed)
}
require.ElementsMatch(t, expectedStages, actualStages, "Trace stages don't match expected stages")
}
func verifyFinalDisposition(t *testing.T, trace *PacketTrace, expectedAllowed bool) {
require.NotEmpty(t, trace.Results, "Trace should have results")
lastResult := trace.Results[len(trace.Results)-1]
require.Equal(t, StageCompleted, lastResult.Stage, "Last stage should be 'Completed'")
require.Equal(t, expectedAllowed, lastResult.Allowed, "Final disposition incorrect")
}
func TestTracePacket(t *testing.T) {
setupTracerTest := func(statefulMode bool) *Manager {
ifaceMock := &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
AddressFunc: func() wgaddr.Address {
return wgaddr.Address{
IP: net.ParseIP("100.10.0.100"),
Network: &net.IPNet{
IP: net.ParseIP("100.10.0.0"),
Mask: net.CIDRMask(16, 32),
},
}
},
}
m, err := Create(ifaceMock, false, flowLogger)
require.NoError(t, err)
if !statefulMode {
m.stateful = false
}
return m
}
createPacketBuilder := func(srcIP, dstIP string, protocol fw.Protocol, srcPort, dstPort uint16, direction fw.RuleDirection) *PacketBuilder {
builder := &PacketBuilder{
SrcIP: netip.MustParseAddr(srcIP),
DstIP: netip.MustParseAddr(dstIP),
Protocol: protocol,
SrcPort: srcPort,
DstPort: dstPort,
Direction: direction,
}
if protocol == "tcp" {
builder.TCPState = &TCPState{SYN: true}
}
return builder
}
createICMPPacketBuilder := func(srcIP, dstIP string, icmpType, icmpCode uint8, direction fw.RuleDirection) *PacketBuilder {
return &PacketBuilder{
SrcIP: netip.MustParseAddr(srcIP),
DstIP: netip.MustParseAddr(dstIP),
Protocol: "icmp",
ICMPType: icmpType,
ICMPCode: icmpCode,
Direction: direction,
}
}
testCases := []struct {
name string
setup func(*Manager)
packetBuilder func() *PacketBuilder
expectedStages []PacketStage
expectedAllow bool
}{
{
name: "LocalTraffic_ACLAllowed",
setup: func(m *Manager) {
ip := net.ParseIP("1.1.1.1")
proto := fw.ProtocolTCP
port := &fw.Port{Values: []uint16{80}}
action := fw.ActionAccept
_, err := m.AddPeerFiltering(nil, ip, proto, nil, port, action, "")
require.NoError(t, err)
},
packetBuilder: func() *PacketBuilder {
return createPacketBuilder("1.1.1.1", "100.10.0.100", "tcp", 12345, 80, fw.RuleDirectionIN)
},
expectedStages: []PacketStage{
StageReceived,
StageConntrack,
StageRouting,
StagePeerACL,
StageCompleted,
},
expectedAllow: true,
},
{
name: "LocalTraffic_ACLDenied",
setup: func(m *Manager) {
ip := net.ParseIP("1.1.1.1")
proto := fw.ProtocolTCP
port := &fw.Port{Values: []uint16{80}}
action := fw.ActionDrop
_, err := m.AddPeerFiltering(nil, ip, proto, nil, port, action, "")
require.NoError(t, err)
},
packetBuilder: func() *PacketBuilder {
return createPacketBuilder("1.1.1.1", "100.10.0.100", "tcp", 12345, 80, fw.RuleDirectionIN)
},
expectedStages: []PacketStage{
StageReceived,
StageConntrack,
StageRouting,
StagePeerACL,
StageCompleted,
},
expectedAllow: false,
},
{
name: "LocalTraffic_WithForwarder",
setup: func(m *Manager) {
m.netstack = true
m.localForwarding = true
m.forwarder.Store(&forwarder.Forwarder{})
ip := net.ParseIP("1.1.1.1")
proto := fw.ProtocolTCP
port := &fw.Port{Values: []uint16{80}}
action := fw.ActionAccept
_, err := m.AddPeerFiltering(nil, ip, proto, nil, port, action, "")
require.NoError(t, err)
},
packetBuilder: func() *PacketBuilder {
return createPacketBuilder("1.1.1.1", "100.10.0.100", "tcp", 12345, 80, fw.RuleDirectionIN)
},
expectedStages: []PacketStage{
StageReceived,
StageConntrack,
StageRouting,
StagePeerACL,
StageForwarding,
StageCompleted,
},
expectedAllow: true,
},
{
name: "LocalTraffic_WithoutForwarder",
setup: func(m *Manager) {
m.netstack = true
m.localForwarding = false
ip := net.ParseIP("1.1.1.1")
proto := fw.ProtocolTCP
port := &fw.Port{Values: []uint16{80}}
action := fw.ActionAccept
_, err := m.AddPeerFiltering(nil, ip, proto, nil, port, action, "")
require.NoError(t, err)
},
packetBuilder: func() *PacketBuilder {
return createPacketBuilder("1.1.1.1", "100.10.0.100", "tcp", 12345, 80, fw.RuleDirectionIN)
},
expectedStages: []PacketStage{
StageReceived,
StageConntrack,
StageRouting,
StagePeerACL,
StageCompleted,
},
expectedAllow: true,
},
{
name: "RoutedTraffic_ACLAllowed",
setup: func(m *Manager) {
m.routingEnabled.Store(true)
m.nativeRouter.Store(false)
m.forwarder.Store(&forwarder.Forwarder{})
src := netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 32)
dst := netip.PrefixFrom(netip.AddrFrom4([4]byte{172, 17, 0, 2}), 32)
_, err := m.AddRouteFiltering(nil, []netip.Prefix{src}, dst, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept)
require.NoError(t, err)
},
packetBuilder: func() *PacketBuilder {
return createPacketBuilder("1.1.1.1", "172.17.0.2", "tcp", 12345, 80, fw.RuleDirectionIN)
},
expectedStages: []PacketStage{
StageReceived,
StageConntrack,
StageRouting,
StageRouteACL,
StageForwarding,
StageCompleted,
},
expectedAllow: true,
},
{
name: "RoutedTraffic_ACLDenied",
setup: func(m *Manager) {
m.routingEnabled.Store(true)
m.nativeRouter.Store(false)
src := netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 32)
dst := netip.PrefixFrom(netip.AddrFrom4([4]byte{172, 17, 0, 2}), 32)
_, err := m.AddRouteFiltering(nil, []netip.Prefix{src}, dst, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionDrop)
require.NoError(t, err)
},
packetBuilder: func() *PacketBuilder {
return createPacketBuilder("1.1.1.1", "172.17.0.2", "tcp", 12345, 80, fw.RuleDirectionIN)
},
expectedStages: []PacketStage{
StageReceived,
StageConntrack,
StageRouting,
StageRouteACL,
StageCompleted,
},
expectedAllow: false,
},
{
name: "RoutedTraffic_NativeRouter",
setup: func(m *Manager) {
m.routingEnabled.Store(true)
m.nativeRouter.Store(true)
},
packetBuilder: func() *PacketBuilder {
return createPacketBuilder("1.1.1.1", "172.17.0.2", "tcp", 12345, 80, fw.RuleDirectionIN)
},
expectedStages: []PacketStage{
StageReceived,
StageConntrack,
StageRouting,
StageRouteACL,
StageForwarding,
StageCompleted,
},
expectedAllow: true,
},
{
name: "RoutedTraffic_RoutingDisabled",
setup: func(m *Manager) {
m.routingEnabled.Store(false)
},
packetBuilder: func() *PacketBuilder {
return createPacketBuilder("1.1.1.1", "172.17.0.2", "tcp", 12345, 80, fw.RuleDirectionIN)
},
expectedStages: []PacketStage{
StageReceived,
StageConntrack,
StageRouting,
StageCompleted,
},
expectedAllow: false,
},
{
name: "ConnectionTracking_Hit",
setup: func(m *Manager) {
srcIP := netip.MustParseAddr("100.10.0.100")
dstIP := netip.MustParseAddr("1.1.1.1")
srcPort := uint16(12345)
dstPort := uint16(80)
m.tcpTracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, conntrack.TCPSyn, 0)
},
packetBuilder: func() *PacketBuilder {
pb := createPacketBuilder("1.1.1.1", "100.10.0.100", "tcp", 80, 12345, fw.RuleDirectionIN)
pb.TCPState = &TCPState{SYN: true, ACK: true}
return pb
},
expectedStages: []PacketStage{
StageReceived,
StageConntrack,
StageCompleted,
},
expectedAllow: true,
},
{
name: "OutboundTraffic",
setup: func(m *Manager) {
},
packetBuilder: func() *PacketBuilder {
return createPacketBuilder("100.10.0.100", "1.1.1.1", "tcp", 12345, 80, fw.RuleDirectionOUT)
},
expectedStages: []PacketStage{
StageReceived,
StageCompleted,
},
expectedAllow: true,
},
{
name: "ICMPEchoRequest",
setup: func(m *Manager) {
ip := net.ParseIP("1.1.1.1")
proto := fw.ProtocolICMP
action := fw.ActionAccept
_, err := m.AddPeerFiltering(nil, ip, proto, nil, nil, action, "")
require.NoError(t, err)
},
packetBuilder: func() *PacketBuilder {
return createICMPPacketBuilder("1.1.1.1", "100.10.0.100", 8, 0, fw.RuleDirectionIN)
},
expectedStages: []PacketStage{
StageReceived,
StageConntrack,
StageRouting,
StagePeerACL,
StageCompleted,
},
expectedAllow: true,
},
{
name: "ICMPDestinationUnreachable",
setup: func(m *Manager) {
ip := net.ParseIP("1.1.1.1")
proto := fw.ProtocolICMP
action := fw.ActionDrop
_, err := m.AddPeerFiltering(nil, ip, proto, nil, nil, action, "")
require.NoError(t, err)
},
packetBuilder: func() *PacketBuilder {
return createICMPPacketBuilder("1.1.1.1", "100.10.0.100", 3, 0, fw.RuleDirectionIN)
},
expectedStages: []PacketStage{
StageReceived,
StageConntrack,
StageRouting,
StagePeerACL,
StageCompleted,
},
expectedAllow: true,
},
{
name: "UDPTraffic_WithoutHook",
setup: func(m *Manager) {
ip := net.ParseIP("1.1.1.1")
proto := fw.ProtocolUDP
port := &fw.Port{Values: []uint16{53}}
action := fw.ActionAccept
_, err := m.AddPeerFiltering(nil, ip, proto, nil, port, action, "")
require.NoError(t, err)
},
packetBuilder: func() *PacketBuilder {
return createPacketBuilder("1.1.1.1", "100.10.0.100", "udp", 12345, 53, fw.RuleDirectionIN)
},
expectedStages: []PacketStage{
StageReceived,
StageConntrack,
StageRouting,
StagePeerACL,
StageCompleted,
},
expectedAllow: true,
},
{
name: "UDPTraffic_WithHook",
setup: func(m *Manager) {
hookFunc := func([]byte) bool {
return true
}
m.AddUDPPacketHook(true, netip.MustParseAddr("1.1.1.1"), 53, hookFunc)
},
packetBuilder: func() *PacketBuilder {
return createPacketBuilder("1.1.1.1", "100.10.0.100", "udp", 12345, 53, fw.RuleDirectionIN)
},
expectedStages: []PacketStage{
StageReceived,
StageConntrack,
StageRouting,
StagePeerACL,
StageCompleted,
},
expectedAllow: false,
},
{
name: "StatefulDisabled_NoTracking",
setup: func(m *Manager) {
m.stateful = false
ip := net.ParseIP("1.1.1.1")
proto := fw.ProtocolTCP
port := &fw.Port{Values: []uint16{80}}
action := fw.ActionDrop
_, err := m.AddPeerFiltering(nil, ip, proto, nil, port, action, "")
require.NoError(t, err)
},
packetBuilder: func() *PacketBuilder {
return createPacketBuilder("1.1.1.1", "100.10.0.100", "tcp", 12345, 80, fw.RuleDirectionIN)
},
expectedStages: []PacketStage{
StageReceived,
StageRouting,
StagePeerACL,
StageCompleted,
},
expectedAllow: false,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
m := setupTracerTest(true)
tc.setup(m)
require.True(t, m.localipmanager.IsLocalIP(netip.MustParseAddr("100.10.0.100")),
"100.10.0.100 should be recognized as a local IP")
require.False(t, m.localipmanager.IsLocalIP(netip.MustParseAddr("172.17.0.2")),
"172.17.0.2 should not be recognized as a local IP")
pb := tc.packetBuilder()
trace, err := m.TracePacketFromBuilder(pb)
require.NoError(t, err)
verifyTraceStages(t, trace, tc.expectedStages)
verifyFinalDisposition(t, trace, tc.expectedAllow)
})
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -1,9 +1,12 @@
//go:build uspbench
package uspfilter
import (
"fmt"
"math/rand"
"net"
"net/netip"
"os"
"strings"
"testing"
@@ -90,8 +93,7 @@ func BenchmarkCoreFiltering(b *testing.B) {
stateful: false,
setupFunc: func(m *Manager) {
// Single rule allowing all traffic
_, err := m.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolALL, nil, nil,
fw.ActionAccept, "", "allow all")
_, err := m.AddPeerFiltering(nil, net.ParseIP("0.0.0.0"), fw.ProtocolALL, nil, nil, fw.ActionAccept, "")
require.NoError(b, err)
},
desc: "Baseline: Single 'allow all' rule without connection tracking",
@@ -111,10 +113,15 @@ func BenchmarkCoreFiltering(b *testing.B) {
// Add explicit rules matching return traffic pattern
for i := 0; i < 1000; i++ { // Simulate realistic ruleset size
ip := generateRandomIPs(1)[0]
_, err := m.AddPeerFiltering(ip, fw.ProtocolTCP,
&fw.Port{Values: []int{1024 + i}},
&fw.Port{Values: []int{80}},
fw.ActionAccept, "", "explicit return")
_, err := m.AddPeerFiltering(
nil,
ip,
fw.ProtocolTCP,
&fw.Port{Values: []uint16{uint16(1024 + i)}},
&fw.Port{Values: []uint16{80}},
fw.ActionAccept,
"",
)
require.NoError(b, err)
}
},
@@ -125,8 +132,15 @@ func BenchmarkCoreFiltering(b *testing.B) {
stateful: true,
setupFunc: func(m *Manager) {
// Add some basic rules but rely on state for established connections
_, err := m.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolTCP, nil, nil,
fw.ActionDrop, "", "default drop")
_, err := m.AddPeerFiltering(
nil,
net.ParseIP("0.0.0.0"),
fw.ProtocolTCP,
nil,
nil,
fw.ActionDrop,
"",
)
require.NoError(b, err)
},
desc: "Connection tracking with established connections",
@@ -155,9 +169,9 @@ func BenchmarkCoreFiltering(b *testing.B) {
// Create manager and basic setup
manager, _ := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
})
}, false, flowLogger)
defer b.Cleanup(func() {
require.NoError(b, manager.Reset(nil))
require.NoError(b, manager.Close(nil))
})
manager.wgNetwork = &net.IPNet{
@@ -179,13 +193,13 @@ func BenchmarkCoreFiltering(b *testing.B) {
// For stateful scenarios, establish the connection
if sc.stateful {
manager.processOutgoingHooks(outbound)
manager.processOutgoingHooks(outbound, 0)
}
// Measure inbound packet processing
b.ResetTimer()
for i := 0; i < b.N; i++ {
manager.dropFilter(inbound, manager.incomingRules)
manager.dropFilter(inbound, 0)
}
})
}
@@ -200,9 +214,9 @@ func BenchmarkStateScaling(b *testing.B) {
b.Run(fmt.Sprintf("conns_%d", count), func(b *testing.B) {
manager, _ := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
})
}, false, flowLogger)
b.Cleanup(func() {
require.NoError(b, manager.Reset(nil))
require.NoError(b, manager.Close(nil))
})
manager.wgNetwork = &net.IPNet{
@@ -216,7 +230,7 @@ func BenchmarkStateScaling(b *testing.B) {
for i := 0; i < count; i++ {
outbound := generatePacket(b, srcIPs[i], dstIPs[i],
uint16(1024+i), 80, layers.IPProtocolTCP)
manager.processOutgoingHooks(outbound)
manager.processOutgoingHooks(outbound, 0)
}
// Test packet
@@ -224,11 +238,11 @@ func BenchmarkStateScaling(b *testing.B) {
testIn := generatePacket(b, dstIPs[0], srcIPs[0], 80, 1024, layers.IPProtocolTCP)
// First establish our test connection
manager.processOutgoingHooks(testOut)
manager.processOutgoingHooks(testOut, 0)
b.ResetTimer()
for i := 0; i < b.N; i++ {
manager.dropFilter(testIn, manager.incomingRules)
manager.dropFilter(testIn, 0)
}
})
}
@@ -248,9 +262,9 @@ func BenchmarkEstablishmentOverhead(b *testing.B) {
b.Run(sc.name, func(b *testing.B) {
manager, _ := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
})
}, false, flowLogger)
b.Cleanup(func() {
require.NoError(b, manager.Reset(nil))
require.NoError(b, manager.Close(nil))
})
manager.wgNetwork = &net.IPNet{
@@ -264,12 +278,12 @@ func BenchmarkEstablishmentOverhead(b *testing.B) {
inbound := generatePacket(b, dstIP, srcIP, 80, 1024, layers.IPProtocolTCP)
if sc.established {
manager.processOutgoingHooks(outbound)
manager.processOutgoingHooks(outbound, 0)
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
manager.dropFilter(inbound, manager.incomingRules)
manager.dropFilter(inbound, 0)
}
})
}
@@ -447,9 +461,9 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
b.Run(sc.name, func(b *testing.B) {
manager, _ := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
})
}, false, flowLogger)
b.Cleanup(func() {
require.NoError(b, manager.Reset(nil))
require.NoError(b, manager.Close(nil))
})
// Setup scenario
@@ -463,25 +477,25 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
// For stateful cases and established connections
if !strings.Contains(sc.name, "allow_non_wg") ||
(strings.Contains(sc.state, "established") || sc.state == "post_handshake") {
manager.processOutgoingHooks(outbound)
manager.processOutgoingHooks(outbound, 0)
// For TCP post-handshake, simulate full handshake
if sc.state == "post_handshake" {
// SYN
syn := generateTCPPacketWithFlags(b, srcIP, dstIP, 1024, 80, uint16(conntrack.TCPSyn))
manager.processOutgoingHooks(syn)
manager.processOutgoingHooks(syn, 0)
// SYN-ACK
synack := generateTCPPacketWithFlags(b, dstIP, srcIP, 80, 1024, uint16(conntrack.TCPSyn|conntrack.TCPAck))
manager.dropFilter(synack, manager.incomingRules)
manager.dropFilter(synack, 0)
// ACK
ack := generateTCPPacketWithFlags(b, srcIP, dstIP, 1024, 80, uint16(conntrack.TCPAck))
manager.processOutgoingHooks(ack)
manager.processOutgoingHooks(ack, 0)
}
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
manager.dropFilter(inbound, manager.incomingRules)
manager.dropFilter(inbound, 0)
}
})
}
@@ -574,9 +588,9 @@ func BenchmarkLongLivedConnections(b *testing.B) {
manager, _ := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
})
}, false, flowLogger)
defer b.Cleanup(func() {
require.NoError(b, manager.Reset(nil))
require.NoError(b, manager.Close(nil))
})
manager.SetNetwork(&net.IPNet{
@@ -587,10 +601,7 @@ func BenchmarkLongLivedConnections(b *testing.B) {
// Setup initial state based on scenario
if sc.rules {
// Single rule to allow all return traffic from port 80
_, err := manager.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolTCP,
&fw.Port{Values: []int{80}},
nil,
fw.ActionAccept, "", "return traffic")
_, err := manager.AddPeerFiltering(nil, net.ParseIP("0.0.0.0"), fw.ProtocolTCP, &fw.Port{Values: []uint16{80}}, nil, fw.ActionAccept, "")
require.NoError(b, err)
}
@@ -613,17 +624,17 @@ func BenchmarkLongLivedConnections(b *testing.B) {
// Initial SYN
syn := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
uint16(1024+i), 80, uint16(conntrack.TCPSyn))
manager.processOutgoingHooks(syn)
manager.processOutgoingHooks(syn, 0)
// SYN-ACK
synack := generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i],
80, uint16(1024+i), uint16(conntrack.TCPSyn|conntrack.TCPAck))
manager.dropFilter(synack, manager.incomingRules)
manager.dropFilter(synack, 0)
// ACK
ack := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
uint16(1024+i), 80, uint16(conntrack.TCPAck))
manager.processOutgoingHooks(ack)
manager.processOutgoingHooks(ack, 0)
}
// Prepare test packets simulating bidirectional traffic
@@ -644,9 +655,9 @@ func BenchmarkLongLivedConnections(b *testing.B) {
// Simulate bidirectional traffic
// First outbound data
manager.processOutgoingHooks(outPackets[connIdx])
manager.processOutgoingHooks(outPackets[connIdx], 0)
// Then inbound response - this is what we're actually measuring
manager.dropFilter(inPackets[connIdx], manager.incomingRules)
manager.dropFilter(inPackets[connIdx], 0)
}
})
}
@@ -665,9 +676,9 @@ func BenchmarkShortLivedConnections(b *testing.B) {
manager, _ := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
})
}, false, flowLogger)
defer b.Cleanup(func() {
require.NoError(b, manager.Reset(nil))
require.NoError(b, manager.Close(nil))
})
manager.SetNetwork(&net.IPNet{
@@ -678,10 +689,7 @@ func BenchmarkShortLivedConnections(b *testing.B) {
// Setup initial state based on scenario
if sc.rules {
// Single rule to allow all return traffic from port 80
_, err := manager.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolTCP,
&fw.Port{Values: []int{80}},
nil,
fw.ActionAccept, "", "return traffic")
_, err := manager.AddPeerFiltering(nil, net.ParseIP("0.0.0.0"), fw.ProtocolTCP, &fw.Port{Values: []uint16{80}}, nil, fw.ActionAccept, "")
require.NoError(b, err)
}
@@ -753,19 +761,19 @@ func BenchmarkShortLivedConnections(b *testing.B) {
p := patterns[connIdx]
// Connection establishment
manager.processOutgoingHooks(p.syn)
manager.dropFilter(p.synAck, manager.incomingRules)
manager.processOutgoingHooks(p.ack)
manager.processOutgoingHooks(p.syn, 0)
manager.dropFilter(p.synAck, 0)
manager.processOutgoingHooks(p.ack, 0)
// Data transfer
manager.processOutgoingHooks(p.request)
manager.dropFilter(p.response, manager.incomingRules)
manager.processOutgoingHooks(p.request, 0)
manager.dropFilter(p.response, 0)
// Connection teardown
manager.processOutgoingHooks(p.finClient)
manager.dropFilter(p.ackServer, manager.incomingRules)
manager.dropFilter(p.finServer, manager.incomingRules)
manager.processOutgoingHooks(p.ackClient)
manager.processOutgoingHooks(p.finClient, 0)
manager.dropFilter(p.ackServer, 0)
manager.dropFilter(p.finServer, 0)
manager.processOutgoingHooks(p.ackClient, 0)
}
})
}
@@ -784,9 +792,9 @@ func BenchmarkParallelLongLivedConnections(b *testing.B) {
manager, _ := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
})
}, false, flowLogger)
defer b.Cleanup(func() {
require.NoError(b, manager.Reset(nil))
require.NoError(b, manager.Close(nil))
})
manager.SetNetwork(&net.IPNet{
@@ -796,10 +804,7 @@ func BenchmarkParallelLongLivedConnections(b *testing.B) {
// Setup initial state based on scenario
if sc.rules {
_, err := manager.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolTCP,
&fw.Port{Values: []int{80}},
nil,
fw.ActionAccept, "", "return traffic")
_, err := manager.AddPeerFiltering(nil, net.ParseIP("0.0.0.0"), fw.ProtocolTCP, &fw.Port{Values: []uint16{80}}, nil, fw.ActionAccept, "")
require.NoError(b, err)
}
@@ -821,15 +826,15 @@ func BenchmarkParallelLongLivedConnections(b *testing.B) {
for i := 0; i < sc.connCount; i++ {
syn := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
uint16(1024+i), 80, uint16(conntrack.TCPSyn))
manager.processOutgoingHooks(syn)
manager.processOutgoingHooks(syn, 0)
synack := generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i],
80, uint16(1024+i), uint16(conntrack.TCPSyn|conntrack.TCPAck))
manager.dropFilter(synack, manager.incomingRules)
manager.dropFilter(synack, 0)
ack := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
uint16(1024+i), 80, uint16(conntrack.TCPAck))
manager.processOutgoingHooks(ack)
manager.processOutgoingHooks(ack, 0)
}
// Pre-generate test packets
@@ -851,8 +856,8 @@ func BenchmarkParallelLongLivedConnections(b *testing.B) {
counter++
// Simulate bidirectional traffic
manager.processOutgoingHooks(outPackets[connIdx])
manager.dropFilter(inPackets[connIdx], manager.incomingRules)
manager.processOutgoingHooks(outPackets[connIdx], 0)
manager.dropFilter(inPackets[connIdx], 0)
}
})
})
@@ -872,9 +877,9 @@ func BenchmarkParallelShortLivedConnections(b *testing.B) {
manager, _ := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
})
}, false, flowLogger)
defer b.Cleanup(func() {
require.NoError(b, manager.Reset(nil))
require.NoError(b, manager.Close(nil))
})
manager.SetNetwork(&net.IPNet{
@@ -883,10 +888,7 @@ func BenchmarkParallelShortLivedConnections(b *testing.B) {
})
if sc.rules {
_, err := manager.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolTCP,
&fw.Port{Values: []int{80}},
nil,
fw.ActionAccept, "", "return traffic")
_, err := manager.AddPeerFiltering(nil, net.ParseIP("0.0.0.0"), fw.ProtocolTCP, &fw.Port{Values: []uint16{80}}, nil, fw.ActionAccept, "")
require.NoError(b, err)
}
@@ -948,17 +950,17 @@ func BenchmarkParallelShortLivedConnections(b *testing.B) {
p := patterns[connIdx]
// Full connection lifecycle
manager.processOutgoingHooks(p.syn)
manager.dropFilter(p.synAck, manager.incomingRules)
manager.processOutgoingHooks(p.ack)
manager.processOutgoingHooks(p.syn, 0)
manager.dropFilter(p.synAck, 0)
manager.processOutgoingHooks(p.ack, 0)
manager.processOutgoingHooks(p.request)
manager.dropFilter(p.response, manager.incomingRules)
manager.processOutgoingHooks(p.request, 0)
manager.dropFilter(p.response, 0)
manager.processOutgoingHooks(p.finClient)
manager.dropFilter(p.ackServer, manager.incomingRules)
manager.dropFilter(p.finServer, manager.incomingRules)
manager.processOutgoingHooks(p.ackClient)
manager.processOutgoingHooks(p.finClient, 0)
manager.dropFilter(p.ackServer, 0)
manager.dropFilter(p.finServer, 0)
manager.processOutgoingHooks(p.ackClient, 0)
}
})
})
@@ -996,3 +998,65 @@ func generateTCPPacketWithFlags(b *testing.B, srcIP, dstIP net.IP, srcPort, dstP
require.NoError(b, gopacket.SerializeLayers(buf, opts, ipv4, tcp, gopacket.Payload("test")))
return buf.Bytes()
}
func BenchmarkRouteACLs(b *testing.B) {
manager := setupRoutedManager(b, "10.10.0.100/16")
// Add several route rules to simulate real-world scenario
rules := []struct {
sources []netip.Prefix
dest netip.Prefix
proto fw.Protocol
port *fw.Port
}{
{
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
dest: netip.MustParsePrefix("192.168.1.0/24"),
proto: fw.ProtocolTCP,
port: &fw.Port{Values: []uint16{80, 443}},
},
{
sources: []netip.Prefix{
netip.MustParsePrefix("172.16.0.0/12"),
netip.MustParsePrefix("10.0.0.0/8"),
},
dest: netip.MustParsePrefix("0.0.0.0/0"),
proto: fw.ProtocolICMP,
},
{
sources: []netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")},
dest: netip.MustParsePrefix("192.168.0.0/16"),
proto: fw.ProtocolUDP,
port: &fw.Port{Values: []uint16{53}},
},
}
for _, r := range rules {
_, err := manager.AddRouteFiltering(nil, r.sources, r.dest, r.proto, nil, r.port, fw.ActionAccept)
if err != nil {
b.Fatal(err)
}
}
// Test cases that exercise different matching scenarios
cases := []struct {
srcIP string
dstIP string
proto fw.Protocol
dstPort uint16
}{
{"100.10.0.1", "192.168.1.100", fw.ProtocolTCP, 443}, // Match first rule
{"172.16.0.1", "8.8.8.8", fw.ProtocolICMP, 0}, // Match second rule
{"1.1.1.1", "192.168.1.53", fw.ProtocolUDP, 53}, // Match third rule
{"192.168.1.1", "10.0.0.1", fw.ProtocolTCP, 8080}, // No match
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
for _, tc := range cases {
srcIP := netip.MustParseAddr(tc.srcIP)
dstIP := netip.MustParseAddr(tc.dstIP)
manager.routeACLsPass(srcIP, dstIP, tc.proto, 0, tc.dstPort)
}
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -1,25 +1,50 @@
package uspfilter
import (
"context"
"fmt"
"net"
"net/netip"
"sync"
"testing"
"time"
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
"github.com/sirupsen/logrus"
"github.com/stretchr/testify/require"
wgdevice "golang.zx2c4.com/wireguard/device"
fw "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/firewall/uspfilter/log"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/internal/netflow"
)
var logger = log.NewFromLogrus(logrus.StandardLogger())
var flowLogger = netflow.NewManager(context.Background(), nil, []byte{}, nil).GetLogger()
type IFaceMock struct {
SetFilterFunc func(device.PacketFilter) error
AddressFunc func() iface.WGAddress
SetFilterFunc func(device.PacketFilter) error
AddressFunc func() wgaddr.Address
GetWGDeviceFunc func() *wgdevice.Device
GetDeviceFunc func() *device.FilteredDevice
}
func (i *IFaceMock) GetWGDevice() *wgdevice.Device {
if i.GetWGDeviceFunc == nil {
return nil
}
return i.GetWGDeviceFunc()
}
func (i *IFaceMock) GetDevice() *device.FilteredDevice {
if i.GetDeviceFunc == nil {
return nil
}
return i.GetDeviceFunc()
}
func (i *IFaceMock) SetFilter(iface device.PacketFilter) error {
@@ -29,9 +54,9 @@ func (i *IFaceMock) SetFilter(iface device.PacketFilter) error {
return i.SetFilterFunc(iface)
}
func (i *IFaceMock) Address() iface.WGAddress {
func (i *IFaceMock) Address() wgaddr.Address {
if i.AddressFunc == nil {
return iface.WGAddress{}
return wgaddr.Address{}
}
return i.AddressFunc()
}
@@ -41,7 +66,7 @@ func TestManagerCreate(t *testing.T) {
SetFilterFunc: func(device.PacketFilter) error { return nil },
}
m, err := Create(ifaceMock)
m, err := Create(ifaceMock, false, flowLogger)
if err != nil {
t.Errorf("failed to create Manager: %v", err)
return
@@ -61,7 +86,7 @@ func TestManagerAddPeerFiltering(t *testing.T) {
},
}
m, err := Create(ifaceMock)
m, err := Create(ifaceMock, false, flowLogger)
if err != nil {
t.Errorf("failed to create Manager: %v", err)
return
@@ -69,11 +94,10 @@ func TestManagerAddPeerFiltering(t *testing.T) {
ip := net.ParseIP("192.168.1.1")
proto := fw.ProtocolTCP
port := &fw.Port{Values: []int{80}}
port := &fw.Port{Values: []uint16{80}}
action := fw.ActionDrop
comment := "Test rule"
rule, err := m.AddPeerFiltering(ip, proto, nil, port, action, "", comment)
rule, err := m.AddPeerFiltering(nil, ip, proto, nil, port, action, "")
if err != nil {
t.Errorf("failed to add filtering: %v", err)
return
@@ -95,26 +119,25 @@ func TestManagerDeleteRule(t *testing.T) {
SetFilterFunc: func(device.PacketFilter) error { return nil },
}
m, err := Create(ifaceMock)
m, err := Create(ifaceMock, false, flowLogger)
if err != nil {
t.Errorf("failed to create Manager: %v", err)
return
}
ip := net.ParseIP("192.168.1.1")
ip := netip.MustParseAddr("192.168.1.1")
proto := fw.ProtocolTCP
port := &fw.Port{Values: []int{80}}
port := &fw.Port{Values: []uint16{80}}
action := fw.ActionDrop
comment := "Test rule 2"
rule2, err := m.AddPeerFiltering(ip, proto, nil, port, action, "", comment)
rule2, err := m.AddPeerFiltering(nil, ip.AsSlice(), proto, nil, port, action, "")
if err != nil {
t.Errorf("failed to add filtering: %v", err)
return
}
for _, r := range rule2 {
if _, ok := m.incomingRules[ip.String()][r.GetRuleID()]; !ok {
if _, ok := m.incomingRules[ip][r.ID()]; !ok {
t.Errorf("rule2 is not in the incomingRules")
}
}
@@ -128,7 +151,7 @@ func TestManagerDeleteRule(t *testing.T) {
}
for _, r := range rule2 {
if _, ok := m.incomingRules[ip.String()][r.GetRuleID()]; ok {
if _, ok := m.incomingRules[ip][r.ID()]; ok {
t.Errorf("rule2 is not in the incomingRules")
}
}
@@ -139,7 +162,7 @@ func TestAddUDPPacketHook(t *testing.T) {
name string
in bool
expDir fw.RuleDirection
ip net.IP
ip netip.Addr
dPort uint16
hook func([]byte) bool
expectedID string
@@ -148,7 +171,7 @@ func TestAddUDPPacketHook(t *testing.T) {
name: "Test Outgoing UDP Packet Hook",
in: false,
expDir: fw.RuleDirectionOUT,
ip: net.IPv4(10, 168, 0, 1),
ip: netip.MustParseAddr("10.168.0.1"),
dPort: 8000,
hook: func([]byte) bool { return true },
},
@@ -156,7 +179,7 @@ func TestAddUDPPacketHook(t *testing.T) {
name: "Test Incoming UDP Packet Hook",
in: true,
expDir: fw.RuleDirectionIN,
ip: net.IPv6loopback,
ip: netip.MustParseAddr("::1"),
dPort: 9000,
hook: func([]byte) bool { return false },
},
@@ -166,18 +189,18 @@ func TestAddUDPPacketHook(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
})
}, false, flowLogger)
require.NoError(t, err)
manager.AddUDPPacketHook(tt.in, tt.ip, tt.dPort, tt.hook)
var addedRule Rule
var addedRule PeerRule
if tt.in {
if len(manager.incomingRules[tt.ip.String()]) != 1 {
if len(manager.incomingRules[tt.ip]) != 1 {
t.Errorf("expected 1 incoming rule, got %d", len(manager.incomingRules))
return
}
for _, rule := range manager.incomingRules[tt.ip.String()] {
for _, rule := range manager.incomingRules[tt.ip] {
addedRule = rule
}
} else {
@@ -185,17 +208,17 @@ func TestAddUDPPacketHook(t *testing.T) {
t.Errorf("expected 1 outgoing rule, got %d", len(manager.outgoingRules))
return
}
for _, rule := range manager.outgoingRules[tt.ip.String()] {
for _, rule := range manager.outgoingRules[tt.ip] {
addedRule = rule
}
}
if !tt.ip.Equal(addedRule.ip) {
if tt.ip.Compare(addedRule.ip) != 0 {
t.Errorf("expected ip %s, got %s", tt.ip, addedRule.ip)
return
}
if tt.dPort != addedRule.dPort {
t.Errorf("expected dPort %d, got %d", tt.dPort, addedRule.dPort)
if tt.dPort != addedRule.dPort.Values[0] {
t.Errorf("expected dPort %d, got %d", tt.dPort, addedRule.dPort.Values[0])
return
}
if layers.LayerTypeUDP != addedRule.protoLayer {
@@ -215,7 +238,7 @@ func TestManagerReset(t *testing.T) {
SetFilterFunc: func(device.PacketFilter) error { return nil },
}
m, err := Create(ifaceMock)
m, err := Create(ifaceMock, false, flowLogger)
if err != nil {
t.Errorf("failed to create Manager: %v", err)
return
@@ -223,17 +246,16 @@ func TestManagerReset(t *testing.T) {
ip := net.ParseIP("192.168.1.1")
proto := fw.ProtocolTCP
port := &fw.Port{Values: []int{80}}
port := &fw.Port{Values: []uint16{80}}
action := fw.ActionDrop
comment := "Test rule"
_, err = m.AddPeerFiltering(ip, proto, nil, port, action, "", comment)
_, err = m.AddPeerFiltering(nil, ip, proto, nil, port, action, "")
if err != nil {
t.Errorf("failed to add filtering: %v", err)
return
}
err = m.Reset(nil)
err = m.Close(nil)
if err != nil {
t.Errorf("failed to reset Manager: %v", err)
return
@@ -247,9 +269,18 @@ func TestManagerReset(t *testing.T) {
func TestNotMatchByIP(t *testing.T) {
ifaceMock := &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
AddressFunc: func() wgaddr.Address {
return wgaddr.Address{
IP: net.ParseIP("100.10.0.100"),
Network: &net.IPNet{
IP: net.ParseIP("100.10.0.0"),
Mask: net.CIDRMask(16, 32),
},
}
},
}
m, err := Create(ifaceMock)
m, err := Create(ifaceMock, false, flowLogger)
if err != nil {
t.Errorf("failed to create Manager: %v", err)
return
@@ -262,9 +293,8 @@ func TestNotMatchByIP(t *testing.T) {
ip := net.ParseIP("0.0.0.0")
proto := fw.ProtocolUDP
action := fw.ActionAccept
comment := "Test rule"
_, err = m.AddPeerFiltering(ip, proto, nil, nil, action, "", comment)
_, err = m.AddPeerFiltering(nil, ip, proto, nil, nil, action, "")
if err != nil {
t.Errorf("failed to add filtering: %v", err)
return
@@ -298,12 +328,12 @@ func TestNotMatchByIP(t *testing.T) {
return
}
if m.dropFilter(buf.Bytes(), m.incomingRules) {
if m.dropFilter(buf.Bytes(), 0) {
t.Errorf("expected packet to be accepted")
return
}
if err = m.Reset(nil); err != nil {
if err = m.Close(nil); err != nil {
t.Errorf("failed to reset Manager: %v", err)
return
}
@@ -317,17 +347,17 @@ func TestRemovePacketHook(t *testing.T) {
}
// creating manager instance
manager, err := Create(iface)
manager, err := Create(iface, false, flowLogger)
if err != nil {
t.Fatalf("Failed to create Manager: %s", err)
}
defer func() {
require.NoError(t, manager.Reset(nil))
require.NoError(t, manager.Close(nil))
}()
// Add a UDP packet hook
hookFunc := func(data []byte) bool { return true }
hookID := manager.AddUDPPacketHook(false, net.IPv4(192, 168, 0, 1), 8080, hookFunc)
hookID := manager.AddUDPPacketHook(false, netip.MustParseAddr("192.168.0.1"), 8080, hookFunc)
// Assert the hook is added by finding it in the manager's outgoing rules
found := false
@@ -363,7 +393,7 @@ func TestRemovePacketHook(t *testing.T) {
func TestProcessOutgoingHooks(t *testing.T) {
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
})
}, false, flowLogger)
require.NoError(t, err)
manager.wgNetwork = &net.IPNet{
@@ -371,9 +401,9 @@ func TestProcessOutgoingHooks(t *testing.T) {
Mask: net.CIDRMask(16, 32),
}
manager.udpTracker.Close()
manager.udpTracker = conntrack.NewUDPTracker(100 * time.Millisecond)
manager.udpTracker = conntrack.NewUDPTracker(100*time.Millisecond, logger, flowLogger)
defer func() {
require.NoError(t, manager.Reset(nil))
require.NoError(t, manager.Close(nil))
}()
manager.decoders = sync.Pool{
@@ -393,7 +423,7 @@ func TestProcessOutgoingHooks(t *testing.T) {
hookCalled := false
hookID := manager.AddUDPPacketHook(
false,
net.ParseIP("100.10.0.100"),
netip.MustParseAddr("100.10.0.100"),
53,
func([]byte) bool {
hookCalled = true
@@ -428,7 +458,7 @@ func TestProcessOutgoingHooks(t *testing.T) {
require.NoError(t, err)
// Test hook gets called
result := manager.processOutgoingHooks(buf.Bytes())
result := manager.processOutgoingHooks(buf.Bytes(), 0)
require.True(t, result)
require.True(t, hookCalled)
@@ -438,7 +468,7 @@ func TestProcessOutgoingHooks(t *testing.T) {
err = gopacket.SerializeLayers(buf, opts, ipv4)
require.NoError(t, err)
result = manager.processOutgoingHooks(buf.Bytes())
result = manager.processOutgoingHooks(buf.Bytes(), 0)
require.False(t, result)
}
@@ -449,12 +479,12 @@ func TestUSPFilterCreatePerformance(t *testing.T) {
ifaceMock := &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}
manager, err := Create(ifaceMock)
manager, err := Create(ifaceMock, false, flowLogger)
require.NoError(t, err)
time.Sleep(time.Second)
defer func() {
if err := manager.Reset(nil); err != nil {
if err := manager.Close(nil); err != nil {
t.Errorf("clear the manager state: %v", err)
}
time.Sleep(time.Second)
@@ -463,8 +493,8 @@ func TestUSPFilterCreatePerformance(t *testing.T) {
ip := net.ParseIP("10.20.0.100")
start := time.Now()
for i := 0; i < testMax; i++ {
port := &fw.Port{Values: []int{1000 + i}}
_, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.ActionAccept, "", "accept HTTP traffic")
port := &fw.Port{Values: []uint16{uint16(1000 + i)}}
_, err = manager.AddPeerFiltering(nil, ip, "tcp", nil, port, fw.ActionAccept, "")
require.NoError(t, err, "failed to add rule")
}
@@ -476,7 +506,7 @@ func TestUSPFilterCreatePerformance(t *testing.T) {
func TestStatefulFirewall_UDPTracking(t *testing.T) {
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
})
}, false, flowLogger)
require.NoError(t, err)
manager.wgNetwork = &net.IPNet{
@@ -485,7 +515,7 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
}
manager.udpTracker.Close() // Close the existing tracker
manager.udpTracker = conntrack.NewUDPTracker(200 * time.Millisecond)
manager.udpTracker = conntrack.NewUDPTracker(200*time.Millisecond, logger, flowLogger)
manager.decoders = sync.Pool{
New: func() any {
d := &decoder{
@@ -500,12 +530,12 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
},
}
defer func() {
require.NoError(t, manager.Reset(nil))
require.NoError(t, manager.Close(nil))
}()
// Set up packet parameters
srcIP := net.ParseIP("100.10.0.1")
dstIP := net.ParseIP("100.10.0.100")
srcIP := netip.MustParseAddr("100.10.0.1")
dstIP := netip.MustParseAddr("100.10.0.100")
srcPort := uint16(51334)
dstPort := uint16(53)
@@ -513,8 +543,8 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
outboundIPv4 := &layers.IPv4{
TTL: 64,
Version: 4,
SrcIP: srcIP,
DstIP: dstIP,
SrcIP: srcIP.AsSlice(),
DstIP: dstIP.AsSlice(),
Protocol: layers.IPProtocolUDP,
}
outboundUDP := &layers.UDP{
@@ -539,15 +569,15 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
require.NoError(t, err)
// Process outbound packet and verify connection tracking
drop := manager.DropOutgoing(outboundBuf.Bytes())
drop := manager.DropOutgoing(outboundBuf.Bytes(), 0)
require.False(t, drop, "Initial outbound packet should not be dropped")
// Verify connection was tracked
conn, exists := manager.udpTracker.GetConnection(srcIP, srcPort, dstIP, dstPort)
require.True(t, exists, "Connection should be tracked after outbound packet")
require.True(t, conntrack.ValidateIPs(conntrack.MakeIPAddr(srcIP), conn.SourceIP), "Source IP should match")
require.True(t, conntrack.ValidateIPs(conntrack.MakeIPAddr(dstIP), conn.DestIP), "Destination IP should match")
require.True(t, srcIP.Compare(conn.SourceIP) == 0, "Source IP should match")
require.True(t, dstIP.Compare(conn.DestIP) == 0, "Destination IP should match")
require.Equal(t, srcPort, conn.SourcePort, "Source port should match")
require.Equal(t, dstPort, conn.DestPort, "Destination port should match")
@@ -555,8 +585,8 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
inboundIPv4 := &layers.IPv4{
TTL: 64,
Version: 4,
SrcIP: dstIP, // Original destination is now source
DstIP: srcIP, // Original source is now destination
SrcIP: dstIP.AsSlice(), // Original destination is now source
DstIP: srcIP.AsSlice(), // Original source is now destination
Protocol: layers.IPProtocolUDP,
}
inboundUDP := &layers.UDP{
@@ -606,7 +636,7 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
for _, cp := range checkPoints {
time.Sleep(cp.sleep)
drop = manager.dropFilter(inboundBuf.Bytes(), manager.incomingRules)
drop = manager.dropFilter(inboundBuf.Bytes(), 0)
require.Equal(t, cp.shouldAllow, !drop, cp.description)
// If the connection should still be valid, verify it exists
@@ -655,7 +685,7 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
}
// Create a new outbound connection for invalid tests
drop = manager.processOutgoingHooks(outboundBuf.Bytes())
drop = manager.processOutgoingHooks(outboundBuf.Bytes(), 0)
require.False(t, drop, "Second outbound packet should not be dropped")
for _, tc := range invalidCases {
@@ -677,7 +707,7 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
require.NoError(t, err)
// Verify the invalid packet is dropped
drop = manager.dropFilter(testBuf.Bytes(), manager.incomingRules)
drop = manager.dropFilter(testBuf.Bytes(), 0)
require.True(t, drop, tc.description)
})
}

View File

@@ -5,7 +5,6 @@ import (
"net"
"net/netip"
"runtime"
"strings"
"sync"
"github.com/pion/stun/v2"
@@ -14,6 +13,8 @@ import (
"golang.org/x/net/ipv4"
"golang.org/x/net/ipv6"
wgConn "golang.zx2c4.com/wireguard/conn"
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
type RecvMessage struct {
@@ -52,9 +53,10 @@ type ICEBind struct {
muUDPMux sync.Mutex
udpMux *UniversalUDPMuxDefault
address wgaddr.Address
}
func NewICEBind(transportNet transport.Net, filterFn FilterFn) *ICEBind {
func NewICEBind(transportNet transport.Net, filterFn FilterFn, address wgaddr.Address) *ICEBind {
b, _ := wgConn.NewStdNetBind().(*wgConn.StdNetBind)
ib := &ICEBind{
StdNetBind: b,
@@ -64,6 +66,7 @@ func NewICEBind(transportNet transport.Net, filterFn FilterFn) *ICEBind {
endpoints: make(map[netip.Addr]net.Conn),
closedChan: make(chan struct{}),
closed: true,
address: address,
}
rc := receiverCreator{
@@ -108,35 +111,17 @@ func (s *ICEBind) GetICEMux() (*UniversalUDPMuxDefault, error) {
return s.udpMux, nil
}
func (b *ICEBind) SetEndpoint(peerAddress *net.UDPAddr, conn net.Conn) (*net.UDPAddr, error) {
fakeUDPAddr, err := fakeAddress(peerAddress)
if err != nil {
return nil, err
}
// force IPv4
fakeAddr, ok := netip.AddrFromSlice(fakeUDPAddr.IP.To4())
if !ok {
return nil, fmt.Errorf("failed to convert IP to netip.Addr")
}
func (b *ICEBind) SetEndpoint(fakeIP netip.Addr, conn net.Conn) {
b.endpointsMu.Lock()
b.endpoints[fakeAddr] = conn
b.endpoints[fakeIP] = conn
b.endpointsMu.Unlock()
return fakeUDPAddr, nil
}
func (b *ICEBind) RemoveEndpoint(fakeUDPAddr *net.UDPAddr) {
fakeAddr, ok := netip.AddrFromSlice(fakeUDPAddr.IP.To4())
if !ok {
log.Warnf("failed to convert IP to netip.Addr")
return
}
func (b *ICEBind) RemoveEndpoint(fakeIP netip.Addr) {
b.endpointsMu.Lock()
defer b.endpointsMu.Unlock()
delete(b.endpoints, fakeAddr)
delete(b.endpoints, fakeIP)
}
func (b *ICEBind) Send(bufs [][]byte, ep wgConn.Endpoint) error {
@@ -161,9 +146,10 @@ func (s *ICEBind) createIPv4ReceiverFn(pc *ipv4.PacketConn, conn *net.UDPConn, r
s.udpMux = NewUniversalUDPMuxDefault(
UniversalUDPMuxParams{
UDPConn: conn,
Net: s.transportNet,
FilterFn: s.filterFn,
UDPConn: conn,
Net: s.transportNet,
FilterFn: s.filterFn,
WGAddress: s.address,
},
)
return func(bufs [][]byte, sizes []int, eps []wgConn.Endpoint) (n int, err error) {
@@ -275,21 +261,6 @@ func (c *ICEBind) receiveRelayed(buffs [][]byte, sizes []int, eps []wgConn.Endpo
}
}
// fakeAddress returns a fake address that is used to as an identifier for the peer.
// The fake address is in the format of 127.1.x.x where x.x is the last two octets of the peer address.
func fakeAddress(peerAddress *net.UDPAddr) (*net.UDPAddr, error) {
octets := strings.Split(peerAddress.IP.String(), ".")
if len(octets) != 4 {
return nil, fmt.Errorf("invalid IP format")
}
newAddr := &net.UDPAddr{
IP: net.ParseIP(fmt.Sprintf("127.1.%s.%s", octets[2], octets[3])),
Port: peerAddress.Port,
}
return newAddr, nil
}
func getMessages(msgsPool *sync.Pool) *[]ipv6.Message {
return msgsPool.Get().(*[]ipv6.Message)
}

View File

@@ -4,6 +4,7 @@ import (
"fmt"
"io"
"net"
"slices"
"strings"
"sync"
@@ -152,46 +153,7 @@ func NewUDPMuxDefault(params UDPMuxParams) *UDPMuxDefault {
params.Logger = logging.NewDefaultLoggerFactory().NewLogger("ice")
}
var localAddrsForUnspecified []net.Addr
if addr, ok := params.UDPConn.LocalAddr().(*net.UDPAddr); !ok {
params.Logger.Errorf("LocalAddr is not a net.UDPAddr, got %T", params.UDPConn.LocalAddr())
} else if ok && addr.IP.IsUnspecified() {
// For unspecified addresses, the correct behavior is to return errListenUnspecified, but
// it will break the applications that are already using unspecified UDP connection
// with UDPMuxDefault, so print a warn log and create a local address list for mux.
params.Logger.Warn("UDPMuxDefault should not listening on unspecified address, use NewMultiUDPMuxFromPort instead")
var networks []ice.NetworkType
switch {
case addr.IP.To16() != nil:
networks = []ice.NetworkType{ice.NetworkTypeUDP4, ice.NetworkTypeUDP6}
case addr.IP.To4() != nil:
networks = []ice.NetworkType{ice.NetworkTypeUDP4}
default:
params.Logger.Errorf("LocalAddr expected IPV4 or IPV6, got %T", params.UDPConn.LocalAddr())
}
if len(networks) > 0 {
if params.Net == nil {
var err error
if params.Net, err = stdnet.NewNet(); err != nil {
params.Logger.Errorf("failed to get create network: %v", err)
}
}
ips, err := localInterfaces(params.Net, params.InterfaceFilter, nil, networks, true)
if err == nil {
for _, ip := range ips {
localAddrsForUnspecified = append(localAddrsForUnspecified, &net.UDPAddr{IP: ip, Port: addr.Port})
}
} else {
params.Logger.Errorf("failed to get local interfaces for unspecified addr: %v", err)
}
}
}
return &UDPMuxDefault{
mux := &UDPMuxDefault{
addressMap: map[string][]*udpMuxedConn{},
params: params,
connsIPv4: make(map[string]*udpMuxedConn),
@@ -203,8 +165,55 @@ func NewUDPMuxDefault(params UDPMuxParams) *UDPMuxDefault {
return newBufferHolder(receiveMTU + maxAddrSize)
},
},
localAddrsForUnspecified: localAddrsForUnspecified,
}
mux.updateLocalAddresses()
return mux
}
func (m *UDPMuxDefault) updateLocalAddresses() {
var localAddrsForUnspecified []net.Addr
if addr, ok := m.params.UDPConn.LocalAddr().(*net.UDPAddr); !ok {
m.params.Logger.Errorf("LocalAddr is not a net.UDPAddr, got %T", m.params.UDPConn.LocalAddr())
} else if ok && addr.IP.IsUnspecified() {
// For unspecified addresses, the correct behavior is to return errListenUnspecified, but
// it will break the applications that are already using unspecified UDP connection
// with UDPMuxDefault, so print a warn log and create a local address list for mux.
m.params.Logger.Warn("UDPMuxDefault should not listening on unspecified address, use NewMultiUDPMuxFromPort instead")
var networks []ice.NetworkType
switch {
case addr.IP.To16() != nil:
networks = []ice.NetworkType{ice.NetworkTypeUDP4, ice.NetworkTypeUDP6}
case addr.IP.To4() != nil:
networks = []ice.NetworkType{ice.NetworkTypeUDP4}
default:
m.params.Logger.Errorf("LocalAddr expected IPV4 or IPV6, got %T", m.params.UDPConn.LocalAddr())
}
if len(networks) > 0 {
if m.params.Net == nil {
var err error
if m.params.Net, err = stdnet.NewNet(); err != nil {
m.params.Logger.Errorf("failed to get create network: %v", err)
}
}
ips, err := localInterfaces(m.params.Net, m.params.InterfaceFilter, nil, networks, true)
if err == nil {
for _, ip := range ips {
localAddrsForUnspecified = append(localAddrsForUnspecified, &net.UDPAddr{IP: ip, Port: addr.Port})
}
} else {
m.params.Logger.Errorf("failed to get local interfaces for unspecified addr: %v", err)
}
}
}
m.mu.Lock()
m.localAddrsForUnspecified = localAddrsForUnspecified
m.mu.Unlock()
}
// LocalAddr returns the listening address of this UDPMuxDefault
@@ -214,8 +223,12 @@ func (m *UDPMuxDefault) LocalAddr() net.Addr {
// GetListenAddresses returns the list of addresses that this mux is listening on
func (m *UDPMuxDefault) GetListenAddresses() []net.Addr {
m.updateLocalAddresses()
m.mu.Lock()
defer m.mu.Unlock()
if len(m.localAddrsForUnspecified) > 0 {
return m.localAddrsForUnspecified
return slices.Clone(m.localAddrsForUnspecified)
}
return []net.Addr{m.LocalAddr()}
@@ -225,7 +238,10 @@ func (m *UDPMuxDefault) GetListenAddresses() []net.Addr {
// creates the connection if an existing one can't be found
func (m *UDPMuxDefault) GetConn(ufrag string, addr net.Addr) (net.PacketConn, error) {
// don't check addr for mux using unspecified address
if len(m.localAddrsForUnspecified) == 0 && m.params.UDPConn.LocalAddr().String() != addr.String() {
m.mu.Lock()
lenLocalAddrs := len(m.localAddrsForUnspecified)
m.mu.Unlock()
if lenLocalAddrs == 0 && m.params.UDPConn.LocalAddr().String() != addr.String() {
return nil, fmt.Errorf("invalid address %s", addr.String())
}

View File

@@ -17,6 +17,8 @@ import (
"github.com/pion/logging"
"github.com/pion/stun/v2"
"github.com/pion/transport/v3"
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
// FilterFn is a function that filters out candidates based on the address.
@@ -41,6 +43,7 @@ type UniversalUDPMuxParams struct {
XORMappedAddrCacheTTL time.Duration
Net transport.Net
FilterFn FilterFn
WGAddress wgaddr.Address
}
// NewUniversalUDPMuxDefault creates an implementation of UniversalUDPMux embedding UDPMux
@@ -64,6 +67,7 @@ func NewUniversalUDPMuxDefault(params UniversalUDPMuxParams) *UniversalUDPMuxDef
mux: m,
logger: params.Logger,
filterFn: params.FilterFn,
address: params.WGAddress,
}
// embed UDPMux
@@ -118,6 +122,7 @@ type udpConn struct {
filterFn FilterFn
// TODO: reset cache on route changes
addrCache sync.Map
address wgaddr.Address
}
func (u *udpConn) WriteTo(b []byte, addr net.Addr) (int, error) {
@@ -159,6 +164,11 @@ func (u *udpConn) performFilterCheck(addr net.Addr) error {
return nil
}
if u.address.Network.Contains(a.AsSlice()) {
log.Warnf("Address %s is part of the NetBird network %s, refusing to write", addr, u.address)
return fmt.Errorf("address %s is part of the NetBird network %s, refusing to write", addr, u.address)
}
if isRouted, prefix, err := u.filterFn(a); err != nil {
log.Errorf("Failed to check if address %s is routed: %v", addr, err)
} else {

View File

@@ -43,13 +43,7 @@ func (c *KernelConfigurer) ConfigureInterface(privateKey string, port int) error
return nil
}
func (c *KernelConfigurer) UpdatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error {
// parse allowed ips
_, ipNet, err := net.ParseCIDR(allowedIps)
if err != nil {
return err
}
func (c *KernelConfigurer) UpdatePeer(peerKey string, allowedIps []net.IPNet, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error {
peerKeyParsed, err := wgtypes.ParseKey(peerKey)
if err != nil {
return err
@@ -58,7 +52,7 @@ func (c *KernelConfigurer) UpdatePeer(peerKey string, allowedIps string, keepAli
PublicKey: peerKeyParsed,
ReplaceAllowedIPs: false,
// don't replace allowed ips, wg will handle duplicated peer IP
AllowedIPs: []net.IPNet{*ipNet},
AllowedIPs: allowedIps,
PersistentKeepaliveInterval: &keepAlive,
Endpoint: endpoint,
PresharedKey: preSharedKey,

View File

@@ -2,5 +2,5 @@
package configurer
// WgInterfaceDefault is a default interface name of Wiretrustee
// WgInterfaceDefault is a default interface name of Netbird
const WgInterfaceDefault = "wt0"

View File

@@ -2,5 +2,5 @@
package configurer
// WgInterfaceDefault is a default interface name of Wiretrustee
// WgInterfaceDefault is a default interface name of Netbird
const WgInterfaceDefault = "utun100"

View File

@@ -52,13 +52,7 @@ func (c *WGUSPConfigurer) ConfigureInterface(privateKey string, port int) error
return c.device.IpcSet(toWgUserspaceString(config))
}
func (c *WGUSPConfigurer) UpdatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error {
// parse allowed ips
_, ipNet, err := net.ParseCIDR(allowedIps)
if err != nil {
return err
}
func (c *WGUSPConfigurer) UpdatePeer(peerKey string, allowedIps []net.IPNet, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error {
peerKeyParsed, err := wgtypes.ParseKey(peerKey)
if err != nil {
return err
@@ -67,7 +61,7 @@ func (c *WGUSPConfigurer) UpdatePeer(peerKey string, allowedIps string, keepAliv
PublicKey: peerKeyParsed,
ReplaceAllowedIPs: false,
// don't replace allowed ips, wg will handle duplicated peer IP
AllowedIPs: []net.IPNet{*ipNet},
AllowedIPs: allowedIps,
PersistentKeepaliveInterval: &keepAlive,
PresharedKey: preSharedKey,
Endpoint: endpoint,
@@ -362,7 +356,7 @@ func toWgUserspaceString(wgCfg wgtypes.Config) string {
}
func getFwmark() int {
if runtime.GOOS == "linux" && !nbnet.CustomRoutingDisabled() {
if nbnet.AdvancedRouting() {
return nbnet.NetbirdFwmark
}
return 0

View File

@@ -3,16 +3,23 @@
package iface
import (
"golang.zx2c4.com/wireguard/tun/netstack"
wgdevice "golang.zx2c4.com/wireguard/device"
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
type WGTunDevice interface {
Create() (device.WGConfigurer, error)
Up() (*bind.UniversalUDPMuxDefault, error)
UpdateAddr(address WGAddress) error
WgAddress() WGAddress
UpdateAddr(address wgaddr.Address) error
WgAddress() wgaddr.Address
DeviceName() string
Close() error
FilteredDevice() *device.FilteredDevice
Device() *wgdevice.Device
GetNet() *netstack.Net
}

View File

@@ -9,14 +9,16 @@ import (
"golang.org/x/sys/unix"
"golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/tun"
"golang.zx2c4.com/wireguard/tun/netstack"
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
// WGTunDevice ignore the WGTunDevice interface on Android because the creation of the tun device is different on this platform
type WGTunDevice struct {
address WGAddress
address wgaddr.Address
port int
key string
mtu int
@@ -30,7 +32,7 @@ type WGTunDevice struct {
configurer WGConfigurer
}
func NewTunDevice(address WGAddress, port int, key string, mtu int, iceBind *bind.ICEBind, tunAdapter TunAdapter) *WGTunDevice {
func NewTunDevice(address wgaddr.Address, port int, key string, mtu int, iceBind *bind.ICEBind, tunAdapter TunAdapter) *WGTunDevice {
return &WGTunDevice{
address: address,
port: port,
@@ -63,7 +65,7 @@ func (t *WGTunDevice) Create(routes []string, dns string, searchDomains []string
t.filteredDevice = newDeviceFilter(tunDevice)
log.Debugf("attaching to interface %v", name)
t.device = device.NewDevice(t.filteredDevice, t.iceBind, device.NewLogger(wgLogLevel(), "[wiretrustee] "))
t.device = device.NewDevice(t.filteredDevice, t.iceBind, device.NewLogger(wgLogLevel(), "[netbird] "))
// without this property mobile devices can discover remote endpoints if the configured one was wrong.
// this helps with support for the older NetBird clients that had a hardcoded direct mode
// t.device.DisableSomeRoamingForBrokenMobileSemantics()
@@ -92,7 +94,7 @@ func (t *WGTunDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
return udpMux, nil
}
func (t *WGTunDevice) UpdateAddr(addr WGAddress) error {
func (t *WGTunDevice) UpdateAddr(addr wgaddr.Address) error {
// todo implement
return nil
}
@@ -122,7 +124,7 @@ func (t *WGTunDevice) DeviceName() string {
return t.name
}
func (t *WGTunDevice) WgAddress() WGAddress {
func (t *WGTunDevice) WgAddress() wgaddr.Address {
return t.address
}
@@ -130,6 +132,10 @@ func (t *WGTunDevice) FilteredDevice() *FilteredDevice {
return t.filteredDevice
}
func (t *WGTunDevice) GetNet() *netstack.Net {
return nil
}
func routesToString(routes []string) string {
return strings.Join(routes, ";")
}

View File

@@ -9,14 +9,16 @@ import (
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/tun"
"golang.zx2c4.com/wireguard/tun/netstack"
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
type TunDevice struct {
name string
address WGAddress
address wgaddr.Address
port int
key string
mtu int
@@ -28,7 +30,7 @@ type TunDevice struct {
configurer WGConfigurer
}
func NewTunDevice(name string, address WGAddress, port int, key string, mtu int, iceBind *bind.ICEBind) *TunDevice {
func NewTunDevice(name string, address wgaddr.Address, port int, key string, mtu int, iceBind *bind.ICEBind) *TunDevice {
return &TunDevice{
name: name,
address: address,
@@ -84,7 +86,7 @@ func (t *TunDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
return udpMux, nil
}
func (t *TunDevice) UpdateAddr(address WGAddress) error {
func (t *TunDevice) UpdateAddr(address wgaddr.Address) error {
t.address = address
return t.assignAddr()
}
@@ -105,7 +107,7 @@ func (t *TunDevice) Close() error {
return nil
}
func (t *TunDevice) WgAddress() WGAddress {
func (t *TunDevice) WgAddress() wgaddr.Address {
return t.address
}
@@ -117,6 +119,11 @@ func (t *TunDevice) FilteredDevice() *FilteredDevice {
return t.filteredDevice
}
// Device returns the wireguard device
func (t *TunDevice) Device() *device.Device {
return t.device
}
// assignAddr Adds IP address to the tunnel interface and network route based on the range provided
func (t *TunDevice) assignAddr() error {
cmd := exec.Command("ifconfig", t.name, "inet", t.address.IP.String(), t.address.IP.String())
@@ -138,3 +145,7 @@ func (t *TunDevice) assignAddr() error {
}
return nil
}
func (t *TunDevice) GetNet() *netstack.Net {
return nil
}

View File

@@ -2,6 +2,7 @@ package device
import (
"net"
"net/netip"
"sync"
"golang.zx2c4.com/wireguard/tun"
@@ -10,16 +11,16 @@ import (
// PacketFilter interface for firewall abilities
type PacketFilter interface {
// DropOutgoing filter outgoing packets from host to external destinations
DropOutgoing(packetData []byte) bool
DropOutgoing(packetData []byte, size int) bool
// DropIncoming filter incoming packets from external sources to host
DropIncoming(packetData []byte) bool
DropIncoming(packetData []byte, size int) bool
// AddUDPPacketHook calls hook when UDP packet from given direction matched
//
// Hook function returns flag which indicates should be the matched package dropped or not.
// Hook function receives raw network packet data as argument.
AddUDPPacketHook(in bool, ip net.IP, dPort uint16, hook func(packet []byte) bool) string
AddUDPPacketHook(in bool, ip netip.Addr, dPort uint16, hook func(packet []byte) bool) string
// RemovePacketHook removes hook by ID
RemovePacketHook(hookID string) error
@@ -57,7 +58,7 @@ func (d *FilteredDevice) Read(bufs [][]byte, sizes []int, offset int) (n int, er
}
for i := 0; i < n; i++ {
if filter.DropOutgoing(bufs[i][offset : offset+sizes[i]]) {
if filter.DropOutgoing(bufs[i][offset:offset+sizes[i]], sizes[i]) {
bufs = append(bufs[:i], bufs[i+1:]...)
sizes = append(sizes[:i], sizes[i+1:]...)
n--
@@ -81,7 +82,7 @@ func (d *FilteredDevice) Write(bufs [][]byte, offset int) (int, error) {
filteredBufs := make([][]byte, 0, len(bufs))
dropped := 0
for _, buf := range bufs {
if !filter.DropIncoming(buf[offset:]) {
if !filter.DropIncoming(buf[offset:], len(buf)) {
filteredBufs = append(filteredBufs, buf)
dropped++
}

View File

@@ -146,7 +146,7 @@ func TestDeviceWrapperRead(t *testing.T) {
tun.EXPECT().Write(mockBufs, 0).Return(0, nil)
filter := mocks.NewMockPacketFilter(ctrl)
filter.EXPECT().DropIncoming(gomock.Any()).Return(true)
filter.EXPECT().DropIncoming(gomock.Any(), gomock.Any()).Return(true)
wrapped := newDeviceFilter(tun)
wrapped.filter = filter
@@ -201,7 +201,7 @@ func TestDeviceWrapperRead(t *testing.T) {
return 1, nil
})
filter := mocks.NewMockPacketFilter(ctrl)
filter.EXPECT().DropOutgoing(gomock.Any()).Return(true)
filter.EXPECT().DropOutgoing(gomock.Any(), gomock.Any()).Return(true)
wrapped := newDeviceFilter(tun)
wrapped.filter = filter

View File

@@ -10,14 +10,16 @@ import (
"golang.org/x/sys/unix"
"golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/tun"
"golang.zx2c4.com/wireguard/tun/netstack"
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
type TunDevice struct {
name string
address WGAddress
address wgaddr.Address
port int
key string
iceBind *bind.ICEBind
@@ -29,7 +31,7 @@ type TunDevice struct {
configurer WGConfigurer
}
func NewTunDevice(name string, address WGAddress, port int, key string, iceBind *bind.ICEBind, tunFd int) *TunDevice {
func NewTunDevice(name string, address wgaddr.Address, port int, key string, iceBind *bind.ICEBind, tunFd int) *TunDevice {
return &TunDevice{
name: name,
address: address,
@@ -64,7 +66,7 @@ func (t *TunDevice) Create() (WGConfigurer, error) {
t.filteredDevice = newDeviceFilter(tunDevice)
log.Debug("Attaching to interface")
t.device = device.NewDevice(t.filteredDevice, t.iceBind, device.NewLogger(wgLogLevel(), "[wiretrustee] "))
t.device = device.NewDevice(t.filteredDevice, t.iceBind, device.NewLogger(wgLogLevel(), "[netbird] "))
// without this property mobile devices can discover remote endpoints if the configured one was wrong.
// this helps with support for the older NetBird clients that had a hardcoded direct mode
// t.device.DisableSomeRoamingForBrokenMobileSemantics()
@@ -119,11 +121,11 @@ func (t *TunDevice) Close() error {
return nil
}
func (t *TunDevice) WgAddress() WGAddress {
func (t *TunDevice) WgAddress() wgaddr.Address {
return t.address
}
func (t *TunDevice) UpdateAddr(addr WGAddress) error {
func (t *TunDevice) UpdateAddr(_ wgaddr.Address) error {
// todo implement
return nil
}
@@ -131,3 +133,7 @@ func (t *TunDevice) UpdateAddr(addr WGAddress) error {
func (t *TunDevice) FilteredDevice() *FilteredDevice {
return t.filteredDevice
}
func (t *TunDevice) GetNet() *netstack.Net {
return nil
}

View File

@@ -9,15 +9,18 @@ import (
"github.com/pion/transport/v3"
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/tun/netstack"
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/sharedsock"
)
type TunKernelDevice struct {
name string
address WGAddress
address wgaddr.Address
wgPort int
key string
mtu int
@@ -32,9 +35,7 @@ type TunKernelDevice struct {
filterFn bind.FilterFn
}
func NewKernelDevice(name string, address WGAddress, wgPort int, key string, mtu int, transportNet transport.Net) *TunKernelDevice {
checkUser()
func NewKernelDevice(name string, address wgaddr.Address, wgPort int, key string, mtu int, transportNet transport.Net) *TunKernelDevice {
ctx, cancel := context.WithCancel(context.Background())
return &TunKernelDevice{
ctx: ctx,
@@ -99,9 +100,10 @@ func (t *TunKernelDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
return nil, err
}
bindParams := bind.UniversalUDPMuxParams{
UDPConn: rawSock,
Net: t.transportNet,
FilterFn: t.filterFn,
UDPConn: rawSock,
Net: t.transportNet,
FilterFn: t.filterFn,
WGAddress: t.address,
}
mux := bind.NewUniversalUDPMuxDefault(bindParams)
go mux.ReadFromConn(t.ctx)
@@ -112,7 +114,7 @@ func (t *TunKernelDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
return t.udpMux, nil
}
func (t *TunKernelDevice) UpdateAddr(address WGAddress) error {
func (t *TunKernelDevice) UpdateAddr(address wgaddr.Address) error {
t.address = address
return t.assignAddr()
}
@@ -145,7 +147,7 @@ func (t *TunKernelDevice) Close() error {
return closErr
}
func (t *TunKernelDevice) WgAddress() WGAddress {
func (t *TunKernelDevice) WgAddress() wgaddr.Address {
return t.address
}
@@ -153,6 +155,11 @@ func (t *TunKernelDevice) DeviceName() string {
return t.name
}
// Device returns the wireguard device, not applicable for kernel devices
func (t *TunKernelDevice) Device() *device.Device {
return nil
}
func (t *TunKernelDevice) FilteredDevice() *FilteredDevice {
return nil
}
@@ -161,3 +168,7 @@ func (t *TunKernelDevice) FilteredDevice() *FilteredDevice {
func (t *TunKernelDevice) assignAddr() error {
return t.link.assignAddr(t.address)
}
func (t *TunKernelDevice) GetNet() *netstack.Net {
return nil
}

View File

@@ -8,15 +8,18 @@ import (
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/tun/netstack"
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/netstack"
nbnetstack "github.com/netbirdio/netbird/client/iface/netstack"
"github.com/netbirdio/netbird/client/iface/wgaddr"
nbnet "github.com/netbirdio/netbird/util/net"
)
type TunNetstackDevice struct {
name string
address WGAddress
address wgaddr.Address
port int
key string
mtu int
@@ -25,12 +28,14 @@ type TunNetstackDevice struct {
device *device.Device
filteredDevice *FilteredDevice
nsTun *netstack.NetStackTun
nsTun *nbnetstack.NetStackTun
udpMux *bind.UniversalUDPMuxDefault
configurer WGConfigurer
net *netstack.Net
}
func NewNetstackDevice(name string, address WGAddress, wgPort int, key string, mtu int, iceBind *bind.ICEBind, listenAddress string) *TunNetstackDevice {
func NewNetstackDevice(name string, address wgaddr.Address, wgPort int, key string, mtu int, iceBind *bind.ICEBind, listenAddress string) *TunNetstackDevice {
return &TunNetstackDevice{
name: name,
address: address,
@@ -43,13 +48,19 @@ func NewNetstackDevice(name string, address WGAddress, wgPort int, key string, m
}
func (t *TunNetstackDevice) Create() (WGConfigurer, error) {
log.Info("create netstack tun interface")
t.nsTun = netstack.NewNetStackTun(t.listenAddress, t.address.IP.String(), t.mtu)
tunIface, err := t.nsTun.Create()
log.Info("create nbnetstack tun interface")
// TODO: get from service listener runtime IP
dnsAddr := nbnet.GetLastIPFromNetwork(t.address.Network, 1)
log.Debugf("netstack using address: %s", t.address.IP)
t.nsTun = nbnetstack.NewNetStackTun(t.listenAddress, t.address.IP, dnsAddr, t.mtu)
log.Debugf("netstack using dns address: %s", dnsAddr)
tunIface, net, err := t.nsTun.Create()
if err != nil {
return nil, fmt.Errorf("error creating tun device: %s", err)
}
t.filteredDevice = newDeviceFilter(tunIface)
t.net = net
t.device = device.NewDevice(
t.filteredDevice,
@@ -87,7 +98,7 @@ func (t *TunNetstackDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
return udpMux, nil
}
func (t *TunNetstackDevice) UpdateAddr(WGAddress) error {
func (t *TunNetstackDevice) UpdateAddr(wgaddr.Address) error {
return nil
}
@@ -106,7 +117,7 @@ func (t *TunNetstackDevice) Close() error {
return nil
}
func (t *TunNetstackDevice) WgAddress() WGAddress {
func (t *TunNetstackDevice) WgAddress() wgaddr.Address {
return t.address
}
@@ -117,3 +128,12 @@ func (t *TunNetstackDevice) DeviceName() string {
func (t *TunNetstackDevice) FilteredDevice() *FilteredDevice {
return t.filteredDevice
}
// Device returns the wireguard device
func (t *TunNetstackDevice) Device() *device.Device {
return t.device
}
func (t *TunNetstackDevice) GetNet() *netstack.Net {
return t.net
}

View File

@@ -4,20 +4,20 @@ package device
import (
"fmt"
"os"
"runtime"
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/tun"
"golang.zx2c4.com/wireguard/tun/netstack"
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
type USPDevice struct {
name string
address WGAddress
address wgaddr.Address
port int
key string
mtu int
@@ -29,11 +29,9 @@ type USPDevice struct {
configurer WGConfigurer
}
func NewUSPDevice(name string, address WGAddress, port int, key string, mtu int, iceBind *bind.ICEBind) *USPDevice {
func NewUSPDevice(name string, address wgaddr.Address, port int, key string, mtu int, iceBind *bind.ICEBind) *USPDevice {
log.Infof("using userspace bind mode")
checkUser()
return &USPDevice{
name: name,
address: address,
@@ -96,7 +94,7 @@ func (t *USPDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
return udpMux, nil
}
func (t *USPDevice) UpdateAddr(address WGAddress) error {
func (t *USPDevice) UpdateAddr(address wgaddr.Address) error {
t.address = address
return t.assignAddr()
}
@@ -116,7 +114,7 @@ func (t *USPDevice) Close() error {
return nil
}
func (t *USPDevice) WgAddress() WGAddress {
func (t *USPDevice) WgAddress() wgaddr.Address {
return t.address
}
@@ -128,6 +126,11 @@ func (t *USPDevice) FilteredDevice() *FilteredDevice {
return t.filteredDevice
}
// Device returns the wireguard device
func (t *USPDevice) Device() *device.Device {
return t.device
}
// assignAddr Adds IP address to the tunnel interface
func (t *USPDevice) assignAddr() error {
link := newWGLink(t.name)
@@ -135,11 +138,6 @@ func (t *USPDevice) assignAddr() error {
return link.assignAddr(t.address)
}
func checkUser() {
if runtime.GOOS == "freebsd" {
euid := os.Geteuid()
if euid != 0 {
log.Warn("newTunUSPDevice: on netbird must run as root to be able to assign address to the tun interface with ifconfig")
}
}
func (t *USPDevice) GetNet() *netstack.Net {
return nil
}

View File

@@ -8,17 +8,19 @@ import (
"golang.org/x/sys/windows"
"golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/tun"
"golang.zx2c4.com/wireguard/tun/netstack"
"golang.zx2c4.com/wireguard/windows/tunnel/winipcfg"
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
const defaultWindowsGUIDSTring = "{f2f29e61-d91f-4d76-8151-119b20c4bdeb}"
type TunDevice struct {
name string
address WGAddress
address wgaddr.Address
port int
key string
mtu int
@@ -31,7 +33,7 @@ type TunDevice struct {
configurer WGConfigurer
}
func NewTunDevice(name string, address WGAddress, port int, key string, mtu int, iceBind *bind.ICEBind) *TunDevice {
func NewTunDevice(name string, address wgaddr.Address, port int, key string, mtu int, iceBind *bind.ICEBind) *TunDevice {
return &TunDevice{
name: name,
address: address,
@@ -117,7 +119,7 @@ func (t *TunDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
return udpMux, nil
}
func (t *TunDevice) UpdateAddr(address WGAddress) error {
func (t *TunDevice) UpdateAddr(address wgaddr.Address) error {
t.address = address
return t.assignAddr()
}
@@ -138,7 +140,7 @@ func (t *TunDevice) Close() error {
}
return nil
}
func (t *TunDevice) WgAddress() WGAddress {
func (t *TunDevice) WgAddress() wgaddr.Address {
return t.address
}
@@ -150,6 +152,11 @@ func (t *TunDevice) FilteredDevice() *FilteredDevice {
return t.filteredDevice
}
// Device returns the wireguard device
func (t *TunDevice) Device() *device.Device {
return t.device
}
func (t *TunDevice) GetInterfaceGUIDString() (string, error) {
if t.nativeTunDevice == nil {
return "", fmt.Errorf("interface has not been initialized yet")
@@ -169,3 +176,7 @@ func (t *TunDevice) assignAddr() error {
log.Debugf("adding address %s to interface: %s", t.address.IP, t.name)
return luid.SetIPAddresses([]netip.Prefix{netip.MustParsePrefix(t.address.String())})
}
func (t *TunDevice) GetNet() *netstack.Net {
return nil
}

View File

@@ -11,7 +11,7 @@ import (
type WGConfigurer interface {
ConfigureInterface(privateKey string, port int) error
UpdatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error
UpdatePeer(peerKey string, allowedIps []net.IPNet, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error
RemovePeer(peerKey string) error
AddAllowedIP(peerKey string, allowedIP string) error
RemoveAllowedIP(peerKey string, allowedIP string) error

View File

@@ -6,6 +6,7 @@ import (
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/iface/freebsd"
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
type wgLink struct {
@@ -56,7 +57,7 @@ func (l *wgLink) up() error {
return nil
}
func (l *wgLink) assignAddr(address WGAddress) error {
func (l *wgLink) assignAddr(address wgaddr.Address) error {
link, err := freebsd.LinkByName(l.name)
if err != nil {
return fmt.Errorf("link by name: %w", err)

View File

@@ -8,6 +8,8 @@ import (
log "github.com/sirupsen/logrus"
"github.com/vishvananda/netlink"
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
type wgLink struct {
@@ -90,7 +92,7 @@ func (l *wgLink) up() error {
return nil
}
func (l *wgLink) assignAddr(address WGAddress) error {
func (l *wgLink) assignAddr(address wgaddr.Address) error {
//delete existing addresses
list, err := netlink.AddrList(l, 0)
if err != nil {

View File

@@ -1,16 +1,23 @@
package iface
import (
wgdevice "golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/tun/netstack"
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
type WGTunDevice interface {
Create(routes []string, dns string, searchDomains []string) (device.WGConfigurer, error)
Up() (*bind.UniversalUDPMuxDefault, error)
UpdateAddr(address WGAddress) error
WgAddress() WGAddress
UpdateAddr(address wgaddr.Address) error
WgAddress() wgaddr.Address
DeviceName() string
Close() error
FilteredDevice() *device.FilteredDevice
Device() *wgdevice.Device
GetNet() *netstack.Net
}

View File

@@ -203,6 +203,11 @@ func (l *Link) setAddr(ip, netmask string) error {
return fmt.Errorf("set interface addr: %w", err)
}
cmd = exec.Command("ifconfig", l.name, "inet6", "fe80::/64")
if out, err := cmd.CombinedOutput(); err != nil {
log.Debugf("adding address command '%v' failed with output: %s", cmd.String(), out)
}
return nil
}

View File

@@ -3,18 +3,23 @@ package iface
import (
"fmt"
"net"
"net/netip"
"sync"
"time"
"github.com/hashicorp/go-multierror"
"github.com/pion/transport/v3"
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/tun/netstack"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
wgdevice "golang.zx2c4.com/wireguard/device"
"github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/iface/wgproxy"
)
@@ -24,8 +29,6 @@ const (
WgInterfaceDefault = configurer.WgInterfaceDefault
)
type WGAddress = device.WGAddress
type wgProxyFactory interface {
GetProxy() wgproxy.Proxy
Free() error
@@ -68,7 +71,7 @@ func (w *WGIface) Name() string {
}
// Address returns the interface address
func (w *WGIface) Address() device.WGAddress {
func (w *WGIface) Address() wgaddr.Address {
return w.tun.WgAddress()
}
@@ -99,7 +102,7 @@ func (w *WGIface) UpdateAddr(newAddr string) error {
w.mu.Lock()
defer w.mu.Unlock()
addr, err := device.ParseWGAddress(newAddr)
addr, err := wgaddr.ParseWGAddress(newAddr)
if err != nil {
return err
}
@@ -109,12 +112,13 @@ func (w *WGIface) UpdateAddr(newAddr string) error {
// UpdatePeer updates existing Wireguard Peer or creates a new one if doesn't exist
// Endpoint is optional
func (w *WGIface) UpdatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error {
func (w *WGIface) UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error {
w.mu.Lock()
defer w.mu.Unlock()
netIPNets := prefixesToIPNets(allowedIps)
log.Debugf("updating interface %s peer %s, endpoint %s", w.tun.DeviceName(), peerKey, endpoint)
return w.configurer.UpdatePeer(peerKey, allowedIps, keepAlive, endpoint, preSharedKey)
return w.configurer.UpdatePeer(peerKey, netIPNets, keepAlive, endpoint, preSharedKey)
}
// RemovePeer removes a Wireguard Peer from the interface iface
@@ -203,6 +207,11 @@ func (w *WGIface) GetDevice() *device.FilteredDevice {
return w.tun.FilteredDevice()
}
// GetWGDevice returns the WireGuard device
func (w *WGIface) GetWGDevice() *wgdevice.Device {
return w.tun.Device()
}
// GetStats returns the last handshake time, rx and tx bytes for the given peer
func (w *WGIface) GetStats(peerKey string) (configurer.WGStats, error) {
return w.configurer.GetStats(peerKey)
@@ -234,3 +243,22 @@ func (w *WGIface) waitUntilRemoved() error {
}
}
}
// GetNet returns the netstack.Net for the netstack device
func (w *WGIface) GetNet() *netstack.Net {
w.mu.Lock()
defer w.mu.Unlock()
return w.tun.GetNet()
}
func prefixesToIPNets(prefixes []netip.Prefix) []net.IPNet {
ipNets := make([]net.IPNet, len(prefixes))
for i, prefix := range prefixes {
ipNets[i] = net.IPNet{
IP: net.IP(prefix.Addr().AsSlice()), // Convert netip.Addr to net.IP
Mask: net.CIDRMask(prefix.Bits(), prefix.Addr().BitLen()), // Create subnet mask
}
}
return ipNets
}

View File

@@ -1,112 +0,0 @@
package iface
import (
"net"
"time"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgproxy"
)
type MockWGIface struct {
CreateFunc func() error
CreateOnAndroidFunc func(routeRange []string, ip string, domains []string) error
IsUserspaceBindFunc func() bool
NameFunc func() string
AddressFunc func() device.WGAddress
ToInterfaceFunc func() *net.Interface
UpFunc func() (*bind.UniversalUDPMuxDefault, error)
UpdateAddrFunc func(newAddr string) error
UpdatePeerFunc func(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error
RemovePeerFunc func(peerKey string) error
AddAllowedIPFunc func(peerKey string, allowedIP string) error
RemoveAllowedIPFunc func(peerKey string, allowedIP string) error
CloseFunc func() error
SetFilterFunc func(filter device.PacketFilter) error
GetFilterFunc func() device.PacketFilter
GetDeviceFunc func() *device.FilteredDevice
GetStatsFunc func(peerKey string) (configurer.WGStats, error)
GetInterfaceGUIDStringFunc func() (string, error)
GetProxyFunc func() wgproxy.Proxy
}
func (m *MockWGIface) GetInterfaceGUIDString() (string, error) {
return m.GetInterfaceGUIDStringFunc()
}
func (m *MockWGIface) Create() error {
return m.CreateFunc()
}
func (m *MockWGIface) CreateOnAndroid(routeRange []string, ip string, domains []string) error {
return m.CreateOnAndroidFunc(routeRange, ip, domains)
}
func (m *MockWGIface) IsUserspaceBind() bool {
return m.IsUserspaceBindFunc()
}
func (m *MockWGIface) Name() string {
return m.NameFunc()
}
func (m *MockWGIface) Address() device.WGAddress {
return m.AddressFunc()
}
func (m *MockWGIface) ToInterface() *net.Interface {
return m.ToInterfaceFunc()
}
func (m *MockWGIface) Up() (*bind.UniversalUDPMuxDefault, error) {
return m.UpFunc()
}
func (m *MockWGIface) UpdateAddr(newAddr string) error {
return m.UpdateAddrFunc(newAddr)
}
func (m *MockWGIface) UpdatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error {
return m.UpdatePeerFunc(peerKey, allowedIps, keepAlive, endpoint, preSharedKey)
}
func (m *MockWGIface) RemovePeer(peerKey string) error {
return m.RemovePeerFunc(peerKey)
}
func (m *MockWGIface) AddAllowedIP(peerKey string, allowedIP string) error {
return m.AddAllowedIPFunc(peerKey, allowedIP)
}
func (m *MockWGIface) RemoveAllowedIP(peerKey string, allowedIP string) error {
return m.RemoveAllowedIPFunc(peerKey, allowedIP)
}
func (m *MockWGIface) Close() error {
return m.CloseFunc()
}
func (m *MockWGIface) SetFilter(filter device.PacketFilter) error {
return m.SetFilterFunc(filter)
}
func (m *MockWGIface) GetFilter() device.PacketFilter {
return m.GetFilterFunc()
}
func (m *MockWGIface) GetDevice() *device.FilteredDevice {
return m.GetDeviceFunc()
}
func (m *MockWGIface) GetStats(peerKey string) (configurer.WGStats, error) {
return m.GetStatsFunc(peerKey)
}
func (m *MockWGIface) GetProxy() wgproxy.Proxy {
//TODO implement me
panic("implement me")
}

Some files were not shown because too many files have changed in this diff Show More