mirror of
https://github.com/netbirdio/netbird.git
synced 2026-05-04 15:02:36 -04:00
Compare commits
38 Commits
fix/up-seq
...
debug-and-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a4b8ddac52 | ||
|
|
0ce32ec527 | ||
|
|
0a03145709 | ||
|
|
af18385291 | ||
|
|
b1f6e9b12d | ||
|
|
180b52ba26 | ||
|
|
f86a7f745e | ||
|
|
fd13247d66 | ||
|
|
1d83fccd9c | ||
|
|
7d3c972653 | ||
|
|
990aa8f7cd | ||
|
|
2e6a44af1f | ||
|
|
559d347588 | ||
|
|
06d71257b4 | ||
|
|
62d10496ee | ||
|
|
651e88d611 | ||
|
|
d496d21693 | ||
|
|
648b4cdf72 | ||
|
|
f0020ad4ce | ||
|
|
1eacff250e | ||
|
|
1f83ba4563 | ||
|
|
3d80a25b4d | ||
|
|
1963644c99 | ||
|
|
360c7134f7 | ||
|
|
775b4feb7e | ||
|
|
aca443bdec | ||
|
|
335866ac60 | ||
|
|
082452eb5f | ||
|
|
b17c1d96a5 | ||
|
|
2d5b5f59c2 | ||
|
|
d5042f688f | ||
|
|
4db73a13d7 | ||
|
|
06a17f0eee | ||
|
|
1f088b7e69 | ||
|
|
ffe74365a8 | ||
|
|
6a0f6efc18 | ||
|
|
bfa6df13c5 | ||
|
|
e9b3b6210d |
2
.github/workflows/golang-test-linux.yml
vendored
2
.github/workflows/golang-test-linux.yml
vendored
@@ -217,7 +217,7 @@ jobs:
|
||||
- arch: "386"
|
||||
raceFlag: ""
|
||||
- arch: "amd64"
|
||||
raceFlag: "-race"
|
||||
raceFlag: ""
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Install Go
|
||||
|
||||
@@ -4,7 +4,6 @@ package android
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"slices"
|
||||
"sync"
|
||||
|
||||
@@ -84,8 +83,7 @@ func NewClient(cfgFile string, androidSDKVersion int, deviceName string, uiVersi
|
||||
}
|
||||
|
||||
// Run start the internal client. It is a blocker function
|
||||
func (c *Client) Run(urlOpener URLOpener, dns *DNSList, dnsReadyListener DnsReadyListener, envList *EnvList) error {
|
||||
exportEnvList(envList)
|
||||
func (c *Client) Run(urlOpener URLOpener, dns *DNSList, dnsReadyListener DnsReadyListener) error {
|
||||
cfg, err := profilemanager.UpdateOrCreateConfig(profilemanager.ConfigInput{
|
||||
ConfigPath: c.cfgFile,
|
||||
})
|
||||
@@ -120,8 +118,7 @@ func (c *Client) Run(urlOpener URLOpener, dns *DNSList, dnsReadyListener DnsRead
|
||||
|
||||
// RunWithoutLogin we apply this type of run function when the backed has been started without UI (i.e. after reboot).
|
||||
// In this case make no sense handle registration steps.
|
||||
func (c *Client) RunWithoutLogin(dns *DNSList, dnsReadyListener DnsReadyListener, envList *EnvList) error {
|
||||
exportEnvList(envList)
|
||||
func (c *Client) RunWithoutLogin(dns *DNSList, dnsReadyListener DnsReadyListener) error {
|
||||
cfg, err := profilemanager.UpdateOrCreateConfig(profilemanager.ConfigInput{
|
||||
ConfigPath: c.cfgFile,
|
||||
})
|
||||
@@ -252,14 +249,3 @@ func (c *Client) SetConnectionListener(listener ConnectionListener) {
|
||||
func (c *Client) RemoveConnectionListener() {
|
||||
c.recorder.RemoveConnectionListener()
|
||||
}
|
||||
|
||||
func exportEnvList(list *EnvList) {
|
||||
if list == nil {
|
||||
return
|
||||
}
|
||||
for k, v := range list.AllItems() {
|
||||
if err := os.Setenv(k, v); err != nil {
|
||||
log.Errorf("could not set env variable %s: %v", k, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,32 +0,0 @@
|
||||
package android
|
||||
|
||||
import "github.com/netbirdio/netbird/client/internal/peer"
|
||||
|
||||
var (
|
||||
// EnvKeyNBForceRelay Exported for Android java client
|
||||
EnvKeyNBForceRelay = peer.EnvKeyNBForceRelay
|
||||
)
|
||||
|
||||
// EnvList wraps a Go map for export to Java
|
||||
type EnvList struct {
|
||||
data map[string]string
|
||||
}
|
||||
|
||||
// NewEnvList creates a new EnvList
|
||||
func NewEnvList() *EnvList {
|
||||
return &EnvList{data: make(map[string]string)}
|
||||
}
|
||||
|
||||
// Put adds a key-value pair
|
||||
func (el *EnvList) Put(key, value string) {
|
||||
el.data[key] = value
|
||||
}
|
||||
|
||||
// Get retrieves a value by key
|
||||
func (el *EnvList) Get(key string) string {
|
||||
return el.data[key]
|
||||
}
|
||||
|
||||
func (el *EnvList) AllItems() map[string]string {
|
||||
return el.data
|
||||
}
|
||||
@@ -33,7 +33,6 @@ type ErrListener interface {
|
||||
// the backend want to show an url for the user
|
||||
type URLOpener interface {
|
||||
Open(string)
|
||||
OnLoginSuccess()
|
||||
}
|
||||
|
||||
// Auth can register or login new client
|
||||
@@ -182,11 +181,6 @@ func (a *Auth) login(urlOpener URLOpener) error {
|
||||
|
||||
err = a.withBackOff(a.ctx, func() error {
|
||||
err := internal.Login(a.ctx, a.config, "", jwtToken)
|
||||
|
||||
if err == nil {
|
||||
go urlOpener.OnLoginSuccess()
|
||||
}
|
||||
|
||||
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied) {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -227,7 +227,7 @@ func doForegroundLogin(ctx context.Context, cmd *cobra.Command, setupKey string,
|
||||
}
|
||||
|
||||
// update host's static platform and system information
|
||||
system.UpdateStaticInfoAsync()
|
||||
system.UpdateStaticInfo()
|
||||
|
||||
configFilePath, err := activeProf.FilePath()
|
||||
if err != nil {
|
||||
|
||||
@@ -231,7 +231,7 @@ func FlagNameToEnvVar(cmdFlag string, prefix string) string {
|
||||
|
||||
// DialClientGRPCServer returns client connection to the daemon server.
|
||||
func DialClientGRPCServer(ctx context.Context, addr string) (*grpc.ClientConn, error) {
|
||||
ctx, cancel := context.WithTimeout(ctx, time.Second*10)
|
||||
ctx, cancel := context.WithTimeout(ctx, time.Second*3)
|
||||
defer cancel()
|
||||
|
||||
return grpc.DialContext(
|
||||
|
||||
@@ -27,7 +27,7 @@ func (p *program) Start(svc service.Service) error {
|
||||
log.Info("starting NetBird service") //nolint
|
||||
|
||||
// Collect static system and platform information
|
||||
system.UpdateStaticInfoAsync()
|
||||
system.UpdateStaticInfo()
|
||||
|
||||
// in any case, even if configuration does not exists we run daemon to serve CLI gRPC API.
|
||||
p.serv = grpc.NewServer()
|
||||
|
||||
@@ -9,26 +9,29 @@ import (
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.opentelemetry.io/otel"
|
||||
"google.golang.org/grpc"
|
||||
|
||||
"github.com/netbirdio/management-integrations/integrations"
|
||||
clientProto "github.com/netbirdio/netbird/client/proto"
|
||||
client "github.com/netbirdio/netbird/client/server"
|
||||
"github.com/netbirdio/netbird/management/internals/server/config"
|
||||
mgmt "github.com/netbirdio/netbird/management/server"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/groups"
|
||||
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
||||
"github.com/netbirdio/netbird/management/server/peers"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/settings"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
|
||||
"github.com/netbirdio/netbird/util"
|
||||
|
||||
"google.golang.org/grpc"
|
||||
|
||||
"github.com/netbirdio/management-integrations/integrations"
|
||||
|
||||
clientProto "github.com/netbirdio/netbird/client/proto"
|
||||
client "github.com/netbirdio/netbird/client/server"
|
||||
mgmt "github.com/netbirdio/netbird/management/server"
|
||||
mgmtProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||
sigProto "github.com/netbirdio/netbird/shared/signal/proto"
|
||||
sig "github.com/netbirdio/netbird/signal/server"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
|
||||
func startTestingServices(t *testing.T) string {
|
||||
@@ -87,20 +90,15 @@ func startManagement(t *testing.T, config *config.Config, testFile string) (*grp
|
||||
if err != nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
t.Cleanup(ctrl.Finish)
|
||||
|
||||
permissionsManagerMock := permissions.NewMockManager(ctrl)
|
||||
peersmanager := peers.NewManager(store, permissionsManagerMock)
|
||||
settingsManagerMock := settings.NewMockManager(ctrl)
|
||||
|
||||
iv, _ := integrations.NewIntegratedValidator(context.Background(), peersmanager, settingsManagerMock, eventStore)
|
||||
iv, _ := integrations.NewIntegratedValidator(context.Background(), eventStore)
|
||||
|
||||
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
|
||||
require.NoError(t, err)
|
||||
ctrl := gomock.NewController(t)
|
||||
t.Cleanup(ctrl.Finish)
|
||||
|
||||
settingsMockManager := settings.NewMockManager(ctrl)
|
||||
permissionsManagerMock := permissions.NewMockManager(ctrl)
|
||||
groupsManager := groups.NewManagerMock()
|
||||
|
||||
settingsMockManager.EXPECT().
|
||||
|
||||
@@ -230,7 +230,7 @@ func runInDaemonMode(ctx context.Context, cmd *cobra.Command, pm *profilemanager
|
||||
|
||||
client := proto.NewDaemonServiceClient(conn)
|
||||
|
||||
status, err := client.Status(ctx, &proto.StatusRequest{WaitForConnectingShift: true})
|
||||
status, err := client.Status(ctx, &proto.StatusRequest{})
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to get daemon status: %v", err)
|
||||
}
|
||||
|
||||
@@ -1,5 +1,17 @@
|
||||
package bind
|
||||
|
||||
import wgConn "golang.zx2c4.com/wireguard/conn"
|
||||
import (
|
||||
"net"
|
||||
|
||||
wgConn "golang.zx2c4.com/wireguard/conn"
|
||||
)
|
||||
|
||||
type Endpoint = wgConn.StdNetEndpoint
|
||||
|
||||
func EndpointToUDPAddr(e Endpoint) *net.UDPAddr {
|
||||
return &net.UDPAddr{
|
||||
IP: e.Addr().AsSlice(),
|
||||
Port: int(e.Port()),
|
||||
Zone: e.Addr().Zone(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package bind
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"net"
|
||||
@@ -15,7 +16,6 @@ import (
|
||||
"golang.org/x/net/ipv6"
|
||||
wgConn "golang.zx2c4.com/wireguard/conn"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/udpmux"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
)
|
||||
@@ -42,10 +42,10 @@ func (rc receiverCreator) CreateIPv4ReceiverFn(pc *ipv4.PacketConn, conn *net.UD
|
||||
// use the port because in the Send function the wgConn.Endpoint the port info is not exported.
|
||||
type ICEBind struct {
|
||||
*wgConn.StdNetBind
|
||||
RecvChan chan RecvMessage
|
||||
recvChan chan RecvMessage
|
||||
|
||||
transportNet transport.Net
|
||||
filterFn udpmux.FilterFn
|
||||
filterFn FilterFn
|
||||
endpoints map[netip.Addr]net.Conn
|
||||
endpointsMu sync.Mutex
|
||||
// every time when Close() is called (i.e. BindUpdate()) we need to close exit from the receiveRelayed and create a
|
||||
@@ -55,17 +55,17 @@ type ICEBind struct {
|
||||
closed bool
|
||||
|
||||
muUDPMux sync.Mutex
|
||||
udpMux *udpmux.UniversalUDPMuxDefault
|
||||
udpMux *UniversalUDPMuxDefault
|
||||
address wgaddr.Address
|
||||
mtu uint16
|
||||
activityRecorder *ActivityRecorder
|
||||
}
|
||||
|
||||
func NewICEBind(transportNet transport.Net, filterFn udpmux.FilterFn, address wgaddr.Address, mtu uint16) *ICEBind {
|
||||
func NewICEBind(transportNet transport.Net, filterFn FilterFn, address wgaddr.Address, mtu uint16) *ICEBind {
|
||||
b, _ := wgConn.NewStdNetBind().(*wgConn.StdNetBind)
|
||||
ib := &ICEBind{
|
||||
StdNetBind: b,
|
||||
RecvChan: make(chan RecvMessage, 1),
|
||||
recvChan: make(chan RecvMessage, 1),
|
||||
transportNet: transportNet,
|
||||
filterFn: filterFn,
|
||||
endpoints: make(map[netip.Addr]net.Conn),
|
||||
@@ -116,7 +116,7 @@ func (s *ICEBind) ActivityRecorder() *ActivityRecorder {
|
||||
}
|
||||
|
||||
// GetICEMux returns the ICE UDPMux that was created and used by ICEBind
|
||||
func (s *ICEBind) GetICEMux() (*udpmux.UniversalUDPMuxDefault, error) {
|
||||
func (s *ICEBind) GetICEMux() (*UniversalUDPMuxDefault, error) {
|
||||
s.muUDPMux.Lock()
|
||||
defer s.muUDPMux.Unlock()
|
||||
if s.udpMux == nil {
|
||||
@@ -155,12 +155,20 @@ func (b *ICEBind) Send(bufs [][]byte, ep wgConn.Endpoint) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *ICEBind) Recv(ctx context.Context, msg RecvMessage) {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case b.recvChan <- msg:
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ICEBind) createIPv4ReceiverFn(pc *ipv4.PacketConn, conn *net.UDPConn, rxOffload bool, msgsPool *sync.Pool) wgConn.ReceiveFunc {
|
||||
s.muUDPMux.Lock()
|
||||
defer s.muUDPMux.Unlock()
|
||||
|
||||
s.udpMux = udpmux.NewUniversalUDPMuxDefault(
|
||||
udpmux.UniversalUDPMuxParams{
|
||||
s.udpMux = NewUniversalUDPMuxDefault(
|
||||
UniversalUDPMuxParams{
|
||||
UDPConn: nbnet.WrapPacketConn(conn),
|
||||
Net: s.transportNet,
|
||||
FilterFn: s.filterFn,
|
||||
@@ -271,7 +279,7 @@ func (c *ICEBind) receiveRelayed(buffs [][]byte, sizes []int, eps []wgConn.Endpo
|
||||
select {
|
||||
case <-c.closedChan:
|
||||
return 0, net.ErrClosed
|
||||
case msg, ok := <-c.RecvChan:
|
||||
case msg, ok := <-c.recvChan:
|
||||
if !ok {
|
||||
return 0, net.ErrClosed
|
||||
}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
package udpmux
|
||||
package bind
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
@@ -22,9 +22,9 @@ import (
|
||||
|
||||
const receiveMTU = 8192
|
||||
|
||||
// SingleSocketUDPMux is an implementation of the interface
|
||||
type SingleSocketUDPMux struct {
|
||||
params Params
|
||||
// UDPMuxDefault is an implementation of the interface
|
||||
type UDPMuxDefault struct {
|
||||
params UDPMuxParams
|
||||
|
||||
closedChan chan struct{}
|
||||
closeOnce sync.Once
|
||||
@@ -32,9 +32,6 @@ type SingleSocketUDPMux struct {
|
||||
// connsIPv4 and connsIPv6 are maps of all udpMuxedConn indexed by ufrag|network|candidateType
|
||||
connsIPv4, connsIPv6 map[string]*udpMuxedConn
|
||||
|
||||
// candidateConnMap maps local candidate IDs to their corresponding connection.
|
||||
candidateConnMap map[string]*udpMuxedConn
|
||||
|
||||
addressMapMu sync.RWMutex
|
||||
addressMap map[string][]*udpMuxedConn
|
||||
|
||||
@@ -49,8 +46,8 @@ type SingleSocketUDPMux struct {
|
||||
|
||||
const maxAddrSize = 512
|
||||
|
||||
// Params are parameters for UDPMux.
|
||||
type Params struct {
|
||||
// UDPMuxParams are parameters for UDPMux.
|
||||
type UDPMuxParams struct {
|
||||
Logger logging.LeveledLogger
|
||||
UDPConn net.PacketConn
|
||||
|
||||
@@ -150,19 +147,18 @@ func isZeros(ip net.IP) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// NewSingleSocketUDPMux creates an implementation of UDPMux
|
||||
func NewSingleSocketUDPMux(params Params) *SingleSocketUDPMux {
|
||||
// NewUDPMuxDefault creates an implementation of UDPMux
|
||||
func NewUDPMuxDefault(params UDPMuxParams) *UDPMuxDefault {
|
||||
if params.Logger == nil {
|
||||
params.Logger = getLogger()
|
||||
}
|
||||
|
||||
mux := &SingleSocketUDPMux{
|
||||
addressMap: map[string][]*udpMuxedConn{},
|
||||
params: params,
|
||||
connsIPv4: make(map[string]*udpMuxedConn),
|
||||
connsIPv6: make(map[string]*udpMuxedConn),
|
||||
candidateConnMap: make(map[string]*udpMuxedConn),
|
||||
closedChan: make(chan struct{}, 1),
|
||||
mux := &UDPMuxDefault{
|
||||
addressMap: map[string][]*udpMuxedConn{},
|
||||
params: params,
|
||||
connsIPv4: make(map[string]*udpMuxedConn),
|
||||
connsIPv6: make(map[string]*udpMuxedConn),
|
||||
closedChan: make(chan struct{}, 1),
|
||||
pool: &sync.Pool{
|
||||
New: func() interface{} {
|
||||
// big enough buffer to fit both packet and address
|
||||
@@ -175,15 +171,15 @@ func NewSingleSocketUDPMux(params Params) *SingleSocketUDPMux {
|
||||
return mux
|
||||
}
|
||||
|
||||
func (m *SingleSocketUDPMux) updateLocalAddresses() {
|
||||
func (m *UDPMuxDefault) updateLocalAddresses() {
|
||||
var localAddrsForUnspecified []net.Addr
|
||||
if addr, ok := m.params.UDPConn.LocalAddr().(*net.UDPAddr); !ok {
|
||||
m.params.Logger.Errorf("LocalAddr is not a net.UDPAddr, got %T", m.params.UDPConn.LocalAddr())
|
||||
} else if ok && addr.IP.IsUnspecified() {
|
||||
// For unspecified addresses, the correct behavior is to return errListenUnspecified, but
|
||||
// it will break the applications that are already using unspecified UDP connection
|
||||
// with SingleSocketUDPMux, so print a warn log and create a local address list for mux.
|
||||
m.params.Logger.Warn("SingleSocketUDPMux should not listening on unspecified address, use NewMultiUDPMuxFromPort instead")
|
||||
// with UDPMuxDefault, so print a warn log and create a local address list for mux.
|
||||
m.params.Logger.Warn("UDPMuxDefault should not listening on unspecified address, use NewMultiUDPMuxFromPort instead")
|
||||
var networks []ice.NetworkType
|
||||
switch {
|
||||
|
||||
@@ -220,13 +216,13 @@ func (m *SingleSocketUDPMux) updateLocalAddresses() {
|
||||
m.mu.Unlock()
|
||||
}
|
||||
|
||||
// LocalAddr returns the listening address of this SingleSocketUDPMux
|
||||
func (m *SingleSocketUDPMux) LocalAddr() net.Addr {
|
||||
// LocalAddr returns the listening address of this UDPMuxDefault
|
||||
func (m *UDPMuxDefault) LocalAddr() net.Addr {
|
||||
return m.params.UDPConn.LocalAddr()
|
||||
}
|
||||
|
||||
// GetListenAddresses returns the list of addresses that this mux is listening on
|
||||
func (m *SingleSocketUDPMux) GetListenAddresses() []net.Addr {
|
||||
func (m *UDPMuxDefault) GetListenAddresses() []net.Addr {
|
||||
m.updateLocalAddresses()
|
||||
|
||||
m.mu.Lock()
|
||||
@@ -240,7 +236,7 @@ func (m *SingleSocketUDPMux) GetListenAddresses() []net.Addr {
|
||||
|
||||
// GetConn returns a PacketConn given the connection's ufrag and network address
|
||||
// creates the connection if an existing one can't be found
|
||||
func (m *SingleSocketUDPMux) GetConn(ufrag string, addr net.Addr, candidateID string) (net.PacketConn, error) {
|
||||
func (m *UDPMuxDefault) GetConn(ufrag string, addr net.Addr) (net.PacketConn, error) {
|
||||
// don't check addr for mux using unspecified address
|
||||
m.mu.Lock()
|
||||
lenLocalAddrs := len(m.localAddrsForUnspecified)
|
||||
@@ -264,14 +260,12 @@ func (m *SingleSocketUDPMux) GetConn(ufrag string, addr net.Addr, candidateID st
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
c := m.createMuxedConn(ufrag, candidateID)
|
||||
c := m.createMuxedConn(ufrag)
|
||||
go func() {
|
||||
<-c.CloseChannel()
|
||||
m.RemoveConnByUfrag(ufrag)
|
||||
}()
|
||||
|
||||
m.candidateConnMap[candidateID] = c
|
||||
|
||||
if isIPv6 {
|
||||
m.connsIPv6[ufrag] = c
|
||||
} else {
|
||||
@@ -282,7 +276,7 @@ func (m *SingleSocketUDPMux) GetConn(ufrag string, addr net.Addr, candidateID st
|
||||
}
|
||||
|
||||
// RemoveConnByUfrag stops and removes the muxed packet connection
|
||||
func (m *SingleSocketUDPMux) RemoveConnByUfrag(ufrag string) {
|
||||
func (m *UDPMuxDefault) RemoveConnByUfrag(ufrag string) {
|
||||
removedConns := make([]*udpMuxedConn, 0, 2)
|
||||
|
||||
// Keep lock section small to avoid deadlock with conn lock
|
||||
@@ -290,12 +284,10 @@ func (m *SingleSocketUDPMux) RemoveConnByUfrag(ufrag string) {
|
||||
if c, ok := m.connsIPv4[ufrag]; ok {
|
||||
delete(m.connsIPv4, ufrag)
|
||||
removedConns = append(removedConns, c)
|
||||
delete(m.candidateConnMap, c.GetCandidateID())
|
||||
}
|
||||
if c, ok := m.connsIPv6[ufrag]; ok {
|
||||
delete(m.connsIPv6, ufrag)
|
||||
removedConns = append(removedConns, c)
|
||||
delete(m.candidateConnMap, c.GetCandidateID())
|
||||
}
|
||||
m.mu.Unlock()
|
||||
|
||||
@@ -322,7 +314,7 @@ func (m *SingleSocketUDPMux) RemoveConnByUfrag(ufrag string) {
|
||||
}
|
||||
|
||||
// IsClosed returns true if the mux had been closed
|
||||
func (m *SingleSocketUDPMux) IsClosed() bool {
|
||||
func (m *UDPMuxDefault) IsClosed() bool {
|
||||
select {
|
||||
case <-m.closedChan:
|
||||
return true
|
||||
@@ -332,7 +324,7 @@ func (m *SingleSocketUDPMux) IsClosed() bool {
|
||||
}
|
||||
|
||||
// Close the mux, no further connections could be created
|
||||
func (m *SingleSocketUDPMux) Close() error {
|
||||
func (m *UDPMuxDefault) Close() error {
|
||||
var err error
|
||||
m.closeOnce.Do(func() {
|
||||
m.mu.Lock()
|
||||
@@ -355,11 +347,11 @@ func (m *SingleSocketUDPMux) Close() error {
|
||||
return err
|
||||
}
|
||||
|
||||
func (m *SingleSocketUDPMux) writeTo(buf []byte, rAddr net.Addr) (n int, err error) {
|
||||
func (m *UDPMuxDefault) writeTo(buf []byte, rAddr net.Addr) (n int, err error) {
|
||||
return m.params.UDPConn.WriteTo(buf, rAddr)
|
||||
}
|
||||
|
||||
func (m *SingleSocketUDPMux) registerConnForAddress(conn *udpMuxedConn, addr string) {
|
||||
func (m *UDPMuxDefault) registerConnForAddress(conn *udpMuxedConn, addr string) {
|
||||
if m.IsClosed() {
|
||||
return
|
||||
}
|
||||
@@ -376,109 +368,81 @@ func (m *SingleSocketUDPMux) registerConnForAddress(conn *udpMuxedConn, addr str
|
||||
log.Debugf("ICE: registered %s for %s", addr, conn.params.Key)
|
||||
}
|
||||
|
||||
func (m *SingleSocketUDPMux) createMuxedConn(key string, candidateID string) *udpMuxedConn {
|
||||
func (m *UDPMuxDefault) createMuxedConn(key string) *udpMuxedConn {
|
||||
c := newUDPMuxedConn(&udpMuxedConnParams{
|
||||
Mux: m,
|
||||
Key: key,
|
||||
AddrPool: m.pool,
|
||||
LocalAddr: m.LocalAddr(),
|
||||
Logger: m.params.Logger,
|
||||
CandidateID: candidateID,
|
||||
Mux: m,
|
||||
Key: key,
|
||||
AddrPool: m.pool,
|
||||
LocalAddr: m.LocalAddr(),
|
||||
Logger: m.params.Logger,
|
||||
})
|
||||
return c
|
||||
}
|
||||
|
||||
// HandleSTUNMessage handles STUN packets and forwards them to underlying pion/ice library
|
||||
func (m *SingleSocketUDPMux) HandleSTUNMessage(msg *stun.Message, addr net.Addr) error {
|
||||
func (m *UDPMuxDefault) HandleSTUNMessage(msg *stun.Message, addr net.Addr) error {
|
||||
|
||||
remoteAddr, ok := addr.(*net.UDPAddr)
|
||||
if !ok {
|
||||
return fmt.Errorf("underlying PacketConn did not return a UDPAddr")
|
||||
}
|
||||
|
||||
// Try to route to specific candidate connection first
|
||||
if conn := m.findCandidateConnection(msg); conn != nil {
|
||||
return conn.writePacket(msg.Raw, remoteAddr)
|
||||
}
|
||||
|
||||
// Fallback: route to all possible connections
|
||||
return m.forwardToAllConnections(msg, addr, remoteAddr)
|
||||
}
|
||||
|
||||
// findCandidateConnection attempts to find the specific connection for a STUN message
|
||||
func (m *SingleSocketUDPMux) findCandidateConnection(msg *stun.Message) *udpMuxedConn {
|
||||
candidatePairID, ok, err := ice.CandidatePairIDFromSTUN(msg)
|
||||
if err != nil {
|
||||
return nil
|
||||
} else if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
conn, exists := m.candidateConnMap[candidatePairID.TargetCandidateID()]
|
||||
if !exists {
|
||||
return nil
|
||||
}
|
||||
return conn
|
||||
}
|
||||
|
||||
// forwardToAllConnections forwards STUN message to all relevant connections
|
||||
func (m *SingleSocketUDPMux) forwardToAllConnections(msg *stun.Message, addr net.Addr, remoteAddr *net.UDPAddr) error {
|
||||
var destinationConnList []*udpMuxedConn
|
||||
|
||||
// Add connections from address map
|
||||
// If we have already seen this address dispatch to the appropriate destination
|
||||
// If you are using the same socket for the Host and SRFLX candidates, it might be that there are more than one
|
||||
// muxed connection - one for the SRFLX candidate and the other one for the HOST one.
|
||||
// We will then forward STUN packets to each of these connections.
|
||||
m.addressMapMu.RLock()
|
||||
var destinationConnList []*udpMuxedConn
|
||||
if storedConns, ok := m.addressMap[addr.String()]; ok {
|
||||
destinationConnList = append(destinationConnList, storedConns...)
|
||||
}
|
||||
m.addressMapMu.RUnlock()
|
||||
|
||||
if conn, ok := m.findConnectionByUsername(msg, addr); ok {
|
||||
// If we have already seen this address dispatch to the appropriate destination
|
||||
// If you are using the same socket for the Host and SRFLX candidates, it might be that there are more than one
|
||||
// muxed connection - one for the SRFLX candidate and the other one for the HOST one.
|
||||
// We will then forward STUN packets to each of these connections.
|
||||
if !m.connectionExists(conn, destinationConnList) {
|
||||
destinationConnList = append(destinationConnList, conn)
|
||||
}
|
||||
var isIPv6 bool
|
||||
if udpAddr, _ := addr.(*net.UDPAddr); udpAddr != nil && udpAddr.IP.To4() == nil {
|
||||
isIPv6 = true
|
||||
}
|
||||
|
||||
// Forward to all found connections
|
||||
// This block is needed to discover Peer Reflexive Candidates for which we don't know the Endpoint upfront.
|
||||
// However, we can take a username attribute from the STUN message which contains ufrag.
|
||||
// We can use ufrag to identify the destination conn to route packet to.
|
||||
attr, stunAttrErr := msg.Get(stun.AttrUsername)
|
||||
if stunAttrErr == nil {
|
||||
ufrag := strings.Split(string(attr), ":")[0]
|
||||
|
||||
m.mu.Lock()
|
||||
destinationConn := m.connsIPv4[ufrag]
|
||||
if isIPv6 {
|
||||
destinationConn = m.connsIPv6[ufrag]
|
||||
}
|
||||
|
||||
if destinationConn != nil {
|
||||
exists := false
|
||||
for _, conn := range destinationConnList {
|
||||
if conn.params.Key == destinationConn.params.Key {
|
||||
exists = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !exists {
|
||||
destinationConnList = append(destinationConnList, destinationConn)
|
||||
}
|
||||
}
|
||||
m.mu.Unlock()
|
||||
}
|
||||
|
||||
// Forward STUN packets to each destination connections even thought the STUN packet might not belong there.
|
||||
// It will be discarded by the further ICE candidate logic if so.
|
||||
for _, conn := range destinationConnList {
|
||||
if err := conn.writePacket(msg.Raw, remoteAddr); err != nil {
|
||||
log.Errorf("could not write packet: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// findConnectionByUsername finds connection using username attribute from STUN message
|
||||
func (m *SingleSocketUDPMux) findConnectionByUsername(msg *stun.Message, addr net.Addr) (*udpMuxedConn, bool) {
|
||||
attr, err := msg.Get(stun.AttrUsername)
|
||||
if err != nil {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
ufrag := strings.Split(string(attr), ":")[0]
|
||||
isIPv6 := isIPv6Address(addr)
|
||||
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
return m.getConn(ufrag, isIPv6)
|
||||
}
|
||||
|
||||
// connectionExists checks if a connection already exists in the list
|
||||
func (m *SingleSocketUDPMux) connectionExists(target *udpMuxedConn, conns []*udpMuxedConn) bool {
|
||||
for _, conn := range conns {
|
||||
if conn.params.Key == target.params.Key {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (m *SingleSocketUDPMux) getConn(ufrag string, isIPv6 bool) (val *udpMuxedConn, ok bool) {
|
||||
func (m *UDPMuxDefault) getConn(ufrag string, isIPv6 bool) (val *udpMuxedConn, ok bool) {
|
||||
if isIPv6 {
|
||||
val, ok = m.connsIPv6[ufrag]
|
||||
} else {
|
||||
@@ -487,13 +451,6 @@ func (m *SingleSocketUDPMux) getConn(ufrag string, isIPv6 bool) (val *udpMuxedCo
|
||||
return
|
||||
}
|
||||
|
||||
func isIPv6Address(addr net.Addr) bool {
|
||||
if udpAddr, ok := addr.(*net.UDPAddr); ok {
|
||||
return udpAddr.IP.To4() == nil
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
type bufferHolder struct {
|
||||
buf []byte
|
||||
}
|
||||
@@ -1,12 +1,12 @@
|
||||
//go:build !ios
|
||||
|
||||
package udpmux
|
||||
package bind
|
||||
|
||||
import (
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
)
|
||||
|
||||
func (m *SingleSocketUDPMux) notifyAddressRemoval(addr string) {
|
||||
func (m *UDPMuxDefault) notifyAddressRemoval(addr string) {
|
||||
// Kernel mode: direct nbnet.PacketConn (SharedSocket wrapped with nbnet)
|
||||
if conn, ok := m.params.UDPConn.(*nbnet.PacketConn); ok {
|
||||
conn.RemoveAddress(addr)
|
||||
7
client/iface/bind/udp_mux_ios.go
Normal file
7
client/iface/bind/udp_mux_ios.go
Normal file
@@ -0,0 +1,7 @@
|
||||
//go:build ios
|
||||
|
||||
package bind
|
||||
|
||||
func (m *UDPMuxDefault) notifyAddressRemoval(addr string) {
|
||||
// iOS doesn't support nbnet hooks, so this is a no-op
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package udpmux
|
||||
package bind
|
||||
|
||||
/*
|
||||
Most of this code was copied from https://github.com/pion/ice and modified to fulfill NetBird's requirements.
|
||||
@@ -29,7 +29,7 @@ type FilterFn func(address netip.Addr) (bool, netip.Prefix, error)
|
||||
// UniversalUDPMuxDefault handles STUN and TURN servers packets by wrapping the original UDPConn
|
||||
// It then passes packets to the UDPMux that does the actual connection muxing.
|
||||
type UniversalUDPMuxDefault struct {
|
||||
*SingleSocketUDPMux
|
||||
*UDPMuxDefault
|
||||
params UniversalUDPMuxParams
|
||||
|
||||
// since we have a shared socket, for srflx candidates it makes sense to have a shared mapped address across all the agents
|
||||
@@ -72,12 +72,12 @@ func NewUniversalUDPMuxDefault(params UniversalUDPMuxParams) *UniversalUDPMuxDef
|
||||
address: params.WGAddress,
|
||||
}
|
||||
|
||||
udpMuxParams := Params{
|
||||
udpMuxParams := UDPMuxParams{
|
||||
Logger: params.Logger,
|
||||
UDPConn: m.params.UDPConn,
|
||||
Net: m.params.Net,
|
||||
}
|
||||
m.SingleSocketUDPMux = NewSingleSocketUDPMux(udpMuxParams)
|
||||
m.UDPMuxDefault = NewUDPMuxDefault(udpMuxParams)
|
||||
|
||||
return m
|
||||
}
|
||||
@@ -211,8 +211,8 @@ func (m *UniversalUDPMuxDefault) GetRelayedAddr(turnAddr net.Addr, deadline time
|
||||
|
||||
// GetConnForURL add uniques to the muxed connection by concatenating ufrag and URL (e.g. STUN URL) to be able to support multiple STUN/TURN servers
|
||||
// and return a unique connection per server.
|
||||
func (m *UniversalUDPMuxDefault) GetConnForURL(ufrag string, url string, addr net.Addr, candidateID string) (net.PacketConn, error) {
|
||||
return m.SingleSocketUDPMux.GetConn(fmt.Sprintf("%s%s", ufrag, url), addr, candidateID)
|
||||
func (m *UniversalUDPMuxDefault) GetConnForURL(ufrag string, url string, addr net.Addr) (net.PacketConn, error) {
|
||||
return m.UDPMuxDefault.GetConn(fmt.Sprintf("%s%s", ufrag, url), addr)
|
||||
}
|
||||
|
||||
// HandleSTUNMessage discovers STUN packets that carry a XOR mapped address from a STUN server.
|
||||
@@ -233,7 +233,7 @@ func (m *UniversalUDPMuxDefault) HandleSTUNMessage(msg *stun.Message, addr net.A
|
||||
}
|
||||
return nil
|
||||
}
|
||||
return m.SingleSocketUDPMux.HandleSTUNMessage(msg, addr)
|
||||
return m.UDPMuxDefault.HandleSTUNMessage(msg, addr)
|
||||
}
|
||||
|
||||
// isXORMappedResponse indicates whether the message is a XORMappedAddress and is coming from the known STUN server.
|
||||
@@ -1,4 +1,4 @@
|
||||
package udpmux
|
||||
package bind
|
||||
|
||||
/*
|
||||
Most of this code was copied from https://github.com/pion/ice and modified to fulfill NetBird's requirements
|
||||
@@ -16,12 +16,11 @@ import (
|
||||
)
|
||||
|
||||
type udpMuxedConnParams struct {
|
||||
Mux *SingleSocketUDPMux
|
||||
AddrPool *sync.Pool
|
||||
Key string
|
||||
LocalAddr net.Addr
|
||||
Logger logging.LeveledLogger
|
||||
CandidateID string
|
||||
Mux *UDPMuxDefault
|
||||
AddrPool *sync.Pool
|
||||
Key string
|
||||
LocalAddr net.Addr
|
||||
Logger logging.LeveledLogger
|
||||
}
|
||||
|
||||
// udpMuxedConn represents a logical packet conn for a single remote as identified by ufrag
|
||||
@@ -120,10 +119,6 @@ func (c *udpMuxedConn) Close() error {
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *udpMuxedConn) GetCandidateID() string {
|
||||
return c.params.CandidateID
|
||||
}
|
||||
|
||||
func (c *udpMuxedConn) isClosed() bool {
|
||||
select {
|
||||
case <-c.closedChan:
|
||||
@@ -7,14 +7,14 @@ import (
|
||||
|
||||
wgdevice "golang.zx2c4.com/wireguard/device"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/bind"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/client/iface/udpmux"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
)
|
||||
|
||||
type WGTunDevice interface {
|
||||
Create() (device.WGConfigurer, error)
|
||||
Up() (*udpmux.UniversalUDPMuxDefault, error)
|
||||
Up() (*bind.UniversalUDPMuxDefault, error)
|
||||
UpdateAddr(address wgaddr.Address) error
|
||||
WgAddress() wgaddr.Address
|
||||
MTU() uint16
|
||||
|
||||
@@ -13,7 +13,6 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/bind"
|
||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||
"github.com/netbirdio/netbird/client/iface/udpmux"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
)
|
||||
|
||||
@@ -30,7 +29,7 @@ type WGTunDevice struct {
|
||||
name string
|
||||
device *device.Device
|
||||
filteredDevice *FilteredDevice
|
||||
udpMux *udpmux.UniversalUDPMuxDefault
|
||||
udpMux *bind.UniversalUDPMuxDefault
|
||||
configurer WGConfigurer
|
||||
}
|
||||
|
||||
@@ -89,7 +88,7 @@ func (t *WGTunDevice) Create(routes []string, dns string, searchDomains []string
|
||||
}
|
||||
return t.configurer, nil
|
||||
}
|
||||
func (t *WGTunDevice) Up() (*udpmux.UniversalUDPMuxDefault, error) {
|
||||
func (t *WGTunDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
|
||||
err := t.device.Up()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
||||
@@ -13,7 +13,6 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/bind"
|
||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||
"github.com/netbirdio/netbird/client/iface/udpmux"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
)
|
||||
|
||||
@@ -27,7 +26,7 @@ type TunDevice struct {
|
||||
|
||||
device *device.Device
|
||||
filteredDevice *FilteredDevice
|
||||
udpMux *udpmux.UniversalUDPMuxDefault
|
||||
udpMux *bind.UniversalUDPMuxDefault
|
||||
configurer WGConfigurer
|
||||
}
|
||||
|
||||
@@ -72,7 +71,7 @@ func (t *TunDevice) Create() (WGConfigurer, error) {
|
||||
return t.configurer, nil
|
||||
}
|
||||
|
||||
func (t *TunDevice) Up() (*udpmux.UniversalUDPMuxDefault, error) {
|
||||
func (t *TunDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
|
||||
err := t.device.Up()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
||||
@@ -14,7 +14,6 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/bind"
|
||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||
"github.com/netbirdio/netbird/client/iface/udpmux"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
)
|
||||
|
||||
@@ -29,7 +28,7 @@ type TunDevice struct {
|
||||
|
||||
device *device.Device
|
||||
filteredDevice *FilteredDevice
|
||||
udpMux *udpmux.UniversalUDPMuxDefault
|
||||
udpMux *bind.UniversalUDPMuxDefault
|
||||
configurer WGConfigurer
|
||||
}
|
||||
|
||||
@@ -84,7 +83,7 @@ func (t *TunDevice) Create() (WGConfigurer, error) {
|
||||
return t.configurer, nil
|
||||
}
|
||||
|
||||
func (t *TunDevice) Up() (*udpmux.UniversalUDPMuxDefault, error) {
|
||||
func (t *TunDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
|
||||
err := t.device.Up()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
||||
@@ -12,8 +12,8 @@ import (
|
||||
"golang.zx2c4.com/wireguard/device"
|
||||
"golang.zx2c4.com/wireguard/tun/netstack"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/bind"
|
||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||
"github.com/netbirdio/netbird/client/iface/udpmux"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
"github.com/netbirdio/netbird/sharedsock"
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
@@ -31,9 +31,9 @@ type TunKernelDevice struct {
|
||||
|
||||
link *wgLink
|
||||
udpMuxConn net.PacketConn
|
||||
udpMux *udpmux.UniversalUDPMuxDefault
|
||||
udpMux *bind.UniversalUDPMuxDefault
|
||||
|
||||
filterFn udpmux.FilterFn
|
||||
filterFn bind.FilterFn
|
||||
}
|
||||
|
||||
func NewKernelDevice(name string, address wgaddr.Address, wgPort int, key string, mtu uint16, transportNet transport.Net) *TunKernelDevice {
|
||||
@@ -79,7 +79,7 @@ func (t *TunKernelDevice) Create() (WGConfigurer, error) {
|
||||
return configurer, nil
|
||||
}
|
||||
|
||||
func (t *TunKernelDevice) Up() (*udpmux.UniversalUDPMuxDefault, error) {
|
||||
func (t *TunKernelDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
|
||||
if t.udpMux != nil {
|
||||
return t.udpMux, nil
|
||||
}
|
||||
@@ -106,14 +106,14 @@ func (t *TunKernelDevice) Up() (*udpmux.UniversalUDPMuxDefault, error) {
|
||||
udpConn = nbnet.WrapPacketConn(rawSock)
|
||||
}
|
||||
|
||||
bindParams := udpmux.UniversalUDPMuxParams{
|
||||
bindParams := bind.UniversalUDPMuxParams{
|
||||
UDPConn: udpConn,
|
||||
Net: t.transportNet,
|
||||
FilterFn: t.filterFn,
|
||||
WGAddress: t.address,
|
||||
MTU: t.mtu,
|
||||
}
|
||||
mux := udpmux.NewUniversalUDPMuxDefault(bindParams)
|
||||
mux := bind.NewUniversalUDPMuxDefault(bindParams)
|
||||
go mux.ReadFromConn(t.ctx)
|
||||
t.udpMuxConn = rawSock
|
||||
t.udpMux = mux
|
||||
|
||||
@@ -10,7 +10,6 @@ import (
|
||||
"github.com/netbirdio/netbird/client/iface/bind"
|
||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||
nbnetstack "github.com/netbirdio/netbird/client/iface/netstack"
|
||||
"github.com/netbirdio/netbird/client/iface/udpmux"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
)
|
||||
@@ -27,7 +26,7 @@ type TunNetstackDevice struct {
|
||||
device *device.Device
|
||||
filteredDevice *FilteredDevice
|
||||
nsTun *nbnetstack.NetStackTun
|
||||
udpMux *udpmux.UniversalUDPMuxDefault
|
||||
udpMux *bind.UniversalUDPMuxDefault
|
||||
configurer WGConfigurer
|
||||
|
||||
net *netstack.Net
|
||||
@@ -81,7 +80,7 @@ func (t *TunNetstackDevice) create() (WGConfigurer, error) {
|
||||
return t.configurer, nil
|
||||
}
|
||||
|
||||
func (t *TunNetstackDevice) Up() (*udpmux.UniversalUDPMuxDefault, error) {
|
||||
func (t *TunNetstackDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
|
||||
if t.device == nil {
|
||||
return nil, fmt.Errorf("device is not ready yet")
|
||||
}
|
||||
|
||||
@@ -12,7 +12,6 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/bind"
|
||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||
"github.com/netbirdio/netbird/client/iface/udpmux"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
)
|
||||
|
||||
@@ -26,7 +25,7 @@ type USPDevice struct {
|
||||
|
||||
device *device.Device
|
||||
filteredDevice *FilteredDevice
|
||||
udpMux *udpmux.UniversalUDPMuxDefault
|
||||
udpMux *bind.UniversalUDPMuxDefault
|
||||
configurer WGConfigurer
|
||||
}
|
||||
|
||||
@@ -75,7 +74,7 @@ func (t *USPDevice) Create() (WGConfigurer, error) {
|
||||
return t.configurer, nil
|
||||
}
|
||||
|
||||
func (t *USPDevice) Up() (*udpmux.UniversalUDPMuxDefault, error) {
|
||||
func (t *USPDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
|
||||
if t.device == nil {
|
||||
return nil, fmt.Errorf("device is not ready yet")
|
||||
}
|
||||
|
||||
@@ -13,7 +13,6 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/bind"
|
||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||
"github.com/netbirdio/netbird/client/iface/udpmux"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
)
|
||||
|
||||
@@ -30,7 +29,7 @@ type TunDevice struct {
|
||||
device *device.Device
|
||||
nativeTunDevice *tun.NativeTun
|
||||
filteredDevice *FilteredDevice
|
||||
udpMux *udpmux.UniversalUDPMuxDefault
|
||||
udpMux *bind.UniversalUDPMuxDefault
|
||||
configurer WGConfigurer
|
||||
}
|
||||
|
||||
@@ -105,7 +104,7 @@ func (t *TunDevice) Create() (WGConfigurer, error) {
|
||||
return t.configurer, nil
|
||||
}
|
||||
|
||||
func (t *TunDevice) Up() (*udpmux.UniversalUDPMuxDefault, error) {
|
||||
func (t *TunDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
|
||||
err := t.device.Up()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
||||
@@ -5,14 +5,14 @@ import (
|
||||
|
||||
"golang.zx2c4.com/wireguard/tun/netstack"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/bind"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/client/iface/udpmux"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
)
|
||||
|
||||
type WGTunDevice interface {
|
||||
Create(routes []string, dns string, searchDomains []string) (device.WGConfigurer, error)
|
||||
Up() (*udpmux.UniversalUDPMuxDefault, error)
|
||||
Up() (*bind.UniversalUDPMuxDefault, error)
|
||||
UpdateAddr(address wgaddr.Address) error
|
||||
WgAddress() wgaddr.Address
|
||||
MTU() uint16
|
||||
|
||||
@@ -16,9 +16,9 @@ import (
|
||||
wgdevice "golang.zx2c4.com/wireguard/device"
|
||||
|
||||
"github.com/netbirdio/netbird/client/errors"
|
||||
"github.com/netbirdio/netbird/client/iface/bind"
|
||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/client/iface/udpmux"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
"github.com/netbirdio/netbird/client/iface/wgproxy"
|
||||
"github.com/netbirdio/netbird/monotime"
|
||||
@@ -61,7 +61,7 @@ type WGIFaceOpts struct {
|
||||
MTU uint16
|
||||
MobileArgs *device.MobileIFaceArguments
|
||||
TransportNet transport.Net
|
||||
FilterFn udpmux.FilterFn
|
||||
FilterFn bind.FilterFn
|
||||
DisableDNS bool
|
||||
}
|
||||
|
||||
@@ -114,7 +114,7 @@ func (r *WGIface) ToInterface() *net.Interface {
|
||||
|
||||
// Up configures a Wireguard interface
|
||||
// The interface must exist before calling this method (e.g. call interface.Create() before)
|
||||
func (w *WGIface) Up() (*udpmux.UniversalUDPMuxDefault, error) {
|
||||
func (w *WGIface) Up() (*bind.UniversalUDPMuxDefault, error) {
|
||||
w.mu.Lock()
|
||||
defer w.mu.Unlock()
|
||||
|
||||
|
||||
40
client/iface/iface_new_freebsd.go
Normal file
40
client/iface/iface_new_freebsd.go
Normal file
@@ -0,0 +1,40 @@
|
||||
//go:build freebsd
|
||||
|
||||
package iface
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/bind"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/client/iface/netstack"
|
||||
"github.com/netbirdio/netbird/client/iface/wgproxy"
|
||||
)
|
||||
|
||||
// NewWGIFace Creates a new WireGuard interface instance
|
||||
func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
|
||||
wgAddress, err := device.ParseWGAddress(opts.Address)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
wgIFace := &WGIface{}
|
||||
|
||||
if netstack.IsEnabled() {
|
||||
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn)
|
||||
wgIFace.tun = device.NewNetstackDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, netstack.ListenAddr())
|
||||
wgIFace.userspaceBind = true
|
||||
wgIFace.wgProxyFactory = wgproxy.NewUSPFactory(iceBind)
|
||||
return wgIFace, nil
|
||||
}
|
||||
|
||||
if device.ModuleTunIsLoaded() {
|
||||
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn)
|
||||
wgIFace.tun = device.NewUSPDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind)
|
||||
wgIFace.userspaceBind = true
|
||||
wgIFace.wgProxyFactory = wgproxy.NewUSPFactory(iceBind)
|
||||
return wgIFace, nil
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("couldn't check or load tun module")
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
//go:build (linux && !android) || freebsd
|
||||
//go:build linux && !android
|
||||
|
||||
package iface
|
||||
|
||||
@@ -1,64 +0,0 @@
|
||||
// Package udpmux provides a custom implementation of a UDP multiplexer
|
||||
// that allows multiple logical ICE connections to share a single underlying
|
||||
// UDP socket. This is based on Pion's ICE library, with modifications for
|
||||
// NetBird's requirements.
|
||||
//
|
||||
// # Background
|
||||
//
|
||||
// In WebRTC and NAT traversal scenarios, ICE (Interactive Connectivity
|
||||
// Establishment) is responsible for discovering candidate network paths
|
||||
// and maintaining connectivity between peers. Each ICE connection
|
||||
// normally requires a dedicated UDP socket. However, using one socket
|
||||
// per candidate can be inefficient and difficult to manage.
|
||||
//
|
||||
// This package introduces SingleSocketUDPMux, which allows multiple ICE
|
||||
// candidate connections (muxed connections) to share a single UDP socket.
|
||||
// It handles demultiplexing of packets based on ICE ufrag values, STUN
|
||||
// attributes, and candidate IDs.
|
||||
//
|
||||
// # Usage
|
||||
//
|
||||
// The typical flow is:
|
||||
//
|
||||
// 1. Create a UDP socket (net.PacketConn).
|
||||
// 2. Construct Params with the socket and optional logger/net stack.
|
||||
// 3. Call NewSingleSocketUDPMux(params).
|
||||
// 4. For each ICE candidate ufrag, call GetConn(ufrag, addr, candidateID)
|
||||
// to obtain a logical PacketConn.
|
||||
// 5. Use the returned PacketConn just like a normal UDP connection.
|
||||
//
|
||||
// # STUN Message Routing Logic
|
||||
//
|
||||
// When a STUN packet arrives, the mux decides which connection should
|
||||
// receive it using this routing logic:
|
||||
//
|
||||
// Primary Routing: Candidate Pair ID
|
||||
// - Extract the candidate pair ID from the STUN message using
|
||||
// ice.CandidatePairIDFromSTUN(msg)
|
||||
// - The target candidate is the locally generated candidate that
|
||||
// corresponds to the connection that should handle this STUN message
|
||||
// - If found, use the target candidate ID to lookup the specific
|
||||
// connection in candidateConnMap
|
||||
// - Route the message directly to that connection
|
||||
//
|
||||
// Fallback Routing: Broadcasting
|
||||
// When candidate pair ID is not available or lookup fails:
|
||||
// - Collect connections from addressMap based on source address
|
||||
// - Find connection using username attribute (ufrag) from STUN message
|
||||
// - Remove duplicate connections from the list
|
||||
// - Send the STUN message to all collected connections
|
||||
//
|
||||
// # Peer Reflexive Candidate Discovery
|
||||
//
|
||||
// When a remote peer sends a STUN message from an unknown source address
|
||||
// (from a candidate that has not been exchanged via signal), the ICE
|
||||
// library will:
|
||||
// - Generate a new peer reflexive candidate for this source address
|
||||
// - Extract or assign a candidate ID based on the STUN message attributes
|
||||
// - Create a mapping between the new peer reflexive candidate ID and
|
||||
// the appropriate local connection
|
||||
//
|
||||
// This discovery mechanism ensures that STUN messages from newly discovered
|
||||
// peer reflexive candidates can be properly routed to the correct local
|
||||
// connection without requiring fallback broadcasting.
|
||||
package udpmux
|
||||
@@ -1,7 +0,0 @@
|
||||
//go:build ios
|
||||
|
||||
package udpmux
|
||||
|
||||
func (m *SingleSocketUDPMux) notifyAddressRemoval(addr string) {
|
||||
// iOS doesn't support nbnet hooks, so this is a no-op
|
||||
}
|
||||
@@ -16,28 +16,37 @@ import (
|
||||
"github.com/netbirdio/netbird/client/iface/wgproxy/listener"
|
||||
)
|
||||
|
||||
type IceBind interface {
|
||||
SetEndpoint(fakeIP netip.Addr, conn net.Conn)
|
||||
RemoveEndpoint(fakeIP netip.Addr)
|
||||
Recv(ctx context.Context, msg bind.RecvMessage)
|
||||
MTU() uint16
|
||||
}
|
||||
|
||||
type ProxyBind struct {
|
||||
Bind *bind.ICEBind
|
||||
bind IceBind
|
||||
|
||||
fakeNetIP *netip.AddrPort
|
||||
wgBindEndpoint *bind.Endpoint
|
||||
remoteConn net.Conn
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
closeMu sync.Mutex
|
||||
closed bool
|
||||
// wgRelayedEndpoint is a fake address that generated by the Bind.SetEndpoint based on the remote NetBird peer address
|
||||
wgRelayedEndpoint *bind.Endpoint
|
||||
wgCurrentUsed *bind.Endpoint
|
||||
remoteConn net.Conn
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
closeMu sync.Mutex
|
||||
closed bool
|
||||
|
||||
pausedMu sync.Mutex
|
||||
paused bool
|
||||
isStarted bool
|
||||
paused bool
|
||||
pausedCond *sync.Cond
|
||||
isStarted bool
|
||||
|
||||
closeListener *listener.CloseListener
|
||||
}
|
||||
|
||||
func NewProxyBind(bind *bind.ICEBind) *ProxyBind {
|
||||
func NewProxyBind(bind IceBind) *ProxyBind {
|
||||
p := &ProxyBind{
|
||||
Bind: bind,
|
||||
bind: bind,
|
||||
closeListener: listener.NewCloseListener(),
|
||||
pausedCond: sync.NewCond(&sync.Mutex{}),
|
||||
}
|
||||
|
||||
return p
|
||||
@@ -46,25 +55,25 @@ func NewProxyBind(bind *bind.ICEBind) *ProxyBind {
|
||||
// AddTurnConn adds a new connection to the bind.
|
||||
// endpoint is the NetBird address of the remote peer. The SetEndpoint return with the address what will be used in the
|
||||
// WireGuard configuration.
|
||||
//
|
||||
// Parameters:
|
||||
// - ctx: Context is used for proxyToLocal to avoid unnecessary error messages
|
||||
// - nbAddr: The NetBird UDP address of the remote peer, it required to generate fake address
|
||||
// - remoteConn: The established TURN connection to the remote peer
|
||||
func (p *ProxyBind) AddTurnConn(ctx context.Context, nbAddr *net.UDPAddr, remoteConn net.Conn) error {
|
||||
fakeNetIP, err := fakeAddress(nbAddr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
p.fakeNetIP = fakeNetIP
|
||||
p.wgBindEndpoint = &bind.Endpoint{AddrPort: *fakeNetIP}
|
||||
p.wgRelayedEndpoint = &bind.Endpoint{AddrPort: *fakeNetIP}
|
||||
p.remoteConn = remoteConn
|
||||
p.ctx, p.cancel = context.WithCancel(ctx)
|
||||
return nil
|
||||
|
||||
}
|
||||
|
||||
func (p *ProxyBind) EndpointAddr() *net.UDPAddr {
|
||||
return &net.UDPAddr{
|
||||
IP: p.fakeNetIP.Addr().AsSlice(),
|
||||
Port: int(p.fakeNetIP.Port()),
|
||||
Zone: p.fakeNetIP.Addr().Zone(),
|
||||
}
|
||||
return bind.EndpointToUDPAddr(*p.wgRelayedEndpoint)
|
||||
}
|
||||
|
||||
func (p *ProxyBind) SetDisconnectListener(disconnected func()) {
|
||||
@@ -76,17 +85,22 @@ func (p *ProxyBind) Work() {
|
||||
return
|
||||
}
|
||||
|
||||
p.Bind.SetEndpoint(p.fakeNetIP.Addr(), p.remoteConn)
|
||||
p.bind.SetEndpoint(p.wgRelayedEndpoint.Addr(), p.remoteConn)
|
||||
|
||||
p.pausedMu.Lock()
|
||||
p.pausedCond.L.Lock()
|
||||
p.paused = false
|
||||
p.pausedMu.Unlock()
|
||||
|
||||
p.wgCurrentUsed = p.wgRelayedEndpoint
|
||||
|
||||
// Start the proxy only once
|
||||
if !p.isStarted {
|
||||
p.isStarted = true
|
||||
go p.proxyToLocal(p.ctx)
|
||||
}
|
||||
|
||||
p.pausedCond.L.Unlock()
|
||||
// todo: review to should be inside the lock scope
|
||||
p.pausedCond.Signal()
|
||||
}
|
||||
|
||||
func (p *ProxyBind) Pause() {
|
||||
@@ -94,9 +108,25 @@ func (p *ProxyBind) Pause() {
|
||||
return
|
||||
}
|
||||
|
||||
p.pausedMu.Lock()
|
||||
p.pausedCond.L.Lock()
|
||||
p.paused = true
|
||||
p.pausedMu.Unlock()
|
||||
p.pausedCond.L.Unlock()
|
||||
}
|
||||
|
||||
func (p *ProxyBind) RedirectAs(endpoint *net.UDPAddr) {
|
||||
p.pausedCond.L.Lock()
|
||||
p.paused = false
|
||||
|
||||
p.wgCurrentUsed = addrToEndpoint(endpoint)
|
||||
|
||||
p.pausedCond.L.Unlock()
|
||||
p.pausedCond.Signal()
|
||||
}
|
||||
|
||||
func addrToEndpoint(addr *net.UDPAddr) *bind.Endpoint {
|
||||
ip, _ := netip.AddrFromSlice(addr.IP.To4())
|
||||
addrPort := netip.AddrPortFrom(ip, uint16(addr.Port))
|
||||
return &bind.Endpoint{AddrPort: addrPort}
|
||||
}
|
||||
|
||||
func (p *ProxyBind) CloseConn() error {
|
||||
@@ -107,6 +137,10 @@ func (p *ProxyBind) CloseConn() error {
|
||||
}
|
||||
|
||||
func (p *ProxyBind) close() error {
|
||||
if p.remoteConn == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
p.closeMu.Lock()
|
||||
defer p.closeMu.Unlock()
|
||||
|
||||
@@ -120,7 +154,12 @@ func (p *ProxyBind) close() error {
|
||||
|
||||
p.cancel()
|
||||
|
||||
p.Bind.RemoveEndpoint(p.fakeNetIP.Addr())
|
||||
p.pausedCond.L.Lock()
|
||||
p.paused = false
|
||||
p.pausedCond.L.Unlock()
|
||||
p.pausedCond.Signal()
|
||||
|
||||
p.bind.RemoveEndpoint(p.wgRelayedEndpoint.Addr())
|
||||
|
||||
if rErr := p.remoteConn.Close(); rErr != nil && !errors.Is(rErr, net.ErrClosed) {
|
||||
return rErr
|
||||
@@ -136,7 +175,7 @@ func (p *ProxyBind) proxyToLocal(ctx context.Context) {
|
||||
}()
|
||||
|
||||
for {
|
||||
buf := make([]byte, p.Bind.MTU()+bufsize.WGBufferOverhead)
|
||||
buf := make([]byte, p.bind.MTU()+bufsize.WGBufferOverhead)
|
||||
n, err := p.remoteConn.Read(buf)
|
||||
if err != nil {
|
||||
if ctx.Err() != nil {
|
||||
@@ -147,18 +186,17 @@ func (p *ProxyBind) proxyToLocal(ctx context.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
p.pausedMu.Lock()
|
||||
if p.paused {
|
||||
p.pausedMu.Unlock()
|
||||
continue
|
||||
p.pausedCond.L.Lock()
|
||||
for p.paused {
|
||||
p.pausedCond.Wait()
|
||||
}
|
||||
|
||||
msg := bind.RecvMessage{
|
||||
Endpoint: p.wgBindEndpoint,
|
||||
Endpoint: p.wgCurrentUsed,
|
||||
Buffer: buf[:n],
|
||||
}
|
||||
p.Bind.RecvChan <- msg
|
||||
p.pausedMu.Unlock()
|
||||
p.bind.Recv(ctx, msg)
|
||||
p.pausedCond.L.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -6,9 +6,7 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"sync"
|
||||
"syscall"
|
||||
|
||||
"github.com/google/gopacket"
|
||||
"github.com/google/gopacket/layers"
|
||||
@@ -18,6 +16,7 @@ import (
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
"github.com/netbirdio/netbird/client/iface/bufsize"
|
||||
"github.com/netbirdio/netbird/client/iface/wgproxy/rawsocket"
|
||||
"github.com/netbirdio/netbird/client/internal/ebpf"
|
||||
ebpfMgr "github.com/netbirdio/netbird/client/internal/ebpf/manager"
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
@@ -27,6 +26,10 @@ const (
|
||||
loopbackAddr = "127.0.0.1"
|
||||
)
|
||||
|
||||
var (
|
||||
localHostNetIP = net.ParseIP("127.0.0.1")
|
||||
)
|
||||
|
||||
// WGEBPFProxy definition for proxy with EBPF support
|
||||
type WGEBPFProxy struct {
|
||||
localWGListenPort int
|
||||
@@ -64,7 +67,7 @@ func (p *WGEBPFProxy) Listen() error {
|
||||
return err
|
||||
}
|
||||
|
||||
p.rawConn, err = p.prepareSenderRawSocket()
|
||||
p.rawConn, err = rawsocket.PrepareSenderRawSocket()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -214,57 +217,17 @@ generatePort:
|
||||
return p.lastUsedPort, nil
|
||||
}
|
||||
|
||||
func (p *WGEBPFProxy) prepareSenderRawSocket() (net.PacketConn, error) {
|
||||
// Create a raw socket.
|
||||
fd, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_RAW, syscall.IPPROTO_RAW)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("creating raw socket failed: %w", err)
|
||||
}
|
||||
|
||||
// Set the IP_HDRINCL option on the socket to tell the kernel that headers are included in the packet.
|
||||
err = syscall.SetsockoptInt(fd, syscall.IPPROTO_IP, syscall.IP_HDRINCL, 1)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("setting IP_HDRINCL failed: %w", err)
|
||||
}
|
||||
|
||||
// Bind the socket to the "lo" interface.
|
||||
err = syscall.SetsockoptString(fd, syscall.SOL_SOCKET, syscall.SO_BINDTODEVICE, "lo")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("binding to lo interface failed: %w", err)
|
||||
}
|
||||
|
||||
// Set the fwmark on the socket.
|
||||
err = nbnet.SetSocketOpt(fd)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("setting fwmark failed: %w", err)
|
||||
}
|
||||
|
||||
// Convert the file descriptor to a PacketConn.
|
||||
file := os.NewFile(uintptr(fd), fmt.Sprintf("fd %d", fd))
|
||||
if file == nil {
|
||||
return nil, fmt.Errorf("converting fd to file failed")
|
||||
}
|
||||
packetConn, err := net.FilePacketConn(file)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("converting file to packet conn failed: %w", err)
|
||||
}
|
||||
|
||||
return packetConn, nil
|
||||
}
|
||||
|
||||
func (p *WGEBPFProxy) sendPkg(data []byte, port int) error {
|
||||
localhost := net.ParseIP("127.0.0.1")
|
||||
|
||||
func (p *WGEBPFProxy) sendPkg(data []byte, endpointAddr *net.UDPAddr) error {
|
||||
payload := gopacket.Payload(data)
|
||||
ipH := &layers.IPv4{
|
||||
DstIP: localhost,
|
||||
SrcIP: localhost,
|
||||
DstIP: localHostNetIP,
|
||||
SrcIP: endpointAddr.IP,
|
||||
Version: 4,
|
||||
TTL: 64,
|
||||
Protocol: layers.IPProtocolUDP,
|
||||
}
|
||||
udpH := &layers.UDP{
|
||||
SrcPort: layers.UDPPort(port),
|
||||
SrcPort: layers.UDPPort(endpointAddr.Port),
|
||||
DstPort: layers.UDPPort(p.localWGListenPort),
|
||||
}
|
||||
|
||||
@@ -279,7 +242,7 @@ func (p *WGEBPFProxy) sendPkg(data []byte, port int) error {
|
||||
if err != nil {
|
||||
return fmt.Errorf("serialize layers: %w", err)
|
||||
}
|
||||
if _, err = p.rawConn.WriteTo(layerBuffer.Bytes(), &net.IPAddr{IP: localhost}); err != nil {
|
||||
if _, err = p.rawConn.WriteTo(layerBuffer.Bytes(), &net.IPAddr{IP: localHostNetIP}); err != nil {
|
||||
return fmt.Errorf("write to raw conn: %w", err)
|
||||
}
|
||||
return nil
|
||||
|
||||
@@ -18,41 +18,42 @@ import (
|
||||
|
||||
// ProxyWrapper help to keep the remoteConn instance for net.Conn.Close function call
|
||||
type ProxyWrapper struct {
|
||||
WgeBPFProxy *WGEBPFProxy
|
||||
wgeBPFProxy *WGEBPFProxy
|
||||
|
||||
remoteConn net.Conn
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
|
||||
wgEndpointAddr *net.UDPAddr
|
||||
wgRelayedEndpointAddr *net.UDPAddr
|
||||
wgEndpointCurrentUsedAddr *net.UDPAddr
|
||||
|
||||
pausedMu sync.Mutex
|
||||
paused bool
|
||||
isStarted bool
|
||||
paused bool
|
||||
pausedCond *sync.Cond
|
||||
isStarted bool
|
||||
|
||||
closeListener *listener.CloseListener
|
||||
}
|
||||
|
||||
func NewProxyWrapper(WgeBPFProxy *WGEBPFProxy) *ProxyWrapper {
|
||||
func NewProxyWrapper(proxy *WGEBPFProxy) *ProxyWrapper {
|
||||
return &ProxyWrapper{
|
||||
WgeBPFProxy: WgeBPFProxy,
|
||||
wgeBPFProxy: proxy,
|
||||
pausedCond: sync.NewCond(&sync.Mutex{}),
|
||||
closeListener: listener.NewCloseListener(),
|
||||
}
|
||||
}
|
||||
|
||||
func (p *ProxyWrapper) AddTurnConn(ctx context.Context, endpoint *net.UDPAddr, remoteConn net.Conn) error {
|
||||
addr, err := p.WgeBPFProxy.AddTurnConn(remoteConn)
|
||||
addr, err := p.wgeBPFProxy.AddTurnConn(remoteConn)
|
||||
if err != nil {
|
||||
return fmt.Errorf("add turn conn: %w", err)
|
||||
}
|
||||
p.remoteConn = remoteConn
|
||||
p.ctx, p.cancel = context.WithCancel(ctx)
|
||||
p.wgEndpointAddr = addr
|
||||
p.wgRelayedEndpointAddr = addr
|
||||
return err
|
||||
}
|
||||
|
||||
func (p *ProxyWrapper) EndpointAddr() *net.UDPAddr {
|
||||
return p.wgEndpointAddr
|
||||
return p.wgRelayedEndpointAddr
|
||||
}
|
||||
|
||||
func (p *ProxyWrapper) SetDisconnectListener(disconnected func()) {
|
||||
@@ -64,14 +65,19 @@ func (p *ProxyWrapper) Work() {
|
||||
return
|
||||
}
|
||||
|
||||
p.pausedMu.Lock()
|
||||
p.pausedCond.L.Lock()
|
||||
p.paused = false
|
||||
p.pausedMu.Unlock()
|
||||
|
||||
p.wgEndpointCurrentUsedAddr = p.wgRelayedEndpointAddr
|
||||
|
||||
if !p.isStarted {
|
||||
p.isStarted = true
|
||||
go p.proxyToLocal(p.ctx)
|
||||
}
|
||||
|
||||
p.pausedCond.L.Unlock()
|
||||
// todo: review to should be inside the lock scope
|
||||
p.pausedCond.Signal()
|
||||
}
|
||||
|
||||
func (p *ProxyWrapper) Pause() {
|
||||
@@ -80,45 +86,59 @@ func (p *ProxyWrapper) Pause() {
|
||||
}
|
||||
|
||||
log.Tracef("pause proxy reading from: %s", p.remoteConn.RemoteAddr())
|
||||
p.pausedMu.Lock()
|
||||
p.pausedCond.L.Lock()
|
||||
p.paused = true
|
||||
p.pausedMu.Unlock()
|
||||
p.pausedCond.L.Unlock()
|
||||
}
|
||||
|
||||
func (p *ProxyWrapper) RedirectAs(endpoint *net.UDPAddr) {
|
||||
p.pausedCond.L.Lock()
|
||||
p.paused = false
|
||||
|
||||
p.wgEndpointCurrentUsedAddr = endpoint
|
||||
|
||||
p.pausedCond.L.Unlock()
|
||||
p.pausedCond.Signal()
|
||||
}
|
||||
|
||||
// CloseConn close the remoteConn and automatically remove the conn instance from the map
|
||||
func (e *ProxyWrapper) CloseConn() error {
|
||||
if e.cancel == nil {
|
||||
func (p *ProxyWrapper) CloseConn() error {
|
||||
if p.cancel == nil {
|
||||
return fmt.Errorf("proxy not started")
|
||||
}
|
||||
|
||||
e.cancel()
|
||||
p.cancel()
|
||||
|
||||
e.closeListener.SetCloseListener(nil)
|
||||
p.closeListener.SetCloseListener(nil)
|
||||
|
||||
if err := e.remoteConn.Close(); err != nil && !errors.Is(err, net.ErrClosed) {
|
||||
return fmt.Errorf("close remote conn: %w", err)
|
||||
p.pausedCond.L.Lock()
|
||||
p.paused = false
|
||||
p.pausedCond.L.Unlock()
|
||||
p.pausedCond.Signal()
|
||||
|
||||
if err := p.remoteConn.Close(); err != nil && !errors.Is(err, net.ErrClosed) {
|
||||
return fmt.Errorf("failed to close remote conn: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *ProxyWrapper) proxyToLocal(ctx context.Context) {
|
||||
defer p.WgeBPFProxy.removeTurnConn(uint16(p.wgEndpointAddr.Port))
|
||||
defer p.wgeBPFProxy.removeTurnConn(uint16(p.wgRelayedEndpointAddr.Port))
|
||||
|
||||
buf := make([]byte, p.WgeBPFProxy.mtu+bufsize.WGBufferOverhead)
|
||||
buf := make([]byte, p.wgeBPFProxy.mtu+bufsize.WGBufferOverhead)
|
||||
for {
|
||||
n, err := p.readFromRemote(ctx, buf)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
p.pausedMu.Lock()
|
||||
if p.paused {
|
||||
p.pausedMu.Unlock()
|
||||
continue
|
||||
p.pausedCond.L.Lock()
|
||||
for p.paused {
|
||||
p.pausedCond.Wait()
|
||||
}
|
||||
|
||||
err = p.WgeBPFProxy.sendPkg(buf[:n], p.wgEndpointAddr.Port)
|
||||
p.pausedMu.Unlock()
|
||||
err = p.wgeBPFProxy.sendPkg(buf[:n], p.wgEndpointCurrentUsedAddr)
|
||||
p.pausedCond.L.Unlock()
|
||||
|
||||
if err != nil {
|
||||
if ctx.Err() != nil {
|
||||
@@ -137,7 +157,7 @@ func (p *ProxyWrapper) readFromRemote(ctx context.Context, buf []byte) (int, err
|
||||
}
|
||||
p.closeListener.Notify()
|
||||
if !errors.Is(err, io.EOF) {
|
||||
log.Errorf("failed to read from turn conn (endpoint: :%d): %s", p.wgEndpointAddr.Port, err)
|
||||
log.Errorf("failed to read from turn conn (endpoint: :%d): %s", p.wgRelayedEndpointAddr.Port, err)
|
||||
}
|
||||
return 0, err
|
||||
}
|
||||
|
||||
@@ -39,7 +39,6 @@ func (w *KernelFactory) GetProxy() Proxy {
|
||||
}
|
||||
|
||||
return ebpf.NewProxyWrapper(w.ebpfProxy)
|
||||
|
||||
}
|
||||
|
||||
func (w *KernelFactory) Free() error {
|
||||
|
||||
@@ -1,31 +0,0 @@
|
||||
package wgproxy
|
||||
|
||||
import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
udpProxy "github.com/netbirdio/netbird/client/iface/wgproxy/udp"
|
||||
)
|
||||
|
||||
// KernelFactory todo: check eBPF support on FreeBSD
|
||||
type KernelFactory struct {
|
||||
wgPort int
|
||||
mtu uint16
|
||||
}
|
||||
|
||||
func NewKernelFactory(wgPort int, mtu uint16) *KernelFactory {
|
||||
log.Infof("WireGuard Proxy Factory will produce UDP proxy")
|
||||
f := &KernelFactory{
|
||||
wgPort: wgPort,
|
||||
mtu: mtu,
|
||||
}
|
||||
|
||||
return f
|
||||
}
|
||||
|
||||
func (w *KernelFactory) GetProxy() Proxy {
|
||||
return udpProxy.NewWGUDPProxy(w.wgPort, w.mtu)
|
||||
}
|
||||
|
||||
func (w *KernelFactory) Free() error {
|
||||
return nil
|
||||
}
|
||||
@@ -11,6 +11,12 @@ type Proxy interface {
|
||||
EndpointAddr() *net.UDPAddr // EndpointAddr returns the address of the WireGuard peer endpoint
|
||||
Work() // Work start or resume the proxy
|
||||
Pause() // Pause to forward the packages from remote connection to WireGuard. The opposite way still works.
|
||||
/*
|
||||
RedirectAs resume the forwarding the packages from relayed connection to WireGuard interface if it was paused
|
||||
and rewrite the src address to the endpoint address.
|
||||
With this logic can avoid the package loss from relayed connections.
|
||||
*/
|
||||
RedirectAs(endpoint *net.UDPAddr)
|
||||
CloseConn() error
|
||||
SetDisconnectListener(disconnected func())
|
||||
}
|
||||
|
||||
@@ -3,54 +3,82 @@
|
||||
package wgproxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"testing"
|
||||
"fmt"
|
||||
"net"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/bind"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
bindproxy "github.com/netbirdio/netbird/client/iface/wgproxy/bind"
|
||||
"github.com/netbirdio/netbird/client/iface/wgproxy/ebpf"
|
||||
"github.com/netbirdio/netbird/client/iface/wgproxy/udp"
|
||||
)
|
||||
|
||||
func TestProxyCloseByRemoteConnEBPF(t *testing.T) {
|
||||
if os.Getenv("GITHUB_ACTIONS") != "true" {
|
||||
t.Skip("Skipping test as it requires root privileges")
|
||||
}
|
||||
ctx := context.Background()
|
||||
func seedProxies() ([]proxyInstance, error) {
|
||||
pl := make([]proxyInstance, 0)
|
||||
|
||||
ebpfProxy := ebpf.NewWGEBPFProxy(51831, 1280)
|
||||
if err := ebpfProxy.Listen(); err != nil {
|
||||
t.Fatalf("failed to initialize ebpf proxy: %s", err)
|
||||
return nil, fmt.Errorf("failed to initialize ebpf proxy: %s", err)
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if err := ebpfProxy.Free(); err != nil {
|
||||
t.Errorf("failed to free ebpf proxy: %s", err)
|
||||
}
|
||||
}()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
proxy Proxy
|
||||
}{
|
||||
{
|
||||
name: "ebpf proxy",
|
||||
proxy: &ebpf.ProxyWrapper{
|
||||
WgeBPFProxy: ebpfProxy,
|
||||
},
|
||||
},
|
||||
pEbpf := proxyInstance{
|
||||
name: "ebpf kernel proxy",
|
||||
proxy: ebpf.NewProxyWrapper(ebpfProxy),
|
||||
wgPort: 51831,
|
||||
closeFn: ebpfProxy.Free,
|
||||
}
|
||||
pl = append(pl, pEbpf)
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
relayedConn := newMockConn()
|
||||
err := tt.proxy.AddTurnConn(ctx, nil, relayedConn)
|
||||
if err != nil {
|
||||
t.Errorf("error: %v", err)
|
||||
}
|
||||
|
||||
_ = relayedConn.Close()
|
||||
if err := tt.proxy.CloseConn(); err != nil {
|
||||
t.Errorf("error: %v", err)
|
||||
}
|
||||
})
|
||||
pUDP := proxyInstance{
|
||||
name: "udp kernel proxy",
|
||||
proxy: udp.NewWGUDPProxy(51832, 1280),
|
||||
wgPort: 51832,
|
||||
closeFn: func() error { return nil },
|
||||
}
|
||||
pl = append(pl, pUDP)
|
||||
return pl, nil
|
||||
}
|
||||
|
||||
func seedProxyForProxyCloseByRemoteConn() ([]proxyInstance, error) {
|
||||
pl := make([]proxyInstance, 0)
|
||||
|
||||
ebpfProxy := ebpf.NewWGEBPFProxy(51831, 1280)
|
||||
if err := ebpfProxy.Listen(); err != nil {
|
||||
return nil, fmt.Errorf("failed to initialize ebpf proxy: %s", err)
|
||||
}
|
||||
|
||||
pEbpf := proxyInstance{
|
||||
name: "ebpf kernel proxy",
|
||||
proxy: ebpf.NewProxyWrapper(ebpfProxy),
|
||||
wgPort: 51831,
|
||||
closeFn: ebpfProxy.Free,
|
||||
}
|
||||
pl = append(pl, pEbpf)
|
||||
|
||||
pUDP := proxyInstance{
|
||||
name: "udp kernel proxy",
|
||||
proxy: udp.NewWGUDPProxy(51832, 1280),
|
||||
wgPort: 51832,
|
||||
closeFn: func() error { return nil },
|
||||
}
|
||||
pl = append(pl, pUDP)
|
||||
wgAddress, err := wgaddr.ParseWGAddress("10.0.0.1/32")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
iceBind := bind.NewICEBind(nil, nil, wgAddress, 1280)
|
||||
endpointAddress := &net.UDPAddr{
|
||||
IP: net.IPv4(10, 0, 0, 1),
|
||||
Port: 1234,
|
||||
}
|
||||
|
||||
pBind := proxyInstance{
|
||||
name: "bind proxy",
|
||||
proxy: bindproxy.NewProxyBind(iceBind),
|
||||
endpointAddr: endpointAddress,
|
||||
closeFn: func() error { return nil },
|
||||
}
|
||||
pl = append(pl, pBind)
|
||||
|
||||
return pl, nil
|
||||
}
|
||||
|
||||
39
client/iface/wgproxy/proxy_seed_test.go
Normal file
39
client/iface/wgproxy/proxy_seed_test.go
Normal file
@@ -0,0 +1,39 @@
|
||||
//go:build !linux
|
||||
|
||||
package wgproxy
|
||||
|
||||
import (
|
||||
"net"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/bind"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
bindproxy "github.com/netbirdio/netbird/client/iface/wgproxy/bind"
|
||||
)
|
||||
|
||||
func seedProxies() ([]proxyInstance, error) {
|
||||
// todo extend with Bind proxy
|
||||
pl := make([]proxyInstance, 0)
|
||||
return pl, nil
|
||||
}
|
||||
|
||||
func seedProxyForProxyCloseByRemoteConn() ([]proxyInstance, error) {
|
||||
pl := make([]proxyInstance, 0)
|
||||
wgAddress, err := wgaddr.ParseWGAddress("10.0.0.1/32")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
iceBind := bind.NewICEBind(nil, nil, wgAddress, 1280)
|
||||
endpointAddress := &net.UDPAddr{
|
||||
IP: net.IPv4(10, 0, 0, 1),
|
||||
Port: 1234,
|
||||
}
|
||||
|
||||
pBind := proxyInstance{
|
||||
name: "bind proxy",
|
||||
proxy: bindproxy.NewProxyBind(iceBind),
|
||||
endpointAddr: endpointAddress,
|
||||
closeFn: func() error { return nil },
|
||||
}
|
||||
pl = append(pl, pBind)
|
||||
return pl, nil
|
||||
}
|
||||
@@ -1,5 +1,3 @@
|
||||
//go:build linux
|
||||
|
||||
package wgproxy
|
||||
|
||||
import (
|
||||
@@ -7,12 +5,9 @@ import (
|
||||
"io"
|
||||
"net"
|
||||
"os"
|
||||
"runtime"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/wgproxy/ebpf"
|
||||
udpProxy "github.com/netbirdio/netbird/client/iface/wgproxy/udp"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
|
||||
@@ -22,6 +17,14 @@ func TestMain(m *testing.M) {
|
||||
os.Exit(code)
|
||||
}
|
||||
|
||||
type proxyInstance struct {
|
||||
name string
|
||||
proxy Proxy
|
||||
wgPort int
|
||||
endpointAddr *net.UDPAddr
|
||||
closeFn func() error
|
||||
}
|
||||
|
||||
type mocConn struct {
|
||||
closeChan chan struct{}
|
||||
closed bool
|
||||
@@ -78,41 +81,21 @@ func (m *mocConn) SetWriteDeadline(t time.Time) error {
|
||||
func TestProxyCloseByRemoteConn(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
proxy Proxy
|
||||
}{
|
||||
{
|
||||
name: "userspace proxy",
|
||||
proxy: udpProxy.NewWGUDPProxy(51830, 1280),
|
||||
},
|
||||
tests, err := seedProxyForProxyCloseByRemoteConn()
|
||||
if err != nil {
|
||||
t.Fatalf("error: %v", err)
|
||||
}
|
||||
|
||||
if runtime.GOOS == "linux" && os.Getenv("GITHUB_ACTIONS") != "true" {
|
||||
ebpfProxy := ebpf.NewWGEBPFProxy(51831, 1280)
|
||||
if err := ebpfProxy.Listen(); err != nil {
|
||||
t.Fatalf("failed to initialize ebpf proxy: %s", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := ebpfProxy.Free(); err != nil {
|
||||
t.Errorf("failed to free ebpf proxy: %s", err)
|
||||
}
|
||||
}()
|
||||
proxyWrapper := ebpf.NewProxyWrapper(ebpfProxy)
|
||||
|
||||
tests = append(tests, struct {
|
||||
name string
|
||||
proxy Proxy
|
||||
}{
|
||||
name: "ebpf proxy",
|
||||
proxy: proxyWrapper,
|
||||
})
|
||||
}
|
||||
relayedConn, _ := net.Dial("udp", "127.0.0.1:1234")
|
||||
defer func() {
|
||||
_ = relayedConn.Close()
|
||||
}()
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
addr, _ := net.ResolveUDPAddr("udp", "100.108.135.221:51892")
|
||||
relayedConn := newMockConn()
|
||||
err := tt.proxy.AddTurnConn(ctx, nil, relayedConn)
|
||||
err := tt.proxy.AddTurnConn(ctx, addr, relayedConn)
|
||||
if err != nil {
|
||||
t.Errorf("error: %v", err)
|
||||
}
|
||||
@@ -124,3 +107,104 @@ func TestProxyCloseByRemoteConn(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestProxyRedirect todo extend the proxies with Bind proxy
|
||||
func TestProxyRedirect(t *testing.T) {
|
||||
tests, err := seedProxies()
|
||||
if err != nil {
|
||||
t.Fatalf("error: %v", err)
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
redirectTraffic(t, tt.proxy, tt.wgPort, tt.endpointAddr)
|
||||
if err := tt.closeFn(); err != nil {
|
||||
t.Errorf("error: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func redirectTraffic(t *testing.T, proxy Proxy, wgPort int, endPointAddr *net.UDPAddr) {
|
||||
t.Helper()
|
||||
|
||||
msgHelloFromRelay := []byte("hello from relay")
|
||||
msgRedirected := [][]byte{
|
||||
[]byte("hello 1. to p2p"),
|
||||
[]byte("hello 2. to p2p"),
|
||||
[]byte("hello 3. to p2p"),
|
||||
}
|
||||
|
||||
dummyWgListener, err := net.ListenUDP("udp", &net.UDPAddr{
|
||||
IP: net.IPv4(127, 0, 0, 1),
|
||||
Port: wgPort})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to listen on udp port: %s", err)
|
||||
}
|
||||
|
||||
relayedServer, _ := net.ListenUDP("udp",
|
||||
&net.UDPAddr{
|
||||
IP: net.IPv4(127, 0, 0, 1),
|
||||
Port: 1234,
|
||||
},
|
||||
)
|
||||
|
||||
relayedConn, _ := net.Dial("udp", "127.0.0.1:1234")
|
||||
|
||||
defer func() {
|
||||
_ = dummyWgListener.Close()
|
||||
_ = relayedConn.Close()
|
||||
_ = relayedServer.Close()
|
||||
}()
|
||||
|
||||
if err := proxy.AddTurnConn(context.Background(), endPointAddr, relayedConn); err != nil {
|
||||
t.Errorf("error: %v", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := proxy.CloseConn(); err != nil {
|
||||
t.Errorf("error: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
proxy.Work()
|
||||
|
||||
if _, err := relayedServer.WriteTo(msgHelloFromRelay, relayedConn.LocalAddr()); err != nil {
|
||||
t.Errorf("error relayedServer.Write(msgHelloFromRelay): %v", err)
|
||||
}
|
||||
|
||||
n, err := dummyWgListener.Read(make([]byte, 1024))
|
||||
if err != nil {
|
||||
t.Errorf("error: %v", err)
|
||||
}
|
||||
|
||||
if n != len(msgHelloFromRelay) {
|
||||
t.Errorf("expected %d bytes, got %d", len(msgHelloFromRelay), n)
|
||||
}
|
||||
|
||||
p2pEndpointAddr := &net.UDPAddr{
|
||||
IP: net.IPv4(192, 168, 0, 56),
|
||||
Port: 1234,
|
||||
}
|
||||
proxy.RedirectAs(p2pEndpointAddr)
|
||||
|
||||
for _, msg := range msgRedirected {
|
||||
if _, err := relayedServer.WriteTo(msg, relayedConn.LocalAddr()); err != nil {
|
||||
t.Errorf("error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
for i := 0; i < len(msgRedirected); i++ {
|
||||
buf := make([]byte, 1024)
|
||||
n, rAddr, err := dummyWgListener.ReadFrom(buf)
|
||||
if err != nil {
|
||||
t.Errorf("error: %v", err)
|
||||
}
|
||||
|
||||
if rAddr.String() != p2pEndpointAddr.String() {
|
||||
t.Errorf("expected %s, got %s", p2pEndpointAddr.String(), rAddr.String())
|
||||
}
|
||||
if string(buf[:n]) != string(msgRedirected[i]) {
|
||||
t.Errorf("expected %s, got %s", string(msgRedirected[i]), string(buf[:n]))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
50
client/iface/wgproxy/rawsocket/rawsocket.go
Normal file
50
client/iface/wgproxy/rawsocket/rawsocket.go
Normal file
@@ -0,0 +1,50 @@
|
||||
//go:build linux && !android
|
||||
|
||||
package rawsocket
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"syscall"
|
||||
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
)
|
||||
|
||||
func PrepareSenderRawSocket() (net.PacketConn, error) {
|
||||
// Create a raw socket.
|
||||
fd, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_RAW, syscall.IPPROTO_RAW)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("creating raw socket failed: %w", err)
|
||||
}
|
||||
|
||||
// Set the IP_HDRINCL option on the socket to tell the kernel that headers are included in the packet.
|
||||
err = syscall.SetsockoptInt(fd, syscall.IPPROTO_IP, syscall.IP_HDRINCL, 1)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("setting IP_HDRINCL failed: %w", err)
|
||||
}
|
||||
|
||||
// Bind the socket to the "lo" interface.
|
||||
err = syscall.SetsockoptString(fd, syscall.SOL_SOCKET, syscall.SO_BINDTODEVICE, "lo")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("binding to lo interface failed: %w", err)
|
||||
}
|
||||
|
||||
// Set the fwmark on the socket.
|
||||
err = nbnet.SetSocketOpt(fd)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("setting fwmark failed: %w", err)
|
||||
}
|
||||
|
||||
// Convert the file descriptor to a PacketConn.
|
||||
file := os.NewFile(uintptr(fd), fmt.Sprintf("fd %d", fd))
|
||||
if file == nil {
|
||||
return nil, fmt.Errorf("converting fd to file failed")
|
||||
}
|
||||
packetConn, err := net.FilePacketConn(file)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("converting file to packet conn failed: %w", err)
|
||||
}
|
||||
|
||||
return packetConn, nil
|
||||
}
|
||||
@@ -1,3 +1,5 @@
|
||||
//go:build linux && !android
|
||||
|
||||
package udp
|
||||
|
||||
import (
|
||||
@@ -21,16 +23,18 @@ type WGUDPProxy struct {
|
||||
localWGListenPort int
|
||||
mtu uint16
|
||||
|
||||
remoteConn net.Conn
|
||||
localConn net.Conn
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
closeMu sync.Mutex
|
||||
closed bool
|
||||
remoteConn net.Conn
|
||||
localConn net.Conn
|
||||
srcFakerConn *SrcFaker
|
||||
sendPkg func(data []byte) (int, error)
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
closeMu sync.Mutex
|
||||
closed bool
|
||||
|
||||
pausedMu sync.Mutex
|
||||
paused bool
|
||||
isStarted bool
|
||||
paused bool
|
||||
pausedCond *sync.Cond
|
||||
isStarted bool
|
||||
|
||||
closeListener *listener.CloseListener
|
||||
}
|
||||
@@ -41,6 +45,7 @@ func NewWGUDPProxy(wgPort int, mtu uint16) *WGUDPProxy {
|
||||
p := &WGUDPProxy{
|
||||
localWGListenPort: wgPort,
|
||||
mtu: mtu,
|
||||
pausedCond: sync.NewCond(&sync.Mutex{}),
|
||||
closeListener: listener.NewCloseListener(),
|
||||
}
|
||||
return p
|
||||
@@ -61,6 +66,7 @@ func (p *WGUDPProxy) AddTurnConn(ctx context.Context, endpoint *net.UDPAddr, rem
|
||||
|
||||
p.ctx, p.cancel = context.WithCancel(ctx)
|
||||
p.localConn = localConn
|
||||
p.sendPkg = p.localConn.Write
|
||||
p.remoteConn = remoteConn
|
||||
|
||||
return err
|
||||
@@ -84,15 +90,24 @@ func (p *WGUDPProxy) Work() {
|
||||
return
|
||||
}
|
||||
|
||||
p.pausedMu.Lock()
|
||||
p.pausedCond.L.Lock()
|
||||
p.paused = false
|
||||
p.pausedMu.Unlock()
|
||||
p.sendPkg = p.localConn.Write
|
||||
|
||||
if p.srcFakerConn != nil {
|
||||
if err := p.srcFakerConn.Close(); err != nil {
|
||||
log.Errorf("failed to close src faker conn: %s", err)
|
||||
}
|
||||
p.srcFakerConn = nil
|
||||
}
|
||||
|
||||
if !p.isStarted {
|
||||
p.isStarted = true
|
||||
go p.proxyToRemote(p.ctx)
|
||||
go p.proxyToLocal(p.ctx)
|
||||
}
|
||||
p.pausedCond.L.Unlock()
|
||||
p.pausedCond.Signal()
|
||||
}
|
||||
|
||||
// Pause pauses the proxy from receiving data from the remote peer
|
||||
@@ -101,9 +116,35 @@ func (p *WGUDPProxy) Pause() {
|
||||
return
|
||||
}
|
||||
|
||||
p.pausedMu.Lock()
|
||||
p.pausedCond.L.Lock()
|
||||
p.paused = true
|
||||
p.pausedMu.Unlock()
|
||||
p.pausedCond.L.Unlock()
|
||||
}
|
||||
|
||||
// RedirectAs start to use the fake sourced raw socket as package sender
|
||||
func (p *WGUDPProxy) RedirectAs(endpoint *net.UDPAddr) {
|
||||
p.pausedCond.L.Lock()
|
||||
defer func() {
|
||||
p.pausedCond.L.Unlock()
|
||||
p.pausedCond.Signal()
|
||||
}()
|
||||
|
||||
p.paused = false
|
||||
if p.srcFakerConn != nil {
|
||||
if err := p.srcFakerConn.Close(); err != nil {
|
||||
log.Errorf("failed to close src faker conn: %s", err)
|
||||
}
|
||||
p.srcFakerConn = nil
|
||||
}
|
||||
srcFakerConn, err := NewSrcFaker(p.localWGListenPort, endpoint)
|
||||
if err != nil {
|
||||
log.Errorf("failed to create src faker conn: %s", err)
|
||||
// fallback to continue without redirecting
|
||||
p.paused = true
|
||||
return
|
||||
}
|
||||
p.srcFakerConn = srcFakerConn
|
||||
p.sendPkg = p.srcFakerConn.SendPkg
|
||||
}
|
||||
|
||||
// CloseConn close the localConn
|
||||
@@ -115,6 +156,8 @@ func (p *WGUDPProxy) CloseConn() error {
|
||||
}
|
||||
|
||||
func (p *WGUDPProxy) close() error {
|
||||
var result *multierror.Error
|
||||
|
||||
p.closeMu.Lock()
|
||||
defer p.closeMu.Unlock()
|
||||
|
||||
@@ -128,7 +171,11 @@ func (p *WGUDPProxy) close() error {
|
||||
|
||||
p.cancel()
|
||||
|
||||
var result *multierror.Error
|
||||
p.pausedCond.L.Lock()
|
||||
p.paused = false
|
||||
p.pausedCond.L.Unlock()
|
||||
p.pausedCond.Signal()
|
||||
|
||||
if err := p.remoteConn.Close(); err != nil && !errors.Is(err, net.ErrClosed) {
|
||||
result = multierror.Append(result, fmt.Errorf("remote conn: %s", err))
|
||||
}
|
||||
@@ -136,6 +183,13 @@ func (p *WGUDPProxy) close() error {
|
||||
if err := p.localConn.Close(); err != nil {
|
||||
result = multierror.Append(result, fmt.Errorf("local conn: %s", err))
|
||||
}
|
||||
|
||||
if p.srcFakerConn != nil {
|
||||
if err := p.srcFakerConn.Close(); err != nil {
|
||||
result = multierror.Append(result, fmt.Errorf("src faker raw conn: %s", err))
|
||||
}
|
||||
}
|
||||
|
||||
return cerrors.FormatErrorOrNil(result)
|
||||
}
|
||||
|
||||
@@ -194,14 +248,12 @@ func (p *WGUDPProxy) proxyToLocal(ctx context.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
p.pausedMu.Lock()
|
||||
if p.paused {
|
||||
p.pausedMu.Unlock()
|
||||
continue
|
||||
p.pausedCond.L.Lock()
|
||||
for p.paused {
|
||||
p.pausedCond.Wait()
|
||||
}
|
||||
|
||||
_, err = p.localConn.Write(buf[:n])
|
||||
p.pausedMu.Unlock()
|
||||
_, err = p.sendPkg(buf[:n])
|
||||
p.pausedCond.L.Unlock()
|
||||
|
||||
if err != nil {
|
||||
if ctx.Err() != nil {
|
||||
|
||||
101
client/iface/wgproxy/udp/rawsocket.go
Normal file
101
client/iface/wgproxy/udp/rawsocket.go
Normal file
@@ -0,0 +1,101 @@
|
||||
//go:build linux && !android
|
||||
|
||||
package udp
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
|
||||
"github.com/google/gopacket"
|
||||
"github.com/google/gopacket/layers"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/wgproxy/rawsocket"
|
||||
)
|
||||
|
||||
var (
|
||||
serializeOpts = gopacket.SerializeOptions{
|
||||
ComputeChecksums: true,
|
||||
FixLengths: true,
|
||||
}
|
||||
|
||||
localHostNetIPAddr = &net.IPAddr{
|
||||
IP: net.ParseIP("127.0.0.1"),
|
||||
}
|
||||
)
|
||||
|
||||
type SrcFaker struct {
|
||||
srcAddr *net.UDPAddr
|
||||
|
||||
rawSocket net.PacketConn
|
||||
ipH gopacket.SerializableLayer
|
||||
udpH gopacket.SerializableLayer
|
||||
layerBuffer gopacket.SerializeBuffer
|
||||
}
|
||||
|
||||
func NewSrcFaker(dstPort int, srcAddr *net.UDPAddr) (*SrcFaker, error) {
|
||||
rawSocket, err := rawsocket.PrepareSenderRawSocket()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ipH, udpH, err := prepareHeaders(dstPort, srcAddr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
f := &SrcFaker{
|
||||
srcAddr: srcAddr,
|
||||
rawSocket: rawSocket,
|
||||
ipH: ipH,
|
||||
udpH: udpH,
|
||||
layerBuffer: gopacket.NewSerializeBuffer(),
|
||||
}
|
||||
|
||||
return f, nil
|
||||
}
|
||||
|
||||
func (f *SrcFaker) Close() error {
|
||||
return f.rawSocket.Close()
|
||||
}
|
||||
|
||||
func (f *SrcFaker) SendPkg(data []byte) (int, error) {
|
||||
defer func() {
|
||||
if err := f.layerBuffer.Clear(); err != nil {
|
||||
log.Errorf("failed to clear layer buffer: %s", err)
|
||||
}
|
||||
}()
|
||||
|
||||
payload := gopacket.Payload(data)
|
||||
|
||||
err := gopacket.SerializeLayers(f.layerBuffer, serializeOpts, f.ipH, f.udpH, payload)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("serialize layers: %w", err)
|
||||
}
|
||||
n, err := f.rawSocket.WriteTo(f.layerBuffer.Bytes(), localHostNetIPAddr)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("write to raw conn: %w", err)
|
||||
}
|
||||
return n, nil
|
||||
}
|
||||
|
||||
func prepareHeaders(dstPort int, srcAddr *net.UDPAddr) (gopacket.SerializableLayer, gopacket.SerializableLayer, error) {
|
||||
ipH := &layers.IPv4{
|
||||
DstIP: net.ParseIP("127.0.0.1"),
|
||||
SrcIP: srcAddr.IP,
|
||||
Version: 4,
|
||||
TTL: 64,
|
||||
Protocol: layers.IPProtocolUDP,
|
||||
}
|
||||
udpH := &layers.UDP{
|
||||
SrcPort: layers.UDPPort(srcAddr.Port),
|
||||
DstPort: layers.UDPPort(dstPort), // dst is the localhost WireGuard port
|
||||
}
|
||||
|
||||
err := udpH.SetNetworkLayerForChecksum(ipH)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("set network layer for checksum: %w", err)
|
||||
}
|
||||
|
||||
return ipH, udpH, nil
|
||||
}
|
||||
@@ -280,12 +280,15 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
|
||||
return wrapErr(err)
|
||||
}
|
||||
|
||||
|
||||
log.Infof("Netbird engine started, the IP is: %s", peerConfig.GetAddress())
|
||||
state.Set(StatusConnected)
|
||||
|
||||
if runningChan != nil {
|
||||
close(runningChan)
|
||||
runningChan = nil
|
||||
select {
|
||||
case runningChan <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
<-engineCtx.Done()
|
||||
|
||||
@@ -29,9 +29,9 @@ import (
|
||||
"github.com/netbirdio/netbird/client/firewall"
|
||||
firewallManager "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/bind"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
nbnetstack "github.com/netbirdio/netbird/client/iface/netstack"
|
||||
"github.com/netbirdio/netbird/client/iface/udpmux"
|
||||
"github.com/netbirdio/netbird/client/internal/acl"
|
||||
"github.com/netbirdio/netbird/client/internal/dns"
|
||||
dnsconfig "github.com/netbirdio/netbird/client/internal/dns/config"
|
||||
@@ -166,7 +166,7 @@ type Engine struct {
|
||||
|
||||
wgInterface WGIface
|
||||
|
||||
udpMux *udpmux.UniversalUDPMuxDefault
|
||||
udpMux *bind.UniversalUDPMuxDefault
|
||||
|
||||
// networkSerial is the latest CurrentSerial (state ID) of the network sent by the Management service
|
||||
networkSerial uint64
|
||||
@@ -461,7 +461,7 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
|
||||
StunTurn: &e.stunTurn,
|
||||
InterfaceBlackList: e.config.IFaceBlackList,
|
||||
DisableIPv6Discovery: e.config.DisableIPv6Discovery,
|
||||
UDPMux: e.udpMux.SingleSocketUDPMux,
|
||||
UDPMux: e.udpMux.UDPMuxDefault,
|
||||
UDPMuxSrflx: e.udpMux,
|
||||
NATExternalIPs: e.parseNATExternalIPMappings(),
|
||||
}
|
||||
@@ -1326,7 +1326,7 @@ func (e *Engine) createPeerConn(pubKey string, allowedIPs []netip.Prefix, agentV
|
||||
StunTurn: &e.stunTurn,
|
||||
InterfaceBlackList: e.config.IFaceBlackList,
|
||||
DisableIPv6Discovery: e.config.DisableIPv6Discovery,
|
||||
UDPMux: e.udpMux.SingleSocketUDPMux,
|
||||
UDPMux: e.udpMux.UDPMuxDefault,
|
||||
UDPMuxSrflx: e.udpMux,
|
||||
NATExternalIPs: e.parseNATExternalIPMappings(),
|
||||
},
|
||||
|
||||
@@ -19,18 +19,21 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.opentelemetry.io/otel"
|
||||
wgdevice "golang.zx2c4.com/wireguard/device"
|
||||
"golang.zx2c4.com/wireguard/tun/netstack"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/keepalive"
|
||||
|
||||
wgdevice "golang.zx2c4.com/wireguard/device"
|
||||
"golang.zx2c4.com/wireguard/tun/netstack"
|
||||
|
||||
"github.com/netbirdio/management-integrations/integrations"
|
||||
"github.com/netbirdio/netbird/management/internals/server/config"
|
||||
"github.com/netbirdio/netbird/management/server/groups"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/bind"
|
||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/client/iface/udpmux"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
"github.com/netbirdio/netbird/client/iface/wgproxy"
|
||||
"github.com/netbirdio/netbird/client/internal/dns"
|
||||
@@ -42,12 +45,9 @@ import (
|
||||
"github.com/netbirdio/netbird/client/ssh"
|
||||
"github.com/netbirdio/netbird/client/system"
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/internals/server/config"
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/groups"
|
||||
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
||||
"github.com/netbirdio/netbird/management/server/peers"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/settings"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
@@ -85,7 +85,7 @@ type MockWGIface struct {
|
||||
NameFunc func() string
|
||||
AddressFunc func() wgaddr.Address
|
||||
ToInterfaceFunc func() *net.Interface
|
||||
UpFunc func() (*udpmux.UniversalUDPMuxDefault, error)
|
||||
UpFunc func() (*bind.UniversalUDPMuxDefault, error)
|
||||
UpdateAddrFunc func(newAddr string) error
|
||||
UpdatePeerFunc func(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error
|
||||
RemovePeerFunc func(peerKey string) error
|
||||
@@ -135,7 +135,7 @@ func (m *MockWGIface) ToInterface() *net.Interface {
|
||||
return m.ToInterfaceFunc()
|
||||
}
|
||||
|
||||
func (m *MockWGIface) Up() (*udpmux.UniversalUDPMuxDefault, error) {
|
||||
func (m *MockWGIface) Up() (*bind.UniversalUDPMuxDefault, error) {
|
||||
return m.UpFunc()
|
||||
}
|
||||
|
||||
@@ -414,7 +414,7 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
engine.udpMux = udpmux.NewUniversalUDPMuxDefault(udpmux.UniversalUDPMuxParams{UDPConn: conn, MTU: 1280})
|
||||
engine.udpMux = bind.NewUniversalUDPMuxDefault(bind.UniversalUDPMuxParams{UDPConn: conn, MTU: 1280})
|
||||
engine.ctx = ctx
|
||||
engine.srWatcher = guard.NewSRWatcher(nil, nil, nil, icemaker.Config{})
|
||||
engine.connMgr = NewConnMgr(engine.config, engine.statusRecorder, engine.peerStore, wgIface)
|
||||
@@ -1555,11 +1555,7 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
permissionsManager := permissions.NewManager(store)
|
||||
peersManager := peers.NewManager(store, permissionsManager)
|
||||
|
||||
ia, _ := integrations.NewIntegratedValidator(context.Background(), peersManager, nil, eventStore)
|
||||
ia, _ := integrations.NewIntegratedValidator(context.Background(), eventStore)
|
||||
|
||||
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
|
||||
require.NoError(t, err)
|
||||
@@ -1576,6 +1572,7 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri
|
||||
Return(&types.ExtraSettings{}, nil).
|
||||
AnyTimes()
|
||||
|
||||
permissionsManager := permissions.NewManager(store)
|
||||
groupsManager := groups.NewManagerMock()
|
||||
|
||||
accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
|
||||
|
||||
@@ -9,9 +9,9 @@ import (
|
||||
"golang.zx2c4.com/wireguard/tun/netstack"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/bind"
|
||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/client/iface/udpmux"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
"github.com/netbirdio/netbird/client/iface/wgproxy"
|
||||
"github.com/netbirdio/netbird/monotime"
|
||||
@@ -24,7 +24,7 @@ type wgIfaceBase interface {
|
||||
Name() string
|
||||
Address() wgaddr.Address
|
||||
ToInterface() *net.Interface
|
||||
Up() (*udpmux.UniversalUDPMuxDefault, error)
|
||||
Up() (*bind.UniversalUDPMuxDefault, error)
|
||||
UpdateAddr(newAddr string) error
|
||||
GetProxy() wgproxy.Proxy
|
||||
UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"math/rand"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"runtime"
|
||||
"sync"
|
||||
"time"
|
||||
@@ -173,7 +174,7 @@ func (conn *Conn) Open(engineCtx context.Context) error {
|
||||
conn.handshaker = NewHandshaker(conn.Log, conn.config, conn.signaler, conn.workerICE, conn.workerRelay)
|
||||
|
||||
conn.handshaker.AddOnNewOfferListener(conn.workerRelay.OnNewOffer)
|
||||
if !isForceRelayed() {
|
||||
if os.Getenv("NB_FORCE_RELAY") != "true" {
|
||||
conn.handshaker.AddOnNewOfferListener(conn.workerICE.OnNewOffer)
|
||||
}
|
||||
|
||||
@@ -375,12 +376,19 @@ func (conn *Conn) onICEConnectionIsReady(priority conntype.ConnPriority, iceConn
|
||||
wgProxy.Work()
|
||||
}
|
||||
|
||||
conn.Log.Infof("configure WireGuard endpoint to: %s", ep.String())
|
||||
if err = conn.configureWGEndpoint(ep, iceConnInfo.RosenpassPubKey); err != nil {
|
||||
conn.handleConfigurationFailure(err, wgProxy)
|
||||
return
|
||||
}
|
||||
wgConfigWorkaround()
|
||||
|
||||
|
||||
if conn.wgProxyRelay != nil {
|
||||
conn.Log.Debugf("redirect packages from relayed conn to WireGuard")
|
||||
conn.wgProxyRelay.RedirectAs(ep)
|
||||
}
|
||||
|
||||
conn.currentConnPriority = priority
|
||||
conn.statusICE.SetConnected()
|
||||
conn.updateIceState(iceConnInfo)
|
||||
@@ -418,6 +426,7 @@ func (conn *Conn) onICEStateDisconnected() {
|
||||
defer conn.wgWatcherWg.Done()
|
||||
conn.workerRelay.EnableWgWatcher(conn.ctx)
|
||||
}()
|
||||
conn.wgProxyRelay.Work()
|
||||
conn.currentConnPriority = conntype.Relay
|
||||
} else {
|
||||
conn.Log.Infof("ICE disconnected, do not switch to Relay. Reset priority to: %s", conntype.None.String())
|
||||
|
||||
@@ -1,14 +0,0 @@
|
||||
package peer
|
||||
|
||||
import (
|
||||
"os"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const (
|
||||
EnvKeyNBForceRelay = "NB_FORCE_RELAY"
|
||||
)
|
||||
|
||||
func isForceRelayed() bool {
|
||||
return strings.EqualFold(os.Getenv(EnvKeyNBForceRelay), "true")
|
||||
}
|
||||
@@ -30,10 +30,9 @@ type WGWatcher struct {
|
||||
peerKey string
|
||||
stateDump *stateDump
|
||||
|
||||
ctx context.Context
|
||||
ctxCancel context.CancelFunc
|
||||
ctxLock sync.Mutex
|
||||
enabledTime time.Time
|
||||
ctx context.Context
|
||||
ctxCancel context.CancelFunc
|
||||
ctxLock sync.Mutex
|
||||
}
|
||||
|
||||
func NewWGWatcher(log *log.Entry, wgIfaceStater WGInterfaceStater, peerKey string, stateDump *stateDump) *WGWatcher {
|
||||
@@ -49,7 +48,6 @@ func NewWGWatcher(log *log.Entry, wgIfaceStater WGInterfaceStater, peerKey strin
|
||||
func (w *WGWatcher) EnableWgWatcher(parentCtx context.Context, onDisconnectedFn func()) {
|
||||
w.log.Debugf("enable WireGuard watcher")
|
||||
w.ctxLock.Lock()
|
||||
w.enabledTime = time.Now()
|
||||
|
||||
if w.ctx != nil && w.ctx.Err() == nil {
|
||||
w.log.Errorf("WireGuard watcher already enabled")
|
||||
@@ -104,8 +102,7 @@ func (w *WGWatcher) periodicHandshakeCheck(ctx context.Context, ctxCancel contex
|
||||
return
|
||||
}
|
||||
if lastHandshake.IsZero() {
|
||||
elapsed := handshake.Sub(w.enabledTime).Seconds()
|
||||
w.log.Infof("first wg handshake detected within: %.2fsec, (%s)", elapsed, handshake)
|
||||
w.log.Infof("first wg handshake detected at: %s", handshake)
|
||||
}
|
||||
|
||||
lastHandshake = *handshake
|
||||
|
||||
@@ -9,10 +9,11 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/pion/ice/v4"
|
||||
"github.com/pion/stun/v2"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/udpmux"
|
||||
"github.com/netbirdio/netbird/client/iface/bind"
|
||||
"github.com/netbirdio/netbird/client/internal/peer/conntype"
|
||||
icemaker "github.com/netbirdio/netbird/client/internal/peer/ice"
|
||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||
@@ -54,6 +55,10 @@ type WorkerICE struct {
|
||||
sessionID ICESessionID
|
||||
muxAgent sync.Mutex
|
||||
|
||||
StunTurn []*stun.URI
|
||||
|
||||
sentExtraSrflx bool
|
||||
|
||||
localUfrag string
|
||||
localPwd string
|
||||
|
||||
@@ -134,6 +139,7 @@ func (w *WorkerICE) OnNewOffer(remoteOfferAnswer *OfferAnswer) {
|
||||
w.muxAgent.Unlock()
|
||||
return
|
||||
}
|
||||
w.sentExtraSrflx = false
|
||||
w.agent = agent
|
||||
w.agentDialerCancel = dialerCancel
|
||||
w.agentConnecting = true
|
||||
@@ -160,21 +166,6 @@ func (w *WorkerICE) OnRemoteCandidate(candidate ice.Candidate, haRoutes route.HA
|
||||
w.log.Errorf("error while handling remote candidate")
|
||||
return
|
||||
}
|
||||
|
||||
if shouldAddExtraCandidate(candidate) {
|
||||
// sends an extra server reflexive candidate to the remote peer with our related port (usually the wireguard port)
|
||||
// this is useful when network has an existing port forwarding rule for the wireguard port and this peer
|
||||
extraSrflx, err := extraSrflxCandidate(candidate)
|
||||
if err != nil {
|
||||
w.log.Errorf("failed creating extra server reflexive candidate %s", err)
|
||||
return
|
||||
}
|
||||
|
||||
if err := w.agent.AddRemoteCandidate(extraSrflx); err != nil {
|
||||
w.log.Errorf("error while handling remote candidate")
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (w *WorkerICE) GetLocalUserCredentials() (frag string, pwd string) {
|
||||
@@ -218,9 +209,7 @@ func (w *WorkerICE) reCreateAgent(dialerCancel context.CancelFunc, candidates []
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := agent.OnSelectedCandidatePairChange(func(c1, c2 ice.Candidate) {
|
||||
w.onICESelectedCandidatePair(agent, c1, c2)
|
||||
}); err != nil {
|
||||
if err := agent.OnSelectedCandidatePairChange(w.onICESelectedCandidatePair); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -338,7 +327,7 @@ func (w *WorkerICE) punchRemoteWGPort(pair *ice.CandidatePair, remoteWgPort int)
|
||||
return
|
||||
}
|
||||
|
||||
mux, ok := w.config.ICEConfig.UDPMuxSrflx.(*udpmux.UniversalUDPMuxDefault)
|
||||
mux, ok := w.config.ICEConfig.UDPMuxSrflx.(*bind.UniversalUDPMuxDefault)
|
||||
if !ok {
|
||||
w.log.Warn("invalid udp mux conversion")
|
||||
return
|
||||
@@ -365,19 +354,48 @@ func (w *WorkerICE) onICECandidate(candidate ice.Candidate) {
|
||||
w.log.Errorf("failed signaling candidate to the remote peer %s %s", w.config.Key, err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (w *WorkerICE) onICESelectedCandidatePair(agent *icemaker.ThreadSafeAgent, c1, c2 ice.Candidate) {
|
||||
w.log.Debugf("selected candidate pair [local <-> remote] -> [%s <-> %s], peer %s", c1.String(), c2.String(),
|
||||
w.config.Key)
|
||||
|
||||
pairStat, ok := agent.GetSelectedCandidatePairStats()
|
||||
if !ok {
|
||||
w.log.Warnf("failed to get selected candidate pair stats")
|
||||
if !w.shouldSendExtraSrflxCandidate(candidate) {
|
||||
return
|
||||
}
|
||||
|
||||
duration := time.Duration(pairStat.CurrentRoundTripTime * float64(time.Second))
|
||||
// sends an extra server reflexive candidate to the remote peer with our related port (usually the wireguard port)
|
||||
// this is useful when network has an existing port forwarding rule for the wireguard port and this peer
|
||||
extraSrflx, err := extraSrflxCandidate(candidate)
|
||||
if err != nil {
|
||||
w.log.Errorf("failed creating extra server reflexive candidate %s", err)
|
||||
return
|
||||
}
|
||||
w.sentExtraSrflx = true
|
||||
|
||||
go func() {
|
||||
err = w.signaler.SignalICECandidate(extraSrflx, w.config.Key)
|
||||
if err != nil {
|
||||
w.log.Errorf("failed signaling the extra server reflexive candidate: %s", err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (w *WorkerICE) onICESelectedCandidatePair(c1 ice.Candidate, c2 ice.Candidate) {
|
||||
w.log.Debugf("selected candidate pair [local <-> remote] -> [%s <-> %s], peer %s", c1.String(), c2.String(),
|
||||
w.config.Key)
|
||||
|
||||
w.muxAgent.Lock()
|
||||
|
||||
pair, err := w.agent.GetSelectedCandidatePair()
|
||||
if err != nil {
|
||||
w.log.Warnf("failed to get selected candidate pair: %s", err)
|
||||
w.muxAgent.Unlock()
|
||||
return
|
||||
}
|
||||
if pair == nil {
|
||||
w.log.Warnf("selected candidate pair is nil, cannot proceed")
|
||||
w.muxAgent.Unlock()
|
||||
return
|
||||
}
|
||||
w.muxAgent.Unlock()
|
||||
|
||||
duration := time.Duration(pair.CurrentRoundTripTime() * float64(time.Second))
|
||||
if err := w.statusRecorder.UpdateLatency(w.config.Key, duration); err != nil {
|
||||
w.log.Debugf("failed to update latency for peer: %s", err)
|
||||
return
|
||||
@@ -406,31 +424,22 @@ func (w *WorkerICE) onConnectionStateChange(agent *icemaker.ThreadSafeAgent, dia
|
||||
}
|
||||
}
|
||||
|
||||
func (w *WorkerICE) shouldSendExtraSrflxCandidate(candidate ice.Candidate) bool {
|
||||
if !w.sentExtraSrflx && candidate.Type() == ice.CandidateTypeServerReflexive && candidate.Port() != candidate.RelatedAddress().Port {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (w *WorkerICE) turnAgentDial(ctx context.Context, agent *icemaker.ThreadSafeAgent, remoteOfferAnswer *OfferAnswer) (*ice.Conn, error) {
|
||||
if isController(w.config) {
|
||||
return w.agent.Dial(ctx, remoteOfferAnswer.IceCredentials.UFrag, remoteOfferAnswer.IceCredentials.Pwd)
|
||||
isControlling := w.config.LocalKey > w.config.Key
|
||||
if isControlling {
|
||||
return agent.Dial(ctx, remoteOfferAnswer.IceCredentials.UFrag, remoteOfferAnswer.IceCredentials.Pwd)
|
||||
} else {
|
||||
return agent.Accept(ctx, remoteOfferAnswer.IceCredentials.UFrag, remoteOfferAnswer.IceCredentials.Pwd)
|
||||
}
|
||||
}
|
||||
|
||||
func shouldAddExtraCandidate(candidate ice.Candidate) bool {
|
||||
if candidate.Type() != ice.CandidateTypeServerReflexive {
|
||||
return false
|
||||
}
|
||||
|
||||
if candidate.Port() == candidate.RelatedAddress().Port {
|
||||
return false
|
||||
}
|
||||
|
||||
// in the older version when we didn't set candidate ID extension the remote peer sent the extra candidates
|
||||
// in newer version we generate locally the extra candidate
|
||||
if _, ok := candidate.GetExtension(ice.ExtensionKeyCandidateID); !ok {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func extraSrflxCandidate(candidate ice.Candidate) (*ice.CandidateServerReflexive, error) {
|
||||
relatedAdd := candidate.RelatedAddress()
|
||||
ec, err := ice.NewCandidateServerReflexive(&ice.CandidateServerReflexiveConfig{
|
||||
@@ -446,10 +455,6 @@ func extraSrflxCandidate(candidate ice.Candidate) (*ice.CandidateServerReflexive
|
||||
}
|
||||
|
||||
for _, e := range candidate.Extensions() {
|
||||
// overwrite the original candidate ID with the new one to avoid candidate duplication
|
||||
if e.Key == ice.ExtensionKeyCandidateID {
|
||||
e.Value = candidate.ID()
|
||||
}
|
||||
if err := ec.AddExtension(e); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -40,7 +40,7 @@ func NewNetWithDiscover(iFaceDiscover ExternalIFaceDiscover, disallowList []stri
|
||||
if netstack.IsEnabled() {
|
||||
n.iFaceDiscover = pionDiscover{}
|
||||
} else {
|
||||
n.iFaceDiscover = newMobileIFaceDiscover(iFaceDiscover)
|
||||
newMobileIFaceDiscover(iFaceDiscover)
|
||||
}
|
||||
return n, n.UpdateInterfaces()
|
||||
}
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
// Code generated by protoc-gen-go. DO NOT EDIT.
|
||||
// versions:
|
||||
// protoc-gen-go v1.36.6
|
||||
// protoc v3.21.9
|
||||
// protoc v5.29.3
|
||||
// source: daemon.proto
|
||||
|
||||
package proto
|
||||
@@ -791,12 +791,11 @@ func (*UpResponse) Descriptor() ([]byte, []int) {
|
||||
}
|
||||
|
||||
type StatusRequest struct {
|
||||
state protoimpl.MessageState `protogen:"open.v1"`
|
||||
GetFullPeerStatus bool `protobuf:"varint,1,opt,name=getFullPeerStatus,proto3" json:"getFullPeerStatus,omitempty"`
|
||||
ShouldRunProbes bool `protobuf:"varint,2,opt,name=shouldRunProbes,proto3" json:"shouldRunProbes,omitempty"`
|
||||
WaitForConnectingShift bool `protobuf:"varint,3,opt,name=waitForConnectingShift,proto3" json:"waitForConnectingShift,omitempty"`
|
||||
unknownFields protoimpl.UnknownFields
|
||||
sizeCache protoimpl.SizeCache
|
||||
state protoimpl.MessageState `protogen:"open.v1"`
|
||||
GetFullPeerStatus bool `protobuf:"varint,1,opt,name=getFullPeerStatus,proto3" json:"getFullPeerStatus,omitempty"`
|
||||
ShouldRunProbes bool `protobuf:"varint,2,opt,name=shouldRunProbes,proto3" json:"shouldRunProbes,omitempty"`
|
||||
unknownFields protoimpl.UnknownFields
|
||||
sizeCache protoimpl.SizeCache
|
||||
}
|
||||
|
||||
func (x *StatusRequest) Reset() {
|
||||
@@ -843,13 +842,6 @@ func (x *StatusRequest) GetShouldRunProbes() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (x *StatusRequest) GetWaitForConnectingShift() bool {
|
||||
if x != nil {
|
||||
return x.WaitForConnectingShift
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
type StatusResponse struct {
|
||||
state protoimpl.MessageState `protogen:"open.v1"`
|
||||
// status of the server.
|
||||
@@ -4681,11 +4673,10 @@ const file_daemon_proto_rawDesc = "" +
|
||||
"\f_profileNameB\v\n" +
|
||||
"\t_username\"\f\n" +
|
||||
"\n" +
|
||||
"UpResponse\"\x9f\x01\n" +
|
||||
"UpResponse\"g\n" +
|
||||
"\rStatusRequest\x12,\n" +
|
||||
"\x11getFullPeerStatus\x18\x01 \x01(\bR\x11getFullPeerStatus\x12(\n" +
|
||||
"\x0fshouldRunProbes\x18\x02 \x01(\bR\x0fshouldRunProbes\x126\n" +
|
||||
"\x16waitForConnectingShift\x18\x03 \x01(\bR\x16waitForConnectingShift\"\x82\x01\n" +
|
||||
"\x0fshouldRunProbes\x18\x02 \x01(\bR\x0fshouldRunProbes\"\x82\x01\n" +
|
||||
"\x0eStatusResponse\x12\x16\n" +
|
||||
"\x06status\x18\x01 \x01(\tR\x06status\x122\n" +
|
||||
"\n" +
|
||||
|
||||
@@ -186,7 +186,6 @@ message UpResponse {}
|
||||
message StatusRequest{
|
||||
bool getFullPeerStatus = 1;
|
||||
bool shouldRunProbes = 2;
|
||||
bool waitForConnectingShift = 3;
|
||||
}
|
||||
|
||||
message StatusResponse{
|
||||
|
||||
@@ -65,8 +65,6 @@ type Server struct {
|
||||
mutex sync.Mutex
|
||||
config *profilemanager.Config
|
||||
proto.UnimplementedDaemonServiceServer
|
||||
clientRunning bool // protected by mutex
|
||||
clientRunningChan chan struct{}
|
||||
|
||||
connectClient *internal.ConnectClient
|
||||
|
||||
@@ -105,7 +103,6 @@ func New(ctx context.Context, logFile string, configFile string, profilesDisable
|
||||
func (s *Server) Start() error {
|
||||
s.mutex.Lock()
|
||||
defer s.mutex.Unlock()
|
||||
|
||||
state := internal.CtxGetState(s.rootCtx)
|
||||
|
||||
if err := handlePanicLog(); err != nil {
|
||||
@@ -119,12 +116,14 @@ func (s *Server) Start() error {
|
||||
// if current state contains any error, return it
|
||||
// in all other cases we can continue execution only if status is idle and up command was
|
||||
// not in the progress or already successfully established connection.
|
||||
_, err := state.Status()
|
||||
status, err := state.Status()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
state.Set(internal.StatusConnecting)
|
||||
if status != internal.StatusIdle {
|
||||
return nil
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(s.rootCtx)
|
||||
s.actCancel = cancel
|
||||
@@ -173,12 +172,8 @@ func (s *Server) Start() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
if s.clientRunning {
|
||||
return nil
|
||||
}
|
||||
s.clientRunning = true
|
||||
s.clientRunningChan = make(chan struct{}, 1)
|
||||
go s.connectWithRetryRuns(ctx, config, s.statusRecorder, s.clientRunningChan)
|
||||
go s.connectWithRetryRuns(ctx, config, s.statusRecorder, nil)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -209,22 +204,12 @@ func (s *Server) setDefaultConfigIfNotExists(ctx context.Context) error {
|
||||
// connectWithRetryRuns runs the client connection with a backoff strategy where we retry the operation as additional
|
||||
// mechanism to keep the client connected even when the connection is lost.
|
||||
// we cancel retry if the client receive a stop or down command, or if disable auto connect is configured.
|
||||
func (s *Server) connectWithRetryRuns(ctx context.Context, profileConfig *profilemanager.Config, statusRecorder *peer.Status, runningChan chan struct{}) {
|
||||
defer func() {
|
||||
s.mutex.Lock()
|
||||
s.clientRunning = false
|
||||
s.mutex.Unlock()
|
||||
}()
|
||||
|
||||
if s.config.DisableAutoConnect {
|
||||
if err := s.connect(ctx, s.config, s.statusRecorder, runningChan); err != nil {
|
||||
log.Debugf("run client connection exited with error: %v", err)
|
||||
}
|
||||
log.Tracef("client connection exited")
|
||||
return
|
||||
}
|
||||
|
||||
func (s *Server) connectWithRetryRuns(ctx context.Context, config *profilemanager.Config, statusRecorder *peer.Status,
|
||||
runningChan chan struct{},
|
||||
) {
|
||||
backOff := getConnectWithBackoff(ctx)
|
||||
retryStarted := false
|
||||
|
||||
go func() {
|
||||
t := time.NewTicker(24 * time.Hour)
|
||||
for {
|
||||
@@ -233,34 +218,91 @@ func (s *Server) connectWithRetryRuns(ctx context.Context, profileConfig *profil
|
||||
t.Stop()
|
||||
return
|
||||
case <-t.C:
|
||||
mgmtState := statusRecorder.GetManagementState()
|
||||
signalState := statusRecorder.GetSignalState()
|
||||
if mgmtState.Connected && signalState.Connected {
|
||||
log.Tracef("resetting status")
|
||||
backOff.Reset()
|
||||
} else {
|
||||
log.Tracef("not resetting status: mgmt: %v, signal: %v", mgmtState.Connected, signalState.Connected)
|
||||
if retryStarted {
|
||||
|
||||
mgmtState := statusRecorder.GetManagementState()
|
||||
signalState := statusRecorder.GetSignalState()
|
||||
if mgmtState.Connected && signalState.Connected {
|
||||
log.Tracef("resetting status")
|
||||
retryStarted = false
|
||||
} else {
|
||||
log.Tracef("not resetting status: mgmt: %v, signal: %v", mgmtState.Connected, signalState.Connected)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
runOperation := func() error {
|
||||
err := s.connect(ctx, profileConfig, statusRecorder, runningChan)
|
||||
log.Tracef("running client connection")
|
||||
s.connectClient = internal.NewConnectClient(ctx, config, statusRecorder)
|
||||
s.connectClient.SetSyncResponsePersistence(s.persistSyncResponse)
|
||||
|
||||
err := s.connectClient.Run(runningChan)
|
||||
if err != nil {
|
||||
log.Debugf("run client connection exited with error: %v. Will retry in the background", err)
|
||||
return err
|
||||
}
|
||||
|
||||
log.Tracef("client connection exited gracefully, do not need to retry")
|
||||
return nil
|
||||
if config.DisableAutoConnect {
|
||||
return backoff.Permanent(err)
|
||||
}
|
||||
|
||||
if !retryStarted {
|
||||
retryStarted = true
|
||||
backOff.Reset()
|
||||
}
|
||||
|
||||
log.Tracef("client connection exited")
|
||||
return fmt.Errorf("client connection exited")
|
||||
}
|
||||
|
||||
if err := backoff.Retry(runOperation, backOff); err != nil {
|
||||
log.Errorf("operation failed: %v", err)
|
||||
err := backoff.Retry(runOperation, backOff)
|
||||
if s, ok := gstatus.FromError(err); ok && s.Code() != codes.Canceled {
|
||||
log.Errorf("received an error when trying to connect: %v", err)
|
||||
} else {
|
||||
log.Tracef("retry canceled")
|
||||
}
|
||||
}
|
||||
|
||||
// getConnectWithBackoff returns a backoff with exponential backoff strategy for connection retries
|
||||
func getConnectWithBackoff(ctx context.Context) backoff.BackOff {
|
||||
initialInterval := parseEnvDuration(retryInitialIntervalVar, defaultInitialRetryTime)
|
||||
maxInterval := parseEnvDuration(maxRetryIntervalVar, defaultMaxRetryInterval)
|
||||
maxElapsedTime := parseEnvDuration(maxRetryTimeVar, defaultMaxRetryTime)
|
||||
multiplier := defaultRetryMultiplier
|
||||
|
||||
if envValue := os.Getenv(retryMultiplierVar); envValue != "" {
|
||||
// parse the multiplier from the environment variable string value to float64
|
||||
value, err := strconv.ParseFloat(envValue, 64)
|
||||
if err != nil {
|
||||
log.Warnf("unable to parse environment variable %s: %s. using default: %f", retryMultiplierVar, envValue, multiplier)
|
||||
} else {
|
||||
multiplier = value
|
||||
}
|
||||
}
|
||||
|
||||
return backoff.WithContext(&backoff.ExponentialBackOff{
|
||||
InitialInterval: initialInterval,
|
||||
RandomizationFactor: 1,
|
||||
Multiplier: multiplier,
|
||||
MaxInterval: maxInterval,
|
||||
MaxElapsedTime: maxElapsedTime, // 14 days
|
||||
Stop: backoff.Stop,
|
||||
Clock: backoff.SystemClock,
|
||||
}, ctx)
|
||||
}
|
||||
|
||||
// parseEnvDuration parses the environment variable and returns the duration
|
||||
func parseEnvDuration(envVar string, defaultDuration time.Duration) time.Duration {
|
||||
if envValue := os.Getenv(envVar); envValue != "" {
|
||||
if duration, err := time.ParseDuration(envValue); err == nil {
|
||||
return duration
|
||||
}
|
||||
log.Warnf("unable to parse environment variable %s: %s. using default: %s", envVar, envValue, defaultDuration)
|
||||
}
|
||||
return defaultDuration
|
||||
}
|
||||
|
||||
// loginAttempt attempts to login using the provided information. it returns a status in case something fails
|
||||
func (s *Server) loginAttempt(ctx context.Context, setupKey, jwtToken string) (internal.StatusType, error) {
|
||||
var status internal.StatusType
|
||||
@@ -674,14 +716,11 @@ func (s *Server) Up(callerCtx context.Context, msg *proto.UpRequest) (*proto.UpR
|
||||
timeoutCtx, cancel := context.WithTimeout(callerCtx, 50*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if !s.clientRunning {
|
||||
s.clientRunning = true
|
||||
s.clientRunningChan = make(chan struct{}, 1)
|
||||
go s.connectWithRetryRuns(ctx, s.config, s.statusRecorder, s.clientRunningChan)
|
||||
}
|
||||
runningChan := make(chan struct{}, 1) // buffered channel to do not lose the signal
|
||||
go s.connectWithRetryRuns(ctx, s.config, s.statusRecorder, runningChan)
|
||||
for {
|
||||
select {
|
||||
case <-s.clientRunningChan:
|
||||
case <-runningChan:
|
||||
s.isSessionActive.Store(true)
|
||||
return &proto.UpResponse{}, nil
|
||||
case <-callerCtx.Done():
|
||||
@@ -959,33 +998,6 @@ func (s *Server) sendLogoutRequestWithConfig(ctx context.Context, config *profil
|
||||
return mgmClient.Logout()
|
||||
}
|
||||
|
||||
func waitStateShift(ctx context.Context) {
|
||||
timer := time.NewTimer(5 * time.Second)
|
||||
defer timer.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
log.Warnf("context done while waiting for state shift: %v", ctx.Err())
|
||||
timer.Stop()
|
||||
return
|
||||
case <-timer.C:
|
||||
log.Warnf("state shift timed out")
|
||||
timer.Stop()
|
||||
return
|
||||
default:
|
||||
status, err := internal.CtxGetState(ctx).Status()
|
||||
if err != nil {
|
||||
log.Errorf("failed to get status: %v", err)
|
||||
return
|
||||
}
|
||||
if status != internal.StatusConnecting {
|
||||
log.Infof("state shifting status: %v", status)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Status returns the daemon status
|
||||
func (s *Server) Status(
|
||||
ctx context.Context,
|
||||
@@ -998,10 +1010,6 @@ func (s *Server) Status(
|
||||
s.mutex.Lock()
|
||||
defer s.mutex.Unlock()
|
||||
|
||||
if msg.WaitForConnectingShift {
|
||||
waitStateShift(s.rootCtx)
|
||||
}
|
||||
|
||||
status, err := internal.CtxGetState(s.rootCtx).Status()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -1119,134 +1127,6 @@ func (s *Server) GetConfig(ctx context.Context, req *proto.GetConfigRequest) (*p
|
||||
}, nil
|
||||
}
|
||||
|
||||
// AddProfile adds a new profile to the daemon.
|
||||
func (s *Server) AddProfile(ctx context.Context, msg *proto.AddProfileRequest) (*proto.AddProfileResponse, error) {
|
||||
s.mutex.Lock()
|
||||
defer s.mutex.Unlock()
|
||||
|
||||
if s.checkProfilesDisabled() {
|
||||
return nil, gstatus.Errorf(codes.Unavailable, errProfilesDisabled)
|
||||
}
|
||||
|
||||
if msg.ProfileName == "" || msg.Username == "" {
|
||||
return nil, gstatus.Errorf(codes.InvalidArgument, "profile name and username must be provided")
|
||||
}
|
||||
|
||||
if err := s.profileManager.AddProfile(msg.ProfileName, msg.Username); err != nil {
|
||||
log.Errorf("failed to create profile: %v", err)
|
||||
return nil, fmt.Errorf("failed to create profile: %w", err)
|
||||
}
|
||||
|
||||
return &proto.AddProfileResponse{}, nil
|
||||
}
|
||||
|
||||
// RemoveProfile removes a profile from the daemon.
|
||||
func (s *Server) RemoveProfile(ctx context.Context, msg *proto.RemoveProfileRequest) (*proto.RemoveProfileResponse, error) {
|
||||
s.mutex.Lock()
|
||||
defer s.mutex.Unlock()
|
||||
|
||||
if err := s.validateProfileOperation(msg.ProfileName, false); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := s.logoutFromProfile(ctx, msg.ProfileName, msg.Username); err != nil {
|
||||
log.Warnf("failed to logout from profile %s before removal: %v", msg.ProfileName, err)
|
||||
}
|
||||
|
||||
if err := s.profileManager.RemoveProfile(msg.ProfileName, msg.Username); err != nil {
|
||||
log.Errorf("failed to remove profile: %v", err)
|
||||
return nil, fmt.Errorf("failed to remove profile: %w", err)
|
||||
}
|
||||
|
||||
return &proto.RemoveProfileResponse{}, nil
|
||||
}
|
||||
|
||||
// ListProfiles lists all profiles in the daemon.
|
||||
func (s *Server) ListProfiles(ctx context.Context, msg *proto.ListProfilesRequest) (*proto.ListProfilesResponse, error) {
|
||||
s.mutex.Lock()
|
||||
defer s.mutex.Unlock()
|
||||
|
||||
if msg.Username == "" {
|
||||
return nil, gstatus.Errorf(codes.InvalidArgument, "username must be provided")
|
||||
}
|
||||
|
||||
profiles, err := s.profileManager.ListProfiles(msg.Username)
|
||||
if err != nil {
|
||||
log.Errorf("failed to list profiles: %v", err)
|
||||
return nil, fmt.Errorf("failed to list profiles: %w", err)
|
||||
}
|
||||
|
||||
response := &proto.ListProfilesResponse{
|
||||
Profiles: make([]*proto.Profile, len(profiles)),
|
||||
}
|
||||
for i, profile := range profiles {
|
||||
response.Profiles[i] = &proto.Profile{
|
||||
Name: profile.Name,
|
||||
IsActive: profile.IsActive,
|
||||
}
|
||||
}
|
||||
|
||||
return response, nil
|
||||
}
|
||||
|
||||
// GetActiveProfile returns the active profile in the daemon.
|
||||
func (s *Server) GetActiveProfile(ctx context.Context, msg *proto.GetActiveProfileRequest) (*proto.GetActiveProfileResponse, error) {
|
||||
s.mutex.Lock()
|
||||
defer s.mutex.Unlock()
|
||||
|
||||
activeProfile, err := s.profileManager.GetActiveProfileState()
|
||||
if err != nil {
|
||||
log.Errorf("failed to get active profile state: %v", err)
|
||||
return nil, fmt.Errorf("failed to get active profile state: %w", err)
|
||||
}
|
||||
|
||||
return &proto.GetActiveProfileResponse{
|
||||
ProfileName: activeProfile.Name,
|
||||
Username: activeProfile.Username,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GetFeatures returns the features supported by the daemon.
|
||||
func (s *Server) GetFeatures(ctx context.Context, msg *proto.GetFeaturesRequest) (*proto.GetFeaturesResponse, error) {
|
||||
s.mutex.Lock()
|
||||
defer s.mutex.Unlock()
|
||||
|
||||
features := &proto.GetFeaturesResponse{
|
||||
DisableProfiles: s.checkProfilesDisabled(),
|
||||
DisableUpdateSettings: s.checkUpdateSettingsDisabled(),
|
||||
}
|
||||
|
||||
return features, nil
|
||||
}
|
||||
|
||||
func (s *Server) connect(ctx context.Context, config *profilemanager.Config, statusRecorder *peer.Status, runningChan chan struct{}) error {
|
||||
log.Tracef("running client connection")
|
||||
s.connectClient = internal.NewConnectClient(ctx, config, statusRecorder)
|
||||
s.connectClient.SetSyncResponsePersistence(s.persistSyncResponse)
|
||||
if err := s.connectClient.Run(runningChan); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) checkProfilesDisabled() bool {
|
||||
// Check if the environment variable is set to disable profiles
|
||||
if s.profilesDisabled {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func (s *Server) checkUpdateSettingsDisabled() bool {
|
||||
// Check if the environment variable is set to disable profiles
|
||||
if s.updateSettingsDisabled {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func (s *Server) onSessionExpire() {
|
||||
if runtime.GOOS != "windows" {
|
||||
isUIActive := internal.CheckUIApp()
|
||||
@@ -1258,45 +1138,6 @@ func (s *Server) onSessionExpire() {
|
||||
}
|
||||
}
|
||||
|
||||
// getConnectWithBackoff returns a backoff with exponential backoff strategy for connection retries
|
||||
func getConnectWithBackoff(ctx context.Context) backoff.BackOff {
|
||||
initialInterval := parseEnvDuration(retryInitialIntervalVar, defaultInitialRetryTime)
|
||||
maxInterval := parseEnvDuration(maxRetryIntervalVar, defaultMaxRetryInterval)
|
||||
maxElapsedTime := parseEnvDuration(maxRetryTimeVar, defaultMaxRetryTime)
|
||||
multiplier := defaultRetryMultiplier
|
||||
|
||||
if envValue := os.Getenv(retryMultiplierVar); envValue != "" {
|
||||
// parse the multiplier from the environment variable string value to float64
|
||||
value, err := strconv.ParseFloat(envValue, 64)
|
||||
if err != nil {
|
||||
log.Warnf("unable to parse environment variable %s: %s. using default: %f", retryMultiplierVar, envValue, multiplier)
|
||||
} else {
|
||||
multiplier = value
|
||||
}
|
||||
}
|
||||
|
||||
return backoff.WithContext(&backoff.ExponentialBackOff{
|
||||
InitialInterval: initialInterval,
|
||||
RandomizationFactor: 1,
|
||||
Multiplier: multiplier,
|
||||
MaxInterval: maxInterval,
|
||||
MaxElapsedTime: maxElapsedTime, // 14 days
|
||||
Stop: backoff.Stop,
|
||||
Clock: backoff.SystemClock,
|
||||
}, ctx)
|
||||
}
|
||||
|
||||
// parseEnvDuration parses the environment variable and returns the duration
|
||||
func parseEnvDuration(envVar string, defaultDuration time.Duration) time.Duration {
|
||||
if envValue := os.Getenv(envVar); envValue != "" {
|
||||
if duration, err := time.ParseDuration(envValue); err == nil {
|
||||
return duration
|
||||
}
|
||||
log.Warnf("unable to parse environment variable %s: %s. using default: %s", envVar, envValue, defaultDuration)
|
||||
}
|
||||
return defaultDuration
|
||||
}
|
||||
|
||||
func toProtoFullStatus(fullStatus peer.FullStatus) *proto.FullStatus {
|
||||
pbFullStatus := proto.FullStatus{
|
||||
ManagementState: &proto.ManagementState{},
|
||||
@@ -1411,3 +1252,121 @@ func sendTerminalNotification() error {
|
||||
|
||||
return wallCmd.Wait()
|
||||
}
|
||||
|
||||
// AddProfile adds a new profile to the daemon.
|
||||
func (s *Server) AddProfile(ctx context.Context, msg *proto.AddProfileRequest) (*proto.AddProfileResponse, error) {
|
||||
s.mutex.Lock()
|
||||
defer s.mutex.Unlock()
|
||||
|
||||
if s.checkProfilesDisabled() {
|
||||
return nil, gstatus.Errorf(codes.Unavailable, errProfilesDisabled)
|
||||
}
|
||||
|
||||
if msg.ProfileName == "" || msg.Username == "" {
|
||||
return nil, gstatus.Errorf(codes.InvalidArgument, "profile name and username must be provided")
|
||||
}
|
||||
|
||||
if err := s.profileManager.AddProfile(msg.ProfileName, msg.Username); err != nil {
|
||||
log.Errorf("failed to create profile: %v", err)
|
||||
return nil, fmt.Errorf("failed to create profile: %w", err)
|
||||
}
|
||||
|
||||
return &proto.AddProfileResponse{}, nil
|
||||
}
|
||||
|
||||
// RemoveProfile removes a profile from the daemon.
|
||||
func (s *Server) RemoveProfile(ctx context.Context, msg *proto.RemoveProfileRequest) (*proto.RemoveProfileResponse, error) {
|
||||
s.mutex.Lock()
|
||||
defer s.mutex.Unlock()
|
||||
|
||||
if err := s.validateProfileOperation(msg.ProfileName, false); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := s.logoutFromProfile(ctx, msg.ProfileName, msg.Username); err != nil {
|
||||
log.Warnf("failed to logout from profile %s before removal: %v", msg.ProfileName, err)
|
||||
}
|
||||
|
||||
if err := s.profileManager.RemoveProfile(msg.ProfileName, msg.Username); err != nil {
|
||||
log.Errorf("failed to remove profile: %v", err)
|
||||
return nil, fmt.Errorf("failed to remove profile: %w", err)
|
||||
}
|
||||
|
||||
return &proto.RemoveProfileResponse{}, nil
|
||||
}
|
||||
|
||||
// ListProfiles lists all profiles in the daemon.
|
||||
func (s *Server) ListProfiles(ctx context.Context, msg *proto.ListProfilesRequest) (*proto.ListProfilesResponse, error) {
|
||||
s.mutex.Lock()
|
||||
defer s.mutex.Unlock()
|
||||
|
||||
if msg.Username == "" {
|
||||
return nil, gstatus.Errorf(codes.InvalidArgument, "username must be provided")
|
||||
}
|
||||
|
||||
profiles, err := s.profileManager.ListProfiles(msg.Username)
|
||||
if err != nil {
|
||||
log.Errorf("failed to list profiles: %v", err)
|
||||
return nil, fmt.Errorf("failed to list profiles: %w", err)
|
||||
}
|
||||
|
||||
response := &proto.ListProfilesResponse{
|
||||
Profiles: make([]*proto.Profile, len(profiles)),
|
||||
}
|
||||
for i, profile := range profiles {
|
||||
response.Profiles[i] = &proto.Profile{
|
||||
Name: profile.Name,
|
||||
IsActive: profile.IsActive,
|
||||
}
|
||||
}
|
||||
|
||||
return response, nil
|
||||
}
|
||||
|
||||
// GetActiveProfile returns the active profile in the daemon.
|
||||
func (s *Server) GetActiveProfile(ctx context.Context, msg *proto.GetActiveProfileRequest) (*proto.GetActiveProfileResponse, error) {
|
||||
s.mutex.Lock()
|
||||
defer s.mutex.Unlock()
|
||||
|
||||
activeProfile, err := s.profileManager.GetActiveProfileState()
|
||||
if err != nil {
|
||||
log.Errorf("failed to get active profile state: %v", err)
|
||||
return nil, fmt.Errorf("failed to get active profile state: %w", err)
|
||||
}
|
||||
|
||||
return &proto.GetActiveProfileResponse{
|
||||
ProfileName: activeProfile.Name,
|
||||
Username: activeProfile.Username,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GetFeatures returns the features supported by the daemon.
|
||||
func (s *Server) GetFeatures(ctx context.Context, msg *proto.GetFeaturesRequest) (*proto.GetFeaturesResponse, error) {
|
||||
s.mutex.Lock()
|
||||
defer s.mutex.Unlock()
|
||||
|
||||
features := &proto.GetFeaturesResponse{
|
||||
DisableProfiles: s.checkProfilesDisabled(),
|
||||
DisableUpdateSettings: s.checkUpdateSettingsDisabled(),
|
||||
}
|
||||
|
||||
return features, nil
|
||||
}
|
||||
|
||||
func (s *Server) checkProfilesDisabled() bool {
|
||||
// Check if the environment variable is set to disable profiles
|
||||
if s.profilesDisabled {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func (s *Server) checkUpdateSettingsDisabled() bool {
|
||||
// Check if the environment variable is set to disable profiles
|
||||
if s.updateSettingsDisabled {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -10,25 +10,25 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.opentelemetry.io/otel"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/keepalive"
|
||||
|
||||
"github.com/netbirdio/management-integrations/integrations"
|
||||
"github.com/netbirdio/netbird/management/internals/server/config"
|
||||
"github.com/netbirdio/netbird/management/server/groups"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/keepalive"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||
daemonProto "github.com/netbirdio/netbird/client/proto"
|
||||
"github.com/netbirdio/netbird/management/internals/server/config"
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/groups"
|
||||
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
||||
"github.com/netbirdio/netbird/management/server/peers"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/settings"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
@@ -294,20 +294,15 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
t.Cleanup(ctrl.Finish)
|
||||
|
||||
permissionsManagerMock := permissions.NewMockManager(ctrl)
|
||||
peersManager := peers.NewManager(store, permissionsManagerMock)
|
||||
settingsManagerMock := settings.NewMockManager(ctrl)
|
||||
|
||||
ia, _ := integrations.NewIntegratedValidator(context.Background(), peersManager, settingsManagerMock, eventStore)
|
||||
ia, _ := integrations.NewIntegratedValidator(context.Background(), eventStore)
|
||||
|
||||
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
|
||||
require.NoError(t, err)
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
t.Cleanup(ctrl.Finish)
|
||||
settingsMockManager := settings.NewMockManager(ctrl)
|
||||
permissionsManagerMock := permissions.NewMockManager(ctrl)
|
||||
groupsManager := groups.NewManagerMock()
|
||||
|
||||
accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false)
|
||||
|
||||
@@ -96,6 +96,14 @@ func (i *Info) SetFlags(
|
||||
i.LazyConnectionEnabled = lazyConnectionEnabled
|
||||
}
|
||||
|
||||
// StaticInfo is an object that contains machine information that does not change
|
||||
type StaticInfo struct {
|
||||
SystemSerialNumber string
|
||||
SystemProductName string
|
||||
SystemManufacturer string
|
||||
Environment Environment
|
||||
}
|
||||
|
||||
// extractUserAgent extracts Netbird's agent (client) name and version from the outgoing context
|
||||
func extractUserAgent(ctx context.Context) string {
|
||||
md, hasMeta := metadata.FromOutgoingContext(ctx)
|
||||
@@ -191,3 +199,10 @@ func GetInfoWithChecks(ctx context.Context, checks []*proto.Checks) (*Info, erro
|
||||
log.Debugf("all system information gathered successfully")
|
||||
return info, nil
|
||||
}
|
||||
|
||||
// UpdateStaticInfo asynchronously updates static system and platform information
|
||||
func UpdateStaticInfo() {
|
||||
go func() {
|
||||
_ = updateStaticInfo()
|
||||
}()
|
||||
}
|
||||
|
||||
@@ -15,11 +15,6 @@ import (
|
||||
"github.com/netbirdio/netbird/version"
|
||||
)
|
||||
|
||||
// UpdateStaticInfoAsync is a no-op on Android as there is no static info to update
|
||||
func UpdateStaticInfoAsync() {
|
||||
// do nothing
|
||||
}
|
||||
|
||||
// GetInfo retrieves and parses the system information
|
||||
func GetInfo(ctx context.Context) *Info {
|
||||
kernel := "android"
|
||||
|
||||
@@ -19,10 +19,6 @@ import (
|
||||
"github.com/netbirdio/netbird/version"
|
||||
)
|
||||
|
||||
func UpdateStaticInfoAsync() {
|
||||
go updateStaticInfo()
|
||||
}
|
||||
|
||||
// GetInfo retrieves and parses the system information
|
||||
func GetInfo(ctx context.Context) *Info {
|
||||
utsname := unix.Utsname{}
|
||||
@@ -45,7 +41,7 @@ func GetInfo(ctx context.Context) *Info {
|
||||
}
|
||||
|
||||
start := time.Now()
|
||||
si := getStaticInfo()
|
||||
si := updateStaticInfo()
|
||||
if time.Since(start) > 1*time.Second {
|
||||
log.Warnf("updateStaticInfo took %s", time.Since(start))
|
||||
}
|
||||
|
||||
@@ -18,11 +18,6 @@ import (
|
||||
"github.com/netbirdio/netbird/version"
|
||||
)
|
||||
|
||||
// UpdateStaticInfoAsync is a no-op on Android as there is no static info to update
|
||||
func UpdateStaticInfoAsync() {
|
||||
// do nothing
|
||||
}
|
||||
|
||||
// GetInfo retrieves and parses the system information
|
||||
func GetInfo(ctx context.Context) *Info {
|
||||
out := _getInfo()
|
||||
|
||||
@@ -10,11 +10,6 @@ import (
|
||||
"github.com/netbirdio/netbird/version"
|
||||
)
|
||||
|
||||
// UpdateStaticInfoAsync is a no-op on Android as there is no static info to update
|
||||
func UpdateStaticInfoAsync() {
|
||||
// do nothing
|
||||
}
|
||||
|
||||
// GetInfo retrieves and parses the system information
|
||||
func GetInfo(ctx context.Context) *Info {
|
||||
|
||||
|
||||
@@ -23,10 +23,6 @@ var (
|
||||
getSystemInfo = defaultSysInfoImplementation
|
||||
)
|
||||
|
||||
func UpdateStaticInfoAsync() {
|
||||
go updateStaticInfo()
|
||||
}
|
||||
|
||||
// GetInfo retrieves and parses the system information
|
||||
func GetInfo(ctx context.Context) *Info {
|
||||
info := _getInfo()
|
||||
@@ -52,7 +48,7 @@ func GetInfo(ctx context.Context) *Info {
|
||||
}
|
||||
|
||||
start := time.Now()
|
||||
si := getStaticInfo()
|
||||
si := updateStaticInfo()
|
||||
if time.Since(start) > 1*time.Second {
|
||||
log.Warnf("updateStaticInfo took %s", time.Since(start))
|
||||
}
|
||||
|
||||
@@ -2,51 +2,187 @@ package system
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"runtime"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/yusufpapurcu/wmi"
|
||||
"golang.org/x/sys/windows/registry"
|
||||
|
||||
"github.com/netbirdio/netbird/version"
|
||||
)
|
||||
|
||||
func UpdateStaticInfoAsync() {
|
||||
go updateStaticInfo()
|
||||
type Win32_OperatingSystem struct {
|
||||
Caption string
|
||||
}
|
||||
|
||||
type Win32_ComputerSystem struct {
|
||||
Manufacturer string
|
||||
}
|
||||
|
||||
type Win32_ComputerSystemProduct struct {
|
||||
Name string
|
||||
}
|
||||
|
||||
type Win32_BIOS struct {
|
||||
SerialNumber string
|
||||
}
|
||||
|
||||
// GetInfo retrieves and parses the system information
|
||||
func GetInfo(ctx context.Context) *Info {
|
||||
osName, osVersion := getOSNameAndVersion()
|
||||
buildVersion := getBuildVersion()
|
||||
|
||||
addrs, err := networkAddresses()
|
||||
if err != nil {
|
||||
log.Warnf("failed to discover network addresses: %s", err)
|
||||
}
|
||||
|
||||
start := time.Now()
|
||||
si := getStaticInfo()
|
||||
si := updateStaticInfo()
|
||||
if time.Since(start) > 1*time.Second {
|
||||
log.Warnf("updateStaticInfo took %s", time.Since(start))
|
||||
}
|
||||
|
||||
gio := &Info{
|
||||
Kernel: "windows",
|
||||
OSVersion: si.OSVersion,
|
||||
OSVersion: osVersion,
|
||||
Platform: "unknown",
|
||||
OS: si.OSName,
|
||||
OS: osName,
|
||||
GoOS: runtime.GOOS,
|
||||
CPUs: runtime.NumCPU(),
|
||||
KernelVersion: si.BuildVersion,
|
||||
KernelVersion: buildVersion,
|
||||
NetworkAddresses: addrs,
|
||||
SystemSerialNumber: si.SystemSerialNumber,
|
||||
SystemProductName: si.SystemProductName,
|
||||
SystemManufacturer: si.SystemManufacturer,
|
||||
Environment: si.Environment,
|
||||
}
|
||||
|
||||
addrs, err := networkAddresses()
|
||||
if err != nil {
|
||||
log.Warnf("failed to discover network addresses: %s", err)
|
||||
} else {
|
||||
gio.NetworkAddresses = addrs
|
||||
}
|
||||
|
||||
systemHostname, _ := os.Hostname()
|
||||
gio.Hostname = extractDeviceName(ctx, systemHostname)
|
||||
gio.NetbirdVersion = version.NetbirdVersion()
|
||||
gio.UIVersion = extractUserAgent(ctx)
|
||||
|
||||
return gio
|
||||
}
|
||||
|
||||
func sysInfo() (serialNumber string, productName string, manufacturer string) {
|
||||
var err error
|
||||
serialNumber, err = sysNumber()
|
||||
if err != nil {
|
||||
log.Warnf("failed to get system serial number: %s", err)
|
||||
}
|
||||
|
||||
productName, err = sysProductName()
|
||||
if err != nil {
|
||||
log.Warnf("failed to get system product name: %s", err)
|
||||
}
|
||||
|
||||
manufacturer, err = sysManufacturer()
|
||||
if err != nil {
|
||||
log.Warnf("failed to get system manufacturer: %s", err)
|
||||
}
|
||||
|
||||
return serialNumber, productName, manufacturer
|
||||
}
|
||||
|
||||
func getOSNameAndVersion() (string, string) {
|
||||
var dst []Win32_OperatingSystem
|
||||
query := wmi.CreateQuery(&dst, "")
|
||||
err := wmi.Query(query, &dst)
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
return "Windows", getBuildVersion()
|
||||
}
|
||||
|
||||
if len(dst) == 0 {
|
||||
return "Windows", getBuildVersion()
|
||||
}
|
||||
|
||||
split := strings.Split(dst[0].Caption, " ")
|
||||
|
||||
if len(split) <= 3 {
|
||||
return "Windows", getBuildVersion()
|
||||
}
|
||||
|
||||
name := split[1]
|
||||
version := split[2]
|
||||
if split[2] == "Server" {
|
||||
name = fmt.Sprintf("%s %s", split[1], split[2])
|
||||
version = split[3]
|
||||
}
|
||||
|
||||
return name, version
|
||||
}
|
||||
|
||||
func getBuildVersion() string {
|
||||
k, err := registry.OpenKey(registry.LOCAL_MACHINE, `SOFTWARE\Microsoft\Windows NT\CurrentVersion`, registry.QUERY_VALUE)
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
return "0.0.0.0"
|
||||
}
|
||||
defer func() {
|
||||
deferErr := k.Close()
|
||||
if deferErr != nil {
|
||||
log.Error(deferErr)
|
||||
}
|
||||
}()
|
||||
|
||||
major, _, err := k.GetIntegerValue("CurrentMajorVersionNumber")
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
}
|
||||
minor, _, err := k.GetIntegerValue("CurrentMinorVersionNumber")
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
}
|
||||
build, _, err := k.GetStringValue("CurrentBuildNumber")
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
}
|
||||
// Update Build Revision
|
||||
ubr, _, err := k.GetIntegerValue("UBR")
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
}
|
||||
ver := fmt.Sprintf("%d.%d.%s.%d", major, minor, build, ubr)
|
||||
return ver
|
||||
}
|
||||
|
||||
func sysNumber() (string, error) {
|
||||
var dst []Win32_BIOS
|
||||
query := wmi.CreateQuery(&dst, "")
|
||||
err := wmi.Query(query, &dst)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return dst[0].SerialNumber, nil
|
||||
}
|
||||
|
||||
func sysProductName() (string, error) {
|
||||
var dst []Win32_ComputerSystemProduct
|
||||
query := wmi.CreateQuery(&dst, "")
|
||||
err := wmi.Query(query, &dst)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
// `ComputerSystemProduct` could be empty on some virtualized systems
|
||||
if len(dst) < 1 {
|
||||
return "unknown", nil
|
||||
}
|
||||
return dst[0].Name, nil
|
||||
}
|
||||
|
||||
func sysManufacturer() (string, error) {
|
||||
var dst []Win32_ComputerSystem
|
||||
query := wmi.CreateQuery(&dst, "")
|
||||
err := wmi.Query(query, &dst)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return dst[0].Manufacturer, nil
|
||||
}
|
||||
|
||||
@@ -3,7 +3,12 @@
|
||||
package system
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/netbirdio/netbird/client/system/detect_cloud"
|
||||
"github.com/netbirdio/netbird/client/system/detect_platform"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -11,26 +16,25 @@ var (
|
||||
once sync.Once
|
||||
)
|
||||
|
||||
// StaticInfo is an object that contains machine information that does not change
|
||||
type StaticInfo struct {
|
||||
SystemSerialNumber string
|
||||
SystemProductName string
|
||||
SystemManufacturer string
|
||||
Environment Environment
|
||||
|
||||
// Windows specific fields
|
||||
OSName string
|
||||
OSVersion string
|
||||
BuildVersion string
|
||||
}
|
||||
|
||||
func updateStaticInfo() {
|
||||
func updateStaticInfo() StaticInfo {
|
||||
once.Do(func() {
|
||||
staticInfo = newStaticInfo()
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
wg := sync.WaitGroup{}
|
||||
wg.Add(3)
|
||||
go func() {
|
||||
staticInfo.SystemSerialNumber, staticInfo.SystemProductName, staticInfo.SystemManufacturer = sysInfo()
|
||||
wg.Done()
|
||||
}()
|
||||
go func() {
|
||||
staticInfo.Environment.Cloud = detect_cloud.Detect(ctx)
|
||||
wg.Done()
|
||||
}()
|
||||
go func() {
|
||||
staticInfo.Environment.Platform = detect_platform.Detect(ctx)
|
||||
wg.Done()
|
||||
}()
|
||||
wg.Wait()
|
||||
})
|
||||
}
|
||||
|
||||
func getStaticInfo() StaticInfo {
|
||||
updateStaticInfo()
|
||||
return staticInfo
|
||||
}
|
||||
|
||||
8
client/system/static_info_stub.go
Normal file
8
client/system/static_info_stub.go
Normal file
@@ -0,0 +1,8 @@
|
||||
//go:build android || freebsd || ios
|
||||
|
||||
package system
|
||||
|
||||
// updateStaticInfo returns an empty implementation for unsupported platforms
|
||||
func updateStaticInfo() StaticInfo {
|
||||
return StaticInfo{}
|
||||
}
|
||||
@@ -1,35 +0,0 @@
|
||||
//go:build (linux && !android) || (darwin && !ios)
|
||||
|
||||
package system
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/netbirdio/netbird/client/system/detect_cloud"
|
||||
"github.com/netbirdio/netbird/client/system/detect_platform"
|
||||
)
|
||||
|
||||
func newStaticInfo() StaticInfo {
|
||||
si := StaticInfo{}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
wg := sync.WaitGroup{}
|
||||
wg.Add(3)
|
||||
go func() {
|
||||
si.SystemSerialNumber, si.SystemProductName, si.SystemManufacturer = sysInfo()
|
||||
wg.Done()
|
||||
}()
|
||||
go func() {
|
||||
si.Environment.Cloud = detect_cloud.Detect(ctx)
|
||||
wg.Done()
|
||||
}()
|
||||
go func() {
|
||||
si.Environment.Platform = detect_platform.Detect(ctx)
|
||||
wg.Done()
|
||||
}()
|
||||
wg.Wait()
|
||||
return si
|
||||
}
|
||||
@@ -1,184 +0,0 @@
|
||||
package system
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/yusufpapurcu/wmi"
|
||||
"golang.org/x/sys/windows/registry"
|
||||
|
||||
"github.com/netbirdio/netbird/client/system/detect_cloud"
|
||||
"github.com/netbirdio/netbird/client/system/detect_platform"
|
||||
)
|
||||
|
||||
type Win32_OperatingSystem struct {
|
||||
Caption string
|
||||
}
|
||||
|
||||
type Win32_ComputerSystem struct {
|
||||
Manufacturer string
|
||||
}
|
||||
|
||||
type Win32_ComputerSystemProduct struct {
|
||||
Name string
|
||||
}
|
||||
|
||||
type Win32_BIOS struct {
|
||||
SerialNumber string
|
||||
}
|
||||
|
||||
func newStaticInfo() StaticInfo {
|
||||
si := StaticInfo{}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
wg := sync.WaitGroup{}
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
si.SystemSerialNumber, si.SystemProductName, si.SystemManufacturer = sysInfo()
|
||||
wg.Done()
|
||||
}()
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
si.Environment.Cloud = detect_cloud.Detect(ctx)
|
||||
wg.Done()
|
||||
}()
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
si.Environment.Platform = detect_platform.Detect(ctx)
|
||||
wg.Done()
|
||||
}()
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
si.OSName, si.OSVersion = getOSNameAndVersion()
|
||||
wg.Done()
|
||||
}()
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
si.BuildVersion = getBuildVersion()
|
||||
wg.Done()
|
||||
}()
|
||||
wg.Wait()
|
||||
return si
|
||||
}
|
||||
|
||||
func sysInfo() (serialNumber string, productName string, manufacturer string) {
|
||||
var err error
|
||||
serialNumber, err = sysNumber()
|
||||
if err != nil {
|
||||
log.Warnf("failed to get system serial number: %s", err)
|
||||
}
|
||||
|
||||
productName, err = sysProductName()
|
||||
if err != nil {
|
||||
log.Warnf("failed to get system product name: %s", err)
|
||||
}
|
||||
|
||||
manufacturer, err = sysManufacturer()
|
||||
if err != nil {
|
||||
log.Warnf("failed to get system manufacturer: %s", err)
|
||||
}
|
||||
|
||||
return serialNumber, productName, manufacturer
|
||||
}
|
||||
|
||||
func sysNumber() (string, error) {
|
||||
var dst []Win32_BIOS
|
||||
query := wmi.CreateQuery(&dst, "")
|
||||
err := wmi.Query(query, &dst)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return dst[0].SerialNumber, nil
|
||||
}
|
||||
|
||||
func sysProductName() (string, error) {
|
||||
var dst []Win32_ComputerSystemProduct
|
||||
query := wmi.CreateQuery(&dst, "")
|
||||
err := wmi.Query(query, &dst)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
// `ComputerSystemProduct` could be empty on some virtualized systems
|
||||
if len(dst) < 1 {
|
||||
return "unknown", nil
|
||||
}
|
||||
return dst[0].Name, nil
|
||||
}
|
||||
|
||||
func sysManufacturer() (string, error) {
|
||||
var dst []Win32_ComputerSystem
|
||||
query := wmi.CreateQuery(&dst, "")
|
||||
err := wmi.Query(query, &dst)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return dst[0].Manufacturer, nil
|
||||
}
|
||||
|
||||
func getOSNameAndVersion() (string, string) {
|
||||
var dst []Win32_OperatingSystem
|
||||
query := wmi.CreateQuery(&dst, "")
|
||||
err := wmi.Query(query, &dst)
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
return "Windows", getBuildVersion()
|
||||
}
|
||||
|
||||
if len(dst) == 0 {
|
||||
return "Windows", getBuildVersion()
|
||||
}
|
||||
|
||||
split := strings.Split(dst[0].Caption, " ")
|
||||
|
||||
if len(split) <= 3 {
|
||||
return "Windows", getBuildVersion()
|
||||
}
|
||||
|
||||
name := split[1]
|
||||
version := split[2]
|
||||
if split[2] == "Server" {
|
||||
name = fmt.Sprintf("%s %s", split[1], split[2])
|
||||
version = split[3]
|
||||
}
|
||||
|
||||
return name, version
|
||||
}
|
||||
|
||||
func getBuildVersion() string {
|
||||
k, err := registry.OpenKey(registry.LOCAL_MACHINE, `SOFTWARE\Microsoft\Windows NT\CurrentVersion`, registry.QUERY_VALUE)
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
return "0.0.0.0"
|
||||
}
|
||||
defer func() {
|
||||
deferErr := k.Close()
|
||||
if deferErr != nil {
|
||||
log.Error(deferErr)
|
||||
}
|
||||
}()
|
||||
|
||||
major, _, err := k.GetIntegerValue("CurrentMajorVersionNumber")
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
}
|
||||
minor, _, err := k.GetIntegerValue("CurrentMinorVersionNumber")
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
}
|
||||
build, _, err := k.GetStringValue("CurrentBuildNumber")
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
}
|
||||
// Update Build Revision
|
||||
ubr, _, err := k.GetIntegerValue("UBR")
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
}
|
||||
ver := fmt.Sprintf("%d.%d.%s.%d", major, minor, build, ubr)
|
||||
return ver
|
||||
}
|
||||
@@ -529,7 +529,7 @@ func (s *serviceClient) getSettingsForm() *widget.Form {
|
||||
var req proto.SetConfigRequest
|
||||
req.ProfileName = activeProf.Name
|
||||
req.Username = currUser.Username
|
||||
|
||||
|
||||
if iMngURL != "" {
|
||||
req.ManagementUrl = iMngURL
|
||||
}
|
||||
@@ -563,28 +563,27 @@ func (s *serviceClient) getSettingsForm() *widget.Form {
|
||||
return
|
||||
}
|
||||
|
||||
go func() {
|
||||
status, err := conn.Status(s.ctx, &proto.StatusRequest{})
|
||||
status, err := conn.Status(s.ctx, &proto.StatusRequest{})
|
||||
if err != nil {
|
||||
log.Errorf("get service status: %v", err)
|
||||
dialog.ShowError(fmt.Errorf("Failed to get service status: %v", err), s.wSettings)
|
||||
return
|
||||
}
|
||||
if status.Status == string(internal.StatusConnected) {
|
||||
// run down & up
|
||||
_, err = conn.Down(s.ctx, &proto.DownRequest{})
|
||||
if err != nil {
|
||||
log.Errorf("get service status: %v", err)
|
||||
dialog.ShowError(fmt.Errorf("Failed to get service status: %v", err), s.wSettings)
|
||||
log.Errorf("down service: %v", err)
|
||||
}
|
||||
|
||||
_, err = conn.Up(s.ctx, &proto.UpRequest{})
|
||||
if err != nil {
|
||||
log.Errorf("up service: %v", err)
|
||||
dialog.ShowError(fmt.Errorf("Failed to reconnect: %v", err), s.wSettings)
|
||||
return
|
||||
}
|
||||
if status.Status == string(internal.StatusConnected) {
|
||||
// run down & up
|
||||
_, err = conn.Down(s.ctx, &proto.DownRequest{})
|
||||
if err != nil {
|
||||
log.Errorf("down service: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
_, err = conn.Up(s.ctx, &proto.UpRequest{})
|
||||
if err != nil {
|
||||
log.Errorf("up service: %v", err)
|
||||
dialog.ShowError(fmt.Errorf("Failed to reconnect: %v", err), s.wSettings)
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
},
|
||||
OnCancel: func() {
|
||||
|
||||
4
go.mod
4
go.mod
@@ -62,7 +62,7 @@ require (
|
||||
github.com/miekg/dns v1.1.59
|
||||
github.com/mitchellh/hashstructure/v2 v2.0.2
|
||||
github.com/nadoo/ipset v0.5.0
|
||||
github.com/netbirdio/management-integrations/integrations v0.0.0-20250906095204-f87a07690ba0
|
||||
github.com/netbirdio/management-integrations/integrations v0.0.0-20250820151658-9ee1b34f4190
|
||||
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45
|
||||
github.com/okta/okta-sdk-golang/v2 v2.18.0
|
||||
github.com/oschwald/maxminddb-golang v1.12.0
|
||||
@@ -261,6 +261,6 @@ replace golang.zx2c4.com/wireguard => github.com/netbirdio/wireguard-go v0.0.0-2
|
||||
|
||||
replace github.com/cloudflare/circl => github.com/cunicu/circl v0.0.0-20230801113412-fec58fc7b5f6
|
||||
|
||||
replace github.com/pion/ice/v4 => github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51
|
||||
replace github.com/pion/ice/v4 => github.com/netbirdio/ice/v4 v4.0.0-20250827161942-426799a23107
|
||||
|
||||
replace github.com/libp2p/go-netroute => github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944
|
||||
|
||||
8
go.sum
8
go.sum
@@ -501,10 +501,10 @@ github.com/neelance/astrewrite v0.0.0-20160511093645-99348263ae86/go.mod h1:kHJE
|
||||
github.com/neelance/sourcemap v0.0.0-20200213170602-2833bce08e4c/go.mod h1:Qr6/a/Q4r9LP1IltGz7tA7iOK1WonHEYhu1HRBA7ZiM=
|
||||
github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944 h1:TDtJKmM6Sf8uYFx/dMeqNOL90KUoRscdfpFZ3Im89uk=
|
||||
github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944/go.mod h1:sHA6TRxjQ6RLbnI+3R4DZo2Eseg/iKiPRfNmcuNySVQ=
|
||||
github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51 h1:Ov4qdafATOgGMB1wbSuh+0aAHcwz9hdvB6VZjh1mVMI=
|
||||
github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51/go.mod h1:ZSIbPdBn5hePO8CpF1PekH2SfpTxg1PDhEwtbqZS7R8=
|
||||
github.com/netbirdio/management-integrations/integrations v0.0.0-20250906095204-f87a07690ba0 h1:9BUqQHPVOGr0edk8EifUBUfTr2Ob0ypAPxtasUApBxQ=
|
||||
github.com/netbirdio/management-integrations/integrations v0.0.0-20250906095204-f87a07690ba0/go.mod h1:v0nUbbHbuQnqR7yKIYnKzsLBCswLtp2JctmKYmGgVhc=
|
||||
github.com/netbirdio/ice/v4 v4.0.0-20250827161942-426799a23107 h1:ZJwhKexMlK15B/Ld+1T8VYE2Mt1lk1kf2DlXr46EHcw=
|
||||
github.com/netbirdio/ice/v4 v4.0.0-20250827161942-426799a23107/go.mod h1:ZSIbPdBn5hePO8CpF1PekH2SfpTxg1PDhEwtbqZS7R8=
|
||||
github.com/netbirdio/management-integrations/integrations v0.0.0-20250820151658-9ee1b34f4190 h1:/ZbExdcDwRq6XgTpTf5I1DPqnC3eInEf0fcmkqR8eSg=
|
||||
github.com/netbirdio/management-integrations/integrations v0.0.0-20250820151658-9ee1b34f4190/go.mod h1:v0nUbbHbuQnqR7yKIYnKzsLBCswLtp2JctmKYmGgVhc=
|
||||
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9axERMVN63dqyFqnvuD+EMJHzM7mNGON8=
|
||||
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
|
||||
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45 h1:ujgviVYmx243Ksy7NdSwrdGPSRNE3pb8kEDSpH0QuAQ=
|
||||
|
||||
@@ -328,45 +328,6 @@ delete_auto_service_user() {
|
||||
echo "$PARSED_RESPONSE"
|
||||
}
|
||||
|
||||
delete_default_zitadel_admin() {
|
||||
INSTANCE_URL=$1
|
||||
PAT=$2
|
||||
|
||||
# Search for the default zitadel-admin user
|
||||
RESPONSE=$(
|
||||
curl -sS -X POST "$INSTANCE_URL/management/v1/users/_search" \
|
||||
-H "Authorization: Bearer $PAT" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"queries": [
|
||||
{
|
||||
"userNameQuery": {
|
||||
"userName": "zitadel-admin@",
|
||||
"method": "TEXT_QUERY_METHOD_STARTS_WITH"
|
||||
}
|
||||
}
|
||||
]
|
||||
}'
|
||||
)
|
||||
|
||||
DEFAULT_ADMIN_ID=$(echo "$RESPONSE" | jq -r '.result[0].id // empty')
|
||||
|
||||
if [ -n "$DEFAULT_ADMIN_ID" ] && [ "$DEFAULT_ADMIN_ID" != "null" ]; then
|
||||
echo "Found default zitadel-admin user with ID: $DEFAULT_ADMIN_ID"
|
||||
|
||||
RESPONSE=$(
|
||||
curl -sS -X DELETE "$INSTANCE_URL/management/v1/users/$DEFAULT_ADMIN_ID" \
|
||||
-H "Authorization: Bearer $PAT" \
|
||||
-H "Content-Type: application/json" \
|
||||
)
|
||||
PARSED_RESPONSE=$(echo "$RESPONSE" | jq -r '.details.changeDate // "deleted"')
|
||||
handle_zitadel_request_response "$PARSED_RESPONSE" "delete_default_zitadel_admin" "$RESPONSE"
|
||||
|
||||
else
|
||||
echo "Default zitadel-admin user not found: $RESPONSE"
|
||||
fi
|
||||
}
|
||||
|
||||
init_zitadel() {
|
||||
echo -e "\nInitializing Zitadel with NetBird's applications\n"
|
||||
INSTANCE_URL="$NETBIRD_HTTP_PROTOCOL://$NETBIRD_DOMAIN"
|
||||
@@ -385,9 +346,6 @@ init_zitadel() {
|
||||
echo -n "Waiting for Zitadel to become ready "
|
||||
wait_api "$INSTANCE_URL" "$PAT"
|
||||
|
||||
echo "Deleting default zitadel-admin user..."
|
||||
delete_default_zitadel_admin "$INSTANCE_URL" "$PAT"
|
||||
|
||||
# create the zitadel project
|
||||
echo "Creating new zitadel project"
|
||||
PROJECT_ID=$(create_new_project "$INSTANCE_URL" "$PAT")
|
||||
|
||||
@@ -20,11 +20,7 @@ func (s *BaseServer) PeersUpdateManager() *server.PeersUpdateManager {
|
||||
|
||||
func (s *BaseServer) IntegratedValidator() integrated_validator.IntegratedValidator {
|
||||
return Create(s, func() integrated_validator.IntegratedValidator {
|
||||
integratedPeerValidator, err := integrations.NewIntegratedValidator(
|
||||
context.Background(),
|
||||
s.PeersManager(),
|
||||
s.SettingsManager(),
|
||||
s.EventStore())
|
||||
integratedPeerValidator, err := integrations.NewIntegratedValidator(context.Background(), s.EventStore())
|
||||
if err != nil {
|
||||
log.Errorf("failed to create integrated peer validator: %v", err)
|
||||
}
|
||||
|
||||
@@ -1714,9 +1714,7 @@ func (am *DefaultAccountManager) onPeersInvalidated(ctx context.Context, account
|
||||
log.WithContext(ctx).Errorf("failed to get invalidated peer %s for account %s: %v", peerID, accountID, err)
|
||||
continue
|
||||
}
|
||||
if peer.UserID != "" {
|
||||
peers = append(peers, peer)
|
||||
}
|
||||
peers = append(peers, peer)
|
||||
}
|
||||
if len(peers) > 0 {
|
||||
err := am.expireAndUpdatePeers(ctx, accountID, peers)
|
||||
|
||||
@@ -18,7 +18,6 @@ type Manager interface {
|
||||
GetPeer(ctx context.Context, accountID, userID, peerID string) (*peer.Peer, error)
|
||||
GetPeerAccountID(ctx context.Context, peerID string) (string, error)
|
||||
GetAllPeers(ctx context.Context, accountID, userID string) ([]*peer.Peer, error)
|
||||
GetPeersByGroupIDs(ctx context.Context, accountID string, groupsIDs []string) ([]*peer.Peer, error)
|
||||
}
|
||||
|
||||
type managerImpl struct {
|
||||
@@ -62,7 +61,3 @@ func (m *managerImpl) GetAllPeers(ctx context.Context, accountID, userID string)
|
||||
func (m *managerImpl) GetPeerAccountID(ctx context.Context, peerID string) (string, error) {
|
||||
return m.store.GetAccountIDByPeerID(ctx, store.LockingStrengthNone, peerID)
|
||||
}
|
||||
|
||||
func (m *managerImpl) GetPeersByGroupIDs(ctx context.Context, accountID string, groupsIDs []string) ([]*peer.Peer, error) {
|
||||
return m.store.GetPeersByGroupIDs(ctx, accountID, groupsIDs)
|
||||
}
|
||||
|
||||
@@ -79,18 +79,3 @@ func (mr *MockManagerMockRecorder) GetPeerAccountID(ctx, peerID interface{}) *go
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPeerAccountID", reflect.TypeOf((*MockManager)(nil).GetPeerAccountID), ctx, peerID)
|
||||
}
|
||||
|
||||
// GetPeersByGroupIDs mocks base method.
|
||||
func (m *MockManager) GetPeersByGroupIDs(ctx context.Context, accountID string, groupsIDs []string) ([]*peer.Peer, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetPeersByGroupIDs", ctx, accountID, groupsIDs)
|
||||
ret0, _ := ret[0].([]*peer.Peer)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetPeersByGroupIDs indicates an expected call of GetPeersByGroupIDs.
|
||||
func (mr *MockManagerMockRecorder) GetPeersByGroupIDs(ctx, accountID, groupsIDs interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPeersByGroupIDs", reflect.TypeOf((*MockManager)(nil).GetPeersByGroupIDs), ctx, accountID, groupsIDs)
|
||||
}
|
||||
|
||||
@@ -167,22 +167,10 @@ func arePolicyChangesAffectPeers(ctx context.Context, transaction store.Store, a
|
||||
// validatePolicy validates the policy and its rules.
|
||||
func validatePolicy(ctx context.Context, transaction store.Store, accountID string, policy *types.Policy) error {
|
||||
if policy.ID != "" {
|
||||
existingPolicy, err := transaction.GetPolicyByID(ctx, store.LockingStrengthNone, accountID, policy.ID)
|
||||
_, err := transaction.GetPolicyByID(ctx, store.LockingStrengthNone, accountID, policy.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// TODO: Refactor to support multiple rules per policy
|
||||
existingRuleIDs := make(map[string]bool)
|
||||
for _, rule := range existingPolicy.Rules {
|
||||
existingRuleIDs[rule.ID] = true
|
||||
}
|
||||
|
||||
for _, rule := range policy.Rules {
|
||||
if rule.ID != "" && !existingRuleIDs[rule.ID] {
|
||||
return status.Errorf(status.InvalidArgument, "invalid rule ID: %s", rule.ID)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
policy.ID = xid.New().String()
|
||||
policy.AccountID = accountID
|
||||
|
||||
@@ -2847,22 +2847,3 @@ func (s *SqlStore) UpdateAccountNetwork(ctx context.Context, accountID string, i
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetPeersByGroupIDs(ctx context.Context, accountID string, groupIDs []string) ([]*nbpeer.Peer, error) {
|
||||
if len(groupIDs) == 0 {
|
||||
return []*nbpeer.Peer{}, nil
|
||||
}
|
||||
|
||||
var peers []*nbpeer.Peer
|
||||
peerIDsSubquery := s.db.Model(&types.GroupPeer{}).
|
||||
Select("DISTINCT peer_id").
|
||||
Where("account_id = ? AND group_id IN ?", accountID, groupIDs)
|
||||
|
||||
result := s.db.Where("id IN (?)", peerIDsSubquery).Find(&peers)
|
||||
if result.Error != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get peers by group IDs: %s", result.Error)
|
||||
return nil, status.Errorf(status.Internal, "failed to get peers by group IDs")
|
||||
}
|
||||
|
||||
return peers, nil
|
||||
}
|
||||
|
||||
@@ -3607,113 +3607,3 @@ func intToIPv4(n uint32) net.IP {
|
||||
binary.BigEndian.PutUint32(ip, n)
|
||||
return ip
|
||||
}
|
||||
|
||||
func TestSqlStore_GetPeersByGroupIDs(t *testing.T) {
|
||||
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||
|
||||
group1ID := "test-group-1"
|
||||
group2ID := "test-group-2"
|
||||
emptyGroupID := "empty-group"
|
||||
|
||||
peer1 := "cfefqs706sqkneg59g4g"
|
||||
peer2 := "cfeg6sf06sqkneg59g50"
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
groupIDs []string
|
||||
expectedPeers []string
|
||||
expectedCount int
|
||||
}{
|
||||
{
|
||||
name: "retrieve peers from single group with multiple peers",
|
||||
groupIDs: []string{group1ID},
|
||||
expectedPeers: []string{peer1, peer2},
|
||||
expectedCount: 2,
|
||||
},
|
||||
{
|
||||
name: "retrieve peers from single group with one peer",
|
||||
groupIDs: []string{group2ID},
|
||||
expectedPeers: []string{peer1},
|
||||
expectedCount: 1,
|
||||
},
|
||||
{
|
||||
name: "retrieve peers from multiple groups (with overlap)",
|
||||
groupIDs: []string{group1ID, group2ID},
|
||||
expectedPeers: []string{peer1, peer2}, // should deduplicate
|
||||
expectedCount: 2,
|
||||
},
|
||||
{
|
||||
name: "retrieve peers from existing 'All' group",
|
||||
groupIDs: []string{"cfefqs706sqkneg59g3g"}, // All group from test data
|
||||
expectedPeers: []string{peer1, peer2},
|
||||
expectedCount: 2,
|
||||
},
|
||||
{
|
||||
name: "retrieve peers from empty group",
|
||||
groupIDs: []string{emptyGroupID},
|
||||
expectedPeers: []string{},
|
||||
expectedCount: 0,
|
||||
},
|
||||
{
|
||||
name: "retrieve peers from non-existing group",
|
||||
groupIDs: []string{"non-existing-group"},
|
||||
expectedPeers: []string{},
|
||||
expectedCount: 0,
|
||||
},
|
||||
{
|
||||
name: "empty group IDs list",
|
||||
groupIDs: []string{},
|
||||
expectedPeers: []string{},
|
||||
expectedCount: 0,
|
||||
},
|
||||
{
|
||||
name: "mix of existing and non-existing groups",
|
||||
groupIDs: []string{group1ID, "non-existing-group"},
|
||||
expectedPeers: []string{peer1, peer2},
|
||||
expectedCount: 2,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store_policy_migrate.sql", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
groups := []*types.Group{
|
||||
{
|
||||
ID: group1ID,
|
||||
AccountID: accountID,
|
||||
},
|
||||
{
|
||||
ID: group2ID,
|
||||
AccountID: accountID,
|
||||
},
|
||||
}
|
||||
require.NoError(t, store.CreateGroups(ctx, accountID, groups))
|
||||
|
||||
require.NoError(t, store.AddPeerToGroup(ctx, accountID, peer1, group1ID))
|
||||
require.NoError(t, store.AddPeerToGroup(ctx, accountID, peer2, group1ID))
|
||||
require.NoError(t, store.AddPeerToGroup(ctx, accountID, peer1, group2ID))
|
||||
|
||||
peers, err := store.GetPeersByGroupIDs(ctx, accountID, tt.groupIDs)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, peers, tt.expectedCount)
|
||||
|
||||
if tt.expectedCount > 0 {
|
||||
actualPeerIDs := make([]string, len(peers))
|
||||
for i, peer := range peers {
|
||||
actualPeerIDs[i] = peer.ID
|
||||
}
|
||||
assert.ElementsMatch(t, tt.expectedPeers, actualPeerIDs)
|
||||
|
||||
// Verify all returned peers belong to the correct account
|
||||
for _, peer := range peers {
|
||||
assert.Equal(t, accountID, peer.AccountID)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -136,7 +136,6 @@ type Store interface {
|
||||
GetUserPeers(ctx context.Context, lockStrength LockingStrength, accountID, userID string) ([]*nbpeer.Peer, error)
|
||||
GetPeerByID(ctx context.Context, lockStrength LockingStrength, accountID string, peerID string) (*nbpeer.Peer, error)
|
||||
GetPeersByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, peerIDs []string) (map[string]*nbpeer.Peer, error)
|
||||
GetPeersByGroupIDs(ctx context.Context, accountID string, groupIDs []string) ([]*nbpeer.Peer, error)
|
||||
GetAccountPeersWithExpiration(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbpeer.Peer, error)
|
||||
GetAccountPeersWithInactivity(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbpeer.Peer, error)
|
||||
GetAllEphemeralPeers(ctx context.Context, lockStrength LockingStrength) ([]*nbpeer.Peer, error)
|
||||
|
||||
@@ -302,11 +302,7 @@ func (a *Account) GetPeerNetworkMap(
|
||||
var zones []nbdns.CustomZone
|
||||
|
||||
if peersCustomZone.Domain != "" {
|
||||
records := filterZoneRecordsForPeers(peer, peersCustomZone, peersToConnect)
|
||||
zones = append(zones, nbdns.CustomZone{
|
||||
Domain: peersCustomZone.Domain,
|
||||
Records: records,
|
||||
})
|
||||
zones = append(zones, peersCustomZone)
|
||||
}
|
||||
dnsUpdate.CustomZones = zones
|
||||
dnsUpdate.NameServerGroups = getPeerNSGroups(a, peerID)
|
||||
@@ -1655,24 +1651,3 @@ func peerSupportsPortRanges(peerVer string) bool {
|
||||
meetMinVer, err := posture.MeetsMinVersion(firewallRuleMinPortRangesVer, peerVer)
|
||||
return err == nil && meetMinVer
|
||||
}
|
||||
|
||||
// filterZoneRecordsForPeers filters DNS records to only include peers to connect.
|
||||
func filterZoneRecordsForPeers(peer *nbpeer.Peer, customZone nbdns.CustomZone, peersToConnect []*nbpeer.Peer) []nbdns.SimpleRecord {
|
||||
filteredRecords := make([]nbdns.SimpleRecord, 0, len(customZone.Records))
|
||||
peerIPs := make(map[string]struct{})
|
||||
|
||||
// Add peer's own IP to include its own DNS records
|
||||
peerIPs[peer.IP.String()] = struct{}{}
|
||||
|
||||
for _, peerToConnect := range peersToConnect {
|
||||
peerIPs[peerToConnect.IP.String()] = struct{}{}
|
||||
}
|
||||
|
||||
for _, record := range customZone.Records {
|
||||
if _, exists := peerIPs[record.RData]; exists {
|
||||
filteredRecords = append(filteredRecords, record)
|
||||
}
|
||||
}
|
||||
|
||||
return filteredRecords
|
||||
}
|
||||
|
||||
@@ -2,17 +2,14 @@ package types
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"slices"
|
||||
"testing"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
|
||||
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
|
||||
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
|
||||
@@ -838,109 +835,3 @@ func Test_NetworksNetMapGenShouldExcludeOtherRouters(t *testing.T) {
|
||||
assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match")
|
||||
assert.Len(t, sourcePeers, 2, "expected source peers don't match")
|
||||
}
|
||||
|
||||
func Test_FilterZoneRecordsForPeers(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
peer *nbpeer.Peer
|
||||
customZone nbdns.CustomZone
|
||||
peersToConnect []*nbpeer.Peer
|
||||
expectedRecords []nbdns.SimpleRecord
|
||||
}{
|
||||
{
|
||||
name: "empty peers to connect",
|
||||
customZone: nbdns.CustomZone{
|
||||
Domain: "netbird.cloud.",
|
||||
Records: []nbdns.SimpleRecord{
|
||||
{Name: "peer1.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"},
|
||||
{Name: "router.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.100"},
|
||||
},
|
||||
},
|
||||
peersToConnect: []*nbpeer.Peer{},
|
||||
peer: &nbpeer.Peer{ID: "router", IP: net.ParseIP("10.0.0.100")},
|
||||
expectedRecords: []nbdns.SimpleRecord{
|
||||
{Name: "router.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.100"},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "multiple peers multiple records match",
|
||||
customZone: nbdns.CustomZone{
|
||||
Domain: "netbird.cloud.",
|
||||
Records: func() []nbdns.SimpleRecord {
|
||||
var records []nbdns.SimpleRecord
|
||||
for i := 1; i <= 100; i++ {
|
||||
records = append(records, nbdns.SimpleRecord{
|
||||
Name: fmt.Sprintf("peer%d.netbird.cloud", i),
|
||||
Type: int(dns.TypeA),
|
||||
Class: nbdns.DefaultClass,
|
||||
TTL: 300,
|
||||
RData: fmt.Sprintf("10.0.%d.%d", i/256, i%256),
|
||||
})
|
||||
}
|
||||
return records
|
||||
}(),
|
||||
},
|
||||
peersToConnect: func() []*nbpeer.Peer {
|
||||
var peers []*nbpeer.Peer
|
||||
for _, i := range []int{1, 5, 10, 25, 50, 75, 100} {
|
||||
peers = append(peers, &nbpeer.Peer{
|
||||
ID: fmt.Sprintf("peer%d", i),
|
||||
IP: net.ParseIP(fmt.Sprintf("10.0.%d.%d", i/256, i%256)),
|
||||
})
|
||||
}
|
||||
return peers
|
||||
}(),
|
||||
peer: &nbpeer.Peer{ID: "router", IP: net.ParseIP("10.0.0.100")},
|
||||
expectedRecords: func() []nbdns.SimpleRecord {
|
||||
var records []nbdns.SimpleRecord
|
||||
for _, i := range []int{1, 5, 10, 25, 50, 75, 100} {
|
||||
records = append(records, nbdns.SimpleRecord{
|
||||
Name: fmt.Sprintf("peer%d.netbird.cloud", i),
|
||||
Type: int(dns.TypeA),
|
||||
Class: nbdns.DefaultClass,
|
||||
TTL: 300,
|
||||
RData: fmt.Sprintf("10.0.%d.%d", i/256, i%256),
|
||||
})
|
||||
}
|
||||
return records
|
||||
}(),
|
||||
},
|
||||
{
|
||||
name: "peers with multiple DNS labels",
|
||||
customZone: nbdns.CustomZone{
|
||||
Domain: "netbird.cloud.",
|
||||
Records: []nbdns.SimpleRecord{
|
||||
{Name: "peer1.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"},
|
||||
{Name: "peer1-alt.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"},
|
||||
{Name: "peer1-backup.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"},
|
||||
{Name: "peer2.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.2"},
|
||||
{Name: "peer2-service.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.2"},
|
||||
{Name: "peer3.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.3"},
|
||||
{Name: "peer3-alt.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.3"},
|
||||
{Name: "router.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.100"},
|
||||
},
|
||||
},
|
||||
peersToConnect: []*nbpeer.Peer{
|
||||
{ID: "peer1", IP: net.ParseIP("10.0.0.1"), DNSLabel: "peer1", ExtraDNSLabels: []string{"peer1-alt", "peer1-backup"}},
|
||||
{ID: "peer2", IP: net.ParseIP("10.0.0.2"), DNSLabel: "peer2", ExtraDNSLabels: []string{"peer2-service"}},
|
||||
},
|
||||
peer: &nbpeer.Peer{ID: "router", IP: net.ParseIP("10.0.0.100")},
|
||||
expectedRecords: []nbdns.SimpleRecord{
|
||||
{Name: "peer1.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"},
|
||||
{Name: "peer1-alt.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"},
|
||||
{Name: "peer1-backup.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"},
|
||||
{Name: "peer2.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.2"},
|
||||
{Name: "peer2-service.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.2"},
|
||||
{Name: "router.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.100"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := filterZoneRecordsForPeers(tt.peer, tt.customZone, tt.peersToConnect)
|
||||
assert.Equal(t, len(tt.expectedRecords), len(result))
|
||||
assert.ElementsMatch(t, tt.expectedRecords, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -942,11 +942,6 @@ func (am *DefaultAccountManager) expireAndUpdatePeers(ctx context.Context, accou
|
||||
// nolint:staticcheck
|
||||
ctx = context.WithValue(ctx, nbContext.PeerIDKey, peer.Key)
|
||||
|
||||
if peer.UserID == "" {
|
||||
// we do not want to expire peers that are added via setup key
|
||||
continue
|
||||
}
|
||||
|
||||
if peer.Status.LoginExpired {
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -130,6 +130,36 @@ repo_gpgcheck=1
|
||||
EOF
|
||||
}
|
||||
|
||||
install_aur_package() {
|
||||
INSTALL_PKGS="git base-devel go"
|
||||
REMOVE_PKGS=""
|
||||
|
||||
# Check if dependencies are installed
|
||||
for PKG in $INSTALL_PKGS; do
|
||||
if ! pacman -Q "$PKG" > /dev/null 2>&1; then
|
||||
# Install missing package(s)
|
||||
${SUDO} pacman -S "$PKG" --noconfirm
|
||||
|
||||
# Add installed package for clean up later
|
||||
REMOVE_PKGS="$REMOVE_PKGS $PKG"
|
||||
fi
|
||||
done
|
||||
|
||||
# Build package from AUR
|
||||
cd /tmp && git clone https://aur.archlinux.org/netbird.git
|
||||
cd netbird && makepkg -sri --noconfirm
|
||||
|
||||
if ! $SKIP_UI_APP; then
|
||||
cd /tmp && git clone https://aur.archlinux.org/netbird-ui.git
|
||||
cd netbird-ui && makepkg -sri --noconfirm
|
||||
fi
|
||||
|
||||
if [ -n "$REMOVE_PKGS" ]; then
|
||||
# Clean up the installed packages
|
||||
${SUDO} pacman -Rs "$REMOVE_PKGS" --noconfirm
|
||||
fi
|
||||
}
|
||||
|
||||
prepare_tun_module() {
|
||||
# Create the necessary file structure for /dev/net/tun
|
||||
if [ ! -c /dev/net/tun ]; then
|
||||
@@ -246,9 +276,12 @@ install_netbird() {
|
||||
if ! $SKIP_UI_APP; then
|
||||
${SUDO} rpm-ostree -y install netbird-ui
|
||||
fi
|
||||
# ensure the service is started after install
|
||||
${SUDO} netbird service install || true
|
||||
${SUDO} netbird service start || true
|
||||
;;
|
||||
pacman)
|
||||
${SUDO} pacman -Syy
|
||||
install_aur_package
|
||||
# in-line with the docs at https://wiki.archlinux.org/title/Netbird
|
||||
${SUDO} systemctl enable --now netbird@main.service
|
||||
;;
|
||||
pkg)
|
||||
# Check if the package is already installed
|
||||
@@ -425,7 +458,11 @@ if type uname >/dev/null 2>&1; then
|
||||
elif [ -x "$(command -v yum)" ]; then
|
||||
PACKAGE_MANAGER="yum"
|
||||
echo "The installation will be performed using yum package manager"
|
||||
elif [ -x "$(command -v pacman)" ]; then
|
||||
PACKAGE_MANAGER="pacman"
|
||||
echo "The installation will be performed using pacman package manager"
|
||||
fi
|
||||
|
||||
else
|
||||
echo "Unable to determine OS type from /etc/os-release"
|
||||
exit 1
|
||||
|
||||
@@ -9,30 +9,34 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
|
||||
"github.com/netbirdio/management-integrations/integrations"
|
||||
"github.com/netbirdio/netbird/client/system"
|
||||
"github.com/netbirdio/netbird/encryption"
|
||||
"github.com/netbirdio/netbird/management/internals/server/config"
|
||||
mgmt "github.com/netbirdio/netbird/management/server"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/groups"
|
||||
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
||||
"github.com/netbirdio/netbird/management/server/mock_server"
|
||||
"github.com/netbirdio/netbird/management/server/peers"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/settings"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/netbirdio/management-integrations/integrations"
|
||||
|
||||
"github.com/netbirdio/netbird/encryption"
|
||||
mgmt "github.com/netbirdio/netbird/management/server"
|
||||
"github.com/netbirdio/netbird/management/server/mock_server"
|
||||
mgmtProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
|
||||
@@ -68,31 +72,13 @@ func startManagement(t *testing.T) (*grpc.Server, net.Listener) {
|
||||
|
||||
peersUpdateManager := mgmt.NewPeersUpdateManager(nil)
|
||||
eventStore := &activity.InMemoryEventStore{}
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
t.Cleanup(ctrl.Finish)
|
||||
|
||||
permissionsManagerMock := permissions.NewMockManager(ctrl)
|
||||
permissionsManagerMock.
|
||||
EXPECT().
|
||||
ValidateUserPermissions(
|
||||
gomock.Any(),
|
||||
gomock.Any(),
|
||||
gomock.Any(),
|
||||
gomock.Any(),
|
||||
gomock.Any(),
|
||||
).
|
||||
Return(true, nil).
|
||||
AnyTimes()
|
||||
|
||||
peersManger := peers.NewManager(store, permissionsManagerMock)
|
||||
settingsManagerMock := settings.NewMockManager(ctrl)
|
||||
|
||||
ia, _ := integrations.NewIntegratedValidator(context.Background(), peersManger, settingsManagerMock, eventStore)
|
||||
ia, _ := integrations.NewIntegratedValidator(context.Background(), eventStore)
|
||||
|
||||
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
|
||||
require.NoError(t, err)
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
t.Cleanup(ctrl.Finish)
|
||||
settingsMockManager := settings.NewMockManager(ctrl)
|
||||
settingsMockManager.
|
||||
EXPECT().
|
||||
@@ -109,6 +95,19 @@ func startManagement(t *testing.T) (*grpc.Server, net.Listener) {
|
||||
Return(&types.ExtraSettings{}, nil).
|
||||
AnyTimes()
|
||||
|
||||
permissionsManagerMock := permissions.NewMockManager(ctrl)
|
||||
permissionsManagerMock.
|
||||
EXPECT().
|
||||
ValidateUserPermissions(
|
||||
gomock.Any(),
|
||||
gomock.Any(),
|
||||
gomock.Any(),
|
||||
gomock.Any(),
|
||||
gomock.Any(),
|
||||
).
|
||||
Return(true, nil).
|
||||
AnyTimes()
|
||||
|
||||
accountManager, err := mgmt.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
|
||||
@@ -78,10 +78,9 @@ func NewManager(ctx context.Context, serverURLs []string, peerID string, mtu uin
|
||||
tokenStore: tokenStore,
|
||||
mtu: mtu,
|
||||
serverPicker: &ServerPicker{
|
||||
TokenStore: tokenStore,
|
||||
PeerID: peerID,
|
||||
MTU: mtu,
|
||||
ConnectionTimeout: defaultConnectionTimeout,
|
||||
TokenStore: tokenStore,
|
||||
PeerID: peerID,
|
||||
MTU: mtu,
|
||||
},
|
||||
relayClients: make(map[string]*RelayTrack),
|
||||
onDisconnectedListeners: make(map[string]*list.List),
|
||||
|
||||
@@ -13,8 +13,11 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
maxConcurrentServers = 7
|
||||
defaultConnectionTimeout = 30 * time.Second
|
||||
maxConcurrentServers = 7
|
||||
)
|
||||
|
||||
var (
|
||||
connectionTimeout = 30 * time.Second
|
||||
)
|
||||
|
||||
type connResult struct {
|
||||
@@ -24,15 +27,14 @@ type connResult struct {
|
||||
}
|
||||
|
||||
type ServerPicker struct {
|
||||
TokenStore *auth.TokenStore
|
||||
ServerURLs atomic.Value
|
||||
PeerID string
|
||||
MTU uint16
|
||||
ConnectionTimeout time.Duration
|
||||
TokenStore *auth.TokenStore
|
||||
ServerURLs atomic.Value
|
||||
PeerID string
|
||||
MTU uint16
|
||||
}
|
||||
|
||||
func (sp *ServerPicker) PickServer(parentCtx context.Context) (*Client, error) {
|
||||
ctx, cancel := context.WithTimeout(parentCtx, sp.ConnectionTimeout)
|
||||
ctx, cancel := context.WithTimeout(parentCtx, connectionTimeout)
|
||||
defer cancel()
|
||||
|
||||
totalServers := len(sp.ServerURLs.Load().([]string))
|
||||
|
||||
@@ -8,15 +8,15 @@ import (
|
||||
)
|
||||
|
||||
func TestServerPicker_UnavailableServers(t *testing.T) {
|
||||
timeout := 5 * time.Second
|
||||
connectionTimeout = 5 * time.Second
|
||||
|
||||
sp := ServerPicker{
|
||||
TokenStore: nil,
|
||||
PeerID: "test",
|
||||
ConnectionTimeout: timeout,
|
||||
TokenStore: nil,
|
||||
PeerID: "test",
|
||||
}
|
||||
sp.ServerURLs.Store([]string{"rel://dummy1", "rel://dummy2"})
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), timeout+1)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), connectionTimeout+1)
|
||||
defer cancel()
|
||||
|
||||
go func() {
|
||||
|
||||
@@ -1,24 +0,0 @@
|
||||
package healthcheck
|
||||
|
||||
import (
|
||||
"os"
|
||||
"strconv"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultAttemptThresholdEnv = "NB_RELAY_HC_ATTEMPT_THRESHOLD"
|
||||
)
|
||||
|
||||
func getAttemptThresholdFromEnv() int {
|
||||
if attemptThreshold := os.Getenv(defaultAttemptThresholdEnv); attemptThreshold != "" {
|
||||
threshold, err := strconv.ParseInt(attemptThreshold, 10, 64)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to parse attempt threshold from environment variable \"%s\" should be an integer. Using default value", attemptThreshold)
|
||||
return defaultAttemptThreshold
|
||||
}
|
||||
return int(threshold)
|
||||
}
|
||||
return defaultAttemptThreshold
|
||||
}
|
||||
@@ -1,36 +0,0 @@
|
||||
package healthcheck
|
||||
|
||||
import (
|
||||
"os"
|
||||
"testing"
|
||||
)
|
||||
|
||||
//nolint:tenv
|
||||
func TestGetAttemptThresholdFromEnv(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
envValue string
|
||||
expected int
|
||||
}{
|
||||
{"Default attempt threshold when env is not set", "", defaultAttemptThreshold},
|
||||
{"Custom attempt threshold when env is set to a valid integer", "3", 3},
|
||||
{"Default attempt threshold when env is set to an invalid value", "invalid", defaultAttemptThreshold},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if tt.envValue == "" {
|
||||
os.Unsetenv(defaultAttemptThresholdEnv)
|
||||
} else {
|
||||
os.Setenv(defaultAttemptThresholdEnv, tt.envValue)
|
||||
}
|
||||
|
||||
result := getAttemptThresholdFromEnv()
|
||||
if result != tt.expected {
|
||||
t.Fatalf("Expected %d, got %d", tt.expected, result)
|
||||
}
|
||||
|
||||
os.Unsetenv(defaultAttemptThresholdEnv)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -7,15 +7,10 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultHeartbeatTimeout = defaultHealthCheckInterval + 10*time.Second
|
||||
var (
|
||||
heartbeatTimeout = healthCheckInterval + 10*time.Second
|
||||
)
|
||||
|
||||
type ReceiverOptions struct {
|
||||
HeartbeatTimeout time.Duration
|
||||
AttemptThreshold int
|
||||
}
|
||||
|
||||
// Receiver is a healthcheck receiver
|
||||
// It will listen for heartbeat and check if the heartbeat is not received in a certain time
|
||||
// If the heartbeat is not received in a certain time, it will send a timeout signal and stop to work
|
||||
@@ -32,23 +27,6 @@ type Receiver struct {
|
||||
|
||||
// NewReceiver creates a new healthcheck receiver and start the timer in the background
|
||||
func NewReceiver(log *log.Entry) *Receiver {
|
||||
opts := ReceiverOptions{
|
||||
HeartbeatTimeout: defaultHeartbeatTimeout,
|
||||
AttemptThreshold: getAttemptThresholdFromEnv(),
|
||||
}
|
||||
return NewReceiverWithOpts(log, opts)
|
||||
}
|
||||
|
||||
func NewReceiverWithOpts(log *log.Entry, opts ReceiverOptions) *Receiver {
|
||||
heartbeatTimeout := opts.HeartbeatTimeout
|
||||
if heartbeatTimeout <= 0 {
|
||||
heartbeatTimeout = defaultHeartbeatTimeout
|
||||
}
|
||||
attemptThreshold := opts.AttemptThreshold
|
||||
if attemptThreshold <= 0 {
|
||||
attemptThreshold = defaultAttemptThreshold
|
||||
}
|
||||
|
||||
ctx, ctxCancel := context.WithCancel(context.Background())
|
||||
|
||||
r := &Receiver{
|
||||
@@ -57,10 +35,10 @@ func NewReceiverWithOpts(log *log.Entry, opts ReceiverOptions) *Receiver {
|
||||
ctx: ctx,
|
||||
ctxCancel: ctxCancel,
|
||||
heartbeat: make(chan struct{}, 1),
|
||||
attemptThreshold: attemptThreshold,
|
||||
attemptThreshold: getAttemptThresholdFromEnv(),
|
||||
}
|
||||
|
||||
go r.waitForHealthcheck(heartbeatTimeout)
|
||||
go r.waitForHealthcheck()
|
||||
return r
|
||||
}
|
||||
|
||||
@@ -77,7 +55,7 @@ func (r *Receiver) Stop() {
|
||||
r.ctxCancel()
|
||||
}
|
||||
|
||||
func (r *Receiver) waitForHealthcheck(heartbeatTimeout time.Duration) {
|
||||
func (r *Receiver) waitForHealthcheck() {
|
||||
ticker := time.NewTicker(heartbeatTimeout)
|
||||
defer ticker.Stop()
|
||||
defer r.ctxCancel()
|
||||
|
||||
@@ -2,18 +2,31 @@ package healthcheck
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
func TestNewReceiver(t *testing.T) {
|
||||
// Mutex to protect global variable access in tests
|
||||
var testMutex sync.Mutex
|
||||
|
||||
opts := ReceiverOptions{
|
||||
HeartbeatTimeout: 5 * time.Second,
|
||||
}
|
||||
r := NewReceiverWithOpts(log.WithContext(context.Background()), opts)
|
||||
func TestNewReceiver(t *testing.T) {
|
||||
testMutex.Lock()
|
||||
originalTimeout := heartbeatTimeout
|
||||
heartbeatTimeout = 5 * time.Second
|
||||
testMutex.Unlock()
|
||||
|
||||
defer func() {
|
||||
testMutex.Lock()
|
||||
heartbeatTimeout = originalTimeout
|
||||
testMutex.Unlock()
|
||||
}()
|
||||
|
||||
r := NewReceiver(log.WithContext(context.Background()))
|
||||
defer r.Stop()
|
||||
|
||||
select {
|
||||
@@ -25,10 +38,18 @@ func TestNewReceiver(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestNewReceiverNotReceive(t *testing.T) {
|
||||
opts := ReceiverOptions{
|
||||
HeartbeatTimeout: 1 * time.Second,
|
||||
}
|
||||
r := NewReceiverWithOpts(log.WithContext(context.Background()), opts)
|
||||
testMutex.Lock()
|
||||
originalTimeout := heartbeatTimeout
|
||||
heartbeatTimeout = 1 * time.Second
|
||||
testMutex.Unlock()
|
||||
|
||||
defer func() {
|
||||
testMutex.Lock()
|
||||
heartbeatTimeout = originalTimeout
|
||||
testMutex.Unlock()
|
||||
}()
|
||||
|
||||
r := NewReceiver(log.WithContext(context.Background()))
|
||||
defer r.Stop()
|
||||
|
||||
select {
|
||||
@@ -40,10 +61,18 @@ func TestNewReceiverNotReceive(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestNewReceiverAck(t *testing.T) {
|
||||
opts := ReceiverOptions{
|
||||
HeartbeatTimeout: 2 * time.Second,
|
||||
}
|
||||
r := NewReceiverWithOpts(log.WithContext(context.Background()), opts)
|
||||
testMutex.Lock()
|
||||
originalTimeout := heartbeatTimeout
|
||||
heartbeatTimeout = 2 * time.Second
|
||||
testMutex.Unlock()
|
||||
|
||||
defer func() {
|
||||
testMutex.Lock()
|
||||
heartbeatTimeout = originalTimeout
|
||||
testMutex.Unlock()
|
||||
}()
|
||||
|
||||
r := NewReceiver(log.WithContext(context.Background()))
|
||||
defer r.Stop()
|
||||
|
||||
r.Heartbeat()
|
||||
@@ -68,19 +97,30 @@ func TestReceiverHealthCheckAttemptThreshold(t *testing.T) {
|
||||
|
||||
for _, tc := range testsCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
healthCheckInterval := 1 * time.Second
|
||||
testMutex.Lock()
|
||||
originalInterval := healthCheckInterval
|
||||
originalTimeout := heartbeatTimeout
|
||||
healthCheckInterval = 1 * time.Second
|
||||
heartbeatTimeout = healthCheckInterval + 500*time.Millisecond
|
||||
testMutex.Unlock()
|
||||
|
||||
opts := ReceiverOptions{
|
||||
HeartbeatTimeout: healthCheckInterval + 500*time.Millisecond,
|
||||
AttemptThreshold: tc.threshold,
|
||||
}
|
||||
defer func() {
|
||||
testMutex.Lock()
|
||||
healthCheckInterval = originalInterval
|
||||
heartbeatTimeout = originalTimeout
|
||||
testMutex.Unlock()
|
||||
}()
|
||||
//nolint:tenv
|
||||
os.Setenv(defaultAttemptThresholdEnv, fmt.Sprintf("%d", tc.threshold))
|
||||
defer os.Unsetenv(defaultAttemptThresholdEnv)
|
||||
|
||||
receiver := NewReceiverWithOpts(log.WithField("test_name", tc.name), opts)
|
||||
receiver := NewReceiver(log.WithField("test_name", tc.name))
|
||||
|
||||
testTimeout := opts.HeartbeatTimeout*time.Duration(tc.threshold) + healthCheckInterval
|
||||
testTimeout := heartbeatTimeout*time.Duration(tc.threshold) + healthCheckInterval
|
||||
|
||||
if tc.resetCounterOnce {
|
||||
receiver.Heartbeat()
|
||||
t.Logf("reset counter once")
|
||||
}
|
||||
|
||||
select {
|
||||
@@ -94,6 +134,7 @@ func TestReceiverHealthCheckAttemptThreshold(t *testing.T) {
|
||||
}
|
||||
t.Fatalf("should have timed out before %s", testTimeout)
|
||||
}
|
||||
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,76 +2,52 @@ package healthcheck
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultAttemptThreshold = 1
|
||||
|
||||
defaultHealthCheckInterval = 25 * time.Second
|
||||
defaultHealthCheckTimeout = 20 * time.Second
|
||||
defaultAttemptThreshold = 1
|
||||
defaultAttemptThresholdEnv = "NB_RELAY_HC_ATTEMPT_THRESHOLD"
|
||||
)
|
||||
|
||||
type SenderOptions struct {
|
||||
HealthCheckInterval time.Duration
|
||||
HealthCheckTimeout time.Duration
|
||||
AttemptThreshold int
|
||||
}
|
||||
var (
|
||||
healthCheckInterval = 25 * time.Second
|
||||
healthCheckTimeout = 20 * time.Second
|
||||
)
|
||||
|
||||
// Sender is a healthcheck sender
|
||||
// It will send healthcheck signal to the receiver
|
||||
// If the receiver does not receive the signal in a certain time, it will send a timeout signal and stop to work
|
||||
// It will also stop if the context is canceled
|
||||
type Sender struct {
|
||||
log *log.Entry
|
||||
// HealthCheck is a channel to send health check signal to the peer
|
||||
HealthCheck chan struct{}
|
||||
// Timeout is a channel to the health check signal is not received in a certain time
|
||||
Timeout chan struct{}
|
||||
|
||||
log *log.Entry
|
||||
healthCheckInterval time.Duration
|
||||
timeout time.Duration
|
||||
|
||||
ack chan struct{}
|
||||
alive bool
|
||||
attemptThreshold int
|
||||
}
|
||||
|
||||
func NewSenderWithOpts(log *log.Entry, opts SenderOptions) *Sender {
|
||||
if opts.HealthCheckInterval <= 0 {
|
||||
opts.HealthCheckInterval = defaultHealthCheckInterval
|
||||
}
|
||||
if opts.HealthCheckTimeout <= 0 {
|
||||
opts.HealthCheckTimeout = defaultHealthCheckTimeout
|
||||
}
|
||||
if opts.AttemptThreshold <= 0 {
|
||||
opts.AttemptThreshold = defaultAttemptThreshold
|
||||
}
|
||||
// NewSender creates a new healthcheck sender
|
||||
func NewSender(log *log.Entry) *Sender {
|
||||
hc := &Sender{
|
||||
HealthCheck: make(chan struct{}, 1),
|
||||
Timeout: make(chan struct{}, 1),
|
||||
log: log,
|
||||
healthCheckInterval: opts.HealthCheckInterval,
|
||||
timeout: opts.HealthCheckInterval + opts.HealthCheckTimeout,
|
||||
ack: make(chan struct{}, 1),
|
||||
attemptThreshold: opts.AttemptThreshold,
|
||||
log: log,
|
||||
HealthCheck: make(chan struct{}, 1),
|
||||
Timeout: make(chan struct{}, 1),
|
||||
ack: make(chan struct{}, 1),
|
||||
attemptThreshold: getAttemptThresholdFromEnv(),
|
||||
}
|
||||
|
||||
return hc
|
||||
}
|
||||
|
||||
// NewSender creates a new healthcheck sender
|
||||
func NewSender(log *log.Entry) *Sender {
|
||||
opts := SenderOptions{
|
||||
HealthCheckInterval: defaultHealthCheckInterval,
|
||||
HealthCheckTimeout: defaultHealthCheckTimeout,
|
||||
AttemptThreshold: getAttemptThresholdFromEnv(),
|
||||
}
|
||||
return NewSenderWithOpts(log, opts)
|
||||
}
|
||||
|
||||
// OnHCResponse sends an acknowledgment signal to the sender
|
||||
func (hc *Sender) OnHCResponse() {
|
||||
select {
|
||||
@@ -81,10 +57,10 @@ func (hc *Sender) OnHCResponse() {
|
||||
}
|
||||
|
||||
func (hc *Sender) StartHealthCheck(ctx context.Context) {
|
||||
ticker := time.NewTicker(hc.healthCheckInterval)
|
||||
ticker := time.NewTicker(healthCheckInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
timeoutTicker := time.NewTicker(hc.timeout)
|
||||
timeoutTicker := time.NewTicker(hc.getTimeoutTime())
|
||||
defer timeoutTicker.Stop()
|
||||
|
||||
defer close(hc.HealthCheck)
|
||||
@@ -116,3 +92,19 @@ func (hc *Sender) StartHealthCheck(ctx context.Context) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (hc *Sender) getTimeoutTime() time.Duration {
|
||||
return healthCheckInterval + healthCheckTimeout
|
||||
}
|
||||
|
||||
func getAttemptThresholdFromEnv() int {
|
||||
if attemptThreshold := os.Getenv(defaultAttemptThresholdEnv); attemptThreshold != "" {
|
||||
threshold, err := strconv.ParseInt(attemptThreshold, 10, 64)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to parse attempt threshold from environment variable \"%s\" should be an integer. Using default value", attemptThreshold)
|
||||
return defaultAttemptThreshold
|
||||
}
|
||||
return int(threshold)
|
||||
}
|
||||
return defaultAttemptThreshold
|
||||
}
|
||||
|
||||
@@ -2,23 +2,26 @@ package healthcheck
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
var (
|
||||
testOpts = SenderOptions{
|
||||
HealthCheckInterval: 2 * time.Second,
|
||||
HealthCheckTimeout: 100 * time.Millisecond,
|
||||
}
|
||||
)
|
||||
func TestMain(m *testing.M) {
|
||||
// override the health check interval to speed up the test
|
||||
healthCheckInterval = 2 * time.Second
|
||||
healthCheckTimeout = 100 * time.Millisecond
|
||||
code := m.Run()
|
||||
os.Exit(code)
|
||||
}
|
||||
|
||||
func TestNewHealthPeriod(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
hc := NewSenderWithOpts(log.WithContext(ctx), testOpts)
|
||||
hc := NewSender(log.WithContext(ctx))
|
||||
go hc.StartHealthCheck(ctx)
|
||||
|
||||
iterations := 0
|
||||
@@ -29,7 +32,7 @@ func TestNewHealthPeriod(t *testing.T) {
|
||||
hc.OnHCResponse()
|
||||
case <-hc.Timeout:
|
||||
t.Fatalf("health check is timed out")
|
||||
case <-time.After(testOpts.HealthCheckInterval + 100*time.Millisecond):
|
||||
case <-time.After(healthCheckInterval + 100*time.Millisecond):
|
||||
t.Fatalf("health check not received")
|
||||
}
|
||||
}
|
||||
@@ -38,19 +41,19 @@ func TestNewHealthPeriod(t *testing.T) {
|
||||
func TestNewHealthFailed(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
hc := NewSenderWithOpts(log.WithContext(ctx), testOpts)
|
||||
hc := NewSender(log.WithContext(ctx))
|
||||
go hc.StartHealthCheck(ctx)
|
||||
|
||||
select {
|
||||
case <-hc.Timeout:
|
||||
case <-time.After(testOpts.HealthCheckInterval + testOpts.HealthCheckTimeout + 100*time.Millisecond):
|
||||
case <-time.After(healthCheckInterval + healthCheckTimeout + 100*time.Millisecond):
|
||||
t.Fatalf("health check is not timed out")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewHealthcheckStop(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
hc := NewSenderWithOpts(log.WithContext(ctx), testOpts)
|
||||
hc := NewSender(log.WithContext(ctx))
|
||||
go hc.StartHealthCheck(ctx)
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
@@ -75,7 +78,7 @@ func TestNewHealthcheckStop(t *testing.T) {
|
||||
func TestTimeoutReset(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
hc := NewSenderWithOpts(log.WithContext(ctx), testOpts)
|
||||
hc := NewSender(log.WithContext(ctx))
|
||||
go hc.StartHealthCheck(ctx)
|
||||
|
||||
iterations := 0
|
||||
@@ -86,7 +89,7 @@ func TestTimeoutReset(t *testing.T) {
|
||||
hc.OnHCResponse()
|
||||
case <-hc.Timeout:
|
||||
t.Fatalf("health check is timed out")
|
||||
case <-time.After(testOpts.HealthCheckInterval + 100*time.Millisecond):
|
||||
case <-time.After(healthCheckInterval + 100*time.Millisecond):
|
||||
t.Fatalf("health check not received")
|
||||
}
|
||||
}
|
||||
@@ -115,16 +118,19 @@ func TestSenderHealthCheckAttemptThreshold(t *testing.T) {
|
||||
|
||||
for _, tc := range testsCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
opts := SenderOptions{
|
||||
HealthCheckInterval: 1 * time.Second,
|
||||
HealthCheckTimeout: 500 * time.Millisecond,
|
||||
AttemptThreshold: tc.threshold,
|
||||
}
|
||||
originalInterval := healthCheckInterval
|
||||
originalTimeout := healthCheckTimeout
|
||||
healthCheckInterval = 1 * time.Second
|
||||
healthCheckTimeout = 500 * time.Millisecond
|
||||
|
||||
//nolint:tenv
|
||||
os.Setenv(defaultAttemptThresholdEnv, fmt.Sprintf("%d", tc.threshold))
|
||||
defer os.Unsetenv(defaultAttemptThresholdEnv)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
sender := NewSenderWithOpts(log.WithField("test_name", tc.name), opts)
|
||||
sender := NewSender(log.WithField("test_name", tc.name))
|
||||
senderExit := make(chan struct{})
|
||||
go func() {
|
||||
sender.StartHealthCheck(ctx)
|
||||
@@ -149,7 +155,7 @@ func TestSenderHealthCheckAttemptThreshold(t *testing.T) {
|
||||
}
|
||||
}()
|
||||
|
||||
testTimeout := (opts.HealthCheckInterval+opts.HealthCheckTimeout)*time.Duration(tc.threshold) + opts.HealthCheckInterval
|
||||
testTimeout := sender.getTimeoutTime()*time.Duration(tc.threshold) + healthCheckInterval
|
||||
|
||||
select {
|
||||
case <-sender.Timeout:
|
||||
@@ -169,7 +175,39 @@ func TestSenderHealthCheckAttemptThreshold(t *testing.T) {
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatalf("sender did not exit in time")
|
||||
}
|
||||
healthCheckInterval = originalInterval
|
||||
healthCheckTimeout = originalTimeout
|
||||
})
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
//nolint:tenv
|
||||
func TestGetAttemptThresholdFromEnv(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
envValue string
|
||||
expected int
|
||||
}{
|
||||
{"Default attempt threshold when env is not set", "", defaultAttemptThreshold},
|
||||
{"Custom attempt threshold when env is set to a valid integer", "3", 3},
|
||||
{"Default attempt threshold when env is set to an invalid value", "invalid", defaultAttemptThreshold},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if tt.envValue == "" {
|
||||
os.Unsetenv(defaultAttemptThresholdEnv)
|
||||
} else {
|
||||
os.Setenv(defaultAttemptThresholdEnv, tt.envValue)
|
||||
}
|
||||
|
||||
result := getAttemptThresholdFromEnv()
|
||||
if result != tt.expected {
|
||||
t.Fatalf("Expected %d, got %d", tt.expected, result)
|
||||
}
|
||||
|
||||
os.Unsetenv(defaultAttemptThresholdEnv)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user