Compare commits

..

2 Commits

Author SHA1 Message Date
Pascal Fischer
27c3a4c5d6 simplify storage inheritance 2024-03-14 11:42:25 +01:00
Pascal Fischer
f31b06fc92 add example setup for management refactor 2024-03-13 23:07:00 +01:00
405 changed files with 10997 additions and 25288 deletions

View File

@@ -14,7 +14,7 @@ jobs:
test:
strategy:
matrix:
store: ['sqlite']
store: ['jsonfile', 'sqlite']
runs-on: macos-latest
steps:
- name: Install Go
@@ -32,9 +32,6 @@ jobs:
restore-keys: |
macos-go-
- name: Install libpcap
run: brew install libpcap
- name: Install modules
run: go mod tidy

View File

@@ -1,39 +0,0 @@
name: Test Code FreeBSD
on:
push:
branches:
- main
pull_request:
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
cancel-in-progress: true
jobs:
test:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- name: Test in FreeBSD
id: test
uses: vmactions/freebsd-vm@v1
with:
usesh: true
prepare: |
pkg install -y curl
pkg install -y git
run: |
set -x
curl -o go.tar.gz https://go.dev/dl/go1.21.11.freebsd-amd64.tar.gz -L
tar zxf go.tar.gz
mv go /usr/local/go
ln -s /usr/local/go/bin/go /usr/local/bin/go
go mod tidy
go test -timeout 5m -p 1 ./iface/...
go test -timeout 5m -p 1 ./client/...
cd client
go build .
cd ..

View File

@@ -14,8 +14,8 @@ jobs:
test:
strategy:
matrix:
arch: [ '386','amd64' ]
store: [ 'sqlite', 'postgres']
arch: ['386','amd64']
store: ['jsonfile', 'sqlite']
runs-on: ubuntu-latest
steps:
- name: Install Go
@@ -36,11 +36,7 @@ jobs:
uses: actions/checkout@v3
- name: Install dependencies
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
- name: Install 32-bit libpcap
if: matrix.arch == '386'
run: sudo dpkg --add-architecture i386 && sudo apt update && sudo apt-get install -y libpcap0.8-dev:i386
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib
- name: Install modules
run: go mod tidy
@@ -71,7 +67,7 @@ jobs:
uses: actions/checkout@v3
- name: Install dependencies
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib
- name: Install modules
run: go mod tidy
@@ -86,10 +82,7 @@ jobs:
run: CGO_ENABLED=0 go test -c -o sharedsock-testing.bin ./sharedsock
- name: Generate RouteManager Test bin
run: CGO_ENABLED=0 go test -c -o routemanager-testing.bin ./client/internal/routemanager
- name: Generate SystemOps Test bin
run: CGO_ENABLED=1 go test -c -o systemops-testing.bin -tags netgo -ldflags '-w -extldflags "-static -ldbus-1 -lpcap"' ./client/internal/routemanager/systemops
run: CGO_ENABLED=0 go test -c -o routemanager-testing.bin ./client/internal/routemanager/...
- name: Generate nftables Manager Test bin
run: CGO_ENABLED=0 go test -c -o nftablesmanager-testing.bin ./client/firewall/nftables/...
@@ -111,15 +104,12 @@ jobs:
- name: Run RouteManager tests in docker
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal/routemanager --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/routemanager-testing.bin -test.timeout 5m -test.parallel 1
- name: Run SystemOps tests in docker
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal/routemanager/systemops --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/systemops-testing.bin -test.timeout 5m -test.parallel 1
- name: Run nftables Manager tests in docker
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/firewall --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/nftablesmanager-testing.bin -test.timeout 5m -test.parallel 1
- name: Run Engine tests in docker with file store
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal -e NETBIRD_STORE_ENGINE="jsonfile" --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/engine-testing.bin -test.timeout 5m -test.parallel 1
- name: Run Engine tests in docker with sqlite store
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal -e NETBIRD_STORE_ENGINE="sqlite" --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/engine-testing.bin -test.timeout 5m -test.parallel 1

View File

@@ -44,9 +44,10 @@ jobs:
- run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOMODCACHE=C:\Users\runneradmin\go\pkg\mod
- run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOCACHE=C:\Users\runneradmin\AppData\Local\go-build
- run: "[Environment]::SetEnvironmentVariable('NETBIRD_STORE_ENGINE', 'jsonfile', 'Machine')"
- name: test
run: PsExec64 -s -w ${{ github.workspace }} cmd.exe /c "C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe test -timeout 10m -p 1 ./... > test-out.txt 2>&1"
run: PsExec64 -s -w ${{ github.workspace }} cmd.exe /c "C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe test -timeout 5m -p 1 ./... > test-out.txt 2>&1"
- name: test output
if: ${{ always() }}
run: Get-Content test-out.txt

View File

@@ -19,7 +19,7 @@ jobs:
- name: codespell
uses: codespell-project/actions-codespell@v2
with:
ignore_words_list: erro,clienta,hastable,
ignore_words_list: erro,clienta
skip: go.mod,go.sum
only_warn: 1
golangci:
@@ -33,10 +33,6 @@ jobs:
steps:
- name: Checkout code
uses: actions/checkout@v3
- name: Check for duplicate constants
if: matrix.os == 'ubuntu-latest'
run: |
! awk '/const \(/,/)/{print $0}' management/server/activity/codes.go | grep -o '= [0-9]*' | sort | uniq -d | grep .
- name: Install Go
uses: actions/setup-go@v4
with:
@@ -44,7 +40,7 @@ jobs:
cache: false
- name: Install dependencies
if: matrix.os == 'ubuntu-latest'
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev libpcap-dev
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev
- name: golangci-lint
uses: golangci/golangci-lint-action@v3
with:

View File

@@ -11,7 +11,7 @@ concurrency:
cancel-in-progress: true
jobs:
android_build:
andrloid_build:
runs-on: ubuntu-latest
steps:
- name: Checkout repository
@@ -38,10 +38,10 @@ jobs:
- name: Setup NDK
run: /usr/local/lib/android/sdk/cmdline-tools/7.0/bin/sdkmanager --install "ndk;23.1.7779620"
- name: install gomobile
run: go install golang.org/x/mobile/cmd/gomobile@v0.0.0-20240404231514-09dbf07665ed
run: go install golang.org/x/mobile/cmd/gomobile@v0.0.0-20230531173138-3c911d8e3eda
- name: gomobile init
run: gomobile init
- name: build android netbird lib
- name: build android nebtird lib
run: PATH=$PATH:$(go env GOPATH) gomobile bind -o $GITHUB_WORKSPACE/netbird.aar -javapkg=io.netbird.gomobile -ldflags="-X golang.zx2c4.com/wireguard/ipc.socketDirectory=/data/data/io.netbird.client/cache/wireguard -X github.com/netbirdio/netbird/version.version=buildtest" $GITHUB_WORKSPACE/client/android
env:
CGO_ENABLED: 0
@@ -56,10 +56,10 @@ jobs:
with:
go-version: "1.21.x"
- name: install gomobile
run: go install golang.org/x/mobile/cmd/gomobile@v0.0.0-20240404231514-09dbf07665ed
run: go install golang.org/x/mobile/cmd/gomobile@v0.0.0-20230531173138-3c911d8e3eda
- name: gomobile init
run: gomobile init
- name: build iOS netbird lib
run: PATH=$PATH:$(go env GOPATH) gomobile bind -target=ios -bundleid=io.netbird.framework -ldflags="-X github.com/netbirdio/netbird/version.version=buildtest" -o ./NetBirdSDK.xcframework ./client/ios/NetBirdSDK
- name: build iOS nebtird lib
run: PATH=$PATH:$(go env GOPATH) gomobile bind -target=ios -bundleid=io.netbird.framework -ldflags="-X github.com/netbirdio/netbird/version.version=buildtest" -o $GITHUB_WORKSPACE/NetBirdSDK.xcframework $GITHUB_WORKSPACE/client/ios/NetBirdSDK
env:
CGO_ENABLED: 0

View File

@@ -7,7 +7,17 @@ on:
branches:
- main
pull_request:
paths:
- 'go.mod'
- 'go.sum'
- '.goreleaser.yml'
- '.goreleaser_ui.yaml'
- '.goreleaser_ui_darwin.yaml'
- '.github/workflows/release.yml'
- 'release_files/**'
- '**/Dockerfile'
- '**/Dockerfile.*'
- 'client/ui/**'
env:
SIGN_PIPE_VER: "v0.0.11"
@@ -96,27 +106,6 @@ jobs:
name: release
path: dist/
retention-days: 3
-
name: upload linux packages
uses: actions/upload-artifact@v3
with:
name: linux-packages
path: dist/netbird_linux**
retention-days: 3
-
name: upload windows packages
uses: actions/upload-artifact@v3
with:
name: windows-packages
path: dist/netbird_windows**
retention-days: 3
-
name: upload macos packages
uses: actions/upload-artifact@v3
with:
name: macos-packages
path: dist/netbird_darwin**
retention-days: 3
release_ui:
runs-on: ubuntu-latest
@@ -173,7 +162,7 @@ jobs:
retention-days: 3
release_ui_darwin:
runs-on: macos-latest
runs-on: macos-11
steps:
- if: ${{ !startsWith(github.ref, 'refs/tags/v') }}
run: echo "flags=--snapshot" >> $GITHUB_ENV

View File

@@ -178,79 +178,34 @@ jobs:
- name: Checkout code
uses: actions/checkout@v3
- name: run script with Zitadel PostgreSQL
- name: run script
run: NETBIRD_DOMAIN=use-ip bash -x infrastructure_files/getting-started-with-zitadel.sh
- name: test Caddy file gen postgres
- name: test Caddy file gen
run: test -f Caddyfile
- name: test docker-compose file gen postgres
- name: test docker-compose file gen
run: test -f docker-compose.yml
- name: test management.json file gen postgres
- name: test management.json file gen
run: test -f management.json
- name: test turnserver.conf file gen postgres
- name: test turnserver.conf file gen
run: |
set -x
test -f turnserver.conf
grep external-ip turnserver.conf
- name: test zitadel.env file gen postgres
- name: test zitadel.env file gen
run: test -f zitadel.env
- name: test dashboard.env file gen postgres
- name: test dashboard.env file gen
run: test -f dashboard.env
- name: test zdb.env file gen postgres
run: test -f zdb.env
- name: Postgres run cleanup
run: |
docker-compose down --volumes --rmi all
rm -rf docker-compose.yml Caddyfile zitadel.env dashboard.env machinekey/zitadel-admin-sa.token turnserver.conf management.json zdb.env
- name: run script with Zitadel CockroachDB
run: bash -x infrastructure_files/getting-started-with-zitadel.sh
env:
NETBIRD_DOMAIN: use-ip
ZITADEL_DATABASE: cockroach
- name: test Caddy file gen CockroachDB
run: test -f Caddyfile
- name: test docker-compose file gen CockroachDB
run: test -f docker-compose.yml
- name: test management.json file gen CockroachDB
run: test -f management.json
- name: test turnserver.conf file gen CockroachDB
run: |
set -x
test -f turnserver.conf
grep external-ip turnserver.conf
- name: test zitadel.env file gen CockroachDB
run: test -f zitadel.env
- name: test dashboard.env file gen CockroachDB
run: test -f dashboard.env
test-download-geolite2-script:
runs-on: ubuntu-latest
steps:
- name: Install jq
run: sudo apt-get update && sudo apt-get install -y unzip sqlite3
- name: Checkout code
uses: actions/checkout@v3
- name: test script
run: bash -x infrastructure_files/download-geolite2.sh
- name: test mmdb file exists
run: test -f GeoLite2-City.mmdb
- name: test geonames file exists
run: test -f geonames.db

View File

@@ -130,10 +130,3 @@ issues:
- path: mock\.go
linters:
- nilnil
# Exclude specific deprecation warnings for grpc methods
- linters:
- staticcheck
text: "grpc.DialContext is deprecated"
- linters:
- staticcheck
text: "grpc.WithBlock is deprecated"

View File

@@ -3,10 +3,8 @@ builds:
- id: netbird-ui-darwin
dir: client/ui
binary: netbird-ui
env:
- CGO_ENABLED=1
- MACOSX_DEPLOYMENT_TARGET=11.0
- MACOS_DEPLOYMENT_TARGET=11.0
env: [CGO_ENABLED=1]
goos:
- darwin
goarch:

View File

@@ -5,7 +5,7 @@
We as members, contributors, and leaders pledge to make participation in our
community a harassment-free experience for everyone, regardless of age, body
size, visible or invisible disability, ethnicity, sex characteristics, gender
identity and expression, level of experience, education, socioeconomic status,
identity and expression, level of experience, education, socio-economic status,
nationality, personal appearance, race, caste, color, religion, or sexual
identity and orientation.

View File

@@ -40,13 +40,11 @@
**Connect.** NetBird creates a WireGuard-based overlay network that automatically connects your machines over an encrypted tunnel, leaving behind the hassle of opening ports, complex firewall rules, VPN gateways, and so forth.
**Secure.** NetBird enables secure remote access by applying granular access policies while allowing you to manage them intuitively from a single place. Works universally on any infrastructure.
**Secure.** NetBird enables secure remote access by applying granular access policies, while allowing you to manage them intuitively from a single place. Works universally on any infrastructure.
### Open-Source Network Security in a Single Platform
![netbird_2](https://github.com/netbirdio/netbird/assets/700848/46bc3b73-508d-4a0e-bb9a-f465d68646ab)
![download (2)](https://github.com/netbirdio/netbird/assets/700848/16210ac2-7265-44c1-8d4e-8fae85534dac)
### Key features
@@ -78,7 +76,7 @@ Follow the [Advanced guide with a custom identity provider](https://docs.netbird
- **Public domain** name pointing to the VM.
**Software requirements:**
- Docker installed on the VM with the docker-compose plugin ([Docker installation guide](https://docs.docker.com/engine/install/)) or docker with docker-compose in version 2 or higher.
- Docker installed on the VM with the docker compose plugin ([Docker installation guide](https://docs.docker.com/engine/install/)) or docker with docker-compose in version 2 or higher.
- [jq](https://jqlang.github.io/jq/) installed. In most distributions
Usually available in the official repositories and can be installed with `sudo apt install jq` or `sudo yum install jq`
- [curl](https://curl.se/) installed.
@@ -95,9 +93,9 @@ export NETBIRD_DOMAIN=netbird.example.com; curl -fsSL https://github.com/netbird
- Every machine in the network runs [NetBird Agent (or Client)](client/) that manages WireGuard.
- Every agent connects to [Management Service](management/) that holds network state, manages peer IPs, and distributes network updates to agents (peers).
- NetBird agent uses WebRTC ICE implemented in [pion/ice library](https://github.com/pion/ice) to discover connection candidates when establishing a peer-to-peer connection between machines.
- Connection candidates are discovered with the help of [STUN](https://en.wikipedia.org/wiki/STUN) servers.
- Connection candidates are discovered with a help of [STUN](https://en.wikipedia.org/wiki/STUN) servers.
- Agents negotiate a connection through [Signal Service](signal/) passing p2p encrypted messages with candidates.
- Sometimes the NAT traversal is unsuccessful due to strict NATs (e.g. mobile carrier-grade NAT) and a p2p connection isn't possible. When this occurs the system falls back to a relay server called [TURN](https://en.wikipedia.org/wiki/Traversal_Using_Relays_around_NAT), and a secure WireGuard tunnel is established via the TURN server.
- Sometimes the NAT traversal is unsuccessful due to strict NATs (e.g. mobile carrier-grade NAT) and p2p connection isn't possible. When this occurs the system falls back to a relay server called [TURN](https://en.wikipedia.org/wiki/Traversal_Using_Relays_around_NAT), and a secure WireGuard tunnel is established via the TURN server.
[Coturn](https://github.com/coturn/coturn) is the one that has been successfully used for STUN and TURN in NetBird setups.
@@ -121,7 +119,7 @@ In November 2022, NetBird joined the [StartUpSecure program](https://www.forschu
![CISPA_Logo_BLACK_EN_RZ_RGB (1)](https://user-images.githubusercontent.com/700848/203091324-c6d311a0-22b5-4b05-a288-91cbc6cdcc46.png)
### Testimonials
We use open-source technologies like [WireGuard®](https://www.wireguard.com/), [Pion ICE (WebRTC)](https://github.com/pion/ice), [Coturn](https://github.com/coturn/coturn), and [Rosenpass](https://rosenpass.eu). We very much appreciate the work these guys are doing and we'd greatly appreciate if you could support them in any way (e.g., by giving a star or a contribution).
We use open-source technologies like [WireGuard®](https://www.wireguard.com/), [Pion ICE (WebRTC)](https://github.com/pion/ice), [Coturn](https://github.com/coturn/coturn), and [Rosenpass](https://rosenpass.eu). We very much appreciate the work these guys are doing and we'd greatly appreciate if you could support them in any way (e.g. giving a star or a contribution).
### Legal
_WireGuard_ and the _WireGuard_ logo are [registered trademarks](https://www.wireguard.com/trademark-policy/) of Jason A. Donenfeld.

View File

@@ -1,5 +1,5 @@
FROM alpine:3.19
FROM alpine:3.18.5
RUN apk add --no-cache ca-certificates iptables ip6tables
ENV NB_FOREGROUND_MODE=true
ENTRYPOINT [ "/usr/local/bin/netbird","up"]
COPY netbird /usr/local/bin/netbird
ENTRYPOINT [ "/go/bin/netbird","up"]
COPY netbird /go/bin/netbird

View File

@@ -1,5 +1,3 @@
//go:build android
package android
import (
@@ -16,7 +14,6 @@ import (
"github.com/netbirdio/netbird/client/system"
"github.com/netbirdio/netbird/formatter"
"github.com/netbirdio/netbird/iface"
"github.com/netbirdio/netbird/util/net"
)
// ConnectionListener export internal Listener for mobile
@@ -57,17 +54,14 @@ type Client struct {
ctxCancel context.CancelFunc
ctxCancelLock *sync.Mutex
deviceName string
uiVersion string
networkChangeListener listener.NetworkChangeListener
}
// NewClient instantiate a new Client
func NewClient(cfgFile, deviceName string, uiVersion string, tunAdapter TunAdapter, iFaceDiscover IFaceDiscover, networkChangeListener NetworkChangeListener) *Client {
net.SetAndroidProtectSocketFn(tunAdapter.ProtectSocket)
func NewClient(cfgFile, deviceName string, tunAdapter TunAdapter, iFaceDiscover IFaceDiscover, networkChangeListener NetworkChangeListener) *Client {
return &Client{
cfgFile: cfgFile,
deviceName: deviceName,
uiVersion: uiVersion,
tunAdapter: tunAdapter,
iFaceDiscover: iFaceDiscover,
recorder: peer.NewRecorder(""),
@@ -90,9 +84,6 @@ func (c *Client) Run(urlOpener URLOpener, dns *DNSList, dnsReadyListener DnsRead
var ctx context.Context
//nolint
ctxWithValues := context.WithValue(context.Background(), system.DeviceNameCtxKey, c.deviceName)
//nolint
ctxWithValues = context.WithValue(ctxWithValues, system.UiVersionCtxKey, c.uiVersion)
c.ctxCancelLock.Lock()
ctx, c.ctxCancel = context.WithCancel(ctxWithValues)
defer c.ctxCancel()
@@ -106,8 +97,7 @@ func (c *Client) Run(urlOpener URLOpener, dns *DNSList, dnsReadyListener DnsRead
// todo do not throw error in case of cancelled context
ctx = internal.CtxInitState(ctx)
connectClient := internal.NewConnectClient(ctx, cfg, c.recorder)
return connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, dns.items, dnsReadyListener)
return internal.RunClientMobile(ctx, cfg, c.recorder, c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, dns.items, dnsReadyListener)
}
// RunWithoutLogin we apply this type of run function when the backed has been started without UI (i.e. after reboot).
@@ -132,8 +122,7 @@ func (c *Client) RunWithoutLogin(dns *DNSList, dnsReadyListener DnsReadyListener
// todo do not throw error in case of cancelled context
ctx = internal.CtxInitState(ctx)
connectClient := internal.NewConnectClient(ctx, cfg, c.recorder)
return connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, dns.items, dnsReadyListener)
return internal.RunClientMobile(ctx, cfg, c.recorder, c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, dns.items, dnsReadyListener)
}
// Stop the internal client and free the resources

View File

@@ -1,212 +0,0 @@
package anonymize
import (
"crypto/rand"
"fmt"
"math/big"
"net"
"net/netip"
"net/url"
"regexp"
"slices"
"strings"
)
type Anonymizer struct {
ipAnonymizer map[netip.Addr]netip.Addr
domainAnonymizer map[string]string
currentAnonIPv4 netip.Addr
currentAnonIPv6 netip.Addr
startAnonIPv4 netip.Addr
startAnonIPv6 netip.Addr
}
func DefaultAddresses() (netip.Addr, netip.Addr) {
// 192.51.100.0, 100::
return netip.AddrFrom4([4]byte{198, 51, 100, 0}), netip.AddrFrom16([16]byte{0x01})
}
func NewAnonymizer(startIPv4, startIPv6 netip.Addr) *Anonymizer {
return &Anonymizer{
ipAnonymizer: map[netip.Addr]netip.Addr{},
domainAnonymizer: map[string]string{},
currentAnonIPv4: startIPv4,
currentAnonIPv6: startIPv6,
startAnonIPv4: startIPv4,
startAnonIPv6: startIPv6,
}
}
func (a *Anonymizer) AnonymizeIP(ip netip.Addr) netip.Addr {
if ip.IsLoopback() ||
ip.IsLinkLocalUnicast() ||
ip.IsLinkLocalMulticast() ||
ip.IsInterfaceLocalMulticast() ||
ip.IsPrivate() ||
ip.IsUnspecified() ||
ip.IsMulticast() ||
isWellKnown(ip) ||
a.isInAnonymizedRange(ip) {
return ip
}
if _, ok := a.ipAnonymizer[ip]; !ok {
if ip.Is4() {
a.ipAnonymizer[ip] = a.currentAnonIPv4
a.currentAnonIPv4 = a.currentAnonIPv4.Next()
} else {
a.ipAnonymizer[ip] = a.currentAnonIPv6
a.currentAnonIPv6 = a.currentAnonIPv6.Next()
}
}
return a.ipAnonymizer[ip]
}
// isInAnonymizedRange checks if an IP is within the range of already assigned anonymized IPs
func (a *Anonymizer) isInAnonymizedRange(ip netip.Addr) bool {
if ip.Is4() && ip.Compare(a.startAnonIPv4) >= 0 && ip.Compare(a.currentAnonIPv4) <= 0 {
return true
} else if !ip.Is4() && ip.Compare(a.startAnonIPv6) >= 0 && ip.Compare(a.currentAnonIPv6) <= 0 {
return true
}
return false
}
func (a *Anonymizer) AnonymizeIPString(ip string) string {
addr, err := netip.ParseAddr(ip)
if err != nil {
return ip
}
return a.AnonymizeIP(addr).String()
}
func (a *Anonymizer) AnonymizeDomain(domain string) string {
if strings.HasSuffix(domain, "netbird.io") ||
strings.HasSuffix(domain, "netbird.selfhosted") ||
strings.HasSuffix(domain, "netbird.cloud") ||
strings.HasSuffix(domain, "netbird.stage") ||
strings.HasSuffix(domain, ".domain") {
return domain
}
parts := strings.Split(domain, ".")
if len(parts) < 2 {
return domain
}
baseDomain := parts[len(parts)-2] + "." + parts[len(parts)-1]
anonymized, ok := a.domainAnonymizer[baseDomain]
if !ok {
anonymizedBase := "anon-" + generateRandomString(5) + ".domain"
a.domainAnonymizer[baseDomain] = anonymizedBase
anonymized = anonymizedBase
}
return strings.Replace(domain, baseDomain, anonymized, 1)
}
func (a *Anonymizer) AnonymizeURI(uri string) string {
u, err := url.Parse(uri)
if err != nil {
return uri
}
var anonymizedHost string
if u.Opaque != "" {
host, port, err := net.SplitHostPort(u.Opaque)
if err == nil {
anonymizedHost = fmt.Sprintf("%s:%s", a.AnonymizeDomain(host), port)
} else {
anonymizedHost = a.AnonymizeDomain(u.Opaque)
}
u.Opaque = anonymizedHost
} else if u.Host != "" {
host, port, err := net.SplitHostPort(u.Host)
if err == nil {
anonymizedHost = fmt.Sprintf("%s:%s", a.AnonymizeDomain(host), port)
} else {
anonymizedHost = a.AnonymizeDomain(u.Host)
}
u.Host = anonymizedHost
}
return u.String()
}
func (a *Anonymizer) AnonymizeString(str string) string {
ipv4Regex := regexp.MustCompile(`\b(?:[0-9]{1,3}\.){3}[0-9]{1,3}\b`)
ipv6Regex := regexp.MustCompile(`\b([0-9a-fA-F:]+:+[0-9a-fA-F]{0,4})(?:%[0-9a-zA-Z]+)?(?:\/[0-9]{1,3})?(?::[0-9]{1,5})?\b`)
str = ipv4Regex.ReplaceAllStringFunc(str, a.AnonymizeIPString)
str = ipv6Regex.ReplaceAllStringFunc(str, a.AnonymizeIPString)
for domain, anonDomain := range a.domainAnonymizer {
str = strings.ReplaceAll(str, domain, anonDomain)
}
str = a.AnonymizeSchemeURI(str)
str = a.AnonymizeDNSLogLine(str)
return str
}
// AnonymizeSchemeURI finds and anonymizes URIs with stun, stuns, turn, and turns schemes.
func (a *Anonymizer) AnonymizeSchemeURI(text string) string {
re := regexp.MustCompile(`(?i)\b(stuns?:|turns?:|https?://)\S+\b`)
return re.ReplaceAllStringFunc(text, a.AnonymizeURI)
}
// AnonymizeDNSLogLine anonymizes domain names in DNS log entries by replacing them with a random string.
func (a *Anonymizer) AnonymizeDNSLogLine(logEntry string) string {
domainPattern := `dns\.Question{Name:"([^"]+)",`
domainRegex := regexp.MustCompile(domainPattern)
return domainRegex.ReplaceAllStringFunc(logEntry, func(match string) string {
parts := strings.Split(match, `"`)
if len(parts) >= 2 {
domain := parts[1]
if strings.HasSuffix(domain, ".domain") {
return match
}
randomDomain := generateRandomString(10) + ".domain"
return strings.Replace(match, domain, randomDomain, 1)
}
return match
})
}
func isWellKnown(addr netip.Addr) bool {
wellKnown := []string{
"8.8.8.8", "8.8.4.4", // Google DNS IPv4
"2001:4860:4860::8888", "2001:4860:4860::8844", // Google DNS IPv6
"1.1.1.1", "1.0.0.1", // Cloudflare DNS IPv4
"2606:4700:4700::1111", "2606:4700:4700::1001", // Cloudflare DNS IPv6
"9.9.9.9", "149.112.112.112", // Quad9 DNS IPv4
"2620:fe::fe", "2620:fe::9", // Quad9 DNS IPv6
}
if slices.Contains(wellKnown, addr.String()) {
return true
}
cgnatRangeStart := netip.AddrFrom4([4]byte{100, 64, 0, 0})
cgnatRange := netip.PrefixFrom(cgnatRangeStart, 10)
return cgnatRange.Contains(addr)
}
func generateRandomString(length int) string {
const letters = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
result := make([]byte, length)
for i := range result {
num, err := rand.Int(rand.Reader, big.NewInt(int64(len(letters))))
if err != nil {
continue
}
result[i] = letters[num.Int64()]
}
return string(result)
}

View File

@@ -1,223 +0,0 @@
package anonymize_test
import (
"net/netip"
"regexp"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/anonymize"
)
func TestAnonymizeIP(t *testing.T) {
startIPv4 := netip.MustParseAddr("198.51.100.0")
startIPv6 := netip.MustParseAddr("100::")
anonymizer := anonymize.NewAnonymizer(startIPv4, startIPv6)
tests := []struct {
name string
ip string
expect string
}{
{"Well known", "8.8.8.8", "8.8.8.8"},
{"First Public IPv4", "1.2.3.4", "198.51.100.0"},
{"Second Public IPv4", "4.3.2.1", "198.51.100.1"},
{"Repeated IPv4", "1.2.3.4", "198.51.100.0"},
{"Private IPv4", "192.168.1.1", "192.168.1.1"},
{"First Public IPv6", "2607:f8b0:4005:805::200e", "100::"},
{"Second Public IPv6", "a::b", "100::1"},
{"Repeated IPv6", "2607:f8b0:4005:805::200e", "100::"},
{"Private IPv6", "fe80::1", "fe80::1"},
{"In Range IPv4", "198.51.100.2", "198.51.100.2"},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
ip := netip.MustParseAddr(tc.ip)
anonymizedIP := anonymizer.AnonymizeIP(ip)
if anonymizedIP.String() != tc.expect {
t.Errorf("%s: expected %s, got %s", tc.name, tc.expect, anonymizedIP)
}
})
}
}
func TestAnonymizeDNSLogLine(t *testing.T) {
anonymizer := anonymize.NewAnonymizer(netip.Addr{}, netip.Addr{})
testLog := `2024-04-23T20:01:11+02:00 TRAC client/internal/dns/local.go:25: received question: dns.Question{Name:"example.com", Qtype:0x1c, Qclass:0x1}`
result := anonymizer.AnonymizeDNSLogLine(testLog)
require.NotEqual(t, testLog, result)
assert.NotContains(t, result, "example.com")
}
func TestAnonymizeDomain(t *testing.T) {
anonymizer := anonymize.NewAnonymizer(netip.Addr{}, netip.Addr{})
tests := []struct {
name string
domain string
expectPattern string
shouldAnonymize bool
}{
{
"General Domain",
"example.com",
`^anon-[a-zA-Z0-9]+\.domain$`,
true,
},
{
"Subdomain",
"sub.example.com",
`^sub\.anon-[a-zA-Z0-9]+\.domain$`,
true,
},
{
"Protected Domain",
"netbird.io",
`^netbird\.io$`,
false,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
result := anonymizer.AnonymizeDomain(tc.domain)
if tc.shouldAnonymize {
assert.Regexp(t, tc.expectPattern, result, "The anonymized domain should match the expected pattern")
assert.NotContains(t, result, tc.domain, "The original domain should not be present in the result")
} else {
assert.Equal(t, tc.domain, result, "Protected domains should not be anonymized")
}
})
}
}
func TestAnonymizeURI(t *testing.T) {
anonymizer := anonymize.NewAnonymizer(netip.Addr{}, netip.Addr{})
tests := []struct {
name string
uri string
regex string
}{
{
"HTTP URI with Port",
"http://example.com:80/path",
`^http://anon-[a-zA-Z0-9]+\.domain:80/path$`,
},
{
"HTTP URI without Port",
"http://example.com/path",
`^http://anon-[a-zA-Z0-9]+\.domain/path$`,
},
{
"Opaque URI with Port",
"stun:example.com:80?transport=udp",
`^stun:anon-[a-zA-Z0-9]+\.domain:80\?transport=udp$`,
},
{
"Opaque URI without Port",
"stun:example.com?transport=udp",
`^stun:anon-[a-zA-Z0-9]+\.domain\?transport=udp$`,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
result := anonymizer.AnonymizeURI(tc.uri)
assert.Regexp(t, regexp.MustCompile(tc.regex), result, "URI should match expected pattern")
require.NotContains(t, result, "example.com", "Original domain should not be present")
})
}
}
func TestAnonymizeSchemeURI(t *testing.T) {
anonymizer := anonymize.NewAnonymizer(netip.Addr{}, netip.Addr{})
tests := []struct {
name string
input string
expect string
}{
{"STUN URI in text", "Connection made via stun:example.com", `Connection made via stun:anon-[a-zA-Z0-9]+\.domain`},
{"TURN URI in log", "Failed attempt turn:some.example.com:3478?transport=tcp: retrying", `Failed attempt turn:some.anon-[a-zA-Z0-9]+\.domain:3478\?transport=tcp: retrying`},
{"HTTPS URI in message", "Visit https://example.com for more", `Visit https://anon-[a-zA-Z0-9]+\.domain for more`},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
result := anonymizer.AnonymizeSchemeURI(tc.input)
assert.Regexp(t, tc.expect, result, "The anonymized output should match expected pattern")
require.NotContains(t, result, "example.com", "Original domain should not be present")
})
}
}
func TestAnonymizString_MemorizedDomain(t *testing.T) {
anonymizer := anonymize.NewAnonymizer(netip.Addr{}, netip.Addr{})
domain := "example.com"
anonymizedDomain := anonymizer.AnonymizeDomain(domain)
sampleString := "This is a test string including the domain example.com which should be anonymized."
firstPassResult := anonymizer.AnonymizeString(sampleString)
secondPassResult := anonymizer.AnonymizeString(firstPassResult)
assert.Contains(t, firstPassResult, anonymizedDomain, "The domain should be anonymized in the first pass")
assert.NotContains(t, firstPassResult, domain, "The original domain should not appear in the first pass output")
assert.Equal(t, firstPassResult, secondPassResult, "The second pass should not further anonymize the string")
}
func TestAnonymizeString_DoubleURI(t *testing.T) {
anonymizer := anonymize.NewAnonymizer(netip.Addr{}, netip.Addr{})
domain := "example.com"
anonymizedDomain := anonymizer.AnonymizeDomain(domain)
sampleString := "Check out our site at https://example.com for more info."
firstPassResult := anonymizer.AnonymizeString(sampleString)
secondPassResult := anonymizer.AnonymizeString(firstPassResult)
assert.Contains(t, firstPassResult, "https://"+anonymizedDomain, "The URI should be anonymized in the first pass")
assert.NotContains(t, firstPassResult, "https://example.com", "The original URI should not appear in the first pass output")
assert.Equal(t, firstPassResult, secondPassResult, "The second pass should not further anonymize the URI")
}
func TestAnonymizeString_IPAddresses(t *testing.T) {
anonymizer := anonymize.NewAnonymizer(anonymize.DefaultAddresses())
tests := []struct {
name string
input string
expect string
}{
{
name: "IPv4 Address",
input: "Error occurred at IP 122.138.1.1",
expect: "Error occurred at IP 198.51.100.0",
},
{
name: "IPv6 Address",
input: "Access attempted from 2001:db8::ff00:42",
expect: "Access attempted from 100::",
},
{
name: "IPv6 Address with Port",
input: "Access attempted from [2001:db8::ff00:42]:8080",
expect: "Access attempted from [100::]:8080",
},
{
name: "Both IPv4 and IPv6",
input: "IPv4: 142.108.0.1 and IPv6: 2001:db8::ff00:43",
expect: "IPv4: 198.51.100.1 and IPv6: 100::1",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
result := anonymizer.AnonymizeString(tc.input)
assert.Equal(t, tc.expect, result, "IP addresses should be anonymized correctly")
})
}
}

View File

@@ -1,255 +0,0 @@
package cmd
import (
"context"
"fmt"
"time"
"github.com/spf13/cobra"
"google.golang.org/grpc/status"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/client/server"
)
var debugCmd = &cobra.Command{
Use: "debug",
Short: "Debugging commands",
Long: "Provides commands for debugging and logging control within the Netbird daemon.",
}
var debugBundleCmd = &cobra.Command{
Use: "bundle",
Example: " netbird debug bundle",
Short: "Create a debug bundle",
Long: "Generates a compressed archive of the daemon's logs and status for debugging purposes.",
RunE: debugBundle,
}
var logCmd = &cobra.Command{
Use: "log",
Short: "Manage logging for the Netbird daemon",
Long: `Commands to manage logging settings for the Netbird daemon, including ICE, gRPC, and general log levels.`,
}
var logLevelCmd = &cobra.Command{
Use: "level <level>",
Short: "Set the logging level for this session",
Long: `Sets the logging level for the current session. This setting is temporary and will revert to the default on daemon restart.
Available log levels are:
panic: for panic level, highest level of severity
fatal: for fatal level errors that cause the program to exit
error: for error conditions
warn: for warning conditions
info: for informational messages
debug: for debug-level messages
trace: for trace-level messages, which include more fine-grained information than debug`,
Args: cobra.ExactArgs(1),
RunE: setLogLevel,
}
var forCmd = &cobra.Command{
Use: "for <time>",
Short: "Run debug logs for a specified duration and create a debug bundle",
Long: `Sets the logging level to trace, runs for the specified duration, and then generates a debug bundle.`,
Example: " netbird debug for 5m",
Args: cobra.ExactArgs(1),
RunE: runForDuration,
}
func debugBundle(cmd *cobra.Command, _ []string) error {
conn, err := getClient(cmd)
if err != nil {
return err
}
defer conn.Close()
client := proto.NewDaemonServiceClient(conn)
resp, err := client.DebugBundle(cmd.Context(), &proto.DebugBundleRequest{
Anonymize: anonymizeFlag,
Status: getStatusOutput(cmd),
})
if err != nil {
return fmt.Errorf("failed to bundle debug: %v", status.Convert(err).Message())
}
cmd.Println(resp.GetPath())
return nil
}
func setLogLevel(cmd *cobra.Command, args []string) error {
conn, err := getClient(cmd)
if err != nil {
return err
}
defer conn.Close()
client := proto.NewDaemonServiceClient(conn)
level := server.ParseLogLevel(args[0])
if level == proto.LogLevel_UNKNOWN {
return fmt.Errorf("unknown log level: %s. Available levels are: panic, fatal, error, warn, info, debug, trace\n", args[0])
}
_, err = client.SetLogLevel(cmd.Context(), &proto.SetLogLevelRequest{
Level: level,
})
if err != nil {
return fmt.Errorf("failed to set log level: %v", status.Convert(err).Message())
}
cmd.Println("Log level set successfully to", args[0])
return nil
}
func runForDuration(cmd *cobra.Command, args []string) error {
duration, err := time.ParseDuration(args[0])
if err != nil {
return fmt.Errorf("invalid duration format: %v", err)
}
conn, err := getClient(cmd)
if err != nil {
return err
}
defer conn.Close()
client := proto.NewDaemonServiceClient(conn)
stat, err := client.Status(cmd.Context(), &proto.StatusRequest{})
if err != nil {
return fmt.Errorf("failed to get status: %v", status.Convert(err).Message())
}
restoreUp := stat.Status == string(internal.StatusConnected) || stat.Status == string(internal.StatusConnecting)
initialLogLevel, err := client.GetLogLevel(cmd.Context(), &proto.GetLogLevelRequest{})
if err != nil {
return fmt.Errorf("failed to get log level: %v", status.Convert(err).Message())
}
if _, err := client.Down(cmd.Context(), &proto.DownRequest{}); err != nil {
return fmt.Errorf("failed to down: %v", status.Convert(err).Message())
}
cmd.Println("Netbird down")
initialLevelTrace := initialLogLevel.GetLevel() >= proto.LogLevel_TRACE
if !initialLevelTrace {
_, err = client.SetLogLevel(cmd.Context(), &proto.SetLogLevelRequest{
Level: proto.LogLevel_TRACE,
})
if err != nil {
return fmt.Errorf("failed to set log level to TRACE: %v", status.Convert(err).Message())
}
cmd.Println("Log level set to trace.")
}
time.Sleep(1 * time.Second)
if _, err := client.Up(cmd.Context(), &proto.UpRequest{}); err != nil {
return fmt.Errorf("failed to up: %v", status.Convert(err).Message())
}
cmd.Println("Netbird up")
time.Sleep(3 * time.Second)
headerPostUp := fmt.Sprintf("----- Netbird post-up - Timestamp: %s", time.Now().Format(time.RFC3339))
statusOutput := fmt.Sprintf("%s\n%s", headerPostUp, getStatusOutput(cmd))
if waitErr := waitForDurationOrCancel(cmd.Context(), duration, cmd); waitErr != nil {
return waitErr
}
cmd.Println("\nDuration completed")
headerPreDown := fmt.Sprintf("----- Netbird pre-down - Timestamp: %s - Duration: %s", time.Now().Format(time.RFC3339), duration)
statusOutput = fmt.Sprintf("%s\n%s\n%s", statusOutput, headerPreDown, getStatusOutput(cmd))
if _, err := client.Down(cmd.Context(), &proto.DownRequest{}); err != nil {
return fmt.Errorf("failed to down: %v", status.Convert(err).Message())
}
cmd.Println("Netbird down")
time.Sleep(1 * time.Second)
if restoreUp {
if _, err := client.Up(cmd.Context(), &proto.UpRequest{}); err != nil {
return fmt.Errorf("failed to up: %v", status.Convert(err).Message())
}
cmd.Println("Netbird up")
}
if !initialLevelTrace {
if _, err := client.SetLogLevel(cmd.Context(), &proto.SetLogLevelRequest{Level: initialLogLevel.GetLevel()}); err != nil {
return fmt.Errorf("failed to restore log level: %v", status.Convert(err).Message())
}
cmd.Println("Log level restored to", initialLogLevel.GetLevel())
}
cmd.Println("Creating debug bundle...")
resp, err := client.DebugBundle(cmd.Context(), &proto.DebugBundleRequest{
Anonymize: anonymizeFlag,
Status: statusOutput,
})
if err != nil {
return fmt.Errorf("failed to bundle debug: %v", status.Convert(err).Message())
}
cmd.Println(resp.GetPath())
return nil
}
func getStatusOutput(cmd *cobra.Command) string {
var statusOutputString string
statusResp, err := getStatus(cmd.Context())
if err != nil {
cmd.PrintErrf("Failed to get status: %v\n", err)
} else {
statusOutputString = parseToFullDetailSummary(convertToStatusOutputOverview(statusResp))
}
return statusOutputString
}
func waitForDurationOrCancel(ctx context.Context, duration time.Duration, cmd *cobra.Command) error {
ticker := time.NewTicker(1 * time.Second)
defer ticker.Stop()
startTime := time.Now()
done := make(chan struct{})
go func() {
defer close(done)
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
elapsed := time.Since(startTime)
if elapsed >= duration {
return
}
remaining := duration - elapsed
cmd.Printf("\rRemaining time: %s", formatDuration(remaining))
}
}
}()
select {
case <-ctx.Done():
return ctx.Err()
case <-done:
return nil
}
}
func formatDuration(d time.Duration) string {
d = d.Round(time.Second)
h := d / time.Hour
d %= time.Hour
m := d / time.Minute
d %= time.Minute
s := d / time.Second
return fmt.Sprintf("%02d:%02d:%02d", h, m, s)
}

View File

@@ -2,9 +2,8 @@ package cmd
import (
"context"
"time"
"github.com/netbirdio/netbird/util"
"time"
log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"

View File

@@ -32,11 +32,8 @@ const (
preSharedKeyFlag = "preshared-key"
interfaceNameFlag = "interface-name"
wireguardPortFlag = "wireguard-port"
networkMonitorFlag = "network-monitor"
disableAutoConnectFlag = "disable-auto-connect"
serverSSHAllowedFlag = "allow-server-ssh"
extraIFaceBlackListFlag = "extra-iface-blacklist"
dnsRouteIntervalFlag = "dns-router-interval"
)
var (
@@ -64,14 +61,8 @@ var (
serverSSHAllowed bool
interfaceName string
wireguardPort uint16
networkMonitor bool
serviceName string
autoConnectDisabled bool
extraIFaceBlackList []string
anonymizeFlag bool
dnsRouteInterval time.Duration
rootCmd = &cobra.Command{
rootCmd = &cobra.Command{
Use: "netbird",
Short: "",
Long: "",
@@ -109,24 +100,15 @@ func init() {
if runtime.GOOS == "windows" {
defaultDaemonAddr = "tcp://127.0.0.1:41731"
}
defaultServiceName := "netbird"
if runtime.GOOS == "windows" {
defaultServiceName = "Netbird"
}
rootCmd.PersistentFlags().StringVar(&daemonAddr, "daemon-addr", defaultDaemonAddr, "Daemon service address to serve CLI requests [unix|tcp]://[path|host:port]")
rootCmd.PersistentFlags().StringVarP(&managementURL, "management-url", "m", "", fmt.Sprintf("Management Service URL [http|https]://[host]:[port] (default \"%s\")", internal.DefaultManagementURL))
rootCmd.PersistentFlags().StringVar(&adminURL, "admin-url", "", fmt.Sprintf("Admin Panel URL [http|https]://[host]:[port] (default \"%s\")", internal.DefaultAdminURL))
rootCmd.PersistentFlags().StringVarP(&serviceName, "service", "s", defaultServiceName, "Netbird system service name")
rootCmd.PersistentFlags().StringVarP(&configPath, "config", "c", defaultConfigPath, "Netbird config file location")
rootCmd.PersistentFlags().StringVarP(&logLevel, "log-level", "l", "info", "sets Netbird log level")
rootCmd.PersistentFlags().StringVar(&logFile, "log-file", defaultLogFile, "sets Netbird log path. If console is specified the log will be output to stdout")
rootCmd.PersistentFlags().StringVarP(&setupKey, "setup-key", "k", "", "Setup key obtained from the Management Service Dashboard (used to register peer)")
rootCmd.PersistentFlags().StringVar(&preSharedKey, preSharedKeyFlag, "", "Sets Wireguard PreSharedKey property. If set, then only peers that have the same key can communicate.")
rootCmd.PersistentFlags().StringVarP(&hostName, "hostname", "n", "", "Sets a custom hostname for the device")
rootCmd.PersistentFlags().BoolVarP(&anonymizeFlag, "anonymize", "A", false, "anonymize IP addresses and non-netbird.io domains in logs and status output")
rootCmd.AddCommand(serviceCmd)
rootCmd.AddCommand(upCmd)
rootCmd.AddCommand(downCmd)
@@ -134,20 +116,8 @@ func init() {
rootCmd.AddCommand(loginCmd)
rootCmd.AddCommand(versionCmd)
rootCmd.AddCommand(sshCmd)
rootCmd.AddCommand(routesCmd)
rootCmd.AddCommand(debugCmd)
serviceCmd.AddCommand(runCmd, startCmd, stopCmd, restartCmd) // service control commands are subcommands of service
serviceCmd.AddCommand(installCmd, uninstallCmd) // service installer commands are subcommands of service
routesCmd.AddCommand(routesListCmd)
routesCmd.AddCommand(routesSelectCmd, routesDeselectCmd)
debugCmd.AddCommand(debugBundleCmd)
debugCmd.AddCommand(logCmd)
logCmd.AddCommand(logLevelCmd)
debugCmd.AddCommand(forCmd)
upCmd.PersistentFlags().StringSliceVar(&natExternalIPs, externalIPMapFlag, nil,
`Sets external IPs maps between local addresses and interfaces.`+
`You can specify a comma-separated list with a single IP and IP/IP or IP/Interface Name. `+
@@ -355,17 +325,3 @@ func migrateToNetbird(oldPath, newPath string) bool {
return true
}
func getClient(cmd *cobra.Command) (*grpc.ClientConn, error) {
SetFlagsFromEnvVars(rootCmd)
cmd.SetOut(cmd.OutOrStdout())
conn, err := DialClientGRPCServer(cmd.Context(), daemonAddr)
if err != nil {
return nil, fmt.Errorf("failed to connect to daemon error: %v\n"+
"If the daemon is not running please run: "+
"\nnetbird service install \nnetbird service start\n", err)
}
return conn, nil
}

View File

@@ -1,174 +0,0 @@
package cmd
import (
"fmt"
"strings"
"github.com/spf13/cobra"
"google.golang.org/grpc/status"
"github.com/netbirdio/netbird/client/proto"
)
var appendFlag bool
var routesCmd = &cobra.Command{
Use: "routes",
Short: "Manage network routes",
Long: `Commands to list, select, or deselect network routes.`,
}
var routesListCmd = &cobra.Command{
Use: "list",
Aliases: []string{"ls"},
Short: "List routes",
Example: " netbird routes list",
Long: "List all available network routes.",
RunE: routesList,
}
var routesSelectCmd = &cobra.Command{
Use: "select route...|all",
Short: "Select routes",
Long: "Select a list of routes by identifiers or 'all' to clear all selections and to accept all (including new) routes.\nDefault mode is replace, use -a to append to already selected routes.",
Example: " netbird routes select all\n netbird routes select route1 route2\n netbird routes select -a route3",
Args: cobra.MinimumNArgs(1),
RunE: routesSelect,
}
var routesDeselectCmd = &cobra.Command{
Use: "deselect route...|all",
Short: "Deselect routes",
Long: "Deselect previously selected routes by identifiers or 'all' to disable accepting any routes.",
Example: " netbird routes deselect all\n netbird routes deselect route1 route2",
Args: cobra.MinimumNArgs(1),
RunE: routesDeselect,
}
func init() {
routesSelectCmd.PersistentFlags().BoolVarP(&appendFlag, "append", "a", false, "Append to current route selection instead of replacing")
}
func routesList(cmd *cobra.Command, _ []string) error {
conn, err := getClient(cmd)
if err != nil {
return err
}
defer conn.Close()
client := proto.NewDaemonServiceClient(conn)
resp, err := client.ListRoutes(cmd.Context(), &proto.ListRoutesRequest{})
if err != nil {
return fmt.Errorf("failed to list routes: %v", status.Convert(err).Message())
}
if len(resp.Routes) == 0 {
cmd.Println("No routes available.")
return nil
}
printRoutes(cmd, resp)
return nil
}
func printRoutes(cmd *cobra.Command, resp *proto.ListRoutesResponse) {
cmd.Println("Available Routes:")
for _, route := range resp.Routes {
printRoute(cmd, route)
}
}
func printRoute(cmd *cobra.Command, route *proto.Route) {
selectedStatus := getSelectedStatus(route)
domains := route.GetDomains()
if len(domains) > 0 {
printDomainRoute(cmd, route, domains, selectedStatus)
} else {
printNetworkRoute(cmd, route, selectedStatus)
}
}
func getSelectedStatus(route *proto.Route) string {
if route.GetSelected() {
return "Selected"
}
return "Not Selected"
}
func printDomainRoute(cmd *cobra.Command, route *proto.Route, domains []string, selectedStatus string) {
cmd.Printf("\n - ID: %s\n Domains: %s\n Status: %s\n", route.GetID(), strings.Join(domains, ", "), selectedStatus)
resolvedIPs := route.GetResolvedIPs()
if len(resolvedIPs) > 0 {
printResolvedIPs(cmd, domains, resolvedIPs)
} else {
cmd.Printf(" Resolved IPs: -\n")
}
}
func printNetworkRoute(cmd *cobra.Command, route *proto.Route, selectedStatus string) {
cmd.Printf("\n - ID: %s\n Network: %s\n Status: %s\n", route.GetID(), route.GetNetwork(), selectedStatus)
}
func printResolvedIPs(cmd *cobra.Command, domains []string, resolvedIPs map[string]*proto.IPList) {
cmd.Printf(" Resolved IPs:\n")
for _, domain := range domains {
if ipList, exists := resolvedIPs[domain]; exists {
cmd.Printf(" [%s]: %s\n", domain, strings.Join(ipList.GetIps(), ", "))
}
}
}
func routesSelect(cmd *cobra.Command, args []string) error {
conn, err := getClient(cmd)
if err != nil {
return err
}
defer conn.Close()
client := proto.NewDaemonServiceClient(conn)
req := &proto.SelectRoutesRequest{
RouteIDs: args,
}
if len(args) == 1 && args[0] == "all" {
req.All = true
} else if appendFlag {
req.Append = true
}
if _, err := client.SelectRoutes(cmd.Context(), req); err != nil {
return fmt.Errorf("failed to select routes: %v", status.Convert(err).Message())
}
cmd.Println("Routes selected successfully.")
return nil
}
func routesDeselect(cmd *cobra.Command, args []string) error {
conn, err := getClient(cmd)
if err != nil {
return err
}
defer conn.Close()
client := proto.NewDaemonServiceClient(conn)
req := &proto.SelectRoutesRequest{
RouteIDs: args,
}
if len(args) == 1 && args[0] == "all" {
req.All = true
}
if _, err := client.DeselectRoutes(cmd.Context(), req); err != nil {
return fmt.Errorf("failed to deselect routes: %v", status.Convert(err).Message())
}
cmd.Println("Routes deselected successfully.")
return nil
}

View File

@@ -2,6 +2,8 @@ package cmd
import (
"context"
"runtime"
"github.com/kardianos/service"
log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
@@ -22,8 +24,12 @@ func newProgram(ctx context.Context, cancel context.CancelFunc) *program {
}
func newSVCConfig() *service.Config {
name := "netbird"
if runtime.GOOS == "windows" {
name = "Netbird"
}
return &service.Config{
Name: serviceName,
Name: name,
DisplayName: "Netbird",
Description: "A WireGuard-based mesh network that connects your devices into a single private network.",
Option: make(service.KeyValue),

View File

@@ -64,10 +64,6 @@ var installCmd = &cobra.Command{
}
}
if runtime.GOOS == "windows" {
svcConfig.Option["OnFailure"] = "restart"
}
ctx, cancel := context.WithCancel(cmd.Context())
s, err := newSVC(newProgram(ctx, cancel), svcConfig)

View File

@@ -24,7 +24,7 @@ var (
)
var sshCmd = &cobra.Command{
Use: "ssh [user@]host",
Use: "ssh",
Args: func(cmd *cobra.Command, args []string) error {
if len(args) < 1 {
return errors.New("requires a host argument")
@@ -94,7 +94,7 @@ func runSSH(ctx context.Context, addr string, pemKey []byte, cmd *cobra.Command)
if err != nil {
cmd.Printf("Error: %v\n", err)
cmd.Printf("Couldn't connect. Please check the connection status or if the ssh server is enabled on the other peer" +
"\nYou can verify the connection by running:\n\n" +
"You can verify the connection by running:\n\n" +
" netbird status\n\n")
return err
}

View File

@@ -6,8 +6,6 @@ import (
"fmt"
"net"
"net/netip"
"os"
"runtime"
"sort"
"strings"
"time"
@@ -16,7 +14,6 @@ import (
"google.golang.org/grpc/status"
"gopkg.in/yaml.v3"
"github.com/netbirdio/netbird/client/anonymize"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/proto"
@@ -37,9 +34,7 @@ type peerStateDetailOutput struct {
LastWireguardHandshake time.Time `json:"lastWireguardHandshake" yaml:"lastWireguardHandshake"`
TransferReceived int64 `json:"transferReceived" yaml:"transferReceived"`
TransferSent int64 `json:"transferSent" yaml:"transferSent"`
Latency time.Duration `json:"latency" yaml:"latency"`
RosenpassEnabled bool `json:"quantumResistance" yaml:"quantumResistance"`
Routes []string `json:"routes" yaml:"routes"`
}
type peersStateOutput struct {
@@ -77,28 +72,19 @@ type iceCandidateType struct {
Remote string `json:"remote" yaml:"remote"`
}
type nsServerGroupStateOutput struct {
Servers []string `json:"servers" yaml:"servers"`
Domains []string `json:"domains" yaml:"domains"`
Enabled bool `json:"enabled" yaml:"enabled"`
Error string `json:"error" yaml:"error"`
}
type statusOutputOverview struct {
Peers peersStateOutput `json:"peers" yaml:"peers"`
CliVersion string `json:"cliVersion" yaml:"cliVersion"`
DaemonVersion string `json:"daemonVersion" yaml:"daemonVersion"`
ManagementState managementStateOutput `json:"management" yaml:"management"`
SignalState signalStateOutput `json:"signal" yaml:"signal"`
Relays relayStateOutput `json:"relays" yaml:"relays"`
IP string `json:"netbirdIp" yaml:"netbirdIp"`
PubKey string `json:"publicKey" yaml:"publicKey"`
KernelInterface bool `json:"usesKernelInterface" yaml:"usesKernelInterface"`
FQDN string `json:"fqdn" yaml:"fqdn"`
RosenpassEnabled bool `json:"quantumResistance" yaml:"quantumResistance"`
RosenpassPermissive bool `json:"quantumResistancePermissive" yaml:"quantumResistancePermissive"`
Routes []string `json:"routes" yaml:"routes"`
NSServerGroups []nsServerGroupStateOutput `json:"dnsServers" yaml:"dnsServers"`
Peers peersStateOutput `json:"peers" yaml:"peers"`
CliVersion string `json:"cliVersion" yaml:"cliVersion"`
DaemonVersion string `json:"daemonVersion" yaml:"daemonVersion"`
ManagementState managementStateOutput `json:"management" yaml:"management"`
SignalState signalStateOutput `json:"signal" yaml:"signal"`
Relays relayStateOutput `json:"relays" yaml:"relays"`
IP string `json:"netbirdIp" yaml:"netbirdIp"`
PubKey string `json:"publicKey" yaml:"publicKey"`
KernelInterface bool `json:"usesKernelInterface" yaml:"usesKernelInterface"`
FQDN string `json:"fqdn" yaml:"fqdn"`
RosenpassEnabled bool `json:"quantumResistance" yaml:"quantumResistance"`
RosenpassPermissive bool `json:"quantumResistancePermissive" yaml:"quantumResistancePermissive"`
}
var (
@@ -147,9 +133,9 @@ func statusFunc(cmd *cobra.Command, args []string) error {
return fmt.Errorf("failed initializing log %v", err)
}
ctx := internal.CtxInitState(cmd.Context())
ctx := internal.CtxInitState(context.Background())
resp, err := getStatus(ctx)
resp, err := getStatus(ctx, cmd)
if err != nil {
return err
}
@@ -182,7 +168,7 @@ func statusFunc(cmd *cobra.Command, args []string) error {
case yamlFlag:
statusOutputString, err = parseToYAML(outputInformationHolder)
default:
statusOutputString = parseGeneralSummary(outputInformationHolder, false, false, false)
statusOutputString = parseGeneralSummary(outputInformationHolder, false, false)
}
if err != nil {
@@ -194,7 +180,7 @@ func statusFunc(cmd *cobra.Command, args []string) error {
return nil
}
func getStatus(ctx context.Context) (*proto.StatusResponse, error) {
func getStatus(ctx context.Context, cmd *cobra.Command) (*proto.StatusResponse, error) {
conn, err := DialClientGRPCServer(ctx, daemonAddr)
if err != nil {
return nil, fmt.Errorf("failed to connect to daemon error: %v\n"+
@@ -203,7 +189,7 @@ func getStatus(ctx context.Context) (*proto.StatusResponse, error) {
}
defer conn.Close()
resp, err := proto.NewDaemonServiceClient(conn).Status(ctx, &proto.StatusRequest{GetFullPeerStatus: true})
resp, err := proto.NewDaemonServiceClient(conn).Status(cmd.Context(), &proto.StatusRequest{GetFullPeerStatus: true})
if err != nil {
return nil, fmt.Errorf("status failed: %v", status.Convert(err).Message())
}
@@ -282,13 +268,6 @@ func convertToStatusOutputOverview(resp *proto.StatusResponse) statusOutputOverv
FQDN: pbFullStatus.GetLocalPeerState().GetFqdn(),
RosenpassEnabled: pbFullStatus.GetLocalPeerState().GetRosenpassEnabled(),
RosenpassPermissive: pbFullStatus.GetLocalPeerState().GetRosenpassPermissive(),
Routes: pbFullStatus.GetLocalPeerState().GetRoutes(),
NSServerGroups: mapNSGroups(pbFullStatus.GetDnsServers()),
}
if anonymizeFlag {
anonymizer := anonymize.NewAnonymizer(anonymize.DefaultAddresses())
anonymizeOverview(anonymizer, &overview)
}
return overview
@@ -320,19 +299,6 @@ func mapRelays(relays []*proto.RelayState) relayStateOutput {
}
}
func mapNSGroups(servers []*proto.NSGroupState) []nsServerGroupStateOutput {
mappedNSGroups := make([]nsServerGroupStateOutput, 0, len(servers))
for _, pbNsGroupServer := range servers {
mappedNSGroups = append(mappedNSGroups, nsServerGroupStateOutput{
Servers: pbNsGroupServer.GetServers(),
Domains: pbNsGroupServer.GetDomains(),
Enabled: pbNsGroupServer.GetEnabled(),
Error: pbNsGroupServer.GetError(),
})
}
return mappedNSGroups
}
func mapPeers(peers []*proto.PeerState) peersStateOutput {
var peersStateDetail []peerStateDetailOutput
localICE := ""
@@ -385,9 +351,7 @@ func mapPeers(peers []*proto.PeerState) peersStateOutput {
LastWireguardHandshake: lastHandshake,
TransferReceived: transferReceived,
TransferSent: transferSent,
Latency: pbPeerState.GetLatency().AsDuration(),
RosenpassEnabled: pbPeerState.GetRosenpassEnabled(),
Routes: pbPeerState.GetRoutes(),
}
peersStateDetail = append(peersStateDetail, peerState)
@@ -437,7 +401,8 @@ func parseToYAML(overview statusOutputOverview) (string, error) {
return string(yamlBytes), nil
}
func parseGeneralSummary(overview statusOutputOverview, showURL bool, showRelays bool, showNameServers bool) string {
func parseGeneralSummary(overview statusOutputOverview, showURL bool, showRelays bool) string {
var managementConnString string
if overview.ManagementState.Connected {
managementConnString = "Connected"
@@ -473,7 +438,7 @@ func parseGeneralSummary(overview statusOutputOverview, showURL bool, showRelays
interfaceIP = "N/A"
}
var relaysString string
var relayAvailableString string
if showRelays {
for _, relay := range overview.Relays.Details {
available := "Available"
@@ -482,46 +447,15 @@ func parseGeneralSummary(overview statusOutputOverview, showURL bool, showRelays
available = "Unavailable"
reason = fmt.Sprintf(", reason: %s", relay.Error)
}
relaysString += fmt.Sprintf("\n [%s] is %s%s", relay.URI, available, reason)
relayAvailableString += fmt.Sprintf("\n [%s] is %s%s", relay.URI, available, reason)
}
} else {
relaysString = fmt.Sprintf("%d/%d Available", overview.Relays.Available, overview.Relays.Total)
relayAvailableString = fmt.Sprintf("%d/%d Available", overview.Relays.Available, overview.Relays.Total)
}
routes := "-"
if len(overview.Routes) > 0 {
sort.Strings(overview.Routes)
routes = strings.Join(overview.Routes, ", ")
}
var dnsServersString string
if showNameServers {
for _, nsServerGroup := range overview.NSServerGroups {
enabled := "Available"
if !nsServerGroup.Enabled {
enabled = "Unavailable"
}
errorString := ""
if nsServerGroup.Error != "" {
errorString = fmt.Sprintf(", reason: %s", nsServerGroup.Error)
errorString = strings.TrimSpace(errorString)
}
domainsString := strings.Join(nsServerGroup.Domains, ", ")
if domainsString == "" {
domainsString = "." // Show "." for the default zone
}
dnsServersString += fmt.Sprintf(
"\n [%s] for [%s] is %s%s",
strings.Join(nsServerGroup.Servers, ", "),
domainsString,
enabled,
errorString,
)
}
} else {
dnsServersString = fmt.Sprintf("%d/%d Available", countEnabled(overview.NSServerGroups), len(overview.NSServerGroups))
}
peersCountString := fmt.Sprintf("%d/%d Connected", overview.Peers.Connected, overview.Peers.Total)
rosenpassEnabledStatus := "false"
if overview.RosenpassEnabled {
@@ -531,41 +465,26 @@ func parseGeneralSummary(overview statusOutputOverview, showURL bool, showRelays
}
}
peersCountString := fmt.Sprintf("%d/%d Connected", overview.Peers.Connected, overview.Peers.Total)
goos := runtime.GOOS
goarch := runtime.GOARCH
goarm := ""
if goarch == "arm" {
goarm = fmt.Sprintf(" (ARMv%s)", os.Getenv("GOARM"))
}
summary := fmt.Sprintf(
"OS: %s\n"+
"Daemon version: %s\n"+
"Daemon version: %s\n"+
"CLI version: %s\n"+
"Management: %s\n"+
"Signal: %s\n"+
"Relays: %s\n"+
"Nameservers: %s\n"+
"FQDN: %s\n"+
"NetBird IP: %s\n"+
"Interface type: %s\n"+
"Quantum resistance: %s\n"+
"Routes: %s\n"+
"Peers count: %s\n",
fmt.Sprintf("%s/%s%s", goos, goarch, goarm),
overview.DaemonVersion,
version.NetbirdVersion(),
managementConnString,
signalConnString,
relaysString,
dnsServersString,
relayAvailableString,
overview.FQDN,
interfaceIP,
interfaceTypeString,
rosenpassEnabledStatus,
routes,
peersCountString,
)
return summary
@@ -573,7 +492,7 @@ func parseGeneralSummary(overview statusOutputOverview, showURL bool, showRelays
func parseToFullDetailSummary(overview statusOutputOverview) string {
parsedPeersString := parsePeers(overview.Peers, overview.RosenpassEnabled, overview.RosenpassPermissive)
summary := parseGeneralSummary(overview, true, true, true)
summary := parseGeneralSummary(overview, true, true)
return fmt.Sprintf(
"Peers detail:"+
@@ -610,6 +529,15 @@ func parsePeers(peers peersStateOutput, rosenpassEnabled, rosenpassPermissive bo
if peerState.IceCandidateEndpoint.Remote != "" {
remoteICEEndpoint = peerState.IceCandidateEndpoint.Remote
}
lastStatusUpdate := "-"
if !peerState.LastStatusUpdate.IsZero() {
lastStatusUpdate = peerState.LastStatusUpdate.Format("2006-01-02 15:04:05")
}
lastWireGuardHandshake := "-"
if !peerState.LastWireguardHandshake.IsZero() && peerState.LastWireguardHandshake != time.Unix(0, 0) {
lastWireGuardHandshake = peerState.LastWireguardHandshake.Format("2006-01-02 15:04:05")
}
rosenpassEnabledStatus := "false"
if rosenpassEnabled {
@@ -628,12 +556,6 @@ func parsePeers(peers peersStateOutput, rosenpassEnabled, rosenpassPermissive bo
}
}
routes := "-"
if len(peerState.Routes) > 0 {
sort.Strings(peerState.Routes)
routes = strings.Join(peerState.Routes, ", ")
}
peerString := fmt.Sprintf(
"\n %s:\n"+
" NetBird IP: %s\n"+
@@ -647,9 +569,7 @@ func parsePeers(peers peersStateOutput, rosenpassEnabled, rosenpassPermissive bo
" Last connection update: %s\n"+
" Last WireGuard handshake: %s\n"+
" Transfer status (received/sent) %s/%s\n"+
" Quantum resistance: %s\n"+
" Routes: %s\n"+
" Latency: %s\n",
" Quantum resistance: %s\n",
peerState.FQDN,
peerState.IP,
peerState.PubKey,
@@ -660,13 +580,11 @@ func parsePeers(peers peersStateOutput, rosenpassEnabled, rosenpassPermissive bo
remoteICE,
localICEEndpoint,
remoteICEEndpoint,
timeAgo(peerState.LastStatusUpdate),
timeAgo(peerState.LastWireguardHandshake),
lastStatusUpdate,
lastWireGuardHandshake,
toIEC(peerState.TransferReceived),
toIEC(peerState.TransferSent),
rosenpassEnabledStatus,
routes,
peerState.Latency.String(),
)
peersString += peerString
@@ -720,144 +638,3 @@ func toIEC(b int64) string {
return fmt.Sprintf("%.1f %ciB",
float64(b)/float64(div), "KMGTPE"[exp])
}
func countEnabled(dnsServers []nsServerGroupStateOutput) int {
count := 0
for _, server := range dnsServers {
if server.Enabled {
count++
}
}
return count
}
// timeAgo returns a string representing the duration since the provided time in a human-readable format.
func timeAgo(t time.Time) string {
if t.IsZero() || t.Equal(time.Unix(0, 0)) {
return "-"
}
duration := time.Since(t)
switch {
case duration < time.Second:
return "Now"
case duration < time.Minute:
seconds := int(duration.Seconds())
if seconds == 1 {
return "1 second ago"
}
return fmt.Sprintf("%d seconds ago", seconds)
case duration < time.Hour:
minutes := int(duration.Minutes())
seconds := int(duration.Seconds()) % 60
if minutes == 1 {
if seconds == 1 {
return "1 minute, 1 second ago"
} else if seconds > 0 {
return fmt.Sprintf("1 minute, %d seconds ago", seconds)
}
return "1 minute ago"
}
if seconds > 0 {
return fmt.Sprintf("%d minutes, %d seconds ago", minutes, seconds)
}
return fmt.Sprintf("%d minutes ago", minutes)
case duration < 24*time.Hour:
hours := int(duration.Hours())
minutes := int(duration.Minutes()) % 60
if hours == 1 {
if minutes == 1 {
return "1 hour, 1 minute ago"
} else if minutes > 0 {
return fmt.Sprintf("1 hour, %d minutes ago", minutes)
}
return "1 hour ago"
}
if minutes > 0 {
return fmt.Sprintf("%d hours, %d minutes ago", hours, minutes)
}
return fmt.Sprintf("%d hours ago", hours)
}
days := int(duration.Hours()) / 24
hours := int(duration.Hours()) % 24
if days == 1 {
if hours == 1 {
return "1 day, 1 hour ago"
} else if hours > 0 {
return fmt.Sprintf("1 day, %d hours ago", hours)
}
return "1 day ago"
}
if hours > 0 {
return fmt.Sprintf("%d days, %d hours ago", days, hours)
}
return fmt.Sprintf("%d days ago", days)
}
func anonymizePeerDetail(a *anonymize.Anonymizer, peer *peerStateDetailOutput) {
peer.FQDN = a.AnonymizeDomain(peer.FQDN)
if localIP, port, err := net.SplitHostPort(peer.IceCandidateEndpoint.Local); err == nil {
peer.IceCandidateEndpoint.Local = fmt.Sprintf("%s:%s", a.AnonymizeIPString(localIP), port)
}
if remoteIP, port, err := net.SplitHostPort(peer.IceCandidateEndpoint.Remote); err == nil {
peer.IceCandidateEndpoint.Remote = fmt.Sprintf("%s:%s", a.AnonymizeIPString(remoteIP), port)
}
for i, route := range peer.Routes {
peer.Routes[i] = a.AnonymizeIPString(route)
}
for i, route := range peer.Routes {
peer.Routes[i] = anonymizeRoute(a, route)
}
}
func anonymizeOverview(a *anonymize.Anonymizer, overview *statusOutputOverview) {
for i, peer := range overview.Peers.Details {
peer := peer
anonymizePeerDetail(a, &peer)
overview.Peers.Details[i] = peer
}
overview.ManagementState.URL = a.AnonymizeURI(overview.ManagementState.URL)
overview.ManagementState.Error = a.AnonymizeString(overview.ManagementState.Error)
overview.SignalState.URL = a.AnonymizeURI(overview.SignalState.URL)
overview.SignalState.Error = a.AnonymizeString(overview.SignalState.Error)
overview.IP = a.AnonymizeIPString(overview.IP)
for i, detail := range overview.Relays.Details {
detail.URI = a.AnonymizeURI(detail.URI)
detail.Error = a.AnonymizeString(detail.Error)
overview.Relays.Details[i] = detail
}
for i, nsGroup := range overview.NSServerGroups {
for j, domain := range nsGroup.Domains {
overview.NSServerGroups[i].Domains[j] = a.AnonymizeDomain(domain)
}
for j, ns := range nsGroup.Servers {
host, port, err := net.SplitHostPort(ns)
if err == nil {
overview.NSServerGroups[i].Servers[j] = fmt.Sprintf("%s:%s", a.AnonymizeIPString(host), port)
}
}
}
for i, route := range overview.Routes {
overview.Routes[i] = anonymizeRoute(a, route)
}
overview.FQDN = a.AnonymizeDomain(overview.FQDN)
}
func anonymizeRoute(a *anonymize.Anonymizer, route string) string {
prefix, err := netip.ParsePrefix(route)
if err == nil {
ip := a.AnonymizeIPString(prefix.Addr().String())
return fmt.Sprintf("%s/%d", ip, prefix.Bits())
}
domains := strings.Split(route, ", ")
for i, domain := range domains {
domains[i] = a.AnonymizeDomain(domain)
}
return strings.Join(domains, ", ")
}

View File

@@ -3,14 +3,11 @@ package cmd
import (
"bytes"
"encoding/json"
"fmt"
"runtime"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/protobuf/types/known/durationpb"
"google.golang.org/protobuf/types/known/timestamppb"
"github.com/netbirdio/netbird/client/proto"
@@ -45,10 +42,6 @@ var resp = &proto.StatusResponse{
LastWireguardHandshake: timestamppb.New(time.Date(2001, time.Month(1), 1, 1, 1, 2, 0, time.UTC)),
BytesRx: 200,
BytesTx: 100,
Routes: []string{
"10.1.0.0/24",
},
Latency: durationpb.New(time.Duration(10000000)),
},
{
IP: "192.168.178.102",
@@ -65,7 +58,6 @@ var resp = &proto.StatusResponse{
LastWireguardHandshake: timestamppb.New(time.Date(2002, time.Month(2), 2, 2, 2, 3, 0, time.UTC)),
BytesRx: 2000,
BytesTx: 1000,
Latency: durationpb.New(time.Duration(10000000)),
},
},
ManagementState: &proto.ManagementState{
@@ -95,31 +87,6 @@ var resp = &proto.StatusResponse{
PubKey: "Some-Pub-Key",
KernelInterface: true,
Fqdn: "some-localhost.awesome-domain.com",
Routes: []string{
"10.10.0.0/24",
},
},
DnsServers: []*proto.NSGroupState{
{
Servers: []string{
"8.8.8.8:53",
},
Domains: nil,
Enabled: true,
Error: "",
},
{
Servers: []string{
"1.1.1.1:53",
"2.2.2.2:53",
},
Domains: []string{
"example.com",
"example.net",
},
Enabled: false,
Error: "timeout",
},
},
},
DaemonVersion: "0.14.1",
@@ -149,10 +116,6 @@ var overview = statusOutputOverview{
LastWireguardHandshake: time.Date(2001, 1, 1, 1, 1, 2, 0, time.UTC),
TransferReceived: 200,
TransferSent: 100,
Routes: []string{
"10.1.0.0/24",
},
Latency: time.Duration(10000000),
},
{
IP: "192.168.178.102",
@@ -173,7 +136,6 @@ var overview = statusOutputOverview{
LastWireguardHandshake: time.Date(2002, 2, 2, 2, 2, 3, 0, time.UTC),
TransferReceived: 2000,
TransferSent: 1000,
Latency: time.Duration(10000000),
},
},
},
@@ -209,31 +171,6 @@ var overview = statusOutputOverview{
PubKey: "Some-Pub-Key",
KernelInterface: true,
FQDN: "some-localhost.awesome-domain.com",
NSServerGroups: []nsServerGroupStateOutput{
{
Servers: []string{
"8.8.8.8:53",
},
Domains: nil,
Enabled: true,
Error: "",
},
{
Servers: []string{
"1.1.1.1:53",
"2.2.2.2:53",
},
Domains: []string{
"example.com",
"example.net",
},
Enabled: false,
Error: "timeout",
},
},
Routes: []string{
"10.10.0.0/24",
},
}
func TestConversionFromFullStatusToOutputOverview(t *testing.T) {
@@ -295,11 +232,7 @@ func TestParsingToJSON(t *testing.T) {
"lastWireguardHandshake": "2001-01-01T01:01:02Z",
"transferReceived": 200,
"transferSent": 100,
"latency": 10000000,
"quantumResistance": false,
"routes": [
"10.1.0.0/24"
]
"quantumResistance":false
},
{
"fqdn": "peer-2.awesome-domain.com",
@@ -320,9 +253,7 @@ func TestParsingToJSON(t *testing.T) {
"lastWireguardHandshake": "2002-02-02T02:02:03Z",
"transferReceived": 2000,
"transferSent": 1000,
"latency": 10000000,
"quantumResistance": false,
"routes": null
"quantumResistance":false
}
]
},
@@ -358,33 +289,8 @@ func TestParsingToJSON(t *testing.T) {
"publicKey": "Some-Pub-Key",
"usesKernelInterface": true,
"fqdn": "some-localhost.awesome-domain.com",
"quantumResistance": false,
"quantumResistancePermissive": false,
"routes": [
"10.10.0.0/24"
],
"dnsServers": [
{
"servers": [
"8.8.8.8:53"
],
"domains": null,
"enabled": true,
"error": ""
},
{
"servers": [
"1.1.1.1:53",
"2.2.2.2:53"
],
"domains": [
"example.com",
"example.net"
],
"enabled": false,
"error": "timeout"
}
]
"quantumResistance":false,
"quantumResistancePermissive":false
}`
// @formatter:on
@@ -418,10 +324,7 @@ func TestParsingToYAML(t *testing.T) {
lastWireguardHandshake: 2001-01-01T01:01:02Z
transferReceived: 200
transferSent: 100
latency: 10ms
quantumResistance: false
routes:
- 10.1.0.0/24
- fqdn: peer-2.awesome-domain.com
netbirdIp: 192.168.178.102
publicKey: Pubkey2
@@ -438,9 +341,7 @@ func TestParsingToYAML(t *testing.T) {
lastWireguardHandshake: 2002-02-02T02:02:03Z
transferReceived: 2000
transferSent: 1000
latency: 10ms
quantumResistance: false
routes: []
cliVersion: development
daemonVersion: 0.14.1
management:
@@ -467,37 +368,15 @@ usesKernelInterface: true
fqdn: some-localhost.awesome-domain.com
quantumResistance: false
quantumResistancePermissive: false
routes:
- 10.10.0.0/24
dnsServers:
- servers:
- 8.8.8.8:53
domains: []
enabled: true
error: ""
- servers:
- 1.1.1.1:53
- 2.2.2.2:53
domains:
- example.com
- example.net
enabled: false
error: timeout
`
assert.Equal(t, expectedYAML, yaml)
}
func TestParsingToDetail(t *testing.T) {
// Calculate time ago based on the fixture dates
lastConnectionUpdate1 := timeAgo(overview.Peers.Details[0].LastStatusUpdate)
lastHandshake1 := timeAgo(overview.Peers.Details[0].LastWireguardHandshake)
lastConnectionUpdate2 := timeAgo(overview.Peers.Details[1].LastStatusUpdate)
lastHandshake2 := timeAgo(overview.Peers.Details[1].LastWireguardHandshake)
detail := parseToFullDetailSummary(overview)
expectedDetail := fmt.Sprintf(
expectedDetail :=
`Peers detail:
peer-1.awesome-domain.com:
NetBird IP: 192.168.178.101
@@ -508,12 +387,10 @@ func TestParsingToDetail(t *testing.T) {
Direct: true
ICE candidate (Local/Remote): -/-
ICE candidate endpoints (Local/Remote): -/-
Last connection update: %s
Last WireGuard handshake: %s
Last connection update: 2001-01-01 01:01:01
Last WireGuard handshake: 2001-01-01 01:01:02
Transfer status (received/sent) 200 B/100 B
Quantum resistance: false
Routes: 10.1.0.0/24
Latency: 10ms
peer-2.awesome-domain.com:
NetBird IP: 192.168.178.102
@@ -524,50 +401,41 @@ func TestParsingToDetail(t *testing.T) {
Direct: false
ICE candidate (Local/Remote): relay/prflx
ICE candidate endpoints (Local/Remote): 10.0.0.1:10001/10.0.10.1:10002
Last connection update: %s
Last WireGuard handshake: %s
Last connection update: 2002-02-02 02:02:02
Last WireGuard handshake: 2002-02-02 02:02:03
Transfer status (received/sent) 2.0 KiB/1000 B
Quantum resistance: false
Routes: -
Latency: 10ms
OS: %s/%s
Daemon version: 0.14.1
CLI version: %s
CLI version: development
Management: Connected to my-awesome-management.com:443
Signal: Connected to my-awesome-signal.com:443
Relays:
[stun:my-awesome-stun.com:3478] is Available
[turns:my-awesome-turn.com:443?transport=tcp] is Unavailable, reason: context: deadline exceeded
Nameservers:
[8.8.8.8:53] for [.] is Available
[1.1.1.1:53, 2.2.2.2:53] for [example.com, example.net] is Unavailable, reason: timeout
FQDN: some-localhost.awesome-domain.com
NetBird IP: 192.168.178.100/16
Interface type: Kernel
Quantum resistance: false
Routes: 10.10.0.0/24
Peers count: 2/2 Connected
`, lastConnectionUpdate1, lastHandshake1, lastConnectionUpdate2, lastHandshake2, runtime.GOOS, runtime.GOARCH, overview.CliVersion)
`
assert.Equal(t, expectedDetail, detail)
}
func TestParsingToShortVersion(t *testing.T) {
shortVersion := parseGeneralSummary(overview, false, false, false)
shortVersion := parseGeneralSummary(overview, false, false)
expectedString := fmt.Sprintf("OS: %s/%s", runtime.GOOS, runtime.GOARCH) + `
Daemon version: 0.14.1
expectedString :=
`Daemon version: 0.14.1
CLI version: development
Management: Connected
Signal: Connected
Relays: 1/2 Available
Nameservers: 1/2 Available
FQDN: some-localhost.awesome-domain.com
NetBird IP: 192.168.178.100/16
Interface type: Kernel
Quantum resistance: false
Routes: 10.10.0.0/24
Peers count: 2/2 Connected
`
@@ -581,31 +449,3 @@ func TestParsingOfIP(t *testing.T) {
assert.Equal(t, "192.168.178.123\n", parsedIP)
}
func TestTimeAgo(t *testing.T) {
now := time.Now()
cases := []struct {
name string
input time.Time
expected string
}{
{"Now", now, "Now"},
{"Seconds ago", now.Add(-10 * time.Second), "10 seconds ago"},
{"One minute ago", now.Add(-1 * time.Minute), "1 minute ago"},
{"Minutes and seconds ago", now.Add(-(1*time.Minute + 30*time.Second)), "1 minute, 30 seconds ago"},
{"One hour ago", now.Add(-1 * time.Hour), "1 hour ago"},
{"Hours and minutes ago", now.Add(-(2*time.Hour + 15*time.Minute)), "2 hours, 15 minutes ago"},
{"One day ago", now.Add(-24 * time.Hour), "1 day ago"},
{"Multiple days ago", now.Add(-(72*time.Hour + 20*time.Minute)), "3 days ago"},
{"Zero time", time.Time{}, "-"},
{"Unix zero time", time.Unix(0, 0), "-"},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
result := timeAgo(tc.input)
assert.Equal(t, tc.expected, result, "Failed %s", tc.name)
})
}
}

View File

@@ -7,17 +7,12 @@ import (
"testing"
"time"
"github.com/stretchr/testify/require"
"go.opentelemetry.io/otel"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/util"
"google.golang.org/grpc"
"github.com/netbirdio/management-integrations/integrations"
clientProto "github.com/netbirdio/netbird/client/proto"
client "github.com/netbirdio/netbird/client/server"
mgmtProto "github.com/netbirdio/netbird/management/proto"
@@ -56,10 +51,7 @@ func startSignal(t *testing.T) (*grpc.Server, net.Listener) {
t.Fatal(err)
}
s := grpc.NewServer()
srv, err := sig.NewServer(otel.Meter(""))
require.NoError(t, err)
sigProto.RegisterSignalExchangeServer(s, srv)
sigProto.RegisterSignalExchangeServer(s, sig.NewServer())
go func() {
if err := s.Serve(lis); err != nil {
panic(err)
@@ -76,24 +68,22 @@ func startManagement(t *testing.T, config *mgmt.Config) (*grpc.Server, net.Liste
t.Fatal(err)
}
s := grpc.NewServer()
store, cleanUp, err := mgmt.NewTestStoreFromJson(context.Background(), config.Datadir)
store, err := mgmt.NewStoreFromJson(config.Datadir, nil)
if err != nil {
t.Fatal(err)
}
t.Cleanup(cleanUp)
peersUpdateManager := mgmt.NewPeersUpdateManager(nil)
eventStore := &activity.InMemoryEventStore{}
if err != nil {
return nil, nil
}
iv, _ := integrations.NewIntegratedValidator(context.Background(), eventStore)
accountManager, err := mgmt.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, iv)
accountManager, err := mgmt.BuildManager(store, peersUpdateManager, nil, "", "", eventStore, nil, false)
if err != nil {
t.Fatal(err)
}
turnManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig)
mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, peersUpdateManager, turnManager, nil, nil)
mgmtServer, err := mgmt.NewServer(config, accountManager, peersUpdateManager, turnManager, nil, nil)
if err != nil {
t.Fatal(err)
}
@@ -108,7 +98,7 @@ func startManagement(t *testing.T, config *mgmt.Config) (*grpc.Server, net.Liste
}
func startClientDaemon(
t *testing.T, ctx context.Context, _, configPath string,
t *testing.T, ctx context.Context, managementURL, configPath string,
) (*grpc.Server, net.Listener) {
t.Helper()
lis, err := net.Listen("tcp", "127.0.0.1:0")

View File

@@ -7,13 +7,11 @@ import (
"net/netip"
"runtime"
"strings"
"time"
log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
"google.golang.org/grpc/codes"
gstatus "google.golang.org/grpc/status"
"google.golang.org/protobuf/types/known/durationpb"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/peer"
@@ -42,12 +40,6 @@ func init() {
upCmd.PersistentFlags().BoolVarP(&foregroundMode, "foreground-mode", "F", false, "start service in foreground")
upCmd.PersistentFlags().StringVar(&interfaceName, interfaceNameFlag, iface.WgInterfaceDefault, "Wireguard interface name")
upCmd.PersistentFlags().Uint16Var(&wireguardPort, wireguardPortFlag, iface.DefaultWgPort, "Wireguard interface listening port")
upCmd.PersistentFlags().BoolVarP(&networkMonitor, networkMonitorFlag, "N", networkMonitor,
`Manage network monitoring. Defaults to true on Windows and macOS, false on Linux. `+
`E.g. --network-monitor=false to disable or --network-monitor=true to enable.`,
)
upCmd.PersistentFlags().StringSliceVar(&extraIFaceBlackList, extraIFaceBlackListFlag, nil, "Extra list of default interfaces to ignore for listening")
upCmd.PersistentFlags().DurationVar(&dnsRouteInterval, dnsRouteIntervalFlag, time.Minute, "DNS route update interval")
}
func upFunc(cmd *cobra.Command, args []string) error {
@@ -91,12 +83,11 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command) error {
}
ic := internal.ConfigInput{
ManagementURL: managementURL,
AdminURL: adminURL,
ConfigPath: configPath,
NATExternalIPs: natExternalIPs,
CustomDNSAddress: customDNSAddressConverted,
ExtraIFaceBlackList: extraIFaceBlackList,
ManagementURL: managementURL,
AdminURL: adminURL,
ConfigPath: configPath,
NATExternalIPs: natExternalIPs,
CustomDNSAddress: customDNSAddressConverted,
}
if cmd.Flag(enableRosenpassFlag).Changed {
@@ -123,10 +114,6 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command) error {
ic.WireguardPort = &p
}
if cmd.Flag(networkMonitorFlag).Changed {
ic.NetworkMonitor = &networkMonitor
}
if rootCmd.PersistentFlags().Changed(preSharedKeyFlag) {
ic.PreSharedKey = &preSharedKey
}
@@ -143,10 +130,6 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command) error {
}
}
if cmd.Flag(dnsRouteIntervalFlag).Changed {
ic.DNSRouteInterval = &dnsRouteInterval
}
config, err := internal.UpdateOrCreateConfig(ic)
if err != nil {
return fmt.Errorf("get config file: %v", err)
@@ -162,12 +145,11 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command) error {
var cancel context.CancelFunc
ctx, cancel = context.WithCancel(ctx)
SetupCloseHandler(ctx, cancel)
connectClient := internal.NewConnectClient(ctx, config, peer.NewRecorder(config.ManagementURL.String()))
return connectClient.Run()
return internal.RunClient(ctx, config, peer.NewRecorder(config.ManagementURL.String()))
}
func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error {
customDNSAddressConverted, err := parseCustomDNSAddress(cmd.Flag(dnsResolverAddress).Changed)
if err != nil {
return err
@@ -208,7 +190,6 @@ func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error {
CustomDNSAddress: customDNSAddressConverted,
IsLinuxDesktopClient: isLinuxRunningDesktop(),
Hostname: hostName,
ExtraIFaceBlacklist: extraIFaceBlackList,
}
if rootCmd.PersistentFlags().Changed(preSharedKeyFlag) {
@@ -243,14 +224,6 @@ func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error {
loginRequest.WireguardPort = &wp
}
if cmd.Flag(networkMonitorFlag).Changed {
loginRequest.NetworkMonitor = &networkMonitor
}
if cmd.Flag(dnsRouteIntervalFlag).Changed {
loginRequest.DnsRouteInterval = durationpb.New(dnsRouteInterval)
}
var loginErr error
var loginResp *proto.LoginResponse

View File

@@ -1,30 +0,0 @@
package errors
import (
"fmt"
"strings"
"github.com/hashicorp/go-multierror"
)
func formatError(es []error) string {
if len(es) == 0 {
return fmt.Sprintf("0 error occurred:\n\t* %s", es[0])
}
points := make([]string, len(es))
for i, err := range es {
points[i] = fmt.Sprintf("* %s", err)
}
return fmt.Sprintf(
"%d errors occurred:\n\t%s",
len(es), strings.Join(points, "\n\t"))
}
func FormatErrorOrNil(err *multierror.Error) error {
if err != nil {
err.ErrorFormat = formatError
}
return err.ErrorOrNil()
}

View File

@@ -42,20 +42,20 @@ func NewFirewall(context context.Context, iface IFaceMapper) (firewall.Manager,
switch check() {
case IPTABLES:
log.Info("creating an iptables firewall manager")
log.Debug("creating an iptables firewall manager")
fm, errFw = nbiptables.Create(context, iface)
if errFw != nil {
log.Errorf("failed to create iptables manager: %s", errFw)
}
case NFTABLES:
log.Info("creating an nftables firewall manager")
log.Debug("creating an nftables firewall manager")
fm, errFw = nbnftables.Create(context, iface)
if errFw != nil {
log.Errorf("failed to create nftables manager: %s", errFw)
}
default:
errFw = fmt.Errorf("no firewall manager found")
log.Info("no firewall manager found, trying to use userspace packet filtering firewall")
log.Debug("no firewall manager found, try to use userspace packet filtering firewall")
}
if iface.IsUserspaceBind() {
@@ -85,58 +85,16 @@ func NewFirewall(context context.Context, iface IFaceMapper) (firewall.Manager,
// check returns the firewall type based on common lib checks. It returns UNKNOWN if no firewall is found.
func check() FWType {
useIPTABLES := false
var iptablesChains []string
ip, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
if err == nil && isIptablesClientAvailable(ip) {
major, minor, _ := ip.GetIptablesVersion()
// use iptables when its version is lower than 1.8.0 which doesn't work well with our nftables manager
if major < 1 || (major == 1 && minor < 8) {
return IPTABLES
}
useIPTABLES = true
iptablesChains, err = ip.ListChains("filter")
if err != nil {
log.Errorf("failed to list iptables chains: %s", err)
useIPTABLES = false
}
}
nf := nftables.Conn{}
if chains, err := nf.ListChains(); err == nil && os.Getenv(SKIP_NFTABLES_ENV) != "true" {
if !useIPTABLES {
return NFTABLES
}
// search for chains where table is filter
// if we find one, we assume that nftables manager can be used with iptables
for _, chain := range chains {
if chain.Table.Name == "filter" {
return NFTABLES
}
}
// check tables for the following constraints:
// 1. there is no chain in nftables for the filter table and there is at least one chain in iptables, we assume that nftables manager can not be used
// 2. there is no tables or more than one table, we assume that nftables manager can be used
// 3. there is only one table and its name is filter, we assume that nftables manager can not be used, since there was no chain in it
// 4. if we find an error we log and continue with iptables check
nbTablesList, err := nf.ListTables()
switch {
case err == nil && len(iptablesChains) > 0:
return IPTABLES
case err == nil && len(nbTablesList) != 1:
return NFTABLES
case err == nil && len(nbTablesList) == 1 && nbTablesList[0].Name == "filter":
return IPTABLES
case err != nil:
log.Errorf("failed to list nftables tables on fw manager discovery: %s", err)
}
if _, err := nf.ListChains(); err == nil && os.Getenv(SKIP_NFTABLES_ENV) != "true" {
return NFTABLES
}
if useIPTABLES {
ip, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
if err != nil {
return UNKNOWN
}
if isIptablesClientAvailable(ip) {
return IPTABLES
}

View File

@@ -74,12 +74,12 @@ func (i *routerManager) InsertRoutingRules(pair firewall.RouterPair) error {
return nil
}
err = i.addNATRule(firewall.NatFormat, tableNat, chainRTNAT, routingFinalNatJump, pair)
err = i.insertRoutingRule(firewall.NatFormat, tableNat, chainRTNAT, routingFinalNatJump, pair)
if err != nil {
return err
}
err = i.addNATRule(firewall.InNatFormat, tableNat, chainRTNAT, routingFinalNatJump, firewall.GetInPair(pair))
err = i.insertRoutingRule(firewall.InNatFormat, tableNat, chainRTNAT, routingFinalNatJump, firewall.GetInPair(pair))
if err != nil {
return err
}
@@ -87,12 +87,12 @@ func (i *routerManager) InsertRoutingRules(pair firewall.RouterPair) error {
return nil
}
// insertRoutingRule inserts an iptables rule
// insertRoutingRule inserts an iptable rule
func (i *routerManager) insertRoutingRule(keyFormat, table, chain, jump string, pair firewall.RouterPair) error {
var err error
ruleKey := firewall.GenKey(keyFormat, pair.ID)
rule := genRuleSpec(jump, pair.Source, pair.Destination)
rule := genRuleSpec(jump, ruleKey, pair.Source, pair.Destination)
existingRule, found := i.rules[ruleKey]
if found {
err = i.iptablesClient.DeleteIfExists(table, chain, existingRule...)
@@ -101,7 +101,6 @@ func (i *routerManager) insertRoutingRule(keyFormat, table, chain, jump string,
}
delete(i.rules, ruleKey)
}
err = i.iptablesClient.Insert(table, chain, 1, rule...)
if err != nil {
return fmt.Errorf("error while adding new %s rule for %s: %v", getIptablesRuleType(table), pair.Destination, err)
@@ -318,13 +317,6 @@ func (i *routerManager) createChain(table, newChain string) error {
return fmt.Errorf("couldn't create chain %s in %s table, error: %v", newChain, table, err)
}
// Add the loopback return rule to the NAT chain
loopbackRule := []string{"-o", "lo", "-j", "RETURN"}
err = i.iptablesClient.Insert(table, newChain, 1, loopbackRule...)
if err != nil {
return fmt.Errorf("failed to add loopback return rule to %s: %v", chainRTNAT, err)
}
err = i.iptablesClient.Append(table, newChain, "-j", "RETURN")
if err != nil {
return fmt.Errorf("couldn't create chain %s default rule, error: %v", newChain, err)
@@ -334,33 +326,9 @@ func (i *routerManager) createChain(table, newChain string) error {
return nil
}
// addNATRule appends an iptables rule pair to the nat chain
func (i *routerManager) addNATRule(keyFormat, table, chain, jump string, pair firewall.RouterPair) error {
ruleKey := firewall.GenKey(keyFormat, pair.ID)
rule := genRuleSpec(jump, pair.Source, pair.Destination)
existingRule, found := i.rules[ruleKey]
if found {
err := i.iptablesClient.DeleteIfExists(table, chain, existingRule...)
if err != nil {
return fmt.Errorf("error while removing existing NAT rule for %s: %v", pair.Destination, err)
}
delete(i.rules, ruleKey)
}
// inserting after loopback ignore rule
err := i.iptablesClient.Insert(table, chain, 2, rule...)
if err != nil {
return fmt.Errorf("error while appending new NAT rule for %s: %v", pair.Destination, err)
}
i.rules[ruleKey] = rule
return nil
}
// genRuleSpec generates rule specification
func genRuleSpec(jump, source, destination string) []string {
return []string{"-s", source, "-d", destination, "-j", jump}
// genRuleSpec generates rule specification with comment identifier
func genRuleSpec(jump, id, source, destination string) []string {
return []string{"-s", source, "-d", destination, "-j", jump, "-m", "comment", "--comment", id}
}
func getIptablesRuleType(table string) string {

View File

@@ -51,12 +51,14 @@ func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) {
Destination: "100.100.100.0/24",
Masquerade: true,
}
forward4Rule := genRuleSpec(routingFinalForwardJump, pair.Source, pair.Destination)
forward4RuleKey := firewall.GenKey(firewall.ForwardingFormat, pair.ID)
forward4Rule := genRuleSpec(routingFinalForwardJump, forward4RuleKey, pair.Source, pair.Destination)
err = manager.iptablesClient.Insert(tableFilter, chainRTFWD, 1, forward4Rule...)
require.NoError(t, err, "inserting rule should not return error")
nat4Rule := genRuleSpec(routingFinalNatJump, pair.Source, pair.Destination)
nat4RuleKey := firewall.GenKey(firewall.NatFormat, pair.ID)
nat4Rule := genRuleSpec(routingFinalNatJump, nat4RuleKey, pair.Source, pair.Destination)
err = manager.iptablesClient.Insert(tableNat, chainRTNAT, 1, nat4Rule...)
require.NoError(t, err, "inserting rule should not return error")
@@ -90,7 +92,7 @@ func TestIptablesManager_InsertRoutingRules(t *testing.T) {
require.NoError(t, err, "forwarding pair should be inserted")
forwardRuleKey := firewall.GenKey(firewall.ForwardingFormat, testCase.InputPair.ID)
forwardRule := genRuleSpec(routingFinalForwardJump, testCase.InputPair.Source, testCase.InputPair.Destination)
forwardRule := genRuleSpec(routingFinalForwardJump, forwardRuleKey, testCase.InputPair.Source, testCase.InputPair.Destination)
exists, err := iptablesClient.Exists(tableFilter, chainRTFWD, forwardRule...)
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableFilter, chainRTFWD)
@@ -101,7 +103,7 @@ func TestIptablesManager_InsertRoutingRules(t *testing.T) {
require.Equal(t, forwardRule[:4], foundRule[:4], "stored forwarding rule should match")
inForwardRuleKey := firewall.GenKey(firewall.InForwardingFormat, testCase.InputPair.ID)
inForwardRule := genRuleSpec(routingFinalForwardJump, firewall.GetInPair(testCase.InputPair).Source, firewall.GetInPair(testCase.InputPair).Destination)
inForwardRule := genRuleSpec(routingFinalForwardJump, inForwardRuleKey, firewall.GetInPair(testCase.InputPair).Source, firewall.GetInPair(testCase.InputPair).Destination)
exists, err = iptablesClient.Exists(tableFilter, chainRTFWD, inForwardRule...)
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableFilter, chainRTFWD)
@@ -112,7 +114,7 @@ func TestIptablesManager_InsertRoutingRules(t *testing.T) {
require.Equal(t, inForwardRule[:4], foundRule[:4], "stored income forwarding rule should match")
natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair.ID)
natRule := genRuleSpec(routingFinalNatJump, testCase.InputPair.Source, testCase.InputPair.Destination)
natRule := genRuleSpec(routingFinalNatJump, natRuleKey, testCase.InputPair.Source, testCase.InputPair.Destination)
exists, err = iptablesClient.Exists(tableNat, chainRTNAT, natRule...)
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainRTNAT)
@@ -128,7 +130,7 @@ func TestIptablesManager_InsertRoutingRules(t *testing.T) {
}
inNatRuleKey := firewall.GenKey(firewall.InNatFormat, testCase.InputPair.ID)
inNatRule := genRuleSpec(routingFinalNatJump, firewall.GetInPair(testCase.InputPair).Source, firewall.GetInPair(testCase.InputPair).Destination)
inNatRule := genRuleSpec(routingFinalNatJump, inNatRuleKey, firewall.GetInPair(testCase.InputPair).Source, firewall.GetInPair(testCase.InputPair).Destination)
exists, err = iptablesClient.Exists(tableNat, chainRTNAT, inNatRule...)
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainRTNAT)
@@ -165,25 +167,25 @@ func TestIptablesManager_RemoveRoutingRules(t *testing.T) {
require.NoError(t, err, "shouldn't return error")
forwardRuleKey := firewall.GenKey(firewall.ForwardingFormat, testCase.InputPair.ID)
forwardRule := genRuleSpec(routingFinalForwardJump, testCase.InputPair.Source, testCase.InputPair.Destination)
forwardRule := genRuleSpec(routingFinalForwardJump, forwardRuleKey, testCase.InputPair.Source, testCase.InputPair.Destination)
err = iptablesClient.Insert(tableFilter, chainRTFWD, 1, forwardRule...)
require.NoError(t, err, "inserting rule should not return error")
inForwardRuleKey := firewall.GenKey(firewall.InForwardingFormat, testCase.InputPair.ID)
inForwardRule := genRuleSpec(routingFinalForwardJump, firewall.GetInPair(testCase.InputPair).Source, firewall.GetInPair(testCase.InputPair).Destination)
inForwardRule := genRuleSpec(routingFinalForwardJump, inForwardRuleKey, firewall.GetInPair(testCase.InputPair).Source, firewall.GetInPair(testCase.InputPair).Destination)
err = iptablesClient.Insert(tableFilter, chainRTFWD, 1, inForwardRule...)
require.NoError(t, err, "inserting rule should not return error")
natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair.ID)
natRule := genRuleSpec(routingFinalNatJump, testCase.InputPair.Source, testCase.InputPair.Destination)
natRule := genRuleSpec(routingFinalNatJump, natRuleKey, testCase.InputPair.Source, testCase.InputPair.Destination)
err = iptablesClient.Insert(tableNat, chainRTNAT, 1, natRule...)
require.NoError(t, err, "inserting rule should not return error")
inNatRuleKey := firewall.GenKey(firewall.InNatFormat, testCase.InputPair.ID)
inNatRule := genRuleSpec(routingFinalNatJump, firewall.GetInPair(testCase.InputPair).Source, firewall.GetInPair(testCase.InputPair).Destination)
inNatRule := genRuleSpec(routingFinalNatJump, inNatRuleKey, firewall.GetInPair(testCase.InputPair).Source, firewall.GetInPair(testCase.InputPair).Destination)
err = iptablesClient.Insert(tableNat, chainRTNAT, 1, inNatRule...)
require.NoError(t, err, "inserting rule should not return error")

View File

@@ -95,7 +95,7 @@ func (m *Manager) InsertRoutingRules(pair firewall.RouterPair) error {
m.mutex.Lock()
defer m.mutex.Unlock()
return m.router.AddRoutingRules(pair)
return m.router.InsertRoutingRules(pair)
}
func (m *Manager) RemoveRoutingRules(pair firewall.RouterPair) error {

View File

@@ -22,8 +22,6 @@ const (
userDataAcceptForwardRuleSrc = "frwacceptsrc"
userDataAcceptForwardRuleDst = "frwacceptdst"
loopbackInterface = "lo\x00"
)
// some presets for building nftable rules
@@ -128,22 +126,6 @@ func (r *router) createContainers() error {
Type: nftables.ChainTypeNAT,
})
// Add RETURN rule for loopback interface
loRule := &nftables.Rule{
Table: r.workTable,
Chain: r.chains[chainNameRoutingNat],
Exprs: []expr.Any{
&expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: []byte(loopbackInterface),
},
&expr.Verdict{Kind: expr.VerdictReturn},
},
}
r.conn.InsertRule(loRule)
err := r.refreshRulesMap()
if err != nil {
log.Errorf("failed to clean up rules from FORWARD chain: %s", err)
@@ -156,28 +138,28 @@ func (r *router) createContainers() error {
return nil
}
// AddRoutingRules appends a nftable rule pair to the forwarding chain and if enabled, to the nat chain
func (r *router) AddRoutingRules(pair manager.RouterPair) error {
// InsertRoutingRules inserts a nftable rule pair to the forwarding chain and if enabled, to the nat chain
func (r *router) InsertRoutingRules(pair manager.RouterPair) error {
err := r.refreshRulesMap()
if err != nil {
return err
}
err = r.addRoutingRule(manager.ForwardingFormat, chainNameRouteingFw, pair, false)
err = r.insertRoutingRule(manager.ForwardingFormat, chainNameRouteingFw, pair, false)
if err != nil {
return err
}
err = r.addRoutingRule(manager.InForwardingFormat, chainNameRouteingFw, manager.GetInPair(pair), false)
err = r.insertRoutingRule(manager.InForwardingFormat, chainNameRouteingFw, manager.GetInPair(pair), false)
if err != nil {
return err
}
if pair.Masquerade {
err = r.addRoutingRule(manager.NatFormat, chainNameRoutingNat, pair, true)
err = r.insertRoutingRule(manager.NatFormat, chainNameRoutingNat, pair, true)
if err != nil {
return err
}
err = r.addRoutingRule(manager.InNatFormat, chainNameRoutingNat, manager.GetInPair(pair), true)
err = r.insertRoutingRule(manager.InNatFormat, chainNameRoutingNat, manager.GetInPair(pair), true)
if err != nil {
return err
}
@@ -195,8 +177,8 @@ func (r *router) AddRoutingRules(pair manager.RouterPair) error {
return nil
}
// addRoutingRule inserts a nftable rule to the conn client flush queue
func (r *router) addRoutingRule(format, chainName string, pair manager.RouterPair, isNat bool) error {
// insertRoutingRule inserts a nftable rule to the conn client flush queue
func (r *router) insertRoutingRule(format, chainName string, pair manager.RouterPair, isNat bool) error {
sourceExp := generateCIDRMatcherExpressions(true, pair.Source)
destExp := generateCIDRMatcherExpressions(false, pair.Destination)
@@ -217,7 +199,7 @@ func (r *router) addRoutingRule(format, chainName string, pair manager.RouterPai
}
}
r.rules[ruleKey] = r.conn.AddRule(&nftables.Rule{
r.rules[ruleKey] = r.conn.InsertRule(&nftables.Rule{
Table: r.workTable,
Chain: r.chains[chainName],
Exprs: expression,

View File

@@ -47,7 +47,7 @@ func TestNftablesManager_InsertRoutingRules(t *testing.T) {
require.NoError(t, err, "shouldn't return error")
err = manager.AddRoutingRules(testCase.InputPair)
err = manager.InsertRoutingRules(testCase.InputPair)
defer func() {
_ = manager.RemoveRoutingRules(testCase.InputPair)
}()

View File

@@ -64,18 +64,15 @@ func manageFirewallRule(ruleName string, action action, extraArgs ...string) err
if action == addRule {
args = append(args, extraArgs...)
}
netshCmd := GetSystem32Command("netsh")
cmd := exec.Command(netshCmd, args...)
cmd := exec.Command("netsh", args...)
cmd.SysProcAttr = &syscall.SysProcAttr{HideWindow: true}
return cmd.Run()
}
func isWindowsFirewallReachable() bool {
args := []string{"advfirewall", "show", "allprofiles", "state"}
netshCmd := GetSystem32Command("netsh")
cmd := exec.Command(netshCmd, args...)
cmd := exec.Command("netsh", args...)
cmd.SysProcAttr = &syscall.SysProcAttr{HideWindow: true}
_, err := cmd.Output()
@@ -90,23 +87,8 @@ func isWindowsFirewallReachable() bool {
func isFirewallRuleActive(ruleName string) bool {
args := []string{"advfirewall", "firewall", "show", "rule", "name=" + ruleName}
netshCmd := GetSystem32Command("netsh")
cmd := exec.Command(netshCmd, args...)
cmd := exec.Command("netsh", args...)
cmd.SysProcAttr = &syscall.SysProcAttr{HideWindow: true}
_, err := cmd.Output()
return err == nil
}
// GetSystem32Command checks if a command can be found in the system path and returns it. In case it can't find it
// in the path it will return the full path of a command assuming C:\windows\system32 as the base path.
func GetSystem32Command(command string) string {
_, err := exec.LookPath(command)
if err == nil {
return command
}
log.Tracef("Command %s not found in PATH, using C:\\windows\\system32\\%s.exe path", command, command)
return "C:\\windows\\system32\\" + command + ".exe"
}

View File

@@ -5,17 +5,12 @@ import (
"fmt"
"net/url"
"os"
"reflect"
"runtime"
"strings"
"time"
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"github.com/netbirdio/netbird/client/internal/routemanager/dynamic"
"github.com/netbirdio/netbird/client/ssh"
"github.com/netbirdio/netbird/iface"
mgm "github.com/netbirdio/netbird/management/client"
@@ -35,10 +30,8 @@ const (
DefaultAdminURL = "https://app.netbird.io:443"
)
var defaultInterfaceBlacklist = []string{
iface.WgInterfaceDefault, "wt", "utun", "tun0", "zt", "ZeroTier", "wg", "ts",
"Tailscale", "tailscale", "docker", "veth", "br-", "lo",
}
var defaultInterfaceBlacklist = []string{iface.WgInterfaceDefault, "wt", "utun", "tun0", "zt", "ZeroTier", "wg", "ts",
"Tailscale", "tailscale", "docker", "veth", "br-", "lo"}
// ConfigInput carries configuration changes to the client
type ConfigInput struct {
@@ -53,10 +46,7 @@ type ConfigInput struct {
RosenpassPermissive *bool
InterfaceName *string
WireguardPort *int
NetworkMonitor *bool
DisableAutoConnect *bool
ExtraIFaceBlackList []string
DNSRouteInterval *time.Duration
}
// Config Configuration type
@@ -68,7 +58,6 @@ type Config struct {
AdminURL *url.URL
WgIface string
WgPort int
NetworkMonitor *bool
IFaceBlackList []string
DisableIPv6Discovery bool
RosenpassEnabled bool
@@ -99,9 +88,6 @@ type Config struct {
// DisableAutoConnect determines whether the client should not start with the service
// it's set to false by default due to backwards compatibility
DisableAutoConnect bool
// DNSRouteInterval is the interval in which the DNS routes are updated
DNSRouteInterval time.Duration
}
// ReadConfig read config file and return with Config. If it is not exists create a new with default values
@@ -111,14 +97,6 @@ func ReadConfig(configPath string) (*Config, error) {
if _, err := util.ReadJson(configPath, config); err != nil {
return nil, err
}
// initialize through apply() without changes
if changed, err := config.apply(ConfigInput{}); err != nil {
return nil, err
} else if changed {
if err = WriteOutConfig(configPath, config); err != nil {
return nil, err
}
}
return config, nil
}
@@ -171,15 +149,78 @@ func WriteOutConfig(path string, config *Config) error {
// createNewConfig creates a new config generating a new Wireguard key and saving to file
func createNewConfig(input ConfigInput) (*Config, error) {
config := &Config{
// defaults to false only for new (post 0.26) configurations
ServerSSHAllowed: util.False(),
}
if _, err := config.apply(input); err != nil {
wgKey := generateKey()
pem, err := ssh.GeneratePrivateKey(ssh.ED25519)
if err != nil {
return nil, err
}
config := &Config{
SSHKey: string(pem),
PrivateKey: wgKey,
IFaceBlackList: []string{},
DisableIPv6Discovery: false,
NATExternalIPs: input.NATExternalIPs,
CustomDNSAddress: string(input.CustomDNSAddress),
ServerSSHAllowed: util.False(),
DisableAutoConnect: false,
}
defaultManagementURL, err := parseURL("Management URL", DefaultManagementURL)
if err != nil {
return nil, err
}
config.ManagementURL = defaultManagementURL
if input.ManagementURL != "" {
URL, err := parseURL("Management URL", input.ManagementURL)
if err != nil {
return nil, err
}
config.ManagementURL = URL
}
config.WgPort = iface.DefaultWgPort
if input.WireguardPort != nil {
config.WgPort = *input.WireguardPort
}
config.WgIface = iface.WgInterfaceDefault
if input.InterfaceName != nil {
config.WgIface = *input.InterfaceName
}
if input.PreSharedKey != nil {
config.PreSharedKey = *input.PreSharedKey
}
if input.RosenpassEnabled != nil {
config.RosenpassEnabled = *input.RosenpassEnabled
}
if input.RosenpassPermissive != nil {
config.RosenpassPermissive = *input.RosenpassPermissive
}
if input.ServerSSHAllowed != nil {
config.ServerSSHAllowed = input.ServerSSHAllowed
}
defaultAdminURL, err := parseURL("Admin URL", DefaultAdminURL)
if err != nil {
return nil, err
}
config.AdminURL = defaultAdminURL
if input.AdminURL != "" {
newURL, err := parseURL("Admin Panel URL", input.AdminURL)
if err != nil {
return nil, err
}
config.AdminURL = newURL
}
config.IFaceBlackList = defaultInterfaceBlacklist
return config, nil
}
@@ -190,12 +231,97 @@ func update(input ConfigInput) (*Config, error) {
return nil, err
}
updated, err := config.apply(input)
if err != nil {
return nil, err
refresh := false
if input.ManagementURL != "" && config.ManagementURL.String() != input.ManagementURL {
log.Infof("new Management URL provided, updated to %s (old value %s)",
input.ManagementURL, config.ManagementURL)
newURL, err := parseURL("Management URL", input.ManagementURL)
if err != nil {
return nil, err
}
config.ManagementURL = newURL
refresh = true
}
if updated {
if input.AdminURL != "" && (config.AdminURL == nil || config.AdminURL.String() != input.AdminURL) {
log.Infof("new Admin Panel URL provided, updated to %s (old value %s)",
input.AdminURL, config.AdminURL)
newURL, err := parseURL("Admin Panel URL", input.AdminURL)
if err != nil {
return nil, err
}
config.AdminURL = newURL
refresh = true
}
if input.PreSharedKey != nil && config.PreSharedKey != *input.PreSharedKey {
log.Infof("new pre-shared key provided, replacing old key")
config.PreSharedKey = *input.PreSharedKey
refresh = true
}
if config.SSHKey == "" {
pem, err := ssh.GeneratePrivateKey(ssh.ED25519)
if err != nil {
return nil, err
}
config.SSHKey = string(pem)
refresh = true
}
if config.WgPort == 0 {
config.WgPort = iface.DefaultWgPort
refresh = true
}
if input.WireguardPort != nil {
config.WgPort = *input.WireguardPort
refresh = true
}
if input.InterfaceName != nil {
config.WgIface = *input.InterfaceName
refresh = true
}
if input.NATExternalIPs != nil && len(config.NATExternalIPs) != len(input.NATExternalIPs) {
config.NATExternalIPs = input.NATExternalIPs
refresh = true
}
if input.CustomDNSAddress != nil {
config.CustomDNSAddress = string(input.CustomDNSAddress)
refresh = true
}
if input.RosenpassEnabled != nil {
config.RosenpassEnabled = *input.RosenpassEnabled
refresh = true
}
if input.RosenpassPermissive != nil {
config.RosenpassPermissive = *input.RosenpassPermissive
refresh = true
}
if input.DisableAutoConnect != nil {
config.DisableAutoConnect = *input.DisableAutoConnect
refresh = true
}
if input.ServerSSHAllowed != nil {
config.ServerSSHAllowed = input.ServerSSHAllowed
refresh = true
}
if config.ServerSSHAllowed == nil {
config.ServerSSHAllowed = util.True()
refresh = true
}
if refresh {
// since we have new management URL, we need to update config file
if err := util.WriteJson(input.ConfigPath, config); err != nil {
return nil, err
}
@@ -204,190 +330,6 @@ func update(input ConfigInput) (*Config, error) {
return config, nil
}
func (config *Config) apply(input ConfigInput) (updated bool, err error) {
if config.ManagementURL == nil {
log.Infof("using default Management URL %s", DefaultManagementURL)
config.ManagementURL, err = parseURL("Management URL", DefaultManagementURL)
if err != nil {
return false, err
}
}
if input.ManagementURL != "" && input.ManagementURL != config.ManagementURL.String() {
log.Infof("new Management URL provided, updated to %#v (old value %#v)",
input.ManagementURL, config.ManagementURL.String())
URL, err := parseURL("Management URL", input.ManagementURL)
if err != nil {
return false, err
}
config.ManagementURL = URL
updated = true
} else if config.ManagementURL == nil {
log.Infof("using default Management URL %s", DefaultManagementURL)
config.ManagementURL, err = parseURL("Management URL", DefaultManagementURL)
if err != nil {
return false, err
}
}
if config.AdminURL == nil {
log.Infof("using default Admin URL %s", DefaultManagementURL)
config.AdminURL, err = parseURL("Admin URL", DefaultAdminURL)
if err != nil {
return false, err
}
}
if input.AdminURL != "" && input.AdminURL != config.AdminURL.String() {
log.Infof("new Admin Panel URL provided, updated to %#v (old value %#v)",
input.AdminURL, config.AdminURL.String())
newURL, err := parseURL("Admin Panel URL", input.AdminURL)
if err != nil {
return updated, err
}
config.AdminURL = newURL
updated = true
}
if config.PrivateKey == "" {
log.Infof("generated new Wireguard key")
config.PrivateKey = generateKey()
updated = true
}
if config.SSHKey == "" {
log.Infof("generated new SSH key")
pem, err := ssh.GeneratePrivateKey(ssh.ED25519)
if err != nil {
return false, err
}
config.SSHKey = string(pem)
updated = true
}
if input.WireguardPort != nil && *input.WireguardPort != config.WgPort {
log.Infof("updating Wireguard port %d (old value %d)",
*input.WireguardPort, config.WgPort)
config.WgPort = *input.WireguardPort
updated = true
} else if config.WgPort == 0 {
config.WgPort = iface.DefaultWgPort
log.Infof("using default Wireguard port %d", config.WgPort)
updated = true
}
if input.InterfaceName != nil && *input.InterfaceName != config.WgIface {
log.Infof("updating Wireguard interface %#v (old value %#v)",
*input.InterfaceName, config.WgIface)
config.WgIface = *input.InterfaceName
updated = true
} else if config.WgIface == "" {
config.WgIface = iface.WgInterfaceDefault
log.Infof("using default Wireguard interface %s", config.WgIface)
updated = true
}
if input.NATExternalIPs != nil && !reflect.DeepEqual(config.NATExternalIPs, input.NATExternalIPs) {
log.Infof("updating NAT External IP [ %s ] (old value: [ %s ])",
strings.Join(input.NATExternalIPs, " "),
strings.Join(config.NATExternalIPs, " "))
config.NATExternalIPs = input.NATExternalIPs
updated = true
}
if input.PreSharedKey != nil && *input.PreSharedKey != config.PreSharedKey {
log.Infof("new pre-shared key provided, replacing old key")
config.PreSharedKey = *input.PreSharedKey
updated = true
}
if input.RosenpassEnabled != nil && *input.RosenpassEnabled != config.RosenpassEnabled {
log.Infof("switching Rosenpass to %t", *input.RosenpassEnabled)
config.RosenpassEnabled = *input.RosenpassEnabled
updated = true
}
if input.RosenpassPermissive != nil && *input.RosenpassPermissive != config.RosenpassPermissive {
log.Infof("switching Rosenpass permissive to %t", *input.RosenpassPermissive)
config.RosenpassPermissive = *input.RosenpassPermissive
updated = true
}
if input.NetworkMonitor != nil && input.NetworkMonitor != config.NetworkMonitor {
log.Infof("switching Network Monitor to %t", *input.NetworkMonitor)
config.NetworkMonitor = input.NetworkMonitor
updated = true
}
if config.NetworkMonitor == nil {
// enable network monitoring by default on windows and darwin clients
if runtime.GOOS == "windows" || runtime.GOOS == "darwin" {
enabled := true
config.NetworkMonitor = &enabled
updated = true
}
}
if input.CustomDNSAddress != nil && string(input.CustomDNSAddress) != config.CustomDNSAddress {
log.Infof("updating custom DNS address %#v (old value %#v)",
string(input.CustomDNSAddress), config.CustomDNSAddress)
config.CustomDNSAddress = string(input.CustomDNSAddress)
updated = true
}
if len(config.IFaceBlackList) == 0 {
log.Infof("filling in interface blacklist with defaults: [ %s ]",
strings.Join(defaultInterfaceBlacklist, " "))
config.IFaceBlackList = append(config.IFaceBlackList, defaultInterfaceBlacklist...)
updated = true
}
if len(input.ExtraIFaceBlackList) > 0 {
for _, iFace := range util.SliceDiff(input.ExtraIFaceBlackList, config.IFaceBlackList) {
log.Infof("adding new entry to interface blacklist: %s", iFace)
config.IFaceBlackList = append(config.IFaceBlackList, iFace)
updated = true
}
}
if input.DisableAutoConnect != nil && *input.DisableAutoConnect != config.DisableAutoConnect {
if *input.DisableAutoConnect {
log.Infof("turning off automatic connection on startup")
} else {
log.Infof("enabling automatic connection on startup")
}
config.DisableAutoConnect = *input.DisableAutoConnect
updated = true
}
if input.ServerSSHAllowed != nil && *input.ServerSSHAllowed != *config.ServerSSHAllowed {
if *input.ServerSSHAllowed {
log.Infof("enabling SSH server")
} else {
log.Infof("disabling SSH server")
}
config.ServerSSHAllowed = input.ServerSSHAllowed
updated = true
} else if config.ServerSSHAllowed == nil {
// enables SSH for configs from old versions to preserve backwards compatibility
log.Infof("falling back to enabled SSH server for pre-existing configuration")
config.ServerSSHAllowed = util.True()
updated = true
}
if input.DNSRouteInterval != nil && *input.DNSRouteInterval != config.DNSRouteInterval {
log.Infof("updating DNS route interval to %s (old value %s)",
input.DNSRouteInterval.String(), config.DNSRouteInterval.String())
config.DNSRouteInterval = *input.DNSRouteInterval
updated = true
} else if config.DNSRouteInterval == 0 {
config.DNSRouteInterval = dynamic.DefaultInterval
log.Infof("using default DNS route interval %s", config.DNSRouteInterval)
updated = true
}
return updated, nil
}
// parseURL parses and validates a service URL
func parseURL(serviceName, serviceURL string) (*url.URL, error) {
parsedMgmtURL, err := url.ParseRequestURI(serviceURL)
@@ -442,6 +384,7 @@ func configFileIsExists(path string) bool {
// If it can switch, then it updates the config and returns a new one. Otherwise, it returns the provided config.
// The check is performed only for the NetBird's managed version.
func UpdateOldManagementURL(ctx context.Context, config *Config, configPath string) (*Config, error) {
defaultManagementURL, err := parseURL("Management URL", DefaultManagementURL)
if err != nil {
return nil, err

View File

@@ -18,6 +18,7 @@ func TestGetConfig(t *testing.T) {
config, err := UpdateOrCreateConfig(ConfigInput{
ConfigPath: filepath.Join(t.TempDir(), "config.json"),
})
if err != nil {
return
}
@@ -85,26 +86,6 @@ func TestGetConfig(t *testing.T) {
assert.Equal(t, readConf.(*Config).ManagementURL.String(), newManagementURL)
}
func TestExtraIFaceBlackList(t *testing.T) {
extraIFaceBlackList := []string{"eth1"}
path := filepath.Join(t.TempDir(), "config.json")
config, err := UpdateOrCreateConfig(ConfigInput{
ConfigPath: path,
ExtraIFaceBlackList: extraIFaceBlackList,
})
if err != nil {
return
}
assert.Contains(t, config.IFaceBlackList, "eth1")
readConf, err := util.ReadJson(path, config)
if err != nil {
return
}
assert.Contains(t, readConf.(*Config).IFaceBlackList, "eth1")
}
func TestHiddenPreSharedKey(t *testing.T) {
hidden := "**********"
samplePreSharedKey := "mysecretpresharedkey"
@@ -130,6 +111,7 @@ func TestHiddenPreSharedKey(t *testing.T) {
ConfigPath: cfgFile,
PreSharedKey: tt.preSharedKey,
})
if err != nil {
t.Fatalf("failed to get cfg: %s", err)
}

View File

@@ -4,11 +4,7 @@ import (
"context"
"errors"
"fmt"
"net"
"runtime"
"runtime/debug"
"strings"
"sync"
"time"
"github.com/cenkalti/backoff/v4"
@@ -31,45 +27,29 @@ import (
"github.com/netbirdio/netbird/version"
)
type ConnectClient struct {
ctx context.Context
config *Config
statusRecorder *peer.Status
engine *Engine
engineMutex sync.Mutex
// RunClient with main logic.
func RunClient(ctx context.Context, config *Config, statusRecorder *peer.Status) error {
return runClient(ctx, config, statusRecorder, MobileDependency{}, nil, nil, nil, nil)
}
func NewConnectClient(
// RunClientWithProbes runs the client's main logic with probes attached
func RunClientWithProbes(
ctx context.Context,
config *Config,
statusRecorder *peer.Status,
) *ConnectClient {
return &ConnectClient{
ctx: ctx,
config: config,
statusRecorder: statusRecorder,
engineMutex: sync.Mutex{},
}
}
// Run with main logic.
func (c *ConnectClient) Run() error {
return c.run(MobileDependency{}, nil, nil, nil, nil)
}
// RunWithProbes runs the client's main logic with probes attached
func (c *ConnectClient) RunWithProbes(
mgmProbe *Probe,
signalProbe *Probe,
relayProbe *Probe,
wgProbe *Probe,
) error {
return c.run(MobileDependency{}, mgmProbe, signalProbe, relayProbe, wgProbe)
return runClient(ctx, config, statusRecorder, MobileDependency{}, mgmProbe, signalProbe, relayProbe, wgProbe)
}
// RunOnAndroid with main logic on mobile system
func (c *ConnectClient) RunOnAndroid(
// RunClientMobile with main logic on mobile system
func RunClientMobile(
ctx context.Context,
config *Config,
statusRecorder *peer.Status,
tunAdapter iface.TunAdapter,
iFaceDiscover stdnet.ExternalIFaceDiscover,
networkChangeListener listener.NetworkChangeListener,
@@ -84,43 +64,40 @@ func (c *ConnectClient) RunOnAndroid(
HostDNSAddresses: dnsAddresses,
DnsReadyListener: dnsReadyListener,
}
return c.run(mobileDependency, nil, nil, nil, nil)
return runClient(ctx, config, statusRecorder, mobileDependency, nil, nil, nil, nil)
}
func (c *ConnectClient) RunOniOS(
func RunClientiOS(
ctx context.Context,
config *Config,
statusRecorder *peer.Status,
fileDescriptor int32,
networkChangeListener listener.NetworkChangeListener,
dnsManager dns.IosDnsManager,
) error {
// Set GC percent to 5% to reduce memory usage as iOS only allows 50MB of memory for the extension.
debug.SetGCPercent(5)
mobileDependency := MobileDependency{
FileDescriptor: fileDescriptor,
NetworkChangeListener: networkChangeListener,
DnsManager: dnsManager,
}
return c.run(mobileDependency, nil, nil, nil, nil)
return runClient(ctx, config, statusRecorder, mobileDependency, nil, nil, nil, nil)
}
func (c *ConnectClient) run(
func runClient(
ctx context.Context,
config *Config,
statusRecorder *peer.Status,
mobileDependency MobileDependency,
mgmProbe *Probe,
signalProbe *Probe,
relayProbe *Probe,
wgProbe *Probe,
) error {
defer func() {
if r := recover(); r != nil {
log.Panicf("Panic occurred: %v, stack trace: %s", r, string(debug.Stack()))
}
}()
log.Infof("starting NetBird client version %s on %s/%s", version.NetbirdVersion(), runtime.GOOS, runtime.GOARCH)
log.Infof("starting NetBird client version %s", version.NetbirdVersion())
// Check if client was not shut down in a clean way and restore DNS config if required.
// Otherwise, we might not be able to connect to the management server to retrieve new config.
if err := dns.CheckUncleanShutdown(c.config.WgIface); err != nil {
if err := dns.CheckUncleanShutdown(config.WgIface); err != nil {
log.Errorf("checking unclean shutdown error: %s", err)
}
@@ -134,7 +111,7 @@ func (c *ConnectClient) run(
Clock: backoff.SystemClock,
}
state := CtxGetState(c.ctx)
state := CtxGetState(ctx)
defer func() {
s, err := state.Status()
if err != nil || s != StatusNeedsLogin {
@@ -143,49 +120,49 @@ func (c *ConnectClient) run(
}()
wrapErr := state.Wrap
myPrivateKey, err := wgtypes.ParseKey(c.config.PrivateKey)
myPrivateKey, err := wgtypes.ParseKey(config.PrivateKey)
if err != nil {
log.Errorf("failed parsing Wireguard key %s: [%s]", c.config.PrivateKey, err.Error())
log.Errorf("failed parsing Wireguard key %s: [%s]", config.PrivateKey, err.Error())
return wrapErr(err)
}
var mgmTlsEnabled bool
if c.config.ManagementURL.Scheme == "https" {
if config.ManagementURL.Scheme == "https" {
mgmTlsEnabled = true
}
publicSSHKey, err := ssh.GeneratePublicKey([]byte(c.config.SSHKey))
publicSSHKey, err := ssh.GeneratePublicKey([]byte(config.SSHKey))
if err != nil {
return err
}
defer c.statusRecorder.ClientStop()
defer statusRecorder.ClientStop()
operation := func() error {
// if context cancelled we not start new backoff cycle
select {
case <-c.ctx.Done():
case <-ctx.Done():
return nil
default:
}
state.Set(StatusConnecting)
engineCtx, cancel := context.WithCancel(c.ctx)
engineCtx, cancel := context.WithCancel(ctx)
defer func() {
c.statusRecorder.MarkManagementDisconnected(state.err)
c.statusRecorder.CleanLocalPeerState()
statusRecorder.MarkManagementDisconnected(state.err)
statusRecorder.CleanLocalPeerState()
cancel()
}()
log.Debugf("connecting to the Management service %s", c.config.ManagementURL.Host)
mgmClient, err := mgm.NewClient(engineCtx, c.config.ManagementURL.Host, myPrivateKey, mgmTlsEnabled)
log.Debugf("connecting to the Management service %s", config.ManagementURL.Host)
mgmClient, err := mgm.NewClient(engineCtx, config.ManagementURL.Host, myPrivateKey, mgmTlsEnabled)
if err != nil {
return wrapErr(gstatus.Errorf(codes.FailedPrecondition, "failed connecting to Management Service : %s", err))
}
mgmNotifier := statusRecorderToMgmConnStateNotifier(c.statusRecorder)
mgmNotifier := statusRecorderToMgmConnStateNotifier(statusRecorder)
mgmClient.SetConnStateListener(mgmNotifier)
log.Debugf("connected to the Management service %s", c.config.ManagementURL.Host)
log.Debugf("connected to the Management service %s", config.ManagementURL.Host)
defer func() {
err = mgmClient.Close()
if err != nil {
@@ -203,7 +180,7 @@ func (c *ConnectClient) run(
}
return wrapErr(err)
}
c.statusRecorder.MarkManagementConnected()
statusRecorder.MarkManagementConnected()
localPeerState := peer.LocalPeerState{
IP: loginResp.GetPeerConfig().GetAddress(),
@@ -212,18 +189,18 @@ func (c *ConnectClient) run(
FQDN: loginResp.GetPeerConfig().GetFqdn(),
}
c.statusRecorder.UpdateLocalPeerState(localPeerState)
statusRecorder.UpdateLocalPeerState(localPeerState)
signalURL := fmt.Sprintf("%s://%s",
strings.ToLower(loginResp.GetWiretrusteeConfig().GetSignal().GetProtocol().String()),
loginResp.GetWiretrusteeConfig().GetSignal().GetUri(),
)
c.statusRecorder.UpdateSignalAddress(signalURL)
statusRecorder.UpdateSignalAddress(signalURL)
c.statusRecorder.MarkSignalDisconnected(nil)
statusRecorder.MarkSignalDisconnected(nil)
defer func() {
c.statusRecorder.MarkSignalDisconnected(state.err)
statusRecorder.MarkSignalDisconnected(state.err)
}()
// with the global Wiretrustee config in hand connect (just a connection, no stream yet) Signal
@@ -239,40 +216,35 @@ func (c *ConnectClient) run(
}
}()
signalNotifier := statusRecorderToSignalConnStateNotifier(c.statusRecorder)
signalNotifier := statusRecorderToSignalConnStateNotifier(statusRecorder)
signalClient.SetConnStateListener(signalNotifier)
c.statusRecorder.MarkSignalConnected()
statusRecorder.MarkSignalConnected()
peerConfig := loginResp.GetPeerConfig()
engineConfig, err := createEngineConfig(myPrivateKey, c.config, peerConfig)
engineConfig, err := createEngineConfig(myPrivateKey, config, peerConfig)
if err != nil {
log.Error(err)
return wrapErr(err)
}
checks := loginResp.GetChecks()
c.engineMutex.Lock()
c.engine = NewEngineWithProbes(engineCtx, cancel, signalClient, mgmClient, engineConfig, mobileDependency, c.statusRecorder, mgmProbe, signalProbe, relayProbe, wgProbe, checks)
c.engineMutex.Unlock()
err = c.engine.Start()
engine := NewEngineWithProbes(engineCtx, cancel, signalClient, mgmClient, engineConfig, mobileDependency, statusRecorder, mgmProbe, signalProbe, relayProbe, wgProbe)
err = engine.Start()
if err != nil {
log.Errorf("error while starting Netbird Connection Engine: %s", err)
return wrapErr(err)
}
log.Infof("Netbird engine started, the IP is: %s", peerConfig.GetAddress())
log.Print("Netbird engine started, my IP is: ", peerConfig.Address)
state.Set(StatusConnected)
<-engineCtx.Done()
c.statusRecorder.ClientTeardown()
statusRecorder.ClientTeardown()
backOff.Reset()
err = c.engine.Stop()
err = engine.Stop()
if err != nil {
log.Errorf("failed stopping engine %v", err)
return wrapErr(err)
@@ -287,7 +259,7 @@ func (c *ConnectClient) run(
return nil
}
c.statusRecorder.ClientStart()
statusRecorder.ClientStart()
err = backoff.Retry(operation, backOff)
if err != nil {
log.Debugf("exiting client retry loop due to unrecoverable error: %s", err)
@@ -299,20 +271,8 @@ func (c *ConnectClient) run(
return nil
}
func (c *ConnectClient) Engine() *Engine {
var e *Engine
c.engineMutex.Lock()
e = c.engine
c.engineMutex.Unlock()
return e
}
// createEngineConfig converts configuration received from Management Service to EngineConfig
func createEngineConfig(key wgtypes.Key, config *Config, peerConfig *mgmProto.PeerConfig) (*EngineConfig, error) {
nm := false
if config.NetworkMonitor != nil {
nm = *config.NetworkMonitor
}
engineConf := &EngineConfig{
WgIfaceName: config.WgIface,
WgAddr: peerConfig.Address,
@@ -320,14 +280,12 @@ func createEngineConfig(key wgtypes.Key, config *Config, peerConfig *mgmProto.Pe
DisableIPv6Discovery: config.DisableIPv6Discovery,
WgPrivateKey: key,
WgPort: config.WgPort,
NetworkMonitor: nm,
SSHKey: []byte(config.SSHKey),
NATExternalIPs: config.NATExternalIPs,
CustomDNSAddress: config.CustomDNSAddress,
RosenpassEnabled: config.RosenpassEnabled,
RosenpassPermissive: config.RosenpassPermissive,
ServerSSHAllowed: util.ReturnBoolWithDefaultTrue(config.ServerSSHAllowed),
DNSRouteInterval: config.DNSRouteInterval,
}
if config.PreSharedKey != "" {
@@ -338,15 +296,6 @@ func createEngineConfig(key wgtypes.Key, config *Config, peerConfig *mgmProto.Pe
engineConf.PreSharedKey = &preSharedKey
}
port, err := freePort(config.WgPort)
if err != nil {
return nil, err
}
if port != config.WgPort {
log.Infof("using %d as wireguard port: %d is in use", port, config.WgPort)
}
engineConf.WgPort = port
return engineConf, nil
}
@@ -396,20 +345,3 @@ func statusRecorderToSignalConnStateNotifier(statusRecorder *peer.Status) signal
notifier, _ := sri.(signal.ConnStateNotifier)
return notifier
}
func freePort(start int) (int, error) {
addr := net.UDPAddr{}
if start == 0 {
start = iface.DefaultWgPort
}
for x := start; x <= 65535; x++ {
addr.Port = x
conn, err := net.ListenUDP("udp", &addr)
if err != nil {
continue
}
conn.Close()
return x, nil
}
return 0, errors.New("no free ports")
}

View File

@@ -1,57 +0,0 @@
package internal
import (
"net"
"testing"
)
func Test_freePort(t *testing.T) {
tests := []struct {
name string
port int
want int
wantErr bool
}{
{
name: "available",
port: 51820,
want: 51820,
wantErr: false,
},
{
name: "notavailable",
port: 51830,
want: 51831,
wantErr: false,
},
{
name: "noports",
port: 65535,
want: 0,
wantErr: true,
},
}
for _, tt := range tests {
c1, err := net.ListenUDP("udp", &net.UDPAddr{Port: 51830})
if err != nil {
t.Errorf("freePort error = %v", err)
}
c2, err := net.ListenUDP("udp", &net.UDPAddr{Port: 65535})
if err != nil {
t.Errorf("freePort error = %v", err)
}
t.Run(tt.name, func(t *testing.T) {
got, err := freePort(tt.port)
if (err != nil) != tt.wantErr {
t.Errorf("freePort() error = %v, wantErr %v", err, tt.wantErr)
return
}
if got != tt.want {
t.Errorf("freePort() = %v, want %v", got, tt.want)
}
})
c1.Close()
c2.Close()
}
}

View File

@@ -1,6 +0,0 @@
package dns
const (
fileUncleanShutdownResolvConfLocation = "/var/db/netbird/resolv.conf"
fileUncleanShutdownManagerTypeLocation = "/var/db/netbird/manager"
)

View File

@@ -1,8 +0,0 @@
//go:build !android
package dns
const (
fileUncleanShutdownResolvConfLocation = "/var/lib/netbird/resolv.conf"
fileUncleanShutdownManagerTypeLocation = "/var/lib/netbird/manager"
)

View File

@@ -1,4 +1,4 @@
//go:build (linux && !android) || freebsd
//go:build !android
package dns

View File

@@ -1,4 +1,4 @@
//go:build (linux && !android) || freebsd
//go:build !android
package dns
@@ -47,20 +47,24 @@ func (f *fileConfigurator) supportCustomPort() bool {
}
func (f *fileConfigurator) applyDNSConfig(config HostDNSConfig) error {
backupFileExist := f.isBackupFileExist()
backupFileExist := false
_, err := os.Stat(fileDefaultResolvConfBackupLocation)
if err == nil {
backupFileExist = true
}
if !config.RouteAll {
if backupFileExist {
f.repair.stopWatchFileChanges()
err := f.restore()
err = f.restore()
if err != nil {
return fmt.Errorf("restoring the original resolv.conf file return err: %w", err)
return fmt.Errorf("unable to configure DNS for this peer using file manager without a Primary nameserver group. Restoring the original file return err: %w", err)
}
}
return fmt.Errorf("unable to configure DNS for this peer using file manager without a nameserver group with all domains configured")
}
if !backupFileExist {
err := f.backup()
err = f.backup()
if err != nil {
return fmt.Errorf("unable to backup the resolv.conf file: %w", err)
}
@@ -180,11 +184,6 @@ func (f *fileConfigurator) restoreUncleanShutdownDNS(storedDNSAddress *netip.Add
return nil
}
func (f *fileConfigurator) isBackupFileExist() bool {
_, err := os.Stat(fileDefaultResolvConfBackupLocation)
return err == nil
}
func restoreResolvConfFile() error {
log.Debugf("restoring unclean shutdown: restoring %s from %s", defaultResolvConfPath, fileUncleanShutdownResolvConfLocation)

View File

@@ -1,4 +1,4 @@
//go:build (linux && !android) || freebsd
//go:build !android
package dns

View File

@@ -1,4 +1,4 @@
//go:build (linux && !android) || freebsd
//go:build !android
package dns

View File

@@ -1,4 +1,4 @@
//go:build (linux && !android) || freebsd
//go:build !android
package dns

View File

@@ -1,4 +1,4 @@
//go:build (linux && !android) || freebsd
//go:build !android
package dns

View File

@@ -1,4 +1,4 @@
//go:build (linux && !android) || freebsd
//go:build !android
package dns

View File

@@ -1,4 +1,4 @@
//go:build (linux && !android) || freebsd
//go:build !android
package dns
@@ -108,7 +108,7 @@ func getOSDNSManagerType() (osManagerType, error) {
if strings.Contains(text, "NetworkManager") && isDbusListenerRunning(networkManagerDest, networkManagerDbusObjectNode) && isNetworkManagerSupported() {
return networkManager, nil
}
if strings.Contains(text, "systemd-resolved") && isSystemdResolvedRunning() {
if strings.Contains(text, "systemd-resolved") && isDbusListenerRunning(systemdResolvedDest, systemdDbusObjectNode) {
if checkStub() {
return systemdManager, nil
} else {
@@ -116,10 +116,16 @@ func getOSDNSManagerType() (osManagerType, error) {
}
}
if strings.Contains(text, "resolvconf") {
if isSystemdResolveConfMode() {
return systemdManager, nil
if isDbusListenerRunning(systemdResolvedDest, systemdDbusObjectNode) {
var value string
err = getSystemdDbusProperty(systemdDbusResolvConfModeProperty, &value)
if err == nil {
if value == systemdDbusResolvConfModeForeign {
return systemdManager, nil
}
}
log.Errorf("got an error while checking systemd resolv conf mode, error: %s", err)
}
return resolvConfManager, nil
}
}

View File

@@ -1,63 +0,0 @@
package dns
import (
"fmt"
"net/netip"
"sync"
log "github.com/sirupsen/logrus"
)
type hostsDNSHolder struct {
unprotectedDNSList map[string]struct{}
mutex sync.RWMutex
}
func newHostsDNSHolder() *hostsDNSHolder {
return &hostsDNSHolder{
unprotectedDNSList: make(map[string]struct{}),
}
}
func (h *hostsDNSHolder) set(list []string) {
h.mutex.Lock()
h.unprotectedDNSList = make(map[string]struct{})
for _, dns := range list {
dnsAddr, err := h.normalizeAddress(dns)
if err != nil {
continue
}
h.unprotectedDNSList[dnsAddr] = struct{}{}
}
h.mutex.Unlock()
}
func (h *hostsDNSHolder) get() map[string]struct{} {
h.mutex.RLock()
l := h.unprotectedDNSList
h.mutex.RUnlock()
return l
}
//nolint:unused
func (h *hostsDNSHolder) isContain(upstream string) bool {
h.mutex.RLock()
defer h.mutex.RUnlock()
_, ok := h.unprotectedDNSList[upstream]
return ok
}
func (h *hostsDNSHolder) normalizeAddress(addr string) (string, error) {
a, err := netip.ParseAddr(addr)
if err != nil {
log.Errorf("invalid upstream IP address: %s, error: %s", addr, err)
return "", err
}
if a.Is4() {
return fmt.Sprintf("%s:53", addr), nil
} else {
return fmt.Sprintf("[%s]:53", addr), nil
}
}

View File

@@ -31,8 +31,6 @@ func (d *localResolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
response := d.lookupRecord(r)
if response != nil {
replyMessage.Answer = append(replyMessage.Answer, response)
} else {
replyMessage.Rcode = dns.RcodeNameError
}
err := w.WriteMsg(replyMessage)

View File

@@ -1,4 +1,4 @@
//go:build (linux && !android) || freebsd
//go:build !android
package dns

View File

@@ -1,4 +1,4 @@
//go:build (linux && !android) || freebsd
//go:build !android
package dns

View File

@@ -4,8 +4,6 @@ import (
"context"
"fmt"
"net/netip"
"runtime"
"strings"
"sync"
"github.com/miekg/dns"
@@ -13,7 +11,6 @@ import (
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/listener"
"github.com/netbirdio/netbird/client/internal/peer"
nbdns "github.com/netbirdio/netbird/dns"
)
@@ -55,14 +52,13 @@ type DefaultServer struct {
currentConfig HostDNSConfig
// permanent related properties
permanent bool
hostsDNSHolder *hostsDNSHolder
permanent bool
hostsDnsList []string
hostsDnsListLock sync.Mutex
// make sense on mobile only
searchDomainNotifier *notifier
iosDnsManager IosDnsManager
statusRecorder *peer.Status
}
type handlerWithStop interface {
@@ -77,12 +73,7 @@ type muxUpdate struct {
}
// NewDefaultServer returns a new dns server
func NewDefaultServer(
ctx context.Context,
wgInterface WGIface,
customAddress string,
statusRecorder *peer.Status,
) (*DefaultServer, error) {
func NewDefaultServer(ctx context.Context, wgInterface WGIface, customAddress string) (*DefaultServer, error) {
var addrPort *netip.AddrPort
if customAddress != "" {
parsedAddrPort, err := netip.ParseAddrPort(customAddress)
@@ -99,22 +90,15 @@ func NewDefaultServer(
dnsService = newServiceViaListener(wgInterface, addrPort)
}
return newDefaultServer(ctx, wgInterface, dnsService, statusRecorder), nil
return newDefaultServer(ctx, wgInterface, dnsService), nil
}
// NewDefaultServerPermanentUpstream returns a new dns server. It optimized for mobile systems
func NewDefaultServerPermanentUpstream(
ctx context.Context,
wgInterface WGIface,
hostsDnsList []string,
config nbdns.Config,
listener listener.NetworkChangeListener,
statusRecorder *peer.Status,
) *DefaultServer {
func NewDefaultServerPermanentUpstream(ctx context.Context, wgInterface WGIface, hostsDnsList []string, config nbdns.Config, listener listener.NetworkChangeListener) *DefaultServer {
log.Debugf("host dns address list is: %v", hostsDnsList)
ds := newDefaultServer(ctx, wgInterface, newServiceViaMemory(wgInterface), statusRecorder)
ds.hostsDNSHolder.set(hostsDnsList)
ds := newDefaultServer(ctx, wgInterface, newServiceViaMemory(wgInterface))
ds.permanent = true
ds.hostsDnsList = hostsDnsList
ds.addHostRootZone()
ds.currentConfig = dnsConfigToHostDNSConfig(config, ds.service.RuntimeIP(), ds.service.RuntimePort())
ds.searchDomainNotifier = newNotifier(ds.SearchDomains())
@@ -124,18 +108,13 @@ func NewDefaultServerPermanentUpstream(
}
// NewDefaultServerIos returns a new dns server. It optimized for ios
func NewDefaultServerIos(
ctx context.Context,
wgInterface WGIface,
iosDnsManager IosDnsManager,
statusRecorder *peer.Status,
) *DefaultServer {
ds := newDefaultServer(ctx, wgInterface, newServiceViaMemory(wgInterface), statusRecorder)
func NewDefaultServerIos(ctx context.Context, wgInterface WGIface, iosDnsManager IosDnsManager) *DefaultServer {
ds := newDefaultServer(ctx, wgInterface, newServiceViaMemory(wgInterface))
ds.iosDnsManager = iosDnsManager
return ds
}
func newDefaultServer(ctx context.Context, wgInterface WGIface, dnsService service, statusRecorder *peer.Status) *DefaultServer {
func newDefaultServer(ctx context.Context, wgInterface WGIface, dnsService service) *DefaultServer {
ctx, stop := context.WithCancel(ctx)
defaultServer := &DefaultServer{
ctx: ctx,
@@ -145,9 +124,7 @@ func newDefaultServer(ctx context.Context, wgInterface WGIface, dnsService servi
localResolver: &localResolver{
registeredMap: make(registrationMap),
},
wgInterface: wgInterface,
statusRecorder: statusRecorder,
hostsDNSHolder: newHostsDNSHolder(),
wgInterface: wgInterface,
}
return defaultServer
@@ -203,8 +180,10 @@ func (s *DefaultServer) Stop() {
// OnUpdatedHostDNSServer update the DNS servers addresses for root zones
// It will be applied if the mgm server do not enforce DNS settings for root zone
func (s *DefaultServer) OnUpdatedHostDNSServer(hostsDnsList []string) {
s.hostsDNSHolder.set(hostsDnsList)
s.hostsDnsListLock.Lock()
defer s.hostsDnsListLock.Unlock()
s.hostsDnsList = hostsDnsList
_, ok := s.dnsMuxMap[nbdns.RootZone]
if ok {
log.Debugf("on new host DNS config but skip to apply it")
@@ -277,15 +256,9 @@ func (s *DefaultServer) SearchDomains() []string {
// ProbeAvailability tests each upstream group's servers for availability
// and deactivates the group if no server responds
func (s *DefaultServer) ProbeAvailability() {
var wg sync.WaitGroup
for _, mux := range s.dnsMuxMap {
wg.Add(1)
go func(mux handlerWithStop) {
defer wg.Done()
mux.probeAvailability()
}(mux)
mux.probeAvailability()
}
wg.Wait()
}
func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
@@ -326,8 +299,6 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
s.searchDomainNotifier.onNewSearchDomains(s.SearchDomains())
}
s.updateNSGroupStates(update.NameServerGroups)
return nil
}
@@ -367,14 +338,7 @@ func (s *DefaultServer) buildUpstreamHandlerUpdate(nameServerGroups []*nbdns.Nam
continue
}
handler, err := newUpstreamResolver(
s.ctx,
s.wgInterface.Name(),
s.wgInterface.Address().IP,
s.wgInterface.Address().Network,
s.statusRecorder,
s.hostsDNSHolder,
)
handler, err := newUpstreamResolver(s.ctx, s.wgInterface.Name(), s.wgInterface.Address().IP, s.wgInterface.Address().Network)
if err != nil {
return nil, fmt.Errorf("unable to create a new upstream resolver, error: %v", err)
}
@@ -452,7 +416,9 @@ func (s *DefaultServer) updateMux(muxUpdates []muxUpdate) {
_, found := muxUpdateMap[key]
if !found {
if !isContainRootUpdate && key == nbdns.RootZone {
s.hostsDnsListLock.Lock()
s.addHostRootZone()
s.hostsDnsListLock.Unlock()
existingHandler.stop()
} else {
existingHandler.stop()
@@ -494,14 +460,14 @@ func getNSHostPort(ns nbdns.NameServer) string {
func (s *DefaultServer) upstreamCallbacks(
nsGroup *nbdns.NameServerGroup,
handler dns.Handler,
) (deactivate func(error), reactivate func()) {
) (deactivate func(), reactivate func()) {
var removeIndex map[string]int
deactivate = func(err error) {
deactivate = func() {
s.mux.Lock()
defer s.mux.Unlock()
l := log.WithField("nameservers", nsGroup.NameServers)
l.Info("Temporarily deactivating nameservers group due to timeout")
l.Info("temporary deactivate nameservers group due timeout")
removeIndex = make(map[string]int)
for _, domain := range nsGroup.Domains {
@@ -510,7 +476,6 @@ func (s *DefaultServer) upstreamCallbacks(
if nsGroup.Primary {
removeIndex[nbdns.RootZone] = -1
s.currentConfig.RouteAll = false
s.service.DeregisterMux(nbdns.RootZone)
}
for i, item := range s.currentConfig.Domains {
@@ -520,17 +485,9 @@ func (s *DefaultServer) upstreamCallbacks(
removeIndex[item.Domain] = i
}
}
if err := s.hostManager.applyDNSConfig(s.currentConfig); err != nil {
l.Errorf("Failed to apply nameserver deactivation on the host: %v", err)
l.WithError(err).Error("fail to apply nameserver deactivation on the host")
}
if runtime.GOOS == "android" && nsGroup.Primary && len(s.hostsDNSHolder.get()) > 0 {
s.addHostRootZone()
}
s.updateNSState(nsGroup, err, false)
}
reactivate = func() {
s.mux.Lock()
@@ -549,79 +506,36 @@ func (s *DefaultServer) upstreamCallbacks(
if nsGroup.Primary {
s.currentConfig.RouteAll = true
s.service.RegisterMux(nbdns.RootZone, handler)
}
if err := s.hostManager.applyDNSConfig(s.currentConfig); err != nil {
l.WithError(err).Error("reactivate temporary disabled nameserver group, DNS update apply")
}
s.updateNSState(nsGroup, nil, true)
}
return
}
func (s *DefaultServer) addHostRootZone() {
handler, err := newUpstreamResolver(
s.ctx,
s.wgInterface.Name(),
s.wgInterface.Address().IP,
s.wgInterface.Address().Network,
s.statusRecorder,
s.hostsDNSHolder,
)
handler, err := newUpstreamResolver(s.ctx, s.wgInterface.Name(), s.wgInterface.Address().IP, s.wgInterface.Address().Network)
if err != nil {
log.Errorf("unable to create a new upstream resolver, error: %v", err)
return
}
handler.upstreamServers = make([]string, len(s.hostsDnsList))
for n, ua := range s.hostsDnsList {
a, err := netip.ParseAddr(ua)
if err != nil {
log.Errorf("invalid upstream IP address: %s, error: %s", ua, err)
continue
}
handler.upstreamServers = make([]string, 0)
for k := range s.hostsDNSHolder.get() {
handler.upstreamServers = append(handler.upstreamServers, k)
ipString := ua
if !a.Is4() {
ipString = fmt.Sprintf("[%s]", ua)
}
handler.upstreamServers[n] = fmt.Sprintf("%s:53", ipString)
}
handler.deactivate = func(error) {}
handler.deactivate = func() {}
handler.reactivate = func() {}
s.service.RegisterMux(nbdns.RootZone, handler)
}
func (s *DefaultServer) updateNSGroupStates(groups []*nbdns.NameServerGroup) {
var states []peer.NSGroupState
for _, group := range groups {
var servers []string
for _, ns := range group.NameServers {
servers = append(servers, fmt.Sprintf("%s:%d", ns.IP, ns.Port))
}
state := peer.NSGroupState{
ID: generateGroupKey(group),
Servers: servers,
Domains: group.Domains,
// The probe will determine the state, default enabled
Enabled: true,
Error: nil,
}
states = append(states, state)
}
s.statusRecorder.UpdateDNSStates(states)
}
func (s *DefaultServer) updateNSState(nsGroup *nbdns.NameServerGroup, err error, enabled bool) {
states := s.statusRecorder.GetDNSStates()
id := generateGroupKey(nsGroup)
for i, state := range states {
if state.ID == id {
states[i].Enabled = enabled
states[i].Error = err
break
}
}
s.statusRecorder.UpdateDNSStates(states)
}
func generateGroupKey(nsGroup *nbdns.NameServerGroup) string {
var servers []string
for _, ns := range nsGroup.NameServers {
servers = append(servers, fmt.Sprintf("%s:%d", ns.IP, ns.Port))
}
return fmt.Sprintf("%s_%s_%s", nsGroup.ID, nsGroup.Name, strings.Join(servers, ","))
}

View File

@@ -1,4 +1,4 @@
//go:build (linux && !android) || freebsd
//go:build !android
package dns

View File

@@ -15,7 +15,6 @@ import (
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/client/firewall/uspfilter"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/stdnet"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/formatter"
@@ -39,10 +38,6 @@ func (w *mocWGIface) Address() iface.WGAddress {
}
}
func (w *mocWGIface) ToInterface() *net.Interface {
panic("implement me")
}
func (w *mocWGIface) GetFilter() iface.PacketFilter {
return w.filter
}
@@ -265,7 +260,7 @@ func TestUpdateDNSServer(t *testing.T) {
if err != nil {
t.Fatal(err)
}
wgIface, err := iface.NewWGIFace(fmt.Sprintf("utun230%d", n), fmt.Sprintf("100.66.100.%d/32", n+1), 33100, privKey.String(), iface.DefaultMTU, newNet, nil, nil)
wgIface, err := iface.NewWGIFace(fmt.Sprintf("utun230%d", n), fmt.Sprintf("100.66.100.%d/32", n+1), 33100, privKey.String(), iface.DefaultMTU, newNet, nil)
if err != nil {
t.Fatal(err)
}
@@ -279,7 +274,7 @@ func TestUpdateDNSServer(t *testing.T) {
t.Log(err)
}
}()
dnsServer, err := NewDefaultServer(context.Background(), wgIface, "", &peer.Status{})
dnsServer, err := NewDefaultServer(context.Background(), wgIface, "")
if err != nil {
t.Fatal(err)
}
@@ -343,7 +338,7 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) {
}
privKey, _ := wgtypes.GeneratePrivateKey()
wgIface, err := iface.NewWGIFace("utun2301", "100.66.100.1/32", 33100, privKey.String(), iface.DefaultMTU, newNet, nil, nil)
wgIface, err := iface.NewWGIFace("utun2301", "100.66.100.1/32", 33100, privKey.String(), iface.DefaultMTU, newNet, nil)
if err != nil {
t.Errorf("build interface wireguard: %v", err)
return
@@ -380,7 +375,7 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) {
return
}
dnsServer, err := NewDefaultServer(context.Background(), wgIface, "", &peer.Status{})
dnsServer, err := NewDefaultServer(context.Background(), wgIface, "")
if err != nil {
t.Errorf("create DNS server: %v", err)
return
@@ -475,7 +470,7 @@ func TestDNSServerStartStop(t *testing.T) {
for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
dnsServer, err := NewDefaultServer(context.Background(), &mocWGIface{}, testCase.addrPort, &peer.Status{})
dnsServer, err := NewDefaultServer(context.Background(), &mocWGIface{}, testCase.addrPort)
if err != nil {
t.Fatalf("%v", err)
}
@@ -546,7 +541,6 @@ func TestDNSServerUpstreamDeactivateCallback(t *testing.T) {
{false, "domain2", false},
},
},
statusRecorder: &peer.Status{},
}
var domainsUpdate string
@@ -569,7 +563,7 @@ func TestDNSServerUpstreamDeactivateCallback(t *testing.T) {
},
}, nil)
deactivate(nil)
deactivate()
expected := "domain0,domain2"
domains := []string{}
for _, item := range server.currentConfig.Domains {
@@ -607,7 +601,7 @@ func TestDNSPermanent_updateHostDNS_emptyUpstream(t *testing.T) {
var dnsList []string
dnsConfig := nbdns.Config{}
dnsServer := NewDefaultServerPermanentUpstream(context.Background(), wgIFace, dnsList, dnsConfig, nil, &peer.Status{})
dnsServer := NewDefaultServerPermanentUpstream(context.Background(), wgIFace, dnsList, dnsConfig, nil)
err = dnsServer.Initialize()
if err != nil {
t.Errorf("failed to initialize DNS server: %v", err)
@@ -631,7 +625,7 @@ func TestDNSPermanent_updateUpstream(t *testing.T) {
}
defer wgIFace.Close()
dnsConfig := nbdns.Config{}
dnsServer := NewDefaultServerPermanentUpstream(context.Background(), wgIFace, []string{"8.8.8.8"}, dnsConfig, nil, &peer.Status{})
dnsServer := NewDefaultServerPermanentUpstream(context.Background(), wgIFace, []string{"8.8.8.8"}, dnsConfig, nil)
err = dnsServer.Initialize()
if err != nil {
t.Errorf("failed to initialize DNS server: %v", err)
@@ -723,7 +717,7 @@ func TestDNSPermanent_matchOnly(t *testing.T) {
}
defer wgIFace.Close()
dnsConfig := nbdns.Config{}
dnsServer := NewDefaultServerPermanentUpstream(context.Background(), wgIFace, []string{"8.8.8.8"}, dnsConfig, nil, &peer.Status{})
dnsServer := NewDefaultServerPermanentUpstream(context.Background(), wgIFace, []string{"8.8.8.8"}, dnsConfig, nil)
err = dnsServer.Initialize()
if err != nil {
t.Errorf("failed to initialize DNS server: %v", err)
@@ -754,11 +748,6 @@ func TestDNSPermanent_matchOnly(t *testing.T) {
NSType: nbdns.UDPNameServerType,
Port: 53,
},
{
IP: netip.MustParseAddr("9.9.9.9"),
NSType: nbdns.UDPNameServerType,
Port: 53,
},
},
Domains: []string{"customdomain.com"},
Primary: false,
@@ -801,7 +790,7 @@ func createWgInterfaceWithBind(t *testing.T) (*iface.WGIface, error) {
}
privKey, _ := wgtypes.GeneratePrivateKey()
wgIface, err := iface.NewWGIFace("utun2301", "100.66.100.2/24", 33100, privKey.String(), iface.DefaultMTU, newNet, nil, nil)
wgIface, err := iface.NewWGIFace("utun2301", "100.66.100.2/24", 33100, privKey.String(), iface.DefaultMTU, newNet, nil)
if err != nil {
t.Fatalf("build interface wireguard: %v", err)
return nil, err

View File

@@ -1,20 +0,0 @@
package dns
import (
"errors"
"fmt"
)
var errNotImplemented = errors.New("not implemented")
func newSystemdDbusConfigurator(wgInterface string) (hostManager, error) {
return nil, fmt.Errorf("systemd dns management: %w on freebsd", errNotImplemented)
}
func isSystemdResolvedRunning() bool {
return false
}
func isSystemdResolveConfMode() bool {
return false
}

View File

@@ -242,25 +242,3 @@ func getSystemdDbusProperty(property string, store any) error {
return v.Store(store)
}
func isSystemdResolvedRunning() bool {
return isDbusListenerRunning(systemdResolvedDest, systemdDbusObjectNode)
}
func isSystemdResolveConfMode() bool {
if !isDbusListenerRunning(systemdResolvedDest, systemdDbusObjectNode) {
return false
}
var value string
if err := getSystemdDbusProperty(systemdDbusResolvConfModeProperty, &value); err != nil {
log.Errorf("got an error while checking systemd resolv conf mode, error: %s", err)
return false
}
if value == systemdDbusResolvConfModeForeign {
return true
}
return false
}

View File

@@ -1,4 +1,4 @@
//go:build (linux && !android) || freebsd
//go:build !android
package dns
@@ -14,6 +14,11 @@ import (
log "github.com/sirupsen/logrus"
)
const (
fileUncleanShutdownResolvConfLocation = "/var/lib/netbird/resolv.conf"
fileUncleanShutdownManagerTypeLocation = "/var/lib/netbird/manager"
)
func CheckUncleanShutdown(wgIface string) error {
if _, err := os.Stat(fileUncleanShutdownResolvConfLocation); err != nil {
if errors.Is(err, fs.ErrNotExist) {

View File

@@ -5,16 +5,14 @@ import (
"errors"
"fmt"
"net"
"runtime"
"sync"
"sync/atomic"
"time"
"github.com/cenkalti/backoff/v4"
"github.com/hashicorp/go-multierror"
"github.com/miekg/dns"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/peer"
)
const (
@@ -47,13 +45,12 @@ type upstreamResolverBase struct {
reactivatePeriod time.Duration
upstreamTimeout time.Duration
deactivate func(error)
reactivate func()
statusRecorder *peer.Status
deactivate func()
reactivate func()
}
func newUpstreamResolverBase(ctx context.Context, statusRecorder *peer.Status) *upstreamResolverBase {
ctx, cancel := context.WithCancel(ctx)
func newUpstreamResolverBase(parentCTX context.Context) *upstreamResolverBase {
ctx, cancel := context.WithCancel(parentCTX)
return &upstreamResolverBase{
ctx: ctx,
@@ -61,7 +58,6 @@ func newUpstreamResolverBase(ctx context.Context, statusRecorder *peer.Status) *
upstreamTimeout: upstreamTimeout,
reactivatePeriod: reactivatePeriod,
failsTillDeact: failsTillDeact,
statusRecorder: statusRecorder,
}
}
@@ -72,17 +68,9 @@ func (u *upstreamResolverBase) stop() {
// ServeDNS handles a DNS request
func (u *upstreamResolverBase) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
var err error
defer func() {
u.checkUpstreamFails(err)
}()
defer u.checkUpstreamFails()
log.WithField("question", r.Question[0]).Trace("received an upstream question")
// set the AuthenticatedData flag and the EDNS0 buffer size to 4096 bytes to support larger dns records
if r.Extra == nil {
r.SetEdns0(4096, false)
r.MsgHdr.AuthenticatedData = true
}
select {
case <-u.ctx.Done():
@@ -93,6 +81,7 @@ func (u *upstreamResolverBase) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
for _, upstream := range u.upstreamServers {
var rm *dns.Msg
var t time.Duration
var err error
func() {
ctx, cancel := context.WithTimeout(u.ctx, u.upstreamTimeout)
@@ -143,7 +132,7 @@ func (u *upstreamResolverBase) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
// If fails count is greater that failsTillDeact, upstream resolving
// will be disabled for reactivatePeriod, after that time period fails counter
// will be reset and upstream will be reactivated.
func (u *upstreamResolverBase) checkUpstreamFails(err error) {
func (u *upstreamResolverBase) checkUpstreamFails() {
u.mutex.Lock()
defer u.mutex.Unlock()
@@ -157,7 +146,7 @@ func (u *upstreamResolverBase) checkUpstreamFails(err error) {
default:
}
u.disable(err)
u.disable()
}
// probeAvailability tests all upstream servers simultaneously and
@@ -176,16 +165,13 @@ func (u *upstreamResolverBase) probeAvailability() {
var mu sync.Mutex
var wg sync.WaitGroup
var errors *multierror.Error
for _, upstream := range u.upstreamServers {
upstream := upstream
wg.Add(1)
go func() {
defer wg.Done()
err := u.testNameserver(upstream)
if err != nil {
errors = multierror.Append(errors, err)
if err := u.testNameserver(upstream); err != nil {
log.Warnf("probing upstream nameserver %s: %s", upstream, err)
return
}
@@ -200,7 +186,7 @@ func (u *upstreamResolverBase) probeAvailability() {
// didn't find a working upstream server, let's disable and try later
if !success {
u.disable(errors.ErrorOrNil())
u.disable()
}
}
@@ -259,15 +245,18 @@ func isTimeout(err error) bool {
return false
}
func (u *upstreamResolverBase) disable(err error) {
func (u *upstreamResolverBase) disable() {
if u.disabled {
return
}
log.Warnf("Upstream resolving is Disabled for %v", reactivatePeriod)
u.deactivate(err)
u.disabled = true
go u.waitUntilResponse()
// todo test the deactivation logic, it seems to affect the client
if runtime.GOOS != "ios" {
log.Warnf("upstream resolving is Disabled for %v", reactivatePeriod)
u.deactivate()
u.disabled = true
go u.waitUntilResponse()
}
}
func (u *upstreamResolverBase) testNameserver(server string) error {

View File

@@ -1,84 +0,0 @@
package dns
import (
"context"
"net"
"syscall"
"time"
"github.com/miekg/dns"
"github.com/netbirdio/netbird/client/internal/peer"
nbnet "github.com/netbirdio/netbird/util/net"
)
type upstreamResolver struct {
*upstreamResolverBase
hostsDNSHolder *hostsDNSHolder
}
// newUpstreamResolver in Android we need to distinguish the DNS servers to available through VPN or outside of VPN
// In case if the assigned DNS address is available only in the protected network then the resolver will time out at the
// first time, and we need to wait for a while to start to use again the proper DNS resolver.
func newUpstreamResolver(
ctx context.Context,
_ string,
_ net.IP,
_ *net.IPNet,
statusRecorder *peer.Status,
hostsDNSHolder *hostsDNSHolder,
) (*upstreamResolver, error) {
upstreamResolverBase := newUpstreamResolverBase(ctx, statusRecorder)
c := &upstreamResolver{
upstreamResolverBase: upstreamResolverBase,
hostsDNSHolder: hostsDNSHolder,
}
upstreamResolverBase.upstreamClient = c
return c, nil
}
// exchange in case of Android if the upstream is a local resolver then we do not need to mark the socket as protected.
// In other case the DNS resolvation goes through the VPN, so we need to force to use the
func (u *upstreamResolver) exchange(ctx context.Context, upstream string, r *dns.Msg) (rm *dns.Msg, t time.Duration, err error) {
if u.isLocalResolver(upstream) {
return u.exchangeWithoutVPN(ctx, upstream, r)
} else {
return u.exchangeWithinVPN(ctx, upstream, r)
}
}
func (u *upstreamResolver) exchangeWithinVPN(ctx context.Context, upstream string, r *dns.Msg) (rm *dns.Msg, t time.Duration, err error) {
upstreamExchangeClient := &dns.Client{}
return upstreamExchangeClient.ExchangeContext(ctx, r, upstream)
}
// exchangeWithoutVPN protect the UDP socket by Android SDK to avoid to goes through the VPN
func (u *upstreamResolver) exchangeWithoutVPN(ctx context.Context, upstream string, r *dns.Msg) (rm *dns.Msg, t time.Duration, err error) {
timeout := upstreamTimeout
if deadline, ok := ctx.Deadline(); ok {
timeout = time.Until(deadline)
}
dialTimeout := timeout
nbDialer := nbnet.NewDialer()
dialer := &net.Dialer{
Control: func(network, address string, c syscall.RawConn) error {
return nbDialer.Control(network, address, c)
},
Timeout: dialTimeout,
}
upstreamExchangeClient := &dns.Client{
Dialer: dialer,
}
return upstreamExchangeClient.Exchange(r, upstream)
}
func (u *upstreamResolver) isLocalResolver(upstream string) bool {
if u.hostsDNSHolder.isContain(upstream) {
return true
}
return false
}

View File

@@ -1,38 +0,0 @@
//go:build !android && !ios
package dns
import (
"context"
"net"
"time"
"github.com/miekg/dns"
"github.com/netbirdio/netbird/client/internal/peer"
)
type upstreamResolver struct {
*upstreamResolverBase
}
func newUpstreamResolver(
ctx context.Context,
_ string,
_ net.IP,
_ *net.IPNet,
statusRecorder *peer.Status,
_ *hostsDNSHolder,
) (*upstreamResolver, error) {
upstreamResolverBase := newUpstreamResolverBase(ctx, statusRecorder)
nonIOS := &upstreamResolver{
upstreamResolverBase: upstreamResolverBase,
}
upstreamResolverBase.upstreamClient = nonIOS
return nonIOS, nil
}
func (u *upstreamResolver) exchange(ctx context.Context, upstream string, r *dns.Msg) (rm *dns.Msg, t time.Duration, err error) {
upstreamExchangeClient := &dns.Client{}
return upstreamExchangeClient.ExchangeContext(ctx, r, upstream)
}

View File

@@ -11,8 +11,6 @@ import (
"github.com/miekg/dns"
log "github.com/sirupsen/logrus"
"golang.org/x/sys/unix"
"github.com/netbirdio/netbird/client/internal/peer"
)
type upstreamResolverIOS struct {
@@ -22,15 +20,8 @@ type upstreamResolverIOS struct {
iIndex int
}
func newUpstreamResolver(
ctx context.Context,
interfaceName string,
ip net.IP,
net *net.IPNet,
statusRecorder *peer.Status,
_ *hostsDNSHolder,
) (*upstreamResolverIOS, error) {
upstreamResolverBase := newUpstreamResolverBase(ctx, statusRecorder)
func newUpstreamResolver(parentCTX context.Context, interfaceName string, ip net.IP, net *net.IPNet) (*upstreamResolverIOS, error) {
upstreamResolverBase := newUpstreamResolverBase(parentCTX)
index, err := getInterfaceIndex(interfaceName)
if err != nil {

View File

@@ -0,0 +1,29 @@
//go:build !ios
package dns
import (
"context"
"net"
"time"
"github.com/miekg/dns"
)
type upstreamResolverNonIOS struct {
*upstreamResolverBase
}
func newUpstreamResolver(parentCTX context.Context, interfaceName string, ip net.IP, net *net.IPNet) (*upstreamResolverNonIOS, error) {
upstreamResolverBase := newUpstreamResolverBase(parentCTX)
nonIOS := &upstreamResolverNonIOS{
upstreamResolverBase: upstreamResolverBase,
}
upstreamResolverBase.upstreamClient = nonIOS
return nonIOS, nil
}
func (u *upstreamResolverNonIOS) exchange(ctx context.Context, upstream string, r *dns.Msg) (rm *dns.Msg, t time.Duration, err error) {
upstreamExchangeClient := &dns.Client{}
return upstreamExchangeClient.ExchangeContext(ctx, r, upstream)
}

View File

@@ -58,7 +58,7 @@ func TestUpstreamResolver_ServeDNS(t *testing.T) {
for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
ctx, cancel := context.WithCancel(context.TODO())
resolver, _ := newUpstreamResolver(ctx, "", net.IP{}, &net.IPNet{}, nil, nil)
resolver, _ := newUpstreamResolver(ctx, "", net.IP{}, &net.IPNet{})
resolver.upstreamServers = testCase.InputServers
resolver.upstreamTimeout = testCase.timeout
if testCase.cancelCTX {
@@ -131,7 +131,7 @@ func TestUpstreamResolver_DeactivationReactivation(t *testing.T) {
}
failed := false
resolver.deactivate = func(error) {
resolver.deactivate = func() {
failed = true
}

View File

@@ -2,17 +2,12 @@
package dns
import (
"net"
"github.com/netbirdio/netbird/iface"
)
import "github.com/netbirdio/netbird/iface"
// WGIface defines subset methods of interface required for manager
type WGIface interface {
Name() string
Address() iface.WGAddress
ToInterface() *net.Interface
IsUserspaceBind() bool
GetFilter() iface.PacketFilter
GetDevice() *iface.DeviceWrapper

View File

@@ -2,15 +2,12 @@ package internal
import (
"context"
"errors"
"fmt"
"maps"
"math/rand"
"net"
"net/netip"
"reflect"
"runtime"
"slices"
"strings"
"sync"
"time"
@@ -24,26 +21,21 @@ import (
"github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/internal/acl"
"github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/networkmonitor"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/relay"
"github.com/netbirdio/netbird/client/internal/rosenpass"
"github.com/netbirdio/netbird/client/internal/routemanager"
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
"github.com/netbirdio/netbird/client/internal/wgproxy"
nbssh "github.com/netbirdio/netbird/client/ssh"
"github.com/netbirdio/netbird/client/system"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/iface"
"github.com/netbirdio/netbird/iface/bind"
mgm "github.com/netbirdio/netbird/management/client"
"github.com/netbirdio/netbird/management/domain"
mgmProto "github.com/netbirdio/netbird/management/proto"
"github.com/netbirdio/netbird/route"
signal "github.com/netbirdio/netbird/signal/client"
sProto "github.com/netbirdio/netbird/signal/proto"
"github.com/netbirdio/netbird/util"
nbnet "github.com/netbirdio/netbird/util/net"
)
// PeerConnectionTimeoutMax is a timeout of an initial connection attempt to a remote peer.
@@ -68,9 +60,6 @@ type EngineConfig struct {
// WgPrivateKey is a Wireguard private key of our peer (it MUST never leave the machine)
WgPrivateKey wgtypes.Key
// NetworkMonitor is a flag to enable network monitoring
NetworkMonitor bool
// IFaceBlackList is a list of network interfaces to ignore when discovering connection candidates (ICE related)
IFaceBlackList []string
DisableIPv6Discovery bool
@@ -94,8 +83,6 @@ type EngineConfig struct {
RosenpassPermissive bool
ServerSSHAllowed bool
DNSRouteInterval time.Duration
}
// Engine is a mechanism responsible for reacting on Signal and Management stream events and managing connections to the remote peers.
@@ -106,10 +93,6 @@ type Engine struct {
mgmClient mgm.Client
// peerConns is a map that holds all the peers that are known to this peer
peerConns map[string]*peer.Conn
beforePeerHook nbnet.AddHookFunc
afterPeerHook nbnet.RemoveHookFunc
// rpManager is a Rosenpass manager
rpManager *rosenpass.Manager
@@ -124,16 +107,10 @@ type Engine struct {
// TURNs is a list of STUN servers used by ICE
TURNs []*stun.URI
// clientRoutes is the most recent list of clientRoutes received from the Management Service
clientRoutes route.HAMap
clientRoutesMu sync.RWMutex
clientCtx context.Context
clientCancel context.CancelFunc
ctx context.Context
cancel context.CancelFunc
ctx context.Context
wgInterface *iface.WGIface
wgProxyFactory *wgproxy.Factory
@@ -142,8 +119,6 @@ type Engine struct {
// networkSerial is the latest CurrentSerial (state ID) of the network sent by the Management service
networkSerial uint64
networkMonitor *networkmonitor.NetworkMonitor
sshServerFunc func(hostKeyPEM []byte, addr string) (nbssh.Server, error)
sshServer nbssh.Server
@@ -159,11 +134,6 @@ type Engine struct {
signalProbe *Probe
relayProbe *Probe
wgProbe *Probe
wgConnWorker sync.WaitGroup
// checks are the client-applied posture checks that need to be evaluated on the client
checks []*mgmProto.Checks
}
// Peer is an instance of the Connection Peer
@@ -174,18 +144,17 @@ type Peer struct {
// NewEngine creates a new Connection Engine
func NewEngine(
clientCtx context.Context,
clientCancel context.CancelFunc,
ctx context.Context,
cancel context.CancelFunc,
signalClient signal.Client,
mgmClient mgm.Client,
config *EngineConfig,
mobileDep MobileDependency,
statusRecorder *peer.Status,
checks []*mgmProto.Checks,
) *Engine {
return NewEngineWithProbes(
clientCtx,
clientCancel,
ctx,
cancel,
signalClient,
mgmClient,
config,
@@ -195,14 +164,13 @@ func NewEngine(
nil,
nil,
nil,
checks,
)
}
// NewEngineWithProbes creates a new Connection Engine with probes attached
func NewEngineWithProbes(
clientCtx context.Context,
clientCancel context.CancelFunc,
ctx context.Context,
cancel context.CancelFunc,
signalClient signal.Client,
mgmClient mgm.Client,
config *EngineConfig,
@@ -212,12 +180,10 @@ func NewEngineWithProbes(
signalProbe *Probe,
relayProbe *Probe,
wgProbe *Probe,
checks []*mgmProto.Checks,
) *Engine {
return &Engine{
clientCtx: clientCtx,
clientCancel: clientCancel,
ctx: ctx,
cancel: cancel,
signal: signalClient,
mgmClient: mgmClient,
peerConns: make(map[string]*peer.Conn),
@@ -229,11 +195,11 @@ func NewEngineWithProbes(
networkSerial: 0,
sshServerFunc: nbssh.DefaultSSHServer,
statusRecorder: statusRecorder,
wgProxyFactory: wgproxy.NewFactory(config.WgPort),
mgmProbe: mgmProbe,
signalProbe: signalProbe,
relayProbe: relayProbe,
wgProbe: wgProbe,
checks: checks,
}
}
@@ -241,31 +207,16 @@ func (e *Engine) Stop() error {
e.syncMsgMux.Lock()
defer e.syncMsgMux.Unlock()
if e.cancel != nil {
e.cancel()
}
// stopping network monitor first to avoid starting the engine again
if e.networkMonitor != nil {
e.networkMonitor.Stop()
}
log.Info("Network monitor: stopped")
err := e.removeAllPeers()
if err != nil {
return err
}
e.clientRoutesMu.Lock()
e.clientRoutes = nil
e.clientRoutesMu.Unlock()
// very ugly but we want to remove peers from the WireGuard interface first before removing interface.
// Removing peers happens in the conn.Close() asynchronously
// Removing peers happens in the conn.CLose() asynchronously
time.Sleep(500 * time.Millisecond)
e.close()
e.wgConnWorker.Wait()
log.Infof("stopped Netbird Engine")
return nil
}
@@ -277,21 +228,13 @@ func (e *Engine) Start() error {
e.syncMsgMux.Lock()
defer e.syncMsgMux.Unlock()
if e.cancel != nil {
e.cancel()
}
e.ctx, e.cancel = context.WithCancel(e.clientCtx)
wgIface, err := e.newWgIface()
if err != nil {
log.Errorf("failed creating wireguard interface instance %s: [%s]", e.config.WgIfaceName, err)
return fmt.Errorf("new wg interface: %w", err)
log.Errorf("failed creating wireguard interface instance %s: [%s]", e.config.WgIfaceName, err.Error())
return err
}
e.wgInterface = wgIface
userspace := e.wgInterface.IsUserspaceBind()
e.wgProxyFactory = wgproxy.NewFactory(e.ctx, userspace, e.config.WgPort)
if e.config.RosenpassEnabled {
log.Infof("rosenpass is enabled")
if e.config.RosenpassPermissive {
@@ -301,37 +244,29 @@ func (e *Engine) Start() error {
}
e.rpManager, err = rosenpass.NewManager(e.config.PreSharedKey, e.config.WgIfaceName)
if err != nil {
return fmt.Errorf("create rosenpass manager: %w", err)
return err
}
err := e.rpManager.Run()
if err != nil {
return fmt.Errorf("run rosenpass manager: %w", err)
return err
}
}
initialRoutes, dnsServer, err := e.newDnsServer()
if err != nil {
e.close()
return fmt.Errorf("create dns server: %w", err)
return err
}
e.dnsServer = dnsServer
e.routeManager = routemanager.NewManager(e.ctx, e.config.WgPrivateKey.PublicKey().String(), e.config.DNSRouteInterval, e.wgInterface, e.statusRecorder, initialRoutes)
beforePeerHook, afterPeerHook, err := e.routeManager.Init()
if err != nil {
log.Errorf("Failed to initialize route manager: %s", err)
} else {
e.beforePeerHook = beforePeerHook
e.afterPeerHook = afterPeerHook
}
e.routeManager = routemanager.NewManager(e.ctx, e.config.WgPrivateKey.PublicKey().String(), e.wgInterface, e.statusRecorder, initialRoutes)
e.routeManager.SetRouteChangeListener(e.mobileDep.NetworkChangeListener)
err = e.wgInterfaceCreate()
if err != nil {
log.Errorf("failed creating tunnel interface %s: [%s]", e.config.WgIfaceName, err.Error())
e.close()
return fmt.Errorf("create wg interface: %w", err)
return err
}
e.firewall, err = firewall.NewFirewall(e.ctx, e.wgInterface)
@@ -343,7 +278,7 @@ func (e *Engine) Start() error {
err = e.routeManager.EnableServerRouter(e.firewall)
if err != nil {
e.close()
return fmt.Errorf("enable server router: %w", err)
return err
}
}
@@ -351,7 +286,7 @@ func (e *Engine) Start() error {
if err != nil {
log.Errorf("failed to pull up wgInterface [%s]: %s", e.wgInterface.Name(), err.Error())
e.close()
return fmt.Errorf("up wg interface: %w", err)
return err
}
if e.firewall != nil {
@@ -361,16 +296,13 @@ func (e *Engine) Start() error {
err = e.dnsServer.Initialize()
if err != nil {
e.close()
return fmt.Errorf("initialize dns server: %w", err)
return err
}
e.receiveSignalEvents()
e.receiveManagementEvents()
e.receiveProbeEvents()
// starting network monitor at the very last to avoid disruptions
e.startNetworkMonitor()
return nil
}
@@ -542,10 +474,6 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
// todo update signal
}
if err := e.updateChecksIfNew(update.Checks); err != nil {
return err
}
if update.GetNetworkMap() != nil {
// only apply new changes and ignore old ones
err := e.updateNetworkMap(update.GetNetworkMap())
@@ -553,27 +481,7 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
return err
}
}
return nil
}
// updateChecksIfNew updates checks if there are changes and sync new meta with management
func (e *Engine) updateChecksIfNew(checks []*mgmProto.Checks) error {
// if checks are equal, we skip the update
if isChecksEqual(e.checks, checks) {
return nil
}
e.checks = checks
info, err := system.GetInfoWithChecks(e.ctx, checks)
if err != nil {
log.Warnf("failed to get system info with checks: %v", err)
info = system.GetInfo(e.ctx)
}
if err := e.mgmClient.SyncMeta(info); err != nil {
log.Errorf("could not sync meta: error %s", err)
return err
}
return nil
}
@@ -589,8 +497,8 @@ func (e *Engine) updateSSH(sshConf *mgmProto.SSHConfig) error {
} else {
if sshConf.GetSshEnabled() {
if runtime.GOOS == "windows" || runtime.GOOS == "freebsd" {
log.Warnf("running SSH server on %s is not supported", runtime.GOOS)
if runtime.GOOS == "windows" {
log.Warnf("running SSH server on Windows is not supported")
return nil
}
// start SSH server if it wasn't running
@@ -663,19 +571,12 @@ func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error {
// E.g. when a new peer has been registered and we are allowed to connect to it.
func (e *Engine) receiveManagementEvents() {
go func() {
info, err := system.GetInfoWithChecks(e.ctx, e.checks)
if err != nil {
log.Warnf("failed to get system info with checks: %v", err)
info = system.GetInfo(e.ctx)
}
// err = e.mgmClient.Sync(info, e.handleSync)
err = e.mgmClient.Sync(e.ctx, info, e.handleSync)
err := e.mgmClient.Sync(e.handleSync)
if err != nil {
// happens if management is unavailable for a long time.
// We want to cancel the operation of the whole client
_ = CtxGetState(e.ctx).Wrap(ErrResetConnection)
e.clientCancel()
e.cancel()
return
}
log.Debugf("stopped receiving updates from Management Service")
@@ -737,20 +638,6 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
return nil
}
protoRoutes := networkMap.GetRoutes()
if protoRoutes == nil {
protoRoutes = []*mgmProto.Route{}
}
_, clientRoutes, err := e.routeManager.UpdateRoutes(serial, toRoutes(protoRoutes))
if err != nil {
log.Errorf("failed to update clientRoutes, err: %v", err)
}
e.clientRoutesMu.Lock()
e.clientRoutes = clientRoutes
e.clientRoutesMu.Unlock()
log.Debugf("got peers update from Management Service, total peers to connect to = %d", len(networkMap.GetRemotePeers()))
e.updateOfflinePeers(networkMap.GetOfflinePeers())
@@ -792,6 +679,14 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
}
}
}
protoRoutes := networkMap.GetRoutes()
if protoRoutes == nil {
protoRoutes = []*mgmProto.Route{}
}
err := e.routeManager.UpdateRoutes(serial, toRoutes(protoRoutes))
if err != nil {
log.Errorf("failed to update routes, err: %v", err)
}
protoDNSConfig := networkMap.GetDNSConfig()
if protoDNSConfig == nil {
@@ -803,40 +698,30 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
log.Errorf("failed to update dns server, err: %v", err)
}
if e.acl != nil {
e.acl.ApplyFiltering(networkMap)
}
e.networkSerial = serial
// Test received (upstream) servers for availability right away instead of upon usage.
// If no server of a server group responds this will disable the respective handler and retry later.
e.dnsServer.ProbeAvailability()
if e.acl != nil {
e.acl.ApplyFiltering(networkMap)
}
e.networkSerial = serial
return nil
}
func toRoutes(protoRoutes []*mgmProto.Route) []*route.Route {
routes := make([]*route.Route, 0)
for _, protoRoute := range protoRoutes {
var prefix netip.Prefix
if len(protoRoute.Domains) == 0 {
var err error
if prefix, err = netip.ParsePrefix(protoRoute.Network); err != nil {
log.Errorf("Failed to parse prefix %s: %v", protoRoute.Network, err)
continue
}
}
_, prefix, _ := route.ParseNetwork(protoRoute.Network)
convertedRoute := &route.Route{
ID: route.ID(protoRoute.ID),
ID: protoRoute.ID,
Network: prefix,
Domains: domain.FromPunycodeList(protoRoute.Domains),
NetID: route.NetID(protoRoute.NetID),
NetID: protoRoute.NetID,
NetworkType: route.NetworkType(protoRoute.NetworkType),
Peer: protoRoute.Peer,
Metric: int(protoRoute.Metric),
Masquerade: protoRoute.Masquerade,
KeepRoute: protoRoute.KeepRoute,
}
routes = append(routes, convertedRoute)
}
@@ -896,7 +781,6 @@ func (e *Engine) updateOfflinePeers(offlinePeers []*mgmProto.RemotePeerConfig) {
FQDN: offlinePeer.GetFqdn(),
ConnStatus: peer.StatusDisconnected,
ConnStatusUpdate: time.Now(),
Mux: new(sync.RWMutex),
}
}
e.statusRecorder.ReplaceOfflinePeers(replacement)
@@ -920,39 +804,27 @@ func (e *Engine) addNewPeer(peerConfig *mgmProto.RemotePeerConfig) error {
if _, ok := e.peerConns[peerKey]; !ok {
conn, err := e.createPeerConn(peerKey, strings.Join(peerIPs, ","))
if err != nil {
return fmt.Errorf("create peer connection: %w", err)
return err
}
e.peerConns[peerKey] = conn
if e.beforePeerHook != nil && e.afterPeerHook != nil {
conn.AddBeforeAddPeerHook(e.beforePeerHook)
conn.AddAfterRemovePeerHook(e.afterPeerHook)
}
err = e.statusRecorder.AddPeer(peerKey, peerConfig.Fqdn)
if err != nil {
log.Warnf("error adding peer %s to status recorder, got error: %v", peerKey, err)
}
e.wgConnWorker.Add(1)
go e.connWorker(conn, peerKey)
}
return nil
}
func (e *Engine) connWorker(conn *peer.Conn, peerKey string) {
defer e.wgConnWorker.Done()
for {
// randomize starting time a bit
min := 500
max := 2000
duration := time.Duration(rand.Intn(max-min)+min) * time.Millisecond
select {
case <-e.ctx.Done():
return
case <-time.After(duration):
}
time.Sleep(time.Duration(rand.Intn(max-min)+min) * time.Millisecond)
// if peer has been removed -> give up
if !e.peerExists(peerKey) {
@@ -970,12 +842,11 @@ func (e *Engine) connWorker(conn *peer.Conn, peerKey string) {
conn.UpdateStunTurn(append(e.STUNs, e.TURNs...))
e.syncMsgMux.Unlock()
err := conn.Open(e.ctx)
err := conn.Open()
if err != nil {
log.Debugf("connection to peer %s failed: %v", peerKey, err)
var connectionClosedError *peer.ConnectionClosedError
switch {
case errors.As(err, &connectionClosedError):
switch err.(type) {
case *peer.ConnectionClosedError:
// conn has been forced to close, so we exit the loop
return
default:
@@ -1039,6 +910,7 @@ func (e *Engine) createPeerConn(pubKey string, allowedIPs string) (*peer.Conn, e
WgConfig: wgConfig,
LocalWgPort: e.config.WgPort,
NATExternalIPs: e.parseNATExternalIPMappings(),
UserspaceBind: e.wgInterface.IsUserspaceBind(),
RosenpassPubKey: e.getRosenpassPubKey(),
RosenpassAddr: e.getRosenpassAddr(),
}
@@ -1085,7 +957,7 @@ func (e *Engine) createPeerConn(pubKey string, allowedIPs string) (*peer.Conn, e
func (e *Engine) receiveSignalEvents() {
go func() {
// connect to a stream of messages coming from the signal server
err := e.signal.Receive(e.ctx, func(msg *sProto.Message) error {
err := e.signal.Receive(func(msg *sProto.Message) error {
e.syncMsgMux.Lock()
defer e.syncMsgMux.Unlock()
@@ -1101,6 +973,8 @@ func (e *Engine) receiveSignalEvents() {
return err
}
conn.RegisterProtoSupportMeta(msg.Body.GetFeaturesSupported())
var rosenpassPubKey []byte
rosenpassAddr := ""
if msg.GetBody().GetRosenpassConfig() != nil {
@@ -1123,6 +997,8 @@ func (e *Engine) receiveSignalEvents() {
return err
}
conn.RegisterProtoSupportMeta(msg.GetBody().GetFeaturesSupported())
var rosenpassPubKey []byte
rosenpassAddr := ""
if msg.GetBody().GetRosenpassConfig() != nil {
@@ -1145,8 +1021,7 @@ func (e *Engine) receiveSignalEvents() {
log.Errorf("failed on parsing remote candidate %s -> %s", candidate, err)
return err
}
conn.OnRemoteCandidate(candidate, e.GetClientRoutes())
conn.OnRemoteCandidate(candidate)
case sProto.Body_MODE:
}
@@ -1156,7 +1031,7 @@ func (e *Engine) receiveSignalEvents() {
// happens if signal is unavailable for a long time.
// We want to cancel the operation of the whole client
_ = CtxGetState(e.ctx).Wrap(ErrResetConnection)
e.clientCancel()
e.cancel()
return
}
}()
@@ -1217,20 +1092,13 @@ func (e *Engine) parseNATExternalIPMappings() []string {
}
func (e *Engine) close() {
if e.wgProxyFactory != nil {
if err := e.wgProxyFactory.Free(); err != nil {
log.Errorf("failed closing ebpf proxy: %s", err)
}
if err := e.wgProxyFactory.Free(); err != nil {
log.Errorf("failed closing ebpf proxy: %s", err)
}
// stop/restore DNS first so dbus and friends don't complain because of a missing interface
if e.dnsServer != nil {
e.dnsServer.Stop()
e.dnsServer = nil
}
if e.routeManager != nil {
e.routeManager.Stop()
}
log.Debugf("removing Netbird interface %s", e.config.WgIfaceName)
@@ -1247,6 +1115,10 @@ func (e *Engine) close() {
}
}
if e.routeManager != nil {
e.routeManager.Stop()
}
if e.firewall != nil {
err := e.firewall.Reset()
if err != nil {
@@ -1260,8 +1132,7 @@ func (e *Engine) close() {
}
func (e *Engine) readInitialSettings() ([]*route.Route, *nbdns.Config, error) {
info := system.GetInfo(e.ctx)
netMap, err := e.mgmClient.GetNetworkMap(info)
netMap, err := e.mgmClient.GetNetworkMap()
if err != nil {
return nil, nil, err
}
@@ -1290,7 +1161,7 @@ func (e *Engine) newWgIface() (*iface.WGIface, error) {
default:
}
return iface.NewWGIFace(e.config.WgIfaceName, e.config.WgAddr, e.config.WgPort, e.config.WgPrivateKey.String(), iface.DefaultMTU, transportNet, mArgs, e.addrViaRoutes)
return iface.NewWGIFace(e.config.WgIfaceName, e.config.WgAddr, e.config.WgPort, e.config.WgPrivateKey.String(), iface.DefaultMTU, transportNet, mArgs)
}
func (e *Engine) wgInterfaceCreate() (err error) {
@@ -1317,21 +1188,14 @@ func (e *Engine) newDnsServer() ([]*route.Route, dns.Server, error) {
if err != nil {
return nil, nil, err
}
dnsServer := dns.NewDefaultServerPermanentUpstream(
e.ctx,
e.wgInterface,
e.mobileDep.HostDNSAddresses,
*dnsConfig,
e.mobileDep.NetworkChangeListener,
e.statusRecorder,
)
dnsServer := dns.NewDefaultServerPermanentUpstream(e.ctx, e.wgInterface, e.mobileDep.HostDNSAddresses, *dnsConfig, e.mobileDep.NetworkChangeListener)
go e.mobileDep.DnsReadyListener.OnReady()
return routes, dnsServer, nil
case "ios":
dnsServer := dns.NewDefaultServerIos(e.ctx, e.wgInterface, e.mobileDep.DnsManager, e.statusRecorder)
dnsServer := dns.NewDefaultServerIos(e.ctx, e.wgInterface, e.mobileDep.DnsManager)
return nil, dnsServer, nil
default:
dnsServer, err := dns.NewDefaultServer(e.ctx, e.wgInterface, e.config.CustomDNSAddress, e.statusRecorder)
dnsServer, err := dns.NewDefaultServer(e.ctx, e.wgInterface, e.config.CustomDNSAddress)
if err != nil {
return nil, nil, err
}
@@ -1339,31 +1203,6 @@ func (e *Engine) newDnsServer() ([]*route.Route, dns.Server, error) {
}
}
// GetClientRoutes returns the current routes from the route map
func (e *Engine) GetClientRoutes() route.HAMap {
e.clientRoutesMu.RLock()
defer e.clientRoutesMu.RUnlock()
return maps.Clone(e.clientRoutes)
}
// GetClientRoutesWithNetID returns the current routes from the route map, but the keys consist of the network ID only
func (e *Engine) GetClientRoutesWithNetID() map[route.NetID][]*route.Route {
e.clientRoutesMu.RLock()
defer e.clientRoutesMu.RUnlock()
routes := make(map[route.NetID][]*route.Route, len(e.clientRoutes))
for id, v := range e.clientRoutes {
routes[id.NetID()] = v
}
return routes
}
// GetRouteManager returns the route manager
func (e *Engine) GetRouteManager() routemanager.Manager {
return e.routeManager
}
func findIPFromInterfaceName(ifaceName string) (net.IP, error) {
iface, err := net.InterfaceByName(ifaceName)
if err != nil {
@@ -1464,72 +1303,3 @@ func (e *Engine) probeSTUNs() []relay.ProbeResult {
func (e *Engine) probeTURNs() []relay.ProbeResult {
return relay.ProbeAll(e.ctx, relay.ProbeTURN, e.TURNs)
}
func (e *Engine) restartEngine() {
if err := e.Stop(); err != nil {
log.Errorf("Failed to stop engine: %v", err)
}
if err := e.Start(); err != nil {
log.Errorf("Failed to start engine: %v", err)
}
}
func (e *Engine) startNetworkMonitor() {
if !e.config.NetworkMonitor {
log.Infof("Network monitor is disabled, not starting")
return
}
e.networkMonitor = networkmonitor.New()
go func() {
var mu sync.Mutex
var debounceTimer *time.Timer
// Start the network monitor with a callback, Start will block until the monitor is stopped,
// a network change is detected, or an error occurs on start up
err := e.networkMonitor.Start(e.ctx, func() {
// This function is called when a network change is detected
mu.Lock()
defer mu.Unlock()
if debounceTimer != nil {
debounceTimer.Stop()
}
// Set a new timer to debounce rapid network changes
debounceTimer = time.AfterFunc(1*time.Second, func() {
// This function is called after the debounce period
mu.Lock()
defer mu.Unlock()
log.Infof("Network monitor detected network change, restarting engine")
e.restartEngine()
})
})
if err != nil && !errors.Is(err, networkmonitor.ErrStopped) {
log.Errorf("Network monitor: %v", err)
}
}()
}
func (e *Engine) addrViaRoutes(addr netip.Addr) (bool, netip.Prefix, error) {
var vpnRoutes []netip.Prefix
for _, routes := range e.GetClientRoutes() {
if len(routes) > 0 && routes[0] != nil {
vpnRoutes = append(vpnRoutes, routes[0].Network)
}
}
if isVpn, prefix := systemops.IsAddrRouted(addr, vpnRoutes); isVpn {
return true, prefix, nil
}
return false, netip.Prefix{}, nil
}
// isChecksEqual checks if two slices of checks are equal.
func isChecksEqual(checks []*mgmProto.Checks, oChecks []*mgmProto.Checks) bool {
return slices.EqualFunc(checks, oChecks, func(checks, oChecks *mgmProto.Checks) bool {
return slices.Equal(checks.Files, oChecks.Files)
})
}

View File

@@ -17,13 +17,10 @@ import (
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.opentelemetry.io/otel"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"google.golang.org/grpc"
"google.golang.org/grpc/keepalive"
"github.com/netbirdio/management-integrations/integrations"
"github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/routemanager"
@@ -58,9 +55,9 @@ var (
)
func TestEngine_SSH(t *testing.T) {
// todo resolve test execution on freebsd
if runtime.GOOS == "windows" || runtime.GOOS == "freebsd" {
t.Skip("skipping TestEngine_SSH")
if runtime.GOOS == "windows" {
t.Skip("skipping TestEngine_SSH on Windows")
}
key, err := wgtypes.GeneratePrivateKey()
@@ -73,12 +70,12 @@ func TestEngine_SSH(t *testing.T) {
defer cancel()
engine := NewEngine(ctx, cancel, &signal.MockClient{}, &mgmt.MockClient{}, &EngineConfig{
WgIfaceName: "utun101",
WgAddr: "100.64.0.1/24",
WgPrivateKey: key,
WgPort: 33100,
WgIfaceName: "utun101",
WgAddr: "100.64.0.1/24",
WgPrivateKey: key,
WgPort: 33100,
ServerSSHAllowed: true,
}, MobileDependency{}, peer.NewRecorder("https://mgm"), nil)
}, MobileDependency{}, peer.NewRecorder("https://mgm"))
engine.dnsServer = &dns.MockServer{
UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil },
@@ -174,7 +171,7 @@ func TestEngine_SSH(t *testing.T) {
t.Fatal(err)
}
// time.Sleep(250 * time.Millisecond)
//time.Sleep(250 * time.Millisecond)
assert.NotNil(t, engine.sshServer)
assert.Contains(t, sshPeersRemoved, "MNHf3Ma6z6mdLbriAJbqhX7+nM/B71lgw2+91q3LfhU=")
@@ -212,16 +209,16 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
WgAddr: "100.64.0.1/24",
WgPrivateKey: key,
WgPort: 33100,
}, MobileDependency{}, peer.NewRecorder("https://mgm"), nil)
}, MobileDependency{}, peer.NewRecorder("https://mgm"))
newNet, err := stdnet.NewNet()
if err != nil {
t.Fatal(err)
}
engine.wgInterface, err = iface.NewWGIFace("utun102", "100.64.0.1/24", engine.config.WgPort, key.String(), iface.DefaultMTU, newNet, nil, nil)
engine.wgInterface, err = iface.NewWGIFace("utun102", "100.64.0.1/24", engine.config.WgPort, key.String(), iface.DefaultMTU, newNet, nil)
if err != nil {
t.Fatal(err)
}
engine.routeManager = routemanager.NewManager(ctx, key.PublicKey().String(), time.Minute, engine.wgInterface, engine.statusRecorder, nil)
engine.routeManager = routemanager.NewManager(ctx, key.PublicKey().String(), engine.wgInterface, engine.statusRecorder, nil)
engine.dnsServer = &dns.MockServer{
UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil },
}
@@ -230,7 +227,6 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
t.Fatal(err)
}
engine.udpMux = bind.NewUniversalUDPMuxDefault(bind.UniversalUDPMuxParams{UDPConn: conn})
engine.ctx = ctx
type testCase struct {
name string
@@ -394,7 +390,7 @@ func TestEngine_Sync(t *testing.T) {
// feed updates to Engine via mocked Management client
updates := make(chan *mgmtProto.SyncResponse)
defer close(updates)
syncFunc := func(ctx context.Context, info *system.Info, msgHandler func(msg *mgmtProto.SyncResponse) error) error {
syncFunc := func(msgHandler func(msg *mgmtProto.SyncResponse) error) error {
for msg := range updates {
err := msgHandler(msg)
if err != nil {
@@ -409,8 +405,7 @@ func TestEngine_Sync(t *testing.T) {
WgAddr: "100.64.0.1/24",
WgPrivateKey: key,
WgPort: 33100,
}, MobileDependency{}, peer.NewRecorder("https://mgm"), nil)
engine.ctx = ctx
}, MobileDependency{}, peer.NewRecorder("https://mgm"))
engine.dnsServer = &dns.MockServer{
UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil },
@@ -568,13 +563,12 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) {
WgAddr: wgAddr,
WgPrivateKey: key,
WgPort: 33100,
}, MobileDependency{}, peer.NewRecorder("https://mgm"), nil)
engine.ctx = ctx
}, MobileDependency{}, peer.NewRecorder("https://mgm"))
newNet, err := stdnet.NewNet()
if err != nil {
t.Fatal(err)
}
engine.wgInterface, err = iface.NewWGIFace(wgIfaceName, wgAddr, engine.config.WgPort, key.String(), iface.DefaultMTU, newNet, nil, nil)
engine.wgInterface, err = iface.NewWGIFace(wgIfaceName, wgAddr, engine.config.WgPort, key.String(), iface.DefaultMTU, newNet, nil)
assert.NoError(t, err, "shouldn't return error")
input := struct {
inputSerial uint64
@@ -582,10 +576,10 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) {
}{}
mockRouteManager := &routemanager.MockManager{
UpdateRoutesFunc: func(updateSerial uint64, newRoutes []*route.Route) (map[route.ID]*route.Route, route.HAMap, error) {
UpdateRoutesFunc: func(updateSerial uint64, newRoutes []*route.Route) error {
input.inputSerial = updateSerial
input.inputRoutes = newRoutes
return nil, nil, testCase.inputErr
return testCase.inputErr
},
}
@@ -602,8 +596,8 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) {
err = engine.updateNetworkMap(testCase.networkMap)
assert.NoError(t, err, "shouldn't return error")
assert.Equal(t, testCase.expectedSerial, input.inputSerial, "serial should match")
assert.Len(t, input.inputRoutes, testCase.expectedLen, "clientRoutes len should match")
assert.Equal(t, testCase.expectedRoutes, input.inputRoutes, "clientRoutes should match")
assert.Len(t, input.inputRoutes, testCase.expectedLen, "routes len should match")
assert.Equal(t, testCase.expectedRoutes, input.inputRoutes, "routes should match")
})
}
}
@@ -738,19 +732,17 @@ func TestEngine_UpdateNetworkMapWithDNSUpdate(t *testing.T) {
WgAddr: wgAddr,
WgPrivateKey: key,
WgPort: 33100,
}, MobileDependency{}, peer.NewRecorder("https://mgm"), nil)
engine.ctx = ctx
}, MobileDependency{}, peer.NewRecorder("https://mgm"))
newNet, err := stdnet.NewNet()
if err != nil {
t.Fatal(err)
}
engine.wgInterface, err = iface.NewWGIFace(wgIfaceName, wgAddr, 33100, key.String(), iface.DefaultMTU, newNet, nil, nil)
engine.wgInterface, err = iface.NewWGIFace(wgIfaceName, wgAddr, 33100, key.String(), iface.DefaultMTU, newNet, nil)
assert.NoError(t, err, "shouldn't return error")
mockRouteManager := &routemanager.MockManager{
UpdateRoutesFunc: func(updateSerial uint64, newRoutes []*route.Route) (map[route.ID]*route.Route, route.HAMap, error) {
return nil, nil, nil
UpdateRoutesFunc: func(updateSerial uint64, newRoutes []*route.Route) error {
return nil
},
}
@@ -811,13 +803,13 @@ func TestEngine_MultiplePeers(t *testing.T) {
ctx, cancel := context.WithCancel(CtxInitState(context.Background()))
defer cancel()
sigServer, signalAddr, err := startSignal(t)
sigServer, signalAddr, err := startSignal()
if err != nil {
t.Fatal(err)
return
}
defer sigServer.Stop()
mgmtServer, mgmtAddr, err := startManagement(t, dir)
mgmtServer, mgmtAddr, err := startManagement(dir)
if err != nil {
t.Fatal(err)
return
@@ -1009,14 +1001,10 @@ func createEngine(ctx context.Context, cancel context.CancelFunc, setupKey strin
WgPort: wgPort,
}
e, err := NewEngine(ctx, cancel, signalClient, mgmtClient, conf, MobileDependency{}, peer.NewRecorder("https://mgm"), nil), nil
e.ctx = ctx
return e, err
return NewEngine(ctx, cancel, signalClient, mgmtClient, conf, MobileDependency{}, peer.NewRecorder("https://mgm")), nil
}
func startSignal(t *testing.T) (*grpc.Server, string, error) {
t.Helper()
func startSignal() (*grpc.Server, string, error) {
s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp))
lis, err := net.Listen("tcp", "localhost:0")
@@ -1024,9 +1012,7 @@ func startSignal(t *testing.T) (*grpc.Server, string, error) {
log.Fatalf("failed to listen: %v", err)
}
srv, err := signalServer.NewServer(otel.Meter(""))
require.NoError(t, err)
proto.RegisterSignalExchangeServer(s, srv)
proto.RegisterSignalExchangeServer(s, signalServer.NewServer())
go func() {
if err = s.Serve(lis); err != nil {
@@ -1037,9 +1023,7 @@ func startSignal(t *testing.T) (*grpc.Server, string, error) {
return s, lis.Addr().String(), nil
}
func startManagement(t *testing.T, dataDir string) (*grpc.Server, string, error) {
t.Helper()
func startManagement(dataDir string) (*grpc.Server, string, error) {
config := &server.Config{
Stuns: []*server.Host{},
TURNConfig: &server.TURNConfig{},
@@ -1056,25 +1040,22 @@ func startManagement(t *testing.T, dataDir string) (*grpc.Server, string, error)
return nil, "", err
}
s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp))
store, cleanUp, err := server.NewTestStoreFromJson(context.Background(), config.Datadir)
store, err := server.NewStoreFromJson(config.Datadir, nil)
if err != nil {
return nil, "", err
}
t.Cleanup(cleanUp)
peersUpdateManager := server.NewPeersUpdateManager(nil)
eventStore := &activity.InMemoryEventStore{}
if err != nil {
return nil, "", err
}
ia, _ := integrations.NewIntegratedValidator(context.Background(), eventStore)
accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia)
accountManager, err := server.BuildManager(store, peersUpdateManager, nil, "", "", eventStore, nil, false)
if err != nil {
return nil, "", err
}
turnManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig)
mgmtServer, err := server.NewServer(context.Background(), config, accountManager, peersUpdateManager, turnManager, nil, nil)
mgmtServer, err := server.NewServer(config, accountManager, peersUpdateManager, turnManager, nil, nil)
if err != nil {
return nil, "", err
}

View File

@@ -68,7 +68,7 @@ func Login(ctx context.Context, config *Config, setupKey string, jwtToken string
}
serverKey, err := doMgmLogin(ctx, mgmClient, pubSSHKey)
if serverKey != nil && isRegistrationNeeded(err) {
if isRegistrationNeeded(err) {
log.Debugf("peer registration required")
_, err = registerPeer(ctx, *serverKey, mgmClient, setupKey, jwtToken, pubSSHKey)
return err

View File

@@ -1,21 +0,0 @@
package networkmonitor
import (
"context"
"errors"
"sync"
)
var ErrStopped = errors.New("monitor has been stopped")
// NetworkMonitor watches for changes in network configuration.
type NetworkMonitor struct {
cancel context.CancelFunc
wg sync.WaitGroup
mu sync.Mutex
}
// New creates a new network monitor.
func New() *NetworkMonitor {
return &NetworkMonitor{}
}

View File

@@ -1,95 +0,0 @@
//go:build (darwin && !ios) || dragonfly || freebsd || netbsd || openbsd
package networkmonitor
import (
"context"
"fmt"
"syscall"
"unsafe"
log "github.com/sirupsen/logrus"
"golang.org/x/net/route"
"golang.org/x/sys/unix"
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
)
func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, callback func()) error {
fd, err := unix.Socket(syscall.AF_ROUTE, syscall.SOCK_RAW, syscall.AF_UNSPEC)
if err != nil {
return fmt.Errorf("failed to open routing socket: %v", err)
}
defer func() {
if err := unix.Close(fd); err != nil {
log.Errorf("Network monitor: failed to close routing socket: %v", err)
}
}()
for {
select {
case <-ctx.Done():
return ErrStopped
default:
buf := make([]byte, 2048)
n, err := unix.Read(fd, buf)
if err != nil {
log.Errorf("Network monitor: failed to read from routing socket: %v", err)
continue
}
if n < unix.SizeofRtMsghdr {
log.Errorf("Network monitor: read from routing socket returned less than expected: %d bytes", n)
continue
}
msg := (*unix.RtMsghdr)(unsafe.Pointer(&buf[0]))
switch msg.Type {
// handle route changes
case unix.RTM_ADD, syscall.RTM_DELETE:
route, err := parseRouteMessage(buf[:n])
if err != nil {
log.Errorf("Network monitor: error parsing routing message: %v", err)
continue
}
if !route.Dst.Addr().IsUnspecified() {
continue
}
intf := "<nil>"
if route.Interface != nil {
intf = route.Interface.Name
}
switch msg.Type {
case unix.RTM_ADD:
log.Infof("Network monitor: default route changed: via %s, interface %s", route.Gw, intf)
go callback()
case unix.RTM_DELETE:
if nexthopv4.Intf != nil && route.Gw.Compare(nexthopv4.IP) == 0 || nexthopv6.Intf != nil && route.Gw.Compare(nexthopv6.IP) == 0 {
log.Infof("Network monitor: default route removed: via %s, interface %s", route.Gw, intf)
go callback()
}
}
}
}
}
}
func parseRouteMessage(buf []byte) (*systemops.Route, error) {
msgs, err := route.ParseRIB(route.RIBTypeRoute, buf)
if err != nil {
return nil, fmt.Errorf("parse RIB: %v", err)
}
if len(msgs) != 1 {
return nil, fmt.Errorf("unexpected RIB message msgs: %v", msgs)
}
msg, ok := msgs[0].(*route.RouteMessage)
if !ok {
return nil, fmt.Errorf("unexpected RIB message type: %T", msgs[0])
}
return systemops.MsgToRoute(msg)
}

View File

@@ -1,82 +0,0 @@
//go:build !ios && !android
package networkmonitor
import (
"context"
"errors"
"fmt"
"net/netip"
"runtime/debug"
"github.com/cenkalti/backoff/v4"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
)
// Start begins monitoring network changes. When a change is detected, it calls the callback asynchronously and returns.
func (nw *NetworkMonitor) Start(ctx context.Context, callback func()) (err error) {
if ctx.Err() != nil {
return ctx.Err()
}
nw.mu.Lock()
ctx, nw.cancel = context.WithCancel(ctx)
nw.mu.Unlock()
nw.wg.Add(1)
defer nw.wg.Done()
var nexthop4, nexthop6 systemops.Nexthop
operation := func() error {
var errv4, errv6 error
nexthop4, errv4 = systemops.GetNextHop(netip.IPv4Unspecified())
nexthop6, errv6 = systemops.GetNextHop(netip.IPv6Unspecified())
if errv4 != nil && errv6 != nil {
return errors.New("failed to get default next hops")
}
if errv4 == nil {
log.Debugf("Network monitor: IPv4 default route: %s, interface: %s", nexthop4.IP, nexthop4.Intf.Name)
}
if errv6 == nil {
log.Debugf("Network monitor: IPv6 default route: %s, interface: %s", nexthop6.IP, nexthop6.Intf.Name)
}
// continue if either route was found
return nil
}
expBackOff := backoff.WithContext(backoff.NewExponentialBackOff(), ctx)
if err := backoff.Retry(operation, expBackOff); err != nil {
return fmt.Errorf("failed to get default next hops: %w", err)
}
// recover in case sys ops panic
defer func() {
if r := recover(); r != nil {
err = fmt.Errorf("panic occurred: %v, stack trace: %s", r, string(debug.Stack()))
}
}()
if err := checkChange(ctx, nexthop4, nexthop6, callback); err != nil {
return fmt.Errorf("check change: %w", err)
}
return nil
}
// Stop stops the network monitor.
func (nw *NetworkMonitor) Stop() {
nw.mu.Lock()
defer nw.mu.Unlock()
if nw.cancel != nil {
nw.cancel()
nw.wg.Wait()
}
}

View File

@@ -1,57 +0,0 @@
//go:build !android
package networkmonitor
import (
"context"
"errors"
"fmt"
"syscall"
log "github.com/sirupsen/logrus"
"github.com/vishvananda/netlink"
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
)
func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, callback func()) error {
if nexthopv4.Intf == nil && nexthopv6.Intf == nil {
return errors.New("no interfaces available")
}
done := make(chan struct{})
defer close(done)
routeChan := make(chan netlink.RouteUpdate)
if err := netlink.RouteSubscribe(routeChan, done); err != nil {
return fmt.Errorf("subscribe to route updates: %v", err)
}
log.Info("Network monitor: started")
for {
select {
case <-ctx.Done():
return ErrStopped
// handle route changes
case route := <-routeChan:
// default route and main table
if route.Dst != nil || route.Table != syscall.RT_TABLE_MAIN {
continue
}
switch route.Type {
// triggered on added/replaced routes
case syscall.RTM_NEWROUTE:
log.Infof("Network monitor: default route changed: via %s, interface %d", route.Gw, route.LinkIndex)
go callback()
return nil
case syscall.RTM_DELROUTE:
if nexthopv4.Intf != nil && route.Gw.Equal(nexthopv4.IP.AsSlice()) || nexthopv6.Intf != nil && route.Gw.Equal(nexthopv6.IP.AsSlice()) {
log.Infof("Network monitor: default route removed: via %s, interface %d", route.Gw, route.LinkIndex)
go callback()
return nil
}
}
}
}
}

View File

@@ -1,12 +0,0 @@
//go:build ios || android
package networkmonitor
import "context"
func (nw *NetworkMonitor) Start(context.Context, func()) error {
return nil
}
func (nw *NetworkMonitor) Stop() {
}

View File

@@ -1,245 +0,0 @@
package networkmonitor
import (
"context"
"fmt"
"net"
"net/netip"
"strings"
"time"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
)
const (
unreachable = 0
incomplete = 1
probe = 2
delay = 3
stale = 4
reachable = 5
permanent = 6
tbd = 7
)
const interval = 10 * time.Second
func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, callback func()) error {
var neighborv4, neighborv6 *systemops.Neighbor
{
initialNeighbors, err := getNeighbors()
if err != nil {
return fmt.Errorf("get neighbors: %w", err)
}
neighborv4 = assignNeighbor(nexthopv4, initialNeighbors)
neighborv6 = assignNeighbor(nexthopv6, initialNeighbors)
}
log.Debugf("Network monitor: initial IPv4 neighbor: %v, IPv6 neighbor: %v", neighborv4, neighborv6)
ticker := time.NewTicker(interval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return ErrStopped
case <-ticker.C:
if changed(nexthopv4, neighborv4, nexthopv6, neighborv6) {
go callback()
return nil
}
}
}
}
func assignNeighbor(nexthop systemops.Nexthop, initialNeighbors map[netip.Addr]systemops.Neighbor) *systemops.Neighbor {
if n, ok := initialNeighbors[nexthop.IP]; ok &&
n.State != unreachable &&
n.State != incomplete &&
n.State != tbd {
return &n
}
return nil
}
func changed(
nexthopv4 systemops.Nexthop,
neighborv4 *systemops.Neighbor,
nexthopv6 systemops.Nexthop,
neighborv6 *systemops.Neighbor,
) bool {
neighbors, err := getNeighbors()
if err != nil {
log.Errorf("network monitor: error fetching current neighbors: %v", err)
return false
}
if neighborChanged(nexthopv4, neighborv4, neighbors) || neighborChanged(nexthopv6, neighborv6, neighbors) {
return true
}
routes, err := getRoutes()
if err != nil {
log.Errorf("network monitor: error fetching current routes: %v", err)
return false
}
if routeChanged(nexthopv4, nexthopv4.Intf, routes) || routeChanged(nexthopv6, nexthopv6.Intf, routes) {
return true
}
return false
}
// routeChanged checks if the default routes still point to our nexthop/interface
func routeChanged(nexthop systemops.Nexthop, intf *net.Interface, routes []systemops.Route) bool {
if !nexthop.IP.IsValid() {
return false
}
unspec := getUnspecifiedPrefix(nexthop.IP)
defaultRoutes, foundMatchingRoute := processRoutes(nexthop, intf, routes, unspec)
log.Tracef("network monitor: all default routes:\n%s", strings.Join(defaultRoutes, "\n"))
if !foundMatchingRoute {
logRouteChange(nexthop.IP, intf)
return true
}
return false
}
func getUnspecifiedPrefix(ip netip.Addr) netip.Prefix {
if ip.Is6() {
return netip.PrefixFrom(netip.IPv6Unspecified(), 0)
}
return netip.PrefixFrom(netip.IPv4Unspecified(), 0)
}
func processRoutes(nexthop systemops.Nexthop, intf *net.Interface, routes []systemops.Route, unspec netip.Prefix) ([]string, bool) {
var defaultRoutes []string
foundMatchingRoute := false
for _, r := range routes {
if r.Destination == unspec {
routeInfo := formatRouteInfo(r)
defaultRoutes = append(defaultRoutes, routeInfo)
if r.Nexthop == nexthop.IP && compareIntf(r.Interface, intf) == 0 {
foundMatchingRoute = true
log.Debugf("network monitor: found matching default route: %s", routeInfo)
}
}
}
return defaultRoutes, foundMatchingRoute
}
func formatRouteInfo(r systemops.Route) string {
newIntf := "<nil>"
if r.Interface != nil {
newIntf = r.Interface.Name
}
return fmt.Sprintf("Nexthop: %s, Interface: %s", r.Nexthop, newIntf)
}
func logRouteChange(ip netip.Addr, intf *net.Interface) {
oldIntf := "<nil>"
if intf != nil {
oldIntf = intf.Name
}
log.Infof("network monitor: default route for %s (%s) is gone or changed", ip, oldIntf)
}
func neighborChanged(nexthop systemops.Nexthop, neighbor *systemops.Neighbor, neighbors map[netip.Addr]systemops.Neighbor) bool {
if neighbor == nil {
return false
}
// TODO: consider non-local nexthops, e.g. on point-to-point interfaces
if n, ok := neighbors[nexthop.IP]; ok {
if n.State == unreachable || n.State == incomplete {
log.Infof("network monitor: neighbor %s (%s) is not reachable: %s", neighbor.IPAddress, neighbor.LinkLayerAddress, stateFromInt(n.State))
return true
} else if n.InterfaceIndex != neighbor.InterfaceIndex {
log.Infof(
"network monitor: neighbor %s (%s) changed interface from '%s' (%d) to '%s' (%d): %s",
neighbor.IPAddress,
neighbor.LinkLayerAddress,
neighbor.InterfaceAlias,
neighbor.InterfaceIndex,
n.InterfaceAlias,
n.InterfaceIndex,
stateFromInt(n.State),
)
return true
}
} else {
log.Infof("network monitor: neighbor %s (%s) is gone", neighbor.IPAddress, neighbor.LinkLayerAddress)
return true
}
return false
}
func getNeighbors() (map[netip.Addr]systemops.Neighbor, error) {
entries, err := systemops.GetNeighbors()
if err != nil {
return nil, fmt.Errorf("get neighbors: %w", err)
}
neighbours := make(map[netip.Addr]systemops.Neighbor, len(entries))
for _, entry := range entries {
neighbours[entry.IPAddress] = entry
}
return neighbours, nil
}
func getRoutes() ([]systemops.Route, error) {
entries, err := systemops.GetRoutes()
if err != nil {
return nil, fmt.Errorf("get routes: %w", err)
}
return entries, nil
}
func stateFromInt(state uint8) string {
switch state {
case unreachable:
return "unreachable"
case incomplete:
return "incomplete"
case probe:
return "probe"
case delay:
return "delay"
case stale:
return "stale"
case reachable:
return "reachable"
case permanent:
return "permanent"
case tbd:
return "tbd"
default:
return "unknown"
}
}
func compareIntf(a, b *net.Interface) int {
if a == nil && b == nil {
return 0
}
if a == nil {
return -1
}
if b == nil {
return 1
}
return a.Index - b.Index
}

View File

@@ -18,17 +18,14 @@ import (
"github.com/netbirdio/netbird/client/internal/wgproxy"
"github.com/netbirdio/netbird/iface"
"github.com/netbirdio/netbird/iface/bind"
"github.com/netbirdio/netbird/route"
signal "github.com/netbirdio/netbird/signal/client"
sProto "github.com/netbirdio/netbird/signal/proto"
nbnet "github.com/netbirdio/netbird/util/net"
"github.com/netbirdio/netbird/version"
)
const (
iceKeepAliveDefault = 4 * time.Second
iceDisconnectedTimeoutDefault = 6 * time.Second
// iceRelayAcceptanceMinWaitDefault is the same as in the Pion ICE package
iceRelayAcceptanceMinWaitDefault = 2 * time.Second
defaultWgKeepAlive = 25 * time.Second
)
@@ -68,6 +65,9 @@ type ConnConfig struct {
NATExternalIPs []string
// UsesBind indicates whether the WireGuard interface is userspace and uses bind.ICEBind
UserspaceBind bool
// RosenpassPubKey is this peer's Rosenpass public key
RosenpassPubKey []byte
// RosenpassPubKey is this peer's RosenpassAddr server address (IP:port)
@@ -127,13 +127,23 @@ type Conn struct {
wgProxyFactory *wgproxy.Factory
wgProxy wgproxy.Proxy
remoteModeCh chan ModeMessage
meta meta
adapter iface.TunAdapter
iFaceDiscover stdnet.ExternalIFaceDiscover
sentExtraSrflx bool
}
connID nbnet.ConnectionID
beforeAddPeerHooks []nbnet.AddHookFunc
afterRemovePeerHooks []nbnet.RemoveHookFunc
// meta holds meta information about a connection
type meta struct {
protoSupport signal.FeaturesSupport
}
// ModeMessage represents a connection mode chosen by the peer
type ModeMessage struct {
// Direct indicates that it decided to use a direct connection
Direct bool
}
// GetConf returns the connection config
@@ -162,6 +172,7 @@ func NewConn(config ConnConfig, statusRecorder *Status, wgProxyFactory *wgproxy.
remoteOffersCh: make(chan OfferAnswer),
remoteAnswerCh: make(chan OfferAnswer),
statusRecorder: statusRecorder,
remoteModeCh: make(chan ModeMessage, 1),
wgProxyFactory: wgProxyFactory,
adapter: adapter,
iFaceDiscover: iFaceDiscover,
@@ -182,22 +193,20 @@ func (conn *Conn) reCreateAgent() error {
iceKeepAlive := iceKeepAlive()
iceDisconnectedTimeout := iceDisconnectedTimeout()
iceRelayAcceptanceMinWait := iceRelayAcceptanceMinWait()
agentConfig := &ice.AgentConfig{
MulticastDNSMode: ice.MulticastDNSModeDisabled,
NetworkTypes: []ice.NetworkType{ice.NetworkTypeUDP4, ice.NetworkTypeUDP6},
Urls: conn.config.StunTurn,
CandidateTypes: conn.candidateTypes(),
FailedTimeout: &failedTimeout,
InterfaceFilter: stdnet.InterfaceFilter(conn.config.InterfaceBlackList),
UDPMux: conn.config.UDPMux,
UDPMuxSrflx: conn.config.UDPMuxSrflx,
NAT1To1IPs: conn.config.NATExternalIPs,
Net: transportNet,
DisconnectedTimeout: &iceDisconnectedTimeout,
KeepaliveInterval: &iceKeepAlive,
RelayAcceptanceMinWait: &iceRelayAcceptanceMinWait,
MulticastDNSMode: ice.MulticastDNSModeDisabled,
NetworkTypes: []ice.NetworkType{ice.NetworkTypeUDP4, ice.NetworkTypeUDP6},
Urls: conn.config.StunTurn,
CandidateTypes: conn.candidateTypes(),
FailedTimeout: &failedTimeout,
InterfaceFilter: stdnet.InterfaceFilter(conn.config.InterfaceBlackList),
UDPMux: conn.config.UDPMux,
UDPMuxSrflx: conn.config.UDPMuxSrflx,
NAT1To1IPs: conn.config.NATExternalIPs,
Net: transportNet,
DisconnectedTimeout: &iceDisconnectedTimeout,
KeepaliveInterval: &iceKeepAlive,
}
if conn.config.DisableIPv6Discovery {
@@ -205,6 +214,7 @@ func (conn *Conn) reCreateAgent() error {
}
conn.agent, err = ice.NewAgent(agentConfig)
if err != nil {
return err
}
@@ -224,17 +234,6 @@ func (conn *Conn) reCreateAgent() error {
return err
}
err = conn.agent.OnSuccessfulSelectedPairBindingResponse(func(p *ice.CandidatePair) {
err := conn.statusRecorder.UpdateLatency(conn.config.Key, p.Latency())
if err != nil {
log.Debugf("failed to update latency for peer %s: %s", conn.config.Key, err)
return
}
})
if err != nil {
return fmt.Errorf("failed setting binding response callback: %w", err)
}
return nil
}
@@ -252,7 +251,7 @@ func (conn *Conn) candidateTypes() []ice.CandidateType {
// Open opens connection to the remote peer starting ICE candidate gathering process.
// Blocks until connection has been closed or connection timeout.
// ConnStatus will be set accordingly
func (conn *Conn) Open(ctx context.Context) error {
func (conn *Conn) Open() error {
log.Debugf("trying to connect to peer %s", conn.config.Key)
peerState := State{
@@ -260,7 +259,6 @@ func (conn *Conn) Open(ctx context.Context) error {
IP: strings.Split(conn.config.WgConfig.AllowedIps, "/")[0],
ConnStatusUpdate: time.Now(),
ConnStatus: conn.status,
Mux: new(sync.RWMutex),
}
err := conn.statusRecorder.UpdatePeerState(peerState)
if err != nil {
@@ -312,7 +310,7 @@ func (conn *Conn) Open(ctx context.Context) error {
// at this point we received offer/answer and we are ready to gather candidates
conn.mu.Lock()
conn.status = StatusConnecting
conn.ctx, conn.notifyDisconnected = context.WithCancel(ctx)
conn.ctx, conn.notifyDisconnected = context.WithCancel(context.Background())
defer conn.notifyDisconnected()
conn.mu.Unlock()
@@ -320,7 +318,6 @@ func (conn *Conn) Open(ctx context.Context) error {
PubKey: conn.config.Key,
ConnStatus: conn.status,
ConnStatusUpdate: time.Now(),
Mux: new(sync.RWMutex),
}
err = conn.statusRecorder.UpdatePeerState(peerState)
if err != nil {
@@ -329,7 +326,7 @@ func (conn *Conn) Open(ctx context.Context) error {
err = conn.agent.GatherCandidates()
if err != nil {
return fmt.Errorf("gather candidates: %v", err)
return err
}
// will block until connection succeeded
@@ -346,12 +343,11 @@ func (conn *Conn) Open(ctx context.Context) error {
return err
}
// dynamically set remote WireGuard port if other side specified a different one from the default one
// dynamically set remote WireGuard port is other side specified a different one from the default one
remoteWgPort := iface.DefaultWgPort
if remoteOfferAnswer.WgListenPort != 0 {
remoteWgPort = remoteOfferAnswer.WgListenPort
}
// the ice connection has been established successfully so we are ready to start the proxy
remoteAddr, err := conn.configureConnection(remoteConn, remoteWgPort, remoteOfferAnswer.RosenpassPubKey,
remoteOfferAnswer.RosenpassAddr)
@@ -376,14 +372,6 @@ func isRelayCandidate(candidate ice.Candidate) bool {
return candidate.Type() == ice.CandidateTypeRelay
}
func (conn *Conn) AddBeforeAddPeerHook(hook nbnet.AddHookFunc) {
conn.beforeAddPeerHooks = append(conn.beforeAddPeerHooks, hook)
}
func (conn *Conn) AddAfterRemovePeerHook(hook nbnet.RemoveHookFunc) {
conn.afterRemovePeerHooks = append(conn.afterRemovePeerHooks, hook)
}
// configureConnection starts proxying traffic from/to local Wireguard and sets connection status to StatusConnected
func (conn *Conn) configureConnection(remoteConn net.Conn, remoteWgPort int, remoteRosenpassPubKey []byte, remoteRosenpassAddr string) (net.Addr, error) {
conn.mu.Lock()
@@ -397,7 +385,7 @@ func (conn *Conn) configureConnection(remoteConn net.Conn, remoteWgPort int, rem
var endpoint net.Addr
if isRelayCandidate(pair.Local) {
log.Debugf("setup relay connection")
conn.wgProxy = conn.wgProxyFactory.GetProxy(conn.ctx)
conn.wgProxy = conn.wgProxyFactory.GetProxy()
endpoint, err = conn.wgProxy.AddTurnConn(remoteConn)
if err != nil {
return nil, err
@@ -409,23 +397,13 @@ func (conn *Conn) configureConnection(remoteConn net.Conn, remoteWgPort int, rem
}
endpointUdpAddr, _ := net.ResolveUDPAddr(endpoint.Network(), endpoint.String())
log.Debugf("Conn resolved IP for %s: %s", endpoint, endpointUdpAddr.IP)
conn.connID = nbnet.GenerateConnID()
for _, hook := range conn.beforeAddPeerHooks {
if err := hook(conn.connID, endpointUdpAddr.IP); err != nil {
log.Errorf("Before add peer hook failed: %v", err)
}
}
err = conn.config.WgConfig.WgInterface.UpdatePeer(conn.config.WgConfig.RemoteKey, conn.config.WgConfig.AllowedIps, defaultWgKeepAlive, endpointUdpAddr, conn.config.WgConfig.PreSharedKey)
if err != nil {
if conn.wgProxy != nil {
if err := conn.wgProxy.CloseConn(); err != nil {
log.Warnf("Failed to close turn connection: %v", err)
}
_ = conn.wgProxy.CloseConn()
}
return nil, fmt.Errorf("update peer: %w", err)
return nil, err
}
conn.status = StatusConnected
@@ -441,10 +419,9 @@ func (conn *Conn) configureConnection(remoteConn net.Conn, remoteWgPort int, rem
LocalIceCandidateType: pair.Local.Type().String(),
RemoteIceCandidateType: pair.Remote.Type().String(),
LocalIceCandidateEndpoint: fmt.Sprintf("%s:%d", pair.Local.Address(), pair.Local.Port()),
RemoteIceCandidateEndpoint: fmt.Sprintf("%s:%d", pair.Remote.Address(), pair.Remote.Port()),
RemoteIceCandidateEndpoint: fmt.Sprintf("%s:%d", pair.Remote.Address(), pair.Local.Port()),
Direct: !isRelayCandidate(pair.Local),
RosenpassEnabled: rosenpassEnabled,
Mux: new(sync.RWMutex),
}
if pair.Local.Type() == ice.CandidateTypeRelay || pair.Remote.Type() == ice.CandidateTypeRelay {
peerState.Relayed = true
@@ -460,10 +437,6 @@ func (conn *Conn) configureConnection(remoteConn net.Conn, remoteWgPort int, rem
return nil, err
}
if runtime.GOOS == "ios" {
runtime.GC()
}
if conn.onConnected != nil {
conn.onConnected(conn.config.Key, remoteRosenpassPubKey, ipNet.IP.String(), remoteRosenpassAddr)
}
@@ -515,15 +488,6 @@ func (conn *Conn) cleanup() error {
// todo: is it problem if we try to remove a peer what is never existed?
err3 = conn.config.WgConfig.WgInterface.RemovePeer(conn.config.WgConfig.RemoteKey)
if conn.connID != "" {
for _, hook := range conn.afterRemovePeerHooks {
if err := hook(conn.connID); err != nil {
log.Errorf("After remove peer hook failed: %v", err)
}
}
}
conn.connID = ""
if conn.notifyDisconnected != nil {
conn.notifyDisconnected()
conn.notifyDisconnected = nil
@@ -539,7 +503,6 @@ func (conn *Conn) cleanup() error {
PubKey: conn.config.Key,
ConnStatus: conn.status,
ConnStatusUpdate: time.Now(),
Mux: new(sync.RWMutex),
}
err := conn.statusRecorder.UpdatePeerState(peerState)
if err != nil {
@@ -594,39 +557,40 @@ func (conn *Conn) SetSendSignalMessage(handler func(message *sProto.Message) err
// onICECandidate is a callback attached to an ICE Agent to receive new local connection candidates
// and then signals them to the remote peer
func (conn *Conn) onICECandidate(candidate ice.Candidate) {
// nil means candidate gathering has been ended
if candidate == nil {
return
if candidate != nil {
// TODO: reported port is incorrect for CandidateTypeHost, makes understanding ICE use via logs confusing as port is ignored
log.Debugf("discovered local candidate %s", candidate.String())
go func() {
err := conn.signalCandidate(candidate)
if err != nil {
log.Errorf("failed signaling candidate to the remote peer %s %s", conn.config.Key, err)
}
// sends an extra server reflexive candidate to the remote peer with our related port (usually the wireguard port)
// this is useful when network has an existing port forwarding rule for the wireguard port and this peer
if !conn.sentExtraSrflx && candidate.Type() == ice.CandidateTypeServerReflexive && candidate.Port() != candidate.RelatedAddress().Port {
relatedAdd := candidate.RelatedAddress()
extraSrflx, err := ice.NewCandidateServerReflexive(&ice.CandidateServerReflexiveConfig{
Network: candidate.NetworkType().String(),
Address: candidate.Address(),
Port: relatedAdd.Port,
Component: candidate.Component(),
RelAddr: relatedAdd.Address,
RelPort: relatedAdd.Port,
})
if err != nil {
log.Errorf("failed creating extra server reflexive candidate %s", err)
return
}
err = conn.signalCandidate(extraSrflx)
if err != nil {
log.Errorf("failed signaling the extra server reflexive candidate to the remote peer %s: %s", conn.config.Key, err)
return
}
conn.sentExtraSrflx = true
}
}()
}
// TODO: reported port is incorrect for CandidateTypeHost, makes understanding ICE use via logs confusing as port is ignored
log.Debugf("discovered local candidate %s", candidate.String())
go func() {
err := conn.signalCandidate(candidate)
if err != nil {
log.Errorf("failed signaling candidate to the remote peer %s %s", conn.config.Key, err)
}
}()
if !conn.shouldSendExtraSrflxCandidate(candidate) {
return
}
// sends an extra server reflexive candidate to the remote peer with our related port (usually the wireguard port)
// this is useful when network has an existing port forwarding rule for the wireguard port and this peer
extraSrflx, err := extraSrflxCandidate(candidate)
if err != nil {
log.Errorf("failed creating extra server reflexive candidate %s", err)
return
}
conn.sentExtraSrflx = true
go func() {
err = conn.signalCandidate(extraSrflx)
if err != nil {
log.Errorf("failed signaling the extra server reflexive candidate to the remote peer %s: %s", conn.config.Key, err)
}
}()
}
func (conn *Conn) onICESelectedCandidatePair(c1 ice.Candidate, c2 ice.Candidate) {
@@ -708,7 +672,7 @@ func (conn *Conn) Close() error {
// before conn.Open() another update from management arrives with peers: [1,2,3,4,5]
// engine adds a new Conn for 4 and 5
// therefore peer 4 has 2 Conn objects
log.Warnf("Connection has been already closed or attempted closing not started connection %s", conn.config.Key)
log.Warnf("connection has been already closed or attempted closing not started coonection %s", conn.config.Key)
return NewConnectionAlreadyClosed(conn.config.Key)
}
}
@@ -751,7 +715,7 @@ func (conn *Conn) OnRemoteAnswer(answer OfferAnswer) bool {
}
// OnRemoteCandidate Handles ICE connection Candidate provided by the remote peer.
func (conn *Conn) OnRemoteCandidate(candidate ice.Candidate, haRoutes route.HAMap) {
func (conn *Conn) OnRemoteCandidate(candidate ice.Candidate) {
log.Debugf("OnRemoteCandidate from peer %s -> %s", conn.config.Key, candidate.String())
go func() {
conn.mu.Lock()
@@ -773,21 +737,8 @@ func (conn *Conn) GetKey() string {
return conn.config.Key
}
func (conn *Conn) shouldSendExtraSrflxCandidate(candidate ice.Candidate) bool {
if !conn.sentExtraSrflx && candidate.Type() == ice.CandidateTypeServerReflexive && candidate.Port() != candidate.RelatedAddress().Port {
return true
}
return false
}
func extraSrflxCandidate(candidate ice.Candidate) (*ice.CandidateServerReflexive, error) {
relatedAdd := candidate.RelatedAddress()
return ice.NewCandidateServerReflexive(&ice.CandidateServerReflexiveConfig{
Network: candidate.NetworkType().String(),
Address: candidate.Address(),
Port: relatedAdd.Port,
Component: candidate.Component(),
RelAddr: relatedAdd.Address,
RelPort: relatedAdd.Port,
})
// RegisterProtoSupportMeta register supported proto message in the connection metadata
func (conn *Conn) RegisterProtoSupportMeta(support []uint32) {
protoSupport := signal.ParseFeaturesSupported(support)
conn.meta.protoSupport = protoSupport
}

View File

@@ -1,7 +1,6 @@
package peer
import (
"context"
"sync"
"testing"
"time"
@@ -36,7 +35,7 @@ func TestNewConn_interfaceFilter(t *testing.T) {
}
func TestConn_GetKey(t *testing.T) {
wgProxyFactory := wgproxy.NewFactory(context.Background(), false, connConf.LocalWgPort)
wgProxyFactory := wgproxy.NewFactory(connConf.LocalWgPort)
defer func() {
_ = wgProxyFactory.Free()
}()
@@ -51,7 +50,7 @@ func TestConn_GetKey(t *testing.T) {
}
func TestConn_OnRemoteOffer(t *testing.T) {
wgProxyFactory := wgproxy.NewFactory(context.Background(), false, connConf.LocalWgPort)
wgProxyFactory := wgproxy.NewFactory(connConf.LocalWgPort)
defer func() {
_ = wgProxyFactory.Free()
}()
@@ -88,7 +87,7 @@ func TestConn_OnRemoteOffer(t *testing.T) {
}
func TestConn_OnRemoteAnswer(t *testing.T) {
wgProxyFactory := wgproxy.NewFactory(context.Background(), false, connConf.LocalWgPort)
wgProxyFactory := wgproxy.NewFactory(connConf.LocalWgPort)
defer func() {
_ = wgProxyFactory.Free()
}()
@@ -124,7 +123,7 @@ func TestConn_OnRemoteAnswer(t *testing.T) {
wg.Wait()
}
func TestConn_Status(t *testing.T) {
wgProxyFactory := wgproxy.NewFactory(context.Background(), false, connConf.LocalWgPort)
wgProxyFactory := wgproxy.NewFactory(connConf.LocalWgPort)
defer func() {
_ = wgProxyFactory.Free()
}()
@@ -154,7 +153,7 @@ func TestConn_Status(t *testing.T) {
}
func TestConn_Close(t *testing.T) {
wgProxyFactory := wgproxy.NewFactory(context.Background(), false, connConf.LocalWgPort)
wgProxyFactory := wgproxy.NewFactory(connConf.LocalWgPort)
defer func() {
_ = wgProxyFactory.Free()
}()

View File

@@ -10,10 +10,9 @@ import (
)
const (
envICEKeepAliveIntervalSec = "NB_ICE_KEEP_ALIVE_INTERVAL_SEC"
envICEDisconnectedTimeoutSec = "NB_ICE_DISCONNECTED_TIMEOUT_SEC"
envICERelayAcceptanceMinWaitSec = "NB_ICE_RELAY_ACCEPTANCE_MIN_WAIT_SEC"
envICEForceRelayConn = "NB_ICE_FORCE_RELAY_CONN"
envICEKeepAliveIntervalSec = "NB_ICE_KEEP_ALIVE_INTERVAL_SEC"
envICEDisconnectedTimeoutSec = "NB_ICE_DISCONNECTED_TIMEOUT_SEC"
envICEForceRelayConn = "NB_ICE_FORCE_RELAY_CONN"
)
func iceKeepAlive() time.Duration {
@@ -22,7 +21,7 @@ func iceKeepAlive() time.Duration {
return iceKeepAliveDefault
}
log.Infof("setting ICE keep alive interval to %s seconds", keepAliveEnv)
log.Debugf("setting ICE keep alive interval to %s seconds", keepAliveEnv)
keepAliveEnvSec, err := strconv.Atoi(keepAliveEnv)
if err != nil {
log.Warnf("invalid value %s set for %s, using default %v", keepAliveEnv, envICEKeepAliveIntervalSec, iceKeepAliveDefault)
@@ -38,7 +37,7 @@ func iceDisconnectedTimeout() time.Duration {
return iceDisconnectedTimeoutDefault
}
log.Infof("setting ICE disconnected timeout to %s seconds", disconnectedTimeoutEnv)
log.Debugf("setting ICE disconnected timeout to %s seconds", disconnectedTimeoutEnv)
disconnectedTimeoutSec, err := strconv.Atoi(disconnectedTimeoutEnv)
if err != nil {
log.Warnf("invalid value %s set for %s, using default %v", disconnectedTimeoutEnv, envICEDisconnectedTimeoutSec, iceDisconnectedTimeoutDefault)
@@ -48,22 +47,6 @@ func iceDisconnectedTimeout() time.Duration {
return time.Duration(disconnectedTimeoutSec) * time.Second
}
func iceRelayAcceptanceMinWait() time.Duration {
iceRelayAcceptanceMinWaitEnv := os.Getenv(envICERelayAcceptanceMinWaitSec)
if iceRelayAcceptanceMinWaitEnv == "" {
return iceRelayAcceptanceMinWaitDefault
}
log.Infof("setting ICE relay acceptance min wait to %s seconds", iceRelayAcceptanceMinWaitEnv)
disconnectedTimeoutSec, err := strconv.Atoi(iceRelayAcceptanceMinWaitEnv)
if err != nil {
log.Warnf("invalid value %s set for %s, using default %v", iceRelayAcceptanceMinWaitEnv, envICERelayAcceptanceMinWaitSec, iceRelayAcceptanceMinWaitDefault)
return iceRelayAcceptanceMinWaitDefault
}
return time.Duration(disconnectedTimeoutSec) * time.Second
}
func hasICEForceRelayConn() bool {
disconnectedTimeoutEnv := os.Getenv(envICEForceRelayConn)
return strings.ToLower(disconnectedTimeoutEnv) == "true"

View File

@@ -2,22 +2,18 @@ package peer
import (
"errors"
"net/netip"
"sync"
"time"
"golang.org/x/exp/maps"
"google.golang.org/grpc/codes"
gstatus "google.golang.org/grpc/status"
"github.com/netbirdio/netbird/client/internal/relay"
"github.com/netbirdio/netbird/iface"
"github.com/netbirdio/netbird/management/domain"
)
// State contains the latest state of a peer
type State struct {
Mux *sync.RWMutex
IP string
PubKey string
FQDN string
@@ -32,40 +28,7 @@ type State struct {
LastWireguardHandshake time.Time
BytesTx int64
BytesRx int64
Latency time.Duration
RosenpassEnabled bool
routes map[string]struct{}
}
// AddRoute add a single route to routes map
func (s *State) AddRoute(network string) {
s.Mux.Lock()
defer s.Mux.Unlock()
if s.routes == nil {
s.routes = make(map[string]struct{})
}
s.routes[network] = struct{}{}
}
// SetRoutes set state routes
func (s *State) SetRoutes(routes map[string]struct{}) {
s.Mux.Lock()
defer s.Mux.Unlock()
s.routes = routes
}
// DeleteRoute removes a route from the network amp
func (s *State) DeleteRoute(network string) {
s.Mux.Lock()
defer s.Mux.Unlock()
delete(s.routes, network)
}
// GetRoutes return routes map
func (s *State) GetRoutes() map[string]struct{} {
s.Mux.RLock()
defer s.Mux.RUnlock()
return s.routes
}
// LocalPeerState contains the latest state of the local peer
@@ -74,7 +37,6 @@ type LocalPeerState struct {
PubKey string
KernelInterface bool
FQDN string
Routes map[string]struct{}
}
// SignalState contains the latest state of a signal connection
@@ -97,16 +59,6 @@ type RosenpassState struct {
Permissive bool
}
// NSGroupState represents the status of a DNS server group, including associated domains,
// whether it's enabled, and the last error message encountered during probing.
type NSGroupState struct {
ID string
Servers []string
Domains []string
Enabled bool
Error error
}
// FullStatus contains the full state held by the Status instance
type FullStatus struct {
Peers []State
@@ -115,28 +67,25 @@ type FullStatus struct {
LocalPeerState LocalPeerState
RosenpassState RosenpassState
Relays []relay.ProbeResult
NSGroupStates []NSGroupState
}
// Status holds a state of peers, signal, management connections and relays
type Status struct {
mux sync.Mutex
peers map[string]State
changeNotify map[string]chan struct{}
signalState bool
signalError error
managementState bool
managementError error
relayStates []relay.ProbeResult
localPeer LocalPeerState
offlinePeers []State
mgmAddress string
signalAddress string
notifier *notifier
rosenpassEnabled bool
rosenpassPermissive bool
nsGroupStates []NSGroupState
resolvedDomainsStates map[domain.Domain][]netip.Prefix
mux sync.Mutex
peers map[string]State
changeNotify map[string]chan struct{}
signalState bool
signalError error
managementState bool
managementError error
relayStates []relay.ProbeResult
localPeer LocalPeerState
offlinePeers []State
mgmAddress string
signalAddress string
notifier *notifier
rosenpassEnabled bool
rosenpassPermissive bool
// To reduce the number of notification invocation this bool will be true when need to call the notification
// Some Peer actions mostly used by in a batch when the network map has been synchronized. In these type of events
@@ -147,12 +96,11 @@ type Status struct {
// NewRecorder returns a new Status instance
func NewRecorder(mgmAddress string) *Status {
return &Status{
peers: make(map[string]State),
changeNotify: make(map[string]chan struct{}),
offlinePeers: make([]State, 0),
notifier: newNotifier(),
mgmAddress: mgmAddress,
resolvedDomainsStates: make(map[domain.Domain][]netip.Prefix),
peers: make(map[string]State),
changeNotify: make(map[string]chan struct{}),
offlinePeers: make([]State, 0),
notifier: newNotifier(),
mgmAddress: mgmAddress,
}
}
@@ -180,7 +128,6 @@ func (d *Status) AddPeer(peerPubKey string, fqdn string) error {
PubKey: peerPubKey,
ConnStatus: StatusDisconnected,
FQDN: fqdn,
Mux: new(sync.RWMutex),
}
d.peerListChangedForNotification = true
return nil
@@ -193,7 +140,7 @@ func (d *Status) GetPeer(peerPubKey string) (State, error) {
state, ok := d.peers[peerPubKey]
if !ok {
return State{}, iface.ErrPeerNotFound
return State{}, errors.New("peer not found")
}
return state, nil
}
@@ -227,10 +174,6 @@ func (d *Status) UpdatePeerState(receivedState State) error {
peerState.IP = receivedState.IP
}
if receivedState.GetRoutes() != nil {
peerState.SetRoutes(receivedState.GetRoutes())
}
skipNotification := shouldSkipNotify(receivedState, peerState)
if receivedState.ConnStatus != peerState.ConnStatus {
@@ -335,13 +278,6 @@ func (d *Status) GetPeerStateChangeNotifier(peer string) <-chan struct{} {
return ch
}
// GetLocalPeerState returns the local peer state
func (d *Status) GetLocalPeerState() LocalPeerState {
d.mux.Lock()
defer d.mux.Unlock()
return d.localPeer
}
// UpdateLocalPeerState updates local peer status
func (d *Status) UpdateLocalPeerState(localPeerState LocalPeerState) {
d.mux.Lock()
@@ -428,24 +364,6 @@ func (d *Status) UpdateRelayStates(relayResults []relay.ProbeResult) {
d.relayStates = relayResults
}
func (d *Status) UpdateDNSStates(dnsStates []NSGroupState) {
d.mux.Lock()
defer d.mux.Unlock()
d.nsGroupStates = dnsStates
}
func (d *Status) UpdateResolvedDomainsStates(domain domain.Domain, prefixes []netip.Prefix) {
d.mux.Lock()
defer d.mux.Unlock()
d.resolvedDomainsStates[domain] = prefixes
}
func (d *Status) DeleteResolvedDomainsStates(domain domain.Domain) {
d.mux.Lock()
defer d.mux.Unlock()
delete(d.resolvedDomainsStates, domain)
}
func (d *Status) GetRosenpassState() RosenpassState {
return RosenpassState{
d.rosenpassEnabled,
@@ -461,22 +379,6 @@ func (d *Status) GetManagementState() ManagementState {
}
}
func (d *Status) UpdateLatency(pubKey string, latency time.Duration) error {
if latency <= 0 {
return nil
}
d.mux.Lock()
defer d.mux.Unlock()
peerState, ok := d.peers[pubKey]
if !ok {
return errors.New("peer doesn't exist")
}
peerState.Latency = latency
d.peers[pubKey] = peerState
return nil
}
// IsLoginRequired determines if a peer's login has expired.
func (d *Status) IsLoginRequired() bool {
d.mux.Lock()
@@ -490,6 +392,7 @@ func (d *Status) IsLoginRequired() bool {
s, ok := gstatus.FromError(d.managementError)
if ok && (s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied) {
return true
}
return false
}
@@ -506,16 +409,6 @@ func (d *Status) GetRelayStates() []relay.ProbeResult {
return d.relayStates
}
func (d *Status) GetDNSStates() []NSGroupState {
return d.nsGroupStates
}
func (d *Status) GetResolvedDomainsStates() map[domain.Domain][]netip.Prefix {
d.mux.Lock()
defer d.mux.Unlock()
return maps.Clone(d.resolvedDomainsStates)
}
// GetFullStatus gets full status
func (d *Status) GetFullStatus() FullStatus {
d.mux.Lock()
@@ -527,7 +420,6 @@ func (d *Status) GetFullStatus() FullStatus {
LocalPeerState: d.localPeer,
Relays: d.GetRelayStates(),
RosenpassState: d.GetRosenpassState(),
NSGroupStates: d.GetDNSStates(),
}
for _, status := range d.peers {

View File

@@ -3,7 +3,6 @@ package peer
import (
"errors"
"testing"
"sync"
"github.com/stretchr/testify/assert"
)
@@ -43,7 +42,6 @@ func TestUpdatePeerState(t *testing.T) {
status := NewRecorder("https://mgm")
peerState := State{
PubKey: key,
Mux: new(sync.RWMutex),
}
status.peers[key] = peerState
@@ -64,7 +62,6 @@ func TestStatus_UpdatePeerFQDN(t *testing.T) {
status := NewRecorder("https://mgm")
peerState := State{
PubKey: key,
Mux: new(sync.RWMutex),
}
status.peers[key] = peerState
@@ -83,7 +80,6 @@ func TestGetPeerStateChangeNotifierLogic(t *testing.T) {
status := NewRecorder("https://mgm")
peerState := State{
PubKey: key,
Mux: new(sync.RWMutex),
}
status.peers[key] = peerState
@@ -108,7 +104,6 @@ func TestRemovePeer(t *testing.T) {
status := NewRecorder("https://mgm")
peerState := State{
PubKey: key,
Mux: new(sync.RWMutex),
}
status.peers[key] = peerState

View File

@@ -10,9 +10,6 @@ import (
"github.com/pion/stun/v2"
"github.com/pion/turn/v3"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/stdnet"
nbnet "github.com/netbirdio/netbird/util/net"
)
// ProbeResult holds the info about the result of a relay probe request
@@ -30,15 +27,7 @@ func ProbeSTUN(ctx context.Context, uri *stun.URI) (addr string, probeErr error)
}
}()
net, err := stdnet.NewNet(nil)
if err != nil {
probeErr = fmt.Errorf("new net: %w", err)
return
}
client, err := stun.DialURI(uri, &stun.DialConfig{
Net: net,
})
client, err := stun.DialURI(uri, &stun.DialConfig{})
if err != nil {
probeErr = fmt.Errorf("dial: %w", err)
return
@@ -96,13 +85,14 @@ func ProbeTURN(ctx context.Context, uri *stun.URI) (addr string, probeErr error)
switch uri.Proto {
case stun.ProtoTypeUDP:
var err error
conn, err = nbnet.NewListener().ListenPacket(ctx, "udp", "")
conn, err = net.ListenPacket("udp", "")
if err != nil {
probeErr = fmt.Errorf("listen: %w", err)
return
}
case stun.ProtoTypeTCP:
tcpConn, err := nbnet.NewDialer().DialContext(ctx, "tcp", turnServerAddr)
dialer := net.Dialer{}
tcpConn, err := dialer.DialContext(ctx, "tcp", turnServerAddr)
if err != nil {
probeErr = fmt.Errorf("dial: %w", err)
return
@@ -119,18 +109,12 @@ func ProbeTURN(ctx context.Context, uri *stun.URI) (addr string, probeErr error)
}
}()
net, err := stdnet.NewNet(nil)
if err != nil {
probeErr = fmt.Errorf("new net: %w", err)
return
}
cfg := &turn.ClientConfig{
STUNServerAddr: turnServerAddr,
TURNServerAddr: turnServerAddr,
Conn: conn,
Username: uri.Username,
Password: uri.Password,
Net: net,
}
client, err := turn.NewClient(cfg)
if err != nil {
@@ -170,7 +154,7 @@ func ProbeAll(
var wg sync.WaitGroup
for i, uri := range relays {
ctx, cancel := context.WithTimeout(ctx, 2*time.Second)
ctx, cancel := context.WithTimeout(ctx, 1*time.Second)
defer cancel()
wg.Add(1)

View File

@@ -3,25 +3,21 @@ package routemanager
import (
"context"
"fmt"
"time"
"net/netip"
"github.com/hashicorp/go-multierror"
log "github.com/sirupsen/logrus"
nberrors "github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/routemanager/dynamic"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
"github.com/netbirdio/netbird/client/internal/routemanager/static"
"github.com/netbirdio/netbird/iface"
"github.com/netbirdio/netbird/route"
)
const minRangeBits = 7
type routerPeerStatus struct {
connected bool
relayed bool
direct bool
latency time.Duration
}
type routesUpdate struct {
@@ -29,48 +25,38 @@ type routesUpdate struct {
routes []*route.Route
}
// RouteHandler defines the interface for handling routes
type RouteHandler interface {
String() string
AddRoute(ctx context.Context) error
RemoveRoute() error
AddAllowedIPs(peerKey string) error
RemoveAllowedIPs() error
}
type clientNetwork struct {
ctx context.Context
cancel context.CancelFunc
stop context.CancelFunc
statusRecorder *peer.Status
wgInterface *iface.WGIface
routes map[route.ID]*route.Route
routes map[string]*route.Route
routeUpdate chan routesUpdate
peerStateUpdate chan struct{}
routePeersNotifiers map[string]chan struct{}
currentChosen *route.Route
handler RouteHandler
chosenRoute *route.Route
network netip.Prefix
updateSerial uint64
}
func newClientNetworkWatcher(ctx context.Context, dnsRouteInterval time.Duration, wgInterface *iface.WGIface, statusRecorder *peer.Status, rt *route.Route, routeRefCounter *refcounter.RouteRefCounter, allowedIPsRefCounter *refcounter.AllowedIPsRefCounter) *clientNetwork {
func newClientNetworkWatcher(ctx context.Context, wgInterface *iface.WGIface, statusRecorder *peer.Status, network netip.Prefix) *clientNetwork {
ctx, cancel := context.WithCancel(ctx)
client := &clientNetwork{
ctx: ctx,
cancel: cancel,
stop: cancel,
statusRecorder: statusRecorder,
wgInterface: wgInterface,
routes: make(map[route.ID]*route.Route),
routes: make(map[string]*route.Route),
routePeersNotifiers: make(map[string]chan struct{}),
routeUpdate: make(chan routesUpdate),
peerStateUpdate: make(chan struct{}),
handler: handlerFromRoute(rt, routeRefCounter, allowedIPsRefCounter, dnsRouteInterval, statusRecorder),
network: network,
}
return client
}
func (c *clientNetwork) getRouterPeerStatuses() map[route.ID]routerPeerStatus {
routePeerStatuses := make(map[route.ID]routerPeerStatus)
func (c *clientNetwork) getRouterPeerStatuses() map[string]routerPeerStatus {
routePeerStatuses := make(map[string]routerPeerStatus)
for _, r := range c.routes {
peerStatus, err := c.statusRecorder.GetPeer(r.Peer)
if err != nil {
@@ -81,37 +67,22 @@ func (c *clientNetwork) getRouterPeerStatuses() map[route.ID]routerPeerStatus {
connected: peerStatus.ConnStatus == peer.StatusConnected,
relayed: peerStatus.Relayed,
direct: peerStatus.Direct,
latency: peerStatus.Latency,
}
}
return routePeerStatuses
}
// getBestRouteFromStatuses determines the most optimal route from the available routes
// within a clientNetwork, taking into account peer connection status, route metrics, and
// preference for non-relayed and direct connections.
//
// It follows these prioritization rules:
// * Connected peers: Only routes with connected peers are considered.
// * Metric: Routes with lower metrics (better) are prioritized.
// * Non-relayed: Routes without relays are preferred.
// * Direct connections: Routes with direct peer connections are favored.
// * Latency: Routes with lower latency are prioritized.
// * Stability: In case of equal scores, the currently active route (if any) is maintained.
//
// It returns the ID of the selected optimal route.
func (c *clientNetwork) getBestRouteFromStatuses(routePeerStatuses map[route.ID]routerPeerStatus) route.ID {
chosen := route.ID("")
chosenScore := float64(0)
currScore := float64(0)
func (c *clientNetwork) getBestRouteFromStatuses(routePeerStatuses map[string]routerPeerStatus) string {
chosen := ""
chosenScore := 0
currID := route.ID("")
if c.currentChosen != nil {
currID = c.currentChosen.ID
currID := ""
if c.chosenRoute != nil {
currID = c.chosenRoute.ID
}
for _, r := range c.routes {
tempScore := float64(0)
tempScore := 0
peerStatus, found := routePeerStatuses[r.ID]
if !found || !peerStatus.connected {
continue
@@ -119,18 +90,9 @@ func (c *clientNetwork) getBestRouteFromStatuses(routePeerStatuses map[route.ID]
if r.Metric < route.MaxMetric {
metricDiff := route.MaxMetric - r.Metric
tempScore = float64(metricDiff) * 10
tempScore = metricDiff * 10
}
// in some temporal cases, latency can be 0, so we set it to 1s to not block but try to avoid this route
latency := time.Second
if peerStatus.latency != 0 {
latency = peerStatus.latency
} else {
log.Warnf("peer %s has 0 latency", r.Peer)
}
tempScore += 1 - latency.Seconds()
if !peerStatus.relayed {
tempScore++
}
@@ -139,7 +101,7 @@ func (c *clientNetwork) getBestRouteFromStatuses(routePeerStatuses map[route.ID]
tempScore++
}
if tempScore > chosenScore || (tempScore == chosenScore && chosen == "") {
if tempScore > chosenScore || (tempScore == chosenScore && r.ID == currID) {
chosen = r.ID
chosenScore = tempScore
}
@@ -148,31 +110,18 @@ func (c *clientNetwork) getBestRouteFromStatuses(routePeerStatuses map[route.ID]
chosen = r.ID
chosenScore = tempScore
}
if r.ID == currID {
currScore = tempScore
}
}
switch {
case chosen == "":
if chosen == "" {
var peers []string
for _, r := range c.routes {
peers = append(peers, r.Peer)
}
log.Warnf("The network [%v] has not been assigned a routing peer as no peers from the list %s are currently connected", c.handler, peers)
case chosen != currID:
// we compare the current score + 10ms to the chosen score to avoid flapping between routes
if currScore != 0 && currScore+0.01 > chosenScore {
log.Debugf("Keeping current routing peer because the score difference with latency is less than 0.01(10ms), current: %f, new: %f", currScore, chosenScore)
return currID
}
var p string
if rt := c.routes[chosen]; rt != nil {
p = rt.Peer
}
log.Infof("New chosen route is %s with peer %s with score %f for network [%v]", chosen, p, chosenScore, c.handler)
log.Warnf("the network %s has not been assigned a routing peer as no peers from the list %s are currently connected", c.network, peers)
} else if chosen != currID {
log.Infof("new chosen route is %s with peer %s with score %d for network %s", chosen, c.routes[chosen].Peer, chosenScore, c.network)
}
return chosen
@@ -206,101 +155,83 @@ func (c *clientNetwork) startPeersStatusChangeWatcher() {
}
}
func (c *clientNetwork) removeRouteFromWireguardPeer() error {
c.removeStateRoute()
func (c *clientNetwork) removeRouteFromWireguardPeer(peerKey string) error {
state, err := c.statusRecorder.GetPeer(peerKey)
if err != nil {
return err
}
if state.ConnStatus != peer.StatusConnected {
return nil
}
if err := c.handler.RemoveAllowedIPs(); err != nil {
return fmt.Errorf("remove allowed IPs: %w", err)
err = c.wgInterface.RemoveAllowedIP(peerKey, c.network.String())
if err != nil {
return fmt.Errorf("couldn't remove allowed IP %s removed for peer %s, err: %v",
c.network, c.chosenRoute.Peer, err)
}
return nil
}
func (c *clientNetwork) removeRouteFromPeerAndSystem() error {
if c.currentChosen == nil {
return nil
}
var merr *multierror.Error
if err := c.removeRouteFromWireguardPeer(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove allowed IPs for peer %s: %w", c.currentChosen.Peer, err))
}
if err := c.handler.RemoveRoute(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove route: %w", err))
}
return nberrors.FormatErrorOrNil(merr)
}
func (c *clientNetwork) recalculateRouteAndUpdatePeerAndSystem() error {
routerPeerStatuses := c.getRouterPeerStatuses()
newChosenID := c.getBestRouteFromStatuses(routerPeerStatuses)
// If no route is chosen, remove the route from the peer and system
if newChosenID == "" {
if err := c.removeRouteFromPeerAndSystem(); err != nil {
return fmt.Errorf("remove route for peer %s: %w", c.currentChosen.Peer, err)
if c.chosenRoute != nil {
err := c.removeRouteFromWireguardPeer(c.chosenRoute.Peer)
if err != nil {
return err
}
c.currentChosen = nil
return nil
}
// If the chosen route is the same as the current route, do nothing
if c.currentChosen != nil && c.currentChosen.ID == newChosenID &&
c.currentChosen.IsEqual(c.routes[newChosenID]) {
return nil
}
if c.currentChosen == nil {
// If they were not previously assigned to another peer, add routes to the system first
if err := c.handler.AddRoute(c.ctx); err != nil {
return fmt.Errorf("add route: %w", err)
}
} else {
// Otherwise, remove the allowed IPs from the previous peer first
if err := c.removeRouteFromWireguardPeer(); err != nil {
return fmt.Errorf("remove allowed IPs for peer %s: %w", c.currentChosen.Peer, err)
err = removeFromRouteTableIfNonSystem(c.network, c.wgInterface.Address().IP.String())
if err != nil {
return fmt.Errorf("couldn't remove route %s from system, err: %v",
c.network, err)
}
}
c.currentChosen = c.routes[newChosenID]
if err := c.handler.AddAllowedIPs(c.currentChosen.Peer); err != nil {
return fmt.Errorf("add allowed IPs for peer %s: %w", c.currentChosen.Peer, err)
}
c.addStateRoute()
return nil
}
func (c *clientNetwork) addStateRoute() {
state, err := c.statusRecorder.GetPeer(c.currentChosen.Peer)
func (c *clientNetwork) recalculateRouteAndUpdatePeerAndSystem() error {
var err error
routerPeerStatuses := c.getRouterPeerStatuses()
chosen := c.getBestRouteFromStatuses(routerPeerStatuses)
if chosen == "" {
err = c.removeRouteFromPeerAndSystem()
if err != nil {
return err
}
c.chosenRoute = nil
return nil
}
if c.chosenRoute != nil && c.chosenRoute.ID == chosen {
if c.chosenRoute.IsEqual(c.routes[chosen]) {
return nil
}
}
if c.chosenRoute != nil {
err = c.removeRouteFromWireguardPeer(c.chosenRoute.Peer)
if err != nil {
return err
}
} else {
err = addToRouteTableIfNoExists(c.network, c.wgInterface.Address().IP.String())
if err != nil {
return fmt.Errorf("route %s couldn't be added for peer %s, err: %v",
c.network.String(), c.wgInterface.Address().IP.String(), err)
}
}
c.chosenRoute = c.routes[chosen]
err = c.wgInterface.AddAllowedIP(c.chosenRoute.Peer, c.network.String())
if err != nil {
log.Errorf("Failed to get peer state: %v", err)
return
log.Errorf("couldn't add allowed IP %s added for peer %s, err: %v",
c.network, c.chosenRoute.Peer, err)
}
state.AddRoute(c.handler.String())
if err := c.statusRecorder.UpdatePeerState(state); err != nil {
log.Warnf("Failed to update peer state: %v", err)
}
}
func (c *clientNetwork) removeStateRoute() {
state, err := c.statusRecorder.GetPeer(c.currentChosen.Peer)
if err != nil {
log.Errorf("Failed to get peer state: %v", err)
return
}
state.DeleteRoute(c.handler.String())
if err := c.statusRecorder.UpdatePeerState(state); err != nil {
log.Warnf("Failed to update peer state: %v", err)
}
return nil
}
func (c *clientNetwork) sendUpdateToClientNetworkWatcher(update routesUpdate) {
@@ -310,7 +241,7 @@ func (c *clientNetwork) sendUpdateToClientNetworkWatcher(update routesUpdate) {
}
func (c *clientNetwork) handleUpdate(update routesUpdate) {
updateMap := make(map[route.ID]*route.Route)
updateMap := make(map[string]*route.Route)
for _, r := range update.routes {
updateMap[r.ID] = r
@@ -333,23 +264,24 @@ func (c *clientNetwork) peersStateAndUpdateWatcher() {
for {
select {
case <-c.ctx.Done():
log.Debugf("Stopping watcher for network [%v]", c.handler)
if err := c.removeRouteFromPeerAndSystem(); err != nil {
log.Errorf("Failed to remove routes for [%v]: %v", c.handler, err)
log.Debugf("stopping watcher for network %s", c.network)
err := c.removeRouteFromPeerAndSystem()
if err != nil {
log.Error(err)
}
return
case <-c.peerStateUpdate:
err := c.recalculateRouteAndUpdatePeerAndSystem()
if err != nil {
log.Errorf("Failed to recalculate routes for network [%v]: %v", c.handler, err)
log.Error(err)
}
case update := <-c.routeUpdate:
if update.updateSerial < c.updateSerial {
log.Warnf("Received a routes update with smaller serial number (%d -> %d), ignoring it", c.updateSerial, update.updateSerial)
log.Warnf("received a routes update with smaller serial number, ignoring it")
continue
}
log.Debugf("Received a new client network route update for [%v]", c.handler)
log.Debugf("received a new client network route update for %s", c.network)
c.handleUpdate(update)
@@ -357,17 +289,10 @@ func (c *clientNetwork) peersStateAndUpdateWatcher() {
err := c.recalculateRouteAndUpdatePeerAndSystem()
if err != nil {
log.Errorf("Failed to recalculate routes for network [%v]: %v", c.handler, err)
log.Error(err)
}
c.startPeersStatusChangeWatcher()
}
}
}
func handlerFromRoute(rt *route.Route, routeRefCounter *refcounter.RouteRefCounter, allowedIPsRefCounter *refcounter.AllowedIPsRefCounter, dnsRouterInteval time.Duration, statusRecorder *peer.Status) RouteHandler {
if rt.IsDynamic() {
return dynamic.NewRoute(rt, routeRefCounter, allowedIPsRefCounter, dnsRouterInteval, statusRecorder)
}
return static.NewRoute(rt, routeRefCounter, allowedIPsRefCounter)
}

View File

@@ -3,9 +3,7 @@ package routemanager
import (
"net/netip"
"testing"
"time"
"github.com/netbirdio/netbird/client/internal/routemanager/static"
"github.com/netbirdio/netbird/route"
)
@@ -13,90 +11,90 @@ func TestGetBestrouteFromStatuses(t *testing.T) {
testCases := []struct {
name string
statuses map[route.ID]routerPeerStatus
expectedRouteID route.ID
currentRoute route.ID
existingRoutes map[route.ID]*route.Route
statuses map[string]routerPeerStatus
expectedRouteID string
currentRoute *route.Route
existingRoutes map[string]*route.Route
}{
{
name: "one route",
statuses: map[route.ID]routerPeerStatus{
statuses: map[string]routerPeerStatus{
"route1": {
connected: true,
relayed: false,
direct: true,
},
},
existingRoutes: map[route.ID]*route.Route{
existingRoutes: map[string]*route.Route{
"route1": {
ID: "route1",
Metric: route.MaxMetric,
Peer: "peer1",
},
},
currentRoute: "",
currentRoute: nil,
expectedRouteID: "route1",
},
{
name: "one connected routes with relayed and direct",
statuses: map[route.ID]routerPeerStatus{
statuses: map[string]routerPeerStatus{
"route1": {
connected: true,
relayed: true,
direct: true,
},
},
existingRoutes: map[route.ID]*route.Route{
existingRoutes: map[string]*route.Route{
"route1": {
ID: "route1",
Metric: route.MaxMetric,
Peer: "peer1",
},
},
currentRoute: "",
currentRoute: nil,
expectedRouteID: "route1",
},
{
name: "one connected routes with relayed and no direct",
statuses: map[route.ID]routerPeerStatus{
statuses: map[string]routerPeerStatus{
"route1": {
connected: true,
relayed: true,
direct: false,
},
},
existingRoutes: map[route.ID]*route.Route{
existingRoutes: map[string]*route.Route{
"route1": {
ID: "route1",
Metric: route.MaxMetric,
Peer: "peer1",
},
},
currentRoute: "",
currentRoute: nil,
expectedRouteID: "route1",
},
{
name: "no connected peers",
statuses: map[route.ID]routerPeerStatus{
statuses: map[string]routerPeerStatus{
"route1": {
connected: false,
relayed: false,
direct: false,
},
},
existingRoutes: map[route.ID]*route.Route{
existingRoutes: map[string]*route.Route{
"route1": {
ID: "route1",
Metric: route.MaxMetric,
Peer: "peer1",
},
},
currentRoute: "",
currentRoute: nil,
expectedRouteID: "",
},
{
name: "multiple connected peers with different metrics",
statuses: map[route.ID]routerPeerStatus{
statuses: map[string]routerPeerStatus{
"route1": {
connected: true,
relayed: false,
@@ -108,7 +106,7 @@ func TestGetBestrouteFromStatuses(t *testing.T) {
direct: true,
},
},
existingRoutes: map[route.ID]*route.Route{
existingRoutes: map[string]*route.Route{
"route1": {
ID: "route1",
Metric: 9000,
@@ -120,12 +118,12 @@ func TestGetBestrouteFromStatuses(t *testing.T) {
Peer: "peer2",
},
},
currentRoute: "",
currentRoute: nil,
expectedRouteID: "route1",
},
{
name: "multiple connected peers with one relayed",
statuses: map[route.ID]routerPeerStatus{
statuses: map[string]routerPeerStatus{
"route1": {
connected: true,
relayed: false,
@@ -137,7 +135,7 @@ func TestGetBestrouteFromStatuses(t *testing.T) {
direct: true,
},
},
existingRoutes: map[route.ID]*route.Route{
existingRoutes: map[string]*route.Route{
"route1": {
ID: "route1",
Metric: route.MaxMetric,
@@ -149,12 +147,12 @@ func TestGetBestrouteFromStatuses(t *testing.T) {
Peer: "peer2",
},
},
currentRoute: "",
currentRoute: nil,
expectedRouteID: "route1",
},
{
name: "multiple connected peers with one direct",
statuses: map[route.ID]routerPeerStatus{
statuses: map[string]routerPeerStatus{
"route1": {
connected: true,
relayed: false,
@@ -166,7 +164,7 @@ func TestGetBestrouteFromStatuses(t *testing.T) {
direct: false,
},
},
existingRoutes: map[route.ID]*route.Route{
existingRoutes: map[string]*route.Route{
"route1": {
ID: "route1",
Metric: route.MaxMetric,
@@ -178,172 +176,18 @@ func TestGetBestrouteFromStatuses(t *testing.T) {
Peer: "peer2",
},
},
currentRoute: "",
currentRoute: nil,
expectedRouteID: "route1",
},
{
name: "multiple connected peers with different latencies",
statuses: map[route.ID]routerPeerStatus{
"route1": {
connected: true,
latency: 300 * time.Millisecond,
},
"route2": {
connected: true,
latency: 10 * time.Millisecond,
},
},
existingRoutes: map[route.ID]*route.Route{
"route1": {
ID: "route1",
Metric: route.MaxMetric,
Peer: "peer1",
},
"route2": {
ID: "route2",
Metric: route.MaxMetric,
Peer: "peer2",
},
},
currentRoute: "",
expectedRouteID: "route2",
},
{
name: "should ignore routes with latency 0",
statuses: map[route.ID]routerPeerStatus{
"route1": {
connected: true,
latency: 0 * time.Millisecond,
},
"route2": {
connected: true,
latency: 10 * time.Millisecond,
},
},
existingRoutes: map[route.ID]*route.Route{
"route1": {
ID: "route1",
Metric: route.MaxMetric,
Peer: "peer1",
},
"route2": {
ID: "route2",
Metric: route.MaxMetric,
Peer: "peer2",
},
},
currentRoute: "",
expectedRouteID: "route2",
},
{
name: "current route with similar score and similar but slightly worse latency should not change",
statuses: map[route.ID]routerPeerStatus{
"route1": {
connected: true,
relayed: false,
direct: true,
latency: 15 * time.Millisecond,
},
"route2": {
connected: true,
relayed: false,
direct: true,
latency: 10 * time.Millisecond,
},
},
existingRoutes: map[route.ID]*route.Route{
"route1": {
ID: "route1",
Metric: route.MaxMetric,
Peer: "peer1",
},
"route2": {
ID: "route2",
Metric: route.MaxMetric,
Peer: "peer2",
},
},
currentRoute: "route1",
expectedRouteID: "route1",
},
{
name: "current route with bad score should be changed to route with better score",
statuses: map[route.ID]routerPeerStatus{
"route1": {
connected: true,
relayed: false,
direct: true,
latency: 200 * time.Millisecond,
},
"route2": {
connected: true,
relayed: false,
direct: true,
latency: 10 * time.Millisecond,
},
},
existingRoutes: map[route.ID]*route.Route{
"route1": {
ID: "route1",
Metric: route.MaxMetric,
Peer: "peer1",
},
"route2": {
ID: "route2",
Metric: route.MaxMetric,
Peer: "peer2",
},
},
currentRoute: "route1",
expectedRouteID: "route2",
},
{
name: "current chosen route doesn't exist anymore",
statuses: map[route.ID]routerPeerStatus{
"route1": {
connected: true,
relayed: false,
direct: true,
latency: 20 * time.Millisecond,
},
"route2": {
connected: true,
relayed: false,
direct: true,
latency: 10 * time.Millisecond,
},
},
existingRoutes: map[route.ID]*route.Route{
"route1": {
ID: "route1",
Metric: route.MaxMetric,
Peer: "peer1",
},
"route2": {
ID: "route2",
Metric: route.MaxMetric,
Peer: "peer2",
},
},
currentRoute: "routeDoesntExistAnymore",
expectedRouteID: "route2",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
currentRoute := &route.Route{
ID: "routeDoesntExistAnymore",
}
if tc.currentRoute != "" {
currentRoute = tc.existingRoutes[tc.currentRoute]
}
// create new clientNetwork
client := &clientNetwork{
handler: static.NewRoute(&route.Route{Network: netip.MustParsePrefix("192.168.0.0/24")}, nil, nil),
routes: tc.existingRoutes,
currentChosen: currentRoute,
network: netip.MustParsePrefix("192.168.0.0/24"),
routes: tc.existingRoutes,
chosenRoute: tc.currentRoute,
}
chosenRoute := client.getBestRouteFromStatuses(tc.statuses)

View File

@@ -1,378 +0,0 @@
package dynamic
import (
"context"
"fmt"
"net"
"net/netip"
"strings"
"sync"
"time"
"github.com/hashicorp/go-multierror"
log "github.com/sirupsen/logrus"
nberrors "github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
"github.com/netbirdio/netbird/client/internal/routemanager/util"
"github.com/netbirdio/netbird/management/domain"
"github.com/netbirdio/netbird/route"
)
const (
DefaultInterval = time.Minute
minInterval = 2 * time.Second
failureInterval = 5 * time.Second
addAllowedIP = "add allowed IP %s: %w"
)
type domainMap map[domain.Domain][]netip.Prefix
type resolveResult struct {
domain domain.Domain
prefix netip.Prefix
err error
}
type Route struct {
route *route.Route
routeRefCounter *refcounter.RouteRefCounter
allowedIPsRefcounter *refcounter.AllowedIPsRefCounter
interval time.Duration
dynamicDomains domainMap
mu sync.Mutex
currentPeerKey string
cancel context.CancelFunc
statusRecorder *peer.Status
}
func NewRoute(
rt *route.Route,
routeRefCounter *refcounter.RouteRefCounter,
allowedIPsRefCounter *refcounter.AllowedIPsRefCounter,
interval time.Duration,
statusRecorder *peer.Status,
) *Route {
return &Route{
route: rt,
routeRefCounter: routeRefCounter,
allowedIPsRefcounter: allowedIPsRefCounter,
interval: interval,
dynamicDomains: domainMap{},
statusRecorder: statusRecorder,
}
}
func (r *Route) String() string {
s, err := r.route.Domains.String()
if err != nil {
return r.route.Domains.PunycodeString()
}
return s
}
func (r *Route) AddRoute(ctx context.Context) error {
r.mu.Lock()
defer r.mu.Unlock()
if r.cancel != nil {
r.cancel()
}
ctx, r.cancel = context.WithCancel(ctx)
go r.startResolver(ctx)
return nil
}
// RemoveRoute will stop the dynamic resolver and remove all dynamic routes.
// It doesn't touch allowed IPs, these should be removed separately and before calling this method.
func (r *Route) RemoveRoute() error {
r.mu.Lock()
defer r.mu.Unlock()
if r.cancel != nil {
r.cancel()
}
var merr *multierror.Error
for domain, prefixes := range r.dynamicDomains {
for _, prefix := range prefixes {
if _, err := r.routeRefCounter.Decrement(prefix); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove dynamic route for IP %s: %w", prefix, err))
}
}
log.Debugf("Removed dynamic route(s) for [%s]: %s", domain.SafeString(), strings.ReplaceAll(fmt.Sprintf("%s", prefixes), " ", ", "))
r.statusRecorder.DeleteResolvedDomainsStates(domain)
}
r.dynamicDomains = domainMap{}
return nberrors.FormatErrorOrNil(merr)
}
func (r *Route) AddAllowedIPs(peerKey string) error {
r.mu.Lock()
defer r.mu.Unlock()
var merr *multierror.Error
for domain, domainPrefixes := range r.dynamicDomains {
for _, prefix := range domainPrefixes {
if err := r.incrementAllowedIP(domain, prefix, peerKey); err != nil {
merr = multierror.Append(merr, fmt.Errorf(addAllowedIP, prefix, err))
}
}
}
r.currentPeerKey = peerKey
return nberrors.FormatErrorOrNil(merr)
}
func (r *Route) RemoveAllowedIPs() error {
r.mu.Lock()
defer r.mu.Unlock()
var merr *multierror.Error
for _, domainPrefixes := range r.dynamicDomains {
for _, prefix := range domainPrefixes {
if _, err := r.allowedIPsRefcounter.Decrement(prefix); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove allowed IP %s: %w", prefix, err))
}
}
}
r.currentPeerKey = ""
return nberrors.FormatErrorOrNil(merr)
}
func (r *Route) startResolver(ctx context.Context) {
log.Debugf("Starting dynamic route resolver for domains [%v]", r)
interval := r.interval
if interval < minInterval {
interval = minInterval
log.Warnf("Dynamic route resolver interval %s is too low, setting to minimum value %s", r.interval, minInterval)
}
ticker := time.NewTicker(interval)
defer ticker.Stop()
if err := r.update(ctx); err != nil {
log.Errorf("Failed to resolve domains for route [%v]: %v", r, err)
if interval > failureInterval {
ticker.Reset(failureInterval)
}
}
for {
select {
case <-ctx.Done():
log.Debugf("Stopping dynamic route resolver for domains [%v]", r)
return
case <-ticker.C:
if err := r.update(ctx); err != nil {
log.Errorf("Failed to resolve domains for route [%v]: %v", r, err)
// Use a lower ticker interval if the update fails
if interval > failureInterval {
ticker.Reset(failureInterval)
}
} else if interval > failureInterval {
// Reset to the original interval if the update succeeds
ticker.Reset(interval)
}
}
}
}
func (r *Route) update(ctx context.Context) error {
if resolved, err := r.resolveDomains(); err != nil {
return fmt.Errorf("resolve domains: %w", err)
} else if err := r.updateDynamicRoutes(ctx, resolved); err != nil {
return fmt.Errorf("update dynamic routes: %w", err)
}
return nil
}
func (r *Route) resolveDomains() (domainMap, error) {
results := make(chan resolveResult)
go r.resolve(results)
resolved := domainMap{}
var merr *multierror.Error
for result := range results {
if result.err != nil {
merr = multierror.Append(merr, result.err)
} else {
resolved[result.domain] = append(resolved[result.domain], result.prefix)
}
}
return resolved, nberrors.FormatErrorOrNil(merr)
}
func (r *Route) resolve(results chan resolveResult) {
var wg sync.WaitGroup
for _, d := range r.route.Domains {
wg.Add(1)
go func(domain domain.Domain) {
defer wg.Done()
ips, err := net.LookupIP(string(domain))
if err != nil {
results <- resolveResult{domain: domain, err: fmt.Errorf("resolve d %s: %w", domain.SafeString(), err)}
return
}
for _, ip := range ips {
prefix, err := util.GetPrefixFromIP(ip)
if err != nil {
results <- resolveResult{domain: domain, err: fmt.Errorf("get prefix from IP %s: %w", ip.String(), err)}
return
}
results <- resolveResult{domain: domain, prefix: prefix}
}
}(d)
}
wg.Wait()
close(results)
}
func (r *Route) updateDynamicRoutes(ctx context.Context, newDomains domainMap) error {
r.mu.Lock()
defer r.mu.Unlock()
if ctx.Err() != nil {
return ctx.Err()
}
var merr *multierror.Error
for domain, newPrefixes := range newDomains {
oldPrefixes := r.dynamicDomains[domain]
toAdd, toRemove := determinePrefixChanges(oldPrefixes, newPrefixes)
addedPrefixes, err := r.addRoutes(domain, toAdd)
if err != nil {
merr = multierror.Append(merr, err)
} else if len(addedPrefixes) > 0 {
log.Debugf("Added dynamic route(s) for [%s]: %s", domain.SafeString(), strings.ReplaceAll(fmt.Sprintf("%s", addedPrefixes), " ", ", "))
}
removedPrefixes, err := r.removeRoutes(toRemove)
if err != nil {
merr = multierror.Append(merr, err)
} else if len(removedPrefixes) > 0 {
log.Debugf("Removed dynamic route(s) for [%s]: %s", domain.SafeString(), strings.ReplaceAll(fmt.Sprintf("%s", removedPrefixes), " ", ", "))
}
updatedPrefixes := combinePrefixes(oldPrefixes, removedPrefixes, addedPrefixes)
r.dynamicDomains[domain] = updatedPrefixes
r.statusRecorder.UpdateResolvedDomainsStates(domain, updatedPrefixes)
}
return nberrors.FormatErrorOrNil(merr)
}
func (r *Route) addRoutes(domain domain.Domain, prefixes []netip.Prefix) ([]netip.Prefix, error) {
var addedPrefixes []netip.Prefix
var merr *multierror.Error
for _, prefix := range prefixes {
if _, err := r.routeRefCounter.Increment(prefix, nil); err != nil {
merr = multierror.Append(merr, fmt.Errorf("add dynamic route for IP %s: %w", prefix, err))
continue
}
if r.currentPeerKey != "" {
if err := r.incrementAllowedIP(domain, prefix, r.currentPeerKey); err != nil {
merr = multierror.Append(merr, fmt.Errorf(addAllowedIP, prefix, err))
}
}
addedPrefixes = append(addedPrefixes, prefix)
}
return addedPrefixes, merr.ErrorOrNil()
}
func (r *Route) removeRoutes(prefixes []netip.Prefix) ([]netip.Prefix, error) {
if r.route.KeepRoute {
return nil, nil
}
var removedPrefixes []netip.Prefix
var merr *multierror.Error
for _, prefix := range prefixes {
if _, err := r.routeRefCounter.Decrement(prefix); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove dynamic route for IP %s: %w", prefix, err))
}
if r.currentPeerKey != "" {
if _, err := r.allowedIPsRefcounter.Decrement(prefix); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove allowed IP %s: %w", prefix, err))
}
}
removedPrefixes = append(removedPrefixes, prefix)
}
return removedPrefixes, merr.ErrorOrNil()
}
func (r *Route) incrementAllowedIP(domain domain.Domain, prefix netip.Prefix, peerKey string) error {
if ref, err := r.allowedIPsRefcounter.Increment(prefix, peerKey); err != nil {
return fmt.Errorf(addAllowedIP, prefix, err)
} else if ref.Count > 1 && ref.Out != peerKey {
log.Warnf("IP [%s] for domain [%s] is already routed by peer [%s]. HA routing disabled",
prefix.Addr(),
domain.SafeString(),
ref.Out,
)
}
return nil
}
func determinePrefixChanges(oldPrefixes, newPrefixes []netip.Prefix) (toAdd, toRemove []netip.Prefix) {
prefixSet := make(map[netip.Prefix]bool)
for _, prefix := range oldPrefixes {
prefixSet[prefix] = false
}
for _, prefix := range newPrefixes {
if _, exists := prefixSet[prefix]; exists {
prefixSet[prefix] = true
} else {
toAdd = append(toAdd, prefix)
}
}
for prefix, inUse := range prefixSet {
if !inUse {
toRemove = append(toRemove, prefix)
}
}
return
}
func combinePrefixes(oldPrefixes, removedPrefixes, addedPrefixes []netip.Prefix) []netip.Prefix {
prefixSet := make(map[netip.Prefix]struct{})
for _, prefix := range oldPrefixes {
prefixSet[prefix] = struct{}{}
}
for _, prefix := range removedPrefixes {
delete(prefixSet, prefix)
}
for _, prefix := range addedPrefixes {
prefixSet[prefix] = struct{}{}
}
var combinedPrefixes []netip.Prefix
for prefix := range prefixSet {
combinedPrefixes = append(combinedPrefixes, prefix)
}
return combinedPrefixes
}

View File

@@ -2,36 +2,22 @@ package routemanager
import (
"context"
"errors"
"fmt"
"net"
"net/netip"
"net/url"
"runtime"
"sync"
"time"
log "github.com/sirupsen/logrus"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/internal/listener"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
"github.com/netbirdio/netbird/client/internal/routemanager/vars"
"github.com/netbirdio/netbird/client/internal/routeselector"
"github.com/netbirdio/netbird/iface"
"github.com/netbirdio/netbird/route"
nbnet "github.com/netbirdio/netbird/util/net"
"github.com/netbirdio/netbird/version"
)
// Manager is a route manager interface
type Manager interface {
Init() (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error)
UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) (map[route.ID]*route.Route, route.HAMap, error)
TriggerSelection(route.HAMap)
GetRouteSelector() *routeselector.RouteSelector
UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) error
SetRouteChangeListener(listener listener.NetworkChangeListener)
InitialRouteRange() []string
EnableServerRouter(firewall firewall.Manager) error
@@ -40,71 +26,29 @@ type Manager interface {
// DefaultManager is the default instance of a route manager
type DefaultManager struct {
ctx context.Context
stop context.CancelFunc
mux sync.Mutex
clientNetworks map[route.HAUniqueID]*clientNetwork
routeSelector *routeselector.RouteSelector
serverRouter serverRouter
sysOps *systemops.SysOps
statusRecorder *peer.Status
wgInterface *iface.WGIface
pubKey string
notifier *notifier
routeRefCounter *refcounter.RouteRefCounter
allowedIPsRefCounter *refcounter.AllowedIPsRefCounter
dnsRouteInterval time.Duration
ctx context.Context
stop context.CancelFunc
mux sync.Mutex
clientNetworks map[string]*clientNetwork
serverRouter serverRouter
statusRecorder *peer.Status
wgInterface *iface.WGIface
pubKey string
notifier *notifier
}
func NewManager(
ctx context.Context,
pubKey string,
dnsRouteInterval time.Duration,
wgInterface *iface.WGIface,
statusRecorder *peer.Status,
initialRoutes []*route.Route,
) *DefaultManager {
func NewManager(ctx context.Context, pubKey string, wgInterface *iface.WGIface, statusRecorder *peer.Status, initialRoutes []*route.Route) *DefaultManager {
mCTX, cancel := context.WithCancel(ctx)
sysOps := systemops.NewSysOps(wgInterface)
dm := &DefaultManager{
ctx: mCTX,
stop: cancel,
dnsRouteInterval: dnsRouteInterval,
clientNetworks: make(map[route.HAUniqueID]*clientNetwork),
routeSelector: routeselector.NewRouteSelector(),
sysOps: sysOps,
statusRecorder: statusRecorder,
wgInterface: wgInterface,
pubKey: pubKey,
notifier: newNotifier(),
ctx: mCTX,
stop: cancel,
clientNetworks: make(map[string]*clientNetwork),
statusRecorder: statusRecorder,
wgInterface: wgInterface,
pubKey: pubKey,
notifier: newNotifier(),
}
dm.routeRefCounter = refcounter.New(
func(prefix netip.Prefix, _ any) (any, error) {
return nil, sysOps.AddVPNRoute(prefix, wgInterface.ToInterface())
},
func(prefix netip.Prefix, _ any) error {
return sysOps.RemoveVPNRoute(prefix, wgInterface.ToInterface())
},
)
dm.allowedIPsRefCounter = refcounter.New(
func(prefix netip.Prefix, peerKey string) (string, error) {
// save peerKey to use it in the remove function
return peerKey, wgInterface.AddAllowedIP(peerKey, prefix.String())
},
func(prefix netip.Prefix, peerKey string) error {
if err := wgInterface.RemoveAllowedIP(peerKey, prefix.String()); err != nil {
if !errors.Is(err, iface.ErrPeerNotFound) && !errors.Is(err, iface.ErrAllowedIPNotFound) {
return err
}
log.Tracef("Remove allowed IPs %s for %s: %v", prefix, peerKey, err)
}
return nil
},
)
if runtime.GOOS == "android" {
cr := dm.clientRoutes(initialRoutes)
dm.notifier.setInitialClientRoutes(cr)
@@ -112,31 +56,9 @@ func NewManager(
return dm
}
// Init sets up the routing
func (m *DefaultManager) Init() (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) {
if nbnet.CustomRoutingDisabled() {
return nil, nil, nil
}
if err := m.sysOps.CleanupRouting(); err != nil {
log.Warnf("Failed cleaning up routing: %v", err)
}
mgmtAddress := m.statusRecorder.GetManagementState().URL
signalAddress := m.statusRecorder.GetSignalState().URL
ips := resolveURLsToIPs([]string{mgmtAddress, signalAddress})
beforePeerHook, afterPeerHook, err := m.sysOps.SetupRouting(ips)
if err != nil {
return nil, nil, fmt.Errorf("setup routing: %w", err)
}
log.Info("Routing setup complete")
return beforePeerHook, afterPeerHook, nil
}
func (m *DefaultManager) EnableServerRouter(firewall firewall.Manager) error {
var err error
m.serverRouter, err = newServerRouter(m.ctx, m.wgInterface, firewall, m.statusRecorder)
m.serverRouter, err = newServerRouter(m.ctx, m.wgInterface, firewall)
if err != nil {
return err
}
@@ -149,53 +71,32 @@ func (m *DefaultManager) Stop() {
if m.serverRouter != nil {
m.serverRouter.cleanUp()
}
if m.routeRefCounter != nil {
if err := m.routeRefCounter.Flush(); err != nil {
log.Errorf("Error flushing route ref counter: %v", err)
}
}
if m.allowedIPsRefCounter != nil {
if err := m.allowedIPsRefCounter.Flush(); err != nil {
log.Errorf("Error flushing allowed IPs ref counter: %v", err)
}
}
if !nbnet.CustomRoutingDisabled() {
if err := m.sysOps.CleanupRouting(); err != nil {
log.Errorf("Error cleaning up routing: %v", err)
} else {
log.Info("Routing cleanup complete")
}
}
m.ctx = nil
}
// UpdateRoutes compares received routes with existing routes and removes, updates or adds them to the client and server maps
func (m *DefaultManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) (map[route.ID]*route.Route, route.HAMap, error) {
// UpdateRoutes compares received routes with existing routes and remove, update or add them to the client and server maps
func (m *DefaultManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) error {
select {
case <-m.ctx.Done():
log.Infof("not updating routes as context is closed")
return nil, nil, m.ctx.Err()
return m.ctx.Err()
default:
m.mux.Lock()
defer m.mux.Unlock()
newServerRoutesMap, newClientRoutesIDMap := m.classifyRoutes(newRoutes)
newServerRoutesMap, newClientRoutesIDMap := m.classifiesRoutes(newRoutes)
filteredClientRoutes := m.routeSelector.FilterSelected(newClientRoutesIDMap)
m.updateClientNetworks(updateSerial, filteredClientRoutes)
m.notifier.onNewRoutes(filteredClientRoutes)
m.updateClientNetworks(updateSerial, newClientRoutesIDMap)
m.notifier.onNewRoutes(newClientRoutesIDMap)
if m.serverRouter != nil {
err := m.serverRouter.updateRoutes(newServerRoutesMap)
if err != nil {
return nil, nil, fmt.Errorf("update routes: %w", err)
return err
}
}
return newServerRoutesMap, newClientRoutesIDMap, nil
return nil
}
}
@@ -206,62 +107,24 @@ func (m *DefaultManager) SetRouteChangeListener(listener listener.NetworkChangeL
// InitialRouteRange return the list of initial routes. It used by mobile systems
func (m *DefaultManager) InitialRouteRange() []string {
return m.notifier.getInitialRouteRanges()
return m.notifier.initialRouteRanges()
}
// GetRouteSelector returns the route selector
func (m *DefaultManager) GetRouteSelector() *routeselector.RouteSelector {
return m.routeSelector
}
// GetClientRoutes returns the client routes
func (m *DefaultManager) GetClientRoutes() map[route.HAUniqueID]*clientNetwork {
return m.clientNetworks
}
// TriggerSelection triggers the selection of routes, stopping deselected watchers and starting newly selected ones
func (m *DefaultManager) TriggerSelection(networks route.HAMap) {
m.mux.Lock()
defer m.mux.Unlock()
networks = m.routeSelector.FilterSelected(networks)
m.notifier.onNewRoutes(networks)
m.stopObsoleteClients(networks)
for id, routes := range networks {
if _, found := m.clientNetworks[id]; found {
// don't touch existing client network watchers
continue
}
clientNetworkWatcher := newClientNetworkWatcher(m.ctx, m.dnsRouteInterval, m.wgInterface, m.statusRecorder, routes[0], m.routeRefCounter, m.allowedIPsRefCounter)
m.clientNetworks[id] = clientNetworkWatcher
go clientNetworkWatcher.peersStateAndUpdateWatcher()
clientNetworkWatcher.sendUpdateToClientNetworkWatcher(routesUpdate{routes: routes})
}
}
// stopObsoleteClients stops the client network watcher for the networks that are not in the new list
func (m *DefaultManager) stopObsoleteClients(networks route.HAMap) {
func (m *DefaultManager) updateClientNetworks(updateSerial uint64, networks map[string][]*route.Route) {
// removing routes that do not exist as per the update from the Management service.
for id, client := range m.clientNetworks {
if _, ok := networks[id]; !ok {
log.Debugf("Stopping client network watcher, %s", id)
client.cancel()
_, found := networks[id]
if !found {
log.Debugf("stopping client network watcher, %s", id)
client.stop()
delete(m.clientNetworks, id)
}
}
}
func (m *DefaultManager) updateClientNetworks(updateSerial uint64, networks route.HAMap) {
// removing routes that do not exist as per the update from the Management service.
m.stopObsoleteClients(networks)
for id, routes := range networks {
clientNetworkWatcher, found := m.clientNetworks[id]
if !found {
clientNetworkWatcher = newClientNetworkWatcher(m.ctx, m.dnsRouteInterval, m.wgInterface, m.statusRecorder, routes[0], m.routeRefCounter, m.allowedIPsRefCounter)
clientNetworkWatcher = newClientNetworkWatcher(m.ctx, m.wgInterface, m.statusRecorder, routes[0].Network)
m.clientNetworks[id] = clientNetworkWatcher
go clientNetworkWatcher.peersStateAndUpdateWatcher()
}
@@ -273,15 +136,15 @@ func (m *DefaultManager) updateClientNetworks(updateSerial uint64, networks rout
}
}
func (m *DefaultManager) classifyRoutes(newRoutes []*route.Route) (map[route.ID]*route.Route, route.HAMap) {
newClientRoutesIDMap := make(route.HAMap)
newServerRoutesMap := make(map[route.ID]*route.Route)
ownNetworkIDs := make(map[route.HAUniqueID]bool)
func (m *DefaultManager) classifiesRoutes(newRoutes []*route.Route) (map[string]*route.Route, map[string][]*route.Route) {
newClientRoutesIDMap := make(map[string][]*route.Route)
newServerRoutesMap := make(map[string]*route.Route)
ownNetworkIDs := make(map[string]bool)
for _, newRoute := range newRoutes {
haID := newRoute.GetHAUniqueID()
networkID := route.GetHAUniqueID(newRoute)
if newRoute.Peer == m.pubKey {
ownNetworkIDs[haID] = true
ownNetworkIDs[networkID] = true
// only linux is supported for now
if runtime.GOOS != "linux" {
log.Warnf("received a route to manage, but agent doesn't support router mode on %s OS", runtime.GOOS)
@@ -292,12 +155,16 @@ func (m *DefaultManager) classifyRoutes(newRoutes []*route.Route) (map[route.ID]
}
for _, newRoute := range newRoutes {
haID := newRoute.GetHAUniqueID()
if !ownNetworkIDs[haID] {
if !isRouteSupported(newRoute) {
networkID := route.GetHAUniqueID(newRoute)
if !ownNetworkIDs[networkID] {
// if prefix is too small, lets assume is a possible default route which is not yet supported
// we skip this route management
if newRoute.Network.Bits() < minRangeBits {
log.Errorf("this agent version: %s, doesn't support default routes, received %s, skipping this route",
version.NetbirdVersion(), newRoute.Network)
continue
}
newClientRoutesIDMap[haID] = append(newClientRoutesIDMap[haID], newRoute)
newClientRoutesIDMap[networkID] = append(newClientRoutesIDMap[networkID], newRoute)
}
}
@@ -305,44 +172,10 @@ func (m *DefaultManager) classifyRoutes(newRoutes []*route.Route) (map[route.ID]
}
func (m *DefaultManager) clientRoutes(initialRoutes []*route.Route) []*route.Route {
_, crMap := m.classifyRoutes(initialRoutes)
rs := make([]*route.Route, 0, len(crMap))
_, crMap := m.classifiesRoutes(initialRoutes)
rs := make([]*route.Route, 0)
for _, routes := range crMap {
rs = append(rs, routes...)
}
return rs
}
func isRouteSupported(route *route.Route) bool {
if !nbnet.CustomRoutingDisabled() || route.IsDynamic() {
return true
}
// If prefix is too small, lets assume it is a possible default prefix which is not yet supported
// we skip this prefix management
if route.Network.Bits() <= vars.MinRangeBits {
log.Warnf("This agent version: %s, doesn't support default routes, received %s, skipping this prefix",
version.NetbirdVersion(), route.Network)
return false
}
return true
}
// resolveURLsToIPs takes a slice of URLs, resolves them to IP addresses and returns a slice of IPs.
func resolveURLsToIPs(urls []string) []net.IP {
var ips []net.IP
for _, rawurl := range urls {
u, err := url.Parse(rawurl)
if err != nil {
log.Errorf("Failed to parse url %s: %v", rawurl, err)
continue
}
ipAddrs, err := net.LookupIP(u.Hostname())
if err != nil {
log.Errorf("Failed to resolve host %s: %v", u.Hostname(), err)
continue
}
ips = append(ips, ipAddrs...)
}
return ips
}

View File

@@ -28,14 +28,13 @@ const remotePeerKey2 = "remote1"
func TestManagerUpdateRoutes(t *testing.T) {
testCases := []struct {
name string
inputInitRoutes []*route.Route
inputRoutes []*route.Route
inputSerial uint64
removeSrvRouter bool
serverRoutesExpected int
clientNetworkWatchersExpected int
clientNetworkWatchersExpectedAllowed int
name string
inputInitRoutes []*route.Route
inputRoutes []*route.Route
inputSerial uint64
removeSrvRouter bool
serverRoutesExpected int
clientNetworkWatchersExpected int
}{
{
name: "Should create 2 client networks",
@@ -201,9 +200,8 @@ func TestManagerUpdateRoutes(t *testing.T) {
Enabled: true,
},
},
inputSerial: 1,
clientNetworkWatchersExpected: 0,
clientNetworkWatchersExpectedAllowed: 1,
inputSerial: 1,
clientNetworkWatchersExpected: 0,
},
{
name: "Remove 1 Client Route",
@@ -407,7 +405,7 @@ func TestManagerUpdateRoutes(t *testing.T) {
if err != nil {
t.Fatal(err)
}
wgInterface, err := iface.NewWGIFace(fmt.Sprintf("utun43%d", n), "100.65.65.2/24", 33100, peerPrivateKey.String(), iface.DefaultMTU, newNet, nil, nil)
wgInterface, err := iface.NewWGIFace(fmt.Sprintf("utun43%d", n), "100.65.65.2/24", 33100, peerPrivateKey.String(), iface.DefaultMTU, newNet, nil)
require.NoError(t, err, "should create testing WGIface interface")
defer wgInterface.Close()
@@ -416,11 +414,7 @@ func TestManagerUpdateRoutes(t *testing.T) {
statusRecorder := peer.NewRecorder("https://mgm")
ctx := context.TODO()
routeManager := NewManager(ctx, localPeerKey, 0, wgInterface, statusRecorder, nil)
_, _, err = routeManager.Init()
require.NoError(t, err, "should init route manager")
routeManager := NewManager(ctx, localPeerKey, wgInterface, statusRecorder, nil)
defer routeManager.Stop()
if testCase.removeSrvRouter {
@@ -428,18 +422,14 @@ func TestManagerUpdateRoutes(t *testing.T) {
}
if len(testCase.inputInitRoutes) > 0 {
_, _, err = routeManager.UpdateRoutes(testCase.inputSerial, testCase.inputRoutes)
err = routeManager.UpdateRoutes(testCase.inputSerial, testCase.inputRoutes)
require.NoError(t, err, "should update routes with init routes")
}
_, _, err = routeManager.UpdateRoutes(testCase.inputSerial+uint64(len(testCase.inputInitRoutes)), testCase.inputRoutes)
err = routeManager.UpdateRoutes(testCase.inputSerial+uint64(len(testCase.inputInitRoutes)), testCase.inputRoutes)
require.NoError(t, err, "should update routes")
expectedWatchers := testCase.clientNetworkWatchersExpected
if testCase.clientNetworkWatchersExpectedAllowed != 0 {
expectedWatchers = testCase.clientNetworkWatchersExpectedAllowed
}
require.Len(t, routeManager.clientNetworks, expectedWatchers, "client networks size should match")
require.Len(t, routeManager.clientNetworks, testCase.clientNetworkWatchersExpected, "client networks size should match")
if runtime.GOOS == "linux" && routeManager.serverRouter != nil {
sr := routeManager.serverRouter.(*defaultServerRouter)

View File

@@ -6,22 +6,14 @@ import (
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/internal/listener"
"github.com/netbirdio/netbird/client/internal/routeselector"
"github.com/netbirdio/netbird/iface"
"github.com/netbirdio/netbird/route"
"github.com/netbirdio/netbird/util/net"
)
// MockManager is the mock instance of a route manager
type MockManager struct {
UpdateRoutesFunc func(updateSerial uint64, newRoutes []*route.Route) (map[route.ID]*route.Route, route.HAMap, error)
TriggerSelectionFunc func(haMap route.HAMap)
GetRouteSelectorFunc func() *routeselector.RouteSelector
StopFunc func()
}
func (m *MockManager) Init() (net.AddHookFunc, net.RemoveHookFunc, error) {
return nil, nil, nil
UpdateRoutesFunc func(updateSerial uint64, newRoutes []*route.Route) error
StopFunc func()
}
// InitialRouteRange mock implementation of InitialRouteRange from Manager interface
@@ -30,25 +22,11 @@ func (m *MockManager) InitialRouteRange() []string {
}
// UpdateRoutes mock implementation of UpdateRoutes from Manager interface
func (m *MockManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) (map[route.ID]*route.Route, route.HAMap, error) {
func (m *MockManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) error {
if m.UpdateRoutesFunc != nil {
return m.UpdateRoutesFunc(updateSerial, newRoutes)
}
return nil, nil, fmt.Errorf("method UpdateRoutes is not implemented")
}
func (m *MockManager) TriggerSelection(networks route.HAMap) {
if m.TriggerSelectionFunc != nil {
m.TriggerSelectionFunc(networks)
}
}
// GetRouteSelector mock implementation of GetRouteSelector from Manager interface
func (m *MockManager) GetRouteSelector() *routeselector.RouteSelector {
if m.GetRouteSelectorFunc != nil {
return m.GetRouteSelectorFunc()
}
return nil
return fmt.Errorf("method UpdateRoutes is not implemented")
}
// Start mock implementation of Start from Manager interface

View File

@@ -1,7 +1,6 @@
package routemanager
import (
"runtime"
"sort"
"strings"
"sync"
@@ -11,8 +10,8 @@ import (
)
type notifier struct {
initialRouteRanges []string
routeRanges []string
initialRouteRangers []string
routeRangers []string
listener listener.NetworkChangeListener
listenerMux sync.Mutex
@@ -34,10 +33,10 @@ func (n *notifier) setInitialClientRoutes(clientRoutes []*route.Route) {
nets = append(nets, r.Network.String())
}
sort.Strings(nets)
n.initialRouteRanges = nets
n.initialRouteRangers = nets
}
func (n *notifier) onNewRoutes(idMap route.HAMap) {
func (n *notifier) onNewRoutes(idMap map[string][]*route.Route) {
newNets := make([]string, 0)
for _, routes := range idMap {
for _, r := range routes {
@@ -46,18 +45,11 @@ func (n *notifier) onNewRoutes(idMap route.HAMap) {
}
sort.Strings(newNets)
switch runtime.GOOS {
case "android":
if !n.hasDiff(n.initialRouteRanges, newNets) {
return
}
default:
if !n.hasDiff(n.routeRanges, newNets) {
return
}
if !n.hasDiff(n.initialRouteRangers, newNets) {
return
}
n.routeRanges = newNets
n.routeRangers = newNets
n.notify()
}
@@ -70,7 +62,7 @@ func (n *notifier) notify() {
}
go func(l listener.NetworkChangeListener) {
l.OnNetworkChanged(strings.Join(addIPv6RangeIfNeeded(n.routeRanges), ","))
l.OnNetworkChanged(strings.Join(n.routeRangers, ","))
}(n.listener)
}
@@ -86,20 +78,6 @@ func (n *notifier) hasDiff(a []string, b []string) bool {
return false
}
func (n *notifier) getInitialRouteRanges() []string {
return addIPv6RangeIfNeeded(n.initialRouteRanges)
}
// addIPv6RangeIfNeeded returns the input ranges with the default IPv6 range when there is an IPv4 default route.
func addIPv6RangeIfNeeded(inputRanges []string) []string {
ranges := inputRanges
for _, r := range inputRanges {
// we are intentionally adding the ipv6 default range in case of ipv4 default range
// to ensure that all traffic is managed by the tunnel interface on android
if r == "0.0.0.0/0" {
ranges = append(ranges, "::/0")
break
}
}
return ranges
func (n *notifier) initialRouteRanges() []string {
return n.initialRouteRangers
}

View File

@@ -1,155 +0,0 @@
package refcounter
import (
"errors"
"fmt"
"net/netip"
"sync"
"github.com/hashicorp/go-multierror"
log "github.com/sirupsen/logrus"
nberrors "github.com/netbirdio/netbird/client/errors"
)
// ErrIgnore can be returned by AddFunc to indicate that the counter not be incremented for the given prefix.
var ErrIgnore = errors.New("ignore")
type Ref[O any] struct {
Count int
Out O
}
type AddFunc[I, O any] func(prefix netip.Prefix, in I) (out O, err error)
type RemoveFunc[I, O any] func(prefix netip.Prefix, out O) error
type Counter[I, O any] struct {
// refCountMap keeps track of the reference Ref for prefixes
refCountMap map[netip.Prefix]Ref[O]
refCountMu sync.Mutex
// idMap keeps track of the prefixes associated with an ID for removal
idMap map[string][]netip.Prefix
idMu sync.Mutex
add AddFunc[I, O]
remove RemoveFunc[I, O]
}
// New creates a new Counter instance
func New[I, O any](add AddFunc[I, O], remove RemoveFunc[I, O]) *Counter[I, O] {
return &Counter[I, O]{
refCountMap: map[netip.Prefix]Ref[O]{},
idMap: map[string][]netip.Prefix{},
add: add,
remove: remove,
}
}
// Increment increments the reference count for the given prefix.
// If this is the first reference to the prefix, the AddFunc is called.
func (rm *Counter[I, O]) Increment(prefix netip.Prefix, in I) (Ref[O], error) {
rm.refCountMu.Lock()
defer rm.refCountMu.Unlock()
ref := rm.refCountMap[prefix]
log.Tracef("Increasing ref count %d for prefix %s with [%v]", ref.Count, prefix, ref.Out)
// Call AddFunc only if it's a new prefix
if ref.Count == 0 {
log.Tracef("Adding for prefix %s with [%v]", prefix, ref.Out)
out, err := rm.add(prefix, in)
if errors.Is(err, ErrIgnore) {
return ref, nil
}
if err != nil {
return ref, fmt.Errorf("failed to add for prefix %s: %w", prefix, err)
}
ref.Out = out
}
ref.Count++
rm.refCountMap[prefix] = ref
return ref, nil
}
// IncrementWithID increments the reference count for the given prefix and groups it under the given ID.
// If this is the first reference to the prefix, the AddFunc is called.
func (rm *Counter[I, O]) IncrementWithID(id string, prefix netip.Prefix, in I) (Ref[O], error) {
rm.idMu.Lock()
defer rm.idMu.Unlock()
ref, err := rm.Increment(prefix, in)
if err != nil {
return ref, fmt.Errorf("with ID: %w", err)
}
rm.idMap[id] = append(rm.idMap[id], prefix)
return ref, nil
}
// Decrement decrements the reference count for the given prefix.
// If the reference count reaches 0, the RemoveFunc is called.
func (rm *Counter[I, O]) Decrement(prefix netip.Prefix) (Ref[O], error) {
rm.refCountMu.Lock()
defer rm.refCountMu.Unlock()
ref, ok := rm.refCountMap[prefix]
if !ok {
log.Tracef("No reference found for prefix %s", prefix)
return ref, nil
}
log.Tracef("Decreasing ref count %d for prefix %s with [%v]", ref.Count, prefix, ref.Out)
if ref.Count == 1 {
log.Tracef("Removing for prefix %s with [%v]", prefix, ref.Out)
if err := rm.remove(prefix, ref.Out); err != nil {
return ref, fmt.Errorf("remove for prefix %s: %w", prefix, err)
}
delete(rm.refCountMap, prefix)
} else {
ref.Count--
rm.refCountMap[prefix] = ref
}
return ref, nil
}
// DecrementWithID decrements the reference count for all prefixes associated with the given ID.
// If the reference count reaches 0, the RemoveFunc is called.
func (rm *Counter[I, O]) DecrementWithID(id string) error {
rm.idMu.Lock()
defer rm.idMu.Unlock()
var merr *multierror.Error
for _, prefix := range rm.idMap[id] {
if _, err := rm.Decrement(prefix); err != nil {
merr = multierror.Append(merr, err)
}
}
delete(rm.idMap, id)
return nberrors.FormatErrorOrNil(merr)
}
// Flush removes all references and calls RemoveFunc for each prefix.
func (rm *Counter[I, O]) Flush() error {
rm.refCountMu.Lock()
defer rm.refCountMu.Unlock()
rm.idMu.Lock()
defer rm.idMu.Unlock()
var merr *multierror.Error
for prefix := range rm.refCountMap {
log.Tracef("Removing for prefix %s", prefix)
ref := rm.refCountMap[prefix]
if err := rm.remove(prefix, ref.Out); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove for prefix %s: %w", prefix, err))
}
}
rm.refCountMap = map[netip.Prefix]Ref[O]{}
rm.idMap = map[string][]netip.Prefix{}
return nberrors.FormatErrorOrNil(merr)
}

View File

@@ -1,7 +0,0 @@
package refcounter
// RouteRefCounter is a Counter for Route, it doesn't take any input on Increment and doesn't use any output on Decrement
type RouteRefCounter = Counter[any, any]
// AllowedIPsRefCounter is a Counter for AllowedIPs, it takes a peer key on Increment and passes it back to Decrement
type AllowedIPsRefCounter = Counter[string, string]

View File

@@ -3,7 +3,7 @@ package routemanager
import "github.com/netbirdio/netbird/route"
type serverRouter interface {
updateRoutes(map[route.ID]*route.Route) error
updateRoutes(map[string]*route.Route) error
removeFromServerNetwork(*route.Route) error
cleanUp()
}

View File

@@ -7,10 +7,9 @@ import (
"fmt"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/iface"
)
func newServerRouter(context.Context, *iface.WGIface, firewall.Manager, *peer.Status) (serverRouter, error) {
func newServerRouter(context.Context, *iface.WGIface, firewall.Manager) (serverRouter, error) {
return nil, fmt.Errorf("server route not supported on this os")
}

View File

@@ -4,40 +4,35 @@ package routemanager
import (
"context"
"fmt"
"net/netip"
"sync"
log "github.com/sirupsen/logrus"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
"github.com/netbirdio/netbird/iface"
"github.com/netbirdio/netbird/route"
)
type defaultServerRouter struct {
mux sync.Mutex
ctx context.Context
routes map[route.ID]*route.Route
firewall firewall.Manager
wgInterface *iface.WGIface
statusRecorder *peer.Status
mux sync.Mutex
ctx context.Context
routes map[string]*route.Route
firewall firewall.Manager
wgInterface *iface.WGIface
}
func newServerRouter(ctx context.Context, wgInterface *iface.WGIface, firewall firewall.Manager, statusRecorder *peer.Status) (serverRouter, error) {
func newServerRouter(ctx context.Context, wgInterface *iface.WGIface, firewall firewall.Manager) (serverRouter, error) {
return &defaultServerRouter{
ctx: ctx,
routes: make(map[route.ID]*route.Route),
firewall: firewall,
wgInterface: wgInterface,
statusRecorder: statusRecorder,
ctx: ctx,
routes: make(map[string]*route.Route),
firewall: firewall,
wgInterface: wgInterface,
}, nil
}
func (m *defaultServerRouter) updateRoutes(routesMap map[route.ID]*route.Route) error {
serverRoutesToRemove := make([]route.ID, 0)
func (m *defaultServerRouter) updateRoutes(routesMap map[string]*route.Route) error {
serverRoutesToRemove := make([]string, 0)
for routeID := range m.routes {
update, found := routesMap[routeID]
@@ -50,7 +45,7 @@ func (m *defaultServerRouter) updateRoutes(routesMap map[route.ID]*route.Route)
oldRoute := m.routes[routeID]
err := m.removeFromServerNetwork(oldRoute)
if err != nil {
log.Errorf("Unable to remove route id: %s, network %s, from server, got: %v",
log.Errorf("unable to remove route id: %s, network %s, from server, got: %v",
oldRoute.ID, oldRoute.Network, err)
}
delete(m.routes, routeID)
@@ -64,14 +59,14 @@ func (m *defaultServerRouter) updateRoutes(routesMap map[route.ID]*route.Route)
err := m.addToServerNetwork(newRoute)
if err != nil {
log.Errorf("Unable to add route %s from server, got: %v", newRoute.ID, err)
log.Errorf("unable to add route %s from server, got: %v", newRoute.ID, err)
continue
}
m.routes[id] = newRoute
}
if len(m.routes) > 0 {
err := systemops.EnableIPForwarding()
err := enableIPForwarding()
if err != nil {
return err
}
@@ -83,28 +78,16 @@ func (m *defaultServerRouter) updateRoutes(routesMap map[route.ID]*route.Route)
func (m *defaultServerRouter) removeFromServerNetwork(route *route.Route) error {
select {
case <-m.ctx.Done():
log.Infof("Not removing from server network because context is done")
log.Infof("not removing from server network because context is done")
return m.ctx.Err()
default:
m.mux.Lock()
defer m.mux.Unlock()
routerPair, err := routeToRouterPair(route)
err := m.firewall.RemoveRoutingRules(routeToRouterPair(m.wgInterface.Address().String(), route))
if err != nil {
return fmt.Errorf("parse prefix: %w", err)
return err
}
err = m.firewall.RemoveRoutingRules(routerPair)
if err != nil {
return fmt.Errorf("remove routing rules: %w", err)
}
delete(m.routes, route.ID)
state := m.statusRecorder.GetLocalPeerState()
delete(state.Routes, route.Network.String())
m.statusRecorder.UpdateLocalPeerState(state)
return nil
}
}
@@ -112,37 +95,16 @@ func (m *defaultServerRouter) removeFromServerNetwork(route *route.Route) error
func (m *defaultServerRouter) addToServerNetwork(route *route.Route) error {
select {
case <-m.ctx.Done():
log.Infof("Not adding to server network because context is done")
log.Infof("not adding to server network because context is done")
return m.ctx.Err()
default:
m.mux.Lock()
defer m.mux.Unlock()
routerPair, err := routeToRouterPair(route)
err := m.firewall.InsertRoutingRules(routeToRouterPair(m.wgInterface.Address().String(), route))
if err != nil {
return fmt.Errorf("parse prefix: %w", err)
return err
}
err = m.firewall.InsertRoutingRules(routerPair)
if err != nil {
return fmt.Errorf("insert routing rules: %w", err)
}
m.routes[route.ID] = route
state := m.statusRecorder.GetLocalPeerState()
if state.Routes == nil {
state.Routes = map[string]struct{}{}
}
routeStr := route.Network.String()
if route.IsDynamic() {
routeStr = route.Domains.SafeString()
}
state.Routes[routeStr] = struct{}{}
m.statusRecorder.UpdateLocalPeerState(state)
return nil
}
}
@@ -151,45 +113,19 @@ func (m *defaultServerRouter) cleanUp() {
m.mux.Lock()
defer m.mux.Unlock()
for _, r := range m.routes {
routerPair, err := routeToRouterPair(r)
err := m.firewall.RemoveRoutingRules(routeToRouterPair(m.wgInterface.Address().String(), r))
if err != nil {
log.Errorf("Failed to convert route to router pair: %v", err)
continue
log.Warnf("failed to remove clean up route: %s", r.ID)
}
err = m.firewall.RemoveRoutingRules(routerPair)
if err != nil {
log.Errorf("Failed to remove cleanup route: %v", err)
}
}
state := m.statusRecorder.GetLocalPeerState()
state.Routes = nil
m.statusRecorder.UpdateLocalPeerState(state)
}
func routeToRouterPair(route *route.Route) (firewall.RouterPair, error) {
// TODO: add ipv6
source := getDefaultPrefix(route.Network)
destination := route.Network.Masked().String()
if route.IsDynamic() {
// TODO: add ipv6
destination = "0.0.0.0/0"
}
func routeToRouterPair(source string, route *route.Route) firewall.RouterPair {
parsed := netip.MustParsePrefix(source).Masked()
return firewall.RouterPair{
ID: string(route.ID),
Source: source.String(),
Destination: destination,
ID: route.ID,
Source: parsed.String(),
Destination: route.Network.Masked().String(),
Masquerade: route.Masquerade,
}, nil
}
func getDefaultPrefix(prefix netip.Prefix) netip.Prefix {
if prefix.Addr().Is6() {
return netip.PrefixFrom(netip.IPv6Unspecified(), 0)
}
return netip.PrefixFrom(netip.IPv4Unspecified(), 0)
}

View File

@@ -1,57 +0,0 @@
package static
import (
"context"
"fmt"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
"github.com/netbirdio/netbird/route"
)
type Route struct {
route *route.Route
routeRefCounter *refcounter.RouteRefCounter
allowedIPsRefcounter *refcounter.AllowedIPsRefCounter
}
func NewRoute(rt *route.Route, routeRefCounter *refcounter.RouteRefCounter, allowedIPsRefCounter *refcounter.AllowedIPsRefCounter) *Route {
return &Route{
route: rt,
routeRefCounter: routeRefCounter,
allowedIPsRefcounter: allowedIPsRefCounter,
}
}
// Route route methods
func (r *Route) String() string {
return r.route.Network.String()
}
func (r *Route) AddRoute(context.Context) error {
_, err := r.routeRefCounter.Increment(r.route.Network, nil)
return err
}
func (r *Route) RemoveRoute() error {
_, err := r.routeRefCounter.Decrement(r.route.Network)
return err
}
func (r *Route) AddAllowedIPs(peerKey string) error {
if ref, err := r.allowedIPsRefcounter.Increment(r.route.Network, peerKey); err != nil {
return fmt.Errorf("add allowed IP %s: %w", r.route.Network, err)
} else if ref.Count > 1 && ref.Out != peerKey {
log.Warnf("Prefix [%s] is already routed by peer [%s]. HA routing disabled",
r.route.Network,
ref.Out,
)
}
return nil
}
func (r *Route) RemoveAllowedIPs() error {
_, err := r.allowedIPsRefcounter.Decrement(r.route.Network)
return err
}

View File

@@ -1,103 +0,0 @@
// go:build !android
package sysctl
import (
"fmt"
"net"
"os"
"strconv"
"strings"
"github.com/hashicorp/go-multierror"
log "github.com/sirupsen/logrus"
nberrors "github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/iface"
)
const (
rpFilterPath = "net.ipv4.conf.all.rp_filter"
rpFilterInterfacePath = "net.ipv4.conf.%s.rp_filter"
srcValidMarkPath = "net.ipv4.conf.all.src_valid_mark"
)
// Setup configures sysctl settings for RP filtering and source validation.
func Setup(wgIface *iface.WGIface) (map[string]int, error) {
keys := map[string]int{}
var result *multierror.Error
oldVal, err := Set(srcValidMarkPath, 1, false)
if err != nil {
result = multierror.Append(result, err)
} else {
keys[srcValidMarkPath] = oldVal
}
oldVal, err = Set(rpFilterPath, 2, true)
if err != nil {
result = multierror.Append(result, err)
} else {
keys[rpFilterPath] = oldVal
}
interfaces, err := net.Interfaces()
if err != nil {
result = multierror.Append(result, fmt.Errorf("list interfaces: %w", err))
}
for _, intf := range interfaces {
if intf.Name == "lo" || wgIface != nil && intf.Name == wgIface.Name() {
continue
}
i := fmt.Sprintf(rpFilterInterfacePath, intf.Name)
oldVal, err := Set(i, 2, true)
if err != nil {
result = multierror.Append(result, err)
} else {
keys[i] = oldVal
}
}
return keys, nberrors.FormatErrorOrNil(result)
}
// Set sets a sysctl configuration, if onlyIfOne is true it will only set the new value if it's set to 1
func Set(key string, desiredValue int, onlyIfOne bool) (int, error) {
path := fmt.Sprintf("/proc/sys/%s", strings.ReplaceAll(key, ".", "/"))
currentValue, err := os.ReadFile(path)
if err != nil {
return -1, fmt.Errorf("read sysctl %s: %w", key, err)
}
currentV, err := strconv.Atoi(strings.TrimSpace(string(currentValue)))
if err != nil && len(currentValue) > 0 {
return -1, fmt.Errorf("convert current desiredValue to int: %w", err)
}
if currentV == desiredValue || onlyIfOne && currentV != 1 {
return currentV, nil
}
//nolint:gosec
if err := os.WriteFile(path, []byte(strconv.Itoa(desiredValue)), 0644); err != nil {
return currentV, fmt.Errorf("write sysctl %s: %w", key, err)
}
log.Debugf("Set sysctl %s from %d to %d", key, currentV, desiredValue)
return currentV, nil
}
// Cleanup resets sysctl settings to their original values.
func Cleanup(originalSettings map[string]int) error {
var result *multierror.Error
for key, value := range originalSettings {
_, err := Set(key, value, false)
if err != nil {
result = multierror.Append(result, err)
}
}
return nberrors.FormatErrorOrNil(result)
}

View File

@@ -1,18 +0,0 @@
//go:build darwin || dragonfly || netbsd || openbsd
package systemops
import "syscall"
// filterRoutesByFlags - return true if need to ignore such route message because it consists specific flags.
func filterRoutesByFlags(routeMessageFlags int) bool {
if routeMessageFlags&syscall.RTF_UP == 0 {
return true
}
if routeMessageFlags&(syscall.RTF_REJECT|syscall.RTF_BLACKHOLE|syscall.RTF_WASCLONED) != 0 {
return true
}
return false
}

View File

@@ -1,19 +0,0 @@
//go:build: freebsd
package systemops
import "syscall"
// filterRoutesByFlags - return true if need to ignore such route message because it consists specific flags.
func filterRoutesByFlags(routeMessageFlags int) bool {
if routeMessageFlags&syscall.RTF_UP == 0 {
return true
}
// NOTE: syscall.RTF_WASCLONED deprecated in FreeBSD 8.0 (https://www.freebsd.org/releases/8.0R/relnotes-detailed/)
// a concept of cloned route (a route generated by an entry with RTF_CLONING flag) is deprecated.
if routeMessageFlags&(syscall.RTF_REJECT|syscall.RTF_BLACKHOLE) != 0 {
return true
}
return false
}

View File

@@ -1,27 +0,0 @@
package systemops
import (
"net"
"net/netip"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
"github.com/netbirdio/netbird/iface"
)
type Nexthop struct {
IP netip.Addr
Intf *net.Interface
}
type ExclusionCounter = refcounter.Counter[any, Nexthop]
type SysOps struct {
refCounter *ExclusionCounter
wgInterface *iface.WGIface
}
func NewSysOps(wgInterface *iface.WGIface) *SysOps {
return &SysOps{
wgInterface: wgInterface,
}
}

View File

@@ -1,160 +0,0 @@
//go:build darwin || dragonfly || freebsd || netbsd || openbsd
package systemops
import (
"errors"
"fmt"
"net"
"net/netip"
"strconv"
"syscall"
"time"
"github.com/cenkalti/backoff/v4"
log "github.com/sirupsen/logrus"
"golang.org/x/net/route"
)
type Route struct {
Dst netip.Prefix
Gw netip.Addr
Interface *net.Interface
}
func getRoutesFromTable() ([]netip.Prefix, error) {
tab, err := retryFetchRIB()
if err != nil {
return nil, fmt.Errorf("fetch RIB: %v", err)
}
msgs, err := route.ParseRIB(route.RIBTypeRoute, tab)
if err != nil {
return nil, fmt.Errorf("parse RIB: %v", err)
}
var prefixList []netip.Prefix
for _, msg := range msgs {
m := msg.(*route.RouteMessage)
if m.Version < 3 || m.Version > 5 {
return nil, fmt.Errorf("unexpected RIB message version: %d", m.Version)
}
if m.Type != syscall.RTM_GET {
return nil, fmt.Errorf("unexpected RIB message type: %d", m.Type)
}
if filterRoutesByFlags(m.Flags) {
continue
}
route, err := MsgToRoute(m)
if err != nil {
log.Warnf("Failed to parse route message: %v", err)
continue
}
if route.Dst.IsValid() {
prefixList = append(prefixList, route.Dst)
}
}
return prefixList, nil
}
func retryFetchRIB() ([]byte, error) {
var out []byte
operation := func() error {
var err error
out, err = route.FetchRIB(syscall.AF_UNSPEC, route.RIBTypeRoute, 0)
if errors.Is(err, syscall.ENOMEM) {
log.Debug("~etrying fetchRIB due to 'cannot allocate memory' error")
return err
} else if err != nil {
return backoff.Permanent(err)
}
return nil
}
expBackOff := backoff.NewExponentialBackOff()
expBackOff.InitialInterval = 50 * time.Millisecond
expBackOff.MaxInterval = 500 * time.Millisecond
expBackOff.MaxElapsedTime = 1 * time.Second
err := backoff.Retry(operation, expBackOff)
if err != nil {
return nil, fmt.Errorf("failed to fetch routing information: %w", err)
}
return out, nil
}
func toNetIP(a route.Addr) netip.Addr {
switch t := a.(type) {
case *route.Inet4Addr:
return netip.AddrFrom4(t.IP)
case *route.Inet6Addr:
ip := netip.AddrFrom16(t.IP)
if t.ZoneID != 0 {
ip = ip.WithZone(strconv.Itoa(t.ZoneID))
}
return ip
default:
return netip.Addr{}
}
}
// ones returns the number of leading ones in the mask.
func ones(a route.Addr) (int, error) {
switch t := a.(type) {
case *route.Inet4Addr:
mask, _ := net.IPMask(t.IP[:]).Size()
return mask, nil
case *route.Inet6Addr:
mask, _ := net.IPMask(t.IP[:]).Size()
return mask, nil
default:
return 0, fmt.Errorf("unexpected address type: %T", a)
}
}
// MsgToRoute converts a route message to a Route.
func MsgToRoute(msg *route.RouteMessage) (*Route, error) {
dstIP, nexthop, dstMask := msg.Addrs[0], msg.Addrs[1], msg.Addrs[2]
addr := toNetIP(dstIP)
var nexthopAddr netip.Addr
var nexthopIntf *net.Interface
switch t := nexthop.(type) {
case *route.Inet4Addr, *route.Inet6Addr:
nexthopAddr = toNetIP(t)
case *route.LinkAddr:
nexthopIntf = &net.Interface{
Index: t.Index,
Name: t.Name,
}
default:
return nil, fmt.Errorf("unexpected next hop type: %T", t)
}
var prefix netip.Prefix
if dstMask == nil {
if addr.Is4() {
prefix = netip.PrefixFrom(addr, 32)
} else {
prefix = netip.PrefixFrom(addr, 128)
}
} else {
bits, err := ones(dstMask)
if err != nil {
return nil, fmt.Errorf("failed to parse mask: %v", dstMask)
}
prefix = netip.PrefixFrom(addr, bits)
}
return &Route{
Dst: prefix,
Gw: nexthopAddr,
Interface: nexthopIntf,
}, nil
}

View File

@@ -1,188 +0,0 @@
//go:build darwin || dragonfly || freebsd || netbsd || openbsd
package systemops
import (
"fmt"
"net"
"net/netip"
"os/exec"
"regexp"
"sync"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/net/route"
)
var expectedVPNint = "utun100"
var expectedExternalInt = "lo0"
var expectedInternalInt = "lo0"
func init() {
testCases = append(testCases, []testCase{
{
name: "To more specific route without custom dialer via vpn",
destination: "10.10.0.2:53",
expectedInterface: expectedVPNint,
dialer: &net.Dialer{},
expectedPacket: createPacketExpectation("100.64.0.1", 12345, "10.10.0.2", 53),
},
}...)
}
func TestConcurrentRoutes(t *testing.T) {
baseIP := netip.MustParseAddr("192.0.2.0")
intf := &net.Interface{Name: "lo0"}
r := NewSysOps(nil)
var wg sync.WaitGroup
for i := 0; i < 1024; i++ {
wg.Add(1)
go func(ip netip.Addr) {
defer wg.Done()
prefix := netip.PrefixFrom(ip, 32)
if err := r.addToRouteTable(prefix, Nexthop{netip.Addr{}, intf}); err != nil {
t.Errorf("Failed to add route for %s: %v", prefix, err)
}
}(baseIP)
baseIP = baseIP.Next()
}
wg.Wait()
baseIP = netip.MustParseAddr("192.0.2.0")
for i := 0; i < 1024; i++ {
wg.Add(1)
go func(ip netip.Addr) {
defer wg.Done()
prefix := netip.PrefixFrom(ip, 32)
if err := r.removeFromRouteTable(prefix, Nexthop{netip.Addr{}, intf}); err != nil {
t.Errorf("Failed to remove route for %s: %v", prefix, err)
}
}(baseIP)
baseIP = baseIP.Next()
}
wg.Wait()
}
func TestBits(t *testing.T) {
tests := []struct {
name string
addr route.Addr
want int
wantErr bool
}{
{
name: "IPv4 all ones",
addr: &route.Inet4Addr{IP: [4]byte{255, 255, 255, 255}},
want: 32,
},
{
name: "IPv4 normal mask",
addr: &route.Inet4Addr{IP: [4]byte{255, 255, 255, 0}},
want: 24,
},
{
name: "IPv6 all ones",
addr: &route.Inet6Addr{IP: [16]byte{255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255}},
want: 128,
},
{
name: "IPv6 normal mask",
addr: &route.Inet6Addr{IP: [16]byte{255, 255, 255, 255, 255, 255, 255, 255, 0, 0, 0, 0, 0, 0, 0, 0}},
want: 64,
},
{
name: "Unsupported type",
addr: &route.LinkAddr{},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := ones(tt.addr)
if tt.wantErr {
assert.Error(t, err)
} else {
assert.NoError(t, err)
assert.Equal(t, tt.want, got)
}
})
}
}
func createAndSetupDummyInterface(t *testing.T, intf string, ipAddressCIDR string) string {
t.Helper()
err := exec.Command("ifconfig", intf, "alias", ipAddressCIDR).Run()
require.NoError(t, err, "Failed to create loopback alias")
t.Cleanup(func() {
err := exec.Command("ifconfig", intf, ipAddressCIDR, "-alias").Run()
assert.NoError(t, err, "Failed to remove loopback alias")
})
return "lo0"
}
func addDummyRoute(t *testing.T, dstCIDR string, gw net.IP, _ string) {
t.Helper()
var originalNexthop net.IP
if dstCIDR == "0.0.0.0/0" {
var err error
originalNexthop, err = fetchOriginalGateway()
if err != nil {
t.Logf("Failed to fetch original gateway: %v", err)
}
if output, err := exec.Command("route", "delete", "-net", dstCIDR).CombinedOutput(); err != nil {
t.Logf("Failed to delete route: %v, output: %s", err, output)
}
}
t.Cleanup(func() {
if originalNexthop != nil {
err := exec.Command("route", "add", "-net", dstCIDR, originalNexthop.String()).Run()
assert.NoError(t, err, "Failed to restore original route")
}
})
err := exec.Command("route", "add", "-net", dstCIDR, gw.String()).Run()
require.NoError(t, err, "Failed to add route")
t.Cleanup(func() {
err := exec.Command("route", "delete", "-net", dstCIDR).Run()
assert.NoError(t, err, "Failed to remove route")
})
}
func fetchOriginalGateway() (net.IP, error) {
output, err := exec.Command("route", "-n", "get", "default").CombinedOutput()
if err != nil {
return nil, err
}
matches := regexp.MustCompile(`gateway: (\S+)`).FindStringSubmatch(string(output))
if len(matches) == 0 {
return nil, fmt.Errorf("gateway not found")
}
return net.ParseIP(matches[1]), nil
}
func setupDummyInterfacesAndRoutes(t *testing.T) {
t.Helper()
defaultDummy := createAndSetupDummyInterface(t, expectedExternalInt, "192.168.0.1/24")
addDummyRoute(t, "0.0.0.0/0", net.IPv4(192, 168, 0, 1), defaultDummy)
otherDummy := createAndSetupDummyInterface(t, expectedInternalInt, "192.168.1.1/24")
addDummyRoute(t, "10.0.0.0/8", net.IPv4(192, 168, 1, 1), otherDummy)
}

Some files were not shown because too many files have changed in this diff Show More