Compare commits

..

61 Commits

Author SHA1 Message Date
bcmmbaga
8c44187900 Merge branch 'refs/heads/test/add-stopwatch' into deploy/peer-performance 2024-05-06 14:29:22 +03:00
Pascal Fischer
0dc050b354 update sqlite operations 2024-05-06 13:25:46 +02:00
Pascal Fischer
69eb6f17c6 remove stacktrace 2024-05-03 17:30:26 +02:00
bcmmbaga
ccbe2e0100 Merge branch 'refs/heads/migrate-with-json-serializer' into deploy/peer-performance 2024-05-03 14:11:23 +03:00
bcmmbaga
5b042eee74 migrate null blob values 2024-05-03 14:07:54 +03:00
Pascal Fischer
125c418d85 add debug to GetAccount sqlite store 2024-04-30 13:52:19 +02:00
Pascal Fischer
9f607029c0 add debug to GetAccount sqlite store 2024-04-30 13:35:04 +02:00
bcmmbaga
4df1d556c8 fix tests 2024-04-30 13:35:49 +03:00
bcmmbaga
acaf315152 Merge branch 'refs/heads/migrate-with-json-serializer' into deploy/peer-performance 2024-04-30 13:32:48 +03:00
bcmmbaga
81b6f7dc56 Merge remote-tracking branch 'refs/remotes/origin/test/add-stopwatch' into deploy/peer-performance 2024-04-30 13:32:37 +03:00
bcmmbaga
cb9a8e9fa4 Add tests 2024-04-30 13:27:14 +03:00
Pascal Fischer
f68d9ce042 fix linter 2024-04-30 12:03:59 +02:00
Pascal Fischer
1f182411c6 update store 2024-04-30 11:43:48 +02:00
Pascal Fischer
68e971cb82 Merge remote-tracking branch 'origin/test/add-stopwatch' into test/add-stopwatch 2024-04-30 11:29:08 +02:00
Pascal Fischer
f4a464fb69 update account manager mock 2024-04-30 11:28:57 +02:00
bcmmbaga
7aff0e828b Refactor 2024-04-30 10:07:40 +03:00
bcmmbaga
90ebf0017b remove duplicate index 2024-04-29 23:21:00 +03:00
bcmmbaga
bf8c6b95de run net ip migration 2024-04-29 23:20:43 +03:00
bcmmbaga
8415064758 migrate net ip field from blob to json 2024-04-29 23:20:22 +03:00
Pascal Fischer
809f6cdbc7 add stacktrace for get network map 2024-04-29 18:56:58 +02:00
Viktor Liu
e435e39158 Fix route selection IDs (#1890) 2024-04-29 18:43:14 +02:00
Maycon Santos
fd26e989e3 Check if channel exist before sending network map (#1894)
Check for connection channel before calculating and sending the network map
2024-04-29 18:31:52 +02:00
Pascal Fischer
7474876d6a add stacktrace for get network map 2024-04-29 18:12:22 +02:00
Pascal Fischer
2095e35217 remove stopwatches + logs 2024-04-29 16:26:44 +02:00
Pascal Fischer
02537e5536 remove stopwatches 2024-04-29 14:57:57 +02:00
bcmmbaga
f99477d3db serialize net.IP as json 2024-04-29 15:45:46 +03:00
Pascal Fischer
3085383729 introduce RWLocks 2024-04-29 14:03:33 +02:00
Pascal Fischer
07bdf977d0 update db access 2024-04-26 18:05:41 +02:00
Pascal Fischer
4a147006d2 add table 2024-04-26 17:45:10 +02:00
Viktor Liu
4424162bce Add client debug features (#1884)
* Add status anonymization
* Add OS/arch to the status command
* Use human-friendly last-update status messages
* Add debug bundle command to collect (anonymized) logs
* Add debug log level command
* And debug for a certain time span command
2024-04-26 17:20:10 +02:00
Pascal Fischer
480192028e try improving Sync method 2024-04-26 17:18:24 +02:00
Viktor Liu
54b045d9ca Replaces powershell with the route command and cache route lookups on windows (#1880) 2024-04-26 16:37:27 +02:00
Bethuel Mmbaga
71c6437bab add content type before writing header (#1887) 2024-04-25 21:20:24 +02:00
Pascal Fischer
027c0e5c63 update 2024-04-25 15:47:32 +02:00
Pascal Fischer
fbb1181b64 update 2024-04-25 15:44:06 +02:00
Pascal Fischer
e50ed290d9 update 2024-04-25 15:17:57 +02:00
Pascal Fischer
8f8cfcbc20 add method timing logs 2024-04-25 14:15:36 +02:00
pascal-fischer
7b254cb966 add methods to manage rosenpass settings for iOS (#1879) 2024-04-23 19:26:03 +02:00
pascal-fischer
8f3a0f2c38 Add retry to IdP cache lookup (#1882) 2024-04-23 19:23:43 +02:00
pascal-fischer
1f33e2e003 Support exit nodes on iOS (#1878) 2024-04-23 19:12:16 +02:00
pascal-fischer
1e6addaa65 Add account locks to getAccountWithAuthorizationClaims method (#1847) 2024-04-23 19:09:58 +02:00
Viktor Liu
f51dc13f8c Add route selection functionality for CLI and GUI (#1865) 2024-04-23 14:42:53 +02:00
dependabot[bot]
3477108ce7 Bump golang.org/x/net from 0.20.0 to 0.23.0 (#1867)
Bumps [golang.org/x/net](https://github.com/golang/net) from 0.20.0 to 0.23.0.
- [Commits](https://github.com/golang/net/compare/v0.20.0...v0.23.0)

---
updated-dependencies:
- dependency-name: golang.org/x/net
  dependency-type: direct:production
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2024-04-23 12:48:25 +02:00
Maycon Santos
012e624296 Fix DNS not found query response (#1877)
for local queries, we should return NXDOMAIN instead of NOERROR

Also, updated gomobile for Android and iOS builds
2024-04-23 10:20:09 +02:00
Maycon Santos
4c5e987e02 Add support for GUI app to display error (#1844) 2024-04-22 11:57:38 +02:00
Maycon Santos
a80c8b0176 Redeem invite only when incoming user was invited (#1861)
checks for users with pending invite status in the cache that already logged in and refresh the cache
2024-04-22 11:10:27 +02:00
Misha Bragin
9e01155d2e Add new intro image 2024-04-22 11:00:52 +02:00
Maycon Santos
3c3111ad01 Copy client binary to a directory in path (#1842) 2024-04-22 10:14:07 +02:00
Misha Bragin
b74078fd95 Use a better way to insert data in batches (#1874) 2024-04-20 22:04:20 +02:00
Viktor Liu
77488ad11a Migrate serializer:gob fields to serializer:json (#1855) 2024-04-18 18:14:21 +02:00
Viktor Liu
e3b76448f3 Fix ICE endpoint remote port in status command (#1851) 2024-04-16 14:01:59 +02:00
Viktor Liu
e0de86d6c9 Use fixed activity codes (#1846)
* Add duplicate constants check
2024-04-15 14:15:46 +02:00
Zoltan Papp
5204d07811 Pass integrated validator for API (#1814)
Pass integrated validator for API handler
2024-04-15 12:08:38 +02:00
Viktor Liu
5ea24ba56e Add sysctl opts to prevent reverse path filtering from dropping fwmark packets (#1839) 2024-04-12 17:53:07 +02:00
Viktor Liu
d30cf8706a Allow disabling custom routing (#1840) 2024-04-12 16:53:11 +02:00
Viktor Liu
15a2feb723 Use fixed preference for rules (#1836) 2024-04-12 16:07:03 +02:00
Viktor Liu
91b2f9fc51 Use route active store (#1834) 2024-04-12 15:22:40 +02:00
Carlos Hernandez
76702c8a09 Add safe read/write to route map (#1760) 2024-04-11 22:12:23 +02:00
Viktor Liu
061f673a4f Don't use the custom dialer as non-root (#1823) 2024-04-11 15:29:03 +02:00
Zoltan Papp
9505805313 Rename variable (#1829) 2024-04-11 14:08:03 +02:00
Maycon Santos
704c67dec8 Allow owners that did not create the account to delete it (#1825)
Sometimes the Owner role will be passed to new users, and they need to be able to delete the account
2024-04-11 10:02:51 +02:00
89 changed files with 4840 additions and 781 deletions

View File

@@ -19,7 +19,7 @@ jobs:
- name: codespell
uses: codespell-project/actions-codespell@v2
with:
ignore_words_list: erro,clienta
ignore_words_list: erro,clienta,hastable,
skip: go.mod,go.sum
only_warn: 1
golangci:
@@ -33,6 +33,10 @@ jobs:
steps:
- name: Checkout code
uses: actions/checkout@v3
- name: Check for duplicate constants
if: matrix.os == 'ubuntu-latest'
run: |
! awk '/const \(/,/)/{print $0}' management/server/activity/codes.go | grep -o '= [0-9]*' | sort | uniq -d | grep .
- name: Install Go
uses: actions/setup-go@v4
with:

View File

@@ -38,7 +38,7 @@ jobs:
- name: Setup NDK
run: /usr/local/lib/android/sdk/cmdline-tools/7.0/bin/sdkmanager --install "ndk;23.1.7779620"
- name: install gomobile
run: go install golang.org/x/mobile/cmd/gomobile@v0.0.0-20230531173138-3c911d8e3eda
run: go install golang.org/x/mobile/cmd/gomobile@v0.0.0-20240404231514-09dbf07665ed
- name: gomobile init
run: gomobile init
- name: build android netbird lib
@@ -56,10 +56,10 @@ jobs:
with:
go-version: "1.21.x"
- name: install gomobile
run: go install golang.org/x/mobile/cmd/gomobile@v0.0.0-20230531173138-3c911d8e3eda
run: go install golang.org/x/mobile/cmd/gomobile@v0.0.0-20240404231514-09dbf07665ed
- name: gomobile init
run: gomobile init
- name: build iOS netbird lib
run: PATH=$PATH:$(go env GOPATH) gomobile bind -target=ios -bundleid=io.netbird.framework -ldflags="-X github.com/netbirdio/netbird/version.version=buildtest" -o $GITHUB_WORKSPACE/NetBirdSDK.xcframework $GITHUB_WORKSPACE/client/ios/NetBirdSDK
run: PATH=$PATH:$(go env GOPATH) gomobile bind -target=ios -bundleid=io.netbird.framework -ldflags="-X github.com/netbirdio/netbird/version.version=buildtest" -o ./NetBirdSDK.xcframework ./client/ios/NetBirdSDK
env:
CGO_ENABLED: 0

View File

@@ -44,7 +44,8 @@
### Open-Source Network Security in a Single Platform
![image](https://github.com/netbirdio/netbird/assets/700848/c0d7bae4-3301-499a-bb4e-5e4a225bf35f)
![netbird_2](https://github.com/netbirdio/netbird/assets/700848/46bc3b73-508d-4a0e-bb9a-f465d68646ab)
### Key features

View File

@@ -1,5 +1,5 @@
FROM alpine:3.18.5
RUN apk add --no-cache ca-certificates iptables ip6tables
ENV NB_FOREGROUND_MODE=true
ENTRYPOINT [ "/go/bin/netbird","up"]
COPY netbird /go/bin/netbird
ENTRYPOINT [ "/usr/local/bin/netbird","up"]
COPY netbird /usr/local/bin/netbird

View File

@@ -0,0 +1,212 @@
package anonymize
import (
"crypto/rand"
"fmt"
"math/big"
"net"
"net/netip"
"net/url"
"regexp"
"slices"
"strings"
)
type Anonymizer struct {
ipAnonymizer map[netip.Addr]netip.Addr
domainAnonymizer map[string]string
currentAnonIPv4 netip.Addr
currentAnonIPv6 netip.Addr
startAnonIPv4 netip.Addr
startAnonIPv6 netip.Addr
}
func DefaultAddresses() (netip.Addr, netip.Addr) {
// 192.51.100.0, 100::
return netip.AddrFrom4([4]byte{198, 51, 100, 0}), netip.AddrFrom16([16]byte{0x01})
}
func NewAnonymizer(startIPv4, startIPv6 netip.Addr) *Anonymizer {
return &Anonymizer{
ipAnonymizer: map[netip.Addr]netip.Addr{},
domainAnonymizer: map[string]string{},
currentAnonIPv4: startIPv4,
currentAnonIPv6: startIPv6,
startAnonIPv4: startIPv4,
startAnonIPv6: startIPv6,
}
}
func (a *Anonymizer) AnonymizeIP(ip netip.Addr) netip.Addr {
if ip.IsLoopback() ||
ip.IsLinkLocalUnicast() ||
ip.IsLinkLocalMulticast() ||
ip.IsInterfaceLocalMulticast() ||
ip.IsPrivate() ||
ip.IsUnspecified() ||
ip.IsMulticast() ||
isWellKnown(ip) ||
a.isInAnonymizedRange(ip) {
return ip
}
if _, ok := a.ipAnonymizer[ip]; !ok {
if ip.Is4() {
a.ipAnonymizer[ip] = a.currentAnonIPv4
a.currentAnonIPv4 = a.currentAnonIPv4.Next()
} else {
a.ipAnonymizer[ip] = a.currentAnonIPv6
a.currentAnonIPv6 = a.currentAnonIPv6.Next()
}
}
return a.ipAnonymizer[ip]
}
// isInAnonymizedRange checks if an IP is within the range of already assigned anonymized IPs
func (a *Anonymizer) isInAnonymizedRange(ip netip.Addr) bool {
if ip.Is4() && ip.Compare(a.startAnonIPv4) >= 0 && ip.Compare(a.currentAnonIPv4) <= 0 {
return true
} else if !ip.Is4() && ip.Compare(a.startAnonIPv6) >= 0 && ip.Compare(a.currentAnonIPv6) <= 0 {
return true
}
return false
}
func (a *Anonymizer) AnonymizeIPString(ip string) string {
addr, err := netip.ParseAddr(ip)
if err != nil {
return ip
}
return a.AnonymizeIP(addr).String()
}
func (a *Anonymizer) AnonymizeDomain(domain string) string {
if strings.HasSuffix(domain, "netbird.io") ||
strings.HasSuffix(domain, "netbird.selfhosted") ||
strings.HasSuffix(domain, "netbird.cloud") ||
strings.HasSuffix(domain, "netbird.stage") ||
strings.HasSuffix(domain, ".domain") {
return domain
}
parts := strings.Split(domain, ".")
if len(parts) < 2 {
return domain
}
baseDomain := parts[len(parts)-2] + "." + parts[len(parts)-1]
anonymized, ok := a.domainAnonymizer[baseDomain]
if !ok {
anonymizedBase := "anon-" + generateRandomString(5) + ".domain"
a.domainAnonymizer[baseDomain] = anonymizedBase
anonymized = anonymizedBase
}
return strings.Replace(domain, baseDomain, anonymized, 1)
}
func (a *Anonymizer) AnonymizeURI(uri string) string {
u, err := url.Parse(uri)
if err != nil {
return uri
}
var anonymizedHost string
if u.Opaque != "" {
host, port, err := net.SplitHostPort(u.Opaque)
if err == nil {
anonymizedHost = fmt.Sprintf("%s:%s", a.AnonymizeDomain(host), port)
} else {
anonymizedHost = a.AnonymizeDomain(u.Opaque)
}
u.Opaque = anonymizedHost
} else if u.Host != "" {
host, port, err := net.SplitHostPort(u.Host)
if err == nil {
anonymizedHost = fmt.Sprintf("%s:%s", a.AnonymizeDomain(host), port)
} else {
anonymizedHost = a.AnonymizeDomain(u.Host)
}
u.Host = anonymizedHost
}
return u.String()
}
func (a *Anonymizer) AnonymizeString(str string) string {
ipv4Regex := regexp.MustCompile(`\b(?:[0-9]{1,3}\.){3}[0-9]{1,3}\b`)
ipv6Regex := regexp.MustCompile(`\b([0-9a-fA-F:]+:+[0-9a-fA-F]{0,4})(?:%[0-9a-zA-Z]+)?(?:\/[0-9]{1,3})?(?::[0-9]{1,5})?\b`)
str = ipv4Regex.ReplaceAllStringFunc(str, a.AnonymizeIPString)
str = ipv6Regex.ReplaceAllStringFunc(str, a.AnonymizeIPString)
for domain, anonDomain := range a.domainAnonymizer {
str = strings.ReplaceAll(str, domain, anonDomain)
}
str = a.AnonymizeSchemeURI(str)
str = a.AnonymizeDNSLogLine(str)
return str
}
// AnonymizeSchemeURI finds and anonymizes URIs with stun, stuns, turn, and turns schemes.
func (a *Anonymizer) AnonymizeSchemeURI(text string) string {
re := regexp.MustCompile(`(?i)\b(stuns?:|turns?:|https?://)\S+\b`)
return re.ReplaceAllStringFunc(text, a.AnonymizeURI)
}
// AnonymizeDNSLogLine anonymizes domain names in DNS log entries by replacing them with a random string.
func (a *Anonymizer) AnonymizeDNSLogLine(logEntry string) string {
domainPattern := `dns\.Question{Name:"([^"]+)",`
domainRegex := regexp.MustCompile(domainPattern)
return domainRegex.ReplaceAllStringFunc(logEntry, func(match string) string {
parts := strings.Split(match, `"`)
if len(parts) >= 2 {
domain := parts[1]
if strings.HasSuffix(domain, ".domain") {
return match
}
randomDomain := generateRandomString(10) + ".domain"
return strings.Replace(match, domain, randomDomain, 1)
}
return match
})
}
func isWellKnown(addr netip.Addr) bool {
wellKnown := []string{
"8.8.8.8", "8.8.4.4", // Google DNS IPv4
"2001:4860:4860::8888", "2001:4860:4860::8844", // Google DNS IPv6
"1.1.1.1", "1.0.0.1", // Cloudflare DNS IPv4
"2606:4700:4700::1111", "2606:4700:4700::1001", // Cloudflare DNS IPv6
"9.9.9.9", "149.112.112.112", // Quad9 DNS IPv4
"2620:fe::fe", "2620:fe::9", // Quad9 DNS IPv6
}
if slices.Contains(wellKnown, addr.String()) {
return true
}
cgnatRangeStart := netip.AddrFrom4([4]byte{100, 64, 0, 0})
cgnatRange := netip.PrefixFrom(cgnatRangeStart, 10)
return cgnatRange.Contains(addr)
}
func generateRandomString(length int) string {
const letters = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
result := make([]byte, length)
for i := range result {
num, err := rand.Int(rand.Reader, big.NewInt(int64(len(letters))))
if err != nil {
continue
}
result[i] = letters[num.Int64()]
}
return string(result)
}

View File

@@ -0,0 +1,223 @@
package anonymize_test
import (
"net/netip"
"regexp"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/anonymize"
)
func TestAnonymizeIP(t *testing.T) {
startIPv4 := netip.MustParseAddr("198.51.100.0")
startIPv6 := netip.MustParseAddr("100::")
anonymizer := anonymize.NewAnonymizer(startIPv4, startIPv6)
tests := []struct {
name string
ip string
expect string
}{
{"Well known", "8.8.8.8", "8.8.8.8"},
{"First Public IPv4", "1.2.3.4", "198.51.100.0"},
{"Second Public IPv4", "4.3.2.1", "198.51.100.1"},
{"Repeated IPv4", "1.2.3.4", "198.51.100.0"},
{"Private IPv4", "192.168.1.1", "192.168.1.1"},
{"First Public IPv6", "2607:f8b0:4005:805::200e", "100::"},
{"Second Public IPv6", "a::b", "100::1"},
{"Repeated IPv6", "2607:f8b0:4005:805::200e", "100::"},
{"Private IPv6", "fe80::1", "fe80::1"},
{"In Range IPv4", "198.51.100.2", "198.51.100.2"},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
ip := netip.MustParseAddr(tc.ip)
anonymizedIP := anonymizer.AnonymizeIP(ip)
if anonymizedIP.String() != tc.expect {
t.Errorf("%s: expected %s, got %s", tc.name, tc.expect, anonymizedIP)
}
})
}
}
func TestAnonymizeDNSLogLine(t *testing.T) {
anonymizer := anonymize.NewAnonymizer(netip.Addr{}, netip.Addr{})
testLog := `2024-04-23T20:01:11+02:00 TRAC client/internal/dns/local.go:25: received question: dns.Question{Name:"example.com", Qtype:0x1c, Qclass:0x1}`
result := anonymizer.AnonymizeDNSLogLine(testLog)
require.NotEqual(t, testLog, result)
assert.NotContains(t, result, "example.com")
}
func TestAnonymizeDomain(t *testing.T) {
anonymizer := anonymize.NewAnonymizer(netip.Addr{}, netip.Addr{})
tests := []struct {
name string
domain string
expectPattern string
shouldAnonymize bool
}{
{
"General Domain",
"example.com",
`^anon-[a-zA-Z0-9]+\.domain$`,
true,
},
{
"Subdomain",
"sub.example.com",
`^sub\.anon-[a-zA-Z0-9]+\.domain$`,
true,
},
{
"Protected Domain",
"netbird.io",
`^netbird\.io$`,
false,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
result := anonymizer.AnonymizeDomain(tc.domain)
if tc.shouldAnonymize {
assert.Regexp(t, tc.expectPattern, result, "The anonymized domain should match the expected pattern")
assert.NotContains(t, result, tc.domain, "The original domain should not be present in the result")
} else {
assert.Equal(t, tc.domain, result, "Protected domains should not be anonymized")
}
})
}
}
func TestAnonymizeURI(t *testing.T) {
anonymizer := anonymize.NewAnonymizer(netip.Addr{}, netip.Addr{})
tests := []struct {
name string
uri string
regex string
}{
{
"HTTP URI with Port",
"http://example.com:80/path",
`^http://anon-[a-zA-Z0-9]+\.domain:80/path$`,
},
{
"HTTP URI without Port",
"http://example.com/path",
`^http://anon-[a-zA-Z0-9]+\.domain/path$`,
},
{
"Opaque URI with Port",
"stun:example.com:80?transport=udp",
`^stun:anon-[a-zA-Z0-9]+\.domain:80\?transport=udp$`,
},
{
"Opaque URI without Port",
"stun:example.com?transport=udp",
`^stun:anon-[a-zA-Z0-9]+\.domain\?transport=udp$`,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
result := anonymizer.AnonymizeURI(tc.uri)
assert.Regexp(t, regexp.MustCompile(tc.regex), result, "URI should match expected pattern")
require.NotContains(t, result, "example.com", "Original domain should not be present")
})
}
}
func TestAnonymizeSchemeURI(t *testing.T) {
anonymizer := anonymize.NewAnonymizer(netip.Addr{}, netip.Addr{})
tests := []struct {
name string
input string
expect string
}{
{"STUN URI in text", "Connection made via stun:example.com", `Connection made via stun:anon-[a-zA-Z0-9]+\.domain`},
{"TURN URI in log", "Failed attempt turn:some.example.com:3478?transport=tcp: retrying", `Failed attempt turn:some.anon-[a-zA-Z0-9]+\.domain:3478\?transport=tcp: retrying`},
{"HTTPS URI in message", "Visit https://example.com for more", `Visit https://anon-[a-zA-Z0-9]+\.domain for more`},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
result := anonymizer.AnonymizeSchemeURI(tc.input)
assert.Regexp(t, tc.expect, result, "The anonymized output should match expected pattern")
require.NotContains(t, result, "example.com", "Original domain should not be present")
})
}
}
func TestAnonymizString_MemorizedDomain(t *testing.T) {
anonymizer := anonymize.NewAnonymizer(netip.Addr{}, netip.Addr{})
domain := "example.com"
anonymizedDomain := anonymizer.AnonymizeDomain(domain)
sampleString := "This is a test string including the domain example.com which should be anonymized."
firstPassResult := anonymizer.AnonymizeString(sampleString)
secondPassResult := anonymizer.AnonymizeString(firstPassResult)
assert.Contains(t, firstPassResult, anonymizedDomain, "The domain should be anonymized in the first pass")
assert.NotContains(t, firstPassResult, domain, "The original domain should not appear in the first pass output")
assert.Equal(t, firstPassResult, secondPassResult, "The second pass should not further anonymize the string")
}
func TestAnonymizeString_DoubleURI(t *testing.T) {
anonymizer := anonymize.NewAnonymizer(netip.Addr{}, netip.Addr{})
domain := "example.com"
anonymizedDomain := anonymizer.AnonymizeDomain(domain)
sampleString := "Check out our site at https://example.com for more info."
firstPassResult := anonymizer.AnonymizeString(sampleString)
secondPassResult := anonymizer.AnonymizeString(firstPassResult)
assert.Contains(t, firstPassResult, "https://"+anonymizedDomain, "The URI should be anonymized in the first pass")
assert.NotContains(t, firstPassResult, "https://example.com", "The original URI should not appear in the first pass output")
assert.Equal(t, firstPassResult, secondPassResult, "The second pass should not further anonymize the URI")
}
func TestAnonymizeString_IPAddresses(t *testing.T) {
anonymizer := anonymize.NewAnonymizer(anonymize.DefaultAddresses())
tests := []struct {
name string
input string
expect string
}{
{
name: "IPv4 Address",
input: "Error occurred at IP 122.138.1.1",
expect: "Error occurred at IP 198.51.100.0",
},
{
name: "IPv6 Address",
input: "Access attempted from 2001:db8::ff00:42",
expect: "Access attempted from 100::",
},
{
name: "IPv6 Address with Port",
input: "Access attempted from [2001:db8::ff00:42]:8080",
expect: "Access attempted from [100::]:8080",
},
{
name: "Both IPv4 and IPv6",
input: "IPv4: 142.108.0.1 and IPv6: 2001:db8::ff00:43",
expect: "IPv4: 198.51.100.1 and IPv6: 100::1",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
result := anonymizer.AnonymizeString(tc.input)
assert.Equal(t, tc.expect, result, "IP addresses should be anonymized correctly")
})
}
}

248
client/cmd/debug.go Normal file
View File

@@ -0,0 +1,248 @@
package cmd
import (
"context"
"fmt"
"strings"
"time"
"github.com/spf13/cobra"
"google.golang.org/grpc/status"
"github.com/netbirdio/netbird/client/proto"
)
var debugCmd = &cobra.Command{
Use: "debug",
Short: "Debugging commands",
Long: "Provides commands for debugging and logging control within the Netbird daemon.",
}
var debugBundleCmd = &cobra.Command{
Use: "bundle",
Example: " netbird debug bundle",
Short: "Create a debug bundle",
Long: "Generates a compressed archive of the daemon's logs and status for debugging purposes.",
RunE: debugBundle,
}
var logCmd = &cobra.Command{
Use: "log",
Short: "Manage logging for the Netbird daemon",
Long: `Commands to manage logging settings for the Netbird daemon, including ICE, gRPC, and general log levels.`,
}
var logLevelCmd = &cobra.Command{
Use: "level <level>",
Short: "Set the logging level for this session",
Long: `Sets the logging level for the current session. This setting is temporary and will revert to the default on daemon restart.
Available log levels are:
panic: for panic level, highest level of severity
fatal: for fatal level errors that cause the program to exit
error: for error conditions
warn: for warning conditions
info: for informational messages
debug: for debug-level messages
trace: for trace-level messages, which include more fine-grained information than debug`,
Args: cobra.ExactArgs(1),
RunE: setLogLevel,
}
var forCmd = &cobra.Command{
Use: "for <time>",
Short: "Run debug logs for a specified duration and create a debug bundle",
Long: `Sets the logging level to trace, runs for the specified duration, and then generates a debug bundle.`,
Example: " netbird debug for 5m",
Args: cobra.ExactArgs(1),
RunE: runForDuration,
}
func debugBundle(cmd *cobra.Command, _ []string) error {
conn, err := getClient(cmd.Context())
if err != nil {
return err
}
defer conn.Close()
client := proto.NewDaemonServiceClient(conn)
resp, err := client.DebugBundle(cmd.Context(), &proto.DebugBundleRequest{
Anonymize: anonymizeFlag,
Status: getStatusOutput(cmd),
})
if err != nil {
return fmt.Errorf("failed to bundle debug: %v", status.Convert(err).Message())
}
cmd.Println(resp.GetPath())
return nil
}
func setLogLevel(cmd *cobra.Command, args []string) error {
conn, err := getClient(cmd.Context())
if err != nil {
return err
}
defer conn.Close()
client := proto.NewDaemonServiceClient(conn)
level := parseLogLevel(args[0])
if level == proto.LogLevel_UNKNOWN {
return fmt.Errorf("unknown log level: %s. Available levels are: panic, fatal, error, warn, info, debug, trace\n", args[0])
}
_, err = client.SetLogLevel(cmd.Context(), &proto.SetLogLevelRequest{
Level: level,
})
if err != nil {
return fmt.Errorf("failed to set log level: %v", status.Convert(err).Message())
}
cmd.Println("Log level set successfully to", args[0])
return nil
}
func parseLogLevel(level string) proto.LogLevel {
switch strings.ToLower(level) {
case "panic":
return proto.LogLevel_PANIC
case "fatal":
return proto.LogLevel_FATAL
case "error":
return proto.LogLevel_ERROR
case "warn":
return proto.LogLevel_WARN
case "info":
return proto.LogLevel_INFO
case "debug":
return proto.LogLevel_DEBUG
case "trace":
return proto.LogLevel_TRACE
default:
return proto.LogLevel_UNKNOWN
}
}
func runForDuration(cmd *cobra.Command, args []string) error {
duration, err := time.ParseDuration(args[0])
if err != nil {
return fmt.Errorf("invalid duration format: %v", err)
}
conn, err := getClient(cmd.Context())
if err != nil {
return err
}
defer conn.Close()
client := proto.NewDaemonServiceClient(conn)
if _, err := client.Down(cmd.Context(), &proto.DownRequest{}); err != nil {
return fmt.Errorf("failed to down: %v", status.Convert(err).Message())
}
cmd.Println("Netbird down")
_, err = client.SetLogLevel(cmd.Context(), &proto.SetLogLevelRequest{
Level: proto.LogLevel_TRACE,
})
if err != nil {
return fmt.Errorf("failed to set log level to trace: %v", status.Convert(err).Message())
}
cmd.Println("Log level set to trace.")
time.Sleep(1 * time.Second)
if _, err := client.Up(cmd.Context(), &proto.UpRequest{}); err != nil {
return fmt.Errorf("failed to up: %v", status.Convert(err).Message())
}
cmd.Println("Netbird up")
time.Sleep(3 * time.Second)
headerPostUp := fmt.Sprintf("----- Netbird post-up - Timestamp: %s", time.Now().Format(time.RFC3339))
statusOutput := fmt.Sprintf("%s\n%s", headerPostUp, getStatusOutput(cmd))
if waitErr := waitForDurationOrCancel(cmd.Context(), duration, cmd); waitErr != nil {
return waitErr
}
cmd.Println("\nDuration completed")
headerPreDown := fmt.Sprintf("----- Netbird pre-down - Timestamp: %s - Duration: %s", time.Now().Format(time.RFC3339), duration)
statusOutput = fmt.Sprintf("%s\n%s\n%s", statusOutput, headerPreDown, getStatusOutput(cmd))
if _, err := client.Down(cmd.Context(), &proto.DownRequest{}); err != nil {
return fmt.Errorf("failed to down: %v", status.Convert(err).Message())
}
cmd.Println("Netbird down")
// TODO reset log level
time.Sleep(1 * time.Second)
cmd.Println("Creating debug bundle...")
resp, err := client.DebugBundle(cmd.Context(), &proto.DebugBundleRequest{
Anonymize: anonymizeFlag,
Status: statusOutput,
})
if err != nil {
return fmt.Errorf("failed to bundle debug: %v", status.Convert(err).Message())
}
cmd.Println(resp.GetPath())
return nil
}
func getStatusOutput(cmd *cobra.Command) string {
var statusOutputString string
statusResp, err := getStatus(cmd.Context())
if err != nil {
cmd.PrintErrf("Failed to get status: %v\n", err)
} else {
statusOutputString = parseToFullDetailSummary(convertToStatusOutputOverview(statusResp))
}
return statusOutputString
}
func waitForDurationOrCancel(ctx context.Context, duration time.Duration, cmd *cobra.Command) error {
ticker := time.NewTicker(1 * time.Second)
defer ticker.Stop()
startTime := time.Now()
done := make(chan struct{})
go func() {
defer close(done)
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
elapsed := time.Since(startTime)
if elapsed >= duration {
return
}
remaining := duration - elapsed
cmd.Printf("\rRemaining time: %s", formatDuration(remaining))
}
}
}()
select {
case <-ctx.Done():
return ctx.Err()
case <-done:
return nil
}
}
func formatDuration(d time.Duration) string {
d = d.Round(time.Second)
h := d / time.Hour
d %= time.Hour
m := d / time.Minute
d %= time.Minute
s := d / time.Second
return fmt.Sprintf("%02d:%02d:%02d", h, m, s)
}

View File

@@ -65,6 +65,7 @@ var (
serviceName string
autoConnectDisabled bool
extraIFaceBlackList []string
anonymizeFlag bool
rootCmd = &cobra.Command{
Use: "netbird",
Short: "",
@@ -119,6 +120,8 @@ func init() {
rootCmd.PersistentFlags().StringVarP(&setupKey, "setup-key", "k", "", "Setup key obtained from the Management Service Dashboard (used to register peer)")
rootCmd.PersistentFlags().StringVar(&preSharedKey, preSharedKeyFlag, "", "Sets Wireguard PreSharedKey property. If set, then only peers that have the same key can communicate.")
rootCmd.PersistentFlags().StringVarP(&hostName, "hostname", "n", "", "Sets a custom hostname for the device")
rootCmd.PersistentFlags().BoolVarP(&anonymizeFlag, "anonymize", "A", false, "anonymize IP addresses and non-netbird.io domains in logs and status output")
rootCmd.AddCommand(serviceCmd)
rootCmd.AddCommand(upCmd)
rootCmd.AddCommand(downCmd)
@@ -126,8 +129,20 @@ func init() {
rootCmd.AddCommand(loginCmd)
rootCmd.AddCommand(versionCmd)
rootCmd.AddCommand(sshCmd)
rootCmd.AddCommand(routesCmd)
rootCmd.AddCommand(debugCmd)
serviceCmd.AddCommand(runCmd, startCmd, stopCmd, restartCmd) // service control commands are subcommands of service
serviceCmd.AddCommand(installCmd, uninstallCmd) // service installer commands are subcommands of service
routesCmd.AddCommand(routesListCmd)
routesCmd.AddCommand(routesSelectCmd, routesDeselectCmd)
debugCmd.AddCommand(debugBundleCmd)
debugCmd.AddCommand(logCmd)
logCmd.AddCommand(logLevelCmd)
debugCmd.AddCommand(forCmd)
upCmd.PersistentFlags().StringSliceVar(&natExternalIPs, externalIPMapFlag, nil,
`Sets external IPs maps between local addresses and interfaces.`+
`You can specify a comma-separated list with a single IP and IP/IP or IP/Interface Name. `+
@@ -335,3 +350,14 @@ func migrateToNetbird(oldPath, newPath string) bool {
return true
}
func getClient(ctx context.Context) (*grpc.ClientConn, error) {
conn, err := DialClientGRPCServer(ctx, daemonAddr)
if err != nil {
return nil, fmt.Errorf("failed to connect to daemon error: %v\n"+
"If the daemon is not running please run: "+
"\nnetbird service install \nnetbird service start\n", err)
}
return conn, nil
}

131
client/cmd/route.go Normal file
View File

@@ -0,0 +1,131 @@
package cmd
import (
"fmt"
"github.com/spf13/cobra"
"google.golang.org/grpc/status"
"github.com/netbirdio/netbird/client/proto"
)
var appendFlag bool
var routesCmd = &cobra.Command{
Use: "routes",
Short: "Manage network routes",
Long: `Commands to list, select, or deselect network routes.`,
}
var routesListCmd = &cobra.Command{
Use: "list",
Aliases: []string{"ls"},
Short: "List routes",
Example: " netbird routes list",
Long: "List all available network routes.",
RunE: routesList,
}
var routesSelectCmd = &cobra.Command{
Use: "select route...|all",
Short: "Select routes",
Long: "Select a list of routes by identifiers or 'all' to clear all selections and to accept all (including new) routes.\nDefault mode is replace, use -a to append to already selected routes.",
Example: " netbird routes select all\n netbird routes select route1 route2\n netbird routes select -a route3",
Args: cobra.MinimumNArgs(1),
RunE: routesSelect,
}
var routesDeselectCmd = &cobra.Command{
Use: "deselect route...|all",
Short: "Deselect routes",
Long: "Deselect previously selected routes by identifiers or 'all' to disable accepting any routes.",
Example: " netbird routes deselect all\n netbird routes deselect route1 route2",
Args: cobra.MinimumNArgs(1),
RunE: routesDeselect,
}
func init() {
routesSelectCmd.PersistentFlags().BoolVarP(&appendFlag, "append", "a", false, "Append to current route selection instead of replacing")
}
func routesList(cmd *cobra.Command, _ []string) error {
conn, err := getClient(cmd.Context())
if err != nil {
return err
}
defer conn.Close()
client := proto.NewDaemonServiceClient(conn)
resp, err := client.ListRoutes(cmd.Context(), &proto.ListRoutesRequest{})
if err != nil {
return fmt.Errorf("failed to list routes: %v", status.Convert(err).Message())
}
if len(resp.Routes) == 0 {
cmd.Println("No routes available.")
return nil
}
cmd.Println("Available Routes:")
for _, route := range resp.Routes {
selectedStatus := "Not Selected"
if route.GetSelected() {
selectedStatus = "Selected"
}
cmd.Printf("\n - ID: %s\n Network: %s\n Status: %s\n", route.GetID(), route.GetNetwork(), selectedStatus)
}
return nil
}
func routesSelect(cmd *cobra.Command, args []string) error {
conn, err := getClient(cmd.Context())
if err != nil {
return err
}
defer conn.Close()
client := proto.NewDaemonServiceClient(conn)
req := &proto.SelectRoutesRequest{
RouteIDs: args,
}
if len(args) == 1 && args[0] == "all" {
req.All = true
} else if appendFlag {
req.Append = true
}
if _, err := client.SelectRoutes(cmd.Context(), req); err != nil {
return fmt.Errorf("failed to select routes: %v", status.Convert(err).Message())
}
cmd.Println("Routes selected successfully.")
return nil
}
func routesDeselect(cmd *cobra.Command, args []string) error {
conn, err := getClient(cmd.Context())
if err != nil {
return err
}
defer conn.Close()
client := proto.NewDaemonServiceClient(conn)
req := &proto.SelectRoutesRequest{
RouteIDs: args,
}
if len(args) == 1 && args[0] == "all" {
req.All = true
}
if _, err := client.DeselectRoutes(cmd.Context(), req); err != nil {
return fmt.Errorf("failed to deselect routes: %v", status.Convert(err).Message())
}
cmd.Println("Routes deselected successfully.")
return nil
}

View File

@@ -24,7 +24,7 @@ var (
)
var sshCmd = &cobra.Command{
Use: "ssh",
Use: "ssh [user@]host",
Args: func(cmd *cobra.Command, args []string) error {
if len(args) < 1 {
return errors.New("requires a host argument")
@@ -94,7 +94,7 @@ func runSSH(ctx context.Context, addr string, pemKey []byte, cmd *cobra.Command)
if err != nil {
cmd.Printf("Error: %v\n", err)
cmd.Printf("Couldn't connect. Please check the connection status or if the ssh server is enabled on the other peer" +
"You can verify the connection by running:\n\n" +
"\nYou can verify the connection by running:\n\n" +
" netbird status\n\n")
return err
}

View File

@@ -6,6 +6,8 @@ import (
"fmt"
"net"
"net/netip"
"os"
"runtime"
"sort"
"strings"
"time"
@@ -14,6 +16,7 @@ import (
"google.golang.org/grpc/status"
"gopkg.in/yaml.v3"
"github.com/netbirdio/netbird/client/anonymize"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/proto"
@@ -144,9 +147,9 @@ func statusFunc(cmd *cobra.Command, args []string) error {
return fmt.Errorf("failed initializing log %v", err)
}
ctx := internal.CtxInitState(context.Background())
ctx := internal.CtxInitState(cmd.Context())
resp, err := getStatus(ctx, cmd)
resp, err := getStatus(ctx)
if err != nil {
return err
}
@@ -191,7 +194,7 @@ func statusFunc(cmd *cobra.Command, args []string) error {
return nil
}
func getStatus(ctx context.Context, cmd *cobra.Command) (*proto.StatusResponse, error) {
func getStatus(ctx context.Context) (*proto.StatusResponse, error) {
conn, err := DialClientGRPCServer(ctx, daemonAddr)
if err != nil {
return nil, fmt.Errorf("failed to connect to daemon error: %v\n"+
@@ -200,7 +203,7 @@ func getStatus(ctx context.Context, cmd *cobra.Command) (*proto.StatusResponse,
}
defer conn.Close()
resp, err := proto.NewDaemonServiceClient(conn).Status(cmd.Context(), &proto.StatusRequest{GetFullPeerStatus: true})
resp, err := proto.NewDaemonServiceClient(conn).Status(ctx, &proto.StatusRequest{GetFullPeerStatus: true})
if err != nil {
return nil, fmt.Errorf("status failed: %v", status.Convert(err).Message())
}
@@ -283,6 +286,11 @@ func convertToStatusOutputOverview(resp *proto.StatusResponse) statusOutputOverv
NSServerGroups: mapNSGroups(pbFullStatus.GetDnsServers()),
}
if anonymizeFlag {
anonymizer := anonymize.NewAnonymizer(anonymize.DefaultAddresses())
anonymizeOverview(anonymizer, &overview)
}
return overview
}
@@ -525,8 +533,16 @@ func parseGeneralSummary(overview statusOutputOverview, showURL bool, showRelays
peersCountString := fmt.Sprintf("%d/%d Connected", overview.Peers.Connected, overview.Peers.Total)
goos := runtime.GOOS
goarch := runtime.GOARCH
goarm := ""
if goarch == "arm" {
goarm = fmt.Sprintf(" (ARMv%s)", os.Getenv("GOARM"))
}
summary := fmt.Sprintf(
"Daemon version: %s\n"+
"OS: %s\n"+
"Daemon version: %s\n"+
"CLI version: %s\n"+
"Management: %s\n"+
"Signal: %s\n"+
@@ -538,6 +554,7 @@ func parseGeneralSummary(overview statusOutputOverview, showURL bool, showRelays
"Quantum resistance: %s\n"+
"Routes: %s\n"+
"Peers count: %s\n",
fmt.Sprintf("%s/%s%s", goos, goarch, goarm),
overview.DaemonVersion,
version.NetbirdVersion(),
managementConnString,
@@ -593,15 +610,6 @@ func parsePeers(peers peersStateOutput, rosenpassEnabled, rosenpassPermissive bo
if peerState.IceCandidateEndpoint.Remote != "" {
remoteICEEndpoint = peerState.IceCandidateEndpoint.Remote
}
lastStatusUpdate := "-"
if !peerState.LastStatusUpdate.IsZero() {
lastStatusUpdate = peerState.LastStatusUpdate.Format("2006-01-02 15:04:05")
}
lastWireGuardHandshake := "-"
if !peerState.LastWireguardHandshake.IsZero() && peerState.LastWireguardHandshake != time.Unix(0, 0) {
lastWireGuardHandshake = peerState.LastWireguardHandshake.Format("2006-01-02 15:04:05")
}
rosenpassEnabledStatus := "false"
if rosenpassEnabled {
@@ -652,8 +660,8 @@ func parsePeers(peers peersStateOutput, rosenpassEnabled, rosenpassPermissive bo
remoteICE,
localICEEndpoint,
remoteICEEndpoint,
lastStatusUpdate,
lastWireGuardHandshake,
timeAgo(peerState.LastStatusUpdate),
timeAgo(peerState.LastWireguardHandshake),
toIEC(peerState.TransferReceived),
toIEC(peerState.TransferSent),
rosenpassEnabledStatus,
@@ -722,3 +730,129 @@ func countEnabled(dnsServers []nsServerGroupStateOutput) int {
}
return count
}
// timeAgo returns a string representing the duration since the provided time in a human-readable format.
func timeAgo(t time.Time) string {
if t.IsZero() || t.Equal(time.Unix(0, 0)) {
return "-"
}
duration := time.Since(t)
switch {
case duration < time.Second:
return "Now"
case duration < time.Minute:
seconds := int(duration.Seconds())
if seconds == 1 {
return "1 second ago"
}
return fmt.Sprintf("%d seconds ago", seconds)
case duration < time.Hour:
minutes := int(duration.Minutes())
seconds := int(duration.Seconds()) % 60
if minutes == 1 {
if seconds == 1 {
return "1 minute, 1 second ago"
} else if seconds > 0 {
return fmt.Sprintf("1 minute, %d seconds ago", seconds)
}
return "1 minute ago"
}
if seconds > 0 {
return fmt.Sprintf("%d minutes, %d seconds ago", minutes, seconds)
}
return fmt.Sprintf("%d minutes ago", minutes)
case duration < 24*time.Hour:
hours := int(duration.Hours())
minutes := int(duration.Minutes()) % 60
if hours == 1 {
if minutes == 1 {
return "1 hour, 1 minute ago"
} else if minutes > 0 {
return fmt.Sprintf("1 hour, %d minutes ago", minutes)
}
return "1 hour ago"
}
if minutes > 0 {
return fmt.Sprintf("%d hours, %d minutes ago", hours, minutes)
}
return fmt.Sprintf("%d hours ago", hours)
}
days := int(duration.Hours()) / 24
hours := int(duration.Hours()) % 24
if days == 1 {
if hours == 1 {
return "1 day, 1 hour ago"
} else if hours > 0 {
return fmt.Sprintf("1 day, %d hours ago", hours)
}
return "1 day ago"
}
if hours > 0 {
return fmt.Sprintf("%d days, %d hours ago", days, hours)
}
return fmt.Sprintf("%d days ago", days)
}
func anonymizePeerDetail(a *anonymize.Anonymizer, peer *peerStateDetailOutput) {
peer.FQDN = a.AnonymizeDomain(peer.FQDN)
if localIP, port, err := net.SplitHostPort(peer.IceCandidateEndpoint.Local); err == nil {
peer.IceCandidateEndpoint.Local = fmt.Sprintf("%s:%s", a.AnonymizeIPString(localIP), port)
}
if remoteIP, port, err := net.SplitHostPort(peer.IceCandidateEndpoint.Remote); err == nil {
peer.IceCandidateEndpoint.Remote = fmt.Sprintf("%s:%s", a.AnonymizeIPString(remoteIP), port)
}
for i, route := range peer.Routes {
peer.Routes[i] = a.AnonymizeIPString(route)
}
for i, route := range peer.Routes {
prefix, err := netip.ParsePrefix(route)
if err == nil {
ip := a.AnonymizeIPString(prefix.Addr().String())
peer.Routes[i] = fmt.Sprintf("%s/%d", ip, prefix.Bits())
}
}
}
func anonymizeOverview(a *anonymize.Anonymizer, overview *statusOutputOverview) {
for i, peer := range overview.Peers.Details {
peer := peer
anonymizePeerDetail(a, &peer)
overview.Peers.Details[i] = peer
}
overview.ManagementState.URL = a.AnonymizeURI(overview.ManagementState.URL)
overview.ManagementState.Error = a.AnonymizeString(overview.ManagementState.Error)
overview.SignalState.URL = a.AnonymizeURI(overview.SignalState.URL)
overview.SignalState.Error = a.AnonymizeString(overview.SignalState.Error)
overview.IP = a.AnonymizeIPString(overview.IP)
for i, detail := range overview.Relays.Details {
detail.URI = a.AnonymizeURI(detail.URI)
detail.Error = a.AnonymizeString(detail.Error)
overview.Relays.Details[i] = detail
}
for i, nsGroup := range overview.NSServerGroups {
for j, domain := range nsGroup.Domains {
overview.NSServerGroups[i].Domains[j] = a.AnonymizeDomain(domain)
}
for j, ns := range nsGroup.Servers {
host, port, err := net.SplitHostPort(ns)
if err == nil {
overview.NSServerGroups[i].Servers[j] = fmt.Sprintf("%s:%s", a.AnonymizeIPString(host), port)
}
}
}
for i, route := range overview.Routes {
prefix, err := netip.ParsePrefix(route)
if err == nil {
ip := a.AnonymizeIPString(prefix.Addr().String())
overview.Routes[i] = fmt.Sprintf("%s/%d", ip, prefix.Bits())
}
}
overview.FQDN = a.AnonymizeDomain(overview.FQDN)
}

View File

@@ -3,6 +3,8 @@ package cmd
import (
"bytes"
"encoding/json"
"fmt"
"runtime"
"testing"
"time"
@@ -487,9 +489,15 @@ dnsServers:
}
func TestParsingToDetail(t *testing.T) {
// Calculate time ago based on the fixture dates
lastConnectionUpdate1 := timeAgo(overview.Peers.Details[0].LastStatusUpdate)
lastHandshake1 := timeAgo(overview.Peers.Details[0].LastWireguardHandshake)
lastConnectionUpdate2 := timeAgo(overview.Peers.Details[1].LastStatusUpdate)
lastHandshake2 := timeAgo(overview.Peers.Details[1].LastWireguardHandshake)
detail := parseToFullDetailSummary(overview)
expectedDetail :=
expectedDetail := fmt.Sprintf(
`Peers detail:
peer-1.awesome-domain.com:
NetBird IP: 192.168.178.101
@@ -500,8 +508,8 @@ func TestParsingToDetail(t *testing.T) {
Direct: true
ICE candidate (Local/Remote): -/-
ICE candidate endpoints (Local/Remote): -/-
Last connection update: 2001-01-01 01:01:01
Last WireGuard handshake: 2001-01-01 01:01:02
Last connection update: %s
Last WireGuard handshake: %s
Transfer status (received/sent) 200 B/100 B
Quantum resistance: false
Routes: 10.1.0.0/24
@@ -516,15 +524,16 @@ func TestParsingToDetail(t *testing.T) {
Direct: false
ICE candidate (Local/Remote): relay/prflx
ICE candidate endpoints (Local/Remote): 10.0.0.1:10001/10.0.10.1:10002
Last connection update: 2002-02-02 02:02:02
Last WireGuard handshake: 2002-02-02 02:02:03
Last connection update: %s
Last WireGuard handshake: %s
Transfer status (received/sent) 2.0 KiB/1000 B
Quantum resistance: false
Routes: -
Latency: 10ms
OS: %s/%s
Daemon version: 0.14.1
CLI version: development
CLI version: %s
Management: Connected to my-awesome-management.com:443
Signal: Connected to my-awesome-signal.com:443
Relays:
@@ -539,7 +548,7 @@ Interface type: Kernel
Quantum resistance: false
Routes: 10.10.0.0/24
Peers count: 2/2 Connected
`
`, lastConnectionUpdate1, lastHandshake1, lastConnectionUpdate2, lastHandshake2, runtime.GOOS, runtime.GOARCH, overview.CliVersion)
assert.Equal(t, expectedDetail, detail)
}
@@ -547,8 +556,8 @@ Peers count: 2/2 Connected
func TestParsingToShortVersion(t *testing.T) {
shortVersion := parseGeneralSummary(overview, false, false, false)
expectedString :=
`Daemon version: 0.14.1
expectedString := fmt.Sprintf("OS: %s/%s", runtime.GOOS, runtime.GOARCH) + `
Daemon version: 0.14.1
CLI version: development
Management: Connected
Signal: Connected
@@ -572,3 +581,31 @@ func TestParsingOfIP(t *testing.T) {
assert.Equal(t, "192.168.178.123\n", parsedIP)
}
func TestTimeAgo(t *testing.T) {
now := time.Now()
cases := []struct {
name string
input time.Time
expected string
}{
{"Now", now, "Now"},
{"Seconds ago", now.Add(-10 * time.Second), "10 seconds ago"},
{"One minute ago", now.Add(-1 * time.Minute), "1 minute ago"},
{"Minutes and seconds ago", now.Add(-(1*time.Minute + 30*time.Second)), "1 minute, 30 seconds ago"},
{"One hour ago", now.Add(-1 * time.Hour), "1 hour ago"},
{"Hours and minutes ago", now.Add(-(2*time.Hour + 15*time.Minute)), "2 hours, 15 minutes ago"},
{"One day ago", now.Add(-24 * time.Hour), "1 day ago"},
{"Multiple days ago", now.Add(-(72*time.Hour + 20*time.Minute)), "3 days ago"},
{"Zero time", time.Time{}, "-"},
{"Unix zero time", time.Unix(0, 0), "-"},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
result := timeAgo(tc.input)
assert.Equal(t, tc.expected, result, "Failed %s", tc.name)
})
}
}

View File

@@ -31,7 +31,7 @@ import (
// RunClient with main logic.
func RunClient(ctx context.Context, config *Config, statusRecorder *peer.Status) error {
return runClient(ctx, config, statusRecorder, MobileDependency{}, nil, nil, nil, nil)
return runClient(ctx, config, statusRecorder, MobileDependency{}, nil, nil, nil, nil, nil)
}
// RunClientWithProbes runs the client's main logic with probes attached
@@ -43,8 +43,9 @@ func RunClientWithProbes(
signalProbe *Probe,
relayProbe *Probe,
wgProbe *Probe,
engineChan chan<- *Engine,
) error {
return runClient(ctx, config, statusRecorder, MobileDependency{}, mgmProbe, signalProbe, relayProbe, wgProbe)
return runClient(ctx, config, statusRecorder, MobileDependency{}, mgmProbe, signalProbe, relayProbe, wgProbe, engineChan)
}
// RunClientMobile with main logic on mobile system
@@ -66,7 +67,7 @@ func RunClientMobile(
HostDNSAddresses: dnsAddresses,
DnsReadyListener: dnsReadyListener,
}
return runClient(ctx, config, statusRecorder, mobileDependency, nil, nil, nil, nil)
return runClient(ctx, config, statusRecorder, mobileDependency, nil, nil, nil, nil, nil)
}
func RunClientiOS(
@@ -82,7 +83,7 @@ func RunClientiOS(
NetworkChangeListener: networkChangeListener,
DnsManager: dnsManager,
}
return runClient(ctx, config, statusRecorder, mobileDependency, nil, nil, nil, nil)
return runClient(ctx, config, statusRecorder, mobileDependency, nil, nil, nil, nil, nil)
}
func runClient(
@@ -94,6 +95,7 @@ func runClient(
signalProbe *Probe,
relayProbe *Probe,
wgProbe *Probe,
engineChan chan<- *Engine,
) error {
defer func() {
if r := recover(); r != nil {
@@ -243,6 +245,9 @@ func runClient(
log.Errorf("error while starting Netbird Connection Engine: %s", err)
return wrapErr(err)
}
if engineChan != nil {
engineChan <- engine
}
log.Print("Netbird engine started, my IP is: ", peerConfig.Address)
state.Set(StatusConnected)
@@ -252,6 +257,10 @@ func runClient(
backOff.Reset()
if engineChan != nil {
engineChan <- nil
}
err = engine.Stop()
if err != nil {
log.Errorf("failed stopping engine %v", err)

View File

@@ -31,6 +31,8 @@ func (d *localResolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
response := d.lookupRecord(r)
if response != nil {
replyMessage.Answer = append(replyMessage.Answer, response)
} else {
replyMessage.Rcode = dns.RcodeNameError
}
err := w.WriteMsg(replyMessage)

View File

@@ -111,6 +111,9 @@ type Engine struct {
// TURNs is a list of STUN servers used by ICE
TURNs []*stun.URI
// clientRoutes is the most recent list of clientRoutes received from the Management Service
clientRoutes map[string][]*route.Route
cancel context.CancelFunc
ctx context.Context
@@ -216,6 +219,8 @@ func (e *Engine) Stop() error {
return err
}
e.clientRoutes = nil
// very ugly but we want to remove peers from the WireGuard interface first before removing interface.
// Removing peers happens in the conn.CLose() asynchronously
time.Sleep(500 * time.Millisecond)
@@ -695,11 +700,14 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
if protoRoutes == nil {
protoRoutes = []*mgmProto.Route{}
}
err := e.routeManager.UpdateRoutes(serial, toRoutes(protoRoutes))
_, clientRoutes, err := e.routeManager.UpdateRoutes(serial, toRoutes(protoRoutes))
if err != nil {
log.Errorf("failed to update routes, err: %v", err)
log.Errorf("failed to update clientRoutes, err: %v", err)
}
e.clientRoutes = clientRoutes
protoDNSConfig := networkMap.GetDNSConfig()
if protoDNSConfig == nil {
protoDNSConfig = &mgmProto.DNSConfig{}
@@ -794,6 +802,7 @@ func (e *Engine) updateOfflinePeers(offlinePeers []*mgmProto.RemotePeerConfig) {
FQDN: offlinePeer.GetFqdn(),
ConnStatus: peer.StatusDisconnected,
ConnStatusUpdate: time.Now(),
Mux: new(sync.RWMutex),
}
}
e.statusRecorder.ReplaceOfflinePeers(replacement)
@@ -1228,6 +1237,28 @@ func (e *Engine) newDnsServer() ([]*route.Route, dns.Server, error) {
}
}
// GetClientRoutes returns the current routes from the route map
func (e *Engine) GetClientRoutes() map[string][]*route.Route {
return e.clientRoutes
}
// GetClientRoutesWithNetID returns the current routes from the route map, but the keys consist of the network ID only
func (e *Engine) GetClientRoutesWithNetID() map[string][]*route.Route {
routes := make(map[string][]*route.Route, len(e.clientRoutes))
for id, v := range e.clientRoutes {
if i := strings.LastIndex(id, "-"); i != -1 {
id = id[:i]
}
routes[id] = v
}
return routes
}
// GetRouteManager returns the route manager
func (e *Engine) GetRouteManager() routemanager.Manager {
return e.routeManager
}
func findIPFromInterfaceName(ifaceName string) (net.IP, error) {
iface, err := net.InterfaceByName(ifaceName)
if err != nil {

View File

@@ -22,6 +22,7 @@ import (
"google.golang.org/grpc/keepalive"
"github.com/netbirdio/management-integrations/integrations"
"github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/routemanager"
@@ -577,10 +578,10 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) {
}{}
mockRouteManager := &routemanager.MockManager{
UpdateRoutesFunc: func(updateSerial uint64, newRoutes []*route.Route) error {
UpdateRoutesFunc: func(updateSerial uint64, newRoutes []*route.Route) (map[string]*route.Route, map[string][]*route.Route, error) {
input.inputSerial = updateSerial
input.inputRoutes = newRoutes
return testCase.inputErr
return nil, nil, testCase.inputErr
},
}
@@ -597,8 +598,8 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) {
err = engine.updateNetworkMap(testCase.networkMap)
assert.NoError(t, err, "shouldn't return error")
assert.Equal(t, testCase.expectedSerial, input.inputSerial, "serial should match")
assert.Len(t, input.inputRoutes, testCase.expectedLen, "routes len should match")
assert.Equal(t, testCase.expectedRoutes, input.inputRoutes, "routes should match")
assert.Len(t, input.inputRoutes, testCase.expectedLen, "clientRoutes len should match")
assert.Equal(t, testCase.expectedRoutes, input.inputRoutes, "clientRoutes should match")
})
}
}
@@ -742,8 +743,8 @@ func TestEngine_UpdateNetworkMapWithDNSUpdate(t *testing.T) {
assert.NoError(t, err, "shouldn't return error")
mockRouteManager := &routemanager.MockManager{
UpdateRoutesFunc: func(updateSerial uint64, newRoutes []*route.Route) error {
return nil
UpdateRoutesFunc: func(updateSerial uint64, newRoutes []*route.Route) (map[string]*route.Route, map[string][]*route.Route, error) {
return nil, nil, nil
},
}

View File

@@ -0,0 +1 @@
package internal

View File

@@ -229,7 +229,6 @@ func (conn *Conn) reCreateAgent() error {
}
conn.agent, err = ice.NewAgent(agentConfig)
if err != nil {
return err
}
@@ -285,6 +284,7 @@ func (conn *Conn) Open() error {
IP: strings.Split(conn.config.WgConfig.AllowedIps, "/")[0],
ConnStatusUpdate: time.Now(),
ConnStatus: conn.status,
Mux: new(sync.RWMutex),
}
err := conn.statusRecorder.UpdatePeerState(peerState)
if err != nil {
@@ -344,6 +344,7 @@ func (conn *Conn) Open() error {
PubKey: conn.config.Key,
ConnStatus: conn.status,
ConnStatusUpdate: time.Now(),
Mux: new(sync.RWMutex),
}
err = conn.statusRecorder.UpdatePeerState(peerState)
if err != nil {
@@ -465,9 +466,10 @@ func (conn *Conn) configureConnection(remoteConn net.Conn, remoteWgPort int, rem
LocalIceCandidateType: pair.Local.Type().String(),
RemoteIceCandidateType: pair.Remote.Type().String(),
LocalIceCandidateEndpoint: fmt.Sprintf("%s:%d", pair.Local.Address(), pair.Local.Port()),
RemoteIceCandidateEndpoint: fmt.Sprintf("%s:%d", pair.Remote.Address(), pair.Local.Port()),
RemoteIceCandidateEndpoint: fmt.Sprintf("%s:%d", pair.Remote.Address(), pair.Remote.Port()),
Direct: !isRelayCandidate(pair.Local),
RosenpassEnabled: rosenpassEnabled,
Mux: new(sync.RWMutex),
}
if pair.Local.Type() == ice.CandidateTypeRelay || pair.Remote.Type() == ice.CandidateTypeRelay {
peerState.Relayed = true
@@ -558,6 +560,7 @@ func (conn *Conn) cleanup() error {
PubKey: conn.config.Key,
ConnStatus: conn.status,
ConnStatusUpdate: time.Now(),
Mux: new(sync.RWMutex),
}
err := conn.statusRecorder.UpdatePeerState(peerState)
if err != nil {

View File

@@ -14,6 +14,7 @@ import (
// State contains the latest state of a peer
type State struct {
Mux *sync.RWMutex
IP string
PubKey string
FQDN string
@@ -30,7 +31,38 @@ type State struct {
BytesRx int64
Latency time.Duration
RosenpassEnabled bool
Routes map[string]struct{}
routes map[string]struct{}
}
// AddRoute add a single route to routes map
func (s *State) AddRoute(network string) {
s.Mux.Lock()
if s.routes == nil {
s.routes = make(map[string]struct{})
}
s.routes[network] = struct{}{}
s.Mux.Unlock()
}
// SetRoutes set state routes
func (s *State) SetRoutes(routes map[string]struct{}) {
s.Mux.Lock()
s.routes = routes
s.Mux.Unlock()
}
// DeleteRoute removes a route from the network amp
func (s *State) DeleteRoute(network string) {
s.Mux.Lock()
delete(s.routes, network)
s.Mux.Unlock()
}
// GetRoutes return routes map
func (s *State) GetRoutes() map[string]struct{} {
s.Mux.RLock()
defer s.Mux.RUnlock()
return s.routes
}
// LocalPeerState contains the latest state of the local peer
@@ -143,6 +175,7 @@ func (d *Status) AddPeer(peerPubKey string, fqdn string) error {
PubKey: peerPubKey,
ConnStatus: StatusDisconnected,
FQDN: fqdn,
Mux: new(sync.RWMutex),
}
d.peerListChangedForNotification = true
return nil
@@ -189,8 +222,8 @@ func (d *Status) UpdatePeerState(receivedState State) error {
peerState.IP = receivedState.IP
}
if receivedState.Routes != nil {
peerState.Routes = receivedState.Routes
if receivedState.GetRoutes() != nil {
peerState.SetRoutes(receivedState.GetRoutes())
}
skipNotification := shouldSkipNotify(receivedState, peerState)
@@ -440,7 +473,6 @@ func (d *Status) IsLoginRequired() bool {
s, ok := gstatus.FromError(d.managementError)
if ok && (s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied) {
return true
}
return false
}

View File

@@ -3,6 +3,7 @@ package peer
import (
"errors"
"testing"
"sync"
"github.com/stretchr/testify/assert"
)
@@ -42,6 +43,7 @@ func TestUpdatePeerState(t *testing.T) {
status := NewRecorder("https://mgm")
peerState := State{
PubKey: key,
Mux: new(sync.RWMutex),
}
status.peers[key] = peerState
@@ -62,6 +64,7 @@ func TestStatus_UpdatePeerFQDN(t *testing.T) {
status := NewRecorder("https://mgm")
peerState := State{
PubKey: key,
Mux: new(sync.RWMutex),
}
status.peers[key] = peerState
@@ -80,6 +83,7 @@ func TestGetPeerStateChangeNotifierLogic(t *testing.T) {
status := NewRecorder("https://mgm")
peerState := State{
PubKey: key,
Mux: new(sync.RWMutex),
}
status.peers[key] = peerState
@@ -104,6 +108,7 @@ func TestRemovePeer(t *testing.T) {
status := NewRecorder("https://mgm")
peerState := State{
PubKey: key,
Mux: new(sync.RWMutex),
}
status.peers[key] = peerState

View File

@@ -3,6 +3,7 @@ package routemanager
import (
"context"
"fmt"
"net"
"net/netip"
"time"
@@ -155,7 +156,11 @@ func (c *clientNetwork) getBestRouteFromStatuses(routePeerStatuses map[string]ro
if currScore != 0 && currScore < chosenScore+0.1 {
return currID
} else {
log.Infof("new chosen route is %s with peer %s with score %f for network %s", chosen, c.routes[chosen].Peer, chosenScore, c.network)
var peer string
if route := c.routes[chosen]; route != nil {
peer = route.Peer
}
log.Infof("new chosen route is %s with peer %s with score %f for network %s", chosen, peer, chosenScore, c.network)
}
}
@@ -196,7 +201,7 @@ func (c *clientNetwork) removeRouteFromWireguardPeer(peerKey string) error {
return fmt.Errorf("get peer state: %v", err)
}
delete(state.Routes, c.network.String())
state.DeleteRoute(c.network.String())
if err := c.statusRecorder.UpdatePeerState(state); err != nil {
log.Warnf("Failed to update peer state: %v", err)
}
@@ -215,7 +220,7 @@ func (c *clientNetwork) removeRouteFromWireguardPeer(peerKey string) error {
func (c *clientNetwork) removeRouteFromPeerAndSystem() error {
if c.chosenRoute != nil {
if err := removeVPNRoute(c.network, c.wgInterface.Name()); err != nil {
if err := removeVPNRoute(c.network, c.getAsInterface()); err != nil {
return fmt.Errorf("remove route %s from system, err: %v", c.network, err)
}
@@ -256,7 +261,7 @@ func (c *clientNetwork) recalculateRouteAndUpdatePeerAndSystem() error {
}
} else {
// otherwise add the route to the system
if err := addVPNRoute(c.network, c.wgInterface.Name()); err != nil {
if err := addVPNRoute(c.network, c.getAsInterface()); err != nil {
return fmt.Errorf("route %s couldn't be added for peer %s, err: %v",
c.network.String(), c.wgInterface.Address().IP.String(), err)
}
@@ -268,10 +273,7 @@ func (c *clientNetwork) recalculateRouteAndUpdatePeerAndSystem() error {
if err != nil {
log.Errorf("Failed to get peer state: %v", err)
} else {
if state.Routes == nil {
state.Routes = map[string]struct{}{}
}
state.Routes[c.network.String()] = struct{}{}
state.AddRoute(c.network.String())
if err := c.statusRecorder.UpdatePeerState(state); err != nil {
log.Warnf("Failed to update peer state: %v", err)
}
@@ -347,3 +349,15 @@ func (c *clientNetwork) peersStateAndUpdateWatcher() {
}
}
}
func (c *clientNetwork) getAsInterface() *net.Interface {
intf, err := net.InterfaceByName(c.wgInterface.Name())
if err != nil {
log.Warnf("Couldn't get interface by name %s: %v", c.wgInterface.Name(), err)
intf = &net.Interface{
Name: c.wgInterface.Name(),
}
}
return intf
}

View File

@@ -14,8 +14,10 @@ import (
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/internal/listener"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/routeselector"
"github.com/netbirdio/netbird/iface"
"github.com/netbirdio/netbird/route"
nbnet "github.com/netbirdio/netbird/util/net"
"github.com/netbirdio/netbird/version"
)
@@ -27,7 +29,9 @@ var defaultv6 = netip.PrefixFrom(netip.IPv6Unspecified(), 0)
// Manager is a route manager interface
type Manager interface {
Init() (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error)
UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) error
UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) (map[string]*route.Route, map[string][]*route.Route, error)
TriggerSelection(map[string][]*route.Route)
GetRouteSelector() *routeselector.RouteSelector
SetRouteChangeListener(listener listener.NetworkChangeListener)
InitialRouteRange() []string
EnableServerRouter(firewall firewall.Manager) error
@@ -40,6 +44,7 @@ type DefaultManager struct {
stop context.CancelFunc
mux sync.Mutex
clientNetworks map[string]*clientNetwork
routeSelector *routeselector.RouteSelector
serverRouter serverRouter
statusRecorder *peer.Status
wgInterface *iface.WGIface
@@ -53,6 +58,7 @@ func NewManager(ctx context.Context, pubKey string, wgInterface *iface.WGIface,
ctx: mCTX,
stop: cancel,
clientNetworks: make(map[string]*clientNetwork),
routeSelector: routeselector.NewRouteSelector(),
statusRecorder: statusRecorder,
wgInterface: wgInterface,
pubKey: pubKey,
@@ -68,6 +74,10 @@ func NewManager(ctx context.Context, pubKey string, wgInterface *iface.WGIface,
// Init sets up the routing
func (m *DefaultManager) Init() (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) {
if nbnet.CustomRoutingDisabled() {
return nil, nil, nil
}
if err := cleanupRouting(); err != nil {
log.Warnf("Failed cleaning up routing: %v", err)
}
@@ -99,37 +109,42 @@ func (m *DefaultManager) Stop() {
if m.serverRouter != nil {
m.serverRouter.cleanUp()
}
if err := cleanupRouting(); err != nil {
log.Errorf("Error cleaning up routing: %v", err)
} else {
log.Info("Routing cleanup complete")
if !nbnet.CustomRoutingDisabled() {
if err := cleanupRouting(); err != nil {
log.Errorf("Error cleaning up routing: %v", err)
} else {
log.Info("Routing cleanup complete")
}
}
m.ctx = nil
}
// UpdateRoutes compares received routes with existing routes and removes, updates or adds them to the client and server maps
func (m *DefaultManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) error {
func (m *DefaultManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) (map[string]*route.Route, map[string][]*route.Route, error) {
select {
case <-m.ctx.Done():
log.Infof("not updating routes as context is closed")
return m.ctx.Err()
return nil, nil, m.ctx.Err()
default:
m.mux.Lock()
defer m.mux.Unlock()
newServerRoutesMap, newClientRoutesIDMap := m.classifiesRoutes(newRoutes)
newServerRoutesMap, newClientRoutesIDMap := m.classifyRoutes(newRoutes)
m.updateClientNetworks(updateSerial, newClientRoutesIDMap)
m.notifier.onNewRoutes(newClientRoutesIDMap)
filteredClientRoutes := m.routeSelector.FilterSelected(newClientRoutesIDMap)
m.updateClientNetworks(updateSerial, filteredClientRoutes)
m.notifier.onNewRoutes(filteredClientRoutes)
if m.serverRouter != nil {
err := m.serverRouter.updateRoutes(newServerRoutesMap)
if err != nil {
return fmt.Errorf("update routes: %w", err)
return nil, nil, fmt.Errorf("update routes: %w", err)
}
}
return nil
return newServerRoutesMap, newClientRoutesIDMap, nil
}
}
@@ -143,16 +158,51 @@ func (m *DefaultManager) InitialRouteRange() []string {
return m.notifier.initialRouteRanges()
}
func (m *DefaultManager) updateClientNetworks(updateSerial uint64, networks map[string][]*route.Route) {
// removing routes that do not exist as per the update from the Management service.
// GetRouteSelector returns the route selector
func (m *DefaultManager) GetRouteSelector() *routeselector.RouteSelector {
return m.routeSelector
}
// GetClientRoutes returns the client routes
func (m *DefaultManager) GetClientRoutes() map[string]*clientNetwork {
return m.clientNetworks
}
// TriggerSelection triggers the selection of routes, stopping deselected watchers and starting newly selected ones
func (m *DefaultManager) TriggerSelection(networks map[string][]*route.Route) {
m.mux.Lock()
defer m.mux.Unlock()
networks = m.routeSelector.FilterSelected(networks)
m.stopObsoleteClients(networks)
for id, routes := range networks {
if _, found := m.clientNetworks[id]; found {
// don't touch existing client network watchers
continue
}
clientNetworkWatcher := newClientNetworkWatcher(m.ctx, m.wgInterface, m.statusRecorder, routes[0].Network)
m.clientNetworks[id] = clientNetworkWatcher
go clientNetworkWatcher.peersStateAndUpdateWatcher()
clientNetworkWatcher.sendUpdateToClientNetworkWatcher(routesUpdate{routes: routes})
}
}
// stopObsoleteClients stops the client network watcher for the networks that are not in the new list
func (m *DefaultManager) stopObsoleteClients(networks map[string][]*route.Route) {
for id, client := range m.clientNetworks {
_, found := networks[id]
if !found {
log.Debugf("stopping client network watcher, %s", id)
if _, ok := networks[id]; !ok {
log.Debugf("Stopping client network watcher, %s", id)
client.stop()
delete(m.clientNetworks, id)
}
}
}
func (m *DefaultManager) updateClientNetworks(updateSerial uint64, networks map[string][]*route.Route) {
// removing routes that do not exist as per the update from the Management service.
m.stopObsoleteClients(networks)
for id, routes := range networks {
clientNetworkWatcher, found := m.clientNetworks[id]
@@ -169,7 +219,7 @@ func (m *DefaultManager) updateClientNetworks(updateSerial uint64, networks map[
}
}
func (m *DefaultManager) classifiesRoutes(newRoutes []*route.Route) (map[string]*route.Route, map[string][]*route.Route) {
func (m *DefaultManager) classifyRoutes(newRoutes []*route.Route) (map[string]*route.Route, map[string][]*route.Route) {
newClientRoutesIDMap := make(map[string][]*route.Route)
newServerRoutesMap := make(map[string]*route.Route)
ownNetworkIDs := make(map[string]bool)
@@ -201,7 +251,7 @@ func (m *DefaultManager) classifiesRoutes(newRoutes []*route.Route) (map[string]
}
func (m *DefaultManager) clientRoutes(initialRoutes []*route.Route) []*route.Route {
_, crMap := m.classifiesRoutes(initialRoutes)
_, crMap := m.classifyRoutes(initialRoutes)
rs := make([]*route.Route, 0)
for _, routes := range crMap {
rs = append(rs, routes...)
@@ -210,9 +260,11 @@ func (m *DefaultManager) clientRoutes(initialRoutes []*route.Route) []*route.Rou
}
func isPrefixSupported(prefix netip.Prefix) bool {
switch runtime.GOOS {
case "linux", "windows", "darwin":
return true
if !nbnet.CustomRoutingDisabled() {
switch runtime.GOOS {
case "linux", "windows", "darwin", "ios":
return true
}
}
// If prefix is too small, lets assume it is a possible default prefix which is not yet supported

View File

@@ -428,11 +428,11 @@ func TestManagerUpdateRoutes(t *testing.T) {
}
if len(testCase.inputInitRoutes) > 0 {
err = routeManager.UpdateRoutes(testCase.inputSerial, testCase.inputRoutes)
_, _, err = routeManager.UpdateRoutes(testCase.inputSerial, testCase.inputRoutes)
require.NoError(t, err, "should update routes with init routes")
}
err = routeManager.UpdateRoutes(testCase.inputSerial+uint64(len(testCase.inputInitRoutes)), testCase.inputRoutes)
_, _, err = routeManager.UpdateRoutes(testCase.inputSerial+uint64(len(testCase.inputInitRoutes)), testCase.inputRoutes)
require.NoError(t, err, "should update routes")
expectedWatchers := testCase.clientNetworkWatchersExpected

View File

@@ -7,14 +7,17 @@ import (
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/internal/listener"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/routeselector"
"github.com/netbirdio/netbird/iface"
"github.com/netbirdio/netbird/route"
)
// MockManager is the mock instance of a route manager
type MockManager struct {
UpdateRoutesFunc func(updateSerial uint64, newRoutes []*route.Route) error
StopFunc func()
UpdateRoutesFunc func(updateSerial uint64, newRoutes []*route.Route) (map[string]*route.Route, map[string][]*route.Route, error)
TriggerSelectionFunc func(map[string][]*route.Route)
GetRouteSelectorFunc func() *routeselector.RouteSelector
StopFunc func()
}
func (m *MockManager) Init() (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) {
@@ -27,11 +30,25 @@ func (m *MockManager) InitialRouteRange() []string {
}
// UpdateRoutes mock implementation of UpdateRoutes from Manager interface
func (m *MockManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) error {
func (m *MockManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) (map[string]*route.Route, map[string][]*route.Route, error) {
if m.UpdateRoutesFunc != nil {
return m.UpdateRoutesFunc(updateSerial, newRoutes)
}
return fmt.Errorf("method UpdateRoutes is not implemented")
return nil, nil, fmt.Errorf("method UpdateRoutes is not implemented")
}
func (m *MockManager) TriggerSelection(networks map[string][]*route.Route) {
if m.TriggerSelectionFunc != nil {
m.TriggerSelectionFunc(networks)
}
}
// GetRouteSelector mock implementation of GetRouteSelector from Manager interface
func (m *MockManager) GetRouteSelector() *routeselector.RouteSelector {
if m.GetRouteSelectorFunc != nil {
return m.GetRouteSelectorFunc()
}
return nil
}
// Start mock implementation of Start from Manager interface

View File

@@ -5,6 +5,7 @@ package routemanager
import (
"errors"
"fmt"
"net"
"net/netip"
"sync"
@@ -17,7 +18,7 @@ import (
type ref struct {
count int
nexthop netip.Addr
intf string
intf *net.Interface
}
type RouteManager struct {
@@ -30,8 +31,8 @@ type RouteManager struct {
mutex sync.Mutex
}
type AddRouteFunc func(prefix netip.Prefix) (nexthop netip.Addr, intf string, err error)
type RemoveRouteFunc func(prefix netip.Prefix, nexthop netip.Addr, intf string) error
type AddRouteFunc func(prefix netip.Prefix) (nexthop netip.Addr, intf *net.Interface, err error)
type RemoveRouteFunc func(prefix netip.Prefix, nexthop netip.Addr, intf *net.Interface) error
func NewRouteManager(addRoute AddRouteFunc, removeRoute RemoveRouteFunc) *RouteManager {
// TODO: read initial routing table into refCountMap

View File

@@ -60,17 +60,13 @@ func addRouteForCurrentDefaultGateway(prefix netip.Prefix) error {
return nil
}
var exitIntf string
gatewayHop, intf, err := getNextHop(defaultGateway)
if err != nil && !errors.Is(err, ErrRouteNotFound) {
return fmt.Errorf("unable to get the next hop for the default gateway address. error: %s", err)
}
if intf != nil {
exitIntf = intf.Name
}
log.Debugf("Adding a new route for gateway %s with next hop %s", gatewayPrefix, gatewayHop)
return addToRouteTable(gatewayPrefix, gatewayHop, exitIntf)
return addToRouteTable(gatewayPrefix, gatewayHop, intf)
}
func getNextHop(ip netip.Addr) (netip.Addr, *net.Interface, error) {
@@ -84,7 +80,7 @@ func getNextHop(ip netip.Addr) (netip.Addr, *net.Interface, error) {
return netip.Addr{}, nil, ErrRouteNotFound
}
log.Debugf("Route for %s: interface %v, nexthop %v, preferred source %v", ip, intf, gateway, preferredSrc)
log.Debugf("Route for %s: interface %v nexthop %v, preferred source %v", ip, intf, gateway, preferredSrc)
if gateway == nil {
if preferredSrc == nil {
return netip.Addr{}, nil, ErrRouteNotFound
@@ -153,12 +149,7 @@ func isSubRange(prefix netip.Prefix) (bool, error) {
// addRouteToNonVPNIntf adds a new route to the routing table for the given prefix and returns the next hop and interface.
// If the next hop or interface is pointing to the VPN interface, it will return the initial values.
func addRouteToNonVPNIntf(
prefix netip.Prefix,
vpnIntf *iface.WGIface,
initialNextHop netip.Addr,
initialIntf *net.Interface,
) (netip.Addr, string, error) {
func addRouteToNonVPNIntf(prefix netip.Prefix, vpnIntf *iface.WGIface, initialNextHop netip.Addr, initialIntf *net.Interface) (netip.Addr, *net.Interface, error) {
addr := prefix.Addr()
switch {
case addr.IsLoopback(),
@@ -168,39 +159,34 @@ func addRouteToNonVPNIntf(
addr.IsUnspecified(),
addr.IsMulticast():
return netip.Addr{}, "", ErrRouteNotAllowed
return netip.Addr{}, nil, ErrRouteNotAllowed
}
// Determine the exit interface and next hop for the prefix, so we can add a specific route
nexthop, intf, err := getNextHop(addr)
if err != nil {
return netip.Addr{}, "", fmt.Errorf("get next hop: %w", err)
return netip.Addr{}, nil, fmt.Errorf("get next hop: %w", err)
}
log.Debugf("Found next hop %s for prefix %s with interface %v", nexthop, prefix, intf)
exitNextHop := nexthop
var exitIntf string
if intf != nil {
exitIntf = intf.Name
}
exitIntf := intf
vpnAddr, ok := netip.AddrFromSlice(vpnIntf.Address().IP)
if !ok {
return netip.Addr{}, "", fmt.Errorf("failed to convert vpn address to netip.Addr")
return netip.Addr{}, nil, fmt.Errorf("failed to convert vpn address to netip.Addr")
}
// if next hop is the VPN address or the interface is the VPN interface, we should use the initial values
if exitNextHop == vpnAddr || exitIntf == vpnIntf.Name() {
if exitNextHop == vpnAddr || exitIntf != nil && exitIntf.Name == vpnIntf.Name() {
log.Debugf("Route for prefix %s is pointing to the VPN interface", prefix)
exitNextHop = initialNextHop
if initialIntf != nil {
exitIntf = initialIntf.Name
}
exitIntf = initialIntf
}
log.Debugf("Adding a new route for prefix %s with next hop %s", prefix, exitNextHop)
if err := addToRouteTable(prefix, exitNextHop, exitIntf); err != nil {
return netip.Addr{}, "", fmt.Errorf("add route to table: %w", err)
return netip.Addr{}, nil, fmt.Errorf("add route to table: %w", err)
}
return exitNextHop, exitIntf, nil
@@ -208,7 +194,7 @@ func addRouteToNonVPNIntf(
// genericAddVPNRoute adds a new route to the vpn interface, it splits the default prefix
// in two /1 prefixes to avoid replacing the existing default route
func genericAddVPNRoute(prefix netip.Prefix, intf string) error {
func genericAddVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
if prefix == defaultv4 {
if err := addToRouteTable(splitDefaultv4_1, netip.Addr{}, intf); err != nil {
return err
@@ -250,7 +236,7 @@ func genericAddVPNRoute(prefix netip.Prefix, intf string) error {
}
// addNonExistingRoute adds a new route to the vpn interface if it doesn't exist in the current routing table
func addNonExistingRoute(prefix netip.Prefix, intf string) error {
func addNonExistingRoute(prefix netip.Prefix, intf *net.Interface) error {
ok, err := existsInRouteTable(prefix)
if err != nil {
return fmt.Errorf("exists in route table: %w", err)
@@ -277,7 +263,7 @@ func addNonExistingRoute(prefix netip.Prefix, intf string) error {
// genericRemoveVPNRoute removes the route from the vpn interface. If a default prefix is given,
// it will remove the split /1 prefixes
func genericRemoveVPNRoute(prefix netip.Prefix, intf string) error {
func genericRemoveVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
if prefix == defaultv4 {
var result *multierror.Error
if err := removeFromRouteTable(splitDefaultv4_1, netip.Addr{}, intf); err != nil {
@@ -343,7 +329,7 @@ func setupRoutingWithRouteManager(routeManager **RouteManager, initAddresses []n
}
*routeManager = NewRouteManager(
func(prefix netip.Prefix) (netip.Addr, string, error) {
func(prefix netip.Prefix) (netip.Addr, *net.Interface, error) {
addr := prefix.Addr()
nexthop, intf := initialNextHopV4, initialIntfV4
if addr.Is6() {

View File

@@ -24,10 +24,10 @@ func enableIPForwarding() error {
return nil
}
func addVPNRoute(netip.Prefix, string) error {
func addVPNRoute(netip.Prefix, *net.Interface) error {
return nil
}
func removeVPNRoute(netip.Prefix, string) error {
func removeVPNRoute(netip.Prefix, *net.Interface) error {
return nil
}

View File

@@ -27,15 +27,15 @@ func cleanupRouting() error {
return cleanupRoutingWithRouteManager(routeManager)
}
func addToRouteTable(prefix netip.Prefix, nexthop netip.Addr, intf string) error {
func addToRouteTable(prefix netip.Prefix, nexthop netip.Addr, intf *net.Interface) error {
return routeCmd("add", prefix, nexthop, intf)
}
func removeFromRouteTable(prefix netip.Prefix, nexthop netip.Addr, intf string) error {
func removeFromRouteTable(prefix netip.Prefix, nexthop netip.Addr, intf *net.Interface) error {
return routeCmd("delete", prefix, nexthop, intf)
}
func routeCmd(action string, prefix netip.Prefix, nexthop netip.Addr, intf string) error {
func routeCmd(action string, prefix netip.Prefix, nexthop netip.Addr, intf *net.Interface) error {
inet := "-inet"
network := prefix.String()
if prefix.IsSingleIP() {
@@ -46,15 +46,15 @@ func routeCmd(action string, prefix netip.Prefix, nexthop netip.Addr, intf strin
// Special case for IPv6 split default route, pointing to the wg interface fails
// TODO: Remove once we have IPv6 support on the interface
if prefix.Bits() == 1 {
intf = "lo0"
intf = &net.Interface{Name: "lo0"}
}
}
args := []string{"-n", action, inet, network}
if nexthop.IsValid() {
args = append(args, nexthop.Unmap().String())
} else if intf != "" {
args = append(args, "-interface", intf)
} else if intf != nil {
args = append(args, "-interface", intf.Name)
}
if err := retryRouteCmd(args); err != nil {

View File

@@ -33,7 +33,7 @@ func init() {
func TestConcurrentRoutes(t *testing.T) {
baseIP := netip.MustParseAddr("192.0.2.0")
intf := "lo0"
intf := &net.Interface{Name: "lo0"}
var wg sync.WaitGroup
for i := 0; i < 1024; i++ {

View File

@@ -24,10 +24,10 @@ func enableIPForwarding() error {
return nil
}
func addVPNRoute(netip.Prefix, string) error {
func addVPNRoute(netip.Prefix, *net.Interface) error {
return nil
}
func removeVPNRoute(netip.Prefix, string) error {
func removeVPNRoute(netip.Prefix, *net.Interface) error {
return nil
}

View File

@@ -4,14 +4,14 @@ package routemanager
import (
"bufio"
"context"
"errors"
"fmt"
"net"
"net/netip"
"os"
"strconv"
"strings"
"syscall"
"time"
"github.com/hashicorp/go-multierror"
log "github.com/sirupsen/logrus"
@@ -32,30 +32,53 @@ const (
rtTablesPath = "/etc/iproute2/rt_tables"
// ipv4ForwardingPath is the path to the file containing the IP forwarding setting.
ipv4ForwardingPath = "/proc/sys/net/ipv4/ip_forward"
ipv4ForwardingPath = "net.ipv4.ip_forward"
rpFilterPath = "net.ipv4.conf.all.rp_filter"
rpFilterInterfacePath = "net.ipv4.conf.%s.rp_filter"
srcValidMarkPath = "net.ipv4.conf.all.src_valid_mark"
)
var ErrTableIDExists = errors.New("ID exists with different name")
var routeManager = &RouteManager{}
var isLegacy = os.Getenv("NB_USE_LEGACY_ROUTING") == "true"
// originalSysctl stores the original sysctl values before they are modified
var originalSysctl map[string]int
// sysctlFailed is used as an indicator to emit a warning when default routes are configured
var sysctlFailed bool
type ruleParams struct {
priority int
fwmark int
tableID int
family int
priority int
invert bool
suppressPrefix int
description string
}
// isLegacy determines whether to use the legacy routing setup
func isLegacy() bool {
return os.Getenv("NB_USE_LEGACY_ROUTING") == "true" || nbnet.CustomRoutingDisabled()
}
// setIsLegacy sets the legacy routing setup
func setIsLegacy(b bool) {
if b {
os.Setenv("NB_USE_LEGACY_ROUTING", "true")
} else {
os.Unsetenv("NB_USE_LEGACY_ROUTING")
}
}
func getSetupRules() []ruleParams {
return []ruleParams{
{nbnet.NetbirdFwmark, NetbirdVPNTableID, netlink.FAMILY_V4, -1, true, -1, "rule v4 netbird"},
{nbnet.NetbirdFwmark, NetbirdVPNTableID, netlink.FAMILY_V6, -1, true, -1, "rule v6 netbird"},
{-1, syscall.RT_TABLE_MAIN, netlink.FAMILY_V4, -1, false, 0, "rule with suppress prefixlen v4"},
{-1, syscall.RT_TABLE_MAIN, netlink.FAMILY_V6, -1, false, 0, "rule with suppress prefixlen v6"},
{100, -1, syscall.RT_TABLE_MAIN, netlink.FAMILY_V4, false, 0, "rule with suppress prefixlen v4"},
{100, -1, syscall.RT_TABLE_MAIN, netlink.FAMILY_V6, false, 0, "rule with suppress prefixlen v6"},
{110, nbnet.NetbirdFwmark, NetbirdVPNTableID, netlink.FAMILY_V4, true, -1, "rule v4 netbird"},
{110, nbnet.NetbirdFwmark, NetbirdVPNTableID, netlink.FAMILY_V6, true, -1, "rule v6 netbird"},
}
}
@@ -69,10 +92,8 @@ func getSetupRules() []ruleParams {
// Rule 2 (VPN Traffic Routing): Directs all remaining traffic to the 'NetbirdVPNTableID' custom routing table.
// This table is where a default route or other specific routes received from the management server are configured,
// enabling VPN connectivity.
//
// The rules are inserted in reverse order, as rules are added from the bottom up in the rule list.
func setupRouting(initAddresses []net.IP, wgIface *iface.WGIface) (_ peer.BeforeAddPeerHookFunc, _ peer.AfterRemovePeerHookFunc, err error) {
if isLegacy {
if isLegacy() {
log.Infof("Using legacy routing setup")
return setupRoutingWithRouteManager(&routeManager, initAddresses, wgIface)
}
@@ -81,6 +102,13 @@ func setupRouting(initAddresses []net.IP, wgIface *iface.WGIface) (_ peer.Before
log.Errorf("Error adding routing table name: %v", err)
}
originalValues, err := setupSysctl(wgIface)
if err != nil {
log.Errorf("Error setting up sysctl: %v", err)
sysctlFailed = true
}
originalSysctl = originalValues
defer func() {
if err != nil {
if cleanErr := cleanupRouting(); cleanErr != nil {
@@ -94,7 +122,7 @@ func setupRouting(initAddresses []net.IP, wgIface *iface.WGIface) (_ peer.Before
if err := addRule(rule); err != nil {
if errors.Is(err, syscall.EOPNOTSUPP) {
log.Warnf("Rule operations are not supported, falling back to the legacy routing setup")
isLegacy = true
setIsLegacy(true)
return setupRoutingWithRouteManager(&routeManager, initAddresses, wgIface)
}
return nil, nil, fmt.Errorf("%s: %w", rule.description, err)
@@ -108,7 +136,7 @@ func setupRouting(initAddresses []net.IP, wgIface *iface.WGIface) (_ peer.Before
// It systematically removes the three rules and any associated routing table entries to ensure a clean state.
// The function uses error aggregation to report any errors encountered during the cleanup process.
func cleanupRouting() error {
if isLegacy {
if isLegacy() {
return cleanupRoutingWithRouteManager(routeManager)
}
@@ -123,27 +151,37 @@ func cleanupRouting() error {
rules := getSetupRules()
for _, rule := range rules {
if err := removeAllRules(rule); err != nil && !errors.Is(err, syscall.EOPNOTSUPP) {
if err := removeRule(rule); err != nil {
result = multierror.Append(result, fmt.Errorf("%s: %w", rule.description, err))
}
}
if err := cleanupSysctl(originalSysctl); err != nil {
result = multierror.Append(result, fmt.Errorf("cleanup sysctl: %w", err))
}
originalSysctl = nil
sysctlFailed = false
return result.ErrorOrNil()
}
func addToRouteTable(prefix netip.Prefix, nexthop netip.Addr, intf string) error {
func addToRouteTable(prefix netip.Prefix, nexthop netip.Addr, intf *net.Interface) error {
return addRoute(prefix, nexthop, intf, syscall.RT_TABLE_MAIN)
}
func removeFromRouteTable(prefix netip.Prefix, nexthop netip.Addr, intf string) error {
func removeFromRouteTable(prefix netip.Prefix, nexthop netip.Addr, intf *net.Interface) error {
return removeRoute(prefix, nexthop, intf, syscall.RT_TABLE_MAIN)
}
func addVPNRoute(prefix netip.Prefix, intf string) error {
if isLegacy {
func addVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
if isLegacy() {
return genericAddVPNRoute(prefix, intf)
}
if sysctlFailed && (prefix == defaultv4 || prefix == defaultv6) {
log.Warnf("Default route is configured but sysctl operations failed, VPN traffic may not be routed correctly, consider using NB_USE_LEGACY_ROUTING=true or setting net.ipv4.conf.*.rp_filter to 2 (loose) or 0 (off)")
}
// No need to check if routes exist as main table takes precedence over the VPN table via Rule 1
// TODO remove this once we have ipv6 support
@@ -158,8 +196,8 @@ func addVPNRoute(prefix netip.Prefix, intf string) error {
return nil
}
func removeVPNRoute(prefix netip.Prefix, intf string) error {
if isLegacy {
func removeVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
if isLegacy() {
return genericRemoveVPNRoute(prefix, intf)
}
@@ -217,7 +255,7 @@ func getRoutes(tableID, family int) ([]netip.Prefix, error) {
}
// addRoute adds a route to a specific routing table identified by tableID.
func addRoute(prefix netip.Prefix, addr netip.Addr, intf string, tableID int) error {
func addRoute(prefix netip.Prefix, addr netip.Addr, intf *net.Interface, tableID int) error {
route := &netlink.Route{
Scope: netlink.SCOPE_UNIVERSE,
Table: tableID,
@@ -277,7 +315,10 @@ func removeUnreachableRoute(prefix netip.Prefix, tableID int) error {
Dst: ipNet,
}
if err := netlink.RouteDel(route); err != nil && !errors.Is(err, syscall.ESRCH) && !errors.Is(err, syscall.EAFNOSUPPORT) {
if err := netlink.RouteDel(route); err != nil &&
!errors.Is(err, syscall.ESRCH) &&
!errors.Is(err, syscall.ENOENT) &&
!errors.Is(err, syscall.EAFNOSUPPORT) {
return fmt.Errorf("netlink remove unreachable route: %w", err)
}
@@ -286,7 +327,7 @@ func removeUnreachableRoute(prefix netip.Prefix, tableID int) error {
}
// removeRoute removes a route from a specific routing table identified by tableID.
func removeRoute(prefix netip.Prefix, addr netip.Addr, intf string, tableID int) error {
func removeRoute(prefix netip.Prefix, addr netip.Addr, intf *net.Interface, tableID int) error {
_, ipNet, err := net.ParseCIDR(prefix.String())
if err != nil {
return fmt.Errorf("parse prefix %s: %w", prefix, err)
@@ -336,22 +377,8 @@ func flushRoutes(tableID, family int) error {
}
func enableIPForwarding() error {
bytes, err := os.ReadFile(ipv4ForwardingPath)
if err != nil {
return fmt.Errorf("read file %s: %w", ipv4ForwardingPath, err)
}
// check if it is already enabled
// see more: https://github.com/netbirdio/netbird/issues/872
if len(bytes) > 0 && bytes[0] == 49 {
return nil
}
//nolint:gosec
if err := os.WriteFile(ipv4ForwardingPath, []byte("1"), 0644); err != nil {
return fmt.Errorf("write file %s: %w", ipv4ForwardingPath, err)
}
return nil
_, err := setSysctl(ipv4ForwardingPath, 1, false)
return err
}
// entryExists checks if the specified ID or name already exists in the rt_tables file
@@ -429,7 +456,7 @@ func addRule(params ruleParams) error {
rule.Invert = params.invert
rule.SuppressPrefixlen = params.suppressPrefix
if err := netlink.RuleAdd(rule); err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) {
if err := netlink.RuleAdd(rule); err != nil && !errors.Is(err, syscall.EEXIST) && !errors.Is(err, syscall.EAFNOSUPPORT) {
return fmt.Errorf("add routing rule: %w", err)
}
@@ -446,58 +473,30 @@ func removeRule(params ruleParams) error {
rule.Priority = params.priority
rule.SuppressPrefixlen = params.suppressPrefix
if err := netlink.RuleDel(rule); err != nil {
if err := netlink.RuleDel(rule); err != nil && !errors.Is(err, syscall.ENOENT) && !errors.Is(err, syscall.EAFNOSUPPORT) {
return fmt.Errorf("remove routing rule: %w", err)
}
return nil
}
func removeAllRules(params ruleParams) error {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
done := make(chan error, 1)
go func() {
for {
if ctx.Err() != nil {
done <- ctx.Err()
return
}
if err := removeRule(params); err != nil {
if errors.Is(err, syscall.ENOENT) || errors.Is(err, syscall.EAFNOSUPPORT) {
done <- nil
return
}
done <- err
return
}
}
}()
select {
case <-ctx.Done():
return ctx.Err()
case err := <-done:
return err
}
}
// addNextHop adds the gateway and device to the route.
func addNextHop(addr netip.Addr, intf string, route *netlink.Route) error {
func addNextHop(addr netip.Addr, intf *net.Interface, route *netlink.Route) error {
if intf != nil {
route.LinkIndex = intf.Index
}
if addr.IsValid() {
route.Gw = addr.AsSlice()
if intf == "" {
intf = addr.Zone()
}
}
if intf != "" {
link, err := netlink.LinkByName(intf)
if err != nil {
return fmt.Errorf("set interface %s: %w", intf, err)
// if zone is set, it means the gateway is a link-local address, so we set the link index
if addr.Zone() != "" && intf == nil {
link, err := netlink.LinkByName(addr.Zone())
if err != nil {
return fmt.Errorf("get link by name for zone %s: %w", addr.Zone(), err)
}
route.LinkIndex = link.Attrs().Index
}
route.LinkIndex = link.Attrs().Index
}
return nil
@@ -509,3 +508,83 @@ func getAddressFamily(prefix netip.Prefix) int {
}
return netlink.FAMILY_V6
}
// setupSysctl configures sysctl settings for RP filtering and source validation.
func setupSysctl(wgIface *iface.WGIface) (map[string]int, error) {
keys := map[string]int{}
var result *multierror.Error
oldVal, err := setSysctl(srcValidMarkPath, 1, false)
if err != nil {
result = multierror.Append(result, err)
} else {
keys[srcValidMarkPath] = oldVal
}
oldVal, err = setSysctl(rpFilterPath, 2, true)
if err != nil {
result = multierror.Append(result, err)
} else {
keys[rpFilterPath] = oldVal
}
interfaces, err := net.Interfaces()
if err != nil {
result = multierror.Append(result, fmt.Errorf("list interfaces: %w", err))
}
for _, intf := range interfaces {
if intf.Name == "lo" || wgIface != nil && intf.Name == wgIface.Name() {
continue
}
i := fmt.Sprintf(rpFilterInterfacePath, intf.Name)
oldVal, err := setSysctl(i, 2, true)
if err != nil {
result = multierror.Append(result, err)
} else {
keys[i] = oldVal
}
}
return keys, result.ErrorOrNil()
}
// setSysctl sets a sysctl configuration, if onlyIfOne is true it will only set the new value if it's set to 1
func setSysctl(key string, desiredValue int, onlyIfOne bool) (int, error) {
path := fmt.Sprintf("/proc/sys/%s", strings.ReplaceAll(key, ".", "/"))
currentValue, err := os.ReadFile(path)
if err != nil {
return -1, fmt.Errorf("read sysctl %s: %w", key, err)
}
currentV, err := strconv.Atoi(strings.TrimSpace(string(currentValue)))
if err != nil && len(currentValue) > 0 {
return -1, fmt.Errorf("convert current desiredValue to int: %w", err)
}
if currentV == desiredValue || onlyIfOne && currentV != 1 {
return currentV, nil
}
//nolint:gosec
if err := os.WriteFile(path, []byte(strconv.Itoa(desiredValue)), 0644); err != nil {
return currentV, fmt.Errorf("write sysctl %s: %w", key, err)
}
log.Debugf("Set sysctl %s from %d to %d", key, currentV, desiredValue)
return currentV, nil
}
func cleanupSysctl(originalSettings map[string]int) error {
var result *multierror.Error
for key, value := range originalSettings {
_, err := setSysctl(key, value, false)
if err != nil {
result = multierror.Append(result, err)
}
}
return result.ErrorOrNil()
}

View File

@@ -3,6 +3,7 @@
package routemanager
import (
"net"
"net/netip"
"runtime"
@@ -14,10 +15,10 @@ func enableIPForwarding() error {
return nil
}
func addVPNRoute(prefix netip.Prefix, intf string) error {
func addVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
return genericAddVPNRoute(prefix, intf)
}
func removeVPNRoute(prefix netip.Prefix, intf string) error {
func removeVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
return genericRemoveVPNRoute(prefix, intf)
}

View File

@@ -50,6 +50,8 @@ func TestAddRemoveRoutes(t *testing.T) {
for n, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
t.Setenv("NB_DISABLE_ROUTE_CACHE", "true")
peerPrivateKey, _ := wgtypes.GeneratePrivateKey()
newNet, err := stdnet.NewNet()
if err != nil {
@@ -61,13 +63,17 @@ func TestAddRemoveRoutes(t *testing.T) {
err = wgInterface.Create()
require.NoError(t, err, "should create testing wireguard interface")
_, _, err = setupRouting(nil, nil)
_, _, err = setupRouting(nil, wgInterface)
require.NoError(t, err)
t.Cleanup(func() {
assert.NoError(t, cleanupRouting())
})
err = genericAddVPNRoute(testCase.prefix, wgInterface.Name())
index, err := net.InterfaceByName(wgInterface.Name())
require.NoError(t, err, "InterfaceByName should not return err")
intf := &net.Interface{Index: index.Index, Name: wgInterface.Name()}
err = addVPNRoute(testCase.prefix, intf)
require.NoError(t, err, "genericAddVPNRoute should not return err")
if testCase.shouldRouteToWireguard {
@@ -78,7 +84,7 @@ func TestAddRemoveRoutes(t *testing.T) {
exists, err := existsInRouteTable(testCase.prefix)
require.NoError(t, err, "existsInRouteTable should not return err")
if exists && testCase.shouldRouteToWireguard {
err = genericRemoveVPNRoute(testCase.prefix, wgInterface.Name())
err = removeVPNRoute(testCase.prefix, intf)
require.NoError(t, err, "genericRemoveVPNRoute should not return err")
prefixGateway, _, err := getNextHop(testCase.prefix.Addr())
@@ -182,12 +188,16 @@ func TestAddExistAndRemoveRoute(t *testing.T) {
}
for n, testCase := range testCases {
var buf bytes.Buffer
log.SetOutput(&buf)
defer func() {
log.SetOutput(os.Stderr)
}()
t.Run(testCase.name, func(t *testing.T) {
t.Setenv("NB_USE_LEGACY_ROUTING", "true")
t.Setenv("NB_DISABLE_ROUTE_CACHE", "true")
peerPrivateKey, _ := wgtypes.GeneratePrivateKey()
newNet, err := stdnet.NewNet()
if err != nil {
@@ -200,14 +210,18 @@ func TestAddExistAndRemoveRoute(t *testing.T) {
err = wgInterface.Create()
require.NoError(t, err, "should create testing wireguard interface")
index, err := net.InterfaceByName(wgInterface.Name())
require.NoError(t, err, "InterfaceByName should not return err")
intf := &net.Interface{Index: index.Index, Name: wgInterface.Name()}
// Prepare the environment
if testCase.preExistingPrefix.IsValid() {
err := genericAddVPNRoute(testCase.preExistingPrefix, wgInterface.Name())
err := addVPNRoute(testCase.preExistingPrefix, intf)
require.NoError(t, err, "should not return err when adding pre-existing route")
}
// Add the route
err = genericAddVPNRoute(testCase.prefix, wgInterface.Name())
err = addVPNRoute(testCase.prefix, intf)
require.NoError(t, err, "should not return err when adding route")
if testCase.shouldAddRoute {
@@ -217,7 +231,7 @@ func TestAddExistAndRemoveRoute(t *testing.T) {
require.True(t, ok, "route should exist")
// remove route again if added
err = genericRemoveVPNRoute(testCase.prefix, wgInterface.Name())
err = removeVPNRoute(testCase.prefix, intf)
require.NoError(t, err, "should not return err")
}
@@ -345,43 +359,47 @@ func setupTestEnv(t *testing.T) {
assert.NoError(t, cleanupRouting())
})
index, err := net.InterfaceByName(wgIface.Name())
require.NoError(t, err, "InterfaceByName should not return err")
intf := &net.Interface{Index: index.Index, Name: wgIface.Name()}
// default route exists in main table and vpn table
err = addVPNRoute(netip.MustParsePrefix("0.0.0.0/0"), wgIface.Name())
err = addVPNRoute(netip.MustParsePrefix("0.0.0.0/0"), intf)
require.NoError(t, err, "addVPNRoute should not return err")
t.Cleanup(func() {
err = removeVPNRoute(netip.MustParsePrefix("0.0.0.0/0"), wgIface.Name())
err = removeVPNRoute(netip.MustParsePrefix("0.0.0.0/0"), intf)
assert.NoError(t, err, "removeVPNRoute should not return err")
})
// 10.0.0.0/8 route exists in main table and vpn table
err = addVPNRoute(netip.MustParsePrefix("10.0.0.0/8"), wgIface.Name())
err = addVPNRoute(netip.MustParsePrefix("10.0.0.0/8"), intf)
require.NoError(t, err, "addVPNRoute should not return err")
t.Cleanup(func() {
err = removeVPNRoute(netip.MustParsePrefix("10.0.0.0/8"), wgIface.Name())
err = removeVPNRoute(netip.MustParsePrefix("10.0.0.0/8"), intf)
assert.NoError(t, err, "removeVPNRoute should not return err")
})
// 10.10.0.0/24 more specific route exists in vpn table
err = addVPNRoute(netip.MustParsePrefix("10.10.0.0/24"), wgIface.Name())
err = addVPNRoute(netip.MustParsePrefix("10.10.0.0/24"), intf)
require.NoError(t, err, "addVPNRoute should not return err")
t.Cleanup(func() {
err = removeVPNRoute(netip.MustParsePrefix("10.10.0.0/24"), wgIface.Name())
err = removeVPNRoute(netip.MustParsePrefix("10.10.0.0/24"), intf)
assert.NoError(t, err, "removeVPNRoute should not return err")
})
// 127.0.10.0/24 more specific route exists in vpn table
err = addVPNRoute(netip.MustParsePrefix("127.0.10.0/24"), wgIface.Name())
err = addVPNRoute(netip.MustParsePrefix("127.0.10.0/24"), intf)
require.NoError(t, err, "addVPNRoute should not return err")
t.Cleanup(func() {
err = removeVPNRoute(netip.MustParsePrefix("127.0.10.0/24"), wgIface.Name())
err = removeVPNRoute(netip.MustParsePrefix("127.0.10.0/24"), intf)
assert.NoError(t, err, "removeVPNRoute should not return err")
})
// unique route in vpn table
err = addVPNRoute(netip.MustParsePrefix("172.16.0.0/12"), wgIface.Name())
err = addVPNRoute(netip.MustParsePrefix("172.16.0.0/12"), intf)
require.NoError(t, err, "addVPNRoute should not return err")
t.Cleanup(func() {
err = removeVPNRoute(netip.MustParsePrefix("172.16.0.0/12"), wgIface.Name())
err = removeVPNRoute(netip.MustParsePrefix("172.16.0.0/12"), intf)
assert.NoError(t, err, "removeVPNRoute should not return err")
})
}

View File

@@ -6,8 +6,12 @@ import (
"fmt"
"net"
"net/netip"
"os"
"os/exec"
"strconv"
"strings"
"sync"
"time"
log "github.com/sirupsen/logrus"
"github.com/yusufpapurcu/wmi"
@@ -21,6 +25,10 @@ type Win32_IP4RouteTable struct {
Mask string
}
var prefixList []netip.Prefix
var lastUpdate time.Time
var mux = sync.Mutex{}
var routeManager *RouteManager
func setupRouting(initAddresses []net.IP, wgIface *iface.WGIface) (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) {
@@ -32,15 +40,23 @@ func cleanupRouting() error {
}
func getRoutesFromTable() ([]netip.Prefix, error) {
var routes []Win32_IP4RouteTable
mux.Lock()
defer mux.Unlock()
query := "SELECT Destination, Mask FROM Win32_IP4RouteTable"
// If many routes are added at the same time this might block for a long time (seconds to minutes), so we cache the result
if !isCacheDisabled() && time.Since(lastUpdate) < 2*time.Second {
return prefixList, nil
}
var routes []Win32_IP4RouteTable
err := wmi.Query(query, &routes)
if err != nil {
return nil, fmt.Errorf("get routes: %w", err)
}
var prefixList []netip.Prefix
prefixList = nil
for _, route := range routes {
addr, err := netip.ParseAddr(route.Destination)
if err != nil {
@@ -60,54 +76,29 @@ func getRoutesFromTable() ([]netip.Prefix, error) {
prefixList = append(prefixList, routePrefix)
}
}
lastUpdate = time.Now()
return prefixList, nil
}
func addRoutePowershell(prefix netip.Prefix, nexthop netip.Addr, intf, intfIdx string) error {
destinationPrefix := prefix.String()
psCmd := "New-NetRoute"
addressFamily := "IPv4"
if prefix.Addr().Is6() {
addressFamily = "IPv6"
}
script := fmt.Sprintf(
`%s -AddressFamily "%s" -DestinationPrefix "%s" -Confirm:$False -ErrorAction Stop`,
psCmd, addressFamily, destinationPrefix,
)
if intfIdx != "" {
script = fmt.Sprintf(
`%s -InterfaceIndex %s`, script, intfIdx,
)
} else {
script = fmt.Sprintf(
`%s -InterfaceAlias "%s"`, script, intf,
)
}
func addRouteCmd(prefix netip.Prefix, nexthop netip.Addr, intf *net.Interface) error {
args := []string{"add", prefix.String()}
if nexthop.IsValid() {
script = fmt.Sprintf(
`%s -NextHop "%s"`, script, nexthop,
)
args = append(args, nexthop.Unmap().String())
} else {
addr := "0.0.0.0"
if prefix.Addr().Is6() {
addr = "::"
}
args = append(args, addr)
}
out, err := exec.Command("powershell", "-Command", script).CombinedOutput()
log.Tracef("PowerShell %s: %s", script, string(out))
if err != nil {
return fmt.Errorf("PowerShell add route: %w", err)
if intf != nil {
args = append(args, "if", strconv.Itoa(intf.Index))
}
return nil
}
func addRouteCmd(prefix netip.Prefix, nexthop netip.Addr, _ string) error {
args := []string{"add", prefix.String(), nexthop.Unmap().String()}
out, err := exec.Command("route", args...).CombinedOutput()
log.Tracef("route %s: %s", strings.Join(args, " "), out)
if err != nil {
return fmt.Errorf("route add: %w", err)
@@ -116,21 +107,20 @@ func addRouteCmd(prefix netip.Prefix, nexthop netip.Addr, _ string) error {
return nil
}
func addToRouteTable(prefix netip.Prefix, nexthop netip.Addr, intf string) error {
var intfIdx string
if nexthop.Zone() != "" {
intfIdx = nexthop.Zone()
func addToRouteTable(prefix netip.Prefix, nexthop netip.Addr, intf *net.Interface) error {
if nexthop.Zone() != "" && intf == nil {
zone, err := strconv.Atoi(nexthop.Zone())
if err != nil {
return fmt.Errorf("invalid zone: %w", err)
}
intf = &net.Interface{Index: zone}
nexthop.WithZone("")
}
// Powershell doesn't support adding routes without an interface but allows to add interface by name
if intf != "" || intfIdx != "" {
return addRoutePowershell(prefix, nexthop, intf, intfIdx)
}
return addRouteCmd(prefix, nexthop, intf)
}
func removeFromRouteTable(prefix netip.Prefix, nexthop netip.Addr, _ string) error {
func removeFromRouteTable(prefix netip.Prefix, nexthop netip.Addr, _ *net.Interface) error {
args := []string{"delete", prefix.String()}
if nexthop.IsValid() {
nexthop.WithZone("")
@@ -145,3 +135,7 @@ func removeFromRouteTable(prefix netip.Prefix, nexthop netip.Addr, _ string) err
}
return nil
}
func isCacheDisabled() bool {
return os.Getenv("NB_DISABLE_ROUTE_CACHE") == "true"
}

View File

@@ -0,0 +1,132 @@
package routeselector
import (
"fmt"
"slices"
"strings"
"github.com/hashicorp/go-multierror"
"golang.org/x/exp/maps"
route "github.com/netbirdio/netbird/route"
)
type RouteSelector struct {
selectedRoutes map[string]struct{}
selectAll bool
}
func NewRouteSelector() *RouteSelector {
return &RouteSelector{
selectedRoutes: map[string]struct{}{},
// default selects all routes
selectAll: true,
}
}
// SelectRoutes updates the selected routes based on the provided route IDs.
func (rs *RouteSelector) SelectRoutes(routes []string, appendRoute bool, allRoutes []string) error {
if !appendRoute {
rs.selectedRoutes = map[string]struct{}{}
}
var multiErr *multierror.Error
for _, route := range routes {
if !slices.Contains(allRoutes, route) {
multiErr = multierror.Append(multiErr, fmt.Errorf("route '%s' is not available", route))
continue
}
rs.selectedRoutes[route] = struct{}{}
}
rs.selectAll = false
if multiErr != nil {
multiErr.ErrorFormat = formatError
}
return multiErr.ErrorOrNil()
}
// SelectAllRoutes sets the selector to select all routes.
func (rs *RouteSelector) SelectAllRoutes() {
rs.selectAll = true
rs.selectedRoutes = map[string]struct{}{}
}
// DeselectRoutes removes specific routes from the selection.
// If the selector is in "select all" mode, it will transition to "select specific" mode.
func (rs *RouteSelector) DeselectRoutes(routes []string, allRoutes []string) error {
if rs.selectAll {
rs.selectAll = false
rs.selectedRoutes = map[string]struct{}{}
for _, route := range allRoutes {
rs.selectedRoutes[route] = struct{}{}
}
}
var multiErr *multierror.Error
for _, route := range routes {
if !slices.Contains(allRoutes, route) {
multiErr = multierror.Append(multiErr, fmt.Errorf("route '%s' is not available", route))
continue
}
delete(rs.selectedRoutes, route)
}
if multiErr != nil {
multiErr.ErrorFormat = formatError
}
return multiErr.ErrorOrNil()
}
// DeselectAllRoutes deselects all routes, effectively disabling route selection.
func (rs *RouteSelector) DeselectAllRoutes() {
rs.selectAll = false
rs.selectedRoutes = map[string]struct{}{}
}
// IsSelected checks if a specific route is selected.
func (rs *RouteSelector) IsSelected(routeID string) bool {
if rs.selectAll {
return true
}
_, selected := rs.selectedRoutes[routeID]
return selected
}
// FilterSelected removes unselected routes from the provided map.
func (rs *RouteSelector) FilterSelected(routes map[string][]*route.Route) map[string][]*route.Route {
if rs.selectAll {
return maps.Clone(routes)
}
filtered := map[string][]*route.Route{}
for id, rt := range routes {
netID := id
if i := strings.LastIndex(id, "-"); i != -1 {
netID = id[:i]
}
if rs.IsSelected(netID) {
filtered[id] = rt
}
}
return filtered
}
func formatError(es []error) string {
if len(es) == 1 {
return fmt.Sprintf("1 error occurred:\n\t* %s", es[0])
}
points := make([]string, len(es))
for i, err := range es {
points[i] = fmt.Sprintf("* %s", err)
}
return fmt.Sprintf(
"%d errors occurred:\n\t%s",
len(es), strings.Join(points, "\n\t"))
}

View File

@@ -0,0 +1,275 @@
package routeselector_test
import (
"slices"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/internal/routeselector"
"github.com/netbirdio/netbird/route"
)
func TestRouteSelector_SelectRoutes(t *testing.T) {
allRoutes := []string{"route1", "route2", "route3"}
tests := []struct {
name string
initialSelected []string
selectRoutes []string
append bool
wantSelected []string
wantError bool
}{
{
name: "Select specific routes, initial all selected",
selectRoutes: []string{"route1", "route2"},
wantSelected: []string{"route1", "route2"},
},
{
name: "Select specific routes, initial all deselected",
initialSelected: []string{},
selectRoutes: []string{"route1", "route2"},
wantSelected: []string{"route1", "route2"},
},
{
name: "Select specific routes with initial selection",
initialSelected: []string{"route1"},
selectRoutes: []string{"route2", "route3"},
wantSelected: []string{"route2", "route3"},
},
{
name: "Select non-existing route",
selectRoutes: []string{"route1", "route4"},
wantSelected: []string{"route1"},
wantError: true,
},
{
name: "Append route with initial selection",
initialSelected: []string{"route1"},
selectRoutes: []string{"route2"},
append: true,
wantSelected: []string{"route1", "route2"},
},
{
name: "Append route without initial selection",
selectRoutes: []string{"route2"},
append: true,
wantSelected: []string{"route2"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
rs := routeselector.NewRouteSelector()
if tt.initialSelected != nil {
err := rs.SelectRoutes(tt.initialSelected, false, allRoutes)
require.NoError(t, err)
}
err := rs.SelectRoutes(tt.selectRoutes, tt.append, allRoutes)
if tt.wantError {
assert.Error(t, err)
} else {
assert.NoError(t, err)
}
for _, id := range allRoutes {
assert.Equal(t, rs.IsSelected(id), slices.Contains(tt.wantSelected, id))
}
})
}
}
func TestRouteSelector_SelectAllRoutes(t *testing.T) {
allRoutes := []string{"route1", "route2", "route3"}
tests := []struct {
name string
initialSelected []string
wantSelected []string
}{
{
name: "Initial all selected",
wantSelected: []string{"route1", "route2", "route3"},
},
{
name: "Initial all deselected",
initialSelected: []string{},
wantSelected: []string{"route1", "route2", "route3"},
},
{
name: "Initial some selected",
initialSelected: []string{"route1"},
wantSelected: []string{"route1", "route2", "route3"},
},
{
name: "Initial all selected",
initialSelected: []string{"route1", "route2", "route3"},
wantSelected: []string{"route1", "route2", "route3"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
rs := routeselector.NewRouteSelector()
if tt.initialSelected != nil {
err := rs.SelectRoutes(tt.initialSelected, false, allRoutes)
require.NoError(t, err)
}
rs.SelectAllRoutes()
for _, id := range allRoutes {
assert.Equal(t, rs.IsSelected(id), slices.Contains(tt.wantSelected, id))
}
})
}
}
func TestRouteSelector_DeselectRoutes(t *testing.T) {
allRoutes := []string{"route1", "route2", "route3"}
tests := []struct {
name string
initialSelected []string
deselectRoutes []string
wantSelected []string
wantError bool
}{
{
name: "Deselect specific routes, initial all selected",
deselectRoutes: []string{"route1", "route2"},
wantSelected: []string{"route3"},
},
{
name: "Deselect specific routes, initial all deselected",
initialSelected: []string{},
deselectRoutes: []string{"route1", "route2"},
wantSelected: []string{},
},
{
name: "Deselect specific routes with initial selection",
initialSelected: []string{"route1", "route2"},
deselectRoutes: []string{"route1", "route3"},
wantSelected: []string{"route2"},
},
{
name: "Deselect non-existing route",
initialSelected: []string{"route1", "route2"},
deselectRoutes: []string{"route1", "route4"},
wantSelected: []string{"route2"},
wantError: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
rs := routeselector.NewRouteSelector()
if tt.initialSelected != nil {
err := rs.SelectRoutes(tt.initialSelected, false, allRoutes)
require.NoError(t, err)
}
err := rs.DeselectRoutes(tt.deselectRoutes, allRoutes)
if tt.wantError {
assert.Error(t, err)
} else {
assert.NoError(t, err)
}
for _, id := range allRoutes {
assert.Equal(t, rs.IsSelected(id), slices.Contains(tt.wantSelected, id))
}
})
}
}
func TestRouteSelector_DeselectAll(t *testing.T) {
allRoutes := []string{"route1", "route2", "route3"}
tests := []struct {
name string
initialSelected []string
wantSelected []string
}{
{
name: "Initial all selected",
wantSelected: []string{},
},
{
name: "Initial all deselected",
initialSelected: []string{},
wantSelected: []string{},
},
{
name: "Initial some selected",
initialSelected: []string{"route1", "route2"},
wantSelected: []string{},
},
{
name: "Initial all selected",
initialSelected: []string{"route1", "route2", "route3"},
wantSelected: []string{},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
rs := routeselector.NewRouteSelector()
if tt.initialSelected != nil {
err := rs.SelectRoutes(tt.initialSelected, false, allRoutes)
require.NoError(t, err)
}
rs.DeselectAllRoutes()
for _, id := range allRoutes {
assert.Equal(t, rs.IsSelected(id), slices.Contains(tt.wantSelected, id))
}
})
}
}
func TestRouteSelector_IsSelected(t *testing.T) {
rs := routeselector.NewRouteSelector()
err := rs.SelectRoutes([]string{"route1", "route2"}, false, []string{"route1", "route2", "route3"})
require.NoError(t, err)
assert.True(t, rs.IsSelected("route1"))
assert.True(t, rs.IsSelected("route2"))
assert.False(t, rs.IsSelected("route3"))
assert.False(t, rs.IsSelected("route4"))
}
func TestRouteSelector_FilterSelected(t *testing.T) {
rs := routeselector.NewRouteSelector()
err := rs.SelectRoutes([]string{"route1", "route2"}, false, []string{"route1", "route2", "route3"})
require.NoError(t, err)
routes := map[string][]*route.Route{
"route1-10.0.0.0/8": {},
"route2-192.168.0.0/16": {},
"route3-172.16.0.0/12": {},
}
filtered := rs.FilterSelected(routes)
assert.Equal(t, map[string][]*route.Route{
"route1-10.0.0.0/8": {},
"route2-192.168.0.0/16": {},
}, filtered)
}

View File

@@ -230,7 +230,7 @@ func (p *WGEBPFProxy) prepareSenderRawSocket() (net.PacketConn, error) {
}
// Set the fwmark on the socket.
err = syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_MARK, nbnet.NetbirdFwmark)
err = nbnet.SetSocketOpt(fd)
if err != nil {
return nil, fmt.Errorf("setting fwmark failed: %w", err)
}

View File

@@ -71,6 +71,42 @@ func (p *Preferences) SetPreSharedKey(key string) {
p.configInput.PreSharedKey = &key
}
// SetRosenpassEnabled store if rosenpass is enabled
func (p *Preferences) SetRosenpassEnabled(enabled bool) {
p.configInput.RosenpassEnabled = &enabled
}
// GetRosenpassEnabled read rosenpass enabled from config file
func (p *Preferences) GetRosenpassEnabled() (bool, error) {
if p.configInput.RosenpassEnabled != nil {
return *p.configInput.RosenpassEnabled, nil
}
cfg, err := internal.ReadConfig(p.configInput.ConfigPath)
if err != nil {
return false, err
}
return cfg.RosenpassEnabled, err
}
// SetRosenpassPermissive store the given permissive and wait for commit
func (p *Preferences) SetRosenpassPermissive(permissive bool) {
p.configInput.RosenpassPermissive = &permissive
}
// GetRosenpassPermissive read rosenpass permissive from config file
func (p *Preferences) GetRosenpassPermissive() (bool, error) {
if p.configInput.RosenpassPermissive != nil {
return *p.configInput.RosenpassPermissive, nil
}
cfg, err := internal.ReadConfig(p.configInput.ConfigPath)
if err != nil {
return false, err
}
return cfg.RosenpassPermissive, err
}
// Commit write out the changes into config file
func (p *Preferences) Commit() error {
_, err := internal.UpdateOrCreateConfig(p.configInput)

View File

@@ -1,17 +1,17 @@
// Code generated by protoc-gen-go. DO NOT EDIT.
// versions:
// protoc-gen-go v1.26.0
// protoc v4.24.3
// protoc v3.12.4
// source: daemon.proto
package proto
import (
_ "github.com/golang/protobuf/protoc-gen-go/descriptor"
duration "github.com/golang/protobuf/ptypes/duration"
timestamp "github.com/golang/protobuf/ptypes/timestamp"
protoreflect "google.golang.org/protobuf/reflect/protoreflect"
protoimpl "google.golang.org/protobuf/runtime/protoimpl"
_ "google.golang.org/protobuf/types/descriptorpb"
durationpb "google.golang.org/protobuf/types/known/durationpb"
timestamppb "google.golang.org/protobuf/types/known/timestamppb"
reflect "reflect"
sync "sync"
)
@@ -23,6 +23,70 @@ const (
_ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20)
)
type LogLevel int32
const (
LogLevel_UNKNOWN LogLevel = 0
LogLevel_PANIC LogLevel = 1
LogLevel_FATAL LogLevel = 2
LogLevel_ERROR LogLevel = 3
LogLevel_WARN LogLevel = 4
LogLevel_INFO LogLevel = 5
LogLevel_DEBUG LogLevel = 6
LogLevel_TRACE LogLevel = 7
)
// Enum value maps for LogLevel.
var (
LogLevel_name = map[int32]string{
0: "UNKNOWN",
1: "PANIC",
2: "FATAL",
3: "ERROR",
4: "WARN",
5: "INFO",
6: "DEBUG",
7: "TRACE",
}
LogLevel_value = map[string]int32{
"UNKNOWN": 0,
"PANIC": 1,
"FATAL": 2,
"ERROR": 3,
"WARN": 4,
"INFO": 5,
"DEBUG": 6,
"TRACE": 7,
}
)
func (x LogLevel) Enum() *LogLevel {
p := new(LogLevel)
*p = x
return p
}
func (x LogLevel) String() string {
return protoimpl.X.EnumStringOf(x.Descriptor(), protoreflect.EnumNumber(x))
}
func (LogLevel) Descriptor() protoreflect.EnumDescriptor {
return file_daemon_proto_enumTypes[0].Descriptor()
}
func (LogLevel) Type() protoreflect.EnumType {
return &file_daemon_proto_enumTypes[0]
}
func (x LogLevel) Number() protoreflect.EnumNumber {
return protoreflect.EnumNumber(x)
}
// Deprecated: Use LogLevel.Descriptor instead.
func (LogLevel) EnumDescriptor() ([]byte, []int) {
return file_daemon_proto_rawDescGZIP(), []int{0}
}
type LoginRequest struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
@@ -766,23 +830,23 @@ type PeerState struct {
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
IP string `protobuf:"bytes,1,opt,name=IP,proto3" json:"IP,omitempty"`
PubKey string `protobuf:"bytes,2,opt,name=pubKey,proto3" json:"pubKey,omitempty"`
ConnStatus string `protobuf:"bytes,3,opt,name=connStatus,proto3" json:"connStatus,omitempty"`
ConnStatusUpdate *timestamppb.Timestamp `protobuf:"bytes,4,opt,name=connStatusUpdate,proto3" json:"connStatusUpdate,omitempty"`
Relayed bool `protobuf:"varint,5,opt,name=relayed,proto3" json:"relayed,omitempty"`
Direct bool `protobuf:"varint,6,opt,name=direct,proto3" json:"direct,omitempty"`
LocalIceCandidateType string `protobuf:"bytes,7,opt,name=localIceCandidateType,proto3" json:"localIceCandidateType,omitempty"`
RemoteIceCandidateType string `protobuf:"bytes,8,opt,name=remoteIceCandidateType,proto3" json:"remoteIceCandidateType,omitempty"`
Fqdn string `protobuf:"bytes,9,opt,name=fqdn,proto3" json:"fqdn,omitempty"`
LocalIceCandidateEndpoint string `protobuf:"bytes,10,opt,name=localIceCandidateEndpoint,proto3" json:"localIceCandidateEndpoint,omitempty"`
RemoteIceCandidateEndpoint string `protobuf:"bytes,11,opt,name=remoteIceCandidateEndpoint,proto3" json:"remoteIceCandidateEndpoint,omitempty"`
LastWireguardHandshake *timestamppb.Timestamp `protobuf:"bytes,12,opt,name=lastWireguardHandshake,proto3" json:"lastWireguardHandshake,omitempty"`
BytesRx int64 `protobuf:"varint,13,opt,name=bytesRx,proto3" json:"bytesRx,omitempty"`
BytesTx int64 `protobuf:"varint,14,opt,name=bytesTx,proto3" json:"bytesTx,omitempty"`
RosenpassEnabled bool `protobuf:"varint,15,opt,name=rosenpassEnabled,proto3" json:"rosenpassEnabled,omitempty"`
Routes []string `protobuf:"bytes,16,rep,name=routes,proto3" json:"routes,omitempty"`
Latency *durationpb.Duration `protobuf:"bytes,17,opt,name=latency,proto3" json:"latency,omitempty"`
IP string `protobuf:"bytes,1,opt,name=IP,proto3" json:"IP,omitempty"`
PubKey string `protobuf:"bytes,2,opt,name=pubKey,proto3" json:"pubKey,omitempty"`
ConnStatus string `protobuf:"bytes,3,opt,name=connStatus,proto3" json:"connStatus,omitempty"`
ConnStatusUpdate *timestamp.Timestamp `protobuf:"bytes,4,opt,name=connStatusUpdate,proto3" json:"connStatusUpdate,omitempty"`
Relayed bool `protobuf:"varint,5,opt,name=relayed,proto3" json:"relayed,omitempty"`
Direct bool `protobuf:"varint,6,opt,name=direct,proto3" json:"direct,omitempty"`
LocalIceCandidateType string `protobuf:"bytes,7,opt,name=localIceCandidateType,proto3" json:"localIceCandidateType,omitempty"`
RemoteIceCandidateType string `protobuf:"bytes,8,opt,name=remoteIceCandidateType,proto3" json:"remoteIceCandidateType,omitempty"`
Fqdn string `protobuf:"bytes,9,opt,name=fqdn,proto3" json:"fqdn,omitempty"`
LocalIceCandidateEndpoint string `protobuf:"bytes,10,opt,name=localIceCandidateEndpoint,proto3" json:"localIceCandidateEndpoint,omitempty"`
RemoteIceCandidateEndpoint string `protobuf:"bytes,11,opt,name=remoteIceCandidateEndpoint,proto3" json:"remoteIceCandidateEndpoint,omitempty"`
LastWireguardHandshake *timestamp.Timestamp `protobuf:"bytes,12,opt,name=lastWireguardHandshake,proto3" json:"lastWireguardHandshake,omitempty"`
BytesRx int64 `protobuf:"varint,13,opt,name=bytesRx,proto3" json:"bytesRx,omitempty"`
BytesTx int64 `protobuf:"varint,14,opt,name=bytesTx,proto3" json:"bytesTx,omitempty"`
RosenpassEnabled bool `protobuf:"varint,15,opt,name=rosenpassEnabled,proto3" json:"rosenpassEnabled,omitempty"`
Routes []string `protobuf:"bytes,16,rep,name=routes,proto3" json:"routes,omitempty"`
Latency *duration.Duration `protobuf:"bytes,17,opt,name=latency,proto3" json:"latency,omitempty"`
}
func (x *PeerState) Reset() {
@@ -838,7 +902,7 @@ func (x *PeerState) GetConnStatus() string {
return ""
}
func (x *PeerState) GetConnStatusUpdate() *timestamppb.Timestamp {
func (x *PeerState) GetConnStatusUpdate() *timestamp.Timestamp {
if x != nil {
return x.ConnStatusUpdate
}
@@ -894,7 +958,7 @@ func (x *PeerState) GetRemoteIceCandidateEndpoint() string {
return ""
}
func (x *PeerState) GetLastWireguardHandshake() *timestamppb.Timestamp {
func (x *PeerState) GetLastWireguardHandshake() *timestamp.Timestamp {
if x != nil {
return x.LastWireguardHandshake
}
@@ -929,7 +993,7 @@ func (x *PeerState) GetRoutes() []string {
return nil
}
func (x *PeerState) GetLatency() *durationpb.Duration {
func (x *PeerState) GetLatency() *duration.Duration {
if x != nil {
return x.Latency
}
@@ -1383,6 +1447,442 @@ func (x *FullStatus) GetDnsServers() []*NSGroupState {
return nil
}
type ListRoutesRequest struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
}
func (x *ListRoutesRequest) Reset() {
*x = ListRoutesRequest{}
if protoimpl.UnsafeEnabled {
mi := &file_daemon_proto_msgTypes[19]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *ListRoutesRequest) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*ListRoutesRequest) ProtoMessage() {}
func (x *ListRoutesRequest) ProtoReflect() protoreflect.Message {
mi := &file_daemon_proto_msgTypes[19]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use ListRoutesRequest.ProtoReflect.Descriptor instead.
func (*ListRoutesRequest) Descriptor() ([]byte, []int) {
return file_daemon_proto_rawDescGZIP(), []int{19}
}
type ListRoutesResponse struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
Routes []*Route `protobuf:"bytes,1,rep,name=routes,proto3" json:"routes,omitempty"`
}
func (x *ListRoutesResponse) Reset() {
*x = ListRoutesResponse{}
if protoimpl.UnsafeEnabled {
mi := &file_daemon_proto_msgTypes[20]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *ListRoutesResponse) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*ListRoutesResponse) ProtoMessage() {}
func (x *ListRoutesResponse) ProtoReflect() protoreflect.Message {
mi := &file_daemon_proto_msgTypes[20]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use ListRoutesResponse.ProtoReflect.Descriptor instead.
func (*ListRoutesResponse) Descriptor() ([]byte, []int) {
return file_daemon_proto_rawDescGZIP(), []int{20}
}
func (x *ListRoutesResponse) GetRoutes() []*Route {
if x != nil {
return x.Routes
}
return nil
}
type SelectRoutesRequest struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
RouteIDs []string `protobuf:"bytes,1,rep,name=routeIDs,proto3" json:"routeIDs,omitempty"`
Append bool `protobuf:"varint,2,opt,name=append,proto3" json:"append,omitempty"`
All bool `protobuf:"varint,3,opt,name=all,proto3" json:"all,omitempty"`
}
func (x *SelectRoutesRequest) Reset() {
*x = SelectRoutesRequest{}
if protoimpl.UnsafeEnabled {
mi := &file_daemon_proto_msgTypes[21]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *SelectRoutesRequest) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*SelectRoutesRequest) ProtoMessage() {}
func (x *SelectRoutesRequest) ProtoReflect() protoreflect.Message {
mi := &file_daemon_proto_msgTypes[21]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use SelectRoutesRequest.ProtoReflect.Descriptor instead.
func (*SelectRoutesRequest) Descriptor() ([]byte, []int) {
return file_daemon_proto_rawDescGZIP(), []int{21}
}
func (x *SelectRoutesRequest) GetRouteIDs() []string {
if x != nil {
return x.RouteIDs
}
return nil
}
func (x *SelectRoutesRequest) GetAppend() bool {
if x != nil {
return x.Append
}
return false
}
func (x *SelectRoutesRequest) GetAll() bool {
if x != nil {
return x.All
}
return false
}
type SelectRoutesResponse struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
}
func (x *SelectRoutesResponse) Reset() {
*x = SelectRoutesResponse{}
if protoimpl.UnsafeEnabled {
mi := &file_daemon_proto_msgTypes[22]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *SelectRoutesResponse) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*SelectRoutesResponse) ProtoMessage() {}
func (x *SelectRoutesResponse) ProtoReflect() protoreflect.Message {
mi := &file_daemon_proto_msgTypes[22]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use SelectRoutesResponse.ProtoReflect.Descriptor instead.
func (*SelectRoutesResponse) Descriptor() ([]byte, []int) {
return file_daemon_proto_rawDescGZIP(), []int{22}
}
type Route struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
ID string `protobuf:"bytes,1,opt,name=ID,proto3" json:"ID,omitempty"`
Network string `protobuf:"bytes,2,opt,name=network,proto3" json:"network,omitempty"`
Selected bool `protobuf:"varint,3,opt,name=selected,proto3" json:"selected,omitempty"`
}
func (x *Route) Reset() {
*x = Route{}
if protoimpl.UnsafeEnabled {
mi := &file_daemon_proto_msgTypes[23]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *Route) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*Route) ProtoMessage() {}
func (x *Route) ProtoReflect() protoreflect.Message {
mi := &file_daemon_proto_msgTypes[23]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use Route.ProtoReflect.Descriptor instead.
func (*Route) Descriptor() ([]byte, []int) {
return file_daemon_proto_rawDescGZIP(), []int{23}
}
func (x *Route) GetID() string {
if x != nil {
return x.ID
}
return ""
}
func (x *Route) GetNetwork() string {
if x != nil {
return x.Network
}
return ""
}
func (x *Route) GetSelected() bool {
if x != nil {
return x.Selected
}
return false
}
type DebugBundleRequest struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
Anonymize bool `protobuf:"varint,1,opt,name=anonymize,proto3" json:"anonymize,omitempty"`
Status string `protobuf:"bytes,2,opt,name=status,proto3" json:"status,omitempty"`
}
func (x *DebugBundleRequest) Reset() {
*x = DebugBundleRequest{}
if protoimpl.UnsafeEnabled {
mi := &file_daemon_proto_msgTypes[24]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *DebugBundleRequest) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*DebugBundleRequest) ProtoMessage() {}
func (x *DebugBundleRequest) ProtoReflect() protoreflect.Message {
mi := &file_daemon_proto_msgTypes[24]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use DebugBundleRequest.ProtoReflect.Descriptor instead.
func (*DebugBundleRequest) Descriptor() ([]byte, []int) {
return file_daemon_proto_rawDescGZIP(), []int{24}
}
func (x *DebugBundleRequest) GetAnonymize() bool {
if x != nil {
return x.Anonymize
}
return false
}
func (x *DebugBundleRequest) GetStatus() string {
if x != nil {
return x.Status
}
return ""
}
type DebugBundleResponse struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
Path string `protobuf:"bytes,1,opt,name=path,proto3" json:"path,omitempty"`
}
func (x *DebugBundleResponse) Reset() {
*x = DebugBundleResponse{}
if protoimpl.UnsafeEnabled {
mi := &file_daemon_proto_msgTypes[25]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *DebugBundleResponse) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*DebugBundleResponse) ProtoMessage() {}
func (x *DebugBundleResponse) ProtoReflect() protoreflect.Message {
mi := &file_daemon_proto_msgTypes[25]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use DebugBundleResponse.ProtoReflect.Descriptor instead.
func (*DebugBundleResponse) Descriptor() ([]byte, []int) {
return file_daemon_proto_rawDescGZIP(), []int{25}
}
func (x *DebugBundleResponse) GetPath() string {
if x != nil {
return x.Path
}
return ""
}
type SetLogLevelRequest struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
Level LogLevel `protobuf:"varint,1,opt,name=level,proto3,enum=daemon.LogLevel" json:"level,omitempty"`
}
func (x *SetLogLevelRequest) Reset() {
*x = SetLogLevelRequest{}
if protoimpl.UnsafeEnabled {
mi := &file_daemon_proto_msgTypes[26]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *SetLogLevelRequest) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*SetLogLevelRequest) ProtoMessage() {}
func (x *SetLogLevelRequest) ProtoReflect() protoreflect.Message {
mi := &file_daemon_proto_msgTypes[26]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use SetLogLevelRequest.ProtoReflect.Descriptor instead.
func (*SetLogLevelRequest) Descriptor() ([]byte, []int) {
return file_daemon_proto_rawDescGZIP(), []int{26}
}
func (x *SetLogLevelRequest) GetLevel() LogLevel {
if x != nil {
return x.Level
}
return LogLevel_UNKNOWN
}
type SetLogLevelResponse struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
}
func (x *SetLogLevelResponse) Reset() {
*x = SetLogLevelResponse{}
if protoimpl.UnsafeEnabled {
mi := &file_daemon_proto_msgTypes[27]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *SetLogLevelResponse) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*SetLogLevelResponse) ProtoMessage() {}
func (x *SetLogLevelResponse) ProtoReflect() protoreflect.Message {
mi := &file_daemon_proto_msgTypes[27]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use SetLogLevelResponse.ProtoReflect.Descriptor instead.
func (*SetLogLevelResponse) Descriptor() ([]byte, []int) {
return file_daemon_proto_rawDescGZIP(), []int{27}
}
var File_daemon_proto protoreflect.FileDescriptor
var file_daemon_proto_rawDesc = []byte{
@@ -1601,32 +2101,92 @@ var file_daemon_proto_rawDesc = []byte{
0x72, 0x65, 0x6c, 0x61, 0x79, 0x73, 0x12, 0x35, 0x0a, 0x0b, 0x64, 0x6e, 0x73, 0x5f, 0x73, 0x65,
0x72, 0x76, 0x65, 0x72, 0x73, 0x18, 0x06, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x14, 0x2e, 0x64, 0x61,
0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4e, 0x53, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x53, 0x74, 0x61, 0x74,
0x65, 0x52, 0x0a, 0x64, 0x6e, 0x73, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x73, 0x32, 0xf7, 0x02,
0x0a, 0x0d, 0x44, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x12,
0x36, 0x0a, 0x05, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x12, 0x14, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f,
0x6e, 0x2e, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x15,
0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x73,
0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x4b, 0x0a, 0x0c, 0x57, 0x61, 0x69, 0x74, 0x53,
0x53, 0x4f, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x12, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e,
0x2e, 0x57, 0x61, 0x69, 0x74, 0x53, 0x53, 0x4f, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x71,
0x75, 0x65, 0x73, 0x74, 0x1a, 0x1c, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x57, 0x61,
0x69, 0x74, 0x53, 0x53, 0x4f, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e,
0x73, 0x65, 0x22, 0x00, 0x12, 0x2d, 0x0a, 0x02, 0x55, 0x70, 0x12, 0x11, 0x2e, 0x64, 0x61, 0x65,
0x6d, 0x6f, 0x6e, 0x2e, 0x55, 0x70, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x12, 0x2e,
0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x55, 0x70, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73,
0x65, 0x22, 0x00, 0x12, 0x39, 0x0a, 0x06, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x12, 0x15, 0x2e,
0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x52, 0x65, 0x71,
0x75, 0x65, 0x73, 0x74, 0x1a, 0x16, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x74,
0x61, 0x74, 0x75, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x33,
0x0a, 0x04, 0x44, 0x6f, 0x77, 0x6e, 0x12, 0x13, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e,
0x44, 0x6f, 0x77, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x14, 0x2e, 0x64, 0x61,
0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x6f, 0x77, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73,
0x65, 0x22, 0x00, 0x12, 0x42, 0x0a, 0x09, 0x47, 0x65, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67,
0x12, 0x18, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x47, 0x65, 0x74, 0x43, 0x6f, 0x6e,
0x66, 0x69, 0x67, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x19, 0x2e, 0x64, 0x61, 0x65,
0x6d, 0x6f, 0x6e, 0x2e, 0x47, 0x65, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x65, 0x73,
0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x42, 0x08, 0x5a, 0x06, 0x2f, 0x70, 0x72, 0x6f, 0x74,
0x6f, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33,
0x65, 0x52, 0x0a, 0x64, 0x6e, 0x73, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x73, 0x22, 0x13, 0x0a,
0x11, 0x4c, 0x69, 0x73, 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65,
0x73, 0x74, 0x22, 0x3b, 0x0a, 0x12, 0x4c, 0x69, 0x73, 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73,
0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x25, 0x0a, 0x06, 0x72, 0x6f, 0x75, 0x74,
0x65, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x0d, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f,
0x6e, 0x2e, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x52, 0x06, 0x72, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x22,
0x5b, 0x0a, 0x13, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x52,
0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x1a, 0x0a, 0x08, 0x72, 0x6f, 0x75, 0x74, 0x65, 0x49,
0x44, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x09, 0x52, 0x08, 0x72, 0x6f, 0x75, 0x74, 0x65, 0x49,
0x44, 0x73, 0x12, 0x16, 0x0a, 0x06, 0x61, 0x70, 0x70, 0x65, 0x6e, 0x64, 0x18, 0x02, 0x20, 0x01,
0x28, 0x08, 0x52, 0x06, 0x61, 0x70, 0x70, 0x65, 0x6e, 0x64, 0x12, 0x10, 0x0a, 0x03, 0x61, 0x6c,
0x6c, 0x18, 0x03, 0x20, 0x01, 0x28, 0x08, 0x52, 0x03, 0x61, 0x6c, 0x6c, 0x22, 0x16, 0x0a, 0x14,
0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x52, 0x65, 0x73, 0x70,
0x6f, 0x6e, 0x73, 0x65, 0x22, 0x4d, 0x0a, 0x05, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x12, 0x0e, 0x0a,
0x02, 0x49, 0x44, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x49, 0x44, 0x12, 0x18, 0x0a,
0x07, 0x6e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07,
0x6e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x12, 0x1a, 0x0a, 0x08, 0x73, 0x65, 0x6c, 0x65, 0x63,
0x74, 0x65, 0x64, 0x18, 0x03, 0x20, 0x01, 0x28, 0x08, 0x52, 0x08, 0x73, 0x65, 0x6c, 0x65, 0x63,
0x74, 0x65, 0x64, 0x22, 0x4a, 0x0a, 0x12, 0x44, 0x65, 0x62, 0x75, 0x67, 0x42, 0x75, 0x6e, 0x64,
0x6c, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x1c, 0x0a, 0x09, 0x61, 0x6e, 0x6f,
0x6e, 0x79, 0x6d, 0x69, 0x7a, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x52, 0x09, 0x61, 0x6e,
0x6f, 0x6e, 0x79, 0x6d, 0x69, 0x7a, 0x65, 0x12, 0x16, 0x0a, 0x06, 0x73, 0x74, 0x61, 0x74, 0x75,
0x73, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x73, 0x74, 0x61, 0x74, 0x75, 0x73, 0x22,
0x29, 0x0a, 0x13, 0x44, 0x65, 0x62, 0x75, 0x67, 0x42, 0x75, 0x6e, 0x64, 0x6c, 0x65, 0x52, 0x65,
0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x70, 0x61, 0x74, 0x68, 0x18, 0x01,
0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x70, 0x61, 0x74, 0x68, 0x22, 0x3c, 0x0a, 0x12, 0x53, 0x65,
0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74,
0x12, 0x26, 0x0a, 0x05, 0x6c, 0x65, 0x76, 0x65, 0x6c, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0e, 0x32,
0x10, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65,
0x6c, 0x52, 0x05, 0x6c, 0x65, 0x76, 0x65, 0x6c, 0x22, 0x15, 0x0a, 0x13, 0x53, 0x65, 0x74, 0x4c,
0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x2a,
0x62, 0x0a, 0x08, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x12, 0x0b, 0x0a, 0x07, 0x55,
0x4e, 0x4b, 0x4e, 0x4f, 0x57, 0x4e, 0x10, 0x00, 0x12, 0x09, 0x0a, 0x05, 0x50, 0x41, 0x4e, 0x49,
0x43, 0x10, 0x01, 0x12, 0x09, 0x0a, 0x05, 0x46, 0x41, 0x54, 0x41, 0x4c, 0x10, 0x02, 0x12, 0x09,
0x0a, 0x05, 0x45, 0x52, 0x52, 0x4f, 0x52, 0x10, 0x03, 0x12, 0x08, 0x0a, 0x04, 0x57, 0x41, 0x52,
0x4e, 0x10, 0x04, 0x12, 0x08, 0x0a, 0x04, 0x49, 0x4e, 0x46, 0x4f, 0x10, 0x05, 0x12, 0x09, 0x0a,
0x05, 0x44, 0x45, 0x42, 0x55, 0x47, 0x10, 0x06, 0x12, 0x09, 0x0a, 0x05, 0x54, 0x52, 0x41, 0x43,
0x45, 0x10, 0x07, 0x32, 0xee, 0x05, 0x0a, 0x0d, 0x44, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x53, 0x65,
0x72, 0x76, 0x69, 0x63, 0x65, 0x12, 0x36, 0x0a, 0x05, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x12, 0x14,
0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x71,
0x75, 0x65, 0x73, 0x74, 0x1a, 0x15, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x6f,
0x67, 0x69, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x4b, 0x0a,
0x0c, 0x57, 0x61, 0x69, 0x74, 0x53, 0x53, 0x4f, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x12, 0x1b, 0x2e,
0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x57, 0x61, 0x69, 0x74, 0x53, 0x53, 0x4f, 0x4c, 0x6f,
0x67, 0x69, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1c, 0x2e, 0x64, 0x61, 0x65,
0x6d, 0x6f, 0x6e, 0x2e, 0x57, 0x61, 0x69, 0x74, 0x53, 0x53, 0x4f, 0x4c, 0x6f, 0x67, 0x69, 0x6e,
0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x2d, 0x0a, 0x02, 0x55, 0x70,
0x12, 0x11, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x55, 0x70, 0x52, 0x65, 0x71, 0x75,
0x65, 0x73, 0x74, 0x1a, 0x12, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x55, 0x70, 0x52,
0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x39, 0x0a, 0x06, 0x53, 0x74, 0x61,
0x74, 0x75, 0x73, 0x12, 0x15, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x74, 0x61,
0x74, 0x75, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x16, 0x2e, 0x64, 0x61, 0x65,
0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e,
0x73, 0x65, 0x22, 0x00, 0x12, 0x33, 0x0a, 0x04, 0x44, 0x6f, 0x77, 0x6e, 0x12, 0x13, 0x2e, 0x64,
0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x6f, 0x77, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73,
0x74, 0x1a, 0x14, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x6f, 0x77, 0x6e, 0x52,
0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x42, 0x0a, 0x09, 0x47, 0x65, 0x74,
0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x18, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e,
0x47, 0x65, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74,
0x1a, 0x19, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x47, 0x65, 0x74, 0x43, 0x6f, 0x6e,
0x66, 0x69, 0x67, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x45, 0x0a,
0x0a, 0x4c, 0x69, 0x73, 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x12, 0x19, 0x2e, 0x64, 0x61,
0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x69, 0x73, 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x52,
0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1a, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e,
0x4c, 0x69, 0x73, 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e,
0x73, 0x65, 0x22, 0x00, 0x12, 0x4b, 0x0a, 0x0c, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x52, 0x6f,
0x75, 0x74, 0x65, 0x73, 0x12, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65,
0x6c, 0x65, 0x63, 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73,
0x74, 0x1a, 0x1c, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x6c, 0x65, 0x63,
0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22,
0x00, 0x12, 0x4d, 0x0a, 0x0e, 0x44, 0x65, 0x73, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x52, 0x6f, 0x75,
0x74, 0x65, 0x73, 0x12, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x6c,
0x65, 0x63, 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74,
0x1a, 0x1c, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74,
0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00,
0x12, 0x48, 0x0a, 0x0b, 0x44, 0x65, 0x62, 0x75, 0x67, 0x42, 0x75, 0x6e, 0x64, 0x6c, 0x65, 0x12,
0x1a, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x65, 0x62, 0x75, 0x67, 0x42, 0x75,
0x6e, 0x64, 0x6c, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1b, 0x2e, 0x64, 0x61,
0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x65, 0x62, 0x75, 0x67, 0x42, 0x75, 0x6e, 0x64, 0x6c, 0x65,
0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x48, 0x0a, 0x0b, 0x53, 0x65,
0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x12, 0x1a, 0x2e, 0x64, 0x61, 0x65, 0x6d,
0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x65,
0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53,
0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e,
0x73, 0x65, 0x22, 0x00, 0x42, 0x08, 0x5a, 0x06, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x06,
0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33,
}
var (
@@ -1641,58 +2201,81 @@ func file_daemon_proto_rawDescGZIP() []byte {
return file_daemon_proto_rawDescData
}
var file_daemon_proto_msgTypes = make([]protoimpl.MessageInfo, 19)
var file_daemon_proto_enumTypes = make([]protoimpl.EnumInfo, 1)
var file_daemon_proto_msgTypes = make([]protoimpl.MessageInfo, 28)
var file_daemon_proto_goTypes = []interface{}{
(*LoginRequest)(nil), // 0: daemon.LoginRequest
(*LoginResponse)(nil), // 1: daemon.LoginResponse
(*WaitSSOLoginRequest)(nil), // 2: daemon.WaitSSOLoginRequest
(*WaitSSOLoginResponse)(nil), // 3: daemon.WaitSSOLoginResponse
(*UpRequest)(nil), // 4: daemon.UpRequest
(*UpResponse)(nil), // 5: daemon.UpResponse
(*StatusRequest)(nil), // 6: daemon.StatusRequest
(*StatusResponse)(nil), // 7: daemon.StatusResponse
(*DownRequest)(nil), // 8: daemon.DownRequest
(*DownResponse)(nil), // 9: daemon.DownResponse
(*GetConfigRequest)(nil), // 10: daemon.GetConfigRequest
(*GetConfigResponse)(nil), // 11: daemon.GetConfigResponse
(*PeerState)(nil), // 12: daemon.PeerState
(*LocalPeerState)(nil), // 13: daemon.LocalPeerState
(*SignalState)(nil), // 14: daemon.SignalState
(*ManagementState)(nil), // 15: daemon.ManagementState
(*RelayState)(nil), // 16: daemon.RelayState
(*NSGroupState)(nil), // 17: daemon.NSGroupState
(*FullStatus)(nil), // 18: daemon.FullStatus
(*timestamppb.Timestamp)(nil), // 19: google.protobuf.Timestamp
(*durationpb.Duration)(nil), // 20: google.protobuf.Duration
(LogLevel)(0), // 0: daemon.LogLevel
(*LoginRequest)(nil), // 1: daemon.LoginRequest
(*LoginResponse)(nil), // 2: daemon.LoginResponse
(*WaitSSOLoginRequest)(nil), // 3: daemon.WaitSSOLoginRequest
(*WaitSSOLoginResponse)(nil), // 4: daemon.WaitSSOLoginResponse
(*UpRequest)(nil), // 5: daemon.UpRequest
(*UpResponse)(nil), // 6: daemon.UpResponse
(*StatusRequest)(nil), // 7: daemon.StatusRequest
(*StatusResponse)(nil), // 8: daemon.StatusResponse
(*DownRequest)(nil), // 9: daemon.DownRequest
(*DownResponse)(nil), // 10: daemon.DownResponse
(*GetConfigRequest)(nil), // 11: daemon.GetConfigRequest
(*GetConfigResponse)(nil), // 12: daemon.GetConfigResponse
(*PeerState)(nil), // 13: daemon.PeerState
(*LocalPeerState)(nil), // 14: daemon.LocalPeerState
(*SignalState)(nil), // 15: daemon.SignalState
(*ManagementState)(nil), // 16: daemon.ManagementState
(*RelayState)(nil), // 17: daemon.RelayState
(*NSGroupState)(nil), // 18: daemon.NSGroupState
(*FullStatus)(nil), // 19: daemon.FullStatus
(*ListRoutesRequest)(nil), // 20: daemon.ListRoutesRequest
(*ListRoutesResponse)(nil), // 21: daemon.ListRoutesResponse
(*SelectRoutesRequest)(nil), // 22: daemon.SelectRoutesRequest
(*SelectRoutesResponse)(nil), // 23: daemon.SelectRoutesResponse
(*Route)(nil), // 24: daemon.Route
(*DebugBundleRequest)(nil), // 25: daemon.DebugBundleRequest
(*DebugBundleResponse)(nil), // 26: daemon.DebugBundleResponse
(*SetLogLevelRequest)(nil), // 27: daemon.SetLogLevelRequest
(*SetLogLevelResponse)(nil), // 28: daemon.SetLogLevelResponse
(*timestamp.Timestamp)(nil), // 29: google.protobuf.Timestamp
(*duration.Duration)(nil), // 30: google.protobuf.Duration
}
var file_daemon_proto_depIdxs = []int32{
18, // 0: daemon.StatusResponse.fullStatus:type_name -> daemon.FullStatus
19, // 1: daemon.PeerState.connStatusUpdate:type_name -> google.protobuf.Timestamp
19, // 2: daemon.PeerState.lastWireguardHandshake:type_name -> google.protobuf.Timestamp
20, // 3: daemon.PeerState.latency:type_name -> google.protobuf.Duration
15, // 4: daemon.FullStatus.managementState:type_name -> daemon.ManagementState
14, // 5: daemon.FullStatus.signalState:type_name -> daemon.SignalState
13, // 6: daemon.FullStatus.localPeerState:type_name -> daemon.LocalPeerState
12, // 7: daemon.FullStatus.peers:type_name -> daemon.PeerState
16, // 8: daemon.FullStatus.relays:type_name -> daemon.RelayState
17, // 9: daemon.FullStatus.dns_servers:type_name -> daemon.NSGroupState
0, // 10: daemon.DaemonService.Login:input_type -> daemon.LoginRequest
2, // 11: daemon.DaemonService.WaitSSOLogin:input_type -> daemon.WaitSSOLoginRequest
4, // 12: daemon.DaemonService.Up:input_type -> daemon.UpRequest
6, // 13: daemon.DaemonService.Status:input_type -> daemon.StatusRequest
8, // 14: daemon.DaemonService.Down:input_type -> daemon.DownRequest
10, // 15: daemon.DaemonService.GetConfig:input_type -> daemon.GetConfigRequest
1, // 16: daemon.DaemonService.Login:output_type -> daemon.LoginResponse
3, // 17: daemon.DaemonService.WaitSSOLogin:output_type -> daemon.WaitSSOLoginResponse
5, // 18: daemon.DaemonService.Up:output_type -> daemon.UpResponse
7, // 19: daemon.DaemonService.Status:output_type -> daemon.StatusResponse
9, // 20: daemon.DaemonService.Down:output_type -> daemon.DownResponse
11, // 21: daemon.DaemonService.GetConfig:output_type -> daemon.GetConfigResponse
16, // [16:22] is the sub-list for method output_type
10, // [10:16] is the sub-list for method input_type
10, // [10:10] is the sub-list for extension type_name
10, // [10:10] is the sub-list for extension extendee
0, // [0:10] is the sub-list for field type_name
19, // 0: daemon.StatusResponse.fullStatus:type_name -> daemon.FullStatus
29, // 1: daemon.PeerState.connStatusUpdate:type_name -> google.protobuf.Timestamp
29, // 2: daemon.PeerState.lastWireguardHandshake:type_name -> google.protobuf.Timestamp
30, // 3: daemon.PeerState.latency:type_name -> google.protobuf.Duration
16, // 4: daemon.FullStatus.managementState:type_name -> daemon.ManagementState
15, // 5: daemon.FullStatus.signalState:type_name -> daemon.SignalState
14, // 6: daemon.FullStatus.localPeerState:type_name -> daemon.LocalPeerState
13, // 7: daemon.FullStatus.peers:type_name -> daemon.PeerState
17, // 8: daemon.FullStatus.relays:type_name -> daemon.RelayState
18, // 9: daemon.FullStatus.dns_servers:type_name -> daemon.NSGroupState
24, // 10: daemon.ListRoutesResponse.routes:type_name -> daemon.Route
0, // 11: daemon.SetLogLevelRequest.level:type_name -> daemon.LogLevel
1, // 12: daemon.DaemonService.Login:input_type -> daemon.LoginRequest
3, // 13: daemon.DaemonService.WaitSSOLogin:input_type -> daemon.WaitSSOLoginRequest
5, // 14: daemon.DaemonService.Up:input_type -> daemon.UpRequest
7, // 15: daemon.DaemonService.Status:input_type -> daemon.StatusRequest
9, // 16: daemon.DaemonService.Down:input_type -> daemon.DownRequest
11, // 17: daemon.DaemonService.GetConfig:input_type -> daemon.GetConfigRequest
20, // 18: daemon.DaemonService.ListRoutes:input_type -> daemon.ListRoutesRequest
22, // 19: daemon.DaemonService.SelectRoutes:input_type -> daemon.SelectRoutesRequest
22, // 20: daemon.DaemonService.DeselectRoutes:input_type -> daemon.SelectRoutesRequest
25, // 21: daemon.DaemonService.DebugBundle:input_type -> daemon.DebugBundleRequest
27, // 22: daemon.DaemonService.SetLogLevel:input_type -> daemon.SetLogLevelRequest
2, // 23: daemon.DaemonService.Login:output_type -> daemon.LoginResponse
4, // 24: daemon.DaemonService.WaitSSOLogin:output_type -> daemon.WaitSSOLoginResponse
6, // 25: daemon.DaemonService.Up:output_type -> daemon.UpResponse
8, // 26: daemon.DaemonService.Status:output_type -> daemon.StatusResponse
10, // 27: daemon.DaemonService.Down:output_type -> daemon.DownResponse
12, // 28: daemon.DaemonService.GetConfig:output_type -> daemon.GetConfigResponse
21, // 29: daemon.DaemonService.ListRoutes:output_type -> daemon.ListRoutesResponse
23, // 30: daemon.DaemonService.SelectRoutes:output_type -> daemon.SelectRoutesResponse
23, // 31: daemon.DaemonService.DeselectRoutes:output_type -> daemon.SelectRoutesResponse
26, // 32: daemon.DaemonService.DebugBundle:output_type -> daemon.DebugBundleResponse
28, // 33: daemon.DaemonService.SetLogLevel:output_type -> daemon.SetLogLevelResponse
23, // [23:34] is the sub-list for method output_type
12, // [12:23] is the sub-list for method input_type
12, // [12:12] is the sub-list for extension type_name
12, // [12:12] is the sub-list for extension extendee
0, // [0:12] is the sub-list for field type_name
}
func init() { file_daemon_proto_init() }
@@ -1929,6 +2512,114 @@ func file_daemon_proto_init() {
return nil
}
}
file_daemon_proto_msgTypes[19].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*ListRoutesRequest); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
file_daemon_proto_msgTypes[20].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*ListRoutesResponse); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
file_daemon_proto_msgTypes[21].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*SelectRoutesRequest); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
file_daemon_proto_msgTypes[22].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*SelectRoutesResponse); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
file_daemon_proto_msgTypes[23].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*Route); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
file_daemon_proto_msgTypes[24].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*DebugBundleRequest); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
file_daemon_proto_msgTypes[25].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*DebugBundleResponse); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
file_daemon_proto_msgTypes[26].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*SetLogLevelRequest); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
file_daemon_proto_msgTypes[27].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*SetLogLevelResponse); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
}
file_daemon_proto_msgTypes[0].OneofWrappers = []interface{}{}
type x struct{}
@@ -1936,13 +2627,14 @@ func file_daemon_proto_init() {
File: protoimpl.DescBuilder{
GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
RawDescriptor: file_daemon_proto_rawDesc,
NumEnums: 0,
NumMessages: 19,
NumEnums: 1,
NumMessages: 28,
NumExtensions: 0,
NumServices: 1,
},
GoTypes: file_daemon_proto_goTypes,
DependencyIndexes: file_daemon_proto_depIdxs,
EnumInfos: file_daemon_proto_enumTypes,
MessageInfos: file_daemon_proto_msgTypes,
}.Build()
File_daemon_proto = out.File

View File

@@ -27,6 +27,21 @@ service DaemonService {
// GetConfig of the daemon.
rpc GetConfig(GetConfigRequest) returns (GetConfigResponse) {}
// List available network routes
rpc ListRoutes(ListRoutesRequest) returns (ListRoutesResponse) {}
// Select specific routes
rpc SelectRoutes(SelectRoutesRequest) returns (SelectRoutesResponse) {}
// Deselect specific routes
rpc DeselectRoutes(SelectRoutesRequest) returns (SelectRoutesResponse) {}
// DebugBundle creates a debug bundle
rpc DebugBundle(DebugBundleRequest) returns (DebugBundleResponse) {}
// SetLogLevel sets the log level of the daemon
rpc SetLogLevel(SetLogLevelRequest) returns (SetLogLevelResponse) {}
};
message LoginRequest {
@@ -195,4 +210,53 @@ message FullStatus {
repeated PeerState peers = 4;
repeated RelayState relays = 5;
repeated NSGroupState dns_servers = 6;
}
message ListRoutesRequest {
}
message ListRoutesResponse {
repeated Route routes = 1;
}
message SelectRoutesRequest {
repeated string routeIDs = 1;
bool append = 2;
bool all = 3;
}
message SelectRoutesResponse {
}
message Route {
string ID = 1;
string network = 2;
bool selected = 3;
}
message DebugBundleRequest {
bool anonymize = 1;
string status = 2;
}
message DebugBundleResponse {
string path = 1;
}
enum LogLevel {
UNKNOWN = 0;
PANIC = 1;
FATAL = 2;
ERROR = 3;
WARN = 4;
INFO = 5;
DEBUG = 6;
TRACE = 7;
}
message SetLogLevelRequest {
LogLevel level = 1;
}
message SetLogLevelResponse {
}

View File

@@ -31,6 +31,16 @@ type DaemonServiceClient interface {
Down(ctx context.Context, in *DownRequest, opts ...grpc.CallOption) (*DownResponse, error)
// GetConfig of the daemon.
GetConfig(ctx context.Context, in *GetConfigRequest, opts ...grpc.CallOption) (*GetConfigResponse, error)
// List available network routes
ListRoutes(ctx context.Context, in *ListRoutesRequest, opts ...grpc.CallOption) (*ListRoutesResponse, error)
// Select specific routes
SelectRoutes(ctx context.Context, in *SelectRoutesRequest, opts ...grpc.CallOption) (*SelectRoutesResponse, error)
// Deselect specific routes
DeselectRoutes(ctx context.Context, in *SelectRoutesRequest, opts ...grpc.CallOption) (*SelectRoutesResponse, error)
// DebugBundle creates a debug bundle
DebugBundle(ctx context.Context, in *DebugBundleRequest, opts ...grpc.CallOption) (*DebugBundleResponse, error)
// SetLogLevel sets the log level of the daemon
SetLogLevel(ctx context.Context, in *SetLogLevelRequest, opts ...grpc.CallOption) (*SetLogLevelResponse, error)
}
type daemonServiceClient struct {
@@ -95,6 +105,51 @@ func (c *daemonServiceClient) GetConfig(ctx context.Context, in *GetConfigReques
return out, nil
}
func (c *daemonServiceClient) ListRoutes(ctx context.Context, in *ListRoutesRequest, opts ...grpc.CallOption) (*ListRoutesResponse, error) {
out := new(ListRoutesResponse)
err := c.cc.Invoke(ctx, "/daemon.DaemonService/ListRoutes", in, out, opts...)
if err != nil {
return nil, err
}
return out, nil
}
func (c *daemonServiceClient) SelectRoutes(ctx context.Context, in *SelectRoutesRequest, opts ...grpc.CallOption) (*SelectRoutesResponse, error) {
out := new(SelectRoutesResponse)
err := c.cc.Invoke(ctx, "/daemon.DaemonService/SelectRoutes", in, out, opts...)
if err != nil {
return nil, err
}
return out, nil
}
func (c *daemonServiceClient) DeselectRoutes(ctx context.Context, in *SelectRoutesRequest, opts ...grpc.CallOption) (*SelectRoutesResponse, error) {
out := new(SelectRoutesResponse)
err := c.cc.Invoke(ctx, "/daemon.DaemonService/DeselectRoutes", in, out, opts...)
if err != nil {
return nil, err
}
return out, nil
}
func (c *daemonServiceClient) DebugBundle(ctx context.Context, in *DebugBundleRequest, opts ...grpc.CallOption) (*DebugBundleResponse, error) {
out := new(DebugBundleResponse)
err := c.cc.Invoke(ctx, "/daemon.DaemonService/DebugBundle", in, out, opts...)
if err != nil {
return nil, err
}
return out, nil
}
func (c *daemonServiceClient) SetLogLevel(ctx context.Context, in *SetLogLevelRequest, opts ...grpc.CallOption) (*SetLogLevelResponse, error) {
out := new(SetLogLevelResponse)
err := c.cc.Invoke(ctx, "/daemon.DaemonService/SetLogLevel", in, out, opts...)
if err != nil {
return nil, err
}
return out, nil
}
// DaemonServiceServer is the server API for DaemonService service.
// All implementations must embed UnimplementedDaemonServiceServer
// for forward compatibility
@@ -112,6 +167,16 @@ type DaemonServiceServer interface {
Down(context.Context, *DownRequest) (*DownResponse, error)
// GetConfig of the daemon.
GetConfig(context.Context, *GetConfigRequest) (*GetConfigResponse, error)
// List available network routes
ListRoutes(context.Context, *ListRoutesRequest) (*ListRoutesResponse, error)
// Select specific routes
SelectRoutes(context.Context, *SelectRoutesRequest) (*SelectRoutesResponse, error)
// Deselect specific routes
DeselectRoutes(context.Context, *SelectRoutesRequest) (*SelectRoutesResponse, error)
// DebugBundle creates a debug bundle
DebugBundle(context.Context, *DebugBundleRequest) (*DebugBundleResponse, error)
// SetLogLevel sets the log level of the daemon
SetLogLevel(context.Context, *SetLogLevelRequest) (*SetLogLevelResponse, error)
mustEmbedUnimplementedDaemonServiceServer()
}
@@ -137,6 +202,21 @@ func (UnimplementedDaemonServiceServer) Down(context.Context, *DownRequest) (*Do
func (UnimplementedDaemonServiceServer) GetConfig(context.Context, *GetConfigRequest) (*GetConfigResponse, error) {
return nil, status.Errorf(codes.Unimplemented, "method GetConfig not implemented")
}
func (UnimplementedDaemonServiceServer) ListRoutes(context.Context, *ListRoutesRequest) (*ListRoutesResponse, error) {
return nil, status.Errorf(codes.Unimplemented, "method ListRoutes not implemented")
}
func (UnimplementedDaemonServiceServer) SelectRoutes(context.Context, *SelectRoutesRequest) (*SelectRoutesResponse, error) {
return nil, status.Errorf(codes.Unimplemented, "method SelectRoutes not implemented")
}
func (UnimplementedDaemonServiceServer) DeselectRoutes(context.Context, *SelectRoutesRequest) (*SelectRoutesResponse, error) {
return nil, status.Errorf(codes.Unimplemented, "method DeselectRoutes not implemented")
}
func (UnimplementedDaemonServiceServer) DebugBundle(context.Context, *DebugBundleRequest) (*DebugBundleResponse, error) {
return nil, status.Errorf(codes.Unimplemented, "method DebugBundle not implemented")
}
func (UnimplementedDaemonServiceServer) SetLogLevel(context.Context, *SetLogLevelRequest) (*SetLogLevelResponse, error) {
return nil, status.Errorf(codes.Unimplemented, "method SetLogLevel not implemented")
}
func (UnimplementedDaemonServiceServer) mustEmbedUnimplementedDaemonServiceServer() {}
// UnsafeDaemonServiceServer may be embedded to opt out of forward compatibility for this service.
@@ -258,6 +338,96 @@ func _DaemonService_GetConfig_Handler(srv interface{}, ctx context.Context, dec
return interceptor(ctx, in, info, handler)
}
func _DaemonService_ListRoutes_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(ListRoutesRequest)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(DaemonServiceServer).ListRoutes(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/daemon.DaemonService/ListRoutes",
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(DaemonServiceServer).ListRoutes(ctx, req.(*ListRoutesRequest))
}
return interceptor(ctx, in, info, handler)
}
func _DaemonService_SelectRoutes_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(SelectRoutesRequest)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(DaemonServiceServer).SelectRoutes(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/daemon.DaemonService/SelectRoutes",
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(DaemonServiceServer).SelectRoutes(ctx, req.(*SelectRoutesRequest))
}
return interceptor(ctx, in, info, handler)
}
func _DaemonService_DeselectRoutes_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(SelectRoutesRequest)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(DaemonServiceServer).DeselectRoutes(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/daemon.DaemonService/DeselectRoutes",
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(DaemonServiceServer).DeselectRoutes(ctx, req.(*SelectRoutesRequest))
}
return interceptor(ctx, in, info, handler)
}
func _DaemonService_DebugBundle_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(DebugBundleRequest)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(DaemonServiceServer).DebugBundle(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/daemon.DaemonService/DebugBundle",
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(DaemonServiceServer).DebugBundle(ctx, req.(*DebugBundleRequest))
}
return interceptor(ctx, in, info, handler)
}
func _DaemonService_SetLogLevel_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(SetLogLevelRequest)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(DaemonServiceServer).SetLogLevel(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/daemon.DaemonService/SetLogLevel",
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(DaemonServiceServer).SetLogLevel(ctx, req.(*SetLogLevelRequest))
}
return interceptor(ctx, in, info, handler)
}
// DaemonService_ServiceDesc is the grpc.ServiceDesc for DaemonService service.
// It's only intended for direct use with grpc.RegisterService,
// and not to be introspected or modified (even as a copy)
@@ -289,6 +459,26 @@ var DaemonService_ServiceDesc = grpc.ServiceDesc{
MethodName: "GetConfig",
Handler: _DaemonService_GetConfig_Handler,
},
{
MethodName: "ListRoutes",
Handler: _DaemonService_ListRoutes_Handler,
},
{
MethodName: "SelectRoutes",
Handler: _DaemonService_SelectRoutes_Handler,
},
{
MethodName: "DeselectRoutes",
Handler: _DaemonService_DeselectRoutes_Handler,
},
{
MethodName: "DebugBundle",
Handler: _DaemonService_DebugBundle_Handler,
},
{
MethodName: "SetLogLevel",
Handler: _DaemonService_SetLogLevel_Handler,
},
},
Streams: []grpc.StreamDesc{},
Metadata: "daemon.proto",

175
client/server/debug.go Normal file
View File

@@ -0,0 +1,175 @@
package server
import (
"archive/zip"
"bufio"
"context"
"fmt"
"io"
"os"
"strings"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/anonymize"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/proto"
)
// DebugBundle creates a debug bundle and returns the location.
func (s *Server) DebugBundle(_ context.Context, req *proto.DebugBundleRequest) (resp *proto.DebugBundleResponse, err error) {
s.mutex.Lock()
defer s.mutex.Unlock()
if s.logFile == "console" {
return nil, fmt.Errorf("log file is set to console, cannot create debug bundle")
}
bundlePath, err := os.CreateTemp("", "netbird.debug.*.zip")
if err != nil {
return nil, fmt.Errorf("create zip file: %w", err)
}
defer func() {
if err := bundlePath.Close(); err != nil {
log.Errorf("failed to close zip file: %v", err)
}
if err != nil {
if err2 := os.Remove(bundlePath.Name()); err2 != nil {
log.Errorf("Failed to remove zip file: %v", err2)
}
}
}()
archive := zip.NewWriter(bundlePath)
defer func() {
if err := archive.Close(); err != nil {
log.Errorf("failed to close archive writer: %v", err)
}
}()
if status := req.GetStatus(); status != "" {
filename := "status.txt"
if req.GetAnonymize() {
filename = "status.anon.txt"
}
statusReader := strings.NewReader(status)
if err := addFileToZip(archive, statusReader, filename); err != nil {
return nil, fmt.Errorf("add status file to zip: %w", err)
}
}
logFile, err := os.Open(s.logFile)
if err != nil {
return nil, fmt.Errorf("open log file: %w", err)
}
defer func() {
if err := logFile.Close(); err != nil {
log.Errorf("failed to close original log file: %v", err)
}
}()
filename := "client.log.txt"
var logReader io.Reader
errChan := make(chan error, 1)
if req.GetAnonymize() {
filename = "client.anon.log.txt"
var writer io.WriteCloser
logReader, writer = io.Pipe()
go s.anonymize(logFile, writer, errChan)
} else {
logReader = logFile
}
if err := addFileToZip(archive, logReader, filename); err != nil {
return nil, fmt.Errorf("add log file to zip: %w", err)
}
select {
case err := <-errChan:
if err != nil {
return nil, err
}
default:
}
return &proto.DebugBundleResponse{Path: bundlePath.Name()}, nil
}
func (s *Server) anonymize(reader io.Reader, writer io.WriteCloser, errChan chan<- error) {
scanner := bufio.NewScanner(reader)
anonymizer := anonymize.NewAnonymizer(anonymize.DefaultAddresses())
status := s.statusRecorder.GetFullStatus()
seedFromStatus(anonymizer, &status)
defer func() {
if err := writer.Close(); err != nil {
log.Errorf("Failed to close writer: %v", err)
}
}()
for scanner.Scan() {
line := anonymizer.AnonymizeString(scanner.Text())
if _, err := writer.Write([]byte(line + "\n")); err != nil {
errChan <- fmt.Errorf("write line to writer: %w", err)
return
}
}
if err := scanner.Err(); err != nil {
errChan <- fmt.Errorf("read line from scanner: %w", err)
return
}
}
// SetLogLevel sets the logging level for the server.
func (s *Server) SetLogLevel(_ context.Context, req *proto.SetLogLevelRequest) (*proto.SetLogLevelResponse, error) {
level, err := log.ParseLevel(req.Level.String())
if err != nil {
return nil, fmt.Errorf("invalid log level: %w", err)
}
log.SetLevel(level)
log.Infof("Log level set to %s", level.String())
return &proto.SetLogLevelResponse{}, nil
}
func addFileToZip(archive *zip.Writer, reader io.Reader, filename string) error {
header := &zip.FileHeader{
Name: filename,
Method: zip.Deflate,
}
writer, err := archive.CreateHeader(header)
if err != nil {
return fmt.Errorf("create zip file header: %w", err)
}
if _, err := io.Copy(writer, reader); err != nil {
return fmt.Errorf("write file to zip: %w", err)
}
return nil
}
func seedFromStatus(a *anonymize.Anonymizer, status *peer.FullStatus) {
status.ManagementState.URL = a.AnonymizeURI(status.ManagementState.URL)
status.SignalState.URL = a.AnonymizeURI(status.SignalState.URL)
status.LocalPeerState.FQDN = a.AnonymizeDomain(status.LocalPeerState.FQDN)
for _, peer := range status.Peers {
a.AnonymizeDomain(peer.FQDN)
}
for _, nsGroup := range status.NSGroupStates {
for _, domain := range nsGroup.Domains {
a.AnonymizeDomain(domain)
}
}
for _, relay := range status.Relays {
if relay.URI != nil {
a.AnonymizeURI(relay.URI.String())
}
}
}

110
client/server/route.go Normal file
View File

@@ -0,0 +1,110 @@
package server
import (
"context"
"fmt"
"net/netip"
"sort"
"golang.org/x/exp/maps"
"github.com/netbirdio/netbird/client/proto"
)
type selectRoute struct {
NetID string
Network netip.Prefix
Selected bool
}
// ListRoutes returns a list of all available routes.
func (s *Server) ListRoutes(ctx context.Context, req *proto.ListRoutesRequest) (*proto.ListRoutesResponse, error) {
s.mutex.Lock()
defer s.mutex.Unlock()
if s.engine == nil {
return nil, fmt.Errorf("not connected")
}
routesMap := s.engine.GetClientRoutesWithNetID()
routeSelector := s.engine.GetRouteManager().GetRouteSelector()
var routes []*selectRoute
for id, rt := range routesMap {
if len(rt) == 0 {
continue
}
route := &selectRoute{
NetID: id,
Network: rt[0].Network,
Selected: routeSelector.IsSelected(id),
}
routes = append(routes, route)
}
sort.Slice(routes, func(i, j int) bool {
iPrefix := routes[i].Network.Bits()
jPrefix := routes[j].Network.Bits()
if iPrefix == jPrefix {
iAddr := routes[i].Network.Addr()
jAddr := routes[j].Network.Addr()
if iAddr == jAddr {
return routes[i].NetID < routes[j].NetID
}
return iAddr.String() < jAddr.String()
}
return iPrefix < jPrefix
})
var pbRoutes []*proto.Route
for _, route := range routes {
pbRoutes = append(pbRoutes, &proto.Route{
ID: route.NetID,
Network: route.Network.String(),
Selected: route.Selected,
})
}
return &proto.ListRoutesResponse{
Routes: pbRoutes,
}, nil
}
// SelectRoutes selects specific routes based on the client request.
func (s *Server) SelectRoutes(_ context.Context, req *proto.SelectRoutesRequest) (*proto.SelectRoutesResponse, error) {
s.mutex.Lock()
defer s.mutex.Unlock()
routeManager := s.engine.GetRouteManager()
routeSelector := routeManager.GetRouteSelector()
if req.GetAll() {
routeSelector.SelectAllRoutes()
} else {
if err := routeSelector.SelectRoutes(req.GetRouteIDs(), req.GetAppend(), maps.Keys(s.engine.GetClientRoutesWithNetID())); err != nil {
return nil, fmt.Errorf("select routes: %w", err)
}
}
routeManager.TriggerSelection(s.engine.GetClientRoutes())
return &proto.SelectRoutesResponse{}, nil
}
// DeselectRoutes deselects specific routes based on the client request.
func (s *Server) DeselectRoutes(_ context.Context, req *proto.SelectRoutesRequest) (*proto.SelectRoutesResponse, error) {
s.mutex.Lock()
defer s.mutex.Unlock()
routeManager := s.engine.GetRouteManager()
routeSelector := routeManager.GetRouteSelector()
if req.GetAll() {
routeSelector.DeselectAllRoutes()
} else {
if err := routeSelector.DeselectRoutes(req.GetRouteIDs(), maps.Keys(s.engine.GetClientRoutesWithNetID())); err != nil {
return nil, fmt.Errorf("deselect routes: %w", err)
}
}
routeManager.TriggerSelection(s.engine.GetClientRoutes())
return &proto.SelectRoutesResponse{}, nil
}

View File

@@ -15,15 +15,15 @@ import (
"google.golang.org/protobuf/types/known/durationpb"
"github.com/netbirdio/netbird/client/internal/auth"
"github.com/netbirdio/netbird/client/system"
log "github.com/sirupsen/logrus"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/metadata"
gstatus "google.golang.org/grpc/status"
"google.golang.org/protobuf/types/known/timestamppb"
"github.com/netbirdio/netbird/client/internal/auth"
"github.com/netbirdio/netbird/client/system"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/proto"
@@ -57,6 +57,8 @@ type Server struct {
config *internal.Config
proto.UnimplementedDaemonServiceServer
engine *internal.Engine
statusRecorder *peer.Status
sessionWatcher *internal.SessionWatcher
@@ -141,8 +143,11 @@ func (s *Server) Start() error {
s.sessionWatcher.SetOnExpireListener(s.onSessionExpire)
}
engineChan := make(chan *internal.Engine, 1)
go s.watchEngine(ctx, engineChan)
if !config.DisableAutoConnect {
go s.connectWithRetryRuns(ctx, config, s.statusRecorder, s.mgmProbe, s.signalProbe, s.relayProbe, s.wgProbe)
go s.connectWithRetryRuns(ctx, config, s.statusRecorder, s.mgmProbe, s.signalProbe, s.relayProbe, s.wgProbe, engineChan)
}
return nil
@@ -153,6 +158,7 @@ func (s *Server) Start() error {
// we cancel retry if the client receive a stop or down command, or if disable auto connect is configured.
func (s *Server) connectWithRetryRuns(ctx context.Context, config *internal.Config, statusRecorder *peer.Status,
mgmProbe *internal.Probe, signalProbe *internal.Probe, relayProbe *internal.Probe, wgProbe *internal.Probe,
engineChan chan<- *internal.Engine,
) {
backOff := getConnectWithBackoff(ctx)
retryStarted := false
@@ -182,7 +188,7 @@ func (s *Server) connectWithRetryRuns(ctx context.Context, config *internal.Conf
runOperation := func() error {
log.Tracef("running client connection")
err := internal.RunClientWithProbes(ctx, config, statusRecorder, mgmProbe, signalProbe, relayProbe, wgProbe)
err := internal.RunClientWithProbes(ctx, config, statusRecorder, mgmProbe, signalProbe, relayProbe, wgProbe, engineChan)
if err != nil {
log.Debugf("run client connection exited with error: %v. Will retry in the background", err)
}
@@ -562,7 +568,10 @@ func (s *Server) Up(callerCtx context.Context, _ *proto.UpRequest) (*proto.UpRes
s.statusRecorder.UpdateManagementAddress(s.config.ManagementURL.String())
s.statusRecorder.UpdateRosenpass(s.config.RosenpassEnabled, s.config.RosenpassPermissive)
go s.connectWithRetryRuns(ctx, s.config, s.statusRecorder, s.mgmProbe, s.signalProbe, s.relayProbe, s.wgProbe)
engineChan := make(chan *internal.Engine, 1)
go s.watchEngine(ctx, engineChan)
go s.connectWithRetryRuns(ctx, s.config, s.statusRecorder, s.mgmProbe, s.signalProbe, s.relayProbe, s.wgProbe, engineChan)
return &proto.UpResponse{}, nil
}
@@ -579,6 +588,8 @@ func (s *Server) Down(_ context.Context, _ *proto.DownRequest) (*proto.DownRespo
state := internal.CtxGetState(s.rootCtx)
state.Set(internal.StatusIdle)
s.engine = nil
return &proto.DownResponse{}, nil
}
@@ -661,7 +672,6 @@ func (s *Server) GetConfig(_ context.Context, _ *proto.GetConfigRequest) (*proto
PreSharedKey: preSharedKey,
}, nil
}
func (s *Server) onSessionExpire() {
if runtime.GOOS != "windows" {
isUIActive := internal.CheckUIApp()
@@ -673,6 +683,22 @@ func (s *Server) onSessionExpire() {
}
}
// watchEngine watches the engine channel and updates the engine state
func (s *Server) watchEngine(ctx context.Context, engineChan chan *internal.Engine) {
log.Tracef("Started watching engine")
for {
select {
case <-ctx.Done():
s.engine = nil
log.Tracef("Stopped watching engine")
return
case engine := <-engineChan:
log.Tracef("Received engine from watcher")
s.engine = engine
}
}
}
func toProtoFullStatus(fullStatus peer.FullStatus) *proto.FullStatus {
pbFullStatus := proto.FullStatus{
ManagementState: &proto.ManagementState{},
@@ -718,7 +744,7 @@ func toProtoFullStatus(fullStatus peer.FullStatus) *proto.FullStatus {
BytesRx: peerState.BytesRx,
BytesTx: peerState.BytesTx,
RosenpassEnabled: peerState.RosenpassEnabled,
Routes: maps.Keys(peerState.Routes),
Routes: maps.Keys(peerState.GetRoutes()),
Latency: durationpb.New(peerState.Latency),
}
pbFullStatus.Peers = append(pbFullStatus.Peers, pbPeerState)

View File

@@ -2,11 +2,12 @@ package server
import (
"context"
"github.com/netbirdio/management-integrations/integrations"
"net"
"testing"
"time"
"github.com/netbirdio/management-integrations/integrations"
log "github.com/sirupsen/logrus"
"google.golang.org/grpc"
"google.golang.org/grpc/keepalive"
@@ -69,7 +70,7 @@ func TestConnectWithRetryRuns(t *testing.T) {
t.Setenv(maxRetryTimeVar, "5s")
t.Setenv(retryMultiplierVar, "1")
s.connectWithRetryRuns(ctx, config, s.statusRecorder, s.mgmProbe, s.signalProbe, s.relayProbe, s.wgProbe)
s.connectWithRetryRuns(ctx, config, s.statusRecorder, s.mgmProbe, s.signalProbe, s.relayProbe, s.wgProbe, nil)
if counter < 3 {
t.Fatalf("expected counter > 2, got %d", counter)
}

View File

@@ -1,7 +1,5 @@
//go:build !(linux && 386)
// +build !linux !386
// skipping linux 32 bits build and tests
package main
import (
@@ -58,14 +56,23 @@ func main() {
var showSettings bool
flag.BoolVar(&showSettings, "settings", false, "run settings windows")
var showRoutes bool
flag.BoolVar(&showRoutes, "routes", false, "run routes windows")
var errorMSG string
flag.StringVar(&errorMSG, "error-msg", "", "displays a error message window")
flag.Parse()
a := app.NewWithID("NetBird")
a.SetIcon(fyne.NewStaticResource("netbird", iconDisconnectedPNG))
client := newServiceClient(daemonAddr, a, showSettings)
if showSettings {
if errorMSG != "" {
showErrorMSG(errorMSG)
return
}
client := newServiceClient(daemonAddr, a, showSettings, showRoutes)
if showSettings || showRoutes {
a.Run()
} else {
if err := checkPIDFile(); err != nil {
@@ -128,6 +135,7 @@ type serviceClient struct {
mVersionDaemon *systray.MenuItem
mUpdate *systray.MenuItem
mQuit *systray.MenuItem
mRoutes *systray.MenuItem
// application with main windows.
app fyne.App
@@ -152,12 +160,15 @@ type serviceClient struct {
daemonVersion string
updateIndicationLock sync.Mutex
isUpdateIconActive bool
showRoutes bool
wRoutes fyne.Window
}
// newServiceClient instance constructor
//
// This constructor also builds the UI elements for the settings window.
func newServiceClient(addr string, a fyne.App, showSettings bool) *serviceClient {
func newServiceClient(addr string, a fyne.App, showSettings bool, showRoutes bool) *serviceClient {
s := &serviceClient{
ctx: context.Background(),
addr: addr,
@@ -165,6 +176,7 @@ func newServiceClient(addr string, a fyne.App, showSettings bool) *serviceClient
sendNotification: false,
showSettings: showSettings,
showRoutes: showRoutes,
update: version.NewUpdate(),
}
@@ -184,14 +196,16 @@ func newServiceClient(addr string, a fyne.App, showSettings bool) *serviceClient
}
if showSettings {
s.showUIElements()
s.showSettingsUI()
return s
} else if showRoutes {
s.showRoutesUI()
}
return s
}
func (s *serviceClient) showUIElements() {
func (s *serviceClient) showSettingsUI() {
// add settings window UI elements.
s.wSettings = s.app.NewWindow("NetBird Settings")
s.iMngURL = widget.NewEntry()
@@ -209,6 +223,18 @@ func (s *serviceClient) showUIElements() {
s.wSettings.Show()
}
// showErrorMSG opens a fyne app window to display the supplied message
func showErrorMSG(msg string) {
app := app.New()
w := app.NewWindow("NetBird Error")
content := widget.NewLabel(msg)
content.Wrapping = fyne.TextWrapWord
w.SetContent(content)
w.Resize(fyne.NewSize(400, 100))
w.Show()
app.Run()
}
// getSettingsForm to embed it into settings window.
func (s *serviceClient) getSettingsForm() *widget.Form {
return &widget.Form{
@@ -397,6 +423,7 @@ func (s *serviceClient) updateStatus() error {
s.mStatus.SetTitle("Connected")
s.mUp.Disable()
s.mDown.Enable()
s.mRoutes.Enable()
systrayIconState = true
} else if status.Status != string(internal.StatusConnected) && s.mUp.Disabled() {
s.connected = false
@@ -409,6 +436,7 @@ func (s *serviceClient) updateStatus() error {
s.mStatus.SetTitle("Disconnected")
s.mDown.Disable()
s.mUp.Enable()
s.mRoutes.Disable()
systrayIconState = false
}
@@ -464,9 +492,11 @@ func (s *serviceClient) onTrayReady() {
s.mUp = systray.AddMenuItem("Connect", "Connect")
s.mDown = systray.AddMenuItem("Disconnect", "Disconnect")
s.mDown.Disable()
s.mAdminPanel = systray.AddMenuItem("Admin Panel", "Wiretrustee Admin Panel")
s.mAdminPanel = systray.AddMenuItem("Admin Panel", "Netbird Admin Panel")
systray.AddSeparator()
s.mSettings = systray.AddMenuItem("Settings", "Settings of the application")
s.mRoutes = systray.AddMenuItem("Network Routes", "Open the routes management window")
s.mRoutes.Disable()
systray.AddSeparator()
s.mAbout = systray.AddMenuItem("About", "About")
@@ -504,16 +534,22 @@ func (s *serviceClient) onTrayReady() {
case <-s.mAdminPanel.ClickedCh:
err = open.Run(s.adminURL)
case <-s.mUp.ClickedCh:
s.mUp.Disabled()
go func() {
defer s.mUp.Enable()
err := s.menuUpClick()
if err != nil {
s.runSelfCommand("error-msg", err.Error())
return
}
}()
case <-s.mDown.ClickedCh:
s.mDown.Disable()
go func() {
defer s.mDown.Enable()
err := s.menuDownClick()
if err != nil {
s.runSelfCommand("error-msg", err.Error())
return
}
}()
@@ -521,24 +557,8 @@ func (s *serviceClient) onTrayReady() {
s.mSettings.Disable()
go func() {
defer s.mSettings.Enable()
proc, err := os.Executable()
if err != nil {
log.Errorf("show settings: %v", err)
return
}
cmd := exec.Command(proc, "--settings=true")
out, err := cmd.CombinedOutput()
if exitErr, ok := err.(*exec.ExitError); ok && exitErr.ExitCode() == 1 {
log.Errorf("start settings UI: %v, %s", err, string(out))
return
}
if len(out) != 0 {
log.Info("settings change:", string(out))
}
// update config in systray when settings windows closed
s.getSrvConfig()
defer s.getSrvConfig()
s.runSelfCommand("settings", "true")
}()
case <-s.mQuit.ClickedCh:
systray.Quit()
@@ -548,6 +568,12 @@ func (s *serviceClient) onTrayReady() {
if err != nil {
log.Errorf("%s", err)
}
case <-s.mRoutes.ClickedCh:
s.mRoutes.Disable()
go func() {
defer s.mRoutes.Enable()
s.runSelfCommand("routes", "true")
}()
}
if err != nil {
log.Errorf("process connection: %v", err)
@@ -556,6 +582,24 @@ func (s *serviceClient) onTrayReady() {
}()
}
func (s *serviceClient) runSelfCommand(command, arg string) {
proc, err := os.Executable()
if err != nil {
log.Errorf("show %s failed with error: %v", command, err)
return
}
cmd := exec.Command(proc, fmt.Sprintf("--%s=%s", command, arg))
out, err := cmd.CombinedOutput()
if exitErr, ok := err.(*exec.ExitError); ok && exitErr.ExitCode() == 1 {
log.Errorf("start %s UI: %v, %s", command, err, string(out))
return
}
if len(out) != 0 {
log.Infof("command %s executed: %s", command, string(out))
}
}
func normalizedVersion(version string) string {
versionString := version
if unicode.IsDigit(rune(versionString[0])) {

203
client/ui/route.go Normal file
View File

@@ -0,0 +1,203 @@
//go:build !(linux && 386)
package main
import (
"fmt"
"strings"
"time"
"fyne.io/fyne/v2"
"fyne.io/fyne/v2/container"
"fyne.io/fyne/v2/dialog"
"fyne.io/fyne/v2/layout"
"fyne.io/fyne/v2/widget"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/proto"
)
func (s *serviceClient) showRoutesUI() {
s.wRoutes = s.app.NewWindow("NetBird Routes")
grid := container.New(layout.NewGridLayout(2))
go s.updateRoutes(grid)
routeCheckContainer := container.NewVBox()
routeCheckContainer.Add(grid)
scrollContainer := container.NewVScroll(routeCheckContainer)
scrollContainer.SetMinSize(fyne.NewSize(200, 300))
buttonBox := container.NewHBox(
layout.NewSpacer(),
widget.NewButton("Refresh", func() {
s.updateRoutes(grid)
}),
widget.NewButton("Select all", func() {
s.selectAllRoutes()
s.updateRoutes(grid)
}),
widget.NewButton("Deselect All", func() {
s.deselectAllRoutes()
s.updateRoutes(grid)
}),
layout.NewSpacer(),
)
content := container.NewBorder(nil, buttonBox, nil, nil, scrollContainer)
s.wRoutes.SetContent(content)
s.wRoutes.Show()
s.startAutoRefresh(5*time.Second, grid)
}
func (s *serviceClient) updateRoutes(grid *fyne.Container) {
routes, err := s.fetchRoutes()
if err != nil {
log.Errorf("get client: %v", err)
s.showError(fmt.Errorf("get client: %v", err))
return
}
grid.Objects = nil
idHeader := widget.NewLabelWithStyle(" ID", fyne.TextAlignLeading, fyne.TextStyle{Bold: true})
networkHeader := widget.NewLabelWithStyle("Network", fyne.TextAlignLeading, fyne.TextStyle{Bold: true})
grid.Add(idHeader)
grid.Add(networkHeader)
for _, route := range routes {
r := route
checkBox := widget.NewCheck(r.ID, func(checked bool) {
s.selectRoute(r.ID, checked)
})
checkBox.Checked = route.Selected
checkBox.Resize(fyne.NewSize(20, 20))
checkBox.Refresh()
grid.Add(checkBox)
grid.Add(widget.NewLabel(r.Network))
}
s.wRoutes.Content().Refresh()
}
func (s *serviceClient) fetchRoutes() ([]*proto.Route, error) {
conn, err := s.getSrvClient(defaultFailTimeout)
if err != nil {
return nil, fmt.Errorf("get client: %v", err)
}
resp, err := conn.ListRoutes(s.ctx, &proto.ListRoutesRequest{})
if err != nil {
return nil, fmt.Errorf("failed to list routes: %v", err)
}
return resp.Routes, nil
}
func (s *serviceClient) selectRoute(id string, checked bool) {
conn, err := s.getSrvClient(defaultFailTimeout)
if err != nil {
log.Errorf("get client: %v", err)
s.showError(fmt.Errorf("get client: %v", err))
return
}
req := &proto.SelectRoutesRequest{
RouteIDs: []string{id},
Append: checked,
}
if checked {
if _, err := conn.SelectRoutes(s.ctx, req); err != nil {
log.Errorf("failed to select route: %v", err)
s.showError(fmt.Errorf("failed to select route: %v", err))
return
}
log.Infof("Route %s selected", id)
} else {
if _, err := conn.DeselectRoutes(s.ctx, req); err != nil {
log.Errorf("failed to deselect route: %v", err)
s.showError(fmt.Errorf("failed to deselect route: %v", err))
return
}
log.Infof("Route %s deselected", id)
}
}
func (s *serviceClient) selectAllRoutes() {
conn, err := s.getSrvClient(defaultFailTimeout)
if err != nil {
log.Errorf("get client: %v", err)
return
}
req := &proto.SelectRoutesRequest{
All: true,
}
if _, err := conn.SelectRoutes(s.ctx, req); err != nil {
log.Errorf("failed to select all routes: %v", err)
s.showError(fmt.Errorf("failed to select all routes: %v", err))
return
}
log.Debug("All routes selected")
}
func (s *serviceClient) deselectAllRoutes() {
conn, err := s.getSrvClient(defaultFailTimeout)
if err != nil {
log.Errorf("get client: %v", err)
return
}
req := &proto.SelectRoutesRequest{
All: true,
}
if _, err := conn.DeselectRoutes(s.ctx, req); err != nil {
log.Errorf("failed to deselect all routes: %v", err)
s.showError(fmt.Errorf("failed to deselect all routes: %v", err))
return
}
log.Debug("All routes deselected")
}
func (s *serviceClient) showError(err error) {
wrappedMessage := wrapText(err.Error(), 50)
dialog.ShowError(fmt.Errorf("%s", wrappedMessage), s.wRoutes)
}
func (s *serviceClient) startAutoRefresh(interval time.Duration, grid *fyne.Container) {
ticker := time.NewTicker(interval)
go func() {
for range ticker.C {
s.updateRoutes(grid)
}
}()
s.wRoutes.SetOnClosed(func() {
ticker.Stop()
})
}
// wrapText inserts newlines into the text to ensure that each line is
// no longer than 'lineLength' runes.
func wrapText(text string, lineLength int) string {
var sb strings.Builder
var currentLineLength int
for _, runeValue := range text {
sb.WriteRune(runeValue)
currentLineLength++
if currentLineLength >= lineLength || runeValue == '\n' {
sb.WriteRune('\n')
currentLineLength = 0
}
}
return sb.String()
}

10
go.mod
View File

@@ -21,8 +21,8 @@ require (
github.com/spf13/cobra v1.7.0
github.com/spf13/pflag v1.0.5
github.com/vishvananda/netlink v1.1.1-0.20211118161826-650dca95af54
golang.org/x/crypto v0.18.0
golang.org/x/sys v0.16.0
golang.org/x/crypto v0.21.0
golang.org/x/sys v0.18.0
golang.zx2c4.com/wireguard v0.0.0-20230704135630-469159ecf7d1
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6
golang.zx2c4.com/wireguard/windows v0.5.3
@@ -60,7 +60,7 @@ require (
github.com/miekg/dns v1.1.43
github.com/mitchellh/hashstructure/v2 v2.0.2
github.com/nadoo/ipset v0.5.0
github.com/netbirdio/management-integrations/integrations v0.0.0-20240326083846-3682438fca98
github.com/netbirdio/management-integrations/integrations v0.0.0-20240415094251-369eb33c9b01
github.com/okta/okta-sdk-golang/v2 v2.18.0
github.com/oschwald/maxminddb-golang v1.12.0
github.com/patrickmn/go-cache v2.1.0+incompatible
@@ -82,10 +82,10 @@ require (
goauthentik.io/api/v3 v3.2023051.3
golang.org/x/exp v0.0.0-20230725093048-515e97ebf090
golang.org/x/mobile v0.0.0-20190719004257-d2bd2a29d028
golang.org/x/net v0.20.0
golang.org/x/net v0.23.0
golang.org/x/oauth2 v0.8.0
golang.org/x/sync v0.3.0
golang.org/x/term v0.16.0
golang.org/x/term v0.18.0
google.golang.org/api v0.126.0
gopkg.in/yaml.v3 v3.0.1
gorm.io/driver/sqlite v1.5.3

16
go.sum
View File

@@ -383,8 +383,8 @@ github.com/nadoo/ipset v0.5.0 h1:5GJUAuZ7ITQQQGne5J96AmFjRtI8Avlbk6CabzYWVUc=
github.com/nadoo/ipset v0.5.0/go.mod h1:rYF5DQLRGGoQ8ZSWeK+6eX5amAuPqwFkWjhQlEITGJQ=
github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e h1:PURA50S8u4mF6RrkYYCAvvPCixhqqEiEy3Ej6avh04c=
github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e/go.mod h1:YMLU7qbKfVjmEv7EoZPIVEI+kNYxWCdPK3VS0BU+U4Q=
github.com/netbirdio/management-integrations/integrations v0.0.0-20240326083846-3682438fca98 h1:i6AtenTLu/CqhTmj0g1K/GWkkpMJMhQM6Vjs46x25nA=
github.com/netbirdio/management-integrations/integrations v0.0.0-20240326083846-3682438fca98/go.mod h1:kxks50DrZnhW+oRTdHOkVOJbcTcyo766am8RBugo+Yc=
github.com/netbirdio/management-integrations/integrations v0.0.0-20240415094251-369eb33c9b01 h1:Fu9fq0ndfKVuFTEwbc8Etqui10BOkcMTv0UqcMy0RuY=
github.com/netbirdio/management-integrations/integrations v0.0.0-20240415094251-369eb33c9b01/go.mod h1:kxks50DrZnhW+oRTdHOkVOJbcTcyo766am8RBugo+Yc=
github.com/netbirdio/service v0.0.0-20230215170314-b923b89432b0 h1:hirFRfx3grVA/9eEyjME5/z3nxdJlN9kfQpvWWPk32g=
github.com/netbirdio/service v0.0.0-20230215170314-b923b89432b0/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
github.com/netbirdio/systray v0.0.0-20231030152038-ef1ed2a27949 h1:xbWM9BU6mwZZLHxEjxIX/V8Hv3HurQt4mReIE4mY4DM=
@@ -581,8 +581,9 @@ golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5y
golang.org/x/crypto v0.0.0-20220314234659-1baeb1ce4c0b/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
golang.org/x/crypto v0.8.0/go.mod h1:mRqEX+O9/h5TFCrQhkgjo2yKi0yYA+9ecGkdQoHrywE=
golang.org/x/crypto v0.12.0/go.mod h1:NF0Gs7EO5K4qLn+Ylc+fih8BSTeIjAP05siRnAh98yw=
golang.org/x/crypto v0.18.0 h1:PGVlW0xEltQnzFZ55hkuX5+KLyrMYhHld1YHO4AKcdc=
golang.org/x/crypto v0.18.0/go.mod h1:R0j02AL6hcrfOiy9T4ZYp/rcWeMxM3L6QYxlOuEG1mg=
golang.org/x/crypto v0.21.0 h1:X31++rzVUdKhX5sWmSOFZxx8UW/ldWx55cbf08iNAMA=
golang.org/x/crypto v0.21.0/go.mod h1:0BP7YvVV9gBbVKyeTG0Gyn+gZm94bibOW5BjDEYAOMs=
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20190510132918-efd6b22b2522/go.mod h1:ZjyILWgesfNpC6sMxTJOJm9Kp84zZh5NQWvqDGG3Qr8=
@@ -669,8 +670,9 @@ golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs=
golang.org/x/net v0.9.0/go.mod h1:d48xBJpPfHeWQsugry2m+kC02ZBRGRgulfHnEXEuWns=
golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg=
golang.org/x/net v0.14.0/go.mod h1:PpSgVXXLK0OxS0F31C1/tv6XNguvCrnXIDrFMspZIUI=
golang.org/x/net v0.20.0 h1:aCL9BSgETF1k+blQaYUBx9hJ9LOGP3gAVemcZlf1Kpo=
golang.org/x/net v0.20.0/go.mod h1:z8BVo6PvndSri0LbOE3hAn0apkU+1YvI6E70E9jsnvY=
golang.org/x/net v0.23.0 h1:7EYJ93RZ9vYSZAIb2x3lnuvqO5zneoD6IvWjuhfxjTs=
golang.org/x/net v0.23.0/go.mod h1:JKghWKKOSdJwpW2GEx0Ja7fmaKnMsbu+MWVZTokSYmg=
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
@@ -761,16 +763,18 @@ golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.7.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.16.0 h1:xWw16ngr6ZMtmxDyKyIgsE93KNKz5HKmMa3b8ALHidU=
golang.org/x/sys v0.16.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.18.0 h1:DBdB3niSjOA/O0blCZBqDefyWNYveAYMNF1Wum0DYQ4=
golang.org/x/sys v0.18.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k=
golang.org/x/term v0.7.0/go.mod h1:P32HKFT3hSsZrRxla30E9HqToFYAQPCMs/zFMBUFqPY=
golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo=
golang.org/x/term v0.11.0/go.mod h1:zC9APTIj3jG3FdV/Ons+XE1riIZXG4aZ4GTHiPZJPIU=
golang.org/x/term v0.16.0 h1:m+B6fahuftsE9qjo0VWp2FW0mB3MTJvR0BaMQrq0pmE=
golang.org/x/term v0.16.0/go.mod h1:yn7UURbUtPyrVJPGPq404EukNFxcm/foM+bV/bfcDsY=
golang.org/x/term v0.18.0 h1:FcHjZXDMxI8mM3nwhX9HlKop4C0YQvCVCdwYl2wOtE8=
golang.org/x/term v0.18.0/go.mod h1:ILwASektA3OnRv7amZ1xhE/KTR+u50pbXfZ03+6Nx58=
golang.org/x/text v0.0.0-20160726164857-2910a502d2bf/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=

View File

@@ -10,8 +10,6 @@ import (
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
nbnet "github.com/netbirdio/netbird/util/net"
)
type wgKernelConfigurer struct {
@@ -31,7 +29,7 @@ func (c *wgKernelConfigurer) configureInterface(privateKey string, port int) err
if err != nil {
return err
}
fwmark := nbnet.NetbirdFwmark
fwmark := getFwmark()
config := wgtypes.Config{
PrivateKey: &key,
ReplacePeers: true,

View File

@@ -349,7 +349,7 @@ func toWgUserspaceString(wgCfg wgtypes.Config) string {
}
func getFwmark() int {
if runtime.GOOS == "linux" {
if runtime.GOOS == "linux" && !nbnet.CustomRoutingDisabled() {
return nbnet.NetbirdFwmark
}
return 0

View File

@@ -251,7 +251,7 @@ var (
ctx, cancel := context.WithCancel(cmd.Context())
defer cancel()
httpAPIHandler, err := httpapi.APIHandler(ctx, accountManager, geo, *jwtValidator, appMetrics, httpAPIAuthCfg)
httpAPIHandler, err := httpapi.APIHandler(ctx, accountManager, geo, *jwtValidator, appMetrics, httpAPIAuthCfg, integratedPeerValidator)
if err != nil {
return fmt.Errorf("failed creating HTTP API handler: %v", err)
}

View File

@@ -11,6 +11,7 @@ import (
"net/netip"
"reflect"
"regexp"
"runtime/debug"
"strings"
"sync"
"time"
@@ -46,6 +47,8 @@ const (
DefaultPeerLoginExpiration = 24 * time.Hour
)
type userLoggedInOnce bool
type ExternalCacheManager cache.CacheInterface[*idp.UserData]
func cacheEntryExpiration() time.Duration {
@@ -74,7 +77,7 @@ type AccountManager interface {
GetUser(claims jwtclaims.AuthorizationClaims) (*User, error)
ListUsers(accountID string) ([]*User, error)
GetPeers(accountID, userID string) ([]*nbpeer.Peer, error)
MarkPeerConnected(peerKey string, connected bool, realIP net.IP) error
MarkPeerConnected(peerKey string, connected bool, realIP net.IP, account *Account) error
DeletePeer(accountID, peerID, userID string) error
UpdatePeer(accountID, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, error)
GetNetworkMap(peerID string) (*NetworkMap, error)
@@ -115,8 +118,8 @@ type AccountManager interface {
SaveDNSSettings(accountID string, userID string, dnsSettingsToSave *DNSSettings) error
GetPeer(accountID, peerID, userID string) (*nbpeer.Peer, error)
UpdateAccountSettings(accountID, userID string, newSettings *Settings) (*Account, error)
LoginPeer(login PeerLogin) (*nbpeer.Peer, *NetworkMap, error) // used by peer gRPC API
SyncPeer(sync PeerSync) (*nbpeer.Peer, *NetworkMap, error) // used by peer gRPC API
LoginPeer(login PeerLogin) (*nbpeer.Peer, *NetworkMap, error) // used by peer gRPC API
SyncPeer(sync PeerSync, account *Account) (*nbpeer.Peer, *NetworkMap, error) // used by peer gRPC API
GetAllConnectedPeers() (map[string]struct{}, error)
HasConnectedChannel(peerID string) bool
GetExternalCacheManager() ExternalCacheManager
@@ -128,6 +131,8 @@ type AccountManager interface {
UpdateIntegratedValidatorGroups(accountID string, userID string, groups []string) error
GroupValidation(accountId string, groups []string) (bool, error)
GetValidatedPeers(account *Account) (map[string]struct{}, error)
SyncAndMarkPeer(peerPubKey string, realIP net.IP) (*nbpeer.Peer, *NetworkMap, error)
CancelPeerRoutines(peer *nbpeer.Peer) error
}
type DefaultAccountManager struct {
@@ -242,19 +247,19 @@ type UserPermissions struct {
}
type UserInfo struct {
ID string `json:"id"`
Email string `json:"email"`
Name string `json:"name"`
Role string `json:"role"`
AutoGroups []string `json:"auto_groups"`
Status string `json:"-"`
IsServiceUser bool `json:"is_service_user"`
IsBlocked bool `json:"is_blocked"`
NonDeletable bool `json:"non_deletable"`
LastLogin time.Time `json:"last_login"`
Issued string `json:"issued"`
ID string `json:"id"`
Email string `json:"email"`
Name string `json:"name"`
Role string `json:"role"`
AutoGroups []string `json:"auto_groups"`
Status string `json:"-"`
IsServiceUser bool `json:"is_service_user"`
IsBlocked bool `json:"is_blocked"`
NonDeletable bool `json:"non_deletable"`
LastLogin time.Time `json:"last_login"`
Issued string `json:"issued"`
IntegrationReference integration_reference.IntegrationReference `json:"-"`
Permissions UserPermissions `json:"permissions"`
Permissions UserPermissions `json:"permissions"`
}
// getRoutesToSync returns the enabled routes for the peer ID and the routes
@@ -384,6 +389,8 @@ func (a *Account) GetGroup(groupID string) *nbgroup.Group {
// GetPeerNetworkMap returns a group by ID if exists, nil otherwise
func (a *Account) GetPeerNetworkMap(peerID, dnsDomain string, validatedPeersMap map[string]struct{}) *NetworkMap {
log.Debugf("GetNetworkMap with trace: %s", string(debug.Stack()))
peer := a.Peers[peerID]
if peer == nil {
return &NetworkMap{
@@ -956,7 +963,7 @@ func (am *DefaultAccountManager) UpdateAccountSettings(accountID, userID string,
return nil, status.Errorf(status.InvalidArgument, "peer login expiration can't be smaller than one hour")
}
unlock := am.Store.AcquireAccountLock(accountID)
unlock := am.Store.AcquireAccountWriteLock(accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID)
@@ -1007,7 +1014,7 @@ func (am *DefaultAccountManager) UpdateAccountSettings(accountID, userID string,
func (am *DefaultAccountManager) peerLoginExpirationJob(accountID string) func() (time.Duration, bool) {
return func() (time.Duration, bool) {
unlock := am.Store.AcquireAccountLock(accountID)
unlock := am.Store.AcquireAccountWriteLock(accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID)
@@ -1092,19 +1099,21 @@ func (am *DefaultAccountManager) warmupIDPCache() error {
}
delete(userData, idp.UnsetAccountID)
rcvdUsers := 0
for accountID, users := range userData {
rcvdUsers += len(users)
err = am.cacheManager.Set(am.ctx, accountID, users, cacheStore.WithExpiration(cacheEntryExpiration()))
if err != nil {
return err
}
}
log.Infof("warmed up IDP cache with %d entries", len(userData))
log.Infof("warmed up IDP cache with %d entries for %d accounts", rcvdUsers, len(userData))
return nil
}
// DeleteAccount deletes an account and all its users from local store and from the remote IDP if the requester is an admin and account owner
func (am *DefaultAccountManager) DeleteAccount(accountID, userID string) error {
unlock := am.Store.AcquireAccountLock(accountID)
unlock := am.Store.AcquireAccountWriteLock(accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID)
if err != nil {
@@ -1120,7 +1129,7 @@ func (am *DefaultAccountManager) DeleteAccount(accountID, userID string) error {
return status.Errorf(status.PermissionDenied, "user is not allowed to delete account")
}
if user.Id != account.CreatedBy {
if user.Role != UserRoleOwner {
return status.Errorf(status.PermissionDenied, "user is not allowed to delete account. Only account owner can delete account")
}
for _, otherUser := range account.Users {
@@ -1263,7 +1272,7 @@ func (am *DefaultAccountManager) lookupUserInCacheByEmail(email string, accountI
// lookupUserInCache looks up user in the IdP cache and returns it. If the user wasn't found, the function returns nil
func (am *DefaultAccountManager) lookupUserInCache(userID string, account *Account) (*idp.UserData, error) {
users := make(map[string]struct{}, len(account.Users))
users := make(map[string]userLoggedInOnce, len(account.Users))
// ignore service users and users provisioned by integrations than are never logged in
for _, user := range account.Users {
if user.IsServiceUser {
@@ -1272,7 +1281,7 @@ func (am *DefaultAccountManager) lookupUserInCache(userID string, account *Accou
if user.Issued == UserIssuedIntegration {
continue
}
users[user.Id] = struct{}{}
users[user.Id] = userLoggedInOnce(!user.LastLogin.IsZero())
}
log.Debugf("looking up user %s of account %s in cache", userID, account.Id)
userData, err := am.lookupCache(users, account.Id)
@@ -1345,22 +1354,57 @@ func (am *DefaultAccountManager) getAccountFromCache(accountID string, forceRelo
}
}
func (am *DefaultAccountManager) lookupCache(accountUsers map[string]struct{}, accountID string) ([]*idp.UserData, error) {
data, err := am.getAccountFromCache(accountID, false)
func (am *DefaultAccountManager) lookupCache(accountUsers map[string]userLoggedInOnce, accountID string) ([]*idp.UserData, error) {
var data []*idp.UserData
var err error
maxAttempts := 2
data, err = am.getAccountFromCache(accountID, false)
if err != nil {
return nil, err
}
userDataMap := make(map[string]struct{})
for attempt := 1; attempt <= maxAttempts; attempt++ {
if am.isCacheFresh(accountUsers, data) {
return data, nil
}
if attempt > 1 {
time.Sleep(200 * time.Millisecond)
}
log.Infof("refreshing cache for account %s", accountID)
data, err = am.refreshCache(accountID)
if err != nil {
return nil, err
}
if attempt == maxAttempts {
log.Warnf("cache for account %s reached maximum refresh attempts (%d)", accountID, maxAttempts)
}
}
return data, nil
}
// isCacheFresh checks if the cache is refreshed already by comparing the accountUsers with the cache data by user count and user invite status
func (am *DefaultAccountManager) isCacheFresh(accountUsers map[string]userLoggedInOnce, data []*idp.UserData) bool {
userDataMap := make(map[string]*idp.UserData, len(data))
for _, datum := range data {
userDataMap[datum.ID] = struct{}{}
userDataMap[datum.ID] = datum
}
// the accountUsers ID list of non integration users from store, we check if cache has all of them
// as result of for loop knownUsersCount will have number of users are not presented in the cashed
knownUsersCount := len(accountUsers)
for user := range accountUsers {
if _, ok := userDataMap[user]; ok {
for user, loggedInOnce := range accountUsers {
if datum, ok := userDataMap[user]; ok {
// check if the matching user data has a pending invite and if the user has logged in once, forcing the cache to be refreshed
if datum.AppMetadata.WTPendingInvite != nil && *datum.AppMetadata.WTPendingInvite && loggedInOnce == true { //nolint:gosimple
log.Infof("user %s has a pending invite and has logged in once, cache invalid", user)
return false
}
knownUsersCount--
continue
}
@@ -1369,15 +1413,11 @@ func (am *DefaultAccountManager) lookupCache(accountUsers map[string]struct{}, a
// if we know users that are not yet in cache more likely cache is outdated
if knownUsersCount > 0 {
log.Debugf("cache doesn't know about %d users from store, reloading", knownUsersCount)
// reload cache once avoiding loops
data, err = am.refreshCache(accountID)
if err != nil {
return nil, err
}
log.Infof("cache invalid. Users unknown to the cache: %d", knownUsersCount)
return false
}
return data, err
return true
}
func (am *DefaultAccountManager) removeUserFromCache(accountID, userID string) error {
@@ -1425,29 +1465,14 @@ func (am *DefaultAccountManager) updateAccountDomainAttributes(account *Account,
}
// handleExistingUserAccount handles existing User accounts and update its domain attributes.
//
// If there is no primary domain account yet, we set the account as primary for the domain. Otherwise,
// we compare the account's ID with the domain account ID, and if they don't match, we set the account as
// non-primary account for the domain. We don't merge accounts at this stage, because of cases when a domain
// was previously unclassified or classified as public so N users that logged int that time, has they own account
// and peers that shouldn't be lost.
func (am *DefaultAccountManager) handleExistingUserAccount(
existingAcc *Account,
domainAcc *Account,
primaryDomain bool,
claims jwtclaims.AuthorizationClaims,
) error {
var err error
if domainAcc != nil && existingAcc.Id != domainAcc.Id {
err = am.updateAccountDomainAttributes(existingAcc, claims, false)
if err != nil {
return err
}
} else {
err = am.updateAccountDomainAttributes(existingAcc, claims, true)
if err != nil {
return err
}
err := am.updateAccountDomainAttributes(existingAcc, claims, primaryDomain)
if err != nil {
return err
}
// we should register the account ID to this user's metadata in our IDP manager
@@ -1547,7 +1572,7 @@ func (am *DefaultAccountManager) MarkPATUsed(tokenID string) error {
return err
}
unlock := am.Store.AcquireAccountLock(account.Id)
unlock := am.Store.AcquireAccountWriteLock(account.Id)
defer unlock()
account, err = am.Store.GetAccountByUser(user.Id)
@@ -1630,7 +1655,7 @@ func (am *DefaultAccountManager) GetAccountFromToken(claims jwtclaims.Authorizat
if err != nil {
return nil, nil, err
}
unlock := am.Store.AcquireAccountLock(newAcc.Id)
unlock := am.Store.AcquireAccountWriteLock(newAcc.Id)
alreadyUnlocked := false
defer func() {
if !alreadyUnlocked {
@@ -1649,7 +1674,7 @@ func (am *DefaultAccountManager) GetAccountFromToken(claims jwtclaims.Authorizat
return nil, nil, status.Errorf(status.NotFound, "user %s not found", claims.UserId)
}
if !user.IsServiceUser {
if !user.IsServiceUser && claims.Invited {
err = am.redeemInvite(account, claims.UserId)
if err != nil {
return nil, nil, err
@@ -1781,12 +1806,33 @@ func (am *DefaultAccountManager) getAccountWithAuthorizationClaims(claims jwtcla
account, err := am.Store.GetAccountByUser(claims.UserId)
if err == nil {
err = am.handleExistingUserAccount(account, domainAccount, claims)
unlockAccount := am.Store.AcquireAccountWriteLock(account.Id)
defer unlockAccount()
account, err = am.Store.GetAccountByUser(claims.UserId)
if err != nil {
return nil, err
}
// If there is no primary domain account yet, we set the account as primary for the domain. Otherwise,
// we compare the account's ID with the domain account ID, and if they don't match, we set the account as
// non-primary account for the domain. We don't merge accounts at this stage, because of cases when a domain
// was previously unclassified or classified as public so N users that logged int that time, has they own account
// and peers that shouldn't be lost.
primaryDomain := domainAccount == nil || account.Id == domainAccount.Id
err = am.handleExistingUserAccount(account, primaryDomain, claims)
if err != nil {
return nil, err
}
return account, nil
} else if s, ok := status.FromError(err); ok && s.Type() == status.NotFound {
if domainAccount != nil {
unlockAccount := am.Store.AcquireAccountWriteLock(domainAccount.Id)
defer unlockAccount()
domainAccount, err = am.Store.GetAccountByPrivateDomain(claims.Domain)
if err != nil {
return nil, err
}
}
return am.handleNewUserAccount(domainAccount, claims)
} else {
// other error
@@ -1794,6 +1840,56 @@ func (am *DefaultAccountManager) getAccountWithAuthorizationClaims(claims jwtcla
}
}
func (am *DefaultAccountManager) SyncAndMarkPeer(peerPubKey string, realIP net.IP) (*nbpeer.Peer, *NetworkMap, error) {
accountID, err := am.Store.GetAccountIDByPeerPubKey(peerPubKey)
if err != nil {
return nil, nil, err
}
unlock := am.Store.AcquireAccountReadLock(accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID)
if err != nil {
return nil, nil, err
}
peer, netMap, err := am.SyncPeer(PeerSync{WireGuardPubKey: peerPubKey}, account)
if err != nil {
return nil, nil, mapError(err)
}
err = am.MarkPeerConnected(peerPubKey, true, realIP, account)
if err != nil {
log.Warnf("failed marking peer as connected %s %v", peerPubKey, err)
}
return peer, netMap, nil
}
func (am *DefaultAccountManager) CancelPeerRoutines(peer *nbpeer.Peer) error {
accountID, err := am.Store.GetAccountIDByPeerPubKey(peer.Key)
if err != nil {
return err
}
unlock := am.Store.AcquireAccountWriteLock(accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID)
if err != nil {
return err
}
err = am.MarkPeerConnected(peer.Key, false, nil, account)
if err != nil {
log.Warnf("failed marking peer as connected %s %v", peer.Key, err)
}
return nil
}
// GetAllConnectedPeers returns connected peers based on peersUpdateManager.GetAllConnectedPeers()
func (am *DefaultAccountManager) GetAllConnectedPeers() (map[string]struct{}, error) {
return am.peersUpdateManager.GetAllConnectedPeers(), nil
@@ -1849,6 +1945,7 @@ func (am *DefaultAccountManager) CheckUserAccessByJWTGroups(claims jwtclaims.Aut
}
func (am *DefaultAccountManager) onPeersInvalidated(accountID string) {
log.Debugf("validated peers has been invalidated for account %s", accountID)
updatedAccount, err := am.Store.GetAccount(accountID)
if err != nil {
log.Errorf("failed to get account %s: %v", accountID, err)

View File

@@ -1655,7 +1655,7 @@ func TestDefaultAccountManager_DefaultAccountSettings(t *testing.T) {
func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) {
manager, err := createManager(t)
require.NoError(t, err, "unable to create account manager")
account, err := manager.GetAccountByUserOrAccountID(userID, "", "")
_, err = manager.GetAccountByUserOrAccountID(userID, "", "")
require.NoError(t, err, "unable to create an account")
key, err := wgtypes.GenerateKey()
@@ -1666,7 +1666,10 @@ func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) {
LoginExpirationEnabled: true,
})
require.NoError(t, err, "unable to add peer")
err = manager.MarkPeerConnected(key.PublicKey().String(), true, nil)
account, err := manager.GetAccountByUserOrAccountID(userID, "", "")
require.NoError(t, err, "unable to get the account")
err = manager.MarkPeerConnected(key.PublicKey().String(), true, nil, account)
require.NoError(t, err, "unable to mark peer connected")
account, err = manager.UpdateAccountSettings(account.Id, userID, &Settings{
PeerLoginExpiration: time.Hour,
@@ -1732,8 +1735,10 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing.
},
}
account, err = manager.GetAccountByUserOrAccountID(userID, "", "")
require.NoError(t, err, "unable to get the account")
// when we mark peer as connected, the peer login expiration routine should trigger
err = manager.MarkPeerConnected(key.PublicKey().String(), true, nil)
err = manager.MarkPeerConnected(key.PublicKey().String(), true, nil, account)
require.NoError(t, err, "unable to mark peer connected")
failed := waitTimeout(wg, time.Second)
@@ -1745,7 +1750,7 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing.
func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *testing.T) {
manager, err := createManager(t)
require.NoError(t, err, "unable to create account manager")
account, err := manager.GetAccountByUserOrAccountID(userID, "", "")
_, err = manager.GetAccountByUserOrAccountID(userID, "", "")
require.NoError(t, err, "unable to create an account")
key, err := wgtypes.GenerateKey()
@@ -1756,7 +1761,10 @@ func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *test
LoginExpirationEnabled: true,
})
require.NoError(t, err, "unable to add peer")
err = manager.MarkPeerConnected(key.PublicKey().String(), true, nil)
account, err := manager.GetAccountByUserOrAccountID(userID, "", "")
require.NoError(t, err, "unable to get the account")
err = manager.MarkPeerConnected(key.PublicKey().String(), true, nil, account)
require.NoError(t, err, "unable to mark peer connected")
wg := &sync.WaitGroup{}

View File

@@ -11,133 +11,134 @@ type Code struct {
Code string
}
// Existing consts must not be changed, as this will break the compatibility with the existing data
const (
// PeerAddedByUser indicates that a user added a new peer to the system
PeerAddedByUser Activity = iota
PeerAddedByUser Activity = 0
// PeerAddedWithSetupKey indicates that a new peer joined the system using a setup key
PeerAddedWithSetupKey
PeerAddedWithSetupKey Activity = 1
// UserJoined indicates that a new user joined the account
UserJoined
UserJoined Activity = 2
// UserInvited indicates that a new user was invited to join the account
UserInvited
UserInvited Activity = 3
// AccountCreated indicates that a new account has been created
AccountCreated
AccountCreated Activity = 4
// PeerRemovedByUser indicates that a user removed a peer from the system
PeerRemovedByUser
PeerRemovedByUser Activity = 5
// RuleAdded indicates that a user added a new rule
RuleAdded
RuleAdded Activity = 6
// RuleUpdated indicates that a user updated a rule
RuleUpdated
RuleUpdated Activity = 7
// RuleRemoved indicates that a user removed a rule
RuleRemoved
RuleRemoved Activity = 8
// PolicyAdded indicates that a user added a new policy
PolicyAdded
PolicyAdded Activity = 9
// PolicyUpdated indicates that a user updated a policy
PolicyUpdated
PolicyUpdated Activity = 10
// PolicyRemoved indicates that a user removed a policy
PolicyRemoved
PolicyRemoved Activity = 11
// SetupKeyCreated indicates that a user created a new setup key
SetupKeyCreated
SetupKeyCreated Activity = 12
// SetupKeyUpdated indicates that a user updated a setup key
SetupKeyUpdated
SetupKeyUpdated Activity = 13
// SetupKeyRevoked indicates that a user revoked a setup key
SetupKeyRevoked
SetupKeyRevoked Activity = 14
// SetupKeyOverused indicates that setup key usage exhausted
SetupKeyOverused
SetupKeyOverused Activity = 15
// GroupCreated indicates that a user created a group
GroupCreated
GroupCreated Activity = 16
// GroupUpdated indicates that a user updated a group
GroupUpdated
GroupUpdated Activity = 17
// GroupAddedToPeer indicates that a user added group to a peer
GroupAddedToPeer
GroupAddedToPeer Activity = 18
// GroupRemovedFromPeer indicates that a user removed peer group
GroupRemovedFromPeer
GroupRemovedFromPeer Activity = 19
// GroupAddedToUser indicates that a user added group to a user
GroupAddedToUser
GroupAddedToUser Activity = 20
// GroupRemovedFromUser indicates that a user removed a group from a user
GroupRemovedFromUser
GroupRemovedFromUser Activity = 21
// UserRoleUpdated indicates that a user changed the role of a user
UserRoleUpdated
UserRoleUpdated Activity = 22
// GroupAddedToSetupKey indicates that a user added group to a setup key
GroupAddedToSetupKey
GroupAddedToSetupKey Activity = 23
// GroupRemovedFromSetupKey indicates that a user removed a group from a setup key
GroupRemovedFromSetupKey
GroupRemovedFromSetupKey Activity = 24
// GroupAddedToDisabledManagementGroups indicates that a user added a group to the DNS setting Disabled management groups
GroupAddedToDisabledManagementGroups
GroupAddedToDisabledManagementGroups Activity = 25
// GroupRemovedFromDisabledManagementGroups indicates that a user removed a group from the DNS setting Disabled management groups
GroupRemovedFromDisabledManagementGroups
GroupRemovedFromDisabledManagementGroups Activity = 26
// RouteCreated indicates that a user created a route
RouteCreated
RouteCreated Activity = 27
// RouteRemoved indicates that a user deleted a route
RouteRemoved
RouteRemoved Activity = 28
// RouteUpdated indicates that a user updated a route
RouteUpdated
RouteUpdated Activity = 29
// PeerSSHEnabled indicates that a user enabled SSH server on a peer
PeerSSHEnabled
PeerSSHEnabled Activity = 30
// PeerSSHDisabled indicates that a user disabled SSH server on a peer
PeerSSHDisabled
PeerSSHDisabled Activity = 31
// PeerRenamed indicates that a user renamed a peer
PeerRenamed
PeerRenamed Activity = 32
// PeerLoginExpirationEnabled indicates that a user enabled login expiration of a peer
PeerLoginExpirationEnabled
PeerLoginExpirationEnabled Activity = 33
// PeerLoginExpirationDisabled indicates that a user disabled login expiration of a peer
PeerLoginExpirationDisabled
PeerLoginExpirationDisabled Activity = 34
// NameserverGroupCreated indicates that a user created a nameservers group
NameserverGroupCreated
NameserverGroupCreated Activity = 35
// NameserverGroupDeleted indicates that a user deleted a nameservers group
NameserverGroupDeleted
NameserverGroupDeleted Activity = 36
// NameserverGroupUpdated indicates that a user updated a nameservers group
NameserverGroupUpdated
NameserverGroupUpdated Activity = 37
// AccountPeerLoginExpirationEnabled indicates that a user enabled peer login expiration for the account
AccountPeerLoginExpirationEnabled
AccountPeerLoginExpirationEnabled Activity = 38
// AccountPeerLoginExpirationDisabled indicates that a user disabled peer login expiration for the account
AccountPeerLoginExpirationDisabled
AccountPeerLoginExpirationDisabled Activity = 39
// AccountPeerLoginExpirationDurationUpdated indicates that a user updated peer login expiration duration for the account
AccountPeerLoginExpirationDurationUpdated
AccountPeerLoginExpirationDurationUpdated Activity = 40
// PersonalAccessTokenCreated indicates that a user created a personal access token
PersonalAccessTokenCreated
PersonalAccessTokenCreated Activity = 41
// PersonalAccessTokenDeleted indicates that a user deleted a personal access token
PersonalAccessTokenDeleted
PersonalAccessTokenDeleted Activity = 42
// ServiceUserCreated indicates that a user created a service user
ServiceUserCreated
ServiceUserCreated Activity = 43
// ServiceUserDeleted indicates that a user deleted a service user
ServiceUserDeleted
ServiceUserDeleted Activity = 44
// UserBlocked indicates that a user blocked another user
UserBlocked
UserBlocked Activity = 45
// UserUnblocked indicates that a user unblocked another user
UserUnblocked
UserUnblocked Activity = 46
// UserDeleted indicates that a user deleted another user
UserDeleted
UserDeleted Activity = 47
// GroupDeleted indicates that a user deleted group
GroupDeleted
GroupDeleted Activity = 48
// UserLoggedInPeer indicates that user logged in their peer with an interactive SSO login
UserLoggedInPeer
UserLoggedInPeer Activity = 49
// PeerLoginExpired indicates that the user peer login has been expired and peer disconnected
PeerLoginExpired
PeerLoginExpired Activity = 50
// DashboardLogin indicates that the user logged in to the dashboard
DashboardLogin
DashboardLogin Activity = 51
// IntegrationCreated indicates that the user created an integration
IntegrationCreated
IntegrationCreated Activity = 52
// IntegrationUpdated indicates that the user updated an integration
IntegrationUpdated
IntegrationUpdated Activity = 53
// IntegrationDeleted indicates that the user deleted an integration
IntegrationDeleted
IntegrationDeleted Activity = 54
// AccountPeerApprovalEnabled indicates that the user enabled peer approval for the account
AccountPeerApprovalEnabled
AccountPeerApprovalEnabled Activity = 55
// AccountPeerApprovalDisabled indicates that the user disabled peer approval for the account
AccountPeerApprovalDisabled
AccountPeerApprovalDisabled Activity = 56
// PeerApproved indicates that the peer has been approved
PeerApproved
PeerApproved Activity = 57
// PeerApprovalRevoked indicates that the peer approval has been revoked
PeerApprovalRevoked
PeerApprovalRevoked Activity = 58
// TransferredOwnerRole indicates that the user transferred the owner role of the account
TransferredOwnerRole
TransferredOwnerRole Activity = 59
// PostureCheckCreated indicates that the user created a posture check
PostureCheckCreated
PostureCheckCreated Activity = 60
// PostureCheckUpdated indicates that the user updated a posture check
PostureCheckUpdated
PostureCheckUpdated Activity = 61
// PostureCheckDeleted indicates that the user deleted a posture check
PostureCheckDeleted
PostureCheckDeleted Activity = 62
)
var activityMap = map[Activity]Code{

View File

@@ -35,7 +35,7 @@ func (d DNSSettings) Copy() DNSSettings {
// GetDNSSettings validates a user role and returns the DNS settings for the provided account ID
func (am *DefaultAccountManager) GetDNSSettings(accountID string, userID string) (*DNSSettings, error) {
unlock := am.Store.AcquireAccountLock(accountID)
unlock := am.Store.AcquireAccountWriteLock(accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID)
@@ -57,7 +57,7 @@ func (am *DefaultAccountManager) GetDNSSettings(accountID string, userID string)
// SaveDNSSettings validates a user role and updates the account's DNS settings
func (am *DefaultAccountManager) SaveDNSSettings(accountID string, userID string, dnsSettingsToSave *DNSSettings) error {
unlock := am.Store.AcquireAccountLock(accountID)
unlock := am.Store.AcquireAccountWriteLock(accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID)

View File

@@ -12,7 +12,7 @@ import (
// GetEvents returns a list of activity events of an account
func (am *DefaultAccountManager) GetEvents(accountID, userID string) ([]*activity.Event, error) {
unlock := am.Store.AcquireAccountLock(accountID)
unlock := am.Store.AcquireAccountWriteLock(accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID)

View File

@@ -279,8 +279,8 @@ func (s *FileStore) AcquireGlobalLock() (unlock func()) {
return unlock
}
// AcquireAccountLock acquires account lock and returns a function that releases the lock
func (s *FileStore) AcquireAccountLock(accountID string) (unlock func()) {
// AcquireAccountWriteLock acquires account lock for writing to a resource and returns a function that releases the lock
func (s *FileStore) AcquireAccountWriteLock(accountID string) (unlock func()) {
log.Debugf("acquiring lock for account %s", accountID)
start := time.Now()
value, _ := s.accountLocks.LoadOrStore(accountID, &sync.Mutex{})
@@ -295,6 +295,12 @@ func (s *FileStore) AcquireAccountLock(accountID string) (unlock func()) {
return unlock
}
// AcquireAccountReadLock AcquireAccountWriteLock acquires account lock for reading a resource and returns a function that releases the lock
// This method is still returns a write lock as file store can't handle read locks
func (s *FileStore) AcquireAccountReadLock(accountID string) (unlock func()) {
return s.AcquireAccountWriteLock(accountID)
}
func (s *FileStore) SaveAccount(account *Account) error {
s.mux.Lock()
defer s.mux.Unlock()
@@ -572,6 +578,18 @@ func (s *FileStore) GetAccountByPeerPubKey(peerKey string) (*Account, error) {
return account.Copy(), nil
}
func (s *FileStore) GetAccountIDByPeerPubKey(peerKey string) (string, error) {
s.mux.Lock()
defer s.mux.Unlock()
accountID, ok := s.PeerKeyID2AccountID[peerKey]
if !ok {
return "", status.Errorf(status.NotFound, "provided peer key doesn't exists %s", peerKey)
}
return accountID, nil
}
// GetInstallationID returns the installation ID from the store
func (s *FileStore) GetInstallationID() string {
return s.InstallationID

View File

@@ -22,7 +22,7 @@ func (e *GroupLinkError) Error() string {
// GetGroup object of the peers
func (am *DefaultAccountManager) GetGroup(accountID, groupID, userID string) (*nbgroup.Group, error) {
unlock := am.Store.AcquireAccountLock(accountID)
unlock := am.Store.AcquireAccountWriteLock(accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID)
@@ -49,7 +49,7 @@ func (am *DefaultAccountManager) GetGroup(accountID, groupID, userID string) (*n
// GetAllGroups returns all groups in an account
func (am *DefaultAccountManager) GetAllGroups(accountID string, userID string) ([]*nbgroup.Group, error) {
unlock := am.Store.AcquireAccountLock(accountID)
unlock := am.Store.AcquireAccountWriteLock(accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID)
@@ -76,7 +76,7 @@ func (am *DefaultAccountManager) GetAllGroups(accountID string, userID string) (
// GetGroupByName filters all groups in an account by name and returns the one with the most peers
func (am *DefaultAccountManager) GetGroupByName(groupName, accountID string) (*nbgroup.Group, error) {
unlock := am.Store.AcquireAccountLock(accountID)
unlock := am.Store.AcquireAccountWriteLock(accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID)
@@ -109,7 +109,7 @@ func (am *DefaultAccountManager) GetGroupByName(groupName, accountID string) (*n
// SaveGroup object of the peers
func (am *DefaultAccountManager) SaveGroup(accountID, userID string, newGroup *nbgroup.Group) error {
unlock := am.Store.AcquireAccountLock(accountID)
unlock := am.Store.AcquireAccountWriteLock(accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID)
@@ -214,7 +214,7 @@ func difference(a, b []string) []string {
// DeleteGroup object of the peers
func (am *DefaultAccountManager) DeleteGroup(accountId, userId, groupID string) error {
unlock := am.Store.AcquireAccountLock(accountId)
unlock := am.Store.AcquireAccountWriteLock(accountId)
defer unlock()
account, err := am.Store.GetAccount(accountId)
@@ -323,7 +323,7 @@ func (am *DefaultAccountManager) DeleteGroup(accountId, userId, groupID string)
// ListGroups objects of the peers
func (am *DefaultAccountManager) ListGroups(accountID string) ([]*nbgroup.Group, error) {
unlock := am.Store.AcquireAccountLock(accountID)
unlock := am.Store.AcquireAccountWriteLock(accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID)
@@ -341,7 +341,7 @@ func (am *DefaultAccountManager) ListGroups(accountID string) ([]*nbgroup.Group,
// GroupAddPeer appends peer to the group
func (am *DefaultAccountManager) GroupAddPeer(accountID, groupID, peerID string) error {
unlock := am.Store.AcquireAccountLock(accountID)
unlock := am.Store.AcquireAccountWriteLock(accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID)
@@ -377,7 +377,7 @@ func (am *DefaultAccountManager) GroupAddPeer(accountID, groupID, peerID string)
// GroupDeletePeer removes peer from the group
func (am *DefaultAccountManager) GroupDeletePeer(accountID, groupID, peerID string) error {
unlock := am.Store.AcquireAccountLock(accountID)
unlock := am.Store.AcquireAccountWriteLock(accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID)

View File

@@ -134,9 +134,9 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi
return err
}
peer, netMap, err := s.accountManager.SyncPeer(PeerSync{WireGuardPubKey: peerKey.String()})
peer, netMap, err := s.accountManager.SyncAndMarkPeer(peerKey.String(), realIP)
if err != nil {
return mapError(err)
return err
}
err = s.sendInitialSync(peerKey, peer, netMap, srv)
@@ -149,11 +149,6 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi
s.ephemeralManager.OnPeerConnected(peer)
err = s.accountManager.MarkPeerConnected(peerKey.String(), true, realIP)
if err != nil {
log.Warnf("failed marking peer as connected %s %v", peerKey, err)
}
if s.config.TURNConfig.TimeBasedCredentials {
s.turnCredentialsManager.SetupRefresh(peer.ID)
}
@@ -207,7 +202,7 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi
func (s *GRPCServer) cancelPeerRoutines(peer *nbpeer.Peer) {
s.peersUpdateManager.CloseChannel(peer.ID)
s.turnCredentialsManager.CancelRefresh(peer.ID)
_ = s.accountManager.MarkPeerConnected(peer.Key, false, nil)
_ = s.accountManager.CancelPeerRoutines(peer)
s.ephemeralManager.OnPeerDisconnected(peer)
}

View File

@@ -12,6 +12,7 @@ import (
s "github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/geolocation"
"github.com/netbirdio/netbird/management/server/http/middleware"
"github.com/netbirdio/netbird/management/server/integrated_validator"
"github.com/netbirdio/netbird/management/server/jwtclaims"
"github.com/netbirdio/netbird/management/server/telemetry"
)
@@ -38,7 +39,7 @@ type emptyObject struct {
}
// APIHandler creates the Management service HTTP API handler registering all the available endpoints.
func APIHandler(ctx context.Context, accountManager s.AccountManager, LocationManager *geolocation.Geolocation, jwtValidator jwtclaims.JWTValidator, appMetrics telemetry.AppMetrics, authCfg AuthCfg) (http.Handler, error) {
func APIHandler(ctx context.Context, accountManager s.AccountManager, LocationManager *geolocation.Geolocation, jwtValidator jwtclaims.JWTValidator, appMetrics telemetry.AppMetrics, authCfg AuthCfg, integratedValidator integrated_validator.IntegratedValidator) (http.Handler, error) {
claimsExtractor := jwtclaims.NewClaimsExtractor(
jwtclaims.WithAudience(authCfg.Audience),
jwtclaims.WithUserIDClaim(authCfg.UserIDClaim),
@@ -75,7 +76,7 @@ func APIHandler(ctx context.Context, accountManager s.AccountManager, LocationMa
AuthCfg: authCfg,
}
if _, err := integrations.RegisterHandlers(ctx, prefix, api.Router, accountManager, claimsExtractor); err != nil {
if _, err := integrations.RegisterHandlers(ctx, prefix, api.Router, accountManager, claimsExtractor, integratedValidator); err != nil {
return nil, fmt.Errorf("register integrations endpoints: %w", err)
}

View File

@@ -181,8 +181,8 @@ func (h *SetupKeysHandler) GetAllSetupKeys(w http.ResponseWriter, r *http.Reques
}
func writeSuccess(w http.ResponseWriter, key *server.SetupKey) {
w.WriteHeader(200)
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(200)
err := json.NewEncoder(w).Encode(toResponseBody(key))
if err != nil {
util.WriteError(err, w)

View File

@@ -198,8 +198,6 @@ func (h *UsersHandler) GetAllUsers(w http.ResponseWriter, r *http.Request) {
serviceUser := r.URL.Query().Get("service_user")
log.Debugf("UserCount: %v", len(data))
users := make([]*api.User, 0)
for _, r := range data {
if r.NonDeletable {

View File

@@ -20,8 +20,8 @@ type ErrorResponse struct {
// WriteJSONObject simply writes object to the HTTP response in JSON format
func WriteJSONObject(w http.ResponseWriter, obj interface{}) {
w.WriteHeader(http.StatusOK)
w.Header().Set("Content-Type", "application/json; charset=UTF-8")
w.WriteHeader(http.StatusOK)
err := json.NewEncoder(w).Encode(obj)
if err != nil {
WriteError(err, w)
@@ -63,8 +63,8 @@ func (d *Duration) UnmarshalJSON(b []byte) error {
// WriteErrorResponse prepares and writes an error response i nJSON
func WriteErrorResponse(errMsg string, httpStatus int, w http.ResponseWriter) {
w.WriteHeader(httpStatus)
w.Header().Set("Content-Type", "application/json; charset=UTF-8")
w.WriteHeader(httpStatus)
err := json.NewEncoder(w).Encode(&ErrorResponse{
Message: errMsg,
Code: httpStatus,

View File

@@ -31,7 +31,7 @@ func (am *DefaultAccountManager) UpdateIntegratedValidatorGroups(accountID strin
return errors.New("invalid groups")
}
unlock := am.Store.AcquireAccountLock(accountID)
unlock := am.Store.AcquireAccountWriteLock(accountID)
defer unlock()
a, err := am.Store.GetAccountByUser(userID)

View File

@@ -13,6 +13,7 @@ type AuthorizationClaims struct {
Domain string
DomainCategory string
LastLogin time.Time
Invited bool
Raw jwt.MapClaims
}

View File

@@ -20,6 +20,8 @@ const (
UserIDClaim = "sub"
// LastLoginSuffix claim for the last login
LastLoginSuffix = "nb_last_login"
// Invited claim indicates that an incoming JWT is from a user that just accepted an invitation
Invited = "nb_invited"
)
// ExtractClaims Extract function type
@@ -100,6 +102,10 @@ func (c *ClaimsExtractor) FromToken(token *jwt.Token) AuthorizationClaims {
if ok {
jwtClaims.LastLogin = parseTime(LastLoginClaimString.(string))
}
invitedBool, ok := claims[c.authAudience+Invited]
if ok {
jwtClaims.Invited = invitedBool.(bool)
}
return jwtClaims
}

View File

@@ -30,6 +30,10 @@ func newTestRequestWithJWT(t *testing.T, claims AuthorizationClaims, audience st
if claims.LastLogin != (time.Time{}) {
claimMaps[audience+LastLoginSuffix] = claims.LastLogin.Format(layout)
}
if claims.Invited {
claimMaps[audience+Invited] = true
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claimMaps)
r, err := http.NewRequest(http.MethodGet, "http://localhost", nil)
require.NoError(t, err, "creating testing request failed")
@@ -59,12 +63,14 @@ func TestExtractClaimsFromRequestContext(t *testing.T) {
AccountId: "testAcc",
LastLogin: lastLogin,
DomainCategory: "public",
Invited: true,
Raw: jwt.MapClaims{
"https://login/wt_account_domain": "test.com",
"https://login/wt_account_domain_category": "public",
"https://login/wt_account_id": "testAcc",
"https://login/nb_last_login": lastLogin.Format(layout),
"sub": "test",
"https://login/" + Invited: true,
},
},
testingFunc: require.EqualValues,

View File

@@ -0,0 +1,204 @@
package migration
import (
"database/sql"
"encoding/gob"
"encoding/json"
"errors"
"fmt"
"net"
"strings"
log "github.com/sirupsen/logrus"
"gorm.io/gorm"
)
// MigrateFieldFromGobToJSON migrates a column from Gob encoding to JSON encoding.
// T is the type of the model that contains the field to be migrated.
// S is the type of the field to be migrated.
func MigrateFieldFromGobToJSON[T any, S any](db *gorm.DB, fieldName string) error {
oldColumnName := fieldName
newColumnName := fieldName + "_tmp"
var model T
if !db.Migrator().HasTable(&model) {
log.Debugf("Table for %T does not exist, no migration needed", model)
return nil
}
stmt := &gorm.Statement{DB: db}
err := stmt.Parse(model)
if err != nil {
return fmt.Errorf("parse model: %w", err)
}
tableName := stmt.Schema.Table
var item string
if err := db.Model(model).Select(oldColumnName).First(&item).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
log.Debugf("No records in table %s, no migration needed", tableName)
return nil
}
return fmt.Errorf("fetch first record: %w", err)
}
var js json.RawMessage
var syntaxError *json.SyntaxError
err = json.Unmarshal([]byte(item), &js)
if err == nil || !errors.As(err, &syntaxError) {
log.Debugf("No migration needed for %s, %s", tableName, fieldName)
return nil
}
if err := db.Transaction(func(tx *gorm.DB) error {
if err := tx.Exec(fmt.Sprintf("ALTER TABLE %s ADD COLUMN %s TEXT", tableName, newColumnName)).Error; err != nil {
return fmt.Errorf("add column %s: %w", newColumnName, err)
}
var rows []map[string]any
if err := tx.Table(tableName).Select("id", oldColumnName).Find(&rows).Error; err != nil {
return fmt.Errorf("find rows: %w", err)
}
for _, row := range rows {
var field S
str, ok := row[oldColumnName].(string)
if !ok {
return fmt.Errorf("type assertion failed")
}
reader := strings.NewReader(str)
if err := gob.NewDecoder(reader).Decode(&field); err != nil {
return fmt.Errorf("gob decode error: %w", err)
}
jsonValue, err := json.Marshal(field)
if err != nil {
return fmt.Errorf("re-encode to JSON: %w", err)
}
if err := tx.Table(tableName).Where("id = ?", row["id"]).Update(newColumnName, jsonValue).Error; err != nil {
return fmt.Errorf("update row: %w", err)
}
}
if err := tx.Exec(fmt.Sprintf("ALTER TABLE %s DROP COLUMN %s", tableName, oldColumnName)).Error; err != nil {
return fmt.Errorf("drop column %s: %w", oldColumnName, err)
}
if err := tx.Exec(fmt.Sprintf("ALTER TABLE %s RENAME COLUMN %s TO %s", tableName, newColumnName, oldColumnName)).Error; err != nil {
return fmt.Errorf("rename column %s to %s: %w", newColumnName, oldColumnName, err)
}
return nil
}); err != nil {
return err
}
log.Infof("Migration of %s.%s from gob to json completed", tableName, fieldName)
return nil
}
// MigrateNetIPFieldFromBlobToJSON migrates a Net IP column from Blob encoding to JSON encoding.
// T is the type of the model that contains the field to be migrated.
func MigrateNetIPFieldFromBlobToJSON[T any](db *gorm.DB, fieldName string, indexName string) error {
oldColumnName := fieldName
newColumnName := fieldName + "_tmp"
var model T
if !db.Migrator().HasTable(&model) {
log.Printf("Table for %T does not exist, no migration needed", model)
return nil
}
stmt := &gorm.Statement{DB: db}
err := stmt.Parse(&model)
if err != nil {
return fmt.Errorf("parse model: %w", err)
}
tableName := stmt.Schema.Table
var item sql.NullString
if err := db.Model(&model).Select(oldColumnName).First(&item).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
log.Printf("No records in table %s, no migration needed", tableName)
return nil
}
return fmt.Errorf("fetch first record: %w", err)
}
if item.Valid {
var js json.RawMessage
var syntaxError *json.SyntaxError
err = json.Unmarshal([]byte(item.String), &js)
if err == nil || !errors.As(err, &syntaxError) {
log.Debugf("No migration needed for %s, %s", tableName, fieldName)
return nil
}
}
if err := db.Transaction(func(tx *gorm.DB) error {
if err := tx.Exec(fmt.Sprintf("ALTER TABLE %s ADD COLUMN %s TEXT", tableName, newColumnName)).Error; err != nil {
return fmt.Errorf("add column %s: %w", newColumnName, err)
}
var rows []map[string]any
if err := tx.Table(tableName).Select("id", oldColumnName).Find(&rows).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
log.Printf("No records in table %s, no migration needed", tableName)
return nil
}
return fmt.Errorf("find rows: %w", err)
}
for _, row := range rows {
var blobValue string
if columnValue := row[oldColumnName]; columnValue != nil {
value, ok := columnValue.(string)
if !ok {
return fmt.Errorf("type assertion failed")
}
blobValue = value
}
columnIpValue := net.IP(blobValue)
if net.ParseIP(columnIpValue.String()) == nil {
log.Debugf("failed to parse %s as ip, fallback to ipv6 loopback", oldColumnName)
columnIpValue = net.IPv6loopback
}
jsonValue, err := json.Marshal(columnIpValue)
if err != nil {
return fmt.Errorf("re-encode to JSON: %w", err)
}
if err := tx.Table(tableName).Where("id = ?", row["id"]).Update(newColumnName, jsonValue).Error; err != nil {
return fmt.Errorf("update row: %w", err)
}
}
if indexName != "" {
if err := tx.Migrator().DropIndex(&model, indexName); err != nil {
return fmt.Errorf("drop index %s: %w", indexName, err)
}
}
if err := tx.Exec(fmt.Sprintf("ALTER TABLE %s DROP COLUMN %s", tableName, oldColumnName)).Error; err != nil {
return fmt.Errorf("drop column %s: %w", oldColumnName, err)
}
if err := tx.Exec(fmt.Sprintf("ALTER TABLE %s RENAME COLUMN %s TO %s", tableName, newColumnName, oldColumnName)).Error; err != nil {
return fmt.Errorf("rename column %s to %s: %w", newColumnName, oldColumnName, err)
}
return nil
}); err != nil {
return err
}
log.Printf("Migration of %s.%s from blob to json completed", tableName, fieldName)
return nil
}

View File

@@ -0,0 +1,161 @@
package migration_test
import (
"encoding/gob"
"net"
"strings"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/migration"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/route"
)
func setupDatabase(t *testing.T) *gorm.DB {
t.Helper()
db, err := gorm.Open(sqlite.Open("file::memory:?cache=shared"), &gorm.Config{
PrepareStmt: true,
})
require.NoError(t, err, "Failed to open database")
return db
}
func TestMigrateFieldFromGobToJSON_EmptyDB(t *testing.T) {
db := setupDatabase(t)
err := migration.MigrateFieldFromGobToJSON[server.Account, net.IPNet](db, "network_net")
require.NoError(t, err, "Migration should not fail for an empty database")
}
func TestMigrateFieldFromGobToJSON_WithGobData(t *testing.T) {
db := setupDatabase(t)
err := db.AutoMigrate(&server.Account{}, &route.Route{})
require.NoError(t, err, "Failed to auto-migrate tables")
_, ipnet, err := net.ParseCIDR("10.0.0.0/24")
require.NoError(t, err, "Failed to parse CIDR")
type network struct {
server.Network
Net net.IPNet `gorm:"serializer:gob"`
}
type account struct {
server.Account
Network *network `gorm:"embedded;embeddedPrefix:network_"`
}
err = db.Save(&account{Account: server.Account{Id: "123"}, Network: &network{Net: *ipnet}}).Error
require.NoError(t, err, "Failed to insert Gob data")
var gobStr string
err = db.Model(&server.Account{}).Select("network_net").First(&gobStr).Error
assert.NoError(t, err, "Failed to fetch Gob data")
err = gob.NewDecoder(strings.NewReader(gobStr)).Decode(&ipnet)
require.NoError(t, err, "Failed to decode Gob data")
err = migration.MigrateFieldFromGobToJSON[server.Account, net.IPNet](db, "network_net")
require.NoError(t, err, "Migration should not fail with Gob data")
var jsonStr string
db.Model(&server.Account{}).Select("network_net").First(&jsonStr)
assert.JSONEq(t, `{"IP":"10.0.0.0","Mask":"////AA=="}`, jsonStr, "Data should be migrated")
}
func TestMigrateFieldFromGobToJSON_WithJSONData(t *testing.T) {
db := setupDatabase(t)
err := db.AutoMigrate(&server.Account{}, &route.Route{})
require.NoError(t, err, "Failed to auto-migrate tables")
_, ipnet, err := net.ParseCIDR("10.0.0.0/24")
require.NoError(t, err, "Failed to parse CIDR")
err = db.Save(&server.Account{Network: &server.Network{Net: *ipnet}}).Error
require.NoError(t, err, "Failed to insert JSON data")
err = migration.MigrateFieldFromGobToJSON[server.Account, net.IPNet](db, "network_net")
require.NoError(t, err, "Migration should not fail with JSON data")
var jsonStr string
db.Model(&server.Account{}).Select("network_net").First(&jsonStr)
assert.JSONEq(t, `{"IP":"10.0.0.0","Mask":"////AA=="}`, jsonStr, "Data should be unchanged")
}
func TestMigrateNetIPFieldFromBlobToJSON_EmptyDB(t *testing.T) {
db := setupDatabase(t)
err := migration.MigrateNetIPFieldFromBlobToJSON[nbpeer.Peer](db, "ip", "idx_peers_account_id_ip")
require.NoError(t, err, "Migration should not fail for an empty database")
}
func TestMigrateNetIPFieldFromBlobToJSON_WithBlobData(t *testing.T) {
db := setupDatabase(t)
err := db.AutoMigrate(&server.Account{}, &nbpeer.Peer{})
require.NoError(t, err, "Failed to auto-migrate tables")
type location struct {
nbpeer.Location
ConnectionIP net.IP
}
type peer struct {
nbpeer.Peer
Location location `gorm:"embedded;embeddedPrefix:location_"`
}
type account struct {
server.Account
Peers []peer `gorm:"foreignKey:AccountID;references:id"`
}
err = db.Save(&account{
Account: server.Account{Id: "123"},
Peers: []peer{
{Location: location{ConnectionIP: net.IP{10, 0, 0, 1}}},
}},
).Error
require.NoError(t, err, "Failed to insert blob data")
var blobValue string
err = db.Model(&nbpeer.Peer{}).Select("location_connection_ip").First(&blobValue).Error
assert.NoError(t, err, "Failed to fetch blob data")
err = migration.MigrateNetIPFieldFromBlobToJSON[nbpeer.Peer](db, "location_connection_ip", "")
require.NoError(t, err, "Migration should not fail with net.IP blob data")
var jsonStr string
db.Model(&nbpeer.Peer{}).Select("location_connection_ip").First(&jsonStr)
assert.JSONEq(t, `"10.0.0.1"`, jsonStr, "Data should be migrated")
}
func TestMigrateNetIPFieldFromBlobToJSON_WithJSONData(t *testing.T) {
db := setupDatabase(t)
err := db.AutoMigrate(&server.Account{}, &nbpeer.Peer{})
require.NoError(t, err, "Failed to auto-migrate tables")
err = db.Save(&server.Account{
Id: "1234",
PeersG: []nbpeer.Peer{
{Location: nbpeer.Location{ConnectionIP: net.IP{10, 0, 0, 1}}},
}},
).Error
require.NoError(t, err, "Failed to insert JSON data")
err = migration.MigrateNetIPFieldFromBlobToJSON[nbpeer.Peer](db, "location_connection_ip", "")
require.NoError(t, err, "Migration should not fail with net.IP JSON data")
var jsonStr string
db.Model(&nbpeer.Peer{}).Select("location_connection_ip").First(&jsonStr)
assert.JSONEq(t, `"10.0.0.1"`, jsonStr, "Data should be unchanged")
}

View File

@@ -22,80 +22,93 @@ type MockAccountManager struct {
GetOrCreateAccountByUserFunc func(userId, domain string) (*server.Account, error)
CreateSetupKeyFunc func(accountId string, keyName string, keyType server.SetupKeyType,
expiresIn time.Duration, autoGroups []string, usageLimit int, userID string, ephemeral bool) (*server.SetupKey, error)
GetSetupKeyFunc func(accountID, userID, keyID string) (*server.SetupKey, error)
GetAccountByUserOrAccountIdFunc func(userId, accountId, domain string) (*server.Account, error)
GetUserFunc func(claims jwtclaims.AuthorizationClaims) (*server.User, error)
ListUsersFunc func(accountID string) ([]*server.User, error)
GetPeersFunc func(accountID, userID string) ([]*nbpeer.Peer, error)
MarkPeerConnectedFunc func(peerKey string, connected bool, realIP net.IP) error
DeletePeerFunc func(accountID, peerKey, userID string) error
GetNetworkMapFunc func(peerKey string) (*server.NetworkMap, error)
GetPeerNetworkFunc func(peerKey string) (*server.Network, error)
AddPeerFunc func(setupKey string, userId string, peer *nbpeer.Peer) (*nbpeer.Peer, *server.NetworkMap, error)
GetGroupFunc func(accountID, groupID, userID string) (*group.Group, error)
GetAllGroupsFunc func(accountID, userID string) ([]*group.Group, error)
GetGroupByNameFunc func(accountID, groupName string) (*group.Group, error)
SaveGroupFunc func(accountID, userID string, group *group.Group) error
DeleteGroupFunc func(accountID, userId, groupID string) error
ListGroupsFunc func(accountID string) ([]*group.Group, error)
GroupAddPeerFunc func(accountID, groupID, peerID string) error
GroupDeletePeerFunc func(accountID, groupID, peerID string) error
DeleteRuleFunc func(accountID, ruleID, userID string) error
GetPolicyFunc func(accountID, policyID, userID string) (*server.Policy, error)
SavePolicyFunc func(accountID, userID string, policy *server.Policy) error
DeletePolicyFunc func(accountID, policyID, userID string) error
ListPoliciesFunc func(accountID, userID string) ([]*server.Policy, error)
GetUsersFromAccountFunc func(accountID, userID string) ([]*server.UserInfo, error)
GetAccountFromPATFunc func(pat string) (*server.Account, *server.User, *server.PersonalAccessToken, error)
MarkPATUsedFunc func(pat string) error
UpdatePeerMetaFunc func(peerID string, meta nbpeer.PeerSystemMeta) error
UpdatePeerSSHKeyFunc func(peerID string, sshKey string) error
UpdatePeerFunc func(accountID, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, error)
CreateRouteFunc func(accountID, prefix, peer string, peerGroups []string, description, netID string, masquerade bool, metric int, groups []string, enabled bool, userID string) (*route.Route, error)
GetRouteFunc func(accountID, routeID, userID string) (*route.Route, error)
SaveRouteFunc func(accountID, userID string, route *route.Route) error
DeleteRouteFunc func(accountID, routeID, userID string) error
ListRoutesFunc func(accountID, userID string) ([]*route.Route, error)
SaveSetupKeyFunc func(accountID string, key *server.SetupKey, userID string) (*server.SetupKey, error)
ListSetupKeysFunc func(accountID, userID string) ([]*server.SetupKey, error)
SaveUserFunc func(accountID, userID string, user *server.User) (*server.UserInfo, error)
SaveOrAddUserFunc func(accountID, userID string, user *server.User, addIfNotExists bool) (*server.UserInfo, error)
DeleteUserFunc func(accountID string, initiatorUserID string, targetUserID string) error
CreatePATFunc func(accountID string, initiatorUserID string, targetUserId string, tokenName string, expiresIn int) (*server.PersonalAccessTokenGenerated, error)
DeletePATFunc func(accountID string, initiatorUserID string, targetUserId string, tokenID string) error
GetPATFunc func(accountID string, initiatorUserID string, targetUserId string, tokenID string) (*server.PersonalAccessToken, error)
GetAllPATsFunc func(accountID string, initiatorUserID string, targetUserId string) ([]*server.PersonalAccessToken, error)
GetNameServerGroupFunc func(accountID, userID, nsGroupID string) (*nbdns.NameServerGroup, error)
CreateNameServerGroupFunc func(accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool, userID string, searchDomainsEnabled bool) (*nbdns.NameServerGroup, error)
SaveNameServerGroupFunc func(accountID, userID string, nsGroupToSave *nbdns.NameServerGroup) error
DeleteNameServerGroupFunc func(accountID, nsGroupID, userID string) error
ListNameServerGroupsFunc func(accountID string, userID string) ([]*nbdns.NameServerGroup, error)
CreateUserFunc func(accountID, userID string, key *server.UserInfo) (*server.UserInfo, error)
GetAccountFromTokenFunc func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error)
CheckUserAccessByJWTGroupsFunc func(claims jwtclaims.AuthorizationClaims) error
DeleteAccountFunc func(accountID, userID string) error
GetDNSDomainFunc func() string
StoreEventFunc func(initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any)
GetEventsFunc func(accountID, userID string) ([]*activity.Event, error)
GetDNSSettingsFunc func(accountID, userID string) (*server.DNSSettings, error)
SaveDNSSettingsFunc func(accountID, userID string, dnsSettingsToSave *server.DNSSettings) error
GetPeerFunc func(accountID, peerID, userID string) (*nbpeer.Peer, error)
UpdateAccountSettingsFunc func(accountID, userID string, newSettings *server.Settings) (*server.Account, error)
LoginPeerFunc func(login server.PeerLogin) (*nbpeer.Peer, *server.NetworkMap, error)
SyncPeerFunc func(sync server.PeerSync) (*nbpeer.Peer, *server.NetworkMap, error)
InviteUserFunc func(accountID string, initiatorUserID string, targetUserEmail string) error
GetAllConnectedPeersFunc func() (map[string]struct{}, error)
HasConnectedChannelFunc func(peerID string) bool
GetExternalCacheManagerFunc func() server.ExternalCacheManager
GetPostureChecksFunc func(accountID, postureChecksID, userID string) (*posture.Checks, error)
SavePostureChecksFunc func(accountID, userID string, postureChecks *posture.Checks) error
DeletePostureChecksFunc func(accountID, postureChecksID, userID string) error
ListPostureChecksFunc func(accountID, userID string) ([]*posture.Checks, error)
GetIdpManagerFunc func() idp.Manager
GetSetupKeyFunc func(accountID, userID, keyID string) (*server.SetupKey, error)
GetAccountByUserOrAccountIdFunc func(userId, accountId, domain string) (*server.Account, error)
GetUserFunc func(claims jwtclaims.AuthorizationClaims) (*server.User, error)
ListUsersFunc func(accountID string) ([]*server.User, error)
GetPeersFunc func(accountID, userID string) ([]*nbpeer.Peer, error)
MarkPeerConnectedFunc func(peerKey string, connected bool, realIP net.IP) error
SyncAndMarkPeerFunc func(peerPubKey string, realIP net.IP) (*nbpeer.Peer, *server.NetworkMap, error)
DeletePeerFunc func(accountID, peerKey, userID string) error
GetNetworkMapFunc func(peerKey string) (*server.NetworkMap, error)
GetPeerNetworkFunc func(peerKey string) (*server.Network, error)
AddPeerFunc func(setupKey string, userId string, peer *nbpeer.Peer) (*nbpeer.Peer, *server.NetworkMap, error)
GetGroupFunc func(accountID, groupID, userID string) (*group.Group, error)
GetAllGroupsFunc func(accountID, userID string) ([]*group.Group, error)
GetGroupByNameFunc func(accountID, groupName string) (*group.Group, error)
SaveGroupFunc func(accountID, userID string, group *group.Group) error
DeleteGroupFunc func(accountID, userId, groupID string) error
ListGroupsFunc func(accountID string) ([]*group.Group, error)
GroupAddPeerFunc func(accountID, groupID, peerID string) error
GroupDeletePeerFunc func(accountID, groupID, peerID string) error
DeleteRuleFunc func(accountID, ruleID, userID string) error
GetPolicyFunc func(accountID, policyID, userID string) (*server.Policy, error)
SavePolicyFunc func(accountID, userID string, policy *server.Policy) error
DeletePolicyFunc func(accountID, policyID, userID string) error
ListPoliciesFunc func(accountID, userID string) ([]*server.Policy, error)
GetUsersFromAccountFunc func(accountID, userID string) ([]*server.UserInfo, error)
GetAccountFromPATFunc func(pat string) (*server.Account, *server.User, *server.PersonalAccessToken, error)
MarkPATUsedFunc func(pat string) error
UpdatePeerMetaFunc func(peerID string, meta nbpeer.PeerSystemMeta) error
UpdatePeerSSHKeyFunc func(peerID string, sshKey string) error
UpdatePeerFunc func(accountID, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, error)
CreateRouteFunc func(accountID, prefix, peer string, peerGroups []string, description, netID string, masquerade bool, metric int, groups []string, enabled bool, userID string) (*route.Route, error)
GetRouteFunc func(accountID, routeID, userID string) (*route.Route, error)
SaveRouteFunc func(accountID, userID string, route *route.Route) error
DeleteRouteFunc func(accountID, routeID, userID string) error
ListRoutesFunc func(accountID, userID string) ([]*route.Route, error)
SaveSetupKeyFunc func(accountID string, key *server.SetupKey, userID string) (*server.SetupKey, error)
ListSetupKeysFunc func(accountID, userID string) ([]*server.SetupKey, error)
SaveUserFunc func(accountID, userID string, user *server.User) (*server.UserInfo, error)
SaveOrAddUserFunc func(accountID, userID string, user *server.User, addIfNotExists bool) (*server.UserInfo, error)
DeleteUserFunc func(accountID string, initiatorUserID string, targetUserID string) error
CreatePATFunc func(accountID string, initiatorUserID string, targetUserId string, tokenName string, expiresIn int) (*server.PersonalAccessTokenGenerated, error)
DeletePATFunc func(accountID string, initiatorUserID string, targetUserId string, tokenID string) error
GetPATFunc func(accountID string, initiatorUserID string, targetUserId string, tokenID string) (*server.PersonalAccessToken, error)
GetAllPATsFunc func(accountID string, initiatorUserID string, targetUserId string) ([]*server.PersonalAccessToken, error)
GetNameServerGroupFunc func(accountID, userID, nsGroupID string) (*nbdns.NameServerGroup, error)
CreateNameServerGroupFunc func(accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool, userID string, searchDomainsEnabled bool) (*nbdns.NameServerGroup, error)
SaveNameServerGroupFunc func(accountID, userID string, nsGroupToSave *nbdns.NameServerGroup) error
DeleteNameServerGroupFunc func(accountID, nsGroupID, userID string) error
ListNameServerGroupsFunc func(accountID string, userID string) ([]*nbdns.NameServerGroup, error)
CreateUserFunc func(accountID, userID string, key *server.UserInfo) (*server.UserInfo, error)
GetAccountFromTokenFunc func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error)
CheckUserAccessByJWTGroupsFunc func(claims jwtclaims.AuthorizationClaims) error
DeleteAccountFunc func(accountID, userID string) error
GetDNSDomainFunc func() string
StoreEventFunc func(initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any)
GetEventsFunc func(accountID, userID string) ([]*activity.Event, error)
GetDNSSettingsFunc func(accountID, userID string) (*server.DNSSettings, error)
SaveDNSSettingsFunc func(accountID, userID string, dnsSettingsToSave *server.DNSSettings) error
GetPeerFunc func(accountID, peerID, userID string) (*nbpeer.Peer, error)
UpdateAccountSettingsFunc func(accountID, userID string, newSettings *server.Settings) (*server.Account, error)
LoginPeerFunc func(login server.PeerLogin) (*nbpeer.Peer, *server.NetworkMap, error)
SyncPeerFunc func(sync server.PeerSync) (*nbpeer.Peer, *server.NetworkMap, error)
InviteUserFunc func(accountID string, initiatorUserID string, targetUserEmail string) error
GetAllConnectedPeersFunc func() (map[string]struct{}, error)
HasConnectedChannelFunc func(peerID string) bool
GetExternalCacheManagerFunc func() server.ExternalCacheManager
GetPostureChecksFunc func(accountID, postureChecksID, userID string) (*posture.Checks, error)
SavePostureChecksFunc func(accountID, userID string, postureChecks *posture.Checks) error
DeletePostureChecksFunc func(accountID, postureChecksID, userID string) error
ListPostureChecksFunc func(accountID, userID string) ([]*posture.Checks, error)
GetIdpManagerFunc func() idp.Manager
UpdateIntegratedValidatorGroupsFunc func(accountID string, userID string, groups []string) error
GroupValidationFunc func(accountId string, groups []string) (bool, error)
}
func (am *MockAccountManager) SyncAndMarkPeer(peerPubKey string, realIP net.IP) (*nbpeer.Peer, *server.NetworkMap, error) {
if am.SyncAndMarkPeerFunc != nil {
return am.SyncAndMarkPeerFunc(peerPubKey, realIP)
}
return nil, nil, status.Errorf(codes.Unimplemented, "method MarkPeerConnected is not implemented")
}
func (am *MockAccountManager) CancelPeerRoutines(peer *nbpeer.Peer) error {
// TODO implement me
panic("implement me")
}
func (am *MockAccountManager) GetValidatedPeers(account *server.Account) (map[string]struct{}, error) {
approvedPeers := make(map[string]struct{})
for id := range account.Peers {
@@ -180,7 +193,7 @@ func (am *MockAccountManager) GetAccountByUserOrAccountID(
}
// MarkPeerConnected mock implementation of MarkPeerConnected from server.AccountManager interface
func (am *MockAccountManager) MarkPeerConnected(peerKey string, connected bool, realIP net.IP) error {
func (am *MockAccountManager) MarkPeerConnected(peerKey string, connected bool, realIP net.IP, account *server.Account) error {
if am.MarkPeerConnectedFunc != nil {
return am.MarkPeerConnectedFunc(peerKey, connected, realIP)
}
@@ -626,7 +639,7 @@ func (am *MockAccountManager) LoginPeer(login server.PeerLogin) (*nbpeer.Peer, *
}
// SyncPeer mocks SyncPeer of the AccountManager interface
func (am *MockAccountManager) SyncPeer(sync server.PeerSync) (*nbpeer.Peer, *server.NetworkMap, error) {
func (am *MockAccountManager) SyncPeer(sync server.PeerSync, account *server.Account) (*nbpeer.Peer, *server.NetworkMap, error) {
if am.SyncPeerFunc != nil {
return am.SyncPeerFunc(sync)
}

View File

@@ -19,7 +19,7 @@ const domainPattern = `^(?i)[a-z0-9]+([\-\.]{1}[a-z0-9]+)*\.[a-z]{2,}$`
// GetNameServerGroup gets a nameserver group object from account and nameserver group IDs
func (am *DefaultAccountManager) GetNameServerGroup(accountID, userID, nsGroupID string) (*nbdns.NameServerGroup, error) {
unlock := am.Store.AcquireAccountLock(accountID)
unlock := am.Store.AcquireAccountWriteLock(accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID)
@@ -47,7 +47,7 @@ func (am *DefaultAccountManager) GetNameServerGroup(accountID, userID, nsGroupID
// CreateNameServerGroup creates and saves a new nameserver group
func (am *DefaultAccountManager) CreateNameServerGroup(accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool, userID string, searchDomainEnabled bool) (*nbdns.NameServerGroup, error) {
unlock := am.Store.AcquireAccountLock(accountID)
unlock := am.Store.AcquireAccountWriteLock(accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID)
@@ -94,7 +94,7 @@ func (am *DefaultAccountManager) CreateNameServerGroup(accountID string, name, d
// SaveNameServerGroup saves nameserver group
func (am *DefaultAccountManager) SaveNameServerGroup(accountID, userID string, nsGroupToSave *nbdns.NameServerGroup) error {
unlock := am.Store.AcquireAccountLock(accountID)
unlock := am.Store.AcquireAccountWriteLock(accountID)
defer unlock()
if nsGroupToSave == nil {
@@ -129,7 +129,7 @@ func (am *DefaultAccountManager) SaveNameServerGroup(accountID, userID string, n
// DeleteNameServerGroup deletes nameserver group with nsGroupID
func (am *DefaultAccountManager) DeleteNameServerGroup(accountID, nsGroupID, userID string) error {
unlock := am.Store.AcquireAccountLock(accountID)
unlock := am.Store.AcquireAccountWriteLock(accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID)
@@ -159,7 +159,7 @@ func (am *DefaultAccountManager) DeleteNameServerGroup(accountID, nsGroupID, use
// ListNameServerGroups returns a list of nameserver groups from account
func (am *DefaultAccountManager) ListNameServerGroups(accountID string, userID string) ([]*nbdns.NameServerGroup, error) {
unlock := am.Store.AcquireAccountLock(accountID)
unlock := am.Store.AcquireAccountWriteLock(accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID)

View File

@@ -36,7 +36,7 @@ type NetworkMap struct {
type Network struct {
Identifier string `json:"id"`
Net net.IPNet `gorm:"serializer:gob"`
Net net.IPNet `gorm:"serializer:json"`
Dns string
// Serial is an ID that increments by 1 when any change to the network happened (e.g. new peer has been added).
// Used to synchronize state to the client apps.

View File

@@ -88,21 +88,7 @@ func (am *DefaultAccountManager) GetPeers(accountID, userID string) ([]*nbpeer.P
}
// MarkPeerConnected marks peer as connected (true) or disconnected (false)
func (am *DefaultAccountManager) MarkPeerConnected(peerPubKey string, connected bool, realIP net.IP) error {
account, err := am.Store.GetAccountByPeerPubKey(peerPubKey)
if err != nil {
return err
}
unlock := am.Store.AcquireAccountLock(account.Id)
defer unlock()
// ensure that we consider modification happened meanwhile (because we were outside the account lock when we fetched the account)
account, err = am.Store.GetAccount(account.Id)
if err != nil {
return err
}
func (am *DefaultAccountManager) MarkPeerConnected(peerPubKey string, connected bool, realIP net.IP, account *Account) error {
peer, err := account.FindPeerByPubKey(peerPubKey)
if err != nil {
return err
@@ -156,7 +142,7 @@ func (am *DefaultAccountManager) MarkPeerConnected(peerPubKey string, connected
// UpdatePeer updates peer. Only Peer.Name, Peer.SSHEnabled, and Peer.LoginExpirationEnabled can be updated.
func (am *DefaultAccountManager) UpdatePeer(accountID, userID string, update *nbpeer.Peer) (*nbpeer.Peer, error) {
unlock := am.Store.AcquireAccountLock(accountID)
unlock := am.Store.AcquireAccountWriteLock(accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID)
@@ -278,7 +264,7 @@ func (am *DefaultAccountManager) deletePeers(account *Account, peerIDs []string,
// DeletePeer removes peer from the account by its IP
func (am *DefaultAccountManager) DeletePeer(accountID, peerID, userID string) error {
unlock := am.Store.AcquireAccountLock(accountID)
unlock := am.Store.AcquireAccountWriteLock(accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID)
@@ -362,7 +348,7 @@ func (am *DefaultAccountManager) AddPeer(setupKey, userID string, peer *nbpeer.P
return nil, nil, status.Errorf(status.NotFound, "failed adding new peer: account not found")
}
unlock := am.Store.AcquireAccountLock(account.Id)
unlock := am.Store.AcquireAccountWriteLock(account.Id)
defer unlock()
// ensure that we consider modification happened meanwhile (because we were outside the account lock when we fetched the account)
@@ -374,14 +360,14 @@ func (am *DefaultAccountManager) AddPeer(setupKey, userID string, peer *nbpeer.P
if strings.ToLower(peer.Meta.Hostname) == "iphone" || strings.ToLower(peer.Meta.Hostname) == "ipad" && userID != "" {
if am.idpManager != nil {
userdata, err := am.lookupUserInCache(userID, account)
if err == nil {
if err == nil && userdata != nil {
peer.Meta.Hostname = fmt.Sprintf("%s-%s", peer.Meta.Hostname, strings.Split(userdata.Email, "@")[0])
}
}
}
// This is a handling for the case when the same machine (with the same WireGuard pub key) tries to register twice.
// Such case is possible when AddPeer function takes long time to finish after AcquireAccountLock (e.g., database is slow)
// Such case is possible when AddPeer function takes long time to finish after AcquireAccountWriteLock (e.g., database is slow)
// and the peer disconnects with a timeout and tries to register again.
// We just check if this machine has been registered before and reject the second registration.
// The connecting peer should be able to recover with a retry.
@@ -518,25 +504,7 @@ func (am *DefaultAccountManager) AddPeer(setupKey, userID string, peer *nbpeer.P
}
// SyncPeer checks whether peer is eligible for receiving NetworkMap (authenticated) and returns its NetworkMap if eligible
func (am *DefaultAccountManager) SyncPeer(sync PeerSync) (*nbpeer.Peer, *NetworkMap, error) {
account, err := am.Store.GetAccountByPeerPubKey(sync.WireGuardPubKey)
if err != nil {
if errStatus, ok := status.FromError(err); ok && errStatus.Type() == status.NotFound {
return nil, nil, status.Errorf(status.Unauthenticated, "peer is not registered")
}
return nil, nil, err
}
// we found the peer, and we follow a normal login flow
unlock := am.Store.AcquireAccountLock(account.Id)
defer unlock()
// fetch the account from the store once more after acquiring lock to avoid concurrent updates inconsistencies
account, err = am.Store.GetAccount(account.Id)
if err != nil {
return nil, nil, err
}
func (am *DefaultAccountManager) SyncPeer(sync PeerSync, account *Account) (*nbpeer.Peer, *NetworkMap, error) {
peer, err := account.FindPeerByPubKey(sync.WireGuardPubKey)
if err != nil {
return nil, nil, status.Errorf(status.Unauthenticated, "peer is not registered")
@@ -551,8 +519,8 @@ func (am *DefaultAccountManager) SyncPeer(sync PeerSync) (*nbpeer.Peer, *Network
return nil, nil, status.Errorf(status.PermissionDenied, "peer login has expired, please log in once more")
}
requiresApproval, isStatusChanged := am.integratedPeerValidator.IsNotValidPeer(account.Id, peer, account.GetPeerGroupsList(peer.ID), account.Settings.Extra)
if requiresApproval {
peerNotValid, isStatusChanged := am.integratedPeerValidator.IsNotValidPeer(account.Id, peer, account.GetPeerGroupsList(peer.ID), account.Settings.Extra)
if peerNotValid {
emptyMap := &NetworkMap{
Network: account.Network.Copy(),
}
@@ -563,11 +531,11 @@ func (am *DefaultAccountManager) SyncPeer(sync PeerSync) (*nbpeer.Peer, *Network
am.updateAccountPeers(account)
}
approvedPeersMap, err := am.GetValidatedPeers(account)
validPeersMap, err := am.GetValidatedPeers(account)
if err != nil {
return nil, nil, err
}
return peer, account.GetPeerNetworkMap(peer.ID, am.dnsDomain, approvedPeersMap), nil
return peer, account.GetPeerNetworkMap(peer.ID, am.dnsDomain, validPeersMap), nil
}
// LoginPeer logs in or registers a peer.
@@ -603,7 +571,7 @@ func (am *DefaultAccountManager) LoginPeer(login PeerLogin) (*nbpeer.Peer, *Netw
}
// we found the peer, and we follow a normal login flow
unlock := am.Store.AcquireAccountLock(account.Id)
unlock := am.Store.AcquireAccountWriteLock(account.Id)
defer unlock()
// fetch the account from the store once more after acquiring lock to avoid concurrent updates inconsistencies
@@ -760,7 +728,7 @@ func (am *DefaultAccountManager) UpdatePeerSSHKey(peerID string, sshKey string)
return err
}
unlock := am.Store.AcquireAccountLock(account.Id)
unlock := am.Store.AcquireAccountWriteLock(account.Id)
defer unlock()
// ensure that we consider modification happened meanwhile (because we were outside the account lock when we fetched the account)
@@ -795,7 +763,7 @@ func (am *DefaultAccountManager) UpdatePeerSSHKey(peerID string, sshKey string)
// GetPeer for a given accountID, peerID and userID error if not found.
func (am *DefaultAccountManager) GetPeer(accountID, peerID, userID string) (*nbpeer.Peer, error) {
unlock := am.Store.AcquireAccountLock(accountID)
unlock := am.Store.AcquireAccountWriteLock(accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID)
@@ -865,6 +833,10 @@ func (am *DefaultAccountManager) updateAccountPeers(account *Account) {
return
}
for _, peer := range peers {
if !am.peersUpdateManager.HasChannel(peer.ID) {
log.Tracef("peer %s doesn't have a channel, skipping network map update", peer.ID)
continue
}
remotePeerNetworkMap := account.GetPeerNetworkMap(peer.ID, am.dnsDomain, approvedPeersMap)
update := toSyncResponse(nil, peer, nil, remotePeerNetworkMap, am.GetDNSDomain())
am.peersUpdateManager.SendUpdate(peer.ID, &UpdateMessage{Update: update})

View File

@@ -13,13 +13,13 @@ type Peer struct {
// ID is an internal ID of the peer
ID string `gorm:"primaryKey"`
// AccountID is a reference to Account that this object belongs
AccountID string `json:"-" gorm:"index;uniqueIndex:idx_peers_account_id_ip"`
AccountID string `json:"-" gorm:"index"`
// WireGuard public key
Key string `gorm:"index"`
// A setup key this peer was registered with
SetupKey string
// IP address of the Peer
IP net.IP `gorm:"uniqueIndex:idx_peers_account_id_ip"`
IP net.IP `gorm:"serializer:json"`
// Meta is a Peer system meta data
Meta PeerSystemMeta `gorm:"embedded;embeddedPrefix:meta_"`
// Name is peer's name (machine name)
@@ -61,7 +61,7 @@ type PeerStatus struct { //nolint:revive
// Location is a geo location information of a Peer based on public connection IP
type Location struct {
ConnectionIP net.IP // from grpc peer or reverse proxy headers depends on setup
ConnectionIP net.IP `gorm:"serializer:json"` // from grpc peer or reverse proxy headers depends on setup
CountryCode string
CityName string
GeoNameID uint // city level geoname id

View File

@@ -314,7 +314,7 @@ func (a *Account) connResourcesGenerator() (func(*PolicyRule, []*nbpeer.Peer, in
// GetPolicy from the store
func (am *DefaultAccountManager) GetPolicy(accountID, policyID, userID string) (*Policy, error) {
unlock := am.Store.AcquireAccountLock(accountID)
unlock := am.Store.AcquireAccountWriteLock(accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID)
@@ -342,7 +342,7 @@ func (am *DefaultAccountManager) GetPolicy(accountID, policyID, userID string) (
// SavePolicy in the store
func (am *DefaultAccountManager) SavePolicy(accountID, userID string, policy *Policy) error {
unlock := am.Store.AcquireAccountLock(accountID)
unlock := am.Store.AcquireAccountWriteLock(accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID)
@@ -370,7 +370,7 @@ func (am *DefaultAccountManager) SavePolicy(accountID, userID string, policy *Po
// DeletePolicy from the store
func (am *DefaultAccountManager) DeletePolicy(accountID, policyID, userID string) error {
unlock := am.Store.AcquireAccountLock(accountID)
unlock := am.Store.AcquireAccountWriteLock(accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID)
@@ -397,7 +397,7 @@ func (am *DefaultAccountManager) DeletePolicy(accountID, policyID, userID string
// ListPolicies from the store
func (am *DefaultAccountManager) ListPolicies(accountID, userID string) ([]*Policy, error) {
unlock := am.Store.AcquireAccountLock(accountID)
unlock := am.Store.AcquireAccountWriteLock(accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID)

View File

@@ -7,7 +7,7 @@ import (
)
func (am *DefaultAccountManager) GetPostureChecks(accountID, postureChecksID, userID string) (*posture.Checks, error) {
unlock := am.Store.AcquireAccountLock(accountID)
unlock := am.Store.AcquireAccountWriteLock(accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID)
@@ -34,7 +34,7 @@ func (am *DefaultAccountManager) GetPostureChecks(accountID, postureChecksID, us
}
func (am *DefaultAccountManager) SavePostureChecks(accountID, userID string, postureChecks *posture.Checks) error {
unlock := am.Store.AcquireAccountLock(accountID)
unlock := am.Store.AcquireAccountWriteLock(accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID)
@@ -81,7 +81,7 @@ func (am *DefaultAccountManager) SavePostureChecks(accountID, userID string, pos
}
func (am *DefaultAccountManager) DeletePostureChecks(accountID, postureChecksID, userID string) error {
unlock := am.Store.AcquireAccountLock(accountID)
unlock := am.Store.AcquireAccountWriteLock(accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID)
@@ -113,7 +113,7 @@ func (am *DefaultAccountManager) DeletePostureChecks(accountID, postureChecksID,
}
func (am *DefaultAccountManager) ListPostureChecks(accountID, userID string) ([]*posture.Checks, error) {
unlock := am.Store.AcquireAccountLock(accountID)
unlock := am.Store.AcquireAccountWriteLock(accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID)

View File

@@ -14,7 +14,7 @@ import (
// GetRoute gets a route object from account and route IDs
func (am *DefaultAccountManager) GetRoute(accountID, routeID, userID string) (*route.Route, error) {
unlock := am.Store.AcquireAccountLock(accountID)
unlock := am.Store.AcquireAccountWriteLock(accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID)
@@ -115,7 +115,7 @@ func (am *DefaultAccountManager) checkRoutePrefixExistsForPeers(account *Account
// CreateRoute creates and saves a new route
func (am *DefaultAccountManager) CreateRoute(accountID, network, peerID string, peerGroupIDs []string, description, netID string, masquerade bool, metric int, groups []string, enabled bool, userID string) (*route.Route, error) {
unlock := am.Store.AcquireAccountLock(accountID)
unlock := am.Store.AcquireAccountWriteLock(accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID)
@@ -194,7 +194,7 @@ func (am *DefaultAccountManager) CreateRoute(accountID, network, peerID string,
// SaveRoute saves route
func (am *DefaultAccountManager) SaveRoute(accountID, userID string, routeToSave *route.Route) error {
unlock := am.Store.AcquireAccountLock(accountID)
unlock := am.Store.AcquireAccountWriteLock(accountID)
defer unlock()
if routeToSave == nil {
@@ -255,7 +255,7 @@ func (am *DefaultAccountManager) SaveRoute(accountID, userID string, routeToSave
// DeleteRoute deletes route with routeID
func (am *DefaultAccountManager) DeleteRoute(accountID, routeID, userID string) error {
unlock := am.Store.AcquireAccountLock(accountID)
unlock := am.Store.AcquireAccountWriteLock(accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID)
@@ -283,7 +283,7 @@ func (am *DefaultAccountManager) DeleteRoute(accountID, routeID, userID string)
// ListRoutes returns a list of routes from account
func (am *DefaultAccountManager) ListRoutes(accountID, userID string) ([]*route.Route, error) {
unlock := am.Store.AcquireAccountLock(accountID)
unlock := am.Store.AcquireAccountWriteLock(accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID)

View File

@@ -209,7 +209,7 @@ func Hash(s string) uint32 {
// and adds it to the specified account. A list of autoGroups IDs can be empty.
func (am *DefaultAccountManager) CreateSetupKey(accountID string, keyName string, keyType SetupKeyType,
expiresIn time.Duration, autoGroups []string, usageLimit int, userID string, ephemeral bool) (*SetupKey, error) {
unlock := am.Store.AcquireAccountLock(accountID)
unlock := am.Store.AcquireAccountWriteLock(accountID)
defer unlock()
keyDuration := DefaultSetupKeyDuration
@@ -255,7 +255,7 @@ func (am *DefaultAccountManager) CreateSetupKey(accountID string, keyName string
// (e.g. the key itself, creation date, ID, etc).
// These properties are overwritten: Name, AutoGroups, Revoked. The rest is copied from the existing key.
func (am *DefaultAccountManager) SaveSetupKey(accountID string, keyToSave *SetupKey, userID string) (*SetupKey, error) {
unlock := am.Store.AcquireAccountLock(accountID)
unlock := am.Store.AcquireAccountWriteLock(accountID)
defer unlock()
if keyToSave == nil {
@@ -327,7 +327,7 @@ func (am *DefaultAccountManager) SaveSetupKey(accountID string, keyToSave *Setup
// ListSetupKeys returns a list of all setup keys of the account
func (am *DefaultAccountManager) ListSetupKeys(accountID, userID string) ([]*SetupKey, error) {
unlock := am.Store.AcquireAccountLock(accountID)
unlock := am.Store.AcquireAccountWriteLock(accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID)
if err != nil {
@@ -359,7 +359,7 @@ func (am *DefaultAccountManager) ListSetupKeys(accountID, userID string) ([]*Set
// GetSetupKey looks up a SetupKey by KeyID, returns NotFound error if not found.
func (am *DefaultAccountManager) GetSetupKey(accountID, userID, keyID string) (*SetupKey, error) {
unlock := am.Store.AcquireAccountLock(accountID)
unlock := am.Store.AcquireAccountWriteLock(accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID)

View File

@@ -3,6 +3,8 @@ package server
import (
"errors"
"fmt"
"net"
"net/netip"
"path/filepath"
"runtime"
"strings"
@@ -18,6 +20,7 @@ import (
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/server/account"
nbgroup "github.com/netbirdio/netbird/management/server/group"
"github.com/netbirdio/netbird/management/server/migration"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/management/server/status"
@@ -40,6 +43,8 @@ type installation struct {
InstallationIDValue string
}
type migrationFunc func(*gorm.DB) error
// NewSqliteStore restores a store from the file located in the datadir
func NewSqliteStore(dataDir string, metrics telemetry.AppMetrics) (*SqliteStore, error) {
storeStr := "store.db?cache=shared"
@@ -50,8 +55,9 @@ func NewSqliteStore(dataDir string, metrics telemetry.AppMetrics) (*SqliteStore,
file := filepath.Join(dataDir, storeStr)
db, err := gorm.Open(sqlite.Open(file), &gorm.Config{
Logger: logger.Default.LogMode(logger.Silent),
PrepareStmt: true,
Logger: logger.Default.LogMode(logger.Silent),
CreateBatchSize: 400,
PrepareStmt: true,
})
if err != nil {
return nil, err
@@ -64,13 +70,16 @@ func NewSqliteStore(dataDir string, metrics telemetry.AppMetrics) (*SqliteStore,
conns := runtime.NumCPU()
sql.SetMaxOpenConns(conns) // TODO: make it configurable
if err := migrate(db); err != nil {
return nil, fmt.Errorf("migrate: %w", err)
}
err = db.AutoMigrate(
&SetupKey{}, &nbpeer.Peer{}, &User{}, &PersonalAccessToken{}, &nbgroup.Group{},
&Account{}, &Policy{}, &PolicyRule{}, &route.Route{}, &nbdns.NameServerGroup{},
&installation{}, &account.ExtraSettings{}, &posture.Checks{}, &nbpeer.NetworkAddress{},
)
if err != nil {
return nil, err
return nil, fmt.Errorf("auto migrate: %w", err)
}
return &SqliteStore{db: db, storeFile: file, metrics: metrics, installationPK: 1}, nil
@@ -118,17 +127,33 @@ func (s *SqliteStore) AcquireGlobalLock() (unlock func()) {
return unlock
}
func (s *SqliteStore) AcquireAccountLock(accountID string) (unlock func()) {
log.Tracef("acquiring lock for account %s", accountID)
func (s *SqliteStore) AcquireAccountWriteLock(accountID string) (unlock func()) {
log.Tracef("acquiring write lock for account %s", accountID)
start := time.Now()
value, _ := s.accountLocks.LoadOrStore(accountID, &sync.Mutex{})
mtx := value.(*sync.Mutex)
value, _ := s.accountLocks.LoadOrStore(accountID, &sync.RWMutex{})
mtx := value.(*sync.RWMutex)
mtx.Lock()
unlock = func() {
mtx.Unlock()
log.Tracef("released lock for account %s in %v", accountID, time.Since(start))
log.Tracef("released write lock for account %s in %v", accountID, time.Since(start))
}
return unlock
}
func (s *SqliteStore) AcquireAccountReadLock(accountID string) (unlock func()) {
log.Tracef("acquiring read lock for account %s", accountID)
start := time.Now()
value, _ := s.accountLocks.LoadOrStore(accountID, &sync.RWMutex{})
mtx := value.(*sync.RWMutex)
mtx.RLock()
unlock = func() {
mtx.RUnlock()
log.Tracef("released read lock for account %s in %v", accountID, time.Since(start))
}
return unlock
@@ -188,7 +213,8 @@ func (s *SqliteStore) SaveAccount(account *Account) error {
result = tx.
Session(&gorm.Session{FullSaveAssociations: true}).
Clauses(clause.OnConflict{UpdateAll: true}).Create(account)
Clauses(clause.OnConflict{UpdateAll: true}).
Create(account)
if result.Error != nil {
return result.Error
}
@@ -253,36 +279,43 @@ func (s *SqliteStore) GetInstallationID() string {
}
func (s *SqliteStore) SavePeerStatus(accountID, peerID string, peerStatus nbpeer.PeerStatus) error {
var peer nbpeer.Peer
var peerCopy nbpeer.Peer
peerCopy.Status = &peerStatus
result := s.db.Model(&nbpeer.Peer{}).
Where("account_id = ? AND id = ?", accountID, peerID).
Updates(peerCopy)
result := s.db.First(&peer, "account_id = ? and id = ?", accountID, peerID)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return status.Errorf(status.NotFound, "peer %s not found", peerID)
}
log.Errorf("error when getting peer from the store: %s", result.Error)
return status.Errorf(status.Internal, "issue getting peer from store")
return result.Error
}
peer.Status = &peerStatus
if result.RowsAffected == 0 {
return status.Errorf(status.NotFound, "peer %s not found", peerID)
}
return s.db.Save(peer).Error
return nil
}
func (s *SqliteStore) SavePeerLocation(accountID string, peerWithLocation *nbpeer.Peer) error {
var peer nbpeer.Peer
result := s.db.First(&peer, "account_id = ? and id = ?", accountID, peerWithLocation.ID)
// To maintain data integrity, we create a copy of the peer's location to prevent unintended updates to other fields.
var peerCopy nbpeer.Peer
// Since the location field has been migrated to JSON serialization,
// updating the struct ensures the correct data format is inserted into the database.
peerCopy.Location = peerWithLocation.Location
result := s.db.Model(&nbpeer.Peer{}).
Where("account_id = ? and id = ?", accountID, peerWithLocation.ID).
Updates(peerCopy)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return status.Errorf(status.NotFound, "peer %s not found", peer.ID)
}
log.Errorf("error when getting peer from the store: %s", result.Error)
return status.Errorf(status.Internal, "issue getting peer from store")
return result.Error
}
peer.Location = peerWithLocation.Location
if result.RowsAffected == 0 {
return status.Errorf(status.NotFound, "peer %s not found", peerWithLocation.ID)
}
return s.db.Save(peer).Error
return nil
}
// DeleteHashedPAT2TokenIDIndex is noop in Sqlite
@@ -390,8 +423,9 @@ func (s *SqliteStore) GetAllAccounts() (all []*Account) {
}
func (s *SqliteStore) GetAccount(accountID string) (*Account, error) {
var account Account
result := s.db.Model(&account).
result := s.db.Debug().Model(&account).
Preload("UsersG.PATsG"). // have to be specifies as this is nester reference
Preload(clause.Associations).
First(&account, "id = ?", accountID)
@@ -511,6 +545,21 @@ func (s *SqliteStore) GetAccountByPeerPubKey(peerKey string) (*Account, error) {
return s.GetAccount(peer.AccountID)
}
func (s *SqliteStore) GetAccountIDByPeerPubKey(peerKey string) (string, error) {
var peer nbpeer.Peer
var accountID string
result := s.db.Model(&peer).Select("account_id").Where("key = ?", peerKey).First(&accountID)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return "", status.Errorf(status.NotFound, "account not found: index lookup failed")
}
log.Errorf("error when getting peer from the store: %s", result.Error)
return "", status.Errorf(status.Internal, "issue getting account from store")
}
return accountID, nil
}
// SaveUserLastLogin stores the last login time for a user in DB.
func (s *SqliteStore) SaveUserLastLogin(accountID, userID string, lastLogin time.Time) error {
var user User
@@ -542,3 +591,36 @@ func (s *SqliteStore) Close() error {
func (s *SqliteStore) GetStoreEngine() StoreEngine {
return SqliteStoreEngine
}
// migrate migrates the SQLite database to the latest schema
func migrate(db *gorm.DB) error {
migrations := getMigrations()
for _, m := range migrations {
if err := m(db); err != nil {
return err
}
}
return nil
}
func getMigrations() []migrationFunc {
return []migrationFunc{
func(db *gorm.DB) error {
return migration.MigrateFieldFromGobToJSON[Account, net.IPNet](db, "network_net")
},
func(db *gorm.DB) error {
return migration.MigrateFieldFromGobToJSON[route.Route, netip.Prefix](db, "network")
},
func(db *gorm.DB) error {
return migration.MigrateFieldFromGobToJSON[route.Route, []string](db, "peer_groups")
},
func(db *gorm.DB) error {
return migration.MigrateNetIPFieldFromBlobToJSON[nbpeer.Peer](db, "location_connection_ip", "")
},
func(db *gorm.DB) error {
return migration.MigrateNetIPFieldFromBlobToJSON[nbpeer.Peer](db, "ip", "idx_peers_account_id_ip")
},
}
}

View File

@@ -2,16 +2,23 @@ package server
import (
"fmt"
"math/rand"
"net"
"net/netip"
"path/filepath"
"runtime"
"testing"
"time"
nbdns "github.com/netbirdio/netbird/dns"
nbgroup "github.com/netbirdio/netbird/management/server/group"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
route2 "github.com/netbirdio/netbird/route"
"github.com/netbirdio/netbird/management/server/status"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
@@ -30,6 +37,151 @@ func TestSqlite_NewStore(t *testing.T) {
}
}
func TestSqlite_SaveAccount_Large(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("The SQLite store is not properly supported by Windows yet")
}
store := newSqliteStore(t)
account := newAccountWithId("account_id", "testuser", "")
groupALL, err := account.GetGroupAll()
if err != nil {
t.Fatal(err)
}
setupKey := GenerateDefaultSetupKey()
account.SetupKeys[setupKey.Key] = setupKey
const numPerAccount = 2000
for n := 0; n < numPerAccount; n++ {
netIP := randomIPv4()
peerID := fmt.Sprintf("%s-peer-%d", account.Id, n)
peer := &nbpeer.Peer{
ID: peerID,
Key: peerID,
SetupKey: "",
IP: netIP,
Name: peerID,
DNSLabel: peerID,
UserID: userID,
Status: &nbpeer.PeerStatus{Connected: false, LastSeen: time.Now()},
SSHEnabled: false,
}
account.Peers[peerID] = peer
group, _ := account.GetGroupAll()
group.Peers = append(group.Peers, peerID)
user := &User{
Id: fmt.Sprintf("%s-user-%d", account.Id, n),
AccountID: account.Id,
}
account.Users[user.Id] = user
route := &route2.Route{
ID: fmt.Sprintf("network-id-%d", n),
Description: "base route",
NetID: fmt.Sprintf("network-id-%d", n),
Network: netip.MustParsePrefix(netIP.String() + "/24"),
NetworkType: route2.IPv4Network,
Metric: 9999,
Masquerade: false,
Enabled: true,
Groups: []string{groupALL.ID},
}
account.Routes[route.ID] = route
group = &nbgroup.Group{
ID: fmt.Sprintf("group-id-%d", n),
AccountID: account.Id,
Name: fmt.Sprintf("group-id-%d", n),
Issued: "api",
Peers: nil,
}
account.Groups[group.ID] = group
nameserver := &nbdns.NameServerGroup{
ID: fmt.Sprintf("nameserver-id-%d", n),
AccountID: account.Id,
Name: fmt.Sprintf("nameserver-id-%d", n),
Description: "",
NameServers: []nbdns.NameServer{{IP: netip.MustParseAddr(netIP.String()), NSType: nbdns.UDPNameServerType}},
Groups: []string{group.ID},
Primary: false,
Domains: nil,
Enabled: false,
SearchDomainsEnabled: false,
}
account.NameServerGroups[nameserver.ID] = nameserver
setupKey := GenerateDefaultSetupKey()
account.SetupKeys[setupKey.Key] = setupKey
}
err = store.SaveAccount(account)
require.NoError(t, err)
if len(store.GetAllAccounts()) != 1 {
t.Errorf("expecting 1 Accounts to be stored after SaveAccount()")
}
a, err := store.GetAccount(account.Id)
if a == nil {
t.Errorf("expecting Account to be stored after SaveAccount(): %v", err)
}
if a != nil && len(a.Policies) != 1 {
t.Errorf("expecting Account to have one policy stored after SaveAccount(), got %d", len(a.Policies))
}
if a != nil && len(a.Policies[0].Rules) != 1 {
t.Errorf("expecting Account to have one policy rule stored after SaveAccount(), got %d", len(a.Policies[0].Rules))
return
}
if a != nil && len(a.Peers) != numPerAccount {
t.Errorf("expecting Account to have %d peers stored after SaveAccount(), got %d",
numPerAccount, len(a.Peers))
return
}
if a != nil && len(a.Users) != numPerAccount+1 {
t.Errorf("expecting Account to have %d users stored after SaveAccount(), got %d",
numPerAccount+1, len(a.Users))
return
}
if a != nil && len(a.Routes) != numPerAccount {
t.Errorf("expecting Account to have %d routes stored after SaveAccount(), got %d",
numPerAccount, len(a.Routes))
return
}
if a != nil && len(a.NameServerGroups) != numPerAccount {
t.Errorf("expecting Account to have %d NameServerGroups stored after SaveAccount(), got %d",
numPerAccount, len(a.NameServerGroups))
return
}
if a != nil && len(a.NameServerGroups) != numPerAccount {
t.Errorf("expecting Account to have %d NameServerGroups stored after SaveAccount(), got %d",
numPerAccount, len(a.NameServerGroups))
return
}
if a != nil && len(a.SetupKeys) != numPerAccount+1 {
t.Errorf("expecting Account to have %d SetupKeys stored after SaveAccount(), got %d",
numPerAccount+1, len(a.SetupKeys))
return
}
}
func randomIPv4() net.IP {
rand.New(rand.NewSource(time.Now().UnixNano()))
b := make([]byte, 4)
for i := range b {
b[i] = byte(rand.Intn(256))
}
return net.IP(b)
}
func TestSqlite_SaveAccount(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("The SQLite store is not properly supported by Windows yet")
@@ -349,6 +501,74 @@ func TestSqlite_GetUserByTokenID(t *testing.T) {
require.Equal(t, status.NotFound, parsedErr.Type(), "should return not found error")
}
func TestMigrate(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("The SQLite store is not properly supported by Windows yet")
}
store := newSqliteStore(t)
err := migrate(store.db)
require.NoError(t, err, "Migration should not fail on empty db")
_, ipnet, err := net.ParseCIDR("10.0.0.0/24")
require.NoError(t, err, "Failed to parse CIDR")
type network struct {
Network
Net net.IPNet `gorm:"serializer:gob"`
}
type location struct {
nbpeer.Location
ConnectionIP net.IP
}
type peer struct {
nbpeer.Peer
Location location `gorm:"embedded;embeddedPrefix:location_"`
}
type account struct {
Account
Network *network `gorm:"embedded;embeddedPrefix:network_"`
Peers []peer `gorm:"foreignKey:AccountID;references:id"`
}
act := &account{
Network: &network{
Net: *ipnet,
},
Peers: []peer{
{Location: location{ConnectionIP: net.IP{10, 0, 0, 1}}},
},
}
err = store.db.Save(act).Error
require.NoError(t, err, "Failed to insert Gob data")
type route struct {
route2.Route
Network netip.Prefix `gorm:"serializer:gob"`
PeerGroups []string `gorm:"serializer:gob"`
}
prefix := netip.MustParsePrefix("11.0.0.0/24")
rt := &route{
Network: prefix,
PeerGroups: []string{"group1", "group2"},
}
err = store.db.Save(rt).Error
require.NoError(t, err, "Failed to insert Gob data")
err = migrate(store.db)
require.NoError(t, err, "Migration should not fail on gob populated db")
err = migrate(store.db)
require.NoError(t, err, "Migration should not fail on migrated db")
}
func newSqliteStore(t *testing.T) *SqliteStore {
t.Helper()

View File

@@ -19,6 +19,7 @@ type Store interface {
DeleteAccount(account *Account) error
GetAccountByUser(userID string) (*Account, error)
GetAccountByPeerPubKey(peerKey string) (*Account, error)
GetAccountIDByPeerPubKey(peerKey string) (string, error)
GetAccountByPeerID(peerID string) (*Account, error)
GetAccountBySetupKey(setupKey string) (*Account, error) // todo use key hash later
GetAccountByPrivateDomain(domain string) (*Account, error)
@@ -29,8 +30,10 @@ type Store interface {
DeleteTokenID2UserIDIndex(tokenID string) error
GetInstallationID() string
SaveInstallationID(ID string) error
// AcquireAccountLock should attempt to acquire account lock and return a function that releases the lock
AcquireAccountLock(accountID string) func()
// AcquireAccountWriteLock should attempt to acquire account lock for write purposes and return a function that releases the lock
AcquireAccountWriteLock(accountID string) func()
// AcquireAccountReadLock should attempt to acquire account lock for read purposes and return a function that releases the lock
AcquireAccountReadLock(accountID string) func()
// AcquireGlobalLock should attempt to acquire a global lock and return a function that releases the lock
AcquireGlobalLock() func()
SavePeerStatus(accountID, peerID string, status nbpeer.PeerStatus) error

View File

@@ -210,7 +210,7 @@ func NewOwnerUser(id string) *User {
// createServiceUser creates a new service user under the given account.
func (am *DefaultAccountManager) createServiceUser(accountID string, initiatorUserID string, role UserRole, serviceUserName string, nonDeletable bool, autoGroups []string) (*UserInfo, error) {
unlock := am.Store.AcquireAccountLock(accountID)
unlock := am.Store.AcquireAccountWriteLock(accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID)
@@ -266,7 +266,7 @@ func (am *DefaultAccountManager) CreateUser(accountID, userID string, user *User
// inviteNewUser Invites a USer to a given account and creates reference in datastore
func (am *DefaultAccountManager) inviteNewUser(accountID, userID string, invite *UserInfo) (*UserInfo, error) {
unlock := am.Store.AcquireAccountLock(accountID)
unlock := am.Store.AcquireAccountWriteLock(accountID)
defer unlock()
if am.idpManager == nil {
@@ -367,7 +367,7 @@ func (am *DefaultAccountManager) GetUser(claims jwtclaims.AuthorizationClaims) (
return nil, fmt.Errorf("failed to get account with token claims %v", err)
}
unlock := am.Store.AcquireAccountLock(account.Id)
unlock := am.Store.AcquireAccountWriteLock(account.Id)
defer unlock()
account, err = am.Store.GetAccount(account.Id)
@@ -400,7 +400,7 @@ func (am *DefaultAccountManager) GetUser(claims jwtclaims.AuthorizationClaims) (
// ListUsers returns lists of all users under the account.
// It doesn't populate user information such as email or name.
func (am *DefaultAccountManager) ListUsers(accountID string) ([]*User, error) {
unlock := am.Store.AcquireAccountLock(accountID)
unlock := am.Store.AcquireAccountWriteLock(accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID)
@@ -427,7 +427,7 @@ func (am *DefaultAccountManager) DeleteUser(accountID, initiatorUserID string, t
if initiatorUserID == targetUserID {
return status.Errorf(status.InvalidArgument, "self deletion is not allowed")
}
unlock := am.Store.AcquireAccountLock(accountID)
unlock := am.Store.AcquireAccountWriteLock(accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID)
@@ -537,7 +537,7 @@ func (am *DefaultAccountManager) deleteUserPeers(initiatorUserID string, targetU
// InviteUser resend invitations to users who haven't activated their accounts prior to the expiration period.
func (am *DefaultAccountManager) InviteUser(accountID string, initiatorUserID string, targetUserID string) error {
unlock := am.Store.AcquireAccountLock(accountID)
unlock := am.Store.AcquireAccountWriteLock(accountID)
defer unlock()
if am.idpManager == nil {
@@ -577,7 +577,7 @@ func (am *DefaultAccountManager) InviteUser(accountID string, initiatorUserID st
// CreatePAT creates a new PAT for the given user
func (am *DefaultAccountManager) CreatePAT(accountID string, initiatorUserID string, targetUserID string, tokenName string, expiresIn int) (*PersonalAccessTokenGenerated, error) {
unlock := am.Store.AcquireAccountLock(accountID)
unlock := am.Store.AcquireAccountWriteLock(accountID)
defer unlock()
if tokenName == "" {
@@ -627,7 +627,7 @@ func (am *DefaultAccountManager) CreatePAT(accountID string, initiatorUserID str
// DeletePAT deletes a specific PAT from a user
func (am *DefaultAccountManager) DeletePAT(accountID string, initiatorUserID string, targetUserID string, tokenID string) error {
unlock := am.Store.AcquireAccountLock(accountID)
unlock := am.Store.AcquireAccountWriteLock(accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID)
@@ -677,7 +677,7 @@ func (am *DefaultAccountManager) DeletePAT(accountID string, initiatorUserID str
// GetPAT returns a specific PAT from a user
func (am *DefaultAccountManager) GetPAT(accountID string, initiatorUserID string, targetUserID string, tokenID string) (*PersonalAccessToken, error) {
unlock := am.Store.AcquireAccountLock(accountID)
unlock := am.Store.AcquireAccountWriteLock(accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID)
@@ -709,7 +709,7 @@ func (am *DefaultAccountManager) GetPAT(accountID string, initiatorUserID string
// GetAllPATs returns all PATs for a user
func (am *DefaultAccountManager) GetAllPATs(accountID string, initiatorUserID string, targetUserID string) ([]*PersonalAccessToken, error) {
unlock := am.Store.AcquireAccountLock(accountID)
unlock := am.Store.AcquireAccountWriteLock(accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID)
@@ -747,7 +747,7 @@ func (am *DefaultAccountManager) SaveUser(accountID, initiatorUserID string, upd
// SaveOrAddUser updates the given user. If addIfNotExists is set to true it will add user when no exist
// Only User.AutoGroups, User.Role, and User.Blocked fields are allowed to be updated for now.
func (am *DefaultAccountManager) SaveOrAddUser(accountID, initiatorUserID string, update *User, addIfNotExists bool) (*UserInfo, error) {
unlock := am.Store.AcquireAccountLock(accountID)
unlock := am.Store.AcquireAccountWriteLock(accountID)
defer unlock()
if update == nil {
@@ -960,7 +960,7 @@ func (am *DefaultAccountManager) GetUsersFromAccount(accountID, userID string) (
queriedUsers := make([]*idp.UserData, 0)
if !isNil(am.idpManager) {
users := make(map[string]struct{}, len(account.Users))
users := make(map[string]userLoggedInOnce, len(account.Users))
usersFromIntegration := make([]*idp.UserData, 0)
for _, user := range account.Users {
if user.Issued == UserIssuedIntegration {
@@ -968,14 +968,14 @@ func (am *DefaultAccountManager) GetUsersFromAccount(accountID, userID string) (
info, err := am.externalCacheManager.Get(am.ctx, key)
if err != nil {
log.Infof("Get ExternalCache for key: %s, error: %s", key, err)
users[user.Id] = struct{}{}
users[user.Id] = true
continue
}
usersFromIntegration = append(usersFromIntegration, info)
continue
}
if !user.IsServiceUser {
users[user.Id] = struct{}{}
users[user.Id] = userLoggedInOnce(!user.LastLogin.IsZero())
}
}
queriedUsers, err = am.lookupCache(users, accountID)

View File

@@ -68,11 +68,11 @@ type Route struct {
ID string `gorm:"primaryKey"`
// AccountID is a reference to Account that this object belongs
AccountID string `gorm:"index"`
Network netip.Prefix `gorm:"serializer:gob"`
Network netip.Prefix `gorm:"serializer:json"`
NetID string
Description string
Peer string
PeerGroups []string `gorm:"serializer:gob"`
PeerGroups []string `gorm:"serializer:json"`
NetworkType NetworkType
Masquerade bool
Metric int
@@ -107,6 +107,12 @@ func (r *Route) Copy() *Route {
// IsEqual compares one route with the other
func (r *Route) IsEqual(other *Route) bool {
if r == nil && other == nil {
return true
} else if r == nil || other == nil {
return false
}
return other.ID == r.ID &&
other.Description == r.Description &&
other.NetID == r.NetID &&

View File

@@ -3,6 +3,8 @@ package grpc
import (
"context"
"net"
"os/user"
"runtime"
log "github.com/sirupsen/logrus"
"google.golang.org/grpc"
@@ -12,6 +14,20 @@ import (
func WithCustomDialer() grpc.DialOption {
return grpc.WithContextDialer(func(ctx context.Context, addr string) (net.Conn, error) {
if runtime.GOOS == "linux" {
currentUser, err := user.Current()
if err != nil {
log.Fatalf("failed to get current user: %v", err)
}
// the custom dialer requires root permissions which are not required for use cases run as non-root
if currentUser.Uid != "0" {
dialer := &net.Dialer{}
return dialer.DialContext(ctx, "tcp", addr)
}
}
conn, err := nbnet.NewDialer().DialContext(ctx, "tcp", addr)
if err != nil {
log.Errorf("Failed to dial: %s", err)

View File

@@ -49,6 +49,10 @@ func RemoveDialerHooks() {
// DialContext wraps the net.Dialer's DialContext method to use the custom connection
func (d *Dialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
if CustomRoutingDisabled() {
return d.Dialer.DialContext(ctx, network, address)
}
var resolver *net.Resolver
if d.Resolver != nil {
resolver = d.Resolver
@@ -123,6 +127,10 @@ func callDialerHooks(ctx context.Context, connID ConnectionID, address string, r
}
func DialUDP(network string, laddr, raddr *net.UDPAddr) (*net.UDPConn, error) {
if CustomRoutingDisabled() {
return net.DialUDP(network, laddr, raddr)
}
dialer := NewDialer()
dialer.LocalAddr = laddr
@@ -143,6 +151,10 @@ func DialUDP(network string, laddr, raddr *net.UDPAddr) (*net.UDPConn, error) {
}
func DialTCP(network string, laddr, raddr *net.TCPAddr) (*net.TCPConn, error) {
if CustomRoutingDisabled() {
return net.DialTCP(network, laddr, raddr)
}
dialer := NewDialer()
dialer.LocalAddr = laddr

View File

@@ -8,6 +8,7 @@ import (
"net"
"sync"
"github.com/pion/transport/v3"
log "github.com/sirupsen/logrus"
)
@@ -52,6 +53,10 @@ func RemoveListenerHooks() {
// ListenPacket listens on the network address and returns a PacketConn
// which includes support for write hooks.
func (l *ListenerConfig) ListenPacket(ctx context.Context, network, address string) (net.PacketConn, error) {
if CustomRoutingDisabled() {
return l.ListenConfig.ListenPacket(ctx, network, address)
}
pc, err := l.ListenConfig.ListenPacket(ctx, network, address)
if err != nil {
return nil, fmt.Errorf("listen packet: %w", err)
@@ -144,7 +149,11 @@ func closeConn(id ConnectionID, conn net.PacketConn) error {
// ListenUDP listens on the network address and returns a transport.UDPConn
// which includes support for write and close hooks.
func ListenUDP(network string, laddr *net.UDPAddr) (*UDPConn, error) {
func ListenUDP(network string, laddr *net.UDPAddr) (transport.UDPConn, error) {
if CustomRoutingDisabled() {
return net.ListenUDP(network, laddr)
}
conn, err := NewListener().ListenPacket(context.Background(), network, laddr.String())
if err != nil {
return nil, fmt.Errorf("listen UDP: %w", err)

View File

@@ -1,10 +1,16 @@
package net
import "github.com/google/uuid"
import (
"os"
"github.com/google/uuid"
)
const (
// NetbirdFwmark is the fwmark value used by Netbird via wireguard
NetbirdFwmark = 0x1BD00
envDisableCustomRouting = "NB_DISABLE_CUSTOM_ROUTING"
)
// ConnectionID provides a globally unique identifier for network connections.
@@ -15,3 +21,7 @@ type ConnectionID string
func GenerateConnID() ConnectionID {
return ConnectionID(uuid.NewString())
}
func CustomRoutingDisabled() bool {
return os.Getenv(envDisableCustomRouting) == "true"
}

View File

@@ -21,7 +21,7 @@ func SetRawSocketMark(conn syscall.RawConn) error {
var setErr error
err := conn.Control(func(fd uintptr) {
setErr = syscall.SetsockoptInt(int(fd), syscall.SOL_SOCKET, syscall.SO_MARK, NetbirdFwmark)
setErr = SetSocketOpt(int(fd))
})
if err != nil {
return fmt.Errorf("control: %w", err)
@@ -33,3 +33,11 @@ func SetRawSocketMark(conn syscall.RawConn) error {
return nil
}
func SetSocketOpt(fd int) error {
if CustomRoutingDisabled() {
return nil
}
return syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_MARK, NetbirdFwmark)
}