mirror of
https://github.com/netbirdio/netbird.git
synced 2026-03-31 14:44:34 -04:00
Compare commits
79 Commits
v0.26.7
...
deploy/pee
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
8c44187900 | ||
|
|
0dc050b354 | ||
|
|
69eb6f17c6 | ||
|
|
ccbe2e0100 | ||
|
|
5b042eee74 | ||
|
|
125c418d85 | ||
|
|
9f607029c0 | ||
|
|
4df1d556c8 | ||
|
|
acaf315152 | ||
|
|
81b6f7dc56 | ||
|
|
cb9a8e9fa4 | ||
|
|
f68d9ce042 | ||
|
|
1f182411c6 | ||
|
|
68e971cb82 | ||
|
|
f4a464fb69 | ||
|
|
7aff0e828b | ||
|
|
90ebf0017b | ||
|
|
bf8c6b95de | ||
|
|
8415064758 | ||
|
|
809f6cdbc7 | ||
|
|
e435e39158 | ||
|
|
fd26e989e3 | ||
|
|
7474876d6a | ||
|
|
2095e35217 | ||
|
|
02537e5536 | ||
|
|
f99477d3db | ||
|
|
3085383729 | ||
|
|
07bdf977d0 | ||
|
|
4a147006d2 | ||
|
|
4424162bce | ||
|
|
480192028e | ||
|
|
54b045d9ca | ||
|
|
71c6437bab | ||
|
|
027c0e5c63 | ||
|
|
fbb1181b64 | ||
|
|
e50ed290d9 | ||
|
|
8f8cfcbc20 | ||
|
|
7b254cb966 | ||
|
|
8f3a0f2c38 | ||
|
|
1f33e2e003 | ||
|
|
1e6addaa65 | ||
|
|
f51dc13f8c | ||
|
|
3477108ce7 | ||
|
|
012e624296 | ||
|
|
4c5e987e02 | ||
|
|
a80c8b0176 | ||
|
|
9e01155d2e | ||
|
|
3c3111ad01 | ||
|
|
b74078fd95 | ||
|
|
77488ad11a | ||
|
|
e3b76448f3 | ||
|
|
e0de86d6c9 | ||
|
|
5204d07811 | ||
|
|
5ea24ba56e | ||
|
|
d30cf8706a | ||
|
|
15a2feb723 | ||
|
|
91b2f9fc51 | ||
|
|
76702c8a09 | ||
|
|
061f673a4f | ||
|
|
9505805313 | ||
|
|
704c67dec8 | ||
|
|
3ed2f08f3c | ||
|
|
4c83408f27 | ||
|
|
90bd39c740 | ||
|
|
dd0cf41147 | ||
|
|
22b2caffc6 | ||
|
|
c1f66d1354 | ||
|
|
ac0fe6025b | ||
|
|
c28657710a | ||
|
|
3875c29f6b | ||
|
|
9f32ccd453 | ||
|
|
1d1d057e7d | ||
|
|
3461b1bb90 | ||
|
|
3d2a2377c6 | ||
|
|
25f5f26527 | ||
|
|
bb0d5c5baf | ||
|
|
7938295190 | ||
|
|
9af532fe71 | ||
|
|
23a1473797 |
3
.github/workflows/golang-test-darwin.yml
vendored
3
.github/workflows/golang-test-darwin.yml
vendored
@@ -32,6 +32,9 @@ jobs:
|
||||
restore-keys: |
|
||||
macos-go-
|
||||
|
||||
- name: Install libpcap
|
||||
run: brew install libpcap
|
||||
|
||||
- name: Install modules
|
||||
run: go mod tidy
|
||||
|
||||
|
||||
2
.github/workflows/golang-test-windows.yml
vendored
2
.github/workflows/golang-test-windows.yml
vendored
@@ -46,7 +46,7 @@ jobs:
|
||||
- run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOCACHE=C:\Users\runneradmin\AppData\Local\go-build
|
||||
|
||||
- name: test
|
||||
run: PsExec64 -s -w ${{ github.workspace }} cmd.exe /c "C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe test -timeout 5m -p 1 ./... > test-out.txt 2>&1"
|
||||
run: PsExec64 -s -w ${{ github.workspace }} cmd.exe /c "C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe test -timeout 10m -p 1 ./... > test-out.txt 2>&1"
|
||||
- name: test output
|
||||
if: ${{ always() }}
|
||||
run: Get-Content test-out.txt
|
||||
|
||||
6
.github/workflows/golangci-lint.yml
vendored
6
.github/workflows/golangci-lint.yml
vendored
@@ -19,7 +19,7 @@ jobs:
|
||||
- name: codespell
|
||||
uses: codespell-project/actions-codespell@v2
|
||||
with:
|
||||
ignore_words_list: erro,clienta
|
||||
ignore_words_list: erro,clienta,hastable,
|
||||
skip: go.mod,go.sum
|
||||
only_warn: 1
|
||||
golangci:
|
||||
@@ -33,6 +33,10 @@ jobs:
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v3
|
||||
- name: Check for duplicate constants
|
||||
if: matrix.os == 'ubuntu-latest'
|
||||
run: |
|
||||
! awk '/const \(/,/)/{print $0}' management/server/activity/codes.go | grep -o '= [0-9]*' | sort | uniq -d | grep .
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v4
|
||||
with:
|
||||
|
||||
@@ -38,7 +38,7 @@ jobs:
|
||||
- name: Setup NDK
|
||||
run: /usr/local/lib/android/sdk/cmdline-tools/7.0/bin/sdkmanager --install "ndk;23.1.7779620"
|
||||
- name: install gomobile
|
||||
run: go install golang.org/x/mobile/cmd/gomobile@v0.0.0-20230531173138-3c911d8e3eda
|
||||
run: go install golang.org/x/mobile/cmd/gomobile@v0.0.0-20240404231514-09dbf07665ed
|
||||
- name: gomobile init
|
||||
run: gomobile init
|
||||
- name: build android netbird lib
|
||||
@@ -56,10 +56,10 @@ jobs:
|
||||
with:
|
||||
go-version: "1.21.x"
|
||||
- name: install gomobile
|
||||
run: go install golang.org/x/mobile/cmd/gomobile@v0.0.0-20230531173138-3c911d8e3eda
|
||||
run: go install golang.org/x/mobile/cmd/gomobile@v0.0.0-20240404231514-09dbf07665ed
|
||||
- name: gomobile init
|
||||
run: gomobile init
|
||||
- name: build iOS netbird lib
|
||||
run: PATH=$PATH:$(go env GOPATH) gomobile bind -target=ios -bundleid=io.netbird.framework -ldflags="-X github.com/netbirdio/netbird/version.version=buildtest" -o $GITHUB_WORKSPACE/NetBirdSDK.xcframework $GITHUB_WORKSPACE/client/ios/NetBirdSDK
|
||||
run: PATH=$PATH:$(go env GOPATH) gomobile bind -target=ios -bundleid=io.netbird.framework -ldflags="-X github.com/netbirdio/netbird/version.version=buildtest" -o ./NetBirdSDK.xcframework ./client/ios/NetBirdSDK
|
||||
env:
|
||||
CGO_ENABLED: 0
|
||||
13
README.md
13
README.md
@@ -40,11 +40,12 @@
|
||||
|
||||
**Connect.** NetBird creates a WireGuard-based overlay network that automatically connects your machines over an encrypted tunnel, leaving behind the hassle of opening ports, complex firewall rules, VPN gateways, and so forth.
|
||||
|
||||
**Secure.** NetBird enables secure remote access by applying granular access policies, while allowing you to manage them intuitively from a single place. Works universally on any infrastructure.
|
||||
**Secure.** NetBird enables secure remote access by applying granular access policies while allowing you to manage them intuitively from a single place. Works universally on any infrastructure.
|
||||
|
||||
### Open-Source Network Security in a Single Platform
|
||||
|
||||

|
||||
|
||||

|
||||
|
||||
|
||||
### Key features
|
||||
@@ -77,7 +78,7 @@ Follow the [Advanced guide with a custom identity provider](https://docs.netbird
|
||||
- **Public domain** name pointing to the VM.
|
||||
|
||||
**Software requirements:**
|
||||
- Docker installed on the VM with the docker compose plugin ([Docker installation guide](https://docs.docker.com/engine/install/)) or docker with docker-compose in version 2 or higher.
|
||||
- Docker installed on the VM with the docker-compose plugin ([Docker installation guide](https://docs.docker.com/engine/install/)) or docker with docker-compose in version 2 or higher.
|
||||
- [jq](https://jqlang.github.io/jq/) installed. In most distributions
|
||||
Usually available in the official repositories and can be installed with `sudo apt install jq` or `sudo yum install jq`
|
||||
- [curl](https://curl.se/) installed.
|
||||
@@ -94,9 +95,9 @@ export NETBIRD_DOMAIN=netbird.example.com; curl -fsSL https://github.com/netbird
|
||||
- Every machine in the network runs [NetBird Agent (or Client)](client/) that manages WireGuard.
|
||||
- Every agent connects to [Management Service](management/) that holds network state, manages peer IPs, and distributes network updates to agents (peers).
|
||||
- NetBird agent uses WebRTC ICE implemented in [pion/ice library](https://github.com/pion/ice) to discover connection candidates when establishing a peer-to-peer connection between machines.
|
||||
- Connection candidates are discovered with a help of [STUN](https://en.wikipedia.org/wiki/STUN) servers.
|
||||
- Connection candidates are discovered with the help of [STUN](https://en.wikipedia.org/wiki/STUN) servers.
|
||||
- Agents negotiate a connection through [Signal Service](signal/) passing p2p encrypted messages with candidates.
|
||||
- Sometimes the NAT traversal is unsuccessful due to strict NATs (e.g. mobile carrier-grade NAT) and p2p connection isn't possible. When this occurs the system falls back to a relay server called [TURN](https://en.wikipedia.org/wiki/Traversal_Using_Relays_around_NAT), and a secure WireGuard tunnel is established via the TURN server.
|
||||
- Sometimes the NAT traversal is unsuccessful due to strict NATs (e.g. mobile carrier-grade NAT) and a p2p connection isn't possible. When this occurs the system falls back to a relay server called [TURN](https://en.wikipedia.org/wiki/Traversal_Using_Relays_around_NAT), and a secure WireGuard tunnel is established via the TURN server.
|
||||
|
||||
[Coturn](https://github.com/coturn/coturn) is the one that has been successfully used for STUN and TURN in NetBird setups.
|
||||
|
||||
@@ -120,7 +121,7 @@ In November 2022, NetBird joined the [StartUpSecure program](https://www.forschu
|
||||

|
||||
|
||||
### Testimonials
|
||||
We use open-source technologies like [WireGuard®](https://www.wireguard.com/), [Pion ICE (WebRTC)](https://github.com/pion/ice), [Coturn](https://github.com/coturn/coturn), and [Rosenpass](https://rosenpass.eu). We very much appreciate the work these guys are doing and we'd greatly appreciate if you could support them in any way (e.g. giving a star or a contribution).
|
||||
We use open-source technologies like [WireGuard®](https://www.wireguard.com/), [Pion ICE (WebRTC)](https://github.com/pion/ice), [Coturn](https://github.com/coturn/coturn), and [Rosenpass](https://rosenpass.eu). We very much appreciate the work these guys are doing and we'd greatly appreciate if you could support them in any way (e.g., by giving a star or a contribution).
|
||||
|
||||
### Legal
|
||||
_WireGuard_ and the _WireGuard_ logo are [registered trademarks](https://www.wireguard.com/trademark-policy/) of Jason A. Donenfeld.
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
FROM alpine:3.18.5
|
||||
RUN apk add --no-cache ca-certificates iptables ip6tables
|
||||
ENV NB_FOREGROUND_MODE=true
|
||||
ENTRYPOINT [ "/go/bin/netbird","up"]
|
||||
COPY netbird /go/bin/netbird
|
||||
ENTRYPOINT [ "/usr/local/bin/netbird","up"]
|
||||
COPY netbird /usr/local/bin/netbird
|
||||
212
client/anonymize/anonymize.go
Normal file
212
client/anonymize/anonymize.go
Normal file
@@ -0,0 +1,212 @@
|
||||
package anonymize
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"fmt"
|
||||
"math/big"
|
||||
"net"
|
||||
"net/netip"
|
||||
"net/url"
|
||||
"regexp"
|
||||
"slices"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type Anonymizer struct {
|
||||
ipAnonymizer map[netip.Addr]netip.Addr
|
||||
domainAnonymizer map[string]string
|
||||
currentAnonIPv4 netip.Addr
|
||||
currentAnonIPv6 netip.Addr
|
||||
startAnonIPv4 netip.Addr
|
||||
startAnonIPv6 netip.Addr
|
||||
}
|
||||
|
||||
func DefaultAddresses() (netip.Addr, netip.Addr) {
|
||||
// 192.51.100.0, 100::
|
||||
return netip.AddrFrom4([4]byte{198, 51, 100, 0}), netip.AddrFrom16([16]byte{0x01})
|
||||
}
|
||||
|
||||
func NewAnonymizer(startIPv4, startIPv6 netip.Addr) *Anonymizer {
|
||||
return &Anonymizer{
|
||||
ipAnonymizer: map[netip.Addr]netip.Addr{},
|
||||
domainAnonymizer: map[string]string{},
|
||||
currentAnonIPv4: startIPv4,
|
||||
currentAnonIPv6: startIPv6,
|
||||
startAnonIPv4: startIPv4,
|
||||
startAnonIPv6: startIPv6,
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Anonymizer) AnonymizeIP(ip netip.Addr) netip.Addr {
|
||||
if ip.IsLoopback() ||
|
||||
ip.IsLinkLocalUnicast() ||
|
||||
ip.IsLinkLocalMulticast() ||
|
||||
ip.IsInterfaceLocalMulticast() ||
|
||||
ip.IsPrivate() ||
|
||||
ip.IsUnspecified() ||
|
||||
ip.IsMulticast() ||
|
||||
isWellKnown(ip) ||
|
||||
a.isInAnonymizedRange(ip) {
|
||||
|
||||
return ip
|
||||
}
|
||||
|
||||
if _, ok := a.ipAnonymizer[ip]; !ok {
|
||||
if ip.Is4() {
|
||||
a.ipAnonymizer[ip] = a.currentAnonIPv4
|
||||
a.currentAnonIPv4 = a.currentAnonIPv4.Next()
|
||||
} else {
|
||||
a.ipAnonymizer[ip] = a.currentAnonIPv6
|
||||
a.currentAnonIPv6 = a.currentAnonIPv6.Next()
|
||||
}
|
||||
}
|
||||
return a.ipAnonymizer[ip]
|
||||
}
|
||||
|
||||
// isInAnonymizedRange checks if an IP is within the range of already assigned anonymized IPs
|
||||
func (a *Anonymizer) isInAnonymizedRange(ip netip.Addr) bool {
|
||||
if ip.Is4() && ip.Compare(a.startAnonIPv4) >= 0 && ip.Compare(a.currentAnonIPv4) <= 0 {
|
||||
return true
|
||||
} else if !ip.Is4() && ip.Compare(a.startAnonIPv6) >= 0 && ip.Compare(a.currentAnonIPv6) <= 0 {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (a *Anonymizer) AnonymizeIPString(ip string) string {
|
||||
addr, err := netip.ParseAddr(ip)
|
||||
if err != nil {
|
||||
return ip
|
||||
}
|
||||
|
||||
return a.AnonymizeIP(addr).String()
|
||||
}
|
||||
|
||||
func (a *Anonymizer) AnonymizeDomain(domain string) string {
|
||||
if strings.HasSuffix(domain, "netbird.io") ||
|
||||
strings.HasSuffix(domain, "netbird.selfhosted") ||
|
||||
strings.HasSuffix(domain, "netbird.cloud") ||
|
||||
strings.HasSuffix(domain, "netbird.stage") ||
|
||||
strings.HasSuffix(domain, ".domain") {
|
||||
return domain
|
||||
}
|
||||
|
||||
parts := strings.Split(domain, ".")
|
||||
if len(parts) < 2 {
|
||||
return domain
|
||||
}
|
||||
|
||||
baseDomain := parts[len(parts)-2] + "." + parts[len(parts)-1]
|
||||
|
||||
anonymized, ok := a.domainAnonymizer[baseDomain]
|
||||
if !ok {
|
||||
anonymizedBase := "anon-" + generateRandomString(5) + ".domain"
|
||||
a.domainAnonymizer[baseDomain] = anonymizedBase
|
||||
anonymized = anonymizedBase
|
||||
}
|
||||
|
||||
return strings.Replace(domain, baseDomain, anonymized, 1)
|
||||
}
|
||||
|
||||
func (a *Anonymizer) AnonymizeURI(uri string) string {
|
||||
u, err := url.Parse(uri)
|
||||
if err != nil {
|
||||
return uri
|
||||
}
|
||||
|
||||
var anonymizedHost string
|
||||
if u.Opaque != "" {
|
||||
host, port, err := net.SplitHostPort(u.Opaque)
|
||||
if err == nil {
|
||||
anonymizedHost = fmt.Sprintf("%s:%s", a.AnonymizeDomain(host), port)
|
||||
} else {
|
||||
anonymizedHost = a.AnonymizeDomain(u.Opaque)
|
||||
}
|
||||
u.Opaque = anonymizedHost
|
||||
} else if u.Host != "" {
|
||||
host, port, err := net.SplitHostPort(u.Host)
|
||||
if err == nil {
|
||||
anonymizedHost = fmt.Sprintf("%s:%s", a.AnonymizeDomain(host), port)
|
||||
} else {
|
||||
anonymizedHost = a.AnonymizeDomain(u.Host)
|
||||
}
|
||||
u.Host = anonymizedHost
|
||||
}
|
||||
return u.String()
|
||||
}
|
||||
|
||||
func (a *Anonymizer) AnonymizeString(str string) string {
|
||||
ipv4Regex := regexp.MustCompile(`\b(?:[0-9]{1,3}\.){3}[0-9]{1,3}\b`)
|
||||
ipv6Regex := regexp.MustCompile(`\b([0-9a-fA-F:]+:+[0-9a-fA-F]{0,4})(?:%[0-9a-zA-Z]+)?(?:\/[0-9]{1,3})?(?::[0-9]{1,5})?\b`)
|
||||
|
||||
str = ipv4Regex.ReplaceAllStringFunc(str, a.AnonymizeIPString)
|
||||
str = ipv6Regex.ReplaceAllStringFunc(str, a.AnonymizeIPString)
|
||||
|
||||
for domain, anonDomain := range a.domainAnonymizer {
|
||||
str = strings.ReplaceAll(str, domain, anonDomain)
|
||||
}
|
||||
|
||||
str = a.AnonymizeSchemeURI(str)
|
||||
str = a.AnonymizeDNSLogLine(str)
|
||||
|
||||
return str
|
||||
}
|
||||
|
||||
// AnonymizeSchemeURI finds and anonymizes URIs with stun, stuns, turn, and turns schemes.
|
||||
func (a *Anonymizer) AnonymizeSchemeURI(text string) string {
|
||||
re := regexp.MustCompile(`(?i)\b(stuns?:|turns?:|https?://)\S+\b`)
|
||||
|
||||
return re.ReplaceAllStringFunc(text, a.AnonymizeURI)
|
||||
}
|
||||
|
||||
// AnonymizeDNSLogLine anonymizes domain names in DNS log entries by replacing them with a random string.
|
||||
func (a *Anonymizer) AnonymizeDNSLogLine(logEntry string) string {
|
||||
domainPattern := `dns\.Question{Name:"([^"]+)",`
|
||||
domainRegex := regexp.MustCompile(domainPattern)
|
||||
|
||||
return domainRegex.ReplaceAllStringFunc(logEntry, func(match string) string {
|
||||
parts := strings.Split(match, `"`)
|
||||
if len(parts) >= 2 {
|
||||
domain := parts[1]
|
||||
if strings.HasSuffix(domain, ".domain") {
|
||||
return match
|
||||
}
|
||||
randomDomain := generateRandomString(10) + ".domain"
|
||||
return strings.Replace(match, domain, randomDomain, 1)
|
||||
}
|
||||
return match
|
||||
})
|
||||
}
|
||||
|
||||
func isWellKnown(addr netip.Addr) bool {
|
||||
wellKnown := []string{
|
||||
"8.8.8.8", "8.8.4.4", // Google DNS IPv4
|
||||
"2001:4860:4860::8888", "2001:4860:4860::8844", // Google DNS IPv6
|
||||
"1.1.1.1", "1.0.0.1", // Cloudflare DNS IPv4
|
||||
"2606:4700:4700::1111", "2606:4700:4700::1001", // Cloudflare DNS IPv6
|
||||
"9.9.9.9", "149.112.112.112", // Quad9 DNS IPv4
|
||||
"2620:fe::fe", "2620:fe::9", // Quad9 DNS IPv6
|
||||
}
|
||||
|
||||
if slices.Contains(wellKnown, addr.String()) {
|
||||
return true
|
||||
}
|
||||
|
||||
cgnatRangeStart := netip.AddrFrom4([4]byte{100, 64, 0, 0})
|
||||
cgnatRange := netip.PrefixFrom(cgnatRangeStart, 10)
|
||||
|
||||
return cgnatRange.Contains(addr)
|
||||
}
|
||||
|
||||
func generateRandomString(length int) string {
|
||||
const letters = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
|
||||
result := make([]byte, length)
|
||||
for i := range result {
|
||||
num, err := rand.Int(rand.Reader, big.NewInt(int64(len(letters))))
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
result[i] = letters[num.Int64()]
|
||||
}
|
||||
return string(result)
|
||||
}
|
||||
223
client/anonymize/anonymize_test.go
Normal file
223
client/anonymize/anonymize_test.go
Normal file
@@ -0,0 +1,223 @@
|
||||
package anonymize_test
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"regexp"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/client/anonymize"
|
||||
)
|
||||
|
||||
func TestAnonymizeIP(t *testing.T) {
|
||||
startIPv4 := netip.MustParseAddr("198.51.100.0")
|
||||
startIPv6 := netip.MustParseAddr("100::")
|
||||
anonymizer := anonymize.NewAnonymizer(startIPv4, startIPv6)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
ip string
|
||||
expect string
|
||||
}{
|
||||
{"Well known", "8.8.8.8", "8.8.8.8"},
|
||||
{"First Public IPv4", "1.2.3.4", "198.51.100.0"},
|
||||
{"Second Public IPv4", "4.3.2.1", "198.51.100.1"},
|
||||
{"Repeated IPv4", "1.2.3.4", "198.51.100.0"},
|
||||
{"Private IPv4", "192.168.1.1", "192.168.1.1"},
|
||||
{"First Public IPv6", "2607:f8b0:4005:805::200e", "100::"},
|
||||
{"Second Public IPv6", "a::b", "100::1"},
|
||||
{"Repeated IPv6", "2607:f8b0:4005:805::200e", "100::"},
|
||||
{"Private IPv6", "fe80::1", "fe80::1"},
|
||||
{"In Range IPv4", "198.51.100.2", "198.51.100.2"},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
ip := netip.MustParseAddr(tc.ip)
|
||||
anonymizedIP := anonymizer.AnonymizeIP(ip)
|
||||
if anonymizedIP.String() != tc.expect {
|
||||
t.Errorf("%s: expected %s, got %s", tc.name, tc.expect, anonymizedIP)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAnonymizeDNSLogLine(t *testing.T) {
|
||||
anonymizer := anonymize.NewAnonymizer(netip.Addr{}, netip.Addr{})
|
||||
testLog := `2024-04-23T20:01:11+02:00 TRAC client/internal/dns/local.go:25: received question: dns.Question{Name:"example.com", Qtype:0x1c, Qclass:0x1}`
|
||||
|
||||
result := anonymizer.AnonymizeDNSLogLine(testLog)
|
||||
require.NotEqual(t, testLog, result)
|
||||
assert.NotContains(t, result, "example.com")
|
||||
}
|
||||
|
||||
func TestAnonymizeDomain(t *testing.T) {
|
||||
anonymizer := anonymize.NewAnonymizer(netip.Addr{}, netip.Addr{})
|
||||
tests := []struct {
|
||||
name string
|
||||
domain string
|
||||
expectPattern string
|
||||
shouldAnonymize bool
|
||||
}{
|
||||
{
|
||||
"General Domain",
|
||||
"example.com",
|
||||
`^anon-[a-zA-Z0-9]+\.domain$`,
|
||||
true,
|
||||
},
|
||||
{
|
||||
"Subdomain",
|
||||
"sub.example.com",
|
||||
`^sub\.anon-[a-zA-Z0-9]+\.domain$`,
|
||||
true,
|
||||
},
|
||||
{
|
||||
"Protected Domain",
|
||||
"netbird.io",
|
||||
`^netbird\.io$`,
|
||||
false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
result := anonymizer.AnonymizeDomain(tc.domain)
|
||||
if tc.shouldAnonymize {
|
||||
assert.Regexp(t, tc.expectPattern, result, "The anonymized domain should match the expected pattern")
|
||||
assert.NotContains(t, result, tc.domain, "The original domain should not be present in the result")
|
||||
} else {
|
||||
assert.Equal(t, tc.domain, result, "Protected domains should not be anonymized")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAnonymizeURI(t *testing.T) {
|
||||
anonymizer := anonymize.NewAnonymizer(netip.Addr{}, netip.Addr{})
|
||||
tests := []struct {
|
||||
name string
|
||||
uri string
|
||||
regex string
|
||||
}{
|
||||
{
|
||||
"HTTP URI with Port",
|
||||
"http://example.com:80/path",
|
||||
`^http://anon-[a-zA-Z0-9]+\.domain:80/path$`,
|
||||
},
|
||||
{
|
||||
"HTTP URI without Port",
|
||||
"http://example.com/path",
|
||||
`^http://anon-[a-zA-Z0-9]+\.domain/path$`,
|
||||
},
|
||||
{
|
||||
"Opaque URI with Port",
|
||||
"stun:example.com:80?transport=udp",
|
||||
`^stun:anon-[a-zA-Z0-9]+\.domain:80\?transport=udp$`,
|
||||
},
|
||||
{
|
||||
"Opaque URI without Port",
|
||||
"stun:example.com?transport=udp",
|
||||
`^stun:anon-[a-zA-Z0-9]+\.domain\?transport=udp$`,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
result := anonymizer.AnonymizeURI(tc.uri)
|
||||
assert.Regexp(t, regexp.MustCompile(tc.regex), result, "URI should match expected pattern")
|
||||
require.NotContains(t, result, "example.com", "Original domain should not be present")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAnonymizeSchemeURI(t *testing.T) {
|
||||
anonymizer := anonymize.NewAnonymizer(netip.Addr{}, netip.Addr{})
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expect string
|
||||
}{
|
||||
{"STUN URI in text", "Connection made via stun:example.com", `Connection made via stun:anon-[a-zA-Z0-9]+\.domain`},
|
||||
{"TURN URI in log", "Failed attempt turn:some.example.com:3478?transport=tcp: retrying", `Failed attempt turn:some.anon-[a-zA-Z0-9]+\.domain:3478\?transport=tcp: retrying`},
|
||||
{"HTTPS URI in message", "Visit https://example.com for more", `Visit https://anon-[a-zA-Z0-9]+\.domain for more`},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
result := anonymizer.AnonymizeSchemeURI(tc.input)
|
||||
assert.Regexp(t, tc.expect, result, "The anonymized output should match expected pattern")
|
||||
require.NotContains(t, result, "example.com", "Original domain should not be present")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAnonymizString_MemorizedDomain(t *testing.T) {
|
||||
anonymizer := anonymize.NewAnonymizer(netip.Addr{}, netip.Addr{})
|
||||
domain := "example.com"
|
||||
anonymizedDomain := anonymizer.AnonymizeDomain(domain)
|
||||
|
||||
sampleString := "This is a test string including the domain example.com which should be anonymized."
|
||||
|
||||
firstPassResult := anonymizer.AnonymizeString(sampleString)
|
||||
secondPassResult := anonymizer.AnonymizeString(firstPassResult)
|
||||
|
||||
assert.Contains(t, firstPassResult, anonymizedDomain, "The domain should be anonymized in the first pass")
|
||||
assert.NotContains(t, firstPassResult, domain, "The original domain should not appear in the first pass output")
|
||||
|
||||
assert.Equal(t, firstPassResult, secondPassResult, "The second pass should not further anonymize the string")
|
||||
}
|
||||
|
||||
func TestAnonymizeString_DoubleURI(t *testing.T) {
|
||||
anonymizer := anonymize.NewAnonymizer(netip.Addr{}, netip.Addr{})
|
||||
domain := "example.com"
|
||||
anonymizedDomain := anonymizer.AnonymizeDomain(domain)
|
||||
|
||||
sampleString := "Check out our site at https://example.com for more info."
|
||||
|
||||
firstPassResult := anonymizer.AnonymizeString(sampleString)
|
||||
secondPassResult := anonymizer.AnonymizeString(firstPassResult)
|
||||
|
||||
assert.Contains(t, firstPassResult, "https://"+anonymizedDomain, "The URI should be anonymized in the first pass")
|
||||
assert.NotContains(t, firstPassResult, "https://example.com", "The original URI should not appear in the first pass output")
|
||||
|
||||
assert.Equal(t, firstPassResult, secondPassResult, "The second pass should not further anonymize the URI")
|
||||
}
|
||||
|
||||
func TestAnonymizeString_IPAddresses(t *testing.T) {
|
||||
anonymizer := anonymize.NewAnonymizer(anonymize.DefaultAddresses())
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expect string
|
||||
}{
|
||||
{
|
||||
name: "IPv4 Address",
|
||||
input: "Error occurred at IP 122.138.1.1",
|
||||
expect: "Error occurred at IP 198.51.100.0",
|
||||
},
|
||||
{
|
||||
name: "IPv6 Address",
|
||||
input: "Access attempted from 2001:db8::ff00:42",
|
||||
expect: "Access attempted from 100::",
|
||||
},
|
||||
{
|
||||
name: "IPv6 Address with Port",
|
||||
input: "Access attempted from [2001:db8::ff00:42]:8080",
|
||||
expect: "Access attempted from [100::]:8080",
|
||||
},
|
||||
{
|
||||
name: "Both IPv4 and IPv6",
|
||||
input: "IPv4: 142.108.0.1 and IPv6: 2001:db8::ff00:43",
|
||||
expect: "IPv4: 198.51.100.1 and IPv6: 100::1",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
result := anonymizer.AnonymizeString(tc.input)
|
||||
assert.Equal(t, tc.expect, result, "IP addresses should be anonymized correctly")
|
||||
})
|
||||
}
|
||||
}
|
||||
248
client/cmd/debug.go
Normal file
248
client/cmd/debug.go
Normal file
@@ -0,0 +1,248 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
"google.golang.org/grpc/status"
|
||||
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
)
|
||||
|
||||
var debugCmd = &cobra.Command{
|
||||
Use: "debug",
|
||||
Short: "Debugging commands",
|
||||
Long: "Provides commands for debugging and logging control within the Netbird daemon.",
|
||||
}
|
||||
|
||||
var debugBundleCmd = &cobra.Command{
|
||||
Use: "bundle",
|
||||
Example: " netbird debug bundle",
|
||||
Short: "Create a debug bundle",
|
||||
Long: "Generates a compressed archive of the daemon's logs and status for debugging purposes.",
|
||||
RunE: debugBundle,
|
||||
}
|
||||
|
||||
var logCmd = &cobra.Command{
|
||||
Use: "log",
|
||||
Short: "Manage logging for the Netbird daemon",
|
||||
Long: `Commands to manage logging settings for the Netbird daemon, including ICE, gRPC, and general log levels.`,
|
||||
}
|
||||
|
||||
var logLevelCmd = &cobra.Command{
|
||||
Use: "level <level>",
|
||||
Short: "Set the logging level for this session",
|
||||
Long: `Sets the logging level for the current session. This setting is temporary and will revert to the default on daemon restart.
|
||||
Available log levels are:
|
||||
panic: for panic level, highest level of severity
|
||||
fatal: for fatal level errors that cause the program to exit
|
||||
error: for error conditions
|
||||
warn: for warning conditions
|
||||
info: for informational messages
|
||||
debug: for debug-level messages
|
||||
trace: for trace-level messages, which include more fine-grained information than debug`,
|
||||
Args: cobra.ExactArgs(1),
|
||||
RunE: setLogLevel,
|
||||
}
|
||||
|
||||
var forCmd = &cobra.Command{
|
||||
Use: "for <time>",
|
||||
Short: "Run debug logs for a specified duration and create a debug bundle",
|
||||
Long: `Sets the logging level to trace, runs for the specified duration, and then generates a debug bundle.`,
|
||||
Example: " netbird debug for 5m",
|
||||
Args: cobra.ExactArgs(1),
|
||||
RunE: runForDuration,
|
||||
}
|
||||
|
||||
func debugBundle(cmd *cobra.Command, _ []string) error {
|
||||
conn, err := getClient(cmd.Context())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
client := proto.NewDaemonServiceClient(conn)
|
||||
resp, err := client.DebugBundle(cmd.Context(), &proto.DebugBundleRequest{
|
||||
Anonymize: anonymizeFlag,
|
||||
Status: getStatusOutput(cmd),
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to bundle debug: %v", status.Convert(err).Message())
|
||||
}
|
||||
|
||||
cmd.Println(resp.GetPath())
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func setLogLevel(cmd *cobra.Command, args []string) error {
|
||||
conn, err := getClient(cmd.Context())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
client := proto.NewDaemonServiceClient(conn)
|
||||
level := parseLogLevel(args[0])
|
||||
if level == proto.LogLevel_UNKNOWN {
|
||||
return fmt.Errorf("unknown log level: %s. Available levels are: panic, fatal, error, warn, info, debug, trace\n", args[0])
|
||||
}
|
||||
|
||||
_, err = client.SetLogLevel(cmd.Context(), &proto.SetLogLevelRequest{
|
||||
Level: level,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to set log level: %v", status.Convert(err).Message())
|
||||
}
|
||||
|
||||
cmd.Println("Log level set successfully to", args[0])
|
||||
return nil
|
||||
}
|
||||
|
||||
func parseLogLevel(level string) proto.LogLevel {
|
||||
switch strings.ToLower(level) {
|
||||
case "panic":
|
||||
return proto.LogLevel_PANIC
|
||||
case "fatal":
|
||||
return proto.LogLevel_FATAL
|
||||
case "error":
|
||||
return proto.LogLevel_ERROR
|
||||
case "warn":
|
||||
return proto.LogLevel_WARN
|
||||
case "info":
|
||||
return proto.LogLevel_INFO
|
||||
case "debug":
|
||||
return proto.LogLevel_DEBUG
|
||||
case "trace":
|
||||
return proto.LogLevel_TRACE
|
||||
default:
|
||||
return proto.LogLevel_UNKNOWN
|
||||
}
|
||||
}
|
||||
|
||||
func runForDuration(cmd *cobra.Command, args []string) error {
|
||||
duration, err := time.ParseDuration(args[0])
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid duration format: %v", err)
|
||||
}
|
||||
|
||||
conn, err := getClient(cmd.Context())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
client := proto.NewDaemonServiceClient(conn)
|
||||
|
||||
if _, err := client.Down(cmd.Context(), &proto.DownRequest{}); err != nil {
|
||||
return fmt.Errorf("failed to down: %v", status.Convert(err).Message())
|
||||
}
|
||||
cmd.Println("Netbird down")
|
||||
|
||||
_, err = client.SetLogLevel(cmd.Context(), &proto.SetLogLevelRequest{
|
||||
Level: proto.LogLevel_TRACE,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to set log level to trace: %v", status.Convert(err).Message())
|
||||
}
|
||||
cmd.Println("Log level set to trace.")
|
||||
|
||||
time.Sleep(1 * time.Second)
|
||||
|
||||
if _, err := client.Up(cmd.Context(), &proto.UpRequest{}); err != nil {
|
||||
return fmt.Errorf("failed to up: %v", status.Convert(err).Message())
|
||||
}
|
||||
cmd.Println("Netbird up")
|
||||
|
||||
time.Sleep(3 * time.Second)
|
||||
|
||||
headerPostUp := fmt.Sprintf("----- Netbird post-up - Timestamp: %s", time.Now().Format(time.RFC3339))
|
||||
statusOutput := fmt.Sprintf("%s\n%s", headerPostUp, getStatusOutput(cmd))
|
||||
|
||||
if waitErr := waitForDurationOrCancel(cmd.Context(), duration, cmd); waitErr != nil {
|
||||
return waitErr
|
||||
}
|
||||
cmd.Println("\nDuration completed")
|
||||
|
||||
headerPreDown := fmt.Sprintf("----- Netbird pre-down - Timestamp: %s - Duration: %s", time.Now().Format(time.RFC3339), duration)
|
||||
statusOutput = fmt.Sprintf("%s\n%s\n%s", statusOutput, headerPreDown, getStatusOutput(cmd))
|
||||
|
||||
if _, err := client.Down(cmd.Context(), &proto.DownRequest{}); err != nil {
|
||||
return fmt.Errorf("failed to down: %v", status.Convert(err).Message())
|
||||
}
|
||||
cmd.Println("Netbird down")
|
||||
|
||||
// TODO reset log level
|
||||
|
||||
time.Sleep(1 * time.Second)
|
||||
|
||||
cmd.Println("Creating debug bundle...")
|
||||
|
||||
resp, err := client.DebugBundle(cmd.Context(), &proto.DebugBundleRequest{
|
||||
Anonymize: anonymizeFlag,
|
||||
Status: statusOutput,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to bundle debug: %v", status.Convert(err).Message())
|
||||
}
|
||||
|
||||
cmd.Println(resp.GetPath())
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func getStatusOutput(cmd *cobra.Command) string {
|
||||
var statusOutputString string
|
||||
statusResp, err := getStatus(cmd.Context())
|
||||
if err != nil {
|
||||
cmd.PrintErrf("Failed to get status: %v\n", err)
|
||||
} else {
|
||||
statusOutputString = parseToFullDetailSummary(convertToStatusOutputOverview(statusResp))
|
||||
}
|
||||
return statusOutputString
|
||||
}
|
||||
|
||||
func waitForDurationOrCancel(ctx context.Context, duration time.Duration, cmd *cobra.Command) error {
|
||||
ticker := time.NewTicker(1 * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
startTime := time.Now()
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer close(done)
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
elapsed := time.Since(startTime)
|
||||
if elapsed >= duration {
|
||||
return
|
||||
}
|
||||
remaining := duration - elapsed
|
||||
cmd.Printf("\rRemaining time: %s", formatDuration(remaining))
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-done:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func formatDuration(d time.Duration) string {
|
||||
d = d.Round(time.Second)
|
||||
h := d / time.Hour
|
||||
d %= time.Hour
|
||||
m := d / time.Minute
|
||||
d %= time.Minute
|
||||
s := d / time.Second
|
||||
return fmt.Sprintf("%02d:%02d:%02d", h, m, s)
|
||||
}
|
||||
@@ -65,6 +65,7 @@ var (
|
||||
serviceName string
|
||||
autoConnectDisabled bool
|
||||
extraIFaceBlackList []string
|
||||
anonymizeFlag bool
|
||||
rootCmd = &cobra.Command{
|
||||
Use: "netbird",
|
||||
Short: "",
|
||||
@@ -119,6 +120,8 @@ func init() {
|
||||
rootCmd.PersistentFlags().StringVarP(&setupKey, "setup-key", "k", "", "Setup key obtained from the Management Service Dashboard (used to register peer)")
|
||||
rootCmd.PersistentFlags().StringVar(&preSharedKey, preSharedKeyFlag, "", "Sets Wireguard PreSharedKey property. If set, then only peers that have the same key can communicate.")
|
||||
rootCmd.PersistentFlags().StringVarP(&hostName, "hostname", "n", "", "Sets a custom hostname for the device")
|
||||
rootCmd.PersistentFlags().BoolVarP(&anonymizeFlag, "anonymize", "A", false, "anonymize IP addresses and non-netbird.io domains in logs and status output")
|
||||
|
||||
rootCmd.AddCommand(serviceCmd)
|
||||
rootCmd.AddCommand(upCmd)
|
||||
rootCmd.AddCommand(downCmd)
|
||||
@@ -126,8 +129,20 @@ func init() {
|
||||
rootCmd.AddCommand(loginCmd)
|
||||
rootCmd.AddCommand(versionCmd)
|
||||
rootCmd.AddCommand(sshCmd)
|
||||
rootCmd.AddCommand(routesCmd)
|
||||
rootCmd.AddCommand(debugCmd)
|
||||
|
||||
serviceCmd.AddCommand(runCmd, startCmd, stopCmd, restartCmd) // service control commands are subcommands of service
|
||||
serviceCmd.AddCommand(installCmd, uninstallCmd) // service installer commands are subcommands of service
|
||||
|
||||
routesCmd.AddCommand(routesListCmd)
|
||||
routesCmd.AddCommand(routesSelectCmd, routesDeselectCmd)
|
||||
|
||||
debugCmd.AddCommand(debugBundleCmd)
|
||||
debugCmd.AddCommand(logCmd)
|
||||
logCmd.AddCommand(logLevelCmd)
|
||||
debugCmd.AddCommand(forCmd)
|
||||
|
||||
upCmd.PersistentFlags().StringSliceVar(&natExternalIPs, externalIPMapFlag, nil,
|
||||
`Sets external IPs maps between local addresses and interfaces.`+
|
||||
`You can specify a comma-separated list with a single IP and IP/IP or IP/Interface Name. `+
|
||||
@@ -335,3 +350,14 @@ func migrateToNetbird(oldPath, newPath string) bool {
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func getClient(ctx context.Context) (*grpc.ClientConn, error) {
|
||||
conn, err := DialClientGRPCServer(ctx, daemonAddr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to connect to daemon error: %v\n"+
|
||||
"If the daemon is not running please run: "+
|
||||
"\nnetbird service install \nnetbird service start\n", err)
|
||||
}
|
||||
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
131
client/cmd/route.go
Normal file
131
client/cmd/route.go
Normal file
@@ -0,0 +1,131 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
"google.golang.org/grpc/status"
|
||||
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
)
|
||||
|
||||
var appendFlag bool
|
||||
|
||||
var routesCmd = &cobra.Command{
|
||||
Use: "routes",
|
||||
Short: "Manage network routes",
|
||||
Long: `Commands to list, select, or deselect network routes.`,
|
||||
}
|
||||
|
||||
var routesListCmd = &cobra.Command{
|
||||
Use: "list",
|
||||
Aliases: []string{"ls"},
|
||||
Short: "List routes",
|
||||
Example: " netbird routes list",
|
||||
Long: "List all available network routes.",
|
||||
RunE: routesList,
|
||||
}
|
||||
|
||||
var routesSelectCmd = &cobra.Command{
|
||||
Use: "select route...|all",
|
||||
Short: "Select routes",
|
||||
Long: "Select a list of routes by identifiers or 'all' to clear all selections and to accept all (including new) routes.\nDefault mode is replace, use -a to append to already selected routes.",
|
||||
Example: " netbird routes select all\n netbird routes select route1 route2\n netbird routes select -a route3",
|
||||
Args: cobra.MinimumNArgs(1),
|
||||
RunE: routesSelect,
|
||||
}
|
||||
|
||||
var routesDeselectCmd = &cobra.Command{
|
||||
Use: "deselect route...|all",
|
||||
Short: "Deselect routes",
|
||||
Long: "Deselect previously selected routes by identifiers or 'all' to disable accepting any routes.",
|
||||
Example: " netbird routes deselect all\n netbird routes deselect route1 route2",
|
||||
Args: cobra.MinimumNArgs(1),
|
||||
RunE: routesDeselect,
|
||||
}
|
||||
|
||||
func init() {
|
||||
routesSelectCmd.PersistentFlags().BoolVarP(&appendFlag, "append", "a", false, "Append to current route selection instead of replacing")
|
||||
}
|
||||
|
||||
func routesList(cmd *cobra.Command, _ []string) error {
|
||||
conn, err := getClient(cmd.Context())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
client := proto.NewDaemonServiceClient(conn)
|
||||
resp, err := client.ListRoutes(cmd.Context(), &proto.ListRoutesRequest{})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to list routes: %v", status.Convert(err).Message())
|
||||
}
|
||||
|
||||
if len(resp.Routes) == 0 {
|
||||
cmd.Println("No routes available.")
|
||||
return nil
|
||||
}
|
||||
|
||||
cmd.Println("Available Routes:")
|
||||
for _, route := range resp.Routes {
|
||||
selectedStatus := "Not Selected"
|
||||
if route.GetSelected() {
|
||||
selectedStatus = "Selected"
|
||||
}
|
||||
cmd.Printf("\n - ID: %s\n Network: %s\n Status: %s\n", route.GetID(), route.GetNetwork(), selectedStatus)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func routesSelect(cmd *cobra.Command, args []string) error {
|
||||
conn, err := getClient(cmd.Context())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
client := proto.NewDaemonServiceClient(conn)
|
||||
req := &proto.SelectRoutesRequest{
|
||||
RouteIDs: args,
|
||||
}
|
||||
|
||||
if len(args) == 1 && args[0] == "all" {
|
||||
req.All = true
|
||||
} else if appendFlag {
|
||||
req.Append = true
|
||||
}
|
||||
|
||||
if _, err := client.SelectRoutes(cmd.Context(), req); err != nil {
|
||||
return fmt.Errorf("failed to select routes: %v", status.Convert(err).Message())
|
||||
}
|
||||
|
||||
cmd.Println("Routes selected successfully.")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func routesDeselect(cmd *cobra.Command, args []string) error {
|
||||
conn, err := getClient(cmd.Context())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
client := proto.NewDaemonServiceClient(conn)
|
||||
req := &proto.SelectRoutesRequest{
|
||||
RouteIDs: args,
|
||||
}
|
||||
|
||||
if len(args) == 1 && args[0] == "all" {
|
||||
req.All = true
|
||||
}
|
||||
|
||||
if _, err := client.DeselectRoutes(cmd.Context(), req); err != nil {
|
||||
return fmt.Errorf("failed to deselect routes: %v", status.Convert(err).Message())
|
||||
}
|
||||
|
||||
cmd.Println("Routes deselected successfully.")
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -64,6 +64,10 @@ var installCmd = &cobra.Command{
|
||||
}
|
||||
}
|
||||
|
||||
if runtime.GOOS == "windows" {
|
||||
svcConfig.Option["OnFailure"] = "restart"
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(cmd.Context())
|
||||
|
||||
s, err := newSVC(newProgram(ctx, cancel), svcConfig)
|
||||
|
||||
@@ -24,7 +24,7 @@ var (
|
||||
)
|
||||
|
||||
var sshCmd = &cobra.Command{
|
||||
Use: "ssh",
|
||||
Use: "ssh [user@]host",
|
||||
Args: func(cmd *cobra.Command, args []string) error {
|
||||
if len(args) < 1 {
|
||||
return errors.New("requires a host argument")
|
||||
@@ -94,7 +94,7 @@ func runSSH(ctx context.Context, addr string, pemKey []byte, cmd *cobra.Command)
|
||||
if err != nil {
|
||||
cmd.Printf("Error: %v\n", err)
|
||||
cmd.Printf("Couldn't connect. Please check the connection status or if the ssh server is enabled on the other peer" +
|
||||
"You can verify the connection by running:\n\n" +
|
||||
"\nYou can verify the connection by running:\n\n" +
|
||||
" netbird status\n\n")
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -6,6 +6,8 @@ import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"runtime"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
@@ -14,6 +16,7 @@ import (
|
||||
"google.golang.org/grpc/status"
|
||||
"gopkg.in/yaml.v3"
|
||||
|
||||
"github.com/netbirdio/netbird/client/anonymize"
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
@@ -144,9 +147,9 @@ func statusFunc(cmd *cobra.Command, args []string) error {
|
||||
return fmt.Errorf("failed initializing log %v", err)
|
||||
}
|
||||
|
||||
ctx := internal.CtxInitState(context.Background())
|
||||
ctx := internal.CtxInitState(cmd.Context())
|
||||
|
||||
resp, err := getStatus(ctx, cmd)
|
||||
resp, err := getStatus(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -191,7 +194,7 @@ func statusFunc(cmd *cobra.Command, args []string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func getStatus(ctx context.Context, cmd *cobra.Command) (*proto.StatusResponse, error) {
|
||||
func getStatus(ctx context.Context) (*proto.StatusResponse, error) {
|
||||
conn, err := DialClientGRPCServer(ctx, daemonAddr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to connect to daemon error: %v\n"+
|
||||
@@ -200,7 +203,7 @@ func getStatus(ctx context.Context, cmd *cobra.Command) (*proto.StatusResponse,
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
resp, err := proto.NewDaemonServiceClient(conn).Status(cmd.Context(), &proto.StatusRequest{GetFullPeerStatus: true})
|
||||
resp, err := proto.NewDaemonServiceClient(conn).Status(ctx, &proto.StatusRequest{GetFullPeerStatus: true})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("status failed: %v", status.Convert(err).Message())
|
||||
}
|
||||
@@ -283,6 +286,11 @@ func convertToStatusOutputOverview(resp *proto.StatusResponse) statusOutputOverv
|
||||
NSServerGroups: mapNSGroups(pbFullStatus.GetDnsServers()),
|
||||
}
|
||||
|
||||
if anonymizeFlag {
|
||||
anonymizer := anonymize.NewAnonymizer(anonymize.DefaultAddresses())
|
||||
anonymizeOverview(anonymizer, &overview)
|
||||
}
|
||||
|
||||
return overview
|
||||
}
|
||||
|
||||
@@ -525,8 +533,16 @@ func parseGeneralSummary(overview statusOutputOverview, showURL bool, showRelays
|
||||
|
||||
peersCountString := fmt.Sprintf("%d/%d Connected", overview.Peers.Connected, overview.Peers.Total)
|
||||
|
||||
goos := runtime.GOOS
|
||||
goarch := runtime.GOARCH
|
||||
goarm := ""
|
||||
if goarch == "arm" {
|
||||
goarm = fmt.Sprintf(" (ARMv%s)", os.Getenv("GOARM"))
|
||||
}
|
||||
|
||||
summary := fmt.Sprintf(
|
||||
"Daemon version: %s\n"+
|
||||
"OS: %s\n"+
|
||||
"Daemon version: %s\n"+
|
||||
"CLI version: %s\n"+
|
||||
"Management: %s\n"+
|
||||
"Signal: %s\n"+
|
||||
@@ -538,6 +554,7 @@ func parseGeneralSummary(overview statusOutputOverview, showURL bool, showRelays
|
||||
"Quantum resistance: %s\n"+
|
||||
"Routes: %s\n"+
|
||||
"Peers count: %s\n",
|
||||
fmt.Sprintf("%s/%s%s", goos, goarch, goarm),
|
||||
overview.DaemonVersion,
|
||||
version.NetbirdVersion(),
|
||||
managementConnString,
|
||||
@@ -593,15 +610,6 @@ func parsePeers(peers peersStateOutput, rosenpassEnabled, rosenpassPermissive bo
|
||||
if peerState.IceCandidateEndpoint.Remote != "" {
|
||||
remoteICEEndpoint = peerState.IceCandidateEndpoint.Remote
|
||||
}
|
||||
lastStatusUpdate := "-"
|
||||
if !peerState.LastStatusUpdate.IsZero() {
|
||||
lastStatusUpdate = peerState.LastStatusUpdate.Format("2006-01-02 15:04:05")
|
||||
}
|
||||
|
||||
lastWireGuardHandshake := "-"
|
||||
if !peerState.LastWireguardHandshake.IsZero() && peerState.LastWireguardHandshake != time.Unix(0, 0) {
|
||||
lastWireGuardHandshake = peerState.LastWireguardHandshake.Format("2006-01-02 15:04:05")
|
||||
}
|
||||
|
||||
rosenpassEnabledStatus := "false"
|
||||
if rosenpassEnabled {
|
||||
@@ -652,8 +660,8 @@ func parsePeers(peers peersStateOutput, rosenpassEnabled, rosenpassPermissive bo
|
||||
remoteICE,
|
||||
localICEEndpoint,
|
||||
remoteICEEndpoint,
|
||||
lastStatusUpdate,
|
||||
lastWireGuardHandshake,
|
||||
timeAgo(peerState.LastStatusUpdate),
|
||||
timeAgo(peerState.LastWireguardHandshake),
|
||||
toIEC(peerState.TransferReceived),
|
||||
toIEC(peerState.TransferSent),
|
||||
rosenpassEnabledStatus,
|
||||
@@ -722,3 +730,129 @@ func countEnabled(dnsServers []nsServerGroupStateOutput) int {
|
||||
}
|
||||
return count
|
||||
}
|
||||
|
||||
// timeAgo returns a string representing the duration since the provided time in a human-readable format.
|
||||
func timeAgo(t time.Time) string {
|
||||
if t.IsZero() || t.Equal(time.Unix(0, 0)) {
|
||||
return "-"
|
||||
}
|
||||
duration := time.Since(t)
|
||||
switch {
|
||||
case duration < time.Second:
|
||||
return "Now"
|
||||
case duration < time.Minute:
|
||||
seconds := int(duration.Seconds())
|
||||
if seconds == 1 {
|
||||
return "1 second ago"
|
||||
}
|
||||
return fmt.Sprintf("%d seconds ago", seconds)
|
||||
case duration < time.Hour:
|
||||
minutes := int(duration.Minutes())
|
||||
seconds := int(duration.Seconds()) % 60
|
||||
if minutes == 1 {
|
||||
if seconds == 1 {
|
||||
return "1 minute, 1 second ago"
|
||||
} else if seconds > 0 {
|
||||
return fmt.Sprintf("1 minute, %d seconds ago", seconds)
|
||||
}
|
||||
return "1 minute ago"
|
||||
}
|
||||
if seconds > 0 {
|
||||
return fmt.Sprintf("%d minutes, %d seconds ago", minutes, seconds)
|
||||
}
|
||||
return fmt.Sprintf("%d minutes ago", minutes)
|
||||
case duration < 24*time.Hour:
|
||||
hours := int(duration.Hours())
|
||||
minutes := int(duration.Minutes()) % 60
|
||||
if hours == 1 {
|
||||
if minutes == 1 {
|
||||
return "1 hour, 1 minute ago"
|
||||
} else if minutes > 0 {
|
||||
return fmt.Sprintf("1 hour, %d minutes ago", minutes)
|
||||
}
|
||||
return "1 hour ago"
|
||||
}
|
||||
if minutes > 0 {
|
||||
return fmt.Sprintf("%d hours, %d minutes ago", hours, minutes)
|
||||
}
|
||||
return fmt.Sprintf("%d hours ago", hours)
|
||||
}
|
||||
|
||||
days := int(duration.Hours()) / 24
|
||||
hours := int(duration.Hours()) % 24
|
||||
if days == 1 {
|
||||
if hours == 1 {
|
||||
return "1 day, 1 hour ago"
|
||||
} else if hours > 0 {
|
||||
return fmt.Sprintf("1 day, %d hours ago", hours)
|
||||
}
|
||||
return "1 day ago"
|
||||
}
|
||||
if hours > 0 {
|
||||
return fmt.Sprintf("%d days, %d hours ago", days, hours)
|
||||
}
|
||||
return fmt.Sprintf("%d days ago", days)
|
||||
}
|
||||
|
||||
func anonymizePeerDetail(a *anonymize.Anonymizer, peer *peerStateDetailOutput) {
|
||||
peer.FQDN = a.AnonymizeDomain(peer.FQDN)
|
||||
if localIP, port, err := net.SplitHostPort(peer.IceCandidateEndpoint.Local); err == nil {
|
||||
peer.IceCandidateEndpoint.Local = fmt.Sprintf("%s:%s", a.AnonymizeIPString(localIP), port)
|
||||
}
|
||||
if remoteIP, port, err := net.SplitHostPort(peer.IceCandidateEndpoint.Remote); err == nil {
|
||||
peer.IceCandidateEndpoint.Remote = fmt.Sprintf("%s:%s", a.AnonymizeIPString(remoteIP), port)
|
||||
}
|
||||
for i, route := range peer.Routes {
|
||||
peer.Routes[i] = a.AnonymizeIPString(route)
|
||||
}
|
||||
|
||||
for i, route := range peer.Routes {
|
||||
prefix, err := netip.ParsePrefix(route)
|
||||
if err == nil {
|
||||
ip := a.AnonymizeIPString(prefix.Addr().String())
|
||||
peer.Routes[i] = fmt.Sprintf("%s/%d", ip, prefix.Bits())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func anonymizeOverview(a *anonymize.Anonymizer, overview *statusOutputOverview) {
|
||||
for i, peer := range overview.Peers.Details {
|
||||
peer := peer
|
||||
anonymizePeerDetail(a, &peer)
|
||||
overview.Peers.Details[i] = peer
|
||||
}
|
||||
|
||||
overview.ManagementState.URL = a.AnonymizeURI(overview.ManagementState.URL)
|
||||
overview.ManagementState.Error = a.AnonymizeString(overview.ManagementState.Error)
|
||||
overview.SignalState.URL = a.AnonymizeURI(overview.SignalState.URL)
|
||||
overview.SignalState.Error = a.AnonymizeString(overview.SignalState.Error)
|
||||
|
||||
overview.IP = a.AnonymizeIPString(overview.IP)
|
||||
for i, detail := range overview.Relays.Details {
|
||||
detail.URI = a.AnonymizeURI(detail.URI)
|
||||
detail.Error = a.AnonymizeString(detail.Error)
|
||||
overview.Relays.Details[i] = detail
|
||||
}
|
||||
|
||||
for i, nsGroup := range overview.NSServerGroups {
|
||||
for j, domain := range nsGroup.Domains {
|
||||
overview.NSServerGroups[i].Domains[j] = a.AnonymizeDomain(domain)
|
||||
}
|
||||
for j, ns := range nsGroup.Servers {
|
||||
host, port, err := net.SplitHostPort(ns)
|
||||
if err == nil {
|
||||
overview.NSServerGroups[i].Servers[j] = fmt.Sprintf("%s:%s", a.AnonymizeIPString(host), port)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for i, route := range overview.Routes {
|
||||
prefix, err := netip.ParsePrefix(route)
|
||||
if err == nil {
|
||||
ip := a.AnonymizeIPString(prefix.Addr().String())
|
||||
overview.Routes[i] = fmt.Sprintf("%s/%d", ip, prefix.Bits())
|
||||
}
|
||||
}
|
||||
|
||||
overview.FQDN = a.AnonymizeDomain(overview.FQDN)
|
||||
}
|
||||
|
||||
@@ -3,6 +3,8 @@ package cmd
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"runtime"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -487,9 +489,15 @@ dnsServers:
|
||||
}
|
||||
|
||||
func TestParsingToDetail(t *testing.T) {
|
||||
// Calculate time ago based on the fixture dates
|
||||
lastConnectionUpdate1 := timeAgo(overview.Peers.Details[0].LastStatusUpdate)
|
||||
lastHandshake1 := timeAgo(overview.Peers.Details[0].LastWireguardHandshake)
|
||||
lastConnectionUpdate2 := timeAgo(overview.Peers.Details[1].LastStatusUpdate)
|
||||
lastHandshake2 := timeAgo(overview.Peers.Details[1].LastWireguardHandshake)
|
||||
|
||||
detail := parseToFullDetailSummary(overview)
|
||||
|
||||
expectedDetail :=
|
||||
expectedDetail := fmt.Sprintf(
|
||||
`Peers detail:
|
||||
peer-1.awesome-domain.com:
|
||||
NetBird IP: 192.168.178.101
|
||||
@@ -500,8 +508,8 @@ func TestParsingToDetail(t *testing.T) {
|
||||
Direct: true
|
||||
ICE candidate (Local/Remote): -/-
|
||||
ICE candidate endpoints (Local/Remote): -/-
|
||||
Last connection update: 2001-01-01 01:01:01
|
||||
Last WireGuard handshake: 2001-01-01 01:01:02
|
||||
Last connection update: %s
|
||||
Last WireGuard handshake: %s
|
||||
Transfer status (received/sent) 200 B/100 B
|
||||
Quantum resistance: false
|
||||
Routes: 10.1.0.0/24
|
||||
@@ -516,15 +524,16 @@ func TestParsingToDetail(t *testing.T) {
|
||||
Direct: false
|
||||
ICE candidate (Local/Remote): relay/prflx
|
||||
ICE candidate endpoints (Local/Remote): 10.0.0.1:10001/10.0.10.1:10002
|
||||
Last connection update: 2002-02-02 02:02:02
|
||||
Last WireGuard handshake: 2002-02-02 02:02:03
|
||||
Last connection update: %s
|
||||
Last WireGuard handshake: %s
|
||||
Transfer status (received/sent) 2.0 KiB/1000 B
|
||||
Quantum resistance: false
|
||||
Routes: -
|
||||
Latency: 10ms
|
||||
|
||||
OS: %s/%s
|
||||
Daemon version: 0.14.1
|
||||
CLI version: development
|
||||
CLI version: %s
|
||||
Management: Connected to my-awesome-management.com:443
|
||||
Signal: Connected to my-awesome-signal.com:443
|
||||
Relays:
|
||||
@@ -539,7 +548,7 @@ Interface type: Kernel
|
||||
Quantum resistance: false
|
||||
Routes: 10.10.0.0/24
|
||||
Peers count: 2/2 Connected
|
||||
`
|
||||
`, lastConnectionUpdate1, lastHandshake1, lastConnectionUpdate2, lastHandshake2, runtime.GOOS, runtime.GOARCH, overview.CliVersion)
|
||||
|
||||
assert.Equal(t, expectedDetail, detail)
|
||||
}
|
||||
@@ -547,8 +556,8 @@ Peers count: 2/2 Connected
|
||||
func TestParsingToShortVersion(t *testing.T) {
|
||||
shortVersion := parseGeneralSummary(overview, false, false, false)
|
||||
|
||||
expectedString :=
|
||||
`Daemon version: 0.14.1
|
||||
expectedString := fmt.Sprintf("OS: %s/%s", runtime.GOOS, runtime.GOARCH) + `
|
||||
Daemon version: 0.14.1
|
||||
CLI version: development
|
||||
Management: Connected
|
||||
Signal: Connected
|
||||
@@ -572,3 +581,31 @@ func TestParsingOfIP(t *testing.T) {
|
||||
|
||||
assert.Equal(t, "192.168.178.123\n", parsedIP)
|
||||
}
|
||||
|
||||
func TestTimeAgo(t *testing.T) {
|
||||
now := time.Now()
|
||||
|
||||
cases := []struct {
|
||||
name string
|
||||
input time.Time
|
||||
expected string
|
||||
}{
|
||||
{"Now", now, "Now"},
|
||||
{"Seconds ago", now.Add(-10 * time.Second), "10 seconds ago"},
|
||||
{"One minute ago", now.Add(-1 * time.Minute), "1 minute ago"},
|
||||
{"Minutes and seconds ago", now.Add(-(1*time.Minute + 30*time.Second)), "1 minute, 30 seconds ago"},
|
||||
{"One hour ago", now.Add(-1 * time.Hour), "1 hour ago"},
|
||||
{"Hours and minutes ago", now.Add(-(2*time.Hour + 15*time.Minute)), "2 hours, 15 minutes ago"},
|
||||
{"One day ago", now.Add(-24 * time.Hour), "1 day ago"},
|
||||
{"Multiple days ago", now.Add(-(72*time.Hour + 20*time.Minute)), "3 days ago"},
|
||||
{"Zero time", time.Time{}, "-"},
|
||||
{"Unix zero time", time.Unix(0, 0), "-"},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
result := timeAgo(tc.input)
|
||||
assert.Equal(t, tc.expected, result, "Failed %s", tc.name)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,6 +4,8 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"runtime"
|
||||
"runtime/debug"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -29,7 +31,7 @@ import (
|
||||
|
||||
// RunClient with main logic.
|
||||
func RunClient(ctx context.Context, config *Config, statusRecorder *peer.Status) error {
|
||||
return runClient(ctx, config, statusRecorder, MobileDependency{}, nil, nil, nil, nil)
|
||||
return runClient(ctx, config, statusRecorder, MobileDependency{}, nil, nil, nil, nil, nil)
|
||||
}
|
||||
|
||||
// RunClientWithProbes runs the client's main logic with probes attached
|
||||
@@ -41,8 +43,9 @@ func RunClientWithProbes(
|
||||
signalProbe *Probe,
|
||||
relayProbe *Probe,
|
||||
wgProbe *Probe,
|
||||
engineChan chan<- *Engine,
|
||||
) error {
|
||||
return runClient(ctx, config, statusRecorder, MobileDependency{}, mgmProbe, signalProbe, relayProbe, wgProbe)
|
||||
return runClient(ctx, config, statusRecorder, MobileDependency{}, mgmProbe, signalProbe, relayProbe, wgProbe, engineChan)
|
||||
}
|
||||
|
||||
// RunClientMobile with main logic on mobile system
|
||||
@@ -64,7 +67,7 @@ func RunClientMobile(
|
||||
HostDNSAddresses: dnsAddresses,
|
||||
DnsReadyListener: dnsReadyListener,
|
||||
}
|
||||
return runClient(ctx, config, statusRecorder, mobileDependency, nil, nil, nil, nil)
|
||||
return runClient(ctx, config, statusRecorder, mobileDependency, nil, nil, nil, nil, nil)
|
||||
}
|
||||
|
||||
func RunClientiOS(
|
||||
@@ -80,7 +83,7 @@ func RunClientiOS(
|
||||
NetworkChangeListener: networkChangeListener,
|
||||
DnsManager: dnsManager,
|
||||
}
|
||||
return runClient(ctx, config, statusRecorder, mobileDependency, nil, nil, nil, nil)
|
||||
return runClient(ctx, config, statusRecorder, mobileDependency, nil, nil, nil, nil, nil)
|
||||
}
|
||||
|
||||
func runClient(
|
||||
@@ -92,8 +95,15 @@ func runClient(
|
||||
signalProbe *Probe,
|
||||
relayProbe *Probe,
|
||||
wgProbe *Probe,
|
||||
engineChan chan<- *Engine,
|
||||
) error {
|
||||
log.Infof("starting NetBird client version %s", version.NetbirdVersion())
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
log.Panicf("Panic occurred: %v, stack trace: %s", r, string(debug.Stack()))
|
||||
}
|
||||
}()
|
||||
|
||||
log.Infof("starting NetBird client version %s on %s/%s", version.NetbirdVersion(), runtime.GOOS, runtime.GOARCH)
|
||||
|
||||
// Check if client was not shut down in a clean way and restore DNS config if required.
|
||||
// Otherwise, we might not be able to connect to the management server to retrieve new config.
|
||||
@@ -235,6 +245,9 @@ func runClient(
|
||||
log.Errorf("error while starting Netbird Connection Engine: %s", err)
|
||||
return wrapErr(err)
|
||||
}
|
||||
if engineChan != nil {
|
||||
engineChan <- engine
|
||||
}
|
||||
|
||||
log.Print("Netbird engine started, my IP is: ", peerConfig.Address)
|
||||
state.Set(StatusConnected)
|
||||
@@ -244,6 +257,10 @@ func runClient(
|
||||
|
||||
backOff.Reset()
|
||||
|
||||
if engineChan != nil {
|
||||
engineChan <- nil
|
||||
}
|
||||
|
||||
err = engine.Stop()
|
||||
if err != nil {
|
||||
log.Errorf("failed stopping engine %v", err)
|
||||
|
||||
@@ -31,6 +31,8 @@ func (d *localResolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||
response := d.lookupRecord(r)
|
||||
if response != nil {
|
||||
replyMessage.Answer = append(replyMessage.Answer, response)
|
||||
} else {
|
||||
replyMessage.Rcode = dns.RcodeNameError
|
||||
}
|
||||
|
||||
err := w.WriteMsg(replyMessage)
|
||||
|
||||
@@ -93,6 +93,10 @@ type Engine struct {
|
||||
mgmClient mgm.Client
|
||||
// peerConns is a map that holds all the peers that are known to this peer
|
||||
peerConns map[string]*peer.Conn
|
||||
|
||||
beforePeerHook peer.BeforeAddPeerHookFunc
|
||||
afterPeerHook peer.AfterRemovePeerHookFunc
|
||||
|
||||
// rpManager is a Rosenpass manager
|
||||
rpManager *rosenpass.Manager
|
||||
|
||||
@@ -107,6 +111,9 @@ type Engine struct {
|
||||
// TURNs is a list of STUN servers used by ICE
|
||||
TURNs []*stun.URI
|
||||
|
||||
// clientRoutes is the most recent list of clientRoutes received from the Management Service
|
||||
clientRoutes map[string][]*route.Route
|
||||
|
||||
cancel context.CancelFunc
|
||||
|
||||
ctx context.Context
|
||||
@@ -212,6 +219,8 @@ func (e *Engine) Stop() error {
|
||||
return err
|
||||
}
|
||||
|
||||
e.clientRoutes = nil
|
||||
|
||||
// very ugly but we want to remove peers from the WireGuard interface first before removing interface.
|
||||
// Removing peers happens in the conn.CLose() asynchronously
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
@@ -260,9 +269,14 @@ func (e *Engine) Start() error {
|
||||
e.dnsServer = dnsServer
|
||||
|
||||
e.routeManager = routemanager.NewManager(e.ctx, e.config.WgPrivateKey.PublicKey().String(), e.wgInterface, e.statusRecorder, initialRoutes)
|
||||
if err := e.routeManager.Init(); err != nil {
|
||||
beforePeerHook, afterPeerHook, err := e.routeManager.Init()
|
||||
if err != nil {
|
||||
log.Errorf("Failed to initialize route manager: %s", err)
|
||||
} else {
|
||||
e.beforePeerHook = beforePeerHook
|
||||
e.afterPeerHook = afterPeerHook
|
||||
}
|
||||
|
||||
e.routeManager.SetRouteChangeListener(e.mobileDep.NetworkChangeListener)
|
||||
|
||||
err = e.wgInterfaceCreate()
|
||||
@@ -686,11 +700,14 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
|
||||
if protoRoutes == nil {
|
||||
protoRoutes = []*mgmProto.Route{}
|
||||
}
|
||||
err := e.routeManager.UpdateRoutes(serial, toRoutes(protoRoutes))
|
||||
|
||||
_, clientRoutes, err := e.routeManager.UpdateRoutes(serial, toRoutes(protoRoutes))
|
||||
if err != nil {
|
||||
log.Errorf("failed to update routes, err: %v", err)
|
||||
log.Errorf("failed to update clientRoutes, err: %v", err)
|
||||
}
|
||||
|
||||
e.clientRoutes = clientRoutes
|
||||
|
||||
protoDNSConfig := networkMap.GetDNSConfig()
|
||||
if protoDNSConfig == nil {
|
||||
protoDNSConfig = &mgmProto.DNSConfig{}
|
||||
@@ -785,6 +802,7 @@ func (e *Engine) updateOfflinePeers(offlinePeers []*mgmProto.RemotePeerConfig) {
|
||||
FQDN: offlinePeer.GetFqdn(),
|
||||
ConnStatus: peer.StatusDisconnected,
|
||||
ConnStatusUpdate: time.Now(),
|
||||
Mux: new(sync.RWMutex),
|
||||
}
|
||||
}
|
||||
e.statusRecorder.ReplaceOfflinePeers(replacement)
|
||||
@@ -808,10 +826,15 @@ func (e *Engine) addNewPeer(peerConfig *mgmProto.RemotePeerConfig) error {
|
||||
if _, ok := e.peerConns[peerKey]; !ok {
|
||||
conn, err := e.createPeerConn(peerKey, strings.Join(peerIPs, ","))
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("create peer connection: %w", err)
|
||||
}
|
||||
e.peerConns[peerKey] = conn
|
||||
|
||||
if e.beforePeerHook != nil && e.afterPeerHook != nil {
|
||||
conn.AddBeforeAddPeerHook(e.beforePeerHook)
|
||||
conn.AddAfterRemovePeerHook(e.afterPeerHook)
|
||||
}
|
||||
|
||||
err = e.statusRecorder.AddPeer(peerKey, peerConfig.Fqdn)
|
||||
if err != nil {
|
||||
log.Warnf("error adding peer %s to status recorder, got error: %v", peerKey, err)
|
||||
@@ -1105,6 +1128,10 @@ func (e *Engine) close() {
|
||||
e.dnsServer.Stop()
|
||||
}
|
||||
|
||||
if e.routeManager != nil {
|
||||
e.routeManager.Stop()
|
||||
}
|
||||
|
||||
log.Debugf("removing Netbird interface %s", e.config.WgIfaceName)
|
||||
if e.wgInterface != nil {
|
||||
if err := e.wgInterface.Close(); err != nil {
|
||||
@@ -1119,10 +1146,6 @@ func (e *Engine) close() {
|
||||
}
|
||||
}
|
||||
|
||||
if e.routeManager != nil {
|
||||
e.routeManager.Stop()
|
||||
}
|
||||
|
||||
if e.firewall != nil {
|
||||
err := e.firewall.Reset()
|
||||
if err != nil {
|
||||
@@ -1214,6 +1237,28 @@ func (e *Engine) newDnsServer() ([]*route.Route, dns.Server, error) {
|
||||
}
|
||||
}
|
||||
|
||||
// GetClientRoutes returns the current routes from the route map
|
||||
func (e *Engine) GetClientRoutes() map[string][]*route.Route {
|
||||
return e.clientRoutes
|
||||
}
|
||||
|
||||
// GetClientRoutesWithNetID returns the current routes from the route map, but the keys consist of the network ID only
|
||||
func (e *Engine) GetClientRoutesWithNetID() map[string][]*route.Route {
|
||||
routes := make(map[string][]*route.Route, len(e.clientRoutes))
|
||||
for id, v := range e.clientRoutes {
|
||||
if i := strings.LastIndex(id, "-"); i != -1 {
|
||||
id = id[:i]
|
||||
}
|
||||
routes[id] = v
|
||||
}
|
||||
return routes
|
||||
}
|
||||
|
||||
// GetRouteManager returns the route manager
|
||||
func (e *Engine) GetRouteManager() routemanager.Manager {
|
||||
return e.routeManager
|
||||
}
|
||||
|
||||
func findIPFromInterfaceName(ifaceName string) (net.IP, error) {
|
||||
iface, err := net.InterfaceByName(ifaceName)
|
||||
if err != nil {
|
||||
|
||||
@@ -22,6 +22,7 @@ import (
|
||||
"google.golang.org/grpc/keepalive"
|
||||
|
||||
"github.com/netbirdio/management-integrations/integrations"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/dns"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager"
|
||||
@@ -577,10 +578,10 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) {
|
||||
}{}
|
||||
|
||||
mockRouteManager := &routemanager.MockManager{
|
||||
UpdateRoutesFunc: func(updateSerial uint64, newRoutes []*route.Route) error {
|
||||
UpdateRoutesFunc: func(updateSerial uint64, newRoutes []*route.Route) (map[string]*route.Route, map[string][]*route.Route, error) {
|
||||
input.inputSerial = updateSerial
|
||||
input.inputRoutes = newRoutes
|
||||
return testCase.inputErr
|
||||
return nil, nil, testCase.inputErr
|
||||
},
|
||||
}
|
||||
|
||||
@@ -597,8 +598,8 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) {
|
||||
err = engine.updateNetworkMap(testCase.networkMap)
|
||||
assert.NoError(t, err, "shouldn't return error")
|
||||
assert.Equal(t, testCase.expectedSerial, input.inputSerial, "serial should match")
|
||||
assert.Len(t, input.inputRoutes, testCase.expectedLen, "routes len should match")
|
||||
assert.Equal(t, testCase.expectedRoutes, input.inputRoutes, "routes should match")
|
||||
assert.Len(t, input.inputRoutes, testCase.expectedLen, "clientRoutes len should match")
|
||||
assert.Equal(t, testCase.expectedRoutes, input.inputRoutes, "clientRoutes should match")
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -742,8 +743,8 @@ func TestEngine_UpdateNetworkMapWithDNSUpdate(t *testing.T) {
|
||||
assert.NoError(t, err, "shouldn't return error")
|
||||
|
||||
mockRouteManager := &routemanager.MockManager{
|
||||
UpdateRoutesFunc: func(updateSerial uint64, newRoutes []*route.Route) error {
|
||||
return nil
|
||||
UpdateRoutesFunc: func(updateSerial uint64, newRoutes []*route.Route) (map[string]*route.Route, map[string][]*route.Route, error) {
|
||||
return nil, nil, nil
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
1
client/internal/engine_watcher.go
Normal file
1
client/internal/engine_watcher.go
Normal file
@@ -0,0 +1 @@
|
||||
package internal
|
||||
@@ -20,6 +20,7 @@ import (
|
||||
"github.com/netbirdio/netbird/iface/bind"
|
||||
signal "github.com/netbirdio/netbird/signal/client"
|
||||
sProto "github.com/netbirdio/netbird/signal/proto"
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
"github.com/netbirdio/netbird/version"
|
||||
)
|
||||
|
||||
@@ -100,6 +101,9 @@ type IceCredentials struct {
|
||||
Pwd string
|
||||
}
|
||||
|
||||
type BeforeAddPeerHookFunc func(connID nbnet.ConnectionID, IP net.IP) error
|
||||
type AfterRemovePeerHookFunc func(connID nbnet.ConnectionID) error
|
||||
|
||||
type Conn struct {
|
||||
config ConnConfig
|
||||
mu sync.Mutex
|
||||
@@ -138,6 +142,10 @@ type Conn struct {
|
||||
|
||||
remoteEndpoint *net.UDPAddr
|
||||
remoteConn *ice.Conn
|
||||
|
||||
connID nbnet.ConnectionID
|
||||
beforeAddPeerHooks []BeforeAddPeerHookFunc
|
||||
afterRemovePeerHooks []AfterRemovePeerHookFunc
|
||||
}
|
||||
|
||||
// meta holds meta information about a connection
|
||||
@@ -221,7 +229,6 @@ func (conn *Conn) reCreateAgent() error {
|
||||
}
|
||||
|
||||
conn.agent, err = ice.NewAgent(agentConfig)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -277,6 +284,7 @@ func (conn *Conn) Open() error {
|
||||
IP: strings.Split(conn.config.WgConfig.AllowedIps, "/")[0],
|
||||
ConnStatusUpdate: time.Now(),
|
||||
ConnStatus: conn.status,
|
||||
Mux: new(sync.RWMutex),
|
||||
}
|
||||
err := conn.statusRecorder.UpdatePeerState(peerState)
|
||||
if err != nil {
|
||||
@@ -336,6 +344,7 @@ func (conn *Conn) Open() error {
|
||||
PubKey: conn.config.Key,
|
||||
ConnStatus: conn.status,
|
||||
ConnStatusUpdate: time.Now(),
|
||||
Mux: new(sync.RWMutex),
|
||||
}
|
||||
err = conn.statusRecorder.UpdatePeerState(peerState)
|
||||
if err != nil {
|
||||
@@ -393,6 +402,14 @@ func isRelayCandidate(candidate ice.Candidate) bool {
|
||||
return candidate.Type() == ice.CandidateTypeRelay
|
||||
}
|
||||
|
||||
func (conn *Conn) AddBeforeAddPeerHook(hook BeforeAddPeerHookFunc) {
|
||||
conn.beforeAddPeerHooks = append(conn.beforeAddPeerHooks, hook)
|
||||
}
|
||||
|
||||
func (conn *Conn) AddAfterRemovePeerHook(hook AfterRemovePeerHookFunc) {
|
||||
conn.afterRemovePeerHooks = append(conn.afterRemovePeerHooks, hook)
|
||||
}
|
||||
|
||||
// configureConnection starts proxying traffic from/to local Wireguard and sets connection status to StatusConnected
|
||||
func (conn *Conn) configureConnection(remoteConn net.Conn, remoteWgPort int, remoteRosenpassPubKey []byte, remoteRosenpassAddr string) (net.Addr, error) {
|
||||
conn.mu.Lock()
|
||||
@@ -419,6 +436,14 @@ func (conn *Conn) configureConnection(remoteConn net.Conn, remoteWgPort int, rem
|
||||
|
||||
endpointUdpAddr, _ := net.ResolveUDPAddr(endpoint.Network(), endpoint.String())
|
||||
conn.remoteEndpoint = endpointUdpAddr
|
||||
log.Debugf("Conn resolved IP for %s: %s", endpoint, endpointUdpAddr.IP)
|
||||
|
||||
conn.connID = nbnet.GenerateConnID()
|
||||
for _, hook := range conn.beforeAddPeerHooks {
|
||||
if err := hook(conn.connID, endpointUdpAddr.IP); err != nil {
|
||||
log.Errorf("Before add peer hook failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
err = conn.config.WgConfig.WgInterface.UpdatePeer(conn.config.WgConfig.RemoteKey, conn.config.WgConfig.AllowedIps, defaultWgKeepAlive, endpointUdpAddr, conn.config.WgConfig.PreSharedKey)
|
||||
if err != nil {
|
||||
@@ -441,9 +466,10 @@ func (conn *Conn) configureConnection(remoteConn net.Conn, remoteWgPort int, rem
|
||||
LocalIceCandidateType: pair.Local.Type().String(),
|
||||
RemoteIceCandidateType: pair.Remote.Type().String(),
|
||||
LocalIceCandidateEndpoint: fmt.Sprintf("%s:%d", pair.Local.Address(), pair.Local.Port()),
|
||||
RemoteIceCandidateEndpoint: fmt.Sprintf("%s:%d", pair.Remote.Address(), pair.Local.Port()),
|
||||
RemoteIceCandidateEndpoint: fmt.Sprintf("%s:%d", pair.Remote.Address(), pair.Remote.Port()),
|
||||
Direct: !isRelayCandidate(pair.Local),
|
||||
RosenpassEnabled: rosenpassEnabled,
|
||||
Mux: new(sync.RWMutex),
|
||||
}
|
||||
if pair.Local.Type() == ice.CandidateTypeRelay || pair.Remote.Type() == ice.CandidateTypeRelay {
|
||||
peerState.Relayed = true
|
||||
@@ -510,6 +536,15 @@ func (conn *Conn) cleanup() error {
|
||||
// todo: is it problem if we try to remove a peer what is never existed?
|
||||
err3 = conn.config.WgConfig.WgInterface.RemovePeer(conn.config.WgConfig.RemoteKey)
|
||||
|
||||
if conn.connID != "" {
|
||||
for _, hook := range conn.afterRemovePeerHooks {
|
||||
if err := hook(conn.connID); err != nil {
|
||||
log.Errorf("After remove peer hook failed: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
conn.connID = ""
|
||||
|
||||
if conn.notifyDisconnected != nil {
|
||||
conn.notifyDisconnected()
|
||||
conn.notifyDisconnected = nil
|
||||
@@ -525,6 +560,7 @@ func (conn *Conn) cleanup() error {
|
||||
PubKey: conn.config.Key,
|
||||
ConnStatus: conn.status,
|
||||
ConnStatusUpdate: time.Now(),
|
||||
Mux: new(sync.RWMutex),
|
||||
}
|
||||
err := conn.statusRecorder.UpdatePeerState(peerState)
|
||||
if err != nil {
|
||||
|
||||
@@ -14,6 +14,7 @@ import (
|
||||
|
||||
// State contains the latest state of a peer
|
||||
type State struct {
|
||||
Mux *sync.RWMutex
|
||||
IP string
|
||||
PubKey string
|
||||
FQDN string
|
||||
@@ -30,7 +31,38 @@ type State struct {
|
||||
BytesRx int64
|
||||
Latency time.Duration
|
||||
RosenpassEnabled bool
|
||||
Routes map[string]struct{}
|
||||
routes map[string]struct{}
|
||||
}
|
||||
|
||||
// AddRoute add a single route to routes map
|
||||
func (s *State) AddRoute(network string) {
|
||||
s.Mux.Lock()
|
||||
if s.routes == nil {
|
||||
s.routes = make(map[string]struct{})
|
||||
}
|
||||
s.routes[network] = struct{}{}
|
||||
s.Mux.Unlock()
|
||||
}
|
||||
|
||||
// SetRoutes set state routes
|
||||
func (s *State) SetRoutes(routes map[string]struct{}) {
|
||||
s.Mux.Lock()
|
||||
s.routes = routes
|
||||
s.Mux.Unlock()
|
||||
}
|
||||
|
||||
// DeleteRoute removes a route from the network amp
|
||||
func (s *State) DeleteRoute(network string) {
|
||||
s.Mux.Lock()
|
||||
delete(s.routes, network)
|
||||
s.Mux.Unlock()
|
||||
}
|
||||
|
||||
// GetRoutes return routes map
|
||||
func (s *State) GetRoutes() map[string]struct{} {
|
||||
s.Mux.RLock()
|
||||
defer s.Mux.RUnlock()
|
||||
return s.routes
|
||||
}
|
||||
|
||||
// LocalPeerState contains the latest state of the local peer
|
||||
@@ -143,6 +175,7 @@ func (d *Status) AddPeer(peerPubKey string, fqdn string) error {
|
||||
PubKey: peerPubKey,
|
||||
ConnStatus: StatusDisconnected,
|
||||
FQDN: fqdn,
|
||||
Mux: new(sync.RWMutex),
|
||||
}
|
||||
d.peerListChangedForNotification = true
|
||||
return nil
|
||||
@@ -189,8 +222,8 @@ func (d *Status) UpdatePeerState(receivedState State) error {
|
||||
peerState.IP = receivedState.IP
|
||||
}
|
||||
|
||||
if receivedState.Routes != nil {
|
||||
peerState.Routes = receivedState.Routes
|
||||
if receivedState.GetRoutes() != nil {
|
||||
peerState.SetRoutes(receivedState.GetRoutes())
|
||||
}
|
||||
|
||||
skipNotification := shouldSkipNotify(receivedState, peerState)
|
||||
@@ -440,7 +473,6 @@ func (d *Status) IsLoginRequired() bool {
|
||||
s, ok := gstatus.FromError(d.managementError)
|
||||
if ok && (s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied) {
|
||||
return true
|
||||
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@ package peer
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
"sync"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
@@ -42,6 +43,7 @@ func TestUpdatePeerState(t *testing.T) {
|
||||
status := NewRecorder("https://mgm")
|
||||
peerState := State{
|
||||
PubKey: key,
|
||||
Mux: new(sync.RWMutex),
|
||||
}
|
||||
|
||||
status.peers[key] = peerState
|
||||
@@ -62,6 +64,7 @@ func TestStatus_UpdatePeerFQDN(t *testing.T) {
|
||||
status := NewRecorder("https://mgm")
|
||||
peerState := State{
|
||||
PubKey: key,
|
||||
Mux: new(sync.RWMutex),
|
||||
}
|
||||
|
||||
status.peers[key] = peerState
|
||||
@@ -80,6 +83,7 @@ func TestGetPeerStateChangeNotifierLogic(t *testing.T) {
|
||||
status := NewRecorder("https://mgm")
|
||||
peerState := State{
|
||||
PubKey: key,
|
||||
Mux: new(sync.RWMutex),
|
||||
}
|
||||
|
||||
status.peers[key] = peerState
|
||||
@@ -104,6 +108,7 @@ func TestRemovePeer(t *testing.T) {
|
||||
status := NewRecorder("https://mgm")
|
||||
peerState := State{
|
||||
PubKey: key,
|
||||
Mux: new(sync.RWMutex),
|
||||
}
|
||||
|
||||
status.peers[key] = peerState
|
||||
|
||||
@@ -3,7 +3,9 @@ package routemanager
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
@@ -18,6 +20,7 @@ type routerPeerStatus struct {
|
||||
connected bool
|
||||
relayed bool
|
||||
direct bool
|
||||
latency time.Duration
|
||||
}
|
||||
|
||||
type routesUpdate struct {
|
||||
@@ -68,6 +71,7 @@ func (c *clientNetwork) getRouterPeerStatuses() map[string]routerPeerStatus {
|
||||
connected: peerStatus.ConnStatus == peer.StatusConnected,
|
||||
relayed: peerStatus.Relayed,
|
||||
direct: peerStatus.Direct,
|
||||
latency: peerStatus.Latency,
|
||||
}
|
||||
}
|
||||
return routePeerStatuses
|
||||
@@ -83,11 +87,13 @@ func (c *clientNetwork) getRouterPeerStatuses() map[string]routerPeerStatus {
|
||||
// * Non-relayed: Routes without relays are preferred.
|
||||
// * Direct connections: Routes with direct peer connections are favored.
|
||||
// * Stability: In case of equal scores, the currently active route (if any) is maintained.
|
||||
// * Latency: Routes with lower latency are prioritized.
|
||||
//
|
||||
// It returns the ID of the selected optimal route.
|
||||
func (c *clientNetwork) getBestRouteFromStatuses(routePeerStatuses map[string]routerPeerStatus) string {
|
||||
chosen := ""
|
||||
chosenScore := 0
|
||||
chosenScore := float64(0)
|
||||
currScore := float64(0)
|
||||
|
||||
currID := ""
|
||||
if c.chosenRoute != nil {
|
||||
@@ -95,7 +101,7 @@ func (c *clientNetwork) getBestRouteFromStatuses(routePeerStatuses map[string]ro
|
||||
}
|
||||
|
||||
for _, r := range c.routes {
|
||||
tempScore := 0
|
||||
tempScore := float64(0)
|
||||
peerStatus, found := routePeerStatuses[r.ID]
|
||||
if !found || !peerStatus.connected {
|
||||
continue
|
||||
@@ -103,9 +109,18 @@ func (c *clientNetwork) getBestRouteFromStatuses(routePeerStatuses map[string]ro
|
||||
|
||||
if r.Metric < route.MaxMetric {
|
||||
metricDiff := route.MaxMetric - r.Metric
|
||||
tempScore = metricDiff * 10
|
||||
tempScore = float64(metricDiff) * 10
|
||||
}
|
||||
|
||||
// in some temporal cases, latency can be 0, so we set it to 1s to not block but try to avoid this route
|
||||
latency := time.Second
|
||||
if peerStatus.latency != 0 {
|
||||
latency = peerStatus.latency
|
||||
} else {
|
||||
log.Warnf("peer %s has 0 latency", r.Peer)
|
||||
}
|
||||
tempScore += 1 - latency.Seconds()
|
||||
|
||||
if !peerStatus.relayed {
|
||||
tempScore++
|
||||
}
|
||||
@@ -114,7 +129,7 @@ func (c *clientNetwork) getBestRouteFromStatuses(routePeerStatuses map[string]ro
|
||||
tempScore++
|
||||
}
|
||||
|
||||
if tempScore > chosenScore || (tempScore == chosenScore && r.ID == currID) {
|
||||
if tempScore > chosenScore || (tempScore == chosenScore && chosen == "") {
|
||||
chosen = r.ID
|
||||
chosenScore = tempScore
|
||||
}
|
||||
@@ -123,18 +138,30 @@ func (c *clientNetwork) getBestRouteFromStatuses(routePeerStatuses map[string]ro
|
||||
chosen = r.ID
|
||||
chosenScore = tempScore
|
||||
}
|
||||
|
||||
if r.ID == currID {
|
||||
currScore = tempScore
|
||||
}
|
||||
}
|
||||
|
||||
if chosen == "" {
|
||||
switch {
|
||||
case chosen == "":
|
||||
var peers []string
|
||||
for _, r := range c.routes {
|
||||
peers = append(peers, r.Peer)
|
||||
}
|
||||
|
||||
log.Warnf("the network %s has not been assigned a routing peer as no peers from the list %s are currently connected", c.network, peers)
|
||||
|
||||
} else if chosen != currID {
|
||||
log.Infof("new chosen route is %s with peer %s with score %d for network %s", chosen, c.routes[chosen].Peer, chosenScore, c.network)
|
||||
case chosen != currID:
|
||||
if currScore != 0 && currScore < chosenScore+0.1 {
|
||||
return currID
|
||||
} else {
|
||||
var peer string
|
||||
if route := c.routes[chosen]; route != nil {
|
||||
peer = route.Peer
|
||||
}
|
||||
log.Infof("new chosen route is %s with peer %s with score %f for network %s", chosen, peer, chosenScore, c.network)
|
||||
}
|
||||
}
|
||||
|
||||
return chosen
|
||||
@@ -174,7 +201,7 @@ func (c *clientNetwork) removeRouteFromWireguardPeer(peerKey string) error {
|
||||
return fmt.Errorf("get peer state: %v", err)
|
||||
}
|
||||
|
||||
delete(state.Routes, c.network.String())
|
||||
state.DeleteRoute(c.network.String())
|
||||
if err := c.statusRecorder.UpdatePeerState(state); err != nil {
|
||||
log.Warnf("Failed to update peer state: %v", err)
|
||||
}
|
||||
@@ -193,7 +220,7 @@ func (c *clientNetwork) removeRouteFromWireguardPeer(peerKey string) error {
|
||||
|
||||
func (c *clientNetwork) removeRouteFromPeerAndSystem() error {
|
||||
if c.chosenRoute != nil {
|
||||
if err := removeFromRouteTableIfNonSystem(c.network, c.wgInterface.Address().IP.String(), c.wgInterface.Name()); err != nil {
|
||||
if err := removeVPNRoute(c.network, c.getAsInterface()); err != nil {
|
||||
return fmt.Errorf("remove route %s from system, err: %v", c.network, err)
|
||||
}
|
||||
|
||||
@@ -234,7 +261,7 @@ func (c *clientNetwork) recalculateRouteAndUpdatePeerAndSystem() error {
|
||||
}
|
||||
} else {
|
||||
// otherwise add the route to the system
|
||||
if err := addToRouteTableIfNoExists(c.network, c.wgInterface.Address().IP.String(), c.wgInterface.Name()); err != nil {
|
||||
if err := addVPNRoute(c.network, c.getAsInterface()); err != nil {
|
||||
return fmt.Errorf("route %s couldn't be added for peer %s, err: %v",
|
||||
c.network.String(), c.wgInterface.Address().IP.String(), err)
|
||||
}
|
||||
@@ -246,10 +273,7 @@ func (c *clientNetwork) recalculateRouteAndUpdatePeerAndSystem() error {
|
||||
if err != nil {
|
||||
log.Errorf("Failed to get peer state: %v", err)
|
||||
} else {
|
||||
if state.Routes == nil {
|
||||
state.Routes = map[string]struct{}{}
|
||||
}
|
||||
state.Routes[c.network.String()] = struct{}{}
|
||||
state.AddRoute(c.network.String())
|
||||
if err := c.statusRecorder.UpdatePeerState(state); err != nil {
|
||||
log.Warnf("Failed to update peer state: %v", err)
|
||||
}
|
||||
@@ -325,3 +349,15 @@ func (c *clientNetwork) peersStateAndUpdateWatcher() {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *clientNetwork) getAsInterface() *net.Interface {
|
||||
intf, err := net.InterfaceByName(c.wgInterface.Name())
|
||||
if err != nil {
|
||||
log.Warnf("Couldn't get interface by name %s: %v", c.wgInterface.Name(), err)
|
||||
intf = &net.Interface{
|
||||
Name: c.wgInterface.Name(),
|
||||
}
|
||||
}
|
||||
|
||||
return intf
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@ package routemanager
|
||||
import (
|
||||
"net/netip"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/netbirdio/netbird/route"
|
||||
)
|
||||
@@ -13,7 +14,7 @@ func TestGetBestrouteFromStatuses(t *testing.T) {
|
||||
name string
|
||||
statuses map[string]routerPeerStatus
|
||||
expectedRouteID string
|
||||
currentRoute *route.Route
|
||||
currentRoute string
|
||||
existingRoutes map[string]*route.Route
|
||||
}{
|
||||
{
|
||||
@@ -32,7 +33,7 @@ func TestGetBestrouteFromStatuses(t *testing.T) {
|
||||
Peer: "peer1",
|
||||
},
|
||||
},
|
||||
currentRoute: nil,
|
||||
currentRoute: "",
|
||||
expectedRouteID: "route1",
|
||||
},
|
||||
{
|
||||
@@ -51,7 +52,7 @@ func TestGetBestrouteFromStatuses(t *testing.T) {
|
||||
Peer: "peer1",
|
||||
},
|
||||
},
|
||||
currentRoute: nil,
|
||||
currentRoute: "",
|
||||
expectedRouteID: "route1",
|
||||
},
|
||||
{
|
||||
@@ -70,7 +71,7 @@ func TestGetBestrouteFromStatuses(t *testing.T) {
|
||||
Peer: "peer1",
|
||||
},
|
||||
},
|
||||
currentRoute: nil,
|
||||
currentRoute: "",
|
||||
expectedRouteID: "route1",
|
||||
},
|
||||
{
|
||||
@@ -89,7 +90,7 @@ func TestGetBestrouteFromStatuses(t *testing.T) {
|
||||
Peer: "peer1",
|
||||
},
|
||||
},
|
||||
currentRoute: nil,
|
||||
currentRoute: "",
|
||||
expectedRouteID: "",
|
||||
},
|
||||
{
|
||||
@@ -118,7 +119,7 @@ func TestGetBestrouteFromStatuses(t *testing.T) {
|
||||
Peer: "peer2",
|
||||
},
|
||||
},
|
||||
currentRoute: nil,
|
||||
currentRoute: "",
|
||||
expectedRouteID: "route1",
|
||||
},
|
||||
{
|
||||
@@ -147,7 +148,7 @@ func TestGetBestrouteFromStatuses(t *testing.T) {
|
||||
Peer: "peer2",
|
||||
},
|
||||
},
|
||||
currentRoute: nil,
|
||||
currentRoute: "",
|
||||
expectedRouteID: "route1",
|
||||
},
|
||||
{
|
||||
@@ -176,18 +177,141 @@ func TestGetBestrouteFromStatuses(t *testing.T) {
|
||||
Peer: "peer2",
|
||||
},
|
||||
},
|
||||
currentRoute: nil,
|
||||
currentRoute: "",
|
||||
expectedRouteID: "route1",
|
||||
},
|
||||
{
|
||||
name: "multiple connected peers with different latencies",
|
||||
statuses: map[string]routerPeerStatus{
|
||||
"route1": {
|
||||
connected: true,
|
||||
latency: 300 * time.Millisecond,
|
||||
},
|
||||
"route2": {
|
||||
connected: true,
|
||||
latency: 10 * time.Millisecond,
|
||||
},
|
||||
},
|
||||
existingRoutes: map[string]*route.Route{
|
||||
"route1": {
|
||||
ID: "route1",
|
||||
Metric: route.MaxMetric,
|
||||
Peer: "peer1",
|
||||
},
|
||||
"route2": {
|
||||
ID: "route2",
|
||||
Metric: route.MaxMetric,
|
||||
Peer: "peer2",
|
||||
},
|
||||
},
|
||||
currentRoute: "",
|
||||
expectedRouteID: "route2",
|
||||
},
|
||||
{
|
||||
name: "should ignore routes with latency 0",
|
||||
statuses: map[string]routerPeerStatus{
|
||||
"route1": {
|
||||
connected: true,
|
||||
latency: 0 * time.Millisecond,
|
||||
},
|
||||
"route2": {
|
||||
connected: true,
|
||||
latency: 10 * time.Millisecond,
|
||||
},
|
||||
},
|
||||
existingRoutes: map[string]*route.Route{
|
||||
"route1": {
|
||||
ID: "route1",
|
||||
Metric: route.MaxMetric,
|
||||
Peer: "peer1",
|
||||
},
|
||||
"route2": {
|
||||
ID: "route2",
|
||||
Metric: route.MaxMetric,
|
||||
Peer: "peer2",
|
||||
},
|
||||
},
|
||||
currentRoute: "",
|
||||
expectedRouteID: "route2",
|
||||
},
|
||||
{
|
||||
name: "current route with similar score and similar but slightly worse latency should not change",
|
||||
statuses: map[string]routerPeerStatus{
|
||||
"route1": {
|
||||
connected: true,
|
||||
relayed: false,
|
||||
direct: true,
|
||||
latency: 12 * time.Millisecond,
|
||||
},
|
||||
"route2": {
|
||||
connected: true,
|
||||
relayed: false,
|
||||
direct: true,
|
||||
latency: 10 * time.Millisecond,
|
||||
},
|
||||
},
|
||||
existingRoutes: map[string]*route.Route{
|
||||
"route1": {
|
||||
ID: "route1",
|
||||
Metric: route.MaxMetric,
|
||||
Peer: "peer1",
|
||||
},
|
||||
"route2": {
|
||||
ID: "route2",
|
||||
Metric: route.MaxMetric,
|
||||
Peer: "peer2",
|
||||
},
|
||||
},
|
||||
currentRoute: "route1",
|
||||
expectedRouteID: "route1",
|
||||
},
|
||||
{
|
||||
name: "current chosen route doesn't exist anymore",
|
||||
statuses: map[string]routerPeerStatus{
|
||||
"route1": {
|
||||
connected: true,
|
||||
relayed: false,
|
||||
direct: true,
|
||||
latency: 20 * time.Millisecond,
|
||||
},
|
||||
"route2": {
|
||||
connected: true,
|
||||
relayed: false,
|
||||
direct: true,
|
||||
latency: 10 * time.Millisecond,
|
||||
},
|
||||
},
|
||||
existingRoutes: map[string]*route.Route{
|
||||
"route1": {
|
||||
ID: "route1",
|
||||
Metric: route.MaxMetric,
|
||||
Peer: "peer1",
|
||||
},
|
||||
"route2": {
|
||||
ID: "route2",
|
||||
Metric: route.MaxMetric,
|
||||
Peer: "peer2",
|
||||
},
|
||||
},
|
||||
currentRoute: "routeDoesntExistAnymore",
|
||||
expectedRouteID: "route2",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
currentRoute := &route.Route{
|
||||
ID: "routeDoesntExistAnymore",
|
||||
}
|
||||
if tc.currentRoute != "" {
|
||||
currentRoute = tc.existingRoutes[tc.currentRoute]
|
||||
}
|
||||
|
||||
// create new clientNetwork
|
||||
client := &clientNetwork{
|
||||
network: netip.MustParsePrefix("192.168.0.0/24"),
|
||||
routes: tc.existingRoutes,
|
||||
chosenRoute: tc.currentRoute,
|
||||
chosenRoute: currentRoute,
|
||||
}
|
||||
|
||||
chosenRoute := client.getBestRouteFromStatuses(tc.statuses)
|
||||
|
||||
@@ -3,7 +3,9 @@ package routemanager
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"net/url"
|
||||
"runtime"
|
||||
"sync"
|
||||
|
||||
@@ -12,8 +14,10 @@ import (
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/internal/listener"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/routeselector"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
"github.com/netbirdio/netbird/version"
|
||||
)
|
||||
|
||||
@@ -24,8 +28,10 @@ var defaultv6 = netip.PrefixFrom(netip.IPv6Unspecified(), 0)
|
||||
|
||||
// Manager is a route manager interface
|
||||
type Manager interface {
|
||||
Init() error
|
||||
UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) error
|
||||
Init() (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error)
|
||||
UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) (map[string]*route.Route, map[string][]*route.Route, error)
|
||||
TriggerSelection(map[string][]*route.Route)
|
||||
GetRouteSelector() *routeselector.RouteSelector
|
||||
SetRouteChangeListener(listener listener.NetworkChangeListener)
|
||||
InitialRouteRange() []string
|
||||
EnableServerRouter(firewall firewall.Manager) error
|
||||
@@ -38,6 +44,7 @@ type DefaultManager struct {
|
||||
stop context.CancelFunc
|
||||
mux sync.Mutex
|
||||
clientNetworks map[string]*clientNetwork
|
||||
routeSelector *routeselector.RouteSelector
|
||||
serverRouter serverRouter
|
||||
statusRecorder *peer.Status
|
||||
wgInterface *iface.WGIface
|
||||
@@ -51,6 +58,7 @@ func NewManager(ctx context.Context, pubKey string, wgInterface *iface.WGIface,
|
||||
ctx: mCTX,
|
||||
stop: cancel,
|
||||
clientNetworks: make(map[string]*clientNetwork),
|
||||
routeSelector: routeselector.NewRouteSelector(),
|
||||
statusRecorder: statusRecorder,
|
||||
wgInterface: wgInterface,
|
||||
pubKey: pubKey,
|
||||
@@ -65,16 +73,25 @@ func NewManager(ctx context.Context, pubKey string, wgInterface *iface.WGIface,
|
||||
}
|
||||
|
||||
// Init sets up the routing
|
||||
func (m *DefaultManager) Init() error {
|
||||
func (m *DefaultManager) Init() (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) {
|
||||
if nbnet.CustomRoutingDisabled() {
|
||||
return nil, nil, nil
|
||||
}
|
||||
|
||||
if err := cleanupRouting(); err != nil {
|
||||
log.Warnf("Failed cleaning up routing: %v", err)
|
||||
}
|
||||
|
||||
if err := setupRouting(); err != nil {
|
||||
return fmt.Errorf("setup routing: %w", err)
|
||||
mgmtAddress := m.statusRecorder.GetManagementState().URL
|
||||
signalAddress := m.statusRecorder.GetSignalState().URL
|
||||
ips := resolveURLsToIPs([]string{mgmtAddress, signalAddress})
|
||||
|
||||
beforePeerHook, afterPeerHook, err := setupRouting(ips, m.wgInterface)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("setup routing: %w", err)
|
||||
}
|
||||
log.Info("Routing setup complete")
|
||||
return nil
|
||||
return beforePeerHook, afterPeerHook, nil
|
||||
}
|
||||
|
||||
func (m *DefaultManager) EnableServerRouter(firewall firewall.Manager) error {
|
||||
@@ -92,37 +109,42 @@ func (m *DefaultManager) Stop() {
|
||||
if m.serverRouter != nil {
|
||||
m.serverRouter.cleanUp()
|
||||
}
|
||||
if err := cleanupRouting(); err != nil {
|
||||
log.Errorf("Error cleaning up routing: %v", err)
|
||||
} else {
|
||||
log.Info("Routing cleanup complete")
|
||||
|
||||
if !nbnet.CustomRoutingDisabled() {
|
||||
if err := cleanupRouting(); err != nil {
|
||||
log.Errorf("Error cleaning up routing: %v", err)
|
||||
} else {
|
||||
log.Info("Routing cleanup complete")
|
||||
}
|
||||
}
|
||||
|
||||
m.ctx = nil
|
||||
}
|
||||
|
||||
// UpdateRoutes compares received routes with existing routes and removes, updates or adds them to the client and server maps
|
||||
func (m *DefaultManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) error {
|
||||
func (m *DefaultManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) (map[string]*route.Route, map[string][]*route.Route, error) {
|
||||
select {
|
||||
case <-m.ctx.Done():
|
||||
log.Infof("not updating routes as context is closed")
|
||||
return m.ctx.Err()
|
||||
return nil, nil, m.ctx.Err()
|
||||
default:
|
||||
m.mux.Lock()
|
||||
defer m.mux.Unlock()
|
||||
|
||||
newServerRoutesMap, newClientRoutesIDMap := m.classifiesRoutes(newRoutes)
|
||||
newServerRoutesMap, newClientRoutesIDMap := m.classifyRoutes(newRoutes)
|
||||
|
||||
m.updateClientNetworks(updateSerial, newClientRoutesIDMap)
|
||||
m.notifier.onNewRoutes(newClientRoutesIDMap)
|
||||
filteredClientRoutes := m.routeSelector.FilterSelected(newClientRoutesIDMap)
|
||||
m.updateClientNetworks(updateSerial, filteredClientRoutes)
|
||||
m.notifier.onNewRoutes(filteredClientRoutes)
|
||||
|
||||
if m.serverRouter != nil {
|
||||
err := m.serverRouter.updateRoutes(newServerRoutesMap)
|
||||
if err != nil {
|
||||
return fmt.Errorf("update routes: %w", err)
|
||||
return nil, nil, fmt.Errorf("update routes: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
return newServerRoutesMap, newClientRoutesIDMap, nil
|
||||
}
|
||||
}
|
||||
|
||||
@@ -136,16 +158,51 @@ func (m *DefaultManager) InitialRouteRange() []string {
|
||||
return m.notifier.initialRouteRanges()
|
||||
}
|
||||
|
||||
func (m *DefaultManager) updateClientNetworks(updateSerial uint64, networks map[string][]*route.Route) {
|
||||
// removing routes that do not exist as per the update from the Management service.
|
||||
// GetRouteSelector returns the route selector
|
||||
func (m *DefaultManager) GetRouteSelector() *routeselector.RouteSelector {
|
||||
return m.routeSelector
|
||||
}
|
||||
|
||||
// GetClientRoutes returns the client routes
|
||||
func (m *DefaultManager) GetClientRoutes() map[string]*clientNetwork {
|
||||
return m.clientNetworks
|
||||
}
|
||||
|
||||
// TriggerSelection triggers the selection of routes, stopping deselected watchers and starting newly selected ones
|
||||
func (m *DefaultManager) TriggerSelection(networks map[string][]*route.Route) {
|
||||
m.mux.Lock()
|
||||
defer m.mux.Unlock()
|
||||
|
||||
networks = m.routeSelector.FilterSelected(networks)
|
||||
m.stopObsoleteClients(networks)
|
||||
|
||||
for id, routes := range networks {
|
||||
if _, found := m.clientNetworks[id]; found {
|
||||
// don't touch existing client network watchers
|
||||
continue
|
||||
}
|
||||
|
||||
clientNetworkWatcher := newClientNetworkWatcher(m.ctx, m.wgInterface, m.statusRecorder, routes[0].Network)
|
||||
m.clientNetworks[id] = clientNetworkWatcher
|
||||
go clientNetworkWatcher.peersStateAndUpdateWatcher()
|
||||
clientNetworkWatcher.sendUpdateToClientNetworkWatcher(routesUpdate{routes: routes})
|
||||
}
|
||||
}
|
||||
|
||||
// stopObsoleteClients stops the client network watcher for the networks that are not in the new list
|
||||
func (m *DefaultManager) stopObsoleteClients(networks map[string][]*route.Route) {
|
||||
for id, client := range m.clientNetworks {
|
||||
_, found := networks[id]
|
||||
if !found {
|
||||
log.Debugf("stopping client network watcher, %s", id)
|
||||
if _, ok := networks[id]; !ok {
|
||||
log.Debugf("Stopping client network watcher, %s", id)
|
||||
client.stop()
|
||||
delete(m.clientNetworks, id)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (m *DefaultManager) updateClientNetworks(updateSerial uint64, networks map[string][]*route.Route) {
|
||||
// removing routes that do not exist as per the update from the Management service.
|
||||
m.stopObsoleteClients(networks)
|
||||
|
||||
for id, routes := range networks {
|
||||
clientNetworkWatcher, found := m.clientNetworks[id]
|
||||
@@ -162,7 +219,7 @@ func (m *DefaultManager) updateClientNetworks(updateSerial uint64, networks map[
|
||||
}
|
||||
}
|
||||
|
||||
func (m *DefaultManager) classifiesRoutes(newRoutes []*route.Route) (map[string]*route.Route, map[string][]*route.Route) {
|
||||
func (m *DefaultManager) classifyRoutes(newRoutes []*route.Route) (map[string]*route.Route, map[string][]*route.Route) {
|
||||
newClientRoutesIDMap := make(map[string][]*route.Route)
|
||||
newServerRoutesMap := make(map[string]*route.Route)
|
||||
ownNetworkIDs := make(map[string]bool)
|
||||
@@ -194,7 +251,7 @@ func (m *DefaultManager) classifiesRoutes(newRoutes []*route.Route) (map[string]
|
||||
}
|
||||
|
||||
func (m *DefaultManager) clientRoutes(initialRoutes []*route.Route) []*route.Route {
|
||||
_, crMap := m.classifiesRoutes(initialRoutes)
|
||||
_, crMap := m.classifyRoutes(initialRoutes)
|
||||
rs := make([]*route.Route, 0)
|
||||
for _, routes := range crMap {
|
||||
rs = append(rs, routes...)
|
||||
@@ -203,16 +260,38 @@ func (m *DefaultManager) clientRoutes(initialRoutes []*route.Route) []*route.Rou
|
||||
}
|
||||
|
||||
func isPrefixSupported(prefix netip.Prefix) bool {
|
||||
if runtime.GOOS == "linux" {
|
||||
return true
|
||||
if !nbnet.CustomRoutingDisabled() {
|
||||
switch runtime.GOOS {
|
||||
case "linux", "windows", "darwin", "ios":
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// If prefix is too small, lets assume it is a possible default prefix which is not yet supported
|
||||
// we skip this prefix management
|
||||
if prefix.Bits() < minRangeBits {
|
||||
if prefix.Bits() <= minRangeBits {
|
||||
log.Warnf("This agent version: %s, doesn't support default routes, received %s, skipping this prefix",
|
||||
version.NetbirdVersion(), prefix)
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// resolveURLsToIPs takes a slice of URLs, resolves them to IP addresses and returns a slice of IPs.
|
||||
func resolveURLsToIPs(urls []string) []net.IP {
|
||||
var ips []net.IP
|
||||
for _, rawurl := range urls {
|
||||
u, err := url.Parse(rawurl)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to parse url %s: %v", rawurl, err)
|
||||
continue
|
||||
}
|
||||
ipAddrs, err := net.LookupIP(u.Hostname())
|
||||
if err != nil {
|
||||
log.Errorf("Failed to resolve host %s: %v", u.Hostname(), err)
|
||||
continue
|
||||
}
|
||||
ips = append(ips, ipAddrs...)
|
||||
}
|
||||
return ips
|
||||
}
|
||||
|
||||
@@ -28,14 +28,14 @@ const remotePeerKey2 = "remote1"
|
||||
|
||||
func TestManagerUpdateRoutes(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
inputInitRoutes []*route.Route
|
||||
inputRoutes []*route.Route
|
||||
inputSerial uint64
|
||||
removeSrvRouter bool
|
||||
serverRoutesExpected int
|
||||
clientNetworkWatchersExpected int
|
||||
clientNetworkWatchersExpectedLinux int
|
||||
name string
|
||||
inputInitRoutes []*route.Route
|
||||
inputRoutes []*route.Route
|
||||
inputSerial uint64
|
||||
removeSrvRouter bool
|
||||
serverRoutesExpected int
|
||||
clientNetworkWatchersExpected int
|
||||
clientNetworkWatchersExpectedAllowed int
|
||||
}{
|
||||
{
|
||||
name: "Should create 2 client networks",
|
||||
@@ -201,9 +201,9 @@ func TestManagerUpdateRoutes(t *testing.T) {
|
||||
Enabled: true,
|
||||
},
|
||||
},
|
||||
inputSerial: 1,
|
||||
clientNetworkWatchersExpected: 0,
|
||||
clientNetworkWatchersExpectedLinux: 1,
|
||||
inputSerial: 1,
|
||||
clientNetworkWatchersExpected: 0,
|
||||
clientNetworkWatchersExpectedAllowed: 1,
|
||||
},
|
||||
{
|
||||
name: "Remove 1 Client Route",
|
||||
@@ -417,7 +417,9 @@ func TestManagerUpdateRoutes(t *testing.T) {
|
||||
statusRecorder := peer.NewRecorder("https://mgm")
|
||||
ctx := context.TODO()
|
||||
routeManager := NewManager(ctx, localPeerKey, wgInterface, statusRecorder, nil)
|
||||
err = routeManager.Init()
|
||||
|
||||
_, _, err = routeManager.Init()
|
||||
|
||||
require.NoError(t, err, "should init route manager")
|
||||
defer routeManager.Stop()
|
||||
|
||||
@@ -426,16 +428,16 @@ func TestManagerUpdateRoutes(t *testing.T) {
|
||||
}
|
||||
|
||||
if len(testCase.inputInitRoutes) > 0 {
|
||||
err = routeManager.UpdateRoutes(testCase.inputSerial, testCase.inputRoutes)
|
||||
_, _, err = routeManager.UpdateRoutes(testCase.inputSerial, testCase.inputRoutes)
|
||||
require.NoError(t, err, "should update routes with init routes")
|
||||
}
|
||||
|
||||
err = routeManager.UpdateRoutes(testCase.inputSerial+uint64(len(testCase.inputInitRoutes)), testCase.inputRoutes)
|
||||
_, _, err = routeManager.UpdateRoutes(testCase.inputSerial+uint64(len(testCase.inputInitRoutes)), testCase.inputRoutes)
|
||||
require.NoError(t, err, "should update routes")
|
||||
|
||||
expectedWatchers := testCase.clientNetworkWatchersExpected
|
||||
if runtime.GOOS == "linux" && testCase.clientNetworkWatchersExpectedLinux != 0 {
|
||||
expectedWatchers = testCase.clientNetworkWatchersExpectedLinux
|
||||
if (runtime.GOOS == "linux" || runtime.GOOS == "windows" || runtime.GOOS == "darwin") && testCase.clientNetworkWatchersExpectedAllowed != 0 {
|
||||
expectedWatchers = testCase.clientNetworkWatchersExpectedAllowed
|
||||
}
|
||||
require.Len(t, routeManager.clientNetworks, expectedWatchers, "client networks size should match")
|
||||
|
||||
|
||||
@@ -6,18 +6,22 @@ import (
|
||||
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/internal/listener"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/routeselector"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
)
|
||||
|
||||
// MockManager is the mock instance of a route manager
|
||||
type MockManager struct {
|
||||
UpdateRoutesFunc func(updateSerial uint64, newRoutes []*route.Route) error
|
||||
StopFunc func()
|
||||
UpdateRoutesFunc func(updateSerial uint64, newRoutes []*route.Route) (map[string]*route.Route, map[string][]*route.Route, error)
|
||||
TriggerSelectionFunc func(map[string][]*route.Route)
|
||||
GetRouteSelectorFunc func() *routeselector.RouteSelector
|
||||
StopFunc func()
|
||||
}
|
||||
|
||||
func (m *MockManager) Init() error {
|
||||
return nil
|
||||
func (m *MockManager) Init() (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) {
|
||||
return nil, nil, nil
|
||||
}
|
||||
|
||||
// InitialRouteRange mock implementation of InitialRouteRange from Manager interface
|
||||
@@ -26,11 +30,25 @@ func (m *MockManager) InitialRouteRange() []string {
|
||||
}
|
||||
|
||||
// UpdateRoutes mock implementation of UpdateRoutes from Manager interface
|
||||
func (m *MockManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) error {
|
||||
func (m *MockManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) (map[string]*route.Route, map[string][]*route.Route, error) {
|
||||
if m.UpdateRoutesFunc != nil {
|
||||
return m.UpdateRoutesFunc(updateSerial, newRoutes)
|
||||
}
|
||||
return fmt.Errorf("method UpdateRoutes is not implemented")
|
||||
return nil, nil, fmt.Errorf("method UpdateRoutes is not implemented")
|
||||
}
|
||||
|
||||
func (m *MockManager) TriggerSelection(networks map[string][]*route.Route) {
|
||||
if m.TriggerSelectionFunc != nil {
|
||||
m.TriggerSelectionFunc(networks)
|
||||
}
|
||||
}
|
||||
|
||||
// GetRouteSelector mock implementation of GetRouteSelector from Manager interface
|
||||
func (m *MockManager) GetRouteSelector() *routeselector.RouteSelector {
|
||||
if m.GetRouteSelectorFunc != nil {
|
||||
return m.GetRouteSelectorFunc()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Start mock implementation of Start from Manager interface
|
||||
|
||||
127
client/internal/routemanager/routemanager.go
Normal file
127
client/internal/routemanager/routemanager.go
Normal file
@@ -0,0 +1,127 @@
|
||||
//go:build !android && !ios
|
||||
|
||||
package routemanager
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"sync"
|
||||
|
||||
"github.com/hashicorp/go-multierror"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
)
|
||||
|
||||
type ref struct {
|
||||
count int
|
||||
nexthop netip.Addr
|
||||
intf *net.Interface
|
||||
}
|
||||
|
||||
type RouteManager struct {
|
||||
// refCountMap keeps track of the reference ref for prefixes
|
||||
refCountMap map[netip.Prefix]ref
|
||||
// prefixMap keeps track of the prefixes associated with a connection ID for removal
|
||||
prefixMap map[nbnet.ConnectionID][]netip.Prefix
|
||||
addRoute AddRouteFunc
|
||||
removeRoute RemoveRouteFunc
|
||||
mutex sync.Mutex
|
||||
}
|
||||
|
||||
type AddRouteFunc func(prefix netip.Prefix) (nexthop netip.Addr, intf *net.Interface, err error)
|
||||
type RemoveRouteFunc func(prefix netip.Prefix, nexthop netip.Addr, intf *net.Interface) error
|
||||
|
||||
func NewRouteManager(addRoute AddRouteFunc, removeRoute RemoveRouteFunc) *RouteManager {
|
||||
// TODO: read initial routing table into refCountMap
|
||||
return &RouteManager{
|
||||
refCountMap: map[netip.Prefix]ref{},
|
||||
prefixMap: map[nbnet.ConnectionID][]netip.Prefix{},
|
||||
addRoute: addRoute,
|
||||
removeRoute: removeRoute,
|
||||
}
|
||||
}
|
||||
|
||||
func (rm *RouteManager) AddRouteRef(connID nbnet.ConnectionID, prefix netip.Prefix) error {
|
||||
rm.mutex.Lock()
|
||||
defer rm.mutex.Unlock()
|
||||
|
||||
ref := rm.refCountMap[prefix]
|
||||
log.Debugf("Increasing route ref count %d for prefix %s", ref.count, prefix)
|
||||
|
||||
// Add route to the system, only if it's a new prefix
|
||||
if ref.count == 0 {
|
||||
log.Debugf("Adding route for prefix %s", prefix)
|
||||
nexthop, intf, err := rm.addRoute(prefix)
|
||||
if errors.Is(err, ErrRouteNotFound) {
|
||||
return nil
|
||||
}
|
||||
if errors.Is(err, ErrRouteNotAllowed) {
|
||||
log.Debugf("Adding route for prefix %s: %s", prefix, err)
|
||||
}
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to add route for prefix %s: %w", prefix, err)
|
||||
}
|
||||
ref.nexthop = nexthop
|
||||
ref.intf = intf
|
||||
}
|
||||
|
||||
ref.count++
|
||||
rm.refCountMap[prefix] = ref
|
||||
rm.prefixMap[connID] = append(rm.prefixMap[connID], prefix)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (rm *RouteManager) RemoveRouteRef(connID nbnet.ConnectionID) error {
|
||||
rm.mutex.Lock()
|
||||
defer rm.mutex.Unlock()
|
||||
|
||||
prefixes, ok := rm.prefixMap[connID]
|
||||
if !ok {
|
||||
log.Debugf("No prefixes found for connection ID %s", connID)
|
||||
return nil
|
||||
}
|
||||
|
||||
var result *multierror.Error
|
||||
for _, prefix := range prefixes {
|
||||
ref := rm.refCountMap[prefix]
|
||||
log.Debugf("Decreasing route ref count %d for prefix %s", ref.count, prefix)
|
||||
if ref.count == 1 {
|
||||
log.Debugf("Removing route for prefix %s", prefix)
|
||||
// TODO: don't fail if the route is not found
|
||||
if err := rm.removeRoute(prefix, ref.nexthop, ref.intf); err != nil {
|
||||
result = multierror.Append(result, fmt.Errorf("remove route for prefix %s: %w", prefix, err))
|
||||
continue
|
||||
}
|
||||
delete(rm.refCountMap, prefix)
|
||||
} else {
|
||||
ref.count--
|
||||
rm.refCountMap[prefix] = ref
|
||||
}
|
||||
}
|
||||
delete(rm.prefixMap, connID)
|
||||
|
||||
return result.ErrorOrNil()
|
||||
}
|
||||
|
||||
// Flush removes all references and routes from the system
|
||||
func (rm *RouteManager) Flush() error {
|
||||
rm.mutex.Lock()
|
||||
defer rm.mutex.Unlock()
|
||||
|
||||
var result *multierror.Error
|
||||
for prefix := range rm.refCountMap {
|
||||
log.Debugf("Removing route for prefix %s", prefix)
|
||||
ref := rm.refCountMap[prefix]
|
||||
if err := rm.removeRoute(prefix, ref.nexthop, ref.intf); err != nil {
|
||||
result = multierror.Append(result, fmt.Errorf("remove route for prefix %s: %w", prefix, err))
|
||||
}
|
||||
}
|
||||
rm.refCountMap = map[netip.Prefix]ref{}
|
||||
rm.prefixMap = map[nbnet.ConnectionID][]netip.Prefix{}
|
||||
|
||||
return result.ErrorOrNil()
|
||||
}
|
||||
@@ -155,11 +155,13 @@ func (m *defaultServerRouter) cleanUp() {
|
||||
log.Errorf("Failed to remove cleanup route: %v", err)
|
||||
}
|
||||
|
||||
state := m.statusRecorder.GetLocalPeerState()
|
||||
state.Routes = nil
|
||||
m.statusRecorder.UpdateLocalPeerState(state)
|
||||
}
|
||||
|
||||
state := m.statusRecorder.GetLocalPeerState()
|
||||
state.Routes = nil
|
||||
m.statusRecorder.UpdateLocalPeerState(state)
|
||||
}
|
||||
|
||||
func routeToRouterPair(source string, route *route.Route) (firewall.RouterPair, error) {
|
||||
parsed, err := netip.ParsePrefix(source)
|
||||
if err != nil {
|
||||
|
||||
414
client/internal/routemanager/systemops.go
Normal file
414
client/internal/routemanager/systemops.go
Normal file
@@ -0,0 +1,414 @@
|
||||
//go:build !android && !ios
|
||||
|
||||
package routemanager
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"runtime"
|
||||
"strconv"
|
||||
|
||||
"github.com/hashicorp/go-multierror"
|
||||
"github.com/libp2p/go-netroute"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
)
|
||||
|
||||
var splitDefaultv4_1 = netip.PrefixFrom(netip.IPv4Unspecified(), 1)
|
||||
var splitDefaultv4_2 = netip.PrefixFrom(netip.AddrFrom4([4]byte{128}), 1)
|
||||
var splitDefaultv6_1 = netip.PrefixFrom(netip.IPv6Unspecified(), 1)
|
||||
var splitDefaultv6_2 = netip.PrefixFrom(netip.AddrFrom16([16]byte{0x80}), 1)
|
||||
|
||||
var ErrRouteNotFound = errors.New("route not found")
|
||||
var ErrRouteNotAllowed = errors.New("route not allowed")
|
||||
|
||||
// TODO: fix: for default our wg address now appears as the default gw
|
||||
func addRouteForCurrentDefaultGateway(prefix netip.Prefix) error {
|
||||
addr := netip.IPv4Unspecified()
|
||||
if prefix.Addr().Is6() {
|
||||
addr = netip.IPv6Unspecified()
|
||||
}
|
||||
|
||||
defaultGateway, _, err := getNextHop(addr)
|
||||
if err != nil && !errors.Is(err, ErrRouteNotFound) {
|
||||
return fmt.Errorf("get existing route gateway: %s", err)
|
||||
}
|
||||
|
||||
if !prefix.Contains(defaultGateway) {
|
||||
log.Debugf("Skipping adding a new route for gateway %s because it is not in the network %s", defaultGateway, prefix)
|
||||
return nil
|
||||
}
|
||||
|
||||
gatewayPrefix := netip.PrefixFrom(defaultGateway, 32)
|
||||
if defaultGateway.Is6() {
|
||||
gatewayPrefix = netip.PrefixFrom(defaultGateway, 128)
|
||||
}
|
||||
|
||||
ok, err := existsInRouteTable(gatewayPrefix)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to check if there is an existing route for gateway %s. error: %s", gatewayPrefix, err)
|
||||
}
|
||||
|
||||
if ok {
|
||||
log.Debugf("Skipping adding a new route for gateway %s because it already exists", gatewayPrefix)
|
||||
return nil
|
||||
}
|
||||
|
||||
gatewayHop, intf, err := getNextHop(defaultGateway)
|
||||
if err != nil && !errors.Is(err, ErrRouteNotFound) {
|
||||
return fmt.Errorf("unable to get the next hop for the default gateway address. error: %s", err)
|
||||
}
|
||||
|
||||
log.Debugf("Adding a new route for gateway %s with next hop %s", gatewayPrefix, gatewayHop)
|
||||
return addToRouteTable(gatewayPrefix, gatewayHop, intf)
|
||||
}
|
||||
|
||||
func getNextHop(ip netip.Addr) (netip.Addr, *net.Interface, error) {
|
||||
r, err := netroute.New()
|
||||
if err != nil {
|
||||
return netip.Addr{}, nil, fmt.Errorf("new netroute: %w", err)
|
||||
}
|
||||
intf, gateway, preferredSrc, err := r.Route(ip.AsSlice())
|
||||
if err != nil {
|
||||
log.Warnf("Failed to get route for %s: %v", ip, err)
|
||||
return netip.Addr{}, nil, ErrRouteNotFound
|
||||
}
|
||||
|
||||
log.Debugf("Route for %s: interface %v nexthop %v, preferred source %v", ip, intf, gateway, preferredSrc)
|
||||
if gateway == nil {
|
||||
if preferredSrc == nil {
|
||||
return netip.Addr{}, nil, ErrRouteNotFound
|
||||
}
|
||||
log.Debugf("No next hop found for ip %s, using preferred source %s", ip, preferredSrc)
|
||||
|
||||
addr, err := ipToAddr(preferredSrc, intf)
|
||||
if err != nil {
|
||||
return netip.Addr{}, nil, fmt.Errorf("convert preferred source to address: %w", err)
|
||||
}
|
||||
return addr.Unmap(), intf, nil
|
||||
}
|
||||
|
||||
addr, err := ipToAddr(gateway, intf)
|
||||
if err != nil {
|
||||
return netip.Addr{}, nil, fmt.Errorf("convert gateway to address: %w", err)
|
||||
}
|
||||
|
||||
return addr, intf, nil
|
||||
}
|
||||
|
||||
// converts a net.IP to a netip.Addr including the zone based on the passed interface
|
||||
func ipToAddr(ip net.IP, intf *net.Interface) (netip.Addr, error) {
|
||||
addr, ok := netip.AddrFromSlice(ip)
|
||||
if !ok {
|
||||
return netip.Addr{}, fmt.Errorf("failed to convert IP address to netip.Addr: %s", ip)
|
||||
}
|
||||
|
||||
if intf != nil && (addr.IsLinkLocalMulticast() || addr.IsLinkLocalUnicast()) {
|
||||
log.Tracef("Adding zone %s to address %s", intf.Name, addr)
|
||||
if runtime.GOOS == "windows" {
|
||||
addr = addr.WithZone(strconv.Itoa(intf.Index))
|
||||
} else {
|
||||
addr = addr.WithZone(intf.Name)
|
||||
}
|
||||
}
|
||||
|
||||
return addr.Unmap(), nil
|
||||
}
|
||||
|
||||
func existsInRouteTable(prefix netip.Prefix) (bool, error) {
|
||||
routes, err := getRoutesFromTable()
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("get routes from table: %w", err)
|
||||
}
|
||||
for _, tableRoute := range routes {
|
||||
if tableRoute == prefix {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func isSubRange(prefix netip.Prefix) (bool, error) {
|
||||
routes, err := getRoutesFromTable()
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("get routes from table: %w", err)
|
||||
}
|
||||
for _, tableRoute := range routes {
|
||||
if tableRoute.Bits() > minRangeBits && tableRoute.Contains(prefix.Addr()) && tableRoute.Bits() < prefix.Bits() {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// addRouteToNonVPNIntf adds a new route to the routing table for the given prefix and returns the next hop and interface.
|
||||
// If the next hop or interface is pointing to the VPN interface, it will return the initial values.
|
||||
func addRouteToNonVPNIntf(prefix netip.Prefix, vpnIntf *iface.WGIface, initialNextHop netip.Addr, initialIntf *net.Interface) (netip.Addr, *net.Interface, error) {
|
||||
addr := prefix.Addr()
|
||||
switch {
|
||||
case addr.IsLoopback(),
|
||||
addr.IsLinkLocalUnicast(),
|
||||
addr.IsLinkLocalMulticast(),
|
||||
addr.IsInterfaceLocalMulticast(),
|
||||
addr.IsUnspecified(),
|
||||
addr.IsMulticast():
|
||||
|
||||
return netip.Addr{}, nil, ErrRouteNotAllowed
|
||||
}
|
||||
|
||||
// Determine the exit interface and next hop for the prefix, so we can add a specific route
|
||||
nexthop, intf, err := getNextHop(addr)
|
||||
if err != nil {
|
||||
return netip.Addr{}, nil, fmt.Errorf("get next hop: %w", err)
|
||||
}
|
||||
|
||||
log.Debugf("Found next hop %s for prefix %s with interface %v", nexthop, prefix, intf)
|
||||
exitNextHop := nexthop
|
||||
exitIntf := intf
|
||||
|
||||
vpnAddr, ok := netip.AddrFromSlice(vpnIntf.Address().IP)
|
||||
if !ok {
|
||||
return netip.Addr{}, nil, fmt.Errorf("failed to convert vpn address to netip.Addr")
|
||||
}
|
||||
|
||||
// if next hop is the VPN address or the interface is the VPN interface, we should use the initial values
|
||||
if exitNextHop == vpnAddr || exitIntf != nil && exitIntf.Name == vpnIntf.Name() {
|
||||
log.Debugf("Route for prefix %s is pointing to the VPN interface", prefix)
|
||||
exitNextHop = initialNextHop
|
||||
exitIntf = initialIntf
|
||||
}
|
||||
|
||||
log.Debugf("Adding a new route for prefix %s with next hop %s", prefix, exitNextHop)
|
||||
if err := addToRouteTable(prefix, exitNextHop, exitIntf); err != nil {
|
||||
return netip.Addr{}, nil, fmt.Errorf("add route to table: %w", err)
|
||||
}
|
||||
|
||||
return exitNextHop, exitIntf, nil
|
||||
}
|
||||
|
||||
// genericAddVPNRoute adds a new route to the vpn interface, it splits the default prefix
|
||||
// in two /1 prefixes to avoid replacing the existing default route
|
||||
func genericAddVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
|
||||
if prefix == defaultv4 {
|
||||
if err := addToRouteTable(splitDefaultv4_1, netip.Addr{}, intf); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := addToRouteTable(splitDefaultv4_2, netip.Addr{}, intf); err != nil {
|
||||
if err2 := removeFromRouteTable(splitDefaultv4_1, netip.Addr{}, intf); err2 != nil {
|
||||
log.Warnf("Failed to rollback route addition: %s", err2)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// TODO: remove once IPv6 is supported on the interface
|
||||
if err := addToRouteTable(splitDefaultv6_1, netip.Addr{}, intf); err != nil {
|
||||
return fmt.Errorf("add unreachable route split 1: %w", err)
|
||||
}
|
||||
if err := addToRouteTable(splitDefaultv6_2, netip.Addr{}, intf); err != nil {
|
||||
if err2 := removeFromRouteTable(splitDefaultv6_1, netip.Addr{}, intf); err2 != nil {
|
||||
log.Warnf("Failed to rollback route addition: %s", err2)
|
||||
}
|
||||
return fmt.Errorf("add unreachable route split 2: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
} else if prefix == defaultv6 {
|
||||
if err := addToRouteTable(splitDefaultv6_1, netip.Addr{}, intf); err != nil {
|
||||
return fmt.Errorf("add unreachable route split 1: %w", err)
|
||||
}
|
||||
if err := addToRouteTable(splitDefaultv6_2, netip.Addr{}, intf); err != nil {
|
||||
if err2 := removeFromRouteTable(splitDefaultv6_1, netip.Addr{}, intf); err2 != nil {
|
||||
log.Warnf("Failed to rollback route addition: %s", err2)
|
||||
}
|
||||
return fmt.Errorf("add unreachable route split 2: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
return addNonExistingRoute(prefix, intf)
|
||||
}
|
||||
|
||||
// addNonExistingRoute adds a new route to the vpn interface if it doesn't exist in the current routing table
|
||||
func addNonExistingRoute(prefix netip.Prefix, intf *net.Interface) error {
|
||||
ok, err := existsInRouteTable(prefix)
|
||||
if err != nil {
|
||||
return fmt.Errorf("exists in route table: %w", err)
|
||||
}
|
||||
if ok {
|
||||
log.Warnf("Skipping adding a new route for network %s because it already exists", prefix)
|
||||
return nil
|
||||
}
|
||||
|
||||
ok, err = isSubRange(prefix)
|
||||
if err != nil {
|
||||
return fmt.Errorf("sub range: %w", err)
|
||||
}
|
||||
|
||||
if ok {
|
||||
err := addRouteForCurrentDefaultGateway(prefix)
|
||||
if err != nil {
|
||||
log.Warnf("Unable to add route for current default gateway route. Will proceed without it. error: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
return addToRouteTable(prefix, netip.Addr{}, intf)
|
||||
}
|
||||
|
||||
// genericRemoveVPNRoute removes the route from the vpn interface. If a default prefix is given,
|
||||
// it will remove the split /1 prefixes
|
||||
func genericRemoveVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
|
||||
if prefix == defaultv4 {
|
||||
var result *multierror.Error
|
||||
if err := removeFromRouteTable(splitDefaultv4_1, netip.Addr{}, intf); err != nil {
|
||||
result = multierror.Append(result, err)
|
||||
}
|
||||
if err := removeFromRouteTable(splitDefaultv4_2, netip.Addr{}, intf); err != nil {
|
||||
result = multierror.Append(result, err)
|
||||
}
|
||||
|
||||
// TODO: remove once IPv6 is supported on the interface
|
||||
if err := removeFromRouteTable(splitDefaultv6_1, netip.Addr{}, intf); err != nil {
|
||||
result = multierror.Append(result, err)
|
||||
}
|
||||
if err := removeFromRouteTable(splitDefaultv6_2, netip.Addr{}, intf); err != nil {
|
||||
result = multierror.Append(result, err)
|
||||
}
|
||||
|
||||
return result.ErrorOrNil()
|
||||
} else if prefix == defaultv6 {
|
||||
var result *multierror.Error
|
||||
if err := removeFromRouteTable(splitDefaultv6_1, netip.Addr{}, intf); err != nil {
|
||||
result = multierror.Append(result, err)
|
||||
}
|
||||
if err := removeFromRouteTable(splitDefaultv6_2, netip.Addr{}, intf); err != nil {
|
||||
result = multierror.Append(result, err)
|
||||
}
|
||||
|
||||
return result.ErrorOrNil()
|
||||
}
|
||||
|
||||
return removeFromRouteTable(prefix, netip.Addr{}, intf)
|
||||
}
|
||||
|
||||
func getPrefixFromIP(ip net.IP) (*netip.Prefix, error) {
|
||||
addr, ok := netip.AddrFromSlice(ip)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("parse IP address: %s", ip)
|
||||
}
|
||||
addr = addr.Unmap()
|
||||
|
||||
var prefixLength int
|
||||
switch {
|
||||
case addr.Is4():
|
||||
prefixLength = 32
|
||||
case addr.Is6():
|
||||
prefixLength = 128
|
||||
default:
|
||||
return nil, fmt.Errorf("invalid IP address: %s", addr)
|
||||
}
|
||||
|
||||
prefix := netip.PrefixFrom(addr, prefixLength)
|
||||
return &prefix, nil
|
||||
}
|
||||
|
||||
func setupRoutingWithRouteManager(routeManager **RouteManager, initAddresses []net.IP, wgIface *iface.WGIface) (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) {
|
||||
initialNextHopV4, initialIntfV4, err := getNextHop(netip.IPv4Unspecified())
|
||||
if err != nil && !errors.Is(err, ErrRouteNotFound) {
|
||||
log.Errorf("Unable to get initial v4 default next hop: %v", err)
|
||||
}
|
||||
initialNextHopV6, initialIntfV6, err := getNextHop(netip.IPv6Unspecified())
|
||||
if err != nil && !errors.Is(err, ErrRouteNotFound) {
|
||||
log.Errorf("Unable to get initial v6 default next hop: %v", err)
|
||||
}
|
||||
|
||||
*routeManager = NewRouteManager(
|
||||
func(prefix netip.Prefix) (netip.Addr, *net.Interface, error) {
|
||||
addr := prefix.Addr()
|
||||
nexthop, intf := initialNextHopV4, initialIntfV4
|
||||
if addr.Is6() {
|
||||
nexthop, intf = initialNextHopV6, initialIntfV6
|
||||
}
|
||||
return addRouteToNonVPNIntf(prefix, wgIface, nexthop, intf)
|
||||
},
|
||||
removeFromRouteTable,
|
||||
)
|
||||
|
||||
return setupHooks(*routeManager, initAddresses)
|
||||
}
|
||||
|
||||
func cleanupRoutingWithRouteManager(routeManager *RouteManager) error {
|
||||
if routeManager == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// TODO: Remove hooks selectively
|
||||
nbnet.RemoveDialerHooks()
|
||||
nbnet.RemoveListenerHooks()
|
||||
|
||||
if err := routeManager.Flush(); err != nil {
|
||||
return fmt.Errorf("flush route manager: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func setupHooks(routeManager *RouteManager, initAddresses []net.IP) (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) {
|
||||
beforeHook := func(connID nbnet.ConnectionID, ip net.IP) error {
|
||||
prefix, err := getPrefixFromIP(ip)
|
||||
if err != nil {
|
||||
return fmt.Errorf("convert ip to prefix: %w", err)
|
||||
}
|
||||
|
||||
if err := routeManager.AddRouteRef(connID, *prefix); err != nil {
|
||||
return fmt.Errorf("adding route reference: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
afterHook := func(connID nbnet.ConnectionID) error {
|
||||
if err := routeManager.RemoveRouteRef(connID); err != nil {
|
||||
return fmt.Errorf("remove route reference: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
for _, ip := range initAddresses {
|
||||
if err := beforeHook("init", ip); err != nil {
|
||||
log.Errorf("Failed to add route reference: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
nbnet.AddDialerHook(func(ctx context.Context, connID nbnet.ConnectionID, resolvedIPs []net.IPAddr) error {
|
||||
if ctx.Err() != nil {
|
||||
return ctx.Err()
|
||||
}
|
||||
|
||||
var result *multierror.Error
|
||||
for _, ip := range resolvedIPs {
|
||||
result = multierror.Append(result, beforeHook(connID, ip.IP))
|
||||
}
|
||||
return result.ErrorOrNil()
|
||||
})
|
||||
|
||||
nbnet.AddDialerCloseHook(func(connID nbnet.ConnectionID, conn *net.Conn) error {
|
||||
return afterHook(connID)
|
||||
})
|
||||
|
||||
nbnet.AddListenerWriteHook(func(connID nbnet.ConnectionID, ip *net.IPAddr, data []byte) error {
|
||||
return beforeHook(connID, ip.IP)
|
||||
})
|
||||
|
||||
nbnet.AddListenerCloseHook(func(connID nbnet.ConnectionID, conn net.PacketConn) error {
|
||||
return afterHook(connID)
|
||||
})
|
||||
|
||||
return beforeHook, afterHook, nil
|
||||
}
|
||||
@@ -1,13 +1,33 @@
|
||||
package routemanager
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"runtime"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
)
|
||||
|
||||
func addToRouteTableIfNoExists(prefix netip.Prefix, addr, intf string) error {
|
||||
func setupRouting([]net.IP, *iface.WGIface) (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) {
|
||||
return nil, nil, nil
|
||||
}
|
||||
|
||||
func cleanupRouting() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func removeFromRouteTableIfNonSystem(prefix netip.Prefix, addr, intf string) error {
|
||||
func enableIPForwarding() error {
|
||||
log.Infof("Enable IP forwarding is not implemented on %s", runtime.GOOS)
|
||||
return nil
|
||||
}
|
||||
|
||||
func addVPNRoute(netip.Prefix, *net.Interface) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func removeVPNRoute(netip.Prefix, *net.Interface) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"net/netip"
|
||||
"syscall"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/net/route"
|
||||
)
|
||||
|
||||
@@ -51,16 +52,24 @@ func getRoutesFromTable() ([]netip.Prefix, error) {
|
||||
continue
|
||||
}
|
||||
|
||||
if len(m.Addrs) < 3 {
|
||||
log.Warnf("Unexpected RIB message Addrs: %v", m.Addrs)
|
||||
continue
|
||||
}
|
||||
|
||||
addr, ok := toNetIPAddr(m.Addrs[0])
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
mask, ok := toNetIPMASK(m.Addrs[2])
|
||||
if !ok {
|
||||
continue
|
||||
cidr := 32
|
||||
if mask := m.Addrs[2]; mask != nil {
|
||||
cidr, ok = toCIDR(mask)
|
||||
if !ok {
|
||||
log.Debugf("Unexpected RIB message Addrs[2]: %v", mask)
|
||||
continue
|
||||
}
|
||||
}
|
||||
cidr, _ := mask.Size()
|
||||
|
||||
routePrefix := netip.PrefixFrom(addr, cidr)
|
||||
if routePrefix.IsValid() {
|
||||
@@ -73,20 +82,19 @@ func getRoutesFromTable() ([]netip.Prefix, error) {
|
||||
func toNetIPAddr(a route.Addr) (netip.Addr, bool) {
|
||||
switch t := a.(type) {
|
||||
case *route.Inet4Addr:
|
||||
ip := net.IPv4(t.IP[0], t.IP[1], t.IP[2], t.IP[3])
|
||||
addr := netip.MustParseAddr(ip.String())
|
||||
return addr, true
|
||||
return netip.AddrFrom4(t.IP), true
|
||||
default:
|
||||
return netip.Addr{}, false
|
||||
}
|
||||
}
|
||||
|
||||
func toNetIPMASK(a route.Addr) (net.IPMask, bool) {
|
||||
func toCIDR(a route.Addr) (int, bool) {
|
||||
switch t := a.(type) {
|
||||
case *route.Inet4Addr:
|
||||
mask := net.IPv4Mask(t.IP[0], t.IP[1], t.IP[2], t.IP[3])
|
||||
return mask, true
|
||||
cidr, _ := mask.Size()
|
||||
return cidr, true
|
||||
default:
|
||||
return nil, false
|
||||
return 0, false
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,13 +0,0 @@
|
||||
//go:build (darwin || dragonfly || freebsd || netbsd || openbsd) && !ios
|
||||
|
||||
package routemanager
|
||||
|
||||
import "net/netip"
|
||||
|
||||
func addToRouteTableIfNoExists(prefix netip.Prefix, addr string, intf string) error {
|
||||
return genericAddToRouteTableIfNoExists(prefix, addr, intf)
|
||||
}
|
||||
|
||||
func removeFromRouteTableIfNonSystem(prefix netip.Prefix, addr string, intf string) error {
|
||||
return genericRemoveFromRouteTableIfNonSystem(prefix, addr, intf)
|
||||
}
|
||||
89
client/internal/routemanager/systemops_darwin.go
Normal file
89
client/internal/routemanager/systemops_darwin.go
Normal file
@@ -0,0 +1,89 @@
|
||||
//go:build darwin && !ios
|
||||
|
||||
package routemanager
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os/exec"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/cenkalti/backoff/v4"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
)
|
||||
|
||||
var routeManager *RouteManager
|
||||
|
||||
func setupRouting(initAddresses []net.IP, wgIface *iface.WGIface) (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) {
|
||||
return setupRoutingWithRouteManager(&routeManager, initAddresses, wgIface)
|
||||
}
|
||||
|
||||
func cleanupRouting() error {
|
||||
return cleanupRoutingWithRouteManager(routeManager)
|
||||
}
|
||||
|
||||
func addToRouteTable(prefix netip.Prefix, nexthop netip.Addr, intf *net.Interface) error {
|
||||
return routeCmd("add", prefix, nexthop, intf)
|
||||
}
|
||||
|
||||
func removeFromRouteTable(prefix netip.Prefix, nexthop netip.Addr, intf *net.Interface) error {
|
||||
return routeCmd("delete", prefix, nexthop, intf)
|
||||
}
|
||||
|
||||
func routeCmd(action string, prefix netip.Prefix, nexthop netip.Addr, intf *net.Interface) error {
|
||||
inet := "-inet"
|
||||
network := prefix.String()
|
||||
if prefix.IsSingleIP() {
|
||||
network = prefix.Addr().String()
|
||||
}
|
||||
if prefix.Addr().Is6() {
|
||||
inet = "-inet6"
|
||||
// Special case for IPv6 split default route, pointing to the wg interface fails
|
||||
// TODO: Remove once we have IPv6 support on the interface
|
||||
if prefix.Bits() == 1 {
|
||||
intf = &net.Interface{Name: "lo0"}
|
||||
}
|
||||
}
|
||||
|
||||
args := []string{"-n", action, inet, network}
|
||||
if nexthop.IsValid() {
|
||||
args = append(args, nexthop.Unmap().String())
|
||||
} else if intf != nil {
|
||||
args = append(args, "-interface", intf.Name)
|
||||
}
|
||||
|
||||
if err := retryRouteCmd(args); err != nil {
|
||||
return fmt.Errorf("failed to %s route for %s: %w", action, prefix, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func retryRouteCmd(args []string) error {
|
||||
operation := func() error {
|
||||
out, err := exec.Command("route", args...).CombinedOutput()
|
||||
log.Tracef("route %s: %s", strings.Join(args, " "), out)
|
||||
// https://github.com/golang/go/issues/45736
|
||||
if err != nil && strings.Contains(string(out), "sysctl: cannot allocate memory") {
|
||||
return err
|
||||
} else if err != nil {
|
||||
return backoff.Permanent(err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
expBackOff := backoff.NewExponentialBackOff()
|
||||
expBackOff.InitialInterval = 50 * time.Millisecond
|
||||
expBackOff.MaxInterval = 500 * time.Millisecond
|
||||
expBackOff.MaxElapsedTime = 1 * time.Second
|
||||
|
||||
err := backoff.Retry(operation, expBackOff)
|
||||
if err != nil {
|
||||
return fmt.Errorf("route cmd retry failed: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
138
client/internal/routemanager/systemops_darwin_test.go
Normal file
138
client/internal/routemanager/systemops_darwin_test.go
Normal file
@@ -0,0 +1,138 @@
|
||||
//go:build !ios
|
||||
|
||||
package routemanager
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os/exec"
|
||||
"regexp"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
var expectedVPNint = "utun100"
|
||||
var expectedExternalInt = "lo0"
|
||||
var expectedInternalInt = "lo0"
|
||||
|
||||
func init() {
|
||||
testCases = append(testCases, []testCase{
|
||||
{
|
||||
name: "To more specific route without custom dialer via vpn",
|
||||
destination: "10.10.0.2:53",
|
||||
expectedInterface: expectedVPNint,
|
||||
dialer: &net.Dialer{},
|
||||
expectedPacket: createPacketExpectation("100.64.0.1", 12345, "10.10.0.2", 53),
|
||||
},
|
||||
}...)
|
||||
}
|
||||
|
||||
func TestConcurrentRoutes(t *testing.T) {
|
||||
baseIP := netip.MustParseAddr("192.0.2.0")
|
||||
intf := &net.Interface{Name: "lo0"}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
for i := 0; i < 1024; i++ {
|
||||
wg.Add(1)
|
||||
go func(ip netip.Addr) {
|
||||
defer wg.Done()
|
||||
prefix := netip.PrefixFrom(ip, 32)
|
||||
if err := addToRouteTable(prefix, netip.Addr{}, intf); err != nil {
|
||||
t.Errorf("Failed to add route for %s: %v", prefix, err)
|
||||
}
|
||||
}(baseIP)
|
||||
baseIP = baseIP.Next()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
baseIP = netip.MustParseAddr("192.0.2.0")
|
||||
|
||||
for i := 0; i < 1024; i++ {
|
||||
wg.Add(1)
|
||||
go func(ip netip.Addr) {
|
||||
defer wg.Done()
|
||||
prefix := netip.PrefixFrom(ip, 32)
|
||||
if err := removeFromRouteTable(prefix, netip.Addr{}, intf); err != nil {
|
||||
t.Errorf("Failed to remove route for %s: %v", prefix, err)
|
||||
}
|
||||
}(baseIP)
|
||||
baseIP = baseIP.Next()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func createAndSetupDummyInterface(t *testing.T, intf string, ipAddressCIDR string) string {
|
||||
t.Helper()
|
||||
|
||||
err := exec.Command("ifconfig", intf, "alias", ipAddressCIDR).Run()
|
||||
require.NoError(t, err, "Failed to create loopback alias")
|
||||
|
||||
t.Cleanup(func() {
|
||||
err := exec.Command("ifconfig", intf, ipAddressCIDR, "-alias").Run()
|
||||
assert.NoError(t, err, "Failed to remove loopback alias")
|
||||
})
|
||||
|
||||
return "lo0"
|
||||
}
|
||||
|
||||
func addDummyRoute(t *testing.T, dstCIDR string, gw net.IP, _ string) {
|
||||
t.Helper()
|
||||
|
||||
var originalNexthop net.IP
|
||||
if dstCIDR == "0.0.0.0/0" {
|
||||
var err error
|
||||
originalNexthop, err = fetchOriginalGateway()
|
||||
if err != nil {
|
||||
t.Logf("Failed to fetch original gateway: %v", err)
|
||||
}
|
||||
|
||||
if output, err := exec.Command("route", "delete", "-net", dstCIDR).CombinedOutput(); err != nil {
|
||||
t.Logf("Failed to delete route: %v, output: %s", err, output)
|
||||
}
|
||||
}
|
||||
|
||||
t.Cleanup(func() {
|
||||
if originalNexthop != nil {
|
||||
err := exec.Command("route", "add", "-net", dstCIDR, originalNexthop.String()).Run()
|
||||
assert.NoError(t, err, "Failed to restore original route")
|
||||
}
|
||||
})
|
||||
|
||||
err := exec.Command("route", "add", "-net", dstCIDR, gw.String()).Run()
|
||||
require.NoError(t, err, "Failed to add route")
|
||||
|
||||
t.Cleanup(func() {
|
||||
err := exec.Command("route", "delete", "-net", dstCIDR).Run()
|
||||
assert.NoError(t, err, "Failed to remove route")
|
||||
})
|
||||
}
|
||||
|
||||
func fetchOriginalGateway() (net.IP, error) {
|
||||
output, err := exec.Command("route", "-n", "get", "default").CombinedOutput()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
matches := regexp.MustCompile(`gateway: (\S+)`).FindStringSubmatch(string(output))
|
||||
if len(matches) == 0 {
|
||||
return nil, fmt.Errorf("gateway not found")
|
||||
}
|
||||
|
||||
return net.ParseIP(matches[1]), nil
|
||||
}
|
||||
|
||||
func setupDummyInterfacesAndRoutes(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
defaultDummy := createAndSetupDummyInterface(t, expectedExternalInt, "192.168.0.1/24")
|
||||
addDummyRoute(t, "0.0.0.0/0", net.IPv4(192, 168, 0, 1), defaultDummy)
|
||||
|
||||
otherDummy := createAndSetupDummyInterface(t, expectedInternalInt, "192.168.1.1/24")
|
||||
addDummyRoute(t, "10.0.0.0/8", net.IPv4(192, 168, 1, 1), otherDummy)
|
||||
}
|
||||
@@ -1,13 +1,33 @@
|
||||
package routemanager
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"runtime"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
)
|
||||
|
||||
func addToRouteTableIfNoExists(prefix netip.Prefix, addr, intf string) error {
|
||||
func setupRouting([]net.IP, *iface.WGIface) (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) {
|
||||
return nil, nil, nil
|
||||
}
|
||||
|
||||
func cleanupRouting() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func removeFromRouteTableIfNonSystem(prefix netip.Prefix, addr, intf string) error {
|
||||
func enableIPForwarding() error {
|
||||
log.Infof("Enable IP forwarding is not implemented on %s", runtime.GOOS)
|
||||
return nil
|
||||
}
|
||||
|
||||
func addVPNRoute(netip.Prefix, *net.Interface) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func removeVPNRoute(netip.Prefix, *net.Interface) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -9,12 +9,16 @@ import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"syscall"
|
||||
|
||||
"github.com/hashicorp/go-multierror"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/vishvananda/netlink"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
)
|
||||
|
||||
@@ -28,27 +32,53 @@ const (
|
||||
rtTablesPath = "/etc/iproute2/rt_tables"
|
||||
|
||||
// ipv4ForwardingPath is the path to the file containing the IP forwarding setting.
|
||||
ipv4ForwardingPath = "/proc/sys/net/ipv4/ip_forward"
|
||||
ipv4ForwardingPath = "net.ipv4.ip_forward"
|
||||
|
||||
rpFilterPath = "net.ipv4.conf.all.rp_filter"
|
||||
rpFilterInterfacePath = "net.ipv4.conf.%s.rp_filter"
|
||||
srcValidMarkPath = "net.ipv4.conf.all.src_valid_mark"
|
||||
)
|
||||
|
||||
var ErrTableIDExists = errors.New("ID exists with different name")
|
||||
|
||||
var routeManager = &RouteManager{}
|
||||
|
||||
// originalSysctl stores the original sysctl values before they are modified
|
||||
var originalSysctl map[string]int
|
||||
|
||||
// sysctlFailed is used as an indicator to emit a warning when default routes are configured
|
||||
var sysctlFailed bool
|
||||
|
||||
type ruleParams struct {
|
||||
priority int
|
||||
fwmark int
|
||||
tableID int
|
||||
family int
|
||||
priority int
|
||||
invert bool
|
||||
suppressPrefix int
|
||||
description string
|
||||
}
|
||||
|
||||
// isLegacy determines whether to use the legacy routing setup
|
||||
func isLegacy() bool {
|
||||
return os.Getenv("NB_USE_LEGACY_ROUTING") == "true" || nbnet.CustomRoutingDisabled()
|
||||
}
|
||||
|
||||
// setIsLegacy sets the legacy routing setup
|
||||
func setIsLegacy(b bool) {
|
||||
if b {
|
||||
os.Setenv("NB_USE_LEGACY_ROUTING", "true")
|
||||
} else {
|
||||
os.Unsetenv("NB_USE_LEGACY_ROUTING")
|
||||
}
|
||||
}
|
||||
|
||||
func getSetupRules() []ruleParams {
|
||||
return []ruleParams{
|
||||
{nbnet.NetbirdFwmark, NetbirdVPNTableID, netlink.FAMILY_V4, -1, true, -1, "rule v4 netbird"},
|
||||
{nbnet.NetbirdFwmark, NetbirdVPNTableID, netlink.FAMILY_V6, -1, true, -1, "rule v6 netbird"},
|
||||
{-1, syscall.RT_TABLE_MAIN, netlink.FAMILY_V4, -1, false, 0, "rule with suppress prefixlen v4"},
|
||||
{-1, syscall.RT_TABLE_MAIN, netlink.FAMILY_V6, -1, false, 0, "rule with suppress prefixlen v6"},
|
||||
{100, -1, syscall.RT_TABLE_MAIN, netlink.FAMILY_V4, false, 0, "rule with suppress prefixlen v4"},
|
||||
{100, -1, syscall.RT_TABLE_MAIN, netlink.FAMILY_V6, false, 0, "rule with suppress prefixlen v6"},
|
||||
{110, nbnet.NetbirdFwmark, NetbirdVPNTableID, netlink.FAMILY_V4, true, -1, "rule v4 netbird"},
|
||||
{110, nbnet.NetbirdFwmark, NetbirdVPNTableID, netlink.FAMILY_V6, true, -1, "rule v6 netbird"},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -62,13 +92,23 @@ func getSetupRules() []ruleParams {
|
||||
// Rule 2 (VPN Traffic Routing): Directs all remaining traffic to the 'NetbirdVPNTableID' custom routing table.
|
||||
// This table is where a default route or other specific routes received from the management server are configured,
|
||||
// enabling VPN connectivity.
|
||||
//
|
||||
// The rules are inserted in reverse order, as rules are added from the bottom up in the rule list.
|
||||
func setupRouting() (err error) {
|
||||
func setupRouting(initAddresses []net.IP, wgIface *iface.WGIface) (_ peer.BeforeAddPeerHookFunc, _ peer.AfterRemovePeerHookFunc, err error) {
|
||||
if isLegacy() {
|
||||
log.Infof("Using legacy routing setup")
|
||||
return setupRoutingWithRouteManager(&routeManager, initAddresses, wgIface)
|
||||
}
|
||||
|
||||
if err = addRoutingTableName(); err != nil {
|
||||
log.Errorf("Error adding routing table name: %v", err)
|
||||
}
|
||||
|
||||
originalValues, err := setupSysctl(wgIface)
|
||||
if err != nil {
|
||||
log.Errorf("Error setting up sysctl: %v", err)
|
||||
sysctlFailed = true
|
||||
}
|
||||
originalSysctl = originalValues
|
||||
|
||||
defer func() {
|
||||
if err != nil {
|
||||
if cleanErr := cleanupRouting(); cleanErr != nil {
|
||||
@@ -80,17 +120,26 @@ func setupRouting() (err error) {
|
||||
rules := getSetupRules()
|
||||
for _, rule := range rules {
|
||||
if err := addRule(rule); err != nil {
|
||||
return fmt.Errorf("%s: %w", rule.description, err)
|
||||
if errors.Is(err, syscall.EOPNOTSUPP) {
|
||||
log.Warnf("Rule operations are not supported, falling back to the legacy routing setup")
|
||||
setIsLegacy(true)
|
||||
return setupRoutingWithRouteManager(&routeManager, initAddresses, wgIface)
|
||||
}
|
||||
return nil, nil, fmt.Errorf("%s: %w", rule.description, err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
return nil, nil, nil
|
||||
}
|
||||
|
||||
// cleanupRouting performs a thorough cleanup of the routing configuration established by 'setupRouting'.
|
||||
// It systematically removes the three rules and any associated routing table entries to ensure a clean state.
|
||||
// The function uses error aggregation to report any errors encountered during the cleanup process.
|
||||
func cleanupRouting() error {
|
||||
if isLegacy() {
|
||||
return cleanupRoutingWithRouteManager(routeManager)
|
||||
}
|
||||
|
||||
var result *multierror.Error
|
||||
|
||||
if err := flushRoutes(NetbirdVPNTableID, netlink.FAMILY_V4); err != nil {
|
||||
@@ -102,61 +151,122 @@ func cleanupRouting() error {
|
||||
|
||||
rules := getSetupRules()
|
||||
for _, rule := range rules {
|
||||
if err := removeAllRules(rule); err != nil {
|
||||
if err := removeRule(rule); err != nil {
|
||||
result = multierror.Append(result, fmt.Errorf("%s: %w", rule.description, err))
|
||||
}
|
||||
}
|
||||
|
||||
if err := cleanupSysctl(originalSysctl); err != nil {
|
||||
result = multierror.Append(result, fmt.Errorf("cleanup sysctl: %w", err))
|
||||
}
|
||||
originalSysctl = nil
|
||||
sysctlFailed = false
|
||||
|
||||
return result.ErrorOrNil()
|
||||
}
|
||||
|
||||
func addToRouteTableIfNoExists(prefix netip.Prefix, _ string, intf string) error {
|
||||
// No need to check if routes exist as main table takes precedence over the VPN table via Rule 2
|
||||
func addToRouteTable(prefix netip.Prefix, nexthop netip.Addr, intf *net.Interface) error {
|
||||
return addRoute(prefix, nexthop, intf, syscall.RT_TABLE_MAIN)
|
||||
}
|
||||
|
||||
func removeFromRouteTable(prefix netip.Prefix, nexthop netip.Addr, intf *net.Interface) error {
|
||||
return removeRoute(prefix, nexthop, intf, syscall.RT_TABLE_MAIN)
|
||||
}
|
||||
|
||||
func addVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
|
||||
if isLegacy() {
|
||||
return genericAddVPNRoute(prefix, intf)
|
||||
}
|
||||
|
||||
if sysctlFailed && (prefix == defaultv4 || prefix == defaultv6) {
|
||||
log.Warnf("Default route is configured but sysctl operations failed, VPN traffic may not be routed correctly, consider using NB_USE_LEGACY_ROUTING=true or setting net.ipv4.conf.*.rp_filter to 2 (loose) or 0 (off)")
|
||||
}
|
||||
|
||||
// No need to check if routes exist as main table takes precedence over the VPN table via Rule 1
|
||||
|
||||
// TODO remove this once we have ipv6 support
|
||||
if prefix == defaultv4 {
|
||||
if err := addUnreachableRoute(&defaultv6, NetbirdVPNTableID, netlink.FAMILY_V6); err != nil {
|
||||
if err := addUnreachableRoute(defaultv6, NetbirdVPNTableID); err != nil {
|
||||
return fmt.Errorf("add blackhole: %w", err)
|
||||
}
|
||||
}
|
||||
if err := addRoute(&prefix, nil, &intf, NetbirdVPNTableID, netlink.FAMILY_V4); err != nil {
|
||||
if err := addRoute(prefix, netip.Addr{}, intf, NetbirdVPNTableID); err != nil {
|
||||
return fmt.Errorf("add route: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func removeFromRouteTableIfNonSystem(prefix netip.Prefix, _ string, intf string) error {
|
||||
func removeVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
|
||||
if isLegacy() {
|
||||
return genericRemoveVPNRoute(prefix, intf)
|
||||
}
|
||||
|
||||
// TODO remove this once we have ipv6 support
|
||||
if prefix == defaultv4 {
|
||||
if err := removeUnreachableRoute(&defaultv6, NetbirdVPNTableID, netlink.FAMILY_V6); err != nil {
|
||||
if err := removeUnreachableRoute(defaultv6, NetbirdVPNTableID); err != nil {
|
||||
return fmt.Errorf("remove unreachable route: %w", err)
|
||||
}
|
||||
}
|
||||
if err := removeRoute(&prefix, nil, &intf, NetbirdVPNTableID, netlink.FAMILY_V4); err != nil {
|
||||
if err := removeRoute(prefix, netip.Addr{}, intf, NetbirdVPNTableID); err != nil {
|
||||
return fmt.Errorf("remove route: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func getRoutesFromTable() ([]netip.Prefix, error) {
|
||||
return getRoutes(NetbirdVPNTableID, netlink.FAMILY_V4)
|
||||
v4Routes, err := getRoutes(syscall.RT_TABLE_MAIN, netlink.FAMILY_V4)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get v4 routes: %w", err)
|
||||
}
|
||||
v6Routes, err := getRoutes(syscall.RT_TABLE_MAIN, netlink.FAMILY_V6)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get v6 routes: %w", err)
|
||||
|
||||
}
|
||||
return append(v4Routes, v6Routes...), nil
|
||||
}
|
||||
|
||||
// getRoutes fetches routes from a specific routing table identified by tableID.
|
||||
func getRoutes(tableID, family int) ([]netip.Prefix, error) {
|
||||
var prefixList []netip.Prefix
|
||||
|
||||
routes, err := netlink.RouteListFiltered(family, &netlink.Route{Table: tableID}, netlink.RT_FILTER_TABLE)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("list routes from table %d: %v", tableID, err)
|
||||
}
|
||||
|
||||
for _, route := range routes {
|
||||
if route.Dst != nil {
|
||||
addr, ok := netip.AddrFromSlice(route.Dst.IP)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("parse route destination IP: %v", route.Dst.IP)
|
||||
}
|
||||
|
||||
ones, _ := route.Dst.Mask.Size()
|
||||
|
||||
prefix := netip.PrefixFrom(addr, ones)
|
||||
if prefix.IsValid() {
|
||||
prefixList = append(prefixList, prefix)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return prefixList, nil
|
||||
}
|
||||
|
||||
// addRoute adds a route to a specific routing table identified by tableID.
|
||||
func addRoute(prefix *netip.Prefix, addr, intf *string, tableID, family int) error {
|
||||
func addRoute(prefix netip.Prefix, addr netip.Addr, intf *net.Interface, tableID int) error {
|
||||
route := &netlink.Route{
|
||||
Scope: netlink.SCOPE_UNIVERSE,
|
||||
Table: tableID,
|
||||
Family: family,
|
||||
Family: getAddressFamily(prefix),
|
||||
}
|
||||
|
||||
if prefix != nil {
|
||||
_, ipNet, err := net.ParseCIDR(prefix.String())
|
||||
if err != nil {
|
||||
return fmt.Errorf("parse prefix %s: %w", prefix, err)
|
||||
}
|
||||
route.Dst = ipNet
|
||||
_, ipNet, err := net.ParseCIDR(prefix.String())
|
||||
if err != nil {
|
||||
return fmt.Errorf("parse prefix %s: %w", prefix, err)
|
||||
}
|
||||
route.Dst = ipNet
|
||||
|
||||
if err := addNextHop(addr, intf, route); err != nil {
|
||||
return fmt.Errorf("add gateway and device: %w", err)
|
||||
@@ -172,7 +282,7 @@ func addRoute(prefix *netip.Prefix, addr, intf *string, tableID, family int) err
|
||||
// addUnreachableRoute adds an unreachable route for the specified IP family and routing table.
|
||||
// ipFamily should be netlink.FAMILY_V4 for IPv4 or netlink.FAMILY_V6 for IPv6.
|
||||
// tableID specifies the routing table to which the unreachable route will be added.
|
||||
func addUnreachableRoute(prefix *netip.Prefix, tableID, ipFamily int) error {
|
||||
func addUnreachableRoute(prefix netip.Prefix, tableID int) error {
|
||||
_, ipNet, err := net.ParseCIDR(prefix.String())
|
||||
if err != nil {
|
||||
return fmt.Errorf("parse prefix %s: %w", prefix, err)
|
||||
@@ -181,7 +291,7 @@ func addUnreachableRoute(prefix *netip.Prefix, tableID, ipFamily int) error {
|
||||
route := &netlink.Route{
|
||||
Type: syscall.RTN_UNREACHABLE,
|
||||
Table: tableID,
|
||||
Family: ipFamily,
|
||||
Family: getAddressFamily(prefix),
|
||||
Dst: ipNet,
|
||||
}
|
||||
|
||||
@@ -192,7 +302,7 @@ func addUnreachableRoute(prefix *netip.Prefix, tableID, ipFamily int) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func removeUnreachableRoute(prefix *netip.Prefix, tableID, ipFamily int) error {
|
||||
func removeUnreachableRoute(prefix netip.Prefix, tableID int) error {
|
||||
_, ipNet, err := net.ParseCIDR(prefix.String())
|
||||
if err != nil {
|
||||
return fmt.Errorf("parse prefix %s: %w", prefix, err)
|
||||
@@ -201,11 +311,14 @@ func removeUnreachableRoute(prefix *netip.Prefix, tableID, ipFamily int) error {
|
||||
route := &netlink.Route{
|
||||
Type: syscall.RTN_UNREACHABLE,
|
||||
Table: tableID,
|
||||
Family: ipFamily,
|
||||
Family: getAddressFamily(prefix),
|
||||
Dst: ipNet,
|
||||
}
|
||||
|
||||
if err := netlink.RouteDel(route); err != nil && !errors.Is(err, syscall.ESRCH) && !errors.Is(err, syscall.EAFNOSUPPORT) {
|
||||
if err := netlink.RouteDel(route); err != nil &&
|
||||
!errors.Is(err, syscall.ESRCH) &&
|
||||
!errors.Is(err, syscall.ENOENT) &&
|
||||
!errors.Is(err, syscall.EAFNOSUPPORT) {
|
||||
return fmt.Errorf("netlink remove unreachable route: %w", err)
|
||||
}
|
||||
|
||||
@@ -214,7 +327,7 @@ func removeUnreachableRoute(prefix *netip.Prefix, tableID, ipFamily int) error {
|
||||
}
|
||||
|
||||
// removeRoute removes a route from a specific routing table identified by tableID.
|
||||
func removeRoute(prefix *netip.Prefix, addr, intf *string, tableID, family int) error {
|
||||
func removeRoute(prefix netip.Prefix, addr netip.Addr, intf *net.Interface, tableID int) error {
|
||||
_, ipNet, err := net.ParseCIDR(prefix.String())
|
||||
if err != nil {
|
||||
return fmt.Errorf("parse prefix %s: %w", prefix, err)
|
||||
@@ -223,7 +336,7 @@ func removeRoute(prefix *netip.Prefix, addr, intf *string, tableID, family int)
|
||||
route := &netlink.Route{
|
||||
Scope: netlink.SCOPE_UNIVERSE,
|
||||
Table: tableID,
|
||||
Family: family,
|
||||
Family: getAddressFamily(prefix),
|
||||
Dst: ipNet,
|
||||
}
|
||||
|
||||
@@ -263,51 +376,9 @@ func flushRoutes(tableID, family int) error {
|
||||
return result.ErrorOrNil()
|
||||
}
|
||||
|
||||
// getRoutes fetches routes from a specific routing table identified by tableID.
|
||||
func getRoutes(tableID, family int) ([]netip.Prefix, error) {
|
||||
var prefixList []netip.Prefix
|
||||
|
||||
routes, err := netlink.RouteListFiltered(family, &netlink.Route{Table: tableID}, netlink.RT_FILTER_TABLE)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("list routes from table %d: %v", tableID, err)
|
||||
}
|
||||
|
||||
for _, route := range routes {
|
||||
if route.Dst != nil {
|
||||
addr, ok := netip.AddrFromSlice(route.Dst.IP)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("parse route destination IP: %v", route.Dst.IP)
|
||||
}
|
||||
|
||||
ones, _ := route.Dst.Mask.Size()
|
||||
|
||||
prefix := netip.PrefixFrom(addr, ones)
|
||||
if prefix.IsValid() {
|
||||
prefixList = append(prefixList, prefix)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return prefixList, nil
|
||||
}
|
||||
|
||||
func enableIPForwarding() error {
|
||||
bytes, err := os.ReadFile(ipv4ForwardingPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("read file %s: %w", ipv4ForwardingPath, err)
|
||||
}
|
||||
|
||||
// check if it is already enabled
|
||||
// see more: https://github.com/netbirdio/netbird/issues/872
|
||||
if len(bytes) > 0 && bytes[0] == 49 {
|
||||
return nil
|
||||
}
|
||||
|
||||
//nolint:gosec
|
||||
if err := os.WriteFile(ipv4ForwardingPath, []byte("1"), 0644); err != nil {
|
||||
return fmt.Errorf("write file %s: %w", ipv4ForwardingPath, err)
|
||||
}
|
||||
return nil
|
||||
_, err := setSysctl(ipv4ForwardingPath, 1, false)
|
||||
return err
|
||||
}
|
||||
|
||||
// entryExists checks if the specified ID or name already exists in the rt_tables file
|
||||
@@ -385,7 +456,7 @@ func addRule(params ruleParams) error {
|
||||
rule.Invert = params.invert
|
||||
rule.SuppressPrefixlen = params.suppressPrefix
|
||||
|
||||
if err := netlink.RuleAdd(rule); err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) {
|
||||
if err := netlink.RuleAdd(rule); err != nil && !errors.Is(err, syscall.EEXIST) && !errors.Is(err, syscall.EAFNOSUPPORT) {
|
||||
return fmt.Errorf("add routing rule: %w", err)
|
||||
}
|
||||
|
||||
@@ -402,43 +473,118 @@ func removeRule(params ruleParams) error {
|
||||
rule.Priority = params.priority
|
||||
rule.SuppressPrefixlen = params.suppressPrefix
|
||||
|
||||
if err := netlink.RuleDel(rule); err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) {
|
||||
if err := netlink.RuleDel(rule); err != nil && !errors.Is(err, syscall.ENOENT) && !errors.Is(err, syscall.EAFNOSUPPORT) {
|
||||
return fmt.Errorf("remove routing rule: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func removeAllRules(params ruleParams) error {
|
||||
for {
|
||||
if err := removeRule(params); err != nil {
|
||||
if errors.Is(err, syscall.ENOENT) {
|
||||
break
|
||||
}
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// addNextHop adds the gateway and device to the route.
|
||||
func addNextHop(addr *string, intf *string, route *netlink.Route) error {
|
||||
if addr != nil {
|
||||
ip := net.ParseIP(*addr)
|
||||
if ip == nil {
|
||||
return fmt.Errorf("parsing address %s failed", *addr)
|
||||
}
|
||||
|
||||
route.Gw = ip
|
||||
func addNextHop(addr netip.Addr, intf *net.Interface, route *netlink.Route) error {
|
||||
if intf != nil {
|
||||
route.LinkIndex = intf.Index
|
||||
}
|
||||
|
||||
if intf != nil {
|
||||
link, err := netlink.LinkByName(*intf)
|
||||
if err != nil {
|
||||
return fmt.Errorf("set interface %s: %w", *intf, err)
|
||||
if addr.IsValid() {
|
||||
route.Gw = addr.AsSlice()
|
||||
|
||||
// if zone is set, it means the gateway is a link-local address, so we set the link index
|
||||
if addr.Zone() != "" && intf == nil {
|
||||
link, err := netlink.LinkByName(addr.Zone())
|
||||
if err != nil {
|
||||
return fmt.Errorf("get link by name for zone %s: %w", addr.Zone(), err)
|
||||
}
|
||||
route.LinkIndex = link.Attrs().Index
|
||||
}
|
||||
route.LinkIndex = link.Attrs().Index
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func getAddressFamily(prefix netip.Prefix) int {
|
||||
if prefix.Addr().Is4() {
|
||||
return netlink.FAMILY_V4
|
||||
}
|
||||
return netlink.FAMILY_V6
|
||||
}
|
||||
|
||||
// setupSysctl configures sysctl settings for RP filtering and source validation.
|
||||
func setupSysctl(wgIface *iface.WGIface) (map[string]int, error) {
|
||||
keys := map[string]int{}
|
||||
var result *multierror.Error
|
||||
|
||||
oldVal, err := setSysctl(srcValidMarkPath, 1, false)
|
||||
if err != nil {
|
||||
result = multierror.Append(result, err)
|
||||
} else {
|
||||
keys[srcValidMarkPath] = oldVal
|
||||
}
|
||||
|
||||
oldVal, err = setSysctl(rpFilterPath, 2, true)
|
||||
if err != nil {
|
||||
result = multierror.Append(result, err)
|
||||
} else {
|
||||
keys[rpFilterPath] = oldVal
|
||||
}
|
||||
|
||||
interfaces, err := net.Interfaces()
|
||||
if err != nil {
|
||||
result = multierror.Append(result, fmt.Errorf("list interfaces: %w", err))
|
||||
}
|
||||
|
||||
for _, intf := range interfaces {
|
||||
if intf.Name == "lo" || wgIface != nil && intf.Name == wgIface.Name() {
|
||||
continue
|
||||
}
|
||||
|
||||
i := fmt.Sprintf(rpFilterInterfacePath, intf.Name)
|
||||
oldVal, err := setSysctl(i, 2, true)
|
||||
if err != nil {
|
||||
result = multierror.Append(result, err)
|
||||
} else {
|
||||
keys[i] = oldVal
|
||||
}
|
||||
}
|
||||
|
||||
return keys, result.ErrorOrNil()
|
||||
}
|
||||
|
||||
// setSysctl sets a sysctl configuration, if onlyIfOne is true it will only set the new value if it's set to 1
|
||||
func setSysctl(key string, desiredValue int, onlyIfOne bool) (int, error) {
|
||||
path := fmt.Sprintf("/proc/sys/%s", strings.ReplaceAll(key, ".", "/"))
|
||||
currentValue, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return -1, fmt.Errorf("read sysctl %s: %w", key, err)
|
||||
}
|
||||
|
||||
currentV, err := strconv.Atoi(strings.TrimSpace(string(currentValue)))
|
||||
if err != nil && len(currentValue) > 0 {
|
||||
return -1, fmt.Errorf("convert current desiredValue to int: %w", err)
|
||||
}
|
||||
|
||||
if currentV == desiredValue || onlyIfOne && currentV != 1 {
|
||||
return currentV, nil
|
||||
}
|
||||
|
||||
//nolint:gosec
|
||||
if err := os.WriteFile(path, []byte(strconv.Itoa(desiredValue)), 0644); err != nil {
|
||||
return currentV, fmt.Errorf("write sysctl %s: %w", key, err)
|
||||
}
|
||||
log.Debugf("Set sysctl %s from %d to %d", key, currentV, desiredValue)
|
||||
|
||||
return currentV, nil
|
||||
}
|
||||
|
||||
func cleanupSysctl(originalSettings map[string]int) error {
|
||||
var result *multierror.Error
|
||||
|
||||
for key, value := range originalSettings {
|
||||
_, err := setSysctl(key, value, false)
|
||||
if err != nil {
|
||||
result = multierror.Append(result, err)
|
||||
}
|
||||
}
|
||||
|
||||
return result.ErrorOrNil()
|
||||
}
|
||||
|
||||
@@ -6,34 +6,38 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"strings"
|
||||
"syscall"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gopacket/gopacket"
|
||||
"github.com/gopacket/gopacket/layers"
|
||||
"github.com/gopacket/gopacket/pcap"
|
||||
"github.com/miekg/dns"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/vishvananda/netlink"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
)
|
||||
|
||||
type PacketExpectation struct {
|
||||
SrcIP net.IP
|
||||
DstIP net.IP
|
||||
SrcPort int
|
||||
DstPort int
|
||||
UDP bool
|
||||
TCP bool
|
||||
var expectedVPNint = "wgtest0"
|
||||
var expectedLoopbackInt = "lo"
|
||||
var expectedExternalInt = "dummyext0"
|
||||
var expectedInternalInt = "dummyint0"
|
||||
|
||||
func init() {
|
||||
testCases = append(testCases, []testCase{
|
||||
{
|
||||
name: "To more specific route without custom dialer via physical interface",
|
||||
destination: "10.10.0.2:53",
|
||||
expectedInterface: expectedInternalInt,
|
||||
dialer: &net.Dialer{},
|
||||
expectedPacket: createPacketExpectation("192.168.1.1", 12345, "10.10.0.2", 53),
|
||||
},
|
||||
{
|
||||
name: "To more specific route (local) without custom dialer via physical interface",
|
||||
destination: "127.0.10.1:53",
|
||||
expectedInterface: expectedLoopbackInt,
|
||||
dialer: &net.Dialer{},
|
||||
expectedPacket: createPacketExpectation("127.0.0.1", 12345, "127.0.10.1", 53),
|
||||
},
|
||||
}...)
|
||||
}
|
||||
|
||||
func TestEntryExists(t *testing.T) {
|
||||
@@ -92,157 +96,7 @@ func TestEntryExists(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestRoutingWithTables(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
destination string
|
||||
captureInterface string
|
||||
dialer *net.Dialer
|
||||
packetExpectation PacketExpectation
|
||||
}{
|
||||
{
|
||||
name: "To external host without fwmark via vpn",
|
||||
destination: "192.0.2.1:53",
|
||||
captureInterface: "wgtest0",
|
||||
dialer: &net.Dialer{},
|
||||
packetExpectation: createPacketExpectation("100.64.0.1", 12345, "192.0.2.1", 53),
|
||||
},
|
||||
{
|
||||
name: "To external host with fwmark via physical interface",
|
||||
destination: "192.0.2.1:53",
|
||||
captureInterface: "dummyext0",
|
||||
dialer: nbnet.NewDialer(),
|
||||
packetExpectation: createPacketExpectation("192.168.0.1", 12345, "192.0.2.1", 53),
|
||||
},
|
||||
|
||||
{
|
||||
name: "To duplicate internal route with fwmark via physical interface",
|
||||
destination: "10.0.0.1:53",
|
||||
captureInterface: "dummyint0",
|
||||
dialer: nbnet.NewDialer(),
|
||||
packetExpectation: createPacketExpectation("192.168.1.1", 12345, "10.0.0.1", 53),
|
||||
},
|
||||
{
|
||||
name: "To duplicate internal route without fwmark via physical interface", // local route takes precedence
|
||||
destination: "10.0.0.1:53",
|
||||
captureInterface: "dummyint0",
|
||||
dialer: &net.Dialer{},
|
||||
packetExpectation: createPacketExpectation("192.168.1.1", 12345, "10.0.0.1", 53),
|
||||
},
|
||||
|
||||
{
|
||||
name: "To unique vpn route with fwmark via physical interface",
|
||||
destination: "172.16.0.1:53",
|
||||
captureInterface: "dummyext0",
|
||||
dialer: nbnet.NewDialer(),
|
||||
packetExpectation: createPacketExpectation("192.168.0.1", 12345, "172.16.0.1", 53),
|
||||
},
|
||||
{
|
||||
name: "To unique vpn route without fwmark via vpn",
|
||||
destination: "172.16.0.1:53",
|
||||
captureInterface: "wgtest0",
|
||||
dialer: &net.Dialer{},
|
||||
packetExpectation: createPacketExpectation("100.64.0.1", 12345, "172.16.0.1", 53),
|
||||
},
|
||||
|
||||
{
|
||||
name: "To more specific route without fwmark via vpn interface",
|
||||
destination: "10.10.0.1:53",
|
||||
captureInterface: "dummyint0",
|
||||
dialer: &net.Dialer{},
|
||||
packetExpectation: createPacketExpectation("192.168.1.1", 12345, "10.10.0.1", 53),
|
||||
},
|
||||
|
||||
{
|
||||
name: "To more specific route (local) without fwmark via physical interface",
|
||||
destination: "127.0.10.1:53",
|
||||
captureInterface: "lo",
|
||||
dialer: &net.Dialer{},
|
||||
packetExpectation: createPacketExpectation("127.0.0.1", 12345, "127.0.10.1", 53),
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
wgIface, _, _ := setupTestEnv(t)
|
||||
|
||||
// default route exists in main table and vpn table
|
||||
err := addToRouteTableIfNoExists(netip.MustParsePrefix("0.0.0.0/0"), wgIface.Address().IP.String(), wgIface.Name())
|
||||
require.NoError(t, err, "addToRouteTableIfNoExists should not return err")
|
||||
|
||||
// 10.0.0.0/8 route exists in main table and vpn table
|
||||
err = addToRouteTableIfNoExists(netip.MustParsePrefix("10.0.0.0/8"), wgIface.Address().IP.String(), wgIface.Name())
|
||||
require.NoError(t, err, "addToRouteTableIfNoExists should not return err")
|
||||
|
||||
// 10.10.0.0/24 more specific route exists in vpn table
|
||||
err = addToRouteTableIfNoExists(netip.MustParsePrefix("10.10.0.0/24"), wgIface.Address().IP.String(), wgIface.Name())
|
||||
require.NoError(t, err, "addToRouteTableIfNoExists should not return err")
|
||||
|
||||
// 127.0.10.0/24 more specific route exists in vpn table
|
||||
err = addToRouteTableIfNoExists(netip.MustParsePrefix("127.0.10.0/24"), wgIface.Address().IP.String(), wgIface.Name())
|
||||
require.NoError(t, err, "addToRouteTableIfNoExists should not return err")
|
||||
|
||||
// unique route in vpn table
|
||||
err = addToRouteTableIfNoExists(netip.MustParsePrefix("172.16.0.0/16"), wgIface.Address().IP.String(), wgIface.Name())
|
||||
require.NoError(t, err, "addToRouteTableIfNoExists should not return err")
|
||||
|
||||
filter := createBPFFilter(tc.destination)
|
||||
handle := startPacketCapture(t, tc.captureInterface, filter)
|
||||
|
||||
sendTestPacket(t, tc.destination, tc.packetExpectation.SrcPort, tc.dialer)
|
||||
|
||||
packetSource := gopacket.NewPacketSource(handle, handle.LinkType())
|
||||
packet, err := packetSource.NextPacket()
|
||||
require.NoError(t, err)
|
||||
|
||||
verifyPacket(t, packet, tc.packetExpectation)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func verifyPacket(t *testing.T, packet gopacket.Packet, exp PacketExpectation) {
|
||||
t.Helper()
|
||||
|
||||
ipLayer := packet.Layer(layers.LayerTypeIPv4)
|
||||
require.NotNil(t, ipLayer, "Expected IPv4 layer not found in packet")
|
||||
|
||||
ip, ok := ipLayer.(*layers.IPv4)
|
||||
require.True(t, ok, "Failed to cast to IPv4 layer")
|
||||
|
||||
// Convert both source and destination IP addresses to 16-byte representation
|
||||
expectedSrcIP := exp.SrcIP.To16()
|
||||
actualSrcIP := ip.SrcIP.To16()
|
||||
assert.Equal(t, expectedSrcIP, actualSrcIP, "Source IP mismatch")
|
||||
|
||||
expectedDstIP := exp.DstIP.To16()
|
||||
actualDstIP := ip.DstIP.To16()
|
||||
assert.Equal(t, expectedDstIP, actualDstIP, "Destination IP mismatch")
|
||||
|
||||
if exp.UDP {
|
||||
udpLayer := packet.Layer(layers.LayerTypeUDP)
|
||||
require.NotNil(t, udpLayer, "Expected UDP layer not found in packet")
|
||||
|
||||
udp, ok := udpLayer.(*layers.UDP)
|
||||
require.True(t, ok, "Failed to cast to UDP layer")
|
||||
|
||||
assert.Equal(t, layers.UDPPort(exp.SrcPort), udp.SrcPort, "UDP source port mismatch")
|
||||
assert.Equal(t, layers.UDPPort(exp.DstPort), udp.DstPort, "UDP destination port mismatch")
|
||||
}
|
||||
|
||||
if exp.TCP {
|
||||
tcpLayer := packet.Layer(layers.LayerTypeTCP)
|
||||
require.NotNil(t, tcpLayer, "Expected TCP layer not found in packet")
|
||||
|
||||
tcp, ok := tcpLayer.(*layers.TCP)
|
||||
require.True(t, ok, "Failed to cast to TCP layer")
|
||||
|
||||
assert.Equal(t, layers.TCPPort(exp.SrcPort), tcp.SrcPort, "TCP source port mismatch")
|
||||
assert.Equal(t, layers.TCPPort(exp.DstPort), tcp.DstPort, "TCP destination port mismatch")
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func createAndSetupDummyInterface(t *testing.T, interfaceName, ipAddressCIDR string) *netlink.Dummy {
|
||||
func createAndSetupDummyInterface(t *testing.T, interfaceName, ipAddressCIDR string) string {
|
||||
t.Helper()
|
||||
|
||||
dummy := &netlink.Dummy{LinkAttrs: netlink.LinkAttrs{Name: interfaceName}}
|
||||
@@ -264,35 +118,52 @@ func createAndSetupDummyInterface(t *testing.T, interfaceName, ipAddressCIDR str
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
return dummy
|
||||
t.Cleanup(func() {
|
||||
err := netlink.LinkDel(dummy)
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
return dummy.Name
|
||||
}
|
||||
|
||||
func addDummyRoute(t *testing.T, dstCIDR string, gw net.IP, linkIndex int) {
|
||||
func addDummyRoute(t *testing.T, dstCIDR string, gw net.IP, intf string) {
|
||||
t.Helper()
|
||||
|
||||
_, dstIPNet, err := net.ParseCIDR(dstCIDR)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Handle existing routes with metric 0
|
||||
var originalNexthop net.IP
|
||||
var originalLinkIndex int
|
||||
if dstIPNet.String() == "0.0.0.0/0" {
|
||||
gw, linkIndex, err := fetchOriginalGateway(netlink.FAMILY_V4)
|
||||
if err != nil {
|
||||
var err error
|
||||
originalNexthop, originalLinkIndex, err = fetchOriginalGateway(netlink.FAMILY_V4)
|
||||
if err != nil && !errors.Is(err, ErrRouteNotFound) {
|
||||
t.Logf("Failed to fetch original gateway: %v", err)
|
||||
}
|
||||
|
||||
// Handle existing routes with metric 0
|
||||
err = netlink.RouteDel(&netlink.Route{Dst: dstIPNet, Priority: 0})
|
||||
if err == nil {
|
||||
t.Cleanup(func() {
|
||||
err := netlink.RouteAdd(&netlink.Route{Dst: dstIPNet, Gw: gw, LinkIndex: linkIndex, Priority: 0})
|
||||
if err != nil && !errors.Is(err, syscall.EEXIST) {
|
||||
t.Fatalf("Failed to add route: %v", err)
|
||||
}
|
||||
})
|
||||
} else if !errors.Is(err, syscall.ESRCH) {
|
||||
t.Logf("Failed to delete route: %v", err)
|
||||
if originalNexthop != nil {
|
||||
err = netlink.RouteDel(&netlink.Route{Dst: dstIPNet, Priority: 0})
|
||||
switch {
|
||||
case err != nil && !errors.Is(err, syscall.ESRCH):
|
||||
t.Logf("Failed to delete route: %v", err)
|
||||
case err == nil:
|
||||
t.Cleanup(func() {
|
||||
err := netlink.RouteAdd(&netlink.Route{Dst: dstIPNet, Gw: originalNexthop, LinkIndex: originalLinkIndex, Priority: 0})
|
||||
if err != nil && !errors.Is(err, syscall.EEXIST) {
|
||||
t.Fatalf("Failed to add route: %v", err)
|
||||
}
|
||||
})
|
||||
default:
|
||||
t.Logf("Failed to delete route: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
link, err := netlink.LinkByName(intf)
|
||||
require.NoError(t, err)
|
||||
linkIndex := link.Attrs().Index
|
||||
|
||||
route := &netlink.Route{
|
||||
Dst: dstIPNet,
|
||||
Gw: gw,
|
||||
@@ -307,9 +178,9 @@ func addDummyRoute(t *testing.T, dstCIDR string, gw net.IP, linkIndex int) {
|
||||
if err != nil && !errors.Is(err, syscall.EEXIST) {
|
||||
t.Fatalf("Failed to add route: %v", err)
|
||||
}
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// fetchOriginalGateway returns the original gateway IP address and the interface index.
|
||||
func fetchOriginalGateway(family int) (net.IP, int, error) {
|
||||
routes, err := netlink.RouteList(nil, family)
|
||||
if err != nil {
|
||||
@@ -317,153 +188,20 @@ func fetchOriginalGateway(family int) (net.IP, int, error) {
|
||||
}
|
||||
|
||||
for _, route := range routes {
|
||||
if route.Dst == nil {
|
||||
if route.Dst == nil && route.Priority == 0 {
|
||||
return route.Gw, route.LinkIndex, nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil, 0, fmt.Errorf("default route not found")
|
||||
return nil, 0, ErrRouteNotFound
|
||||
}
|
||||
|
||||
func setupDummyInterfacesAndRoutes(t *testing.T) (string, string) {
|
||||
func setupDummyInterfacesAndRoutes(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
defaultDummy := createAndSetupDummyInterface(t, "dummyext0", "192.168.0.1/24")
|
||||
addDummyRoute(t, "0.0.0.0/0", net.IPv4(192, 168, 0, 1), defaultDummy.Attrs().Index)
|
||||
addDummyRoute(t, "0.0.0.0/0", net.IPv4(192, 168, 0, 1), defaultDummy)
|
||||
|
||||
otherDummy := createAndSetupDummyInterface(t, "dummyint0", "192.168.1.1/24")
|
||||
addDummyRoute(t, "10.0.0.0/8", nil, otherDummy.Attrs().Index)
|
||||
|
||||
t.Cleanup(func() {
|
||||
err := netlink.LinkDel(defaultDummy)
|
||||
assert.NoError(t, err)
|
||||
err = netlink.LinkDel(otherDummy)
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
return defaultDummy.Name, otherDummy.Name
|
||||
}
|
||||
|
||||
func createWGInterface(t *testing.T, interfaceName, ipAddressCIDR string, listenPort int) *iface.WGIface {
|
||||
t.Helper()
|
||||
|
||||
peerPrivateKey, err := wgtypes.GeneratePrivateKey()
|
||||
require.NoError(t, err)
|
||||
|
||||
newNet, err := stdnet.NewNet(nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
wgInterface, err := iface.NewWGIFace(interfaceName, ipAddressCIDR, listenPort, peerPrivateKey.String(), iface.DefaultMTU, newNet, nil)
|
||||
require.NoError(t, err, "should create testing WireGuard interface")
|
||||
|
||||
err = wgInterface.Create()
|
||||
require.NoError(t, err, "should create testing WireGuard interface")
|
||||
|
||||
t.Cleanup(func() {
|
||||
wgInterface.Close()
|
||||
})
|
||||
|
||||
return wgInterface
|
||||
}
|
||||
|
||||
func setupTestEnv(t *testing.T) (*iface.WGIface, string, string) {
|
||||
t.Helper()
|
||||
|
||||
defaultDummy, otherDummy := setupDummyInterfacesAndRoutes(t)
|
||||
|
||||
wgIface := createWGInterface(t, "wgtest0", "100.64.0.1/24", 51820)
|
||||
t.Cleanup(func() {
|
||||
assert.NoError(t, wgIface.Close())
|
||||
})
|
||||
|
||||
err := setupRouting()
|
||||
require.NoError(t, err, "setupRouting should not return err")
|
||||
t.Cleanup(func() {
|
||||
assert.NoError(t, cleanupRouting())
|
||||
})
|
||||
|
||||
return wgIface, defaultDummy, otherDummy
|
||||
}
|
||||
|
||||
func startPacketCapture(t *testing.T, intf, filter string) *pcap.Handle {
|
||||
t.Helper()
|
||||
|
||||
inactive, err := pcap.NewInactiveHandle(intf)
|
||||
require.NoError(t, err, "Failed to create inactive pcap handle")
|
||||
defer inactive.CleanUp()
|
||||
|
||||
err = inactive.SetSnapLen(1600)
|
||||
require.NoError(t, err, "Failed to set snap length on inactive handle")
|
||||
|
||||
err = inactive.SetTimeout(time.Second * 10)
|
||||
require.NoError(t, err, "Failed to set timeout on inactive handle")
|
||||
|
||||
err = inactive.SetImmediateMode(true)
|
||||
require.NoError(t, err, "Failed to set immediate mode on inactive handle")
|
||||
|
||||
handle, err := inactive.Activate()
|
||||
require.NoError(t, err, "Failed to activate pcap handle")
|
||||
t.Cleanup(handle.Close)
|
||||
|
||||
err = handle.SetBPFFilter(filter)
|
||||
require.NoError(t, err, "Failed to set BPF filter")
|
||||
|
||||
return handle
|
||||
}
|
||||
|
||||
func sendTestPacket(t *testing.T, destination string, sourcePort int, dialer *net.Dialer) {
|
||||
t.Helper()
|
||||
|
||||
if dialer == nil {
|
||||
dialer = &net.Dialer{}
|
||||
}
|
||||
|
||||
if sourcePort != 0 {
|
||||
localUDPAddr := &net.UDPAddr{
|
||||
IP: net.IPv4zero,
|
||||
Port: sourcePort,
|
||||
}
|
||||
dialer.LocalAddr = localUDPAddr
|
||||
}
|
||||
|
||||
msg := new(dns.Msg)
|
||||
msg.Id = dns.Id()
|
||||
msg.RecursionDesired = true
|
||||
msg.Question = []dns.Question{
|
||||
{Name: "example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET},
|
||||
}
|
||||
|
||||
conn, err := dialer.Dial("udp", destination)
|
||||
require.NoError(t, err, "Failed to dial UDP")
|
||||
defer conn.Close()
|
||||
|
||||
data, err := msg.Pack()
|
||||
require.NoError(t, err, "Failed to pack DNS message")
|
||||
|
||||
_, err = conn.Write(data)
|
||||
if err != nil {
|
||||
if strings.Contains(err.Error(), "required key not available") {
|
||||
t.Logf("Ignoring WireGuard key error: %v", err)
|
||||
return
|
||||
}
|
||||
t.Fatalf("Failed to send DNS query: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func createBPFFilter(destination string) string {
|
||||
host, port, err := net.SplitHostPort(destination)
|
||||
if err != nil {
|
||||
return fmt.Sprintf("udp and dst host %s and dst port %s", host, port)
|
||||
}
|
||||
return "udp"
|
||||
}
|
||||
|
||||
func createPacketExpectation(srcIP string, srcPort int, dstIP string, dstPort int) PacketExpectation {
|
||||
return PacketExpectation{
|
||||
SrcIP: net.ParseIP(srcIP),
|
||||
DstIP: net.ParseIP(dstIP),
|
||||
SrcPort: srcPort,
|
||||
DstPort: dstPort,
|
||||
UDP: true,
|
||||
}
|
||||
addDummyRoute(t, "10.0.0.0/8", net.IPv4(192, 168, 1, 1), otherDummy)
|
||||
}
|
||||
|
||||
@@ -1,148 +0,0 @@
|
||||
//go:build !android
|
||||
|
||||
//nolint:unused
|
||||
package routemanager
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os/exec"
|
||||
"runtime"
|
||||
|
||||
"github.com/libp2p/go-netroute"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
var errRouteNotFound = fmt.Errorf("route not found")
|
||||
|
||||
func genericAddRouteForCurrentDefaultGateway(prefix netip.Prefix) error {
|
||||
defaultGateway, err := getExistingRIBRouteGateway(defaultv4)
|
||||
if err != nil && !errors.Is(err, errRouteNotFound) {
|
||||
return fmt.Errorf("get existing route gateway: %s", err)
|
||||
}
|
||||
|
||||
addr := netip.MustParseAddr(defaultGateway.String())
|
||||
|
||||
if !prefix.Contains(addr) {
|
||||
log.Debugf("Skipping adding a new route for gateway %s because it is not in the network %s", addr, prefix)
|
||||
return nil
|
||||
}
|
||||
|
||||
gatewayPrefix := netip.PrefixFrom(addr, 32)
|
||||
|
||||
ok, err := existsInRouteTable(gatewayPrefix)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to check if there is an existing route for gateway %s. error: %s", gatewayPrefix, err)
|
||||
}
|
||||
|
||||
if ok {
|
||||
log.Debugf("Skipping adding a new route for gateway %s because it already exists", gatewayPrefix)
|
||||
return nil
|
||||
}
|
||||
|
||||
gatewayHop, err := getExistingRIBRouteGateway(gatewayPrefix)
|
||||
if err != nil && !errors.Is(err, errRouteNotFound) {
|
||||
return fmt.Errorf("unable to get the next hop for the default gateway address. error: %s", err)
|
||||
}
|
||||
log.Debugf("Adding a new route for gateway %s with next hop %s", gatewayPrefix, gatewayHop)
|
||||
return genericAddToRouteTable(gatewayPrefix, gatewayHop.String(), "")
|
||||
}
|
||||
|
||||
func genericAddToRouteTableIfNoExists(prefix netip.Prefix, addr string, intf string) error {
|
||||
ok, err := existsInRouteTable(prefix)
|
||||
if err != nil {
|
||||
return fmt.Errorf("exists in route table: %w", err)
|
||||
}
|
||||
if ok {
|
||||
log.Warnf("Skipping adding a new route for network %s because it already exists", prefix)
|
||||
return nil
|
||||
}
|
||||
|
||||
ok, err = isSubRange(prefix)
|
||||
if err != nil {
|
||||
return fmt.Errorf("sub range: %w", err)
|
||||
}
|
||||
|
||||
if ok {
|
||||
err := genericAddRouteForCurrentDefaultGateway(prefix)
|
||||
if err != nil {
|
||||
log.Warnf("Unable to add route for current default gateway route. Will proceed without it. error: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
return genericAddToRouteTable(prefix, addr, intf)
|
||||
}
|
||||
|
||||
func genericRemoveFromRouteTableIfNonSystem(prefix netip.Prefix, addr string, intf string) error {
|
||||
return genericRemoveFromRouteTable(prefix, addr, intf)
|
||||
}
|
||||
|
||||
func genericAddToRouteTable(prefix netip.Prefix, addr, _ string) error {
|
||||
cmd := exec.Command("route", "add", prefix.String(), addr)
|
||||
out, err := cmd.Output()
|
||||
if err != nil {
|
||||
return fmt.Errorf("add route: %w", err)
|
||||
}
|
||||
log.Debugf(string(out))
|
||||
return nil
|
||||
}
|
||||
|
||||
func genericRemoveFromRouteTable(prefix netip.Prefix, addr, _ string) error {
|
||||
args := []string{"delete", prefix.String()}
|
||||
if runtime.GOOS == "darwin" {
|
||||
args = append(args, addr)
|
||||
}
|
||||
cmd := exec.Command("route", args...)
|
||||
out, err := cmd.Output()
|
||||
if err != nil {
|
||||
return fmt.Errorf("remove route: %w", err)
|
||||
}
|
||||
log.Debugf(string(out))
|
||||
return nil
|
||||
}
|
||||
|
||||
func getExistingRIBRouteGateway(prefix netip.Prefix) (net.IP, error) {
|
||||
r, err := netroute.New()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("new netroute: %w", err)
|
||||
}
|
||||
_, gateway, preferredSrc, err := r.Route(prefix.Addr().AsSlice())
|
||||
if err != nil {
|
||||
log.Errorf("Getting routes returned an error: %v", err)
|
||||
return nil, errRouteNotFound
|
||||
}
|
||||
|
||||
if gateway == nil {
|
||||
return preferredSrc, nil
|
||||
}
|
||||
|
||||
return gateway, nil
|
||||
}
|
||||
|
||||
func existsInRouteTable(prefix netip.Prefix) (bool, error) {
|
||||
routes, err := getRoutesFromTable()
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("get routes from table: %w", err)
|
||||
}
|
||||
for _, tableRoute := range routes {
|
||||
if tableRoute == prefix {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func isSubRange(prefix netip.Prefix) (bool, error) {
|
||||
routes, err := getRoutesFromTable()
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("get routes from table: %w", err)
|
||||
}
|
||||
for _, tableRoute := range routes {
|
||||
if isPrefixSupported(tableRoute) && tableRoute.Contains(prefix.Addr()) && tableRoute.Bits() < prefix.Bits() {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
@@ -1,282 +0,0 @@
|
||||
//go:build !android
|
||||
|
||||
package routemanager
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"os/exec"
|
||||
"runtime"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/pion/transport/v3/stdnet"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
)
|
||||
|
||||
func assertWGOutInterface(t *testing.T, prefix netip.Prefix, wgIface *iface.WGIface, invert bool) {
|
||||
t.Helper()
|
||||
|
||||
if runtime.GOOS == "linux" {
|
||||
outIntf, err := getOutgoingInterfaceLinux(prefix.Addr().String())
|
||||
require.NoError(t, err, "getOutgoingInterfaceLinux should not return error")
|
||||
if invert {
|
||||
require.NotEqual(t, wgIface.Name(), outIntf, "outgoing interface should not be the wireguard interface")
|
||||
} else {
|
||||
require.Equal(t, wgIface.Name(), outIntf, "outgoing interface should be the wireguard interface")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
prefixGateway, err := getExistingRIBRouteGateway(prefix)
|
||||
require.NoError(t, err, "getExistingRIBRouteGateway should not return err")
|
||||
if invert {
|
||||
assert.NotEqual(t, wgIface.Address().IP.String(), prefixGateway.String(), "route should not point to wireguard interface IP")
|
||||
} else {
|
||||
assert.Equal(t, wgIface.Address().IP.String(), prefixGateway.String(), "route should point to wireguard interface IP")
|
||||
}
|
||||
}
|
||||
|
||||
func getOutgoingInterfaceLinux(destination string) (string, error) {
|
||||
cmd := exec.Command("ip", "route", "get", destination)
|
||||
output, err := cmd.Output()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("executing ip route get: %w", err)
|
||||
}
|
||||
|
||||
return parseOutgoingInterface(string(output)), nil
|
||||
}
|
||||
|
||||
func parseOutgoingInterface(routeGetOutput string) string {
|
||||
fields := strings.Fields(routeGetOutput)
|
||||
for i, field := range fields {
|
||||
if field == "dev" && i+1 < len(fields) {
|
||||
return fields[i+1]
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func TestAddRemoveRoutes(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
prefix netip.Prefix
|
||||
shouldRouteToWireguard bool
|
||||
shouldBeRemoved bool
|
||||
}{
|
||||
{
|
||||
name: "Should Add And Remove Route 100.66.120.0/24",
|
||||
prefix: netip.MustParsePrefix("100.66.120.0/24"),
|
||||
shouldRouteToWireguard: true,
|
||||
shouldBeRemoved: true,
|
||||
},
|
||||
{
|
||||
name: "Should Not Add Or Remove Route 127.0.0.1/32",
|
||||
prefix: netip.MustParsePrefix("127.0.0.1/32"),
|
||||
shouldRouteToWireguard: false,
|
||||
shouldBeRemoved: false,
|
||||
},
|
||||
}
|
||||
|
||||
for n, testCase := range testCases {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
peerPrivateKey, _ := wgtypes.GeneratePrivateKey()
|
||||
newNet, err := stdnet.NewNet()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
wgInterface, err := iface.NewWGIFace(fmt.Sprintf("utun53%d", n), "100.65.75.2/24", 33100, peerPrivateKey.String(), iface.DefaultMTU, newNet, nil)
|
||||
require.NoError(t, err, "should create testing WGIface interface")
|
||||
defer wgInterface.Close()
|
||||
|
||||
err = wgInterface.Create()
|
||||
require.NoError(t, err, "should create testing wireguard interface")
|
||||
|
||||
require.NoError(t, setupRouting())
|
||||
t.Cleanup(func() {
|
||||
assert.NoError(t, cleanupRouting())
|
||||
})
|
||||
|
||||
err = addToRouteTableIfNoExists(testCase.prefix, wgInterface.Address().IP.String(), wgInterface.Name())
|
||||
require.NoError(t, err, "addToRouteTableIfNoExists should not return err")
|
||||
|
||||
if testCase.shouldRouteToWireguard {
|
||||
assertWGOutInterface(t, testCase.prefix, wgInterface, false)
|
||||
} else {
|
||||
assertWGOutInterface(t, testCase.prefix, wgInterface, true)
|
||||
}
|
||||
exists, err := existsInRouteTable(testCase.prefix)
|
||||
require.NoError(t, err, "existsInRouteTable should not return err")
|
||||
if exists && testCase.shouldRouteToWireguard {
|
||||
err = removeFromRouteTableIfNonSystem(testCase.prefix, wgInterface.Address().IP.String(), wgInterface.Name())
|
||||
require.NoError(t, err, "removeFromRouteTableIfNonSystem should not return err")
|
||||
|
||||
prefixGateway, err := getExistingRIBRouteGateway(testCase.prefix)
|
||||
require.NoError(t, err, "getExistingRIBRouteGateway should not return err")
|
||||
|
||||
internetGateway, err := getExistingRIBRouteGateway(netip.MustParsePrefix("0.0.0.0/0"))
|
||||
require.NoError(t, err)
|
||||
|
||||
if testCase.shouldBeRemoved {
|
||||
require.Equal(t, internetGateway, prefixGateway, "route should be pointing to default internet gateway")
|
||||
} else {
|
||||
require.NotEqual(t, internetGateway, prefixGateway, "route should be pointing to a different gateway than the internet gateway")
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetExistingRIBRouteGateway(t *testing.T) {
|
||||
gateway, err := getExistingRIBRouteGateway(netip.MustParsePrefix("0.0.0.0/0"))
|
||||
if err != nil {
|
||||
t.Fatal("shouldn't return error when fetching the gateway: ", err)
|
||||
}
|
||||
if gateway == nil {
|
||||
t.Fatal("should return a gateway")
|
||||
}
|
||||
addresses, err := net.InterfaceAddrs()
|
||||
if err != nil {
|
||||
t.Fatal("shouldn't return error when fetching interface addresses: ", err)
|
||||
}
|
||||
|
||||
var testingIP string
|
||||
var testingPrefix netip.Prefix
|
||||
for _, address := range addresses {
|
||||
if address.Network() != "ip+net" {
|
||||
continue
|
||||
}
|
||||
prefix := netip.MustParsePrefix(address.String())
|
||||
if !prefix.Addr().IsLoopback() && prefix.Addr().Is4() {
|
||||
testingIP = prefix.Addr().String()
|
||||
testingPrefix = prefix.Masked()
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
localIP, err := getExistingRIBRouteGateway(testingPrefix)
|
||||
if err != nil {
|
||||
t.Fatal("shouldn't return error: ", err)
|
||||
}
|
||||
if localIP == nil {
|
||||
t.Fatal("should return a gateway for local network")
|
||||
}
|
||||
if localIP.String() == gateway.String() {
|
||||
t.Fatal("local ip should not match with gateway IP")
|
||||
}
|
||||
if localIP.String() != testingIP {
|
||||
t.Fatalf("local ip should match with testing IP: want %s got %s", testingIP, localIP.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestAddExistAndRemoveRouteNonAndroid(t *testing.T) {
|
||||
defaultGateway, err := getExistingRIBRouteGateway(netip.MustParsePrefix("0.0.0.0/0"))
|
||||
t.Log("defaultGateway: ", defaultGateway)
|
||||
if err != nil {
|
||||
t.Fatal("shouldn't return error when fetching the gateway: ", err)
|
||||
}
|
||||
testCases := []struct {
|
||||
name string
|
||||
prefix netip.Prefix
|
||||
preExistingPrefix netip.Prefix
|
||||
shouldAddRoute bool
|
||||
}{
|
||||
{
|
||||
name: "Should Add And Remove random Route",
|
||||
prefix: netip.MustParsePrefix("99.99.99.99/32"),
|
||||
shouldAddRoute: true,
|
||||
},
|
||||
{
|
||||
name: "Should Not Add Route if overlaps with default gateway",
|
||||
prefix: netip.MustParsePrefix(defaultGateway.String() + "/31"),
|
||||
shouldAddRoute: false,
|
||||
},
|
||||
{
|
||||
name: "Should Add Route if bigger network exists",
|
||||
prefix: netip.MustParsePrefix("100.100.100.0/24"),
|
||||
preExistingPrefix: netip.MustParsePrefix("100.100.0.0/16"),
|
||||
shouldAddRoute: true,
|
||||
},
|
||||
{
|
||||
name: "Should Add Route if smaller network exists",
|
||||
prefix: netip.MustParsePrefix("100.100.0.0/16"),
|
||||
preExistingPrefix: netip.MustParsePrefix("100.100.100.0/24"),
|
||||
shouldAddRoute: true,
|
||||
},
|
||||
{
|
||||
name: "Should Not Add Route if same network exists",
|
||||
prefix: netip.MustParsePrefix("100.100.0.0/16"),
|
||||
preExistingPrefix: netip.MustParsePrefix("100.100.0.0/16"),
|
||||
shouldAddRoute: false,
|
||||
},
|
||||
}
|
||||
|
||||
for n, testCase := range testCases {
|
||||
var buf bytes.Buffer
|
||||
log.SetOutput(&buf)
|
||||
defer func() {
|
||||
log.SetOutput(os.Stderr)
|
||||
}()
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
peerPrivateKey, _ := wgtypes.GeneratePrivateKey()
|
||||
newNet, err := stdnet.NewNet()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
wgInterface, err := iface.NewWGIFace(fmt.Sprintf("utun53%d", n), "100.65.75.2/24", 33100, peerPrivateKey.String(), iface.DefaultMTU, newNet, nil)
|
||||
require.NoError(t, err, "should create testing WGIface interface")
|
||||
defer wgInterface.Close()
|
||||
|
||||
err = wgInterface.Create()
|
||||
require.NoError(t, err, "should create testing wireguard interface")
|
||||
|
||||
require.NoError(t, setupRouting())
|
||||
t.Cleanup(func() {
|
||||
assert.NoError(t, cleanupRouting())
|
||||
})
|
||||
|
||||
MockAddr := wgInterface.Address().IP.String()
|
||||
|
||||
// Prepare the environment
|
||||
if testCase.preExistingPrefix.IsValid() {
|
||||
err := addToRouteTableIfNoExists(testCase.preExistingPrefix, MockAddr, wgInterface.Name())
|
||||
require.NoError(t, err, "should not return err when adding pre-existing route")
|
||||
}
|
||||
|
||||
// Add the route
|
||||
err = addToRouteTableIfNoExists(testCase.prefix, MockAddr, wgInterface.Name())
|
||||
require.NoError(t, err, "should not return err when adding route")
|
||||
|
||||
if testCase.shouldAddRoute {
|
||||
// test if route exists after adding
|
||||
ok, err := existsInRouteTable(testCase.prefix)
|
||||
require.NoError(t, err, "should not return err")
|
||||
require.True(t, ok, "route should exist")
|
||||
|
||||
// remove route again if added
|
||||
err = removeFromRouteTableIfNonSystem(testCase.prefix, MockAddr, wgInterface.Name())
|
||||
require.NoError(t, err, "should not return err")
|
||||
}
|
||||
|
||||
// route should either not have been added or should have been removed
|
||||
// In case of already existing route, it should not have been added (but still exist)
|
||||
ok, err := existsInRouteTable(testCase.prefix)
|
||||
t.Log("Buffer string: ", buf.String())
|
||||
require.NoError(t, err, "should not return err")
|
||||
|
||||
// Linux uses a separate routing table, so the route can exist in both tables.
|
||||
// The main routing table takes precedence over the wireguard routing table.
|
||||
if !strings.Contains(buf.String(), "because it already exists") && runtime.GOOS != "linux" {
|
||||
require.False(t, ok, "route should not exist")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,22 +1,24 @@
|
||||
//go:build !linux || android
|
||||
//go:build !linux && !ios
|
||||
|
||||
package routemanager
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"runtime"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
func setupRouting() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func cleanupRouting() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func enableIPForwarding() error {
|
||||
log.Infof("Enable IP forwarding is not implemented on %s", runtime.GOOS)
|
||||
return nil
|
||||
}
|
||||
|
||||
func addVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
|
||||
return genericAddVPNRoute(prefix, intf)
|
||||
}
|
||||
|
||||
func removeVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
|
||||
return genericRemoveVPNRoute(prefix, intf)
|
||||
}
|
||||
|
||||
@@ -1,80 +0,0 @@
|
||||
//go:build !linux || android
|
||||
|
||||
package routemanager
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestIsSubRange(t *testing.T) {
|
||||
addresses, err := net.InterfaceAddrs()
|
||||
if err != nil {
|
||||
t.Fatal("shouldn't return error when fetching interface addresses: ", err)
|
||||
}
|
||||
|
||||
var subRangeAddressPrefixes []netip.Prefix
|
||||
var nonSubRangeAddressPrefixes []netip.Prefix
|
||||
for _, address := range addresses {
|
||||
p := netip.MustParsePrefix(address.String())
|
||||
if !p.Addr().IsLoopback() && p.Addr().Is4() && p.Bits() < 32 {
|
||||
p2 := netip.PrefixFrom(p.Masked().Addr(), p.Bits()+1)
|
||||
subRangeAddressPrefixes = append(subRangeAddressPrefixes, p2)
|
||||
nonSubRangeAddressPrefixes = append(nonSubRangeAddressPrefixes, p.Masked())
|
||||
}
|
||||
}
|
||||
|
||||
for _, prefix := range subRangeAddressPrefixes {
|
||||
isSubRangePrefix, err := isSubRange(prefix)
|
||||
if err != nil {
|
||||
t.Fatal("shouldn't return error when checking if address is sub-range: ", err)
|
||||
}
|
||||
if !isSubRangePrefix {
|
||||
t.Fatalf("address %s should be sub-range of an existing route in the table", prefix)
|
||||
}
|
||||
}
|
||||
|
||||
for _, prefix := range nonSubRangeAddressPrefixes {
|
||||
isSubRangePrefix, err := isSubRange(prefix)
|
||||
if err != nil {
|
||||
t.Fatal("shouldn't return error when checking if address is sub-range: ", err)
|
||||
}
|
||||
if isSubRangePrefix {
|
||||
t.Fatalf("address %s should not be sub-range of an existing route in the table", prefix)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestExistsInRouteTable(t *testing.T) {
|
||||
require.NoError(t, setupRouting())
|
||||
t.Cleanup(func() {
|
||||
assert.NoError(t, cleanupRouting())
|
||||
})
|
||||
|
||||
addresses, err := net.InterfaceAddrs()
|
||||
if err != nil {
|
||||
t.Fatal("shouldn't return error when fetching interface addresses: ", err)
|
||||
}
|
||||
|
||||
var addressPrefixes []netip.Prefix
|
||||
for _, address := range addresses {
|
||||
p := netip.MustParsePrefix(address.String())
|
||||
if p.Addr().Is4() {
|
||||
addressPrefixes = append(addressPrefixes, p.Masked())
|
||||
}
|
||||
}
|
||||
|
||||
for _, prefix := range addressPrefixes {
|
||||
exists, err := existsInRouteTable(prefix)
|
||||
if err != nil {
|
||||
t.Fatal("shouldn't return error when checking if address exists in route table: ", err)
|
||||
}
|
||||
if !exists {
|
||||
t.Fatalf("address %s should exist in route table", prefix)
|
||||
}
|
||||
}
|
||||
}
|
||||
420
client/internal/routemanager/systemops_test.go
Normal file
420
client/internal/routemanager/systemops_test.go
Normal file
@@ -0,0 +1,420 @@
|
||||
//go:build !android && !ios
|
||||
|
||||
package routemanager
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"runtime"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/pion/transport/v3/stdnet"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
)
|
||||
|
||||
type dialer interface {
|
||||
Dial(network, address string) (net.Conn, error)
|
||||
DialContext(ctx context.Context, network, address string) (net.Conn, error)
|
||||
}
|
||||
|
||||
func TestAddRemoveRoutes(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
prefix netip.Prefix
|
||||
shouldRouteToWireguard bool
|
||||
shouldBeRemoved bool
|
||||
}{
|
||||
{
|
||||
name: "Should Add And Remove Route 100.66.120.0/24",
|
||||
prefix: netip.MustParsePrefix("100.66.120.0/24"),
|
||||
shouldRouteToWireguard: true,
|
||||
shouldBeRemoved: true,
|
||||
},
|
||||
{
|
||||
name: "Should Not Add Or Remove Route 127.0.0.1/32",
|
||||
prefix: netip.MustParsePrefix("127.0.0.1/32"),
|
||||
shouldRouteToWireguard: false,
|
||||
shouldBeRemoved: false,
|
||||
},
|
||||
}
|
||||
|
||||
for n, testCase := range testCases {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
t.Setenv("NB_DISABLE_ROUTE_CACHE", "true")
|
||||
|
||||
peerPrivateKey, _ := wgtypes.GeneratePrivateKey()
|
||||
newNet, err := stdnet.NewNet()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
wgInterface, err := iface.NewWGIFace(fmt.Sprintf("utun53%d", n), "100.65.75.2/24", 33100, peerPrivateKey.String(), iface.DefaultMTU, newNet, nil)
|
||||
require.NoError(t, err, "should create testing WGIface interface")
|
||||
defer wgInterface.Close()
|
||||
|
||||
err = wgInterface.Create()
|
||||
require.NoError(t, err, "should create testing wireguard interface")
|
||||
_, _, err = setupRouting(nil, wgInterface)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
assert.NoError(t, cleanupRouting())
|
||||
})
|
||||
|
||||
index, err := net.InterfaceByName(wgInterface.Name())
|
||||
require.NoError(t, err, "InterfaceByName should not return err")
|
||||
intf := &net.Interface{Index: index.Index, Name: wgInterface.Name()}
|
||||
|
||||
err = addVPNRoute(testCase.prefix, intf)
|
||||
require.NoError(t, err, "genericAddVPNRoute should not return err")
|
||||
|
||||
if testCase.shouldRouteToWireguard {
|
||||
assertWGOutInterface(t, testCase.prefix, wgInterface, false)
|
||||
} else {
|
||||
assertWGOutInterface(t, testCase.prefix, wgInterface, true)
|
||||
}
|
||||
exists, err := existsInRouteTable(testCase.prefix)
|
||||
require.NoError(t, err, "existsInRouteTable should not return err")
|
||||
if exists && testCase.shouldRouteToWireguard {
|
||||
err = removeVPNRoute(testCase.prefix, intf)
|
||||
require.NoError(t, err, "genericRemoveVPNRoute should not return err")
|
||||
|
||||
prefixGateway, _, err := getNextHop(testCase.prefix.Addr())
|
||||
require.NoError(t, err, "getNextHop should not return err")
|
||||
|
||||
internetGateway, _, err := getNextHop(netip.MustParseAddr("0.0.0.0"))
|
||||
require.NoError(t, err)
|
||||
|
||||
if testCase.shouldBeRemoved {
|
||||
require.Equal(t, internetGateway, prefixGateway, "route should be pointing to default internet gateway")
|
||||
} else {
|
||||
require.NotEqual(t, internetGateway, prefixGateway, "route should be pointing to a different gateway than the internet gateway")
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetNextHop(t *testing.T) {
|
||||
gateway, _, err := getNextHop(netip.MustParseAddr("0.0.0.0"))
|
||||
if err != nil {
|
||||
t.Fatal("shouldn't return error when fetching the gateway: ", err)
|
||||
}
|
||||
if !gateway.IsValid() {
|
||||
t.Fatal("should return a gateway")
|
||||
}
|
||||
addresses, err := net.InterfaceAddrs()
|
||||
if err != nil {
|
||||
t.Fatal("shouldn't return error when fetching interface addresses: ", err)
|
||||
}
|
||||
|
||||
var testingIP string
|
||||
var testingPrefix netip.Prefix
|
||||
for _, address := range addresses {
|
||||
if address.Network() != "ip+net" {
|
||||
continue
|
||||
}
|
||||
prefix := netip.MustParsePrefix(address.String())
|
||||
if !prefix.Addr().IsLoopback() && prefix.Addr().Is4() {
|
||||
testingIP = prefix.Addr().String()
|
||||
testingPrefix = prefix.Masked()
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
localIP, _, err := getNextHop(testingPrefix.Addr())
|
||||
if err != nil {
|
||||
t.Fatal("shouldn't return error: ", err)
|
||||
}
|
||||
if !localIP.IsValid() {
|
||||
t.Fatal("should return a gateway for local network")
|
||||
}
|
||||
if localIP.String() == gateway.String() {
|
||||
t.Fatal("local ip should not match with gateway IP")
|
||||
}
|
||||
if localIP.String() != testingIP {
|
||||
t.Fatalf("local ip should match with testing IP: want %s got %s", testingIP, localIP.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestAddExistAndRemoveRoute(t *testing.T) {
|
||||
defaultGateway, _, err := getNextHop(netip.MustParseAddr("0.0.0.0"))
|
||||
t.Log("defaultGateway: ", defaultGateway)
|
||||
if err != nil {
|
||||
t.Fatal("shouldn't return error when fetching the gateway: ", err)
|
||||
}
|
||||
testCases := []struct {
|
||||
name string
|
||||
prefix netip.Prefix
|
||||
preExistingPrefix netip.Prefix
|
||||
shouldAddRoute bool
|
||||
}{
|
||||
{
|
||||
name: "Should Add And Remove random Route",
|
||||
prefix: netip.MustParsePrefix("99.99.99.99/32"),
|
||||
shouldAddRoute: true,
|
||||
},
|
||||
{
|
||||
name: "Should Not Add Route if overlaps with default gateway",
|
||||
prefix: netip.MustParsePrefix(defaultGateway.String() + "/31"),
|
||||
shouldAddRoute: false,
|
||||
},
|
||||
{
|
||||
name: "Should Add Route if bigger network exists",
|
||||
prefix: netip.MustParsePrefix("100.100.100.0/24"),
|
||||
preExistingPrefix: netip.MustParsePrefix("100.100.0.0/16"),
|
||||
shouldAddRoute: true,
|
||||
},
|
||||
{
|
||||
name: "Should Add Route if smaller network exists",
|
||||
prefix: netip.MustParsePrefix("100.100.0.0/16"),
|
||||
preExistingPrefix: netip.MustParsePrefix("100.100.100.0/24"),
|
||||
shouldAddRoute: true,
|
||||
},
|
||||
{
|
||||
name: "Should Not Add Route if same network exists",
|
||||
prefix: netip.MustParsePrefix("100.100.0.0/16"),
|
||||
preExistingPrefix: netip.MustParsePrefix("100.100.0.0/16"),
|
||||
shouldAddRoute: false,
|
||||
},
|
||||
}
|
||||
|
||||
for n, testCase := range testCases {
|
||||
|
||||
var buf bytes.Buffer
|
||||
log.SetOutput(&buf)
|
||||
defer func() {
|
||||
log.SetOutput(os.Stderr)
|
||||
}()
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
t.Setenv("NB_USE_LEGACY_ROUTING", "true")
|
||||
t.Setenv("NB_DISABLE_ROUTE_CACHE", "true")
|
||||
|
||||
peerPrivateKey, _ := wgtypes.GeneratePrivateKey()
|
||||
newNet, err := stdnet.NewNet()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
wgInterface, err := iface.NewWGIFace(fmt.Sprintf("utun53%d", n), "100.65.75.2/24", 33100, peerPrivateKey.String(), iface.DefaultMTU, newNet, nil)
|
||||
require.NoError(t, err, "should create testing WGIface interface")
|
||||
defer wgInterface.Close()
|
||||
|
||||
err = wgInterface.Create()
|
||||
require.NoError(t, err, "should create testing wireguard interface")
|
||||
|
||||
index, err := net.InterfaceByName(wgInterface.Name())
|
||||
require.NoError(t, err, "InterfaceByName should not return err")
|
||||
intf := &net.Interface{Index: index.Index, Name: wgInterface.Name()}
|
||||
|
||||
// Prepare the environment
|
||||
if testCase.preExistingPrefix.IsValid() {
|
||||
err := addVPNRoute(testCase.preExistingPrefix, intf)
|
||||
require.NoError(t, err, "should not return err when adding pre-existing route")
|
||||
}
|
||||
|
||||
// Add the route
|
||||
err = addVPNRoute(testCase.prefix, intf)
|
||||
require.NoError(t, err, "should not return err when adding route")
|
||||
|
||||
if testCase.shouldAddRoute {
|
||||
// test if route exists after adding
|
||||
ok, err := existsInRouteTable(testCase.prefix)
|
||||
require.NoError(t, err, "should not return err")
|
||||
require.True(t, ok, "route should exist")
|
||||
|
||||
// remove route again if added
|
||||
err = removeVPNRoute(testCase.prefix, intf)
|
||||
require.NoError(t, err, "should not return err")
|
||||
}
|
||||
|
||||
// route should either not have been added or should have been removed
|
||||
// In case of already existing route, it should not have been added (but still exist)
|
||||
ok, err := existsInRouteTable(testCase.prefix)
|
||||
t.Log("Buffer string: ", buf.String())
|
||||
require.NoError(t, err, "should not return err")
|
||||
|
||||
if !strings.Contains(buf.String(), "because it already exists") {
|
||||
require.False(t, ok, "route should not exist")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsSubRange(t *testing.T) {
|
||||
addresses, err := net.InterfaceAddrs()
|
||||
if err != nil {
|
||||
t.Fatal("shouldn't return error when fetching interface addresses: ", err)
|
||||
}
|
||||
|
||||
var subRangeAddressPrefixes []netip.Prefix
|
||||
var nonSubRangeAddressPrefixes []netip.Prefix
|
||||
for _, address := range addresses {
|
||||
p := netip.MustParsePrefix(address.String())
|
||||
if !p.Addr().IsLoopback() && p.Addr().Is4() && p.Bits() < 32 {
|
||||
p2 := netip.PrefixFrom(p.Masked().Addr(), p.Bits()+1)
|
||||
subRangeAddressPrefixes = append(subRangeAddressPrefixes, p2)
|
||||
nonSubRangeAddressPrefixes = append(nonSubRangeAddressPrefixes, p.Masked())
|
||||
}
|
||||
}
|
||||
|
||||
for _, prefix := range subRangeAddressPrefixes {
|
||||
isSubRangePrefix, err := isSubRange(prefix)
|
||||
if err != nil {
|
||||
t.Fatal("shouldn't return error when checking if address is sub-range: ", err)
|
||||
}
|
||||
if !isSubRangePrefix {
|
||||
t.Fatalf("address %s should be sub-range of an existing route in the table", prefix)
|
||||
}
|
||||
}
|
||||
|
||||
for _, prefix := range nonSubRangeAddressPrefixes {
|
||||
isSubRangePrefix, err := isSubRange(prefix)
|
||||
if err != nil {
|
||||
t.Fatal("shouldn't return error when checking if address is sub-range: ", err)
|
||||
}
|
||||
if isSubRangePrefix {
|
||||
t.Fatalf("address %s should not be sub-range of an existing route in the table", prefix)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestExistsInRouteTable(t *testing.T) {
|
||||
addresses, err := net.InterfaceAddrs()
|
||||
if err != nil {
|
||||
t.Fatal("shouldn't return error when fetching interface addresses: ", err)
|
||||
}
|
||||
|
||||
var addressPrefixes []netip.Prefix
|
||||
for _, address := range addresses {
|
||||
p := netip.MustParsePrefix(address.String())
|
||||
if p.Addr().Is6() {
|
||||
continue
|
||||
}
|
||||
// Windows sometimes has hidden interface link local addrs that don't turn up on any interface
|
||||
if runtime.GOOS == "windows" && p.Addr().IsLinkLocalUnicast() {
|
||||
continue
|
||||
}
|
||||
// Linux loopback 127/8 is in the local table, not in the main table and always takes precedence
|
||||
if runtime.GOOS == "linux" && p.Addr().IsLoopback() {
|
||||
continue
|
||||
}
|
||||
|
||||
addressPrefixes = append(addressPrefixes, p.Masked())
|
||||
}
|
||||
|
||||
for _, prefix := range addressPrefixes {
|
||||
exists, err := existsInRouteTable(prefix)
|
||||
if err != nil {
|
||||
t.Fatal("shouldn't return error when checking if address exists in route table: ", err)
|
||||
}
|
||||
if !exists {
|
||||
t.Fatalf("address %s should exist in route table", prefix)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func createWGInterface(t *testing.T, interfaceName, ipAddressCIDR string, listenPort int) *iface.WGIface {
|
||||
t.Helper()
|
||||
|
||||
peerPrivateKey, err := wgtypes.GeneratePrivateKey()
|
||||
require.NoError(t, err)
|
||||
|
||||
newNet, err := stdnet.NewNet()
|
||||
require.NoError(t, err)
|
||||
|
||||
wgInterface, err := iface.NewWGIFace(interfaceName, ipAddressCIDR, listenPort, peerPrivateKey.String(), iface.DefaultMTU, newNet, nil)
|
||||
require.NoError(t, err, "should create testing WireGuard interface")
|
||||
|
||||
err = wgInterface.Create()
|
||||
require.NoError(t, err, "should create testing WireGuard interface")
|
||||
|
||||
t.Cleanup(func() {
|
||||
wgInterface.Close()
|
||||
})
|
||||
|
||||
return wgInterface
|
||||
}
|
||||
|
||||
func setupTestEnv(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
setupDummyInterfacesAndRoutes(t)
|
||||
|
||||
wgIface := createWGInterface(t, expectedVPNint, "100.64.0.1/24", 51820)
|
||||
t.Cleanup(func() {
|
||||
assert.NoError(t, wgIface.Close())
|
||||
})
|
||||
|
||||
_, _, err := setupRouting(nil, wgIface)
|
||||
require.NoError(t, err, "setupRouting should not return err")
|
||||
t.Cleanup(func() {
|
||||
assert.NoError(t, cleanupRouting())
|
||||
})
|
||||
|
||||
index, err := net.InterfaceByName(wgIface.Name())
|
||||
require.NoError(t, err, "InterfaceByName should not return err")
|
||||
intf := &net.Interface{Index: index.Index, Name: wgIface.Name()}
|
||||
|
||||
// default route exists in main table and vpn table
|
||||
err = addVPNRoute(netip.MustParsePrefix("0.0.0.0/0"), intf)
|
||||
require.NoError(t, err, "addVPNRoute should not return err")
|
||||
t.Cleanup(func() {
|
||||
err = removeVPNRoute(netip.MustParsePrefix("0.0.0.0/0"), intf)
|
||||
assert.NoError(t, err, "removeVPNRoute should not return err")
|
||||
})
|
||||
|
||||
// 10.0.0.0/8 route exists in main table and vpn table
|
||||
err = addVPNRoute(netip.MustParsePrefix("10.0.0.0/8"), intf)
|
||||
require.NoError(t, err, "addVPNRoute should not return err")
|
||||
t.Cleanup(func() {
|
||||
err = removeVPNRoute(netip.MustParsePrefix("10.0.0.0/8"), intf)
|
||||
assert.NoError(t, err, "removeVPNRoute should not return err")
|
||||
})
|
||||
|
||||
// 10.10.0.0/24 more specific route exists in vpn table
|
||||
err = addVPNRoute(netip.MustParsePrefix("10.10.0.0/24"), intf)
|
||||
require.NoError(t, err, "addVPNRoute should not return err")
|
||||
t.Cleanup(func() {
|
||||
err = removeVPNRoute(netip.MustParsePrefix("10.10.0.0/24"), intf)
|
||||
assert.NoError(t, err, "removeVPNRoute should not return err")
|
||||
})
|
||||
|
||||
// 127.0.10.0/24 more specific route exists in vpn table
|
||||
err = addVPNRoute(netip.MustParsePrefix("127.0.10.0/24"), intf)
|
||||
require.NoError(t, err, "addVPNRoute should not return err")
|
||||
t.Cleanup(func() {
|
||||
err = removeVPNRoute(netip.MustParsePrefix("127.0.10.0/24"), intf)
|
||||
assert.NoError(t, err, "removeVPNRoute should not return err")
|
||||
})
|
||||
|
||||
// unique route in vpn table
|
||||
err = addVPNRoute(netip.MustParsePrefix("172.16.0.0/12"), intf)
|
||||
require.NoError(t, err, "addVPNRoute should not return err")
|
||||
t.Cleanup(func() {
|
||||
err = removeVPNRoute(netip.MustParsePrefix("172.16.0.0/12"), intf)
|
||||
assert.NoError(t, err, "removeVPNRoute should not return err")
|
||||
})
|
||||
}
|
||||
|
||||
func assertWGOutInterface(t *testing.T, prefix netip.Prefix, wgIface *iface.WGIface, invert bool) {
|
||||
t.Helper()
|
||||
if runtime.GOOS == "linux" && prefix.Addr().IsLoopback() {
|
||||
return
|
||||
}
|
||||
|
||||
prefixGateway, _, err := getNextHop(prefix.Addr())
|
||||
require.NoError(t, err, "getNextHop should not return err")
|
||||
if invert {
|
||||
assert.NotEqual(t, wgIface.Address().IP.String(), prefixGateway.String(), "route should not point to wireguard interface IP")
|
||||
} else {
|
||||
assert.Equal(t, wgIface.Address().IP.String(), prefixGateway.String(), "route should point to wireguard interface IP")
|
||||
}
|
||||
}
|
||||
234
client/internal/routemanager/systemops_unix_test.go
Normal file
234
client/internal/routemanager/systemops_unix_test.go
Normal file
@@ -0,0 +1,234 @@
|
||||
//go:build (linux && !android) || (darwin && !ios) || freebsd || openbsd || netbsd || dragonfly
|
||||
|
||||
package routemanager
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gopacket/gopacket"
|
||||
"github.com/gopacket/gopacket/layers"
|
||||
"github.com/gopacket/gopacket/pcap"
|
||||
"github.com/miekg/dns"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
)
|
||||
|
||||
type PacketExpectation struct {
|
||||
SrcIP net.IP
|
||||
DstIP net.IP
|
||||
SrcPort int
|
||||
DstPort int
|
||||
UDP bool
|
||||
TCP bool
|
||||
}
|
||||
|
||||
type testCase struct {
|
||||
name string
|
||||
destination string
|
||||
expectedInterface string
|
||||
dialer dialer
|
||||
expectedPacket PacketExpectation
|
||||
}
|
||||
|
||||
var testCases = []testCase{
|
||||
{
|
||||
name: "To external host without custom dialer via vpn",
|
||||
destination: "192.0.2.1:53",
|
||||
expectedInterface: expectedVPNint,
|
||||
dialer: &net.Dialer{},
|
||||
expectedPacket: createPacketExpectation("100.64.0.1", 12345, "192.0.2.1", 53),
|
||||
},
|
||||
{
|
||||
name: "To external host with custom dialer via physical interface",
|
||||
destination: "192.0.2.1:53",
|
||||
expectedInterface: expectedExternalInt,
|
||||
dialer: nbnet.NewDialer(),
|
||||
expectedPacket: createPacketExpectation("192.168.0.1", 12345, "192.0.2.1", 53),
|
||||
},
|
||||
|
||||
{
|
||||
name: "To duplicate internal route with custom dialer via physical interface",
|
||||
destination: "10.0.0.2:53",
|
||||
expectedInterface: expectedInternalInt,
|
||||
dialer: nbnet.NewDialer(),
|
||||
expectedPacket: createPacketExpectation("192.168.1.1", 12345, "10.0.0.2", 53),
|
||||
},
|
||||
{
|
||||
name: "To duplicate internal route without custom dialer via physical interface", // local route takes precedence
|
||||
destination: "10.0.0.2:53",
|
||||
expectedInterface: expectedInternalInt,
|
||||
dialer: &net.Dialer{},
|
||||
expectedPacket: createPacketExpectation("192.168.1.1", 12345, "10.0.0.2", 53),
|
||||
},
|
||||
|
||||
{
|
||||
name: "To unique vpn route with custom dialer via physical interface",
|
||||
destination: "172.16.0.2:53",
|
||||
expectedInterface: expectedExternalInt,
|
||||
dialer: nbnet.NewDialer(),
|
||||
expectedPacket: createPacketExpectation("192.168.0.1", 12345, "172.16.0.2", 53),
|
||||
},
|
||||
{
|
||||
name: "To unique vpn route without custom dialer via vpn",
|
||||
destination: "172.16.0.2:53",
|
||||
expectedInterface: expectedVPNint,
|
||||
dialer: &net.Dialer{},
|
||||
expectedPacket: createPacketExpectation("100.64.0.1", 12345, "172.16.0.2", 53),
|
||||
},
|
||||
}
|
||||
|
||||
func TestRouting(t *testing.T) {
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
setupTestEnv(t)
|
||||
|
||||
filter := createBPFFilter(tc.destination)
|
||||
handle := startPacketCapture(t, tc.expectedInterface, filter)
|
||||
|
||||
sendTestPacket(t, tc.destination, tc.expectedPacket.SrcPort, tc.dialer)
|
||||
|
||||
packetSource := gopacket.NewPacketSource(handle, handle.LinkType())
|
||||
packet, err := packetSource.NextPacket()
|
||||
require.NoError(t, err)
|
||||
|
||||
verifyPacket(t, packet, tc.expectedPacket)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func createPacketExpectation(srcIP string, srcPort int, dstIP string, dstPort int) PacketExpectation {
|
||||
return PacketExpectation{
|
||||
SrcIP: net.ParseIP(srcIP),
|
||||
DstIP: net.ParseIP(dstIP),
|
||||
SrcPort: srcPort,
|
||||
DstPort: dstPort,
|
||||
UDP: true,
|
||||
}
|
||||
}
|
||||
|
||||
func startPacketCapture(t *testing.T, intf, filter string) *pcap.Handle {
|
||||
t.Helper()
|
||||
|
||||
inactive, err := pcap.NewInactiveHandle(intf)
|
||||
require.NoError(t, err, "Failed to create inactive pcap handle")
|
||||
defer inactive.CleanUp()
|
||||
|
||||
err = inactive.SetSnapLen(1600)
|
||||
require.NoError(t, err, "Failed to set snap length on inactive handle")
|
||||
|
||||
err = inactive.SetTimeout(time.Second * 10)
|
||||
require.NoError(t, err, "Failed to set timeout on inactive handle")
|
||||
|
||||
err = inactive.SetImmediateMode(true)
|
||||
require.NoError(t, err, "Failed to set immediate mode on inactive handle")
|
||||
|
||||
handle, err := inactive.Activate()
|
||||
require.NoError(t, err, "Failed to activate pcap handle")
|
||||
t.Cleanup(handle.Close)
|
||||
|
||||
err = handle.SetBPFFilter(filter)
|
||||
require.NoError(t, err, "Failed to set BPF filter")
|
||||
|
||||
return handle
|
||||
}
|
||||
|
||||
func sendTestPacket(t *testing.T, destination string, sourcePort int, dialer dialer) {
|
||||
t.Helper()
|
||||
|
||||
if dialer == nil {
|
||||
dialer = &net.Dialer{}
|
||||
}
|
||||
|
||||
if sourcePort != 0 {
|
||||
localUDPAddr := &net.UDPAddr{
|
||||
IP: net.IPv4zero,
|
||||
Port: sourcePort,
|
||||
}
|
||||
switch dialer := dialer.(type) {
|
||||
case *nbnet.Dialer:
|
||||
dialer.LocalAddr = localUDPAddr
|
||||
case *net.Dialer:
|
||||
dialer.LocalAddr = localUDPAddr
|
||||
default:
|
||||
t.Fatal("Unsupported dialer type")
|
||||
}
|
||||
}
|
||||
|
||||
msg := new(dns.Msg)
|
||||
msg.Id = dns.Id()
|
||||
msg.RecursionDesired = true
|
||||
msg.Question = []dns.Question{
|
||||
{Name: "example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET},
|
||||
}
|
||||
|
||||
conn, err := dialer.Dial("udp", destination)
|
||||
require.NoError(t, err, "Failed to dial UDP")
|
||||
defer conn.Close()
|
||||
|
||||
data, err := msg.Pack()
|
||||
require.NoError(t, err, "Failed to pack DNS message")
|
||||
|
||||
_, err = conn.Write(data)
|
||||
if err != nil {
|
||||
if strings.Contains(err.Error(), "required key not available") {
|
||||
t.Logf("Ignoring WireGuard key error: %v", err)
|
||||
return
|
||||
}
|
||||
t.Fatalf("Failed to send DNS query: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func createBPFFilter(destination string) string {
|
||||
host, port, err := net.SplitHostPort(destination)
|
||||
if err != nil {
|
||||
return fmt.Sprintf("udp and dst host %s and dst port %s", host, port)
|
||||
}
|
||||
return "udp"
|
||||
}
|
||||
|
||||
func verifyPacket(t *testing.T, packet gopacket.Packet, exp PacketExpectation) {
|
||||
t.Helper()
|
||||
|
||||
ipLayer := packet.Layer(layers.LayerTypeIPv4)
|
||||
require.NotNil(t, ipLayer, "Expected IPv4 layer not found in packet")
|
||||
|
||||
ip, ok := ipLayer.(*layers.IPv4)
|
||||
require.True(t, ok, "Failed to cast to IPv4 layer")
|
||||
|
||||
// Convert both source and destination IP addresses to 16-byte representation
|
||||
expectedSrcIP := exp.SrcIP.To16()
|
||||
actualSrcIP := ip.SrcIP.To16()
|
||||
assert.Equal(t, expectedSrcIP, actualSrcIP, "Source IP mismatch")
|
||||
|
||||
expectedDstIP := exp.DstIP.To16()
|
||||
actualDstIP := ip.DstIP.To16()
|
||||
assert.Equal(t, expectedDstIP, actualDstIP, "Destination IP mismatch")
|
||||
|
||||
if exp.UDP {
|
||||
udpLayer := packet.Layer(layers.LayerTypeUDP)
|
||||
require.NotNil(t, udpLayer, "Expected UDP layer not found in packet")
|
||||
|
||||
udp, ok := udpLayer.(*layers.UDP)
|
||||
require.True(t, ok, "Failed to cast to UDP layer")
|
||||
|
||||
assert.Equal(t, layers.UDPPort(exp.SrcPort), udp.SrcPort, "UDP source port mismatch")
|
||||
assert.Equal(t, layers.UDPPort(exp.DstPort), udp.DstPort, "UDP destination port mismatch")
|
||||
}
|
||||
|
||||
if exp.TCP {
|
||||
tcpLayer := packet.Layer(layers.LayerTypeTCP)
|
||||
require.NotNil(t, tcpLayer, "Expected TCP layer not found in packet")
|
||||
|
||||
tcp, ok := tcpLayer.(*layers.TCP)
|
||||
require.True(t, ok, "Failed to cast to TCP layer")
|
||||
|
||||
assert.Equal(t, layers.TCPPort(exp.SrcPort), tcp.SrcPort, "TCP source port mismatch")
|
||||
assert.Equal(t, layers.TCPPort(exp.DstPort), tcp.DstPort, "TCP destination port mismatch")
|
||||
}
|
||||
}
|
||||
@@ -6,9 +6,18 @@ import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"os/exec"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/yusufpapurcu/wmi"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
)
|
||||
|
||||
type Win32_IP4RouteTable struct {
|
||||
@@ -16,16 +25,38 @@ type Win32_IP4RouteTable struct {
|
||||
Mask string
|
||||
}
|
||||
|
||||
var prefixList []netip.Prefix
|
||||
var lastUpdate time.Time
|
||||
var mux = sync.Mutex{}
|
||||
|
||||
var routeManager *RouteManager
|
||||
|
||||
func setupRouting(initAddresses []net.IP, wgIface *iface.WGIface) (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) {
|
||||
return setupRoutingWithRouteManager(&routeManager, initAddresses, wgIface)
|
||||
}
|
||||
|
||||
func cleanupRouting() error {
|
||||
return cleanupRoutingWithRouteManager(routeManager)
|
||||
}
|
||||
|
||||
func getRoutesFromTable() ([]netip.Prefix, error) {
|
||||
var routes []Win32_IP4RouteTable
|
||||
mux.Lock()
|
||||
defer mux.Unlock()
|
||||
|
||||
query := "SELECT Destination, Mask FROM Win32_IP4RouteTable"
|
||||
|
||||
// If many routes are added at the same time this might block for a long time (seconds to minutes), so we cache the result
|
||||
if !isCacheDisabled() && time.Since(lastUpdate) < 2*time.Second {
|
||||
return prefixList, nil
|
||||
}
|
||||
|
||||
var routes []Win32_IP4RouteTable
|
||||
err := wmi.Query(query, &routes)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get routes: %w", err)
|
||||
}
|
||||
|
||||
var prefixList []netip.Prefix
|
||||
prefixList = nil
|
||||
for _, route := range routes {
|
||||
addr, err := netip.ParseAddr(route.Destination)
|
||||
if err != nil {
|
||||
@@ -45,13 +76,66 @@ func getRoutesFromTable() ([]netip.Prefix, error) {
|
||||
prefixList = append(prefixList, routePrefix)
|
||||
}
|
||||
}
|
||||
|
||||
lastUpdate = time.Now()
|
||||
return prefixList, nil
|
||||
}
|
||||
|
||||
func addToRouteTableIfNoExists(prefix netip.Prefix, addr string, intf string) error {
|
||||
return genericAddToRouteTableIfNoExists(prefix, addr, intf)
|
||||
func addRouteCmd(prefix netip.Prefix, nexthop netip.Addr, intf *net.Interface) error {
|
||||
args := []string{"add", prefix.String()}
|
||||
|
||||
if nexthop.IsValid() {
|
||||
args = append(args, nexthop.Unmap().String())
|
||||
} else {
|
||||
addr := "0.0.0.0"
|
||||
if prefix.Addr().Is6() {
|
||||
addr = "::"
|
||||
}
|
||||
args = append(args, addr)
|
||||
}
|
||||
|
||||
if intf != nil {
|
||||
args = append(args, "if", strconv.Itoa(intf.Index))
|
||||
}
|
||||
|
||||
out, err := exec.Command("route", args...).CombinedOutput()
|
||||
log.Tracef("route %s: %s", strings.Join(args, " "), out)
|
||||
if err != nil {
|
||||
return fmt.Errorf("route add: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func removeFromRouteTableIfNonSystem(prefix netip.Prefix, addr string, intf string) error {
|
||||
return genericRemoveFromRouteTableIfNonSystem(prefix, addr, intf)
|
||||
func addToRouteTable(prefix netip.Prefix, nexthop netip.Addr, intf *net.Interface) error {
|
||||
if nexthop.Zone() != "" && intf == nil {
|
||||
zone, err := strconv.Atoi(nexthop.Zone())
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid zone: %w", err)
|
||||
}
|
||||
intf = &net.Interface{Index: zone}
|
||||
nexthop.WithZone("")
|
||||
}
|
||||
|
||||
return addRouteCmd(prefix, nexthop, intf)
|
||||
}
|
||||
|
||||
func removeFromRouteTable(prefix netip.Prefix, nexthop netip.Addr, _ *net.Interface) error {
|
||||
args := []string{"delete", prefix.String()}
|
||||
if nexthop.IsValid() {
|
||||
nexthop.WithZone("")
|
||||
args = append(args, nexthop.Unmap().String())
|
||||
}
|
||||
|
||||
out, err := exec.Command("route", args...).CombinedOutput()
|
||||
log.Tracef("route %s: %s", strings.Join(args, " "), out)
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("remove route: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func isCacheDisabled() bool {
|
||||
return os.Getenv("NB_DISABLE_ROUTE_CACHE") == "true"
|
||||
}
|
||||
|
||||
289
client/internal/routemanager/systemops_windows_test.go
Normal file
289
client/internal/routemanager/systemops_windows_test.go
Normal file
@@ -0,0 +1,289 @@
|
||||
package routemanager
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net"
|
||||
"os/exec"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
)
|
||||
|
||||
var expectedExtInt = "Ethernet1"
|
||||
|
||||
type RouteInfo struct {
|
||||
NextHop string `json:"nexthop"`
|
||||
InterfaceAlias string `json:"interfacealias"`
|
||||
RouteMetric int `json:"routemetric"`
|
||||
}
|
||||
|
||||
type FindNetRouteOutput struct {
|
||||
IPAddress string `json:"IPAddress"`
|
||||
InterfaceIndex int `json:"InterfaceIndex"`
|
||||
InterfaceAlias string `json:"InterfaceAlias"`
|
||||
AddressFamily int `json:"AddressFamily"`
|
||||
NextHop string `json:"NextHop"`
|
||||
DestinationPrefix string `json:"DestinationPrefix"`
|
||||
}
|
||||
|
||||
type testCase struct {
|
||||
name string
|
||||
destination string
|
||||
expectedSourceIP string
|
||||
expectedDestPrefix string
|
||||
expectedNextHop string
|
||||
expectedInterface string
|
||||
dialer dialer
|
||||
}
|
||||
|
||||
var expectedVPNint = "wgtest0"
|
||||
|
||||
var testCases = []testCase{
|
||||
{
|
||||
name: "To external host without custom dialer via vpn",
|
||||
destination: "192.0.2.1:53",
|
||||
expectedSourceIP: "100.64.0.1",
|
||||
expectedDestPrefix: "128.0.0.0/1",
|
||||
expectedNextHop: "0.0.0.0",
|
||||
expectedInterface: "wgtest0",
|
||||
dialer: &net.Dialer{},
|
||||
},
|
||||
{
|
||||
name: "To external host with custom dialer via physical interface",
|
||||
destination: "192.0.2.1:53",
|
||||
expectedDestPrefix: "192.0.2.1/32",
|
||||
expectedInterface: expectedExtInt,
|
||||
dialer: nbnet.NewDialer(),
|
||||
},
|
||||
|
||||
{
|
||||
name: "To duplicate internal route with custom dialer via physical interface",
|
||||
destination: "10.0.0.2:53",
|
||||
expectedDestPrefix: "10.0.0.2/32",
|
||||
expectedInterface: expectedExtInt,
|
||||
dialer: nbnet.NewDialer(),
|
||||
},
|
||||
{
|
||||
name: "To duplicate internal route without custom dialer via physical interface", // local route takes precedence
|
||||
destination: "10.0.0.2:53",
|
||||
expectedSourceIP: "10.0.0.1",
|
||||
expectedDestPrefix: "10.0.0.0/8",
|
||||
expectedNextHop: "0.0.0.0",
|
||||
expectedInterface: "Loopback Pseudo-Interface 1",
|
||||
dialer: &net.Dialer{},
|
||||
},
|
||||
|
||||
{
|
||||
name: "To unique vpn route with custom dialer via physical interface",
|
||||
destination: "172.16.0.2:53",
|
||||
expectedDestPrefix: "172.16.0.2/32",
|
||||
expectedInterface: expectedExtInt,
|
||||
dialer: nbnet.NewDialer(),
|
||||
},
|
||||
{
|
||||
name: "To unique vpn route without custom dialer via vpn",
|
||||
destination: "172.16.0.2:53",
|
||||
expectedSourceIP: "100.64.0.1",
|
||||
expectedDestPrefix: "172.16.0.0/12",
|
||||
expectedNextHop: "0.0.0.0",
|
||||
expectedInterface: "wgtest0",
|
||||
dialer: &net.Dialer{},
|
||||
},
|
||||
|
||||
{
|
||||
name: "To more specific route without custom dialer via vpn interface",
|
||||
destination: "10.10.0.2:53",
|
||||
expectedSourceIP: "100.64.0.1",
|
||||
expectedDestPrefix: "10.10.0.0/24",
|
||||
expectedNextHop: "0.0.0.0",
|
||||
expectedInterface: "wgtest0",
|
||||
dialer: &net.Dialer{},
|
||||
},
|
||||
|
||||
{
|
||||
name: "To more specific route (local) without custom dialer via physical interface",
|
||||
destination: "127.0.10.2:53",
|
||||
expectedSourceIP: "10.0.0.1",
|
||||
expectedDestPrefix: "127.0.0.0/8",
|
||||
expectedNextHop: "0.0.0.0",
|
||||
expectedInterface: "Loopback Pseudo-Interface 1",
|
||||
dialer: &net.Dialer{},
|
||||
},
|
||||
}
|
||||
|
||||
func TestRouting(t *testing.T) {
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
setupTestEnv(t)
|
||||
|
||||
route, err := fetchOriginalGateway()
|
||||
require.NoError(t, err, "Failed to fetch original gateway")
|
||||
ip, err := fetchInterfaceIP(route.InterfaceAlias)
|
||||
require.NoError(t, err, "Failed to fetch interface IP")
|
||||
|
||||
output := testRoute(t, tc.destination, tc.dialer)
|
||||
if tc.expectedInterface == expectedExtInt {
|
||||
verifyOutput(t, output, ip, tc.expectedDestPrefix, route.NextHop, route.InterfaceAlias)
|
||||
} else {
|
||||
verifyOutput(t, output, tc.expectedSourceIP, tc.expectedDestPrefix, tc.expectedNextHop, tc.expectedInterface)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// fetchInterfaceIP fetches the IPv4 address of the specified interface.
|
||||
func fetchInterfaceIP(interfaceAlias string) (string, error) {
|
||||
script := fmt.Sprintf(`Get-NetIPAddress -InterfaceAlias "%s" | Where-Object AddressFamily -eq 2 | Select-Object -ExpandProperty IPAddress`, interfaceAlias)
|
||||
out, err := exec.Command("powershell", "-Command", script).Output()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to execute Get-NetIPAddress: %w", err)
|
||||
}
|
||||
|
||||
ip := strings.TrimSpace(string(out))
|
||||
return ip, nil
|
||||
}
|
||||
|
||||
func testRoute(t *testing.T, destination string, dialer dialer) *FindNetRouteOutput {
|
||||
t.Helper()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
|
||||
defer cancel()
|
||||
|
||||
conn, err := dialer.DialContext(ctx, "udp", destination)
|
||||
require.NoError(t, err, "Failed to dial destination")
|
||||
defer func() {
|
||||
err := conn.Close()
|
||||
assert.NoError(t, err, "Failed to close connection")
|
||||
}()
|
||||
|
||||
host, _, err := net.SplitHostPort(destination)
|
||||
require.NoError(t, err)
|
||||
|
||||
script := fmt.Sprintf(`Find-NetRoute -RemoteIPAddress "%s" | Select-Object -Property IPAddress, InterfaceIndex, InterfaceAlias, AddressFamily, NextHop, DestinationPrefix | ConvertTo-Json`, host)
|
||||
|
||||
out, err := exec.Command("powershell", "-Command", script).Output()
|
||||
require.NoError(t, err, "Failed to execute Find-NetRoute")
|
||||
|
||||
var outputs []FindNetRouteOutput
|
||||
err = json.Unmarshal(out, &outputs)
|
||||
require.NoError(t, err, "Failed to parse JSON outputs from Find-NetRoute")
|
||||
|
||||
require.Greater(t, len(outputs), 0, "No route found for destination")
|
||||
combinedOutput := combineOutputs(outputs)
|
||||
|
||||
return combinedOutput
|
||||
}
|
||||
|
||||
func createAndSetupDummyInterface(t *testing.T, interfaceName, ipAddressCIDR string) string {
|
||||
t.Helper()
|
||||
|
||||
ip, ipNet, err := net.ParseCIDR(ipAddressCIDR)
|
||||
require.NoError(t, err)
|
||||
subnetMaskSize, _ := ipNet.Mask.Size()
|
||||
script := fmt.Sprintf(`New-NetIPAddress -InterfaceAlias "%s" -IPAddress "%s" -PrefixLength %d -PolicyStore ActiveStore -Confirm:$False`, interfaceName, ip.String(), subnetMaskSize)
|
||||
_, err = exec.Command("powershell", "-Command", script).CombinedOutput()
|
||||
require.NoError(t, err, "Failed to assign IP address to loopback adapter")
|
||||
|
||||
// Wait for the IP address to be applied
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
defer cancel()
|
||||
err = waitForIPAddress(ctx, interfaceName, ip.String())
|
||||
require.NoError(t, err, "IP address not applied within timeout")
|
||||
|
||||
t.Cleanup(func() {
|
||||
script = fmt.Sprintf(`Remove-NetIPAddress -InterfaceAlias "%s" -IPAddress "%s" -Confirm:$False`, interfaceName, ip.String())
|
||||
_, err = exec.Command("powershell", "-Command", script).CombinedOutput()
|
||||
require.NoError(t, err, "Failed to remove IP address from loopback adapter")
|
||||
})
|
||||
|
||||
return interfaceName
|
||||
}
|
||||
|
||||
func fetchOriginalGateway() (*RouteInfo, error) {
|
||||
cmd := exec.Command("powershell", "-Command", "Get-NetRoute -DestinationPrefix 0.0.0.0/0 | Select-Object NextHop, RouteMetric, InterfaceAlias | ConvertTo-Json")
|
||||
output, err := cmd.CombinedOutput()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to execute Get-NetRoute: %w", err)
|
||||
}
|
||||
|
||||
var routeInfo RouteInfo
|
||||
err = json.Unmarshal(output, &routeInfo)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse JSON output: %w", err)
|
||||
}
|
||||
|
||||
return &routeInfo, nil
|
||||
}
|
||||
|
||||
func verifyOutput(t *testing.T, output *FindNetRouteOutput, sourceIP, destPrefix, nextHop, intf string) {
|
||||
t.Helper()
|
||||
|
||||
assert.Equal(t, sourceIP, output.IPAddress, "Source IP mismatch")
|
||||
assert.Equal(t, destPrefix, output.DestinationPrefix, "Destination prefix mismatch")
|
||||
assert.Equal(t, nextHop, output.NextHop, "Next hop mismatch")
|
||||
assert.Equal(t, intf, output.InterfaceAlias, "Interface mismatch")
|
||||
}
|
||||
|
||||
func waitForIPAddress(ctx context.Context, interfaceAlias, expectedIPAddress string) error {
|
||||
ticker := time.NewTicker(1 * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-ticker.C:
|
||||
out, err := exec.Command("powershell", "-Command", fmt.Sprintf(`Get-NetIPAddress -InterfaceAlias "%s" | Select-Object -ExpandProperty IPAddress`, interfaceAlias)).CombinedOutput()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ipAddresses := strings.Split(strings.TrimSpace(string(out)), "\n")
|
||||
for _, ip := range ipAddresses {
|
||||
if strings.TrimSpace(ip) == expectedIPAddress {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func combineOutputs(outputs []FindNetRouteOutput) *FindNetRouteOutput {
|
||||
var combined FindNetRouteOutput
|
||||
|
||||
for _, output := range outputs {
|
||||
if output.IPAddress != "" {
|
||||
combined.IPAddress = output.IPAddress
|
||||
}
|
||||
if output.InterfaceIndex != 0 {
|
||||
combined.InterfaceIndex = output.InterfaceIndex
|
||||
}
|
||||
if output.InterfaceAlias != "" {
|
||||
combined.InterfaceAlias = output.InterfaceAlias
|
||||
}
|
||||
if output.AddressFamily != 0 {
|
||||
combined.AddressFamily = output.AddressFamily
|
||||
}
|
||||
if output.NextHop != "" {
|
||||
combined.NextHop = output.NextHop
|
||||
}
|
||||
if output.DestinationPrefix != "" {
|
||||
combined.DestinationPrefix = output.DestinationPrefix
|
||||
}
|
||||
}
|
||||
|
||||
return &combined
|
||||
}
|
||||
|
||||
func setupDummyInterfacesAndRoutes(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
createAndSetupDummyInterface(t, "Loopback Pseudo-Interface 1", "10.0.0.1/8")
|
||||
}
|
||||
132
client/internal/routeselector/routeselector.go
Normal file
132
client/internal/routeselector/routeselector.go
Normal file
@@ -0,0 +1,132 @@
|
||||
package routeselector
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/hashicorp/go-multierror"
|
||||
"golang.org/x/exp/maps"
|
||||
|
||||
route "github.com/netbirdio/netbird/route"
|
||||
)
|
||||
|
||||
type RouteSelector struct {
|
||||
selectedRoutes map[string]struct{}
|
||||
selectAll bool
|
||||
}
|
||||
|
||||
func NewRouteSelector() *RouteSelector {
|
||||
return &RouteSelector{
|
||||
selectedRoutes: map[string]struct{}{},
|
||||
// default selects all routes
|
||||
selectAll: true,
|
||||
}
|
||||
}
|
||||
|
||||
// SelectRoutes updates the selected routes based on the provided route IDs.
|
||||
func (rs *RouteSelector) SelectRoutes(routes []string, appendRoute bool, allRoutes []string) error {
|
||||
if !appendRoute {
|
||||
rs.selectedRoutes = map[string]struct{}{}
|
||||
}
|
||||
|
||||
var multiErr *multierror.Error
|
||||
for _, route := range routes {
|
||||
if !slices.Contains(allRoutes, route) {
|
||||
multiErr = multierror.Append(multiErr, fmt.Errorf("route '%s' is not available", route))
|
||||
continue
|
||||
}
|
||||
|
||||
rs.selectedRoutes[route] = struct{}{}
|
||||
}
|
||||
rs.selectAll = false
|
||||
|
||||
if multiErr != nil {
|
||||
multiErr.ErrorFormat = formatError
|
||||
}
|
||||
|
||||
return multiErr.ErrorOrNil()
|
||||
}
|
||||
|
||||
// SelectAllRoutes sets the selector to select all routes.
|
||||
func (rs *RouteSelector) SelectAllRoutes() {
|
||||
rs.selectAll = true
|
||||
rs.selectedRoutes = map[string]struct{}{}
|
||||
}
|
||||
|
||||
// DeselectRoutes removes specific routes from the selection.
|
||||
// If the selector is in "select all" mode, it will transition to "select specific" mode.
|
||||
func (rs *RouteSelector) DeselectRoutes(routes []string, allRoutes []string) error {
|
||||
if rs.selectAll {
|
||||
rs.selectAll = false
|
||||
rs.selectedRoutes = map[string]struct{}{}
|
||||
for _, route := range allRoutes {
|
||||
rs.selectedRoutes[route] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
var multiErr *multierror.Error
|
||||
|
||||
for _, route := range routes {
|
||||
if !slices.Contains(allRoutes, route) {
|
||||
multiErr = multierror.Append(multiErr, fmt.Errorf("route '%s' is not available", route))
|
||||
continue
|
||||
}
|
||||
delete(rs.selectedRoutes, route)
|
||||
}
|
||||
|
||||
if multiErr != nil {
|
||||
multiErr.ErrorFormat = formatError
|
||||
}
|
||||
|
||||
return multiErr.ErrorOrNil()
|
||||
}
|
||||
|
||||
// DeselectAllRoutes deselects all routes, effectively disabling route selection.
|
||||
func (rs *RouteSelector) DeselectAllRoutes() {
|
||||
rs.selectAll = false
|
||||
rs.selectedRoutes = map[string]struct{}{}
|
||||
}
|
||||
|
||||
// IsSelected checks if a specific route is selected.
|
||||
func (rs *RouteSelector) IsSelected(routeID string) bool {
|
||||
if rs.selectAll {
|
||||
return true
|
||||
}
|
||||
_, selected := rs.selectedRoutes[routeID]
|
||||
return selected
|
||||
}
|
||||
|
||||
// FilterSelected removes unselected routes from the provided map.
|
||||
func (rs *RouteSelector) FilterSelected(routes map[string][]*route.Route) map[string][]*route.Route {
|
||||
if rs.selectAll {
|
||||
return maps.Clone(routes)
|
||||
}
|
||||
|
||||
filtered := map[string][]*route.Route{}
|
||||
for id, rt := range routes {
|
||||
netID := id
|
||||
if i := strings.LastIndex(id, "-"); i != -1 {
|
||||
netID = id[:i]
|
||||
}
|
||||
if rs.IsSelected(netID) {
|
||||
filtered[id] = rt
|
||||
}
|
||||
}
|
||||
return filtered
|
||||
}
|
||||
|
||||
func formatError(es []error) string {
|
||||
if len(es) == 1 {
|
||||
return fmt.Sprintf("1 error occurred:\n\t* %s", es[0])
|
||||
}
|
||||
|
||||
points := make([]string, len(es))
|
||||
for i, err := range es {
|
||||
points[i] = fmt.Sprintf("* %s", err)
|
||||
}
|
||||
|
||||
return fmt.Sprintf(
|
||||
"%d errors occurred:\n\t%s",
|
||||
len(es), strings.Join(points, "\n\t"))
|
||||
}
|
||||
275
client/internal/routeselector/routeselector_test.go
Normal file
275
client/internal/routeselector/routeselector_test.go
Normal file
@@ -0,0 +1,275 @@
|
||||
package routeselector_test
|
||||
|
||||
import (
|
||||
"slices"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/routeselector"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
)
|
||||
|
||||
func TestRouteSelector_SelectRoutes(t *testing.T) {
|
||||
allRoutes := []string{"route1", "route2", "route3"}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
initialSelected []string
|
||||
|
||||
selectRoutes []string
|
||||
append bool
|
||||
|
||||
wantSelected []string
|
||||
wantError bool
|
||||
}{
|
||||
{
|
||||
name: "Select specific routes, initial all selected",
|
||||
selectRoutes: []string{"route1", "route2"},
|
||||
wantSelected: []string{"route1", "route2"},
|
||||
},
|
||||
{
|
||||
name: "Select specific routes, initial all deselected",
|
||||
initialSelected: []string{},
|
||||
selectRoutes: []string{"route1", "route2"},
|
||||
wantSelected: []string{"route1", "route2"},
|
||||
},
|
||||
{
|
||||
name: "Select specific routes with initial selection",
|
||||
initialSelected: []string{"route1"},
|
||||
selectRoutes: []string{"route2", "route3"},
|
||||
wantSelected: []string{"route2", "route3"},
|
||||
},
|
||||
{
|
||||
name: "Select non-existing route",
|
||||
selectRoutes: []string{"route1", "route4"},
|
||||
wantSelected: []string{"route1"},
|
||||
wantError: true,
|
||||
},
|
||||
{
|
||||
name: "Append route with initial selection",
|
||||
initialSelected: []string{"route1"},
|
||||
selectRoutes: []string{"route2"},
|
||||
append: true,
|
||||
wantSelected: []string{"route1", "route2"},
|
||||
},
|
||||
{
|
||||
name: "Append route without initial selection",
|
||||
selectRoutes: []string{"route2"},
|
||||
append: true,
|
||||
wantSelected: []string{"route2"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
rs := routeselector.NewRouteSelector()
|
||||
|
||||
if tt.initialSelected != nil {
|
||||
err := rs.SelectRoutes(tt.initialSelected, false, allRoutes)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
err := rs.SelectRoutes(tt.selectRoutes, tt.append, allRoutes)
|
||||
if tt.wantError {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
for _, id := range allRoutes {
|
||||
assert.Equal(t, rs.IsSelected(id), slices.Contains(tt.wantSelected, id))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouteSelector_SelectAllRoutes(t *testing.T) {
|
||||
allRoutes := []string{"route1", "route2", "route3"}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
initialSelected []string
|
||||
|
||||
wantSelected []string
|
||||
}{
|
||||
{
|
||||
name: "Initial all selected",
|
||||
wantSelected: []string{"route1", "route2", "route3"},
|
||||
},
|
||||
{
|
||||
name: "Initial all deselected",
|
||||
initialSelected: []string{},
|
||||
wantSelected: []string{"route1", "route2", "route3"},
|
||||
},
|
||||
{
|
||||
name: "Initial some selected",
|
||||
initialSelected: []string{"route1"},
|
||||
wantSelected: []string{"route1", "route2", "route3"},
|
||||
},
|
||||
{
|
||||
name: "Initial all selected",
|
||||
initialSelected: []string{"route1", "route2", "route3"},
|
||||
wantSelected: []string{"route1", "route2", "route3"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
rs := routeselector.NewRouteSelector()
|
||||
|
||||
if tt.initialSelected != nil {
|
||||
err := rs.SelectRoutes(tt.initialSelected, false, allRoutes)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
rs.SelectAllRoutes()
|
||||
|
||||
for _, id := range allRoutes {
|
||||
assert.Equal(t, rs.IsSelected(id), slices.Contains(tt.wantSelected, id))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouteSelector_DeselectRoutes(t *testing.T) {
|
||||
allRoutes := []string{"route1", "route2", "route3"}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
initialSelected []string
|
||||
|
||||
deselectRoutes []string
|
||||
|
||||
wantSelected []string
|
||||
wantError bool
|
||||
}{
|
||||
{
|
||||
name: "Deselect specific routes, initial all selected",
|
||||
deselectRoutes: []string{"route1", "route2"},
|
||||
wantSelected: []string{"route3"},
|
||||
},
|
||||
{
|
||||
name: "Deselect specific routes, initial all deselected",
|
||||
initialSelected: []string{},
|
||||
deselectRoutes: []string{"route1", "route2"},
|
||||
wantSelected: []string{},
|
||||
},
|
||||
{
|
||||
name: "Deselect specific routes with initial selection",
|
||||
initialSelected: []string{"route1", "route2"},
|
||||
deselectRoutes: []string{"route1", "route3"},
|
||||
wantSelected: []string{"route2"},
|
||||
},
|
||||
{
|
||||
name: "Deselect non-existing route",
|
||||
initialSelected: []string{"route1", "route2"},
|
||||
deselectRoutes: []string{"route1", "route4"},
|
||||
wantSelected: []string{"route2"},
|
||||
wantError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
rs := routeselector.NewRouteSelector()
|
||||
|
||||
if tt.initialSelected != nil {
|
||||
err := rs.SelectRoutes(tt.initialSelected, false, allRoutes)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
err := rs.DeselectRoutes(tt.deselectRoutes, allRoutes)
|
||||
if tt.wantError {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
for _, id := range allRoutes {
|
||||
assert.Equal(t, rs.IsSelected(id), slices.Contains(tt.wantSelected, id))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouteSelector_DeselectAll(t *testing.T) {
|
||||
allRoutes := []string{"route1", "route2", "route3"}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
initialSelected []string
|
||||
|
||||
wantSelected []string
|
||||
}{
|
||||
{
|
||||
name: "Initial all selected",
|
||||
wantSelected: []string{},
|
||||
},
|
||||
{
|
||||
name: "Initial all deselected",
|
||||
initialSelected: []string{},
|
||||
wantSelected: []string{},
|
||||
},
|
||||
{
|
||||
name: "Initial some selected",
|
||||
initialSelected: []string{"route1", "route2"},
|
||||
wantSelected: []string{},
|
||||
},
|
||||
{
|
||||
name: "Initial all selected",
|
||||
initialSelected: []string{"route1", "route2", "route3"},
|
||||
wantSelected: []string{},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
rs := routeselector.NewRouteSelector()
|
||||
|
||||
if tt.initialSelected != nil {
|
||||
err := rs.SelectRoutes(tt.initialSelected, false, allRoutes)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
rs.DeselectAllRoutes()
|
||||
|
||||
for _, id := range allRoutes {
|
||||
assert.Equal(t, rs.IsSelected(id), slices.Contains(tt.wantSelected, id))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouteSelector_IsSelected(t *testing.T) {
|
||||
rs := routeselector.NewRouteSelector()
|
||||
|
||||
err := rs.SelectRoutes([]string{"route1", "route2"}, false, []string{"route1", "route2", "route3"})
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.True(t, rs.IsSelected("route1"))
|
||||
assert.True(t, rs.IsSelected("route2"))
|
||||
assert.False(t, rs.IsSelected("route3"))
|
||||
assert.False(t, rs.IsSelected("route4"))
|
||||
}
|
||||
|
||||
func TestRouteSelector_FilterSelected(t *testing.T) {
|
||||
rs := routeselector.NewRouteSelector()
|
||||
|
||||
err := rs.SelectRoutes([]string{"route1", "route2"}, false, []string{"route1", "route2", "route3"})
|
||||
require.NoError(t, err)
|
||||
|
||||
routes := map[string][]*route.Route{
|
||||
"route1-10.0.0.0/8": {},
|
||||
"route2-192.168.0.0/16": {},
|
||||
"route3-172.16.0.0/12": {},
|
||||
}
|
||||
|
||||
filtered := rs.FilterSelected(routes)
|
||||
|
||||
assert.Equal(t, map[string][]*route.Route{
|
||||
"route1-10.0.0.0/8": {},
|
||||
"route2-192.168.0.0/16": {},
|
||||
}, filtered)
|
||||
}
|
||||
@@ -1,10 +1,8 @@
|
||||
package wgproxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
"net"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -25,7 +23,7 @@ func (pl portLookup) searchFreePort() (int, error) {
|
||||
}
|
||||
|
||||
func (pl portLookup) tryToBind(port int) error {
|
||||
l, err := nbnet.NewListener().ListenPacket(context.Background(), "udp", fmt.Sprintf(":%d", port))
|
||||
l, err := net.ListenPacket("udp", fmt.Sprintf(":%d", port))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
|
||||
"github.com/google/gopacket"
|
||||
"github.com/google/gopacket/layers"
|
||||
"github.com/pion/transport/v3"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/ebpf"
|
||||
@@ -29,7 +30,7 @@ type WGEBPFProxy struct {
|
||||
turnConnMutex sync.Mutex
|
||||
|
||||
rawConn net.PacketConn
|
||||
conn *net.UDPConn
|
||||
conn transport.UDPConn
|
||||
}
|
||||
|
||||
// NewWGEBPFProxy create new WGEBPFProxy instance
|
||||
@@ -67,7 +68,7 @@ func (p *WGEBPFProxy) Listen() error {
|
||||
IP: net.ParseIP("127.0.0.1"),
|
||||
}
|
||||
|
||||
p.conn, err = nbnet.ListenUDP("udp", &addr)
|
||||
conn, err := nbnet.ListenUDP("udp", &addr)
|
||||
if err != nil {
|
||||
cErr := p.Free()
|
||||
if cErr != nil {
|
||||
@@ -75,6 +76,7 @@ func (p *WGEBPFProxy) Listen() error {
|
||||
}
|
||||
return err
|
||||
}
|
||||
p.conn = conn
|
||||
|
||||
go p.proxyToRemote()
|
||||
log.Infof("local wg proxy listening on: %d", wgPorxyPort)
|
||||
@@ -228,7 +230,7 @@ func (p *WGEBPFProxy) prepareSenderRawSocket() (net.PacketConn, error) {
|
||||
}
|
||||
|
||||
// Set the fwmark on the socket.
|
||||
err = syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_MARK, nbnet.NetbirdFwmark)
|
||||
err = nbnet.SetSocketOpt(fd)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("setting fwmark failed: %w", err)
|
||||
}
|
||||
|
||||
@@ -71,6 +71,42 @@ func (p *Preferences) SetPreSharedKey(key string) {
|
||||
p.configInput.PreSharedKey = &key
|
||||
}
|
||||
|
||||
// SetRosenpassEnabled store if rosenpass is enabled
|
||||
func (p *Preferences) SetRosenpassEnabled(enabled bool) {
|
||||
p.configInput.RosenpassEnabled = &enabled
|
||||
}
|
||||
|
||||
// GetRosenpassEnabled read rosenpass enabled from config file
|
||||
func (p *Preferences) GetRosenpassEnabled() (bool, error) {
|
||||
if p.configInput.RosenpassEnabled != nil {
|
||||
return *p.configInput.RosenpassEnabled, nil
|
||||
}
|
||||
|
||||
cfg, err := internal.ReadConfig(p.configInput.ConfigPath)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return cfg.RosenpassEnabled, err
|
||||
}
|
||||
|
||||
// SetRosenpassPermissive store the given permissive and wait for commit
|
||||
func (p *Preferences) SetRosenpassPermissive(permissive bool) {
|
||||
p.configInput.RosenpassPermissive = &permissive
|
||||
}
|
||||
|
||||
// GetRosenpassPermissive read rosenpass permissive from config file
|
||||
func (p *Preferences) GetRosenpassPermissive() (bool, error) {
|
||||
if p.configInput.RosenpassPermissive != nil {
|
||||
return *p.configInput.RosenpassPermissive, nil
|
||||
}
|
||||
|
||||
cfg, err := internal.ReadConfig(p.configInput.ConfigPath)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return cfg.RosenpassPermissive, err
|
||||
}
|
||||
|
||||
// Commit write out the changes into config file
|
||||
func (p *Preferences) Commit() error {
|
||||
_, err := internal.UpdateOrCreateConfig(p.configInput)
|
||||
|
||||
@@ -1,17 +1,17 @@
|
||||
// Code generated by protoc-gen-go. DO NOT EDIT.
|
||||
// versions:
|
||||
// protoc-gen-go v1.26.0
|
||||
// protoc v4.24.3
|
||||
// protoc v3.12.4
|
||||
// source: daemon.proto
|
||||
|
||||
package proto
|
||||
|
||||
import (
|
||||
_ "github.com/golang/protobuf/protoc-gen-go/descriptor"
|
||||
duration "github.com/golang/protobuf/ptypes/duration"
|
||||
timestamp "github.com/golang/protobuf/ptypes/timestamp"
|
||||
protoreflect "google.golang.org/protobuf/reflect/protoreflect"
|
||||
protoimpl "google.golang.org/protobuf/runtime/protoimpl"
|
||||
_ "google.golang.org/protobuf/types/descriptorpb"
|
||||
durationpb "google.golang.org/protobuf/types/known/durationpb"
|
||||
timestamppb "google.golang.org/protobuf/types/known/timestamppb"
|
||||
reflect "reflect"
|
||||
sync "sync"
|
||||
)
|
||||
@@ -23,6 +23,70 @@ const (
|
||||
_ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20)
|
||||
)
|
||||
|
||||
type LogLevel int32
|
||||
|
||||
const (
|
||||
LogLevel_UNKNOWN LogLevel = 0
|
||||
LogLevel_PANIC LogLevel = 1
|
||||
LogLevel_FATAL LogLevel = 2
|
||||
LogLevel_ERROR LogLevel = 3
|
||||
LogLevel_WARN LogLevel = 4
|
||||
LogLevel_INFO LogLevel = 5
|
||||
LogLevel_DEBUG LogLevel = 6
|
||||
LogLevel_TRACE LogLevel = 7
|
||||
)
|
||||
|
||||
// Enum value maps for LogLevel.
|
||||
var (
|
||||
LogLevel_name = map[int32]string{
|
||||
0: "UNKNOWN",
|
||||
1: "PANIC",
|
||||
2: "FATAL",
|
||||
3: "ERROR",
|
||||
4: "WARN",
|
||||
5: "INFO",
|
||||
6: "DEBUG",
|
||||
7: "TRACE",
|
||||
}
|
||||
LogLevel_value = map[string]int32{
|
||||
"UNKNOWN": 0,
|
||||
"PANIC": 1,
|
||||
"FATAL": 2,
|
||||
"ERROR": 3,
|
||||
"WARN": 4,
|
||||
"INFO": 5,
|
||||
"DEBUG": 6,
|
||||
"TRACE": 7,
|
||||
}
|
||||
)
|
||||
|
||||
func (x LogLevel) Enum() *LogLevel {
|
||||
p := new(LogLevel)
|
||||
*p = x
|
||||
return p
|
||||
}
|
||||
|
||||
func (x LogLevel) String() string {
|
||||
return protoimpl.X.EnumStringOf(x.Descriptor(), protoreflect.EnumNumber(x))
|
||||
}
|
||||
|
||||
func (LogLevel) Descriptor() protoreflect.EnumDescriptor {
|
||||
return file_daemon_proto_enumTypes[0].Descriptor()
|
||||
}
|
||||
|
||||
func (LogLevel) Type() protoreflect.EnumType {
|
||||
return &file_daemon_proto_enumTypes[0]
|
||||
}
|
||||
|
||||
func (x LogLevel) Number() protoreflect.EnumNumber {
|
||||
return protoreflect.EnumNumber(x)
|
||||
}
|
||||
|
||||
// Deprecated: Use LogLevel.Descriptor instead.
|
||||
func (LogLevel) EnumDescriptor() ([]byte, []int) {
|
||||
return file_daemon_proto_rawDescGZIP(), []int{0}
|
||||
}
|
||||
|
||||
type LoginRequest struct {
|
||||
state protoimpl.MessageState
|
||||
sizeCache protoimpl.SizeCache
|
||||
@@ -766,23 +830,23 @@ type PeerState struct {
|
||||
sizeCache protoimpl.SizeCache
|
||||
unknownFields protoimpl.UnknownFields
|
||||
|
||||
IP string `protobuf:"bytes,1,opt,name=IP,proto3" json:"IP,omitempty"`
|
||||
PubKey string `protobuf:"bytes,2,opt,name=pubKey,proto3" json:"pubKey,omitempty"`
|
||||
ConnStatus string `protobuf:"bytes,3,opt,name=connStatus,proto3" json:"connStatus,omitempty"`
|
||||
ConnStatusUpdate *timestamppb.Timestamp `protobuf:"bytes,4,opt,name=connStatusUpdate,proto3" json:"connStatusUpdate,omitempty"`
|
||||
Relayed bool `protobuf:"varint,5,opt,name=relayed,proto3" json:"relayed,omitempty"`
|
||||
Direct bool `protobuf:"varint,6,opt,name=direct,proto3" json:"direct,omitempty"`
|
||||
LocalIceCandidateType string `protobuf:"bytes,7,opt,name=localIceCandidateType,proto3" json:"localIceCandidateType,omitempty"`
|
||||
RemoteIceCandidateType string `protobuf:"bytes,8,opt,name=remoteIceCandidateType,proto3" json:"remoteIceCandidateType,omitempty"`
|
||||
Fqdn string `protobuf:"bytes,9,opt,name=fqdn,proto3" json:"fqdn,omitempty"`
|
||||
LocalIceCandidateEndpoint string `protobuf:"bytes,10,opt,name=localIceCandidateEndpoint,proto3" json:"localIceCandidateEndpoint,omitempty"`
|
||||
RemoteIceCandidateEndpoint string `protobuf:"bytes,11,opt,name=remoteIceCandidateEndpoint,proto3" json:"remoteIceCandidateEndpoint,omitempty"`
|
||||
LastWireguardHandshake *timestamppb.Timestamp `protobuf:"bytes,12,opt,name=lastWireguardHandshake,proto3" json:"lastWireguardHandshake,omitempty"`
|
||||
BytesRx int64 `protobuf:"varint,13,opt,name=bytesRx,proto3" json:"bytesRx,omitempty"`
|
||||
BytesTx int64 `protobuf:"varint,14,opt,name=bytesTx,proto3" json:"bytesTx,omitempty"`
|
||||
RosenpassEnabled bool `protobuf:"varint,15,opt,name=rosenpassEnabled,proto3" json:"rosenpassEnabled,omitempty"`
|
||||
Routes []string `protobuf:"bytes,16,rep,name=routes,proto3" json:"routes,omitempty"`
|
||||
Latency *durationpb.Duration `protobuf:"bytes,17,opt,name=latency,proto3" json:"latency,omitempty"`
|
||||
IP string `protobuf:"bytes,1,opt,name=IP,proto3" json:"IP,omitempty"`
|
||||
PubKey string `protobuf:"bytes,2,opt,name=pubKey,proto3" json:"pubKey,omitempty"`
|
||||
ConnStatus string `protobuf:"bytes,3,opt,name=connStatus,proto3" json:"connStatus,omitempty"`
|
||||
ConnStatusUpdate *timestamp.Timestamp `protobuf:"bytes,4,opt,name=connStatusUpdate,proto3" json:"connStatusUpdate,omitempty"`
|
||||
Relayed bool `protobuf:"varint,5,opt,name=relayed,proto3" json:"relayed,omitempty"`
|
||||
Direct bool `protobuf:"varint,6,opt,name=direct,proto3" json:"direct,omitempty"`
|
||||
LocalIceCandidateType string `protobuf:"bytes,7,opt,name=localIceCandidateType,proto3" json:"localIceCandidateType,omitempty"`
|
||||
RemoteIceCandidateType string `protobuf:"bytes,8,opt,name=remoteIceCandidateType,proto3" json:"remoteIceCandidateType,omitempty"`
|
||||
Fqdn string `protobuf:"bytes,9,opt,name=fqdn,proto3" json:"fqdn,omitempty"`
|
||||
LocalIceCandidateEndpoint string `protobuf:"bytes,10,opt,name=localIceCandidateEndpoint,proto3" json:"localIceCandidateEndpoint,omitempty"`
|
||||
RemoteIceCandidateEndpoint string `protobuf:"bytes,11,opt,name=remoteIceCandidateEndpoint,proto3" json:"remoteIceCandidateEndpoint,omitempty"`
|
||||
LastWireguardHandshake *timestamp.Timestamp `protobuf:"bytes,12,opt,name=lastWireguardHandshake,proto3" json:"lastWireguardHandshake,omitempty"`
|
||||
BytesRx int64 `protobuf:"varint,13,opt,name=bytesRx,proto3" json:"bytesRx,omitempty"`
|
||||
BytesTx int64 `protobuf:"varint,14,opt,name=bytesTx,proto3" json:"bytesTx,omitempty"`
|
||||
RosenpassEnabled bool `protobuf:"varint,15,opt,name=rosenpassEnabled,proto3" json:"rosenpassEnabled,omitempty"`
|
||||
Routes []string `protobuf:"bytes,16,rep,name=routes,proto3" json:"routes,omitempty"`
|
||||
Latency *duration.Duration `protobuf:"bytes,17,opt,name=latency,proto3" json:"latency,omitempty"`
|
||||
}
|
||||
|
||||
func (x *PeerState) Reset() {
|
||||
@@ -838,7 +902,7 @@ func (x *PeerState) GetConnStatus() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
func (x *PeerState) GetConnStatusUpdate() *timestamppb.Timestamp {
|
||||
func (x *PeerState) GetConnStatusUpdate() *timestamp.Timestamp {
|
||||
if x != nil {
|
||||
return x.ConnStatusUpdate
|
||||
}
|
||||
@@ -894,7 +958,7 @@ func (x *PeerState) GetRemoteIceCandidateEndpoint() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
func (x *PeerState) GetLastWireguardHandshake() *timestamppb.Timestamp {
|
||||
func (x *PeerState) GetLastWireguardHandshake() *timestamp.Timestamp {
|
||||
if x != nil {
|
||||
return x.LastWireguardHandshake
|
||||
}
|
||||
@@ -929,7 +993,7 @@ func (x *PeerState) GetRoutes() []string {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (x *PeerState) GetLatency() *durationpb.Duration {
|
||||
func (x *PeerState) GetLatency() *duration.Duration {
|
||||
if x != nil {
|
||||
return x.Latency
|
||||
}
|
||||
@@ -1383,6 +1447,442 @@ func (x *FullStatus) GetDnsServers() []*NSGroupState {
|
||||
return nil
|
||||
}
|
||||
|
||||
type ListRoutesRequest struct {
|
||||
state protoimpl.MessageState
|
||||
sizeCache protoimpl.SizeCache
|
||||
unknownFields protoimpl.UnknownFields
|
||||
}
|
||||
|
||||
func (x *ListRoutesRequest) Reset() {
|
||||
*x = ListRoutesRequest{}
|
||||
if protoimpl.UnsafeEnabled {
|
||||
mi := &file_daemon_proto_msgTypes[19]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
}
|
||||
|
||||
func (x *ListRoutesRequest) String() string {
|
||||
return protoimpl.X.MessageStringOf(x)
|
||||
}
|
||||
|
||||
func (*ListRoutesRequest) ProtoMessage() {}
|
||||
|
||||
func (x *ListRoutesRequest) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_daemon_proto_msgTypes[19]
|
||||
if protoimpl.UnsafeEnabled && x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
return ms
|
||||
}
|
||||
return mi.MessageOf(x)
|
||||
}
|
||||
|
||||
// Deprecated: Use ListRoutesRequest.ProtoReflect.Descriptor instead.
|
||||
func (*ListRoutesRequest) Descriptor() ([]byte, []int) {
|
||||
return file_daemon_proto_rawDescGZIP(), []int{19}
|
||||
}
|
||||
|
||||
type ListRoutesResponse struct {
|
||||
state protoimpl.MessageState
|
||||
sizeCache protoimpl.SizeCache
|
||||
unknownFields protoimpl.UnknownFields
|
||||
|
||||
Routes []*Route `protobuf:"bytes,1,rep,name=routes,proto3" json:"routes,omitempty"`
|
||||
}
|
||||
|
||||
func (x *ListRoutesResponse) Reset() {
|
||||
*x = ListRoutesResponse{}
|
||||
if protoimpl.UnsafeEnabled {
|
||||
mi := &file_daemon_proto_msgTypes[20]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
}
|
||||
|
||||
func (x *ListRoutesResponse) String() string {
|
||||
return protoimpl.X.MessageStringOf(x)
|
||||
}
|
||||
|
||||
func (*ListRoutesResponse) ProtoMessage() {}
|
||||
|
||||
func (x *ListRoutesResponse) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_daemon_proto_msgTypes[20]
|
||||
if protoimpl.UnsafeEnabled && x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
return ms
|
||||
}
|
||||
return mi.MessageOf(x)
|
||||
}
|
||||
|
||||
// Deprecated: Use ListRoutesResponse.ProtoReflect.Descriptor instead.
|
||||
func (*ListRoutesResponse) Descriptor() ([]byte, []int) {
|
||||
return file_daemon_proto_rawDescGZIP(), []int{20}
|
||||
}
|
||||
|
||||
func (x *ListRoutesResponse) GetRoutes() []*Route {
|
||||
if x != nil {
|
||||
return x.Routes
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type SelectRoutesRequest struct {
|
||||
state protoimpl.MessageState
|
||||
sizeCache protoimpl.SizeCache
|
||||
unknownFields protoimpl.UnknownFields
|
||||
|
||||
RouteIDs []string `protobuf:"bytes,1,rep,name=routeIDs,proto3" json:"routeIDs,omitempty"`
|
||||
Append bool `protobuf:"varint,2,opt,name=append,proto3" json:"append,omitempty"`
|
||||
All bool `protobuf:"varint,3,opt,name=all,proto3" json:"all,omitempty"`
|
||||
}
|
||||
|
||||
func (x *SelectRoutesRequest) Reset() {
|
||||
*x = SelectRoutesRequest{}
|
||||
if protoimpl.UnsafeEnabled {
|
||||
mi := &file_daemon_proto_msgTypes[21]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
}
|
||||
|
||||
func (x *SelectRoutesRequest) String() string {
|
||||
return protoimpl.X.MessageStringOf(x)
|
||||
}
|
||||
|
||||
func (*SelectRoutesRequest) ProtoMessage() {}
|
||||
|
||||
func (x *SelectRoutesRequest) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_daemon_proto_msgTypes[21]
|
||||
if protoimpl.UnsafeEnabled && x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
return ms
|
||||
}
|
||||
return mi.MessageOf(x)
|
||||
}
|
||||
|
||||
// Deprecated: Use SelectRoutesRequest.ProtoReflect.Descriptor instead.
|
||||
func (*SelectRoutesRequest) Descriptor() ([]byte, []int) {
|
||||
return file_daemon_proto_rawDescGZIP(), []int{21}
|
||||
}
|
||||
|
||||
func (x *SelectRoutesRequest) GetRouteIDs() []string {
|
||||
if x != nil {
|
||||
return x.RouteIDs
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (x *SelectRoutesRequest) GetAppend() bool {
|
||||
if x != nil {
|
||||
return x.Append
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (x *SelectRoutesRequest) GetAll() bool {
|
||||
if x != nil {
|
||||
return x.All
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
type SelectRoutesResponse struct {
|
||||
state protoimpl.MessageState
|
||||
sizeCache protoimpl.SizeCache
|
||||
unknownFields protoimpl.UnknownFields
|
||||
}
|
||||
|
||||
func (x *SelectRoutesResponse) Reset() {
|
||||
*x = SelectRoutesResponse{}
|
||||
if protoimpl.UnsafeEnabled {
|
||||
mi := &file_daemon_proto_msgTypes[22]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
}
|
||||
|
||||
func (x *SelectRoutesResponse) String() string {
|
||||
return protoimpl.X.MessageStringOf(x)
|
||||
}
|
||||
|
||||
func (*SelectRoutesResponse) ProtoMessage() {}
|
||||
|
||||
func (x *SelectRoutesResponse) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_daemon_proto_msgTypes[22]
|
||||
if protoimpl.UnsafeEnabled && x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
return ms
|
||||
}
|
||||
return mi.MessageOf(x)
|
||||
}
|
||||
|
||||
// Deprecated: Use SelectRoutesResponse.ProtoReflect.Descriptor instead.
|
||||
func (*SelectRoutesResponse) Descriptor() ([]byte, []int) {
|
||||
return file_daemon_proto_rawDescGZIP(), []int{22}
|
||||
}
|
||||
|
||||
type Route struct {
|
||||
state protoimpl.MessageState
|
||||
sizeCache protoimpl.SizeCache
|
||||
unknownFields protoimpl.UnknownFields
|
||||
|
||||
ID string `protobuf:"bytes,1,opt,name=ID,proto3" json:"ID,omitempty"`
|
||||
Network string `protobuf:"bytes,2,opt,name=network,proto3" json:"network,omitempty"`
|
||||
Selected bool `protobuf:"varint,3,opt,name=selected,proto3" json:"selected,omitempty"`
|
||||
}
|
||||
|
||||
func (x *Route) Reset() {
|
||||
*x = Route{}
|
||||
if protoimpl.UnsafeEnabled {
|
||||
mi := &file_daemon_proto_msgTypes[23]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
}
|
||||
|
||||
func (x *Route) String() string {
|
||||
return protoimpl.X.MessageStringOf(x)
|
||||
}
|
||||
|
||||
func (*Route) ProtoMessage() {}
|
||||
|
||||
func (x *Route) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_daemon_proto_msgTypes[23]
|
||||
if protoimpl.UnsafeEnabled && x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
return ms
|
||||
}
|
||||
return mi.MessageOf(x)
|
||||
}
|
||||
|
||||
// Deprecated: Use Route.ProtoReflect.Descriptor instead.
|
||||
func (*Route) Descriptor() ([]byte, []int) {
|
||||
return file_daemon_proto_rawDescGZIP(), []int{23}
|
||||
}
|
||||
|
||||
func (x *Route) GetID() string {
|
||||
if x != nil {
|
||||
return x.ID
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (x *Route) GetNetwork() string {
|
||||
if x != nil {
|
||||
return x.Network
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (x *Route) GetSelected() bool {
|
||||
if x != nil {
|
||||
return x.Selected
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
type DebugBundleRequest struct {
|
||||
state protoimpl.MessageState
|
||||
sizeCache protoimpl.SizeCache
|
||||
unknownFields protoimpl.UnknownFields
|
||||
|
||||
Anonymize bool `protobuf:"varint,1,opt,name=anonymize,proto3" json:"anonymize,omitempty"`
|
||||
Status string `protobuf:"bytes,2,opt,name=status,proto3" json:"status,omitempty"`
|
||||
}
|
||||
|
||||
func (x *DebugBundleRequest) Reset() {
|
||||
*x = DebugBundleRequest{}
|
||||
if protoimpl.UnsafeEnabled {
|
||||
mi := &file_daemon_proto_msgTypes[24]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
}
|
||||
|
||||
func (x *DebugBundleRequest) String() string {
|
||||
return protoimpl.X.MessageStringOf(x)
|
||||
}
|
||||
|
||||
func (*DebugBundleRequest) ProtoMessage() {}
|
||||
|
||||
func (x *DebugBundleRequest) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_daemon_proto_msgTypes[24]
|
||||
if protoimpl.UnsafeEnabled && x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
return ms
|
||||
}
|
||||
return mi.MessageOf(x)
|
||||
}
|
||||
|
||||
// Deprecated: Use DebugBundleRequest.ProtoReflect.Descriptor instead.
|
||||
func (*DebugBundleRequest) Descriptor() ([]byte, []int) {
|
||||
return file_daemon_proto_rawDescGZIP(), []int{24}
|
||||
}
|
||||
|
||||
func (x *DebugBundleRequest) GetAnonymize() bool {
|
||||
if x != nil {
|
||||
return x.Anonymize
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (x *DebugBundleRequest) GetStatus() string {
|
||||
if x != nil {
|
||||
return x.Status
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
type DebugBundleResponse struct {
|
||||
state protoimpl.MessageState
|
||||
sizeCache protoimpl.SizeCache
|
||||
unknownFields protoimpl.UnknownFields
|
||||
|
||||
Path string `protobuf:"bytes,1,opt,name=path,proto3" json:"path,omitempty"`
|
||||
}
|
||||
|
||||
func (x *DebugBundleResponse) Reset() {
|
||||
*x = DebugBundleResponse{}
|
||||
if protoimpl.UnsafeEnabled {
|
||||
mi := &file_daemon_proto_msgTypes[25]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
}
|
||||
|
||||
func (x *DebugBundleResponse) String() string {
|
||||
return protoimpl.X.MessageStringOf(x)
|
||||
}
|
||||
|
||||
func (*DebugBundleResponse) ProtoMessage() {}
|
||||
|
||||
func (x *DebugBundleResponse) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_daemon_proto_msgTypes[25]
|
||||
if protoimpl.UnsafeEnabled && x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
return ms
|
||||
}
|
||||
return mi.MessageOf(x)
|
||||
}
|
||||
|
||||
// Deprecated: Use DebugBundleResponse.ProtoReflect.Descriptor instead.
|
||||
func (*DebugBundleResponse) Descriptor() ([]byte, []int) {
|
||||
return file_daemon_proto_rawDescGZIP(), []int{25}
|
||||
}
|
||||
|
||||
func (x *DebugBundleResponse) GetPath() string {
|
||||
if x != nil {
|
||||
return x.Path
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
type SetLogLevelRequest struct {
|
||||
state protoimpl.MessageState
|
||||
sizeCache protoimpl.SizeCache
|
||||
unknownFields protoimpl.UnknownFields
|
||||
|
||||
Level LogLevel `protobuf:"varint,1,opt,name=level,proto3,enum=daemon.LogLevel" json:"level,omitempty"`
|
||||
}
|
||||
|
||||
func (x *SetLogLevelRequest) Reset() {
|
||||
*x = SetLogLevelRequest{}
|
||||
if protoimpl.UnsafeEnabled {
|
||||
mi := &file_daemon_proto_msgTypes[26]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
}
|
||||
|
||||
func (x *SetLogLevelRequest) String() string {
|
||||
return protoimpl.X.MessageStringOf(x)
|
||||
}
|
||||
|
||||
func (*SetLogLevelRequest) ProtoMessage() {}
|
||||
|
||||
func (x *SetLogLevelRequest) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_daemon_proto_msgTypes[26]
|
||||
if protoimpl.UnsafeEnabled && x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
return ms
|
||||
}
|
||||
return mi.MessageOf(x)
|
||||
}
|
||||
|
||||
// Deprecated: Use SetLogLevelRequest.ProtoReflect.Descriptor instead.
|
||||
func (*SetLogLevelRequest) Descriptor() ([]byte, []int) {
|
||||
return file_daemon_proto_rawDescGZIP(), []int{26}
|
||||
}
|
||||
|
||||
func (x *SetLogLevelRequest) GetLevel() LogLevel {
|
||||
if x != nil {
|
||||
return x.Level
|
||||
}
|
||||
return LogLevel_UNKNOWN
|
||||
}
|
||||
|
||||
type SetLogLevelResponse struct {
|
||||
state protoimpl.MessageState
|
||||
sizeCache protoimpl.SizeCache
|
||||
unknownFields protoimpl.UnknownFields
|
||||
}
|
||||
|
||||
func (x *SetLogLevelResponse) Reset() {
|
||||
*x = SetLogLevelResponse{}
|
||||
if protoimpl.UnsafeEnabled {
|
||||
mi := &file_daemon_proto_msgTypes[27]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
}
|
||||
|
||||
func (x *SetLogLevelResponse) String() string {
|
||||
return protoimpl.X.MessageStringOf(x)
|
||||
}
|
||||
|
||||
func (*SetLogLevelResponse) ProtoMessage() {}
|
||||
|
||||
func (x *SetLogLevelResponse) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_daemon_proto_msgTypes[27]
|
||||
if protoimpl.UnsafeEnabled && x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
return ms
|
||||
}
|
||||
return mi.MessageOf(x)
|
||||
}
|
||||
|
||||
// Deprecated: Use SetLogLevelResponse.ProtoReflect.Descriptor instead.
|
||||
func (*SetLogLevelResponse) Descriptor() ([]byte, []int) {
|
||||
return file_daemon_proto_rawDescGZIP(), []int{27}
|
||||
}
|
||||
|
||||
var File_daemon_proto protoreflect.FileDescriptor
|
||||
|
||||
var file_daemon_proto_rawDesc = []byte{
|
||||
@@ -1601,32 +2101,92 @@ var file_daemon_proto_rawDesc = []byte{
|
||||
0x72, 0x65, 0x6c, 0x61, 0x79, 0x73, 0x12, 0x35, 0x0a, 0x0b, 0x64, 0x6e, 0x73, 0x5f, 0x73, 0x65,
|
||||
0x72, 0x76, 0x65, 0x72, 0x73, 0x18, 0x06, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x14, 0x2e, 0x64, 0x61,
|
||||
0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4e, 0x53, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x53, 0x74, 0x61, 0x74,
|
||||
0x65, 0x52, 0x0a, 0x64, 0x6e, 0x73, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x73, 0x32, 0xf7, 0x02,
|
||||
0x0a, 0x0d, 0x44, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x12,
|
||||
0x36, 0x0a, 0x05, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x12, 0x14, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f,
|
||||
0x6e, 0x2e, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x15,
|
||||
0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x73,
|
||||
0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x4b, 0x0a, 0x0c, 0x57, 0x61, 0x69, 0x74, 0x53,
|
||||
0x53, 0x4f, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x12, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e,
|
||||
0x2e, 0x57, 0x61, 0x69, 0x74, 0x53, 0x53, 0x4f, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x71,
|
||||
0x75, 0x65, 0x73, 0x74, 0x1a, 0x1c, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x57, 0x61,
|
||||
0x69, 0x74, 0x53, 0x53, 0x4f, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e,
|
||||
0x73, 0x65, 0x22, 0x00, 0x12, 0x2d, 0x0a, 0x02, 0x55, 0x70, 0x12, 0x11, 0x2e, 0x64, 0x61, 0x65,
|
||||
0x6d, 0x6f, 0x6e, 0x2e, 0x55, 0x70, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x12, 0x2e,
|
||||
0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x55, 0x70, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73,
|
||||
0x65, 0x22, 0x00, 0x12, 0x39, 0x0a, 0x06, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x12, 0x15, 0x2e,
|
||||
0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x52, 0x65, 0x71,
|
||||
0x75, 0x65, 0x73, 0x74, 0x1a, 0x16, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x74,
|
||||
0x61, 0x74, 0x75, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x33,
|
||||
0x0a, 0x04, 0x44, 0x6f, 0x77, 0x6e, 0x12, 0x13, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e,
|
||||
0x44, 0x6f, 0x77, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x14, 0x2e, 0x64, 0x61,
|
||||
0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x6f, 0x77, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73,
|
||||
0x65, 0x22, 0x00, 0x12, 0x42, 0x0a, 0x09, 0x47, 0x65, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67,
|
||||
0x12, 0x18, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x47, 0x65, 0x74, 0x43, 0x6f, 0x6e,
|
||||
0x66, 0x69, 0x67, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x19, 0x2e, 0x64, 0x61, 0x65,
|
||||
0x6d, 0x6f, 0x6e, 0x2e, 0x47, 0x65, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x65, 0x73,
|
||||
0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x42, 0x08, 0x5a, 0x06, 0x2f, 0x70, 0x72, 0x6f, 0x74,
|
||||
0x6f, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33,
|
||||
0x65, 0x52, 0x0a, 0x64, 0x6e, 0x73, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x73, 0x22, 0x13, 0x0a,
|
||||
0x11, 0x4c, 0x69, 0x73, 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65,
|
||||
0x73, 0x74, 0x22, 0x3b, 0x0a, 0x12, 0x4c, 0x69, 0x73, 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73,
|
||||
0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x25, 0x0a, 0x06, 0x72, 0x6f, 0x75, 0x74,
|
||||
0x65, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x0d, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f,
|
||||
0x6e, 0x2e, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x52, 0x06, 0x72, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x22,
|
||||
0x5b, 0x0a, 0x13, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x52,
|
||||
0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x1a, 0x0a, 0x08, 0x72, 0x6f, 0x75, 0x74, 0x65, 0x49,
|
||||
0x44, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x09, 0x52, 0x08, 0x72, 0x6f, 0x75, 0x74, 0x65, 0x49,
|
||||
0x44, 0x73, 0x12, 0x16, 0x0a, 0x06, 0x61, 0x70, 0x70, 0x65, 0x6e, 0x64, 0x18, 0x02, 0x20, 0x01,
|
||||
0x28, 0x08, 0x52, 0x06, 0x61, 0x70, 0x70, 0x65, 0x6e, 0x64, 0x12, 0x10, 0x0a, 0x03, 0x61, 0x6c,
|
||||
0x6c, 0x18, 0x03, 0x20, 0x01, 0x28, 0x08, 0x52, 0x03, 0x61, 0x6c, 0x6c, 0x22, 0x16, 0x0a, 0x14,
|
||||
0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x52, 0x65, 0x73, 0x70,
|
||||
0x6f, 0x6e, 0x73, 0x65, 0x22, 0x4d, 0x0a, 0x05, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x12, 0x0e, 0x0a,
|
||||
0x02, 0x49, 0x44, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x49, 0x44, 0x12, 0x18, 0x0a,
|
||||
0x07, 0x6e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07,
|
||||
0x6e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x12, 0x1a, 0x0a, 0x08, 0x73, 0x65, 0x6c, 0x65, 0x63,
|
||||
0x74, 0x65, 0x64, 0x18, 0x03, 0x20, 0x01, 0x28, 0x08, 0x52, 0x08, 0x73, 0x65, 0x6c, 0x65, 0x63,
|
||||
0x74, 0x65, 0x64, 0x22, 0x4a, 0x0a, 0x12, 0x44, 0x65, 0x62, 0x75, 0x67, 0x42, 0x75, 0x6e, 0x64,
|
||||
0x6c, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x1c, 0x0a, 0x09, 0x61, 0x6e, 0x6f,
|
||||
0x6e, 0x79, 0x6d, 0x69, 0x7a, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x52, 0x09, 0x61, 0x6e,
|
||||
0x6f, 0x6e, 0x79, 0x6d, 0x69, 0x7a, 0x65, 0x12, 0x16, 0x0a, 0x06, 0x73, 0x74, 0x61, 0x74, 0x75,
|
||||
0x73, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x73, 0x74, 0x61, 0x74, 0x75, 0x73, 0x22,
|
||||
0x29, 0x0a, 0x13, 0x44, 0x65, 0x62, 0x75, 0x67, 0x42, 0x75, 0x6e, 0x64, 0x6c, 0x65, 0x52, 0x65,
|
||||
0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x70, 0x61, 0x74, 0x68, 0x18, 0x01,
|
||||
0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x70, 0x61, 0x74, 0x68, 0x22, 0x3c, 0x0a, 0x12, 0x53, 0x65,
|
||||
0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74,
|
||||
0x12, 0x26, 0x0a, 0x05, 0x6c, 0x65, 0x76, 0x65, 0x6c, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0e, 0x32,
|
||||
0x10, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65,
|
||||
0x6c, 0x52, 0x05, 0x6c, 0x65, 0x76, 0x65, 0x6c, 0x22, 0x15, 0x0a, 0x13, 0x53, 0x65, 0x74, 0x4c,
|
||||
0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x2a,
|
||||
0x62, 0x0a, 0x08, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x12, 0x0b, 0x0a, 0x07, 0x55,
|
||||
0x4e, 0x4b, 0x4e, 0x4f, 0x57, 0x4e, 0x10, 0x00, 0x12, 0x09, 0x0a, 0x05, 0x50, 0x41, 0x4e, 0x49,
|
||||
0x43, 0x10, 0x01, 0x12, 0x09, 0x0a, 0x05, 0x46, 0x41, 0x54, 0x41, 0x4c, 0x10, 0x02, 0x12, 0x09,
|
||||
0x0a, 0x05, 0x45, 0x52, 0x52, 0x4f, 0x52, 0x10, 0x03, 0x12, 0x08, 0x0a, 0x04, 0x57, 0x41, 0x52,
|
||||
0x4e, 0x10, 0x04, 0x12, 0x08, 0x0a, 0x04, 0x49, 0x4e, 0x46, 0x4f, 0x10, 0x05, 0x12, 0x09, 0x0a,
|
||||
0x05, 0x44, 0x45, 0x42, 0x55, 0x47, 0x10, 0x06, 0x12, 0x09, 0x0a, 0x05, 0x54, 0x52, 0x41, 0x43,
|
||||
0x45, 0x10, 0x07, 0x32, 0xee, 0x05, 0x0a, 0x0d, 0x44, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x53, 0x65,
|
||||
0x72, 0x76, 0x69, 0x63, 0x65, 0x12, 0x36, 0x0a, 0x05, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x12, 0x14,
|
||||
0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x71,
|
||||
0x75, 0x65, 0x73, 0x74, 0x1a, 0x15, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x6f,
|
||||
0x67, 0x69, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x4b, 0x0a,
|
||||
0x0c, 0x57, 0x61, 0x69, 0x74, 0x53, 0x53, 0x4f, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x12, 0x1b, 0x2e,
|
||||
0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x57, 0x61, 0x69, 0x74, 0x53, 0x53, 0x4f, 0x4c, 0x6f,
|
||||
0x67, 0x69, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1c, 0x2e, 0x64, 0x61, 0x65,
|
||||
0x6d, 0x6f, 0x6e, 0x2e, 0x57, 0x61, 0x69, 0x74, 0x53, 0x53, 0x4f, 0x4c, 0x6f, 0x67, 0x69, 0x6e,
|
||||
0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x2d, 0x0a, 0x02, 0x55, 0x70,
|
||||
0x12, 0x11, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x55, 0x70, 0x52, 0x65, 0x71, 0x75,
|
||||
0x65, 0x73, 0x74, 0x1a, 0x12, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x55, 0x70, 0x52,
|
||||
0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x39, 0x0a, 0x06, 0x53, 0x74, 0x61,
|
||||
0x74, 0x75, 0x73, 0x12, 0x15, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x74, 0x61,
|
||||
0x74, 0x75, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x16, 0x2e, 0x64, 0x61, 0x65,
|
||||
0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e,
|
||||
0x73, 0x65, 0x22, 0x00, 0x12, 0x33, 0x0a, 0x04, 0x44, 0x6f, 0x77, 0x6e, 0x12, 0x13, 0x2e, 0x64,
|
||||
0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x6f, 0x77, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73,
|
||||
0x74, 0x1a, 0x14, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x6f, 0x77, 0x6e, 0x52,
|
||||
0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x42, 0x0a, 0x09, 0x47, 0x65, 0x74,
|
||||
0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x18, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e,
|
||||
0x47, 0x65, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74,
|
||||
0x1a, 0x19, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x47, 0x65, 0x74, 0x43, 0x6f, 0x6e,
|
||||
0x66, 0x69, 0x67, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x45, 0x0a,
|
||||
0x0a, 0x4c, 0x69, 0x73, 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x12, 0x19, 0x2e, 0x64, 0x61,
|
||||
0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x69, 0x73, 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x52,
|
||||
0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1a, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e,
|
||||
0x4c, 0x69, 0x73, 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e,
|
||||
0x73, 0x65, 0x22, 0x00, 0x12, 0x4b, 0x0a, 0x0c, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x52, 0x6f,
|
||||
0x75, 0x74, 0x65, 0x73, 0x12, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65,
|
||||
0x6c, 0x65, 0x63, 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73,
|
||||
0x74, 0x1a, 0x1c, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x6c, 0x65, 0x63,
|
||||
0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22,
|
||||
0x00, 0x12, 0x4d, 0x0a, 0x0e, 0x44, 0x65, 0x73, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x52, 0x6f, 0x75,
|
||||
0x74, 0x65, 0x73, 0x12, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x6c,
|
||||
0x65, 0x63, 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74,
|
||||
0x1a, 0x1c, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74,
|
||||
0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00,
|
||||
0x12, 0x48, 0x0a, 0x0b, 0x44, 0x65, 0x62, 0x75, 0x67, 0x42, 0x75, 0x6e, 0x64, 0x6c, 0x65, 0x12,
|
||||
0x1a, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x65, 0x62, 0x75, 0x67, 0x42, 0x75,
|
||||
0x6e, 0x64, 0x6c, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1b, 0x2e, 0x64, 0x61,
|
||||
0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x65, 0x62, 0x75, 0x67, 0x42, 0x75, 0x6e, 0x64, 0x6c, 0x65,
|
||||
0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x48, 0x0a, 0x0b, 0x53, 0x65,
|
||||
0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x12, 0x1a, 0x2e, 0x64, 0x61, 0x65, 0x6d,
|
||||
0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x65,
|
||||
0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53,
|
||||
0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e,
|
||||
0x73, 0x65, 0x22, 0x00, 0x42, 0x08, 0x5a, 0x06, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x06,
|
||||
0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33,
|
||||
}
|
||||
|
||||
var (
|
||||
@@ -1641,58 +2201,81 @@ func file_daemon_proto_rawDescGZIP() []byte {
|
||||
return file_daemon_proto_rawDescData
|
||||
}
|
||||
|
||||
var file_daemon_proto_msgTypes = make([]protoimpl.MessageInfo, 19)
|
||||
var file_daemon_proto_enumTypes = make([]protoimpl.EnumInfo, 1)
|
||||
var file_daemon_proto_msgTypes = make([]protoimpl.MessageInfo, 28)
|
||||
var file_daemon_proto_goTypes = []interface{}{
|
||||
(*LoginRequest)(nil), // 0: daemon.LoginRequest
|
||||
(*LoginResponse)(nil), // 1: daemon.LoginResponse
|
||||
(*WaitSSOLoginRequest)(nil), // 2: daemon.WaitSSOLoginRequest
|
||||
(*WaitSSOLoginResponse)(nil), // 3: daemon.WaitSSOLoginResponse
|
||||
(*UpRequest)(nil), // 4: daemon.UpRequest
|
||||
(*UpResponse)(nil), // 5: daemon.UpResponse
|
||||
(*StatusRequest)(nil), // 6: daemon.StatusRequest
|
||||
(*StatusResponse)(nil), // 7: daemon.StatusResponse
|
||||
(*DownRequest)(nil), // 8: daemon.DownRequest
|
||||
(*DownResponse)(nil), // 9: daemon.DownResponse
|
||||
(*GetConfigRequest)(nil), // 10: daemon.GetConfigRequest
|
||||
(*GetConfigResponse)(nil), // 11: daemon.GetConfigResponse
|
||||
(*PeerState)(nil), // 12: daemon.PeerState
|
||||
(*LocalPeerState)(nil), // 13: daemon.LocalPeerState
|
||||
(*SignalState)(nil), // 14: daemon.SignalState
|
||||
(*ManagementState)(nil), // 15: daemon.ManagementState
|
||||
(*RelayState)(nil), // 16: daemon.RelayState
|
||||
(*NSGroupState)(nil), // 17: daemon.NSGroupState
|
||||
(*FullStatus)(nil), // 18: daemon.FullStatus
|
||||
(*timestamppb.Timestamp)(nil), // 19: google.protobuf.Timestamp
|
||||
(*durationpb.Duration)(nil), // 20: google.protobuf.Duration
|
||||
(LogLevel)(0), // 0: daemon.LogLevel
|
||||
(*LoginRequest)(nil), // 1: daemon.LoginRequest
|
||||
(*LoginResponse)(nil), // 2: daemon.LoginResponse
|
||||
(*WaitSSOLoginRequest)(nil), // 3: daemon.WaitSSOLoginRequest
|
||||
(*WaitSSOLoginResponse)(nil), // 4: daemon.WaitSSOLoginResponse
|
||||
(*UpRequest)(nil), // 5: daemon.UpRequest
|
||||
(*UpResponse)(nil), // 6: daemon.UpResponse
|
||||
(*StatusRequest)(nil), // 7: daemon.StatusRequest
|
||||
(*StatusResponse)(nil), // 8: daemon.StatusResponse
|
||||
(*DownRequest)(nil), // 9: daemon.DownRequest
|
||||
(*DownResponse)(nil), // 10: daemon.DownResponse
|
||||
(*GetConfigRequest)(nil), // 11: daemon.GetConfigRequest
|
||||
(*GetConfigResponse)(nil), // 12: daemon.GetConfigResponse
|
||||
(*PeerState)(nil), // 13: daemon.PeerState
|
||||
(*LocalPeerState)(nil), // 14: daemon.LocalPeerState
|
||||
(*SignalState)(nil), // 15: daemon.SignalState
|
||||
(*ManagementState)(nil), // 16: daemon.ManagementState
|
||||
(*RelayState)(nil), // 17: daemon.RelayState
|
||||
(*NSGroupState)(nil), // 18: daemon.NSGroupState
|
||||
(*FullStatus)(nil), // 19: daemon.FullStatus
|
||||
(*ListRoutesRequest)(nil), // 20: daemon.ListRoutesRequest
|
||||
(*ListRoutesResponse)(nil), // 21: daemon.ListRoutesResponse
|
||||
(*SelectRoutesRequest)(nil), // 22: daemon.SelectRoutesRequest
|
||||
(*SelectRoutesResponse)(nil), // 23: daemon.SelectRoutesResponse
|
||||
(*Route)(nil), // 24: daemon.Route
|
||||
(*DebugBundleRequest)(nil), // 25: daemon.DebugBundleRequest
|
||||
(*DebugBundleResponse)(nil), // 26: daemon.DebugBundleResponse
|
||||
(*SetLogLevelRequest)(nil), // 27: daemon.SetLogLevelRequest
|
||||
(*SetLogLevelResponse)(nil), // 28: daemon.SetLogLevelResponse
|
||||
(*timestamp.Timestamp)(nil), // 29: google.protobuf.Timestamp
|
||||
(*duration.Duration)(nil), // 30: google.protobuf.Duration
|
||||
}
|
||||
var file_daemon_proto_depIdxs = []int32{
|
||||
18, // 0: daemon.StatusResponse.fullStatus:type_name -> daemon.FullStatus
|
||||
19, // 1: daemon.PeerState.connStatusUpdate:type_name -> google.protobuf.Timestamp
|
||||
19, // 2: daemon.PeerState.lastWireguardHandshake:type_name -> google.protobuf.Timestamp
|
||||
20, // 3: daemon.PeerState.latency:type_name -> google.protobuf.Duration
|
||||
15, // 4: daemon.FullStatus.managementState:type_name -> daemon.ManagementState
|
||||
14, // 5: daemon.FullStatus.signalState:type_name -> daemon.SignalState
|
||||
13, // 6: daemon.FullStatus.localPeerState:type_name -> daemon.LocalPeerState
|
||||
12, // 7: daemon.FullStatus.peers:type_name -> daemon.PeerState
|
||||
16, // 8: daemon.FullStatus.relays:type_name -> daemon.RelayState
|
||||
17, // 9: daemon.FullStatus.dns_servers:type_name -> daemon.NSGroupState
|
||||
0, // 10: daemon.DaemonService.Login:input_type -> daemon.LoginRequest
|
||||
2, // 11: daemon.DaemonService.WaitSSOLogin:input_type -> daemon.WaitSSOLoginRequest
|
||||
4, // 12: daemon.DaemonService.Up:input_type -> daemon.UpRequest
|
||||
6, // 13: daemon.DaemonService.Status:input_type -> daemon.StatusRequest
|
||||
8, // 14: daemon.DaemonService.Down:input_type -> daemon.DownRequest
|
||||
10, // 15: daemon.DaemonService.GetConfig:input_type -> daemon.GetConfigRequest
|
||||
1, // 16: daemon.DaemonService.Login:output_type -> daemon.LoginResponse
|
||||
3, // 17: daemon.DaemonService.WaitSSOLogin:output_type -> daemon.WaitSSOLoginResponse
|
||||
5, // 18: daemon.DaemonService.Up:output_type -> daemon.UpResponse
|
||||
7, // 19: daemon.DaemonService.Status:output_type -> daemon.StatusResponse
|
||||
9, // 20: daemon.DaemonService.Down:output_type -> daemon.DownResponse
|
||||
11, // 21: daemon.DaemonService.GetConfig:output_type -> daemon.GetConfigResponse
|
||||
16, // [16:22] is the sub-list for method output_type
|
||||
10, // [10:16] is the sub-list for method input_type
|
||||
10, // [10:10] is the sub-list for extension type_name
|
||||
10, // [10:10] is the sub-list for extension extendee
|
||||
0, // [0:10] is the sub-list for field type_name
|
||||
19, // 0: daemon.StatusResponse.fullStatus:type_name -> daemon.FullStatus
|
||||
29, // 1: daemon.PeerState.connStatusUpdate:type_name -> google.protobuf.Timestamp
|
||||
29, // 2: daemon.PeerState.lastWireguardHandshake:type_name -> google.protobuf.Timestamp
|
||||
30, // 3: daemon.PeerState.latency:type_name -> google.protobuf.Duration
|
||||
16, // 4: daemon.FullStatus.managementState:type_name -> daemon.ManagementState
|
||||
15, // 5: daemon.FullStatus.signalState:type_name -> daemon.SignalState
|
||||
14, // 6: daemon.FullStatus.localPeerState:type_name -> daemon.LocalPeerState
|
||||
13, // 7: daemon.FullStatus.peers:type_name -> daemon.PeerState
|
||||
17, // 8: daemon.FullStatus.relays:type_name -> daemon.RelayState
|
||||
18, // 9: daemon.FullStatus.dns_servers:type_name -> daemon.NSGroupState
|
||||
24, // 10: daemon.ListRoutesResponse.routes:type_name -> daemon.Route
|
||||
0, // 11: daemon.SetLogLevelRequest.level:type_name -> daemon.LogLevel
|
||||
1, // 12: daemon.DaemonService.Login:input_type -> daemon.LoginRequest
|
||||
3, // 13: daemon.DaemonService.WaitSSOLogin:input_type -> daemon.WaitSSOLoginRequest
|
||||
5, // 14: daemon.DaemonService.Up:input_type -> daemon.UpRequest
|
||||
7, // 15: daemon.DaemonService.Status:input_type -> daemon.StatusRequest
|
||||
9, // 16: daemon.DaemonService.Down:input_type -> daemon.DownRequest
|
||||
11, // 17: daemon.DaemonService.GetConfig:input_type -> daemon.GetConfigRequest
|
||||
20, // 18: daemon.DaemonService.ListRoutes:input_type -> daemon.ListRoutesRequest
|
||||
22, // 19: daemon.DaemonService.SelectRoutes:input_type -> daemon.SelectRoutesRequest
|
||||
22, // 20: daemon.DaemonService.DeselectRoutes:input_type -> daemon.SelectRoutesRequest
|
||||
25, // 21: daemon.DaemonService.DebugBundle:input_type -> daemon.DebugBundleRequest
|
||||
27, // 22: daemon.DaemonService.SetLogLevel:input_type -> daemon.SetLogLevelRequest
|
||||
2, // 23: daemon.DaemonService.Login:output_type -> daemon.LoginResponse
|
||||
4, // 24: daemon.DaemonService.WaitSSOLogin:output_type -> daemon.WaitSSOLoginResponse
|
||||
6, // 25: daemon.DaemonService.Up:output_type -> daemon.UpResponse
|
||||
8, // 26: daemon.DaemonService.Status:output_type -> daemon.StatusResponse
|
||||
10, // 27: daemon.DaemonService.Down:output_type -> daemon.DownResponse
|
||||
12, // 28: daemon.DaemonService.GetConfig:output_type -> daemon.GetConfigResponse
|
||||
21, // 29: daemon.DaemonService.ListRoutes:output_type -> daemon.ListRoutesResponse
|
||||
23, // 30: daemon.DaemonService.SelectRoutes:output_type -> daemon.SelectRoutesResponse
|
||||
23, // 31: daemon.DaemonService.DeselectRoutes:output_type -> daemon.SelectRoutesResponse
|
||||
26, // 32: daemon.DaemonService.DebugBundle:output_type -> daemon.DebugBundleResponse
|
||||
28, // 33: daemon.DaemonService.SetLogLevel:output_type -> daemon.SetLogLevelResponse
|
||||
23, // [23:34] is the sub-list for method output_type
|
||||
12, // [12:23] is the sub-list for method input_type
|
||||
12, // [12:12] is the sub-list for extension type_name
|
||||
12, // [12:12] is the sub-list for extension extendee
|
||||
0, // [0:12] is the sub-list for field type_name
|
||||
}
|
||||
|
||||
func init() { file_daemon_proto_init() }
|
||||
@@ -1929,6 +2512,114 @@ func file_daemon_proto_init() {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
file_daemon_proto_msgTypes[19].Exporter = func(v interface{}, i int) interface{} {
|
||||
switch v := v.(*ListRoutesRequest); i {
|
||||
case 0:
|
||||
return &v.state
|
||||
case 1:
|
||||
return &v.sizeCache
|
||||
case 2:
|
||||
return &v.unknownFields
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
file_daemon_proto_msgTypes[20].Exporter = func(v interface{}, i int) interface{} {
|
||||
switch v := v.(*ListRoutesResponse); i {
|
||||
case 0:
|
||||
return &v.state
|
||||
case 1:
|
||||
return &v.sizeCache
|
||||
case 2:
|
||||
return &v.unknownFields
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
file_daemon_proto_msgTypes[21].Exporter = func(v interface{}, i int) interface{} {
|
||||
switch v := v.(*SelectRoutesRequest); i {
|
||||
case 0:
|
||||
return &v.state
|
||||
case 1:
|
||||
return &v.sizeCache
|
||||
case 2:
|
||||
return &v.unknownFields
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
file_daemon_proto_msgTypes[22].Exporter = func(v interface{}, i int) interface{} {
|
||||
switch v := v.(*SelectRoutesResponse); i {
|
||||
case 0:
|
||||
return &v.state
|
||||
case 1:
|
||||
return &v.sizeCache
|
||||
case 2:
|
||||
return &v.unknownFields
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
file_daemon_proto_msgTypes[23].Exporter = func(v interface{}, i int) interface{} {
|
||||
switch v := v.(*Route); i {
|
||||
case 0:
|
||||
return &v.state
|
||||
case 1:
|
||||
return &v.sizeCache
|
||||
case 2:
|
||||
return &v.unknownFields
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
file_daemon_proto_msgTypes[24].Exporter = func(v interface{}, i int) interface{} {
|
||||
switch v := v.(*DebugBundleRequest); i {
|
||||
case 0:
|
||||
return &v.state
|
||||
case 1:
|
||||
return &v.sizeCache
|
||||
case 2:
|
||||
return &v.unknownFields
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
file_daemon_proto_msgTypes[25].Exporter = func(v interface{}, i int) interface{} {
|
||||
switch v := v.(*DebugBundleResponse); i {
|
||||
case 0:
|
||||
return &v.state
|
||||
case 1:
|
||||
return &v.sizeCache
|
||||
case 2:
|
||||
return &v.unknownFields
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
file_daemon_proto_msgTypes[26].Exporter = func(v interface{}, i int) interface{} {
|
||||
switch v := v.(*SetLogLevelRequest); i {
|
||||
case 0:
|
||||
return &v.state
|
||||
case 1:
|
||||
return &v.sizeCache
|
||||
case 2:
|
||||
return &v.unknownFields
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
file_daemon_proto_msgTypes[27].Exporter = func(v interface{}, i int) interface{} {
|
||||
switch v := v.(*SetLogLevelResponse); i {
|
||||
case 0:
|
||||
return &v.state
|
||||
case 1:
|
||||
return &v.sizeCache
|
||||
case 2:
|
||||
return &v.unknownFields
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
file_daemon_proto_msgTypes[0].OneofWrappers = []interface{}{}
|
||||
type x struct{}
|
||||
@@ -1936,13 +2627,14 @@ func file_daemon_proto_init() {
|
||||
File: protoimpl.DescBuilder{
|
||||
GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
|
||||
RawDescriptor: file_daemon_proto_rawDesc,
|
||||
NumEnums: 0,
|
||||
NumMessages: 19,
|
||||
NumEnums: 1,
|
||||
NumMessages: 28,
|
||||
NumExtensions: 0,
|
||||
NumServices: 1,
|
||||
},
|
||||
GoTypes: file_daemon_proto_goTypes,
|
||||
DependencyIndexes: file_daemon_proto_depIdxs,
|
||||
EnumInfos: file_daemon_proto_enumTypes,
|
||||
MessageInfos: file_daemon_proto_msgTypes,
|
||||
}.Build()
|
||||
File_daemon_proto = out.File
|
||||
|
||||
@@ -27,6 +27,21 @@ service DaemonService {
|
||||
|
||||
// GetConfig of the daemon.
|
||||
rpc GetConfig(GetConfigRequest) returns (GetConfigResponse) {}
|
||||
|
||||
// List available network routes
|
||||
rpc ListRoutes(ListRoutesRequest) returns (ListRoutesResponse) {}
|
||||
|
||||
// Select specific routes
|
||||
rpc SelectRoutes(SelectRoutesRequest) returns (SelectRoutesResponse) {}
|
||||
|
||||
// Deselect specific routes
|
||||
rpc DeselectRoutes(SelectRoutesRequest) returns (SelectRoutesResponse) {}
|
||||
|
||||
// DebugBundle creates a debug bundle
|
||||
rpc DebugBundle(DebugBundleRequest) returns (DebugBundleResponse) {}
|
||||
|
||||
// SetLogLevel sets the log level of the daemon
|
||||
rpc SetLogLevel(SetLogLevelRequest) returns (SetLogLevelResponse) {}
|
||||
};
|
||||
|
||||
message LoginRequest {
|
||||
@@ -195,4 +210,53 @@ message FullStatus {
|
||||
repeated PeerState peers = 4;
|
||||
repeated RelayState relays = 5;
|
||||
repeated NSGroupState dns_servers = 6;
|
||||
}
|
||||
|
||||
message ListRoutesRequest {
|
||||
}
|
||||
|
||||
message ListRoutesResponse {
|
||||
repeated Route routes = 1;
|
||||
}
|
||||
|
||||
message SelectRoutesRequest {
|
||||
repeated string routeIDs = 1;
|
||||
bool append = 2;
|
||||
bool all = 3;
|
||||
}
|
||||
|
||||
message SelectRoutesResponse {
|
||||
}
|
||||
|
||||
message Route {
|
||||
string ID = 1;
|
||||
string network = 2;
|
||||
bool selected = 3;
|
||||
}
|
||||
|
||||
message DebugBundleRequest {
|
||||
bool anonymize = 1;
|
||||
string status = 2;
|
||||
}
|
||||
|
||||
message DebugBundleResponse {
|
||||
string path = 1;
|
||||
}
|
||||
|
||||
enum LogLevel {
|
||||
UNKNOWN = 0;
|
||||
PANIC = 1;
|
||||
FATAL = 2;
|
||||
ERROR = 3;
|
||||
WARN = 4;
|
||||
INFO = 5;
|
||||
DEBUG = 6;
|
||||
TRACE = 7;
|
||||
}
|
||||
|
||||
message SetLogLevelRequest {
|
||||
LogLevel level = 1;
|
||||
}
|
||||
|
||||
message SetLogLevelResponse {
|
||||
}
|
||||
@@ -31,6 +31,16 @@ type DaemonServiceClient interface {
|
||||
Down(ctx context.Context, in *DownRequest, opts ...grpc.CallOption) (*DownResponse, error)
|
||||
// GetConfig of the daemon.
|
||||
GetConfig(ctx context.Context, in *GetConfigRequest, opts ...grpc.CallOption) (*GetConfigResponse, error)
|
||||
// List available network routes
|
||||
ListRoutes(ctx context.Context, in *ListRoutesRequest, opts ...grpc.CallOption) (*ListRoutesResponse, error)
|
||||
// Select specific routes
|
||||
SelectRoutes(ctx context.Context, in *SelectRoutesRequest, opts ...grpc.CallOption) (*SelectRoutesResponse, error)
|
||||
// Deselect specific routes
|
||||
DeselectRoutes(ctx context.Context, in *SelectRoutesRequest, opts ...grpc.CallOption) (*SelectRoutesResponse, error)
|
||||
// DebugBundle creates a debug bundle
|
||||
DebugBundle(ctx context.Context, in *DebugBundleRequest, opts ...grpc.CallOption) (*DebugBundleResponse, error)
|
||||
// SetLogLevel sets the log level of the daemon
|
||||
SetLogLevel(ctx context.Context, in *SetLogLevelRequest, opts ...grpc.CallOption) (*SetLogLevelResponse, error)
|
||||
}
|
||||
|
||||
type daemonServiceClient struct {
|
||||
@@ -95,6 +105,51 @@ func (c *daemonServiceClient) GetConfig(ctx context.Context, in *GetConfigReques
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (c *daemonServiceClient) ListRoutes(ctx context.Context, in *ListRoutesRequest, opts ...grpc.CallOption) (*ListRoutesResponse, error) {
|
||||
out := new(ListRoutesResponse)
|
||||
err := c.cc.Invoke(ctx, "/daemon.DaemonService/ListRoutes", in, out, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (c *daemonServiceClient) SelectRoutes(ctx context.Context, in *SelectRoutesRequest, opts ...grpc.CallOption) (*SelectRoutesResponse, error) {
|
||||
out := new(SelectRoutesResponse)
|
||||
err := c.cc.Invoke(ctx, "/daemon.DaemonService/SelectRoutes", in, out, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (c *daemonServiceClient) DeselectRoutes(ctx context.Context, in *SelectRoutesRequest, opts ...grpc.CallOption) (*SelectRoutesResponse, error) {
|
||||
out := new(SelectRoutesResponse)
|
||||
err := c.cc.Invoke(ctx, "/daemon.DaemonService/DeselectRoutes", in, out, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (c *daemonServiceClient) DebugBundle(ctx context.Context, in *DebugBundleRequest, opts ...grpc.CallOption) (*DebugBundleResponse, error) {
|
||||
out := new(DebugBundleResponse)
|
||||
err := c.cc.Invoke(ctx, "/daemon.DaemonService/DebugBundle", in, out, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (c *daemonServiceClient) SetLogLevel(ctx context.Context, in *SetLogLevelRequest, opts ...grpc.CallOption) (*SetLogLevelResponse, error) {
|
||||
out := new(SetLogLevelResponse)
|
||||
err := c.cc.Invoke(ctx, "/daemon.DaemonService/SetLogLevel", in, out, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// DaemonServiceServer is the server API for DaemonService service.
|
||||
// All implementations must embed UnimplementedDaemonServiceServer
|
||||
// for forward compatibility
|
||||
@@ -112,6 +167,16 @@ type DaemonServiceServer interface {
|
||||
Down(context.Context, *DownRequest) (*DownResponse, error)
|
||||
// GetConfig of the daemon.
|
||||
GetConfig(context.Context, *GetConfigRequest) (*GetConfigResponse, error)
|
||||
// List available network routes
|
||||
ListRoutes(context.Context, *ListRoutesRequest) (*ListRoutesResponse, error)
|
||||
// Select specific routes
|
||||
SelectRoutes(context.Context, *SelectRoutesRequest) (*SelectRoutesResponse, error)
|
||||
// Deselect specific routes
|
||||
DeselectRoutes(context.Context, *SelectRoutesRequest) (*SelectRoutesResponse, error)
|
||||
// DebugBundle creates a debug bundle
|
||||
DebugBundle(context.Context, *DebugBundleRequest) (*DebugBundleResponse, error)
|
||||
// SetLogLevel sets the log level of the daemon
|
||||
SetLogLevel(context.Context, *SetLogLevelRequest) (*SetLogLevelResponse, error)
|
||||
mustEmbedUnimplementedDaemonServiceServer()
|
||||
}
|
||||
|
||||
@@ -137,6 +202,21 @@ func (UnimplementedDaemonServiceServer) Down(context.Context, *DownRequest) (*Do
|
||||
func (UnimplementedDaemonServiceServer) GetConfig(context.Context, *GetConfigRequest) (*GetConfigResponse, error) {
|
||||
return nil, status.Errorf(codes.Unimplemented, "method GetConfig not implemented")
|
||||
}
|
||||
func (UnimplementedDaemonServiceServer) ListRoutes(context.Context, *ListRoutesRequest) (*ListRoutesResponse, error) {
|
||||
return nil, status.Errorf(codes.Unimplemented, "method ListRoutes not implemented")
|
||||
}
|
||||
func (UnimplementedDaemonServiceServer) SelectRoutes(context.Context, *SelectRoutesRequest) (*SelectRoutesResponse, error) {
|
||||
return nil, status.Errorf(codes.Unimplemented, "method SelectRoutes not implemented")
|
||||
}
|
||||
func (UnimplementedDaemonServiceServer) DeselectRoutes(context.Context, *SelectRoutesRequest) (*SelectRoutesResponse, error) {
|
||||
return nil, status.Errorf(codes.Unimplemented, "method DeselectRoutes not implemented")
|
||||
}
|
||||
func (UnimplementedDaemonServiceServer) DebugBundle(context.Context, *DebugBundleRequest) (*DebugBundleResponse, error) {
|
||||
return nil, status.Errorf(codes.Unimplemented, "method DebugBundle not implemented")
|
||||
}
|
||||
func (UnimplementedDaemonServiceServer) SetLogLevel(context.Context, *SetLogLevelRequest) (*SetLogLevelResponse, error) {
|
||||
return nil, status.Errorf(codes.Unimplemented, "method SetLogLevel not implemented")
|
||||
}
|
||||
func (UnimplementedDaemonServiceServer) mustEmbedUnimplementedDaemonServiceServer() {}
|
||||
|
||||
// UnsafeDaemonServiceServer may be embedded to opt out of forward compatibility for this service.
|
||||
@@ -258,6 +338,96 @@ func _DaemonService_GetConfig_Handler(srv interface{}, ctx context.Context, dec
|
||||
return interceptor(ctx, in, info, handler)
|
||||
}
|
||||
|
||||
func _DaemonService_ListRoutes_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
|
||||
in := new(ListRoutesRequest)
|
||||
if err := dec(in); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if interceptor == nil {
|
||||
return srv.(DaemonServiceServer).ListRoutes(ctx, in)
|
||||
}
|
||||
info := &grpc.UnaryServerInfo{
|
||||
Server: srv,
|
||||
FullMethod: "/daemon.DaemonService/ListRoutes",
|
||||
}
|
||||
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return srv.(DaemonServiceServer).ListRoutes(ctx, req.(*ListRoutesRequest))
|
||||
}
|
||||
return interceptor(ctx, in, info, handler)
|
||||
}
|
||||
|
||||
func _DaemonService_SelectRoutes_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
|
||||
in := new(SelectRoutesRequest)
|
||||
if err := dec(in); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if interceptor == nil {
|
||||
return srv.(DaemonServiceServer).SelectRoutes(ctx, in)
|
||||
}
|
||||
info := &grpc.UnaryServerInfo{
|
||||
Server: srv,
|
||||
FullMethod: "/daemon.DaemonService/SelectRoutes",
|
||||
}
|
||||
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return srv.(DaemonServiceServer).SelectRoutes(ctx, req.(*SelectRoutesRequest))
|
||||
}
|
||||
return interceptor(ctx, in, info, handler)
|
||||
}
|
||||
|
||||
func _DaemonService_DeselectRoutes_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
|
||||
in := new(SelectRoutesRequest)
|
||||
if err := dec(in); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if interceptor == nil {
|
||||
return srv.(DaemonServiceServer).DeselectRoutes(ctx, in)
|
||||
}
|
||||
info := &grpc.UnaryServerInfo{
|
||||
Server: srv,
|
||||
FullMethod: "/daemon.DaemonService/DeselectRoutes",
|
||||
}
|
||||
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return srv.(DaemonServiceServer).DeselectRoutes(ctx, req.(*SelectRoutesRequest))
|
||||
}
|
||||
return interceptor(ctx, in, info, handler)
|
||||
}
|
||||
|
||||
func _DaemonService_DebugBundle_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
|
||||
in := new(DebugBundleRequest)
|
||||
if err := dec(in); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if interceptor == nil {
|
||||
return srv.(DaemonServiceServer).DebugBundle(ctx, in)
|
||||
}
|
||||
info := &grpc.UnaryServerInfo{
|
||||
Server: srv,
|
||||
FullMethod: "/daemon.DaemonService/DebugBundle",
|
||||
}
|
||||
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return srv.(DaemonServiceServer).DebugBundle(ctx, req.(*DebugBundleRequest))
|
||||
}
|
||||
return interceptor(ctx, in, info, handler)
|
||||
}
|
||||
|
||||
func _DaemonService_SetLogLevel_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
|
||||
in := new(SetLogLevelRequest)
|
||||
if err := dec(in); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if interceptor == nil {
|
||||
return srv.(DaemonServiceServer).SetLogLevel(ctx, in)
|
||||
}
|
||||
info := &grpc.UnaryServerInfo{
|
||||
Server: srv,
|
||||
FullMethod: "/daemon.DaemonService/SetLogLevel",
|
||||
}
|
||||
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return srv.(DaemonServiceServer).SetLogLevel(ctx, req.(*SetLogLevelRequest))
|
||||
}
|
||||
return interceptor(ctx, in, info, handler)
|
||||
}
|
||||
|
||||
// DaemonService_ServiceDesc is the grpc.ServiceDesc for DaemonService service.
|
||||
// It's only intended for direct use with grpc.RegisterService,
|
||||
// and not to be introspected or modified (even as a copy)
|
||||
@@ -289,6 +459,26 @@ var DaemonService_ServiceDesc = grpc.ServiceDesc{
|
||||
MethodName: "GetConfig",
|
||||
Handler: _DaemonService_GetConfig_Handler,
|
||||
},
|
||||
{
|
||||
MethodName: "ListRoutes",
|
||||
Handler: _DaemonService_ListRoutes_Handler,
|
||||
},
|
||||
{
|
||||
MethodName: "SelectRoutes",
|
||||
Handler: _DaemonService_SelectRoutes_Handler,
|
||||
},
|
||||
{
|
||||
MethodName: "DeselectRoutes",
|
||||
Handler: _DaemonService_DeselectRoutes_Handler,
|
||||
},
|
||||
{
|
||||
MethodName: "DebugBundle",
|
||||
Handler: _DaemonService_DebugBundle_Handler,
|
||||
},
|
||||
{
|
||||
MethodName: "SetLogLevel",
|
||||
Handler: _DaemonService_SetLogLevel_Handler,
|
||||
},
|
||||
},
|
||||
Streams: []grpc.StreamDesc{},
|
||||
Metadata: "daemon.proto",
|
||||
|
||||
175
client/server/debug.go
Normal file
175
client/server/debug.go
Normal file
@@ -0,0 +1,175 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"archive/zip"
|
||||
"bufio"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/anonymize"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
)
|
||||
|
||||
// DebugBundle creates a debug bundle and returns the location.
|
||||
func (s *Server) DebugBundle(_ context.Context, req *proto.DebugBundleRequest) (resp *proto.DebugBundleResponse, err error) {
|
||||
s.mutex.Lock()
|
||||
defer s.mutex.Unlock()
|
||||
|
||||
if s.logFile == "console" {
|
||||
return nil, fmt.Errorf("log file is set to console, cannot create debug bundle")
|
||||
}
|
||||
|
||||
bundlePath, err := os.CreateTemp("", "netbird.debug.*.zip")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create zip file: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := bundlePath.Close(); err != nil {
|
||||
log.Errorf("failed to close zip file: %v", err)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
if err2 := os.Remove(bundlePath.Name()); err2 != nil {
|
||||
log.Errorf("Failed to remove zip file: %v", err2)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
archive := zip.NewWriter(bundlePath)
|
||||
defer func() {
|
||||
if err := archive.Close(); err != nil {
|
||||
log.Errorf("failed to close archive writer: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
if status := req.GetStatus(); status != "" {
|
||||
filename := "status.txt"
|
||||
if req.GetAnonymize() {
|
||||
filename = "status.anon.txt"
|
||||
}
|
||||
statusReader := strings.NewReader(status)
|
||||
if err := addFileToZip(archive, statusReader, filename); err != nil {
|
||||
return nil, fmt.Errorf("add status file to zip: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
logFile, err := os.Open(s.logFile)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("open log file: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := logFile.Close(); err != nil {
|
||||
log.Errorf("failed to close original log file: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
filename := "client.log.txt"
|
||||
var logReader io.Reader
|
||||
errChan := make(chan error, 1)
|
||||
if req.GetAnonymize() {
|
||||
filename = "client.anon.log.txt"
|
||||
var writer io.WriteCloser
|
||||
logReader, writer = io.Pipe()
|
||||
|
||||
go s.anonymize(logFile, writer, errChan)
|
||||
} else {
|
||||
logReader = logFile
|
||||
}
|
||||
if err := addFileToZip(archive, logReader, filename); err != nil {
|
||||
return nil, fmt.Errorf("add log file to zip: %w", err)
|
||||
}
|
||||
|
||||
select {
|
||||
case err := <-errChan:
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
default:
|
||||
}
|
||||
|
||||
return &proto.DebugBundleResponse{Path: bundlePath.Name()}, nil
|
||||
}
|
||||
|
||||
func (s *Server) anonymize(reader io.Reader, writer io.WriteCloser, errChan chan<- error) {
|
||||
scanner := bufio.NewScanner(reader)
|
||||
anonymizer := anonymize.NewAnonymizer(anonymize.DefaultAddresses())
|
||||
|
||||
status := s.statusRecorder.GetFullStatus()
|
||||
seedFromStatus(anonymizer, &status)
|
||||
|
||||
defer func() {
|
||||
if err := writer.Close(); err != nil {
|
||||
log.Errorf("Failed to close writer: %v", err)
|
||||
}
|
||||
}()
|
||||
for scanner.Scan() {
|
||||
line := anonymizer.AnonymizeString(scanner.Text())
|
||||
if _, err := writer.Write([]byte(line + "\n")); err != nil {
|
||||
errChan <- fmt.Errorf("write line to writer: %w", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
if err := scanner.Err(); err != nil {
|
||||
errChan <- fmt.Errorf("read line from scanner: %w", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// SetLogLevel sets the logging level for the server.
|
||||
func (s *Server) SetLogLevel(_ context.Context, req *proto.SetLogLevelRequest) (*proto.SetLogLevelResponse, error) {
|
||||
level, err := log.ParseLevel(req.Level.String())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid log level: %w", err)
|
||||
}
|
||||
|
||||
log.SetLevel(level)
|
||||
log.Infof("Log level set to %s", level.String())
|
||||
return &proto.SetLogLevelResponse{}, nil
|
||||
}
|
||||
|
||||
func addFileToZip(archive *zip.Writer, reader io.Reader, filename string) error {
|
||||
header := &zip.FileHeader{
|
||||
Name: filename,
|
||||
Method: zip.Deflate,
|
||||
}
|
||||
|
||||
writer, err := archive.CreateHeader(header)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create zip file header: %w", err)
|
||||
}
|
||||
|
||||
if _, err := io.Copy(writer, reader); err != nil {
|
||||
return fmt.Errorf("write file to zip: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func seedFromStatus(a *anonymize.Anonymizer, status *peer.FullStatus) {
|
||||
status.ManagementState.URL = a.AnonymizeURI(status.ManagementState.URL)
|
||||
status.SignalState.URL = a.AnonymizeURI(status.SignalState.URL)
|
||||
|
||||
status.LocalPeerState.FQDN = a.AnonymizeDomain(status.LocalPeerState.FQDN)
|
||||
|
||||
for _, peer := range status.Peers {
|
||||
a.AnonymizeDomain(peer.FQDN)
|
||||
}
|
||||
|
||||
for _, nsGroup := range status.NSGroupStates {
|
||||
for _, domain := range nsGroup.Domains {
|
||||
a.AnonymizeDomain(domain)
|
||||
}
|
||||
}
|
||||
|
||||
for _, relay := range status.Relays {
|
||||
if relay.URI != nil {
|
||||
a.AnonymizeURI(relay.URI.String())
|
||||
}
|
||||
}
|
||||
}
|
||||
110
client/server/route.go
Normal file
110
client/server/route.go
Normal file
@@ -0,0 +1,110 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"sort"
|
||||
|
||||
"golang.org/x/exp/maps"
|
||||
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
)
|
||||
|
||||
type selectRoute struct {
|
||||
NetID string
|
||||
Network netip.Prefix
|
||||
Selected bool
|
||||
}
|
||||
|
||||
// ListRoutes returns a list of all available routes.
|
||||
func (s *Server) ListRoutes(ctx context.Context, req *proto.ListRoutesRequest) (*proto.ListRoutesResponse, error) {
|
||||
s.mutex.Lock()
|
||||
defer s.mutex.Unlock()
|
||||
|
||||
if s.engine == nil {
|
||||
return nil, fmt.Errorf("not connected")
|
||||
}
|
||||
|
||||
routesMap := s.engine.GetClientRoutesWithNetID()
|
||||
routeSelector := s.engine.GetRouteManager().GetRouteSelector()
|
||||
|
||||
var routes []*selectRoute
|
||||
for id, rt := range routesMap {
|
||||
if len(rt) == 0 {
|
||||
continue
|
||||
}
|
||||
route := &selectRoute{
|
||||
NetID: id,
|
||||
Network: rt[0].Network,
|
||||
Selected: routeSelector.IsSelected(id),
|
||||
}
|
||||
routes = append(routes, route)
|
||||
}
|
||||
|
||||
sort.Slice(routes, func(i, j int) bool {
|
||||
iPrefix := routes[i].Network.Bits()
|
||||
jPrefix := routes[j].Network.Bits()
|
||||
|
||||
if iPrefix == jPrefix {
|
||||
iAddr := routes[i].Network.Addr()
|
||||
jAddr := routes[j].Network.Addr()
|
||||
if iAddr == jAddr {
|
||||
return routes[i].NetID < routes[j].NetID
|
||||
}
|
||||
return iAddr.String() < jAddr.String()
|
||||
}
|
||||
return iPrefix < jPrefix
|
||||
})
|
||||
|
||||
var pbRoutes []*proto.Route
|
||||
for _, route := range routes {
|
||||
pbRoutes = append(pbRoutes, &proto.Route{
|
||||
ID: route.NetID,
|
||||
Network: route.Network.String(),
|
||||
Selected: route.Selected,
|
||||
})
|
||||
}
|
||||
|
||||
return &proto.ListRoutesResponse{
|
||||
Routes: pbRoutes,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// SelectRoutes selects specific routes based on the client request.
|
||||
func (s *Server) SelectRoutes(_ context.Context, req *proto.SelectRoutesRequest) (*proto.SelectRoutesResponse, error) {
|
||||
s.mutex.Lock()
|
||||
defer s.mutex.Unlock()
|
||||
|
||||
routeManager := s.engine.GetRouteManager()
|
||||
routeSelector := routeManager.GetRouteSelector()
|
||||
if req.GetAll() {
|
||||
routeSelector.SelectAllRoutes()
|
||||
} else {
|
||||
if err := routeSelector.SelectRoutes(req.GetRouteIDs(), req.GetAppend(), maps.Keys(s.engine.GetClientRoutesWithNetID())); err != nil {
|
||||
return nil, fmt.Errorf("select routes: %w", err)
|
||||
}
|
||||
}
|
||||
routeManager.TriggerSelection(s.engine.GetClientRoutes())
|
||||
|
||||
return &proto.SelectRoutesResponse{}, nil
|
||||
}
|
||||
|
||||
// DeselectRoutes deselects specific routes based on the client request.
|
||||
func (s *Server) DeselectRoutes(_ context.Context, req *proto.SelectRoutesRequest) (*proto.SelectRoutesResponse, error) {
|
||||
s.mutex.Lock()
|
||||
defer s.mutex.Unlock()
|
||||
|
||||
routeManager := s.engine.GetRouteManager()
|
||||
routeSelector := routeManager.GetRouteSelector()
|
||||
if req.GetAll() {
|
||||
routeSelector.DeselectAllRoutes()
|
||||
} else {
|
||||
if err := routeSelector.DeselectRoutes(req.GetRouteIDs(), maps.Keys(s.engine.GetClientRoutesWithNetID())); err != nil {
|
||||
return nil, fmt.Errorf("deselect routes: %w", err)
|
||||
}
|
||||
}
|
||||
routeManager.TriggerSelection(s.engine.GetClientRoutes())
|
||||
|
||||
return &proto.SelectRoutesResponse{}, nil
|
||||
}
|
||||
@@ -15,15 +15,15 @@ import (
|
||||
|
||||
"google.golang.org/protobuf/types/known/durationpb"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/auth"
|
||||
"github.com/netbirdio/netbird/client/system"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/metadata"
|
||||
gstatus "google.golang.org/grpc/status"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/auth"
|
||||
"github.com/netbirdio/netbird/client/system"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
@@ -57,6 +57,8 @@ type Server struct {
|
||||
config *internal.Config
|
||||
proto.UnimplementedDaemonServiceServer
|
||||
|
||||
engine *internal.Engine
|
||||
|
||||
statusRecorder *peer.Status
|
||||
sessionWatcher *internal.SessionWatcher
|
||||
|
||||
@@ -141,8 +143,11 @@ func (s *Server) Start() error {
|
||||
s.sessionWatcher.SetOnExpireListener(s.onSessionExpire)
|
||||
}
|
||||
|
||||
engineChan := make(chan *internal.Engine, 1)
|
||||
go s.watchEngine(ctx, engineChan)
|
||||
|
||||
if !config.DisableAutoConnect {
|
||||
go s.connectWithRetryRuns(ctx, config, s.statusRecorder, s.mgmProbe, s.signalProbe, s.relayProbe, s.wgProbe)
|
||||
go s.connectWithRetryRuns(ctx, config, s.statusRecorder, s.mgmProbe, s.signalProbe, s.relayProbe, s.wgProbe, engineChan)
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -153,6 +158,7 @@ func (s *Server) Start() error {
|
||||
// we cancel retry if the client receive a stop or down command, or if disable auto connect is configured.
|
||||
func (s *Server) connectWithRetryRuns(ctx context.Context, config *internal.Config, statusRecorder *peer.Status,
|
||||
mgmProbe *internal.Probe, signalProbe *internal.Probe, relayProbe *internal.Probe, wgProbe *internal.Probe,
|
||||
engineChan chan<- *internal.Engine,
|
||||
) {
|
||||
backOff := getConnectWithBackoff(ctx)
|
||||
retryStarted := false
|
||||
@@ -182,7 +188,7 @@ func (s *Server) connectWithRetryRuns(ctx context.Context, config *internal.Conf
|
||||
|
||||
runOperation := func() error {
|
||||
log.Tracef("running client connection")
|
||||
err := internal.RunClientWithProbes(ctx, config, statusRecorder, mgmProbe, signalProbe, relayProbe, wgProbe)
|
||||
err := internal.RunClientWithProbes(ctx, config, statusRecorder, mgmProbe, signalProbe, relayProbe, wgProbe, engineChan)
|
||||
if err != nil {
|
||||
log.Debugf("run client connection exited with error: %v. Will retry in the background", err)
|
||||
}
|
||||
@@ -562,7 +568,10 @@ func (s *Server) Up(callerCtx context.Context, _ *proto.UpRequest) (*proto.UpRes
|
||||
s.statusRecorder.UpdateManagementAddress(s.config.ManagementURL.String())
|
||||
s.statusRecorder.UpdateRosenpass(s.config.RosenpassEnabled, s.config.RosenpassPermissive)
|
||||
|
||||
go s.connectWithRetryRuns(ctx, s.config, s.statusRecorder, s.mgmProbe, s.signalProbe, s.relayProbe, s.wgProbe)
|
||||
engineChan := make(chan *internal.Engine, 1)
|
||||
go s.watchEngine(ctx, engineChan)
|
||||
|
||||
go s.connectWithRetryRuns(ctx, s.config, s.statusRecorder, s.mgmProbe, s.signalProbe, s.relayProbe, s.wgProbe, engineChan)
|
||||
|
||||
return &proto.UpResponse{}, nil
|
||||
}
|
||||
@@ -579,6 +588,8 @@ func (s *Server) Down(_ context.Context, _ *proto.DownRequest) (*proto.DownRespo
|
||||
state := internal.CtxGetState(s.rootCtx)
|
||||
state.Set(internal.StatusIdle)
|
||||
|
||||
s.engine = nil
|
||||
|
||||
return &proto.DownResponse{}, nil
|
||||
}
|
||||
|
||||
@@ -661,7 +672,6 @@ func (s *Server) GetConfig(_ context.Context, _ *proto.GetConfigRequest) (*proto
|
||||
PreSharedKey: preSharedKey,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *Server) onSessionExpire() {
|
||||
if runtime.GOOS != "windows" {
|
||||
isUIActive := internal.CheckUIApp()
|
||||
@@ -673,6 +683,22 @@ func (s *Server) onSessionExpire() {
|
||||
}
|
||||
}
|
||||
|
||||
// watchEngine watches the engine channel and updates the engine state
|
||||
func (s *Server) watchEngine(ctx context.Context, engineChan chan *internal.Engine) {
|
||||
log.Tracef("Started watching engine")
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
s.engine = nil
|
||||
log.Tracef("Stopped watching engine")
|
||||
return
|
||||
case engine := <-engineChan:
|
||||
log.Tracef("Received engine from watcher")
|
||||
s.engine = engine
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func toProtoFullStatus(fullStatus peer.FullStatus) *proto.FullStatus {
|
||||
pbFullStatus := proto.FullStatus{
|
||||
ManagementState: &proto.ManagementState{},
|
||||
@@ -718,7 +744,7 @@ func toProtoFullStatus(fullStatus peer.FullStatus) *proto.FullStatus {
|
||||
BytesRx: peerState.BytesRx,
|
||||
BytesTx: peerState.BytesTx,
|
||||
RosenpassEnabled: peerState.RosenpassEnabled,
|
||||
Routes: maps.Keys(peerState.Routes),
|
||||
Routes: maps.Keys(peerState.GetRoutes()),
|
||||
Latency: durationpb.New(peerState.Latency),
|
||||
}
|
||||
pbFullStatus.Peers = append(pbFullStatus.Peers, pbPeerState)
|
||||
|
||||
@@ -2,11 +2,12 @@ package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"github.com/netbirdio/management-integrations/integrations"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/netbirdio/management-integrations/integrations"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/keepalive"
|
||||
@@ -69,7 +70,7 @@ func TestConnectWithRetryRuns(t *testing.T) {
|
||||
t.Setenv(maxRetryTimeVar, "5s")
|
||||
t.Setenv(retryMultiplierVar, "1")
|
||||
|
||||
s.connectWithRetryRuns(ctx, config, s.statusRecorder, s.mgmProbe, s.signalProbe, s.relayProbe, s.wgProbe)
|
||||
s.connectWithRetryRuns(ctx, config, s.statusRecorder, s.mgmProbe, s.signalProbe, s.relayProbe, s.wgProbe, nil)
|
||||
if counter < 3 {
|
||||
t.Fatalf("expected counter > 2, got %d", counter)
|
||||
}
|
||||
|
||||
@@ -25,8 +25,6 @@ func Detect(ctx context.Context) string {
|
||||
detectDigitalOcean,
|
||||
detectGCP,
|
||||
detectOracle,
|
||||
detectIBMCloud,
|
||||
detectSoftlayer,
|
||||
detectVultr,
|
||||
}
|
||||
|
||||
|
||||
@@ -6,7 +6,7 @@ import (
|
||||
)
|
||||
|
||||
func detectGCP(ctx context.Context) string {
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", "http://metadata.google.internal", nil)
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", "http://169.254.169.254", nil)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
@@ -1,54 +0,0 @@
|
||||
package detect_cloud
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
func detectIBMCloud(ctx context.Context) string {
|
||||
v1ResultChan := make(chan bool, 1)
|
||||
v2ResultChan := make(chan bool, 1)
|
||||
|
||||
go func() {
|
||||
v1ResultChan <- detectIBMSecure(ctx)
|
||||
}()
|
||||
|
||||
go func() {
|
||||
v2ResultChan <- detectIBM(ctx)
|
||||
}()
|
||||
|
||||
v1Result, v2Result := <-v1ResultChan, <-v2ResultChan
|
||||
|
||||
if v1Result || v2Result {
|
||||
return "IBM Cloud"
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func detectIBMSecure(ctx context.Context) bool {
|
||||
req, err := http.NewRequestWithContext(ctx, "PUT", "https://api.metadata.cloud.ibm.com/instance_identity/v1/token", nil)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
resp, err := hc.Do(req)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
return resp.StatusCode == http.StatusOK
|
||||
}
|
||||
|
||||
func detectIBM(ctx context.Context) bool {
|
||||
req, err := http.NewRequestWithContext(ctx, "PUT", "http://api.metadata.cloud.ibm.com/instance_identity/v1/token", nil)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
resp, err := hc.Do(req)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
return resp.StatusCode == http.StatusOK
|
||||
}
|
||||
@@ -1,25 +0,0 @@
|
||||
package detect_cloud
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
func detectSoftlayer(ctx context.Context) string {
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", "https://api.service.softlayer.com/rest/v3/SoftLayer_Resource_Metadata/UserMetadata.txt", nil)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
resp, err := hc.Do(req)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode == http.StatusOK {
|
||||
// Since SoftLayer was acquired by IBM, we should return "IBM Cloud"
|
||||
return "IBM Cloud"
|
||||
}
|
||||
return ""
|
||||
}
|
||||
@@ -1,7 +1,5 @@
|
||||
//go:build !(linux && 386)
|
||||
// +build !linux !386
|
||||
|
||||
// skipping linux 32 bits build and tests
|
||||
package main
|
||||
|
||||
import (
|
||||
@@ -58,14 +56,23 @@ func main() {
|
||||
|
||||
var showSettings bool
|
||||
flag.BoolVar(&showSettings, "settings", false, "run settings windows")
|
||||
var showRoutes bool
|
||||
flag.BoolVar(&showRoutes, "routes", false, "run routes windows")
|
||||
var errorMSG string
|
||||
flag.StringVar(&errorMSG, "error-msg", "", "displays a error message window")
|
||||
|
||||
flag.Parse()
|
||||
|
||||
a := app.NewWithID("NetBird")
|
||||
a.SetIcon(fyne.NewStaticResource("netbird", iconDisconnectedPNG))
|
||||
|
||||
client := newServiceClient(daemonAddr, a, showSettings)
|
||||
if showSettings {
|
||||
if errorMSG != "" {
|
||||
showErrorMSG(errorMSG)
|
||||
return
|
||||
}
|
||||
|
||||
client := newServiceClient(daemonAddr, a, showSettings, showRoutes)
|
||||
if showSettings || showRoutes {
|
||||
a.Run()
|
||||
} else {
|
||||
if err := checkPIDFile(); err != nil {
|
||||
@@ -128,6 +135,7 @@ type serviceClient struct {
|
||||
mVersionDaemon *systray.MenuItem
|
||||
mUpdate *systray.MenuItem
|
||||
mQuit *systray.MenuItem
|
||||
mRoutes *systray.MenuItem
|
||||
|
||||
// application with main windows.
|
||||
app fyne.App
|
||||
@@ -152,12 +160,15 @@ type serviceClient struct {
|
||||
daemonVersion string
|
||||
updateIndicationLock sync.Mutex
|
||||
isUpdateIconActive bool
|
||||
|
||||
showRoutes bool
|
||||
wRoutes fyne.Window
|
||||
}
|
||||
|
||||
// newServiceClient instance constructor
|
||||
//
|
||||
// This constructor also builds the UI elements for the settings window.
|
||||
func newServiceClient(addr string, a fyne.App, showSettings bool) *serviceClient {
|
||||
func newServiceClient(addr string, a fyne.App, showSettings bool, showRoutes bool) *serviceClient {
|
||||
s := &serviceClient{
|
||||
ctx: context.Background(),
|
||||
addr: addr,
|
||||
@@ -165,6 +176,7 @@ func newServiceClient(addr string, a fyne.App, showSettings bool) *serviceClient
|
||||
sendNotification: false,
|
||||
|
||||
showSettings: showSettings,
|
||||
showRoutes: showRoutes,
|
||||
update: version.NewUpdate(),
|
||||
}
|
||||
|
||||
@@ -184,14 +196,16 @@ func newServiceClient(addr string, a fyne.App, showSettings bool) *serviceClient
|
||||
}
|
||||
|
||||
if showSettings {
|
||||
s.showUIElements()
|
||||
s.showSettingsUI()
|
||||
return s
|
||||
} else if showRoutes {
|
||||
s.showRoutesUI()
|
||||
}
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *serviceClient) showUIElements() {
|
||||
func (s *serviceClient) showSettingsUI() {
|
||||
// add settings window UI elements.
|
||||
s.wSettings = s.app.NewWindow("NetBird Settings")
|
||||
s.iMngURL = widget.NewEntry()
|
||||
@@ -209,6 +223,18 @@ func (s *serviceClient) showUIElements() {
|
||||
s.wSettings.Show()
|
||||
}
|
||||
|
||||
// showErrorMSG opens a fyne app window to display the supplied message
|
||||
func showErrorMSG(msg string) {
|
||||
app := app.New()
|
||||
w := app.NewWindow("NetBird Error")
|
||||
content := widget.NewLabel(msg)
|
||||
content.Wrapping = fyne.TextWrapWord
|
||||
w.SetContent(content)
|
||||
w.Resize(fyne.NewSize(400, 100))
|
||||
w.Show()
|
||||
app.Run()
|
||||
}
|
||||
|
||||
// getSettingsForm to embed it into settings window.
|
||||
func (s *serviceClient) getSettingsForm() *widget.Form {
|
||||
return &widget.Form{
|
||||
@@ -397,6 +423,7 @@ func (s *serviceClient) updateStatus() error {
|
||||
s.mStatus.SetTitle("Connected")
|
||||
s.mUp.Disable()
|
||||
s.mDown.Enable()
|
||||
s.mRoutes.Enable()
|
||||
systrayIconState = true
|
||||
} else if status.Status != string(internal.StatusConnected) && s.mUp.Disabled() {
|
||||
s.connected = false
|
||||
@@ -409,6 +436,7 @@ func (s *serviceClient) updateStatus() error {
|
||||
s.mStatus.SetTitle("Disconnected")
|
||||
s.mDown.Disable()
|
||||
s.mUp.Enable()
|
||||
s.mRoutes.Disable()
|
||||
systrayIconState = false
|
||||
}
|
||||
|
||||
@@ -464,9 +492,11 @@ func (s *serviceClient) onTrayReady() {
|
||||
s.mUp = systray.AddMenuItem("Connect", "Connect")
|
||||
s.mDown = systray.AddMenuItem("Disconnect", "Disconnect")
|
||||
s.mDown.Disable()
|
||||
s.mAdminPanel = systray.AddMenuItem("Admin Panel", "Wiretrustee Admin Panel")
|
||||
s.mAdminPanel = systray.AddMenuItem("Admin Panel", "Netbird Admin Panel")
|
||||
systray.AddSeparator()
|
||||
s.mSettings = systray.AddMenuItem("Settings", "Settings of the application")
|
||||
s.mRoutes = systray.AddMenuItem("Network Routes", "Open the routes management window")
|
||||
s.mRoutes.Disable()
|
||||
systray.AddSeparator()
|
||||
|
||||
s.mAbout = systray.AddMenuItem("About", "About")
|
||||
@@ -504,16 +534,22 @@ func (s *serviceClient) onTrayReady() {
|
||||
case <-s.mAdminPanel.ClickedCh:
|
||||
err = open.Run(s.adminURL)
|
||||
case <-s.mUp.ClickedCh:
|
||||
s.mUp.Disabled()
|
||||
go func() {
|
||||
defer s.mUp.Enable()
|
||||
err := s.menuUpClick()
|
||||
if err != nil {
|
||||
s.runSelfCommand("error-msg", err.Error())
|
||||
return
|
||||
}
|
||||
}()
|
||||
case <-s.mDown.ClickedCh:
|
||||
s.mDown.Disable()
|
||||
go func() {
|
||||
defer s.mDown.Enable()
|
||||
err := s.menuDownClick()
|
||||
if err != nil {
|
||||
s.runSelfCommand("error-msg", err.Error())
|
||||
return
|
||||
}
|
||||
}()
|
||||
@@ -521,24 +557,8 @@ func (s *serviceClient) onTrayReady() {
|
||||
s.mSettings.Disable()
|
||||
go func() {
|
||||
defer s.mSettings.Enable()
|
||||
proc, err := os.Executable()
|
||||
if err != nil {
|
||||
log.Errorf("show settings: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
cmd := exec.Command(proc, "--settings=true")
|
||||
out, err := cmd.CombinedOutput()
|
||||
if exitErr, ok := err.(*exec.ExitError); ok && exitErr.ExitCode() == 1 {
|
||||
log.Errorf("start settings UI: %v, %s", err, string(out))
|
||||
return
|
||||
}
|
||||
if len(out) != 0 {
|
||||
log.Info("settings change:", string(out))
|
||||
}
|
||||
|
||||
// update config in systray when settings windows closed
|
||||
s.getSrvConfig()
|
||||
defer s.getSrvConfig()
|
||||
s.runSelfCommand("settings", "true")
|
||||
}()
|
||||
case <-s.mQuit.ClickedCh:
|
||||
systray.Quit()
|
||||
@@ -548,6 +568,12 @@ func (s *serviceClient) onTrayReady() {
|
||||
if err != nil {
|
||||
log.Errorf("%s", err)
|
||||
}
|
||||
case <-s.mRoutes.ClickedCh:
|
||||
s.mRoutes.Disable()
|
||||
go func() {
|
||||
defer s.mRoutes.Enable()
|
||||
s.runSelfCommand("routes", "true")
|
||||
}()
|
||||
}
|
||||
if err != nil {
|
||||
log.Errorf("process connection: %v", err)
|
||||
@@ -556,6 +582,24 @@ func (s *serviceClient) onTrayReady() {
|
||||
}()
|
||||
}
|
||||
|
||||
func (s *serviceClient) runSelfCommand(command, arg string) {
|
||||
proc, err := os.Executable()
|
||||
if err != nil {
|
||||
log.Errorf("show %s failed with error: %v", command, err)
|
||||
return
|
||||
}
|
||||
|
||||
cmd := exec.Command(proc, fmt.Sprintf("--%s=%s", command, arg))
|
||||
out, err := cmd.CombinedOutput()
|
||||
if exitErr, ok := err.(*exec.ExitError); ok && exitErr.ExitCode() == 1 {
|
||||
log.Errorf("start %s UI: %v, %s", command, err, string(out))
|
||||
return
|
||||
}
|
||||
if len(out) != 0 {
|
||||
log.Infof("command %s executed: %s", command, string(out))
|
||||
}
|
||||
}
|
||||
|
||||
func normalizedVersion(version string) string {
|
||||
versionString := version
|
||||
if unicode.IsDigit(rune(versionString[0])) {
|
||||
|
||||
203
client/ui/route.go
Normal file
203
client/ui/route.go
Normal file
@@ -0,0 +1,203 @@
|
||||
//go:build !(linux && 386)
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"fyne.io/fyne/v2"
|
||||
"fyne.io/fyne/v2/container"
|
||||
"fyne.io/fyne/v2/dialog"
|
||||
"fyne.io/fyne/v2/layout"
|
||||
"fyne.io/fyne/v2/widget"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
)
|
||||
|
||||
func (s *serviceClient) showRoutesUI() {
|
||||
s.wRoutes = s.app.NewWindow("NetBird Routes")
|
||||
|
||||
grid := container.New(layout.NewGridLayout(2))
|
||||
go s.updateRoutes(grid)
|
||||
routeCheckContainer := container.NewVBox()
|
||||
routeCheckContainer.Add(grid)
|
||||
scrollContainer := container.NewVScroll(routeCheckContainer)
|
||||
scrollContainer.SetMinSize(fyne.NewSize(200, 300))
|
||||
|
||||
buttonBox := container.NewHBox(
|
||||
layout.NewSpacer(),
|
||||
widget.NewButton("Refresh", func() {
|
||||
s.updateRoutes(grid)
|
||||
}),
|
||||
widget.NewButton("Select all", func() {
|
||||
s.selectAllRoutes()
|
||||
s.updateRoutes(grid)
|
||||
}),
|
||||
widget.NewButton("Deselect All", func() {
|
||||
s.deselectAllRoutes()
|
||||
s.updateRoutes(grid)
|
||||
}),
|
||||
layout.NewSpacer(),
|
||||
)
|
||||
|
||||
content := container.NewBorder(nil, buttonBox, nil, nil, scrollContainer)
|
||||
|
||||
s.wRoutes.SetContent(content)
|
||||
s.wRoutes.Show()
|
||||
|
||||
s.startAutoRefresh(5*time.Second, grid)
|
||||
}
|
||||
|
||||
func (s *serviceClient) updateRoutes(grid *fyne.Container) {
|
||||
routes, err := s.fetchRoutes()
|
||||
if err != nil {
|
||||
log.Errorf("get client: %v", err)
|
||||
s.showError(fmt.Errorf("get client: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
grid.Objects = nil
|
||||
idHeader := widget.NewLabelWithStyle(" ID", fyne.TextAlignLeading, fyne.TextStyle{Bold: true})
|
||||
networkHeader := widget.NewLabelWithStyle("Network", fyne.TextAlignLeading, fyne.TextStyle{Bold: true})
|
||||
|
||||
grid.Add(idHeader)
|
||||
grid.Add(networkHeader)
|
||||
for _, route := range routes {
|
||||
r := route
|
||||
|
||||
checkBox := widget.NewCheck(r.ID, func(checked bool) {
|
||||
s.selectRoute(r.ID, checked)
|
||||
})
|
||||
checkBox.Checked = route.Selected
|
||||
checkBox.Resize(fyne.NewSize(20, 20))
|
||||
checkBox.Refresh()
|
||||
|
||||
grid.Add(checkBox)
|
||||
grid.Add(widget.NewLabel(r.Network))
|
||||
}
|
||||
|
||||
s.wRoutes.Content().Refresh()
|
||||
}
|
||||
|
||||
func (s *serviceClient) fetchRoutes() ([]*proto.Route, error) {
|
||||
conn, err := s.getSrvClient(defaultFailTimeout)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get client: %v", err)
|
||||
}
|
||||
|
||||
resp, err := conn.ListRoutes(s.ctx, &proto.ListRoutesRequest{})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to list routes: %v", err)
|
||||
}
|
||||
|
||||
return resp.Routes, nil
|
||||
}
|
||||
|
||||
func (s *serviceClient) selectRoute(id string, checked bool) {
|
||||
conn, err := s.getSrvClient(defaultFailTimeout)
|
||||
if err != nil {
|
||||
log.Errorf("get client: %v", err)
|
||||
s.showError(fmt.Errorf("get client: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
req := &proto.SelectRoutesRequest{
|
||||
RouteIDs: []string{id},
|
||||
Append: checked,
|
||||
}
|
||||
|
||||
if checked {
|
||||
if _, err := conn.SelectRoutes(s.ctx, req); err != nil {
|
||||
log.Errorf("failed to select route: %v", err)
|
||||
s.showError(fmt.Errorf("failed to select route: %v", err))
|
||||
return
|
||||
}
|
||||
log.Infof("Route %s selected", id)
|
||||
} else {
|
||||
if _, err := conn.DeselectRoutes(s.ctx, req); err != nil {
|
||||
log.Errorf("failed to deselect route: %v", err)
|
||||
s.showError(fmt.Errorf("failed to deselect route: %v", err))
|
||||
return
|
||||
}
|
||||
log.Infof("Route %s deselected", id)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *serviceClient) selectAllRoutes() {
|
||||
conn, err := s.getSrvClient(defaultFailTimeout)
|
||||
if err != nil {
|
||||
log.Errorf("get client: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
req := &proto.SelectRoutesRequest{
|
||||
All: true,
|
||||
}
|
||||
if _, err := conn.SelectRoutes(s.ctx, req); err != nil {
|
||||
log.Errorf("failed to select all routes: %v", err)
|
||||
s.showError(fmt.Errorf("failed to select all routes: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
log.Debug("All routes selected")
|
||||
}
|
||||
|
||||
func (s *serviceClient) deselectAllRoutes() {
|
||||
conn, err := s.getSrvClient(defaultFailTimeout)
|
||||
if err != nil {
|
||||
log.Errorf("get client: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
req := &proto.SelectRoutesRequest{
|
||||
All: true,
|
||||
}
|
||||
if _, err := conn.DeselectRoutes(s.ctx, req); err != nil {
|
||||
log.Errorf("failed to deselect all routes: %v", err)
|
||||
s.showError(fmt.Errorf("failed to deselect all routes: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
log.Debug("All routes deselected")
|
||||
}
|
||||
|
||||
func (s *serviceClient) showError(err error) {
|
||||
wrappedMessage := wrapText(err.Error(), 50)
|
||||
|
||||
dialog.ShowError(fmt.Errorf("%s", wrappedMessage), s.wRoutes)
|
||||
}
|
||||
|
||||
func (s *serviceClient) startAutoRefresh(interval time.Duration, grid *fyne.Container) {
|
||||
ticker := time.NewTicker(interval)
|
||||
go func() {
|
||||
for range ticker.C {
|
||||
s.updateRoutes(grid)
|
||||
}
|
||||
}()
|
||||
|
||||
s.wRoutes.SetOnClosed(func() {
|
||||
ticker.Stop()
|
||||
})
|
||||
}
|
||||
|
||||
// wrapText inserts newlines into the text to ensure that each line is
|
||||
// no longer than 'lineLength' runes.
|
||||
func wrapText(text string, lineLength int) string {
|
||||
var sb strings.Builder
|
||||
var currentLineLength int
|
||||
|
||||
for _, runeValue := range text {
|
||||
sb.WriteRune(runeValue)
|
||||
currentLineLength++
|
||||
|
||||
if currentLineLength >= lineLength || runeValue == '\n' {
|
||||
sb.WriteRune('\n')
|
||||
currentLineLength = 0
|
||||
}
|
||||
}
|
||||
|
||||
return sb.String()
|
||||
}
|
||||
12
go.mod
12
go.mod
@@ -21,8 +21,8 @@ require (
|
||||
github.com/spf13/cobra v1.7.0
|
||||
github.com/spf13/pflag v1.0.5
|
||||
github.com/vishvananda/netlink v1.1.1-0.20211118161826-650dca95af54
|
||||
golang.org/x/crypto v0.18.0
|
||||
golang.org/x/sys v0.16.0
|
||||
golang.org/x/crypto v0.21.0
|
||||
golang.org/x/sys v0.18.0
|
||||
golang.zx2c4.com/wireguard v0.0.0-20230704135630-469159ecf7d1
|
||||
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6
|
||||
golang.zx2c4.com/wireguard/windows v0.5.3
|
||||
@@ -53,14 +53,14 @@ require (
|
||||
github.com/hashicorp/go-multierror v1.1.1
|
||||
github.com/hashicorp/go-secure-stdlib/base62 v0.1.2
|
||||
github.com/hashicorp/go-version v1.6.0
|
||||
github.com/libp2p/go-netroute v0.2.0
|
||||
github.com/libp2p/go-netroute v0.2.1
|
||||
github.com/magiconair/properties v1.8.5
|
||||
github.com/mattn/go-sqlite3 v1.14.19
|
||||
github.com/mdlayher/socket v0.4.1
|
||||
github.com/miekg/dns v1.1.43
|
||||
github.com/mitchellh/hashstructure/v2 v2.0.2
|
||||
github.com/nadoo/ipset v0.5.0
|
||||
github.com/netbirdio/management-integrations/integrations v0.0.0-20240326083846-3682438fca98
|
||||
github.com/netbirdio/management-integrations/integrations v0.0.0-20240415094251-369eb33c9b01
|
||||
github.com/okta/okta-sdk-golang/v2 v2.18.0
|
||||
github.com/oschwald/maxminddb-golang v1.12.0
|
||||
github.com/patrickmn/go-cache v2.1.0+incompatible
|
||||
@@ -82,10 +82,10 @@ require (
|
||||
goauthentik.io/api/v3 v3.2023051.3
|
||||
golang.org/x/exp v0.0.0-20230725093048-515e97ebf090
|
||||
golang.org/x/mobile v0.0.0-20190719004257-d2bd2a29d028
|
||||
golang.org/x/net v0.20.0
|
||||
golang.org/x/net v0.23.0
|
||||
golang.org/x/oauth2 v0.8.0
|
||||
golang.org/x/sync v0.3.0
|
||||
golang.org/x/term v0.16.0
|
||||
golang.org/x/term v0.18.0
|
||||
google.golang.org/api v0.126.0
|
||||
gopkg.in/yaml.v3 v3.0.1
|
||||
gorm.io/driver/sqlite v1.5.3
|
||||
|
||||
22
go.sum
22
go.sum
@@ -345,8 +345,8 @@ github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
|
||||
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
|
||||
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
|
||||
github.com/leodido/go-urn v1.1.0/go.mod h1:+cyI34gQWZcE1eQU7NVgKkkzdXDQHr1dBMtdAPozLkw=
|
||||
github.com/libp2p/go-netroute v0.2.0 h1:0FpsbsvuSnAhXFnCY0VLFbJOzaK0VnP0r1QT/o4nWRE=
|
||||
github.com/libp2p/go-netroute v0.2.0/go.mod h1:Vio7LTzZ+6hoT4CMZi5/6CpY3Snzh2vgZhWgxMNwlQI=
|
||||
github.com/libp2p/go-netroute v0.2.1 h1:V8kVrpD8GK0Riv15/7VN6RbUQ3URNZVosw7H2v9tksU=
|
||||
github.com/libp2p/go-netroute v0.2.1/go.mod h1:hraioZr0fhBjG0ZRXJJ6Zj2IVEVNx6tDTFQfSmcq7mQ=
|
||||
github.com/lucor/goinfo v0.0.0-20210802170112-c078a2b0f08b/go.mod h1:PRq09yoB+Q2OJReAmwzKivcYyremnibWGbK7WfftHzc=
|
||||
github.com/magiconair/properties v1.8.5 h1:b6kJs+EmPFMYGkow9GiUyCyOvIwYetYJ3fSaWak/Gls=
|
||||
github.com/magiconair/properties v1.8.5/go.mod h1:y3VJvCyxH9uVvJTWEGAELF3aiYNyPKd5NZ3oSwXrF60=
|
||||
@@ -383,8 +383,8 @@ github.com/nadoo/ipset v0.5.0 h1:5GJUAuZ7ITQQQGne5J96AmFjRtI8Avlbk6CabzYWVUc=
|
||||
github.com/nadoo/ipset v0.5.0/go.mod h1:rYF5DQLRGGoQ8ZSWeK+6eX5amAuPqwFkWjhQlEITGJQ=
|
||||
github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e h1:PURA50S8u4mF6RrkYYCAvvPCixhqqEiEy3Ej6avh04c=
|
||||
github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e/go.mod h1:YMLU7qbKfVjmEv7EoZPIVEI+kNYxWCdPK3VS0BU+U4Q=
|
||||
github.com/netbirdio/management-integrations/integrations v0.0.0-20240326083846-3682438fca98 h1:i6AtenTLu/CqhTmj0g1K/GWkkpMJMhQM6Vjs46x25nA=
|
||||
github.com/netbirdio/management-integrations/integrations v0.0.0-20240326083846-3682438fca98/go.mod h1:kxks50DrZnhW+oRTdHOkVOJbcTcyo766am8RBugo+Yc=
|
||||
github.com/netbirdio/management-integrations/integrations v0.0.0-20240415094251-369eb33c9b01 h1:Fu9fq0ndfKVuFTEwbc8Etqui10BOkcMTv0UqcMy0RuY=
|
||||
github.com/netbirdio/management-integrations/integrations v0.0.0-20240415094251-369eb33c9b01/go.mod h1:kxks50DrZnhW+oRTdHOkVOJbcTcyo766am8RBugo+Yc=
|
||||
github.com/netbirdio/service v0.0.0-20230215170314-b923b89432b0 h1:hirFRfx3grVA/9eEyjME5/z3nxdJlN9kfQpvWWPk32g=
|
||||
github.com/netbirdio/service v0.0.0-20230215170314-b923b89432b0/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
|
||||
github.com/netbirdio/systray v0.0.0-20231030152038-ef1ed2a27949 h1:xbWM9BU6mwZZLHxEjxIX/V8Hv3HurQt4mReIE4mY4DM=
|
||||
@@ -581,8 +581,9 @@ golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5y
|
||||
golang.org/x/crypto v0.0.0-20220314234659-1baeb1ce4c0b/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
|
||||
golang.org/x/crypto v0.8.0/go.mod h1:mRqEX+O9/h5TFCrQhkgjo2yKi0yYA+9ecGkdQoHrywE=
|
||||
golang.org/x/crypto v0.12.0/go.mod h1:NF0Gs7EO5K4qLn+Ylc+fih8BSTeIjAP05siRnAh98yw=
|
||||
golang.org/x/crypto v0.18.0 h1:PGVlW0xEltQnzFZ55hkuX5+KLyrMYhHld1YHO4AKcdc=
|
||||
golang.org/x/crypto v0.18.0/go.mod h1:R0j02AL6hcrfOiy9T4ZYp/rcWeMxM3L6QYxlOuEG1mg=
|
||||
golang.org/x/crypto v0.21.0 h1:X31++rzVUdKhX5sWmSOFZxx8UW/ldWx55cbf08iNAMA=
|
||||
golang.org/x/crypto v0.21.0/go.mod h1:0BP7YvVV9gBbVKyeTG0Gyn+gZm94bibOW5BjDEYAOMs=
|
||||
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
||||
golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
||||
golang.org/x/exp v0.0.0-20190510132918-efd6b22b2522/go.mod h1:ZjyILWgesfNpC6sMxTJOJm9Kp84zZh5NQWvqDGG3Qr8=
|
||||
@@ -659,7 +660,6 @@ golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwY
|
||||
golang.org/x/net v0.0.0-20201110031124-69a78807bb2b/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
|
||||
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
|
||||
golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM=
|
||||
golang.org/x/net v0.0.0-20210423184538-5f58ad60dda6/go.mod h1:OJAsFXCWl8Ukc7SiCT/9KSuxbyM7479/AVlXFRxuMCk=
|
||||
golang.org/x/net v0.0.0-20210428140749-89ef3d95e781/go.mod h1:OJAsFXCWl8Ukc7SiCT/9KSuxbyM7479/AVlXFRxuMCk=
|
||||
golang.org/x/net v0.0.0-20210525063256-abc453219eb5/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
|
||||
golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
|
||||
@@ -670,8 +670,9 @@ golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs=
|
||||
golang.org/x/net v0.9.0/go.mod h1:d48xBJpPfHeWQsugry2m+kC02ZBRGRgulfHnEXEuWns=
|
||||
golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg=
|
||||
golang.org/x/net v0.14.0/go.mod h1:PpSgVXXLK0OxS0F31C1/tv6XNguvCrnXIDrFMspZIUI=
|
||||
golang.org/x/net v0.20.0 h1:aCL9BSgETF1k+blQaYUBx9hJ9LOGP3gAVemcZlf1Kpo=
|
||||
golang.org/x/net v0.20.0/go.mod h1:z8BVo6PvndSri0LbOE3hAn0apkU+1YvI6E70E9jsnvY=
|
||||
golang.org/x/net v0.23.0 h1:7EYJ93RZ9vYSZAIb2x3lnuvqO5zneoD6IvWjuhfxjTs=
|
||||
golang.org/x/net v0.23.0/go.mod h1:JKghWKKOSdJwpW2GEx0Ja7fmaKnMsbu+MWVZTokSYmg=
|
||||
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
|
||||
golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
|
||||
golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
|
||||
@@ -746,7 +747,6 @@ golang.org/x/sys v0.0.0-20210124154548-22da62e12c0c/go.mod h1:h1NjWce9XRLGQEsW7w
|
||||
golang.org/x/sys v0.0.0-20210303074136-134d130e1a04/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20210426080607-c94f62235c83/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20210603081109-ebe580a85c40/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
@@ -763,16 +763,18 @@ golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.7.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.16.0 h1:xWw16ngr6ZMtmxDyKyIgsE93KNKz5HKmMa3b8ALHidU=
|
||||
golang.org/x/sys v0.16.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/sys v0.18.0 h1:DBdB3niSjOA/O0blCZBqDefyWNYveAYMNF1Wum0DYQ4=
|
||||
golang.org/x/sys v0.18.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
||||
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
|
||||
golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k=
|
||||
golang.org/x/term v0.7.0/go.mod h1:P32HKFT3hSsZrRxla30E9HqToFYAQPCMs/zFMBUFqPY=
|
||||
golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo=
|
||||
golang.org/x/term v0.11.0/go.mod h1:zC9APTIj3jG3FdV/Ons+XE1riIZXG4aZ4GTHiPZJPIU=
|
||||
golang.org/x/term v0.16.0 h1:m+B6fahuftsE9qjo0VWp2FW0mB3MTJvR0BaMQrq0pmE=
|
||||
golang.org/x/term v0.16.0/go.mod h1:yn7UURbUtPyrVJPGPq404EukNFxcm/foM+bV/bfcDsY=
|
||||
golang.org/x/term v0.18.0 h1:FcHjZXDMxI8mM3nwhX9HlKop4C0YQvCVCdwYl2wOtE8=
|
||||
golang.org/x/term v0.18.0/go.mod h1:ILwASektA3OnRv7amZ1xhE/KTR+u50pbXfZ03+6Nx58=
|
||||
golang.org/x/text v0.0.0-20160726164857-2910a502d2bf/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
|
||||
@@ -10,8 +10,6 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.zx2c4.com/wireguard/wgctrl"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
)
|
||||
|
||||
type wgKernelConfigurer struct {
|
||||
@@ -31,7 +29,7 @@ func (c *wgKernelConfigurer) configureInterface(privateKey string, port int) err
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
fwmark := nbnet.NetbirdFwmark
|
||||
fwmark := getFwmark()
|
||||
config := wgtypes.Config{
|
||||
PrivateKey: &key,
|
||||
ReplacePeers: true,
|
||||
|
||||
@@ -349,7 +349,7 @@ func toWgUserspaceString(wgCfg wgtypes.Config) string {
|
||||
}
|
||||
|
||||
func getFwmark() int {
|
||||
if runtime.GOOS == "linux" {
|
||||
if runtime.GOOS == "linux" && !nbnet.CustomRoutingDisabled() {
|
||||
return nbnet.NetbirdFwmark
|
||||
}
|
||||
return 0
|
||||
|
||||
@@ -58,6 +58,7 @@ services:
|
||||
command: [
|
||||
"--port", "443",
|
||||
"--log-file", "console",
|
||||
"--log-level", "info",
|
||||
"--disable-anonymous-metrics=$NETBIRD_DISABLE_ANONYMOUS_METRICS",
|
||||
"--single-account-mode-domain=$NETBIRD_MGMT_SINGLE_ACCOUNT_MODE_DOMAIN",
|
||||
"--dns-domain=$NETBIRD_MGMT_DNS_DOMAIN"
|
||||
|
||||
@@ -2,7 +2,7 @@ version: "3"
|
||||
services:
|
||||
#UI dashboard
|
||||
dashboard:
|
||||
image: wiretrustee/dashboard:$NETBIRD_DASHBOARD_TAG
|
||||
image: netbirdio/dashboard:$NETBIRD_DASHBOARD_TAG
|
||||
restart: unless-stopped
|
||||
#ports:
|
||||
# - 80:80
|
||||
|
||||
@@ -251,7 +251,7 @@ var (
|
||||
|
||||
ctx, cancel := context.WithCancel(cmd.Context())
|
||||
defer cancel()
|
||||
httpAPIHandler, err := httpapi.APIHandler(ctx, accountManager, geo, *jwtValidator, appMetrics, httpAPIAuthCfg)
|
||||
httpAPIHandler, err := httpapi.APIHandler(ctx, accountManager, geo, *jwtValidator, appMetrics, httpAPIAuthCfg, integratedPeerValidator)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed creating HTTP API handler: %v", err)
|
||||
}
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
"net/netip"
|
||||
"reflect"
|
||||
"regexp"
|
||||
"runtime/debug"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
@@ -46,6 +47,8 @@ const (
|
||||
DefaultPeerLoginExpiration = 24 * time.Hour
|
||||
)
|
||||
|
||||
type userLoggedInOnce bool
|
||||
|
||||
type ExternalCacheManager cache.CacheInterface[*idp.UserData]
|
||||
|
||||
func cacheEntryExpiration() time.Duration {
|
||||
@@ -74,7 +77,7 @@ type AccountManager interface {
|
||||
GetUser(claims jwtclaims.AuthorizationClaims) (*User, error)
|
||||
ListUsers(accountID string) ([]*User, error)
|
||||
GetPeers(accountID, userID string) ([]*nbpeer.Peer, error)
|
||||
MarkPeerConnected(peerKey string, connected bool, realIP net.IP) error
|
||||
MarkPeerConnected(peerKey string, connected bool, realIP net.IP, account *Account) error
|
||||
DeletePeer(accountID, peerID, userID string) error
|
||||
UpdatePeer(accountID, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, error)
|
||||
GetNetworkMap(peerID string) (*NetworkMap, error)
|
||||
@@ -115,8 +118,8 @@ type AccountManager interface {
|
||||
SaveDNSSettings(accountID string, userID string, dnsSettingsToSave *DNSSettings) error
|
||||
GetPeer(accountID, peerID, userID string) (*nbpeer.Peer, error)
|
||||
UpdateAccountSettings(accountID, userID string, newSettings *Settings) (*Account, error)
|
||||
LoginPeer(login PeerLogin) (*nbpeer.Peer, *NetworkMap, error) // used by peer gRPC API
|
||||
SyncPeer(sync PeerSync) (*nbpeer.Peer, *NetworkMap, error) // used by peer gRPC API
|
||||
LoginPeer(login PeerLogin) (*nbpeer.Peer, *NetworkMap, error) // used by peer gRPC API
|
||||
SyncPeer(sync PeerSync, account *Account) (*nbpeer.Peer, *NetworkMap, error) // used by peer gRPC API
|
||||
GetAllConnectedPeers() (map[string]struct{}, error)
|
||||
HasConnectedChannel(peerID string) bool
|
||||
GetExternalCacheManager() ExternalCacheManager
|
||||
@@ -128,6 +131,8 @@ type AccountManager interface {
|
||||
UpdateIntegratedValidatorGroups(accountID string, userID string, groups []string) error
|
||||
GroupValidation(accountId string, groups []string) (bool, error)
|
||||
GetValidatedPeers(account *Account) (map[string]struct{}, error)
|
||||
SyncAndMarkPeer(peerPubKey string, realIP net.IP) (*nbpeer.Peer, *NetworkMap, error)
|
||||
CancelPeerRoutines(peer *nbpeer.Peer) error
|
||||
}
|
||||
|
||||
type DefaultAccountManager struct {
|
||||
@@ -242,19 +247,19 @@ type UserPermissions struct {
|
||||
}
|
||||
|
||||
type UserInfo struct {
|
||||
ID string `json:"id"`
|
||||
Email string `json:"email"`
|
||||
Name string `json:"name"`
|
||||
Role string `json:"role"`
|
||||
AutoGroups []string `json:"auto_groups"`
|
||||
Status string `json:"-"`
|
||||
IsServiceUser bool `json:"is_service_user"`
|
||||
IsBlocked bool `json:"is_blocked"`
|
||||
NonDeletable bool `json:"non_deletable"`
|
||||
LastLogin time.Time `json:"last_login"`
|
||||
Issued string `json:"issued"`
|
||||
ID string `json:"id"`
|
||||
Email string `json:"email"`
|
||||
Name string `json:"name"`
|
||||
Role string `json:"role"`
|
||||
AutoGroups []string `json:"auto_groups"`
|
||||
Status string `json:"-"`
|
||||
IsServiceUser bool `json:"is_service_user"`
|
||||
IsBlocked bool `json:"is_blocked"`
|
||||
NonDeletable bool `json:"non_deletable"`
|
||||
LastLogin time.Time `json:"last_login"`
|
||||
Issued string `json:"issued"`
|
||||
IntegrationReference integration_reference.IntegrationReference `json:"-"`
|
||||
Permissions UserPermissions `json:"permissions"`
|
||||
Permissions UserPermissions `json:"permissions"`
|
||||
}
|
||||
|
||||
// getRoutesToSync returns the enabled routes for the peer ID and the routes
|
||||
@@ -278,7 +283,7 @@ func (a *Account) getRoutesToSync(peerID string, aclPeers []*nbpeer.Peer) []*rou
|
||||
return routes
|
||||
}
|
||||
|
||||
// filterRoutesByHAMembership filters and returns a list of routes that don't share the same HA route membership
|
||||
// filterRoutesFromPeersOfSameHAGroup filters and returns a list of routes that don't share the same HA route membership
|
||||
func (a *Account) filterRoutesFromPeersOfSameHAGroup(routes []*route.Route, peerMemberships lookupMap) []*route.Route {
|
||||
var filteredRoutes []*route.Route
|
||||
for _, r := range routes {
|
||||
@@ -384,6 +389,8 @@ func (a *Account) GetGroup(groupID string) *nbgroup.Group {
|
||||
|
||||
// GetPeerNetworkMap returns a group by ID if exists, nil otherwise
|
||||
func (a *Account) GetPeerNetworkMap(peerID, dnsDomain string, validatedPeersMap map[string]struct{}) *NetworkMap {
|
||||
log.Debugf("GetNetworkMap with trace: %s", string(debug.Stack()))
|
||||
|
||||
peer := a.Peers[peerID]
|
||||
if peer == nil {
|
||||
return &NetworkMap{
|
||||
@@ -956,7 +963,7 @@ func (am *DefaultAccountManager) UpdateAccountSettings(accountID, userID string,
|
||||
return nil, status.Errorf(status.InvalidArgument, "peer login expiration can't be smaller than one hour")
|
||||
}
|
||||
|
||||
unlock := am.Store.AcquireAccountLock(accountID)
|
||||
unlock := am.Store.AcquireAccountWriteLock(accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(accountID)
|
||||
@@ -1007,7 +1014,7 @@ func (am *DefaultAccountManager) UpdateAccountSettings(accountID, userID string,
|
||||
|
||||
func (am *DefaultAccountManager) peerLoginExpirationJob(accountID string) func() (time.Duration, bool) {
|
||||
return func() (time.Duration, bool) {
|
||||
unlock := am.Store.AcquireAccountLock(accountID)
|
||||
unlock := am.Store.AcquireAccountWriteLock(accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(accountID)
|
||||
@@ -1092,19 +1099,21 @@ func (am *DefaultAccountManager) warmupIDPCache() error {
|
||||
}
|
||||
delete(userData, idp.UnsetAccountID)
|
||||
|
||||
rcvdUsers := 0
|
||||
for accountID, users := range userData {
|
||||
rcvdUsers += len(users)
|
||||
err = am.cacheManager.Set(am.ctx, accountID, users, cacheStore.WithExpiration(cacheEntryExpiration()))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
log.Infof("warmed up IDP cache with %d entries", len(userData))
|
||||
log.Infof("warmed up IDP cache with %d entries for %d accounts", rcvdUsers, len(userData))
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteAccount deletes an account and all its users from local store and from the remote IDP if the requester is an admin and account owner
|
||||
func (am *DefaultAccountManager) DeleteAccount(accountID, userID string) error {
|
||||
unlock := am.Store.AcquireAccountLock(accountID)
|
||||
unlock := am.Store.AcquireAccountWriteLock(accountID)
|
||||
defer unlock()
|
||||
account, err := am.Store.GetAccount(accountID)
|
||||
if err != nil {
|
||||
@@ -1120,7 +1129,7 @@ func (am *DefaultAccountManager) DeleteAccount(accountID, userID string) error {
|
||||
return status.Errorf(status.PermissionDenied, "user is not allowed to delete account")
|
||||
}
|
||||
|
||||
if user.Id != account.CreatedBy {
|
||||
if user.Role != UserRoleOwner {
|
||||
return status.Errorf(status.PermissionDenied, "user is not allowed to delete account. Only account owner can delete account")
|
||||
}
|
||||
for _, otherUser := range account.Users {
|
||||
@@ -1263,7 +1272,7 @@ func (am *DefaultAccountManager) lookupUserInCacheByEmail(email string, accountI
|
||||
|
||||
// lookupUserInCache looks up user in the IdP cache and returns it. If the user wasn't found, the function returns nil
|
||||
func (am *DefaultAccountManager) lookupUserInCache(userID string, account *Account) (*idp.UserData, error) {
|
||||
users := make(map[string]struct{}, len(account.Users))
|
||||
users := make(map[string]userLoggedInOnce, len(account.Users))
|
||||
// ignore service users and users provisioned by integrations than are never logged in
|
||||
for _, user := range account.Users {
|
||||
if user.IsServiceUser {
|
||||
@@ -1272,7 +1281,7 @@ func (am *DefaultAccountManager) lookupUserInCache(userID string, account *Accou
|
||||
if user.Issued == UserIssuedIntegration {
|
||||
continue
|
||||
}
|
||||
users[user.Id] = struct{}{}
|
||||
users[user.Id] = userLoggedInOnce(!user.LastLogin.IsZero())
|
||||
}
|
||||
log.Debugf("looking up user %s of account %s in cache", userID, account.Id)
|
||||
userData, err := am.lookupCache(users, account.Id)
|
||||
@@ -1345,22 +1354,57 @@ func (am *DefaultAccountManager) getAccountFromCache(accountID string, forceRelo
|
||||
}
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) lookupCache(accountUsers map[string]struct{}, accountID string) ([]*idp.UserData, error) {
|
||||
data, err := am.getAccountFromCache(accountID, false)
|
||||
func (am *DefaultAccountManager) lookupCache(accountUsers map[string]userLoggedInOnce, accountID string) ([]*idp.UserData, error) {
|
||||
var data []*idp.UserData
|
||||
var err error
|
||||
|
||||
maxAttempts := 2
|
||||
|
||||
data, err = am.getAccountFromCache(accountID, false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
userDataMap := make(map[string]struct{})
|
||||
for attempt := 1; attempt <= maxAttempts; attempt++ {
|
||||
if am.isCacheFresh(accountUsers, data) {
|
||||
return data, nil
|
||||
}
|
||||
|
||||
if attempt > 1 {
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
}
|
||||
|
||||
log.Infof("refreshing cache for account %s", accountID)
|
||||
data, err = am.refreshCache(accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if attempt == maxAttempts {
|
||||
log.Warnf("cache for account %s reached maximum refresh attempts (%d)", accountID, maxAttempts)
|
||||
}
|
||||
}
|
||||
|
||||
return data, nil
|
||||
}
|
||||
|
||||
// isCacheFresh checks if the cache is refreshed already by comparing the accountUsers with the cache data by user count and user invite status
|
||||
func (am *DefaultAccountManager) isCacheFresh(accountUsers map[string]userLoggedInOnce, data []*idp.UserData) bool {
|
||||
userDataMap := make(map[string]*idp.UserData, len(data))
|
||||
for _, datum := range data {
|
||||
userDataMap[datum.ID] = struct{}{}
|
||||
userDataMap[datum.ID] = datum
|
||||
}
|
||||
|
||||
// the accountUsers ID list of non integration users from store, we check if cache has all of them
|
||||
// as result of for loop knownUsersCount will have number of users are not presented in the cashed
|
||||
knownUsersCount := len(accountUsers)
|
||||
for user := range accountUsers {
|
||||
if _, ok := userDataMap[user]; ok {
|
||||
for user, loggedInOnce := range accountUsers {
|
||||
if datum, ok := userDataMap[user]; ok {
|
||||
// check if the matching user data has a pending invite and if the user has logged in once, forcing the cache to be refreshed
|
||||
if datum.AppMetadata.WTPendingInvite != nil && *datum.AppMetadata.WTPendingInvite && loggedInOnce == true { //nolint:gosimple
|
||||
log.Infof("user %s has a pending invite and has logged in once, cache invalid", user)
|
||||
return false
|
||||
}
|
||||
knownUsersCount--
|
||||
continue
|
||||
}
|
||||
@@ -1369,15 +1413,11 @@ func (am *DefaultAccountManager) lookupCache(accountUsers map[string]struct{}, a
|
||||
|
||||
// if we know users that are not yet in cache more likely cache is outdated
|
||||
if knownUsersCount > 0 {
|
||||
log.Debugf("cache doesn't know about %d users from store, reloading", knownUsersCount)
|
||||
// reload cache once avoiding loops
|
||||
data, err = am.refreshCache(accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
log.Infof("cache invalid. Users unknown to the cache: %d", knownUsersCount)
|
||||
return false
|
||||
}
|
||||
|
||||
return data, err
|
||||
return true
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) removeUserFromCache(accountID, userID string) error {
|
||||
@@ -1425,29 +1465,14 @@ func (am *DefaultAccountManager) updateAccountDomainAttributes(account *Account,
|
||||
}
|
||||
|
||||
// handleExistingUserAccount handles existing User accounts and update its domain attributes.
|
||||
//
|
||||
// If there is no primary domain account yet, we set the account as primary for the domain. Otherwise,
|
||||
// we compare the account's ID with the domain account ID, and if they don't match, we set the account as
|
||||
// non-primary account for the domain. We don't merge accounts at this stage, because of cases when a domain
|
||||
// was previously unclassified or classified as public so N users that logged int that time, has they own account
|
||||
// and peers that shouldn't be lost.
|
||||
func (am *DefaultAccountManager) handleExistingUserAccount(
|
||||
existingAcc *Account,
|
||||
domainAcc *Account,
|
||||
primaryDomain bool,
|
||||
claims jwtclaims.AuthorizationClaims,
|
||||
) error {
|
||||
var err error
|
||||
|
||||
if domainAcc != nil && existingAcc.Id != domainAcc.Id {
|
||||
err = am.updateAccountDomainAttributes(existingAcc, claims, false)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
err = am.updateAccountDomainAttributes(existingAcc, claims, true)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err := am.updateAccountDomainAttributes(existingAcc, claims, primaryDomain)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// we should register the account ID to this user's metadata in our IDP manager
|
||||
@@ -1547,7 +1572,7 @@ func (am *DefaultAccountManager) MarkPATUsed(tokenID string) error {
|
||||
return err
|
||||
}
|
||||
|
||||
unlock := am.Store.AcquireAccountLock(account.Id)
|
||||
unlock := am.Store.AcquireAccountWriteLock(account.Id)
|
||||
defer unlock()
|
||||
|
||||
account, err = am.Store.GetAccountByUser(user.Id)
|
||||
@@ -1630,7 +1655,7 @@ func (am *DefaultAccountManager) GetAccountFromToken(claims jwtclaims.Authorizat
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
unlock := am.Store.AcquireAccountLock(newAcc.Id)
|
||||
unlock := am.Store.AcquireAccountWriteLock(newAcc.Id)
|
||||
alreadyUnlocked := false
|
||||
defer func() {
|
||||
if !alreadyUnlocked {
|
||||
@@ -1649,7 +1674,7 @@ func (am *DefaultAccountManager) GetAccountFromToken(claims jwtclaims.Authorizat
|
||||
return nil, nil, status.Errorf(status.NotFound, "user %s not found", claims.UserId)
|
||||
}
|
||||
|
||||
if !user.IsServiceUser {
|
||||
if !user.IsServiceUser && claims.Invited {
|
||||
err = am.redeemInvite(account, claims.UserId)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
@@ -1781,12 +1806,33 @@ func (am *DefaultAccountManager) getAccountWithAuthorizationClaims(claims jwtcla
|
||||
|
||||
account, err := am.Store.GetAccountByUser(claims.UserId)
|
||||
if err == nil {
|
||||
err = am.handleExistingUserAccount(account, domainAccount, claims)
|
||||
unlockAccount := am.Store.AcquireAccountWriteLock(account.Id)
|
||||
defer unlockAccount()
|
||||
account, err = am.Store.GetAccountByUser(claims.UserId)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// If there is no primary domain account yet, we set the account as primary for the domain. Otherwise,
|
||||
// we compare the account's ID with the domain account ID, and if they don't match, we set the account as
|
||||
// non-primary account for the domain. We don't merge accounts at this stage, because of cases when a domain
|
||||
// was previously unclassified or classified as public so N users that logged int that time, has they own account
|
||||
// and peers that shouldn't be lost.
|
||||
primaryDomain := domainAccount == nil || account.Id == domainAccount.Id
|
||||
|
||||
err = am.handleExistingUserAccount(account, primaryDomain, claims)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return account, nil
|
||||
} else if s, ok := status.FromError(err); ok && s.Type() == status.NotFound {
|
||||
if domainAccount != nil {
|
||||
unlockAccount := am.Store.AcquireAccountWriteLock(domainAccount.Id)
|
||||
defer unlockAccount()
|
||||
domainAccount, err = am.Store.GetAccountByPrivateDomain(claims.Domain)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return am.handleNewUserAccount(domainAccount, claims)
|
||||
} else {
|
||||
// other error
|
||||
@@ -1794,6 +1840,56 @@ func (am *DefaultAccountManager) getAccountWithAuthorizationClaims(claims jwtcla
|
||||
}
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) SyncAndMarkPeer(peerPubKey string, realIP net.IP) (*nbpeer.Peer, *NetworkMap, error) {
|
||||
accountID, err := am.Store.GetAccountIDByPeerPubKey(peerPubKey)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
unlock := am.Store.AcquireAccountReadLock(accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(accountID)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
peer, netMap, err := am.SyncPeer(PeerSync{WireGuardPubKey: peerPubKey}, account)
|
||||
if err != nil {
|
||||
return nil, nil, mapError(err)
|
||||
}
|
||||
|
||||
err = am.MarkPeerConnected(peerPubKey, true, realIP, account)
|
||||
if err != nil {
|
||||
log.Warnf("failed marking peer as connected %s %v", peerPubKey, err)
|
||||
}
|
||||
|
||||
return peer, netMap, nil
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) CancelPeerRoutines(peer *nbpeer.Peer) error {
|
||||
accountID, err := am.Store.GetAccountIDByPeerPubKey(peer.Key)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
unlock := am.Store.AcquireAccountWriteLock(accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = am.MarkPeerConnected(peer.Key, false, nil, account)
|
||||
if err != nil {
|
||||
log.Warnf("failed marking peer as connected %s %v", peer.Key, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
}
|
||||
|
||||
// GetAllConnectedPeers returns connected peers based on peersUpdateManager.GetAllConnectedPeers()
|
||||
func (am *DefaultAccountManager) GetAllConnectedPeers() (map[string]struct{}, error) {
|
||||
return am.peersUpdateManager.GetAllConnectedPeers(), nil
|
||||
@@ -1849,6 +1945,7 @@ func (am *DefaultAccountManager) CheckUserAccessByJWTGroups(claims jwtclaims.Aut
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) onPeersInvalidated(accountID string) {
|
||||
log.Debugf("validated peers has been invalidated for account %s", accountID)
|
||||
updatedAccount, err := am.Store.GetAccount(accountID)
|
||||
if err != nil {
|
||||
log.Errorf("failed to get account %s: %v", accountID, err)
|
||||
|
||||
@@ -1655,7 +1655,7 @@ func TestDefaultAccountManager_DefaultAccountSettings(t *testing.T) {
|
||||
func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) {
|
||||
manager, err := createManager(t)
|
||||
require.NoError(t, err, "unable to create account manager")
|
||||
account, err := manager.GetAccountByUserOrAccountID(userID, "", "")
|
||||
_, err = manager.GetAccountByUserOrAccountID(userID, "", "")
|
||||
require.NoError(t, err, "unable to create an account")
|
||||
|
||||
key, err := wgtypes.GenerateKey()
|
||||
@@ -1666,7 +1666,10 @@ func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) {
|
||||
LoginExpirationEnabled: true,
|
||||
})
|
||||
require.NoError(t, err, "unable to add peer")
|
||||
err = manager.MarkPeerConnected(key.PublicKey().String(), true, nil)
|
||||
|
||||
account, err := manager.GetAccountByUserOrAccountID(userID, "", "")
|
||||
require.NoError(t, err, "unable to get the account")
|
||||
err = manager.MarkPeerConnected(key.PublicKey().String(), true, nil, account)
|
||||
require.NoError(t, err, "unable to mark peer connected")
|
||||
account, err = manager.UpdateAccountSettings(account.Id, userID, &Settings{
|
||||
PeerLoginExpiration: time.Hour,
|
||||
@@ -1732,8 +1735,10 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing.
|
||||
},
|
||||
}
|
||||
|
||||
account, err = manager.GetAccountByUserOrAccountID(userID, "", "")
|
||||
require.NoError(t, err, "unable to get the account")
|
||||
// when we mark peer as connected, the peer login expiration routine should trigger
|
||||
err = manager.MarkPeerConnected(key.PublicKey().String(), true, nil)
|
||||
err = manager.MarkPeerConnected(key.PublicKey().String(), true, nil, account)
|
||||
require.NoError(t, err, "unable to mark peer connected")
|
||||
|
||||
failed := waitTimeout(wg, time.Second)
|
||||
@@ -1745,7 +1750,7 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing.
|
||||
func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *testing.T) {
|
||||
manager, err := createManager(t)
|
||||
require.NoError(t, err, "unable to create account manager")
|
||||
account, err := manager.GetAccountByUserOrAccountID(userID, "", "")
|
||||
_, err = manager.GetAccountByUserOrAccountID(userID, "", "")
|
||||
require.NoError(t, err, "unable to create an account")
|
||||
|
||||
key, err := wgtypes.GenerateKey()
|
||||
@@ -1756,7 +1761,10 @@ func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *test
|
||||
LoginExpirationEnabled: true,
|
||||
})
|
||||
require.NoError(t, err, "unable to add peer")
|
||||
err = manager.MarkPeerConnected(key.PublicKey().String(), true, nil)
|
||||
|
||||
account, err := manager.GetAccountByUserOrAccountID(userID, "", "")
|
||||
require.NoError(t, err, "unable to get the account")
|
||||
err = manager.MarkPeerConnected(key.PublicKey().String(), true, nil, account)
|
||||
require.NoError(t, err, "unable to mark peer connected")
|
||||
|
||||
wg := &sync.WaitGroup{}
|
||||
|
||||
@@ -11,133 +11,134 @@ type Code struct {
|
||||
Code string
|
||||
}
|
||||
|
||||
// Existing consts must not be changed, as this will break the compatibility with the existing data
|
||||
const (
|
||||
// PeerAddedByUser indicates that a user added a new peer to the system
|
||||
PeerAddedByUser Activity = iota
|
||||
PeerAddedByUser Activity = 0
|
||||
// PeerAddedWithSetupKey indicates that a new peer joined the system using a setup key
|
||||
PeerAddedWithSetupKey
|
||||
PeerAddedWithSetupKey Activity = 1
|
||||
// UserJoined indicates that a new user joined the account
|
||||
UserJoined
|
||||
UserJoined Activity = 2
|
||||
// UserInvited indicates that a new user was invited to join the account
|
||||
UserInvited
|
||||
UserInvited Activity = 3
|
||||
// AccountCreated indicates that a new account has been created
|
||||
AccountCreated
|
||||
AccountCreated Activity = 4
|
||||
// PeerRemovedByUser indicates that a user removed a peer from the system
|
||||
PeerRemovedByUser
|
||||
PeerRemovedByUser Activity = 5
|
||||
// RuleAdded indicates that a user added a new rule
|
||||
RuleAdded
|
||||
RuleAdded Activity = 6
|
||||
// RuleUpdated indicates that a user updated a rule
|
||||
RuleUpdated
|
||||
RuleUpdated Activity = 7
|
||||
// RuleRemoved indicates that a user removed a rule
|
||||
RuleRemoved
|
||||
RuleRemoved Activity = 8
|
||||
// PolicyAdded indicates that a user added a new policy
|
||||
PolicyAdded
|
||||
PolicyAdded Activity = 9
|
||||
// PolicyUpdated indicates that a user updated a policy
|
||||
PolicyUpdated
|
||||
PolicyUpdated Activity = 10
|
||||
// PolicyRemoved indicates that a user removed a policy
|
||||
PolicyRemoved
|
||||
PolicyRemoved Activity = 11
|
||||
// SetupKeyCreated indicates that a user created a new setup key
|
||||
SetupKeyCreated
|
||||
SetupKeyCreated Activity = 12
|
||||
// SetupKeyUpdated indicates that a user updated a setup key
|
||||
SetupKeyUpdated
|
||||
SetupKeyUpdated Activity = 13
|
||||
// SetupKeyRevoked indicates that a user revoked a setup key
|
||||
SetupKeyRevoked
|
||||
SetupKeyRevoked Activity = 14
|
||||
// SetupKeyOverused indicates that setup key usage exhausted
|
||||
SetupKeyOverused
|
||||
SetupKeyOverused Activity = 15
|
||||
// GroupCreated indicates that a user created a group
|
||||
GroupCreated
|
||||
GroupCreated Activity = 16
|
||||
// GroupUpdated indicates that a user updated a group
|
||||
GroupUpdated
|
||||
GroupUpdated Activity = 17
|
||||
// GroupAddedToPeer indicates that a user added group to a peer
|
||||
GroupAddedToPeer
|
||||
GroupAddedToPeer Activity = 18
|
||||
// GroupRemovedFromPeer indicates that a user removed peer group
|
||||
GroupRemovedFromPeer
|
||||
GroupRemovedFromPeer Activity = 19
|
||||
// GroupAddedToUser indicates that a user added group to a user
|
||||
GroupAddedToUser
|
||||
GroupAddedToUser Activity = 20
|
||||
// GroupRemovedFromUser indicates that a user removed a group from a user
|
||||
GroupRemovedFromUser
|
||||
GroupRemovedFromUser Activity = 21
|
||||
// UserRoleUpdated indicates that a user changed the role of a user
|
||||
UserRoleUpdated
|
||||
UserRoleUpdated Activity = 22
|
||||
// GroupAddedToSetupKey indicates that a user added group to a setup key
|
||||
GroupAddedToSetupKey
|
||||
GroupAddedToSetupKey Activity = 23
|
||||
// GroupRemovedFromSetupKey indicates that a user removed a group from a setup key
|
||||
GroupRemovedFromSetupKey
|
||||
GroupRemovedFromSetupKey Activity = 24
|
||||
// GroupAddedToDisabledManagementGroups indicates that a user added a group to the DNS setting Disabled management groups
|
||||
GroupAddedToDisabledManagementGroups
|
||||
GroupAddedToDisabledManagementGroups Activity = 25
|
||||
// GroupRemovedFromDisabledManagementGroups indicates that a user removed a group from the DNS setting Disabled management groups
|
||||
GroupRemovedFromDisabledManagementGroups
|
||||
GroupRemovedFromDisabledManagementGroups Activity = 26
|
||||
// RouteCreated indicates that a user created a route
|
||||
RouteCreated
|
||||
RouteCreated Activity = 27
|
||||
// RouteRemoved indicates that a user deleted a route
|
||||
RouteRemoved
|
||||
RouteRemoved Activity = 28
|
||||
// RouteUpdated indicates that a user updated a route
|
||||
RouteUpdated
|
||||
RouteUpdated Activity = 29
|
||||
// PeerSSHEnabled indicates that a user enabled SSH server on a peer
|
||||
PeerSSHEnabled
|
||||
PeerSSHEnabled Activity = 30
|
||||
// PeerSSHDisabled indicates that a user disabled SSH server on a peer
|
||||
PeerSSHDisabled
|
||||
PeerSSHDisabled Activity = 31
|
||||
// PeerRenamed indicates that a user renamed a peer
|
||||
PeerRenamed
|
||||
PeerRenamed Activity = 32
|
||||
// PeerLoginExpirationEnabled indicates that a user enabled login expiration of a peer
|
||||
PeerLoginExpirationEnabled
|
||||
PeerLoginExpirationEnabled Activity = 33
|
||||
// PeerLoginExpirationDisabled indicates that a user disabled login expiration of a peer
|
||||
PeerLoginExpirationDisabled
|
||||
PeerLoginExpirationDisabled Activity = 34
|
||||
// NameserverGroupCreated indicates that a user created a nameservers group
|
||||
NameserverGroupCreated
|
||||
NameserverGroupCreated Activity = 35
|
||||
// NameserverGroupDeleted indicates that a user deleted a nameservers group
|
||||
NameserverGroupDeleted
|
||||
NameserverGroupDeleted Activity = 36
|
||||
// NameserverGroupUpdated indicates that a user updated a nameservers group
|
||||
NameserverGroupUpdated
|
||||
NameserverGroupUpdated Activity = 37
|
||||
// AccountPeerLoginExpirationEnabled indicates that a user enabled peer login expiration for the account
|
||||
AccountPeerLoginExpirationEnabled
|
||||
AccountPeerLoginExpirationEnabled Activity = 38
|
||||
// AccountPeerLoginExpirationDisabled indicates that a user disabled peer login expiration for the account
|
||||
AccountPeerLoginExpirationDisabled
|
||||
AccountPeerLoginExpirationDisabled Activity = 39
|
||||
// AccountPeerLoginExpirationDurationUpdated indicates that a user updated peer login expiration duration for the account
|
||||
AccountPeerLoginExpirationDurationUpdated
|
||||
AccountPeerLoginExpirationDurationUpdated Activity = 40
|
||||
// PersonalAccessTokenCreated indicates that a user created a personal access token
|
||||
PersonalAccessTokenCreated
|
||||
PersonalAccessTokenCreated Activity = 41
|
||||
// PersonalAccessTokenDeleted indicates that a user deleted a personal access token
|
||||
PersonalAccessTokenDeleted
|
||||
PersonalAccessTokenDeleted Activity = 42
|
||||
// ServiceUserCreated indicates that a user created a service user
|
||||
ServiceUserCreated
|
||||
ServiceUserCreated Activity = 43
|
||||
// ServiceUserDeleted indicates that a user deleted a service user
|
||||
ServiceUserDeleted
|
||||
ServiceUserDeleted Activity = 44
|
||||
// UserBlocked indicates that a user blocked another user
|
||||
UserBlocked
|
||||
UserBlocked Activity = 45
|
||||
// UserUnblocked indicates that a user unblocked another user
|
||||
UserUnblocked
|
||||
UserUnblocked Activity = 46
|
||||
// UserDeleted indicates that a user deleted another user
|
||||
UserDeleted
|
||||
UserDeleted Activity = 47
|
||||
// GroupDeleted indicates that a user deleted group
|
||||
GroupDeleted
|
||||
GroupDeleted Activity = 48
|
||||
// UserLoggedInPeer indicates that user logged in their peer with an interactive SSO login
|
||||
UserLoggedInPeer
|
||||
UserLoggedInPeer Activity = 49
|
||||
// PeerLoginExpired indicates that the user peer login has been expired and peer disconnected
|
||||
PeerLoginExpired
|
||||
PeerLoginExpired Activity = 50
|
||||
// DashboardLogin indicates that the user logged in to the dashboard
|
||||
DashboardLogin
|
||||
DashboardLogin Activity = 51
|
||||
// IntegrationCreated indicates that the user created an integration
|
||||
IntegrationCreated
|
||||
IntegrationCreated Activity = 52
|
||||
// IntegrationUpdated indicates that the user updated an integration
|
||||
IntegrationUpdated
|
||||
IntegrationUpdated Activity = 53
|
||||
// IntegrationDeleted indicates that the user deleted an integration
|
||||
IntegrationDeleted
|
||||
IntegrationDeleted Activity = 54
|
||||
// AccountPeerApprovalEnabled indicates that the user enabled peer approval for the account
|
||||
AccountPeerApprovalEnabled
|
||||
AccountPeerApprovalEnabled Activity = 55
|
||||
// AccountPeerApprovalDisabled indicates that the user disabled peer approval for the account
|
||||
AccountPeerApprovalDisabled
|
||||
AccountPeerApprovalDisabled Activity = 56
|
||||
// PeerApproved indicates that the peer has been approved
|
||||
PeerApproved
|
||||
PeerApproved Activity = 57
|
||||
// PeerApprovalRevoked indicates that the peer approval has been revoked
|
||||
PeerApprovalRevoked
|
||||
PeerApprovalRevoked Activity = 58
|
||||
// TransferredOwnerRole indicates that the user transferred the owner role of the account
|
||||
TransferredOwnerRole
|
||||
TransferredOwnerRole Activity = 59
|
||||
// PostureCheckCreated indicates that the user created a posture check
|
||||
PostureCheckCreated
|
||||
PostureCheckCreated Activity = 60
|
||||
// PostureCheckUpdated indicates that the user updated a posture check
|
||||
PostureCheckUpdated
|
||||
PostureCheckUpdated Activity = 61
|
||||
// PostureCheckDeleted indicates that the user deleted a posture check
|
||||
PostureCheckDeleted
|
||||
PostureCheckDeleted Activity = 62
|
||||
)
|
||||
|
||||
var activityMap = map[Activity]Code{
|
||||
|
||||
@@ -35,7 +35,7 @@ func (d DNSSettings) Copy() DNSSettings {
|
||||
|
||||
// GetDNSSettings validates a user role and returns the DNS settings for the provided account ID
|
||||
func (am *DefaultAccountManager) GetDNSSettings(accountID string, userID string) (*DNSSettings, error) {
|
||||
unlock := am.Store.AcquireAccountLock(accountID)
|
||||
unlock := am.Store.AcquireAccountWriteLock(accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(accountID)
|
||||
@@ -57,7 +57,7 @@ func (am *DefaultAccountManager) GetDNSSettings(accountID string, userID string)
|
||||
|
||||
// SaveDNSSettings validates a user role and updates the account's DNS settings
|
||||
func (am *DefaultAccountManager) SaveDNSSettings(accountID string, userID string, dnsSettingsToSave *DNSSettings) error {
|
||||
unlock := am.Store.AcquireAccountLock(accountID)
|
||||
unlock := am.Store.AcquireAccountWriteLock(accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(accountID)
|
||||
|
||||
@@ -12,7 +12,7 @@ import (
|
||||
|
||||
// GetEvents returns a list of activity events of an account
|
||||
func (am *DefaultAccountManager) GetEvents(accountID, userID string) ([]*activity.Event, error) {
|
||||
unlock := am.Store.AcquireAccountLock(accountID)
|
||||
unlock := am.Store.AcquireAccountWriteLock(accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(accountID)
|
||||
|
||||
@@ -279,8 +279,8 @@ func (s *FileStore) AcquireGlobalLock() (unlock func()) {
|
||||
return unlock
|
||||
}
|
||||
|
||||
// AcquireAccountLock acquires account lock and returns a function that releases the lock
|
||||
func (s *FileStore) AcquireAccountLock(accountID string) (unlock func()) {
|
||||
// AcquireAccountWriteLock acquires account lock for writing to a resource and returns a function that releases the lock
|
||||
func (s *FileStore) AcquireAccountWriteLock(accountID string) (unlock func()) {
|
||||
log.Debugf("acquiring lock for account %s", accountID)
|
||||
start := time.Now()
|
||||
value, _ := s.accountLocks.LoadOrStore(accountID, &sync.Mutex{})
|
||||
@@ -295,6 +295,12 @@ func (s *FileStore) AcquireAccountLock(accountID string) (unlock func()) {
|
||||
return unlock
|
||||
}
|
||||
|
||||
// AcquireAccountReadLock AcquireAccountWriteLock acquires account lock for reading a resource and returns a function that releases the lock
|
||||
// This method is still returns a write lock as file store can't handle read locks
|
||||
func (s *FileStore) AcquireAccountReadLock(accountID string) (unlock func()) {
|
||||
return s.AcquireAccountWriteLock(accountID)
|
||||
}
|
||||
|
||||
func (s *FileStore) SaveAccount(account *Account) error {
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
@@ -572,6 +578,18 @@ func (s *FileStore) GetAccountByPeerPubKey(peerKey string) (*Account, error) {
|
||||
return account.Copy(), nil
|
||||
}
|
||||
|
||||
func (s *FileStore) GetAccountIDByPeerPubKey(peerKey string) (string, error) {
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
|
||||
accountID, ok := s.PeerKeyID2AccountID[peerKey]
|
||||
if !ok {
|
||||
return "", status.Errorf(status.NotFound, "provided peer key doesn't exists %s", peerKey)
|
||||
}
|
||||
|
||||
return accountID, nil
|
||||
}
|
||||
|
||||
// GetInstallationID returns the installation ID from the store
|
||||
func (s *FileStore) GetInstallationID() string {
|
||||
return s.InstallationID
|
||||
|
||||
@@ -22,7 +22,7 @@ func (e *GroupLinkError) Error() string {
|
||||
|
||||
// GetGroup object of the peers
|
||||
func (am *DefaultAccountManager) GetGroup(accountID, groupID, userID string) (*nbgroup.Group, error) {
|
||||
unlock := am.Store.AcquireAccountLock(accountID)
|
||||
unlock := am.Store.AcquireAccountWriteLock(accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(accountID)
|
||||
@@ -49,7 +49,7 @@ func (am *DefaultAccountManager) GetGroup(accountID, groupID, userID string) (*n
|
||||
|
||||
// GetAllGroups returns all groups in an account
|
||||
func (am *DefaultAccountManager) GetAllGroups(accountID string, userID string) ([]*nbgroup.Group, error) {
|
||||
unlock := am.Store.AcquireAccountLock(accountID)
|
||||
unlock := am.Store.AcquireAccountWriteLock(accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(accountID)
|
||||
@@ -76,7 +76,7 @@ func (am *DefaultAccountManager) GetAllGroups(accountID string, userID string) (
|
||||
|
||||
// GetGroupByName filters all groups in an account by name and returns the one with the most peers
|
||||
func (am *DefaultAccountManager) GetGroupByName(groupName, accountID string) (*nbgroup.Group, error) {
|
||||
unlock := am.Store.AcquireAccountLock(accountID)
|
||||
unlock := am.Store.AcquireAccountWriteLock(accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(accountID)
|
||||
@@ -109,7 +109,7 @@ func (am *DefaultAccountManager) GetGroupByName(groupName, accountID string) (*n
|
||||
|
||||
// SaveGroup object of the peers
|
||||
func (am *DefaultAccountManager) SaveGroup(accountID, userID string, newGroup *nbgroup.Group) error {
|
||||
unlock := am.Store.AcquireAccountLock(accountID)
|
||||
unlock := am.Store.AcquireAccountWriteLock(accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(accountID)
|
||||
@@ -214,7 +214,7 @@ func difference(a, b []string) []string {
|
||||
|
||||
// DeleteGroup object of the peers
|
||||
func (am *DefaultAccountManager) DeleteGroup(accountId, userId, groupID string) error {
|
||||
unlock := am.Store.AcquireAccountLock(accountId)
|
||||
unlock := am.Store.AcquireAccountWriteLock(accountId)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(accountId)
|
||||
@@ -323,7 +323,7 @@ func (am *DefaultAccountManager) DeleteGroup(accountId, userId, groupID string)
|
||||
|
||||
// ListGroups objects of the peers
|
||||
func (am *DefaultAccountManager) ListGroups(accountID string) ([]*nbgroup.Group, error) {
|
||||
unlock := am.Store.AcquireAccountLock(accountID)
|
||||
unlock := am.Store.AcquireAccountWriteLock(accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(accountID)
|
||||
@@ -341,7 +341,7 @@ func (am *DefaultAccountManager) ListGroups(accountID string) ([]*nbgroup.Group,
|
||||
|
||||
// GroupAddPeer appends peer to the group
|
||||
func (am *DefaultAccountManager) GroupAddPeer(accountID, groupID, peerID string) error {
|
||||
unlock := am.Store.AcquireAccountLock(accountID)
|
||||
unlock := am.Store.AcquireAccountWriteLock(accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(accountID)
|
||||
@@ -377,7 +377,7 @@ func (am *DefaultAccountManager) GroupAddPeer(accountID, groupID, peerID string)
|
||||
|
||||
// GroupDeletePeer removes peer from the group
|
||||
func (am *DefaultAccountManager) GroupDeletePeer(accountID, groupID, peerID string) error {
|
||||
unlock := am.Store.AcquireAccountLock(accountID)
|
||||
unlock := am.Store.AcquireAccountWriteLock(accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(accountID)
|
||||
|
||||
@@ -134,9 +134,9 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi
|
||||
return err
|
||||
}
|
||||
|
||||
peer, netMap, err := s.accountManager.SyncPeer(PeerSync{WireGuardPubKey: peerKey.String()})
|
||||
peer, netMap, err := s.accountManager.SyncAndMarkPeer(peerKey.String(), realIP)
|
||||
if err != nil {
|
||||
return mapError(err)
|
||||
return err
|
||||
}
|
||||
|
||||
err = s.sendInitialSync(peerKey, peer, netMap, srv)
|
||||
@@ -149,11 +149,6 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi
|
||||
|
||||
s.ephemeralManager.OnPeerConnected(peer)
|
||||
|
||||
err = s.accountManager.MarkPeerConnected(peerKey.String(), true, realIP)
|
||||
if err != nil {
|
||||
log.Warnf("failed marking peer as connected %s %v", peerKey, err)
|
||||
}
|
||||
|
||||
if s.config.TURNConfig.TimeBasedCredentials {
|
||||
s.turnCredentialsManager.SetupRefresh(peer.ID)
|
||||
}
|
||||
@@ -207,7 +202,7 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi
|
||||
func (s *GRPCServer) cancelPeerRoutines(peer *nbpeer.Peer) {
|
||||
s.peersUpdateManager.CloseChannel(peer.ID)
|
||||
s.turnCredentialsManager.CancelRefresh(peer.ID)
|
||||
_ = s.accountManager.MarkPeerConnected(peer.Key, false, nil)
|
||||
_ = s.accountManager.CancelPeerRoutines(peer)
|
||||
s.ephemeralManager.OnPeerDisconnected(peer)
|
||||
}
|
||||
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
s "github.com/netbirdio/netbird/management/server"
|
||||
"github.com/netbirdio/netbird/management/server/geolocation"
|
||||
"github.com/netbirdio/netbird/management/server/http/middleware"
|
||||
"github.com/netbirdio/netbird/management/server/integrated_validator"
|
||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
)
|
||||
@@ -38,7 +39,7 @@ type emptyObject struct {
|
||||
}
|
||||
|
||||
// APIHandler creates the Management service HTTP API handler registering all the available endpoints.
|
||||
func APIHandler(ctx context.Context, accountManager s.AccountManager, LocationManager *geolocation.Geolocation, jwtValidator jwtclaims.JWTValidator, appMetrics telemetry.AppMetrics, authCfg AuthCfg) (http.Handler, error) {
|
||||
func APIHandler(ctx context.Context, accountManager s.AccountManager, LocationManager *geolocation.Geolocation, jwtValidator jwtclaims.JWTValidator, appMetrics telemetry.AppMetrics, authCfg AuthCfg, integratedValidator integrated_validator.IntegratedValidator) (http.Handler, error) {
|
||||
claimsExtractor := jwtclaims.NewClaimsExtractor(
|
||||
jwtclaims.WithAudience(authCfg.Audience),
|
||||
jwtclaims.WithUserIDClaim(authCfg.UserIDClaim),
|
||||
@@ -75,7 +76,7 @@ func APIHandler(ctx context.Context, accountManager s.AccountManager, LocationMa
|
||||
AuthCfg: authCfg,
|
||||
}
|
||||
|
||||
if _, err := integrations.RegisterHandlers(ctx, prefix, api.Router, accountManager, claimsExtractor); err != nil {
|
||||
if _, err := integrations.RegisterHandlers(ctx, prefix, api.Router, accountManager, claimsExtractor, integratedValidator); err != nil {
|
||||
return nil, fmt.Errorf("register integrations endpoints: %w", err)
|
||||
}
|
||||
|
||||
|
||||
@@ -181,8 +181,8 @@ func (h *SetupKeysHandler) GetAllSetupKeys(w http.ResponseWriter, r *http.Reques
|
||||
}
|
||||
|
||||
func writeSuccess(w http.ResponseWriter, key *server.SetupKey) {
|
||||
w.WriteHeader(200)
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(200)
|
||||
err := json.NewEncoder(w).Encode(toResponseBody(key))
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
|
||||
@@ -198,8 +198,6 @@ func (h *UsersHandler) GetAllUsers(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
serviceUser := r.URL.Query().Get("service_user")
|
||||
|
||||
log.Debugf("UserCount: %v", len(data))
|
||||
|
||||
users := make([]*api.User, 0)
|
||||
for _, r := range data {
|
||||
if r.NonDeletable {
|
||||
|
||||
@@ -20,8 +20,8 @@ type ErrorResponse struct {
|
||||
|
||||
// WriteJSONObject simply writes object to the HTTP response in JSON format
|
||||
func WriteJSONObject(w http.ResponseWriter, obj interface{}) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Header().Set("Content-Type", "application/json; charset=UTF-8")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
err := json.NewEncoder(w).Encode(obj)
|
||||
if err != nil {
|
||||
WriteError(err, w)
|
||||
@@ -63,8 +63,8 @@ func (d *Duration) UnmarshalJSON(b []byte) error {
|
||||
|
||||
// WriteErrorResponse prepares and writes an error response i nJSON
|
||||
func WriteErrorResponse(errMsg string, httpStatus int, w http.ResponseWriter) {
|
||||
w.WriteHeader(httpStatus)
|
||||
w.Header().Set("Content-Type", "application/json; charset=UTF-8")
|
||||
w.WriteHeader(httpStatus)
|
||||
err := json.NewEncoder(w).Encode(&ErrorResponse{
|
||||
Message: errMsg,
|
||||
Code: httpStatus,
|
||||
|
||||
@@ -115,7 +115,15 @@ func (ac *AzureCredentials) requestJWTToken() (*http.Response, error) {
|
||||
data.Set("client_id", ac.clientConfig.ClientID)
|
||||
data.Set("client_secret", ac.clientConfig.ClientSecret)
|
||||
data.Set("grant_type", ac.clientConfig.GrantType)
|
||||
data.Set("scope", "https://graph.microsoft.com/.default")
|
||||
parsedURL, err := url.Parse(ac.clientConfig.GraphAPIEndpoint)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// get base url and add "/.default" as scope
|
||||
baseURL := parsedURL.Scheme + "://" + parsedURL.Host
|
||||
scopeURL := baseURL + "/.default"
|
||||
data.Set("scope", scopeURL)
|
||||
|
||||
payload := strings.NewReader(data.Encode())
|
||||
req, err := http.NewRequest(http.MethodPost, ac.clientConfig.TokenEndpoint, payload)
|
||||
|
||||
@@ -273,7 +273,7 @@ func (om *OktaManager) DeleteUser(userID string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// parseOktaUserToUserData parse okta user to UserData.
|
||||
// parseOktaUser parse okta user to UserData.
|
||||
func parseOktaUser(user *okta.User) (*UserData, error) {
|
||||
var oktaUser struct {
|
||||
Email string `json:"email"`
|
||||
|
||||
@@ -31,7 +31,7 @@ func (am *DefaultAccountManager) UpdateIntegratedValidatorGroups(accountID strin
|
||||
return errors.New("invalid groups")
|
||||
}
|
||||
|
||||
unlock := am.Store.AcquireAccountLock(accountID)
|
||||
unlock := am.Store.AcquireAccountWriteLock(accountID)
|
||||
defer unlock()
|
||||
|
||||
a, err := am.Store.GetAccountByUser(userID)
|
||||
|
||||
@@ -13,6 +13,7 @@ type AuthorizationClaims struct {
|
||||
Domain string
|
||||
DomainCategory string
|
||||
LastLogin time.Time
|
||||
Invited bool
|
||||
|
||||
Raw jwt.MapClaims
|
||||
}
|
||||
|
||||
@@ -20,6 +20,8 @@ const (
|
||||
UserIDClaim = "sub"
|
||||
// LastLoginSuffix claim for the last login
|
||||
LastLoginSuffix = "nb_last_login"
|
||||
// Invited claim indicates that an incoming JWT is from a user that just accepted an invitation
|
||||
Invited = "nb_invited"
|
||||
)
|
||||
|
||||
// ExtractClaims Extract function type
|
||||
@@ -100,6 +102,10 @@ func (c *ClaimsExtractor) FromToken(token *jwt.Token) AuthorizationClaims {
|
||||
if ok {
|
||||
jwtClaims.LastLogin = parseTime(LastLoginClaimString.(string))
|
||||
}
|
||||
invitedBool, ok := claims[c.authAudience+Invited]
|
||||
if ok {
|
||||
jwtClaims.Invited = invitedBool.(bool)
|
||||
}
|
||||
return jwtClaims
|
||||
}
|
||||
|
||||
|
||||
@@ -30,6 +30,10 @@ func newTestRequestWithJWT(t *testing.T, claims AuthorizationClaims, audience st
|
||||
if claims.LastLogin != (time.Time{}) {
|
||||
claimMaps[audience+LastLoginSuffix] = claims.LastLogin.Format(layout)
|
||||
}
|
||||
|
||||
if claims.Invited {
|
||||
claimMaps[audience+Invited] = true
|
||||
}
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claimMaps)
|
||||
r, err := http.NewRequest(http.MethodGet, "http://localhost", nil)
|
||||
require.NoError(t, err, "creating testing request failed")
|
||||
@@ -59,12 +63,14 @@ func TestExtractClaimsFromRequestContext(t *testing.T) {
|
||||
AccountId: "testAcc",
|
||||
LastLogin: lastLogin,
|
||||
DomainCategory: "public",
|
||||
Invited: true,
|
||||
Raw: jwt.MapClaims{
|
||||
"https://login/wt_account_domain": "test.com",
|
||||
"https://login/wt_account_domain_category": "public",
|
||||
"https://login/wt_account_id": "testAcc",
|
||||
"https://login/nb_last_login": lastLogin.Format(layout),
|
||||
"sub": "test",
|
||||
"https://login/" + Invited: true,
|
||||
},
|
||||
},
|
||||
testingFunc: require.EqualValues,
|
||||
|
||||
204
management/server/migration/migration.go
Normal file
204
management/server/migration/migration.go
Normal file
@@ -0,0 +1,204 @@
|
||||
package migration
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"encoding/gob"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"strings"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// MigrateFieldFromGobToJSON migrates a column from Gob encoding to JSON encoding.
|
||||
// T is the type of the model that contains the field to be migrated.
|
||||
// S is the type of the field to be migrated.
|
||||
func MigrateFieldFromGobToJSON[T any, S any](db *gorm.DB, fieldName string) error {
|
||||
|
||||
oldColumnName := fieldName
|
||||
newColumnName := fieldName + "_tmp"
|
||||
|
||||
var model T
|
||||
|
||||
if !db.Migrator().HasTable(&model) {
|
||||
log.Debugf("Table for %T does not exist, no migration needed", model)
|
||||
return nil
|
||||
}
|
||||
|
||||
stmt := &gorm.Statement{DB: db}
|
||||
err := stmt.Parse(model)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parse model: %w", err)
|
||||
}
|
||||
tableName := stmt.Schema.Table
|
||||
|
||||
var item string
|
||||
if err := db.Model(model).Select(oldColumnName).First(&item).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
log.Debugf("No records in table %s, no migration needed", tableName)
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("fetch first record: %w", err)
|
||||
}
|
||||
|
||||
var js json.RawMessage
|
||||
var syntaxError *json.SyntaxError
|
||||
err = json.Unmarshal([]byte(item), &js)
|
||||
if err == nil || !errors.As(err, &syntaxError) {
|
||||
log.Debugf("No migration needed for %s, %s", tableName, fieldName)
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := db.Transaction(func(tx *gorm.DB) error {
|
||||
if err := tx.Exec(fmt.Sprintf("ALTER TABLE %s ADD COLUMN %s TEXT", tableName, newColumnName)).Error; err != nil {
|
||||
return fmt.Errorf("add column %s: %w", newColumnName, err)
|
||||
}
|
||||
|
||||
var rows []map[string]any
|
||||
if err := tx.Table(tableName).Select("id", oldColumnName).Find(&rows).Error; err != nil {
|
||||
return fmt.Errorf("find rows: %w", err)
|
||||
}
|
||||
|
||||
for _, row := range rows {
|
||||
var field S
|
||||
|
||||
str, ok := row[oldColumnName].(string)
|
||||
if !ok {
|
||||
return fmt.Errorf("type assertion failed")
|
||||
}
|
||||
reader := strings.NewReader(str)
|
||||
|
||||
if err := gob.NewDecoder(reader).Decode(&field); err != nil {
|
||||
return fmt.Errorf("gob decode error: %w", err)
|
||||
}
|
||||
|
||||
jsonValue, err := json.Marshal(field)
|
||||
if err != nil {
|
||||
return fmt.Errorf("re-encode to JSON: %w", err)
|
||||
}
|
||||
|
||||
if err := tx.Table(tableName).Where("id = ?", row["id"]).Update(newColumnName, jsonValue).Error; err != nil {
|
||||
return fmt.Errorf("update row: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := tx.Exec(fmt.Sprintf("ALTER TABLE %s DROP COLUMN %s", tableName, oldColumnName)).Error; err != nil {
|
||||
return fmt.Errorf("drop column %s: %w", oldColumnName, err)
|
||||
}
|
||||
if err := tx.Exec(fmt.Sprintf("ALTER TABLE %s RENAME COLUMN %s TO %s", tableName, newColumnName, oldColumnName)).Error; err != nil {
|
||||
return fmt.Errorf("rename column %s to %s: %w", newColumnName, oldColumnName, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
log.Infof("Migration of %s.%s from gob to json completed", tableName, fieldName)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// MigrateNetIPFieldFromBlobToJSON migrates a Net IP column from Blob encoding to JSON encoding.
|
||||
// T is the type of the model that contains the field to be migrated.
|
||||
func MigrateNetIPFieldFromBlobToJSON[T any](db *gorm.DB, fieldName string, indexName string) error {
|
||||
oldColumnName := fieldName
|
||||
newColumnName := fieldName + "_tmp"
|
||||
|
||||
var model T
|
||||
|
||||
if !db.Migrator().HasTable(&model) {
|
||||
log.Printf("Table for %T does not exist, no migration needed", model)
|
||||
return nil
|
||||
}
|
||||
|
||||
stmt := &gorm.Statement{DB: db}
|
||||
err := stmt.Parse(&model)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parse model: %w", err)
|
||||
}
|
||||
tableName := stmt.Schema.Table
|
||||
|
||||
var item sql.NullString
|
||||
if err := db.Model(&model).Select(oldColumnName).First(&item).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
log.Printf("No records in table %s, no migration needed", tableName)
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("fetch first record: %w", err)
|
||||
}
|
||||
|
||||
if item.Valid {
|
||||
var js json.RawMessage
|
||||
var syntaxError *json.SyntaxError
|
||||
err = json.Unmarshal([]byte(item.String), &js)
|
||||
if err == nil || !errors.As(err, &syntaxError) {
|
||||
log.Debugf("No migration needed for %s, %s", tableName, fieldName)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
if err := db.Transaction(func(tx *gorm.DB) error {
|
||||
if err := tx.Exec(fmt.Sprintf("ALTER TABLE %s ADD COLUMN %s TEXT", tableName, newColumnName)).Error; err != nil {
|
||||
return fmt.Errorf("add column %s: %w", newColumnName, err)
|
||||
}
|
||||
|
||||
var rows []map[string]any
|
||||
if err := tx.Table(tableName).Select("id", oldColumnName).Find(&rows).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
log.Printf("No records in table %s, no migration needed", tableName)
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("find rows: %w", err)
|
||||
}
|
||||
|
||||
for _, row := range rows {
|
||||
var blobValue string
|
||||
if columnValue := row[oldColumnName]; columnValue != nil {
|
||||
value, ok := columnValue.(string)
|
||||
if !ok {
|
||||
return fmt.Errorf("type assertion failed")
|
||||
}
|
||||
blobValue = value
|
||||
}
|
||||
|
||||
columnIpValue := net.IP(blobValue)
|
||||
if net.ParseIP(columnIpValue.String()) == nil {
|
||||
log.Debugf("failed to parse %s as ip, fallback to ipv6 loopback", oldColumnName)
|
||||
columnIpValue = net.IPv6loopback
|
||||
}
|
||||
|
||||
jsonValue, err := json.Marshal(columnIpValue)
|
||||
if err != nil {
|
||||
return fmt.Errorf("re-encode to JSON: %w", err)
|
||||
}
|
||||
|
||||
if err := tx.Table(tableName).Where("id = ?", row["id"]).Update(newColumnName, jsonValue).Error; err != nil {
|
||||
return fmt.Errorf("update row: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if indexName != "" {
|
||||
if err := tx.Migrator().DropIndex(&model, indexName); err != nil {
|
||||
return fmt.Errorf("drop index %s: %w", indexName, err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := tx.Exec(fmt.Sprintf("ALTER TABLE %s DROP COLUMN %s", tableName, oldColumnName)).Error; err != nil {
|
||||
return fmt.Errorf("drop column %s: %w", oldColumnName, err)
|
||||
}
|
||||
if err := tx.Exec(fmt.Sprintf("ALTER TABLE %s RENAME COLUMN %s TO %s", tableName, newColumnName, oldColumnName)).Error; err != nil {
|
||||
return fmt.Errorf("rename column %s to %s: %w", newColumnName, oldColumnName, err)
|
||||
}
|
||||
return nil
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
log.Printf("Migration of %s.%s from blob to json completed", tableName, fieldName)
|
||||
|
||||
return nil
|
||||
}
|
||||
161
management/server/migration/migration_test.go
Normal file
161
management/server/migration/migration_test.go
Normal file
@@ -0,0 +1,161 @@
|
||||
package migration_test
|
||||
|
||||
import (
|
||||
"encoding/gob"
|
||||
"net"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"gorm.io/driver/sqlite"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
"github.com/netbirdio/netbird/management/server/migration"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
)
|
||||
|
||||
func setupDatabase(t *testing.T) *gorm.DB {
|
||||
t.Helper()
|
||||
|
||||
db, err := gorm.Open(sqlite.Open("file::memory:?cache=shared"), &gorm.Config{
|
||||
PrepareStmt: true,
|
||||
})
|
||||
|
||||
require.NoError(t, err, "Failed to open database")
|
||||
return db
|
||||
}
|
||||
|
||||
func TestMigrateFieldFromGobToJSON_EmptyDB(t *testing.T) {
|
||||
db := setupDatabase(t)
|
||||
err := migration.MigrateFieldFromGobToJSON[server.Account, net.IPNet](db, "network_net")
|
||||
require.NoError(t, err, "Migration should not fail for an empty database")
|
||||
}
|
||||
|
||||
func TestMigrateFieldFromGobToJSON_WithGobData(t *testing.T) {
|
||||
db := setupDatabase(t)
|
||||
|
||||
err := db.AutoMigrate(&server.Account{}, &route.Route{})
|
||||
require.NoError(t, err, "Failed to auto-migrate tables")
|
||||
|
||||
_, ipnet, err := net.ParseCIDR("10.0.0.0/24")
|
||||
require.NoError(t, err, "Failed to parse CIDR")
|
||||
|
||||
type network struct {
|
||||
server.Network
|
||||
Net net.IPNet `gorm:"serializer:gob"`
|
||||
}
|
||||
|
||||
type account struct {
|
||||
server.Account
|
||||
Network *network `gorm:"embedded;embeddedPrefix:network_"`
|
||||
}
|
||||
|
||||
err = db.Save(&account{Account: server.Account{Id: "123"}, Network: &network{Net: *ipnet}}).Error
|
||||
require.NoError(t, err, "Failed to insert Gob data")
|
||||
|
||||
var gobStr string
|
||||
err = db.Model(&server.Account{}).Select("network_net").First(&gobStr).Error
|
||||
assert.NoError(t, err, "Failed to fetch Gob data")
|
||||
|
||||
err = gob.NewDecoder(strings.NewReader(gobStr)).Decode(&ipnet)
|
||||
require.NoError(t, err, "Failed to decode Gob data")
|
||||
|
||||
err = migration.MigrateFieldFromGobToJSON[server.Account, net.IPNet](db, "network_net")
|
||||
require.NoError(t, err, "Migration should not fail with Gob data")
|
||||
|
||||
var jsonStr string
|
||||
db.Model(&server.Account{}).Select("network_net").First(&jsonStr)
|
||||
assert.JSONEq(t, `{"IP":"10.0.0.0","Mask":"////AA=="}`, jsonStr, "Data should be migrated")
|
||||
}
|
||||
|
||||
func TestMigrateFieldFromGobToJSON_WithJSONData(t *testing.T) {
|
||||
db := setupDatabase(t)
|
||||
|
||||
err := db.AutoMigrate(&server.Account{}, &route.Route{})
|
||||
require.NoError(t, err, "Failed to auto-migrate tables")
|
||||
|
||||
_, ipnet, err := net.ParseCIDR("10.0.0.0/24")
|
||||
require.NoError(t, err, "Failed to parse CIDR")
|
||||
|
||||
err = db.Save(&server.Account{Network: &server.Network{Net: *ipnet}}).Error
|
||||
require.NoError(t, err, "Failed to insert JSON data")
|
||||
|
||||
err = migration.MigrateFieldFromGobToJSON[server.Account, net.IPNet](db, "network_net")
|
||||
require.NoError(t, err, "Migration should not fail with JSON data")
|
||||
|
||||
var jsonStr string
|
||||
db.Model(&server.Account{}).Select("network_net").First(&jsonStr)
|
||||
assert.JSONEq(t, `{"IP":"10.0.0.0","Mask":"////AA=="}`, jsonStr, "Data should be unchanged")
|
||||
}
|
||||
|
||||
func TestMigrateNetIPFieldFromBlobToJSON_EmptyDB(t *testing.T) {
|
||||
db := setupDatabase(t)
|
||||
err := migration.MigrateNetIPFieldFromBlobToJSON[nbpeer.Peer](db, "ip", "idx_peers_account_id_ip")
|
||||
require.NoError(t, err, "Migration should not fail for an empty database")
|
||||
}
|
||||
|
||||
func TestMigrateNetIPFieldFromBlobToJSON_WithBlobData(t *testing.T) {
|
||||
db := setupDatabase(t)
|
||||
|
||||
err := db.AutoMigrate(&server.Account{}, &nbpeer.Peer{})
|
||||
require.NoError(t, err, "Failed to auto-migrate tables")
|
||||
|
||||
type location struct {
|
||||
nbpeer.Location
|
||||
ConnectionIP net.IP
|
||||
}
|
||||
|
||||
type peer struct {
|
||||
nbpeer.Peer
|
||||
Location location `gorm:"embedded;embeddedPrefix:location_"`
|
||||
}
|
||||
|
||||
type account struct {
|
||||
server.Account
|
||||
Peers []peer `gorm:"foreignKey:AccountID;references:id"`
|
||||
}
|
||||
|
||||
err = db.Save(&account{
|
||||
Account: server.Account{Id: "123"},
|
||||
Peers: []peer{
|
||||
{Location: location{ConnectionIP: net.IP{10, 0, 0, 1}}},
|
||||
}},
|
||||
).Error
|
||||
require.NoError(t, err, "Failed to insert blob data")
|
||||
|
||||
var blobValue string
|
||||
err = db.Model(&nbpeer.Peer{}).Select("location_connection_ip").First(&blobValue).Error
|
||||
assert.NoError(t, err, "Failed to fetch blob data")
|
||||
|
||||
err = migration.MigrateNetIPFieldFromBlobToJSON[nbpeer.Peer](db, "location_connection_ip", "")
|
||||
require.NoError(t, err, "Migration should not fail with net.IP blob data")
|
||||
|
||||
var jsonStr string
|
||||
db.Model(&nbpeer.Peer{}).Select("location_connection_ip").First(&jsonStr)
|
||||
assert.JSONEq(t, `"10.0.0.1"`, jsonStr, "Data should be migrated")
|
||||
}
|
||||
|
||||
func TestMigrateNetIPFieldFromBlobToJSON_WithJSONData(t *testing.T) {
|
||||
db := setupDatabase(t)
|
||||
|
||||
err := db.AutoMigrate(&server.Account{}, &nbpeer.Peer{})
|
||||
require.NoError(t, err, "Failed to auto-migrate tables")
|
||||
|
||||
err = db.Save(&server.Account{
|
||||
Id: "1234",
|
||||
PeersG: []nbpeer.Peer{
|
||||
{Location: nbpeer.Location{ConnectionIP: net.IP{10, 0, 0, 1}}},
|
||||
}},
|
||||
).Error
|
||||
require.NoError(t, err, "Failed to insert JSON data")
|
||||
|
||||
err = migration.MigrateNetIPFieldFromBlobToJSON[nbpeer.Peer](db, "location_connection_ip", "")
|
||||
require.NoError(t, err, "Migration should not fail with net.IP JSON data")
|
||||
|
||||
var jsonStr string
|
||||
db.Model(&nbpeer.Peer{}).Select("location_connection_ip").First(&jsonStr)
|
||||
assert.JSONEq(t, `"10.0.0.1"`, jsonStr, "Data should be unchanged")
|
||||
}
|
||||
@@ -22,80 +22,93 @@ type MockAccountManager struct {
|
||||
GetOrCreateAccountByUserFunc func(userId, domain string) (*server.Account, error)
|
||||
CreateSetupKeyFunc func(accountId string, keyName string, keyType server.SetupKeyType,
|
||||
expiresIn time.Duration, autoGroups []string, usageLimit int, userID string, ephemeral bool) (*server.SetupKey, error)
|
||||
GetSetupKeyFunc func(accountID, userID, keyID string) (*server.SetupKey, error)
|
||||
GetAccountByUserOrAccountIdFunc func(userId, accountId, domain string) (*server.Account, error)
|
||||
GetUserFunc func(claims jwtclaims.AuthorizationClaims) (*server.User, error)
|
||||
ListUsersFunc func(accountID string) ([]*server.User, error)
|
||||
GetPeersFunc func(accountID, userID string) ([]*nbpeer.Peer, error)
|
||||
MarkPeerConnectedFunc func(peerKey string, connected bool, realIP net.IP) error
|
||||
DeletePeerFunc func(accountID, peerKey, userID string) error
|
||||
GetNetworkMapFunc func(peerKey string) (*server.NetworkMap, error)
|
||||
GetPeerNetworkFunc func(peerKey string) (*server.Network, error)
|
||||
AddPeerFunc func(setupKey string, userId string, peer *nbpeer.Peer) (*nbpeer.Peer, *server.NetworkMap, error)
|
||||
GetGroupFunc func(accountID, groupID, userID string) (*group.Group, error)
|
||||
GetAllGroupsFunc func(accountID, userID string) ([]*group.Group, error)
|
||||
GetGroupByNameFunc func(accountID, groupName string) (*group.Group, error)
|
||||
SaveGroupFunc func(accountID, userID string, group *group.Group) error
|
||||
DeleteGroupFunc func(accountID, userId, groupID string) error
|
||||
ListGroupsFunc func(accountID string) ([]*group.Group, error)
|
||||
GroupAddPeerFunc func(accountID, groupID, peerID string) error
|
||||
GroupDeletePeerFunc func(accountID, groupID, peerID string) error
|
||||
DeleteRuleFunc func(accountID, ruleID, userID string) error
|
||||
GetPolicyFunc func(accountID, policyID, userID string) (*server.Policy, error)
|
||||
SavePolicyFunc func(accountID, userID string, policy *server.Policy) error
|
||||
DeletePolicyFunc func(accountID, policyID, userID string) error
|
||||
ListPoliciesFunc func(accountID, userID string) ([]*server.Policy, error)
|
||||
GetUsersFromAccountFunc func(accountID, userID string) ([]*server.UserInfo, error)
|
||||
GetAccountFromPATFunc func(pat string) (*server.Account, *server.User, *server.PersonalAccessToken, error)
|
||||
MarkPATUsedFunc func(pat string) error
|
||||
UpdatePeerMetaFunc func(peerID string, meta nbpeer.PeerSystemMeta) error
|
||||
UpdatePeerSSHKeyFunc func(peerID string, sshKey string) error
|
||||
UpdatePeerFunc func(accountID, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, error)
|
||||
CreateRouteFunc func(accountID, prefix, peer string, peerGroups []string, description, netID string, masquerade bool, metric int, groups []string, enabled bool, userID string) (*route.Route, error)
|
||||
GetRouteFunc func(accountID, routeID, userID string) (*route.Route, error)
|
||||
SaveRouteFunc func(accountID, userID string, route *route.Route) error
|
||||
DeleteRouteFunc func(accountID, routeID, userID string) error
|
||||
ListRoutesFunc func(accountID, userID string) ([]*route.Route, error)
|
||||
SaveSetupKeyFunc func(accountID string, key *server.SetupKey, userID string) (*server.SetupKey, error)
|
||||
ListSetupKeysFunc func(accountID, userID string) ([]*server.SetupKey, error)
|
||||
SaveUserFunc func(accountID, userID string, user *server.User) (*server.UserInfo, error)
|
||||
SaveOrAddUserFunc func(accountID, userID string, user *server.User, addIfNotExists bool) (*server.UserInfo, error)
|
||||
DeleteUserFunc func(accountID string, initiatorUserID string, targetUserID string) error
|
||||
CreatePATFunc func(accountID string, initiatorUserID string, targetUserId string, tokenName string, expiresIn int) (*server.PersonalAccessTokenGenerated, error)
|
||||
DeletePATFunc func(accountID string, initiatorUserID string, targetUserId string, tokenID string) error
|
||||
GetPATFunc func(accountID string, initiatorUserID string, targetUserId string, tokenID string) (*server.PersonalAccessToken, error)
|
||||
GetAllPATsFunc func(accountID string, initiatorUserID string, targetUserId string) ([]*server.PersonalAccessToken, error)
|
||||
GetNameServerGroupFunc func(accountID, userID, nsGroupID string) (*nbdns.NameServerGroup, error)
|
||||
CreateNameServerGroupFunc func(accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool, userID string, searchDomainsEnabled bool) (*nbdns.NameServerGroup, error)
|
||||
SaveNameServerGroupFunc func(accountID, userID string, nsGroupToSave *nbdns.NameServerGroup) error
|
||||
DeleteNameServerGroupFunc func(accountID, nsGroupID, userID string) error
|
||||
ListNameServerGroupsFunc func(accountID string, userID string) ([]*nbdns.NameServerGroup, error)
|
||||
CreateUserFunc func(accountID, userID string, key *server.UserInfo) (*server.UserInfo, error)
|
||||
GetAccountFromTokenFunc func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error)
|
||||
CheckUserAccessByJWTGroupsFunc func(claims jwtclaims.AuthorizationClaims) error
|
||||
DeleteAccountFunc func(accountID, userID string) error
|
||||
GetDNSDomainFunc func() string
|
||||
StoreEventFunc func(initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any)
|
||||
GetEventsFunc func(accountID, userID string) ([]*activity.Event, error)
|
||||
GetDNSSettingsFunc func(accountID, userID string) (*server.DNSSettings, error)
|
||||
SaveDNSSettingsFunc func(accountID, userID string, dnsSettingsToSave *server.DNSSettings) error
|
||||
GetPeerFunc func(accountID, peerID, userID string) (*nbpeer.Peer, error)
|
||||
UpdateAccountSettingsFunc func(accountID, userID string, newSettings *server.Settings) (*server.Account, error)
|
||||
LoginPeerFunc func(login server.PeerLogin) (*nbpeer.Peer, *server.NetworkMap, error)
|
||||
SyncPeerFunc func(sync server.PeerSync) (*nbpeer.Peer, *server.NetworkMap, error)
|
||||
InviteUserFunc func(accountID string, initiatorUserID string, targetUserEmail string) error
|
||||
GetAllConnectedPeersFunc func() (map[string]struct{}, error)
|
||||
HasConnectedChannelFunc func(peerID string) bool
|
||||
GetExternalCacheManagerFunc func() server.ExternalCacheManager
|
||||
GetPostureChecksFunc func(accountID, postureChecksID, userID string) (*posture.Checks, error)
|
||||
SavePostureChecksFunc func(accountID, userID string, postureChecks *posture.Checks) error
|
||||
DeletePostureChecksFunc func(accountID, postureChecksID, userID string) error
|
||||
ListPostureChecksFunc func(accountID, userID string) ([]*posture.Checks, error)
|
||||
GetIdpManagerFunc func() idp.Manager
|
||||
GetSetupKeyFunc func(accountID, userID, keyID string) (*server.SetupKey, error)
|
||||
GetAccountByUserOrAccountIdFunc func(userId, accountId, domain string) (*server.Account, error)
|
||||
GetUserFunc func(claims jwtclaims.AuthorizationClaims) (*server.User, error)
|
||||
ListUsersFunc func(accountID string) ([]*server.User, error)
|
||||
GetPeersFunc func(accountID, userID string) ([]*nbpeer.Peer, error)
|
||||
MarkPeerConnectedFunc func(peerKey string, connected bool, realIP net.IP) error
|
||||
SyncAndMarkPeerFunc func(peerPubKey string, realIP net.IP) (*nbpeer.Peer, *server.NetworkMap, error)
|
||||
DeletePeerFunc func(accountID, peerKey, userID string) error
|
||||
GetNetworkMapFunc func(peerKey string) (*server.NetworkMap, error)
|
||||
GetPeerNetworkFunc func(peerKey string) (*server.Network, error)
|
||||
AddPeerFunc func(setupKey string, userId string, peer *nbpeer.Peer) (*nbpeer.Peer, *server.NetworkMap, error)
|
||||
GetGroupFunc func(accountID, groupID, userID string) (*group.Group, error)
|
||||
GetAllGroupsFunc func(accountID, userID string) ([]*group.Group, error)
|
||||
GetGroupByNameFunc func(accountID, groupName string) (*group.Group, error)
|
||||
SaveGroupFunc func(accountID, userID string, group *group.Group) error
|
||||
DeleteGroupFunc func(accountID, userId, groupID string) error
|
||||
ListGroupsFunc func(accountID string) ([]*group.Group, error)
|
||||
GroupAddPeerFunc func(accountID, groupID, peerID string) error
|
||||
GroupDeletePeerFunc func(accountID, groupID, peerID string) error
|
||||
DeleteRuleFunc func(accountID, ruleID, userID string) error
|
||||
GetPolicyFunc func(accountID, policyID, userID string) (*server.Policy, error)
|
||||
SavePolicyFunc func(accountID, userID string, policy *server.Policy) error
|
||||
DeletePolicyFunc func(accountID, policyID, userID string) error
|
||||
ListPoliciesFunc func(accountID, userID string) ([]*server.Policy, error)
|
||||
GetUsersFromAccountFunc func(accountID, userID string) ([]*server.UserInfo, error)
|
||||
GetAccountFromPATFunc func(pat string) (*server.Account, *server.User, *server.PersonalAccessToken, error)
|
||||
MarkPATUsedFunc func(pat string) error
|
||||
UpdatePeerMetaFunc func(peerID string, meta nbpeer.PeerSystemMeta) error
|
||||
UpdatePeerSSHKeyFunc func(peerID string, sshKey string) error
|
||||
UpdatePeerFunc func(accountID, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, error)
|
||||
CreateRouteFunc func(accountID, prefix, peer string, peerGroups []string, description, netID string, masquerade bool, metric int, groups []string, enabled bool, userID string) (*route.Route, error)
|
||||
GetRouteFunc func(accountID, routeID, userID string) (*route.Route, error)
|
||||
SaveRouteFunc func(accountID, userID string, route *route.Route) error
|
||||
DeleteRouteFunc func(accountID, routeID, userID string) error
|
||||
ListRoutesFunc func(accountID, userID string) ([]*route.Route, error)
|
||||
SaveSetupKeyFunc func(accountID string, key *server.SetupKey, userID string) (*server.SetupKey, error)
|
||||
ListSetupKeysFunc func(accountID, userID string) ([]*server.SetupKey, error)
|
||||
SaveUserFunc func(accountID, userID string, user *server.User) (*server.UserInfo, error)
|
||||
SaveOrAddUserFunc func(accountID, userID string, user *server.User, addIfNotExists bool) (*server.UserInfo, error)
|
||||
DeleteUserFunc func(accountID string, initiatorUserID string, targetUserID string) error
|
||||
CreatePATFunc func(accountID string, initiatorUserID string, targetUserId string, tokenName string, expiresIn int) (*server.PersonalAccessTokenGenerated, error)
|
||||
DeletePATFunc func(accountID string, initiatorUserID string, targetUserId string, tokenID string) error
|
||||
GetPATFunc func(accountID string, initiatorUserID string, targetUserId string, tokenID string) (*server.PersonalAccessToken, error)
|
||||
GetAllPATsFunc func(accountID string, initiatorUserID string, targetUserId string) ([]*server.PersonalAccessToken, error)
|
||||
GetNameServerGroupFunc func(accountID, userID, nsGroupID string) (*nbdns.NameServerGroup, error)
|
||||
CreateNameServerGroupFunc func(accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool, userID string, searchDomainsEnabled bool) (*nbdns.NameServerGroup, error)
|
||||
SaveNameServerGroupFunc func(accountID, userID string, nsGroupToSave *nbdns.NameServerGroup) error
|
||||
DeleteNameServerGroupFunc func(accountID, nsGroupID, userID string) error
|
||||
ListNameServerGroupsFunc func(accountID string, userID string) ([]*nbdns.NameServerGroup, error)
|
||||
CreateUserFunc func(accountID, userID string, key *server.UserInfo) (*server.UserInfo, error)
|
||||
GetAccountFromTokenFunc func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error)
|
||||
CheckUserAccessByJWTGroupsFunc func(claims jwtclaims.AuthorizationClaims) error
|
||||
DeleteAccountFunc func(accountID, userID string) error
|
||||
GetDNSDomainFunc func() string
|
||||
StoreEventFunc func(initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any)
|
||||
GetEventsFunc func(accountID, userID string) ([]*activity.Event, error)
|
||||
GetDNSSettingsFunc func(accountID, userID string) (*server.DNSSettings, error)
|
||||
SaveDNSSettingsFunc func(accountID, userID string, dnsSettingsToSave *server.DNSSettings) error
|
||||
GetPeerFunc func(accountID, peerID, userID string) (*nbpeer.Peer, error)
|
||||
UpdateAccountSettingsFunc func(accountID, userID string, newSettings *server.Settings) (*server.Account, error)
|
||||
LoginPeerFunc func(login server.PeerLogin) (*nbpeer.Peer, *server.NetworkMap, error)
|
||||
SyncPeerFunc func(sync server.PeerSync) (*nbpeer.Peer, *server.NetworkMap, error)
|
||||
InviteUserFunc func(accountID string, initiatorUserID string, targetUserEmail string) error
|
||||
GetAllConnectedPeersFunc func() (map[string]struct{}, error)
|
||||
HasConnectedChannelFunc func(peerID string) bool
|
||||
GetExternalCacheManagerFunc func() server.ExternalCacheManager
|
||||
GetPostureChecksFunc func(accountID, postureChecksID, userID string) (*posture.Checks, error)
|
||||
SavePostureChecksFunc func(accountID, userID string, postureChecks *posture.Checks) error
|
||||
DeletePostureChecksFunc func(accountID, postureChecksID, userID string) error
|
||||
ListPostureChecksFunc func(accountID, userID string) ([]*posture.Checks, error)
|
||||
GetIdpManagerFunc func() idp.Manager
|
||||
UpdateIntegratedValidatorGroupsFunc func(accountID string, userID string, groups []string) error
|
||||
GroupValidationFunc func(accountId string, groups []string) (bool, error)
|
||||
}
|
||||
|
||||
func (am *MockAccountManager) SyncAndMarkPeer(peerPubKey string, realIP net.IP) (*nbpeer.Peer, *server.NetworkMap, error) {
|
||||
if am.SyncAndMarkPeerFunc != nil {
|
||||
return am.SyncAndMarkPeerFunc(peerPubKey, realIP)
|
||||
}
|
||||
return nil, nil, status.Errorf(codes.Unimplemented, "method MarkPeerConnected is not implemented")
|
||||
}
|
||||
|
||||
func (am *MockAccountManager) CancelPeerRoutines(peer *nbpeer.Peer) error {
|
||||
// TODO implement me
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
func (am *MockAccountManager) GetValidatedPeers(account *server.Account) (map[string]struct{}, error) {
|
||||
approvedPeers := make(map[string]struct{})
|
||||
for id := range account.Peers {
|
||||
@@ -180,7 +193,7 @@ func (am *MockAccountManager) GetAccountByUserOrAccountID(
|
||||
}
|
||||
|
||||
// MarkPeerConnected mock implementation of MarkPeerConnected from server.AccountManager interface
|
||||
func (am *MockAccountManager) MarkPeerConnected(peerKey string, connected bool, realIP net.IP) error {
|
||||
func (am *MockAccountManager) MarkPeerConnected(peerKey string, connected bool, realIP net.IP, account *server.Account) error {
|
||||
if am.MarkPeerConnectedFunc != nil {
|
||||
return am.MarkPeerConnectedFunc(peerKey, connected, realIP)
|
||||
}
|
||||
@@ -626,7 +639,7 @@ func (am *MockAccountManager) LoginPeer(login server.PeerLogin) (*nbpeer.Peer, *
|
||||
}
|
||||
|
||||
// SyncPeer mocks SyncPeer of the AccountManager interface
|
||||
func (am *MockAccountManager) SyncPeer(sync server.PeerSync) (*nbpeer.Peer, *server.NetworkMap, error) {
|
||||
func (am *MockAccountManager) SyncPeer(sync server.PeerSync, account *server.Account) (*nbpeer.Peer, *server.NetworkMap, error) {
|
||||
if am.SyncPeerFunc != nil {
|
||||
return am.SyncPeerFunc(sync)
|
||||
}
|
||||
@@ -706,7 +719,7 @@ func (am *MockAccountManager) GetIdpManager() idp.Manager {
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdateIntegratedValidatedGroups mocks UpdateIntegratedApprovalGroups of the AccountManager interface
|
||||
// UpdateIntegratedValidatorGroups mocks UpdateIntegratedApprovalGroups of the AccountManager interface
|
||||
func (am *MockAccountManager) UpdateIntegratedValidatorGroups(accountID string, userID string, groups []string) error {
|
||||
if am.UpdateIntegratedValidatorGroupsFunc != nil {
|
||||
return am.UpdateIntegratedValidatorGroupsFunc(accountID, userID, groups)
|
||||
|
||||
@@ -19,7 +19,7 @@ const domainPattern = `^(?i)[a-z0-9]+([\-\.]{1}[a-z0-9]+)*\.[a-z]{2,}$`
|
||||
// GetNameServerGroup gets a nameserver group object from account and nameserver group IDs
|
||||
func (am *DefaultAccountManager) GetNameServerGroup(accountID, userID, nsGroupID string) (*nbdns.NameServerGroup, error) {
|
||||
|
||||
unlock := am.Store.AcquireAccountLock(accountID)
|
||||
unlock := am.Store.AcquireAccountWriteLock(accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(accountID)
|
||||
@@ -47,7 +47,7 @@ func (am *DefaultAccountManager) GetNameServerGroup(accountID, userID, nsGroupID
|
||||
// CreateNameServerGroup creates and saves a new nameserver group
|
||||
func (am *DefaultAccountManager) CreateNameServerGroup(accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool, userID string, searchDomainEnabled bool) (*nbdns.NameServerGroup, error) {
|
||||
|
||||
unlock := am.Store.AcquireAccountLock(accountID)
|
||||
unlock := am.Store.AcquireAccountWriteLock(accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(accountID)
|
||||
@@ -94,7 +94,7 @@ func (am *DefaultAccountManager) CreateNameServerGroup(accountID string, name, d
|
||||
// SaveNameServerGroup saves nameserver group
|
||||
func (am *DefaultAccountManager) SaveNameServerGroup(accountID, userID string, nsGroupToSave *nbdns.NameServerGroup) error {
|
||||
|
||||
unlock := am.Store.AcquireAccountLock(accountID)
|
||||
unlock := am.Store.AcquireAccountWriteLock(accountID)
|
||||
defer unlock()
|
||||
|
||||
if nsGroupToSave == nil {
|
||||
@@ -129,7 +129,7 @@ func (am *DefaultAccountManager) SaveNameServerGroup(accountID, userID string, n
|
||||
// DeleteNameServerGroup deletes nameserver group with nsGroupID
|
||||
func (am *DefaultAccountManager) DeleteNameServerGroup(accountID, nsGroupID, userID string) error {
|
||||
|
||||
unlock := am.Store.AcquireAccountLock(accountID)
|
||||
unlock := am.Store.AcquireAccountWriteLock(accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(accountID)
|
||||
@@ -159,7 +159,7 @@ func (am *DefaultAccountManager) DeleteNameServerGroup(accountID, nsGroupID, use
|
||||
// ListNameServerGroups returns a list of nameserver groups from account
|
||||
func (am *DefaultAccountManager) ListNameServerGroups(accountID string, userID string) ([]*nbdns.NameServerGroup, error) {
|
||||
|
||||
unlock := am.Store.AcquireAccountLock(accountID)
|
||||
unlock := am.Store.AcquireAccountWriteLock(accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(accountID)
|
||||
|
||||
@@ -36,7 +36,7 @@ type NetworkMap struct {
|
||||
|
||||
type Network struct {
|
||||
Identifier string `json:"id"`
|
||||
Net net.IPNet `gorm:"serializer:gob"`
|
||||
Net net.IPNet `gorm:"serializer:json"`
|
||||
Dns string
|
||||
// Serial is an ID that increments by 1 when any change to the network happened (e.g. new peer has been added).
|
||||
// Used to synchronize state to the client apps.
|
||||
|
||||
@@ -88,21 +88,7 @@ func (am *DefaultAccountManager) GetPeers(accountID, userID string) ([]*nbpeer.P
|
||||
}
|
||||
|
||||
// MarkPeerConnected marks peer as connected (true) or disconnected (false)
|
||||
func (am *DefaultAccountManager) MarkPeerConnected(peerPubKey string, connected bool, realIP net.IP) error {
|
||||
account, err := am.Store.GetAccountByPeerPubKey(peerPubKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
unlock := am.Store.AcquireAccountLock(account.Id)
|
||||
defer unlock()
|
||||
|
||||
// ensure that we consider modification happened meanwhile (because we were outside the account lock when we fetched the account)
|
||||
account, err = am.Store.GetAccount(account.Id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) MarkPeerConnected(peerPubKey string, connected bool, realIP net.IP, account *Account) error {
|
||||
peer, err := account.FindPeerByPubKey(peerPubKey)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -156,7 +142,7 @@ func (am *DefaultAccountManager) MarkPeerConnected(peerPubKey string, connected
|
||||
|
||||
// UpdatePeer updates peer. Only Peer.Name, Peer.SSHEnabled, and Peer.LoginExpirationEnabled can be updated.
|
||||
func (am *DefaultAccountManager) UpdatePeer(accountID, userID string, update *nbpeer.Peer) (*nbpeer.Peer, error) {
|
||||
unlock := am.Store.AcquireAccountLock(accountID)
|
||||
unlock := am.Store.AcquireAccountWriteLock(accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(accountID)
|
||||
@@ -278,7 +264,7 @@ func (am *DefaultAccountManager) deletePeers(account *Account, peerIDs []string,
|
||||
|
||||
// DeletePeer removes peer from the account by its IP
|
||||
func (am *DefaultAccountManager) DeletePeer(accountID, peerID, userID string) error {
|
||||
unlock := am.Store.AcquireAccountLock(accountID)
|
||||
unlock := am.Store.AcquireAccountWriteLock(accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(accountID)
|
||||
@@ -362,7 +348,7 @@ func (am *DefaultAccountManager) AddPeer(setupKey, userID string, peer *nbpeer.P
|
||||
return nil, nil, status.Errorf(status.NotFound, "failed adding new peer: account not found")
|
||||
}
|
||||
|
||||
unlock := am.Store.AcquireAccountLock(account.Id)
|
||||
unlock := am.Store.AcquireAccountWriteLock(account.Id)
|
||||
defer unlock()
|
||||
|
||||
// ensure that we consider modification happened meanwhile (because we were outside the account lock when we fetched the account)
|
||||
@@ -374,14 +360,14 @@ func (am *DefaultAccountManager) AddPeer(setupKey, userID string, peer *nbpeer.P
|
||||
if strings.ToLower(peer.Meta.Hostname) == "iphone" || strings.ToLower(peer.Meta.Hostname) == "ipad" && userID != "" {
|
||||
if am.idpManager != nil {
|
||||
userdata, err := am.lookupUserInCache(userID, account)
|
||||
if err == nil {
|
||||
if err == nil && userdata != nil {
|
||||
peer.Meta.Hostname = fmt.Sprintf("%s-%s", peer.Meta.Hostname, strings.Split(userdata.Email, "@")[0])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// This is a handling for the case when the same machine (with the same WireGuard pub key) tries to register twice.
|
||||
// Such case is possible when AddPeer function takes long time to finish after AcquireAccountLock (e.g., database is slow)
|
||||
// Such case is possible when AddPeer function takes long time to finish after AcquireAccountWriteLock (e.g., database is slow)
|
||||
// and the peer disconnects with a timeout and tries to register again.
|
||||
// We just check if this machine has been registered before and reject the second registration.
|
||||
// The connecting peer should be able to recover with a retry.
|
||||
@@ -518,25 +504,7 @@ func (am *DefaultAccountManager) AddPeer(setupKey, userID string, peer *nbpeer.P
|
||||
}
|
||||
|
||||
// SyncPeer checks whether peer is eligible for receiving NetworkMap (authenticated) and returns its NetworkMap if eligible
|
||||
func (am *DefaultAccountManager) SyncPeer(sync PeerSync) (*nbpeer.Peer, *NetworkMap, error) {
|
||||
account, err := am.Store.GetAccountByPeerPubKey(sync.WireGuardPubKey)
|
||||
if err != nil {
|
||||
if errStatus, ok := status.FromError(err); ok && errStatus.Type() == status.NotFound {
|
||||
return nil, nil, status.Errorf(status.Unauthenticated, "peer is not registered")
|
||||
}
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
// we found the peer, and we follow a normal login flow
|
||||
unlock := am.Store.AcquireAccountLock(account.Id)
|
||||
defer unlock()
|
||||
|
||||
// fetch the account from the store once more after acquiring lock to avoid concurrent updates inconsistencies
|
||||
account, err = am.Store.GetAccount(account.Id)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) SyncPeer(sync PeerSync, account *Account) (*nbpeer.Peer, *NetworkMap, error) {
|
||||
peer, err := account.FindPeerByPubKey(sync.WireGuardPubKey)
|
||||
if err != nil {
|
||||
return nil, nil, status.Errorf(status.Unauthenticated, "peer is not registered")
|
||||
@@ -551,8 +519,8 @@ func (am *DefaultAccountManager) SyncPeer(sync PeerSync) (*nbpeer.Peer, *Network
|
||||
return nil, nil, status.Errorf(status.PermissionDenied, "peer login has expired, please log in once more")
|
||||
}
|
||||
|
||||
requiresApproval, isStatusChanged := am.integratedPeerValidator.IsNotValidPeer(account.Id, peer, account.GetPeerGroupsList(peer.ID), account.Settings.Extra)
|
||||
if requiresApproval {
|
||||
peerNotValid, isStatusChanged := am.integratedPeerValidator.IsNotValidPeer(account.Id, peer, account.GetPeerGroupsList(peer.ID), account.Settings.Extra)
|
||||
if peerNotValid {
|
||||
emptyMap := &NetworkMap{
|
||||
Network: account.Network.Copy(),
|
||||
}
|
||||
@@ -563,11 +531,11 @@ func (am *DefaultAccountManager) SyncPeer(sync PeerSync) (*nbpeer.Peer, *Network
|
||||
am.updateAccountPeers(account)
|
||||
}
|
||||
|
||||
approvedPeersMap, err := am.GetValidatedPeers(account)
|
||||
validPeersMap, err := am.GetValidatedPeers(account)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
return peer, account.GetPeerNetworkMap(peer.ID, am.dnsDomain, approvedPeersMap), nil
|
||||
return peer, account.GetPeerNetworkMap(peer.ID, am.dnsDomain, validPeersMap), nil
|
||||
}
|
||||
|
||||
// LoginPeer logs in or registers a peer.
|
||||
@@ -603,7 +571,7 @@ func (am *DefaultAccountManager) LoginPeer(login PeerLogin) (*nbpeer.Peer, *Netw
|
||||
}
|
||||
|
||||
// we found the peer, and we follow a normal login flow
|
||||
unlock := am.Store.AcquireAccountLock(account.Id)
|
||||
unlock := am.Store.AcquireAccountWriteLock(account.Id)
|
||||
defer unlock()
|
||||
|
||||
// fetch the account from the store once more after acquiring lock to avoid concurrent updates inconsistencies
|
||||
@@ -760,7 +728,7 @@ func (am *DefaultAccountManager) UpdatePeerSSHKey(peerID string, sshKey string)
|
||||
return err
|
||||
}
|
||||
|
||||
unlock := am.Store.AcquireAccountLock(account.Id)
|
||||
unlock := am.Store.AcquireAccountWriteLock(account.Id)
|
||||
defer unlock()
|
||||
|
||||
// ensure that we consider modification happened meanwhile (because we were outside the account lock when we fetched the account)
|
||||
@@ -795,7 +763,7 @@ func (am *DefaultAccountManager) UpdatePeerSSHKey(peerID string, sshKey string)
|
||||
|
||||
// GetPeer for a given accountID, peerID and userID error if not found.
|
||||
func (am *DefaultAccountManager) GetPeer(accountID, peerID, userID string) (*nbpeer.Peer, error) {
|
||||
unlock := am.Store.AcquireAccountLock(accountID)
|
||||
unlock := am.Store.AcquireAccountWriteLock(accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(accountID)
|
||||
@@ -865,6 +833,10 @@ func (am *DefaultAccountManager) updateAccountPeers(account *Account) {
|
||||
return
|
||||
}
|
||||
for _, peer := range peers {
|
||||
if !am.peersUpdateManager.HasChannel(peer.ID) {
|
||||
log.Tracef("peer %s doesn't have a channel, skipping network map update", peer.ID)
|
||||
continue
|
||||
}
|
||||
remotePeerNetworkMap := account.GetPeerNetworkMap(peer.ID, am.dnsDomain, approvedPeersMap)
|
||||
update := toSyncResponse(nil, peer, nil, remotePeerNetworkMap, am.GetDNSDomain())
|
||||
am.peersUpdateManager.SendUpdate(peer.ID, &UpdateMessage{Update: update})
|
||||
|
||||
@@ -13,13 +13,13 @@ type Peer struct {
|
||||
// ID is an internal ID of the peer
|
||||
ID string `gorm:"primaryKey"`
|
||||
// AccountID is a reference to Account that this object belongs
|
||||
AccountID string `json:"-" gorm:"index;uniqueIndex:idx_peers_account_id_ip"`
|
||||
AccountID string `json:"-" gorm:"index"`
|
||||
// WireGuard public key
|
||||
Key string `gorm:"index"`
|
||||
// A setup key this peer was registered with
|
||||
SetupKey string
|
||||
// IP address of the Peer
|
||||
IP net.IP `gorm:"uniqueIndex:idx_peers_account_id_ip"`
|
||||
IP net.IP `gorm:"serializer:json"`
|
||||
// Meta is a Peer system meta data
|
||||
Meta PeerSystemMeta `gorm:"embedded;embeddedPrefix:meta_"`
|
||||
// Name is peer's name (machine name)
|
||||
@@ -61,7 +61,7 @@ type PeerStatus struct { //nolint:revive
|
||||
|
||||
// Location is a geo location information of a Peer based on public connection IP
|
||||
type Location struct {
|
||||
ConnectionIP net.IP // from grpc peer or reverse proxy headers depends on setup
|
||||
ConnectionIP net.IP `gorm:"serializer:json"` // from grpc peer or reverse proxy headers depends on setup
|
||||
CountryCode string
|
||||
CityName string
|
||||
GeoNameID uint // city level geoname id
|
||||
|
||||
@@ -314,7 +314,7 @@ func (a *Account) connResourcesGenerator() (func(*PolicyRule, []*nbpeer.Peer, in
|
||||
|
||||
// GetPolicy from the store
|
||||
func (am *DefaultAccountManager) GetPolicy(accountID, policyID, userID string) (*Policy, error) {
|
||||
unlock := am.Store.AcquireAccountLock(accountID)
|
||||
unlock := am.Store.AcquireAccountWriteLock(accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(accountID)
|
||||
@@ -342,7 +342,7 @@ func (am *DefaultAccountManager) GetPolicy(accountID, policyID, userID string) (
|
||||
|
||||
// SavePolicy in the store
|
||||
func (am *DefaultAccountManager) SavePolicy(accountID, userID string, policy *Policy) error {
|
||||
unlock := am.Store.AcquireAccountLock(accountID)
|
||||
unlock := am.Store.AcquireAccountWriteLock(accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(accountID)
|
||||
@@ -370,7 +370,7 @@ func (am *DefaultAccountManager) SavePolicy(accountID, userID string, policy *Po
|
||||
|
||||
// DeletePolicy from the store
|
||||
func (am *DefaultAccountManager) DeletePolicy(accountID, policyID, userID string) error {
|
||||
unlock := am.Store.AcquireAccountLock(accountID)
|
||||
unlock := am.Store.AcquireAccountWriteLock(accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(accountID)
|
||||
@@ -397,7 +397,7 @@ func (am *DefaultAccountManager) DeletePolicy(accountID, policyID, userID string
|
||||
|
||||
// ListPolicies from the store
|
||||
func (am *DefaultAccountManager) ListPolicies(accountID, userID string) ([]*Policy, error) {
|
||||
unlock := am.Store.AcquireAccountLock(accountID)
|
||||
unlock := am.Store.AcquireAccountWriteLock(accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(accountID)
|
||||
|
||||
@@ -7,7 +7,7 @@ import (
|
||||
)
|
||||
|
||||
func (am *DefaultAccountManager) GetPostureChecks(accountID, postureChecksID, userID string) (*posture.Checks, error) {
|
||||
unlock := am.Store.AcquireAccountLock(accountID)
|
||||
unlock := am.Store.AcquireAccountWriteLock(accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(accountID)
|
||||
@@ -34,7 +34,7 @@ func (am *DefaultAccountManager) GetPostureChecks(accountID, postureChecksID, us
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) SavePostureChecks(accountID, userID string, postureChecks *posture.Checks) error {
|
||||
unlock := am.Store.AcquireAccountLock(accountID)
|
||||
unlock := am.Store.AcquireAccountWriteLock(accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(accountID)
|
||||
@@ -81,7 +81,7 @@ func (am *DefaultAccountManager) SavePostureChecks(accountID, userID string, pos
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) DeletePostureChecks(accountID, postureChecksID, userID string) error {
|
||||
unlock := am.Store.AcquireAccountLock(accountID)
|
||||
unlock := am.Store.AcquireAccountWriteLock(accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(accountID)
|
||||
@@ -113,7 +113,7 @@ func (am *DefaultAccountManager) DeletePostureChecks(accountID, postureChecksID,
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) ListPostureChecks(accountID, userID string) ([]*posture.Checks, error) {
|
||||
unlock := am.Store.AcquireAccountLock(accountID)
|
||||
unlock := am.Store.AcquireAccountWriteLock(accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(accountID)
|
||||
|
||||
@@ -14,7 +14,7 @@ import (
|
||||
|
||||
// GetRoute gets a route object from account and route IDs
|
||||
func (am *DefaultAccountManager) GetRoute(accountID, routeID, userID string) (*route.Route, error) {
|
||||
unlock := am.Store.AcquireAccountLock(accountID)
|
||||
unlock := am.Store.AcquireAccountWriteLock(accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(accountID)
|
||||
@@ -115,7 +115,7 @@ func (am *DefaultAccountManager) checkRoutePrefixExistsForPeers(account *Account
|
||||
|
||||
// CreateRoute creates and saves a new route
|
||||
func (am *DefaultAccountManager) CreateRoute(accountID, network, peerID string, peerGroupIDs []string, description, netID string, masquerade bool, metric int, groups []string, enabled bool, userID string) (*route.Route, error) {
|
||||
unlock := am.Store.AcquireAccountLock(accountID)
|
||||
unlock := am.Store.AcquireAccountWriteLock(accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(accountID)
|
||||
@@ -194,7 +194,7 @@ func (am *DefaultAccountManager) CreateRoute(accountID, network, peerID string,
|
||||
|
||||
// SaveRoute saves route
|
||||
func (am *DefaultAccountManager) SaveRoute(accountID, userID string, routeToSave *route.Route) error {
|
||||
unlock := am.Store.AcquireAccountLock(accountID)
|
||||
unlock := am.Store.AcquireAccountWriteLock(accountID)
|
||||
defer unlock()
|
||||
|
||||
if routeToSave == nil {
|
||||
@@ -255,7 +255,7 @@ func (am *DefaultAccountManager) SaveRoute(accountID, userID string, routeToSave
|
||||
|
||||
// DeleteRoute deletes route with routeID
|
||||
func (am *DefaultAccountManager) DeleteRoute(accountID, routeID, userID string) error {
|
||||
unlock := am.Store.AcquireAccountLock(accountID)
|
||||
unlock := am.Store.AcquireAccountWriteLock(accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(accountID)
|
||||
@@ -283,7 +283,7 @@ func (am *DefaultAccountManager) DeleteRoute(accountID, routeID, userID string)
|
||||
|
||||
// ListRoutes returns a list of routes from account
|
||||
func (am *DefaultAccountManager) ListRoutes(accountID, userID string) ([]*route.Route, error) {
|
||||
unlock := am.Store.AcquireAccountLock(accountID)
|
||||
unlock := am.Store.AcquireAccountWriteLock(accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(accountID)
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user