mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-19 07:54:49 -04:00
Compare commits
52 Commits
debug-user
...
test/add-r
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9b2351193c | ||
|
|
63fd508556 | ||
|
|
760d61c7a3 | ||
|
|
93a0315120 | ||
|
|
676f201c83 | ||
|
|
b7173ab956 | ||
|
|
26b418c42f | ||
|
|
0267cd1ddd | ||
|
|
6b86350b9d | ||
|
|
102384bfbb | ||
|
|
0735340a0b | ||
|
|
51e4b9aba6 | ||
|
|
62f9c8ace9 | ||
|
|
c57869aa78 | ||
|
|
abf6a1e08e | ||
|
|
673f441d6e | ||
|
|
1a12100790 | ||
|
|
3e963ffeba | ||
|
|
86fa1eaa16 | ||
|
|
1046342e2c | ||
|
|
89729d85df | ||
|
|
2c5dff2f89 | ||
|
|
779643463d | ||
|
|
22ac5ea0e8 | ||
|
|
cf60191bb5 | ||
|
|
8bfab0d6dd | ||
|
|
921b5606ce | ||
|
|
84126f9425 | ||
|
|
489f13031b | ||
|
|
c5b065aec1 | ||
|
|
b09bc6534c | ||
|
|
34f1a366b3 | ||
|
|
483edfcdc6 | ||
|
|
ef2eace033 | ||
|
|
1bddfa5b7b | ||
|
|
6ea7c665dc | ||
|
|
4a3c782a31 | ||
|
|
9359fea507 | ||
|
|
fcd2c15a37 | ||
|
|
ebda0fc538 | ||
|
|
ac135ab11d | ||
|
|
25faf9283d | ||
|
|
59faaa99f6 | ||
|
|
9762b39f29 | ||
|
|
ffdd115ded | ||
|
|
055df9854c | ||
|
|
12f883badf | ||
|
|
2abb92b0d4 | ||
|
|
01c3719c5d | ||
|
|
7b64953eed | ||
|
|
9bc7d788f0 | ||
|
|
b5419ef11a |
@@ -235,13 +235,6 @@ func runForDuration(cmd *cobra.Command, args []string) error {
|
||||
return fmt.Errorf("failed to bundle debug: %v", status.Convert(err).Message())
|
||||
}
|
||||
|
||||
// Disable network map persistence after creating the debug bundle
|
||||
if _, err := client.SetNetworkMapPersistence(cmd.Context(), &proto.SetNetworkMapPersistenceRequest{
|
||||
Enabled: false,
|
||||
}); err != nil {
|
||||
return fmt.Errorf("failed to disable network map persistence: %v", status.Convert(err).Message())
|
||||
}
|
||||
|
||||
if stateWasDown {
|
||||
if _, err := client.Down(cmd.Context(), &proto.DownRequest{}); err != nil {
|
||||
return fmt.Errorf("failed to down: %v", status.Convert(err).Message())
|
||||
|
||||
@@ -42,7 +42,6 @@ const (
|
||||
blockLANAccessFlag = "block-lan-access"
|
||||
uploadBundle = "upload-bundle"
|
||||
uploadBundleURL = "upload-bundle-url"
|
||||
defaultBundleURL = "https://upload.debug.netbird.io" + types.GetURLPath
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -188,7 +187,7 @@ func init() {
|
||||
|
||||
debugCmd.PersistentFlags().BoolVarP(&debugSystemInfoFlag, systemInfoFlag, "S", true, "Adds system information to the debug bundle")
|
||||
debugCmd.PersistentFlags().BoolVarP(&debugUploadBundle, uploadBundle, "U", false, fmt.Sprintf("Uploads the debug bundle to a server from URL defined by %s", uploadBundleURL))
|
||||
debugCmd.PersistentFlags().StringVar(&debugUploadBundleURL, uploadBundleURL, defaultBundleURL, "Service URL to get an URL to upload the debug bundle")
|
||||
debugCmd.PersistentFlags().StringVar(&debugUploadBundleURL, uploadBundleURL, types.DefaultBundleURL, "Service URL to get an URL to upload the debug bundle")
|
||||
}
|
||||
|
||||
// SetupCloseHandler handles SIGTERM signal and exits with success
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package dns_test
|
||||
|
||||
import (
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
@@ -9,6 +8,7 @@ import (
|
||||
"github.com/stretchr/testify/mock"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/client/internal/dns"
|
||||
"github.com/netbirdio/netbird/client/internal/dns/test"
|
||||
)
|
||||
|
||||
// TestHandlerChain_ServeDNS_Priorities tests that handlers are executed in priority order
|
||||
@@ -30,7 +30,7 @@ func TestHandlerChain_ServeDNS_Priorities(t *testing.T) {
|
||||
r.SetQuestion("example.com.", dns.TypeA)
|
||||
|
||||
// Create test writer
|
||||
w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
|
||||
w := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}}
|
||||
|
||||
// Setup expectations - only highest priority handler should be called
|
||||
dnsRouteHandler.On("ServeDNS", mock.Anything, r).Once()
|
||||
@@ -142,7 +142,7 @@ func TestHandlerChain_ServeDNS_DomainMatching(t *testing.T) {
|
||||
|
||||
r := new(dns.Msg)
|
||||
r.SetQuestion(tt.queryDomain, dns.TypeA)
|
||||
w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
|
||||
w := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}}
|
||||
|
||||
chain.ServeDNS(w, r)
|
||||
|
||||
@@ -259,7 +259,7 @@ func TestHandlerChain_ServeDNS_OverlappingDomains(t *testing.T) {
|
||||
// Create and execute request
|
||||
r := new(dns.Msg)
|
||||
r.SetQuestion(tt.queryDomain, dns.TypeA)
|
||||
w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
|
||||
w := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}}
|
||||
chain.ServeDNS(w, r)
|
||||
|
||||
// Verify expectations
|
||||
@@ -316,7 +316,7 @@ func TestHandlerChain_ServeDNS_ChainContinuation(t *testing.T) {
|
||||
}).Once()
|
||||
|
||||
// Execute
|
||||
w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
|
||||
w := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}}
|
||||
chain.ServeDNS(w, r)
|
||||
|
||||
// Verify all handlers were called in order
|
||||
@@ -325,20 +325,6 @@ func TestHandlerChain_ServeDNS_ChainContinuation(t *testing.T) {
|
||||
handler3.AssertExpectations(t)
|
||||
}
|
||||
|
||||
// mockResponseWriter implements dns.ResponseWriter for testing
|
||||
type mockResponseWriter struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *mockResponseWriter) LocalAddr() net.Addr { return nil }
|
||||
func (m *mockResponseWriter) RemoteAddr() net.Addr { return nil }
|
||||
func (m *mockResponseWriter) WriteMsg(*dns.Msg) error { return nil }
|
||||
func (m *mockResponseWriter) Write([]byte) (int, error) { return 0, nil }
|
||||
func (m *mockResponseWriter) Close() error { return nil }
|
||||
func (m *mockResponseWriter) TsigStatus() error { return nil }
|
||||
func (m *mockResponseWriter) TsigTimersOnly(bool) {}
|
||||
func (m *mockResponseWriter) Hijack() {}
|
||||
|
||||
func TestHandlerChain_PriorityDeregistration(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -425,7 +411,7 @@ func TestHandlerChain_PriorityDeregistration(t *testing.T) {
|
||||
// Create test request
|
||||
r := new(dns.Msg)
|
||||
r.SetQuestion(tt.query, dns.TypeA)
|
||||
w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
|
||||
w := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}}
|
||||
|
||||
// Setup expectations
|
||||
for priority, handler := range handlers {
|
||||
@@ -471,7 +457,7 @@ func TestHandlerChain_MultiPriorityHandling(t *testing.T) {
|
||||
chain.AddHandler(testDomain, matchHandler, nbdns.PriorityMatchDomain)
|
||||
|
||||
// Test 1: Initial state
|
||||
w1 := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
|
||||
w1 := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}}
|
||||
// Highest priority handler (routeHandler) should be called
|
||||
routeHandler.On("ServeDNS", mock.Anything, r).Return().Once()
|
||||
matchHandler.On("ServeDNS", mock.Anything, r).Maybe() // Ensure others are not expected yet
|
||||
@@ -490,7 +476,7 @@ func TestHandlerChain_MultiPriorityHandling(t *testing.T) {
|
||||
// Test 2: Remove highest priority handler
|
||||
chain.RemoveHandler(testDomain, nbdns.PriorityDNSRoute)
|
||||
|
||||
w2 := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
|
||||
w2 := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}}
|
||||
// Now middle priority handler (matchHandler) should be called
|
||||
matchHandler.On("ServeDNS", mock.Anything, r).Return().Once()
|
||||
defaultHandler.On("ServeDNS", mock.Anything, r).Maybe() // Ensure default is not expected yet
|
||||
@@ -506,7 +492,7 @@ func TestHandlerChain_MultiPriorityHandling(t *testing.T) {
|
||||
// Test 3: Remove middle priority handler
|
||||
chain.RemoveHandler(testDomain, nbdns.PriorityMatchDomain)
|
||||
|
||||
w3 := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
|
||||
w3 := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}}
|
||||
// Now lowest priority handler (defaultHandler) should be called
|
||||
defaultHandler.On("ServeDNS", mock.Anything, r).Return().Once()
|
||||
|
||||
@@ -519,7 +505,7 @@ func TestHandlerChain_MultiPriorityHandling(t *testing.T) {
|
||||
// Test 4: Remove last handler
|
||||
chain.RemoveHandler(testDomain, nbdns.PriorityDefault)
|
||||
|
||||
w4 := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
|
||||
w4 := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}}
|
||||
chain.ServeDNS(w4, r) // Call ServeDNS on the now empty chain for this domain
|
||||
|
||||
for _, m := range mocks {
|
||||
@@ -675,7 +661,7 @@ func TestHandlerChain_CaseSensitivity(t *testing.T) {
|
||||
// Execute request
|
||||
r := new(dns.Msg)
|
||||
r.SetQuestion(tt.query, dns.TypeA)
|
||||
chain.ServeDNS(&mockResponseWriter{}, r)
|
||||
chain.ServeDNS(&test.MockResponseWriter{}, r)
|
||||
|
||||
// Verify each handler was called exactly as expected
|
||||
for _, h := range tt.addHandlers {
|
||||
@@ -819,7 +805,7 @@ func TestHandlerChain_DomainSpecificityOrdering(t *testing.T) {
|
||||
|
||||
r := new(dns.Msg)
|
||||
r.SetQuestion(tt.query, dns.TypeA)
|
||||
w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
|
||||
w := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}}
|
||||
|
||||
// Setup handler expectations
|
||||
for pattern, handler := range handlers {
|
||||
@@ -969,7 +955,7 @@ func TestHandlerChain_AddRemoveRoundtrip(t *testing.T) {
|
||||
handler := &nbdns.MockHandler{}
|
||||
r := new(dns.Msg)
|
||||
r.SetQuestion(tt.queryPattern, dns.TypeA)
|
||||
w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
|
||||
w := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}}
|
||||
|
||||
// First verify no handler is called before adding any
|
||||
chain.ServeDNS(w, r)
|
||||
|
||||
@@ -1,130 +0,0 @@
|
||||
package dns
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
)
|
||||
|
||||
type registrationMap map[string]struct{}
|
||||
|
||||
type localResolver struct {
|
||||
registeredMap registrationMap
|
||||
records sync.Map // key: string (domain_class_type), value: []dns.RR
|
||||
}
|
||||
|
||||
func (d *localResolver) MatchSubdomains() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (d *localResolver) stop() {
|
||||
}
|
||||
|
||||
// String returns a string representation of the local resolver
|
||||
func (d *localResolver) String() string {
|
||||
return fmt.Sprintf("local resolver [%d records]", len(d.registeredMap))
|
||||
}
|
||||
|
||||
// ID returns the unique handler ID
|
||||
func (d *localResolver) id() handlerID {
|
||||
return "local-resolver"
|
||||
}
|
||||
|
||||
// ServeDNS handles a DNS request
|
||||
func (d *localResolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||
if len(r.Question) > 0 {
|
||||
log.Tracef("received local question: domain=%s type=%v class=%v", r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass)
|
||||
}
|
||||
|
||||
replyMessage := &dns.Msg{}
|
||||
replyMessage.SetReply(r)
|
||||
replyMessage.RecursionAvailable = true
|
||||
|
||||
// lookup all records matching the question
|
||||
records := d.lookupRecords(r)
|
||||
if len(records) > 0 {
|
||||
replyMessage.Rcode = dns.RcodeSuccess
|
||||
replyMessage.Answer = append(replyMessage.Answer, records...)
|
||||
} else {
|
||||
replyMessage.Rcode = dns.RcodeNameError
|
||||
}
|
||||
|
||||
err := w.WriteMsg(replyMessage)
|
||||
if err != nil {
|
||||
log.Debugf("got an error while writing the local resolver response, error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// lookupRecords fetches *all* DNS records matching the first question in r.
|
||||
func (d *localResolver) lookupRecords(r *dns.Msg) []dns.RR {
|
||||
if len(r.Question) == 0 {
|
||||
return nil
|
||||
}
|
||||
question := r.Question[0]
|
||||
question.Name = strings.ToLower(question.Name)
|
||||
key := buildRecordKey(question.Name, question.Qclass, question.Qtype)
|
||||
|
||||
value, found := d.records.Load(key)
|
||||
if !found {
|
||||
// alternatively check if we have a cname
|
||||
if question.Qtype != dns.TypeCNAME {
|
||||
r.Question[0].Qtype = dns.TypeCNAME
|
||||
return d.lookupRecords(r)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
records, ok := value.([]dns.RR)
|
||||
if !ok {
|
||||
log.Errorf("failed to cast records to []dns.RR, records: %v", value)
|
||||
return nil
|
||||
}
|
||||
|
||||
// if there's more than one record, rotate them (round-robin)
|
||||
if len(records) > 1 {
|
||||
first := records[0]
|
||||
records = append(records[1:], first)
|
||||
d.records.Store(key, records)
|
||||
}
|
||||
|
||||
return records
|
||||
}
|
||||
|
||||
// registerRecord stores a new record by appending it to any existing list
|
||||
func (d *localResolver) registerRecord(record nbdns.SimpleRecord) (string, error) {
|
||||
rr, err := dns.NewRR(record.String())
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("register record: %w", err)
|
||||
}
|
||||
|
||||
rr.Header().Rdlength = record.Len()
|
||||
header := rr.Header()
|
||||
key := buildRecordKey(header.Name, header.Class, header.Rrtype)
|
||||
|
||||
// load any existing slice of records, then append
|
||||
existing, _ := d.records.LoadOrStore(key, []dns.RR{})
|
||||
records := existing.([]dns.RR)
|
||||
records = append(records, rr)
|
||||
|
||||
// store updated slice
|
||||
d.records.Store(key, records)
|
||||
return key, nil
|
||||
}
|
||||
|
||||
// deleteRecord removes *all* records under the recordKey.
|
||||
func (d *localResolver) deleteRecord(recordKey string) {
|
||||
d.records.Delete(dns.Fqdn(recordKey))
|
||||
}
|
||||
|
||||
// buildRecordKey consistently generates a key: name_class_type
|
||||
func buildRecordKey(name string, class, qType uint16) string {
|
||||
return fmt.Sprintf("%s_%d_%d", dns.Fqdn(name), class, qType)
|
||||
}
|
||||
|
||||
func (d *localResolver) probeAvailability() {}
|
||||
149
client/internal/dns/local/local.go
Normal file
149
client/internal/dns/local/local.go
Normal file
@@ -0,0 +1,149 @@
|
||||
package local
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/exp/maps"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/dns/types"
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
)
|
||||
|
||||
type Resolver struct {
|
||||
mu sync.RWMutex
|
||||
records map[dns.Question][]dns.RR
|
||||
}
|
||||
|
||||
func NewResolver() *Resolver {
|
||||
return &Resolver{
|
||||
records: make(map[dns.Question][]dns.RR),
|
||||
}
|
||||
}
|
||||
|
||||
func (d *Resolver) MatchSubdomains() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// String returns a string representation of the local resolver
|
||||
func (d *Resolver) String() string {
|
||||
return fmt.Sprintf("local resolver [%d records]", len(d.records))
|
||||
}
|
||||
|
||||
func (d *Resolver) Stop() {}
|
||||
|
||||
// ID returns the unique handler ID
|
||||
func (d *Resolver) ID() types.HandlerID {
|
||||
return "local-resolver"
|
||||
}
|
||||
|
||||
func (d *Resolver) ProbeAvailability() {}
|
||||
|
||||
// ServeDNS handles a DNS request
|
||||
func (d *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||
if len(r.Question) == 0 {
|
||||
log.Debugf("received local resolver request with no question")
|
||||
return
|
||||
}
|
||||
question := r.Question[0]
|
||||
question.Name = strings.ToLower(dns.Fqdn(question.Name))
|
||||
|
||||
log.Tracef("received local question: domain=%s type=%v class=%v", r.Question[0].Name, question.Qtype, question.Qclass)
|
||||
|
||||
replyMessage := &dns.Msg{}
|
||||
replyMessage.SetReply(r)
|
||||
replyMessage.RecursionAvailable = true
|
||||
|
||||
// lookup all records matching the question
|
||||
records := d.lookupRecords(question)
|
||||
if len(records) > 0 {
|
||||
replyMessage.Rcode = dns.RcodeSuccess
|
||||
replyMessage.Answer = append(replyMessage.Answer, records...)
|
||||
} else {
|
||||
// TODO: return success if we have a different record type for the same name, relevant for search domains
|
||||
replyMessage.Rcode = dns.RcodeNameError
|
||||
}
|
||||
|
||||
if err := w.WriteMsg(replyMessage); err != nil {
|
||||
log.Warnf("failed to write the local resolver response: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// lookupRecords fetches *all* DNS records matching the first question in r.
|
||||
func (d *Resolver) lookupRecords(question dns.Question) []dns.RR {
|
||||
d.mu.RLock()
|
||||
records, found := d.records[question]
|
||||
|
||||
if !found {
|
||||
d.mu.RUnlock()
|
||||
// alternatively check if we have a cname
|
||||
if question.Qtype != dns.TypeCNAME {
|
||||
question.Qtype = dns.TypeCNAME
|
||||
return d.lookupRecords(question)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
recordsCopy := slices.Clone(records)
|
||||
d.mu.RUnlock()
|
||||
|
||||
// if there's more than one record, rotate them (round-robin)
|
||||
if len(recordsCopy) > 1 {
|
||||
d.mu.Lock()
|
||||
records = d.records[question]
|
||||
if len(records) > 1 {
|
||||
first := records[0]
|
||||
records = append(records[1:], first)
|
||||
d.records[question] = records
|
||||
}
|
||||
d.mu.Unlock()
|
||||
}
|
||||
|
||||
return recordsCopy
|
||||
}
|
||||
|
||||
func (d *Resolver) Update(update []nbdns.SimpleRecord) {
|
||||
d.mu.Lock()
|
||||
defer d.mu.Unlock()
|
||||
|
||||
maps.Clear(d.records)
|
||||
|
||||
for _, rec := range update {
|
||||
if err := d.registerRecord(rec); err != nil {
|
||||
log.Warnf("failed to register the record (%s): %v", rec, err)
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// RegisterRecord stores a new record by appending it to any existing list
|
||||
func (d *Resolver) RegisterRecord(record nbdns.SimpleRecord) error {
|
||||
d.mu.Lock()
|
||||
defer d.mu.Unlock()
|
||||
|
||||
return d.registerRecord(record)
|
||||
}
|
||||
|
||||
// registerRecord performs the registration with the lock already held
|
||||
func (d *Resolver) registerRecord(record nbdns.SimpleRecord) error {
|
||||
rr, err := dns.NewRR(record.String())
|
||||
if err != nil {
|
||||
return fmt.Errorf("register record: %w", err)
|
||||
}
|
||||
|
||||
rr.Header().Rdlength = record.Len()
|
||||
header := rr.Header()
|
||||
q := dns.Question{
|
||||
Name: strings.ToLower(dns.Fqdn(header.Name)),
|
||||
Qtype: header.Rrtype,
|
||||
Qclass: header.Class,
|
||||
}
|
||||
|
||||
d.records[q] = append(d.records[q], rr)
|
||||
|
||||
return nil
|
||||
}
|
||||
472
client/internal/dns/local/local_test.go
Normal file
472
client/internal/dns/local/local_test.go
Normal file
@@ -0,0 +1,472 @@
|
||||
package local
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/dns/test"
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
)
|
||||
|
||||
func TestLocalResolver_ServeDNS(t *testing.T) {
|
||||
recordA := nbdns.SimpleRecord{
|
||||
Name: "peera.netbird.cloud.",
|
||||
Type: 1,
|
||||
Class: nbdns.DefaultClass,
|
||||
TTL: 300,
|
||||
RData: "1.2.3.4",
|
||||
}
|
||||
|
||||
recordCNAME := nbdns.SimpleRecord{
|
||||
Name: "peerb.netbird.cloud.",
|
||||
Type: 5,
|
||||
Class: nbdns.DefaultClass,
|
||||
TTL: 300,
|
||||
RData: "www.netbird.io",
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
inputRecord nbdns.SimpleRecord
|
||||
inputMSG *dns.Msg
|
||||
responseShouldBeNil bool
|
||||
}{
|
||||
{
|
||||
name: "Should Resolve A Record",
|
||||
inputRecord: recordA,
|
||||
inputMSG: new(dns.Msg).SetQuestion(recordA.Name, dns.TypeA),
|
||||
},
|
||||
{
|
||||
name: "Should Resolve CNAME Record",
|
||||
inputRecord: recordCNAME,
|
||||
inputMSG: new(dns.Msg).SetQuestion(recordCNAME.Name, dns.TypeCNAME),
|
||||
},
|
||||
{
|
||||
name: "Should Not Write When Not Found A Record",
|
||||
inputRecord: recordA,
|
||||
inputMSG: new(dns.Msg).SetQuestion("not.found.com", dns.TypeA),
|
||||
responseShouldBeNil: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, testCase := range testCases {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
resolver := NewResolver()
|
||||
_ = resolver.RegisterRecord(testCase.inputRecord)
|
||||
var responseMSG *dns.Msg
|
||||
responseWriter := &test.MockResponseWriter{
|
||||
WriteMsgFunc: func(m *dns.Msg) error {
|
||||
responseMSG = m
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
resolver.ServeDNS(responseWriter, testCase.inputMSG)
|
||||
|
||||
if responseMSG == nil || len(responseMSG.Answer) == 0 {
|
||||
if testCase.responseShouldBeNil {
|
||||
return
|
||||
}
|
||||
t.Fatalf("should write a response message")
|
||||
}
|
||||
|
||||
answerString := responseMSG.Answer[0].String()
|
||||
if !strings.Contains(answerString, testCase.inputRecord.Name) {
|
||||
t.Fatalf("answer doesn't contain the same domain name: \nWant: %s\nGot:%s", testCase.name, answerString)
|
||||
}
|
||||
if !strings.Contains(answerString, dns.Type(testCase.inputRecord.Type).String()) {
|
||||
t.Fatalf("answer doesn't contain the correct type: \nWant: %s\nGot:%s", dns.Type(testCase.inputRecord.Type).String(), answerString)
|
||||
}
|
||||
if !strings.Contains(answerString, testCase.inputRecord.RData) {
|
||||
t.Fatalf("answer doesn't contain the same address: \nWant: %s\nGot:%s", testCase.inputRecord.RData, answerString)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestLocalResolver_Update_StaleRecord verifies that updating
|
||||
// a record correctly replaces the old one, preventing stale entries.
|
||||
func TestLocalResolver_Update_StaleRecord(t *testing.T) {
|
||||
recordName := "host.example.com."
|
||||
recordType := dns.TypeA
|
||||
recordClass := dns.ClassINET
|
||||
|
||||
record1 := nbdns.SimpleRecord{
|
||||
Name: recordName, Type: int(recordType), Class: nbdns.DefaultClass, TTL: 300, RData: "1.1.1.1",
|
||||
}
|
||||
record2 := nbdns.SimpleRecord{
|
||||
Name: recordName, Type: int(recordType), Class: nbdns.DefaultClass, TTL: 300, RData: "2.2.2.2",
|
||||
}
|
||||
|
||||
recordKey := dns.Question{Name: recordName, Qtype: uint16(recordClass), Qclass: recordType}
|
||||
|
||||
resolver := NewResolver()
|
||||
|
||||
update1 := []nbdns.SimpleRecord{record1}
|
||||
update2 := []nbdns.SimpleRecord{record2}
|
||||
|
||||
// Apply first update
|
||||
resolver.Update(update1)
|
||||
|
||||
// Verify first update
|
||||
resolver.mu.RLock()
|
||||
rrSlice1, found1 := resolver.records[recordKey]
|
||||
resolver.mu.RUnlock()
|
||||
|
||||
require.True(t, found1, "Record key %s not found after first update", recordKey)
|
||||
require.Len(t, rrSlice1, 1, "Should have exactly 1 record after first update")
|
||||
assert.Contains(t, rrSlice1[0].String(), record1.RData, "Record after first update should be %s", record1.RData)
|
||||
|
||||
// Apply second update
|
||||
resolver.Update(update2)
|
||||
|
||||
// Verify second update
|
||||
resolver.mu.RLock()
|
||||
rrSlice2, found2 := resolver.records[recordKey]
|
||||
resolver.mu.RUnlock()
|
||||
|
||||
require.True(t, found2, "Record key %s not found after second update", recordKey)
|
||||
require.Len(t, rrSlice2, 1, "Should have exactly 1 record after update overwriting the key")
|
||||
assert.Contains(t, rrSlice2[0].String(), record2.RData, "The single record should be the updated one (%s)", record2.RData)
|
||||
assert.NotContains(t, rrSlice2[0].String(), record1.RData, "The stale record (%s) should not be present", record1.RData)
|
||||
}
|
||||
|
||||
// TestLocalResolver_MultipleRecords_SameQuestion verifies that multiple records
|
||||
// with the same question are stored properly
|
||||
func TestLocalResolver_MultipleRecords_SameQuestion(t *testing.T) {
|
||||
resolver := NewResolver()
|
||||
|
||||
recordName := "multi.example.com."
|
||||
recordType := dns.TypeA
|
||||
|
||||
// Create two records with the same name and type but different IPs
|
||||
record1 := nbdns.SimpleRecord{
|
||||
Name: recordName, Type: int(recordType), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1",
|
||||
}
|
||||
record2 := nbdns.SimpleRecord{
|
||||
Name: recordName, Type: int(recordType), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.2",
|
||||
}
|
||||
|
||||
update := []nbdns.SimpleRecord{record1, record2}
|
||||
|
||||
// Apply update with both records
|
||||
resolver.Update(update)
|
||||
|
||||
// Create question that matches both records
|
||||
question := dns.Question{
|
||||
Name: recordName,
|
||||
Qtype: recordType,
|
||||
Qclass: dns.ClassINET,
|
||||
}
|
||||
|
||||
// Verify both records are stored
|
||||
resolver.mu.RLock()
|
||||
records, found := resolver.records[question]
|
||||
resolver.mu.RUnlock()
|
||||
|
||||
require.True(t, found, "Records for question %v not found", question)
|
||||
require.Len(t, records, 2, "Should have exactly 2 records for the same question")
|
||||
|
||||
// Verify both record data values are present
|
||||
recordStrings := []string{records[0].String(), records[1].String()}
|
||||
assert.Contains(t, recordStrings[0]+recordStrings[1], record1.RData, "First record data should be present")
|
||||
assert.Contains(t, recordStrings[0]+recordStrings[1], record2.RData, "Second record data should be present")
|
||||
}
|
||||
|
||||
// TestLocalResolver_RecordRotation verifies that records are rotated in a round-robin fashion
|
||||
func TestLocalResolver_RecordRotation(t *testing.T) {
|
||||
resolver := NewResolver()
|
||||
|
||||
recordName := "rotation.example.com."
|
||||
recordType := dns.TypeA
|
||||
|
||||
// Create three records with the same name and type but different IPs
|
||||
record1 := nbdns.SimpleRecord{
|
||||
Name: recordName, Type: int(recordType), Class: nbdns.DefaultClass, TTL: 300, RData: "192.168.1.1",
|
||||
}
|
||||
record2 := nbdns.SimpleRecord{
|
||||
Name: recordName, Type: int(recordType), Class: nbdns.DefaultClass, TTL: 300, RData: "192.168.1.2",
|
||||
}
|
||||
record3 := nbdns.SimpleRecord{
|
||||
Name: recordName, Type: int(recordType), Class: nbdns.DefaultClass, TTL: 300, RData: "192.168.1.3",
|
||||
}
|
||||
|
||||
update := []nbdns.SimpleRecord{record1, record2, record3}
|
||||
|
||||
// Apply update with all three records
|
||||
resolver.Update(update)
|
||||
|
||||
msg := new(dns.Msg).SetQuestion(recordName, recordType)
|
||||
|
||||
// First lookup - should return the records in original order
|
||||
var responses [3]*dns.Msg
|
||||
|
||||
// Perform three lookups to verify rotation
|
||||
for i := 0; i < 3; i++ {
|
||||
responseWriter := &test.MockResponseWriter{
|
||||
WriteMsgFunc: func(m *dns.Msg) error {
|
||||
responses[i] = m
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
resolver.ServeDNS(responseWriter, msg)
|
||||
}
|
||||
|
||||
// Verify all three responses contain answers
|
||||
for i, resp := range responses {
|
||||
require.NotNil(t, resp, "Response %d should not be nil", i)
|
||||
require.Len(t, resp.Answer, 3, "Response %d should have 3 answers", i)
|
||||
}
|
||||
|
||||
// Verify the first record in each response is different due to rotation
|
||||
firstRecordIPs := []string{
|
||||
responses[0].Answer[0].String(),
|
||||
responses[1].Answer[0].String(),
|
||||
responses[2].Answer[0].String(),
|
||||
}
|
||||
|
||||
// Each record should be different (rotated)
|
||||
assert.NotEqual(t, firstRecordIPs[0], firstRecordIPs[1], "First lookup should differ from second lookup due to rotation")
|
||||
assert.NotEqual(t, firstRecordIPs[1], firstRecordIPs[2], "Second lookup should differ from third lookup due to rotation")
|
||||
assert.NotEqual(t, firstRecordIPs[0], firstRecordIPs[2], "First lookup should differ from third lookup due to rotation")
|
||||
|
||||
// After three rotations, we should have cycled through all records
|
||||
assert.Contains(t, firstRecordIPs[0]+firstRecordIPs[1]+firstRecordIPs[2], record1.RData)
|
||||
assert.Contains(t, firstRecordIPs[0]+firstRecordIPs[1]+firstRecordIPs[2], record2.RData)
|
||||
assert.Contains(t, firstRecordIPs[0]+firstRecordIPs[1]+firstRecordIPs[2], record3.RData)
|
||||
}
|
||||
|
||||
// TestLocalResolver_CaseInsensitiveMatching verifies that DNS record lookups are case-insensitive
|
||||
func TestLocalResolver_CaseInsensitiveMatching(t *testing.T) {
|
||||
resolver := NewResolver()
|
||||
|
||||
// Create record with lowercase name
|
||||
lowerCaseRecord := nbdns.SimpleRecord{
|
||||
Name: "lower.example.com.",
|
||||
Type: int(dns.TypeA),
|
||||
Class: nbdns.DefaultClass,
|
||||
TTL: 300,
|
||||
RData: "10.10.10.10",
|
||||
}
|
||||
|
||||
// Create record with mixed case name
|
||||
mixedCaseRecord := nbdns.SimpleRecord{
|
||||
Name: "MiXeD.ExAmPlE.CoM.",
|
||||
Type: int(dns.TypeA),
|
||||
Class: nbdns.DefaultClass,
|
||||
TTL: 300,
|
||||
RData: "20.20.20.20",
|
||||
}
|
||||
|
||||
// Update resolver with the records
|
||||
resolver.Update([]nbdns.SimpleRecord{lowerCaseRecord, mixedCaseRecord})
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
queryName string
|
||||
expectedRData string
|
||||
shouldResolve bool
|
||||
}{
|
||||
{
|
||||
name: "Query lowercase with lowercase record",
|
||||
queryName: "lower.example.com.",
|
||||
expectedRData: "10.10.10.10",
|
||||
shouldResolve: true,
|
||||
},
|
||||
{
|
||||
name: "Query uppercase with lowercase record",
|
||||
queryName: "LOWER.EXAMPLE.COM.",
|
||||
expectedRData: "10.10.10.10",
|
||||
shouldResolve: true,
|
||||
},
|
||||
{
|
||||
name: "Query mixed case with lowercase record",
|
||||
queryName: "LoWeR.eXaMpLe.CoM.",
|
||||
expectedRData: "10.10.10.10",
|
||||
shouldResolve: true,
|
||||
},
|
||||
{
|
||||
name: "Query lowercase with mixed case record",
|
||||
queryName: "mixed.example.com.",
|
||||
expectedRData: "20.20.20.20",
|
||||
shouldResolve: true,
|
||||
},
|
||||
{
|
||||
name: "Query uppercase with mixed case record",
|
||||
queryName: "MIXED.EXAMPLE.COM.",
|
||||
expectedRData: "20.20.20.20",
|
||||
shouldResolve: true,
|
||||
},
|
||||
{
|
||||
name: "Query with different casing pattern",
|
||||
queryName: "mIxEd.ExaMpLe.cOm.",
|
||||
expectedRData: "20.20.20.20",
|
||||
shouldResolve: true,
|
||||
},
|
||||
{
|
||||
name: "Query non-existent domain",
|
||||
queryName: "nonexistent.example.com.",
|
||||
shouldResolve: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
var responseMSG *dns.Msg
|
||||
|
||||
// Create DNS query with the test case name
|
||||
msg := new(dns.Msg).SetQuestion(tc.queryName, dns.TypeA)
|
||||
|
||||
// Create mock response writer to capture the response
|
||||
responseWriter := &test.MockResponseWriter{
|
||||
WriteMsgFunc: func(m *dns.Msg) error {
|
||||
responseMSG = m
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
// Perform DNS query
|
||||
resolver.ServeDNS(responseWriter, msg)
|
||||
|
||||
// Check if we expect a successful resolution
|
||||
if !tc.shouldResolve {
|
||||
if responseMSG == nil || len(responseMSG.Answer) == 0 {
|
||||
// Expected no answer, test passes
|
||||
return
|
||||
}
|
||||
t.Fatalf("Expected no resolution for %s, but got answer: %v", tc.queryName, responseMSG.Answer)
|
||||
}
|
||||
|
||||
// Verify we got a response
|
||||
require.NotNil(t, responseMSG, "Should have received a response message")
|
||||
require.Greater(t, len(responseMSG.Answer), 0, "Response should contain at least one answer")
|
||||
|
||||
// Verify the response contains the expected data
|
||||
answerString := responseMSG.Answer[0].String()
|
||||
assert.Contains(t, answerString, tc.expectedRData,
|
||||
"Answer should contain the expected IP address %s, got: %s",
|
||||
tc.expectedRData, answerString)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestLocalResolver_CNAMEFallback verifies that the resolver correctly falls back
|
||||
// to checking for CNAME records when the requested record type isn't found
|
||||
func TestLocalResolver_CNAMEFallback(t *testing.T) {
|
||||
resolver := NewResolver()
|
||||
|
||||
// Create a CNAME record (but no A record for this name)
|
||||
cnameRecord := nbdns.SimpleRecord{
|
||||
Name: "alias.example.com.",
|
||||
Type: int(dns.TypeCNAME),
|
||||
Class: nbdns.DefaultClass,
|
||||
TTL: 300,
|
||||
RData: "target.example.com.",
|
||||
}
|
||||
|
||||
// Create an A record for the CNAME target
|
||||
targetRecord := nbdns.SimpleRecord{
|
||||
Name: "target.example.com.",
|
||||
Type: int(dns.TypeA),
|
||||
Class: nbdns.DefaultClass,
|
||||
TTL: 300,
|
||||
RData: "192.168.100.100",
|
||||
}
|
||||
|
||||
// Update resolver with both records
|
||||
resolver.Update([]nbdns.SimpleRecord{cnameRecord, targetRecord})
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
queryName string
|
||||
queryType uint16
|
||||
expectedType string
|
||||
expectedRData string
|
||||
shouldResolve bool
|
||||
}{
|
||||
{
|
||||
name: "Directly query CNAME record",
|
||||
queryName: "alias.example.com.",
|
||||
queryType: dns.TypeCNAME,
|
||||
expectedType: "CNAME",
|
||||
expectedRData: "target.example.com.",
|
||||
shouldResolve: true,
|
||||
},
|
||||
{
|
||||
name: "Query A record but get CNAME fallback",
|
||||
queryName: "alias.example.com.",
|
||||
queryType: dns.TypeA,
|
||||
expectedType: "CNAME",
|
||||
expectedRData: "target.example.com.",
|
||||
shouldResolve: true,
|
||||
},
|
||||
{
|
||||
name: "Query AAAA record but get CNAME fallback",
|
||||
queryName: "alias.example.com.",
|
||||
queryType: dns.TypeAAAA,
|
||||
expectedType: "CNAME",
|
||||
expectedRData: "target.example.com.",
|
||||
shouldResolve: true,
|
||||
},
|
||||
{
|
||||
name: "Query direct A record",
|
||||
queryName: "target.example.com.",
|
||||
queryType: dns.TypeA,
|
||||
expectedType: "A",
|
||||
expectedRData: "192.168.100.100",
|
||||
shouldResolve: true,
|
||||
},
|
||||
{
|
||||
name: "Query non-existent name",
|
||||
queryName: "nonexistent.example.com.",
|
||||
queryType: dns.TypeA,
|
||||
shouldResolve: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
var responseMSG *dns.Msg
|
||||
|
||||
// Create DNS query with the test case parameters
|
||||
msg := new(dns.Msg).SetQuestion(tc.queryName, tc.queryType)
|
||||
|
||||
// Create mock response writer to capture the response
|
||||
responseWriter := &test.MockResponseWriter{
|
||||
WriteMsgFunc: func(m *dns.Msg) error {
|
||||
responseMSG = m
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
// Perform DNS query
|
||||
resolver.ServeDNS(responseWriter, msg)
|
||||
|
||||
// Check if we expect a successful resolution
|
||||
if !tc.shouldResolve {
|
||||
if responseMSG == nil || len(responseMSG.Answer) == 0 || responseMSG.Rcode != dns.RcodeSuccess {
|
||||
// Expected no resolution, test passes
|
||||
return
|
||||
}
|
||||
t.Fatalf("Expected no resolution for %s, but got answer: %v", tc.queryName, responseMSG.Answer)
|
||||
}
|
||||
|
||||
// Verify we got a successful response
|
||||
require.NotNil(t, responseMSG, "Should have received a response message")
|
||||
require.Equal(t, dns.RcodeSuccess, responseMSG.Rcode, "Response should have success status code")
|
||||
require.Greater(t, len(responseMSG.Answer), 0, "Response should contain at least one answer")
|
||||
|
||||
// Verify the response contains the expected data
|
||||
answerString := responseMSG.Answer[0].String()
|
||||
assert.Contains(t, answerString, tc.expectedType,
|
||||
"Answer should be of type %s, got: %s", tc.expectedType, answerString)
|
||||
assert.Contains(t, answerString, tc.expectedRData,
|
||||
"Answer should contain the expected data %s, got: %s", tc.expectedRData, answerString)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,88 +0,0 @@
|
||||
package dns
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
)
|
||||
|
||||
func TestLocalResolver_ServeDNS(t *testing.T) {
|
||||
recordA := nbdns.SimpleRecord{
|
||||
Name: "peera.netbird.cloud.",
|
||||
Type: 1,
|
||||
Class: nbdns.DefaultClass,
|
||||
TTL: 300,
|
||||
RData: "1.2.3.4",
|
||||
}
|
||||
|
||||
recordCNAME := nbdns.SimpleRecord{
|
||||
Name: "peerb.netbird.cloud.",
|
||||
Type: 5,
|
||||
Class: nbdns.DefaultClass,
|
||||
TTL: 300,
|
||||
RData: "www.netbird.io",
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
inputRecord nbdns.SimpleRecord
|
||||
inputMSG *dns.Msg
|
||||
responseShouldBeNil bool
|
||||
}{
|
||||
{
|
||||
name: "Should Resolve A Record",
|
||||
inputRecord: recordA,
|
||||
inputMSG: new(dns.Msg).SetQuestion(recordA.Name, dns.TypeA),
|
||||
},
|
||||
{
|
||||
name: "Should Resolve CNAME Record",
|
||||
inputRecord: recordCNAME,
|
||||
inputMSG: new(dns.Msg).SetQuestion(recordCNAME.Name, dns.TypeCNAME),
|
||||
},
|
||||
{
|
||||
name: "Should Not Write When Not Found A Record",
|
||||
inputRecord: recordA,
|
||||
inputMSG: new(dns.Msg).SetQuestion("not.found.com", dns.TypeA),
|
||||
responseShouldBeNil: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, testCase := range testCases {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
resolver := &localResolver{
|
||||
registeredMap: make(registrationMap),
|
||||
}
|
||||
_, _ = resolver.registerRecord(testCase.inputRecord)
|
||||
var responseMSG *dns.Msg
|
||||
responseWriter := &mockResponseWriter{
|
||||
WriteMsgFunc: func(m *dns.Msg) error {
|
||||
responseMSG = m
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
resolver.ServeDNS(responseWriter, testCase.inputMSG)
|
||||
|
||||
if responseMSG == nil || len(responseMSG.Answer) == 0 {
|
||||
if testCase.responseShouldBeNil {
|
||||
return
|
||||
}
|
||||
t.Fatalf("should write a response message")
|
||||
}
|
||||
|
||||
answerString := responseMSG.Answer[0].String()
|
||||
if !strings.Contains(answerString, testCase.inputRecord.Name) {
|
||||
t.Fatalf("answer doesn't contain the same domain name: \nWant: %s\nGot:%s", testCase.name, answerString)
|
||||
}
|
||||
if !strings.Contains(answerString, dns.Type(testCase.inputRecord.Type).String()) {
|
||||
t.Fatalf("answer doesn't contain the correct type: \nWant: %s\nGot:%s", dns.Type(testCase.inputRecord.Type).String(), answerString)
|
||||
}
|
||||
if !strings.Contains(answerString, testCase.inputRecord.RData) {
|
||||
t.Fatalf("answer doesn't contain the same address: \nWant: %s\nGot:%s", testCase.inputRecord.RData, answerString)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,26 +0,0 @@
|
||||
package dns
|
||||
|
||||
import (
|
||||
"net"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
type mockResponseWriter struct {
|
||||
WriteMsgFunc func(m *dns.Msg) error
|
||||
}
|
||||
|
||||
func (rw *mockResponseWriter) WriteMsg(m *dns.Msg) error {
|
||||
if rw.WriteMsgFunc != nil {
|
||||
return rw.WriteMsgFunc(m)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (rw *mockResponseWriter) LocalAddr() net.Addr { return nil }
|
||||
func (rw *mockResponseWriter) RemoteAddr() net.Addr { return nil }
|
||||
func (rw *mockResponseWriter) Write([]byte) (int, error) { return 0, nil }
|
||||
func (rw *mockResponseWriter) Close() error { return nil }
|
||||
func (rw *mockResponseWriter) TsigStatus() error { return nil }
|
||||
func (rw *mockResponseWriter) TsigTimersOnly(bool) {}
|
||||
func (rw *mockResponseWriter) Hijack() {}
|
||||
@@ -15,6 +15,8 @@ import (
|
||||
"golang.org/x/exp/maps"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/netstack"
|
||||
"github.com/netbirdio/netbird/client/internal/dns/local"
|
||||
"github.com/netbirdio/netbird/client/internal/dns/types"
|
||||
"github.com/netbirdio/netbird/client/internal/listener"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
@@ -46,8 +48,6 @@ type Server interface {
|
||||
ProbeAvailability()
|
||||
}
|
||||
|
||||
type handlerID string
|
||||
|
||||
type nsGroupsByDomain struct {
|
||||
domain string
|
||||
groups []*nbdns.NameServerGroup
|
||||
@@ -61,7 +61,7 @@ type DefaultServer struct {
|
||||
mux sync.Mutex
|
||||
service service
|
||||
dnsMuxMap registeredHandlerMap
|
||||
localResolver *localResolver
|
||||
localResolver *local.Resolver
|
||||
wgInterface WGIface
|
||||
hostManager hostManager
|
||||
updateSerial uint64
|
||||
@@ -84,9 +84,9 @@ type DefaultServer struct {
|
||||
|
||||
type handlerWithStop interface {
|
||||
dns.Handler
|
||||
stop()
|
||||
probeAvailability()
|
||||
id() handlerID
|
||||
Stop()
|
||||
ProbeAvailability()
|
||||
ID() types.HandlerID
|
||||
}
|
||||
|
||||
type handlerWrapper struct {
|
||||
@@ -95,7 +95,7 @@ type handlerWrapper struct {
|
||||
priority int
|
||||
}
|
||||
|
||||
type registeredHandlerMap map[handlerID]handlerWrapper
|
||||
type registeredHandlerMap map[types.HandlerID]handlerWrapper
|
||||
|
||||
// NewDefaultServer returns a new dns server
|
||||
func NewDefaultServer(
|
||||
@@ -171,16 +171,14 @@ func newDefaultServer(
|
||||
handlerChain := NewHandlerChain()
|
||||
ctx, stop := context.WithCancel(ctx)
|
||||
defaultServer := &DefaultServer{
|
||||
ctx: ctx,
|
||||
ctxCancel: stop,
|
||||
disableSys: disableSys,
|
||||
service: dnsService,
|
||||
handlerChain: handlerChain,
|
||||
extraDomains: make(map[domain.Domain]int),
|
||||
dnsMuxMap: make(registeredHandlerMap),
|
||||
localResolver: &localResolver{
|
||||
registeredMap: make(registrationMap),
|
||||
},
|
||||
ctx: ctx,
|
||||
ctxCancel: stop,
|
||||
disableSys: disableSys,
|
||||
service: dnsService,
|
||||
handlerChain: handlerChain,
|
||||
extraDomains: make(map[domain.Domain]int),
|
||||
dnsMuxMap: make(registeredHandlerMap),
|
||||
localResolver: local.NewResolver(),
|
||||
wgInterface: wgInterface,
|
||||
statusRecorder: statusRecorder,
|
||||
stateManager: stateManager,
|
||||
@@ -403,7 +401,7 @@ func (s *DefaultServer) ProbeAvailability() {
|
||||
wg.Add(1)
|
||||
go func(mux handlerWithStop) {
|
||||
defer wg.Done()
|
||||
mux.probeAvailability()
|
||||
mux.ProbeAvailability()
|
||||
}(mux.handler)
|
||||
}
|
||||
wg.Wait()
|
||||
@@ -420,7 +418,7 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
|
||||
s.service.Stop()
|
||||
}
|
||||
|
||||
localMuxUpdates, localRecordsByDomain, err := s.buildLocalHandlerUpdate(update.CustomZones)
|
||||
localMuxUpdates, localRecords, err := s.buildLocalHandlerUpdate(update.CustomZones)
|
||||
if err != nil {
|
||||
return fmt.Errorf("local handler updater: %w", err)
|
||||
}
|
||||
@@ -434,7 +432,7 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
|
||||
s.updateMux(muxUpdates)
|
||||
|
||||
// register local records
|
||||
s.updateLocalResolver(localRecordsByDomain)
|
||||
s.localResolver.Update(localRecords)
|
||||
|
||||
s.currentConfig = dnsConfigToHostDNSConfig(update, s.service.RuntimeIP(), s.service.RuntimePort())
|
||||
|
||||
@@ -516,11 +514,9 @@ func (s *DefaultServer) handleErrNoGroupaAll(err error) {
|
||||
)
|
||||
}
|
||||
|
||||
func (s *DefaultServer) buildLocalHandlerUpdate(
|
||||
customZones []nbdns.CustomZone,
|
||||
) ([]handlerWrapper, map[string][]nbdns.SimpleRecord, error) {
|
||||
func (s *DefaultServer) buildLocalHandlerUpdate(customZones []nbdns.CustomZone) ([]handlerWrapper, []nbdns.SimpleRecord, error) {
|
||||
var muxUpdates []handlerWrapper
|
||||
localRecords := make(map[string][]nbdns.SimpleRecord)
|
||||
var localRecords []nbdns.SimpleRecord
|
||||
|
||||
for _, customZone := range customZones {
|
||||
if len(customZone.Records) == 0 {
|
||||
@@ -534,17 +530,13 @@ func (s *DefaultServer) buildLocalHandlerUpdate(
|
||||
priority: PriorityMatchDomain,
|
||||
})
|
||||
|
||||
// group all records under this domain
|
||||
for _, record := range customZone.Records {
|
||||
var class uint16 = dns.ClassINET
|
||||
if record.Class != nbdns.DefaultClass {
|
||||
log.Warnf("received an invalid class type: %s", record.Class)
|
||||
continue
|
||||
}
|
||||
|
||||
key := buildRecordKey(record.Name, class, uint16(record.Type))
|
||||
|
||||
localRecords[key] = append(localRecords[key], record)
|
||||
// zone records contain the fqdn, so we can just flatten them
|
||||
localRecords = append(localRecords, record)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -627,7 +619,7 @@ func (s *DefaultServer) createHandlersForDomainGroup(domainGroup nsGroupsByDomai
|
||||
}
|
||||
|
||||
if len(handler.upstreamServers) == 0 {
|
||||
handler.stop()
|
||||
handler.Stop()
|
||||
log.Errorf("received a nameserver group with an invalid nameserver list")
|
||||
continue
|
||||
}
|
||||
@@ -656,7 +648,7 @@ func (s *DefaultServer) updateMux(muxUpdates []handlerWrapper) {
|
||||
// this will introduce a short period of time when the server is not able to handle DNS requests
|
||||
for _, existing := range s.dnsMuxMap {
|
||||
s.deregisterHandler([]string{existing.domain}, existing.priority)
|
||||
existing.handler.stop()
|
||||
existing.handler.Stop()
|
||||
}
|
||||
|
||||
muxUpdateMap := make(registeredHandlerMap)
|
||||
@@ -667,7 +659,7 @@ func (s *DefaultServer) updateMux(muxUpdates []handlerWrapper) {
|
||||
containsRootUpdate = true
|
||||
}
|
||||
s.registerHandler([]string{update.domain}, update.handler, update.priority)
|
||||
muxUpdateMap[update.handler.id()] = update
|
||||
muxUpdateMap[update.handler.ID()] = update
|
||||
}
|
||||
|
||||
// If there's no root update and we had a root handler, restore it
|
||||
@@ -683,33 +675,6 @@ func (s *DefaultServer) updateMux(muxUpdates []handlerWrapper) {
|
||||
s.dnsMuxMap = muxUpdateMap
|
||||
}
|
||||
|
||||
func (s *DefaultServer) updateLocalResolver(update map[string][]nbdns.SimpleRecord) {
|
||||
// remove old records that are no longer present
|
||||
for key := range s.localResolver.registeredMap {
|
||||
_, found := update[key]
|
||||
if !found {
|
||||
s.localResolver.deleteRecord(key)
|
||||
}
|
||||
}
|
||||
|
||||
updatedMap := make(registrationMap)
|
||||
for _, recs := range update {
|
||||
for _, rec := range recs {
|
||||
// convert the record to a dns.RR and register
|
||||
key, err := s.localResolver.registerRecord(rec)
|
||||
if err != nil {
|
||||
log.Warnf("got an error while registering the record (%s), error: %v",
|
||||
rec.String(), err)
|
||||
continue
|
||||
}
|
||||
|
||||
updatedMap[key] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
s.localResolver.registeredMap = updatedMap
|
||||
}
|
||||
|
||||
func getNSHostPort(ns nbdns.NameServer) string {
|
||||
return fmt.Sprintf("%s:%d", ns.IP.String(), ns.Port)
|
||||
}
|
||||
|
||||
@@ -23,6 +23,9 @@ import (
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
pfmock "github.com/netbirdio/netbird/client/iface/mocks"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
"github.com/netbirdio/netbird/client/internal/dns/local"
|
||||
"github.com/netbirdio/netbird/client/internal/dns/test"
|
||||
"github.com/netbirdio/netbird/client/internal/dns/types"
|
||||
"github.com/netbirdio/netbird/client/internal/netflow"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
@@ -107,6 +110,7 @@ func generateDummyHandler(domain string, servers []nbdns.NameServer) *upstreamRe
|
||||
}
|
||||
|
||||
func TestUpdateDNSServer(t *testing.T) {
|
||||
|
||||
nameServers := []nbdns.NameServer{
|
||||
{
|
||||
IP: netip.MustParseAddr("8.8.8.8"),
|
||||
@@ -120,22 +124,21 @@ func TestUpdateDNSServer(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
dummyHandler := &localResolver{}
|
||||
dummyHandler := local.NewResolver()
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
initUpstreamMap registeredHandlerMap
|
||||
initLocalMap registrationMap
|
||||
initLocalRecords []nbdns.SimpleRecord
|
||||
initSerial uint64
|
||||
inputSerial uint64
|
||||
inputUpdate nbdns.Config
|
||||
shouldFail bool
|
||||
expectedUpstreamMap registeredHandlerMap
|
||||
expectedLocalMap registrationMap
|
||||
expectedLocalQs []dns.Question
|
||||
}{
|
||||
{
|
||||
name: "Initial Config Should Succeed",
|
||||
initLocalMap: make(registrationMap),
|
||||
initUpstreamMap: make(registeredHandlerMap),
|
||||
initSerial: 0,
|
||||
inputSerial: 1,
|
||||
@@ -159,30 +162,30 @@ func TestUpdateDNSServer(t *testing.T) {
|
||||
},
|
||||
},
|
||||
expectedUpstreamMap: registeredHandlerMap{
|
||||
generateDummyHandler("netbird.io", nameServers).id(): handlerWrapper{
|
||||
generateDummyHandler("netbird.io", nameServers).ID(): handlerWrapper{
|
||||
domain: "netbird.io",
|
||||
handler: dummyHandler,
|
||||
priority: PriorityMatchDomain,
|
||||
},
|
||||
dummyHandler.id(): handlerWrapper{
|
||||
dummyHandler.ID(): handlerWrapper{
|
||||
domain: "netbird.cloud",
|
||||
handler: dummyHandler,
|
||||
priority: PriorityMatchDomain,
|
||||
},
|
||||
generateDummyHandler(".", nameServers).id(): handlerWrapper{
|
||||
generateDummyHandler(".", nameServers).ID(): handlerWrapper{
|
||||
domain: nbdns.RootZone,
|
||||
handler: dummyHandler,
|
||||
priority: PriorityDefault,
|
||||
},
|
||||
},
|
||||
expectedLocalMap: registrationMap{buildRecordKey(zoneRecords[0].Name, 1, 1): struct{}{}},
|
||||
expectedLocalQs: []dns.Question{{Name: "peera.netbird.cloud.", Qtype: dns.TypeA, Qclass: dns.ClassINET}},
|
||||
},
|
||||
{
|
||||
name: "New Config Should Succeed",
|
||||
initLocalMap: registrationMap{"netbird.cloud": struct{}{}},
|
||||
name: "New Config Should Succeed",
|
||||
initLocalRecords: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: 1, Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}},
|
||||
initUpstreamMap: registeredHandlerMap{
|
||||
generateDummyHandler(zoneRecords[0].Name, nameServers).id(): handlerWrapper{
|
||||
domain: buildRecordKey(zoneRecords[0].Name, 1, 1),
|
||||
generateDummyHandler(zoneRecords[0].Name, nameServers).ID(): handlerWrapper{
|
||||
domain: "netbird.cloud",
|
||||
handler: dummyHandler,
|
||||
priority: PriorityMatchDomain,
|
||||
},
|
||||
@@ -205,7 +208,7 @@ func TestUpdateDNSServer(t *testing.T) {
|
||||
},
|
||||
},
|
||||
expectedUpstreamMap: registeredHandlerMap{
|
||||
generateDummyHandler("netbird.io", nameServers).id(): handlerWrapper{
|
||||
generateDummyHandler("netbird.io", nameServers).ID(): handlerWrapper{
|
||||
domain: "netbird.io",
|
||||
handler: dummyHandler,
|
||||
priority: PriorityMatchDomain,
|
||||
@@ -216,22 +219,22 @@ func TestUpdateDNSServer(t *testing.T) {
|
||||
priority: PriorityMatchDomain,
|
||||
},
|
||||
},
|
||||
expectedLocalMap: registrationMap{buildRecordKey(zoneRecords[0].Name, 1, 1): struct{}{}},
|
||||
expectedLocalQs: []dns.Question{{Name: zoneRecords[0].Name, Qtype: 1, Qclass: 1}},
|
||||
},
|
||||
{
|
||||
name: "Smaller Config Serial Should Be Skipped",
|
||||
initLocalMap: make(registrationMap),
|
||||
initUpstreamMap: make(registeredHandlerMap),
|
||||
initSerial: 2,
|
||||
inputSerial: 1,
|
||||
shouldFail: true,
|
||||
name: "Smaller Config Serial Should Be Skipped",
|
||||
initLocalRecords: []nbdns.SimpleRecord{},
|
||||
initUpstreamMap: make(registeredHandlerMap),
|
||||
initSerial: 2,
|
||||
inputSerial: 1,
|
||||
shouldFail: true,
|
||||
},
|
||||
{
|
||||
name: "Empty NS Group Domain Or Not Primary Element Should Fail",
|
||||
initLocalMap: make(registrationMap),
|
||||
initUpstreamMap: make(registeredHandlerMap),
|
||||
initSerial: 0,
|
||||
inputSerial: 1,
|
||||
name: "Empty NS Group Domain Or Not Primary Element Should Fail",
|
||||
initLocalRecords: []nbdns.SimpleRecord{},
|
||||
initUpstreamMap: make(registeredHandlerMap),
|
||||
initSerial: 0,
|
||||
inputSerial: 1,
|
||||
inputUpdate: nbdns.Config{
|
||||
ServiceEnable: true,
|
||||
CustomZones: []nbdns.CustomZone{
|
||||
@@ -249,11 +252,11 @@ func TestUpdateDNSServer(t *testing.T) {
|
||||
shouldFail: true,
|
||||
},
|
||||
{
|
||||
name: "Invalid NS Group Nameservers list Should Fail",
|
||||
initLocalMap: make(registrationMap),
|
||||
initUpstreamMap: make(registeredHandlerMap),
|
||||
initSerial: 0,
|
||||
inputSerial: 1,
|
||||
name: "Invalid NS Group Nameservers list Should Fail",
|
||||
initLocalRecords: []nbdns.SimpleRecord{},
|
||||
initUpstreamMap: make(registeredHandlerMap),
|
||||
initSerial: 0,
|
||||
inputSerial: 1,
|
||||
inputUpdate: nbdns.Config{
|
||||
ServiceEnable: true,
|
||||
CustomZones: []nbdns.CustomZone{
|
||||
@@ -271,11 +274,11 @@ func TestUpdateDNSServer(t *testing.T) {
|
||||
shouldFail: true,
|
||||
},
|
||||
{
|
||||
name: "Invalid Custom Zone Records list Should Skip",
|
||||
initLocalMap: make(registrationMap),
|
||||
initUpstreamMap: make(registeredHandlerMap),
|
||||
initSerial: 0,
|
||||
inputSerial: 1,
|
||||
name: "Invalid Custom Zone Records list Should Skip",
|
||||
initLocalRecords: []nbdns.SimpleRecord{},
|
||||
initUpstreamMap: make(registeredHandlerMap),
|
||||
initSerial: 0,
|
||||
inputSerial: 1,
|
||||
inputUpdate: nbdns.Config{
|
||||
ServiceEnable: true,
|
||||
CustomZones: []nbdns.CustomZone{
|
||||
@@ -290,17 +293,17 @@ func TestUpdateDNSServer(t *testing.T) {
|
||||
},
|
||||
},
|
||||
},
|
||||
expectedUpstreamMap: registeredHandlerMap{generateDummyHandler(".", nameServers).id(): handlerWrapper{
|
||||
expectedUpstreamMap: registeredHandlerMap{generateDummyHandler(".", nameServers).ID(): handlerWrapper{
|
||||
domain: ".",
|
||||
handler: dummyHandler,
|
||||
priority: PriorityDefault,
|
||||
}},
|
||||
},
|
||||
{
|
||||
name: "Empty Config Should Succeed and Clean Maps",
|
||||
initLocalMap: registrationMap{"netbird.cloud": struct{}{}},
|
||||
name: "Empty Config Should Succeed and Clean Maps",
|
||||
initLocalRecords: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}},
|
||||
initUpstreamMap: registeredHandlerMap{
|
||||
generateDummyHandler(zoneRecords[0].Name, nameServers).id(): handlerWrapper{
|
||||
generateDummyHandler(zoneRecords[0].Name, nameServers).ID(): handlerWrapper{
|
||||
domain: zoneRecords[0].Name,
|
||||
handler: dummyHandler,
|
||||
priority: PriorityMatchDomain,
|
||||
@@ -310,13 +313,13 @@ func TestUpdateDNSServer(t *testing.T) {
|
||||
inputSerial: 1,
|
||||
inputUpdate: nbdns.Config{ServiceEnable: true},
|
||||
expectedUpstreamMap: make(registeredHandlerMap),
|
||||
expectedLocalMap: make(registrationMap),
|
||||
expectedLocalQs: []dns.Question{},
|
||||
},
|
||||
{
|
||||
name: "Disabled Service Should clean map",
|
||||
initLocalMap: registrationMap{"netbird.cloud": struct{}{}},
|
||||
name: "Disabled Service Should clean map",
|
||||
initLocalRecords: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}},
|
||||
initUpstreamMap: registeredHandlerMap{
|
||||
generateDummyHandler(zoneRecords[0].Name, nameServers).id(): handlerWrapper{
|
||||
generateDummyHandler(zoneRecords[0].Name, nameServers).ID(): handlerWrapper{
|
||||
domain: zoneRecords[0].Name,
|
||||
handler: dummyHandler,
|
||||
priority: PriorityMatchDomain,
|
||||
@@ -326,7 +329,7 @@ func TestUpdateDNSServer(t *testing.T) {
|
||||
inputSerial: 1,
|
||||
inputUpdate: nbdns.Config{ServiceEnable: false},
|
||||
expectedUpstreamMap: make(registeredHandlerMap),
|
||||
expectedLocalMap: make(registrationMap),
|
||||
expectedLocalQs: []dns.Question{},
|
||||
},
|
||||
}
|
||||
|
||||
@@ -377,7 +380,7 @@ func TestUpdateDNSServer(t *testing.T) {
|
||||
}()
|
||||
|
||||
dnsServer.dnsMuxMap = testCase.initUpstreamMap
|
||||
dnsServer.localResolver.registeredMap = testCase.initLocalMap
|
||||
dnsServer.localResolver.Update(testCase.initLocalRecords)
|
||||
dnsServer.updateSerial = testCase.initSerial
|
||||
|
||||
err = dnsServer.UpdateDNSServer(testCase.inputSerial, testCase.inputUpdate)
|
||||
@@ -399,15 +402,23 @@ func TestUpdateDNSServer(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
if len(dnsServer.localResolver.registeredMap) != len(testCase.expectedLocalMap) {
|
||||
t.Fatalf("update local failed, registered map size is different than expected, want %d, got %d", len(testCase.expectedLocalMap), len(dnsServer.localResolver.registeredMap))
|
||||
var responseMSG *dns.Msg
|
||||
responseWriter := &test.MockResponseWriter{
|
||||
WriteMsgFunc: func(m *dns.Msg) error {
|
||||
responseMSG = m
|
||||
return nil
|
||||
},
|
||||
}
|
||||
for _, q := range testCase.expectedLocalQs {
|
||||
dnsServer.localResolver.ServeDNS(responseWriter, &dns.Msg{
|
||||
Question: []dns.Question{q},
|
||||
})
|
||||
}
|
||||
|
||||
for key := range testCase.expectedLocalMap {
|
||||
_, found := dnsServer.localResolver.registeredMap[key]
|
||||
if !found {
|
||||
t.Fatalf("update local failed, key %s was not found in the localResolver.registeredMap: %#v", key, dnsServer.localResolver.registeredMap)
|
||||
}
|
||||
if len(testCase.expectedLocalQs) > 0 {
|
||||
assert.NotNil(t, responseMSG, "response message should not be nil")
|
||||
assert.Equal(t, dns.RcodeSuccess, responseMSG.Rcode, "response code should be success")
|
||||
assert.NotEmpty(t, responseMSG.Answer, "response message should have answers")
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -491,11 +502,12 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) {
|
||||
dnsServer.dnsMuxMap = registeredHandlerMap{
|
||||
"id1": handlerWrapper{
|
||||
domain: zoneRecords[0].Name,
|
||||
handler: &localResolver{},
|
||||
handler: &local.Resolver{},
|
||||
priority: PriorityMatchDomain,
|
||||
},
|
||||
}
|
||||
dnsServer.localResolver.registeredMap = registrationMap{"netbird.cloud": struct{}{}}
|
||||
//dnsServer.localResolver.RegisteredMap = local.RegistrationMap{local.BuildRecordKey("netbird.cloud", dns.ClassINET, dns.TypeA): struct{}{}}
|
||||
dnsServer.localResolver.Update([]nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}})
|
||||
dnsServer.updateSerial = 0
|
||||
|
||||
nameServers := []nbdns.NameServer{
|
||||
@@ -582,7 +594,7 @@ func TestDNSServerStartStop(t *testing.T) {
|
||||
}
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
defer dnsServer.Stop()
|
||||
_, err = dnsServer.localResolver.registerRecord(zoneRecords[0])
|
||||
err = dnsServer.localResolver.RegisterRecord(zoneRecords[0])
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
@@ -630,13 +642,11 @@ func TestDNSServerStartStop(t *testing.T) {
|
||||
func TestDNSServerUpstreamDeactivateCallback(t *testing.T) {
|
||||
hostManager := &mockHostConfigurator{}
|
||||
server := DefaultServer{
|
||||
ctx: context.Background(),
|
||||
service: NewServiceViaMemory(&mocWGIface{}),
|
||||
localResolver: &localResolver{
|
||||
registeredMap: make(registrationMap),
|
||||
},
|
||||
handlerChain: NewHandlerChain(),
|
||||
hostManager: hostManager,
|
||||
ctx: context.Background(),
|
||||
service: NewServiceViaMemory(&mocWGIface{}),
|
||||
localResolver: local.NewResolver(),
|
||||
handlerChain: NewHandlerChain(),
|
||||
hostManager: hostManager,
|
||||
currentConfig: HostDNSConfig{
|
||||
Domains: []DomainConfig{
|
||||
{false, "domain0", false},
|
||||
@@ -1004,7 +1014,7 @@ func TestHandlerChain_DomainPriorities(t *testing.T) {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
r := new(dns.Msg)
|
||||
r.SetQuestion(tc.query, dns.TypeA)
|
||||
w := &ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
|
||||
w := &ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}}
|
||||
|
||||
if mh, ok := tc.expectedHandler.(*MockHandler); ok {
|
||||
mh.On("ServeDNS", mock.Anything, r).Once()
|
||||
@@ -1037,9 +1047,9 @@ type mockHandler struct {
|
||||
}
|
||||
|
||||
func (m *mockHandler) ServeDNS(dns.ResponseWriter, *dns.Msg) {}
|
||||
func (m *mockHandler) stop() {}
|
||||
func (m *mockHandler) probeAvailability() {}
|
||||
func (m *mockHandler) id() handlerID { return handlerID(m.Id) }
|
||||
func (m *mockHandler) Stop() {}
|
||||
func (m *mockHandler) ProbeAvailability() {}
|
||||
func (m *mockHandler) ID() types.HandlerID { return types.HandlerID(m.Id) }
|
||||
|
||||
type mockService struct{}
|
||||
|
||||
@@ -1113,7 +1123,7 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
|
||||
name string
|
||||
initialHandlers registeredHandlerMap
|
||||
updates []handlerWrapper
|
||||
expectedHandlers map[string]string // map[handlerID]domain
|
||||
expectedHandlers map[string]string // map[HandlerID]domain
|
||||
description string
|
||||
}{
|
||||
{
|
||||
@@ -1409,7 +1419,7 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
|
||||
|
||||
// Check each expected handler
|
||||
for id, expectedDomain := range tt.expectedHandlers {
|
||||
handler, exists := server.dnsMuxMap[handlerID(id)]
|
||||
handler, exists := server.dnsMuxMap[types.HandlerID(id)]
|
||||
assert.True(t, exists, "Expected handler %s not found", id)
|
||||
if exists {
|
||||
assert.Equal(t, expectedDomain, handler.domain,
|
||||
@@ -1418,9 +1428,9 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
|
||||
}
|
||||
|
||||
// Verify no unexpected handlers exist
|
||||
for handlerID := range server.dnsMuxMap {
|
||||
_, expected := tt.expectedHandlers[string(handlerID)]
|
||||
assert.True(t, expected, "Unexpected handler found: %s", handlerID)
|
||||
for HandlerID := range server.dnsMuxMap {
|
||||
_, expected := tt.expectedHandlers[string(HandlerID)]
|
||||
assert.True(t, expected, "Unexpected handler found: %s", HandlerID)
|
||||
}
|
||||
|
||||
// Verify the handlerChain state and order
|
||||
@@ -1696,7 +1706,7 @@ func TestExtraDomains(t *testing.T) {
|
||||
handlerChain: NewHandlerChain(),
|
||||
wgInterface: &mocWGIface{},
|
||||
hostManager: mockHostConfig,
|
||||
localResolver: &localResolver{},
|
||||
localResolver: &local.Resolver{},
|
||||
service: mockSvc,
|
||||
statusRecorder: peer.NewRecorder("test"),
|
||||
extraDomains: make(map[domain.Domain]int),
|
||||
@@ -1781,7 +1791,7 @@ func TestExtraDomainsRefCounting(t *testing.T) {
|
||||
ctx: context.Background(),
|
||||
handlerChain: NewHandlerChain(),
|
||||
hostManager: mockHostConfig,
|
||||
localResolver: &localResolver{},
|
||||
localResolver: &local.Resolver{},
|
||||
service: mockSvc,
|
||||
statusRecorder: peer.NewRecorder("test"),
|
||||
extraDomains: make(map[domain.Domain]int),
|
||||
@@ -1833,7 +1843,7 @@ func TestUpdateConfigWithExistingExtraDomains(t *testing.T) {
|
||||
ctx: context.Background(),
|
||||
handlerChain: NewHandlerChain(),
|
||||
hostManager: mockHostConfig,
|
||||
localResolver: &localResolver{},
|
||||
localResolver: &local.Resolver{},
|
||||
service: mockSvc,
|
||||
statusRecorder: peer.NewRecorder("test"),
|
||||
extraDomains: make(map[domain.Domain]int),
|
||||
@@ -1916,7 +1926,7 @@ func TestDomainCaseHandling(t *testing.T) {
|
||||
ctx: context.Background(),
|
||||
handlerChain: NewHandlerChain(),
|
||||
hostManager: mockHostConfig,
|
||||
localResolver: &localResolver{},
|
||||
localResolver: &local.Resolver{},
|
||||
service: mockSvc,
|
||||
statusRecorder: peer.NewRecorder("test"),
|
||||
extraDomains: make(map[domain.Domain]int),
|
||||
|
||||
26
client/internal/dns/test/mock.go
Normal file
26
client/internal/dns/test/mock.go
Normal file
@@ -0,0 +1,26 @@
|
||||
package test
|
||||
|
||||
import (
|
||||
"net"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
type MockResponseWriter struct {
|
||||
WriteMsgFunc func(m *dns.Msg) error
|
||||
}
|
||||
|
||||
func (rw *MockResponseWriter) WriteMsg(m *dns.Msg) error {
|
||||
if rw.WriteMsgFunc != nil {
|
||||
return rw.WriteMsgFunc(m)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (rw *MockResponseWriter) LocalAddr() net.Addr { return nil }
|
||||
func (rw *MockResponseWriter) RemoteAddr() net.Addr { return nil }
|
||||
func (rw *MockResponseWriter) Write([]byte) (int, error) { return 0, nil }
|
||||
func (rw *MockResponseWriter) Close() error { return nil }
|
||||
func (rw *MockResponseWriter) TsigStatus() error { return nil }
|
||||
func (rw *MockResponseWriter) TsigTimersOnly(bool) {}
|
||||
func (rw *MockResponseWriter) Hijack() {}
|
||||
3
client/internal/dns/types/types.go
Normal file
3
client/internal/dns/types/types.go
Normal file
@@ -0,0 +1,3 @@
|
||||
package types
|
||||
|
||||
type HandlerID string
|
||||
@@ -19,6 +19,7 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/internal/dns/types"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
)
|
||||
@@ -81,21 +82,21 @@ func (u *upstreamResolverBase) String() string {
|
||||
}
|
||||
|
||||
// ID returns the unique handler ID
|
||||
func (u *upstreamResolverBase) id() handlerID {
|
||||
func (u *upstreamResolverBase) ID() types.HandlerID {
|
||||
servers := slices.Clone(u.upstreamServers)
|
||||
slices.Sort(servers)
|
||||
|
||||
hash := sha256.New()
|
||||
hash.Write([]byte(u.domain + ":"))
|
||||
hash.Write([]byte(strings.Join(servers, ",")))
|
||||
return handlerID("upstream-" + hex.EncodeToString(hash.Sum(nil)[:8]))
|
||||
return types.HandlerID("upstream-" + hex.EncodeToString(hash.Sum(nil)[:8]))
|
||||
}
|
||||
|
||||
func (u *upstreamResolverBase) MatchSubdomains() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (u *upstreamResolverBase) stop() {
|
||||
func (u *upstreamResolverBase) Stop() {
|
||||
log.Debugf("stopping serving DNS for upstreams %s", u.upstreamServers)
|
||||
u.cancel()
|
||||
}
|
||||
@@ -198,9 +199,9 @@ func (u *upstreamResolverBase) checkUpstreamFails(err error) {
|
||||
)
|
||||
}
|
||||
|
||||
// probeAvailability tests all upstream servers simultaneously and
|
||||
// ProbeAvailability tests all upstream servers simultaneously and
|
||||
// disables the resolver if none work
|
||||
func (u *upstreamResolverBase) probeAvailability() {
|
||||
func (u *upstreamResolverBase) ProbeAvailability() {
|
||||
u.mutex.Lock()
|
||||
defer u.mutex.Unlock()
|
||||
|
||||
|
||||
@@ -8,6 +8,8 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/dns/test"
|
||||
)
|
||||
|
||||
func TestUpstreamResolver_ServeDNS(t *testing.T) {
|
||||
@@ -66,7 +68,7 @@ func TestUpstreamResolver_ServeDNS(t *testing.T) {
|
||||
}
|
||||
|
||||
var responseMSG *dns.Msg
|
||||
responseWriter := &mockResponseWriter{
|
||||
responseWriter := &test.MockResponseWriter{
|
||||
WriteMsgFunc: func(m *dns.Msg) error {
|
||||
responseMSG = m
|
||||
return nil
|
||||
@@ -130,7 +132,7 @@ func TestUpstreamResolver_DeactivationReactivation(t *testing.T) {
|
||||
resolver.failsTillDeact = 0
|
||||
resolver.reactivatePeriod = time.Microsecond * 100
|
||||
|
||||
responseWriter := &mockResponseWriter{
|
||||
responseWriter := &test.MockResponseWriter{
|
||||
WriteMsgFunc: func(m *dns.Msg) error { return nil },
|
||||
}
|
||||
|
||||
|
||||
@@ -51,14 +51,16 @@ func (s *Server) DebugBundle(_ context.Context, req *proto.DebugBundleRequest) (
|
||||
}
|
||||
|
||||
if req.GetUploadURL() == "" {
|
||||
|
||||
return &proto.DebugBundleResponse{Path: path}, nil
|
||||
}
|
||||
key, err := uploadDebugBundle(context.Background(), req.GetUploadURL(), s.config.ManagementURL.String(), path)
|
||||
if err != nil {
|
||||
log.Errorf("failed to upload debug bundle to %s: %v", req.GetUploadURL(), err)
|
||||
return &proto.DebugBundleResponse{Path: path, UploadFailureReason: err.Error()}, nil
|
||||
}
|
||||
|
||||
log.Infof("debug bundle uploaded to %s with key %s", req.GetUploadURL(), key)
|
||||
|
||||
return &proto.DebugBundleResponse{Path: path, UploadedKey: key}, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -51,14 +51,17 @@ const (
|
||||
)
|
||||
|
||||
func main() {
|
||||
daemonAddr, showSettings, showNetworks, errorMsg, saveLogsInFile := parseFlags()
|
||||
daemonAddr, showSettings, showNetworks, showDebug, errorMsg, saveLogsInFile := parseFlags()
|
||||
|
||||
// Initialize file logging if needed.
|
||||
var logFile string
|
||||
if saveLogsInFile {
|
||||
if err := initLogFile(); err != nil {
|
||||
file, err := initLogFile()
|
||||
if err != nil {
|
||||
log.Errorf("error while initializing log: %v", err)
|
||||
return
|
||||
}
|
||||
logFile = file
|
||||
}
|
||||
|
||||
// Create the Fyne application.
|
||||
@@ -72,13 +75,13 @@ func main() {
|
||||
}
|
||||
|
||||
// Create the service client (this also builds the settings or networks UI if requested).
|
||||
client := newServiceClient(daemonAddr, a, showSettings, showNetworks)
|
||||
client := newServiceClient(daemonAddr, logFile, a, showSettings, showNetworks, showDebug)
|
||||
|
||||
// Watch for theme/settings changes to update the icon.
|
||||
go watchSettingsChanges(a, client)
|
||||
|
||||
// Run in window mode if any UI flag was set.
|
||||
if showSettings || showNetworks {
|
||||
if showSettings || showNetworks || showDebug {
|
||||
a.Run()
|
||||
return
|
||||
}
|
||||
@@ -99,7 +102,7 @@ func main() {
|
||||
}
|
||||
|
||||
// parseFlags reads and returns all needed command-line flags.
|
||||
func parseFlags() (daemonAddr string, showSettings, showNetworks bool, errorMsg string, saveLogsInFile bool) {
|
||||
func parseFlags() (daemonAddr string, showSettings, showNetworks, showDebug bool, errorMsg string, saveLogsInFile bool) {
|
||||
defaultDaemonAddr := "unix:///var/run/netbird.sock"
|
||||
if runtime.GOOS == "windows" {
|
||||
defaultDaemonAddr = "tcp://127.0.0.1:41731"
|
||||
@@ -107,25 +110,17 @@ func parseFlags() (daemonAddr string, showSettings, showNetworks bool, errorMsg
|
||||
flag.StringVar(&daemonAddr, "daemon-addr", defaultDaemonAddr, "Daemon service address to serve CLI requests [unix|tcp]://[path|host:port]")
|
||||
flag.BoolVar(&showSettings, "settings", false, "run settings window")
|
||||
flag.BoolVar(&showNetworks, "networks", false, "run networks window")
|
||||
flag.BoolVar(&showDebug, "debug", false, "run debug window")
|
||||
flag.StringVar(&errorMsg, "error-msg", "", "displays an error message window")
|
||||
|
||||
tmpDir := "/tmp"
|
||||
if runtime.GOOS == "windows" {
|
||||
tmpDir = os.TempDir()
|
||||
}
|
||||
flag.BoolVar(&saveLogsInFile, "use-log-file", false, fmt.Sprintf("save logs in a file: %s/netbird-ui-PID.log", tmpDir))
|
||||
flag.BoolVar(&saveLogsInFile, "use-log-file", false, fmt.Sprintf("save logs in a file: %s/netbird-ui-PID.log", os.TempDir()))
|
||||
flag.Parse()
|
||||
return
|
||||
}
|
||||
|
||||
// initLogFile initializes logging into a file.
|
||||
func initLogFile() error {
|
||||
tmpDir := "/tmp"
|
||||
if runtime.GOOS == "windows" {
|
||||
tmpDir = os.TempDir()
|
||||
}
|
||||
logFile := path.Join(tmpDir, fmt.Sprintf("netbird-ui-%d.log", os.Getpid()))
|
||||
return util.InitLog("trace", logFile)
|
||||
func initLogFile() (string, error) {
|
||||
logFile := path.Join(os.TempDir(), fmt.Sprintf("netbird-ui-%d.log", os.Getpid()))
|
||||
return logFile, util.InitLog("trace", logFile)
|
||||
}
|
||||
|
||||
// watchSettingsChanges listens for Fyne theme/settings changes and updates the client icon.
|
||||
@@ -168,9 +163,10 @@ var iconConnectingMacOS []byte
|
||||
var iconErrorMacOS []byte
|
||||
|
||||
type serviceClient struct {
|
||||
ctx context.Context
|
||||
addr string
|
||||
conn proto.DaemonServiceClient
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
addr string
|
||||
conn proto.DaemonServiceClient
|
||||
|
||||
icAbout []byte
|
||||
icConnected []byte
|
||||
@@ -231,13 +227,14 @@ type serviceClient struct {
|
||||
daemonVersion string
|
||||
updateIndicationLock sync.Mutex
|
||||
isUpdateIconActive bool
|
||||
showRoutes bool
|
||||
wRoutes fyne.Window
|
||||
showNetworks bool
|
||||
wNetworks fyne.Window
|
||||
|
||||
eventManager *event.Manager
|
||||
|
||||
exitNodeMu sync.Mutex
|
||||
mExitNodeItems []menuHandler
|
||||
logFile string
|
||||
}
|
||||
|
||||
type menuHandler struct {
|
||||
@@ -248,25 +245,30 @@ type menuHandler struct {
|
||||
// newServiceClient instance constructor
|
||||
//
|
||||
// This constructor also builds the UI elements for the settings window.
|
||||
func newServiceClient(addr string, a fyne.App, showSettings bool, showRoutes bool) *serviceClient {
|
||||
func newServiceClient(addr string, logFile string, a fyne.App, showSettings bool, showNetworks bool, showDebug bool) *serviceClient {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
s := &serviceClient{
|
||||
ctx: context.Background(),
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
addr: addr,
|
||||
app: a,
|
||||
logFile: logFile,
|
||||
sendNotification: false,
|
||||
|
||||
showAdvancedSettings: showSettings,
|
||||
showRoutes: showRoutes,
|
||||
showNetworks: showNetworks,
|
||||
update: version.NewUpdate(),
|
||||
}
|
||||
|
||||
s.setNewIcons()
|
||||
|
||||
if showSettings {
|
||||
switch {
|
||||
case showSettings:
|
||||
s.showSettingsUI()
|
||||
return s
|
||||
} else if showRoutes {
|
||||
case showNetworks:
|
||||
s.showNetworksUI()
|
||||
case showDebug:
|
||||
s.showDebugUI()
|
||||
}
|
||||
|
||||
return s
|
||||
@@ -313,6 +315,8 @@ func (s *serviceClient) updateIcon() {
|
||||
func (s *serviceClient) showSettingsUI() {
|
||||
// add settings window UI elements.
|
||||
s.wSettings = s.app.NewWindow("NetBird Settings")
|
||||
s.wSettings.SetOnClosed(s.cancel)
|
||||
|
||||
s.iMngURL = widget.NewEntry()
|
||||
s.iAdminURL = widget.NewEntry()
|
||||
s.iConfigFile = widget.NewEntry()
|
||||
@@ -743,11 +747,10 @@ func (s *serviceClient) onTrayReady() {
|
||||
s.runSelfCommand("settings", "true")
|
||||
}()
|
||||
case <-s.mCreateDebugBundle.ClickedCh:
|
||||
s.mCreateDebugBundle.Disable()
|
||||
go func() {
|
||||
if err := s.createAndOpenDebugBundle(); err != nil {
|
||||
log.Errorf("Failed to create debug bundle: %v", err)
|
||||
s.app.SendNotification(fyne.NewNotification("Error", "Failed to create debug bundle"))
|
||||
}
|
||||
defer s.mCreateDebugBundle.Enable()
|
||||
s.runSelfCommand("debug", "true")
|
||||
}()
|
||||
case <-s.mQuit.ClickedCh:
|
||||
systray.Quit()
|
||||
@@ -789,7 +792,7 @@ func (s *serviceClient) onTrayReady() {
|
||||
func (s *serviceClient) runSelfCommand(command, arg string) {
|
||||
proc, err := os.Executable()
|
||||
if err != nil {
|
||||
log.Errorf("show %s failed with error: %v", command, err)
|
||||
log.Errorf("Error getting executable path: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -798,14 +801,48 @@ func (s *serviceClient) runSelfCommand(command, arg string) {
|
||||
fmt.Sprintf("--daemon-addr=%s", s.addr),
|
||||
)
|
||||
|
||||
out, err := cmd.CombinedOutput()
|
||||
if exitErr, ok := err.(*exec.ExitError); ok && exitErr.ExitCode() == 1 {
|
||||
log.Errorf("start %s UI: %v, %s", command, err, string(out))
|
||||
if out := s.attachOutput(cmd); out != nil {
|
||||
defer func() {
|
||||
if err := out.Close(); err != nil {
|
||||
log.Errorf("Error closing log file %s: %v", s.logFile, err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
log.Printf("Running command: %s --%s=%s --daemon-addr=%s", proc, command, arg, s.addr)
|
||||
|
||||
err = cmd.Run()
|
||||
|
||||
if err != nil {
|
||||
var exitErr *exec.ExitError
|
||||
if errors.As(err, &exitErr) {
|
||||
log.Printf("Command '%s %s' failed with exit code %d", command, arg, exitErr.ExitCode())
|
||||
} else {
|
||||
log.Printf("Failed to start/run command '%s %s': %v", command, arg, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
if len(out) != 0 {
|
||||
log.Infof("command %s executed: %s", command, string(out))
|
||||
|
||||
log.Printf("Command '%s %s' completed successfully.", command, arg)
|
||||
}
|
||||
|
||||
func (s *serviceClient) attachOutput(cmd *exec.Cmd) *os.File {
|
||||
if s.logFile == "" {
|
||||
// attach child's streams to parent's streams
|
||||
cmd.Stdout = os.Stdout
|
||||
cmd.Stderr = os.Stderr
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
out, err := os.OpenFile(s.logFile, os.O_WRONLY|os.O_APPEND, 0)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to open log file %s: %v", s.logFile, err)
|
||||
return nil
|
||||
}
|
||||
cmd.Stdout = out
|
||||
cmd.Stderr = out
|
||||
return out
|
||||
}
|
||||
|
||||
func normalizedVersion(version string) string {
|
||||
@@ -818,9 +855,7 @@ func normalizedVersion(version string) string {
|
||||
|
||||
// onTrayExit is called when the tray icon is closed.
|
||||
func (s *serviceClient) onTrayExit() {
|
||||
for _, item := range s.mExitNodeItems {
|
||||
item.cancel()
|
||||
}
|
||||
s.cancel()
|
||||
}
|
||||
|
||||
// getSrvClient connection to the service.
|
||||
@@ -829,7 +864,7 @@ func (s *serviceClient) getSrvClient(timeout time.Duration) (proto.DaemonService
|
||||
return s.conn, nil
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), timeout)
|
||||
ctx, cancel := context.WithTimeout(s.ctx, timeout)
|
||||
defer cancel()
|
||||
|
||||
conn, err := grpc.DialContext(
|
||||
|
||||
@@ -3,48 +3,721 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"fyne.io/fyne/v2"
|
||||
"fyne.io/fyne/v2/container"
|
||||
"fyne.io/fyne/v2/dialog"
|
||||
"fyne.io/fyne/v2/widget"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/skratchdot/open-golang/open"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
nbstatus "github.com/netbirdio/netbird/client/status"
|
||||
uptypes "github.com/netbirdio/netbird/upload-server/types"
|
||||
)
|
||||
|
||||
func (s *serviceClient) createAndOpenDebugBundle() error {
|
||||
// Initial state for the debug collection
|
||||
type debugInitialState struct {
|
||||
wasDown bool
|
||||
logLevel proto.LogLevel
|
||||
isLevelTrace bool
|
||||
}
|
||||
|
||||
// Debug collection parameters
|
||||
type debugCollectionParams struct {
|
||||
duration time.Duration
|
||||
anonymize bool
|
||||
systemInfo bool
|
||||
upload bool
|
||||
uploadURL string
|
||||
enablePersistence bool
|
||||
}
|
||||
|
||||
// UI components for progress tracking
|
||||
type progressUI struct {
|
||||
statusLabel *widget.Label
|
||||
progressBar *widget.ProgressBar
|
||||
uiControls []fyne.Disableable
|
||||
window fyne.Window
|
||||
}
|
||||
|
||||
func (s *serviceClient) showDebugUI() {
|
||||
w := s.app.NewWindow("NetBird Debug")
|
||||
w.SetOnClosed(s.cancel)
|
||||
|
||||
w.Resize(fyne.NewSize(600, 500))
|
||||
w.SetFixedSize(true)
|
||||
|
||||
anonymizeCheck := widget.NewCheck("Anonymize sensitive information (public IPs, domains, ...)", nil)
|
||||
systemInfoCheck := widget.NewCheck("Include system information (routes, interfaces, ...)", nil)
|
||||
systemInfoCheck.SetChecked(true)
|
||||
uploadCheck := widget.NewCheck("Upload bundle automatically after creation", nil)
|
||||
uploadCheck.SetChecked(true)
|
||||
|
||||
uploadURLLabel := widget.NewLabel("Debug upload URL:")
|
||||
uploadURL := widget.NewEntry()
|
||||
uploadURL.SetText(uptypes.DefaultBundleURL)
|
||||
uploadURL.SetPlaceHolder("Enter upload URL")
|
||||
|
||||
uploadURLContainer := container.NewVBox(
|
||||
uploadURLLabel,
|
||||
uploadURL,
|
||||
)
|
||||
|
||||
uploadCheck.OnChanged = func(checked bool) {
|
||||
if checked {
|
||||
uploadURLContainer.Show()
|
||||
} else {
|
||||
uploadURLContainer.Hide()
|
||||
}
|
||||
}
|
||||
|
||||
debugModeContainer := container.NewHBox()
|
||||
runForDurationCheck := widget.NewCheck("Run with trace logs before creating bundle", nil)
|
||||
runForDurationCheck.SetChecked(true)
|
||||
|
||||
forLabel := widget.NewLabel("for")
|
||||
|
||||
durationInput := widget.NewEntry()
|
||||
durationInput.SetText("1")
|
||||
minutesLabel := widget.NewLabel("minute")
|
||||
durationInput.Validator = func(s string) error {
|
||||
return validateMinute(s, minutesLabel)
|
||||
}
|
||||
|
||||
noteLabel := widget.NewLabel("Note: NetBird will be brought up and down during collection")
|
||||
|
||||
runForDurationCheck.OnChanged = func(checked bool) {
|
||||
if checked {
|
||||
forLabel.Show()
|
||||
durationInput.Show()
|
||||
minutesLabel.Show()
|
||||
noteLabel.Show()
|
||||
} else {
|
||||
forLabel.Hide()
|
||||
durationInput.Hide()
|
||||
minutesLabel.Hide()
|
||||
noteLabel.Hide()
|
||||
}
|
||||
}
|
||||
|
||||
debugModeContainer.Add(runForDurationCheck)
|
||||
debugModeContainer.Add(forLabel)
|
||||
debugModeContainer.Add(durationInput)
|
||||
debugModeContainer.Add(minutesLabel)
|
||||
|
||||
statusLabel := widget.NewLabel("")
|
||||
statusLabel.Hide()
|
||||
|
||||
progressBar := widget.NewProgressBar()
|
||||
progressBar.Hide()
|
||||
|
||||
createButton := widget.NewButton("Create Debug Bundle", nil)
|
||||
|
||||
// UI controls that should be disabled during debug collection
|
||||
uiControls := []fyne.Disableable{
|
||||
anonymizeCheck,
|
||||
systemInfoCheck,
|
||||
uploadCheck,
|
||||
uploadURL,
|
||||
runForDurationCheck,
|
||||
durationInput,
|
||||
createButton,
|
||||
}
|
||||
|
||||
createButton.OnTapped = s.getCreateHandler(
|
||||
statusLabel,
|
||||
progressBar,
|
||||
uploadCheck,
|
||||
uploadURL,
|
||||
anonymizeCheck,
|
||||
systemInfoCheck,
|
||||
runForDurationCheck,
|
||||
durationInput,
|
||||
uiControls,
|
||||
w,
|
||||
)
|
||||
|
||||
content := container.NewVBox(
|
||||
widget.NewLabel("Create a debug bundle to help troubleshoot issues with NetBird"),
|
||||
widget.NewLabel(""),
|
||||
anonymizeCheck,
|
||||
systemInfoCheck,
|
||||
uploadCheck,
|
||||
uploadURLContainer,
|
||||
widget.NewLabel(""),
|
||||
debugModeContainer,
|
||||
noteLabel,
|
||||
widget.NewLabel(""),
|
||||
statusLabel,
|
||||
progressBar,
|
||||
createButton,
|
||||
)
|
||||
|
||||
paddedContent := container.NewPadded(content)
|
||||
w.SetContent(paddedContent)
|
||||
|
||||
w.Show()
|
||||
}
|
||||
|
||||
func validateMinute(s string, minutesLabel *widget.Label) error {
|
||||
if val, err := strconv.Atoi(s); err != nil || val < 1 {
|
||||
return fmt.Errorf("must be a number ≥ 1")
|
||||
}
|
||||
if s == "1" {
|
||||
minutesLabel.SetText("minute")
|
||||
} else {
|
||||
minutesLabel.SetText("minutes")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// disableUIControls disables the provided UI controls
|
||||
func disableUIControls(controls []fyne.Disableable) {
|
||||
for _, control := range controls {
|
||||
control.Disable()
|
||||
}
|
||||
}
|
||||
|
||||
// enableUIControls enables the provided UI controls
|
||||
func enableUIControls(controls []fyne.Disableable) {
|
||||
for _, control := range controls {
|
||||
control.Enable()
|
||||
}
|
||||
}
|
||||
|
||||
func (s *serviceClient) getCreateHandler(
|
||||
statusLabel *widget.Label,
|
||||
progressBar *widget.ProgressBar,
|
||||
uploadCheck *widget.Check,
|
||||
uploadURL *widget.Entry,
|
||||
anonymizeCheck *widget.Check,
|
||||
systemInfoCheck *widget.Check,
|
||||
runForDurationCheck *widget.Check,
|
||||
duration *widget.Entry,
|
||||
uiControls []fyne.Disableable,
|
||||
w fyne.Window,
|
||||
) func() {
|
||||
return func() {
|
||||
disableUIControls(uiControls)
|
||||
statusLabel.Show()
|
||||
|
||||
var url string
|
||||
if uploadCheck.Checked {
|
||||
url = uploadURL.Text
|
||||
if url == "" {
|
||||
statusLabel.SetText("Error: Upload URL is required when upload is enabled")
|
||||
enableUIControls(uiControls)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
params := &debugCollectionParams{
|
||||
anonymize: anonymizeCheck.Checked,
|
||||
systemInfo: systemInfoCheck.Checked,
|
||||
upload: uploadCheck.Checked,
|
||||
uploadURL: url,
|
||||
enablePersistence: true,
|
||||
}
|
||||
|
||||
runForDuration := runForDurationCheck.Checked
|
||||
if runForDuration {
|
||||
minutes, err := time.ParseDuration(duration.Text + "m")
|
||||
if err != nil {
|
||||
statusLabel.SetText(fmt.Sprintf("Error: Invalid duration: %v", err))
|
||||
enableUIControls(uiControls)
|
||||
return
|
||||
}
|
||||
params.duration = minutes
|
||||
|
||||
statusLabel.SetText(fmt.Sprintf("Running in debug mode for %d minutes...", int(minutes.Minutes())))
|
||||
progressBar.Show()
|
||||
progressBar.SetValue(0)
|
||||
|
||||
go s.handleRunForDuration(
|
||||
statusLabel,
|
||||
progressBar,
|
||||
uiControls,
|
||||
w,
|
||||
params,
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
statusLabel.SetText("Creating debug bundle...")
|
||||
go s.handleDebugCreation(
|
||||
anonymizeCheck.Checked,
|
||||
systemInfoCheck.Checked,
|
||||
uploadCheck.Checked,
|
||||
url,
|
||||
statusLabel,
|
||||
uiControls,
|
||||
w,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *serviceClient) handleRunForDuration(
|
||||
statusLabel *widget.Label,
|
||||
progressBar *widget.ProgressBar,
|
||||
uiControls []fyne.Disableable,
|
||||
w fyne.Window,
|
||||
params *debugCollectionParams,
|
||||
) {
|
||||
progressUI := &progressUI{
|
||||
statusLabel: statusLabel,
|
||||
progressBar: progressBar,
|
||||
uiControls: uiControls,
|
||||
window: w,
|
||||
}
|
||||
|
||||
conn, err := s.getSrvClient(failFastTimeout)
|
||||
if err != nil {
|
||||
return fmt.Errorf("get client: %v", err)
|
||||
handleError(progressUI, fmt.Sprintf("Failed to get client for debug: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
initialState, err := s.getInitialState(conn)
|
||||
if err != nil {
|
||||
handleError(progressUI, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
statusOutput, err := s.collectDebugData(conn, initialState, params, progressUI)
|
||||
if err != nil {
|
||||
handleError(progressUI, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if err := s.createDebugBundleFromCollection(conn, params, statusOutput, progressUI); err != nil {
|
||||
handleError(progressUI, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
s.restoreServiceState(conn, initialState)
|
||||
|
||||
progressUI.statusLabel.SetText("Bundle created successfully")
|
||||
}
|
||||
|
||||
// Get initial state of the service
|
||||
func (s *serviceClient) getInitialState(conn proto.DaemonServiceClient) (*debugInitialState, error) {
|
||||
statusResp, err := conn.Status(s.ctx, &proto.StatusRequest{})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf(" get status: %v", err)
|
||||
}
|
||||
|
||||
logLevelResp, err := conn.GetLogLevel(s.ctx, &proto.GetLogLevelRequest{})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get log level: %v", err)
|
||||
}
|
||||
|
||||
wasDown := statusResp.Status != string(internal.StatusConnected) &&
|
||||
statusResp.Status != string(internal.StatusConnecting)
|
||||
|
||||
initialLogLevel := logLevelResp.GetLevel()
|
||||
initialLevelTrace := initialLogLevel >= proto.LogLevel_TRACE
|
||||
|
||||
return &debugInitialState{
|
||||
wasDown: wasDown,
|
||||
logLevel: initialLogLevel,
|
||||
isLevelTrace: initialLevelTrace,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Handle progress tracking during collection
|
||||
func startProgressTracker(ctx context.Context, wg *sync.WaitGroup, duration time.Duration, progress *progressUI) {
|
||||
progress.progressBar.Show()
|
||||
progress.progressBar.SetValue(0)
|
||||
|
||||
startTime := time.Now()
|
||||
endTime := startTime.Add(duration)
|
||||
wg.Add(1)
|
||||
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
ticker := time.NewTicker(500 * time.Millisecond)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
remaining := time.Until(endTime)
|
||||
if remaining <= 0 {
|
||||
remaining = 0
|
||||
}
|
||||
|
||||
elapsed := time.Since(startTime)
|
||||
progressVal := float64(elapsed) / float64(duration)
|
||||
if progressVal > 1.0 {
|
||||
progressVal = 1.0
|
||||
}
|
||||
|
||||
progress.progressBar.SetValue(progressVal)
|
||||
progress.statusLabel.SetText(fmt.Sprintf("Running with trace logs... %s remaining", formatDuration(remaining)))
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
}
|
||||
|
||||
func (s *serviceClient) configureServiceForDebug(
|
||||
conn proto.DaemonServiceClient,
|
||||
state *debugInitialState,
|
||||
enablePersistence bool,
|
||||
) error {
|
||||
if state.wasDown {
|
||||
if _, err := conn.Up(s.ctx, &proto.UpRequest{}); err != nil {
|
||||
return fmt.Errorf("bring service up: %v", err)
|
||||
}
|
||||
log.Info("Service brought up for debug")
|
||||
time.Sleep(time.Second * 10)
|
||||
}
|
||||
|
||||
if !state.isLevelTrace {
|
||||
if _, err := conn.SetLogLevel(s.ctx, &proto.SetLogLevelRequest{Level: proto.LogLevel_TRACE}); err != nil {
|
||||
return fmt.Errorf("set log level to TRACE: %v", err)
|
||||
}
|
||||
log.Info("Log level set to TRACE for debug")
|
||||
}
|
||||
|
||||
if _, err := conn.Down(s.ctx, &proto.DownRequest{}); err != nil {
|
||||
return fmt.Errorf("bring service down: %v", err)
|
||||
}
|
||||
time.Sleep(time.Second)
|
||||
|
||||
if enablePersistence {
|
||||
if _, err := conn.SetNetworkMapPersistence(s.ctx, &proto.SetNetworkMapPersistenceRequest{
|
||||
Enabled: true,
|
||||
}); err != nil {
|
||||
return fmt.Errorf("enable network map persistence: %v", err)
|
||||
}
|
||||
log.Info("Network map persistence enabled for debug")
|
||||
}
|
||||
|
||||
if _, err := conn.Up(s.ctx, &proto.UpRequest{}); err != nil {
|
||||
return fmt.Errorf("bring service back up: %v", err)
|
||||
}
|
||||
time.Sleep(time.Second * 3)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *serviceClient) collectDebugData(
|
||||
conn proto.DaemonServiceClient,
|
||||
state *debugInitialState,
|
||||
params *debugCollectionParams,
|
||||
progress *progressUI,
|
||||
) (string, error) {
|
||||
ctx, cancel := context.WithTimeout(s.ctx, params.duration)
|
||||
defer cancel()
|
||||
var wg sync.WaitGroup
|
||||
startProgressTracker(ctx, &wg, params.duration, progress)
|
||||
|
||||
if err := s.configureServiceForDebug(conn, state, params.enablePersistence); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
postUpStatus, err := conn.Status(s.ctx, &proto.StatusRequest{GetFullPeerStatus: true})
|
||||
if err != nil {
|
||||
log.Warnf("Failed to get post-up status: %v", err)
|
||||
}
|
||||
|
||||
var postUpStatusOutput string
|
||||
if postUpStatus != nil {
|
||||
overview := nbstatus.ConvertToStatusOutputOverview(postUpStatus, params.anonymize, "", nil, nil, nil)
|
||||
postUpStatusOutput = nbstatus.ParseToFullDetailSummary(overview)
|
||||
}
|
||||
headerPostUp := fmt.Sprintf("----- NetBird post-up - Timestamp: %s", time.Now().Format(time.RFC3339))
|
||||
statusOutput := fmt.Sprintf("%s\n%s", headerPostUp, postUpStatusOutput)
|
||||
|
||||
wg.Wait()
|
||||
progress.progressBar.Hide()
|
||||
progress.statusLabel.SetText("Collecting debug data...")
|
||||
|
||||
preDownStatus, err := conn.Status(s.ctx, &proto.StatusRequest{GetFullPeerStatus: true})
|
||||
if err != nil {
|
||||
log.Warnf("Failed to get pre-down status: %v", err)
|
||||
}
|
||||
|
||||
var preDownStatusOutput string
|
||||
if preDownStatus != nil {
|
||||
overview := nbstatus.ConvertToStatusOutputOverview(preDownStatus, params.anonymize, "", nil, nil, nil)
|
||||
preDownStatusOutput = nbstatus.ParseToFullDetailSummary(overview)
|
||||
}
|
||||
headerPreDown := fmt.Sprintf("----- NetBird pre-down - Timestamp: %s - Duration: %s",
|
||||
time.Now().Format(time.RFC3339), params.duration)
|
||||
statusOutput = fmt.Sprintf("%s\n%s\n%s", statusOutput, headerPreDown, preDownStatusOutput)
|
||||
|
||||
return statusOutput, nil
|
||||
}
|
||||
|
||||
// Create the debug bundle with collected data
|
||||
func (s *serviceClient) createDebugBundleFromCollection(
|
||||
conn proto.DaemonServiceClient,
|
||||
params *debugCollectionParams,
|
||||
statusOutput string,
|
||||
progress *progressUI,
|
||||
) error {
|
||||
progress.statusLabel.SetText("Creating debug bundle with collected logs...")
|
||||
|
||||
request := &proto.DebugBundleRequest{
|
||||
Anonymize: params.anonymize,
|
||||
Status: statusOutput,
|
||||
SystemInfo: params.systemInfo,
|
||||
}
|
||||
|
||||
if params.upload {
|
||||
request.UploadURL = params.uploadURL
|
||||
}
|
||||
|
||||
resp, err := conn.DebugBundle(s.ctx, request)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create debug bundle: %v", err)
|
||||
}
|
||||
|
||||
// Show appropriate dialog based on upload status
|
||||
localPath := resp.GetPath()
|
||||
uploadFailureReason := resp.GetUploadFailureReason()
|
||||
uploadedKey := resp.GetUploadedKey()
|
||||
|
||||
if params.upload {
|
||||
if uploadFailureReason != "" {
|
||||
showUploadFailedDialog(progress.window, localPath, uploadFailureReason)
|
||||
} else {
|
||||
showUploadSuccessDialog(progress.window, localPath, uploadedKey)
|
||||
}
|
||||
} else {
|
||||
showBundleCreatedDialog(progress.window, localPath)
|
||||
}
|
||||
|
||||
enableUIControls(progress.uiControls)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Restore service to original state
|
||||
func (s *serviceClient) restoreServiceState(conn proto.DaemonServiceClient, state *debugInitialState) {
|
||||
if state.wasDown {
|
||||
if _, err := conn.Down(s.ctx, &proto.DownRequest{}); err != nil {
|
||||
log.Errorf("Failed to restore down state: %v", err)
|
||||
} else {
|
||||
log.Info("Service state restored to down")
|
||||
}
|
||||
}
|
||||
|
||||
if !state.isLevelTrace {
|
||||
if _, err := conn.SetLogLevel(s.ctx, &proto.SetLogLevelRequest{Level: state.logLevel}); err != nil {
|
||||
log.Errorf("Failed to restore log level: %v", err)
|
||||
} else {
|
||||
log.Info("Log level restored to original setting")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Handle errors during debug collection
|
||||
func handleError(progress *progressUI, errMsg string) {
|
||||
log.Errorf("%s", errMsg)
|
||||
progress.statusLabel.SetText(errMsg)
|
||||
progress.progressBar.Hide()
|
||||
enableUIControls(progress.uiControls)
|
||||
}
|
||||
|
||||
func (s *serviceClient) handleDebugCreation(
|
||||
anonymize bool,
|
||||
systemInfo bool,
|
||||
upload bool,
|
||||
uploadURL string,
|
||||
statusLabel *widget.Label,
|
||||
uiControls []fyne.Disableable,
|
||||
w fyne.Window,
|
||||
) {
|
||||
log.Infof("Creating debug bundle (Anonymized: %v, System Info: %v, Upload Attempt: %v)...",
|
||||
anonymize, systemInfo, upload)
|
||||
|
||||
resp, err := s.createDebugBundle(anonymize, systemInfo, uploadURL)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to create debug bundle: %v", err)
|
||||
statusLabel.SetText(fmt.Sprintf("Error creating bundle: %v", err))
|
||||
enableUIControls(uiControls)
|
||||
return
|
||||
}
|
||||
|
||||
localPath := resp.GetPath()
|
||||
uploadFailureReason := resp.GetUploadFailureReason()
|
||||
uploadedKey := resp.GetUploadedKey()
|
||||
|
||||
if upload {
|
||||
if uploadFailureReason != "" {
|
||||
showUploadFailedDialog(w, localPath, uploadFailureReason)
|
||||
} else {
|
||||
showUploadSuccessDialog(w, localPath, uploadedKey)
|
||||
}
|
||||
} else {
|
||||
showBundleCreatedDialog(w, localPath)
|
||||
}
|
||||
|
||||
enableUIControls(uiControls)
|
||||
statusLabel.SetText("Bundle created successfully")
|
||||
}
|
||||
|
||||
func (s *serviceClient) createDebugBundle(anonymize bool, systemInfo bool, uploadURL string) (*proto.DebugBundleResponse, error) {
|
||||
conn, err := s.getSrvClient(failFastTimeout)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get client: %v", err)
|
||||
}
|
||||
|
||||
statusResp, err := conn.Status(s.ctx, &proto.StatusRequest{GetFullPeerStatus: true})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get status: %v", err)
|
||||
log.Warnf("failed to get status for debug bundle: %v", err)
|
||||
}
|
||||
|
||||
overview := nbstatus.ConvertToStatusOutputOverview(statusResp, true, "", nil, nil, nil)
|
||||
statusOutput := nbstatus.ParseToFullDetailSummary(overview)
|
||||
var statusOutput string
|
||||
if statusResp != nil {
|
||||
overview := nbstatus.ConvertToStatusOutputOverview(statusResp, anonymize, "", nil, nil, nil)
|
||||
statusOutput = nbstatus.ParseToFullDetailSummary(overview)
|
||||
}
|
||||
|
||||
resp, err := conn.DebugBundle(s.ctx, &proto.DebugBundleRequest{
|
||||
Anonymize: true,
|
||||
request := &proto.DebugBundleRequest{
|
||||
Anonymize: anonymize,
|
||||
Status: statusOutput,
|
||||
SystemInfo: true,
|
||||
})
|
||||
SystemInfo: systemInfo,
|
||||
}
|
||||
|
||||
if uploadURL != "" {
|
||||
request.UploadURL = uploadURL
|
||||
}
|
||||
|
||||
resp, err := conn.DebugBundle(s.ctx, request)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create debug bundle: %v", err)
|
||||
return nil, fmt.Errorf("failed to create debug bundle via daemon: %v", err)
|
||||
}
|
||||
|
||||
bundleDir := filepath.Dir(resp.GetPath())
|
||||
if err := open.Start(bundleDir); err != nil {
|
||||
return fmt.Errorf("failed to open debug bundle directory: %v", err)
|
||||
}
|
||||
|
||||
s.app.SendNotification(fyne.NewNotification(
|
||||
"Debug Bundle",
|
||||
fmt.Sprintf("Debug bundle created at %s. Administrator privileges are required to access it.", resp.GetPath()),
|
||||
))
|
||||
|
||||
return nil
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// formatDuration formats a duration in HH:MM:SS format
|
||||
func formatDuration(d time.Duration) string {
|
||||
d = d.Round(time.Second)
|
||||
h := d / time.Hour
|
||||
d %= time.Hour
|
||||
m := d / time.Minute
|
||||
d %= time.Minute
|
||||
s := d / time.Second
|
||||
return fmt.Sprintf("%02d:%02d:%02d", h, m, s)
|
||||
}
|
||||
|
||||
// createButtonWithAction creates a button with the given label and action
|
||||
func createButtonWithAction(label string, action func()) *widget.Button {
|
||||
button := widget.NewButton(label, action)
|
||||
return button
|
||||
}
|
||||
|
||||
// showUploadFailedDialog displays a dialog when upload fails
|
||||
func showUploadFailedDialog(w fyne.Window, localPath, failureReason string) {
|
||||
content := container.NewVBox(
|
||||
widget.NewLabel(fmt.Sprintf("Bundle upload failed:\n%s\n\n"+
|
||||
"A local copy was saved at:\n%s", failureReason, localPath)),
|
||||
)
|
||||
|
||||
customDialog := dialog.NewCustom("Upload Failed", "Cancel", content, w)
|
||||
|
||||
buttonBox := container.NewHBox(
|
||||
createButtonWithAction("Open file", func() {
|
||||
log.Infof("Attempting to open local file: %s", localPath)
|
||||
if openErr := open.Start(localPath); openErr != nil {
|
||||
log.Errorf("Failed to open local file '%s': %v", localPath, openErr)
|
||||
dialog.ShowError(fmt.Errorf("open the local file:\n%s\n\nError: %v", localPath, openErr), w)
|
||||
}
|
||||
}),
|
||||
createButtonWithAction("Open folder", func() {
|
||||
folderPath := filepath.Dir(localPath)
|
||||
log.Infof("Attempting to open local folder: %s", folderPath)
|
||||
if openErr := open.Start(folderPath); openErr != nil {
|
||||
log.Errorf("Failed to open local folder '%s': %v", folderPath, openErr)
|
||||
dialog.ShowError(fmt.Errorf("open the local folder:\n%s\n\nError: %v", folderPath, openErr), w)
|
||||
}
|
||||
}),
|
||||
)
|
||||
|
||||
content.Add(buttonBox)
|
||||
customDialog.Show()
|
||||
}
|
||||
|
||||
// showUploadSuccessDialog displays a dialog when upload succeeds
|
||||
func showUploadSuccessDialog(w fyne.Window, localPath, uploadedKey string) {
|
||||
log.Infof("Upload key: %s", uploadedKey)
|
||||
keyEntry := widget.NewEntry()
|
||||
keyEntry.SetText(uploadedKey)
|
||||
keyEntry.Disable()
|
||||
|
||||
content := container.NewVBox(
|
||||
widget.NewLabel("Bundle uploaded successfully!"),
|
||||
widget.NewLabel(""),
|
||||
widget.NewLabel("Upload key:"),
|
||||
keyEntry,
|
||||
widget.NewLabel(""),
|
||||
widget.NewLabel(fmt.Sprintf("Local copy saved at:\n%s", localPath)),
|
||||
)
|
||||
|
||||
customDialog := dialog.NewCustom("Upload Successful", "OK", content, w)
|
||||
|
||||
copyBtn := createButtonWithAction("Copy key", func() {
|
||||
w.Clipboard().SetContent(uploadedKey)
|
||||
log.Info("Upload key copied to clipboard")
|
||||
})
|
||||
|
||||
buttonBox := createButtonBox(localPath, w, copyBtn)
|
||||
content.Add(buttonBox)
|
||||
customDialog.Show()
|
||||
}
|
||||
|
||||
// showBundleCreatedDialog displays a dialog when bundle is created without upload
|
||||
func showBundleCreatedDialog(w fyne.Window, localPath string) {
|
||||
content := container.NewVBox(
|
||||
widget.NewLabel(fmt.Sprintf("Bundle created locally at:\n%s\n\n"+
|
||||
"Administrator privileges may be required to access the file.", localPath)),
|
||||
)
|
||||
|
||||
customDialog := dialog.NewCustom("Debug Bundle Created", "Cancel", content, w)
|
||||
|
||||
buttonBox := createButtonBox(localPath, w, nil)
|
||||
content.Add(buttonBox)
|
||||
customDialog.Show()
|
||||
}
|
||||
|
||||
func createButtonBox(localPath string, w fyne.Window, elems ...fyne.Widget) *fyne.Container {
|
||||
box := container.NewHBox()
|
||||
for _, elem := range elems {
|
||||
box.Add(elem)
|
||||
}
|
||||
|
||||
fileBtn := createButtonWithAction("Open file", func() {
|
||||
log.Infof("Attempting to open local file: %s", localPath)
|
||||
if openErr := open.Start(localPath); openErr != nil {
|
||||
log.Errorf("Failed to open local file '%s': %v", localPath, openErr)
|
||||
dialog.ShowError(fmt.Errorf("open the local file:\n%s\n\nError: %v", localPath, openErr), w)
|
||||
}
|
||||
})
|
||||
|
||||
folderBtn := createButtonWithAction("Open folder", func() {
|
||||
folderPath := filepath.Dir(localPath)
|
||||
log.Infof("Attempting to open local folder: %s", folderPath)
|
||||
if openErr := open.Start(folderPath); openErr != nil {
|
||||
log.Errorf("Failed to open local folder '%s': %v", folderPath, openErr)
|
||||
dialog.ShowError(fmt.Errorf("open the local folder:\n%s\n\nError: %v", folderPath, openErr), w)
|
||||
}
|
||||
})
|
||||
|
||||
box.Add(fileBtn)
|
||||
box.Add(folderBtn)
|
||||
|
||||
return box
|
||||
}
|
||||
|
||||
@@ -34,7 +34,8 @@ const (
|
||||
type filter string
|
||||
|
||||
func (s *serviceClient) showNetworksUI() {
|
||||
s.wRoutes = s.app.NewWindow("Networks")
|
||||
s.wNetworks = s.app.NewWindow("Networks")
|
||||
s.wNetworks.SetOnClosed(s.cancel)
|
||||
|
||||
allGrid := container.New(layout.NewGridLayout(3))
|
||||
go s.updateNetworks(allGrid, allNetworks)
|
||||
@@ -78,8 +79,8 @@ func (s *serviceClient) showNetworksUI() {
|
||||
|
||||
content := container.NewBorder(nil, buttonBox, nil, nil, scrollContainer)
|
||||
|
||||
s.wRoutes.SetContent(content)
|
||||
s.wRoutes.Show()
|
||||
s.wNetworks.SetContent(content)
|
||||
s.wNetworks.Show()
|
||||
|
||||
s.startAutoRefresh(10*time.Second, tabs, allGrid, overlappingGrid, exitNodeGrid)
|
||||
}
|
||||
@@ -148,7 +149,7 @@ func (s *serviceClient) updateNetworks(grid *fyne.Container, f filter) {
|
||||
grid.Add(resolvedIPsSelector)
|
||||
}
|
||||
|
||||
s.wRoutes.Content().Refresh()
|
||||
s.wNetworks.Content().Refresh()
|
||||
grid.Refresh()
|
||||
}
|
||||
|
||||
@@ -305,7 +306,7 @@ func (s *serviceClient) getNetworksRequest(f filter, appendRoute bool) *proto.Se
|
||||
func (s *serviceClient) showError(err error) {
|
||||
wrappedMessage := wrapText(err.Error(), 50)
|
||||
|
||||
dialog.ShowError(fmt.Errorf("%s", wrappedMessage), s.wRoutes)
|
||||
dialog.ShowError(fmt.Errorf("%s", wrappedMessage), s.wNetworks)
|
||||
}
|
||||
|
||||
func (s *serviceClient) startAutoRefresh(interval time.Duration, tabs *container.AppTabs, allGrid, overlappingGrid, exitNodesGrid *fyne.Container) {
|
||||
@@ -316,14 +317,15 @@ func (s *serviceClient) startAutoRefresh(interval time.Duration, tabs *container
|
||||
}
|
||||
}()
|
||||
|
||||
s.wRoutes.SetOnClosed(func() {
|
||||
s.wNetworks.SetOnClosed(func() {
|
||||
ticker.Stop()
|
||||
s.cancel()
|
||||
})
|
||||
}
|
||||
|
||||
func (s *serviceClient) updateNetworksBasedOnDisplayTab(tabs *container.AppTabs, allGrid, overlappingGrid, exitNodesGrid *fyne.Container) {
|
||||
grid, f := getGridAndFilterFromTab(tabs, allGrid, overlappingGrid, exitNodesGrid)
|
||||
s.wRoutes.Content().Refresh()
|
||||
s.wNetworks.Content().Refresh()
|
||||
s.updateNetworks(grid, f)
|
||||
}
|
||||
|
||||
@@ -373,7 +375,7 @@ func (s *serviceClient) recreateExitNodeMenu(exitNodes []*proto.Network) {
|
||||
node.Selected,
|
||||
)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
ctx, cancel := context.WithCancel(s.ctx)
|
||||
s.mExitNodeItems = append(s.mExitNodeItems, menuHandler{
|
||||
MenuItem: menuItem,
|
||||
cancel: cancel,
|
||||
|
||||
@@ -66,17 +66,17 @@ func (s SimpleRecord) String() string {
|
||||
func (s SimpleRecord) Len() uint16 {
|
||||
emptyString := s.RData == ""
|
||||
switch s.Type {
|
||||
case 1:
|
||||
case int(dns.TypeA):
|
||||
if emptyString {
|
||||
return 0
|
||||
}
|
||||
return net.IPv4len
|
||||
case 5:
|
||||
case int(dns.TypeCNAME):
|
||||
if emptyString || s.RData == "." {
|
||||
return 1
|
||||
}
|
||||
return uint16(len(s.RData) + 1)
|
||||
case 28:
|
||||
case int(dns.TypeAAAA):
|
||||
if emptyString {
|
||||
return 0
|
||||
}
|
||||
|
||||
2
go.mod
2
go.mod
@@ -106,6 +106,7 @@ require (
|
||||
golang.org/x/oauth2 v0.24.0
|
||||
golang.org/x/sync v0.13.0
|
||||
golang.org/x/term v0.31.0
|
||||
golang.org/x/time v0.5.0
|
||||
google.golang.org/api v0.177.0
|
||||
gopkg.in/yaml.v3 v3.0.1
|
||||
gorm.io/driver/mysql v1.5.7
|
||||
@@ -240,7 +241,6 @@ require (
|
||||
golang.org/x/image v0.18.0 // indirect
|
||||
golang.org/x/mod v0.17.0 // indirect
|
||||
golang.org/x/text v0.24.0 // indirect
|
||||
golang.org/x/time v0.5.0 // indirect
|
||||
golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d // indirect
|
||||
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20240509183442-62759503f434 // indirect
|
||||
|
||||
@@ -20,7 +20,9 @@ func (a *AccountsAPI) List(ctx context.Context) ([]api.Account, error) {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
ret, err := parseResponse[[]api.Account](resp)
|
||||
return ret, err
|
||||
}
|
||||
@@ -36,7 +38,9 @@ func (a *AccountsAPI) Update(ctx context.Context, accountID string, request api.
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
ret, err := parseResponse[api.Account](resp)
|
||||
return &ret, err
|
||||
}
|
||||
@@ -48,7 +52,9 @@ func (a *AccountsAPI) Delete(ctx context.Context, accountID string) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -20,7 +20,9 @@ func (a *DNSAPI) ListNameserverGroups(ctx context.Context) ([]api.NameserverGrou
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
ret, err := parseResponse[[]api.NameserverGroup](resp)
|
||||
return ret, err
|
||||
}
|
||||
@@ -32,7 +34,9 @@ func (a *DNSAPI) GetNameserverGroup(ctx context.Context, nameserverGroupID strin
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
ret, err := parseResponse[api.NameserverGroup](resp)
|
||||
return &ret, err
|
||||
}
|
||||
@@ -48,7 +52,9 @@ func (a *DNSAPI) CreateNameserverGroup(ctx context.Context, request api.PostApiD
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
ret, err := parseResponse[api.NameserverGroup](resp)
|
||||
return &ret, err
|
||||
}
|
||||
@@ -64,7 +70,9 @@ func (a *DNSAPI) UpdateNameserverGroup(ctx context.Context, nameserverGroupID st
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
ret, err := parseResponse[api.NameserverGroup](resp)
|
||||
return &ret, err
|
||||
}
|
||||
@@ -76,7 +84,9 @@ func (a *DNSAPI) DeleteNameserverGroup(ctx context.Context, nameserverGroupID st
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -88,7 +98,9 @@ func (a *DNSAPI) GetSettings(ctx context.Context) (*api.DNSSettings, error) {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
ret, err := parseResponse[api.DNSSettings](resp)
|
||||
return &ret, err
|
||||
}
|
||||
@@ -104,7 +116,9 @@ func (a *DNSAPI) UpdateSettings(ctx context.Context, request api.PutApiDnsSettin
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
ret, err := parseResponse[api.DNSSettings](resp)
|
||||
return &ret, err
|
||||
}
|
||||
|
||||
@@ -18,7 +18,9 @@ func (a *EventsAPI) List(ctx context.Context) ([]api.Event, error) {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
ret, err := parseResponse[[]api.Event](resp)
|
||||
return ret, err
|
||||
}
|
||||
|
||||
@@ -18,7 +18,9 @@ func (a *GeoLocationAPI) ListCountries(ctx context.Context) ([]api.Country, erro
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
ret, err := parseResponse[[]api.Country](resp)
|
||||
return ret, err
|
||||
}
|
||||
@@ -30,7 +32,9 @@ func (a *GeoLocationAPI) ListCountryCities(ctx context.Context, countryCode stri
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
ret, err := parseResponse[[]api.City](resp)
|
||||
return ret, err
|
||||
}
|
||||
|
||||
@@ -20,7 +20,9 @@ func (a *GroupsAPI) List(ctx context.Context) ([]api.Group, error) {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
ret, err := parseResponse[[]api.Group](resp)
|
||||
return ret, err
|
||||
}
|
||||
@@ -32,7 +34,9 @@ func (a *GroupsAPI) Get(ctx context.Context, groupID string) (*api.Group, error)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
ret, err := parseResponse[api.Group](resp)
|
||||
return &ret, err
|
||||
}
|
||||
@@ -48,7 +52,9 @@ func (a *GroupsAPI) Create(ctx context.Context, request api.PostApiGroupsJSONReq
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
ret, err := parseResponse[api.Group](resp)
|
||||
return &ret, err
|
||||
}
|
||||
@@ -64,7 +70,9 @@ func (a *GroupsAPI) Update(ctx context.Context, groupID string, request api.PutA
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
ret, err := parseResponse[api.Group](resp)
|
||||
return &ret, err
|
||||
}
|
||||
@@ -76,7 +84,9 @@ func (a *GroupsAPI) Delete(ctx context.Context, groupID string) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -20,7 +20,9 @@ func (a *NetworksAPI) List(ctx context.Context) ([]api.Network, error) {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
ret, err := parseResponse[[]api.Network](resp)
|
||||
return ret, err
|
||||
}
|
||||
@@ -32,7 +34,9 @@ func (a *NetworksAPI) Get(ctx context.Context, networkID string) (*api.Network,
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
ret, err := parseResponse[api.Network](resp)
|
||||
return &ret, err
|
||||
}
|
||||
@@ -48,7 +52,9 @@ func (a *NetworksAPI) Create(ctx context.Context, request api.PostApiNetworksJSO
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
ret, err := parseResponse[api.Network](resp)
|
||||
return &ret, err
|
||||
}
|
||||
@@ -64,7 +70,9 @@ func (a *NetworksAPI) Update(ctx context.Context, networkID string, request api.
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
ret, err := parseResponse[api.Network](resp)
|
||||
return &ret, err
|
||||
}
|
||||
@@ -76,7 +84,9 @@ func (a *NetworksAPI) Delete(ctx context.Context, networkID string) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -102,7 +112,9 @@ func (a *NetworkResourcesAPI) List(ctx context.Context) ([]api.NetworkResource,
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
ret, err := parseResponse[[]api.NetworkResource](resp)
|
||||
return ret, err
|
||||
}
|
||||
@@ -114,7 +126,9 @@ func (a *NetworkResourcesAPI) Get(ctx context.Context, networkResourceID string)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
ret, err := parseResponse[api.NetworkResource](resp)
|
||||
return &ret, err
|
||||
}
|
||||
@@ -130,7 +144,9 @@ func (a *NetworkResourcesAPI) Create(ctx context.Context, request api.PostApiNet
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
ret, err := parseResponse[api.NetworkResource](resp)
|
||||
return &ret, err
|
||||
}
|
||||
@@ -146,7 +162,9 @@ func (a *NetworkResourcesAPI) Update(ctx context.Context, networkResourceID stri
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
ret, err := parseResponse[api.NetworkResource](resp)
|
||||
return &ret, err
|
||||
}
|
||||
@@ -158,7 +176,9 @@ func (a *NetworkResourcesAPI) Delete(ctx context.Context, networkResourceID stri
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -184,7 +204,9 @@ func (a *NetworkRoutersAPI) List(ctx context.Context) ([]api.NetworkRouter, erro
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
ret, err := parseResponse[[]api.NetworkRouter](resp)
|
||||
return ret, err
|
||||
}
|
||||
@@ -196,7 +218,9 @@ func (a *NetworkRoutersAPI) Get(ctx context.Context, networkRouterID string) (*a
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
ret, err := parseResponse[api.NetworkRouter](resp)
|
||||
return &ret, err
|
||||
}
|
||||
@@ -212,7 +236,9 @@ func (a *NetworkRoutersAPI) Create(ctx context.Context, request api.PostApiNetwo
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
ret, err := parseResponse[api.NetworkRouter](resp)
|
||||
return &ret, err
|
||||
}
|
||||
@@ -228,7 +254,9 @@ func (a *NetworkRoutersAPI) Update(ctx context.Context, networkRouterID string,
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
ret, err := parseResponse[api.NetworkRouter](resp)
|
||||
return &ret, err
|
||||
}
|
||||
@@ -240,7 +268,9 @@ func (a *NetworkRoutersAPI) Delete(ctx context.Context, networkRouterID string)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -20,7 +20,9 @@ func (a *PeersAPI) List(ctx context.Context) ([]api.Peer, error) {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
ret, err := parseResponse[[]api.Peer](resp)
|
||||
return ret, err
|
||||
}
|
||||
@@ -32,7 +34,9 @@ func (a *PeersAPI) Get(ctx context.Context, peerID string) (*api.Peer, error) {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
ret, err := parseResponse[api.Peer](resp)
|
||||
return &ret, err
|
||||
}
|
||||
@@ -48,7 +52,9 @@ func (a *PeersAPI) Update(ctx context.Context, peerID string, request api.PutApi
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
ret, err := parseResponse[api.Peer](resp)
|
||||
return &ret, err
|
||||
}
|
||||
@@ -60,7 +66,9 @@ func (a *PeersAPI) Delete(ctx context.Context, peerID string) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -72,7 +80,9 @@ func (a *PeersAPI) ListAccessiblePeers(ctx context.Context, peerID string) ([]ap
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
ret, err := parseResponse[[]api.Peer](resp)
|
||||
return ret, err
|
||||
}
|
||||
|
||||
@@ -20,7 +20,9 @@ func (a *PoliciesAPI) List(ctx context.Context) ([]api.Policy, error) {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
ret, err := parseResponse[[]api.Policy](resp)
|
||||
return ret, err
|
||||
}
|
||||
@@ -32,7 +34,9 @@ func (a *PoliciesAPI) Get(ctx context.Context, policyID string) (*api.Policy, er
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
ret, err := parseResponse[api.Policy](resp)
|
||||
return &ret, err
|
||||
}
|
||||
@@ -48,7 +52,9 @@ func (a *PoliciesAPI) Create(ctx context.Context, request api.PostApiPoliciesJSO
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
ret, err := parseResponse[api.Policy](resp)
|
||||
return &ret, err
|
||||
}
|
||||
@@ -64,7 +70,9 @@ func (a *PoliciesAPI) Update(ctx context.Context, policyID string, request api.P
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
ret, err := parseResponse[api.Policy](resp)
|
||||
return &ret, err
|
||||
}
|
||||
@@ -76,7 +84,9 @@ func (a *PoliciesAPI) Delete(ctx context.Context, policyID string) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -20,7 +20,9 @@ func (a *PostureChecksAPI) List(ctx context.Context) ([]api.PostureCheck, error)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
ret, err := parseResponse[[]api.PostureCheck](resp)
|
||||
return ret, err
|
||||
}
|
||||
@@ -32,7 +34,9 @@ func (a *PostureChecksAPI) Get(ctx context.Context, postureCheckID string) (*api
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
ret, err := parseResponse[api.PostureCheck](resp)
|
||||
return &ret, err
|
||||
}
|
||||
@@ -48,7 +52,9 @@ func (a *PostureChecksAPI) Create(ctx context.Context, request api.PostApiPostur
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
ret, err := parseResponse[api.PostureCheck](resp)
|
||||
return &ret, err
|
||||
}
|
||||
@@ -64,7 +70,9 @@ func (a *PostureChecksAPI) Update(ctx context.Context, postureCheckID string, re
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
ret, err := parseResponse[api.PostureCheck](resp)
|
||||
return &ret, err
|
||||
}
|
||||
@@ -76,7 +84,9 @@ func (a *PostureChecksAPI) Delete(ctx context.Context, postureCheckID string) er
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -20,7 +20,9 @@ func (a *RoutesAPI) List(ctx context.Context) ([]api.Route, error) {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
ret, err := parseResponse[[]api.Route](resp)
|
||||
return ret, err
|
||||
}
|
||||
@@ -32,7 +34,9 @@ func (a *RoutesAPI) Get(ctx context.Context, routeID string) (*api.Route, error)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
ret, err := parseResponse[api.Route](resp)
|
||||
return &ret, err
|
||||
}
|
||||
@@ -48,7 +52,9 @@ func (a *RoutesAPI) Create(ctx context.Context, request api.PostApiRoutesJSONReq
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
ret, err := parseResponse[api.Route](resp)
|
||||
return &ret, err
|
||||
}
|
||||
@@ -64,7 +70,9 @@ func (a *RoutesAPI) Update(ctx context.Context, routeID string, request api.PutA
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
ret, err := parseResponse[api.Route](resp)
|
||||
return &ret, err
|
||||
}
|
||||
@@ -76,7 +84,9 @@ func (a *RoutesAPI) Delete(ctx context.Context, routeID string) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -20,7 +20,9 @@ func (a *SetupKeysAPI) List(ctx context.Context) ([]api.SetupKey, error) {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
ret, err := parseResponse[[]api.SetupKey](resp)
|
||||
return ret, err
|
||||
}
|
||||
@@ -32,7 +34,9 @@ func (a *SetupKeysAPI) Get(ctx context.Context, setupKeyID string) (*api.SetupKe
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
ret, err := parseResponse[api.SetupKey](resp)
|
||||
return &ret, err
|
||||
}
|
||||
@@ -48,7 +52,9 @@ func (a *SetupKeysAPI) Create(ctx context.Context, request api.PostApiSetupKeysJ
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
ret, err := parseResponse[api.SetupKeyClear](resp)
|
||||
return &ret, err
|
||||
}
|
||||
@@ -64,7 +70,9 @@ func (a *SetupKeysAPI) Update(ctx context.Context, setupKeyID string, request ap
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
ret, err := parseResponse[api.SetupKey](resp)
|
||||
return &ret, err
|
||||
}
|
||||
@@ -76,7 +84,9 @@ func (a *SetupKeysAPI) Delete(ctx context.Context, setupKeyID string) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -20,7 +20,9 @@ func (a *TokensAPI) List(ctx context.Context, userID string) ([]api.PersonalAcce
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
ret, err := parseResponse[[]api.PersonalAccessToken](resp)
|
||||
return ret, err
|
||||
}
|
||||
@@ -32,7 +34,9 @@ func (a *TokensAPI) Get(ctx context.Context, userID, tokenID string) (*api.Perso
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
ret, err := parseResponse[api.PersonalAccessToken](resp)
|
||||
return &ret, err
|
||||
}
|
||||
@@ -48,7 +52,9 @@ func (a *TokensAPI) Create(ctx context.Context, userID string, request api.PostA
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
ret, err := parseResponse[api.PersonalAccessTokenGenerated](resp)
|
||||
return &ret, err
|
||||
}
|
||||
@@ -60,7 +66,9 @@ func (a *TokensAPI) Delete(ctx context.Context, userID, tokenID string) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -20,7 +20,9 @@ func (a *UsersAPI) List(ctx context.Context) ([]api.User, error) {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
ret, err := parseResponse[[]api.User](resp)
|
||||
return ret, err
|
||||
}
|
||||
@@ -36,7 +38,9 @@ func (a *UsersAPI) Create(ctx context.Context, request api.PostApiUsersJSONReque
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
ret, err := parseResponse[api.User](resp)
|
||||
return &ret, err
|
||||
}
|
||||
@@ -52,7 +56,9 @@ func (a *UsersAPI) Update(ctx context.Context, userID string, request api.PutApi
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
ret, err := parseResponse[api.User](resp)
|
||||
return &ret, err
|
||||
}
|
||||
@@ -64,7 +70,9 @@ func (a *UsersAPI) Delete(ctx context.Context, userID string) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -76,7 +84,9 @@ func (a *UsersAPI) ResendInvitation(ctx context.Context, userID string) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -88,7 +98,9 @@ func (a *UsersAPI) Current(ctx context.Context) (*api.User, error) {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
|
||||
ret, err := parseResponse[api.User](resp)
|
||||
return &ret, err
|
||||
|
||||
@@ -30,11 +30,8 @@ var (
|
||||
Issued: ptr("api"),
|
||||
LastLogin: &time.Time{},
|
||||
Name: "M. Essam",
|
||||
Permissions: &api.UserPermissions{
|
||||
DashboardView: ptr(api.UserPermissionsDashboardViewFull),
|
||||
},
|
||||
Role: "user",
|
||||
Status: api.UserStatusActive,
|
||||
Role: "user",
|
||||
Status: api.UserStatusActive,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@@ -134,6 +134,7 @@ var (
|
||||
},
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
flag.Parse()
|
||||
startPprof()
|
||||
|
||||
ctx, cancel := context.WithCancel(cmd.Context())
|
||||
defer cancel()
|
||||
|
||||
@@ -2,9 +2,13 @@ package cmd
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
_ "net/http/pprof"
|
||||
"os"
|
||||
"os/signal"
|
||||
"runtime"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
@@ -17,6 +21,17 @@ const (
|
||||
idpSignKeyRefreshEnabledFlagName = "idp-sign-key-refresh-enabled"
|
||||
)
|
||||
|
||||
func startPprof() {
|
||||
go func() {
|
||||
runtime.SetBlockProfileRate(1)
|
||||
runtime.SetMutexProfileFraction(1)
|
||||
log.Debugf("Starting pprof server on 0.0.0.0:6060")
|
||||
if err := http.ListenAndServe("0.0.0.0:6060", nil); err != nil {
|
||||
log.Fatalf("pprof server failed: %v", err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
var (
|
||||
dnsDomain string
|
||||
mgmtDataDir string
|
||||
|
||||
@@ -202,7 +202,7 @@ func BuildManager(
|
||||
if err != nil {
|
||||
initialInterval = 1
|
||||
} else {
|
||||
initialInterval = int64(interval) * 10
|
||||
initialInterval = int64(interval) * 2
|
||||
go func() {
|
||||
time.Sleep(30 * time.Second)
|
||||
am.updateAccountPeersBufferInterval.Store(int64(time.Duration(interval) * time.Millisecond))
|
||||
@@ -603,11 +603,15 @@ func (am *DefaultAccountManager) DeleteAccount(ctx context.Context, accountID, u
|
||||
}
|
||||
|
||||
for _, otherUser := range account.Users {
|
||||
if otherUser.IsServiceUser {
|
||||
if otherUser.Id == userID {
|
||||
continue
|
||||
}
|
||||
|
||||
if otherUser.Id == userID {
|
||||
if otherUser.IsServiceUser {
|
||||
err = am.deleteServiceUser(ctx, accountID, userID, otherUser)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -712,7 +716,7 @@ func (am *DefaultAccountManager) loadAccount(ctx context.Context, accountID any)
|
||||
log.WithContext(ctx).Debugf("account %s not found in cache, reloading", accountID)
|
||||
accountIDString := fmt.Sprintf("%v", accountID)
|
||||
|
||||
account, err := am.Store.GetAccount(ctx, accountIDString)
|
||||
accountUsers, err := am.Store.GetAccountUsers(ctx, store.LockingStrengthShare, accountIDString)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
@@ -721,7 +725,7 @@ func (am *DefaultAccountManager) loadAccount(ctx context.Context, accountID any)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
log.WithContext(ctx).Debugf("%d entries received from IdP management", len(userData))
|
||||
log.WithContext(ctx).Debugf("%d entries received from IdP management for account %s", len(userData), accountIDString)
|
||||
|
||||
dataMap := make(map[string]*idp.UserData, len(userData))
|
||||
for _, datum := range userData {
|
||||
@@ -729,7 +733,7 @@ func (am *DefaultAccountManager) loadAccount(ctx context.Context, accountID any)
|
||||
}
|
||||
|
||||
matchedUserData := make([]*idp.UserData, 0)
|
||||
for _, user := range account.Users {
|
||||
for _, user := range accountUsers {
|
||||
if user.IsServiceUser {
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -16,6 +16,7 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/posture"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/management/server/users"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
)
|
||||
|
||||
@@ -115,5 +116,6 @@ type Manager interface {
|
||||
CreateAccountByPrivateDomain(ctx context.Context, initiatorId, domain string) (*types.Account, error)
|
||||
UpdateToPrimaryAccount(ctx context.Context, accountId string) (*types.Account, error)
|
||||
GetOwnerInfo(ctx context.Context, accountId string) (*types.UserInfo, error)
|
||||
GetCurrentUserInfo(ctx context.Context, accountID, userID string) (*types.UserInfo, error)
|
||||
GetCurrentUserInfo(ctx context.Context, userAuth nbcontext.UserAuth) (*users.UserInfoWithPermissions, error)
|
||||
BufferUpdateAccountPeers(ctx context.Context, accountID string)
|
||||
}
|
||||
|
||||
@@ -853,6 +853,42 @@ func TestAccountManager_DeleteAccount(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
account.Users["service-user-1"] = &types.User{
|
||||
Id: "service-user-1",
|
||||
Role: types.UserRoleAdmin,
|
||||
IsServiceUser: true,
|
||||
Issued: types.UserIssuedAPI,
|
||||
PATs: map[string]*types.PersonalAccessToken{
|
||||
"pat-1": {
|
||||
ID: "pat-1",
|
||||
UserID: "service-user-1",
|
||||
Name: "service-user-1",
|
||||
HashedToken: "hashedToken",
|
||||
CreatedAt: time.Now(),
|
||||
},
|
||||
},
|
||||
}
|
||||
account.Users[userId] = &types.User{
|
||||
Id: "service-user-2",
|
||||
Role: types.UserRoleUser,
|
||||
IsServiceUser: true,
|
||||
Issued: types.UserIssuedAPI,
|
||||
PATs: map[string]*types.PersonalAccessToken{
|
||||
"pat-2": {
|
||||
ID: "pat-2",
|
||||
UserID: userId,
|
||||
Name: userId,
|
||||
HashedToken: "hashedToken",
|
||||
CreatedAt: time.Now(),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
err = manager.Store.SaveAccount(context.Background(), account)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = manager.DeleteAccount(context.Background(), account.Id, userId)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
@@ -862,6 +898,14 @@ func TestAccountManager_DeleteAccount(t *testing.T) {
|
||||
if err == nil {
|
||||
t.Fatal(fmt.Errorf("expected to get an error when trying to get deleted account, got %v", getAccount))
|
||||
}
|
||||
|
||||
pats, err := manager.Store.GetUserPATs(context.Background(), store.LockingStrengthShare, "service-user-1")
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, pats, 0)
|
||||
|
||||
pats, err = manager.Store.GetUserPATs(context.Background(), store.LockingStrengthShare, userId)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, pats, 0)
|
||||
}
|
||||
|
||||
func BenchmarkTest_GetAccountWithclaims(b *testing.B) {
|
||||
|
||||
@@ -3,8 +3,11 @@ package server
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"math/rand/v2"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
@@ -13,6 +16,7 @@ import (
|
||||
"github.com/golang/protobuf/ptypes/timestamp"
|
||||
"github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/realip"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/time/rate"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/peer"
|
||||
@@ -47,6 +51,11 @@ type GRPCServer struct {
|
||||
ephemeralManager *EphemeralManager
|
||||
peerLocks sync.Map
|
||||
authManager auth.Manager
|
||||
|
||||
syncLimiter *rate.Limiter
|
||||
loginLimiter *rate.Limiter
|
||||
loginLimiterStore sync.Map
|
||||
loginPeerLimit rate.Limit
|
||||
}
|
||||
|
||||
// NewServer creates a new Management server
|
||||
@@ -76,6 +85,41 @@ func NewServer(
|
||||
}
|
||||
}
|
||||
|
||||
multiplier := time.Second
|
||||
d, e := time.ParseDuration(os.Getenv("NB_LOGIN_RATE"))
|
||||
if e == nil {
|
||||
multiplier = d
|
||||
}
|
||||
|
||||
loginRatePerS, err := strconv.Atoi(os.Getenv("NB_LOGIN_RATE_PER_M"))
|
||||
if loginRatePerS == 0 || err != nil {
|
||||
loginRatePerS = 200
|
||||
}
|
||||
|
||||
loginBurst, err := strconv.Atoi(os.Getenv("NB_LOGIN_BURST"))
|
||||
if loginBurst == 0 || err != nil {
|
||||
loginBurst = 200
|
||||
}
|
||||
log.WithContext(ctx).Infof("login burst limit set to %d", loginBurst)
|
||||
|
||||
loginPeerRatePerS, err := strconv.Atoi(os.Getenv("NB_LOGIN_PEER_RATE_PER_M"))
|
||||
if loginPeerRatePerS == 0 || err != nil {
|
||||
loginPeerRatePerS = 200
|
||||
}
|
||||
log.WithContext(ctx).Infof("login rate limit set to %d/min", loginRatePerS)
|
||||
|
||||
syncRatePerS, err := strconv.Atoi(os.Getenv("NB_SYNC_RATE_PER_M"))
|
||||
if syncRatePerS == 0 || err != nil {
|
||||
syncRatePerS = 200
|
||||
}
|
||||
log.WithContext(ctx).Infof("sync rate limit set to %d/min", syncRatePerS)
|
||||
|
||||
syncBurst, err := strconv.Atoi(os.Getenv("NB_SYNC_BURST"))
|
||||
if syncBurst == 0 || err != nil {
|
||||
syncBurst = 200
|
||||
}
|
||||
log.WithContext(ctx).Infof("sync burst limit set to %d", syncBurst)
|
||||
|
||||
return &GRPCServer{
|
||||
wgKey: key,
|
||||
// peerKey -> event channel
|
||||
@@ -87,6 +131,9 @@ func NewServer(
|
||||
authManager: authManager,
|
||||
appMetrics: appMetrics,
|
||||
ephemeralManager: ephemeralManager,
|
||||
syncLimiter: rate.NewLimiter(rate.Every(time.Minute/time.Duration(syncRatePerS)), syncBurst),
|
||||
loginLimiter: rate.NewLimiter(rate.Every(multiplier/time.Duration(loginRatePerS)), loginBurst),
|
||||
loginPeerLimit: rate.Every(time.Minute / time.Duration(loginPeerRatePerS)),
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -128,11 +175,18 @@ func getRealIP(ctx context.Context) net.IP {
|
||||
// Sync validates the existence of a connecting peer, sends an initial state (all available for the connecting peers) and
|
||||
// notifies the connected peer of any updates (e.g. new peers under the same account)
|
||||
func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_SyncServer) error {
|
||||
reqStart := time.Now()
|
||||
if s.appMetrics != nil {
|
||||
s.appMetrics.GRPCMetrics().CountSyncRequest()
|
||||
}
|
||||
|
||||
if !s.syncLimiter.Allow() {
|
||||
time.Sleep(time.Second + (time.Millisecond * time.Duration(rand.IntN(20)*100)))
|
||||
log.Warnf("sync rate limit exceeded for peer %s", req.WgPubKey)
|
||||
return status.Errorf(codes.Internal, "temp rate limit reached")
|
||||
}
|
||||
|
||||
reqStart := time.Now()
|
||||
|
||||
ctx := srv.Context()
|
||||
|
||||
syncReq := &proto.SyncRequest{}
|
||||
@@ -416,15 +470,58 @@ func (s *GRPCServer) parseRequest(ctx context.Context, req *proto.EncryptedMessa
|
||||
// In case it isn't, the endpoint checks whether setup key is provided within the request and tries to register a peer.
|
||||
// In case of the successful registration login is also successful
|
||||
func (s *GRPCServer) Login(ctx context.Context, req *proto.EncryptedMessage) (*proto.EncryptedMessage, error) {
|
||||
if s.appMetrics != nil {
|
||||
s.appMetrics.GRPCMetrics().CountLoginRequest()
|
||||
}
|
||||
|
||||
limiterIface, ok := s.loginLimiterStore.Load(req.WgPubKey)
|
||||
if !ok {
|
||||
// Check global limiter before allowing a new peer limiter
|
||||
if !s.loginLimiter.Allow() {
|
||||
time.Sleep(time.Second + (time.Millisecond * time.Duration(rand.IntN(20)*100)))
|
||||
log.WithContext(ctx).Warnf("rate limit exceeded for peer %s", req.WgPubKey)
|
||||
return nil, fmt.Errorf("temp rate limit reached (global limit)")
|
||||
}
|
||||
|
||||
// Create new limiter for this peer
|
||||
newLimiter := rate.NewLimiter(s.loginPeerLimit, 1000)
|
||||
s.loginLimiterStore.Store(req.WgPubKey, newLimiter)
|
||||
|
||||
if !newLimiter.Allow() {
|
||||
time.Sleep(time.Second + (time.Millisecond * time.Duration(rand.IntN(20)*100)))
|
||||
log.WithContext(ctx).Warnf("rate limit exceeded for peer %s", req.WgPubKey)
|
||||
return nil, fmt.Errorf("temp rate limit reached (new peer limit)")
|
||||
}
|
||||
} else {
|
||||
// Use existing limiter for this peer
|
||||
limiter := limiterIface.(*rate.Limiter)
|
||||
if !limiter.Allow() {
|
||||
time.Sleep(time.Second + (time.Millisecond * time.Duration(rand.IntN(20)*100)))
|
||||
log.WithContext(ctx).Warnf("rate limit exceeded for peer %s", req.WgPubKey)
|
||||
return nil, fmt.Errorf("temp rate limit reached (peer limit)")
|
||||
}
|
||||
}
|
||||
|
||||
// limiter, _ := s.loginLimiterStore.LoadOrStore(req.WgPubKey, rate.NewLimiter(s.loginPeerLimit, 1))
|
||||
// if !limiter.(*rate.Limiter).Allow() {
|
||||
// time.Sleep(time.Millisecond * time.Duration(rand.IntN(10)*100))
|
||||
// log.WithContext(ctx).Warnf("rate limit exceeded for %s", req.WgPubKey)
|
||||
// return nil, status.Errorf(codes.Internal, "temp rate limit reached")
|
||||
// }
|
||||
//
|
||||
// if os.Getenv("ENABLE_LOGIN_RATE_LIMIT") == "true" {
|
||||
// if !s.loginLimiter.Allow() {
|
||||
// return nil, status.Errorf(codes.Internal, "temp rate limit reached")
|
||||
// }
|
||||
// }
|
||||
|
||||
reqStart := time.Now()
|
||||
defer func() {
|
||||
if s.appMetrics != nil {
|
||||
s.appMetrics.GRPCMetrics().CountLoginRequestDuration(time.Since(reqStart))
|
||||
}
|
||||
}()
|
||||
if s.appMetrics != nil {
|
||||
s.appMetrics.GRPCMetrics().CountLoginRequest()
|
||||
}
|
||||
|
||||
realIP := getRealIP(ctx)
|
||||
log.WithContext(ctx).Debugf("Login request from peer [%s] [%s]", req.WgPubKey, realIP.String())
|
||||
|
||||
|
||||
@@ -216,11 +216,25 @@ components:
|
||||
UserPermissions:
|
||||
type: object
|
||||
properties:
|
||||
dashboard_view:
|
||||
description: User's permission to view the dashboard
|
||||
type: string
|
||||
enum: [ "limited", "blocked", "full" ]
|
||||
example: limited
|
||||
is_restricted:
|
||||
type: boolean
|
||||
description: Indicates whether this User's Peers view is restricted
|
||||
modules:
|
||||
type: object
|
||||
additionalProperties:
|
||||
type: object
|
||||
additionalProperties:
|
||||
type: boolean
|
||||
propertyNames:
|
||||
type: string
|
||||
description: The operation type
|
||||
propertyNames:
|
||||
type: string
|
||||
description: The module name
|
||||
example: {"networks": { "read": true, "create": false, "update": false, "delete": false}, "peers": { "read": false, "create": false, "update": false, "delete": false} }
|
||||
required:
|
||||
- modules
|
||||
- is_restricted
|
||||
UserRequest:
|
||||
type: object
|
||||
properties:
|
||||
|
||||
@@ -178,13 +178,6 @@ const (
|
||||
UserStatusInvited UserStatus = "invited"
|
||||
)
|
||||
|
||||
// Defines values for UserPermissionsDashboardView.
|
||||
const (
|
||||
UserPermissionsDashboardViewBlocked UserPermissionsDashboardView = "blocked"
|
||||
UserPermissionsDashboardViewFull UserPermissionsDashboardView = "full"
|
||||
UserPermissionsDashboardViewLimited UserPermissionsDashboardView = "limited"
|
||||
)
|
||||
|
||||
// Defines values for GetApiEventsNetworkTrafficParamsType.
|
||||
const (
|
||||
GetApiEventsNetworkTrafficParamsTypeTYPEDROP GetApiEventsNetworkTrafficParamsType = "TYPE_DROP"
|
||||
@@ -1757,13 +1750,11 @@ type UserCreateRequest struct {
|
||||
|
||||
// UserPermissions defines model for UserPermissions.
|
||||
type UserPermissions struct {
|
||||
// DashboardView User's permission to view the dashboard
|
||||
DashboardView *UserPermissionsDashboardView `json:"dashboard_view,omitempty"`
|
||||
// IsRestricted Indicates whether this User's Peers view is restricted
|
||||
IsRestricted bool `json:"is_restricted"`
|
||||
Modules map[string]map[string]bool `json:"modules"`
|
||||
}
|
||||
|
||||
// UserPermissionsDashboardView User's permission to view the dashboard
|
||||
type UserPermissionsDashboardView string
|
||||
|
||||
// UserRequest defines model for UserRequest.
|
||||
type UserRequest struct {
|
||||
// AutoGroups Group IDs to auto-assign to peers registered by this user
|
||||
|
||||
@@ -4,10 +4,17 @@ import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/time/rate"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
@@ -23,6 +30,9 @@ import (
|
||||
// Handler is a handler that returns peers of the account
|
||||
type Handler struct {
|
||||
accountManager account.Manager
|
||||
rateLimiter *rate.Limiter
|
||||
limiterStore sync.Map
|
||||
reqLimit rate.Limit
|
||||
}
|
||||
|
||||
func AddEndpoints(accountManager account.Manager, router *mux.Router) {
|
||||
@@ -35,8 +45,15 @@ func AddEndpoints(accountManager account.Manager, router *mux.Router) {
|
||||
|
||||
// NewHandler creates a new peers Handler
|
||||
func NewHandler(accountManager account.Manager) *Handler {
|
||||
apiRatePerM, err := strconv.Atoi(os.Getenv("NB_API_RATE_PER_M"))
|
||||
if apiRatePerM == 0 || err != nil {
|
||||
apiRatePerM = 60
|
||||
}
|
||||
log.Infof("peers API rate limit set to %d/min", apiRatePerM)
|
||||
return &Handler{
|
||||
accountManager: accountManager,
|
||||
rateLimiter: rate.NewLimiter(rate.Every(time.Minute/time.Duration(apiRatePerM)), 1),
|
||||
reqLimit: rate.Every(time.Minute / time.Duration(apiRatePerM)),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -54,6 +71,11 @@ func (h *Handler) checkPeerStatus(peer *nbpeer.Peer) (*nbpeer.Peer, error) {
|
||||
}
|
||||
|
||||
func (h *Handler) getPeer(ctx context.Context, accountID, peerID, userID string, w http.ResponseWriter) {
|
||||
if !h.rateLimiter.Allow() {
|
||||
util.WriteError(ctx, fmt.Errorf("temp rate limit reached"), w)
|
||||
return
|
||||
}
|
||||
|
||||
peer, err := h.accountManager.GetPeer(ctx, accountID, peerID, userID)
|
||||
if err != nil {
|
||||
util.WriteError(ctx, err, w)
|
||||
@@ -91,7 +113,7 @@ func (h *Handler) updatePeer(ctx context.Context, accountID, userID, peerID stri
|
||||
req := &api.PeerRequest{}
|
||||
err := json.NewDecoder(r.Body).Decode(&req)
|
||||
if err != nil {
|
||||
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
||||
util.WriteErrorResponse("couldn't parse JSON request", http.StatusPreconditionRequired, w)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -184,9 +206,40 @@ func (h *Handler) HandlePeer(w http.ResponseWriter, r *http.Request) {
|
||||
util.WriteError(r.Context(), status.Errorf(status.NotFound, "unknown METHOD"), w)
|
||||
}
|
||||
}
|
||||
func getCallerIP(r *http.Request) string {
|
||||
// Check X-Forwarded-For header first (can be a comma-separated list)
|
||||
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
|
||||
// Use first IP in the list
|
||||
parts := strings.Split(xff, ",")
|
||||
return strings.TrimSpace(parts[0])
|
||||
}
|
||||
|
||||
// Then check X-Real-IP
|
||||
if xrip := r.Header.Get("X-Real-IP"); xrip != "" {
|
||||
return xrip
|
||||
}
|
||||
|
||||
// Fallback to RemoteAddr
|
||||
ip, _, err := net.SplitHostPort(r.RemoteAddr)
|
||||
if err != nil {
|
||||
return r.RemoteAddr // may be raw IP
|
||||
}
|
||||
return ip
|
||||
}
|
||||
|
||||
// GetAllPeers returns a list of all peers associated with a provided account
|
||||
func (h *Handler) GetAllPeers(w http.ResponseWriter, r *http.Request) {
|
||||
ip := getCallerIP(r)
|
||||
limiter, _ := h.limiterStore.LoadOrStore(ip, rate.NewLimiter(h.reqLimit, 1))
|
||||
if !limiter.(*rate.Limiter).Allow() {
|
||||
log.WithContext(r.Context()).Errorf("rate limit exceeded for IP: %s", ip)
|
||||
util.WriteError(r.Context(), status.Errorf(status.StatusTooManyRequests, "temp rate limit reached"), w)
|
||||
return
|
||||
}
|
||||
//if !h.rateLimiter.Allow() {
|
||||
// util.WriteError(r.Context(), status.Errorf(status.StatusTooManyRequests, "temp rate limit reached"), w)
|
||||
// return
|
||||
//}
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/http/util"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/management/server/users"
|
||||
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
)
|
||||
@@ -272,15 +273,33 @@ func (h *handler) getCurrentUser(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
|
||||
user, err := h.accountManager.GetCurrentUserInfo(ctx, accountID, userID)
|
||||
user, err := h.accountManager.GetCurrentUserInfo(ctx, userAuth)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
util.WriteJSONObject(r.Context(), w, toUserResponse(user, userID))
|
||||
util.WriteJSONObject(r.Context(), w, toUserWithPermissionsResponse(user, userAuth.UserId))
|
||||
}
|
||||
|
||||
func toUserWithPermissionsResponse(user *users.UserInfoWithPermissions, userID string) *api.User {
|
||||
response := toUserResponse(user.UserInfo, userID)
|
||||
|
||||
// stringify modules and operations keys
|
||||
modules := make(map[string]map[string]bool)
|
||||
for module, operations := range user.Permissions {
|
||||
modules[string(module)] = make(map[string]bool)
|
||||
for op, val := range operations {
|
||||
modules[string(module)][string(op)] = val
|
||||
}
|
||||
}
|
||||
|
||||
response.Permissions = &api.UserPermissions{
|
||||
IsRestricted: user.Restricted,
|
||||
Modules: modules,
|
||||
}
|
||||
|
||||
return response
|
||||
}
|
||||
|
||||
func toUserResponse(user *types.UserInfo, currenUserID string) *api.User {
|
||||
@@ -316,8 +335,5 @@ func toUserResponse(user *types.UserInfo, currenUserID string) *api.User {
|
||||
IsBlocked: user.IsBlocked,
|
||||
LastLogin: &user.LastLogin,
|
||||
Issued: &user.Issued,
|
||||
Permissions: &api.UserPermissions{
|
||||
DashboardView: (*api.UserPermissionsDashboardView)(&user.Permissions.DashboardView),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -13,12 +13,16 @@ import (
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/http/api"
|
||||
"github.com/netbirdio/netbird/management/server/mock_server"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/roles"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/management/server/users"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -107,7 +111,7 @@ func initUsersTestData() *handler {
|
||||
return nil, status.Errorf(status.NotFound, "user with ID %s does not exists", userID)
|
||||
}
|
||||
|
||||
info, err := update.Copy().ToUserInfo(nil, &types.Settings{RegularUsersViewBlocked: false})
|
||||
info, err := update.Copy().ToUserInfo(nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -124,8 +128,8 @@ func initUsersTestData() *handler {
|
||||
|
||||
return nil
|
||||
},
|
||||
GetCurrentUserInfoFunc: func(ctx context.Context, accountID, userID string) (*types.UserInfo, error) {
|
||||
switch userID {
|
||||
GetCurrentUserInfoFunc: func(ctx context.Context, userAuth nbcontext.UserAuth) (*users.UserInfoWithPermissions, error) {
|
||||
switch userAuth.UserId {
|
||||
case "not-found":
|
||||
return nil, status.NewUserNotFoundError("not-found")
|
||||
case "not-of-account":
|
||||
@@ -135,52 +139,68 @@ func initUsersTestData() *handler {
|
||||
case "service-user":
|
||||
return nil, status.NewPermissionDeniedError()
|
||||
case "owner":
|
||||
return &types.UserInfo{
|
||||
ID: "owner",
|
||||
Name: "",
|
||||
Role: "owner",
|
||||
Status: "active",
|
||||
IsServiceUser: false,
|
||||
IsBlocked: false,
|
||||
NonDeletable: false,
|
||||
Issued: "api",
|
||||
Permissions: types.UserPermissions{
|
||||
DashboardView: "full",
|
||||
return &users.UserInfoWithPermissions{
|
||||
UserInfo: &types.UserInfo{
|
||||
ID: "owner",
|
||||
Name: "",
|
||||
Role: "owner",
|
||||
Status: "active",
|
||||
IsServiceUser: false,
|
||||
IsBlocked: false,
|
||||
NonDeletable: false,
|
||||
Issued: "api",
|
||||
},
|
||||
Permissions: mergeRolePermissions(roles.Owner),
|
||||
}, nil
|
||||
case "regular-user":
|
||||
return &types.UserInfo{
|
||||
ID: "regular-user",
|
||||
Name: "",
|
||||
Role: "user",
|
||||
Status: "active",
|
||||
IsServiceUser: false,
|
||||
IsBlocked: false,
|
||||
NonDeletable: false,
|
||||
Issued: "api",
|
||||
Permissions: types.UserPermissions{
|
||||
DashboardView: "limited",
|
||||
return &users.UserInfoWithPermissions{
|
||||
UserInfo: &types.UserInfo{
|
||||
ID: "regular-user",
|
||||
Name: "",
|
||||
Role: "user",
|
||||
Status: "active",
|
||||
IsServiceUser: false,
|
||||
IsBlocked: false,
|
||||
NonDeletable: false,
|
||||
Issued: "api",
|
||||
},
|
||||
Permissions: mergeRolePermissions(roles.User),
|
||||
}, nil
|
||||
|
||||
case "admin-user":
|
||||
return &types.UserInfo{
|
||||
ID: "admin-user",
|
||||
Name: "",
|
||||
Role: "admin",
|
||||
Status: "active",
|
||||
IsServiceUser: false,
|
||||
IsBlocked: false,
|
||||
NonDeletable: false,
|
||||
LastLogin: time.Time{},
|
||||
Issued: "api",
|
||||
Permissions: types.UserPermissions{
|
||||
DashboardView: "full",
|
||||
return &users.UserInfoWithPermissions{
|
||||
UserInfo: &types.UserInfo{
|
||||
ID: "admin-user",
|
||||
Name: "",
|
||||
Role: "admin",
|
||||
Status: "active",
|
||||
IsServiceUser: false,
|
||||
IsBlocked: false,
|
||||
NonDeletable: false,
|
||||
LastLogin: time.Time{},
|
||||
Issued: "api",
|
||||
},
|
||||
Permissions: mergeRolePermissions(roles.Admin),
|
||||
}, nil
|
||||
case "restricted-user":
|
||||
return &users.UserInfoWithPermissions{
|
||||
UserInfo: &types.UserInfo{
|
||||
ID: "restricted-user",
|
||||
Name: "",
|
||||
Role: "user",
|
||||
Status: "active",
|
||||
IsServiceUser: false,
|
||||
IsBlocked: false,
|
||||
NonDeletable: false,
|
||||
LastLogin: time.Time{},
|
||||
Issued: "api",
|
||||
},
|
||||
Permissions: mergeRolePermissions(roles.User),
|
||||
Restricted: true,
|
||||
}, nil
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("user id %s not handled", userID)
|
||||
return nil, fmt.Errorf("user id %s not handled", userAuth.UserId)
|
||||
},
|
||||
},
|
||||
}
|
||||
@@ -546,6 +566,7 @@ func TestCurrentUser(t *testing.T) {
|
||||
name string
|
||||
expectedStatus int
|
||||
requestAuth nbcontext.UserAuth
|
||||
expectedResult *api.User
|
||||
}{
|
||||
{
|
||||
name: "without auth",
|
||||
@@ -575,16 +596,78 @@ func TestCurrentUser(t *testing.T) {
|
||||
name: "owner",
|
||||
requestAuth: nbcontext.UserAuth{UserId: "owner"},
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedResult: &api.User{
|
||||
Id: "owner",
|
||||
Role: "owner",
|
||||
Status: "active",
|
||||
IsBlocked: false,
|
||||
IsCurrent: ptr(true),
|
||||
IsServiceUser: ptr(false),
|
||||
AutoGroups: []string{},
|
||||
Issued: ptr("api"),
|
||||
LastLogin: ptr(time.Time{}),
|
||||
Permissions: &api.UserPermissions{
|
||||
Modules: stringifyPermissionsKeys(mergeRolePermissions(roles.Owner)),
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "regular user",
|
||||
requestAuth: nbcontext.UserAuth{UserId: "regular-user"},
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedResult: &api.User{
|
||||
Id: "regular-user",
|
||||
Role: "user",
|
||||
Status: "active",
|
||||
IsBlocked: false,
|
||||
IsCurrent: ptr(true),
|
||||
IsServiceUser: ptr(false),
|
||||
AutoGroups: []string{},
|
||||
Issued: ptr("api"),
|
||||
LastLogin: ptr(time.Time{}),
|
||||
Permissions: &api.UserPermissions{
|
||||
Modules: stringifyPermissionsKeys(mergeRolePermissions(roles.User)),
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "admin user",
|
||||
requestAuth: nbcontext.UserAuth{UserId: "admin-user"},
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedResult: &api.User{
|
||||
Id: "admin-user",
|
||||
Role: "admin",
|
||||
Status: "active",
|
||||
IsBlocked: false,
|
||||
IsCurrent: ptr(true),
|
||||
IsServiceUser: ptr(false),
|
||||
AutoGroups: []string{},
|
||||
Issued: ptr("api"),
|
||||
LastLogin: ptr(time.Time{}),
|
||||
Permissions: &api.UserPermissions{
|
||||
Modules: stringifyPermissionsKeys(mergeRolePermissions(roles.Admin)),
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "restricted user",
|
||||
requestAuth: nbcontext.UserAuth{UserId: "restricted-user"},
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedResult: &api.User{
|
||||
Id: "restricted-user",
|
||||
Role: "user",
|
||||
Status: "active",
|
||||
IsBlocked: false,
|
||||
IsCurrent: ptr(true),
|
||||
IsServiceUser: ptr(false),
|
||||
AutoGroups: []string{},
|
||||
Issued: ptr("api"),
|
||||
LastLogin: ptr(time.Time{}),
|
||||
Permissions: &api.UserPermissions{
|
||||
IsRestricted: true,
|
||||
Modules: stringifyPermissionsKeys(mergeRolePermissions(roles.User)),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
@@ -603,10 +686,42 @@ func TestCurrentUser(t *testing.T) {
|
||||
res := rr.Result()
|
||||
defer res.Body.Close()
|
||||
|
||||
if status := rr.Code; status != tc.expectedStatus {
|
||||
t.Fatalf("handler returned wrong status code: got %v want %v",
|
||||
status, tc.expectedStatus)
|
||||
assert.Equal(t, tc.expectedStatus, rr.Code, "handler returned wrong status code")
|
||||
|
||||
if tc.expectedResult != nil {
|
||||
var result api.User
|
||||
require.NoError(t, json.NewDecoder(res.Body).Decode(&result))
|
||||
assert.EqualValues(t, *tc.expectedResult, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func ptr[T any, PT *T](x T) PT {
|
||||
return &x
|
||||
}
|
||||
|
||||
func mergeRolePermissions(role roles.RolePermissions) roles.Permissions {
|
||||
permissions := roles.Permissions{}
|
||||
|
||||
for k := range modules.All {
|
||||
if rolePermissions, ok := role.Permissions[k]; ok {
|
||||
permissions[k] = rolePermissions
|
||||
continue
|
||||
}
|
||||
permissions[k] = role.AutoAllowNew
|
||||
}
|
||||
|
||||
return permissions
|
||||
}
|
||||
|
||||
func stringifyPermissionsKeys(permissions roles.Permissions) map[string]map[string]bool {
|
||||
modules := make(map[string]map[string]bool)
|
||||
for module, operations := range permissions {
|
||||
modules[string(module)] = make(map[string]bool)
|
||||
for op, val := range operations {
|
||||
modules[string(module)][string(op)] = val
|
||||
}
|
||||
}
|
||||
return modules
|
||||
}
|
||||
|
||||
@@ -106,6 +106,8 @@ func WriteError(ctx context.Context, err error, w http.ResponseWriter) {
|
||||
httpStatus = http.StatusUnauthorized
|
||||
case status.BadRequest:
|
||||
httpStatus = http.StatusBadRequest
|
||||
case status.StatusTooManyRequests:
|
||||
httpStatus = http.StatusTooManyRequests
|
||||
default:
|
||||
}
|
||||
msg = strings.ToLower(err.Error())
|
||||
|
||||
@@ -3,12 +3,14 @@ package port_forwarding
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/peer"
|
||||
nbtypes "github.com/netbirdio/netbird/management/server/types"
|
||||
)
|
||||
|
||||
type Controller interface {
|
||||
SendUpdate(ctx context.Context, accountID string, affectedProxyID string, affectedPeerIDs []string)
|
||||
GetProxyNetworkMaps(ctx context.Context, accountID string) (map[string]*nbtypes.NetworkMap, error)
|
||||
SendUpdate(ctx context.Context, accountID string, affectedProxyID string, affectedPeerIDs []string, accountPeers map[string]*peer.Peer)
|
||||
GetProxyNetworkMaps(ctx context.Context, accountID, peerID string, accountPeers map[string]*peer.Peer) (map[string]*nbtypes.NetworkMap, error)
|
||||
GetProxyNetworkMapsAll(ctx context.Context, accountID string, accountPeers map[string]*peer.Peer) (map[string]*nbtypes.NetworkMap, error)
|
||||
IsPeerInIngressPorts(ctx context.Context, accountID, peerID string) (bool, error)
|
||||
}
|
||||
|
||||
@@ -19,11 +21,15 @@ func NewControllerMock() *ControllerMock {
|
||||
return &ControllerMock{}
|
||||
}
|
||||
|
||||
func (c *ControllerMock) SendUpdate(ctx context.Context, accountID string, affectedProxyID string, affectedPeerIDs []string) {
|
||||
func (c *ControllerMock) SendUpdate(ctx context.Context, accountID string, affectedProxyID string, affectedPeerIDs []string, accountPeers map[string]*peer.Peer) {
|
||||
// noop
|
||||
}
|
||||
|
||||
func (c *ControllerMock) GetProxyNetworkMaps(ctx context.Context, accountID string) (map[string]*nbtypes.NetworkMap, error) {
|
||||
func (c *ControllerMock) GetProxyNetworkMaps(ctx context.Context, accountID, peerID string, accountPeers map[string]*peer.Peer) (map[string]*nbtypes.NetworkMap, error) {
|
||||
return make(map[string]*nbtypes.NetworkMap), nil
|
||||
}
|
||||
|
||||
func (c *ControllerMock) GetProxyNetworkMapsAll(ctx context.Context, accountID string, accountPeers map[string]*peer.Peer) (map[string]*nbtypes.NetworkMap, error) {
|
||||
return make(map[string]*nbtypes.NetworkMap), nil
|
||||
}
|
||||
|
||||
|
||||
@@ -352,3 +352,24 @@ func MigrateNewField[T any](ctx context.Context, db *gorm.DB, columnName string,
|
||||
log.WithContext(ctx).Infof("Migration of empty %s to default value in table %s completed", columnName, tableName)
|
||||
return nil
|
||||
}
|
||||
|
||||
func DropIndex[T any](ctx context.Context, db *gorm.DB, indexName string) error {
|
||||
var model T
|
||||
|
||||
if !db.Migrator().HasTable(&model) {
|
||||
log.WithContext(ctx).Debugf("table for %T does not exist, no migration needed", model)
|
||||
return nil
|
||||
}
|
||||
|
||||
if !db.Migrator().HasIndex(&model, indexName) {
|
||||
log.WithContext(ctx).Debugf("index %s does not exist in table %T, no migration needed", indexName, model)
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := db.Migrator().DropIndex(&model, indexName); err != nil {
|
||||
return fmt.Errorf("failed to drop index %s: %w", indexName, err)
|
||||
}
|
||||
|
||||
log.WithContext(ctx).Infof("dropped index %s from table %T", indexName, model)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -227,3 +227,25 @@ func TestMigrateSetupKeyToHashedSetupKey_ForAlreadyMigratedKey_Case2(t *testing.
|
||||
|
||||
assert.Equal(t, "9+FQcmNd2GCxIK+SvHmtp6PPGV4MKEicDS+xuSQmvlE=", key.Key, "Key should be hashed")
|
||||
}
|
||||
|
||||
func TestDropIndex(t *testing.T) {
|
||||
db := setupDatabase(t)
|
||||
|
||||
err := db.AutoMigrate(&types.SetupKey{})
|
||||
require.NoError(t, err, "Failed to auto-migrate tables")
|
||||
|
||||
err = db.Save(&types.SetupKey{
|
||||
Id: "1",
|
||||
Key: "9+FQcmNd2GCxIK+SvHmtp6PPGV4MKEicDS+xuSQmvlE=",
|
||||
}).Error
|
||||
require.NoError(t, err, "Failed to insert setup key")
|
||||
|
||||
exist := db.Migrator().HasIndex(&types.SetupKey{}, "idx_setup_keys_account_id")
|
||||
assert.True(t, exist, "Should have the index")
|
||||
|
||||
err = migration.DropIndex[types.SetupKey](context.Background(), db, "idx_setup_keys_account_id")
|
||||
require.NoError(t, err, "Migration should not fail to remove index")
|
||||
|
||||
exist = db.Migrator().HasIndex(&types.SetupKey{}, "idx_setup_keys_account_id")
|
||||
assert.False(t, exist, "Should not have the index")
|
||||
}
|
||||
|
||||
@@ -19,6 +19,7 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/posture"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/management/server/users"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
)
|
||||
|
||||
@@ -115,7 +116,7 @@ type MockAccountManager struct {
|
||||
CreateAccountByPrivateDomainFunc func(ctx context.Context, initiatorId, domain string) (*types.Account, error)
|
||||
UpdateToPrimaryAccountFunc func(ctx context.Context, accountId string) (*types.Account, error)
|
||||
GetOwnerInfoFunc func(ctx context.Context, accountID string) (*types.UserInfo, error)
|
||||
GetCurrentUserInfoFunc func(ctx context.Context, accountID, userID string) (*types.UserInfo, error)
|
||||
GetCurrentUserInfoFunc func(ctx context.Context, userAuth nbcontext.UserAuth) (*users.UserInfoWithPermissions, error)
|
||||
GetAccountMetaFunc func(ctx context.Context, accountID, userID string) (*types.AccountMeta, error)
|
||||
}
|
||||
|
||||
@@ -882,9 +883,13 @@ func (am *MockAccountManager) GetOwnerInfo(ctx context.Context, accountId string
|
||||
return nil, status.Errorf(codes.Unimplemented, "method GetOwnerInfo is not implemented")
|
||||
}
|
||||
|
||||
func (am *MockAccountManager) GetCurrentUserInfo(ctx context.Context, accountID, userID string) (*types.UserInfo, error) {
|
||||
func (am *MockAccountManager) GetCurrentUserInfo(ctx context.Context, userAuth nbcontext.UserAuth) (*users.UserInfoWithPermissions, error) {
|
||||
if am.GetCurrentUserInfoFunc != nil {
|
||||
return am.GetCurrentUserInfoFunc(ctx, accountID, userID)
|
||||
return am.GetCurrentUserInfoFunc(ctx, userAuth)
|
||||
}
|
||||
return nil, status.Errorf(codes.Unimplemented, "method GetCurrentUserInfo is not implemented")
|
||||
}
|
||||
|
||||
func (am *MockAccountManager) BufferUpdateAccountPeers(ctx context.Context, accountID string) {
|
||||
// noop
|
||||
}
|
||||
|
||||
@@ -30,7 +30,7 @@ func (p NetworkResourceType) String() string {
|
||||
}
|
||||
|
||||
type NetworkResource struct {
|
||||
ID string `gorm:"index"`
|
||||
ID string `gorm:"primaryKey"`
|
||||
NetworkID string `gorm:"index"`
|
||||
AccountID string `gorm:"index"`
|
||||
Name string
|
||||
|
||||
@@ -10,7 +10,7 @@ import (
|
||||
)
|
||||
|
||||
type NetworkRouter struct {
|
||||
ID string `gorm:"index"`
|
||||
ID string `gorm:"primaryKey"`
|
||||
NetworkID string `gorm:"index"`
|
||||
AccountID string `gorm:"index"`
|
||||
Peer string
|
||||
|
||||
@@ -7,7 +7,7 @@ import (
|
||||
)
|
||||
|
||||
type Network struct {
|
||||
ID string `gorm:"index"`
|
||||
ID string `gorm:"primaryKey"`
|
||||
AccountID string `gorm:"index"`
|
||||
Name string
|
||||
Description string
|
||||
|
||||
@@ -49,20 +49,9 @@ func (am *DefaultAccountManager) GetPeers(ctx context.Context, accountID, userID
|
||||
return nil, err
|
||||
}
|
||||
|
||||
peers := make([]*nbpeer.Peer, 0)
|
||||
peersMap := make(map[string]*nbpeer.Peer)
|
||||
|
||||
for _, peer := range accountPeers {
|
||||
if user.IsRegularUser() && user.Id != peer.UserID {
|
||||
// only display peers that belong to the current user if the current user is not an admin
|
||||
continue
|
||||
}
|
||||
peers = append(peers, peer)
|
||||
peersMap[peer.ID] = peer
|
||||
}
|
||||
|
||||
// @note if the user has permission to read peers it shows all account peers
|
||||
if allowed {
|
||||
return peers, nil
|
||||
return accountPeers, nil
|
||||
}
|
||||
|
||||
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID)
|
||||
@@ -70,10 +59,22 @@ func (am *DefaultAccountManager) GetPeers(ctx context.Context, accountID, userID
|
||||
return nil, fmt.Errorf("failed to get account settings: %w", err)
|
||||
}
|
||||
|
||||
if settings.RegularUsersViewBlocked {
|
||||
if user.IsRestrictable() && settings.RegularUsersViewBlocked {
|
||||
return []*nbpeer.Peer{}, nil
|
||||
}
|
||||
|
||||
// @note if it does not have permission read peers then only display it's own peers
|
||||
peers := make([]*nbpeer.Peer, 0)
|
||||
peersMap := make(map[string]*nbpeer.Peer)
|
||||
|
||||
for _, peer := range accountPeers {
|
||||
if user.Id != peer.UserID {
|
||||
continue
|
||||
}
|
||||
peers = append(peers, peer)
|
||||
peersMap[peer.ID] = peer
|
||||
}
|
||||
|
||||
return am.getUserAccessiblePeers(ctx, accountID, peersMap, peers)
|
||||
}
|
||||
|
||||
@@ -418,7 +419,7 @@ func (am *DefaultAccountManager) GetNetworkMap(ctx context.Context, peerID strin
|
||||
}
|
||||
customZone := account.GetPeersCustomZone(ctx, am.GetDNSDomain(account.Settings))
|
||||
|
||||
proxyNetworkMaps, err := am.proxyController.GetProxyNetworkMaps(ctx, account.Id)
|
||||
proxyNetworkMaps, err := am.proxyController.GetProxyNetworkMaps(ctx, account.Id, peerID, account.Peers)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get proxy network maps: %v", err)
|
||||
return nil, err
|
||||
@@ -452,6 +453,7 @@ func (am *DefaultAccountManager) GetPeerNetwork(ctx context.Context, peerID stri
|
||||
// Each new Peer will be assigned a new next net.IP from the Account.Network and Account.Network.LastIP will be updated (IP's are not reused).
|
||||
// The peer property is just a placeholder for the Peer properties to pass further
|
||||
func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) {
|
||||
startGlobal := time.Now()
|
||||
if setupKey == "" && userID == "" {
|
||||
// no auth method provided => reject access
|
||||
return nil, nil, nil, status.Errorf(status.Unauthenticated, "no peer auth method provided, please use a setup key or interactive SSO login")
|
||||
@@ -504,6 +506,8 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
|
||||
var ephemeral bool
|
||||
var groupsToAdd []string
|
||||
var allowExtraDNSLabels bool
|
||||
|
||||
start := time.Now()
|
||||
if addedByUser {
|
||||
user, err := transaction.GetUserByUserID(ctx, store.LockingStrengthUpdate, userID)
|
||||
if err != nil {
|
||||
@@ -536,6 +540,9 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
|
||||
}
|
||||
}
|
||||
|
||||
log.WithContext(ctx).Debugf("AddPeer: setup key get took %v", time.Since(start))
|
||||
|
||||
start = time.Now()
|
||||
if (strings.ToLower(peer.Meta.Hostname) == "iphone" || strings.ToLower(peer.Meta.Hostname) == "ipad") && userID != "" {
|
||||
if am.idpManager != nil {
|
||||
userdata, err := am.idpManager.GetUserDataByID(ctx, userID, idp.AppMetadata{WTAccountID: accountID})
|
||||
@@ -544,16 +551,21 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
|
||||
}
|
||||
}
|
||||
}
|
||||
log.WithContext(ctx).Debugf("AddPeer: idp took %v", time.Since(start))
|
||||
|
||||
start = time.Now()
|
||||
freeLabel, err := am.getFreeDNSLabel(ctx, transaction, accountID, peer.Meta.Hostname)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get free DNS label: %w", err)
|
||||
}
|
||||
log.WithContext(ctx).Debugf("AddPeer: free label took %v", time.Since(start))
|
||||
|
||||
start = time.Now()
|
||||
freeIP, err := getFreeIP(ctx, transaction, accountID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get free IP: %w", err)
|
||||
}
|
||||
log.WithContext(ctx).Debugf("AddPeer: ip took %v", time.Since(start))
|
||||
|
||||
registrationTime := time.Now().UTC()
|
||||
newPeer = &nbpeer.Peer{
|
||||
@@ -577,17 +589,22 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
|
||||
ExtraDNSLabels: peer.ExtraDNSLabels,
|
||||
AllowExtraDNSLabels: allowExtraDNSLabels,
|
||||
}
|
||||
|
||||
start = time.Now()
|
||||
settings, err := transaction.GetAccountSettings(ctx, store.LockingStrengthShare, accountID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get account settings: %w", err)
|
||||
}
|
||||
|
||||
log.WithContext(ctx).Debugf("AddPeer: settings took %v", time.Since(start))
|
||||
|
||||
opEvent.TargetID = newPeer.ID
|
||||
opEvent.Meta = newPeer.EventMeta(am.GetDNSDomain(settings))
|
||||
if !addedByUser {
|
||||
opEvent.Meta["setup_key_name"] = setupKeyName
|
||||
}
|
||||
|
||||
start = time.Now()
|
||||
if am.geo != nil && newPeer.Location.ConnectionIP != nil {
|
||||
location, err := am.geo.Lookup(newPeer.Location.ConnectionIP)
|
||||
if err != nil {
|
||||
@@ -599,8 +616,11 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
|
||||
}
|
||||
}
|
||||
|
||||
log.WithContext(ctx).Debugf("AddPeer: geo took %v", time.Since(start))
|
||||
|
||||
newPeer = am.integratedPeerValidator.PreparePeer(ctx, accountID, newPeer, groupsToAdd, settings.Extra)
|
||||
|
||||
start = time.Now()
|
||||
err = transaction.AddPeerToAllGroup(ctx, store.LockingStrengthUpdate, accountID, newPeer.ID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed adding peer to All group: %w", err)
|
||||
@@ -615,11 +635,16 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
|
||||
}
|
||||
}
|
||||
|
||||
log.WithContext(ctx).Debugf("AddPeer: add peer to group took %v", time.Since(start))
|
||||
|
||||
start = time.Now()
|
||||
err = transaction.AddPeerToAccount(ctx, store.LockingStrengthUpdate, newPeer)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to add peer to account: %w", err)
|
||||
}
|
||||
|
||||
log.WithContext(ctx).Debugf("AddPeer: add peer to account took %v", time.Since(start))
|
||||
|
||||
err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to increment network serial: %w", err)
|
||||
@@ -637,11 +662,14 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
|
||||
}
|
||||
}
|
||||
|
||||
start = time.Now()
|
||||
updateAccountPeers, err = isPeerInActiveGroup(ctx, transaction, accountID, newPeer.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
log.WithContext(ctx).Debugf("AddPeer: is peer in active group took %v", time.Since(start))
|
||||
|
||||
log.WithContext(ctx).Debugf("Peer %s added to account %s", newPeer.ID, accountID)
|
||||
return nil
|
||||
})
|
||||
@@ -656,8 +684,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
|
||||
|
||||
am.StoreEvent(ctx, opEvent.InitiatorID, opEvent.TargetID, opEvent.AccountID, opEvent.Activity, opEvent.Meta)
|
||||
|
||||
unlock()
|
||||
unlock = nil
|
||||
log.WithContext(ctx).Debugf("AddPeer took %v", time.Since(startGlobal))
|
||||
|
||||
if updateAccountPeers {
|
||||
am.BufferUpdateAccountPeers(ctx, accountID)
|
||||
@@ -996,48 +1023,56 @@ func (am *DefaultAccountManager) checkIFPeerNeedsLoginWithoutLock(ctx context.Co
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) getValidatedPeerWithMap(ctx context.Context, isRequiresApproval bool, accountID string, peer *nbpeer.Peer) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) {
|
||||
start := time.Now()
|
||||
mstart := time.Now()
|
||||
defer func() {
|
||||
log.WithContext(ctx).Debugf("getValidatedPeerWithMap: took %s", time.Since(start))
|
||||
log.WithContext(ctx).Debugf("getValidatedPeerWithMap: took %s", time.Since(mstart))
|
||||
}()
|
||||
|
||||
if isRequiresApproval {
|
||||
start := time.Now()
|
||||
network, err := am.Store.GetAccountNetwork(ctx, store.LockingStrengthShare, accountID)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
|
||||
log.WithContext(ctx).Debugf("getValidatedPeerWithMap: took %s", time.Since(start))
|
||||
|
||||
emptyMap := &types.NetworkMap{
|
||||
Network: network.Copy(),
|
||||
}
|
||||
return peer, emptyMap, nil, nil
|
||||
}
|
||||
|
||||
start := time.Now()
|
||||
account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
|
||||
log.WithContext(ctx).Debugf("GetAccountWithBackpressure: took %s", time.Since(start))
|
||||
start = time.Now()
|
||||
approvedPeersMap, err := am.integratedPeerValidator.GetValidatedPeers(account.Id, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
|
||||
log.WithContext(ctx).Debugf("GetValidatedPeers: took %s", time.Since(start))
|
||||
start = time.Now()
|
||||
postureChecks, err := am.getPeerPostureChecks(account, peer.ID)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
|
||||
log.WithContext(ctx).Debugf("getPeerPostureChecks: took %s", time.Since(start))
|
||||
start = time.Now()
|
||||
customZone := account.GetPeersCustomZone(ctx, am.GetDNSDomain(account.Settings))
|
||||
|
||||
proxyNetworkMaps, err := am.proxyController.GetProxyNetworkMaps(ctx, account.Id)
|
||||
log.WithContext(ctx).Debugf("GetPeersCustomZone: took %s", time.Since(start))
|
||||
start = time.Now()
|
||||
proxyNetworkMaps, err := am.proxyController.GetProxyNetworkMaps(ctx, account.Id, peer.ID, account.Peers)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get proxy network maps: %v", err)
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
|
||||
log.WithContext(ctx).Debugf("GetProxyNetworkMaps: took %s", time.Since(start))
|
||||
start = time.Now()
|
||||
networkMap := account.GetPeerNetworkMap(ctx, peer.ID, customZone, approvedPeersMap, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), am.metrics.AccountManagerMetrics())
|
||||
|
||||
log.WithContext(ctx).Debugf("GetPeerNetworkMap: took %s", time.Since(start))
|
||||
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]
|
||||
if ok {
|
||||
networkMap.Merge(proxyNetworkMap)
|
||||
@@ -1165,13 +1200,16 @@ func (am *DefaultAccountManager) checkIfUserOwnsPeer(ctx context.Context, accoun
|
||||
// UpdateAccountPeers updates all peers that belong to an account.
|
||||
// Should be called when changes have to be synced to peers.
|
||||
func (am *DefaultAccountManager) UpdateAccountPeers(ctx context.Context, accountID string) {
|
||||
globalStart := time.Now()
|
||||
start := time.Now()
|
||||
account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to send out updates to peers. failed to get account: %v", err)
|
||||
return
|
||||
}
|
||||
log.WithContext(ctx).Infof("updateAccountPeers: getAccount took %s", time.Since(start))
|
||||
|
||||
start := time.Now()
|
||||
start = time.Now()
|
||||
|
||||
approvedPeersMap, err := am.integratedPeerValidator.GetValidatedPeers(account.Id, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra)
|
||||
if err != nil {
|
||||
@@ -1179,6 +1217,8 @@ func (am *DefaultAccountManager) UpdateAccountPeers(ctx context.Context, account
|
||||
return
|
||||
}
|
||||
|
||||
log.WithContext(ctx).Infof("updateAccountPeers: validatePeers took %s", time.Since(start))
|
||||
|
||||
var wg sync.WaitGroup
|
||||
semaphore := make(chan struct{}, 10)
|
||||
|
||||
@@ -1188,11 +1228,21 @@ func (am *DefaultAccountManager) UpdateAccountPeers(ctx context.Context, account
|
||||
resourcePolicies := account.GetResourcePoliciesMap()
|
||||
routers := account.GetResourceRoutersMap()
|
||||
|
||||
proxyNetworkMaps, err := am.proxyController.GetProxyNetworkMaps(ctx, accountID)
|
||||
start = time.Now()
|
||||
proxyNetworkMaps, err := am.proxyController.GetProxyNetworkMapsAll(ctx, accountID, account.Peers)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get proxy network maps: %v", err)
|
||||
return
|
||||
}
|
||||
log.WithContext(ctx).Infof("updateAccountPeers: getProxyNetworkMaps took %s", time.Since(start))
|
||||
for _, id := range []string{"d07kd1ei389c73dq19gg", "d07kcaui389c73dq19g0", "d0e7uo6i389c73f040v0"} {
|
||||
peerMap, ok := proxyNetworkMaps[id]
|
||||
if !ok {
|
||||
log.WithContext(ctx).Infof("updateAccountPeers xxx: proxy network map %s not found", id)
|
||||
continue
|
||||
}
|
||||
log.WithContext(ctx).Infof("updateAccountPeers xxx: peer %s has %d peers, %d offline peers, %d, firewall rules, %d forwarding rules, %d routing rules", id, len(peerMap.Peers), len(peerMap.OfflinePeers), len(peerMap.FirewallRules), len(peerMap.ForwardingRules), len(peerMap.RoutesFirewallRules))
|
||||
}
|
||||
|
||||
for _, peer := range account.Peers {
|
||||
if !am.peersUpdateManager.HasChannel(peer.ID) {
|
||||
@@ -1225,16 +1275,22 @@ func (am *DefaultAccountManager) UpdateAccountPeers(ctx context.Context, account
|
||||
return
|
||||
}
|
||||
|
||||
start = time.Now()
|
||||
update := toSyncResponse(ctx, nil, p, nil, nil, remotePeerNetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings.RoutingPeerDNSResolutionEnabled, extraSetting)
|
||||
log.WithContext(ctx).Infof("updateAccountPeers: toSyncResponse took %s", time.Since(start))
|
||||
start = time.Now()
|
||||
am.peersUpdateManager.SendUpdate(ctx, p.ID, &UpdateMessage{Update: update, NetworkMap: remotePeerNetworkMap})
|
||||
log.WithContext(ctx).Infof("updateAccountPeers: sending update toSyncResponse took %s", time.Since(start))
|
||||
}(peer)
|
||||
}
|
||||
|
||||
//
|
||||
|
||||
wg.Wait()
|
||||
log.WithContext(ctx).Infof("updateAccountPeers: waiting for updates to complete took %s", time.Since(globalStart))
|
||||
|
||||
if am.metrics != nil {
|
||||
am.metrics.AccountManagerMetrics().CountUpdateAccountPeersDuration(time.Since(start))
|
||||
am.metrics.AccountManagerMetrics().CountUpdateAccountPeersDuration(time.Since(globalStart))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1291,7 +1347,7 @@ func (am *DefaultAccountManager) UpdateAccountPeer(ctx context.Context, accountI
|
||||
return
|
||||
}
|
||||
|
||||
proxyNetworkMaps, err := am.proxyController.GetProxyNetworkMaps(ctx, accountId)
|
||||
proxyNetworkMaps, err := am.proxyController.GetProxyNetworkMaps(ctx, accountId, peerId, account.Peers)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get proxy network maps: %v", err)
|
||||
return
|
||||
|
||||
@@ -24,7 +24,7 @@ type Peer struct {
|
||||
// Meta is a Peer system meta data
|
||||
Meta PeerSystemMeta `gorm:"embedded;embeddedPrefix:meta_"`
|
||||
// Name is peer's name (machine name)
|
||||
Name string
|
||||
Name string `gorm:"index"`
|
||||
// DNSLabel is the parsed peer name for domain resolution. It is used to form an FQDN by appending the account's
|
||||
// domain to the peer label. e.g. peer-dns-label.netbird.cloud
|
||||
DNSLabel string
|
||||
|
||||
@@ -20,6 +20,8 @@ type Manager interface {
|
||||
ValidateUserPermissions(ctx context.Context, accountID, userID string, module modules.Module, operation operations.Operation) (bool, error)
|
||||
ValidateRoleModuleAccess(ctx context.Context, accountID string, role roles.RolePermissions, module modules.Module, operation operations.Operation) bool
|
||||
ValidateAccountAccess(ctx context.Context, accountID string, user *types.User, allowOwnerAndAdmin bool) error
|
||||
|
||||
GetPermissionsByRole(ctx context.Context, role types.UserRole) (roles.Permissions, error)
|
||||
}
|
||||
|
||||
type managerImpl struct {
|
||||
@@ -96,3 +98,22 @@ func (m *managerImpl) ValidateAccountAccess(ctx context.Context, accountID strin
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *managerImpl) GetPermissionsByRole(ctx context.Context, role types.UserRole) (roles.Permissions, error) {
|
||||
roleMap, ok := roles.RolesMap[role]
|
||||
if !ok {
|
||||
return roles.Permissions{}, status.NewUserRoleNotFoundError(string(role))
|
||||
}
|
||||
|
||||
permissions := roles.Permissions{}
|
||||
|
||||
for k := range modules.All {
|
||||
if rolePermissions, ok := roleMap.Permissions[k]; ok {
|
||||
permissions[k] = rolePermissions
|
||||
continue
|
||||
}
|
||||
permissions[k] = roleMap.AutoAllowNew
|
||||
}
|
||||
|
||||
return permissions, nil
|
||||
}
|
||||
|
||||
@@ -38,6 +38,21 @@ func (m *MockManager) EXPECT() *MockManagerMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// GetPermissionsByRole mocks base method.
|
||||
func (m *MockManager) GetPermissionsByRole(ctx context.Context, role types.UserRole) (roles.Permissions, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetPermissionsByRole", ctx, role)
|
||||
ret0, _ := ret[0].(roles.Permissions)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetPermissionsByRole indicates an expected call of GetPermissionsByRole.
|
||||
func (mr *MockManagerMockRecorder) GetPermissionsByRole(ctx, role interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPermissionsByRole", reflect.TypeOf((*MockManager)(nil).GetPermissionsByRole), ctx, role)
|
||||
}
|
||||
|
||||
// ValidateAccountAccess mocks base method.
|
||||
func (m *MockManager) ValidateAccountAccess(ctx context.Context, accountID string, user *types.User, allowOwnerAndAdmin bool) error {
|
||||
m.ctrl.T.Helper()
|
||||
|
||||
@@ -17,3 +17,19 @@ const (
|
||||
SetupKeys Module = "setup_keys"
|
||||
Pats Module = "pats"
|
||||
)
|
||||
|
||||
var All = map[Module]struct{}{
|
||||
Networks: {},
|
||||
Peers: {},
|
||||
Groups: {},
|
||||
Settings: {},
|
||||
Accounts: {},
|
||||
Dns: {},
|
||||
Nameservers: {},
|
||||
Events: {},
|
||||
Policies: {},
|
||||
Routes: {},
|
||||
Users: {},
|
||||
SetupKeys: {},
|
||||
Pats: {},
|
||||
}
|
||||
|
||||
@@ -23,9 +23,9 @@ var NetworkAdmin = RolePermissions{
|
||||
},
|
||||
modules.Groups: {
|
||||
operations.Read: true,
|
||||
operations.Create: false,
|
||||
operations.Update: false,
|
||||
operations.Delete: false,
|
||||
operations.Create: true,
|
||||
operations.Update: true,
|
||||
operations.Delete: true,
|
||||
},
|
||||
modules.Settings: {
|
||||
operations.Read: true,
|
||||
@@ -87,5 +87,11 @@ var NetworkAdmin = RolePermissions{
|
||||
operations.Update: true,
|
||||
operations.Delete: true,
|
||||
},
|
||||
modules.Peers: {
|
||||
operations.Read: true,
|
||||
operations.Create: false,
|
||||
operations.Update: false,
|
||||
operations.Delete: false,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
@@ -37,6 +37,8 @@ const (
|
||||
|
||||
// Unauthenticated indicates that user is not authenticated due to absence of valid credentials
|
||||
Unauthenticated Type = 10
|
||||
|
||||
StatusTooManyRequests = 11
|
||||
)
|
||||
|
||||
// Type is a type of the Error
|
||||
|
||||
@@ -802,7 +802,7 @@ func (s *SqlStore) GetAccountByPeerPubKey(ctx context.Context, peerKey string) (
|
||||
|
||||
func (s *SqlStore) GetAnyAccountID(ctx context.Context) (string, error) {
|
||||
var account types.Account
|
||||
result := s.db.WithContext(ctx).Select("id").Limit(1).Find(&account)
|
||||
result := s.db.WithContext(ctx).Select("id").Order("created_at desc").Limit(1).Find(&account)
|
||||
if result.Error != nil {
|
||||
return "", status.NewGetAccountFromStoreError(result.Error)
|
||||
}
|
||||
@@ -1311,7 +1311,7 @@ func (s *SqlStore) GetAccountPeers(ctx context.Context, lockStrength LockingStre
|
||||
query := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Where(accountIDCondition, accountID)
|
||||
|
||||
if nameFilter != "" {
|
||||
query = query.Where("name LIKE ?", "%"+nameFilter+"%")
|
||||
query = query.Where("name = ?", nameFilter)
|
||||
}
|
||||
if ipFilter != "" {
|
||||
query = query.Where("ip LIKE ?", "%"+ipFilter+"%")
|
||||
@@ -1683,18 +1683,26 @@ func (s *SqlStore) SavePolicy(ctx context.Context, lockStrength LockingStrength,
|
||||
}
|
||||
|
||||
func (s *SqlStore) DeletePolicy(ctx context.Context, lockStrength LockingStrength, accountID, policyID string) error {
|
||||
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).
|
||||
Delete(&types.Policy{}, accountAndIDQueryCondition, accountID, policyID)
|
||||
if err := result.Error; err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to delete policy from store: %s", err)
|
||||
return status.Errorf(status.Internal, "failed to delete policy from store")
|
||||
}
|
||||
return s.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
||||
if err := tx.Where("policy_id = ?", policyID).Delete(&types.PolicyRule{}).Error; err != nil {
|
||||
return fmt.Errorf("delete policy rules: %w", err)
|
||||
}
|
||||
|
||||
if result.RowsAffected == 0 {
|
||||
return status.NewPolicyNotFoundError(policyID)
|
||||
}
|
||||
result := tx.Clauses(clause.Locking{Strength: string(lockStrength)}).
|
||||
Where(accountAndIDQueryCondition, accountID, policyID).
|
||||
Delete(&types.Policy{})
|
||||
|
||||
return nil
|
||||
if err := result.Error; err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to delete policy from store: %s", err)
|
||||
return status.Errorf(status.Internal, "failed to delete policy from store")
|
||||
}
|
||||
|
||||
if result.RowsAffected == 0 {
|
||||
return status.NewPolicyNotFoundError(policyID)
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// GetAccountPostureChecks retrieves posture checks for an account.
|
||||
|
||||
@@ -60,10 +60,10 @@ func Test_NewStore(t *testing.T) {
|
||||
|
||||
runTestForAllEngines(t, "", func(t *testing.T, store Store) {
|
||||
if store == nil {
|
||||
t.Errorf("expected to create a new Store")
|
||||
t.Fatalf("expected to create a new Store")
|
||||
}
|
||||
if len(store.GetAllAccounts(context.Background())) != 0 {
|
||||
t.Errorf("expected to create a new empty Accounts map when creating a new FileStore")
|
||||
t.Fatalf("expected to create a new empty Accounts map when creating a new FileStore")
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -1115,7 +1115,7 @@ func TestSqlite_CreateAndGetObjectInTransaction(t *testing.T) {
|
||||
|
||||
group := &types.Group{
|
||||
ID: "group-id",
|
||||
AccountID: "account-id",
|
||||
AccountID: "bf1c8084-ba50-4ce7-9439-34653001fc3b",
|
||||
Name: "group-name",
|
||||
Issued: "api",
|
||||
Peers: nil,
|
||||
|
||||
@@ -315,6 +315,15 @@ func getMigrations(ctx context.Context) []migrationFunc {
|
||||
func(db *gorm.DB) error {
|
||||
return migration.MigrateNewField[routerTypes.NetworkRouter](ctx, db, "enabled", true)
|
||||
},
|
||||
func(db *gorm.DB) error {
|
||||
return migration.DropIndex[networkTypes.Network](ctx, db, "idx_networks_id")
|
||||
},
|
||||
func(db *gorm.DB) error {
|
||||
return migration.DropIndex[resourceTypes.NetworkResource](ctx, db, "idx_network_resources_id")
|
||||
},
|
||||
func(db *gorm.DB) error {
|
||||
return migration.DropIndex[routerTypes.NetworkRouter](ctx, db, "idx_network_routers_id")
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -14,7 +14,7 @@ const (
|
||||
// Group of the peers for ACL
|
||||
type Group struct {
|
||||
// ID of the group
|
||||
ID string
|
||||
ID string `gorm:"primaryKey"`
|
||||
|
||||
// AccountID is a reference to Account that this object belongs
|
||||
AccountID string `json:"-" gorm:"index"`
|
||||
|
||||
@@ -65,11 +65,6 @@ type UserInfo struct {
|
||||
LastLogin time.Time `json:"last_login"`
|
||||
Issued string `json:"issued"`
|
||||
IntegrationReference integration_reference.IntegrationReference `json:"-"`
|
||||
Permissions UserPermissions `json:"permissions"`
|
||||
}
|
||||
|
||||
type UserPermissions struct {
|
||||
DashboardView string `json:"dashboard_view"`
|
||||
}
|
||||
|
||||
// User represents a user of the system
|
||||
@@ -132,21 +127,18 @@ func (u *User) IsRegularUser() bool {
|
||||
return !u.HasAdminPower() && !u.IsServiceUser
|
||||
}
|
||||
|
||||
// IsRestrictable checks whether a user is in a restrictable role.
|
||||
func (u *User) IsRestrictable() bool {
|
||||
return u.Role == UserRoleUser || u.Role == UserRoleBillingAdmin
|
||||
}
|
||||
|
||||
// ToUserInfo converts a User object to a UserInfo object.
|
||||
func (u *User) ToUserInfo(userData *idp.UserData, settings *Settings) (*UserInfo, error) {
|
||||
func (u *User) ToUserInfo(userData *idp.UserData) (*UserInfo, error) {
|
||||
autoGroups := u.AutoGroups
|
||||
if autoGroups == nil {
|
||||
autoGroups = []string{}
|
||||
}
|
||||
|
||||
dashboardViewPermissions := "full"
|
||||
if !u.HasAdminPower() {
|
||||
dashboardViewPermissions = "limited"
|
||||
if settings.RegularUsersViewBlocked {
|
||||
dashboardViewPermissions = "blocked"
|
||||
}
|
||||
}
|
||||
|
||||
if userData == nil {
|
||||
return &UserInfo{
|
||||
ID: u.Id,
|
||||
@@ -159,9 +151,6 @@ func (u *User) ToUserInfo(userData *idp.UserData, settings *Settings) (*UserInfo
|
||||
IsBlocked: u.Blocked,
|
||||
LastLogin: u.GetLastLogin(),
|
||||
Issued: u.Issued,
|
||||
Permissions: UserPermissions{
|
||||
DashboardView: dashboardViewPermissions,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
if userData.ID != u.Id {
|
||||
@@ -184,9 +173,6 @@ func (u *User) ToUserInfo(userData *idp.UserData, settings *Settings) (*UserInfo
|
||||
IsBlocked: u.Blocked,
|
||||
LastLogin: u.GetLastLogin(),
|
||||
Issued: u.Issued,
|
||||
Permissions: UserPermissions{
|
||||
DashboardView: dashboardViewPermissions,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
nbContext "github.com/netbirdio/netbird/management/server/context"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/idp"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||
@@ -19,6 +20,7 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/management/server/users"
|
||||
"github.com/netbirdio/netbird/management/server/util"
|
||||
)
|
||||
|
||||
@@ -122,11 +124,6 @@ func (am *DefaultAccountManager) inviteNewUser(ctx context.Context, accountID, u
|
||||
CreatedAt: time.Now().UTC(),
|
||||
}
|
||||
|
||||
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err = am.Store.SaveUser(ctx, store.LockingStrengthUpdate, newUser); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -138,7 +135,7 @@ func (am *DefaultAccountManager) inviteNewUser(ctx context.Context, accountID, u
|
||||
|
||||
am.StoreEvent(ctx, userID, newUser.Id, accountID, activity.UserInvited, nil)
|
||||
|
||||
return newUser.ToUserInfo(idpUser, settings)
|
||||
return newUser.ToUserInfo(idpUser)
|
||||
}
|
||||
|
||||
// createNewIdpUser validates the invite and creates a new user in the IdP
|
||||
@@ -360,6 +357,7 @@ func (am *DefaultAccountManager) CreatePAT(ctx context.Context, accountID string
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// @note this is essential to prevent non admin users with Pats create permission frpm creating one for a service user
|
||||
if initiatorUserID != targetUserID && !(initiatorUser.HasAdminPower() && targetUser.IsServiceUser) {
|
||||
return nil, status.NewAdminPermissionError()
|
||||
}
|
||||
@@ -727,19 +725,14 @@ func handleOwnerRoleTransfer(ctx context.Context, transaction store.Store, initi
|
||||
// If the AccountManager has a non-nil idpManager and the User is not a service user,
|
||||
// it will attempt to look up the UserData from the cache.
|
||||
func (am *DefaultAccountManager) getUserInfo(ctx context.Context, user *types.User, accountID string) (*types.UserInfo, error) {
|
||||
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !isNil(am.idpManager) && !user.IsServiceUser {
|
||||
userData, err := am.lookupUserInCache(ctx, user.Id, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return user.ToUserInfo(userData, settings)
|
||||
return user.ToUserInfo(userData)
|
||||
}
|
||||
return user.ToUserInfo(nil, settings)
|
||||
return user.ToUserInfo(nil)
|
||||
}
|
||||
|
||||
// validateUserUpdate validates the update operation for a user.
|
||||
@@ -879,17 +872,12 @@ func (am *DefaultAccountManager) BuildUserInfosForAccount(ctx context.Context, a
|
||||
queriedUsers = append(queriedUsers, usersFromIntegration...)
|
||||
}
|
||||
|
||||
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
userInfosMap := make(map[string]*types.UserInfo)
|
||||
|
||||
// in case of self-hosted, or IDP doesn't return anything, we will return the locally stored userInfo
|
||||
if len(queriedUsers) == 0 {
|
||||
for _, accountUser := range accountUsers {
|
||||
info, err := accountUser.ToUserInfo(nil, settings)
|
||||
info, err := accountUser.ToUserInfo(nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -902,7 +890,7 @@ func (am *DefaultAccountManager) BuildUserInfosForAccount(ctx context.Context, a
|
||||
for _, localUser := range accountUsers {
|
||||
var info *types.UserInfo
|
||||
if queriedUser, contains := findUserInIDPUserdata(localUser.Id, queriedUsers); contains {
|
||||
info, err = localUser.ToUserInfo(queriedUser, settings)
|
||||
info, err = localUser.ToUserInfo(queriedUser)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -912,14 +900,6 @@ func (am *DefaultAccountManager) BuildUserInfosForAccount(ctx context.Context, a
|
||||
name = localUser.ServiceUserName
|
||||
}
|
||||
|
||||
dashboardViewPermissions := "full"
|
||||
if !localUser.HasAdminPower() {
|
||||
dashboardViewPermissions = "limited"
|
||||
if settings.RegularUsersViewBlocked {
|
||||
dashboardViewPermissions = "blocked"
|
||||
}
|
||||
}
|
||||
|
||||
info = &types.UserInfo{
|
||||
ID: localUser.Id,
|
||||
Email: "",
|
||||
@@ -929,7 +909,6 @@ func (am *DefaultAccountManager) BuildUserInfosForAccount(ctx context.Context, a
|
||||
Status: string(types.UserStatusActive),
|
||||
IsServiceUser: localUser.IsServiceUser,
|
||||
NonDeletable: localUser.NonDeletable,
|
||||
Permissions: types.UserPermissions{DashboardView: dashboardViewPermissions},
|
||||
}
|
||||
}
|
||||
userInfosMap[info.ID] = info
|
||||
@@ -1239,8 +1218,10 @@ func validateUserInvite(invite *types.UserInfo) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetCurrentUserInfo retrieves the account's current user info
|
||||
func (am *DefaultAccountManager) GetCurrentUserInfo(ctx context.Context, accountID, userID string) (*types.UserInfo, error) {
|
||||
// GetCurrentUserInfo retrieves the account's current user info and permissions
|
||||
func (am *DefaultAccountManager) GetCurrentUserInfo(ctx context.Context, userAuth nbcontext.UserAuth) (*users.UserInfoWithPermissions, error) {
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
|
||||
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -1258,10 +1239,25 @@ func (am *DefaultAccountManager) GetCurrentUserInfo(ctx context.Context, account
|
||||
return nil, err
|
||||
}
|
||||
|
||||
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
userInfo, err := am.getUserInfo(ctx, user, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return userInfo, nil
|
||||
userWithPermissions := &users.UserInfoWithPermissions{
|
||||
UserInfo: userInfo,
|
||||
Restricted: !userAuth.IsChild && user.IsRestrictable() && settings.RegularUsersViewBlocked,
|
||||
}
|
||||
|
||||
permissions, err := am.permissionsManager.GetPermissionsByRole(ctx, user.Role)
|
||||
if err == nil {
|
||||
userWithPermissions.Permissions = permissions
|
||||
}
|
||||
|
||||
return userWithPermissions, nil
|
||||
}
|
||||
|
||||
@@ -13,7 +13,10 @@ import (
|
||||
nbcache "github.com/netbirdio/netbird/management/server/cache"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/roles"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
"github.com/netbirdio/netbird/management/server/users"
|
||||
"github.com/netbirdio/netbird/management/server/util"
|
||||
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
@@ -1020,90 +1023,6 @@ func TestDefaultAccountManager_ListUsers(t *testing.T) {
|
||||
assert.Equal(t, 2, regular)
|
||||
}
|
||||
|
||||
func TestDefaultAccountManager_ListUsers_DashboardPermissions(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
role types.UserRole
|
||||
limitedViewSettings bool
|
||||
expectedDashboardPermissions string
|
||||
}{
|
||||
{
|
||||
name: "Regular user, no limited view settings",
|
||||
role: types.UserRoleUser,
|
||||
limitedViewSettings: false,
|
||||
expectedDashboardPermissions: "limited",
|
||||
},
|
||||
{
|
||||
name: "Admin user, no limited view settings",
|
||||
role: types.UserRoleAdmin,
|
||||
limitedViewSettings: false,
|
||||
expectedDashboardPermissions: "full",
|
||||
},
|
||||
{
|
||||
name: "Owner, no limited view settings",
|
||||
role: types.UserRoleOwner,
|
||||
limitedViewSettings: false,
|
||||
expectedDashboardPermissions: "full",
|
||||
},
|
||||
{
|
||||
name: "Regular user, limited view settings",
|
||||
role: types.UserRoleUser,
|
||||
limitedViewSettings: true,
|
||||
expectedDashboardPermissions: "blocked",
|
||||
},
|
||||
{
|
||||
name: "Admin user, limited view settings",
|
||||
role: types.UserRoleAdmin,
|
||||
limitedViewSettings: true,
|
||||
expectedDashboardPermissions: "full",
|
||||
},
|
||||
{
|
||||
name: "Owner, limited view settings",
|
||||
role: types.UserRoleOwner,
|
||||
limitedViewSettings: true,
|
||||
expectedDashboardPermissions: "full",
|
||||
},
|
||||
}
|
||||
|
||||
for _, testCase := range testCases {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir())
|
||||
if err != nil {
|
||||
t.Fatalf("Error when creating store: %s", err)
|
||||
}
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
|
||||
account.Users["normal_user1"] = types.NewUser("normal_user1", testCase.role, false, false, "", []string{}, types.UserIssuedAPI)
|
||||
account.Settings.RegularUsersViewBlocked = testCase.limitedViewSettings
|
||||
delete(account.Users, mockUserID)
|
||||
|
||||
err = store.SaveAccount(context.Background(), account)
|
||||
if err != nil {
|
||||
t.Fatalf("Error when saving account: %s", err)
|
||||
}
|
||||
|
||||
permissionsManager := permissions.NewManager(store)
|
||||
am := DefaultAccountManager{
|
||||
Store: store,
|
||||
eventStore: &activity.InMemoryEventStore{},
|
||||
permissionsManager: permissionsManager,
|
||||
}
|
||||
|
||||
users, err := am.ListUsers(context.Background(), mockAccountID)
|
||||
if err != nil {
|
||||
t.Fatalf("Error when checking user role: %s", err)
|
||||
}
|
||||
|
||||
assert.Equal(t, 1, len(users))
|
||||
|
||||
userInfo, _ := users[0].ToUserInfo(nil, account.Settings)
|
||||
assert.Equal(t, testCase.expectedDashboardPermissions, userInfo.Permissions.DashboardView)
|
||||
})
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestDefaultAccountManager_ExternalCache(t *testing.T) {
|
||||
store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir())
|
||||
if err != nil {
|
||||
@@ -1654,121 +1573,154 @@ func TestDefaultAccountManager_GetCurrentUserInfo(t *testing.T) {
|
||||
|
||||
tt := []struct {
|
||||
name string
|
||||
accountId string
|
||||
userId string
|
||||
userAuth nbcontext.UserAuth
|
||||
expectedErr error
|
||||
expectedResult *types.UserInfo
|
||||
expectedResult *users.UserInfoWithPermissions
|
||||
}{
|
||||
{
|
||||
name: "not found",
|
||||
accountId: account1.Id,
|
||||
userId: "not-found",
|
||||
userAuth: nbcontext.UserAuth{AccountId: account1.Id, UserId: "not-found"},
|
||||
expectedErr: status.NewUserNotFoundError("not-found"),
|
||||
},
|
||||
{
|
||||
name: "not part of account",
|
||||
accountId: account1.Id,
|
||||
userId: "account2Owner",
|
||||
userAuth: nbcontext.UserAuth{AccountId: account1.Id, UserId: "account2Owner"},
|
||||
expectedErr: status.NewUserNotPartOfAccountError(),
|
||||
},
|
||||
{
|
||||
name: "blocked",
|
||||
accountId: account1.Id,
|
||||
userId: "blocked-user",
|
||||
userAuth: nbcontext.UserAuth{AccountId: account1.Id, UserId: "blocked-user"},
|
||||
expectedErr: status.NewUserBlockedError(),
|
||||
},
|
||||
{
|
||||
name: "service user",
|
||||
accountId: account1.Id,
|
||||
userId: "service-user",
|
||||
userAuth: nbcontext.UserAuth{AccountId: account1.Id, UserId: "service-user"},
|
||||
expectedErr: status.NewPermissionDeniedError(),
|
||||
},
|
||||
{
|
||||
name: "owner user",
|
||||
accountId: account1.Id,
|
||||
userId: "account1Owner",
|
||||
expectedResult: &types.UserInfo{
|
||||
ID: "account1Owner",
|
||||
Name: "",
|
||||
Role: "owner",
|
||||
AutoGroups: []string{},
|
||||
Status: "active",
|
||||
IsServiceUser: false,
|
||||
IsBlocked: false,
|
||||
NonDeletable: false,
|
||||
LastLogin: time.Time{},
|
||||
Issued: "api",
|
||||
IntegrationReference: integration_reference.IntegrationReference{},
|
||||
Permissions: types.UserPermissions{
|
||||
DashboardView: "full",
|
||||
name: "owner user",
|
||||
userAuth: nbcontext.UserAuth{AccountId: account1.Id, UserId: "account1Owner"},
|
||||
expectedResult: &users.UserInfoWithPermissions{
|
||||
UserInfo: &types.UserInfo{
|
||||
ID: "account1Owner",
|
||||
Name: "",
|
||||
Role: "owner",
|
||||
AutoGroups: []string{},
|
||||
Status: "active",
|
||||
IsServiceUser: false,
|
||||
IsBlocked: false,
|
||||
NonDeletable: false,
|
||||
LastLogin: time.Time{},
|
||||
Issued: "api",
|
||||
IntegrationReference: integration_reference.IntegrationReference{},
|
||||
},
|
||||
Permissions: mergeRolePermissions(roles.Owner),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "regular user",
|
||||
accountId: account1.Id,
|
||||
userId: "regular-user",
|
||||
expectedResult: &types.UserInfo{
|
||||
ID: "regular-user",
|
||||
Name: "",
|
||||
Role: "user",
|
||||
Status: "active",
|
||||
IsServiceUser: false,
|
||||
IsBlocked: false,
|
||||
NonDeletable: false,
|
||||
LastLogin: time.Time{},
|
||||
Issued: "api",
|
||||
IntegrationReference: integration_reference.IntegrationReference{},
|
||||
Permissions: types.UserPermissions{
|
||||
DashboardView: "limited",
|
||||
name: "regular user",
|
||||
userAuth: nbcontext.UserAuth{AccountId: account1.Id, UserId: "regular-user"},
|
||||
expectedResult: &users.UserInfoWithPermissions{
|
||||
UserInfo: &types.UserInfo{
|
||||
ID: "regular-user",
|
||||
Name: "",
|
||||
Role: "user",
|
||||
Status: "active",
|
||||
IsServiceUser: false,
|
||||
IsBlocked: false,
|
||||
NonDeletable: false,
|
||||
LastLogin: time.Time{},
|
||||
Issued: "api",
|
||||
IntegrationReference: integration_reference.IntegrationReference{},
|
||||
},
|
||||
Permissions: mergeRolePermissions(roles.User),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "admin user",
|
||||
accountId: account1.Id,
|
||||
userId: "admin-user",
|
||||
expectedResult: &types.UserInfo{
|
||||
ID: "admin-user",
|
||||
Name: "",
|
||||
Role: "admin",
|
||||
Status: "active",
|
||||
IsServiceUser: false,
|
||||
IsBlocked: false,
|
||||
NonDeletable: false,
|
||||
LastLogin: time.Time{},
|
||||
Issued: "api",
|
||||
IntegrationReference: integration_reference.IntegrationReference{},
|
||||
Permissions: types.UserPermissions{
|
||||
DashboardView: "full",
|
||||
name: "admin user",
|
||||
userAuth: nbcontext.UserAuth{AccountId: account1.Id, UserId: "admin-user"},
|
||||
expectedResult: &users.UserInfoWithPermissions{
|
||||
UserInfo: &types.UserInfo{
|
||||
ID: "admin-user",
|
||||
Name: "",
|
||||
Role: "admin",
|
||||
Status: "active",
|
||||
IsServiceUser: false,
|
||||
IsBlocked: false,
|
||||
NonDeletable: false,
|
||||
LastLogin: time.Time{},
|
||||
Issued: "api",
|
||||
IntegrationReference: integration_reference.IntegrationReference{},
|
||||
},
|
||||
Permissions: mergeRolePermissions(roles.Admin),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "settings blocked regular user",
|
||||
accountId: account2.Id,
|
||||
userId: "settings-blocked-user",
|
||||
expectedResult: &types.UserInfo{
|
||||
ID: "settings-blocked-user",
|
||||
Name: "",
|
||||
Role: "user",
|
||||
Status: "active",
|
||||
IsServiceUser: false,
|
||||
IsBlocked: false,
|
||||
NonDeletable: false,
|
||||
LastLogin: time.Time{},
|
||||
Issued: "api",
|
||||
IntegrationReference: integration_reference.IntegrationReference{},
|
||||
Permissions: types.UserPermissions{
|
||||
DashboardView: "blocked",
|
||||
name: "settings blocked regular user",
|
||||
userAuth: nbcontext.UserAuth{AccountId: account2.Id, UserId: "settings-blocked-user"},
|
||||
expectedResult: &users.UserInfoWithPermissions{
|
||||
UserInfo: &types.UserInfo{
|
||||
ID: "settings-blocked-user",
|
||||
Name: "",
|
||||
Role: "user",
|
||||
Status: "active",
|
||||
IsServiceUser: false,
|
||||
IsBlocked: false,
|
||||
NonDeletable: false,
|
||||
LastLogin: time.Time{},
|
||||
Issued: "api",
|
||||
IntegrationReference: integration_reference.IntegrationReference{},
|
||||
},
|
||||
Permissions: mergeRolePermissions(roles.User),
|
||||
Restricted: true,
|
||||
},
|
||||
},
|
||||
|
||||
{
|
||||
name: "settings blocked regular user child account",
|
||||
userAuth: nbcontext.UserAuth{AccountId: account2.Id, UserId: "settings-blocked-user", IsChild: true},
|
||||
expectedResult: &users.UserInfoWithPermissions{
|
||||
UserInfo: &types.UserInfo{
|
||||
ID: "settings-blocked-user",
|
||||
Name: "",
|
||||
Role: "user",
|
||||
Status: "active",
|
||||
IsServiceUser: false,
|
||||
IsBlocked: false,
|
||||
NonDeletable: false,
|
||||
LastLogin: time.Time{},
|
||||
Issued: "api",
|
||||
IntegrationReference: integration_reference.IntegrationReference{},
|
||||
},
|
||||
Permissions: mergeRolePermissions(roles.User),
|
||||
Restricted: false,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "settings blocked owner user",
|
||||
userAuth: nbcontext.UserAuth{AccountId: account2.Id, UserId: "account2Owner"},
|
||||
expectedResult: &users.UserInfoWithPermissions{
|
||||
UserInfo: &types.UserInfo{
|
||||
ID: "account2Owner",
|
||||
Name: "",
|
||||
Role: "owner",
|
||||
AutoGroups: []string{},
|
||||
Status: "active",
|
||||
IsServiceUser: false,
|
||||
IsBlocked: false,
|
||||
NonDeletable: false,
|
||||
LastLogin: time.Time{},
|
||||
Issued: "api",
|
||||
IntegrationReference: integration_reference.IntegrationReference{},
|
||||
},
|
||||
Permissions: mergeRolePermissions(roles.Owner),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tt {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
result, err := am.GetCurrentUserInfo(context.Background(), tc.accountId, tc.userId)
|
||||
result, err := am.GetCurrentUserInfo(context.Background(), tc.userAuth)
|
||||
|
||||
if tc.expectedErr != nil {
|
||||
assert.Equal(t, err, tc.expectedErr)
|
||||
@@ -1780,3 +1732,17 @@ func TestDefaultAccountManager_GetCurrentUserInfo(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func mergeRolePermissions(role roles.RolePermissions) roles.Permissions {
|
||||
permissions := roles.Permissions{}
|
||||
|
||||
for k := range modules.All {
|
||||
if rolePermissions, ok := role.Permissions[k]; ok {
|
||||
permissions[k] = rolePermissions
|
||||
continue
|
||||
}
|
||||
permissions[k] = role.AutoAllowNew
|
||||
}
|
||||
|
||||
return permissions
|
||||
}
|
||||
|
||||
14
management/server/users/user.go
Normal file
14
management/server/users/user.go
Normal file
@@ -0,0 +1,14 @@
|
||||
package users
|
||||
|
||||
import (
|
||||
"github.com/netbirdio/netbird/management/server/permissions/roles"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
)
|
||||
|
||||
// Wrapped UserInfo with Role Permissions
|
||||
type UserInfoWithPermissions struct {
|
||||
*types.UserInfo
|
||||
|
||||
Permissions roles.Permissions
|
||||
Restricted bool
|
||||
}
|
||||
@@ -28,6 +28,16 @@ func (d Dialer) Dial(ctx context.Context, address string) (net.Conn, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Get the base TLS config
|
||||
tlsClientConfig := quictls.ClientQUICTLSConfig()
|
||||
|
||||
// Set ServerName to hostname if not an IP address
|
||||
host, _, splitErr := net.SplitHostPort(quicURL)
|
||||
if splitErr == nil && net.ParseIP(host) == nil {
|
||||
// It's a hostname, not an IP - modify directly
|
||||
tlsClientConfig.ServerName = host
|
||||
}
|
||||
|
||||
quicConfig := &quic.Config{
|
||||
KeepAlivePeriod: 30 * time.Second,
|
||||
MaxIdleTimeout: 4 * time.Minute,
|
||||
@@ -47,7 +57,7 @@ func (d Dialer) Dial(ctx context.Context, address string) (net.Conn, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
session, err := quic.Dial(ctx, udpConn, udpAddr, quictls.ClientQUICTLSConfig(), quicConfig)
|
||||
session, err := quic.Dial(ctx, udpConn, udpAddr, tlsClientConfig, quicConfig)
|
||||
if err != nil {
|
||||
if errors.Is(err, context.Canceled) {
|
||||
return nil, err
|
||||
@@ -61,12 +71,29 @@ func (d Dialer) Dial(ctx context.Context, address string) (net.Conn, error) {
|
||||
}
|
||||
|
||||
func prepareURL(address string) (string, error) {
|
||||
if !strings.HasPrefix(address, "rel://") && !strings.HasPrefix(address, "rels://") {
|
||||
var host string
|
||||
var defaultPort string
|
||||
|
||||
switch {
|
||||
case strings.HasPrefix(address, "rels://"):
|
||||
host = address[7:]
|
||||
defaultPort = "443"
|
||||
case strings.HasPrefix(address, "rel://"):
|
||||
host = address[6:]
|
||||
defaultPort = "80"
|
||||
default:
|
||||
return "", fmt.Errorf("unsupported scheme: %s", address)
|
||||
}
|
||||
|
||||
if strings.HasPrefix(address, "rels://") {
|
||||
return address[7:], nil
|
||||
finalHost, finalPort, err := net.SplitHostPort(host)
|
||||
if err != nil {
|
||||
if strings.Contains(err.Error(), "missing port") {
|
||||
return host + ":" + defaultPort, nil
|
||||
}
|
||||
|
||||
// return any other split error as is
|
||||
return "", err
|
||||
}
|
||||
return address[6:], nil
|
||||
|
||||
return finalHost + ":" + finalPort, nil
|
||||
}
|
||||
|
||||
@@ -224,16 +224,22 @@ check_use_bin_variable() {
|
||||
|
||||
install_netbird() {
|
||||
if [ -x "$(command -v netbird)" ]; then
|
||||
status_output=$(netbird status)
|
||||
if echo "$status_output" | grep -q 'Management: Connected' && echo "$status_output" | grep -q 'Signal: Connected'; then
|
||||
echo "NetBird service is running, please stop it before proceeding"
|
||||
exit 1
|
||||
fi
|
||||
status_output="$(netbird status 2>&1 || true)"
|
||||
|
||||
if [ -n "$status_output" ]; then
|
||||
echo "NetBird seems to be installed already, please remove it before proceeding"
|
||||
exit 1
|
||||
fi
|
||||
if echo "$status_output" | grep -q 'failed to connect to daemon error: context deadline exceeded'; then
|
||||
echo "Warning: could not reach NetBird daemon (timeout), proceeding anyway"
|
||||
else
|
||||
if echo "$status_output" | grep -q 'Management: Connected' && \
|
||||
echo "$status_output" | grep -q 'Signal: Connected'; then
|
||||
echo "NetBird service is running, please stop it before proceeding"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ -n "$status_output" ]; then
|
||||
echo "NetBird seems to be installed already, please remove it before proceeding"
|
||||
exit 1
|
||||
fi
|
||||
fi
|
||||
fi
|
||||
|
||||
# Run the installation, if a desktop environment is not detected
|
||||
|
||||
@@ -7,6 +7,8 @@ const (
|
||||
ClientHeaderValue = "netbird"
|
||||
// GetURLPath is the path for the GetURL request
|
||||
GetURLPath = "/upload-url"
|
||||
|
||||
DefaultBundleURL = "https://upload.debug.netbird.io" + GetURLPath
|
||||
)
|
||||
|
||||
// GetURLResponse is the response for the GetURL request
|
||||
|
||||
Reference in New Issue
Block a user