Compare commits

..

2 Commits

Author SHA1 Message Date
Maycon Santos
5f41e2bd13 Merge branch 'main' into debug-google-workspace 2024-02-19 15:26:58 +01:00
bcmmbaga
e3d038da8a debug google workspace request 2024-02-16 13:10:37 +03:00
261 changed files with 3657 additions and 19791 deletions

View File

@@ -2,7 +2,7 @@
name: Bug/Issue report
about: Create a report to help us improve
title: ''
labels: ['triage-needed']
labels: ['triage']
assignees: ''
---

View File

@@ -32,14 +32,8 @@ jobs:
restore-keys: |
macos-go-
- name: Install libpcap
run: brew install libpcap
- name: Install modules
run: go mod tidy
- name: check git status
run: git --no-pager diff --exit-code
- name: Test
run: NETBIRD_STORE_ENGINE=${{ matrix.store }} go test -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 5m -p 1 ./...
run: NETBIRD_STORE_ENGINE=${{ matrix.store }} go test -exec 'sudo --preserve-env=CI' -timeout 5m -p 1 ./...

View File

@@ -14,8 +14,8 @@ jobs:
test:
strategy:
matrix:
arch: [ '386','amd64' ]
store: [ 'jsonfile', 'sqlite' ]
arch: ['386','amd64']
store: ['jsonfile', 'sqlite']
runs-on: ubuntu-latest
steps:
- name: Install Go
@@ -36,20 +36,13 @@ jobs:
uses: actions/checkout@v3
- name: Install dependencies
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
- name: Install 32-bit libpcap
if: matrix.arch == '386'
run: sudo dpkg --add-architecture i386 && sudo apt update && sudo apt-get install -y libpcap0.8-dev:i386
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib
- name: Install modules
run: go mod tidy
- name: check git status
run: git --no-pager diff --exit-code
- name: Test
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} go test -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 5m -p 1 ./...
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} go test -exec 'sudo --preserve-env=CI' -timeout 5m -p 1 ./...
test_client_on_docker:
runs-on: ubuntu-20.04
@@ -71,14 +64,11 @@ jobs:
uses: actions/checkout@v3
- name: Install dependencies
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib
- name: Install modules
run: go mod tidy
- name: check git status
run: git --no-pager diff --exit-code
- name: Generate Iface Test bin
run: CGO_ENABLED=0 go test -c -o iface-testing.bin ./iface/
@@ -86,7 +76,7 @@ jobs:
run: CGO_ENABLED=0 go test -c -o sharedsock-testing.bin ./sharedsock
- name: Generate RouteManager Test bin
run: CGO_ENABLED=1 go test -c -o routemanager-testing.bin -tags netgo -ldflags '-w -extldflags "-static -ldbus-1 -lpcap"' ./client/internal/routemanager/...
run: CGO_ENABLED=0 go test -c -o routemanager-testing.bin ./client/internal/routemanager/...
- name: Generate nftables Manager Test bin
run: CGO_ENABLED=0 go test -c -o nftablesmanager-testing.bin ./client/firewall/nftables/...
@@ -113,7 +103,7 @@ jobs:
- name: Run Engine tests in docker with file store
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal -e NETBIRD_STORE_ENGINE="jsonfile" --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/engine-testing.bin -test.timeout 5m -test.parallel 1
- name: Run Engine tests in docker with sqlite store
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal -e NETBIRD_STORE_ENGINE="sqlite" --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/engine-testing.bin -test.timeout 5m -test.parallel 1

View File

@@ -46,7 +46,7 @@ jobs:
- run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOCACHE=C:\Users\runneradmin\AppData\Local\go-build
- name: test
run: PsExec64 -s -w ${{ github.workspace }} cmd.exe /c "C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe test -timeout 10m -p 1 ./... > test-out.txt 2>&1"
run: PsExec64 -s -w ${{ github.workspace }} cmd.exe /c "C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe test -timeout 5m -p 1 ./... > test-out.txt 2>&1"
- name: test output
if: ${{ always() }}
run: Get-Content test-out.txt

View File

@@ -19,7 +19,7 @@ jobs:
- name: codespell
uses: codespell-project/actions-codespell@v2
with:
ignore_words_list: erro,clienta,hastable,
ignore_words_list: erro,clienta
skip: go.mod,go.sum
only_warn: 1
golangci:
@@ -33,10 +33,6 @@ jobs:
steps:
- name: Checkout code
uses: actions/checkout@v3
- name: Check for duplicate constants
if: matrix.os == 'ubuntu-latest'
run: |
! awk '/const \(/,/)/{print $0}' management/server/activity/codes.go | grep -o '= [0-9]*' | sort | uniq -d | grep .
- name: Install Go
uses: actions/setup-go@v4
with:
@@ -44,7 +40,7 @@ jobs:
cache: false
- name: Install dependencies
if: matrix.os == 'ubuntu-latest'
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev libpcap-dev
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev
- name: golangci-lint
uses: golangci/golangci-lint-action@v3
with:

View File

@@ -11,7 +11,7 @@ concurrency:
cancel-in-progress: true
jobs:
android_build:
andrloid_build:
runs-on: ubuntu-latest
steps:
- name: Checkout repository
@@ -38,10 +38,10 @@ jobs:
- name: Setup NDK
run: /usr/local/lib/android/sdk/cmdline-tools/7.0/bin/sdkmanager --install "ndk;23.1.7779620"
- name: install gomobile
run: go install golang.org/x/mobile/cmd/gomobile@v0.0.0-20240404231514-09dbf07665ed
run: go install golang.org/x/mobile/cmd/gomobile@v0.0.0-20230531173138-3c911d8e3eda
- name: gomobile init
run: gomobile init
- name: build android netbird lib
- name: build android nebtird lib
run: PATH=$PATH:$(go env GOPATH) gomobile bind -o $GITHUB_WORKSPACE/netbird.aar -javapkg=io.netbird.gomobile -ldflags="-X golang.zx2c4.com/wireguard/ipc.socketDirectory=/data/data/io.netbird.client/cache/wireguard -X github.com/netbirdio/netbird/version.version=buildtest" $GITHUB_WORKSPACE/client/android
env:
CGO_ENABLED: 0
@@ -56,10 +56,10 @@ jobs:
with:
go-version: "1.21.x"
- name: install gomobile
run: go install golang.org/x/mobile/cmd/gomobile@v0.0.0-20240404231514-09dbf07665ed
run: go install golang.org/x/mobile/cmd/gomobile@v0.0.0-20230531173138-3c911d8e3eda
- name: gomobile init
run: gomobile init
- name: build iOS netbird lib
run: PATH=$PATH:$(go env GOPATH) gomobile bind -target=ios -bundleid=io.netbird.framework -ldflags="-X github.com/netbirdio/netbird/version.version=buildtest" -o ./NetBirdSDK.xcframework ./client/ios/NetBirdSDK
- name: build iOS nebtird lib
run: PATH=$PATH:$(go env GOPATH) gomobile bind -target=ios -bundleid=io.netbird.framework -ldflags="-X github.com/netbirdio/netbird/version.version=buildtest" -o $GITHUB_WORKSPACE/NetBirdSDK.xcframework $GITHUB_WORKSPACE/client/ios/NetBirdSDK
env:
CGO_ENABLED: 0

View File

@@ -190,9 +190,6 @@ jobs:
-
name: Install modules
run: go mod tidy
-
name: check git status
run: git --no-pager diff --exit-code
-
name: Run GoReleaser
id: goreleaser

View File

@@ -127,9 +127,6 @@ jobs:
- name: Install modules
run: go mod tidy
- name: check git status
run: git --no-pager diff --exit-code
- name: Build management binary
working-directory: management
run: CGO_ENABLED=1 go build -o netbird-mgmt main.go
@@ -162,13 +159,6 @@ jobs:
test $count -eq 4
working-directory: infrastructure_files/artifacts
- name: test geolocation databases
working-directory: infrastructure_files/artifacts
run: |
sleep 30
docker compose exec management ls -l /var/lib/netbird/ | grep -i GeoLite2-City.mmdb
docker compose exec management ls -l /var/lib/netbird/ | grep -i geonames.db
test-getting-started-script:
runs-on: ubuntu-latest
steps:
@@ -196,16 +186,3 @@ jobs:
run: test -f zitadel.env
- name: test dashboard.env file gen
run: test -f dashboard.env
test-download-geolite2-script:
runs-on: ubuntu-latest
steps:
- name: Install jq
run: sudo apt-get update && sudo apt-get install -y unzip sqlite3
- name: Checkout code
uses: actions/checkout@v3
- name: test script
run: bash -x infrastructure_files/download-geolite2.sh
- name: test mmdb file exists
run: test -f GeoLite2-City.mmdb
- name: test geonames file exists
run: test -f geonames.db

2
.gitignore vendored
View File

@@ -29,4 +29,4 @@ infrastructure_files/setup.env
infrastructure_files/setup-*.env
.vscode
.DS_Store
GeoLite2-City*
*.db

View File

@@ -63,14 +63,6 @@ linters-settings:
enable:
- nilness
revive:
rules:
- name: exported
severity: warning
disabled: false
arguments:
- "checkPrivateReceivers"
- "sayRepetitiveInsteadOfStutters"
tenv:
# The option `all` will run against whole test files (`_test.go`) regardless of method/function signatures.
# Otherwise, only methods that take `*testing.T`, `*testing.B`, and `testing.TB` as arguments are checked.
@@ -101,7 +93,6 @@ linters:
- nilerr # finds the code that returns nil even if it checks that the error is not nil
- nilnil # checks that there is no simultaneous return of nil error and an invalid value
- predeclared # predeclared finds code that shadows one of Go's predeclared identifiers
- revive # Fast, configurable, extensible, flexible, and beautiful linter for Go. Drop-in replacement of golint.
- sqlclosecheck # checks that sql.Rows and sql.Stmt are closed
- thelper # thelper detects Go test helpers without t.Helper() call and checks the consistency of test helpers.
- wastedassign # wastedassign finds wasted assignment statements

View File

@@ -54,7 +54,7 @@ nfpms:
contents:
- src: client/ui/netbird.desktop
dst: /usr/share/applications/netbird.desktop
- src: client/ui/netbird-systemtray-connected.png
- src: client/ui/netbird-systemtray-default.png
dst: /usr/share/pixmaps/netbird.png
dependencies:
- netbird
@@ -71,7 +71,7 @@ nfpms:
contents:
- src: client/ui/netbird.desktop
dst: /usr/share/applications/netbird.desktop
- src: client/ui/netbird-systemtray-connected.png
- src: client/ui/netbird-systemtray-default.png
dst: /usr/share/pixmaps/netbird.png
dependencies:
- netbird

View File

@@ -1,6 +1,6 @@
<p align="center">
<strong>:hatching_chick: New Release! Device Posture Checks.</strong>
<a href="https://docs.netbird.io/how-to/manage-posture-checks">
<strong>:hatching_chick: New Release! Self-hosting in under 5 min.</strong>
<a href="https://github.com/netbirdio/netbird#quickstart-with-self-hosted-netbird">
Learn more
</a>
</p>
@@ -40,26 +40,27 @@
**Connect.** NetBird creates a WireGuard-based overlay network that automatically connects your machines over an encrypted tunnel, leaving behind the hassle of opening ports, complex firewall rules, VPN gateways, and so forth.
**Secure.** NetBird enables secure remote access by applying granular access policies while allowing you to manage them intuitively from a single place. Works universally on any infrastructure.
**Secure.** NetBird enables secure remote access by applying granular access policies, while allowing you to manage them intuitively from a single place. Works universally on any infrastructure.
### Open-Source Network Security in a Single Platform
![netbird_2](https://github.com/netbirdio/netbird/assets/700848/46bc3b73-508d-4a0e-bb9a-f465d68646ab)
### Secure peer-to-peer VPN with SSO and MFA in minutes
https://user-images.githubusercontent.com/700848/197345890-2e2cded5-7b7a-436f-a444-94e80dd24f46.mov
### Key features
| Connectivity | Management | Security | Automation | Platforms |
|------------------------------------------------------------------------------------------------------------------------------|----------------------------------------------------------------------------------------------------------|---------------------------------------------------------------------------------------------------------------------------------------|------------------------------------------------------------------------------------------------------------------------------------------|-----------------------------------------------------------------------------------------|
| <ul><li> - \[x] Kernel WireGuard </ul></li> | <ul><li> - \[x] [Admin Web UI](https://github.com/netbirdio/dashboard) </ul></li> | <ul><li> - \[x] [SSO & MFA support](https://docs.netbird.io/how-to/installation#running-net-bird-with-sso-login) </ul></li> | <ul><li> - \[x] [Public API](https://docs.netbird.io/api) </ul></li> | <ul><li> - \[x] Linux </ul></li> |
| <ul><li> - \[x] Peer-to-peer connections </ul></li> | <ul><li> - \[x] Auto peer discovery and configuration </ul></li> | <ul><li> - \[x] [Access control - groups & rules](https://docs.netbird.io/how-to/manage-network-access) </ul></li> | <ul><li> - \[x] [Setup keys for bulk network provisioning](https://docs.netbird.io/how-to/register-machines-using-setup-keys) </ul></li> | <ul><li> - \[x] Mac </ul></li> |
| <ul><li> - \[x] Connection relay fallback </ul></li> | <ul><li> - \[x] [IdP integrations](https://docs.netbird.io/selfhosted/identity-providers) </ul></li> | <ul><li> - \[x] [Activity logging](https://docs.netbird.io/how-to/monitor-system-and-network-activity) </ul></li> | <ul><li> - \[x] [Self-hosting quickstart script](https://docs.netbird.io/selfhosted/selfhosted-quickstart) </ul></li> | <ul><li> - \[x] Windows </ul></li> |
| <ul><li> - \[x] [Routes to external networks](https://docs.netbird.io/how-to/routing-traffic-to-private-networks) </ul></li> | <ul><li> - \[x] [Private DNS](https://docs.netbird.io/how-to/manage-dns-in-your-network) </ul></li> | <ul><li> - \[x] [Device posture checks](https://docs.netbird.io/how-to/manage-posture-checks) </ul></li> | <ul><li> - \[x] IdP groups sync with JWT </ul></li> | <ul><li> - \[x] Android </ul></li> |
| <ul><li> - \[x] NAT traversal with BPF </ul></li> | <ul><li> - \[x] [Multiuser support](https://docs.netbird.io/how-to/add-users-to-your-network) </ul></li> | <ul><li> - \[x] Peer-to-peer encryption </ul></li> | | <ul><li> - \[x] iOS </ul></li> |
| | | <ul><li> - \[x] [Quantum-resistance with Rosenpass](https://netbird.io/knowledge-hub/the-first-quantum-resistant-mesh-vpn) </ul></li> | | <ul><li> - \[x] OpenWRT </ul></li> |
| | | <ui><li> - \[x] [Periodic re-authentication](https://docs.netbird.io/how-to/enforce-periodic-user-authentication)</ul></li> | | <ul><li> - \[x] [Serverless](https://docs.netbird.io/how-to/netbird-on-faas) </ul></li> |
| | | | | <ul><li> - \[x] Docker </ul></li> |
| Connectivity | Management | Automation | Platforms |
|---------------------------------------------------------------------------------------------------------------------------------|--------------------------------------------------------------------------|----------------------------------------------------------------------------|---------------------------------------|
| <ul><li> - \[x] Kernel WireGuard </ul></li> | <ul><li> - \[x] [Admin Web UI](https://github.com/netbirdio/dashboard) </ul></li> | <ul><li> - \[x] [Public API](https://docs.netbird.io/api) </ul></li> | <ul><li> - \[x] Linux </ul></li> |
| <ul><li> - \[x] Peer-to-peer connections </ul></li> | <ul><li> - \[x] Auto peer discovery and configuration </ul></li> | <ul><li> - \[x] [Setup keys for bulk network provisioning](https://docs.netbird.io/how-to/register-machines-using-setup-keys) </ul></li> | <ul><li> - \[x] Mac </ul></li> |
| <ul><li> - \[x] Peer-to-peer encryption </ul></li> | <ul><li> - \[x] [IdP integrations](https://docs.netbird.io/selfhosted/identity-providers) </ul></li> | <ul><li> - \[x] [Self-hosting quickstart script](https://docs.netbird.io/selfhosted/selfhosted-quickstart) </ul></li> | <ul><li> - \[x] Windows </ul></li> |
| <ul><li> - \[x] Connection relay fallback </ul></li> | <ul><li> - \[x] [SSO & MFA support](https://docs.netbird.io/how-to/installation#running-net-bird-with-sso-login) </ul></li> | <ul><li> - \[x] IdP groups sync with JWT </ul></li> | <ul><li> - \[x] Android </ul></li> |
| <ul><li> - \[x] [Routes to external networks](https://docs.netbird.io/how-to/routing-traffic-to-private-networks) </ul></li> | <ul><li> - \[x] [Access control - groups & rules](https://docs.netbird.io/how-to/manage-network-access) </ul></li> | | <ul><li> - \[x] iOS </ul></li> |
| <ul><li> - \[x] NAT traversal with BPF </ul></li> | <ul><li> - \[x] [Private DNS](https://docs.netbird.io/how-to/manage-dns-in-your-network) </ul></li> | | <ul><li> - \[x] Docker </ul></li> |
| <ul><li> - \[x] Post-quantum-secure connection through [Rosenpass](https://rosenpass.eu) </ul></li> | <ul><li> - \[x] [Multiuser support](https://docs.netbird.io/how-to/add-users-to-your-network) </ul></li> | | <ul><li> - \[x] OpenWRT </ul></li> |
| | <ul><li> - \[x] [Activity logging](https://docs.netbird.io/how-to/monitor-system-and-network-activity) </ul></li> | | |
| | <ul><li> - \[x] SSH access management </ul></li> | | |
### Quickstart with NetBird Cloud
- Download and install NetBird at [https://app.netbird.io/install](https://app.netbird.io/install)
@@ -78,7 +79,7 @@ Follow the [Advanced guide with a custom identity provider](https://docs.netbird
- **Public domain** name pointing to the VM.
**Software requirements:**
- Docker installed on the VM with the docker-compose plugin ([Docker installation guide](https://docs.docker.com/engine/install/)) or docker with docker-compose in version 2 or higher.
- Docker installed on the VM with the docker compose plugin ([Docker installation guide](https://docs.docker.com/engine/install/)) or docker with docker-compose in version 2 or higher.
- [jq](https://jqlang.github.io/jq/) installed. In most distributions
Usually available in the official repositories and can be installed with `sudo apt install jq` or `sudo yum install jq`
- [curl](https://curl.se/) installed.
@@ -95,9 +96,9 @@ export NETBIRD_DOMAIN=netbird.example.com; curl -fsSL https://github.com/netbird
- Every machine in the network runs [NetBird Agent (or Client)](client/) that manages WireGuard.
- Every agent connects to [Management Service](management/) that holds network state, manages peer IPs, and distributes network updates to agents (peers).
- NetBird agent uses WebRTC ICE implemented in [pion/ice library](https://github.com/pion/ice) to discover connection candidates when establishing a peer-to-peer connection between machines.
- Connection candidates are discovered with the help of [STUN](https://en.wikipedia.org/wiki/STUN) servers.
- Connection candidates are discovered with a help of [STUN](https://en.wikipedia.org/wiki/STUN) servers.
- Agents negotiate a connection through [Signal Service](signal/) passing p2p encrypted messages with candidates.
- Sometimes the NAT traversal is unsuccessful due to strict NATs (e.g. mobile carrier-grade NAT) and a p2p connection isn't possible. When this occurs the system falls back to a relay server called [TURN](https://en.wikipedia.org/wiki/Traversal_Using_Relays_around_NAT), and a secure WireGuard tunnel is established via the TURN server.
- Sometimes the NAT traversal is unsuccessful due to strict NATs (e.g. mobile carrier-grade NAT) and p2p connection isn't possible. When this occurs the system falls back to a relay server called [TURN](https://en.wikipedia.org/wiki/Traversal_Using_Relays_around_NAT), and a secure WireGuard tunnel is established via the TURN server.
[Coturn](https://github.com/coturn/coturn) is the one that has been successfully used for STUN and TURN in NetBird setups.
@@ -108,8 +109,8 @@ export NETBIRD_DOMAIN=netbird.example.com; curl -fsSL https://github.com/netbird
See a complete [architecture overview](https://docs.netbird.io/about-netbird/how-netbird-works#architecture) for details.
### Community projects
- [NetBird on OpenWRT](https://github.com/messense/openwrt-netbird)
- [NetBird installer script](https://github.com/physk/netbird-installer)
- [NetBird ansible collection by Dominion Solutions](https://galaxy.ansible.com/ui/repo/published/dominion_solutions/netbird/)
**Note**: The `main` branch may be in an *unstable or even broken state* during development.
For stable versions, see [releases](https://github.com/netbirdio/netbird/releases).
@@ -121,7 +122,7 @@ In November 2022, NetBird joined the [StartUpSecure program](https://www.forschu
![CISPA_Logo_BLACK_EN_RZ_RGB (1)](https://user-images.githubusercontent.com/700848/203091324-c6d311a0-22b5-4b05-a288-91cbc6cdcc46.png)
### Testimonials
We use open-source technologies like [WireGuard®](https://www.wireguard.com/), [Pion ICE (WebRTC)](https://github.com/pion/ice), [Coturn](https://github.com/coturn/coturn), and [Rosenpass](https://rosenpass.eu). We very much appreciate the work these guys are doing and we'd greatly appreciate if you could support them in any way (e.g., by giving a star or a contribution).
We use open-source technologies like [WireGuard®](https://www.wireguard.com/), [Pion ICE (WebRTC)](https://github.com/pion/ice), [Coturn](https://github.com/coturn/coturn), and [Rosenpass](https://rosenpass.eu). We very much appreciate the work these guys are doing and we'd greatly appreciate if you could support them in any way (e.g. giving a star or a contribution).
### Legal
_WireGuard_ and the _WireGuard_ logo are [registered trademarks](https://www.wireguard.com/trademark-policy/) of Jason A. Donenfeld.

View File

@@ -1,5 +1,5 @@
FROM alpine:3.18.5
RUN apk add --no-cache ca-certificates iptables ip6tables
ENV NB_FOREGROUND_MODE=true
ENTRYPOINT [ "/usr/local/bin/netbird","up"]
COPY netbird /usr/local/bin/netbird
ENTRYPOINT [ "/go/bin/netbird","up"]
COPY netbird /go/bin/netbird

View File

@@ -79,7 +79,6 @@ func (c *Client) Run(urlOpener URLOpener, dns *DNSList, dnsReadyListener DnsRead
return err
}
c.recorder.UpdateManagementAddress(cfg.ManagementURL.String())
c.recorder.UpdateRosenpass(cfg.RosenpassEnabled, cfg.RosenpassPermissive)
var ctx context.Context
//nolint
@@ -110,7 +109,6 @@ func (c *Client) RunWithoutLogin(dns *DNSList, dnsReadyListener DnsReadyListener
return err
}
c.recorder.UpdateManagementAddress(cfg.ManagementURL.String())
c.recorder.UpdateRosenpass(cfg.RosenpassEnabled, cfg.RosenpassPermissive)
var ctx context.Context
//nolint

View File

@@ -1,212 +0,0 @@
package anonymize
import (
"crypto/rand"
"fmt"
"math/big"
"net"
"net/netip"
"net/url"
"regexp"
"slices"
"strings"
)
type Anonymizer struct {
ipAnonymizer map[netip.Addr]netip.Addr
domainAnonymizer map[string]string
currentAnonIPv4 netip.Addr
currentAnonIPv6 netip.Addr
startAnonIPv4 netip.Addr
startAnonIPv6 netip.Addr
}
func DefaultAddresses() (netip.Addr, netip.Addr) {
// 192.51.100.0, 100::
return netip.AddrFrom4([4]byte{198, 51, 100, 0}), netip.AddrFrom16([16]byte{0x01})
}
func NewAnonymizer(startIPv4, startIPv6 netip.Addr) *Anonymizer {
return &Anonymizer{
ipAnonymizer: map[netip.Addr]netip.Addr{},
domainAnonymizer: map[string]string{},
currentAnonIPv4: startIPv4,
currentAnonIPv6: startIPv6,
startAnonIPv4: startIPv4,
startAnonIPv6: startIPv6,
}
}
func (a *Anonymizer) AnonymizeIP(ip netip.Addr) netip.Addr {
if ip.IsLoopback() ||
ip.IsLinkLocalUnicast() ||
ip.IsLinkLocalMulticast() ||
ip.IsInterfaceLocalMulticast() ||
ip.IsPrivate() ||
ip.IsUnspecified() ||
ip.IsMulticast() ||
isWellKnown(ip) ||
a.isInAnonymizedRange(ip) {
return ip
}
if _, ok := a.ipAnonymizer[ip]; !ok {
if ip.Is4() {
a.ipAnonymizer[ip] = a.currentAnonIPv4
a.currentAnonIPv4 = a.currentAnonIPv4.Next()
} else {
a.ipAnonymizer[ip] = a.currentAnonIPv6
a.currentAnonIPv6 = a.currentAnonIPv6.Next()
}
}
return a.ipAnonymizer[ip]
}
// isInAnonymizedRange checks if an IP is within the range of already assigned anonymized IPs
func (a *Anonymizer) isInAnonymizedRange(ip netip.Addr) bool {
if ip.Is4() && ip.Compare(a.startAnonIPv4) >= 0 && ip.Compare(a.currentAnonIPv4) <= 0 {
return true
} else if !ip.Is4() && ip.Compare(a.startAnonIPv6) >= 0 && ip.Compare(a.currentAnonIPv6) <= 0 {
return true
}
return false
}
func (a *Anonymizer) AnonymizeIPString(ip string) string {
addr, err := netip.ParseAddr(ip)
if err != nil {
return ip
}
return a.AnonymizeIP(addr).String()
}
func (a *Anonymizer) AnonymizeDomain(domain string) string {
if strings.HasSuffix(domain, "netbird.io") ||
strings.HasSuffix(domain, "netbird.selfhosted") ||
strings.HasSuffix(domain, "netbird.cloud") ||
strings.HasSuffix(domain, "netbird.stage") ||
strings.HasSuffix(domain, ".domain") {
return domain
}
parts := strings.Split(domain, ".")
if len(parts) < 2 {
return domain
}
baseDomain := parts[len(parts)-2] + "." + parts[len(parts)-1]
anonymized, ok := a.domainAnonymizer[baseDomain]
if !ok {
anonymizedBase := "anon-" + generateRandomString(5) + ".domain"
a.domainAnonymizer[baseDomain] = anonymizedBase
anonymized = anonymizedBase
}
return strings.Replace(domain, baseDomain, anonymized, 1)
}
func (a *Anonymizer) AnonymizeURI(uri string) string {
u, err := url.Parse(uri)
if err != nil {
return uri
}
var anonymizedHost string
if u.Opaque != "" {
host, port, err := net.SplitHostPort(u.Opaque)
if err == nil {
anonymizedHost = fmt.Sprintf("%s:%s", a.AnonymizeDomain(host), port)
} else {
anonymizedHost = a.AnonymizeDomain(u.Opaque)
}
u.Opaque = anonymizedHost
} else if u.Host != "" {
host, port, err := net.SplitHostPort(u.Host)
if err == nil {
anonymizedHost = fmt.Sprintf("%s:%s", a.AnonymizeDomain(host), port)
} else {
anonymizedHost = a.AnonymizeDomain(u.Host)
}
u.Host = anonymizedHost
}
return u.String()
}
func (a *Anonymizer) AnonymizeString(str string) string {
ipv4Regex := regexp.MustCompile(`\b(?:[0-9]{1,3}\.){3}[0-9]{1,3}\b`)
ipv6Regex := regexp.MustCompile(`\b([0-9a-fA-F:]+:+[0-9a-fA-F]{0,4})(?:%[0-9a-zA-Z]+)?(?:\/[0-9]{1,3})?(?::[0-9]{1,5})?\b`)
str = ipv4Regex.ReplaceAllStringFunc(str, a.AnonymizeIPString)
str = ipv6Regex.ReplaceAllStringFunc(str, a.AnonymizeIPString)
for domain, anonDomain := range a.domainAnonymizer {
str = strings.ReplaceAll(str, domain, anonDomain)
}
str = a.AnonymizeSchemeURI(str)
str = a.AnonymizeDNSLogLine(str)
return str
}
// AnonymizeSchemeURI finds and anonymizes URIs with stun, stuns, turn, and turns schemes.
func (a *Anonymizer) AnonymizeSchemeURI(text string) string {
re := regexp.MustCompile(`(?i)\b(stuns?:|turns?:|https?://)\S+\b`)
return re.ReplaceAllStringFunc(text, a.AnonymizeURI)
}
// AnonymizeDNSLogLine anonymizes domain names in DNS log entries by replacing them with a random string.
func (a *Anonymizer) AnonymizeDNSLogLine(logEntry string) string {
domainPattern := `dns\.Question{Name:"([^"]+)",`
domainRegex := regexp.MustCompile(domainPattern)
return domainRegex.ReplaceAllStringFunc(logEntry, func(match string) string {
parts := strings.Split(match, `"`)
if len(parts) >= 2 {
domain := parts[1]
if strings.HasSuffix(domain, ".domain") {
return match
}
randomDomain := generateRandomString(10) + ".domain"
return strings.Replace(match, domain, randomDomain, 1)
}
return match
})
}
func isWellKnown(addr netip.Addr) bool {
wellKnown := []string{
"8.8.8.8", "8.8.4.4", // Google DNS IPv4
"2001:4860:4860::8888", "2001:4860:4860::8844", // Google DNS IPv6
"1.1.1.1", "1.0.0.1", // Cloudflare DNS IPv4
"2606:4700:4700::1111", "2606:4700:4700::1001", // Cloudflare DNS IPv6
"9.9.9.9", "149.112.112.112", // Quad9 DNS IPv4
"2620:fe::fe", "2620:fe::9", // Quad9 DNS IPv6
}
if slices.Contains(wellKnown, addr.String()) {
return true
}
cgnatRangeStart := netip.AddrFrom4([4]byte{100, 64, 0, 0})
cgnatRange := netip.PrefixFrom(cgnatRangeStart, 10)
return cgnatRange.Contains(addr)
}
func generateRandomString(length int) string {
const letters = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
result := make([]byte, length)
for i := range result {
num, err := rand.Int(rand.Reader, big.NewInt(int64(len(letters))))
if err != nil {
continue
}
result[i] = letters[num.Int64()]
}
return string(result)
}

View File

@@ -1,223 +0,0 @@
package anonymize_test
import (
"net/netip"
"regexp"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/anonymize"
)
func TestAnonymizeIP(t *testing.T) {
startIPv4 := netip.MustParseAddr("198.51.100.0")
startIPv6 := netip.MustParseAddr("100::")
anonymizer := anonymize.NewAnonymizer(startIPv4, startIPv6)
tests := []struct {
name string
ip string
expect string
}{
{"Well known", "8.8.8.8", "8.8.8.8"},
{"First Public IPv4", "1.2.3.4", "198.51.100.0"},
{"Second Public IPv4", "4.3.2.1", "198.51.100.1"},
{"Repeated IPv4", "1.2.3.4", "198.51.100.0"},
{"Private IPv4", "192.168.1.1", "192.168.1.1"},
{"First Public IPv6", "2607:f8b0:4005:805::200e", "100::"},
{"Second Public IPv6", "a::b", "100::1"},
{"Repeated IPv6", "2607:f8b0:4005:805::200e", "100::"},
{"Private IPv6", "fe80::1", "fe80::1"},
{"In Range IPv4", "198.51.100.2", "198.51.100.2"},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
ip := netip.MustParseAddr(tc.ip)
anonymizedIP := anonymizer.AnonymizeIP(ip)
if anonymizedIP.String() != tc.expect {
t.Errorf("%s: expected %s, got %s", tc.name, tc.expect, anonymizedIP)
}
})
}
}
func TestAnonymizeDNSLogLine(t *testing.T) {
anonymizer := anonymize.NewAnonymizer(netip.Addr{}, netip.Addr{})
testLog := `2024-04-23T20:01:11+02:00 TRAC client/internal/dns/local.go:25: received question: dns.Question{Name:"example.com", Qtype:0x1c, Qclass:0x1}`
result := anonymizer.AnonymizeDNSLogLine(testLog)
require.NotEqual(t, testLog, result)
assert.NotContains(t, result, "example.com")
}
func TestAnonymizeDomain(t *testing.T) {
anonymizer := anonymize.NewAnonymizer(netip.Addr{}, netip.Addr{})
tests := []struct {
name string
domain string
expectPattern string
shouldAnonymize bool
}{
{
"General Domain",
"example.com",
`^anon-[a-zA-Z0-9]+\.domain$`,
true,
},
{
"Subdomain",
"sub.example.com",
`^sub\.anon-[a-zA-Z0-9]+\.domain$`,
true,
},
{
"Protected Domain",
"netbird.io",
`^netbird\.io$`,
false,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
result := anonymizer.AnonymizeDomain(tc.domain)
if tc.shouldAnonymize {
assert.Regexp(t, tc.expectPattern, result, "The anonymized domain should match the expected pattern")
assert.NotContains(t, result, tc.domain, "The original domain should not be present in the result")
} else {
assert.Equal(t, tc.domain, result, "Protected domains should not be anonymized")
}
})
}
}
func TestAnonymizeURI(t *testing.T) {
anonymizer := anonymize.NewAnonymizer(netip.Addr{}, netip.Addr{})
tests := []struct {
name string
uri string
regex string
}{
{
"HTTP URI with Port",
"http://example.com:80/path",
`^http://anon-[a-zA-Z0-9]+\.domain:80/path$`,
},
{
"HTTP URI without Port",
"http://example.com/path",
`^http://anon-[a-zA-Z0-9]+\.domain/path$`,
},
{
"Opaque URI with Port",
"stun:example.com:80?transport=udp",
`^stun:anon-[a-zA-Z0-9]+\.domain:80\?transport=udp$`,
},
{
"Opaque URI without Port",
"stun:example.com?transport=udp",
`^stun:anon-[a-zA-Z0-9]+\.domain\?transport=udp$`,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
result := anonymizer.AnonymizeURI(tc.uri)
assert.Regexp(t, regexp.MustCompile(tc.regex), result, "URI should match expected pattern")
require.NotContains(t, result, "example.com", "Original domain should not be present")
})
}
}
func TestAnonymizeSchemeURI(t *testing.T) {
anonymizer := anonymize.NewAnonymizer(netip.Addr{}, netip.Addr{})
tests := []struct {
name string
input string
expect string
}{
{"STUN URI in text", "Connection made via stun:example.com", `Connection made via stun:anon-[a-zA-Z0-9]+\.domain`},
{"TURN URI in log", "Failed attempt turn:some.example.com:3478?transport=tcp: retrying", `Failed attempt turn:some.anon-[a-zA-Z0-9]+\.domain:3478\?transport=tcp: retrying`},
{"HTTPS URI in message", "Visit https://example.com for more", `Visit https://anon-[a-zA-Z0-9]+\.domain for more`},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
result := anonymizer.AnonymizeSchemeURI(tc.input)
assert.Regexp(t, tc.expect, result, "The anonymized output should match expected pattern")
require.NotContains(t, result, "example.com", "Original domain should not be present")
})
}
}
func TestAnonymizString_MemorizedDomain(t *testing.T) {
anonymizer := anonymize.NewAnonymizer(netip.Addr{}, netip.Addr{})
domain := "example.com"
anonymizedDomain := anonymizer.AnonymizeDomain(domain)
sampleString := "This is a test string including the domain example.com which should be anonymized."
firstPassResult := anonymizer.AnonymizeString(sampleString)
secondPassResult := anonymizer.AnonymizeString(firstPassResult)
assert.Contains(t, firstPassResult, anonymizedDomain, "The domain should be anonymized in the first pass")
assert.NotContains(t, firstPassResult, domain, "The original domain should not appear in the first pass output")
assert.Equal(t, firstPassResult, secondPassResult, "The second pass should not further anonymize the string")
}
func TestAnonymizeString_DoubleURI(t *testing.T) {
anonymizer := anonymize.NewAnonymizer(netip.Addr{}, netip.Addr{})
domain := "example.com"
anonymizedDomain := anonymizer.AnonymizeDomain(domain)
sampleString := "Check out our site at https://example.com for more info."
firstPassResult := anonymizer.AnonymizeString(sampleString)
secondPassResult := anonymizer.AnonymizeString(firstPassResult)
assert.Contains(t, firstPassResult, "https://"+anonymizedDomain, "The URI should be anonymized in the first pass")
assert.NotContains(t, firstPassResult, "https://example.com", "The original URI should not appear in the first pass output")
assert.Equal(t, firstPassResult, secondPassResult, "The second pass should not further anonymize the URI")
}
func TestAnonymizeString_IPAddresses(t *testing.T) {
anonymizer := anonymize.NewAnonymizer(anonymize.DefaultAddresses())
tests := []struct {
name string
input string
expect string
}{
{
name: "IPv4 Address",
input: "Error occurred at IP 122.138.1.1",
expect: "Error occurred at IP 198.51.100.0",
},
{
name: "IPv6 Address",
input: "Access attempted from 2001:db8::ff00:42",
expect: "Access attempted from 100::",
},
{
name: "IPv6 Address with Port",
input: "Access attempted from [2001:db8::ff00:42]:8080",
expect: "Access attempted from [100::]:8080",
},
{
name: "Both IPv4 and IPv6",
input: "IPv4: 142.108.0.1 and IPv6: 2001:db8::ff00:43",
expect: "IPv4: 198.51.100.1 and IPv6: 100::1",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
result := anonymizer.AnonymizeString(tc.input)
assert.Equal(t, tc.expect, result, "IP addresses should be anonymized correctly")
})
}
}

View File

@@ -1,248 +0,0 @@
package cmd
import (
"context"
"fmt"
"strings"
"time"
"github.com/spf13/cobra"
"google.golang.org/grpc/status"
"github.com/netbirdio/netbird/client/proto"
)
var debugCmd = &cobra.Command{
Use: "debug",
Short: "Debugging commands",
Long: "Provides commands for debugging and logging control within the Netbird daemon.",
}
var debugBundleCmd = &cobra.Command{
Use: "bundle",
Example: " netbird debug bundle",
Short: "Create a debug bundle",
Long: "Generates a compressed archive of the daemon's logs and status for debugging purposes.",
RunE: debugBundle,
}
var logCmd = &cobra.Command{
Use: "log",
Short: "Manage logging for the Netbird daemon",
Long: `Commands to manage logging settings for the Netbird daemon, including ICE, gRPC, and general log levels.`,
}
var logLevelCmd = &cobra.Command{
Use: "level <level>",
Short: "Set the logging level for this session",
Long: `Sets the logging level for the current session. This setting is temporary and will revert to the default on daemon restart.
Available log levels are:
panic: for panic level, highest level of severity
fatal: for fatal level errors that cause the program to exit
error: for error conditions
warn: for warning conditions
info: for informational messages
debug: for debug-level messages
trace: for trace-level messages, which include more fine-grained information than debug`,
Args: cobra.ExactArgs(1),
RunE: setLogLevel,
}
var forCmd = &cobra.Command{
Use: "for <time>",
Short: "Run debug logs for a specified duration and create a debug bundle",
Long: `Sets the logging level to trace, runs for the specified duration, and then generates a debug bundle.`,
Example: " netbird debug for 5m",
Args: cobra.ExactArgs(1),
RunE: runForDuration,
}
func debugBundle(cmd *cobra.Command, _ []string) error {
conn, err := getClient(cmd.Context())
if err != nil {
return err
}
defer conn.Close()
client := proto.NewDaemonServiceClient(conn)
resp, err := client.DebugBundle(cmd.Context(), &proto.DebugBundleRequest{
Anonymize: anonymizeFlag,
Status: getStatusOutput(cmd),
})
if err != nil {
return fmt.Errorf("failed to bundle debug: %v", status.Convert(err).Message())
}
cmd.Println(resp.GetPath())
return nil
}
func setLogLevel(cmd *cobra.Command, args []string) error {
conn, err := getClient(cmd.Context())
if err != nil {
return err
}
defer conn.Close()
client := proto.NewDaemonServiceClient(conn)
level := parseLogLevel(args[0])
if level == proto.LogLevel_UNKNOWN {
return fmt.Errorf("unknown log level: %s. Available levels are: panic, fatal, error, warn, info, debug, trace\n", args[0])
}
_, err = client.SetLogLevel(cmd.Context(), &proto.SetLogLevelRequest{
Level: level,
})
if err != nil {
return fmt.Errorf("failed to set log level: %v", status.Convert(err).Message())
}
cmd.Println("Log level set successfully to", args[0])
return nil
}
func parseLogLevel(level string) proto.LogLevel {
switch strings.ToLower(level) {
case "panic":
return proto.LogLevel_PANIC
case "fatal":
return proto.LogLevel_FATAL
case "error":
return proto.LogLevel_ERROR
case "warn":
return proto.LogLevel_WARN
case "info":
return proto.LogLevel_INFO
case "debug":
return proto.LogLevel_DEBUG
case "trace":
return proto.LogLevel_TRACE
default:
return proto.LogLevel_UNKNOWN
}
}
func runForDuration(cmd *cobra.Command, args []string) error {
duration, err := time.ParseDuration(args[0])
if err != nil {
return fmt.Errorf("invalid duration format: %v", err)
}
conn, err := getClient(cmd.Context())
if err != nil {
return err
}
defer conn.Close()
client := proto.NewDaemonServiceClient(conn)
if _, err := client.Down(cmd.Context(), &proto.DownRequest{}); err != nil {
return fmt.Errorf("failed to down: %v", status.Convert(err).Message())
}
cmd.Println("Netbird down")
_, err = client.SetLogLevel(cmd.Context(), &proto.SetLogLevelRequest{
Level: proto.LogLevel_TRACE,
})
if err != nil {
return fmt.Errorf("failed to set log level to trace: %v", status.Convert(err).Message())
}
cmd.Println("Log level set to trace.")
time.Sleep(1 * time.Second)
if _, err := client.Up(cmd.Context(), &proto.UpRequest{}); err != nil {
return fmt.Errorf("failed to up: %v", status.Convert(err).Message())
}
cmd.Println("Netbird up")
time.Sleep(3 * time.Second)
headerPostUp := fmt.Sprintf("----- Netbird post-up - Timestamp: %s", time.Now().Format(time.RFC3339))
statusOutput := fmt.Sprintf("%s\n%s", headerPostUp, getStatusOutput(cmd))
if waitErr := waitForDurationOrCancel(cmd.Context(), duration, cmd); waitErr != nil {
return waitErr
}
cmd.Println("\nDuration completed")
headerPreDown := fmt.Sprintf("----- Netbird pre-down - Timestamp: %s - Duration: %s", time.Now().Format(time.RFC3339), duration)
statusOutput = fmt.Sprintf("%s\n%s\n%s", statusOutput, headerPreDown, getStatusOutput(cmd))
if _, err := client.Down(cmd.Context(), &proto.DownRequest{}); err != nil {
return fmt.Errorf("failed to down: %v", status.Convert(err).Message())
}
cmd.Println("Netbird down")
// TODO reset log level
time.Sleep(1 * time.Second)
cmd.Println("Creating debug bundle...")
resp, err := client.DebugBundle(cmd.Context(), &proto.DebugBundleRequest{
Anonymize: anonymizeFlag,
Status: statusOutput,
})
if err != nil {
return fmt.Errorf("failed to bundle debug: %v", status.Convert(err).Message())
}
cmd.Println(resp.GetPath())
return nil
}
func getStatusOutput(cmd *cobra.Command) string {
var statusOutputString string
statusResp, err := getStatus(cmd.Context())
if err != nil {
cmd.PrintErrf("Failed to get status: %v\n", err)
} else {
statusOutputString = parseToFullDetailSummary(convertToStatusOutputOverview(statusResp))
}
return statusOutputString
}
func waitForDurationOrCancel(ctx context.Context, duration time.Duration, cmd *cobra.Command) error {
ticker := time.NewTicker(1 * time.Second)
defer ticker.Stop()
startTime := time.Now()
done := make(chan struct{})
go func() {
defer close(done)
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
elapsed := time.Since(startTime)
if elapsed >= duration {
return
}
remaining := duration - elapsed
cmd.Printf("\rRemaining time: %s", formatDuration(remaining))
}
}
}()
select {
case <-ctx.Done():
return ctx.Err()
case <-done:
return nil
}
}
func formatDuration(d time.Duration) string {
d = d.Round(time.Second)
h := d / time.Hour
d %= time.Hour
m := d / time.Minute
d %= time.Minute
s := d / time.Second
return fmt.Sprintf("%02d:%02d:%02d", h, m, s)
}

View File

@@ -25,16 +25,12 @@ import (
)
const (
externalIPMapFlag = "external-ip-map"
dnsResolverAddress = "dns-resolver-address"
enableRosenpassFlag = "enable-rosenpass"
rosenpassPermissiveFlag = "rosenpass-permissive"
preSharedKeyFlag = "preshared-key"
interfaceNameFlag = "interface-name"
wireguardPortFlag = "wireguard-port"
disableAutoConnectFlag = "disable-auto-connect"
serverSSHAllowedFlag = "allow-server-ssh"
extraIFaceBlackListFlag = "extra-iface-blacklist"
externalIPMapFlag = "external-ip-map"
dnsResolverAddress = "dns-resolver-address"
enableRosenpassFlag = "enable-rosenpass"
preSharedKeyFlag = "preshared-key"
interfaceNameFlag = "interface-name"
wireguardPortFlag = "wireguard-port"
)
var (
@@ -58,14 +54,8 @@ var (
natExternalIPs []string
customDNSAddress string
rosenpassEnabled bool
rosenpassPermissive bool
serverSSHAllowed bool
interfaceName string
wireguardPort uint16
serviceName string
autoConnectDisabled bool
extraIFaceBlackList []string
anonymizeFlag bool
rootCmd = &cobra.Command{
Use: "netbird",
Short: "",
@@ -104,24 +94,15 @@ func init() {
if runtime.GOOS == "windows" {
defaultDaemonAddr = "tcp://127.0.0.1:41731"
}
defaultServiceName := "netbird"
if runtime.GOOS == "windows" {
defaultServiceName = "Netbird"
}
rootCmd.PersistentFlags().StringVar(&daemonAddr, "daemon-addr", defaultDaemonAddr, "Daemon service address to serve CLI requests [unix|tcp]://[path|host:port]")
rootCmd.PersistentFlags().StringVarP(&managementURL, "management-url", "m", "", fmt.Sprintf("Management Service URL [http|https]://[host]:[port] (default \"%s\")", internal.DefaultManagementURL))
rootCmd.PersistentFlags().StringVar(&adminURL, "admin-url", "", fmt.Sprintf("Admin Panel URL [http|https]://[host]:[port] (default \"%s\")", internal.DefaultAdminURL))
rootCmd.PersistentFlags().StringVarP(&serviceName, "service", "s", defaultServiceName, "Netbird system service name")
rootCmd.PersistentFlags().StringVarP(&configPath, "config", "c", defaultConfigPath, "Netbird config file location")
rootCmd.PersistentFlags().StringVarP(&logLevel, "log-level", "l", "info", "sets Netbird log level")
rootCmd.PersistentFlags().StringVar(&logFile, "log-file", defaultLogFile, "sets Netbird log path. If console is specified the log will be output to stdout")
rootCmd.PersistentFlags().StringVarP(&setupKey, "setup-key", "k", "", "Setup key obtained from the Management Service Dashboard (used to register peer)")
rootCmd.PersistentFlags().StringVar(&preSharedKey, preSharedKeyFlag, "", "Sets Wireguard PreSharedKey property. If set, then only peers that have the same key can communicate.")
rootCmd.PersistentFlags().StringVarP(&hostName, "hostname", "n", "", "Sets a custom hostname for the device")
rootCmd.PersistentFlags().BoolVarP(&anonymizeFlag, "anonymize", "A", false, "anonymize IP addresses and non-netbird.io domains in logs and status output")
rootCmd.AddCommand(serviceCmd)
rootCmd.AddCommand(upCmd)
rootCmd.AddCommand(downCmd)
@@ -129,20 +110,8 @@ func init() {
rootCmd.AddCommand(loginCmd)
rootCmd.AddCommand(versionCmd)
rootCmd.AddCommand(sshCmd)
rootCmd.AddCommand(routesCmd)
rootCmd.AddCommand(debugCmd)
serviceCmd.AddCommand(runCmd, startCmd, stopCmd, restartCmd) // service control commands are subcommands of service
serviceCmd.AddCommand(installCmd, uninstallCmd) // service installer commands are subcommands of service
routesCmd.AddCommand(routesListCmd)
routesCmd.AddCommand(routesSelectCmd, routesDeselectCmd)
debugCmd.AddCommand(debugBundleCmd)
debugCmd.AddCommand(logCmd)
logCmd.AddCommand(logLevelCmd)
debugCmd.AddCommand(forCmd)
upCmd.PersistentFlags().StringSliceVar(&natExternalIPs, externalIPMapFlag, nil,
`Sets external IPs maps between local addresses and interfaces.`+
`You can specify a comma-separated list with a single IP and IP/IP or IP/Interface Name. `+
@@ -157,9 +126,6 @@ func init() {
`E.g. --dns-resolver-address 127.0.0.1:5053 or --dns-resolver-address ""`,
)
upCmd.PersistentFlags().BoolVar(&rosenpassEnabled, enableRosenpassFlag, false, "[Experimental] Enable Rosenpass feature. If enabled, the connection will be post-quantum secured via Rosenpass.")
upCmd.PersistentFlags().BoolVar(&rosenpassPermissive, rosenpassPermissiveFlag, false, "[Experimental] Enable Rosenpass in permissive mode to allow this peer to accept WireGuard connections without requiring Rosenpass functionality from peers that do not have Rosenpass enabled.")
upCmd.PersistentFlags().BoolVar(&serverSSHAllowed, serverSSHAllowedFlag, false, "Allow SSH server on peer. If enabled, the SSH server will be permitted")
upCmd.PersistentFlags().BoolVar(&autoConnectDisabled, disableAutoConnectFlag, false, "Disables auto-connect feature. If enabled, then the client won't connect automatically when the service starts.")
}
// SetupCloseHandler handles SIGTERM signal and exits with success
@@ -210,7 +176,7 @@ func FlagNameToEnvVar(cmdFlag string, prefix string) string {
return prefix + upper
}
// DialClientGRPCServer returns client connection to the daemon server.
// DialClientGRPCServer returns client connection to the dameno server.
func DialClientGRPCServer(ctx context.Context, addr string) (*grpc.ClientConn, error) {
ctx, cancel := context.WithTimeout(ctx, time.Second*3)
defer cancel()
@@ -350,14 +316,3 @@ func migrateToNetbird(oldPath, newPath string) bool {
return true
}
func getClient(ctx context.Context) (*grpc.ClientConn, error) {
conn, err := DialClientGRPCServer(ctx, daemonAddr)
if err != nil {
return nil, fmt.Errorf("failed to connect to daemon error: %v\n"+
"If the daemon is not running please run: "+
"\nnetbird service install \nnetbird service start\n", err)
}
return conn, nil
}

View File

@@ -1,131 +0,0 @@
package cmd
import (
"fmt"
"github.com/spf13/cobra"
"google.golang.org/grpc/status"
"github.com/netbirdio/netbird/client/proto"
)
var appendFlag bool
var routesCmd = &cobra.Command{
Use: "routes",
Short: "Manage network routes",
Long: `Commands to list, select, or deselect network routes.`,
}
var routesListCmd = &cobra.Command{
Use: "list",
Aliases: []string{"ls"},
Short: "List routes",
Example: " netbird routes list",
Long: "List all available network routes.",
RunE: routesList,
}
var routesSelectCmd = &cobra.Command{
Use: "select route...|all",
Short: "Select routes",
Long: "Select a list of routes by identifiers or 'all' to clear all selections and to accept all (including new) routes.\nDefault mode is replace, use -a to append to already selected routes.",
Example: " netbird routes select all\n netbird routes select route1 route2\n netbird routes select -a route3",
Args: cobra.MinimumNArgs(1),
RunE: routesSelect,
}
var routesDeselectCmd = &cobra.Command{
Use: "deselect route...|all",
Short: "Deselect routes",
Long: "Deselect previously selected routes by identifiers or 'all' to disable accepting any routes.",
Example: " netbird routes deselect all\n netbird routes deselect route1 route2",
Args: cobra.MinimumNArgs(1),
RunE: routesDeselect,
}
func init() {
routesSelectCmd.PersistentFlags().BoolVarP(&appendFlag, "append", "a", false, "Append to current route selection instead of replacing")
}
func routesList(cmd *cobra.Command, _ []string) error {
conn, err := getClient(cmd.Context())
if err != nil {
return err
}
defer conn.Close()
client := proto.NewDaemonServiceClient(conn)
resp, err := client.ListRoutes(cmd.Context(), &proto.ListRoutesRequest{})
if err != nil {
return fmt.Errorf("failed to list routes: %v", status.Convert(err).Message())
}
if len(resp.Routes) == 0 {
cmd.Println("No routes available.")
return nil
}
cmd.Println("Available Routes:")
for _, route := range resp.Routes {
selectedStatus := "Not Selected"
if route.GetSelected() {
selectedStatus = "Selected"
}
cmd.Printf("\n - ID: %s\n Network: %s\n Status: %s\n", route.GetID(), route.GetNetwork(), selectedStatus)
}
return nil
}
func routesSelect(cmd *cobra.Command, args []string) error {
conn, err := getClient(cmd.Context())
if err != nil {
return err
}
defer conn.Close()
client := proto.NewDaemonServiceClient(conn)
req := &proto.SelectRoutesRequest{
RouteIDs: args,
}
if len(args) == 1 && args[0] == "all" {
req.All = true
} else if appendFlag {
req.Append = true
}
if _, err := client.SelectRoutes(cmd.Context(), req); err != nil {
return fmt.Errorf("failed to select routes: %v", status.Convert(err).Message())
}
cmd.Println("Routes selected successfully.")
return nil
}
func routesDeselect(cmd *cobra.Command, args []string) error {
conn, err := getClient(cmd.Context())
if err != nil {
return err
}
defer conn.Close()
client := proto.NewDaemonServiceClient(conn)
req := &proto.SelectRoutesRequest{
RouteIDs: args,
}
if len(args) == 1 && args[0] == "all" {
req.All = true
}
if _, err := client.DeselectRoutes(cmd.Context(), req); err != nil {
return fmt.Errorf("failed to deselect routes: %v", status.Convert(err).Message())
}
cmd.Println("Routes deselected successfully.")
return nil
}

View File

@@ -2,6 +2,8 @@ package cmd
import (
"context"
"runtime"
"github.com/kardianos/service"
log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
@@ -22,8 +24,12 @@ func newProgram(ctx context.Context, cancel context.CancelFunc) *program {
}
func newSVCConfig() *service.Config {
name := "netbird"
if runtime.GOOS == "windows" {
name = "Netbird"
}
return &service.Config{
Name: serviceName,
Name: name,
DisplayName: "Netbird",
Description: "A WireGuard-based mesh network that connects your devices into a single private network.",
Option: make(service.KeyValue),

View File

@@ -64,10 +64,6 @@ var installCmd = &cobra.Command{
}
}
if runtime.GOOS == "windows" {
svcConfig.Option["OnFailure"] = "restart"
}
ctx, cancel := context.WithCancel(cmd.Context())
s, err := newSVC(newProgram(ctx, cancel), svcConfig)
@@ -81,7 +77,6 @@ var installCmd = &cobra.Command{
cmd.PrintErrln(err)
return err
}
cmd.Println("Netbird service has been installed")
return nil
},
@@ -111,7 +106,7 @@ var uninstallCmd = &cobra.Command{
if err != nil {
return err
}
cmd.Println("Netbird service has been uninstalled")
cmd.Println("Netbird has been uninstalled")
return nil
},
}

View File

@@ -24,7 +24,7 @@ var (
)
var sshCmd = &cobra.Command{
Use: "ssh [user@]host",
Use: "ssh",
Args: func(cmd *cobra.Command, args []string) error {
if len(args) < 1 {
return errors.New("requires a host argument")
@@ -94,7 +94,7 @@ func runSSH(ctx context.Context, addr string, pemKey []byte, cmd *cobra.Command)
if err != nil {
cmd.Printf("Error: %v\n", err)
cmd.Printf("Couldn't connect. Please check the connection status or if the ssh server is enabled on the other peer" +
"\nYou can verify the connection by running:\n\n" +
"You can verify the connection by running:\n\n" +
" netbird status\n\n")
return err
}

View File

@@ -6,8 +6,6 @@ import (
"fmt"
"net"
"net/netip"
"os"
"runtime"
"sort"
"strings"
"time"
@@ -16,7 +14,6 @@ import (
"google.golang.org/grpc/status"
"gopkg.in/yaml.v3"
"github.com/netbirdio/netbird/client/anonymize"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/proto"
@@ -37,9 +34,6 @@ type peerStateDetailOutput struct {
LastWireguardHandshake time.Time `json:"lastWireguardHandshake" yaml:"lastWireguardHandshake"`
TransferReceived int64 `json:"transferReceived" yaml:"transferReceived"`
TransferSent int64 `json:"transferSent" yaml:"transferSent"`
Latency time.Duration `json:"latency" yaml:"latency"`
RosenpassEnabled bool `json:"quantumResistance" yaml:"quantumResistance"`
Routes []string `json:"routes" yaml:"routes"`
}
type peersStateOutput struct {
@@ -77,28 +71,17 @@ type iceCandidateType struct {
Remote string `json:"remote" yaml:"remote"`
}
type nsServerGroupStateOutput struct {
Servers []string `json:"servers" yaml:"servers"`
Domains []string `json:"domains" yaml:"domains"`
Enabled bool `json:"enabled" yaml:"enabled"`
Error string `json:"error" yaml:"error"`
}
type statusOutputOverview struct {
Peers peersStateOutput `json:"peers" yaml:"peers"`
CliVersion string `json:"cliVersion" yaml:"cliVersion"`
DaemonVersion string `json:"daemonVersion" yaml:"daemonVersion"`
ManagementState managementStateOutput `json:"management" yaml:"management"`
SignalState signalStateOutput `json:"signal" yaml:"signal"`
Relays relayStateOutput `json:"relays" yaml:"relays"`
IP string `json:"netbirdIp" yaml:"netbirdIp"`
PubKey string `json:"publicKey" yaml:"publicKey"`
KernelInterface bool `json:"usesKernelInterface" yaml:"usesKernelInterface"`
FQDN string `json:"fqdn" yaml:"fqdn"`
RosenpassEnabled bool `json:"quantumResistance" yaml:"quantumResistance"`
RosenpassPermissive bool `json:"quantumResistancePermissive" yaml:"quantumResistancePermissive"`
Routes []string `json:"routes" yaml:"routes"`
NSServerGroups []nsServerGroupStateOutput `json:"dnsServers" yaml:"dnsServers"`
Peers peersStateOutput `json:"peers" yaml:"peers"`
CliVersion string `json:"cliVersion" yaml:"cliVersion"`
DaemonVersion string `json:"daemonVersion" yaml:"daemonVersion"`
ManagementState managementStateOutput `json:"management" yaml:"management"`
SignalState signalStateOutput `json:"signal" yaml:"signal"`
Relays relayStateOutput `json:"relays" yaml:"relays"`
IP string `json:"netbirdIp" yaml:"netbirdIp"`
PubKey string `json:"publicKey" yaml:"publicKey"`
KernelInterface bool `json:"usesKernelInterface" yaml:"usesKernelInterface"`
FQDN string `json:"fqdn" yaml:"fqdn"`
}
var (
@@ -147,9 +130,9 @@ func statusFunc(cmd *cobra.Command, args []string) error {
return fmt.Errorf("failed initializing log %v", err)
}
ctx := internal.CtxInitState(cmd.Context())
ctx := internal.CtxInitState(context.Background())
resp, err := getStatus(ctx)
resp, err := getStatus(ctx, cmd)
if err != nil {
return err
}
@@ -182,7 +165,7 @@ func statusFunc(cmd *cobra.Command, args []string) error {
case yamlFlag:
statusOutputString, err = parseToYAML(outputInformationHolder)
default:
statusOutputString = parseGeneralSummary(outputInformationHolder, false, false, false)
statusOutputString = parseGeneralSummary(outputInformationHolder, false, false)
}
if err != nil {
@@ -194,7 +177,7 @@ func statusFunc(cmd *cobra.Command, args []string) error {
return nil
}
func getStatus(ctx context.Context) (*proto.StatusResponse, error) {
func getStatus(ctx context.Context, cmd *cobra.Command) (*proto.StatusResponse, error) {
conn, err := DialClientGRPCServer(ctx, daemonAddr)
if err != nil {
return nil, fmt.Errorf("failed to connect to daemon error: %v\n"+
@@ -203,7 +186,7 @@ func getStatus(ctx context.Context) (*proto.StatusResponse, error) {
}
defer conn.Close()
resp, err := proto.NewDaemonServiceClient(conn).Status(ctx, &proto.StatusRequest{GetFullPeerStatus: true})
resp, err := proto.NewDaemonServiceClient(conn).Status(cmd.Context(), &proto.StatusRequest{GetFullPeerStatus: true})
if err != nil {
return nil, fmt.Errorf("status failed: %v", status.Convert(err).Message())
}
@@ -270,25 +253,16 @@ func convertToStatusOutputOverview(resp *proto.StatusResponse) statusOutputOverv
peersOverview := mapPeers(resp.GetFullStatus().GetPeers())
overview := statusOutputOverview{
Peers: peersOverview,
CliVersion: version.NetbirdVersion(),
DaemonVersion: resp.GetDaemonVersion(),
ManagementState: managementOverview,
SignalState: signalOverview,
Relays: relayOverview,
IP: pbFullStatus.GetLocalPeerState().GetIP(),
PubKey: pbFullStatus.GetLocalPeerState().GetPubKey(),
KernelInterface: pbFullStatus.GetLocalPeerState().GetKernelInterface(),
FQDN: pbFullStatus.GetLocalPeerState().GetFqdn(),
RosenpassEnabled: pbFullStatus.GetLocalPeerState().GetRosenpassEnabled(),
RosenpassPermissive: pbFullStatus.GetLocalPeerState().GetRosenpassPermissive(),
Routes: pbFullStatus.GetLocalPeerState().GetRoutes(),
NSServerGroups: mapNSGroups(pbFullStatus.GetDnsServers()),
}
if anonymizeFlag {
anonymizer := anonymize.NewAnonymizer(anonymize.DefaultAddresses())
anonymizeOverview(anonymizer, &overview)
Peers: peersOverview,
CliVersion: version.NetbirdVersion(),
DaemonVersion: resp.GetDaemonVersion(),
ManagementState: managementOverview,
SignalState: signalOverview,
Relays: relayOverview,
IP: pbFullStatus.GetLocalPeerState().GetIP(),
PubKey: pbFullStatus.GetLocalPeerState().GetPubKey(),
KernelInterface: pbFullStatus.GetLocalPeerState().GetKernelInterface(),
FQDN: pbFullStatus.GetLocalPeerState().GetFqdn(),
}
return overview
@@ -320,19 +294,6 @@ func mapRelays(relays []*proto.RelayState) relayStateOutput {
}
}
func mapNSGroups(servers []*proto.NSGroupState) []nsServerGroupStateOutput {
mappedNSGroups := make([]nsServerGroupStateOutput, 0, len(servers))
for _, pbNsGroupServer := range servers {
mappedNSGroups = append(mappedNSGroups, nsServerGroupStateOutput{
Servers: pbNsGroupServer.GetServers(),
Domains: pbNsGroupServer.GetDomains(),
Enabled: pbNsGroupServer.GetEnabled(),
Error: pbNsGroupServer.GetError(),
})
}
return mappedNSGroups
}
func mapPeers(peers []*proto.PeerState) peersStateOutput {
var peersStateDetail []peerStateDetailOutput
localICE := ""
@@ -385,9 +346,6 @@ func mapPeers(peers []*proto.PeerState) peersStateOutput {
LastWireguardHandshake: lastHandshake,
TransferReceived: transferReceived,
TransferSent: transferSent,
Latency: pbPeerState.GetLatency().AsDuration(),
RosenpassEnabled: pbPeerState.GetRosenpassEnabled(),
Routes: pbPeerState.GetRoutes(),
}
peersStateDetail = append(peersStateDetail, peerState)
@@ -437,7 +395,8 @@ func parseToYAML(overview statusOutputOverview) (string, error) {
return string(yamlBytes), nil
}
func parseGeneralSummary(overview statusOutputOverview, showURL bool, showRelays bool, showNameServers bool) string {
func parseGeneralSummary(overview statusOutputOverview, showURL bool, showRelays bool) string {
var managementConnString string
if overview.ManagementState.Connected {
managementConnString = "Connected"
@@ -473,7 +432,7 @@ func parseGeneralSummary(overview statusOutputOverview, showURL bool, showRelays
interfaceIP = "N/A"
}
var relaysString string
var relayAvailableString string
if showRelays {
for _, relay := range overview.Relays.Details {
available := "Available"
@@ -482,98 +441,42 @@ func parseGeneralSummary(overview statusOutputOverview, showURL bool, showRelays
available = "Unavailable"
reason = fmt.Sprintf(", reason: %s", relay.Error)
}
relaysString += fmt.Sprintf("\n [%s] is %s%s", relay.URI, available, reason)
relayAvailableString += fmt.Sprintf("\n [%s] is %s%s", relay.URI, available, reason)
}
} else {
relaysString = fmt.Sprintf("%d/%d Available", overview.Relays.Available, overview.Relays.Total)
}
routes := "-"
if len(overview.Routes) > 0 {
sort.Strings(overview.Routes)
routes = strings.Join(overview.Routes, ", ")
}
var dnsServersString string
if showNameServers {
for _, nsServerGroup := range overview.NSServerGroups {
enabled := "Available"
if !nsServerGroup.Enabled {
enabled = "Unavailable"
}
errorString := ""
if nsServerGroup.Error != "" {
errorString = fmt.Sprintf(", reason: %s", nsServerGroup.Error)
errorString = strings.TrimSpace(errorString)
}
domainsString := strings.Join(nsServerGroup.Domains, ", ")
if domainsString == "" {
domainsString = "." // Show "." for the default zone
}
dnsServersString += fmt.Sprintf(
"\n [%s] for [%s] is %s%s",
strings.Join(nsServerGroup.Servers, ", "),
domainsString,
enabled,
errorString,
)
}
} else {
dnsServersString = fmt.Sprintf("%d/%d Available", countEnabled(overview.NSServerGroups), len(overview.NSServerGroups))
}
rosenpassEnabledStatus := "false"
if overview.RosenpassEnabled {
rosenpassEnabledStatus = "true"
if overview.RosenpassPermissive {
rosenpassEnabledStatus = "true (permissive)" //nolint:gosec
}
relayAvailableString = fmt.Sprintf("%d/%d Available", overview.Relays.Available, overview.Relays.Total)
}
peersCountString := fmt.Sprintf("%d/%d Connected", overview.Peers.Connected, overview.Peers.Total)
goos := runtime.GOOS
goarch := runtime.GOARCH
goarm := ""
if goarch == "arm" {
goarm = fmt.Sprintf(" (ARMv%s)", os.Getenv("GOARM"))
}
summary := fmt.Sprintf(
"OS: %s\n"+
"Daemon version: %s\n"+
"Daemon version: %s\n"+
"CLI version: %s\n"+
"Management: %s\n"+
"Signal: %s\n"+
"Relays: %s\n"+
"Nameservers: %s\n"+
"FQDN: %s\n"+
"NetBird IP: %s\n"+
"Interface type: %s\n"+
"Quantum resistance: %s\n"+
"Routes: %s\n"+
"Peers count: %s\n",
fmt.Sprintf("%s/%s%s", goos, goarch, goarm),
overview.DaemonVersion,
version.NetbirdVersion(),
managementConnString,
signalConnString,
relaysString,
dnsServersString,
relayAvailableString,
overview.FQDN,
interfaceIP,
interfaceTypeString,
rosenpassEnabledStatus,
routes,
peersCountString,
)
return summary
}
func parseToFullDetailSummary(overview statusOutputOverview) string {
parsedPeersString := parsePeers(overview.Peers, overview.RosenpassEnabled, overview.RosenpassPermissive)
summary := parseGeneralSummary(overview, true, true, true)
parsedPeersString := parsePeers(overview.Peers)
summary := parseGeneralSummary(overview, true, true)
return fmt.Sprintf(
"Peers detail:"+
@@ -584,7 +487,7 @@ func parseToFullDetailSummary(overview statusOutputOverview) string {
)
}
func parsePeers(peers peersStateOutput, rosenpassEnabled, rosenpassPermissive bool) string {
func parsePeers(peers peersStateOutput) string {
var (
peersString = ""
)
@@ -610,28 +513,14 @@ func parsePeers(peers peersStateOutput, rosenpassEnabled, rosenpassPermissive bo
if peerState.IceCandidateEndpoint.Remote != "" {
remoteICEEndpoint = peerState.IceCandidateEndpoint.Remote
}
rosenpassEnabledStatus := "false"
if rosenpassEnabled {
if peerState.RosenpassEnabled {
rosenpassEnabledStatus = "true"
} else {
if rosenpassPermissive {
rosenpassEnabledStatus = "false (remote didn't enable quantum resistance)"
} else {
rosenpassEnabledStatus = "false (connection won't work without a permissive mode)"
}
}
} else {
if peerState.RosenpassEnabled {
rosenpassEnabledStatus = "false (connection might not work without a remote permissive mode)"
}
lastStatusUpdate := "-"
if !peerState.LastStatusUpdate.IsZero() {
lastStatusUpdate = peerState.LastStatusUpdate.Format("2006-01-02 15:04:05")
}
routes := "-"
if len(peerState.Routes) > 0 {
sort.Strings(peerState.Routes)
routes = strings.Join(peerState.Routes, ", ")
lastWireguardHandshake := "-"
if !peerState.LastWireguardHandshake.IsZero() && peerState.LastWireguardHandshake != time.Unix(0, 0) {
lastWireguardHandshake = peerState.LastWireguardHandshake.Format("2006-01-02 15:04:05")
}
peerString := fmt.Sprintf(
@@ -645,11 +534,8 @@ func parsePeers(peers peersStateOutput, rosenpassEnabled, rosenpassPermissive bo
" ICE candidate (Local/Remote): %s/%s\n"+
" ICE candidate endpoints (Local/Remote): %s/%s\n"+
" Last connection update: %s\n"+
" Last WireGuard handshake: %s\n"+
" Transfer status (received/sent) %s/%s\n"+
" Quantum resistance: %s\n"+
" Routes: %s\n"+
" Latency: %s\n",
" Last Wireguard handshake: %s\n"+
" Transfer status (received/sent) %s/%s\n",
peerState.FQDN,
peerState.IP,
peerState.PubKey,
@@ -660,13 +546,10 @@ func parsePeers(peers peersStateOutput, rosenpassEnabled, rosenpassPermissive bo
remoteICE,
localICEEndpoint,
remoteICEEndpoint,
timeAgo(peerState.LastStatusUpdate),
timeAgo(peerState.LastWireguardHandshake),
lastStatusUpdate,
lastWireguardHandshake,
toIEC(peerState.TransferReceived),
toIEC(peerState.TransferSent),
rosenpassEnabledStatus,
routes,
peerState.Latency.String(),
)
peersString += peerString
@@ -720,139 +603,3 @@ func toIEC(b int64) string {
return fmt.Sprintf("%.1f %ciB",
float64(b)/float64(div), "KMGTPE"[exp])
}
func countEnabled(dnsServers []nsServerGroupStateOutput) int {
count := 0
for _, server := range dnsServers {
if server.Enabled {
count++
}
}
return count
}
// timeAgo returns a string representing the duration since the provided time in a human-readable format.
func timeAgo(t time.Time) string {
if t.IsZero() || t.Equal(time.Unix(0, 0)) {
return "-"
}
duration := time.Since(t)
switch {
case duration < time.Second:
return "Now"
case duration < time.Minute:
seconds := int(duration.Seconds())
if seconds == 1 {
return "1 second ago"
}
return fmt.Sprintf("%d seconds ago", seconds)
case duration < time.Hour:
minutes := int(duration.Minutes())
seconds := int(duration.Seconds()) % 60
if minutes == 1 {
if seconds == 1 {
return "1 minute, 1 second ago"
} else if seconds > 0 {
return fmt.Sprintf("1 minute, %d seconds ago", seconds)
}
return "1 minute ago"
}
if seconds > 0 {
return fmt.Sprintf("%d minutes, %d seconds ago", minutes, seconds)
}
return fmt.Sprintf("%d minutes ago", minutes)
case duration < 24*time.Hour:
hours := int(duration.Hours())
minutes := int(duration.Minutes()) % 60
if hours == 1 {
if minutes == 1 {
return "1 hour, 1 minute ago"
} else if minutes > 0 {
return fmt.Sprintf("1 hour, %d minutes ago", minutes)
}
return "1 hour ago"
}
if minutes > 0 {
return fmt.Sprintf("%d hours, %d minutes ago", hours, minutes)
}
return fmt.Sprintf("%d hours ago", hours)
}
days := int(duration.Hours()) / 24
hours := int(duration.Hours()) % 24
if days == 1 {
if hours == 1 {
return "1 day, 1 hour ago"
} else if hours > 0 {
return fmt.Sprintf("1 day, %d hours ago", hours)
}
return "1 day ago"
}
if hours > 0 {
return fmt.Sprintf("%d days, %d hours ago", days, hours)
}
return fmt.Sprintf("%d days ago", days)
}
func anonymizePeerDetail(a *anonymize.Anonymizer, peer *peerStateDetailOutput) {
peer.FQDN = a.AnonymizeDomain(peer.FQDN)
if localIP, port, err := net.SplitHostPort(peer.IceCandidateEndpoint.Local); err == nil {
peer.IceCandidateEndpoint.Local = fmt.Sprintf("%s:%s", a.AnonymizeIPString(localIP), port)
}
if remoteIP, port, err := net.SplitHostPort(peer.IceCandidateEndpoint.Remote); err == nil {
peer.IceCandidateEndpoint.Remote = fmt.Sprintf("%s:%s", a.AnonymizeIPString(remoteIP), port)
}
for i, route := range peer.Routes {
peer.Routes[i] = a.AnonymizeIPString(route)
}
for i, route := range peer.Routes {
prefix, err := netip.ParsePrefix(route)
if err == nil {
ip := a.AnonymizeIPString(prefix.Addr().String())
peer.Routes[i] = fmt.Sprintf("%s/%d", ip, prefix.Bits())
}
}
}
func anonymizeOverview(a *anonymize.Anonymizer, overview *statusOutputOverview) {
for i, peer := range overview.Peers.Details {
peer := peer
anonymizePeerDetail(a, &peer)
overview.Peers.Details[i] = peer
}
overview.ManagementState.URL = a.AnonymizeURI(overview.ManagementState.URL)
overview.ManagementState.Error = a.AnonymizeString(overview.ManagementState.Error)
overview.SignalState.URL = a.AnonymizeURI(overview.SignalState.URL)
overview.SignalState.Error = a.AnonymizeString(overview.SignalState.Error)
overview.IP = a.AnonymizeIPString(overview.IP)
for i, detail := range overview.Relays.Details {
detail.URI = a.AnonymizeURI(detail.URI)
detail.Error = a.AnonymizeString(detail.Error)
overview.Relays.Details[i] = detail
}
for i, nsGroup := range overview.NSServerGroups {
for j, domain := range nsGroup.Domains {
overview.NSServerGroups[i].Domains[j] = a.AnonymizeDomain(domain)
}
for j, ns := range nsGroup.Servers {
host, port, err := net.SplitHostPort(ns)
if err == nil {
overview.NSServerGroups[i].Servers[j] = fmt.Sprintf("%s:%s", a.AnonymizeIPString(host), port)
}
}
}
for i, route := range overview.Routes {
prefix, err := netip.ParsePrefix(route)
if err == nil {
ip := a.AnonymizeIPString(prefix.Addr().String())
overview.Routes[i] = fmt.Sprintf("%s/%d", ip, prefix.Bits())
}
}
overview.FQDN = a.AnonymizeDomain(overview.FQDN)
}

View File

@@ -3,14 +3,11 @@ package cmd
import (
"bytes"
"encoding/json"
"fmt"
"runtime"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/protobuf/types/known/durationpb"
"google.golang.org/protobuf/types/known/timestamppb"
"github.com/netbirdio/netbird/client/proto"
@@ -45,10 +42,6 @@ var resp = &proto.StatusResponse{
LastWireguardHandshake: timestamppb.New(time.Date(2001, time.Month(1), 1, 1, 1, 2, 0, time.UTC)),
BytesRx: 200,
BytesTx: 100,
Routes: []string{
"10.1.0.0/24",
},
Latency: durationpb.New(time.Duration(10000000)),
},
{
IP: "192.168.178.102",
@@ -65,7 +58,6 @@ var resp = &proto.StatusResponse{
LastWireguardHandshake: timestamppb.New(time.Date(2002, time.Month(2), 2, 2, 2, 3, 0, time.UTC)),
BytesRx: 2000,
BytesTx: 1000,
Latency: durationpb.New(time.Duration(10000000)),
},
},
ManagementState: &proto.ManagementState{
@@ -95,31 +87,6 @@ var resp = &proto.StatusResponse{
PubKey: "Some-Pub-Key",
KernelInterface: true,
Fqdn: "some-localhost.awesome-domain.com",
Routes: []string{
"10.10.0.0/24",
},
},
DnsServers: []*proto.NSGroupState{
{
Servers: []string{
"8.8.8.8:53",
},
Domains: nil,
Enabled: true,
Error: "",
},
{
Servers: []string{
"1.1.1.1:53",
"2.2.2.2:53",
},
Domains: []string{
"example.com",
"example.net",
},
Enabled: false,
Error: "timeout",
},
},
},
DaemonVersion: "0.14.1",
@@ -149,10 +116,6 @@ var overview = statusOutputOverview{
LastWireguardHandshake: time.Date(2001, 1, 1, 1, 1, 2, 0, time.UTC),
TransferReceived: 200,
TransferSent: 100,
Routes: []string{
"10.1.0.0/24",
},
Latency: time.Duration(10000000),
},
{
IP: "192.168.178.102",
@@ -173,7 +136,6 @@ var overview = statusOutputOverview{
LastWireguardHandshake: time.Date(2002, 2, 2, 2, 2, 3, 0, time.UTC),
TransferReceived: 2000,
TransferSent: 1000,
Latency: time.Duration(10000000),
},
},
},
@@ -209,31 +171,6 @@ var overview = statusOutputOverview{
PubKey: "Some-Pub-Key",
KernelInterface: true,
FQDN: "some-localhost.awesome-domain.com",
NSServerGroups: []nsServerGroupStateOutput{
{
Servers: []string{
"8.8.8.8:53",
},
Domains: nil,
Enabled: true,
Error: "",
},
{
Servers: []string{
"1.1.1.1:53",
"2.2.2.2:53",
},
Domains: []string{
"example.com",
"example.net",
},
Enabled: false,
Error: "timeout",
},
},
Routes: []string{
"10.10.0.0/24",
},
}
func TestConversionFromFullStatusToOutputOverview(t *testing.T) {
@@ -294,12 +231,7 @@ func TestParsingToJSON(t *testing.T) {
},
"lastWireguardHandshake": "2001-01-01T01:01:02Z",
"transferReceived": 200,
"transferSent": 100,
"latency": 10000000,
"quantumResistance": false,
"routes": [
"10.1.0.0/24"
]
"transferSent": 100
},
{
"fqdn": "peer-2.awesome-domain.com",
@@ -319,10 +251,7 @@ func TestParsingToJSON(t *testing.T) {
},
"lastWireguardHandshake": "2002-02-02T02:02:03Z",
"transferReceived": 2000,
"transferSent": 1000,
"latency": 10000000,
"quantumResistance": false,
"routes": null
"transferSent": 1000
}
]
},
@@ -357,34 +286,7 @@ func TestParsingToJSON(t *testing.T) {
"netbirdIp": "192.168.178.100/16",
"publicKey": "Some-Pub-Key",
"usesKernelInterface": true,
"fqdn": "some-localhost.awesome-domain.com",
"quantumResistance": false,
"quantumResistancePermissive": false,
"routes": [
"10.10.0.0/24"
],
"dnsServers": [
{
"servers": [
"8.8.8.8:53"
],
"domains": null,
"enabled": true,
"error": ""
},
{
"servers": [
"1.1.1.1:53",
"2.2.2.2:53"
],
"domains": [
"example.com",
"example.net"
],
"enabled": false,
"error": "timeout"
}
]
"fqdn": "some-localhost.awesome-domain.com"
}`
// @formatter:on
@@ -418,10 +320,6 @@ func TestParsingToYAML(t *testing.T) {
lastWireguardHandshake: 2001-01-01T01:01:02Z
transferReceived: 200
transferSent: 100
latency: 10ms
quantumResistance: false
routes:
- 10.1.0.0/24
- fqdn: peer-2.awesome-domain.com
netbirdIp: 192.168.178.102
publicKey: Pubkey2
@@ -438,9 +336,6 @@ func TestParsingToYAML(t *testing.T) {
lastWireguardHandshake: 2002-02-02T02:02:03Z
transferReceived: 2000
transferSent: 1000
latency: 10ms
quantumResistance: false
routes: []
cliVersion: development
daemonVersion: 0.14.1
management:
@@ -465,39 +360,15 @@ netbirdIp: 192.168.178.100/16
publicKey: Some-Pub-Key
usesKernelInterface: true
fqdn: some-localhost.awesome-domain.com
quantumResistance: false
quantumResistancePermissive: false
routes:
- 10.10.0.0/24
dnsServers:
- servers:
- 8.8.8.8:53
domains: []
enabled: true
error: ""
- servers:
- 1.1.1.1:53
- 2.2.2.2:53
domains:
- example.com
- example.net
enabled: false
error: timeout
`
assert.Equal(t, expectedYAML, yaml)
}
func TestParsingToDetail(t *testing.T) {
// Calculate time ago based on the fixture dates
lastConnectionUpdate1 := timeAgo(overview.Peers.Details[0].LastStatusUpdate)
lastHandshake1 := timeAgo(overview.Peers.Details[0].LastWireguardHandshake)
lastConnectionUpdate2 := timeAgo(overview.Peers.Details[1].LastStatusUpdate)
lastHandshake2 := timeAgo(overview.Peers.Details[1].LastWireguardHandshake)
detail := parseToFullDetailSummary(overview)
expectedDetail := fmt.Sprintf(
expectedDetail :=
`Peers detail:
peer-1.awesome-domain.com:
NetBird IP: 192.168.178.101
@@ -508,12 +379,9 @@ func TestParsingToDetail(t *testing.T) {
Direct: true
ICE candidate (Local/Remote): -/-
ICE candidate endpoints (Local/Remote): -/-
Last connection update: %s
Last WireGuard handshake: %s
Last connection update: 2001-01-01 01:01:01
Last Wireguard handshake: 2001-01-01 01:01:02
Transfer status (received/sent) 200 B/100 B
Quantum resistance: false
Routes: 10.1.0.0/24
Latency: 10ms
peer-2.awesome-domain.com:
NetBird IP: 192.168.178.102
@@ -524,50 +392,38 @@ func TestParsingToDetail(t *testing.T) {
Direct: false
ICE candidate (Local/Remote): relay/prflx
ICE candidate endpoints (Local/Remote): 10.0.0.1:10001/10.0.10.1:10002
Last connection update: %s
Last WireGuard handshake: %s
Last connection update: 2002-02-02 02:02:02
Last Wireguard handshake: 2002-02-02 02:02:03
Transfer status (received/sent) 2.0 KiB/1000 B
Quantum resistance: false
Routes: -
Latency: 10ms
OS: %s/%s
Daemon version: 0.14.1
CLI version: %s
CLI version: development
Management: Connected to my-awesome-management.com:443
Signal: Connected to my-awesome-signal.com:443
Relays:
[stun:my-awesome-stun.com:3478] is Available
[turns:my-awesome-turn.com:443?transport=tcp] is Unavailable, reason: context: deadline exceeded
Nameservers:
[8.8.8.8:53] for [.] is Available
[1.1.1.1:53, 2.2.2.2:53] for [example.com, example.net] is Unavailable, reason: timeout
FQDN: some-localhost.awesome-domain.com
NetBird IP: 192.168.178.100/16
Interface type: Kernel
Quantum resistance: false
Routes: 10.10.0.0/24
Peers count: 2/2 Connected
`, lastConnectionUpdate1, lastHandshake1, lastConnectionUpdate2, lastHandshake2, runtime.GOOS, runtime.GOARCH, overview.CliVersion)
`
assert.Equal(t, expectedDetail, detail)
}
func TestParsingToShortVersion(t *testing.T) {
shortVersion := parseGeneralSummary(overview, false, false, false)
shortVersion := parseGeneralSummary(overview, false, false)
expectedString := fmt.Sprintf("OS: %s/%s", runtime.GOOS, runtime.GOARCH) + `
Daemon version: 0.14.1
expectedString :=
`Daemon version: 0.14.1
CLI version: development
Management: Connected
Signal: Connected
Relays: 1/2 Available
Nameservers: 1/2 Available
FQDN: some-localhost.awesome-domain.com
NetBird IP: 192.168.178.100/16
Interface type: Kernel
Quantum resistance: false
Routes: 10.10.0.0/24
Peers count: 2/2 Connected
`
@@ -581,31 +437,3 @@ func TestParsingOfIP(t *testing.T) {
assert.Equal(t, "192.168.178.123\n", parsedIP)
}
func TestTimeAgo(t *testing.T) {
now := time.Now()
cases := []struct {
name string
input time.Time
expected string
}{
{"Now", now, "Now"},
{"Seconds ago", now.Add(-10 * time.Second), "10 seconds ago"},
{"One minute ago", now.Add(-1 * time.Minute), "1 minute ago"},
{"Minutes and seconds ago", now.Add(-(1*time.Minute + 30*time.Second)), "1 minute, 30 seconds ago"},
{"One hour ago", now.Add(-1 * time.Hour), "1 hour ago"},
{"Hours and minutes ago", now.Add(-(2*time.Hour + 15*time.Minute)), "2 hours, 15 minutes ago"},
{"One day ago", now.Add(-24 * time.Hour), "1 day ago"},
{"Multiple days ago", now.Add(-(72*time.Hour + 20*time.Minute)), "3 days ago"},
{"Zero time", time.Time{}, "-"},
{"Unix zero time", time.Unix(0, 0), "-"},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
result := timeAgo(tc.input)
assert.Equal(t, tc.expected, result, "Failed %s", tc.name)
})
}
}

View File

@@ -13,7 +13,6 @@ import (
"google.golang.org/grpc"
"github.com/netbirdio/management-integrations/integrations"
clientProto "github.com/netbirdio/netbird/client/proto"
client "github.com/netbirdio/netbird/client/server"
mgmtProto "github.com/netbirdio/netbird/management/proto"
@@ -79,8 +78,8 @@ func startManagement(t *testing.T, config *mgmt.Config) (*grpc.Server, net.Liste
if err != nil {
return nil, nil
}
iv, _ := integrations.NewIntegratedValidator(eventStore)
accountManager, err := mgmt.BuildManager(store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, iv)
accountManager, err := mgmt.BuildManager(store, peersUpdateManager, nil, "", "",
eventStore, false)
if err != nil {
t.Fatal(err)
}

View File

@@ -40,7 +40,6 @@ func init() {
upCmd.PersistentFlags().BoolVarP(&foregroundMode, "foreground-mode", "F", false, "start service in foreground")
upCmd.PersistentFlags().StringVar(&interfaceName, interfaceNameFlag, iface.WgInterfaceDefault, "Wireguard interface name")
upCmd.PersistentFlags().Uint16Var(&wireguardPort, wireguardPortFlag, iface.DefaultWgPort, "Wireguard interface listening port")
upCmd.PersistentFlags().StringSliceVar(&extraIFaceBlackList, extraIFaceBlackListFlag, nil, "Extra list of default interfaces to ignore for listening")
}
func upFunc(cmd *cobra.Command, args []string) error {
@@ -84,26 +83,17 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command) error {
}
ic := internal.ConfigInput{
ManagementURL: managementURL,
AdminURL: adminURL,
ConfigPath: configPath,
NATExternalIPs: natExternalIPs,
CustomDNSAddress: customDNSAddressConverted,
ExtraIFaceBlackList: extraIFaceBlackList,
ManagementURL: managementURL,
AdminURL: adminURL,
ConfigPath: configPath,
NATExternalIPs: natExternalIPs,
CustomDNSAddress: customDNSAddressConverted,
}
if cmd.Flag(enableRosenpassFlag).Changed {
ic.RosenpassEnabled = &rosenpassEnabled
}
if cmd.Flag(rosenpassPermissiveFlag).Changed {
ic.RosenpassPermissive = &rosenpassPermissive
}
if cmd.Flag(serverSSHAllowedFlag).Changed {
ic.ServerSSHAllowed = &serverSSHAllowed
}
if cmd.Flag(interfaceNameFlag).Changed {
if err := parseInterfaceName(interfaceName); err != nil {
return err
@@ -120,18 +110,6 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command) error {
ic.PreSharedKey = &preSharedKey
}
if cmd.Flag(disableAutoConnectFlag).Changed {
ic.DisableAutoConnect = &autoConnectDisabled
if autoConnectDisabled {
cmd.Println("Autoconnect has been disabled. The client won't connect automatically when the service starts.")
}
if !autoConnectDisabled {
cmd.Println("Autoconnect has been enabled. The client will connect automatically when the service starts.")
}
}
config, err := internal.UpdateOrCreateConfig(ic)
if err != nil {
return fmt.Errorf("get config file: %v", err)
@@ -151,6 +129,7 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command) error {
}
func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error {
customDNSAddressConverted, err := parseCustomDNSAddress(cmd.Flag(dnsResolverAddress).Changed)
if err != nil {
return err
@@ -191,7 +170,6 @@ func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error {
CustomDNSAddress: customDNSAddressConverted,
IsLinuxDesktopClient: isLinuxRunningDesktop(),
Hostname: hostName,
ExtraIFaceBlacklist: extraIFaceBlackList,
}
if rootCmd.PersistentFlags().Changed(preSharedKeyFlag) {
@@ -202,18 +180,6 @@ func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error {
loginRequest.RosenpassEnabled = &rosenpassEnabled
}
if cmd.Flag(rosenpassPermissiveFlag).Changed {
loginRequest.RosenpassPermissive = &rosenpassPermissive
}
if cmd.Flag(serverSSHAllowedFlag).Changed {
loginRequest.ServerSSHAllowed = &serverSSHAllowed
}
if cmd.Flag(disableAutoConnectFlag).Changed {
loginRequest.DisableAutoConnect = &autoConnectDisabled
}
if cmd.Flag(interfaceNameFlag).Changed {
if err := parseInterfaceName(interfaceName); err != nil {
return err

View File

@@ -26,7 +26,7 @@ type HTTPClient interface {
}
// AuthFlowInfo holds information for the OAuth 2.0 authorization flow
type AuthFlowInfo struct { //nolint:revive
type AuthFlowInfo struct {
DeviceCode string `json:"device_code"`
UserCode string `json:"user_code"`
VerificationURI string `json:"verification_uri"`

View File

@@ -30,26 +30,20 @@ const (
DefaultAdminURL = "https://app.netbird.io:443"
)
var defaultInterfaceBlacklist = []string{
iface.WgInterfaceDefault, "wt", "utun", "tun0", "zt", "ZeroTier", "wg", "ts",
"Tailscale", "tailscale", "docker", "veth", "br-", "lo",
}
var defaultInterfaceBlacklist = []string{iface.WgInterfaceDefault, "wt", "utun", "tun0", "zt", "ZeroTier", "wg", "ts",
"Tailscale", "tailscale", "docker", "veth", "br-", "lo"}
// ConfigInput carries configuration changes to the client
type ConfigInput struct {
ManagementURL string
AdminURL string
ConfigPath string
PreSharedKey *string
ServerSSHAllowed *bool
NATExternalIPs []string
CustomDNSAddress []byte
RosenpassEnabled *bool
RosenpassPermissive *bool
InterfaceName *string
WireguardPort *int
DisableAutoConnect *bool
ExtraIFaceBlackList []string
ManagementURL string
AdminURL string
ConfigPath string
PreSharedKey *string
NATExternalIPs []string
CustomDNSAddress []byte
RosenpassEnabled *bool
InterfaceName *string
WireguardPort *int
}
// Config Configuration type
@@ -64,8 +58,6 @@ type Config struct {
IFaceBlackList []string
DisableIPv6Discovery bool
RosenpassEnabled bool
RosenpassPermissive bool
ServerSSHAllowed *bool
// SSHKey is a private SSH key in a PEM format
SSHKey string
@@ -87,10 +79,6 @@ type Config struct {
NATExternalIPs []string
// CustomDNSAddress sets the DNS resolver listening address in format ip:port
CustomDNSAddress string
// DisableAutoConnect determines whether the client should not start with the service
// it's set to false by default due to backwards compatibility
DisableAutoConnect bool
}
// ReadConfig read config file and return with Config. If it is not exists create a new with default values
@@ -100,7 +88,6 @@ func ReadConfig(configPath string) (*Config, error) {
if _, err := util.ReadJson(configPath, config); err != nil {
return nil, err
}
return config, nil
}
@@ -165,8 +152,6 @@ func createNewConfig(input ConfigInput) (*Config, error) {
DisableIPv6Discovery: false,
NATExternalIPs: input.NATExternalIPs,
CustomDNSAddress: string(input.CustomDNSAddress),
ServerSSHAllowed: util.False(),
DisableAutoConnect: false,
}
defaultManagementURL, err := parseURL("Management URL", DefaultManagementURL)
@@ -201,14 +186,6 @@ func createNewConfig(input ConfigInput) (*Config, error) {
config.RosenpassEnabled = *input.RosenpassEnabled
}
if input.RosenpassPermissive != nil {
config.RosenpassPermissive = *input.RosenpassPermissive
}
if input.ServerSSHAllowed != nil {
config.ServerSSHAllowed = input.ServerSSHAllowed
}
defaultAdminURL, err := parseURL("Admin URL", DefaultAdminURL)
if err != nil {
return nil, err
@@ -223,8 +200,7 @@ func createNewConfig(input ConfigInput) (*Config, error) {
config.AdminURL = newURL
}
// nolint:gocritic
config.IFaceBlackList = append(defaultInterfaceBlacklist, input.ExtraIFaceBlackList...)
config.IFaceBlackList = defaultInterfaceBlacklist
return config, nil
}
@@ -304,33 +280,6 @@ func update(input ConfigInput) (*Config, error) {
refresh = true
}
if input.RosenpassPermissive != nil {
config.RosenpassPermissive = *input.RosenpassPermissive
refresh = true
}
if input.DisableAutoConnect != nil {
config.DisableAutoConnect = *input.DisableAutoConnect
refresh = true
}
if input.ServerSSHAllowed != nil {
config.ServerSSHAllowed = input.ServerSSHAllowed
refresh = true
}
if config.ServerSSHAllowed == nil {
config.ServerSSHAllowed = util.True()
refresh = true
}
if len(input.ExtraIFaceBlackList) > 0 {
for _, iFace := range util.SliceDiff(input.ExtraIFaceBlackList, config.IFaceBlackList) {
config.IFaceBlackList = append(config.IFaceBlackList, iFace)
refresh = true
}
}
if refresh {
// since we have new management URL, we need to update config file
if err := util.WriteJson(input.ConfigPath, config); err != nil {
@@ -395,6 +344,7 @@ func configFileIsExists(path string) bool {
// If it can switch, then it updates the config and returns a new one. Otherwise, it returns the provided config.
// The check is performed only for the NetBird's managed version.
func UpdateOldManagementURL(ctx context.Context, config *Config, configPath string) (*Config, error) {
defaultManagementURL, err := parseURL("Management URL", DefaultManagementURL)
if err != nil {
return nil, err

View File

@@ -18,6 +18,7 @@ func TestGetConfig(t *testing.T) {
config, err := UpdateOrCreateConfig(ConfigInput{
ConfigPath: filepath.Join(t.TempDir(), "config.json"),
})
if err != nil {
return
}
@@ -85,26 +86,6 @@ func TestGetConfig(t *testing.T) {
assert.Equal(t, readConf.(*Config).ManagementURL.String(), newManagementURL)
}
func TestExtraIFaceBlackList(t *testing.T) {
extraIFaceBlackList := []string{"eth1"}
path := filepath.Join(t.TempDir(), "config.json")
config, err := UpdateOrCreateConfig(ConfigInput{
ConfigPath: path,
ExtraIFaceBlackList: extraIFaceBlackList,
})
if err != nil {
return
}
assert.Contains(t, config.IFaceBlackList, "eth1")
readConf, err := util.ReadJson(path, config)
if err != nil {
return
}
assert.Contains(t, readConf.(*Config).IFaceBlackList, "eth1")
}
func TestHiddenPreSharedKey(t *testing.T) {
hidden := "**********"
samplePreSharedKey := "mysecretpresharedkey"
@@ -130,6 +111,7 @@ func TestHiddenPreSharedKey(t *testing.T) {
ConfigPath: cfgFile,
PreSharedKey: tt.preSharedKey,
})
if err != nil {
t.Fatalf("failed to get cfg: %s", err)
}

View File

@@ -4,8 +4,6 @@ import (
"context"
"errors"
"fmt"
"runtime"
"runtime/debug"
"strings"
"time"
@@ -25,13 +23,12 @@ import (
mgm "github.com/netbirdio/netbird/management/client"
mgmProto "github.com/netbirdio/netbird/management/proto"
signal "github.com/netbirdio/netbird/signal/client"
"github.com/netbirdio/netbird/util"
"github.com/netbirdio/netbird/version"
)
// RunClient with main logic.
func RunClient(ctx context.Context, config *Config, statusRecorder *peer.Status) error {
return runClient(ctx, config, statusRecorder, MobileDependency{}, nil, nil, nil, nil, nil)
return runClient(ctx, config, statusRecorder, MobileDependency{}, nil, nil, nil, nil)
}
// RunClientWithProbes runs the client's main logic with probes attached
@@ -43,9 +40,8 @@ func RunClientWithProbes(
signalProbe *Probe,
relayProbe *Probe,
wgProbe *Probe,
engineChan chan<- *Engine,
) error {
return runClient(ctx, config, statusRecorder, MobileDependency{}, mgmProbe, signalProbe, relayProbe, wgProbe, engineChan)
return runClient(ctx, config, statusRecorder, MobileDependency{}, mgmProbe, signalProbe, relayProbe, wgProbe)
}
// RunClientMobile with main logic on mobile system
@@ -67,7 +63,7 @@ func RunClientMobile(
HostDNSAddresses: dnsAddresses,
DnsReadyListener: dnsReadyListener,
}
return runClient(ctx, config, statusRecorder, mobileDependency, nil, nil, nil, nil, nil)
return runClient(ctx, config, statusRecorder, mobileDependency, nil, nil, nil, nil)
}
func RunClientiOS(
@@ -83,7 +79,7 @@ func RunClientiOS(
NetworkChangeListener: networkChangeListener,
DnsManager: dnsManager,
}
return runClient(ctx, config, statusRecorder, mobileDependency, nil, nil, nil, nil, nil)
return runClient(ctx, config, statusRecorder, mobileDependency, nil, nil, nil, nil)
}
func runClient(
@@ -95,15 +91,8 @@ func runClient(
signalProbe *Probe,
relayProbe *Probe,
wgProbe *Probe,
engineChan chan<- *Engine,
) error {
defer func() {
if r := recover(); r != nil {
log.Panicf("Panic occurred: %v, stack trace: %s", r, string(debug.Stack()))
}
}()
log.Infof("starting NetBird client version %s on %s/%s", version.NetbirdVersion(), runtime.GOOS, runtime.GOARCH)
log.Infof("starting NetBird client version %s", version.NetbirdVersion())
// Check if client was not shut down in a clean way and restore DNS config if required.
// Otherwise, we might not be able to connect to the management server to retrieve new config.
@@ -245,9 +234,6 @@ func runClient(
log.Errorf("error while starting Netbird Connection Engine: %s", err)
return wrapErr(err)
}
if engineChan != nil {
engineChan <- engine
}
log.Print("Netbird engine started, my IP is: ", peerConfig.Address)
state.Set(StatusConnected)
@@ -257,10 +243,6 @@ func runClient(
backOff.Reset()
if engineChan != nil {
engineChan <- nil
}
err = engine.Stop()
if err != nil {
log.Errorf("failed stopping engine %v", err)
@@ -301,8 +283,6 @@ func createEngineConfig(key wgtypes.Key, config *Config, peerConfig *mgmProto.Pe
NATExternalIPs: config.NATExternalIPs,
CustomDNSAddress: config.CustomDNSAddress,
RosenpassEnabled: config.RosenpassEnabled,
RosenpassPermissive: config.RosenpassPermissive,
ServerSSHAllowed: util.ReturnBoolWithDefaultTrue(config.ServerSSHAllowed),
}
if config.PreSharedKey != "" {

View File

@@ -8,7 +8,6 @@ import (
"net/netip"
"os"
"strings"
"time"
log "github.com/sirupsen/logrus"
)
@@ -24,16 +23,10 @@ const (
fileMaxNumberOfSearchDomains = 6
)
const (
dnsFailoverTimeout = 4 * time.Second
dnsFailoverAttempts = 1
)
type fileConfigurator struct {
repair *repair
originalPerms os.FileMode
nbNameserverIP string
originalPerms os.FileMode
}
func newFileConfigurator() (hostManager, error) {
@@ -71,7 +64,7 @@ func (f *fileConfigurator) applyDNSConfig(config HostDNSConfig) error {
}
nbSearchDomains := searchDomains(config)
f.nbNameserverIP = config.ServerIP
nbNameserverIP := config.ServerIP
resolvConf, err := parseBackupResolvConf()
if err != nil {
@@ -80,11 +73,11 @@ func (f *fileConfigurator) applyDNSConfig(config HostDNSConfig) error {
f.repair.stopWatchFileChanges()
err = f.updateConfig(nbSearchDomains, f.nbNameserverIP, resolvConf)
err = f.updateConfig(nbSearchDomains, nbNameserverIP, resolvConf)
if err != nil {
return err
}
f.repair.watchFileChanges(nbSearchDomains, f.nbNameserverIP)
f.repair.watchFileChanges(nbSearchDomains, nbNameserverIP)
return nil
}
@@ -92,11 +85,10 @@ func (f *fileConfigurator) updateConfig(nbSearchDomains []string, nbNameserverIP
searchDomainList := mergeSearchDomains(nbSearchDomains, cfg.searchDomains)
nameServers := generateNsList(nbNameserverIP, cfg)
options := prepareOptionsWithTimeout(cfg.others, int(dnsFailoverTimeout.Seconds()), dnsFailoverAttempts)
buf := prepareResolvConfContent(
searchDomainList,
nameServers,
options)
cfg.others)
log.Debugf("creating managed file %s", defaultResolvConfPath)
err := os.WriteFile(defaultResolvConfPath, buf.Bytes(), f.originalPerms)
@@ -139,12 +131,7 @@ func (f *fileConfigurator) backup() error {
}
func (f *fileConfigurator) restore() error {
err := removeFirstNbNameserver(fileDefaultResolvConfBackupLocation, f.nbNameserverIP)
if err != nil {
log.Errorf("Failed to remove netbird nameserver from %s on backup restore: %s", fileDefaultResolvConfBackupLocation, err)
}
err = copyFile(fileDefaultResolvConfBackupLocation, defaultResolvConfPath)
err := copyFile(fileDefaultResolvConfBackupLocation, defaultResolvConfPath)
if err != nil {
return fmt.Errorf("restoring %s from %s: %w", defaultResolvConfPath, fileDefaultResolvConfBackupLocation, err)
}
@@ -170,7 +157,7 @@ func (f *fileConfigurator) restoreUncleanShutdownDNS(storedDNSAddress *netip.Add
currentDNSAddress, err := netip.ParseAddr(resolvConf.nameServers[0])
// not a valid first nameserver -> restore
if err != nil {
log.Errorf("restoring unclean shutdown: parse dns address %s failed: %s", resolvConf.nameServers[0], err)
log.Errorf("restoring unclean shutdown: parse dns address %s failed: %s", resolvConf.nameServers[1], err)
return restoreResolvConfFile()
}

View File

@@ -5,7 +5,6 @@ package dns
import (
"fmt"
"os"
"regexp"
"strings"
log "github.com/sirupsen/logrus"
@@ -15,9 +14,6 @@ const (
defaultResolvConfPath = "/etc/resolv.conf"
)
var timeoutRegex = regexp.MustCompile(`timeout:\d+`)
var attemptsRegex = regexp.MustCompile(`attempts:\d+`)
type resolvConf struct {
nameServers []string
searchDomains []string
@@ -107,62 +103,3 @@ func parseResolvConfFile(resolvConfFile string) (*resolvConf, error) {
}
return rconf, nil
}
// prepareOptionsWithTimeout appends timeout to existing options if it doesn't exist,
// otherwise it adds a new option with timeout and attempts.
func prepareOptionsWithTimeout(input []string, timeout int, attempts int) []string {
configs := make([]string, len(input))
copy(configs, input)
for i, config := range configs {
if strings.HasPrefix(config, "options") {
config = strings.ReplaceAll(config, "rotate", "")
config = strings.Join(strings.Fields(config), " ")
if strings.Contains(config, "timeout:") {
config = timeoutRegex.ReplaceAllString(config, fmt.Sprintf("timeout:%d", timeout))
} else {
config = strings.Replace(config, "options ", fmt.Sprintf("options timeout:%d ", timeout), 1)
}
if strings.Contains(config, "attempts:") {
config = attemptsRegex.ReplaceAllString(config, fmt.Sprintf("attempts:%d", attempts))
} else {
config = strings.Replace(config, "options ", fmt.Sprintf("options attempts:%d ", attempts), 1)
}
configs[i] = config
return configs
}
}
return append(configs, fmt.Sprintf("options timeout:%d attempts:%d", timeout, attempts))
}
// removeFirstNbNameserver removes the given nameserver from the given file if it is in the first position
// and writes the file back to the original location
func removeFirstNbNameserver(filename, nameserverIP string) error {
resolvConf, err := parseResolvConfFile(filename)
if err != nil {
return fmt.Errorf("parse backup resolv.conf: %w", err)
}
content, err := os.ReadFile(filename)
if err != nil {
return fmt.Errorf("read %s: %w", filename, err)
}
if len(resolvConf.nameServers) > 1 && resolvConf.nameServers[0] == nameserverIP {
newContent := strings.Replace(string(content), fmt.Sprintf("nameserver %s\n", nameserverIP), "", 1)
stat, err := os.Stat(filename)
if err != nil {
return fmt.Errorf("stat %s: %w", filename, err)
}
if err := os.WriteFile(filename, []byte(newContent), stat.Mode()); err != nil {
return fmt.Errorf("write %s: %w", filename, err)
}
}
return nil
}

View File

@@ -6,8 +6,6 @@ import (
"os"
"path/filepath"
"testing"
"github.com/stretchr/testify/assert"
)
func Test_parseResolvConf(t *testing.T) {
@@ -174,131 +172,3 @@ nameserver 192.168.0.1
t.Errorf("unexpected resolv.conf content: %v", cfg)
}
}
func TestPrepareOptionsWithTimeout(t *testing.T) {
tests := []struct {
name string
others []string
timeout int
attempts int
expected []string
}{
{
name: "Append new options with timeout and attempts",
others: []string{"some config"},
timeout: 2,
attempts: 2,
expected: []string{"some config", "options timeout:2 attempts:2"},
},
{
name: "Modify existing options to exclude rotate and include timeout and attempts",
others: []string{"some config", "options rotate someother"},
timeout: 3,
attempts: 2,
expected: []string{"some config", "options attempts:2 timeout:3 someother"},
},
{
name: "Existing options with timeout and attempts are updated",
others: []string{"some config", "options timeout:4 attempts:3"},
timeout: 5,
attempts: 4,
expected: []string{"some config", "options timeout:5 attempts:4"},
},
{
name: "Modify existing options, add missing attempts before timeout",
others: []string{"some config", "options timeout:4"},
timeout: 4,
attempts: 3,
expected: []string{"some config", "options attempts:3 timeout:4"},
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
result := prepareOptionsWithTimeout(tc.others, tc.timeout, tc.attempts)
assert.Equal(t, tc.expected, result)
})
}
}
func TestRemoveFirstNbNameserver(t *testing.T) {
testCases := []struct {
name string
content string
ipToRemove string
expected string
}{
{
name: "Unrelated nameservers with comments and options",
content: `# This is a comment
options rotate
nameserver 1.1.1.1
# Another comment
nameserver 8.8.4.4
search example.com`,
ipToRemove: "9.9.9.9",
expected: `# This is a comment
options rotate
nameserver 1.1.1.1
# Another comment
nameserver 8.8.4.4
search example.com`,
},
{
name: "First nameserver matches",
content: `search example.com
nameserver 9.9.9.9
# oof, a comment
nameserver 8.8.4.4
options attempts:5`,
ipToRemove: "9.9.9.9",
expected: `search example.com
# oof, a comment
nameserver 8.8.4.4
options attempts:5`,
},
{
name: "Target IP not the first nameserver",
// nolint:dupword
content: `# Comment about the first nameserver
nameserver 8.8.4.4
# Comment before our target
nameserver 9.9.9.9
options timeout:2`,
ipToRemove: "9.9.9.9",
// nolint:dupword
expected: `# Comment about the first nameserver
nameserver 8.8.4.4
# Comment before our target
nameserver 9.9.9.9
options timeout:2`,
},
{
name: "Only nameserver matches",
content: `options debug
nameserver 9.9.9.9
search localdomain`,
ipToRemove: "9.9.9.9",
expected: `options debug
nameserver 9.9.9.9
search localdomain`,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
tempDir := t.TempDir()
tempFile := filepath.Join(tempDir, "resolv.conf")
err := os.WriteFile(tempFile, []byte(tc.content), 0644)
assert.NoError(t, err)
err = removeFirstNbNameserver(tempFile, tc.ipToRemove)
assert.NoError(t, err)
content, err := os.ReadFile(tempFile)
assert.NoError(t, err)
assert.Equal(t, tc.expected, string(content), "The resulting content should match the expected output.")
})
}
}

View File

@@ -65,7 +65,7 @@ func newHostManager(wgInterface string) (hostManager, error) {
return nil, err
}
log.Infof("System DNS manager discovered: %s", osManager)
log.Debugf("discovered mode is: %s", osManager)
return newHostManagerFromType(wgInterface, osManager)
}

View File

@@ -31,8 +31,6 @@ func (d *localResolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
response := d.lookupRecord(r)
if response != nil {
replyMessage.Answer = append(replyMessage.Answer, response)
} else {
replyMessage.Rcode = dns.RcodeNameError
}
err := w.WriteMsg(replyMessage)

View File

@@ -53,12 +53,10 @@ func (r *resolvconf) applyDNSConfig(config HostDNSConfig) error {
searchDomainList := searchDomains(config)
searchDomainList = mergeSearchDomains(searchDomainList, r.originalSearchDomains)
options := prepareOptionsWithTimeout(r.othersConfigs, int(dnsFailoverTimeout.Seconds()), dnsFailoverAttempts)
buf := prepareResolvConfContent(
searchDomainList,
append([]string{config.ServerIP}, r.originalNameServers...),
options)
r.othersConfigs)
// create a backup for unclean shutdown detection before the resolv.conf is changed
if err := createUncleanShutdownIndicator(defaultResolvConfPath, resolvConfManager, config.ServerIP); err != nil {

View File

@@ -4,7 +4,6 @@ import (
"context"
"fmt"
"net/netip"
"strings"
"sync"
"github.com/miekg/dns"
@@ -12,7 +11,6 @@ import (
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/listener"
"github.com/netbirdio/netbird/client/internal/peer"
nbdns "github.com/netbirdio/netbird/dns"
)
@@ -61,8 +59,6 @@ type DefaultServer struct {
// make sense on mobile only
searchDomainNotifier *notifier
iosDnsManager IosDnsManager
statusRecorder *peer.Status
}
type handlerWithStop interface {
@@ -77,12 +73,7 @@ type muxUpdate struct {
}
// NewDefaultServer returns a new dns server
func NewDefaultServer(
ctx context.Context,
wgInterface WGIface,
customAddress string,
statusRecorder *peer.Status,
) (*DefaultServer, error) {
func NewDefaultServer(ctx context.Context, wgInterface WGIface, customAddress string) (*DefaultServer, error) {
var addrPort *netip.AddrPort
if customAddress != "" {
parsedAddrPort, err := netip.ParseAddrPort(customAddress)
@@ -99,20 +90,13 @@ func NewDefaultServer(
dnsService = newServiceViaListener(wgInterface, addrPort)
}
return newDefaultServer(ctx, wgInterface, dnsService, statusRecorder), nil
return newDefaultServer(ctx, wgInterface, dnsService), nil
}
// NewDefaultServerPermanentUpstream returns a new dns server. It optimized for mobile systems
func NewDefaultServerPermanentUpstream(
ctx context.Context,
wgInterface WGIface,
hostsDnsList []string,
config nbdns.Config,
listener listener.NetworkChangeListener,
statusRecorder *peer.Status,
) *DefaultServer {
func NewDefaultServerPermanentUpstream(ctx context.Context, wgInterface WGIface, hostsDnsList []string, config nbdns.Config, listener listener.NetworkChangeListener) *DefaultServer {
log.Debugf("host dns address list is: %v", hostsDnsList)
ds := newDefaultServer(ctx, wgInterface, newServiceViaMemory(wgInterface), statusRecorder)
ds := newDefaultServer(ctx, wgInterface, newServiceViaMemory(wgInterface))
ds.permanent = true
ds.hostsDnsList = hostsDnsList
ds.addHostRootZone()
@@ -124,18 +108,13 @@ func NewDefaultServerPermanentUpstream(
}
// NewDefaultServerIos returns a new dns server. It optimized for ios
func NewDefaultServerIos(
ctx context.Context,
wgInterface WGIface,
iosDnsManager IosDnsManager,
statusRecorder *peer.Status,
) *DefaultServer {
ds := newDefaultServer(ctx, wgInterface, newServiceViaMemory(wgInterface), statusRecorder)
func NewDefaultServerIos(ctx context.Context, wgInterface WGIface, iosDnsManager IosDnsManager) *DefaultServer {
ds := newDefaultServer(ctx, wgInterface, newServiceViaMemory(wgInterface))
ds.iosDnsManager = iosDnsManager
return ds
}
func newDefaultServer(ctx context.Context, wgInterface WGIface, dnsService service, statusRecorder *peer.Status) *DefaultServer {
func newDefaultServer(ctx context.Context, wgInterface WGIface, dnsService service) *DefaultServer {
ctx, stop := context.WithCancel(ctx)
defaultServer := &DefaultServer{
ctx: ctx,
@@ -145,8 +124,7 @@ func newDefaultServer(ctx context.Context, wgInterface WGIface, dnsService servi
localResolver: &localResolver{
registeredMap: make(registrationMap),
},
wgInterface: wgInterface,
statusRecorder: statusRecorder,
wgInterface: wgInterface,
}
return defaultServer
@@ -278,15 +256,9 @@ func (s *DefaultServer) SearchDomains() []string {
// ProbeAvailability tests each upstream group's servers for availability
// and deactivates the group if no server responds
func (s *DefaultServer) ProbeAvailability() {
var wg sync.WaitGroup
for _, mux := range s.dnsMuxMap {
wg.Add(1)
go func(mux handlerWithStop) {
defer wg.Done()
mux.probeAvailability()
}(mux)
mux.probeAvailability()
}
wg.Wait()
}
func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
@@ -327,8 +299,6 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
s.searchDomainNotifier.onNewSearchDomains(s.SearchDomains())
}
s.updateNSGroupStates(update.NameServerGroups)
return nil
}
@@ -368,13 +338,7 @@ func (s *DefaultServer) buildUpstreamHandlerUpdate(nameServerGroups []*nbdns.Nam
continue
}
handler, err := newUpstreamResolver(
s.ctx,
s.wgInterface.Name(),
s.wgInterface.Address().IP,
s.wgInterface.Address().Network,
s.statusRecorder,
)
handler, err := newUpstreamResolver(s.ctx, s.wgInterface.Name(), s.wgInterface.Address().IP, s.wgInterface.Address().Network)
if err != nil {
return nil, fmt.Errorf("unable to create a new upstream resolver, error: %v", err)
}
@@ -496,14 +460,14 @@ func getNSHostPort(ns nbdns.NameServer) string {
func (s *DefaultServer) upstreamCallbacks(
nsGroup *nbdns.NameServerGroup,
handler dns.Handler,
) (deactivate func(error), reactivate func()) {
) (deactivate func(), reactivate func()) {
var removeIndex map[string]int
deactivate = func(err error) {
deactivate = func() {
s.mux.Lock()
defer s.mux.Unlock()
l := log.WithField("nameservers", nsGroup.NameServers)
l.Info("Temporarily deactivating nameservers group due to timeout")
l.Info("temporary deactivate nameservers group due timeout")
removeIndex = make(map[string]int)
for _, domain := range nsGroup.Domains {
@@ -522,11 +486,8 @@ func (s *DefaultServer) upstreamCallbacks(
}
}
if err := s.hostManager.applyDNSConfig(s.currentConfig); err != nil {
l.Errorf("Failed to apply nameserver deactivation on the host: %v", err)
l.WithError(err).Error("fail to apply nameserver deactivation on the host")
}
s.updateNSState(nsGroup, err, false)
}
reactivate = func() {
s.mux.Lock()
@@ -549,20 +510,12 @@ func (s *DefaultServer) upstreamCallbacks(
if err := s.hostManager.applyDNSConfig(s.currentConfig); err != nil {
l.WithError(err).Error("reactivate temporary disabled nameserver group, DNS update apply")
}
s.updateNSState(nsGroup, nil, true)
}
return
}
func (s *DefaultServer) addHostRootZone() {
handler, err := newUpstreamResolver(
s.ctx,
s.wgInterface.Name(),
s.wgInterface.Address().IP,
s.wgInterface.Address().Network,
s.statusRecorder,
)
handler, err := newUpstreamResolver(s.ctx, s.wgInterface.Name(), s.wgInterface.Address().IP, s.wgInterface.Address().Network)
if err != nil {
log.Errorf("unable to create a new upstream resolver, error: %v", err)
return
@@ -582,50 +535,7 @@ func (s *DefaultServer) addHostRootZone() {
handler.upstreamServers[n] = fmt.Sprintf("%s:53", ipString)
}
handler.deactivate = func(error) {}
handler.deactivate = func() {}
handler.reactivate = func() {}
s.service.RegisterMux(nbdns.RootZone, handler)
}
func (s *DefaultServer) updateNSGroupStates(groups []*nbdns.NameServerGroup) {
var states []peer.NSGroupState
for _, group := range groups {
var servers []string
for _, ns := range group.NameServers {
servers = append(servers, fmt.Sprintf("%s:%d", ns.IP, ns.Port))
}
state := peer.NSGroupState{
ID: generateGroupKey(group),
Servers: servers,
Domains: group.Domains,
// The probe will determine the state, default enabled
Enabled: true,
Error: nil,
}
states = append(states, state)
}
s.statusRecorder.UpdateDNSStates(states)
}
func (s *DefaultServer) updateNSState(nsGroup *nbdns.NameServerGroup, err error, enabled bool) {
states := s.statusRecorder.GetDNSStates()
id := generateGroupKey(nsGroup)
for i, state := range states {
if state.ID == id {
states[i].Enabled = enabled
states[i].Error = err
break
}
}
s.statusRecorder.UpdateDNSStates(states)
}
func generateGroupKey(nsGroup *nbdns.NameServerGroup) string {
var servers []string
for _, ns := range nsGroup.NameServers {
servers = append(servers, fmt.Sprintf("%s:%d", ns.IP, ns.Port))
}
return fmt.Sprintf("%s_%s_%s", nsGroup.ID, nsGroup.Name, strings.Join(servers, ","))
}

View File

@@ -15,7 +15,6 @@ import (
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/client/firewall/uspfilter"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/stdnet"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/formatter"
@@ -275,7 +274,7 @@ func TestUpdateDNSServer(t *testing.T) {
t.Log(err)
}
}()
dnsServer, err := NewDefaultServer(context.Background(), wgIface, "", &peer.Status{})
dnsServer, err := NewDefaultServer(context.Background(), wgIface, "")
if err != nil {
t.Fatal(err)
}
@@ -376,7 +375,7 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) {
return
}
dnsServer, err := NewDefaultServer(context.Background(), wgIface, "", &peer.Status{})
dnsServer, err := NewDefaultServer(context.Background(), wgIface, "")
if err != nil {
t.Errorf("create DNS server: %v", err)
return
@@ -471,7 +470,7 @@ func TestDNSServerStartStop(t *testing.T) {
for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
dnsServer, err := NewDefaultServer(context.Background(), &mocWGIface{}, testCase.addrPort, &peer.Status{})
dnsServer, err := NewDefaultServer(context.Background(), &mocWGIface{}, testCase.addrPort)
if err != nil {
t.Fatalf("%v", err)
}
@@ -542,7 +541,6 @@ func TestDNSServerUpstreamDeactivateCallback(t *testing.T) {
{false, "domain2", false},
},
},
statusRecorder: &peer.Status{},
}
var domainsUpdate string
@@ -565,7 +563,7 @@ func TestDNSServerUpstreamDeactivateCallback(t *testing.T) {
},
}, nil)
deactivate(nil)
deactivate()
expected := "domain0,domain2"
domains := []string{}
for _, item := range server.currentConfig.Domains {
@@ -603,7 +601,7 @@ func TestDNSPermanent_updateHostDNS_emptyUpstream(t *testing.T) {
var dnsList []string
dnsConfig := nbdns.Config{}
dnsServer := NewDefaultServerPermanentUpstream(context.Background(), wgIFace, dnsList, dnsConfig, nil, &peer.Status{})
dnsServer := NewDefaultServerPermanentUpstream(context.Background(), wgIFace, dnsList, dnsConfig, nil)
err = dnsServer.Initialize()
if err != nil {
t.Errorf("failed to initialize DNS server: %v", err)
@@ -627,7 +625,7 @@ func TestDNSPermanent_updateUpstream(t *testing.T) {
}
defer wgIFace.Close()
dnsConfig := nbdns.Config{}
dnsServer := NewDefaultServerPermanentUpstream(context.Background(), wgIFace, []string{"8.8.8.8"}, dnsConfig, nil, &peer.Status{})
dnsServer := NewDefaultServerPermanentUpstream(context.Background(), wgIFace, []string{"8.8.8.8"}, dnsConfig, nil)
err = dnsServer.Initialize()
if err != nil {
t.Errorf("failed to initialize DNS server: %v", err)
@@ -719,7 +717,7 @@ func TestDNSPermanent_matchOnly(t *testing.T) {
}
defer wgIFace.Close()
dnsConfig := nbdns.Config{}
dnsServer := NewDefaultServerPermanentUpstream(context.Background(), wgIFace, []string{"8.8.8.8"}, dnsConfig, nil, &peer.Status{})
dnsServer := NewDefaultServerPermanentUpstream(context.Background(), wgIFace, []string{"8.8.8.8"}, dnsConfig, nil)
err = dnsServer.Initialize()
if err != nil {
t.Errorf("failed to initialize DNS server: %v", err)
@@ -750,11 +748,6 @@ func TestDNSPermanent_matchOnly(t *testing.T) {
NSType: nbdns.UDPNameServerType,
Port: 53,
},
{
IP: netip.MustParseAddr("9.9.9.9"),
NSType: nbdns.UDPNameServerType,
Port: 53,
},
},
Domains: []string{"customdomain.com"},
Primary: false,

View File

@@ -11,11 +11,8 @@ import (
"time"
"github.com/cenkalti/backoff/v4"
"github.com/hashicorp/go-multierror"
"github.com/miekg/dns"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/peer"
)
const (
@@ -48,13 +45,12 @@ type upstreamResolverBase struct {
reactivatePeriod time.Duration
upstreamTimeout time.Duration
deactivate func(error)
reactivate func()
statusRecorder *peer.Status
deactivate func()
reactivate func()
}
func newUpstreamResolverBase(ctx context.Context, statusRecorder *peer.Status) *upstreamResolverBase {
ctx, cancel := context.WithCancel(ctx)
func newUpstreamResolverBase(parentCTX context.Context) *upstreamResolverBase {
ctx, cancel := context.WithCancel(parentCTX)
return &upstreamResolverBase{
ctx: ctx,
@@ -62,7 +58,6 @@ func newUpstreamResolverBase(ctx context.Context, statusRecorder *peer.Status) *
upstreamTimeout: upstreamTimeout,
reactivatePeriod: reactivatePeriod,
failsTillDeact: failsTillDeact,
statusRecorder: statusRecorder,
}
}
@@ -73,10 +68,7 @@ func (u *upstreamResolverBase) stop() {
// ServeDNS handles a DNS request
func (u *upstreamResolverBase) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
var err error
defer func() {
u.checkUpstreamFails(err)
}()
defer u.checkUpstreamFails()
log.WithField("question", r.Question[0]).Trace("received an upstream question")
@@ -89,6 +81,7 @@ func (u *upstreamResolverBase) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
for _, upstream := range u.upstreamServers {
var rm *dns.Msg
var t time.Duration
var err error
func() {
ctx, cancel := context.WithTimeout(u.ctx, u.upstreamTimeout)
@@ -139,7 +132,7 @@ func (u *upstreamResolverBase) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
// If fails count is greater that failsTillDeact, upstream resolving
// will be disabled for reactivatePeriod, after that time period fails counter
// will be reset and upstream will be reactivated.
func (u *upstreamResolverBase) checkUpstreamFails(err error) {
func (u *upstreamResolverBase) checkUpstreamFails() {
u.mutex.Lock()
defer u.mutex.Unlock()
@@ -153,7 +146,7 @@ func (u *upstreamResolverBase) checkUpstreamFails(err error) {
default:
}
u.disable(err)
u.disable()
}
// probeAvailability tests all upstream servers simultaneously and
@@ -172,16 +165,13 @@ func (u *upstreamResolverBase) probeAvailability() {
var mu sync.Mutex
var wg sync.WaitGroup
var errors *multierror.Error
for _, upstream := range u.upstreamServers {
upstream := upstream
wg.Add(1)
go func() {
defer wg.Done()
err := u.testNameserver(upstream)
if err != nil {
errors = multierror.Append(errors, err)
if err := u.testNameserver(upstream); err != nil {
log.Warnf("probing upstream nameserver %s: %s", upstream, err)
return
}
@@ -196,7 +186,7 @@ func (u *upstreamResolverBase) probeAvailability() {
// didn't find a working upstream server, let's disable and try later
if !success {
u.disable(errors.ErrorOrNil())
u.disable()
}
}
@@ -255,15 +245,15 @@ func isTimeout(err error) bool {
return false
}
func (u *upstreamResolverBase) disable(err error) {
func (u *upstreamResolverBase) disable() {
if u.disabled {
return
}
// todo test the deactivation logic, it seems to affect the client
if runtime.GOOS != "ios" {
log.Warnf("Upstream resolving is Disabled for %v", reactivatePeriod)
u.deactivate(err)
log.Warnf("upstream resolving is Disabled for %v", reactivatePeriod)
u.deactivate()
u.disabled = true
go u.waitUntilResponse()
}

View File

@@ -11,8 +11,6 @@ import (
"github.com/miekg/dns"
log "github.com/sirupsen/logrus"
"golang.org/x/sys/unix"
"github.com/netbirdio/netbird/client/internal/peer"
)
type upstreamResolverIOS struct {
@@ -22,14 +20,8 @@ type upstreamResolverIOS struct {
iIndex int
}
func newUpstreamResolver(
ctx context.Context,
interfaceName string,
ip net.IP,
net *net.IPNet,
statusRecorder *peer.Status,
) (*upstreamResolverIOS, error) {
upstreamResolverBase := newUpstreamResolverBase(ctx, statusRecorder)
func newUpstreamResolver(parentCTX context.Context, interfaceName string, ip net.IP, net *net.IPNet) (*upstreamResolverIOS, error) {
upstreamResolverBase := newUpstreamResolverBase(parentCTX)
index, err := getInterfaceIndex(interfaceName)
if err != nil {

View File

@@ -8,22 +8,14 @@ import (
"time"
"github.com/miekg/dns"
"github.com/netbirdio/netbird/client/internal/peer"
)
type upstreamResolverNonIOS struct {
*upstreamResolverBase
}
func newUpstreamResolver(
ctx context.Context,
_ string,
_ net.IP,
_ *net.IPNet,
statusRecorder *peer.Status,
) (*upstreamResolverNonIOS, error) {
upstreamResolverBase := newUpstreamResolverBase(ctx, statusRecorder)
func newUpstreamResolver(parentCTX context.Context, interfaceName string, ip net.IP, net *net.IPNet) (*upstreamResolverNonIOS, error) {
upstreamResolverBase := newUpstreamResolverBase(parentCTX)
nonIOS := &upstreamResolverNonIOS{
upstreamResolverBase: upstreamResolverBase,
}

View File

@@ -58,7 +58,7 @@ func TestUpstreamResolver_ServeDNS(t *testing.T) {
for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
ctx, cancel := context.WithCancel(context.TODO())
resolver, _ := newUpstreamResolver(ctx, "", net.IP{}, &net.IPNet{}, nil)
resolver, _ := newUpstreamResolver(ctx, "", net.IP{}, &net.IPNet{})
resolver.upstreamServers = testCase.InputServers
resolver.upstreamTimeout = testCase.timeout
if testCase.cancelCTX {
@@ -131,7 +131,7 @@ func TestUpstreamResolver_DeactivationReactivation(t *testing.T) {
}
failed := false
resolver.deactivate = func(error) {
resolver.deactivate = func() {
failed = true
}

View File

@@ -79,10 +79,7 @@ type EngineConfig struct {
CustomDNSAddress string
RosenpassEnabled bool
RosenpassPermissive bool
ServerSSHAllowed bool
RosenpassEnabled bool
}
// Engine is a mechanism responsible for reacting on Signal and Management stream events and managing connections to the remote peers.
@@ -93,10 +90,6 @@ type Engine struct {
mgmClient mgm.Client
// peerConns is a map that holds all the peers that are known to this peer
peerConns map[string]*peer.Conn
beforePeerHook peer.BeforeAddPeerHookFunc
afterPeerHook peer.AfterRemovePeerHookFunc
// rpManager is a Rosenpass manager
rpManager *rosenpass.Manager
@@ -111,9 +104,6 @@ type Engine struct {
// TURNs is a list of STUN servers used by ICE
TURNs []*stun.URI
// clientRoutes is the most recent list of clientRoutes received from the Management Service
clientRoutes map[string][]*route.Route
cancel context.CancelFunc
ctx context.Context
@@ -219,8 +209,6 @@ func (e *Engine) Stop() error {
return err
}
e.clientRoutes = nil
// very ugly but we want to remove peers from the WireGuard interface first before removing interface.
// Removing peers happens in the conn.CLose() asynchronously
time.Sleep(500 * time.Millisecond)
@@ -239,51 +227,38 @@ func (e *Engine) Start() error {
wgIface, err := e.newWgIface()
if err != nil {
log.Errorf("failed creating wireguard interface instance %s: [%s]", e.config.WgIfaceName, err)
return fmt.Errorf("new wg interface: %w", err)
log.Errorf("failed creating wireguard interface instance %s: [%s]", e.config.WgIfaceName, err.Error())
return err
}
e.wgInterface = wgIface
if e.config.RosenpassEnabled {
log.Infof("rosenpass is enabled")
if e.config.RosenpassPermissive {
log.Infof("running rosenpass in permissive mode")
} else {
log.Infof("running rosenpass in strict mode")
}
e.rpManager, err = rosenpass.NewManager(e.config.PreSharedKey, e.config.WgIfaceName)
if err != nil {
return fmt.Errorf("create rosenpass manager: %w", err)
return err
}
err := e.rpManager.Run()
if err != nil {
return fmt.Errorf("run rosenpass manager: %w", err)
return err
}
}
initialRoutes, dnsServer, err := e.newDnsServer()
if err != nil {
e.close()
return fmt.Errorf("create dns server: %w", err)
return err
}
e.dnsServer = dnsServer
e.routeManager = routemanager.NewManager(e.ctx, e.config.WgPrivateKey.PublicKey().String(), e.wgInterface, e.statusRecorder, initialRoutes)
beforePeerHook, afterPeerHook, err := e.routeManager.Init()
if err != nil {
log.Errorf("Failed to initialize route manager: %s", err)
} else {
e.beforePeerHook = beforePeerHook
e.afterPeerHook = afterPeerHook
}
e.routeManager.SetRouteChangeListener(e.mobileDep.NetworkChangeListener)
err = e.wgInterfaceCreate()
if err != nil {
log.Errorf("failed creating tunnel interface %s: [%s]", e.config.WgIfaceName, err.Error())
e.close()
return fmt.Errorf("create wg interface: %w", err)
return err
}
e.firewall, err = firewall.NewFirewall(e.ctx, e.wgInterface)
@@ -295,7 +270,7 @@ func (e *Engine) Start() error {
err = e.routeManager.EnableServerRouter(e.firewall)
if err != nil {
e.close()
return fmt.Errorf("enable server router: %w", err)
return err
}
}
@@ -303,7 +278,7 @@ func (e *Engine) Start() error {
if err != nil {
log.Errorf("failed to pull up wgInterface [%s]: %s", e.wgInterface.Name(), err.Error())
e.close()
return fmt.Errorf("up wg interface: %w", err)
return err
}
if e.firewall != nil {
@@ -313,7 +288,7 @@ func (e *Engine) Start() error {
err = e.dnsServer.Initialize()
if err != nil {
e.close()
return fmt.Errorf("initialize dns server: %w", err)
return err
}
e.receiveSignalEvents()
@@ -507,52 +482,44 @@ func isNil(server nbssh.Server) bool {
}
func (e *Engine) updateSSH(sshConf *mgmProto.SSHConfig) error {
if !e.config.ServerSSHAllowed {
log.Warnf("running SSH server is not permitted")
return nil
} else {
if sshConf.GetSshEnabled() {
if runtime.GOOS == "windows" {
log.Warnf("running SSH server on Windows is not supported")
return nil
}
// start SSH server if it wasn't running
if isNil(e.sshServer) {
// nil sshServer means it has not yet been started
var err error
e.sshServer, err = e.sshServerFunc(e.config.SSHKey,
fmt.Sprintf("%s:%d", e.wgInterface.Address().IP.String(), nbssh.DefaultSSHPort))
if err != nil {
return err
}
go func() {
// blocking
err = e.sshServer.Start()
if err != nil {
// will throw error when we stop it even if it is a graceful stop
log.Debugf("stopped SSH server with error %v", err)
}
e.syncMsgMux.Lock()
defer e.syncMsgMux.Unlock()
e.sshServer = nil
log.Infof("stopped SSH server")
}()
} else {
log.Debugf("SSH server is already running")
}
} else if !isNil(e.sshServer) {
// Disable SSH server request, so stop it if it was running
err := e.sshServer.Stop()
if err != nil {
log.Warnf("failed to stop SSH server %v", err)
}
e.sshServer = nil
if sshConf.GetSshEnabled() {
if runtime.GOOS == "windows" {
log.Warnf("running SSH server on Windows is not supported")
return nil
}
return nil
// start SSH server if it wasn't running
if isNil(e.sshServer) {
// nil sshServer means it has not yet been started
var err error
e.sshServer, err = e.sshServerFunc(e.config.SSHKey,
fmt.Sprintf("%s:%d", e.wgInterface.Address().IP.String(), nbssh.DefaultSSHPort))
if err != nil {
return err
}
go func() {
// blocking
err = e.sshServer.Start()
if err != nil {
// will throw error when we stop it even if it is a graceful stop
log.Debugf("stopped SSH server with error %v", err)
}
e.syncMsgMux.Lock()
defer e.syncMsgMux.Unlock()
e.sshServer = nil
log.Infof("stopped SSH server")
}()
} else {
log.Debugf("SSH server is already running")
}
} else if !isNil(e.sshServer) {
// Disable SSH server request, so stop it if it was running
err := e.sshServer.Stop()
if err != nil {
log.Warnf("failed to stop SSH server %v", err)
}
e.sshServer = nil
}
return nil
}
func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error {
@@ -700,14 +667,11 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
if protoRoutes == nil {
protoRoutes = []*mgmProto.Route{}
}
_, clientRoutes, err := e.routeManager.UpdateRoutes(serial, toRoutes(protoRoutes))
err := e.routeManager.UpdateRoutes(serial, toRoutes(protoRoutes))
if err != nil {
log.Errorf("failed to update clientRoutes, err: %v", err)
log.Errorf("failed to update routes, err: %v", err)
}
e.clientRoutes = clientRoutes
protoDNSConfig := networkMap.GetDNSConfig()
if protoDNSConfig == nil {
protoDNSConfig = &mgmProto.DNSConfig{}
@@ -718,16 +682,15 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
log.Errorf("failed to update dns server, err: %v", err)
}
if e.acl != nil {
e.acl.ApplyFiltering(networkMap)
}
e.networkSerial = serial
// Test received (upstream) servers for availability right away instead of upon usage.
// If no server of a server group responds this will disable the respective handler and retry later.
e.dnsServer.ProbeAvailability()
if e.acl != nil {
e.acl.ApplyFiltering(networkMap)
}
e.networkSerial = serial
return nil
}
@@ -802,7 +765,6 @@ func (e *Engine) updateOfflinePeers(offlinePeers []*mgmProto.RemotePeerConfig) {
FQDN: offlinePeer.GetFqdn(),
ConnStatus: peer.StatusDisconnected,
ConnStatusUpdate: time.Now(),
Mux: new(sync.RWMutex),
}
}
e.statusRecorder.ReplaceOfflinePeers(replacement)
@@ -826,15 +788,10 @@ func (e *Engine) addNewPeer(peerConfig *mgmProto.RemotePeerConfig) error {
if _, ok := e.peerConns[peerKey]; !ok {
conn, err := e.createPeerConn(peerKey, strings.Join(peerIPs, ","))
if err != nil {
return fmt.Errorf("create peer connection: %w", err)
return err
}
e.peerConns[peerKey] = conn
if e.beforePeerHook != nil && e.afterPeerHook != nil {
conn.AddBeforeAddPeerHook(e.beforePeerHook)
conn.AddAfterRemovePeerHook(e.afterPeerHook)
}
err = e.statusRecorder.AddPeer(peerKey, peerConfig.Fqdn)
if err != nil {
log.Warnf("error adding peer %s to status recorder, got error: %v", peerKey, err)
@@ -903,7 +860,7 @@ func (e *Engine) createPeerConn(pubKey string, allowedIPs string) (*peer.Conn, e
PreSharedKey: e.config.PreSharedKey,
}
if e.config.RosenpassEnabled && !e.config.RosenpassPermissive {
if e.config.RosenpassEnabled {
lk := []byte(e.config.WgPrivateKey.PublicKey().String())
rk := []byte(wgConfig.RemoteKey)
var keyInput []byte
@@ -1128,10 +1085,6 @@ func (e *Engine) close() {
e.dnsServer.Stop()
}
if e.routeManager != nil {
e.routeManager.Stop()
}
log.Debugf("removing Netbird interface %s", e.config.WgIfaceName)
if e.wgInterface != nil {
if err := e.wgInterface.Close(); err != nil {
@@ -1146,6 +1099,10 @@ func (e *Engine) close() {
}
}
if e.routeManager != nil {
e.routeManager.Stop()
}
if e.firewall != nil {
err := e.firewall.Reset()
if err != nil {
@@ -1215,21 +1172,14 @@ func (e *Engine) newDnsServer() ([]*route.Route, dns.Server, error) {
if err != nil {
return nil, nil, err
}
dnsServer := dns.NewDefaultServerPermanentUpstream(
e.ctx,
e.wgInterface,
e.mobileDep.HostDNSAddresses,
*dnsConfig,
e.mobileDep.NetworkChangeListener,
e.statusRecorder,
)
dnsServer := dns.NewDefaultServerPermanentUpstream(e.ctx, e.wgInterface, e.mobileDep.HostDNSAddresses, *dnsConfig, e.mobileDep.NetworkChangeListener)
go e.mobileDep.DnsReadyListener.OnReady()
return routes, dnsServer, nil
case "ios":
dnsServer := dns.NewDefaultServerIos(e.ctx, e.wgInterface, e.mobileDep.DnsManager, e.statusRecorder)
dnsServer := dns.NewDefaultServerIos(e.ctx, e.wgInterface, e.mobileDep.DnsManager)
return nil, dnsServer, nil
default:
dnsServer, err := dns.NewDefaultServer(e.ctx, e.wgInterface, e.config.CustomDNSAddress, e.statusRecorder)
dnsServer, err := dns.NewDefaultServer(e.ctx, e.wgInterface, e.config.CustomDNSAddress)
if err != nil {
return nil, nil, err
}
@@ -1237,28 +1187,6 @@ func (e *Engine) newDnsServer() ([]*route.Route, dns.Server, error) {
}
}
// GetClientRoutes returns the current routes from the route map
func (e *Engine) GetClientRoutes() map[string][]*route.Route {
return e.clientRoutes
}
// GetClientRoutesWithNetID returns the current routes from the route map, but the keys consist of the network ID only
func (e *Engine) GetClientRoutesWithNetID() map[string][]*route.Route {
routes := make(map[string][]*route.Route, len(e.clientRoutes))
for id, v := range e.clientRoutes {
if i := strings.LastIndex(id, "-"); i != -1 {
id = id[:i]
}
routes[id] = v
}
return routes
}
// GetRouteManager returns the route manager
func (e *Engine) GetRouteManager() routemanager.Manager {
return e.routeManager
}
func findIPFromInterfaceName(ifaceName string) (net.IP, error) {
iface, err := net.InterfaceByName(ifaceName)
if err != nil {
@@ -1342,7 +1270,7 @@ func (e *Engine) receiveProbeEvents() {
log.Debugf("failed to get wg stats for peer %s: %s", key, err)
}
// wgStats could be zero value, in which case we just reset the stats
if err := e.statusRecorder.UpdateWireGuardPeerState(key, wgStats); err != nil {
if err := e.statusRecorder.UpdateWireguardPeerState(key, wgStats); err != nil {
log.Debugf("failed to update wg stats for peer %s: %s", key, err)
}
}

View File

@@ -21,8 +21,6 @@ import (
"google.golang.org/grpc"
"google.golang.org/grpc/keepalive"
"github.com/netbirdio/management-integrations/integrations"
"github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/routemanager"
@@ -72,11 +70,10 @@ func TestEngine_SSH(t *testing.T) {
defer cancel()
engine := NewEngine(ctx, cancel, &signal.MockClient{}, &mgmt.MockClient{}, &EngineConfig{
WgIfaceName: "utun101",
WgAddr: "100.64.0.1/24",
WgPrivateKey: key,
WgPort: 33100,
ServerSSHAllowed: true,
WgIfaceName: "utun101",
WgAddr: "100.64.0.1/24",
WgPrivateKey: key,
WgPort: 33100,
}, MobileDependency{}, peer.NewRecorder("https://mgm"))
engine.dnsServer = &dns.MockServer{
@@ -578,10 +575,10 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) {
}{}
mockRouteManager := &routemanager.MockManager{
UpdateRoutesFunc: func(updateSerial uint64, newRoutes []*route.Route) (map[string]*route.Route, map[string][]*route.Route, error) {
UpdateRoutesFunc: func(updateSerial uint64, newRoutes []*route.Route) error {
input.inputSerial = updateSerial
input.inputRoutes = newRoutes
return nil, nil, testCase.inputErr
return testCase.inputErr
},
}
@@ -598,8 +595,8 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) {
err = engine.updateNetworkMap(testCase.networkMap)
assert.NoError(t, err, "shouldn't return error")
assert.Equal(t, testCase.expectedSerial, input.inputSerial, "serial should match")
assert.Len(t, input.inputRoutes, testCase.expectedLen, "clientRoutes len should match")
assert.Equal(t, testCase.expectedRoutes, input.inputRoutes, "clientRoutes should match")
assert.Len(t, input.inputRoutes, testCase.expectedLen, "routes len should match")
assert.Equal(t, testCase.expectedRoutes, input.inputRoutes, "routes should match")
})
}
}
@@ -743,8 +740,8 @@ func TestEngine_UpdateNetworkMapWithDNSUpdate(t *testing.T) {
assert.NoError(t, err, "shouldn't return error")
mockRouteManager := &routemanager.MockManager{
UpdateRoutesFunc: func(updateSerial uint64, newRoutes []*route.Route) (map[string]*route.Route, map[string][]*route.Route, error) {
return nil, nil, nil
UpdateRoutesFunc: func(updateSerial uint64, newRoutes []*route.Route) error {
return nil
},
}
@@ -1052,8 +1049,8 @@ func startManagement(dataDir string) (*grpc.Server, string, error) {
if err != nil {
return nil, "", err
}
ia, _ := integrations.NewIntegratedValidator(eventStore)
accountManager, err := server.BuildManager(store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia)
accountManager, err := server.BuildManager(store, peersUpdateManager, nil, "", "",
eventStore, false)
if err != nil {
return nil, "", err
}

View File

@@ -1 +0,0 @@
package internal

View File

@@ -20,15 +20,12 @@ import (
"github.com/netbirdio/netbird/iface/bind"
signal "github.com/netbirdio/netbird/signal/client"
sProto "github.com/netbirdio/netbird/signal/proto"
nbnet "github.com/netbirdio/netbird/util/net"
"github.com/netbirdio/netbird/version"
)
const (
iceKeepAliveDefault = 4 * time.Second
iceDisconnectedTimeoutDefault = 6 * time.Second
// iceRelayAcceptanceMinWaitDefault is the same as in the Pion ICE package
iceRelayAcceptanceMinWaitDefault = 2 * time.Second
defaultWgKeepAlive = 25 * time.Second
)
@@ -101,9 +98,6 @@ type IceCredentials struct {
Pwd string
}
type BeforeAddPeerHookFunc func(connID nbnet.ConnectionID, IP net.IP) error
type AfterRemovePeerHookFunc func(connID nbnet.ConnectionID) error
type Conn struct {
config ConnConfig
mu sync.Mutex
@@ -139,13 +133,6 @@ type Conn struct {
adapter iface.TunAdapter
iFaceDiscover stdnet.ExternalIFaceDiscover
sentExtraSrflx bool
remoteEndpoint *net.UDPAddr
remoteConn *ice.Conn
connID nbnet.ConnectionID
beforeAddPeerHooks []BeforeAddPeerHookFunc
afterRemovePeerHooks []AfterRemovePeerHookFunc
}
// meta holds meta information about a connection
@@ -206,22 +193,20 @@ func (conn *Conn) reCreateAgent() error {
iceKeepAlive := iceKeepAlive()
iceDisconnectedTimeout := iceDisconnectedTimeout()
iceRelayAcceptanceMinWait := iceRelayAcceptanceMinWait()
agentConfig := &ice.AgentConfig{
MulticastDNSMode: ice.MulticastDNSModeDisabled,
NetworkTypes: []ice.NetworkType{ice.NetworkTypeUDP4, ice.NetworkTypeUDP6},
Urls: conn.config.StunTurn,
CandidateTypes: conn.candidateTypes(),
FailedTimeout: &failedTimeout,
InterfaceFilter: stdnet.InterfaceFilter(conn.config.InterfaceBlackList),
UDPMux: conn.config.UDPMux,
UDPMuxSrflx: conn.config.UDPMuxSrflx,
NAT1To1IPs: conn.config.NATExternalIPs,
Net: transportNet,
DisconnectedTimeout: &iceDisconnectedTimeout,
KeepaliveInterval: &iceKeepAlive,
RelayAcceptanceMinWait: &iceRelayAcceptanceMinWait,
MulticastDNSMode: ice.MulticastDNSModeDisabled,
NetworkTypes: []ice.NetworkType{ice.NetworkTypeUDP4, ice.NetworkTypeUDP6},
Urls: conn.config.StunTurn,
CandidateTypes: conn.candidateTypes(),
FailedTimeout: &failedTimeout,
InterfaceFilter: stdnet.InterfaceFilter(conn.config.InterfaceBlackList),
UDPMux: conn.config.UDPMux,
UDPMuxSrflx: conn.config.UDPMuxSrflx,
NAT1To1IPs: conn.config.NATExternalIPs,
Net: transportNet,
DisconnectedTimeout: &iceDisconnectedTimeout,
KeepaliveInterval: &iceKeepAlive,
}
if conn.config.DisableIPv6Discovery {
@@ -229,6 +214,7 @@ func (conn *Conn) reCreateAgent() error {
}
conn.agent, err = ice.NewAgent(agentConfig)
if err != nil {
return err
}
@@ -248,17 +234,6 @@ func (conn *Conn) reCreateAgent() error {
return err
}
err = conn.agent.OnSuccessfulSelectedPairBindingResponse(func(p *ice.CandidatePair) {
err := conn.statusRecorder.UpdateLatency(conn.config.Key, p.Latency())
if err != nil {
log.Debugf("failed to update latency for peer %s: %s", conn.config.Key, err)
return
}
})
if err != nil {
return fmt.Errorf("failed setting binding response callback: %w", err)
}
return nil
}
@@ -284,7 +259,6 @@ func (conn *Conn) Open() error {
IP: strings.Split(conn.config.WgConfig.AllowedIps, "/")[0],
ConnStatusUpdate: time.Now(),
ConnStatus: conn.status,
Mux: new(sync.RWMutex),
}
err := conn.statusRecorder.UpdatePeerState(peerState)
if err != nil {
@@ -344,7 +318,6 @@ func (conn *Conn) Open() error {
PubKey: conn.config.Key,
ConnStatus: conn.status,
ConnStatusUpdate: time.Now(),
Mux: new(sync.RWMutex),
}
err = conn.statusRecorder.UpdatePeerState(peerState)
if err != nil {
@@ -375,9 +348,6 @@ func (conn *Conn) Open() error {
if remoteOfferAnswer.WgListenPort != 0 {
remoteWgPort = remoteOfferAnswer.WgListenPort
}
conn.remoteConn = remoteConn
// the ice connection has been established successfully so we are ready to start the proxy
remoteAddr, err := conn.configureConnection(remoteConn, remoteWgPort, remoteOfferAnswer.RosenpassPubKey,
remoteOfferAnswer.RosenpassAddr)
@@ -402,14 +372,6 @@ func isRelayCandidate(candidate ice.Candidate) bool {
return candidate.Type() == ice.CandidateTypeRelay
}
func (conn *Conn) AddBeforeAddPeerHook(hook BeforeAddPeerHookFunc) {
conn.beforeAddPeerHooks = append(conn.beforeAddPeerHooks, hook)
}
func (conn *Conn) AddAfterRemovePeerHook(hook AfterRemovePeerHookFunc) {
conn.afterRemovePeerHooks = append(conn.afterRemovePeerHooks, hook)
}
// configureConnection starts proxying traffic from/to local Wireguard and sets connection status to StatusConnected
func (conn *Conn) configureConnection(remoteConn net.Conn, remoteWgPort int, remoteRosenpassPubKey []byte, remoteRosenpassAddr string) (net.Addr, error) {
conn.mu.Lock()
@@ -435,15 +397,6 @@ func (conn *Conn) configureConnection(remoteConn net.Conn, remoteWgPort int, rem
}
endpointUdpAddr, _ := net.ResolveUDPAddr(endpoint.Network(), endpoint.String())
conn.remoteEndpoint = endpointUdpAddr
log.Debugf("Conn resolved IP for %s: %s", endpoint, endpointUdpAddr.IP)
conn.connID = nbnet.GenerateConnID()
for _, hook := range conn.beforeAddPeerHooks {
if err := hook(conn.connID, endpointUdpAddr.IP); err != nil {
log.Errorf("Before add peer hook failed: %v", err)
}
}
err = conn.config.WgConfig.WgInterface.UpdatePeer(conn.config.WgConfig.RemoteKey, conn.config.WgConfig.AllowedIps, defaultWgKeepAlive, endpointUdpAddr, conn.config.WgConfig.PreSharedKey)
if err != nil {
@@ -454,10 +407,6 @@ func (conn *Conn) configureConnection(remoteConn net.Conn, remoteWgPort int, rem
}
conn.status = StatusConnected
rosenpassEnabled := false
if remoteRosenpassPubKey != nil {
rosenpassEnabled = true
}
peerState := State{
PubKey: conn.config.Key,
@@ -466,10 +415,8 @@ func (conn *Conn) configureConnection(remoteConn net.Conn, remoteWgPort int, rem
LocalIceCandidateType: pair.Local.Type().String(),
RemoteIceCandidateType: pair.Remote.Type().String(),
LocalIceCandidateEndpoint: fmt.Sprintf("%s:%d", pair.Local.Address(), pair.Local.Port()),
RemoteIceCandidateEndpoint: fmt.Sprintf("%s:%d", pair.Remote.Address(), pair.Remote.Port()),
RemoteIceCandidateEndpoint: fmt.Sprintf("%s:%d", pair.Remote.Address(), pair.Local.Port()),
Direct: !isRelayCandidate(pair.Local),
RosenpassEnabled: rosenpassEnabled,
Mux: new(sync.RWMutex),
}
if pair.Local.Type() == ice.CandidateTypeRelay || pair.Remote.Type() == ice.CandidateTypeRelay {
peerState.Relayed = true
@@ -536,15 +483,6 @@ func (conn *Conn) cleanup() error {
// todo: is it problem if we try to remove a peer what is never existed?
err3 = conn.config.WgConfig.WgInterface.RemovePeer(conn.config.WgConfig.RemoteKey)
if conn.connID != "" {
for _, hook := range conn.afterRemovePeerHooks {
if err := hook(conn.connID); err != nil {
log.Errorf("After remove peer hook failed: %v", err)
}
}
}
conn.connID = ""
if conn.notifyDisconnected != nil {
conn.notifyDisconnected()
conn.notifyDisconnected = nil
@@ -560,7 +498,6 @@ func (conn *Conn) cleanup() error {
PubKey: conn.config.Key,
ConnStatus: conn.status,
ConnStatusUpdate: time.Now(),
Mux: new(sync.RWMutex),
}
err := conn.statusRecorder.UpdatePeerState(peerState)
if err != nil {
@@ -568,7 +505,7 @@ func (conn *Conn) cleanup() error {
// todo rethink status updates
log.Debugf("error while updating peer's %s state, err: %v", conn.config.Key, err)
}
if err := conn.statusRecorder.UpdateWireGuardPeerState(conn.config.Key, iface.WGStats{}); err != nil {
if err := conn.statusRecorder.UpdateWireguardPeerState(conn.config.Key, iface.WGStats{}); err != nil {
log.Debugf("failed to reset wireguard stats for peer %s: %s", conn.config.Key, err)
}

View File

@@ -10,10 +10,9 @@ import (
)
const (
envICEKeepAliveIntervalSec = "NB_ICE_KEEP_ALIVE_INTERVAL_SEC"
envICEDisconnectedTimeoutSec = "NB_ICE_DISCONNECTED_TIMEOUT_SEC"
envICERelayAcceptanceMinWaitSec = "NB_ICE_RELAY_ACCEPTANCE_MIN_WAIT_SEC"
envICEForceRelayConn = "NB_ICE_FORCE_RELAY_CONN"
envICEKeepAliveIntervalSec = "NB_ICE_KEEP_ALIVE_INTERVAL_SEC"
envICEDisconnectedTimeoutSec = "NB_ICE_DISCONNECTED_TIMEOUT_SEC"
envICEForceRelayConn = "NB_ICE_FORCE_RELAY_CONN"
)
func iceKeepAlive() time.Duration {
@@ -22,7 +21,7 @@ func iceKeepAlive() time.Duration {
return iceKeepAliveDefault
}
log.Infof("setting ICE keep alive interval to %s seconds", keepAliveEnv)
log.Debugf("setting ICE keep alive interval to %s seconds", keepAliveEnv)
keepAliveEnvSec, err := strconv.Atoi(keepAliveEnv)
if err != nil {
log.Warnf("invalid value %s set for %s, using default %v", keepAliveEnv, envICEKeepAliveIntervalSec, iceKeepAliveDefault)
@@ -38,7 +37,7 @@ func iceDisconnectedTimeout() time.Duration {
return iceDisconnectedTimeoutDefault
}
log.Infof("setting ICE disconnected timeout to %s seconds", disconnectedTimeoutEnv)
log.Debugf("setting ICE disconnected timeout to %s seconds", disconnectedTimeoutEnv)
disconnectedTimeoutSec, err := strconv.Atoi(disconnectedTimeoutEnv)
if err != nil {
log.Warnf("invalid value %s set for %s, using default %v", disconnectedTimeoutEnv, envICEDisconnectedTimeoutSec, iceDisconnectedTimeoutDefault)
@@ -48,22 +47,6 @@ func iceDisconnectedTimeout() time.Duration {
return time.Duration(disconnectedTimeoutSec) * time.Second
}
func iceRelayAcceptanceMinWait() time.Duration {
iceRelayAcceptanceMinWaitEnv := os.Getenv(envICERelayAcceptanceMinWaitSec)
if iceRelayAcceptanceMinWaitEnv == "" {
return iceRelayAcceptanceMinWaitDefault
}
log.Infof("setting ICE relay acceptance min wait to %s seconds", iceRelayAcceptanceMinWaitEnv)
disconnectedTimeoutSec, err := strconv.Atoi(iceRelayAcceptanceMinWaitEnv)
if err != nil {
log.Warnf("invalid value %s set for %s, using default %v", iceRelayAcceptanceMinWaitEnv, envICERelayAcceptanceMinWaitSec, iceRelayAcceptanceMinWaitDefault)
return iceRelayAcceptanceMinWaitDefault
}
return time.Duration(disconnectedTimeoutSec) * time.Second
}
func hasICEForceRelayConn() bool {
disconnectedTimeoutEnv := os.Getenv(envICEForceRelayConn)
return strings.ToLower(disconnectedTimeoutEnv) == "true"

View File

@@ -5,16 +5,12 @@ import (
"sync"
"time"
"google.golang.org/grpc/codes"
gstatus "google.golang.org/grpc/status"
"github.com/netbirdio/netbird/client/internal/relay"
"github.com/netbirdio/netbird/iface"
)
// State contains the latest state of a peer
type State struct {
Mux *sync.RWMutex
IP string
PubKey string
FQDN string
@@ -29,40 +25,6 @@ type State struct {
LastWireguardHandshake time.Time
BytesTx int64
BytesRx int64
Latency time.Duration
RosenpassEnabled bool
routes map[string]struct{}
}
// AddRoute add a single route to routes map
func (s *State) AddRoute(network string) {
s.Mux.Lock()
if s.routes == nil {
s.routes = make(map[string]struct{})
}
s.routes[network] = struct{}{}
s.Mux.Unlock()
}
// SetRoutes set state routes
func (s *State) SetRoutes(routes map[string]struct{}) {
s.Mux.Lock()
s.routes = routes
s.Mux.Unlock()
}
// DeleteRoute removes a route from the network amp
func (s *State) DeleteRoute(network string) {
s.Mux.Lock()
delete(s.routes, network)
s.Mux.Unlock()
}
// GetRoutes return routes map
func (s *State) GetRoutes() map[string]struct{} {
s.Mux.RLock()
defer s.Mux.RUnlock()
return s.routes
}
// LocalPeerState contains the latest state of the local peer
@@ -71,7 +33,6 @@ type LocalPeerState struct {
PubKey string
KernelInterface bool
FQDN string
Routes map[string]struct{}
}
// SignalState contains the latest state of a signal connection
@@ -88,51 +49,30 @@ type ManagementState struct {
Error error
}
// RosenpassState contains the latest state of the Rosenpass configuration
type RosenpassState struct {
Enabled bool
Permissive bool
}
// NSGroupState represents the status of a DNS server group, including associated domains,
// whether it's enabled, and the last error message encountered during probing.
type NSGroupState struct {
ID string
Servers []string
Domains []string
Enabled bool
Error error
}
// FullStatus contains the full state held by the Status instance
type FullStatus struct {
Peers []State
ManagementState ManagementState
SignalState SignalState
LocalPeerState LocalPeerState
RosenpassState RosenpassState
Relays []relay.ProbeResult
NSGroupStates []NSGroupState
}
// Status holds a state of peers, signal, management connections and relays
type Status struct {
mux sync.Mutex
peers map[string]State
changeNotify map[string]chan struct{}
signalState bool
signalError error
managementState bool
managementError error
relayStates []relay.ProbeResult
localPeer LocalPeerState
offlinePeers []State
mgmAddress string
signalAddress string
notifier *notifier
rosenpassEnabled bool
rosenpassPermissive bool
nsGroupStates []NSGroupState
mux sync.Mutex
peers map[string]State
changeNotify map[string]chan struct{}
signalState bool
signalError error
managementState bool
managementError error
relayStates []relay.ProbeResult
localPeer LocalPeerState
offlinePeers []State
mgmAddress string
signalAddress string
notifier *notifier
// To reduce the number of notification invocation this bool will be true when need to call the notification
// Some Peer actions mostly used by in a batch when the network map has been synchronized. In these type of events
@@ -175,7 +115,6 @@ func (d *Status) AddPeer(peerPubKey string, fqdn string) error {
PubKey: peerPubKey,
ConnStatus: StatusDisconnected,
FQDN: fqdn,
Mux: new(sync.RWMutex),
}
d.peerListChangedForNotification = true
return nil
@@ -222,10 +161,6 @@ func (d *Status) UpdatePeerState(receivedState State) error {
peerState.IP = receivedState.IP
}
if receivedState.GetRoutes() != nil {
peerState.SetRoutes(receivedState.GetRoutes())
}
skipNotification := shouldSkipNotify(receivedState, peerState)
if receivedState.ConnStatus != peerState.ConnStatus {
@@ -237,7 +172,6 @@ func (d *Status) UpdatePeerState(receivedState State) error {
peerState.RemoteIceCandidateType = receivedState.RemoteIceCandidateType
peerState.LocalIceCandidateEndpoint = receivedState.LocalIceCandidateEndpoint
peerState.RemoteIceCandidateEndpoint = receivedState.RemoteIceCandidateEndpoint
peerState.RosenpassEnabled = receivedState.RosenpassEnabled
}
d.peers[receivedState.PubKey] = peerState
@@ -256,8 +190,8 @@ func (d *Status) UpdatePeerState(receivedState State) error {
return nil
}
// UpdateWireGuardPeerState updates the WireGuard bits of the peer state
func (d *Status) UpdateWireGuardPeerState(pubKey string, wgStats iface.WGStats) error {
// UpdateWireguardPeerState updates the wireguard bits of the peer state
func (d *Status) UpdateWireguardPeerState(pubKey string, wgStats iface.WGStats) error {
d.mux.Lock()
defer d.mux.Unlock()
@@ -330,13 +264,6 @@ func (d *Status) GetPeerStateChangeNotifier(peer string) <-chan struct{} {
return ch
}
// GetLocalPeerState returns the local peer state
func (d *Status) GetLocalPeerState() LocalPeerState {
d.mux.Lock()
defer d.mux.Unlock()
return d.localPeer
}
// UpdateLocalPeerState updates local peer status
func (d *Status) UpdateLocalPeerState(localPeerState LocalPeerState) {
d.mux.Lock()
@@ -389,14 +316,6 @@ func (d *Status) UpdateManagementAddress(mgmAddress string) {
d.mgmAddress = mgmAddress
}
// UpdateRosenpass update the Rosenpass configuration
func (d *Status) UpdateRosenpass(rosenpassEnabled, rosenpassPermissive bool) {
d.mux.Lock()
defer d.mux.Unlock()
d.rosenpassPermissive = rosenpassPermissive
d.rosenpassEnabled = rosenpassEnabled
}
// MarkSignalDisconnected sets SignalState to disconnected
func (d *Status) MarkSignalDisconnected(err error) {
d.mux.Lock()
@@ -423,19 +342,6 @@ func (d *Status) UpdateRelayStates(relayResults []relay.ProbeResult) {
d.relayStates = relayResults
}
func (d *Status) UpdateDNSStates(dnsStates []NSGroupState) {
d.mux.Lock()
defer d.mux.Unlock()
d.nsGroupStates = dnsStates
}
func (d *Status) GetRosenpassState() RosenpassState {
return RosenpassState{
d.rosenpassEnabled,
d.rosenpassPermissive,
}
}
func (d *Status) GetManagementState() ManagementState {
return ManagementState{
d.mgmAddress,
@@ -444,39 +350,6 @@ func (d *Status) GetManagementState() ManagementState {
}
}
func (d *Status) UpdateLatency(pubKey string, latency time.Duration) error {
if latency <= 0 {
return nil
}
d.mux.Lock()
defer d.mux.Unlock()
peerState, ok := d.peers[pubKey]
if !ok {
return errors.New("peer doesn't exist")
}
peerState.Latency = latency
d.peers[pubKey] = peerState
return nil
}
// IsLoginRequired determines if a peer's login has expired.
func (d *Status) IsLoginRequired() bool {
d.mux.Lock()
defer d.mux.Unlock()
// if peer is connected to the management then login is not expired
if d.managementState {
return false
}
s, ok := gstatus.FromError(d.managementError)
if ok && (s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied) {
return true
}
return false
}
func (d *Status) GetSignalState() SignalState {
return SignalState{
d.signalAddress,
@@ -489,10 +362,6 @@ func (d *Status) GetRelayStates() []relay.ProbeResult {
return d.relayStates
}
func (d *Status) GetDNSStates() []NSGroupState {
return d.nsGroupStates
}
// GetFullStatus gets full status
func (d *Status) GetFullStatus() FullStatus {
d.mux.Lock()
@@ -503,8 +372,6 @@ func (d *Status) GetFullStatus() FullStatus {
SignalState: d.GetSignalState(),
LocalPeerState: d.localPeer,
Relays: d.GetRelayStates(),
RosenpassState: d.GetRosenpassState(),
NSGroupStates: d.GetDNSStates(),
}
for _, status := range d.peers {

View File

@@ -3,7 +3,6 @@ package peer
import (
"errors"
"testing"
"sync"
"github.com/stretchr/testify/assert"
)
@@ -43,7 +42,6 @@ func TestUpdatePeerState(t *testing.T) {
status := NewRecorder("https://mgm")
peerState := State{
PubKey: key,
Mux: new(sync.RWMutex),
}
status.peers[key] = peerState
@@ -64,7 +62,6 @@ func TestStatus_UpdatePeerFQDN(t *testing.T) {
status := NewRecorder("https://mgm")
peerState := State{
PubKey: key,
Mux: new(sync.RWMutex),
}
status.peers[key] = peerState
@@ -83,7 +80,6 @@ func TestGetPeerStateChangeNotifierLogic(t *testing.T) {
status := NewRecorder("https://mgm")
peerState := State{
PubKey: key,
Mux: new(sync.RWMutex),
}
status.peers[key] = peerState
@@ -108,7 +104,6 @@ func TestRemovePeer(t *testing.T) {
status := NewRecorder("https://mgm")
peerState := State{
PubKey: key,
Mux: new(sync.RWMutex),
}
status.peers[key] = peerState

View File

@@ -10,9 +10,6 @@ import (
"github.com/pion/stun/v2"
"github.com/pion/turn/v3"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/stdnet"
nbnet "github.com/netbirdio/netbird/util/net"
)
// ProbeResult holds the info about the result of a relay probe request
@@ -30,15 +27,7 @@ func ProbeSTUN(ctx context.Context, uri *stun.URI) (addr string, probeErr error)
}
}()
net, err := stdnet.NewNet(nil)
if err != nil {
probeErr = fmt.Errorf("new net: %w", err)
return
}
client, err := stun.DialURI(uri, &stun.DialConfig{
Net: net,
})
client, err := stun.DialURI(uri, &stun.DialConfig{})
if err != nil {
probeErr = fmt.Errorf("dial: %w", err)
return
@@ -96,13 +85,14 @@ func ProbeTURN(ctx context.Context, uri *stun.URI) (addr string, probeErr error)
switch uri.Proto {
case stun.ProtoTypeUDP:
var err error
conn, err = nbnet.NewListener().ListenPacket(ctx, "udp", "")
conn, err = net.ListenPacket("udp", "")
if err != nil {
probeErr = fmt.Errorf("listen: %w", err)
return
}
case stun.ProtoTypeTCP:
tcpConn, err := nbnet.NewDialer().DialContext(ctx, "tcp", turnServerAddr)
dialer := net.Dialer{}
tcpConn, err := dialer.DialContext(ctx, "tcp", turnServerAddr)
if err != nil {
probeErr = fmt.Errorf("dial: %w", err)
return
@@ -119,18 +109,12 @@ func ProbeTURN(ctx context.Context, uri *stun.URI) (addr string, probeErr error)
}
}()
net, err := stdnet.NewNet(nil)
if err != nil {
probeErr = fmt.Errorf("new net: %w", err)
return
}
cfg := &turn.ClientConfig{
STUNServerAddr: turnServerAddr,
TURNServerAddr: turnServerAddr,
Conn: conn,
Username: uri.Username,
Password: uri.Password,
Net: net,
}
client, err := turn.NewClient(cfg)
if err != nil {

View File

@@ -3,9 +3,7 @@ package routemanager
import (
"context"
"fmt"
"net"
"net/netip"
"time"
log "github.com/sirupsen/logrus"
@@ -20,7 +18,6 @@ type routerPeerStatus struct {
connected bool
relayed bool
direct bool
latency time.Duration
}
type routesUpdate struct {
@@ -44,7 +41,6 @@ type clientNetwork struct {
func newClientNetworkWatcher(ctx context.Context, wgInterface *iface.WGIface, statusRecorder *peer.Status, network netip.Prefix) *clientNetwork {
ctx, cancel := context.WithCancel(ctx)
client := &clientNetwork{
ctx: ctx,
stop: cancel,
@@ -71,29 +67,14 @@ func (c *clientNetwork) getRouterPeerStatuses() map[string]routerPeerStatus {
connected: peerStatus.ConnStatus == peer.StatusConnected,
relayed: peerStatus.Relayed,
direct: peerStatus.Direct,
latency: peerStatus.Latency,
}
}
return routePeerStatuses
}
// getBestRouteFromStatuses determines the most optimal route from the available routes
// within a clientNetwork, taking into account peer connection status, route metrics, and
// preference for non-relayed and direct connections.
//
// It follows these prioritization rules:
// * Connected peers: Only routes with connected peers are considered.
// * Metric: Routes with lower metrics (better) are prioritized.
// * Non-relayed: Routes without relays are preferred.
// * Direct connections: Routes with direct peer connections are favored.
// * Stability: In case of equal scores, the currently active route (if any) is maintained.
// * Latency: Routes with lower latency are prioritized.
//
// It returns the ID of the selected optimal route.
func (c *clientNetwork) getBestRouteFromStatuses(routePeerStatuses map[string]routerPeerStatus) string {
chosen := ""
chosenScore := float64(0)
currScore := float64(0)
chosenScore := 0
currID := ""
if c.chosenRoute != nil {
@@ -101,7 +82,7 @@ func (c *clientNetwork) getBestRouteFromStatuses(routePeerStatuses map[string]ro
}
for _, r := range c.routes {
tempScore := float64(0)
tempScore := 0
peerStatus, found := routePeerStatuses[r.ID]
if !found || !peerStatus.connected {
continue
@@ -109,18 +90,9 @@ func (c *clientNetwork) getBestRouteFromStatuses(routePeerStatuses map[string]ro
if r.Metric < route.MaxMetric {
metricDiff := route.MaxMetric - r.Metric
tempScore = float64(metricDiff) * 10
tempScore = metricDiff * 10
}
// in some temporal cases, latency can be 0, so we set it to 1s to not block but try to avoid this route
latency := time.Second
if peerStatus.latency != 0 {
latency = peerStatus.latency
} else {
log.Warnf("peer %s has 0 latency", r.Peer)
}
tempScore += 1 - latency.Seconds()
if !peerStatus.relayed {
tempScore++
}
@@ -129,7 +101,7 @@ func (c *clientNetwork) getBestRouteFromStatuses(routePeerStatuses map[string]ro
tempScore++
}
if tempScore > chosenScore || (tempScore == chosenScore && chosen == "") {
if tempScore > chosenScore || (tempScore == chosenScore && r.ID == currID) {
chosen = r.ID
chosenScore = tempScore
}
@@ -138,30 +110,18 @@ func (c *clientNetwork) getBestRouteFromStatuses(routePeerStatuses map[string]ro
chosen = r.ID
chosenScore = tempScore
}
if r.ID == currID {
currScore = tempScore
}
}
switch {
case chosen == "":
if chosen == "" {
var peers []string
for _, r := range c.routes {
peers = append(peers, r.Peer)
}
log.Warnf("the network %s has not been assigned a routing peer as no peers from the list %s are currently connected", c.network, peers)
case chosen != currID:
if currScore != 0 && currScore < chosenScore+0.1 {
return currID
} else {
var peer string
if route := c.routes[chosen]; route != nil {
peer = route.Peer
}
log.Infof("new chosen route is %s with peer %s with score %f for network %s", chosen, peer, chosenScore, c.network)
}
} else if chosen != currID {
log.Infof("new chosen route is %s with peer %s with score %d for network %s", chosen, c.routes[chosen].Peer, chosenScore, c.network)
}
return chosen
@@ -198,21 +158,15 @@ func (c *clientNetwork) startPeersStatusChangeWatcher() {
func (c *clientNetwork) removeRouteFromWireguardPeer(peerKey string) error {
state, err := c.statusRecorder.GetPeer(peerKey)
if err != nil {
return fmt.Errorf("get peer state: %v", err)
return err
}
state.DeleteRoute(c.network.String())
if err := c.statusRecorder.UpdatePeerState(state); err != nil {
log.Warnf("Failed to update peer state: %v", err)
}
if state.ConnStatus != peer.StatusConnected {
return nil
}
err = c.wgInterface.RemoveAllowedIP(peerKey, c.network.String())
if err != nil {
return fmt.Errorf("remove allowed IP %s removed for peer %s, err: %v",
return fmt.Errorf("couldn't remove allowed IP %s removed for peer %s, err: %v",
c.network, c.chosenRoute.Peer, err)
}
return nil
@@ -220,26 +174,30 @@ func (c *clientNetwork) removeRouteFromWireguardPeer(peerKey string) error {
func (c *clientNetwork) removeRouteFromPeerAndSystem() error {
if c.chosenRoute != nil {
if err := removeVPNRoute(c.network, c.getAsInterface()); err != nil {
return fmt.Errorf("remove route %s from system, err: %v", c.network, err)
err := c.removeRouteFromWireguardPeer(c.chosenRoute.Peer)
if err != nil {
return err
}
if err := c.removeRouteFromWireguardPeer(c.chosenRoute.Peer); err != nil {
return fmt.Errorf("remove route: %v", err)
err = removeFromRouteTableIfNonSystem(c.network, c.wgInterface.Address().IP.String())
if err != nil {
return fmt.Errorf("couldn't remove route %s from system, err: %v",
c.network, err)
}
}
return nil
}
func (c *clientNetwork) recalculateRouteAndUpdatePeerAndSystem() error {
var err error
routerPeerStatuses := c.getRouterPeerStatuses()
chosen := c.getBestRouteFromStatuses(routerPeerStatuses)
// If no route is chosen, remove the route from the peer and system
if chosen == "" {
if err := c.removeRouteFromPeerAndSystem(); err != nil {
return fmt.Errorf("remove route from peer and system: %v", err)
err = c.removeRouteFromPeerAndSystem()
if err != nil {
return err
}
c.chosenRoute = nil
@@ -247,7 +205,6 @@ func (c *clientNetwork) recalculateRouteAndUpdatePeerAndSystem() error {
return nil
}
// If the chosen route is the same as the current route, do nothing
if c.chosenRoute != nil && c.chosenRoute.ID == chosen {
if c.chosenRoute.IsEqual(c.routes[chosen]) {
return nil
@@ -255,31 +212,21 @@ func (c *clientNetwork) recalculateRouteAndUpdatePeerAndSystem() error {
}
if c.chosenRoute != nil {
// If a previous route exists, remove it from the peer
if err := c.removeRouteFromWireguardPeer(c.chosenRoute.Peer); err != nil {
return fmt.Errorf("remove route from peer: %v", err)
err = c.removeRouteFromWireguardPeer(c.chosenRoute.Peer)
if err != nil {
return err
}
} else {
// otherwise add the route to the system
if err := addVPNRoute(c.network, c.getAsInterface()); err != nil {
err = addToRouteTableIfNoExists(c.network, c.wgInterface.Address().IP.String())
if err != nil {
return fmt.Errorf("route %s couldn't be added for peer %s, err: %v",
c.network.String(), c.wgInterface.Address().IP.String(), err)
}
}
c.chosenRoute = c.routes[chosen]
state, err := c.statusRecorder.GetPeer(c.chosenRoute.Peer)
err = c.wgInterface.AddAllowedIP(c.chosenRoute.Peer, c.network.String())
if err != nil {
log.Errorf("Failed to get peer state: %v", err)
} else {
state.AddRoute(c.network.String())
if err := c.statusRecorder.UpdatePeerState(state); err != nil {
log.Warnf("Failed to update peer state: %v", err)
}
}
if err := c.wgInterface.AddAllowedIP(c.chosenRoute.Peer, c.network.String()); err != nil {
log.Errorf("couldn't add allowed IP %s added for peer %s, err: %v",
c.network, c.chosenRoute.Peer, err)
}
@@ -320,21 +267,21 @@ func (c *clientNetwork) peersStateAndUpdateWatcher() {
log.Debugf("stopping watcher for network %s", c.network)
err := c.removeRouteFromPeerAndSystem()
if err != nil {
log.Errorf("Couldn't remove route from peer and system for network %s: %v", c.network, err)
log.Error(err)
}
return
case <-c.peerStateUpdate:
err := c.recalculateRouteAndUpdatePeerAndSystem()
if err != nil {
log.Errorf("Couldn't recalculate route and update peer and system: %v", err)
log.Error(err)
}
case update := <-c.routeUpdate:
if update.updateSerial < c.updateSerial {
log.Warnf("Received a routes update with smaller serial number, ignoring it")
log.Warnf("received a routes update with smaller serial number, ignoring it")
continue
}
log.Debugf("Received a new client network route update for %s", c.network)
log.Debugf("received a new client network route update for %s", c.network)
c.handleUpdate(update)
@@ -342,22 +289,10 @@ func (c *clientNetwork) peersStateAndUpdateWatcher() {
err := c.recalculateRouteAndUpdatePeerAndSystem()
if err != nil {
log.Errorf("Couldn't recalculate route and update peer and system for network %s: %v", c.network, err)
log.Error(err)
}
c.startPeersStatusChangeWatcher()
}
}
}
func (c *clientNetwork) getAsInterface() *net.Interface {
intf, err := net.InterfaceByName(c.wgInterface.Name())
if err != nil {
log.Warnf("Couldn't get interface by name %s: %v", c.wgInterface.Name(), err)
intf = &net.Interface{
Name: c.wgInterface.Name(),
}
}
return intf
}

View File

@@ -3,7 +3,6 @@ package routemanager
import (
"net/netip"
"testing"
"time"
"github.com/netbirdio/netbird/route"
)
@@ -14,7 +13,7 @@ func TestGetBestrouteFromStatuses(t *testing.T) {
name string
statuses map[string]routerPeerStatus
expectedRouteID string
currentRoute string
currentRoute *route.Route
existingRoutes map[string]*route.Route
}{
{
@@ -33,7 +32,7 @@ func TestGetBestrouteFromStatuses(t *testing.T) {
Peer: "peer1",
},
},
currentRoute: "",
currentRoute: nil,
expectedRouteID: "route1",
},
{
@@ -52,7 +51,7 @@ func TestGetBestrouteFromStatuses(t *testing.T) {
Peer: "peer1",
},
},
currentRoute: "",
currentRoute: nil,
expectedRouteID: "route1",
},
{
@@ -71,7 +70,7 @@ func TestGetBestrouteFromStatuses(t *testing.T) {
Peer: "peer1",
},
},
currentRoute: "",
currentRoute: nil,
expectedRouteID: "route1",
},
{
@@ -90,7 +89,7 @@ func TestGetBestrouteFromStatuses(t *testing.T) {
Peer: "peer1",
},
},
currentRoute: "",
currentRoute: nil,
expectedRouteID: "",
},
{
@@ -119,7 +118,7 @@ func TestGetBestrouteFromStatuses(t *testing.T) {
Peer: "peer2",
},
},
currentRoute: "",
currentRoute: nil,
expectedRouteID: "route1",
},
{
@@ -148,7 +147,7 @@ func TestGetBestrouteFromStatuses(t *testing.T) {
Peer: "peer2",
},
},
currentRoute: "",
currentRoute: nil,
expectedRouteID: "route1",
},
{
@@ -177,141 +176,18 @@ func TestGetBestrouteFromStatuses(t *testing.T) {
Peer: "peer2",
},
},
currentRoute: "",
currentRoute: nil,
expectedRouteID: "route1",
},
{
name: "multiple connected peers with different latencies",
statuses: map[string]routerPeerStatus{
"route1": {
connected: true,
latency: 300 * time.Millisecond,
},
"route2": {
connected: true,
latency: 10 * time.Millisecond,
},
},
existingRoutes: map[string]*route.Route{
"route1": {
ID: "route1",
Metric: route.MaxMetric,
Peer: "peer1",
},
"route2": {
ID: "route2",
Metric: route.MaxMetric,
Peer: "peer2",
},
},
currentRoute: "",
expectedRouteID: "route2",
},
{
name: "should ignore routes with latency 0",
statuses: map[string]routerPeerStatus{
"route1": {
connected: true,
latency: 0 * time.Millisecond,
},
"route2": {
connected: true,
latency: 10 * time.Millisecond,
},
},
existingRoutes: map[string]*route.Route{
"route1": {
ID: "route1",
Metric: route.MaxMetric,
Peer: "peer1",
},
"route2": {
ID: "route2",
Metric: route.MaxMetric,
Peer: "peer2",
},
},
currentRoute: "",
expectedRouteID: "route2",
},
{
name: "current route with similar score and similar but slightly worse latency should not change",
statuses: map[string]routerPeerStatus{
"route1": {
connected: true,
relayed: false,
direct: true,
latency: 12 * time.Millisecond,
},
"route2": {
connected: true,
relayed: false,
direct: true,
latency: 10 * time.Millisecond,
},
},
existingRoutes: map[string]*route.Route{
"route1": {
ID: "route1",
Metric: route.MaxMetric,
Peer: "peer1",
},
"route2": {
ID: "route2",
Metric: route.MaxMetric,
Peer: "peer2",
},
},
currentRoute: "route1",
expectedRouteID: "route1",
},
{
name: "current chosen route doesn't exist anymore",
statuses: map[string]routerPeerStatus{
"route1": {
connected: true,
relayed: false,
direct: true,
latency: 20 * time.Millisecond,
},
"route2": {
connected: true,
relayed: false,
direct: true,
latency: 10 * time.Millisecond,
},
},
existingRoutes: map[string]*route.Route{
"route1": {
ID: "route1",
Metric: route.MaxMetric,
Peer: "peer1",
},
"route2": {
ID: "route2",
Metric: route.MaxMetric,
Peer: "peer2",
},
},
currentRoute: "routeDoesntExistAnymore",
expectedRouteID: "route2",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
currentRoute := &route.Route{
ID: "routeDoesntExistAnymore",
}
if tc.currentRoute != "" {
currentRoute = tc.existingRoutes[tc.currentRoute]
}
// create new clientNetwork
client := &clientNetwork{
network: netip.MustParsePrefix("192.168.0.0/24"),
routes: tc.existingRoutes,
chosenRoute: currentRoute,
chosenRoute: tc.currentRoute,
}
chosenRoute := client.getBestRouteFromStatuses(tc.statuses)

View File

@@ -2,10 +2,6 @@ package routemanager
import (
"context"
"fmt"
"net"
"net/netip"
"net/url"
"runtime"
"sync"
@@ -14,24 +10,14 @@ import (
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/internal/listener"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/routeselector"
"github.com/netbirdio/netbird/iface"
"github.com/netbirdio/netbird/route"
nbnet "github.com/netbirdio/netbird/util/net"
"github.com/netbirdio/netbird/version"
)
var defaultv4 = netip.PrefixFrom(netip.IPv4Unspecified(), 0)
// nolint:unused
var defaultv6 = netip.PrefixFrom(netip.IPv6Unspecified(), 0)
// Manager is a route manager interface
type Manager interface {
Init() (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error)
UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) (map[string]*route.Route, map[string][]*route.Route, error)
TriggerSelection(map[string][]*route.Route)
GetRouteSelector() *routeselector.RouteSelector
UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) error
SetRouteChangeListener(listener listener.NetworkChangeListener)
InitialRouteRange() []string
EnableServerRouter(firewall firewall.Manager) error
@@ -44,7 +30,6 @@ type DefaultManager struct {
stop context.CancelFunc
mux sync.Mutex
clientNetworks map[string]*clientNetwork
routeSelector *routeselector.RouteSelector
serverRouter serverRouter
statusRecorder *peer.Status
wgInterface *iface.WGIface
@@ -58,7 +43,6 @@ func NewManager(ctx context.Context, pubKey string, wgInterface *iface.WGIface,
ctx: mCTX,
stop: cancel,
clientNetworks: make(map[string]*clientNetwork),
routeSelector: routeselector.NewRouteSelector(),
statusRecorder: statusRecorder,
wgInterface: wgInterface,
pubKey: pubKey,
@@ -72,31 +56,9 @@ func NewManager(ctx context.Context, pubKey string, wgInterface *iface.WGIface,
return dm
}
// Init sets up the routing
func (m *DefaultManager) Init() (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) {
if nbnet.CustomRoutingDisabled() {
return nil, nil, nil
}
if err := cleanupRouting(); err != nil {
log.Warnf("Failed cleaning up routing: %v", err)
}
mgmtAddress := m.statusRecorder.GetManagementState().URL
signalAddress := m.statusRecorder.GetSignalState().URL
ips := resolveURLsToIPs([]string{mgmtAddress, signalAddress})
beforePeerHook, afterPeerHook, err := setupRouting(ips, m.wgInterface)
if err != nil {
return nil, nil, fmt.Errorf("setup routing: %w", err)
}
log.Info("Routing setup complete")
return beforePeerHook, afterPeerHook, nil
}
func (m *DefaultManager) EnableServerRouter(firewall firewall.Manager) error {
var err error
m.serverRouter, err = newServerRouter(m.ctx, m.wgInterface, firewall, m.statusRecorder)
m.serverRouter, err = newServerRouter(m.ctx, m.wgInterface, firewall)
if err != nil {
return err
}
@@ -109,42 +71,32 @@ func (m *DefaultManager) Stop() {
if m.serverRouter != nil {
m.serverRouter.cleanUp()
}
if !nbnet.CustomRoutingDisabled() {
if err := cleanupRouting(); err != nil {
log.Errorf("Error cleaning up routing: %v", err)
} else {
log.Info("Routing cleanup complete")
}
}
m.ctx = nil
}
// UpdateRoutes compares received routes with existing routes and removes, updates or adds them to the client and server maps
func (m *DefaultManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) (map[string]*route.Route, map[string][]*route.Route, error) {
// UpdateRoutes compares received routes with existing routes and remove, update or add them to the client and server maps
func (m *DefaultManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) error {
select {
case <-m.ctx.Done():
log.Infof("not updating routes as context is closed")
return nil, nil, m.ctx.Err()
return m.ctx.Err()
default:
m.mux.Lock()
defer m.mux.Unlock()
newServerRoutesMap, newClientRoutesIDMap := m.classifyRoutes(newRoutes)
newServerRoutesMap, newClientRoutesIDMap := m.classifiesRoutes(newRoutes)
filteredClientRoutes := m.routeSelector.FilterSelected(newClientRoutesIDMap)
m.updateClientNetworks(updateSerial, filteredClientRoutes)
m.notifier.onNewRoutes(filteredClientRoutes)
m.updateClientNetworks(updateSerial, newClientRoutesIDMap)
m.notifier.onNewRoutes(newClientRoutesIDMap)
if m.serverRouter != nil {
err := m.serverRouter.updateRoutes(newServerRoutesMap)
if err != nil {
return nil, nil, fmt.Errorf("update routes: %w", err)
return err
}
}
return newServerRoutesMap, newClientRoutesIDMap, nil
return nil
}
}
@@ -158,51 +110,16 @@ func (m *DefaultManager) InitialRouteRange() []string {
return m.notifier.initialRouteRanges()
}
// GetRouteSelector returns the route selector
func (m *DefaultManager) GetRouteSelector() *routeselector.RouteSelector {
return m.routeSelector
}
// GetClientRoutes returns the client routes
func (m *DefaultManager) GetClientRoutes() map[string]*clientNetwork {
return m.clientNetworks
}
// TriggerSelection triggers the selection of routes, stopping deselected watchers and starting newly selected ones
func (m *DefaultManager) TriggerSelection(networks map[string][]*route.Route) {
m.mux.Lock()
defer m.mux.Unlock()
networks = m.routeSelector.FilterSelected(networks)
m.stopObsoleteClients(networks)
for id, routes := range networks {
if _, found := m.clientNetworks[id]; found {
// don't touch existing client network watchers
continue
}
clientNetworkWatcher := newClientNetworkWatcher(m.ctx, m.wgInterface, m.statusRecorder, routes[0].Network)
m.clientNetworks[id] = clientNetworkWatcher
go clientNetworkWatcher.peersStateAndUpdateWatcher()
clientNetworkWatcher.sendUpdateToClientNetworkWatcher(routesUpdate{routes: routes})
}
}
// stopObsoleteClients stops the client network watcher for the networks that are not in the new list
func (m *DefaultManager) stopObsoleteClients(networks map[string][]*route.Route) {
func (m *DefaultManager) updateClientNetworks(updateSerial uint64, networks map[string][]*route.Route) {
// removing routes that do not exist as per the update from the Management service.
for id, client := range m.clientNetworks {
if _, ok := networks[id]; !ok {
log.Debugf("Stopping client network watcher, %s", id)
_, found := networks[id]
if !found {
log.Debugf("stopping client network watcher, %s", id)
client.stop()
delete(m.clientNetworks, id)
}
}
}
func (m *DefaultManager) updateClientNetworks(updateSerial uint64, networks map[string][]*route.Route) {
// removing routes that do not exist as per the update from the Management service.
m.stopObsoleteClients(networks)
for id, routes := range networks {
clientNetworkWatcher, found := m.clientNetworks[id]
@@ -219,7 +136,7 @@ func (m *DefaultManager) updateClientNetworks(updateSerial uint64, networks map[
}
}
func (m *DefaultManager) classifyRoutes(newRoutes []*route.Route) (map[string]*route.Route, map[string][]*route.Route) {
func (m *DefaultManager) classifiesRoutes(newRoutes []*route.Route) (map[string]*route.Route, map[string][]*route.Route) {
newClientRoutesIDMap := make(map[string][]*route.Route)
newServerRoutesMap := make(map[string]*route.Route)
ownNetworkIDs := make(map[string]bool)
@@ -240,7 +157,11 @@ func (m *DefaultManager) classifyRoutes(newRoutes []*route.Route) (map[string]*r
for _, newRoute := range newRoutes {
networkID := route.GetHAUniqueID(newRoute)
if !ownNetworkIDs[networkID] {
if !isPrefixSupported(newRoute.Network) {
// if prefix is too small, lets assume is a possible default route which is not yet supported
// we skip this route management
if newRoute.Network.Bits() < minRangeBits {
log.Errorf("this agent version: %s, doesn't support default routes, received %s, skipping this route",
version.NetbirdVersion(), newRoute.Network)
continue
}
newClientRoutesIDMap[networkID] = append(newClientRoutesIDMap[networkID], newRoute)
@@ -251,47 +172,10 @@ func (m *DefaultManager) classifyRoutes(newRoutes []*route.Route) (map[string]*r
}
func (m *DefaultManager) clientRoutes(initialRoutes []*route.Route) []*route.Route {
_, crMap := m.classifyRoutes(initialRoutes)
_, crMap := m.classifiesRoutes(initialRoutes)
rs := make([]*route.Route, 0)
for _, routes := range crMap {
rs = append(rs, routes...)
}
return rs
}
func isPrefixSupported(prefix netip.Prefix) bool {
if !nbnet.CustomRoutingDisabled() {
switch runtime.GOOS {
case "linux", "windows", "darwin", "ios":
return true
}
}
// If prefix is too small, lets assume it is a possible default prefix which is not yet supported
// we skip this prefix management
if prefix.Bits() <= minRangeBits {
log.Warnf("This agent version: %s, doesn't support default routes, received %s, skipping this prefix",
version.NetbirdVersion(), prefix)
return false
}
return true
}
// resolveURLsToIPs takes a slice of URLs, resolves them to IP addresses and returns a slice of IPs.
func resolveURLsToIPs(urls []string) []net.IP {
var ips []net.IP
for _, rawurl := range urls {
u, err := url.Parse(rawurl)
if err != nil {
log.Errorf("Failed to parse url %s: %v", rawurl, err)
continue
}
ipAddrs, err := net.LookupIP(u.Hostname())
if err != nil {
log.Errorf("Failed to resolve host %s: %v", u.Hostname(), err)
continue
}
ips = append(ips, ipAddrs...)
}
return ips
}

View File

@@ -28,14 +28,13 @@ const remotePeerKey2 = "remote1"
func TestManagerUpdateRoutes(t *testing.T) {
testCases := []struct {
name string
inputInitRoutes []*route.Route
inputRoutes []*route.Route
inputSerial uint64
removeSrvRouter bool
serverRoutesExpected int
clientNetworkWatchersExpected int
clientNetworkWatchersExpectedAllowed int
name string
inputInitRoutes []*route.Route
inputRoutes []*route.Route
inputSerial uint64
removeSrvRouter bool
serverRoutesExpected int
clientNetworkWatchersExpected int
}{
{
name: "Should create 2 client networks",
@@ -201,9 +200,8 @@ func TestManagerUpdateRoutes(t *testing.T) {
Enabled: true,
},
},
inputSerial: 1,
clientNetworkWatchersExpected: 0,
clientNetworkWatchersExpectedAllowed: 1,
inputSerial: 1,
clientNetworkWatchersExpected: 0,
},
{
name: "Remove 1 Client Route",
@@ -417,10 +415,6 @@ func TestManagerUpdateRoutes(t *testing.T) {
statusRecorder := peer.NewRecorder("https://mgm")
ctx := context.TODO()
routeManager := NewManager(ctx, localPeerKey, wgInterface, statusRecorder, nil)
_, _, err = routeManager.Init()
require.NoError(t, err, "should init route manager")
defer routeManager.Stop()
if testCase.removeSrvRouter {
@@ -428,18 +422,14 @@ func TestManagerUpdateRoutes(t *testing.T) {
}
if len(testCase.inputInitRoutes) > 0 {
_, _, err = routeManager.UpdateRoutes(testCase.inputSerial, testCase.inputRoutes)
err = routeManager.UpdateRoutes(testCase.inputSerial, testCase.inputRoutes)
require.NoError(t, err, "should update routes with init routes")
}
_, _, err = routeManager.UpdateRoutes(testCase.inputSerial+uint64(len(testCase.inputInitRoutes)), testCase.inputRoutes)
err = routeManager.UpdateRoutes(testCase.inputSerial+uint64(len(testCase.inputInitRoutes)), testCase.inputRoutes)
require.NoError(t, err, "should update routes")
expectedWatchers := testCase.clientNetworkWatchersExpected
if (runtime.GOOS == "linux" || runtime.GOOS == "windows" || runtime.GOOS == "darwin") && testCase.clientNetworkWatchersExpectedAllowed != 0 {
expectedWatchers = testCase.clientNetworkWatchersExpectedAllowed
}
require.Len(t, routeManager.clientNetworks, expectedWatchers, "client networks size should match")
require.Len(t, routeManager.clientNetworks, testCase.clientNetworkWatchersExpected, "client networks size should match")
if runtime.GOOS == "linux" && routeManager.serverRouter != nil {
sr := routeManager.serverRouter.(*defaultServerRouter)

View File

@@ -6,22 +6,14 @@ import (
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/internal/listener"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/routeselector"
"github.com/netbirdio/netbird/iface"
"github.com/netbirdio/netbird/route"
)
// MockManager is the mock instance of a route manager
type MockManager struct {
UpdateRoutesFunc func(updateSerial uint64, newRoutes []*route.Route) (map[string]*route.Route, map[string][]*route.Route, error)
TriggerSelectionFunc func(map[string][]*route.Route)
GetRouteSelectorFunc func() *routeselector.RouteSelector
StopFunc func()
}
func (m *MockManager) Init() (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) {
return nil, nil, nil
UpdateRoutesFunc func(updateSerial uint64, newRoutes []*route.Route) error
StopFunc func()
}
// InitialRouteRange mock implementation of InitialRouteRange from Manager interface
@@ -30,25 +22,11 @@ func (m *MockManager) InitialRouteRange() []string {
}
// UpdateRoutes mock implementation of UpdateRoutes from Manager interface
func (m *MockManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) (map[string]*route.Route, map[string][]*route.Route, error) {
func (m *MockManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) error {
if m.UpdateRoutesFunc != nil {
return m.UpdateRoutesFunc(updateSerial, newRoutes)
}
return nil, nil, fmt.Errorf("method UpdateRoutes is not implemented")
}
func (m *MockManager) TriggerSelection(networks map[string][]*route.Route) {
if m.TriggerSelectionFunc != nil {
m.TriggerSelectionFunc(networks)
}
}
// GetRouteSelector mock implementation of GetRouteSelector from Manager interface
func (m *MockManager) GetRouteSelector() *routeselector.RouteSelector {
if m.GetRouteSelectorFunc != nil {
return m.GetRouteSelectorFunc()
}
return nil
return fmt.Errorf("method UpdateRoutes is not implemented")
}
// Start mock implementation of Start from Manager interface

View File

@@ -1,127 +0,0 @@
//go:build !android && !ios
package routemanager
import (
"errors"
"fmt"
"net"
"net/netip"
"sync"
"github.com/hashicorp/go-multierror"
log "github.com/sirupsen/logrus"
nbnet "github.com/netbirdio/netbird/util/net"
)
type ref struct {
count int
nexthop netip.Addr
intf *net.Interface
}
type RouteManager struct {
// refCountMap keeps track of the reference ref for prefixes
refCountMap map[netip.Prefix]ref
// prefixMap keeps track of the prefixes associated with a connection ID for removal
prefixMap map[nbnet.ConnectionID][]netip.Prefix
addRoute AddRouteFunc
removeRoute RemoveRouteFunc
mutex sync.Mutex
}
type AddRouteFunc func(prefix netip.Prefix) (nexthop netip.Addr, intf *net.Interface, err error)
type RemoveRouteFunc func(prefix netip.Prefix, nexthop netip.Addr, intf *net.Interface) error
func NewRouteManager(addRoute AddRouteFunc, removeRoute RemoveRouteFunc) *RouteManager {
// TODO: read initial routing table into refCountMap
return &RouteManager{
refCountMap: map[netip.Prefix]ref{},
prefixMap: map[nbnet.ConnectionID][]netip.Prefix{},
addRoute: addRoute,
removeRoute: removeRoute,
}
}
func (rm *RouteManager) AddRouteRef(connID nbnet.ConnectionID, prefix netip.Prefix) error {
rm.mutex.Lock()
defer rm.mutex.Unlock()
ref := rm.refCountMap[prefix]
log.Debugf("Increasing route ref count %d for prefix %s", ref.count, prefix)
// Add route to the system, only if it's a new prefix
if ref.count == 0 {
log.Debugf("Adding route for prefix %s", prefix)
nexthop, intf, err := rm.addRoute(prefix)
if errors.Is(err, ErrRouteNotFound) {
return nil
}
if errors.Is(err, ErrRouteNotAllowed) {
log.Debugf("Adding route for prefix %s: %s", prefix, err)
}
if err != nil {
return fmt.Errorf("failed to add route for prefix %s: %w", prefix, err)
}
ref.nexthop = nexthop
ref.intf = intf
}
ref.count++
rm.refCountMap[prefix] = ref
rm.prefixMap[connID] = append(rm.prefixMap[connID], prefix)
return nil
}
func (rm *RouteManager) RemoveRouteRef(connID nbnet.ConnectionID) error {
rm.mutex.Lock()
defer rm.mutex.Unlock()
prefixes, ok := rm.prefixMap[connID]
if !ok {
log.Debugf("No prefixes found for connection ID %s", connID)
return nil
}
var result *multierror.Error
for _, prefix := range prefixes {
ref := rm.refCountMap[prefix]
log.Debugf("Decreasing route ref count %d for prefix %s", ref.count, prefix)
if ref.count == 1 {
log.Debugf("Removing route for prefix %s", prefix)
// TODO: don't fail if the route is not found
if err := rm.removeRoute(prefix, ref.nexthop, ref.intf); err != nil {
result = multierror.Append(result, fmt.Errorf("remove route for prefix %s: %w", prefix, err))
continue
}
delete(rm.refCountMap, prefix)
} else {
ref.count--
rm.refCountMap[prefix] = ref
}
}
delete(rm.prefixMap, connID)
return result.ErrorOrNil()
}
// Flush removes all references and routes from the system
func (rm *RouteManager) Flush() error {
rm.mutex.Lock()
defer rm.mutex.Unlock()
var result *multierror.Error
for prefix := range rm.refCountMap {
log.Debugf("Removing route for prefix %s", prefix)
ref := rm.refCountMap[prefix]
if err := rm.removeRoute(prefix, ref.nexthop, ref.intf); err != nil {
result = multierror.Append(result, fmt.Errorf("remove route for prefix %s: %w", prefix, err))
}
}
rm.refCountMap = map[netip.Prefix]ref{}
rm.prefixMap = map[nbnet.ConnectionID][]netip.Prefix{}
return result.ErrorOrNil()
}

View File

@@ -7,10 +7,9 @@ import (
"fmt"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/iface"
)
func newServerRouter(context.Context, *iface.WGIface, firewall.Manager, *peer.Status) (serverRouter, error) {
func newServerRouter(context.Context, *iface.WGIface, firewall.Manager) (serverRouter, error) {
return nil, fmt.Errorf("server route not supported on this os")
}

View File

@@ -4,34 +4,30 @@ package routemanager
import (
"context"
"fmt"
"net/netip"
"sync"
log "github.com/sirupsen/logrus"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/iface"
"github.com/netbirdio/netbird/route"
)
type defaultServerRouter struct {
mux sync.Mutex
ctx context.Context
routes map[string]*route.Route
firewall firewall.Manager
wgInterface *iface.WGIface
statusRecorder *peer.Status
mux sync.Mutex
ctx context.Context
routes map[string]*route.Route
firewall firewall.Manager
wgInterface *iface.WGIface
}
func newServerRouter(ctx context.Context, wgInterface *iface.WGIface, firewall firewall.Manager, statusRecorder *peer.Status) (serverRouter, error) {
func newServerRouter(ctx context.Context, wgInterface *iface.WGIface, firewall firewall.Manager) (serverRouter, error) {
return &defaultServerRouter{
ctx: ctx,
routes: make(map[string]*route.Route),
firewall: firewall,
wgInterface: wgInterface,
statusRecorder: statusRecorder,
ctx: ctx,
routes: make(map[string]*route.Route),
firewall: firewall,
wgInterface: wgInterface,
}, nil
}
@@ -49,7 +45,7 @@ func (m *defaultServerRouter) updateRoutes(routesMap map[string]*route.Route) er
oldRoute := m.routes[routeID]
err := m.removeFromServerNetwork(oldRoute)
if err != nil {
log.Errorf("Unable to remove route id: %s, network %s, from server, got: %v",
log.Errorf("unable to remove route id: %s, network %s, from server, got: %v",
oldRoute.ID, oldRoute.Network, err)
}
delete(m.routes, routeID)
@@ -63,7 +59,7 @@ func (m *defaultServerRouter) updateRoutes(routesMap map[string]*route.Route) er
err := m.addToServerNetwork(newRoute)
if err != nil {
log.Errorf("Unable to add route %s from server, got: %v", newRoute.ID, err)
log.Errorf("unable to add route %s from server, got: %v", newRoute.ID, err)
continue
}
m.routes[id] = newRoute
@@ -82,28 +78,16 @@ func (m *defaultServerRouter) updateRoutes(routesMap map[string]*route.Route) er
func (m *defaultServerRouter) removeFromServerNetwork(route *route.Route) error {
select {
case <-m.ctx.Done():
log.Infof("Not removing from server network because context is done")
log.Infof("not removing from server network because context is done")
return m.ctx.Err()
default:
m.mux.Lock()
defer m.mux.Unlock()
routerPair, err := routeToRouterPair(m.wgInterface.Address().Masked().String(), route)
err := m.firewall.RemoveRoutingRules(routeToRouterPair(m.wgInterface.Address().String(), route))
if err != nil {
return fmt.Errorf("parse prefix: %w", err)
return err
}
err = m.firewall.RemoveRoutingRules(routerPair)
if err != nil {
return fmt.Errorf("remove routing rules: %w", err)
}
delete(m.routes, route.ID)
state := m.statusRecorder.GetLocalPeerState()
delete(state.Routes, route.Network.String())
m.statusRecorder.UpdateLocalPeerState(state)
return nil
}
}
@@ -111,31 +95,16 @@ func (m *defaultServerRouter) removeFromServerNetwork(route *route.Route) error
func (m *defaultServerRouter) addToServerNetwork(route *route.Route) error {
select {
case <-m.ctx.Done():
log.Infof("Not adding to server network because context is done")
log.Infof("not adding to server network because context is done")
return m.ctx.Err()
default:
m.mux.Lock()
defer m.mux.Unlock()
routerPair, err := routeToRouterPair(m.wgInterface.Address().Masked().String(), route)
err := m.firewall.InsertRoutingRules(routeToRouterPair(m.wgInterface.Address().String(), route))
if err != nil {
return fmt.Errorf("parse prefix: %w", err)
return err
}
err = m.firewall.InsertRoutingRules(routerPair)
if err != nil {
return fmt.Errorf("insert routing rules: %w", err)
}
m.routes[route.ID] = route
state := m.statusRecorder.GetLocalPeerState()
if state.Routes == nil {
state.Routes = map[string]struct{}{}
}
state.Routes[route.Network.String()] = struct{}{}
m.statusRecorder.UpdateLocalPeerState(state)
return nil
}
}
@@ -144,33 +113,19 @@ func (m *defaultServerRouter) cleanUp() {
m.mux.Lock()
defer m.mux.Unlock()
for _, r := range m.routes {
routerPair, err := routeToRouterPair(m.wgInterface.Address().Masked().String(), r)
err := m.firewall.RemoveRoutingRules(routeToRouterPair(m.wgInterface.Address().String(), r))
if err != nil {
log.Errorf("Failed to convert route to router pair: %v", err)
continue
log.Warnf("failed to remove clean up route: %s", r.ID)
}
err = m.firewall.RemoveRoutingRules(routerPair)
if err != nil {
log.Errorf("Failed to remove cleanup route: %v", err)
}
}
state := m.statusRecorder.GetLocalPeerState()
state.Routes = nil
m.statusRecorder.UpdateLocalPeerState(state)
}
func routeToRouterPair(source string, route *route.Route) (firewall.RouterPair, error) {
parsed, err := netip.ParsePrefix(source)
if err != nil {
return firewall.RouterPair{}, err
}
func routeToRouterPair(source string, route *route.Route) firewall.RouterPair {
parsed := netip.MustParsePrefix(source).Masked()
return firewall.RouterPair{
ID: route.ID,
Source: parsed.String(),
Destination: route.Network.Masked().String(),
Masquerade: route.Masquerade,
}, nil
}
}

View File

@@ -1,414 +0,0 @@
//go:build !android && !ios
package routemanager
import (
"context"
"errors"
"fmt"
"net"
"net/netip"
"runtime"
"strconv"
"github.com/hashicorp/go-multierror"
"github.com/libp2p/go-netroute"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/iface"
nbnet "github.com/netbirdio/netbird/util/net"
)
var splitDefaultv4_1 = netip.PrefixFrom(netip.IPv4Unspecified(), 1)
var splitDefaultv4_2 = netip.PrefixFrom(netip.AddrFrom4([4]byte{128}), 1)
var splitDefaultv6_1 = netip.PrefixFrom(netip.IPv6Unspecified(), 1)
var splitDefaultv6_2 = netip.PrefixFrom(netip.AddrFrom16([16]byte{0x80}), 1)
var ErrRouteNotFound = errors.New("route not found")
var ErrRouteNotAllowed = errors.New("route not allowed")
// TODO: fix: for default our wg address now appears as the default gw
func addRouteForCurrentDefaultGateway(prefix netip.Prefix) error {
addr := netip.IPv4Unspecified()
if prefix.Addr().Is6() {
addr = netip.IPv6Unspecified()
}
defaultGateway, _, err := getNextHop(addr)
if err != nil && !errors.Is(err, ErrRouteNotFound) {
return fmt.Errorf("get existing route gateway: %s", err)
}
if !prefix.Contains(defaultGateway) {
log.Debugf("Skipping adding a new route for gateway %s because it is not in the network %s", defaultGateway, prefix)
return nil
}
gatewayPrefix := netip.PrefixFrom(defaultGateway, 32)
if defaultGateway.Is6() {
gatewayPrefix = netip.PrefixFrom(defaultGateway, 128)
}
ok, err := existsInRouteTable(gatewayPrefix)
if err != nil {
return fmt.Errorf("unable to check if there is an existing route for gateway %s. error: %s", gatewayPrefix, err)
}
if ok {
log.Debugf("Skipping adding a new route for gateway %s because it already exists", gatewayPrefix)
return nil
}
gatewayHop, intf, err := getNextHop(defaultGateway)
if err != nil && !errors.Is(err, ErrRouteNotFound) {
return fmt.Errorf("unable to get the next hop for the default gateway address. error: %s", err)
}
log.Debugf("Adding a new route for gateway %s with next hop %s", gatewayPrefix, gatewayHop)
return addToRouteTable(gatewayPrefix, gatewayHop, intf)
}
func getNextHop(ip netip.Addr) (netip.Addr, *net.Interface, error) {
r, err := netroute.New()
if err != nil {
return netip.Addr{}, nil, fmt.Errorf("new netroute: %w", err)
}
intf, gateway, preferredSrc, err := r.Route(ip.AsSlice())
if err != nil {
log.Warnf("Failed to get route for %s: %v", ip, err)
return netip.Addr{}, nil, ErrRouteNotFound
}
log.Debugf("Route for %s: interface %v nexthop %v, preferred source %v", ip, intf, gateway, preferredSrc)
if gateway == nil {
if preferredSrc == nil {
return netip.Addr{}, nil, ErrRouteNotFound
}
log.Debugf("No next hop found for ip %s, using preferred source %s", ip, preferredSrc)
addr, err := ipToAddr(preferredSrc, intf)
if err != nil {
return netip.Addr{}, nil, fmt.Errorf("convert preferred source to address: %w", err)
}
return addr.Unmap(), intf, nil
}
addr, err := ipToAddr(gateway, intf)
if err != nil {
return netip.Addr{}, nil, fmt.Errorf("convert gateway to address: %w", err)
}
return addr, intf, nil
}
// converts a net.IP to a netip.Addr including the zone based on the passed interface
func ipToAddr(ip net.IP, intf *net.Interface) (netip.Addr, error) {
addr, ok := netip.AddrFromSlice(ip)
if !ok {
return netip.Addr{}, fmt.Errorf("failed to convert IP address to netip.Addr: %s", ip)
}
if intf != nil && (addr.IsLinkLocalMulticast() || addr.IsLinkLocalUnicast()) {
log.Tracef("Adding zone %s to address %s", intf.Name, addr)
if runtime.GOOS == "windows" {
addr = addr.WithZone(strconv.Itoa(intf.Index))
} else {
addr = addr.WithZone(intf.Name)
}
}
return addr.Unmap(), nil
}
func existsInRouteTable(prefix netip.Prefix) (bool, error) {
routes, err := getRoutesFromTable()
if err != nil {
return false, fmt.Errorf("get routes from table: %w", err)
}
for _, tableRoute := range routes {
if tableRoute == prefix {
return true, nil
}
}
return false, nil
}
func isSubRange(prefix netip.Prefix) (bool, error) {
routes, err := getRoutesFromTable()
if err != nil {
return false, fmt.Errorf("get routes from table: %w", err)
}
for _, tableRoute := range routes {
if tableRoute.Bits() > minRangeBits && tableRoute.Contains(prefix.Addr()) && tableRoute.Bits() < prefix.Bits() {
return true, nil
}
}
return false, nil
}
// addRouteToNonVPNIntf adds a new route to the routing table for the given prefix and returns the next hop and interface.
// If the next hop or interface is pointing to the VPN interface, it will return the initial values.
func addRouteToNonVPNIntf(prefix netip.Prefix, vpnIntf *iface.WGIface, initialNextHop netip.Addr, initialIntf *net.Interface) (netip.Addr, *net.Interface, error) {
addr := prefix.Addr()
switch {
case addr.IsLoopback(),
addr.IsLinkLocalUnicast(),
addr.IsLinkLocalMulticast(),
addr.IsInterfaceLocalMulticast(),
addr.IsUnspecified(),
addr.IsMulticast():
return netip.Addr{}, nil, ErrRouteNotAllowed
}
// Determine the exit interface and next hop for the prefix, so we can add a specific route
nexthop, intf, err := getNextHop(addr)
if err != nil {
return netip.Addr{}, nil, fmt.Errorf("get next hop: %w", err)
}
log.Debugf("Found next hop %s for prefix %s with interface %v", nexthop, prefix, intf)
exitNextHop := nexthop
exitIntf := intf
vpnAddr, ok := netip.AddrFromSlice(vpnIntf.Address().IP)
if !ok {
return netip.Addr{}, nil, fmt.Errorf("failed to convert vpn address to netip.Addr")
}
// if next hop is the VPN address or the interface is the VPN interface, we should use the initial values
if exitNextHop == vpnAddr || exitIntf != nil && exitIntf.Name == vpnIntf.Name() {
log.Debugf("Route for prefix %s is pointing to the VPN interface", prefix)
exitNextHop = initialNextHop
exitIntf = initialIntf
}
log.Debugf("Adding a new route for prefix %s with next hop %s", prefix, exitNextHop)
if err := addToRouteTable(prefix, exitNextHop, exitIntf); err != nil {
return netip.Addr{}, nil, fmt.Errorf("add route to table: %w", err)
}
return exitNextHop, exitIntf, nil
}
// genericAddVPNRoute adds a new route to the vpn interface, it splits the default prefix
// in two /1 prefixes to avoid replacing the existing default route
func genericAddVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
if prefix == defaultv4 {
if err := addToRouteTable(splitDefaultv4_1, netip.Addr{}, intf); err != nil {
return err
}
if err := addToRouteTable(splitDefaultv4_2, netip.Addr{}, intf); err != nil {
if err2 := removeFromRouteTable(splitDefaultv4_1, netip.Addr{}, intf); err2 != nil {
log.Warnf("Failed to rollback route addition: %s", err2)
}
return err
}
// TODO: remove once IPv6 is supported on the interface
if err := addToRouteTable(splitDefaultv6_1, netip.Addr{}, intf); err != nil {
return fmt.Errorf("add unreachable route split 1: %w", err)
}
if err := addToRouteTable(splitDefaultv6_2, netip.Addr{}, intf); err != nil {
if err2 := removeFromRouteTable(splitDefaultv6_1, netip.Addr{}, intf); err2 != nil {
log.Warnf("Failed to rollback route addition: %s", err2)
}
return fmt.Errorf("add unreachable route split 2: %w", err)
}
return nil
} else if prefix == defaultv6 {
if err := addToRouteTable(splitDefaultv6_1, netip.Addr{}, intf); err != nil {
return fmt.Errorf("add unreachable route split 1: %w", err)
}
if err := addToRouteTable(splitDefaultv6_2, netip.Addr{}, intf); err != nil {
if err2 := removeFromRouteTable(splitDefaultv6_1, netip.Addr{}, intf); err2 != nil {
log.Warnf("Failed to rollback route addition: %s", err2)
}
return fmt.Errorf("add unreachable route split 2: %w", err)
}
return nil
}
return addNonExistingRoute(prefix, intf)
}
// addNonExistingRoute adds a new route to the vpn interface if it doesn't exist in the current routing table
func addNonExistingRoute(prefix netip.Prefix, intf *net.Interface) error {
ok, err := existsInRouteTable(prefix)
if err != nil {
return fmt.Errorf("exists in route table: %w", err)
}
if ok {
log.Warnf("Skipping adding a new route for network %s because it already exists", prefix)
return nil
}
ok, err = isSubRange(prefix)
if err != nil {
return fmt.Errorf("sub range: %w", err)
}
if ok {
err := addRouteForCurrentDefaultGateway(prefix)
if err != nil {
log.Warnf("Unable to add route for current default gateway route. Will proceed without it. error: %s", err)
}
}
return addToRouteTable(prefix, netip.Addr{}, intf)
}
// genericRemoveVPNRoute removes the route from the vpn interface. If a default prefix is given,
// it will remove the split /1 prefixes
func genericRemoveVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
if prefix == defaultv4 {
var result *multierror.Error
if err := removeFromRouteTable(splitDefaultv4_1, netip.Addr{}, intf); err != nil {
result = multierror.Append(result, err)
}
if err := removeFromRouteTable(splitDefaultv4_2, netip.Addr{}, intf); err != nil {
result = multierror.Append(result, err)
}
// TODO: remove once IPv6 is supported on the interface
if err := removeFromRouteTable(splitDefaultv6_1, netip.Addr{}, intf); err != nil {
result = multierror.Append(result, err)
}
if err := removeFromRouteTable(splitDefaultv6_2, netip.Addr{}, intf); err != nil {
result = multierror.Append(result, err)
}
return result.ErrorOrNil()
} else if prefix == defaultv6 {
var result *multierror.Error
if err := removeFromRouteTable(splitDefaultv6_1, netip.Addr{}, intf); err != nil {
result = multierror.Append(result, err)
}
if err := removeFromRouteTable(splitDefaultv6_2, netip.Addr{}, intf); err != nil {
result = multierror.Append(result, err)
}
return result.ErrorOrNil()
}
return removeFromRouteTable(prefix, netip.Addr{}, intf)
}
func getPrefixFromIP(ip net.IP) (*netip.Prefix, error) {
addr, ok := netip.AddrFromSlice(ip)
if !ok {
return nil, fmt.Errorf("parse IP address: %s", ip)
}
addr = addr.Unmap()
var prefixLength int
switch {
case addr.Is4():
prefixLength = 32
case addr.Is6():
prefixLength = 128
default:
return nil, fmt.Errorf("invalid IP address: %s", addr)
}
prefix := netip.PrefixFrom(addr, prefixLength)
return &prefix, nil
}
func setupRoutingWithRouteManager(routeManager **RouteManager, initAddresses []net.IP, wgIface *iface.WGIface) (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) {
initialNextHopV4, initialIntfV4, err := getNextHop(netip.IPv4Unspecified())
if err != nil && !errors.Is(err, ErrRouteNotFound) {
log.Errorf("Unable to get initial v4 default next hop: %v", err)
}
initialNextHopV6, initialIntfV6, err := getNextHop(netip.IPv6Unspecified())
if err != nil && !errors.Is(err, ErrRouteNotFound) {
log.Errorf("Unable to get initial v6 default next hop: %v", err)
}
*routeManager = NewRouteManager(
func(prefix netip.Prefix) (netip.Addr, *net.Interface, error) {
addr := prefix.Addr()
nexthop, intf := initialNextHopV4, initialIntfV4
if addr.Is6() {
nexthop, intf = initialNextHopV6, initialIntfV6
}
return addRouteToNonVPNIntf(prefix, wgIface, nexthop, intf)
},
removeFromRouteTable,
)
return setupHooks(*routeManager, initAddresses)
}
func cleanupRoutingWithRouteManager(routeManager *RouteManager) error {
if routeManager == nil {
return nil
}
// TODO: Remove hooks selectively
nbnet.RemoveDialerHooks()
nbnet.RemoveListenerHooks()
if err := routeManager.Flush(); err != nil {
return fmt.Errorf("flush route manager: %w", err)
}
return nil
}
func setupHooks(routeManager *RouteManager, initAddresses []net.IP) (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) {
beforeHook := func(connID nbnet.ConnectionID, ip net.IP) error {
prefix, err := getPrefixFromIP(ip)
if err != nil {
return fmt.Errorf("convert ip to prefix: %w", err)
}
if err := routeManager.AddRouteRef(connID, *prefix); err != nil {
return fmt.Errorf("adding route reference: %v", err)
}
return nil
}
afterHook := func(connID nbnet.ConnectionID) error {
if err := routeManager.RemoveRouteRef(connID); err != nil {
return fmt.Errorf("remove route reference: %w", err)
}
return nil
}
for _, ip := range initAddresses {
if err := beforeHook("init", ip); err != nil {
log.Errorf("Failed to add route reference: %v", err)
}
}
nbnet.AddDialerHook(func(ctx context.Context, connID nbnet.ConnectionID, resolvedIPs []net.IPAddr) error {
if ctx.Err() != nil {
return ctx.Err()
}
var result *multierror.Error
for _, ip := range resolvedIPs {
result = multierror.Append(result, beforeHook(connID, ip.IP))
}
return result.ErrorOrNil()
})
nbnet.AddDialerCloseHook(func(connID nbnet.ConnectionID, conn *net.Conn) error {
return afterHook(connID)
})
nbnet.AddListenerWriteHook(func(connID nbnet.ConnectionID, ip *net.IPAddr, data []byte) error {
return beforeHook(connID, ip.IP)
})
nbnet.AddListenerCloseHook(func(connID nbnet.ConnectionID, conn net.PacketConn) error {
return afterHook(connID)
})
return beforeHook, afterHook, nil
}

View File

@@ -1,33 +1,13 @@
package routemanager
import (
"net"
"net/netip"
"runtime"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/iface"
)
func setupRouting([]net.IP, *iface.WGIface) (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) {
return nil, nil, nil
}
func cleanupRouting() error {
func addToRouteTableIfNoExists(prefix netip.Prefix, addr string) error {
return nil
}
func enableIPForwarding() error {
log.Infof("Enable IP forwarding is not implemented on %s", runtime.GOOS)
return nil
}
func addVPNRoute(netip.Prefix, *net.Interface) error {
return nil
}
func removeVPNRoute(netip.Prefix, *net.Interface) error {
func removeFromRouteTableIfNonSystem(prefix netip.Prefix, addr string) error {
return nil
}

View File

@@ -1,4 +1,5 @@
//go:build darwin || dragonfly || freebsd || netbsd || openbsd
// +build darwin dragonfly freebsd netbsd openbsd
package routemanager
@@ -8,7 +9,6 @@ import (
"net/netip"
"syscall"
log "github.com/sirupsen/logrus"
"golang.org/x/net/route"
)
@@ -52,24 +52,16 @@ func getRoutesFromTable() ([]netip.Prefix, error) {
continue
}
if len(m.Addrs) < 3 {
log.Warnf("Unexpected RIB message Addrs: %v", m.Addrs)
continue
}
addr, ok := toNetIPAddr(m.Addrs[0])
if !ok {
continue
}
cidr := 32
if mask := m.Addrs[2]; mask != nil {
cidr, ok = toCIDR(mask)
if !ok {
log.Debugf("Unexpected RIB message Addrs[2]: %v", mask)
continue
}
mask, ok := toNetIPMASK(m.Addrs[2])
if !ok {
continue
}
cidr, _ := mask.Size()
routePrefix := netip.PrefixFrom(addr, cidr)
if routePrefix.IsValid() {
@@ -82,19 +74,20 @@ func getRoutesFromTable() ([]netip.Prefix, error) {
func toNetIPAddr(a route.Addr) (netip.Addr, bool) {
switch t := a.(type) {
case *route.Inet4Addr:
return netip.AddrFrom4(t.IP), true
ip := net.IPv4(t.IP[0], t.IP[1], t.IP[2], t.IP[3])
addr := netip.MustParseAddr(ip.String())
return addr, true
default:
return netip.Addr{}, false
}
}
func toCIDR(a route.Addr) (int, bool) {
func toNetIPMASK(a route.Addr) (net.IPMask, bool) {
switch t := a.(type) {
case *route.Inet4Addr:
mask := net.IPv4Mask(t.IP[0], t.IP[1], t.IP[2], t.IP[3])
cidr, _ := mask.Size()
return cidr, true
return mask, true
default:
return 0, false
return nil, false
}
}

View File

@@ -1,89 +0,0 @@
//go:build darwin && !ios
package routemanager
import (
"fmt"
"net"
"net/netip"
"os/exec"
"strings"
"time"
"github.com/cenkalti/backoff/v4"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/iface"
)
var routeManager *RouteManager
func setupRouting(initAddresses []net.IP, wgIface *iface.WGIface) (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) {
return setupRoutingWithRouteManager(&routeManager, initAddresses, wgIface)
}
func cleanupRouting() error {
return cleanupRoutingWithRouteManager(routeManager)
}
func addToRouteTable(prefix netip.Prefix, nexthop netip.Addr, intf *net.Interface) error {
return routeCmd("add", prefix, nexthop, intf)
}
func removeFromRouteTable(prefix netip.Prefix, nexthop netip.Addr, intf *net.Interface) error {
return routeCmd("delete", prefix, nexthop, intf)
}
func routeCmd(action string, prefix netip.Prefix, nexthop netip.Addr, intf *net.Interface) error {
inet := "-inet"
network := prefix.String()
if prefix.IsSingleIP() {
network = prefix.Addr().String()
}
if prefix.Addr().Is6() {
inet = "-inet6"
// Special case for IPv6 split default route, pointing to the wg interface fails
// TODO: Remove once we have IPv6 support on the interface
if prefix.Bits() == 1 {
intf = &net.Interface{Name: "lo0"}
}
}
args := []string{"-n", action, inet, network}
if nexthop.IsValid() {
args = append(args, nexthop.Unmap().String())
} else if intf != nil {
args = append(args, "-interface", intf.Name)
}
if err := retryRouteCmd(args); err != nil {
return fmt.Errorf("failed to %s route for %s: %w", action, prefix, err)
}
return nil
}
func retryRouteCmd(args []string) error {
operation := func() error {
out, err := exec.Command("route", args...).CombinedOutput()
log.Tracef("route %s: %s", strings.Join(args, " "), out)
// https://github.com/golang/go/issues/45736
if err != nil && strings.Contains(string(out), "sysctl: cannot allocate memory") {
return err
} else if err != nil {
return backoff.Permanent(err)
}
return nil
}
expBackOff := backoff.NewExponentialBackOff()
expBackOff.InitialInterval = 50 * time.Millisecond
expBackOff.MaxInterval = 500 * time.Millisecond
expBackOff.MaxElapsedTime = 1 * time.Second
err := backoff.Retry(operation, expBackOff)
if err != nil {
return fmt.Errorf("route cmd retry failed: %w", err)
}
return nil
}

View File

@@ -1,138 +0,0 @@
//go:build !ios
package routemanager
import (
"fmt"
"net"
"net/netip"
"os/exec"
"regexp"
"sync"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
var expectedVPNint = "utun100"
var expectedExternalInt = "lo0"
var expectedInternalInt = "lo0"
func init() {
testCases = append(testCases, []testCase{
{
name: "To more specific route without custom dialer via vpn",
destination: "10.10.0.2:53",
expectedInterface: expectedVPNint,
dialer: &net.Dialer{},
expectedPacket: createPacketExpectation("100.64.0.1", 12345, "10.10.0.2", 53),
},
}...)
}
func TestConcurrentRoutes(t *testing.T) {
baseIP := netip.MustParseAddr("192.0.2.0")
intf := &net.Interface{Name: "lo0"}
var wg sync.WaitGroup
for i := 0; i < 1024; i++ {
wg.Add(1)
go func(ip netip.Addr) {
defer wg.Done()
prefix := netip.PrefixFrom(ip, 32)
if err := addToRouteTable(prefix, netip.Addr{}, intf); err != nil {
t.Errorf("Failed to add route for %s: %v", prefix, err)
}
}(baseIP)
baseIP = baseIP.Next()
}
wg.Wait()
baseIP = netip.MustParseAddr("192.0.2.0")
for i := 0; i < 1024; i++ {
wg.Add(1)
go func(ip netip.Addr) {
defer wg.Done()
prefix := netip.PrefixFrom(ip, 32)
if err := removeFromRouteTable(prefix, netip.Addr{}, intf); err != nil {
t.Errorf("Failed to remove route for %s: %v", prefix, err)
}
}(baseIP)
baseIP = baseIP.Next()
}
wg.Wait()
}
func createAndSetupDummyInterface(t *testing.T, intf string, ipAddressCIDR string) string {
t.Helper()
err := exec.Command("ifconfig", intf, "alias", ipAddressCIDR).Run()
require.NoError(t, err, "Failed to create loopback alias")
t.Cleanup(func() {
err := exec.Command("ifconfig", intf, ipAddressCIDR, "-alias").Run()
assert.NoError(t, err, "Failed to remove loopback alias")
})
return "lo0"
}
func addDummyRoute(t *testing.T, dstCIDR string, gw net.IP, _ string) {
t.Helper()
var originalNexthop net.IP
if dstCIDR == "0.0.0.0/0" {
var err error
originalNexthop, err = fetchOriginalGateway()
if err != nil {
t.Logf("Failed to fetch original gateway: %v", err)
}
if output, err := exec.Command("route", "delete", "-net", dstCIDR).CombinedOutput(); err != nil {
t.Logf("Failed to delete route: %v, output: %s", err, output)
}
}
t.Cleanup(func() {
if originalNexthop != nil {
err := exec.Command("route", "add", "-net", dstCIDR, originalNexthop.String()).Run()
assert.NoError(t, err, "Failed to restore original route")
}
})
err := exec.Command("route", "add", "-net", dstCIDR, gw.String()).Run()
require.NoError(t, err, "Failed to add route")
t.Cleanup(func() {
err := exec.Command("route", "delete", "-net", dstCIDR).Run()
assert.NoError(t, err, "Failed to remove route")
})
}
func fetchOriginalGateway() (net.IP, error) {
output, err := exec.Command("route", "-n", "get", "default").CombinedOutput()
if err != nil {
return nil, err
}
matches := regexp.MustCompile(`gateway: (\S+)`).FindStringSubmatch(string(output))
if len(matches) == 0 {
return nil, fmt.Errorf("gateway not found")
}
return net.ParseIP(matches[1]), nil
}
func setupDummyInterfacesAndRoutes(t *testing.T) {
t.Helper()
defaultDummy := createAndSetupDummyInterface(t, expectedExternalInt, "192.168.0.1/24")
addDummyRoute(t, "0.0.0.0/0", net.IPv4(192, 168, 0, 1), defaultDummy)
otherDummy := createAndSetupDummyInterface(t, expectedInternalInt, "192.168.1.1/24")
addDummyRoute(t, "10.0.0.0/8", net.IPv4(192, 168, 1, 1), otherDummy)
}

View File

@@ -1,33 +1,15 @@
//go:build ios
package routemanager
import (
"net"
"net/netip"
"runtime"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/iface"
)
func setupRouting([]net.IP, *iface.WGIface) (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) {
return nil, nil, nil
}
func cleanupRouting() error {
func addToRouteTableIfNoExists(prefix netip.Prefix, addr string) error {
return nil
}
func enableIPForwarding() error {
log.Infof("Enable IP forwarding is not implemented on %s", runtime.GOOS)
return nil
}
func addVPNRoute(netip.Prefix, *net.Interface) error {
return nil
}
func removeVPNRoute(netip.Prefix, *net.Interface) error {
func removeFromRouteTableIfNonSystem(prefix netip.Prefix, addr string) error {
return nil
}

View File

@@ -3,588 +3,149 @@
package routemanager
import (
"bufio"
"errors"
"fmt"
"net"
"net/netip"
"os"
"strconv"
"strings"
"syscall"
"unsafe"
"github.com/hashicorp/go-multierror"
log "github.com/sirupsen/logrus"
"github.com/vishvananda/netlink"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/iface"
nbnet "github.com/netbirdio/netbird/util/net"
)
const (
// NetbirdVPNTableID is the ID of the custom routing table used by Netbird.
NetbirdVPNTableID = 0x1BD0
// NetbirdVPNTableName is the name of the custom routing table used by Netbird.
NetbirdVPNTableName = "netbird"
// Pulled from http://man7.org/linux/man-pages/man7/rtnetlink.7.html
// See the section on RTM_NEWROUTE, specifically 'struct rtmsg'.
type routeInfoInMemory struct {
Family byte
DstLen byte
SrcLen byte
TOS byte
// rtTablesPath is the path to the file containing the routing table names.
rtTablesPath = "/etc/iproute2/rt_tables"
Table byte
Protocol byte
Scope byte
Type byte
// ipv4ForwardingPath is the path to the file containing the IP forwarding setting.
ipv4ForwardingPath = "net.ipv4.ip_forward"
rpFilterPath = "net.ipv4.conf.all.rp_filter"
rpFilterInterfacePath = "net.ipv4.conf.%s.rp_filter"
srcValidMarkPath = "net.ipv4.conf.all.src_valid_mark"
)
var ErrTableIDExists = errors.New("ID exists with different name")
var routeManager = &RouteManager{}
// originalSysctl stores the original sysctl values before they are modified
var originalSysctl map[string]int
// sysctlFailed is used as an indicator to emit a warning when default routes are configured
var sysctlFailed bool
type ruleParams struct {
priority int
fwmark int
tableID int
family int
invert bool
suppressPrefix int
description string
Flags uint32
}
// isLegacy determines whether to use the legacy routing setup
func isLegacy() bool {
return os.Getenv("NB_USE_LEGACY_ROUTING") == "true" || nbnet.CustomRoutingDisabled()
}
const ipv4ForwardingPath = "/proc/sys/net/ipv4/ip_forward"
// setIsLegacy sets the legacy routing setup
func setIsLegacy(b bool) {
if b {
os.Setenv("NB_USE_LEGACY_ROUTING", "true")
} else {
os.Unsetenv("NB_USE_LEGACY_ROUTING")
}
}
func getSetupRules() []ruleParams {
return []ruleParams{
{100, -1, syscall.RT_TABLE_MAIN, netlink.FAMILY_V4, false, 0, "rule with suppress prefixlen v4"},
{100, -1, syscall.RT_TABLE_MAIN, netlink.FAMILY_V6, false, 0, "rule with suppress prefixlen v6"},
{110, nbnet.NetbirdFwmark, NetbirdVPNTableID, netlink.FAMILY_V4, true, -1, "rule v4 netbird"},
{110, nbnet.NetbirdFwmark, NetbirdVPNTableID, netlink.FAMILY_V6, true, -1, "rule v6 netbird"},
}
}
// setupRouting establishes the routing configuration for the VPN, including essential rules
// to ensure proper traffic flow for management, locally configured routes, and VPN traffic.
//
// Rule 1 (Main Route Precedence): Safeguards locally installed routes by giving them precedence over
// potential routes received and configured for the VPN. This rule is skipped for the default route and routes
// that are not in the main table.
//
// Rule 2 (VPN Traffic Routing): Directs all remaining traffic to the 'NetbirdVPNTableID' custom routing table.
// This table is where a default route or other specific routes received from the management server are configured,
// enabling VPN connectivity.
func setupRouting(initAddresses []net.IP, wgIface *iface.WGIface) (_ peer.BeforeAddPeerHookFunc, _ peer.AfterRemovePeerHookFunc, err error) {
if isLegacy() {
log.Infof("Using legacy routing setup")
return setupRoutingWithRouteManager(&routeManager, initAddresses, wgIface)
}
if err = addRoutingTableName(); err != nil {
log.Errorf("Error adding routing table name: %v", err)
}
originalValues, err := setupSysctl(wgIface)
func addToRouteTable(prefix netip.Prefix, addr string) error {
_, ipNet, err := net.ParseCIDR(prefix.String())
if err != nil {
log.Errorf("Error setting up sysctl: %v", err)
sysctlFailed = true
}
originalSysctl = originalValues
defer func() {
if err != nil {
if cleanErr := cleanupRouting(); cleanErr != nil {
log.Errorf("Error cleaning up routing: %v", cleanErr)
}
}
}()
rules := getSetupRules()
for _, rule := range rules {
if err := addRule(rule); err != nil {
if errors.Is(err, syscall.EOPNOTSUPP) {
log.Warnf("Rule operations are not supported, falling back to the legacy routing setup")
setIsLegacy(true)
return setupRoutingWithRouteManager(&routeManager, initAddresses, wgIface)
}
return nil, nil, fmt.Errorf("%s: %w", rule.description, err)
}
return err
}
return nil, nil, nil
}
// cleanupRouting performs a thorough cleanup of the routing configuration established by 'setupRouting'.
// It systematically removes the three rules and any associated routing table entries to ensure a clean state.
// The function uses error aggregation to report any errors encountered during the cleanup process.
func cleanupRouting() error {
if isLegacy() {
return cleanupRoutingWithRouteManager(routeManager)
addrMask := "/32"
if prefix.Addr().Unmap().Is6() {
addrMask = "/128"
}
var result *multierror.Error
if err := flushRoutes(NetbirdVPNTableID, netlink.FAMILY_V4); err != nil {
result = multierror.Append(result, fmt.Errorf("flush routes v4: %w", err))
}
if err := flushRoutes(NetbirdVPNTableID, netlink.FAMILY_V6); err != nil {
result = multierror.Append(result, fmt.Errorf("flush routes v6: %w", err))
ip, _, err := net.ParseCIDR(addr + addrMask)
if err != nil {
return err
}
rules := getSetupRules()
for _, rule := range rules {
if err := removeRule(rule); err != nil {
result = multierror.Append(result, fmt.Errorf("%s: %w", rule.description, err))
}
route := &netlink.Route{
Scope: netlink.SCOPE_UNIVERSE,
Dst: ipNet,
Gw: ip,
}
if err := cleanupSysctl(originalSysctl); err != nil {
result = multierror.Append(result, fmt.Errorf("cleanup sysctl: %w", err))
}
originalSysctl = nil
sysctlFailed = false
return result.ErrorOrNil()
}
func addToRouteTable(prefix netip.Prefix, nexthop netip.Addr, intf *net.Interface) error {
return addRoute(prefix, nexthop, intf, syscall.RT_TABLE_MAIN)
}
func removeFromRouteTable(prefix netip.Prefix, nexthop netip.Addr, intf *net.Interface) error {
return removeRoute(prefix, nexthop, intf, syscall.RT_TABLE_MAIN)
}
func addVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
if isLegacy() {
return genericAddVPNRoute(prefix, intf)
err = netlink.RouteAdd(route)
if err != nil {
return err
}
if sysctlFailed && (prefix == defaultv4 || prefix == defaultv6) {
log.Warnf("Default route is configured but sysctl operations failed, VPN traffic may not be routed correctly, consider using NB_USE_LEGACY_ROUTING=true or setting net.ipv4.conf.*.rp_filter to 2 (loose) or 0 (off)")
}
// No need to check if routes exist as main table takes precedence over the VPN table via Rule 1
// TODO remove this once we have ipv6 support
if prefix == defaultv4 {
if err := addUnreachableRoute(defaultv6, NetbirdVPNTableID); err != nil {
return fmt.Errorf("add blackhole: %w", err)
}
}
if err := addRoute(prefix, netip.Addr{}, intf, NetbirdVPNTableID); err != nil {
return fmt.Errorf("add route: %w", err)
}
return nil
}
func removeVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
if isLegacy() {
return genericRemoveVPNRoute(prefix, intf)
func removeFromRouteTable(prefix netip.Prefix, addr string) error {
_, ipNet, err := net.ParseCIDR(prefix.String())
if err != nil {
return err
}
// TODO remove this once we have ipv6 support
if prefix == defaultv4 {
if err := removeUnreachableRoute(defaultv6, NetbirdVPNTableID); err != nil {
return fmt.Errorf("remove unreachable route: %w", err)
}
addrMask := "/32"
if prefix.Addr().Unmap().Is6() {
addrMask = "/128"
}
if err := removeRoute(prefix, netip.Addr{}, intf, NetbirdVPNTableID); err != nil {
return fmt.Errorf("remove route: %w", err)
ip, _, err := net.ParseCIDR(addr + addrMask)
if err != nil {
return err
}
route := &netlink.Route{
Scope: netlink.SCOPE_UNIVERSE,
Dst: ipNet,
Gw: ip,
}
err = netlink.RouteDel(route)
if err != nil {
return err
}
return nil
}
func getRoutesFromTable() ([]netip.Prefix, error) {
v4Routes, err := getRoutes(syscall.RT_TABLE_MAIN, netlink.FAMILY_V4)
tab, err := syscall.NetlinkRIB(syscall.RTM_GETROUTE, syscall.AF_UNSPEC)
if err != nil {
return nil, fmt.Errorf("get v4 routes: %w", err)
return nil, err
}
v6Routes, err := getRoutes(syscall.RT_TABLE_MAIN, netlink.FAMILY_V6)
msgs, err := syscall.ParseNetlinkMessage(tab)
if err != nil {
return nil, fmt.Errorf("get v6 routes: %w", err)
return nil, err
}
return append(v4Routes, v6Routes...), nil
}
// getRoutes fetches routes from a specific routing table identified by tableID.
func getRoutes(tableID, family int) ([]netip.Prefix, error) {
var prefixList []netip.Prefix
routes, err := netlink.RouteListFiltered(family, &netlink.Route{Table: tableID}, netlink.RT_FILTER_TABLE)
if err != nil {
return nil, fmt.Errorf("list routes from table %d: %v", tableID, err)
}
for _, route := range routes {
if route.Dst != nil {
addr, ok := netip.AddrFromSlice(route.Dst.IP)
if !ok {
return nil, fmt.Errorf("parse route destination IP: %v", route.Dst.IP)
loop:
for _, m := range msgs {
switch m.Header.Type {
case syscall.NLMSG_DONE:
break loop
case syscall.RTM_NEWROUTE:
rt := (*routeInfoInMemory)(unsafe.Pointer(&m.Data[0]))
msg := m
attrs, err := syscall.ParseNetlinkRouteAttr(&msg)
if err != nil {
return nil, err
}
if rt.Family != syscall.AF_INET {
continue loop
}
ones, _ := route.Dst.Mask.Size()
prefix := netip.PrefixFrom(addr, ones)
if prefix.IsValid() {
prefixList = append(prefixList, prefix)
for _, attr := range attrs {
if attr.Attr.Type == syscall.RTA_DST {
addr, ok := netip.AddrFromSlice(attr.Value)
if !ok {
continue
}
mask := net.CIDRMask(int(rt.DstLen), len(attr.Value)*8)
cidr, _ := mask.Size()
routePrefix := netip.PrefixFrom(addr, cidr)
if routePrefix.IsValid() && routePrefix.Addr().Is4() {
prefixList = append(prefixList, routePrefix)
}
}
}
}
}
return prefixList, nil
}
// addRoute adds a route to a specific routing table identified by tableID.
func addRoute(prefix netip.Prefix, addr netip.Addr, intf *net.Interface, tableID int) error {
route := &netlink.Route{
Scope: netlink.SCOPE_UNIVERSE,
Table: tableID,
Family: getAddressFamily(prefix),
}
_, ipNet, err := net.ParseCIDR(prefix.String())
if err != nil {
return fmt.Errorf("parse prefix %s: %w", prefix, err)
}
route.Dst = ipNet
if err := addNextHop(addr, intf, route); err != nil {
return fmt.Errorf("add gateway and device: %w", err)
}
if err := netlink.RouteAdd(route); err != nil && !errors.Is(err, syscall.EEXIST) && !errors.Is(err, syscall.EAFNOSUPPORT) {
return fmt.Errorf("netlink add route: %w", err)
}
return nil
}
// addUnreachableRoute adds an unreachable route for the specified IP family and routing table.
// ipFamily should be netlink.FAMILY_V4 for IPv4 or netlink.FAMILY_V6 for IPv6.
// tableID specifies the routing table to which the unreachable route will be added.
func addUnreachableRoute(prefix netip.Prefix, tableID int) error {
_, ipNet, err := net.ParseCIDR(prefix.String())
if err != nil {
return fmt.Errorf("parse prefix %s: %w", prefix, err)
}
route := &netlink.Route{
Type: syscall.RTN_UNREACHABLE,
Table: tableID,
Family: getAddressFamily(prefix),
Dst: ipNet,
}
if err := netlink.RouteAdd(route); err != nil && !errors.Is(err, syscall.EEXIST) && !errors.Is(err, syscall.EAFNOSUPPORT) {
return fmt.Errorf("netlink add unreachable route: %w", err)
}
return nil
}
func removeUnreachableRoute(prefix netip.Prefix, tableID int) error {
_, ipNet, err := net.ParseCIDR(prefix.String())
if err != nil {
return fmt.Errorf("parse prefix %s: %w", prefix, err)
}
route := &netlink.Route{
Type: syscall.RTN_UNREACHABLE,
Table: tableID,
Family: getAddressFamily(prefix),
Dst: ipNet,
}
if err := netlink.RouteDel(route); err != nil &&
!errors.Is(err, syscall.ESRCH) &&
!errors.Is(err, syscall.ENOENT) &&
!errors.Is(err, syscall.EAFNOSUPPORT) {
return fmt.Errorf("netlink remove unreachable route: %w", err)
}
return nil
}
// removeRoute removes a route from a specific routing table identified by tableID.
func removeRoute(prefix netip.Prefix, addr netip.Addr, intf *net.Interface, tableID int) error {
_, ipNet, err := net.ParseCIDR(prefix.String())
if err != nil {
return fmt.Errorf("parse prefix %s: %w", prefix, err)
}
route := &netlink.Route{
Scope: netlink.SCOPE_UNIVERSE,
Table: tableID,
Family: getAddressFamily(prefix),
Dst: ipNet,
}
if err := addNextHop(addr, intf, route); err != nil {
return fmt.Errorf("add gateway and device: %w", err)
}
if err := netlink.RouteDel(route); err != nil && !errors.Is(err, syscall.ESRCH) && !errors.Is(err, syscall.EAFNOSUPPORT) {
return fmt.Errorf("netlink remove route: %w", err)
}
return nil
}
func flushRoutes(tableID, family int) error {
routes, err := netlink.RouteListFiltered(family, &netlink.Route{Table: tableID}, netlink.RT_FILTER_TABLE)
if err != nil {
return fmt.Errorf("list routes from table %d: %w", tableID, err)
}
var result *multierror.Error
for i := range routes {
route := routes[i]
// unreachable default routes don't come back with Dst set
if route.Gw == nil && route.Src == nil && route.Dst == nil {
if family == netlink.FAMILY_V4 {
routes[i].Dst = &net.IPNet{IP: net.IPv4zero, Mask: net.CIDRMask(0, 32)}
} else {
routes[i].Dst = &net.IPNet{IP: net.IPv6zero, Mask: net.CIDRMask(0, 128)}
}
}
if err := netlink.RouteDel(&routes[i]); err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) {
result = multierror.Append(result, fmt.Errorf("failed to delete route %v from table %d: %w", routes[i], tableID, err))
}
}
return result.ErrorOrNil()
}
func enableIPForwarding() error {
_, err := setSysctl(ipv4ForwardingPath, 1, false)
return err
}
// entryExists checks if the specified ID or name already exists in the rt_tables file
// and verifies if existing names start with "netbird_".
func entryExists(file *os.File, id int) (bool, error) {
if _, err := file.Seek(0, 0); err != nil {
return false, fmt.Errorf("seek rt_tables: %w", err)
}
scanner := bufio.NewScanner(file)
for scanner.Scan() {
line := scanner.Text()
var existingID int
var existingName string
if _, err := fmt.Sscanf(line, "%d %s\n", &existingID, &existingName); err == nil {
if existingID == id {
if existingName != NetbirdVPNTableName {
return true, ErrTableIDExists
}
return true, nil
}
}
}
if err := scanner.Err(); err != nil {
return false, fmt.Errorf("scan rt_tables: %w", err)
}
return false, nil
}
// addRoutingTableName adds human-readable names for custom routing tables.
func addRoutingTableName() error {
file, err := os.Open(rtTablesPath)
bytes, err := os.ReadFile(ipv4ForwardingPath)
if err != nil {
if errors.Is(err, os.ErrNotExist) {
return nil
}
return fmt.Errorf("open rt_tables: %w", err)
return err
}
defer func() {
if err := file.Close(); err != nil {
log.Errorf("Error closing rt_tables: %v", err)
}
}()
exists, err := entryExists(file, NetbirdVPNTableID)
if err != nil {
return fmt.Errorf("verify entry %d, %s: %w", NetbirdVPNTableID, NetbirdVPNTableName, err)
}
if exists {
// check if it is already enabled
// see more: https://github.com/netbirdio/netbird/issues/872
if len(bytes) > 0 && bytes[0] == 49 {
return nil
}
// Reopen the file in append mode to add new entries
if err := file.Close(); err != nil {
log.Errorf("Error closing rt_tables before appending: %v", err)
}
file, err = os.OpenFile(rtTablesPath, os.O_WRONLY|os.O_APPEND|os.O_CREATE, 0644)
if err != nil {
return fmt.Errorf("open rt_tables for appending: %w", err)
}
if _, err := file.WriteString(fmt.Sprintf("\n%d\t%s\n", NetbirdVPNTableID, NetbirdVPNTableName)); err != nil {
return fmt.Errorf("append entry to rt_tables: %w", err)
}
return nil
}
// addRule adds a routing rule to a specific routing table identified by tableID.
func addRule(params ruleParams) error {
rule := netlink.NewRule()
rule.Table = params.tableID
rule.Mark = params.fwmark
rule.Family = params.family
rule.Priority = params.priority
rule.Invert = params.invert
rule.SuppressPrefixlen = params.suppressPrefix
if err := netlink.RuleAdd(rule); err != nil && !errors.Is(err, syscall.EEXIST) && !errors.Is(err, syscall.EAFNOSUPPORT) {
return fmt.Errorf("add routing rule: %w", err)
}
return nil
}
// removeRule removes a routing rule from a specific routing table identified by tableID.
func removeRule(params ruleParams) error {
rule := netlink.NewRule()
rule.Table = params.tableID
rule.Mark = params.fwmark
rule.Family = params.family
rule.Invert = params.invert
rule.Priority = params.priority
rule.SuppressPrefixlen = params.suppressPrefix
if err := netlink.RuleDel(rule); err != nil && !errors.Is(err, syscall.ENOENT) && !errors.Is(err, syscall.EAFNOSUPPORT) {
return fmt.Errorf("remove routing rule: %w", err)
}
return nil
}
// addNextHop adds the gateway and device to the route.
func addNextHop(addr netip.Addr, intf *net.Interface, route *netlink.Route) error {
if intf != nil {
route.LinkIndex = intf.Index
}
if addr.IsValid() {
route.Gw = addr.AsSlice()
// if zone is set, it means the gateway is a link-local address, so we set the link index
if addr.Zone() != "" && intf == nil {
link, err := netlink.LinkByName(addr.Zone())
if err != nil {
return fmt.Errorf("get link by name for zone %s: %w", addr.Zone(), err)
}
route.LinkIndex = link.Attrs().Index
}
}
return nil
}
func getAddressFamily(prefix netip.Prefix) int {
if prefix.Addr().Is4() {
return netlink.FAMILY_V4
}
return netlink.FAMILY_V6
}
// setupSysctl configures sysctl settings for RP filtering and source validation.
func setupSysctl(wgIface *iface.WGIface) (map[string]int, error) {
keys := map[string]int{}
var result *multierror.Error
oldVal, err := setSysctl(srcValidMarkPath, 1, false)
if err != nil {
result = multierror.Append(result, err)
} else {
keys[srcValidMarkPath] = oldVal
}
oldVal, err = setSysctl(rpFilterPath, 2, true)
if err != nil {
result = multierror.Append(result, err)
} else {
keys[rpFilterPath] = oldVal
}
interfaces, err := net.Interfaces()
if err != nil {
result = multierror.Append(result, fmt.Errorf("list interfaces: %w", err))
}
for _, intf := range interfaces {
if intf.Name == "lo" || wgIface != nil && intf.Name == wgIface.Name() {
continue
}
i := fmt.Sprintf(rpFilterInterfacePath, intf.Name)
oldVal, err := setSysctl(i, 2, true)
if err != nil {
result = multierror.Append(result, err)
} else {
keys[i] = oldVal
}
}
return keys, result.ErrorOrNil()
}
// setSysctl sets a sysctl configuration, if onlyIfOne is true it will only set the new value if it's set to 1
func setSysctl(key string, desiredValue int, onlyIfOne bool) (int, error) {
path := fmt.Sprintf("/proc/sys/%s", strings.ReplaceAll(key, ".", "/"))
currentValue, err := os.ReadFile(path)
if err != nil {
return -1, fmt.Errorf("read sysctl %s: %w", key, err)
}
currentV, err := strconv.Atoi(strings.TrimSpace(string(currentValue)))
if err != nil && len(currentValue) > 0 {
return -1, fmt.Errorf("convert current desiredValue to int: %w", err)
}
if currentV == desiredValue || onlyIfOne && currentV != 1 {
return currentV, nil
}
//nolint:gosec
if err := os.WriteFile(path, []byte(strconv.Itoa(desiredValue)), 0644); err != nil {
return currentV, fmt.Errorf("write sysctl %s: %w", key, err)
}
log.Debugf("Set sysctl %s from %d to %d", key, currentV, desiredValue)
return currentV, nil
}
func cleanupSysctl(originalSettings map[string]int) error {
var result *multierror.Error
for key, value := range originalSettings {
_, err := setSysctl(key, value, false)
if err != nil {
result = multierror.Append(result, err)
}
}
return result.ErrorOrNil()
return os.WriteFile(ipv4ForwardingPath, []byte("1"), 0644) //nolint:gosec
}

View File

@@ -1,207 +0,0 @@
//go:build !android
package routemanager
import (
"errors"
"fmt"
"net"
"os"
"strings"
"syscall"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/vishvananda/netlink"
)
var expectedVPNint = "wgtest0"
var expectedLoopbackInt = "lo"
var expectedExternalInt = "dummyext0"
var expectedInternalInt = "dummyint0"
func init() {
testCases = append(testCases, []testCase{
{
name: "To more specific route without custom dialer via physical interface",
destination: "10.10.0.2:53",
expectedInterface: expectedInternalInt,
dialer: &net.Dialer{},
expectedPacket: createPacketExpectation("192.168.1.1", 12345, "10.10.0.2", 53),
},
{
name: "To more specific route (local) without custom dialer via physical interface",
destination: "127.0.10.1:53",
expectedInterface: expectedLoopbackInt,
dialer: &net.Dialer{},
expectedPacket: createPacketExpectation("127.0.0.1", 12345, "127.0.10.1", 53),
},
}...)
}
func TestEntryExists(t *testing.T) {
tempDir := t.TempDir()
tempFilePath := fmt.Sprintf("%s/rt_tables", tempDir)
content := []string{
"1000 reserved",
fmt.Sprintf("%d %s", NetbirdVPNTableID, NetbirdVPNTableName),
"9999 other_table",
}
require.NoError(t, os.WriteFile(tempFilePath, []byte(strings.Join(content, "\n")), 0644))
file, err := os.Open(tempFilePath)
require.NoError(t, err)
defer func() {
assert.NoError(t, file.Close())
}()
tests := []struct {
name string
id int
shouldExist bool
err error
}{
{
name: "ExistsWithNetbirdPrefix",
id: 7120,
shouldExist: true,
err: nil,
},
{
name: "ExistsWithDifferentName",
id: 1000,
shouldExist: true,
err: ErrTableIDExists,
},
{
name: "DoesNotExist",
id: 1234,
shouldExist: false,
err: nil,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
exists, err := entryExists(file, tc.id)
if tc.err != nil {
assert.ErrorIs(t, err, tc.err)
} else {
assert.NoError(t, err)
}
assert.Equal(t, tc.shouldExist, exists)
})
}
}
func createAndSetupDummyInterface(t *testing.T, interfaceName, ipAddressCIDR string) string {
t.Helper()
dummy := &netlink.Dummy{LinkAttrs: netlink.LinkAttrs{Name: interfaceName}}
err := netlink.LinkDel(dummy)
if err != nil && !errors.Is(err, syscall.EINVAL) {
t.Logf("Failed to delete dummy interface: %v", err)
}
err = netlink.LinkAdd(dummy)
require.NoError(t, err)
err = netlink.LinkSetUp(dummy)
require.NoError(t, err)
if ipAddressCIDR != "" {
addr, err := netlink.ParseAddr(ipAddressCIDR)
require.NoError(t, err)
err = netlink.AddrAdd(dummy, addr)
require.NoError(t, err)
}
t.Cleanup(func() {
err := netlink.LinkDel(dummy)
assert.NoError(t, err)
})
return dummy.Name
}
func addDummyRoute(t *testing.T, dstCIDR string, gw net.IP, intf string) {
t.Helper()
_, dstIPNet, err := net.ParseCIDR(dstCIDR)
require.NoError(t, err)
// Handle existing routes with metric 0
var originalNexthop net.IP
var originalLinkIndex int
if dstIPNet.String() == "0.0.0.0/0" {
var err error
originalNexthop, originalLinkIndex, err = fetchOriginalGateway(netlink.FAMILY_V4)
if err != nil && !errors.Is(err, ErrRouteNotFound) {
t.Logf("Failed to fetch original gateway: %v", err)
}
if originalNexthop != nil {
err = netlink.RouteDel(&netlink.Route{Dst: dstIPNet, Priority: 0})
switch {
case err != nil && !errors.Is(err, syscall.ESRCH):
t.Logf("Failed to delete route: %v", err)
case err == nil:
t.Cleanup(func() {
err := netlink.RouteAdd(&netlink.Route{Dst: dstIPNet, Gw: originalNexthop, LinkIndex: originalLinkIndex, Priority: 0})
if err != nil && !errors.Is(err, syscall.EEXIST) {
t.Fatalf("Failed to add route: %v", err)
}
})
default:
t.Logf("Failed to delete route: %v", err)
}
}
}
link, err := netlink.LinkByName(intf)
require.NoError(t, err)
linkIndex := link.Attrs().Index
route := &netlink.Route{
Dst: dstIPNet,
Gw: gw,
LinkIndex: linkIndex,
}
err = netlink.RouteDel(route)
if err != nil && !errors.Is(err, syscall.ESRCH) {
t.Logf("Failed to delete route: %v", err)
}
err = netlink.RouteAdd(route)
if err != nil && !errors.Is(err, syscall.EEXIST) {
t.Fatalf("Failed to add route: %v", err)
}
require.NoError(t, err)
}
func fetchOriginalGateway(family int) (net.IP, int, error) {
routes, err := netlink.RouteList(nil, family)
if err != nil {
return nil, 0, err
}
for _, route := range routes {
if route.Dst == nil && route.Priority == 0 {
return route.Gw, route.LinkIndex, nil
}
}
return nil, 0, ErrRouteNotFound
}
func setupDummyInterfacesAndRoutes(t *testing.T) {
t.Helper()
defaultDummy := createAndSetupDummyInterface(t, "dummyext0", "192.168.0.1/24")
addDummyRoute(t, "0.0.0.0/0", net.IPv4(192, 168, 0, 1), defaultDummy)
otherDummy := createAndSetupDummyInterface(t, "dummyint0", "192.168.1.1/24")
addDummyRoute(t, "10.0.0.0/8", net.IPv4(192, 168, 1, 1), otherDummy)
}

View File

@@ -0,0 +1,120 @@
//go:build !android && !ios
package routemanager
import (
"fmt"
"net"
"net/netip"
"github.com/libp2p/go-netroute"
log "github.com/sirupsen/logrus"
)
var errRouteNotFound = fmt.Errorf("route not found")
func addToRouteTableIfNoExists(prefix netip.Prefix, addr string) error {
ok, err := existsInRouteTable(prefix)
if err != nil {
return err
}
if ok {
log.Warnf("skipping adding a new route for network %s because it already exists", prefix)
return nil
}
ok, err = isSubRange(prefix)
if err != nil {
return err
}
if ok {
err := addRouteForCurrentDefaultGateway(prefix)
if err != nil {
log.Warnf("unable to add route for current default gateway route. Will proceed without it. error: %s", err)
}
}
return addToRouteTable(prefix, addr)
}
func addRouteForCurrentDefaultGateway(prefix netip.Prefix) error {
defaultGateway, err := getExistingRIBRouteGateway(netip.MustParsePrefix("0.0.0.0/0"))
if err != nil && err != errRouteNotFound {
return err
}
addr := netip.MustParseAddr(defaultGateway.String())
if !prefix.Contains(addr) {
log.Debugf("skipping adding a new route for gateway %s because it is not in the network %s", addr, prefix)
return nil
}
gatewayPrefix := netip.PrefixFrom(addr, 32)
ok, err := existsInRouteTable(gatewayPrefix)
if err != nil {
return fmt.Errorf("unable to check if there is an existing route for gateway %s. error: %s", gatewayPrefix, err)
}
if ok {
log.Debugf("skipping adding a new route for gateway %s because it already exists", gatewayPrefix)
return nil
}
gatewayHop, err := getExistingRIBRouteGateway(gatewayPrefix)
if err != nil && err != errRouteNotFound {
return fmt.Errorf("unable to get the next hop for the default gateway address. error: %s", err)
}
log.Debugf("adding a new route for gateway %s with next hop %s", gatewayPrefix, gatewayHop)
return addToRouteTable(gatewayPrefix, gatewayHop.String())
}
func existsInRouteTable(prefix netip.Prefix) (bool, error) {
routes, err := getRoutesFromTable()
if err != nil {
return false, err
}
for _, tableRoute := range routes {
if tableRoute == prefix {
return true, nil
}
}
return false, nil
}
func isSubRange(prefix netip.Prefix) (bool, error) {
routes, err := getRoutesFromTable()
if err != nil {
return false, err
}
for _, tableRoute := range routes {
if tableRoute.Bits() > minRangeBits && tableRoute.Contains(prefix.Addr()) && tableRoute.Bits() < prefix.Bits() {
return true, nil
}
}
return false, nil
}
func removeFromRouteTableIfNonSystem(prefix netip.Prefix, addr string) error {
return removeFromRouteTable(prefix, addr)
}
func getExistingRIBRouteGateway(prefix netip.Prefix) (net.IP, error) {
r, err := netroute.New()
if err != nil {
return nil, err
}
_, gateway, preferredSrc, err := r.Route(prefix.Addr().AsSlice())
if err != nil {
log.Errorf("getting routes returned an error: %v", err)
return nil, errRouteNotFound
}
if gateway == nil {
return preferredSrc, nil
}
return gateway, nil
}

View File

@@ -1,32 +1,24 @@
//go:build !android && !ios
//go:build !android
package routemanager
import (
"bytes"
"context"
"fmt"
"net"
"net/netip"
"os"
"runtime"
"strings"
"testing"
"github.com/pion/transport/v3/stdnet"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/iface"
)
type dialer interface {
Dial(network, address string) (net.Conn, error)
DialContext(ctx context.Context, network, address string) (net.Conn, error)
}
func TestAddRemoveRoutes(t *testing.T) {
testCases := []struct {
name string
@@ -50,8 +42,6 @@ func TestAddRemoveRoutes(t *testing.T) {
for n, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
t.Setenv("NB_DISABLE_ROUTE_CACHE", "true")
peerPrivateKey, _ := wgtypes.GeneratePrivateKey()
newNet, err := stdnet.NewNet()
if err != nil {
@@ -63,34 +53,27 @@ func TestAddRemoveRoutes(t *testing.T) {
err = wgInterface.Create()
require.NoError(t, err, "should create testing wireguard interface")
_, _, err = setupRouting(nil, wgInterface)
require.NoError(t, err)
t.Cleanup(func() {
assert.NoError(t, cleanupRouting())
})
index, err := net.InterfaceByName(wgInterface.Name())
require.NoError(t, err, "InterfaceByName should not return err")
intf := &net.Interface{Index: index.Index, Name: wgInterface.Name()}
err = addVPNRoute(testCase.prefix, intf)
require.NoError(t, err, "genericAddVPNRoute should not return err")
err = addToRouteTableIfNoExists(testCase.prefix, wgInterface.Address().IP.String())
require.NoError(t, err, "addToRouteTableIfNoExists should not return err")
prefixGateway, err := getExistingRIBRouteGateway(testCase.prefix)
require.NoError(t, err, "getExistingRIBRouteGateway should not return err")
if testCase.shouldRouteToWireguard {
assertWGOutInterface(t, testCase.prefix, wgInterface, false)
require.Equal(t, wgInterface.Address().IP.String(), prefixGateway.String(), "route should point to wireguard interface IP")
} else {
assertWGOutInterface(t, testCase.prefix, wgInterface, true)
require.NotEqual(t, wgInterface.Address().IP.String(), prefixGateway.String(), "route should point to a different interface")
}
exists, err := existsInRouteTable(testCase.prefix)
require.NoError(t, err, "existsInRouteTable should not return err")
if exists && testCase.shouldRouteToWireguard {
err = removeVPNRoute(testCase.prefix, intf)
require.NoError(t, err, "genericRemoveVPNRoute should not return err")
err = removeFromRouteTableIfNonSystem(testCase.prefix, wgInterface.Address().IP.String())
require.NoError(t, err, "removeFromRouteTableIfNonSystem should not return err")
prefixGateway, _, err := getNextHop(testCase.prefix.Addr())
require.NoError(t, err, "getNextHop should not return err")
prefixGateway, err = getExistingRIBRouteGateway(testCase.prefix)
require.NoError(t, err, "getExistingRIBRouteGateway should not return err")
internetGateway, _, err := getNextHop(netip.MustParseAddr("0.0.0.0"))
internetGateway, err := getExistingRIBRouteGateway(netip.MustParsePrefix("0.0.0.0/0"))
require.NoError(t, err)
if testCase.shouldBeRemoved {
@@ -103,12 +86,12 @@ func TestAddRemoveRoutes(t *testing.T) {
}
}
func TestGetNextHop(t *testing.T) {
gateway, _, err := getNextHop(netip.MustParseAddr("0.0.0.0"))
func TestGetExistingRIBRouteGateway(t *testing.T) {
gateway, err := getExistingRIBRouteGateway(netip.MustParsePrefix("0.0.0.0/0"))
if err != nil {
t.Fatal("shouldn't return error when fetching the gateway: ", err)
}
if !gateway.IsValid() {
if gateway == nil {
t.Fatal("should return a gateway")
}
addresses, err := net.InterfaceAddrs()
@@ -130,11 +113,11 @@ func TestGetNextHop(t *testing.T) {
}
}
localIP, _, err := getNextHop(testingPrefix.Addr())
localIP, err := getExistingRIBRouteGateway(testingPrefix)
if err != nil {
t.Fatal("shouldn't return error: ", err)
}
if !localIP.IsValid() {
if localIP == nil {
t.Fatal("should return a gateway for local network")
}
if localIP.String() == gateway.String() {
@@ -145,8 +128,8 @@ func TestGetNextHop(t *testing.T) {
}
}
func TestAddExistAndRemoveRoute(t *testing.T) {
defaultGateway, _, err := getNextHop(netip.MustParseAddr("0.0.0.0"))
func TestAddExistAndRemoveRouteNonAndroid(t *testing.T) {
defaultGateway, err := getExistingRIBRouteGateway(netip.MustParsePrefix("0.0.0.0/0"))
t.Log("defaultGateway: ", defaultGateway)
if err != nil {
t.Fatal("shouldn't return error when fetching the gateway: ", err)
@@ -188,16 +171,12 @@ func TestAddExistAndRemoveRoute(t *testing.T) {
}
for n, testCase := range testCases {
var buf bytes.Buffer
log.SetOutput(&buf)
defer func() {
log.SetOutput(os.Stderr)
}()
t.Run(testCase.name, func(t *testing.T) {
t.Setenv("NB_USE_LEGACY_ROUTING", "true")
t.Setenv("NB_DISABLE_ROUTE_CACHE", "true")
peerPrivateKey, _ := wgtypes.GeneratePrivateKey()
newNet, err := stdnet.NewNet()
if err != nil {
@@ -210,18 +189,16 @@ func TestAddExistAndRemoveRoute(t *testing.T) {
err = wgInterface.Create()
require.NoError(t, err, "should create testing wireguard interface")
index, err := net.InterfaceByName(wgInterface.Name())
require.NoError(t, err, "InterfaceByName should not return err")
intf := &net.Interface{Index: index.Index, Name: wgInterface.Name()}
MockAddr := wgInterface.Address().IP.String()
// Prepare the environment
if testCase.preExistingPrefix.IsValid() {
err := addVPNRoute(testCase.preExistingPrefix, intf)
err := addToRouteTableIfNoExists(testCase.preExistingPrefix, MockAddr)
require.NoError(t, err, "should not return err when adding pre-existing route")
}
// Add the route
err = addVPNRoute(testCase.prefix, intf)
err = addToRouteTableIfNoExists(testCase.prefix, MockAddr)
require.NoError(t, err, "should not return err when adding route")
if testCase.shouldAddRoute {
@@ -231,7 +208,7 @@ func TestAddExistAndRemoveRoute(t *testing.T) {
require.True(t, ok, "route should exist")
// remove route again if added
err = removeVPNRoute(testCase.prefix, intf)
err = removeFromRouteTableIfNonSystem(testCase.prefix, MockAddr)
require.NoError(t, err, "should not return err")
}
@@ -240,7 +217,6 @@ func TestAddExistAndRemoveRoute(t *testing.T) {
ok, err := existsInRouteTable(testCase.prefix)
t.Log("Buffer string: ", buf.String())
require.NoError(t, err, "should not return err")
if !strings.Contains(buf.String(), "because it already exists") {
require.False(t, ok, "route should not exist")
}
@@ -248,6 +224,31 @@ func TestAddExistAndRemoveRoute(t *testing.T) {
}
}
func TestExistsInRouteTable(t *testing.T) {
addresses, err := net.InterfaceAddrs()
if err != nil {
t.Fatal("shouldn't return error when fetching interface addresses: ", err)
}
var addressPrefixes []netip.Prefix
for _, address := range addresses {
p := netip.MustParsePrefix(address.String())
if p.Addr().Is4() {
addressPrefixes = append(addressPrefixes, p.Masked())
}
}
for _, prefix := range addressPrefixes {
exists, err := existsInRouteTable(prefix)
if err != nil {
t.Fatal("shouldn't return error when checking if address exists in route table: ", err)
}
if !exists {
t.Fatalf("address %s should exist in route table", prefix)
}
}
}
func TestIsSubRange(t *testing.T) {
addresses, err := net.InterfaceAddrs()
if err != nil {
@@ -285,136 +286,3 @@ func TestIsSubRange(t *testing.T) {
}
}
}
func TestExistsInRouteTable(t *testing.T) {
addresses, err := net.InterfaceAddrs()
if err != nil {
t.Fatal("shouldn't return error when fetching interface addresses: ", err)
}
var addressPrefixes []netip.Prefix
for _, address := range addresses {
p := netip.MustParsePrefix(address.String())
if p.Addr().Is6() {
continue
}
// Windows sometimes has hidden interface link local addrs that don't turn up on any interface
if runtime.GOOS == "windows" && p.Addr().IsLinkLocalUnicast() {
continue
}
// Linux loopback 127/8 is in the local table, not in the main table and always takes precedence
if runtime.GOOS == "linux" && p.Addr().IsLoopback() {
continue
}
addressPrefixes = append(addressPrefixes, p.Masked())
}
for _, prefix := range addressPrefixes {
exists, err := existsInRouteTable(prefix)
if err != nil {
t.Fatal("shouldn't return error when checking if address exists in route table: ", err)
}
if !exists {
t.Fatalf("address %s should exist in route table", prefix)
}
}
}
func createWGInterface(t *testing.T, interfaceName, ipAddressCIDR string, listenPort int) *iface.WGIface {
t.Helper()
peerPrivateKey, err := wgtypes.GeneratePrivateKey()
require.NoError(t, err)
newNet, err := stdnet.NewNet()
require.NoError(t, err)
wgInterface, err := iface.NewWGIFace(interfaceName, ipAddressCIDR, listenPort, peerPrivateKey.String(), iface.DefaultMTU, newNet, nil)
require.NoError(t, err, "should create testing WireGuard interface")
err = wgInterface.Create()
require.NoError(t, err, "should create testing WireGuard interface")
t.Cleanup(func() {
wgInterface.Close()
})
return wgInterface
}
func setupTestEnv(t *testing.T) {
t.Helper()
setupDummyInterfacesAndRoutes(t)
wgIface := createWGInterface(t, expectedVPNint, "100.64.0.1/24", 51820)
t.Cleanup(func() {
assert.NoError(t, wgIface.Close())
})
_, _, err := setupRouting(nil, wgIface)
require.NoError(t, err, "setupRouting should not return err")
t.Cleanup(func() {
assert.NoError(t, cleanupRouting())
})
index, err := net.InterfaceByName(wgIface.Name())
require.NoError(t, err, "InterfaceByName should not return err")
intf := &net.Interface{Index: index.Index, Name: wgIface.Name()}
// default route exists in main table and vpn table
err = addVPNRoute(netip.MustParsePrefix("0.0.0.0/0"), intf)
require.NoError(t, err, "addVPNRoute should not return err")
t.Cleanup(func() {
err = removeVPNRoute(netip.MustParsePrefix("0.0.0.0/0"), intf)
assert.NoError(t, err, "removeVPNRoute should not return err")
})
// 10.0.0.0/8 route exists in main table and vpn table
err = addVPNRoute(netip.MustParsePrefix("10.0.0.0/8"), intf)
require.NoError(t, err, "addVPNRoute should not return err")
t.Cleanup(func() {
err = removeVPNRoute(netip.MustParsePrefix("10.0.0.0/8"), intf)
assert.NoError(t, err, "removeVPNRoute should not return err")
})
// 10.10.0.0/24 more specific route exists in vpn table
err = addVPNRoute(netip.MustParsePrefix("10.10.0.0/24"), intf)
require.NoError(t, err, "addVPNRoute should not return err")
t.Cleanup(func() {
err = removeVPNRoute(netip.MustParsePrefix("10.10.0.0/24"), intf)
assert.NoError(t, err, "removeVPNRoute should not return err")
})
// 127.0.10.0/24 more specific route exists in vpn table
err = addVPNRoute(netip.MustParsePrefix("127.0.10.0/24"), intf)
require.NoError(t, err, "addVPNRoute should not return err")
t.Cleanup(func() {
err = removeVPNRoute(netip.MustParsePrefix("127.0.10.0/24"), intf)
assert.NoError(t, err, "removeVPNRoute should not return err")
})
// unique route in vpn table
err = addVPNRoute(netip.MustParsePrefix("172.16.0.0/12"), intf)
require.NoError(t, err, "addVPNRoute should not return err")
t.Cleanup(func() {
err = removeVPNRoute(netip.MustParsePrefix("172.16.0.0/12"), intf)
assert.NoError(t, err, "removeVPNRoute should not return err")
})
}
func assertWGOutInterface(t *testing.T, prefix netip.Prefix, wgIface *iface.WGIface, invert bool) {
t.Helper()
if runtime.GOOS == "linux" && prefix.Addr().IsLoopback() {
return
}
prefixGateway, _, err := getNextHop(prefix.Addr())
require.NoError(t, err, "getNextHop should not return err")
if invert {
assert.NotEqual(t, wgIface.Address().IP.String(), prefixGateway.String(), "route should not point to wireguard interface IP")
} else {
assert.Equal(t, wgIface.Address().IP.String(), prefixGateway.String(), "route should point to wireguard interface IP")
}
}

View File

@@ -1,24 +1,41 @@
//go:build !linux && !ios
//go:build !linux
// +build !linux
package routemanager
import (
"net"
"net/netip"
"os/exec"
"runtime"
log "github.com/sirupsen/logrus"
)
func enableIPForwarding() error {
log.Infof("Enable IP forwarding is not implemented on %s", runtime.GOOS)
func addToRouteTable(prefix netip.Prefix, addr string) error {
cmd := exec.Command("route", "add", prefix.String(), addr)
out, err := cmd.Output()
if err != nil {
return err
}
log.Debugf(string(out))
return nil
}
func addVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
return genericAddVPNRoute(prefix, intf)
func removeFromRouteTable(prefix netip.Prefix, addr string) error {
args := []string{"delete", prefix.String()}
if runtime.GOOS == "darwin" {
args = append(args, addr)
}
cmd := exec.Command("route", args...)
out, err := cmd.Output()
if err != nil {
return err
}
log.Debugf(string(out))
return nil
}
func removeVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
return genericRemoveVPNRoute(prefix, intf)
func enableIPForwarding() error {
log.Infof("enable IP forwarding is not implemented on %s", runtime.GOOS)
return nil
}

View File

@@ -1,234 +0,0 @@
//go:build (linux && !android) || (darwin && !ios) || freebsd || openbsd || netbsd || dragonfly
package routemanager
import (
"fmt"
"net"
"strings"
"testing"
"time"
"github.com/gopacket/gopacket"
"github.com/gopacket/gopacket/layers"
"github.com/gopacket/gopacket/pcap"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
nbnet "github.com/netbirdio/netbird/util/net"
)
type PacketExpectation struct {
SrcIP net.IP
DstIP net.IP
SrcPort int
DstPort int
UDP bool
TCP bool
}
type testCase struct {
name string
destination string
expectedInterface string
dialer dialer
expectedPacket PacketExpectation
}
var testCases = []testCase{
{
name: "To external host without custom dialer via vpn",
destination: "192.0.2.1:53",
expectedInterface: expectedVPNint,
dialer: &net.Dialer{},
expectedPacket: createPacketExpectation("100.64.0.1", 12345, "192.0.2.1", 53),
},
{
name: "To external host with custom dialer via physical interface",
destination: "192.0.2.1:53",
expectedInterface: expectedExternalInt,
dialer: nbnet.NewDialer(),
expectedPacket: createPacketExpectation("192.168.0.1", 12345, "192.0.2.1", 53),
},
{
name: "To duplicate internal route with custom dialer via physical interface",
destination: "10.0.0.2:53",
expectedInterface: expectedInternalInt,
dialer: nbnet.NewDialer(),
expectedPacket: createPacketExpectation("192.168.1.1", 12345, "10.0.0.2", 53),
},
{
name: "To duplicate internal route without custom dialer via physical interface", // local route takes precedence
destination: "10.0.0.2:53",
expectedInterface: expectedInternalInt,
dialer: &net.Dialer{},
expectedPacket: createPacketExpectation("192.168.1.1", 12345, "10.0.0.2", 53),
},
{
name: "To unique vpn route with custom dialer via physical interface",
destination: "172.16.0.2:53",
expectedInterface: expectedExternalInt,
dialer: nbnet.NewDialer(),
expectedPacket: createPacketExpectation("192.168.0.1", 12345, "172.16.0.2", 53),
},
{
name: "To unique vpn route without custom dialer via vpn",
destination: "172.16.0.2:53",
expectedInterface: expectedVPNint,
dialer: &net.Dialer{},
expectedPacket: createPacketExpectation("100.64.0.1", 12345, "172.16.0.2", 53),
},
}
func TestRouting(t *testing.T) {
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
setupTestEnv(t)
filter := createBPFFilter(tc.destination)
handle := startPacketCapture(t, tc.expectedInterface, filter)
sendTestPacket(t, tc.destination, tc.expectedPacket.SrcPort, tc.dialer)
packetSource := gopacket.NewPacketSource(handle, handle.LinkType())
packet, err := packetSource.NextPacket()
require.NoError(t, err)
verifyPacket(t, packet, tc.expectedPacket)
})
}
}
func createPacketExpectation(srcIP string, srcPort int, dstIP string, dstPort int) PacketExpectation {
return PacketExpectation{
SrcIP: net.ParseIP(srcIP),
DstIP: net.ParseIP(dstIP),
SrcPort: srcPort,
DstPort: dstPort,
UDP: true,
}
}
func startPacketCapture(t *testing.T, intf, filter string) *pcap.Handle {
t.Helper()
inactive, err := pcap.NewInactiveHandle(intf)
require.NoError(t, err, "Failed to create inactive pcap handle")
defer inactive.CleanUp()
err = inactive.SetSnapLen(1600)
require.NoError(t, err, "Failed to set snap length on inactive handle")
err = inactive.SetTimeout(time.Second * 10)
require.NoError(t, err, "Failed to set timeout on inactive handle")
err = inactive.SetImmediateMode(true)
require.NoError(t, err, "Failed to set immediate mode on inactive handle")
handle, err := inactive.Activate()
require.NoError(t, err, "Failed to activate pcap handle")
t.Cleanup(handle.Close)
err = handle.SetBPFFilter(filter)
require.NoError(t, err, "Failed to set BPF filter")
return handle
}
func sendTestPacket(t *testing.T, destination string, sourcePort int, dialer dialer) {
t.Helper()
if dialer == nil {
dialer = &net.Dialer{}
}
if sourcePort != 0 {
localUDPAddr := &net.UDPAddr{
IP: net.IPv4zero,
Port: sourcePort,
}
switch dialer := dialer.(type) {
case *nbnet.Dialer:
dialer.LocalAddr = localUDPAddr
case *net.Dialer:
dialer.LocalAddr = localUDPAddr
default:
t.Fatal("Unsupported dialer type")
}
}
msg := new(dns.Msg)
msg.Id = dns.Id()
msg.RecursionDesired = true
msg.Question = []dns.Question{
{Name: "example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET},
}
conn, err := dialer.Dial("udp", destination)
require.NoError(t, err, "Failed to dial UDP")
defer conn.Close()
data, err := msg.Pack()
require.NoError(t, err, "Failed to pack DNS message")
_, err = conn.Write(data)
if err != nil {
if strings.Contains(err.Error(), "required key not available") {
t.Logf("Ignoring WireGuard key error: %v", err)
return
}
t.Fatalf("Failed to send DNS query: %v", err)
}
}
func createBPFFilter(destination string) string {
host, port, err := net.SplitHostPort(destination)
if err != nil {
return fmt.Sprintf("udp and dst host %s and dst port %s", host, port)
}
return "udp"
}
func verifyPacket(t *testing.T, packet gopacket.Packet, exp PacketExpectation) {
t.Helper()
ipLayer := packet.Layer(layers.LayerTypeIPv4)
require.NotNil(t, ipLayer, "Expected IPv4 layer not found in packet")
ip, ok := ipLayer.(*layers.IPv4)
require.True(t, ok, "Failed to cast to IPv4 layer")
// Convert both source and destination IP addresses to 16-byte representation
expectedSrcIP := exp.SrcIP.To16()
actualSrcIP := ip.SrcIP.To16()
assert.Equal(t, expectedSrcIP, actualSrcIP, "Source IP mismatch")
expectedDstIP := exp.DstIP.To16()
actualDstIP := ip.DstIP.To16()
assert.Equal(t, expectedDstIP, actualDstIP, "Destination IP mismatch")
if exp.UDP {
udpLayer := packet.Layer(layers.LayerTypeUDP)
require.NotNil(t, udpLayer, "Expected UDP layer not found in packet")
udp, ok := udpLayer.(*layers.UDP)
require.True(t, ok, "Failed to cast to UDP layer")
assert.Equal(t, layers.UDPPort(exp.SrcPort), udp.SrcPort, "UDP source port mismatch")
assert.Equal(t, layers.UDPPort(exp.DstPort), udp.DstPort, "UDP destination port mismatch")
}
if exp.TCP {
tcpLayer := packet.Layer(layers.LayerTypeTCP)
require.NotNil(t, tcpLayer, "Expected TCP layer not found in packet")
tcp, ok := tcpLayer.(*layers.TCP)
require.True(t, ok, "Failed to cast to TCP layer")
assert.Equal(t, layers.TCPPort(exp.SrcPort), tcp.SrcPort, "TCP source port mismatch")
assert.Equal(t, layers.TCPPort(exp.DstPort), tcp.DstPort, "TCP destination port mismatch")
}
}

View File

@@ -1,23 +1,13 @@
//go:build windows
// +build windows
package routemanager
import (
"fmt"
"net"
"net/netip"
"os"
"os/exec"
"strconv"
"strings"
"sync"
"time"
log "github.com/sirupsen/logrus"
"github.com/yusufpapurcu/wmi"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/iface"
)
type Win32_IP4RouteTable struct {
@@ -25,47 +15,23 @@ type Win32_IP4RouteTable struct {
Mask string
}
var prefixList []netip.Prefix
var lastUpdate time.Time
var mux = sync.Mutex{}
var routeManager *RouteManager
func setupRouting(initAddresses []net.IP, wgIface *iface.WGIface) (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) {
return setupRoutingWithRouteManager(&routeManager, initAddresses, wgIface)
}
func cleanupRouting() error {
return cleanupRoutingWithRouteManager(routeManager)
}
func getRoutesFromTable() ([]netip.Prefix, error) {
mux.Lock()
defer mux.Unlock()
var routes []Win32_IP4RouteTable
query := "SELECT Destination, Mask FROM Win32_IP4RouteTable"
// If many routes are added at the same time this might block for a long time (seconds to minutes), so we cache the result
if !isCacheDisabled() && time.Since(lastUpdate) < 2*time.Second {
return prefixList, nil
}
var routes []Win32_IP4RouteTable
err := wmi.Query(query, &routes)
if err != nil {
return nil, fmt.Errorf("get routes: %w", err)
return nil, err
}
prefixList = nil
var prefixList []netip.Prefix
for _, route := range routes {
addr, err := netip.ParseAddr(route.Destination)
if err != nil {
log.Warnf("Unable to parse route destination %s: %v", route.Destination, err)
continue
}
maskSlice := net.ParseIP(route.Mask).To4()
if maskSlice == nil {
log.Warnf("Unable to parse route mask %s", route.Mask)
continue
}
mask := net.IPv4Mask(maskSlice[0], maskSlice[1], maskSlice[2], maskSlice[3])
@@ -76,66 +42,5 @@ func getRoutesFromTable() ([]netip.Prefix, error) {
prefixList = append(prefixList, routePrefix)
}
}
lastUpdate = time.Now()
return prefixList, nil
}
func addRouteCmd(prefix netip.Prefix, nexthop netip.Addr, intf *net.Interface) error {
args := []string{"add", prefix.String()}
if nexthop.IsValid() {
args = append(args, nexthop.Unmap().String())
} else {
addr := "0.0.0.0"
if prefix.Addr().Is6() {
addr = "::"
}
args = append(args, addr)
}
if intf != nil {
args = append(args, "if", strconv.Itoa(intf.Index))
}
out, err := exec.Command("route", args...).CombinedOutput()
log.Tracef("route %s: %s", strings.Join(args, " "), out)
if err != nil {
return fmt.Errorf("route add: %w", err)
}
return nil
}
func addToRouteTable(prefix netip.Prefix, nexthop netip.Addr, intf *net.Interface) error {
if nexthop.Zone() != "" && intf == nil {
zone, err := strconv.Atoi(nexthop.Zone())
if err != nil {
return fmt.Errorf("invalid zone: %w", err)
}
intf = &net.Interface{Index: zone}
nexthop.WithZone("")
}
return addRouteCmd(prefix, nexthop, intf)
}
func removeFromRouteTable(prefix netip.Prefix, nexthop netip.Addr, _ *net.Interface) error {
args := []string{"delete", prefix.String()}
if nexthop.IsValid() {
nexthop.WithZone("")
args = append(args, nexthop.Unmap().String())
}
out, err := exec.Command("route", args...).CombinedOutput()
log.Tracef("route %s: %s", strings.Join(args, " "), out)
if err != nil {
return fmt.Errorf("remove route: %w", err)
}
return nil
}
func isCacheDisabled() bool {
return os.Getenv("NB_DISABLE_ROUTE_CACHE") == "true"
}

View File

@@ -1,289 +0,0 @@
package routemanager
import (
"context"
"encoding/json"
"fmt"
"net"
"os/exec"
"strings"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
nbnet "github.com/netbirdio/netbird/util/net"
)
var expectedExtInt = "Ethernet1"
type RouteInfo struct {
NextHop string `json:"nexthop"`
InterfaceAlias string `json:"interfacealias"`
RouteMetric int `json:"routemetric"`
}
type FindNetRouteOutput struct {
IPAddress string `json:"IPAddress"`
InterfaceIndex int `json:"InterfaceIndex"`
InterfaceAlias string `json:"InterfaceAlias"`
AddressFamily int `json:"AddressFamily"`
NextHop string `json:"NextHop"`
DestinationPrefix string `json:"DestinationPrefix"`
}
type testCase struct {
name string
destination string
expectedSourceIP string
expectedDestPrefix string
expectedNextHop string
expectedInterface string
dialer dialer
}
var expectedVPNint = "wgtest0"
var testCases = []testCase{
{
name: "To external host without custom dialer via vpn",
destination: "192.0.2.1:53",
expectedSourceIP: "100.64.0.1",
expectedDestPrefix: "128.0.0.0/1",
expectedNextHop: "0.0.0.0",
expectedInterface: "wgtest0",
dialer: &net.Dialer{},
},
{
name: "To external host with custom dialer via physical interface",
destination: "192.0.2.1:53",
expectedDestPrefix: "192.0.2.1/32",
expectedInterface: expectedExtInt,
dialer: nbnet.NewDialer(),
},
{
name: "To duplicate internal route with custom dialer via physical interface",
destination: "10.0.0.2:53",
expectedDestPrefix: "10.0.0.2/32",
expectedInterface: expectedExtInt,
dialer: nbnet.NewDialer(),
},
{
name: "To duplicate internal route without custom dialer via physical interface", // local route takes precedence
destination: "10.0.0.2:53",
expectedSourceIP: "10.0.0.1",
expectedDestPrefix: "10.0.0.0/8",
expectedNextHop: "0.0.0.0",
expectedInterface: "Loopback Pseudo-Interface 1",
dialer: &net.Dialer{},
},
{
name: "To unique vpn route with custom dialer via physical interface",
destination: "172.16.0.2:53",
expectedDestPrefix: "172.16.0.2/32",
expectedInterface: expectedExtInt,
dialer: nbnet.NewDialer(),
},
{
name: "To unique vpn route without custom dialer via vpn",
destination: "172.16.0.2:53",
expectedSourceIP: "100.64.0.1",
expectedDestPrefix: "172.16.0.0/12",
expectedNextHop: "0.0.0.0",
expectedInterface: "wgtest0",
dialer: &net.Dialer{},
},
{
name: "To more specific route without custom dialer via vpn interface",
destination: "10.10.0.2:53",
expectedSourceIP: "100.64.0.1",
expectedDestPrefix: "10.10.0.0/24",
expectedNextHop: "0.0.0.0",
expectedInterface: "wgtest0",
dialer: &net.Dialer{},
},
{
name: "To more specific route (local) without custom dialer via physical interface",
destination: "127.0.10.2:53",
expectedSourceIP: "10.0.0.1",
expectedDestPrefix: "127.0.0.0/8",
expectedNextHop: "0.0.0.0",
expectedInterface: "Loopback Pseudo-Interface 1",
dialer: &net.Dialer{},
},
}
func TestRouting(t *testing.T) {
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
setupTestEnv(t)
route, err := fetchOriginalGateway()
require.NoError(t, err, "Failed to fetch original gateway")
ip, err := fetchInterfaceIP(route.InterfaceAlias)
require.NoError(t, err, "Failed to fetch interface IP")
output := testRoute(t, tc.destination, tc.dialer)
if tc.expectedInterface == expectedExtInt {
verifyOutput(t, output, ip, tc.expectedDestPrefix, route.NextHop, route.InterfaceAlias)
} else {
verifyOutput(t, output, tc.expectedSourceIP, tc.expectedDestPrefix, tc.expectedNextHop, tc.expectedInterface)
}
})
}
}
// fetchInterfaceIP fetches the IPv4 address of the specified interface.
func fetchInterfaceIP(interfaceAlias string) (string, error) {
script := fmt.Sprintf(`Get-NetIPAddress -InterfaceAlias "%s" | Where-Object AddressFamily -eq 2 | Select-Object -ExpandProperty IPAddress`, interfaceAlias)
out, err := exec.Command("powershell", "-Command", script).Output()
if err != nil {
return "", fmt.Errorf("failed to execute Get-NetIPAddress: %w", err)
}
ip := strings.TrimSpace(string(out))
return ip, nil
}
func testRoute(t *testing.T, destination string, dialer dialer) *FindNetRouteOutput {
t.Helper()
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
defer cancel()
conn, err := dialer.DialContext(ctx, "udp", destination)
require.NoError(t, err, "Failed to dial destination")
defer func() {
err := conn.Close()
assert.NoError(t, err, "Failed to close connection")
}()
host, _, err := net.SplitHostPort(destination)
require.NoError(t, err)
script := fmt.Sprintf(`Find-NetRoute -RemoteIPAddress "%s" | Select-Object -Property IPAddress, InterfaceIndex, InterfaceAlias, AddressFamily, NextHop, DestinationPrefix | ConvertTo-Json`, host)
out, err := exec.Command("powershell", "-Command", script).Output()
require.NoError(t, err, "Failed to execute Find-NetRoute")
var outputs []FindNetRouteOutput
err = json.Unmarshal(out, &outputs)
require.NoError(t, err, "Failed to parse JSON outputs from Find-NetRoute")
require.Greater(t, len(outputs), 0, "No route found for destination")
combinedOutput := combineOutputs(outputs)
return combinedOutput
}
func createAndSetupDummyInterface(t *testing.T, interfaceName, ipAddressCIDR string) string {
t.Helper()
ip, ipNet, err := net.ParseCIDR(ipAddressCIDR)
require.NoError(t, err)
subnetMaskSize, _ := ipNet.Mask.Size()
script := fmt.Sprintf(`New-NetIPAddress -InterfaceAlias "%s" -IPAddress "%s" -PrefixLength %d -PolicyStore ActiveStore -Confirm:$False`, interfaceName, ip.String(), subnetMaskSize)
_, err = exec.Command("powershell", "-Command", script).CombinedOutput()
require.NoError(t, err, "Failed to assign IP address to loopback adapter")
// Wait for the IP address to be applied
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
err = waitForIPAddress(ctx, interfaceName, ip.String())
require.NoError(t, err, "IP address not applied within timeout")
t.Cleanup(func() {
script = fmt.Sprintf(`Remove-NetIPAddress -InterfaceAlias "%s" -IPAddress "%s" -Confirm:$False`, interfaceName, ip.String())
_, err = exec.Command("powershell", "-Command", script).CombinedOutput()
require.NoError(t, err, "Failed to remove IP address from loopback adapter")
})
return interfaceName
}
func fetchOriginalGateway() (*RouteInfo, error) {
cmd := exec.Command("powershell", "-Command", "Get-NetRoute -DestinationPrefix 0.0.0.0/0 | Select-Object NextHop, RouteMetric, InterfaceAlias | ConvertTo-Json")
output, err := cmd.CombinedOutput()
if err != nil {
return nil, fmt.Errorf("failed to execute Get-NetRoute: %w", err)
}
var routeInfo RouteInfo
err = json.Unmarshal(output, &routeInfo)
if err != nil {
return nil, fmt.Errorf("failed to parse JSON output: %w", err)
}
return &routeInfo, nil
}
func verifyOutput(t *testing.T, output *FindNetRouteOutput, sourceIP, destPrefix, nextHop, intf string) {
t.Helper()
assert.Equal(t, sourceIP, output.IPAddress, "Source IP mismatch")
assert.Equal(t, destPrefix, output.DestinationPrefix, "Destination prefix mismatch")
assert.Equal(t, nextHop, output.NextHop, "Next hop mismatch")
assert.Equal(t, intf, output.InterfaceAlias, "Interface mismatch")
}
func waitForIPAddress(ctx context.Context, interfaceAlias, expectedIPAddress string) error {
ticker := time.NewTicker(1 * time.Second)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return ctx.Err()
case <-ticker.C:
out, err := exec.Command("powershell", "-Command", fmt.Sprintf(`Get-NetIPAddress -InterfaceAlias "%s" | Select-Object -ExpandProperty IPAddress`, interfaceAlias)).CombinedOutput()
if err != nil {
return err
}
ipAddresses := strings.Split(strings.TrimSpace(string(out)), "\n")
for _, ip := range ipAddresses {
if strings.TrimSpace(ip) == expectedIPAddress {
return nil
}
}
}
}
}
func combineOutputs(outputs []FindNetRouteOutput) *FindNetRouteOutput {
var combined FindNetRouteOutput
for _, output := range outputs {
if output.IPAddress != "" {
combined.IPAddress = output.IPAddress
}
if output.InterfaceIndex != 0 {
combined.InterfaceIndex = output.InterfaceIndex
}
if output.InterfaceAlias != "" {
combined.InterfaceAlias = output.InterfaceAlias
}
if output.AddressFamily != 0 {
combined.AddressFamily = output.AddressFamily
}
if output.NextHop != "" {
combined.NextHop = output.NextHop
}
if output.DestinationPrefix != "" {
combined.DestinationPrefix = output.DestinationPrefix
}
}
return &combined
}
func setupDummyInterfacesAndRoutes(t *testing.T) {
t.Helper()
createAndSetupDummyInterface(t, "Loopback Pseudo-Interface 1", "10.0.0.1/8")
}

View File

@@ -1,132 +0,0 @@
package routeselector
import (
"fmt"
"slices"
"strings"
"github.com/hashicorp/go-multierror"
"golang.org/x/exp/maps"
route "github.com/netbirdio/netbird/route"
)
type RouteSelector struct {
selectedRoutes map[string]struct{}
selectAll bool
}
func NewRouteSelector() *RouteSelector {
return &RouteSelector{
selectedRoutes: map[string]struct{}{},
// default selects all routes
selectAll: true,
}
}
// SelectRoutes updates the selected routes based on the provided route IDs.
func (rs *RouteSelector) SelectRoutes(routes []string, appendRoute bool, allRoutes []string) error {
if !appendRoute {
rs.selectedRoutes = map[string]struct{}{}
}
var multiErr *multierror.Error
for _, route := range routes {
if !slices.Contains(allRoutes, route) {
multiErr = multierror.Append(multiErr, fmt.Errorf("route '%s' is not available", route))
continue
}
rs.selectedRoutes[route] = struct{}{}
}
rs.selectAll = false
if multiErr != nil {
multiErr.ErrorFormat = formatError
}
return multiErr.ErrorOrNil()
}
// SelectAllRoutes sets the selector to select all routes.
func (rs *RouteSelector) SelectAllRoutes() {
rs.selectAll = true
rs.selectedRoutes = map[string]struct{}{}
}
// DeselectRoutes removes specific routes from the selection.
// If the selector is in "select all" mode, it will transition to "select specific" mode.
func (rs *RouteSelector) DeselectRoutes(routes []string, allRoutes []string) error {
if rs.selectAll {
rs.selectAll = false
rs.selectedRoutes = map[string]struct{}{}
for _, route := range allRoutes {
rs.selectedRoutes[route] = struct{}{}
}
}
var multiErr *multierror.Error
for _, route := range routes {
if !slices.Contains(allRoutes, route) {
multiErr = multierror.Append(multiErr, fmt.Errorf("route '%s' is not available", route))
continue
}
delete(rs.selectedRoutes, route)
}
if multiErr != nil {
multiErr.ErrorFormat = formatError
}
return multiErr.ErrorOrNil()
}
// DeselectAllRoutes deselects all routes, effectively disabling route selection.
func (rs *RouteSelector) DeselectAllRoutes() {
rs.selectAll = false
rs.selectedRoutes = map[string]struct{}{}
}
// IsSelected checks if a specific route is selected.
func (rs *RouteSelector) IsSelected(routeID string) bool {
if rs.selectAll {
return true
}
_, selected := rs.selectedRoutes[routeID]
return selected
}
// FilterSelected removes unselected routes from the provided map.
func (rs *RouteSelector) FilterSelected(routes map[string][]*route.Route) map[string][]*route.Route {
if rs.selectAll {
return maps.Clone(routes)
}
filtered := map[string][]*route.Route{}
for id, rt := range routes {
netID := id
if i := strings.LastIndex(id, "-"); i != -1 {
netID = id[:i]
}
if rs.IsSelected(netID) {
filtered[id] = rt
}
}
return filtered
}
func formatError(es []error) string {
if len(es) == 1 {
return fmt.Sprintf("1 error occurred:\n\t* %s", es[0])
}
points := make([]string, len(es))
for i, err := range es {
points[i] = fmt.Sprintf("* %s", err)
}
return fmt.Sprintf(
"%d errors occurred:\n\t%s",
len(es), strings.Join(points, "\n\t"))
}

View File

@@ -1,275 +0,0 @@
package routeselector_test
import (
"slices"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/internal/routeselector"
"github.com/netbirdio/netbird/route"
)
func TestRouteSelector_SelectRoutes(t *testing.T) {
allRoutes := []string{"route1", "route2", "route3"}
tests := []struct {
name string
initialSelected []string
selectRoutes []string
append bool
wantSelected []string
wantError bool
}{
{
name: "Select specific routes, initial all selected",
selectRoutes: []string{"route1", "route2"},
wantSelected: []string{"route1", "route2"},
},
{
name: "Select specific routes, initial all deselected",
initialSelected: []string{},
selectRoutes: []string{"route1", "route2"},
wantSelected: []string{"route1", "route2"},
},
{
name: "Select specific routes with initial selection",
initialSelected: []string{"route1"},
selectRoutes: []string{"route2", "route3"},
wantSelected: []string{"route2", "route3"},
},
{
name: "Select non-existing route",
selectRoutes: []string{"route1", "route4"},
wantSelected: []string{"route1"},
wantError: true,
},
{
name: "Append route with initial selection",
initialSelected: []string{"route1"},
selectRoutes: []string{"route2"},
append: true,
wantSelected: []string{"route1", "route2"},
},
{
name: "Append route without initial selection",
selectRoutes: []string{"route2"},
append: true,
wantSelected: []string{"route2"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
rs := routeselector.NewRouteSelector()
if tt.initialSelected != nil {
err := rs.SelectRoutes(tt.initialSelected, false, allRoutes)
require.NoError(t, err)
}
err := rs.SelectRoutes(tt.selectRoutes, tt.append, allRoutes)
if tt.wantError {
assert.Error(t, err)
} else {
assert.NoError(t, err)
}
for _, id := range allRoutes {
assert.Equal(t, rs.IsSelected(id), slices.Contains(tt.wantSelected, id))
}
})
}
}
func TestRouteSelector_SelectAllRoutes(t *testing.T) {
allRoutes := []string{"route1", "route2", "route3"}
tests := []struct {
name string
initialSelected []string
wantSelected []string
}{
{
name: "Initial all selected",
wantSelected: []string{"route1", "route2", "route3"},
},
{
name: "Initial all deselected",
initialSelected: []string{},
wantSelected: []string{"route1", "route2", "route3"},
},
{
name: "Initial some selected",
initialSelected: []string{"route1"},
wantSelected: []string{"route1", "route2", "route3"},
},
{
name: "Initial all selected",
initialSelected: []string{"route1", "route2", "route3"},
wantSelected: []string{"route1", "route2", "route3"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
rs := routeselector.NewRouteSelector()
if tt.initialSelected != nil {
err := rs.SelectRoutes(tt.initialSelected, false, allRoutes)
require.NoError(t, err)
}
rs.SelectAllRoutes()
for _, id := range allRoutes {
assert.Equal(t, rs.IsSelected(id), slices.Contains(tt.wantSelected, id))
}
})
}
}
func TestRouteSelector_DeselectRoutes(t *testing.T) {
allRoutes := []string{"route1", "route2", "route3"}
tests := []struct {
name string
initialSelected []string
deselectRoutes []string
wantSelected []string
wantError bool
}{
{
name: "Deselect specific routes, initial all selected",
deselectRoutes: []string{"route1", "route2"},
wantSelected: []string{"route3"},
},
{
name: "Deselect specific routes, initial all deselected",
initialSelected: []string{},
deselectRoutes: []string{"route1", "route2"},
wantSelected: []string{},
},
{
name: "Deselect specific routes with initial selection",
initialSelected: []string{"route1", "route2"},
deselectRoutes: []string{"route1", "route3"},
wantSelected: []string{"route2"},
},
{
name: "Deselect non-existing route",
initialSelected: []string{"route1", "route2"},
deselectRoutes: []string{"route1", "route4"},
wantSelected: []string{"route2"},
wantError: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
rs := routeselector.NewRouteSelector()
if tt.initialSelected != nil {
err := rs.SelectRoutes(tt.initialSelected, false, allRoutes)
require.NoError(t, err)
}
err := rs.DeselectRoutes(tt.deselectRoutes, allRoutes)
if tt.wantError {
assert.Error(t, err)
} else {
assert.NoError(t, err)
}
for _, id := range allRoutes {
assert.Equal(t, rs.IsSelected(id), slices.Contains(tt.wantSelected, id))
}
})
}
}
func TestRouteSelector_DeselectAll(t *testing.T) {
allRoutes := []string{"route1", "route2", "route3"}
tests := []struct {
name string
initialSelected []string
wantSelected []string
}{
{
name: "Initial all selected",
wantSelected: []string{},
},
{
name: "Initial all deselected",
initialSelected: []string{},
wantSelected: []string{},
},
{
name: "Initial some selected",
initialSelected: []string{"route1", "route2"},
wantSelected: []string{},
},
{
name: "Initial all selected",
initialSelected: []string{"route1", "route2", "route3"},
wantSelected: []string{},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
rs := routeselector.NewRouteSelector()
if tt.initialSelected != nil {
err := rs.SelectRoutes(tt.initialSelected, false, allRoutes)
require.NoError(t, err)
}
rs.DeselectAllRoutes()
for _, id := range allRoutes {
assert.Equal(t, rs.IsSelected(id), slices.Contains(tt.wantSelected, id))
}
})
}
}
func TestRouteSelector_IsSelected(t *testing.T) {
rs := routeselector.NewRouteSelector()
err := rs.SelectRoutes([]string{"route1", "route2"}, false, []string{"route1", "route2", "route3"})
require.NoError(t, err)
assert.True(t, rs.IsSelected("route1"))
assert.True(t, rs.IsSelected("route2"))
assert.False(t, rs.IsSelected("route3"))
assert.False(t, rs.IsSelected("route4"))
}
func TestRouteSelector_FilterSelected(t *testing.T) {
rs := routeselector.NewRouteSelector()
err := rs.SelectRoutes([]string{"route1", "route2"}, false, []string{"route1", "route2", "route3"})
require.NoError(t, err)
routes := map[string][]*route.Route{
"route1-10.0.0.0/8": {},
"route2-192.168.0.0/16": {},
"route3-172.16.0.0/12": {},
}
filtered := rs.FilterSelected(routes)
assert.Equal(t, map[string][]*route.Route{
"route1-10.0.0.0/8": {},
"route2-192.168.0.0/16": {},
}, filtered)
}

View File

@@ -1,82 +0,0 @@
package internal
import (
"context"
"os/exec"
"strings"
"sync"
"time"
"github.com/netbirdio/netbird/client/internal/peer"
)
type SessionWatcher struct {
ctx context.Context
mutex sync.Mutex
peerStatusRecorder *peer.Status
watchTicker *time.Ticker
sendNotification bool
onExpireListener func()
}
// NewSessionWatcher creates a new instance of SessionWatcher.
func NewSessionWatcher(ctx context.Context, peerStatusRecorder *peer.Status) *SessionWatcher {
s := &SessionWatcher{
ctx: ctx,
peerStatusRecorder: peerStatusRecorder,
watchTicker: time.NewTicker(2 * time.Second),
}
go s.startWatcher()
return s
}
// SetOnExpireListener sets the callback func to be called when the session expires.
func (s *SessionWatcher) SetOnExpireListener(onExpire func()) {
s.mutex.Lock()
defer s.mutex.Unlock()
s.onExpireListener = onExpire
}
// startWatcher continuously checks if the session requires login and
// calls the onExpireListener if login is required.
func (s *SessionWatcher) startWatcher() {
for {
select {
case <-s.ctx.Done():
s.watchTicker.Stop()
return
case <-s.watchTicker.C:
managementState := s.peerStatusRecorder.GetManagementState()
if managementState.Connected {
s.sendNotification = true
}
isLoginRequired := s.peerStatusRecorder.IsLoginRequired()
if isLoginRequired && s.sendNotification && s.onExpireListener != nil {
s.mutex.Lock()
s.onExpireListener()
s.sendNotification = false
s.mutex.Unlock()
}
}
}
}
// CheckUIApp checks whether UI application is running.
func CheckUIApp() bool {
cmd := exec.Command("ps", "-ef")
output, err := cmd.Output()
if err != nil {
return false
}
lines := strings.Split(string(output), "\n")
for _, line := range lines {
if strings.Contains(line, "netbird-ui") && !strings.Contains(line, "grep") {
return true
}
}
return false
}

View File

@@ -1,24 +0,0 @@
package stdnet
import (
"net"
"github.com/pion/transport/v3"
nbnet "github.com/netbirdio/netbird/util/net"
)
// Dial connects to the address on the named network.
func (n *Net) Dial(network, address string) (net.Conn, error) {
return nbnet.NewDialer().Dial(network, address)
}
// DialUDP connects to the address on the named UDP network.
func (n *Net) DialUDP(network string, laddr, raddr *net.UDPAddr) (transport.UDPConn, error) {
return nbnet.DialUDP(network, laddr, raddr)
}
// DialTCP connects to the address on the named TCP network.
func (n *Net) DialTCP(network string, laddr, raddr *net.TCPAddr) (transport.TCPConn, error) {
return nbnet.DialTCP(network, laddr, raddr)
}

View File

@@ -1,20 +0,0 @@
package stdnet
import (
"context"
"net"
"github.com/pion/transport/v3"
nbnet "github.com/netbirdio/netbird/util/net"
)
// ListenPacket listens for incoming packets on the given network and address.
func (n *Net) ListenPacket(network, address string) (net.PacketConn, error) {
return nbnet.NewListener().ListenPacket(context.Background(), network, address)
}
// ListenUDP acts like ListenPacket for UDP networks.
func (n *Net) ListenUDP(network string, locAddr *net.UDPAddr) (transport.UDPConn, error) {
return nbnet.ListenUDP(network, locAddr)
}

View File

@@ -12,12 +12,10 @@ import (
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
"github.com/pion/transport/v3"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/ebpf"
ebpfMgr "github.com/netbirdio/netbird/client/internal/ebpf/manager"
nbnet "github.com/netbirdio/netbird/util/net"
)
// WGEBPFProxy definition for proxy with EBPF support
@@ -30,7 +28,7 @@ type WGEBPFProxy struct {
turnConnMutex sync.Mutex
rawConn net.PacketConn
conn transport.UDPConn
conn *net.UDPConn
}
// NewWGEBPFProxy create new WGEBPFProxy instance
@@ -68,7 +66,7 @@ func (p *WGEBPFProxy) Listen() error {
IP: net.ParseIP("127.0.0.1"),
}
conn, err := nbnet.ListenUDP("udp", &addr)
p.conn, err = net.ListenUDP("udp", &addr)
if err != nil {
cErr := p.Free()
if cErr != nil {
@@ -76,7 +74,6 @@ func (p *WGEBPFProxy) Listen() error {
}
return err
}
p.conn = conn
go p.proxyToRemote()
log.Infof("local wg proxy listening on: %d", wgPorxyPort)
@@ -211,41 +208,20 @@ generatePort:
}
func (p *WGEBPFProxy) prepareSenderRawSocket() (net.PacketConn, error) {
// Create a raw socket.
fd, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_RAW, syscall.IPPROTO_RAW)
if err != nil {
return nil, fmt.Errorf("creating raw socket failed: %w", err)
return nil, err
}
// Set the IP_HDRINCL option on the socket to tell the kernel that headers are included in the packet.
err = syscall.SetsockoptInt(fd, syscall.IPPROTO_IP, syscall.IP_HDRINCL, 1)
if err != nil {
return nil, fmt.Errorf("setting IP_HDRINCL failed: %w", err)
return nil, err
}
// Bind the socket to the "lo" interface.
err = syscall.SetsockoptString(fd, syscall.SOL_SOCKET, syscall.SO_BINDTODEVICE, "lo")
if err != nil {
return nil, fmt.Errorf("binding to lo interface failed: %w", err)
return nil, err
}
// Set the fwmark on the socket.
err = nbnet.SetSocketOpt(fd)
if err != nil {
return nil, fmt.Errorf("setting fwmark failed: %w", err)
}
// Convert the file descriptor to a PacketConn.
file := os.NewFile(uintptr(fd), fmt.Sprintf("fd %d", fd))
if file == nil {
return nil, fmt.Errorf("converting fd to file failed")
}
packetConn, err := net.FilePacketConn(file)
if err != nil {
return nil, fmt.Errorf("converting file to packet conn failed: %w", err)
}
return packetConn, nil
return net.FilePacketConn(os.NewFile(uintptr(fd), fmt.Sprintf("fd %d", fd)))
}
func (p *WGEBPFProxy) sendPkg(data []byte, port uint16) error {

View File

@@ -6,8 +6,6 @@ import (
"net"
log "github.com/sirupsen/logrus"
nbnet "github.com/netbirdio/netbird/util/net"
)
// WGUserSpaceProxy proxies
@@ -35,7 +33,7 @@ func (p *WGUserSpaceProxy) AddTurnConn(remoteConn net.Conn) (net.Addr, error) {
p.remoteConn = remoteConn
var err error
p.localConn, err = nbnet.NewDialer().Dial("udp", fmt.Sprintf(":%d", p.localWGListenPort))
p.localConn, err = net.Dial("udp", fmt.Sprintf(":%d", p.localWGListenPort))
if err != nil {
log.Errorf("failed dialing to local Wireguard port %s", err)
return nil, err

View File

@@ -82,7 +82,6 @@ func (c *Client) Run(fd int32, interfaceName string) error {
return err
}
c.recorder.UpdateManagementAddress(cfg.ManagementURL.String())
c.recorder.UpdateRosenpass(cfg.RosenpassEnabled, cfg.RosenpassPermissive)
var ctx context.Context
//nolint

View File

@@ -71,42 +71,6 @@ func (p *Preferences) SetPreSharedKey(key string) {
p.configInput.PreSharedKey = &key
}
// SetRosenpassEnabled store if rosenpass is enabled
func (p *Preferences) SetRosenpassEnabled(enabled bool) {
p.configInput.RosenpassEnabled = &enabled
}
// GetRosenpassEnabled read rosenpass enabled from config file
func (p *Preferences) GetRosenpassEnabled() (bool, error) {
if p.configInput.RosenpassEnabled != nil {
return *p.configInput.RosenpassEnabled, nil
}
cfg, err := internal.ReadConfig(p.configInput.ConfigPath)
if err != nil {
return false, err
}
return cfg.RosenpassEnabled, err
}
// SetRosenpassPermissive store the given permissive and wait for commit
func (p *Preferences) SetRosenpassPermissive(permissive bool) {
p.configInput.RosenpassPermissive = &permissive
}
// GetRosenpassPermissive read rosenpass permissive from config file
func (p *Preferences) GetRosenpassPermissive() (bool, error) {
if p.configInput.RosenpassPermissive != nil {
return *p.configInput.RosenpassPermissive, nil
}
cfg, err := internal.ReadConfig(p.configInput.ConfigPath)
if err != nil {
return false, err
}
return cfg.RosenpassPermissive, err
}
// Commit write out the changes into config file
func (p *Preferences) Commit() error {
_, err := internal.UpdateOrCreateConfig(p.configInput)

File diff suppressed because it is too large Load Diff

View File

@@ -2,7 +2,6 @@ syntax = "proto3";
import "google/protobuf/descriptor.proto";
import "google/protobuf/timestamp.proto";
import "google/protobuf/duration.proto";
option go_package = "/proto";
@@ -27,21 +26,6 @@ service DaemonService {
// GetConfig of the daemon.
rpc GetConfig(GetConfigRequest) returns (GetConfigResponse) {}
// List available network routes
rpc ListRoutes(ListRoutesRequest) returns (ListRoutesResponse) {}
// Select specific routes
rpc SelectRoutes(SelectRoutesRequest) returns (SelectRoutesResponse) {}
// Deselect specific routes
rpc DeselectRoutes(SelectRoutesRequest) returns (SelectRoutesResponse) {}
// DebugBundle creates a debug bundle
rpc DebugBundle(DebugBundleRequest) returns (DebugBundleResponse) {}
// SetLogLevel sets the log level of the daemon
rpc SetLogLevel(SetLogLevelRequest) returns (SetLogLevelResponse) {}
};
message LoginRequest {
@@ -50,7 +34,7 @@ message LoginRequest {
// This is the old PreSharedKey field which will be deprecated in favor of optionalPreSharedKey field that is defined as optional
// to allow clearing of preshared key while being able to persist in the config file.
string preSharedKey = 2 [deprecated = true];
string preSharedKey = 2 [deprecated=true];
// managementUrl to authenticate.
string managementUrl = 3;
@@ -79,14 +63,6 @@ message LoginRequest {
optional int64 wireguardPort = 12;
optional string optionalPreSharedKey = 13;
optional bool disableAutoConnect = 14;
optional bool serverSSHAllowed = 15;
optional bool rosenpassPermissive = 16;
repeated string extraIFaceBlacklist = 17;
}
message LoginResponse {
@@ -158,9 +134,6 @@ message PeerState {
google.protobuf.Timestamp lastWireguardHandshake = 12;
int64 bytesRx = 13;
int64 bytesTx = 14;
bool rosenpassEnabled = 15;
repeated string routes = 16;
google.protobuf.Duration latency = 17;
}
// LocalPeerState contains the latest state of the local peer
@@ -169,9 +142,6 @@ message LocalPeerState {
string pubKey = 2;
bool kernelInterface = 3;
string fqdn = 4;
bool rosenpassEnabled = 5;
bool rosenpassPermissive = 6;
repeated string routes = 7;
}
// SignalState contains the latest state of a signal connection
@@ -195,13 +165,6 @@ message RelayState {
string error = 3;
}
message NSGroupState {
repeated string servers = 1;
repeated string domains = 2;
bool enabled = 3;
string error = 4;
}
// FullStatus contains the full state held by the Status instance
message FullStatus {
ManagementState managementState = 1;
@@ -209,54 +172,4 @@ message FullStatus {
LocalPeerState localPeerState = 3;
repeated PeerState peers = 4;
repeated RelayState relays = 5;
repeated NSGroupState dns_servers = 6;
}
message ListRoutesRequest {
}
message ListRoutesResponse {
repeated Route routes = 1;
}
message SelectRoutesRequest {
repeated string routeIDs = 1;
bool append = 2;
bool all = 3;
}
message SelectRoutesResponse {
}
message Route {
string ID = 1;
string network = 2;
bool selected = 3;
}
message DebugBundleRequest {
bool anonymize = 1;
string status = 2;
}
message DebugBundleResponse {
string path = 1;
}
enum LogLevel {
UNKNOWN = 0;
PANIC = 1;
FATAL = 2;
ERROR = 3;
WARN = 4;
INFO = 5;
DEBUG = 6;
TRACE = 7;
}
message SetLogLevelRequest {
LogLevel level = 1;
}
message SetLogLevelResponse {
}

View File

@@ -31,16 +31,6 @@ type DaemonServiceClient interface {
Down(ctx context.Context, in *DownRequest, opts ...grpc.CallOption) (*DownResponse, error)
// GetConfig of the daemon.
GetConfig(ctx context.Context, in *GetConfigRequest, opts ...grpc.CallOption) (*GetConfigResponse, error)
// List available network routes
ListRoutes(ctx context.Context, in *ListRoutesRequest, opts ...grpc.CallOption) (*ListRoutesResponse, error)
// Select specific routes
SelectRoutes(ctx context.Context, in *SelectRoutesRequest, opts ...grpc.CallOption) (*SelectRoutesResponse, error)
// Deselect specific routes
DeselectRoutes(ctx context.Context, in *SelectRoutesRequest, opts ...grpc.CallOption) (*SelectRoutesResponse, error)
// DebugBundle creates a debug bundle
DebugBundle(ctx context.Context, in *DebugBundleRequest, opts ...grpc.CallOption) (*DebugBundleResponse, error)
// SetLogLevel sets the log level of the daemon
SetLogLevel(ctx context.Context, in *SetLogLevelRequest, opts ...grpc.CallOption) (*SetLogLevelResponse, error)
}
type daemonServiceClient struct {
@@ -105,51 +95,6 @@ func (c *daemonServiceClient) GetConfig(ctx context.Context, in *GetConfigReques
return out, nil
}
func (c *daemonServiceClient) ListRoutes(ctx context.Context, in *ListRoutesRequest, opts ...grpc.CallOption) (*ListRoutesResponse, error) {
out := new(ListRoutesResponse)
err := c.cc.Invoke(ctx, "/daemon.DaemonService/ListRoutes", in, out, opts...)
if err != nil {
return nil, err
}
return out, nil
}
func (c *daemonServiceClient) SelectRoutes(ctx context.Context, in *SelectRoutesRequest, opts ...grpc.CallOption) (*SelectRoutesResponse, error) {
out := new(SelectRoutesResponse)
err := c.cc.Invoke(ctx, "/daemon.DaemonService/SelectRoutes", in, out, opts...)
if err != nil {
return nil, err
}
return out, nil
}
func (c *daemonServiceClient) DeselectRoutes(ctx context.Context, in *SelectRoutesRequest, opts ...grpc.CallOption) (*SelectRoutesResponse, error) {
out := new(SelectRoutesResponse)
err := c.cc.Invoke(ctx, "/daemon.DaemonService/DeselectRoutes", in, out, opts...)
if err != nil {
return nil, err
}
return out, nil
}
func (c *daemonServiceClient) DebugBundle(ctx context.Context, in *DebugBundleRequest, opts ...grpc.CallOption) (*DebugBundleResponse, error) {
out := new(DebugBundleResponse)
err := c.cc.Invoke(ctx, "/daemon.DaemonService/DebugBundle", in, out, opts...)
if err != nil {
return nil, err
}
return out, nil
}
func (c *daemonServiceClient) SetLogLevel(ctx context.Context, in *SetLogLevelRequest, opts ...grpc.CallOption) (*SetLogLevelResponse, error) {
out := new(SetLogLevelResponse)
err := c.cc.Invoke(ctx, "/daemon.DaemonService/SetLogLevel", in, out, opts...)
if err != nil {
return nil, err
}
return out, nil
}
// DaemonServiceServer is the server API for DaemonService service.
// All implementations must embed UnimplementedDaemonServiceServer
// for forward compatibility
@@ -167,16 +112,6 @@ type DaemonServiceServer interface {
Down(context.Context, *DownRequest) (*DownResponse, error)
// GetConfig of the daemon.
GetConfig(context.Context, *GetConfigRequest) (*GetConfigResponse, error)
// List available network routes
ListRoutes(context.Context, *ListRoutesRequest) (*ListRoutesResponse, error)
// Select specific routes
SelectRoutes(context.Context, *SelectRoutesRequest) (*SelectRoutesResponse, error)
// Deselect specific routes
DeselectRoutes(context.Context, *SelectRoutesRequest) (*SelectRoutesResponse, error)
// DebugBundle creates a debug bundle
DebugBundle(context.Context, *DebugBundleRequest) (*DebugBundleResponse, error)
// SetLogLevel sets the log level of the daemon
SetLogLevel(context.Context, *SetLogLevelRequest) (*SetLogLevelResponse, error)
mustEmbedUnimplementedDaemonServiceServer()
}
@@ -202,21 +137,6 @@ func (UnimplementedDaemonServiceServer) Down(context.Context, *DownRequest) (*Do
func (UnimplementedDaemonServiceServer) GetConfig(context.Context, *GetConfigRequest) (*GetConfigResponse, error) {
return nil, status.Errorf(codes.Unimplemented, "method GetConfig not implemented")
}
func (UnimplementedDaemonServiceServer) ListRoutes(context.Context, *ListRoutesRequest) (*ListRoutesResponse, error) {
return nil, status.Errorf(codes.Unimplemented, "method ListRoutes not implemented")
}
func (UnimplementedDaemonServiceServer) SelectRoutes(context.Context, *SelectRoutesRequest) (*SelectRoutesResponse, error) {
return nil, status.Errorf(codes.Unimplemented, "method SelectRoutes not implemented")
}
func (UnimplementedDaemonServiceServer) DeselectRoutes(context.Context, *SelectRoutesRequest) (*SelectRoutesResponse, error) {
return nil, status.Errorf(codes.Unimplemented, "method DeselectRoutes not implemented")
}
func (UnimplementedDaemonServiceServer) DebugBundle(context.Context, *DebugBundleRequest) (*DebugBundleResponse, error) {
return nil, status.Errorf(codes.Unimplemented, "method DebugBundle not implemented")
}
func (UnimplementedDaemonServiceServer) SetLogLevel(context.Context, *SetLogLevelRequest) (*SetLogLevelResponse, error) {
return nil, status.Errorf(codes.Unimplemented, "method SetLogLevel not implemented")
}
func (UnimplementedDaemonServiceServer) mustEmbedUnimplementedDaemonServiceServer() {}
// UnsafeDaemonServiceServer may be embedded to opt out of forward compatibility for this service.
@@ -338,96 +258,6 @@ func _DaemonService_GetConfig_Handler(srv interface{}, ctx context.Context, dec
return interceptor(ctx, in, info, handler)
}
func _DaemonService_ListRoutes_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(ListRoutesRequest)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(DaemonServiceServer).ListRoutes(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/daemon.DaemonService/ListRoutes",
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(DaemonServiceServer).ListRoutes(ctx, req.(*ListRoutesRequest))
}
return interceptor(ctx, in, info, handler)
}
func _DaemonService_SelectRoutes_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(SelectRoutesRequest)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(DaemonServiceServer).SelectRoutes(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/daemon.DaemonService/SelectRoutes",
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(DaemonServiceServer).SelectRoutes(ctx, req.(*SelectRoutesRequest))
}
return interceptor(ctx, in, info, handler)
}
func _DaemonService_DeselectRoutes_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(SelectRoutesRequest)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(DaemonServiceServer).DeselectRoutes(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/daemon.DaemonService/DeselectRoutes",
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(DaemonServiceServer).DeselectRoutes(ctx, req.(*SelectRoutesRequest))
}
return interceptor(ctx, in, info, handler)
}
func _DaemonService_DebugBundle_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(DebugBundleRequest)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(DaemonServiceServer).DebugBundle(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/daemon.DaemonService/DebugBundle",
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(DaemonServiceServer).DebugBundle(ctx, req.(*DebugBundleRequest))
}
return interceptor(ctx, in, info, handler)
}
func _DaemonService_SetLogLevel_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(SetLogLevelRequest)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(DaemonServiceServer).SetLogLevel(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/daemon.DaemonService/SetLogLevel",
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(DaemonServiceServer).SetLogLevel(ctx, req.(*SetLogLevelRequest))
}
return interceptor(ctx, in, info, handler)
}
// DaemonService_ServiceDesc is the grpc.ServiceDesc for DaemonService service.
// It's only intended for direct use with grpc.RegisterService,
// and not to be introspected or modified (even as a copy)
@@ -459,26 +289,6 @@ var DaemonService_ServiceDesc = grpc.ServiceDesc{
MethodName: "GetConfig",
Handler: _DaemonService_GetConfig_Handler,
},
{
MethodName: "ListRoutes",
Handler: _DaemonService_ListRoutes_Handler,
},
{
MethodName: "SelectRoutes",
Handler: _DaemonService_SelectRoutes_Handler,
},
{
MethodName: "DeselectRoutes",
Handler: _DaemonService_DeselectRoutes_Handler,
},
{
MethodName: "DebugBundle",
Handler: _DaemonService_DebugBundle_Handler,
},
{
MethodName: "SetLogLevel",
Handler: _DaemonService_SetLogLevel_Handler,
},
},
Streams: []grpc.StreamDesc{},
Metadata: "daemon.proto",

View File

@@ -1,175 +0,0 @@
package server
import (
"archive/zip"
"bufio"
"context"
"fmt"
"io"
"os"
"strings"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/anonymize"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/proto"
)
// DebugBundle creates a debug bundle and returns the location.
func (s *Server) DebugBundle(_ context.Context, req *proto.DebugBundleRequest) (resp *proto.DebugBundleResponse, err error) {
s.mutex.Lock()
defer s.mutex.Unlock()
if s.logFile == "console" {
return nil, fmt.Errorf("log file is set to console, cannot create debug bundle")
}
bundlePath, err := os.CreateTemp("", "netbird.debug.*.zip")
if err != nil {
return nil, fmt.Errorf("create zip file: %w", err)
}
defer func() {
if err := bundlePath.Close(); err != nil {
log.Errorf("failed to close zip file: %v", err)
}
if err != nil {
if err2 := os.Remove(bundlePath.Name()); err2 != nil {
log.Errorf("Failed to remove zip file: %v", err2)
}
}
}()
archive := zip.NewWriter(bundlePath)
defer func() {
if err := archive.Close(); err != nil {
log.Errorf("failed to close archive writer: %v", err)
}
}()
if status := req.GetStatus(); status != "" {
filename := "status.txt"
if req.GetAnonymize() {
filename = "status.anon.txt"
}
statusReader := strings.NewReader(status)
if err := addFileToZip(archive, statusReader, filename); err != nil {
return nil, fmt.Errorf("add status file to zip: %w", err)
}
}
logFile, err := os.Open(s.logFile)
if err != nil {
return nil, fmt.Errorf("open log file: %w", err)
}
defer func() {
if err := logFile.Close(); err != nil {
log.Errorf("failed to close original log file: %v", err)
}
}()
filename := "client.log.txt"
var logReader io.Reader
errChan := make(chan error, 1)
if req.GetAnonymize() {
filename = "client.anon.log.txt"
var writer io.WriteCloser
logReader, writer = io.Pipe()
go s.anonymize(logFile, writer, errChan)
} else {
logReader = logFile
}
if err := addFileToZip(archive, logReader, filename); err != nil {
return nil, fmt.Errorf("add log file to zip: %w", err)
}
select {
case err := <-errChan:
if err != nil {
return nil, err
}
default:
}
return &proto.DebugBundleResponse{Path: bundlePath.Name()}, nil
}
func (s *Server) anonymize(reader io.Reader, writer io.WriteCloser, errChan chan<- error) {
scanner := bufio.NewScanner(reader)
anonymizer := anonymize.NewAnonymizer(anonymize.DefaultAddresses())
status := s.statusRecorder.GetFullStatus()
seedFromStatus(anonymizer, &status)
defer func() {
if err := writer.Close(); err != nil {
log.Errorf("Failed to close writer: %v", err)
}
}()
for scanner.Scan() {
line := anonymizer.AnonymizeString(scanner.Text())
if _, err := writer.Write([]byte(line + "\n")); err != nil {
errChan <- fmt.Errorf("write line to writer: %w", err)
return
}
}
if err := scanner.Err(); err != nil {
errChan <- fmt.Errorf("read line from scanner: %w", err)
return
}
}
// SetLogLevel sets the logging level for the server.
func (s *Server) SetLogLevel(_ context.Context, req *proto.SetLogLevelRequest) (*proto.SetLogLevelResponse, error) {
level, err := log.ParseLevel(req.Level.String())
if err != nil {
return nil, fmt.Errorf("invalid log level: %w", err)
}
log.SetLevel(level)
log.Infof("Log level set to %s", level.String())
return &proto.SetLogLevelResponse{}, nil
}
func addFileToZip(archive *zip.Writer, reader io.Reader, filename string) error {
header := &zip.FileHeader{
Name: filename,
Method: zip.Deflate,
}
writer, err := archive.CreateHeader(header)
if err != nil {
return fmt.Errorf("create zip file header: %w", err)
}
if _, err := io.Copy(writer, reader); err != nil {
return fmt.Errorf("write file to zip: %w", err)
}
return nil
}
func seedFromStatus(a *anonymize.Anonymizer, status *peer.FullStatus) {
status.ManagementState.URL = a.AnonymizeURI(status.ManagementState.URL)
status.SignalState.URL = a.AnonymizeURI(status.SignalState.URL)
status.LocalPeerState.FQDN = a.AnonymizeDomain(status.LocalPeerState.FQDN)
for _, peer := range status.Peers {
a.AnonymizeDomain(peer.FQDN)
}
for _, nsGroup := range status.NSGroupStates {
for _, domain := range nsGroup.Domains {
a.AnonymizeDomain(domain)
}
}
for _, relay := range status.Relays {
if relay.URI != nil {
a.AnonymizeURI(relay.URI.String())
}
}
}

View File

@@ -1,110 +0,0 @@
package server
import (
"context"
"fmt"
"net/netip"
"sort"
"golang.org/x/exp/maps"
"github.com/netbirdio/netbird/client/proto"
)
type selectRoute struct {
NetID string
Network netip.Prefix
Selected bool
}
// ListRoutes returns a list of all available routes.
func (s *Server) ListRoutes(ctx context.Context, req *proto.ListRoutesRequest) (*proto.ListRoutesResponse, error) {
s.mutex.Lock()
defer s.mutex.Unlock()
if s.engine == nil {
return nil, fmt.Errorf("not connected")
}
routesMap := s.engine.GetClientRoutesWithNetID()
routeSelector := s.engine.GetRouteManager().GetRouteSelector()
var routes []*selectRoute
for id, rt := range routesMap {
if len(rt) == 0 {
continue
}
route := &selectRoute{
NetID: id,
Network: rt[0].Network,
Selected: routeSelector.IsSelected(id),
}
routes = append(routes, route)
}
sort.Slice(routes, func(i, j int) bool {
iPrefix := routes[i].Network.Bits()
jPrefix := routes[j].Network.Bits()
if iPrefix == jPrefix {
iAddr := routes[i].Network.Addr()
jAddr := routes[j].Network.Addr()
if iAddr == jAddr {
return routes[i].NetID < routes[j].NetID
}
return iAddr.String() < jAddr.String()
}
return iPrefix < jPrefix
})
var pbRoutes []*proto.Route
for _, route := range routes {
pbRoutes = append(pbRoutes, &proto.Route{
ID: route.NetID,
Network: route.Network.String(),
Selected: route.Selected,
})
}
return &proto.ListRoutesResponse{
Routes: pbRoutes,
}, nil
}
// SelectRoutes selects specific routes based on the client request.
func (s *Server) SelectRoutes(_ context.Context, req *proto.SelectRoutesRequest) (*proto.SelectRoutesResponse, error) {
s.mutex.Lock()
defer s.mutex.Unlock()
routeManager := s.engine.GetRouteManager()
routeSelector := routeManager.GetRouteSelector()
if req.GetAll() {
routeSelector.SelectAllRoutes()
} else {
if err := routeSelector.SelectRoutes(req.GetRouteIDs(), req.GetAppend(), maps.Keys(s.engine.GetClientRoutesWithNetID())); err != nil {
return nil, fmt.Errorf("select routes: %w", err)
}
}
routeManager.TriggerSelection(s.engine.GetClientRoutes())
return &proto.SelectRoutesResponse{}, nil
}
// DeselectRoutes deselects specific routes based on the client request.
func (s *Server) DeselectRoutes(_ context.Context, req *proto.SelectRoutesRequest) (*proto.SelectRoutesResponse, error) {
s.mutex.Lock()
defer s.mutex.Unlock()
routeManager := s.engine.GetRouteManager()
routeSelector := routeManager.GetRouteSelector()
if req.GetAll() {
routeSelector.DeselectAllRoutes()
} else {
if err := routeSelector.DeselectRoutes(req.GetRouteIDs(), maps.Keys(s.engine.GetClientRoutesWithNetID())); err != nil {
return nil, fmt.Errorf("deselect routes: %w", err)
}
}
routeManager.TriggerSelection(s.engine.GetClientRoutes())
return &proto.SelectRoutesResponse{}, nil
}

View File

@@ -3,17 +3,11 @@ package server
import (
"context"
"fmt"
"os"
"os/exec"
"runtime"
"strconv"
"sync"
"time"
"github.com/cenkalti/backoff/v4"
"golang.org/x/exp/maps"
"google.golang.org/protobuf/types/known/durationpb"
"github.com/netbirdio/netbird/client/internal/auth"
"github.com/netbirdio/netbird/client/system"
log "github.com/sirupsen/logrus"
"google.golang.org/grpc/codes"
@@ -21,26 +15,13 @@ import (
gstatus "google.golang.org/grpc/status"
"google.golang.org/protobuf/types/known/timestamppb"
"github.com/netbirdio/netbird/client/internal/auth"
"github.com/netbirdio/netbird/client/system"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/version"
)
const (
probeThreshold = time.Second * 5
retryInitialIntervalVar = "NB_CONN_RETRY_INTERVAL_TIME"
maxRetryIntervalVar = "NB_CONN_MAX_RETRY_INTERVAL_TIME"
maxRetryTimeVar = "NB_CONN_MAX_RETRY_TIME_TIME"
retryMultiplierVar = "NB_CONN_RETRY_MULTIPLIER"
defaultInitialRetryTime = 14 * 24 * time.Hour
defaultMaxRetryInterval = 60 * time.Minute
defaultMaxRetryTime = 14 * 24 * time.Hour
defaultRetryMultiplier = 1.7
)
const probeThreshold = time.Second * 5
// Server for service control.
type Server struct {
@@ -57,10 +38,7 @@ type Server struct {
config *internal.Config
proto.UnimplementedDaemonServiceServer
engine *internal.Engine
statusRecorder *peer.Status
sessionWatcher *internal.SessionWatcher
mgmProbe *internal.Probe
signalProbe *internal.Probe
@@ -134,123 +112,17 @@ func (s *Server) Start() error {
if s.statusRecorder == nil {
s.statusRecorder = peer.NewRecorder(config.ManagementURL.String())
} else {
s.statusRecorder.UpdateManagementAddress(config.ManagementURL.String())
}
s.statusRecorder.UpdateManagementAddress(config.ManagementURL.String())
s.statusRecorder.UpdateRosenpass(config.RosenpassEnabled, config.RosenpassPermissive)
if s.sessionWatcher == nil {
s.sessionWatcher = internal.NewSessionWatcher(s.rootCtx, s.statusRecorder)
s.sessionWatcher.SetOnExpireListener(s.onSessionExpire)
}
engineChan := make(chan *internal.Engine, 1)
go s.watchEngine(ctx, engineChan)
if !config.DisableAutoConnect {
go s.connectWithRetryRuns(ctx, config, s.statusRecorder, s.mgmProbe, s.signalProbe, s.relayProbe, s.wgProbe, engineChan)
}
return nil
}
// connectWithRetryRuns runs the client connection with a backoff strategy where we retry the operation as additional
// mechanism to keep the client connected even when the connection is lost.
// we cancel retry if the client receive a stop or down command, or if disable auto connect is configured.
func (s *Server) connectWithRetryRuns(ctx context.Context, config *internal.Config, statusRecorder *peer.Status,
mgmProbe *internal.Probe, signalProbe *internal.Probe, relayProbe *internal.Probe, wgProbe *internal.Probe,
engineChan chan<- *internal.Engine,
) {
backOff := getConnectWithBackoff(ctx)
retryStarted := false
go func() {
t := time.NewTicker(24 * time.Hour)
for {
select {
case <-ctx.Done():
t.Stop()
return
case <-t.C:
if retryStarted {
mgmtState := statusRecorder.GetManagementState()
signalState := statusRecorder.GetSignalState()
if mgmtState.Connected && signalState.Connected {
log.Tracef("resetting status")
retryStarted = false
} else {
log.Tracef("not resetting status: mgmt: %v, signal: %v", mgmtState.Connected, signalState.Connected)
}
}
}
if err := internal.RunClientWithProbes(ctx, config, s.statusRecorder, s.mgmProbe, s.signalProbe, s.relayProbe, s.wgProbe); err != nil {
log.Errorf("init connections: %v", err)
}
}()
runOperation := func() error {
log.Tracef("running client connection")
err := internal.RunClientWithProbes(ctx, config, statusRecorder, mgmProbe, signalProbe, relayProbe, wgProbe, engineChan)
if err != nil {
log.Debugf("run client connection exited with error: %v. Will retry in the background", err)
}
if config.DisableAutoConnect {
return backoff.Permanent(err)
}
if !retryStarted {
retryStarted = true
backOff.Reset()
}
log.Tracef("client connection exited")
return fmt.Errorf("client connection exited")
}
err := backoff.Retry(runOperation, backOff)
if s, ok := gstatus.FromError(err); ok && s.Code() != codes.Canceled {
log.Errorf("received an error when trying to connect: %v", err)
} else {
log.Tracef("retry canceled")
}
}
// getConnectWithBackoff returns a backoff with exponential backoff strategy for connection retries
func getConnectWithBackoff(ctx context.Context) backoff.BackOff {
initialInterval := parseEnvDuration(retryInitialIntervalVar, defaultInitialRetryTime)
maxInterval := parseEnvDuration(maxRetryIntervalVar, defaultMaxRetryInterval)
maxElapsedTime := parseEnvDuration(maxRetryTimeVar, defaultMaxRetryTime)
multiplier := defaultRetryMultiplier
if envValue := os.Getenv(retryMultiplierVar); envValue != "" {
// parse the multiplier from the environment variable string value to float64
value, err := strconv.ParseFloat(envValue, 64)
if err != nil {
log.Warnf("unable to parse environment variable %s: %s. using default: %f", retryMultiplierVar, envValue, multiplier)
} else {
multiplier = value
}
}
return backoff.WithContext(&backoff.ExponentialBackOff{
InitialInterval: initialInterval,
RandomizationFactor: 1,
Multiplier: multiplier,
MaxInterval: maxInterval,
MaxElapsedTime: maxElapsedTime, // 14 days
Stop: backoff.Stop,
Clock: backoff.SystemClock,
}, ctx)
}
// parseEnvDuration parses the environment variable and returns the duration
func parseEnvDuration(envVar string, defaultDuration time.Duration) time.Duration {
if envValue := os.Getenv(envVar); envValue != "" {
if duration, err := time.ParseDuration(envValue); err == nil {
return duration
}
log.Warnf("unable to parse environment variable %s: %s. using default: %s", envVar, envValue, defaultDuration)
}
return defaultDuration
return nil
}
// loginAttempt attempts to login using the provided information. it returns a status in case something fails
@@ -332,21 +204,6 @@ func (s *Server) Login(callerCtx context.Context, msg *proto.LoginRequest) (*pro
s.latestConfigInput.RosenpassEnabled = msg.RosenpassEnabled
}
if msg.RosenpassPermissive != nil {
inputConfig.RosenpassPermissive = msg.RosenpassPermissive
s.latestConfigInput.RosenpassPermissive = msg.RosenpassPermissive
}
if msg.ServerSSHAllowed != nil {
inputConfig.ServerSSHAllowed = msg.ServerSSHAllowed
s.latestConfigInput.ServerSSHAllowed = msg.ServerSSHAllowed
}
if msg.DisableAutoConnect != nil {
inputConfig.DisableAutoConnect = msg.DisableAutoConnect
s.latestConfigInput.DisableAutoConnect = msg.DisableAutoConnect
}
if msg.InterfaceName != nil {
inputConfig.InterfaceName = msg.InterfaceName
s.latestConfigInput.InterfaceName = msg.InterfaceName
@@ -358,11 +215,6 @@ func (s *Server) Login(callerCtx context.Context, msg *proto.LoginRequest) (*pro
s.latestConfigInput.WireguardPort = &port
}
if len(msg.ExtraIFaceBlacklist) > 0 {
inputConfig.ExtraIFaceBlackList = msg.ExtraIFaceBlacklist
s.latestConfigInput.ExtraIFaceBlackList = msg.ExtraIFaceBlacklist
}
s.mutex.Unlock()
if msg.OptionalPreSharedKey != nil {
@@ -564,14 +416,16 @@ func (s *Server) Up(callerCtx context.Context, _ *proto.UpRequest) (*proto.UpRes
if s.statusRecorder == nil {
s.statusRecorder = peer.NewRecorder(s.config.ManagementURL.String())
} else {
s.statusRecorder.UpdateManagementAddress(s.config.ManagementURL.String())
}
s.statusRecorder.UpdateManagementAddress(s.config.ManagementURL.String())
s.statusRecorder.UpdateRosenpass(s.config.RosenpassEnabled, s.config.RosenpassPermissive)
engineChan := make(chan *internal.Engine, 1)
go s.watchEngine(ctx, engineChan)
go s.connectWithRetryRuns(ctx, s.config, s.statusRecorder, s.mgmProbe, s.signalProbe, s.relayProbe, s.wgProbe, engineChan)
go func() {
if err := internal.RunClientWithProbes(ctx, s.config, s.statusRecorder, s.mgmProbe, s.signalProbe, s.relayProbe, s.wgProbe); err != nil {
log.Errorf("run client connection: %v", err)
return
}
}()
return &proto.UpResponse{}, nil
}
@@ -588,8 +442,6 @@ func (s *Server) Down(_ context.Context, _ *proto.DownRequest) (*proto.DownRespo
state := internal.CtxGetState(s.rootCtx)
state.Set(internal.StatusIdle)
s.engine = nil
return &proto.DownResponse{}, nil
}
@@ -610,9 +462,9 @@ func (s *Server) Status(
if s.statusRecorder == nil {
s.statusRecorder = peer.NewRecorder(s.config.ManagementURL.String())
} else {
s.statusRecorder.UpdateManagementAddress(s.config.ManagementURL.String())
}
s.statusRecorder.UpdateManagementAddress(s.config.ManagementURL.String())
s.statusRecorder.UpdateRosenpass(s.config.RosenpassEnabled, s.config.RosenpassPermissive)
if msg.GetFullPeerStatus {
s.runProbes()
@@ -672,32 +524,6 @@ func (s *Server) GetConfig(_ context.Context, _ *proto.GetConfigRequest) (*proto
PreSharedKey: preSharedKey,
}, nil
}
func (s *Server) onSessionExpire() {
if runtime.GOOS != "windows" {
isUIActive := internal.CheckUIApp()
if !isUIActive {
if err := sendTerminalNotification(); err != nil {
log.Errorf("send session expire terminal notification: %v", err)
}
}
}
}
// watchEngine watches the engine channel and updates the engine state
func (s *Server) watchEngine(ctx context.Context, engineChan chan *internal.Engine) {
log.Tracef("Started watching engine")
for {
select {
case <-ctx.Done():
s.engine = nil
log.Tracef("Stopped watching engine")
return
case engine := <-engineChan:
log.Tracef("Received engine from watcher")
s.engine = engine
}
}
}
func toProtoFullStatus(fullStatus peer.FullStatus) *proto.FullStatus {
pbFullStatus := proto.FullStatus{
@@ -705,6 +531,7 @@ func toProtoFullStatus(fullStatus peer.FullStatus) *proto.FullStatus {
SignalState: &proto.SignalState{},
LocalPeerState: &proto.LocalPeerState{},
Peers: []*proto.PeerState{},
Relays: []*proto.RelayState{},
}
pbFullStatus.ManagementState.URL = fullStatus.ManagementState.URL
@@ -723,9 +550,6 @@ func toProtoFullStatus(fullStatus peer.FullStatus) *proto.FullStatus {
pbFullStatus.LocalPeerState.PubKey = fullStatus.LocalPeerState.PubKey
pbFullStatus.LocalPeerState.KernelInterface = fullStatus.LocalPeerState.KernelInterface
pbFullStatus.LocalPeerState.Fqdn = fullStatus.LocalPeerState.FQDN
pbFullStatus.LocalPeerState.RosenpassPermissive = fullStatus.RosenpassState.Permissive
pbFullStatus.LocalPeerState.RosenpassEnabled = fullStatus.RosenpassState.Enabled
pbFullStatus.LocalPeerState.Routes = maps.Keys(fullStatus.LocalPeerState.Routes)
for _, peerState := range fullStatus.Peers {
pbPeerState := &proto.PeerState{
@@ -743,9 +567,6 @@ func toProtoFullStatus(fullStatus peer.FullStatus) *proto.FullStatus {
LastWireguardHandshake: timestamppb.New(peerState.LastWireguardHandshake),
BytesRx: peerState.BytesRx,
BytesTx: peerState.BytesTx,
RosenpassEnabled: peerState.RosenpassEnabled,
Routes: maps.Keys(peerState.GetRoutes()),
Latency: durationpb.New(peerState.Latency),
}
pbFullStatus.Peers = append(pbFullStatus.Peers, pbPeerState)
}
@@ -761,47 +582,5 @@ func toProtoFullStatus(fullStatus peer.FullStatus) *proto.FullStatus {
pbFullStatus.Relays = append(pbFullStatus.Relays, pbRelayState)
}
for _, dnsState := range fullStatus.NSGroupStates {
var err string
if dnsState.Error != nil {
err = dnsState.Error.Error()
}
pbDnsState := &proto.NSGroupState{
Servers: dnsState.Servers,
Domains: dnsState.Domains,
Enabled: dnsState.Enabled,
Error: err,
}
pbFullStatus.DnsServers = append(pbFullStatus.DnsServers, pbDnsState)
}
return &pbFullStatus
}
// sendTerminalNotification sends a terminal notification message
// to inform the user that the NetBird connection session has expired.
func sendTerminalNotification() error {
message := "NetBird connection session expired\n\nPlease re-authenticate to connect to the network."
echoCmd := exec.Command("echo", message)
wallCmd := exec.Command("sudo", "wall")
echoCmdStdout, err := echoCmd.StdoutPipe()
if err != nil {
return err
}
wallCmd.Stdin = echoCmdStdout
if err := echoCmd.Start(); err != nil {
return err
}
if err := wallCmd.Start(); err != nil {
return err
}
if err := echoCmd.Wait(); err != nil {
return err
}
return wallCmd.Wait()
}

View File

@@ -1,160 +0,0 @@
package server
import (
"context"
"net"
"testing"
"time"
"github.com/netbirdio/management-integrations/integrations"
log "github.com/sirupsen/logrus"
"google.golang.org/grpc"
"google.golang.org/grpc/keepalive"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/peer"
mgmtProto "github.com/netbirdio/netbird/management/proto"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/signal/proto"
signalServer "github.com/netbirdio/netbird/signal/server"
)
var (
kaep = keepalive.EnforcementPolicy{
MinTime: 15 * time.Second,
PermitWithoutStream: true,
}
kasp = keepalive.ServerParameters{
MaxConnectionIdle: 15 * time.Second,
MaxConnectionAgeGrace: 5 * time.Second,
Time: 5 * time.Second,
Timeout: 2 * time.Second,
}
)
// TestConnectWithRetryRuns checks that the connectWithRetry function runs and runs the retries according to the times specified via environment variables
// we will use a management server started via to simulate the server and capture the number of retries
func TestConnectWithRetryRuns(t *testing.T) {
// start the signal server
_, signalAddr, err := startSignal()
if err != nil {
t.Fatalf("failed to start signal server: %v", err)
}
counter := 0
// start the management server
_, mgmtAddr, err := startManagement(t, signalAddr, &counter)
if err != nil {
t.Fatalf("failed to start management server: %v", err)
}
ctx := internal.CtxInitState(context.Background())
ctx, cancel := context.WithDeadline(ctx, time.Now().Add(30*time.Second))
defer cancel()
// create new server
s := New(ctx, t.TempDir()+"/config.json", "debug")
s.latestConfigInput.ManagementURL = "http://" + mgmtAddr
config, err := internal.UpdateOrCreateConfig(s.latestConfigInput)
if err != nil {
t.Fatalf("failed to create config: %v", err)
}
s.config = config
s.statusRecorder = peer.NewRecorder(config.ManagementURL.String())
t.Setenv(retryInitialIntervalVar, "1s")
t.Setenv(maxRetryIntervalVar, "2s")
t.Setenv(maxRetryTimeVar, "5s")
t.Setenv(retryMultiplierVar, "1")
s.connectWithRetryRuns(ctx, config, s.statusRecorder, s.mgmProbe, s.signalProbe, s.relayProbe, s.wgProbe, nil)
if counter < 3 {
t.Fatalf("expected counter > 2, got %d", counter)
}
}
type mockServer struct {
mgmtProto.ManagementServiceServer
counter *int
}
func (m *mockServer) Login(ctx context.Context, req *mgmtProto.EncryptedMessage) (*mgmtProto.EncryptedMessage, error) {
*m.counter++
return m.ManagementServiceServer.Login(ctx, req)
}
func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Server, string, error) {
t.Helper()
dataDir := t.TempDir()
config := &server.Config{
Stuns: []*server.Host{},
TURNConfig: &server.TURNConfig{},
Signal: &server.Host{
Proto: "http",
URI: signalAddr,
},
Datadir: dataDir,
HttpConfig: nil,
}
lis, err := net.Listen("tcp", "localhost:0")
if err != nil {
return nil, "", err
}
s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp))
store, err := server.NewStoreFromJson(config.Datadir, nil)
if err != nil {
return nil, "", err
}
peersUpdateManager := server.NewPeersUpdateManager(nil)
eventStore := &activity.InMemoryEventStore{}
if err != nil {
return nil, "", err
}
ia, _ := integrations.NewIntegratedValidator(eventStore)
accountManager, err := server.BuildManager(store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia)
if err != nil {
return nil, "", err
}
turnManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig)
mgmtServer, err := server.NewServer(config, accountManager, peersUpdateManager, turnManager, nil, nil)
if err != nil {
return nil, "", err
}
mock := &mockServer{
ManagementServiceServer: mgmtServer,
counter: counter,
}
mgmtProto.RegisterManagementServiceServer(s, mock)
go func() {
if err = s.Serve(lis); err != nil {
log.Fatalf("failed to serve: %v", err)
}
}()
return s, lis.Addr().String(), nil
}
func startSignal() (*grpc.Server, string, error) {
s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp))
lis, err := net.Listen("tcp", "localhost:0")
if err != nil {
log.Fatalf("failed to listen: %v", err)
}
proto.RegisterSignalExchangeServer(s, signalServer.NewServer())
go func() {
if err = s.Serve(lis); err != nil {
log.Fatalf("failed to serve: %v", err)
}
}()
return s, lis.Addr().String(), nil
}

View File

@@ -1,24 +0,0 @@
package detect_cloud
import (
"context"
"net/http"
)
func detectAlibabaCloud(ctx context.Context) string {
req, err := http.NewRequestWithContext(ctx, "GET", "http://100.100.100.200/latest/", nil)
if err != nil {
return ""
}
resp, err := hc.Do(req)
if err != nil {
return ""
}
defer resp.Body.Close()
if resp.StatusCode == http.StatusOK {
return "Alibaba Cloud"
}
return ""
}

View File

@@ -1,57 +0,0 @@
package detect_cloud
import (
"context"
"net/http"
)
func detectAWS(ctx context.Context) string {
v1ResultChan := make(chan bool, 1)
v2ResultChan := make(chan bool, 1)
go func() {
v1ResultChan <- detectAWSIDMSv1(ctx)
}()
go func() {
v2ResultChan <- detectAWSIDMSv2(ctx)
}()
v1Result, v2Result := <-v1ResultChan, <-v2ResultChan
if v1Result || v2Result {
return "Amazon Web Services"
}
return ""
}
func detectAWSIDMSv1(ctx context.Context) bool {
req, err := http.NewRequestWithContext(ctx, "GET", "http://169.254.169.254/latest/", nil)
if err != nil {
return false
}
resp, err := hc.Do(req)
if err != nil {
return false
}
defer resp.Body.Close()
return resp.StatusCode == http.StatusOK
}
func detectAWSIDMSv2(ctx context.Context) bool {
req, err := http.NewRequestWithContext(ctx, "PUT", "http://169.254.169.254/latest/api/token", nil)
if err != nil {
return false
}
req.Header.Set("X-aws-ec2-metadata-token-ttl-seconds", "21600")
resp, err := hc.Do(req)
if err != nil {
return false
}
defer resp.Body.Close()
return resp.StatusCode == http.StatusOK
}

View File

@@ -1,25 +0,0 @@
package detect_cloud
import (
"context"
"net/http"
)
func detectAzure(ctx context.Context) string {
req, err := http.NewRequestWithContext(ctx, "GET", "http://169.254.169.254/metadata/instance?api-version=2021-02-01", nil)
if err != nil {
return ""
}
req.Header.Set("Metadata", "true")
resp, err := hc.Do(req)
if err != nil {
return ""
}
defer resp.Body.Close()
if resp.StatusCode == http.StatusOK {
return "Microsoft Azure"
}
return ""
}

View File

@@ -1,63 +0,0 @@
package detect_cloud
import (
"context"
"net/http"
"sync"
"time"
)
/*
This packages is inspired by the work of the original author (https://github.com/perlogix), but it has been modified to fit the needs of the project.
Original project: https://github.com/perlogix/libdetectcloud
*/
var hc = &http.Client{Timeout: 300 * time.Millisecond}
func Detect(ctx context.Context) string {
subCtx, cancel := context.WithCancel(context.Background())
defer cancel()
funcs := []func(context.Context) string{
detectAlibabaCloud,
detectAWS,
detectAzure,
detectDigitalOcean,
detectGCP,
detectOracle,
detectVultr,
}
results := make(chan string, len(funcs))
var wg sync.WaitGroup
for _, fn := range funcs {
wg.Add(1)
go func(f func(context.Context) string) {
defer wg.Done()
select {
case <-subCtx.Done():
return
default:
if result := f(ctx); result != "" {
results <- result
cancel()
}
}
}(fn)
}
go func() {
wg.Wait()
close(results)
}()
for result := range results {
if result != "" {
return result
}
}
return ""
}

View File

@@ -1,24 +0,0 @@
package detect_cloud
import (
"context"
"net/http"
)
func detectDigitalOcean(ctx context.Context) string {
req, err := http.NewRequestWithContext(ctx, "GET", "http://169.254.169.254/metadata/v1/", nil)
if err != nil {
return ""
}
resp, err := hc.Do(req)
if err != nil {
return ""
}
defer resp.Body.Close()
if resp.StatusCode == http.StatusOK {
return "Digital Ocean"
}
return ""
}

View File

@@ -1,25 +0,0 @@
package detect_cloud
import (
"context"
"net/http"
)
func detectGCP(ctx context.Context) string {
req, err := http.NewRequestWithContext(ctx, "GET", "http://169.254.169.254", nil)
if err != nil {
return ""
}
req.Header.Add("Metadata-Flavor", "Google")
resp, err := hc.Do(req)
if err != nil {
return ""
}
defer resp.Body.Close()
if resp.StatusCode == http.StatusOK {
return "Google Cloud Platform"
}
return ""
}

View File

@@ -1,56 +0,0 @@
package detect_cloud
import (
"context"
"net/http"
)
func detectOracle(ctx context.Context) string {
v1ResultChan := make(chan bool, 1)
v2ResultChan := make(chan bool, 1)
go func() {
v1ResultChan <- detectOracleIDMSv1(ctx)
}()
go func() {
v2ResultChan <- detectOracleIDMSv2(ctx)
}()
v1Result, v2Result := <-v1ResultChan, <-v2ResultChan
if v1Result || v2Result {
return "Oracle"
}
return ""
}
func detectOracleIDMSv1(ctx context.Context) bool {
req, err := http.NewRequestWithContext(ctx, "GET", "http://169.254.169.254/opc/v1/instance/", nil)
if err != nil {
return false
}
req.Header.Add("Authorization", "Bearer Oracle")
resp, err := hc.Do(req)
if err != nil {
return false
}
defer resp.Body.Close()
return resp.StatusCode == http.StatusOK
}
func detectOracleIDMSv2(ctx context.Context) bool {
req, err := http.NewRequestWithContext(ctx, "GET", "http://169.254.169.254/opc/v2/instance/", nil)
if err != nil {
return false
}
req.Header.Add("Authorization", "Bearer Oracle")
resp, err := hc.Do(req)
if err != nil {
return false
}
defer resp.Body.Close()
return resp.StatusCode == http.StatusOK
}

View File

@@ -1,24 +0,0 @@
package detect_cloud
import (
"context"
"net/http"
)
func detectVultr(ctx context.Context) string {
req, err := http.NewRequestWithContext(ctx, "GET", "http://169.254.169.254/v1.json", nil)
if err != nil {
return ""
}
resp, err := hc.Do(req)
if err != nil {
return ""
}
defer resp.Body.Close()
if resp.StatusCode == http.StatusOK {
return "Vultr"
}
return ""
}

View File

@@ -1,53 +0,0 @@
package detect_platform
import (
"context"
"net/http"
"sync"
"time"
)
var hc = &http.Client{Timeout: 300 * time.Millisecond}
func Detect(ctx context.Context) string {
subCtx, cancel := context.WithCancel(context.Background())
defer cancel()
funcs := []func(context.Context) string{
detectOpenStack,
detectContainer,
}
results := make(chan string, len(funcs))
var wg sync.WaitGroup
for _, fn := range funcs {
wg.Add(1)
go func(f func(context.Context) string) {
defer wg.Done()
select {
case <-subCtx.Done():
return
default:
if result := f(ctx); result != "" {
results <- result
cancel()
}
}
}(fn)
}
go func() {
wg.Wait()
close(results)
}()
for result := range results {
if result != "" {
return result
}
}
return ""
}

View File

@@ -1,17 +0,0 @@
package detect_platform
import (
"context"
"os"
)
func detectContainer(ctx context.Context) string {
if _, exists := os.LookupEnv("KUBERNETES_SERVICE_HOST"); exists {
return "Kubernetes"
}
if _, err := os.Stat("/.dockerenv"); err == nil {
return "Docker"
}
return ""
}

View File

@@ -1,24 +0,0 @@
package detect_platform
import (
"context"
"net/http"
)
func detectOpenStack(ctx context.Context) string {
req, err := http.NewRequestWithContext(ctx, "GET", "http://169.254.169.254/openstack", nil)
if err != nil {
return ""
}
resp, err := hc.Do(req)
if err != nil {
return ""
}
defer resp.Body.Close()
if resp.StatusCode == http.StatusOK {
return "OpenStack"
}
return ""
}

View File

@@ -2,8 +2,6 @@ package system
import (
"context"
"net"
"net/netip"
"strings"
"google.golang.org/grpc/metadata"
@@ -20,21 +18,12 @@ const OsVersionCtxKey = "OsVersion"
// OsNameCtxKey context key for operating system name
const OsNameCtxKey = "OsName"
type NetworkAddress struct {
NetIP netip.Prefix
Mac string
}
type Environment struct {
Cloud string
Platform string
}
// Info is an object that contains machine information
// Most of the code is taken from https://github.com/matishsiao/goInfo
type Info struct {
GoOS string
Kernel string
Core string
Platform string
OS string
OSVersion string
@@ -42,12 +31,6 @@ type Info struct {
CPUs int
WiretrusteeVersion string
UIVersion string
KernelVersion string
NetworkAddresses []NetworkAddress
SystemSerialNumber string
SystemProductName string
SystemManufacturer string
Environment Environment
}
// extractUserAgent extracts Netbird's agent (client) name and version from the outgoing context
@@ -79,53 +62,3 @@ func extractDeviceName(ctx context.Context, defaultName string) string {
func GetDesktopUIUserAgent() string {
return "netbird-desktop-ui/" + version.NetbirdVersion()
}
func networkAddresses() ([]NetworkAddress, error) {
interfaces, err := net.Interfaces()
if err != nil {
return nil, err
}
var netAddresses []NetworkAddress
for _, iface := range interfaces {
if iface.HardwareAddr.String() == "" {
continue
}
addrs, err := iface.Addrs()
if err != nil {
continue
}
for _, address := range addrs {
ipNet, ok := address.(*net.IPNet)
if !ok {
continue
}
if ipNet.IP.IsLoopback() {
continue
}
netAddr := NetworkAddress{
NetIP: netip.MustParsePrefix(ipNet.String()),
Mac: iface.HardwareAddr.String(),
}
if isDuplicated(netAddresses, netAddr) {
continue
}
netAddresses = append(netAddresses, netAddr)
}
}
return netAddresses, nil
}
func isDuplicated(addresses []NetworkAddress, addr NetworkAddress) bool {
for _, duplicated := range addresses {
if duplicated.NetIP == addr.NetIP {
return true
}
}
return false
}

Some files were not shown because too many files have changed in this diff Show More