mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-05 08:54:11 -04:00
Compare commits
1 Commits
v0.58.1
...
snyk-fix-9
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f3fc6ab5e3 |
2
.github/workflows/release.yml
vendored
2
.github/workflows/release.yml
vendored
@@ -9,7 +9,7 @@ on:
|
||||
pull_request:
|
||||
|
||||
env:
|
||||
SIGN_PIPE_VER: "v0.0.23"
|
||||
SIGN_PIPE_VER: "v0.0.22"
|
||||
GORELEASER_VER: "v2.3.2"
|
||||
PRODUCT_NAME: "NetBird"
|
||||
COPYRIGHT: "NetBird GmbH"
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
|
||||
<div align="center">
|
||||
<br/>
|
||||
<br/>
|
||||
@@ -53,7 +52,7 @@
|
||||
|
||||
### Open Source Network Security in a Single Platform
|
||||
|
||||
https://github.com/user-attachments/assets/10cec749-bb56-4ab3-97af-4e38850108d2
|
||||
<img width="1188" alt="centralized-network-management 1" src="https://github.com/user-attachments/assets/c28cc8e4-15d2-4d2f-bb97-a6433db39d56" />
|
||||
|
||||
### NetBird on Lawrence Systems (Video)
|
||||
[](https://www.youtube.com/watch?v=Kwrff6h0rEw)
|
||||
|
||||
@@ -18,7 +18,7 @@ ENV \
|
||||
NB_LOG_FILE="console,/var/log/netbird/client.log" \
|
||||
NB_DAEMON_ADDR="unix:///var/run/netbird.sock" \
|
||||
NB_ENTRYPOINT_SERVICE_TIMEOUT="5" \
|
||||
NB_ENTRYPOINT_LOGIN_TIMEOUT="5"
|
||||
NB_ENTRYPOINT_LOGIN_TIMEOUT="1"
|
||||
|
||||
ENTRYPOINT [ "/usr/local/bin/netbird-entrypoint.sh" ]
|
||||
|
||||
|
||||
@@ -19,7 +19,7 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||
"github.com/netbirdio/netbird/client/system"
|
||||
"github.com/netbirdio/netbird/formatter"
|
||||
"github.com/netbirdio/netbird/client/net"
|
||||
"github.com/netbirdio/netbird/util/net"
|
||||
)
|
||||
|
||||
// ConnectionListener export internal Listener for mobile
|
||||
|
||||
@@ -33,7 +33,6 @@ type ErrListener interface {
|
||||
// the backend want to show an url for the user
|
||||
type URLOpener interface {
|
||||
Open(string)
|
||||
OnLoginSuccess()
|
||||
}
|
||||
|
||||
// Auth can register or login new client
|
||||
@@ -182,11 +181,6 @@ func (a *Auth) login(urlOpener URLOpener) error {
|
||||
|
||||
err = a.withBackOff(a.ctx, func() error {
|
||||
err := internal.Login(a.ctx, a.config, "", jwtToken)
|
||||
|
||||
if err == nil {
|
||||
go urlOpener.OnLoginSuccess()
|
||||
}
|
||||
|
||||
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied) {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -27,7 +27,7 @@ var downCmd = &cobra.Command{
|
||||
return err
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second*20)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second*7)
|
||||
defer cancel()
|
||||
|
||||
conn, err := DialClientGRPCServer(ctx, daemonAddr)
|
||||
|
||||
@@ -231,7 +231,7 @@ func FlagNameToEnvVar(cmdFlag string, prefix string) string {
|
||||
|
||||
// DialClientGRPCServer returns client connection to the daemon server.
|
||||
func DialClientGRPCServer(ctx context.Context, addr string) (*grpc.ClientConn, error) {
|
||||
ctx, cancel := context.WithTimeout(ctx, time.Second*10)
|
||||
ctx, cancel := context.WithTimeout(ctx, time.Second*3)
|
||||
defer cancel()
|
||||
|
||||
return grpc.DialContext(
|
||||
|
||||
@@ -230,9 +230,7 @@ func runInDaemonMode(ctx context.Context, cmd *cobra.Command, pm *profilemanager
|
||||
|
||||
client := proto.NewDaemonServiceClient(conn)
|
||||
|
||||
status, err := client.Status(ctx, &proto.StatusRequest{
|
||||
WaitForReady: func() *bool { b := true; return &b }(),
|
||||
})
|
||||
status, err := client.Status(ctx, &proto.StatusRequest{})
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to get daemon status: %v", err)
|
||||
}
|
||||
|
||||
@@ -135,7 +135,7 @@ func (c *Client) Start(startCtx context.Context) error {
|
||||
|
||||
// either startup error (permanent backoff err) or nil err (successful engine up)
|
||||
// TODO: make after-startup backoff err available
|
||||
run := make(chan struct{})
|
||||
run := make(chan struct{}, 1)
|
||||
clientErr := make(chan error, 1)
|
||||
go func() {
|
||||
if err := client.Run(run); err != nil {
|
||||
|
||||
@@ -12,7 +12,7 @@ import (
|
||||
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
nbnet "github.com/netbirdio/netbird/client/net"
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
)
|
||||
|
||||
const (
|
||||
|
||||
@@ -19,7 +19,7 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/ipfwdstate"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
nbnet "github.com/netbirdio/netbird/client/net"
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
)
|
||||
|
||||
// constants needed to manage and create iptable rules
|
||||
|
||||
@@ -14,7 +14,7 @@ import (
|
||||
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/firewall/test"
|
||||
nbnet "github.com/netbirdio/netbird/client/net"
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
)
|
||||
|
||||
func isIptablesSupported() bool {
|
||||
|
||||
@@ -16,7 +16,7 @@ import (
|
||||
"golang.org/x/sys/unix"
|
||||
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
nbnet "github.com/netbirdio/netbird/client/net"
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
)
|
||||
|
||||
const (
|
||||
|
||||
@@ -22,7 +22,7 @@ import (
|
||||
nbid "github.com/netbirdio/netbird/client/internal/acl/id"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/ipfwdstate"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
||||
nbnet "github.com/netbirdio/netbird/client/net"
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
)
|
||||
|
||||
const (
|
||||
|
||||
@@ -3,7 +3,7 @@ package bind
|
||||
import (
|
||||
wireguard "golang.zx2c4.com/wireguard/conn"
|
||||
|
||||
nbnet "github.com/netbirdio/netbird/client/net"
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
)
|
||||
|
||||
// TODO: This is most likely obsolete since the control fns should be called by the wrapped udpconn (ice_bind.go)
|
||||
|
||||
@@ -17,7 +17,7 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/udpmux"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
nbnet "github.com/netbirdio/netbird/client/net"
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
)
|
||||
|
||||
type RecvMessage struct {
|
||||
|
||||
@@ -17,8 +17,8 @@ import (
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/bind"
|
||||
nbnet "github.com/netbirdio/netbird/client/net"
|
||||
"github.com/netbirdio/netbird/monotime"
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -409,7 +409,7 @@ func toBytes(s string) (int64, error) {
|
||||
}
|
||||
|
||||
func getFwmark() int {
|
||||
if nbnet.AdvancedRouting() && runtime.GOOS == "linux" {
|
||||
if nbnet.AdvancedRouting() {
|
||||
return nbnet.ControlPlaneMark
|
||||
}
|
||||
return 0
|
||||
|
||||
@@ -15,8 +15,8 @@ import (
|
||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||
"github.com/netbirdio/netbird/client/iface/udpmux"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
nbnet "github.com/netbirdio/netbird/client/net"
|
||||
"github.com/netbirdio/netbird/sharedsock"
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
)
|
||||
|
||||
type TunKernelDevice struct {
|
||||
@@ -101,8 +101,13 @@ func (t *TunKernelDevice) Up() (*udpmux.UniversalUDPMuxDefault, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var udpConn net.PacketConn = rawSock
|
||||
if !nbnet.AdvancedRouting() {
|
||||
udpConn = nbnet.WrapPacketConn(rawSock)
|
||||
}
|
||||
|
||||
bindParams := udpmux.UniversalUDPMuxParams{
|
||||
UDPConn: nbnet.WrapPacketConn(rawSock),
|
||||
UDPConn: udpConn,
|
||||
Net: t.transportNet,
|
||||
FilterFn: t.filterFn,
|
||||
WGAddress: t.address,
|
||||
|
||||
@@ -12,7 +12,7 @@ import (
|
||||
nbnetstack "github.com/netbirdio/netbird/client/iface/netstack"
|
||||
"github.com/netbirdio/netbird/client/iface/udpmux"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
nbnet "github.com/netbirdio/netbird/client/net"
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
)
|
||||
|
||||
type TunNetstackDevice struct {
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
package udpmux
|
||||
|
||||
import (
|
||||
nbnet "github.com/netbirdio/netbird/client/net"
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
)
|
||||
|
||||
func (m *SingleSocketUDPMux) notifyAddressRemoval(addr string) {
|
||||
|
||||
@@ -20,7 +20,7 @@ import (
|
||||
"github.com/netbirdio/netbird/client/iface/bufsize"
|
||||
"github.com/netbirdio/netbird/client/internal/ebpf"
|
||||
ebpfMgr "github.com/netbirdio/netbird/client/internal/ebpf/manager"
|
||||
nbnet "github.com/netbirdio/netbird/client/net"
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
)
|
||||
|
||||
const (
|
||||
|
||||
@@ -34,7 +34,7 @@ import (
|
||||
relayClient "github.com/netbirdio/netbird/shared/relay/client"
|
||||
signal "github.com/netbirdio/netbird/shared/signal/client"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
nbnet "github.com/netbirdio/netbird/client/net"
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
"github.com/netbirdio/netbird/version"
|
||||
)
|
||||
|
||||
|
||||
@@ -10,7 +10,7 @@ import (
|
||||
"github.com/miekg/dns"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
nbnet "github.com/netbirdio/netbird/client/net"
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
)
|
||||
|
||||
type ServiceViaMemory struct {
|
||||
|
||||
@@ -10,7 +10,7 @@ import (
|
||||
"github.com/miekg/dns"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
nbnet "github.com/netbirdio/netbird/client/net"
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
)
|
||||
|
||||
type upstreamResolver struct {
|
||||
|
||||
@@ -446,8 +446,6 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
|
||||
return fmt.Errorf("up wg interface: %w", err)
|
||||
}
|
||||
|
||||
|
||||
|
||||
// if inbound conns are blocked there is no need to create the ACL manager
|
||||
if e.firewall != nil && !e.config.BlockInbound {
|
||||
e.acl = acl.NewDefaultManager(e.firewall)
|
||||
|
||||
@@ -14,7 +14,7 @@ import (
|
||||
"github.com/ti-mo/netfilter"
|
||||
|
||||
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||
nbnet "github.com/netbirdio/netbird/client/net"
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
)
|
||||
|
||||
const defaultChannelSize = 100
|
||||
|
||||
@@ -12,7 +12,7 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||
nbnet "github.com/netbirdio/netbird/client/net"
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
)
|
||||
|
||||
// ProbeResult holds the info about the result of a relay probe request
|
||||
|
||||
@@ -36,9 +36,9 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/vars"
|
||||
"github.com/netbirdio/netbird/client/internal/routeselector"
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
nbnet "github.com/netbirdio/netbird/client/net"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
relayClient "github.com/netbirdio/netbird/shared/relay/client"
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
"github.com/netbirdio/netbird/version"
|
||||
)
|
||||
|
||||
@@ -108,10 +108,6 @@ func NewManager(config ManagerConfig) *DefaultManager {
|
||||
notifier := notifier.NewNotifier()
|
||||
sysOps := systemops.NewSysOps(config.WGInterface, notifier)
|
||||
|
||||
if runtime.GOOS == "windows" && config.WGInterface != nil {
|
||||
nbnet.SetVPNInterfaceName(config.WGInterface.Name())
|
||||
}
|
||||
|
||||
dm := &DefaultManager{
|
||||
ctx: mCTX,
|
||||
stop: cancel,
|
||||
@@ -212,7 +208,7 @@ func (m *DefaultManager) Init() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := m.sysOps.CleanupRouting(nil, nbnet.AdvancedRouting()); err != nil {
|
||||
if err := m.sysOps.CleanupRouting(nil); err != nil {
|
||||
log.Warnf("Failed cleaning up routing: %v", err)
|
||||
}
|
||||
|
||||
@@ -223,7 +219,7 @@ func (m *DefaultManager) Init() error {
|
||||
|
||||
ips := resolveURLsToIPs(initialAddresses)
|
||||
|
||||
if err := m.sysOps.SetupRouting(ips, m.stateManager, nbnet.AdvancedRouting()); err != nil {
|
||||
if err := m.sysOps.SetupRouting(ips, m.stateManager); err != nil {
|
||||
return fmt.Errorf("setup routing: %w", err)
|
||||
}
|
||||
|
||||
@@ -289,15 +285,11 @@ func (m *DefaultManager) Stop(stateManager *statemanager.Manager) {
|
||||
}
|
||||
|
||||
if !nbnet.CustomRoutingDisabled() && !m.disableClientRoutes {
|
||||
if err := m.sysOps.CleanupRouting(stateManager, nbnet.AdvancedRouting()); err != nil {
|
||||
if err := m.sysOps.CleanupRouting(stateManager); err != nil {
|
||||
log.Errorf("Error cleaning up routing: %v", err)
|
||||
} else {
|
||||
log.Info("Routing cleanup complete")
|
||||
}
|
||||
|
||||
if runtime.GOOS == "windows" {
|
||||
nbnet.SetVPNInterfaceName("")
|
||||
}
|
||||
}
|
||||
|
||||
m.mux.Lock()
|
||||
|
||||
@@ -12,11 +12,11 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
)
|
||||
|
||||
func (r *SysOps) SetupRouting([]net.IP, *statemanager.Manager, bool) error {
|
||||
func (r *SysOps) SetupRouting([]net.IP, *statemanager.Manager) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *SysOps) CleanupRouting(*statemanager.Manager, bool) error {
|
||||
func (r *SysOps) CleanupRouting(*statemanager.Manager) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
package systemops
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
@@ -21,7 +22,7 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/util"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/vars"
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
"github.com/netbirdio/netbird/client/net/hooks"
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
)
|
||||
|
||||
const localSubnetsCacheTTL = 15 * time.Minute
|
||||
@@ -95,9 +96,9 @@ func (r *SysOps) cleanupRefCounter(stateManager *statemanager.Manager) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
hooks.RemoveWriteHooks()
|
||||
hooks.RemoveCloseHooks()
|
||||
hooks.RemoveAddressRemoveHooks()
|
||||
// TODO: Remove hooks selectively
|
||||
nbnet.RemoveDialerHooks()
|
||||
nbnet.RemoveListenerHooks()
|
||||
|
||||
if err := r.refCounter.Flush(); err != nil {
|
||||
return fmt.Errorf("flush route manager: %w", err)
|
||||
@@ -289,7 +290,12 @@ func (r *SysOps) genericRemoveVPNRoute(prefix netip.Prefix, intf *net.Interface)
|
||||
}
|
||||
|
||||
func (r *SysOps) setupHooks(initAddresses []net.IP, stateManager *statemanager.Manager) error {
|
||||
beforeHook := func(connID hooks.ConnectionID, prefix netip.Prefix) error {
|
||||
beforeHook := func(connID nbnet.ConnectionID, ip net.IP) error {
|
||||
prefix, err := util.GetPrefixFromIP(ip)
|
||||
if err != nil {
|
||||
return fmt.Errorf("convert ip to prefix: %w", err)
|
||||
}
|
||||
|
||||
if _, err := r.refCounter.IncrementWithID(string(connID), prefix, struct{}{}); err != nil {
|
||||
return fmt.Errorf("adding route reference: %v", err)
|
||||
}
|
||||
@@ -298,7 +304,7 @@ func (r *SysOps) setupHooks(initAddresses []net.IP, stateManager *statemanager.M
|
||||
|
||||
return nil
|
||||
}
|
||||
afterHook := func(connID hooks.ConnectionID) error {
|
||||
afterHook := func(connID nbnet.ConnectionID) error {
|
||||
if err := r.refCounter.DecrementWithID(string(connID)); err != nil {
|
||||
return fmt.Errorf("remove route reference: %w", err)
|
||||
}
|
||||
@@ -311,20 +317,36 @@ func (r *SysOps) setupHooks(initAddresses []net.IP, stateManager *statemanager.M
|
||||
var merr *multierror.Error
|
||||
|
||||
for _, ip := range initAddresses {
|
||||
prefix, err := util.GetPrefixFromIP(ip)
|
||||
if err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("invalid IP address %s: %w", ip, err))
|
||||
continue
|
||||
}
|
||||
if err := beforeHook("init", prefix); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("add initial route for %s: %w", prefix, err))
|
||||
if err := beforeHook("init", ip); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("add initial route for %s: %w", ip, err))
|
||||
}
|
||||
}
|
||||
|
||||
hooks.AddWriteHook(beforeHook)
|
||||
hooks.AddCloseHook(afterHook)
|
||||
nbnet.AddDialerHook(func(ctx context.Context, connID nbnet.ConnectionID, resolvedIPs []net.IPAddr) error {
|
||||
if ctx.Err() != nil {
|
||||
return ctx.Err()
|
||||
}
|
||||
|
||||
hooks.AddAddressRemoveHook(func(connID hooks.ConnectionID, prefix netip.Prefix) error {
|
||||
var merr *multierror.Error
|
||||
for _, ip := range resolvedIPs {
|
||||
merr = multierror.Append(merr, beforeHook(connID, ip.IP))
|
||||
}
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
})
|
||||
|
||||
nbnet.AddDialerCloseHook(func(connID nbnet.ConnectionID, conn *net.Conn) error {
|
||||
return afterHook(connID)
|
||||
})
|
||||
|
||||
nbnet.AddListenerWriteHook(func(connID nbnet.ConnectionID, ip *net.IPAddr, data []byte) error {
|
||||
return beforeHook(connID, ip.IP)
|
||||
})
|
||||
|
||||
nbnet.AddListenerCloseHook(func(connID nbnet.ConnectionID, conn net.PacketConn) error {
|
||||
return afterHook(connID)
|
||||
})
|
||||
|
||||
nbnet.AddListenerAddressRemoveHook(func(connID nbnet.ConnectionID, prefix netip.Prefix) error {
|
||||
if _, err := r.refCounter.Decrement(prefix); err != nil {
|
||||
return fmt.Errorf("remove route reference: %w", err)
|
||||
}
|
||||
|
||||
@@ -22,7 +22,6 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/vars"
|
||||
nbnet "github.com/netbirdio/netbird/client/net"
|
||||
)
|
||||
|
||||
type dialer interface {
|
||||
@@ -144,11 +143,10 @@ func TestAddVPNRoute(t *testing.T) {
|
||||
wgInterface := createWGInterface(t, fmt.Sprintf("utun53%d", n), "100.65.75.2/24", 33100+n)
|
||||
|
||||
r := NewSysOps(wgInterface, nil)
|
||||
advancedRouting := nbnet.AdvancedRouting()
|
||||
err := r.SetupRouting(nil, nil, advancedRouting)
|
||||
err := r.SetupRouting(nil, nil)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
assert.NoError(t, r.CleanupRouting(nil, advancedRouting))
|
||||
assert.NoError(t, r.CleanupRouting(nil))
|
||||
})
|
||||
|
||||
intf, err := net.InterfaceByName(wgInterface.Name())
|
||||
@@ -343,11 +341,10 @@ func TestAddRouteToNonVPNIntf(t *testing.T) {
|
||||
wgInterface := createWGInterface(t, fmt.Sprintf("utun54%d", n), "100.65.75.2/24", 33200+n)
|
||||
|
||||
r := NewSysOps(wgInterface, nil)
|
||||
advancedRouting := nbnet.AdvancedRouting()
|
||||
err := r.SetupRouting(nil, nil, advancedRouting)
|
||||
err := r.SetupRouting(nil, nil)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
assert.NoError(t, r.CleanupRouting(nil, advancedRouting))
|
||||
assert.NoError(t, r.CleanupRouting(nil))
|
||||
})
|
||||
|
||||
initialNextHopV4, err := GetNextHop(netip.IPv4Unspecified())
|
||||
@@ -487,11 +484,10 @@ func setupTestEnv(t *testing.T) {
|
||||
})
|
||||
|
||||
r := NewSysOps(wgInterface, nil)
|
||||
advancedRouting := nbnet.AdvancedRouting()
|
||||
err := r.SetupRouting(nil, nil, advancedRouting)
|
||||
err := r.SetupRouting(nil, nil)
|
||||
require.NoError(t, err, "setupRouting should not return err")
|
||||
t.Cleanup(func() {
|
||||
assert.NoError(t, r.CleanupRouting(nil, advancedRouting))
|
||||
assert.NoError(t, r.CleanupRouting(nil))
|
||||
})
|
||||
|
||||
index, err := net.InterfaceByName(wgInterface.Name())
|
||||
|
||||
@@ -12,14 +12,14 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
)
|
||||
|
||||
func (r *SysOps) SetupRouting([]net.IP, *statemanager.Manager, bool) error {
|
||||
func (r *SysOps) SetupRouting([]net.IP, *statemanager.Manager) error {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
r.prefixes = make(map[netip.Prefix]struct{})
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *SysOps) CleanupRouting(*statemanager.Manager, bool) error {
|
||||
func (r *SysOps) CleanupRouting(*statemanager.Manager) error {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
|
||||
@@ -20,7 +20,7 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/sysctl"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/vars"
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
nbnet "github.com/netbirdio/netbird/client/net"
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
)
|
||||
|
||||
// IPRule contains IP rule information for debugging
|
||||
@@ -94,15 +94,15 @@ func getSetupRules() []ruleParams {
|
||||
// Rule 2 (VPN Traffic Routing): Directs all remaining traffic to the 'NetbirdVPNTableID' custom routing table.
|
||||
// This table is where a default route or other specific routes received from the management server are configured,
|
||||
// enabling VPN connectivity.
|
||||
func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager, advancedRouting bool) (err error) {
|
||||
if !advancedRouting {
|
||||
func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager) (err error) {
|
||||
if !nbnet.AdvancedRouting() {
|
||||
log.Infof("Using legacy routing setup")
|
||||
return r.setupRefCounter(initAddresses, stateManager)
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if err != nil {
|
||||
if cleanErr := r.CleanupRouting(stateManager, advancedRouting); cleanErr != nil {
|
||||
if cleanErr := r.CleanupRouting(stateManager); cleanErr != nil {
|
||||
log.Errorf("Error cleaning up routing: %v", cleanErr)
|
||||
}
|
||||
}
|
||||
@@ -132,8 +132,8 @@ func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager
|
||||
// CleanupRouting performs a thorough cleanup of the routing configuration established by 'setupRouting'.
|
||||
// It systematically removes the three rules and any associated routing table entries to ensure a clean state.
|
||||
// The function uses error aggregation to report any errors encountered during the cleanup process.
|
||||
func (r *SysOps) CleanupRouting(stateManager *statemanager.Manager, advancedRouting bool) error {
|
||||
if !advancedRouting {
|
||||
func (r *SysOps) CleanupRouting(stateManager *statemanager.Manager) error {
|
||||
if !nbnet.AdvancedRouting() {
|
||||
return r.cleanupRefCounter(stateManager)
|
||||
}
|
||||
|
||||
|
||||
@@ -20,11 +20,11 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
)
|
||||
|
||||
func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager, advancedRouting bool) error {
|
||||
func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager) error {
|
||||
return r.setupRefCounter(initAddresses, stateManager)
|
||||
}
|
||||
|
||||
func (r *SysOps) CleanupRouting(stateManager *statemanager.Manager, advancedRouting bool) error {
|
||||
func (r *SysOps) CleanupRouting(stateManager *statemanager.Manager) error {
|
||||
return r.cleanupRefCounter(stateManager)
|
||||
}
|
||||
|
||||
|
||||
@@ -17,7 +17,7 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
nbnet "github.com/netbirdio/netbird/client/net"
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
)
|
||||
|
||||
type PacketExpectation struct {
|
||||
|
||||
@@ -8,7 +8,6 @@ import (
|
||||
"net/netip"
|
||||
"os"
|
||||
"runtime/debug"
|
||||
"sort"
|
||||
"strconv"
|
||||
"sync"
|
||||
"syscall"
|
||||
@@ -20,16 +19,9 @@ import (
|
||||
"golang.org/x/sys/windows"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
nbnet "github.com/netbirdio/netbird/client/net"
|
||||
)
|
||||
|
||||
func init() {
|
||||
nbnet.GetBestInterfaceFunc = GetBestInterface
|
||||
}
|
||||
|
||||
const (
|
||||
InfiniteLifetime = 0xffffffff
|
||||
)
|
||||
const InfiniteLifetime = 0xffffffff
|
||||
|
||||
type RouteUpdateType int
|
||||
|
||||
@@ -85,14 +77,6 @@ type MIB_IPFORWARD_TABLE2 struct {
|
||||
Table [1]MIB_IPFORWARD_ROW2 // Flexible array member
|
||||
}
|
||||
|
||||
// candidateRoute represents a potential route for selection during route lookup
|
||||
type candidateRoute struct {
|
||||
interfaceIndex uint32
|
||||
prefixLength uint8
|
||||
routeMetric uint32
|
||||
interfaceMetric int
|
||||
}
|
||||
|
||||
// IP_ADDRESS_PREFIX is defined in https://learn.microsoft.com/en-us/windows/win32/api/netioapi/ns-netioapi-ip_address_prefix
|
||||
type IP_ADDRESS_PREFIX struct {
|
||||
Prefix SOCKADDR_INET
|
||||
@@ -193,20 +177,11 @@ const (
|
||||
RouteDeleted
|
||||
)
|
||||
|
||||
func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager, advancedRouting bool) error {
|
||||
if advancedRouting {
|
||||
return nil
|
||||
}
|
||||
|
||||
log.Infof("Using legacy routing setup with ref counters")
|
||||
func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager) error {
|
||||
return r.setupRefCounter(initAddresses, stateManager)
|
||||
}
|
||||
|
||||
func (r *SysOps) CleanupRouting(stateManager *statemanager.Manager, advancedRouting bool) error {
|
||||
if advancedRouting {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *SysOps) CleanupRouting(stateManager *statemanager.Manager) error {
|
||||
return r.cleanupRefCounter(stateManager)
|
||||
}
|
||||
|
||||
@@ -660,7 +635,10 @@ func getWindowsRoutingTable() (*MIB_IPFORWARD_TABLE2, error) {
|
||||
|
||||
func freeWindowsRoutingTable(table *MIB_IPFORWARD_TABLE2) {
|
||||
if table != nil {
|
||||
_, _, _ = procFreeMibTable.Call(uintptr(unsafe.Pointer(table)))
|
||||
ret, _, _ := procFreeMibTable.Call(uintptr(unsafe.Pointer(table)))
|
||||
if ret != 0 {
|
||||
log.Warnf("FreeMibTable failed with return code: %d", ret)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -674,7 +652,8 @@ func parseWindowsRoutingTable(table *MIB_IPFORWARD_TABLE2) []DetailedRoute {
|
||||
entryPtr := basePtr + uintptr(i)*entrySize
|
||||
entry := (*MIB_IPFORWARD_ROW2)(unsafe.Pointer(entryPtr))
|
||||
|
||||
if detailed := buildWindowsDetailedRoute(entry); detailed != nil {
|
||||
detailed := buildWindowsDetailedRoute(entry)
|
||||
if detailed != nil {
|
||||
detailedRoutes = append(detailedRoutes, *detailed)
|
||||
}
|
||||
}
|
||||
@@ -823,46 +802,6 @@ func addZone(ip netip.Addr, interfaceIndex int) netip.Addr {
|
||||
return ip
|
||||
}
|
||||
|
||||
// parseCandidatesFromTable extracts all matching candidate routes from the routing table
|
||||
func parseCandidatesFromTable(table *MIB_IPFORWARD_TABLE2, dest netip.Addr, skipInterfaceIndex int) []candidateRoute {
|
||||
var candidates []candidateRoute
|
||||
entrySize := unsafe.Sizeof(MIB_IPFORWARD_ROW2{})
|
||||
basePtr := uintptr(unsafe.Pointer(&table.Table[0]))
|
||||
|
||||
for i := uint32(0); i < table.NumEntries; i++ {
|
||||
entryPtr := basePtr + uintptr(i)*entrySize
|
||||
entry := (*MIB_IPFORWARD_ROW2)(unsafe.Pointer(entryPtr))
|
||||
|
||||
if candidate := parseCandidateRoute(entry, dest, skipInterfaceIndex); candidate != nil {
|
||||
candidates = append(candidates, *candidate)
|
||||
}
|
||||
}
|
||||
|
||||
return candidates
|
||||
}
|
||||
|
||||
// parseCandidateRoute extracts candidate route information from a MIB_IPFORWARD_ROW2 entry
|
||||
// Returns nil if the route doesn't match the destination or should be skipped
|
||||
func parseCandidateRoute(entry *MIB_IPFORWARD_ROW2, dest netip.Addr, skipInterfaceIndex int) *candidateRoute {
|
||||
if skipInterfaceIndex > 0 && int(entry.InterfaceIndex) == skipInterfaceIndex {
|
||||
return nil
|
||||
}
|
||||
|
||||
destPrefix := parseIPPrefix(entry.DestinationPrefix, int(entry.InterfaceIndex))
|
||||
if !destPrefix.IsValid() || !destPrefix.Contains(dest) {
|
||||
return nil
|
||||
}
|
||||
|
||||
interfaceMetric := getInterfaceMetric(entry.InterfaceIndex, entry.DestinationPrefix.Prefix.sin6_family)
|
||||
|
||||
return &candidateRoute{
|
||||
interfaceIndex: entry.InterfaceIndex,
|
||||
prefixLength: entry.DestinationPrefix.PrefixLength,
|
||||
routeMetric: entry.Metric,
|
||||
interfaceMetric: interfaceMetric,
|
||||
}
|
||||
}
|
||||
|
||||
// getInterfaceMetric retrieves the interface metric for a given interface and address family
|
||||
func getInterfaceMetric(interfaceIndex uint32, family int16) int {
|
||||
if interfaceIndex == 0 {
|
||||
@@ -882,76 +821,6 @@ func getInterfaceMetric(interfaceIndex uint32, family int16) int {
|
||||
return int(ipInterfaceRow.Metric)
|
||||
}
|
||||
|
||||
// sortRouteCandidates sorts route candidates by priority: prefix length -> route metric -> interface metric
|
||||
func sortRouteCandidates(candidates []candidateRoute) {
|
||||
sort.Slice(candidates, func(i, j int) bool {
|
||||
if candidates[i].prefixLength != candidates[j].prefixLength {
|
||||
return candidates[i].prefixLength > candidates[j].prefixLength
|
||||
}
|
||||
if candidates[i].routeMetric != candidates[j].routeMetric {
|
||||
return candidates[i].routeMetric < candidates[j].routeMetric
|
||||
}
|
||||
return candidates[i].interfaceMetric < candidates[j].interfaceMetric
|
||||
})
|
||||
}
|
||||
|
||||
// GetBestInterface finds the best interface for reaching a destination,
|
||||
// excluding the VPN interface to avoid routing loops.
|
||||
//
|
||||
// Route selection priority:
|
||||
// 1. Longest prefix match (most specific route)
|
||||
// 2. Lowest route metric
|
||||
// 3. Lowest interface metric
|
||||
func GetBestInterface(dest netip.Addr, vpnIntf string) (*net.Interface, error) {
|
||||
var skipInterfaceIndex int
|
||||
if vpnIntf != "" {
|
||||
if iface, err := net.InterfaceByName(vpnIntf); err == nil {
|
||||
skipInterfaceIndex = iface.Index
|
||||
} else {
|
||||
// not critical, if we cannot get ahold of the interface then we won't need to skip it
|
||||
log.Warnf("failed to get VPN interface %s: %v", vpnIntf, err)
|
||||
}
|
||||
}
|
||||
|
||||
table, err := getWindowsRoutingTable()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get routing table: %w", err)
|
||||
}
|
||||
defer freeWindowsRoutingTable(table)
|
||||
|
||||
candidates := parseCandidatesFromTable(table, dest, skipInterfaceIndex)
|
||||
|
||||
if len(candidates) == 0 {
|
||||
return nil, fmt.Errorf("no route to %s", dest)
|
||||
}
|
||||
|
||||
// Sort routes: prefix length -> route metric -> interface metric
|
||||
sortRouteCandidates(candidates)
|
||||
|
||||
for _, candidate := range candidates {
|
||||
iface, err := net.InterfaceByIndex(int(candidate.interfaceIndex))
|
||||
if err != nil {
|
||||
log.Warnf("failed to get interface by index %d: %v", candidate.interfaceIndex, err)
|
||||
continue
|
||||
}
|
||||
|
||||
if iface.Flags&net.FlagLoopback != 0 && !dest.IsLoopback() {
|
||||
continue
|
||||
}
|
||||
|
||||
if iface.Flags&net.FlagUp == 0 {
|
||||
log.Debugf("interface %s is down, trying next route", iface.Name)
|
||||
continue
|
||||
}
|
||||
|
||||
log.Debugf("route lookup for %s: selected interface %s (index %d), route metric %d, interface metric %d",
|
||||
dest, iface.Name, iface.Index, candidate.routeMetric, candidate.interfaceMetric)
|
||||
return iface, nil
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("no usable interface found for %s", dest)
|
||||
}
|
||||
|
||||
// formatRouteAge formats the route age in seconds to a human-readable string
|
||||
func formatRouteAge(ageSeconds uint32) string {
|
||||
if ageSeconds == 0 {
|
||||
|
||||
@@ -15,7 +15,7 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
nbnet "github.com/netbirdio/netbird/client/net"
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
)
|
||||
|
||||
var (
|
||||
|
||||
@@ -12,8 +12,18 @@ func GetPrefixFromIP(ip net.IP) (netip.Prefix, error) {
|
||||
if !ok {
|
||||
return netip.Prefix{}, fmt.Errorf("parse IP address: %s", ip)
|
||||
}
|
||||
|
||||
addr = addr.Unmap()
|
||||
prefix := netip.PrefixFrom(addr, addr.BitLen())
|
||||
|
||||
var prefixLength int
|
||||
switch {
|
||||
case addr.Is4():
|
||||
prefixLength = 32
|
||||
case addr.Is6():
|
||||
prefixLength = 128
|
||||
default:
|
||||
return netip.Prefix{}, fmt.Errorf("invalid IP address: %s", addr)
|
||||
}
|
||||
|
||||
prefix := netip.PrefixFrom(addr, prefixLength)
|
||||
return prefix, nil
|
||||
}
|
||||
|
||||
@@ -5,7 +5,7 @@ import (
|
||||
|
||||
"github.com/pion/transport/v3"
|
||||
|
||||
nbnet "github.com/netbirdio/netbird/client/net"
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
)
|
||||
|
||||
// Dial connects to the address on the named network.
|
||||
|
||||
@@ -6,7 +6,7 @@ import (
|
||||
|
||||
"github.com/pion/transport/v3"
|
||||
|
||||
nbnet "github.com/netbirdio/netbird/client/net"
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
)
|
||||
|
||||
// ListenPacket listens for incoming packets on the given network and address.
|
||||
|
||||
@@ -1,49 +0,0 @@
|
||||
//go:build !ios
|
||||
|
||||
package net
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/net/hooks"
|
||||
)
|
||||
|
||||
// Conn wraps a net.Conn to override the Close method
|
||||
type Conn struct {
|
||||
net.Conn
|
||||
ID hooks.ConnectionID
|
||||
}
|
||||
|
||||
// Close overrides the net.Conn Close method to execute all registered hooks after closing the connection
|
||||
// Close overrides the net.Conn Close method to execute all registered hooks before closing the connection.
|
||||
func (c *Conn) Close() error {
|
||||
return closeConn(c.ID, c.Conn)
|
||||
}
|
||||
|
||||
// TCPConn wraps net.TCPConn to override its Close method to include hook functionality.
|
||||
type TCPConn struct {
|
||||
*net.TCPConn
|
||||
ID hooks.ConnectionID
|
||||
}
|
||||
|
||||
// Close overrides the net.TCPConn Close method to execute all registered hooks before closing the connection.
|
||||
func (c *TCPConn) Close() error {
|
||||
return closeConn(c.ID, c.TCPConn)
|
||||
}
|
||||
|
||||
// closeConn is a helper function to close connections and execute close hooks.
|
||||
func closeConn(id hooks.ConnectionID, conn io.Closer) error {
|
||||
err := conn.Close()
|
||||
|
||||
closeHooks := hooks.GetCloseHooks()
|
||||
for _, hook := range closeHooks {
|
||||
if err := hook(id); err != nil {
|
||||
log.Errorf("Error executing close hook: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
@@ -1,82 +0,0 @@
|
||||
//go:build !ios
|
||||
|
||||
package net
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
|
||||
"github.com/pion/transport/v3"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
func DialUDP(network string, laddr, raddr *net.UDPAddr) (transport.UDPConn, error) {
|
||||
if CustomRoutingDisabled() {
|
||||
return net.DialUDP(network, laddr, raddr)
|
||||
}
|
||||
|
||||
dialer := NewDialer()
|
||||
dialer.LocalAddr = laddr
|
||||
|
||||
conn, err := dialer.Dial(network, raddr.String())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("dialing UDP %s: %w", raddr.String(), err)
|
||||
}
|
||||
|
||||
switch c := conn.(type) {
|
||||
case *net.UDPConn:
|
||||
// Advanced routing: plain connection
|
||||
return c, nil
|
||||
case *Conn:
|
||||
// Legacy routing: wrapped connection preserves close hooks
|
||||
udpConn, ok := c.Conn.(*net.UDPConn)
|
||||
if !ok {
|
||||
if err := conn.Close(); err != nil {
|
||||
log.Errorf("Failed to close connection: %v", err)
|
||||
}
|
||||
return nil, fmt.Errorf("expected UDP connection, got %T", c.Conn)
|
||||
}
|
||||
return &UDPConn{UDPConn: udpConn, ID: c.ID, seenAddrs: &sync.Map{}}, nil
|
||||
}
|
||||
|
||||
if err := conn.Close(); err != nil {
|
||||
log.Errorf("failed to close connection: %v", err)
|
||||
}
|
||||
return nil, fmt.Errorf("unexpected connection type: %T", conn)
|
||||
}
|
||||
|
||||
func DialTCP(network string, laddr, raddr *net.TCPAddr) (transport.TCPConn, error) {
|
||||
if CustomRoutingDisabled() {
|
||||
return net.DialTCP(network, laddr, raddr)
|
||||
}
|
||||
|
||||
dialer := NewDialer()
|
||||
dialer.LocalAddr = laddr
|
||||
|
||||
conn, err := dialer.Dial(network, raddr.String())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("dialing TCP %s: %w", raddr.String(), err)
|
||||
}
|
||||
|
||||
switch c := conn.(type) {
|
||||
case *net.TCPConn:
|
||||
// Advanced routing: plain connection
|
||||
return c, nil
|
||||
case *Conn:
|
||||
// Legacy routing: wrapped connection preserves close hooks
|
||||
tcpConn, ok := c.Conn.(*net.TCPConn)
|
||||
if !ok {
|
||||
if err := conn.Close(); err != nil {
|
||||
log.Errorf("Failed to close connection: %v", err)
|
||||
}
|
||||
return nil, fmt.Errorf("expected TCP connection, got %T", c.Conn)
|
||||
}
|
||||
return &TCPConn{TCPConn: tcpConn, ID: c.ID}, nil
|
||||
}
|
||||
|
||||
if err := conn.Close(); err != nil {
|
||||
log.Errorf("failed to close connection: %v", err)
|
||||
}
|
||||
return nil, fmt.Errorf("unexpected connection type: %T", conn)
|
||||
}
|
||||
@@ -1,87 +0,0 @@
|
||||
//go:build !ios
|
||||
|
||||
package net
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
|
||||
"github.com/hashicorp/go-multierror"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/util"
|
||||
"github.com/netbirdio/netbird/client/net/hooks"
|
||||
)
|
||||
|
||||
// DialContext wraps the net.Dialer's DialContext method to use the custom connection
|
||||
func (d *Dialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
log.Debugf("Dialing %s %s", network, address)
|
||||
|
||||
if CustomRoutingDisabled() || AdvancedRouting() {
|
||||
return d.Dialer.DialContext(ctx, network, address)
|
||||
}
|
||||
|
||||
connID := hooks.GenerateConnID()
|
||||
if err := callDialerHooks(ctx, connID, address, d.Resolver); err != nil {
|
||||
log.Errorf("Failed to call dialer hooks: %v", err)
|
||||
}
|
||||
|
||||
conn, err := d.Dialer.DialContext(ctx, network, address)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("d.Dialer.DialContext: %w", err)
|
||||
}
|
||||
|
||||
// Wrap the connection in Conn to handle Close with hooks
|
||||
return &Conn{Conn: conn, ID: connID}, nil
|
||||
}
|
||||
|
||||
// Dial wraps the net.Dialer's Dial method to use the custom connection
|
||||
func (d *Dialer) Dial(network, address string) (net.Conn, error) {
|
||||
return d.DialContext(context.Background(), network, address)
|
||||
}
|
||||
|
||||
func callDialerHooks(ctx context.Context, connID hooks.ConnectionID, address string, customResolver *net.Resolver) error {
|
||||
if ctx.Err() != nil {
|
||||
return ctx.Err()
|
||||
}
|
||||
|
||||
writeHooks := hooks.GetWriteHooks()
|
||||
if len(writeHooks) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
host, _, err := net.SplitHostPort(address)
|
||||
if err != nil {
|
||||
return fmt.Errorf("split host and port: %w", err)
|
||||
}
|
||||
|
||||
resolver := customResolver
|
||||
if resolver == nil {
|
||||
resolver = net.DefaultResolver
|
||||
}
|
||||
|
||||
ips, err := resolver.LookupIPAddr(ctx, host)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to resolve address %s: %w", address, err)
|
||||
}
|
||||
|
||||
log.Debugf("Dialer resolved IPs for %s: %v", address, ips)
|
||||
|
||||
var merr *multierror.Error
|
||||
for _, ip := range ips {
|
||||
prefix, err := util.GetPrefixFromIP(ip.IP)
|
||||
if err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("convert IP %s to prefix: %w", ip.IP, err))
|
||||
continue
|
||||
}
|
||||
for _, hook := range writeHooks {
|
||||
if err := hook(connID, prefix); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("executing dial hook for IP %s: %w", ip.IP, err))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
@@ -1,7 +0,0 @@
|
||||
//go:build !linux && !windows
|
||||
|
||||
package net
|
||||
|
||||
func (d *Dialer) init() {
|
||||
// implemented on Linux, Android, and Windows only
|
||||
}
|
||||
@@ -1,5 +0,0 @@
|
||||
package net
|
||||
|
||||
func (d *Dialer) init() {
|
||||
d.Dialer.Control = applyUnicastIFToSocket
|
||||
}
|
||||
@@ -1,24 +0,0 @@
|
||||
//go:build android
|
||||
|
||||
package net
|
||||
|
||||
// Init initializes the network environment for Android
|
||||
func Init() {
|
||||
// No initialization needed on Android
|
||||
}
|
||||
|
||||
// AdvancedRouting reports whether routing loops can be avoided without using exclusion routes.
|
||||
// Always returns true on Android since we cannot handle routes dynamically.
|
||||
func AdvancedRouting() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// SetVPNInterfaceName is a no-op on Android
|
||||
func SetVPNInterfaceName(name string) {
|
||||
// No-op on Android - not needed for Android VPN service
|
||||
}
|
||||
|
||||
// GetVPNInterfaceName returns empty string on Android
|
||||
func GetVPNInterfaceName() string {
|
||||
return ""
|
||||
}
|
||||
@@ -1,23 +0,0 @@
|
||||
//go:build !linux && !windows && !android
|
||||
|
||||
package net
|
||||
|
||||
// Init initializes the network environment (no-op on non-Linux/Windows platforms)
|
||||
func Init() {
|
||||
// No-op on non-Linux/Windows platforms
|
||||
}
|
||||
|
||||
// AdvancedRouting returns false on non-Linux/Windows platforms
|
||||
func AdvancedRouting() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// SetVPNInterfaceName is a no-op on non-Windows platforms
|
||||
func SetVPNInterfaceName(name string) {
|
||||
// No-op on non-Windows platforms
|
||||
}
|
||||
|
||||
// GetVPNInterfaceName returns empty string on non-Windows platforms
|
||||
func GetVPNInterfaceName() string {
|
||||
return ""
|
||||
}
|
||||
@@ -1,67 +0,0 @@
|
||||
//go:build windows
|
||||
|
||||
package net
|
||||
|
||||
import (
|
||||
"os"
|
||||
"strconv"
|
||||
"sync"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/netstack"
|
||||
)
|
||||
|
||||
var (
|
||||
vpnInterfaceName string
|
||||
vpnInitMutex sync.RWMutex
|
||||
|
||||
advancedRoutingSupported bool
|
||||
)
|
||||
|
||||
func Init() {
|
||||
advancedRoutingSupported = checkAdvancedRoutingSupport()
|
||||
}
|
||||
|
||||
func checkAdvancedRoutingSupport() bool {
|
||||
var err error
|
||||
var legacyRouting bool
|
||||
if val := os.Getenv(envUseLegacyRouting); val != "" {
|
||||
legacyRouting, err = strconv.ParseBool(val)
|
||||
if err != nil {
|
||||
log.Warnf("failed to parse %s: %v", envUseLegacyRouting, err)
|
||||
}
|
||||
}
|
||||
|
||||
if legacyRouting || netstack.IsEnabled() {
|
||||
log.Info("advanced routing has been requested to be disabled")
|
||||
return false
|
||||
}
|
||||
|
||||
log.Info("system supports advanced routing")
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// AdvancedRouting reports whether routing loops can be avoided without using exclusion routes
|
||||
func AdvancedRouting() bool {
|
||||
return advancedRoutingSupported
|
||||
}
|
||||
|
||||
// GetVPNInterfaceName returns the stored VPN interface name
|
||||
func GetVPNInterfaceName() string {
|
||||
vpnInitMutex.RLock()
|
||||
defer vpnInitMutex.RUnlock()
|
||||
return vpnInterfaceName
|
||||
}
|
||||
|
||||
// SetVPNInterfaceName sets the VPN interface name for lazy initialization
|
||||
func SetVPNInterfaceName(name string) {
|
||||
vpnInitMutex.Lock()
|
||||
defer vpnInitMutex.Unlock()
|
||||
vpnInterfaceName = name
|
||||
|
||||
if name != "" {
|
||||
log.Infof("VPN interface name set to %s for route exclusion", name)
|
||||
}
|
||||
}
|
||||
@@ -1,93 +0,0 @@
|
||||
package hooks
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"slices"
|
||||
"sync"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// ConnectionID provides a globally unique identifier for network connections.
|
||||
// It's used to track connections throughout their lifecycle so the close hook can correlate with the dial hook.
|
||||
type ConnectionID string
|
||||
|
||||
// GenerateConnID generates a unique identifier for each connection.
|
||||
func GenerateConnID() ConnectionID {
|
||||
return ConnectionID(uuid.NewString())
|
||||
}
|
||||
|
||||
type WriteHookFunc func(connID ConnectionID, prefix netip.Prefix) error
|
||||
type CloseHookFunc func(connID ConnectionID) error
|
||||
type AddressRemoveHookFunc func(connID ConnectionID, prefix netip.Prefix) error
|
||||
|
||||
var (
|
||||
hooksMutex sync.RWMutex
|
||||
|
||||
writeHooks []WriteHookFunc
|
||||
closeHooks []CloseHookFunc
|
||||
addressRemoveHooks []AddressRemoveHookFunc
|
||||
)
|
||||
|
||||
// AddWriteHook allows adding a new hook to be executed before writing/dialing.
|
||||
func AddWriteHook(hook WriteHookFunc) {
|
||||
hooksMutex.Lock()
|
||||
defer hooksMutex.Unlock()
|
||||
writeHooks = append(writeHooks, hook)
|
||||
}
|
||||
|
||||
// AddCloseHook allows adding a new hook to be executed on connection close.
|
||||
func AddCloseHook(hook CloseHookFunc) {
|
||||
hooksMutex.Lock()
|
||||
defer hooksMutex.Unlock()
|
||||
closeHooks = append(closeHooks, hook)
|
||||
}
|
||||
|
||||
// RemoveWriteHooks removes all write hooks.
|
||||
func RemoveWriteHooks() {
|
||||
hooksMutex.Lock()
|
||||
defer hooksMutex.Unlock()
|
||||
writeHooks = nil
|
||||
}
|
||||
|
||||
// RemoveCloseHooks removes all close hooks.
|
||||
func RemoveCloseHooks() {
|
||||
hooksMutex.Lock()
|
||||
defer hooksMutex.Unlock()
|
||||
closeHooks = nil
|
||||
}
|
||||
|
||||
// AddAddressRemoveHook allows adding a new hook to be executed when an address is removed.
|
||||
func AddAddressRemoveHook(hook AddressRemoveHookFunc) {
|
||||
hooksMutex.Lock()
|
||||
defer hooksMutex.Unlock()
|
||||
addressRemoveHooks = append(addressRemoveHooks, hook)
|
||||
}
|
||||
|
||||
// RemoveAddressRemoveHooks removes all listener address hooks.
|
||||
func RemoveAddressRemoveHooks() {
|
||||
hooksMutex.Lock()
|
||||
defer hooksMutex.Unlock()
|
||||
addressRemoveHooks = nil
|
||||
}
|
||||
|
||||
// GetWriteHooks returns a copy of the current write hooks.
|
||||
func GetWriteHooks() []WriteHookFunc {
|
||||
hooksMutex.RLock()
|
||||
defer hooksMutex.RUnlock()
|
||||
return slices.Clone(writeHooks)
|
||||
}
|
||||
|
||||
// GetCloseHooks returns a copy of the current close hooks.
|
||||
func GetCloseHooks() []CloseHookFunc {
|
||||
hooksMutex.RLock()
|
||||
defer hooksMutex.RUnlock()
|
||||
return slices.Clone(closeHooks)
|
||||
}
|
||||
|
||||
// GetAddressRemoveHooks returns a copy of the current listener address remove hooks.
|
||||
func GetAddressRemoveHooks() []AddressRemoveHookFunc {
|
||||
hooksMutex.RLock()
|
||||
defer hooksMutex.RUnlock()
|
||||
return slices.Clone(addressRemoveHooks)
|
||||
}
|
||||
@@ -1,47 +0,0 @@
|
||||
//go:build !ios
|
||||
|
||||
package net
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
|
||||
"github.com/pion/transport/v3"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// ListenUDP listens on the network address and returns a transport.UDPConn
|
||||
// which includes support for write and close hooks.
|
||||
func ListenUDP(network string, laddr *net.UDPAddr) (transport.UDPConn, error) {
|
||||
if CustomRoutingDisabled() {
|
||||
return net.ListenUDP(network, laddr)
|
||||
}
|
||||
|
||||
conn, err := NewListener().ListenPacket(context.Background(), network, laddr.String())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("listen UDP: %w", err)
|
||||
}
|
||||
|
||||
switch c := conn.(type) {
|
||||
case *net.UDPConn:
|
||||
// Advanced routing: plain connection
|
||||
return c, nil
|
||||
case *PacketConn:
|
||||
// Legacy routing: wrapped connection for hooks
|
||||
udpConn, ok := c.PacketConn.(*net.UDPConn)
|
||||
if !ok {
|
||||
if err := c.Close(); err != nil {
|
||||
log.Errorf("Failed to close connection: %v", err)
|
||||
}
|
||||
return nil, fmt.Errorf("expected UDPConn, got %T", c.PacketConn)
|
||||
}
|
||||
return &UDPConn{UDPConn: udpConn, ID: c.ID, seenAddrs: &sync.Map{}}, nil
|
||||
}
|
||||
|
||||
if err := conn.Close(); err != nil {
|
||||
log.Errorf("failed to close connection: %v", err)
|
||||
}
|
||||
return nil, fmt.Errorf("unexpected connection type: %T", conn)
|
||||
}
|
||||
@@ -1,7 +0,0 @@
|
||||
//go:build !linux && !windows
|
||||
|
||||
package net
|
||||
|
||||
func (l *ListenerConfig) init() {
|
||||
// implemented on Linux, Android, and Windows only
|
||||
}
|
||||
@@ -1,8 +0,0 @@
|
||||
package net
|
||||
|
||||
func (l *ListenerConfig) init() {
|
||||
// TODO: this will select a single source interface, but for UDP we can have various source interfaces and IP addresses.
|
||||
// For now we stick to the one that matches the request IP address, which can be the unspecified IP. In this case
|
||||
// the interface will be selected that serves the default route.
|
||||
l.ListenConfig.Control = applyUnicastIFToSocket
|
||||
}
|
||||
@@ -1,153 +0,0 @@
|
||||
//go:build !ios
|
||||
|
||||
package net
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"sync"
|
||||
|
||||
"github.com/hashicorp/go-multierror"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/util"
|
||||
"github.com/netbirdio/netbird/client/net/hooks"
|
||||
)
|
||||
|
||||
// ListenPacket listens on the network address and returns a PacketConn
|
||||
// which includes support for write hooks.
|
||||
func (l *ListenerConfig) ListenPacket(ctx context.Context, network, address string) (net.PacketConn, error) {
|
||||
if CustomRoutingDisabled() || AdvancedRouting() {
|
||||
return l.ListenConfig.ListenPacket(ctx, network, address)
|
||||
}
|
||||
|
||||
pc, err := l.ListenConfig.ListenPacket(ctx, network, address)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("listen packet: %w", err)
|
||||
}
|
||||
connID := hooks.GenerateConnID()
|
||||
|
||||
return &PacketConn{PacketConn: pc, ID: connID, seenAddrs: &sync.Map{}}, nil
|
||||
}
|
||||
|
||||
// PacketConn wraps net.PacketConn to override its WriteTo and Close methods to include hook functionality.
|
||||
type PacketConn struct {
|
||||
net.PacketConn
|
||||
ID hooks.ConnectionID
|
||||
seenAddrs *sync.Map
|
||||
}
|
||||
|
||||
// WriteTo writes a packet with payload b to addr, executing registered write hooks beforehand.
|
||||
func (c *PacketConn) WriteTo(b []byte, addr net.Addr) (n int, err error) {
|
||||
if err := callWriteHooks(c.ID, c.seenAddrs, addr); err != nil {
|
||||
log.Errorf("Failed to call write hooks: %v", err)
|
||||
}
|
||||
return c.PacketConn.WriteTo(b, addr)
|
||||
}
|
||||
|
||||
// Close overrides the net.PacketConn Close method to execute all registered hooks before closing the connection.
|
||||
func (c *PacketConn) Close() error {
|
||||
defer c.seenAddrs.Clear()
|
||||
return closeConn(c.ID, c.PacketConn)
|
||||
}
|
||||
|
||||
// UDPConn wraps net.UDPConn to override its WriteTo and Close methods to include hook functionality.
|
||||
type UDPConn struct {
|
||||
*net.UDPConn
|
||||
ID hooks.ConnectionID
|
||||
seenAddrs *sync.Map
|
||||
}
|
||||
|
||||
// WriteTo writes a packet with payload b to addr, executing registered write hooks beforehand.
|
||||
func (c *UDPConn) WriteTo(b []byte, addr net.Addr) (n int, err error) {
|
||||
if err := callWriteHooks(c.ID, c.seenAddrs, addr); err != nil {
|
||||
log.Errorf("Failed to call write hooks: %v", err)
|
||||
}
|
||||
return c.UDPConn.WriteTo(b, addr)
|
||||
}
|
||||
|
||||
// Close overrides the net.UDPConn Close method to execute all registered hooks before closing the connection.
|
||||
func (c *UDPConn) Close() error {
|
||||
defer c.seenAddrs.Clear()
|
||||
return closeConn(c.ID, c.UDPConn)
|
||||
}
|
||||
|
||||
// RemoveAddress removes an address from the seen cache and triggers removal hooks.
|
||||
func (c *PacketConn) RemoveAddress(addr string) {
|
||||
if _, exists := c.seenAddrs.LoadAndDelete(addr); !exists {
|
||||
return
|
||||
}
|
||||
|
||||
ipStr, _, err := net.SplitHostPort(addr)
|
||||
if err != nil {
|
||||
log.Errorf("Error splitting IP address and port: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
ipAddr, err := netip.ParseAddr(ipStr)
|
||||
if err != nil {
|
||||
log.Errorf("Error parsing IP address %s: %v", ipStr, err)
|
||||
return
|
||||
}
|
||||
|
||||
prefix := netip.PrefixFrom(ipAddr.Unmap(), ipAddr.BitLen())
|
||||
|
||||
addressRemoveHooks := hooks.GetAddressRemoveHooks()
|
||||
if len(addressRemoveHooks) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
for _, hook := range addressRemoveHooks {
|
||||
if err := hook(c.ID, prefix); err != nil {
|
||||
log.Errorf("Error executing listener address remove hook: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// WrapPacketConn wraps an existing net.PacketConn with nbnet hook functionality
|
||||
func WrapPacketConn(conn net.PacketConn) net.PacketConn {
|
||||
if AdvancedRouting() {
|
||||
// hooks not required for advanced routing
|
||||
return conn
|
||||
}
|
||||
return &PacketConn{
|
||||
PacketConn: conn,
|
||||
ID: hooks.GenerateConnID(),
|
||||
seenAddrs: &sync.Map{},
|
||||
}
|
||||
}
|
||||
|
||||
func callWriteHooks(id hooks.ConnectionID, seenAddrs *sync.Map, addr net.Addr) error {
|
||||
if _, loaded := seenAddrs.LoadOrStore(addr.String(), true); loaded {
|
||||
return nil
|
||||
}
|
||||
|
||||
writeHooks := hooks.GetWriteHooks()
|
||||
if len(writeHooks) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
udpAddr, ok := addr.(*net.UDPAddr)
|
||||
if !ok {
|
||||
return fmt.Errorf("expected *net.UDPAddr for packet connection, got %T", addr)
|
||||
}
|
||||
|
||||
prefix, err := util.GetPrefixFromIP(udpAddr.IP)
|
||||
if err != nil {
|
||||
return fmt.Errorf("convert UDP IP %s to prefix: %w", udpAddr.IP, err)
|
||||
}
|
||||
|
||||
log.Debugf("Listener resolved IP for %s: %s", addr, prefix)
|
||||
|
||||
var merr *multierror.Error
|
||||
for _, hook := range writeHooks {
|
||||
if err := hook(id, prefix); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("execute write hook: %w", err))
|
||||
}
|
||||
}
|
||||
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
@@ -1,284 +0,0 @@
|
||||
package net
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"strconv"
|
||||
"strings"
|
||||
"syscall"
|
||||
"time"
|
||||
"unsafe"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
const (
|
||||
// https://learn.microsoft.com/en-us/windows/win32/winsock/ipproto-ip-socket-options
|
||||
IpUnicastIf = 31
|
||||
Ipv6UnicastIf = 31
|
||||
|
||||
// https://learn.microsoft.com/en-us/windows/win32/winsock/ipproto-ipv6-socket-options
|
||||
Ipv6V6only = 27
|
||||
)
|
||||
|
||||
// GetBestInterfaceFunc is set at runtime to avoid import cycle
|
||||
var GetBestInterfaceFunc func(dest netip.Addr, vpnIntf string) (*net.Interface, error)
|
||||
|
||||
// nativeToBigEndian converts a uint32 from native byte order to big-endian
|
||||
func nativeToBigEndian(v uint32) uint32 {
|
||||
return (v&0xff)<<24 | (v&0xff00)<<8 | (v&0xff0000)>>8 | (v&0xff000000)>>24
|
||||
}
|
||||
|
||||
// parseDestinationAddress parses the destination address from various formats
|
||||
func parseDestinationAddress(network, address string) (netip.Addr, error) {
|
||||
if address == "" {
|
||||
if strings.HasSuffix(network, "6") {
|
||||
return netip.IPv6Unspecified(), nil
|
||||
}
|
||||
return netip.IPv4Unspecified(), nil
|
||||
}
|
||||
|
||||
if addrPort, err := netip.ParseAddrPort(address); err == nil {
|
||||
return addrPort.Addr(), nil
|
||||
}
|
||||
|
||||
if dest, err := netip.ParseAddr(address); err == nil {
|
||||
return dest, nil
|
||||
}
|
||||
|
||||
host, _, err := net.SplitHostPort(address)
|
||||
if err != nil {
|
||||
// No port, treat whole string as host
|
||||
host = address
|
||||
}
|
||||
|
||||
if host == "" {
|
||||
if strings.HasSuffix(network, "6") {
|
||||
return netip.IPv6Unspecified(), nil
|
||||
}
|
||||
return netip.IPv4Unspecified(), nil
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
ips, err := net.DefaultResolver.LookupIPAddr(ctx, host)
|
||||
if err != nil || len(ips) == 0 {
|
||||
return netip.Addr{}, fmt.Errorf("resolve destination %s: %w", host, err)
|
||||
}
|
||||
|
||||
dest, ok := netip.AddrFromSlice(ips[0].IP)
|
||||
if !ok {
|
||||
return netip.Addr{}, fmt.Errorf("convert IP %v to netip.Addr", ips[0].IP)
|
||||
}
|
||||
|
||||
if ips[0].Zone != "" {
|
||||
dest = dest.WithZone(ips[0].Zone)
|
||||
}
|
||||
|
||||
return dest, nil
|
||||
}
|
||||
|
||||
func getInterfaceFromZone(zone string) *net.Interface {
|
||||
if zone == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
idx, err := strconv.Atoi(zone)
|
||||
if err != nil {
|
||||
log.Debugf("invalid zone format for Windows (expected numeric): %s", zone)
|
||||
return nil
|
||||
}
|
||||
|
||||
iface, err := net.InterfaceByIndex(idx)
|
||||
if err != nil {
|
||||
log.Debugf("failed to get interface by index %d from zone: %v", idx, err)
|
||||
return nil
|
||||
}
|
||||
|
||||
return iface
|
||||
}
|
||||
|
||||
type interfaceSelection struct {
|
||||
iface4 *net.Interface
|
||||
iface6 *net.Interface
|
||||
}
|
||||
|
||||
func selectInterfaceForZone(dest netip.Addr, zone string) *interfaceSelection {
|
||||
iface := getInterfaceFromZone(zone)
|
||||
if iface == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if dest.Is6() {
|
||||
return &interfaceSelection{iface6: iface}
|
||||
}
|
||||
return &interfaceSelection{iface4: iface}
|
||||
}
|
||||
|
||||
func selectInterfaceForUnspecified() (*interfaceSelection, error) {
|
||||
if GetBestInterfaceFunc == nil {
|
||||
return nil, errors.New("GetBestInterfaceFunc not initialized")
|
||||
}
|
||||
|
||||
var result interfaceSelection
|
||||
vpnIfaceName := GetVPNInterfaceName()
|
||||
|
||||
if iface4, err := GetBestInterfaceFunc(netip.IPv4Unspecified(), vpnIfaceName); err == nil {
|
||||
result.iface4 = iface4
|
||||
} else {
|
||||
log.Debugf("No IPv4 default route found: %v", err)
|
||||
}
|
||||
|
||||
if iface6, err := GetBestInterfaceFunc(netip.IPv6Unspecified(), vpnIfaceName); err == nil {
|
||||
result.iface6 = iface6
|
||||
} else {
|
||||
log.Debugf("No IPv6 default route found: %v", err)
|
||||
}
|
||||
|
||||
if result.iface4 == nil && result.iface6 == nil {
|
||||
return nil, errors.New("no default routes found")
|
||||
}
|
||||
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
func selectInterface(dest netip.Addr) (*interfaceSelection, error) {
|
||||
if zone := dest.Zone(); zone != "" {
|
||||
if selection := selectInterfaceForZone(dest, zone); selection != nil {
|
||||
return selection, nil
|
||||
}
|
||||
}
|
||||
|
||||
if dest.IsUnspecified() {
|
||||
return selectInterfaceForUnspecified()
|
||||
}
|
||||
|
||||
if GetBestInterfaceFunc == nil {
|
||||
return nil, errors.New("GetBestInterfaceFunc not initialized")
|
||||
}
|
||||
|
||||
iface, err := GetBestInterfaceFunc(dest, GetVPNInterfaceName())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("find route for %s: %w", dest, err)
|
||||
}
|
||||
|
||||
if dest.Is6() {
|
||||
return &interfaceSelection{iface6: iface}, nil
|
||||
}
|
||||
return &interfaceSelection{iface4: iface}, nil
|
||||
}
|
||||
|
||||
func setIPv4UnicastIF(fd uintptr, iface *net.Interface) error {
|
||||
ifaceIndexBE := nativeToBigEndian(uint32(iface.Index))
|
||||
if err := windows.SetsockoptInt(windows.Handle(fd), windows.IPPROTO_IP, IpUnicastIf, int(ifaceIndexBE)); err != nil {
|
||||
return fmt.Errorf("set IP_UNICAST_IF: %w (interface: %s, index: %d)", err, iface.Name, iface.Index)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func setIPv6UnicastIF(fd uintptr, iface *net.Interface) error {
|
||||
if err := windows.SetsockoptInt(windows.Handle(fd), windows.IPPROTO_IPV6, Ipv6UnicastIf, iface.Index); err != nil {
|
||||
return fmt.Errorf("set IPV6_UNICAST_IF: %w (interface: %s, index: %d)", err, iface.Name, iface.Index)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func setUnicastIf(fd uintptr, network string, selection *interfaceSelection, address string) error {
|
||||
// The Go runtime always passes specific network types to Control (udp4, udp6, tcp4, tcp6, etc.)
|
||||
// Never generic ones (udp, tcp, ip)
|
||||
|
||||
switch {
|
||||
case strings.HasSuffix(network, "4"):
|
||||
// IPv4-only socket (udp4, tcp4, ip4)
|
||||
return setUnicastIfIPv4(fd, network, selection, address)
|
||||
|
||||
case strings.HasSuffix(network, "6"):
|
||||
// IPv6 socket (udp6, tcp6, ip6) - could be dual-stack or IPv6-only
|
||||
return setUnicastIfIPv6(fd, network, selection, address)
|
||||
}
|
||||
|
||||
// Shouldn't reach here based on Go's documented behavior
|
||||
return fmt.Errorf("unexpected network type: %s", network)
|
||||
}
|
||||
|
||||
func setUnicastIfIPv4(fd uintptr, network string, selection *interfaceSelection, address string) error {
|
||||
if selection.iface4 == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := setIPv4UnicastIF(fd, selection.iface4); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
log.Debugf("Set IP_UNICAST_IF=%d on %s for %s to %s", selection.iface4.Index, selection.iface4.Name, network, address)
|
||||
return nil
|
||||
}
|
||||
|
||||
func setUnicastIfIPv6(fd uintptr, network string, selection *interfaceSelection, address string) error {
|
||||
isDualStack := checkDualStack(fd)
|
||||
|
||||
// For dual-stack sockets, also set the IPv4 option
|
||||
if isDualStack && selection.iface4 != nil {
|
||||
if err := setIPv4UnicastIF(fd, selection.iface4); err != nil {
|
||||
return err
|
||||
}
|
||||
log.Debugf("Set IP_UNICAST_IF=%d on %s for %s to %s (dual-stack)", selection.iface4.Index, selection.iface4.Name, network, address)
|
||||
}
|
||||
|
||||
if selection.iface6 == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := setIPv6UnicastIF(fd, selection.iface6); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
log.Debugf("Set IPV6_UNICAST_IF=%d on %s for %s to %s", selection.iface6.Index, selection.iface6.Name, network, address)
|
||||
return nil
|
||||
}
|
||||
|
||||
func checkDualStack(fd uintptr) bool {
|
||||
var v6Only int
|
||||
v6OnlyLen := int32(unsafe.Sizeof(v6Only))
|
||||
err := windows.Getsockopt(windows.Handle(fd), windows.IPPROTO_IPV6, Ipv6V6only, (*byte)(unsafe.Pointer(&v6Only)), &v6OnlyLen)
|
||||
return err == nil && v6Only == 0
|
||||
}
|
||||
|
||||
// applyUnicastIFToSocket applies IpUnicastIf to a socket based on the destination address
|
||||
func applyUnicastIFToSocket(network string, address string, c syscall.RawConn) error {
|
||||
if !AdvancedRouting() {
|
||||
return nil
|
||||
}
|
||||
|
||||
dest, err := parseDestinationAddress(network, address)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
dest = dest.Unmap()
|
||||
|
||||
if !dest.IsValid() {
|
||||
return fmt.Errorf("invalid destination address for %s", address)
|
||||
}
|
||||
|
||||
selection, err := selectInterface(dest)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var controlErr error
|
||||
err = c.Control(func(fd uintptr) {
|
||||
controlErr = setUnicastIf(fd, network, selection, address)
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("control: %w", err)
|
||||
}
|
||||
|
||||
return controlErr
|
||||
}
|
||||
@@ -2,7 +2,7 @@
|
||||
set -eEuo pipefail
|
||||
|
||||
: ${NB_ENTRYPOINT_SERVICE_TIMEOUT:="5"}
|
||||
: ${NB_ENTRYPOINT_LOGIN_TIMEOUT:="5"}
|
||||
: ${NB_ENTRYPOINT_LOGIN_TIMEOUT:="1"}
|
||||
NETBIRD_BIN="${NETBIRD_BIN:-"netbird"}"
|
||||
export NB_LOG_FILE="${NB_LOG_FILE:-"console,/var/log/netbird/client.log"}"
|
||||
service_pids=()
|
||||
@@ -39,7 +39,7 @@ wait_for_message() {
|
||||
info "not waiting for log line ${message@Q} due to zero timeout."
|
||||
elif test -n "${log_file_path}"; then
|
||||
info "waiting for log line ${message@Q} for ${timeout} seconds..."
|
||||
grep -E -q "${message}" <(timeout "${timeout}" tail -F "${log_file_path}" 2>/dev/null)
|
||||
grep -q "${message}" <(timeout "${timeout}" tail -F "${log_file_path}" 2>/dev/null)
|
||||
else
|
||||
info "log file unsupported, sleeping for ${timeout} seconds..."
|
||||
sleep "${timeout}"
|
||||
@@ -81,7 +81,7 @@ wait_for_daemon_startup() {
|
||||
login_if_needed() {
|
||||
local timeout="${1}"
|
||||
|
||||
if test -n "${log_file_path}" && wait_for_message "${timeout}" 'peer has been successfully registered|management connection state READY'; then
|
||||
if test -n "${log_file_path}" && wait_for_message "${timeout}" 'peer has been successfully registered'; then
|
||||
info "already logged in, skipping 'netbird up'..."
|
||||
else
|
||||
info "logging in..."
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
// Code generated by protoc-gen-go. DO NOT EDIT.
|
||||
// versions:
|
||||
// protoc-gen-go v1.36.6
|
||||
// protoc v6.32.1
|
||||
// protoc v5.29.3
|
||||
// source: daemon.proto
|
||||
|
||||
package proto
|
||||
@@ -794,10 +794,8 @@ type StatusRequest struct {
|
||||
state protoimpl.MessageState `protogen:"open.v1"`
|
||||
GetFullPeerStatus bool `protobuf:"varint,1,opt,name=getFullPeerStatus,proto3" json:"getFullPeerStatus,omitempty"`
|
||||
ShouldRunProbes bool `protobuf:"varint,2,opt,name=shouldRunProbes,proto3" json:"shouldRunProbes,omitempty"`
|
||||
// the UI do not using this yet, but CLIs could use it to wait until the status is ready
|
||||
WaitForReady *bool `protobuf:"varint,3,opt,name=waitForReady,proto3,oneof" json:"waitForReady,omitempty"`
|
||||
unknownFields protoimpl.UnknownFields
|
||||
sizeCache protoimpl.SizeCache
|
||||
unknownFields protoimpl.UnknownFields
|
||||
sizeCache protoimpl.SizeCache
|
||||
}
|
||||
|
||||
func (x *StatusRequest) Reset() {
|
||||
@@ -844,13 +842,6 @@ func (x *StatusRequest) GetShouldRunProbes() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (x *StatusRequest) GetWaitForReady() bool {
|
||||
if x != nil && x.WaitForReady != nil {
|
||||
return *x.WaitForReady
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
type StatusResponse struct {
|
||||
state protoimpl.MessageState `protogen:"open.v1"`
|
||||
// status of the server.
|
||||
@@ -4682,12 +4673,10 @@ const file_daemon_proto_rawDesc = "" +
|
||||
"\f_profileNameB\v\n" +
|
||||
"\t_username\"\f\n" +
|
||||
"\n" +
|
||||
"UpResponse\"\xa1\x01\n" +
|
||||
"UpResponse\"g\n" +
|
||||
"\rStatusRequest\x12,\n" +
|
||||
"\x11getFullPeerStatus\x18\x01 \x01(\bR\x11getFullPeerStatus\x12(\n" +
|
||||
"\x0fshouldRunProbes\x18\x02 \x01(\bR\x0fshouldRunProbes\x12'\n" +
|
||||
"\fwaitForReady\x18\x03 \x01(\bH\x00R\fwaitForReady\x88\x01\x01B\x0f\n" +
|
||||
"\r_waitForReady\"\x82\x01\n" +
|
||||
"\x0fshouldRunProbes\x18\x02 \x01(\bR\x0fshouldRunProbes\"\x82\x01\n" +
|
||||
"\x0eStatusResponse\x12\x16\n" +
|
||||
"\x06status\x18\x01 \x01(\tR\x06status\x122\n" +
|
||||
"\n" +
|
||||
@@ -5242,7 +5231,6 @@ func file_daemon_proto_init() {
|
||||
}
|
||||
file_daemon_proto_msgTypes[1].OneofWrappers = []any{}
|
||||
file_daemon_proto_msgTypes[5].OneofWrappers = []any{}
|
||||
file_daemon_proto_msgTypes[7].OneofWrappers = []any{}
|
||||
file_daemon_proto_msgTypes[26].OneofWrappers = []any{
|
||||
(*PortInfo_Port)(nil),
|
||||
(*PortInfo_Range_)(nil),
|
||||
|
||||
@@ -186,8 +186,6 @@ message UpResponse {}
|
||||
message StatusRequest{
|
||||
bool getFullPeerStatus = 1;
|
||||
bool shouldRunProbes = 2;
|
||||
// the UI do not using this yet, but CLIs could use it to wait until the status is ready
|
||||
optional bool waitForReady = 3;
|
||||
}
|
||||
|
||||
message StatusResponse{
|
||||
|
||||
@@ -67,7 +67,6 @@ type Server struct {
|
||||
proto.UnimplementedDaemonServiceServer
|
||||
clientRunning bool // protected by mutex
|
||||
clientRunningChan chan struct{}
|
||||
clientGiveUpChan chan struct{}
|
||||
|
||||
connectClient *internal.ConnectClient
|
||||
|
||||
@@ -107,10 +106,6 @@ func (s *Server) Start() error {
|
||||
s.mutex.Lock()
|
||||
defer s.mutex.Unlock()
|
||||
|
||||
if s.clientRunning {
|
||||
return nil
|
||||
}
|
||||
|
||||
state := internal.CtxGetState(s.rootCtx)
|
||||
|
||||
if err := handlePanicLog(); err != nil {
|
||||
@@ -180,10 +175,12 @@ func (s *Server) Start() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
if s.clientRunning {
|
||||
return nil
|
||||
}
|
||||
s.clientRunning = true
|
||||
s.clientRunningChan = make(chan struct{})
|
||||
s.clientGiveUpChan = make(chan struct{})
|
||||
go s.connectWithRetryRuns(ctx, config, s.statusRecorder, s.clientRunningChan, s.clientGiveUpChan)
|
||||
s.clientRunningChan = make(chan struct{}, 1)
|
||||
go s.connectWithRetryRuns(ctx, config, s.statusRecorder, s.clientRunningChan)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -214,7 +211,7 @@ func (s *Server) setDefaultConfigIfNotExists(ctx context.Context) error {
|
||||
// connectWithRetryRuns runs the client connection with a backoff strategy where we retry the operation as additional
|
||||
// mechanism to keep the client connected even when the connection is lost.
|
||||
// we cancel retry if the client receive a stop or down command, or if disable auto connect is configured.
|
||||
func (s *Server) connectWithRetryRuns(ctx context.Context, profileConfig *profilemanager.Config, statusRecorder *peer.Status, runningChan chan struct{}, giveUpChan chan struct{}) {
|
||||
func (s *Server) connectWithRetryRuns(ctx context.Context, profileConfig *profilemanager.Config, statusRecorder *peer.Status, runningChan chan struct{}) {
|
||||
defer func() {
|
||||
s.mutex.Lock()
|
||||
s.clientRunning = false
|
||||
@@ -264,10 +261,6 @@ func (s *Server) connectWithRetryRuns(ctx context.Context, profileConfig *profil
|
||||
if err := backoff.Retry(runOperation, backOff); err != nil {
|
||||
log.Errorf("operation failed: %v", err)
|
||||
}
|
||||
|
||||
if giveUpChan != nil {
|
||||
close(giveUpChan)
|
||||
}
|
||||
}
|
||||
|
||||
// loginAttempt attempts to login using the provided information. it returns a status in case something fails
|
||||
@@ -386,7 +379,7 @@ func (s *Server) Login(callerCtx context.Context, msg *proto.LoginRequest) (*pro
|
||||
if s.actCancel != nil {
|
||||
s.actCancel()
|
||||
}
|
||||
ctx, cancel := context.WithCancel(callerCtx)
|
||||
ctx, cancel := context.WithCancel(s.rootCtx)
|
||||
|
||||
md, ok := metadata.FromIncomingContext(callerCtx)
|
||||
if ok {
|
||||
@@ -396,11 +389,11 @@ func (s *Server) Login(callerCtx context.Context, msg *proto.LoginRequest) (*pro
|
||||
s.actCancel = cancel
|
||||
s.mutex.Unlock()
|
||||
|
||||
if err := restoreResidualState(s.rootCtx, s.profileManager.GetStatePath()); err != nil {
|
||||
if err := restoreResidualState(ctx, s.profileManager.GetStatePath()); err != nil {
|
||||
log.Warnf(errRestoreResidualState, err)
|
||||
}
|
||||
|
||||
state := internal.CtxGetState(s.rootCtx)
|
||||
state := internal.CtxGetState(ctx)
|
||||
defer func() {
|
||||
status, err := state.Status()
|
||||
if err != nil || (status != internal.StatusNeedsLogin && status != internal.StatusLoginFailed) {
|
||||
@@ -613,20 +606,6 @@ func (s *Server) WaitSSOLogin(callerCtx context.Context, msg *proto.WaitSSOLogin
|
||||
// Up starts engine work in the daemon.
|
||||
func (s *Server) Up(callerCtx context.Context, msg *proto.UpRequest) (*proto.UpResponse, error) {
|
||||
s.mutex.Lock()
|
||||
if s.clientRunning {
|
||||
state := internal.CtxGetState(s.rootCtx)
|
||||
status, err := state.Status()
|
||||
if err != nil {
|
||||
s.mutex.Unlock()
|
||||
return nil, err
|
||||
}
|
||||
if status == internal.StatusNeedsLogin {
|
||||
s.actCancel()
|
||||
}
|
||||
s.mutex.Unlock()
|
||||
|
||||
return s.waitForUp(callerCtx)
|
||||
}
|
||||
defer s.mutex.Unlock()
|
||||
|
||||
if err := restoreResidualState(callerCtx, s.profileManager.GetStatePath()); err != nil {
|
||||
@@ -642,16 +621,16 @@ func (s *Server) Up(callerCtx context.Context, msg *proto.UpRequest) (*proto.UpR
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if status != internal.StatusIdle {
|
||||
return nil, fmt.Errorf("up already in progress: current status %s", status)
|
||||
}
|
||||
|
||||
// it should be nil here, but in case it isn't we cancel it.
|
||||
// it should be nil here, but .
|
||||
if s.actCancel != nil {
|
||||
s.actCancel()
|
||||
}
|
||||
ctx, cancel := context.WithCancel(s.rootCtx)
|
||||
|
||||
md, ok := metadata.FromIncomingContext(callerCtx)
|
||||
if ok {
|
||||
ctx = metadata.NewOutgoingContext(ctx, md)
|
||||
@@ -694,31 +673,26 @@ func (s *Server) Up(callerCtx context.Context, msg *proto.UpRequest) (*proto.UpR
|
||||
s.statusRecorder.UpdateManagementAddress(s.config.ManagementURL.String())
|
||||
s.statusRecorder.UpdateRosenpass(s.config.RosenpassEnabled, s.config.RosenpassPermissive)
|
||||
|
||||
s.clientRunning = true
|
||||
s.clientRunningChan = make(chan struct{})
|
||||
s.clientGiveUpChan = make(chan struct{})
|
||||
go s.connectWithRetryRuns(ctx, s.config, s.statusRecorder, s.clientRunningChan, s.clientGiveUpChan)
|
||||
|
||||
return s.waitForUp(callerCtx)
|
||||
}
|
||||
|
||||
// todo: handle potential race conditions
|
||||
func (s *Server) waitForUp(callerCtx context.Context) (*proto.UpResponse, error) {
|
||||
timeoutCtx, cancel := context.WithTimeout(callerCtx, 50*time.Second)
|
||||
defer cancel()
|
||||
|
||||
select {
|
||||
case <-s.clientGiveUpChan:
|
||||
return nil, fmt.Errorf("client gave up to connect")
|
||||
case <-s.clientRunningChan:
|
||||
s.isSessionActive.Store(true)
|
||||
return &proto.UpResponse{}, nil
|
||||
case <-callerCtx.Done():
|
||||
log.Debug("context done, stopping the wait for engine to become ready")
|
||||
return nil, callerCtx.Err()
|
||||
case <-timeoutCtx.Done():
|
||||
log.Debug("up is timed out, stopping the wait for engine to become ready")
|
||||
return nil, timeoutCtx.Err()
|
||||
if !s.clientRunning {
|
||||
s.clientRunning = true
|
||||
s.clientRunningChan = make(chan struct{}, 1)
|
||||
go s.connectWithRetryRuns(ctx, s.config, s.statusRecorder, s.clientRunningChan)
|
||||
}
|
||||
for {
|
||||
select {
|
||||
case <-s.clientRunningChan:
|
||||
s.isSessionActive.Store(true)
|
||||
return &proto.UpResponse{}, nil
|
||||
case <-callerCtx.Done():
|
||||
log.Debug("context done, stopping the wait for engine to become ready")
|
||||
return nil, callerCtx.Err()
|
||||
case <-timeoutCtx.Done():
|
||||
log.Debug("up is timed out, stopping the wait for engine to become ready")
|
||||
return nil, timeoutCtx.Err()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -992,47 +966,13 @@ func (s *Server) Status(
|
||||
ctx context.Context,
|
||||
msg *proto.StatusRequest,
|
||||
) (*proto.StatusResponse, error) {
|
||||
s.mutex.Lock()
|
||||
clientRunning := s.clientRunning
|
||||
s.mutex.Unlock()
|
||||
|
||||
if msg.WaitForReady != nil && *msg.WaitForReady && clientRunning {
|
||||
state := internal.CtxGetState(s.rootCtx)
|
||||
status, err := state.Status()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if status != internal.StatusIdle && status != internal.StatusConnected && status != internal.StatusConnecting {
|
||||
s.actCancel()
|
||||
}
|
||||
|
||||
ticker := time.NewTicker(1 * time.Second)
|
||||
defer ticker.Stop()
|
||||
loop:
|
||||
for {
|
||||
select {
|
||||
case <-s.clientGiveUpChan:
|
||||
ticker.Stop()
|
||||
break loop
|
||||
case <-s.clientRunningChan:
|
||||
ticker.Stop()
|
||||
break loop
|
||||
case <-ticker.C:
|
||||
status, err := state.Status()
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
if status != internal.StatusIdle && status != internal.StatusConnected && status != internal.StatusConnecting {
|
||||
s.actCancel()
|
||||
}
|
||||
continue
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
}
|
||||
if ctx.Err() != nil {
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
|
||||
s.mutex.Lock()
|
||||
defer s.mutex.Unlock()
|
||||
|
||||
status, err := internal.CtxGetState(s.rootCtx).Status()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
||||
@@ -105,7 +105,7 @@ func TestConnectWithRetryRuns(t *testing.T) {
|
||||
t.Setenv(maxRetryTimeVar, "5s")
|
||||
t.Setenv(retryMultiplierVar, "1")
|
||||
|
||||
s.connectWithRetryRuns(ctx, config, s.statusRecorder, nil, nil)
|
||||
s.connectWithRetryRuns(ctx, config, s.statusRecorder, nil)
|
||||
if counter < 3 {
|
||||
t.Fatalf("expected counter > 2, got %d", counter)
|
||||
}
|
||||
@@ -134,12 +134,8 @@ func TestServer_Up(t *testing.T) {
|
||||
|
||||
profName := "default"
|
||||
|
||||
u, err := url.Parse("http://non-existent-url-for-testing.invalid:12345")
|
||||
require.NoError(t, err)
|
||||
|
||||
ic := profilemanager.ConfigInput{
|
||||
ConfigPath: filepath.Join(tempDir, profName+".json"),
|
||||
ManagementURL: u.String(),
|
||||
ConfigPath: filepath.Join(tempDir, profName+".json"),
|
||||
}
|
||||
|
||||
_, err = profilemanager.UpdateOrCreateConfig(ic)
|
||||
@@ -157,9 +153,16 @@ func TestServer_Up(t *testing.T) {
|
||||
}
|
||||
|
||||
s := New(ctx, "console", "", false, false)
|
||||
|
||||
err = s.Start()
|
||||
require.NoError(t, err)
|
||||
|
||||
u, err := url.Parse("http://non-existent-url-for-testing.invalid:12345")
|
||||
require.NoError(t, err)
|
||||
s.config = &profilemanager.Config{
|
||||
ManagementURL: u,
|
||||
}
|
||||
|
||||
upCtx, cancel := context.WithTimeout(ctx, 1*time.Second)
|
||||
defer cancel()
|
||||
|
||||
@@ -168,7 +171,6 @@ func TestServer_Up(t *testing.T) {
|
||||
Username: &currUser.Username,
|
||||
}
|
||||
_, err = s.Up(upCtx, upReq)
|
||||
log.Errorf("error from Up: %v", err)
|
||||
|
||||
assert.Contains(t, err.Error(), "context deadline exceeded")
|
||||
}
|
||||
|
||||
@@ -529,7 +529,7 @@ func (s *serviceClient) getSettingsForm() *widget.Form {
|
||||
var req proto.SetConfigRequest
|
||||
req.ProfileName = activeProf.Name
|
||||
req.Username = currUser.Username
|
||||
|
||||
|
||||
if iMngURL != "" {
|
||||
req.ManagementUrl = iMngURL
|
||||
}
|
||||
@@ -563,28 +563,27 @@ func (s *serviceClient) getSettingsForm() *widget.Form {
|
||||
return
|
||||
}
|
||||
|
||||
go func() {
|
||||
status, err := conn.Status(s.ctx, &proto.StatusRequest{})
|
||||
status, err := conn.Status(s.ctx, &proto.StatusRequest{})
|
||||
if err != nil {
|
||||
log.Errorf("get service status: %v", err)
|
||||
dialog.ShowError(fmt.Errorf("Failed to get service status: %v", err), s.wSettings)
|
||||
return
|
||||
}
|
||||
if status.Status == string(internal.StatusConnected) {
|
||||
// run down & up
|
||||
_, err = conn.Down(s.ctx, &proto.DownRequest{})
|
||||
if err != nil {
|
||||
log.Errorf("get service status: %v", err)
|
||||
dialog.ShowError(fmt.Errorf("Failed to get service status: %v", err), s.wSettings)
|
||||
log.Errorf("down service: %v", err)
|
||||
}
|
||||
|
||||
_, err = conn.Up(s.ctx, &proto.UpRequest{})
|
||||
if err != nil {
|
||||
log.Errorf("up service: %v", err)
|
||||
dialog.ShowError(fmt.Errorf("Failed to reconnect: %v", err), s.wSettings)
|
||||
return
|
||||
}
|
||||
if status.Status == string(internal.StatusConnected) {
|
||||
// run down & up
|
||||
_, err = conn.Down(s.ctx, &proto.DownRequest{})
|
||||
if err != nil {
|
||||
log.Errorf("down service: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
_, err = conn.Up(s.ctx, &proto.UpRequest{})
|
||||
if err != nil {
|
||||
log.Errorf("up service: %v", err)
|
||||
dialog.ShowError(fmt.Errorf("Failed to reconnect: %v", err), s.wSettings)
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
},
|
||||
OnCancel: func() {
|
||||
|
||||
@@ -20,9 +20,9 @@ import (
|
||||
"google.golang.org/grpc/keepalive"
|
||||
"google.golang.org/grpc/status"
|
||||
|
||||
nbgrpc "github.com/netbirdio/netbird/client/grpc"
|
||||
"github.com/netbirdio/netbird/flow/proto"
|
||||
"github.com/netbirdio/netbird/util/embeddedroots"
|
||||
nbgrpc "github.com/netbirdio/netbird/util/grpc"
|
||||
)
|
||||
|
||||
type GRPCClient struct {
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
FROM ubuntu:24.04
|
||||
FROM ubuntu:24.10
|
||||
RUN apt update && apt install -y ca-certificates && rm -fr /var/cache/apt
|
||||
ENTRYPOINT [ "/go/bin/netbird-mgmt","management","--log-level","debug"]
|
||||
CMD ["--log-file", "console"]
|
||||
|
||||
@@ -302,11 +302,7 @@ func (a *Account) GetPeerNetworkMap(
|
||||
var zones []nbdns.CustomZone
|
||||
|
||||
if peersCustomZone.Domain != "" {
|
||||
records := filterZoneRecordsForPeers(peer, peersCustomZone, peersToConnect)
|
||||
zones = append(zones, nbdns.CustomZone{
|
||||
Domain: peersCustomZone.Domain,
|
||||
Records: records,
|
||||
})
|
||||
zones = append(zones, peersCustomZone)
|
||||
}
|
||||
dnsUpdate.CustomZones = zones
|
||||
dnsUpdate.NameServerGroups = getPeerNSGroups(a, peerID)
|
||||
@@ -1655,24 +1651,3 @@ func peerSupportsPortRanges(peerVer string) bool {
|
||||
meetMinVer, err := posture.MeetsMinVersion(firewallRuleMinPortRangesVer, peerVer)
|
||||
return err == nil && meetMinVer
|
||||
}
|
||||
|
||||
// filterZoneRecordsForPeers filters DNS records to only include peers to connect.
|
||||
func filterZoneRecordsForPeers(peer *nbpeer.Peer, customZone nbdns.CustomZone, peersToConnect []*nbpeer.Peer) []nbdns.SimpleRecord {
|
||||
filteredRecords := make([]nbdns.SimpleRecord, 0, len(customZone.Records))
|
||||
peerIPs := make(map[string]struct{})
|
||||
|
||||
// Add peer's own IP to include its own DNS records
|
||||
peerIPs[peer.IP.String()] = struct{}{}
|
||||
|
||||
for _, peerToConnect := range peersToConnect {
|
||||
peerIPs[peerToConnect.IP.String()] = struct{}{}
|
||||
}
|
||||
|
||||
for _, record := range customZone.Records {
|
||||
if _, exists := peerIPs[record.RData]; exists {
|
||||
filteredRecords = append(filteredRecords, record)
|
||||
}
|
||||
}
|
||||
|
||||
return filteredRecords
|
||||
}
|
||||
|
||||
@@ -2,17 +2,14 @@ package types
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"slices"
|
||||
"testing"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
|
||||
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
|
||||
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
|
||||
@@ -838,109 +835,3 @@ func Test_NetworksNetMapGenShouldExcludeOtherRouters(t *testing.T) {
|
||||
assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match")
|
||||
assert.Len(t, sourcePeers, 2, "expected source peers don't match")
|
||||
}
|
||||
|
||||
func Test_FilterZoneRecordsForPeers(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
peer *nbpeer.Peer
|
||||
customZone nbdns.CustomZone
|
||||
peersToConnect []*nbpeer.Peer
|
||||
expectedRecords []nbdns.SimpleRecord
|
||||
}{
|
||||
{
|
||||
name: "empty peers to connect",
|
||||
customZone: nbdns.CustomZone{
|
||||
Domain: "netbird.cloud.",
|
||||
Records: []nbdns.SimpleRecord{
|
||||
{Name: "peer1.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"},
|
||||
{Name: "router.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.100"},
|
||||
},
|
||||
},
|
||||
peersToConnect: []*nbpeer.Peer{},
|
||||
peer: &nbpeer.Peer{ID: "router", IP: net.ParseIP("10.0.0.100")},
|
||||
expectedRecords: []nbdns.SimpleRecord{
|
||||
{Name: "router.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.100"},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "multiple peers multiple records match",
|
||||
customZone: nbdns.CustomZone{
|
||||
Domain: "netbird.cloud.",
|
||||
Records: func() []nbdns.SimpleRecord {
|
||||
var records []nbdns.SimpleRecord
|
||||
for i := 1; i <= 100; i++ {
|
||||
records = append(records, nbdns.SimpleRecord{
|
||||
Name: fmt.Sprintf("peer%d.netbird.cloud", i),
|
||||
Type: int(dns.TypeA),
|
||||
Class: nbdns.DefaultClass,
|
||||
TTL: 300,
|
||||
RData: fmt.Sprintf("10.0.%d.%d", i/256, i%256),
|
||||
})
|
||||
}
|
||||
return records
|
||||
}(),
|
||||
},
|
||||
peersToConnect: func() []*nbpeer.Peer {
|
||||
var peers []*nbpeer.Peer
|
||||
for _, i := range []int{1, 5, 10, 25, 50, 75, 100} {
|
||||
peers = append(peers, &nbpeer.Peer{
|
||||
ID: fmt.Sprintf("peer%d", i),
|
||||
IP: net.ParseIP(fmt.Sprintf("10.0.%d.%d", i/256, i%256)),
|
||||
})
|
||||
}
|
||||
return peers
|
||||
}(),
|
||||
peer: &nbpeer.Peer{ID: "router", IP: net.ParseIP("10.0.0.100")},
|
||||
expectedRecords: func() []nbdns.SimpleRecord {
|
||||
var records []nbdns.SimpleRecord
|
||||
for _, i := range []int{1, 5, 10, 25, 50, 75, 100} {
|
||||
records = append(records, nbdns.SimpleRecord{
|
||||
Name: fmt.Sprintf("peer%d.netbird.cloud", i),
|
||||
Type: int(dns.TypeA),
|
||||
Class: nbdns.DefaultClass,
|
||||
TTL: 300,
|
||||
RData: fmt.Sprintf("10.0.%d.%d", i/256, i%256),
|
||||
})
|
||||
}
|
||||
return records
|
||||
}(),
|
||||
},
|
||||
{
|
||||
name: "peers with multiple DNS labels",
|
||||
customZone: nbdns.CustomZone{
|
||||
Domain: "netbird.cloud.",
|
||||
Records: []nbdns.SimpleRecord{
|
||||
{Name: "peer1.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"},
|
||||
{Name: "peer1-alt.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"},
|
||||
{Name: "peer1-backup.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"},
|
||||
{Name: "peer2.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.2"},
|
||||
{Name: "peer2-service.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.2"},
|
||||
{Name: "peer3.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.3"},
|
||||
{Name: "peer3-alt.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.3"},
|
||||
{Name: "router.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.100"},
|
||||
},
|
||||
},
|
||||
peersToConnect: []*nbpeer.Peer{
|
||||
{ID: "peer1", IP: net.ParseIP("10.0.0.1"), DNSLabel: "peer1", ExtraDNSLabels: []string{"peer1-alt", "peer1-backup"}},
|
||||
{ID: "peer2", IP: net.ParseIP("10.0.0.2"), DNSLabel: "peer2", ExtraDNSLabels: []string{"peer2-service"}},
|
||||
},
|
||||
peer: &nbpeer.Peer{ID: "router", IP: net.ParseIP("10.0.0.100")},
|
||||
expectedRecords: []nbdns.SimpleRecord{
|
||||
{Name: "peer1.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"},
|
||||
{Name: "peer1-alt.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"},
|
||||
{Name: "peer1-backup.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"},
|
||||
{Name: "peer2.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.2"},
|
||||
{Name: "peer2-service.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.2"},
|
||||
{Name: "router.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.100"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := filterZoneRecordsForPeers(tt.peer, tt.customZone, tt.peersToConnect)
|
||||
assert.Equal(t, len(tt.expectedRecords), len(result))
|
||||
assert.ElementsMatch(t, tt.expectedRecords, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -17,11 +17,11 @@ import (
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/connectivity"
|
||||
|
||||
nbgrpc "github.com/netbirdio/netbird/client/grpc"
|
||||
"github.com/netbirdio/netbird/client/system"
|
||||
"github.com/netbirdio/netbird/encryption"
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
nbgrpc "github.com/netbirdio/netbird/util/grpc"
|
||||
)
|
||||
|
||||
const ConnectTimeout = 10 * time.Second
|
||||
@@ -52,7 +52,7 @@ func NewClient(ctx context.Context, addr string, ourPrivateKey wgtypes.Key, tlsE
|
||||
|
||||
operation := func() error {
|
||||
var err error
|
||||
conn, err = nbgrpc.CreateConnection(ctx, addr, tlsEnabled)
|
||||
conn, err = nbgrpc.CreateConnection(addr, tlsEnabled)
|
||||
if err != nil {
|
||||
log.Printf("createConnection error: %v", err)
|
||||
return err
|
||||
|
||||
@@ -12,7 +12,7 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
quictls "github.com/netbirdio/netbird/shared/relay/tls"
|
||||
nbnet "github.com/netbirdio/netbird/client/net"
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
)
|
||||
|
||||
type Dialer struct {
|
||||
|
||||
@@ -16,7 +16,7 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/shared/relay"
|
||||
"github.com/netbirdio/netbird/util/embeddedroots"
|
||||
nbnet "github.com/netbirdio/netbird/client/net"
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
)
|
||||
|
||||
type Dialer struct {
|
||||
|
||||
@@ -16,10 +16,10 @@ import (
|
||||
"google.golang.org/grpc/metadata"
|
||||
"google.golang.org/grpc/status"
|
||||
|
||||
nbgrpc "github.com/netbirdio/netbird/client/grpc"
|
||||
"github.com/netbirdio/netbird/encryption"
|
||||
"github.com/netbirdio/netbird/shared/management/client"
|
||||
"github.com/netbirdio/netbird/shared/signal/proto"
|
||||
nbgrpc "github.com/netbirdio/netbird/util/grpc"
|
||||
)
|
||||
|
||||
// ConnStateNotifier is a wrapper interface of the status recorder
|
||||
@@ -57,7 +57,7 @@ func NewClient(ctx context.Context, addr string, key wgtypes.Key, tlsEnabled boo
|
||||
|
||||
operation := func() error {
|
||||
var err error
|
||||
conn, err = nbgrpc.CreateConnection(ctx, addr, tlsEnabled)
|
||||
conn, err = nbgrpc.CreateConnection(addr, tlsEnabled)
|
||||
if err != nil {
|
||||
log.Printf("createConnection error: %v", err)
|
||||
return err
|
||||
|
||||
@@ -22,7 +22,7 @@ import (
|
||||
"golang.org/x/sync/errgroup"
|
||||
"golang.org/x/sys/unix"
|
||||
|
||||
nbnet "github.com/netbirdio/netbird/client/net"
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
)
|
||||
|
||||
// ErrSharedSockStopped indicates that shared socket has been stopped
|
||||
@@ -93,7 +93,7 @@ func Listen(port int, filter BPFFilter, mtu uint16) (_ net.PacketConn, err error
|
||||
}
|
||||
|
||||
if err = nbnet.SetSocketMark(rawSock.conn4); err != nil {
|
||||
return nil, fmt.Errorf("set SO_MARK on ipv4 socket: %w", err)
|
||||
return nil, fmt.Errorf("failed to set SO_MARK on ipv4 socket: %w", err)
|
||||
}
|
||||
|
||||
var sockErr error
|
||||
@@ -102,7 +102,7 @@ func Listen(port int, filter BPFFilter, mtu uint16) (_ net.PacketConn, err error
|
||||
log.Errorf("Failed to create ipv6 raw socket: %v", err)
|
||||
} else {
|
||||
if err = nbnet.SetSocketMark(rawSock.conn6); err != nil {
|
||||
return nil, fmt.Errorf("set SO_MARK on ipv6 socket: %w", err)
|
||||
return nil, fmt.Errorf("failed to set SO_MARK on ipv6 socket: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -20,9 +20,8 @@ import (
|
||||
"google.golang.org/grpc/credentials/insecure"
|
||||
"google.golang.org/grpc/keepalive"
|
||||
|
||||
nbnet "github.com/netbirdio/netbird/client/net"
|
||||
|
||||
"github.com/netbirdio/netbird/util/embeddedroots"
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
)
|
||||
|
||||
func WithCustomDialer() grpc.DialOption {
|
||||
@@ -58,7 +57,7 @@ func Backoff(ctx context.Context) backoff.BackOff {
|
||||
return backoff.WithContext(b, ctx)
|
||||
}
|
||||
|
||||
func CreateConnection(ctx context.Context, addr string, tlsEnabled bool) (*grpc.ClientConn, error) {
|
||||
func CreateConnection(addr string, tlsEnabled bool) (*grpc.ClientConn, error) {
|
||||
transportOption := grpc.WithTransportCredentials(insecure.NewCredentials())
|
||||
if tlsEnabled {
|
||||
certPool, err := x509.SystemCertPool()
|
||||
@@ -72,7 +71,7 @@ func CreateConnection(ctx context.Context, addr string, tlsEnabled bool) (*grpc.
|
||||
}))
|
||||
}
|
||||
|
||||
connCtx, cancel := context.WithTimeout(ctx, 30*time.Second)
|
||||
connCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
conn, err := grpc.DialContext(
|
||||
31
util/net/conn.go
Normal file
31
util/net/conn.go
Normal file
@@ -0,0 +1,31 @@
|
||||
//go:build !ios
|
||||
|
||||
package net
|
||||
|
||||
import (
|
||||
"net"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// Conn wraps a net.Conn to override the Close method
|
||||
type Conn struct {
|
||||
net.Conn
|
||||
ID ConnectionID
|
||||
}
|
||||
|
||||
// Close overrides the net.Conn Close method to execute all registered hooks after closing the connection
|
||||
func (c *Conn) Close() error {
|
||||
err := c.Conn.Close()
|
||||
|
||||
dialerCloseHooksMutex.RLock()
|
||||
defer dialerCloseHooksMutex.RUnlock()
|
||||
|
||||
for _, hook := range dialerCloseHooks {
|
||||
if err := hook(c.ID, &c.Conn); err != nil {
|
||||
log.Errorf("Error executing dialer close hook: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
58
util/net/dial.go
Normal file
58
util/net/dial.go
Normal file
@@ -0,0 +1,58 @@
|
||||
//go:build !ios
|
||||
|
||||
package net
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
func DialUDP(network string, laddr, raddr *net.UDPAddr) (*net.UDPConn, error) {
|
||||
if CustomRoutingDisabled() {
|
||||
return net.DialUDP(network, laddr, raddr)
|
||||
}
|
||||
|
||||
dialer := NewDialer()
|
||||
dialer.LocalAddr = laddr
|
||||
|
||||
conn, err := dialer.Dial(network, raddr.String())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("dialing UDP %s: %w", raddr.String(), err)
|
||||
}
|
||||
|
||||
udpConn, ok := conn.(*Conn).Conn.(*net.UDPConn)
|
||||
if !ok {
|
||||
if err := conn.Close(); err != nil {
|
||||
log.Errorf("Failed to close connection: %v", err)
|
||||
}
|
||||
return nil, fmt.Errorf("expected UDP connection, got different type: %T", conn)
|
||||
}
|
||||
|
||||
return udpConn, nil
|
||||
}
|
||||
|
||||
func DialTCP(network string, laddr, raddr *net.TCPAddr) (*net.TCPConn, error) {
|
||||
if CustomRoutingDisabled() {
|
||||
return net.DialTCP(network, laddr, raddr)
|
||||
}
|
||||
|
||||
dialer := NewDialer()
|
||||
dialer.LocalAddr = laddr
|
||||
|
||||
conn, err := dialer.Dial(network, raddr.String())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("dialing TCP %s: %w", raddr.String(), err)
|
||||
}
|
||||
|
||||
tcpConn, ok := conn.(*Conn).Conn.(*net.TCPConn)
|
||||
if !ok {
|
||||
if err := conn.Close(); err != nil {
|
||||
log.Errorf("Failed to close connection: %v", err)
|
||||
}
|
||||
return nil, fmt.Errorf("expected TCP connection, got different type: %T", conn)
|
||||
}
|
||||
|
||||
return tcpConn, nil
|
||||
}
|
||||
@@ -16,5 +16,6 @@ func NewDialer() *Dialer {
|
||||
Dialer: &net.Dialer{},
|
||||
}
|
||||
dialer.init()
|
||||
|
||||
return dialer
|
||||
}
|
||||
107
util/net/dialer_dial.go
Normal file
107
util/net/dialer_dial.go
Normal file
@@ -0,0 +1,107 @@
|
||||
//go:build !ios
|
||||
|
||||
package net
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
|
||||
"github.com/hashicorp/go-multierror"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
type DialerDialHookFunc func(ctx context.Context, connID ConnectionID, resolvedAddresses []net.IPAddr) error
|
||||
type DialerCloseHookFunc func(connID ConnectionID, conn *net.Conn) error
|
||||
|
||||
var (
|
||||
dialerDialHooksMutex sync.RWMutex
|
||||
dialerDialHooks []DialerDialHookFunc
|
||||
dialerCloseHooksMutex sync.RWMutex
|
||||
dialerCloseHooks []DialerCloseHookFunc
|
||||
)
|
||||
|
||||
// AddDialerHook allows adding a new hook to be executed before dialing.
|
||||
func AddDialerHook(hook DialerDialHookFunc) {
|
||||
dialerDialHooksMutex.Lock()
|
||||
defer dialerDialHooksMutex.Unlock()
|
||||
dialerDialHooks = append(dialerDialHooks, hook)
|
||||
}
|
||||
|
||||
// AddDialerCloseHook allows adding a new hook to be executed on connection close.
|
||||
func AddDialerCloseHook(hook DialerCloseHookFunc) {
|
||||
dialerCloseHooksMutex.Lock()
|
||||
defer dialerCloseHooksMutex.Unlock()
|
||||
dialerCloseHooks = append(dialerCloseHooks, hook)
|
||||
}
|
||||
|
||||
// RemoveDialerHooks removes all dialer hooks.
|
||||
func RemoveDialerHooks() {
|
||||
dialerDialHooksMutex.Lock()
|
||||
defer dialerDialHooksMutex.Unlock()
|
||||
dialerDialHooks = nil
|
||||
|
||||
dialerCloseHooksMutex.Lock()
|
||||
defer dialerCloseHooksMutex.Unlock()
|
||||
dialerCloseHooks = nil
|
||||
}
|
||||
|
||||
// DialContext wraps the net.Dialer's DialContext method to use the custom connection
|
||||
func (d *Dialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
log.Debugf("Dialing %s %s", network, address)
|
||||
|
||||
if CustomRoutingDisabled() {
|
||||
return d.Dialer.DialContext(ctx, network, address)
|
||||
}
|
||||
|
||||
var resolver *net.Resolver
|
||||
if d.Resolver != nil {
|
||||
resolver = d.Resolver
|
||||
}
|
||||
|
||||
connID := GenerateConnID()
|
||||
if dialerDialHooks != nil {
|
||||
if err := callDialerHooks(ctx, connID, address, resolver); err != nil {
|
||||
log.Errorf("Failed to call dialer hooks: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
conn, err := d.Dialer.DialContext(ctx, network, address)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("d.Dialer.DialContext: %w", err)
|
||||
}
|
||||
|
||||
// Wrap the connection in Conn to handle Close with hooks
|
||||
return &Conn{Conn: conn, ID: connID}, nil
|
||||
}
|
||||
|
||||
// Dial wraps the net.Dialer's Dial method to use the custom connection
|
||||
func (d *Dialer) Dial(network, address string) (net.Conn, error) {
|
||||
return d.DialContext(context.Background(), network, address)
|
||||
}
|
||||
|
||||
func callDialerHooks(ctx context.Context, connID ConnectionID, address string, resolver *net.Resolver) error {
|
||||
host, _, err := net.SplitHostPort(address)
|
||||
if err != nil {
|
||||
return fmt.Errorf("split host and port: %w", err)
|
||||
}
|
||||
ips, err := resolver.LookupIPAddr(ctx, host)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to resolve address %s: %w", address, err)
|
||||
}
|
||||
|
||||
log.Debugf("Dialer resolved IPs for %s: %v", address, ips)
|
||||
|
||||
var result *multierror.Error
|
||||
|
||||
dialerDialHooksMutex.RLock()
|
||||
defer dialerDialHooksMutex.RUnlock()
|
||||
for _, hook := range dialerDialHooks {
|
||||
if err := hook(ctx, connID, ips); err != nil {
|
||||
result = multierror.Append(result, fmt.Errorf("executing dial hook: %w", err))
|
||||
}
|
||||
}
|
||||
|
||||
return result.ErrorOrNil()
|
||||
}
|
||||
7
util/net/dialer_init_nonlinux.go
Normal file
7
util/net/dialer_init_nonlinux.go
Normal file
@@ -0,0 +1,7 @@
|
||||
//go:build !linux
|
||||
|
||||
package net
|
||||
|
||||
func (d *Dialer) init() {
|
||||
// implemented on Linux and Android only
|
||||
}
|
||||
@@ -11,7 +11,6 @@ import (
|
||||
|
||||
const (
|
||||
envDisableCustomRouting = "NB_DISABLE_CUSTOM_ROUTING"
|
||||
envUseLegacyRouting = "NB_USE_LEGACY_ROUTING"
|
||||
)
|
||||
|
||||
// CustomRoutingDisabled returns true if custom routing is disabled.
|
||||
12
util/net/env_generic.go
Normal file
12
util/net/env_generic.go
Normal file
@@ -0,0 +1,12 @@
|
||||
//go:build !linux || android
|
||||
|
||||
package net
|
||||
|
||||
func Init() {
|
||||
// nothing to do on non-linux
|
||||
}
|
||||
|
||||
func AdvancedRouting() bool {
|
||||
// non-linux currently doesn't support advanced routing
|
||||
return false
|
||||
}
|
||||
@@ -17,7 +17,8 @@ import (
|
||||
|
||||
const (
|
||||
// these have the same effect, skip socket env supported for backward compatibility
|
||||
envSkipSocketMark = "NB_SKIP_SOCKET_MARK"
|
||||
envSkipSocketMark = "NB_SKIP_SOCKET_MARK"
|
||||
envUseLegacyRouting = "NB_USE_LEGACY_ROUTING"
|
||||
)
|
||||
|
||||
var advancedRoutingSupported bool
|
||||
@@ -26,7 +27,6 @@ func Init() {
|
||||
advancedRoutingSupported = checkAdvancedRoutingSupport()
|
||||
}
|
||||
|
||||
// AdvancedRouting reports whether routing loops can be avoided without using exclusion routes
|
||||
func AdvancedRouting() bool {
|
||||
return advancedRoutingSupported
|
||||
}
|
||||
@@ -73,7 +73,7 @@ func checkAdvancedRoutingSupport() bool {
|
||||
}
|
||||
|
||||
func CheckFwmarkSupport() bool {
|
||||
// temporarily enable advanced routing to check if fwmarks are supported
|
||||
// temporarily enable advanced routing to check fwmarks are supported
|
||||
old := advancedRoutingSupported
|
||||
advancedRoutingSupported = true
|
||||
defer func() {
|
||||
@@ -129,13 +129,3 @@ func CheckRuleOperationsSupport() bool {
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// SetVPNInterfaceName is a no-op on Linux
|
||||
func SetVPNInterfaceName(name string) {
|
||||
// No-op on Linux - not needed for fwmark-based routing
|
||||
}
|
||||
|
||||
// GetVPNInterfaceName returns empty string on Linux
|
||||
func GetVPNInterfaceName() string {
|
||||
return ""
|
||||
}
|
||||
37
util/net/listen.go
Normal file
37
util/net/listen.go
Normal file
@@ -0,0 +1,37 @@
|
||||
//go:build !ios
|
||||
|
||||
package net
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
|
||||
"github.com/pion/transport/v3"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// ListenUDP listens on the network address and returns a transport.UDPConn
|
||||
// which includes support for write and close hooks.
|
||||
func ListenUDP(network string, laddr *net.UDPAddr) (transport.UDPConn, error) {
|
||||
if CustomRoutingDisabled() {
|
||||
return net.ListenUDP(network, laddr)
|
||||
}
|
||||
|
||||
conn, err := NewListener().ListenPacket(context.Background(), network, laddr.String())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("listen UDP: %w", err)
|
||||
}
|
||||
|
||||
packetConn := conn.(*PacketConn)
|
||||
udpConn, ok := packetConn.PacketConn.(*net.UDPConn)
|
||||
if !ok {
|
||||
if err := packetConn.Close(); err != nil {
|
||||
log.Errorf("Failed to close connection: %v", err)
|
||||
}
|
||||
return nil, fmt.Errorf("expected UDPConn, got different type: %T", udpConn)
|
||||
}
|
||||
|
||||
return &UDPConn{UDPConn: udpConn, ID: packetConn.ID, seenAddrs: &sync.Map{}}, nil
|
||||
}
|
||||
@@ -7,12 +7,14 @@ import (
|
||||
// ListenerConfig extends the standard net.ListenConfig with the ability to execute hooks before
|
||||
// responding via the socket and after closing. This can be used to bypass the VPN for listeners.
|
||||
type ListenerConfig struct {
|
||||
net.ListenConfig
|
||||
*net.ListenConfig
|
||||
}
|
||||
|
||||
// NewListener creates a new ListenerConfig instance.
|
||||
func NewListener() *ListenerConfig {
|
||||
listener := &ListenerConfig{}
|
||||
listener := &ListenerConfig{
|
||||
ListenConfig: &net.ListenConfig{},
|
||||
}
|
||||
listener.init()
|
||||
|
||||
return listener
|
||||
7
util/net/listener_init_nonlinux.go
Normal file
7
util/net/listener_init_nonlinux.go
Normal file
@@ -0,0 +1,7 @@
|
||||
//go:build !linux
|
||||
|
||||
package net
|
||||
|
||||
func (l *ListenerConfig) init() {
|
||||
// implemented on Linux and Android only
|
||||
}
|
||||
205
util/net/listener_listen.go
Normal file
205
util/net/listener_listen.go
Normal file
@@ -0,0 +1,205 @@
|
||||
//go:build !ios
|
||||
|
||||
package net
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"sync"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// ListenerWriteHookFunc defines the function signature for write hooks for PacketConn.
|
||||
type ListenerWriteHookFunc func(connID ConnectionID, ip *net.IPAddr, data []byte) error
|
||||
|
||||
// ListenerCloseHookFunc defines the function signature for close hooks for PacketConn.
|
||||
type ListenerCloseHookFunc func(connID ConnectionID, conn net.PacketConn) error
|
||||
|
||||
// ListenerAddressRemoveHookFunc defines the function signature for hooks called when addresses are removed.
|
||||
type ListenerAddressRemoveHookFunc func(connID ConnectionID, prefix netip.Prefix) error
|
||||
|
||||
var (
|
||||
listenerWriteHooksMutex sync.RWMutex
|
||||
listenerWriteHooks []ListenerWriteHookFunc
|
||||
listenerCloseHooksMutex sync.RWMutex
|
||||
listenerCloseHooks []ListenerCloseHookFunc
|
||||
listenerAddressRemoveHooksMutex sync.RWMutex
|
||||
listenerAddressRemoveHooks []ListenerAddressRemoveHookFunc
|
||||
)
|
||||
|
||||
// AddListenerWriteHook allows adding a new write hook to be executed before a UDP packet is sent.
|
||||
func AddListenerWriteHook(hook ListenerWriteHookFunc) {
|
||||
listenerWriteHooksMutex.Lock()
|
||||
defer listenerWriteHooksMutex.Unlock()
|
||||
listenerWriteHooks = append(listenerWriteHooks, hook)
|
||||
}
|
||||
|
||||
// AddListenerCloseHook allows adding a new hook to be executed upon closing a UDP connection.
|
||||
func AddListenerCloseHook(hook ListenerCloseHookFunc) {
|
||||
listenerCloseHooksMutex.Lock()
|
||||
defer listenerCloseHooksMutex.Unlock()
|
||||
listenerCloseHooks = append(listenerCloseHooks, hook)
|
||||
}
|
||||
|
||||
// AddListenerAddressRemoveHook allows adding a new hook to be executed when an address is removed.
|
||||
func AddListenerAddressRemoveHook(hook ListenerAddressRemoveHookFunc) {
|
||||
listenerAddressRemoveHooksMutex.Lock()
|
||||
defer listenerAddressRemoveHooksMutex.Unlock()
|
||||
listenerAddressRemoveHooks = append(listenerAddressRemoveHooks, hook)
|
||||
}
|
||||
|
||||
// RemoveListenerHooks removes all listener hooks.
|
||||
func RemoveListenerHooks() {
|
||||
listenerWriteHooksMutex.Lock()
|
||||
defer listenerWriteHooksMutex.Unlock()
|
||||
listenerWriteHooks = nil
|
||||
|
||||
listenerCloseHooksMutex.Lock()
|
||||
defer listenerCloseHooksMutex.Unlock()
|
||||
listenerCloseHooks = nil
|
||||
|
||||
listenerAddressRemoveHooksMutex.Lock()
|
||||
defer listenerAddressRemoveHooksMutex.Unlock()
|
||||
listenerAddressRemoveHooks = nil
|
||||
}
|
||||
|
||||
// ListenPacket listens on the network address and returns a PacketConn
|
||||
// which includes support for write hooks.
|
||||
func (l *ListenerConfig) ListenPacket(ctx context.Context, network, address string) (net.PacketConn, error) {
|
||||
if CustomRoutingDisabled() {
|
||||
return l.ListenConfig.ListenPacket(ctx, network, address)
|
||||
}
|
||||
|
||||
pc, err := l.ListenConfig.ListenPacket(ctx, network, address)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("listen packet: %w", err)
|
||||
}
|
||||
connID := GenerateConnID()
|
||||
|
||||
return &PacketConn{PacketConn: pc, ID: connID, seenAddrs: &sync.Map{}}, nil
|
||||
}
|
||||
|
||||
// PacketConn wraps net.PacketConn to override its WriteTo and Close methods to include hook functionality.
|
||||
type PacketConn struct {
|
||||
net.PacketConn
|
||||
ID ConnectionID
|
||||
seenAddrs *sync.Map
|
||||
}
|
||||
|
||||
// WriteTo writes a packet with payload b to addr, executing registered write hooks beforehand.
|
||||
func (c *PacketConn) WriteTo(b []byte, addr net.Addr) (n int, err error) {
|
||||
callWriteHooks(c.ID, c.seenAddrs, b, addr)
|
||||
return c.PacketConn.WriteTo(b, addr)
|
||||
}
|
||||
|
||||
// Close overrides the net.PacketConn Close method to execute all registered hooks before closing the connection.
|
||||
func (c *PacketConn) Close() error {
|
||||
c.seenAddrs = &sync.Map{}
|
||||
return closeConn(c.ID, c.PacketConn)
|
||||
}
|
||||
|
||||
// UDPConn wraps net.UDPConn to override its WriteTo and Close methods to include hook functionality.
|
||||
type UDPConn struct {
|
||||
*net.UDPConn
|
||||
ID ConnectionID
|
||||
seenAddrs *sync.Map
|
||||
}
|
||||
|
||||
// WriteTo writes a packet with payload b to addr, executing registered write hooks beforehand.
|
||||
func (c *UDPConn) WriteTo(b []byte, addr net.Addr) (n int, err error) {
|
||||
callWriteHooks(c.ID, c.seenAddrs, b, addr)
|
||||
return c.UDPConn.WriteTo(b, addr)
|
||||
}
|
||||
|
||||
// Close overrides the net.UDPConn Close method to execute all registered hooks before closing the connection.
|
||||
func (c *UDPConn) Close() error {
|
||||
c.seenAddrs = &sync.Map{}
|
||||
return closeConn(c.ID, c.UDPConn)
|
||||
}
|
||||
|
||||
// RemoveAddress removes an address from the seen cache and triggers removal hooks.
|
||||
func (c *PacketConn) RemoveAddress(addr string) {
|
||||
if _, exists := c.seenAddrs.LoadAndDelete(addr); !exists {
|
||||
return
|
||||
}
|
||||
|
||||
ipStr, _, err := net.SplitHostPort(addr)
|
||||
if err != nil {
|
||||
log.Errorf("Error splitting IP address and port: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
ipAddr, err := netip.ParseAddr(ipStr)
|
||||
if err != nil {
|
||||
log.Errorf("Error parsing IP address %s: %v", ipStr, err)
|
||||
return
|
||||
}
|
||||
|
||||
prefix := netip.PrefixFrom(ipAddr, ipAddr.BitLen())
|
||||
|
||||
listenerAddressRemoveHooksMutex.RLock()
|
||||
defer listenerAddressRemoveHooksMutex.RUnlock()
|
||||
|
||||
for _, hook := range listenerAddressRemoveHooks {
|
||||
if err := hook(c.ID, prefix); err != nil {
|
||||
log.Errorf("Error executing listener address remove hook: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// WrapPacketConn wraps an existing net.PacketConn with nbnet functionality
|
||||
func WrapPacketConn(conn net.PacketConn) *PacketConn {
|
||||
return &PacketConn{
|
||||
PacketConn: conn,
|
||||
ID: GenerateConnID(),
|
||||
seenAddrs: &sync.Map{},
|
||||
}
|
||||
}
|
||||
|
||||
func callWriteHooks(id ConnectionID, seenAddrs *sync.Map, b []byte, addr net.Addr) {
|
||||
// Lookup the address in the seenAddrs map to avoid calling the hooks for every write
|
||||
if _, loaded := seenAddrs.LoadOrStore(addr.String(), true); !loaded {
|
||||
ipStr, _, splitErr := net.SplitHostPort(addr.String())
|
||||
if splitErr != nil {
|
||||
log.Errorf("Error splitting IP address and port: %v", splitErr)
|
||||
return
|
||||
}
|
||||
|
||||
ip, err := net.ResolveIPAddr("ip", ipStr)
|
||||
if err != nil {
|
||||
log.Errorf("Error resolving IP address: %v", err)
|
||||
return
|
||||
}
|
||||
log.Debugf("Listener resolved IP for %s: %s", addr, ip)
|
||||
|
||||
func() {
|
||||
listenerWriteHooksMutex.RLock()
|
||||
defer listenerWriteHooksMutex.RUnlock()
|
||||
|
||||
for _, hook := range listenerWriteHooks {
|
||||
if err := hook(id, ip, b); err != nil {
|
||||
log.Errorf("Error executing listener write hook: %v", err)
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
func closeConn(id ConnectionID, conn net.PacketConn) error {
|
||||
err := conn.Close()
|
||||
|
||||
listenerCloseHooksMutex.RLock()
|
||||
defer listenerCloseHooksMutex.RUnlock()
|
||||
|
||||
for _, hook := range listenerCloseHooks {
|
||||
if err := hook(id, conn); err != nil {
|
||||
log.Errorf("Error executing listener close hook: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
@@ -5,6 +5,8 @@ import (
|
||||
"math/big"
|
||||
"net"
|
||||
"net/netip"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -42,6 +44,18 @@ func IsDataPlaneMark(fwmark uint32) bool {
|
||||
return fwmark >= DataPlaneMarkLower && fwmark <= DataPlaneMarkUpper
|
||||
}
|
||||
|
||||
// ConnectionID provides a globally unique identifier for network connections.
|
||||
// It's used to track connections throughout their lifecycle so the close hook can correlate with the dial hook.
|
||||
type ConnectionID string
|
||||
|
||||
type AddHookFunc func(connID ConnectionID, IP net.IP) error
|
||||
type RemoveHookFunc func(connID ConnectionID) error
|
||||
|
||||
// GenerateConnID generates a unique identifier for each connection.
|
||||
func GenerateConnID() ConnectionID {
|
||||
return ConnectionID(uuid.NewString())
|
||||
}
|
||||
|
||||
func GetLastIPFromNetwork(network netip.Prefix, fromEnd int) (netip.Addr, error) {
|
||||
var endIP net.IP
|
||||
addr := network.Addr().AsSlice()
|
||||
Reference in New Issue
Block a user