Compare commits

..

1 Commits

Author SHA1 Message Date
bcmmbaga
24053750d9 add debug logs for user state change
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2025-04-30 15:21:01 +03:00
61 changed files with 900 additions and 2345 deletions

View File

@@ -235,6 +235,13 @@ func runForDuration(cmd *cobra.Command, args []string) error {
return fmt.Errorf("failed to bundle debug: %v", status.Convert(err).Message())
}
// Disable network map persistence after creating the debug bundle
if _, err := client.SetNetworkMapPersistence(cmd.Context(), &proto.SetNetworkMapPersistenceRequest{
Enabled: false,
}); err != nil {
return fmt.Errorf("failed to disable network map persistence: %v", status.Convert(err).Message())
}
if stateWasDown {
if _, err := client.Down(cmd.Context(), &proto.DownRequest{}); err != nil {
return fmt.Errorf("failed to down: %v", status.Convert(err).Message())

View File

@@ -42,6 +42,7 @@ const (
blockLANAccessFlag = "block-lan-access"
uploadBundle = "upload-bundle"
uploadBundleURL = "upload-bundle-url"
defaultBundleURL = "https://upload.debug.netbird.io" + types.GetURLPath
)
var (
@@ -187,7 +188,7 @@ func init() {
debugCmd.PersistentFlags().BoolVarP(&debugSystemInfoFlag, systemInfoFlag, "S", true, "Adds system information to the debug bundle")
debugCmd.PersistentFlags().BoolVarP(&debugUploadBundle, uploadBundle, "U", false, fmt.Sprintf("Uploads the debug bundle to a server from URL defined by %s", uploadBundleURL))
debugCmd.PersistentFlags().StringVar(&debugUploadBundleURL, uploadBundleURL, types.DefaultBundleURL, "Service URL to get an URL to upload the debug bundle")
debugCmd.PersistentFlags().StringVar(&debugUploadBundleURL, uploadBundleURL, defaultBundleURL, "Service URL to get an URL to upload the debug bundle")
}
// SetupCloseHandler handles SIGTERM signal and exits with success

View File

@@ -1,6 +1,7 @@
package dns_test
import (
"net"
"testing"
"github.com/miekg/dns"
@@ -8,7 +9,6 @@ import (
"github.com/stretchr/testify/mock"
nbdns "github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/dns/test"
)
// TestHandlerChain_ServeDNS_Priorities tests that handlers are executed in priority order
@@ -30,7 +30,7 @@ func TestHandlerChain_ServeDNS_Priorities(t *testing.T) {
r.SetQuestion("example.com.", dns.TypeA)
// Create test writer
w := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}}
w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
// Setup expectations - only highest priority handler should be called
dnsRouteHandler.On("ServeDNS", mock.Anything, r).Once()
@@ -142,7 +142,7 @@ func TestHandlerChain_ServeDNS_DomainMatching(t *testing.T) {
r := new(dns.Msg)
r.SetQuestion(tt.queryDomain, dns.TypeA)
w := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}}
w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
chain.ServeDNS(w, r)
@@ -259,7 +259,7 @@ func TestHandlerChain_ServeDNS_OverlappingDomains(t *testing.T) {
// Create and execute request
r := new(dns.Msg)
r.SetQuestion(tt.queryDomain, dns.TypeA)
w := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}}
w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
chain.ServeDNS(w, r)
// Verify expectations
@@ -316,7 +316,7 @@ func TestHandlerChain_ServeDNS_ChainContinuation(t *testing.T) {
}).Once()
// Execute
w := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}}
w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
chain.ServeDNS(w, r)
// Verify all handlers were called in order
@@ -325,6 +325,20 @@ func TestHandlerChain_ServeDNS_ChainContinuation(t *testing.T) {
handler3.AssertExpectations(t)
}
// mockResponseWriter implements dns.ResponseWriter for testing
type mockResponseWriter struct {
mock.Mock
}
func (m *mockResponseWriter) LocalAddr() net.Addr { return nil }
func (m *mockResponseWriter) RemoteAddr() net.Addr { return nil }
func (m *mockResponseWriter) WriteMsg(*dns.Msg) error { return nil }
func (m *mockResponseWriter) Write([]byte) (int, error) { return 0, nil }
func (m *mockResponseWriter) Close() error { return nil }
func (m *mockResponseWriter) TsigStatus() error { return nil }
func (m *mockResponseWriter) TsigTimersOnly(bool) {}
func (m *mockResponseWriter) Hijack() {}
func TestHandlerChain_PriorityDeregistration(t *testing.T) {
tests := []struct {
name string
@@ -411,7 +425,7 @@ func TestHandlerChain_PriorityDeregistration(t *testing.T) {
// Create test request
r := new(dns.Msg)
r.SetQuestion(tt.query, dns.TypeA)
w := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}}
w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
// Setup expectations
for priority, handler := range handlers {
@@ -457,7 +471,7 @@ func TestHandlerChain_MultiPriorityHandling(t *testing.T) {
chain.AddHandler(testDomain, matchHandler, nbdns.PriorityMatchDomain)
// Test 1: Initial state
w1 := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}}
w1 := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
// Highest priority handler (routeHandler) should be called
routeHandler.On("ServeDNS", mock.Anything, r).Return().Once()
matchHandler.On("ServeDNS", mock.Anything, r).Maybe() // Ensure others are not expected yet
@@ -476,7 +490,7 @@ func TestHandlerChain_MultiPriorityHandling(t *testing.T) {
// Test 2: Remove highest priority handler
chain.RemoveHandler(testDomain, nbdns.PriorityDNSRoute)
w2 := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}}
w2 := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
// Now middle priority handler (matchHandler) should be called
matchHandler.On("ServeDNS", mock.Anything, r).Return().Once()
defaultHandler.On("ServeDNS", mock.Anything, r).Maybe() // Ensure default is not expected yet
@@ -492,7 +506,7 @@ func TestHandlerChain_MultiPriorityHandling(t *testing.T) {
// Test 3: Remove middle priority handler
chain.RemoveHandler(testDomain, nbdns.PriorityMatchDomain)
w3 := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}}
w3 := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
// Now lowest priority handler (defaultHandler) should be called
defaultHandler.On("ServeDNS", mock.Anything, r).Return().Once()
@@ -505,7 +519,7 @@ func TestHandlerChain_MultiPriorityHandling(t *testing.T) {
// Test 4: Remove last handler
chain.RemoveHandler(testDomain, nbdns.PriorityDefault)
w4 := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}}
w4 := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
chain.ServeDNS(w4, r) // Call ServeDNS on the now empty chain for this domain
for _, m := range mocks {
@@ -661,7 +675,7 @@ func TestHandlerChain_CaseSensitivity(t *testing.T) {
// Execute request
r := new(dns.Msg)
r.SetQuestion(tt.query, dns.TypeA)
chain.ServeDNS(&test.MockResponseWriter{}, r)
chain.ServeDNS(&mockResponseWriter{}, r)
// Verify each handler was called exactly as expected
for _, h := range tt.addHandlers {
@@ -805,7 +819,7 @@ func TestHandlerChain_DomainSpecificityOrdering(t *testing.T) {
r := new(dns.Msg)
r.SetQuestion(tt.query, dns.TypeA)
w := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}}
w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
// Setup handler expectations
for pattern, handler := range handlers {
@@ -955,7 +969,7 @@ func TestHandlerChain_AddRemoveRoundtrip(t *testing.T) {
handler := &nbdns.MockHandler{}
r := new(dns.Msg)
r.SetQuestion(tt.queryPattern, dns.TypeA)
w := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}}
w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
// First verify no handler is called before adding any
chain.ServeDNS(w, r)

View File

@@ -0,0 +1,130 @@
package dns
import (
"fmt"
"strings"
"sync"
"github.com/miekg/dns"
log "github.com/sirupsen/logrus"
nbdns "github.com/netbirdio/netbird/dns"
)
type registrationMap map[string]struct{}
type localResolver struct {
registeredMap registrationMap
records sync.Map // key: string (domain_class_type), value: []dns.RR
}
func (d *localResolver) MatchSubdomains() bool {
return true
}
func (d *localResolver) stop() {
}
// String returns a string representation of the local resolver
func (d *localResolver) String() string {
return fmt.Sprintf("local resolver [%d records]", len(d.registeredMap))
}
// ID returns the unique handler ID
func (d *localResolver) id() handlerID {
return "local-resolver"
}
// ServeDNS handles a DNS request
func (d *localResolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
if len(r.Question) > 0 {
log.Tracef("received local question: domain=%s type=%v class=%v", r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass)
}
replyMessage := &dns.Msg{}
replyMessage.SetReply(r)
replyMessage.RecursionAvailable = true
// lookup all records matching the question
records := d.lookupRecords(r)
if len(records) > 0 {
replyMessage.Rcode = dns.RcodeSuccess
replyMessage.Answer = append(replyMessage.Answer, records...)
} else {
replyMessage.Rcode = dns.RcodeNameError
}
err := w.WriteMsg(replyMessage)
if err != nil {
log.Debugf("got an error while writing the local resolver response, error: %v", err)
}
}
// lookupRecords fetches *all* DNS records matching the first question in r.
func (d *localResolver) lookupRecords(r *dns.Msg) []dns.RR {
if len(r.Question) == 0 {
return nil
}
question := r.Question[0]
question.Name = strings.ToLower(question.Name)
key := buildRecordKey(question.Name, question.Qclass, question.Qtype)
value, found := d.records.Load(key)
if !found {
// alternatively check if we have a cname
if question.Qtype != dns.TypeCNAME {
r.Question[0].Qtype = dns.TypeCNAME
return d.lookupRecords(r)
}
return nil
}
records, ok := value.([]dns.RR)
if !ok {
log.Errorf("failed to cast records to []dns.RR, records: %v", value)
return nil
}
// if there's more than one record, rotate them (round-robin)
if len(records) > 1 {
first := records[0]
records = append(records[1:], first)
d.records.Store(key, records)
}
return records
}
// registerRecord stores a new record by appending it to any existing list
func (d *localResolver) registerRecord(record nbdns.SimpleRecord) (string, error) {
rr, err := dns.NewRR(record.String())
if err != nil {
return "", fmt.Errorf("register record: %w", err)
}
rr.Header().Rdlength = record.Len()
header := rr.Header()
key := buildRecordKey(header.Name, header.Class, header.Rrtype)
// load any existing slice of records, then append
existing, _ := d.records.LoadOrStore(key, []dns.RR{})
records := existing.([]dns.RR)
records = append(records, rr)
// store updated slice
d.records.Store(key, records)
return key, nil
}
// deleteRecord removes *all* records under the recordKey.
func (d *localResolver) deleteRecord(recordKey string) {
d.records.Delete(dns.Fqdn(recordKey))
}
// buildRecordKey consistently generates a key: name_class_type
func buildRecordKey(name string, class, qType uint16) string {
return fmt.Sprintf("%s_%d_%d", dns.Fqdn(name), class, qType)
}
func (d *localResolver) probeAvailability() {}

View File

@@ -1,149 +0,0 @@
package local
import (
"fmt"
"slices"
"strings"
"sync"
"github.com/miekg/dns"
log "github.com/sirupsen/logrus"
"golang.org/x/exp/maps"
"github.com/netbirdio/netbird/client/internal/dns/types"
nbdns "github.com/netbirdio/netbird/dns"
)
type Resolver struct {
mu sync.RWMutex
records map[dns.Question][]dns.RR
}
func NewResolver() *Resolver {
return &Resolver{
records: make(map[dns.Question][]dns.RR),
}
}
func (d *Resolver) MatchSubdomains() bool {
return true
}
// String returns a string representation of the local resolver
func (d *Resolver) String() string {
return fmt.Sprintf("local resolver [%d records]", len(d.records))
}
func (d *Resolver) Stop() {}
// ID returns the unique handler ID
func (d *Resolver) ID() types.HandlerID {
return "local-resolver"
}
func (d *Resolver) ProbeAvailability() {}
// ServeDNS handles a DNS request
func (d *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
if len(r.Question) == 0 {
log.Debugf("received local resolver request with no question")
return
}
question := r.Question[0]
question.Name = strings.ToLower(dns.Fqdn(question.Name))
log.Tracef("received local question: domain=%s type=%v class=%v", r.Question[0].Name, question.Qtype, question.Qclass)
replyMessage := &dns.Msg{}
replyMessage.SetReply(r)
replyMessage.RecursionAvailable = true
// lookup all records matching the question
records := d.lookupRecords(question)
if len(records) > 0 {
replyMessage.Rcode = dns.RcodeSuccess
replyMessage.Answer = append(replyMessage.Answer, records...)
} else {
// TODO: return success if we have a different record type for the same name, relevant for search domains
replyMessage.Rcode = dns.RcodeNameError
}
if err := w.WriteMsg(replyMessage); err != nil {
log.Warnf("failed to write the local resolver response: %v", err)
}
}
// lookupRecords fetches *all* DNS records matching the first question in r.
func (d *Resolver) lookupRecords(question dns.Question) []dns.RR {
d.mu.RLock()
records, found := d.records[question]
if !found {
d.mu.RUnlock()
// alternatively check if we have a cname
if question.Qtype != dns.TypeCNAME {
question.Qtype = dns.TypeCNAME
return d.lookupRecords(question)
}
return nil
}
recordsCopy := slices.Clone(records)
d.mu.RUnlock()
// if there's more than one record, rotate them (round-robin)
if len(recordsCopy) > 1 {
d.mu.Lock()
records = d.records[question]
if len(records) > 1 {
first := records[0]
records = append(records[1:], first)
d.records[question] = records
}
d.mu.Unlock()
}
return recordsCopy
}
func (d *Resolver) Update(update []nbdns.SimpleRecord) {
d.mu.Lock()
defer d.mu.Unlock()
maps.Clear(d.records)
for _, rec := range update {
if err := d.registerRecord(rec); err != nil {
log.Warnf("failed to register the record (%s): %v", rec, err)
continue
}
}
}
// RegisterRecord stores a new record by appending it to any existing list
func (d *Resolver) RegisterRecord(record nbdns.SimpleRecord) error {
d.mu.Lock()
defer d.mu.Unlock()
return d.registerRecord(record)
}
// registerRecord performs the registration with the lock already held
func (d *Resolver) registerRecord(record nbdns.SimpleRecord) error {
rr, err := dns.NewRR(record.String())
if err != nil {
return fmt.Errorf("register record: %w", err)
}
rr.Header().Rdlength = record.Len()
header := rr.Header()
q := dns.Question{
Name: strings.ToLower(dns.Fqdn(header.Name)),
Qtype: header.Rrtype,
Qclass: header.Class,
}
d.records[q] = append(d.records[q], rr)
return nil
}

View File

@@ -1,472 +0,0 @@
package local
import (
"strings"
"testing"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/internal/dns/test"
nbdns "github.com/netbirdio/netbird/dns"
)
func TestLocalResolver_ServeDNS(t *testing.T) {
recordA := nbdns.SimpleRecord{
Name: "peera.netbird.cloud.",
Type: 1,
Class: nbdns.DefaultClass,
TTL: 300,
RData: "1.2.3.4",
}
recordCNAME := nbdns.SimpleRecord{
Name: "peerb.netbird.cloud.",
Type: 5,
Class: nbdns.DefaultClass,
TTL: 300,
RData: "www.netbird.io",
}
testCases := []struct {
name string
inputRecord nbdns.SimpleRecord
inputMSG *dns.Msg
responseShouldBeNil bool
}{
{
name: "Should Resolve A Record",
inputRecord: recordA,
inputMSG: new(dns.Msg).SetQuestion(recordA.Name, dns.TypeA),
},
{
name: "Should Resolve CNAME Record",
inputRecord: recordCNAME,
inputMSG: new(dns.Msg).SetQuestion(recordCNAME.Name, dns.TypeCNAME),
},
{
name: "Should Not Write When Not Found A Record",
inputRecord: recordA,
inputMSG: new(dns.Msg).SetQuestion("not.found.com", dns.TypeA),
responseShouldBeNil: true,
},
}
for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
resolver := NewResolver()
_ = resolver.RegisterRecord(testCase.inputRecord)
var responseMSG *dns.Msg
responseWriter := &test.MockResponseWriter{
WriteMsgFunc: func(m *dns.Msg) error {
responseMSG = m
return nil
},
}
resolver.ServeDNS(responseWriter, testCase.inputMSG)
if responseMSG == nil || len(responseMSG.Answer) == 0 {
if testCase.responseShouldBeNil {
return
}
t.Fatalf("should write a response message")
}
answerString := responseMSG.Answer[0].String()
if !strings.Contains(answerString, testCase.inputRecord.Name) {
t.Fatalf("answer doesn't contain the same domain name: \nWant: %s\nGot:%s", testCase.name, answerString)
}
if !strings.Contains(answerString, dns.Type(testCase.inputRecord.Type).String()) {
t.Fatalf("answer doesn't contain the correct type: \nWant: %s\nGot:%s", dns.Type(testCase.inputRecord.Type).String(), answerString)
}
if !strings.Contains(answerString, testCase.inputRecord.RData) {
t.Fatalf("answer doesn't contain the same address: \nWant: %s\nGot:%s", testCase.inputRecord.RData, answerString)
}
})
}
}
// TestLocalResolver_Update_StaleRecord verifies that updating
// a record correctly replaces the old one, preventing stale entries.
func TestLocalResolver_Update_StaleRecord(t *testing.T) {
recordName := "host.example.com."
recordType := dns.TypeA
recordClass := dns.ClassINET
record1 := nbdns.SimpleRecord{
Name: recordName, Type: int(recordType), Class: nbdns.DefaultClass, TTL: 300, RData: "1.1.1.1",
}
record2 := nbdns.SimpleRecord{
Name: recordName, Type: int(recordType), Class: nbdns.DefaultClass, TTL: 300, RData: "2.2.2.2",
}
recordKey := dns.Question{Name: recordName, Qtype: uint16(recordClass), Qclass: recordType}
resolver := NewResolver()
update1 := []nbdns.SimpleRecord{record1}
update2 := []nbdns.SimpleRecord{record2}
// Apply first update
resolver.Update(update1)
// Verify first update
resolver.mu.RLock()
rrSlice1, found1 := resolver.records[recordKey]
resolver.mu.RUnlock()
require.True(t, found1, "Record key %s not found after first update", recordKey)
require.Len(t, rrSlice1, 1, "Should have exactly 1 record after first update")
assert.Contains(t, rrSlice1[0].String(), record1.RData, "Record after first update should be %s", record1.RData)
// Apply second update
resolver.Update(update2)
// Verify second update
resolver.mu.RLock()
rrSlice2, found2 := resolver.records[recordKey]
resolver.mu.RUnlock()
require.True(t, found2, "Record key %s not found after second update", recordKey)
require.Len(t, rrSlice2, 1, "Should have exactly 1 record after update overwriting the key")
assert.Contains(t, rrSlice2[0].String(), record2.RData, "The single record should be the updated one (%s)", record2.RData)
assert.NotContains(t, rrSlice2[0].String(), record1.RData, "The stale record (%s) should not be present", record1.RData)
}
// TestLocalResolver_MultipleRecords_SameQuestion verifies that multiple records
// with the same question are stored properly
func TestLocalResolver_MultipleRecords_SameQuestion(t *testing.T) {
resolver := NewResolver()
recordName := "multi.example.com."
recordType := dns.TypeA
// Create two records with the same name and type but different IPs
record1 := nbdns.SimpleRecord{
Name: recordName, Type: int(recordType), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1",
}
record2 := nbdns.SimpleRecord{
Name: recordName, Type: int(recordType), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.2",
}
update := []nbdns.SimpleRecord{record1, record2}
// Apply update with both records
resolver.Update(update)
// Create question that matches both records
question := dns.Question{
Name: recordName,
Qtype: recordType,
Qclass: dns.ClassINET,
}
// Verify both records are stored
resolver.mu.RLock()
records, found := resolver.records[question]
resolver.mu.RUnlock()
require.True(t, found, "Records for question %v not found", question)
require.Len(t, records, 2, "Should have exactly 2 records for the same question")
// Verify both record data values are present
recordStrings := []string{records[0].String(), records[1].String()}
assert.Contains(t, recordStrings[0]+recordStrings[1], record1.RData, "First record data should be present")
assert.Contains(t, recordStrings[0]+recordStrings[1], record2.RData, "Second record data should be present")
}
// TestLocalResolver_RecordRotation verifies that records are rotated in a round-robin fashion
func TestLocalResolver_RecordRotation(t *testing.T) {
resolver := NewResolver()
recordName := "rotation.example.com."
recordType := dns.TypeA
// Create three records with the same name and type but different IPs
record1 := nbdns.SimpleRecord{
Name: recordName, Type: int(recordType), Class: nbdns.DefaultClass, TTL: 300, RData: "192.168.1.1",
}
record2 := nbdns.SimpleRecord{
Name: recordName, Type: int(recordType), Class: nbdns.DefaultClass, TTL: 300, RData: "192.168.1.2",
}
record3 := nbdns.SimpleRecord{
Name: recordName, Type: int(recordType), Class: nbdns.DefaultClass, TTL: 300, RData: "192.168.1.3",
}
update := []nbdns.SimpleRecord{record1, record2, record3}
// Apply update with all three records
resolver.Update(update)
msg := new(dns.Msg).SetQuestion(recordName, recordType)
// First lookup - should return the records in original order
var responses [3]*dns.Msg
// Perform three lookups to verify rotation
for i := 0; i < 3; i++ {
responseWriter := &test.MockResponseWriter{
WriteMsgFunc: func(m *dns.Msg) error {
responses[i] = m
return nil
},
}
resolver.ServeDNS(responseWriter, msg)
}
// Verify all three responses contain answers
for i, resp := range responses {
require.NotNil(t, resp, "Response %d should not be nil", i)
require.Len(t, resp.Answer, 3, "Response %d should have 3 answers", i)
}
// Verify the first record in each response is different due to rotation
firstRecordIPs := []string{
responses[0].Answer[0].String(),
responses[1].Answer[0].String(),
responses[2].Answer[0].String(),
}
// Each record should be different (rotated)
assert.NotEqual(t, firstRecordIPs[0], firstRecordIPs[1], "First lookup should differ from second lookup due to rotation")
assert.NotEqual(t, firstRecordIPs[1], firstRecordIPs[2], "Second lookup should differ from third lookup due to rotation")
assert.NotEqual(t, firstRecordIPs[0], firstRecordIPs[2], "First lookup should differ from third lookup due to rotation")
// After three rotations, we should have cycled through all records
assert.Contains(t, firstRecordIPs[0]+firstRecordIPs[1]+firstRecordIPs[2], record1.RData)
assert.Contains(t, firstRecordIPs[0]+firstRecordIPs[1]+firstRecordIPs[2], record2.RData)
assert.Contains(t, firstRecordIPs[0]+firstRecordIPs[1]+firstRecordIPs[2], record3.RData)
}
// TestLocalResolver_CaseInsensitiveMatching verifies that DNS record lookups are case-insensitive
func TestLocalResolver_CaseInsensitiveMatching(t *testing.T) {
resolver := NewResolver()
// Create record with lowercase name
lowerCaseRecord := nbdns.SimpleRecord{
Name: "lower.example.com.",
Type: int(dns.TypeA),
Class: nbdns.DefaultClass,
TTL: 300,
RData: "10.10.10.10",
}
// Create record with mixed case name
mixedCaseRecord := nbdns.SimpleRecord{
Name: "MiXeD.ExAmPlE.CoM.",
Type: int(dns.TypeA),
Class: nbdns.DefaultClass,
TTL: 300,
RData: "20.20.20.20",
}
// Update resolver with the records
resolver.Update([]nbdns.SimpleRecord{lowerCaseRecord, mixedCaseRecord})
testCases := []struct {
name string
queryName string
expectedRData string
shouldResolve bool
}{
{
name: "Query lowercase with lowercase record",
queryName: "lower.example.com.",
expectedRData: "10.10.10.10",
shouldResolve: true,
},
{
name: "Query uppercase with lowercase record",
queryName: "LOWER.EXAMPLE.COM.",
expectedRData: "10.10.10.10",
shouldResolve: true,
},
{
name: "Query mixed case with lowercase record",
queryName: "LoWeR.eXaMpLe.CoM.",
expectedRData: "10.10.10.10",
shouldResolve: true,
},
{
name: "Query lowercase with mixed case record",
queryName: "mixed.example.com.",
expectedRData: "20.20.20.20",
shouldResolve: true,
},
{
name: "Query uppercase with mixed case record",
queryName: "MIXED.EXAMPLE.COM.",
expectedRData: "20.20.20.20",
shouldResolve: true,
},
{
name: "Query with different casing pattern",
queryName: "mIxEd.ExaMpLe.cOm.",
expectedRData: "20.20.20.20",
shouldResolve: true,
},
{
name: "Query non-existent domain",
queryName: "nonexistent.example.com.",
shouldResolve: false,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
var responseMSG *dns.Msg
// Create DNS query with the test case name
msg := new(dns.Msg).SetQuestion(tc.queryName, dns.TypeA)
// Create mock response writer to capture the response
responseWriter := &test.MockResponseWriter{
WriteMsgFunc: func(m *dns.Msg) error {
responseMSG = m
return nil
},
}
// Perform DNS query
resolver.ServeDNS(responseWriter, msg)
// Check if we expect a successful resolution
if !tc.shouldResolve {
if responseMSG == nil || len(responseMSG.Answer) == 0 {
// Expected no answer, test passes
return
}
t.Fatalf("Expected no resolution for %s, but got answer: %v", tc.queryName, responseMSG.Answer)
}
// Verify we got a response
require.NotNil(t, responseMSG, "Should have received a response message")
require.Greater(t, len(responseMSG.Answer), 0, "Response should contain at least one answer")
// Verify the response contains the expected data
answerString := responseMSG.Answer[0].String()
assert.Contains(t, answerString, tc.expectedRData,
"Answer should contain the expected IP address %s, got: %s",
tc.expectedRData, answerString)
})
}
}
// TestLocalResolver_CNAMEFallback verifies that the resolver correctly falls back
// to checking for CNAME records when the requested record type isn't found
func TestLocalResolver_CNAMEFallback(t *testing.T) {
resolver := NewResolver()
// Create a CNAME record (but no A record for this name)
cnameRecord := nbdns.SimpleRecord{
Name: "alias.example.com.",
Type: int(dns.TypeCNAME),
Class: nbdns.DefaultClass,
TTL: 300,
RData: "target.example.com.",
}
// Create an A record for the CNAME target
targetRecord := nbdns.SimpleRecord{
Name: "target.example.com.",
Type: int(dns.TypeA),
Class: nbdns.DefaultClass,
TTL: 300,
RData: "192.168.100.100",
}
// Update resolver with both records
resolver.Update([]nbdns.SimpleRecord{cnameRecord, targetRecord})
testCases := []struct {
name string
queryName string
queryType uint16
expectedType string
expectedRData string
shouldResolve bool
}{
{
name: "Directly query CNAME record",
queryName: "alias.example.com.",
queryType: dns.TypeCNAME,
expectedType: "CNAME",
expectedRData: "target.example.com.",
shouldResolve: true,
},
{
name: "Query A record but get CNAME fallback",
queryName: "alias.example.com.",
queryType: dns.TypeA,
expectedType: "CNAME",
expectedRData: "target.example.com.",
shouldResolve: true,
},
{
name: "Query AAAA record but get CNAME fallback",
queryName: "alias.example.com.",
queryType: dns.TypeAAAA,
expectedType: "CNAME",
expectedRData: "target.example.com.",
shouldResolve: true,
},
{
name: "Query direct A record",
queryName: "target.example.com.",
queryType: dns.TypeA,
expectedType: "A",
expectedRData: "192.168.100.100",
shouldResolve: true,
},
{
name: "Query non-existent name",
queryName: "nonexistent.example.com.",
queryType: dns.TypeA,
shouldResolve: false,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
var responseMSG *dns.Msg
// Create DNS query with the test case parameters
msg := new(dns.Msg).SetQuestion(tc.queryName, tc.queryType)
// Create mock response writer to capture the response
responseWriter := &test.MockResponseWriter{
WriteMsgFunc: func(m *dns.Msg) error {
responseMSG = m
return nil
},
}
// Perform DNS query
resolver.ServeDNS(responseWriter, msg)
// Check if we expect a successful resolution
if !tc.shouldResolve {
if responseMSG == nil || len(responseMSG.Answer) == 0 || responseMSG.Rcode != dns.RcodeSuccess {
// Expected no resolution, test passes
return
}
t.Fatalf("Expected no resolution for %s, but got answer: %v", tc.queryName, responseMSG.Answer)
}
// Verify we got a successful response
require.NotNil(t, responseMSG, "Should have received a response message")
require.Equal(t, dns.RcodeSuccess, responseMSG.Rcode, "Response should have success status code")
require.Greater(t, len(responseMSG.Answer), 0, "Response should contain at least one answer")
// Verify the response contains the expected data
answerString := responseMSG.Answer[0].String()
assert.Contains(t, answerString, tc.expectedType,
"Answer should be of type %s, got: %s", tc.expectedType, answerString)
assert.Contains(t, answerString, tc.expectedRData,
"Answer should contain the expected data %s, got: %s", tc.expectedRData, answerString)
})
}
}

View File

@@ -0,0 +1,88 @@
package dns
import (
"strings"
"testing"
"github.com/miekg/dns"
nbdns "github.com/netbirdio/netbird/dns"
)
func TestLocalResolver_ServeDNS(t *testing.T) {
recordA := nbdns.SimpleRecord{
Name: "peera.netbird.cloud.",
Type: 1,
Class: nbdns.DefaultClass,
TTL: 300,
RData: "1.2.3.4",
}
recordCNAME := nbdns.SimpleRecord{
Name: "peerb.netbird.cloud.",
Type: 5,
Class: nbdns.DefaultClass,
TTL: 300,
RData: "www.netbird.io",
}
testCases := []struct {
name string
inputRecord nbdns.SimpleRecord
inputMSG *dns.Msg
responseShouldBeNil bool
}{
{
name: "Should Resolve A Record",
inputRecord: recordA,
inputMSG: new(dns.Msg).SetQuestion(recordA.Name, dns.TypeA),
},
{
name: "Should Resolve CNAME Record",
inputRecord: recordCNAME,
inputMSG: new(dns.Msg).SetQuestion(recordCNAME.Name, dns.TypeCNAME),
},
{
name: "Should Not Write When Not Found A Record",
inputRecord: recordA,
inputMSG: new(dns.Msg).SetQuestion("not.found.com", dns.TypeA),
responseShouldBeNil: true,
},
}
for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
resolver := &localResolver{
registeredMap: make(registrationMap),
}
_, _ = resolver.registerRecord(testCase.inputRecord)
var responseMSG *dns.Msg
responseWriter := &mockResponseWriter{
WriteMsgFunc: func(m *dns.Msg) error {
responseMSG = m
return nil
},
}
resolver.ServeDNS(responseWriter, testCase.inputMSG)
if responseMSG == nil || len(responseMSG.Answer) == 0 {
if testCase.responseShouldBeNil {
return
}
t.Fatalf("should write a response message")
}
answerString := responseMSG.Answer[0].String()
if !strings.Contains(answerString, testCase.inputRecord.Name) {
t.Fatalf("answer doesn't contain the same domain name: \nWant: %s\nGot:%s", testCase.name, answerString)
}
if !strings.Contains(answerString, dns.Type(testCase.inputRecord.Type).String()) {
t.Fatalf("answer doesn't contain the correct type: \nWant: %s\nGot:%s", dns.Type(testCase.inputRecord.Type).String(), answerString)
}
if !strings.Contains(answerString, testCase.inputRecord.RData) {
t.Fatalf("answer doesn't contain the same address: \nWant: %s\nGot:%s", testCase.inputRecord.RData, answerString)
}
})
}
}

View File

@@ -0,0 +1,26 @@
package dns
import (
"net"
"github.com/miekg/dns"
)
type mockResponseWriter struct {
WriteMsgFunc func(m *dns.Msg) error
}
func (rw *mockResponseWriter) WriteMsg(m *dns.Msg) error {
if rw.WriteMsgFunc != nil {
return rw.WriteMsgFunc(m)
}
return nil
}
func (rw *mockResponseWriter) LocalAddr() net.Addr { return nil }
func (rw *mockResponseWriter) RemoteAddr() net.Addr { return nil }
func (rw *mockResponseWriter) Write([]byte) (int, error) { return 0, nil }
func (rw *mockResponseWriter) Close() error { return nil }
func (rw *mockResponseWriter) TsigStatus() error { return nil }
func (rw *mockResponseWriter) TsigTimersOnly(bool) {}
func (rw *mockResponseWriter) Hijack() {}

View File

@@ -15,8 +15,6 @@ import (
"golang.org/x/exp/maps"
"github.com/netbirdio/netbird/client/iface/netstack"
"github.com/netbirdio/netbird/client/internal/dns/local"
"github.com/netbirdio/netbird/client/internal/dns/types"
"github.com/netbirdio/netbird/client/internal/listener"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/statemanager"
@@ -48,6 +46,8 @@ type Server interface {
ProbeAvailability()
}
type handlerID string
type nsGroupsByDomain struct {
domain string
groups []*nbdns.NameServerGroup
@@ -61,7 +61,7 @@ type DefaultServer struct {
mux sync.Mutex
service service
dnsMuxMap registeredHandlerMap
localResolver *local.Resolver
localResolver *localResolver
wgInterface WGIface
hostManager hostManager
updateSerial uint64
@@ -84,9 +84,9 @@ type DefaultServer struct {
type handlerWithStop interface {
dns.Handler
Stop()
ProbeAvailability()
ID() types.HandlerID
stop()
probeAvailability()
id() handlerID
}
type handlerWrapper struct {
@@ -95,7 +95,7 @@ type handlerWrapper struct {
priority int
}
type registeredHandlerMap map[types.HandlerID]handlerWrapper
type registeredHandlerMap map[handlerID]handlerWrapper
// NewDefaultServer returns a new dns server
func NewDefaultServer(
@@ -171,14 +171,16 @@ func newDefaultServer(
handlerChain := NewHandlerChain()
ctx, stop := context.WithCancel(ctx)
defaultServer := &DefaultServer{
ctx: ctx,
ctxCancel: stop,
disableSys: disableSys,
service: dnsService,
handlerChain: handlerChain,
extraDomains: make(map[domain.Domain]int),
dnsMuxMap: make(registeredHandlerMap),
localResolver: local.NewResolver(),
ctx: ctx,
ctxCancel: stop,
disableSys: disableSys,
service: dnsService,
handlerChain: handlerChain,
extraDomains: make(map[domain.Domain]int),
dnsMuxMap: make(registeredHandlerMap),
localResolver: &localResolver{
registeredMap: make(registrationMap),
},
wgInterface: wgInterface,
statusRecorder: statusRecorder,
stateManager: stateManager,
@@ -401,7 +403,7 @@ func (s *DefaultServer) ProbeAvailability() {
wg.Add(1)
go func(mux handlerWithStop) {
defer wg.Done()
mux.ProbeAvailability()
mux.probeAvailability()
}(mux.handler)
}
wg.Wait()
@@ -418,7 +420,7 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
s.service.Stop()
}
localMuxUpdates, localRecords, err := s.buildLocalHandlerUpdate(update.CustomZones)
localMuxUpdates, localRecordsByDomain, err := s.buildLocalHandlerUpdate(update.CustomZones)
if err != nil {
return fmt.Errorf("local handler updater: %w", err)
}
@@ -432,7 +434,7 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
s.updateMux(muxUpdates)
// register local records
s.localResolver.Update(localRecords)
s.updateLocalResolver(localRecordsByDomain)
s.currentConfig = dnsConfigToHostDNSConfig(update, s.service.RuntimeIP(), s.service.RuntimePort())
@@ -514,9 +516,11 @@ func (s *DefaultServer) handleErrNoGroupaAll(err error) {
)
}
func (s *DefaultServer) buildLocalHandlerUpdate(customZones []nbdns.CustomZone) ([]handlerWrapper, []nbdns.SimpleRecord, error) {
func (s *DefaultServer) buildLocalHandlerUpdate(
customZones []nbdns.CustomZone,
) ([]handlerWrapper, map[string][]nbdns.SimpleRecord, error) {
var muxUpdates []handlerWrapper
var localRecords []nbdns.SimpleRecord
localRecords := make(map[string][]nbdns.SimpleRecord)
for _, customZone := range customZones {
if len(customZone.Records) == 0 {
@@ -530,13 +534,17 @@ func (s *DefaultServer) buildLocalHandlerUpdate(customZones []nbdns.CustomZone)
priority: PriorityMatchDomain,
})
// group all records under this domain
for _, record := range customZone.Records {
var class uint16 = dns.ClassINET
if record.Class != nbdns.DefaultClass {
log.Warnf("received an invalid class type: %s", record.Class)
continue
}
// zone records contain the fqdn, so we can just flatten them
localRecords = append(localRecords, record)
key := buildRecordKey(record.Name, class, uint16(record.Type))
localRecords[key] = append(localRecords[key], record)
}
}
@@ -619,7 +627,7 @@ func (s *DefaultServer) createHandlersForDomainGroup(domainGroup nsGroupsByDomai
}
if len(handler.upstreamServers) == 0 {
handler.Stop()
handler.stop()
log.Errorf("received a nameserver group with an invalid nameserver list")
continue
}
@@ -648,7 +656,7 @@ func (s *DefaultServer) updateMux(muxUpdates []handlerWrapper) {
// this will introduce a short period of time when the server is not able to handle DNS requests
for _, existing := range s.dnsMuxMap {
s.deregisterHandler([]string{existing.domain}, existing.priority)
existing.handler.Stop()
existing.handler.stop()
}
muxUpdateMap := make(registeredHandlerMap)
@@ -659,7 +667,7 @@ func (s *DefaultServer) updateMux(muxUpdates []handlerWrapper) {
containsRootUpdate = true
}
s.registerHandler([]string{update.domain}, update.handler, update.priority)
muxUpdateMap[update.handler.ID()] = update
muxUpdateMap[update.handler.id()] = update
}
// If there's no root update and we had a root handler, restore it
@@ -675,6 +683,33 @@ func (s *DefaultServer) updateMux(muxUpdates []handlerWrapper) {
s.dnsMuxMap = muxUpdateMap
}
func (s *DefaultServer) updateLocalResolver(update map[string][]nbdns.SimpleRecord) {
// remove old records that are no longer present
for key := range s.localResolver.registeredMap {
_, found := update[key]
if !found {
s.localResolver.deleteRecord(key)
}
}
updatedMap := make(registrationMap)
for _, recs := range update {
for _, rec := range recs {
// convert the record to a dns.RR and register
key, err := s.localResolver.registerRecord(rec)
if err != nil {
log.Warnf("got an error while registering the record (%s), error: %v",
rec.String(), err)
continue
}
updatedMap[key] = struct{}{}
}
}
s.localResolver.registeredMap = updatedMap
}
func getNSHostPort(ns nbdns.NameServer) string {
return fmt.Sprintf("%s:%d", ns.IP.String(), ns.Port)
}

View File

@@ -23,9 +23,6 @@ import (
"github.com/netbirdio/netbird/client/iface/device"
pfmock "github.com/netbirdio/netbird/client/iface/mocks"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/internal/dns/local"
"github.com/netbirdio/netbird/client/internal/dns/test"
"github.com/netbirdio/netbird/client/internal/dns/types"
"github.com/netbirdio/netbird/client/internal/netflow"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/statemanager"
@@ -110,7 +107,6 @@ func generateDummyHandler(domain string, servers []nbdns.NameServer) *upstreamRe
}
func TestUpdateDNSServer(t *testing.T) {
nameServers := []nbdns.NameServer{
{
IP: netip.MustParseAddr("8.8.8.8"),
@@ -124,21 +120,22 @@ func TestUpdateDNSServer(t *testing.T) {
},
}
dummyHandler := local.NewResolver()
dummyHandler := &localResolver{}
testCases := []struct {
name string
initUpstreamMap registeredHandlerMap
initLocalRecords []nbdns.SimpleRecord
initLocalMap registrationMap
initSerial uint64
inputSerial uint64
inputUpdate nbdns.Config
shouldFail bool
expectedUpstreamMap registeredHandlerMap
expectedLocalQs []dns.Question
expectedLocalMap registrationMap
}{
{
name: "Initial Config Should Succeed",
initLocalMap: make(registrationMap),
initUpstreamMap: make(registeredHandlerMap),
initSerial: 0,
inputSerial: 1,
@@ -162,30 +159,30 @@ func TestUpdateDNSServer(t *testing.T) {
},
},
expectedUpstreamMap: registeredHandlerMap{
generateDummyHandler("netbird.io", nameServers).ID(): handlerWrapper{
generateDummyHandler("netbird.io", nameServers).id(): handlerWrapper{
domain: "netbird.io",
handler: dummyHandler,
priority: PriorityMatchDomain,
},
dummyHandler.ID(): handlerWrapper{
dummyHandler.id(): handlerWrapper{
domain: "netbird.cloud",
handler: dummyHandler,
priority: PriorityMatchDomain,
},
generateDummyHandler(".", nameServers).ID(): handlerWrapper{
generateDummyHandler(".", nameServers).id(): handlerWrapper{
domain: nbdns.RootZone,
handler: dummyHandler,
priority: PriorityDefault,
},
},
expectedLocalQs: []dns.Question{{Name: "peera.netbird.cloud.", Qtype: dns.TypeA, Qclass: dns.ClassINET}},
expectedLocalMap: registrationMap{buildRecordKey(zoneRecords[0].Name, 1, 1): struct{}{}},
},
{
name: "New Config Should Succeed",
initLocalRecords: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: 1, Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}},
name: "New Config Should Succeed",
initLocalMap: registrationMap{"netbird.cloud": struct{}{}},
initUpstreamMap: registeredHandlerMap{
generateDummyHandler(zoneRecords[0].Name, nameServers).ID(): handlerWrapper{
domain: "netbird.cloud",
generateDummyHandler(zoneRecords[0].Name, nameServers).id(): handlerWrapper{
domain: buildRecordKey(zoneRecords[0].Name, 1, 1),
handler: dummyHandler,
priority: PriorityMatchDomain,
},
@@ -208,7 +205,7 @@ func TestUpdateDNSServer(t *testing.T) {
},
},
expectedUpstreamMap: registeredHandlerMap{
generateDummyHandler("netbird.io", nameServers).ID(): handlerWrapper{
generateDummyHandler("netbird.io", nameServers).id(): handlerWrapper{
domain: "netbird.io",
handler: dummyHandler,
priority: PriorityMatchDomain,
@@ -219,22 +216,22 @@ func TestUpdateDNSServer(t *testing.T) {
priority: PriorityMatchDomain,
},
},
expectedLocalQs: []dns.Question{{Name: zoneRecords[0].Name, Qtype: 1, Qclass: 1}},
expectedLocalMap: registrationMap{buildRecordKey(zoneRecords[0].Name, 1, 1): struct{}{}},
},
{
name: "Smaller Config Serial Should Be Skipped",
initLocalRecords: []nbdns.SimpleRecord{},
initUpstreamMap: make(registeredHandlerMap),
initSerial: 2,
inputSerial: 1,
shouldFail: true,
name: "Smaller Config Serial Should Be Skipped",
initLocalMap: make(registrationMap),
initUpstreamMap: make(registeredHandlerMap),
initSerial: 2,
inputSerial: 1,
shouldFail: true,
},
{
name: "Empty NS Group Domain Or Not Primary Element Should Fail",
initLocalRecords: []nbdns.SimpleRecord{},
initUpstreamMap: make(registeredHandlerMap),
initSerial: 0,
inputSerial: 1,
name: "Empty NS Group Domain Or Not Primary Element Should Fail",
initLocalMap: make(registrationMap),
initUpstreamMap: make(registeredHandlerMap),
initSerial: 0,
inputSerial: 1,
inputUpdate: nbdns.Config{
ServiceEnable: true,
CustomZones: []nbdns.CustomZone{
@@ -252,11 +249,11 @@ func TestUpdateDNSServer(t *testing.T) {
shouldFail: true,
},
{
name: "Invalid NS Group Nameservers list Should Fail",
initLocalRecords: []nbdns.SimpleRecord{},
initUpstreamMap: make(registeredHandlerMap),
initSerial: 0,
inputSerial: 1,
name: "Invalid NS Group Nameservers list Should Fail",
initLocalMap: make(registrationMap),
initUpstreamMap: make(registeredHandlerMap),
initSerial: 0,
inputSerial: 1,
inputUpdate: nbdns.Config{
ServiceEnable: true,
CustomZones: []nbdns.CustomZone{
@@ -274,11 +271,11 @@ func TestUpdateDNSServer(t *testing.T) {
shouldFail: true,
},
{
name: "Invalid Custom Zone Records list Should Skip",
initLocalRecords: []nbdns.SimpleRecord{},
initUpstreamMap: make(registeredHandlerMap),
initSerial: 0,
inputSerial: 1,
name: "Invalid Custom Zone Records list Should Skip",
initLocalMap: make(registrationMap),
initUpstreamMap: make(registeredHandlerMap),
initSerial: 0,
inputSerial: 1,
inputUpdate: nbdns.Config{
ServiceEnable: true,
CustomZones: []nbdns.CustomZone{
@@ -293,17 +290,17 @@ func TestUpdateDNSServer(t *testing.T) {
},
},
},
expectedUpstreamMap: registeredHandlerMap{generateDummyHandler(".", nameServers).ID(): handlerWrapper{
expectedUpstreamMap: registeredHandlerMap{generateDummyHandler(".", nameServers).id(): handlerWrapper{
domain: ".",
handler: dummyHandler,
priority: PriorityDefault,
}},
},
{
name: "Empty Config Should Succeed and Clean Maps",
initLocalRecords: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}},
name: "Empty Config Should Succeed and Clean Maps",
initLocalMap: registrationMap{"netbird.cloud": struct{}{}},
initUpstreamMap: registeredHandlerMap{
generateDummyHandler(zoneRecords[0].Name, nameServers).ID(): handlerWrapper{
generateDummyHandler(zoneRecords[0].Name, nameServers).id(): handlerWrapper{
domain: zoneRecords[0].Name,
handler: dummyHandler,
priority: PriorityMatchDomain,
@@ -313,13 +310,13 @@ func TestUpdateDNSServer(t *testing.T) {
inputSerial: 1,
inputUpdate: nbdns.Config{ServiceEnable: true},
expectedUpstreamMap: make(registeredHandlerMap),
expectedLocalQs: []dns.Question{},
expectedLocalMap: make(registrationMap),
},
{
name: "Disabled Service Should clean map",
initLocalRecords: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}},
name: "Disabled Service Should clean map",
initLocalMap: registrationMap{"netbird.cloud": struct{}{}},
initUpstreamMap: registeredHandlerMap{
generateDummyHandler(zoneRecords[0].Name, nameServers).ID(): handlerWrapper{
generateDummyHandler(zoneRecords[0].Name, nameServers).id(): handlerWrapper{
domain: zoneRecords[0].Name,
handler: dummyHandler,
priority: PriorityMatchDomain,
@@ -329,7 +326,7 @@ func TestUpdateDNSServer(t *testing.T) {
inputSerial: 1,
inputUpdate: nbdns.Config{ServiceEnable: false},
expectedUpstreamMap: make(registeredHandlerMap),
expectedLocalQs: []dns.Question{},
expectedLocalMap: make(registrationMap),
},
}
@@ -380,7 +377,7 @@ func TestUpdateDNSServer(t *testing.T) {
}()
dnsServer.dnsMuxMap = testCase.initUpstreamMap
dnsServer.localResolver.Update(testCase.initLocalRecords)
dnsServer.localResolver.registeredMap = testCase.initLocalMap
dnsServer.updateSerial = testCase.initSerial
err = dnsServer.UpdateDNSServer(testCase.inputSerial, testCase.inputUpdate)
@@ -402,23 +399,15 @@ func TestUpdateDNSServer(t *testing.T) {
}
}
var responseMSG *dns.Msg
responseWriter := &test.MockResponseWriter{
WriteMsgFunc: func(m *dns.Msg) error {
responseMSG = m
return nil
},
}
for _, q := range testCase.expectedLocalQs {
dnsServer.localResolver.ServeDNS(responseWriter, &dns.Msg{
Question: []dns.Question{q},
})
if len(dnsServer.localResolver.registeredMap) != len(testCase.expectedLocalMap) {
t.Fatalf("update local failed, registered map size is different than expected, want %d, got %d", len(testCase.expectedLocalMap), len(dnsServer.localResolver.registeredMap))
}
if len(testCase.expectedLocalQs) > 0 {
assert.NotNil(t, responseMSG, "response message should not be nil")
assert.Equal(t, dns.RcodeSuccess, responseMSG.Rcode, "response code should be success")
assert.NotEmpty(t, responseMSG.Answer, "response message should have answers")
for key := range testCase.expectedLocalMap {
_, found := dnsServer.localResolver.registeredMap[key]
if !found {
t.Fatalf("update local failed, key %s was not found in the localResolver.registeredMap: %#v", key, dnsServer.localResolver.registeredMap)
}
}
})
}
@@ -502,12 +491,11 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) {
dnsServer.dnsMuxMap = registeredHandlerMap{
"id1": handlerWrapper{
domain: zoneRecords[0].Name,
handler: &local.Resolver{},
handler: &localResolver{},
priority: PriorityMatchDomain,
},
}
//dnsServer.localResolver.RegisteredMap = local.RegistrationMap{local.BuildRecordKey("netbird.cloud", dns.ClassINET, dns.TypeA): struct{}{}}
dnsServer.localResolver.Update([]nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}})
dnsServer.localResolver.registeredMap = registrationMap{"netbird.cloud": struct{}{}}
dnsServer.updateSerial = 0
nameServers := []nbdns.NameServer{
@@ -594,7 +582,7 @@ func TestDNSServerStartStop(t *testing.T) {
}
time.Sleep(100 * time.Millisecond)
defer dnsServer.Stop()
err = dnsServer.localResolver.RegisterRecord(zoneRecords[0])
_, err = dnsServer.localResolver.registerRecord(zoneRecords[0])
if err != nil {
t.Error(err)
}
@@ -642,11 +630,13 @@ func TestDNSServerStartStop(t *testing.T) {
func TestDNSServerUpstreamDeactivateCallback(t *testing.T) {
hostManager := &mockHostConfigurator{}
server := DefaultServer{
ctx: context.Background(),
service: NewServiceViaMemory(&mocWGIface{}),
localResolver: local.NewResolver(),
handlerChain: NewHandlerChain(),
hostManager: hostManager,
ctx: context.Background(),
service: NewServiceViaMemory(&mocWGIface{}),
localResolver: &localResolver{
registeredMap: make(registrationMap),
},
handlerChain: NewHandlerChain(),
hostManager: hostManager,
currentConfig: HostDNSConfig{
Domains: []DomainConfig{
{false, "domain0", false},
@@ -1014,7 +1004,7 @@ func TestHandlerChain_DomainPriorities(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
r := new(dns.Msg)
r.SetQuestion(tc.query, dns.TypeA)
w := &ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}}
w := &ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
if mh, ok := tc.expectedHandler.(*MockHandler); ok {
mh.On("ServeDNS", mock.Anything, r).Once()
@@ -1047,9 +1037,9 @@ type mockHandler struct {
}
func (m *mockHandler) ServeDNS(dns.ResponseWriter, *dns.Msg) {}
func (m *mockHandler) Stop() {}
func (m *mockHandler) ProbeAvailability() {}
func (m *mockHandler) ID() types.HandlerID { return types.HandlerID(m.Id) }
func (m *mockHandler) stop() {}
func (m *mockHandler) probeAvailability() {}
func (m *mockHandler) id() handlerID { return handlerID(m.Id) }
type mockService struct{}
@@ -1123,7 +1113,7 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
name string
initialHandlers registeredHandlerMap
updates []handlerWrapper
expectedHandlers map[string]string // map[HandlerID]domain
expectedHandlers map[string]string // map[handlerID]domain
description string
}{
{
@@ -1419,7 +1409,7 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
// Check each expected handler
for id, expectedDomain := range tt.expectedHandlers {
handler, exists := server.dnsMuxMap[types.HandlerID(id)]
handler, exists := server.dnsMuxMap[handlerID(id)]
assert.True(t, exists, "Expected handler %s not found", id)
if exists {
assert.Equal(t, expectedDomain, handler.domain,
@@ -1428,9 +1418,9 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
}
// Verify no unexpected handlers exist
for HandlerID := range server.dnsMuxMap {
_, expected := tt.expectedHandlers[string(HandlerID)]
assert.True(t, expected, "Unexpected handler found: %s", HandlerID)
for handlerID := range server.dnsMuxMap {
_, expected := tt.expectedHandlers[string(handlerID)]
assert.True(t, expected, "Unexpected handler found: %s", handlerID)
}
// Verify the handlerChain state and order
@@ -1706,7 +1696,7 @@ func TestExtraDomains(t *testing.T) {
handlerChain: NewHandlerChain(),
wgInterface: &mocWGIface{},
hostManager: mockHostConfig,
localResolver: &local.Resolver{},
localResolver: &localResolver{},
service: mockSvc,
statusRecorder: peer.NewRecorder("test"),
extraDomains: make(map[domain.Domain]int),
@@ -1791,7 +1781,7 @@ func TestExtraDomainsRefCounting(t *testing.T) {
ctx: context.Background(),
handlerChain: NewHandlerChain(),
hostManager: mockHostConfig,
localResolver: &local.Resolver{},
localResolver: &localResolver{},
service: mockSvc,
statusRecorder: peer.NewRecorder("test"),
extraDomains: make(map[domain.Domain]int),
@@ -1843,7 +1833,7 @@ func TestUpdateConfigWithExistingExtraDomains(t *testing.T) {
ctx: context.Background(),
handlerChain: NewHandlerChain(),
hostManager: mockHostConfig,
localResolver: &local.Resolver{},
localResolver: &localResolver{},
service: mockSvc,
statusRecorder: peer.NewRecorder("test"),
extraDomains: make(map[domain.Domain]int),
@@ -1926,7 +1916,7 @@ func TestDomainCaseHandling(t *testing.T) {
ctx: context.Background(),
handlerChain: NewHandlerChain(),
hostManager: mockHostConfig,
localResolver: &local.Resolver{},
localResolver: &localResolver{},
service: mockSvc,
statusRecorder: peer.NewRecorder("test"),
extraDomains: make(map[domain.Domain]int),

View File

@@ -1,26 +0,0 @@
package test
import (
"net"
"github.com/miekg/dns"
)
type MockResponseWriter struct {
WriteMsgFunc func(m *dns.Msg) error
}
func (rw *MockResponseWriter) WriteMsg(m *dns.Msg) error {
if rw.WriteMsgFunc != nil {
return rw.WriteMsgFunc(m)
}
return nil
}
func (rw *MockResponseWriter) LocalAddr() net.Addr { return nil }
func (rw *MockResponseWriter) RemoteAddr() net.Addr { return nil }
func (rw *MockResponseWriter) Write([]byte) (int, error) { return 0, nil }
func (rw *MockResponseWriter) Close() error { return nil }
func (rw *MockResponseWriter) TsigStatus() error { return nil }
func (rw *MockResponseWriter) TsigTimersOnly(bool) {}
func (rw *MockResponseWriter) Hijack() {}

View File

@@ -1,3 +0,0 @@
package types
type HandlerID string

View File

@@ -19,7 +19,6 @@ import (
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/internal/dns/types"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/proto"
)
@@ -82,21 +81,21 @@ func (u *upstreamResolverBase) String() string {
}
// ID returns the unique handler ID
func (u *upstreamResolverBase) ID() types.HandlerID {
func (u *upstreamResolverBase) id() handlerID {
servers := slices.Clone(u.upstreamServers)
slices.Sort(servers)
hash := sha256.New()
hash.Write([]byte(u.domain + ":"))
hash.Write([]byte(strings.Join(servers, ",")))
return types.HandlerID("upstream-" + hex.EncodeToString(hash.Sum(nil)[:8]))
return handlerID("upstream-" + hex.EncodeToString(hash.Sum(nil)[:8]))
}
func (u *upstreamResolverBase) MatchSubdomains() bool {
return true
}
func (u *upstreamResolverBase) Stop() {
func (u *upstreamResolverBase) stop() {
log.Debugf("stopping serving DNS for upstreams %s", u.upstreamServers)
u.cancel()
}
@@ -199,9 +198,9 @@ func (u *upstreamResolverBase) checkUpstreamFails(err error) {
)
}
// ProbeAvailability tests all upstream servers simultaneously and
// probeAvailability tests all upstream servers simultaneously and
// disables the resolver if none work
func (u *upstreamResolverBase) ProbeAvailability() {
func (u *upstreamResolverBase) probeAvailability() {
u.mutex.Lock()
defer u.mutex.Unlock()

View File

@@ -8,8 +8,6 @@ import (
"time"
"github.com/miekg/dns"
"github.com/netbirdio/netbird/client/internal/dns/test"
)
func TestUpstreamResolver_ServeDNS(t *testing.T) {
@@ -68,7 +66,7 @@ func TestUpstreamResolver_ServeDNS(t *testing.T) {
}
var responseMSG *dns.Msg
responseWriter := &test.MockResponseWriter{
responseWriter := &mockResponseWriter{
WriteMsgFunc: func(m *dns.Msg) error {
responseMSG = m
return nil
@@ -132,7 +130,7 @@ func TestUpstreamResolver_DeactivationReactivation(t *testing.T) {
resolver.failsTillDeact = 0
resolver.reactivatePeriod = time.Microsecond * 100
responseWriter := &test.MockResponseWriter{
responseWriter := &mockResponseWriter{
WriteMsgFunc: func(m *dns.Msg) error { return nil },
}

View File

@@ -51,16 +51,14 @@ func (s *Server) DebugBundle(_ context.Context, req *proto.DebugBundleRequest) (
}
if req.GetUploadURL() == "" {
return &proto.DebugBundleResponse{Path: path}, nil
}
key, err := uploadDebugBundle(context.Background(), req.GetUploadURL(), s.config.ManagementURL.String(), path)
if err != nil {
log.Errorf("failed to upload debug bundle to %s: %v", req.GetUploadURL(), err)
return &proto.DebugBundleResponse{Path: path, UploadFailureReason: err.Error()}, nil
}
log.Infof("debug bundle uploaded to %s with key %s", req.GetUploadURL(), key)
return &proto.DebugBundleResponse{Path: path, UploadedKey: key}, nil
}

View File

@@ -51,17 +51,14 @@ const (
)
func main() {
daemonAddr, showSettings, showNetworks, showDebug, errorMsg, saveLogsInFile := parseFlags()
daemonAddr, showSettings, showNetworks, errorMsg, saveLogsInFile := parseFlags()
// Initialize file logging if needed.
var logFile string
if saveLogsInFile {
file, err := initLogFile()
if err != nil {
if err := initLogFile(); err != nil {
log.Errorf("error while initializing log: %v", err)
return
}
logFile = file
}
// Create the Fyne application.
@@ -75,13 +72,13 @@ func main() {
}
// Create the service client (this also builds the settings or networks UI if requested).
client := newServiceClient(daemonAddr, logFile, a, showSettings, showNetworks, showDebug)
client := newServiceClient(daemonAddr, a, showSettings, showNetworks)
// Watch for theme/settings changes to update the icon.
go watchSettingsChanges(a, client)
// Run in window mode if any UI flag was set.
if showSettings || showNetworks || showDebug {
if showSettings || showNetworks {
a.Run()
return
}
@@ -102,7 +99,7 @@ func main() {
}
// parseFlags reads and returns all needed command-line flags.
func parseFlags() (daemonAddr string, showSettings, showNetworks, showDebug bool, errorMsg string, saveLogsInFile bool) {
func parseFlags() (daemonAddr string, showSettings, showNetworks bool, errorMsg string, saveLogsInFile bool) {
defaultDaemonAddr := "unix:///var/run/netbird.sock"
if runtime.GOOS == "windows" {
defaultDaemonAddr = "tcp://127.0.0.1:41731"
@@ -110,17 +107,25 @@ func parseFlags() (daemonAddr string, showSettings, showNetworks, showDebug bool
flag.StringVar(&daemonAddr, "daemon-addr", defaultDaemonAddr, "Daemon service address to serve CLI requests [unix|tcp]://[path|host:port]")
flag.BoolVar(&showSettings, "settings", false, "run settings window")
flag.BoolVar(&showNetworks, "networks", false, "run networks window")
flag.BoolVar(&showDebug, "debug", false, "run debug window")
flag.StringVar(&errorMsg, "error-msg", "", "displays an error message window")
flag.BoolVar(&saveLogsInFile, "use-log-file", false, fmt.Sprintf("save logs in a file: %s/netbird-ui-PID.log", os.TempDir()))
tmpDir := "/tmp"
if runtime.GOOS == "windows" {
tmpDir = os.TempDir()
}
flag.BoolVar(&saveLogsInFile, "use-log-file", false, fmt.Sprintf("save logs in a file: %s/netbird-ui-PID.log", tmpDir))
flag.Parse()
return
}
// initLogFile initializes logging into a file.
func initLogFile() (string, error) {
logFile := path.Join(os.TempDir(), fmt.Sprintf("netbird-ui-%d.log", os.Getpid()))
return logFile, util.InitLog("trace", logFile)
func initLogFile() error {
tmpDir := "/tmp"
if runtime.GOOS == "windows" {
tmpDir = os.TempDir()
}
logFile := path.Join(tmpDir, fmt.Sprintf("netbird-ui-%d.log", os.Getpid()))
return util.InitLog("trace", logFile)
}
// watchSettingsChanges listens for Fyne theme/settings changes and updates the client icon.
@@ -163,10 +168,9 @@ var iconConnectingMacOS []byte
var iconErrorMacOS []byte
type serviceClient struct {
ctx context.Context
cancel context.CancelFunc
addr string
conn proto.DaemonServiceClient
ctx context.Context
addr string
conn proto.DaemonServiceClient
icAbout []byte
icConnected []byte
@@ -227,14 +231,13 @@ type serviceClient struct {
daemonVersion string
updateIndicationLock sync.Mutex
isUpdateIconActive bool
showNetworks bool
wNetworks fyne.Window
showRoutes bool
wRoutes fyne.Window
eventManager *event.Manager
exitNodeMu sync.Mutex
mExitNodeItems []menuHandler
logFile string
}
type menuHandler struct {
@@ -245,30 +248,25 @@ type menuHandler struct {
// newServiceClient instance constructor
//
// This constructor also builds the UI elements for the settings window.
func newServiceClient(addr string, logFile string, a fyne.App, showSettings bool, showNetworks bool, showDebug bool) *serviceClient {
ctx, cancel := context.WithCancel(context.Background())
func newServiceClient(addr string, a fyne.App, showSettings bool, showRoutes bool) *serviceClient {
s := &serviceClient{
ctx: ctx,
cancel: cancel,
ctx: context.Background(),
addr: addr,
app: a,
logFile: logFile,
sendNotification: false,
showAdvancedSettings: showSettings,
showNetworks: showNetworks,
showRoutes: showRoutes,
update: version.NewUpdate(),
}
s.setNewIcons()
switch {
case showSettings:
if showSettings {
s.showSettingsUI()
case showNetworks:
return s
} else if showRoutes {
s.showNetworksUI()
case showDebug:
s.showDebugUI()
}
return s
@@ -315,8 +313,6 @@ func (s *serviceClient) updateIcon() {
func (s *serviceClient) showSettingsUI() {
// add settings window UI elements.
s.wSettings = s.app.NewWindow("NetBird Settings")
s.wSettings.SetOnClosed(s.cancel)
s.iMngURL = widget.NewEntry()
s.iAdminURL = widget.NewEntry()
s.iConfigFile = widget.NewEntry()
@@ -747,10 +743,11 @@ func (s *serviceClient) onTrayReady() {
s.runSelfCommand("settings", "true")
}()
case <-s.mCreateDebugBundle.ClickedCh:
s.mCreateDebugBundle.Disable()
go func() {
defer s.mCreateDebugBundle.Enable()
s.runSelfCommand("debug", "true")
if err := s.createAndOpenDebugBundle(); err != nil {
log.Errorf("Failed to create debug bundle: %v", err)
s.app.SendNotification(fyne.NewNotification("Error", "Failed to create debug bundle"))
}
}()
case <-s.mQuit.ClickedCh:
systray.Quit()
@@ -792,7 +789,7 @@ func (s *serviceClient) onTrayReady() {
func (s *serviceClient) runSelfCommand(command, arg string) {
proc, err := os.Executable()
if err != nil {
log.Errorf("Error getting executable path: %v", err)
log.Errorf("show %s failed with error: %v", command, err)
return
}
@@ -801,48 +798,14 @@ func (s *serviceClient) runSelfCommand(command, arg string) {
fmt.Sprintf("--daemon-addr=%s", s.addr),
)
if out := s.attachOutput(cmd); out != nil {
defer func() {
if err := out.Close(); err != nil {
log.Errorf("Error closing log file %s: %v", s.logFile, err)
}
}()
}
log.Printf("Running command: %s --%s=%s --daemon-addr=%s", proc, command, arg, s.addr)
err = cmd.Run()
if err != nil {
var exitErr *exec.ExitError
if errors.As(err, &exitErr) {
log.Printf("Command '%s %s' failed with exit code %d", command, arg, exitErr.ExitCode())
} else {
log.Printf("Failed to start/run command '%s %s': %v", command, arg, err)
}
out, err := cmd.CombinedOutput()
if exitErr, ok := err.(*exec.ExitError); ok && exitErr.ExitCode() == 1 {
log.Errorf("start %s UI: %v, %s", command, err, string(out))
return
}
log.Printf("Command '%s %s' completed successfully.", command, arg)
}
func (s *serviceClient) attachOutput(cmd *exec.Cmd) *os.File {
if s.logFile == "" {
// attach child's streams to parent's streams
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
return nil
if len(out) != 0 {
log.Infof("command %s executed: %s", command, string(out))
}
out, err := os.OpenFile(s.logFile, os.O_WRONLY|os.O_APPEND, 0)
if err != nil {
log.Errorf("Failed to open log file %s: %v", s.logFile, err)
return nil
}
cmd.Stdout = out
cmd.Stderr = out
return out
}
func normalizedVersion(version string) string {
@@ -855,7 +818,9 @@ func normalizedVersion(version string) string {
// onTrayExit is called when the tray icon is closed.
func (s *serviceClient) onTrayExit() {
s.cancel()
for _, item := range s.mExitNodeItems {
item.cancel()
}
}
// getSrvClient connection to the service.
@@ -864,7 +829,7 @@ func (s *serviceClient) getSrvClient(timeout time.Duration) (proto.DaemonService
return s.conn, nil
}
ctx, cancel := context.WithTimeout(s.ctx, timeout)
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
conn, err := grpc.DialContext(

View File

@@ -3,721 +3,48 @@
package main
import (
"context"
"fmt"
"path/filepath"
"strconv"
"sync"
"time"
"fyne.io/fyne/v2"
"fyne.io/fyne/v2/container"
"fyne.io/fyne/v2/dialog"
"fyne.io/fyne/v2/widget"
log "github.com/sirupsen/logrus"
"github.com/skratchdot/open-golang/open"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/proto"
nbstatus "github.com/netbirdio/netbird/client/status"
uptypes "github.com/netbirdio/netbird/upload-server/types"
)
// Initial state for the debug collection
type debugInitialState struct {
wasDown bool
logLevel proto.LogLevel
isLevelTrace bool
}
// Debug collection parameters
type debugCollectionParams struct {
duration time.Duration
anonymize bool
systemInfo bool
upload bool
uploadURL string
enablePersistence bool
}
// UI components for progress tracking
type progressUI struct {
statusLabel *widget.Label
progressBar *widget.ProgressBar
uiControls []fyne.Disableable
window fyne.Window
}
func (s *serviceClient) showDebugUI() {
w := s.app.NewWindow("NetBird Debug")
w.SetOnClosed(s.cancel)
w.Resize(fyne.NewSize(600, 500))
w.SetFixedSize(true)
anonymizeCheck := widget.NewCheck("Anonymize sensitive information (public IPs, domains, ...)", nil)
systemInfoCheck := widget.NewCheck("Include system information (routes, interfaces, ...)", nil)
systemInfoCheck.SetChecked(true)
uploadCheck := widget.NewCheck("Upload bundle automatically after creation", nil)
uploadCheck.SetChecked(true)
uploadURLLabel := widget.NewLabel("Debug upload URL:")
uploadURL := widget.NewEntry()
uploadURL.SetText(uptypes.DefaultBundleURL)
uploadURL.SetPlaceHolder("Enter upload URL")
uploadURLContainer := container.NewVBox(
uploadURLLabel,
uploadURL,
)
uploadCheck.OnChanged = func(checked bool) {
if checked {
uploadURLContainer.Show()
} else {
uploadURLContainer.Hide()
}
}
debugModeContainer := container.NewHBox()
runForDurationCheck := widget.NewCheck("Run with trace logs before creating bundle", nil)
runForDurationCheck.SetChecked(true)
forLabel := widget.NewLabel("for")
durationInput := widget.NewEntry()
durationInput.SetText("1")
minutesLabel := widget.NewLabel("minute")
durationInput.Validator = func(s string) error {
return validateMinute(s, minutesLabel)
}
noteLabel := widget.NewLabel("Note: NetBird will be brought up and down during collection")
runForDurationCheck.OnChanged = func(checked bool) {
if checked {
forLabel.Show()
durationInput.Show()
minutesLabel.Show()
noteLabel.Show()
} else {
forLabel.Hide()
durationInput.Hide()
minutesLabel.Hide()
noteLabel.Hide()
}
}
debugModeContainer.Add(runForDurationCheck)
debugModeContainer.Add(forLabel)
debugModeContainer.Add(durationInput)
debugModeContainer.Add(minutesLabel)
statusLabel := widget.NewLabel("")
statusLabel.Hide()
progressBar := widget.NewProgressBar()
progressBar.Hide()
createButton := widget.NewButton("Create Debug Bundle", nil)
// UI controls that should be disabled during debug collection
uiControls := []fyne.Disableable{
anonymizeCheck,
systemInfoCheck,
uploadCheck,
uploadURL,
runForDurationCheck,
durationInput,
createButton,
}
createButton.OnTapped = s.getCreateHandler(
statusLabel,
progressBar,
uploadCheck,
uploadURL,
anonymizeCheck,
systemInfoCheck,
runForDurationCheck,
durationInput,
uiControls,
w,
)
content := container.NewVBox(
widget.NewLabel("Create a debug bundle to help troubleshoot issues with NetBird"),
widget.NewLabel(""),
anonymizeCheck,
systemInfoCheck,
uploadCheck,
uploadURLContainer,
widget.NewLabel(""),
debugModeContainer,
noteLabel,
widget.NewLabel(""),
statusLabel,
progressBar,
createButton,
)
paddedContent := container.NewPadded(content)
w.SetContent(paddedContent)
w.Show()
}
func validateMinute(s string, minutesLabel *widget.Label) error {
if val, err := strconv.Atoi(s); err != nil || val < 1 {
return fmt.Errorf("must be a number ≥ 1")
}
if s == "1" {
minutesLabel.SetText("minute")
} else {
minutesLabel.SetText("minutes")
}
return nil
}
// disableUIControls disables the provided UI controls
func disableUIControls(controls []fyne.Disableable) {
for _, control := range controls {
control.Disable()
}
}
// enableUIControls enables the provided UI controls
func enableUIControls(controls []fyne.Disableable) {
for _, control := range controls {
control.Enable()
}
}
func (s *serviceClient) getCreateHandler(
statusLabel *widget.Label,
progressBar *widget.ProgressBar,
uploadCheck *widget.Check,
uploadURL *widget.Entry,
anonymizeCheck *widget.Check,
systemInfoCheck *widget.Check,
runForDurationCheck *widget.Check,
duration *widget.Entry,
uiControls []fyne.Disableable,
w fyne.Window,
) func() {
return func() {
disableUIControls(uiControls)
statusLabel.Show()
var url string
if uploadCheck.Checked {
url = uploadURL.Text
if url == "" {
statusLabel.SetText("Error: Upload URL is required when upload is enabled")
enableUIControls(uiControls)
return
}
}
params := &debugCollectionParams{
anonymize: anonymizeCheck.Checked,
systemInfo: systemInfoCheck.Checked,
upload: uploadCheck.Checked,
uploadURL: url,
enablePersistence: true,
}
runForDuration := runForDurationCheck.Checked
if runForDuration {
minutes, err := time.ParseDuration(duration.Text + "m")
if err != nil {
statusLabel.SetText(fmt.Sprintf("Error: Invalid duration: %v", err))
enableUIControls(uiControls)
return
}
params.duration = minutes
statusLabel.SetText(fmt.Sprintf("Running in debug mode for %d minutes...", int(minutes.Minutes())))
progressBar.Show()
progressBar.SetValue(0)
go s.handleRunForDuration(
statusLabel,
progressBar,
uiControls,
w,
params,
)
return
}
statusLabel.SetText("Creating debug bundle...")
go s.handleDebugCreation(
anonymizeCheck.Checked,
systemInfoCheck.Checked,
uploadCheck.Checked,
url,
statusLabel,
uiControls,
w,
)
}
}
func (s *serviceClient) handleRunForDuration(
statusLabel *widget.Label,
progressBar *widget.ProgressBar,
uiControls []fyne.Disableable,
w fyne.Window,
params *debugCollectionParams,
) {
progressUI := &progressUI{
statusLabel: statusLabel,
progressBar: progressBar,
uiControls: uiControls,
window: w,
}
func (s *serviceClient) createAndOpenDebugBundle() error {
conn, err := s.getSrvClient(failFastTimeout)
if err != nil {
handleError(progressUI, fmt.Sprintf("Failed to get client for debug: %v", err))
return
}
initialState, err := s.getInitialState(conn)
if err != nil {
handleError(progressUI, err.Error())
return
}
statusOutput, err := s.collectDebugData(conn, initialState, params, progressUI)
if err != nil {
handleError(progressUI, err.Error())
return
}
if err := s.createDebugBundleFromCollection(conn, params, statusOutput, progressUI); err != nil {
handleError(progressUI, err.Error())
return
}
s.restoreServiceState(conn, initialState)
progressUI.statusLabel.SetText("Bundle created successfully")
}
// Get initial state of the service
func (s *serviceClient) getInitialState(conn proto.DaemonServiceClient) (*debugInitialState, error) {
statusResp, err := conn.Status(s.ctx, &proto.StatusRequest{})
if err != nil {
return nil, fmt.Errorf(" get status: %v", err)
}
logLevelResp, err := conn.GetLogLevel(s.ctx, &proto.GetLogLevelRequest{})
if err != nil {
return nil, fmt.Errorf("get log level: %v", err)
}
wasDown := statusResp.Status != string(internal.StatusConnected) &&
statusResp.Status != string(internal.StatusConnecting)
initialLogLevel := logLevelResp.GetLevel()
initialLevelTrace := initialLogLevel >= proto.LogLevel_TRACE
return &debugInitialState{
wasDown: wasDown,
logLevel: initialLogLevel,
isLevelTrace: initialLevelTrace,
}, nil
}
// Handle progress tracking during collection
func startProgressTracker(ctx context.Context, wg *sync.WaitGroup, duration time.Duration, progress *progressUI) {
progress.progressBar.Show()
progress.progressBar.SetValue(0)
startTime := time.Now()
endTime := startTime.Add(duration)
wg.Add(1)
go func() {
defer wg.Done()
ticker := time.NewTicker(500 * time.Millisecond)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
remaining := time.Until(endTime)
if remaining <= 0 {
remaining = 0
}
elapsed := time.Since(startTime)
progressVal := float64(elapsed) / float64(duration)
if progressVal > 1.0 {
progressVal = 1.0
}
progress.progressBar.SetValue(progressVal)
progress.statusLabel.SetText(fmt.Sprintf("Running with trace logs... %s remaining", formatDuration(remaining)))
}
}
}()
}
func (s *serviceClient) configureServiceForDebug(
conn proto.DaemonServiceClient,
state *debugInitialState,
enablePersistence bool,
) error {
if state.wasDown {
if _, err := conn.Up(s.ctx, &proto.UpRequest{}); err != nil {
return fmt.Errorf("bring service up: %v", err)
}
log.Info("Service brought up for debug")
time.Sleep(time.Second * 10)
}
if !state.isLevelTrace {
if _, err := conn.SetLogLevel(s.ctx, &proto.SetLogLevelRequest{Level: proto.LogLevel_TRACE}); err != nil {
return fmt.Errorf("set log level to TRACE: %v", err)
}
log.Info("Log level set to TRACE for debug")
}
if _, err := conn.Down(s.ctx, &proto.DownRequest{}); err != nil {
return fmt.Errorf("bring service down: %v", err)
}
time.Sleep(time.Second)
if enablePersistence {
if _, err := conn.SetNetworkMapPersistence(s.ctx, &proto.SetNetworkMapPersistenceRequest{
Enabled: true,
}); err != nil {
return fmt.Errorf("enable network map persistence: %v", err)
}
log.Info("Network map persistence enabled for debug")
}
if _, err := conn.Up(s.ctx, &proto.UpRequest{}); err != nil {
return fmt.Errorf("bring service back up: %v", err)
}
time.Sleep(time.Second * 3)
return nil
}
func (s *serviceClient) collectDebugData(
conn proto.DaemonServiceClient,
state *debugInitialState,
params *debugCollectionParams,
progress *progressUI,
) (string, error) {
ctx, cancel := context.WithTimeout(s.ctx, params.duration)
defer cancel()
var wg sync.WaitGroup
startProgressTracker(ctx, &wg, params.duration, progress)
if err := s.configureServiceForDebug(conn, state, params.enablePersistence); err != nil {
return "", err
}
postUpStatus, err := conn.Status(s.ctx, &proto.StatusRequest{GetFullPeerStatus: true})
if err != nil {
log.Warnf("Failed to get post-up status: %v", err)
}
var postUpStatusOutput string
if postUpStatus != nil {
overview := nbstatus.ConvertToStatusOutputOverview(postUpStatus, params.anonymize, "", nil, nil, nil)
postUpStatusOutput = nbstatus.ParseToFullDetailSummary(overview)
}
headerPostUp := fmt.Sprintf("----- NetBird post-up - Timestamp: %s", time.Now().Format(time.RFC3339))
statusOutput := fmt.Sprintf("%s\n%s", headerPostUp, postUpStatusOutput)
wg.Wait()
progress.progressBar.Hide()
progress.statusLabel.SetText("Collecting debug data...")
preDownStatus, err := conn.Status(s.ctx, &proto.StatusRequest{GetFullPeerStatus: true})
if err != nil {
log.Warnf("Failed to get pre-down status: %v", err)
}
var preDownStatusOutput string
if preDownStatus != nil {
overview := nbstatus.ConvertToStatusOutputOverview(preDownStatus, params.anonymize, "", nil, nil, nil)
preDownStatusOutput = nbstatus.ParseToFullDetailSummary(overview)
}
headerPreDown := fmt.Sprintf("----- NetBird pre-down - Timestamp: %s - Duration: %s",
time.Now().Format(time.RFC3339), params.duration)
statusOutput = fmt.Sprintf("%s\n%s\n%s", statusOutput, headerPreDown, preDownStatusOutput)
return statusOutput, nil
}
// Create the debug bundle with collected data
func (s *serviceClient) createDebugBundleFromCollection(
conn proto.DaemonServiceClient,
params *debugCollectionParams,
statusOutput string,
progress *progressUI,
) error {
progress.statusLabel.SetText("Creating debug bundle with collected logs...")
request := &proto.DebugBundleRequest{
Anonymize: params.anonymize,
Status: statusOutput,
SystemInfo: params.systemInfo,
}
if params.upload {
request.UploadURL = params.uploadURL
}
resp, err := conn.DebugBundle(s.ctx, request)
if err != nil {
return fmt.Errorf("create debug bundle: %v", err)
}
// Show appropriate dialog based on upload status
localPath := resp.GetPath()
uploadFailureReason := resp.GetUploadFailureReason()
uploadedKey := resp.GetUploadedKey()
if params.upload {
if uploadFailureReason != "" {
showUploadFailedDialog(progress.window, localPath, uploadFailureReason)
} else {
showUploadSuccessDialog(progress.window, localPath, uploadedKey)
}
} else {
showBundleCreatedDialog(progress.window, localPath)
}
enableUIControls(progress.uiControls)
return nil
}
// Restore service to original state
func (s *serviceClient) restoreServiceState(conn proto.DaemonServiceClient, state *debugInitialState) {
if state.wasDown {
if _, err := conn.Down(s.ctx, &proto.DownRequest{}); err != nil {
log.Errorf("Failed to restore down state: %v", err)
} else {
log.Info("Service state restored to down")
}
}
if !state.isLevelTrace {
if _, err := conn.SetLogLevel(s.ctx, &proto.SetLogLevelRequest{Level: state.logLevel}); err != nil {
log.Errorf("Failed to restore log level: %v", err)
} else {
log.Info("Log level restored to original setting")
}
}
}
// Handle errors during debug collection
func handleError(progress *progressUI, errMsg string) {
log.Errorf("%s", errMsg)
progress.statusLabel.SetText(errMsg)
progress.progressBar.Hide()
enableUIControls(progress.uiControls)
}
func (s *serviceClient) handleDebugCreation(
anonymize bool,
systemInfo bool,
upload bool,
uploadURL string,
statusLabel *widget.Label,
uiControls []fyne.Disableable,
w fyne.Window,
) {
log.Infof("Creating debug bundle (Anonymized: %v, System Info: %v, Upload Attempt: %v)...",
anonymize, systemInfo, upload)
resp, err := s.createDebugBundle(anonymize, systemInfo, uploadURL)
if err != nil {
log.Errorf("Failed to create debug bundle: %v", err)
statusLabel.SetText(fmt.Sprintf("Error creating bundle: %v", err))
enableUIControls(uiControls)
return
}
localPath := resp.GetPath()
uploadFailureReason := resp.GetUploadFailureReason()
uploadedKey := resp.GetUploadedKey()
if upload {
if uploadFailureReason != "" {
showUploadFailedDialog(w, localPath, uploadFailureReason)
} else {
showUploadSuccessDialog(w, localPath, uploadedKey)
}
} else {
showBundleCreatedDialog(w, localPath)
}
enableUIControls(uiControls)
statusLabel.SetText("Bundle created successfully")
}
func (s *serviceClient) createDebugBundle(anonymize bool, systemInfo bool, uploadURL string) (*proto.DebugBundleResponse, error) {
conn, err := s.getSrvClient(failFastTimeout)
if err != nil {
return nil, fmt.Errorf("get client: %v", err)
return fmt.Errorf("get client: %v", err)
}
statusResp, err := conn.Status(s.ctx, &proto.StatusRequest{GetFullPeerStatus: true})
if err != nil {
log.Warnf("failed to get status for debug bundle: %v", err)
return fmt.Errorf("failed to get status: %v", err)
}
var statusOutput string
if statusResp != nil {
overview := nbstatus.ConvertToStatusOutputOverview(statusResp, anonymize, "", nil, nil, nil)
statusOutput = nbstatus.ParseToFullDetailSummary(overview)
}
overview := nbstatus.ConvertToStatusOutputOverview(statusResp, true, "", nil, nil, nil)
statusOutput := nbstatus.ParseToFullDetailSummary(overview)
request := &proto.DebugBundleRequest{
Anonymize: anonymize,
resp, err := conn.DebugBundle(s.ctx, &proto.DebugBundleRequest{
Anonymize: true,
Status: statusOutput,
SystemInfo: systemInfo,
}
if uploadURL != "" {
request.UploadURL = uploadURL
}
resp, err := conn.DebugBundle(s.ctx, request)
SystemInfo: true,
})
if err != nil {
return nil, fmt.Errorf("failed to create debug bundle via daemon: %v", err)
return fmt.Errorf("failed to create debug bundle: %v", err)
}
return resp, nil
}
// formatDuration formats a duration in HH:MM:SS format
func formatDuration(d time.Duration) string {
d = d.Round(time.Second)
h := d / time.Hour
d %= time.Hour
m := d / time.Minute
d %= time.Minute
s := d / time.Second
return fmt.Sprintf("%02d:%02d:%02d", h, m, s)
}
// createButtonWithAction creates a button with the given label and action
func createButtonWithAction(label string, action func()) *widget.Button {
button := widget.NewButton(label, action)
return button
}
// showUploadFailedDialog displays a dialog when upload fails
func showUploadFailedDialog(w fyne.Window, localPath, failureReason string) {
content := container.NewVBox(
widget.NewLabel(fmt.Sprintf("Bundle upload failed:\n%s\n\n"+
"A local copy was saved at:\n%s", failureReason, localPath)),
)
customDialog := dialog.NewCustom("Upload Failed", "Cancel", content, w)
buttonBox := container.NewHBox(
createButtonWithAction("Open file", func() {
log.Infof("Attempting to open local file: %s", localPath)
if openErr := open.Start(localPath); openErr != nil {
log.Errorf("Failed to open local file '%s': %v", localPath, openErr)
dialog.ShowError(fmt.Errorf("open the local file:\n%s\n\nError: %v", localPath, openErr), w)
}
}),
createButtonWithAction("Open folder", func() {
folderPath := filepath.Dir(localPath)
log.Infof("Attempting to open local folder: %s", folderPath)
if openErr := open.Start(folderPath); openErr != nil {
log.Errorf("Failed to open local folder '%s': %v", folderPath, openErr)
dialog.ShowError(fmt.Errorf("open the local folder:\n%s\n\nError: %v", folderPath, openErr), w)
}
}),
)
content.Add(buttonBox)
customDialog.Show()
}
// showUploadSuccessDialog displays a dialog when upload succeeds
func showUploadSuccessDialog(w fyne.Window, localPath, uploadedKey string) {
log.Infof("Upload key: %s", uploadedKey)
keyEntry := widget.NewEntry()
keyEntry.SetText(uploadedKey)
keyEntry.Disable()
content := container.NewVBox(
widget.NewLabel("Bundle uploaded successfully!"),
widget.NewLabel(""),
widget.NewLabel("Upload key:"),
keyEntry,
widget.NewLabel(""),
widget.NewLabel(fmt.Sprintf("Local copy saved at:\n%s", localPath)),
)
customDialog := dialog.NewCustom("Upload Successful", "OK", content, w)
copyBtn := createButtonWithAction("Copy key", func() {
w.Clipboard().SetContent(uploadedKey)
log.Info("Upload key copied to clipboard")
})
buttonBox := createButtonBox(localPath, w, copyBtn)
content.Add(buttonBox)
customDialog.Show()
}
// showBundleCreatedDialog displays a dialog when bundle is created without upload
func showBundleCreatedDialog(w fyne.Window, localPath string) {
content := container.NewVBox(
widget.NewLabel(fmt.Sprintf("Bundle created locally at:\n%s\n\n"+
"Administrator privileges may be required to access the file.", localPath)),
)
customDialog := dialog.NewCustom("Debug Bundle Created", "Cancel", content, w)
buttonBox := createButtonBox(localPath, w, nil)
content.Add(buttonBox)
customDialog.Show()
}
func createButtonBox(localPath string, w fyne.Window, elems ...fyne.Widget) *fyne.Container {
box := container.NewHBox()
for _, elem := range elems {
box.Add(elem)
bundleDir := filepath.Dir(resp.GetPath())
if err := open.Start(bundleDir); err != nil {
return fmt.Errorf("failed to open debug bundle directory: %v", err)
}
fileBtn := createButtonWithAction("Open file", func() {
log.Infof("Attempting to open local file: %s", localPath)
if openErr := open.Start(localPath); openErr != nil {
log.Errorf("Failed to open local file '%s': %v", localPath, openErr)
dialog.ShowError(fmt.Errorf("open the local file:\n%s\n\nError: %v", localPath, openErr), w)
}
})
s.app.SendNotification(fyne.NewNotification(
"Debug Bundle",
fmt.Sprintf("Debug bundle created at %s. Administrator privileges are required to access it.", resp.GetPath()),
))
folderBtn := createButtonWithAction("Open folder", func() {
folderPath := filepath.Dir(localPath)
log.Infof("Attempting to open local folder: %s", folderPath)
if openErr := open.Start(folderPath); openErr != nil {
log.Errorf("Failed to open local folder '%s': %v", folderPath, openErr)
dialog.ShowError(fmt.Errorf("open the local folder:\n%s\n\nError: %v", folderPath, openErr), w)
}
})
box.Add(fileBtn)
box.Add(folderBtn)
return box
return nil
}

View File

@@ -34,8 +34,7 @@ const (
type filter string
func (s *serviceClient) showNetworksUI() {
s.wNetworks = s.app.NewWindow("Networks")
s.wNetworks.SetOnClosed(s.cancel)
s.wRoutes = s.app.NewWindow("Networks")
allGrid := container.New(layout.NewGridLayout(3))
go s.updateNetworks(allGrid, allNetworks)
@@ -79,8 +78,8 @@ func (s *serviceClient) showNetworksUI() {
content := container.NewBorder(nil, buttonBox, nil, nil, scrollContainer)
s.wNetworks.SetContent(content)
s.wNetworks.Show()
s.wRoutes.SetContent(content)
s.wRoutes.Show()
s.startAutoRefresh(10*time.Second, tabs, allGrid, overlappingGrid, exitNodeGrid)
}
@@ -149,7 +148,7 @@ func (s *serviceClient) updateNetworks(grid *fyne.Container, f filter) {
grid.Add(resolvedIPsSelector)
}
s.wNetworks.Content().Refresh()
s.wRoutes.Content().Refresh()
grid.Refresh()
}
@@ -306,7 +305,7 @@ func (s *serviceClient) getNetworksRequest(f filter, appendRoute bool) *proto.Se
func (s *serviceClient) showError(err error) {
wrappedMessage := wrapText(err.Error(), 50)
dialog.ShowError(fmt.Errorf("%s", wrappedMessage), s.wNetworks)
dialog.ShowError(fmt.Errorf("%s", wrappedMessage), s.wRoutes)
}
func (s *serviceClient) startAutoRefresh(interval time.Duration, tabs *container.AppTabs, allGrid, overlappingGrid, exitNodesGrid *fyne.Container) {
@@ -317,15 +316,14 @@ func (s *serviceClient) startAutoRefresh(interval time.Duration, tabs *container
}
}()
s.wNetworks.SetOnClosed(func() {
s.wRoutes.SetOnClosed(func() {
ticker.Stop()
s.cancel()
})
}
func (s *serviceClient) updateNetworksBasedOnDisplayTab(tabs *container.AppTabs, allGrid, overlappingGrid, exitNodesGrid *fyne.Container) {
grid, f := getGridAndFilterFromTab(tabs, allGrid, overlappingGrid, exitNodesGrid)
s.wNetworks.Content().Refresh()
s.wRoutes.Content().Refresh()
s.updateNetworks(grid, f)
}
@@ -375,7 +373,7 @@ func (s *serviceClient) recreateExitNodeMenu(exitNodes []*proto.Network) {
node.Selected,
)
ctx, cancel := context.WithCancel(s.ctx)
ctx, cancel := context.WithCancel(context.Background())
s.mExitNodeItems = append(s.mExitNodeItems, menuHandler{
MenuItem: menuItem,
cancel: cancel,

View File

@@ -66,17 +66,17 @@ func (s SimpleRecord) String() string {
func (s SimpleRecord) Len() uint16 {
emptyString := s.RData == ""
switch s.Type {
case int(dns.TypeA):
case 1:
if emptyString {
return 0
}
return net.IPv4len
case int(dns.TypeCNAME):
case 5:
if emptyString || s.RData == "." {
return 1
}
return uint16(len(s.RData) + 1)
case int(dns.TypeAAAA):
case 28:
if emptyString {
return 0
}

View File

@@ -20,9 +20,7 @@ func (a *AccountsAPI) List(ctx context.Context) ([]api.Account, error) {
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
defer resp.Body.Close()
ret, err := parseResponse[[]api.Account](resp)
return ret, err
}
@@ -38,9 +36,7 @@ func (a *AccountsAPI) Update(ctx context.Context, accountID string, request api.
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
defer resp.Body.Close()
ret, err := parseResponse[api.Account](resp)
return &ret, err
}
@@ -52,9 +48,7 @@ func (a *AccountsAPI) Delete(ctx context.Context, accountID string) error {
if err != nil {
return err
}
if resp.Body != nil {
defer resp.Body.Close()
}
defer resp.Body.Close()
return nil
}

View File

@@ -20,9 +20,7 @@ func (a *DNSAPI) ListNameserverGroups(ctx context.Context) ([]api.NameserverGrou
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
defer resp.Body.Close()
ret, err := parseResponse[[]api.NameserverGroup](resp)
return ret, err
}
@@ -34,9 +32,7 @@ func (a *DNSAPI) GetNameserverGroup(ctx context.Context, nameserverGroupID strin
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
defer resp.Body.Close()
ret, err := parseResponse[api.NameserverGroup](resp)
return &ret, err
}
@@ -52,9 +48,7 @@ func (a *DNSAPI) CreateNameserverGroup(ctx context.Context, request api.PostApiD
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
defer resp.Body.Close()
ret, err := parseResponse[api.NameserverGroup](resp)
return &ret, err
}
@@ -70,9 +64,7 @@ func (a *DNSAPI) UpdateNameserverGroup(ctx context.Context, nameserverGroupID st
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
defer resp.Body.Close()
ret, err := parseResponse[api.NameserverGroup](resp)
return &ret, err
}
@@ -84,9 +76,7 @@ func (a *DNSAPI) DeleteNameserverGroup(ctx context.Context, nameserverGroupID st
if err != nil {
return err
}
if resp.Body != nil {
defer resp.Body.Close()
}
defer resp.Body.Close()
return nil
}
@@ -98,9 +88,7 @@ func (a *DNSAPI) GetSettings(ctx context.Context) (*api.DNSSettings, error) {
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
defer resp.Body.Close()
ret, err := parseResponse[api.DNSSettings](resp)
return &ret, err
}
@@ -116,9 +104,7 @@ func (a *DNSAPI) UpdateSettings(ctx context.Context, request api.PutApiDnsSettin
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
defer resp.Body.Close()
ret, err := parseResponse[api.DNSSettings](resp)
return &ret, err
}

View File

@@ -18,9 +18,7 @@ func (a *EventsAPI) List(ctx context.Context) ([]api.Event, error) {
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
defer resp.Body.Close()
ret, err := parseResponse[[]api.Event](resp)
return ret, err
}

View File

@@ -18,9 +18,7 @@ func (a *GeoLocationAPI) ListCountries(ctx context.Context) ([]api.Country, erro
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
defer resp.Body.Close()
ret, err := parseResponse[[]api.Country](resp)
return ret, err
}
@@ -32,9 +30,7 @@ func (a *GeoLocationAPI) ListCountryCities(ctx context.Context, countryCode stri
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
defer resp.Body.Close()
ret, err := parseResponse[[]api.City](resp)
return ret, err
}

View File

@@ -20,9 +20,7 @@ func (a *GroupsAPI) List(ctx context.Context) ([]api.Group, error) {
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
defer resp.Body.Close()
ret, err := parseResponse[[]api.Group](resp)
return ret, err
}
@@ -34,9 +32,7 @@ func (a *GroupsAPI) Get(ctx context.Context, groupID string) (*api.Group, error)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
defer resp.Body.Close()
ret, err := parseResponse[api.Group](resp)
return &ret, err
}
@@ -52,9 +48,7 @@ func (a *GroupsAPI) Create(ctx context.Context, request api.PostApiGroupsJSONReq
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
defer resp.Body.Close()
ret, err := parseResponse[api.Group](resp)
return &ret, err
}
@@ -70,9 +64,7 @@ func (a *GroupsAPI) Update(ctx context.Context, groupID string, request api.PutA
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
defer resp.Body.Close()
ret, err := parseResponse[api.Group](resp)
return &ret, err
}
@@ -84,9 +76,7 @@ func (a *GroupsAPI) Delete(ctx context.Context, groupID string) error {
if err != nil {
return err
}
if resp.Body != nil {
defer resp.Body.Close()
}
defer resp.Body.Close()
return nil
}

View File

@@ -20,9 +20,7 @@ func (a *NetworksAPI) List(ctx context.Context) ([]api.Network, error) {
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
defer resp.Body.Close()
ret, err := parseResponse[[]api.Network](resp)
return ret, err
}
@@ -34,9 +32,7 @@ func (a *NetworksAPI) Get(ctx context.Context, networkID string) (*api.Network,
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
defer resp.Body.Close()
ret, err := parseResponse[api.Network](resp)
return &ret, err
}
@@ -52,9 +48,7 @@ func (a *NetworksAPI) Create(ctx context.Context, request api.PostApiNetworksJSO
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
defer resp.Body.Close()
ret, err := parseResponse[api.Network](resp)
return &ret, err
}
@@ -70,9 +64,7 @@ func (a *NetworksAPI) Update(ctx context.Context, networkID string, request api.
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
defer resp.Body.Close()
ret, err := parseResponse[api.Network](resp)
return &ret, err
}
@@ -84,9 +76,7 @@ func (a *NetworksAPI) Delete(ctx context.Context, networkID string) error {
if err != nil {
return err
}
if resp.Body != nil {
defer resp.Body.Close()
}
defer resp.Body.Close()
return nil
}
@@ -112,9 +102,7 @@ func (a *NetworkResourcesAPI) List(ctx context.Context) ([]api.NetworkResource,
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
defer resp.Body.Close()
ret, err := parseResponse[[]api.NetworkResource](resp)
return ret, err
}
@@ -126,9 +114,7 @@ func (a *NetworkResourcesAPI) Get(ctx context.Context, networkResourceID string)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
defer resp.Body.Close()
ret, err := parseResponse[api.NetworkResource](resp)
return &ret, err
}
@@ -144,9 +130,7 @@ func (a *NetworkResourcesAPI) Create(ctx context.Context, request api.PostApiNet
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
defer resp.Body.Close()
ret, err := parseResponse[api.NetworkResource](resp)
return &ret, err
}
@@ -162,9 +146,7 @@ func (a *NetworkResourcesAPI) Update(ctx context.Context, networkResourceID stri
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
defer resp.Body.Close()
ret, err := parseResponse[api.NetworkResource](resp)
return &ret, err
}
@@ -176,9 +158,7 @@ func (a *NetworkResourcesAPI) Delete(ctx context.Context, networkResourceID stri
if err != nil {
return err
}
if resp.Body != nil {
defer resp.Body.Close()
}
defer resp.Body.Close()
return nil
}
@@ -204,9 +184,7 @@ func (a *NetworkRoutersAPI) List(ctx context.Context) ([]api.NetworkRouter, erro
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
defer resp.Body.Close()
ret, err := parseResponse[[]api.NetworkRouter](resp)
return ret, err
}
@@ -218,9 +196,7 @@ func (a *NetworkRoutersAPI) Get(ctx context.Context, networkRouterID string) (*a
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
defer resp.Body.Close()
ret, err := parseResponse[api.NetworkRouter](resp)
return &ret, err
}
@@ -236,9 +212,7 @@ func (a *NetworkRoutersAPI) Create(ctx context.Context, request api.PostApiNetwo
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
defer resp.Body.Close()
ret, err := parseResponse[api.NetworkRouter](resp)
return &ret, err
}
@@ -254,9 +228,7 @@ func (a *NetworkRoutersAPI) Update(ctx context.Context, networkRouterID string,
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
defer resp.Body.Close()
ret, err := parseResponse[api.NetworkRouter](resp)
return &ret, err
}
@@ -268,9 +240,7 @@ func (a *NetworkRoutersAPI) Delete(ctx context.Context, networkRouterID string)
if err != nil {
return err
}
if resp.Body != nil {
defer resp.Body.Close()
}
defer resp.Body.Close()
return nil
}

View File

@@ -20,9 +20,7 @@ func (a *PeersAPI) List(ctx context.Context) ([]api.Peer, error) {
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
defer resp.Body.Close()
ret, err := parseResponse[[]api.Peer](resp)
return ret, err
}
@@ -34,9 +32,7 @@ func (a *PeersAPI) Get(ctx context.Context, peerID string) (*api.Peer, error) {
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
defer resp.Body.Close()
ret, err := parseResponse[api.Peer](resp)
return &ret, err
}
@@ -52,9 +48,7 @@ func (a *PeersAPI) Update(ctx context.Context, peerID string, request api.PutApi
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
defer resp.Body.Close()
ret, err := parseResponse[api.Peer](resp)
return &ret, err
}
@@ -66,9 +60,7 @@ func (a *PeersAPI) Delete(ctx context.Context, peerID string) error {
if err != nil {
return err
}
if resp.Body != nil {
defer resp.Body.Close()
}
defer resp.Body.Close()
return nil
}
@@ -80,9 +72,7 @@ func (a *PeersAPI) ListAccessiblePeers(ctx context.Context, peerID string) ([]ap
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
defer resp.Body.Close()
ret, err := parseResponse[[]api.Peer](resp)
return ret, err
}

View File

@@ -20,9 +20,7 @@ func (a *PoliciesAPI) List(ctx context.Context) ([]api.Policy, error) {
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
defer resp.Body.Close()
ret, err := parseResponse[[]api.Policy](resp)
return ret, err
}
@@ -34,9 +32,7 @@ func (a *PoliciesAPI) Get(ctx context.Context, policyID string) (*api.Policy, er
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
defer resp.Body.Close()
ret, err := parseResponse[api.Policy](resp)
return &ret, err
}
@@ -52,9 +48,7 @@ func (a *PoliciesAPI) Create(ctx context.Context, request api.PostApiPoliciesJSO
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
defer resp.Body.Close()
ret, err := parseResponse[api.Policy](resp)
return &ret, err
}
@@ -70,9 +64,7 @@ func (a *PoliciesAPI) Update(ctx context.Context, policyID string, request api.P
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
defer resp.Body.Close()
ret, err := parseResponse[api.Policy](resp)
return &ret, err
}
@@ -84,9 +76,7 @@ func (a *PoliciesAPI) Delete(ctx context.Context, policyID string) error {
if err != nil {
return err
}
if resp.Body != nil {
defer resp.Body.Close()
}
defer resp.Body.Close()
return nil
}

View File

@@ -20,9 +20,7 @@ func (a *PostureChecksAPI) List(ctx context.Context) ([]api.PostureCheck, error)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
defer resp.Body.Close()
ret, err := parseResponse[[]api.PostureCheck](resp)
return ret, err
}
@@ -34,9 +32,7 @@ func (a *PostureChecksAPI) Get(ctx context.Context, postureCheckID string) (*api
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
defer resp.Body.Close()
ret, err := parseResponse[api.PostureCheck](resp)
return &ret, err
}
@@ -52,9 +48,7 @@ func (a *PostureChecksAPI) Create(ctx context.Context, request api.PostApiPostur
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
defer resp.Body.Close()
ret, err := parseResponse[api.PostureCheck](resp)
return &ret, err
}
@@ -70,9 +64,7 @@ func (a *PostureChecksAPI) Update(ctx context.Context, postureCheckID string, re
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
defer resp.Body.Close()
ret, err := parseResponse[api.PostureCheck](resp)
return &ret, err
}
@@ -84,9 +76,7 @@ func (a *PostureChecksAPI) Delete(ctx context.Context, postureCheckID string) er
if err != nil {
return err
}
if resp.Body != nil {
defer resp.Body.Close()
}
defer resp.Body.Close()
return nil
}

View File

@@ -20,9 +20,7 @@ func (a *RoutesAPI) List(ctx context.Context) ([]api.Route, error) {
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
defer resp.Body.Close()
ret, err := parseResponse[[]api.Route](resp)
return ret, err
}
@@ -34,9 +32,7 @@ func (a *RoutesAPI) Get(ctx context.Context, routeID string) (*api.Route, error)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
defer resp.Body.Close()
ret, err := parseResponse[api.Route](resp)
return &ret, err
}
@@ -52,9 +48,7 @@ func (a *RoutesAPI) Create(ctx context.Context, request api.PostApiRoutesJSONReq
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
defer resp.Body.Close()
ret, err := parseResponse[api.Route](resp)
return &ret, err
}
@@ -70,9 +64,7 @@ func (a *RoutesAPI) Update(ctx context.Context, routeID string, request api.PutA
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
defer resp.Body.Close()
ret, err := parseResponse[api.Route](resp)
return &ret, err
}
@@ -84,9 +76,7 @@ func (a *RoutesAPI) Delete(ctx context.Context, routeID string) error {
if err != nil {
return err
}
if resp.Body != nil {
defer resp.Body.Close()
}
defer resp.Body.Close()
return nil
}

View File

@@ -20,9 +20,7 @@ func (a *SetupKeysAPI) List(ctx context.Context) ([]api.SetupKey, error) {
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
defer resp.Body.Close()
ret, err := parseResponse[[]api.SetupKey](resp)
return ret, err
}
@@ -34,9 +32,7 @@ func (a *SetupKeysAPI) Get(ctx context.Context, setupKeyID string) (*api.SetupKe
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
defer resp.Body.Close()
ret, err := parseResponse[api.SetupKey](resp)
return &ret, err
}
@@ -52,9 +48,7 @@ func (a *SetupKeysAPI) Create(ctx context.Context, request api.PostApiSetupKeysJ
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
defer resp.Body.Close()
ret, err := parseResponse[api.SetupKeyClear](resp)
return &ret, err
}
@@ -70,9 +64,7 @@ func (a *SetupKeysAPI) Update(ctx context.Context, setupKeyID string, request ap
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
defer resp.Body.Close()
ret, err := parseResponse[api.SetupKey](resp)
return &ret, err
}
@@ -84,9 +76,7 @@ func (a *SetupKeysAPI) Delete(ctx context.Context, setupKeyID string) error {
if err != nil {
return err
}
if resp.Body != nil {
defer resp.Body.Close()
}
defer resp.Body.Close()
return nil
}

View File

@@ -20,9 +20,7 @@ func (a *TokensAPI) List(ctx context.Context, userID string) ([]api.PersonalAcce
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
defer resp.Body.Close()
ret, err := parseResponse[[]api.PersonalAccessToken](resp)
return ret, err
}
@@ -34,9 +32,7 @@ func (a *TokensAPI) Get(ctx context.Context, userID, tokenID string) (*api.Perso
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
defer resp.Body.Close()
ret, err := parseResponse[api.PersonalAccessToken](resp)
return &ret, err
}
@@ -52,9 +48,7 @@ func (a *TokensAPI) Create(ctx context.Context, userID string, request api.PostA
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
defer resp.Body.Close()
ret, err := parseResponse[api.PersonalAccessTokenGenerated](resp)
return &ret, err
}
@@ -66,9 +60,7 @@ func (a *TokensAPI) Delete(ctx context.Context, userID, tokenID string) error {
if err != nil {
return err
}
if resp.Body != nil {
defer resp.Body.Close()
}
defer resp.Body.Close()
return nil
}

View File

@@ -20,9 +20,7 @@ func (a *UsersAPI) List(ctx context.Context) ([]api.User, error) {
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
defer resp.Body.Close()
ret, err := parseResponse[[]api.User](resp)
return ret, err
}
@@ -38,9 +36,7 @@ func (a *UsersAPI) Create(ctx context.Context, request api.PostApiUsersJSONReque
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
defer resp.Body.Close()
ret, err := parseResponse[api.User](resp)
return &ret, err
}
@@ -56,9 +52,7 @@ func (a *UsersAPI) Update(ctx context.Context, userID string, request api.PutApi
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
defer resp.Body.Close()
ret, err := parseResponse[api.User](resp)
return &ret, err
}
@@ -70,9 +64,7 @@ func (a *UsersAPI) Delete(ctx context.Context, userID string) error {
if err != nil {
return err
}
if resp.Body != nil {
defer resp.Body.Close()
}
defer resp.Body.Close()
return nil
}
@@ -84,9 +76,7 @@ func (a *UsersAPI) ResendInvitation(ctx context.Context, userID string) error {
if err != nil {
return err
}
if resp.Body != nil {
defer resp.Body.Close()
}
defer resp.Body.Close()
return nil
}
@@ -98,9 +88,7 @@ func (a *UsersAPI) Current(ctx context.Context) (*api.User, error) {
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
defer resp.Body.Close()
ret, err := parseResponse[api.User](resp)
return &ret, err

View File

@@ -30,8 +30,11 @@ var (
Issued: ptr("api"),
LastLogin: &time.Time{},
Name: "M. Essam",
Role: "user",
Status: api.UserStatusActive,
Permissions: &api.UserPermissions{
DashboardView: ptr(api.UserPermissionsDashboardViewFull),
},
Role: "user",
Status: api.UserStatusActive,
}
)

View File

@@ -9,6 +9,7 @@ import (
"os"
"reflect"
"regexp"
"runtime/debug"
"slices"
"strconv"
"strings"
@@ -712,7 +713,7 @@ func (am *DefaultAccountManager) loadAccount(ctx context.Context, accountID any)
log.WithContext(ctx).Debugf("account %s not found in cache, reloading", accountID)
accountIDString := fmt.Sprintf("%v", accountID)
accountUsers, err := am.Store.GetAccountUsers(ctx, store.LockingStrengthShare, accountIDString)
account, err := am.Store.GetAccount(ctx, accountIDString)
if err != nil {
return nil, nil, err
}
@@ -721,7 +722,7 @@ func (am *DefaultAccountManager) loadAccount(ctx context.Context, accountID any)
if err != nil {
return nil, nil, err
}
log.WithContext(ctx).Debugf("%d entries received from IdP management for account %s", len(userData), accountIDString)
log.WithContext(ctx).Debugf("%d entries received from IdP management", len(userData))
dataMap := make(map[string]*idp.UserData, len(userData))
for _, datum := range userData {
@@ -729,7 +730,7 @@ func (am *DefaultAccountManager) loadAccount(ctx context.Context, accountID any)
}
matchedUserData := make([]*idp.UserData, 0)
for _, user := range accountUsers {
for _, user := range account.Users {
if user.IsServiceUser {
continue
}
@@ -1054,6 +1055,9 @@ func (am *DefaultAccountManager) addNewUserToDomainAccount(ctx context.Context,
return "", err
}
log.WithContext(ctx).Debugf("created new regular user ID: %s, domainAccountId: %s, accountID: %s Trace: %s", userAuth.UserId,
domainAccountID, userAuth.AccountId, debug.Stack(),
)
am.StoreEvent(ctx, userAuth.UserId, userAuth.UserId, domainAccountID, activity.UserJoined, nil)
return domainAccountID, nil

View File

@@ -16,7 +16,6 @@ import (
"github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/management/server/users"
"github.com/netbirdio/netbird/route"
)
@@ -116,5 +115,5 @@ type Manager interface {
CreateAccountByPrivateDomain(ctx context.Context, initiatorId, domain string) (*types.Account, error)
UpdateToPrimaryAccount(ctx context.Context, accountId string) (*types.Account, error)
GetOwnerInfo(ctx context.Context, accountId string) (*types.UserInfo, error)
GetCurrentUserInfo(ctx context.Context, userAuth nbcontext.UserAuth) (*users.UserInfoWithPermissions, error)
GetCurrentUserInfo(ctx context.Context, accountID, userID string) (*types.UserInfo, error)
}

View File

@@ -216,25 +216,11 @@ components:
UserPermissions:
type: object
properties:
is_restricted:
type: boolean
description: Indicates whether this User's Peers view is restricted
modules:
type: object
additionalProperties:
type: object
additionalProperties:
type: boolean
propertyNames:
type: string
description: The operation type
propertyNames:
type: string
description: The module name
example: {"networks": { "read": true, "create": false, "update": false, "delete": false}, "peers": { "read": false, "create": false, "update": false, "delete": false} }
required:
- modules
- is_restricted
dashboard_view:
description: User's permission to view the dashboard
type: string
enum: [ "limited", "blocked", "full" ]
example: limited
UserRequest:
type: object
properties:

View File

@@ -178,6 +178,13 @@ const (
UserStatusInvited UserStatus = "invited"
)
// Defines values for UserPermissionsDashboardView.
const (
UserPermissionsDashboardViewBlocked UserPermissionsDashboardView = "blocked"
UserPermissionsDashboardViewFull UserPermissionsDashboardView = "full"
UserPermissionsDashboardViewLimited UserPermissionsDashboardView = "limited"
)
// Defines values for GetApiEventsNetworkTrafficParamsType.
const (
GetApiEventsNetworkTrafficParamsTypeTYPEDROP GetApiEventsNetworkTrafficParamsType = "TYPE_DROP"
@@ -1750,11 +1757,13 @@ type UserCreateRequest struct {
// UserPermissions defines model for UserPermissions.
type UserPermissions struct {
// IsRestricted Indicates whether this User's Peers view is restricted
IsRestricted bool `json:"is_restricted"`
Modules map[string]map[string]bool `json:"modules"`
// DashboardView User's permission to view the dashboard
DashboardView *UserPermissionsDashboardView `json:"dashboard_view,omitempty"`
}
// UserPermissionsDashboardView User's permission to view the dashboard
type UserPermissionsDashboardView string
// UserRequest defines model for UserRequest.
type UserRequest struct {
// AutoGroups Group IDs to auto-assign to peers registered by this user

View File

@@ -13,7 +13,6 @@ import (
"github.com/netbirdio/netbird/management/server/http/util"
"github.com/netbirdio/netbird/management/server/status"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/management/server/users"
nbcontext "github.com/netbirdio/netbird/management/server/context"
)
@@ -273,33 +272,15 @@ func (h *handler) getCurrentUser(w http.ResponseWriter, r *http.Request) {
return
}
user, err := h.accountManager.GetCurrentUserInfo(ctx, userAuth)
accountID, userID := userAuth.AccountId, userAuth.UserId
user, err := h.accountManager.GetCurrentUserInfo(ctx, accountID, userID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
util.WriteJSONObject(r.Context(), w, toUserWithPermissionsResponse(user, userAuth.UserId))
}
func toUserWithPermissionsResponse(user *users.UserInfoWithPermissions, userID string) *api.User {
response := toUserResponse(user.UserInfo, userID)
// stringify modules and operations keys
modules := make(map[string]map[string]bool)
for module, operations := range user.Permissions {
modules[string(module)] = make(map[string]bool)
for op, val := range operations {
modules[string(module)][string(op)] = val
}
}
response.Permissions = &api.UserPermissions{
IsRestricted: user.Restricted,
Modules: modules,
}
return response
util.WriteJSONObject(r.Context(), w, toUserResponse(user, userID))
}
func toUserResponse(user *types.UserInfo, currenUserID string) *api.User {
@@ -335,5 +316,8 @@ func toUserResponse(user *types.UserInfo, currenUserID string) *api.User {
IsBlocked: user.IsBlocked,
LastLogin: &user.LastLogin,
Issued: &user.Issued,
Permissions: &api.UserPermissions{
DashboardView: (*api.UserPermissionsDashboardView)(&user.Permissions.DashboardView),
},
}
}

View File

@@ -13,16 +13,12 @@ import (
"github.com/gorilla/mux"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/mock_server"
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/roles"
"github.com/netbirdio/netbird/management/server/status"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/management/server/users"
)
const (
@@ -111,7 +107,7 @@ func initUsersTestData() *handler {
return nil, status.Errorf(status.NotFound, "user with ID %s does not exists", userID)
}
info, err := update.Copy().ToUserInfo(nil)
info, err := update.Copy().ToUserInfo(nil, &types.Settings{RegularUsersViewBlocked: false})
if err != nil {
return nil, err
}
@@ -128,8 +124,8 @@ func initUsersTestData() *handler {
return nil
},
GetCurrentUserInfoFunc: func(ctx context.Context, userAuth nbcontext.UserAuth) (*users.UserInfoWithPermissions, error) {
switch userAuth.UserId {
GetCurrentUserInfoFunc: func(ctx context.Context, accountID, userID string) (*types.UserInfo, error) {
switch userID {
case "not-found":
return nil, status.NewUserNotFoundError("not-found")
case "not-of-account":
@@ -139,68 +135,52 @@ func initUsersTestData() *handler {
case "service-user":
return nil, status.NewPermissionDeniedError()
case "owner":
return &users.UserInfoWithPermissions{
UserInfo: &types.UserInfo{
ID: "owner",
Name: "",
Role: "owner",
Status: "active",
IsServiceUser: false,
IsBlocked: false,
NonDeletable: false,
Issued: "api",
return &types.UserInfo{
ID: "owner",
Name: "",
Role: "owner",
Status: "active",
IsServiceUser: false,
IsBlocked: false,
NonDeletable: false,
Issued: "api",
Permissions: types.UserPermissions{
DashboardView: "full",
},
Permissions: mergeRolePermissions(roles.Owner),
}, nil
case "regular-user":
return &users.UserInfoWithPermissions{
UserInfo: &types.UserInfo{
ID: "regular-user",
Name: "",
Role: "user",
Status: "active",
IsServiceUser: false,
IsBlocked: false,
NonDeletable: false,
Issued: "api",
return &types.UserInfo{
ID: "regular-user",
Name: "",
Role: "user",
Status: "active",
IsServiceUser: false,
IsBlocked: false,
NonDeletable: false,
Issued: "api",
Permissions: types.UserPermissions{
DashboardView: "limited",
},
Permissions: mergeRolePermissions(roles.User),
}, nil
case "admin-user":
return &users.UserInfoWithPermissions{
UserInfo: &types.UserInfo{
ID: "admin-user",
Name: "",
Role: "admin",
Status: "active",
IsServiceUser: false,
IsBlocked: false,
NonDeletable: false,
LastLogin: time.Time{},
Issued: "api",
return &types.UserInfo{
ID: "admin-user",
Name: "",
Role: "admin",
Status: "active",
IsServiceUser: false,
IsBlocked: false,
NonDeletable: false,
LastLogin: time.Time{},
Issued: "api",
Permissions: types.UserPermissions{
DashboardView: "full",
},
Permissions: mergeRolePermissions(roles.Admin),
}, nil
case "restricted-user":
return &users.UserInfoWithPermissions{
UserInfo: &types.UserInfo{
ID: "restricted-user",
Name: "",
Role: "user",
Status: "active",
IsServiceUser: false,
IsBlocked: false,
NonDeletable: false,
LastLogin: time.Time{},
Issued: "api",
},
Permissions: mergeRolePermissions(roles.User),
Restricted: true,
}, nil
}
return nil, fmt.Errorf("user id %s not handled", userAuth.UserId)
return nil, fmt.Errorf("user id %s not handled", userID)
},
},
}
@@ -566,7 +546,6 @@ func TestCurrentUser(t *testing.T) {
name string
expectedStatus int
requestAuth nbcontext.UserAuth
expectedResult *api.User
}{
{
name: "without auth",
@@ -596,78 +575,16 @@ func TestCurrentUser(t *testing.T) {
name: "owner",
requestAuth: nbcontext.UserAuth{UserId: "owner"},
expectedStatus: http.StatusOK,
expectedResult: &api.User{
Id: "owner",
Role: "owner",
Status: "active",
IsBlocked: false,
IsCurrent: ptr(true),
IsServiceUser: ptr(false),
AutoGroups: []string{},
Issued: ptr("api"),
LastLogin: ptr(time.Time{}),
Permissions: &api.UserPermissions{
Modules: stringifyPermissionsKeys(mergeRolePermissions(roles.Owner)),
},
},
},
{
name: "regular user",
requestAuth: nbcontext.UserAuth{UserId: "regular-user"},
expectedStatus: http.StatusOK,
expectedResult: &api.User{
Id: "regular-user",
Role: "user",
Status: "active",
IsBlocked: false,
IsCurrent: ptr(true),
IsServiceUser: ptr(false),
AutoGroups: []string{},
Issued: ptr("api"),
LastLogin: ptr(time.Time{}),
Permissions: &api.UserPermissions{
Modules: stringifyPermissionsKeys(mergeRolePermissions(roles.User)),
},
},
},
{
name: "admin user",
requestAuth: nbcontext.UserAuth{UserId: "admin-user"},
expectedStatus: http.StatusOK,
expectedResult: &api.User{
Id: "admin-user",
Role: "admin",
Status: "active",
IsBlocked: false,
IsCurrent: ptr(true),
IsServiceUser: ptr(false),
AutoGroups: []string{},
Issued: ptr("api"),
LastLogin: ptr(time.Time{}),
Permissions: &api.UserPermissions{
Modules: stringifyPermissionsKeys(mergeRolePermissions(roles.Admin)),
},
},
},
{
name: "restricted user",
requestAuth: nbcontext.UserAuth{UserId: "restricted-user"},
expectedStatus: http.StatusOK,
expectedResult: &api.User{
Id: "restricted-user",
Role: "user",
Status: "active",
IsBlocked: false,
IsCurrent: ptr(true),
IsServiceUser: ptr(false),
AutoGroups: []string{},
Issued: ptr("api"),
LastLogin: ptr(time.Time{}),
Permissions: &api.UserPermissions{
IsRestricted: true,
Modules: stringifyPermissionsKeys(mergeRolePermissions(roles.User)),
},
},
},
}
@@ -686,42 +603,10 @@ func TestCurrentUser(t *testing.T) {
res := rr.Result()
defer res.Body.Close()
assert.Equal(t, tc.expectedStatus, rr.Code, "handler returned wrong status code")
if tc.expectedResult != nil {
var result api.User
require.NoError(t, json.NewDecoder(res.Body).Decode(&result))
assert.EqualValues(t, *tc.expectedResult, result)
if status := rr.Code; status != tc.expectedStatus {
t.Fatalf("handler returned wrong status code: got %v want %v",
status, tc.expectedStatus)
}
})
}
}
func ptr[T any, PT *T](x T) PT {
return &x
}
func mergeRolePermissions(role roles.RolePermissions) roles.Permissions {
permissions := roles.Permissions{}
for k := range modules.All {
if rolePermissions, ok := role.Permissions[k]; ok {
permissions[k] = rolePermissions
continue
}
permissions[k] = role.AutoAllowNew
}
return permissions
}
func stringifyPermissionsKeys(permissions roles.Permissions) map[string]map[string]bool {
modules := make(map[string]map[string]bool)
for module, operations := range permissions {
modules[string(module)] = make(map[string]bool)
for op, val := range operations {
modules[string(module)][string(op)] = val
}
}
return modules
}

View File

@@ -352,24 +352,3 @@ func MigrateNewField[T any](ctx context.Context, db *gorm.DB, columnName string,
log.WithContext(ctx).Infof("Migration of empty %s to default value in table %s completed", columnName, tableName)
return nil
}
func DropIndex[T any](ctx context.Context, db *gorm.DB, indexName string) error {
var model T
if !db.Migrator().HasTable(&model) {
log.WithContext(ctx).Debugf("table for %T does not exist, no migration needed", model)
return nil
}
if !db.Migrator().HasIndex(&model, indexName) {
log.WithContext(ctx).Debugf("index %s does not exist in table %T, no migration needed", indexName, model)
return nil
}
if err := db.Migrator().DropIndex(&model, indexName); err != nil {
return fmt.Errorf("failed to drop index %s: %w", indexName, err)
}
log.WithContext(ctx).Infof("dropped index %s from table %T", indexName, model)
return nil
}

View File

@@ -227,25 +227,3 @@ func TestMigrateSetupKeyToHashedSetupKey_ForAlreadyMigratedKey_Case2(t *testing.
assert.Equal(t, "9+FQcmNd2GCxIK+SvHmtp6PPGV4MKEicDS+xuSQmvlE=", key.Key, "Key should be hashed")
}
func TestDropIndex(t *testing.T) {
db := setupDatabase(t)
err := db.AutoMigrate(&types.SetupKey{})
require.NoError(t, err, "Failed to auto-migrate tables")
err = db.Save(&types.SetupKey{
Id: "1",
Key: "9+FQcmNd2GCxIK+SvHmtp6PPGV4MKEicDS+xuSQmvlE=",
}).Error
require.NoError(t, err, "Failed to insert setup key")
exist := db.Migrator().HasIndex(&types.SetupKey{}, "idx_setup_keys_account_id")
assert.True(t, exist, "Should have the index")
err = migration.DropIndex[types.SetupKey](context.Background(), db, "idx_setup_keys_account_id")
require.NoError(t, err, "Migration should not fail to remove index")
exist = db.Migrator().HasIndex(&types.SetupKey{}, "idx_setup_keys_account_id")
assert.False(t, exist, "Should not have the index")
}

View File

@@ -19,7 +19,6 @@ import (
"github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/management/server/users"
"github.com/netbirdio/netbird/route"
)
@@ -116,7 +115,7 @@ type MockAccountManager struct {
CreateAccountByPrivateDomainFunc func(ctx context.Context, initiatorId, domain string) (*types.Account, error)
UpdateToPrimaryAccountFunc func(ctx context.Context, accountId string) (*types.Account, error)
GetOwnerInfoFunc func(ctx context.Context, accountID string) (*types.UserInfo, error)
GetCurrentUserInfoFunc func(ctx context.Context, userAuth nbcontext.UserAuth) (*users.UserInfoWithPermissions, error)
GetCurrentUserInfoFunc func(ctx context.Context, accountID, userID string) (*types.UserInfo, error)
GetAccountMetaFunc func(ctx context.Context, accountID, userID string) (*types.AccountMeta, error)
}
@@ -883,9 +882,9 @@ func (am *MockAccountManager) GetOwnerInfo(ctx context.Context, accountId string
return nil, status.Errorf(codes.Unimplemented, "method GetOwnerInfo is not implemented")
}
func (am *MockAccountManager) GetCurrentUserInfo(ctx context.Context, userAuth nbcontext.UserAuth) (*users.UserInfoWithPermissions, error) {
func (am *MockAccountManager) GetCurrentUserInfo(ctx context.Context, accountID, userID string) (*types.UserInfo, error) {
if am.GetCurrentUserInfoFunc != nil {
return am.GetCurrentUserInfoFunc(ctx, userAuth)
return am.GetCurrentUserInfoFunc(ctx, accountID, userID)
}
return nil, status.Errorf(codes.Unimplemented, "method GetCurrentUserInfo is not implemented")
}

View File

@@ -30,7 +30,7 @@ func (p NetworkResourceType) String() string {
}
type NetworkResource struct {
ID string `gorm:"primaryKey"`
ID string `gorm:"index"`
NetworkID string `gorm:"index"`
AccountID string `gorm:"index"`
Name string

View File

@@ -10,7 +10,7 @@ import (
)
type NetworkRouter struct {
ID string `gorm:"primaryKey"`
ID string `gorm:"index"`
NetworkID string `gorm:"index"`
AccountID string `gorm:"index"`
Peer string

View File

@@ -7,7 +7,7 @@ import (
)
type Network struct {
ID string `gorm:"primaryKey"`
ID string `gorm:"index"`
AccountID string `gorm:"index"`
Name string
Description string

View File

@@ -49,9 +49,20 @@ func (am *DefaultAccountManager) GetPeers(ctx context.Context, accountID, userID
return nil, err
}
// @note if the user has permission to read peers it shows all account peers
peers := make([]*nbpeer.Peer, 0)
peersMap := make(map[string]*nbpeer.Peer)
for _, peer := range accountPeers {
if user.IsRegularUser() && user.Id != peer.UserID {
// only display peers that belong to the current user if the current user is not an admin
continue
}
peers = append(peers, peer)
peersMap[peer.ID] = peer
}
if allowed {
return accountPeers, nil
return peers, nil
}
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID)
@@ -59,22 +70,10 @@ func (am *DefaultAccountManager) GetPeers(ctx context.Context, accountID, userID
return nil, fmt.Errorf("failed to get account settings: %w", err)
}
if user.IsRestrictable() && settings.RegularUsersViewBlocked {
if settings.RegularUsersViewBlocked {
return []*nbpeer.Peer{}, nil
}
// @note if it does not have permission read peers then only display it's own peers
peers := make([]*nbpeer.Peer, 0)
peersMap := make(map[string]*nbpeer.Peer)
for _, peer := range accountPeers {
if user.Id != peer.UserID {
continue
}
peers = append(peers, peer)
peersMap[peer.ID] = peer
}
return am.getUserAccessiblePeers(ctx, accountID, peersMap, peers)
}

View File

@@ -20,8 +20,6 @@ type Manager interface {
ValidateUserPermissions(ctx context.Context, accountID, userID string, module modules.Module, operation operations.Operation) (bool, error)
ValidateRoleModuleAccess(ctx context.Context, accountID string, role roles.RolePermissions, module modules.Module, operation operations.Operation) bool
ValidateAccountAccess(ctx context.Context, accountID string, user *types.User, allowOwnerAndAdmin bool) error
GetPermissionsByRole(ctx context.Context, role types.UserRole) (roles.Permissions, error)
}
type managerImpl struct {
@@ -98,22 +96,3 @@ func (m *managerImpl) ValidateAccountAccess(ctx context.Context, accountID strin
}
return nil
}
func (m *managerImpl) GetPermissionsByRole(ctx context.Context, role types.UserRole) (roles.Permissions, error) {
roleMap, ok := roles.RolesMap[role]
if !ok {
return roles.Permissions{}, status.NewUserRoleNotFoundError(string(role))
}
permissions := roles.Permissions{}
for k := range modules.All {
if rolePermissions, ok := roleMap.Permissions[k]; ok {
permissions[k] = rolePermissions
continue
}
permissions[k] = roleMap.AutoAllowNew
}
return permissions, nil
}

View File

@@ -38,21 +38,6 @@ func (m *MockManager) EXPECT() *MockManagerMockRecorder {
return m.recorder
}
// GetPermissionsByRole mocks base method.
func (m *MockManager) GetPermissionsByRole(ctx context.Context, role types.UserRole) (roles.Permissions, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetPermissionsByRole", ctx, role)
ret0, _ := ret[0].(roles.Permissions)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetPermissionsByRole indicates an expected call of GetPermissionsByRole.
func (mr *MockManagerMockRecorder) GetPermissionsByRole(ctx, role interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPermissionsByRole", reflect.TypeOf((*MockManager)(nil).GetPermissionsByRole), ctx, role)
}
// ValidateAccountAccess mocks base method.
func (m *MockManager) ValidateAccountAccess(ctx context.Context, accountID string, user *types.User, allowOwnerAndAdmin bool) error {
m.ctrl.T.Helper()

View File

@@ -17,19 +17,3 @@ const (
SetupKeys Module = "setup_keys"
Pats Module = "pats"
)
var All = map[Module]struct{}{
Networks: {},
Peers: {},
Groups: {},
Settings: {},
Accounts: {},
Dns: {},
Nameservers: {},
Events: {},
Policies: {},
Routes: {},
Users: {},
SetupKeys: {},
Pats: {},
}

View File

@@ -23,9 +23,9 @@ var NetworkAdmin = RolePermissions{
},
modules.Groups: {
operations.Read: true,
operations.Create: true,
operations.Update: true,
operations.Delete: true,
operations.Create: false,
operations.Update: false,
operations.Delete: false,
},
modules.Settings: {
operations.Read: true,
@@ -87,11 +87,5 @@ var NetworkAdmin = RolePermissions{
operations.Update: true,
operations.Delete: true,
},
modules.Peers: {
operations.Read: true,
operations.Create: false,
operations.Update: false,
operations.Delete: false,
},
},
}

View File

@@ -209,6 +209,10 @@ func (s *SqlStore) SaveAccount(ctx context.Context, account *types.Account) erro
if s.metrics != nil {
s.metrics.StoreMetrics().CountPersistenceDuration(took)
}
data, _ := json.MarshalIndent(account, "", " ")
log.WithContext(ctx).Debugf("saved an account Account ID: %s, state: %s Trace: %s", account.Id, data, debug.Stack())
log.WithContext(ctx).Debugf("took %d ms to persist an account to the store", took.Milliseconds())
return err
@@ -802,7 +806,7 @@ func (s *SqlStore) GetAccountByPeerPubKey(ctx context.Context, peerKey string) (
func (s *SqlStore) GetAnyAccountID(ctx context.Context) (string, error) {
var account types.Account
result := s.db.WithContext(ctx).Select("id").Order("created_at desc").Limit(1).Find(&account)
result := s.db.WithContext(ctx).Select("id").Limit(1).Find(&account)
if result.Error != nil {
return "", status.NewGetAccountFromStoreError(result.Error)
}

View File

@@ -60,10 +60,10 @@ func Test_NewStore(t *testing.T) {
runTestForAllEngines(t, "", func(t *testing.T, store Store) {
if store == nil {
t.Fatalf("expected to create a new Store")
t.Errorf("expected to create a new Store")
}
if len(store.GetAllAccounts(context.Background())) != 0 {
t.Fatalf("expected to create a new empty Accounts map when creating a new FileStore")
t.Errorf("expected to create a new empty Accounts map when creating a new FileStore")
}
})
}
@@ -1115,7 +1115,7 @@ func TestSqlite_CreateAndGetObjectInTransaction(t *testing.T) {
group := &types.Group{
ID: "group-id",
AccountID: "bf1c8084-ba50-4ce7-9439-34653001fc3b",
AccountID: "account-id",
Name: "group-name",
Issued: "api",
Peers: nil,

View File

@@ -315,15 +315,6 @@ func getMigrations(ctx context.Context) []migrationFunc {
func(db *gorm.DB) error {
return migration.MigrateNewField[routerTypes.NetworkRouter](ctx, db, "enabled", true)
},
func(db *gorm.DB) error {
return migration.DropIndex[networkTypes.Network](ctx, db, "idx_networks_id")
},
func(db *gorm.DB) error {
return migration.DropIndex[resourceTypes.NetworkResource](ctx, db, "idx_network_resources_id")
},
func(db *gorm.DB) error {
return migration.DropIndex[routerTypes.NetworkRouter](ctx, db, "idx_network_routers_id")
},
}
}

View File

@@ -14,7 +14,7 @@ const (
// Group of the peers for ACL
type Group struct {
// ID of the group
ID string `gorm:"primaryKey"`
ID string
// AccountID is a reference to Account that this object belongs
AccountID string `json:"-" gorm:"index"`

View File

@@ -65,6 +65,11 @@ type UserInfo struct {
LastLogin time.Time `json:"last_login"`
Issued string `json:"issued"`
IntegrationReference integration_reference.IntegrationReference `json:"-"`
Permissions UserPermissions `json:"permissions"`
}
type UserPermissions struct {
DashboardView string `json:"dashboard_view"`
}
// User represents a user of the system
@@ -127,18 +132,21 @@ func (u *User) IsRegularUser() bool {
return !u.HasAdminPower() && !u.IsServiceUser
}
// IsRestrictable checks whether a user is in a restrictable role.
func (u *User) IsRestrictable() bool {
return u.Role == UserRoleUser || u.Role == UserRoleBillingAdmin
}
// ToUserInfo converts a User object to a UserInfo object.
func (u *User) ToUserInfo(userData *idp.UserData) (*UserInfo, error) {
func (u *User) ToUserInfo(userData *idp.UserData, settings *Settings) (*UserInfo, error) {
autoGroups := u.AutoGroups
if autoGroups == nil {
autoGroups = []string{}
}
dashboardViewPermissions := "full"
if !u.HasAdminPower() {
dashboardViewPermissions = "limited"
if settings.RegularUsersViewBlocked {
dashboardViewPermissions = "blocked"
}
}
if userData == nil {
return &UserInfo{
ID: u.Id,
@@ -151,6 +159,9 @@ func (u *User) ToUserInfo(userData *idp.UserData) (*UserInfo, error) {
IsBlocked: u.Blocked,
LastLogin: u.GetLastLogin(),
Issued: u.Issued,
Permissions: UserPermissions{
DashboardView: dashboardViewPermissions,
},
}, nil
}
if userData.ID != u.Id {
@@ -173,6 +184,9 @@ func (u *User) ToUserInfo(userData *idp.UserData) (*UserInfo, error) {
IsBlocked: u.Blocked,
LastLogin: u.GetLastLogin(),
Issued: u.Issued,
Permissions: UserPermissions{
DashboardView: dashboardViewPermissions,
},
}, nil
}

View File

@@ -12,7 +12,6 @@ import (
"github.com/netbirdio/netbird/management/server/activity"
nbContext "github.com/netbirdio/netbird/management/server/context"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/idp"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/permissions/modules"
@@ -20,7 +19,6 @@ import (
"github.com/netbirdio/netbird/management/server/status"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/management/server/users"
"github.com/netbirdio/netbird/management/server/util"
)
@@ -124,6 +122,11 @@ func (am *DefaultAccountManager) inviteNewUser(ctx context.Context, accountID, u
CreatedAt: time.Now().UTC(),
}
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID)
if err != nil {
return nil, err
}
if err = am.Store.SaveUser(ctx, store.LockingStrengthUpdate, newUser); err != nil {
return nil, err
}
@@ -135,7 +138,7 @@ func (am *DefaultAccountManager) inviteNewUser(ctx context.Context, accountID, u
am.StoreEvent(ctx, userID, newUser.Id, accountID, activity.UserInvited, nil)
return newUser.ToUserInfo(idpUser)
return newUser.ToUserInfo(idpUser, settings)
}
// createNewIdpUser validates the invite and creates a new user in the IdP
@@ -357,7 +360,6 @@ func (am *DefaultAccountManager) CreatePAT(ctx context.Context, accountID string
return nil, err
}
// @note this is essential to prevent non admin users with Pats create permission frpm creating one for a service user
if initiatorUserID != targetUserID && !(initiatorUser.HasAdminPower() && targetUser.IsServiceUser) {
return nil, status.NewAdminPermissionError()
}
@@ -725,14 +727,19 @@ func handleOwnerRoleTransfer(ctx context.Context, transaction store.Store, initi
// If the AccountManager has a non-nil idpManager and the User is not a service user,
// it will attempt to look up the UserData from the cache.
func (am *DefaultAccountManager) getUserInfo(ctx context.Context, user *types.User, accountID string) (*types.UserInfo, error) {
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID)
if err != nil {
return nil, err
}
if !isNil(am.idpManager) && !user.IsServiceUser {
userData, err := am.lookupUserInCache(ctx, user.Id, accountID)
if err != nil {
return nil, err
}
return user.ToUserInfo(userData)
return user.ToUserInfo(userData, settings)
}
return user.ToUserInfo(nil)
return user.ToUserInfo(nil, settings)
}
// validateUserUpdate validates the update operation for a user.
@@ -872,12 +879,17 @@ func (am *DefaultAccountManager) BuildUserInfosForAccount(ctx context.Context, a
queriedUsers = append(queriedUsers, usersFromIntegration...)
}
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID)
if err != nil {
return nil, err
}
userInfosMap := make(map[string]*types.UserInfo)
// in case of self-hosted, or IDP doesn't return anything, we will return the locally stored userInfo
if len(queriedUsers) == 0 {
for _, accountUser := range accountUsers {
info, err := accountUser.ToUserInfo(nil)
info, err := accountUser.ToUserInfo(nil, settings)
if err != nil {
return nil, err
}
@@ -890,7 +902,7 @@ func (am *DefaultAccountManager) BuildUserInfosForAccount(ctx context.Context, a
for _, localUser := range accountUsers {
var info *types.UserInfo
if queriedUser, contains := findUserInIDPUserdata(localUser.Id, queriedUsers); contains {
info, err = localUser.ToUserInfo(queriedUser)
info, err = localUser.ToUserInfo(queriedUser, settings)
if err != nil {
return nil, err
}
@@ -900,6 +912,14 @@ func (am *DefaultAccountManager) BuildUserInfosForAccount(ctx context.Context, a
name = localUser.ServiceUserName
}
dashboardViewPermissions := "full"
if !localUser.HasAdminPower() {
dashboardViewPermissions = "limited"
if settings.RegularUsersViewBlocked {
dashboardViewPermissions = "blocked"
}
}
info = &types.UserInfo{
ID: localUser.Id,
Email: "",
@@ -909,6 +929,7 @@ func (am *DefaultAccountManager) BuildUserInfosForAccount(ctx context.Context, a
Status: string(types.UserStatusActive),
IsServiceUser: localUser.IsServiceUser,
NonDeletable: localUser.NonDeletable,
Permissions: types.UserPermissions{DashboardView: dashboardViewPermissions},
}
}
userInfosMap[info.ID] = info
@@ -1218,10 +1239,8 @@ func validateUserInvite(invite *types.UserInfo) error {
return nil
}
// GetCurrentUserInfo retrieves the account's current user info and permissions
func (am *DefaultAccountManager) GetCurrentUserInfo(ctx context.Context, userAuth nbcontext.UserAuth) (*users.UserInfoWithPermissions, error) {
accountID, userID := userAuth.AccountId, userAuth.UserId
// GetCurrentUserInfo retrieves the account's current user info
func (am *DefaultAccountManager) GetCurrentUserInfo(ctx context.Context, accountID, userID string) (*types.UserInfo, error) {
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID)
if err != nil {
return nil, err
@@ -1239,25 +1258,10 @@ func (am *DefaultAccountManager) GetCurrentUserInfo(ctx context.Context, userAut
return nil, err
}
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID)
if err != nil {
return nil, err
}
userInfo, err := am.getUserInfo(ctx, user, accountID)
if err != nil {
return nil, err
}
userWithPermissions := &users.UserInfoWithPermissions{
UserInfo: userInfo,
Restricted: !userAuth.IsChild && user.IsRestrictable() && settings.RegularUsersViewBlocked,
}
permissions, err := am.permissionsManager.GetPermissionsByRole(ctx, user.Role)
if err == nil {
userWithPermissions.Permissions = permissions
}
return userWithPermissions, nil
return userInfo, nil
}

View File

@@ -13,10 +13,7 @@ import (
nbcache "github.com/netbirdio/netbird/management/server/cache"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/roles"
"github.com/netbirdio/netbird/management/server/status"
"github.com/netbirdio/netbird/management/server/users"
"github.com/netbirdio/netbird/management/server/util"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
@@ -1023,6 +1020,90 @@ func TestDefaultAccountManager_ListUsers(t *testing.T) {
assert.Equal(t, 2, regular)
}
func TestDefaultAccountManager_ListUsers_DashboardPermissions(t *testing.T) {
testCases := []struct {
name string
role types.UserRole
limitedViewSettings bool
expectedDashboardPermissions string
}{
{
name: "Regular user, no limited view settings",
role: types.UserRoleUser,
limitedViewSettings: false,
expectedDashboardPermissions: "limited",
},
{
name: "Admin user, no limited view settings",
role: types.UserRoleAdmin,
limitedViewSettings: false,
expectedDashboardPermissions: "full",
},
{
name: "Owner, no limited view settings",
role: types.UserRoleOwner,
limitedViewSettings: false,
expectedDashboardPermissions: "full",
},
{
name: "Regular user, limited view settings",
role: types.UserRoleUser,
limitedViewSettings: true,
expectedDashboardPermissions: "blocked",
},
{
name: "Admin user, limited view settings",
role: types.UserRoleAdmin,
limitedViewSettings: true,
expectedDashboardPermissions: "full",
},
{
name: "Owner, limited view settings",
role: types.UserRoleOwner,
limitedViewSettings: true,
expectedDashboardPermissions: "full",
},
}
for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir())
if err != nil {
t.Fatalf("Error when creating store: %s", err)
}
t.Cleanup(cleanup)
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
account.Users["normal_user1"] = types.NewUser("normal_user1", testCase.role, false, false, "", []string{}, types.UserIssuedAPI)
account.Settings.RegularUsersViewBlocked = testCase.limitedViewSettings
delete(account.Users, mockUserID)
err = store.SaveAccount(context.Background(), account)
if err != nil {
t.Fatalf("Error when saving account: %s", err)
}
permissionsManager := permissions.NewManager(store)
am := DefaultAccountManager{
Store: store,
eventStore: &activity.InMemoryEventStore{},
permissionsManager: permissionsManager,
}
users, err := am.ListUsers(context.Background(), mockAccountID)
if err != nil {
t.Fatalf("Error when checking user role: %s", err)
}
assert.Equal(t, 1, len(users))
userInfo, _ := users[0].ToUserInfo(nil, account.Settings)
assert.Equal(t, testCase.expectedDashboardPermissions, userInfo.Permissions.DashboardView)
})
}
}
func TestDefaultAccountManager_ExternalCache(t *testing.T) {
store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir())
if err != nil {
@@ -1573,154 +1654,121 @@ func TestDefaultAccountManager_GetCurrentUserInfo(t *testing.T) {
tt := []struct {
name string
userAuth nbcontext.UserAuth
accountId string
userId string
expectedErr error
expectedResult *users.UserInfoWithPermissions
expectedResult *types.UserInfo
}{
{
name: "not found",
userAuth: nbcontext.UserAuth{AccountId: account1.Id, UserId: "not-found"},
accountId: account1.Id,
userId: "not-found",
expectedErr: status.NewUserNotFoundError("not-found"),
},
{
name: "not part of account",
userAuth: nbcontext.UserAuth{AccountId: account1.Id, UserId: "account2Owner"},
accountId: account1.Id,
userId: "account2Owner",
expectedErr: status.NewUserNotPartOfAccountError(),
},
{
name: "blocked",
userAuth: nbcontext.UserAuth{AccountId: account1.Id, UserId: "blocked-user"},
accountId: account1.Id,
userId: "blocked-user",
expectedErr: status.NewUserBlockedError(),
},
{
name: "service user",
userAuth: nbcontext.UserAuth{AccountId: account1.Id, UserId: "service-user"},
accountId: account1.Id,
userId: "service-user",
expectedErr: status.NewPermissionDeniedError(),
},
{
name: "owner user",
userAuth: nbcontext.UserAuth{AccountId: account1.Id, UserId: "account1Owner"},
expectedResult: &users.UserInfoWithPermissions{
UserInfo: &types.UserInfo{
ID: "account1Owner",
Name: "",
Role: "owner",
AutoGroups: []string{},
Status: "active",
IsServiceUser: false,
IsBlocked: false,
NonDeletable: false,
LastLogin: time.Time{},
Issued: "api",
IntegrationReference: integration_reference.IntegrationReference{},
name: "owner user",
accountId: account1.Id,
userId: "account1Owner",
expectedResult: &types.UserInfo{
ID: "account1Owner",
Name: "",
Role: "owner",
AutoGroups: []string{},
Status: "active",
IsServiceUser: false,
IsBlocked: false,
NonDeletable: false,
LastLogin: time.Time{},
Issued: "api",
IntegrationReference: integration_reference.IntegrationReference{},
Permissions: types.UserPermissions{
DashboardView: "full",
},
Permissions: mergeRolePermissions(roles.Owner),
},
},
{
name: "regular user",
userAuth: nbcontext.UserAuth{AccountId: account1.Id, UserId: "regular-user"},
expectedResult: &users.UserInfoWithPermissions{
UserInfo: &types.UserInfo{
ID: "regular-user",
Name: "",
Role: "user",
Status: "active",
IsServiceUser: false,
IsBlocked: false,
NonDeletable: false,
LastLogin: time.Time{},
Issued: "api",
IntegrationReference: integration_reference.IntegrationReference{},
name: "regular user",
accountId: account1.Id,
userId: "regular-user",
expectedResult: &types.UserInfo{
ID: "regular-user",
Name: "",
Role: "user",
Status: "active",
IsServiceUser: false,
IsBlocked: false,
NonDeletable: false,
LastLogin: time.Time{},
Issued: "api",
IntegrationReference: integration_reference.IntegrationReference{},
Permissions: types.UserPermissions{
DashboardView: "limited",
},
Permissions: mergeRolePermissions(roles.User),
},
},
{
name: "admin user",
userAuth: nbcontext.UserAuth{AccountId: account1.Id, UserId: "admin-user"},
expectedResult: &users.UserInfoWithPermissions{
UserInfo: &types.UserInfo{
ID: "admin-user",
Name: "",
Role: "admin",
Status: "active",
IsServiceUser: false,
IsBlocked: false,
NonDeletable: false,
LastLogin: time.Time{},
Issued: "api",
IntegrationReference: integration_reference.IntegrationReference{},
name: "admin user",
accountId: account1.Id,
userId: "admin-user",
expectedResult: &types.UserInfo{
ID: "admin-user",
Name: "",
Role: "admin",
Status: "active",
IsServiceUser: false,
IsBlocked: false,
NonDeletable: false,
LastLogin: time.Time{},
Issued: "api",
IntegrationReference: integration_reference.IntegrationReference{},
Permissions: types.UserPermissions{
DashboardView: "full",
},
Permissions: mergeRolePermissions(roles.Admin),
},
},
{
name: "settings blocked regular user",
userAuth: nbcontext.UserAuth{AccountId: account2.Id, UserId: "settings-blocked-user"},
expectedResult: &users.UserInfoWithPermissions{
UserInfo: &types.UserInfo{
ID: "settings-blocked-user",
Name: "",
Role: "user",
Status: "active",
IsServiceUser: false,
IsBlocked: false,
NonDeletable: false,
LastLogin: time.Time{},
Issued: "api",
IntegrationReference: integration_reference.IntegrationReference{},
name: "settings blocked regular user",
accountId: account2.Id,
userId: "settings-blocked-user",
expectedResult: &types.UserInfo{
ID: "settings-blocked-user",
Name: "",
Role: "user",
Status: "active",
IsServiceUser: false,
IsBlocked: false,
NonDeletable: false,
LastLogin: time.Time{},
Issued: "api",
IntegrationReference: integration_reference.IntegrationReference{},
Permissions: types.UserPermissions{
DashboardView: "blocked",
},
Permissions: mergeRolePermissions(roles.User),
Restricted: true,
},
},
{
name: "settings blocked regular user child account",
userAuth: nbcontext.UserAuth{AccountId: account2.Id, UserId: "settings-blocked-user", IsChild: true},
expectedResult: &users.UserInfoWithPermissions{
UserInfo: &types.UserInfo{
ID: "settings-blocked-user",
Name: "",
Role: "user",
Status: "active",
IsServiceUser: false,
IsBlocked: false,
NonDeletable: false,
LastLogin: time.Time{},
Issued: "api",
IntegrationReference: integration_reference.IntegrationReference{},
},
Permissions: mergeRolePermissions(roles.User),
Restricted: false,
},
},
{
name: "settings blocked owner user",
userAuth: nbcontext.UserAuth{AccountId: account2.Id, UserId: "account2Owner"},
expectedResult: &users.UserInfoWithPermissions{
UserInfo: &types.UserInfo{
ID: "account2Owner",
Name: "",
Role: "owner",
AutoGroups: []string{},
Status: "active",
IsServiceUser: false,
IsBlocked: false,
NonDeletable: false,
LastLogin: time.Time{},
Issued: "api",
IntegrationReference: integration_reference.IntegrationReference{},
},
Permissions: mergeRolePermissions(roles.Owner),
},
},
}
for _, tc := range tt {
t.Run(tc.name, func(t *testing.T) {
result, err := am.GetCurrentUserInfo(context.Background(), tc.userAuth)
result, err := am.GetCurrentUserInfo(context.Background(), tc.accountId, tc.userId)
if tc.expectedErr != nil {
assert.Equal(t, err, tc.expectedErr)
@@ -1732,17 +1780,3 @@ func TestDefaultAccountManager_GetCurrentUserInfo(t *testing.T) {
})
}
}
func mergeRolePermissions(role roles.RolePermissions) roles.Permissions {
permissions := roles.Permissions{}
for k := range modules.All {
if rolePermissions, ok := role.Permissions[k]; ok {
permissions[k] = rolePermissions
continue
}
permissions[k] = role.AutoAllowNew
}
return permissions
}

View File

@@ -1,14 +0,0 @@
package users
import (
"github.com/netbirdio/netbird/management/server/permissions/roles"
"github.com/netbirdio/netbird/management/server/types"
)
// Wrapped UserInfo with Role Permissions
type UserInfoWithPermissions struct {
*types.UserInfo
Permissions roles.Permissions
Restricted bool
}

View File

@@ -28,16 +28,6 @@ func (d Dialer) Dial(ctx context.Context, address string) (net.Conn, error) {
return nil, err
}
// Get the base TLS config
tlsClientConfig := quictls.ClientQUICTLSConfig()
// Set ServerName to hostname if not an IP address
host, _, splitErr := net.SplitHostPort(quicURL)
if splitErr == nil && net.ParseIP(host) == nil {
// It's a hostname, not an IP - modify directly
tlsClientConfig.ServerName = host
}
quicConfig := &quic.Config{
KeepAlivePeriod: 30 * time.Second,
MaxIdleTimeout: 4 * time.Minute,
@@ -57,7 +47,7 @@ func (d Dialer) Dial(ctx context.Context, address string) (net.Conn, error) {
return nil, err
}
session, err := quic.Dial(ctx, udpConn, udpAddr, tlsClientConfig, quicConfig)
session, err := quic.Dial(ctx, udpConn, udpAddr, quictls.ClientQUICTLSConfig(), quicConfig)
if err != nil {
if errors.Is(err, context.Canceled) {
return nil, err
@@ -71,29 +61,12 @@ func (d Dialer) Dial(ctx context.Context, address string) (net.Conn, error) {
}
func prepareURL(address string) (string, error) {
var host string
var defaultPort string
switch {
case strings.HasPrefix(address, "rels://"):
host = address[7:]
defaultPort = "443"
case strings.HasPrefix(address, "rel://"):
host = address[6:]
defaultPort = "80"
default:
if !strings.HasPrefix(address, "rel://") && !strings.HasPrefix(address, "rels://") {
return "", fmt.Errorf("unsupported scheme: %s", address)
}
finalHost, finalPort, err := net.SplitHostPort(host)
if err != nil {
if strings.Contains(err.Error(), "missing port") {
return host + ":" + defaultPort, nil
}
// return any other split error as is
return "", err
if strings.HasPrefix(address, "rels://") {
return address[7:], nil
}
return finalHost + ":" + finalPort, nil
return address[6:], nil
}

View File

@@ -224,22 +224,16 @@ check_use_bin_variable() {
install_netbird() {
if [ -x "$(command -v netbird)" ]; then
status_output="$(netbird status 2>&1 || true)"
status_output=$(netbird status)
if echo "$status_output" | grep -q 'Management: Connected' && echo "$status_output" | grep -q 'Signal: Connected'; then
echo "NetBird service is running, please stop it before proceeding"
exit 1
fi
if echo "$status_output" | grep -q 'failed to connect to daemon error: context deadline exceeded'; then
echo "Warning: could not reach NetBird daemon (timeout), proceeding anyway"
else
if echo "$status_output" | grep -q 'Management: Connected' && \
echo "$status_output" | grep -q 'Signal: Connected'; then
echo "NetBird service is running, please stop it before proceeding"
exit 1
fi
if [ -n "$status_output" ]; then
echo "NetBird seems to be installed already, please remove it before proceeding"
exit 1
fi
fi
if [ -n "$status_output" ]; then
echo "NetBird seems to be installed already, please remove it before proceeding"
exit 1
fi
fi
# Run the installation, if a desktop environment is not detected

View File

@@ -7,8 +7,6 @@ const (
ClientHeaderValue = "netbird"
// GetURLPath is the path for the GetURL request
GetURLPath = "/upload-url"
DefaultBundleURL = "https://upload.debug.netbird.io" + GetURLPath
)
// GetURLResponse is the response for the GetURL request