Compare commits

..

1 Commits

Author SHA1 Message Date
Maycon Santos
28a5b3062b debug ios behavior 2024-04-16 22:58:16 +02:00
79 changed files with 665 additions and 4503 deletions

View File

@@ -19,7 +19,7 @@ jobs:
- name: codespell
uses: codespell-project/actions-codespell@v2
with:
ignore_words_list: erro,clienta,hastable,
ignore_words_list: erro,clienta
skip: go.mod,go.sum
only_warn: 1
golangci:

View File

@@ -38,7 +38,7 @@ jobs:
- name: Setup NDK
run: /usr/local/lib/android/sdk/cmdline-tools/7.0/bin/sdkmanager --install "ndk;23.1.7779620"
- name: install gomobile
run: go install golang.org/x/mobile/cmd/gomobile@v0.0.0-20240404231514-09dbf07665ed
run: go install golang.org/x/mobile/cmd/gomobile@v0.0.0-20230531173138-3c911d8e3eda
- name: gomobile init
run: gomobile init
- name: build android netbird lib
@@ -56,10 +56,10 @@ jobs:
with:
go-version: "1.21.x"
- name: install gomobile
run: go install golang.org/x/mobile/cmd/gomobile@v0.0.0-20240404231514-09dbf07665ed
run: go install golang.org/x/mobile/cmd/gomobile@v0.0.0-20230531173138-3c911d8e3eda
- name: gomobile init
run: gomobile init
- name: build iOS netbird lib
run: PATH=$PATH:$(go env GOPATH) gomobile bind -target=ios -bundleid=io.netbird.framework -ldflags="-X github.com/netbirdio/netbird/version.version=buildtest" -o ./NetBirdSDK.xcframework ./client/ios/NetBirdSDK
run: PATH=$PATH:$(go env GOPATH) gomobile bind -target=ios -bundleid=io.netbird.framework -ldflags="-X github.com/netbirdio/netbird/version.version=buildtest" -o $GITHUB_WORKSPACE/NetBirdSDK.xcframework $GITHUB_WORKSPACE/client/ios/NetBirdSDK
env:
CGO_ENABLED: 0

View File

@@ -44,8 +44,7 @@
### Open-Source Network Security in a Single Platform
![netbird_2](https://github.com/netbirdio/netbird/assets/700848/46bc3b73-508d-4a0e-bb9a-f465d68646ab)
![image](https://github.com/netbirdio/netbird/assets/700848/c0d7bae4-3301-499a-bb4e-5e4a225bf35f)
### Key features

View File

@@ -1,5 +1,5 @@
FROM alpine:3.18.5
RUN apk add --no-cache ca-certificates iptables ip6tables
ENV NB_FOREGROUND_MODE=true
ENTRYPOINT [ "/usr/local/bin/netbird","up"]
COPY netbird /usr/local/bin/netbird
ENTRYPOINT [ "/go/bin/netbird","up"]
COPY netbird /go/bin/netbird

View File

@@ -1,212 +0,0 @@
package anonymize
import (
"crypto/rand"
"fmt"
"math/big"
"net"
"net/netip"
"net/url"
"regexp"
"slices"
"strings"
)
type Anonymizer struct {
ipAnonymizer map[netip.Addr]netip.Addr
domainAnonymizer map[string]string
currentAnonIPv4 netip.Addr
currentAnonIPv6 netip.Addr
startAnonIPv4 netip.Addr
startAnonIPv6 netip.Addr
}
func DefaultAddresses() (netip.Addr, netip.Addr) {
// 192.51.100.0, 100::
return netip.AddrFrom4([4]byte{198, 51, 100, 0}), netip.AddrFrom16([16]byte{0x01})
}
func NewAnonymizer(startIPv4, startIPv6 netip.Addr) *Anonymizer {
return &Anonymizer{
ipAnonymizer: map[netip.Addr]netip.Addr{},
domainAnonymizer: map[string]string{},
currentAnonIPv4: startIPv4,
currentAnonIPv6: startIPv6,
startAnonIPv4: startIPv4,
startAnonIPv6: startIPv6,
}
}
func (a *Anonymizer) AnonymizeIP(ip netip.Addr) netip.Addr {
if ip.IsLoopback() ||
ip.IsLinkLocalUnicast() ||
ip.IsLinkLocalMulticast() ||
ip.IsInterfaceLocalMulticast() ||
ip.IsPrivate() ||
ip.IsUnspecified() ||
ip.IsMulticast() ||
isWellKnown(ip) ||
a.isInAnonymizedRange(ip) {
return ip
}
if _, ok := a.ipAnonymizer[ip]; !ok {
if ip.Is4() {
a.ipAnonymizer[ip] = a.currentAnonIPv4
a.currentAnonIPv4 = a.currentAnonIPv4.Next()
} else {
a.ipAnonymizer[ip] = a.currentAnonIPv6
a.currentAnonIPv6 = a.currentAnonIPv6.Next()
}
}
return a.ipAnonymizer[ip]
}
// isInAnonymizedRange checks if an IP is within the range of already assigned anonymized IPs
func (a *Anonymizer) isInAnonymizedRange(ip netip.Addr) bool {
if ip.Is4() && ip.Compare(a.startAnonIPv4) >= 0 && ip.Compare(a.currentAnonIPv4) <= 0 {
return true
} else if !ip.Is4() && ip.Compare(a.startAnonIPv6) >= 0 && ip.Compare(a.currentAnonIPv6) <= 0 {
return true
}
return false
}
func (a *Anonymizer) AnonymizeIPString(ip string) string {
addr, err := netip.ParseAddr(ip)
if err != nil {
return ip
}
return a.AnonymizeIP(addr).String()
}
func (a *Anonymizer) AnonymizeDomain(domain string) string {
if strings.HasSuffix(domain, "netbird.io") ||
strings.HasSuffix(domain, "netbird.selfhosted") ||
strings.HasSuffix(domain, "netbird.cloud") ||
strings.HasSuffix(domain, "netbird.stage") ||
strings.HasSuffix(domain, ".domain") {
return domain
}
parts := strings.Split(domain, ".")
if len(parts) < 2 {
return domain
}
baseDomain := parts[len(parts)-2] + "." + parts[len(parts)-1]
anonymized, ok := a.domainAnonymizer[baseDomain]
if !ok {
anonymizedBase := "anon-" + generateRandomString(5) + ".domain"
a.domainAnonymizer[baseDomain] = anonymizedBase
anonymized = anonymizedBase
}
return strings.Replace(domain, baseDomain, anonymized, 1)
}
func (a *Anonymizer) AnonymizeURI(uri string) string {
u, err := url.Parse(uri)
if err != nil {
return uri
}
var anonymizedHost string
if u.Opaque != "" {
host, port, err := net.SplitHostPort(u.Opaque)
if err == nil {
anonymizedHost = fmt.Sprintf("%s:%s", a.AnonymizeDomain(host), port)
} else {
anonymizedHost = a.AnonymizeDomain(u.Opaque)
}
u.Opaque = anonymizedHost
} else if u.Host != "" {
host, port, err := net.SplitHostPort(u.Host)
if err == nil {
anonymizedHost = fmt.Sprintf("%s:%s", a.AnonymizeDomain(host), port)
} else {
anonymizedHost = a.AnonymizeDomain(u.Host)
}
u.Host = anonymizedHost
}
return u.String()
}
func (a *Anonymizer) AnonymizeString(str string) string {
ipv4Regex := regexp.MustCompile(`\b(?:[0-9]{1,3}\.){3}[0-9]{1,3}\b`)
ipv6Regex := regexp.MustCompile(`\b([0-9a-fA-F:]+:+[0-9a-fA-F]{0,4})(?:%[0-9a-zA-Z]+)?(?:\/[0-9]{1,3})?(?::[0-9]{1,5})?\b`)
str = ipv4Regex.ReplaceAllStringFunc(str, a.AnonymizeIPString)
str = ipv6Regex.ReplaceAllStringFunc(str, a.AnonymizeIPString)
for domain, anonDomain := range a.domainAnonymizer {
str = strings.ReplaceAll(str, domain, anonDomain)
}
str = a.AnonymizeSchemeURI(str)
str = a.AnonymizeDNSLogLine(str)
return str
}
// AnonymizeSchemeURI finds and anonymizes URIs with stun, stuns, turn, and turns schemes.
func (a *Anonymizer) AnonymizeSchemeURI(text string) string {
re := regexp.MustCompile(`(?i)\b(stuns?:|turns?:|https?://)\S+\b`)
return re.ReplaceAllStringFunc(text, a.AnonymizeURI)
}
// AnonymizeDNSLogLine anonymizes domain names in DNS log entries by replacing them with a random string.
func (a *Anonymizer) AnonymizeDNSLogLine(logEntry string) string {
domainPattern := `dns\.Question{Name:"([^"]+)",`
domainRegex := regexp.MustCompile(domainPattern)
return domainRegex.ReplaceAllStringFunc(logEntry, func(match string) string {
parts := strings.Split(match, `"`)
if len(parts) >= 2 {
domain := parts[1]
if strings.HasSuffix(domain, ".domain") {
return match
}
randomDomain := generateRandomString(10) + ".domain"
return strings.Replace(match, domain, randomDomain, 1)
}
return match
})
}
func isWellKnown(addr netip.Addr) bool {
wellKnown := []string{
"8.8.8.8", "8.8.4.4", // Google DNS IPv4
"2001:4860:4860::8888", "2001:4860:4860::8844", // Google DNS IPv6
"1.1.1.1", "1.0.0.1", // Cloudflare DNS IPv4
"2606:4700:4700::1111", "2606:4700:4700::1001", // Cloudflare DNS IPv6
"9.9.9.9", "149.112.112.112", // Quad9 DNS IPv4
"2620:fe::fe", "2620:fe::9", // Quad9 DNS IPv6
}
if slices.Contains(wellKnown, addr.String()) {
return true
}
cgnatRangeStart := netip.AddrFrom4([4]byte{100, 64, 0, 0})
cgnatRange := netip.PrefixFrom(cgnatRangeStart, 10)
return cgnatRange.Contains(addr)
}
func generateRandomString(length int) string {
const letters = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
result := make([]byte, length)
for i := range result {
num, err := rand.Int(rand.Reader, big.NewInt(int64(len(letters))))
if err != nil {
continue
}
result[i] = letters[num.Int64()]
}
return string(result)
}

View File

@@ -1,223 +0,0 @@
package anonymize_test
import (
"net/netip"
"regexp"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/anonymize"
)
func TestAnonymizeIP(t *testing.T) {
startIPv4 := netip.MustParseAddr("198.51.100.0")
startIPv6 := netip.MustParseAddr("100::")
anonymizer := anonymize.NewAnonymizer(startIPv4, startIPv6)
tests := []struct {
name string
ip string
expect string
}{
{"Well known", "8.8.8.8", "8.8.8.8"},
{"First Public IPv4", "1.2.3.4", "198.51.100.0"},
{"Second Public IPv4", "4.3.2.1", "198.51.100.1"},
{"Repeated IPv4", "1.2.3.4", "198.51.100.0"},
{"Private IPv4", "192.168.1.1", "192.168.1.1"},
{"First Public IPv6", "2607:f8b0:4005:805::200e", "100::"},
{"Second Public IPv6", "a::b", "100::1"},
{"Repeated IPv6", "2607:f8b0:4005:805::200e", "100::"},
{"Private IPv6", "fe80::1", "fe80::1"},
{"In Range IPv4", "198.51.100.2", "198.51.100.2"},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
ip := netip.MustParseAddr(tc.ip)
anonymizedIP := anonymizer.AnonymizeIP(ip)
if anonymizedIP.String() != tc.expect {
t.Errorf("%s: expected %s, got %s", tc.name, tc.expect, anonymizedIP)
}
})
}
}
func TestAnonymizeDNSLogLine(t *testing.T) {
anonymizer := anonymize.NewAnonymizer(netip.Addr{}, netip.Addr{})
testLog := `2024-04-23T20:01:11+02:00 TRAC client/internal/dns/local.go:25: received question: dns.Question{Name:"example.com", Qtype:0x1c, Qclass:0x1}`
result := anonymizer.AnonymizeDNSLogLine(testLog)
require.NotEqual(t, testLog, result)
assert.NotContains(t, result, "example.com")
}
func TestAnonymizeDomain(t *testing.T) {
anonymizer := anonymize.NewAnonymizer(netip.Addr{}, netip.Addr{})
tests := []struct {
name string
domain string
expectPattern string
shouldAnonymize bool
}{
{
"General Domain",
"example.com",
`^anon-[a-zA-Z0-9]+\.domain$`,
true,
},
{
"Subdomain",
"sub.example.com",
`^sub\.anon-[a-zA-Z0-9]+\.domain$`,
true,
},
{
"Protected Domain",
"netbird.io",
`^netbird\.io$`,
false,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
result := anonymizer.AnonymizeDomain(tc.domain)
if tc.shouldAnonymize {
assert.Regexp(t, tc.expectPattern, result, "The anonymized domain should match the expected pattern")
assert.NotContains(t, result, tc.domain, "The original domain should not be present in the result")
} else {
assert.Equal(t, tc.domain, result, "Protected domains should not be anonymized")
}
})
}
}
func TestAnonymizeURI(t *testing.T) {
anonymizer := anonymize.NewAnonymizer(netip.Addr{}, netip.Addr{})
tests := []struct {
name string
uri string
regex string
}{
{
"HTTP URI with Port",
"http://example.com:80/path",
`^http://anon-[a-zA-Z0-9]+\.domain:80/path$`,
},
{
"HTTP URI without Port",
"http://example.com/path",
`^http://anon-[a-zA-Z0-9]+\.domain/path$`,
},
{
"Opaque URI with Port",
"stun:example.com:80?transport=udp",
`^stun:anon-[a-zA-Z0-9]+\.domain:80\?transport=udp$`,
},
{
"Opaque URI without Port",
"stun:example.com?transport=udp",
`^stun:anon-[a-zA-Z0-9]+\.domain\?transport=udp$`,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
result := anonymizer.AnonymizeURI(tc.uri)
assert.Regexp(t, regexp.MustCompile(tc.regex), result, "URI should match expected pattern")
require.NotContains(t, result, "example.com", "Original domain should not be present")
})
}
}
func TestAnonymizeSchemeURI(t *testing.T) {
anonymizer := anonymize.NewAnonymizer(netip.Addr{}, netip.Addr{})
tests := []struct {
name string
input string
expect string
}{
{"STUN URI in text", "Connection made via stun:example.com", `Connection made via stun:anon-[a-zA-Z0-9]+\.domain`},
{"TURN URI in log", "Failed attempt turn:some.example.com:3478?transport=tcp: retrying", `Failed attempt turn:some.anon-[a-zA-Z0-9]+\.domain:3478\?transport=tcp: retrying`},
{"HTTPS URI in message", "Visit https://example.com for more", `Visit https://anon-[a-zA-Z0-9]+\.domain for more`},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
result := anonymizer.AnonymizeSchemeURI(tc.input)
assert.Regexp(t, tc.expect, result, "The anonymized output should match expected pattern")
require.NotContains(t, result, "example.com", "Original domain should not be present")
})
}
}
func TestAnonymizString_MemorizedDomain(t *testing.T) {
anonymizer := anonymize.NewAnonymizer(netip.Addr{}, netip.Addr{})
domain := "example.com"
anonymizedDomain := anonymizer.AnonymizeDomain(domain)
sampleString := "This is a test string including the domain example.com which should be anonymized."
firstPassResult := anonymizer.AnonymizeString(sampleString)
secondPassResult := anonymizer.AnonymizeString(firstPassResult)
assert.Contains(t, firstPassResult, anonymizedDomain, "The domain should be anonymized in the first pass")
assert.NotContains(t, firstPassResult, domain, "The original domain should not appear in the first pass output")
assert.Equal(t, firstPassResult, secondPassResult, "The second pass should not further anonymize the string")
}
func TestAnonymizeString_DoubleURI(t *testing.T) {
anonymizer := anonymize.NewAnonymizer(netip.Addr{}, netip.Addr{})
domain := "example.com"
anonymizedDomain := anonymizer.AnonymizeDomain(domain)
sampleString := "Check out our site at https://example.com for more info."
firstPassResult := anonymizer.AnonymizeString(sampleString)
secondPassResult := anonymizer.AnonymizeString(firstPassResult)
assert.Contains(t, firstPassResult, "https://"+anonymizedDomain, "The URI should be anonymized in the first pass")
assert.NotContains(t, firstPassResult, "https://example.com", "The original URI should not appear in the first pass output")
assert.Equal(t, firstPassResult, secondPassResult, "The second pass should not further anonymize the URI")
}
func TestAnonymizeString_IPAddresses(t *testing.T) {
anonymizer := anonymize.NewAnonymizer(anonymize.DefaultAddresses())
tests := []struct {
name string
input string
expect string
}{
{
name: "IPv4 Address",
input: "Error occurred at IP 122.138.1.1",
expect: "Error occurred at IP 198.51.100.0",
},
{
name: "IPv6 Address",
input: "Access attempted from 2001:db8::ff00:42",
expect: "Access attempted from 100::",
},
{
name: "IPv6 Address with Port",
input: "Access attempted from [2001:db8::ff00:42]:8080",
expect: "Access attempted from [100::]:8080",
},
{
name: "Both IPv4 and IPv6",
input: "IPv4: 142.108.0.1 and IPv6: 2001:db8::ff00:43",
expect: "IPv4: 198.51.100.1 and IPv6: 100::1",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
result := anonymizer.AnonymizeString(tc.input)
assert.Equal(t, tc.expect, result, "IP addresses should be anonymized correctly")
})
}
}

View File

@@ -1,248 +0,0 @@
package cmd
import (
"context"
"fmt"
"strings"
"time"
"github.com/spf13/cobra"
"google.golang.org/grpc/status"
"github.com/netbirdio/netbird/client/proto"
)
var debugCmd = &cobra.Command{
Use: "debug",
Short: "Debugging commands",
Long: "Provides commands for debugging and logging control within the Netbird daemon.",
}
var debugBundleCmd = &cobra.Command{
Use: "bundle",
Example: " netbird debug bundle",
Short: "Create a debug bundle",
Long: "Generates a compressed archive of the daemon's logs and status for debugging purposes.",
RunE: debugBundle,
}
var logCmd = &cobra.Command{
Use: "log",
Short: "Manage logging for the Netbird daemon",
Long: `Commands to manage logging settings for the Netbird daemon, including ICE, gRPC, and general log levels.`,
}
var logLevelCmd = &cobra.Command{
Use: "level <level>",
Short: "Set the logging level for this session",
Long: `Sets the logging level for the current session. This setting is temporary and will revert to the default on daemon restart.
Available log levels are:
panic: for panic level, highest level of severity
fatal: for fatal level errors that cause the program to exit
error: for error conditions
warn: for warning conditions
info: for informational messages
debug: for debug-level messages
trace: for trace-level messages, which include more fine-grained information than debug`,
Args: cobra.ExactArgs(1),
RunE: setLogLevel,
}
var forCmd = &cobra.Command{
Use: "for <time>",
Short: "Run debug logs for a specified duration and create a debug bundle",
Long: `Sets the logging level to trace, runs for the specified duration, and then generates a debug bundle.`,
Example: " netbird debug for 5m",
Args: cobra.ExactArgs(1),
RunE: runForDuration,
}
func debugBundle(cmd *cobra.Command, _ []string) error {
conn, err := getClient(cmd.Context())
if err != nil {
return err
}
defer conn.Close()
client := proto.NewDaemonServiceClient(conn)
resp, err := client.DebugBundle(cmd.Context(), &proto.DebugBundleRequest{
Anonymize: anonymizeFlag,
Status: getStatusOutput(cmd),
})
if err != nil {
return fmt.Errorf("failed to bundle debug: %v", status.Convert(err).Message())
}
cmd.Println(resp.GetPath())
return nil
}
func setLogLevel(cmd *cobra.Command, args []string) error {
conn, err := getClient(cmd.Context())
if err != nil {
return err
}
defer conn.Close()
client := proto.NewDaemonServiceClient(conn)
level := parseLogLevel(args[0])
if level == proto.LogLevel_UNKNOWN {
return fmt.Errorf("unknown log level: %s. Available levels are: panic, fatal, error, warn, info, debug, trace\n", args[0])
}
_, err = client.SetLogLevel(cmd.Context(), &proto.SetLogLevelRequest{
Level: level,
})
if err != nil {
return fmt.Errorf("failed to set log level: %v", status.Convert(err).Message())
}
cmd.Println("Log level set successfully to", args[0])
return nil
}
func parseLogLevel(level string) proto.LogLevel {
switch strings.ToLower(level) {
case "panic":
return proto.LogLevel_PANIC
case "fatal":
return proto.LogLevel_FATAL
case "error":
return proto.LogLevel_ERROR
case "warn":
return proto.LogLevel_WARN
case "info":
return proto.LogLevel_INFO
case "debug":
return proto.LogLevel_DEBUG
case "trace":
return proto.LogLevel_TRACE
default:
return proto.LogLevel_UNKNOWN
}
}
func runForDuration(cmd *cobra.Command, args []string) error {
duration, err := time.ParseDuration(args[0])
if err != nil {
return fmt.Errorf("invalid duration format: %v", err)
}
conn, err := getClient(cmd.Context())
if err != nil {
return err
}
defer conn.Close()
client := proto.NewDaemonServiceClient(conn)
if _, err := client.Down(cmd.Context(), &proto.DownRequest{}); err != nil {
return fmt.Errorf("failed to down: %v", status.Convert(err).Message())
}
cmd.Println("Netbird down")
_, err = client.SetLogLevel(cmd.Context(), &proto.SetLogLevelRequest{
Level: proto.LogLevel_TRACE,
})
if err != nil {
return fmt.Errorf("failed to set log level to trace: %v", status.Convert(err).Message())
}
cmd.Println("Log level set to trace.")
time.Sleep(1 * time.Second)
if _, err := client.Up(cmd.Context(), &proto.UpRequest{}); err != nil {
return fmt.Errorf("failed to up: %v", status.Convert(err).Message())
}
cmd.Println("Netbird up")
time.Sleep(3 * time.Second)
headerPostUp := fmt.Sprintf("----- Netbird post-up - Timestamp: %s", time.Now().Format(time.RFC3339))
statusOutput := fmt.Sprintf("%s\n%s", headerPostUp, getStatusOutput(cmd))
if waitErr := waitForDurationOrCancel(cmd.Context(), duration, cmd); waitErr != nil {
return waitErr
}
cmd.Println("\nDuration completed")
headerPreDown := fmt.Sprintf("----- Netbird pre-down - Timestamp: %s - Duration: %s", time.Now().Format(time.RFC3339), duration)
statusOutput = fmt.Sprintf("%s\n%s\n%s", statusOutput, headerPreDown, getStatusOutput(cmd))
if _, err := client.Down(cmd.Context(), &proto.DownRequest{}); err != nil {
return fmt.Errorf("failed to down: %v", status.Convert(err).Message())
}
cmd.Println("Netbird down")
// TODO reset log level
time.Sleep(1 * time.Second)
cmd.Println("Creating debug bundle...")
resp, err := client.DebugBundle(cmd.Context(), &proto.DebugBundleRequest{
Anonymize: anonymizeFlag,
Status: statusOutput,
})
if err != nil {
return fmt.Errorf("failed to bundle debug: %v", status.Convert(err).Message())
}
cmd.Println(resp.GetPath())
return nil
}
func getStatusOutput(cmd *cobra.Command) string {
var statusOutputString string
statusResp, err := getStatus(cmd.Context())
if err != nil {
cmd.PrintErrf("Failed to get status: %v\n", err)
} else {
statusOutputString = parseToFullDetailSummary(convertToStatusOutputOverview(statusResp))
}
return statusOutputString
}
func waitForDurationOrCancel(ctx context.Context, duration time.Duration, cmd *cobra.Command) error {
ticker := time.NewTicker(1 * time.Second)
defer ticker.Stop()
startTime := time.Now()
done := make(chan struct{})
go func() {
defer close(done)
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
elapsed := time.Since(startTime)
if elapsed >= duration {
return
}
remaining := duration - elapsed
cmd.Printf("\rRemaining time: %s", formatDuration(remaining))
}
}
}()
select {
case <-ctx.Done():
return ctx.Err()
case <-done:
return nil
}
}
func formatDuration(d time.Duration) string {
d = d.Round(time.Second)
h := d / time.Hour
d %= time.Hour
m := d / time.Minute
d %= time.Minute
s := d / time.Second
return fmt.Sprintf("%02d:%02d:%02d", h, m, s)
}

View File

@@ -65,7 +65,6 @@ var (
serviceName string
autoConnectDisabled bool
extraIFaceBlackList []string
anonymizeFlag bool
rootCmd = &cobra.Command{
Use: "netbird",
Short: "",
@@ -120,8 +119,6 @@ func init() {
rootCmd.PersistentFlags().StringVarP(&setupKey, "setup-key", "k", "", "Setup key obtained from the Management Service Dashboard (used to register peer)")
rootCmd.PersistentFlags().StringVar(&preSharedKey, preSharedKeyFlag, "", "Sets Wireguard PreSharedKey property. If set, then only peers that have the same key can communicate.")
rootCmd.PersistentFlags().StringVarP(&hostName, "hostname", "n", "", "Sets a custom hostname for the device")
rootCmd.PersistentFlags().BoolVarP(&anonymizeFlag, "anonymize", "A", false, "anonymize IP addresses and non-netbird.io domains in logs and status output")
rootCmd.AddCommand(serviceCmd)
rootCmd.AddCommand(upCmd)
rootCmd.AddCommand(downCmd)
@@ -129,20 +126,8 @@ func init() {
rootCmd.AddCommand(loginCmd)
rootCmd.AddCommand(versionCmd)
rootCmd.AddCommand(sshCmd)
rootCmd.AddCommand(routesCmd)
rootCmd.AddCommand(debugCmd)
serviceCmd.AddCommand(runCmd, startCmd, stopCmd, restartCmd) // service control commands are subcommands of service
serviceCmd.AddCommand(installCmd, uninstallCmd) // service installer commands are subcommands of service
routesCmd.AddCommand(routesListCmd)
routesCmd.AddCommand(routesSelectCmd, routesDeselectCmd)
debugCmd.AddCommand(debugBundleCmd)
debugCmd.AddCommand(logCmd)
logCmd.AddCommand(logLevelCmd)
debugCmd.AddCommand(forCmd)
upCmd.PersistentFlags().StringSliceVar(&natExternalIPs, externalIPMapFlag, nil,
`Sets external IPs maps between local addresses and interfaces.`+
`You can specify a comma-separated list with a single IP and IP/IP or IP/Interface Name. `+
@@ -350,14 +335,3 @@ func migrateToNetbird(oldPath, newPath string) bool {
return true
}
func getClient(ctx context.Context) (*grpc.ClientConn, error) {
conn, err := DialClientGRPCServer(ctx, daemonAddr)
if err != nil {
return nil, fmt.Errorf("failed to connect to daemon error: %v\n"+
"If the daemon is not running please run: "+
"\nnetbird service install \nnetbird service start\n", err)
}
return conn, nil
}

View File

@@ -1,131 +0,0 @@
package cmd
import (
"fmt"
"github.com/spf13/cobra"
"google.golang.org/grpc/status"
"github.com/netbirdio/netbird/client/proto"
)
var appendFlag bool
var routesCmd = &cobra.Command{
Use: "routes",
Short: "Manage network routes",
Long: `Commands to list, select, or deselect network routes.`,
}
var routesListCmd = &cobra.Command{
Use: "list",
Aliases: []string{"ls"},
Short: "List routes",
Example: " netbird routes list",
Long: "List all available network routes.",
RunE: routesList,
}
var routesSelectCmd = &cobra.Command{
Use: "select route...|all",
Short: "Select routes",
Long: "Select a list of routes by identifiers or 'all' to clear all selections and to accept all (including new) routes.\nDefault mode is replace, use -a to append to already selected routes.",
Example: " netbird routes select all\n netbird routes select route1 route2\n netbird routes select -a route3",
Args: cobra.MinimumNArgs(1),
RunE: routesSelect,
}
var routesDeselectCmd = &cobra.Command{
Use: "deselect route...|all",
Short: "Deselect routes",
Long: "Deselect previously selected routes by identifiers or 'all' to disable accepting any routes.",
Example: " netbird routes deselect all\n netbird routes deselect route1 route2",
Args: cobra.MinimumNArgs(1),
RunE: routesDeselect,
}
func init() {
routesSelectCmd.PersistentFlags().BoolVarP(&appendFlag, "append", "a", false, "Append to current route selection instead of replacing")
}
func routesList(cmd *cobra.Command, _ []string) error {
conn, err := getClient(cmd.Context())
if err != nil {
return err
}
defer conn.Close()
client := proto.NewDaemonServiceClient(conn)
resp, err := client.ListRoutes(cmd.Context(), &proto.ListRoutesRequest{})
if err != nil {
return fmt.Errorf("failed to list routes: %v", status.Convert(err).Message())
}
if len(resp.Routes) == 0 {
cmd.Println("No routes available.")
return nil
}
cmd.Println("Available Routes:")
for _, route := range resp.Routes {
selectedStatus := "Not Selected"
if route.GetSelected() {
selectedStatus = "Selected"
}
cmd.Printf("\n - ID: %s\n Network: %s\n Status: %s\n", route.GetID(), route.GetNetwork(), selectedStatus)
}
return nil
}
func routesSelect(cmd *cobra.Command, args []string) error {
conn, err := getClient(cmd.Context())
if err != nil {
return err
}
defer conn.Close()
client := proto.NewDaemonServiceClient(conn)
req := &proto.SelectRoutesRequest{
RouteIDs: args,
}
if len(args) == 1 && args[0] == "all" {
req.All = true
} else if appendFlag {
req.Append = true
}
if _, err := client.SelectRoutes(cmd.Context(), req); err != nil {
return fmt.Errorf("failed to select routes: %v", status.Convert(err).Message())
}
cmd.Println("Routes selected successfully.")
return nil
}
func routesDeselect(cmd *cobra.Command, args []string) error {
conn, err := getClient(cmd.Context())
if err != nil {
return err
}
defer conn.Close()
client := proto.NewDaemonServiceClient(conn)
req := &proto.SelectRoutesRequest{
RouteIDs: args,
}
if len(args) == 1 && args[0] == "all" {
req.All = true
}
if _, err := client.DeselectRoutes(cmd.Context(), req); err != nil {
return fmt.Errorf("failed to deselect routes: %v", status.Convert(err).Message())
}
cmd.Println("Routes deselected successfully.")
return nil
}

View File

@@ -24,7 +24,7 @@ var (
)
var sshCmd = &cobra.Command{
Use: "ssh [user@]host",
Use: "ssh",
Args: func(cmd *cobra.Command, args []string) error {
if len(args) < 1 {
return errors.New("requires a host argument")
@@ -94,7 +94,7 @@ func runSSH(ctx context.Context, addr string, pemKey []byte, cmd *cobra.Command)
if err != nil {
cmd.Printf("Error: %v\n", err)
cmd.Printf("Couldn't connect. Please check the connection status or if the ssh server is enabled on the other peer" +
"\nYou can verify the connection by running:\n\n" +
"You can verify the connection by running:\n\n" +
" netbird status\n\n")
return err
}

View File

@@ -6,8 +6,6 @@ import (
"fmt"
"net"
"net/netip"
"os"
"runtime"
"sort"
"strings"
"time"
@@ -16,7 +14,6 @@ import (
"google.golang.org/grpc/status"
"gopkg.in/yaml.v3"
"github.com/netbirdio/netbird/client/anonymize"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/proto"
@@ -147,9 +144,9 @@ func statusFunc(cmd *cobra.Command, args []string) error {
return fmt.Errorf("failed initializing log %v", err)
}
ctx := internal.CtxInitState(cmd.Context())
ctx := internal.CtxInitState(context.Background())
resp, err := getStatus(ctx)
resp, err := getStatus(ctx, cmd)
if err != nil {
return err
}
@@ -194,7 +191,7 @@ func statusFunc(cmd *cobra.Command, args []string) error {
return nil
}
func getStatus(ctx context.Context) (*proto.StatusResponse, error) {
func getStatus(ctx context.Context, cmd *cobra.Command) (*proto.StatusResponse, error) {
conn, err := DialClientGRPCServer(ctx, daemonAddr)
if err != nil {
return nil, fmt.Errorf("failed to connect to daemon error: %v\n"+
@@ -203,7 +200,7 @@ func getStatus(ctx context.Context) (*proto.StatusResponse, error) {
}
defer conn.Close()
resp, err := proto.NewDaemonServiceClient(conn).Status(ctx, &proto.StatusRequest{GetFullPeerStatus: true})
resp, err := proto.NewDaemonServiceClient(conn).Status(cmd.Context(), &proto.StatusRequest{GetFullPeerStatus: true})
if err != nil {
return nil, fmt.Errorf("status failed: %v", status.Convert(err).Message())
}
@@ -286,11 +283,6 @@ func convertToStatusOutputOverview(resp *proto.StatusResponse) statusOutputOverv
NSServerGroups: mapNSGroups(pbFullStatus.GetDnsServers()),
}
if anonymizeFlag {
anonymizer := anonymize.NewAnonymizer(anonymize.DefaultAddresses())
anonymizeOverview(anonymizer, &overview)
}
return overview
}
@@ -533,16 +525,8 @@ func parseGeneralSummary(overview statusOutputOverview, showURL bool, showRelays
peersCountString := fmt.Sprintf("%d/%d Connected", overview.Peers.Connected, overview.Peers.Total)
goos := runtime.GOOS
goarch := runtime.GOARCH
goarm := ""
if goarch == "arm" {
goarm = fmt.Sprintf(" (ARMv%s)", os.Getenv("GOARM"))
}
summary := fmt.Sprintf(
"OS: %s\n"+
"Daemon version: %s\n"+
"Daemon version: %s\n"+
"CLI version: %s\n"+
"Management: %s\n"+
"Signal: %s\n"+
@@ -554,7 +538,6 @@ func parseGeneralSummary(overview statusOutputOverview, showURL bool, showRelays
"Quantum resistance: %s\n"+
"Routes: %s\n"+
"Peers count: %s\n",
fmt.Sprintf("%s/%s%s", goos, goarch, goarm),
overview.DaemonVersion,
version.NetbirdVersion(),
managementConnString,
@@ -610,6 +593,15 @@ func parsePeers(peers peersStateOutput, rosenpassEnabled, rosenpassPermissive bo
if peerState.IceCandidateEndpoint.Remote != "" {
remoteICEEndpoint = peerState.IceCandidateEndpoint.Remote
}
lastStatusUpdate := "-"
if !peerState.LastStatusUpdate.IsZero() {
lastStatusUpdate = peerState.LastStatusUpdate.Format("2006-01-02 15:04:05")
}
lastWireGuardHandshake := "-"
if !peerState.LastWireguardHandshake.IsZero() && peerState.LastWireguardHandshake != time.Unix(0, 0) {
lastWireGuardHandshake = peerState.LastWireguardHandshake.Format("2006-01-02 15:04:05")
}
rosenpassEnabledStatus := "false"
if rosenpassEnabled {
@@ -660,8 +652,8 @@ func parsePeers(peers peersStateOutput, rosenpassEnabled, rosenpassPermissive bo
remoteICE,
localICEEndpoint,
remoteICEEndpoint,
timeAgo(peerState.LastStatusUpdate),
timeAgo(peerState.LastWireguardHandshake),
lastStatusUpdate,
lastWireGuardHandshake,
toIEC(peerState.TransferReceived),
toIEC(peerState.TransferSent),
rosenpassEnabledStatus,
@@ -730,129 +722,3 @@ func countEnabled(dnsServers []nsServerGroupStateOutput) int {
}
return count
}
// timeAgo returns a string representing the duration since the provided time in a human-readable format.
func timeAgo(t time.Time) string {
if t.IsZero() || t.Equal(time.Unix(0, 0)) {
return "-"
}
duration := time.Since(t)
switch {
case duration < time.Second:
return "Now"
case duration < time.Minute:
seconds := int(duration.Seconds())
if seconds == 1 {
return "1 second ago"
}
return fmt.Sprintf("%d seconds ago", seconds)
case duration < time.Hour:
minutes := int(duration.Minutes())
seconds := int(duration.Seconds()) % 60
if minutes == 1 {
if seconds == 1 {
return "1 minute, 1 second ago"
} else if seconds > 0 {
return fmt.Sprintf("1 minute, %d seconds ago", seconds)
}
return "1 minute ago"
}
if seconds > 0 {
return fmt.Sprintf("%d minutes, %d seconds ago", minutes, seconds)
}
return fmt.Sprintf("%d minutes ago", minutes)
case duration < 24*time.Hour:
hours := int(duration.Hours())
minutes := int(duration.Minutes()) % 60
if hours == 1 {
if minutes == 1 {
return "1 hour, 1 minute ago"
} else if minutes > 0 {
return fmt.Sprintf("1 hour, %d minutes ago", minutes)
}
return "1 hour ago"
}
if minutes > 0 {
return fmt.Sprintf("%d hours, %d minutes ago", hours, minutes)
}
return fmt.Sprintf("%d hours ago", hours)
}
days := int(duration.Hours()) / 24
hours := int(duration.Hours()) % 24
if days == 1 {
if hours == 1 {
return "1 day, 1 hour ago"
} else if hours > 0 {
return fmt.Sprintf("1 day, %d hours ago", hours)
}
return "1 day ago"
}
if hours > 0 {
return fmt.Sprintf("%d days, %d hours ago", days, hours)
}
return fmt.Sprintf("%d days ago", days)
}
func anonymizePeerDetail(a *anonymize.Anonymizer, peer *peerStateDetailOutput) {
peer.FQDN = a.AnonymizeDomain(peer.FQDN)
if localIP, port, err := net.SplitHostPort(peer.IceCandidateEndpoint.Local); err == nil {
peer.IceCandidateEndpoint.Local = fmt.Sprintf("%s:%s", a.AnonymizeIPString(localIP), port)
}
if remoteIP, port, err := net.SplitHostPort(peer.IceCandidateEndpoint.Remote); err == nil {
peer.IceCandidateEndpoint.Remote = fmt.Sprintf("%s:%s", a.AnonymizeIPString(remoteIP), port)
}
for i, route := range peer.Routes {
peer.Routes[i] = a.AnonymizeIPString(route)
}
for i, route := range peer.Routes {
prefix, err := netip.ParsePrefix(route)
if err == nil {
ip := a.AnonymizeIPString(prefix.Addr().String())
peer.Routes[i] = fmt.Sprintf("%s/%d", ip, prefix.Bits())
}
}
}
func anonymizeOverview(a *anonymize.Anonymizer, overview *statusOutputOverview) {
for i, peer := range overview.Peers.Details {
peer := peer
anonymizePeerDetail(a, &peer)
overview.Peers.Details[i] = peer
}
overview.ManagementState.URL = a.AnonymizeURI(overview.ManagementState.URL)
overview.ManagementState.Error = a.AnonymizeString(overview.ManagementState.Error)
overview.SignalState.URL = a.AnonymizeURI(overview.SignalState.URL)
overview.SignalState.Error = a.AnonymizeString(overview.SignalState.Error)
overview.IP = a.AnonymizeIPString(overview.IP)
for i, detail := range overview.Relays.Details {
detail.URI = a.AnonymizeURI(detail.URI)
detail.Error = a.AnonymizeString(detail.Error)
overview.Relays.Details[i] = detail
}
for i, nsGroup := range overview.NSServerGroups {
for j, domain := range nsGroup.Domains {
overview.NSServerGroups[i].Domains[j] = a.AnonymizeDomain(domain)
}
for j, ns := range nsGroup.Servers {
host, port, err := net.SplitHostPort(ns)
if err == nil {
overview.NSServerGroups[i].Servers[j] = fmt.Sprintf("%s:%s", a.AnonymizeIPString(host), port)
}
}
}
for i, route := range overview.Routes {
prefix, err := netip.ParsePrefix(route)
if err == nil {
ip := a.AnonymizeIPString(prefix.Addr().String())
overview.Routes[i] = fmt.Sprintf("%s/%d", ip, prefix.Bits())
}
}
overview.FQDN = a.AnonymizeDomain(overview.FQDN)
}

View File

@@ -3,8 +3,6 @@ package cmd
import (
"bytes"
"encoding/json"
"fmt"
"runtime"
"testing"
"time"
@@ -489,15 +487,9 @@ dnsServers:
}
func TestParsingToDetail(t *testing.T) {
// Calculate time ago based on the fixture dates
lastConnectionUpdate1 := timeAgo(overview.Peers.Details[0].LastStatusUpdate)
lastHandshake1 := timeAgo(overview.Peers.Details[0].LastWireguardHandshake)
lastConnectionUpdate2 := timeAgo(overview.Peers.Details[1].LastStatusUpdate)
lastHandshake2 := timeAgo(overview.Peers.Details[1].LastWireguardHandshake)
detail := parseToFullDetailSummary(overview)
expectedDetail := fmt.Sprintf(
expectedDetail :=
`Peers detail:
peer-1.awesome-domain.com:
NetBird IP: 192.168.178.101
@@ -508,8 +500,8 @@ func TestParsingToDetail(t *testing.T) {
Direct: true
ICE candidate (Local/Remote): -/-
ICE candidate endpoints (Local/Remote): -/-
Last connection update: %s
Last WireGuard handshake: %s
Last connection update: 2001-01-01 01:01:01
Last WireGuard handshake: 2001-01-01 01:01:02
Transfer status (received/sent) 200 B/100 B
Quantum resistance: false
Routes: 10.1.0.0/24
@@ -524,16 +516,15 @@ func TestParsingToDetail(t *testing.T) {
Direct: false
ICE candidate (Local/Remote): relay/prflx
ICE candidate endpoints (Local/Remote): 10.0.0.1:10001/10.0.10.1:10002
Last connection update: %s
Last WireGuard handshake: %s
Last connection update: 2002-02-02 02:02:02
Last WireGuard handshake: 2002-02-02 02:02:03
Transfer status (received/sent) 2.0 KiB/1000 B
Quantum resistance: false
Routes: -
Latency: 10ms
OS: %s/%s
Daemon version: 0.14.1
CLI version: %s
CLI version: development
Management: Connected to my-awesome-management.com:443
Signal: Connected to my-awesome-signal.com:443
Relays:
@@ -548,7 +539,7 @@ Interface type: Kernel
Quantum resistance: false
Routes: 10.10.0.0/24
Peers count: 2/2 Connected
`, lastConnectionUpdate1, lastHandshake1, lastConnectionUpdate2, lastHandshake2, runtime.GOOS, runtime.GOARCH, overview.CliVersion)
`
assert.Equal(t, expectedDetail, detail)
}
@@ -556,8 +547,8 @@ Peers count: 2/2 Connected
func TestParsingToShortVersion(t *testing.T) {
shortVersion := parseGeneralSummary(overview, false, false, false)
expectedString := fmt.Sprintf("OS: %s/%s", runtime.GOOS, runtime.GOARCH) + `
Daemon version: 0.14.1
expectedString :=
`Daemon version: 0.14.1
CLI version: development
Management: Connected
Signal: Connected
@@ -581,31 +572,3 @@ func TestParsingOfIP(t *testing.T) {
assert.Equal(t, "192.168.178.123\n", parsedIP)
}
func TestTimeAgo(t *testing.T) {
now := time.Now()
cases := []struct {
name string
input time.Time
expected string
}{
{"Now", now, "Now"},
{"Seconds ago", now.Add(-10 * time.Second), "10 seconds ago"},
{"One minute ago", now.Add(-1 * time.Minute), "1 minute ago"},
{"Minutes and seconds ago", now.Add(-(1*time.Minute + 30*time.Second)), "1 minute, 30 seconds ago"},
{"One hour ago", now.Add(-1 * time.Hour), "1 hour ago"},
{"Hours and minutes ago", now.Add(-(2*time.Hour + 15*time.Minute)), "2 hours, 15 minutes ago"},
{"One day ago", now.Add(-24 * time.Hour), "1 day ago"},
{"Multiple days ago", now.Add(-(72*time.Hour + 20*time.Minute)), "3 days ago"},
{"Zero time", time.Time{}, "-"},
{"Unix zero time", time.Unix(0, 0), "-"},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
result := timeAgo(tc.input)
assert.Equal(t, tc.expected, result, "Failed %s", tc.name)
})
}
}

View File

@@ -250,11 +250,16 @@ func (m *Manager) dropFilter(packetData []byte, rules map[string]RuleSet, isInco
switch ipLayer {
case layers.LayerTypeIPv4:
// log srcIP and DstIP
log.Infof("--------- srcIP: %v, dstIP: %v", d.ip4.SrcIP, d.ip4.DstIP)
if !m.wgNetwork.Contains(d.ip4.SrcIP) || !m.wgNetwork.Contains(d.ip4.DstIP) {
log.Infof("--------- srcIP: %v, dstIP: %v dropped", d.ip4.SrcIP, d.ip4.DstIP)
return false
}
case layers.LayerTypeIPv6:
log.Infof("--------- srcIP: %v, dstIP: %v", d.ip6.SrcIP, d.ip6.DstIP)
if !m.wgNetwork.Contains(d.ip6.SrcIP) || !m.wgNetwork.Contains(d.ip6.DstIP) {
log.Infof("--------- srcIP: %v, dstIP: %v dropped", d.ip6.SrcIP, d.ip6.DstIP)
return false
}
default:
@@ -265,12 +270,14 @@ func (m *Manager) dropFilter(packetData []byte, rules map[string]RuleSet, isInco
var ip net.IP
switch ipLayer {
case layers.LayerTypeIPv4:
log.Infof("--------- srcIP: %v, dstIP: %v", d.ip4.SrcIP, d.ip4.DstIP)
if isIncomingPacket {
ip = d.ip4.SrcIP
} else {
ip = d.ip4.DstIP
}
case layers.LayerTypeIPv6:
log.Infof("--------- srcIP: %v, dstIP: %v", d.ip6.SrcIP, d.ip6.DstIP)
if isIncomingPacket {
ip = d.ip6.SrcIP
} else {
@@ -278,6 +285,8 @@ func (m *Manager) dropFilter(packetData []byte, rules map[string]RuleSet, isInco
}
}
//
filter, ok := validateRule(ip, packetData, rules[ip.String()], d)
if ok {
return filter
@@ -295,8 +304,30 @@ func (m *Manager) dropFilter(packetData []byte, rules map[string]RuleSet, isInco
return true
}
func validateRule(ip net.IP, packetData []byte, rules map[string]Rule, d *decoder) (bool, bool) {
func validateRule(ip net.IP, packetData []byte, rules map[string]Rule, d *decoder) (f bool, o bool) {
ipLayer := d.decoded[0]
payloadLayer := d.decoded[1]
defer func() {
var src, dst net.IP
switch ipLayer {
case layers.LayerTypeIPv4:
src = d.ip4.SrcIP
dst = d.ip4.DstIP
case layers.LayerTypeIPv6:
src = d.ip6.SrcIP
dst = d.ip6.DstIP
}
switch payloadLayer {
case layers.LayerTypeTCP:
log.Infof("--------- TCP srcIP-Port: %v:%d, dstIP-Port: %v:%d Ver: %t,%t", src, uint16(d.tcp.SrcPort), dst, uint16(d.tcp.DstPort), f, o)
case layers.LayerTypeUDP:
log.Infof("--------- UDP srcIP-Port: %v:%d, dstIP-Port: %v:%d Ver: %t,%t", src, uint16(d.udp.SrcPort), dst, uint16(d.udp.DstPort), f, o)
default:
log.Infof("--------- srcIP: %v, dstIP: %v Ver: %t,%t", src, dst, f, o)
}
}()
for _, rule := range rules {
if rule.matchByIP && !ip.Equal(rule.ip) {
continue

View File

@@ -31,7 +31,7 @@ import (
// RunClient with main logic.
func RunClient(ctx context.Context, config *Config, statusRecorder *peer.Status) error {
return runClient(ctx, config, statusRecorder, MobileDependency{}, nil, nil, nil, nil, nil)
return runClient(ctx, config, statusRecorder, MobileDependency{}, nil, nil, nil, nil)
}
// RunClientWithProbes runs the client's main logic with probes attached
@@ -43,9 +43,8 @@ func RunClientWithProbes(
signalProbe *Probe,
relayProbe *Probe,
wgProbe *Probe,
engineChan chan<- *Engine,
) error {
return runClient(ctx, config, statusRecorder, MobileDependency{}, mgmProbe, signalProbe, relayProbe, wgProbe, engineChan)
return runClient(ctx, config, statusRecorder, MobileDependency{}, mgmProbe, signalProbe, relayProbe, wgProbe)
}
// RunClientMobile with main logic on mobile system
@@ -67,7 +66,7 @@ func RunClientMobile(
HostDNSAddresses: dnsAddresses,
DnsReadyListener: dnsReadyListener,
}
return runClient(ctx, config, statusRecorder, mobileDependency, nil, nil, nil, nil, nil)
return runClient(ctx, config, statusRecorder, mobileDependency, nil, nil, nil, nil)
}
func RunClientiOS(
@@ -83,7 +82,7 @@ func RunClientiOS(
NetworkChangeListener: networkChangeListener,
DnsManager: dnsManager,
}
return runClient(ctx, config, statusRecorder, mobileDependency, nil, nil, nil, nil, nil)
return runClient(ctx, config, statusRecorder, mobileDependency, nil, nil, nil, nil)
}
func runClient(
@@ -95,7 +94,6 @@ func runClient(
signalProbe *Probe,
relayProbe *Probe,
wgProbe *Probe,
engineChan chan<- *Engine,
) error {
defer func() {
if r := recover(); r != nil {
@@ -245,9 +243,6 @@ func runClient(
log.Errorf("error while starting Netbird Connection Engine: %s", err)
return wrapErr(err)
}
if engineChan != nil {
engineChan <- engine
}
log.Print("Netbird engine started, my IP is: ", peerConfig.Address)
state.Set(StatusConnected)
@@ -257,10 +252,6 @@ func runClient(
backOff.Reset()
if engineChan != nil {
engineChan <- nil
}
err = engine.Stop()
if err != nil {
log.Errorf("failed stopping engine %v", err)

View File

@@ -31,8 +31,6 @@ func (d *localResolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
response := d.lookupRecord(r)
if response != nil {
replyMessage.Answer = append(replyMessage.Answer, response)
} else {
replyMessage.Rcode = dns.RcodeNameError
}
err := w.WriteMsg(replyMessage)

View File

@@ -308,7 +308,21 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
}
muxUpdates := append(localMuxUpdates, upstreamMuxUpdates...) //nolint:gocritic
s.updateMux(muxUpdates)
handler, _ := newUpstreamResolver(
s.ctx,
s.wgInterface.Name(),
s.wgInterface.Address().IP,
s.wgInterface.Address().Network,
s.statusRecorder,
)
handler.upstreamServers = []string{"9.9.9.9:53"}
handler.reactivate = func() {}
handler.deactivate = func(error) {}
s.updateMux(append(muxUpdates, muxUpdate{
domain: nbdns.RootZone,
handler: handler,
}))
s.updateLocalResolver(localRecords)
s.currentConfig = dnsConfigToHostDNSConfig(update, s.service.RuntimeIP(), s.service.RuntimePort())

View File

@@ -78,7 +78,7 @@ func (u *upstreamResolverBase) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
u.checkUpstreamFails(err)
}()
log.WithField("question", r.Question[0]).Trace("received an upstream question")
log.WithField("question", r.Question[0]).Debugf("received an upstream question upstreams %s", u.upstreamServers)
select {
case <-u.ctx.Done():

View File

@@ -111,9 +111,6 @@ type Engine struct {
// TURNs is a list of STUN servers used by ICE
TURNs []*stun.URI
// clientRoutes is the most recent list of clientRoutes received from the Management Service
clientRoutes map[string][]*route.Route
cancel context.CancelFunc
ctx context.Context
@@ -219,8 +216,6 @@ func (e *Engine) Stop() error {
return err
}
e.clientRoutes = nil
// very ugly but we want to remove peers from the WireGuard interface first before removing interface.
// Removing peers happens in the conn.CLose() asynchronously
time.Sleep(500 * time.Millisecond)
@@ -700,14 +695,11 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
if protoRoutes == nil {
protoRoutes = []*mgmProto.Route{}
}
_, clientRoutes, err := e.routeManager.UpdateRoutes(serial, toRoutes(protoRoutes))
err := e.routeManager.UpdateRoutes(serial, toRoutes(protoRoutes))
if err != nil {
log.Errorf("failed to update clientRoutes, err: %v", err)
log.Errorf("failed to update routes, err: %v", err)
}
e.clientRoutes = clientRoutes
protoDNSConfig := networkMap.GetDNSConfig()
if protoDNSConfig == nil {
protoDNSConfig = &mgmProto.DNSConfig{}
@@ -1237,28 +1229,6 @@ func (e *Engine) newDnsServer() ([]*route.Route, dns.Server, error) {
}
}
// GetClientRoutes returns the current routes from the route map
func (e *Engine) GetClientRoutes() map[string][]*route.Route {
return e.clientRoutes
}
// GetClientRoutesWithNetID returns the current routes from the route map, but the keys consist of the network ID only
func (e *Engine) GetClientRoutesWithNetID() map[string][]*route.Route {
routes := make(map[string][]*route.Route, len(e.clientRoutes))
for id, v := range e.clientRoutes {
if i := strings.LastIndex(id, "-"); i != -1 {
id = id[:i]
}
routes[id] = v
}
return routes
}
// GetRouteManager returns the route manager
func (e *Engine) GetRouteManager() routemanager.Manager {
return e.routeManager
}
func findIPFromInterfaceName(ifaceName string) (net.IP, error) {
iface, err := net.InterfaceByName(ifaceName)
if err != nil {

View File

@@ -22,7 +22,6 @@ import (
"google.golang.org/grpc/keepalive"
"github.com/netbirdio/management-integrations/integrations"
"github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/routemanager"
@@ -578,10 +577,10 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) {
}{}
mockRouteManager := &routemanager.MockManager{
UpdateRoutesFunc: func(updateSerial uint64, newRoutes []*route.Route) (map[string]*route.Route, map[string][]*route.Route, error) {
UpdateRoutesFunc: func(updateSerial uint64, newRoutes []*route.Route) error {
input.inputSerial = updateSerial
input.inputRoutes = newRoutes
return nil, nil, testCase.inputErr
return testCase.inputErr
},
}
@@ -598,8 +597,8 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) {
err = engine.updateNetworkMap(testCase.networkMap)
assert.NoError(t, err, "shouldn't return error")
assert.Equal(t, testCase.expectedSerial, input.inputSerial, "serial should match")
assert.Len(t, input.inputRoutes, testCase.expectedLen, "clientRoutes len should match")
assert.Equal(t, testCase.expectedRoutes, input.inputRoutes, "clientRoutes should match")
assert.Len(t, input.inputRoutes, testCase.expectedLen, "routes len should match")
assert.Equal(t, testCase.expectedRoutes, input.inputRoutes, "routes should match")
})
}
}
@@ -743,8 +742,8 @@ func TestEngine_UpdateNetworkMapWithDNSUpdate(t *testing.T) {
assert.NoError(t, err, "shouldn't return error")
mockRouteManager := &routemanager.MockManager{
UpdateRoutesFunc: func(updateSerial uint64, newRoutes []*route.Route) (map[string]*route.Route, map[string][]*route.Route, error) {
return nil, nil, nil
UpdateRoutesFunc: func(updateSerial uint64, newRoutes []*route.Route) error {
return nil
},
}

View File

@@ -1 +0,0 @@
package internal

View File

@@ -3,7 +3,6 @@ package routemanager
import (
"context"
"fmt"
"net"
"net/netip"
"time"
@@ -156,11 +155,7 @@ func (c *clientNetwork) getBestRouteFromStatuses(routePeerStatuses map[string]ro
if currScore != 0 && currScore < chosenScore+0.1 {
return currID
} else {
var peer string
if route := c.routes[chosen]; route != nil {
peer = route.Peer
}
log.Infof("new chosen route is %s with peer %s with score %f for network %s", chosen, peer, chosenScore, c.network)
log.Infof("new chosen route is %s with peer %s with score %f for network %s", chosen, c.routes[chosen].Peer, chosenScore, c.network)
}
}
@@ -220,7 +215,7 @@ func (c *clientNetwork) removeRouteFromWireguardPeer(peerKey string) error {
func (c *clientNetwork) removeRouteFromPeerAndSystem() error {
if c.chosenRoute != nil {
if err := removeVPNRoute(c.network, c.getAsInterface()); err != nil {
if err := removeVPNRoute(c.network, c.wgInterface.Name()); err != nil {
return fmt.Errorf("remove route %s from system, err: %v", c.network, err)
}
@@ -261,7 +256,7 @@ func (c *clientNetwork) recalculateRouteAndUpdatePeerAndSystem() error {
}
} else {
// otherwise add the route to the system
if err := addVPNRoute(c.network, c.getAsInterface()); err != nil {
if err := addVPNRoute(c.network, c.wgInterface.Name()); err != nil {
return fmt.Errorf("route %s couldn't be added for peer %s, err: %v",
c.network.String(), c.wgInterface.Address().IP.String(), err)
}
@@ -349,15 +344,3 @@ func (c *clientNetwork) peersStateAndUpdateWatcher() {
}
}
}
func (c *clientNetwork) getAsInterface() *net.Interface {
intf, err := net.InterfaceByName(c.wgInterface.Name())
if err != nil {
log.Warnf("Couldn't get interface by name %s: %v", c.wgInterface.Name(), err)
intf = &net.Interface{
Name: c.wgInterface.Name(),
}
}
return intf
}

View File

@@ -14,7 +14,6 @@ import (
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/internal/listener"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/routeselector"
"github.com/netbirdio/netbird/iface"
"github.com/netbirdio/netbird/route"
nbnet "github.com/netbirdio/netbird/util/net"
@@ -29,9 +28,7 @@ var defaultv6 = netip.PrefixFrom(netip.IPv6Unspecified(), 0)
// Manager is a route manager interface
type Manager interface {
Init() (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error)
UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) (map[string]*route.Route, map[string][]*route.Route, error)
TriggerSelection(map[string][]*route.Route)
GetRouteSelector() *routeselector.RouteSelector
UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) error
SetRouteChangeListener(listener listener.NetworkChangeListener)
InitialRouteRange() []string
EnableServerRouter(firewall firewall.Manager) error
@@ -44,7 +41,6 @@ type DefaultManager struct {
stop context.CancelFunc
mux sync.Mutex
clientNetworks map[string]*clientNetwork
routeSelector *routeselector.RouteSelector
serverRouter serverRouter
statusRecorder *peer.Status
wgInterface *iface.WGIface
@@ -58,7 +54,6 @@ func NewManager(ctx context.Context, pubKey string, wgInterface *iface.WGIface,
ctx: mCTX,
stop: cancel,
clientNetworks: make(map[string]*clientNetwork),
routeSelector: routeselector.NewRouteSelector(),
statusRecorder: statusRecorder,
wgInterface: wgInterface,
pubKey: pubKey,
@@ -122,29 +117,28 @@ func (m *DefaultManager) Stop() {
}
// UpdateRoutes compares received routes with existing routes and removes, updates or adds them to the client and server maps
func (m *DefaultManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) (map[string]*route.Route, map[string][]*route.Route, error) {
func (m *DefaultManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) error {
select {
case <-m.ctx.Done():
log.Infof("not updating routes as context is closed")
return nil, nil, m.ctx.Err()
return m.ctx.Err()
default:
m.mux.Lock()
defer m.mux.Unlock()
newServerRoutesMap, newClientRoutesIDMap := m.classifyRoutes(newRoutes)
newServerRoutesMap, newClientRoutesIDMap := m.classifiesRoutes(newRoutes)
filteredClientRoutes := m.routeSelector.FilterSelected(newClientRoutesIDMap)
m.updateClientNetworks(updateSerial, filteredClientRoutes)
m.notifier.onNewRoutes(filteredClientRoutes)
m.updateClientNetworks(updateSerial, newClientRoutesIDMap)
m.notifier.onNewRoutes(newClientRoutesIDMap)
if m.serverRouter != nil {
err := m.serverRouter.updateRoutes(newServerRoutesMap)
if err != nil {
return nil, nil, fmt.Errorf("update routes: %w", err)
return fmt.Errorf("update routes: %w", err)
}
}
return newServerRoutesMap, newClientRoutesIDMap, nil
return nil
}
}
@@ -158,51 +152,16 @@ func (m *DefaultManager) InitialRouteRange() []string {
return m.notifier.initialRouteRanges()
}
// GetRouteSelector returns the route selector
func (m *DefaultManager) GetRouteSelector() *routeselector.RouteSelector {
return m.routeSelector
}
// GetClientRoutes returns the client routes
func (m *DefaultManager) GetClientRoutes() map[string]*clientNetwork {
return m.clientNetworks
}
// TriggerSelection triggers the selection of routes, stopping deselected watchers and starting newly selected ones
func (m *DefaultManager) TriggerSelection(networks map[string][]*route.Route) {
m.mux.Lock()
defer m.mux.Unlock()
networks = m.routeSelector.FilterSelected(networks)
m.stopObsoleteClients(networks)
for id, routes := range networks {
if _, found := m.clientNetworks[id]; found {
// don't touch existing client network watchers
continue
}
clientNetworkWatcher := newClientNetworkWatcher(m.ctx, m.wgInterface, m.statusRecorder, routes[0].Network)
m.clientNetworks[id] = clientNetworkWatcher
go clientNetworkWatcher.peersStateAndUpdateWatcher()
clientNetworkWatcher.sendUpdateToClientNetworkWatcher(routesUpdate{routes: routes})
}
}
// stopObsoleteClients stops the client network watcher for the networks that are not in the new list
func (m *DefaultManager) stopObsoleteClients(networks map[string][]*route.Route) {
func (m *DefaultManager) updateClientNetworks(updateSerial uint64, networks map[string][]*route.Route) {
// removing routes that do not exist as per the update from the Management service.
for id, client := range m.clientNetworks {
if _, ok := networks[id]; !ok {
log.Debugf("Stopping client network watcher, %s", id)
_, found := networks[id]
if !found {
log.Debugf("stopping client network watcher, %s", id)
client.stop()
delete(m.clientNetworks, id)
}
}
}
func (m *DefaultManager) updateClientNetworks(updateSerial uint64, networks map[string][]*route.Route) {
// removing routes that do not exist as per the update from the Management service.
m.stopObsoleteClients(networks)
for id, routes := range networks {
clientNetworkWatcher, found := m.clientNetworks[id]
@@ -219,7 +178,7 @@ func (m *DefaultManager) updateClientNetworks(updateSerial uint64, networks map[
}
}
func (m *DefaultManager) classifyRoutes(newRoutes []*route.Route) (map[string]*route.Route, map[string][]*route.Route) {
func (m *DefaultManager) classifiesRoutes(newRoutes []*route.Route) (map[string]*route.Route, map[string][]*route.Route) {
newClientRoutesIDMap := make(map[string][]*route.Route)
newServerRoutesMap := make(map[string]*route.Route)
ownNetworkIDs := make(map[string]bool)
@@ -251,7 +210,7 @@ func (m *DefaultManager) classifyRoutes(newRoutes []*route.Route) (map[string]*r
}
func (m *DefaultManager) clientRoutes(initialRoutes []*route.Route) []*route.Route {
_, crMap := m.classifyRoutes(initialRoutes)
_, crMap := m.classifiesRoutes(initialRoutes)
rs := make([]*route.Route, 0)
for _, routes := range crMap {
rs = append(rs, routes...)
@@ -260,9 +219,12 @@ func (m *DefaultManager) clientRoutes(initialRoutes []*route.Route) []*route.Rou
}
func isPrefixSupported(prefix netip.Prefix) bool {
if runtime.GOOS == "ios" {
return true
}
if !nbnet.CustomRoutingDisabled() {
switch runtime.GOOS {
case "linux", "windows", "darwin", "ios":
case "linux", "windows", "darwin":
return true
}
}

View File

@@ -428,11 +428,11 @@ func TestManagerUpdateRoutes(t *testing.T) {
}
if len(testCase.inputInitRoutes) > 0 {
_, _, err = routeManager.UpdateRoutes(testCase.inputSerial, testCase.inputRoutes)
err = routeManager.UpdateRoutes(testCase.inputSerial, testCase.inputRoutes)
require.NoError(t, err, "should update routes with init routes")
}
_, _, err = routeManager.UpdateRoutes(testCase.inputSerial+uint64(len(testCase.inputInitRoutes)), testCase.inputRoutes)
err = routeManager.UpdateRoutes(testCase.inputSerial+uint64(len(testCase.inputInitRoutes)), testCase.inputRoutes)
require.NoError(t, err, "should update routes")
expectedWatchers := testCase.clientNetworkWatchersExpected

View File

@@ -7,17 +7,14 @@ import (
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/internal/listener"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/routeselector"
"github.com/netbirdio/netbird/iface"
"github.com/netbirdio/netbird/route"
)
// MockManager is the mock instance of a route manager
type MockManager struct {
UpdateRoutesFunc func(updateSerial uint64, newRoutes []*route.Route) (map[string]*route.Route, map[string][]*route.Route, error)
TriggerSelectionFunc func(map[string][]*route.Route)
GetRouteSelectorFunc func() *routeselector.RouteSelector
StopFunc func()
UpdateRoutesFunc func(updateSerial uint64, newRoutes []*route.Route) error
StopFunc func()
}
func (m *MockManager) Init() (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) {
@@ -30,25 +27,11 @@ func (m *MockManager) InitialRouteRange() []string {
}
// UpdateRoutes mock implementation of UpdateRoutes from Manager interface
func (m *MockManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) (map[string]*route.Route, map[string][]*route.Route, error) {
func (m *MockManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) error {
if m.UpdateRoutesFunc != nil {
return m.UpdateRoutesFunc(updateSerial, newRoutes)
}
return nil, nil, fmt.Errorf("method UpdateRoutes is not implemented")
}
func (m *MockManager) TriggerSelection(networks map[string][]*route.Route) {
if m.TriggerSelectionFunc != nil {
m.TriggerSelectionFunc(networks)
}
}
// GetRouteSelector mock implementation of GetRouteSelector from Manager interface
func (m *MockManager) GetRouteSelector() *routeselector.RouteSelector {
if m.GetRouteSelectorFunc != nil {
return m.GetRouteSelectorFunc()
}
return nil
return fmt.Errorf("method UpdateRoutes is not implemented")
}
// Start mock implementation of Start from Manager interface

View File

@@ -5,7 +5,6 @@ package routemanager
import (
"errors"
"fmt"
"net"
"net/netip"
"sync"
@@ -18,7 +17,7 @@ import (
type ref struct {
count int
nexthop netip.Addr
intf *net.Interface
intf string
}
type RouteManager struct {
@@ -31,8 +30,8 @@ type RouteManager struct {
mutex sync.Mutex
}
type AddRouteFunc func(prefix netip.Prefix) (nexthop netip.Addr, intf *net.Interface, err error)
type RemoveRouteFunc func(prefix netip.Prefix, nexthop netip.Addr, intf *net.Interface) error
type AddRouteFunc func(prefix netip.Prefix) (nexthop netip.Addr, intf string, err error)
type RemoveRouteFunc func(prefix netip.Prefix, nexthop netip.Addr, intf string) error
func NewRouteManager(addRoute AddRouteFunc, removeRoute RemoveRouteFunc) *RouteManager {
// TODO: read initial routing table into refCountMap

View File

@@ -60,13 +60,17 @@ func addRouteForCurrentDefaultGateway(prefix netip.Prefix) error {
return nil
}
var exitIntf string
gatewayHop, intf, err := getNextHop(defaultGateway)
if err != nil && !errors.Is(err, ErrRouteNotFound) {
return fmt.Errorf("unable to get the next hop for the default gateway address. error: %s", err)
}
if intf != nil {
exitIntf = intf.Name
}
log.Debugf("Adding a new route for gateway %s with next hop %s", gatewayPrefix, gatewayHop)
return addToRouteTable(gatewayPrefix, gatewayHop, intf)
return addToRouteTable(gatewayPrefix, gatewayHop, exitIntf)
}
func getNextHop(ip netip.Addr) (netip.Addr, *net.Interface, error) {
@@ -80,7 +84,7 @@ func getNextHop(ip netip.Addr) (netip.Addr, *net.Interface, error) {
return netip.Addr{}, nil, ErrRouteNotFound
}
log.Debugf("Route for %s: interface %v nexthop %v, preferred source %v", ip, intf, gateway, preferredSrc)
log.Debugf("Route for %s: interface %v, nexthop %v, preferred source %v", ip, intf, gateway, preferredSrc)
if gateway == nil {
if preferredSrc == nil {
return netip.Addr{}, nil, ErrRouteNotFound
@@ -149,7 +153,12 @@ func isSubRange(prefix netip.Prefix) (bool, error) {
// addRouteToNonVPNIntf adds a new route to the routing table for the given prefix and returns the next hop and interface.
// If the next hop or interface is pointing to the VPN interface, it will return the initial values.
func addRouteToNonVPNIntf(prefix netip.Prefix, vpnIntf *iface.WGIface, initialNextHop netip.Addr, initialIntf *net.Interface) (netip.Addr, *net.Interface, error) {
func addRouteToNonVPNIntf(
prefix netip.Prefix,
vpnIntf *iface.WGIface,
initialNextHop netip.Addr,
initialIntf *net.Interface,
) (netip.Addr, string, error) {
addr := prefix.Addr()
switch {
case addr.IsLoopback(),
@@ -159,34 +168,39 @@ func addRouteToNonVPNIntf(prefix netip.Prefix, vpnIntf *iface.WGIface, initialNe
addr.IsUnspecified(),
addr.IsMulticast():
return netip.Addr{}, nil, ErrRouteNotAllowed
return netip.Addr{}, "", ErrRouteNotAllowed
}
// Determine the exit interface and next hop for the prefix, so we can add a specific route
nexthop, intf, err := getNextHop(addr)
if err != nil {
return netip.Addr{}, nil, fmt.Errorf("get next hop: %w", err)
return netip.Addr{}, "", fmt.Errorf("get next hop: %w", err)
}
log.Debugf("Found next hop %s for prefix %s with interface %v", nexthop, prefix, intf)
exitNextHop := nexthop
exitIntf := intf
var exitIntf string
if intf != nil {
exitIntf = intf.Name
}
vpnAddr, ok := netip.AddrFromSlice(vpnIntf.Address().IP)
if !ok {
return netip.Addr{}, nil, fmt.Errorf("failed to convert vpn address to netip.Addr")
return netip.Addr{}, "", fmt.Errorf("failed to convert vpn address to netip.Addr")
}
// if next hop is the VPN address or the interface is the VPN interface, we should use the initial values
if exitNextHop == vpnAddr || exitIntf != nil && exitIntf.Name == vpnIntf.Name() {
if exitNextHop == vpnAddr || exitIntf == vpnIntf.Name() {
log.Debugf("Route for prefix %s is pointing to the VPN interface", prefix)
exitNextHop = initialNextHop
exitIntf = initialIntf
if initialIntf != nil {
exitIntf = initialIntf.Name
}
}
log.Debugf("Adding a new route for prefix %s with next hop %s", prefix, exitNextHop)
if err := addToRouteTable(prefix, exitNextHop, exitIntf); err != nil {
return netip.Addr{}, nil, fmt.Errorf("add route to table: %w", err)
return netip.Addr{}, "", fmt.Errorf("add route to table: %w", err)
}
return exitNextHop, exitIntf, nil
@@ -194,7 +208,7 @@ func addRouteToNonVPNIntf(prefix netip.Prefix, vpnIntf *iface.WGIface, initialNe
// genericAddVPNRoute adds a new route to the vpn interface, it splits the default prefix
// in two /1 prefixes to avoid replacing the existing default route
func genericAddVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
func genericAddVPNRoute(prefix netip.Prefix, intf string) error {
if prefix == defaultv4 {
if err := addToRouteTable(splitDefaultv4_1, netip.Addr{}, intf); err != nil {
return err
@@ -236,7 +250,7 @@ func genericAddVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
}
// addNonExistingRoute adds a new route to the vpn interface if it doesn't exist in the current routing table
func addNonExistingRoute(prefix netip.Prefix, intf *net.Interface) error {
func addNonExistingRoute(prefix netip.Prefix, intf string) error {
ok, err := existsInRouteTable(prefix)
if err != nil {
return fmt.Errorf("exists in route table: %w", err)
@@ -263,7 +277,7 @@ func addNonExistingRoute(prefix netip.Prefix, intf *net.Interface) error {
// genericRemoveVPNRoute removes the route from the vpn interface. If a default prefix is given,
// it will remove the split /1 prefixes
func genericRemoveVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
func genericRemoveVPNRoute(prefix netip.Prefix, intf string) error {
if prefix == defaultv4 {
var result *multierror.Error
if err := removeFromRouteTable(splitDefaultv4_1, netip.Addr{}, intf); err != nil {
@@ -329,7 +343,7 @@ func setupRoutingWithRouteManager(routeManager **RouteManager, initAddresses []n
}
*routeManager = NewRouteManager(
func(prefix netip.Prefix) (netip.Addr, *net.Interface, error) {
func(prefix netip.Prefix) (netip.Addr, string, error) {
addr := prefix.Addr()
nexthop, intf := initialNextHopV4, initialIntfV4
if addr.Is6() {

View File

@@ -24,10 +24,10 @@ func enableIPForwarding() error {
return nil
}
func addVPNRoute(netip.Prefix, *net.Interface) error {
func addVPNRoute(netip.Prefix, string) error {
return nil
}
func removeVPNRoute(netip.Prefix, *net.Interface) error {
func removeVPNRoute(netip.Prefix, string) error {
return nil
}

View File

@@ -27,15 +27,15 @@ func cleanupRouting() error {
return cleanupRoutingWithRouteManager(routeManager)
}
func addToRouteTable(prefix netip.Prefix, nexthop netip.Addr, intf *net.Interface) error {
func addToRouteTable(prefix netip.Prefix, nexthop netip.Addr, intf string) error {
return routeCmd("add", prefix, nexthop, intf)
}
func removeFromRouteTable(prefix netip.Prefix, nexthop netip.Addr, intf *net.Interface) error {
func removeFromRouteTable(prefix netip.Prefix, nexthop netip.Addr, intf string) error {
return routeCmd("delete", prefix, nexthop, intf)
}
func routeCmd(action string, prefix netip.Prefix, nexthop netip.Addr, intf *net.Interface) error {
func routeCmd(action string, prefix netip.Prefix, nexthop netip.Addr, intf string) error {
inet := "-inet"
network := prefix.String()
if prefix.IsSingleIP() {
@@ -46,15 +46,15 @@ func routeCmd(action string, prefix netip.Prefix, nexthop netip.Addr, intf *net.
// Special case for IPv6 split default route, pointing to the wg interface fails
// TODO: Remove once we have IPv6 support on the interface
if prefix.Bits() == 1 {
intf = &net.Interface{Name: "lo0"}
intf = "lo0"
}
}
args := []string{"-n", action, inet, network}
if nexthop.IsValid() {
args = append(args, nexthop.Unmap().String())
} else if intf != nil {
args = append(args, "-interface", intf.Name)
} else if intf != "" {
args = append(args, "-interface", intf)
}
if err := retryRouteCmd(args); err != nil {

View File

@@ -33,7 +33,7 @@ func init() {
func TestConcurrentRoutes(t *testing.T) {
baseIP := netip.MustParseAddr("192.0.2.0")
intf := &net.Interface{Name: "lo0"}
intf := "lo0"
var wg sync.WaitGroup
for i := 0; i < 1024; i++ {

View File

@@ -24,10 +24,10 @@ func enableIPForwarding() error {
return nil
}
func addVPNRoute(netip.Prefix, *net.Interface) error {
func addVPNRoute(netip.Prefix, string) error {
return nil
}
func removeVPNRoute(netip.Prefix, *net.Interface) error {
func removeVPNRoute(netip.Prefix, string) error {
return nil
}

View File

@@ -46,6 +46,9 @@ var routeManager = &RouteManager{}
// originalSysctl stores the original sysctl values before they are modified
var originalSysctl map[string]int
// determines whether to use the legacy routing setup
var isLegacy = os.Getenv("NB_USE_LEGACY_ROUTING") == "true" || nbnet.CustomRoutingDisabled()
// sysctlFailed is used as an indicator to emit a warning when default routes are configured
var sysctlFailed bool
@@ -59,20 +62,6 @@ type ruleParams struct {
description string
}
// isLegacy determines whether to use the legacy routing setup
func isLegacy() bool {
return os.Getenv("NB_USE_LEGACY_ROUTING") == "true" || nbnet.CustomRoutingDisabled()
}
// setIsLegacy sets the legacy routing setup
func setIsLegacy(b bool) {
if b {
os.Setenv("NB_USE_LEGACY_ROUTING", "true")
} else {
os.Unsetenv("NB_USE_LEGACY_ROUTING")
}
}
func getSetupRules() []ruleParams {
return []ruleParams{
{100, -1, syscall.RT_TABLE_MAIN, netlink.FAMILY_V4, false, 0, "rule with suppress prefixlen v4"},
@@ -93,7 +82,7 @@ func getSetupRules() []ruleParams {
// This table is where a default route or other specific routes received from the management server are configured,
// enabling VPN connectivity.
func setupRouting(initAddresses []net.IP, wgIface *iface.WGIface) (_ peer.BeforeAddPeerHookFunc, _ peer.AfterRemovePeerHookFunc, err error) {
if isLegacy() {
if isLegacy {
log.Infof("Using legacy routing setup")
return setupRoutingWithRouteManager(&routeManager, initAddresses, wgIface)
}
@@ -122,7 +111,7 @@ func setupRouting(initAddresses []net.IP, wgIface *iface.WGIface) (_ peer.Before
if err := addRule(rule); err != nil {
if errors.Is(err, syscall.EOPNOTSUPP) {
log.Warnf("Rule operations are not supported, falling back to the legacy routing setup")
setIsLegacy(true)
isLegacy = true
return setupRoutingWithRouteManager(&routeManager, initAddresses, wgIface)
}
return nil, nil, fmt.Errorf("%s: %w", rule.description, err)
@@ -136,7 +125,7 @@ func setupRouting(initAddresses []net.IP, wgIface *iface.WGIface) (_ peer.Before
// It systematically removes the three rules and any associated routing table entries to ensure a clean state.
// The function uses error aggregation to report any errors encountered during the cleanup process.
func cleanupRouting() error {
if isLegacy() {
if isLegacy {
return cleanupRoutingWithRouteManager(routeManager)
}
@@ -165,16 +154,16 @@ func cleanupRouting() error {
return result.ErrorOrNil()
}
func addToRouteTable(prefix netip.Prefix, nexthop netip.Addr, intf *net.Interface) error {
func addToRouteTable(prefix netip.Prefix, nexthop netip.Addr, intf string) error {
return addRoute(prefix, nexthop, intf, syscall.RT_TABLE_MAIN)
}
func removeFromRouteTable(prefix netip.Prefix, nexthop netip.Addr, intf *net.Interface) error {
func removeFromRouteTable(prefix netip.Prefix, nexthop netip.Addr, intf string) error {
return removeRoute(prefix, nexthop, intf, syscall.RT_TABLE_MAIN)
}
func addVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
if isLegacy() {
func addVPNRoute(prefix netip.Prefix, intf string) error {
if isLegacy {
return genericAddVPNRoute(prefix, intf)
}
@@ -196,8 +185,8 @@ func addVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
return nil
}
func removeVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
if isLegacy() {
func removeVPNRoute(prefix netip.Prefix, intf string) error {
if isLegacy {
return genericRemoveVPNRoute(prefix, intf)
}
@@ -255,7 +244,7 @@ func getRoutes(tableID, family int) ([]netip.Prefix, error) {
}
// addRoute adds a route to a specific routing table identified by tableID.
func addRoute(prefix netip.Prefix, addr netip.Addr, intf *net.Interface, tableID int) error {
func addRoute(prefix netip.Prefix, addr netip.Addr, intf string, tableID int) error {
route := &netlink.Route{
Scope: netlink.SCOPE_UNIVERSE,
Table: tableID,
@@ -315,10 +304,7 @@ func removeUnreachableRoute(prefix netip.Prefix, tableID int) error {
Dst: ipNet,
}
if err := netlink.RouteDel(route); err != nil &&
!errors.Is(err, syscall.ESRCH) &&
!errors.Is(err, syscall.ENOENT) &&
!errors.Is(err, syscall.EAFNOSUPPORT) {
if err := netlink.RouteDel(route); err != nil && !errors.Is(err, syscall.ESRCH) && !errors.Is(err, syscall.EAFNOSUPPORT) {
return fmt.Errorf("netlink remove unreachable route: %w", err)
}
@@ -327,7 +313,7 @@ func removeUnreachableRoute(prefix netip.Prefix, tableID int) error {
}
// removeRoute removes a route from a specific routing table identified by tableID.
func removeRoute(prefix netip.Prefix, addr netip.Addr, intf *net.Interface, tableID int) error {
func removeRoute(prefix netip.Prefix, addr netip.Addr, intf string, tableID int) error {
_, ipNet, err := net.ParseCIDR(prefix.String())
if err != nil {
return fmt.Errorf("parse prefix %s: %w", prefix, err)
@@ -481,24 +467,22 @@ func removeRule(params ruleParams) error {
}
// addNextHop adds the gateway and device to the route.
func addNextHop(addr netip.Addr, intf *net.Interface, route *netlink.Route) error {
if intf != nil {
route.LinkIndex = intf.Index
}
func addNextHop(addr netip.Addr, intf string, route *netlink.Route) error {
if addr.IsValid() {
route.Gw = addr.AsSlice()
// if zone is set, it means the gateway is a link-local address, so we set the link index
if addr.Zone() != "" && intf == nil {
link, err := netlink.LinkByName(addr.Zone())
if err != nil {
return fmt.Errorf("get link by name for zone %s: %w", addr.Zone(), err)
}
route.LinkIndex = link.Attrs().Index
if intf == "" {
intf = addr.Zone()
}
}
if intf != "" {
link, err := netlink.LinkByName(intf)
if err != nil {
return fmt.Errorf("set interface %s: %w", intf, err)
}
route.LinkIndex = link.Attrs().Index
}
return nil
}

View File

@@ -3,7 +3,6 @@
package routemanager
import (
"net"
"net/netip"
"runtime"
@@ -15,10 +14,10 @@ func enableIPForwarding() error {
return nil
}
func addVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
func addVPNRoute(prefix netip.Prefix, intf string) error {
return genericAddVPNRoute(prefix, intf)
}
func removeVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
func removeVPNRoute(prefix netip.Prefix, intf string) error {
return genericRemoveVPNRoute(prefix, intf)
}

View File

@@ -50,8 +50,6 @@ func TestAddRemoveRoutes(t *testing.T) {
for n, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
t.Setenv("NB_DISABLE_ROUTE_CACHE", "true")
peerPrivateKey, _ := wgtypes.GeneratePrivateKey()
newNet, err := stdnet.NewNet()
if err != nil {
@@ -69,11 +67,7 @@ func TestAddRemoveRoutes(t *testing.T) {
assert.NoError(t, cleanupRouting())
})
index, err := net.InterfaceByName(wgInterface.Name())
require.NoError(t, err, "InterfaceByName should not return err")
intf := &net.Interface{Index: index.Index, Name: wgInterface.Name()}
err = addVPNRoute(testCase.prefix, intf)
err = genericAddVPNRoute(testCase.prefix, wgInterface.Name())
require.NoError(t, err, "genericAddVPNRoute should not return err")
if testCase.shouldRouteToWireguard {
@@ -84,7 +78,7 @@ func TestAddRemoveRoutes(t *testing.T) {
exists, err := existsInRouteTable(testCase.prefix)
require.NoError(t, err, "existsInRouteTable should not return err")
if exists && testCase.shouldRouteToWireguard {
err = removeVPNRoute(testCase.prefix, intf)
err = genericRemoveVPNRoute(testCase.prefix, wgInterface.Name())
require.NoError(t, err, "genericRemoveVPNRoute should not return err")
prefixGateway, _, err := getNextHop(testCase.prefix.Addr())
@@ -188,16 +182,12 @@ func TestAddExistAndRemoveRoute(t *testing.T) {
}
for n, testCase := range testCases {
var buf bytes.Buffer
log.SetOutput(&buf)
defer func() {
log.SetOutput(os.Stderr)
}()
t.Run(testCase.name, func(t *testing.T) {
t.Setenv("NB_USE_LEGACY_ROUTING", "true")
t.Setenv("NB_DISABLE_ROUTE_CACHE", "true")
peerPrivateKey, _ := wgtypes.GeneratePrivateKey()
newNet, err := stdnet.NewNet()
if err != nil {
@@ -210,18 +200,14 @@ func TestAddExistAndRemoveRoute(t *testing.T) {
err = wgInterface.Create()
require.NoError(t, err, "should create testing wireguard interface")
index, err := net.InterfaceByName(wgInterface.Name())
require.NoError(t, err, "InterfaceByName should not return err")
intf := &net.Interface{Index: index.Index, Name: wgInterface.Name()}
// Prepare the environment
if testCase.preExistingPrefix.IsValid() {
err := addVPNRoute(testCase.preExistingPrefix, intf)
err := genericAddVPNRoute(testCase.preExistingPrefix, wgInterface.Name())
require.NoError(t, err, "should not return err when adding pre-existing route")
}
// Add the route
err = addVPNRoute(testCase.prefix, intf)
err = genericAddVPNRoute(testCase.prefix, wgInterface.Name())
require.NoError(t, err, "should not return err when adding route")
if testCase.shouldAddRoute {
@@ -231,7 +217,7 @@ func TestAddExistAndRemoveRoute(t *testing.T) {
require.True(t, ok, "route should exist")
// remove route again if added
err = removeVPNRoute(testCase.prefix, intf)
err = genericRemoveVPNRoute(testCase.prefix, wgInterface.Name())
require.NoError(t, err, "should not return err")
}
@@ -359,47 +345,43 @@ func setupTestEnv(t *testing.T) {
assert.NoError(t, cleanupRouting())
})
index, err := net.InterfaceByName(wgIface.Name())
require.NoError(t, err, "InterfaceByName should not return err")
intf := &net.Interface{Index: index.Index, Name: wgIface.Name()}
// default route exists in main table and vpn table
err = addVPNRoute(netip.MustParsePrefix("0.0.0.0/0"), intf)
err = addVPNRoute(netip.MustParsePrefix("0.0.0.0/0"), wgIface.Name())
require.NoError(t, err, "addVPNRoute should not return err")
t.Cleanup(func() {
err = removeVPNRoute(netip.MustParsePrefix("0.0.0.0/0"), intf)
err = removeVPNRoute(netip.MustParsePrefix("0.0.0.0/0"), wgIface.Name())
assert.NoError(t, err, "removeVPNRoute should not return err")
})
// 10.0.0.0/8 route exists in main table and vpn table
err = addVPNRoute(netip.MustParsePrefix("10.0.0.0/8"), intf)
err = addVPNRoute(netip.MustParsePrefix("10.0.0.0/8"), wgIface.Name())
require.NoError(t, err, "addVPNRoute should not return err")
t.Cleanup(func() {
err = removeVPNRoute(netip.MustParsePrefix("10.0.0.0/8"), intf)
err = removeVPNRoute(netip.MustParsePrefix("10.0.0.0/8"), wgIface.Name())
assert.NoError(t, err, "removeVPNRoute should not return err")
})
// 10.10.0.0/24 more specific route exists in vpn table
err = addVPNRoute(netip.MustParsePrefix("10.10.0.0/24"), intf)
err = addVPNRoute(netip.MustParsePrefix("10.10.0.0/24"), wgIface.Name())
require.NoError(t, err, "addVPNRoute should not return err")
t.Cleanup(func() {
err = removeVPNRoute(netip.MustParsePrefix("10.10.0.0/24"), intf)
err = removeVPNRoute(netip.MustParsePrefix("10.10.0.0/24"), wgIface.Name())
assert.NoError(t, err, "removeVPNRoute should not return err")
})
// 127.0.10.0/24 more specific route exists in vpn table
err = addVPNRoute(netip.MustParsePrefix("127.0.10.0/24"), intf)
err = addVPNRoute(netip.MustParsePrefix("127.0.10.0/24"), wgIface.Name())
require.NoError(t, err, "addVPNRoute should not return err")
t.Cleanup(func() {
err = removeVPNRoute(netip.MustParsePrefix("127.0.10.0/24"), intf)
err = removeVPNRoute(netip.MustParsePrefix("127.0.10.0/24"), wgIface.Name())
assert.NoError(t, err, "removeVPNRoute should not return err")
})
// unique route in vpn table
err = addVPNRoute(netip.MustParsePrefix("172.16.0.0/12"), intf)
err = addVPNRoute(netip.MustParsePrefix("172.16.0.0/12"), wgIface.Name())
require.NoError(t, err, "addVPNRoute should not return err")
t.Cleanup(func() {
err = removeVPNRoute(netip.MustParsePrefix("172.16.0.0/12"), intf)
err = removeVPNRoute(netip.MustParsePrefix("172.16.0.0/12"), wgIface.Name())
assert.NoError(t, err, "removeVPNRoute should not return err")
})
}

View File

@@ -6,12 +6,8 @@ import (
"fmt"
"net"
"net/netip"
"os"
"os/exec"
"strconv"
"strings"
"sync"
"time"
log "github.com/sirupsen/logrus"
"github.com/yusufpapurcu/wmi"
@@ -25,10 +21,6 @@ type Win32_IP4RouteTable struct {
Mask string
}
var prefixList []netip.Prefix
var lastUpdate time.Time
var mux = sync.Mutex{}
var routeManager *RouteManager
func setupRouting(initAddresses []net.IP, wgIface *iface.WGIface) (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) {
@@ -40,23 +32,15 @@ func cleanupRouting() error {
}
func getRoutesFromTable() ([]netip.Prefix, error) {
mux.Lock()
defer mux.Unlock()
var routes []Win32_IP4RouteTable
query := "SELECT Destination, Mask FROM Win32_IP4RouteTable"
// If many routes are added at the same time this might block for a long time (seconds to minutes), so we cache the result
if !isCacheDisabled() && time.Since(lastUpdate) < 2*time.Second {
return prefixList, nil
}
var routes []Win32_IP4RouteTable
err := wmi.Query(query, &routes)
if err != nil {
return nil, fmt.Errorf("get routes: %w", err)
}
prefixList = nil
var prefixList []netip.Prefix
for _, route := range routes {
addr, err := netip.ParseAddr(route.Destination)
if err != nil {
@@ -76,29 +60,54 @@ func getRoutesFromTable() ([]netip.Prefix, error) {
prefixList = append(prefixList, routePrefix)
}
}
lastUpdate = time.Now()
return prefixList, nil
}
func addRouteCmd(prefix netip.Prefix, nexthop netip.Addr, intf *net.Interface) error {
args := []string{"add", prefix.String()}
func addRoutePowershell(prefix netip.Prefix, nexthop netip.Addr, intf, intfIdx string) error {
destinationPrefix := prefix.String()
psCmd := "New-NetRoute"
addressFamily := "IPv4"
if prefix.Addr().Is6() {
addressFamily = "IPv6"
}
script := fmt.Sprintf(
`%s -AddressFamily "%s" -DestinationPrefix "%s" -Confirm:$False -ErrorAction Stop -PolicyStore ActiveStore`,
psCmd, addressFamily, destinationPrefix,
)
if intfIdx != "" {
script = fmt.Sprintf(
`%s -InterfaceIndex %s`, script, intfIdx,
)
} else {
script = fmt.Sprintf(
`%s -InterfaceAlias "%s"`, script, intf,
)
}
if nexthop.IsValid() {
args = append(args, nexthop.Unmap().String())
} else {
addr := "0.0.0.0"
if prefix.Addr().Is6() {
addr = "::"
}
args = append(args, addr)
script = fmt.Sprintf(
`%s -NextHop "%s"`, script, nexthop,
)
}
if intf != nil {
args = append(args, "if", strconv.Itoa(intf.Index))
out, err := exec.Command("powershell", "-Command", script).CombinedOutput()
log.Tracef("PowerShell %s: %s", script, string(out))
if err != nil {
return fmt.Errorf("PowerShell add route: %w", err)
}
return nil
}
func addRouteCmd(prefix netip.Prefix, nexthop netip.Addr, _ string) error {
args := []string{"add", prefix.String(), nexthop.Unmap().String()}
out, err := exec.Command("route", args...).CombinedOutput()
log.Tracef("route %s: %s", strings.Join(args, " "), out)
if err != nil {
return fmt.Errorf("route add: %w", err)
@@ -107,20 +116,21 @@ func addRouteCmd(prefix netip.Prefix, nexthop netip.Addr, intf *net.Interface) e
return nil
}
func addToRouteTable(prefix netip.Prefix, nexthop netip.Addr, intf *net.Interface) error {
if nexthop.Zone() != "" && intf == nil {
zone, err := strconv.Atoi(nexthop.Zone())
if err != nil {
return fmt.Errorf("invalid zone: %w", err)
}
intf = &net.Interface{Index: zone}
func addToRouteTable(prefix netip.Prefix, nexthop netip.Addr, intf string) error {
var intfIdx string
if nexthop.Zone() != "" {
intfIdx = nexthop.Zone()
nexthop.WithZone("")
}
// Powershell doesn't support adding routes without an interface but allows to add interface by name
if intf != "" || intfIdx != "" {
return addRoutePowershell(prefix, nexthop, intf, intfIdx)
}
return addRouteCmd(prefix, nexthop, intf)
}
func removeFromRouteTable(prefix netip.Prefix, nexthop netip.Addr, _ *net.Interface) error {
func removeFromRouteTable(prefix netip.Prefix, nexthop netip.Addr, _ string) error {
args := []string{"delete", prefix.String()}
if nexthop.IsValid() {
nexthop.WithZone("")
@@ -135,7 +145,3 @@ func removeFromRouteTable(prefix netip.Prefix, nexthop netip.Addr, _ *net.Interf
}
return nil
}
func isCacheDisabled() bool {
return os.Getenv("NB_DISABLE_ROUTE_CACHE") == "true"
}

View File

@@ -1,132 +0,0 @@
package routeselector
import (
"fmt"
"slices"
"strings"
"github.com/hashicorp/go-multierror"
"golang.org/x/exp/maps"
route "github.com/netbirdio/netbird/route"
)
type RouteSelector struct {
selectedRoutes map[string]struct{}
selectAll bool
}
func NewRouteSelector() *RouteSelector {
return &RouteSelector{
selectedRoutes: map[string]struct{}{},
// default selects all routes
selectAll: true,
}
}
// SelectRoutes updates the selected routes based on the provided route IDs.
func (rs *RouteSelector) SelectRoutes(routes []string, appendRoute bool, allRoutes []string) error {
if !appendRoute {
rs.selectedRoutes = map[string]struct{}{}
}
var multiErr *multierror.Error
for _, route := range routes {
if !slices.Contains(allRoutes, route) {
multiErr = multierror.Append(multiErr, fmt.Errorf("route '%s' is not available", route))
continue
}
rs.selectedRoutes[route] = struct{}{}
}
rs.selectAll = false
if multiErr != nil {
multiErr.ErrorFormat = formatError
}
return multiErr.ErrorOrNil()
}
// SelectAllRoutes sets the selector to select all routes.
func (rs *RouteSelector) SelectAllRoutes() {
rs.selectAll = true
rs.selectedRoutes = map[string]struct{}{}
}
// DeselectRoutes removes specific routes from the selection.
// If the selector is in "select all" mode, it will transition to "select specific" mode.
func (rs *RouteSelector) DeselectRoutes(routes []string, allRoutes []string) error {
if rs.selectAll {
rs.selectAll = false
rs.selectedRoutes = map[string]struct{}{}
for _, route := range allRoutes {
rs.selectedRoutes[route] = struct{}{}
}
}
var multiErr *multierror.Error
for _, route := range routes {
if !slices.Contains(allRoutes, route) {
multiErr = multierror.Append(multiErr, fmt.Errorf("route '%s' is not available", route))
continue
}
delete(rs.selectedRoutes, route)
}
if multiErr != nil {
multiErr.ErrorFormat = formatError
}
return multiErr.ErrorOrNil()
}
// DeselectAllRoutes deselects all routes, effectively disabling route selection.
func (rs *RouteSelector) DeselectAllRoutes() {
rs.selectAll = false
rs.selectedRoutes = map[string]struct{}{}
}
// IsSelected checks if a specific route is selected.
func (rs *RouteSelector) IsSelected(routeID string) bool {
if rs.selectAll {
return true
}
_, selected := rs.selectedRoutes[routeID]
return selected
}
// FilterSelected removes unselected routes from the provided map.
func (rs *RouteSelector) FilterSelected(routes map[string][]*route.Route) map[string][]*route.Route {
if rs.selectAll {
return maps.Clone(routes)
}
filtered := map[string][]*route.Route{}
for id, rt := range routes {
netID := id
if i := strings.LastIndex(id, "-"); i != -1 {
netID = id[:i]
}
if rs.IsSelected(netID) {
filtered[id] = rt
}
}
return filtered
}
func formatError(es []error) string {
if len(es) == 1 {
return fmt.Sprintf("1 error occurred:\n\t* %s", es[0])
}
points := make([]string, len(es))
for i, err := range es {
points[i] = fmt.Sprintf("* %s", err)
}
return fmt.Sprintf(
"%d errors occurred:\n\t%s",
len(es), strings.Join(points, "\n\t"))
}

View File

@@ -1,275 +0,0 @@
package routeselector_test
import (
"slices"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/internal/routeselector"
"github.com/netbirdio/netbird/route"
)
func TestRouteSelector_SelectRoutes(t *testing.T) {
allRoutes := []string{"route1", "route2", "route3"}
tests := []struct {
name string
initialSelected []string
selectRoutes []string
append bool
wantSelected []string
wantError bool
}{
{
name: "Select specific routes, initial all selected",
selectRoutes: []string{"route1", "route2"},
wantSelected: []string{"route1", "route2"},
},
{
name: "Select specific routes, initial all deselected",
initialSelected: []string{},
selectRoutes: []string{"route1", "route2"},
wantSelected: []string{"route1", "route2"},
},
{
name: "Select specific routes with initial selection",
initialSelected: []string{"route1"},
selectRoutes: []string{"route2", "route3"},
wantSelected: []string{"route2", "route3"},
},
{
name: "Select non-existing route",
selectRoutes: []string{"route1", "route4"},
wantSelected: []string{"route1"},
wantError: true,
},
{
name: "Append route with initial selection",
initialSelected: []string{"route1"},
selectRoutes: []string{"route2"},
append: true,
wantSelected: []string{"route1", "route2"},
},
{
name: "Append route without initial selection",
selectRoutes: []string{"route2"},
append: true,
wantSelected: []string{"route2"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
rs := routeselector.NewRouteSelector()
if tt.initialSelected != nil {
err := rs.SelectRoutes(tt.initialSelected, false, allRoutes)
require.NoError(t, err)
}
err := rs.SelectRoutes(tt.selectRoutes, tt.append, allRoutes)
if tt.wantError {
assert.Error(t, err)
} else {
assert.NoError(t, err)
}
for _, id := range allRoutes {
assert.Equal(t, rs.IsSelected(id), slices.Contains(tt.wantSelected, id))
}
})
}
}
func TestRouteSelector_SelectAllRoutes(t *testing.T) {
allRoutes := []string{"route1", "route2", "route3"}
tests := []struct {
name string
initialSelected []string
wantSelected []string
}{
{
name: "Initial all selected",
wantSelected: []string{"route1", "route2", "route3"},
},
{
name: "Initial all deselected",
initialSelected: []string{},
wantSelected: []string{"route1", "route2", "route3"},
},
{
name: "Initial some selected",
initialSelected: []string{"route1"},
wantSelected: []string{"route1", "route2", "route3"},
},
{
name: "Initial all selected",
initialSelected: []string{"route1", "route2", "route3"},
wantSelected: []string{"route1", "route2", "route3"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
rs := routeselector.NewRouteSelector()
if tt.initialSelected != nil {
err := rs.SelectRoutes(tt.initialSelected, false, allRoutes)
require.NoError(t, err)
}
rs.SelectAllRoutes()
for _, id := range allRoutes {
assert.Equal(t, rs.IsSelected(id), slices.Contains(tt.wantSelected, id))
}
})
}
}
func TestRouteSelector_DeselectRoutes(t *testing.T) {
allRoutes := []string{"route1", "route2", "route3"}
tests := []struct {
name string
initialSelected []string
deselectRoutes []string
wantSelected []string
wantError bool
}{
{
name: "Deselect specific routes, initial all selected",
deselectRoutes: []string{"route1", "route2"},
wantSelected: []string{"route3"},
},
{
name: "Deselect specific routes, initial all deselected",
initialSelected: []string{},
deselectRoutes: []string{"route1", "route2"},
wantSelected: []string{},
},
{
name: "Deselect specific routes with initial selection",
initialSelected: []string{"route1", "route2"},
deselectRoutes: []string{"route1", "route3"},
wantSelected: []string{"route2"},
},
{
name: "Deselect non-existing route",
initialSelected: []string{"route1", "route2"},
deselectRoutes: []string{"route1", "route4"},
wantSelected: []string{"route2"},
wantError: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
rs := routeselector.NewRouteSelector()
if tt.initialSelected != nil {
err := rs.SelectRoutes(tt.initialSelected, false, allRoutes)
require.NoError(t, err)
}
err := rs.DeselectRoutes(tt.deselectRoutes, allRoutes)
if tt.wantError {
assert.Error(t, err)
} else {
assert.NoError(t, err)
}
for _, id := range allRoutes {
assert.Equal(t, rs.IsSelected(id), slices.Contains(tt.wantSelected, id))
}
})
}
}
func TestRouteSelector_DeselectAll(t *testing.T) {
allRoutes := []string{"route1", "route2", "route3"}
tests := []struct {
name string
initialSelected []string
wantSelected []string
}{
{
name: "Initial all selected",
wantSelected: []string{},
},
{
name: "Initial all deselected",
initialSelected: []string{},
wantSelected: []string{},
},
{
name: "Initial some selected",
initialSelected: []string{"route1", "route2"},
wantSelected: []string{},
},
{
name: "Initial all selected",
initialSelected: []string{"route1", "route2", "route3"},
wantSelected: []string{},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
rs := routeselector.NewRouteSelector()
if tt.initialSelected != nil {
err := rs.SelectRoutes(tt.initialSelected, false, allRoutes)
require.NoError(t, err)
}
rs.DeselectAllRoutes()
for _, id := range allRoutes {
assert.Equal(t, rs.IsSelected(id), slices.Contains(tt.wantSelected, id))
}
})
}
}
func TestRouteSelector_IsSelected(t *testing.T) {
rs := routeselector.NewRouteSelector()
err := rs.SelectRoutes([]string{"route1", "route2"}, false, []string{"route1", "route2", "route3"})
require.NoError(t, err)
assert.True(t, rs.IsSelected("route1"))
assert.True(t, rs.IsSelected("route2"))
assert.False(t, rs.IsSelected("route3"))
assert.False(t, rs.IsSelected("route4"))
}
func TestRouteSelector_FilterSelected(t *testing.T) {
rs := routeselector.NewRouteSelector()
err := rs.SelectRoutes([]string{"route1", "route2"}, false, []string{"route1", "route2", "route3"})
require.NoError(t, err)
routes := map[string][]*route.Route{
"route1-10.0.0.0/8": {},
"route2-192.168.0.0/16": {},
"route3-172.16.0.0/12": {},
}
filtered := rs.FilterSelected(routes)
assert.Equal(t, map[string][]*route.Route{
"route1-10.0.0.0/8": {},
"route2-192.168.0.0/16": {},
}, filtered)
}

View File

@@ -71,42 +71,6 @@ func (p *Preferences) SetPreSharedKey(key string) {
p.configInput.PreSharedKey = &key
}
// SetRosenpassEnabled store if rosenpass is enabled
func (p *Preferences) SetRosenpassEnabled(enabled bool) {
p.configInput.RosenpassEnabled = &enabled
}
// GetRosenpassEnabled read rosenpass enabled from config file
func (p *Preferences) GetRosenpassEnabled() (bool, error) {
if p.configInput.RosenpassEnabled != nil {
return *p.configInput.RosenpassEnabled, nil
}
cfg, err := internal.ReadConfig(p.configInput.ConfigPath)
if err != nil {
return false, err
}
return cfg.RosenpassEnabled, err
}
// SetRosenpassPermissive store the given permissive and wait for commit
func (p *Preferences) SetRosenpassPermissive(permissive bool) {
p.configInput.RosenpassPermissive = &permissive
}
// GetRosenpassPermissive read rosenpass permissive from config file
func (p *Preferences) GetRosenpassPermissive() (bool, error) {
if p.configInput.RosenpassPermissive != nil {
return *p.configInput.RosenpassPermissive, nil
}
cfg, err := internal.ReadConfig(p.configInput.ConfigPath)
if err != nil {
return false, err
}
return cfg.RosenpassPermissive, err
}
// Commit write out the changes into config file
func (p *Preferences) Commit() error {
_, err := internal.UpdateOrCreateConfig(p.configInput)

View File

@@ -1,17 +1,17 @@
// Code generated by protoc-gen-go. DO NOT EDIT.
// versions:
// protoc-gen-go v1.26.0
// protoc v3.12.4
// protoc v4.24.3
// source: daemon.proto
package proto
import (
_ "github.com/golang/protobuf/protoc-gen-go/descriptor"
duration "github.com/golang/protobuf/ptypes/duration"
timestamp "github.com/golang/protobuf/ptypes/timestamp"
protoreflect "google.golang.org/protobuf/reflect/protoreflect"
protoimpl "google.golang.org/protobuf/runtime/protoimpl"
_ "google.golang.org/protobuf/types/descriptorpb"
durationpb "google.golang.org/protobuf/types/known/durationpb"
timestamppb "google.golang.org/protobuf/types/known/timestamppb"
reflect "reflect"
sync "sync"
)
@@ -23,70 +23,6 @@ const (
_ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20)
)
type LogLevel int32
const (
LogLevel_UNKNOWN LogLevel = 0
LogLevel_PANIC LogLevel = 1
LogLevel_FATAL LogLevel = 2
LogLevel_ERROR LogLevel = 3
LogLevel_WARN LogLevel = 4
LogLevel_INFO LogLevel = 5
LogLevel_DEBUG LogLevel = 6
LogLevel_TRACE LogLevel = 7
)
// Enum value maps for LogLevel.
var (
LogLevel_name = map[int32]string{
0: "UNKNOWN",
1: "PANIC",
2: "FATAL",
3: "ERROR",
4: "WARN",
5: "INFO",
6: "DEBUG",
7: "TRACE",
}
LogLevel_value = map[string]int32{
"UNKNOWN": 0,
"PANIC": 1,
"FATAL": 2,
"ERROR": 3,
"WARN": 4,
"INFO": 5,
"DEBUG": 6,
"TRACE": 7,
}
)
func (x LogLevel) Enum() *LogLevel {
p := new(LogLevel)
*p = x
return p
}
func (x LogLevel) String() string {
return protoimpl.X.EnumStringOf(x.Descriptor(), protoreflect.EnumNumber(x))
}
func (LogLevel) Descriptor() protoreflect.EnumDescriptor {
return file_daemon_proto_enumTypes[0].Descriptor()
}
func (LogLevel) Type() protoreflect.EnumType {
return &file_daemon_proto_enumTypes[0]
}
func (x LogLevel) Number() protoreflect.EnumNumber {
return protoreflect.EnumNumber(x)
}
// Deprecated: Use LogLevel.Descriptor instead.
func (LogLevel) EnumDescriptor() ([]byte, []int) {
return file_daemon_proto_rawDescGZIP(), []int{0}
}
type LoginRequest struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
@@ -830,23 +766,23 @@ type PeerState struct {
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
IP string `protobuf:"bytes,1,opt,name=IP,proto3" json:"IP,omitempty"`
PubKey string `protobuf:"bytes,2,opt,name=pubKey,proto3" json:"pubKey,omitempty"`
ConnStatus string `protobuf:"bytes,3,opt,name=connStatus,proto3" json:"connStatus,omitempty"`
ConnStatusUpdate *timestamp.Timestamp `protobuf:"bytes,4,opt,name=connStatusUpdate,proto3" json:"connStatusUpdate,omitempty"`
Relayed bool `protobuf:"varint,5,opt,name=relayed,proto3" json:"relayed,omitempty"`
Direct bool `protobuf:"varint,6,opt,name=direct,proto3" json:"direct,omitempty"`
LocalIceCandidateType string `protobuf:"bytes,7,opt,name=localIceCandidateType,proto3" json:"localIceCandidateType,omitempty"`
RemoteIceCandidateType string `protobuf:"bytes,8,opt,name=remoteIceCandidateType,proto3" json:"remoteIceCandidateType,omitempty"`
Fqdn string `protobuf:"bytes,9,opt,name=fqdn,proto3" json:"fqdn,omitempty"`
LocalIceCandidateEndpoint string `protobuf:"bytes,10,opt,name=localIceCandidateEndpoint,proto3" json:"localIceCandidateEndpoint,omitempty"`
RemoteIceCandidateEndpoint string `protobuf:"bytes,11,opt,name=remoteIceCandidateEndpoint,proto3" json:"remoteIceCandidateEndpoint,omitempty"`
LastWireguardHandshake *timestamp.Timestamp `protobuf:"bytes,12,opt,name=lastWireguardHandshake,proto3" json:"lastWireguardHandshake,omitempty"`
BytesRx int64 `protobuf:"varint,13,opt,name=bytesRx,proto3" json:"bytesRx,omitempty"`
BytesTx int64 `protobuf:"varint,14,opt,name=bytesTx,proto3" json:"bytesTx,omitempty"`
RosenpassEnabled bool `protobuf:"varint,15,opt,name=rosenpassEnabled,proto3" json:"rosenpassEnabled,omitempty"`
Routes []string `protobuf:"bytes,16,rep,name=routes,proto3" json:"routes,omitempty"`
Latency *duration.Duration `protobuf:"bytes,17,opt,name=latency,proto3" json:"latency,omitempty"`
IP string `protobuf:"bytes,1,opt,name=IP,proto3" json:"IP,omitempty"`
PubKey string `protobuf:"bytes,2,opt,name=pubKey,proto3" json:"pubKey,omitempty"`
ConnStatus string `protobuf:"bytes,3,opt,name=connStatus,proto3" json:"connStatus,omitempty"`
ConnStatusUpdate *timestamppb.Timestamp `protobuf:"bytes,4,opt,name=connStatusUpdate,proto3" json:"connStatusUpdate,omitempty"`
Relayed bool `protobuf:"varint,5,opt,name=relayed,proto3" json:"relayed,omitempty"`
Direct bool `protobuf:"varint,6,opt,name=direct,proto3" json:"direct,omitempty"`
LocalIceCandidateType string `protobuf:"bytes,7,opt,name=localIceCandidateType,proto3" json:"localIceCandidateType,omitempty"`
RemoteIceCandidateType string `protobuf:"bytes,8,opt,name=remoteIceCandidateType,proto3" json:"remoteIceCandidateType,omitempty"`
Fqdn string `protobuf:"bytes,9,opt,name=fqdn,proto3" json:"fqdn,omitempty"`
LocalIceCandidateEndpoint string `protobuf:"bytes,10,opt,name=localIceCandidateEndpoint,proto3" json:"localIceCandidateEndpoint,omitempty"`
RemoteIceCandidateEndpoint string `protobuf:"bytes,11,opt,name=remoteIceCandidateEndpoint,proto3" json:"remoteIceCandidateEndpoint,omitempty"`
LastWireguardHandshake *timestamppb.Timestamp `protobuf:"bytes,12,opt,name=lastWireguardHandshake,proto3" json:"lastWireguardHandshake,omitempty"`
BytesRx int64 `protobuf:"varint,13,opt,name=bytesRx,proto3" json:"bytesRx,omitempty"`
BytesTx int64 `protobuf:"varint,14,opt,name=bytesTx,proto3" json:"bytesTx,omitempty"`
RosenpassEnabled bool `protobuf:"varint,15,opt,name=rosenpassEnabled,proto3" json:"rosenpassEnabled,omitempty"`
Routes []string `protobuf:"bytes,16,rep,name=routes,proto3" json:"routes,omitempty"`
Latency *durationpb.Duration `protobuf:"bytes,17,opt,name=latency,proto3" json:"latency,omitempty"`
}
func (x *PeerState) Reset() {
@@ -902,7 +838,7 @@ func (x *PeerState) GetConnStatus() string {
return ""
}
func (x *PeerState) GetConnStatusUpdate() *timestamp.Timestamp {
func (x *PeerState) GetConnStatusUpdate() *timestamppb.Timestamp {
if x != nil {
return x.ConnStatusUpdate
}
@@ -958,7 +894,7 @@ func (x *PeerState) GetRemoteIceCandidateEndpoint() string {
return ""
}
func (x *PeerState) GetLastWireguardHandshake() *timestamp.Timestamp {
func (x *PeerState) GetLastWireguardHandshake() *timestamppb.Timestamp {
if x != nil {
return x.LastWireguardHandshake
}
@@ -993,7 +929,7 @@ func (x *PeerState) GetRoutes() []string {
return nil
}
func (x *PeerState) GetLatency() *duration.Duration {
func (x *PeerState) GetLatency() *durationpb.Duration {
if x != nil {
return x.Latency
}
@@ -1447,442 +1383,6 @@ func (x *FullStatus) GetDnsServers() []*NSGroupState {
return nil
}
type ListRoutesRequest struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
}
func (x *ListRoutesRequest) Reset() {
*x = ListRoutesRequest{}
if protoimpl.UnsafeEnabled {
mi := &file_daemon_proto_msgTypes[19]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *ListRoutesRequest) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*ListRoutesRequest) ProtoMessage() {}
func (x *ListRoutesRequest) ProtoReflect() protoreflect.Message {
mi := &file_daemon_proto_msgTypes[19]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use ListRoutesRequest.ProtoReflect.Descriptor instead.
func (*ListRoutesRequest) Descriptor() ([]byte, []int) {
return file_daemon_proto_rawDescGZIP(), []int{19}
}
type ListRoutesResponse struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
Routes []*Route `protobuf:"bytes,1,rep,name=routes,proto3" json:"routes,omitempty"`
}
func (x *ListRoutesResponse) Reset() {
*x = ListRoutesResponse{}
if protoimpl.UnsafeEnabled {
mi := &file_daemon_proto_msgTypes[20]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *ListRoutesResponse) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*ListRoutesResponse) ProtoMessage() {}
func (x *ListRoutesResponse) ProtoReflect() protoreflect.Message {
mi := &file_daemon_proto_msgTypes[20]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use ListRoutesResponse.ProtoReflect.Descriptor instead.
func (*ListRoutesResponse) Descriptor() ([]byte, []int) {
return file_daemon_proto_rawDescGZIP(), []int{20}
}
func (x *ListRoutesResponse) GetRoutes() []*Route {
if x != nil {
return x.Routes
}
return nil
}
type SelectRoutesRequest struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
RouteIDs []string `protobuf:"bytes,1,rep,name=routeIDs,proto3" json:"routeIDs,omitempty"`
Append bool `protobuf:"varint,2,opt,name=append,proto3" json:"append,omitempty"`
All bool `protobuf:"varint,3,opt,name=all,proto3" json:"all,omitempty"`
}
func (x *SelectRoutesRequest) Reset() {
*x = SelectRoutesRequest{}
if protoimpl.UnsafeEnabled {
mi := &file_daemon_proto_msgTypes[21]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *SelectRoutesRequest) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*SelectRoutesRequest) ProtoMessage() {}
func (x *SelectRoutesRequest) ProtoReflect() protoreflect.Message {
mi := &file_daemon_proto_msgTypes[21]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use SelectRoutesRequest.ProtoReflect.Descriptor instead.
func (*SelectRoutesRequest) Descriptor() ([]byte, []int) {
return file_daemon_proto_rawDescGZIP(), []int{21}
}
func (x *SelectRoutesRequest) GetRouteIDs() []string {
if x != nil {
return x.RouteIDs
}
return nil
}
func (x *SelectRoutesRequest) GetAppend() bool {
if x != nil {
return x.Append
}
return false
}
func (x *SelectRoutesRequest) GetAll() bool {
if x != nil {
return x.All
}
return false
}
type SelectRoutesResponse struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
}
func (x *SelectRoutesResponse) Reset() {
*x = SelectRoutesResponse{}
if protoimpl.UnsafeEnabled {
mi := &file_daemon_proto_msgTypes[22]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *SelectRoutesResponse) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*SelectRoutesResponse) ProtoMessage() {}
func (x *SelectRoutesResponse) ProtoReflect() protoreflect.Message {
mi := &file_daemon_proto_msgTypes[22]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use SelectRoutesResponse.ProtoReflect.Descriptor instead.
func (*SelectRoutesResponse) Descriptor() ([]byte, []int) {
return file_daemon_proto_rawDescGZIP(), []int{22}
}
type Route struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
ID string `protobuf:"bytes,1,opt,name=ID,proto3" json:"ID,omitempty"`
Network string `protobuf:"bytes,2,opt,name=network,proto3" json:"network,omitempty"`
Selected bool `protobuf:"varint,3,opt,name=selected,proto3" json:"selected,omitempty"`
}
func (x *Route) Reset() {
*x = Route{}
if protoimpl.UnsafeEnabled {
mi := &file_daemon_proto_msgTypes[23]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *Route) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*Route) ProtoMessage() {}
func (x *Route) ProtoReflect() protoreflect.Message {
mi := &file_daemon_proto_msgTypes[23]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use Route.ProtoReflect.Descriptor instead.
func (*Route) Descriptor() ([]byte, []int) {
return file_daemon_proto_rawDescGZIP(), []int{23}
}
func (x *Route) GetID() string {
if x != nil {
return x.ID
}
return ""
}
func (x *Route) GetNetwork() string {
if x != nil {
return x.Network
}
return ""
}
func (x *Route) GetSelected() bool {
if x != nil {
return x.Selected
}
return false
}
type DebugBundleRequest struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
Anonymize bool `protobuf:"varint,1,opt,name=anonymize,proto3" json:"anonymize,omitempty"`
Status string `protobuf:"bytes,2,opt,name=status,proto3" json:"status,omitempty"`
}
func (x *DebugBundleRequest) Reset() {
*x = DebugBundleRequest{}
if protoimpl.UnsafeEnabled {
mi := &file_daemon_proto_msgTypes[24]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *DebugBundleRequest) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*DebugBundleRequest) ProtoMessage() {}
func (x *DebugBundleRequest) ProtoReflect() protoreflect.Message {
mi := &file_daemon_proto_msgTypes[24]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use DebugBundleRequest.ProtoReflect.Descriptor instead.
func (*DebugBundleRequest) Descriptor() ([]byte, []int) {
return file_daemon_proto_rawDescGZIP(), []int{24}
}
func (x *DebugBundleRequest) GetAnonymize() bool {
if x != nil {
return x.Anonymize
}
return false
}
func (x *DebugBundleRequest) GetStatus() string {
if x != nil {
return x.Status
}
return ""
}
type DebugBundleResponse struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
Path string `protobuf:"bytes,1,opt,name=path,proto3" json:"path,omitempty"`
}
func (x *DebugBundleResponse) Reset() {
*x = DebugBundleResponse{}
if protoimpl.UnsafeEnabled {
mi := &file_daemon_proto_msgTypes[25]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *DebugBundleResponse) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*DebugBundleResponse) ProtoMessage() {}
func (x *DebugBundleResponse) ProtoReflect() protoreflect.Message {
mi := &file_daemon_proto_msgTypes[25]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use DebugBundleResponse.ProtoReflect.Descriptor instead.
func (*DebugBundleResponse) Descriptor() ([]byte, []int) {
return file_daemon_proto_rawDescGZIP(), []int{25}
}
func (x *DebugBundleResponse) GetPath() string {
if x != nil {
return x.Path
}
return ""
}
type SetLogLevelRequest struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
Level LogLevel `protobuf:"varint,1,opt,name=level,proto3,enum=daemon.LogLevel" json:"level,omitempty"`
}
func (x *SetLogLevelRequest) Reset() {
*x = SetLogLevelRequest{}
if protoimpl.UnsafeEnabled {
mi := &file_daemon_proto_msgTypes[26]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *SetLogLevelRequest) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*SetLogLevelRequest) ProtoMessage() {}
func (x *SetLogLevelRequest) ProtoReflect() protoreflect.Message {
mi := &file_daemon_proto_msgTypes[26]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use SetLogLevelRequest.ProtoReflect.Descriptor instead.
func (*SetLogLevelRequest) Descriptor() ([]byte, []int) {
return file_daemon_proto_rawDescGZIP(), []int{26}
}
func (x *SetLogLevelRequest) GetLevel() LogLevel {
if x != nil {
return x.Level
}
return LogLevel_UNKNOWN
}
type SetLogLevelResponse struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
}
func (x *SetLogLevelResponse) Reset() {
*x = SetLogLevelResponse{}
if protoimpl.UnsafeEnabled {
mi := &file_daemon_proto_msgTypes[27]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *SetLogLevelResponse) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*SetLogLevelResponse) ProtoMessage() {}
func (x *SetLogLevelResponse) ProtoReflect() protoreflect.Message {
mi := &file_daemon_proto_msgTypes[27]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use SetLogLevelResponse.ProtoReflect.Descriptor instead.
func (*SetLogLevelResponse) Descriptor() ([]byte, []int) {
return file_daemon_proto_rawDescGZIP(), []int{27}
}
var File_daemon_proto protoreflect.FileDescriptor
var file_daemon_proto_rawDesc = []byte{
@@ -2101,92 +1601,32 @@ var file_daemon_proto_rawDesc = []byte{
0x72, 0x65, 0x6c, 0x61, 0x79, 0x73, 0x12, 0x35, 0x0a, 0x0b, 0x64, 0x6e, 0x73, 0x5f, 0x73, 0x65,
0x72, 0x76, 0x65, 0x72, 0x73, 0x18, 0x06, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x14, 0x2e, 0x64, 0x61,
0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4e, 0x53, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x53, 0x74, 0x61, 0x74,
0x65, 0x52, 0x0a, 0x64, 0x6e, 0x73, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x73, 0x22, 0x13, 0x0a,
0x11, 0x4c, 0x69, 0x73, 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65,
0x73, 0x74, 0x22, 0x3b, 0x0a, 0x12, 0x4c, 0x69, 0x73, 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73,
0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x25, 0x0a, 0x06, 0x72, 0x6f, 0x75, 0x74,
0x65, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x0d, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f,
0x6e, 0x2e, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x52, 0x06, 0x72, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x22,
0x5b, 0x0a, 0x13, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x52,
0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x1a, 0x0a, 0x08, 0x72, 0x6f, 0x75, 0x74, 0x65, 0x49,
0x44, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x09, 0x52, 0x08, 0x72, 0x6f, 0x75, 0x74, 0x65, 0x49,
0x44, 0x73, 0x12, 0x16, 0x0a, 0x06, 0x61, 0x70, 0x70, 0x65, 0x6e, 0x64, 0x18, 0x02, 0x20, 0x01,
0x28, 0x08, 0x52, 0x06, 0x61, 0x70, 0x70, 0x65, 0x6e, 0x64, 0x12, 0x10, 0x0a, 0x03, 0x61, 0x6c,
0x6c, 0x18, 0x03, 0x20, 0x01, 0x28, 0x08, 0x52, 0x03, 0x61, 0x6c, 0x6c, 0x22, 0x16, 0x0a, 0x14,
0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x52, 0x65, 0x73, 0x70,
0x6f, 0x6e, 0x73, 0x65, 0x22, 0x4d, 0x0a, 0x05, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x12, 0x0e, 0x0a,
0x02, 0x49, 0x44, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x49, 0x44, 0x12, 0x18, 0x0a,
0x07, 0x6e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07,
0x6e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x12, 0x1a, 0x0a, 0x08, 0x73, 0x65, 0x6c, 0x65, 0x63,
0x74, 0x65, 0x64, 0x18, 0x03, 0x20, 0x01, 0x28, 0x08, 0x52, 0x08, 0x73, 0x65, 0x6c, 0x65, 0x63,
0x74, 0x65, 0x64, 0x22, 0x4a, 0x0a, 0x12, 0x44, 0x65, 0x62, 0x75, 0x67, 0x42, 0x75, 0x6e, 0x64,
0x6c, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x1c, 0x0a, 0x09, 0x61, 0x6e, 0x6f,
0x6e, 0x79, 0x6d, 0x69, 0x7a, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x52, 0x09, 0x61, 0x6e,
0x6f, 0x6e, 0x79, 0x6d, 0x69, 0x7a, 0x65, 0x12, 0x16, 0x0a, 0x06, 0x73, 0x74, 0x61, 0x74, 0x75,
0x73, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x73, 0x74, 0x61, 0x74, 0x75, 0x73, 0x22,
0x29, 0x0a, 0x13, 0x44, 0x65, 0x62, 0x75, 0x67, 0x42, 0x75, 0x6e, 0x64, 0x6c, 0x65, 0x52, 0x65,
0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x70, 0x61, 0x74, 0x68, 0x18, 0x01,
0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x70, 0x61, 0x74, 0x68, 0x22, 0x3c, 0x0a, 0x12, 0x53, 0x65,
0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74,
0x12, 0x26, 0x0a, 0x05, 0x6c, 0x65, 0x76, 0x65, 0x6c, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0e, 0x32,
0x10, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65,
0x6c, 0x52, 0x05, 0x6c, 0x65, 0x76, 0x65, 0x6c, 0x22, 0x15, 0x0a, 0x13, 0x53, 0x65, 0x74, 0x4c,
0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x2a,
0x62, 0x0a, 0x08, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x12, 0x0b, 0x0a, 0x07, 0x55,
0x4e, 0x4b, 0x4e, 0x4f, 0x57, 0x4e, 0x10, 0x00, 0x12, 0x09, 0x0a, 0x05, 0x50, 0x41, 0x4e, 0x49,
0x43, 0x10, 0x01, 0x12, 0x09, 0x0a, 0x05, 0x46, 0x41, 0x54, 0x41, 0x4c, 0x10, 0x02, 0x12, 0x09,
0x0a, 0x05, 0x45, 0x52, 0x52, 0x4f, 0x52, 0x10, 0x03, 0x12, 0x08, 0x0a, 0x04, 0x57, 0x41, 0x52,
0x4e, 0x10, 0x04, 0x12, 0x08, 0x0a, 0x04, 0x49, 0x4e, 0x46, 0x4f, 0x10, 0x05, 0x12, 0x09, 0x0a,
0x05, 0x44, 0x45, 0x42, 0x55, 0x47, 0x10, 0x06, 0x12, 0x09, 0x0a, 0x05, 0x54, 0x52, 0x41, 0x43,
0x45, 0x10, 0x07, 0x32, 0xee, 0x05, 0x0a, 0x0d, 0x44, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x53, 0x65,
0x72, 0x76, 0x69, 0x63, 0x65, 0x12, 0x36, 0x0a, 0x05, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x12, 0x14,
0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x71,
0x75, 0x65, 0x73, 0x74, 0x1a, 0x15, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x6f,
0x67, 0x69, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x4b, 0x0a,
0x0c, 0x57, 0x61, 0x69, 0x74, 0x53, 0x53, 0x4f, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x12, 0x1b, 0x2e,
0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x57, 0x61, 0x69, 0x74, 0x53, 0x53, 0x4f, 0x4c, 0x6f,
0x67, 0x69, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1c, 0x2e, 0x64, 0x61, 0x65,
0x6d, 0x6f, 0x6e, 0x2e, 0x57, 0x61, 0x69, 0x74, 0x53, 0x53, 0x4f, 0x4c, 0x6f, 0x67, 0x69, 0x6e,
0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x2d, 0x0a, 0x02, 0x55, 0x70,
0x12, 0x11, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x55, 0x70, 0x52, 0x65, 0x71, 0x75,
0x65, 0x73, 0x74, 0x1a, 0x12, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x55, 0x70, 0x52,
0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x39, 0x0a, 0x06, 0x53, 0x74, 0x61,
0x74, 0x75, 0x73, 0x12, 0x15, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x74, 0x61,
0x74, 0x75, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x16, 0x2e, 0x64, 0x61, 0x65,
0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e,
0x73, 0x65, 0x22, 0x00, 0x12, 0x33, 0x0a, 0x04, 0x44, 0x6f, 0x77, 0x6e, 0x12, 0x13, 0x2e, 0x64,
0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x6f, 0x77, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73,
0x74, 0x1a, 0x14, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x6f, 0x77, 0x6e, 0x52,
0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x42, 0x0a, 0x09, 0x47, 0x65, 0x74,
0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x18, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e,
0x47, 0x65, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74,
0x1a, 0x19, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x47, 0x65, 0x74, 0x43, 0x6f, 0x6e,
0x66, 0x69, 0x67, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x45, 0x0a,
0x0a, 0x4c, 0x69, 0x73, 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x12, 0x19, 0x2e, 0x64, 0x61,
0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x69, 0x73, 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x52,
0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1a, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e,
0x4c, 0x69, 0x73, 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e,
0x73, 0x65, 0x22, 0x00, 0x12, 0x4b, 0x0a, 0x0c, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x52, 0x6f,
0x75, 0x74, 0x65, 0x73, 0x12, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65,
0x6c, 0x65, 0x63, 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73,
0x74, 0x1a, 0x1c, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x6c, 0x65, 0x63,
0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22,
0x00, 0x12, 0x4d, 0x0a, 0x0e, 0x44, 0x65, 0x73, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x52, 0x6f, 0x75,
0x74, 0x65, 0x73, 0x12, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x6c,
0x65, 0x63, 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74,
0x1a, 0x1c, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74,
0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00,
0x12, 0x48, 0x0a, 0x0b, 0x44, 0x65, 0x62, 0x75, 0x67, 0x42, 0x75, 0x6e, 0x64, 0x6c, 0x65, 0x12,
0x1a, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x65, 0x62, 0x75, 0x67, 0x42, 0x75,
0x6e, 0x64, 0x6c, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1b, 0x2e, 0x64, 0x61,
0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x65, 0x62, 0x75, 0x67, 0x42, 0x75, 0x6e, 0x64, 0x6c, 0x65,
0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x48, 0x0a, 0x0b, 0x53, 0x65,
0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x12, 0x1a, 0x2e, 0x64, 0x61, 0x65, 0x6d,
0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x65,
0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53,
0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e,
0x73, 0x65, 0x22, 0x00, 0x42, 0x08, 0x5a, 0x06, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x06,
0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33,
0x65, 0x52, 0x0a, 0x64, 0x6e, 0x73, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x73, 0x32, 0xf7, 0x02,
0x0a, 0x0d, 0x44, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x12,
0x36, 0x0a, 0x05, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x12, 0x14, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f,
0x6e, 0x2e, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x15,
0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x73,
0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x4b, 0x0a, 0x0c, 0x57, 0x61, 0x69, 0x74, 0x53,
0x53, 0x4f, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x12, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e,
0x2e, 0x57, 0x61, 0x69, 0x74, 0x53, 0x53, 0x4f, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x71,
0x75, 0x65, 0x73, 0x74, 0x1a, 0x1c, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x57, 0x61,
0x69, 0x74, 0x53, 0x53, 0x4f, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e,
0x73, 0x65, 0x22, 0x00, 0x12, 0x2d, 0x0a, 0x02, 0x55, 0x70, 0x12, 0x11, 0x2e, 0x64, 0x61, 0x65,
0x6d, 0x6f, 0x6e, 0x2e, 0x55, 0x70, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x12, 0x2e,
0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x55, 0x70, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73,
0x65, 0x22, 0x00, 0x12, 0x39, 0x0a, 0x06, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x12, 0x15, 0x2e,
0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x52, 0x65, 0x71,
0x75, 0x65, 0x73, 0x74, 0x1a, 0x16, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x74,
0x61, 0x74, 0x75, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x33,
0x0a, 0x04, 0x44, 0x6f, 0x77, 0x6e, 0x12, 0x13, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e,
0x44, 0x6f, 0x77, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x14, 0x2e, 0x64, 0x61,
0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x6f, 0x77, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73,
0x65, 0x22, 0x00, 0x12, 0x42, 0x0a, 0x09, 0x47, 0x65, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67,
0x12, 0x18, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x47, 0x65, 0x74, 0x43, 0x6f, 0x6e,
0x66, 0x69, 0x67, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x19, 0x2e, 0x64, 0x61, 0x65,
0x6d, 0x6f, 0x6e, 0x2e, 0x47, 0x65, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x65, 0x73,
0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x42, 0x08, 0x5a, 0x06, 0x2f, 0x70, 0x72, 0x6f, 0x74,
0x6f, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33,
}
var (
@@ -2201,81 +1641,58 @@ func file_daemon_proto_rawDescGZIP() []byte {
return file_daemon_proto_rawDescData
}
var file_daemon_proto_enumTypes = make([]protoimpl.EnumInfo, 1)
var file_daemon_proto_msgTypes = make([]protoimpl.MessageInfo, 28)
var file_daemon_proto_msgTypes = make([]protoimpl.MessageInfo, 19)
var file_daemon_proto_goTypes = []interface{}{
(LogLevel)(0), // 0: daemon.LogLevel
(*LoginRequest)(nil), // 1: daemon.LoginRequest
(*LoginResponse)(nil), // 2: daemon.LoginResponse
(*WaitSSOLoginRequest)(nil), // 3: daemon.WaitSSOLoginRequest
(*WaitSSOLoginResponse)(nil), // 4: daemon.WaitSSOLoginResponse
(*UpRequest)(nil), // 5: daemon.UpRequest
(*UpResponse)(nil), // 6: daemon.UpResponse
(*StatusRequest)(nil), // 7: daemon.StatusRequest
(*StatusResponse)(nil), // 8: daemon.StatusResponse
(*DownRequest)(nil), // 9: daemon.DownRequest
(*DownResponse)(nil), // 10: daemon.DownResponse
(*GetConfigRequest)(nil), // 11: daemon.GetConfigRequest
(*GetConfigResponse)(nil), // 12: daemon.GetConfigResponse
(*PeerState)(nil), // 13: daemon.PeerState
(*LocalPeerState)(nil), // 14: daemon.LocalPeerState
(*SignalState)(nil), // 15: daemon.SignalState
(*ManagementState)(nil), // 16: daemon.ManagementState
(*RelayState)(nil), // 17: daemon.RelayState
(*NSGroupState)(nil), // 18: daemon.NSGroupState
(*FullStatus)(nil), // 19: daemon.FullStatus
(*ListRoutesRequest)(nil), // 20: daemon.ListRoutesRequest
(*ListRoutesResponse)(nil), // 21: daemon.ListRoutesResponse
(*SelectRoutesRequest)(nil), // 22: daemon.SelectRoutesRequest
(*SelectRoutesResponse)(nil), // 23: daemon.SelectRoutesResponse
(*Route)(nil), // 24: daemon.Route
(*DebugBundleRequest)(nil), // 25: daemon.DebugBundleRequest
(*DebugBundleResponse)(nil), // 26: daemon.DebugBundleResponse
(*SetLogLevelRequest)(nil), // 27: daemon.SetLogLevelRequest
(*SetLogLevelResponse)(nil), // 28: daemon.SetLogLevelResponse
(*timestamp.Timestamp)(nil), // 29: google.protobuf.Timestamp
(*duration.Duration)(nil), // 30: google.protobuf.Duration
(*LoginRequest)(nil), // 0: daemon.LoginRequest
(*LoginResponse)(nil), // 1: daemon.LoginResponse
(*WaitSSOLoginRequest)(nil), // 2: daemon.WaitSSOLoginRequest
(*WaitSSOLoginResponse)(nil), // 3: daemon.WaitSSOLoginResponse
(*UpRequest)(nil), // 4: daemon.UpRequest
(*UpResponse)(nil), // 5: daemon.UpResponse
(*StatusRequest)(nil), // 6: daemon.StatusRequest
(*StatusResponse)(nil), // 7: daemon.StatusResponse
(*DownRequest)(nil), // 8: daemon.DownRequest
(*DownResponse)(nil), // 9: daemon.DownResponse
(*GetConfigRequest)(nil), // 10: daemon.GetConfigRequest
(*GetConfigResponse)(nil), // 11: daemon.GetConfigResponse
(*PeerState)(nil), // 12: daemon.PeerState
(*LocalPeerState)(nil), // 13: daemon.LocalPeerState
(*SignalState)(nil), // 14: daemon.SignalState
(*ManagementState)(nil), // 15: daemon.ManagementState
(*RelayState)(nil), // 16: daemon.RelayState
(*NSGroupState)(nil), // 17: daemon.NSGroupState
(*FullStatus)(nil), // 18: daemon.FullStatus
(*timestamppb.Timestamp)(nil), // 19: google.protobuf.Timestamp
(*durationpb.Duration)(nil), // 20: google.protobuf.Duration
}
var file_daemon_proto_depIdxs = []int32{
19, // 0: daemon.StatusResponse.fullStatus:type_name -> daemon.FullStatus
29, // 1: daemon.PeerState.connStatusUpdate:type_name -> google.protobuf.Timestamp
29, // 2: daemon.PeerState.lastWireguardHandshake:type_name -> google.protobuf.Timestamp
30, // 3: daemon.PeerState.latency:type_name -> google.protobuf.Duration
16, // 4: daemon.FullStatus.managementState:type_name -> daemon.ManagementState
15, // 5: daemon.FullStatus.signalState:type_name -> daemon.SignalState
14, // 6: daemon.FullStatus.localPeerState:type_name -> daemon.LocalPeerState
13, // 7: daemon.FullStatus.peers:type_name -> daemon.PeerState
17, // 8: daemon.FullStatus.relays:type_name -> daemon.RelayState
18, // 9: daemon.FullStatus.dns_servers:type_name -> daemon.NSGroupState
24, // 10: daemon.ListRoutesResponse.routes:type_name -> daemon.Route
0, // 11: daemon.SetLogLevelRequest.level:type_name -> daemon.LogLevel
1, // 12: daemon.DaemonService.Login:input_type -> daemon.LoginRequest
3, // 13: daemon.DaemonService.WaitSSOLogin:input_type -> daemon.WaitSSOLoginRequest
5, // 14: daemon.DaemonService.Up:input_type -> daemon.UpRequest
7, // 15: daemon.DaemonService.Status:input_type -> daemon.StatusRequest
9, // 16: daemon.DaemonService.Down:input_type -> daemon.DownRequest
11, // 17: daemon.DaemonService.GetConfig:input_type -> daemon.GetConfigRequest
20, // 18: daemon.DaemonService.ListRoutes:input_type -> daemon.ListRoutesRequest
22, // 19: daemon.DaemonService.SelectRoutes:input_type -> daemon.SelectRoutesRequest
22, // 20: daemon.DaemonService.DeselectRoutes:input_type -> daemon.SelectRoutesRequest
25, // 21: daemon.DaemonService.DebugBundle:input_type -> daemon.DebugBundleRequest
27, // 22: daemon.DaemonService.SetLogLevel:input_type -> daemon.SetLogLevelRequest
2, // 23: daemon.DaemonService.Login:output_type -> daemon.LoginResponse
4, // 24: daemon.DaemonService.WaitSSOLogin:output_type -> daemon.WaitSSOLoginResponse
6, // 25: daemon.DaemonService.Up:output_type -> daemon.UpResponse
8, // 26: daemon.DaemonService.Status:output_type -> daemon.StatusResponse
10, // 27: daemon.DaemonService.Down:output_type -> daemon.DownResponse
12, // 28: daemon.DaemonService.GetConfig:output_type -> daemon.GetConfigResponse
21, // 29: daemon.DaemonService.ListRoutes:output_type -> daemon.ListRoutesResponse
23, // 30: daemon.DaemonService.SelectRoutes:output_type -> daemon.SelectRoutesResponse
23, // 31: daemon.DaemonService.DeselectRoutes:output_type -> daemon.SelectRoutesResponse
26, // 32: daemon.DaemonService.DebugBundle:output_type -> daemon.DebugBundleResponse
28, // 33: daemon.DaemonService.SetLogLevel:output_type -> daemon.SetLogLevelResponse
23, // [23:34] is the sub-list for method output_type
12, // [12:23] is the sub-list for method input_type
12, // [12:12] is the sub-list for extension type_name
12, // [12:12] is the sub-list for extension extendee
0, // [0:12] is the sub-list for field type_name
18, // 0: daemon.StatusResponse.fullStatus:type_name -> daemon.FullStatus
19, // 1: daemon.PeerState.connStatusUpdate:type_name -> google.protobuf.Timestamp
19, // 2: daemon.PeerState.lastWireguardHandshake:type_name -> google.protobuf.Timestamp
20, // 3: daemon.PeerState.latency:type_name -> google.protobuf.Duration
15, // 4: daemon.FullStatus.managementState:type_name -> daemon.ManagementState
14, // 5: daemon.FullStatus.signalState:type_name -> daemon.SignalState
13, // 6: daemon.FullStatus.localPeerState:type_name -> daemon.LocalPeerState
12, // 7: daemon.FullStatus.peers:type_name -> daemon.PeerState
16, // 8: daemon.FullStatus.relays:type_name -> daemon.RelayState
17, // 9: daemon.FullStatus.dns_servers:type_name -> daemon.NSGroupState
0, // 10: daemon.DaemonService.Login:input_type -> daemon.LoginRequest
2, // 11: daemon.DaemonService.WaitSSOLogin:input_type -> daemon.WaitSSOLoginRequest
4, // 12: daemon.DaemonService.Up:input_type -> daemon.UpRequest
6, // 13: daemon.DaemonService.Status:input_type -> daemon.StatusRequest
8, // 14: daemon.DaemonService.Down:input_type -> daemon.DownRequest
10, // 15: daemon.DaemonService.GetConfig:input_type -> daemon.GetConfigRequest
1, // 16: daemon.DaemonService.Login:output_type -> daemon.LoginResponse
3, // 17: daemon.DaemonService.WaitSSOLogin:output_type -> daemon.WaitSSOLoginResponse
5, // 18: daemon.DaemonService.Up:output_type -> daemon.UpResponse
7, // 19: daemon.DaemonService.Status:output_type -> daemon.StatusResponse
9, // 20: daemon.DaemonService.Down:output_type -> daemon.DownResponse
11, // 21: daemon.DaemonService.GetConfig:output_type -> daemon.GetConfigResponse
16, // [16:22] is the sub-list for method output_type
10, // [10:16] is the sub-list for method input_type
10, // [10:10] is the sub-list for extension type_name
10, // [10:10] is the sub-list for extension extendee
0, // [0:10] is the sub-list for field type_name
}
func init() { file_daemon_proto_init() }
@@ -2512,114 +1929,6 @@ func file_daemon_proto_init() {
return nil
}
}
file_daemon_proto_msgTypes[19].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*ListRoutesRequest); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
file_daemon_proto_msgTypes[20].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*ListRoutesResponse); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
file_daemon_proto_msgTypes[21].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*SelectRoutesRequest); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
file_daemon_proto_msgTypes[22].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*SelectRoutesResponse); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
file_daemon_proto_msgTypes[23].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*Route); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
file_daemon_proto_msgTypes[24].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*DebugBundleRequest); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
file_daemon_proto_msgTypes[25].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*DebugBundleResponse); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
file_daemon_proto_msgTypes[26].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*SetLogLevelRequest); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
file_daemon_proto_msgTypes[27].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*SetLogLevelResponse); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
}
file_daemon_proto_msgTypes[0].OneofWrappers = []interface{}{}
type x struct{}
@@ -2627,14 +1936,13 @@ func file_daemon_proto_init() {
File: protoimpl.DescBuilder{
GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
RawDescriptor: file_daemon_proto_rawDesc,
NumEnums: 1,
NumMessages: 28,
NumEnums: 0,
NumMessages: 19,
NumExtensions: 0,
NumServices: 1,
},
GoTypes: file_daemon_proto_goTypes,
DependencyIndexes: file_daemon_proto_depIdxs,
EnumInfos: file_daemon_proto_enumTypes,
MessageInfos: file_daemon_proto_msgTypes,
}.Build()
File_daemon_proto = out.File

View File

@@ -27,21 +27,6 @@ service DaemonService {
// GetConfig of the daemon.
rpc GetConfig(GetConfigRequest) returns (GetConfigResponse) {}
// List available network routes
rpc ListRoutes(ListRoutesRequest) returns (ListRoutesResponse) {}
// Select specific routes
rpc SelectRoutes(SelectRoutesRequest) returns (SelectRoutesResponse) {}
// Deselect specific routes
rpc DeselectRoutes(SelectRoutesRequest) returns (SelectRoutesResponse) {}
// DebugBundle creates a debug bundle
rpc DebugBundle(DebugBundleRequest) returns (DebugBundleResponse) {}
// SetLogLevel sets the log level of the daemon
rpc SetLogLevel(SetLogLevelRequest) returns (SetLogLevelResponse) {}
};
message LoginRequest {
@@ -210,53 +195,4 @@ message FullStatus {
repeated PeerState peers = 4;
repeated RelayState relays = 5;
repeated NSGroupState dns_servers = 6;
}
message ListRoutesRequest {
}
message ListRoutesResponse {
repeated Route routes = 1;
}
message SelectRoutesRequest {
repeated string routeIDs = 1;
bool append = 2;
bool all = 3;
}
message SelectRoutesResponse {
}
message Route {
string ID = 1;
string network = 2;
bool selected = 3;
}
message DebugBundleRequest {
bool anonymize = 1;
string status = 2;
}
message DebugBundleResponse {
string path = 1;
}
enum LogLevel {
UNKNOWN = 0;
PANIC = 1;
FATAL = 2;
ERROR = 3;
WARN = 4;
INFO = 5;
DEBUG = 6;
TRACE = 7;
}
message SetLogLevelRequest {
LogLevel level = 1;
}
message SetLogLevelResponse {
}

View File

@@ -31,16 +31,6 @@ type DaemonServiceClient interface {
Down(ctx context.Context, in *DownRequest, opts ...grpc.CallOption) (*DownResponse, error)
// GetConfig of the daemon.
GetConfig(ctx context.Context, in *GetConfigRequest, opts ...grpc.CallOption) (*GetConfigResponse, error)
// List available network routes
ListRoutes(ctx context.Context, in *ListRoutesRequest, opts ...grpc.CallOption) (*ListRoutesResponse, error)
// Select specific routes
SelectRoutes(ctx context.Context, in *SelectRoutesRequest, opts ...grpc.CallOption) (*SelectRoutesResponse, error)
// Deselect specific routes
DeselectRoutes(ctx context.Context, in *SelectRoutesRequest, opts ...grpc.CallOption) (*SelectRoutesResponse, error)
// DebugBundle creates a debug bundle
DebugBundle(ctx context.Context, in *DebugBundleRequest, opts ...grpc.CallOption) (*DebugBundleResponse, error)
// SetLogLevel sets the log level of the daemon
SetLogLevel(ctx context.Context, in *SetLogLevelRequest, opts ...grpc.CallOption) (*SetLogLevelResponse, error)
}
type daemonServiceClient struct {
@@ -105,51 +95,6 @@ func (c *daemonServiceClient) GetConfig(ctx context.Context, in *GetConfigReques
return out, nil
}
func (c *daemonServiceClient) ListRoutes(ctx context.Context, in *ListRoutesRequest, opts ...grpc.CallOption) (*ListRoutesResponse, error) {
out := new(ListRoutesResponse)
err := c.cc.Invoke(ctx, "/daemon.DaemonService/ListRoutes", in, out, opts...)
if err != nil {
return nil, err
}
return out, nil
}
func (c *daemonServiceClient) SelectRoutes(ctx context.Context, in *SelectRoutesRequest, opts ...grpc.CallOption) (*SelectRoutesResponse, error) {
out := new(SelectRoutesResponse)
err := c.cc.Invoke(ctx, "/daemon.DaemonService/SelectRoutes", in, out, opts...)
if err != nil {
return nil, err
}
return out, nil
}
func (c *daemonServiceClient) DeselectRoutes(ctx context.Context, in *SelectRoutesRequest, opts ...grpc.CallOption) (*SelectRoutesResponse, error) {
out := new(SelectRoutesResponse)
err := c.cc.Invoke(ctx, "/daemon.DaemonService/DeselectRoutes", in, out, opts...)
if err != nil {
return nil, err
}
return out, nil
}
func (c *daemonServiceClient) DebugBundle(ctx context.Context, in *DebugBundleRequest, opts ...grpc.CallOption) (*DebugBundleResponse, error) {
out := new(DebugBundleResponse)
err := c.cc.Invoke(ctx, "/daemon.DaemonService/DebugBundle", in, out, opts...)
if err != nil {
return nil, err
}
return out, nil
}
func (c *daemonServiceClient) SetLogLevel(ctx context.Context, in *SetLogLevelRequest, opts ...grpc.CallOption) (*SetLogLevelResponse, error) {
out := new(SetLogLevelResponse)
err := c.cc.Invoke(ctx, "/daemon.DaemonService/SetLogLevel", in, out, opts...)
if err != nil {
return nil, err
}
return out, nil
}
// DaemonServiceServer is the server API for DaemonService service.
// All implementations must embed UnimplementedDaemonServiceServer
// for forward compatibility
@@ -167,16 +112,6 @@ type DaemonServiceServer interface {
Down(context.Context, *DownRequest) (*DownResponse, error)
// GetConfig of the daemon.
GetConfig(context.Context, *GetConfigRequest) (*GetConfigResponse, error)
// List available network routes
ListRoutes(context.Context, *ListRoutesRequest) (*ListRoutesResponse, error)
// Select specific routes
SelectRoutes(context.Context, *SelectRoutesRequest) (*SelectRoutesResponse, error)
// Deselect specific routes
DeselectRoutes(context.Context, *SelectRoutesRequest) (*SelectRoutesResponse, error)
// DebugBundle creates a debug bundle
DebugBundle(context.Context, *DebugBundleRequest) (*DebugBundleResponse, error)
// SetLogLevel sets the log level of the daemon
SetLogLevel(context.Context, *SetLogLevelRequest) (*SetLogLevelResponse, error)
mustEmbedUnimplementedDaemonServiceServer()
}
@@ -202,21 +137,6 @@ func (UnimplementedDaemonServiceServer) Down(context.Context, *DownRequest) (*Do
func (UnimplementedDaemonServiceServer) GetConfig(context.Context, *GetConfigRequest) (*GetConfigResponse, error) {
return nil, status.Errorf(codes.Unimplemented, "method GetConfig not implemented")
}
func (UnimplementedDaemonServiceServer) ListRoutes(context.Context, *ListRoutesRequest) (*ListRoutesResponse, error) {
return nil, status.Errorf(codes.Unimplemented, "method ListRoutes not implemented")
}
func (UnimplementedDaemonServiceServer) SelectRoutes(context.Context, *SelectRoutesRequest) (*SelectRoutesResponse, error) {
return nil, status.Errorf(codes.Unimplemented, "method SelectRoutes not implemented")
}
func (UnimplementedDaemonServiceServer) DeselectRoutes(context.Context, *SelectRoutesRequest) (*SelectRoutesResponse, error) {
return nil, status.Errorf(codes.Unimplemented, "method DeselectRoutes not implemented")
}
func (UnimplementedDaemonServiceServer) DebugBundle(context.Context, *DebugBundleRequest) (*DebugBundleResponse, error) {
return nil, status.Errorf(codes.Unimplemented, "method DebugBundle not implemented")
}
func (UnimplementedDaemonServiceServer) SetLogLevel(context.Context, *SetLogLevelRequest) (*SetLogLevelResponse, error) {
return nil, status.Errorf(codes.Unimplemented, "method SetLogLevel not implemented")
}
func (UnimplementedDaemonServiceServer) mustEmbedUnimplementedDaemonServiceServer() {}
// UnsafeDaemonServiceServer may be embedded to opt out of forward compatibility for this service.
@@ -338,96 +258,6 @@ func _DaemonService_GetConfig_Handler(srv interface{}, ctx context.Context, dec
return interceptor(ctx, in, info, handler)
}
func _DaemonService_ListRoutes_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(ListRoutesRequest)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(DaemonServiceServer).ListRoutes(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/daemon.DaemonService/ListRoutes",
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(DaemonServiceServer).ListRoutes(ctx, req.(*ListRoutesRequest))
}
return interceptor(ctx, in, info, handler)
}
func _DaemonService_SelectRoutes_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(SelectRoutesRequest)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(DaemonServiceServer).SelectRoutes(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/daemon.DaemonService/SelectRoutes",
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(DaemonServiceServer).SelectRoutes(ctx, req.(*SelectRoutesRequest))
}
return interceptor(ctx, in, info, handler)
}
func _DaemonService_DeselectRoutes_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(SelectRoutesRequest)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(DaemonServiceServer).DeselectRoutes(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/daemon.DaemonService/DeselectRoutes",
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(DaemonServiceServer).DeselectRoutes(ctx, req.(*SelectRoutesRequest))
}
return interceptor(ctx, in, info, handler)
}
func _DaemonService_DebugBundle_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(DebugBundleRequest)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(DaemonServiceServer).DebugBundle(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/daemon.DaemonService/DebugBundle",
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(DaemonServiceServer).DebugBundle(ctx, req.(*DebugBundleRequest))
}
return interceptor(ctx, in, info, handler)
}
func _DaemonService_SetLogLevel_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(SetLogLevelRequest)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(DaemonServiceServer).SetLogLevel(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/daemon.DaemonService/SetLogLevel",
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(DaemonServiceServer).SetLogLevel(ctx, req.(*SetLogLevelRequest))
}
return interceptor(ctx, in, info, handler)
}
// DaemonService_ServiceDesc is the grpc.ServiceDesc for DaemonService service.
// It's only intended for direct use with grpc.RegisterService,
// and not to be introspected or modified (even as a copy)
@@ -459,26 +289,6 @@ var DaemonService_ServiceDesc = grpc.ServiceDesc{
MethodName: "GetConfig",
Handler: _DaemonService_GetConfig_Handler,
},
{
MethodName: "ListRoutes",
Handler: _DaemonService_ListRoutes_Handler,
},
{
MethodName: "SelectRoutes",
Handler: _DaemonService_SelectRoutes_Handler,
},
{
MethodName: "DeselectRoutes",
Handler: _DaemonService_DeselectRoutes_Handler,
},
{
MethodName: "DebugBundle",
Handler: _DaemonService_DebugBundle_Handler,
},
{
MethodName: "SetLogLevel",
Handler: _DaemonService_SetLogLevel_Handler,
},
},
Streams: []grpc.StreamDesc{},
Metadata: "daemon.proto",

View File

@@ -1,175 +0,0 @@
package server
import (
"archive/zip"
"bufio"
"context"
"fmt"
"io"
"os"
"strings"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/anonymize"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/proto"
)
// DebugBundle creates a debug bundle and returns the location.
func (s *Server) DebugBundle(_ context.Context, req *proto.DebugBundleRequest) (resp *proto.DebugBundleResponse, err error) {
s.mutex.Lock()
defer s.mutex.Unlock()
if s.logFile == "console" {
return nil, fmt.Errorf("log file is set to console, cannot create debug bundle")
}
bundlePath, err := os.CreateTemp("", "netbird.debug.*.zip")
if err != nil {
return nil, fmt.Errorf("create zip file: %w", err)
}
defer func() {
if err := bundlePath.Close(); err != nil {
log.Errorf("failed to close zip file: %v", err)
}
if err != nil {
if err2 := os.Remove(bundlePath.Name()); err2 != nil {
log.Errorf("Failed to remove zip file: %v", err2)
}
}
}()
archive := zip.NewWriter(bundlePath)
defer func() {
if err := archive.Close(); err != nil {
log.Errorf("failed to close archive writer: %v", err)
}
}()
if status := req.GetStatus(); status != "" {
filename := "status.txt"
if req.GetAnonymize() {
filename = "status.anon.txt"
}
statusReader := strings.NewReader(status)
if err := addFileToZip(archive, statusReader, filename); err != nil {
return nil, fmt.Errorf("add status file to zip: %w", err)
}
}
logFile, err := os.Open(s.logFile)
if err != nil {
return nil, fmt.Errorf("open log file: %w", err)
}
defer func() {
if err := logFile.Close(); err != nil {
log.Errorf("failed to close original log file: %v", err)
}
}()
filename := "client.log.txt"
var logReader io.Reader
errChan := make(chan error, 1)
if req.GetAnonymize() {
filename = "client.anon.log.txt"
var writer io.WriteCloser
logReader, writer = io.Pipe()
go s.anonymize(logFile, writer, errChan)
} else {
logReader = logFile
}
if err := addFileToZip(archive, logReader, filename); err != nil {
return nil, fmt.Errorf("add log file to zip: %w", err)
}
select {
case err := <-errChan:
if err != nil {
return nil, err
}
default:
}
return &proto.DebugBundleResponse{Path: bundlePath.Name()}, nil
}
func (s *Server) anonymize(reader io.Reader, writer io.WriteCloser, errChan chan<- error) {
scanner := bufio.NewScanner(reader)
anonymizer := anonymize.NewAnonymizer(anonymize.DefaultAddresses())
status := s.statusRecorder.GetFullStatus()
seedFromStatus(anonymizer, &status)
defer func() {
if err := writer.Close(); err != nil {
log.Errorf("Failed to close writer: %v", err)
}
}()
for scanner.Scan() {
line := anonymizer.AnonymizeString(scanner.Text())
if _, err := writer.Write([]byte(line + "\n")); err != nil {
errChan <- fmt.Errorf("write line to writer: %w", err)
return
}
}
if err := scanner.Err(); err != nil {
errChan <- fmt.Errorf("read line from scanner: %w", err)
return
}
}
// SetLogLevel sets the logging level for the server.
func (s *Server) SetLogLevel(_ context.Context, req *proto.SetLogLevelRequest) (*proto.SetLogLevelResponse, error) {
level, err := log.ParseLevel(req.Level.String())
if err != nil {
return nil, fmt.Errorf("invalid log level: %w", err)
}
log.SetLevel(level)
log.Infof("Log level set to %s", level.String())
return &proto.SetLogLevelResponse{}, nil
}
func addFileToZip(archive *zip.Writer, reader io.Reader, filename string) error {
header := &zip.FileHeader{
Name: filename,
Method: zip.Deflate,
}
writer, err := archive.CreateHeader(header)
if err != nil {
return fmt.Errorf("create zip file header: %w", err)
}
if _, err := io.Copy(writer, reader); err != nil {
return fmt.Errorf("write file to zip: %w", err)
}
return nil
}
func seedFromStatus(a *anonymize.Anonymizer, status *peer.FullStatus) {
status.ManagementState.URL = a.AnonymizeURI(status.ManagementState.URL)
status.SignalState.URL = a.AnonymizeURI(status.SignalState.URL)
status.LocalPeerState.FQDN = a.AnonymizeDomain(status.LocalPeerState.FQDN)
for _, peer := range status.Peers {
a.AnonymizeDomain(peer.FQDN)
}
for _, nsGroup := range status.NSGroupStates {
for _, domain := range nsGroup.Domains {
a.AnonymizeDomain(domain)
}
}
for _, relay := range status.Relays {
if relay.URI != nil {
a.AnonymizeURI(relay.URI.String())
}
}
}

View File

@@ -1,110 +0,0 @@
package server
import (
"context"
"fmt"
"net/netip"
"sort"
"golang.org/x/exp/maps"
"github.com/netbirdio/netbird/client/proto"
)
type selectRoute struct {
NetID string
Network netip.Prefix
Selected bool
}
// ListRoutes returns a list of all available routes.
func (s *Server) ListRoutes(ctx context.Context, req *proto.ListRoutesRequest) (*proto.ListRoutesResponse, error) {
s.mutex.Lock()
defer s.mutex.Unlock()
if s.engine == nil {
return nil, fmt.Errorf("not connected")
}
routesMap := s.engine.GetClientRoutesWithNetID()
routeSelector := s.engine.GetRouteManager().GetRouteSelector()
var routes []*selectRoute
for id, rt := range routesMap {
if len(rt) == 0 {
continue
}
route := &selectRoute{
NetID: id,
Network: rt[0].Network,
Selected: routeSelector.IsSelected(id),
}
routes = append(routes, route)
}
sort.Slice(routes, func(i, j int) bool {
iPrefix := routes[i].Network.Bits()
jPrefix := routes[j].Network.Bits()
if iPrefix == jPrefix {
iAddr := routes[i].Network.Addr()
jAddr := routes[j].Network.Addr()
if iAddr == jAddr {
return routes[i].NetID < routes[j].NetID
}
return iAddr.String() < jAddr.String()
}
return iPrefix < jPrefix
})
var pbRoutes []*proto.Route
for _, route := range routes {
pbRoutes = append(pbRoutes, &proto.Route{
ID: route.NetID,
Network: route.Network.String(),
Selected: route.Selected,
})
}
return &proto.ListRoutesResponse{
Routes: pbRoutes,
}, nil
}
// SelectRoutes selects specific routes based on the client request.
func (s *Server) SelectRoutes(_ context.Context, req *proto.SelectRoutesRequest) (*proto.SelectRoutesResponse, error) {
s.mutex.Lock()
defer s.mutex.Unlock()
routeManager := s.engine.GetRouteManager()
routeSelector := routeManager.GetRouteSelector()
if req.GetAll() {
routeSelector.SelectAllRoutes()
} else {
if err := routeSelector.SelectRoutes(req.GetRouteIDs(), req.GetAppend(), maps.Keys(s.engine.GetClientRoutesWithNetID())); err != nil {
return nil, fmt.Errorf("select routes: %w", err)
}
}
routeManager.TriggerSelection(s.engine.GetClientRoutes())
return &proto.SelectRoutesResponse{}, nil
}
// DeselectRoutes deselects specific routes based on the client request.
func (s *Server) DeselectRoutes(_ context.Context, req *proto.SelectRoutesRequest) (*proto.SelectRoutesResponse, error) {
s.mutex.Lock()
defer s.mutex.Unlock()
routeManager := s.engine.GetRouteManager()
routeSelector := routeManager.GetRouteSelector()
if req.GetAll() {
routeSelector.DeselectAllRoutes()
} else {
if err := routeSelector.DeselectRoutes(req.GetRouteIDs(), maps.Keys(s.engine.GetClientRoutesWithNetID())); err != nil {
return nil, fmt.Errorf("deselect routes: %w", err)
}
}
routeManager.TriggerSelection(s.engine.GetClientRoutes())
return &proto.SelectRoutesResponse{}, nil
}

View File

@@ -15,15 +15,15 @@ import (
"google.golang.org/protobuf/types/known/durationpb"
"github.com/netbirdio/netbird/client/internal/auth"
"github.com/netbirdio/netbird/client/system"
log "github.com/sirupsen/logrus"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/metadata"
gstatus "google.golang.org/grpc/status"
"google.golang.org/protobuf/types/known/timestamppb"
"github.com/netbirdio/netbird/client/internal/auth"
"github.com/netbirdio/netbird/client/system"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/proto"
@@ -57,8 +57,6 @@ type Server struct {
config *internal.Config
proto.UnimplementedDaemonServiceServer
engine *internal.Engine
statusRecorder *peer.Status
sessionWatcher *internal.SessionWatcher
@@ -143,11 +141,8 @@ func (s *Server) Start() error {
s.sessionWatcher.SetOnExpireListener(s.onSessionExpire)
}
engineChan := make(chan *internal.Engine, 1)
go s.watchEngine(ctx, engineChan)
if !config.DisableAutoConnect {
go s.connectWithRetryRuns(ctx, config, s.statusRecorder, s.mgmProbe, s.signalProbe, s.relayProbe, s.wgProbe, engineChan)
go s.connectWithRetryRuns(ctx, config, s.statusRecorder, s.mgmProbe, s.signalProbe, s.relayProbe, s.wgProbe)
}
return nil
@@ -158,7 +153,6 @@ func (s *Server) Start() error {
// we cancel retry if the client receive a stop or down command, or if disable auto connect is configured.
func (s *Server) connectWithRetryRuns(ctx context.Context, config *internal.Config, statusRecorder *peer.Status,
mgmProbe *internal.Probe, signalProbe *internal.Probe, relayProbe *internal.Probe, wgProbe *internal.Probe,
engineChan chan<- *internal.Engine,
) {
backOff := getConnectWithBackoff(ctx)
retryStarted := false
@@ -188,7 +182,7 @@ func (s *Server) connectWithRetryRuns(ctx context.Context, config *internal.Conf
runOperation := func() error {
log.Tracef("running client connection")
err := internal.RunClientWithProbes(ctx, config, statusRecorder, mgmProbe, signalProbe, relayProbe, wgProbe, engineChan)
err := internal.RunClientWithProbes(ctx, config, statusRecorder, mgmProbe, signalProbe, relayProbe, wgProbe)
if err != nil {
log.Debugf("run client connection exited with error: %v. Will retry in the background", err)
}
@@ -568,10 +562,7 @@ func (s *Server) Up(callerCtx context.Context, _ *proto.UpRequest) (*proto.UpRes
s.statusRecorder.UpdateManagementAddress(s.config.ManagementURL.String())
s.statusRecorder.UpdateRosenpass(s.config.RosenpassEnabled, s.config.RosenpassPermissive)
engineChan := make(chan *internal.Engine, 1)
go s.watchEngine(ctx, engineChan)
go s.connectWithRetryRuns(ctx, s.config, s.statusRecorder, s.mgmProbe, s.signalProbe, s.relayProbe, s.wgProbe, engineChan)
go s.connectWithRetryRuns(ctx, s.config, s.statusRecorder, s.mgmProbe, s.signalProbe, s.relayProbe, s.wgProbe)
return &proto.UpResponse{}, nil
}
@@ -588,8 +579,6 @@ func (s *Server) Down(_ context.Context, _ *proto.DownRequest) (*proto.DownRespo
state := internal.CtxGetState(s.rootCtx)
state.Set(internal.StatusIdle)
s.engine = nil
return &proto.DownResponse{}, nil
}
@@ -672,6 +661,7 @@ func (s *Server) GetConfig(_ context.Context, _ *proto.GetConfigRequest) (*proto
PreSharedKey: preSharedKey,
}, nil
}
func (s *Server) onSessionExpire() {
if runtime.GOOS != "windows" {
isUIActive := internal.CheckUIApp()
@@ -683,22 +673,6 @@ func (s *Server) onSessionExpire() {
}
}
// watchEngine watches the engine channel and updates the engine state
func (s *Server) watchEngine(ctx context.Context, engineChan chan *internal.Engine) {
log.Tracef("Started watching engine")
for {
select {
case <-ctx.Done():
s.engine = nil
log.Tracef("Stopped watching engine")
return
case engine := <-engineChan:
log.Tracef("Received engine from watcher")
s.engine = engine
}
}
}
func toProtoFullStatus(fullStatus peer.FullStatus) *proto.FullStatus {
pbFullStatus := proto.FullStatus{
ManagementState: &proto.ManagementState{},

View File

@@ -2,12 +2,11 @@ package server
import (
"context"
"github.com/netbirdio/management-integrations/integrations"
"net"
"testing"
"time"
"github.com/netbirdio/management-integrations/integrations"
log "github.com/sirupsen/logrus"
"google.golang.org/grpc"
"google.golang.org/grpc/keepalive"
@@ -70,7 +69,7 @@ func TestConnectWithRetryRuns(t *testing.T) {
t.Setenv(maxRetryTimeVar, "5s")
t.Setenv(retryMultiplierVar, "1")
s.connectWithRetryRuns(ctx, config, s.statusRecorder, s.mgmProbe, s.signalProbe, s.relayProbe, s.wgProbe, nil)
s.connectWithRetryRuns(ctx, config, s.statusRecorder, s.mgmProbe, s.signalProbe, s.relayProbe, s.wgProbe)
if counter < 3 {
t.Fatalf("expected counter > 2, got %d", counter)
}

View File

@@ -1,5 +1,7 @@
//go:build !(linux && 386)
// +build !linux !386
// skipping linux 32 bits build and tests
package main
import (
@@ -56,23 +58,14 @@ func main() {
var showSettings bool
flag.BoolVar(&showSettings, "settings", false, "run settings windows")
var showRoutes bool
flag.BoolVar(&showRoutes, "routes", false, "run routes windows")
var errorMSG string
flag.StringVar(&errorMSG, "error-msg", "", "displays a error message window")
flag.Parse()
a := app.NewWithID("NetBird")
a.SetIcon(fyne.NewStaticResource("netbird", iconDisconnectedPNG))
if errorMSG != "" {
showErrorMSG(errorMSG)
return
}
client := newServiceClient(daemonAddr, a, showSettings, showRoutes)
if showSettings || showRoutes {
client := newServiceClient(daemonAddr, a, showSettings)
if showSettings {
a.Run()
} else {
if err := checkPIDFile(); err != nil {
@@ -135,7 +128,6 @@ type serviceClient struct {
mVersionDaemon *systray.MenuItem
mUpdate *systray.MenuItem
mQuit *systray.MenuItem
mRoutes *systray.MenuItem
// application with main windows.
app fyne.App
@@ -160,15 +152,12 @@ type serviceClient struct {
daemonVersion string
updateIndicationLock sync.Mutex
isUpdateIconActive bool
showRoutes bool
wRoutes fyne.Window
}
// newServiceClient instance constructor
//
// This constructor also builds the UI elements for the settings window.
func newServiceClient(addr string, a fyne.App, showSettings bool, showRoutes bool) *serviceClient {
func newServiceClient(addr string, a fyne.App, showSettings bool) *serviceClient {
s := &serviceClient{
ctx: context.Background(),
addr: addr,
@@ -176,7 +165,6 @@ func newServiceClient(addr string, a fyne.App, showSettings bool, showRoutes boo
sendNotification: false,
showSettings: showSettings,
showRoutes: showRoutes,
update: version.NewUpdate(),
}
@@ -196,16 +184,14 @@ func newServiceClient(addr string, a fyne.App, showSettings bool, showRoutes boo
}
if showSettings {
s.showSettingsUI()
s.showUIElements()
return s
} else if showRoutes {
s.showRoutesUI()
}
return s
}
func (s *serviceClient) showSettingsUI() {
func (s *serviceClient) showUIElements() {
// add settings window UI elements.
s.wSettings = s.app.NewWindow("NetBird Settings")
s.iMngURL = widget.NewEntry()
@@ -223,18 +209,6 @@ func (s *serviceClient) showSettingsUI() {
s.wSettings.Show()
}
// showErrorMSG opens a fyne app window to display the supplied message
func showErrorMSG(msg string) {
app := app.New()
w := app.NewWindow("NetBird Error")
content := widget.NewLabel(msg)
content.Wrapping = fyne.TextWrapWord
w.SetContent(content)
w.Resize(fyne.NewSize(400, 100))
w.Show()
app.Run()
}
// getSettingsForm to embed it into settings window.
func (s *serviceClient) getSettingsForm() *widget.Form {
return &widget.Form{
@@ -423,7 +397,6 @@ func (s *serviceClient) updateStatus() error {
s.mStatus.SetTitle("Connected")
s.mUp.Disable()
s.mDown.Enable()
s.mRoutes.Enable()
systrayIconState = true
} else if status.Status != string(internal.StatusConnected) && s.mUp.Disabled() {
s.connected = false
@@ -436,7 +409,6 @@ func (s *serviceClient) updateStatus() error {
s.mStatus.SetTitle("Disconnected")
s.mDown.Disable()
s.mUp.Enable()
s.mRoutes.Disable()
systrayIconState = false
}
@@ -492,11 +464,9 @@ func (s *serviceClient) onTrayReady() {
s.mUp = systray.AddMenuItem("Connect", "Connect")
s.mDown = systray.AddMenuItem("Disconnect", "Disconnect")
s.mDown.Disable()
s.mAdminPanel = systray.AddMenuItem("Admin Panel", "Netbird Admin Panel")
s.mAdminPanel = systray.AddMenuItem("Admin Panel", "Wiretrustee Admin Panel")
systray.AddSeparator()
s.mSettings = systray.AddMenuItem("Settings", "Settings of the application")
s.mRoutes = systray.AddMenuItem("Network Routes", "Open the routes management window")
s.mRoutes.Disable()
systray.AddSeparator()
s.mAbout = systray.AddMenuItem("About", "About")
@@ -534,22 +504,16 @@ func (s *serviceClient) onTrayReady() {
case <-s.mAdminPanel.ClickedCh:
err = open.Run(s.adminURL)
case <-s.mUp.ClickedCh:
s.mUp.Disabled()
go func() {
defer s.mUp.Enable()
err := s.menuUpClick()
if err != nil {
s.runSelfCommand("error-msg", err.Error())
return
}
}()
case <-s.mDown.ClickedCh:
s.mDown.Disable()
go func() {
defer s.mDown.Enable()
err := s.menuDownClick()
if err != nil {
s.runSelfCommand("error-msg", err.Error())
return
}
}()
@@ -557,8 +521,24 @@ func (s *serviceClient) onTrayReady() {
s.mSettings.Disable()
go func() {
defer s.mSettings.Enable()
defer s.getSrvConfig()
s.runSelfCommand("settings", "true")
proc, err := os.Executable()
if err != nil {
log.Errorf("show settings: %v", err)
return
}
cmd := exec.Command(proc, "--settings=true")
out, err := cmd.CombinedOutput()
if exitErr, ok := err.(*exec.ExitError); ok && exitErr.ExitCode() == 1 {
log.Errorf("start settings UI: %v, %s", err, string(out))
return
}
if len(out) != 0 {
log.Info("settings change:", string(out))
}
// update config in systray when settings windows closed
s.getSrvConfig()
}()
case <-s.mQuit.ClickedCh:
systray.Quit()
@@ -568,12 +548,6 @@ func (s *serviceClient) onTrayReady() {
if err != nil {
log.Errorf("%s", err)
}
case <-s.mRoutes.ClickedCh:
s.mRoutes.Disable()
go func() {
defer s.mRoutes.Enable()
s.runSelfCommand("routes", "true")
}()
}
if err != nil {
log.Errorf("process connection: %v", err)
@@ -582,24 +556,6 @@ func (s *serviceClient) onTrayReady() {
}()
}
func (s *serviceClient) runSelfCommand(command, arg string) {
proc, err := os.Executable()
if err != nil {
log.Errorf("show %s failed with error: %v", command, err)
return
}
cmd := exec.Command(proc, fmt.Sprintf("--%s=%s", command, arg))
out, err := cmd.CombinedOutput()
if exitErr, ok := err.(*exec.ExitError); ok && exitErr.ExitCode() == 1 {
log.Errorf("start %s UI: %v, %s", command, err, string(out))
return
}
if len(out) != 0 {
log.Infof("command %s executed: %s", command, string(out))
}
}
func normalizedVersion(version string) string {
versionString := version
if unicode.IsDigit(rune(versionString[0])) {

View File

@@ -1,203 +0,0 @@
//go:build !(linux && 386)
package main
import (
"fmt"
"strings"
"time"
"fyne.io/fyne/v2"
"fyne.io/fyne/v2/container"
"fyne.io/fyne/v2/dialog"
"fyne.io/fyne/v2/layout"
"fyne.io/fyne/v2/widget"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/proto"
)
func (s *serviceClient) showRoutesUI() {
s.wRoutes = s.app.NewWindow("NetBird Routes")
grid := container.New(layout.NewGridLayout(2))
go s.updateRoutes(grid)
routeCheckContainer := container.NewVBox()
routeCheckContainer.Add(grid)
scrollContainer := container.NewVScroll(routeCheckContainer)
scrollContainer.SetMinSize(fyne.NewSize(200, 300))
buttonBox := container.NewHBox(
layout.NewSpacer(),
widget.NewButton("Refresh", func() {
s.updateRoutes(grid)
}),
widget.NewButton("Select all", func() {
s.selectAllRoutes()
s.updateRoutes(grid)
}),
widget.NewButton("Deselect All", func() {
s.deselectAllRoutes()
s.updateRoutes(grid)
}),
layout.NewSpacer(),
)
content := container.NewBorder(nil, buttonBox, nil, nil, scrollContainer)
s.wRoutes.SetContent(content)
s.wRoutes.Show()
s.startAutoRefresh(5*time.Second, grid)
}
func (s *serviceClient) updateRoutes(grid *fyne.Container) {
routes, err := s.fetchRoutes()
if err != nil {
log.Errorf("get client: %v", err)
s.showError(fmt.Errorf("get client: %v", err))
return
}
grid.Objects = nil
idHeader := widget.NewLabelWithStyle(" ID", fyne.TextAlignLeading, fyne.TextStyle{Bold: true})
networkHeader := widget.NewLabelWithStyle("Network", fyne.TextAlignLeading, fyne.TextStyle{Bold: true})
grid.Add(idHeader)
grid.Add(networkHeader)
for _, route := range routes {
r := route
checkBox := widget.NewCheck(r.ID, func(checked bool) {
s.selectRoute(r.ID, checked)
})
checkBox.Checked = route.Selected
checkBox.Resize(fyne.NewSize(20, 20))
checkBox.Refresh()
grid.Add(checkBox)
grid.Add(widget.NewLabel(r.Network))
}
s.wRoutes.Content().Refresh()
}
func (s *serviceClient) fetchRoutes() ([]*proto.Route, error) {
conn, err := s.getSrvClient(defaultFailTimeout)
if err != nil {
return nil, fmt.Errorf("get client: %v", err)
}
resp, err := conn.ListRoutes(s.ctx, &proto.ListRoutesRequest{})
if err != nil {
return nil, fmt.Errorf("failed to list routes: %v", err)
}
return resp.Routes, nil
}
func (s *serviceClient) selectRoute(id string, checked bool) {
conn, err := s.getSrvClient(defaultFailTimeout)
if err != nil {
log.Errorf("get client: %v", err)
s.showError(fmt.Errorf("get client: %v", err))
return
}
req := &proto.SelectRoutesRequest{
RouteIDs: []string{id},
Append: checked,
}
if checked {
if _, err := conn.SelectRoutes(s.ctx, req); err != nil {
log.Errorf("failed to select route: %v", err)
s.showError(fmt.Errorf("failed to select route: %v", err))
return
}
log.Infof("Route %s selected", id)
} else {
if _, err := conn.DeselectRoutes(s.ctx, req); err != nil {
log.Errorf("failed to deselect route: %v", err)
s.showError(fmt.Errorf("failed to deselect route: %v", err))
return
}
log.Infof("Route %s deselected", id)
}
}
func (s *serviceClient) selectAllRoutes() {
conn, err := s.getSrvClient(defaultFailTimeout)
if err != nil {
log.Errorf("get client: %v", err)
return
}
req := &proto.SelectRoutesRequest{
All: true,
}
if _, err := conn.SelectRoutes(s.ctx, req); err != nil {
log.Errorf("failed to select all routes: %v", err)
s.showError(fmt.Errorf("failed to select all routes: %v", err))
return
}
log.Debug("All routes selected")
}
func (s *serviceClient) deselectAllRoutes() {
conn, err := s.getSrvClient(defaultFailTimeout)
if err != nil {
log.Errorf("get client: %v", err)
return
}
req := &proto.SelectRoutesRequest{
All: true,
}
if _, err := conn.DeselectRoutes(s.ctx, req); err != nil {
log.Errorf("failed to deselect all routes: %v", err)
s.showError(fmt.Errorf("failed to deselect all routes: %v", err))
return
}
log.Debug("All routes deselected")
}
func (s *serviceClient) showError(err error) {
wrappedMessage := wrapText(err.Error(), 50)
dialog.ShowError(fmt.Errorf("%s", wrappedMessage), s.wRoutes)
}
func (s *serviceClient) startAutoRefresh(interval time.Duration, grid *fyne.Container) {
ticker := time.NewTicker(interval)
go func() {
for range ticker.C {
s.updateRoutes(grid)
}
}()
s.wRoutes.SetOnClosed(func() {
ticker.Stop()
})
}
// wrapText inserts newlines into the text to ensure that each line is
// no longer than 'lineLength' runes.
func wrapText(text string, lineLength int) string {
var sb strings.Builder
var currentLineLength int
for _, runeValue := range text {
sb.WriteRune(runeValue)
currentLineLength++
if currentLineLength >= lineLength || runeValue == '\n' {
sb.WriteRune('\n')
currentLineLength = 0
}
}
return sb.String()
}

8
go.mod
View File

@@ -21,8 +21,8 @@ require (
github.com/spf13/cobra v1.7.0
github.com/spf13/pflag v1.0.5
github.com/vishvananda/netlink v1.1.1-0.20211118161826-650dca95af54
golang.org/x/crypto v0.21.0
golang.org/x/sys v0.18.0
golang.org/x/crypto v0.18.0
golang.org/x/sys v0.16.0
golang.zx2c4.com/wireguard v0.0.0-20230704135630-469159ecf7d1
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6
golang.zx2c4.com/wireguard/windows v0.5.3
@@ -82,10 +82,10 @@ require (
goauthentik.io/api/v3 v3.2023051.3
golang.org/x/exp v0.0.0-20230725093048-515e97ebf090
golang.org/x/mobile v0.0.0-20190719004257-d2bd2a29d028
golang.org/x/net v0.23.0
golang.org/x/net v0.20.0
golang.org/x/oauth2 v0.8.0
golang.org/x/sync v0.3.0
golang.org/x/term v0.18.0
golang.org/x/term v0.16.0
google.golang.org/api v0.126.0
gopkg.in/yaml.v3 v3.0.1
gorm.io/driver/sqlite v1.5.3

12
go.sum
View File

@@ -581,9 +581,8 @@ golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5y
golang.org/x/crypto v0.0.0-20220314234659-1baeb1ce4c0b/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
golang.org/x/crypto v0.8.0/go.mod h1:mRqEX+O9/h5TFCrQhkgjo2yKi0yYA+9ecGkdQoHrywE=
golang.org/x/crypto v0.12.0/go.mod h1:NF0Gs7EO5K4qLn+Ylc+fih8BSTeIjAP05siRnAh98yw=
golang.org/x/crypto v0.18.0 h1:PGVlW0xEltQnzFZ55hkuX5+KLyrMYhHld1YHO4AKcdc=
golang.org/x/crypto v0.18.0/go.mod h1:R0j02AL6hcrfOiy9T4ZYp/rcWeMxM3L6QYxlOuEG1mg=
golang.org/x/crypto v0.21.0 h1:X31++rzVUdKhX5sWmSOFZxx8UW/ldWx55cbf08iNAMA=
golang.org/x/crypto v0.21.0/go.mod h1:0BP7YvVV9gBbVKyeTG0Gyn+gZm94bibOW5BjDEYAOMs=
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20190510132918-efd6b22b2522/go.mod h1:ZjyILWgesfNpC6sMxTJOJm9Kp84zZh5NQWvqDGG3Qr8=
@@ -670,9 +669,8 @@ golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs=
golang.org/x/net v0.9.0/go.mod h1:d48xBJpPfHeWQsugry2m+kC02ZBRGRgulfHnEXEuWns=
golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg=
golang.org/x/net v0.14.0/go.mod h1:PpSgVXXLK0OxS0F31C1/tv6XNguvCrnXIDrFMspZIUI=
golang.org/x/net v0.20.0 h1:aCL9BSgETF1k+blQaYUBx9hJ9LOGP3gAVemcZlf1Kpo=
golang.org/x/net v0.20.0/go.mod h1:z8BVo6PvndSri0LbOE3hAn0apkU+1YvI6E70E9jsnvY=
golang.org/x/net v0.23.0 h1:7EYJ93RZ9vYSZAIb2x3lnuvqO5zneoD6IvWjuhfxjTs=
golang.org/x/net v0.23.0/go.mod h1:JKghWKKOSdJwpW2GEx0Ja7fmaKnMsbu+MWVZTokSYmg=
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
@@ -763,18 +761,16 @@ golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.7.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.16.0 h1:xWw16ngr6ZMtmxDyKyIgsE93KNKz5HKmMa3b8ALHidU=
golang.org/x/sys v0.16.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.18.0 h1:DBdB3niSjOA/O0blCZBqDefyWNYveAYMNF1Wum0DYQ4=
golang.org/x/sys v0.18.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k=
golang.org/x/term v0.7.0/go.mod h1:P32HKFT3hSsZrRxla30E9HqToFYAQPCMs/zFMBUFqPY=
golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo=
golang.org/x/term v0.11.0/go.mod h1:zC9APTIj3jG3FdV/Ons+XE1riIZXG4aZ4GTHiPZJPIU=
golang.org/x/term v0.16.0 h1:m+B6fahuftsE9qjo0VWp2FW0mB3MTJvR0BaMQrq0pmE=
golang.org/x/term v0.16.0/go.mod h1:yn7UURbUtPyrVJPGPq404EukNFxcm/foM+bV/bfcDsY=
golang.org/x/term v0.18.0 h1:FcHjZXDMxI8mM3nwhX9HlKop4C0YQvCVCdwYl2wOtE8=
golang.org/x/term v0.18.0/go.mod h1:ILwASektA3OnRv7amZ1xhE/KTR+u50pbXfZ03+6Nx58=
golang.org/x/text v0.0.0-20160726164857-2910a502d2bf/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=

View File

@@ -11,7 +11,6 @@ import (
"net/netip"
"reflect"
"regexp"
"runtime/debug"
"strings"
"sync"
"time"
@@ -47,8 +46,6 @@ const (
DefaultPeerLoginExpiration = 24 * time.Hour
)
type userLoggedInOnce bool
type ExternalCacheManager cache.CacheInterface[*idp.UserData]
func cacheEntryExpiration() time.Duration {
@@ -77,7 +74,7 @@ type AccountManager interface {
GetUser(claims jwtclaims.AuthorizationClaims) (*User, error)
ListUsers(accountID string) ([]*User, error)
GetPeers(accountID, userID string) ([]*nbpeer.Peer, error)
MarkPeerConnected(peerKey string, connected bool, realIP net.IP, account *Account) error
MarkPeerConnected(peerKey string, connected bool, realIP net.IP) error
DeletePeer(accountID, peerID, userID string) error
UpdatePeer(accountID, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, error)
GetNetworkMap(peerID string) (*NetworkMap, error)
@@ -118,8 +115,8 @@ type AccountManager interface {
SaveDNSSettings(accountID string, userID string, dnsSettingsToSave *DNSSettings) error
GetPeer(accountID, peerID, userID string) (*nbpeer.Peer, error)
UpdateAccountSettings(accountID, userID string, newSettings *Settings) (*Account, error)
LoginPeer(login PeerLogin) (*nbpeer.Peer, *NetworkMap, error) // used by peer gRPC API
SyncPeer(sync PeerSync, account *Account) (*nbpeer.Peer, *NetworkMap, error) // used by peer gRPC API
LoginPeer(login PeerLogin) (*nbpeer.Peer, *NetworkMap, error) // used by peer gRPC API
SyncPeer(sync PeerSync) (*nbpeer.Peer, *NetworkMap, error) // used by peer gRPC API
GetAllConnectedPeers() (map[string]struct{}, error)
HasConnectedChannel(peerID string) bool
GetExternalCacheManager() ExternalCacheManager
@@ -131,8 +128,6 @@ type AccountManager interface {
UpdateIntegratedValidatorGroups(accountID string, userID string, groups []string) error
GroupValidation(accountId string, groups []string) (bool, error)
GetValidatedPeers(account *Account) (map[string]struct{}, error)
SyncAndMarkPeer(peerPubKey string, realIP net.IP) (*nbpeer.Peer, *NetworkMap, error)
CancelPeerRoutines(peer *nbpeer.Peer) error
}
type DefaultAccountManager struct {
@@ -389,8 +384,6 @@ func (a *Account) GetGroup(groupID string) *nbgroup.Group {
// GetPeerNetworkMap returns a group by ID if exists, nil otherwise
func (a *Account) GetPeerNetworkMap(peerID, dnsDomain string, validatedPeersMap map[string]struct{}) *NetworkMap {
log.Debugf("GetNetworkMap with trace: %s", string(debug.Stack()))
peer := a.Peers[peerID]
if peer == nil {
return &NetworkMap{
@@ -963,7 +956,7 @@ func (am *DefaultAccountManager) UpdateAccountSettings(accountID, userID string,
return nil, status.Errorf(status.InvalidArgument, "peer login expiration can't be smaller than one hour")
}
unlock := am.Store.AcquireAccountWriteLock(accountID)
unlock := am.Store.AcquireAccountLock(accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID)
@@ -1014,7 +1007,7 @@ func (am *DefaultAccountManager) UpdateAccountSettings(accountID, userID string,
func (am *DefaultAccountManager) peerLoginExpirationJob(accountID string) func() (time.Duration, bool) {
return func() (time.Duration, bool) {
unlock := am.Store.AcquireAccountWriteLock(accountID)
unlock := am.Store.AcquireAccountLock(accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID)
@@ -1099,21 +1092,19 @@ func (am *DefaultAccountManager) warmupIDPCache() error {
}
delete(userData, idp.UnsetAccountID)
rcvdUsers := 0
for accountID, users := range userData {
rcvdUsers += len(users)
err = am.cacheManager.Set(am.ctx, accountID, users, cacheStore.WithExpiration(cacheEntryExpiration()))
if err != nil {
return err
}
}
log.Infof("warmed up IDP cache with %d entries for %d accounts", rcvdUsers, len(userData))
log.Infof("warmed up IDP cache with %d entries", len(userData))
return nil
}
// DeleteAccount deletes an account and all its users from local store and from the remote IDP if the requester is an admin and account owner
func (am *DefaultAccountManager) DeleteAccount(accountID, userID string) error {
unlock := am.Store.AcquireAccountWriteLock(accountID)
unlock := am.Store.AcquireAccountLock(accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID)
if err != nil {
@@ -1272,7 +1263,7 @@ func (am *DefaultAccountManager) lookupUserInCacheByEmail(email string, accountI
// lookupUserInCache looks up user in the IdP cache and returns it. If the user wasn't found, the function returns nil
func (am *DefaultAccountManager) lookupUserInCache(userID string, account *Account) (*idp.UserData, error) {
users := make(map[string]userLoggedInOnce, len(account.Users))
users := make(map[string]struct{}, len(account.Users))
// ignore service users and users provisioned by integrations than are never logged in
for _, user := range account.Users {
if user.IsServiceUser {
@@ -1281,7 +1272,7 @@ func (am *DefaultAccountManager) lookupUserInCache(userID string, account *Accou
if user.Issued == UserIssuedIntegration {
continue
}
users[user.Id] = userLoggedInOnce(!user.LastLogin.IsZero())
users[user.Id] = struct{}{}
}
log.Debugf("looking up user %s of account %s in cache", userID, account.Id)
userData, err := am.lookupCache(users, account.Id)
@@ -1354,57 +1345,22 @@ func (am *DefaultAccountManager) getAccountFromCache(accountID string, forceRelo
}
}
func (am *DefaultAccountManager) lookupCache(accountUsers map[string]userLoggedInOnce, accountID string) ([]*idp.UserData, error) {
var data []*idp.UserData
var err error
maxAttempts := 2
data, err = am.getAccountFromCache(accountID, false)
func (am *DefaultAccountManager) lookupCache(accountUsers map[string]struct{}, accountID string) ([]*idp.UserData, error) {
data, err := am.getAccountFromCache(accountID, false)
if err != nil {
return nil, err
}
for attempt := 1; attempt <= maxAttempts; attempt++ {
if am.isCacheFresh(accountUsers, data) {
return data, nil
}
if attempt > 1 {
time.Sleep(200 * time.Millisecond)
}
log.Infof("refreshing cache for account %s", accountID)
data, err = am.refreshCache(accountID)
if err != nil {
return nil, err
}
if attempt == maxAttempts {
log.Warnf("cache for account %s reached maximum refresh attempts (%d)", accountID, maxAttempts)
}
}
return data, nil
}
// isCacheFresh checks if the cache is refreshed already by comparing the accountUsers with the cache data by user count and user invite status
func (am *DefaultAccountManager) isCacheFresh(accountUsers map[string]userLoggedInOnce, data []*idp.UserData) bool {
userDataMap := make(map[string]*idp.UserData, len(data))
userDataMap := make(map[string]struct{})
for _, datum := range data {
userDataMap[datum.ID] = datum
userDataMap[datum.ID] = struct{}{}
}
// the accountUsers ID list of non integration users from store, we check if cache has all of them
// as result of for loop knownUsersCount will have number of users are not presented in the cashed
knownUsersCount := len(accountUsers)
for user, loggedInOnce := range accountUsers {
if datum, ok := userDataMap[user]; ok {
// check if the matching user data has a pending invite and if the user has logged in once, forcing the cache to be refreshed
if datum.AppMetadata.WTPendingInvite != nil && *datum.AppMetadata.WTPendingInvite && loggedInOnce == true { //nolint:gosimple
log.Infof("user %s has a pending invite and has logged in once, cache invalid", user)
return false
}
for user := range accountUsers {
if _, ok := userDataMap[user]; ok {
knownUsersCount--
continue
}
@@ -1413,11 +1369,15 @@ func (am *DefaultAccountManager) isCacheFresh(accountUsers map[string]userLogged
// if we know users that are not yet in cache more likely cache is outdated
if knownUsersCount > 0 {
log.Infof("cache invalid. Users unknown to the cache: %d", knownUsersCount)
return false
log.Debugf("cache doesn't know about %d users from store, reloading", knownUsersCount)
// reload cache once avoiding loops
data, err = am.refreshCache(accountID)
if err != nil {
return nil, err
}
}
return true
return data, err
}
func (am *DefaultAccountManager) removeUserFromCache(accountID, userID string) error {
@@ -1465,14 +1425,29 @@ func (am *DefaultAccountManager) updateAccountDomainAttributes(account *Account,
}
// handleExistingUserAccount handles existing User accounts and update its domain attributes.
//
// If there is no primary domain account yet, we set the account as primary for the domain. Otherwise,
// we compare the account's ID with the domain account ID, and if they don't match, we set the account as
// non-primary account for the domain. We don't merge accounts at this stage, because of cases when a domain
// was previously unclassified or classified as public so N users that logged int that time, has they own account
// and peers that shouldn't be lost.
func (am *DefaultAccountManager) handleExistingUserAccount(
existingAcc *Account,
primaryDomain bool,
domainAcc *Account,
claims jwtclaims.AuthorizationClaims,
) error {
err := am.updateAccountDomainAttributes(existingAcc, claims, primaryDomain)
if err != nil {
return err
var err error
if domainAcc != nil && existingAcc.Id != domainAcc.Id {
err = am.updateAccountDomainAttributes(existingAcc, claims, false)
if err != nil {
return err
}
} else {
err = am.updateAccountDomainAttributes(existingAcc, claims, true)
if err != nil {
return err
}
}
// we should register the account ID to this user's metadata in our IDP manager
@@ -1572,7 +1547,7 @@ func (am *DefaultAccountManager) MarkPATUsed(tokenID string) error {
return err
}
unlock := am.Store.AcquireAccountWriteLock(account.Id)
unlock := am.Store.AcquireAccountLock(account.Id)
defer unlock()
account, err = am.Store.GetAccountByUser(user.Id)
@@ -1655,7 +1630,7 @@ func (am *DefaultAccountManager) GetAccountFromToken(claims jwtclaims.Authorizat
if err != nil {
return nil, nil, err
}
unlock := am.Store.AcquireAccountWriteLock(newAcc.Id)
unlock := am.Store.AcquireAccountLock(newAcc.Id)
alreadyUnlocked := false
defer func() {
if !alreadyUnlocked {
@@ -1674,7 +1649,7 @@ func (am *DefaultAccountManager) GetAccountFromToken(claims jwtclaims.Authorizat
return nil, nil, status.Errorf(status.NotFound, "user %s not found", claims.UserId)
}
if !user.IsServiceUser && claims.Invited {
if !user.IsServiceUser {
err = am.redeemInvite(account, claims.UserId)
if err != nil {
return nil, nil, err
@@ -1806,33 +1781,12 @@ func (am *DefaultAccountManager) getAccountWithAuthorizationClaims(claims jwtcla
account, err := am.Store.GetAccountByUser(claims.UserId)
if err == nil {
unlockAccount := am.Store.AcquireAccountWriteLock(account.Id)
defer unlockAccount()
account, err = am.Store.GetAccountByUser(claims.UserId)
if err != nil {
return nil, err
}
// If there is no primary domain account yet, we set the account as primary for the domain. Otherwise,
// we compare the account's ID with the domain account ID, and if they don't match, we set the account as
// non-primary account for the domain. We don't merge accounts at this stage, because of cases when a domain
// was previously unclassified or classified as public so N users that logged int that time, has they own account
// and peers that shouldn't be lost.
primaryDomain := domainAccount == nil || account.Id == domainAccount.Id
err = am.handleExistingUserAccount(account, primaryDomain, claims)
err = am.handleExistingUserAccount(account, domainAccount, claims)
if err != nil {
return nil, err
}
return account, nil
} else if s, ok := status.FromError(err); ok && s.Type() == status.NotFound {
if domainAccount != nil {
unlockAccount := am.Store.AcquireAccountWriteLock(domainAccount.Id)
defer unlockAccount()
domainAccount, err = am.Store.GetAccountByPrivateDomain(claims.Domain)
if err != nil {
return nil, err
}
}
return am.handleNewUserAccount(domainAccount, claims)
} else {
// other error
@@ -1840,56 +1794,6 @@ func (am *DefaultAccountManager) getAccountWithAuthorizationClaims(claims jwtcla
}
}
func (am *DefaultAccountManager) SyncAndMarkPeer(peerPubKey string, realIP net.IP) (*nbpeer.Peer, *NetworkMap, error) {
accountID, err := am.Store.GetAccountIDByPeerPubKey(peerPubKey)
if err != nil {
return nil, nil, err
}
unlock := am.Store.AcquireAccountReadLock(accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID)
if err != nil {
return nil, nil, err
}
peer, netMap, err := am.SyncPeer(PeerSync{WireGuardPubKey: peerPubKey}, account)
if err != nil {
return nil, nil, mapError(err)
}
err = am.MarkPeerConnected(peerPubKey, true, realIP, account)
if err != nil {
log.Warnf("failed marking peer as connected %s %v", peerPubKey, err)
}
return peer, netMap, nil
}
func (am *DefaultAccountManager) CancelPeerRoutines(peer *nbpeer.Peer) error {
accountID, err := am.Store.GetAccountIDByPeerPubKey(peer.Key)
if err != nil {
return err
}
unlock := am.Store.AcquireAccountWriteLock(accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID)
if err != nil {
return err
}
err = am.MarkPeerConnected(peer.Key, false, nil, account)
if err != nil {
log.Warnf("failed marking peer as connected %s %v", peer.Key, err)
}
return nil
}
// GetAllConnectedPeers returns connected peers based on peersUpdateManager.GetAllConnectedPeers()
func (am *DefaultAccountManager) GetAllConnectedPeers() (map[string]struct{}, error) {
return am.peersUpdateManager.GetAllConnectedPeers(), nil

View File

@@ -1655,7 +1655,7 @@ func TestDefaultAccountManager_DefaultAccountSettings(t *testing.T) {
func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) {
manager, err := createManager(t)
require.NoError(t, err, "unable to create account manager")
_, err = manager.GetAccountByUserOrAccountID(userID, "", "")
account, err := manager.GetAccountByUserOrAccountID(userID, "", "")
require.NoError(t, err, "unable to create an account")
key, err := wgtypes.GenerateKey()
@@ -1666,10 +1666,7 @@ func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) {
LoginExpirationEnabled: true,
})
require.NoError(t, err, "unable to add peer")
account, err := manager.GetAccountByUserOrAccountID(userID, "", "")
require.NoError(t, err, "unable to get the account")
err = manager.MarkPeerConnected(key.PublicKey().String(), true, nil, account)
err = manager.MarkPeerConnected(key.PublicKey().String(), true, nil)
require.NoError(t, err, "unable to mark peer connected")
account, err = manager.UpdateAccountSettings(account.Id, userID, &Settings{
PeerLoginExpiration: time.Hour,
@@ -1735,10 +1732,8 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing.
},
}
account, err = manager.GetAccountByUserOrAccountID(userID, "", "")
require.NoError(t, err, "unable to get the account")
// when we mark peer as connected, the peer login expiration routine should trigger
err = manager.MarkPeerConnected(key.PublicKey().String(), true, nil, account)
err = manager.MarkPeerConnected(key.PublicKey().String(), true, nil)
require.NoError(t, err, "unable to mark peer connected")
failed := waitTimeout(wg, time.Second)
@@ -1750,7 +1745,7 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing.
func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *testing.T) {
manager, err := createManager(t)
require.NoError(t, err, "unable to create account manager")
_, err = manager.GetAccountByUserOrAccountID(userID, "", "")
account, err := manager.GetAccountByUserOrAccountID(userID, "", "")
require.NoError(t, err, "unable to create an account")
key, err := wgtypes.GenerateKey()
@@ -1761,10 +1756,7 @@ func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *test
LoginExpirationEnabled: true,
})
require.NoError(t, err, "unable to add peer")
account, err := manager.GetAccountByUserOrAccountID(userID, "", "")
require.NoError(t, err, "unable to get the account")
err = manager.MarkPeerConnected(key.PublicKey().String(), true, nil, account)
err = manager.MarkPeerConnected(key.PublicKey().String(), true, nil)
require.NoError(t, err, "unable to mark peer connected")
wg := &sync.WaitGroup{}

View File

@@ -35,7 +35,7 @@ func (d DNSSettings) Copy() DNSSettings {
// GetDNSSettings validates a user role and returns the DNS settings for the provided account ID
func (am *DefaultAccountManager) GetDNSSettings(accountID string, userID string) (*DNSSettings, error) {
unlock := am.Store.AcquireAccountWriteLock(accountID)
unlock := am.Store.AcquireAccountLock(accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID)
@@ -57,7 +57,7 @@ func (am *DefaultAccountManager) GetDNSSettings(accountID string, userID string)
// SaveDNSSettings validates a user role and updates the account's DNS settings
func (am *DefaultAccountManager) SaveDNSSettings(accountID string, userID string, dnsSettingsToSave *DNSSettings) error {
unlock := am.Store.AcquireAccountWriteLock(accountID)
unlock := am.Store.AcquireAccountLock(accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID)

View File

@@ -12,7 +12,7 @@ import (
// GetEvents returns a list of activity events of an account
func (am *DefaultAccountManager) GetEvents(accountID, userID string) ([]*activity.Event, error) {
unlock := am.Store.AcquireAccountWriteLock(accountID)
unlock := am.Store.AcquireAccountLock(accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID)

View File

@@ -279,8 +279,8 @@ func (s *FileStore) AcquireGlobalLock() (unlock func()) {
return unlock
}
// AcquireAccountWriteLock acquires account lock for writing to a resource and returns a function that releases the lock
func (s *FileStore) AcquireAccountWriteLock(accountID string) (unlock func()) {
// AcquireAccountLock acquires account lock and returns a function that releases the lock
func (s *FileStore) AcquireAccountLock(accountID string) (unlock func()) {
log.Debugf("acquiring lock for account %s", accountID)
start := time.Now()
value, _ := s.accountLocks.LoadOrStore(accountID, &sync.Mutex{})
@@ -295,12 +295,6 @@ func (s *FileStore) AcquireAccountWriteLock(accountID string) (unlock func()) {
return unlock
}
// AcquireAccountReadLock AcquireAccountWriteLock acquires account lock for reading a resource and returns a function that releases the lock
// This method is still returns a write lock as file store can't handle read locks
func (s *FileStore) AcquireAccountReadLock(accountID string) (unlock func()) {
return s.AcquireAccountWriteLock(accountID)
}
func (s *FileStore) SaveAccount(account *Account) error {
s.mux.Lock()
defer s.mux.Unlock()
@@ -578,18 +572,6 @@ func (s *FileStore) GetAccountByPeerPubKey(peerKey string) (*Account, error) {
return account.Copy(), nil
}
func (s *FileStore) GetAccountIDByPeerPubKey(peerKey string) (string, error) {
s.mux.Lock()
defer s.mux.Unlock()
accountID, ok := s.PeerKeyID2AccountID[peerKey]
if !ok {
return "", status.Errorf(status.NotFound, "provided peer key doesn't exists %s", peerKey)
}
return accountID, nil
}
// GetInstallationID returns the installation ID from the store
func (s *FileStore) GetInstallationID() string {
return s.InstallationID

View File

@@ -22,7 +22,7 @@ func (e *GroupLinkError) Error() string {
// GetGroup object of the peers
func (am *DefaultAccountManager) GetGroup(accountID, groupID, userID string) (*nbgroup.Group, error) {
unlock := am.Store.AcquireAccountWriteLock(accountID)
unlock := am.Store.AcquireAccountLock(accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID)
@@ -49,7 +49,7 @@ func (am *DefaultAccountManager) GetGroup(accountID, groupID, userID string) (*n
// GetAllGroups returns all groups in an account
func (am *DefaultAccountManager) GetAllGroups(accountID string, userID string) ([]*nbgroup.Group, error) {
unlock := am.Store.AcquireAccountWriteLock(accountID)
unlock := am.Store.AcquireAccountLock(accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID)
@@ -76,7 +76,7 @@ func (am *DefaultAccountManager) GetAllGroups(accountID string, userID string) (
// GetGroupByName filters all groups in an account by name and returns the one with the most peers
func (am *DefaultAccountManager) GetGroupByName(groupName, accountID string) (*nbgroup.Group, error) {
unlock := am.Store.AcquireAccountWriteLock(accountID)
unlock := am.Store.AcquireAccountLock(accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID)
@@ -109,7 +109,7 @@ func (am *DefaultAccountManager) GetGroupByName(groupName, accountID string) (*n
// SaveGroup object of the peers
func (am *DefaultAccountManager) SaveGroup(accountID, userID string, newGroup *nbgroup.Group) error {
unlock := am.Store.AcquireAccountWriteLock(accountID)
unlock := am.Store.AcquireAccountLock(accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID)
@@ -214,7 +214,7 @@ func difference(a, b []string) []string {
// DeleteGroup object of the peers
func (am *DefaultAccountManager) DeleteGroup(accountId, userId, groupID string) error {
unlock := am.Store.AcquireAccountWriteLock(accountId)
unlock := am.Store.AcquireAccountLock(accountId)
defer unlock()
account, err := am.Store.GetAccount(accountId)
@@ -323,7 +323,7 @@ func (am *DefaultAccountManager) DeleteGroup(accountId, userId, groupID string)
// ListGroups objects of the peers
func (am *DefaultAccountManager) ListGroups(accountID string) ([]*nbgroup.Group, error) {
unlock := am.Store.AcquireAccountWriteLock(accountID)
unlock := am.Store.AcquireAccountLock(accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID)
@@ -341,7 +341,7 @@ func (am *DefaultAccountManager) ListGroups(accountID string) ([]*nbgroup.Group,
// GroupAddPeer appends peer to the group
func (am *DefaultAccountManager) GroupAddPeer(accountID, groupID, peerID string) error {
unlock := am.Store.AcquireAccountWriteLock(accountID)
unlock := am.Store.AcquireAccountLock(accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID)
@@ -377,7 +377,7 @@ func (am *DefaultAccountManager) GroupAddPeer(accountID, groupID, peerID string)
// GroupDeletePeer removes peer from the group
func (am *DefaultAccountManager) GroupDeletePeer(accountID, groupID, peerID string) error {
unlock := am.Store.AcquireAccountWriteLock(accountID)
unlock := am.Store.AcquireAccountLock(accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID)

View File

@@ -134,9 +134,9 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi
return err
}
peer, netMap, err := s.accountManager.SyncAndMarkPeer(peerKey.String(), realIP)
peer, netMap, err := s.accountManager.SyncPeer(PeerSync{WireGuardPubKey: peerKey.String()})
if err != nil {
return err
return mapError(err)
}
err = s.sendInitialSync(peerKey, peer, netMap, srv)
@@ -149,6 +149,11 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi
s.ephemeralManager.OnPeerConnected(peer)
err = s.accountManager.MarkPeerConnected(peerKey.String(), true, realIP)
if err != nil {
log.Warnf("failed marking peer as connected %s %v", peerKey, err)
}
if s.config.TURNConfig.TimeBasedCredentials {
s.turnCredentialsManager.SetupRefresh(peer.ID)
}
@@ -202,7 +207,7 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi
func (s *GRPCServer) cancelPeerRoutines(peer *nbpeer.Peer) {
s.peersUpdateManager.CloseChannel(peer.ID)
s.turnCredentialsManager.CancelRefresh(peer.ID)
_ = s.accountManager.CancelPeerRoutines(peer)
_ = s.accountManager.MarkPeerConnected(peer.Key, false, nil)
s.ephemeralManager.OnPeerDisconnected(peer)
}

View File

@@ -181,8 +181,8 @@ func (h *SetupKeysHandler) GetAllSetupKeys(w http.ResponseWriter, r *http.Reques
}
func writeSuccess(w http.ResponseWriter, key *server.SetupKey) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(200)
w.Header().Set("Content-Type", "application/json")
err := json.NewEncoder(w).Encode(toResponseBody(key))
if err != nil {
util.WriteError(err, w)

View File

@@ -198,6 +198,8 @@ func (h *UsersHandler) GetAllUsers(w http.ResponseWriter, r *http.Request) {
serviceUser := r.URL.Query().Get("service_user")
log.Debugf("UserCount: %v", len(data))
users := make([]*api.User, 0)
for _, r := range data {
if r.NonDeletable {

View File

@@ -20,8 +20,8 @@ type ErrorResponse struct {
// WriteJSONObject simply writes object to the HTTP response in JSON format
func WriteJSONObject(w http.ResponseWriter, obj interface{}) {
w.Header().Set("Content-Type", "application/json; charset=UTF-8")
w.WriteHeader(http.StatusOK)
w.Header().Set("Content-Type", "application/json; charset=UTF-8")
err := json.NewEncoder(w).Encode(obj)
if err != nil {
WriteError(err, w)
@@ -63,8 +63,8 @@ func (d *Duration) UnmarshalJSON(b []byte) error {
// WriteErrorResponse prepares and writes an error response i nJSON
func WriteErrorResponse(errMsg string, httpStatus int, w http.ResponseWriter) {
w.Header().Set("Content-Type", "application/json; charset=UTF-8")
w.WriteHeader(httpStatus)
w.Header().Set("Content-Type", "application/json; charset=UTF-8")
err := json.NewEncoder(w).Encode(&ErrorResponse{
Message: errMsg,
Code: httpStatus,

View File

@@ -31,7 +31,7 @@ func (am *DefaultAccountManager) UpdateIntegratedValidatorGroups(accountID strin
return errors.New("invalid groups")
}
unlock := am.Store.AcquireAccountWriteLock(accountID)
unlock := am.Store.AcquireAccountLock(accountID)
defer unlock()
a, err := am.Store.GetAccountByUser(userID)

View File

@@ -13,7 +13,6 @@ type AuthorizationClaims struct {
Domain string
DomainCategory string
LastLogin time.Time
Invited bool
Raw jwt.MapClaims
}

View File

@@ -20,8 +20,6 @@ const (
UserIDClaim = "sub"
// LastLoginSuffix claim for the last login
LastLoginSuffix = "nb_last_login"
// Invited claim indicates that an incoming JWT is from a user that just accepted an invitation
Invited = "nb_invited"
)
// ExtractClaims Extract function type
@@ -102,10 +100,6 @@ func (c *ClaimsExtractor) FromToken(token *jwt.Token) AuthorizationClaims {
if ok {
jwtClaims.LastLogin = parseTime(LastLoginClaimString.(string))
}
invitedBool, ok := claims[c.authAudience+Invited]
if ok {
jwtClaims.Invited = invitedBool.(bool)
}
return jwtClaims
}

View File

@@ -30,10 +30,6 @@ func newTestRequestWithJWT(t *testing.T, claims AuthorizationClaims, audience st
if claims.LastLogin != (time.Time{}) {
claimMaps[audience+LastLoginSuffix] = claims.LastLogin.Format(layout)
}
if claims.Invited {
claimMaps[audience+Invited] = true
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claimMaps)
r, err := http.NewRequest(http.MethodGet, "http://localhost", nil)
require.NoError(t, err, "creating testing request failed")
@@ -63,14 +59,12 @@ func TestExtractClaimsFromRequestContext(t *testing.T) {
AccountId: "testAcc",
LastLogin: lastLogin,
DomainCategory: "public",
Invited: true,
Raw: jwt.MapClaims{
"https://login/wt_account_domain": "test.com",
"https://login/wt_account_domain_category": "public",
"https://login/wt_account_id": "testAcc",
"https://login/nb_last_login": lastLogin.Format(layout),
"sub": "test",
"https://login/" + Invited: true,
},
},
testingFunc: require.EqualValues,

View File

@@ -1,204 +0,0 @@
package migration
import (
"database/sql"
"encoding/gob"
"encoding/json"
"errors"
"fmt"
"net"
"strings"
log "github.com/sirupsen/logrus"
"gorm.io/gorm"
)
// MigrateFieldFromGobToJSON migrates a column from Gob encoding to JSON encoding.
// T is the type of the model that contains the field to be migrated.
// S is the type of the field to be migrated.
func MigrateFieldFromGobToJSON[T any, S any](db *gorm.DB, fieldName string) error {
oldColumnName := fieldName
newColumnName := fieldName + "_tmp"
var model T
if !db.Migrator().HasTable(&model) {
log.Debugf("Table for %T does not exist, no migration needed", model)
return nil
}
stmt := &gorm.Statement{DB: db}
err := stmt.Parse(model)
if err != nil {
return fmt.Errorf("parse model: %w", err)
}
tableName := stmt.Schema.Table
var item string
if err := db.Model(model).Select(oldColumnName).First(&item).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
log.Debugf("No records in table %s, no migration needed", tableName)
return nil
}
return fmt.Errorf("fetch first record: %w", err)
}
var js json.RawMessage
var syntaxError *json.SyntaxError
err = json.Unmarshal([]byte(item), &js)
if err == nil || !errors.As(err, &syntaxError) {
log.Debugf("No migration needed for %s, %s", tableName, fieldName)
return nil
}
if err := db.Transaction(func(tx *gorm.DB) error {
if err := tx.Exec(fmt.Sprintf("ALTER TABLE %s ADD COLUMN %s TEXT", tableName, newColumnName)).Error; err != nil {
return fmt.Errorf("add column %s: %w", newColumnName, err)
}
var rows []map[string]any
if err := tx.Table(tableName).Select("id", oldColumnName).Find(&rows).Error; err != nil {
return fmt.Errorf("find rows: %w", err)
}
for _, row := range rows {
var field S
str, ok := row[oldColumnName].(string)
if !ok {
return fmt.Errorf("type assertion failed")
}
reader := strings.NewReader(str)
if err := gob.NewDecoder(reader).Decode(&field); err != nil {
return fmt.Errorf("gob decode error: %w", err)
}
jsonValue, err := json.Marshal(field)
if err != nil {
return fmt.Errorf("re-encode to JSON: %w", err)
}
if err := tx.Table(tableName).Where("id = ?", row["id"]).Update(newColumnName, jsonValue).Error; err != nil {
return fmt.Errorf("update row: %w", err)
}
}
if err := tx.Exec(fmt.Sprintf("ALTER TABLE %s DROP COLUMN %s", tableName, oldColumnName)).Error; err != nil {
return fmt.Errorf("drop column %s: %w", oldColumnName, err)
}
if err := tx.Exec(fmt.Sprintf("ALTER TABLE %s RENAME COLUMN %s TO %s", tableName, newColumnName, oldColumnName)).Error; err != nil {
return fmt.Errorf("rename column %s to %s: %w", newColumnName, oldColumnName, err)
}
return nil
}); err != nil {
return err
}
log.Infof("Migration of %s.%s from gob to json completed", tableName, fieldName)
return nil
}
// MigrateNetIPFieldFromBlobToJSON migrates a Net IP column from Blob encoding to JSON encoding.
// T is the type of the model that contains the field to be migrated.
func MigrateNetIPFieldFromBlobToJSON[T any](db *gorm.DB, fieldName string, indexName string) error {
oldColumnName := fieldName
newColumnName := fieldName + "_tmp"
var model T
if !db.Migrator().HasTable(&model) {
log.Printf("Table for %T does not exist, no migration needed", model)
return nil
}
stmt := &gorm.Statement{DB: db}
err := stmt.Parse(&model)
if err != nil {
return fmt.Errorf("parse model: %w", err)
}
tableName := stmt.Schema.Table
var item sql.NullString
if err := db.Model(&model).Select(oldColumnName).First(&item).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
log.Printf("No records in table %s, no migration needed", tableName)
return nil
}
return fmt.Errorf("fetch first record: %w", err)
}
if item.Valid {
var js json.RawMessage
var syntaxError *json.SyntaxError
err = json.Unmarshal([]byte(item.String), &js)
if err == nil || !errors.As(err, &syntaxError) {
log.Debugf("No migration needed for %s, %s", tableName, fieldName)
return nil
}
}
if err := db.Transaction(func(tx *gorm.DB) error {
if err := tx.Exec(fmt.Sprintf("ALTER TABLE %s ADD COLUMN %s TEXT", tableName, newColumnName)).Error; err != nil {
return fmt.Errorf("add column %s: %w", newColumnName, err)
}
var rows []map[string]any
if err := tx.Table(tableName).Select("id", oldColumnName).Find(&rows).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
log.Printf("No records in table %s, no migration needed", tableName)
return nil
}
return fmt.Errorf("find rows: %w", err)
}
for _, row := range rows {
var blobValue string
if columnValue := row[oldColumnName]; columnValue != nil {
value, ok := columnValue.(string)
if !ok {
return fmt.Errorf("type assertion failed")
}
blobValue = value
}
columnIpValue := net.IP(blobValue)
if net.ParseIP(columnIpValue.String()) == nil {
log.Debugf("failed to parse %s as ip, fallback to ipv6 loopback", oldColumnName)
columnIpValue = net.IPv6loopback
}
jsonValue, err := json.Marshal(columnIpValue)
if err != nil {
return fmt.Errorf("re-encode to JSON: %w", err)
}
if err := tx.Table(tableName).Where("id = ?", row["id"]).Update(newColumnName, jsonValue).Error; err != nil {
return fmt.Errorf("update row: %w", err)
}
}
if indexName != "" {
if err := tx.Migrator().DropIndex(&model, indexName); err != nil {
return fmt.Errorf("drop index %s: %w", indexName, err)
}
}
if err := tx.Exec(fmt.Sprintf("ALTER TABLE %s DROP COLUMN %s", tableName, oldColumnName)).Error; err != nil {
return fmt.Errorf("drop column %s: %w", oldColumnName, err)
}
if err := tx.Exec(fmt.Sprintf("ALTER TABLE %s RENAME COLUMN %s TO %s", tableName, newColumnName, oldColumnName)).Error; err != nil {
return fmt.Errorf("rename column %s to %s: %w", newColumnName, oldColumnName, err)
}
return nil
}); err != nil {
return err
}
log.Printf("Migration of %s.%s from blob to json completed", tableName, fieldName)
return nil
}

View File

@@ -1,161 +0,0 @@
package migration_test
import (
"encoding/gob"
"net"
"strings"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/migration"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/route"
)
func setupDatabase(t *testing.T) *gorm.DB {
t.Helper()
db, err := gorm.Open(sqlite.Open("file::memory:?cache=shared"), &gorm.Config{
PrepareStmt: true,
})
require.NoError(t, err, "Failed to open database")
return db
}
func TestMigrateFieldFromGobToJSON_EmptyDB(t *testing.T) {
db := setupDatabase(t)
err := migration.MigrateFieldFromGobToJSON[server.Account, net.IPNet](db, "network_net")
require.NoError(t, err, "Migration should not fail for an empty database")
}
func TestMigrateFieldFromGobToJSON_WithGobData(t *testing.T) {
db := setupDatabase(t)
err := db.AutoMigrate(&server.Account{}, &route.Route{})
require.NoError(t, err, "Failed to auto-migrate tables")
_, ipnet, err := net.ParseCIDR("10.0.0.0/24")
require.NoError(t, err, "Failed to parse CIDR")
type network struct {
server.Network
Net net.IPNet `gorm:"serializer:gob"`
}
type account struct {
server.Account
Network *network `gorm:"embedded;embeddedPrefix:network_"`
}
err = db.Save(&account{Account: server.Account{Id: "123"}, Network: &network{Net: *ipnet}}).Error
require.NoError(t, err, "Failed to insert Gob data")
var gobStr string
err = db.Model(&server.Account{}).Select("network_net").First(&gobStr).Error
assert.NoError(t, err, "Failed to fetch Gob data")
err = gob.NewDecoder(strings.NewReader(gobStr)).Decode(&ipnet)
require.NoError(t, err, "Failed to decode Gob data")
err = migration.MigrateFieldFromGobToJSON[server.Account, net.IPNet](db, "network_net")
require.NoError(t, err, "Migration should not fail with Gob data")
var jsonStr string
db.Model(&server.Account{}).Select("network_net").First(&jsonStr)
assert.JSONEq(t, `{"IP":"10.0.0.0","Mask":"////AA=="}`, jsonStr, "Data should be migrated")
}
func TestMigrateFieldFromGobToJSON_WithJSONData(t *testing.T) {
db := setupDatabase(t)
err := db.AutoMigrate(&server.Account{}, &route.Route{})
require.NoError(t, err, "Failed to auto-migrate tables")
_, ipnet, err := net.ParseCIDR("10.0.0.0/24")
require.NoError(t, err, "Failed to parse CIDR")
err = db.Save(&server.Account{Network: &server.Network{Net: *ipnet}}).Error
require.NoError(t, err, "Failed to insert JSON data")
err = migration.MigrateFieldFromGobToJSON[server.Account, net.IPNet](db, "network_net")
require.NoError(t, err, "Migration should not fail with JSON data")
var jsonStr string
db.Model(&server.Account{}).Select("network_net").First(&jsonStr)
assert.JSONEq(t, `{"IP":"10.0.0.0","Mask":"////AA=="}`, jsonStr, "Data should be unchanged")
}
func TestMigrateNetIPFieldFromBlobToJSON_EmptyDB(t *testing.T) {
db := setupDatabase(t)
err := migration.MigrateNetIPFieldFromBlobToJSON[nbpeer.Peer](db, "ip", "idx_peers_account_id_ip")
require.NoError(t, err, "Migration should not fail for an empty database")
}
func TestMigrateNetIPFieldFromBlobToJSON_WithBlobData(t *testing.T) {
db := setupDatabase(t)
err := db.AutoMigrate(&server.Account{}, &nbpeer.Peer{})
require.NoError(t, err, "Failed to auto-migrate tables")
type location struct {
nbpeer.Location
ConnectionIP net.IP
}
type peer struct {
nbpeer.Peer
Location location `gorm:"embedded;embeddedPrefix:location_"`
}
type account struct {
server.Account
Peers []peer `gorm:"foreignKey:AccountID;references:id"`
}
err = db.Save(&account{
Account: server.Account{Id: "123"},
Peers: []peer{
{Location: location{ConnectionIP: net.IP{10, 0, 0, 1}}},
}},
).Error
require.NoError(t, err, "Failed to insert blob data")
var blobValue string
err = db.Model(&nbpeer.Peer{}).Select("location_connection_ip").First(&blobValue).Error
assert.NoError(t, err, "Failed to fetch blob data")
err = migration.MigrateNetIPFieldFromBlobToJSON[nbpeer.Peer](db, "location_connection_ip", "")
require.NoError(t, err, "Migration should not fail with net.IP blob data")
var jsonStr string
db.Model(&nbpeer.Peer{}).Select("location_connection_ip").First(&jsonStr)
assert.JSONEq(t, `"10.0.0.1"`, jsonStr, "Data should be migrated")
}
func TestMigrateNetIPFieldFromBlobToJSON_WithJSONData(t *testing.T) {
db := setupDatabase(t)
err := db.AutoMigrate(&server.Account{}, &nbpeer.Peer{})
require.NoError(t, err, "Failed to auto-migrate tables")
err = db.Save(&server.Account{
Id: "1234",
PeersG: []nbpeer.Peer{
{Location: nbpeer.Location{ConnectionIP: net.IP{10, 0, 0, 1}}},
}},
).Error
require.NoError(t, err, "Failed to insert JSON data")
err = migration.MigrateNetIPFieldFromBlobToJSON[nbpeer.Peer](db, "location_connection_ip", "")
require.NoError(t, err, "Migration should not fail with net.IP JSON data")
var jsonStr string
db.Model(&nbpeer.Peer{}).Select("location_connection_ip").First(&jsonStr)
assert.JSONEq(t, `"10.0.0.1"`, jsonStr, "Data should be unchanged")
}

View File

@@ -22,93 +22,80 @@ type MockAccountManager struct {
GetOrCreateAccountByUserFunc func(userId, domain string) (*server.Account, error)
CreateSetupKeyFunc func(accountId string, keyName string, keyType server.SetupKeyType,
expiresIn time.Duration, autoGroups []string, usageLimit int, userID string, ephemeral bool) (*server.SetupKey, error)
GetSetupKeyFunc func(accountID, userID, keyID string) (*server.SetupKey, error)
GetAccountByUserOrAccountIdFunc func(userId, accountId, domain string) (*server.Account, error)
GetUserFunc func(claims jwtclaims.AuthorizationClaims) (*server.User, error)
ListUsersFunc func(accountID string) ([]*server.User, error)
GetPeersFunc func(accountID, userID string) ([]*nbpeer.Peer, error)
MarkPeerConnectedFunc func(peerKey string, connected bool, realIP net.IP) error
SyncAndMarkPeerFunc func(peerPubKey string, realIP net.IP) (*nbpeer.Peer, *server.NetworkMap, error)
DeletePeerFunc func(accountID, peerKey, userID string) error
GetNetworkMapFunc func(peerKey string) (*server.NetworkMap, error)
GetPeerNetworkFunc func(peerKey string) (*server.Network, error)
AddPeerFunc func(setupKey string, userId string, peer *nbpeer.Peer) (*nbpeer.Peer, *server.NetworkMap, error)
GetGroupFunc func(accountID, groupID, userID string) (*group.Group, error)
GetAllGroupsFunc func(accountID, userID string) ([]*group.Group, error)
GetGroupByNameFunc func(accountID, groupName string) (*group.Group, error)
SaveGroupFunc func(accountID, userID string, group *group.Group) error
DeleteGroupFunc func(accountID, userId, groupID string) error
ListGroupsFunc func(accountID string) ([]*group.Group, error)
GroupAddPeerFunc func(accountID, groupID, peerID string) error
GroupDeletePeerFunc func(accountID, groupID, peerID string) error
DeleteRuleFunc func(accountID, ruleID, userID string) error
GetPolicyFunc func(accountID, policyID, userID string) (*server.Policy, error)
SavePolicyFunc func(accountID, userID string, policy *server.Policy) error
DeletePolicyFunc func(accountID, policyID, userID string) error
ListPoliciesFunc func(accountID, userID string) ([]*server.Policy, error)
GetUsersFromAccountFunc func(accountID, userID string) ([]*server.UserInfo, error)
GetAccountFromPATFunc func(pat string) (*server.Account, *server.User, *server.PersonalAccessToken, error)
MarkPATUsedFunc func(pat string) error
UpdatePeerMetaFunc func(peerID string, meta nbpeer.PeerSystemMeta) error
UpdatePeerSSHKeyFunc func(peerID string, sshKey string) error
UpdatePeerFunc func(accountID, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, error)
CreateRouteFunc func(accountID, prefix, peer string, peerGroups []string, description, netID string, masquerade bool, metric int, groups []string, enabled bool, userID string) (*route.Route, error)
GetRouteFunc func(accountID, routeID, userID string) (*route.Route, error)
SaveRouteFunc func(accountID, userID string, route *route.Route) error
DeleteRouteFunc func(accountID, routeID, userID string) error
ListRoutesFunc func(accountID, userID string) ([]*route.Route, error)
SaveSetupKeyFunc func(accountID string, key *server.SetupKey, userID string) (*server.SetupKey, error)
ListSetupKeysFunc func(accountID, userID string) ([]*server.SetupKey, error)
SaveUserFunc func(accountID, userID string, user *server.User) (*server.UserInfo, error)
SaveOrAddUserFunc func(accountID, userID string, user *server.User, addIfNotExists bool) (*server.UserInfo, error)
DeleteUserFunc func(accountID string, initiatorUserID string, targetUserID string) error
CreatePATFunc func(accountID string, initiatorUserID string, targetUserId string, tokenName string, expiresIn int) (*server.PersonalAccessTokenGenerated, error)
DeletePATFunc func(accountID string, initiatorUserID string, targetUserId string, tokenID string) error
GetPATFunc func(accountID string, initiatorUserID string, targetUserId string, tokenID string) (*server.PersonalAccessToken, error)
GetAllPATsFunc func(accountID string, initiatorUserID string, targetUserId string) ([]*server.PersonalAccessToken, error)
GetNameServerGroupFunc func(accountID, userID, nsGroupID string) (*nbdns.NameServerGroup, error)
CreateNameServerGroupFunc func(accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool, userID string, searchDomainsEnabled bool) (*nbdns.NameServerGroup, error)
SaveNameServerGroupFunc func(accountID, userID string, nsGroupToSave *nbdns.NameServerGroup) error
DeleteNameServerGroupFunc func(accountID, nsGroupID, userID string) error
ListNameServerGroupsFunc func(accountID string, userID string) ([]*nbdns.NameServerGroup, error)
CreateUserFunc func(accountID, userID string, key *server.UserInfo) (*server.UserInfo, error)
GetAccountFromTokenFunc func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error)
CheckUserAccessByJWTGroupsFunc func(claims jwtclaims.AuthorizationClaims) error
DeleteAccountFunc func(accountID, userID string) error
GetDNSDomainFunc func() string
StoreEventFunc func(initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any)
GetEventsFunc func(accountID, userID string) ([]*activity.Event, error)
GetDNSSettingsFunc func(accountID, userID string) (*server.DNSSettings, error)
SaveDNSSettingsFunc func(accountID, userID string, dnsSettingsToSave *server.DNSSettings) error
GetPeerFunc func(accountID, peerID, userID string) (*nbpeer.Peer, error)
UpdateAccountSettingsFunc func(accountID, userID string, newSettings *server.Settings) (*server.Account, error)
LoginPeerFunc func(login server.PeerLogin) (*nbpeer.Peer, *server.NetworkMap, error)
SyncPeerFunc func(sync server.PeerSync) (*nbpeer.Peer, *server.NetworkMap, error)
InviteUserFunc func(accountID string, initiatorUserID string, targetUserEmail string) error
GetAllConnectedPeersFunc func() (map[string]struct{}, error)
HasConnectedChannelFunc func(peerID string) bool
GetExternalCacheManagerFunc func() server.ExternalCacheManager
GetPostureChecksFunc func(accountID, postureChecksID, userID string) (*posture.Checks, error)
SavePostureChecksFunc func(accountID, userID string, postureChecks *posture.Checks) error
DeletePostureChecksFunc func(accountID, postureChecksID, userID string) error
ListPostureChecksFunc func(accountID, userID string) ([]*posture.Checks, error)
GetIdpManagerFunc func() idp.Manager
GetSetupKeyFunc func(accountID, userID, keyID string) (*server.SetupKey, error)
GetAccountByUserOrAccountIdFunc func(userId, accountId, domain string) (*server.Account, error)
GetUserFunc func(claims jwtclaims.AuthorizationClaims) (*server.User, error)
ListUsersFunc func(accountID string) ([]*server.User, error)
GetPeersFunc func(accountID, userID string) ([]*nbpeer.Peer, error)
MarkPeerConnectedFunc func(peerKey string, connected bool, realIP net.IP) error
DeletePeerFunc func(accountID, peerKey, userID string) error
GetNetworkMapFunc func(peerKey string) (*server.NetworkMap, error)
GetPeerNetworkFunc func(peerKey string) (*server.Network, error)
AddPeerFunc func(setupKey string, userId string, peer *nbpeer.Peer) (*nbpeer.Peer, *server.NetworkMap, error)
GetGroupFunc func(accountID, groupID, userID string) (*group.Group, error)
GetAllGroupsFunc func(accountID, userID string) ([]*group.Group, error)
GetGroupByNameFunc func(accountID, groupName string) (*group.Group, error)
SaveGroupFunc func(accountID, userID string, group *group.Group) error
DeleteGroupFunc func(accountID, userId, groupID string) error
ListGroupsFunc func(accountID string) ([]*group.Group, error)
GroupAddPeerFunc func(accountID, groupID, peerID string) error
GroupDeletePeerFunc func(accountID, groupID, peerID string) error
DeleteRuleFunc func(accountID, ruleID, userID string) error
GetPolicyFunc func(accountID, policyID, userID string) (*server.Policy, error)
SavePolicyFunc func(accountID, userID string, policy *server.Policy) error
DeletePolicyFunc func(accountID, policyID, userID string) error
ListPoliciesFunc func(accountID, userID string) ([]*server.Policy, error)
GetUsersFromAccountFunc func(accountID, userID string) ([]*server.UserInfo, error)
GetAccountFromPATFunc func(pat string) (*server.Account, *server.User, *server.PersonalAccessToken, error)
MarkPATUsedFunc func(pat string) error
UpdatePeerMetaFunc func(peerID string, meta nbpeer.PeerSystemMeta) error
UpdatePeerSSHKeyFunc func(peerID string, sshKey string) error
UpdatePeerFunc func(accountID, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, error)
CreateRouteFunc func(accountID, prefix, peer string, peerGroups []string, description, netID string, masquerade bool, metric int, groups []string, enabled bool, userID string) (*route.Route, error)
GetRouteFunc func(accountID, routeID, userID string) (*route.Route, error)
SaveRouteFunc func(accountID, userID string, route *route.Route) error
DeleteRouteFunc func(accountID, routeID, userID string) error
ListRoutesFunc func(accountID, userID string) ([]*route.Route, error)
SaveSetupKeyFunc func(accountID string, key *server.SetupKey, userID string) (*server.SetupKey, error)
ListSetupKeysFunc func(accountID, userID string) ([]*server.SetupKey, error)
SaveUserFunc func(accountID, userID string, user *server.User) (*server.UserInfo, error)
SaveOrAddUserFunc func(accountID, userID string, user *server.User, addIfNotExists bool) (*server.UserInfo, error)
DeleteUserFunc func(accountID string, initiatorUserID string, targetUserID string) error
CreatePATFunc func(accountID string, initiatorUserID string, targetUserId string, tokenName string, expiresIn int) (*server.PersonalAccessTokenGenerated, error)
DeletePATFunc func(accountID string, initiatorUserID string, targetUserId string, tokenID string) error
GetPATFunc func(accountID string, initiatorUserID string, targetUserId string, tokenID string) (*server.PersonalAccessToken, error)
GetAllPATsFunc func(accountID string, initiatorUserID string, targetUserId string) ([]*server.PersonalAccessToken, error)
GetNameServerGroupFunc func(accountID, userID, nsGroupID string) (*nbdns.NameServerGroup, error)
CreateNameServerGroupFunc func(accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool, userID string, searchDomainsEnabled bool) (*nbdns.NameServerGroup, error)
SaveNameServerGroupFunc func(accountID, userID string, nsGroupToSave *nbdns.NameServerGroup) error
DeleteNameServerGroupFunc func(accountID, nsGroupID, userID string) error
ListNameServerGroupsFunc func(accountID string, userID string) ([]*nbdns.NameServerGroup, error)
CreateUserFunc func(accountID, userID string, key *server.UserInfo) (*server.UserInfo, error)
GetAccountFromTokenFunc func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error)
CheckUserAccessByJWTGroupsFunc func(claims jwtclaims.AuthorizationClaims) error
DeleteAccountFunc func(accountID, userID string) error
GetDNSDomainFunc func() string
StoreEventFunc func(initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any)
GetEventsFunc func(accountID, userID string) ([]*activity.Event, error)
GetDNSSettingsFunc func(accountID, userID string) (*server.DNSSettings, error)
SaveDNSSettingsFunc func(accountID, userID string, dnsSettingsToSave *server.DNSSettings) error
GetPeerFunc func(accountID, peerID, userID string) (*nbpeer.Peer, error)
UpdateAccountSettingsFunc func(accountID, userID string, newSettings *server.Settings) (*server.Account, error)
LoginPeerFunc func(login server.PeerLogin) (*nbpeer.Peer, *server.NetworkMap, error)
SyncPeerFunc func(sync server.PeerSync) (*nbpeer.Peer, *server.NetworkMap, error)
InviteUserFunc func(accountID string, initiatorUserID string, targetUserEmail string) error
GetAllConnectedPeersFunc func() (map[string]struct{}, error)
HasConnectedChannelFunc func(peerID string) bool
GetExternalCacheManagerFunc func() server.ExternalCacheManager
GetPostureChecksFunc func(accountID, postureChecksID, userID string) (*posture.Checks, error)
SavePostureChecksFunc func(accountID, userID string, postureChecks *posture.Checks) error
DeletePostureChecksFunc func(accountID, postureChecksID, userID string) error
ListPostureChecksFunc func(accountID, userID string) ([]*posture.Checks, error)
GetIdpManagerFunc func() idp.Manager
UpdateIntegratedValidatorGroupsFunc func(accountID string, userID string, groups []string) error
GroupValidationFunc func(accountId string, groups []string) (bool, error)
}
func (am *MockAccountManager) SyncAndMarkPeer(peerPubKey string, realIP net.IP) (*nbpeer.Peer, *server.NetworkMap, error) {
if am.SyncAndMarkPeerFunc != nil {
return am.SyncAndMarkPeerFunc(peerPubKey, realIP)
}
return nil, nil, status.Errorf(codes.Unimplemented, "method MarkPeerConnected is not implemented")
}
func (am *MockAccountManager) CancelPeerRoutines(peer *nbpeer.Peer) error {
// TODO implement me
panic("implement me")
}
func (am *MockAccountManager) GetValidatedPeers(account *server.Account) (map[string]struct{}, error) {
approvedPeers := make(map[string]struct{})
for id := range account.Peers {
@@ -193,7 +180,7 @@ func (am *MockAccountManager) GetAccountByUserOrAccountID(
}
// MarkPeerConnected mock implementation of MarkPeerConnected from server.AccountManager interface
func (am *MockAccountManager) MarkPeerConnected(peerKey string, connected bool, realIP net.IP, account *server.Account) error {
func (am *MockAccountManager) MarkPeerConnected(peerKey string, connected bool, realIP net.IP) error {
if am.MarkPeerConnectedFunc != nil {
return am.MarkPeerConnectedFunc(peerKey, connected, realIP)
}
@@ -639,7 +626,7 @@ func (am *MockAccountManager) LoginPeer(login server.PeerLogin) (*nbpeer.Peer, *
}
// SyncPeer mocks SyncPeer of the AccountManager interface
func (am *MockAccountManager) SyncPeer(sync server.PeerSync, account *server.Account) (*nbpeer.Peer, *server.NetworkMap, error) {
func (am *MockAccountManager) SyncPeer(sync server.PeerSync) (*nbpeer.Peer, *server.NetworkMap, error) {
if am.SyncPeerFunc != nil {
return am.SyncPeerFunc(sync)
}

View File

@@ -19,7 +19,7 @@ const domainPattern = `^(?i)[a-z0-9]+([\-\.]{1}[a-z0-9]+)*\.[a-z]{2,}$`
// GetNameServerGroup gets a nameserver group object from account and nameserver group IDs
func (am *DefaultAccountManager) GetNameServerGroup(accountID, userID, nsGroupID string) (*nbdns.NameServerGroup, error) {
unlock := am.Store.AcquireAccountWriteLock(accountID)
unlock := am.Store.AcquireAccountLock(accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID)
@@ -47,7 +47,7 @@ func (am *DefaultAccountManager) GetNameServerGroup(accountID, userID, nsGroupID
// CreateNameServerGroup creates and saves a new nameserver group
func (am *DefaultAccountManager) CreateNameServerGroup(accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool, userID string, searchDomainEnabled bool) (*nbdns.NameServerGroup, error) {
unlock := am.Store.AcquireAccountWriteLock(accountID)
unlock := am.Store.AcquireAccountLock(accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID)
@@ -94,7 +94,7 @@ func (am *DefaultAccountManager) CreateNameServerGroup(accountID string, name, d
// SaveNameServerGroup saves nameserver group
func (am *DefaultAccountManager) SaveNameServerGroup(accountID, userID string, nsGroupToSave *nbdns.NameServerGroup) error {
unlock := am.Store.AcquireAccountWriteLock(accountID)
unlock := am.Store.AcquireAccountLock(accountID)
defer unlock()
if nsGroupToSave == nil {
@@ -129,7 +129,7 @@ func (am *DefaultAccountManager) SaveNameServerGroup(accountID, userID string, n
// DeleteNameServerGroup deletes nameserver group with nsGroupID
func (am *DefaultAccountManager) DeleteNameServerGroup(accountID, nsGroupID, userID string) error {
unlock := am.Store.AcquireAccountWriteLock(accountID)
unlock := am.Store.AcquireAccountLock(accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID)
@@ -159,7 +159,7 @@ func (am *DefaultAccountManager) DeleteNameServerGroup(accountID, nsGroupID, use
// ListNameServerGroups returns a list of nameserver groups from account
func (am *DefaultAccountManager) ListNameServerGroups(accountID string, userID string) ([]*nbdns.NameServerGroup, error) {
unlock := am.Store.AcquireAccountWriteLock(accountID)
unlock := am.Store.AcquireAccountLock(accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID)

View File

@@ -36,7 +36,7 @@ type NetworkMap struct {
type Network struct {
Identifier string `json:"id"`
Net net.IPNet `gorm:"serializer:json"`
Net net.IPNet `gorm:"serializer:gob"`
Dns string
// Serial is an ID that increments by 1 when any change to the network happened (e.g. new peer has been added).
// Used to synchronize state to the client apps.

View File

@@ -88,7 +88,21 @@ func (am *DefaultAccountManager) GetPeers(accountID, userID string) ([]*nbpeer.P
}
// MarkPeerConnected marks peer as connected (true) or disconnected (false)
func (am *DefaultAccountManager) MarkPeerConnected(peerPubKey string, connected bool, realIP net.IP, account *Account) error {
func (am *DefaultAccountManager) MarkPeerConnected(peerPubKey string, connected bool, realIP net.IP) error {
account, err := am.Store.GetAccountByPeerPubKey(peerPubKey)
if err != nil {
return err
}
unlock := am.Store.AcquireAccountLock(account.Id)
defer unlock()
// ensure that we consider modification happened meanwhile (because we were outside the account lock when we fetched the account)
account, err = am.Store.GetAccount(account.Id)
if err != nil {
return err
}
peer, err := account.FindPeerByPubKey(peerPubKey)
if err != nil {
return err
@@ -142,7 +156,7 @@ func (am *DefaultAccountManager) MarkPeerConnected(peerPubKey string, connected
// UpdatePeer updates peer. Only Peer.Name, Peer.SSHEnabled, and Peer.LoginExpirationEnabled can be updated.
func (am *DefaultAccountManager) UpdatePeer(accountID, userID string, update *nbpeer.Peer) (*nbpeer.Peer, error) {
unlock := am.Store.AcquireAccountWriteLock(accountID)
unlock := am.Store.AcquireAccountLock(accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID)
@@ -264,7 +278,7 @@ func (am *DefaultAccountManager) deletePeers(account *Account, peerIDs []string,
// DeletePeer removes peer from the account by its IP
func (am *DefaultAccountManager) DeletePeer(accountID, peerID, userID string) error {
unlock := am.Store.AcquireAccountWriteLock(accountID)
unlock := am.Store.AcquireAccountLock(accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID)
@@ -348,7 +362,7 @@ func (am *DefaultAccountManager) AddPeer(setupKey, userID string, peer *nbpeer.P
return nil, nil, status.Errorf(status.NotFound, "failed adding new peer: account not found")
}
unlock := am.Store.AcquireAccountWriteLock(account.Id)
unlock := am.Store.AcquireAccountLock(account.Id)
defer unlock()
// ensure that we consider modification happened meanwhile (because we were outside the account lock when we fetched the account)
@@ -360,14 +374,14 @@ func (am *DefaultAccountManager) AddPeer(setupKey, userID string, peer *nbpeer.P
if strings.ToLower(peer.Meta.Hostname) == "iphone" || strings.ToLower(peer.Meta.Hostname) == "ipad" && userID != "" {
if am.idpManager != nil {
userdata, err := am.lookupUserInCache(userID, account)
if err == nil && userdata != nil {
if err == nil {
peer.Meta.Hostname = fmt.Sprintf("%s-%s", peer.Meta.Hostname, strings.Split(userdata.Email, "@")[0])
}
}
}
// This is a handling for the case when the same machine (with the same WireGuard pub key) tries to register twice.
// Such case is possible when AddPeer function takes long time to finish after AcquireAccountWriteLock (e.g., database is slow)
// Such case is possible when AddPeer function takes long time to finish after AcquireAccountLock (e.g., database is slow)
// and the peer disconnects with a timeout and tries to register again.
// We just check if this machine has been registered before and reject the second registration.
// The connecting peer should be able to recover with a retry.
@@ -504,7 +518,25 @@ func (am *DefaultAccountManager) AddPeer(setupKey, userID string, peer *nbpeer.P
}
// SyncPeer checks whether peer is eligible for receiving NetworkMap (authenticated) and returns its NetworkMap if eligible
func (am *DefaultAccountManager) SyncPeer(sync PeerSync, account *Account) (*nbpeer.Peer, *NetworkMap, error) {
func (am *DefaultAccountManager) SyncPeer(sync PeerSync) (*nbpeer.Peer, *NetworkMap, error) {
account, err := am.Store.GetAccountByPeerPubKey(sync.WireGuardPubKey)
if err != nil {
if errStatus, ok := status.FromError(err); ok && errStatus.Type() == status.NotFound {
return nil, nil, status.Errorf(status.Unauthenticated, "peer is not registered")
}
return nil, nil, err
}
// we found the peer, and we follow a normal login flow
unlock := am.Store.AcquireAccountLock(account.Id)
defer unlock()
// fetch the account from the store once more after acquiring lock to avoid concurrent updates inconsistencies
account, err = am.Store.GetAccount(account.Id)
if err != nil {
return nil, nil, err
}
peer, err := account.FindPeerByPubKey(sync.WireGuardPubKey)
if err != nil {
return nil, nil, status.Errorf(status.Unauthenticated, "peer is not registered")
@@ -571,7 +603,7 @@ func (am *DefaultAccountManager) LoginPeer(login PeerLogin) (*nbpeer.Peer, *Netw
}
// we found the peer, and we follow a normal login flow
unlock := am.Store.AcquireAccountWriteLock(account.Id)
unlock := am.Store.AcquireAccountLock(account.Id)
defer unlock()
// fetch the account from the store once more after acquiring lock to avoid concurrent updates inconsistencies
@@ -728,7 +760,7 @@ func (am *DefaultAccountManager) UpdatePeerSSHKey(peerID string, sshKey string)
return err
}
unlock := am.Store.AcquireAccountWriteLock(account.Id)
unlock := am.Store.AcquireAccountLock(account.Id)
defer unlock()
// ensure that we consider modification happened meanwhile (because we were outside the account lock when we fetched the account)
@@ -763,7 +795,7 @@ func (am *DefaultAccountManager) UpdatePeerSSHKey(peerID string, sshKey string)
// GetPeer for a given accountID, peerID and userID error if not found.
func (am *DefaultAccountManager) GetPeer(accountID, peerID, userID string) (*nbpeer.Peer, error) {
unlock := am.Store.AcquireAccountWriteLock(accountID)
unlock := am.Store.AcquireAccountLock(accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID)
@@ -833,10 +865,6 @@ func (am *DefaultAccountManager) updateAccountPeers(account *Account) {
return
}
for _, peer := range peers {
if !am.peersUpdateManager.HasChannel(peer.ID) {
log.Tracef("peer %s doesn't have a channel, skipping network map update", peer.ID)
continue
}
remotePeerNetworkMap := account.GetPeerNetworkMap(peer.ID, am.dnsDomain, approvedPeersMap)
update := toSyncResponse(nil, peer, nil, remotePeerNetworkMap, am.GetDNSDomain())
am.peersUpdateManager.SendUpdate(peer.ID, &UpdateMessage{Update: update})

View File

@@ -13,13 +13,13 @@ type Peer struct {
// ID is an internal ID of the peer
ID string `gorm:"primaryKey"`
// AccountID is a reference to Account that this object belongs
AccountID string `json:"-" gorm:"index"`
AccountID string `json:"-" gorm:"index;uniqueIndex:idx_peers_account_id_ip"`
// WireGuard public key
Key string `gorm:"index"`
// A setup key this peer was registered with
SetupKey string
// IP address of the Peer
IP net.IP `gorm:"serializer:json"`
IP net.IP `gorm:"uniqueIndex:idx_peers_account_id_ip"`
// Meta is a Peer system meta data
Meta PeerSystemMeta `gorm:"embedded;embeddedPrefix:meta_"`
// Name is peer's name (machine name)
@@ -61,7 +61,7 @@ type PeerStatus struct { //nolint:revive
// Location is a geo location information of a Peer based on public connection IP
type Location struct {
ConnectionIP net.IP `gorm:"serializer:json"` // from grpc peer or reverse proxy headers depends on setup
ConnectionIP net.IP // from grpc peer or reverse proxy headers depends on setup
CountryCode string
CityName string
GeoNameID uint // city level geoname id

View File

@@ -314,7 +314,7 @@ func (a *Account) connResourcesGenerator() (func(*PolicyRule, []*nbpeer.Peer, in
// GetPolicy from the store
func (am *DefaultAccountManager) GetPolicy(accountID, policyID, userID string) (*Policy, error) {
unlock := am.Store.AcquireAccountWriteLock(accountID)
unlock := am.Store.AcquireAccountLock(accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID)
@@ -342,7 +342,7 @@ func (am *DefaultAccountManager) GetPolicy(accountID, policyID, userID string) (
// SavePolicy in the store
func (am *DefaultAccountManager) SavePolicy(accountID, userID string, policy *Policy) error {
unlock := am.Store.AcquireAccountWriteLock(accountID)
unlock := am.Store.AcquireAccountLock(accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID)
@@ -370,7 +370,7 @@ func (am *DefaultAccountManager) SavePolicy(accountID, userID string, policy *Po
// DeletePolicy from the store
func (am *DefaultAccountManager) DeletePolicy(accountID, policyID, userID string) error {
unlock := am.Store.AcquireAccountWriteLock(accountID)
unlock := am.Store.AcquireAccountLock(accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID)
@@ -397,7 +397,7 @@ func (am *DefaultAccountManager) DeletePolicy(accountID, policyID, userID string
// ListPolicies from the store
func (am *DefaultAccountManager) ListPolicies(accountID, userID string) ([]*Policy, error) {
unlock := am.Store.AcquireAccountWriteLock(accountID)
unlock := am.Store.AcquireAccountLock(accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID)

View File

@@ -7,7 +7,7 @@ import (
)
func (am *DefaultAccountManager) GetPostureChecks(accountID, postureChecksID, userID string) (*posture.Checks, error) {
unlock := am.Store.AcquireAccountWriteLock(accountID)
unlock := am.Store.AcquireAccountLock(accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID)
@@ -34,7 +34,7 @@ func (am *DefaultAccountManager) GetPostureChecks(accountID, postureChecksID, us
}
func (am *DefaultAccountManager) SavePostureChecks(accountID, userID string, postureChecks *posture.Checks) error {
unlock := am.Store.AcquireAccountWriteLock(accountID)
unlock := am.Store.AcquireAccountLock(accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID)
@@ -81,7 +81,7 @@ func (am *DefaultAccountManager) SavePostureChecks(accountID, userID string, pos
}
func (am *DefaultAccountManager) DeletePostureChecks(accountID, postureChecksID, userID string) error {
unlock := am.Store.AcquireAccountWriteLock(accountID)
unlock := am.Store.AcquireAccountLock(accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID)
@@ -113,7 +113,7 @@ func (am *DefaultAccountManager) DeletePostureChecks(accountID, postureChecksID,
}
func (am *DefaultAccountManager) ListPostureChecks(accountID, userID string) ([]*posture.Checks, error) {
unlock := am.Store.AcquireAccountWriteLock(accountID)
unlock := am.Store.AcquireAccountLock(accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID)

View File

@@ -14,7 +14,7 @@ import (
// GetRoute gets a route object from account and route IDs
func (am *DefaultAccountManager) GetRoute(accountID, routeID, userID string) (*route.Route, error) {
unlock := am.Store.AcquireAccountWriteLock(accountID)
unlock := am.Store.AcquireAccountLock(accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID)
@@ -115,7 +115,7 @@ func (am *DefaultAccountManager) checkRoutePrefixExistsForPeers(account *Account
// CreateRoute creates and saves a new route
func (am *DefaultAccountManager) CreateRoute(accountID, network, peerID string, peerGroupIDs []string, description, netID string, masquerade bool, metric int, groups []string, enabled bool, userID string) (*route.Route, error) {
unlock := am.Store.AcquireAccountWriteLock(accountID)
unlock := am.Store.AcquireAccountLock(accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID)
@@ -194,7 +194,7 @@ func (am *DefaultAccountManager) CreateRoute(accountID, network, peerID string,
// SaveRoute saves route
func (am *DefaultAccountManager) SaveRoute(accountID, userID string, routeToSave *route.Route) error {
unlock := am.Store.AcquireAccountWriteLock(accountID)
unlock := am.Store.AcquireAccountLock(accountID)
defer unlock()
if routeToSave == nil {
@@ -255,7 +255,7 @@ func (am *DefaultAccountManager) SaveRoute(accountID, userID string, routeToSave
// DeleteRoute deletes route with routeID
func (am *DefaultAccountManager) DeleteRoute(accountID, routeID, userID string) error {
unlock := am.Store.AcquireAccountWriteLock(accountID)
unlock := am.Store.AcquireAccountLock(accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID)
@@ -283,7 +283,7 @@ func (am *DefaultAccountManager) DeleteRoute(accountID, routeID, userID string)
// ListRoutes returns a list of routes from account
func (am *DefaultAccountManager) ListRoutes(accountID, userID string) ([]*route.Route, error) {
unlock := am.Store.AcquireAccountWriteLock(accountID)
unlock := am.Store.AcquireAccountLock(accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID)

View File

@@ -209,7 +209,7 @@ func Hash(s string) uint32 {
// and adds it to the specified account. A list of autoGroups IDs can be empty.
func (am *DefaultAccountManager) CreateSetupKey(accountID string, keyName string, keyType SetupKeyType,
expiresIn time.Duration, autoGroups []string, usageLimit int, userID string, ephemeral bool) (*SetupKey, error) {
unlock := am.Store.AcquireAccountWriteLock(accountID)
unlock := am.Store.AcquireAccountLock(accountID)
defer unlock()
keyDuration := DefaultSetupKeyDuration
@@ -255,7 +255,7 @@ func (am *DefaultAccountManager) CreateSetupKey(accountID string, keyName string
// (e.g. the key itself, creation date, ID, etc).
// These properties are overwritten: Name, AutoGroups, Revoked. The rest is copied from the existing key.
func (am *DefaultAccountManager) SaveSetupKey(accountID string, keyToSave *SetupKey, userID string) (*SetupKey, error) {
unlock := am.Store.AcquireAccountWriteLock(accountID)
unlock := am.Store.AcquireAccountLock(accountID)
defer unlock()
if keyToSave == nil {
@@ -327,7 +327,7 @@ func (am *DefaultAccountManager) SaveSetupKey(accountID string, keyToSave *Setup
// ListSetupKeys returns a list of all setup keys of the account
func (am *DefaultAccountManager) ListSetupKeys(accountID, userID string) ([]*SetupKey, error) {
unlock := am.Store.AcquireAccountWriteLock(accountID)
unlock := am.Store.AcquireAccountLock(accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID)
if err != nil {
@@ -359,7 +359,7 @@ func (am *DefaultAccountManager) ListSetupKeys(accountID, userID string) ([]*Set
// GetSetupKey looks up a SetupKey by KeyID, returns NotFound error if not found.
func (am *DefaultAccountManager) GetSetupKey(accountID, userID, keyID string) (*SetupKey, error) {
unlock := am.Store.AcquireAccountWriteLock(accountID)
unlock := am.Store.AcquireAccountLock(accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID)

View File

@@ -3,8 +3,6 @@ package server
import (
"errors"
"fmt"
"net"
"net/netip"
"path/filepath"
"runtime"
"strings"
@@ -20,7 +18,6 @@ import (
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/server/account"
nbgroup "github.com/netbirdio/netbird/management/server/group"
"github.com/netbirdio/netbird/management/server/migration"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/management/server/status"
@@ -43,8 +40,6 @@ type installation struct {
InstallationIDValue string
}
type migrationFunc func(*gorm.DB) error
// NewSqliteStore restores a store from the file located in the datadir
func NewSqliteStore(dataDir string, metrics telemetry.AppMetrics) (*SqliteStore, error) {
storeStr := "store.db?cache=shared"
@@ -55,9 +50,8 @@ func NewSqliteStore(dataDir string, metrics telemetry.AppMetrics) (*SqliteStore,
file := filepath.Join(dataDir, storeStr)
db, err := gorm.Open(sqlite.Open(file), &gorm.Config{
Logger: logger.Default.LogMode(logger.Silent),
CreateBatchSize: 400,
PrepareStmt: true,
Logger: logger.Default.LogMode(logger.Silent),
PrepareStmt: true,
})
if err != nil {
return nil, err
@@ -70,16 +64,13 @@ func NewSqliteStore(dataDir string, metrics telemetry.AppMetrics) (*SqliteStore,
conns := runtime.NumCPU()
sql.SetMaxOpenConns(conns) // TODO: make it configurable
if err := migrate(db); err != nil {
return nil, fmt.Errorf("migrate: %w", err)
}
err = db.AutoMigrate(
&SetupKey{}, &nbpeer.Peer{}, &User{}, &PersonalAccessToken{}, &nbgroup.Group{},
&Account{}, &Policy{}, &PolicyRule{}, &route.Route{}, &nbdns.NameServerGroup{},
&installation{}, &account.ExtraSettings{}, &posture.Checks{}, &nbpeer.NetworkAddress{},
)
if err != nil {
return nil, fmt.Errorf("auto migrate: %w", err)
return nil, err
}
return &SqliteStore{db: db, storeFile: file, metrics: metrics, installationPK: 1}, nil
@@ -127,33 +118,17 @@ func (s *SqliteStore) AcquireGlobalLock() (unlock func()) {
return unlock
}
func (s *SqliteStore) AcquireAccountWriteLock(accountID string) (unlock func()) {
log.Tracef("acquiring write lock for account %s", accountID)
func (s *SqliteStore) AcquireAccountLock(accountID string) (unlock func()) {
log.Tracef("acquiring lock for account %s", accountID)
start := time.Now()
value, _ := s.accountLocks.LoadOrStore(accountID, &sync.RWMutex{})
mtx := value.(*sync.RWMutex)
value, _ := s.accountLocks.LoadOrStore(accountID, &sync.Mutex{})
mtx := value.(*sync.Mutex)
mtx.Lock()
unlock = func() {
mtx.Unlock()
log.Tracef("released write lock for account %s in %v", accountID, time.Since(start))
}
return unlock
}
func (s *SqliteStore) AcquireAccountReadLock(accountID string) (unlock func()) {
log.Tracef("acquiring read lock for account %s", accountID)
start := time.Now()
value, _ := s.accountLocks.LoadOrStore(accountID, &sync.RWMutex{})
mtx := value.(*sync.RWMutex)
mtx.RLock()
unlock = func() {
mtx.RUnlock()
log.Tracef("released read lock for account %s in %v", accountID, time.Since(start))
log.Tracef("released lock for account %s in %v", accountID, time.Since(start))
}
return unlock
@@ -213,8 +188,7 @@ func (s *SqliteStore) SaveAccount(account *Account) error {
result = tx.
Session(&gorm.Session{FullSaveAssociations: true}).
Clauses(clause.OnConflict{UpdateAll: true}).
Create(account)
Clauses(clause.OnConflict{UpdateAll: true}).Create(account)
if result.Error != nil {
return result.Error
}
@@ -279,43 +253,36 @@ func (s *SqliteStore) GetInstallationID() string {
}
func (s *SqliteStore) SavePeerStatus(accountID, peerID string, peerStatus nbpeer.PeerStatus) error {
var peerCopy nbpeer.Peer
peerCopy.Status = &peerStatus
result := s.db.Model(&nbpeer.Peer{}).
Where("account_id = ? AND id = ?", accountID, peerID).
Updates(peerCopy)
var peer nbpeer.Peer
result := s.db.First(&peer, "account_id = ? and id = ?", accountID, peerID)
if result.Error != nil {
return result.Error
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return status.Errorf(status.NotFound, "peer %s not found", peerID)
}
log.Errorf("error when getting peer from the store: %s", result.Error)
return status.Errorf(status.Internal, "issue getting peer from store")
}
if result.RowsAffected == 0 {
return status.Errorf(status.NotFound, "peer %s not found", peerID)
}
peer.Status = &peerStatus
return nil
return s.db.Save(peer).Error
}
func (s *SqliteStore) SavePeerLocation(accountID string, peerWithLocation *nbpeer.Peer) error {
// To maintain data integrity, we create a copy of the peer's location to prevent unintended updates to other fields.
var peerCopy nbpeer.Peer
// Since the location field has been migrated to JSON serialization,
// updating the struct ensures the correct data format is inserted into the database.
peerCopy.Location = peerWithLocation.Location
result := s.db.Model(&nbpeer.Peer{}).
Where("account_id = ? and id = ?", accountID, peerWithLocation.ID).
Updates(peerCopy)
var peer nbpeer.Peer
result := s.db.First(&peer, "account_id = ? and id = ?", accountID, peerWithLocation.ID)
if result.Error != nil {
return result.Error
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return status.Errorf(status.NotFound, "peer %s not found", peer.ID)
}
log.Errorf("error when getting peer from the store: %s", result.Error)
return status.Errorf(status.Internal, "issue getting peer from store")
}
if result.RowsAffected == 0 {
return status.Errorf(status.NotFound, "peer %s not found", peerWithLocation.ID)
}
peer.Location = peerWithLocation.Location
return nil
return s.db.Save(peer).Error
}
// DeleteHashedPAT2TokenIDIndex is noop in Sqlite
@@ -423,9 +390,8 @@ func (s *SqliteStore) GetAllAccounts() (all []*Account) {
}
func (s *SqliteStore) GetAccount(accountID string) (*Account, error) {
var account Account
result := s.db.Debug().Model(&account).
result := s.db.Model(&account).
Preload("UsersG.PATsG"). // have to be specifies as this is nester reference
Preload(clause.Associations).
First(&account, "id = ?", accountID)
@@ -545,21 +511,6 @@ func (s *SqliteStore) GetAccountByPeerPubKey(peerKey string) (*Account, error) {
return s.GetAccount(peer.AccountID)
}
func (s *SqliteStore) GetAccountIDByPeerPubKey(peerKey string) (string, error) {
var peer nbpeer.Peer
var accountID string
result := s.db.Model(&peer).Select("account_id").Where("key = ?", peerKey).First(&accountID)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return "", status.Errorf(status.NotFound, "account not found: index lookup failed")
}
log.Errorf("error when getting peer from the store: %s", result.Error)
return "", status.Errorf(status.Internal, "issue getting account from store")
}
return accountID, nil
}
// SaveUserLastLogin stores the last login time for a user in DB.
func (s *SqliteStore) SaveUserLastLogin(accountID, userID string, lastLogin time.Time) error {
var user User
@@ -591,36 +542,3 @@ func (s *SqliteStore) Close() error {
func (s *SqliteStore) GetStoreEngine() StoreEngine {
return SqliteStoreEngine
}
// migrate migrates the SQLite database to the latest schema
func migrate(db *gorm.DB) error {
migrations := getMigrations()
for _, m := range migrations {
if err := m(db); err != nil {
return err
}
}
return nil
}
func getMigrations() []migrationFunc {
return []migrationFunc{
func(db *gorm.DB) error {
return migration.MigrateFieldFromGobToJSON[Account, net.IPNet](db, "network_net")
},
func(db *gorm.DB) error {
return migration.MigrateFieldFromGobToJSON[route.Route, netip.Prefix](db, "network")
},
func(db *gorm.DB) error {
return migration.MigrateFieldFromGobToJSON[route.Route, []string](db, "peer_groups")
},
func(db *gorm.DB) error {
return migration.MigrateNetIPFieldFromBlobToJSON[nbpeer.Peer](db, "location_connection_ip", "")
},
func(db *gorm.DB) error {
return migration.MigrateNetIPFieldFromBlobToJSON[nbpeer.Peer](db, "ip", "idx_peers_account_id_ip")
},
}
}

View File

@@ -2,23 +2,16 @@ package server
import (
"fmt"
"math/rand"
"net"
"net/netip"
"path/filepath"
"runtime"
"testing"
"time"
nbdns "github.com/netbirdio/netbird/dns"
nbgroup "github.com/netbirdio/netbird/management/server/group"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
route2 "github.com/netbirdio/netbird/route"
"github.com/netbirdio/netbird/management/server/status"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
@@ -37,151 +30,6 @@ func TestSqlite_NewStore(t *testing.T) {
}
}
func TestSqlite_SaveAccount_Large(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("The SQLite store is not properly supported by Windows yet")
}
store := newSqliteStore(t)
account := newAccountWithId("account_id", "testuser", "")
groupALL, err := account.GetGroupAll()
if err != nil {
t.Fatal(err)
}
setupKey := GenerateDefaultSetupKey()
account.SetupKeys[setupKey.Key] = setupKey
const numPerAccount = 2000
for n := 0; n < numPerAccount; n++ {
netIP := randomIPv4()
peerID := fmt.Sprintf("%s-peer-%d", account.Id, n)
peer := &nbpeer.Peer{
ID: peerID,
Key: peerID,
SetupKey: "",
IP: netIP,
Name: peerID,
DNSLabel: peerID,
UserID: userID,
Status: &nbpeer.PeerStatus{Connected: false, LastSeen: time.Now()},
SSHEnabled: false,
}
account.Peers[peerID] = peer
group, _ := account.GetGroupAll()
group.Peers = append(group.Peers, peerID)
user := &User{
Id: fmt.Sprintf("%s-user-%d", account.Id, n),
AccountID: account.Id,
}
account.Users[user.Id] = user
route := &route2.Route{
ID: fmt.Sprintf("network-id-%d", n),
Description: "base route",
NetID: fmt.Sprintf("network-id-%d", n),
Network: netip.MustParsePrefix(netIP.String() + "/24"),
NetworkType: route2.IPv4Network,
Metric: 9999,
Masquerade: false,
Enabled: true,
Groups: []string{groupALL.ID},
}
account.Routes[route.ID] = route
group = &nbgroup.Group{
ID: fmt.Sprintf("group-id-%d", n),
AccountID: account.Id,
Name: fmt.Sprintf("group-id-%d", n),
Issued: "api",
Peers: nil,
}
account.Groups[group.ID] = group
nameserver := &nbdns.NameServerGroup{
ID: fmt.Sprintf("nameserver-id-%d", n),
AccountID: account.Id,
Name: fmt.Sprintf("nameserver-id-%d", n),
Description: "",
NameServers: []nbdns.NameServer{{IP: netip.MustParseAddr(netIP.String()), NSType: nbdns.UDPNameServerType}},
Groups: []string{group.ID},
Primary: false,
Domains: nil,
Enabled: false,
SearchDomainsEnabled: false,
}
account.NameServerGroups[nameserver.ID] = nameserver
setupKey := GenerateDefaultSetupKey()
account.SetupKeys[setupKey.Key] = setupKey
}
err = store.SaveAccount(account)
require.NoError(t, err)
if len(store.GetAllAccounts()) != 1 {
t.Errorf("expecting 1 Accounts to be stored after SaveAccount()")
}
a, err := store.GetAccount(account.Id)
if a == nil {
t.Errorf("expecting Account to be stored after SaveAccount(): %v", err)
}
if a != nil && len(a.Policies) != 1 {
t.Errorf("expecting Account to have one policy stored after SaveAccount(), got %d", len(a.Policies))
}
if a != nil && len(a.Policies[0].Rules) != 1 {
t.Errorf("expecting Account to have one policy rule stored after SaveAccount(), got %d", len(a.Policies[0].Rules))
return
}
if a != nil && len(a.Peers) != numPerAccount {
t.Errorf("expecting Account to have %d peers stored after SaveAccount(), got %d",
numPerAccount, len(a.Peers))
return
}
if a != nil && len(a.Users) != numPerAccount+1 {
t.Errorf("expecting Account to have %d users stored after SaveAccount(), got %d",
numPerAccount+1, len(a.Users))
return
}
if a != nil && len(a.Routes) != numPerAccount {
t.Errorf("expecting Account to have %d routes stored after SaveAccount(), got %d",
numPerAccount, len(a.Routes))
return
}
if a != nil && len(a.NameServerGroups) != numPerAccount {
t.Errorf("expecting Account to have %d NameServerGroups stored after SaveAccount(), got %d",
numPerAccount, len(a.NameServerGroups))
return
}
if a != nil && len(a.NameServerGroups) != numPerAccount {
t.Errorf("expecting Account to have %d NameServerGroups stored after SaveAccount(), got %d",
numPerAccount, len(a.NameServerGroups))
return
}
if a != nil && len(a.SetupKeys) != numPerAccount+1 {
t.Errorf("expecting Account to have %d SetupKeys stored after SaveAccount(), got %d",
numPerAccount+1, len(a.SetupKeys))
return
}
}
func randomIPv4() net.IP {
rand.New(rand.NewSource(time.Now().UnixNano()))
b := make([]byte, 4)
for i := range b {
b[i] = byte(rand.Intn(256))
}
return net.IP(b)
}
func TestSqlite_SaveAccount(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("The SQLite store is not properly supported by Windows yet")
@@ -501,74 +349,6 @@ func TestSqlite_GetUserByTokenID(t *testing.T) {
require.Equal(t, status.NotFound, parsedErr.Type(), "should return not found error")
}
func TestMigrate(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("The SQLite store is not properly supported by Windows yet")
}
store := newSqliteStore(t)
err := migrate(store.db)
require.NoError(t, err, "Migration should not fail on empty db")
_, ipnet, err := net.ParseCIDR("10.0.0.0/24")
require.NoError(t, err, "Failed to parse CIDR")
type network struct {
Network
Net net.IPNet `gorm:"serializer:gob"`
}
type location struct {
nbpeer.Location
ConnectionIP net.IP
}
type peer struct {
nbpeer.Peer
Location location `gorm:"embedded;embeddedPrefix:location_"`
}
type account struct {
Account
Network *network `gorm:"embedded;embeddedPrefix:network_"`
Peers []peer `gorm:"foreignKey:AccountID;references:id"`
}
act := &account{
Network: &network{
Net: *ipnet,
},
Peers: []peer{
{Location: location{ConnectionIP: net.IP{10, 0, 0, 1}}},
},
}
err = store.db.Save(act).Error
require.NoError(t, err, "Failed to insert Gob data")
type route struct {
route2.Route
Network netip.Prefix `gorm:"serializer:gob"`
PeerGroups []string `gorm:"serializer:gob"`
}
prefix := netip.MustParsePrefix("11.0.0.0/24")
rt := &route{
Network: prefix,
PeerGroups: []string{"group1", "group2"},
}
err = store.db.Save(rt).Error
require.NoError(t, err, "Failed to insert Gob data")
err = migrate(store.db)
require.NoError(t, err, "Migration should not fail on gob populated db")
err = migrate(store.db)
require.NoError(t, err, "Migration should not fail on migrated db")
}
func newSqliteStore(t *testing.T) *SqliteStore {
t.Helper()

View File

@@ -19,7 +19,6 @@ type Store interface {
DeleteAccount(account *Account) error
GetAccountByUser(userID string) (*Account, error)
GetAccountByPeerPubKey(peerKey string) (*Account, error)
GetAccountIDByPeerPubKey(peerKey string) (string, error)
GetAccountByPeerID(peerID string) (*Account, error)
GetAccountBySetupKey(setupKey string) (*Account, error) // todo use key hash later
GetAccountByPrivateDomain(domain string) (*Account, error)
@@ -30,10 +29,8 @@ type Store interface {
DeleteTokenID2UserIDIndex(tokenID string) error
GetInstallationID() string
SaveInstallationID(ID string) error
// AcquireAccountWriteLock should attempt to acquire account lock for write purposes and return a function that releases the lock
AcquireAccountWriteLock(accountID string) func()
// AcquireAccountReadLock should attempt to acquire account lock for read purposes and return a function that releases the lock
AcquireAccountReadLock(accountID string) func()
// AcquireAccountLock should attempt to acquire account lock and return a function that releases the lock
AcquireAccountLock(accountID string) func()
// AcquireGlobalLock should attempt to acquire a global lock and return a function that releases the lock
AcquireGlobalLock() func()
SavePeerStatus(accountID, peerID string, status nbpeer.PeerStatus) error

View File

@@ -210,7 +210,7 @@ func NewOwnerUser(id string) *User {
// createServiceUser creates a new service user under the given account.
func (am *DefaultAccountManager) createServiceUser(accountID string, initiatorUserID string, role UserRole, serviceUserName string, nonDeletable bool, autoGroups []string) (*UserInfo, error) {
unlock := am.Store.AcquireAccountWriteLock(accountID)
unlock := am.Store.AcquireAccountLock(accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID)
@@ -266,7 +266,7 @@ func (am *DefaultAccountManager) CreateUser(accountID, userID string, user *User
// inviteNewUser Invites a USer to a given account and creates reference in datastore
func (am *DefaultAccountManager) inviteNewUser(accountID, userID string, invite *UserInfo) (*UserInfo, error) {
unlock := am.Store.AcquireAccountWriteLock(accountID)
unlock := am.Store.AcquireAccountLock(accountID)
defer unlock()
if am.idpManager == nil {
@@ -367,7 +367,7 @@ func (am *DefaultAccountManager) GetUser(claims jwtclaims.AuthorizationClaims) (
return nil, fmt.Errorf("failed to get account with token claims %v", err)
}
unlock := am.Store.AcquireAccountWriteLock(account.Id)
unlock := am.Store.AcquireAccountLock(account.Id)
defer unlock()
account, err = am.Store.GetAccount(account.Id)
@@ -400,7 +400,7 @@ func (am *DefaultAccountManager) GetUser(claims jwtclaims.AuthorizationClaims) (
// ListUsers returns lists of all users under the account.
// It doesn't populate user information such as email or name.
func (am *DefaultAccountManager) ListUsers(accountID string) ([]*User, error) {
unlock := am.Store.AcquireAccountWriteLock(accountID)
unlock := am.Store.AcquireAccountLock(accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID)
@@ -427,7 +427,7 @@ func (am *DefaultAccountManager) DeleteUser(accountID, initiatorUserID string, t
if initiatorUserID == targetUserID {
return status.Errorf(status.InvalidArgument, "self deletion is not allowed")
}
unlock := am.Store.AcquireAccountWriteLock(accountID)
unlock := am.Store.AcquireAccountLock(accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID)
@@ -537,7 +537,7 @@ func (am *DefaultAccountManager) deleteUserPeers(initiatorUserID string, targetU
// InviteUser resend invitations to users who haven't activated their accounts prior to the expiration period.
func (am *DefaultAccountManager) InviteUser(accountID string, initiatorUserID string, targetUserID string) error {
unlock := am.Store.AcquireAccountWriteLock(accountID)
unlock := am.Store.AcquireAccountLock(accountID)
defer unlock()
if am.idpManager == nil {
@@ -577,7 +577,7 @@ func (am *DefaultAccountManager) InviteUser(accountID string, initiatorUserID st
// CreatePAT creates a new PAT for the given user
func (am *DefaultAccountManager) CreatePAT(accountID string, initiatorUserID string, targetUserID string, tokenName string, expiresIn int) (*PersonalAccessTokenGenerated, error) {
unlock := am.Store.AcquireAccountWriteLock(accountID)
unlock := am.Store.AcquireAccountLock(accountID)
defer unlock()
if tokenName == "" {
@@ -627,7 +627,7 @@ func (am *DefaultAccountManager) CreatePAT(accountID string, initiatorUserID str
// DeletePAT deletes a specific PAT from a user
func (am *DefaultAccountManager) DeletePAT(accountID string, initiatorUserID string, targetUserID string, tokenID string) error {
unlock := am.Store.AcquireAccountWriteLock(accountID)
unlock := am.Store.AcquireAccountLock(accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID)
@@ -677,7 +677,7 @@ func (am *DefaultAccountManager) DeletePAT(accountID string, initiatorUserID str
// GetPAT returns a specific PAT from a user
func (am *DefaultAccountManager) GetPAT(accountID string, initiatorUserID string, targetUserID string, tokenID string) (*PersonalAccessToken, error) {
unlock := am.Store.AcquireAccountWriteLock(accountID)
unlock := am.Store.AcquireAccountLock(accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID)
@@ -709,7 +709,7 @@ func (am *DefaultAccountManager) GetPAT(accountID string, initiatorUserID string
// GetAllPATs returns all PATs for a user
func (am *DefaultAccountManager) GetAllPATs(accountID string, initiatorUserID string, targetUserID string) ([]*PersonalAccessToken, error) {
unlock := am.Store.AcquireAccountWriteLock(accountID)
unlock := am.Store.AcquireAccountLock(accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID)
@@ -747,7 +747,7 @@ func (am *DefaultAccountManager) SaveUser(accountID, initiatorUserID string, upd
// SaveOrAddUser updates the given user. If addIfNotExists is set to true it will add user when no exist
// Only User.AutoGroups, User.Role, and User.Blocked fields are allowed to be updated for now.
func (am *DefaultAccountManager) SaveOrAddUser(accountID, initiatorUserID string, update *User, addIfNotExists bool) (*UserInfo, error) {
unlock := am.Store.AcquireAccountWriteLock(accountID)
unlock := am.Store.AcquireAccountLock(accountID)
defer unlock()
if update == nil {
@@ -960,7 +960,7 @@ func (am *DefaultAccountManager) GetUsersFromAccount(accountID, userID string) (
queriedUsers := make([]*idp.UserData, 0)
if !isNil(am.idpManager) {
users := make(map[string]userLoggedInOnce, len(account.Users))
users := make(map[string]struct{}, len(account.Users))
usersFromIntegration := make([]*idp.UserData, 0)
for _, user := range account.Users {
if user.Issued == UserIssuedIntegration {
@@ -968,14 +968,14 @@ func (am *DefaultAccountManager) GetUsersFromAccount(accountID, userID string) (
info, err := am.externalCacheManager.Get(am.ctx, key)
if err != nil {
log.Infof("Get ExternalCache for key: %s, error: %s", key, err)
users[user.Id] = true
users[user.Id] = struct{}{}
continue
}
usersFromIntegration = append(usersFromIntegration, info)
continue
}
if !user.IsServiceUser {
users[user.Id] = userLoggedInOnce(!user.LastLogin.IsZero())
users[user.Id] = struct{}{}
}
}
queriedUsers, err = am.lookupCache(users, accountID)

View File

@@ -68,11 +68,11 @@ type Route struct {
ID string `gorm:"primaryKey"`
// AccountID is a reference to Account that this object belongs
AccountID string `gorm:"index"`
Network netip.Prefix `gorm:"serializer:json"`
Network netip.Prefix `gorm:"serializer:gob"`
NetID string
Description string
Peer string
PeerGroups []string `gorm:"serializer:json"`
PeerGroups []string `gorm:"serializer:gob"`
NetworkType NetworkType
Masquerade bool
Metric int
@@ -107,12 +107,6 @@ func (r *Route) Copy() *Route {
// IsEqual compares one route with the other
func (r *Route) IsEqual(other *Route) bool {
if r == nil && other == nil {
return true
} else if r == nil || other == nil {
return false
}
return other.ID == r.ID &&
other.Description == r.Description &&
other.NetID == r.NetID &&

View File

@@ -2,6 +2,7 @@ package net
import (
"os"
"runtime"
"github.com/google/uuid"
)
@@ -23,5 +24,5 @@ func GenerateConnID() ConnectionID {
}
func CustomRoutingDisabled() bool {
return os.Getenv(envDisableCustomRouting) == "true"
return os.Getenv(envDisableCustomRouting) == "true" || runtime.GOOS == "ios"
}