Compare commits

..

3 Commits

Author SHA1 Message Date
Hakan Sariman
2db23a42dc Add DNS configuration snapshot and per-domain statistics tracking 2025-09-09 17:09:13 +07:00
Hakan Sariman
c2822eebb0 [client] Enhance logging for peer disconnection events 2025-09-09 15:02:16 +07:00
Hakan Sariman
5b246e0a08 debug dns 2025-09-09 14:48:39 +07:00
52 changed files with 857 additions and 985 deletions

View File

@@ -217,7 +217,7 @@ jobs:
- arch: "386"
raceFlag: ""
- arch: "amd64"
raceFlag: "-race"
raceFlag: ""
runs-on: ubuntu-22.04
steps:
- name: Install Go

View File

@@ -33,7 +33,6 @@ type ErrListener interface {
// the backend want to show an url for the user
type URLOpener interface {
Open(string)
OnLoginSuccess()
}
// Auth can register or login new client
@@ -182,11 +181,6 @@ func (a *Auth) login(urlOpener URLOpener) error {
err = a.withBackOff(a.ctx, func() error {
err := internal.Login(a.ctx, a.config, "", jwtToken)
if err == nil {
go urlOpener.OnLoginSuccess()
}
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied) {
return nil
}

View File

@@ -388,12 +388,12 @@ func generateDebugBundle(config *profilemanager.Config, recorder *peer.Status, c
}
func init() {
debugBundleCmd.Flags().Uint32VarP(&logFileCount, "log-file-count", "C", 1, "Number of rotated log files to include in debug bundle")
debugBundleCmd.Flags().Uint32VarP(&logFileCount, "log-file-count", "C", 10, "Number of rotated log files to include in debug bundle")
debugBundleCmd.Flags().BoolVarP(&systemInfoFlag, "system-info", "S", true, "Adds system information to the debug bundle")
debugBundleCmd.Flags().BoolVarP(&uploadBundleFlag, "upload-bundle", "U", false, "Uploads the debug bundle to a server")
debugBundleCmd.Flags().StringVar(&uploadBundleURLFlag, "upload-bundle-url", types.DefaultBundleURL, "Service URL to get an URL to upload the debug bundle")
forCmd.Flags().Uint32VarP(&logFileCount, "log-file-count", "C", 1, "Number of rotated log files to include in debug bundle")
forCmd.Flags().Uint32VarP(&logFileCount, "log-file-count", "C", 10, "Number of rotated log files to include in debug bundle")
forCmd.Flags().BoolVarP(&systemInfoFlag, "system-info", "S", true, "Adds system information to the debug bundle")
forCmd.Flags().BoolVarP(&uploadBundleFlag, "upload-bundle", "U", false, "Uploads the debug bundle to a server")
forCmd.Flags().StringVar(&uploadBundleURLFlag, "upload-bundle-url", types.DefaultBundleURL, "Service URL to get an URL to upload the debug bundle")

View File

@@ -231,7 +231,7 @@ func FlagNameToEnvVar(cmdFlag string, prefix string) string {
// DialClientGRPCServer returns client connection to the daemon server.
func DialClientGRPCServer(ctx context.Context, addr string) (*grpc.ClientConn, error) {
ctx, cancel := context.WithTimeout(ctx, time.Second*10)
ctx, cancel := context.WithTimeout(ctx, time.Second*3)
defer cancel()
return grpc.DialContext(

View File

@@ -230,7 +230,7 @@ func runInDaemonMode(ctx context.Context, cmd *cobra.Command, pm *profilemanager
client := proto.NewDaemonServiceClient(conn)
status, err := client.Status(ctx, &proto.StatusRequest{WaitForConnectingShift: true})
status, err := client.Status(ctx, &proto.StatusRequest{})
if err != nil {
return fmt.Errorf("unable to get daemon status: %v", err)
}

View File

@@ -15,7 +15,6 @@ import (
"golang.org/x/net/ipv6"
wgConn "golang.zx2c4.com/wireguard/conn"
"github.com/netbirdio/netbird/client/iface/udpmux"
"github.com/netbirdio/netbird/client/iface/wgaddr"
nbnet "github.com/netbirdio/netbird/util/net"
)
@@ -45,7 +44,7 @@ type ICEBind struct {
RecvChan chan RecvMessage
transportNet transport.Net
filterFn udpmux.FilterFn
filterFn FilterFn
endpoints map[netip.Addr]net.Conn
endpointsMu sync.Mutex
// every time when Close() is called (i.e. BindUpdate()) we need to close exit from the receiveRelayed and create a
@@ -55,13 +54,13 @@ type ICEBind struct {
closed bool
muUDPMux sync.Mutex
udpMux *udpmux.UniversalUDPMuxDefault
udpMux *UniversalUDPMuxDefault
address wgaddr.Address
mtu uint16
activityRecorder *ActivityRecorder
}
func NewICEBind(transportNet transport.Net, filterFn udpmux.FilterFn, address wgaddr.Address, mtu uint16) *ICEBind {
func NewICEBind(transportNet transport.Net, filterFn FilterFn, address wgaddr.Address, mtu uint16) *ICEBind {
b, _ := wgConn.NewStdNetBind().(*wgConn.StdNetBind)
ib := &ICEBind{
StdNetBind: b,
@@ -116,7 +115,7 @@ func (s *ICEBind) ActivityRecorder() *ActivityRecorder {
}
// GetICEMux returns the ICE UDPMux that was created and used by ICEBind
func (s *ICEBind) GetICEMux() (*udpmux.UniversalUDPMuxDefault, error) {
func (s *ICEBind) GetICEMux() (*UniversalUDPMuxDefault, error) {
s.muUDPMux.Lock()
defer s.muUDPMux.Unlock()
if s.udpMux == nil {
@@ -159,8 +158,8 @@ func (s *ICEBind) createIPv4ReceiverFn(pc *ipv4.PacketConn, conn *net.UDPConn, r
s.muUDPMux.Lock()
defer s.muUDPMux.Unlock()
s.udpMux = udpmux.NewUniversalUDPMuxDefault(
udpmux.UniversalUDPMuxParams{
s.udpMux = NewUniversalUDPMuxDefault(
UniversalUDPMuxParams{
UDPConn: nbnet.WrapPacketConn(conn),
Net: s.transportNet,
FilterFn: s.filterFn,

View File

@@ -1,4 +1,4 @@
package udpmux
package bind
import (
"fmt"
@@ -22,9 +22,9 @@ import (
const receiveMTU = 8192
// SingleSocketUDPMux is an implementation of the interface
type SingleSocketUDPMux struct {
params Params
// UDPMuxDefault is an implementation of the interface
type UDPMuxDefault struct {
params UDPMuxParams
closedChan chan struct{}
closeOnce sync.Once
@@ -32,9 +32,6 @@ type SingleSocketUDPMux struct {
// connsIPv4 and connsIPv6 are maps of all udpMuxedConn indexed by ufrag|network|candidateType
connsIPv4, connsIPv6 map[string]*udpMuxedConn
// candidateConnMap maps local candidate IDs to their corresponding connection.
candidateConnMap map[string]*udpMuxedConn
addressMapMu sync.RWMutex
addressMap map[string][]*udpMuxedConn
@@ -49,8 +46,8 @@ type SingleSocketUDPMux struct {
const maxAddrSize = 512
// Params are parameters for UDPMux.
type Params struct {
// UDPMuxParams are parameters for UDPMux.
type UDPMuxParams struct {
Logger logging.LeveledLogger
UDPConn net.PacketConn
@@ -150,19 +147,18 @@ func isZeros(ip net.IP) bool {
return true
}
// NewSingleSocketUDPMux creates an implementation of UDPMux
func NewSingleSocketUDPMux(params Params) *SingleSocketUDPMux {
// NewUDPMuxDefault creates an implementation of UDPMux
func NewUDPMuxDefault(params UDPMuxParams) *UDPMuxDefault {
if params.Logger == nil {
params.Logger = getLogger()
}
mux := &SingleSocketUDPMux{
addressMap: map[string][]*udpMuxedConn{},
params: params,
connsIPv4: make(map[string]*udpMuxedConn),
connsIPv6: make(map[string]*udpMuxedConn),
candidateConnMap: make(map[string]*udpMuxedConn),
closedChan: make(chan struct{}, 1),
mux := &UDPMuxDefault{
addressMap: map[string][]*udpMuxedConn{},
params: params,
connsIPv4: make(map[string]*udpMuxedConn),
connsIPv6: make(map[string]*udpMuxedConn),
closedChan: make(chan struct{}, 1),
pool: &sync.Pool{
New: func() interface{} {
// big enough buffer to fit both packet and address
@@ -175,15 +171,15 @@ func NewSingleSocketUDPMux(params Params) *SingleSocketUDPMux {
return mux
}
func (m *SingleSocketUDPMux) updateLocalAddresses() {
func (m *UDPMuxDefault) updateLocalAddresses() {
var localAddrsForUnspecified []net.Addr
if addr, ok := m.params.UDPConn.LocalAddr().(*net.UDPAddr); !ok {
m.params.Logger.Errorf("LocalAddr is not a net.UDPAddr, got %T", m.params.UDPConn.LocalAddr())
} else if ok && addr.IP.IsUnspecified() {
// For unspecified addresses, the correct behavior is to return errListenUnspecified, but
// it will break the applications that are already using unspecified UDP connection
// with SingleSocketUDPMux, so print a warn log and create a local address list for mux.
m.params.Logger.Warn("SingleSocketUDPMux should not listening on unspecified address, use NewMultiUDPMuxFromPort instead")
// with UDPMuxDefault, so print a warn log and create a local address list for mux.
m.params.Logger.Warn("UDPMuxDefault should not listening on unspecified address, use NewMultiUDPMuxFromPort instead")
var networks []ice.NetworkType
switch {
@@ -220,13 +216,13 @@ func (m *SingleSocketUDPMux) updateLocalAddresses() {
m.mu.Unlock()
}
// LocalAddr returns the listening address of this SingleSocketUDPMux
func (m *SingleSocketUDPMux) LocalAddr() net.Addr {
// LocalAddr returns the listening address of this UDPMuxDefault
func (m *UDPMuxDefault) LocalAddr() net.Addr {
return m.params.UDPConn.LocalAddr()
}
// GetListenAddresses returns the list of addresses that this mux is listening on
func (m *SingleSocketUDPMux) GetListenAddresses() []net.Addr {
func (m *UDPMuxDefault) GetListenAddresses() []net.Addr {
m.updateLocalAddresses()
m.mu.Lock()
@@ -240,7 +236,7 @@ func (m *SingleSocketUDPMux) GetListenAddresses() []net.Addr {
// GetConn returns a PacketConn given the connection's ufrag and network address
// creates the connection if an existing one can't be found
func (m *SingleSocketUDPMux) GetConn(ufrag string, addr net.Addr, candidateID string) (net.PacketConn, error) {
func (m *UDPMuxDefault) GetConn(ufrag string, addr net.Addr) (net.PacketConn, error) {
// don't check addr for mux using unspecified address
m.mu.Lock()
lenLocalAddrs := len(m.localAddrsForUnspecified)
@@ -264,14 +260,12 @@ func (m *SingleSocketUDPMux) GetConn(ufrag string, addr net.Addr, candidateID st
return conn, nil
}
c := m.createMuxedConn(ufrag, candidateID)
c := m.createMuxedConn(ufrag)
go func() {
<-c.CloseChannel()
m.RemoveConnByUfrag(ufrag)
}()
m.candidateConnMap[candidateID] = c
if isIPv6 {
m.connsIPv6[ufrag] = c
} else {
@@ -282,7 +276,7 @@ func (m *SingleSocketUDPMux) GetConn(ufrag string, addr net.Addr, candidateID st
}
// RemoveConnByUfrag stops and removes the muxed packet connection
func (m *SingleSocketUDPMux) RemoveConnByUfrag(ufrag string) {
func (m *UDPMuxDefault) RemoveConnByUfrag(ufrag string) {
removedConns := make([]*udpMuxedConn, 0, 2)
// Keep lock section small to avoid deadlock with conn lock
@@ -290,12 +284,10 @@ func (m *SingleSocketUDPMux) RemoveConnByUfrag(ufrag string) {
if c, ok := m.connsIPv4[ufrag]; ok {
delete(m.connsIPv4, ufrag)
removedConns = append(removedConns, c)
delete(m.candidateConnMap, c.GetCandidateID())
}
if c, ok := m.connsIPv6[ufrag]; ok {
delete(m.connsIPv6, ufrag)
removedConns = append(removedConns, c)
delete(m.candidateConnMap, c.GetCandidateID())
}
m.mu.Unlock()
@@ -322,7 +314,7 @@ func (m *SingleSocketUDPMux) RemoveConnByUfrag(ufrag string) {
}
// IsClosed returns true if the mux had been closed
func (m *SingleSocketUDPMux) IsClosed() bool {
func (m *UDPMuxDefault) IsClosed() bool {
select {
case <-m.closedChan:
return true
@@ -332,7 +324,7 @@ func (m *SingleSocketUDPMux) IsClosed() bool {
}
// Close the mux, no further connections could be created
func (m *SingleSocketUDPMux) Close() error {
func (m *UDPMuxDefault) Close() error {
var err error
m.closeOnce.Do(func() {
m.mu.Lock()
@@ -355,11 +347,11 @@ func (m *SingleSocketUDPMux) Close() error {
return err
}
func (m *SingleSocketUDPMux) writeTo(buf []byte, rAddr net.Addr) (n int, err error) {
func (m *UDPMuxDefault) writeTo(buf []byte, rAddr net.Addr) (n int, err error) {
return m.params.UDPConn.WriteTo(buf, rAddr)
}
func (m *SingleSocketUDPMux) registerConnForAddress(conn *udpMuxedConn, addr string) {
func (m *UDPMuxDefault) registerConnForAddress(conn *udpMuxedConn, addr string) {
if m.IsClosed() {
return
}
@@ -376,109 +368,81 @@ func (m *SingleSocketUDPMux) registerConnForAddress(conn *udpMuxedConn, addr str
log.Debugf("ICE: registered %s for %s", addr, conn.params.Key)
}
func (m *SingleSocketUDPMux) createMuxedConn(key string, candidateID string) *udpMuxedConn {
func (m *UDPMuxDefault) createMuxedConn(key string) *udpMuxedConn {
c := newUDPMuxedConn(&udpMuxedConnParams{
Mux: m,
Key: key,
AddrPool: m.pool,
LocalAddr: m.LocalAddr(),
Logger: m.params.Logger,
CandidateID: candidateID,
Mux: m,
Key: key,
AddrPool: m.pool,
LocalAddr: m.LocalAddr(),
Logger: m.params.Logger,
})
return c
}
// HandleSTUNMessage handles STUN packets and forwards them to underlying pion/ice library
func (m *SingleSocketUDPMux) HandleSTUNMessage(msg *stun.Message, addr net.Addr) error {
func (m *UDPMuxDefault) HandleSTUNMessage(msg *stun.Message, addr net.Addr) error {
remoteAddr, ok := addr.(*net.UDPAddr)
if !ok {
return fmt.Errorf("underlying PacketConn did not return a UDPAddr")
}
// Try to route to specific candidate connection first
if conn := m.findCandidateConnection(msg); conn != nil {
return conn.writePacket(msg.Raw, remoteAddr)
}
// Fallback: route to all possible connections
return m.forwardToAllConnections(msg, addr, remoteAddr)
}
// findCandidateConnection attempts to find the specific connection for a STUN message
func (m *SingleSocketUDPMux) findCandidateConnection(msg *stun.Message) *udpMuxedConn {
candidatePairID, ok, err := ice.CandidatePairIDFromSTUN(msg)
if err != nil {
return nil
} else if !ok {
return nil
}
m.mu.Lock()
defer m.mu.Unlock()
conn, exists := m.candidateConnMap[candidatePairID.TargetCandidateID()]
if !exists {
return nil
}
return conn
}
// forwardToAllConnections forwards STUN message to all relevant connections
func (m *SingleSocketUDPMux) forwardToAllConnections(msg *stun.Message, addr net.Addr, remoteAddr *net.UDPAddr) error {
var destinationConnList []*udpMuxedConn
// Add connections from address map
// If we have already seen this address dispatch to the appropriate destination
// If you are using the same socket for the Host and SRFLX candidates, it might be that there are more than one
// muxed connection - one for the SRFLX candidate and the other one for the HOST one.
// We will then forward STUN packets to each of these connections.
m.addressMapMu.RLock()
var destinationConnList []*udpMuxedConn
if storedConns, ok := m.addressMap[addr.String()]; ok {
destinationConnList = append(destinationConnList, storedConns...)
}
m.addressMapMu.RUnlock()
if conn, ok := m.findConnectionByUsername(msg, addr); ok {
// If we have already seen this address dispatch to the appropriate destination
// If you are using the same socket for the Host and SRFLX candidates, it might be that there are more than one
// muxed connection - one for the SRFLX candidate and the other one for the HOST one.
// We will then forward STUN packets to each of these connections.
if !m.connectionExists(conn, destinationConnList) {
destinationConnList = append(destinationConnList, conn)
}
var isIPv6 bool
if udpAddr, _ := addr.(*net.UDPAddr); udpAddr != nil && udpAddr.IP.To4() == nil {
isIPv6 = true
}
// Forward to all found connections
// This block is needed to discover Peer Reflexive Candidates for which we don't know the Endpoint upfront.
// However, we can take a username attribute from the STUN message which contains ufrag.
// We can use ufrag to identify the destination conn to route packet to.
attr, stunAttrErr := msg.Get(stun.AttrUsername)
if stunAttrErr == nil {
ufrag := strings.Split(string(attr), ":")[0]
m.mu.Lock()
destinationConn := m.connsIPv4[ufrag]
if isIPv6 {
destinationConn = m.connsIPv6[ufrag]
}
if destinationConn != nil {
exists := false
for _, conn := range destinationConnList {
if conn.params.Key == destinationConn.params.Key {
exists = true
break
}
}
if !exists {
destinationConnList = append(destinationConnList, destinationConn)
}
}
m.mu.Unlock()
}
// Forward STUN packets to each destination connections even thought the STUN packet might not belong there.
// It will be discarded by the further ICE candidate logic if so.
for _, conn := range destinationConnList {
if err := conn.writePacket(msg.Raw, remoteAddr); err != nil {
log.Errorf("could not write packet: %v", err)
}
}
return nil
}
// findConnectionByUsername finds connection using username attribute from STUN message
func (m *SingleSocketUDPMux) findConnectionByUsername(msg *stun.Message, addr net.Addr) (*udpMuxedConn, bool) {
attr, err := msg.Get(stun.AttrUsername)
if err != nil {
return nil, false
}
ufrag := strings.Split(string(attr), ":")[0]
isIPv6 := isIPv6Address(addr)
m.mu.Lock()
defer m.mu.Unlock()
return m.getConn(ufrag, isIPv6)
}
// connectionExists checks if a connection already exists in the list
func (m *SingleSocketUDPMux) connectionExists(target *udpMuxedConn, conns []*udpMuxedConn) bool {
for _, conn := range conns {
if conn.params.Key == target.params.Key {
return true
}
}
return false
}
func (m *SingleSocketUDPMux) getConn(ufrag string, isIPv6 bool) (val *udpMuxedConn, ok bool) {
func (m *UDPMuxDefault) getConn(ufrag string, isIPv6 bool) (val *udpMuxedConn, ok bool) {
if isIPv6 {
val, ok = m.connsIPv6[ufrag]
} else {
@@ -487,13 +451,6 @@ func (m *SingleSocketUDPMux) getConn(ufrag string, isIPv6 bool) (val *udpMuxedCo
return
}
func isIPv6Address(addr net.Addr) bool {
if udpAddr, ok := addr.(*net.UDPAddr); ok {
return udpAddr.IP.To4() == nil
}
return false
}
type bufferHolder struct {
buf []byte
}

View File

@@ -1,12 +1,12 @@
//go:build !ios
package udpmux
package bind
import (
nbnet "github.com/netbirdio/netbird/util/net"
)
func (m *SingleSocketUDPMux) notifyAddressRemoval(addr string) {
func (m *UDPMuxDefault) notifyAddressRemoval(addr string) {
// Kernel mode: direct nbnet.PacketConn (SharedSocket wrapped with nbnet)
if conn, ok := m.params.UDPConn.(*nbnet.PacketConn); ok {
conn.RemoveAddress(addr)

View File

@@ -0,0 +1,7 @@
//go:build ios
package bind
func (m *UDPMuxDefault) notifyAddressRemoval(addr string) {
// iOS doesn't support nbnet hooks, so this is a no-op
}

View File

@@ -1,4 +1,4 @@
package udpmux
package bind
/*
Most of this code was copied from https://github.com/pion/ice and modified to fulfill NetBird's requirements.
@@ -29,7 +29,7 @@ type FilterFn func(address netip.Addr) (bool, netip.Prefix, error)
// UniversalUDPMuxDefault handles STUN and TURN servers packets by wrapping the original UDPConn
// It then passes packets to the UDPMux that does the actual connection muxing.
type UniversalUDPMuxDefault struct {
*SingleSocketUDPMux
*UDPMuxDefault
params UniversalUDPMuxParams
// since we have a shared socket, for srflx candidates it makes sense to have a shared mapped address across all the agents
@@ -72,12 +72,12 @@ func NewUniversalUDPMuxDefault(params UniversalUDPMuxParams) *UniversalUDPMuxDef
address: params.WGAddress,
}
udpMuxParams := Params{
udpMuxParams := UDPMuxParams{
Logger: params.Logger,
UDPConn: m.params.UDPConn,
Net: m.params.Net,
}
m.SingleSocketUDPMux = NewSingleSocketUDPMux(udpMuxParams)
m.UDPMuxDefault = NewUDPMuxDefault(udpMuxParams)
return m
}
@@ -211,8 +211,8 @@ func (m *UniversalUDPMuxDefault) GetRelayedAddr(turnAddr net.Addr, deadline time
// GetConnForURL add uniques to the muxed connection by concatenating ufrag and URL (e.g. STUN URL) to be able to support multiple STUN/TURN servers
// and return a unique connection per server.
func (m *UniversalUDPMuxDefault) GetConnForURL(ufrag string, url string, addr net.Addr, candidateID string) (net.PacketConn, error) {
return m.SingleSocketUDPMux.GetConn(fmt.Sprintf("%s%s", ufrag, url), addr, candidateID)
func (m *UniversalUDPMuxDefault) GetConnForURL(ufrag string, url string, addr net.Addr) (net.PacketConn, error) {
return m.UDPMuxDefault.GetConn(fmt.Sprintf("%s%s", ufrag, url), addr)
}
// HandleSTUNMessage discovers STUN packets that carry a XOR mapped address from a STUN server.
@@ -233,7 +233,7 @@ func (m *UniversalUDPMuxDefault) HandleSTUNMessage(msg *stun.Message, addr net.A
}
return nil
}
return m.SingleSocketUDPMux.HandleSTUNMessage(msg, addr)
return m.UDPMuxDefault.HandleSTUNMessage(msg, addr)
}
// isXORMappedResponse indicates whether the message is a XORMappedAddress and is coming from the known STUN server.

View File

@@ -1,4 +1,4 @@
package udpmux
package bind
/*
Most of this code was copied from https://github.com/pion/ice and modified to fulfill NetBird's requirements
@@ -16,12 +16,11 @@ import (
)
type udpMuxedConnParams struct {
Mux *SingleSocketUDPMux
AddrPool *sync.Pool
Key string
LocalAddr net.Addr
Logger logging.LeveledLogger
CandidateID string
Mux *UDPMuxDefault
AddrPool *sync.Pool
Key string
LocalAddr net.Addr
Logger logging.LeveledLogger
}
// udpMuxedConn represents a logical packet conn for a single remote as identified by ufrag
@@ -120,10 +119,6 @@ func (c *udpMuxedConn) Close() error {
return err
}
func (c *udpMuxedConn) GetCandidateID() string {
return c.params.CandidateID
}
func (c *udpMuxedConn) isClosed() bool {
select {
case <-c.closedChan:

View File

@@ -7,14 +7,14 @@ import (
wgdevice "golang.zx2c4.com/wireguard/device"
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/udpmux"
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
type WGTunDevice interface {
Create() (device.WGConfigurer, error)
Up() (*udpmux.UniversalUDPMuxDefault, error)
Up() (*bind.UniversalUDPMuxDefault, error)
UpdateAddr(address wgaddr.Address) error
WgAddress() wgaddr.Address
MTU() uint16

View File

@@ -13,7 +13,6 @@ import (
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/udpmux"
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
@@ -30,7 +29,7 @@ type WGTunDevice struct {
name string
device *device.Device
filteredDevice *FilteredDevice
udpMux *udpmux.UniversalUDPMuxDefault
udpMux *bind.UniversalUDPMuxDefault
configurer WGConfigurer
}
@@ -89,7 +88,7 @@ func (t *WGTunDevice) Create(routes []string, dns string, searchDomains []string
}
return t.configurer, nil
}
func (t *WGTunDevice) Up() (*udpmux.UniversalUDPMuxDefault, error) {
func (t *WGTunDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
err := t.device.Up()
if err != nil {
return nil, err

View File

@@ -13,7 +13,6 @@ import (
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/udpmux"
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
@@ -27,7 +26,7 @@ type TunDevice struct {
device *device.Device
filteredDevice *FilteredDevice
udpMux *udpmux.UniversalUDPMuxDefault
udpMux *bind.UniversalUDPMuxDefault
configurer WGConfigurer
}
@@ -72,7 +71,7 @@ func (t *TunDevice) Create() (WGConfigurer, error) {
return t.configurer, nil
}
func (t *TunDevice) Up() (*udpmux.UniversalUDPMuxDefault, error) {
func (t *TunDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
err := t.device.Up()
if err != nil {
return nil, err

View File

@@ -14,7 +14,6 @@ import (
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/udpmux"
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
@@ -29,7 +28,7 @@ type TunDevice struct {
device *device.Device
filteredDevice *FilteredDevice
udpMux *udpmux.UniversalUDPMuxDefault
udpMux *bind.UniversalUDPMuxDefault
configurer WGConfigurer
}
@@ -84,7 +83,7 @@ func (t *TunDevice) Create() (WGConfigurer, error) {
return t.configurer, nil
}
func (t *TunDevice) Up() (*udpmux.UniversalUDPMuxDefault, error) {
func (t *TunDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
err := t.device.Up()
if err != nil {
return nil, err

View File

@@ -12,8 +12,8 @@ import (
"golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/tun/netstack"
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/udpmux"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/sharedsock"
nbnet "github.com/netbirdio/netbird/util/net"
@@ -31,9 +31,9 @@ type TunKernelDevice struct {
link *wgLink
udpMuxConn net.PacketConn
udpMux *udpmux.UniversalUDPMuxDefault
udpMux *bind.UniversalUDPMuxDefault
filterFn udpmux.FilterFn
filterFn bind.FilterFn
}
func NewKernelDevice(name string, address wgaddr.Address, wgPort int, key string, mtu uint16, transportNet transport.Net) *TunKernelDevice {
@@ -79,7 +79,7 @@ func (t *TunKernelDevice) Create() (WGConfigurer, error) {
return configurer, nil
}
func (t *TunKernelDevice) Up() (*udpmux.UniversalUDPMuxDefault, error) {
func (t *TunKernelDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
if t.udpMux != nil {
return t.udpMux, nil
}
@@ -106,14 +106,14 @@ func (t *TunKernelDevice) Up() (*udpmux.UniversalUDPMuxDefault, error) {
udpConn = nbnet.WrapPacketConn(rawSock)
}
bindParams := udpmux.UniversalUDPMuxParams{
bindParams := bind.UniversalUDPMuxParams{
UDPConn: udpConn,
Net: t.transportNet,
FilterFn: t.filterFn,
WGAddress: t.address,
MTU: t.mtu,
}
mux := udpmux.NewUniversalUDPMuxDefault(bindParams)
mux := bind.NewUniversalUDPMuxDefault(bindParams)
go mux.ReadFromConn(t.ctx)
t.udpMuxConn = rawSock
t.udpMux = mux

View File

@@ -10,7 +10,6 @@ import (
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer"
nbnetstack "github.com/netbirdio/netbird/client/iface/netstack"
"github.com/netbirdio/netbird/client/iface/udpmux"
"github.com/netbirdio/netbird/client/iface/wgaddr"
nbnet "github.com/netbirdio/netbird/util/net"
)
@@ -27,7 +26,7 @@ type TunNetstackDevice struct {
device *device.Device
filteredDevice *FilteredDevice
nsTun *nbnetstack.NetStackTun
udpMux *udpmux.UniversalUDPMuxDefault
udpMux *bind.UniversalUDPMuxDefault
configurer WGConfigurer
net *netstack.Net
@@ -81,7 +80,7 @@ func (t *TunNetstackDevice) create() (WGConfigurer, error) {
return t.configurer, nil
}
func (t *TunNetstackDevice) Up() (*udpmux.UniversalUDPMuxDefault, error) {
func (t *TunNetstackDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
if t.device == nil {
return nil, fmt.Errorf("device is not ready yet")
}

View File

@@ -12,7 +12,6 @@ import (
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/udpmux"
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
@@ -26,7 +25,7 @@ type USPDevice struct {
device *device.Device
filteredDevice *FilteredDevice
udpMux *udpmux.UniversalUDPMuxDefault
udpMux *bind.UniversalUDPMuxDefault
configurer WGConfigurer
}
@@ -75,7 +74,7 @@ func (t *USPDevice) Create() (WGConfigurer, error) {
return t.configurer, nil
}
func (t *USPDevice) Up() (*udpmux.UniversalUDPMuxDefault, error) {
func (t *USPDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
if t.device == nil {
return nil, fmt.Errorf("device is not ready yet")
}

View File

@@ -13,7 +13,6 @@ import (
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/udpmux"
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
@@ -30,7 +29,7 @@ type TunDevice struct {
device *device.Device
nativeTunDevice *tun.NativeTun
filteredDevice *FilteredDevice
udpMux *udpmux.UniversalUDPMuxDefault
udpMux *bind.UniversalUDPMuxDefault
configurer WGConfigurer
}
@@ -105,7 +104,7 @@ func (t *TunDevice) Create() (WGConfigurer, error) {
return t.configurer, nil
}
func (t *TunDevice) Up() (*udpmux.UniversalUDPMuxDefault, error) {
func (t *TunDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
err := t.device.Up()
if err != nil {
return nil, err

View File

@@ -5,14 +5,14 @@ import (
"golang.zx2c4.com/wireguard/tun/netstack"
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/udpmux"
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
type WGTunDevice interface {
Create(routes []string, dns string, searchDomains []string) (device.WGConfigurer, error)
Up() (*udpmux.UniversalUDPMuxDefault, error)
Up() (*bind.UniversalUDPMuxDefault, error)
UpdateAddr(address wgaddr.Address) error
WgAddress() wgaddr.Address
MTU() uint16

View File

@@ -16,9 +16,9 @@ import (
wgdevice "golang.zx2c4.com/wireguard/device"
"github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/udpmux"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/iface/wgproxy"
"github.com/netbirdio/netbird/monotime"
@@ -61,7 +61,7 @@ type WGIFaceOpts struct {
MTU uint16
MobileArgs *device.MobileIFaceArguments
TransportNet transport.Net
FilterFn udpmux.FilterFn
FilterFn bind.FilterFn
DisableDNS bool
}
@@ -114,7 +114,7 @@ func (r *WGIface) ToInterface() *net.Interface {
// Up configures a Wireguard interface
// The interface must exist before calling this method (e.g. call interface.Create() before)
func (w *WGIface) Up() (*udpmux.UniversalUDPMuxDefault, error) {
func (w *WGIface) Up() (*bind.UniversalUDPMuxDefault, error) {
w.mu.Lock()
defer w.mu.Unlock()

View File

@@ -1,64 +0,0 @@
// Package udpmux provides a custom implementation of a UDP multiplexer
// that allows multiple logical ICE connections to share a single underlying
// UDP socket. This is based on Pion's ICE library, with modifications for
// NetBird's requirements.
//
// # Background
//
// In WebRTC and NAT traversal scenarios, ICE (Interactive Connectivity
// Establishment) is responsible for discovering candidate network paths
// and maintaining connectivity between peers. Each ICE connection
// normally requires a dedicated UDP socket. However, using one socket
// per candidate can be inefficient and difficult to manage.
//
// This package introduces SingleSocketUDPMux, which allows multiple ICE
// candidate connections (muxed connections) to share a single UDP socket.
// It handles demultiplexing of packets based on ICE ufrag values, STUN
// attributes, and candidate IDs.
//
// # Usage
//
// The typical flow is:
//
// 1. Create a UDP socket (net.PacketConn).
// 2. Construct Params with the socket and optional logger/net stack.
// 3. Call NewSingleSocketUDPMux(params).
// 4. For each ICE candidate ufrag, call GetConn(ufrag, addr, candidateID)
// to obtain a logical PacketConn.
// 5. Use the returned PacketConn just like a normal UDP connection.
//
// # STUN Message Routing Logic
//
// When a STUN packet arrives, the mux decides which connection should
// receive it using this routing logic:
//
// Primary Routing: Candidate Pair ID
// - Extract the candidate pair ID from the STUN message using
// ice.CandidatePairIDFromSTUN(msg)
// - The target candidate is the locally generated candidate that
// corresponds to the connection that should handle this STUN message
// - If found, use the target candidate ID to lookup the specific
// connection in candidateConnMap
// - Route the message directly to that connection
//
// Fallback Routing: Broadcasting
// When candidate pair ID is not available or lookup fails:
// - Collect connections from addressMap based on source address
// - Find connection using username attribute (ufrag) from STUN message
// - Remove duplicate connections from the list
// - Send the STUN message to all collected connections
//
// # Peer Reflexive Candidate Discovery
//
// When a remote peer sends a STUN message from an unknown source address
// (from a candidate that has not been exchanged via signal), the ICE
// library will:
// - Generate a new peer reflexive candidate for this source address
// - Extract or assign a candidate ID based on the STUN message attributes
// - Create a mapping between the new peer reflexive candidate ID and
// the appropriate local connection
//
// This discovery mechanism ensures that STUN messages from newly discovered
// peer reflexive candidates can be properly routed to the correct local
// connection without requiring fallback broadcasting.
package udpmux

View File

@@ -1,7 +0,0 @@
//go:build ios
package udpmux
func (m *SingleSocketUDPMux) notifyAddressRemoval(addr string) {
// iOS doesn't support nbnet hooks, so this is a no-op
}

View File

@@ -280,12 +280,15 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
return wrapErr(err)
}
log.Infof("Netbird engine started, the IP is: %s", peerConfig.GetAddress())
state.Set(StatusConnected)
if runningChan != nil {
close(runningChan)
runningChan = nil
select {
case runningChan <- struct{}{}:
default:
}
}
<-engineCtx.Done()

View File

@@ -315,6 +315,10 @@ func (g *BundleGenerator) createArchive() error {
return fmt.Errorf("add sync response: %w", err)
}
if err := g.addDNSConfig(); err != nil {
log.Errorf("failed to add DNS config to debug bundle: %v", err)
}
if err := g.addStateFile(); err != nil {
log.Errorf("failed to add state file to debug bundle: %v", err)
}
@@ -341,6 +345,50 @@ func (g *BundleGenerator) createArchive() error {
return nil
}
// addDNSConfig writes a dns_config.json snapshot with routed domains and NS group status
func (g *BundleGenerator) addDNSConfig() error {
type nsGroup struct {
ID string `json:"id"`
Servers []string `json:"servers"`
Domains []string `json:"domains"`
Enabled bool `json:"enabled"`
Error string `json:"error,omitempty"`
}
type dnsConfig struct {
Groups []nsGroup `json:"name_server_groups"`
}
if g.statusRecorder == nil {
return nil
}
states := g.statusRecorder.GetDNSStates()
cfg := dnsConfig{Groups: make([]nsGroup, 0, len(states))}
for _, st := range states {
var servers []string
for _, ap := range st.Servers {
servers = append(servers, ap.String())
}
var errStr string
if st.Error != nil {
errStr = st.Error.Error()
}
cfg.Groups = append(cfg.Groups, nsGroup{
ID: st.ID,
Servers: servers,
Domains: st.Domains,
Enabled: st.Enabled,
Error: errStr,
})
}
bs, err := json.MarshalIndent(cfg, "", " ")
if err != nil {
return fmt.Errorf("marshal dns config: %w", err)
}
return g.addFileToZip(bytes.NewReader(bs), "dns_config.json")
}
func (g *BundleGenerator) addSystemInfo() {
if err := g.addRoutes(); err != nil {
log.Errorf("failed to add routes to debug bundle: %v", err)

View File

@@ -46,6 +46,18 @@ type DNSForwarder struct {
fwdEntries []*ForwarderEntry
firewall firewaller
resolver resolver
// failure rate tracking for routed domains
failureMu sync.Mutex
failureCounts map[string]int
failureWindow time.Duration
lastLogPerHost map[string]time.Time
// per-domain rolling stats and windows
statsMu sync.Mutex
stats map[string]*domainStats
winSize time.Duration
slowT time.Duration
}
func NewDNSForwarder(listenAddress string, ttl uint32, firewall firewaller, statusRecorder *peer.Status) *DNSForwarder {
@@ -56,9 +68,25 @@ func NewDNSForwarder(listenAddress string, ttl uint32, firewall firewaller, stat
firewall: firewall,
statusRecorder: statusRecorder,
resolver: net.DefaultResolver,
failureCounts: make(map[string]int),
failureWindow: 10 * time.Second,
lastLogPerHost: make(map[string]time.Time),
stats: make(map[string]*domainStats),
winSize: 10 * time.Second,
slowT: 300 * time.Millisecond,
}
}
type domainStats struct {
total int
success int
timeouts int
notfound int
failures int // other failures (incl. SERVFAIL-like)
slow int
lastLog time.Time
}
func (f *DNSForwarder) Listen(entries []*ForwarderEntry) error {
log.Infof("starting DNS forwarder on address=%s", f.listenAddress)
@@ -163,12 +191,19 @@ func (f *DNSForwarder) handleDNSQuery(w dns.ResponseWriter, query *dns.Msg) *dns
ctx, cancel := context.WithTimeout(context.Background(), upstreamTimeout)
defer cancel()
start := time.Now()
ips, err := f.resolver.LookupNetIP(ctx, network, domain)
elapsed := time.Since(start)
if err != nil {
f.handleDNSError(ctx, w, question, resp, domain, err)
// record error stats for routed domains
f.recordErrorStats(strings.TrimSuffix(domain, "."), err)
return nil
}
// record success timing
f.recordSuccessStats(strings.TrimSuffix(domain, "."), elapsed)
f.updateInternalState(ips, mostSpecificResId, matchingEntries)
f.addIPsToResponse(resp, domain, ips)
@@ -306,6 +341,91 @@ func (f *DNSForwarder) handleDNSError(ctx context.Context, w dns.ResponseWriter,
if err := w.WriteMsg(resp); err != nil {
log.Errorf("failed to write failure DNS response: %v", err)
}
// Track failure rate for routed domains only
if resID, _ := f.getMatchingEntries(strings.TrimSuffix(domain, ".")); resID != "" {
f.recordDomainFailure(strings.TrimSuffix(domain, "."))
}
}
// recordErrorStats updates per-domain counters and emits rate-limited logs
func (f *DNSForwarder) recordErrorStats(domain string, err error) {
domain = strings.ToLower(domain)
f.statsMu.Lock()
s := f.ensureStats(domain)
s.total++
var dnsErr *net.DNSError
if errors.As(err, &dnsErr) {
if dnsErr.IsNotFound {
s.notfound++
} else if dnsErr.Timeout() {
s.timeouts++
} else {
s.failures++
}
} else {
s.failures++
}
f.maybeLogDomainStats(domain, s)
f.statsMu.Unlock()
}
// recordSuccessStats updates per-domain latency stats and slow counters, logs if needed (rate-limited)
func (f *DNSForwarder) recordSuccessStats(domain string, elapsed time.Duration) {
domain = strings.ToLower(domain)
f.statsMu.Lock()
s := f.ensureStats(domain)
s.total++
s.success++
if elapsed >= f.slowT {
s.slow++
}
f.maybeLogDomainStats(domain, s)
f.statsMu.Unlock()
}
func (f *DNSForwarder) ensureStats(domain string) *domainStats {
if ds, ok := f.stats[domain]; ok {
return ds
}
ds := &domainStats{}
f.stats[domain] = ds
return ds
}
// maybeLogDomainStats logs a compact summary per routed domain at most once per window
func (f *DNSForwarder) maybeLogDomainStats(domain string, s *domainStats) {
now := time.Now()
if !s.lastLog.IsZero() && now.Sub(s.lastLog) < f.winSize {
return
}
// check if routed (avoid logging for non-routed domains)
if resID, _ := f.getMatchingEntries(domain); resID == "" {
return
}
// only log if something noteworthy happened in the window
noteworthy := s.timeouts > 0 || s.notfound > 0 || s.failures > 0 || s.slow > 0
if !noteworthy {
s.lastLog = now
return
}
// warn on persistent problems, info otherwise
levelWarn := s.timeouts >= 3 || s.failures >= 3
if levelWarn {
log.Warnf("[d] DNS stats: domain=%s total=%d ok=%d timeout=%d nxdomain=%d fail=%d slow=%d(>=%s)",
domain, s.total, s.success, s.timeouts, s.notfound, s.failures, s.slow, f.slowT)
} else {
log.Infof("[d] DNS stats: domain=%s total=%d ok=%d timeout=%d nxdomain=%d fail=%d slow=%d(>=%s)",
domain, s.total, s.success, s.timeouts, s.notfound, s.failures, s.slow, f.slowT)
}
// reset counters for next window
*s = domainStats{lastLog: now}
}
// addIPsToResponse adds IP addresses to the DNS response as appropriate A or AAAA records
@@ -341,6 +461,27 @@ func (f *DNSForwarder) addIPsToResponse(resp *dns.Msg, domain string, ips []neti
}
}
// recordDomainFailure increments failure count for the domain and logs at info/warn with throttling.
func (f *DNSForwarder) recordDomainFailure(domain string) {
domain = strings.ToLower(domain)
f.failureMu.Lock()
defer f.failureMu.Unlock()
f.failureCounts[domain]++
count := f.failureCounts[domain]
now := time.Now()
last, ok := f.lastLogPerHost[domain]
if ok && now.Sub(last) < f.failureWindow {
return
}
f.lastLogPerHost[domain] = now
log.Warnf("[d] DNS failures observed for routed domain: domain=%s failures=%d/%s", domain, count, f.failureWindow)
}
// getMatchingEntries retrieves the resource IDs for a given domain.
// It returns the most specific match and all matching resource IDs.
func (f *DNSForwarder) getMatchingEntries(domain string) (route.ResID, []*ForwarderEntry) {

View File

@@ -29,9 +29,9 @@ import (
"github.com/netbirdio/netbird/client/firewall"
firewallManager "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/device"
nbnetstack "github.com/netbirdio/netbird/client/iface/netstack"
"github.com/netbirdio/netbird/client/iface/udpmux"
"github.com/netbirdio/netbird/client/internal/acl"
"github.com/netbirdio/netbird/client/internal/dns"
dnsconfig "github.com/netbirdio/netbird/client/internal/dns/config"
@@ -166,7 +166,7 @@ type Engine struct {
wgInterface WGIface
udpMux *udpmux.UniversalUDPMuxDefault
udpMux *bind.UniversalUDPMuxDefault
// networkSerial is the latest CurrentSerial (state ID) of the network sent by the Management service
networkSerial uint64
@@ -461,7 +461,7 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
StunTurn: &e.stunTurn,
InterfaceBlackList: e.config.IFaceBlackList,
DisableIPv6Discovery: e.config.DisableIPv6Discovery,
UDPMux: e.udpMux.SingleSocketUDPMux,
UDPMux: e.udpMux.UDPMuxDefault,
UDPMuxSrflx: e.udpMux,
NATExternalIPs: e.parseNATExternalIPMappings(),
}
@@ -1326,7 +1326,7 @@ func (e *Engine) createPeerConn(pubKey string, allowedIPs []netip.Prefix, agentV
StunTurn: &e.stunTurn,
InterfaceBlackList: e.config.IFaceBlackList,
DisableIPv6Discovery: e.config.DisableIPv6Discovery,
UDPMux: e.udpMux.SingleSocketUDPMux,
UDPMux: e.udpMux.UDPMuxDefault,
UDPMuxSrflx: e.udpMux,
NATExternalIPs: e.parseNATExternalIPMappings(),
},

View File

@@ -26,11 +26,10 @@ import (
"google.golang.org/grpc/keepalive"
"github.com/netbirdio/management-integrations/integrations"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/udpmux"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/iface/wgproxy"
"github.com/netbirdio/netbird/client/internal/dns"
@@ -85,7 +84,7 @@ type MockWGIface struct {
NameFunc func() string
AddressFunc func() wgaddr.Address
ToInterfaceFunc func() *net.Interface
UpFunc func() (*udpmux.UniversalUDPMuxDefault, error)
UpFunc func() (*bind.UniversalUDPMuxDefault, error)
UpdateAddrFunc func(newAddr string) error
UpdatePeerFunc func(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error
RemovePeerFunc func(peerKey string) error
@@ -135,7 +134,7 @@ func (m *MockWGIface) ToInterface() *net.Interface {
return m.ToInterfaceFunc()
}
func (m *MockWGIface) Up() (*udpmux.UniversalUDPMuxDefault, error) {
func (m *MockWGIface) Up() (*bind.UniversalUDPMuxDefault, error) {
return m.UpFunc()
}
@@ -414,7 +413,7 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
if err != nil {
t.Fatal(err)
}
engine.udpMux = udpmux.NewUniversalUDPMuxDefault(udpmux.UniversalUDPMuxParams{UDPConn: conn, MTU: 1280})
engine.udpMux = bind.NewUniversalUDPMuxDefault(bind.UniversalUDPMuxParams{UDPConn: conn, MTU: 1280})
engine.ctx = ctx
engine.srWatcher = guard.NewSRWatcher(nil, nil, nil, icemaker.Config{})
engine.connMgr = NewConnMgr(engine.config, engine.statusRecorder, engine.peerStore, wgIface)

View File

@@ -9,9 +9,9 @@ import (
"golang.zx2c4.com/wireguard/tun/netstack"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/udpmux"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/iface/wgproxy"
"github.com/netbirdio/netbird/monotime"
@@ -24,7 +24,7 @@ type wgIfaceBase interface {
Name() string
Address() wgaddr.Address
ToInterface() *net.Interface
Up() (*udpmux.UniversalUDPMuxDefault, error)
Up() (*bind.UniversalUDPMuxDefault, error)
UpdateAddr(newAddr string) error
GetProxy() wgproxy.Proxy
UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error

View File

@@ -21,9 +21,9 @@ import (
"github.com/netbirdio/netbird/client/internal/ingressgw"
"github.com/netbirdio/netbird/client/internal/relay"
"github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/route"
"github.com/netbirdio/netbird/shared/management/domain"
relayClient "github.com/netbirdio/netbird/shared/relay/client"
"github.com/netbirdio/netbird/route"
)
const eventQueueSize = 10
@@ -201,6 +201,8 @@ type Status struct {
resolvedDomainsStates map[domain.Domain]ResolvedDomainInfo
lazyConnectionEnabled bool
lastDisconnectLog map[string]time.Time
// To reduce the number of notification invocation this bool will be true when need to call the notification
// Some Peer actions mostly used by in a batch when the network map has been synchronized. In these type of events
// set to true this variable and at the end of the processing we will reset it by the FinishPeerListModifications()
@@ -229,6 +231,7 @@ func NewRecorder(mgmAddress string) *Status {
notifier: newNotifier(),
mgmAddress: mgmAddress,
resolvedDomainsStates: map[domain.Domain]ResolvedDomainInfo{},
lastDisconnectLog: make(map[string]time.Time),
}
}
@@ -487,6 +490,9 @@ func (d *Status) UpdatePeerRelayedStateToDisconnected(receivedState State) error
d.peers[receivedState.PubKey] = peerState
// info log about disconnect with impacted routes (throttled)
d.logPeerDisconnectIfNeeded(receivedState.PubKey, peerState)
if hasConnStatusChanged(oldState, receivedState.ConnStatus) {
d.notifyPeerListChanged()
}
@@ -519,6 +525,9 @@ func (d *Status) UpdatePeerICEStateToDisconnected(receivedState State) error {
d.peers[receivedState.PubKey] = peerState
// info log about disconnect with impacted routes (throttled)
d.logPeerDisconnectIfNeeded(receivedState.PubKey, peerState)
if hasConnStatusChanged(oldState, receivedState.ConnStatus) {
d.notifyPeerListChanged()
}
@@ -529,6 +538,49 @@ func (d *Status) UpdatePeerICEStateToDisconnected(receivedState State) error {
return nil
}
// logPeerDisconnectIfNeeded logs an info message when a routing peer transitions to disconnected
// with the number of impacted routes. Throttled to once per peer per 30 seconds.
func (d *Status) logPeerDisconnectIfNeeded(pubKey string, state State) {
if state.ConnStatus != StatusIdle {
return
}
now := time.Now()
last, ok := d.lastDisconnectLog[pubKey]
if ok && now.Sub(last) < 10*time.Second {
return
}
d.lastDisconnectLog[pubKey] = now
routes := state.GetRoutes()
numRoutes := len(routes)
fqdn := state.FQDN
if fqdn == "" {
fqdn = pubKey
}
// prepare a bounded list of impacted routes to avoid huge log lines
maxList := 20
list := make([]string, 0, maxList)
for r := range routes {
if len(list) >= maxList {
break
}
list = append(list, r)
}
more := ""
if numRoutes > len(list) {
more = ", more=" + fmt.Sprintf("%d", numRoutes-len(list))
}
if len(list) > 0 {
log.Warnf("[d] Routing peer disconnected: peer=%s impacted_routes=%d routes=%v%s", fqdn, numRoutes, list, more)
} else {
log.Warnf("[d] Routing peer disconnected: peer=%s impacted_routes=%d", fqdn, numRoutes)
}
}
// UpdateWireGuardPeerState updates the WireGuard bits of the peer state
func (d *Status) UpdateWireGuardPeerState(pubKey string, wgStats configurer.WGStats) error {
d.mux.Lock()

View File

@@ -9,10 +9,11 @@ import (
"time"
"github.com/pion/ice/v4"
"github.com/pion/stun/v2"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/udpmux"
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/internal/peer/conntype"
icemaker "github.com/netbirdio/netbird/client/internal/peer/ice"
"github.com/netbirdio/netbird/client/internal/stdnet"
@@ -54,6 +55,10 @@ type WorkerICE struct {
sessionID ICESessionID
muxAgent sync.Mutex
StunTurn []*stun.URI
sentExtraSrflx bool
localUfrag string
localPwd string
@@ -134,6 +139,7 @@ func (w *WorkerICE) OnNewOffer(remoteOfferAnswer *OfferAnswer) {
w.muxAgent.Unlock()
return
}
w.sentExtraSrflx = false
w.agent = agent
w.agentDialerCancel = dialerCancel
w.agentConnecting = true
@@ -160,21 +166,6 @@ func (w *WorkerICE) OnRemoteCandidate(candidate ice.Candidate, haRoutes route.HA
w.log.Errorf("error while handling remote candidate")
return
}
if shouldAddExtraCandidate(candidate) {
// sends an extra server reflexive candidate to the remote peer with our related port (usually the wireguard port)
// this is useful when network has an existing port forwarding rule for the wireguard port and this peer
extraSrflx, err := extraSrflxCandidate(candidate)
if err != nil {
w.log.Errorf("failed creating extra server reflexive candidate %s", err)
return
}
if err := w.agent.AddRemoteCandidate(extraSrflx); err != nil {
w.log.Errorf("error while handling remote candidate")
return
}
}
}
func (w *WorkerICE) GetLocalUserCredentials() (frag string, pwd string) {
@@ -218,9 +209,7 @@ func (w *WorkerICE) reCreateAgent(dialerCancel context.CancelFunc, candidates []
return nil, err
}
if err := agent.OnSelectedCandidatePairChange(func(c1, c2 ice.Candidate) {
w.onICESelectedCandidatePair(agent, c1, c2)
}); err != nil {
if err := agent.OnSelectedCandidatePairChange(w.onICESelectedCandidatePair); err != nil {
return nil, err
}
@@ -338,7 +327,7 @@ func (w *WorkerICE) punchRemoteWGPort(pair *ice.CandidatePair, remoteWgPort int)
return
}
mux, ok := w.config.ICEConfig.UDPMuxSrflx.(*udpmux.UniversalUDPMuxDefault)
mux, ok := w.config.ICEConfig.UDPMuxSrflx.(*bind.UniversalUDPMuxDefault)
if !ok {
w.log.Warn("invalid udp mux conversion")
return
@@ -365,19 +354,48 @@ func (w *WorkerICE) onICECandidate(candidate ice.Candidate) {
w.log.Errorf("failed signaling candidate to the remote peer %s %s", w.config.Key, err)
}
}()
}
func (w *WorkerICE) onICESelectedCandidatePair(agent *icemaker.ThreadSafeAgent, c1, c2 ice.Candidate) {
w.log.Debugf("selected candidate pair [local <-> remote] -> [%s <-> %s], peer %s", c1.String(), c2.String(),
w.config.Key)
pairStat, ok := agent.GetSelectedCandidatePairStats()
if !ok {
w.log.Warnf("failed to get selected candidate pair stats")
if !w.shouldSendExtraSrflxCandidate(candidate) {
return
}
duration := time.Duration(pairStat.CurrentRoundTripTime * float64(time.Second))
// sends an extra server reflexive candidate to the remote peer with our related port (usually the wireguard port)
// this is useful when network has an existing port forwarding rule for the wireguard port and this peer
extraSrflx, err := extraSrflxCandidate(candidate)
if err != nil {
w.log.Errorf("failed creating extra server reflexive candidate %s", err)
return
}
w.sentExtraSrflx = true
go func() {
err = w.signaler.SignalICECandidate(extraSrflx, w.config.Key)
if err != nil {
w.log.Errorf("failed signaling the extra server reflexive candidate: %s", err)
}
}()
}
func (w *WorkerICE) onICESelectedCandidatePair(c1 ice.Candidate, c2 ice.Candidate) {
w.log.Debugf("selected candidate pair [local <-> remote] -> [%s <-> %s], peer %s", c1.String(), c2.String(),
w.config.Key)
w.muxAgent.Lock()
pair, err := w.agent.GetSelectedCandidatePair()
if err != nil {
w.log.Warnf("failed to get selected candidate pair: %s", err)
w.muxAgent.Unlock()
return
}
if pair == nil {
w.log.Warnf("selected candidate pair is nil, cannot proceed")
w.muxAgent.Unlock()
return
}
w.muxAgent.Unlock()
duration := time.Duration(pair.CurrentRoundTripTime() * float64(time.Second))
if err := w.statusRecorder.UpdateLatency(w.config.Key, duration); err != nil {
w.log.Debugf("failed to update latency for peer: %s", err)
return
@@ -406,31 +424,22 @@ func (w *WorkerICE) onConnectionStateChange(agent *icemaker.ThreadSafeAgent, dia
}
}
func (w *WorkerICE) shouldSendExtraSrflxCandidate(candidate ice.Candidate) bool {
if !w.sentExtraSrflx && candidate.Type() == ice.CandidateTypeServerReflexive && candidate.Port() != candidate.RelatedAddress().Port {
return true
}
return false
}
func (w *WorkerICE) turnAgentDial(ctx context.Context, agent *icemaker.ThreadSafeAgent, remoteOfferAnswer *OfferAnswer) (*ice.Conn, error) {
if isController(w.config) {
return w.agent.Dial(ctx, remoteOfferAnswer.IceCredentials.UFrag, remoteOfferAnswer.IceCredentials.Pwd)
isControlling := w.config.LocalKey > w.config.Key
if isControlling {
return agent.Dial(ctx, remoteOfferAnswer.IceCredentials.UFrag, remoteOfferAnswer.IceCredentials.Pwd)
} else {
return agent.Accept(ctx, remoteOfferAnswer.IceCredentials.UFrag, remoteOfferAnswer.IceCredentials.Pwd)
}
}
func shouldAddExtraCandidate(candidate ice.Candidate) bool {
if candidate.Type() != ice.CandidateTypeServerReflexive {
return false
}
if candidate.Port() == candidate.RelatedAddress().Port {
return false
}
// in the older version when we didn't set candidate ID extension the remote peer sent the extra candidates
// in newer version we generate locally the extra candidate
if _, ok := candidate.GetExtension(ice.ExtensionKeyCandidateID); !ok {
return false
}
return true
}
func extraSrflxCandidate(candidate ice.Candidate) (*ice.CandidateServerReflexive, error) {
relatedAdd := candidate.RelatedAddress()
ec, err := ice.NewCandidateServerReflexive(&ice.CandidateServerReflexiveConfig{
@@ -446,10 +455,6 @@ func extraSrflxCandidate(candidate ice.Candidate) (*ice.CandidateServerReflexive
}
for _, e := range candidate.Extensions() {
// overwrite the original candidate ID with the new one to avoid candidate duplication
if e.Key == ice.ExtensionKeyCandidateID {
e.Value = candidate.ID()
}
if err := ec.AddExtension(e); err != nil {
return nil, err
}

View File

@@ -1,7 +1,7 @@
// Code generated by protoc-gen-go. DO NOT EDIT.
// versions:
// protoc-gen-go v1.36.6
// protoc v3.21.9
// protoc v5.29.3
// source: daemon.proto
package proto
@@ -791,12 +791,11 @@ func (*UpResponse) Descriptor() ([]byte, []int) {
}
type StatusRequest struct {
state protoimpl.MessageState `protogen:"open.v1"`
GetFullPeerStatus bool `protobuf:"varint,1,opt,name=getFullPeerStatus,proto3" json:"getFullPeerStatus,omitempty"`
ShouldRunProbes bool `protobuf:"varint,2,opt,name=shouldRunProbes,proto3" json:"shouldRunProbes,omitempty"`
WaitForConnectingShift bool `protobuf:"varint,3,opt,name=waitForConnectingShift,proto3" json:"waitForConnectingShift,omitempty"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
state protoimpl.MessageState `protogen:"open.v1"`
GetFullPeerStatus bool `protobuf:"varint,1,opt,name=getFullPeerStatus,proto3" json:"getFullPeerStatus,omitempty"`
ShouldRunProbes bool `protobuf:"varint,2,opt,name=shouldRunProbes,proto3" json:"shouldRunProbes,omitempty"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
func (x *StatusRequest) Reset() {
@@ -843,13 +842,6 @@ func (x *StatusRequest) GetShouldRunProbes() bool {
return false
}
func (x *StatusRequest) GetWaitForConnectingShift() bool {
if x != nil {
return x.WaitForConnectingShift
}
return false
}
type StatusResponse struct {
state protoimpl.MessageState `protogen:"open.v1"`
// status of the server.
@@ -4681,11 +4673,10 @@ const file_daemon_proto_rawDesc = "" +
"\f_profileNameB\v\n" +
"\t_username\"\f\n" +
"\n" +
"UpResponse\"\x9f\x01\n" +
"UpResponse\"g\n" +
"\rStatusRequest\x12,\n" +
"\x11getFullPeerStatus\x18\x01 \x01(\bR\x11getFullPeerStatus\x12(\n" +
"\x0fshouldRunProbes\x18\x02 \x01(\bR\x0fshouldRunProbes\x126\n" +
"\x16waitForConnectingShift\x18\x03 \x01(\bR\x16waitForConnectingShift\"\x82\x01\n" +
"\x0fshouldRunProbes\x18\x02 \x01(\bR\x0fshouldRunProbes\"\x82\x01\n" +
"\x0eStatusResponse\x12\x16\n" +
"\x06status\x18\x01 \x01(\tR\x06status\x122\n" +
"\n" +

View File

@@ -186,7 +186,6 @@ message UpResponse {}
message StatusRequest{
bool getFullPeerStatus = 1;
bool shouldRunProbes = 2;
bool waitForConnectingShift = 3;
}
message StatusResponse{

View File

@@ -65,8 +65,6 @@ type Server struct {
mutex sync.Mutex
config *profilemanager.Config
proto.UnimplementedDaemonServiceServer
clientRunning bool // protected by mutex
clientRunningChan chan struct{}
connectClient *internal.ConnectClient
@@ -105,7 +103,6 @@ func New(ctx context.Context, logFile string, configFile string, profilesDisable
func (s *Server) Start() error {
s.mutex.Lock()
defer s.mutex.Unlock()
state := internal.CtxGetState(s.rootCtx)
if err := handlePanicLog(); err != nil {
@@ -119,12 +116,14 @@ func (s *Server) Start() error {
// if current state contains any error, return it
// in all other cases we can continue execution only if status is idle and up command was
// not in the progress or already successfully established connection.
_, err := state.Status()
status, err := state.Status()
if err != nil {
return err
}
state.Set(internal.StatusConnecting)
if status != internal.StatusIdle {
return nil
}
ctx, cancel := context.WithCancel(s.rootCtx)
s.actCancel = cancel
@@ -173,12 +172,8 @@ func (s *Server) Start() error {
return nil
}
if s.clientRunning {
return nil
}
s.clientRunning = true
s.clientRunningChan = make(chan struct{}, 1)
go s.connectWithRetryRuns(ctx, config, s.statusRecorder, s.clientRunningChan)
go s.connectWithRetryRuns(ctx, config, s.statusRecorder, nil)
return nil
}
@@ -209,22 +204,12 @@ func (s *Server) setDefaultConfigIfNotExists(ctx context.Context) error {
// connectWithRetryRuns runs the client connection with a backoff strategy where we retry the operation as additional
// mechanism to keep the client connected even when the connection is lost.
// we cancel retry if the client receive a stop or down command, or if disable auto connect is configured.
func (s *Server) connectWithRetryRuns(ctx context.Context, profileConfig *profilemanager.Config, statusRecorder *peer.Status, runningChan chan struct{}) {
defer func() {
s.mutex.Lock()
s.clientRunning = false
s.mutex.Unlock()
}()
if s.config.DisableAutoConnect {
if err := s.connect(ctx, s.config, s.statusRecorder, runningChan); err != nil {
log.Debugf("run client connection exited with error: %v", err)
}
log.Tracef("client connection exited")
return
}
func (s *Server) connectWithRetryRuns(ctx context.Context, config *profilemanager.Config, statusRecorder *peer.Status,
runningChan chan struct{},
) {
backOff := getConnectWithBackoff(ctx)
retryStarted := false
go func() {
t := time.NewTicker(24 * time.Hour)
for {
@@ -233,34 +218,91 @@ func (s *Server) connectWithRetryRuns(ctx context.Context, profileConfig *profil
t.Stop()
return
case <-t.C:
mgmtState := statusRecorder.GetManagementState()
signalState := statusRecorder.GetSignalState()
if mgmtState.Connected && signalState.Connected {
log.Tracef("resetting status")
backOff.Reset()
} else {
log.Tracef("not resetting status: mgmt: %v, signal: %v", mgmtState.Connected, signalState.Connected)
if retryStarted {
mgmtState := statusRecorder.GetManagementState()
signalState := statusRecorder.GetSignalState()
if mgmtState.Connected && signalState.Connected {
log.Tracef("resetting status")
retryStarted = false
} else {
log.Tracef("not resetting status: mgmt: %v, signal: %v", mgmtState.Connected, signalState.Connected)
}
}
}
}
}()
runOperation := func() error {
err := s.connect(ctx, profileConfig, statusRecorder, runningChan)
log.Tracef("running client connection")
s.connectClient = internal.NewConnectClient(ctx, config, statusRecorder)
s.connectClient.SetSyncResponsePersistence(s.persistSyncResponse)
err := s.connectClient.Run(runningChan)
if err != nil {
log.Debugf("run client connection exited with error: %v. Will retry in the background", err)
return err
}
log.Tracef("client connection exited gracefully, do not need to retry")
return nil
if config.DisableAutoConnect {
return backoff.Permanent(err)
}
if !retryStarted {
retryStarted = true
backOff.Reset()
}
log.Tracef("client connection exited")
return fmt.Errorf("client connection exited")
}
if err := backoff.Retry(runOperation, backOff); err != nil {
log.Errorf("operation failed: %v", err)
err := backoff.Retry(runOperation, backOff)
if s, ok := gstatus.FromError(err); ok && s.Code() != codes.Canceled {
log.Errorf("received an error when trying to connect: %v", err)
} else {
log.Tracef("retry canceled")
}
}
// getConnectWithBackoff returns a backoff with exponential backoff strategy for connection retries
func getConnectWithBackoff(ctx context.Context) backoff.BackOff {
initialInterval := parseEnvDuration(retryInitialIntervalVar, defaultInitialRetryTime)
maxInterval := parseEnvDuration(maxRetryIntervalVar, defaultMaxRetryInterval)
maxElapsedTime := parseEnvDuration(maxRetryTimeVar, defaultMaxRetryTime)
multiplier := defaultRetryMultiplier
if envValue := os.Getenv(retryMultiplierVar); envValue != "" {
// parse the multiplier from the environment variable string value to float64
value, err := strconv.ParseFloat(envValue, 64)
if err != nil {
log.Warnf("unable to parse environment variable %s: %s. using default: %f", retryMultiplierVar, envValue, multiplier)
} else {
multiplier = value
}
}
return backoff.WithContext(&backoff.ExponentialBackOff{
InitialInterval: initialInterval,
RandomizationFactor: 1,
Multiplier: multiplier,
MaxInterval: maxInterval,
MaxElapsedTime: maxElapsedTime, // 14 days
Stop: backoff.Stop,
Clock: backoff.SystemClock,
}, ctx)
}
// parseEnvDuration parses the environment variable and returns the duration
func parseEnvDuration(envVar string, defaultDuration time.Duration) time.Duration {
if envValue := os.Getenv(envVar); envValue != "" {
if duration, err := time.ParseDuration(envValue); err == nil {
return duration
}
log.Warnf("unable to parse environment variable %s: %s. using default: %s", envVar, envValue, defaultDuration)
}
return defaultDuration
}
// loginAttempt attempts to login using the provided information. it returns a status in case something fails
func (s *Server) loginAttempt(ctx context.Context, setupKey, jwtToken string) (internal.StatusType, error) {
var status internal.StatusType
@@ -674,14 +716,11 @@ func (s *Server) Up(callerCtx context.Context, msg *proto.UpRequest) (*proto.UpR
timeoutCtx, cancel := context.WithTimeout(callerCtx, 50*time.Second)
defer cancel()
if !s.clientRunning {
s.clientRunning = true
s.clientRunningChan = make(chan struct{}, 1)
go s.connectWithRetryRuns(ctx, s.config, s.statusRecorder, s.clientRunningChan)
}
runningChan := make(chan struct{}, 1) // buffered channel to do not lose the signal
go s.connectWithRetryRuns(ctx, s.config, s.statusRecorder, runningChan)
for {
select {
case <-s.clientRunningChan:
case <-runningChan:
s.isSessionActive.Store(true)
return &proto.UpResponse{}, nil
case <-callerCtx.Done():
@@ -959,33 +998,6 @@ func (s *Server) sendLogoutRequestWithConfig(ctx context.Context, config *profil
return mgmClient.Logout()
}
func waitStateShift(ctx context.Context) {
timer := time.NewTimer(5 * time.Second)
defer timer.Stop()
for {
select {
case <-ctx.Done():
log.Warnf("context done while waiting for state shift: %v", ctx.Err())
timer.Stop()
return
case <-timer.C:
log.Warnf("state shift timed out")
timer.Stop()
return
default:
status, err := internal.CtxGetState(ctx).Status()
if err != nil {
log.Errorf("failed to get status: %v", err)
return
}
if status != internal.StatusConnecting {
log.Infof("state shifting status: %v", status)
return
}
}
}
}
// Status returns the daemon status
func (s *Server) Status(
ctx context.Context,
@@ -998,10 +1010,6 @@ func (s *Server) Status(
s.mutex.Lock()
defer s.mutex.Unlock()
if msg.WaitForConnectingShift {
waitStateShift(s.rootCtx)
}
status, err := internal.CtxGetState(s.rootCtx).Status()
if err != nil {
return nil, err
@@ -1119,134 +1127,6 @@ func (s *Server) GetConfig(ctx context.Context, req *proto.GetConfigRequest) (*p
}, nil
}
// AddProfile adds a new profile to the daemon.
func (s *Server) AddProfile(ctx context.Context, msg *proto.AddProfileRequest) (*proto.AddProfileResponse, error) {
s.mutex.Lock()
defer s.mutex.Unlock()
if s.checkProfilesDisabled() {
return nil, gstatus.Errorf(codes.Unavailable, errProfilesDisabled)
}
if msg.ProfileName == "" || msg.Username == "" {
return nil, gstatus.Errorf(codes.InvalidArgument, "profile name and username must be provided")
}
if err := s.profileManager.AddProfile(msg.ProfileName, msg.Username); err != nil {
log.Errorf("failed to create profile: %v", err)
return nil, fmt.Errorf("failed to create profile: %w", err)
}
return &proto.AddProfileResponse{}, nil
}
// RemoveProfile removes a profile from the daemon.
func (s *Server) RemoveProfile(ctx context.Context, msg *proto.RemoveProfileRequest) (*proto.RemoveProfileResponse, error) {
s.mutex.Lock()
defer s.mutex.Unlock()
if err := s.validateProfileOperation(msg.ProfileName, false); err != nil {
return nil, err
}
if err := s.logoutFromProfile(ctx, msg.ProfileName, msg.Username); err != nil {
log.Warnf("failed to logout from profile %s before removal: %v", msg.ProfileName, err)
}
if err := s.profileManager.RemoveProfile(msg.ProfileName, msg.Username); err != nil {
log.Errorf("failed to remove profile: %v", err)
return nil, fmt.Errorf("failed to remove profile: %w", err)
}
return &proto.RemoveProfileResponse{}, nil
}
// ListProfiles lists all profiles in the daemon.
func (s *Server) ListProfiles(ctx context.Context, msg *proto.ListProfilesRequest) (*proto.ListProfilesResponse, error) {
s.mutex.Lock()
defer s.mutex.Unlock()
if msg.Username == "" {
return nil, gstatus.Errorf(codes.InvalidArgument, "username must be provided")
}
profiles, err := s.profileManager.ListProfiles(msg.Username)
if err != nil {
log.Errorf("failed to list profiles: %v", err)
return nil, fmt.Errorf("failed to list profiles: %w", err)
}
response := &proto.ListProfilesResponse{
Profiles: make([]*proto.Profile, len(profiles)),
}
for i, profile := range profiles {
response.Profiles[i] = &proto.Profile{
Name: profile.Name,
IsActive: profile.IsActive,
}
}
return response, nil
}
// GetActiveProfile returns the active profile in the daemon.
func (s *Server) GetActiveProfile(ctx context.Context, msg *proto.GetActiveProfileRequest) (*proto.GetActiveProfileResponse, error) {
s.mutex.Lock()
defer s.mutex.Unlock()
activeProfile, err := s.profileManager.GetActiveProfileState()
if err != nil {
log.Errorf("failed to get active profile state: %v", err)
return nil, fmt.Errorf("failed to get active profile state: %w", err)
}
return &proto.GetActiveProfileResponse{
ProfileName: activeProfile.Name,
Username: activeProfile.Username,
}, nil
}
// GetFeatures returns the features supported by the daemon.
func (s *Server) GetFeatures(ctx context.Context, msg *proto.GetFeaturesRequest) (*proto.GetFeaturesResponse, error) {
s.mutex.Lock()
defer s.mutex.Unlock()
features := &proto.GetFeaturesResponse{
DisableProfiles: s.checkProfilesDisabled(),
DisableUpdateSettings: s.checkUpdateSettingsDisabled(),
}
return features, nil
}
func (s *Server) connect(ctx context.Context, config *profilemanager.Config, statusRecorder *peer.Status, runningChan chan struct{}) error {
log.Tracef("running client connection")
s.connectClient = internal.NewConnectClient(ctx, config, statusRecorder)
s.connectClient.SetSyncResponsePersistence(s.persistSyncResponse)
if err := s.connectClient.Run(runningChan); err != nil {
return err
}
return nil
}
func (s *Server) checkProfilesDisabled() bool {
// Check if the environment variable is set to disable profiles
if s.profilesDisabled {
return true
}
return false
}
func (s *Server) checkUpdateSettingsDisabled() bool {
// Check if the environment variable is set to disable profiles
if s.updateSettingsDisabled {
return true
}
return false
}
func (s *Server) onSessionExpire() {
if runtime.GOOS != "windows" {
isUIActive := internal.CheckUIApp()
@@ -1258,45 +1138,6 @@ func (s *Server) onSessionExpire() {
}
}
// getConnectWithBackoff returns a backoff with exponential backoff strategy for connection retries
func getConnectWithBackoff(ctx context.Context) backoff.BackOff {
initialInterval := parseEnvDuration(retryInitialIntervalVar, defaultInitialRetryTime)
maxInterval := parseEnvDuration(maxRetryIntervalVar, defaultMaxRetryInterval)
maxElapsedTime := parseEnvDuration(maxRetryTimeVar, defaultMaxRetryTime)
multiplier := defaultRetryMultiplier
if envValue := os.Getenv(retryMultiplierVar); envValue != "" {
// parse the multiplier from the environment variable string value to float64
value, err := strconv.ParseFloat(envValue, 64)
if err != nil {
log.Warnf("unable to parse environment variable %s: %s. using default: %f", retryMultiplierVar, envValue, multiplier)
} else {
multiplier = value
}
}
return backoff.WithContext(&backoff.ExponentialBackOff{
InitialInterval: initialInterval,
RandomizationFactor: 1,
Multiplier: multiplier,
MaxInterval: maxInterval,
MaxElapsedTime: maxElapsedTime, // 14 days
Stop: backoff.Stop,
Clock: backoff.SystemClock,
}, ctx)
}
// parseEnvDuration parses the environment variable and returns the duration
func parseEnvDuration(envVar string, defaultDuration time.Duration) time.Duration {
if envValue := os.Getenv(envVar); envValue != "" {
if duration, err := time.ParseDuration(envValue); err == nil {
return duration
}
log.Warnf("unable to parse environment variable %s: %s. using default: %s", envVar, envValue, defaultDuration)
}
return defaultDuration
}
func toProtoFullStatus(fullStatus peer.FullStatus) *proto.FullStatus {
pbFullStatus := proto.FullStatus{
ManagementState: &proto.ManagementState{},
@@ -1411,3 +1252,121 @@ func sendTerminalNotification() error {
return wallCmd.Wait()
}
// AddProfile adds a new profile to the daemon.
func (s *Server) AddProfile(ctx context.Context, msg *proto.AddProfileRequest) (*proto.AddProfileResponse, error) {
s.mutex.Lock()
defer s.mutex.Unlock()
if s.checkProfilesDisabled() {
return nil, gstatus.Errorf(codes.Unavailable, errProfilesDisabled)
}
if msg.ProfileName == "" || msg.Username == "" {
return nil, gstatus.Errorf(codes.InvalidArgument, "profile name and username must be provided")
}
if err := s.profileManager.AddProfile(msg.ProfileName, msg.Username); err != nil {
log.Errorf("failed to create profile: %v", err)
return nil, fmt.Errorf("failed to create profile: %w", err)
}
return &proto.AddProfileResponse{}, nil
}
// RemoveProfile removes a profile from the daemon.
func (s *Server) RemoveProfile(ctx context.Context, msg *proto.RemoveProfileRequest) (*proto.RemoveProfileResponse, error) {
s.mutex.Lock()
defer s.mutex.Unlock()
if err := s.validateProfileOperation(msg.ProfileName, false); err != nil {
return nil, err
}
if err := s.logoutFromProfile(ctx, msg.ProfileName, msg.Username); err != nil {
log.Warnf("failed to logout from profile %s before removal: %v", msg.ProfileName, err)
}
if err := s.profileManager.RemoveProfile(msg.ProfileName, msg.Username); err != nil {
log.Errorf("failed to remove profile: %v", err)
return nil, fmt.Errorf("failed to remove profile: %w", err)
}
return &proto.RemoveProfileResponse{}, nil
}
// ListProfiles lists all profiles in the daemon.
func (s *Server) ListProfiles(ctx context.Context, msg *proto.ListProfilesRequest) (*proto.ListProfilesResponse, error) {
s.mutex.Lock()
defer s.mutex.Unlock()
if msg.Username == "" {
return nil, gstatus.Errorf(codes.InvalidArgument, "username must be provided")
}
profiles, err := s.profileManager.ListProfiles(msg.Username)
if err != nil {
log.Errorf("failed to list profiles: %v", err)
return nil, fmt.Errorf("failed to list profiles: %w", err)
}
response := &proto.ListProfilesResponse{
Profiles: make([]*proto.Profile, len(profiles)),
}
for i, profile := range profiles {
response.Profiles[i] = &proto.Profile{
Name: profile.Name,
IsActive: profile.IsActive,
}
}
return response, nil
}
// GetActiveProfile returns the active profile in the daemon.
func (s *Server) GetActiveProfile(ctx context.Context, msg *proto.GetActiveProfileRequest) (*proto.GetActiveProfileResponse, error) {
s.mutex.Lock()
defer s.mutex.Unlock()
activeProfile, err := s.profileManager.GetActiveProfileState()
if err != nil {
log.Errorf("failed to get active profile state: %v", err)
return nil, fmt.Errorf("failed to get active profile state: %w", err)
}
return &proto.GetActiveProfileResponse{
ProfileName: activeProfile.Name,
Username: activeProfile.Username,
}, nil
}
// GetFeatures returns the features supported by the daemon.
func (s *Server) GetFeatures(ctx context.Context, msg *proto.GetFeaturesRequest) (*proto.GetFeaturesResponse, error) {
s.mutex.Lock()
defer s.mutex.Unlock()
features := &proto.GetFeaturesResponse{
DisableProfiles: s.checkProfilesDisabled(),
DisableUpdateSettings: s.checkUpdateSettingsDisabled(),
}
return features, nil
}
func (s *Server) checkProfilesDisabled() bool {
// Check if the environment variable is set to disable profiles
if s.profilesDisabled {
return true
}
return false
}
func (s *Server) checkUpdateSettingsDisabled() bool {
// Check if the environment variable is set to disable profiles
if s.updateSettingsDisabled {
return true
}
return false
}

View File

@@ -18,7 +18,6 @@ import (
"google.golang.org/grpc/keepalive"
"github.com/netbirdio/management-integrations/integrations"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/profilemanager"

View File

@@ -529,7 +529,7 @@ func (s *serviceClient) getSettingsForm() *widget.Form {
var req proto.SetConfigRequest
req.ProfileName = activeProf.Name
req.Username = currUser.Username
if iMngURL != "" {
req.ManagementUrl = iMngURL
}
@@ -563,28 +563,27 @@ func (s *serviceClient) getSettingsForm() *widget.Form {
return
}
go func() {
status, err := conn.Status(s.ctx, &proto.StatusRequest{})
status, err := conn.Status(s.ctx, &proto.StatusRequest{})
if err != nil {
log.Errorf("get service status: %v", err)
dialog.ShowError(fmt.Errorf("Failed to get service status: %v", err), s.wSettings)
return
}
if status.Status == string(internal.StatusConnected) {
// run down & up
_, err = conn.Down(s.ctx, &proto.DownRequest{})
if err != nil {
log.Errorf("get service status: %v", err)
dialog.ShowError(fmt.Errorf("Failed to get service status: %v", err), s.wSettings)
log.Errorf("down service: %v", err)
}
_, err = conn.Up(s.ctx, &proto.UpRequest{})
if err != nil {
log.Errorf("up service: %v", err)
dialog.ShowError(fmt.Errorf("Failed to reconnect: %v", err), s.wSettings)
return
}
if status.Status == string(internal.StatusConnected) {
// run down & up
_, err = conn.Down(s.ctx, &proto.DownRequest{})
if err != nil {
log.Errorf("down service: %v", err)
}
}
_, err = conn.Up(s.ctx, &proto.UpRequest{})
if err != nil {
log.Errorf("up service: %v", err)
dialog.ShowError(fmt.Errorf("Failed to reconnect: %v", err), s.wSettings)
return
}
}
}()
}
},
OnCancel: func() {

2
go.mod
View File

@@ -261,6 +261,6 @@ replace golang.zx2c4.com/wireguard => github.com/netbirdio/wireguard-go v0.0.0-2
replace github.com/cloudflare/circl => github.com/cunicu/circl v0.0.0-20230801113412-fec58fc7b5f6
replace github.com/pion/ice/v4 => github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51
replace github.com/pion/ice/v4 => github.com/netbirdio/ice/v4 v4.0.0-20250827161942-426799a23107
replace github.com/libp2p/go-netroute => github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944

4
go.sum
View File

@@ -501,8 +501,8 @@ github.com/neelance/astrewrite v0.0.0-20160511093645-99348263ae86/go.mod h1:kHJE
github.com/neelance/sourcemap v0.0.0-20200213170602-2833bce08e4c/go.mod h1:Qr6/a/Q4r9LP1IltGz7tA7iOK1WonHEYhu1HRBA7ZiM=
github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944 h1:TDtJKmM6Sf8uYFx/dMeqNOL90KUoRscdfpFZ3Im89uk=
github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944/go.mod h1:sHA6TRxjQ6RLbnI+3R4DZo2Eseg/iKiPRfNmcuNySVQ=
github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51 h1:Ov4qdafATOgGMB1wbSuh+0aAHcwz9hdvB6VZjh1mVMI=
github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51/go.mod h1:ZSIbPdBn5hePO8CpF1PekH2SfpTxg1PDhEwtbqZS7R8=
github.com/netbirdio/ice/v4 v4.0.0-20250827161942-426799a23107 h1:ZJwhKexMlK15B/Ld+1T8VYE2Mt1lk1kf2DlXr46EHcw=
github.com/netbirdio/ice/v4 v4.0.0-20250827161942-426799a23107/go.mod h1:ZSIbPdBn5hePO8CpF1PekH2SfpTxg1PDhEwtbqZS7R8=
github.com/netbirdio/management-integrations/integrations v0.0.0-20250906095204-f87a07690ba0 h1:9BUqQHPVOGr0edk8EifUBUfTr2Ob0ypAPxtasUApBxQ=
github.com/netbirdio/management-integrations/integrations v0.0.0-20250906095204-f87a07690ba0/go.mod h1:v0nUbbHbuQnqR7yKIYnKzsLBCswLtp2JctmKYmGgVhc=
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9axERMVN63dqyFqnvuD+EMJHzM7mNGON8=

View File

@@ -328,45 +328,6 @@ delete_auto_service_user() {
echo "$PARSED_RESPONSE"
}
delete_default_zitadel_admin() {
INSTANCE_URL=$1
PAT=$2
# Search for the default zitadel-admin user
RESPONSE=$(
curl -sS -X POST "$INSTANCE_URL/management/v1/users/_search" \
-H "Authorization: Bearer $PAT" \
-H "Content-Type: application/json" \
-d '{
"queries": [
{
"userNameQuery": {
"userName": "zitadel-admin@",
"method": "TEXT_QUERY_METHOD_STARTS_WITH"
}
}
]
}'
)
DEFAULT_ADMIN_ID=$(echo "$RESPONSE" | jq -r '.result[0].id // empty')
if [ -n "$DEFAULT_ADMIN_ID" ] && [ "$DEFAULT_ADMIN_ID" != "null" ]; then
echo "Found default zitadel-admin user with ID: $DEFAULT_ADMIN_ID"
RESPONSE=$(
curl -sS -X DELETE "$INSTANCE_URL/management/v1/users/$DEFAULT_ADMIN_ID" \
-H "Authorization: Bearer $PAT" \
-H "Content-Type: application/json" \
)
PARSED_RESPONSE=$(echo "$RESPONSE" | jq -r '.details.changeDate // "deleted"')
handle_zitadel_request_response "$PARSED_RESPONSE" "delete_default_zitadel_admin" "$RESPONSE"
else
echo "Default zitadel-admin user not found: $RESPONSE"
fi
}
init_zitadel() {
echo -e "\nInitializing Zitadel with NetBird's applications\n"
INSTANCE_URL="$NETBIRD_HTTP_PROTOCOL://$NETBIRD_DOMAIN"
@@ -385,9 +346,6 @@ init_zitadel() {
echo -n "Waiting for Zitadel to become ready "
wait_api "$INSTANCE_URL" "$PAT"
echo "Deleting default zitadel-admin user..."
delete_default_zitadel_admin "$INSTANCE_URL" "$PAT"
# create the zitadel project
echo "Creating new zitadel project"
PROJECT_ID=$(create_new_project "$INSTANCE_URL" "$PAT")

View File

@@ -167,22 +167,10 @@ func arePolicyChangesAffectPeers(ctx context.Context, transaction store.Store, a
// validatePolicy validates the policy and its rules.
func validatePolicy(ctx context.Context, transaction store.Store, accountID string, policy *types.Policy) error {
if policy.ID != "" {
existingPolicy, err := transaction.GetPolicyByID(ctx, store.LockingStrengthNone, accountID, policy.ID)
_, err := transaction.GetPolicyByID(ctx, store.LockingStrengthNone, accountID, policy.ID)
if err != nil {
return err
}
// TODO: Refactor to support multiple rules per policy
existingRuleIDs := make(map[string]bool)
for _, rule := range existingPolicy.Rules {
existingRuleIDs[rule.ID] = true
}
for _, rule := range policy.Rules {
if rule.ID != "" && !existingRuleIDs[rule.ID] {
return status.Errorf(status.InvalidArgument, "invalid rule ID: %s", rule.ID)
}
}
} else {
policy.ID = xid.New().String()
policy.AccountID = accountID

View File

@@ -302,11 +302,7 @@ func (a *Account) GetPeerNetworkMap(
var zones []nbdns.CustomZone
if peersCustomZone.Domain != "" {
records := filterZoneRecordsForPeers(peer, peersCustomZone, peersToConnect)
zones = append(zones, nbdns.CustomZone{
Domain: peersCustomZone.Domain,
Records: records,
})
zones = append(zones, peersCustomZone)
}
dnsUpdate.CustomZones = zones
dnsUpdate.NameServerGroups = getPeerNSGroups(a, peerID)
@@ -1655,24 +1651,3 @@ func peerSupportsPortRanges(peerVer string) bool {
meetMinVer, err := posture.MeetsMinVersion(firewallRuleMinPortRangesVer, peerVer)
return err == nil && meetMinVer
}
// filterZoneRecordsForPeers filters DNS records to only include peers to connect.
func filterZoneRecordsForPeers(peer *nbpeer.Peer, customZone nbdns.CustomZone, peersToConnect []*nbpeer.Peer) []nbdns.SimpleRecord {
filteredRecords := make([]nbdns.SimpleRecord, 0, len(customZone.Records))
peerIPs := make(map[string]struct{})
// Add peer's own IP to include its own DNS records
peerIPs[peer.IP.String()] = struct{}{}
for _, peerToConnect := range peersToConnect {
peerIPs[peerToConnect.IP.String()] = struct{}{}
}
for _, record := range customZone.Records {
if _, exists := peerIPs[record.RData]; exists {
filteredRecords = append(filteredRecords, record)
}
}
return filteredRecords
}

View File

@@ -2,17 +2,14 @@ package types
import (
"context"
"fmt"
"net"
"net/netip"
"slices"
"testing"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
nbdns "github.com/netbirdio/netbird/dns"
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
@@ -838,109 +835,3 @@ func Test_NetworksNetMapGenShouldExcludeOtherRouters(t *testing.T) {
assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match")
assert.Len(t, sourcePeers, 2, "expected source peers don't match")
}
func Test_FilterZoneRecordsForPeers(t *testing.T) {
tests := []struct {
name string
peer *nbpeer.Peer
customZone nbdns.CustomZone
peersToConnect []*nbpeer.Peer
expectedRecords []nbdns.SimpleRecord
}{
{
name: "empty peers to connect",
customZone: nbdns.CustomZone{
Domain: "netbird.cloud.",
Records: []nbdns.SimpleRecord{
{Name: "peer1.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"},
{Name: "router.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.100"},
},
},
peersToConnect: []*nbpeer.Peer{},
peer: &nbpeer.Peer{ID: "router", IP: net.ParseIP("10.0.0.100")},
expectedRecords: []nbdns.SimpleRecord{
{Name: "router.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.100"},
},
},
{
name: "multiple peers multiple records match",
customZone: nbdns.CustomZone{
Domain: "netbird.cloud.",
Records: func() []nbdns.SimpleRecord {
var records []nbdns.SimpleRecord
for i := 1; i <= 100; i++ {
records = append(records, nbdns.SimpleRecord{
Name: fmt.Sprintf("peer%d.netbird.cloud", i),
Type: int(dns.TypeA),
Class: nbdns.DefaultClass,
TTL: 300,
RData: fmt.Sprintf("10.0.%d.%d", i/256, i%256),
})
}
return records
}(),
},
peersToConnect: func() []*nbpeer.Peer {
var peers []*nbpeer.Peer
for _, i := range []int{1, 5, 10, 25, 50, 75, 100} {
peers = append(peers, &nbpeer.Peer{
ID: fmt.Sprintf("peer%d", i),
IP: net.ParseIP(fmt.Sprintf("10.0.%d.%d", i/256, i%256)),
})
}
return peers
}(),
peer: &nbpeer.Peer{ID: "router", IP: net.ParseIP("10.0.0.100")},
expectedRecords: func() []nbdns.SimpleRecord {
var records []nbdns.SimpleRecord
for _, i := range []int{1, 5, 10, 25, 50, 75, 100} {
records = append(records, nbdns.SimpleRecord{
Name: fmt.Sprintf("peer%d.netbird.cloud", i),
Type: int(dns.TypeA),
Class: nbdns.DefaultClass,
TTL: 300,
RData: fmt.Sprintf("10.0.%d.%d", i/256, i%256),
})
}
return records
}(),
},
{
name: "peers with multiple DNS labels",
customZone: nbdns.CustomZone{
Domain: "netbird.cloud.",
Records: []nbdns.SimpleRecord{
{Name: "peer1.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"},
{Name: "peer1-alt.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"},
{Name: "peer1-backup.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"},
{Name: "peer2.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.2"},
{Name: "peer2-service.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.2"},
{Name: "peer3.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.3"},
{Name: "peer3-alt.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.3"},
{Name: "router.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.100"},
},
},
peersToConnect: []*nbpeer.Peer{
{ID: "peer1", IP: net.ParseIP("10.0.0.1"), DNSLabel: "peer1", ExtraDNSLabels: []string{"peer1-alt", "peer1-backup"}},
{ID: "peer2", IP: net.ParseIP("10.0.0.2"), DNSLabel: "peer2", ExtraDNSLabels: []string{"peer2-service"}},
},
peer: &nbpeer.Peer{ID: "router", IP: net.ParseIP("10.0.0.100")},
expectedRecords: []nbdns.SimpleRecord{
{Name: "peer1.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"},
{Name: "peer1-alt.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"},
{Name: "peer1-backup.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"},
{Name: "peer2.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.2"},
{Name: "peer2-service.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.2"},
{Name: "router.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.100"},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := filterZoneRecordsForPeers(tt.peer, tt.customZone, tt.peersToConnect)
assert.Equal(t, len(tt.expectedRecords), len(result))
assert.ElementsMatch(t, tt.expectedRecords, result)
})
}
}

View File

@@ -78,10 +78,9 @@ func NewManager(ctx context.Context, serverURLs []string, peerID string, mtu uin
tokenStore: tokenStore,
mtu: mtu,
serverPicker: &ServerPicker{
TokenStore: tokenStore,
PeerID: peerID,
MTU: mtu,
ConnectionTimeout: defaultConnectionTimeout,
TokenStore: tokenStore,
PeerID: peerID,
MTU: mtu,
},
relayClients: make(map[string]*RelayTrack),
onDisconnectedListeners: make(map[string]*list.List),

View File

@@ -13,8 +13,11 @@ import (
)
const (
maxConcurrentServers = 7
defaultConnectionTimeout = 30 * time.Second
maxConcurrentServers = 7
)
var (
connectionTimeout = 30 * time.Second
)
type connResult struct {
@@ -24,15 +27,14 @@ type connResult struct {
}
type ServerPicker struct {
TokenStore *auth.TokenStore
ServerURLs atomic.Value
PeerID string
MTU uint16
ConnectionTimeout time.Duration
TokenStore *auth.TokenStore
ServerURLs atomic.Value
PeerID string
MTU uint16
}
func (sp *ServerPicker) PickServer(parentCtx context.Context) (*Client, error) {
ctx, cancel := context.WithTimeout(parentCtx, sp.ConnectionTimeout)
ctx, cancel := context.WithTimeout(parentCtx, connectionTimeout)
defer cancel()
totalServers := len(sp.ServerURLs.Load().([]string))

View File

@@ -8,15 +8,15 @@ import (
)
func TestServerPicker_UnavailableServers(t *testing.T) {
timeout := 5 * time.Second
connectionTimeout = 5 * time.Second
sp := ServerPicker{
TokenStore: nil,
PeerID: "test",
ConnectionTimeout: timeout,
TokenStore: nil,
PeerID: "test",
}
sp.ServerURLs.Store([]string{"rel://dummy1", "rel://dummy2"})
ctx, cancel := context.WithTimeout(context.Background(), timeout+1)
ctx, cancel := context.WithTimeout(context.Background(), connectionTimeout+1)
defer cancel()
go func() {

View File

@@ -1,24 +0,0 @@
package healthcheck
import (
"os"
"strconv"
log "github.com/sirupsen/logrus"
)
const (
defaultAttemptThresholdEnv = "NB_RELAY_HC_ATTEMPT_THRESHOLD"
)
func getAttemptThresholdFromEnv() int {
if attemptThreshold := os.Getenv(defaultAttemptThresholdEnv); attemptThreshold != "" {
threshold, err := strconv.ParseInt(attemptThreshold, 10, 64)
if err != nil {
log.Errorf("Failed to parse attempt threshold from environment variable \"%s\" should be an integer. Using default value", attemptThreshold)
return defaultAttemptThreshold
}
return int(threshold)
}
return defaultAttemptThreshold
}

View File

@@ -1,36 +0,0 @@
package healthcheck
import (
"os"
"testing"
)
//nolint:tenv
func TestGetAttemptThresholdFromEnv(t *testing.T) {
tests := []struct {
name string
envValue string
expected int
}{
{"Default attempt threshold when env is not set", "", defaultAttemptThreshold},
{"Custom attempt threshold when env is set to a valid integer", "3", 3},
{"Default attempt threshold when env is set to an invalid value", "invalid", defaultAttemptThreshold},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if tt.envValue == "" {
os.Unsetenv(defaultAttemptThresholdEnv)
} else {
os.Setenv(defaultAttemptThresholdEnv, tt.envValue)
}
result := getAttemptThresholdFromEnv()
if result != tt.expected {
t.Fatalf("Expected %d, got %d", tt.expected, result)
}
os.Unsetenv(defaultAttemptThresholdEnv)
})
}
}

View File

@@ -7,15 +7,10 @@ import (
log "github.com/sirupsen/logrus"
)
const (
defaultHeartbeatTimeout = defaultHealthCheckInterval + 10*time.Second
var (
heartbeatTimeout = healthCheckInterval + 10*time.Second
)
type ReceiverOptions struct {
HeartbeatTimeout time.Duration
AttemptThreshold int
}
// Receiver is a healthcheck receiver
// It will listen for heartbeat and check if the heartbeat is not received in a certain time
// If the heartbeat is not received in a certain time, it will send a timeout signal and stop to work
@@ -32,23 +27,6 @@ type Receiver struct {
// NewReceiver creates a new healthcheck receiver and start the timer in the background
func NewReceiver(log *log.Entry) *Receiver {
opts := ReceiverOptions{
HeartbeatTimeout: defaultHeartbeatTimeout,
AttemptThreshold: getAttemptThresholdFromEnv(),
}
return NewReceiverWithOpts(log, opts)
}
func NewReceiverWithOpts(log *log.Entry, opts ReceiverOptions) *Receiver {
heartbeatTimeout := opts.HeartbeatTimeout
if heartbeatTimeout <= 0 {
heartbeatTimeout = defaultHeartbeatTimeout
}
attemptThreshold := opts.AttemptThreshold
if attemptThreshold <= 0 {
attemptThreshold = defaultAttemptThreshold
}
ctx, ctxCancel := context.WithCancel(context.Background())
r := &Receiver{
@@ -57,10 +35,10 @@ func NewReceiverWithOpts(log *log.Entry, opts ReceiverOptions) *Receiver {
ctx: ctx,
ctxCancel: ctxCancel,
heartbeat: make(chan struct{}, 1),
attemptThreshold: attemptThreshold,
attemptThreshold: getAttemptThresholdFromEnv(),
}
go r.waitForHealthcheck(heartbeatTimeout)
go r.waitForHealthcheck()
return r
}
@@ -77,7 +55,7 @@ func (r *Receiver) Stop() {
r.ctxCancel()
}
func (r *Receiver) waitForHealthcheck(heartbeatTimeout time.Duration) {
func (r *Receiver) waitForHealthcheck() {
ticker := time.NewTicker(heartbeatTimeout)
defer ticker.Stop()
defer r.ctxCancel()

View File

@@ -2,18 +2,31 @@ package healthcheck
import (
"context"
"fmt"
"os"
"sync"
"testing"
"time"
log "github.com/sirupsen/logrus"
)
func TestNewReceiver(t *testing.T) {
// Mutex to protect global variable access in tests
var testMutex sync.Mutex
opts := ReceiverOptions{
HeartbeatTimeout: 5 * time.Second,
}
r := NewReceiverWithOpts(log.WithContext(context.Background()), opts)
func TestNewReceiver(t *testing.T) {
testMutex.Lock()
originalTimeout := heartbeatTimeout
heartbeatTimeout = 5 * time.Second
testMutex.Unlock()
defer func() {
testMutex.Lock()
heartbeatTimeout = originalTimeout
testMutex.Unlock()
}()
r := NewReceiver(log.WithContext(context.Background()))
defer r.Stop()
select {
@@ -25,10 +38,18 @@ func TestNewReceiver(t *testing.T) {
}
func TestNewReceiverNotReceive(t *testing.T) {
opts := ReceiverOptions{
HeartbeatTimeout: 1 * time.Second,
}
r := NewReceiverWithOpts(log.WithContext(context.Background()), opts)
testMutex.Lock()
originalTimeout := heartbeatTimeout
heartbeatTimeout = 1 * time.Second
testMutex.Unlock()
defer func() {
testMutex.Lock()
heartbeatTimeout = originalTimeout
testMutex.Unlock()
}()
r := NewReceiver(log.WithContext(context.Background()))
defer r.Stop()
select {
@@ -40,10 +61,18 @@ func TestNewReceiverNotReceive(t *testing.T) {
}
func TestNewReceiverAck(t *testing.T) {
opts := ReceiverOptions{
HeartbeatTimeout: 2 * time.Second,
}
r := NewReceiverWithOpts(log.WithContext(context.Background()), opts)
testMutex.Lock()
originalTimeout := heartbeatTimeout
heartbeatTimeout = 2 * time.Second
testMutex.Unlock()
defer func() {
testMutex.Lock()
heartbeatTimeout = originalTimeout
testMutex.Unlock()
}()
r := NewReceiver(log.WithContext(context.Background()))
defer r.Stop()
r.Heartbeat()
@@ -68,19 +97,30 @@ func TestReceiverHealthCheckAttemptThreshold(t *testing.T) {
for _, tc := range testsCases {
t.Run(tc.name, func(t *testing.T) {
healthCheckInterval := 1 * time.Second
testMutex.Lock()
originalInterval := healthCheckInterval
originalTimeout := heartbeatTimeout
healthCheckInterval = 1 * time.Second
heartbeatTimeout = healthCheckInterval + 500*time.Millisecond
testMutex.Unlock()
opts := ReceiverOptions{
HeartbeatTimeout: healthCheckInterval + 500*time.Millisecond,
AttemptThreshold: tc.threshold,
}
defer func() {
testMutex.Lock()
healthCheckInterval = originalInterval
heartbeatTimeout = originalTimeout
testMutex.Unlock()
}()
//nolint:tenv
os.Setenv(defaultAttemptThresholdEnv, fmt.Sprintf("%d", tc.threshold))
defer os.Unsetenv(defaultAttemptThresholdEnv)
receiver := NewReceiverWithOpts(log.WithField("test_name", tc.name), opts)
receiver := NewReceiver(log.WithField("test_name", tc.name))
testTimeout := opts.HeartbeatTimeout*time.Duration(tc.threshold) + healthCheckInterval
testTimeout := heartbeatTimeout*time.Duration(tc.threshold) + healthCheckInterval
if tc.resetCounterOnce {
receiver.Heartbeat()
t.Logf("reset counter once")
}
select {
@@ -94,6 +134,7 @@ func TestReceiverHealthCheckAttemptThreshold(t *testing.T) {
}
t.Fatalf("should have timed out before %s", testTimeout)
}
})
}
}

View File

@@ -2,76 +2,52 @@ package healthcheck
import (
"context"
"os"
"strconv"
"time"
log "github.com/sirupsen/logrus"
)
const (
defaultAttemptThreshold = 1
defaultHealthCheckInterval = 25 * time.Second
defaultHealthCheckTimeout = 20 * time.Second
defaultAttemptThreshold = 1
defaultAttemptThresholdEnv = "NB_RELAY_HC_ATTEMPT_THRESHOLD"
)
type SenderOptions struct {
HealthCheckInterval time.Duration
HealthCheckTimeout time.Duration
AttemptThreshold int
}
var (
healthCheckInterval = 25 * time.Second
healthCheckTimeout = 20 * time.Second
)
// Sender is a healthcheck sender
// It will send healthcheck signal to the receiver
// If the receiver does not receive the signal in a certain time, it will send a timeout signal and stop to work
// It will also stop if the context is canceled
type Sender struct {
log *log.Entry
// HealthCheck is a channel to send health check signal to the peer
HealthCheck chan struct{}
// Timeout is a channel to the health check signal is not received in a certain time
Timeout chan struct{}
log *log.Entry
healthCheckInterval time.Duration
timeout time.Duration
ack chan struct{}
alive bool
attemptThreshold int
}
func NewSenderWithOpts(log *log.Entry, opts SenderOptions) *Sender {
if opts.HealthCheckInterval <= 0 {
opts.HealthCheckInterval = defaultHealthCheckInterval
}
if opts.HealthCheckTimeout <= 0 {
opts.HealthCheckTimeout = defaultHealthCheckTimeout
}
if opts.AttemptThreshold <= 0 {
opts.AttemptThreshold = defaultAttemptThreshold
}
// NewSender creates a new healthcheck sender
func NewSender(log *log.Entry) *Sender {
hc := &Sender{
HealthCheck: make(chan struct{}, 1),
Timeout: make(chan struct{}, 1),
log: log,
healthCheckInterval: opts.HealthCheckInterval,
timeout: opts.HealthCheckInterval + opts.HealthCheckTimeout,
ack: make(chan struct{}, 1),
attemptThreshold: opts.AttemptThreshold,
log: log,
HealthCheck: make(chan struct{}, 1),
Timeout: make(chan struct{}, 1),
ack: make(chan struct{}, 1),
attemptThreshold: getAttemptThresholdFromEnv(),
}
return hc
}
// NewSender creates a new healthcheck sender
func NewSender(log *log.Entry) *Sender {
opts := SenderOptions{
HealthCheckInterval: defaultHealthCheckInterval,
HealthCheckTimeout: defaultHealthCheckTimeout,
AttemptThreshold: getAttemptThresholdFromEnv(),
}
return NewSenderWithOpts(log, opts)
}
// OnHCResponse sends an acknowledgment signal to the sender
func (hc *Sender) OnHCResponse() {
select {
@@ -81,10 +57,10 @@ func (hc *Sender) OnHCResponse() {
}
func (hc *Sender) StartHealthCheck(ctx context.Context) {
ticker := time.NewTicker(hc.healthCheckInterval)
ticker := time.NewTicker(healthCheckInterval)
defer ticker.Stop()
timeoutTicker := time.NewTicker(hc.timeout)
timeoutTicker := time.NewTicker(hc.getTimeoutTime())
defer timeoutTicker.Stop()
defer close(hc.HealthCheck)
@@ -116,3 +92,19 @@ func (hc *Sender) StartHealthCheck(ctx context.Context) {
}
}
}
func (hc *Sender) getTimeoutTime() time.Duration {
return healthCheckInterval + healthCheckTimeout
}
func getAttemptThresholdFromEnv() int {
if attemptThreshold := os.Getenv(defaultAttemptThresholdEnv); attemptThreshold != "" {
threshold, err := strconv.ParseInt(attemptThreshold, 10, 64)
if err != nil {
log.Errorf("Failed to parse attempt threshold from environment variable \"%s\" should be an integer. Using default value", attemptThreshold)
return defaultAttemptThreshold
}
return int(threshold)
}
return defaultAttemptThreshold
}

View File

@@ -2,23 +2,26 @@ package healthcheck
import (
"context"
"fmt"
"os"
"testing"
"time"
log "github.com/sirupsen/logrus"
)
var (
testOpts = SenderOptions{
HealthCheckInterval: 2 * time.Second,
HealthCheckTimeout: 100 * time.Millisecond,
}
)
func TestMain(m *testing.M) {
// override the health check interval to speed up the test
healthCheckInterval = 2 * time.Second
healthCheckTimeout = 100 * time.Millisecond
code := m.Run()
os.Exit(code)
}
func TestNewHealthPeriod(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
hc := NewSenderWithOpts(log.WithContext(ctx), testOpts)
hc := NewSender(log.WithContext(ctx))
go hc.StartHealthCheck(ctx)
iterations := 0
@@ -29,7 +32,7 @@ func TestNewHealthPeriod(t *testing.T) {
hc.OnHCResponse()
case <-hc.Timeout:
t.Fatalf("health check is timed out")
case <-time.After(testOpts.HealthCheckInterval + 100*time.Millisecond):
case <-time.After(healthCheckInterval + 100*time.Millisecond):
t.Fatalf("health check not received")
}
}
@@ -38,19 +41,19 @@ func TestNewHealthPeriod(t *testing.T) {
func TestNewHealthFailed(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
hc := NewSenderWithOpts(log.WithContext(ctx), testOpts)
hc := NewSender(log.WithContext(ctx))
go hc.StartHealthCheck(ctx)
select {
case <-hc.Timeout:
case <-time.After(testOpts.HealthCheckInterval + testOpts.HealthCheckTimeout + 100*time.Millisecond):
case <-time.After(healthCheckInterval + healthCheckTimeout + 100*time.Millisecond):
t.Fatalf("health check is not timed out")
}
}
func TestNewHealthcheckStop(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
hc := NewSenderWithOpts(log.WithContext(ctx), testOpts)
hc := NewSender(log.WithContext(ctx))
go hc.StartHealthCheck(ctx)
time.Sleep(100 * time.Millisecond)
@@ -75,7 +78,7 @@ func TestNewHealthcheckStop(t *testing.T) {
func TestTimeoutReset(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
hc := NewSenderWithOpts(log.WithContext(ctx), testOpts)
hc := NewSender(log.WithContext(ctx))
go hc.StartHealthCheck(ctx)
iterations := 0
@@ -86,7 +89,7 @@ func TestTimeoutReset(t *testing.T) {
hc.OnHCResponse()
case <-hc.Timeout:
t.Fatalf("health check is timed out")
case <-time.After(testOpts.HealthCheckInterval + 100*time.Millisecond):
case <-time.After(healthCheckInterval + 100*time.Millisecond):
t.Fatalf("health check not received")
}
}
@@ -115,16 +118,19 @@ func TestSenderHealthCheckAttemptThreshold(t *testing.T) {
for _, tc := range testsCases {
t.Run(tc.name, func(t *testing.T) {
opts := SenderOptions{
HealthCheckInterval: 1 * time.Second,
HealthCheckTimeout: 500 * time.Millisecond,
AttemptThreshold: tc.threshold,
}
originalInterval := healthCheckInterval
originalTimeout := healthCheckTimeout
healthCheckInterval = 1 * time.Second
healthCheckTimeout = 500 * time.Millisecond
//nolint:tenv
os.Setenv(defaultAttemptThresholdEnv, fmt.Sprintf("%d", tc.threshold))
defer os.Unsetenv(defaultAttemptThresholdEnv)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
sender := NewSenderWithOpts(log.WithField("test_name", tc.name), opts)
sender := NewSender(log.WithField("test_name", tc.name))
senderExit := make(chan struct{})
go func() {
sender.StartHealthCheck(ctx)
@@ -149,7 +155,7 @@ func TestSenderHealthCheckAttemptThreshold(t *testing.T) {
}
}()
testTimeout := (opts.HealthCheckInterval+opts.HealthCheckTimeout)*time.Duration(tc.threshold) + opts.HealthCheckInterval
testTimeout := sender.getTimeoutTime()*time.Duration(tc.threshold) + healthCheckInterval
select {
case <-sender.Timeout:
@@ -169,7 +175,39 @@ func TestSenderHealthCheckAttemptThreshold(t *testing.T) {
case <-time.After(2 * time.Second):
t.Fatalf("sender did not exit in time")
}
healthCheckInterval = originalInterval
healthCheckTimeout = originalTimeout
})
}
}
//nolint:tenv
func TestGetAttemptThresholdFromEnv(t *testing.T) {
tests := []struct {
name string
envValue string
expected int
}{
{"Default attempt threshold when env is not set", "", defaultAttemptThreshold},
{"Custom attempt threshold when env is set to a valid integer", "3", 3},
{"Default attempt threshold when env is set to an invalid value", "invalid", defaultAttemptThreshold},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if tt.envValue == "" {
os.Unsetenv(defaultAttemptThresholdEnv)
} else {
os.Setenv(defaultAttemptThresholdEnv, tt.envValue)
}
result := getAttemptThresholdFromEnv()
if result != tt.expected {
t.Fatalf("Expected %d, got %d", tt.expected, result)
}
os.Unsetenv(defaultAttemptThresholdEnv)
})
}
}

View File

@@ -14,7 +14,7 @@ import (
"github.com/netbirdio/netbird/formatter"
)
const defaultLogSize = 15
const defaultLogSize = 100
const (
LogConsole = "console"