Compare commits

...

9 Commits

Author SHA1 Message Date
Viktor Liu
ff5eddf70b Merge branch 'main' into add-ns-punnycode-support 2025-06-08 13:14:52 +02:00
Viktor Liu
273160c682 [client] Use punycode domains internally consequently (#3867) 2025-05-24 18:25:15 +02:00
bcmmbaga
1d6c360aec fix tests
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2025-05-23 13:07:26 +03:00
bcmmbaga
f04e7c3f06 Merge branch 'main' into add-ns-punnycode-support
# Conflicts:
#	management/server/nameserver.go
#	management/server/nameserver_test.go
2025-05-23 13:00:19 +03:00
bcmmbaga
3d89cd43c2 fix tests
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2025-05-22 22:44:30 +03:00
bcmmbaga
0eeda712d0 add support for punycode domain
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2025-05-22 22:44:12 +03:00
bcmmbaga
3e3268db5f Remove support for wildcard ns match domain
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2025-05-22 19:01:53 +03:00
bcmmbaga
31f0879e71 remove the leading dot and root dot support ns regex
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2025-05-22 18:51:05 +03:00
bcmmbaga
f25b5bb987 Enhance match domain validation logic and add test cases
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2025-05-22 16:35:45 +03:00
38 changed files with 287 additions and 259 deletions

View File

@@ -14,6 +14,7 @@ import (
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/management/domain"
)
// WGTunDevice ignore the WGTunDevice interface on Android because the creation of the tun device is different on this platform
@@ -43,11 +44,11 @@ func NewTunDevice(address wgaddr.Address, port int, key string, mtu int, iceBind
}
}
func (t *WGTunDevice) Create(routes []string, dns string, searchDomains []string) (WGConfigurer, error) {
func (t *WGTunDevice) Create(routes []string, dns string, searchDomains domain.List) (WGConfigurer, error) {
log.Info("create tun interface")
routesString := routesToString(routes)
searchDomainsToString := searchDomainsToString(searchDomains)
searchDomainsToString := searchDomainsToString(searchDomains.ToPunycodeList())
fd, err := t.tunAdapter.ConfigureInterface(t.address.String(), t.mtu, dns, searchDomainsToString, routesString)
if err != nil {

View File

@@ -8,10 +8,11 @@ import (
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/management/domain"
)
type WGTunDevice interface {
Create(routes []string, dns string, searchDomains []string) (device.WGConfigurer, error)
Create(routes []string, dns string, searchDomains domain.List) (device.WGConfigurer, error)
Up() (*bind.UniversalUDPMuxDefault, error)
UpdateAddr(address wgaddr.Address) error
WgAddress() wgaddr.Address

View File

@@ -2,7 +2,11 @@
package iface
import "fmt"
import (
"fmt"
"github.com/netbirdio/netbird/management/domain"
)
// Create creates a new Wireguard interface, sets a given IP and brings it up.
// Will reuse an existing one.
@@ -21,6 +25,6 @@ func (w *WGIface) Create() error {
}
// CreateOnAndroid this function make sense on mobile only
func (w *WGIface) CreateOnAndroid([]string, string, []string) error {
func (w *WGIface) CreateOnAndroid([]string, string, domain.List) error {
return fmt.Errorf("this function has not implemented on non mobile")
}

View File

@@ -2,11 +2,13 @@ package iface
import (
"fmt"
"github.com/netbirdio/netbird/management/domain"
)
// CreateOnAndroid creates a new Wireguard interface, sets a given IP and brings it up.
// Will reuse an existing one.
func (w *WGIface) CreateOnAndroid(routes []string, dns string, searchDomains []string) error {
func (w *WGIface) CreateOnAndroid(routes []string, dns string, searchDomains domain.List) error {
w.mu.Lock()
defer w.mu.Unlock()

View File

@@ -7,6 +7,8 @@ import (
"time"
"github.com/cenkalti/backoff/v4"
"github.com/netbirdio/netbird/management/domain"
)
// Create creates a new Wireguard interface, sets a given IP and brings it up.
@@ -36,6 +38,6 @@ func (w *WGIface) Create() error {
}
// CreateOnAndroid this function make sense on mobile only
func (w *WGIface) CreateOnAndroid([]string, string, []string) error {
func (w *WGIface) CreateOnAndroid([]string, string, domain.List) error {
return fmt.Errorf("this function has not implemented on this platform")
}

View File

@@ -239,7 +239,7 @@ func searchDomains(config HostDNSConfig) []string {
continue
}
listOfDomains = append(listOfDomains, strings.TrimSuffix(dConf.Domain, "."))
listOfDomains = append(listOfDomains, strings.TrimSuffix(dConf.Domain.PunycodeString(), "."))
}
return listOfDomains
}

View File

@@ -8,6 +8,8 @@ import (
"github.com/miekg/dns"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/domain"
)
const (
@@ -24,8 +26,8 @@ type SubdomainMatcher interface {
type HandlerEntry struct {
Handler dns.Handler
Priority int
Pattern string
OrigPattern string
Pattern domain.Domain
OrigPattern domain.Domain
IsWildcard bool
MatchSubdomains bool
}
@@ -39,7 +41,7 @@ type HandlerChain struct {
// ResponseWriterChain wraps a dns.ResponseWriter to track if handler wants to continue chain
type ResponseWriterChain struct {
dns.ResponseWriter
origPattern string
origPattern domain.Domain
shouldContinue bool
}
@@ -59,18 +61,18 @@ func NewHandlerChain() *HandlerChain {
}
// GetOrigPattern returns the original pattern of the handler that wrote the response
func (w *ResponseWriterChain) GetOrigPattern() string {
func (w *ResponseWriterChain) GetOrigPattern() domain.Domain {
return w.origPattern
}
// AddHandler adds a new handler to the chain, replacing any existing handler with the same pattern and priority
func (c *HandlerChain) AddHandler(pattern string, handler dns.Handler, priority int) {
func (c *HandlerChain) AddHandler(pattern domain.Domain, handler dns.Handler, priority int) {
c.mu.Lock()
defer c.mu.Unlock()
pattern = strings.ToLower(dns.Fqdn(pattern))
pattern = domain.Domain(strings.ToLower(dns.Fqdn(pattern.PunycodeString())))
origPattern := pattern
isWildcard := strings.HasPrefix(pattern, "*.")
isWildcard := strings.HasPrefix(pattern.PunycodeString(), "*.")
if isWildcard {
pattern = pattern[2:]
}
@@ -110,8 +112,8 @@ func (c *HandlerChain) findHandlerPosition(newEntry HandlerEntry) int {
// domain specificity next
if h.Priority == newEntry.Priority {
newDots := strings.Count(newEntry.Pattern, ".")
existingDots := strings.Count(h.Pattern, ".")
newDots := strings.Count(newEntry.Pattern.PunycodeString(), ".")
existingDots := strings.Count(h.Pattern.PunycodeString(), ".")
if newDots > existingDots {
return i
}
@@ -123,20 +125,20 @@ func (c *HandlerChain) findHandlerPosition(newEntry HandlerEntry) int {
}
// RemoveHandler removes a handler for the given pattern and priority
func (c *HandlerChain) RemoveHandler(pattern string, priority int) {
func (c *HandlerChain) RemoveHandler(pattern domain.Domain, priority int) {
c.mu.Lock()
defer c.mu.Unlock()
pattern = dns.Fqdn(pattern)
pattern = domain.Domain(dns.Fqdn(pattern.PunycodeString()))
c.removeEntry(pattern, priority)
}
func (c *HandlerChain) removeEntry(pattern string, priority int) {
func (c *HandlerChain) removeEntry(pattern domain.Domain, priority int) {
// Find and remove handlers matching both original pattern (case-insensitive) and priority
for i := len(c.handlers) - 1; i >= 0; i-- {
entry := c.handlers[i]
if strings.EqualFold(entry.OrigPattern, pattern) && entry.Priority == priority {
if strings.EqualFold(entry.OrigPattern.PunycodeString(), pattern.PunycodeString()) && entry.Priority == priority {
c.handlers = append(c.handlers[:i], c.handlers[i+1:]...)
break
}
@@ -201,16 +203,16 @@ func (c *HandlerChain) isHandlerMatch(qname string, entry HandlerEntry) bool {
case entry.Pattern == ".":
return true
case entry.IsWildcard:
parts := strings.Split(strings.TrimSuffix(qname, entry.Pattern), ".")
return len(parts) >= 2 && strings.HasSuffix(qname, entry.Pattern)
parts := strings.Split(strings.TrimSuffix(qname, entry.Pattern.PunycodeString()), ".")
return len(parts) >= 2 && strings.HasSuffix(qname, entry.Pattern.PunycodeString())
default:
// For non-wildcard patterns:
// If handler wants subdomain matching, allow suffix match
// Otherwise require exact match
if entry.MatchSubdomains {
return strings.EqualFold(qname, entry.Pattern) || strings.HasSuffix(qname, "."+entry.Pattern)
return strings.EqualFold(qname, entry.Pattern.PunycodeString()) || strings.HasSuffix(qname, "."+entry.Pattern.PunycodeString())
} else {
return strings.EqualFold(qname, entry.Pattern)
return strings.EqualFold(qname, entry.Pattern.PunycodeString())
}
}
}

View File

@@ -9,6 +9,7 @@ import (
nbdns "github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/dns/test"
"github.com/netbirdio/netbird/management/domain"
)
// TestHandlerChain_ServeDNS_Priorities tests that handlers are executed in priority order
@@ -50,8 +51,8 @@ func TestHandlerChain_ServeDNS_Priorities(t *testing.T) {
func TestHandlerChain_ServeDNS_DomainMatching(t *testing.T) {
tests := []struct {
name string
handlerDomain string
queryDomain string
handlerDomain domain.Domain
queryDomain domain.Domain
isWildcard bool
matchSubdomains bool
shouldMatch bool
@@ -141,7 +142,7 @@ func TestHandlerChain_ServeDNS_DomainMatching(t *testing.T) {
chain.AddHandler(pattern, handler, nbdns.PriorityDefault)
r := new(dns.Msg)
r.SetQuestion(tt.queryDomain, dns.TypeA)
r.SetQuestion(tt.queryDomain.PunycodeString(), dns.TypeA)
w := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}}
chain.ServeDNS(w, r)
@@ -160,17 +161,17 @@ func TestHandlerChain_ServeDNS_OverlappingDomains(t *testing.T) {
tests := []struct {
name string
handlers []struct {
pattern string
pattern domain.Domain
priority int
}
queryDomain string
queryDomain domain.Domain
expectedCalls int
expectedHandler int // index of the handler that should be called
}{
{
name: "wildcard and exact same priority - exact should win",
handlers: []struct {
pattern string
pattern domain.Domain
priority int
}{
{pattern: "*.example.com.", priority: nbdns.PriorityDefault},
@@ -183,7 +184,7 @@ func TestHandlerChain_ServeDNS_OverlappingDomains(t *testing.T) {
{
name: "higher priority wildcard over lower priority exact",
handlers: []struct {
pattern string
pattern domain.Domain
priority int
}{
{pattern: "example.com.", priority: nbdns.PriorityDefault},
@@ -196,7 +197,7 @@ func TestHandlerChain_ServeDNS_OverlappingDomains(t *testing.T) {
{
name: "multiple wildcards different priorities",
handlers: []struct {
pattern string
pattern domain.Domain
priority int
}{
{pattern: "*.example.com.", priority: nbdns.PriorityDefault},
@@ -210,7 +211,7 @@ func TestHandlerChain_ServeDNS_OverlappingDomains(t *testing.T) {
{
name: "subdomain with mix of patterns",
handlers: []struct {
pattern string
pattern domain.Domain
priority int
}{
{pattern: "*.example.com.", priority: nbdns.PriorityDefault},
@@ -224,7 +225,7 @@ func TestHandlerChain_ServeDNS_OverlappingDomains(t *testing.T) {
{
name: "root zone with specific domain",
handlers: []struct {
pattern string
pattern domain.Domain
priority int
}{
{pattern: ".", priority: nbdns.PriorityDefault},
@@ -258,7 +259,7 @@ func TestHandlerChain_ServeDNS_OverlappingDomains(t *testing.T) {
// Create and execute request
r := new(dns.Msg)
r.SetQuestion(tt.queryDomain, dns.TypeA)
r.SetQuestion(tt.queryDomain.PunycodeString(), dns.TypeA)
w := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}}
chain.ServeDNS(w, r)
@@ -330,7 +331,7 @@ func TestHandlerChain_PriorityDeregistration(t *testing.T) {
name string
ops []struct {
action string // "add" or "remove"
pattern string
pattern domain.Domain
priority int
}
query string
@@ -340,7 +341,7 @@ func TestHandlerChain_PriorityDeregistration(t *testing.T) {
name: "remove high priority keeps lower priority handler",
ops: []struct {
action string
pattern string
pattern domain.Domain
priority int
}{
{"add", "example.com.", nbdns.PriorityDNSRoute},
@@ -357,7 +358,7 @@ func TestHandlerChain_PriorityDeregistration(t *testing.T) {
name: "remove lower priority keeps high priority handler",
ops: []struct {
action string
pattern string
pattern domain.Domain
priority int
}{
{"add", "example.com.", nbdns.PriorityDNSRoute},
@@ -374,7 +375,7 @@ func TestHandlerChain_PriorityDeregistration(t *testing.T) {
name: "remove all handlers in order",
ops: []struct {
action string
pattern string
pattern domain.Domain
priority int
}{
{"add", "example.com.", nbdns.PriorityDNSRoute},
@@ -436,7 +437,7 @@ func TestHandlerChain_PriorityDeregistration(t *testing.T) {
func TestHandlerChain_MultiPriorityHandling(t *testing.T) {
chain := nbdns.NewHandlerChain()
testDomain := "example.com."
testDomain := domain.Domain("example.com.")
testQuery := "test.example.com."
// Create handlers with MatchSubdomains enabled
@@ -518,7 +519,7 @@ func TestHandlerChain_CaseSensitivity(t *testing.T) {
name string
scenario string
addHandlers []struct {
pattern string
pattern domain.Domain
priority int
subdomains bool
shouldMatch bool
@@ -530,7 +531,7 @@ func TestHandlerChain_CaseSensitivity(t *testing.T) {
name: "case insensitive exact match",
scenario: "handler registered lowercase, query uppercase",
addHandlers: []struct {
pattern string
pattern domain.Domain
priority int
subdomains bool
shouldMatch bool
@@ -544,7 +545,7 @@ func TestHandlerChain_CaseSensitivity(t *testing.T) {
name: "case insensitive wildcard match",
scenario: "handler registered mixed case wildcard, query different case",
addHandlers: []struct {
pattern string
pattern domain.Domain
priority int
subdomains bool
shouldMatch bool
@@ -558,7 +559,7 @@ func TestHandlerChain_CaseSensitivity(t *testing.T) {
name: "multiple handlers different case same domain",
scenario: "second handler should replace first despite case difference",
addHandlers: []struct {
pattern string
pattern domain.Domain
priority int
subdomains bool
shouldMatch bool
@@ -573,7 +574,7 @@ func TestHandlerChain_CaseSensitivity(t *testing.T) {
name: "subdomain matching case insensitive",
scenario: "handler with MatchSubdomains true should match regardless of case",
addHandlers: []struct {
pattern string
pattern domain.Domain
priority int
subdomains bool
shouldMatch bool
@@ -587,7 +588,7 @@ func TestHandlerChain_CaseSensitivity(t *testing.T) {
name: "root zone case insensitive",
scenario: "root zone handler should match regardless of case",
addHandlers: []struct {
pattern string
pattern domain.Domain
priority int
subdomains bool
shouldMatch bool
@@ -601,7 +602,7 @@ func TestHandlerChain_CaseSensitivity(t *testing.T) {
name: "multiple handlers different priority",
scenario: "should call higher priority handler despite case differences",
addHandlers: []struct {
pattern string
pattern domain.Domain
priority int
subdomains bool
shouldMatch bool
@@ -618,7 +619,7 @@ func TestHandlerChain_CaseSensitivity(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
chain := nbdns.NewHandlerChain()
handlerCalls := make(map[string]bool) // track which patterns were called
handlerCalls := make(map[domain.Domain]bool) // track which patterns were called
// Add handlers according to test case
for _, h := range tt.addHandlers {
@@ -686,19 +687,19 @@ func TestHandlerChain_DomainSpecificityOrdering(t *testing.T) {
scenario string
ops []struct {
action string
pattern string
pattern domain.Domain
priority int
subdomain bool
}
query string
expectedMatch string
query domain.Domain
expectedMatch domain.Domain
}{
{
name: "more specific domain matches first",
scenario: "sub.example.com should match before example.com",
ops: []struct {
action string
pattern string
pattern domain.Domain
priority int
subdomain bool
}{
@@ -713,7 +714,7 @@ func TestHandlerChain_DomainSpecificityOrdering(t *testing.T) {
scenario: "sub.example.com should match before example.com",
ops: []struct {
action string
pattern string
pattern domain.Domain
priority int
subdomain bool
}{
@@ -728,7 +729,7 @@ func TestHandlerChain_DomainSpecificityOrdering(t *testing.T) {
scenario: "after removing most specific, should fall back to less specific",
ops: []struct {
action string
pattern string
pattern domain.Domain
priority int
subdomain bool
}{
@@ -745,7 +746,7 @@ func TestHandlerChain_DomainSpecificityOrdering(t *testing.T) {
scenario: "less specific domain with higher priority should match first",
ops: []struct {
action string
pattern string
pattern domain.Domain
priority int
subdomain bool
}{
@@ -760,7 +761,7 @@ func TestHandlerChain_DomainSpecificityOrdering(t *testing.T) {
scenario: "with equal priority, more specific domain should match",
ops: []struct {
action string
pattern string
pattern domain.Domain
priority int
subdomain bool
}{
@@ -776,7 +777,7 @@ func TestHandlerChain_DomainSpecificityOrdering(t *testing.T) {
scenario: "specific domain should match before wildcard at same priority",
ops: []struct {
action string
pattern string
pattern domain.Domain
priority int
subdomain bool
}{
@@ -791,7 +792,7 @@ func TestHandlerChain_DomainSpecificityOrdering(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
chain := nbdns.NewHandlerChain()
handlers := make(map[string]*nbdns.MockSubdomainHandler)
handlers := make(map[domain.Domain]*nbdns.MockSubdomainHandler)
for _, op := range tt.ops {
if op.action == "add" {
@@ -804,7 +805,7 @@ func TestHandlerChain_DomainSpecificityOrdering(t *testing.T) {
}
r := new(dns.Msg)
r.SetQuestion(tt.query, dns.TypeA)
r.SetQuestion(tt.query.PunycodeString(), dns.TypeA)
w := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}}
// Setup handler expectations
@@ -836,9 +837,9 @@ func TestHandlerChain_DomainSpecificityOrdering(t *testing.T) {
func TestHandlerChain_AddRemoveRoundtrip(t *testing.T) {
tests := []struct {
name string
addPattern string
removePattern string
queryPattern string
addPattern domain.Domain
removePattern domain.Domain
queryPattern domain.Domain
shouldBeRemoved bool
description string
}{
@@ -954,7 +955,7 @@ func TestHandlerChain_AddRemoveRoundtrip(t *testing.T) {
handler := &nbdns.MockHandler{}
r := new(dns.Msg)
r.SetQuestion(tt.queryPattern, dns.TypeA)
r.SetQuestion(tt.queryPattern.PunycodeString(), dns.TypeA)
w := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}}
// First verify no handler is called before adding any

View File

@@ -9,6 +9,7 @@ import (
"github.com/netbirdio/netbird/client/internal/statemanager"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/domain"
)
var ErrRouteAllWithoutNameserverGroup = fmt.Errorf("unable to configure DNS for this peer using file manager without a nameserver group with all domains configured")
@@ -39,9 +40,9 @@ type HostDNSConfig struct {
}
type DomainConfig struct {
Disabled bool `json:"disabled"`
Domain string `json:"domain"`
MatchOnly bool `json:"matchOnly"`
Disabled bool `json:"disabled"`
Domain domain.Domain `json:"domain"`
MatchOnly bool `json:"matchOnly"`
}
type mockHostConfigurator struct {
@@ -103,18 +104,20 @@ func dnsConfigToHostDNSConfig(dnsConfig nbdns.Config, ip string, port int) HostD
config.RouteAll = true
}
for _, domain := range nsConfig.Domains {
for _, d := range nsConfig.Domains {
d := strings.ToLower(dns.Fqdn(d.PunycodeString()))
config.Domains = append(config.Domains, DomainConfig{
Domain: strings.ToLower(dns.Fqdn(domain)),
Domain: domain.Domain(d),
MatchOnly: !nsConfig.SearchDomainsEnabled,
})
}
}
for _, customZone := range dnsConfig.CustomZones {
matchOnly := strings.HasSuffix(customZone.Domain, ipv4ReverseZone) || strings.HasSuffix(customZone.Domain, ipv6ReverseZone)
d := strings.ToLower(dns.Fqdn(customZone.Domain))
matchOnly := strings.HasSuffix(d, ipv4ReverseZone) || strings.HasSuffix(d, ipv6ReverseZone)
config.Domains = append(config.Domains, DomainConfig{
Domain: strings.ToLower(dns.Fqdn(customZone.Domain)),
Domain: domain.Domain(d),
MatchOnly: matchOnly,
})
}

View File

@@ -79,10 +79,10 @@ func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *
continue
}
if dConf.MatchOnly {
matchDomains = append(matchDomains, strings.TrimSuffix(dConf.Domain, "."))
matchDomains = append(matchDomains, strings.TrimSuffix(dConf.Domain.PunycodeString(), "."))
continue
}
searchDomains = append(searchDomains, strings.TrimSuffix(""+dConf.Domain, "."))
searchDomains = append(searchDomains, strings.TrimSuffix(""+dConf.Domain.PunycodeString(), "."))
}
matchKey := getKeyWithInput(netbirdDNSStateKeyFormat, matchSuffix)

View File

@@ -186,9 +186,9 @@ func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig, stateManager
continue
}
if !dConf.MatchOnly {
searchDomains = append(searchDomains, strings.TrimSuffix(dConf.Domain, "."))
searchDomains = append(searchDomains, strings.TrimSuffix(dConf.Domain.PunycodeString(), "."))
}
matchDomains = append(matchDomains, "."+strings.TrimSuffix(dConf.Domain, "."))
matchDomains = append(matchDomains, "."+strings.TrimSuffix(dConf.Domain.PunycodeString(), "."))
}
if len(matchDomains) != 0 {

View File

@@ -62,8 +62,8 @@ func (m *MockServer) UpdateDNSServer(serial uint64, update nbdns.Config) error {
return fmt.Errorf("method UpdateDNSServer is not implemented")
}
func (m *MockServer) SearchDomains() []string {
return make([]string, 0)
func (m *MockServer) SearchDomains() domain.List {
return make(domain.List, 0)
}
// ProbeAvailability mocks implementation of ProbeAvailability from the Server interface

View File

@@ -125,10 +125,10 @@ func (n *networkManagerDbusConfigurator) applyDNSConfig(config HostDNSConfig, st
continue
}
if dConf.MatchOnly {
matchDomains = append(matchDomains, "~."+dConf.Domain)
matchDomains = append(matchDomains, "~."+dConf.Domain.PunycodeString())
continue
}
searchDomains = append(searchDomains, dConf.Domain)
searchDomains = append(searchDomains, dConf.Domain.PunycodeString())
}
newDomainList := append(searchDomains, matchDomains...) //nolint:gocritic

View File

@@ -1,21 +1,19 @@
package dns
import (
"reflect"
"sort"
"sync"
"github.com/netbirdio/netbird/client/internal/listener"
"github.com/netbirdio/netbird/management/domain"
)
type notifier struct {
listener listener.NetworkChangeListener
listenerMux sync.Mutex
searchDomains []string
searchDomains domain.List
}
func newNotifier(initialSearchDomains []string) *notifier {
sort.Strings(initialSearchDomains)
func newNotifier(initialSearchDomains domain.List) *notifier {
return &notifier{
searchDomains: initialSearchDomains,
}
@@ -27,16 +25,8 @@ func (n *notifier) setListener(listener listener.NetworkChangeListener) {
n.listener = listener
}
func (n *notifier) onNewSearchDomains(searchDomains []string) {
sort.Strings(searchDomains)
if len(n.searchDomains) != len(searchDomains) {
n.searchDomains = searchDomains
n.notify()
return
}
if reflect.DeepEqual(n.searchDomains, searchDomains) {
func (n *notifier) onNewSearchDomains(searchDomains domain.List) {
if searchDomains.Equal(n.searchDomains) {
return
}

View File

@@ -44,12 +44,12 @@ type Server interface {
DnsIP() string
UpdateDNSServer(serial uint64, update nbdns.Config) error
OnUpdatedHostDNSServer(strings []string)
SearchDomains() []string
SearchDomains() domain.List
ProbeAvailability()
}
type nsGroupsByDomain struct {
domain string
domain domain.Domain
groups []*nbdns.NameServerGroup
}
@@ -90,7 +90,7 @@ type handlerWithStop interface {
}
type handlerWrapper struct {
domain string
domain domain.Domain
handler handlerWithStop
priority int
}
@@ -197,7 +197,7 @@ func (s *DefaultServer) RegisterHandler(domains domain.List, handler dns.Handler
s.mux.Lock()
defer s.mux.Unlock()
s.registerHandler(domains.ToPunycodeList(), handler, priority)
s.registerHandler(domains, handler, priority)
// TODO: This will take over zones for non-wildcard domains, for which we might not have a handler in the chain
for _, domain := range domains {
@@ -207,7 +207,7 @@ func (s *DefaultServer) RegisterHandler(domains domain.List, handler dns.Handler
s.applyHostConfig()
}
func (s *DefaultServer) registerHandler(domains []string, handler dns.Handler, priority int) {
func (s *DefaultServer) registerHandler(domains domain.List, handler dns.Handler, priority int) {
log.Debugf("registering handler %s with priority %d", handler, priority)
for _, domain := range domains {
@@ -224,7 +224,7 @@ func (s *DefaultServer) DeregisterHandler(domains domain.List, priority int) {
s.mux.Lock()
defer s.mux.Unlock()
s.deregisterHandler(domains.ToPunycodeList(), priority)
s.deregisterHandler(domains, priority)
for _, domain := range domains {
zone := toZone(domain)
s.extraDomains[zone]--
@@ -235,7 +235,7 @@ func (s *DefaultServer) DeregisterHandler(domains domain.List, priority int) {
s.applyHostConfig()
}
func (s *DefaultServer) deregisterHandler(domains []string, priority int) {
func (s *DefaultServer) deregisterHandler(domains domain.List, priority int) {
log.Debugf("deregistering handler %v with priority %d", domains, priority)
for _, domain := range domains {
@@ -378,8 +378,8 @@ func (s *DefaultServer) UpdateDNSServer(serial uint64, update nbdns.Config) erro
return nil
}
func (s *DefaultServer) SearchDomains() []string {
var searchDomains []string
func (s *DefaultServer) SearchDomains() domain.List {
var searchDomains domain.List
for _, dConf := range s.currentConfig.Domains {
if dConf.Disabled {
@@ -472,18 +472,16 @@ func (s *DefaultServer) applyHostConfig() {
config := s.currentConfig
existingDomains := make(map[string]struct{})
existingDomains := make(map[domain.Domain]struct{})
for _, d := range config.Domains {
existingDomains[d.Domain] = struct{}{}
}
// add extra domains only if they're not already in the config
for domain := range s.extraDomains {
domainStr := domain.PunycodeString()
if _, exists := existingDomains[domainStr]; !exists {
for d := range s.extraDomains {
if _, exists := existingDomains[d]; !exists {
config.Domains = append(config.Domains, DomainConfig{
Domain: domainStr,
Domain: d,
MatchOnly: true,
})
}
@@ -525,7 +523,7 @@ func (s *DefaultServer) buildLocalHandlerUpdate(customZones []nbdns.CustomZone)
}
muxUpdates = append(muxUpdates, handlerWrapper{
domain: customZone.Domain,
domain: domain.Domain(customZone.Domain),
handler: s.localResolver,
priority: PriorityMatchDomain,
})
@@ -647,7 +645,7 @@ func (s *DefaultServer) createHandlersForDomainGroup(domainGroup nsGroupsByDomai
func (s *DefaultServer) updateMux(muxUpdates []handlerWrapper) {
// this will introduce a short period of time when the server is not able to handle DNS requests
for _, existing := range s.dnsMuxMap {
s.deregisterHandler([]string{existing.domain}, existing.priority)
s.deregisterHandler(domain.List{existing.domain}, existing.priority)
existing.handler.Stop()
}
@@ -658,7 +656,7 @@ func (s *DefaultServer) updateMux(muxUpdates []handlerWrapper) {
if update.domain == nbdns.RootZone {
containsRootUpdate = true
}
s.registerHandler([]string{update.domain}, update.handler, update.priority)
s.registerHandler(domain.List{update.domain}, update.handler, update.priority)
muxUpdateMap[update.handler.ID()] = update
}
@@ -687,7 +685,7 @@ func (s *DefaultServer) upstreamCallbacks(
handler dns.Handler,
priority int,
) (deactivate func(error), reactivate func()) {
var removeIndex map[string]int
var removeIndex map[domain.Domain]int
deactivate = func(err error) {
s.mux.Lock()
defer s.mux.Unlock()
@@ -695,20 +693,20 @@ func (s *DefaultServer) upstreamCallbacks(
l := log.WithField("nameservers", nsGroup.NameServers)
l.Info("Temporarily deactivating nameservers group due to timeout")
removeIndex = make(map[string]int)
removeIndex = make(map[domain.Domain]int)
for _, domain := range nsGroup.Domains {
removeIndex[domain] = -1
}
if nsGroup.Primary {
removeIndex[nbdns.RootZone] = -1
s.currentConfig.RouteAll = false
s.deregisterHandler([]string{nbdns.RootZone}, priority)
s.deregisterHandler(domain.List{nbdns.RootZone}, priority)
}
for i, item := range s.currentConfig.Domains {
if _, found := removeIndex[item.Domain]; found {
s.currentConfig.Domains[i].Disabled = true
s.deregisterHandler([]string{item.Domain}, priority)
s.deregisterHandler(domain.List{item.Domain}, priority)
removeIndex[item.Domain] = i
}
}
@@ -732,12 +730,12 @@ func (s *DefaultServer) upstreamCallbacks(
s.mux.Lock()
defer s.mux.Unlock()
for domain, i := range removeIndex {
if i == -1 || i >= len(s.currentConfig.Domains) || s.currentConfig.Domains[i].Domain != domain {
for d, i := range removeIndex {
if i == -1 || i >= len(s.currentConfig.Domains) || s.currentConfig.Domains[i].Domain != d{
continue
}
s.currentConfig.Domains[i].Disabled = false
s.registerHandler([]string{domain}, handler, priority)
s.registerHandler(domain.List{d}, handler, priority)
}
l := log.WithField("nameservers", nsGroup.NameServers)
@@ -745,7 +743,7 @@ func (s *DefaultServer) upstreamCallbacks(
if nsGroup.Primary {
s.currentConfig.RouteAll = true
s.registerHandler([]string{nbdns.RootZone}, handler, priority)
s.registerHandler(domain.List{nbdns.RootZone}, handler, priority)
}
s.applyHostConfig()
@@ -777,7 +775,7 @@ func (s *DefaultServer) addHostRootZone() {
handler.deactivate = func(error) {}
handler.reactivate = func() {}
s.registerHandler([]string{nbdns.RootZone}, handler, PriorityDefault)
s.registerHandler(domain.List{nbdns.RootZone}, handler, PriorityDefault)
}
func (s *DefaultServer) updateNSGroupStates(groups []*nbdns.NameServerGroup) {
@@ -792,7 +790,7 @@ func (s *DefaultServer) updateNSGroupStates(groups []*nbdns.NameServerGroup) {
state := peer.NSGroupState{
ID: generateGroupKey(group),
Servers: servers,
Domains: group.Domains,
Domains: group.Domains.ToPunycodeList(),
// The probe will determine the state, default enabled
Enabled: true,
Error: nil,
@@ -825,7 +823,7 @@ func generateGroupKey(nsGroup *nbdns.NameServerGroup) string {
// groupNSGroupsByDomain groups nameserver groups by their match domains
func groupNSGroupsByDomain(nsGroups []*nbdns.NameServerGroup) []nsGroupsByDomain {
domainMap := make(map[string][]*nbdns.NameServerGroup)
domainMap := make(map[domain.Domain][]*nbdns.NameServerGroup)
for _, group := range nsGroups {
if group.Primary {

View File

@@ -6,7 +6,6 @@ import (
"net"
"net/netip"
"os"
"strings"
"testing"
"time"
@@ -96,7 +95,7 @@ func init() {
formatter.SetTextFormatter(log.StandardLogger())
}
func generateDummyHandler(domain string, servers []nbdns.NameServer) *upstreamResolverBase {
func generateDummyHandler(domain domain.Domain, servers []nbdns.NameServer) *upstreamResolverBase {
var srvs []string
for _, srv := range servers {
srvs = append(srvs, getNSHostPort(srv))
@@ -151,7 +150,7 @@ func TestUpdateDNSServer(t *testing.T) {
},
NameServerGroups: []*nbdns.NameServerGroup{
{
Domains: []string{"netbird.io"},
Domains: domain.List{"netbird.io"},
NameServers: nameServers,
},
{
@@ -183,7 +182,7 @@ func TestUpdateDNSServer(t *testing.T) {
name: "New Config Should Succeed",
initLocalRecords: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: 1, Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}},
initUpstreamMap: registeredHandlerMap{
generateDummyHandler(zoneRecords[0].Name, nameServers).ID(): handlerWrapper{
generateDummyHandler(domain.Domain(zoneRecords[0].Name), nameServers).ID(): handlerWrapper{
domain: "netbird.cloud",
handler: dummyHandler,
priority: PriorityMatchDomain,
@@ -201,7 +200,7 @@ func TestUpdateDNSServer(t *testing.T) {
},
NameServerGroups: []*nbdns.NameServerGroup{
{
Domains: []string{"netbird.io"},
Domains: domain.List{"netbird.io"},
NameServers: nameServers,
},
},
@@ -302,8 +301,8 @@ func TestUpdateDNSServer(t *testing.T) {
name: "Empty Config Should Succeed and Clean Maps",
initLocalRecords: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}},
initUpstreamMap: registeredHandlerMap{
generateDummyHandler(zoneRecords[0].Name, nameServers).ID(): handlerWrapper{
domain: zoneRecords[0].Name,
generateDummyHandler(domain.Domain(zoneRecords[0].Name), nameServers).ID(): handlerWrapper{
domain: domain.Domain(zoneRecords[0].Name),
handler: dummyHandler,
priority: PriorityMatchDomain,
},
@@ -318,8 +317,8 @@ func TestUpdateDNSServer(t *testing.T) {
name: "Disabled Service Should clean map",
initLocalRecords: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}},
initUpstreamMap: registeredHandlerMap{
generateDummyHandler(zoneRecords[0].Name, nameServers).ID(): handlerWrapper{
domain: zoneRecords[0].Name,
generateDummyHandler(domain.Domain(zoneRecords[0].Name), nameServers).ID(): handlerWrapper{
domain: domain.Domain(zoneRecords[0].Name),
handler: dummyHandler,
priority: PriorityMatchDomain,
},
@@ -493,7 +492,7 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) {
dnsServer.dnsMuxMap = registeredHandlerMap{
"id1": handlerWrapper{
domain: zoneRecords[0].Name,
domain: domain.Domain(zoneRecords[0].Name),
handler: &local.Resolver{},
priority: PriorityMatchDomain,
},
@@ -525,7 +524,7 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) {
},
NameServerGroups: []*nbdns.NameServerGroup{
{
Domains: []string{"netbird.io"},
Domains: domain.List{"netbird.io"},
NameServers: nameServers,
},
{
@@ -591,7 +590,7 @@ func TestDNSServerStartStop(t *testing.T) {
t.Error(err)
}
dnsServer.registerHandler([]string{"netbird.cloud"}, dnsServer.localResolver, 1)
dnsServer.registerHandler(domain.List{"netbird.cloud"}, dnsServer.localResolver, 1)
resolver := &net.Resolver{
PreferGo: true,
@@ -651,48 +650,48 @@ func TestDNSServerUpstreamDeactivateCallback(t *testing.T) {
var domainsUpdate string
hostManager.applyDNSConfigFunc = func(config HostDNSConfig, statemanager *statemanager.Manager) error {
domains := []string{}
domains := domain.List{}
for _, item := range config.Domains {
if item.Disabled {
continue
}
domains = append(domains, item.Domain)
}
domainsUpdate = strings.Join(domains, ",")
domainsUpdate = domains.PunycodeString()
return nil
}
deactivate, reactivate := server.upstreamCallbacks(&nbdns.NameServerGroup{
Domains: []string{"domain1"},
Domains: domain.List{"domain1"},
NameServers: []nbdns.NameServer{
{IP: netip.MustParseAddr("8.8.0.0"), NSType: nbdns.UDPNameServerType, Port: 53},
},
}, nil, 0)
deactivate(nil)
expected := "domain0,domain2"
domains := []string{}
expected := "domain0, domain2"
domains := domain.List{}
for _, item := range server.currentConfig.Domains {
if item.Disabled {
continue
}
domains = append(domains, item.Domain)
}
got := strings.Join(domains, ",")
got := domains.PunycodeString()
if expected != got {
t.Errorf("expected domains list: %q, got %q", expected, got)
}
reactivate()
expected = "domain0,domain1,domain2"
domains = []string{}
expected = "domain0, domain1, domain2"
domains = domain.List{}
for _, item := range server.currentConfig.Domains {
if item.Disabled {
continue
}
domains = append(domains, item.Domain)
}
got = strings.Join(domains, ",")
got = domains.PunycodeString()
if expected != got {
t.Errorf("expected domains list: %q, got %q", expected, domainsUpdate)
}
@@ -860,7 +859,7 @@ func TestDNSPermanent_matchOnly(t *testing.T) {
Port: 53,
},
},
Domains: []string{"google.com"},
Domains: domain.List{"google.com"},
Primary: false,
},
},
@@ -1115,7 +1114,7 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
name string
initialHandlers registeredHandlerMap
updates []handlerWrapper
expectedHandlers map[string]string // map[HandlerID]domain
expectedHandlers map[string]domain.Domain // map[HandlerID]domain
description string
}{
{
@@ -1131,7 +1130,7 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
priority: PriorityMatchDomain - 1,
},
},
expectedHandlers: map[string]string{
expectedHandlers: map[string]domain.Domain{
"upstream-group2": "example.com",
},
description: "When group1 is not included in the update, it should be removed while group2 remains",
@@ -1149,7 +1148,7 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
priority: PriorityMatchDomain,
},
},
expectedHandlers: map[string]string{
expectedHandlers: map[string]domain.Domain{
"upstream-group1": "example.com",
},
description: "When group2 is not included in the update, it should be removed while group1 remains",
@@ -1182,7 +1181,7 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
priority: PriorityMatchDomain - 1,
},
},
expectedHandlers: map[string]string{
expectedHandlers: map[string]domain.Domain{
"upstream-group1": "example.com",
"upstream-group2": "example.com",
"upstream-group3": "example.com",
@@ -1217,7 +1216,7 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
priority: PriorityMatchDomain - 2,
},
},
expectedHandlers: map[string]string{
expectedHandlers: map[string]domain.Domain{
"upstream-group1": "example.com",
"upstream-group2": "example.com",
"upstream-group3": "example.com",
@@ -1237,7 +1236,7 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
priority: PriorityDefault - 1,
},
},
expectedHandlers: map[string]string{
expectedHandlers: map[string]domain.Domain{
"upstream-root2": ".",
},
description: "When root1 is not included in the update, it should be removed while root2 remains",
@@ -1254,7 +1253,7 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
priority: PriorityDefault,
},
},
expectedHandlers: map[string]string{
expectedHandlers: map[string]domain.Domain{
"upstream-root1": ".",
},
description: "When root2 is not included in the update, it should be removed while root1 remains",
@@ -1285,7 +1284,7 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
priority: PriorityDefault - 1,
},
},
expectedHandlers: map[string]string{
expectedHandlers: map[string]domain.Domain{
"upstream-root1": ".",
"upstream-root2": ".",
"upstream-root3": ".",
@@ -1318,7 +1317,7 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
priority: PriorityDefault - 2,
},
},
expectedHandlers: map[string]string{
expectedHandlers: map[string]domain.Domain{
"upstream-root1": ".",
"upstream-root2": ".",
"upstream-root3": ".",
@@ -1345,7 +1344,7 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
priority: PriorityMatchDomain,
},
},
expectedHandlers: map[string]string{
expectedHandlers: map[string]domain.Domain{
"upstream-group1": "example.com",
"upstream-other": "other.com",
},
@@ -1384,7 +1383,7 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
priority: PriorityMatchDomain,
},
},
expectedHandlers: map[string]string{
expectedHandlers: map[string]domain.Domain{
"upstream-group1": "example.com",
"upstream-group2": "example.com",
"upstream-other": "other.com",
@@ -1440,7 +1439,7 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
for _, muxEntry := range server.dnsMuxMap {
if chainEntry.Handler == muxEntry.handler &&
chainEntry.Priority == muxEntry.priority &&
chainEntry.Pattern == dns.Fqdn(muxEntry.domain) {
chainEntry.Pattern.PunycodeString() == dns.Fqdn(muxEntry.domain.PunycodeString()) {
foundInMux = true
break
}
@@ -1459,8 +1458,8 @@ func TestExtraDomains(t *testing.T) {
registerDomains []domain.List
deregisterDomains []domain.List
finalConfig nbdns.Config
expectedDomains []string
expectedMatchOnly []string
expectedDomains domain.List
expectedMatchOnly domain.List
applyHostConfigCall int
}{
{
@@ -1474,12 +1473,12 @@ func TestExtraDomains(t *testing.T) {
{Domain: "config.example.com"},
},
},
expectedDomains: []string{
expectedDomains: domain.List{
"config.example.com.",
"extra1.example.com.",
"extra2.example.com.",
},
expectedMatchOnly: []string{
expectedMatchOnly: domain.List{
"extra1.example.com.",
"extra2.example.com.",
},
@@ -1496,12 +1495,12 @@ func TestExtraDomains(t *testing.T) {
registerDomains: []domain.List{
{"extra1.example.com", "extra2.example.com"},
},
expectedDomains: []string{
expectedDomains: domain.List{
"config.example.com.",
"extra1.example.com.",
"extra2.example.com.",
},
expectedMatchOnly: []string{
expectedMatchOnly: domain.List{
"extra1.example.com.",
"extra2.example.com.",
},
@@ -1519,12 +1518,12 @@ func TestExtraDomains(t *testing.T) {
registerDomains: []domain.List{
{"extra.example.com", "overlap.example.com"},
},
expectedDomains: []string{
expectedDomains: domain.List{
"config.example.com.",
"overlap.example.com.",
"extra.example.com.",
},
expectedMatchOnly: []string{
expectedMatchOnly: domain.List{
"extra.example.com.",
},
applyHostConfigCall: 2,
@@ -1544,12 +1543,12 @@ func TestExtraDomains(t *testing.T) {
deregisterDomains: []domain.List{
{"extra1.example.com", "extra3.example.com"},
},
expectedDomains: []string{
expectedDomains: domain.List{
"config.example.com.",
"extra2.example.com.",
"extra4.example.com.",
},
expectedMatchOnly: []string{
expectedMatchOnly: domain.List{
"extra2.example.com.",
"extra4.example.com.",
},
@@ -1570,13 +1569,13 @@ func TestExtraDomains(t *testing.T) {
deregisterDomains: []domain.List{
{"duplicate.example.com"},
},
expectedDomains: []string{
expectedDomains: domain.List{
"config.example.com.",
"extra.example.com.",
"other.example.com.",
"duplicate.example.com.",
},
expectedMatchOnly: []string{
expectedMatchOnly: domain.List{
"extra.example.com.",
"other.example.com.",
"duplicate.example.com.",
@@ -1601,13 +1600,13 @@ func TestExtraDomains(t *testing.T) {
{Domain: "newconfig.example.com"},
},
},
expectedDomains: []string{
expectedDomains: domain.List{
"config.example.com.",
"newconfig.example.com.",
"extra.example.com.",
"duplicate.example.com.",
},
expectedMatchOnly: []string{
expectedMatchOnly: domain.List{
"extra.example.com.",
"duplicate.example.com.",
},
@@ -1628,12 +1627,12 @@ func TestExtraDomains(t *testing.T) {
deregisterDomains: []domain.List{
{"protected.example.com"},
},
expectedDomains: []string{
expectedDomains: domain.List{
"extra.example.com.",
"config.example.com.",
"protected.example.com.",
},
expectedMatchOnly: []string{
expectedMatchOnly: domain.List{
"extra.example.com.",
},
applyHostConfigCall: 3,
@@ -1644,7 +1643,7 @@ func TestExtraDomains(t *testing.T) {
ServiceEnable: true,
NameServerGroups: []*nbdns.NameServerGroup{
{
Domains: []string{"ns.example.com", "overlap.ns.example.com"},
Domains: domain.List{"ns.example.com", "overlap.ns.example.com"},
NameServers: []nbdns.NameServer{
{
IP: netip.MustParseAddr("8.8.8.8"),
@@ -1658,12 +1657,12 @@ func TestExtraDomains(t *testing.T) {
registerDomains: []domain.List{
{"extra.example.com", "overlap.ns.example.com"},
},
expectedDomains: []string{
expectedDomains: domain.List{
"ns.example.com.",
"overlap.ns.example.com.",
"extra.example.com.",
},
expectedMatchOnly: []string{
expectedMatchOnly: domain.List{
"ns.example.com.",
"overlap.ns.example.com.",
"extra.example.com.",
@@ -1734,8 +1733,8 @@ func TestExtraDomains(t *testing.T) {
lastConfig := capturedConfigs[len(capturedConfigs)-1]
// Check all expected domains are present
domainMap := make(map[string]bool)
matchOnlyMap := make(map[string]bool)
domainMap := make(map[domain.Domain]bool)
matchOnlyMap := make(map[domain.Domain]bool)
for _, d := range lastConfig.Domains {
domainMap[d.Domain] = true
@@ -1852,12 +1851,12 @@ func TestUpdateConfigWithExistingExtraDomains(t *testing.T) {
err := server.applyConfiguration(initialConfig)
assert.NoError(t, err)
var domains []string
var domains domain.List
for _, d := range capturedConfig.Domains {
domains = append(domains, d.Domain)
}
assert.Contains(t, domains, "config.example.com.")
assert.Contains(t, domains, "extra.example.com.")
assert.Contains(t, domains, domain.Domain("config.example.com."))
assert.Contains(t, domains, domain.Domain("extra.example.com."))
// Now apply a new configuration with overlapping domain
updatedConfig := nbdns.Config{
@@ -1871,7 +1870,7 @@ func TestUpdateConfigWithExistingExtraDomains(t *testing.T) {
assert.NoError(t, err)
// Verify both domains are in config, but no duplicates
domains = []string{}
domains = domain.List{}
matchOnlyCount := 0
for _, d := range capturedConfig.Domains {
domains = append(domains, d.Domain)
@@ -1880,12 +1879,12 @@ func TestUpdateConfigWithExistingExtraDomains(t *testing.T) {
}
}
assert.Contains(t, domains, "config.example.com.")
assert.Contains(t, domains, "extra.example.com.")
assert.Contains(t, domains, domain.Domain("config.example.com."))
assert.Contains(t, domains, domain.Domain("extra.example.com."))
assert.Equal(t, 2, len(domains), "Should have exactly 2 domains with no duplicates")
// Extra domain should no longer be marked as match-only when in config
matchOnlyDomain := ""
var matchOnlyDomain domain.Domain
for _, d := range capturedConfig.Domains {
if d.Domain == "extra.example.com." && d.MatchOnly {
matchOnlyDomain = d.Domain
@@ -1938,10 +1937,10 @@ func TestDomainCaseHandling(t *testing.T) {
err := server.applyConfiguration(config)
assert.NoError(t, err)
var domains []string
var domains domain.List
for _, d := range capturedConfig.Domains {
domains = append(domains, d.Domain)
}
assert.Contains(t, domains, "config.example.com.", "Mixed case domain should be normalized and pre.sent")
assert.Contains(t, domains, "mixed.example.com.", "Mixed case domain should be normalized and present")
assert.Contains(t, domains, domain.Domain("config.example.com."), "Mixed case domain should be normalized and pre.sent")
assert.Contains(t, domains, domain.Domain("mixed.example.com."), "Mixed case domain should be normalized and present")
}

View File

@@ -117,15 +117,15 @@ func (s *systemdDbusConfigurator) applyDNSConfig(config HostDNSConfig, stateMana
continue
}
domainsInput = append(domainsInput, systemdDbusLinkDomainsInput{
Domain: dConf.Domain,
Domain: dConf.Domain.PunycodeString(),
MatchOnly: dConf.MatchOnly,
})
if dConf.MatchOnly {
matchDomains = append(matchDomains, dConf.Domain)
matchDomains = append(matchDomains, dConf.Domain.PunycodeString())
continue
}
searchDomains = append(searchDomains, dConf.Domain)
searchDomains = append(searchDomains, dConf.Domain.PunycodeString())
}
if config.RouteAll {

View File

@@ -22,6 +22,7 @@ import (
"github.com/netbirdio/netbird/client/internal/dns/types"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/management/domain"
)
const (
@@ -48,7 +49,7 @@ type upstreamResolverBase struct {
cancel context.CancelFunc
upstreamClient upstreamClient
upstreamServers []string
domain string
domain domain.Domain
disabled bool
failsCount atomic.Int32
successCount atomic.Int32
@@ -62,7 +63,7 @@ type upstreamResolverBase struct {
statusRecorder *peer.Status
}
func newUpstreamResolverBase(ctx context.Context, statusRecorder *peer.Status, domain string) *upstreamResolverBase {
func newUpstreamResolverBase(ctx context.Context, statusRecorder *peer.Status, domain domain.Domain) *upstreamResolverBase {
ctx, cancel := context.WithCancel(ctx)
return &upstreamResolverBase{

View File

@@ -10,6 +10,7 @@ import (
"github.com/miekg/dns"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/management/domain"
nbnet "github.com/netbirdio/netbird/util/net"
)
@@ -28,7 +29,7 @@ func newUpstreamResolver(
_ netip.Prefix,
statusRecorder *peer.Status,
hostsDNSHolder *hostsDNSHolder,
domain string,
domain domain.Domain,
) (*upstreamResolver, error) {
upstreamResolverBase := newUpstreamResolverBase(ctx, statusRecorder, domain)
c := &upstreamResolver{

View File

@@ -10,6 +10,7 @@ import (
"github.com/miekg/dns"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/management/domain"
)
type upstreamResolver struct {
@@ -23,7 +24,7 @@ func newUpstreamResolver(
_ netip.Prefix,
statusRecorder *peer.Status,
_ *hostsDNSHolder,
domain string,
domain domain.Domain,
) (*upstreamResolver, error) {
upstreamResolverBase := newUpstreamResolverBase(ctx, statusRecorder, domain)
nonIOS := &upstreamResolver{

View File

@@ -15,6 +15,7 @@ import (
"golang.org/x/sys/unix"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/management/domain"
)
type upstreamResolverIOS struct {
@@ -31,7 +32,7 @@ func newUpstreamResolver(
net netip.Prefix,
statusRecorder *peer.Status,
_ *hostsDNSHolder,
domain string,
domain domain.Domain,
) (*upstreamResolverIOS, error) {
upstreamResolverBase := newUpstreamResolverBase(ctx, statusRecorder, domain)

View File

@@ -1165,7 +1165,7 @@ func toDNSConfig(protoDNSConfig *mgmProto.DNSConfig, network netip.Prefix) nbdns
for _, nsGroup := range protoDNSConfig.GetNameServerGroups() {
dnsNSGroup := &nbdns.NameServerGroup{
Primary: nsGroup.GetPrimary(),
Domains: nsGroup.GetDomains(),
Domains: domain.FromPunycodeList(nsGroup.GetDomains()),
SearchDomainsEnabled: nsGroup.GetSearchDomainsEnabled(),
}
for _, ns := range nsGroup.GetNameServers() {

View File

@@ -44,6 +44,7 @@ import (
"github.com/netbirdio/netbird/client/system"
nbdns "github.com/netbirdio/netbird/dns"
mgmt "github.com/netbirdio/netbird/management/client"
"github.com/netbirdio/netbird/management/domain"
mgmtProto "github.com/netbirdio/netbird/management/proto"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/activity"
@@ -77,7 +78,7 @@ var (
type MockWGIface struct {
CreateFunc func() error
CreateOnAndroidFunc func(routeRange []string, ip string, domains []string) error
CreateOnAndroidFunc func(routeRange []string, ip string, domains domain.List) error
IsUserspaceBindFunc func() bool
NameFunc func() string
AddressFunc func() wgaddr.Address
@@ -111,7 +112,7 @@ func (m *MockWGIface) Create() error {
return m.CreateFunc()
}
func (m *MockWGIface) CreateOnAndroid(routeRange []string, ip string, domains []string) error {
func (m *MockWGIface) CreateOnAndroid(routeRange []string, ip string, domains domain.List) error {
return m.CreateOnAndroidFunc(routeRange, ip, domains)
}

View File

@@ -14,11 +14,12 @@ import (
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/iface/wgproxy"
"github.com/netbirdio/netbird/management/domain"
)
type wgIfaceBase interface {
Create() error
CreateOnAndroid(routeRange []string, ip string, domains []string) error
CreateOnAndroid(routeRange []string, ip string, domains domain.List) error
IsUserspaceBind() bool
Name() string
Address() wgaddr.Address

View File

@@ -229,15 +229,14 @@ func (d *DnsInterceptor) writeMsg(w dns.ResponseWriter, r *dns.Msg) error {
}
if len(r.Answer) > 0 && len(r.Question) > 0 {
origPattern := ""
var origPattern domain.Domain
if writer, ok := w.(*nbdns.ResponseWriterChain); ok {
origPattern = writer.GetOrigPattern()
}
resolvedDomain := domain.Domain(strings.ToLower(r.Question[0].Name))
// already punycode via RegisterHandler()
originalDomain := domain.Domain(origPattern)
originalDomain := origPattern
if originalDomain == "" {
originalDomain = resolvedDomain
}

View File

@@ -6,6 +6,8 @@ import (
"net/url"
"strconv"
"strings"
"github.com/netbirdio/netbird/management/domain"
)
const (
@@ -64,7 +66,7 @@ type NameServerGroup struct {
// Primary indicates that the nameserver group is the primary resolver for any dns query
Primary bool
// Domains indicate the dns query domains to use with this nameserver group
Domains []string `gorm:"serializer:json"`
Domains domain.List `gorm:"serializer:json"`
// Enabled group status
Enabled bool
// SearchDomainsEnabled indicates whether to add match domains to search domains list or not
@@ -142,7 +144,7 @@ func (g *NameServerGroup) Copy() *NameServerGroup {
Groups: make([]string, len(g.Groups)),
Enabled: g.Enabled,
Primary: g.Primary,
Domains: make([]string, len(g.Domains)),
Domains: make(domain.List, len(g.Domains)),
SearchDomainsEnabled: g.SearchDomainsEnabled,
}
@@ -188,7 +190,7 @@ func containsNameServer(element NameServer, list []NameServer) bool {
return false
}
func compareGroupsList(list, other []string) bool {
func compareGroupsList[T comparable](list, other []T) bool {
if len(list) != len(other) {
return false
}

View File

@@ -30,7 +30,7 @@ func (d Domain) SafeString() string {
}
// PunycodeString returns the punycode representation of the Domain.
// This should only be used if a punycode domain is expected but only a string is supported.
// This should only be used if a punycode domain is expected but only a string is supported (e.g. an external library).
func (d Domain) PunycodeString() string {
return string(d)
}

View File

@@ -1,7 +1,7 @@
package domain
import (
"sort"
"slices"
"strings"
)
@@ -41,6 +41,7 @@ func (d List) ToSafeStringList() []string {
}
// String converts List to a comma-separated string.
// This is useful for displaying domain names in a user-friendly format.
func (d List) String() (string, error) {
list, err := d.ToStringList()
if err != nil {
@@ -50,7 +51,8 @@ func (d List) String() (string, error) {
}
// SafeString converts List to a comma-separated non-punycode string.
// If a domain cannot be converted, the original string is used.
// This is useful for displaying domain names in a user-friendly format.
// If a domain cannot be converted, the original (punycode) string is used.
func (d List) SafeString() string {
str, err := d.String()
if err != nil {
@@ -64,28 +66,22 @@ func (d List) PunycodeString() string {
return strings.Join(d.ToPunycodeList(), ", ")
}
// Equal checks if two domain lists are equal without considering the order.
func (d List) Equal(domains List) bool {
if len(d) != len(domains) {
return false
}
sort.Slice(d, func(i, j int) bool {
return d[i] < d[j]
})
d1 := slices.Clone(d)
d2 := slices.Clone(domains)
sort.Slice(domains, func(i, j int) bool {
return domains[i] < domains[j]
})
slices.Sort(d1)
slices.Sort(d2)
for i, domain := range d {
if domain != domains[i] {
return false
}
}
return true
return slices.Equal(d1, d2)
}
// FromStringList creates a DomainList from a slice of string.
// FromStringList creates a List from a slice of strings.
func FromStringList(s []string) (List, error) {
var dl List
for _, domain := range s {

View File

@@ -78,7 +78,7 @@ type Manager interface {
DeleteRoute(ctx context.Context, accountID string, routeID route.ID, userID string) error
ListRoutes(ctx context.Context, accountID, userID string) ([]*route.Route, error)
GetNameServerGroup(ctx context.Context, accountID, userID, nsGroupID string) (*nbdns.NameServerGroup, error)
CreateNameServerGroup(ctx context.Context, accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool, userID string, searchDomainsEnabled bool) (*nbdns.NameServerGroup, error)
CreateNameServerGroup(ctx context.Context, accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains domain.List, enabled bool, userID string, searchDomainsEnabled bool) (*nbdns.NameServerGroup, error)
SaveNameServerGroup(ctx context.Context, accountID, userID string, nsGroupToSave *nbdns.NameServerGroup) error
DeleteNameServerGroup(ctx context.Context, accountID, nsGroupID, userID string) error
ListNameServerGroups(ctx context.Context, accountID string, userID string) ([]*nbdns.NameServerGroup, error)

View File

@@ -19,6 +19,9 @@ import (
"github.com/stretchr/testify/require"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/management/domain"
"github.com/netbirdio/netbird/management/server/idp"
nbdns "github.com/netbirdio/netbird/dns"
nbAccount "github.com/netbirdio/netbird/management/server/account"
"github.com/netbirdio/netbird/management/server/activity"
@@ -1688,7 +1691,7 @@ func TestAccount_Copy(t *testing.T) {
NameServerGroups: map[string]*nbdns.NameServerGroup{
"nsGroup1": {
ID: "nsGroup1",
Domains: []string{},
Domains: domain.List{},
Groups: []string{},
NameServers: []nbdns.NameServer{},
},

View File

@@ -258,7 +258,7 @@ func convertToProtoCustomZone(zone nbdns.CustomZone) *proto.CustomZone {
func convertToProtoNameServerGroup(nsGroup *nbdns.NameServerGroup) *proto.NameServerGroup {
protoGroup := &proto.NameServerGroup{
Primary: nsGroup.Primary,
Domains: nsGroup.Domains,
Domains: nsGroup.Domains.ToPunycodeList(),
SearchDomainsEnabled: nsGroup.SearchDomainsEnabled,
NameServers: make([]*proto.NameServer, 0, len(nsGroup.NameServers)),
}

View File

@@ -12,6 +12,7 @@ import (
"github.com/stretchr/testify/assert"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/domain"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/settings"
@@ -367,7 +368,7 @@ func generateTestData(size int) nbdns.Config {
config.NameServerGroups[i] = &nbdns.NameServerGroup{
ID: fmt.Sprintf("group%d", i),
Primary: i == 0,
Domains: []string{fmt.Sprintf("domain%d.com", i)},
Domains: domain.List{domain.Domain(fmt.Sprintf("domain%d.com", i))},
SearchDomainsEnabled: true,
NameServers: []nbdns.NameServer{
{
@@ -547,7 +548,7 @@ func TestDNSAccountPeersUpdate(t *testing.T) {
Port: dns.DefaultDNSPort,
}},
[]string{"groupB"},
true, []string{}, true, userID, false,
true, domain.List{}, true, userID, false,
)
assert.NoError(t, err)
@@ -580,7 +581,7 @@ func TestDNSAccountPeersUpdate(t *testing.T) {
Port: dns.DefaultDNSPort,
}},
[]string{"groupA"},
true, []string{}, true, userID, false,
true, domain.List{}, true, userID, false,
)
assert.NoError(t, err)

View File

@@ -9,6 +9,7 @@ import (
log "github.com/sirupsen/logrus"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/domain"
"github.com/netbirdio/netbird/management/server/account"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/http/api"
@@ -83,7 +84,13 @@ func (h *nameserversHandler) createNameserverGroup(w http.ResponseWriter, r *htt
return
}
nsGroup, err := h.accountManager.CreateNameServerGroup(r.Context(), accountID, req.Name, req.Description, nsList, req.Groups, req.Primary, req.Domains, req.Enabled, userID, req.SearchDomainsEnabled)
domains, err := domain.FromStringList(req.Domains)
if err != nil {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid domains format"), w)
return
}
nsGroup, err := h.accountManager.CreateNameServerGroup(r.Context(), accountID, req.Name, req.Description, nsList, req.Groups, req.Primary, domains, req.Enabled, userID, req.SearchDomainsEnabled)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -123,12 +130,18 @@ func (h *nameserversHandler) updateNameserverGroup(w http.ResponseWriter, r *htt
return
}
domains, err := domain.FromStringList(req.Domains)
if err != nil {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid domains format"), w)
return
}
updatedNSGroup := &nbdns.NameServerGroup{
ID: nsGroupID,
Name: req.Name,
Description: req.Description,
Primary: req.Primary,
Domains: req.Domains,
Domains: domains,
NameServers: nsList,
Groups: req.Groups,
Enabled: req.Enabled,
@@ -227,7 +240,7 @@ func toNameserverGroupResponse(serverNSGroup *nbdns.NameServerGroup) *api.Namese
Name: serverNSGroup.Name,
Description: serverNSGroup.Description,
Primary: serverNSGroup.Primary,
Domains: serverNSGroup.Domains,
Domains: serverNSGroup.Domains.ToSafeStringList(),
Groups: serverNSGroup.Groups,
Nameservers: nsList,
Enabled: serverNSGroup.Enabled,

View File

@@ -10,17 +10,15 @@ import (
"net/netip"
"testing"
"github.com/gorilla/mux"
"github.com/stretchr/testify/assert"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/status"
"github.com/gorilla/mux"
"github.com/netbirdio/netbird/management/domain"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/mock_server"
"github.com/netbirdio/netbird/management/server/status"
)
const (
@@ -47,7 +45,7 @@ var baseExistingNSGroup = &nbdns.NameServerGroup{
},
},
Groups: []string{"testing"},
Domains: []string{"domain"},
Domains: domain.List{"domain"},
Enabled: true,
}
@@ -60,7 +58,7 @@ func initNameserversTestData() *nameserversHandler {
}
return nil, status.Errorf(status.NotFound, "nameserver group with ID %s not found", nsGroupID)
},
CreateNameServerGroupFunc: func(_ context.Context, accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool, _ string, searchDomains bool) (*nbdns.NameServerGroup, error) {
CreateNameServerGroupFunc: func(_ context.Context, accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains domain.List, enabled bool, _ string, searchDomains bool) (*nbdns.NameServerGroup, error) {
return &nbdns.NameServerGroup{
ID: existingNSGroupID,
Name: name,

View File

@@ -77,7 +77,7 @@ type MockAccountManager struct {
GetPATFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserId string, tokenID string) (*types.PersonalAccessToken, error)
GetAllPATsFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserId string) ([]*types.PersonalAccessToken, error)
GetNameServerGroupFunc func(ctx context.Context, accountID, userID, nsGroupID string) (*nbdns.NameServerGroup, error)
CreateNameServerGroupFunc func(ctx context.Context, accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool, userID string, searchDomainsEnabled bool) (*nbdns.NameServerGroup, error)
CreateNameServerGroupFunc func(ctx context.Context, accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains domain.List, enabled bool, userID string, searchDomainsEnabled bool) (*nbdns.NameServerGroup, error)
SaveNameServerGroupFunc func(ctx context.Context, accountID, userID string, nsGroupToSave *nbdns.NameServerGroup) error
DeleteNameServerGroupFunc func(ctx context.Context, accountID, nsGroupID, userID string) error
ListNameServerGroupsFunc func(ctx context.Context, accountID string, userID string) ([]*nbdns.NameServerGroup, error)
@@ -567,7 +567,7 @@ func (am *MockAccountManager) GetNameServerGroup(ctx context.Context, accountID,
}
// CreateNameServerGroup mocks CreateNameServerGroup of the AccountManager interface
func (am *MockAccountManager) CreateNameServerGroup(ctx context.Context, accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool, userID string, searchDomainsEnabled bool) (*nbdns.NameServerGroup, error) {
func (am *MockAccountManager) CreateNameServerGroup(ctx context.Context, accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains domain.List, enabled bool, userID string, searchDomainsEnabled bool) (*nbdns.NameServerGroup, error) {
if am.CreateNameServerGroupFunc != nil {
return am.CreateNameServerGroupFunc(ctx, accountID, name, description, nameServerList, groups, primary, domains, enabled, userID, searchDomainsEnabled)
}

View File

@@ -10,6 +10,7 @@ import (
"github.com/rs/xid"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/domain"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations"
@@ -18,7 +19,7 @@ import (
"github.com/netbirdio/netbird/management/server/types"
)
const domainPattern = `^(?i)[a-z0-9]+([\-\.]{1}[a-z0-9]+)*[*.a-z]{1,}$`
const domainPattern = `^(?i)[a-z0-9]+([\-]+[a-z0-9]+)*[*.a-z]{1,}$`
var invalidDomainName = errors.New("invalid domain name")
@@ -36,7 +37,7 @@ func (am *DefaultAccountManager) GetNameServerGroup(ctx context.Context, account
}
// CreateNameServerGroup creates and saves a new nameserver group
func (am *DefaultAccountManager) CreateNameServerGroup(ctx context.Context, accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool, userID string, searchDomainEnabled bool) (*nbdns.NameServerGroup, error) {
func (am *DefaultAccountManager) CreateNameServerGroup(ctx context.Context, accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains domain.List, enabled bool, userID string, searchDomainEnabled bool) (*nbdns.NameServerGroup, error) {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
@@ -252,7 +253,7 @@ func areNameServerGroupChangesAffectPeers(ctx context.Context, transaction store
return anyGroupHasPeersOrResources(ctx, transaction, oldNSGroup.AccountID, oldNSGroup.Groups)
}
func validateDomainInput(primary bool, domains []string, searchDomainsEnabled bool) error {
func validateDomainInput(primary bool, domains domain.List, searchDomainsEnabled bool) error {
if !primary && len(domains) == 0 {
return status.Errorf(status.InvalidArgument, "nameserver group primary status is false and domains are empty,"+
" it should be primary or have at least one domain")
@@ -268,7 +269,7 @@ func validateDomainInput(primary bool, domains []string, searchDomainsEnabled bo
}
for _, domain := range domains {
if err := validateDomain(domain); err != nil {
if err := validateDomain(domain.PunycodeString()); err != nil {
return status.Errorf(status.InvalidArgument, "nameserver group got an invalid domain: %s %q", domain, err)
}
}

View File

@@ -11,6 +11,7 @@ import (
"github.com/stretchr/testify/require"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/domain"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
@@ -41,7 +42,7 @@ func TestCreateNameServerGroup(t *testing.T) {
groups []string
nameServers []nbdns.NameServer
primary bool
domains []string
domains domain.List
searchDomains bool
}
@@ -102,7 +103,7 @@ func TestCreateNameServerGroup(t *testing.T) {
description: "super",
groups: []string{group1ID},
primary: false,
domains: []string{validDomain},
domains: domain.List{validDomain},
nameServers: []nbdns.NameServer{
{
IP: netip.MustParseAddr("1.1.1.1"),
@@ -123,7 +124,7 @@ func TestCreateNameServerGroup(t *testing.T) {
Name: "super",
Description: "super",
Primary: false,
Domains: []string{"example.com"},
Domains: domain.List{"example.com"},
Groups: []string{group1ID},
NameServers: []nbdns.NameServer{
{
@@ -360,7 +361,7 @@ func TestCreateNameServerGroup(t *testing.T) {
name: "super",
description: "super",
groups: []string{group1ID},
domains: []string{invalidDomain},
domains: domain.List{invalidDomain},
nameServers: []nbdns.NameServer{
{
IP: netip.MustParseAddr("1.1.1.1"),
@@ -447,8 +448,8 @@ func TestSaveNameServerGroup(t *testing.T) {
validGroups := []string{group2ID}
invalidGroups := []string{"nonExisting"}
disabledPrimary := false
validDomains := []string{validDomain}
invalidDomains := []string{invalidDomain}
validDomains := domain.List{validDomain}
invalidDomains := domain.List{invalidDomain}
validNameServerList := []nbdns.NameServer{
{
@@ -491,7 +492,7 @@ func TestSaveNameServerGroup(t *testing.T) {
newID *string
newName *string
newPrimary *bool
newDomains []string
newDomains domain.List
newNSList []nbdns.NameServer
newGroups []string
skipCopying bool
@@ -908,6 +909,11 @@ func TestValidateDomain(t *testing.T) {
domain: "example.",
errFunc: require.NoError,
},
{
name: "Valid domain name with double hyphen",
domain: "xn--bcher-kva.com",
errFunc: require.NoError,
},
{
name: "Invalid wildcard domain name",
domain: "*.example",
@@ -924,8 +930,8 @@ func TestValidateDomain(t *testing.T) {
errFunc: require.Error,
},
{
name: "Invalid domain name with double hyphen",
domain: "test--example.com",
name: "Invalid domain name with double dot",
domain: "example..com",
errFunc: require.Error,
},
{
@@ -1009,7 +1015,7 @@ func TestNameServerAccountPeersUpdate(t *testing.T) {
Port: nbdns.DefaultDNSPort,
}},
[]string{"groupA"},
true, []string{}, true, userID, false,
true, domain.List{}, true, userID, false,
)
assert.NoError(t, err)
@@ -1054,7 +1060,7 @@ func TestNameServerAccountPeersUpdate(t *testing.T) {
Port: nbdns.DefaultDNSPort,
}},
[]string{"groupB"},
true, []string{}, true, userID, false,
true, domain.List{}, true, userID, false,
)
assert.NoError(t, err)

View File

@@ -1108,7 +1108,7 @@ func TestToSyncResponse(t *testing.T) {
Port: nbdns.DefaultDNSPort,
}},
Primary: true,
Domains: []string{"example.com"},
Domains: domain.List{"example.com"},
Enabled: true,
SearchDomainsEnabled: true,
},
@@ -1121,7 +1121,7 @@ func TestToSyncResponse(t *testing.T) {
}},
Groups: []string{"group1"},
Primary: true,
Domains: []string{"example.com"},
Domains: domain.List{"example.com"},
Enabled: true,
SearchDomainsEnabled: true,
},
@@ -1995,7 +1995,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
Port: nbdns.DefaultDNSPort,
}},
[]string{"groupC"},
true, []string{}, true, userID, false,
true, domain.List{}, true, userID, false,
)
require.NoError(t, err)