mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-05 08:54:11 -04:00
Compare commits
29 Commits
v0.56.0
...
snyk-fix-7
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
22df3afa78 | ||
|
|
17bab881f7 | ||
|
|
25ed58328a | ||
|
|
644ed4b934 | ||
|
|
58faa341d2 | ||
|
|
5853b5553c | ||
|
|
998fb30e1e | ||
|
|
e254b4cde5 | ||
|
|
ead1c618ba | ||
|
|
55126f990c | ||
|
|
90577682e4 | ||
|
|
dc30dcacce | ||
|
|
2c87fa6236 | ||
|
|
ec8d83ade4 | ||
|
|
3130cce72d | ||
|
|
bd23ab925e | ||
|
|
0c6f671a7c | ||
|
|
cf7f6c355f | ||
|
|
47e64d72db | ||
|
|
9e81e782e5 | ||
|
|
7aef0f67df | ||
|
|
dba7ef667d | ||
|
|
69d87343d2 | ||
|
|
5113c70943 | ||
|
|
ad8fcda67b | ||
|
|
d33f88df82 | ||
|
|
786ca6fc79 | ||
|
|
dfebdf1444 | ||
|
|
a8dcff69c2 |
2
.github/workflows/golang-test-linux.yml
vendored
2
.github/workflows/golang-test-linux.yml
vendored
@@ -217,7 +217,7 @@ jobs:
|
||||
- arch: "386"
|
||||
raceFlag: ""
|
||||
- arch: "amd64"
|
||||
raceFlag: ""
|
||||
raceFlag: "-race"
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Install Go
|
||||
|
||||
2
.github/workflows/release.yml
vendored
2
.github/workflows/release.yml
vendored
@@ -9,7 +9,7 @@ on:
|
||||
pull_request:
|
||||
|
||||
env:
|
||||
SIGN_PIPE_VER: "v0.0.22"
|
||||
SIGN_PIPE_VER: "v0.0.23"
|
||||
GORELEASER_VER: "v2.3.2"
|
||||
PRODUCT_NAME: "NetBird"
|
||||
COPYRIGHT: "NetBird GmbH"
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
|
||||
<div align="center">
|
||||
<br/>
|
||||
<br/>
|
||||
@@ -52,7 +53,7 @@
|
||||
|
||||
### Open Source Network Security in a Single Platform
|
||||
|
||||
<img width="1188" alt="centralized-network-management 1" src="https://github.com/user-attachments/assets/c28cc8e4-15d2-4d2f-bb97-a6433db39d56" />
|
||||
https://github.com/user-attachments/assets/10cec749-bb56-4ab3-97af-4e38850108d2
|
||||
|
||||
### NetBird on Lawrence Systems (Video)
|
||||
[](https://www.youtube.com/watch?v=Kwrff6h0rEw)
|
||||
|
||||
@@ -18,7 +18,7 @@ ENV \
|
||||
NB_LOG_FILE="console,/var/log/netbird/client.log" \
|
||||
NB_DAEMON_ADDR="unix:///var/run/netbird.sock" \
|
||||
NB_ENTRYPOINT_SERVICE_TIMEOUT="5" \
|
||||
NB_ENTRYPOINT_LOGIN_TIMEOUT="1"
|
||||
NB_ENTRYPOINT_LOGIN_TIMEOUT="5"
|
||||
|
||||
ENTRYPOINT [ "/usr/local/bin/netbird-entrypoint.sh" ]
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@ package android
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"slices"
|
||||
"sync"
|
||||
|
||||
@@ -18,7 +19,7 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||
"github.com/netbirdio/netbird/client/system"
|
||||
"github.com/netbirdio/netbird/formatter"
|
||||
"github.com/netbirdio/netbird/util/net"
|
||||
"github.com/netbirdio/netbird/client/net"
|
||||
)
|
||||
|
||||
// ConnectionListener export internal Listener for mobile
|
||||
@@ -83,7 +84,8 @@ func NewClient(cfgFile string, androidSDKVersion int, deviceName string, uiVersi
|
||||
}
|
||||
|
||||
// Run start the internal client. It is a blocker function
|
||||
func (c *Client) Run(urlOpener URLOpener, dns *DNSList, dnsReadyListener DnsReadyListener) error {
|
||||
func (c *Client) Run(urlOpener URLOpener, dns *DNSList, dnsReadyListener DnsReadyListener, envList *EnvList) error {
|
||||
exportEnvList(envList)
|
||||
cfg, err := profilemanager.UpdateOrCreateConfig(profilemanager.ConfigInput{
|
||||
ConfigPath: c.cfgFile,
|
||||
})
|
||||
@@ -118,7 +120,8 @@ func (c *Client) Run(urlOpener URLOpener, dns *DNSList, dnsReadyListener DnsRead
|
||||
|
||||
// RunWithoutLogin we apply this type of run function when the backed has been started without UI (i.e. after reboot).
|
||||
// In this case make no sense handle registration steps.
|
||||
func (c *Client) RunWithoutLogin(dns *DNSList, dnsReadyListener DnsReadyListener) error {
|
||||
func (c *Client) RunWithoutLogin(dns *DNSList, dnsReadyListener DnsReadyListener, envList *EnvList) error {
|
||||
exportEnvList(envList)
|
||||
cfg, err := profilemanager.UpdateOrCreateConfig(profilemanager.ConfigInput{
|
||||
ConfigPath: c.cfgFile,
|
||||
})
|
||||
@@ -249,3 +252,14 @@ func (c *Client) SetConnectionListener(listener ConnectionListener) {
|
||||
func (c *Client) RemoveConnectionListener() {
|
||||
c.recorder.RemoveConnectionListener()
|
||||
}
|
||||
|
||||
func exportEnvList(list *EnvList) {
|
||||
if list == nil {
|
||||
return
|
||||
}
|
||||
for k, v := range list.AllItems() {
|
||||
if err := os.Setenv(k, v); err != nil {
|
||||
log.Errorf("could not set env variable %s: %v", k, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
32
client/android/env_list.go
Normal file
32
client/android/env_list.go
Normal file
@@ -0,0 +1,32 @@
|
||||
package android
|
||||
|
||||
import "github.com/netbirdio/netbird/client/internal/peer"
|
||||
|
||||
var (
|
||||
// EnvKeyNBForceRelay Exported for Android java client
|
||||
EnvKeyNBForceRelay = peer.EnvKeyNBForceRelay
|
||||
)
|
||||
|
||||
// EnvList wraps a Go map for export to Java
|
||||
type EnvList struct {
|
||||
data map[string]string
|
||||
}
|
||||
|
||||
// NewEnvList creates a new EnvList
|
||||
func NewEnvList() *EnvList {
|
||||
return &EnvList{data: make(map[string]string)}
|
||||
}
|
||||
|
||||
// Put adds a key-value pair
|
||||
func (el *EnvList) Put(key, value string) {
|
||||
el.data[key] = value
|
||||
}
|
||||
|
||||
// Get retrieves a value by key
|
||||
func (el *EnvList) Get(key string) string {
|
||||
return el.data[key]
|
||||
}
|
||||
|
||||
func (el *EnvList) AllItems() map[string]string {
|
||||
return el.data
|
||||
}
|
||||
@@ -33,6 +33,7 @@ type ErrListener interface {
|
||||
// the backend want to show an url for the user
|
||||
type URLOpener interface {
|
||||
Open(string)
|
||||
OnLoginSuccess()
|
||||
}
|
||||
|
||||
// Auth can register or login new client
|
||||
@@ -181,6 +182,11 @@ func (a *Auth) login(urlOpener URLOpener) error {
|
||||
|
||||
err = a.withBackOff(a.ctx, func() error {
|
||||
err := internal.Login(a.ctx, a.config, "", jwtToken)
|
||||
|
||||
if err == nil {
|
||||
go urlOpener.OnLoginSuccess()
|
||||
}
|
||||
|
||||
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied) {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -27,7 +27,7 @@ var downCmd = &cobra.Command{
|
||||
return err
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second*7)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second*20)
|
||||
defer cancel()
|
||||
|
||||
conn, err := DialClientGRPCServer(ctx, daemonAddr)
|
||||
|
||||
@@ -227,7 +227,7 @@ func doForegroundLogin(ctx context.Context, cmd *cobra.Command, setupKey string,
|
||||
}
|
||||
|
||||
// update host's static platform and system information
|
||||
system.UpdateStaticInfo()
|
||||
system.UpdateStaticInfoAsync()
|
||||
|
||||
configFilePath, err := activeProf.FilePath()
|
||||
if err != nil {
|
||||
|
||||
@@ -231,7 +231,7 @@ func FlagNameToEnvVar(cmdFlag string, prefix string) string {
|
||||
|
||||
// DialClientGRPCServer returns client connection to the daemon server.
|
||||
func DialClientGRPCServer(ctx context.Context, addr string) (*grpc.ClientConn, error) {
|
||||
ctx, cancel := context.WithTimeout(ctx, time.Second*3)
|
||||
ctx, cancel := context.WithTimeout(ctx, time.Second*10)
|
||||
defer cancel()
|
||||
|
||||
return grpc.DialContext(
|
||||
|
||||
@@ -27,7 +27,7 @@ func (p *program) Start(svc service.Service) error {
|
||||
log.Info("starting NetBird service") //nolint
|
||||
|
||||
// Collect static system and platform information
|
||||
system.UpdateStaticInfo()
|
||||
system.UpdateStaticInfoAsync()
|
||||
|
||||
// in any case, even if configuration does not exists we run daemon to serve CLI gRPC API.
|
||||
p.serv = grpc.NewServer()
|
||||
|
||||
@@ -9,29 +9,26 @@ import (
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.opentelemetry.io/otel"
|
||||
"google.golang.org/grpc"
|
||||
|
||||
"github.com/netbirdio/management-integrations/integrations"
|
||||
clientProto "github.com/netbirdio/netbird/client/proto"
|
||||
client "github.com/netbirdio/netbird/client/server"
|
||||
"github.com/netbirdio/netbird/management/internals/server/config"
|
||||
mgmt "github.com/netbirdio/netbird/management/server"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/groups"
|
||||
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
||||
"github.com/netbirdio/netbird/management/server/peers"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/settings"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
|
||||
"github.com/netbirdio/netbird/util"
|
||||
|
||||
"google.golang.org/grpc"
|
||||
|
||||
"github.com/netbirdio/management-integrations/integrations"
|
||||
|
||||
clientProto "github.com/netbirdio/netbird/client/proto"
|
||||
client "github.com/netbirdio/netbird/client/server"
|
||||
mgmt "github.com/netbirdio/netbird/management/server"
|
||||
mgmtProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||
sigProto "github.com/netbirdio/netbird/shared/signal/proto"
|
||||
sig "github.com/netbirdio/netbird/signal/server"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
|
||||
func startTestingServices(t *testing.T) string {
|
||||
@@ -90,15 +87,20 @@ func startManagement(t *testing.T, config *config.Config, testFile string) (*grp
|
||||
if err != nil {
|
||||
return nil, nil
|
||||
}
|
||||
iv, _ := integrations.NewIntegratedValidator(context.Background(), eventStore)
|
||||
|
||||
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
|
||||
require.NoError(t, err)
|
||||
ctrl := gomock.NewController(t)
|
||||
t.Cleanup(ctrl.Finish)
|
||||
|
||||
settingsMockManager := settings.NewMockManager(ctrl)
|
||||
permissionsManagerMock := permissions.NewMockManager(ctrl)
|
||||
peersmanager := peers.NewManager(store, permissionsManagerMock)
|
||||
settingsManagerMock := settings.NewMockManager(ctrl)
|
||||
|
||||
iv, _ := integrations.NewIntegratedValidator(context.Background(), peersmanager, settingsManagerMock, eventStore)
|
||||
|
||||
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
|
||||
require.NoError(t, err)
|
||||
|
||||
settingsMockManager := settings.NewMockManager(ctrl)
|
||||
groupsManager := groups.NewManagerMock()
|
||||
|
||||
settingsMockManager.EXPECT().
|
||||
|
||||
@@ -230,7 +230,9 @@ func runInDaemonMode(ctx context.Context, cmd *cobra.Command, pm *profilemanager
|
||||
|
||||
client := proto.NewDaemonServiceClient(conn)
|
||||
|
||||
status, err := client.Status(ctx, &proto.StatusRequest{})
|
||||
status, err := client.Status(ctx, &proto.StatusRequest{
|
||||
WaitForReady: func() *bool { b := true; return &b }(),
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to get daemon status: %v", err)
|
||||
}
|
||||
|
||||
@@ -135,7 +135,7 @@ func (c *Client) Start(startCtx context.Context) error {
|
||||
|
||||
// either startup error (permanent backoff err) or nil err (successful engine up)
|
||||
// TODO: make after-startup backoff err available
|
||||
run := make(chan struct{}, 1)
|
||||
run := make(chan struct{})
|
||||
clientErr := make(chan error, 1)
|
||||
go func() {
|
||||
if err := client.Run(run); err != nil {
|
||||
|
||||
@@ -12,7 +12,7 @@ import (
|
||||
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
nbnet "github.com/netbirdio/netbird/client/net"
|
||||
)
|
||||
|
||||
const (
|
||||
|
||||
@@ -19,7 +19,7 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/ipfwdstate"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
nbnet "github.com/netbirdio/netbird/client/net"
|
||||
)
|
||||
|
||||
// constants needed to manage and create iptable rules
|
||||
|
||||
@@ -14,7 +14,7 @@ import (
|
||||
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/firewall/test"
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
nbnet "github.com/netbirdio/netbird/client/net"
|
||||
)
|
||||
|
||||
func isIptablesSupported() bool {
|
||||
|
||||
@@ -16,7 +16,7 @@ import (
|
||||
"golang.org/x/sys/unix"
|
||||
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
nbnet "github.com/netbirdio/netbird/client/net"
|
||||
)
|
||||
|
||||
const (
|
||||
|
||||
@@ -22,7 +22,7 @@ import (
|
||||
nbid "github.com/netbirdio/netbird/client/internal/acl/id"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/ipfwdstate"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
nbnet "github.com/netbirdio/netbird/client/net"
|
||||
)
|
||||
|
||||
const (
|
||||
|
||||
@@ -20,8 +20,9 @@ import (
|
||||
"google.golang.org/grpc/credentials/insecure"
|
||||
"google.golang.org/grpc/keepalive"
|
||||
|
||||
nbnet "github.com/netbirdio/netbird/client/net"
|
||||
|
||||
"github.com/netbirdio/netbird/util/embeddedroots"
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
)
|
||||
|
||||
func WithCustomDialer() grpc.DialOption {
|
||||
@@ -57,7 +58,7 @@ func Backoff(ctx context.Context) backoff.BackOff {
|
||||
return backoff.WithContext(b, ctx)
|
||||
}
|
||||
|
||||
func CreateConnection(addr string, tlsEnabled bool) (*grpc.ClientConn, error) {
|
||||
func CreateConnection(ctx context.Context, addr string, tlsEnabled bool) (*grpc.ClientConn, error) {
|
||||
transportOption := grpc.WithTransportCredentials(insecure.NewCredentials())
|
||||
if tlsEnabled {
|
||||
certPool, err := x509.SystemCertPool()
|
||||
@@ -71,7 +72,7 @@ func CreateConnection(addr string, tlsEnabled bool) (*grpc.ClientConn, error) {
|
||||
}))
|
||||
}
|
||||
|
||||
connCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
connCtx, cancel := context.WithTimeout(ctx, 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
conn, err := grpc.DialContext(
|
||||
@@ -3,7 +3,7 @@ package bind
|
||||
import (
|
||||
wireguard "golang.zx2c4.com/wireguard/conn"
|
||||
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
nbnet "github.com/netbirdio/netbird/client/net"
|
||||
)
|
||||
|
||||
// TODO: This is most likely obsolete since the control fns should be called by the wrapped udpconn (ice_bind.go)
|
||||
|
||||
@@ -15,8 +15,9 @@ import (
|
||||
"golang.org/x/net/ipv6"
|
||||
wgConn "golang.zx2c4.com/wireguard/conn"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/udpmux"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
nbnet "github.com/netbirdio/netbird/client/net"
|
||||
)
|
||||
|
||||
type RecvMessage struct {
|
||||
@@ -44,7 +45,7 @@ type ICEBind struct {
|
||||
RecvChan chan RecvMessage
|
||||
|
||||
transportNet transport.Net
|
||||
filterFn FilterFn
|
||||
filterFn udpmux.FilterFn
|
||||
endpoints map[netip.Addr]net.Conn
|
||||
endpointsMu sync.Mutex
|
||||
// every time when Close() is called (i.e. BindUpdate()) we need to close exit from the receiveRelayed and create a
|
||||
@@ -54,13 +55,13 @@ type ICEBind struct {
|
||||
closed bool
|
||||
|
||||
muUDPMux sync.Mutex
|
||||
udpMux *UniversalUDPMuxDefault
|
||||
udpMux *udpmux.UniversalUDPMuxDefault
|
||||
address wgaddr.Address
|
||||
mtu uint16
|
||||
activityRecorder *ActivityRecorder
|
||||
}
|
||||
|
||||
func NewICEBind(transportNet transport.Net, filterFn FilterFn, address wgaddr.Address, mtu uint16) *ICEBind {
|
||||
func NewICEBind(transportNet transport.Net, filterFn udpmux.FilterFn, address wgaddr.Address, mtu uint16) *ICEBind {
|
||||
b, _ := wgConn.NewStdNetBind().(*wgConn.StdNetBind)
|
||||
ib := &ICEBind{
|
||||
StdNetBind: b,
|
||||
@@ -115,7 +116,7 @@ func (s *ICEBind) ActivityRecorder() *ActivityRecorder {
|
||||
}
|
||||
|
||||
// GetICEMux returns the ICE UDPMux that was created and used by ICEBind
|
||||
func (s *ICEBind) GetICEMux() (*UniversalUDPMuxDefault, error) {
|
||||
func (s *ICEBind) GetICEMux() (*udpmux.UniversalUDPMuxDefault, error) {
|
||||
s.muUDPMux.Lock()
|
||||
defer s.muUDPMux.Unlock()
|
||||
if s.udpMux == nil {
|
||||
@@ -158,8 +159,8 @@ func (s *ICEBind) createIPv4ReceiverFn(pc *ipv4.PacketConn, conn *net.UDPConn, r
|
||||
s.muUDPMux.Lock()
|
||||
defer s.muUDPMux.Unlock()
|
||||
|
||||
s.udpMux = NewUniversalUDPMuxDefault(
|
||||
UniversalUDPMuxParams{
|
||||
s.udpMux = udpmux.NewUniversalUDPMuxDefault(
|
||||
udpmux.UniversalUDPMuxParams{
|
||||
UDPConn: nbnet.WrapPacketConn(conn),
|
||||
Net: s.transportNet,
|
||||
FilterFn: s.filterFn,
|
||||
|
||||
@@ -1,7 +0,0 @@
|
||||
//go:build ios
|
||||
|
||||
package bind
|
||||
|
||||
func (m *UDPMuxDefault) notifyAddressRemoval(addr string) {
|
||||
// iOS doesn't support nbnet hooks, so this is a no-op
|
||||
}
|
||||
@@ -17,8 +17,8 @@ import (
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/bind"
|
||||
nbnet "github.com/netbirdio/netbird/client/net"
|
||||
"github.com/netbirdio/netbird/monotime"
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -394,6 +394,13 @@ func toLastHandshake(stringVar string) (time.Time, error) {
|
||||
if err != nil {
|
||||
return time.Time{}, fmt.Errorf("parse handshake sec: %w", err)
|
||||
}
|
||||
|
||||
// If sec is 0 (Unix epoch), return zero time instead
|
||||
// This indicates no handshake has occurred
|
||||
if sec == 0 {
|
||||
return time.Time{}, nil
|
||||
}
|
||||
|
||||
return time.Unix(sec, 0), nil
|
||||
}
|
||||
|
||||
@@ -402,7 +409,7 @@ func toBytes(s string) (int64, error) {
|
||||
}
|
||||
|
||||
func getFwmark() int {
|
||||
if nbnet.AdvancedRouting() {
|
||||
if nbnet.AdvancedRouting() && runtime.GOOS == "linux" {
|
||||
return nbnet.ControlPlaneMark
|
||||
}
|
||||
return 0
|
||||
|
||||
@@ -7,14 +7,14 @@ import (
|
||||
|
||||
wgdevice "golang.zx2c4.com/wireguard/device"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/bind"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/client/iface/udpmux"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
)
|
||||
|
||||
type WGTunDevice interface {
|
||||
Create() (device.WGConfigurer, error)
|
||||
Up() (*bind.UniversalUDPMuxDefault, error)
|
||||
Up() (*udpmux.UniversalUDPMuxDefault, error)
|
||||
UpdateAddr(address wgaddr.Address) error
|
||||
WgAddress() wgaddr.Address
|
||||
MTU() uint16
|
||||
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/bind"
|
||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||
"github.com/netbirdio/netbird/client/iface/udpmux"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
)
|
||||
|
||||
@@ -29,7 +30,7 @@ type WGTunDevice struct {
|
||||
name string
|
||||
device *device.Device
|
||||
filteredDevice *FilteredDevice
|
||||
udpMux *bind.UniversalUDPMuxDefault
|
||||
udpMux *udpmux.UniversalUDPMuxDefault
|
||||
configurer WGConfigurer
|
||||
}
|
||||
|
||||
@@ -88,7 +89,7 @@ func (t *WGTunDevice) Create(routes []string, dns string, searchDomains []string
|
||||
}
|
||||
return t.configurer, nil
|
||||
}
|
||||
func (t *WGTunDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
|
||||
func (t *WGTunDevice) Up() (*udpmux.UniversalUDPMuxDefault, error) {
|
||||
err := t.device.Up()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/bind"
|
||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||
"github.com/netbirdio/netbird/client/iface/udpmux"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
)
|
||||
|
||||
@@ -26,7 +27,7 @@ type TunDevice struct {
|
||||
|
||||
device *device.Device
|
||||
filteredDevice *FilteredDevice
|
||||
udpMux *bind.UniversalUDPMuxDefault
|
||||
udpMux *udpmux.UniversalUDPMuxDefault
|
||||
configurer WGConfigurer
|
||||
}
|
||||
|
||||
@@ -71,7 +72,7 @@ func (t *TunDevice) Create() (WGConfigurer, error) {
|
||||
return t.configurer, nil
|
||||
}
|
||||
|
||||
func (t *TunDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
|
||||
func (t *TunDevice) Up() (*udpmux.UniversalUDPMuxDefault, error) {
|
||||
err := t.device.Up()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
||||
@@ -14,6 +14,7 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/bind"
|
||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||
"github.com/netbirdio/netbird/client/iface/udpmux"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
)
|
||||
|
||||
@@ -28,7 +29,7 @@ type TunDevice struct {
|
||||
|
||||
device *device.Device
|
||||
filteredDevice *FilteredDevice
|
||||
udpMux *bind.UniversalUDPMuxDefault
|
||||
udpMux *udpmux.UniversalUDPMuxDefault
|
||||
configurer WGConfigurer
|
||||
}
|
||||
|
||||
@@ -83,7 +84,7 @@ func (t *TunDevice) Create() (WGConfigurer, error) {
|
||||
return t.configurer, nil
|
||||
}
|
||||
|
||||
func (t *TunDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
|
||||
func (t *TunDevice) Up() (*udpmux.UniversalUDPMuxDefault, error) {
|
||||
err := t.device.Up()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
||||
@@ -12,11 +12,11 @@ import (
|
||||
"golang.zx2c4.com/wireguard/device"
|
||||
"golang.zx2c4.com/wireguard/tun/netstack"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/bind"
|
||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||
"github.com/netbirdio/netbird/client/iface/udpmux"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
nbnet "github.com/netbirdio/netbird/client/net"
|
||||
"github.com/netbirdio/netbird/sharedsock"
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
)
|
||||
|
||||
type TunKernelDevice struct {
|
||||
@@ -31,9 +31,9 @@ type TunKernelDevice struct {
|
||||
|
||||
link *wgLink
|
||||
udpMuxConn net.PacketConn
|
||||
udpMux *bind.UniversalUDPMuxDefault
|
||||
udpMux *udpmux.UniversalUDPMuxDefault
|
||||
|
||||
filterFn bind.FilterFn
|
||||
filterFn udpmux.FilterFn
|
||||
}
|
||||
|
||||
func NewKernelDevice(name string, address wgaddr.Address, wgPort int, key string, mtu uint16, transportNet transport.Net) *TunKernelDevice {
|
||||
@@ -79,7 +79,7 @@ func (t *TunKernelDevice) Create() (WGConfigurer, error) {
|
||||
return configurer, nil
|
||||
}
|
||||
|
||||
func (t *TunKernelDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
|
||||
func (t *TunKernelDevice) Up() (*udpmux.UniversalUDPMuxDefault, error) {
|
||||
if t.udpMux != nil {
|
||||
return t.udpMux, nil
|
||||
}
|
||||
@@ -101,19 +101,14 @@ func (t *TunKernelDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var udpConn net.PacketConn = rawSock
|
||||
if !nbnet.AdvancedRouting() {
|
||||
udpConn = nbnet.WrapPacketConn(rawSock)
|
||||
}
|
||||
|
||||
bindParams := bind.UniversalUDPMuxParams{
|
||||
UDPConn: udpConn,
|
||||
bindParams := udpmux.UniversalUDPMuxParams{
|
||||
UDPConn: nbnet.WrapPacketConn(rawSock),
|
||||
Net: t.transportNet,
|
||||
FilterFn: t.filterFn,
|
||||
WGAddress: t.address,
|
||||
MTU: t.mtu,
|
||||
}
|
||||
mux := bind.NewUniversalUDPMuxDefault(bindParams)
|
||||
mux := udpmux.NewUniversalUDPMuxDefault(bindParams)
|
||||
go mux.ReadFromConn(t.ctx)
|
||||
t.udpMuxConn = rawSock
|
||||
t.udpMux = mux
|
||||
|
||||
@@ -10,8 +10,9 @@ import (
|
||||
"github.com/netbirdio/netbird/client/iface/bind"
|
||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||
nbnetstack "github.com/netbirdio/netbird/client/iface/netstack"
|
||||
"github.com/netbirdio/netbird/client/iface/udpmux"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
nbnet "github.com/netbirdio/netbird/client/net"
|
||||
)
|
||||
|
||||
type TunNetstackDevice struct {
|
||||
@@ -26,7 +27,7 @@ type TunNetstackDevice struct {
|
||||
device *device.Device
|
||||
filteredDevice *FilteredDevice
|
||||
nsTun *nbnetstack.NetStackTun
|
||||
udpMux *bind.UniversalUDPMuxDefault
|
||||
udpMux *udpmux.UniversalUDPMuxDefault
|
||||
configurer WGConfigurer
|
||||
|
||||
net *netstack.Net
|
||||
@@ -80,7 +81,7 @@ func (t *TunNetstackDevice) create() (WGConfigurer, error) {
|
||||
return t.configurer, nil
|
||||
}
|
||||
|
||||
func (t *TunNetstackDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
|
||||
func (t *TunNetstackDevice) Up() (*udpmux.UniversalUDPMuxDefault, error) {
|
||||
if t.device == nil {
|
||||
return nil, fmt.Errorf("device is not ready yet")
|
||||
}
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/bind"
|
||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||
"github.com/netbirdio/netbird/client/iface/udpmux"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
)
|
||||
|
||||
@@ -25,7 +26,7 @@ type USPDevice struct {
|
||||
|
||||
device *device.Device
|
||||
filteredDevice *FilteredDevice
|
||||
udpMux *bind.UniversalUDPMuxDefault
|
||||
udpMux *udpmux.UniversalUDPMuxDefault
|
||||
configurer WGConfigurer
|
||||
}
|
||||
|
||||
@@ -74,7 +75,7 @@ func (t *USPDevice) Create() (WGConfigurer, error) {
|
||||
return t.configurer, nil
|
||||
}
|
||||
|
||||
func (t *USPDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
|
||||
func (t *USPDevice) Up() (*udpmux.UniversalUDPMuxDefault, error) {
|
||||
if t.device == nil {
|
||||
return nil, fmt.Errorf("device is not ready yet")
|
||||
}
|
||||
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/bind"
|
||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||
"github.com/netbirdio/netbird/client/iface/udpmux"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
)
|
||||
|
||||
@@ -29,7 +30,7 @@ type TunDevice struct {
|
||||
device *device.Device
|
||||
nativeTunDevice *tun.NativeTun
|
||||
filteredDevice *FilteredDevice
|
||||
udpMux *bind.UniversalUDPMuxDefault
|
||||
udpMux *udpmux.UniversalUDPMuxDefault
|
||||
configurer WGConfigurer
|
||||
}
|
||||
|
||||
@@ -104,7 +105,7 @@ func (t *TunDevice) Create() (WGConfigurer, error) {
|
||||
return t.configurer, nil
|
||||
}
|
||||
|
||||
func (t *TunDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
|
||||
func (t *TunDevice) Up() (*udpmux.UniversalUDPMuxDefault, error) {
|
||||
err := t.device.Up()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
||||
@@ -5,14 +5,14 @@ import (
|
||||
|
||||
"golang.zx2c4.com/wireguard/tun/netstack"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/bind"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/client/iface/udpmux"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
)
|
||||
|
||||
type WGTunDevice interface {
|
||||
Create(routes []string, dns string, searchDomains []string) (device.WGConfigurer, error)
|
||||
Up() (*bind.UniversalUDPMuxDefault, error)
|
||||
Up() (*udpmux.UniversalUDPMuxDefault, error)
|
||||
UpdateAddr(address wgaddr.Address) error
|
||||
WgAddress() wgaddr.Address
|
||||
MTU() uint16
|
||||
|
||||
@@ -16,9 +16,9 @@ import (
|
||||
wgdevice "golang.zx2c4.com/wireguard/device"
|
||||
|
||||
"github.com/netbirdio/netbird/client/errors"
|
||||
"github.com/netbirdio/netbird/client/iface/bind"
|
||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/client/iface/udpmux"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
"github.com/netbirdio/netbird/client/iface/wgproxy"
|
||||
"github.com/netbirdio/netbird/monotime"
|
||||
@@ -61,7 +61,7 @@ type WGIFaceOpts struct {
|
||||
MTU uint16
|
||||
MobileArgs *device.MobileIFaceArguments
|
||||
TransportNet transport.Net
|
||||
FilterFn bind.FilterFn
|
||||
FilterFn udpmux.FilterFn
|
||||
DisableDNS bool
|
||||
}
|
||||
|
||||
@@ -114,7 +114,7 @@ func (r *WGIface) ToInterface() *net.Interface {
|
||||
|
||||
// Up configures a Wireguard interface
|
||||
// The interface must exist before calling this method (e.g. call interface.Create() before)
|
||||
func (w *WGIface) Up() (*bind.UniversalUDPMuxDefault, error) {
|
||||
func (w *WGIface) Up() (*udpmux.UniversalUDPMuxDefault, error) {
|
||||
w.mu.Lock()
|
||||
defer w.mu.Unlock()
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
package bind
|
||||
package udpmux
|
||||
|
||||
/*
|
||||
Most of this code was copied from https://github.com/pion/ice and modified to fulfill NetBird's requirements
|
||||
@@ -16,11 +16,12 @@ import (
|
||||
)
|
||||
|
||||
type udpMuxedConnParams struct {
|
||||
Mux *UDPMuxDefault
|
||||
AddrPool *sync.Pool
|
||||
Key string
|
||||
LocalAddr net.Addr
|
||||
Logger logging.LeveledLogger
|
||||
Mux *SingleSocketUDPMux
|
||||
AddrPool *sync.Pool
|
||||
Key string
|
||||
LocalAddr net.Addr
|
||||
Logger logging.LeveledLogger
|
||||
CandidateID string
|
||||
}
|
||||
|
||||
// udpMuxedConn represents a logical packet conn for a single remote as identified by ufrag
|
||||
@@ -119,6 +120,10 @@ func (c *udpMuxedConn) Close() error {
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *udpMuxedConn) GetCandidateID() string {
|
||||
return c.params.CandidateID
|
||||
}
|
||||
|
||||
func (c *udpMuxedConn) isClosed() bool {
|
||||
select {
|
||||
case <-c.closedChan:
|
||||
64
client/iface/udpmux/doc.go
Normal file
64
client/iface/udpmux/doc.go
Normal file
@@ -0,0 +1,64 @@
|
||||
// Package udpmux provides a custom implementation of a UDP multiplexer
|
||||
// that allows multiple logical ICE connections to share a single underlying
|
||||
// UDP socket. This is based on Pion's ICE library, with modifications for
|
||||
// NetBird's requirements.
|
||||
//
|
||||
// # Background
|
||||
//
|
||||
// In WebRTC and NAT traversal scenarios, ICE (Interactive Connectivity
|
||||
// Establishment) is responsible for discovering candidate network paths
|
||||
// and maintaining connectivity between peers. Each ICE connection
|
||||
// normally requires a dedicated UDP socket. However, using one socket
|
||||
// per candidate can be inefficient and difficult to manage.
|
||||
//
|
||||
// This package introduces SingleSocketUDPMux, which allows multiple ICE
|
||||
// candidate connections (muxed connections) to share a single UDP socket.
|
||||
// It handles demultiplexing of packets based on ICE ufrag values, STUN
|
||||
// attributes, and candidate IDs.
|
||||
//
|
||||
// # Usage
|
||||
//
|
||||
// The typical flow is:
|
||||
//
|
||||
// 1. Create a UDP socket (net.PacketConn).
|
||||
// 2. Construct Params with the socket and optional logger/net stack.
|
||||
// 3. Call NewSingleSocketUDPMux(params).
|
||||
// 4. For each ICE candidate ufrag, call GetConn(ufrag, addr, candidateID)
|
||||
// to obtain a logical PacketConn.
|
||||
// 5. Use the returned PacketConn just like a normal UDP connection.
|
||||
//
|
||||
// # STUN Message Routing Logic
|
||||
//
|
||||
// When a STUN packet arrives, the mux decides which connection should
|
||||
// receive it using this routing logic:
|
||||
//
|
||||
// Primary Routing: Candidate Pair ID
|
||||
// - Extract the candidate pair ID from the STUN message using
|
||||
// ice.CandidatePairIDFromSTUN(msg)
|
||||
// - The target candidate is the locally generated candidate that
|
||||
// corresponds to the connection that should handle this STUN message
|
||||
// - If found, use the target candidate ID to lookup the specific
|
||||
// connection in candidateConnMap
|
||||
// - Route the message directly to that connection
|
||||
//
|
||||
// Fallback Routing: Broadcasting
|
||||
// When candidate pair ID is not available or lookup fails:
|
||||
// - Collect connections from addressMap based on source address
|
||||
// - Find connection using username attribute (ufrag) from STUN message
|
||||
// - Remove duplicate connections from the list
|
||||
// - Send the STUN message to all collected connections
|
||||
//
|
||||
// # Peer Reflexive Candidate Discovery
|
||||
//
|
||||
// When a remote peer sends a STUN message from an unknown source address
|
||||
// (from a candidate that has not been exchanged via signal), the ICE
|
||||
// library will:
|
||||
// - Generate a new peer reflexive candidate for this source address
|
||||
// - Extract or assign a candidate ID based on the STUN message attributes
|
||||
// - Create a mapping between the new peer reflexive candidate ID and
|
||||
// the appropriate local connection
|
||||
//
|
||||
// This discovery mechanism ensures that STUN messages from newly discovered
|
||||
// peer reflexive candidates can be properly routed to the correct local
|
||||
// connection without requiring fallback broadcasting.
|
||||
package udpmux
|
||||
@@ -1,4 +1,4 @@
|
||||
package bind
|
||||
package udpmux
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
@@ -22,9 +22,9 @@ import (
|
||||
|
||||
const receiveMTU = 8192
|
||||
|
||||
// UDPMuxDefault is an implementation of the interface
|
||||
type UDPMuxDefault struct {
|
||||
params UDPMuxParams
|
||||
// SingleSocketUDPMux is an implementation of the interface
|
||||
type SingleSocketUDPMux struct {
|
||||
params Params
|
||||
|
||||
closedChan chan struct{}
|
||||
closeOnce sync.Once
|
||||
@@ -32,6 +32,9 @@ type UDPMuxDefault struct {
|
||||
// connsIPv4 and connsIPv6 are maps of all udpMuxedConn indexed by ufrag|network|candidateType
|
||||
connsIPv4, connsIPv6 map[string]*udpMuxedConn
|
||||
|
||||
// candidateConnMap maps local candidate IDs to their corresponding connection.
|
||||
candidateConnMap map[string]*udpMuxedConn
|
||||
|
||||
addressMapMu sync.RWMutex
|
||||
addressMap map[string][]*udpMuxedConn
|
||||
|
||||
@@ -46,8 +49,8 @@ type UDPMuxDefault struct {
|
||||
|
||||
const maxAddrSize = 512
|
||||
|
||||
// UDPMuxParams are parameters for UDPMux.
|
||||
type UDPMuxParams struct {
|
||||
// Params are parameters for UDPMux.
|
||||
type Params struct {
|
||||
Logger logging.LeveledLogger
|
||||
UDPConn net.PacketConn
|
||||
|
||||
@@ -147,18 +150,19 @@ func isZeros(ip net.IP) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// NewUDPMuxDefault creates an implementation of UDPMux
|
||||
func NewUDPMuxDefault(params UDPMuxParams) *UDPMuxDefault {
|
||||
// NewSingleSocketUDPMux creates an implementation of UDPMux
|
||||
func NewSingleSocketUDPMux(params Params) *SingleSocketUDPMux {
|
||||
if params.Logger == nil {
|
||||
params.Logger = getLogger()
|
||||
}
|
||||
|
||||
mux := &UDPMuxDefault{
|
||||
addressMap: map[string][]*udpMuxedConn{},
|
||||
params: params,
|
||||
connsIPv4: make(map[string]*udpMuxedConn),
|
||||
connsIPv6: make(map[string]*udpMuxedConn),
|
||||
closedChan: make(chan struct{}, 1),
|
||||
mux := &SingleSocketUDPMux{
|
||||
addressMap: map[string][]*udpMuxedConn{},
|
||||
params: params,
|
||||
connsIPv4: make(map[string]*udpMuxedConn),
|
||||
connsIPv6: make(map[string]*udpMuxedConn),
|
||||
candidateConnMap: make(map[string]*udpMuxedConn),
|
||||
closedChan: make(chan struct{}, 1),
|
||||
pool: &sync.Pool{
|
||||
New: func() interface{} {
|
||||
// big enough buffer to fit both packet and address
|
||||
@@ -171,15 +175,15 @@ func NewUDPMuxDefault(params UDPMuxParams) *UDPMuxDefault {
|
||||
return mux
|
||||
}
|
||||
|
||||
func (m *UDPMuxDefault) updateLocalAddresses() {
|
||||
func (m *SingleSocketUDPMux) updateLocalAddresses() {
|
||||
var localAddrsForUnspecified []net.Addr
|
||||
if addr, ok := m.params.UDPConn.LocalAddr().(*net.UDPAddr); !ok {
|
||||
m.params.Logger.Errorf("LocalAddr is not a net.UDPAddr, got %T", m.params.UDPConn.LocalAddr())
|
||||
} else if ok && addr.IP.IsUnspecified() {
|
||||
// For unspecified addresses, the correct behavior is to return errListenUnspecified, but
|
||||
// it will break the applications that are already using unspecified UDP connection
|
||||
// with UDPMuxDefault, so print a warn log and create a local address list for mux.
|
||||
m.params.Logger.Warn("UDPMuxDefault should not listening on unspecified address, use NewMultiUDPMuxFromPort instead")
|
||||
// with SingleSocketUDPMux, so print a warn log and create a local address list for mux.
|
||||
m.params.Logger.Warn("SingleSocketUDPMux should not listening on unspecified address, use NewMultiUDPMuxFromPort instead")
|
||||
var networks []ice.NetworkType
|
||||
switch {
|
||||
|
||||
@@ -216,13 +220,13 @@ func (m *UDPMuxDefault) updateLocalAddresses() {
|
||||
m.mu.Unlock()
|
||||
}
|
||||
|
||||
// LocalAddr returns the listening address of this UDPMuxDefault
|
||||
func (m *UDPMuxDefault) LocalAddr() net.Addr {
|
||||
// LocalAddr returns the listening address of this SingleSocketUDPMux
|
||||
func (m *SingleSocketUDPMux) LocalAddr() net.Addr {
|
||||
return m.params.UDPConn.LocalAddr()
|
||||
}
|
||||
|
||||
// GetListenAddresses returns the list of addresses that this mux is listening on
|
||||
func (m *UDPMuxDefault) GetListenAddresses() []net.Addr {
|
||||
func (m *SingleSocketUDPMux) GetListenAddresses() []net.Addr {
|
||||
m.updateLocalAddresses()
|
||||
|
||||
m.mu.Lock()
|
||||
@@ -236,7 +240,7 @@ func (m *UDPMuxDefault) GetListenAddresses() []net.Addr {
|
||||
|
||||
// GetConn returns a PacketConn given the connection's ufrag and network address
|
||||
// creates the connection if an existing one can't be found
|
||||
func (m *UDPMuxDefault) GetConn(ufrag string, addr net.Addr) (net.PacketConn, error) {
|
||||
func (m *SingleSocketUDPMux) GetConn(ufrag string, addr net.Addr, candidateID string) (net.PacketConn, error) {
|
||||
// don't check addr for mux using unspecified address
|
||||
m.mu.Lock()
|
||||
lenLocalAddrs := len(m.localAddrsForUnspecified)
|
||||
@@ -260,12 +264,14 @@ func (m *UDPMuxDefault) GetConn(ufrag string, addr net.Addr) (net.PacketConn, er
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
c := m.createMuxedConn(ufrag)
|
||||
c := m.createMuxedConn(ufrag, candidateID)
|
||||
go func() {
|
||||
<-c.CloseChannel()
|
||||
m.RemoveConnByUfrag(ufrag)
|
||||
}()
|
||||
|
||||
m.candidateConnMap[candidateID] = c
|
||||
|
||||
if isIPv6 {
|
||||
m.connsIPv6[ufrag] = c
|
||||
} else {
|
||||
@@ -276,7 +282,7 @@ func (m *UDPMuxDefault) GetConn(ufrag string, addr net.Addr) (net.PacketConn, er
|
||||
}
|
||||
|
||||
// RemoveConnByUfrag stops and removes the muxed packet connection
|
||||
func (m *UDPMuxDefault) RemoveConnByUfrag(ufrag string) {
|
||||
func (m *SingleSocketUDPMux) RemoveConnByUfrag(ufrag string) {
|
||||
removedConns := make([]*udpMuxedConn, 0, 2)
|
||||
|
||||
// Keep lock section small to avoid deadlock with conn lock
|
||||
@@ -284,10 +290,12 @@ func (m *UDPMuxDefault) RemoveConnByUfrag(ufrag string) {
|
||||
if c, ok := m.connsIPv4[ufrag]; ok {
|
||||
delete(m.connsIPv4, ufrag)
|
||||
removedConns = append(removedConns, c)
|
||||
delete(m.candidateConnMap, c.GetCandidateID())
|
||||
}
|
||||
if c, ok := m.connsIPv6[ufrag]; ok {
|
||||
delete(m.connsIPv6, ufrag)
|
||||
removedConns = append(removedConns, c)
|
||||
delete(m.candidateConnMap, c.GetCandidateID())
|
||||
}
|
||||
m.mu.Unlock()
|
||||
|
||||
@@ -314,7 +322,7 @@ func (m *UDPMuxDefault) RemoveConnByUfrag(ufrag string) {
|
||||
}
|
||||
|
||||
// IsClosed returns true if the mux had been closed
|
||||
func (m *UDPMuxDefault) IsClosed() bool {
|
||||
func (m *SingleSocketUDPMux) IsClosed() bool {
|
||||
select {
|
||||
case <-m.closedChan:
|
||||
return true
|
||||
@@ -324,7 +332,7 @@ func (m *UDPMuxDefault) IsClosed() bool {
|
||||
}
|
||||
|
||||
// Close the mux, no further connections could be created
|
||||
func (m *UDPMuxDefault) Close() error {
|
||||
func (m *SingleSocketUDPMux) Close() error {
|
||||
var err error
|
||||
m.closeOnce.Do(func() {
|
||||
m.mu.Lock()
|
||||
@@ -347,11 +355,11 @@ func (m *UDPMuxDefault) Close() error {
|
||||
return err
|
||||
}
|
||||
|
||||
func (m *UDPMuxDefault) writeTo(buf []byte, rAddr net.Addr) (n int, err error) {
|
||||
func (m *SingleSocketUDPMux) writeTo(buf []byte, rAddr net.Addr) (n int, err error) {
|
||||
return m.params.UDPConn.WriteTo(buf, rAddr)
|
||||
}
|
||||
|
||||
func (m *UDPMuxDefault) registerConnForAddress(conn *udpMuxedConn, addr string) {
|
||||
func (m *SingleSocketUDPMux) registerConnForAddress(conn *udpMuxedConn, addr string) {
|
||||
if m.IsClosed() {
|
||||
return
|
||||
}
|
||||
@@ -368,81 +376,109 @@ func (m *UDPMuxDefault) registerConnForAddress(conn *udpMuxedConn, addr string)
|
||||
log.Debugf("ICE: registered %s for %s", addr, conn.params.Key)
|
||||
}
|
||||
|
||||
func (m *UDPMuxDefault) createMuxedConn(key string) *udpMuxedConn {
|
||||
func (m *SingleSocketUDPMux) createMuxedConn(key string, candidateID string) *udpMuxedConn {
|
||||
c := newUDPMuxedConn(&udpMuxedConnParams{
|
||||
Mux: m,
|
||||
Key: key,
|
||||
AddrPool: m.pool,
|
||||
LocalAddr: m.LocalAddr(),
|
||||
Logger: m.params.Logger,
|
||||
Mux: m,
|
||||
Key: key,
|
||||
AddrPool: m.pool,
|
||||
LocalAddr: m.LocalAddr(),
|
||||
Logger: m.params.Logger,
|
||||
CandidateID: candidateID,
|
||||
})
|
||||
return c
|
||||
}
|
||||
|
||||
// HandleSTUNMessage handles STUN packets and forwards them to underlying pion/ice library
|
||||
func (m *UDPMuxDefault) HandleSTUNMessage(msg *stun.Message, addr net.Addr) error {
|
||||
|
||||
func (m *SingleSocketUDPMux) HandleSTUNMessage(msg *stun.Message, addr net.Addr) error {
|
||||
remoteAddr, ok := addr.(*net.UDPAddr)
|
||||
if !ok {
|
||||
return fmt.Errorf("underlying PacketConn did not return a UDPAddr")
|
||||
}
|
||||
|
||||
// If we have already seen this address dispatch to the appropriate destination
|
||||
// If you are using the same socket for the Host and SRFLX candidates, it might be that there are more than one
|
||||
// muxed connection - one for the SRFLX candidate and the other one for the HOST one.
|
||||
// We will then forward STUN packets to each of these connections.
|
||||
m.addressMapMu.RLock()
|
||||
// Try to route to specific candidate connection first
|
||||
if conn := m.findCandidateConnection(msg); conn != nil {
|
||||
return conn.writePacket(msg.Raw, remoteAddr)
|
||||
}
|
||||
|
||||
// Fallback: route to all possible connections
|
||||
return m.forwardToAllConnections(msg, addr, remoteAddr)
|
||||
}
|
||||
|
||||
// findCandidateConnection attempts to find the specific connection for a STUN message
|
||||
func (m *SingleSocketUDPMux) findCandidateConnection(msg *stun.Message) *udpMuxedConn {
|
||||
candidatePairID, ok, err := ice.CandidatePairIDFromSTUN(msg)
|
||||
if err != nil {
|
||||
return nil
|
||||
} else if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
conn, exists := m.candidateConnMap[candidatePairID.TargetCandidateID()]
|
||||
if !exists {
|
||||
return nil
|
||||
}
|
||||
return conn
|
||||
}
|
||||
|
||||
// forwardToAllConnections forwards STUN message to all relevant connections
|
||||
func (m *SingleSocketUDPMux) forwardToAllConnections(msg *stun.Message, addr net.Addr, remoteAddr *net.UDPAddr) error {
|
||||
var destinationConnList []*udpMuxedConn
|
||||
|
||||
// Add connections from address map
|
||||
m.addressMapMu.RLock()
|
||||
if storedConns, ok := m.addressMap[addr.String()]; ok {
|
||||
destinationConnList = append(destinationConnList, storedConns...)
|
||||
}
|
||||
m.addressMapMu.RUnlock()
|
||||
|
||||
var isIPv6 bool
|
||||
if udpAddr, _ := addr.(*net.UDPAddr); udpAddr != nil && udpAddr.IP.To4() == nil {
|
||||
isIPv6 = true
|
||||
if conn, ok := m.findConnectionByUsername(msg, addr); ok {
|
||||
// If we have already seen this address dispatch to the appropriate destination
|
||||
// If you are using the same socket for the Host and SRFLX candidates, it might be that there are more than one
|
||||
// muxed connection - one for the SRFLX candidate and the other one for the HOST one.
|
||||
// We will then forward STUN packets to each of these connections.
|
||||
if !m.connectionExists(conn, destinationConnList) {
|
||||
destinationConnList = append(destinationConnList, conn)
|
||||
}
|
||||
}
|
||||
|
||||
// This block is needed to discover Peer Reflexive Candidates for which we don't know the Endpoint upfront.
|
||||
// However, we can take a username attribute from the STUN message which contains ufrag.
|
||||
// We can use ufrag to identify the destination conn to route packet to.
|
||||
attr, stunAttrErr := msg.Get(stun.AttrUsername)
|
||||
if stunAttrErr == nil {
|
||||
ufrag := strings.Split(string(attr), ":")[0]
|
||||
|
||||
m.mu.Lock()
|
||||
destinationConn := m.connsIPv4[ufrag]
|
||||
if isIPv6 {
|
||||
destinationConn = m.connsIPv6[ufrag]
|
||||
}
|
||||
|
||||
if destinationConn != nil {
|
||||
exists := false
|
||||
for _, conn := range destinationConnList {
|
||||
if conn.params.Key == destinationConn.params.Key {
|
||||
exists = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !exists {
|
||||
destinationConnList = append(destinationConnList, destinationConn)
|
||||
}
|
||||
}
|
||||
m.mu.Unlock()
|
||||
}
|
||||
|
||||
// Forward STUN packets to each destination connections even thought the STUN packet might not belong there.
|
||||
// It will be discarded by the further ICE candidate logic if so.
|
||||
// Forward to all found connections
|
||||
for _, conn := range destinationConnList {
|
||||
if err := conn.writePacket(msg.Raw, remoteAddr); err != nil {
|
||||
log.Errorf("could not write packet: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *UDPMuxDefault) getConn(ufrag string, isIPv6 bool) (val *udpMuxedConn, ok bool) {
|
||||
// findConnectionByUsername finds connection using username attribute from STUN message
|
||||
func (m *SingleSocketUDPMux) findConnectionByUsername(msg *stun.Message, addr net.Addr) (*udpMuxedConn, bool) {
|
||||
attr, err := msg.Get(stun.AttrUsername)
|
||||
if err != nil {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
ufrag := strings.Split(string(attr), ":")[0]
|
||||
isIPv6 := isIPv6Address(addr)
|
||||
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
return m.getConn(ufrag, isIPv6)
|
||||
}
|
||||
|
||||
// connectionExists checks if a connection already exists in the list
|
||||
func (m *SingleSocketUDPMux) connectionExists(target *udpMuxedConn, conns []*udpMuxedConn) bool {
|
||||
for _, conn := range conns {
|
||||
if conn.params.Key == target.params.Key {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (m *SingleSocketUDPMux) getConn(ufrag string, isIPv6 bool) (val *udpMuxedConn, ok bool) {
|
||||
if isIPv6 {
|
||||
val, ok = m.connsIPv6[ufrag]
|
||||
} else {
|
||||
@@ -451,6 +487,13 @@ func (m *UDPMuxDefault) getConn(ufrag string, isIPv6 bool) (val *udpMuxedConn, o
|
||||
return
|
||||
}
|
||||
|
||||
func isIPv6Address(addr net.Addr) bool {
|
||||
if udpAddr, ok := addr.(*net.UDPAddr); ok {
|
||||
return udpAddr.IP.To4() == nil
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
type bufferHolder struct {
|
||||
buf []byte
|
||||
}
|
||||
@@ -1,12 +1,12 @@
|
||||
//go:build !ios
|
||||
|
||||
package bind
|
||||
package udpmux
|
||||
|
||||
import (
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
nbnet "github.com/netbirdio/netbird/client/net"
|
||||
)
|
||||
|
||||
func (m *UDPMuxDefault) notifyAddressRemoval(addr string) {
|
||||
func (m *SingleSocketUDPMux) notifyAddressRemoval(addr string) {
|
||||
// Kernel mode: direct nbnet.PacketConn (SharedSocket wrapped with nbnet)
|
||||
if conn, ok := m.params.UDPConn.(*nbnet.PacketConn); ok {
|
||||
conn.RemoveAddress(addr)
|
||||
7
client/iface/udpmux/mux_ios.go
Normal file
7
client/iface/udpmux/mux_ios.go
Normal file
@@ -0,0 +1,7 @@
|
||||
//go:build ios
|
||||
|
||||
package udpmux
|
||||
|
||||
func (m *SingleSocketUDPMux) notifyAddressRemoval(addr string) {
|
||||
// iOS doesn't support nbnet hooks, so this is a no-op
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package bind
|
||||
package udpmux
|
||||
|
||||
/*
|
||||
Most of this code was copied from https://github.com/pion/ice and modified to fulfill NetBird's requirements.
|
||||
@@ -29,7 +29,7 @@ type FilterFn func(address netip.Addr) (bool, netip.Prefix, error)
|
||||
// UniversalUDPMuxDefault handles STUN and TURN servers packets by wrapping the original UDPConn
|
||||
// It then passes packets to the UDPMux that does the actual connection muxing.
|
||||
type UniversalUDPMuxDefault struct {
|
||||
*UDPMuxDefault
|
||||
*SingleSocketUDPMux
|
||||
params UniversalUDPMuxParams
|
||||
|
||||
// since we have a shared socket, for srflx candidates it makes sense to have a shared mapped address across all the agents
|
||||
@@ -72,12 +72,12 @@ func NewUniversalUDPMuxDefault(params UniversalUDPMuxParams) *UniversalUDPMuxDef
|
||||
address: params.WGAddress,
|
||||
}
|
||||
|
||||
udpMuxParams := UDPMuxParams{
|
||||
udpMuxParams := Params{
|
||||
Logger: params.Logger,
|
||||
UDPConn: m.params.UDPConn,
|
||||
Net: m.params.Net,
|
||||
}
|
||||
m.UDPMuxDefault = NewUDPMuxDefault(udpMuxParams)
|
||||
m.SingleSocketUDPMux = NewSingleSocketUDPMux(udpMuxParams)
|
||||
|
||||
return m
|
||||
}
|
||||
@@ -211,8 +211,8 @@ func (m *UniversalUDPMuxDefault) GetRelayedAddr(turnAddr net.Addr, deadline time
|
||||
|
||||
// GetConnForURL add uniques to the muxed connection by concatenating ufrag and URL (e.g. STUN URL) to be able to support multiple STUN/TURN servers
|
||||
// and return a unique connection per server.
|
||||
func (m *UniversalUDPMuxDefault) GetConnForURL(ufrag string, url string, addr net.Addr) (net.PacketConn, error) {
|
||||
return m.UDPMuxDefault.GetConn(fmt.Sprintf("%s%s", ufrag, url), addr)
|
||||
func (m *UniversalUDPMuxDefault) GetConnForURL(ufrag string, url string, addr net.Addr, candidateID string) (net.PacketConn, error) {
|
||||
return m.SingleSocketUDPMux.GetConn(fmt.Sprintf("%s%s", ufrag, url), addr, candidateID)
|
||||
}
|
||||
|
||||
// HandleSTUNMessage discovers STUN packets that carry a XOR mapped address from a STUN server.
|
||||
@@ -233,7 +233,7 @@ func (m *UniversalUDPMuxDefault) HandleSTUNMessage(msg *stun.Message, addr net.A
|
||||
}
|
||||
return nil
|
||||
}
|
||||
return m.UDPMuxDefault.HandleSTUNMessage(msg, addr)
|
||||
return m.SingleSocketUDPMux.HandleSTUNMessage(msg, addr)
|
||||
}
|
||||
|
||||
// isXORMappedResponse indicates whether the message is a XORMappedAddress and is coming from the known STUN server.
|
||||
@@ -20,7 +20,7 @@ import (
|
||||
"github.com/netbirdio/netbird/client/iface/bufsize"
|
||||
"github.com/netbirdio/netbird/client/internal/ebpf"
|
||||
ebpfMgr "github.com/netbirdio/netbird/client/internal/ebpf/manager"
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
nbnet "github.com/netbirdio/netbird/client/net"
|
||||
)
|
||||
|
||||
const (
|
||||
|
||||
@@ -34,7 +34,7 @@ import (
|
||||
relayClient "github.com/netbirdio/netbird/shared/relay/client"
|
||||
signal "github.com/netbirdio/netbird/shared/signal/client"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
nbnet "github.com/netbirdio/netbird/client/net"
|
||||
"github.com/netbirdio/netbird/version"
|
||||
)
|
||||
|
||||
@@ -280,15 +280,12 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
|
||||
return wrapErr(err)
|
||||
}
|
||||
|
||||
|
||||
log.Infof("Netbird engine started, the IP is: %s", peerConfig.GetAddress())
|
||||
state.Set(StatusConnected)
|
||||
|
||||
if runningChan != nil {
|
||||
select {
|
||||
case runningChan <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
close(runningChan)
|
||||
runningChan = nil
|
||||
}
|
||||
|
||||
<-engineCtx.Done()
|
||||
|
||||
@@ -240,15 +240,17 @@ func (r *registryConfigurator) addDNSMatchPolicy(domains []string, ip netip.Addr
|
||||
// if the gpo key is present, we need to put our DNS settings there, otherwise our config might be ignored
|
||||
// see https://learn.microsoft.com/en-us/openspecs/windows_protocols/ms-gpnrpt/8cc31cb9-20cb-4140-9e85-3e08703b4745
|
||||
for i, domain := range domains {
|
||||
policyPath := fmt.Sprintf("%s-%d", dnsPolicyConfigMatchPath, i)
|
||||
if r.gpo {
|
||||
policyPath = fmt.Sprintf("%s-%d", gpoDnsPolicyConfigMatchPath, i)
|
||||
}
|
||||
|
||||
singleDomain := []string{domain}
|
||||
|
||||
if err := r.configureDNSPolicy(policyPath, singleDomain, ip); err != nil {
|
||||
return i, fmt.Errorf("configure DNS policy for domain %s: %w", domain, err)
|
||||
if err := r.configureDNSPolicy(dnsPolicyConfigMatchPath, singleDomain, ip); err != nil {
|
||||
return i, fmt.Errorf("configure DNS Local policy for domain %s: %w", domain, err)
|
||||
}
|
||||
|
||||
if r.gpo {
|
||||
if err := r.configureDNSPolicy(gpoDnsPolicyConfigMatchPath, singleDomain, ip); err != nil {
|
||||
return i, fmt.Errorf("configure gpo DNS policy: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
log.Debugf("added NRPT entry for domain: %s", domain)
|
||||
@@ -401,6 +403,7 @@ func (r *registryConfigurator) removeDNSMatchPolicies() error {
|
||||
if err := removeRegistryKeyFromDNSPolicyConfig(dnsPolicyConfigMatchPath); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove local base entry: %w", err))
|
||||
}
|
||||
|
||||
if err := removeRegistryKeyFromDNSPolicyConfig(gpoDnsPolicyConfigMatchPath); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove GPO base entry: %w", err))
|
||||
}
|
||||
@@ -412,6 +415,7 @@ func (r *registryConfigurator) removeDNSMatchPolicies() error {
|
||||
if err := removeRegistryKeyFromDNSPolicyConfig(localPath); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove local entry %d: %w", i, err))
|
||||
}
|
||||
|
||||
if err := removeRegistryKeyFromDNSPolicyConfig(gpoPath); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove GPO entry %d: %w", i, err))
|
||||
}
|
||||
|
||||
@@ -10,7 +10,7 @@ import (
|
||||
"github.com/miekg/dns"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
nbnet "github.com/netbirdio/netbird/client/net"
|
||||
)
|
||||
|
||||
type ServiceViaMemory struct {
|
||||
|
||||
@@ -10,7 +10,7 @@ import (
|
||||
"github.com/miekg/dns"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
nbnet "github.com/netbirdio/netbird/client/net"
|
||||
)
|
||||
|
||||
type upstreamResolver struct {
|
||||
|
||||
@@ -29,9 +29,9 @@ import (
|
||||
"github.com/netbirdio/netbird/client/firewall"
|
||||
firewallManager "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/bind"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
nbnetstack "github.com/netbirdio/netbird/client/iface/netstack"
|
||||
"github.com/netbirdio/netbird/client/iface/udpmux"
|
||||
"github.com/netbirdio/netbird/client/internal/acl"
|
||||
"github.com/netbirdio/netbird/client/internal/dns"
|
||||
dnsconfig "github.com/netbirdio/netbird/client/internal/dns/config"
|
||||
@@ -166,7 +166,7 @@ type Engine struct {
|
||||
|
||||
wgInterface WGIface
|
||||
|
||||
udpMux *bind.UniversalUDPMuxDefault
|
||||
udpMux *udpmux.UniversalUDPMuxDefault
|
||||
|
||||
// networkSerial is the latest CurrentSerial (state ID) of the network sent by the Management service
|
||||
networkSerial uint64
|
||||
@@ -198,6 +198,10 @@ type Engine struct {
|
||||
latestSyncResponse *mgmProto.SyncResponse
|
||||
connSemaphore *semaphoregroup.SemaphoreGroup
|
||||
flowManager nftypes.FlowManager
|
||||
|
||||
// WireGuard interface monitor
|
||||
wgIfaceMonitor *WGIfaceMonitor
|
||||
wgIfaceMonitorWg sync.WaitGroup
|
||||
}
|
||||
|
||||
// Peer is an instance of the Connection Peer
|
||||
@@ -341,6 +345,9 @@ func (e *Engine) Stop() error {
|
||||
log.Errorf("failed to persist state: %v", err)
|
||||
}
|
||||
|
||||
// Stop WireGuard interface monitor and wait for it to exit
|
||||
e.wgIfaceMonitorWg.Wait()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -446,6 +453,8 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
|
||||
return fmt.Errorf("up wg interface: %w", err)
|
||||
}
|
||||
|
||||
|
||||
|
||||
// if inbound conns are blocked there is no need to create the ACL manager
|
||||
if e.firewall != nil && !e.config.BlockInbound {
|
||||
e.acl = acl.NewDefaultManager(e.firewall)
|
||||
@@ -461,7 +470,7 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
|
||||
StunTurn: &e.stunTurn,
|
||||
InterfaceBlackList: e.config.IFaceBlackList,
|
||||
DisableIPv6Discovery: e.config.DisableIPv6Discovery,
|
||||
UDPMux: e.udpMux.UDPMuxDefault,
|
||||
UDPMux: e.udpMux.SingleSocketUDPMux,
|
||||
UDPMuxSrflx: e.udpMux,
|
||||
NATExternalIPs: e.parseNATExternalIPMappings(),
|
||||
}
|
||||
@@ -477,6 +486,22 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
|
||||
|
||||
// starting network monitor at the very last to avoid disruptions
|
||||
e.startNetworkMonitor()
|
||||
|
||||
// monitor WireGuard interface lifecycle and restart engine on changes
|
||||
e.wgIfaceMonitor = NewWGIfaceMonitor()
|
||||
e.wgIfaceMonitorWg.Add(1)
|
||||
|
||||
go func() {
|
||||
defer e.wgIfaceMonitorWg.Done()
|
||||
|
||||
if shouldRestart, err := e.wgIfaceMonitor.Start(e.ctx, e.wgInterface.Name()); shouldRestart {
|
||||
log.Infof("WireGuard interface monitor: %s, restarting engine", err)
|
||||
e.restartEngine()
|
||||
} else if err != nil {
|
||||
log.Warnf("WireGuard interface monitor: %s", err)
|
||||
}
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -949,7 +974,6 @@ func (e *Engine) receiveManagementEvents() {
|
||||
e.config.LazyConnectionEnabled,
|
||||
)
|
||||
|
||||
// err = e.mgmClient.Sync(info, e.handleSync)
|
||||
err = e.mgmClient.Sync(e.ctx, info, e.handleSync)
|
||||
if err != nil {
|
||||
// happens if management is unavailable for a long time.
|
||||
@@ -960,7 +984,7 @@ func (e *Engine) receiveManagementEvents() {
|
||||
}
|
||||
log.Debugf("stopped receiving updates from Management Service")
|
||||
}()
|
||||
log.Debugf("connecting to Management Service updates stream")
|
||||
log.Infof("connecting to Management Service updates stream")
|
||||
}
|
||||
|
||||
func (e *Engine) updateSTUNs(stuns []*mgmProto.HostConfig) error {
|
||||
@@ -1327,7 +1351,7 @@ func (e *Engine) createPeerConn(pubKey string, allowedIPs []netip.Prefix, agentV
|
||||
StunTurn: &e.stunTurn,
|
||||
InterfaceBlackList: e.config.IFaceBlackList,
|
||||
DisableIPv6Discovery: e.config.DisableIPv6Discovery,
|
||||
UDPMux: e.udpMux.UDPMuxDefault,
|
||||
UDPMux: e.udpMux.SingleSocketUDPMux,
|
||||
UDPMuxSrflx: e.udpMux,
|
||||
NATExternalIPs: e.parseNATExternalIPMappings(),
|
||||
},
|
||||
|
||||
@@ -19,21 +19,18 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.opentelemetry.io/otel"
|
||||
wgdevice "golang.zx2c4.com/wireguard/device"
|
||||
"golang.zx2c4.com/wireguard/tun/netstack"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/keepalive"
|
||||
|
||||
wgdevice "golang.zx2c4.com/wireguard/device"
|
||||
"golang.zx2c4.com/wireguard/tun/netstack"
|
||||
|
||||
"github.com/netbirdio/management-integrations/integrations"
|
||||
"github.com/netbirdio/netbird/management/internals/server/config"
|
||||
"github.com/netbirdio/netbird/management/server/groups"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/bind"
|
||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/client/iface/udpmux"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
"github.com/netbirdio/netbird/client/iface/wgproxy"
|
||||
"github.com/netbirdio/netbird/client/internal/dns"
|
||||
@@ -45,9 +42,12 @@ import (
|
||||
"github.com/netbirdio/netbird/client/ssh"
|
||||
"github.com/netbirdio/netbird/client/system"
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/internals/server/config"
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/groups"
|
||||
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
||||
"github.com/netbirdio/netbird/management/server/peers"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/settings"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
@@ -85,7 +85,7 @@ type MockWGIface struct {
|
||||
NameFunc func() string
|
||||
AddressFunc func() wgaddr.Address
|
||||
ToInterfaceFunc func() *net.Interface
|
||||
UpFunc func() (*bind.UniversalUDPMuxDefault, error)
|
||||
UpFunc func() (*udpmux.UniversalUDPMuxDefault, error)
|
||||
UpdateAddrFunc func(newAddr string) error
|
||||
UpdatePeerFunc func(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error
|
||||
RemovePeerFunc func(peerKey string) error
|
||||
@@ -135,7 +135,7 @@ func (m *MockWGIface) ToInterface() *net.Interface {
|
||||
return m.ToInterfaceFunc()
|
||||
}
|
||||
|
||||
func (m *MockWGIface) Up() (*bind.UniversalUDPMuxDefault, error) {
|
||||
func (m *MockWGIface) Up() (*udpmux.UniversalUDPMuxDefault, error) {
|
||||
return m.UpFunc()
|
||||
}
|
||||
|
||||
@@ -414,7 +414,7 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
engine.udpMux = bind.NewUniversalUDPMuxDefault(bind.UniversalUDPMuxParams{UDPConn: conn, MTU: 1280})
|
||||
engine.udpMux = udpmux.NewUniversalUDPMuxDefault(udpmux.UniversalUDPMuxParams{UDPConn: conn, MTU: 1280})
|
||||
engine.ctx = ctx
|
||||
engine.srWatcher = guard.NewSRWatcher(nil, nil, nil, icemaker.Config{})
|
||||
engine.connMgr = NewConnMgr(engine.config, engine.statusRecorder, engine.peerStore, wgIface)
|
||||
@@ -1555,7 +1555,11 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
ia, _ := integrations.NewIntegratedValidator(context.Background(), eventStore)
|
||||
|
||||
permissionsManager := permissions.NewManager(store)
|
||||
peersManager := peers.NewManager(store, permissionsManager)
|
||||
|
||||
ia, _ := integrations.NewIntegratedValidator(context.Background(), peersManager, nil, eventStore)
|
||||
|
||||
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
|
||||
require.NoError(t, err)
|
||||
@@ -1572,7 +1576,6 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri
|
||||
Return(&types.ExtraSettings{}, nil).
|
||||
AnyTimes()
|
||||
|
||||
permissionsManager := permissions.NewManager(store)
|
||||
groupsManager := groups.NewManagerMock()
|
||||
|
||||
accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
|
||||
|
||||
@@ -9,9 +9,9 @@ import (
|
||||
"golang.zx2c4.com/wireguard/tun/netstack"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/bind"
|
||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/client/iface/udpmux"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
"github.com/netbirdio/netbird/client/iface/wgproxy"
|
||||
"github.com/netbirdio/netbird/monotime"
|
||||
@@ -24,7 +24,7 @@ type wgIfaceBase interface {
|
||||
Name() string
|
||||
Address() wgaddr.Address
|
||||
ToInterface() *net.Interface
|
||||
Up() (*bind.UniversalUDPMuxDefault, error)
|
||||
Up() (*udpmux.UniversalUDPMuxDefault, error)
|
||||
UpdateAddr(newAddr string) error
|
||||
GetProxy() wgproxy.Proxy
|
||||
UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error
|
||||
|
||||
@@ -14,7 +14,7 @@ import (
|
||||
"github.com/ti-mo/netfilter"
|
||||
|
||||
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
nbnet "github.com/netbirdio/netbird/client/net"
|
||||
)
|
||||
|
||||
const defaultChannelSize = 100
|
||||
|
||||
@@ -6,7 +6,6 @@ import (
|
||||
"math/rand"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"runtime"
|
||||
"sync"
|
||||
"time"
|
||||
@@ -174,7 +173,7 @@ func (conn *Conn) Open(engineCtx context.Context) error {
|
||||
conn.handshaker = NewHandshaker(conn.Log, conn.config, conn.signaler, conn.workerICE, conn.workerRelay)
|
||||
|
||||
conn.handshaker.AddOnNewOfferListener(conn.workerRelay.OnNewOffer)
|
||||
if os.Getenv("NB_FORCE_RELAY") != "true" {
|
||||
if !isForceRelayed() {
|
||||
conn.handshaker.AddOnNewOfferListener(conn.workerICE.OnNewOffer)
|
||||
}
|
||||
|
||||
|
||||
14
client/internal/peer/env.go
Normal file
14
client/internal/peer/env.go
Normal file
@@ -0,0 +1,14 @@
|
||||
package peer
|
||||
|
||||
import (
|
||||
"os"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const (
|
||||
EnvKeyNBForceRelay = "NB_FORCE_RELAY"
|
||||
)
|
||||
|
||||
func isForceRelayed() bool {
|
||||
return strings.EqualFold(os.Getenv(EnvKeyNBForceRelay), "true")
|
||||
}
|
||||
@@ -43,13 +43,6 @@ type OfferAnswer struct {
|
||||
SessionID *ICESessionID
|
||||
}
|
||||
|
||||
func (oa *OfferAnswer) SessionIDString() string {
|
||||
if oa.SessionID == nil {
|
||||
return "unknown"
|
||||
}
|
||||
return oa.SessionID.String()
|
||||
}
|
||||
|
||||
type Handshaker struct {
|
||||
mu sync.Mutex
|
||||
log *log.Entry
|
||||
@@ -57,7 +50,7 @@ type Handshaker struct {
|
||||
signaler *Signaler
|
||||
ice *WorkerICE
|
||||
relay *WorkerRelay
|
||||
onNewOfferListeners []func(*OfferAnswer)
|
||||
onNewOfferListeners []*OfferListener
|
||||
|
||||
// remoteOffersCh is a channel used to wait for remote credentials to proceed with the connection
|
||||
remoteOffersCh chan OfferAnswer
|
||||
@@ -78,7 +71,8 @@ func NewHandshaker(log *log.Entry, config ConnConfig, signaler *Signaler, ice *W
|
||||
}
|
||||
|
||||
func (h *Handshaker) AddOnNewOfferListener(offer func(remoteOfferAnswer *OfferAnswer)) {
|
||||
h.onNewOfferListeners = append(h.onNewOfferListeners, offer)
|
||||
l := NewOfferListener(offer)
|
||||
h.onNewOfferListeners = append(h.onNewOfferListeners, l)
|
||||
}
|
||||
|
||||
func (h *Handshaker) Listen(ctx context.Context) {
|
||||
@@ -91,13 +85,13 @@ func (h *Handshaker) Listen(ctx context.Context) {
|
||||
continue
|
||||
}
|
||||
for _, listener := range h.onNewOfferListeners {
|
||||
listener(&remoteOfferAnswer)
|
||||
listener.Notify(&remoteOfferAnswer)
|
||||
}
|
||||
h.log.Infof("received offer, running version %s, remote WireGuard listen port %d, session id: %s", remoteOfferAnswer.Version, remoteOfferAnswer.WgListenPort, remoteOfferAnswer.SessionIDString())
|
||||
case remoteOfferAnswer := <-h.remoteAnswerCh:
|
||||
h.log.Infof("received answer, running version %s, remote WireGuard listen port %d, session id: %s", remoteOfferAnswer.Version, remoteOfferAnswer.WgListenPort, remoteOfferAnswer.SessionIDString())
|
||||
for _, listener := range h.onNewOfferListeners {
|
||||
listener(&remoteOfferAnswer)
|
||||
listener.Notify(&remoteOfferAnswer)
|
||||
}
|
||||
case <-ctx.Done():
|
||||
h.log.Infof("stop listening for remote offers and answers")
|
||||
|
||||
62
client/internal/peer/handshaker_listener.go
Normal file
62
client/internal/peer/handshaker_listener.go
Normal file
@@ -0,0 +1,62 @@
|
||||
package peer
|
||||
|
||||
import (
|
||||
"sync"
|
||||
)
|
||||
|
||||
type callbackFunc func(remoteOfferAnswer *OfferAnswer)
|
||||
|
||||
func (oa *OfferAnswer) SessionIDString() string {
|
||||
if oa.SessionID == nil {
|
||||
return "unknown"
|
||||
}
|
||||
return oa.SessionID.String()
|
||||
}
|
||||
|
||||
type OfferListener struct {
|
||||
fn callbackFunc
|
||||
running bool
|
||||
latest *OfferAnswer
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
func NewOfferListener(fn callbackFunc) *OfferListener {
|
||||
return &OfferListener{
|
||||
fn: fn,
|
||||
}
|
||||
}
|
||||
|
||||
func (o *OfferListener) Notify(remoteOfferAnswer *OfferAnswer) {
|
||||
o.mu.Lock()
|
||||
defer o.mu.Unlock()
|
||||
|
||||
// Store the latest offer
|
||||
o.latest = remoteOfferAnswer
|
||||
|
||||
// If already running, the running goroutine will pick up this latest value
|
||||
if o.running {
|
||||
return
|
||||
}
|
||||
|
||||
// Start processing
|
||||
o.running = true
|
||||
|
||||
// Process in a goroutine to avoid blocking the caller
|
||||
go func(remoteOfferAnswer *OfferAnswer) {
|
||||
for {
|
||||
o.fn(remoteOfferAnswer)
|
||||
|
||||
o.mu.Lock()
|
||||
if o.latest == nil {
|
||||
// No more work to do
|
||||
o.running = false
|
||||
o.mu.Unlock()
|
||||
return
|
||||
}
|
||||
remoteOfferAnswer = o.latest
|
||||
// Clear the latest to mark it as being processed
|
||||
o.latest = nil
|
||||
o.mu.Unlock()
|
||||
}
|
||||
}(remoteOfferAnswer)
|
||||
}
|
||||
39
client/internal/peer/handshaker_listener_test.go
Normal file
39
client/internal/peer/handshaker_listener_test.go
Normal file
@@ -0,0 +1,39 @@
|
||||
package peer
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func Test_newOfferListener(t *testing.T) {
|
||||
dummyOfferAnswer := &OfferAnswer{}
|
||||
runChan := make(chan struct{}, 10)
|
||||
|
||||
longRunningFn := func(remoteOfferAnswer *OfferAnswer) {
|
||||
time.Sleep(1 * time.Second)
|
||||
runChan <- struct{}{}
|
||||
}
|
||||
|
||||
hl := NewOfferListener(longRunningFn)
|
||||
|
||||
hl.Notify(dummyOfferAnswer)
|
||||
hl.Notify(dummyOfferAnswer)
|
||||
hl.Notify(dummyOfferAnswer)
|
||||
|
||||
// Wait for exactly 2 callbacks
|
||||
for i := 0; i < 2; i++ {
|
||||
select {
|
||||
case <-runChan:
|
||||
case <-time.After(3 * time.Second):
|
||||
t.Fatal("Timeout waiting for callback")
|
||||
}
|
||||
}
|
||||
|
||||
// Verify no additional callbacks happen
|
||||
select {
|
||||
case <-runChan:
|
||||
t.Fatal("Unexpected additional callback")
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
t.Log("Correctly received exactly 2 callbacks")
|
||||
}
|
||||
}
|
||||
@@ -30,9 +30,10 @@ type WGWatcher struct {
|
||||
peerKey string
|
||||
stateDump *stateDump
|
||||
|
||||
ctx context.Context
|
||||
ctxCancel context.CancelFunc
|
||||
ctxLock sync.Mutex
|
||||
ctx context.Context
|
||||
ctxCancel context.CancelFunc
|
||||
ctxLock sync.Mutex
|
||||
enabledTime time.Time
|
||||
}
|
||||
|
||||
func NewWGWatcher(log *log.Entry, wgIfaceStater WGInterfaceStater, peerKey string, stateDump *stateDump) *WGWatcher {
|
||||
@@ -48,6 +49,7 @@ func NewWGWatcher(log *log.Entry, wgIfaceStater WGInterfaceStater, peerKey strin
|
||||
func (w *WGWatcher) EnableWgWatcher(parentCtx context.Context, onDisconnectedFn func()) {
|
||||
w.log.Debugf("enable WireGuard watcher")
|
||||
w.ctxLock.Lock()
|
||||
w.enabledTime = time.Now()
|
||||
|
||||
if w.ctx != nil && w.ctx.Err() == nil {
|
||||
w.log.Errorf("WireGuard watcher already enabled")
|
||||
@@ -101,6 +103,11 @@ func (w *WGWatcher) periodicHandshakeCheck(ctx context.Context, ctxCancel contex
|
||||
onDisconnectedFn()
|
||||
return
|
||||
}
|
||||
if lastHandshake.IsZero() {
|
||||
elapsed := handshake.Sub(w.enabledTime).Seconds()
|
||||
w.log.Infof("first wg handshake detected within: %.2fsec, (%s)", elapsed, handshake)
|
||||
}
|
||||
|
||||
lastHandshake = *handshake
|
||||
|
||||
resetTime := time.Until(handshake.Add(checkPeriod))
|
||||
|
||||
@@ -9,11 +9,10 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/pion/ice/v4"
|
||||
"github.com/pion/stun/v2"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/bind"
|
||||
"github.com/netbirdio/netbird/client/iface/udpmux"
|
||||
"github.com/netbirdio/netbird/client/internal/peer/conntype"
|
||||
icemaker "github.com/netbirdio/netbird/client/internal/peer/ice"
|
||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||
@@ -55,10 +54,6 @@ type WorkerICE struct {
|
||||
sessionID ICESessionID
|
||||
muxAgent sync.Mutex
|
||||
|
||||
StunTurn []*stun.URI
|
||||
|
||||
sentExtraSrflx bool
|
||||
|
||||
localUfrag string
|
||||
localPwd string
|
||||
|
||||
@@ -122,7 +117,6 @@ func (w *WorkerICE) OnNewOffer(remoteOfferAnswer *OfferAnswer) {
|
||||
w.log.Warnf("failed to close ICE agent: %s", err)
|
||||
}
|
||||
w.agent = nil
|
||||
// todo consider to switch to Relay connection while establishing a new ICE connection
|
||||
}
|
||||
|
||||
var preferredCandidateTypes []ice.CandidateType
|
||||
@@ -140,7 +134,6 @@ func (w *WorkerICE) OnNewOffer(remoteOfferAnswer *OfferAnswer) {
|
||||
w.muxAgent.Unlock()
|
||||
return
|
||||
}
|
||||
w.sentExtraSrflx = false
|
||||
w.agent = agent
|
||||
w.agentDialerCancel = dialerCancel
|
||||
w.agentConnecting = true
|
||||
@@ -167,6 +160,21 @@ func (w *WorkerICE) OnRemoteCandidate(candidate ice.Candidate, haRoutes route.HA
|
||||
w.log.Errorf("error while handling remote candidate")
|
||||
return
|
||||
}
|
||||
|
||||
if shouldAddExtraCandidate(candidate) {
|
||||
// sends an extra server reflexive candidate to the remote peer with our related port (usually the wireguard port)
|
||||
// this is useful when network has an existing port forwarding rule for the wireguard port and this peer
|
||||
extraSrflx, err := extraSrflxCandidate(candidate)
|
||||
if err != nil {
|
||||
w.log.Errorf("failed creating extra server reflexive candidate %s", err)
|
||||
return
|
||||
}
|
||||
|
||||
if err := w.agent.AddRemoteCandidate(extraSrflx); err != nil {
|
||||
w.log.Errorf("error while handling remote candidate")
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (w *WorkerICE) GetLocalUserCredentials() (frag string, pwd string) {
|
||||
@@ -210,7 +218,9 @@ func (w *WorkerICE) reCreateAgent(dialerCancel context.CancelFunc, candidates []
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := agent.OnSelectedCandidatePairChange(w.onICESelectedCandidatePair); err != nil {
|
||||
if err := agent.OnSelectedCandidatePairChange(func(c1, c2 ice.Candidate) {
|
||||
w.onICESelectedCandidatePair(agent, c1, c2)
|
||||
}); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -328,7 +338,7 @@ func (w *WorkerICE) punchRemoteWGPort(pair *ice.CandidatePair, remoteWgPort int)
|
||||
return
|
||||
}
|
||||
|
||||
mux, ok := w.config.ICEConfig.UDPMuxSrflx.(*bind.UniversalUDPMuxDefault)
|
||||
mux, ok := w.config.ICEConfig.UDPMuxSrflx.(*udpmux.UniversalUDPMuxDefault)
|
||||
if !ok {
|
||||
w.log.Warn("invalid udp mux conversion")
|
||||
return
|
||||
@@ -355,48 +365,19 @@ func (w *WorkerICE) onICECandidate(candidate ice.Candidate) {
|
||||
w.log.Errorf("failed signaling candidate to the remote peer %s %s", w.config.Key, err)
|
||||
}
|
||||
}()
|
||||
|
||||
if !w.shouldSendExtraSrflxCandidate(candidate) {
|
||||
return
|
||||
}
|
||||
|
||||
// sends an extra server reflexive candidate to the remote peer with our related port (usually the wireguard port)
|
||||
// this is useful when network has an existing port forwarding rule for the wireguard port and this peer
|
||||
extraSrflx, err := extraSrflxCandidate(candidate)
|
||||
if err != nil {
|
||||
w.log.Errorf("failed creating extra server reflexive candidate %s", err)
|
||||
return
|
||||
}
|
||||
w.sentExtraSrflx = true
|
||||
|
||||
go func() {
|
||||
err = w.signaler.SignalICECandidate(extraSrflx, w.config.Key)
|
||||
if err != nil {
|
||||
w.log.Errorf("failed signaling the extra server reflexive candidate: %s", err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (w *WorkerICE) onICESelectedCandidatePair(c1 ice.Candidate, c2 ice.Candidate) {
|
||||
func (w *WorkerICE) onICESelectedCandidatePair(agent *icemaker.ThreadSafeAgent, c1, c2 ice.Candidate) {
|
||||
w.log.Debugf("selected candidate pair [local <-> remote] -> [%s <-> %s], peer %s", c1.String(), c2.String(),
|
||||
w.config.Key)
|
||||
|
||||
w.muxAgent.Lock()
|
||||
|
||||
pair, err := w.agent.GetSelectedCandidatePair()
|
||||
if err != nil {
|
||||
w.log.Warnf("failed to get selected candidate pair: %s", err)
|
||||
w.muxAgent.Unlock()
|
||||
pairStat, ok := agent.GetSelectedCandidatePairStats()
|
||||
if !ok {
|
||||
w.log.Warnf("failed to get selected candidate pair stats")
|
||||
return
|
||||
}
|
||||
if pair == nil {
|
||||
w.log.Warnf("selected candidate pair is nil, cannot proceed")
|
||||
w.muxAgent.Unlock()
|
||||
return
|
||||
}
|
||||
w.muxAgent.Unlock()
|
||||
|
||||
duration := time.Duration(pair.CurrentRoundTripTime() * float64(time.Second))
|
||||
duration := time.Duration(pairStat.CurrentRoundTripTime * float64(time.Second))
|
||||
if err := w.statusRecorder.UpdateLatency(w.config.Key, duration); err != nil {
|
||||
w.log.Debugf("failed to update latency for peer: %s", err)
|
||||
return
|
||||
@@ -410,7 +391,10 @@ func (w *WorkerICE) onConnectionStateChange(agent *icemaker.ThreadSafeAgent, dia
|
||||
case ice.ConnectionStateConnected:
|
||||
w.lastKnownState = ice.ConnectionStateConnected
|
||||
return
|
||||
case ice.ConnectionStateFailed, ice.ConnectionStateDisconnected:
|
||||
case ice.ConnectionStateFailed, ice.ConnectionStateDisconnected, ice.ConnectionStateClosed:
|
||||
// ice.ConnectionStateClosed happens when we recreate the agent. For the P2P to TURN switch important to
|
||||
// notify the conn.onICEStateDisconnected changes to update the current used priority
|
||||
|
||||
if w.lastKnownState == ice.ConnectionStateConnected {
|
||||
w.lastKnownState = ice.ConnectionStateDisconnected
|
||||
w.conn.onICEStateDisconnected()
|
||||
@@ -422,22 +406,31 @@ func (w *WorkerICE) onConnectionStateChange(agent *icemaker.ThreadSafeAgent, dia
|
||||
}
|
||||
}
|
||||
|
||||
func (w *WorkerICE) shouldSendExtraSrflxCandidate(candidate ice.Candidate) bool {
|
||||
if !w.sentExtraSrflx && candidate.Type() == ice.CandidateTypeServerReflexive && candidate.Port() != candidate.RelatedAddress().Port {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (w *WorkerICE) turnAgentDial(ctx context.Context, agent *icemaker.ThreadSafeAgent, remoteOfferAnswer *OfferAnswer) (*ice.Conn, error) {
|
||||
isControlling := w.config.LocalKey > w.config.Key
|
||||
if isControlling {
|
||||
return agent.Dial(ctx, remoteOfferAnswer.IceCredentials.UFrag, remoteOfferAnswer.IceCredentials.Pwd)
|
||||
if isController(w.config) {
|
||||
return w.agent.Dial(ctx, remoteOfferAnswer.IceCredentials.UFrag, remoteOfferAnswer.IceCredentials.Pwd)
|
||||
} else {
|
||||
return agent.Accept(ctx, remoteOfferAnswer.IceCredentials.UFrag, remoteOfferAnswer.IceCredentials.Pwd)
|
||||
}
|
||||
}
|
||||
|
||||
func shouldAddExtraCandidate(candidate ice.Candidate) bool {
|
||||
if candidate.Type() != ice.CandidateTypeServerReflexive {
|
||||
return false
|
||||
}
|
||||
|
||||
if candidate.Port() == candidate.RelatedAddress().Port {
|
||||
return false
|
||||
}
|
||||
|
||||
// in the older version when we didn't set candidate ID extension the remote peer sent the extra candidates
|
||||
// in newer version we generate locally the extra candidate
|
||||
if _, ok := candidate.GetExtension(ice.ExtensionKeyCandidateID); !ok {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func extraSrflxCandidate(candidate ice.Candidate) (*ice.CandidateServerReflexive, error) {
|
||||
relatedAdd := candidate.RelatedAddress()
|
||||
ec, err := ice.NewCandidateServerReflexive(&ice.CandidateServerReflexiveConfig{
|
||||
@@ -453,6 +446,10 @@ func extraSrflxCandidate(candidate ice.Candidate) (*ice.CandidateServerReflexive
|
||||
}
|
||||
|
||||
for _, e := range candidate.Extensions() {
|
||||
// overwrite the original candidate ID with the new one to avoid candidate duplication
|
||||
if e.Key == ice.ExtensionKeyCandidateID {
|
||||
e.Value = candidate.ID()
|
||||
}
|
||||
if err := ec.AddExtension(e); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -12,7 +12,7 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
nbnet "github.com/netbirdio/netbird/client/net"
|
||||
)
|
||||
|
||||
// ProbeResult holds the info about the result of a relay probe request
|
||||
|
||||
@@ -36,9 +36,9 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/vars"
|
||||
"github.com/netbirdio/netbird/client/internal/routeselector"
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
nbnet "github.com/netbirdio/netbird/client/net"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
relayClient "github.com/netbirdio/netbird/shared/relay/client"
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
"github.com/netbirdio/netbird/version"
|
||||
)
|
||||
|
||||
@@ -108,6 +108,10 @@ func NewManager(config ManagerConfig) *DefaultManager {
|
||||
notifier := notifier.NewNotifier()
|
||||
sysOps := systemops.NewSysOps(config.WGInterface, notifier)
|
||||
|
||||
if runtime.GOOS == "windows" && config.WGInterface != nil {
|
||||
nbnet.SetVPNInterfaceName(config.WGInterface.Name())
|
||||
}
|
||||
|
||||
dm := &DefaultManager{
|
||||
ctx: mCTX,
|
||||
stop: cancel,
|
||||
@@ -208,7 +212,7 @@ func (m *DefaultManager) Init() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := m.sysOps.CleanupRouting(nil); err != nil {
|
||||
if err := m.sysOps.CleanupRouting(nil, nbnet.AdvancedRouting()); err != nil {
|
||||
log.Warnf("Failed cleaning up routing: %v", err)
|
||||
}
|
||||
|
||||
@@ -219,7 +223,7 @@ func (m *DefaultManager) Init() error {
|
||||
|
||||
ips := resolveURLsToIPs(initialAddresses)
|
||||
|
||||
if err := m.sysOps.SetupRouting(ips, m.stateManager); err != nil {
|
||||
if err := m.sysOps.SetupRouting(ips, m.stateManager, nbnet.AdvancedRouting()); err != nil {
|
||||
return fmt.Errorf("setup routing: %w", err)
|
||||
}
|
||||
|
||||
@@ -285,11 +289,15 @@ func (m *DefaultManager) Stop(stateManager *statemanager.Manager) {
|
||||
}
|
||||
|
||||
if !nbnet.CustomRoutingDisabled() && !m.disableClientRoutes {
|
||||
if err := m.sysOps.CleanupRouting(stateManager); err != nil {
|
||||
if err := m.sysOps.CleanupRouting(stateManager, nbnet.AdvancedRouting()); err != nil {
|
||||
log.Errorf("Error cleaning up routing: %v", err)
|
||||
} else {
|
||||
log.Info("Routing cleanup complete")
|
||||
}
|
||||
|
||||
if runtime.GOOS == "windows" {
|
||||
nbnet.SetVPNInterfaceName("")
|
||||
}
|
||||
}
|
||||
|
||||
m.mux.Lock()
|
||||
|
||||
@@ -12,11 +12,11 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
)
|
||||
|
||||
func (r *SysOps) SetupRouting([]net.IP, *statemanager.Manager) error {
|
||||
func (r *SysOps) SetupRouting([]net.IP, *statemanager.Manager, bool) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *SysOps) CleanupRouting(*statemanager.Manager) error {
|
||||
func (r *SysOps) CleanupRouting(*statemanager.Manager, bool) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -3,7 +3,6 @@
|
||||
package systemops
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
@@ -22,7 +21,7 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/util"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/vars"
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
"github.com/netbirdio/netbird/client/net/hooks"
|
||||
)
|
||||
|
||||
const localSubnetsCacheTTL = 15 * time.Minute
|
||||
@@ -96,9 +95,9 @@ func (r *SysOps) cleanupRefCounter(stateManager *statemanager.Manager) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// TODO: Remove hooks selectively
|
||||
nbnet.RemoveDialerHooks()
|
||||
nbnet.RemoveListenerHooks()
|
||||
hooks.RemoveWriteHooks()
|
||||
hooks.RemoveCloseHooks()
|
||||
hooks.RemoveAddressRemoveHooks()
|
||||
|
||||
if err := r.refCounter.Flush(); err != nil {
|
||||
return fmt.Errorf("flush route manager: %w", err)
|
||||
@@ -290,12 +289,7 @@ func (r *SysOps) genericRemoveVPNRoute(prefix netip.Prefix, intf *net.Interface)
|
||||
}
|
||||
|
||||
func (r *SysOps) setupHooks(initAddresses []net.IP, stateManager *statemanager.Manager) error {
|
||||
beforeHook := func(connID nbnet.ConnectionID, ip net.IP) error {
|
||||
prefix, err := util.GetPrefixFromIP(ip)
|
||||
if err != nil {
|
||||
return fmt.Errorf("convert ip to prefix: %w", err)
|
||||
}
|
||||
|
||||
beforeHook := func(connID hooks.ConnectionID, prefix netip.Prefix) error {
|
||||
if _, err := r.refCounter.IncrementWithID(string(connID), prefix, struct{}{}); err != nil {
|
||||
return fmt.Errorf("adding route reference: %v", err)
|
||||
}
|
||||
@@ -304,7 +298,7 @@ func (r *SysOps) setupHooks(initAddresses []net.IP, stateManager *statemanager.M
|
||||
|
||||
return nil
|
||||
}
|
||||
afterHook := func(connID nbnet.ConnectionID) error {
|
||||
afterHook := func(connID hooks.ConnectionID) error {
|
||||
if err := r.refCounter.DecrementWithID(string(connID)); err != nil {
|
||||
return fmt.Errorf("remove route reference: %w", err)
|
||||
}
|
||||
@@ -317,36 +311,20 @@ func (r *SysOps) setupHooks(initAddresses []net.IP, stateManager *statemanager.M
|
||||
var merr *multierror.Error
|
||||
|
||||
for _, ip := range initAddresses {
|
||||
if err := beforeHook("init", ip); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("add initial route for %s: %w", ip, err))
|
||||
prefix, err := util.GetPrefixFromIP(ip)
|
||||
if err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("invalid IP address %s: %w", ip, err))
|
||||
continue
|
||||
}
|
||||
if err := beforeHook("init", prefix); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("add initial route for %s: %w", prefix, err))
|
||||
}
|
||||
}
|
||||
|
||||
nbnet.AddDialerHook(func(ctx context.Context, connID nbnet.ConnectionID, resolvedIPs []net.IPAddr) error {
|
||||
if ctx.Err() != nil {
|
||||
return ctx.Err()
|
||||
}
|
||||
hooks.AddWriteHook(beforeHook)
|
||||
hooks.AddCloseHook(afterHook)
|
||||
|
||||
var merr *multierror.Error
|
||||
for _, ip := range resolvedIPs {
|
||||
merr = multierror.Append(merr, beforeHook(connID, ip.IP))
|
||||
}
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
})
|
||||
|
||||
nbnet.AddDialerCloseHook(func(connID nbnet.ConnectionID, conn *net.Conn) error {
|
||||
return afterHook(connID)
|
||||
})
|
||||
|
||||
nbnet.AddListenerWriteHook(func(connID nbnet.ConnectionID, ip *net.IPAddr, data []byte) error {
|
||||
return beforeHook(connID, ip.IP)
|
||||
})
|
||||
|
||||
nbnet.AddListenerCloseHook(func(connID nbnet.ConnectionID, conn net.PacketConn) error {
|
||||
return afterHook(connID)
|
||||
})
|
||||
|
||||
nbnet.AddListenerAddressRemoveHook(func(connID nbnet.ConnectionID, prefix netip.Prefix) error {
|
||||
hooks.AddAddressRemoveHook(func(connID hooks.ConnectionID, prefix netip.Prefix) error {
|
||||
if _, err := r.refCounter.Decrement(prefix); err != nil {
|
||||
return fmt.Errorf("remove route reference: %w", err)
|
||||
}
|
||||
|
||||
@@ -22,6 +22,7 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/vars"
|
||||
nbnet "github.com/netbirdio/netbird/client/net"
|
||||
)
|
||||
|
||||
type dialer interface {
|
||||
@@ -143,10 +144,11 @@ func TestAddVPNRoute(t *testing.T) {
|
||||
wgInterface := createWGInterface(t, fmt.Sprintf("utun53%d", n), "100.65.75.2/24", 33100+n)
|
||||
|
||||
r := NewSysOps(wgInterface, nil)
|
||||
err := r.SetupRouting(nil, nil)
|
||||
advancedRouting := nbnet.AdvancedRouting()
|
||||
err := r.SetupRouting(nil, nil, advancedRouting)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
assert.NoError(t, r.CleanupRouting(nil))
|
||||
assert.NoError(t, r.CleanupRouting(nil, advancedRouting))
|
||||
})
|
||||
|
||||
intf, err := net.InterfaceByName(wgInterface.Name())
|
||||
@@ -341,10 +343,11 @@ func TestAddRouteToNonVPNIntf(t *testing.T) {
|
||||
wgInterface := createWGInterface(t, fmt.Sprintf("utun54%d", n), "100.65.75.2/24", 33200+n)
|
||||
|
||||
r := NewSysOps(wgInterface, nil)
|
||||
err := r.SetupRouting(nil, nil)
|
||||
advancedRouting := nbnet.AdvancedRouting()
|
||||
err := r.SetupRouting(nil, nil, advancedRouting)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
assert.NoError(t, r.CleanupRouting(nil))
|
||||
assert.NoError(t, r.CleanupRouting(nil, advancedRouting))
|
||||
})
|
||||
|
||||
initialNextHopV4, err := GetNextHop(netip.IPv4Unspecified())
|
||||
@@ -484,10 +487,11 @@ func setupTestEnv(t *testing.T) {
|
||||
})
|
||||
|
||||
r := NewSysOps(wgInterface, nil)
|
||||
err := r.SetupRouting(nil, nil)
|
||||
advancedRouting := nbnet.AdvancedRouting()
|
||||
err := r.SetupRouting(nil, nil, advancedRouting)
|
||||
require.NoError(t, err, "setupRouting should not return err")
|
||||
t.Cleanup(func() {
|
||||
assert.NoError(t, r.CleanupRouting(nil))
|
||||
assert.NoError(t, r.CleanupRouting(nil, advancedRouting))
|
||||
})
|
||||
|
||||
index, err := net.InterfaceByName(wgInterface.Name())
|
||||
|
||||
@@ -12,14 +12,14 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
)
|
||||
|
||||
func (r *SysOps) SetupRouting([]net.IP, *statemanager.Manager) error {
|
||||
func (r *SysOps) SetupRouting([]net.IP, *statemanager.Manager, bool) error {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
r.prefixes = make(map[netip.Prefix]struct{})
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *SysOps) CleanupRouting(*statemanager.Manager) error {
|
||||
func (r *SysOps) CleanupRouting(*statemanager.Manager, bool) error {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
|
||||
@@ -20,7 +20,7 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/sysctl"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/vars"
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
nbnet "github.com/netbirdio/netbird/client/net"
|
||||
)
|
||||
|
||||
// IPRule contains IP rule information for debugging
|
||||
@@ -94,15 +94,15 @@ func getSetupRules() []ruleParams {
|
||||
// Rule 2 (VPN Traffic Routing): Directs all remaining traffic to the 'NetbirdVPNTableID' custom routing table.
|
||||
// This table is where a default route or other specific routes received from the management server are configured,
|
||||
// enabling VPN connectivity.
|
||||
func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager) (err error) {
|
||||
if !nbnet.AdvancedRouting() {
|
||||
func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager, advancedRouting bool) (err error) {
|
||||
if !advancedRouting {
|
||||
log.Infof("Using legacy routing setup")
|
||||
return r.setupRefCounter(initAddresses, stateManager)
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if err != nil {
|
||||
if cleanErr := r.CleanupRouting(stateManager); cleanErr != nil {
|
||||
if cleanErr := r.CleanupRouting(stateManager, advancedRouting); cleanErr != nil {
|
||||
log.Errorf("Error cleaning up routing: %v", cleanErr)
|
||||
}
|
||||
}
|
||||
@@ -132,8 +132,8 @@ func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager
|
||||
// CleanupRouting performs a thorough cleanup of the routing configuration established by 'setupRouting'.
|
||||
// It systematically removes the three rules and any associated routing table entries to ensure a clean state.
|
||||
// The function uses error aggregation to report any errors encountered during the cleanup process.
|
||||
func (r *SysOps) CleanupRouting(stateManager *statemanager.Manager) error {
|
||||
if !nbnet.AdvancedRouting() {
|
||||
func (r *SysOps) CleanupRouting(stateManager *statemanager.Manager, advancedRouting bool) error {
|
||||
if !advancedRouting {
|
||||
return r.cleanupRefCounter(stateManager)
|
||||
}
|
||||
|
||||
|
||||
@@ -20,11 +20,11 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
)
|
||||
|
||||
func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager) error {
|
||||
func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager, advancedRouting bool) error {
|
||||
return r.setupRefCounter(initAddresses, stateManager)
|
||||
}
|
||||
|
||||
func (r *SysOps) CleanupRouting(stateManager *statemanager.Manager) error {
|
||||
func (r *SysOps) CleanupRouting(stateManager *statemanager.Manager, advancedRouting bool) error {
|
||||
return r.cleanupRefCounter(stateManager)
|
||||
}
|
||||
|
||||
|
||||
@@ -17,7 +17,7 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
nbnet "github.com/netbirdio/netbird/client/net"
|
||||
)
|
||||
|
||||
type PacketExpectation struct {
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"net/netip"
|
||||
"os"
|
||||
"runtime/debug"
|
||||
"sort"
|
||||
"strconv"
|
||||
"sync"
|
||||
"syscall"
|
||||
@@ -19,9 +20,16 @@ import (
|
||||
"golang.org/x/sys/windows"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
nbnet "github.com/netbirdio/netbird/client/net"
|
||||
)
|
||||
|
||||
const InfiniteLifetime = 0xffffffff
|
||||
func init() {
|
||||
nbnet.GetBestInterfaceFunc = GetBestInterface
|
||||
}
|
||||
|
||||
const (
|
||||
InfiniteLifetime = 0xffffffff
|
||||
)
|
||||
|
||||
type RouteUpdateType int
|
||||
|
||||
@@ -77,6 +85,14 @@ type MIB_IPFORWARD_TABLE2 struct {
|
||||
Table [1]MIB_IPFORWARD_ROW2 // Flexible array member
|
||||
}
|
||||
|
||||
// candidateRoute represents a potential route for selection during route lookup
|
||||
type candidateRoute struct {
|
||||
interfaceIndex uint32
|
||||
prefixLength uint8
|
||||
routeMetric uint32
|
||||
interfaceMetric int
|
||||
}
|
||||
|
||||
// IP_ADDRESS_PREFIX is defined in https://learn.microsoft.com/en-us/windows/win32/api/netioapi/ns-netioapi-ip_address_prefix
|
||||
type IP_ADDRESS_PREFIX struct {
|
||||
Prefix SOCKADDR_INET
|
||||
@@ -177,11 +193,20 @@ const (
|
||||
RouteDeleted
|
||||
)
|
||||
|
||||
func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager) error {
|
||||
func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager, advancedRouting bool) error {
|
||||
if advancedRouting {
|
||||
return nil
|
||||
}
|
||||
|
||||
log.Infof("Using legacy routing setup with ref counters")
|
||||
return r.setupRefCounter(initAddresses, stateManager)
|
||||
}
|
||||
|
||||
func (r *SysOps) CleanupRouting(stateManager *statemanager.Manager) error {
|
||||
func (r *SysOps) CleanupRouting(stateManager *statemanager.Manager, advancedRouting bool) error {
|
||||
if advancedRouting {
|
||||
return nil
|
||||
}
|
||||
|
||||
return r.cleanupRefCounter(stateManager)
|
||||
}
|
||||
|
||||
@@ -635,10 +660,7 @@ func getWindowsRoutingTable() (*MIB_IPFORWARD_TABLE2, error) {
|
||||
|
||||
func freeWindowsRoutingTable(table *MIB_IPFORWARD_TABLE2) {
|
||||
if table != nil {
|
||||
ret, _, _ := procFreeMibTable.Call(uintptr(unsafe.Pointer(table)))
|
||||
if ret != 0 {
|
||||
log.Warnf("FreeMibTable failed with return code: %d", ret)
|
||||
}
|
||||
_, _, _ = procFreeMibTable.Call(uintptr(unsafe.Pointer(table)))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -652,8 +674,7 @@ func parseWindowsRoutingTable(table *MIB_IPFORWARD_TABLE2) []DetailedRoute {
|
||||
entryPtr := basePtr + uintptr(i)*entrySize
|
||||
entry := (*MIB_IPFORWARD_ROW2)(unsafe.Pointer(entryPtr))
|
||||
|
||||
detailed := buildWindowsDetailedRoute(entry)
|
||||
if detailed != nil {
|
||||
if detailed := buildWindowsDetailedRoute(entry); detailed != nil {
|
||||
detailedRoutes = append(detailedRoutes, *detailed)
|
||||
}
|
||||
}
|
||||
@@ -802,6 +823,46 @@ func addZone(ip netip.Addr, interfaceIndex int) netip.Addr {
|
||||
return ip
|
||||
}
|
||||
|
||||
// parseCandidatesFromTable extracts all matching candidate routes from the routing table
|
||||
func parseCandidatesFromTable(table *MIB_IPFORWARD_TABLE2, dest netip.Addr, skipInterfaceIndex int) []candidateRoute {
|
||||
var candidates []candidateRoute
|
||||
entrySize := unsafe.Sizeof(MIB_IPFORWARD_ROW2{})
|
||||
basePtr := uintptr(unsafe.Pointer(&table.Table[0]))
|
||||
|
||||
for i := uint32(0); i < table.NumEntries; i++ {
|
||||
entryPtr := basePtr + uintptr(i)*entrySize
|
||||
entry := (*MIB_IPFORWARD_ROW2)(unsafe.Pointer(entryPtr))
|
||||
|
||||
if candidate := parseCandidateRoute(entry, dest, skipInterfaceIndex); candidate != nil {
|
||||
candidates = append(candidates, *candidate)
|
||||
}
|
||||
}
|
||||
|
||||
return candidates
|
||||
}
|
||||
|
||||
// parseCandidateRoute extracts candidate route information from a MIB_IPFORWARD_ROW2 entry
|
||||
// Returns nil if the route doesn't match the destination or should be skipped
|
||||
func parseCandidateRoute(entry *MIB_IPFORWARD_ROW2, dest netip.Addr, skipInterfaceIndex int) *candidateRoute {
|
||||
if skipInterfaceIndex > 0 && int(entry.InterfaceIndex) == skipInterfaceIndex {
|
||||
return nil
|
||||
}
|
||||
|
||||
destPrefix := parseIPPrefix(entry.DestinationPrefix, int(entry.InterfaceIndex))
|
||||
if !destPrefix.IsValid() || !destPrefix.Contains(dest) {
|
||||
return nil
|
||||
}
|
||||
|
||||
interfaceMetric := getInterfaceMetric(entry.InterfaceIndex, entry.DestinationPrefix.Prefix.sin6_family)
|
||||
|
||||
return &candidateRoute{
|
||||
interfaceIndex: entry.InterfaceIndex,
|
||||
prefixLength: entry.DestinationPrefix.PrefixLength,
|
||||
routeMetric: entry.Metric,
|
||||
interfaceMetric: interfaceMetric,
|
||||
}
|
||||
}
|
||||
|
||||
// getInterfaceMetric retrieves the interface metric for a given interface and address family
|
||||
func getInterfaceMetric(interfaceIndex uint32, family int16) int {
|
||||
if interfaceIndex == 0 {
|
||||
@@ -821,6 +882,76 @@ func getInterfaceMetric(interfaceIndex uint32, family int16) int {
|
||||
return int(ipInterfaceRow.Metric)
|
||||
}
|
||||
|
||||
// sortRouteCandidates sorts route candidates by priority: prefix length -> route metric -> interface metric
|
||||
func sortRouteCandidates(candidates []candidateRoute) {
|
||||
sort.Slice(candidates, func(i, j int) bool {
|
||||
if candidates[i].prefixLength != candidates[j].prefixLength {
|
||||
return candidates[i].prefixLength > candidates[j].prefixLength
|
||||
}
|
||||
if candidates[i].routeMetric != candidates[j].routeMetric {
|
||||
return candidates[i].routeMetric < candidates[j].routeMetric
|
||||
}
|
||||
return candidates[i].interfaceMetric < candidates[j].interfaceMetric
|
||||
})
|
||||
}
|
||||
|
||||
// GetBestInterface finds the best interface for reaching a destination,
|
||||
// excluding the VPN interface to avoid routing loops.
|
||||
//
|
||||
// Route selection priority:
|
||||
// 1. Longest prefix match (most specific route)
|
||||
// 2. Lowest route metric
|
||||
// 3. Lowest interface metric
|
||||
func GetBestInterface(dest netip.Addr, vpnIntf string) (*net.Interface, error) {
|
||||
var skipInterfaceIndex int
|
||||
if vpnIntf != "" {
|
||||
if iface, err := net.InterfaceByName(vpnIntf); err == nil {
|
||||
skipInterfaceIndex = iface.Index
|
||||
} else {
|
||||
// not critical, if we cannot get ahold of the interface then we won't need to skip it
|
||||
log.Warnf("failed to get VPN interface %s: %v", vpnIntf, err)
|
||||
}
|
||||
}
|
||||
|
||||
table, err := getWindowsRoutingTable()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get routing table: %w", err)
|
||||
}
|
||||
defer freeWindowsRoutingTable(table)
|
||||
|
||||
candidates := parseCandidatesFromTable(table, dest, skipInterfaceIndex)
|
||||
|
||||
if len(candidates) == 0 {
|
||||
return nil, fmt.Errorf("no route to %s", dest)
|
||||
}
|
||||
|
||||
// Sort routes: prefix length -> route metric -> interface metric
|
||||
sortRouteCandidates(candidates)
|
||||
|
||||
for _, candidate := range candidates {
|
||||
iface, err := net.InterfaceByIndex(int(candidate.interfaceIndex))
|
||||
if err != nil {
|
||||
log.Warnf("failed to get interface by index %d: %v", candidate.interfaceIndex, err)
|
||||
continue
|
||||
}
|
||||
|
||||
if iface.Flags&net.FlagLoopback != 0 && !dest.IsLoopback() {
|
||||
continue
|
||||
}
|
||||
|
||||
if iface.Flags&net.FlagUp == 0 {
|
||||
log.Debugf("interface %s is down, trying next route", iface.Name)
|
||||
continue
|
||||
}
|
||||
|
||||
log.Debugf("route lookup for %s: selected interface %s (index %d), route metric %d, interface metric %d",
|
||||
dest, iface.Name, iface.Index, candidate.routeMetric, candidate.interfaceMetric)
|
||||
return iface, nil
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("no usable interface found for %s", dest)
|
||||
}
|
||||
|
||||
// formatRouteAge formats the route age in seconds to a human-readable string
|
||||
func formatRouteAge(ageSeconds uint32) string {
|
||||
if ageSeconds == 0 {
|
||||
|
||||
@@ -15,7 +15,7 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
nbnet "github.com/netbirdio/netbird/client/net"
|
||||
)
|
||||
|
||||
var (
|
||||
|
||||
@@ -12,18 +12,8 @@ func GetPrefixFromIP(ip net.IP) (netip.Prefix, error) {
|
||||
if !ok {
|
||||
return netip.Prefix{}, fmt.Errorf("parse IP address: %s", ip)
|
||||
}
|
||||
|
||||
addr = addr.Unmap()
|
||||
|
||||
var prefixLength int
|
||||
switch {
|
||||
case addr.Is4():
|
||||
prefixLength = 32
|
||||
case addr.Is6():
|
||||
prefixLength = 128
|
||||
default:
|
||||
return netip.Prefix{}, fmt.Errorf("invalid IP address: %s", addr)
|
||||
}
|
||||
|
||||
prefix := netip.PrefixFrom(addr, prefixLength)
|
||||
prefix := netip.PrefixFrom(addr, addr.BitLen())
|
||||
return prefix, nil
|
||||
}
|
||||
|
||||
@@ -5,7 +5,7 @@ import (
|
||||
|
||||
"github.com/pion/transport/v3"
|
||||
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
nbnet "github.com/netbirdio/netbird/client/net"
|
||||
)
|
||||
|
||||
// Dial connects to the address on the named network.
|
||||
|
||||
@@ -6,7 +6,7 @@ import (
|
||||
|
||||
"github.com/pion/transport/v3"
|
||||
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
nbnet "github.com/netbirdio/netbird/client/net"
|
||||
)
|
||||
|
||||
// ListenPacket listens for incoming packets on the given network and address.
|
||||
|
||||
@@ -40,7 +40,7 @@ func NewNetWithDiscover(iFaceDiscover ExternalIFaceDiscover, disallowList []stri
|
||||
if netstack.IsEnabled() {
|
||||
n.iFaceDiscover = pionDiscover{}
|
||||
} else {
|
||||
newMobileIFaceDiscover(iFaceDiscover)
|
||||
n.iFaceDiscover = newMobileIFaceDiscover(iFaceDiscover)
|
||||
}
|
||||
return n, n.UpdateInterfaces()
|
||||
}
|
||||
|
||||
98
client/internal/wg_iface_monitor.go
Normal file
98
client/internal/wg_iface_monitor.go
Normal file
@@ -0,0 +1,98 @@
|
||||
package internal
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"runtime"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// WGIfaceMonitor monitors the WireGuard interface lifecycle and restarts the engine
|
||||
// if the interface is deleted externally while the engine is running.
|
||||
type WGIfaceMonitor struct {
|
||||
done chan struct{}
|
||||
}
|
||||
|
||||
// NewWGIfaceMonitor creates a new WGIfaceMonitor instance.
|
||||
func NewWGIfaceMonitor() *WGIfaceMonitor {
|
||||
return &WGIfaceMonitor{
|
||||
done: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
// Start begins monitoring the WireGuard interface.
|
||||
// It relies on the provided context cancellation to stop.
|
||||
func (m *WGIfaceMonitor) Start(ctx context.Context, ifaceName string) (shouldRestart bool, err error) {
|
||||
defer close(m.done)
|
||||
|
||||
// Skip on mobile platforms as they handle interface lifecycle differently
|
||||
if runtime.GOOS == "android" || runtime.GOOS == "ios" {
|
||||
log.Debugf("Interface monitor: skipped on %s platform", runtime.GOOS)
|
||||
return false, errors.New("not supported on mobile platforms")
|
||||
}
|
||||
|
||||
if ifaceName == "" {
|
||||
log.Debugf("Interface monitor: empty interface name, skipping monitor")
|
||||
return false, errors.New("empty interface name")
|
||||
}
|
||||
|
||||
// Get initial interface index to track the specific interface instance
|
||||
expectedIndex, err := getInterfaceIndex(ifaceName)
|
||||
if err != nil {
|
||||
log.Debugf("Interface monitor: interface %s not found, skipping monitor", ifaceName)
|
||||
return false, fmt.Errorf("interface %s not found: %w", ifaceName, err)
|
||||
}
|
||||
|
||||
log.Infof("Interface monitor: watching %s (index: %d)", ifaceName, expectedIndex)
|
||||
|
||||
ticker := time.NewTicker(2 * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
log.Infof("Interface monitor: stopped for %s", ifaceName)
|
||||
return false, fmt.Errorf("wg interface monitor stopped: %v", ctx.Err())
|
||||
case <-ticker.C:
|
||||
currentIndex, err := getInterfaceIndex(ifaceName)
|
||||
if err != nil {
|
||||
// Interface was deleted
|
||||
log.Infof("Interface monitor: %s deleted", ifaceName)
|
||||
return true, fmt.Errorf("interface %s deleted: %w", ifaceName, err)
|
||||
}
|
||||
|
||||
// Check if interface index changed (interface was recreated)
|
||||
if currentIndex != expectedIndex {
|
||||
log.Infof("Interface monitor: %s recreated (index changed from %d to %d), restarting engine",
|
||||
ifaceName, expectedIndex, currentIndex)
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// getInterfaceIndex returns the index of a network interface by name.
|
||||
// Returns an error if the interface is not found.
|
||||
func getInterfaceIndex(name string) (int, error) {
|
||||
if name == "" {
|
||||
return 0, fmt.Errorf("empty interface name")
|
||||
}
|
||||
ifi, err := net.InterfaceByName(name)
|
||||
if err != nil {
|
||||
// Check if it's specifically a "not found" error
|
||||
if errors.Is(err, &net.OpError{}) {
|
||||
// On some systems, this might be a "not found" error
|
||||
return 0, fmt.Errorf("interface not found: %w", err)
|
||||
}
|
||||
return 0, fmt.Errorf("failed to lookup interface: %w", err)
|
||||
}
|
||||
if ifi == nil {
|
||||
return 0, fmt.Errorf("interface not found")
|
||||
}
|
||||
return ifi.Index, nil
|
||||
}
|
||||
49
client/net/conn.go
Normal file
49
client/net/conn.go
Normal file
@@ -0,0 +1,49 @@
|
||||
//go:build !ios
|
||||
|
||||
package net
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/net/hooks"
|
||||
)
|
||||
|
||||
// Conn wraps a net.Conn to override the Close method
|
||||
type Conn struct {
|
||||
net.Conn
|
||||
ID hooks.ConnectionID
|
||||
}
|
||||
|
||||
// Close overrides the net.Conn Close method to execute all registered hooks after closing the connection
|
||||
// Close overrides the net.Conn Close method to execute all registered hooks before closing the connection.
|
||||
func (c *Conn) Close() error {
|
||||
return closeConn(c.ID, c.Conn)
|
||||
}
|
||||
|
||||
// TCPConn wraps net.TCPConn to override its Close method to include hook functionality.
|
||||
type TCPConn struct {
|
||||
*net.TCPConn
|
||||
ID hooks.ConnectionID
|
||||
}
|
||||
|
||||
// Close overrides the net.TCPConn Close method to execute all registered hooks before closing the connection.
|
||||
func (c *TCPConn) Close() error {
|
||||
return closeConn(c.ID, c.TCPConn)
|
||||
}
|
||||
|
||||
// closeConn is a helper function to close connections and execute close hooks.
|
||||
func closeConn(id hooks.ConnectionID, conn io.Closer) error {
|
||||
err := conn.Close()
|
||||
|
||||
closeHooks := hooks.GetCloseHooks()
|
||||
for _, hook := range closeHooks {
|
||||
if err := hook(id); err != nil {
|
||||
log.Errorf("Error executing close hook: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
82
client/net/dial.go
Normal file
82
client/net/dial.go
Normal file
@@ -0,0 +1,82 @@
|
||||
//go:build !ios
|
||||
|
||||
package net
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
|
||||
"github.com/pion/transport/v3"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
func DialUDP(network string, laddr, raddr *net.UDPAddr) (transport.UDPConn, error) {
|
||||
if CustomRoutingDisabled() {
|
||||
return net.DialUDP(network, laddr, raddr)
|
||||
}
|
||||
|
||||
dialer := NewDialer()
|
||||
dialer.LocalAddr = laddr
|
||||
|
||||
conn, err := dialer.Dial(network, raddr.String())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("dialing UDP %s: %w", raddr.String(), err)
|
||||
}
|
||||
|
||||
switch c := conn.(type) {
|
||||
case *net.UDPConn:
|
||||
// Advanced routing: plain connection
|
||||
return c, nil
|
||||
case *Conn:
|
||||
// Legacy routing: wrapped connection preserves close hooks
|
||||
udpConn, ok := c.Conn.(*net.UDPConn)
|
||||
if !ok {
|
||||
if err := conn.Close(); err != nil {
|
||||
log.Errorf("Failed to close connection: %v", err)
|
||||
}
|
||||
return nil, fmt.Errorf("expected UDP connection, got %T", c.Conn)
|
||||
}
|
||||
return &UDPConn{UDPConn: udpConn, ID: c.ID, seenAddrs: &sync.Map{}}, nil
|
||||
}
|
||||
|
||||
if err := conn.Close(); err != nil {
|
||||
log.Errorf("failed to close connection: %v", err)
|
||||
}
|
||||
return nil, fmt.Errorf("unexpected connection type: %T", conn)
|
||||
}
|
||||
|
||||
func DialTCP(network string, laddr, raddr *net.TCPAddr) (transport.TCPConn, error) {
|
||||
if CustomRoutingDisabled() {
|
||||
return net.DialTCP(network, laddr, raddr)
|
||||
}
|
||||
|
||||
dialer := NewDialer()
|
||||
dialer.LocalAddr = laddr
|
||||
|
||||
conn, err := dialer.Dial(network, raddr.String())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("dialing TCP %s: %w", raddr.String(), err)
|
||||
}
|
||||
|
||||
switch c := conn.(type) {
|
||||
case *net.TCPConn:
|
||||
// Advanced routing: plain connection
|
||||
return c, nil
|
||||
case *Conn:
|
||||
// Legacy routing: wrapped connection preserves close hooks
|
||||
tcpConn, ok := c.Conn.(*net.TCPConn)
|
||||
if !ok {
|
||||
if err := conn.Close(); err != nil {
|
||||
log.Errorf("Failed to close connection: %v", err)
|
||||
}
|
||||
return nil, fmt.Errorf("expected TCP connection, got %T", c.Conn)
|
||||
}
|
||||
return &TCPConn{TCPConn: tcpConn, ID: c.ID}, nil
|
||||
}
|
||||
|
||||
if err := conn.Close(); err != nil {
|
||||
log.Errorf("failed to close connection: %v", err)
|
||||
}
|
||||
return nil, fmt.Errorf("unexpected connection type: %T", conn)
|
||||
}
|
||||
@@ -16,6 +16,5 @@ func NewDialer() *Dialer {
|
||||
Dialer: &net.Dialer{},
|
||||
}
|
||||
dialer.init()
|
||||
|
||||
return dialer
|
||||
}
|
||||
87
client/net/dialer_dial.go
Normal file
87
client/net/dialer_dial.go
Normal file
@@ -0,0 +1,87 @@
|
||||
//go:build !ios
|
||||
|
||||
package net
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
|
||||
"github.com/hashicorp/go-multierror"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/util"
|
||||
"github.com/netbirdio/netbird/client/net/hooks"
|
||||
)
|
||||
|
||||
// DialContext wraps the net.Dialer's DialContext method to use the custom connection
|
||||
func (d *Dialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
log.Debugf("Dialing %s %s", network, address)
|
||||
|
||||
if CustomRoutingDisabled() || AdvancedRouting() {
|
||||
return d.Dialer.DialContext(ctx, network, address)
|
||||
}
|
||||
|
||||
connID := hooks.GenerateConnID()
|
||||
if err := callDialerHooks(ctx, connID, address, d.Resolver); err != nil {
|
||||
log.Errorf("Failed to call dialer hooks: %v", err)
|
||||
}
|
||||
|
||||
conn, err := d.Dialer.DialContext(ctx, network, address)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("d.Dialer.DialContext: %w", err)
|
||||
}
|
||||
|
||||
// Wrap the connection in Conn to handle Close with hooks
|
||||
return &Conn{Conn: conn, ID: connID}, nil
|
||||
}
|
||||
|
||||
// Dial wraps the net.Dialer's Dial method to use the custom connection
|
||||
func (d *Dialer) Dial(network, address string) (net.Conn, error) {
|
||||
return d.DialContext(context.Background(), network, address)
|
||||
}
|
||||
|
||||
func callDialerHooks(ctx context.Context, connID hooks.ConnectionID, address string, customResolver *net.Resolver) error {
|
||||
if ctx.Err() != nil {
|
||||
return ctx.Err()
|
||||
}
|
||||
|
||||
writeHooks := hooks.GetWriteHooks()
|
||||
if len(writeHooks) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
host, _, err := net.SplitHostPort(address)
|
||||
if err != nil {
|
||||
return fmt.Errorf("split host and port: %w", err)
|
||||
}
|
||||
|
||||
resolver := customResolver
|
||||
if resolver == nil {
|
||||
resolver = net.DefaultResolver
|
||||
}
|
||||
|
||||
ips, err := resolver.LookupIPAddr(ctx, host)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to resolve address %s: %w", address, err)
|
||||
}
|
||||
|
||||
log.Debugf("Dialer resolved IPs for %s: %v", address, ips)
|
||||
|
||||
var merr *multierror.Error
|
||||
for _, ip := range ips {
|
||||
prefix, err := util.GetPrefixFromIP(ip.IP)
|
||||
if err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("convert IP %s to prefix: %w", ip.IP, err))
|
||||
continue
|
||||
}
|
||||
for _, hook := range writeHooks {
|
||||
if err := hook(connID, prefix); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("executing dial hook for IP %s: %w", ip.IP, err))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
7
client/net/dialer_init_generic.go
Normal file
7
client/net/dialer_init_generic.go
Normal file
@@ -0,0 +1,7 @@
|
||||
//go:build !linux && !windows
|
||||
|
||||
package net
|
||||
|
||||
func (d *Dialer) init() {
|
||||
// implemented on Linux, Android, and Windows only
|
||||
}
|
||||
5
client/net/dialer_init_windows.go
Normal file
5
client/net/dialer_init_windows.go
Normal file
@@ -0,0 +1,5 @@
|
||||
package net
|
||||
|
||||
func (d *Dialer) init() {
|
||||
d.Dialer.Control = applyUnicastIFToSocket
|
||||
}
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
|
||||
const (
|
||||
envDisableCustomRouting = "NB_DISABLE_CUSTOM_ROUTING"
|
||||
envUseLegacyRouting = "NB_USE_LEGACY_ROUTING"
|
||||
)
|
||||
|
||||
// CustomRoutingDisabled returns true if custom routing is disabled.
|
||||
24
client/net/env_android.go
Normal file
24
client/net/env_android.go
Normal file
@@ -0,0 +1,24 @@
|
||||
//go:build android
|
||||
|
||||
package net
|
||||
|
||||
// Init initializes the network environment for Android
|
||||
func Init() {
|
||||
// No initialization needed on Android
|
||||
}
|
||||
|
||||
// AdvancedRouting reports whether routing loops can be avoided without using exclusion routes.
|
||||
// Always returns true on Android since we cannot handle routes dynamically.
|
||||
func AdvancedRouting() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// SetVPNInterfaceName is a no-op on Android
|
||||
func SetVPNInterfaceName(name string) {
|
||||
// No-op on Android - not needed for Android VPN service
|
||||
}
|
||||
|
||||
// GetVPNInterfaceName returns empty string on Android
|
||||
func GetVPNInterfaceName() string {
|
||||
return ""
|
||||
}
|
||||
23
client/net/env_generic.go
Normal file
23
client/net/env_generic.go
Normal file
@@ -0,0 +1,23 @@
|
||||
//go:build !linux && !windows && !android
|
||||
|
||||
package net
|
||||
|
||||
// Init initializes the network environment (no-op on non-Linux/Windows platforms)
|
||||
func Init() {
|
||||
// No-op on non-Linux/Windows platforms
|
||||
}
|
||||
|
||||
// AdvancedRouting returns false on non-Linux/Windows platforms
|
||||
func AdvancedRouting() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// SetVPNInterfaceName is a no-op on non-Windows platforms
|
||||
func SetVPNInterfaceName(name string) {
|
||||
// No-op on non-Windows platforms
|
||||
}
|
||||
|
||||
// GetVPNInterfaceName returns empty string on non-Windows platforms
|
||||
func GetVPNInterfaceName() string {
|
||||
return ""
|
||||
}
|
||||
@@ -17,8 +17,7 @@ import (
|
||||
|
||||
const (
|
||||
// these have the same effect, skip socket env supported for backward compatibility
|
||||
envSkipSocketMark = "NB_SKIP_SOCKET_MARK"
|
||||
envUseLegacyRouting = "NB_USE_LEGACY_ROUTING"
|
||||
envSkipSocketMark = "NB_SKIP_SOCKET_MARK"
|
||||
)
|
||||
|
||||
var advancedRoutingSupported bool
|
||||
@@ -27,6 +26,7 @@ func Init() {
|
||||
advancedRoutingSupported = checkAdvancedRoutingSupport()
|
||||
}
|
||||
|
||||
// AdvancedRouting reports whether routing loops can be avoided without using exclusion routes
|
||||
func AdvancedRouting() bool {
|
||||
return advancedRoutingSupported
|
||||
}
|
||||
@@ -73,7 +73,7 @@ func checkAdvancedRoutingSupport() bool {
|
||||
}
|
||||
|
||||
func CheckFwmarkSupport() bool {
|
||||
// temporarily enable advanced routing to check fwmarks are supported
|
||||
// temporarily enable advanced routing to check if fwmarks are supported
|
||||
old := advancedRoutingSupported
|
||||
advancedRoutingSupported = true
|
||||
defer func() {
|
||||
@@ -129,3 +129,13 @@ func CheckRuleOperationsSupport() bool {
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// SetVPNInterfaceName is a no-op on Linux
|
||||
func SetVPNInterfaceName(name string) {
|
||||
// No-op on Linux - not needed for fwmark-based routing
|
||||
}
|
||||
|
||||
// GetVPNInterfaceName returns empty string on Linux
|
||||
func GetVPNInterfaceName() string {
|
||||
return ""
|
||||
}
|
||||
67
client/net/env_windows.go
Normal file
67
client/net/env_windows.go
Normal file
@@ -0,0 +1,67 @@
|
||||
//go:build windows
|
||||
|
||||
package net
|
||||
|
||||
import (
|
||||
"os"
|
||||
"strconv"
|
||||
"sync"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/netstack"
|
||||
)
|
||||
|
||||
var (
|
||||
vpnInterfaceName string
|
||||
vpnInitMutex sync.RWMutex
|
||||
|
||||
advancedRoutingSupported bool
|
||||
)
|
||||
|
||||
func Init() {
|
||||
advancedRoutingSupported = checkAdvancedRoutingSupport()
|
||||
}
|
||||
|
||||
func checkAdvancedRoutingSupport() bool {
|
||||
var err error
|
||||
var legacyRouting bool
|
||||
if val := os.Getenv(envUseLegacyRouting); val != "" {
|
||||
legacyRouting, err = strconv.ParseBool(val)
|
||||
if err != nil {
|
||||
log.Warnf("failed to parse %s: %v", envUseLegacyRouting, err)
|
||||
}
|
||||
}
|
||||
|
||||
if legacyRouting || netstack.IsEnabled() {
|
||||
log.Info("advanced routing has been requested to be disabled")
|
||||
return false
|
||||
}
|
||||
|
||||
log.Info("system supports advanced routing")
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// AdvancedRouting reports whether routing loops can be avoided without using exclusion routes
|
||||
func AdvancedRouting() bool {
|
||||
return advancedRoutingSupported
|
||||
}
|
||||
|
||||
// GetVPNInterfaceName returns the stored VPN interface name
|
||||
func GetVPNInterfaceName() string {
|
||||
vpnInitMutex.RLock()
|
||||
defer vpnInitMutex.RUnlock()
|
||||
return vpnInterfaceName
|
||||
}
|
||||
|
||||
// SetVPNInterfaceName sets the VPN interface name for lazy initialization
|
||||
func SetVPNInterfaceName(name string) {
|
||||
vpnInitMutex.Lock()
|
||||
defer vpnInitMutex.Unlock()
|
||||
vpnInterfaceName = name
|
||||
|
||||
if name != "" {
|
||||
log.Infof("VPN interface name set to %s for route exclusion", name)
|
||||
}
|
||||
}
|
||||
93
client/net/hooks/hooks.go
Normal file
93
client/net/hooks/hooks.go
Normal file
@@ -0,0 +1,93 @@
|
||||
package hooks
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"slices"
|
||||
"sync"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// ConnectionID provides a globally unique identifier for network connections.
|
||||
// It's used to track connections throughout their lifecycle so the close hook can correlate with the dial hook.
|
||||
type ConnectionID string
|
||||
|
||||
// GenerateConnID generates a unique identifier for each connection.
|
||||
func GenerateConnID() ConnectionID {
|
||||
return ConnectionID(uuid.NewString())
|
||||
}
|
||||
|
||||
type WriteHookFunc func(connID ConnectionID, prefix netip.Prefix) error
|
||||
type CloseHookFunc func(connID ConnectionID) error
|
||||
type AddressRemoveHookFunc func(connID ConnectionID, prefix netip.Prefix) error
|
||||
|
||||
var (
|
||||
hooksMutex sync.RWMutex
|
||||
|
||||
writeHooks []WriteHookFunc
|
||||
closeHooks []CloseHookFunc
|
||||
addressRemoveHooks []AddressRemoveHookFunc
|
||||
)
|
||||
|
||||
// AddWriteHook allows adding a new hook to be executed before writing/dialing.
|
||||
func AddWriteHook(hook WriteHookFunc) {
|
||||
hooksMutex.Lock()
|
||||
defer hooksMutex.Unlock()
|
||||
writeHooks = append(writeHooks, hook)
|
||||
}
|
||||
|
||||
// AddCloseHook allows adding a new hook to be executed on connection close.
|
||||
func AddCloseHook(hook CloseHookFunc) {
|
||||
hooksMutex.Lock()
|
||||
defer hooksMutex.Unlock()
|
||||
closeHooks = append(closeHooks, hook)
|
||||
}
|
||||
|
||||
// RemoveWriteHooks removes all write hooks.
|
||||
func RemoveWriteHooks() {
|
||||
hooksMutex.Lock()
|
||||
defer hooksMutex.Unlock()
|
||||
writeHooks = nil
|
||||
}
|
||||
|
||||
// RemoveCloseHooks removes all close hooks.
|
||||
func RemoveCloseHooks() {
|
||||
hooksMutex.Lock()
|
||||
defer hooksMutex.Unlock()
|
||||
closeHooks = nil
|
||||
}
|
||||
|
||||
// AddAddressRemoveHook allows adding a new hook to be executed when an address is removed.
|
||||
func AddAddressRemoveHook(hook AddressRemoveHookFunc) {
|
||||
hooksMutex.Lock()
|
||||
defer hooksMutex.Unlock()
|
||||
addressRemoveHooks = append(addressRemoveHooks, hook)
|
||||
}
|
||||
|
||||
// RemoveAddressRemoveHooks removes all listener address hooks.
|
||||
func RemoveAddressRemoveHooks() {
|
||||
hooksMutex.Lock()
|
||||
defer hooksMutex.Unlock()
|
||||
addressRemoveHooks = nil
|
||||
}
|
||||
|
||||
// GetWriteHooks returns a copy of the current write hooks.
|
||||
func GetWriteHooks() []WriteHookFunc {
|
||||
hooksMutex.RLock()
|
||||
defer hooksMutex.RUnlock()
|
||||
return slices.Clone(writeHooks)
|
||||
}
|
||||
|
||||
// GetCloseHooks returns a copy of the current close hooks.
|
||||
func GetCloseHooks() []CloseHookFunc {
|
||||
hooksMutex.RLock()
|
||||
defer hooksMutex.RUnlock()
|
||||
return slices.Clone(closeHooks)
|
||||
}
|
||||
|
||||
// GetAddressRemoveHooks returns a copy of the current listener address remove hooks.
|
||||
func GetAddressRemoveHooks() []AddressRemoveHookFunc {
|
||||
hooksMutex.RLock()
|
||||
defer hooksMutex.RUnlock()
|
||||
return slices.Clone(addressRemoveHooks)
|
||||
}
|
||||
47
client/net/listen.go
Normal file
47
client/net/listen.go
Normal file
@@ -0,0 +1,47 @@
|
||||
//go:build !ios
|
||||
|
||||
package net
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
|
||||
"github.com/pion/transport/v3"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// ListenUDP listens on the network address and returns a transport.UDPConn
|
||||
// which includes support for write and close hooks.
|
||||
func ListenUDP(network string, laddr *net.UDPAddr) (transport.UDPConn, error) {
|
||||
if CustomRoutingDisabled() {
|
||||
return net.ListenUDP(network, laddr)
|
||||
}
|
||||
|
||||
conn, err := NewListener().ListenPacket(context.Background(), network, laddr.String())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("listen UDP: %w", err)
|
||||
}
|
||||
|
||||
switch c := conn.(type) {
|
||||
case *net.UDPConn:
|
||||
// Advanced routing: plain connection
|
||||
return c, nil
|
||||
case *PacketConn:
|
||||
// Legacy routing: wrapped connection for hooks
|
||||
udpConn, ok := c.PacketConn.(*net.UDPConn)
|
||||
if !ok {
|
||||
if err := c.Close(); err != nil {
|
||||
log.Errorf("Failed to close connection: %v", err)
|
||||
}
|
||||
return nil, fmt.Errorf("expected UDPConn, got %T", c.PacketConn)
|
||||
}
|
||||
return &UDPConn{UDPConn: udpConn, ID: c.ID, seenAddrs: &sync.Map{}}, nil
|
||||
}
|
||||
|
||||
if err := conn.Close(); err != nil {
|
||||
log.Errorf("failed to close connection: %v", err)
|
||||
}
|
||||
return nil, fmt.Errorf("unexpected connection type: %T", conn)
|
||||
}
|
||||
@@ -7,14 +7,12 @@ import (
|
||||
// ListenerConfig extends the standard net.ListenConfig with the ability to execute hooks before
|
||||
// responding via the socket and after closing. This can be used to bypass the VPN for listeners.
|
||||
type ListenerConfig struct {
|
||||
*net.ListenConfig
|
||||
net.ListenConfig
|
||||
}
|
||||
|
||||
// NewListener creates a new ListenerConfig instance.
|
||||
func NewListener() *ListenerConfig {
|
||||
listener := &ListenerConfig{
|
||||
ListenConfig: &net.ListenConfig{},
|
||||
}
|
||||
listener := &ListenerConfig{}
|
||||
listener.init()
|
||||
|
||||
return listener
|
||||
7
client/net/listener_init_generic.go
Normal file
7
client/net/listener_init_generic.go
Normal file
@@ -0,0 +1,7 @@
|
||||
//go:build !linux && !windows
|
||||
|
||||
package net
|
||||
|
||||
func (l *ListenerConfig) init() {
|
||||
// implemented on Linux, Android, and Windows only
|
||||
}
|
||||
8
client/net/listener_init_windows.go
Normal file
8
client/net/listener_init_windows.go
Normal file
@@ -0,0 +1,8 @@
|
||||
package net
|
||||
|
||||
func (l *ListenerConfig) init() {
|
||||
// TODO: this will select a single source interface, but for UDP we can have various source interfaces and IP addresses.
|
||||
// For now we stick to the one that matches the request IP address, which can be the unspecified IP. In this case
|
||||
// the interface will be selected that serves the default route.
|
||||
l.ListenConfig.Control = applyUnicastIFToSocket
|
||||
}
|
||||
153
client/net/listener_listen.go
Normal file
153
client/net/listener_listen.go
Normal file
@@ -0,0 +1,153 @@
|
||||
//go:build !ios
|
||||
|
||||
package net
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"sync"
|
||||
|
||||
"github.com/hashicorp/go-multierror"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/util"
|
||||
"github.com/netbirdio/netbird/client/net/hooks"
|
||||
)
|
||||
|
||||
// ListenPacket listens on the network address and returns a PacketConn
|
||||
// which includes support for write hooks.
|
||||
func (l *ListenerConfig) ListenPacket(ctx context.Context, network, address string) (net.PacketConn, error) {
|
||||
if CustomRoutingDisabled() || AdvancedRouting() {
|
||||
return l.ListenConfig.ListenPacket(ctx, network, address)
|
||||
}
|
||||
|
||||
pc, err := l.ListenConfig.ListenPacket(ctx, network, address)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("listen packet: %w", err)
|
||||
}
|
||||
connID := hooks.GenerateConnID()
|
||||
|
||||
return &PacketConn{PacketConn: pc, ID: connID, seenAddrs: &sync.Map{}}, nil
|
||||
}
|
||||
|
||||
// PacketConn wraps net.PacketConn to override its WriteTo and Close methods to include hook functionality.
|
||||
type PacketConn struct {
|
||||
net.PacketConn
|
||||
ID hooks.ConnectionID
|
||||
seenAddrs *sync.Map
|
||||
}
|
||||
|
||||
// WriteTo writes a packet with payload b to addr, executing registered write hooks beforehand.
|
||||
func (c *PacketConn) WriteTo(b []byte, addr net.Addr) (n int, err error) {
|
||||
if err := callWriteHooks(c.ID, c.seenAddrs, addr); err != nil {
|
||||
log.Errorf("Failed to call write hooks: %v", err)
|
||||
}
|
||||
return c.PacketConn.WriteTo(b, addr)
|
||||
}
|
||||
|
||||
// Close overrides the net.PacketConn Close method to execute all registered hooks before closing the connection.
|
||||
func (c *PacketConn) Close() error {
|
||||
defer c.seenAddrs.Clear()
|
||||
return closeConn(c.ID, c.PacketConn)
|
||||
}
|
||||
|
||||
// UDPConn wraps net.UDPConn to override its WriteTo and Close methods to include hook functionality.
|
||||
type UDPConn struct {
|
||||
*net.UDPConn
|
||||
ID hooks.ConnectionID
|
||||
seenAddrs *sync.Map
|
||||
}
|
||||
|
||||
// WriteTo writes a packet with payload b to addr, executing registered write hooks beforehand.
|
||||
func (c *UDPConn) WriteTo(b []byte, addr net.Addr) (n int, err error) {
|
||||
if err := callWriteHooks(c.ID, c.seenAddrs, addr); err != nil {
|
||||
log.Errorf("Failed to call write hooks: %v", err)
|
||||
}
|
||||
return c.UDPConn.WriteTo(b, addr)
|
||||
}
|
||||
|
||||
// Close overrides the net.UDPConn Close method to execute all registered hooks before closing the connection.
|
||||
func (c *UDPConn) Close() error {
|
||||
defer c.seenAddrs.Clear()
|
||||
return closeConn(c.ID, c.UDPConn)
|
||||
}
|
||||
|
||||
// RemoveAddress removes an address from the seen cache and triggers removal hooks.
|
||||
func (c *PacketConn) RemoveAddress(addr string) {
|
||||
if _, exists := c.seenAddrs.LoadAndDelete(addr); !exists {
|
||||
return
|
||||
}
|
||||
|
||||
ipStr, _, err := net.SplitHostPort(addr)
|
||||
if err != nil {
|
||||
log.Errorf("Error splitting IP address and port: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
ipAddr, err := netip.ParseAddr(ipStr)
|
||||
if err != nil {
|
||||
log.Errorf("Error parsing IP address %s: %v", ipStr, err)
|
||||
return
|
||||
}
|
||||
|
||||
prefix := netip.PrefixFrom(ipAddr.Unmap(), ipAddr.BitLen())
|
||||
|
||||
addressRemoveHooks := hooks.GetAddressRemoveHooks()
|
||||
if len(addressRemoveHooks) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
for _, hook := range addressRemoveHooks {
|
||||
if err := hook(c.ID, prefix); err != nil {
|
||||
log.Errorf("Error executing listener address remove hook: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// WrapPacketConn wraps an existing net.PacketConn with nbnet hook functionality
|
||||
func WrapPacketConn(conn net.PacketConn) net.PacketConn {
|
||||
if AdvancedRouting() {
|
||||
// hooks not required for advanced routing
|
||||
return conn
|
||||
}
|
||||
return &PacketConn{
|
||||
PacketConn: conn,
|
||||
ID: hooks.GenerateConnID(),
|
||||
seenAddrs: &sync.Map{},
|
||||
}
|
||||
}
|
||||
|
||||
func callWriteHooks(id hooks.ConnectionID, seenAddrs *sync.Map, addr net.Addr) error {
|
||||
if _, loaded := seenAddrs.LoadOrStore(addr.String(), true); loaded {
|
||||
return nil
|
||||
}
|
||||
|
||||
writeHooks := hooks.GetWriteHooks()
|
||||
if len(writeHooks) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
udpAddr, ok := addr.(*net.UDPAddr)
|
||||
if !ok {
|
||||
return fmt.Errorf("expected *net.UDPAddr for packet connection, got %T", addr)
|
||||
}
|
||||
|
||||
prefix, err := util.GetPrefixFromIP(udpAddr.IP)
|
||||
if err != nil {
|
||||
return fmt.Errorf("convert UDP IP %s to prefix: %w", udpAddr.IP, err)
|
||||
}
|
||||
|
||||
log.Debugf("Listener resolved IP for %s: %s", addr, prefix)
|
||||
|
||||
var merr *multierror.Error
|
||||
for _, hook := range writeHooks {
|
||||
if err := hook(id, prefix); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("execute write hook: %w", err))
|
||||
}
|
||||
}
|
||||
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
@@ -5,8 +5,6 @@ import (
|
||||
"math/big"
|
||||
"net"
|
||||
"net/netip"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -44,18 +42,6 @@ func IsDataPlaneMark(fwmark uint32) bool {
|
||||
return fwmark >= DataPlaneMarkLower && fwmark <= DataPlaneMarkUpper
|
||||
}
|
||||
|
||||
// ConnectionID provides a globally unique identifier for network connections.
|
||||
// It's used to track connections throughout their lifecycle so the close hook can correlate with the dial hook.
|
||||
type ConnectionID string
|
||||
|
||||
type AddHookFunc func(connID ConnectionID, IP net.IP) error
|
||||
type RemoveHookFunc func(connID ConnectionID) error
|
||||
|
||||
// GenerateConnID generates a unique identifier for each connection.
|
||||
func GenerateConnID() ConnectionID {
|
||||
return ConnectionID(uuid.NewString())
|
||||
}
|
||||
|
||||
func GetLastIPFromNetwork(network netip.Prefix, fromEnd int) (netip.Addr, error) {
|
||||
var endIP net.IP
|
||||
addr := network.Addr().AsSlice()
|
||||
284
client/net/net_windows.go
Normal file
284
client/net/net_windows.go
Normal file
@@ -0,0 +1,284 @@
|
||||
package net
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"strconv"
|
||||
"strings"
|
||||
"syscall"
|
||||
"time"
|
||||
"unsafe"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
const (
|
||||
// https://learn.microsoft.com/en-us/windows/win32/winsock/ipproto-ip-socket-options
|
||||
IpUnicastIf = 31
|
||||
Ipv6UnicastIf = 31
|
||||
|
||||
// https://learn.microsoft.com/en-us/windows/win32/winsock/ipproto-ipv6-socket-options
|
||||
Ipv6V6only = 27
|
||||
)
|
||||
|
||||
// GetBestInterfaceFunc is set at runtime to avoid import cycle
|
||||
var GetBestInterfaceFunc func(dest netip.Addr, vpnIntf string) (*net.Interface, error)
|
||||
|
||||
// nativeToBigEndian converts a uint32 from native byte order to big-endian
|
||||
func nativeToBigEndian(v uint32) uint32 {
|
||||
return (v&0xff)<<24 | (v&0xff00)<<8 | (v&0xff0000)>>8 | (v&0xff000000)>>24
|
||||
}
|
||||
|
||||
// parseDestinationAddress parses the destination address from various formats
|
||||
func parseDestinationAddress(network, address string) (netip.Addr, error) {
|
||||
if address == "" {
|
||||
if strings.HasSuffix(network, "6") {
|
||||
return netip.IPv6Unspecified(), nil
|
||||
}
|
||||
return netip.IPv4Unspecified(), nil
|
||||
}
|
||||
|
||||
if addrPort, err := netip.ParseAddrPort(address); err == nil {
|
||||
return addrPort.Addr(), nil
|
||||
}
|
||||
|
||||
if dest, err := netip.ParseAddr(address); err == nil {
|
||||
return dest, nil
|
||||
}
|
||||
|
||||
host, _, err := net.SplitHostPort(address)
|
||||
if err != nil {
|
||||
// No port, treat whole string as host
|
||||
host = address
|
||||
}
|
||||
|
||||
if host == "" {
|
||||
if strings.HasSuffix(network, "6") {
|
||||
return netip.IPv6Unspecified(), nil
|
||||
}
|
||||
return netip.IPv4Unspecified(), nil
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
ips, err := net.DefaultResolver.LookupIPAddr(ctx, host)
|
||||
if err != nil || len(ips) == 0 {
|
||||
return netip.Addr{}, fmt.Errorf("resolve destination %s: %w", host, err)
|
||||
}
|
||||
|
||||
dest, ok := netip.AddrFromSlice(ips[0].IP)
|
||||
if !ok {
|
||||
return netip.Addr{}, fmt.Errorf("convert IP %v to netip.Addr", ips[0].IP)
|
||||
}
|
||||
|
||||
if ips[0].Zone != "" {
|
||||
dest = dest.WithZone(ips[0].Zone)
|
||||
}
|
||||
|
||||
return dest, nil
|
||||
}
|
||||
|
||||
func getInterfaceFromZone(zone string) *net.Interface {
|
||||
if zone == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
idx, err := strconv.Atoi(zone)
|
||||
if err != nil {
|
||||
log.Debugf("invalid zone format for Windows (expected numeric): %s", zone)
|
||||
return nil
|
||||
}
|
||||
|
||||
iface, err := net.InterfaceByIndex(idx)
|
||||
if err != nil {
|
||||
log.Debugf("failed to get interface by index %d from zone: %v", idx, err)
|
||||
return nil
|
||||
}
|
||||
|
||||
return iface
|
||||
}
|
||||
|
||||
type interfaceSelection struct {
|
||||
iface4 *net.Interface
|
||||
iface6 *net.Interface
|
||||
}
|
||||
|
||||
func selectInterfaceForZone(dest netip.Addr, zone string) *interfaceSelection {
|
||||
iface := getInterfaceFromZone(zone)
|
||||
if iface == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if dest.Is6() {
|
||||
return &interfaceSelection{iface6: iface}
|
||||
}
|
||||
return &interfaceSelection{iface4: iface}
|
||||
}
|
||||
|
||||
func selectInterfaceForUnspecified() (*interfaceSelection, error) {
|
||||
if GetBestInterfaceFunc == nil {
|
||||
return nil, errors.New("GetBestInterfaceFunc not initialized")
|
||||
}
|
||||
|
||||
var result interfaceSelection
|
||||
vpnIfaceName := GetVPNInterfaceName()
|
||||
|
||||
if iface4, err := GetBestInterfaceFunc(netip.IPv4Unspecified(), vpnIfaceName); err == nil {
|
||||
result.iface4 = iface4
|
||||
} else {
|
||||
log.Debugf("No IPv4 default route found: %v", err)
|
||||
}
|
||||
|
||||
if iface6, err := GetBestInterfaceFunc(netip.IPv6Unspecified(), vpnIfaceName); err == nil {
|
||||
result.iface6 = iface6
|
||||
} else {
|
||||
log.Debugf("No IPv6 default route found: %v", err)
|
||||
}
|
||||
|
||||
if result.iface4 == nil && result.iface6 == nil {
|
||||
return nil, errors.New("no default routes found")
|
||||
}
|
||||
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
func selectInterface(dest netip.Addr) (*interfaceSelection, error) {
|
||||
if zone := dest.Zone(); zone != "" {
|
||||
if selection := selectInterfaceForZone(dest, zone); selection != nil {
|
||||
return selection, nil
|
||||
}
|
||||
}
|
||||
|
||||
if dest.IsUnspecified() {
|
||||
return selectInterfaceForUnspecified()
|
||||
}
|
||||
|
||||
if GetBestInterfaceFunc == nil {
|
||||
return nil, errors.New("GetBestInterfaceFunc not initialized")
|
||||
}
|
||||
|
||||
iface, err := GetBestInterfaceFunc(dest, GetVPNInterfaceName())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("find route for %s: %w", dest, err)
|
||||
}
|
||||
|
||||
if dest.Is6() {
|
||||
return &interfaceSelection{iface6: iface}, nil
|
||||
}
|
||||
return &interfaceSelection{iface4: iface}, nil
|
||||
}
|
||||
|
||||
func setIPv4UnicastIF(fd uintptr, iface *net.Interface) error {
|
||||
ifaceIndexBE := nativeToBigEndian(uint32(iface.Index))
|
||||
if err := windows.SetsockoptInt(windows.Handle(fd), windows.IPPROTO_IP, IpUnicastIf, int(ifaceIndexBE)); err != nil {
|
||||
return fmt.Errorf("set IP_UNICAST_IF: %w (interface: %s, index: %d)", err, iface.Name, iface.Index)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func setIPv6UnicastIF(fd uintptr, iface *net.Interface) error {
|
||||
if err := windows.SetsockoptInt(windows.Handle(fd), windows.IPPROTO_IPV6, Ipv6UnicastIf, iface.Index); err != nil {
|
||||
return fmt.Errorf("set IPV6_UNICAST_IF: %w (interface: %s, index: %d)", err, iface.Name, iface.Index)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func setUnicastIf(fd uintptr, network string, selection *interfaceSelection, address string) error {
|
||||
// The Go runtime always passes specific network types to Control (udp4, udp6, tcp4, tcp6, etc.)
|
||||
// Never generic ones (udp, tcp, ip)
|
||||
|
||||
switch {
|
||||
case strings.HasSuffix(network, "4"):
|
||||
// IPv4-only socket (udp4, tcp4, ip4)
|
||||
return setUnicastIfIPv4(fd, network, selection, address)
|
||||
|
||||
case strings.HasSuffix(network, "6"):
|
||||
// IPv6 socket (udp6, tcp6, ip6) - could be dual-stack or IPv6-only
|
||||
return setUnicastIfIPv6(fd, network, selection, address)
|
||||
}
|
||||
|
||||
// Shouldn't reach here based on Go's documented behavior
|
||||
return fmt.Errorf("unexpected network type: %s", network)
|
||||
}
|
||||
|
||||
func setUnicastIfIPv4(fd uintptr, network string, selection *interfaceSelection, address string) error {
|
||||
if selection.iface4 == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := setIPv4UnicastIF(fd, selection.iface4); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
log.Debugf("Set IP_UNICAST_IF=%d on %s for %s to %s", selection.iface4.Index, selection.iface4.Name, network, address)
|
||||
return nil
|
||||
}
|
||||
|
||||
func setUnicastIfIPv6(fd uintptr, network string, selection *interfaceSelection, address string) error {
|
||||
isDualStack := checkDualStack(fd)
|
||||
|
||||
// For dual-stack sockets, also set the IPv4 option
|
||||
if isDualStack && selection.iface4 != nil {
|
||||
if err := setIPv4UnicastIF(fd, selection.iface4); err != nil {
|
||||
return err
|
||||
}
|
||||
log.Debugf("Set IP_UNICAST_IF=%d on %s for %s to %s (dual-stack)", selection.iface4.Index, selection.iface4.Name, network, address)
|
||||
}
|
||||
|
||||
if selection.iface6 == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := setIPv6UnicastIF(fd, selection.iface6); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
log.Debugf("Set IPV6_UNICAST_IF=%d on %s for %s to %s", selection.iface6.Index, selection.iface6.Name, network, address)
|
||||
return nil
|
||||
}
|
||||
|
||||
func checkDualStack(fd uintptr) bool {
|
||||
var v6Only int
|
||||
v6OnlyLen := int32(unsafe.Sizeof(v6Only))
|
||||
err := windows.Getsockopt(windows.Handle(fd), windows.IPPROTO_IPV6, Ipv6V6only, (*byte)(unsafe.Pointer(&v6Only)), &v6OnlyLen)
|
||||
return err == nil && v6Only == 0
|
||||
}
|
||||
|
||||
// applyUnicastIFToSocket applies IpUnicastIf to a socket based on the destination address
|
||||
func applyUnicastIFToSocket(network string, address string, c syscall.RawConn) error {
|
||||
if !AdvancedRouting() {
|
||||
return nil
|
||||
}
|
||||
|
||||
dest, err := parseDestinationAddress(network, address)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
dest = dest.Unmap()
|
||||
|
||||
if !dest.IsValid() {
|
||||
return fmt.Errorf("invalid destination address for %s", address)
|
||||
}
|
||||
|
||||
selection, err := selectInterface(dest)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var controlErr error
|
||||
err = c.Control(func(fd uintptr) {
|
||||
controlErr = setUnicastIf(fd, network, selection, address)
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("control: %w", err)
|
||||
}
|
||||
|
||||
return controlErr
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user