mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-11 11:46:28 -04:00
Compare commits
22 Commits
refactor/n
...
feature/ch
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
90c45a6e24 | ||
|
|
cc97cffff1 | ||
|
|
c28275611b | ||
|
|
56f169eede | ||
|
|
07cf9d5895 | ||
|
|
7df49e249d | ||
|
|
dbfc8a52c9 | ||
|
|
98ddac07bf | ||
|
|
48475ddc05 | ||
|
|
6aa4ba7af4 | ||
|
|
2e16c9914a | ||
|
|
5c29d395b2 | ||
|
|
229e0038ee | ||
|
|
75327d9519 | ||
|
|
c92e6c1b5f | ||
|
|
641eb5140b | ||
|
|
45c25dca84 | ||
|
|
679c58ce47 | ||
|
|
719283c792 | ||
|
|
a2313a5ba4 | ||
|
|
8c108ccad3 | ||
|
|
86eff0d750 |
@@ -200,7 +200,7 @@ func (a *Auth) login(urlOpener URLOpener) error {
|
||||
}
|
||||
|
||||
func (a *Auth) foregroundGetTokenInfo(urlOpener URLOpener) (*auth.TokenInfo, error) {
|
||||
oAuthFlow, err := auth.NewOAuthFlow(a.ctx, a.config, false)
|
||||
oAuthFlow, err := auth.NewOAuthFlow(a.ctx, a.config, false, "")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -106,6 +106,13 @@ func doDaemonLogin(ctx context.Context, cmd *cobra.Command, providedSetupKey str
|
||||
Username: &username,
|
||||
}
|
||||
|
||||
profileState, err := pm.GetProfileState(activeProf.Name)
|
||||
if err != nil {
|
||||
log.Debugf("failed to get profile state for login hint: %v", err)
|
||||
} else if profileState.Email != "" {
|
||||
loginRequest.Hint = &profileState.Email
|
||||
}
|
||||
|
||||
if rootCmd.PersistentFlags().Changed(preSharedKeyFlag) {
|
||||
loginRequest.OptionalPreSharedKey = &preSharedKey
|
||||
}
|
||||
@@ -241,7 +248,7 @@ func doForegroundLogin(ctx context.Context, cmd *cobra.Command, setupKey string,
|
||||
return fmt.Errorf("read config file %s: %v", configFilePath, err)
|
||||
}
|
||||
|
||||
err = foregroundLogin(ctx, cmd, config, setupKey)
|
||||
err = foregroundLogin(ctx, cmd, config, setupKey, activeProf.Name)
|
||||
if err != nil {
|
||||
return fmt.Errorf("foreground login failed: %v", err)
|
||||
}
|
||||
@@ -269,7 +276,7 @@ func handleSSOLogin(ctx context.Context, cmd *cobra.Command, loginResp *proto.Lo
|
||||
return nil
|
||||
}
|
||||
|
||||
func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *profilemanager.Config, setupKey string) error {
|
||||
func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *profilemanager.Config, setupKey, profileName string) error {
|
||||
needsLogin := false
|
||||
|
||||
err := WithBackOff(func() error {
|
||||
@@ -286,7 +293,7 @@ func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *profileman
|
||||
|
||||
jwtToken := ""
|
||||
if setupKey == "" && needsLogin {
|
||||
tokenInfo, err := foregroundGetTokenInfo(ctx, cmd, config)
|
||||
tokenInfo, err := foregroundGetTokenInfo(ctx, cmd, config, profileName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("interactive sso login failed: %v", err)
|
||||
}
|
||||
@@ -315,8 +322,17 @@ func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *profileman
|
||||
return nil
|
||||
}
|
||||
|
||||
func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *profilemanager.Config) (*auth.TokenInfo, error) {
|
||||
oAuthFlow, err := auth.NewOAuthFlow(ctx, config, isUnixRunningDesktop())
|
||||
func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *profilemanager.Config, profileName string) (*auth.TokenInfo, error) {
|
||||
hint := ""
|
||||
pm := profilemanager.NewProfileManager()
|
||||
profileState, err := pm.GetProfileState(profileName)
|
||||
if err != nil {
|
||||
log.Debugf("failed to get profile state for login hint: %v", err)
|
||||
} else if profileState.Email != "" {
|
||||
hint = profileState.Email
|
||||
}
|
||||
|
||||
oAuthFlow, err := auth.NewOAuthFlow(ctx, config, isUnixRunningDesktop(), hint)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -10,6 +10,8 @@ import (
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/kardianos/service"
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
@@ -81,6 +83,10 @@ func configurePlatformSpecificSettings(svcConfig *service.Config) error {
|
||||
svcConfig.Option["LogDirectory"] = dir
|
||||
}
|
||||
}
|
||||
|
||||
if err := configureSystemdNetworkd(); err != nil {
|
||||
log.Warnf("failed to configure systemd-networkd: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
if runtime.GOOS == "windows" {
|
||||
@@ -160,6 +166,12 @@ var uninstallCmd = &cobra.Command{
|
||||
return fmt.Errorf("uninstall service: %w", err)
|
||||
}
|
||||
|
||||
if runtime.GOOS == "linux" {
|
||||
if err := cleanupSystemdNetworkd(); err != nil {
|
||||
log.Warnf("failed to cleanup systemd-networkd configuration: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
cmd.Println("NetBird service has been uninstalled")
|
||||
return nil
|
||||
},
|
||||
@@ -245,3 +257,50 @@ func isServiceRunning() (bool, error) {
|
||||
|
||||
return status == service.StatusRunning, nil
|
||||
}
|
||||
|
||||
const (
|
||||
networkdConf = "/etc/systemd/networkd.conf"
|
||||
networkdConfDir = "/etc/systemd/networkd.conf.d"
|
||||
networkdConfFile = "/etc/systemd/networkd.conf.d/99-netbird.conf"
|
||||
networkdConfContent = `# Created by NetBird to prevent systemd-networkd from removing
|
||||
# routes and policy rules managed by NetBird.
|
||||
|
||||
[Network]
|
||||
ManageForeignRoutes=no
|
||||
ManageForeignRoutingPolicyRules=no
|
||||
`
|
||||
)
|
||||
|
||||
// configureSystemdNetworkd creates a drop-in configuration file to prevent
|
||||
// systemd-networkd from removing NetBird's routes and policy rules.
|
||||
func configureSystemdNetworkd() error {
|
||||
if _, err := os.Stat(networkdConf); os.IsNotExist(err) {
|
||||
log.Debug("systemd-networkd not in use, skipping configuration")
|
||||
return nil
|
||||
}
|
||||
|
||||
// nolint:gosec // standard networkd permissions
|
||||
if err := os.MkdirAll(networkdConfDir, 0755); err != nil {
|
||||
return fmt.Errorf("create networkd.conf.d directory: %w", err)
|
||||
}
|
||||
|
||||
// nolint:gosec // standard networkd permissions
|
||||
if err := os.WriteFile(networkdConfFile, []byte(networkdConfContent), 0644); err != nil {
|
||||
return fmt.Errorf("write networkd configuration: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// cleanupSystemdNetworkd removes the NetBird systemd-networkd configuration file.
|
||||
func cleanupSystemdNetworkd() error {
|
||||
if _, err := os.Stat(networkdConfFile); os.IsNotExist(err) {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := os.Remove(networkdConfFile); err != nil {
|
||||
return fmt.Errorf("remove networkd configuration: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -12,6 +12,9 @@ import (
|
||||
"google.golang.org/grpc"
|
||||
|
||||
"github.com/netbirdio/management-integrations/integrations"
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
|
||||
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
|
||||
|
||||
clientProto "github.com/netbirdio/netbird/client/proto"
|
||||
client "github.com/netbirdio/netbird/client/server"
|
||||
@@ -84,7 +87,6 @@ func startManagement(t *testing.T, config *config.Config, testFile string) (*grp
|
||||
}
|
||||
t.Cleanup(cleanUp)
|
||||
|
||||
peersUpdateManager := mgmt.NewPeersUpdateManager(nil)
|
||||
eventStore := &activity.InMemoryEventStore{}
|
||||
if err != nil {
|
||||
return nil, nil
|
||||
@@ -110,13 +112,18 @@ func startManagement(t *testing.T, config *config.Config, testFile string) (*grp
|
||||
Return(&types.Settings{}, nil).
|
||||
AnyTimes()
|
||||
|
||||
accountManager, err := mgmt.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, iv, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false)
|
||||
ctx := context.Background()
|
||||
updateManager := update_channel.NewPeersUpdateManager(metrics)
|
||||
requestBuffer := mgmt.NewAccountRequestBuffer(ctx, store)
|
||||
networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, mgmt.MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock())
|
||||
|
||||
accountManager, err := mgmt.BuildManager(context.Background(), store, networkMapController, nil, "", eventStore, nil, false, iv, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
secretsManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
|
||||
mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, &manager.EphemeralManager{}, nil, &mgmt.MockIntegratedValidator{})
|
||||
secretsManager := nbgrpc.NewTimeBasedAuthSecretsManager(updateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
|
||||
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, updateManager, secretsManager, nil, &manager.EphemeralManager{}, nil, &mgmt.MockIntegratedValidator{}, networkMapController)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
@@ -185,7 +185,7 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command, activeProf *pr
|
||||
|
||||
_, _ = profilemanager.UpdateOldManagementURL(ctx, config, configFilePath)
|
||||
|
||||
err = foregroundLogin(ctx, cmd, config, providedSetupKey)
|
||||
err = foregroundLogin(ctx, cmd, config, providedSetupKey, activeProf.Name)
|
||||
if err != nil {
|
||||
return fmt.Errorf("foreground login failed: %v", err)
|
||||
}
|
||||
@@ -286,6 +286,13 @@ func doDaemonUp(ctx context.Context, cmd *cobra.Command, client proto.DaemonServ
|
||||
loginRequest.ProfileName = &activeProf.Name
|
||||
loginRequest.Username = &username
|
||||
|
||||
profileState, err := pm.GetProfileState(activeProf.Name)
|
||||
if err != nil {
|
||||
log.Debugf("failed to get profile state for login hint: %v", err)
|
||||
} else if profileState.Email != "" {
|
||||
loginRequest.Hint = &profileState.Email
|
||||
}
|
||||
|
||||
var loginErr error
|
||||
var loginResp *proto.LoginResponse
|
||||
|
||||
|
||||
@@ -15,13 +15,13 @@ import (
|
||||
)
|
||||
|
||||
// NewFirewall creates a firewall manager instance
|
||||
func NewFirewall(iface IFaceMapper, _ *statemanager.Manager, flowLogger nftypes.FlowLogger, disableServerRoutes bool) (firewall.Manager, error) {
|
||||
func NewFirewall(iface IFaceMapper, _ *statemanager.Manager, flowLogger nftypes.FlowLogger, disableServerRoutes bool, mtu uint16) (firewall.Manager, error) {
|
||||
if !iface.IsUserspaceBind() {
|
||||
return nil, fmt.Errorf("not implemented for this OS: %s", runtime.GOOS)
|
||||
}
|
||||
|
||||
// use userspace packet filtering firewall
|
||||
fm, err := uspfilter.Create(iface, disableServerRoutes, flowLogger)
|
||||
fm, err := uspfilter.Create(iface, disableServerRoutes, flowLogger, mtu)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -34,12 +34,12 @@ const SKIP_NFTABLES_ENV = "NB_SKIP_NFTABLES_CHECK"
|
||||
// FWType is the type for the firewall type
|
||||
type FWType int
|
||||
|
||||
func NewFirewall(iface IFaceMapper, stateManager *statemanager.Manager, flowLogger nftypes.FlowLogger, disableServerRoutes bool) (firewall.Manager, error) {
|
||||
func NewFirewall(iface IFaceMapper, stateManager *statemanager.Manager, flowLogger nftypes.FlowLogger, disableServerRoutes bool, mtu uint16) (firewall.Manager, error) {
|
||||
// on the linux system we try to user nftables or iptables
|
||||
// in any case, because we need to allow netbird interface traffic
|
||||
// so we use AllowNetbird traffic from these firewall managers
|
||||
// for the userspace packet filtering firewall
|
||||
fm, err := createNativeFirewall(iface, stateManager, disableServerRoutes)
|
||||
fm, err := createNativeFirewall(iface, stateManager, disableServerRoutes, mtu)
|
||||
|
||||
if !iface.IsUserspaceBind() {
|
||||
return fm, err
|
||||
@@ -48,11 +48,11 @@ func NewFirewall(iface IFaceMapper, stateManager *statemanager.Manager, flowLogg
|
||||
if err != nil {
|
||||
log.Warnf("failed to create native firewall: %v. Proceeding with userspace", err)
|
||||
}
|
||||
return createUserspaceFirewall(iface, fm, disableServerRoutes, flowLogger)
|
||||
return createUserspaceFirewall(iface, fm, disableServerRoutes, flowLogger, mtu)
|
||||
}
|
||||
|
||||
func createNativeFirewall(iface IFaceMapper, stateManager *statemanager.Manager, routes bool) (firewall.Manager, error) {
|
||||
fm, err := createFW(iface)
|
||||
func createNativeFirewall(iface IFaceMapper, stateManager *statemanager.Manager, routes bool, mtu uint16) (firewall.Manager, error) {
|
||||
fm, err := createFW(iface, mtu)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create firewall: %s", err)
|
||||
}
|
||||
@@ -64,26 +64,26 @@ func createNativeFirewall(iface IFaceMapper, stateManager *statemanager.Manager,
|
||||
return fm, nil
|
||||
}
|
||||
|
||||
func createFW(iface IFaceMapper) (firewall.Manager, error) {
|
||||
func createFW(iface IFaceMapper, mtu uint16) (firewall.Manager, error) {
|
||||
switch check() {
|
||||
case IPTABLES:
|
||||
log.Info("creating an iptables firewall manager")
|
||||
return nbiptables.Create(iface)
|
||||
return nbiptables.Create(iface, mtu)
|
||||
case NFTABLES:
|
||||
log.Info("creating an nftables firewall manager")
|
||||
return nbnftables.Create(iface)
|
||||
return nbnftables.Create(iface, mtu)
|
||||
default:
|
||||
log.Info("no firewall manager found, trying to use userspace packet filtering firewall")
|
||||
return nil, errors.New("no firewall manager found")
|
||||
}
|
||||
}
|
||||
|
||||
func createUserspaceFirewall(iface IFaceMapper, fm firewall.Manager, disableServerRoutes bool, flowLogger nftypes.FlowLogger) (firewall.Manager, error) {
|
||||
func createUserspaceFirewall(iface IFaceMapper, fm firewall.Manager, disableServerRoutes bool, flowLogger nftypes.FlowLogger, mtu uint16) (firewall.Manager, error) {
|
||||
var errUsp error
|
||||
if fm != nil {
|
||||
fm, errUsp = uspfilter.CreateWithNativeFirewall(iface, fm, disableServerRoutes, flowLogger)
|
||||
fm, errUsp = uspfilter.CreateWithNativeFirewall(iface, fm, disableServerRoutes, flowLogger, mtu)
|
||||
} else {
|
||||
fm, errUsp = uspfilter.Create(iface, disableServerRoutes, flowLogger)
|
||||
fm, errUsp = uspfilter.Create(iface, disableServerRoutes, flowLogger, mtu)
|
||||
}
|
||||
|
||||
if errUsp != nil {
|
||||
|
||||
@@ -36,7 +36,7 @@ type iFaceMapper interface {
|
||||
}
|
||||
|
||||
// Create iptables firewall manager
|
||||
func Create(wgIface iFaceMapper) (*Manager, error) {
|
||||
func Create(wgIface iFaceMapper, mtu uint16) (*Manager, error) {
|
||||
iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("init iptables: %w", err)
|
||||
@@ -47,7 +47,7 @@ func Create(wgIface iFaceMapper) (*Manager, error) {
|
||||
ipv4Client: iptablesClient,
|
||||
}
|
||||
|
||||
m.router, err = newRouter(iptablesClient, wgIface)
|
||||
m.router, err = newRouter(iptablesClient, wgIface, mtu)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create router: %w", err)
|
||||
}
|
||||
@@ -66,6 +66,7 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error {
|
||||
NameStr: m.wgIface.Name(),
|
||||
WGAddress: m.wgIface.Address(),
|
||||
UserspaceBind: m.wgIface.IsUserspaceBind(),
|
||||
MTU: m.router.mtu,
|
||||
},
|
||||
}
|
||||
stateManager.RegisterState(state)
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
fw "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
)
|
||||
|
||||
@@ -53,7 +54,7 @@ func TestIptablesManager(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
// just check on the local interface
|
||||
manager, err := Create(ifaceMock)
|
||||
manager, err := Create(ifaceMock, iface.DefaultMTU)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, manager.Init(nil))
|
||||
|
||||
@@ -114,7 +115,7 @@ func TestIptablesManagerDenyRules(t *testing.T) {
|
||||
ipv4Client, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
||||
require.NoError(t, err)
|
||||
|
||||
manager, err := Create(ifaceMock)
|
||||
manager, err := Create(ifaceMock, iface.DefaultMTU)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, manager.Init(nil))
|
||||
|
||||
@@ -198,7 +199,7 @@ func TestIptablesManagerIPSet(t *testing.T) {
|
||||
}
|
||||
|
||||
// just check on the local interface
|
||||
manager, err := Create(mock)
|
||||
manager, err := Create(mock, iface.DefaultMTU)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, manager.Init(nil))
|
||||
|
||||
@@ -264,7 +265,7 @@ func TestIptablesCreatePerformance(t *testing.T) {
|
||||
for _, testMax := range []int{10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 200, 300, 400, 500, 600, 700, 800, 900, 1000} {
|
||||
t.Run(fmt.Sprintf("Testing %d rules", testMax), func(t *testing.T) {
|
||||
// just check on the local interface
|
||||
manager, err := Create(mock)
|
||||
manager, err := Create(mock, iface.DefaultMTU)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, manager.Init(nil))
|
||||
time.Sleep(time.Second)
|
||||
|
||||
@@ -30,17 +30,20 @@ const (
|
||||
|
||||
chainPOSTROUTING = "POSTROUTING"
|
||||
chainPREROUTING = "PREROUTING"
|
||||
chainFORWARD = "FORWARD"
|
||||
chainRTNAT = "NETBIRD-RT-NAT"
|
||||
chainRTFWDIN = "NETBIRD-RT-FWD-IN"
|
||||
chainRTFWDOUT = "NETBIRD-RT-FWD-OUT"
|
||||
chainRTPRE = "NETBIRD-RT-PRE"
|
||||
chainRTRDR = "NETBIRD-RT-RDR"
|
||||
chainRTMSSCLAMP = "NETBIRD-RT-MSSCLAMP"
|
||||
routingFinalForwardJump = "ACCEPT"
|
||||
routingFinalNatJump = "MASQUERADE"
|
||||
|
||||
jumpManglePre = "jump-mangle-pre"
|
||||
jumpNatPre = "jump-nat-pre"
|
||||
jumpNatPost = "jump-nat-post"
|
||||
jumpMSSClamp = "jump-mss-clamp"
|
||||
markManglePre = "mark-mangle-pre"
|
||||
markManglePost = "mark-mangle-post"
|
||||
matchSet = "--match-set"
|
||||
@@ -48,6 +51,9 @@ const (
|
||||
dnatSuffix = "_dnat"
|
||||
snatSuffix = "_snat"
|
||||
fwdSuffix = "_fwd"
|
||||
|
||||
// ipTCPHeaderMinSize represents minimum IP (20) + TCP (20) header size for MSS calculation
|
||||
ipTCPHeaderMinSize = 40
|
||||
)
|
||||
|
||||
type ruleInfo struct {
|
||||
@@ -77,16 +83,18 @@ type router struct {
|
||||
ipsetCounter *ipsetCounter
|
||||
wgIface iFaceMapper
|
||||
legacyManagement bool
|
||||
mtu uint16
|
||||
|
||||
stateManager *statemanager.Manager
|
||||
ipFwdState *ipfwdstate.IPForwardingState
|
||||
}
|
||||
|
||||
func newRouter(iptablesClient *iptables.IPTables, wgIface iFaceMapper) (*router, error) {
|
||||
func newRouter(iptablesClient *iptables.IPTables, wgIface iFaceMapper, mtu uint16) (*router, error) {
|
||||
r := &router{
|
||||
iptablesClient: iptablesClient,
|
||||
rules: make(map[string][]string),
|
||||
wgIface: wgIface,
|
||||
mtu: mtu,
|
||||
ipFwdState: ipfwdstate.NewIPForwardingState(),
|
||||
}
|
||||
|
||||
@@ -392,6 +400,7 @@ func (r *router) cleanUpDefaultForwardRules() error {
|
||||
{chainRTPRE, tableMangle},
|
||||
{chainRTNAT, tableNat},
|
||||
{chainRTRDR, tableNat},
|
||||
{chainRTMSSCLAMP, tableMangle},
|
||||
} {
|
||||
ok, err := r.iptablesClient.ChainExists(chainInfo.table, chainInfo.chain)
|
||||
if err != nil {
|
||||
@@ -416,6 +425,7 @@ func (r *router) createContainers() error {
|
||||
{chainRTPRE, tableMangle},
|
||||
{chainRTNAT, tableNat},
|
||||
{chainRTRDR, tableNat},
|
||||
{chainRTMSSCLAMP, tableMangle},
|
||||
} {
|
||||
if err := r.iptablesClient.NewChain(chainInfo.table, chainInfo.chain); err != nil {
|
||||
return fmt.Errorf("create chain %s in table %s: %w", chainInfo.chain, chainInfo.table, err)
|
||||
@@ -438,6 +448,10 @@ func (r *router) createContainers() error {
|
||||
return fmt.Errorf("add jump rules: %w", err)
|
||||
}
|
||||
|
||||
if err := r.addMSSClampingRules(); err != nil {
|
||||
log.Errorf("failed to add MSS clamping rules: %s", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -518,6 +532,35 @@ func (r *router) addPostroutingRules() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// addMSSClampingRules adds MSS clamping rules to prevent fragmentation for forwarded traffic.
|
||||
// TODO: Add IPv6 support
|
||||
func (r *router) addMSSClampingRules() error {
|
||||
mss := r.mtu - ipTCPHeaderMinSize
|
||||
|
||||
// Add jump rule from FORWARD chain in mangle table to our custom chain
|
||||
jumpRule := []string{
|
||||
"-j", chainRTMSSCLAMP,
|
||||
}
|
||||
if err := r.iptablesClient.Insert(tableMangle, chainFORWARD, 1, jumpRule...); err != nil {
|
||||
return fmt.Errorf("add jump to MSS clamp chain: %w", err)
|
||||
}
|
||||
r.rules[jumpMSSClamp] = jumpRule
|
||||
|
||||
ruleOut := []string{
|
||||
"-o", r.wgIface.Name(),
|
||||
"-p", "tcp",
|
||||
"--tcp-flags", "SYN,RST", "SYN",
|
||||
"-j", "TCPMSS",
|
||||
"--set-mss", fmt.Sprintf("%d", mss),
|
||||
}
|
||||
if err := r.iptablesClient.Append(tableMangle, chainRTMSSCLAMP, ruleOut...); err != nil {
|
||||
return fmt.Errorf("add outbound MSS clamp rule: %w", err)
|
||||
}
|
||||
r.rules["mss-clamp-out"] = ruleOut
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *router) insertEstablishedRule(chain string) error {
|
||||
establishedRule := getConntrackEstablished()
|
||||
|
||||
@@ -558,7 +601,7 @@ func (r *router) addJumpRules() error {
|
||||
}
|
||||
|
||||
func (r *router) cleanJumpRules() error {
|
||||
for _, ruleKey := range []string{jumpNatPost, jumpManglePre, jumpNatPre} {
|
||||
for _, ruleKey := range []string{jumpNatPost, jumpManglePre, jumpNatPre, jumpMSSClamp} {
|
||||
if rule, exists := r.rules[ruleKey]; exists {
|
||||
var table, chain string
|
||||
switch ruleKey {
|
||||
@@ -571,6 +614,9 @@ func (r *router) cleanJumpRules() error {
|
||||
case jumpNatPre:
|
||||
table = tableNat
|
||||
chain = chainPREROUTING
|
||||
case jumpMSSClamp:
|
||||
table = tableMangle
|
||||
chain = chainFORWARD
|
||||
default:
|
||||
return fmt.Errorf("unknown jump rule: %s", ruleKey)
|
||||
}
|
||||
|
||||
@@ -14,6 +14,7 @@ import (
|
||||
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/firewall/test"
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
nbnet "github.com/netbirdio/netbird/client/net"
|
||||
)
|
||||
|
||||
@@ -30,7 +31,7 @@ func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) {
|
||||
iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
||||
require.NoError(t, err, "failed to init iptables client")
|
||||
|
||||
manager, err := newRouter(iptablesClient, ifaceMock)
|
||||
manager, err := newRouter(iptablesClient, ifaceMock, iface.DefaultMTU)
|
||||
require.NoError(t, err, "should return a valid iptables manager")
|
||||
require.NoError(t, manager.init(nil))
|
||||
|
||||
@@ -38,7 +39,6 @@ func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) {
|
||||
assert.NoError(t, manager.Reset(), "shouldn't return error")
|
||||
}()
|
||||
|
||||
// Now 5 rules:
|
||||
// 1. established rule forward in
|
||||
// 2. estbalished rule forward out
|
||||
// 3. jump rule to POST nat chain
|
||||
@@ -48,7 +48,9 @@ func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) {
|
||||
// 7. static return masquerade rule
|
||||
// 8. mangle prerouting mark rule
|
||||
// 9. mangle postrouting mark rule
|
||||
require.Len(t, manager.rules, 9, "should have created rules map")
|
||||
// 10. jump rule to MSS clamping chain
|
||||
// 11. MSS clamping rule for outbound traffic
|
||||
require.Len(t, manager.rules, 11, "should have created rules map")
|
||||
|
||||
exists, err := manager.iptablesClient.Exists(tableNat, chainPOSTROUTING, "-j", chainRTNAT)
|
||||
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainPOSTROUTING)
|
||||
@@ -82,7 +84,7 @@ func TestIptablesManager_AddNatRule(t *testing.T) {
|
||||
iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
||||
require.NoError(t, err, "failed to init iptables client")
|
||||
|
||||
manager, err := newRouter(iptablesClient, ifaceMock)
|
||||
manager, err := newRouter(iptablesClient, ifaceMock, iface.DefaultMTU)
|
||||
require.NoError(t, err, "shouldn't return error")
|
||||
require.NoError(t, manager.init(nil))
|
||||
|
||||
@@ -155,7 +157,7 @@ func TestIptablesManager_RemoveNatRule(t *testing.T) {
|
||||
t.Run(testCase.Name, func(t *testing.T) {
|
||||
iptablesClient, _ := iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
||||
|
||||
manager, err := newRouter(iptablesClient, ifaceMock)
|
||||
manager, err := newRouter(iptablesClient, ifaceMock, iface.DefaultMTU)
|
||||
require.NoError(t, err, "shouldn't return error")
|
||||
require.NoError(t, manager.init(nil))
|
||||
defer func() {
|
||||
@@ -217,7 +219,7 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
|
||||
iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
||||
require.NoError(t, err, "Failed to create iptables client")
|
||||
|
||||
r, err := newRouter(iptablesClient, ifaceMock)
|
||||
r, err := newRouter(iptablesClient, ifaceMock, iface.DefaultMTU)
|
||||
require.NoError(t, err, "Failed to create router manager")
|
||||
require.NoError(t, r.init(nil))
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
)
|
||||
|
||||
@@ -11,6 +12,7 @@ type InterfaceState struct {
|
||||
NameStr string `json:"name"`
|
||||
WGAddress wgaddr.Address `json:"wg_address"`
|
||||
UserspaceBind bool `json:"userspace_bind"`
|
||||
MTU uint16 `json:"mtu"`
|
||||
}
|
||||
|
||||
func (i *InterfaceState) Name() string {
|
||||
@@ -42,7 +44,11 @@ func (s *ShutdownState) Name() string {
|
||||
}
|
||||
|
||||
func (s *ShutdownState) Cleanup() error {
|
||||
ipt, err := Create(s.InterfaceState)
|
||||
mtu := s.InterfaceState.MTU
|
||||
if mtu == 0 {
|
||||
mtu = iface.DefaultMTU
|
||||
}
|
||||
ipt, err := Create(s.InterfaceState, mtu)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create iptables manager: %w", err)
|
||||
}
|
||||
|
||||
@@ -100,6 +100,9 @@ type Manager interface {
|
||||
//
|
||||
// If comment argument is empty firewall manager should set
|
||||
// rule ID as comment for the rule
|
||||
//
|
||||
// Note: Callers should call Flush() after adding rules to ensure
|
||||
// they are applied to the kernel and rule handles are refreshed.
|
||||
AddPeerFiltering(
|
||||
id []byte,
|
||||
ip net.IP,
|
||||
|
||||
@@ -29,8 +29,6 @@ const (
|
||||
chainNameForwardFilter = "netbird-acl-forward-filter"
|
||||
chainNameManglePrerouting = "netbird-mangle-prerouting"
|
||||
chainNameManglePostrouting = "netbird-mangle-postrouting"
|
||||
|
||||
allowNetbirdInputRuleID = "allow Netbird incoming traffic"
|
||||
)
|
||||
|
||||
const flushError = "flush: %w"
|
||||
@@ -195,25 +193,6 @@ func (m *AclManager) DeletePeerRule(rule firewall.Rule) error {
|
||||
// createDefaultAllowRules creates default allow rules for the input and output chains
|
||||
func (m *AclManager) createDefaultAllowRules() error {
|
||||
expIn := []expr.Any{
|
||||
&expr.Payload{
|
||||
DestRegister: 1,
|
||||
Base: expr.PayloadBaseNetworkHeader,
|
||||
Offset: 12,
|
||||
Len: 4,
|
||||
},
|
||||
// mask
|
||||
&expr.Bitwise{
|
||||
SourceRegister: 1,
|
||||
DestRegister: 1,
|
||||
Len: 4,
|
||||
Mask: []byte{0, 0, 0, 0},
|
||||
Xor: []byte{0, 0, 0, 0},
|
||||
},
|
||||
// net address
|
||||
&expr.Cmp{
|
||||
Register: 1,
|
||||
Data: []byte{0, 0, 0, 0},
|
||||
},
|
||||
&expr.Verdict{
|
||||
Kind: expr.VerdictAccept,
|
||||
},
|
||||
@@ -258,7 +237,7 @@ func (m *AclManager) addIOFiltering(
|
||||
action firewall.Action,
|
||||
ipset *nftables.Set,
|
||||
) (*Rule, error) {
|
||||
ruleId := generatePeerRuleId(ip, sPort, dPort, action, ipset)
|
||||
ruleId := generatePeerRuleId(ip, proto, sPort, dPort, action, ipset)
|
||||
if r, ok := m.rules[ruleId]; ok {
|
||||
return &Rule{
|
||||
nftRule: r.nftRule,
|
||||
@@ -357,11 +336,12 @@ func (m *AclManager) addIOFiltering(
|
||||
}
|
||||
|
||||
if err := m.rConn.Flush(); err != nil {
|
||||
return nil, fmt.Errorf(flushError, err)
|
||||
return nil, fmt.Errorf("flush input rule %s: %v", ruleId, err)
|
||||
}
|
||||
|
||||
ruleStruct := &Rule{
|
||||
nftRule: nftRule,
|
||||
nftRule: nftRule,
|
||||
// best effort mangle rule
|
||||
mangleRule: m.createPreroutingRule(expressions, userData),
|
||||
nftSet: ipset,
|
||||
ruleID: ruleId,
|
||||
@@ -420,12 +400,19 @@ func (m *AclManager) createPreroutingRule(expressions []expr.Any, userData []byt
|
||||
},
|
||||
)
|
||||
|
||||
return m.rConn.AddRule(&nftables.Rule{
|
||||
nfRule := m.rConn.AddRule(&nftables.Rule{
|
||||
Table: m.workTable,
|
||||
Chain: m.chainPrerouting,
|
||||
Exprs: preroutingExprs,
|
||||
UserData: userData,
|
||||
})
|
||||
|
||||
if err := m.rConn.Flush(); err != nil {
|
||||
log.Errorf("failed to flush mangle rule %s: %v", string(userData), err)
|
||||
return nil
|
||||
}
|
||||
|
||||
return nfRule
|
||||
}
|
||||
|
||||
func (m *AclManager) createDefaultChains() (err error) {
|
||||
@@ -697,8 +684,8 @@ func (m *AclManager) refreshRuleHandles(chain *nftables.Chain, mangle bool) erro
|
||||
return nil
|
||||
}
|
||||
|
||||
func generatePeerRuleId(ip net.IP, sPort *firewall.Port, dPort *firewall.Port, action firewall.Action, ipset *nftables.Set) string {
|
||||
rulesetID := ":"
|
||||
func generatePeerRuleId(ip net.IP, proto firewall.Protocol, sPort *firewall.Port, dPort *firewall.Port, action firewall.Action, ipset *nftables.Set) string {
|
||||
rulesetID := ":" + string(proto) + ":"
|
||||
if sPort != nil {
|
||||
rulesetID += sPort.String()
|
||||
}
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
package nftables
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"sync"
|
||||
|
||||
"github.com/google/nftables"
|
||||
@@ -19,13 +19,22 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
// tableNameNetbird is the name of the table that is used for filtering by the Netbird client
|
||||
// tableNameNetbird is the default name of the table that is used for filtering by the Netbird client
|
||||
tableNameNetbird = "netbird"
|
||||
// envTableName is the environment variable to override the table name
|
||||
envTableName = "NB_NFTABLES_TABLE"
|
||||
|
||||
tableNameFilter = "filter"
|
||||
chainNameInput = "INPUT"
|
||||
)
|
||||
|
||||
func getTableName() string {
|
||||
if name := os.Getenv(envTableName); name != "" {
|
||||
return name
|
||||
}
|
||||
return tableNameNetbird
|
||||
}
|
||||
|
||||
// iFaceMapper defines subset methods of interface required for manager
|
||||
type iFaceMapper interface {
|
||||
Name() string
|
||||
@@ -44,16 +53,16 @@ type Manager struct {
|
||||
}
|
||||
|
||||
// Create nftables firewall manager
|
||||
func Create(wgIface iFaceMapper) (*Manager, error) {
|
||||
func Create(wgIface iFaceMapper, mtu uint16) (*Manager, error) {
|
||||
m := &Manager{
|
||||
rConn: &nftables.Conn{},
|
||||
wgIface: wgIface,
|
||||
}
|
||||
|
||||
workTable := &nftables.Table{Name: tableNameNetbird, Family: nftables.TableFamilyIPv4}
|
||||
workTable := &nftables.Table{Name: getTableName(), Family: nftables.TableFamilyIPv4}
|
||||
|
||||
var err error
|
||||
m.router, err = newRouter(workTable, wgIface)
|
||||
m.router, err = newRouter(workTable, wgIface, mtu)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create router: %w", err)
|
||||
}
|
||||
@@ -93,6 +102,7 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error {
|
||||
NameStr: m.wgIface.Name(),
|
||||
WGAddress: m.wgIface.Address(),
|
||||
UserspaceBind: m.wgIface.IsUserspaceBind(),
|
||||
MTU: m.router.mtu,
|
||||
},
|
||||
}); err != nil {
|
||||
log.Errorf("failed to update state: %v", err)
|
||||
@@ -197,44 +207,11 @@ func (m *Manager) AllowNetbird() error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
err := m.aclManager.createDefaultAllowRules()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create default allow rules: %v", err)
|
||||
if err := m.aclManager.createDefaultAllowRules(); err != nil {
|
||||
return fmt.Errorf("create default allow rules: %w", err)
|
||||
}
|
||||
|
||||
chains, err := m.rConn.ListChainsOfTableFamily(nftables.TableFamilyIPv4)
|
||||
if err != nil {
|
||||
return fmt.Errorf("list of chains: %w", err)
|
||||
}
|
||||
|
||||
var chain *nftables.Chain
|
||||
for _, c := range chains {
|
||||
if c.Table.Name == tableNameFilter && c.Name == chainNameInput {
|
||||
chain = c
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if chain == nil {
|
||||
log.Debugf("chain INPUT not found. Skipping add allow netbird rule")
|
||||
return nil
|
||||
}
|
||||
|
||||
rules, err := m.rConn.GetRules(chain.Table, chain)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get rules for the INPUT chain: %v", err)
|
||||
}
|
||||
|
||||
if rule := m.detectAllowNetbirdRule(rules); rule != nil {
|
||||
log.Debugf("allow netbird rule already exists: %v", rule)
|
||||
return nil
|
||||
}
|
||||
|
||||
m.applyAllowNetbirdRules(chain)
|
||||
|
||||
err = m.rConn.Flush()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to flush allow input netbird rules: %v", err)
|
||||
if err := m.rConn.Flush(); err != nil {
|
||||
return fmt.Errorf("flush allow input netbird rules: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -250,10 +227,6 @@ func (m *Manager) Close(stateManager *statemanager.Manager) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
if err := m.resetNetbirdInputRules(); err != nil {
|
||||
return fmt.Errorf("reset netbird input rules: %v", err)
|
||||
}
|
||||
|
||||
if err := m.router.Reset(); err != nil {
|
||||
return fmt.Errorf("reset router: %v", err)
|
||||
}
|
||||
@@ -273,49 +246,15 @@ func (m *Manager) Close(stateManager *statemanager.Manager) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manager) resetNetbirdInputRules() error {
|
||||
chains, err := m.rConn.ListChains()
|
||||
if err != nil {
|
||||
return fmt.Errorf("list chains: %w", err)
|
||||
}
|
||||
|
||||
m.deleteNetbirdInputRules(chains)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manager) deleteNetbirdInputRules(chains []*nftables.Chain) {
|
||||
for _, c := range chains {
|
||||
if c.Table.Name == tableNameFilter && c.Name == chainNameInput {
|
||||
rules, err := m.rConn.GetRules(c.Table, c)
|
||||
if err != nil {
|
||||
log.Errorf("get rules for chain %q: %v", c.Name, err)
|
||||
continue
|
||||
}
|
||||
|
||||
m.deleteMatchingRules(rules)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) deleteMatchingRules(rules []*nftables.Rule) {
|
||||
for _, r := range rules {
|
||||
if bytes.Equal(r.UserData, []byte(allowNetbirdInputRuleID)) {
|
||||
if err := m.rConn.DelRule(r); err != nil {
|
||||
log.Errorf("delete rule: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) cleanupNetbirdTables() error {
|
||||
tables, err := m.rConn.ListTables()
|
||||
if err != nil {
|
||||
return fmt.Errorf("list tables: %w", err)
|
||||
}
|
||||
|
||||
tableName := getTableName()
|
||||
for _, t := range tables {
|
||||
if t.Name == tableNameNetbird {
|
||||
if t.Name == tableName {
|
||||
m.rConn.DelTable(t)
|
||||
}
|
||||
}
|
||||
@@ -398,55 +337,18 @@ func (m *Manager) createWorkTable() (*nftables.Table, error) {
|
||||
return nil, fmt.Errorf("list of tables: %w", err)
|
||||
}
|
||||
|
||||
tableName := getTableName()
|
||||
for _, t := range tables {
|
||||
if t.Name == tableNameNetbird {
|
||||
if t.Name == tableName {
|
||||
m.rConn.DelTable(t)
|
||||
}
|
||||
}
|
||||
|
||||
table := m.rConn.AddTable(&nftables.Table{Name: tableNameNetbird, Family: nftables.TableFamilyIPv4})
|
||||
table := m.rConn.AddTable(&nftables.Table{Name: getTableName(), Family: nftables.TableFamilyIPv4})
|
||||
err = m.rConn.Flush()
|
||||
return table, err
|
||||
}
|
||||
|
||||
func (m *Manager) applyAllowNetbirdRules(chain *nftables.Chain) {
|
||||
rule := &nftables.Rule{
|
||||
Table: chain.Table,
|
||||
Chain: chain,
|
||||
Exprs: []expr.Any{
|
||||
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: ifname(m.wgIface.Name()),
|
||||
},
|
||||
&expr.Verdict{
|
||||
Kind: expr.VerdictAccept,
|
||||
},
|
||||
},
|
||||
UserData: []byte(allowNetbirdInputRuleID),
|
||||
}
|
||||
_ = m.rConn.InsertRule(rule)
|
||||
}
|
||||
|
||||
func (m *Manager) detectAllowNetbirdRule(existedRules []*nftables.Rule) *nftables.Rule {
|
||||
ifName := ifname(m.wgIface.Name())
|
||||
for _, rule := range existedRules {
|
||||
if rule.Table.Name == tableNameFilter && rule.Chain.Name == chainNameInput {
|
||||
if len(rule.Exprs) < 4 {
|
||||
if e, ok := rule.Exprs[0].(*expr.Meta); !ok || e.Key != expr.MetaKeyIIFNAME {
|
||||
continue
|
||||
}
|
||||
if e, ok := rule.Exprs[1].(*expr.Cmp); !ok || e.Op != expr.CmpOpEq || !bytes.Equal(e.Data, ifName) {
|
||||
continue
|
||||
}
|
||||
return rule
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func insertReturnTrafficRule(conn *nftables.Conn, table *nftables.Table, chain *nftables.Chain) {
|
||||
rule := &nftables.Rule{
|
||||
Table: table,
|
||||
|
||||
@@ -16,6 +16,7 @@ import (
|
||||
"golang.org/x/sys/unix"
|
||||
|
||||
fw "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
)
|
||||
|
||||
@@ -56,7 +57,7 @@ func (i *iFaceMock) IsUserspaceBind() bool { return false }
|
||||
func TestNftablesManager(t *testing.T) {
|
||||
|
||||
// just check on the local interface
|
||||
manager, err := Create(ifaceMock)
|
||||
manager, err := Create(ifaceMock, iface.DefaultMTU)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, manager.Init(nil))
|
||||
time.Sleep(time.Second * 3)
|
||||
@@ -168,7 +169,7 @@ func TestNftablesManager(t *testing.T) {
|
||||
func TestNftablesManagerRuleOrder(t *testing.T) {
|
||||
// This test verifies rule insertion order in nftables peer ACLs
|
||||
// We add accept rule first, then deny rule to test ordering behavior
|
||||
manager, err := Create(ifaceMock)
|
||||
manager, err := Create(ifaceMock, iface.DefaultMTU)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, manager.Init(nil))
|
||||
|
||||
@@ -261,7 +262,7 @@ func TestNFtablesCreatePerformance(t *testing.T) {
|
||||
for _, testMax := range []int{10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 200, 300, 400, 500, 600, 700, 800, 900, 1000} {
|
||||
t.Run(fmt.Sprintf("Testing %d rules", testMax), func(t *testing.T) {
|
||||
// just check on the local interface
|
||||
manager, err := Create(mock)
|
||||
manager, err := Create(mock, iface.DefaultMTU)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, manager.Init(nil))
|
||||
time.Sleep(time.Second * 3)
|
||||
@@ -345,7 +346,7 @@ func TestNftablesManagerCompatibilityWithIptables(t *testing.T) {
|
||||
stdout, stderr := runIptablesSave(t)
|
||||
verifyIptablesOutput(t, stdout, stderr)
|
||||
|
||||
manager, err := Create(ifaceMock)
|
||||
manager, err := Create(ifaceMock, iface.DefaultMTU)
|
||||
require.NoError(t, err, "failed to create manager")
|
||||
require.NoError(t, manager.Init(nil))
|
||||
|
||||
|
||||
@@ -16,6 +16,7 @@ import (
|
||||
"github.com/google/nftables/xt"
|
||||
"github.com/hashicorp/go-multierror"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/sys/unix"
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
@@ -32,12 +33,17 @@ const (
|
||||
chainNameRoutingNat = "netbird-rt-postrouting"
|
||||
chainNameRoutingRdr = "netbird-rt-redirect"
|
||||
chainNameForward = "FORWARD"
|
||||
chainNameMangleForward = "netbird-mangle-forward"
|
||||
|
||||
userDataAcceptForwardRuleIif = "frwacceptiif"
|
||||
userDataAcceptForwardRuleOif = "frwacceptoif"
|
||||
userDataAcceptInputRule = "inputaccept"
|
||||
|
||||
dnatSuffix = "_dnat"
|
||||
snatSuffix = "_snat"
|
||||
|
||||
// ipTCPHeaderMinSize represents minimum IP (20) + TCP (20) header size for MSS calculation
|
||||
ipTCPHeaderMinSize = 40
|
||||
)
|
||||
|
||||
const refreshRulesMapError = "refresh rules map: %w"
|
||||
@@ -63,9 +69,10 @@ type router struct {
|
||||
wgIface iFaceMapper
|
||||
ipFwdState *ipfwdstate.IPForwardingState
|
||||
legacyManagement bool
|
||||
mtu uint16
|
||||
}
|
||||
|
||||
func newRouter(workTable *nftables.Table, wgIface iFaceMapper) (*router, error) {
|
||||
func newRouter(workTable *nftables.Table, wgIface iFaceMapper, mtu uint16) (*router, error) {
|
||||
r := &router{
|
||||
conn: &nftables.Conn{},
|
||||
workTable: workTable,
|
||||
@@ -73,6 +80,7 @@ func newRouter(workTable *nftables.Table, wgIface iFaceMapper) (*router, error)
|
||||
rules: make(map[string]*nftables.Rule),
|
||||
wgIface: wgIface,
|
||||
ipFwdState: ipfwdstate.NewIPForwardingState(),
|
||||
mtu: mtu,
|
||||
}
|
||||
|
||||
r.ipsetCounter = refcounter.New(
|
||||
@@ -96,8 +104,8 @@ func newRouter(workTable *nftables.Table, wgIface iFaceMapper) (*router, error)
|
||||
func (r *router) init(workTable *nftables.Table) error {
|
||||
r.workTable = workTable
|
||||
|
||||
if err := r.removeAcceptForwardRules(); err != nil {
|
||||
log.Errorf("failed to clean up rules from FORWARD chain: %s", err)
|
||||
if err := r.removeAcceptFilterRules(); err != nil {
|
||||
log.Errorf("failed to clean up rules from filter table: %s", err)
|
||||
}
|
||||
|
||||
if err := r.createContainers(); err != nil {
|
||||
@@ -111,15 +119,15 @@ func (r *router) init(workTable *nftables.Table) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Reset cleans existing nftables default forward rules from the system
|
||||
// Reset cleans existing nftables filter table rules from the system
|
||||
func (r *router) Reset() error {
|
||||
// clear without deleting the ipsets, the nf table will be deleted by the caller
|
||||
r.ipsetCounter.Clear()
|
||||
|
||||
var merr *multierror.Error
|
||||
|
||||
if err := r.removeAcceptForwardRules(); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove accept forward rules: %w", err))
|
||||
if err := r.removeAcceptFilterRules(); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove accept filter rules: %w", err))
|
||||
}
|
||||
|
||||
if err := r.removeNatPreroutingRules(); err != nil {
|
||||
@@ -220,11 +228,23 @@ func (r *router) createContainers() error {
|
||||
Type: nftables.ChainTypeFilter,
|
||||
})
|
||||
|
||||
r.chains[chainNameMangleForward] = r.conn.AddChain(&nftables.Chain{
|
||||
Name: chainNameMangleForward,
|
||||
Table: r.workTable,
|
||||
Hooknum: nftables.ChainHookForward,
|
||||
Priority: nftables.ChainPriorityMangle,
|
||||
Type: nftables.ChainTypeFilter,
|
||||
})
|
||||
|
||||
// Add the single NAT rule that matches on mark
|
||||
if err := r.addPostroutingRules(); err != nil {
|
||||
return fmt.Errorf("add single nat rule: %v", err)
|
||||
}
|
||||
|
||||
if err := r.addMSSClampingRules(); err != nil {
|
||||
log.Errorf("failed to add MSS clamping rules: %s", err)
|
||||
}
|
||||
|
||||
if err := r.acceptForwardRules(); err != nil {
|
||||
log.Errorf("failed to add accept rules for the forward chain: %s", err)
|
||||
}
|
||||
@@ -745,6 +765,83 @@ func (r *router) addPostroutingRules() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// addMSSClampingRules adds MSS clamping rules to prevent fragmentation for forwarded traffic.
|
||||
// TODO: Add IPv6 support
|
||||
func (r *router) addMSSClampingRules() error {
|
||||
mss := r.mtu - ipTCPHeaderMinSize
|
||||
|
||||
exprsOut := []expr.Any{
|
||||
&expr.Meta{
|
||||
Key: expr.MetaKeyOIFNAME,
|
||||
Register: 1,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: ifname(r.wgIface.Name()),
|
||||
},
|
||||
&expr.Meta{
|
||||
Key: expr.MetaKeyL4PROTO,
|
||||
Register: 1,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: []byte{unix.IPPROTO_TCP},
|
||||
},
|
||||
&expr.Payload{
|
||||
DestRegister: 1,
|
||||
Base: expr.PayloadBaseTransportHeader,
|
||||
Offset: 13,
|
||||
Len: 1,
|
||||
},
|
||||
&expr.Bitwise{
|
||||
DestRegister: 1,
|
||||
SourceRegister: 1,
|
||||
Len: 1,
|
||||
Mask: []byte{0x02},
|
||||
Xor: []byte{0x00},
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpNeq,
|
||||
Register: 1,
|
||||
Data: []byte{0x00},
|
||||
},
|
||||
&expr.Counter{},
|
||||
&expr.Exthdr{
|
||||
DestRegister: 1,
|
||||
Type: 2,
|
||||
Offset: 2,
|
||||
Len: 2,
|
||||
Op: expr.ExthdrOpTcpopt,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpGt,
|
||||
Register: 1,
|
||||
Data: binaryutil.BigEndian.PutUint16(uint16(mss)),
|
||||
},
|
||||
&expr.Immediate{
|
||||
Register: 1,
|
||||
Data: binaryutil.BigEndian.PutUint16(uint16(mss)),
|
||||
},
|
||||
&expr.Exthdr{
|
||||
SourceRegister: 1,
|
||||
Type: 2,
|
||||
Offset: 2,
|
||||
Len: 2,
|
||||
Op: expr.ExthdrOpTcpopt,
|
||||
},
|
||||
}
|
||||
|
||||
r.conn.AddRule(&nftables.Rule{
|
||||
Table: r.workTable,
|
||||
Chain: r.chains[chainNameMangleForward],
|
||||
Exprs: exprsOut,
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// addLegacyRouteRule adds a legacy routing rule for mgmt servers pre route acls
|
||||
func (r *router) addLegacyRouteRule(pair firewall.RouterPair) error {
|
||||
sourceExp, err := r.applyNetwork(pair.Source, nil, true)
|
||||
@@ -840,6 +937,7 @@ func (r *router) RemoveAllLegacyRouteRules() error {
|
||||
// that our traffic is not dropped by existing rules there.
|
||||
// The existing FORWARD rules/policies decide outbound traffic towards our interface.
|
||||
// In case the FORWARD policy is set to "drop", we add an established/related rule to allow return traffic for the inbound rule.
|
||||
// This method also adds INPUT chain rules to allow traffic to the local interface.
|
||||
func (r *router) acceptForwardRules() error {
|
||||
if r.filterTable == nil {
|
||||
log.Debugf("table 'filter' not found for forward rules, skipping accept rules")
|
||||
@@ -849,7 +947,7 @@ func (r *router) acceptForwardRules() error {
|
||||
fw := "iptables"
|
||||
|
||||
defer func() {
|
||||
log.Debugf("Used %s to add accept forward rules", fw)
|
||||
log.Debugf("Used %s to add accept forward and input rules", fw)
|
||||
}()
|
||||
|
||||
// Try iptables first and fallback to nftables if iptables is not available
|
||||
@@ -859,22 +957,30 @@ func (r *router) acceptForwardRules() error {
|
||||
log.Warnf("Will use nftables to manipulate the filter table because iptables is not available: %v", err)
|
||||
|
||||
fw = "nftables"
|
||||
return r.acceptForwardRulesNftables()
|
||||
return r.acceptFilterRulesNftables()
|
||||
}
|
||||
|
||||
return r.acceptForwardRulesIptables(ipt)
|
||||
return r.acceptFilterRulesIptables(ipt)
|
||||
}
|
||||
|
||||
func (r *router) acceptForwardRulesIptables(ipt *iptables.IPTables) error {
|
||||
func (r *router) acceptFilterRulesIptables(ipt *iptables.IPTables) error {
|
||||
var merr *multierror.Error
|
||||
|
||||
for _, rule := range r.getAcceptForwardRules() {
|
||||
if err := ipt.Insert("filter", chainNameForward, 1, rule...); err != nil {
|
||||
merr = multierror.Append(err, fmt.Errorf("add iptables rule: %v", err))
|
||||
merr = multierror.Append(err, fmt.Errorf("add iptables forward rule: %v", err))
|
||||
} else {
|
||||
log.Debugf("added iptables rule: %v", rule)
|
||||
log.Debugf("added iptables forward rule: %v", rule)
|
||||
}
|
||||
}
|
||||
|
||||
inputRule := r.getAcceptInputRule()
|
||||
if err := ipt.Insert("filter", chainNameInput, 1, inputRule...); err != nil {
|
||||
merr = multierror.Append(err, fmt.Errorf("add iptables input rule: %v", err))
|
||||
} else {
|
||||
log.Debugf("added iptables input rule: %v", inputRule)
|
||||
}
|
||||
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
@@ -886,10 +992,13 @@ func (r *router) getAcceptForwardRules() [][]string {
|
||||
}
|
||||
}
|
||||
|
||||
func (r *router) acceptForwardRulesNftables() error {
|
||||
func (r *router) getAcceptInputRule() []string {
|
||||
return []string{"-i", r.wgIface.Name(), "-j", "ACCEPT"}
|
||||
}
|
||||
|
||||
func (r *router) acceptFilterRulesNftables() error {
|
||||
intf := ifname(r.wgIface.Name())
|
||||
|
||||
// Rule for incoming interface (iif) with counter
|
||||
iifRule := &nftables.Rule{
|
||||
Table: r.filterTable,
|
||||
Chain: &nftables.Chain{
|
||||
@@ -922,11 +1031,10 @@ func (r *router) acceptForwardRulesNftables() error {
|
||||
},
|
||||
}
|
||||
|
||||
// Rule for outgoing interface (oif) with counter
|
||||
oifRule := &nftables.Rule{
|
||||
Table: r.filterTable,
|
||||
Chain: &nftables.Chain{
|
||||
Name: "FORWARD",
|
||||
Name: chainNameForward,
|
||||
Table: r.filterTable,
|
||||
Type: nftables.ChainTypeFilter,
|
||||
Hooknum: nftables.ChainHookForward,
|
||||
@@ -935,35 +1043,60 @@ func (r *router) acceptForwardRulesNftables() error {
|
||||
Exprs: append(oifExprs, getEstablishedExprs(2)...),
|
||||
UserData: []byte(userDataAcceptForwardRuleOif),
|
||||
}
|
||||
|
||||
r.conn.InsertRule(oifRule)
|
||||
|
||||
inputRule := &nftables.Rule{
|
||||
Table: r.filterTable,
|
||||
Chain: &nftables.Chain{
|
||||
Name: chainNameInput,
|
||||
Table: r.filterTable,
|
||||
Type: nftables.ChainTypeFilter,
|
||||
Hooknum: nftables.ChainHookInput,
|
||||
Priority: nftables.ChainPriorityFilter,
|
||||
},
|
||||
Exprs: []expr.Any{
|
||||
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: intf,
|
||||
},
|
||||
&expr.Counter{},
|
||||
&expr.Verdict{Kind: expr.VerdictAccept},
|
||||
},
|
||||
UserData: []byte(userDataAcceptInputRule),
|
||||
}
|
||||
r.conn.InsertRule(inputRule)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *router) removeAcceptForwardRules() error {
|
||||
func (r *router) removeAcceptFilterRules() error {
|
||||
if r.filterTable == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Try iptables first and fallback to nftables if iptables is not available
|
||||
ipt, err := iptables.New()
|
||||
if err != nil {
|
||||
log.Warnf("Will use nftables to manipulate the filter table because iptables is not available: %v", err)
|
||||
return r.removeAcceptForwardRulesNftables()
|
||||
return r.removeAcceptFilterRulesNftables()
|
||||
}
|
||||
|
||||
return r.removeAcceptForwardRulesIptables(ipt)
|
||||
return r.removeAcceptFilterRulesIptables(ipt)
|
||||
}
|
||||
|
||||
func (r *router) removeAcceptForwardRulesNftables() error {
|
||||
func (r *router) removeAcceptFilterRulesNftables() error {
|
||||
chains, err := r.conn.ListChainsOfTableFamily(nftables.TableFamilyIPv4)
|
||||
if err != nil {
|
||||
return fmt.Errorf("list chains: %v", err)
|
||||
}
|
||||
|
||||
for _, chain := range chains {
|
||||
if chain.Table.Name != r.filterTable.Name || chain.Name != chainNameForward {
|
||||
if chain.Table.Name != r.filterTable.Name {
|
||||
continue
|
||||
}
|
||||
|
||||
if chain.Name != chainNameForward && chain.Name != chainNameInput {
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -974,7 +1107,8 @@ func (r *router) removeAcceptForwardRulesNftables() error {
|
||||
|
||||
for _, rule := range rules {
|
||||
if bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleIif)) ||
|
||||
bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleOif)) {
|
||||
bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleOif)) ||
|
||||
bytes.Equal(rule.UserData, []byte(userDataAcceptInputRule)) {
|
||||
if err := r.conn.DelRule(rule); err != nil {
|
||||
return fmt.Errorf("delete rule: %v", err)
|
||||
}
|
||||
@@ -989,14 +1123,20 @@ func (r *router) removeAcceptForwardRulesNftables() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *router) removeAcceptForwardRulesIptables(ipt *iptables.IPTables) error {
|
||||
func (r *router) removeAcceptFilterRulesIptables(ipt *iptables.IPTables) error {
|
||||
var merr *multierror.Error
|
||||
|
||||
for _, rule := range r.getAcceptForwardRules() {
|
||||
if err := ipt.DeleteIfExists("filter", chainNameForward, rule...); err != nil {
|
||||
merr = multierror.Append(err, fmt.Errorf("remove iptables rule: %v", err))
|
||||
merr = multierror.Append(err, fmt.Errorf("remove iptables forward rule: %v", err))
|
||||
}
|
||||
}
|
||||
|
||||
inputRule := r.getAcceptInputRule()
|
||||
if err := ipt.DeleteIfExists("filter", chainNameInput, inputRule...); err != nil {
|
||||
merr = multierror.Append(err, fmt.Errorf("remove iptables input rule: %v", err))
|
||||
}
|
||||
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
|
||||
@@ -17,6 +17,7 @@ import (
|
||||
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/firewall/test"
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -36,7 +37,7 @@ func TestNftablesManager_AddNatRule(t *testing.T) {
|
||||
for _, testCase := range test.InsertRuleTestCases {
|
||||
t.Run(testCase.Name, func(t *testing.T) {
|
||||
// need fw manager to init both acl mgr and router for all chains to be present
|
||||
manager, err := Create(ifaceMock)
|
||||
manager, err := Create(ifaceMock, iface.DefaultMTU)
|
||||
t.Cleanup(func() {
|
||||
require.NoError(t, manager.Close(nil))
|
||||
})
|
||||
@@ -125,7 +126,7 @@ func TestNftablesManager_RemoveNatRule(t *testing.T) {
|
||||
|
||||
for _, testCase := range test.RemoveRuleTestCases {
|
||||
t.Run(testCase.Name, func(t *testing.T) {
|
||||
manager, err := Create(ifaceMock)
|
||||
manager, err := Create(ifaceMock, iface.DefaultMTU)
|
||||
t.Cleanup(func() {
|
||||
require.NoError(t, manager.Close(nil))
|
||||
})
|
||||
@@ -197,7 +198,7 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
|
||||
|
||||
defer deleteWorkTable()
|
||||
|
||||
r, err := newRouter(workTable, ifaceMock)
|
||||
r, err := newRouter(workTable, ifaceMock, iface.DefaultMTU)
|
||||
require.NoError(t, err, "Failed to create router")
|
||||
require.NoError(t, r.init(workTable))
|
||||
|
||||
@@ -364,7 +365,7 @@ func TestNftablesCreateIpSet(t *testing.T) {
|
||||
|
||||
defer deleteWorkTable()
|
||||
|
||||
r, err := newRouter(workTable, ifaceMock)
|
||||
r, err := newRouter(workTable, ifaceMock, iface.DefaultMTU)
|
||||
require.NoError(t, err, "Failed to create router")
|
||||
require.NoError(t, r.init(workTable))
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@ package nftables
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
)
|
||||
|
||||
@@ -10,6 +11,7 @@ type InterfaceState struct {
|
||||
NameStr string `json:"name"`
|
||||
WGAddress wgaddr.Address `json:"wg_address"`
|
||||
UserspaceBind bool `json:"userspace_bind"`
|
||||
MTU uint16 `json:"mtu"`
|
||||
}
|
||||
|
||||
func (i *InterfaceState) Name() string {
|
||||
@@ -33,7 +35,11 @@ func (s *ShutdownState) Name() string {
|
||||
}
|
||||
|
||||
func (s *ShutdownState) Cleanup() error {
|
||||
nft, err := Create(s.InterfaceState)
|
||||
mtu := s.InterfaceState.MTU
|
||||
if mtu == 0 {
|
||||
mtu = iface.DefaultMTU
|
||||
}
|
||||
nft, err := Create(s.InterfaceState, mtu)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create nftables manager: %w", err)
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package uspfilter
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
@@ -27,7 +28,12 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
)
|
||||
|
||||
const layerTypeAll = 0
|
||||
const (
|
||||
layerTypeAll = 0
|
||||
|
||||
// ipTCPHeaderMinSize represents minimum IP (20) + TCP (20) header size for MSS calculation
|
||||
ipTCPHeaderMinSize = 40
|
||||
)
|
||||
|
||||
const (
|
||||
// EnvDisableConntrack disables the stateful filter, replies to outbound traffic won't be allowed.
|
||||
@@ -36,6 +42,9 @@ const (
|
||||
// EnvDisableUserspaceRouting disables userspace routing, to-be-routed packets will be dropped.
|
||||
EnvDisableUserspaceRouting = "NB_DISABLE_USERSPACE_ROUTING"
|
||||
|
||||
// EnvDisableMSSClamping disables TCP MSS clamping for forwarded traffic.
|
||||
EnvDisableMSSClamping = "NB_DISABLE_MSS_CLAMPING"
|
||||
|
||||
// EnvForceUserspaceRouter forces userspace routing even if native routing is available.
|
||||
EnvForceUserspaceRouter = "NB_FORCE_USERSPACE_ROUTER"
|
||||
|
||||
@@ -50,6 +59,12 @@ const (
|
||||
|
||||
var errNatNotSupported = errors.New("nat not supported with userspace firewall")
|
||||
|
||||
// serviceKey represents a protocol/port combination for netstack service registry
|
||||
type serviceKey struct {
|
||||
protocol gopacket.LayerType
|
||||
port uint16
|
||||
}
|
||||
|
||||
// RuleSet is a set of rules grouped by a string key
|
||||
type RuleSet map[string]PeerRule
|
||||
|
||||
@@ -113,6 +128,13 @@ type Manager struct {
|
||||
portDNATEnabled atomic.Bool
|
||||
portDNATRules []portDNATRule
|
||||
portDNATMutex sync.RWMutex
|
||||
|
||||
netstackServices map[serviceKey]struct{}
|
||||
netstackServiceMutex sync.RWMutex
|
||||
|
||||
mtu uint16
|
||||
mssClampValue uint16
|
||||
mssClampEnabled bool
|
||||
}
|
||||
|
||||
// decoder for packages
|
||||
@@ -131,16 +153,16 @@ type decoder struct {
|
||||
}
|
||||
|
||||
// Create userspace firewall manager constructor
|
||||
func Create(iface common.IFaceMapper, disableServerRoutes bool, flowLogger nftypes.FlowLogger) (*Manager, error) {
|
||||
return create(iface, nil, disableServerRoutes, flowLogger)
|
||||
func Create(iface common.IFaceMapper, disableServerRoutes bool, flowLogger nftypes.FlowLogger, mtu uint16) (*Manager, error) {
|
||||
return create(iface, nil, disableServerRoutes, flowLogger, mtu)
|
||||
}
|
||||
|
||||
func CreateWithNativeFirewall(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableServerRoutes bool, flowLogger nftypes.FlowLogger) (*Manager, error) {
|
||||
func CreateWithNativeFirewall(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableServerRoutes bool, flowLogger nftypes.FlowLogger, mtu uint16) (*Manager, error) {
|
||||
if nativeFirewall == nil {
|
||||
return nil, errors.New("native firewall is nil")
|
||||
}
|
||||
|
||||
mgr, err := create(iface, nativeFirewall, disableServerRoutes, flowLogger)
|
||||
mgr, err := create(iface, nativeFirewall, disableServerRoutes, flowLogger, mtu)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -148,8 +170,8 @@ func CreateWithNativeFirewall(iface common.IFaceMapper, nativeFirewall firewall.
|
||||
return mgr, nil
|
||||
}
|
||||
|
||||
func parseCreateEnv() (bool, bool) {
|
||||
var disableConntrack, enableLocalForwarding bool
|
||||
func parseCreateEnv() (bool, bool, bool) {
|
||||
var disableConntrack, enableLocalForwarding, disableMSSClamping bool
|
||||
var err error
|
||||
if val := os.Getenv(EnvDisableConntrack); val != "" {
|
||||
disableConntrack, err = strconv.ParseBool(val)
|
||||
@@ -168,12 +190,18 @@ func parseCreateEnv() (bool, bool) {
|
||||
log.Warnf("failed to parse %s: %v", EnvEnableLocalForwarding, err)
|
||||
}
|
||||
}
|
||||
if val := os.Getenv(EnvDisableMSSClamping); val != "" {
|
||||
disableMSSClamping, err = strconv.ParseBool(val)
|
||||
if err != nil {
|
||||
log.Warnf("failed to parse %s: %v", EnvDisableMSSClamping, err)
|
||||
}
|
||||
}
|
||||
|
||||
return disableConntrack, enableLocalForwarding
|
||||
return disableConntrack, enableLocalForwarding, disableMSSClamping
|
||||
}
|
||||
|
||||
func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableServerRoutes bool, flowLogger nftypes.FlowLogger) (*Manager, error) {
|
||||
disableConntrack, enableLocalForwarding := parseCreateEnv()
|
||||
func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableServerRoutes bool, flowLogger nftypes.FlowLogger, mtu uint16) (*Manager, error) {
|
||||
disableConntrack, enableLocalForwarding, disableMSSClamping := parseCreateEnv()
|
||||
|
||||
m := &Manager{
|
||||
decoders: sync.Pool{
|
||||
@@ -203,13 +231,18 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe
|
||||
localForwarding: enableLocalForwarding,
|
||||
dnatMappings: make(map[netip.Addr]netip.Addr),
|
||||
portDNATRules: []portDNATRule{},
|
||||
netstackServices: make(map[serviceKey]struct{}),
|
||||
mtu: mtu,
|
||||
}
|
||||
m.routingEnabled.Store(false)
|
||||
|
||||
if !disableMSSClamping {
|
||||
m.mssClampEnabled = true
|
||||
m.mssClampValue = mtu - ipTCPHeaderMinSize
|
||||
}
|
||||
if err := m.localipmanager.UpdateLocalIPs(iface); err != nil {
|
||||
return nil, fmt.Errorf("update local IPs: %w", err)
|
||||
}
|
||||
|
||||
if disableConntrack {
|
||||
log.Info("conntrack is disabled")
|
||||
} else {
|
||||
@@ -217,14 +250,11 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe
|
||||
m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout, m.logger, flowLogger)
|
||||
m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout, m.logger, flowLogger)
|
||||
}
|
||||
|
||||
// netstack needs the forwarder for local traffic
|
||||
if m.netstack && m.localForwarding {
|
||||
if err := m.initForwarder(); err != nil {
|
||||
log.Errorf("failed to initialize forwarder: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := iface.SetFilter(m); err != nil {
|
||||
return nil, fmt.Errorf("set filter: %w", err)
|
||||
}
|
||||
@@ -327,7 +357,7 @@ func (m *Manager) initForwarder() error {
|
||||
return errors.New("forwarding not supported")
|
||||
}
|
||||
|
||||
forwarder, err := forwarder.New(m.wgIface, m.logger, m.flowLogger, m.netstack)
|
||||
forwarder, err := forwarder.New(m.wgIface, m.logger, m.flowLogger, m.netstack, m.mtu)
|
||||
if err != nil {
|
||||
m.routingEnabled.Store(false)
|
||||
return fmt.Errorf("create forwarder: %w", err)
|
||||
@@ -633,8 +663,17 @@ func (m *Manager) filterOutbound(packetData []byte, size int) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
if d.decoded[1] == layers.LayerTypeUDP && m.udpHooksDrop(uint16(d.udp.DstPort), dstIP, packetData) {
|
||||
return true
|
||||
switch d.decoded[1] {
|
||||
case layers.LayerTypeUDP:
|
||||
if m.udpHooksDrop(uint16(d.udp.DstPort), dstIP, packetData) {
|
||||
return true
|
||||
}
|
||||
case layers.LayerTypeTCP:
|
||||
// Clamp MSS on all TCP SYN packets, including those from local IPs.
|
||||
// SNATed routed traffic may appear as local IP but still requires clamping.
|
||||
if m.mssClampEnabled {
|
||||
m.clampTCPMSS(packetData, d)
|
||||
}
|
||||
}
|
||||
|
||||
m.trackOutbound(d, srcIP, dstIP, packetData, size)
|
||||
@@ -681,6 +720,97 @@ func getTCPFlags(tcp *layers.TCP) uint8 {
|
||||
return flags
|
||||
}
|
||||
|
||||
// clampTCPMSS clamps the TCP MSS option in SYN and SYN-ACK packets to prevent fragmentation.
|
||||
// Both sides advertise their MSS during connection establishment, so we need to clamp both.
|
||||
func (m *Manager) clampTCPMSS(packetData []byte, d *decoder) bool {
|
||||
if !d.tcp.SYN {
|
||||
return false
|
||||
}
|
||||
if len(d.tcp.Options) == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
mssOptionIndex := -1
|
||||
var currentMSS uint16
|
||||
for i, opt := range d.tcp.Options {
|
||||
if opt.OptionType == layers.TCPOptionKindMSS && len(opt.OptionData) == 2 {
|
||||
currentMSS = binary.BigEndian.Uint16(opt.OptionData)
|
||||
if currentMSS > m.mssClampValue {
|
||||
mssOptionIndex = i
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if mssOptionIndex == -1 {
|
||||
return false
|
||||
}
|
||||
|
||||
ipHeaderSize := int(d.ip4.IHL) * 4
|
||||
if ipHeaderSize < 20 {
|
||||
return false
|
||||
}
|
||||
|
||||
if !m.updateMSSOption(packetData, d, mssOptionIndex, ipHeaderSize) {
|
||||
return false
|
||||
}
|
||||
|
||||
m.logger.Trace2("Clamped TCP MSS from %d to %d", currentMSS, m.mssClampValue)
|
||||
return true
|
||||
}
|
||||
|
||||
func (m *Manager) updateMSSOption(packetData []byte, d *decoder, mssOptionIndex, ipHeaderSize int) bool {
|
||||
tcpHeaderStart := ipHeaderSize
|
||||
tcpOptionsStart := tcpHeaderStart + 20
|
||||
|
||||
optOffset := tcpOptionsStart
|
||||
for j := 0; j < mssOptionIndex; j++ {
|
||||
switch d.tcp.Options[j].OptionType {
|
||||
case layers.TCPOptionKindEndList, layers.TCPOptionKindNop:
|
||||
optOffset++
|
||||
default:
|
||||
optOffset += 2 + len(d.tcp.Options[j].OptionData)
|
||||
}
|
||||
}
|
||||
|
||||
mssValueOffset := optOffset + 2
|
||||
binary.BigEndian.PutUint16(packetData[mssValueOffset:mssValueOffset+2], m.mssClampValue)
|
||||
|
||||
m.recalculateTCPChecksum(packetData, d, tcpHeaderStart)
|
||||
return true
|
||||
}
|
||||
|
||||
func (m *Manager) recalculateTCPChecksum(packetData []byte, d *decoder, tcpHeaderStart int) {
|
||||
tcpLayer := packetData[tcpHeaderStart:]
|
||||
tcpLength := len(packetData) - tcpHeaderStart
|
||||
|
||||
tcpLayer[16] = 0
|
||||
tcpLayer[17] = 0
|
||||
|
||||
var pseudoSum uint32
|
||||
pseudoSum += uint32(d.ip4.SrcIP[0])<<8 | uint32(d.ip4.SrcIP[1])
|
||||
pseudoSum += uint32(d.ip4.SrcIP[2])<<8 | uint32(d.ip4.SrcIP[3])
|
||||
pseudoSum += uint32(d.ip4.DstIP[0])<<8 | uint32(d.ip4.DstIP[1])
|
||||
pseudoSum += uint32(d.ip4.DstIP[2])<<8 | uint32(d.ip4.DstIP[3])
|
||||
pseudoSum += uint32(d.ip4.Protocol)
|
||||
pseudoSum += uint32(tcpLength)
|
||||
|
||||
var sum uint32 = pseudoSum
|
||||
for i := 0; i < tcpLength-1; i += 2 {
|
||||
sum += uint32(tcpLayer[i])<<8 | uint32(tcpLayer[i+1])
|
||||
}
|
||||
if tcpLength%2 == 1 {
|
||||
sum += uint32(tcpLayer[tcpLength-1]) << 8
|
||||
}
|
||||
|
||||
for sum > 0xFFFF {
|
||||
sum = (sum & 0xFFFF) + (sum >> 16)
|
||||
}
|
||||
|
||||
checksum := ^uint16(sum)
|
||||
binary.BigEndian.PutUint16(tcpLayer[16:18], checksum)
|
||||
}
|
||||
|
||||
func (m *Manager) trackOutbound(d *decoder, srcIP, dstIP netip.Addr, packetData []byte, size int) {
|
||||
transport := d.decoded[1]
|
||||
switch transport {
|
||||
@@ -838,9 +968,7 @@ func (m *Manager) handleLocalTraffic(d *decoder, srcIP, dstIP netip.Addr, packet
|
||||
return true
|
||||
}
|
||||
|
||||
// If requested we pass local traffic to internal interfaces to the forwarder.
|
||||
// netstack doesn't have an interface to forward packets to the native stack so we always need to use the forwarder.
|
||||
if m.localForwarding && (m.netstack || dstIP != m.wgIface.Address().IP) {
|
||||
if m.shouldForward(d, dstIP) {
|
||||
return m.handleForwardedLocalTraffic(packetData)
|
||||
}
|
||||
|
||||
@@ -1274,3 +1402,86 @@ func (m *Manager) DisableRouting() error {
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// RegisterNetstackService registers a service as listening on the netstack for the given protocol and port
|
||||
func (m *Manager) RegisterNetstackService(protocol nftypes.Protocol, port uint16) {
|
||||
m.netstackServiceMutex.Lock()
|
||||
defer m.netstackServiceMutex.Unlock()
|
||||
layerType := m.protocolToLayerType(protocol)
|
||||
key := serviceKey{protocol: layerType, port: port}
|
||||
m.netstackServices[key] = struct{}{}
|
||||
m.logger.Debug3("RegisterNetstackService: registered %s:%d (layerType=%s)", protocol, port, layerType)
|
||||
m.logger.Debug1("RegisterNetstackService: current registry size: %d", len(m.netstackServices))
|
||||
}
|
||||
|
||||
// UnregisterNetstackService removes a service from the netstack registry
|
||||
func (m *Manager) UnregisterNetstackService(protocol nftypes.Protocol, port uint16) {
|
||||
m.netstackServiceMutex.Lock()
|
||||
defer m.netstackServiceMutex.Unlock()
|
||||
layerType := m.protocolToLayerType(protocol)
|
||||
key := serviceKey{protocol: layerType, port: port}
|
||||
delete(m.netstackServices, key)
|
||||
m.logger.Debug2("Unregistered netstack service on protocol %s port %d", protocol, port)
|
||||
}
|
||||
|
||||
// protocolToLayerType converts nftypes.Protocol to gopacket.LayerType for internal use
|
||||
func (m *Manager) protocolToLayerType(protocol nftypes.Protocol) gopacket.LayerType {
|
||||
switch protocol {
|
||||
case nftypes.TCP:
|
||||
return layers.LayerTypeTCP
|
||||
case nftypes.UDP:
|
||||
return layers.LayerTypeUDP
|
||||
case nftypes.ICMP:
|
||||
return layers.LayerTypeICMPv4
|
||||
default:
|
||||
return gopacket.LayerType(0) // Invalid/unknown
|
||||
}
|
||||
}
|
||||
|
||||
// shouldForward determines if a packet should be forwarded to the forwarder.
|
||||
// The forwarder handles routing packets to the native OS network stack.
|
||||
// Returns true if packet should go to the forwarder, false if it should go to netstack listeners or the native stack directly.
|
||||
func (m *Manager) shouldForward(d *decoder, dstIP netip.Addr) bool {
|
||||
// not enabled, never forward
|
||||
if !m.localForwarding {
|
||||
return false
|
||||
}
|
||||
|
||||
// netstack always needs to forward because it's lacking a native interface
|
||||
// exception for registered netstack services, those should go to netstack listeners
|
||||
if m.netstack {
|
||||
return !m.hasMatchingNetstackService(d)
|
||||
}
|
||||
|
||||
// traffic to our other local interfaces (not NetBird IP) - always forward
|
||||
if dstIP != m.wgIface.Address().IP {
|
||||
return true
|
||||
}
|
||||
|
||||
// traffic to our NetBird IP, not netstack mode - send to netstack listeners
|
||||
return false
|
||||
}
|
||||
|
||||
// hasMatchingNetstackService checks if there's a registered netstack service for this packet
|
||||
func (m *Manager) hasMatchingNetstackService(d *decoder) bool {
|
||||
if len(d.decoded) < 2 {
|
||||
return false
|
||||
}
|
||||
|
||||
var dstPort uint16
|
||||
switch d.decoded[1] {
|
||||
case layers.LayerTypeTCP:
|
||||
dstPort = uint16(d.tcp.DstPort)
|
||||
case layers.LayerTypeUDP:
|
||||
dstPort = uint16(d.udp.DstPort)
|
||||
default:
|
||||
return false
|
||||
}
|
||||
|
||||
key := serviceKey{protocol: d.decoded[1], port: dstPort}
|
||||
m.netstackServiceMutex.RLock()
|
||||
_, exists := m.netstackServices[key]
|
||||
m.netstackServiceMutex.RUnlock()
|
||||
|
||||
return exists
|
||||
}
|
||||
|
||||
@@ -17,6 +17,7 @@ import (
|
||||
|
||||
fw "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
)
|
||||
|
||||
@@ -169,7 +170,7 @@ func BenchmarkCoreFiltering(b *testing.B) {
|
||||
// Create manager and basic setup
|
||||
manager, _ := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, false, flowLogger)
|
||||
}, false, flowLogger, iface.DefaultMTU)
|
||||
defer b.Cleanup(func() {
|
||||
require.NoError(b, manager.Close(nil))
|
||||
})
|
||||
@@ -209,7 +210,7 @@ func BenchmarkStateScaling(b *testing.B) {
|
||||
b.Run(fmt.Sprintf("conns_%d", count), func(b *testing.B) {
|
||||
manager, _ := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, false, flowLogger)
|
||||
}, false, flowLogger, iface.DefaultMTU)
|
||||
b.Cleanup(func() {
|
||||
require.NoError(b, manager.Close(nil))
|
||||
})
|
||||
@@ -252,7 +253,7 @@ func BenchmarkEstablishmentOverhead(b *testing.B) {
|
||||
b.Run(sc.name, func(b *testing.B) {
|
||||
manager, _ := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, false, flowLogger)
|
||||
}, false, flowLogger, iface.DefaultMTU)
|
||||
b.Cleanup(func() {
|
||||
require.NoError(b, manager.Close(nil))
|
||||
})
|
||||
@@ -410,7 +411,7 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
|
||||
b.Run(sc.name, func(b *testing.B) {
|
||||
manager, _ := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, false, flowLogger)
|
||||
}, false, flowLogger, iface.DefaultMTU)
|
||||
b.Cleanup(func() {
|
||||
require.NoError(b, manager.Close(nil))
|
||||
})
|
||||
@@ -537,7 +538,7 @@ func BenchmarkLongLivedConnections(b *testing.B) {
|
||||
|
||||
manager, _ := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, false, flowLogger)
|
||||
}, false, flowLogger, iface.DefaultMTU)
|
||||
defer b.Cleanup(func() {
|
||||
require.NoError(b, manager.Close(nil))
|
||||
})
|
||||
@@ -620,7 +621,7 @@ func BenchmarkShortLivedConnections(b *testing.B) {
|
||||
|
||||
manager, _ := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, false, flowLogger)
|
||||
}, false, flowLogger, iface.DefaultMTU)
|
||||
defer b.Cleanup(func() {
|
||||
require.NoError(b, manager.Close(nil))
|
||||
})
|
||||
@@ -731,7 +732,7 @@ func BenchmarkParallelLongLivedConnections(b *testing.B) {
|
||||
|
||||
manager, _ := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, false, flowLogger)
|
||||
}, false, flowLogger, iface.DefaultMTU)
|
||||
defer b.Cleanup(func() {
|
||||
require.NoError(b, manager.Close(nil))
|
||||
})
|
||||
@@ -811,7 +812,7 @@ func BenchmarkParallelShortLivedConnections(b *testing.B) {
|
||||
|
||||
manager, _ := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, false, flowLogger)
|
||||
}, false, flowLogger, iface.DefaultMTU)
|
||||
defer b.Cleanup(func() {
|
||||
require.NoError(b, manager.Close(nil))
|
||||
})
|
||||
@@ -896,38 +897,6 @@ func BenchmarkParallelShortLivedConnections(b *testing.B) {
|
||||
}
|
||||
}
|
||||
|
||||
// generateTCPPacketWithFlags creates a TCP packet with specific flags
|
||||
func generateTCPPacketWithFlags(b *testing.B, srcIP, dstIP net.IP, srcPort, dstPort, flags uint16) []byte {
|
||||
b.Helper()
|
||||
|
||||
ipv4 := &layers.IPv4{
|
||||
TTL: 64,
|
||||
Version: 4,
|
||||
SrcIP: srcIP,
|
||||
DstIP: dstIP,
|
||||
Protocol: layers.IPProtocolTCP,
|
||||
}
|
||||
|
||||
tcp := &layers.TCP{
|
||||
SrcPort: layers.TCPPort(srcPort),
|
||||
DstPort: layers.TCPPort(dstPort),
|
||||
}
|
||||
|
||||
// Set TCP flags
|
||||
tcp.SYN = (flags & uint16(conntrack.TCPSyn)) != 0
|
||||
tcp.ACK = (flags & uint16(conntrack.TCPAck)) != 0
|
||||
tcp.PSH = (flags & uint16(conntrack.TCPPush)) != 0
|
||||
tcp.RST = (flags & uint16(conntrack.TCPRst)) != 0
|
||||
tcp.FIN = (flags & uint16(conntrack.TCPFin)) != 0
|
||||
|
||||
require.NoError(b, tcp.SetNetworkLayerForChecksum(ipv4))
|
||||
|
||||
buf := gopacket.NewSerializeBuffer()
|
||||
opts := gopacket.SerializeOptions{ComputeChecksums: true, FixLengths: true}
|
||||
require.NoError(b, gopacket.SerializeLayers(buf, opts, ipv4, tcp, gopacket.Payload("test")))
|
||||
return buf.Bytes()
|
||||
}
|
||||
|
||||
func BenchmarkRouteACLs(b *testing.B) {
|
||||
manager := setupRoutedManager(b, "10.10.0.100/16")
|
||||
|
||||
@@ -990,3 +959,231 @@ func BenchmarkRouteACLs(b *testing.B) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkMSSClamping benchmarks the MSS clamping impact on filterOutbound.
|
||||
// This shows the overhead difference between the common case (non-SYN packets, fast path)
|
||||
// and the rare case (SYN packets that need clamping, expensive path).
|
||||
func BenchmarkMSSClamping(b *testing.B) {
|
||||
scenarios := []struct {
|
||||
name string
|
||||
description string
|
||||
genPacket func(*testing.B, net.IP, net.IP) []byte
|
||||
frequency string
|
||||
}{
|
||||
{
|
||||
name: "syn_needs_clamp",
|
||||
description: "SYN packet needing MSS clamping",
|
||||
genPacket: func(b *testing.B, src, dst net.IP) []byte {
|
||||
return generateSYNPacketWithMSS(b, src, dst, 12345, 80, 1460)
|
||||
},
|
||||
frequency: "~0.1% of traffic - EXPENSIVE",
|
||||
},
|
||||
{
|
||||
name: "syn_no_clamp_needed",
|
||||
description: "SYN packet with already-small MSS",
|
||||
genPacket: func(b *testing.B, src, dst net.IP) []byte {
|
||||
return generateSYNPacketWithMSS(b, src, dst, 12345, 80, 1200)
|
||||
},
|
||||
frequency: "~0.05% of traffic",
|
||||
},
|
||||
{
|
||||
name: "tcp_ack",
|
||||
description: "Non-SYN TCP packet (ACK, data transfer)",
|
||||
genPacket: func(b *testing.B, src, dst net.IP) []byte {
|
||||
return generateTCPPacketWithFlags(b, src, dst, 12345, 80, uint16(conntrack.TCPAck))
|
||||
},
|
||||
frequency: "~60-70% of traffic - FAST PATH",
|
||||
},
|
||||
{
|
||||
name: "tcp_psh_ack",
|
||||
description: "TCP data packet (PSH+ACK)",
|
||||
genPacket: func(b *testing.B, src, dst net.IP) []byte {
|
||||
return generateTCPPacketWithFlags(b, src, dst, 12345, 80, uint16(conntrack.TCPPush|conntrack.TCPAck))
|
||||
},
|
||||
frequency: "~10-20% of traffic - FAST PATH",
|
||||
},
|
||||
{
|
||||
name: "udp",
|
||||
description: "UDP packet",
|
||||
genPacket: func(b *testing.B, src, dst net.IP) []byte {
|
||||
return generatePacket(b, src, dst, 12345, 80, layers.IPProtocolUDP)
|
||||
},
|
||||
frequency: "~20-30% of traffic - FAST PATH",
|
||||
},
|
||||
}
|
||||
|
||||
for _, sc := range scenarios {
|
||||
b.Run(sc.name, func(b *testing.B) {
|
||||
manager, err := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, false, flowLogger, iface.DefaultMTU)
|
||||
require.NoError(b, err)
|
||||
defer func() {
|
||||
require.NoError(b, manager.Close(nil))
|
||||
}()
|
||||
|
||||
manager.mssClampEnabled = true
|
||||
manager.mssClampValue = 1240
|
||||
|
||||
srcIP := net.ParseIP("100.64.0.2")
|
||||
dstIP := net.ParseIP("8.8.8.8")
|
||||
packet := sc.genPacket(b, srcIP, dstIP)
|
||||
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
manager.filterOutbound(packet, len(packet))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkMSSClampingOverhead compares overhead of MSS clamping enabled vs disabled
|
||||
// for the common case (non-SYN TCP packets).
|
||||
func BenchmarkMSSClampingOverhead(b *testing.B) {
|
||||
scenarios := []struct {
|
||||
name string
|
||||
enabled bool
|
||||
genPacket func(*testing.B, net.IP, net.IP) []byte
|
||||
}{
|
||||
{
|
||||
name: "disabled_tcp_ack",
|
||||
enabled: false,
|
||||
genPacket: func(b *testing.B, src, dst net.IP) []byte {
|
||||
return generateTCPPacketWithFlags(b, src, dst, 12345, 80, uint16(conntrack.TCPAck))
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "enabled_tcp_ack",
|
||||
enabled: true,
|
||||
genPacket: func(b *testing.B, src, dst net.IP) []byte {
|
||||
return generateTCPPacketWithFlags(b, src, dst, 12345, 80, uint16(conntrack.TCPAck))
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "disabled_syn_needs_clamp",
|
||||
enabled: false,
|
||||
genPacket: func(b *testing.B, src, dst net.IP) []byte {
|
||||
return generateSYNPacketWithMSS(b, src, dst, 12345, 80, 1460)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "enabled_syn_needs_clamp",
|
||||
enabled: true,
|
||||
genPacket: func(b *testing.B, src, dst net.IP) []byte {
|
||||
return generateSYNPacketWithMSS(b, src, dst, 12345, 80, 1460)
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, sc := range scenarios {
|
||||
b.Run(sc.name, func(b *testing.B) {
|
||||
manager, err := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, false, flowLogger, iface.DefaultMTU)
|
||||
require.NoError(b, err)
|
||||
defer func() {
|
||||
require.NoError(b, manager.Close(nil))
|
||||
}()
|
||||
|
||||
manager.mssClampEnabled = sc.enabled
|
||||
if sc.enabled {
|
||||
manager.mssClampValue = 1240
|
||||
}
|
||||
|
||||
srcIP := net.ParseIP("100.64.0.2")
|
||||
dstIP := net.ParseIP("8.8.8.8")
|
||||
packet := sc.genPacket(b, srcIP, dstIP)
|
||||
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
manager.filterOutbound(packet, len(packet))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkMSSClampingMemory measures memory allocations for common vs rare cases
|
||||
func BenchmarkMSSClampingMemory(b *testing.B) {
|
||||
scenarios := []struct {
|
||||
name string
|
||||
genPacket func(*testing.B, net.IP, net.IP) []byte
|
||||
}{
|
||||
{
|
||||
name: "tcp_ack_fast_path",
|
||||
genPacket: func(b *testing.B, src, dst net.IP) []byte {
|
||||
return generateTCPPacketWithFlags(b, src, dst, 12345, 80, uint16(conntrack.TCPAck))
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "syn_needs_clamp",
|
||||
genPacket: func(b *testing.B, src, dst net.IP) []byte {
|
||||
return generateSYNPacketWithMSS(b, src, dst, 12345, 80, 1460)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "udp_fast_path",
|
||||
genPacket: func(b *testing.B, src, dst net.IP) []byte {
|
||||
return generatePacket(b, src, dst, 12345, 80, layers.IPProtocolUDP)
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, sc := range scenarios {
|
||||
b.Run(sc.name, func(b *testing.B) {
|
||||
manager, err := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, false, flowLogger, iface.DefaultMTU)
|
||||
require.NoError(b, err)
|
||||
defer func() {
|
||||
require.NoError(b, manager.Close(nil))
|
||||
}()
|
||||
|
||||
manager.mssClampEnabled = true
|
||||
manager.mssClampValue = 1240
|
||||
|
||||
srcIP := net.ParseIP("100.64.0.2")
|
||||
dstIP := net.ParseIP("8.8.8.8")
|
||||
packet := sc.genPacket(b, srcIP, dstIP)
|
||||
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
manager.filterOutbound(packet, len(packet))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func generateSYNPacketNoMSS(b *testing.B, srcIP, dstIP net.IP, srcPort, dstPort uint16) []byte {
|
||||
b.Helper()
|
||||
|
||||
ip := &layers.IPv4{
|
||||
Version: 4,
|
||||
IHL: 5,
|
||||
TTL: 64,
|
||||
Protocol: layers.IPProtocolTCP,
|
||||
SrcIP: srcIP,
|
||||
DstIP: dstIP,
|
||||
}
|
||||
|
||||
tcp := &layers.TCP{
|
||||
SrcPort: layers.TCPPort(srcPort),
|
||||
DstPort: layers.TCPPort(dstPort),
|
||||
SYN: true,
|
||||
Seq: 1000,
|
||||
Window: 65535,
|
||||
}
|
||||
|
||||
require.NoError(b, tcp.SetNetworkLayerForChecksum(ip))
|
||||
|
||||
buf := gopacket.NewSerializeBuffer()
|
||||
opts := gopacket.SerializeOptions{
|
||||
FixLengths: true,
|
||||
ComputeChecksums: true,
|
||||
}
|
||||
|
||||
require.NoError(b, gopacket.SerializeLayers(buf, opts, ip, tcp, gopacket.Payload([]byte{})))
|
||||
return buf.Bytes()
|
||||
}
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
wgdevice "golang.zx2c4.com/wireguard/device"
|
||||
|
||||
fw "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/client/iface/mocks"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
@@ -31,7 +32,7 @@ func TestPeerACLFiltering(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
manager, err := Create(ifaceMock, false, flowLogger)
|
||||
manager, err := Create(ifaceMock, false, flowLogger, iface.DefaultMTU)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, manager)
|
||||
|
||||
@@ -616,7 +617,7 @@ func setupRoutedManager(tb testing.TB, network string) *Manager {
|
||||
},
|
||||
}
|
||||
|
||||
manager, err := Create(ifaceMock, false, flowLogger)
|
||||
manager, err := Create(ifaceMock, false, flowLogger, iface.DefaultMTU)
|
||||
require.NoError(tb, err)
|
||||
require.NoError(tb, manager.EnableRouting())
|
||||
require.NotNil(tb, manager)
|
||||
@@ -1462,7 +1463,7 @@ func TestRouteACLSet(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
manager, err := Create(ifaceMock, false, flowLogger)
|
||||
manager, err := Create(ifaceMock, false, flowLogger, iface.DefaultMTU)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
require.NoError(t, manager.Close(nil))
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package uspfilter
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
@@ -17,6 +18,7 @@ import (
|
||||
fw "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
|
||||
"github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
||||
nbiface "github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
"github.com/netbirdio/netbird/client/internal/netflow"
|
||||
@@ -66,7 +68,7 @@ func TestManagerCreate(t *testing.T) {
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}
|
||||
|
||||
m, err := Create(ifaceMock, false, flowLogger)
|
||||
m, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
|
||||
if err != nil {
|
||||
t.Errorf("failed to create Manager: %v", err)
|
||||
return
|
||||
@@ -86,7 +88,7 @@ func TestManagerAddPeerFiltering(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
m, err := Create(ifaceMock, false, flowLogger)
|
||||
m, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
|
||||
if err != nil {
|
||||
t.Errorf("failed to create Manager: %v", err)
|
||||
return
|
||||
@@ -119,7 +121,7 @@ func TestManagerDeleteRule(t *testing.T) {
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}
|
||||
|
||||
m, err := Create(ifaceMock, false, flowLogger)
|
||||
m, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
|
||||
if err != nil {
|
||||
t.Errorf("failed to create Manager: %v", err)
|
||||
return
|
||||
@@ -215,7 +217,7 @@ func TestAddUDPPacketHook(t *testing.T) {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
manager, err := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, false, flowLogger)
|
||||
}, false, flowLogger, nbiface.DefaultMTU)
|
||||
require.NoError(t, err)
|
||||
|
||||
manager.AddUDPPacketHook(tt.in, tt.ip, tt.dPort, tt.hook)
|
||||
@@ -265,7 +267,7 @@ func TestManagerReset(t *testing.T) {
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}
|
||||
|
||||
m, err := Create(ifaceMock, false, flowLogger)
|
||||
m, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
|
||||
if err != nil {
|
||||
t.Errorf("failed to create Manager: %v", err)
|
||||
return
|
||||
@@ -304,7 +306,7 @@ func TestNotMatchByIP(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
m, err := Create(ifaceMock, false, flowLogger)
|
||||
m, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
|
||||
if err != nil {
|
||||
t.Errorf("failed to create Manager: %v", err)
|
||||
return
|
||||
@@ -367,7 +369,7 @@ func TestRemovePacketHook(t *testing.T) {
|
||||
}
|
||||
|
||||
// creating manager instance
|
||||
manager, err := Create(iface, false, flowLogger)
|
||||
manager, err := Create(iface, false, flowLogger, nbiface.DefaultMTU)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create Manager: %s", err)
|
||||
}
|
||||
@@ -413,7 +415,7 @@ func TestRemovePacketHook(t *testing.T) {
|
||||
func TestProcessOutgoingHooks(t *testing.T) {
|
||||
manager, err := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, false, flowLogger)
|
||||
}, false, flowLogger, nbiface.DefaultMTU)
|
||||
require.NoError(t, err)
|
||||
|
||||
manager.udpTracker.Close()
|
||||
@@ -495,7 +497,7 @@ func TestUSPFilterCreatePerformance(t *testing.T) {
|
||||
ifaceMock := &IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}
|
||||
manager, err := Create(ifaceMock, false, flowLogger)
|
||||
manager, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
|
||||
require.NoError(t, err)
|
||||
time.Sleep(time.Second)
|
||||
|
||||
@@ -522,7 +524,7 @@ func TestUSPFilterCreatePerformance(t *testing.T) {
|
||||
func TestStatefulFirewall_UDPTracking(t *testing.T) {
|
||||
manager, err := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, false, flowLogger)
|
||||
}, false, flowLogger, nbiface.DefaultMTU)
|
||||
require.NoError(t, err)
|
||||
|
||||
manager.udpTracker.Close() // Close the existing tracker
|
||||
@@ -729,7 +731,7 @@ func TestUpdateSetMerge(t *testing.T) {
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}
|
||||
|
||||
manager, err := Create(ifaceMock, false, flowLogger)
|
||||
manager, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
require.NoError(t, manager.Close(nil))
|
||||
@@ -815,7 +817,7 @@ func TestUpdateSetDeduplication(t *testing.T) {
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}
|
||||
|
||||
manager, err := Create(ifaceMock, false, flowLogger)
|
||||
manager, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
require.NoError(t, manager.Close(nil))
|
||||
@@ -923,3 +925,192 @@ func TestUpdateSetDeduplication(t *testing.T) {
|
||||
require.Equal(t, tc.expected, isAllowed, tc.desc)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMSSClamping(t *testing.T) {
|
||||
ifaceMock := &IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
AddressFunc: func() wgaddr.Address {
|
||||
return wgaddr.Address{
|
||||
IP: netip.MustParseAddr("100.10.0.100"),
|
||||
Network: netip.MustParsePrefix("100.10.0.0/16"),
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
manager, err := Create(ifaceMock, false, flowLogger, 1280)
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
require.NoError(t, manager.Close(nil))
|
||||
}()
|
||||
|
||||
require.True(t, manager.mssClampEnabled, "MSS clamping should be enabled by default")
|
||||
expectedMSSValue := uint16(1280 - ipTCPHeaderMinSize)
|
||||
require.Equal(t, expectedMSSValue, manager.mssClampValue, "MSS clamp value should be MTU - 40")
|
||||
|
||||
err = manager.UpdateLocalIPs()
|
||||
require.NoError(t, err)
|
||||
|
||||
srcIP := net.ParseIP("100.10.0.2")
|
||||
dstIP := net.ParseIP("8.8.8.8")
|
||||
|
||||
t.Run("SYN packet with high MSS gets clamped", func(t *testing.T) {
|
||||
highMSS := uint16(1460)
|
||||
packet := generateSYNPacketWithMSS(t, srcIP, dstIP, 12345, 80, highMSS)
|
||||
|
||||
manager.filterOutbound(packet, len(packet))
|
||||
|
||||
d := parsePacket(t, packet)
|
||||
require.Len(t, d.tcp.Options, 1, "Should have MSS option")
|
||||
require.Equal(t, uint8(layers.TCPOptionKindMSS), uint8(d.tcp.Options[0].OptionType))
|
||||
actualMSS := binary.BigEndian.Uint16(d.tcp.Options[0].OptionData)
|
||||
require.Equal(t, expectedMSSValue, actualMSS, "MSS should be clamped to MTU - 40")
|
||||
})
|
||||
|
||||
t.Run("SYN packet with low MSS unchanged", func(t *testing.T) {
|
||||
lowMSS := uint16(1200)
|
||||
packet := generateSYNPacketWithMSS(t, srcIP, dstIP, 12345, 80, lowMSS)
|
||||
|
||||
manager.filterOutbound(packet, len(packet))
|
||||
|
||||
d := parsePacket(t, packet)
|
||||
require.Len(t, d.tcp.Options, 1, "Should have MSS option")
|
||||
actualMSS := binary.BigEndian.Uint16(d.tcp.Options[0].OptionData)
|
||||
require.Equal(t, lowMSS, actualMSS, "Low MSS should not be modified")
|
||||
})
|
||||
|
||||
t.Run("SYN-ACK packet gets clamped", func(t *testing.T) {
|
||||
highMSS := uint16(1460)
|
||||
packet := generateSYNACKPacketWithMSS(t, srcIP, dstIP, 12345, 80, highMSS)
|
||||
|
||||
manager.filterOutbound(packet, len(packet))
|
||||
|
||||
d := parsePacket(t, packet)
|
||||
require.Len(t, d.tcp.Options, 1, "Should have MSS option")
|
||||
actualMSS := binary.BigEndian.Uint16(d.tcp.Options[0].OptionData)
|
||||
require.Equal(t, expectedMSSValue, actualMSS, "MSS in SYN-ACK should be clamped")
|
||||
})
|
||||
|
||||
t.Run("Non-SYN packet unchanged", func(t *testing.T) {
|
||||
packet := generateTCPPacketWithFlags(t, srcIP, dstIP, 12345, 80, uint16(conntrack.TCPAck))
|
||||
|
||||
manager.filterOutbound(packet, len(packet))
|
||||
|
||||
d := parsePacket(t, packet)
|
||||
require.Empty(t, d.tcp.Options, "ACK packet should have no options")
|
||||
})
|
||||
}
|
||||
|
||||
func generateSYNPacketWithMSS(tb testing.TB, srcIP, dstIP net.IP, srcPort, dstPort uint16, mss uint16) []byte {
|
||||
tb.Helper()
|
||||
|
||||
ipLayer := &layers.IPv4{
|
||||
Version: 4,
|
||||
TTL: 64,
|
||||
Protocol: layers.IPProtocolTCP,
|
||||
SrcIP: srcIP,
|
||||
DstIP: dstIP,
|
||||
}
|
||||
|
||||
tcpLayer := &layers.TCP{
|
||||
SrcPort: layers.TCPPort(srcPort),
|
||||
DstPort: layers.TCPPort(dstPort),
|
||||
SYN: true,
|
||||
Window: 65535,
|
||||
Options: []layers.TCPOption{
|
||||
{
|
||||
OptionType: layers.TCPOptionKindMSS,
|
||||
OptionLength: 4,
|
||||
OptionData: binary.BigEndian.AppendUint16(nil, mss),
|
||||
},
|
||||
},
|
||||
}
|
||||
err := tcpLayer.SetNetworkLayerForChecksum(ipLayer)
|
||||
require.NoError(tb, err)
|
||||
|
||||
buf := gopacket.NewSerializeBuffer()
|
||||
opts := gopacket.SerializeOptions{ComputeChecksums: true, FixLengths: true}
|
||||
err = gopacket.SerializeLayers(buf, opts, ipLayer, tcpLayer, gopacket.Payload([]byte{}))
|
||||
require.NoError(tb, err)
|
||||
|
||||
return buf.Bytes()
|
||||
}
|
||||
|
||||
func generateSYNACKPacketWithMSS(tb testing.TB, srcIP, dstIP net.IP, srcPort, dstPort uint16, mss uint16) []byte {
|
||||
tb.Helper()
|
||||
|
||||
ipLayer := &layers.IPv4{
|
||||
Version: 4,
|
||||
TTL: 64,
|
||||
Protocol: layers.IPProtocolTCP,
|
||||
SrcIP: srcIP,
|
||||
DstIP: dstIP,
|
||||
}
|
||||
|
||||
tcpLayer := &layers.TCP{
|
||||
SrcPort: layers.TCPPort(srcPort),
|
||||
DstPort: layers.TCPPort(dstPort),
|
||||
SYN: true,
|
||||
ACK: true,
|
||||
Window: 65535,
|
||||
Options: []layers.TCPOption{
|
||||
{
|
||||
OptionType: layers.TCPOptionKindMSS,
|
||||
OptionLength: 4,
|
||||
OptionData: binary.BigEndian.AppendUint16(nil, mss),
|
||||
},
|
||||
},
|
||||
}
|
||||
err := tcpLayer.SetNetworkLayerForChecksum(ipLayer)
|
||||
require.NoError(tb, err)
|
||||
|
||||
buf := gopacket.NewSerializeBuffer()
|
||||
opts := gopacket.SerializeOptions{ComputeChecksums: true, FixLengths: true}
|
||||
err = gopacket.SerializeLayers(buf, opts, ipLayer, tcpLayer, gopacket.Payload([]byte{}))
|
||||
require.NoError(tb, err)
|
||||
|
||||
return buf.Bytes()
|
||||
}
|
||||
|
||||
func generateTCPPacketWithFlags(tb testing.TB, srcIP, dstIP net.IP, srcPort, dstPort uint16, flags uint16) []byte {
|
||||
tb.Helper()
|
||||
|
||||
ipLayer := &layers.IPv4{
|
||||
Version: 4,
|
||||
TTL: 64,
|
||||
Protocol: layers.IPProtocolTCP,
|
||||
SrcIP: srcIP,
|
||||
DstIP: dstIP,
|
||||
}
|
||||
|
||||
tcpLayer := &layers.TCP{
|
||||
SrcPort: layers.TCPPort(srcPort),
|
||||
DstPort: layers.TCPPort(dstPort),
|
||||
Window: 65535,
|
||||
}
|
||||
|
||||
if flags&uint16(conntrack.TCPSyn) != 0 {
|
||||
tcpLayer.SYN = true
|
||||
}
|
||||
if flags&uint16(conntrack.TCPAck) != 0 {
|
||||
tcpLayer.ACK = true
|
||||
}
|
||||
if flags&uint16(conntrack.TCPFin) != 0 {
|
||||
tcpLayer.FIN = true
|
||||
}
|
||||
if flags&uint16(conntrack.TCPRst) != 0 {
|
||||
tcpLayer.RST = true
|
||||
}
|
||||
if flags&uint16(conntrack.TCPPush) != 0 {
|
||||
tcpLayer.PSH = true
|
||||
}
|
||||
|
||||
err := tcpLayer.SetNetworkLayerForChecksum(ipLayer)
|
||||
require.NoError(tb, err)
|
||||
|
||||
buf := gopacket.NewSerializeBuffer()
|
||||
opts := gopacket.SerializeOptions{ComputeChecksums: true, FixLengths: true}
|
||||
err = gopacket.SerializeLayers(buf, opts, ipLayer, tcpLayer, gopacket.Payload([]byte{}))
|
||||
require.NoError(tb, err)
|
||||
|
||||
return buf.Bytes()
|
||||
}
|
||||
|
||||
@@ -45,7 +45,7 @@ type Forwarder struct {
|
||||
netstack bool
|
||||
}
|
||||
|
||||
func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.FlowLogger, netstack bool) (*Forwarder, error) {
|
||||
func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.FlowLogger, netstack bool, mtu uint16) (*Forwarder, error) {
|
||||
s := stack.New(stack.Options{
|
||||
NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
|
||||
TransportProtocols: []stack.TransportProtocolFactory{
|
||||
@@ -56,10 +56,6 @@ func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.Flow
|
||||
HandleLocal: false,
|
||||
})
|
||||
|
||||
mtu, err := iface.GetDevice().MTU()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get MTU: %w", err)
|
||||
}
|
||||
nicID := tcpip.NICID(1)
|
||||
endpoint := &endpoint{
|
||||
logger: logger,
|
||||
@@ -68,7 +64,7 @@ func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.Flow
|
||||
}
|
||||
|
||||
if err := s.CreateNIC(nicID, endpoint); err != nil {
|
||||
return nil, fmt.Errorf("failed to create NIC: %v", err)
|
||||
return nil, fmt.Errorf("create NIC: %v", err)
|
||||
}
|
||||
|
||||
protoAddr := tcpip.ProtocolAddress{
|
||||
|
||||
@@ -49,7 +49,7 @@ type idleConn struct {
|
||||
conn *udpPacketConn
|
||||
}
|
||||
|
||||
func newUDPForwarder(mtu int, logger *nblog.Logger, flowLogger nftypes.FlowLogger) *udpForwarder {
|
||||
func newUDPForwarder(mtu uint16, logger *nblog.Logger, flowLogger nftypes.FlowLogger) *udpForwarder {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
f := &udpForwarder{
|
||||
logger: logger,
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
)
|
||||
|
||||
@@ -65,7 +66,7 @@ func BenchmarkDNATTranslation(b *testing.B) {
|
||||
b.Run(sc.name, func(b *testing.B) {
|
||||
manager, err := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, false, flowLogger)
|
||||
}, false, flowLogger, iface.DefaultMTU)
|
||||
require.NoError(b, err)
|
||||
defer func() {
|
||||
require.NoError(b, manager.Close(nil))
|
||||
@@ -125,7 +126,7 @@ func BenchmarkDNATTranslation(b *testing.B) {
|
||||
func BenchmarkDNATConcurrency(b *testing.B) {
|
||||
manager, err := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, false, flowLogger)
|
||||
}, false, flowLogger, iface.DefaultMTU)
|
||||
require.NoError(b, err)
|
||||
defer func() {
|
||||
require.NoError(b, manager.Close(nil))
|
||||
@@ -197,7 +198,7 @@ func BenchmarkDNATScaling(b *testing.B) {
|
||||
b.Run(fmt.Sprintf("mappings_%d", count), func(b *testing.B) {
|
||||
manager, err := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, false, flowLogger)
|
||||
}, false, flowLogger, iface.DefaultMTU)
|
||||
require.NoError(b, err)
|
||||
defer func() {
|
||||
require.NoError(b, manager.Close(nil))
|
||||
@@ -309,7 +310,7 @@ func BenchmarkChecksumUpdate(b *testing.B) {
|
||||
func BenchmarkDNATMemoryAllocations(b *testing.B) {
|
||||
manager, err := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, false, flowLogger)
|
||||
}, false, flowLogger, iface.DefaultMTU)
|
||||
require.NoError(b, err)
|
||||
defer func() {
|
||||
require.NoError(b, manager.Close(nil))
|
||||
@@ -472,7 +473,7 @@ func BenchmarkPortDNAT(b *testing.B) {
|
||||
b.Run(sc.name, func(b *testing.B) {
|
||||
manager, err := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, false, flowLogger)
|
||||
}, false, flowLogger, iface.DefaultMTU)
|
||||
require.NoError(b, err)
|
||||
defer func() {
|
||||
require.NoError(b, manager.Close(nil))
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
)
|
||||
|
||||
@@ -16,7 +17,7 @@ import (
|
||||
func TestDNATTranslationCorrectness(t *testing.T) {
|
||||
manager, err := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, false, flowLogger)
|
||||
}, false, flowLogger, iface.DefaultMTU)
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
require.NoError(t, manager.Close(nil))
|
||||
@@ -100,7 +101,7 @@ func parsePacket(t testing.TB, packetData []byte) *decoder {
|
||||
func TestDNATMappingManagement(t *testing.T) {
|
||||
manager, err := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, false, flowLogger)
|
||||
}, false, flowLogger, iface.DefaultMTU)
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
require.NoError(t, manager.Close(nil))
|
||||
@@ -148,7 +149,7 @@ func TestDNATMappingManagement(t *testing.T) {
|
||||
func TestInboundPortDNAT(t *testing.T) {
|
||||
manager, err := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, false, flowLogger)
|
||||
}, false, flowLogger, iface.DefaultMTU)
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
require.NoError(t, manager.Close(nil))
|
||||
@@ -198,7 +199,7 @@ func TestInboundPortDNAT(t *testing.T) {
|
||||
func TestInboundPortDNATNegative(t *testing.T) {
|
||||
manager, err := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, false, flowLogger)
|
||||
}, false, flowLogger, iface.DefaultMTU)
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
require.NoError(t, manager.Close(nil))
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
fw "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
|
||||
"github.com/netbirdio/netbird/client/firewall/uspfilter/forwarder"
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
)
|
||||
@@ -44,7 +45,7 @@ func TestTracePacket(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
m, err := Create(ifaceMock, false, flowLogger)
|
||||
m, err := Create(ifaceMock, false, flowLogger, iface.DefaultMTU)
|
||||
require.NoError(t, err)
|
||||
|
||||
if !statefulMode {
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/client/firewall"
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
"github.com/netbirdio/netbird/client/internal/acl/mocks"
|
||||
"github.com/netbirdio/netbird/client/internal/netflow"
|
||||
@@ -52,7 +53,7 @@ func TestDefaultManager(t *testing.T) {
|
||||
}).AnyTimes()
|
||||
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
|
||||
|
||||
fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false)
|
||||
fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false, iface.DefaultMTU)
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
err = fw.Close(nil)
|
||||
@@ -170,7 +171,7 @@ func TestDefaultManagerStateless(t *testing.T) {
|
||||
}).AnyTimes()
|
||||
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
|
||||
|
||||
fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false)
|
||||
fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false, iface.DefaultMTU)
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
err = fw.Close(nil)
|
||||
@@ -321,7 +322,7 @@ func TestDefaultManagerEnableSSHRules(t *testing.T) {
|
||||
}).AnyTimes()
|
||||
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
|
||||
|
||||
fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false)
|
||||
fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false, iface.DefaultMTU)
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
err = fw.Close(nil)
|
||||
|
||||
@@ -128,9 +128,34 @@ func (d *DeviceAuthorizationFlow) RequestAuthInfo(ctx context.Context) (AuthFlow
|
||||
deviceCode.VerificationURIComplete = deviceCode.VerificationURI
|
||||
}
|
||||
|
||||
if d.providerConfig.LoginHint != "" {
|
||||
deviceCode.VerificationURIComplete = appendLoginHint(deviceCode.VerificationURIComplete, d.providerConfig.LoginHint)
|
||||
if deviceCode.VerificationURI != "" {
|
||||
deviceCode.VerificationURI = appendLoginHint(deviceCode.VerificationURI, d.providerConfig.LoginHint)
|
||||
}
|
||||
}
|
||||
|
||||
return deviceCode, err
|
||||
}
|
||||
|
||||
func appendLoginHint(uri, loginHint string) string {
|
||||
if uri == "" || loginHint == "" {
|
||||
return uri
|
||||
}
|
||||
|
||||
parsedURL, err := url.Parse(uri)
|
||||
if err != nil {
|
||||
log.Debugf("failed to parse verification URI for login_hint: %v", err)
|
||||
return uri
|
||||
}
|
||||
|
||||
query := parsedURL.Query()
|
||||
query.Set("login_hint", loginHint)
|
||||
parsedURL.RawQuery = query.Encode()
|
||||
|
||||
return parsedURL.String()
|
||||
}
|
||||
|
||||
func (d *DeviceAuthorizationFlow) requestToken(info AuthFlowInfo) (TokenRequestResponse, error) {
|
||||
form := url.Values{}
|
||||
form.Add("client_id", d.providerConfig.ClientID)
|
||||
|
||||
@@ -66,32 +66,34 @@ func (t TokenInfo) GetTokenToUse() string {
|
||||
// and if that also fails, the authentication process is deemed unsuccessful
|
||||
//
|
||||
// On Linux distros without desktop environment support, it only tries to initialize the Device Code Flow
|
||||
func NewOAuthFlow(ctx context.Context, config *profilemanager.Config, isUnixDesktopClient bool) (OAuthFlow, error) {
|
||||
func NewOAuthFlow(ctx context.Context, config *profilemanager.Config, isUnixDesktopClient bool, hint string) (OAuthFlow, error) {
|
||||
if (runtime.GOOS == "linux" || runtime.GOOS == "freebsd") && !isUnixDesktopClient {
|
||||
return authenticateWithDeviceCodeFlow(ctx, config)
|
||||
return authenticateWithDeviceCodeFlow(ctx, config, hint)
|
||||
}
|
||||
|
||||
pkceFlow, err := authenticateWithPKCEFlow(ctx, config)
|
||||
pkceFlow, err := authenticateWithPKCEFlow(ctx, config, hint)
|
||||
if err != nil {
|
||||
// fallback to device code flow
|
||||
log.Debugf("failed to initialize pkce authentication with error: %v\n", err)
|
||||
log.Debug("falling back to device code flow")
|
||||
return authenticateWithDeviceCodeFlow(ctx, config)
|
||||
return authenticateWithDeviceCodeFlow(ctx, config, hint)
|
||||
}
|
||||
return pkceFlow, nil
|
||||
}
|
||||
|
||||
// authenticateWithPKCEFlow initializes the Proof Key for Code Exchange flow auth flow
|
||||
func authenticateWithPKCEFlow(ctx context.Context, config *profilemanager.Config) (OAuthFlow, error) {
|
||||
func authenticateWithPKCEFlow(ctx context.Context, config *profilemanager.Config, hint string) (OAuthFlow, error) {
|
||||
pkceFlowInfo, err := internal.GetPKCEAuthorizationFlowInfo(ctx, config.PrivateKey, config.ManagementURL, config.ClientCertKeyPair)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("getting pkce authorization flow info failed with error: %v", err)
|
||||
}
|
||||
|
||||
pkceFlowInfo.ProviderConfig.LoginHint = hint
|
||||
|
||||
return NewPKCEAuthorizationFlow(pkceFlowInfo.ProviderConfig)
|
||||
}
|
||||
|
||||
// authenticateWithDeviceCodeFlow initializes the Device Code auth Flow
|
||||
func authenticateWithDeviceCodeFlow(ctx context.Context, config *profilemanager.Config) (OAuthFlow, error) {
|
||||
func authenticateWithDeviceCodeFlow(ctx context.Context, config *profilemanager.Config, hint string) (OAuthFlow, error) {
|
||||
deviceFlowInfo, err := internal.GetDeviceAuthorizationFlowInfo(ctx, config.PrivateKey, config.ManagementURL)
|
||||
if err != nil {
|
||||
switch s, ok := gstatus.FromError(err); {
|
||||
@@ -107,5 +109,7 @@ func authenticateWithDeviceCodeFlow(ctx context.Context, config *profilemanager.
|
||||
}
|
||||
}
|
||||
|
||||
deviceFlowInfo.ProviderConfig.LoginHint = hint
|
||||
|
||||
return NewDeviceAuthorizationFlow(deviceFlowInfo.ProviderConfig)
|
||||
}
|
||||
|
||||
@@ -109,6 +109,9 @@ func (p *PKCEAuthorizationFlow) RequestAuthInfo(ctx context.Context) (AuthFlowIn
|
||||
params = append(params, oauth2.SetAuthURLParam("max_age", "0"))
|
||||
}
|
||||
}
|
||||
if p.providerConfig.LoginHint != "" {
|
||||
params = append(params, oauth2.SetAuthURLParam("login_hint", p.providerConfig.LoginHint))
|
||||
}
|
||||
|
||||
authURL := p.oAuthConfig.AuthCodeURL(state, params...)
|
||||
|
||||
|
||||
@@ -25,6 +25,7 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||
nbnet "github.com/netbirdio/netbird/client/net"
|
||||
cProto "github.com/netbirdio/netbird/client/proto"
|
||||
"github.com/netbirdio/netbird/client/ssh"
|
||||
"github.com/netbirdio/netbird/client/system"
|
||||
@@ -34,7 +35,6 @@ import (
|
||||
relayClient "github.com/netbirdio/netbird/shared/relay/client"
|
||||
signal "github.com/netbirdio/netbird/shared/signal/client"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
nbnet "github.com/netbirdio/netbird/client/net"
|
||||
"github.com/netbirdio/netbird/version"
|
||||
)
|
||||
|
||||
@@ -289,15 +289,18 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
|
||||
}
|
||||
|
||||
<-engineCtx.Done()
|
||||
|
||||
c.engineMutex.Lock()
|
||||
if c.engine != nil && c.engine.wgInterface != nil {
|
||||
log.Infof("ensuring %s is removed, Netbird engine context cancelled", c.engine.wgInterface.Name())
|
||||
if err := c.engine.Stop(); err != nil {
|
||||
engine := c.engine
|
||||
c.engine = nil
|
||||
c.engineMutex.Unlock()
|
||||
|
||||
if engine != nil && engine.wgInterface != nil {
|
||||
log.Infof("ensuring %s is removed, Netbird engine context cancelled", engine.wgInterface.Name())
|
||||
if err := engine.Stop(); err != nil {
|
||||
log.Errorf("Failed to stop engine: %v", err)
|
||||
}
|
||||
c.engine = nil
|
||||
}
|
||||
c.engineMutex.Unlock()
|
||||
c.statusRecorder.ClientTeardown()
|
||||
|
||||
backOff.Reset()
|
||||
@@ -382,19 +385,12 @@ func (c *ConnectClient) Status() StatusType {
|
||||
}
|
||||
|
||||
func (c *ConnectClient) Stop() error {
|
||||
if c == nil {
|
||||
return nil
|
||||
engine := c.Engine()
|
||||
if engine != nil {
|
||||
if err := engine.Stop(); err != nil {
|
||||
return fmt.Errorf("stop engine: %w", err)
|
||||
}
|
||||
}
|
||||
c.engineMutex.Lock()
|
||||
defer c.engineMutex.Unlock()
|
||||
|
||||
if c.engine == nil {
|
||||
return nil
|
||||
}
|
||||
if err := c.engine.Stop(); err != nil {
|
||||
return fmt.Errorf("stop engine: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -44,6 +44,8 @@ interfaces.txt: Anonymized network interface information, if --system-info flag
|
||||
ip_rules.txt: Detailed IP routing rules in tabular format including priority, source, destination, interfaces, table, and action information (Linux only), if --system-info flag was provided.
|
||||
iptables.txt: Anonymized iptables rules with packet counters, if --system-info flag was provided.
|
||||
nftables.txt: Anonymized nftables rules with packet counters, if --system-info flag was provided.
|
||||
resolv.conf: DNS resolver configuration from /etc/resolv.conf (Unix systems only), if --system-info flag was provided.
|
||||
scutil_dns.txt: DNS configuration from scutil --dns (macOS only), if --system-info flag was provided.
|
||||
resolved_domains.txt: Anonymized resolved domain IP addresses from the status recorder.
|
||||
config.txt: Anonymized configuration information of the NetBird client.
|
||||
network_map.json: Anonymized sync response containing peer configurations, routes, DNS settings, and firewall rules.
|
||||
@@ -184,6 +186,20 @@ The ip_rules.txt file contains detailed IP routing rule information:
|
||||
The table format provides comprehensive visibility into the IP routing decision process, including how traffic is directed to different routing tables based on various criteria. This is valuable for troubleshooting advanced routing configurations and policy-based routing.
|
||||
|
||||
For anonymized rules, IP addresses and prefixes are replaced as described above. Interface names are anonymized using string anonymization. Table names, actions, and other non-sensitive information remain unchanged.
|
||||
|
||||
DNS Configuration
|
||||
The debug bundle includes platform-specific DNS configuration files:
|
||||
|
||||
resolv.conf (Unix systems):
|
||||
- Contains DNS resolver configuration from /etc/resolv.conf
|
||||
- Includes nameserver entries, search domains, and resolver options
|
||||
- All IP addresses and domain names are anonymized following the same rules as other files
|
||||
|
||||
scutil_dns.txt (macOS only):
|
||||
- Contains detailed DNS configuration from scutil --dns
|
||||
- Shows DNS configuration for all network interfaces
|
||||
- Includes search domains, nameservers, and DNS resolver settings
|
||||
- All IP addresses and domain names are anonymized
|
||||
`
|
||||
|
||||
const (
|
||||
@@ -357,6 +373,10 @@ func (g *BundleGenerator) addSystemInfo() {
|
||||
if err := g.addFirewallRules(); err != nil {
|
||||
log.Errorf("failed to add firewall rules to debug bundle: %v", err)
|
||||
}
|
||||
|
||||
if err := g.addDNSInfo(); err != nil {
|
||||
log.Errorf("failed to add DNS info to debug bundle: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (g *BundleGenerator) addReadme() error {
|
||||
|
||||
53
client/internal/debug/debug_darwin.go
Normal file
53
client/internal/debug/debug_darwin.go
Normal file
@@ -0,0 +1,53 @@
|
||||
//go:build darwin && !ios
|
||||
|
||||
package debug
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"os/exec"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// addDNSInfo collects and adds DNS configuration information to the archive
|
||||
func (g *BundleGenerator) addDNSInfo() error {
|
||||
if err := g.addResolvConf(); err != nil {
|
||||
log.Errorf("failed to add resolv.conf: %v", err)
|
||||
}
|
||||
|
||||
if err := g.addScutilDNS(); err != nil {
|
||||
log.Errorf("failed to add scutil DNS output: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (g *BundleGenerator) addScutilDNS() error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
cmd := exec.CommandContext(ctx, "scutil", "--dns")
|
||||
output, err := cmd.CombinedOutput()
|
||||
if err != nil {
|
||||
return fmt.Errorf("execute scutil --dns: %w", err)
|
||||
}
|
||||
|
||||
if len(bytes.TrimSpace(output)) == 0 {
|
||||
return fmt.Errorf("no scutil DNS output")
|
||||
}
|
||||
|
||||
content := string(output)
|
||||
if g.anonymize {
|
||||
content = g.anonymizer.AnonymizeString(content)
|
||||
}
|
||||
|
||||
if err := g.addFileToZip(strings.NewReader(content), "scutil_dns.txt"); err != nil {
|
||||
return fmt.Errorf("add scutil DNS output to zip: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -5,3 +5,7 @@ package debug
|
||||
func (g *BundleGenerator) addRoutes() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (g *BundleGenerator) addDNSInfo() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
16
client/internal/debug/debug_nondarwin.go
Normal file
16
client/internal/debug/debug_nondarwin.go
Normal file
@@ -0,0 +1,16 @@
|
||||
//go:build unix && !darwin && !android
|
||||
|
||||
package debug
|
||||
|
||||
import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// addDNSInfo collects and adds DNS configuration information to the archive
|
||||
func (g *BundleGenerator) addDNSInfo() error {
|
||||
if err := g.addResolvConf(); err != nil {
|
||||
log.Errorf("failed to add resolv.conf: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
7
client/internal/debug/debug_nonunix.go
Normal file
7
client/internal/debug/debug_nonunix.go
Normal file
@@ -0,0 +1,7 @@
|
||||
//go:build !unix
|
||||
|
||||
package debug
|
||||
|
||||
func (g *BundleGenerator) addDNSInfo() error {
|
||||
return nil
|
||||
}
|
||||
29
client/internal/debug/debug_unix.go
Normal file
29
client/internal/debug/debug_unix.go
Normal file
@@ -0,0 +1,29 @@
|
||||
//go:build unix && !android
|
||||
|
||||
package debug
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const resolvConfPath = "/etc/resolv.conf"
|
||||
|
||||
func (g *BundleGenerator) addResolvConf() error {
|
||||
data, err := os.ReadFile(resolvConfPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("read %s: %w", resolvConfPath, err)
|
||||
}
|
||||
|
||||
content := string(data)
|
||||
if g.anonymize {
|
||||
content = g.anonymizer.AnonymizeString(content)
|
||||
}
|
||||
|
||||
if err := g.addFileToZip(strings.NewReader(content), "resolv.conf"); err != nil {
|
||||
return fmt.Errorf("add resolv.conf to zip: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -38,6 +38,8 @@ type DeviceAuthProviderConfig struct {
|
||||
Scope string
|
||||
// UseIDToken indicates if the id token should be used for authentication
|
||||
UseIDToken bool
|
||||
// LoginHint is used to pre-fill the email/username field during authentication
|
||||
LoginHint string
|
||||
}
|
||||
|
||||
// GetDeviceAuthorizationFlowInfo initialize a DeviceAuthorizationFlow instance and return with it
|
||||
|
||||
@@ -65,8 +65,9 @@ type hostManagerWithOriginalNS interface {
|
||||
|
||||
// DefaultServer dns server object
|
||||
type DefaultServer struct {
|
||||
ctx context.Context
|
||||
ctxCancel context.CancelFunc
|
||||
ctx context.Context
|
||||
ctxCancel context.CancelFunc
|
||||
shutdownWg sync.WaitGroup
|
||||
// disableSys disables system DNS management (e.g., /etc/resolv.conf updates) while keeping the DNS service running.
|
||||
// This is different from ServiceEnable=false from management which completely disables the DNS service.
|
||||
disableSys bool
|
||||
@@ -318,6 +319,7 @@ func (s *DefaultServer) DnsIP() netip.Addr {
|
||||
// Stop stops the server
|
||||
func (s *DefaultServer) Stop() {
|
||||
s.ctxCancel()
|
||||
s.shutdownWg.Wait()
|
||||
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
@@ -507,8 +509,9 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
|
||||
|
||||
s.applyHostConfig()
|
||||
|
||||
s.shutdownWg.Add(1)
|
||||
go func() {
|
||||
// persist dns state right away
|
||||
defer s.shutdownWg.Done()
|
||||
if err := s.stateManager.PersistState(s.ctx); err != nil {
|
||||
log.Errorf("Failed to persist dns state: %v", err)
|
||||
}
|
||||
|
||||
@@ -944,7 +944,7 @@ func createWgInterfaceWithBind(t *testing.T) (*iface.WGIface, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
pf, err := uspfilter.Create(wgIface, false, flowLogger)
|
||||
pf, err := uspfilter.Create(wgIface, false, flowLogger, iface.DefaultMTU)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create uspfilter: %v", err)
|
||||
return nil, err
|
||||
|
||||
@@ -83,4 +83,3 @@ func TestCacheMiss(t *testing.T) {
|
||||
t.Fatalf("expected cache miss, got=%v ok=%v", got, ok)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -14,6 +14,7 @@ import (
|
||||
"github.com/hashicorp/go-multierror"
|
||||
"github.com/miekg/dns"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.zx2c4.com/wireguard/tun/netstack"
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
@@ -33,7 +34,7 @@ type firewaller interface {
|
||||
}
|
||||
|
||||
type DNSForwarder struct {
|
||||
listenAddress string
|
||||
listenAddress netip.AddrPort
|
||||
ttl uint32
|
||||
statusRecorder *peer.Status
|
||||
|
||||
@@ -47,9 +48,11 @@ type DNSForwarder struct {
|
||||
firewall firewaller
|
||||
resolver resolver
|
||||
cache *cache
|
||||
|
||||
wgIface wgIface
|
||||
}
|
||||
|
||||
func NewDNSForwarder(listenAddress string, ttl uint32, firewall firewaller, statusRecorder *peer.Status) *DNSForwarder {
|
||||
func NewDNSForwarder(listenAddress netip.AddrPort, ttl uint32, firewall firewaller, statusRecorder *peer.Status, wgIface wgIface) *DNSForwarder {
|
||||
log.Debugf("creating DNS forwarder with listen_address=%s ttl=%d", listenAddress, ttl)
|
||||
return &DNSForwarder{
|
||||
listenAddress: listenAddress,
|
||||
@@ -58,30 +61,46 @@ func NewDNSForwarder(listenAddress string, ttl uint32, firewall firewaller, stat
|
||||
statusRecorder: statusRecorder,
|
||||
resolver: net.DefaultResolver,
|
||||
cache: newCache(),
|
||||
wgIface: wgIface,
|
||||
}
|
||||
}
|
||||
|
||||
func (f *DNSForwarder) Listen(entries []*ForwarderEntry) error {
|
||||
log.Infof("starting DNS forwarder on address=%s", f.listenAddress)
|
||||
var netstackNet *netstack.Net
|
||||
if f.wgIface != nil {
|
||||
netstackNet = f.wgIface.GetNet()
|
||||
}
|
||||
|
||||
addrDesc := f.listenAddress.String()
|
||||
if netstackNet != nil {
|
||||
addrDesc = fmt.Sprintf("netstack %s", f.listenAddress)
|
||||
}
|
||||
log.Infof("starting DNS forwarder on address=%s", addrDesc)
|
||||
|
||||
udpLn, err := f.createUDPListener(netstackNet)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create UDP listener: %w", err)
|
||||
}
|
||||
|
||||
tcpLn, err := f.createTCPListener(netstackNet)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create TCP listener: %w", err)
|
||||
}
|
||||
|
||||
// UDP server
|
||||
mux := dns.NewServeMux()
|
||||
f.mux = mux
|
||||
mux.HandleFunc(".", f.handleDNSQueryUDP)
|
||||
f.dnsServer = &dns.Server{
|
||||
Addr: f.listenAddress,
|
||||
Net: "udp",
|
||||
Handler: mux,
|
||||
PacketConn: udpLn,
|
||||
Handler: mux,
|
||||
}
|
||||
|
||||
// TCP server
|
||||
tcpMux := dns.NewServeMux()
|
||||
f.tcpMux = tcpMux
|
||||
tcpMux.HandleFunc(".", f.handleDNSQueryTCP)
|
||||
f.tcpServer = &dns.Server{
|
||||
Addr: f.listenAddress,
|
||||
Net: "tcp",
|
||||
Handler: tcpMux,
|
||||
Listener: tcpLn,
|
||||
Handler: tcpMux,
|
||||
}
|
||||
|
||||
f.UpdateDomains(entries)
|
||||
@@ -89,18 +108,33 @@ func (f *DNSForwarder) Listen(entries []*ForwarderEntry) error {
|
||||
errCh := make(chan error, 2)
|
||||
|
||||
go func() {
|
||||
log.Infof("DNS UDP listener running on %s", f.listenAddress)
|
||||
errCh <- f.dnsServer.ListenAndServe()
|
||||
log.Infof("DNS UDP listener running on %s", addrDesc)
|
||||
errCh <- f.dnsServer.ActivateAndServe()
|
||||
}()
|
||||
go func() {
|
||||
log.Infof("DNS TCP listener running on %s", f.listenAddress)
|
||||
errCh <- f.tcpServer.ListenAndServe()
|
||||
log.Infof("DNS TCP listener running on %s", addrDesc)
|
||||
errCh <- f.tcpServer.ActivateAndServe()
|
||||
}()
|
||||
|
||||
// return the first error we get (e.g. bind failure or shutdown)
|
||||
return <-errCh
|
||||
}
|
||||
|
||||
func (f *DNSForwarder) createUDPListener(netstackNet *netstack.Net) (net.PacketConn, error) {
|
||||
if netstackNet != nil {
|
||||
return netstackNet.ListenUDPAddrPort(f.listenAddress)
|
||||
}
|
||||
|
||||
return net.ListenUDP("udp", net.UDPAddrFromAddrPort(f.listenAddress))
|
||||
}
|
||||
|
||||
func (f *DNSForwarder) createTCPListener(netstackNet *netstack.Net) (net.Listener, error) {
|
||||
if netstackNet != nil {
|
||||
return netstackNet.ListenTCPAddrPort(f.listenAddress)
|
||||
}
|
||||
|
||||
return net.ListenTCP("tcp", net.TCPAddrFromAddrPort(f.listenAddress))
|
||||
}
|
||||
|
||||
func (f *DNSForwarder) UpdateDomains(entries []*ForwarderEntry) {
|
||||
f.mutex.Lock()
|
||||
defer f.mutex.Unlock()
|
||||
|
||||
@@ -297,7 +297,7 @@ func TestDNSForwarder_UnauthorizedDomainAccess(t *testing.T) {
|
||||
mockResolver.On("LookupNetIP", mock.Anything, "ip4", dns.Fqdn(tt.queryDomain)).Return([]netip.Addr{fakeIP}, nil)
|
||||
}
|
||||
|
||||
forwarder := NewDNSForwarder("127.0.0.1:0", 300, mockFirewall, &peer.Status{})
|
||||
forwarder := NewDNSForwarder(netip.MustParseAddrPort("127.0.0.1:0"), 300, mockFirewall, &peer.Status{}, nil)
|
||||
forwarder.resolver = mockResolver
|
||||
|
||||
d, err := domain.FromString(tt.configuredDomain)
|
||||
@@ -402,7 +402,7 @@ func TestDNSForwarder_FirewallSetUpdates(t *testing.T) {
|
||||
mockResolver := &MockResolver{}
|
||||
|
||||
// Set up forwarder
|
||||
forwarder := NewDNSForwarder("127.0.0.1:0", 300, mockFirewall, &peer.Status{})
|
||||
forwarder := NewDNSForwarder(netip.MustParseAddrPort("127.0.0.1:0"), 300, mockFirewall, &peer.Status{}, nil)
|
||||
forwarder.resolver = mockResolver
|
||||
|
||||
// Create entries and track sets
|
||||
@@ -489,7 +489,7 @@ func TestDNSForwarder_MultipleIPsInSingleUpdate(t *testing.T) {
|
||||
mockFirewall := &MockFirewall{}
|
||||
mockResolver := &MockResolver{}
|
||||
|
||||
forwarder := NewDNSForwarder("127.0.0.1:0", 300, mockFirewall, &peer.Status{})
|
||||
forwarder := NewDNSForwarder(netip.MustParseAddrPort("127.0.0.1:0"), 300, mockFirewall, &peer.Status{}, nil)
|
||||
forwarder.resolver = mockResolver
|
||||
|
||||
// Configure a single domain
|
||||
@@ -584,7 +584,7 @@ func TestDNSForwarder_ResponseCodes(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
forwarder := NewDNSForwarder("127.0.0.1:0", 300, nil, &peer.Status{})
|
||||
forwarder := NewDNSForwarder(netip.MustParseAddrPort("127.0.0.1:0"), 300, nil, &peer.Status{}, nil)
|
||||
|
||||
d, err := domain.FromString(tt.configured)
|
||||
require.NoError(t, err)
|
||||
@@ -616,7 +616,7 @@ func TestDNSForwarder_ResponseCodes(t *testing.T) {
|
||||
func TestDNSForwarder_TCPTruncation(t *testing.T) {
|
||||
// Test that large UDP responses are truncated with TC bit set
|
||||
mockResolver := &MockResolver{}
|
||||
forwarder := NewDNSForwarder("127.0.0.1:0", 300, nil, &peer.Status{})
|
||||
forwarder := NewDNSForwarder(netip.MustParseAddrPort("127.0.0.1:0"), 300, nil, &peer.Status{}, nil)
|
||||
forwarder.resolver = mockResolver
|
||||
|
||||
d, _ := domain.FromString("example.com")
|
||||
@@ -652,7 +652,7 @@ func TestDNSForwarder_TCPTruncation(t *testing.T) {
|
||||
// a subsequent upstream failure still returns a successful response from cache.
|
||||
func TestDNSForwarder_ServeFromCacheOnUpstreamFailure(t *testing.T) {
|
||||
mockResolver := &MockResolver{}
|
||||
forwarder := NewDNSForwarder("127.0.0.1:0", 300, nil, &peer.Status{})
|
||||
forwarder := NewDNSForwarder(netip.MustParseAddrPort("127.0.0.1:0"), 300, nil, &peer.Status{}, nil)
|
||||
forwarder.resolver = mockResolver
|
||||
|
||||
d, err := domain.FromString("example.com")
|
||||
@@ -696,7 +696,7 @@ func TestDNSForwarder_ServeFromCacheOnUpstreamFailure(t *testing.T) {
|
||||
// Verifies that cache normalization works across casing and trailing dot variations.
|
||||
func TestDNSForwarder_CacheNormalizationCasingAndDot(t *testing.T) {
|
||||
mockResolver := &MockResolver{}
|
||||
forwarder := NewDNSForwarder("127.0.0.1:0", 300, nil, &peer.Status{})
|
||||
forwarder := NewDNSForwarder(netip.MustParseAddrPort("127.0.0.1:0"), 300, nil, &peer.Status{}, nil)
|
||||
forwarder.resolver = mockResolver
|
||||
|
||||
d, err := domain.FromString("ExAmPlE.CoM")
|
||||
@@ -742,7 +742,7 @@ func TestDNSForwarder_MultipleOverlappingPatterns(t *testing.T) {
|
||||
mockFirewall := &MockFirewall{}
|
||||
mockResolver := &MockResolver{}
|
||||
|
||||
forwarder := NewDNSForwarder("127.0.0.1:0", 300, mockFirewall, &peer.Status{})
|
||||
forwarder := NewDNSForwarder(netip.MustParseAddrPort("127.0.0.1:0"), 300, mockFirewall, &peer.Status{}, nil)
|
||||
forwarder.resolver = mockResolver
|
||||
|
||||
// Set up complex overlapping patterns
|
||||
@@ -804,7 +804,7 @@ func TestDNSForwarder_NodataVsNxdomain(t *testing.T) {
|
||||
mockFirewall := &MockFirewall{}
|
||||
mockResolver := &MockResolver{}
|
||||
|
||||
forwarder := NewDNSForwarder("127.0.0.1:0", 300, mockFirewall, &peer.Status{})
|
||||
forwarder := NewDNSForwarder(netip.MustParseAddrPort("127.0.0.1:0"), 300, mockFirewall, &peer.Status{}, nil)
|
||||
forwarder.resolver = mockResolver
|
||||
|
||||
d, err := domain.FromString("example.com")
|
||||
@@ -925,7 +925,7 @@ func TestDNSForwarder_NodataVsNxdomain(t *testing.T) {
|
||||
|
||||
func TestDNSForwarder_EmptyQuery(t *testing.T) {
|
||||
// Test handling of malformed query with no questions
|
||||
forwarder := NewDNSForwarder("127.0.0.1:0", 300, nil, &peer.Status{})
|
||||
forwarder := NewDNSForwarder(netip.MustParseAddrPort("127.0.0.1:0"), 300, nil, &peer.Status{}, nil)
|
||||
|
||||
query := &dns.Msg{}
|
||||
// Don't set any question
|
||||
|
||||
@@ -10,9 +10,12 @@ import (
|
||||
|
||||
"github.com/hashicorp/go-multierror"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.zx2c4.com/wireguard/tun/netstack"
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
@@ -24,6 +27,12 @@ const (
|
||||
envServerPort = "NB_DNS_FORWARDER_PORT"
|
||||
)
|
||||
|
||||
// wgIface defines the interface for WireGuard interface operations needed by the DNS forwarder.
|
||||
type wgIface interface {
|
||||
GetNet() *netstack.Net
|
||||
Address() wgaddr.Address
|
||||
}
|
||||
|
||||
// ForwarderEntry is a mapping from a domain to a resource ID and a hash of the parent domain list.
|
||||
type ForwarderEntry struct {
|
||||
Domain domain.Domain
|
||||
@@ -34,7 +43,7 @@ type ForwarderEntry struct {
|
||||
type Manager struct {
|
||||
firewall firewall.Manager
|
||||
statusRecorder *peer.Status
|
||||
localAddr netip.Addr
|
||||
wgIface wgIface
|
||||
serverPort uint16
|
||||
|
||||
fwRules []firewall.Rule
|
||||
@@ -42,7 +51,7 @@ type Manager struct {
|
||||
dnsForwarder *DNSForwarder
|
||||
}
|
||||
|
||||
func NewManager(fw firewall.Manager, statusRecorder *peer.Status, localAddr netip.Addr) *Manager {
|
||||
func NewManager(fw firewall.Manager, statusRecorder *peer.Status, wgIface wgIface) *Manager {
|
||||
serverPort := nbdns.ForwarderServerPort
|
||||
if envPort := os.Getenv(envServerPort); envPort != "" {
|
||||
if port, err := strconv.ParseUint(envPort, 10, 16); err == nil && port > 0 {
|
||||
@@ -56,7 +65,7 @@ func NewManager(fw firewall.Manager, statusRecorder *peer.Status, localAddr neti
|
||||
return &Manager{
|
||||
firewall: fw,
|
||||
statusRecorder: statusRecorder,
|
||||
localAddr: localAddr,
|
||||
wgIface: wgIface,
|
||||
serverPort: serverPort,
|
||||
}
|
||||
}
|
||||
@@ -71,21 +80,25 @@ func (m *Manager) Start(fwdEntries []*ForwarderEntry) error {
|
||||
return err
|
||||
}
|
||||
|
||||
if m.localAddr.IsValid() && m.firewall != nil {
|
||||
if err := m.firewall.AddInboundDNAT(m.localAddr, firewall.ProtocolUDP, nbdns.ForwarderClientPort, m.serverPort); err != nil {
|
||||
localAddr := m.wgIface.Address().IP
|
||||
|
||||
if localAddr.IsValid() && m.firewall != nil {
|
||||
if err := m.firewall.AddInboundDNAT(localAddr, firewall.ProtocolUDP, nbdns.ForwarderClientPort, m.serverPort); err != nil {
|
||||
log.Warnf("failed to add DNS UDP DNAT rule: %v", err)
|
||||
} else {
|
||||
log.Infof("added DNS UDP DNAT rule: %s:%d -> %s:%d", m.localAddr, nbdns.ForwarderClientPort, m.localAddr, m.serverPort)
|
||||
log.Infof("added DNS UDP DNAT rule: %s:%d -> %s:%d", localAddr, nbdns.ForwarderClientPort, localAddr, m.serverPort)
|
||||
}
|
||||
|
||||
if err := m.firewall.AddInboundDNAT(m.localAddr, firewall.ProtocolTCP, nbdns.ForwarderClientPort, m.serverPort); err != nil {
|
||||
if err := m.firewall.AddInboundDNAT(localAddr, firewall.ProtocolTCP, nbdns.ForwarderClientPort, m.serverPort); err != nil {
|
||||
log.Warnf("failed to add DNS TCP DNAT rule: %v", err)
|
||||
} else {
|
||||
log.Infof("added DNS TCP DNAT rule: %s:%d -> %s:%d", m.localAddr, nbdns.ForwarderClientPort, m.localAddr, m.serverPort)
|
||||
log.Infof("added DNS TCP DNAT rule: %s:%d -> %s:%d", localAddr, nbdns.ForwarderClientPort, localAddr, m.serverPort)
|
||||
}
|
||||
}
|
||||
|
||||
m.dnsForwarder = NewDNSForwarder(fmt.Sprintf(":%d", m.serverPort), dnsTTL, m.firewall, m.statusRecorder)
|
||||
listenAddress := netip.AddrPortFrom(localAddr, m.serverPort)
|
||||
m.dnsForwarder = NewDNSForwarder(listenAddress, dnsTTL, m.firewall, m.statusRecorder, m.wgIface)
|
||||
|
||||
go func() {
|
||||
if err := m.dnsForwarder.Listen(fwdEntries); err != nil {
|
||||
// todo handle close error if it is exists
|
||||
@@ -111,16 +124,19 @@ func (m *Manager) Stop(ctx context.Context) error {
|
||||
|
||||
var mErr *multierror.Error
|
||||
|
||||
if m.localAddr.IsValid() && m.firewall != nil {
|
||||
if err := m.firewall.RemoveInboundDNAT(m.localAddr, firewall.ProtocolUDP, nbdns.ForwarderClientPort, m.serverPort); err != nil {
|
||||
localAddr := m.wgIface.Address().IP
|
||||
if localAddr.IsValid() && m.firewall != nil {
|
||||
if err := m.firewall.RemoveInboundDNAT(localAddr, firewall.ProtocolUDP, nbdns.ForwarderClientPort, m.serverPort); err != nil {
|
||||
mErr = multierror.Append(mErr, fmt.Errorf("remove DNS UDP DNAT rule: %w", err))
|
||||
}
|
||||
|
||||
if err := m.firewall.RemoveInboundDNAT(m.localAddr, firewall.ProtocolTCP, nbdns.ForwarderClientPort, m.serverPort); err != nil {
|
||||
if err := m.firewall.RemoveInboundDNAT(localAddr, firewall.ProtocolTCP, nbdns.ForwarderClientPort, m.serverPort); err != nil {
|
||||
mErr = multierror.Append(mErr, fmt.Errorf("remove DNS TCP DNAT rule: %w", err))
|
||||
}
|
||||
}
|
||||
|
||||
m.unregisterNetstackServices()
|
||||
|
||||
if err := m.dropDNSFirewall(); err != nil {
|
||||
mErr = multierror.Append(mErr, err)
|
||||
}
|
||||
@@ -145,21 +161,50 @@ func (m *Manager) allowDNSFirewall() error {
|
||||
|
||||
dnsRules, err := m.firewall.AddPeerFiltering(nil, net.IP{0, 0, 0, 0}, firewall.ProtocolUDP, nil, dport, firewall.ActionAccept, "")
|
||||
if err != nil {
|
||||
log.Errorf("failed to add allow DNS router rules, err: %v", err)
|
||||
return err
|
||||
return fmt.Errorf("add udp firewall rule: %w", err)
|
||||
}
|
||||
m.fwRules = dnsRules
|
||||
|
||||
tcpRules, err := m.firewall.AddPeerFiltering(nil, net.IP{0, 0, 0, 0}, firewall.ProtocolTCP, nil, dport, firewall.ActionAccept, "")
|
||||
if err != nil {
|
||||
log.Errorf("failed to add allow DNS router rules, err: %v", err)
|
||||
return err
|
||||
return fmt.Errorf("add tcp firewall rule: %w", err)
|
||||
}
|
||||
|
||||
if err := m.firewall.Flush(); err != nil {
|
||||
return fmt.Errorf("flush: %w", err)
|
||||
}
|
||||
|
||||
m.fwRules = dnsRules
|
||||
m.tcpRules = tcpRules
|
||||
|
||||
m.registerNetstackServices()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manager) registerNetstackServices() {
|
||||
if netstackNet := m.wgIface.GetNet(); netstackNet != nil {
|
||||
if registrar, ok := m.firewall.(interface {
|
||||
RegisterNetstackService(protocol nftypes.Protocol, port uint16)
|
||||
}); ok {
|
||||
registrar.RegisterNetstackService(nftypes.TCP, m.serverPort)
|
||||
registrar.RegisterNetstackService(nftypes.UDP, m.serverPort)
|
||||
log.Debugf("registered DNS forwarder service with netstack for UDP/TCP:%d", m.serverPort)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) unregisterNetstackServices() {
|
||||
if netstackNet := m.wgIface.GetNet(); netstackNet != nil {
|
||||
if registrar, ok := m.firewall.(interface {
|
||||
UnregisterNetstackService(protocol nftypes.Protocol, port uint16)
|
||||
}); ok {
|
||||
registrar.UnregisterNetstackService(nftypes.TCP, m.serverPort)
|
||||
registrar.UnregisterNetstackService(nftypes.UDP, m.serverPort)
|
||||
log.Debugf("unregistered DNS forwarder service with netstack for UDP/TCP:%d", m.serverPort)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) dropDNSFirewall() error {
|
||||
var mErr *multierror.Error
|
||||
for _, rule := range m.fwRules {
|
||||
|
||||
@@ -148,6 +148,8 @@ type Engine struct {
|
||||
|
||||
// syncMsgMux is used to guarantee sequential Management Service message processing
|
||||
syncMsgMux *sync.Mutex
|
||||
// sshMux protects sshServer field access
|
||||
sshMux sync.Mutex
|
||||
|
||||
config *EngineConfig
|
||||
mobileDep MobileDependency
|
||||
@@ -200,8 +202,10 @@ type Engine struct {
|
||||
flowManager nftypes.FlowManager
|
||||
|
||||
// WireGuard interface monitor
|
||||
wgIfaceMonitor *WGIfaceMonitor
|
||||
wgIfaceMonitorWg sync.WaitGroup
|
||||
wgIfaceMonitor *WGIfaceMonitor
|
||||
|
||||
// shutdownWg tracks all long-running goroutines to ensure clean shutdown
|
||||
shutdownWg sync.WaitGroup
|
||||
|
||||
probeStunTurn *relay.StunTurnProbe
|
||||
}
|
||||
@@ -298,17 +302,12 @@ func (e *Engine) Stop() error {
|
||||
e.ingressGatewayMgr = nil
|
||||
}
|
||||
|
||||
e.stopDNSForwarder()
|
||||
|
||||
if e.routeManager != nil {
|
||||
e.routeManager.Stop(e.stateManager)
|
||||
}
|
||||
|
||||
if e.dnsForwardMgr != nil {
|
||||
if err := e.dnsForwardMgr.Stop(context.Background()); err != nil {
|
||||
log.Errorf("failed to stop DNS forward: %v", err)
|
||||
}
|
||||
e.dnsForwardMgr = nil
|
||||
}
|
||||
|
||||
if e.srWatcher != nil {
|
||||
e.srWatcher.Close()
|
||||
}
|
||||
@@ -325,10 +324,6 @@ func (e *Engine) Stop() error {
|
||||
e.cancel()
|
||||
}
|
||||
|
||||
// very ugly but we want to remove peers from the WireGuard interface first before removing interface.
|
||||
// Removing peers happens in the conn.Close() asynchronously
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
|
||||
e.close()
|
||||
|
||||
// stop flow manager after wg interface is gone
|
||||
@@ -336,8 +331,6 @@ func (e *Engine) Stop() error {
|
||||
e.flowManager.Close()
|
||||
}
|
||||
|
||||
log.Infof("stopped Netbird Engine")
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
defer cancel()
|
||||
|
||||
@@ -348,12 +341,52 @@ func (e *Engine) Stop() error {
|
||||
log.Errorf("failed to persist state: %v", err)
|
||||
}
|
||||
|
||||
// Stop WireGuard interface monitor and wait for it to exit
|
||||
e.wgIfaceMonitorWg.Wait()
|
||||
timeout := e.calculateShutdownTimeout()
|
||||
log.Debugf("waiting for goroutines to finish with timeout: %v", timeout)
|
||||
shutdownCtx, cancel := context.WithTimeout(context.Background(), timeout)
|
||||
defer cancel()
|
||||
|
||||
if err := waitWithContext(shutdownCtx, &e.shutdownWg); err != nil {
|
||||
log.Warnf("shutdown timeout exceeded after %v, some goroutines may still be running", timeout)
|
||||
}
|
||||
|
||||
log.Infof("stopped Netbird Engine")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// calculateShutdownTimeout returns shutdown timeout: 10s base + 100ms per peer, capped at 30s.
|
||||
func (e *Engine) calculateShutdownTimeout() time.Duration {
|
||||
peerCount := len(e.peerStore.PeersPubKey())
|
||||
|
||||
baseTimeout := 10 * time.Second
|
||||
perPeerTimeout := time.Duration(peerCount) * 100 * time.Millisecond
|
||||
timeout := baseTimeout + perPeerTimeout
|
||||
|
||||
maxTimeout := 30 * time.Second
|
||||
if timeout > maxTimeout {
|
||||
timeout = maxTimeout
|
||||
}
|
||||
|
||||
return timeout
|
||||
}
|
||||
|
||||
// waitWithContext waits for WaitGroup with timeout, returns ctx.Err() on timeout.
|
||||
func waitWithContext(ctx context.Context, wg *sync.WaitGroup) error {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
wg.Wait()
|
||||
close(done)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
return nil
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
// Start creates a new WireGuard tunnel interface and listens to events from Signal and Management services
|
||||
// Connections to remote peers are not established here.
|
||||
// However, they will be established once an event with a list of peers to connect to will be received from Management Service
|
||||
@@ -483,14 +516,14 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
|
||||
|
||||
// monitor WireGuard interface lifecycle and restart engine on changes
|
||||
e.wgIfaceMonitor = NewWGIfaceMonitor()
|
||||
e.wgIfaceMonitorWg.Add(1)
|
||||
e.shutdownWg.Add(1)
|
||||
|
||||
go func() {
|
||||
defer e.wgIfaceMonitorWg.Done()
|
||||
defer e.shutdownWg.Done()
|
||||
|
||||
if shouldRestart, err := e.wgIfaceMonitor.Start(e.ctx, e.wgInterface.Name()); shouldRestart {
|
||||
log.Infof("WireGuard interface monitor: %s, restarting engine", err)
|
||||
e.restartEngine()
|
||||
e.triggerClientRestart()
|
||||
} else if err != nil {
|
||||
log.Warnf("WireGuard interface monitor: %s", err)
|
||||
}
|
||||
@@ -506,7 +539,7 @@ func (e *Engine) createFirewall() error {
|
||||
}
|
||||
|
||||
var err error
|
||||
e.firewall, err = firewall.NewFirewall(e.wgInterface, e.stateManager, e.flowManager.GetLogger(), e.config.DisableServerRoutes)
|
||||
e.firewall, err = firewall.NewFirewall(e.wgInterface, e.stateManager, e.flowManager.GetLogger(), e.config.DisableServerRoutes, e.config.MTU)
|
||||
if err != nil || e.firewall == nil {
|
||||
log.Errorf("failed creating firewall manager: %s", err)
|
||||
return nil
|
||||
@@ -674,9 +707,11 @@ func (e *Engine) removeAllPeers() error {
|
||||
func (e *Engine) removePeer(peerKey string) error {
|
||||
log.Debugf("removing peer from engine %s", peerKey)
|
||||
|
||||
e.sshMux.Lock()
|
||||
if !isNil(e.sshServer) {
|
||||
e.sshServer.RemoveAuthorizedKey(peerKey)
|
||||
}
|
||||
e.sshMux.Unlock()
|
||||
|
||||
e.connMgr.RemovePeerConn(peerKey)
|
||||
|
||||
@@ -878,6 +913,7 @@ func (e *Engine) updateSSH(sshConf *mgmProto.SSHConfig) error {
|
||||
log.Warnf("running SSH server on %s is not supported", runtime.GOOS)
|
||||
return nil
|
||||
}
|
||||
e.sshMux.Lock()
|
||||
// start SSH server if it wasn't running
|
||||
if isNil(e.sshServer) {
|
||||
listenAddr := fmt.Sprintf("%s:%d", e.wgInterface.Address().IP.String(), nbssh.DefaultSSHPort)
|
||||
@@ -885,34 +921,42 @@ func (e *Engine) updateSSH(sshConf *mgmProto.SSHConfig) error {
|
||||
listenAddr = fmt.Sprintf("127.0.0.1:%d", nbssh.DefaultSSHPort)
|
||||
}
|
||||
// nil sshServer means it has not yet been started
|
||||
var err error
|
||||
e.sshServer, err = e.sshServerFunc(e.config.SSHKey, listenAddr)
|
||||
|
||||
server, err := e.sshServerFunc(e.config.SSHKey, listenAddr)
|
||||
if err != nil {
|
||||
e.sshMux.Unlock()
|
||||
return fmt.Errorf("create ssh server: %w", err)
|
||||
}
|
||||
|
||||
e.sshServer = server
|
||||
e.sshMux.Unlock()
|
||||
|
||||
go func() {
|
||||
// blocking
|
||||
err = e.sshServer.Start()
|
||||
err = server.Start()
|
||||
if err != nil {
|
||||
// will throw error when we stop it even if it is a graceful stop
|
||||
log.Debugf("stopped SSH server with error %v", err)
|
||||
}
|
||||
e.syncMsgMux.Lock()
|
||||
defer e.syncMsgMux.Unlock()
|
||||
e.sshMux.Lock()
|
||||
e.sshServer = nil
|
||||
e.sshMux.Unlock()
|
||||
log.Infof("stopped SSH server")
|
||||
}()
|
||||
} else {
|
||||
e.sshMux.Unlock()
|
||||
log.Debugf("SSH server is already running")
|
||||
}
|
||||
} else if !isNil(e.sshServer) {
|
||||
// Disable SSH server request, so stop it if it was running
|
||||
err := e.sshServer.Stop()
|
||||
if err != nil {
|
||||
log.Warnf("failed to stop SSH server %v", err)
|
||||
} else {
|
||||
e.sshMux.Lock()
|
||||
if !isNil(e.sshServer) {
|
||||
// Disable SSH server request, so stop it if it was running
|
||||
err := e.sshServer.Stop()
|
||||
if err != nil {
|
||||
log.Warnf("failed to stop SSH server %v", err)
|
||||
}
|
||||
e.sshServer = nil
|
||||
}
|
||||
e.sshServer = nil
|
||||
e.sshMux.Unlock()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -949,7 +993,9 @@ func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error {
|
||||
// receiveManagementEvents connects to the Management Service event stream to receive updates from the management service
|
||||
// E.g. when a new peer has been registered and we are allowed to connect to it.
|
||||
func (e *Engine) receiveManagementEvents() {
|
||||
e.shutdownWg.Add(1)
|
||||
go func() {
|
||||
defer e.shutdownWg.Done()
|
||||
info, err := system.GetInfoWithChecks(e.ctx, e.checks)
|
||||
if err != nil {
|
||||
log.Warnf("failed to get system info with checks: %v", err)
|
||||
@@ -1125,6 +1171,7 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
|
||||
e.statusRecorder.FinishPeerListModifications()
|
||||
|
||||
// update SSHServer by adding remote peer SSH keys
|
||||
e.sshMux.Lock()
|
||||
if !isNil(e.sshServer) {
|
||||
for _, config := range networkMap.GetRemotePeers() {
|
||||
if config.GetSshConfig() != nil && config.GetSshConfig().GetSshPubKey() != nil {
|
||||
@@ -1135,6 +1182,7 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
|
||||
}
|
||||
}
|
||||
}
|
||||
e.sshMux.Unlock()
|
||||
}
|
||||
|
||||
// must set the exclude list after the peers are added. Without it the manager can not figure out the peers parameters from the store
|
||||
@@ -1377,7 +1425,9 @@ func (e *Engine) createPeerConn(pubKey string, allowedIPs []netip.Prefix, agentV
|
||||
|
||||
// receiveSignalEvents connects to the Signal Service event stream to negotiate connection with remote peers
|
||||
func (e *Engine) receiveSignalEvents() {
|
||||
e.shutdownWg.Add(1)
|
||||
go func() {
|
||||
defer e.shutdownWg.Done()
|
||||
// connect to a stream of messages coming from the signal server
|
||||
err := e.signal.Receive(e.ctx, func(msg *sProto.Message) error {
|
||||
e.syncMsgMux.Lock()
|
||||
@@ -1494,12 +1544,14 @@ func (e *Engine) close() {
|
||||
e.statusRecorder.SetWgIface(nil)
|
||||
}
|
||||
|
||||
e.sshMux.Lock()
|
||||
if !isNil(e.sshServer) {
|
||||
err := e.sshServer.Stop()
|
||||
if err != nil {
|
||||
log.Warnf("failed stopping the SSH server: %v", err)
|
||||
}
|
||||
}
|
||||
e.sshMux.Unlock()
|
||||
|
||||
if e.firewall != nil {
|
||||
err := e.firewall.Close(e.stateManager)
|
||||
@@ -1730,8 +1782,10 @@ func (e *Engine) RunHealthProbes(waitForResult bool) bool {
|
||||
return allHealthy
|
||||
}
|
||||
|
||||
// restartEngine restarts the engine by cancelling the client context
|
||||
func (e *Engine) restartEngine() {
|
||||
// triggerClientRestart triggers a full client restart by cancelling the client context.
|
||||
// Note: This does NOT just restart the engine - it cancels the entire client context,
|
||||
// which causes the connect client's retry loop to create a completely new engine.
|
||||
func (e *Engine) triggerClientRestart() {
|
||||
e.syncMsgMux.Lock()
|
||||
defer e.syncMsgMux.Unlock()
|
||||
|
||||
@@ -1753,7 +1807,9 @@ func (e *Engine) startNetworkMonitor() {
|
||||
}
|
||||
|
||||
e.networkMonitor = networkmonitor.New()
|
||||
e.shutdownWg.Add(1)
|
||||
go func() {
|
||||
defer e.shutdownWg.Done()
|
||||
if err := e.networkMonitor.Listen(e.ctx); err != nil {
|
||||
if errors.Is(err, context.Canceled) {
|
||||
log.Infof("network monitor stopped")
|
||||
@@ -1763,8 +1819,8 @@ func (e *Engine) startNetworkMonitor() {
|
||||
return
|
||||
}
|
||||
|
||||
log.Infof("Network monitor: detected network change, restarting engine")
|
||||
e.restartEngine()
|
||||
log.Infof("Network monitor: detected network change, triggering client restart")
|
||||
e.triggerClientRestart()
|
||||
}()
|
||||
}
|
||||
|
||||
@@ -1855,38 +1911,46 @@ func (e *Engine) updateDNSForwarder(
|
||||
}
|
||||
|
||||
if !enabled {
|
||||
if e.dnsForwardMgr == nil {
|
||||
return
|
||||
}
|
||||
if err := e.dnsForwardMgr.Stop(context.Background()); err != nil {
|
||||
log.Errorf("failed to stop DNS forward: %v", err)
|
||||
}
|
||||
e.stopDNSForwarder()
|
||||
return
|
||||
}
|
||||
|
||||
if len(fwdEntries) > 0 {
|
||||
if e.dnsForwardMgr == nil {
|
||||
localAddr := e.wgInterface.Address().IP
|
||||
e.dnsForwardMgr = dnsfwd.NewManager(e.firewall, e.statusRecorder, localAddr)
|
||||
|
||||
if err := e.dnsForwardMgr.Start(fwdEntries); err != nil {
|
||||
log.Errorf("failed to start DNS forward: %v", err)
|
||||
e.dnsForwardMgr = nil
|
||||
}
|
||||
|
||||
log.Infof("started domain router service with %d entries", len(fwdEntries))
|
||||
e.startDNSForwarder(fwdEntries)
|
||||
} else {
|
||||
e.dnsForwardMgr.UpdateDomains(fwdEntries)
|
||||
}
|
||||
} else if e.dnsForwardMgr != nil {
|
||||
log.Infof("disable domain router service")
|
||||
if err := e.dnsForwardMgr.Stop(context.Background()); err != nil {
|
||||
log.Errorf("failed to stop DNS forward: %v", err)
|
||||
}
|
||||
e.dnsForwardMgr = nil
|
||||
e.stopDNSForwarder()
|
||||
}
|
||||
}
|
||||
|
||||
func (e *Engine) startDNSForwarder(fwdEntries []*dnsfwd.ForwarderEntry) {
|
||||
e.dnsForwardMgr = dnsfwd.NewManager(e.firewall, e.statusRecorder, e.wgInterface)
|
||||
|
||||
if err := e.dnsForwardMgr.Start(fwdEntries); err != nil {
|
||||
log.Errorf("failed to start DNS forward: %v", err)
|
||||
e.dnsForwardMgr = nil
|
||||
return
|
||||
}
|
||||
|
||||
log.Infof("started domain router service with %d entries", len(fwdEntries))
|
||||
}
|
||||
|
||||
func (e *Engine) stopDNSForwarder() {
|
||||
if e.dnsForwardMgr == nil {
|
||||
return
|
||||
}
|
||||
|
||||
if err := e.dnsForwardMgr.Stop(context.Background()); err != nil {
|
||||
log.Errorf("failed to stop DNS forward: %v", err)
|
||||
}
|
||||
|
||||
e.dnsForwardMgr = nil
|
||||
}
|
||||
|
||||
func (e *Engine) GetNet() (*netstack.Net, error) {
|
||||
e.syncMsgMux.Lock()
|
||||
intf := e.wgInterface
|
||||
|
||||
@@ -26,6 +26,9 @@ import (
|
||||
"google.golang.org/grpc/keepalive"
|
||||
|
||||
"github.com/netbirdio/management-integrations/integrations"
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
|
||||
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/server/config"
|
||||
"github.com/netbirdio/netbird/management/server/groups"
|
||||
@@ -1556,7 +1559,6 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri
|
||||
}
|
||||
t.Cleanup(cleanUp)
|
||||
|
||||
peersUpdateManager := server.NewPeersUpdateManager(nil)
|
||||
eventStore := &activity.InMemoryEventStore{}
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
@@ -1584,13 +1586,16 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri
|
||||
|
||||
groupsManager := groups.NewManagerMock()
|
||||
|
||||
accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
|
||||
updateManager := update_channel.NewPeersUpdateManager(metrics)
|
||||
requestBuffer := server.NewAccountRequestBuffer(context.Background(), store)
|
||||
networkMapController := controller.NewController(context.Background(), store, metrics, updateManager, requestBuffer, server.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock())
|
||||
accountManager, err := server.BuildManager(context.Background(), store, networkMapController, nil, "", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
secretsManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
|
||||
mgmtServer, err := server.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, &manager.EphemeralManager{}, nil, &server.MockIntegratedValidator{})
|
||||
secretsManager := nbgrpc.NewTimeBasedAuthSecretsManager(updateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
|
||||
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, updateManager, secretsManager, nil, &manager.EphemeralManager{}, nil, &server.MockIntegratedValidator{}, networkMapController)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
@@ -24,6 +24,7 @@ import (
|
||||
// Manager handles netflow tracking and logging
|
||||
type Manager struct {
|
||||
mux sync.Mutex
|
||||
shutdownWg sync.WaitGroup
|
||||
logger nftypes.FlowLogger
|
||||
flowConfig *nftypes.FlowConfig
|
||||
conntrack nftypes.ConnTracker
|
||||
@@ -105,8 +106,15 @@ func (m *Manager) resetClient() error {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
m.cancel = cancel
|
||||
|
||||
go m.receiveACKs(ctx, flowClient)
|
||||
go m.startSender(ctx)
|
||||
m.shutdownWg.Add(2)
|
||||
go func() {
|
||||
defer m.shutdownWg.Done()
|
||||
m.receiveACKs(ctx, flowClient)
|
||||
}()
|
||||
go func() {
|
||||
defer m.shutdownWg.Done()
|
||||
m.startSender(ctx)
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -176,11 +184,12 @@ func (m *Manager) Update(update *nftypes.FlowConfig) error {
|
||||
// Close cleans up all resources
|
||||
func (m *Manager) Close() {
|
||||
m.mux.Lock()
|
||||
defer m.mux.Unlock()
|
||||
|
||||
if err := m.disableFlow(); err != nil {
|
||||
log.Warnf("failed to disable flow manager: %v", err)
|
||||
}
|
||||
m.mux.Unlock()
|
||||
|
||||
m.shutdownWg.Wait()
|
||||
}
|
||||
|
||||
// GetLogger returns the flow logger
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
//go:build (darwin && !ios) || dragonfly || freebsd || netbsd || openbsd
|
||||
//go:build dragonfly || freebsd || netbsd || openbsd
|
||||
|
||||
package networkmonitor
|
||||
|
||||
@@ -6,21 +6,19 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"syscall"
|
||||
"unsafe"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/net/route"
|
||||
"golang.org/x/sys/unix"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
|
||||
)
|
||||
|
||||
func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop) error {
|
||||
fd, err := unix.Socket(syscall.AF_ROUTE, syscall.SOCK_RAW, syscall.AF_UNSPEC)
|
||||
fd, err := prepareFd()
|
||||
if err != nil {
|
||||
return fmt.Errorf("open routing socket: %v", err)
|
||||
}
|
||||
|
||||
defer func() {
|
||||
err := unix.Close(fd)
|
||||
if err != nil && !errors.Is(err, unix.EBADF) {
|
||||
@@ -28,72 +26,5 @@ func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop) er
|
||||
}
|
||||
}()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
default:
|
||||
buf := make([]byte, 2048)
|
||||
n, err := unix.Read(fd, buf)
|
||||
if err != nil {
|
||||
if !errors.Is(err, unix.EBADF) && !errors.Is(err, unix.EINVAL) {
|
||||
log.Warnf("Network monitor: failed to read from routing socket: %v", err)
|
||||
}
|
||||
continue
|
||||
}
|
||||
if n < unix.SizeofRtMsghdr {
|
||||
log.Debugf("Network monitor: read from routing socket returned less than expected: %d bytes", n)
|
||||
continue
|
||||
}
|
||||
|
||||
msg := (*unix.RtMsghdr)(unsafe.Pointer(&buf[0]))
|
||||
|
||||
switch msg.Type {
|
||||
// handle route changes
|
||||
case unix.RTM_ADD, syscall.RTM_DELETE:
|
||||
route, err := parseRouteMessage(buf[:n])
|
||||
if err != nil {
|
||||
log.Debugf("Network monitor: error parsing routing message: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
if route.Dst.Bits() != 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
intf := "<nil>"
|
||||
if route.Interface != nil {
|
||||
intf = route.Interface.Name
|
||||
}
|
||||
switch msg.Type {
|
||||
case unix.RTM_ADD:
|
||||
log.Infof("Network monitor: default route changed: via %s, interface %s", route.Gw, intf)
|
||||
return nil
|
||||
case unix.RTM_DELETE:
|
||||
if nexthopv4.Intf != nil && route.Gw.Compare(nexthopv4.IP) == 0 || nexthopv6.Intf != nil && route.Gw.Compare(nexthopv6.IP) == 0 {
|
||||
log.Infof("Network monitor: default route removed: via %s, interface %s", route.Gw, intf)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func parseRouteMessage(buf []byte) (*systemops.Route, error) {
|
||||
msgs, err := route.ParseRIB(route.RIBTypeRoute, buf)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse RIB: %v", err)
|
||||
}
|
||||
|
||||
if len(msgs) != 1 {
|
||||
return nil, fmt.Errorf("unexpected RIB message msgs: %v", msgs)
|
||||
}
|
||||
|
||||
msg, ok := msgs[0].(*route.RouteMessage)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("unexpected RIB message type: %T", msgs[0])
|
||||
}
|
||||
|
||||
return systemops.MsgToRoute(msg)
|
||||
return routeCheck(ctx, fd, nexthopv4, nexthopv6)
|
||||
}
|
||||
|
||||
92
client/internal/networkmonitor/check_change_common.go
Normal file
92
client/internal/networkmonitor/check_change_common.go
Normal file
@@ -0,0 +1,92 @@
|
||||
//go:build dragonfly || freebsd || netbsd || openbsd || darwin
|
||||
|
||||
package networkmonitor
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"syscall"
|
||||
"unsafe"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/net/route"
|
||||
"golang.org/x/sys/unix"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
|
||||
)
|
||||
|
||||
func prepareFd() (int, error) {
|
||||
return unix.Socket(syscall.AF_ROUTE, syscall.SOCK_RAW, syscall.AF_UNSPEC)
|
||||
}
|
||||
|
||||
func routeCheck(ctx context.Context, fd int, nexthopv4, nexthopv6 systemops.Nexthop) error {
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
default:
|
||||
buf := make([]byte, 2048)
|
||||
n, err := unix.Read(fd, buf)
|
||||
if err != nil {
|
||||
if !errors.Is(err, unix.EBADF) && !errors.Is(err, unix.EINVAL) {
|
||||
log.Warnf("Network monitor: failed to read from routing socket: %v", err)
|
||||
}
|
||||
continue
|
||||
}
|
||||
if n < unix.SizeofRtMsghdr {
|
||||
log.Debugf("Network monitor: read from routing socket returned less than expected: %d bytes", n)
|
||||
continue
|
||||
}
|
||||
|
||||
msg := (*unix.RtMsghdr)(unsafe.Pointer(&buf[0]))
|
||||
|
||||
switch msg.Type {
|
||||
// handle route changes
|
||||
case unix.RTM_ADD, syscall.RTM_DELETE:
|
||||
route, err := parseRouteMessage(buf[:n])
|
||||
if err != nil {
|
||||
log.Debugf("Network monitor: error parsing routing message: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
if route.Dst.Bits() != 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
intf := "<nil>"
|
||||
if route.Interface != nil {
|
||||
intf = route.Interface.Name
|
||||
}
|
||||
switch msg.Type {
|
||||
case unix.RTM_ADD:
|
||||
log.Infof("Network monitor: default route changed: via %s, interface %s", route.Gw, intf)
|
||||
return nil
|
||||
case unix.RTM_DELETE:
|
||||
if nexthopv4.Intf != nil && route.Gw.Compare(nexthopv4.IP) == 0 || nexthopv6.Intf != nil && route.Gw.Compare(nexthopv6.IP) == 0 {
|
||||
log.Infof("Network monitor: default route removed: via %s, interface %s", route.Gw, intf)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func parseRouteMessage(buf []byte) (*systemops.Route, error) {
|
||||
msgs, err := route.ParseRIB(route.RIBTypeRoute, buf)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse RIB: %v", err)
|
||||
}
|
||||
|
||||
if len(msgs) != 1 {
|
||||
return nil, fmt.Errorf("unexpected RIB message msgs: %v", msgs)
|
||||
}
|
||||
|
||||
msg, ok := msgs[0].(*route.RouteMessage)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("unexpected RIB message type: %T", msgs[0])
|
||||
}
|
||||
|
||||
return systemops.MsgToRoute(msg)
|
||||
}
|
||||
149
client/internal/networkmonitor/check_change_darwin.go
Normal file
149
client/internal/networkmonitor/check_change_darwin.go
Normal file
@@ -0,0 +1,149 @@
|
||||
//go:build darwin && !ios
|
||||
|
||||
package networkmonitor
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"hash/fnv"
|
||||
"os/exec"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/sys/unix"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
|
||||
)
|
||||
|
||||
// todo: refactor to not use static functions
|
||||
|
||||
func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop) error {
|
||||
fd, err := prepareFd()
|
||||
if err != nil {
|
||||
return fmt.Errorf("open routing socket: %v", err)
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if err := unix.Close(fd); err != nil {
|
||||
if !errors.Is(err, unix.EBADF) {
|
||||
log.Warnf("Network monitor: failed to close routing socket: %v", err)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
routeChanged := make(chan struct{})
|
||||
go func() {
|
||||
_ = routeCheck(ctx, fd, nexthopv4, nexthopv6)
|
||||
close(routeChanged)
|
||||
}()
|
||||
|
||||
wakeUp := make(chan struct{})
|
||||
go func() {
|
||||
wakeUpListen(ctx)
|
||||
close(wakeUp)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-routeChanged:
|
||||
if ctx.Err() != nil {
|
||||
return ctx.Err()
|
||||
}
|
||||
log.Infof("route change detected")
|
||||
return nil
|
||||
case <-wakeUp:
|
||||
if ctx.Err() != nil {
|
||||
return ctx.Err()
|
||||
}
|
||||
log.Infof("wakeup detected")
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func wakeUpListen(ctx context.Context) {
|
||||
log.Infof("start to watch for system wakeups")
|
||||
var (
|
||||
initialHash uint32
|
||||
err error
|
||||
)
|
||||
|
||||
// Keep retrying until initial sysctl succeeds or context is canceled
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
log.Info("exit from wakeUpListen initial hash detection due to context cancellation")
|
||||
return
|
||||
default:
|
||||
initialHash, err = readSleepTimeHash()
|
||||
if err != nil {
|
||||
log.Errorf("failed to detect initial sleep time: %v", err)
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
log.Info("exit from wakeUpListen initial hash detection due to context cancellation")
|
||||
return
|
||||
case <-time.After(3 * time.Second):
|
||||
continue
|
||||
}
|
||||
}
|
||||
log.Debugf("initial wakeup hash: %d", initialHash)
|
||||
break
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
ticker := time.NewTicker(5 * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
log.Info("context canceled, stopping wakeUpListen")
|
||||
return
|
||||
|
||||
case <-ticker.C:
|
||||
newHash, err := readSleepTimeHash()
|
||||
if err != nil {
|
||||
log.Errorf("failed to read sleep time hash: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
if newHash == initialHash {
|
||||
log.Tracef("no wakeup detected")
|
||||
continue
|
||||
}
|
||||
|
||||
upOut, err := exec.Command("uptime").Output()
|
||||
if err != nil {
|
||||
log.Errorf("failed to run uptime command: %v", err)
|
||||
upOut = []byte("unknown")
|
||||
}
|
||||
log.Infof("Wakeup detected: %d -> %d, uptime: %s", initialHash, newHash, upOut)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func readSleepTimeHash() (uint32, error) {
|
||||
cmd := exec.Command("sysctl", "kern.sleeptime")
|
||||
out, err := cmd.Output()
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("failed to run sysctl: %w", err)
|
||||
}
|
||||
|
||||
h, err := hash(out)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("failed to compute hash: %w", err)
|
||||
}
|
||||
|
||||
return h, nil
|
||||
}
|
||||
|
||||
func hash(data []byte) (uint32, error) {
|
||||
hasher := fnv.New32a() // Create a new 32-bit FNV-1a hasher
|
||||
if _, err := hasher.Write(data); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return hasher.Sum32(), nil
|
||||
}
|
||||
@@ -88,6 +88,7 @@ func (nw *NetworkMonitor) Listen(ctx context.Context) (err error) {
|
||||
event := make(chan struct{}, 1)
|
||||
go nw.checkChanges(ctx, event, nexthop4, nexthop6)
|
||||
|
||||
log.Infof("start watching for network changes")
|
||||
// debounce changes
|
||||
timer := time.NewTimer(0)
|
||||
timer.Stop()
|
||||
|
||||
@@ -19,11 +19,10 @@ type SRWatcher struct {
|
||||
signalClient chNotifier
|
||||
relayManager chNotifier
|
||||
|
||||
listeners map[chan struct{}]struct{}
|
||||
mu sync.Mutex
|
||||
iFaceDiscover stdnet.ExternalIFaceDiscover
|
||||
iceConfig ice.Config
|
||||
|
||||
listeners map[chan struct{}]struct{}
|
||||
mu sync.Mutex
|
||||
iFaceDiscover stdnet.ExternalIFaceDiscover
|
||||
iceConfig ice.Config
|
||||
cancelIceMonitor context.CancelFunc
|
||||
}
|
||||
|
||||
|
||||
@@ -411,7 +411,7 @@ func (w *WorkerICE) onConnectionStateChange(agent *icemaker.ThreadSafeAgent, dia
|
||||
|
||||
func (w *WorkerICE) turnAgentDial(ctx context.Context, agent *icemaker.ThreadSafeAgent, remoteOfferAnswer *OfferAnswer) (*ice.Conn, error) {
|
||||
if isController(w.config) {
|
||||
return w.agent.Dial(ctx, remoteOfferAnswer.IceCredentials.UFrag, remoteOfferAnswer.IceCredentials.Pwd)
|
||||
return agent.Dial(ctx, remoteOfferAnswer.IceCredentials.UFrag, remoteOfferAnswer.IceCredentials.Pwd)
|
||||
} else {
|
||||
return agent.Accept(ctx, remoteOfferAnswer.IceCredentials.UFrag, remoteOfferAnswer.IceCredentials.Pwd)
|
||||
}
|
||||
|
||||
@@ -44,6 +44,8 @@ type PKCEAuthProviderConfig struct {
|
||||
DisablePromptLogin bool
|
||||
// LoginFlag is used to configure the PKCE flow login behavior
|
||||
LoginFlag common.LoginFlag
|
||||
// LoginHint is used to pre-fill the email/username field during authentication
|
||||
LoginHint string
|
||||
}
|
||||
|
||||
// GetPKCEAuthorizationFlowInfo initialize a PKCEAuthorizationFlow instance and return with it
|
||||
|
||||
@@ -81,6 +81,7 @@ type DefaultManager struct {
|
||||
ctx context.Context
|
||||
stop context.CancelFunc
|
||||
mux sync.Mutex
|
||||
shutdownWg sync.WaitGroup
|
||||
clientNetworks map[route.HAUniqueID]*client.Watcher
|
||||
routeSelector *routeselector.RouteSelector
|
||||
serverRouter *server.Router
|
||||
@@ -283,6 +284,7 @@ func (m *DefaultManager) SetDNSForwarderPort(port uint16) {
|
||||
// Stop stops the manager watchers and clean firewall rules
|
||||
func (m *DefaultManager) Stop(stateManager *statemanager.Manager) {
|
||||
m.stop()
|
||||
m.shutdownWg.Wait()
|
||||
if m.serverRouter != nil {
|
||||
m.serverRouter.CleanUp()
|
||||
}
|
||||
@@ -485,7 +487,11 @@ func (m *DefaultManager) TriggerSelection(networks route.HAMap) {
|
||||
}
|
||||
clientNetworkWatcher := client.NewWatcher(config)
|
||||
m.clientNetworks[id] = clientNetworkWatcher
|
||||
go clientNetworkWatcher.Start()
|
||||
m.shutdownWg.Add(1)
|
||||
go func() {
|
||||
defer m.shutdownWg.Done()
|
||||
clientNetworkWatcher.Start()
|
||||
}()
|
||||
clientNetworkWatcher.SendUpdate(client.RoutesUpdate{Routes: routes})
|
||||
}
|
||||
|
||||
@@ -527,7 +533,11 @@ func (m *DefaultManager) updateClientNetworks(updateSerial uint64, networks rout
|
||||
}
|
||||
clientNetworkWatcher = client.NewWatcher(config)
|
||||
m.clientNetworks[id] = clientNetworkWatcher
|
||||
go clientNetworkWatcher.Start()
|
||||
m.shutdownWg.Add(1)
|
||||
go func() {
|
||||
defer m.shutdownWg.Done()
|
||||
clientNetworkWatcher.Start()
|
||||
}()
|
||||
}
|
||||
update := client.RoutesUpdate{
|
||||
UpdateSerial: updateSerial,
|
||||
|
||||
@@ -9,8 +9,6 @@ import (
|
||||
"github.com/hashicorp/go-multierror"
|
||||
"golang.org/x/exp/maps"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/errors"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
)
|
||||
@@ -128,13 +126,11 @@ func (rs *RouteSelector) IsSelected(routeID route.NetID) bool {
|
||||
defer rs.mu.RUnlock()
|
||||
|
||||
if rs.deselectAll {
|
||||
log.Debugf("Route %s not selected (deselect all)", routeID)
|
||||
return false
|
||||
}
|
||||
|
||||
_, deselected := rs.deselectedRoutes[routeID]
|
||||
isSelected := !deselected
|
||||
log.Debugf("Route %s selection status: %v (deselected: %v)", routeID, isSelected, deselected)
|
||||
return isSelected
|
||||
}
|
||||
|
||||
|
||||
@@ -228,7 +228,7 @@ func (c *Client) LoginForMobile() string {
|
||||
ConfigPath: c.cfgFile,
|
||||
})
|
||||
|
||||
oAuthFlow, err := auth.NewOAuthFlow(ctx, cfg, false)
|
||||
oAuthFlow, err := auth.NewOAuthFlow(ctx, cfg, false, "")
|
||||
if err != nil {
|
||||
return err.Error()
|
||||
}
|
||||
|
||||
@@ -279,8 +279,10 @@ type LoginRequest struct {
|
||||
ProfileName *string `protobuf:"bytes,30,opt,name=profileName,proto3,oneof" json:"profileName,omitempty"`
|
||||
Username *string `protobuf:"bytes,31,opt,name=username,proto3,oneof" json:"username,omitempty"`
|
||||
Mtu *int64 `protobuf:"varint,32,opt,name=mtu,proto3,oneof" json:"mtu,omitempty"`
|
||||
unknownFields protoimpl.UnknownFields
|
||||
sizeCache protoimpl.SizeCache
|
||||
// hint is used to pre-fill the email/username field during SSO authentication
|
||||
Hint *string `protobuf:"bytes,33,opt,name=hint,proto3,oneof" json:"hint,omitempty"`
|
||||
unknownFields protoimpl.UnknownFields
|
||||
sizeCache protoimpl.SizeCache
|
||||
}
|
||||
|
||||
func (x *LoginRequest) Reset() {
|
||||
@@ -538,6 +540,13 @@ func (x *LoginRequest) GetMtu() int64 {
|
||||
return 0
|
||||
}
|
||||
|
||||
func (x *LoginRequest) GetHint() string {
|
||||
if x != nil && x.Hint != nil {
|
||||
return *x.Hint
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
type LoginResponse struct {
|
||||
state protoimpl.MessageState `protogen:"open.v1"`
|
||||
NeedsSSOLogin bool `protobuf:"varint,1,opt,name=needsSSOLogin,proto3" json:"needsSSOLogin,omitempty"`
|
||||
@@ -4608,7 +4617,7 @@ var File_daemon_proto protoreflect.FileDescriptor
|
||||
const file_daemon_proto_rawDesc = "" +
|
||||
"\n" +
|
||||
"\fdaemon.proto\x12\x06daemon\x1a google/protobuf/descriptor.proto\x1a\x1fgoogle/protobuf/timestamp.proto\x1a\x1egoogle/protobuf/duration.proto\"\x0e\n" +
|
||||
"\fEmptyRequest\"\xc3\x0e\n" +
|
||||
"\fEmptyRequest\"\xe5\x0e\n" +
|
||||
"\fLoginRequest\x12\x1a\n" +
|
||||
"\bsetupKey\x18\x01 \x01(\tR\bsetupKey\x12&\n" +
|
||||
"\fpreSharedKey\x18\x02 \x01(\tB\x02\x18\x01R\fpreSharedKey\x12$\n" +
|
||||
@@ -4645,7 +4654,8 @@ const file_daemon_proto_rawDesc = "" +
|
||||
"\rblock_inbound\x18\x1d \x01(\bH\x10R\fblockInbound\x88\x01\x01\x12%\n" +
|
||||
"\vprofileName\x18\x1e \x01(\tH\x11R\vprofileName\x88\x01\x01\x12\x1f\n" +
|
||||
"\busername\x18\x1f \x01(\tH\x12R\busername\x88\x01\x01\x12\x15\n" +
|
||||
"\x03mtu\x18 \x01(\x03H\x13R\x03mtu\x88\x01\x01B\x13\n" +
|
||||
"\x03mtu\x18 \x01(\x03H\x13R\x03mtu\x88\x01\x01\x12\x17\n" +
|
||||
"\x04hint\x18! \x01(\tH\x14R\x04hint\x88\x01\x01B\x13\n" +
|
||||
"\x11_rosenpassEnabledB\x10\n" +
|
||||
"\x0e_interfaceNameB\x10\n" +
|
||||
"\x0e_wireguardPortB\x17\n" +
|
||||
@@ -4665,7 +4675,8 @@ const file_daemon_proto_rawDesc = "" +
|
||||
"\x0e_block_inboundB\x0e\n" +
|
||||
"\f_profileNameB\v\n" +
|
||||
"\t_usernameB\x06\n" +
|
||||
"\x04_mtu\"\xb5\x01\n" +
|
||||
"\x04_mtuB\a\n" +
|
||||
"\x05_hint\"\xb5\x01\n" +
|
||||
"\rLoginResponse\x12$\n" +
|
||||
"\rneedsSSOLogin\x18\x01 \x01(\bR\rneedsSSOLogin\x12\x1a\n" +
|
||||
"\buserCode\x18\x02 \x01(\tR\buserCode\x12(\n" +
|
||||
|
||||
@@ -158,6 +158,9 @@ message LoginRequest {
|
||||
optional string username = 31;
|
||||
|
||||
optional int64 mtu = 32;
|
||||
|
||||
// hint is used to pre-fill the email/username field during SSO authentication
|
||||
optional string hint = 33;
|
||||
}
|
||||
|
||||
message LoginResponse {
|
||||
|
||||
@@ -483,7 +483,11 @@ func (s *Server) Login(callerCtx context.Context, msg *proto.LoginRequest) (*pro
|
||||
state.Set(internal.StatusConnecting)
|
||||
|
||||
if msg.SetupKey == "" {
|
||||
oAuthFlow, err := auth.NewOAuthFlow(ctx, config, msg.IsUnixDesktopClient)
|
||||
hint := ""
|
||||
if msg.Hint != nil {
|
||||
hint = *msg.Hint
|
||||
}
|
||||
oAuthFlow, err := auth.NewOAuthFlow(ctx, config, msg.IsUnixDesktopClient, hint)
|
||||
if err != nil {
|
||||
state.Set(internal.StatusLoginFailed)
|
||||
return nil, err
|
||||
|
||||
@@ -14,6 +14,9 @@ import (
|
||||
"go.opentelemetry.io/otel"
|
||||
|
||||
"github.com/netbirdio/management-integrations/integrations"
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
|
||||
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/server/config"
|
||||
"github.com/netbirdio/netbird/management/server/groups"
|
||||
@@ -290,7 +293,6 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve
|
||||
}
|
||||
t.Cleanup(cleanUp)
|
||||
|
||||
peersUpdateManager := server.NewPeersUpdateManager(nil)
|
||||
eventStore := &activity.InMemoryEventStore{}
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
@@ -311,13 +313,16 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve
|
||||
settingsMockManager := settings.NewMockManager(ctrl)
|
||||
groupsManager := groups.NewManagerMock()
|
||||
|
||||
accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false)
|
||||
requestBuffer := server.NewAccountRequestBuffer(context.Background(), store)
|
||||
peersUpdateManager := update_channel.NewPeersUpdateManager(metrics)
|
||||
networkMapController := controller.NewController(context.Background(), store, metrics, peersUpdateManager, requestBuffer, server.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock())
|
||||
accountManager, err := server.BuildManager(context.Background(), store, networkMapController, nil, "", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
secretsManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
|
||||
mgmtServer, err := server.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, &manager.EphemeralManager{}, nil, &server.MockIntegratedValidator{})
|
||||
secretsManager := nbgrpc.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
|
||||
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, &manager.EphemeralManager{}, nil, &server.MockIntegratedValidator{}, networkMapController)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
@@ -610,11 +610,20 @@ func (s *serviceClient) login(ctx context.Context, openURL bool) (*proto.LoginRe
|
||||
return nil, fmt.Errorf("get current user: %w", err)
|
||||
}
|
||||
|
||||
loginResp, err := conn.Login(ctx, &proto.LoginRequest{
|
||||
loginReq := &proto.LoginRequest{
|
||||
IsUnixDesktopClient: runtime.GOOS == "linux" || runtime.GOOS == "freebsd",
|
||||
ProfileName: &activeProf.Name,
|
||||
Username: &currUser.Username,
|
||||
})
|
||||
}
|
||||
|
||||
profileState, err := s.profileManager.GetProfileState(activeProf.Name)
|
||||
if err != nil {
|
||||
log.Debugf("failed to get profile state for login hint: %v", err)
|
||||
} else if profileState.Email != "" {
|
||||
loginReq.Hint = &profileState.Email
|
||||
}
|
||||
|
||||
loginResp, err := conn.Login(ctx, loginReq)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("login to management: %w", err)
|
||||
}
|
||||
|
||||
14
go.mod
14
go.mod
@@ -56,6 +56,7 @@ require (
|
||||
github.com/hashicorp/go-multierror v1.1.1
|
||||
github.com/hashicorp/go-secure-stdlib/base62 v0.1.2
|
||||
github.com/hashicorp/go-version v1.6.0
|
||||
github.com/jackc/pgx/v5 v5.5.5
|
||||
github.com/libdns/route53 v1.5.0
|
||||
github.com/libp2p/go-netroute v0.2.1
|
||||
github.com/mdlayher/socket v0.5.1
|
||||
@@ -76,7 +77,7 @@ require (
|
||||
github.com/pion/transport/v3 v3.0.7
|
||||
github.com/pion/turn/v3 v3.0.1
|
||||
github.com/prometheus/client_golang v1.22.0
|
||||
github.com/quic-go/quic-go v0.48.2
|
||||
github.com/quic-go/quic-go v0.49.1
|
||||
github.com/redis/go-redis/v9 v9.7.3
|
||||
github.com/rs/xid v1.3.0
|
||||
github.com/shirou/gopsutil/v3 v3.24.4
|
||||
@@ -98,15 +99,17 @@ require (
|
||||
go.opentelemetry.io/otel/exporters/prometheus v0.48.0
|
||||
go.opentelemetry.io/otel/metric v1.35.0
|
||||
go.opentelemetry.io/otel/sdk/metric v1.35.0
|
||||
go.uber.org/mock v0.5.0
|
||||
go.uber.org/zap v1.27.0
|
||||
goauthentik.io/api/v3 v3.2023051.3
|
||||
golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842
|
||||
golang.org/x/mobile v0.0.0-20231127183840-76ac6878050a
|
||||
golang.org/x/mod v0.25.0
|
||||
golang.org/x/mod v0.26.0
|
||||
golang.org/x/net v0.42.0
|
||||
golang.org/x/oauth2 v0.28.0
|
||||
golang.org/x/oauth2 v0.30.0
|
||||
golang.org/x/sync v0.16.0
|
||||
golang.org/x/term v0.33.0
|
||||
golang.org/x/time v0.12.0
|
||||
google.golang.org/api v0.177.0
|
||||
gopkg.in/yaml.v3 v3.0.1
|
||||
gorm.io/driver/mysql v1.5.7
|
||||
@@ -146,7 +149,7 @@ require (
|
||||
github.com/beorn7/perks v1.0.1 // indirect
|
||||
github.com/caddyserver/zerossl v0.1.3 // indirect
|
||||
github.com/cespare/xxhash/v2 v2.3.0 // indirect
|
||||
github.com/containerd/containerd v1.7.27 // indirect
|
||||
github.com/containerd/containerd v1.7.29 // indirect
|
||||
github.com/containerd/log v0.1.0 // indirect
|
||||
github.com/containerd/platforms v0.2.1 // indirect
|
||||
github.com/cpuguy83/dockercfg v0.3.2 // indirect
|
||||
@@ -183,7 +186,6 @@ require (
|
||||
github.com/inconshreveable/mousetrap v1.1.0 // indirect
|
||||
github.com/jackc/pgpassfile v1.0.0 // indirect
|
||||
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect
|
||||
github.com/jackc/pgx/v5 v5.5.5 // indirect
|
||||
github.com/jackc/puddle/v2 v2.2.1 // indirect
|
||||
github.com/jeandeaual/go-locale v0.0.0-20240223122105-ce5225dcaa49 // indirect
|
||||
github.com/jinzhu/inflection v1.0.0 // indirect
|
||||
@@ -241,11 +243,9 @@ require (
|
||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.51.0 // indirect
|
||||
go.opentelemetry.io/otel/sdk v1.35.0 // indirect
|
||||
go.opentelemetry.io/otel/trace v1.35.0 // indirect
|
||||
go.uber.org/mock v0.4.0 // indirect
|
||||
go.uber.org/multierr v1.11.0 // indirect
|
||||
golang.org/x/image v0.18.0 // indirect
|
||||
golang.org/x/text v0.27.0 // indirect
|
||||
golang.org/x/time v0.5.0 // indirect
|
||||
golang.org/x/tools v0.34.0 // indirect
|
||||
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20250707201910-8d1bb00bc6a7 // indirect
|
||||
|
||||
24
go.sum
24
go.sum
@@ -142,8 +142,8 @@ github.com/cncf/udpa/go v0.0.0-20200629203442-efcf912fb354/go.mod h1:WmhPx2Nbnht
|
||||
github.com/cncf/udpa/go v0.0.0-20201120205902-5459f2c99403/go.mod h1:WmhPx2Nbnhtbo57+VJT5O0JRkEi1Wbu0z5j0R8u5Hbk=
|
||||
github.com/coder/websocket v1.8.13 h1:f3QZdXy7uGVz+4uCJy2nTZyM0yTBj8yANEHhqlXZ9FE=
|
||||
github.com/coder/websocket v1.8.13/go.mod h1:LNVeNrXQZfe5qhS9ALED3uA+l5pPqvwXg3CKoDBB2gs=
|
||||
github.com/containerd/containerd v1.7.27 h1:yFyEyojddO3MIGVER2xJLWoCIn+Up4GaHFquP7hsFII=
|
||||
github.com/containerd/containerd v1.7.27/go.mod h1:xZmPnl75Vc+BLGt4MIfu6bp+fy03gdHAn9bz+FreFR0=
|
||||
github.com/containerd/containerd v1.7.29 h1:90fWABQsaN9mJhGkoVnuzEY+o1XDPbg9BTC9QTAHnuE=
|
||||
github.com/containerd/containerd v1.7.29/go.mod h1:azUkWcOvHrWvaiUjSQH0fjzuHIwSPg1WL5PshGP4Szs=
|
||||
github.com/containerd/log v0.1.0 h1:TCJt7ioM2cr/tfR8GPbGf9/VRAX8D2B4PjzCpfX540I=
|
||||
github.com/containerd/log v0.1.0/go.mod h1:VRRf09a7mHDIRezVKTRCrOq78v577GXq3bSa3EhrzVo=
|
||||
github.com/containerd/platforms v0.2.1 h1:zvwtM3rz2YHPQsF2CHYM8+KtB5dvhISiXh5ZpSBQv6A=
|
||||
@@ -590,8 +590,8 @@ github.com/prometheus/common v0.62.0 h1:xasJaQlnWAeyHdUBeGjXmutelfJHWMRr+Fg4QszZ
|
||||
github.com/prometheus/common v0.62.0/go.mod h1:vyBcEuLSvWos9B1+CyL7JZ2up+uFzXhkqml0W5zIY1I=
|
||||
github.com/prometheus/procfs v0.15.1 h1:YagwOFzUgYfKKHX6Dr+sHT7km/hxC76UB0learggepc=
|
||||
github.com/prometheus/procfs v0.15.1/go.mod h1:fB45yRUv8NstnjriLhBQLuOUt+WW4BsoGhij/e3PBqk=
|
||||
github.com/quic-go/quic-go v0.48.2 h1:wsKXZPeGWpMpCGSWqOcqpW2wZYic/8T3aqiOID0/KWE=
|
||||
github.com/quic-go/quic-go v0.48.2/go.mod h1:yBgs3rWBOADpga7F+jJsb6Ybg1LSYiQvwWlLX+/6HMs=
|
||||
github.com/quic-go/quic-go v0.49.1 h1:e5JXpUyF0f2uFjckQzD8jTghZrOUK1xxDqqZhlwixo0=
|
||||
github.com/quic-go/quic-go v0.49.1/go.mod h1:s2wDnmCdooUQBmQfpUSTCYBl1/D4FcqbULMMkASvR6s=
|
||||
github.com/redis/go-redis/v9 v9.7.3 h1:YpPyAayJV+XErNsatSElgRZZVCwXX9QzkKYNvO7x0wM=
|
||||
github.com/redis/go-redis/v9 v9.7.3/go.mod h1:bGUrSggJ9X9GUmZpZNEOQKaANxSGgOEBRltRTZHSvrA=
|
||||
github.com/rogpeppe/fastuuid v1.2.0/go.mod h1:jVj6XXZzXRy/MSR5jhDC/2q6DgLz+nrA6LYCDYWNEvQ=
|
||||
@@ -749,8 +749,8 @@ go.opentelemetry.io/proto/otlp v1.0.0/go.mod h1:Sy6pihPLfYHkr3NkUbEhGHFhINUSI/v8
|
||||
go.uber.org/atomic v1.7.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc=
|
||||
go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto=
|
||||
go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE=
|
||||
go.uber.org/mock v0.4.0 h1:VcM4ZOtdbR4f6VXfiOpwpVJDL6lCReaZ6mw31wqh7KU=
|
||||
go.uber.org/mock v0.4.0/go.mod h1:a6FSlNadKUHUa9IP5Vyt1zh4fC7uAwxMutEAscFbkZc=
|
||||
go.uber.org/mock v0.5.0 h1:KAMbZvZPyBPWgD14IrIQ38QCyjwpvVVV6K/bHl1IwQU=
|
||||
go.uber.org/mock v0.5.0/go.mod h1:ge71pBPLYDk7QIi1LupWxdAykm7KIEFchiOqd6z7qMM=
|
||||
go.uber.org/multierr v1.6.0/go.mod h1:cdWPpRnG4AhwMwsgIHip0KRBQjJy5kYEpYjJxpXp9iU=
|
||||
go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0=
|
||||
go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y=
|
||||
@@ -818,8 +818,8 @@ golang.org/x/mod v0.4.1/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
|
||||
golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
|
||||
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4=
|
||||
golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
|
||||
golang.org/x/mod v0.25.0 h1:n7a+ZbQKQA/Ysbyb0/6IbB1H/X41mKgbhfv7AfG/44w=
|
||||
golang.org/x/mod v0.25.0/go.mod h1:IXM97Txy2VM4PJ3gI61r1YEk/gAj6zAHN3AdZt6S9Ww=
|
||||
golang.org/x/mod v0.26.0 h1:EGMPT//Ezu+ylkCijjPc+f4Aih7sZvaAr+O3EHBxvZg=
|
||||
golang.org/x/mod v0.26.0/go.mod h1:/j6NAhSk8iQ723BGAUyoAcn7SlD7s15Dp9Nd/SfeaFQ=
|
||||
golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
||||
golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
||||
golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
||||
@@ -880,8 +880,8 @@ golang.org/x/oauth2 v0.0.0-20210220000619-9bb904979d93/go.mod h1:KelEdhl1UZF7XfJ
|
||||
golang.org/x/oauth2 v0.0.0-20210313182246-cd4f82c27b84/go.mod h1:KelEdhl1UZF7XfJ4dDtk6s++YSgaE7mD/BuKKDLBl4A=
|
||||
golang.org/x/oauth2 v0.0.0-20210402161424-2e8d93401602/go.mod h1:KelEdhl1UZF7XfJ4dDtk6s++YSgaE7mD/BuKKDLBl4A=
|
||||
golang.org/x/oauth2 v0.8.0/go.mod h1:yr7u4HXZRm1R1kBWqr/xKNqewf0plRYoB7sla+BCIXE=
|
||||
golang.org/x/oauth2 v0.28.0 h1:CrgCKl8PPAVtLnU3c+EDw6x11699EWlsDeWNWKdIOkc=
|
||||
golang.org/x/oauth2 v0.28.0/go.mod h1:onh5ek6nERTohokkhCD/y2cV4Do3fxFHFuAejCkRWT8=
|
||||
golang.org/x/oauth2 v0.30.0 h1:dnDm7JmhM45NNpd8FDDeLhK6FwqbOf4MLCM9zb1BOHI=
|
||||
golang.org/x/oauth2 v0.30.0/go.mod h1:B++QgG3ZKulg6sRPGD/mqlHQs5rB3Ml9erfeDY7xKlU=
|
||||
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
@@ -993,8 +993,8 @@ golang.org/x/text v0.27.0/go.mod h1:1D28KMCvyooCX9hBiosv5Tz/+YLxj0j7XhWjpSUF7CU=
|
||||
golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
|
||||
golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
|
||||
golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
|
||||
golang.org/x/time v0.5.0 h1:o7cqy6amK/52YcAKIPlM3a+Fpj35zvRj2TP+e1xFSfk=
|
||||
golang.org/x/time v0.5.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
|
||||
golang.org/x/time v0.12.0 h1:ScB/8o8olJvc+CQPWrK3fPZNfh7qgwCrY0zJmoEQLSE=
|
||||
golang.org/x/time v0.12.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg=
|
||||
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||
golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||
golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY=
|
||||
|
||||
31
management/internals/controllers/network_map/controller/cache/dns_config_cache.go
vendored
Normal file
31
management/internals/controllers/network_map/controller/cache/dns_config_cache.go
vendored
Normal file
@@ -0,0 +1,31 @@
|
||||
package cache
|
||||
|
||||
import (
|
||||
"sync"
|
||||
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
|
||||
// DNSConfigCache is a thread-safe cache for DNS configuration components
|
||||
type DNSConfigCache struct {
|
||||
NameServerGroups sync.Map
|
||||
}
|
||||
|
||||
// GetNameServerGroup retrieves a cached name server group
|
||||
func (c *DNSConfigCache) GetNameServerGroup(key string) (*proto.NameServerGroup, bool) {
|
||||
if c == nil {
|
||||
return nil, false
|
||||
}
|
||||
if value, ok := c.NameServerGroups.Load(key); ok {
|
||||
return value.(*proto.NameServerGroup), true
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// SetNameServerGroup stores a name server group in the cache
|
||||
func (c *DNSConfigCache) SetNameServerGroup(key string, value *proto.NameServerGroup) {
|
||||
if c == nil {
|
||||
return
|
||||
}
|
||||
c.NameServerGroups.Store(key, value)
|
||||
}
|
||||
@@ -0,0 +1,784 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/exp/maps"
|
||||
"golang.org/x/mod/semver"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller/cache"
|
||||
"github.com/netbirdio/netbird/management/internals/shared/grpc"
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
"github.com/netbirdio/netbird/management/server/integrations/integrated_validator"
|
||||
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/posture"
|
||||
"github.com/netbirdio/netbird/management/server/settings"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
|
||||
type Controller struct {
|
||||
repo Repository
|
||||
metrics *metrics
|
||||
// This should not be here, but we need to maintain it for the time being
|
||||
accountManagerMetrics *telemetry.AccountManagerMetrics
|
||||
peersUpdateManager network_map.PeersUpdateManager
|
||||
settingsManager settings.Manager
|
||||
|
||||
accountUpdateLocks sync.Map
|
||||
sendAccountUpdateLocks sync.Map
|
||||
updateAccountPeersBufferInterval atomic.Int64
|
||||
// dnsDomain is used for peer resolution. This is appended to the peer's name
|
||||
dnsDomain string
|
||||
|
||||
requestBuffer account.RequestBuffer
|
||||
|
||||
proxyController port_forwarding.Controller
|
||||
|
||||
integratedPeerValidator integrated_validator.IntegratedValidator
|
||||
|
||||
holder *types.Holder
|
||||
|
||||
expNewNetworkMap bool
|
||||
expNewNetworkMapAIDs map[string]struct{}
|
||||
}
|
||||
|
||||
type bufferUpdate struct {
|
||||
mu sync.Mutex
|
||||
next *time.Timer
|
||||
update atomic.Bool
|
||||
}
|
||||
|
||||
var _ network_map.Controller = (*Controller)(nil)
|
||||
|
||||
func NewController(ctx context.Context, store store.Store, metrics telemetry.AppMetrics, peersUpdateManager network_map.PeersUpdateManager, requestBuffer account.RequestBuffer, integratedPeerValidator integrated_validator.IntegratedValidator, settingsManager settings.Manager, dnsDomain string, proxyController port_forwarding.Controller) *Controller {
|
||||
nMetrics, err := newMetrics(metrics.UpdateChannelMetrics())
|
||||
if err != nil {
|
||||
log.Fatal(fmt.Errorf("error creating metrics: %w", err))
|
||||
}
|
||||
|
||||
newNetworkMapBuilder, err := strconv.ParseBool(os.Getenv(network_map.EnvNewNetworkMapBuilder))
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Warnf("failed to parse %s, using default value false: %v", network_map.EnvNewNetworkMapBuilder, err)
|
||||
newNetworkMapBuilder = false
|
||||
}
|
||||
|
||||
ids := strings.Split(os.Getenv(network_map.EnvNewNetworkMapAccounts), ",")
|
||||
expIDs := make(map[string]struct{}, len(ids))
|
||||
for _, id := range ids {
|
||||
expIDs[id] = struct{}{}
|
||||
}
|
||||
|
||||
return &Controller{
|
||||
repo: newRepository(store),
|
||||
metrics: nMetrics,
|
||||
accountManagerMetrics: metrics.AccountManagerMetrics(),
|
||||
peersUpdateManager: peersUpdateManager,
|
||||
requestBuffer: requestBuffer,
|
||||
integratedPeerValidator: integratedPeerValidator,
|
||||
settingsManager: settingsManager,
|
||||
dnsDomain: dnsDomain,
|
||||
|
||||
proxyController: proxyController,
|
||||
|
||||
holder: types.NewHolder(),
|
||||
expNewNetworkMap: newNetworkMapBuilder,
|
||||
expNewNetworkMapAIDs: expIDs,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID string) error {
|
||||
log.WithContext(ctx).Tracef("updating peers for account %s from %s", accountID, util.GetCallerName())
|
||||
var (
|
||||
account *types.Account
|
||||
err error
|
||||
)
|
||||
if c.experimentalNetworkMap(accountID) {
|
||||
account = c.getAccountFromHolderOrInit(accountID)
|
||||
} else {
|
||||
account, err = c.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get account: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
globalStart := time.Now()
|
||||
|
||||
hasPeersConnected := false
|
||||
for _, peer := range account.Peers {
|
||||
if c.peersUpdateManager.HasChannel(peer.ID) {
|
||||
hasPeersConnected = true
|
||||
break
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
if !hasPeersConnected {
|
||||
return nil
|
||||
}
|
||||
|
||||
approvedPeersMap, err := c.integratedPeerValidator.GetValidatedPeers(ctx, account.Id, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get validate peers: %v", err)
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
semaphore := make(chan struct{}, 10)
|
||||
|
||||
dnsCache := &cache.DNSConfigCache{}
|
||||
dnsDomain := c.GetDNSDomain(account.Settings)
|
||||
customZone := account.GetPeersCustomZone(ctx, dnsDomain)
|
||||
resourcePolicies := account.GetResourcePoliciesMap()
|
||||
routers := account.GetResourceRoutersMap()
|
||||
|
||||
if c.experimentalNetworkMap(accountID) {
|
||||
c.initNetworkMapBuilderIfNeeded(account, approvedPeersMap)
|
||||
}
|
||||
|
||||
proxyNetworkMaps, err := c.proxyController.GetProxyNetworkMapsAll(ctx, accountID, account.Peers)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get proxy network maps: %v", err)
|
||||
return fmt.Errorf("failed to get proxy network maps: %v", err)
|
||||
}
|
||||
|
||||
extraSetting, err := c.settingsManager.GetExtraSettings(ctx, accountID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get flow enabled status: %v", err)
|
||||
}
|
||||
|
||||
dnsFwdPort := computeForwarderPort(maps.Values(account.Peers), network_map.DnsForwarderPortMinVersion)
|
||||
|
||||
for _, peer := range account.Peers {
|
||||
if !c.peersUpdateManager.HasChannel(peer.ID) {
|
||||
log.WithContext(ctx).Tracef("peer %s doesn't have a channel, skipping network map update", peer.ID)
|
||||
continue
|
||||
}
|
||||
|
||||
wg.Add(1)
|
||||
semaphore <- struct{}{}
|
||||
go func(p *nbpeer.Peer) {
|
||||
defer wg.Done()
|
||||
defer func() { <-semaphore }()
|
||||
|
||||
start := time.Now()
|
||||
|
||||
postureChecks, err := c.getPeerPostureChecks(account, p.ID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Debugf("failed to get posture checks for peer %s: %v", p.ID, err)
|
||||
return
|
||||
}
|
||||
|
||||
c.metrics.CountCalcPostureChecksDuration(time.Since(start))
|
||||
start = time.Now()
|
||||
|
||||
var remotePeerNetworkMap *types.NetworkMap
|
||||
|
||||
if c.experimentalNetworkMap(accountID) {
|
||||
remotePeerNetworkMap = c.getPeerNetworkMapExp(ctx, p.AccountID, p.ID, approvedPeersMap, customZone, c.accountManagerMetrics)
|
||||
} else {
|
||||
remotePeerNetworkMap = account.GetPeerNetworkMap(ctx, p.ID, customZone, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics)
|
||||
}
|
||||
|
||||
c.metrics.CountCalcPeerNetworkMapDuration(time.Since(start))
|
||||
|
||||
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]
|
||||
if ok {
|
||||
remotePeerNetworkMap.Merge(proxyNetworkMap)
|
||||
}
|
||||
|
||||
peerGroups := account.GetPeerGroups(p.ID)
|
||||
start = time.Now()
|
||||
update := grpc.ToSyncResponse(ctx, nil, p, nil, nil, remotePeerNetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings, extraSetting, maps.Keys(peerGroups), dnsFwdPort)
|
||||
c.metrics.CountToSyncResponseDuration(time.Since(start))
|
||||
|
||||
c.peersUpdateManager.SendUpdate(ctx, p.ID, &network_map.UpdateMessage{Update: update})
|
||||
}(peer)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
if c.accountManagerMetrics != nil {
|
||||
c.accountManagerMetrics.CountUpdateAccountPeersDuration(time.Since(globalStart))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Controller) bufferSendUpdateAccountPeers(ctx context.Context, accountID string) error {
|
||||
log.WithContext(ctx).Tracef("buffer sending update peers for account %s from %s", accountID, util.GetCallerName())
|
||||
|
||||
bufUpd, _ := c.sendAccountUpdateLocks.LoadOrStore(accountID, &bufferUpdate{})
|
||||
b := bufUpd.(*bufferUpdate)
|
||||
|
||||
if !b.mu.TryLock() {
|
||||
b.update.Store(true)
|
||||
return nil
|
||||
}
|
||||
|
||||
if b.next != nil {
|
||||
b.next.Stop()
|
||||
}
|
||||
|
||||
go func() {
|
||||
defer b.mu.Unlock()
|
||||
_ = c.sendUpdateAccountPeers(ctx, accountID)
|
||||
if !b.update.Load() {
|
||||
return
|
||||
}
|
||||
b.update.Store(false)
|
||||
if b.next == nil {
|
||||
b.next = time.AfterFunc(time.Duration(c.updateAccountPeersBufferInterval.Load()), func() {
|
||||
_ = c.sendUpdateAccountPeers(ctx, accountID)
|
||||
})
|
||||
return
|
||||
}
|
||||
b.next.Reset(time.Duration(c.updateAccountPeersBufferInterval.Load()))
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdatePeers updates all peers that belong to an account.
|
||||
// Should be called when changes have to be synced to peers.
|
||||
func (c *Controller) UpdateAccountPeers(ctx context.Context, accountID string) error {
|
||||
if err := c.RecalculateNetworkMapCache(ctx, accountID); err != nil {
|
||||
return fmt.Errorf("recalculate network map cache: %v", err)
|
||||
}
|
||||
|
||||
return c.sendUpdateAccountPeers(ctx, accountID)
|
||||
}
|
||||
|
||||
func (c *Controller) UpdateAccountPeer(ctx context.Context, accountId string, peerId string) error {
|
||||
if !c.peersUpdateManager.HasChannel(peerId) {
|
||||
return fmt.Errorf("peer %s doesn't have a channel, skipping network map update", peerId)
|
||||
}
|
||||
|
||||
account, err := c.requestBuffer.GetAccountWithBackpressure(ctx, accountId)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to send out updates to peer %s: %v", peerId, err)
|
||||
}
|
||||
|
||||
peer := account.GetPeer(peerId)
|
||||
if peer == nil {
|
||||
return fmt.Errorf("peer %s doesn't exists in account %s", peerId, accountId)
|
||||
}
|
||||
|
||||
approvedPeersMap, err := c.integratedPeerValidator.GetValidatedPeers(ctx, account.Id, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get validated peers: %v", err)
|
||||
}
|
||||
|
||||
dnsCache := &cache.DNSConfigCache{}
|
||||
dnsDomain := c.GetDNSDomain(account.Settings)
|
||||
customZone := account.GetPeersCustomZone(ctx, dnsDomain)
|
||||
resourcePolicies := account.GetResourcePoliciesMap()
|
||||
routers := account.GetResourceRoutersMap()
|
||||
|
||||
postureChecks, err := c.getPeerPostureChecks(account, peerId)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to send update to peer %s, failed to get posture checks: %v", peerId, err)
|
||||
return fmt.Errorf("failed to get posture checks for peer %s: %v", peerId, err)
|
||||
}
|
||||
|
||||
proxyNetworkMaps, err := c.proxyController.GetProxyNetworkMaps(ctx, account.Id, peer.ID, account.Peers)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get proxy network maps: %v", err)
|
||||
return err
|
||||
}
|
||||
|
||||
var remotePeerNetworkMap *types.NetworkMap
|
||||
|
||||
if c.experimentalNetworkMap(accountId) {
|
||||
remotePeerNetworkMap = c.getPeerNetworkMapExp(ctx, peer.AccountID, peer.ID, approvedPeersMap, customZone, c.accountManagerMetrics)
|
||||
} else {
|
||||
remotePeerNetworkMap = account.GetPeerNetworkMap(ctx, peerId, customZone, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics)
|
||||
}
|
||||
|
||||
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]
|
||||
if ok {
|
||||
remotePeerNetworkMap.Merge(proxyNetworkMap)
|
||||
}
|
||||
|
||||
extraSettings, err := c.settingsManager.GetExtraSettings(ctx, peer.AccountID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get extra settings: %v", err)
|
||||
}
|
||||
|
||||
peerGroups := account.GetPeerGroups(peerId)
|
||||
dnsFwdPort := computeForwarderPort(maps.Values(account.Peers), network_map.DnsForwarderPortMinVersion)
|
||||
|
||||
update := grpc.ToSyncResponse(ctx, nil, peer, nil, nil, remotePeerNetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings, extraSettings, maps.Keys(peerGroups), dnsFwdPort)
|
||||
c.peersUpdateManager.SendUpdate(ctx, peer.ID, &network_map.UpdateMessage{Update: update})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Controller) BufferUpdateAccountPeers(ctx context.Context, accountID string) error {
|
||||
log.WithContext(ctx).Tracef("buffer updating peers for account %s from %s", accountID, util.GetCallerName())
|
||||
|
||||
bufUpd, _ := c.accountUpdateLocks.LoadOrStore(accountID, &bufferUpdate{})
|
||||
b := bufUpd.(*bufferUpdate)
|
||||
|
||||
if !b.mu.TryLock() {
|
||||
b.update.Store(true)
|
||||
return nil
|
||||
}
|
||||
|
||||
if b.next != nil {
|
||||
b.next.Stop()
|
||||
}
|
||||
|
||||
go func() {
|
||||
defer b.mu.Unlock()
|
||||
_ = c.UpdateAccountPeers(ctx, accountID)
|
||||
if !b.update.Load() {
|
||||
return
|
||||
}
|
||||
b.update.Store(false)
|
||||
if b.next == nil {
|
||||
b.next = time.AfterFunc(time.Duration(c.updateAccountPeersBufferInterval.Load()), func() {
|
||||
_ = c.UpdateAccountPeers(ctx, accountID)
|
||||
})
|
||||
return
|
||||
}
|
||||
b.next.Reset(time.Duration(c.updateAccountPeersBufferInterval.Load()))
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Controller) DeletePeer(ctx context.Context, accountId string, peerId string) error {
|
||||
network, err := c.repo.GetAccountNetwork(ctx, accountId)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
peers, err := c.repo.GetAccountPeers(ctx, accountId)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
dnsFwdPort := computeForwarderPort(peers, network_map.DnsForwarderPortMinVersion)
|
||||
c.peersUpdateManager.SendUpdate(ctx, peerId, &network_map.UpdateMessage{
|
||||
Update: &proto.SyncResponse{
|
||||
RemotePeers: []*proto.RemotePeerConfig{},
|
||||
RemotePeersIsEmpty: true,
|
||||
NetworkMap: &proto.NetworkMap{
|
||||
Serial: network.CurrentSerial(),
|
||||
RemotePeers: []*proto.RemotePeerConfig{},
|
||||
RemotePeersIsEmpty: true,
|
||||
FirewallRules: []*proto.FirewallRule{},
|
||||
FirewallRulesIsEmpty: true,
|
||||
DNSConfig: &proto.DNSConfig{
|
||||
ForwarderPort: dnsFwdPort,
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
c.peersUpdateManager.CloseChannel(ctx, peerId)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Controller) GetValidatedPeerWithMap(ctx context.Context, isRequiresApproval bool, accountID string, peer *nbpeer.Peer) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) {
|
||||
if isRequiresApproval {
|
||||
network, err := c.repo.GetAccountNetwork(ctx, accountID)
|
||||
if err != nil {
|
||||
return nil, nil, nil, 0, err
|
||||
}
|
||||
|
||||
emptyMap := &types.NetworkMap{
|
||||
Network: network.Copy(),
|
||||
}
|
||||
return peer, emptyMap, nil, 0, nil
|
||||
}
|
||||
|
||||
var (
|
||||
account *types.Account
|
||||
err error
|
||||
)
|
||||
if c.experimentalNetworkMap(accountID) {
|
||||
account = c.getAccountFromHolderOrInit(accountID)
|
||||
} else {
|
||||
account, err = c.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
|
||||
if err != nil {
|
||||
return nil, nil, nil, 0, err
|
||||
}
|
||||
}
|
||||
|
||||
approvedPeersMap, err := c.integratedPeerValidator.GetValidatedPeers(ctx, account.Id, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra)
|
||||
if err != nil {
|
||||
return nil, nil, nil, 0, err
|
||||
}
|
||||
|
||||
startPosture := time.Now()
|
||||
postureChecks, err := c.getPeerPostureChecks(account, peer.ID)
|
||||
if err != nil {
|
||||
return nil, nil, nil, 0, err
|
||||
}
|
||||
log.WithContext(ctx).Debugf("getPeerPostureChecks took %s", time.Since(startPosture))
|
||||
|
||||
customZone := account.GetPeersCustomZone(ctx, c.GetDNSDomain(account.Settings))
|
||||
|
||||
proxyNetworkMaps, err := c.proxyController.GetProxyNetworkMaps(ctx, account.Id, peer.ID, account.Peers)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get proxy network maps: %v", err)
|
||||
return nil, nil, nil, 0, err
|
||||
}
|
||||
|
||||
var networkMap *types.NetworkMap
|
||||
|
||||
if c.experimentalNetworkMap(accountID) {
|
||||
networkMap = c.getPeerNetworkMapExp(ctx, peer.AccountID, peer.ID, approvedPeersMap, customZone, c.accountManagerMetrics)
|
||||
} else {
|
||||
networkMap = account.GetPeerNetworkMap(ctx, peer.ID, customZone, approvedPeersMap, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), c.accountManagerMetrics)
|
||||
}
|
||||
|
||||
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]
|
||||
if ok {
|
||||
networkMap.Merge(proxyNetworkMap)
|
||||
}
|
||||
|
||||
dnsFwdPort := computeForwarderPort(maps.Values(account.Peers), network_map.DnsForwarderPortMinVersion)
|
||||
|
||||
return peer, networkMap, postureChecks, dnsFwdPort, nil
|
||||
}
|
||||
|
||||
func (c *Controller) initNetworkMapBuilderIfNeeded(account *types.Account, validatedPeers map[string]struct{}) {
|
||||
c.enrichAccountFromHolder(account)
|
||||
account.InitNetworkMapBuilderIfNeeded(validatedPeers)
|
||||
}
|
||||
|
||||
func (c *Controller) getPeerNetworkMapExp(
|
||||
ctx context.Context,
|
||||
accountId string,
|
||||
peerId string,
|
||||
validatedPeers map[string]struct{},
|
||||
customZone nbdns.CustomZone,
|
||||
metrics *telemetry.AccountManagerMetrics,
|
||||
) *types.NetworkMap {
|
||||
account := c.getAccountFromHolderOrInit(accountId)
|
||||
if account == nil {
|
||||
log.WithContext(ctx).Warnf("account %s not found in holder when getting peer network map", accountId)
|
||||
return &types.NetworkMap{
|
||||
Network: &types.Network{},
|
||||
}
|
||||
}
|
||||
return account.GetPeerNetworkMapExp(ctx, peerId, customZone, validatedPeers, metrics)
|
||||
}
|
||||
|
||||
func (c *Controller) onPeerAddedUpdNetworkMapCache(account *types.Account, peerId string) error {
|
||||
c.enrichAccountFromHolder(account)
|
||||
return account.OnPeerAddedUpdNetworkMapCache(peerId)
|
||||
}
|
||||
|
||||
func (c *Controller) onPeerDeletedUpdNetworkMapCache(account *types.Account, peerId string) error {
|
||||
c.enrichAccountFromHolder(account)
|
||||
return account.OnPeerDeletedUpdNetworkMapCache(peerId)
|
||||
}
|
||||
|
||||
func (c *Controller) UpdatePeerInNetworkMapCache(accountId string, peer *nbpeer.Peer) {
|
||||
account := c.getAccountFromHolder(accountId)
|
||||
if account == nil {
|
||||
return
|
||||
}
|
||||
account.UpdatePeerInNetworkMapCache(peer)
|
||||
}
|
||||
|
||||
func (c *Controller) recalculateNetworkMapCache(account *types.Account, validatedPeers map[string]struct{}) {
|
||||
account.RecalculateNetworkMapCache(validatedPeers)
|
||||
c.updateAccountInHolder(account)
|
||||
}
|
||||
|
||||
func (c *Controller) RecalculateNetworkMapCache(ctx context.Context, accountId string) error {
|
||||
if c.experimentalNetworkMap(accountId) {
|
||||
account, err := c.requestBuffer.GetAccountWithBackpressure(ctx, accountId)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
validatedPeers, err := c.integratedPeerValidator.GetValidatedPeers(ctx, account.Id, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get validate peers: %v", err)
|
||||
return err
|
||||
}
|
||||
c.recalculateNetworkMapCache(account, validatedPeers)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Controller) experimentalNetworkMap(accountId string) bool {
|
||||
_, ok := c.expNewNetworkMapAIDs[accountId]
|
||||
return c.expNewNetworkMap || ok
|
||||
}
|
||||
|
||||
func (c *Controller) enrichAccountFromHolder(account *types.Account) {
|
||||
a := c.holder.GetAccount(account.Id)
|
||||
if a == nil {
|
||||
c.holder.AddAccount(account)
|
||||
return
|
||||
}
|
||||
account.NetworkMapCache = a.NetworkMapCache
|
||||
if account.NetworkMapCache == nil {
|
||||
return
|
||||
}
|
||||
account.NetworkMapCache.UpdateAccountPointer(account)
|
||||
c.holder.AddAccount(account)
|
||||
}
|
||||
|
||||
func (c *Controller) getAccountFromHolder(accountID string) *types.Account {
|
||||
return c.holder.GetAccount(accountID)
|
||||
}
|
||||
|
||||
func (c *Controller) getAccountFromHolderOrInit(accountID string) *types.Account {
|
||||
a := c.holder.GetAccount(accountID)
|
||||
if a != nil {
|
||||
return a
|
||||
}
|
||||
account, err := c.holder.LoadOrStoreFunc(accountID, c.requestBuffer.GetAccountWithBackpressure)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
return account
|
||||
}
|
||||
|
||||
func (c *Controller) updateAccountInHolder(account *types.Account) {
|
||||
c.holder.AddAccount(account)
|
||||
}
|
||||
|
||||
// GetDNSDomain returns the configured dnsDomain
|
||||
func (c *Controller) GetDNSDomain(settings *types.Settings) string {
|
||||
if settings == nil {
|
||||
return c.dnsDomain
|
||||
}
|
||||
if settings.DNSDomain == "" {
|
||||
return c.dnsDomain
|
||||
}
|
||||
|
||||
return settings.DNSDomain
|
||||
}
|
||||
|
||||
// getPeerPostureChecks returns the posture checks applied for a given peer.
|
||||
func (c *Controller) getPeerPostureChecks(account *types.Account, peerID string) ([]*posture.Checks, error) {
|
||||
peerPostureChecks := make(map[string]*posture.Checks)
|
||||
|
||||
if len(account.PostureChecks) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
for _, policy := range account.Policies {
|
||||
if !policy.Enabled || len(policy.SourcePostureChecks) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
if err := addPolicyPostureChecks(account, peerID, policy, peerPostureChecks); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return maps.Values(peerPostureChecks), nil
|
||||
}
|
||||
|
||||
func (c *Controller) StartWarmup(ctx context.Context) {
|
||||
var initialInterval int64
|
||||
intervalStr := os.Getenv("NB_PEER_UPDATE_INTERVAL_MS")
|
||||
interval, err := strconv.Atoi(intervalStr)
|
||||
if err != nil {
|
||||
initialInterval = 1
|
||||
log.WithContext(ctx).Warnf("failed to parse peer update interval, using default value %dms: %v", initialInterval, err)
|
||||
} else {
|
||||
initialInterval = int64(interval) * 10
|
||||
go func() {
|
||||
startupPeriodStr := os.Getenv("NB_PEER_UPDATE_STARTUP_PERIOD_S")
|
||||
startupPeriod, err := strconv.Atoi(startupPeriodStr)
|
||||
if err != nil {
|
||||
startupPeriod = 1
|
||||
log.WithContext(ctx).Warnf("failed to parse peer update startup period, using default value %ds: %v", startupPeriod, err)
|
||||
}
|
||||
time.Sleep(time.Duration(startupPeriod) * time.Second)
|
||||
c.updateAccountPeersBufferInterval.Store(int64(time.Duration(interval) * time.Millisecond))
|
||||
log.WithContext(ctx).Infof("set peer update buffer interval to %dms", interval)
|
||||
}()
|
||||
}
|
||||
c.updateAccountPeersBufferInterval.Store(int64(time.Duration(initialInterval) * time.Millisecond))
|
||||
log.WithContext(ctx).Infof("set peer update buffer interval to %dms", initialInterval)
|
||||
|
||||
}
|
||||
|
||||
// computeForwarderPort checks if all peers in the account have updated to a specific version or newer.
|
||||
// If all peers have the required version, it returns the new well-known port (22054), otherwise returns 0.
|
||||
func computeForwarderPort(peers []*nbpeer.Peer, requiredVersion string) int64 {
|
||||
if len(peers) == 0 {
|
||||
return int64(network_map.OldForwarderPort)
|
||||
}
|
||||
|
||||
reqVer := semver.Canonical(requiredVersion)
|
||||
|
||||
// Check if all peers have the required version or newer
|
||||
for _, peer := range peers {
|
||||
|
||||
// Development version is always supported
|
||||
if peer.Meta.WtVersion == "development" {
|
||||
continue
|
||||
}
|
||||
peerVersion := semver.Canonical("v" + peer.Meta.WtVersion)
|
||||
if peerVersion == "" {
|
||||
// If any peer doesn't have version info, return 0
|
||||
return int64(network_map.OldForwarderPort)
|
||||
}
|
||||
|
||||
// Compare versions
|
||||
if semver.Compare(peerVersion, reqVer) < 0 {
|
||||
return int64(network_map.OldForwarderPort)
|
||||
}
|
||||
}
|
||||
|
||||
// All peers have the required version or newer
|
||||
return int64(network_map.DnsForwarderPort)
|
||||
}
|
||||
|
||||
// addPolicyPostureChecks adds posture checks from a policy to the peer posture checks map if the peer is in the policy's source groups.
|
||||
func addPolicyPostureChecks(account *types.Account, peerID string, policy *types.Policy, peerPostureChecks map[string]*posture.Checks) error {
|
||||
isInGroup, err := isPeerInPolicySourceGroups(account, peerID, policy)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !isInGroup {
|
||||
return nil
|
||||
}
|
||||
|
||||
for _, sourcePostureCheckID := range policy.SourcePostureChecks {
|
||||
postureCheck := account.GetPostureChecks(sourcePostureCheckID)
|
||||
if postureCheck == nil {
|
||||
return errors.New("failed to add policy posture checks: posture checks not found")
|
||||
}
|
||||
peerPostureChecks[sourcePostureCheckID] = postureCheck
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// isPeerInPolicySourceGroups checks if a peer is present in any of the policy rule source groups.
|
||||
func isPeerInPolicySourceGroups(account *types.Account, peerID string, policy *types.Policy) (bool, error) {
|
||||
for _, rule := range policy.Rules {
|
||||
if !rule.Enabled {
|
||||
continue
|
||||
}
|
||||
|
||||
for _, sourceGroup := range rule.Sources {
|
||||
group := account.GetGroup(sourceGroup)
|
||||
if group == nil {
|
||||
return false, fmt.Errorf("failed to check peer in policy source group: group not found")
|
||||
}
|
||||
|
||||
if slices.Contains(group.Peers, peerID) {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (c *Controller) OnPeerUpdated(accountId string, peer *nbpeer.Peer) {
|
||||
c.UpdatePeerInNetworkMapCache(accountId, peer)
|
||||
_ = c.bufferSendUpdateAccountPeers(context.Background(), accountId)
|
||||
}
|
||||
|
||||
func (c *Controller) OnPeerAdded(ctx context.Context, accountID string, peerID string) error {
|
||||
if c.experimentalNetworkMap(accountID) {
|
||||
account, err := c.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = c.onPeerAddedUpdNetworkMapCache(account, peerID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return c.bufferSendUpdateAccountPeers(ctx, accountID)
|
||||
}
|
||||
|
||||
func (c *Controller) OnPeerDeleted(ctx context.Context, accountID string, peerID string) error {
|
||||
if c.experimentalNetworkMap(accountID) {
|
||||
account, err := c.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = c.onPeerDeletedUpdNetworkMapCache(account, peerID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return c.bufferSendUpdateAccountPeers(ctx, accountID)
|
||||
}
|
||||
|
||||
// GetNetworkMap returns Network map for a given peer (omits original peer from the Peers result)
|
||||
func (c *Controller) GetNetworkMap(ctx context.Context, peerID string) (*types.NetworkMap, error) {
|
||||
account, err := c.repo.GetAccountByPeerID(ctx, peerID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
peer := account.GetPeer(peerID)
|
||||
if peer == nil {
|
||||
return nil, status.Errorf(status.NotFound, "peer with ID %s not found", peerID)
|
||||
}
|
||||
|
||||
groups := make(map[string][]string)
|
||||
for groupID, group := range account.Groups {
|
||||
groups[groupID] = group.Peers
|
||||
}
|
||||
|
||||
validatedPeers, err := c.integratedPeerValidator.GetValidatedPeers(ctx, account.Id, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
customZone := account.GetPeersCustomZone(ctx, c.GetDNSDomain(account.Settings))
|
||||
|
||||
proxyNetworkMaps, err := c.proxyController.GetProxyNetworkMaps(ctx, account.Id, peerID, account.Peers)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get proxy network maps: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var networkMap *types.NetworkMap
|
||||
|
||||
if c.experimentalNetworkMap(peer.AccountID) {
|
||||
networkMap = c.getPeerNetworkMapExp(ctx, peer.AccountID, peerID, validatedPeers, customZone, nil)
|
||||
} else {
|
||||
networkMap = account.GetPeerNetworkMap(ctx, peer.ID, customZone, validatedPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil)
|
||||
}
|
||||
|
||||
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]
|
||||
if ok {
|
||||
networkMap.Merge(proxyNetworkMap)
|
||||
}
|
||||
|
||||
return networkMap, nil
|
||||
}
|
||||
|
||||
func (c *Controller) DisconnectPeers(ctx context.Context, peerIDs []string) {
|
||||
c.peersUpdateManager.CloseChannels(ctx, peerIDs)
|
||||
}
|
||||
|
||||
func (c *Controller) IsConnected(peerID string) bool {
|
||||
return c.peersUpdateManager.HasChannel(peerID)
|
||||
}
|
||||
@@ -0,0 +1,244 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
|
||||
"github.com/netbirdio/netbird/management/server/mock_server"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
)
|
||||
|
||||
func TestComputeForwarderPort(t *testing.T) {
|
||||
// Test with empty peers list
|
||||
peers := []*nbpeer.Peer{}
|
||||
result := computeForwarderPort(peers, "v0.59.0")
|
||||
if result != int64(network_map.OldForwarderPort) {
|
||||
t.Errorf("Expected %d for empty peers list, got %d", network_map.OldForwarderPort, result)
|
||||
}
|
||||
|
||||
// Test with peers that have old versions
|
||||
peers = []*nbpeer.Peer{
|
||||
{
|
||||
Meta: nbpeer.PeerSystemMeta{
|
||||
WtVersion: "0.57.0",
|
||||
},
|
||||
},
|
||||
{
|
||||
Meta: nbpeer.PeerSystemMeta{
|
||||
WtVersion: "0.26.0",
|
||||
},
|
||||
},
|
||||
}
|
||||
result = computeForwarderPort(peers, "v0.59.0")
|
||||
if result != int64(network_map.OldForwarderPort) {
|
||||
t.Errorf("Expected %d for peers with old versions, got %d", network_map.OldForwarderPort, result)
|
||||
}
|
||||
|
||||
// Test with peers that have new versions
|
||||
peers = []*nbpeer.Peer{
|
||||
{
|
||||
Meta: nbpeer.PeerSystemMeta{
|
||||
WtVersion: "0.59.0",
|
||||
},
|
||||
},
|
||||
{
|
||||
Meta: nbpeer.PeerSystemMeta{
|
||||
WtVersion: "0.59.0",
|
||||
},
|
||||
},
|
||||
}
|
||||
result = computeForwarderPort(peers, "v0.59.0")
|
||||
if result != int64(network_map.DnsForwarderPort) {
|
||||
t.Errorf("Expected %d for peers with new versions, got %d", network_map.DnsForwarderPort, result)
|
||||
}
|
||||
|
||||
// Test with peers that have mixed versions
|
||||
peers = []*nbpeer.Peer{
|
||||
{
|
||||
Meta: nbpeer.PeerSystemMeta{
|
||||
WtVersion: "0.59.0",
|
||||
},
|
||||
},
|
||||
{
|
||||
Meta: nbpeer.PeerSystemMeta{
|
||||
WtVersion: "0.57.0",
|
||||
},
|
||||
},
|
||||
}
|
||||
result = computeForwarderPort(peers, "v0.59.0")
|
||||
if result != int64(network_map.OldForwarderPort) {
|
||||
t.Errorf("Expected %d for peers with mixed versions, got %d", network_map.OldForwarderPort, result)
|
||||
}
|
||||
|
||||
// Test with peers that have empty version
|
||||
peers = []*nbpeer.Peer{
|
||||
{
|
||||
Meta: nbpeer.PeerSystemMeta{
|
||||
WtVersion: "",
|
||||
},
|
||||
},
|
||||
}
|
||||
result = computeForwarderPort(peers, "v0.59.0")
|
||||
if result != int64(network_map.OldForwarderPort) {
|
||||
t.Errorf("Expected %d for peers with empty version, got %d", network_map.OldForwarderPort, result)
|
||||
}
|
||||
|
||||
peers = []*nbpeer.Peer{
|
||||
{
|
||||
Meta: nbpeer.PeerSystemMeta{
|
||||
WtVersion: "development",
|
||||
},
|
||||
},
|
||||
}
|
||||
result = computeForwarderPort(peers, "v0.59.0")
|
||||
if result == int64(network_map.OldForwarderPort) {
|
||||
t.Errorf("Expected %d for peers with dev version, got %d", network_map.DnsForwarderPort, result)
|
||||
}
|
||||
|
||||
// Test with peers that have unknown version string
|
||||
peers = []*nbpeer.Peer{
|
||||
{
|
||||
Meta: nbpeer.PeerSystemMeta{
|
||||
WtVersion: "unknown",
|
||||
},
|
||||
},
|
||||
}
|
||||
result = computeForwarderPort(peers, "v0.59.0")
|
||||
if result != int64(network_map.OldForwarderPort) {
|
||||
t.Errorf("Expected %d for peers with unknown version, got %d", network_map.OldForwarderPort, result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBufferUpdateAccountPeers(t *testing.T) {
|
||||
const (
|
||||
peersCount = 1000
|
||||
updateAccountInterval = 50 * time.Millisecond
|
||||
)
|
||||
|
||||
var (
|
||||
deletedPeers, updatePeersDeleted, updatePeersRuns atomic.Int32
|
||||
uapLastRun, dpLastRun atomic.Int64
|
||||
|
||||
totalNewRuns, totalOldRuns int
|
||||
)
|
||||
|
||||
uap := func(ctx context.Context, accountID string) {
|
||||
updatePeersDeleted.Store(deletedPeers.Load())
|
||||
updatePeersRuns.Add(1)
|
||||
uapLastRun.Store(time.Now().UnixMilli())
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
}
|
||||
|
||||
t.Run("new approach", func(t *testing.T) {
|
||||
updatePeersRuns.Store(0)
|
||||
updatePeersDeleted.Store(0)
|
||||
deletedPeers.Store(0)
|
||||
|
||||
var mustore sync.Map
|
||||
bufupd := func(ctx context.Context, accountID string) {
|
||||
mu, _ := mustore.LoadOrStore(accountID, &bufferUpdate{})
|
||||
b := mu.(*bufferUpdate)
|
||||
|
||||
if !b.mu.TryLock() {
|
||||
b.update.Store(true)
|
||||
return
|
||||
}
|
||||
|
||||
if b.next != nil {
|
||||
b.next.Stop()
|
||||
}
|
||||
|
||||
go func() {
|
||||
defer b.mu.Unlock()
|
||||
uap(ctx, accountID)
|
||||
if !b.update.Load() {
|
||||
return
|
||||
}
|
||||
b.update.Store(false)
|
||||
b.next = time.AfterFunc(updateAccountInterval, func() {
|
||||
uap(ctx, accountID)
|
||||
})
|
||||
}()
|
||||
}
|
||||
dp := func(ctx context.Context, accountID, peerID, userID string) error {
|
||||
deletedPeers.Add(1)
|
||||
dpLastRun.Store(time.Now().UnixMilli())
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
bufupd(ctx, accountID)
|
||||
return nil
|
||||
}
|
||||
|
||||
am := mock_server.MockAccountManager{
|
||||
UpdateAccountPeersFunc: uap,
|
||||
BufferUpdateAccountPeersFunc: bufupd,
|
||||
DeletePeerFunc: dp,
|
||||
}
|
||||
empty := ""
|
||||
for range peersCount {
|
||||
//nolint
|
||||
am.DeletePeer(context.Background(), empty, empty, empty)
|
||||
}
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
assert.Equal(t, peersCount, int(deletedPeers.Load()), "Expected all peers to be deleted")
|
||||
assert.Equal(t, peersCount, int(updatePeersDeleted.Load()), "Expected all peers to be updated in the buffer")
|
||||
assert.GreaterOrEqual(t, uapLastRun.Load(), dpLastRun.Load(), "Expected update account peers to run after delete peer")
|
||||
|
||||
totalNewRuns = int(updatePeersRuns.Load())
|
||||
})
|
||||
|
||||
t.Run("old approach", func(t *testing.T) {
|
||||
updatePeersRuns.Store(0)
|
||||
updatePeersDeleted.Store(0)
|
||||
deletedPeers.Store(0)
|
||||
|
||||
var mustore sync.Map
|
||||
bufupd := func(ctx context.Context, accountID string) {
|
||||
mu, _ := mustore.LoadOrStore(accountID, &sync.Mutex{})
|
||||
b := mu.(*sync.Mutex)
|
||||
|
||||
if !b.TryLock() {
|
||||
return
|
||||
}
|
||||
|
||||
go func() {
|
||||
time.Sleep(updateAccountInterval)
|
||||
b.Unlock()
|
||||
uap(ctx, accountID)
|
||||
}()
|
||||
}
|
||||
dp := func(ctx context.Context, accountID, peerID, userID string) error {
|
||||
deletedPeers.Add(1)
|
||||
dpLastRun.Store(time.Now().UnixMilli())
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
bufupd(ctx, accountID)
|
||||
return nil
|
||||
}
|
||||
|
||||
am := mock_server.MockAccountManager{
|
||||
UpdateAccountPeersFunc: uap,
|
||||
BufferUpdateAccountPeersFunc: bufupd,
|
||||
DeletePeerFunc: dp,
|
||||
}
|
||||
empty := ""
|
||||
for range peersCount {
|
||||
//nolint
|
||||
am.DeletePeer(context.Background(), empty, empty, empty)
|
||||
}
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
assert.Equal(t, peersCount, int(deletedPeers.Load()), "Expected all peers to be deleted")
|
||||
assert.Equal(t, peersCount, int(updatePeersDeleted.Load()), "Expected all peers to be updated in the buffer")
|
||||
assert.GreaterOrEqual(t, uapLastRun.Load(), dpLastRun.Load(), "Expected update account peers to run after delete peer")
|
||||
|
||||
totalOldRuns = int(updatePeersRuns.Load())
|
||||
})
|
||||
assert.Less(t, totalNewRuns, totalOldRuns, "Expected new approach to run less than old approach. New runs: %d, Old runs: %d", totalNewRuns, totalOldRuns)
|
||||
t.Logf("New runs: %d, Old runs: %d", totalNewRuns, totalOldRuns)
|
||||
}
|
||||
@@ -0,0 +1,15 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
)
|
||||
|
||||
type metrics struct {
|
||||
*telemetry.UpdateChannelMetrics
|
||||
}
|
||||
|
||||
func newMetrics(updateChannelMetrics *telemetry.UpdateChannelMetrics) (*metrics, error) {
|
||||
return &metrics{
|
||||
updateChannelMetrics,
|
||||
}, nil
|
||||
}
|
||||
@@ -0,0 +1,39 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
)
|
||||
|
||||
type Repository interface {
|
||||
GetAccountNetwork(ctx context.Context, accountID string) (*types.Network, error)
|
||||
GetAccountPeers(ctx context.Context, accountID string) ([]*peer.Peer, error)
|
||||
GetAccountByPeerID(ctx context.Context, peerID string) (*types.Account, error)
|
||||
}
|
||||
|
||||
type repository struct {
|
||||
store store.Store
|
||||
}
|
||||
|
||||
var _ Repository = (*repository)(nil)
|
||||
|
||||
func newRepository(s store.Store) Repository {
|
||||
return &repository{
|
||||
store: s,
|
||||
}
|
||||
}
|
||||
|
||||
func (r *repository) GetAccountNetwork(ctx context.Context, accountID string) (*types.Network, error) {
|
||||
return r.store.GetAccountNetwork(ctx, store.LockingStrengthNone, accountID)
|
||||
}
|
||||
|
||||
func (r *repository) GetAccountPeers(ctx context.Context, accountID string) ([]*peer.Peer, error) {
|
||||
return r.store.GetAccountPeers(ctx, store.LockingStrengthNone, accountID, "", "")
|
||||
}
|
||||
|
||||
func (r *repository) GetAccountByPeerID(ctx context.Context, peerID string) (*types.Account, error) {
|
||||
return r.store.GetAccountByPeerID(ctx, peerID)
|
||||
}
|
||||
39
management/internals/controllers/network_map/interface.go
Normal file
39
management/internals/controllers/network_map/interface.go
Normal file
@@ -0,0 +1,39 @@
|
||||
package network_map
|
||||
|
||||
//go:generate go run go.uber.org/mock/mockgen -package network_map -destination=interface_mock.go -source=./interface.go -build_flags=-mod=mod
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/posture"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
)
|
||||
|
||||
const (
|
||||
EnvNewNetworkMapBuilder = "NB_EXPERIMENT_NETWORK_MAP"
|
||||
EnvNewNetworkMapAccounts = "NB_EXPERIMENT_NETWORK_MAP_ACCOUNTS"
|
||||
|
||||
DnsForwarderPort = nbdns.ForwarderServerPort
|
||||
OldForwarderPort = nbdns.ForwarderClientPort
|
||||
DnsForwarderPortMinVersion = "v0.59.0"
|
||||
)
|
||||
|
||||
type Controller interface {
|
||||
UpdateAccountPeers(ctx context.Context, accountID string) error
|
||||
UpdateAccountPeer(ctx context.Context, accountId string, peerId string) error
|
||||
BufferUpdateAccountPeers(ctx context.Context, accountID string) error
|
||||
GetValidatedPeerWithMap(ctx context.Context, isRequiresApproval bool, accountID string, p *nbpeer.Peer) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error)
|
||||
GetDNSDomain(settings *types.Settings) string
|
||||
StartWarmup(context.Context)
|
||||
GetNetworkMap(ctx context.Context, peerID string) (*types.NetworkMap, error)
|
||||
|
||||
DeletePeer(ctx context.Context, accountId string, peerId string) error
|
||||
|
||||
OnPeerUpdated(accountId string, peer *nbpeer.Peer)
|
||||
OnPeerAdded(ctx context.Context, accountID string, peerID string) error
|
||||
OnPeerDeleted(ctx context.Context, accountID string, peerID string) error
|
||||
DisconnectPeers(ctx context.Context, peerIDs []string)
|
||||
IsConnected(peerID string) bool
|
||||
}
|
||||
225
management/internals/controllers/network_map/interface_mock.go
Normal file
225
management/internals/controllers/network_map/interface_mock.go
Normal file
@@ -0,0 +1,225 @@
|
||||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: ./interface.go
|
||||
//
|
||||
// Generated by this command:
|
||||
//
|
||||
// mockgen -package network_map -destination=interface_mock.go -source=./interface.go -build_flags=-mod=mod
|
||||
//
|
||||
|
||||
// Package network_map is a generated GoMock package.
|
||||
package network_map
|
||||
|
||||
import (
|
||||
context "context"
|
||||
reflect "reflect"
|
||||
|
||||
peer "github.com/netbirdio/netbird/management/server/peer"
|
||||
posture "github.com/netbirdio/netbird/management/server/posture"
|
||||
types "github.com/netbirdio/netbird/management/server/types"
|
||||
gomock "go.uber.org/mock/gomock"
|
||||
)
|
||||
|
||||
// MockController is a mock of Controller interface.
|
||||
type MockController struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockControllerMockRecorder
|
||||
isgomock struct{}
|
||||
}
|
||||
|
||||
// MockControllerMockRecorder is the mock recorder for MockController.
|
||||
type MockControllerMockRecorder struct {
|
||||
mock *MockController
|
||||
}
|
||||
|
||||
// NewMockController creates a new mock instance.
|
||||
func NewMockController(ctrl *gomock.Controller) *MockController {
|
||||
mock := &MockController{ctrl: ctrl}
|
||||
mock.recorder = &MockControllerMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use.
|
||||
func (m *MockController) EXPECT() *MockControllerMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// BufferUpdateAccountPeers mocks base method.
|
||||
func (m *MockController) BufferUpdateAccountPeers(ctx context.Context, accountID string) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "BufferUpdateAccountPeers", ctx, accountID)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// BufferUpdateAccountPeers indicates an expected call of BufferUpdateAccountPeers.
|
||||
func (mr *MockControllerMockRecorder) BufferUpdateAccountPeers(ctx, accountID any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BufferUpdateAccountPeers", reflect.TypeOf((*MockController)(nil).BufferUpdateAccountPeers), ctx, accountID)
|
||||
}
|
||||
|
||||
// DeletePeer mocks base method.
|
||||
func (m *MockController) DeletePeer(ctx context.Context, accountId, peerId string) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "DeletePeer", ctx, accountId, peerId)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// DeletePeer indicates an expected call of DeletePeer.
|
||||
func (mr *MockControllerMockRecorder) DeletePeer(ctx, accountId, peerId any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeletePeer", reflect.TypeOf((*MockController)(nil).DeletePeer), ctx, accountId, peerId)
|
||||
}
|
||||
|
||||
// DisconnectPeers mocks base method.
|
||||
func (m *MockController) DisconnectPeers(ctx context.Context, peerIDs []string) {
|
||||
m.ctrl.T.Helper()
|
||||
m.ctrl.Call(m, "DisconnectPeers", ctx, peerIDs)
|
||||
}
|
||||
|
||||
// DisconnectPeers indicates an expected call of DisconnectPeers.
|
||||
func (mr *MockControllerMockRecorder) DisconnectPeers(ctx, peerIDs any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DisconnectPeers", reflect.TypeOf((*MockController)(nil).DisconnectPeers), ctx, peerIDs)
|
||||
}
|
||||
|
||||
// GetDNSDomain mocks base method.
|
||||
func (m *MockController) GetDNSDomain(settings *types.Settings) string {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetDNSDomain", settings)
|
||||
ret0, _ := ret[0].(string)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// GetDNSDomain indicates an expected call of GetDNSDomain.
|
||||
func (mr *MockControllerMockRecorder) GetDNSDomain(settings any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetDNSDomain", reflect.TypeOf((*MockController)(nil).GetDNSDomain), settings)
|
||||
}
|
||||
|
||||
// GetNetworkMap mocks base method.
|
||||
func (m *MockController) GetNetworkMap(ctx context.Context, peerID string) (*types.NetworkMap, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetNetworkMap", ctx, peerID)
|
||||
ret0, _ := ret[0].(*types.NetworkMap)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetNetworkMap indicates an expected call of GetNetworkMap.
|
||||
func (mr *MockControllerMockRecorder) GetNetworkMap(ctx, peerID any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetNetworkMap", reflect.TypeOf((*MockController)(nil).GetNetworkMap), ctx, peerID)
|
||||
}
|
||||
|
||||
// GetValidatedPeerWithMap mocks base method.
|
||||
func (m *MockController) GetValidatedPeerWithMap(ctx context.Context, isRequiresApproval bool, accountID string, p *peer.Peer) (*peer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetValidatedPeerWithMap", ctx, isRequiresApproval, accountID, p)
|
||||
ret0, _ := ret[0].(*peer.Peer)
|
||||
ret1, _ := ret[1].(*types.NetworkMap)
|
||||
ret2, _ := ret[2].([]*posture.Checks)
|
||||
ret3, _ := ret[3].(int64)
|
||||
ret4, _ := ret[4].(error)
|
||||
return ret0, ret1, ret2, ret3, ret4
|
||||
}
|
||||
|
||||
// GetValidatedPeerWithMap indicates an expected call of GetValidatedPeerWithMap.
|
||||
func (mr *MockControllerMockRecorder) GetValidatedPeerWithMap(ctx, isRequiresApproval, accountID, p any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetValidatedPeerWithMap", reflect.TypeOf((*MockController)(nil).GetValidatedPeerWithMap), ctx, isRequiresApproval, accountID, p)
|
||||
}
|
||||
|
||||
// IsConnected mocks base method.
|
||||
func (m *MockController) IsConnected(peerID string) bool {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "IsConnected", peerID)
|
||||
ret0, _ := ret[0].(bool)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// IsConnected indicates an expected call of IsConnected.
|
||||
func (mr *MockControllerMockRecorder) IsConnected(peerID any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsConnected", reflect.TypeOf((*MockController)(nil).IsConnected), peerID)
|
||||
}
|
||||
|
||||
// OnPeerAdded mocks base method.
|
||||
func (m *MockController) OnPeerAdded(ctx context.Context, accountID, peerID string) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "OnPeerAdded", ctx, accountID, peerID)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// OnPeerAdded indicates an expected call of OnPeerAdded.
|
||||
func (mr *MockControllerMockRecorder) OnPeerAdded(ctx, accountID, peerID any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnPeerAdded", reflect.TypeOf((*MockController)(nil).OnPeerAdded), ctx, accountID, peerID)
|
||||
}
|
||||
|
||||
// OnPeerDeleted mocks base method.
|
||||
func (m *MockController) OnPeerDeleted(ctx context.Context, accountID, peerID string) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "OnPeerDeleted", ctx, accountID, peerID)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// OnPeerDeleted indicates an expected call of OnPeerDeleted.
|
||||
func (mr *MockControllerMockRecorder) OnPeerDeleted(ctx, accountID, peerID any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnPeerDeleted", reflect.TypeOf((*MockController)(nil).OnPeerDeleted), ctx, accountID, peerID)
|
||||
}
|
||||
|
||||
// OnPeerUpdated mocks base method.
|
||||
func (m *MockController) OnPeerUpdated(accountId string, peer *peer.Peer) {
|
||||
m.ctrl.T.Helper()
|
||||
m.ctrl.Call(m, "OnPeerUpdated", accountId, peer)
|
||||
}
|
||||
|
||||
// OnPeerUpdated indicates an expected call of OnPeerUpdated.
|
||||
func (mr *MockControllerMockRecorder) OnPeerUpdated(accountId, peer any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnPeerUpdated", reflect.TypeOf((*MockController)(nil).OnPeerUpdated), accountId, peer)
|
||||
}
|
||||
|
||||
// StartWarmup mocks base method.
|
||||
func (m *MockController) StartWarmup(arg0 context.Context) {
|
||||
m.ctrl.T.Helper()
|
||||
m.ctrl.Call(m, "StartWarmup", arg0)
|
||||
}
|
||||
|
||||
// StartWarmup indicates an expected call of StartWarmup.
|
||||
func (mr *MockControllerMockRecorder) StartWarmup(arg0 any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StartWarmup", reflect.TypeOf((*MockController)(nil).StartWarmup), arg0)
|
||||
}
|
||||
|
||||
// UpdateAccountPeer mocks base method.
|
||||
func (m *MockController) UpdateAccountPeer(ctx context.Context, accountId, peerId string) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "UpdateAccountPeer", ctx, accountId, peerId)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// UpdateAccountPeer indicates an expected call of UpdateAccountPeer.
|
||||
func (mr *MockControllerMockRecorder) UpdateAccountPeer(ctx, accountId, peerId any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateAccountPeer", reflect.TypeOf((*MockController)(nil).UpdateAccountPeer), ctx, accountId, peerId)
|
||||
}
|
||||
|
||||
// UpdateAccountPeers mocks base method.
|
||||
func (m *MockController) UpdateAccountPeers(ctx context.Context, accountID string) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "UpdateAccountPeers", ctx, accountID)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// UpdateAccountPeers indicates an expected call of UpdateAccountPeers.
|
||||
func (mr *MockControllerMockRecorder) UpdateAccountPeers(ctx, accountID any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateAccountPeers", reflect.TypeOf((*MockController)(nil).UpdateAccountPeers), ctx, accountID)
|
||||
}
|
||||
@@ -0,0 +1 @@
|
||||
package network_map
|
||||
@@ -0,0 +1,13 @@
|
||||
package network_map
|
||||
|
||||
import "context"
|
||||
|
||||
type PeersUpdateManager interface {
|
||||
SendUpdate(ctx context.Context, peerID string, update *UpdateMessage)
|
||||
CreateChannel(ctx context.Context, peerID string) chan *UpdateMessage
|
||||
CloseChannel(ctx context.Context, peerID string)
|
||||
CountStreams() int
|
||||
HasChannel(peerID string) bool
|
||||
CloseChannels(ctx context.Context, peerIDs []string)
|
||||
GetAllConnectedPeers() map[string]struct{}
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package server
|
||||
package update_channel
|
||||
|
||||
import (
|
||||
"context"
|
||||
@@ -7,38 +7,34 @@ import (
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
)
|
||||
|
||||
const channelBufferSize = 100
|
||||
|
||||
type UpdateMessage struct {
|
||||
Update *proto.SyncResponse
|
||||
NetworkMap *types.NetworkMap
|
||||
}
|
||||
|
||||
type PeersUpdateManager struct {
|
||||
// peerChannels is an update channel indexed by Peer.ID
|
||||
peerChannels map[string]chan *UpdateMessage
|
||||
peerChannels map[string]chan *network_map.UpdateMessage
|
||||
// channelsMux keeps the mutex to access peerChannels
|
||||
channelsMux *sync.RWMutex
|
||||
// metrics provides method to collect application metrics
|
||||
metrics telemetry.AppMetrics
|
||||
}
|
||||
|
||||
var _ network_map.PeersUpdateManager = (*PeersUpdateManager)(nil)
|
||||
|
||||
// NewPeersUpdateManager returns a new instance of PeersUpdateManager
|
||||
func NewPeersUpdateManager(metrics telemetry.AppMetrics) *PeersUpdateManager {
|
||||
return &PeersUpdateManager{
|
||||
peerChannels: make(map[string]chan *UpdateMessage),
|
||||
peerChannels: make(map[string]chan *network_map.UpdateMessage),
|
||||
channelsMux: &sync.RWMutex{},
|
||||
metrics: metrics,
|
||||
}
|
||||
}
|
||||
|
||||
// SendUpdate sends update message to the peer's channel
|
||||
func (p *PeersUpdateManager) SendUpdate(ctx context.Context, peerID string, update *UpdateMessage) {
|
||||
func (p *PeersUpdateManager) SendUpdate(ctx context.Context, peerID string, update *network_map.UpdateMessage) {
|
||||
start := time.Now()
|
||||
var found, dropped bool
|
||||
|
||||
@@ -66,7 +62,7 @@ func (p *PeersUpdateManager) SendUpdate(ctx context.Context, peerID string, upda
|
||||
}
|
||||
|
||||
// CreateChannel creates a go channel for a given peer used to deliver updates relevant to the peer.
|
||||
func (p *PeersUpdateManager) CreateChannel(ctx context.Context, peerID string) chan *UpdateMessage {
|
||||
func (p *PeersUpdateManager) CreateChannel(ctx context.Context, peerID string) chan *network_map.UpdateMessage {
|
||||
start := time.Now()
|
||||
|
||||
closed := false
|
||||
@@ -85,7 +81,7 @@ func (p *PeersUpdateManager) CreateChannel(ctx context.Context, peerID string) c
|
||||
close(channel)
|
||||
}
|
||||
// mbragin: todo shouldn't it be more? or configurable?
|
||||
channel := make(chan *UpdateMessage, channelBufferSize)
|
||||
channel := make(chan *network_map.UpdateMessage, channelBufferSize)
|
||||
p.peerChannels[peerID] = channel
|
||||
|
||||
log.WithContext(ctx).Debugf("opened updates channel for a peer %s", peerID)
|
||||
@@ -176,3 +172,9 @@ func (p *PeersUpdateManager) HasChannel(peerID string) bool {
|
||||
|
||||
return ok
|
||||
}
|
||||
|
||||
func (p *PeersUpdateManager) CountStreams() int {
|
||||
p.channelsMux.RLock()
|
||||
defer p.channelsMux.RUnlock()
|
||||
return len(p.peerChannels)
|
||||
}
|
||||
@@ -1,10 +1,11 @@
|
||||
package server
|
||||
package update_channel
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
|
||||
@@ -24,7 +25,7 @@ func TestCreateChannel(t *testing.T) {
|
||||
func TestSendUpdate(t *testing.T) {
|
||||
peer := "test-sendupdate"
|
||||
peersUpdater := NewPeersUpdateManager(nil)
|
||||
update1 := &UpdateMessage{Update: &proto.SyncResponse{
|
||||
update1 := &network_map.UpdateMessage{Update: &proto.SyncResponse{
|
||||
NetworkMap: &proto.NetworkMap{
|
||||
Serial: 0,
|
||||
},
|
||||
@@ -44,7 +45,7 @@ func TestSendUpdate(t *testing.T) {
|
||||
peersUpdater.SendUpdate(context.Background(), peer, update1)
|
||||
}
|
||||
|
||||
update2 := &UpdateMessage{Update: &proto.SyncResponse{
|
||||
update2 := &network_map.UpdateMessage{Update: &proto.SyncResponse{
|
||||
NetworkMap: &proto.NetworkMap{
|
||||
Serial: 10,
|
||||
},
|
||||
@@ -0,0 +1,9 @@
|
||||
package network_map
|
||||
|
||||
import (
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
|
||||
type UpdateMessage struct {
|
||||
Update *proto.SyncResponse
|
||||
}
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/gorilla/mux"
|
||||
grpcMiddleware "github.com/grpc-ecosystem/go-grpc-middleware/v2"
|
||||
"github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/realip"
|
||||
log "github.com/sirupsen/logrus"
|
||||
@@ -22,7 +23,7 @@ import (
|
||||
"github.com/netbirdio/netbird/encryption"
|
||||
"github.com/netbirdio/netbird/formatter/hook"
|
||||
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
nbContext "github.com/netbirdio/netbird/management/server/context"
|
||||
nbhttp "github.com/netbirdio/netbird/management/server/http"
|
||||
@@ -91,9 +92,15 @@ func (s *BaseServer) EventStore() activity.Store {
|
||||
})
|
||||
}
|
||||
|
||||
func (s *BaseServer) Router() *mux.Router {
|
||||
return Create(s, func() *mux.Router {
|
||||
return nbhttp.NewRouter(s.AuthManager(), s.Metrics(), s.AccountManager())
|
||||
})
|
||||
}
|
||||
|
||||
func (s *BaseServer) APIHandler() http.Handler {
|
||||
return Create(s, func() http.Handler {
|
||||
httpAPIHandler, err := nbhttp.NewAPIHandler(context.Background(), s.AccountManager(), s.NetworksManager(), s.ResourcesManager(), s.RoutesManager(), s.GroupsManager(), s.GeoLocationManager(), s.AuthManager(), s.Metrics(), s.IntegratedValidator(), s.ProxyController(), s.PermissionsManager(), s.PeersManager(), s.SettingsManager())
|
||||
httpAPIHandler, err := nbhttp.NewAPIHandler(context.Background(), s.AccountManager(), s.NetworksManager(), s.ResourcesManager(), s.RoutesManager(), s.GroupsManager(), s.GeoLocationManager(), s.Router(), s.Metrics(), s.IntegratedValidator(), s.ProxyController(), s.PermissionsManager(), s.PeersManager(), s.SettingsManager(), s.NetworkMapController())
|
||||
if err != nil {
|
||||
log.Fatalf("failed to create API handler: %v", err)
|
||||
}
|
||||
@@ -145,7 +152,7 @@ func (s *BaseServer) GRPCServer() *grpc.Server {
|
||||
}
|
||||
|
||||
gRPCAPIHandler := grpc.NewServer(gRPCOpts...)
|
||||
srv, err := server.NewServer(context.Background(), s.config, s.AccountManager(), s.SettingsManager(), s.PeersUpdateManager(), s.SecretsManager(), s.Metrics(), s.EphemeralManager(), s.AuthManager(), s.IntegratedValidator())
|
||||
srv, err := nbgrpc.NewServer(s.config, s.AccountManager(), s.SettingsManager(), s.PeersUpdateManager(), s.SecretsManager(), s.Metrics(), s.EphemeralManager(), s.AuthManager(), s.IntegratedValidator(), s.NetworkMapController())
|
||||
if err != nil {
|
||||
log.Fatalf("failed to create management server: %v", err)
|
||||
}
|
||||
|
||||
@@ -6,6 +6,10 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/management-integrations/integrations"
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
|
||||
nmapcontroller "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
|
||||
"github.com/netbirdio/netbird/management/internals/shared/grpc"
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
"github.com/netbirdio/netbird/management/server/auth"
|
||||
"github.com/netbirdio/netbird/management/server/integrations/integrated_validator"
|
||||
@@ -14,9 +18,9 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/peers/ephemeral/manager"
|
||||
)
|
||||
|
||||
func (s *BaseServer) PeersUpdateManager() *server.PeersUpdateManager {
|
||||
return Create(s, func() *server.PeersUpdateManager {
|
||||
return server.NewPeersUpdateManager(s.Metrics())
|
||||
func (s *BaseServer) PeersUpdateManager() network_map.PeersUpdateManager {
|
||||
return Create(s, func() *update_channel.PeersUpdateManager {
|
||||
return update_channel.NewPeersUpdateManager(s.Metrics())
|
||||
})
|
||||
}
|
||||
|
||||
@@ -40,9 +44,9 @@ func (s *BaseServer) ProxyController() port_forwarding.Controller {
|
||||
})
|
||||
}
|
||||
|
||||
func (s *BaseServer) SecretsManager() *server.TimeBasedAuthSecretsManager {
|
||||
return Create(s, func() *server.TimeBasedAuthSecretsManager {
|
||||
return server.NewTimeBasedAuthSecretsManager(s.PeersUpdateManager(), s.config.TURNConfig, s.config.Relay, s.SettingsManager(), s.GroupsManager())
|
||||
func (s *BaseServer) SecretsManager() *grpc.TimeBasedAuthSecretsManager {
|
||||
return Create(s, func() *grpc.TimeBasedAuthSecretsManager {
|
||||
return grpc.NewTimeBasedAuthSecretsManager(s.PeersUpdateManager(), s.config.TURNConfig, s.config.Relay, s.SettingsManager(), s.GroupsManager())
|
||||
})
|
||||
}
|
||||
|
||||
@@ -63,3 +67,15 @@ func (s *BaseServer) EphemeralManager() ephemeral.Manager {
|
||||
return manager.NewEphemeralManager(s.Store(), s.AccountManager())
|
||||
})
|
||||
}
|
||||
|
||||
func (s *BaseServer) NetworkMapController() network_map.Controller {
|
||||
return Create(s, func() *nmapcontroller.Controller {
|
||||
return nmapcontroller.NewController(context.Background(), s.Store(), s.Metrics(), s.PeersUpdateManager(), s.AccountRequestBuffer(), s.IntegratedValidator(), s.SettingsManager(), s.dnsDomain, s.ProxyController())
|
||||
})
|
||||
}
|
||||
|
||||
func (s *BaseServer) AccountRequestBuffer() *server.AccountRequestBuffer {
|
||||
return Create(s, func() *server.AccountRequestBuffer {
|
||||
return server.NewAccountRequestBuffer(context.Background(), s.Store())
|
||||
})
|
||||
}
|
||||
|
||||
@@ -66,8 +66,7 @@ func (s *BaseServer) PeersManager() peers.Manager {
|
||||
|
||||
func (s *BaseServer) AccountManager() account.Manager {
|
||||
return Create(s, func() account.Manager {
|
||||
accountManager, err := server.BuildManager(context.Background(), s.Store(), s.PeersUpdateManager(), s.IdpManager(), s.mgmtSingleAccModeDomain,
|
||||
s.dnsDomain, s.EventStore(), s.GeoLocationManager(), s.userDeleteFromIDPEnabled, s.IntegratedValidator(), s.Metrics(), s.ProxyController(), s.SettingsManager(), s.PermissionsManager(), s.config.DisableDefaultPolicy)
|
||||
accountManager, err := server.BuildManager(context.Background(), s.Store(), s.NetworkMapController(), s.IdpManager(), s.mgmtSingleAccModeDomain, s.EventStore(), s.GeoLocationManager(), s.userDeleteFromIDPEnabled, s.IntegratedValidator(), s.Metrics(), s.ProxyController(), s.SettingsManager(), s.PermissionsManager(), s.config.DisableDefaultPolicy)
|
||||
if err != nil {
|
||||
log.Fatalf("failed to create account manager: %v", err)
|
||||
}
|
||||
|
||||
352
management/internals/shared/grpc/conversion.go
Normal file
352
management/internals/shared/grpc/conversion.go
Normal file
@@ -0,0 +1,352 @@
|
||||
package grpc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
integrationsConfig "github.com/netbirdio/management-integrations/integrations/config"
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller/cache"
|
||||
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/posture"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
|
||||
func toNetbirdConfig(config *nbconfig.Config, turnCredentials *Token, relayToken *Token, extraSettings *types.ExtraSettings) *proto.NetbirdConfig {
|
||||
if config == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var stuns []*proto.HostConfig
|
||||
for _, stun := range config.Stuns {
|
||||
stuns = append(stuns, &proto.HostConfig{
|
||||
Uri: stun.URI,
|
||||
Protocol: ToResponseProto(stun.Proto),
|
||||
})
|
||||
}
|
||||
|
||||
var turns []*proto.ProtectedHostConfig
|
||||
if config.TURNConfig != nil {
|
||||
for _, turn := range config.TURNConfig.Turns {
|
||||
var username string
|
||||
var password string
|
||||
if turnCredentials != nil {
|
||||
username = turnCredentials.Payload
|
||||
password = turnCredentials.Signature
|
||||
} else {
|
||||
username = turn.Username
|
||||
password = turn.Password
|
||||
}
|
||||
turns = append(turns, &proto.ProtectedHostConfig{
|
||||
HostConfig: &proto.HostConfig{
|
||||
Uri: turn.URI,
|
||||
Protocol: ToResponseProto(turn.Proto),
|
||||
},
|
||||
User: username,
|
||||
Password: password,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
var relayCfg *proto.RelayConfig
|
||||
if config.Relay != nil && len(config.Relay.Addresses) > 0 {
|
||||
relayCfg = &proto.RelayConfig{
|
||||
Urls: config.Relay.Addresses,
|
||||
}
|
||||
|
||||
if relayToken != nil {
|
||||
relayCfg.TokenPayload = relayToken.Payload
|
||||
relayCfg.TokenSignature = relayToken.Signature
|
||||
}
|
||||
}
|
||||
|
||||
var signalCfg *proto.HostConfig
|
||||
if config.Signal != nil {
|
||||
signalCfg = &proto.HostConfig{
|
||||
Uri: config.Signal.URI,
|
||||
Protocol: ToResponseProto(config.Signal.Proto),
|
||||
}
|
||||
}
|
||||
|
||||
nbConfig := &proto.NetbirdConfig{
|
||||
Stuns: stuns,
|
||||
Turns: turns,
|
||||
Signal: signalCfg,
|
||||
Relay: relayCfg,
|
||||
}
|
||||
|
||||
return nbConfig
|
||||
}
|
||||
|
||||
func toPeerConfig(peer *nbpeer.Peer, network *types.Network, dnsName string, settings *types.Settings) *proto.PeerConfig {
|
||||
netmask, _ := network.Net.Mask.Size()
|
||||
fqdn := peer.FQDN(dnsName)
|
||||
return &proto.PeerConfig{
|
||||
Address: fmt.Sprintf("%s/%d", peer.IP.String(), netmask), // take it from the network
|
||||
SshConfig: &proto.SSHConfig{SshEnabled: peer.SSHEnabled},
|
||||
Fqdn: fqdn,
|
||||
RoutingPeerDnsResolutionEnabled: settings.RoutingPeerDNSResolutionEnabled,
|
||||
LazyConnectionEnabled: settings.LazyConnectionEnabled,
|
||||
}
|
||||
}
|
||||
|
||||
func ToSyncResponse(ctx context.Context, config *nbconfig.Config, peer *nbpeer.Peer, turnCredentials *Token, relayCredentials *Token, networkMap *types.NetworkMap, dnsName string, checks []*posture.Checks, dnsCache *cache.DNSConfigCache, settings *types.Settings, extraSettings *types.ExtraSettings, peerGroups []string, dnsFwdPort int64) *proto.SyncResponse {
|
||||
response := &proto.SyncResponse{
|
||||
PeerConfig: toPeerConfig(peer, networkMap.Network, dnsName, settings),
|
||||
NetworkMap: &proto.NetworkMap{
|
||||
Serial: networkMap.Network.CurrentSerial(),
|
||||
Routes: toProtocolRoutes(networkMap.Routes),
|
||||
DNSConfig: toProtocolDNSConfig(networkMap.DNSConfig, dnsCache, dnsFwdPort),
|
||||
},
|
||||
Checks: toProtocolChecks(ctx, checks),
|
||||
}
|
||||
|
||||
nbConfig := toNetbirdConfig(config, turnCredentials, relayCredentials, extraSettings)
|
||||
extendedConfig := integrationsConfig.ExtendNetBirdConfig(peer.ID, peerGroups, nbConfig, extraSettings)
|
||||
response.NetbirdConfig = extendedConfig
|
||||
|
||||
response.NetworkMap.PeerConfig = response.PeerConfig
|
||||
|
||||
remotePeers := make([]*proto.RemotePeerConfig, 0, len(networkMap.Peers)+len(networkMap.OfflinePeers))
|
||||
remotePeers = appendRemotePeerConfig(remotePeers, networkMap.Peers, dnsName)
|
||||
response.RemotePeers = remotePeers
|
||||
response.NetworkMap.RemotePeers = remotePeers
|
||||
response.RemotePeersIsEmpty = len(remotePeers) == 0
|
||||
response.NetworkMap.RemotePeersIsEmpty = response.RemotePeersIsEmpty
|
||||
|
||||
response.NetworkMap.OfflinePeers = appendRemotePeerConfig(nil, networkMap.OfflinePeers, dnsName)
|
||||
|
||||
firewallRules := toProtocolFirewallRules(networkMap.FirewallRules)
|
||||
response.NetworkMap.FirewallRules = firewallRules
|
||||
response.NetworkMap.FirewallRulesIsEmpty = len(firewallRules) == 0
|
||||
|
||||
routesFirewallRules := toProtocolRoutesFirewallRules(networkMap.RoutesFirewallRules)
|
||||
response.NetworkMap.RoutesFirewallRules = routesFirewallRules
|
||||
response.NetworkMap.RoutesFirewallRulesIsEmpty = len(routesFirewallRules) == 0
|
||||
|
||||
if networkMap.ForwardingRules != nil {
|
||||
forwardingRules := make([]*proto.ForwardingRule, 0, len(networkMap.ForwardingRules))
|
||||
for _, rule := range networkMap.ForwardingRules {
|
||||
forwardingRules = append(forwardingRules, rule.ToProto())
|
||||
}
|
||||
response.NetworkMap.ForwardingRules = forwardingRules
|
||||
}
|
||||
|
||||
return response
|
||||
}
|
||||
|
||||
func appendRemotePeerConfig(dst []*proto.RemotePeerConfig, peers []*nbpeer.Peer, dnsName string) []*proto.RemotePeerConfig {
|
||||
for _, rPeer := range peers {
|
||||
dst = append(dst, &proto.RemotePeerConfig{
|
||||
WgPubKey: rPeer.Key,
|
||||
AllowedIps: []string{rPeer.IP.String() + "/32"},
|
||||
SshConfig: &proto.SSHConfig{SshPubKey: []byte(rPeer.SSHKey)},
|
||||
Fqdn: rPeer.FQDN(dnsName),
|
||||
AgentVersion: rPeer.Meta.WtVersion,
|
||||
})
|
||||
}
|
||||
return dst
|
||||
}
|
||||
|
||||
// toProtocolDNSConfig converts nbdns.Config to proto.DNSConfig using the cache
|
||||
func toProtocolDNSConfig(update nbdns.Config, cache *cache.DNSConfigCache, forwardPort int64) *proto.DNSConfig {
|
||||
protoUpdate := &proto.DNSConfig{
|
||||
ServiceEnable: update.ServiceEnable,
|
||||
CustomZones: make([]*proto.CustomZone, 0, len(update.CustomZones)),
|
||||
NameServerGroups: make([]*proto.NameServerGroup, 0, len(update.NameServerGroups)),
|
||||
ForwarderPort: forwardPort,
|
||||
}
|
||||
|
||||
for _, zone := range update.CustomZones {
|
||||
protoZone := convertToProtoCustomZone(zone)
|
||||
protoUpdate.CustomZones = append(protoUpdate.CustomZones, protoZone)
|
||||
}
|
||||
|
||||
for _, nsGroup := range update.NameServerGroups {
|
||||
cacheKey := nsGroup.ID
|
||||
if cachedGroup, exists := cache.GetNameServerGroup(cacheKey); exists {
|
||||
protoUpdate.NameServerGroups = append(protoUpdate.NameServerGroups, cachedGroup)
|
||||
} else {
|
||||
protoGroup := convertToProtoNameServerGroup(nsGroup)
|
||||
cache.SetNameServerGroup(cacheKey, protoGroup)
|
||||
protoUpdate.NameServerGroups = append(protoUpdate.NameServerGroups, protoGroup)
|
||||
}
|
||||
}
|
||||
|
||||
return protoUpdate
|
||||
}
|
||||
|
||||
func ToResponseProto(configProto nbconfig.Protocol) proto.HostConfig_Protocol {
|
||||
switch configProto {
|
||||
case nbconfig.UDP:
|
||||
return proto.HostConfig_UDP
|
||||
case nbconfig.DTLS:
|
||||
return proto.HostConfig_DTLS
|
||||
case nbconfig.HTTP:
|
||||
return proto.HostConfig_HTTP
|
||||
case nbconfig.HTTPS:
|
||||
return proto.HostConfig_HTTPS
|
||||
case nbconfig.TCP:
|
||||
return proto.HostConfig_TCP
|
||||
default:
|
||||
panic(fmt.Errorf("unexpected config protocol type %v", configProto))
|
||||
}
|
||||
}
|
||||
|
||||
func toProtocolRoutes(routes []*route.Route) []*proto.Route {
|
||||
protoRoutes := make([]*proto.Route, 0, len(routes))
|
||||
for _, r := range routes {
|
||||
protoRoutes = append(protoRoutes, toProtocolRoute(r))
|
||||
}
|
||||
return protoRoutes
|
||||
}
|
||||
|
||||
func toProtocolRoute(route *route.Route) *proto.Route {
|
||||
return &proto.Route{
|
||||
ID: string(route.ID),
|
||||
NetID: string(route.NetID),
|
||||
Network: route.Network.String(),
|
||||
Domains: route.Domains.ToPunycodeList(),
|
||||
NetworkType: int64(route.NetworkType),
|
||||
Peer: route.Peer,
|
||||
Metric: int64(route.Metric),
|
||||
Masquerade: route.Masquerade,
|
||||
KeepRoute: route.KeepRoute,
|
||||
SkipAutoApply: route.SkipAutoApply,
|
||||
}
|
||||
}
|
||||
|
||||
// toProtocolFirewallRules converts the firewall rules to the protocol firewall rules.
|
||||
func toProtocolFirewallRules(rules []*types.FirewallRule) []*proto.FirewallRule {
|
||||
result := make([]*proto.FirewallRule, len(rules))
|
||||
for i := range rules {
|
||||
rule := rules[i]
|
||||
|
||||
fwRule := &proto.FirewallRule{
|
||||
PolicyID: []byte(rule.PolicyID),
|
||||
PeerIP: rule.PeerIP,
|
||||
Direction: getProtoDirection(rule.Direction),
|
||||
Action: getProtoAction(rule.Action),
|
||||
Protocol: getProtoProtocol(rule.Protocol),
|
||||
Port: rule.Port,
|
||||
}
|
||||
|
||||
if shouldUsePortRange(fwRule) {
|
||||
fwRule.PortInfo = rule.PortRange.ToProto()
|
||||
}
|
||||
|
||||
result[i] = fwRule
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// getProtoDirection converts the direction to proto.RuleDirection.
|
||||
func getProtoDirection(direction int) proto.RuleDirection {
|
||||
if direction == types.FirewallRuleDirectionOUT {
|
||||
return proto.RuleDirection_OUT
|
||||
}
|
||||
return proto.RuleDirection_IN
|
||||
}
|
||||
|
||||
func toProtocolRoutesFirewallRules(rules []*types.RouteFirewallRule) []*proto.RouteFirewallRule {
|
||||
result := make([]*proto.RouteFirewallRule, len(rules))
|
||||
for i := range rules {
|
||||
rule := rules[i]
|
||||
result[i] = &proto.RouteFirewallRule{
|
||||
SourceRanges: rule.SourceRanges,
|
||||
Action: getProtoAction(rule.Action),
|
||||
Destination: rule.Destination,
|
||||
Protocol: getProtoProtocol(rule.Protocol),
|
||||
PortInfo: getProtoPortInfo(rule),
|
||||
IsDynamic: rule.IsDynamic,
|
||||
Domains: rule.Domains.ToPunycodeList(),
|
||||
PolicyID: []byte(rule.PolicyID),
|
||||
RouteID: string(rule.RouteID),
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// getProtoAction converts the action to proto.RuleAction.
|
||||
func getProtoAction(action string) proto.RuleAction {
|
||||
if action == string(types.PolicyTrafficActionDrop) {
|
||||
return proto.RuleAction_DROP
|
||||
}
|
||||
return proto.RuleAction_ACCEPT
|
||||
}
|
||||
|
||||
// getProtoProtocol converts the protocol to proto.RuleProtocol.
|
||||
func getProtoProtocol(protocol string) proto.RuleProtocol {
|
||||
switch types.PolicyRuleProtocolType(protocol) {
|
||||
case types.PolicyRuleProtocolALL:
|
||||
return proto.RuleProtocol_ALL
|
||||
case types.PolicyRuleProtocolTCP:
|
||||
return proto.RuleProtocol_TCP
|
||||
case types.PolicyRuleProtocolUDP:
|
||||
return proto.RuleProtocol_UDP
|
||||
case types.PolicyRuleProtocolICMP:
|
||||
return proto.RuleProtocol_ICMP
|
||||
default:
|
||||
return proto.RuleProtocol_UNKNOWN
|
||||
}
|
||||
}
|
||||
|
||||
// getProtoPortInfo converts the port info to proto.PortInfo.
|
||||
func getProtoPortInfo(rule *types.RouteFirewallRule) *proto.PortInfo {
|
||||
var portInfo proto.PortInfo
|
||||
if rule.Port != 0 {
|
||||
portInfo.PortSelection = &proto.PortInfo_Port{Port: uint32(rule.Port)}
|
||||
} else if portRange := rule.PortRange; portRange.Start != 0 && portRange.End != 0 {
|
||||
portInfo.PortSelection = &proto.PortInfo_Range_{
|
||||
Range: &proto.PortInfo_Range{
|
||||
Start: uint32(portRange.Start),
|
||||
End: uint32(portRange.End),
|
||||
},
|
||||
}
|
||||
}
|
||||
return &portInfo
|
||||
}
|
||||
|
||||
func shouldUsePortRange(rule *proto.FirewallRule) bool {
|
||||
return rule.Port == "" && (rule.Protocol == proto.RuleProtocol_UDP || rule.Protocol == proto.RuleProtocol_TCP)
|
||||
}
|
||||
|
||||
// Helper function to convert nbdns.CustomZone to proto.CustomZone
|
||||
func convertToProtoCustomZone(zone nbdns.CustomZone) *proto.CustomZone {
|
||||
protoZone := &proto.CustomZone{
|
||||
Domain: zone.Domain,
|
||||
Records: make([]*proto.SimpleRecord, 0, len(zone.Records)),
|
||||
}
|
||||
for _, record := range zone.Records {
|
||||
protoZone.Records = append(protoZone.Records, &proto.SimpleRecord{
|
||||
Name: record.Name,
|
||||
Type: int64(record.Type),
|
||||
Class: record.Class,
|
||||
TTL: int64(record.TTL),
|
||||
RData: record.RData,
|
||||
})
|
||||
}
|
||||
return protoZone
|
||||
}
|
||||
|
||||
// Helper function to convert nbdns.NameServerGroup to proto.NameServerGroup
|
||||
func convertToProtoNameServerGroup(nsGroup *nbdns.NameServerGroup) *proto.NameServerGroup {
|
||||
protoGroup := &proto.NameServerGroup{
|
||||
Primary: nsGroup.Primary,
|
||||
Domains: nsGroup.Domains,
|
||||
SearchDomainsEnabled: nsGroup.SearchDomainsEnabled,
|
||||
NameServers: make([]*proto.NameServer, 0, len(nsGroup.NameServers)),
|
||||
}
|
||||
for _, ns := range nsGroup.NameServers {
|
||||
protoGroup.NameServers = append(protoGroup.NameServers, &proto.NameServer{
|
||||
IP: ns.IP.String(),
|
||||
Port: int64(ns.Port),
|
||||
NSType: int64(ns.NSType),
|
||||
})
|
||||
}
|
||||
return protoGroup
|
||||
}
|
||||
150
management/internals/shared/grpc/conversion_test.go
Normal file
150
management/internals/shared/grpc/conversion_test.go
Normal file
@@ -0,0 +1,150 @@
|
||||
package grpc
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller/cache"
|
||||
)
|
||||
|
||||
func TestToProtocolDNSConfigWithCache(t *testing.T) {
|
||||
var cache cache.DNSConfigCache
|
||||
|
||||
// Create two different configs
|
||||
config1 := nbdns.Config{
|
||||
ServiceEnable: true,
|
||||
CustomZones: []nbdns.CustomZone{
|
||||
{
|
||||
Domain: "example.com",
|
||||
Records: []nbdns.SimpleRecord{
|
||||
{Name: "www", Type: 1, Class: "IN", TTL: 300, RData: "192.168.1.1"},
|
||||
},
|
||||
},
|
||||
},
|
||||
NameServerGroups: []*nbdns.NameServerGroup{
|
||||
{
|
||||
ID: "group1",
|
||||
Name: "Group 1",
|
||||
NameServers: []nbdns.NameServer{
|
||||
{IP: netip.MustParseAddr("8.8.8.8"), Port: 53},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
config2 := nbdns.Config{
|
||||
ServiceEnable: true,
|
||||
CustomZones: []nbdns.CustomZone{
|
||||
{
|
||||
Domain: "example.org",
|
||||
Records: []nbdns.SimpleRecord{
|
||||
{Name: "mail", Type: 1, Class: "IN", TTL: 300, RData: "192.168.1.2"},
|
||||
},
|
||||
},
|
||||
},
|
||||
NameServerGroups: []*nbdns.NameServerGroup{
|
||||
{
|
||||
ID: "group2",
|
||||
Name: "Group 2",
|
||||
NameServers: []nbdns.NameServer{
|
||||
{IP: netip.MustParseAddr("8.8.4.4"), Port: 53},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// First run with config1
|
||||
result1 := toProtocolDNSConfig(config1, &cache, int64(network_map.DnsForwarderPort))
|
||||
|
||||
// Second run with config2
|
||||
result2 := toProtocolDNSConfig(config2, &cache, int64(network_map.DnsForwarderPort))
|
||||
|
||||
// Third run with config1 again
|
||||
result3 := toProtocolDNSConfig(config1, &cache, int64(network_map.DnsForwarderPort))
|
||||
|
||||
// Verify that result1 and result3 are identical
|
||||
if !reflect.DeepEqual(result1, result3) {
|
||||
t.Errorf("Results are not identical when run with the same input. Expected %v, got %v", result1, result3)
|
||||
}
|
||||
|
||||
// Verify that result2 is different from result1 and result3
|
||||
if reflect.DeepEqual(result1, result2) || reflect.DeepEqual(result2, result3) {
|
||||
t.Errorf("Results should be different for different inputs")
|
||||
}
|
||||
|
||||
if _, exists := cache.GetNameServerGroup("group1"); !exists {
|
||||
t.Errorf("Cache should contain name server group 'group1'")
|
||||
}
|
||||
|
||||
if _, exists := cache.GetNameServerGroup("group2"); !exists {
|
||||
t.Errorf("Cache should contain name server group 'group2'")
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkToProtocolDNSConfig(b *testing.B) {
|
||||
sizes := []int{10, 100, 1000}
|
||||
|
||||
for _, size := range sizes {
|
||||
testData := generateTestData(size)
|
||||
|
||||
b.Run(fmt.Sprintf("WithCache-Size%d", size), func(b *testing.B) {
|
||||
cache := &cache.DNSConfigCache{}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
toProtocolDNSConfig(testData, cache, int64(network_map.DnsForwarderPort))
|
||||
}
|
||||
})
|
||||
|
||||
b.Run(fmt.Sprintf("WithoutCache-Size%d", size), func(b *testing.B) {
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
cache := &cache.DNSConfigCache{}
|
||||
toProtocolDNSConfig(testData, cache, int64(network_map.DnsForwarderPort))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func generateTestData(size int) nbdns.Config {
|
||||
config := nbdns.Config{
|
||||
ServiceEnable: true,
|
||||
CustomZones: make([]nbdns.CustomZone, size),
|
||||
NameServerGroups: make([]*nbdns.NameServerGroup, size),
|
||||
}
|
||||
|
||||
for i := 0; i < size; i++ {
|
||||
config.CustomZones[i] = nbdns.CustomZone{
|
||||
Domain: fmt.Sprintf("domain%d.com", i),
|
||||
Records: []nbdns.SimpleRecord{
|
||||
{
|
||||
Name: fmt.Sprintf("record%d", i),
|
||||
Type: 1,
|
||||
Class: "IN",
|
||||
TTL: 3600,
|
||||
RData: "192.168.1.1",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
config.NameServerGroups[i] = &nbdns.NameServerGroup{
|
||||
ID: fmt.Sprintf("group%d", i),
|
||||
Primary: i == 0,
|
||||
Domains: []string{fmt.Sprintf("domain%d.com", i)},
|
||||
SearchDomainsEnabled: true,
|
||||
NameServers: []nbdns.NameServer{
|
||||
{
|
||||
IP: netip.MustParseAddr("8.8.8.8"),
|
||||
Port: 53,
|
||||
NSType: 1,
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
return config
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package server
|
||||
package grpc
|
||||
|
||||
import (
|
||||
"hash/fnv"
|
||||
@@ -1,4 +1,4 @@
|
||||
package server
|
||||
package grpc
|
||||
|
||||
import (
|
||||
"hash/fnv"
|
||||
@@ -1,4 +1,4 @@
|
||||
package server
|
||||
package grpc
|
||||
|
||||
import (
|
||||
"context"
|
||||
@@ -7,8 +7,10 @@ import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
pb "github.com/golang/protobuf/proto" // nolint
|
||||
@@ -20,7 +22,7 @@ import (
|
||||
"google.golang.org/grpc/peer"
|
||||
"google.golang.org/grpc/status"
|
||||
|
||||
integrationsConfig "github.com/netbirdio/management-integrations/integrations/config"
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
|
||||
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
|
||||
"github.com/netbirdio/netbird/management/server/peers/ephemeral"
|
||||
|
||||
@@ -44,15 +46,18 @@ import (
|
||||
const (
|
||||
envLogBlockedPeers = "NB_LOG_BLOCKED_PEERS"
|
||||
envBlockPeers = "NB_BLOCK_SAME_PEERS"
|
||||
envConcurrentSyncs = "NB_MAX_CONCURRENT_SYNCS"
|
||||
|
||||
defaultSyncLim = 1000
|
||||
)
|
||||
|
||||
// GRPCServer an instance of a Management gRPC API server
|
||||
type GRPCServer struct {
|
||||
// Server an instance of a Management gRPC API server
|
||||
type Server struct {
|
||||
accountManager account.Manager
|
||||
settingsManager settings.Manager
|
||||
wgKey wgtypes.Key
|
||||
proto.UnimplementedManagementServiceServer
|
||||
peersUpdateManager *PeersUpdateManager
|
||||
peersUpdateManager network_map.PeersUpdateManager
|
||||
config *nbconfig.Config
|
||||
secretsManager SecretsManager
|
||||
appMetrics telemetry.AppMetrics
|
||||
@@ -63,21 +68,28 @@ type GRPCServer struct {
|
||||
logBlockedPeers bool
|
||||
blockPeersWithSameConfig bool
|
||||
integratedPeerValidator integrated_validator.IntegratedValidator
|
||||
|
||||
loginFilter *loginFilter
|
||||
|
||||
networkMapController network_map.Controller
|
||||
|
||||
syncSem atomic.Int32
|
||||
syncLim int32
|
||||
}
|
||||
|
||||
// NewServer creates a new Management server
|
||||
func NewServer(
|
||||
ctx context.Context,
|
||||
config *nbconfig.Config,
|
||||
accountManager account.Manager,
|
||||
settingsManager settings.Manager,
|
||||
peersUpdateManager *PeersUpdateManager,
|
||||
peersUpdateManager network_map.PeersUpdateManager,
|
||||
secretsManager SecretsManager,
|
||||
appMetrics telemetry.AppMetrics,
|
||||
ephemeralManager ephemeral.Manager,
|
||||
authManager auth.Manager,
|
||||
integratedPeerValidator integrated_validator.IntegratedValidator,
|
||||
) (*GRPCServer, error) {
|
||||
networkMapController network_map.Controller,
|
||||
) (*Server, error) {
|
||||
key, err := wgtypes.GeneratePrivateKey()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -86,7 +98,7 @@ func NewServer(
|
||||
if appMetrics != nil {
|
||||
// update gauge based on number of connected peers which is equal to open gRPC streams
|
||||
err = appMetrics.GRPCMetrics().RegisterConnectedStreams(func() int64 {
|
||||
return int64(len(peersUpdateManager.peerChannels))
|
||||
return int64(peersUpdateManager.CountStreams())
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -96,7 +108,18 @@ func NewServer(
|
||||
logBlockedPeers := strings.ToLower(os.Getenv(envLogBlockedPeers)) == "true"
|
||||
blockPeersWithSameConfig := strings.ToLower(os.Getenv(envBlockPeers)) == "true"
|
||||
|
||||
return &GRPCServer{
|
||||
syncLim := int32(defaultSyncLim)
|
||||
if syncLimStr := os.Getenv(envConcurrentSyncs); syncLimStr != "" {
|
||||
syncLimParsed, err := strconv.Atoi(syncLimStr)
|
||||
if err != nil {
|
||||
log.Errorf("invalid value for %s: %v using %d", envConcurrentSyncs, err, defaultSyncLim)
|
||||
} else {
|
||||
//nolint:gosec
|
||||
syncLim = int32(syncLimParsed)
|
||||
}
|
||||
}
|
||||
|
||||
return &Server{
|
||||
wgKey: key,
|
||||
// peerKey -> event channel
|
||||
peersUpdateManager: peersUpdateManager,
|
||||
@@ -110,10 +133,15 @@ func NewServer(
|
||||
logBlockedPeers: logBlockedPeers,
|
||||
blockPeersWithSameConfig: blockPeersWithSameConfig,
|
||||
integratedPeerValidator: integratedPeerValidator,
|
||||
networkMapController: networkMapController,
|
||||
|
||||
loginFilter: newLoginFilter(),
|
||||
|
||||
syncLim: syncLim,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *GRPCServer) GetServerKey(ctx context.Context, req *proto.Empty) (*proto.ServerKeyResponse, error) {
|
||||
func (s *Server) GetServerKey(ctx context.Context, req *proto.Empty) (*proto.ServerKeyResponse, error) {
|
||||
ip := ""
|
||||
p, ok := peer.FromContext(ctx)
|
||||
if ok {
|
||||
@@ -150,7 +178,12 @@ func getRealIP(ctx context.Context) net.IP {
|
||||
|
||||
// Sync validates the existence of a connecting peer, sends an initial state (all available for the connecting peers) and
|
||||
// notifies the connected peer of any updates (e.g. new peers under the same account)
|
||||
func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_SyncServer) error {
|
||||
func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_SyncServer) error {
|
||||
if s.syncSem.Load() >= s.syncLim {
|
||||
return status.Errorf(codes.ResourceExhausted, "too many concurrent sync requests, please try again later")
|
||||
}
|
||||
s.syncSem.Add(1)
|
||||
|
||||
reqStart := time.Now()
|
||||
|
||||
ctx := srv.Context()
|
||||
@@ -158,13 +191,14 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi
|
||||
syncReq := &proto.SyncRequest{}
|
||||
peerKey, err := s.parseRequest(ctx, req, syncReq)
|
||||
if err != nil {
|
||||
s.syncSem.Add(-1)
|
||||
return err
|
||||
}
|
||||
realIP := getRealIP(ctx)
|
||||
sRealIP := realIP.String()
|
||||
peerMeta := extractPeerMeta(ctx, syncReq.GetMeta())
|
||||
metahashed := metaHash(peerMeta, sRealIP)
|
||||
if !s.accountManager.AllowSync(peerKey.String(), metahashed) {
|
||||
if !s.loginFilter.allowLogin(peerKey.String(), metahashed) {
|
||||
if s.appMetrics != nil {
|
||||
s.appMetrics.GRPCMetrics().CountSyncRequestBlocked()
|
||||
}
|
||||
@@ -172,6 +206,7 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi
|
||||
log.WithContext(ctx).Warnf("peer %s with meta hash %d is blocked from syncing", peerKey.String(), metahashed)
|
||||
}
|
||||
if s.blockPeersWithSameConfig {
|
||||
s.syncSem.Add(-1)
|
||||
return mapError(ctx, internalStatus.ErrPeerAlreadyLoggedIn)
|
||||
}
|
||||
}
|
||||
@@ -183,42 +218,54 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi
|
||||
// nolint:staticcheck
|
||||
ctx = context.WithValue(ctx, nbContext.PeerIDKey, peerKey.String())
|
||||
|
||||
unlock := s.acquirePeerLockByUID(ctx, peerKey.String())
|
||||
defer func() {
|
||||
if unlock != nil {
|
||||
unlock()
|
||||
}
|
||||
}()
|
||||
|
||||
accountID, err := s.accountManager.GetAccountIDForPeerKey(ctx, peerKey.String())
|
||||
if err != nil {
|
||||
// nolint:staticcheck
|
||||
ctx = context.WithValue(ctx, nbContext.AccountIDKey, "UNKNOWN")
|
||||
log.WithContext(ctx).Tracef("peer %s is not registered", peerKey.String())
|
||||
if errStatus, ok := internalStatus.FromError(err); ok && errStatus.Type() == internalStatus.NotFound {
|
||||
s.syncSem.Add(-1)
|
||||
return status.Errorf(codes.PermissionDenied, "peer is not registered")
|
||||
}
|
||||
s.syncSem.Add(-1)
|
||||
return err
|
||||
}
|
||||
|
||||
log.WithContext(ctx).Debugf("Sync: GetAccountIDForPeerKey since start %v", time.Since(reqStart))
|
||||
|
||||
// nolint:staticcheck
|
||||
ctx = context.WithValue(ctx, nbContext.AccountIDKey, accountID)
|
||||
|
||||
start := time.Now()
|
||||
unlock := s.acquirePeerLockByUID(ctx, peerKey.String())
|
||||
defer func() {
|
||||
if unlock != nil {
|
||||
unlock()
|
||||
}
|
||||
}()
|
||||
log.WithContext(ctx).Tracef("acquired peer lock for peer %s took %v", peerKey.String(), time.Since(start))
|
||||
log.WithContext(ctx).Debugf("Sync: acquirePeerLockByUID since start %v", time.Since(reqStart))
|
||||
|
||||
log.WithContext(ctx).Debugf("Sync request from peer [%s] [%s]", req.WgPubKey, sRealIP)
|
||||
|
||||
if syncReq.GetMeta() == nil {
|
||||
log.WithContext(ctx).Tracef("peer system meta has to be provided on sync. Peer %s, remote addr %s", peerKey.String(), realIP)
|
||||
}
|
||||
|
||||
peer, netMap, postureChecks, err := s.accountManager.SyncAndMarkPeer(ctx, accountID, peerKey.String(), peerMeta, realIP)
|
||||
metahash := metaHash(peerMeta, realIP.String())
|
||||
s.loginFilter.addLogin(peerKey.String(), metahash)
|
||||
|
||||
peer, netMap, postureChecks, dnsFwdPort, err := s.accountManager.SyncAndMarkPeer(ctx, accountID, peerKey.String(), peerMeta, realIP)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Debugf("error while syncing peer %s: %v", peerKey.String(), err)
|
||||
s.syncSem.Add(-1)
|
||||
return mapError(ctx, err)
|
||||
}
|
||||
|
||||
err = s.sendInitialSync(ctx, peerKey, peer, netMap, postureChecks, srv)
|
||||
err = s.sendInitialSync(ctx, peerKey, peer, netMap, postureChecks, srv, dnsFwdPort)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Debugf("error while sending initial sync for %s: %v", peerKey.String(), err)
|
||||
s.syncSem.Add(-1)
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -235,13 +282,13 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi
|
||||
unlock()
|
||||
unlock = nil
|
||||
|
||||
log.WithContext(ctx).Debugf("Sync: took %v", time.Since(reqStart))
|
||||
s.syncSem.Add(-1)
|
||||
|
||||
return s.handleUpdates(ctx, accountID, peerKey, peer, updates, srv)
|
||||
}
|
||||
|
||||
// handleUpdates sends updates to the connected peer until the updates channel is closed.
|
||||
func (s *GRPCServer) handleUpdates(ctx context.Context, accountID string, peerKey wgtypes.Key, peer *nbpeer.Peer, updates chan *UpdateMessage, srv proto.ManagementService_SyncServer) error {
|
||||
func (s *Server) handleUpdates(ctx context.Context, accountID string, peerKey wgtypes.Key, peer *nbpeer.Peer, updates chan *network_map.UpdateMessage, srv proto.ManagementService_SyncServer) error {
|
||||
log.WithContext(ctx).Tracef("starting to handle updates for peer %s", peerKey.String())
|
||||
for {
|
||||
select {
|
||||
@@ -275,7 +322,7 @@ func (s *GRPCServer) handleUpdates(ctx context.Context, accountID string, peerKe
|
||||
|
||||
// sendUpdate encrypts the update message using the peer key and the server's wireguard key,
|
||||
// then sends the encrypted message to the connected peer via the sync server.
|
||||
func (s *GRPCServer) sendUpdate(ctx context.Context, accountID string, peerKey wgtypes.Key, peer *nbpeer.Peer, update *UpdateMessage, srv proto.ManagementService_SyncServer) error {
|
||||
func (s *Server) sendUpdate(ctx context.Context, accountID string, peerKey wgtypes.Key, peer *nbpeer.Peer, update *network_map.UpdateMessage, srv proto.ManagementService_SyncServer) error {
|
||||
encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, update.Update)
|
||||
if err != nil {
|
||||
s.cancelPeerRoutines(ctx, accountID, peer)
|
||||
@@ -293,7 +340,7 @@ func (s *GRPCServer) sendUpdate(ctx context.Context, accountID string, peerKey w
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *GRPCServer) cancelPeerRoutines(ctx context.Context, accountID string, peer *nbpeer.Peer) {
|
||||
func (s *Server) cancelPeerRoutines(ctx context.Context, accountID string, peer *nbpeer.Peer) {
|
||||
unlock := s.acquirePeerLockByUID(ctx, peer.Key)
|
||||
defer unlock()
|
||||
|
||||
@@ -308,7 +355,7 @@ func (s *GRPCServer) cancelPeerRoutines(ctx context.Context, accountID string, p
|
||||
log.WithContext(ctx).Tracef("peer %s has been disconnected", peer.Key)
|
||||
}
|
||||
|
||||
func (s *GRPCServer) validateToken(ctx context.Context, jwtToken string) (string, error) {
|
||||
func (s *Server) validateToken(ctx context.Context, jwtToken string) (string, error) {
|
||||
if s.authManager == nil {
|
||||
return "", status.Errorf(codes.Internal, "missing auth manager")
|
||||
}
|
||||
@@ -342,7 +389,7 @@ func (s *GRPCServer) validateToken(ctx context.Context, jwtToken string) (string
|
||||
return userAuth.UserId, nil
|
||||
}
|
||||
|
||||
func (s *GRPCServer) acquirePeerLockByUID(ctx context.Context, uniqueID string) (unlock func()) {
|
||||
func (s *Server) acquirePeerLockByUID(ctx context.Context, uniqueID string) (unlock func()) {
|
||||
log.WithContext(ctx).Tracef("acquiring peer lock for ID %s", uniqueID)
|
||||
|
||||
start := time.Now()
|
||||
@@ -450,7 +497,7 @@ func extractPeerMeta(ctx context.Context, meta *proto.PeerSystemMeta) nbpeer.Pee
|
||||
}
|
||||
}
|
||||
|
||||
func (s *GRPCServer) parseRequest(ctx context.Context, req *proto.EncryptedMessage, parsed pb.Message) (wgtypes.Key, error) {
|
||||
func (s *Server) parseRequest(ctx context.Context, req *proto.EncryptedMessage, parsed pb.Message) (wgtypes.Key, error) {
|
||||
peerKey, err := wgtypes.ParseKey(req.GetWgPubKey())
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Warnf("error while parsing peer's WireGuard public key %s.", req.WgPubKey)
|
||||
@@ -469,7 +516,7 @@ func (s *GRPCServer) parseRequest(ctx context.Context, req *proto.EncryptedMessa
|
||||
// In case it is, the login is successful
|
||||
// In case it isn't, the endpoint checks whether setup key is provided within the request and tries to register a peer.
|
||||
// In case of the successful registration login is also successful
|
||||
func (s *GRPCServer) Login(ctx context.Context, req *proto.EncryptedMessage) (*proto.EncryptedMessage, error) {
|
||||
func (s *Server) Login(ctx context.Context, req *proto.EncryptedMessage) (*proto.EncryptedMessage, error) {
|
||||
reqStart := time.Now()
|
||||
realIP := getRealIP(ctx)
|
||||
sRealIP := realIP.String()
|
||||
@@ -483,7 +530,7 @@ func (s *GRPCServer) Login(ctx context.Context, req *proto.EncryptedMessage) (*p
|
||||
|
||||
peerMeta := extractPeerMeta(ctx, loginReq.GetMeta())
|
||||
metahashed := metaHash(peerMeta, sRealIP)
|
||||
if !s.accountManager.AllowSync(peerKey.String(), metahashed) {
|
||||
if !s.loginFilter.allowLogin(peerKey.String(), metahashed) {
|
||||
if s.logBlockedPeers {
|
||||
log.WithContext(ctx).Warnf("peer %s with meta hash %d is blocked from login", peerKey.String(), metahashed)
|
||||
}
|
||||
@@ -509,10 +556,16 @@ func (s *GRPCServer) Login(ctx context.Context, req *proto.EncryptedMessage) (*p
|
||||
//nolint
|
||||
ctx = context.WithValue(ctx, nbContext.AccountIDKey, accountID)
|
||||
|
||||
log.WithContext(ctx).Debugf("Login: GetAccountIDForPeerKey since start %v", time.Since(reqStart))
|
||||
|
||||
defer func() {
|
||||
if s.appMetrics != nil {
|
||||
s.appMetrics.GRPCMetrics().CountLoginRequestDuration(time.Since(reqStart), accountID)
|
||||
}
|
||||
took := time.Since(reqStart)
|
||||
if took > 7*time.Second {
|
||||
log.WithContext(ctx).Debugf("Login: took %v", time.Since(reqStart))
|
||||
}
|
||||
}()
|
||||
|
||||
if loginReq.GetMeta() == nil {
|
||||
@@ -546,9 +599,12 @@ func (s *GRPCServer) Login(ctx context.Context, req *proto.EncryptedMessage) (*p
|
||||
return nil, mapError(ctx, err)
|
||||
}
|
||||
|
||||
log.WithContext(ctx).Debugf("Login: LoginPeer since start %v", time.Since(reqStart))
|
||||
|
||||
// if the login request contains setup key then it is a registration request
|
||||
if loginReq.GetSetupKey() != "" {
|
||||
s.ephemeralManager.OnPeerDisconnected(ctx, peer)
|
||||
log.WithContext(ctx).Debugf("Login: OnPeerDisconnected since start %v", time.Since(reqStart))
|
||||
}
|
||||
|
||||
loginResp, err := s.prepareLoginResponse(ctx, peer, netMap, postureChecks)
|
||||
@@ -557,6 +613,8 @@ func (s *GRPCServer) Login(ctx context.Context, req *proto.EncryptedMessage) (*p
|
||||
return nil, status.Errorf(codes.Internal, "failed logging in peer")
|
||||
}
|
||||
|
||||
log.WithContext(ctx).Debugf("Login: prepareLoginResponse since start %v", time.Since(reqStart))
|
||||
|
||||
encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, loginResp)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Warnf("failed encrypting peer %s message", peer.ID)
|
||||
@@ -569,7 +627,7 @@ func (s *GRPCServer) Login(ctx context.Context, req *proto.EncryptedMessage) (*p
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *GRPCServer) prepareLoginResponse(ctx context.Context, peer *nbpeer.Peer, netMap *types.NetworkMap, postureChecks []*posture.Checks) (*proto.LoginResponse, error) {
|
||||
func (s *Server) prepareLoginResponse(ctx context.Context, peer *nbpeer.Peer, netMap *types.NetworkMap, postureChecks []*posture.Checks) (*proto.LoginResponse, error) {
|
||||
var relayToken *Token
|
||||
var err error
|
||||
if s.config.Relay != nil && len(s.config.Relay.Addresses) > 0 {
|
||||
@@ -588,7 +646,7 @@ func (s *GRPCServer) prepareLoginResponse(ctx context.Context, peer *nbpeer.Peer
|
||||
// if peer has reached this point then it has logged in
|
||||
loginResp := &proto.LoginResponse{
|
||||
NetbirdConfig: toNetbirdConfig(s.config, nil, relayToken, nil),
|
||||
PeerConfig: toPeerConfig(peer, netMap.Network, s.accountManager.GetDNSDomain(settings), settings),
|
||||
PeerConfig: toPeerConfig(peer, netMap.Network, s.networkMapController.GetDNSDomain(settings), settings),
|
||||
Checks: toProtocolChecks(ctx, postureChecks),
|
||||
}
|
||||
|
||||
@@ -600,7 +658,7 @@ func (s *GRPCServer) prepareLoginResponse(ctx context.Context, peer *nbpeer.Peer
|
||||
//
|
||||
// The user ID can be empty if the token is not provided, which is acceptable if the peer is already
|
||||
// registered or if it uses a setup key to register.
|
||||
func (s *GRPCServer) processJwtToken(ctx context.Context, loginReq *proto.LoginRequest, peerKey wgtypes.Key) (string, error) {
|
||||
func (s *Server) processJwtToken(ctx context.Context, loginReq *proto.LoginRequest, peerKey wgtypes.Key) (string, error) {
|
||||
userID := ""
|
||||
if loginReq.GetJwtToken() != "" {
|
||||
var err error
|
||||
@@ -620,166 +678,13 @@ func (s *GRPCServer) processJwtToken(ctx context.Context, loginReq *proto.LoginR
|
||||
return userID, nil
|
||||
}
|
||||
|
||||
func ToResponseProto(configProto nbconfig.Protocol) proto.HostConfig_Protocol {
|
||||
switch configProto {
|
||||
case nbconfig.UDP:
|
||||
return proto.HostConfig_UDP
|
||||
case nbconfig.DTLS:
|
||||
return proto.HostConfig_DTLS
|
||||
case nbconfig.HTTP:
|
||||
return proto.HostConfig_HTTP
|
||||
case nbconfig.HTTPS:
|
||||
return proto.HostConfig_HTTPS
|
||||
case nbconfig.TCP:
|
||||
return proto.HostConfig_TCP
|
||||
default:
|
||||
panic(fmt.Errorf("unexpected config protocol type %v", configProto))
|
||||
}
|
||||
}
|
||||
|
||||
func toNetbirdConfig(config *nbconfig.Config, turnCredentials *Token, relayToken *Token, extraSettings *types.ExtraSettings) *proto.NetbirdConfig {
|
||||
if config == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var stuns []*proto.HostConfig
|
||||
for _, stun := range config.Stuns {
|
||||
stuns = append(stuns, &proto.HostConfig{
|
||||
Uri: stun.URI,
|
||||
Protocol: ToResponseProto(stun.Proto),
|
||||
})
|
||||
}
|
||||
|
||||
var turns []*proto.ProtectedHostConfig
|
||||
if config.TURNConfig != nil {
|
||||
for _, turn := range config.TURNConfig.Turns {
|
||||
var username string
|
||||
var password string
|
||||
if turnCredentials != nil {
|
||||
username = turnCredentials.Payload
|
||||
password = turnCredentials.Signature
|
||||
} else {
|
||||
username = turn.Username
|
||||
password = turn.Password
|
||||
}
|
||||
turns = append(turns, &proto.ProtectedHostConfig{
|
||||
HostConfig: &proto.HostConfig{
|
||||
Uri: turn.URI,
|
||||
Protocol: ToResponseProto(turn.Proto),
|
||||
},
|
||||
User: username,
|
||||
Password: password,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
var relayCfg *proto.RelayConfig
|
||||
if config.Relay != nil && len(config.Relay.Addresses) > 0 {
|
||||
relayCfg = &proto.RelayConfig{
|
||||
Urls: config.Relay.Addresses,
|
||||
}
|
||||
|
||||
if relayToken != nil {
|
||||
relayCfg.TokenPayload = relayToken.Payload
|
||||
relayCfg.TokenSignature = relayToken.Signature
|
||||
}
|
||||
}
|
||||
|
||||
var signalCfg *proto.HostConfig
|
||||
if config.Signal != nil {
|
||||
signalCfg = &proto.HostConfig{
|
||||
Uri: config.Signal.URI,
|
||||
Protocol: ToResponseProto(config.Signal.Proto),
|
||||
}
|
||||
}
|
||||
|
||||
nbConfig := &proto.NetbirdConfig{
|
||||
Stuns: stuns,
|
||||
Turns: turns,
|
||||
Signal: signalCfg,
|
||||
Relay: relayCfg,
|
||||
}
|
||||
|
||||
return nbConfig
|
||||
}
|
||||
|
||||
func toPeerConfig(peer *nbpeer.Peer, network *types.Network, dnsName string, settings *types.Settings) *proto.PeerConfig {
|
||||
netmask, _ := network.Net.Mask.Size()
|
||||
fqdn := peer.FQDN(dnsName)
|
||||
return &proto.PeerConfig{
|
||||
Address: fmt.Sprintf("%s/%d", peer.IP.String(), netmask), // take it from the network
|
||||
SshConfig: &proto.SSHConfig{SshEnabled: peer.SSHEnabled},
|
||||
Fqdn: fqdn,
|
||||
RoutingPeerDnsResolutionEnabled: settings.RoutingPeerDNSResolutionEnabled,
|
||||
LazyConnectionEnabled: settings.LazyConnectionEnabled,
|
||||
}
|
||||
}
|
||||
|
||||
func toSyncResponse(ctx context.Context, config *nbconfig.Config, peer *nbpeer.Peer, turnCredentials *Token, relayCredentials *Token, networkMap *types.NetworkMap, dnsName string, checks []*posture.Checks, dnsCache *DNSConfigCache, settings *types.Settings, extraSettings *types.ExtraSettings, peerGroups []string, dnsFwdPort int64) *proto.SyncResponse {
|
||||
response := &proto.SyncResponse{
|
||||
PeerConfig: toPeerConfig(peer, networkMap.Network, dnsName, settings),
|
||||
NetworkMap: &proto.NetworkMap{
|
||||
Serial: networkMap.Network.CurrentSerial(),
|
||||
Routes: toProtocolRoutes(networkMap.Routes),
|
||||
DNSConfig: toProtocolDNSConfig(networkMap.DNSConfig, dnsCache, dnsFwdPort),
|
||||
},
|
||||
Checks: toProtocolChecks(ctx, checks),
|
||||
}
|
||||
|
||||
nbConfig := toNetbirdConfig(config, turnCredentials, relayCredentials, extraSettings)
|
||||
extendedConfig := integrationsConfig.ExtendNetBirdConfig(peer.ID, peerGroups, nbConfig, extraSettings)
|
||||
response.NetbirdConfig = extendedConfig
|
||||
|
||||
response.NetworkMap.PeerConfig = response.PeerConfig
|
||||
|
||||
remotePeers := make([]*proto.RemotePeerConfig, 0, len(networkMap.Peers)+len(networkMap.OfflinePeers))
|
||||
remotePeers = appendRemotePeerConfig(remotePeers, networkMap.Peers, dnsName)
|
||||
response.RemotePeers = remotePeers
|
||||
response.NetworkMap.RemotePeers = remotePeers
|
||||
response.RemotePeersIsEmpty = len(remotePeers) == 0
|
||||
response.NetworkMap.RemotePeersIsEmpty = response.RemotePeersIsEmpty
|
||||
|
||||
response.NetworkMap.OfflinePeers = appendRemotePeerConfig(nil, networkMap.OfflinePeers, dnsName)
|
||||
|
||||
firewallRules := toProtocolFirewallRules(networkMap.FirewallRules)
|
||||
response.NetworkMap.FirewallRules = firewallRules
|
||||
response.NetworkMap.FirewallRulesIsEmpty = len(firewallRules) == 0
|
||||
|
||||
routesFirewallRules := toProtocolRoutesFirewallRules(networkMap.RoutesFirewallRules)
|
||||
response.NetworkMap.RoutesFirewallRules = routesFirewallRules
|
||||
response.NetworkMap.RoutesFirewallRulesIsEmpty = len(routesFirewallRules) == 0
|
||||
|
||||
if networkMap.ForwardingRules != nil {
|
||||
forwardingRules := make([]*proto.ForwardingRule, 0, len(networkMap.ForwardingRules))
|
||||
for _, rule := range networkMap.ForwardingRules {
|
||||
forwardingRules = append(forwardingRules, rule.ToProto())
|
||||
}
|
||||
response.NetworkMap.ForwardingRules = forwardingRules
|
||||
}
|
||||
|
||||
return response
|
||||
}
|
||||
|
||||
func appendRemotePeerConfig(dst []*proto.RemotePeerConfig, peers []*nbpeer.Peer, dnsName string) []*proto.RemotePeerConfig {
|
||||
for _, rPeer := range peers {
|
||||
dst = append(dst, &proto.RemotePeerConfig{
|
||||
WgPubKey: rPeer.Key,
|
||||
AllowedIps: []string{rPeer.IP.String() + "/32"},
|
||||
SshConfig: &proto.SSHConfig{SshPubKey: []byte(rPeer.SSHKey)},
|
||||
Fqdn: rPeer.FQDN(dnsName),
|
||||
AgentVersion: rPeer.Meta.WtVersion,
|
||||
})
|
||||
}
|
||||
return dst
|
||||
}
|
||||
|
||||
// IsHealthy indicates whether the service is healthy
|
||||
func (s *GRPCServer) IsHealthy(ctx context.Context, req *proto.Empty) (*proto.Empty, error) {
|
||||
func (s *Server) IsHealthy(ctx context.Context, req *proto.Empty) (*proto.Empty, error) {
|
||||
return &proto.Empty{}, nil
|
||||
}
|
||||
|
||||
// sendInitialSync sends initial proto.SyncResponse to the peer requesting synchronization
|
||||
func (s *GRPCServer) sendInitialSync(ctx context.Context, peerKey wgtypes.Key, peer *nbpeer.Peer, networkMap *types.NetworkMap, postureChecks []*posture.Checks, srv proto.ManagementService_SyncServer) error {
|
||||
func (s *Server) sendInitialSync(ctx context.Context, peerKey wgtypes.Key, peer *nbpeer.Peer, networkMap *types.NetworkMap, postureChecks []*posture.Checks, srv proto.ManagementService_SyncServer, dnsFwdPort int64) error {
|
||||
var err error
|
||||
|
||||
var turnToken *Token
|
||||
@@ -803,29 +708,24 @@ func (s *GRPCServer) sendInitialSync(ctx context.Context, peerKey wgtypes.Key, p
|
||||
return status.Errorf(codes.Internal, "error handling request")
|
||||
}
|
||||
|
||||
peerGroups, err := getPeerGroupIDs(ctx, s.accountManager.GetStore(), peer.AccountID, peer.ID)
|
||||
peerGroups, err := s.accountManager.GetStore().GetPeerGroupIDs(ctx, store.LockingStrengthNone, peer.AccountID, peer.ID)
|
||||
if err != nil {
|
||||
return status.Errorf(codes.Internal, "failed to get peer groups %s", err)
|
||||
}
|
||||
|
||||
// Get all peers in the account for forwarder port computation
|
||||
allPeers, err := s.accountManager.GetStore().GetAccountPeers(ctx, store.LockingStrengthNone, peer.AccountID, "", "")
|
||||
if err != nil {
|
||||
return fmt.Errorf("get account peers: %w", err)
|
||||
}
|
||||
dnsFwdPort := computeForwarderPort(allPeers, dnsForwarderPortMinVersion)
|
||||
|
||||
plainResp := toSyncResponse(ctx, s.config, peer, turnToken, relayToken, networkMap, s.accountManager.GetDNSDomain(settings), postureChecks, nil, settings, settings.Extra, peerGroups, dnsFwdPort)
|
||||
plainResp := ToSyncResponse(ctx, s.config, peer, turnToken, relayToken, networkMap, s.networkMapController.GetDNSDomain(settings), postureChecks, nil, settings, settings.Extra, peerGroups, dnsFwdPort)
|
||||
|
||||
encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, plainResp)
|
||||
if err != nil {
|
||||
return status.Errorf(codes.Internal, "error handling request")
|
||||
}
|
||||
|
||||
sendStart := time.Now()
|
||||
err = srv.Send(&proto.EncryptedMessage{
|
||||
WgPubKey: s.wgKey.PublicKey().String(),
|
||||
Body: encryptedResp,
|
||||
})
|
||||
log.WithContext(ctx).Debugf("sendInitialSync: sending response took %s", time.Since(sendStart))
|
||||
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed sending SyncResponse %v", err)
|
||||
@@ -838,7 +738,7 @@ func (s *GRPCServer) sendInitialSync(ctx context.Context, peerKey wgtypes.Key, p
|
||||
// GetDeviceAuthorizationFlow returns a device authorization flow information
|
||||
// This is used for initiating an Oauth 2 device authorization grant flow
|
||||
// which will be used by our clients to Login
|
||||
func (s *GRPCServer) GetDeviceAuthorizationFlow(ctx context.Context, req *proto.EncryptedMessage) (*proto.EncryptedMessage, error) {
|
||||
func (s *Server) GetDeviceAuthorizationFlow(ctx context.Context, req *proto.EncryptedMessage) (*proto.EncryptedMessage, error) {
|
||||
log.WithContext(ctx).Tracef("GetDeviceAuthorizationFlow request for pubKey: %s", req.WgPubKey)
|
||||
start := time.Now()
|
||||
defer func() {
|
||||
@@ -896,7 +796,7 @@ func (s *GRPCServer) GetDeviceAuthorizationFlow(ctx context.Context, req *proto.
|
||||
// GetPKCEAuthorizationFlow returns a pkce authorization flow information
|
||||
// This is used for initiating an Oauth 2 pkce authorization grant flow
|
||||
// which will be used by our clients to Login
|
||||
func (s *GRPCServer) GetPKCEAuthorizationFlow(ctx context.Context, req *proto.EncryptedMessage) (*proto.EncryptedMessage, error) {
|
||||
func (s *Server) GetPKCEAuthorizationFlow(ctx context.Context, req *proto.EncryptedMessage) (*proto.EncryptedMessage, error) {
|
||||
log.WithContext(ctx).Tracef("GetPKCEAuthorizationFlow request for pubKey: %s", req.WgPubKey)
|
||||
start := time.Now()
|
||||
defer func() {
|
||||
@@ -951,7 +851,7 @@ func (s *GRPCServer) GetPKCEAuthorizationFlow(ctx context.Context, req *proto.En
|
||||
|
||||
// SyncMeta endpoint is used to synchronize peer's system metadata and notifies the connected,
|
||||
// peer's under the same account of any updates.
|
||||
func (s *GRPCServer) SyncMeta(ctx context.Context, req *proto.EncryptedMessage) (*proto.Empty, error) {
|
||||
func (s *Server) SyncMeta(ctx context.Context, req *proto.EncryptedMessage) (*proto.Empty, error) {
|
||||
realIP := getRealIP(ctx)
|
||||
log.WithContext(ctx).Debugf("Sync meta request from peer [%s] [%s]", req.WgPubKey, realIP.String())
|
||||
|
||||
@@ -976,7 +876,7 @@ func (s *GRPCServer) SyncMeta(ctx context.Context, req *proto.EncryptedMessage)
|
||||
return &proto.Empty{}, nil
|
||||
}
|
||||
|
||||
func (s *GRPCServer) Logout(ctx context.Context, req *proto.EncryptedMessage) (*proto.Empty, error) {
|
||||
func (s *Server) Logout(ctx context.Context, req *proto.EncryptedMessage) (*proto.Empty, error) {
|
||||
log.WithContext(ctx).Debugf("Logout request from peer [%s]", req.WgPubKey)
|
||||
start := time.Now()
|
||||
|
||||
106
management/internals/shared/grpc/server_test.go
Normal file
106
management/internals/shared/grpc/server_test.go
Normal file
@@ -0,0 +1,106 @@
|
||||
package grpc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
|
||||
"github.com/netbirdio/netbird/encryption"
|
||||
"github.com/netbirdio/netbird/management/internals/server/config"
|
||||
mgmtProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
|
||||
func TestServer_GetDeviceAuthorizationFlow(t *testing.T) {
|
||||
testingServerKey, err := wgtypes.GeneratePrivateKey()
|
||||
if err != nil {
|
||||
t.Errorf("unable to generate server wg key for testing GetDeviceAuthorizationFlow, error: %v", err)
|
||||
}
|
||||
|
||||
testingClientKey, err := wgtypes.GeneratePrivateKey()
|
||||
if err != nil {
|
||||
t.Errorf("unable to generate client wg key for testing GetDeviceAuthorizationFlow, error: %v", err)
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
inputFlow *config.DeviceAuthorizationFlow
|
||||
expectedFlow *mgmtProto.DeviceAuthorizationFlow
|
||||
expectedErrFunc require.ErrorAssertionFunc
|
||||
expectedErrMSG string
|
||||
expectedComparisonFunc require.ComparisonAssertionFunc
|
||||
expectedComparisonMSG string
|
||||
}{
|
||||
{
|
||||
name: "Testing No Device Flow Config",
|
||||
inputFlow: nil,
|
||||
expectedErrFunc: require.Error,
|
||||
expectedErrMSG: "should return error",
|
||||
},
|
||||
{
|
||||
name: "Testing Invalid Device Flow Provider Config",
|
||||
inputFlow: &config.DeviceAuthorizationFlow{
|
||||
Provider: "NoNe",
|
||||
ProviderConfig: config.ProviderConfig{
|
||||
ClientID: "test",
|
||||
},
|
||||
},
|
||||
expectedErrFunc: require.Error,
|
||||
expectedErrMSG: "should return error",
|
||||
},
|
||||
{
|
||||
name: "Testing Full Device Flow Config",
|
||||
inputFlow: &config.DeviceAuthorizationFlow{
|
||||
Provider: "hosted",
|
||||
ProviderConfig: config.ProviderConfig{
|
||||
ClientID: "test",
|
||||
},
|
||||
},
|
||||
expectedFlow: &mgmtProto.DeviceAuthorizationFlow{
|
||||
Provider: 0,
|
||||
ProviderConfig: &mgmtProto.ProviderConfig{
|
||||
ClientID: "test",
|
||||
},
|
||||
},
|
||||
expectedErrFunc: require.NoError,
|
||||
expectedErrMSG: "should not return error",
|
||||
expectedComparisonFunc: require.Equal,
|
||||
expectedComparisonMSG: "should match",
|
||||
},
|
||||
}
|
||||
|
||||
for _, testCase := range testCases {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
mgmtServer := &Server{
|
||||
wgKey: testingServerKey,
|
||||
config: &config.Config{
|
||||
DeviceAuthorizationFlow: testCase.inputFlow,
|
||||
},
|
||||
}
|
||||
|
||||
message := &mgmtProto.DeviceAuthorizationFlowRequest{}
|
||||
|
||||
encryptedMSG, err := encryption.EncryptMessage(testingClientKey.PublicKey(), mgmtServer.wgKey, message)
|
||||
require.NoError(t, err, "should be able to encrypt message")
|
||||
|
||||
resp, err := mgmtServer.GetDeviceAuthorizationFlow(
|
||||
context.TODO(),
|
||||
&mgmtProto.EncryptedMessage{
|
||||
WgPubKey: testingClientKey.PublicKey().String(),
|
||||
Body: encryptedMSG,
|
||||
},
|
||||
)
|
||||
testCase.expectedErrFunc(t, err, testCase.expectedErrMSG)
|
||||
if testCase.expectedComparisonFunc != nil {
|
||||
flowInfoResp := &mgmtProto.DeviceAuthorizationFlow{}
|
||||
|
||||
err = encryption.DecryptMessage(mgmtServer.wgKey.PublicKey(), testingClientKey, resp.Body, flowInfoResp)
|
||||
require.NoError(t, err, "should be able to decrypt")
|
||||
|
||||
testCase.expectedComparisonFunc(t, testCase.expectedFlow.Provider, flowInfoResp.Provider, testCase.expectedComparisonMSG)
|
||||
testCase.expectedComparisonFunc(t, testCase.expectedFlow.ProviderConfig.ClientID, flowInfoResp.ProviderConfig.ClientID, testCase.expectedComparisonMSG)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package server
|
||||
package grpc
|
||||
|
||||
import (
|
||||
"context"
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
integrationsConfig "github.com/netbirdio/management-integrations/integrations/config"
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
|
||||
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
|
||||
"github.com/netbirdio/netbird/management/server/groups"
|
||||
"github.com/netbirdio/netbird/management/server/settings"
|
||||
@@ -37,7 +38,7 @@ type TimeBasedAuthSecretsManager struct {
|
||||
relayCfg *nbconfig.Relay
|
||||
turnHmacToken *auth.TimedHMAC
|
||||
relayHmacToken *authv2.Generator
|
||||
updateManager *PeersUpdateManager
|
||||
updateManager network_map.PeersUpdateManager
|
||||
settingsManager settings.Manager
|
||||
groupsManager groups.Manager
|
||||
turnCancelMap map[string]chan struct{}
|
||||
@@ -46,7 +47,7 @@ type TimeBasedAuthSecretsManager struct {
|
||||
|
||||
type Token auth.Token
|
||||
|
||||
func NewTimeBasedAuthSecretsManager(updateManager *PeersUpdateManager, turnCfg *nbconfig.TURNConfig, relayCfg *nbconfig.Relay, settingsManager settings.Manager, groupsManager groups.Manager) *TimeBasedAuthSecretsManager {
|
||||
func NewTimeBasedAuthSecretsManager(updateManager network_map.PeersUpdateManager, turnCfg *nbconfig.TURNConfig, relayCfg *nbconfig.Relay, settingsManager settings.Manager, groupsManager groups.Manager) *TimeBasedAuthSecretsManager {
|
||||
mgr := &TimeBasedAuthSecretsManager{
|
||||
updateManager: updateManager,
|
||||
turnCfg: turnCfg,
|
||||
@@ -227,7 +228,7 @@ func (m *TimeBasedAuthSecretsManager) pushNewTURNAndRelayTokens(ctx context.Cont
|
||||
m.extendNetbirdConfig(ctx, peerID, accountID, update)
|
||||
|
||||
log.WithContext(ctx).Debugf("sending new TURN credentials to peer %s", peerID)
|
||||
m.updateManager.SendUpdate(ctx, peerID, &UpdateMessage{Update: update})
|
||||
m.updateManager.SendUpdate(ctx, peerID, &network_map.UpdateMessage{Update: update})
|
||||
}
|
||||
|
||||
func (m *TimeBasedAuthSecretsManager) pushNewRelayTokens(ctx context.Context, accountID, peerID string) {
|
||||
@@ -251,7 +252,7 @@ func (m *TimeBasedAuthSecretsManager) pushNewRelayTokens(ctx context.Context, ac
|
||||
m.extendNetbirdConfig(ctx, peerID, accountID, update)
|
||||
|
||||
log.WithContext(ctx).Debugf("sending new relay credentials to peer %s", peerID)
|
||||
m.updateManager.SendUpdate(ctx, peerID, &UpdateMessage{Update: update})
|
||||
m.updateManager.SendUpdate(ctx, peerID, &network_map.UpdateMessage{Update: update})
|
||||
}
|
||||
|
||||
func (m *TimeBasedAuthSecretsManager) extendNetbirdConfig(ctx context.Context, peerID, accountID string, update *proto.SyncResponse) {
|
||||
@@ -1,4 +1,4 @@
|
||||
package server
|
||||
package grpc
|
||||
|
||||
import (
|
||||
"context"
|
||||
@@ -13,6 +13,8 @@ import (
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
|
||||
"github.com/netbirdio/netbird/management/internals/server/config"
|
||||
"github.com/netbirdio/netbird/management/server/groups"
|
||||
"github.com/netbirdio/netbird/management/server/settings"
|
||||
@@ -31,7 +33,7 @@ var TurnTestHost = &config.Host{
|
||||
func TestTimeBasedAuthSecretsManager_GenerateCredentials(t *testing.T) {
|
||||
ttl := util.Duration{Duration: time.Hour}
|
||||
secret := "some_secret"
|
||||
peersManager := NewPeersUpdateManager(nil)
|
||||
peersManager := update_channel.NewPeersUpdateManager(nil)
|
||||
|
||||
rc := &config.Relay{
|
||||
Addresses: []string{"localhost:0"},
|
||||
@@ -80,7 +82,7 @@ func TestTimeBasedAuthSecretsManager_GenerateCredentials(t *testing.T) {
|
||||
func TestTimeBasedAuthSecretsManager_SetupRefresh(t *testing.T) {
|
||||
ttl := util.Duration{Duration: 2 * time.Second}
|
||||
secret := "some_secret"
|
||||
peersManager := NewPeersUpdateManager(nil)
|
||||
peersManager := update_channel.NewPeersUpdateManager(nil)
|
||||
peer := "some_peer"
|
||||
updateChannel := peersManager.CreateChannel(context.Background(), peer)
|
||||
|
||||
@@ -116,7 +118,7 @@ func TestTimeBasedAuthSecretsManager_SetupRefresh(t *testing.T) {
|
||||
t.Errorf("expecting peer to be present in the relay cancel map, got not present")
|
||||
}
|
||||
|
||||
var updates []*UpdateMessage
|
||||
var updates []*network_map.UpdateMessage
|
||||
|
||||
loop:
|
||||
for timeout := time.After(5 * time.Second); ; {
|
||||
@@ -185,7 +187,7 @@ loop:
|
||||
func TestTimeBasedAuthSecretsManager_CancelRefresh(t *testing.T) {
|
||||
ttl := util.Duration{Duration: time.Hour}
|
||||
secret := "some_secret"
|
||||
peersManager := NewPeersUpdateManager(nil)
|
||||
peersManager := update_channel.NewPeersUpdateManager(nil)
|
||||
peer := "some_peer"
|
||||
|
||||
rc := &config.Relay{
|
||||
@@ -1,11 +1,19 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"github.com/netbirdio/netbird/management/cmd"
|
||||
"log"
|
||||
"net/http"
|
||||
// nolint:gosec
|
||||
_ "net/http/pprof"
|
||||
"os"
|
||||
|
||||
"github.com/netbirdio/netbird/management/cmd"
|
||||
)
|
||||
|
||||
func main() {
|
||||
go func() {
|
||||
log.Println(http.ListenAndServe("localhost:6060", nil))
|
||||
}()
|
||||
if err := cmd.Execute(); err != nil {
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
@@ -11,10 +11,8 @@ import (
|
||||
"reflect"
|
||||
"regexp"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
cacheStore "github.com/eko/gocache/lib/v4/store"
|
||||
@@ -26,6 +24,7 @@ import (
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/formatter/hook"
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
nbcache "github.com/netbirdio/netbird/management/server/cache"
|
||||
@@ -68,7 +67,7 @@ type DefaultAccountManager struct {
|
||||
cacheMux sync.Mutex
|
||||
// cacheLoading keeps the accountIDs that are currently reloading. The accountID has to be removed once cache has been reloaded
|
||||
cacheLoading map[string]chan struct{}
|
||||
peersUpdateManager *PeersUpdateManager
|
||||
networkMapController network_map.Controller
|
||||
idpManager idp.Manager
|
||||
cacheManager *nbcache.AccountUserDataCache
|
||||
externalCacheManager nbcache.UserDataCache
|
||||
@@ -88,8 +87,7 @@ type DefaultAccountManager struct {
|
||||
singleAccountMode bool
|
||||
// singleAccountModeDomain is a domain to use in singleAccountMode setup
|
||||
singleAccountModeDomain string
|
||||
// dnsDomain is used for peer resolution. This is appended to the peer's name
|
||||
dnsDomain string
|
||||
|
||||
peerLoginExpiry Scheduler
|
||||
|
||||
peerInactivityExpiry Scheduler
|
||||
@@ -103,14 +101,11 @@ type DefaultAccountManager struct {
|
||||
|
||||
permissionsManager permissions.Manager
|
||||
|
||||
accountUpdateLocks sync.Map
|
||||
updateAccountPeersBufferInterval atomic.Int64
|
||||
|
||||
loginFilter *loginFilter
|
||||
|
||||
disableDefaultPolicy bool
|
||||
}
|
||||
|
||||
var _ account.Manager = (*DefaultAccountManager)(nil)
|
||||
|
||||
func isUniqueConstraintError(err error) bool {
|
||||
switch {
|
||||
case strings.Contains(err.Error(), "(SQLSTATE 23505)"),
|
||||
@@ -177,10 +172,9 @@ func (am *DefaultAccountManager) getJWTGroupsChanges(user *types.User, groups []
|
||||
func BuildManager(
|
||||
ctx context.Context,
|
||||
store store.Store,
|
||||
peersUpdateManager *PeersUpdateManager,
|
||||
networkMapController network_map.Controller,
|
||||
idpManager idp.Manager,
|
||||
singleAccountModeDomain string,
|
||||
dnsDomain string,
|
||||
eventStore activity.Store,
|
||||
geo geolocation.Geolocation,
|
||||
userDeleteFromIDPEnabled bool,
|
||||
@@ -199,12 +193,11 @@ func BuildManager(
|
||||
am := &DefaultAccountManager{
|
||||
Store: store,
|
||||
geo: geo,
|
||||
peersUpdateManager: peersUpdateManager,
|
||||
networkMapController: networkMapController,
|
||||
idpManager: idpManager,
|
||||
ctx: context.Background(),
|
||||
cacheMux: sync.Mutex{},
|
||||
cacheLoading: map[string]chan struct{}{},
|
||||
dnsDomain: dnsDomain,
|
||||
eventStore: eventStore,
|
||||
peerLoginExpiry: NewDefaultScheduler(),
|
||||
peerInactivityExpiry: NewDefaultScheduler(),
|
||||
@@ -215,11 +208,10 @@ func BuildManager(
|
||||
proxyController: proxyController,
|
||||
settingsManager: settingsManager,
|
||||
permissionsManager: permissionsManager,
|
||||
loginFilter: newLoginFilter(),
|
||||
disableDefaultPolicy: disableDefaultPolicy,
|
||||
}
|
||||
|
||||
am.startWarmup(ctx)
|
||||
am.networkMapController.StartWarmup(ctx)
|
||||
|
||||
accountsCounter, err := store.GetAccountsCounter(ctx)
|
||||
if err != nil {
|
||||
@@ -267,32 +259,6 @@ func (am *DefaultAccountManager) SetEphemeralManager(em ephemeral.Manager) {
|
||||
am.ephemeralManager = em
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) startWarmup(ctx context.Context) {
|
||||
var initialInterval int64
|
||||
intervalStr := os.Getenv("NB_PEER_UPDATE_INTERVAL_MS")
|
||||
interval, err := strconv.Atoi(intervalStr)
|
||||
if err != nil {
|
||||
initialInterval = 1
|
||||
log.WithContext(ctx).Warnf("failed to parse peer update interval, using default value %dms: %v", initialInterval, err)
|
||||
} else {
|
||||
initialInterval = int64(interval) * 10
|
||||
go func() {
|
||||
startupPeriodStr := os.Getenv("NB_PEER_UPDATE_STARTUP_PERIOD_S")
|
||||
startupPeriod, err := strconv.Atoi(startupPeriodStr)
|
||||
if err != nil {
|
||||
startupPeriod = 1
|
||||
log.WithContext(ctx).Warnf("failed to parse peer update startup period, using default value %ds: %v", startupPeriod, err)
|
||||
}
|
||||
time.Sleep(time.Duration(startupPeriod) * time.Second)
|
||||
am.updateAccountPeersBufferInterval.Store(int64(time.Duration(interval) * time.Millisecond))
|
||||
log.WithContext(ctx).Infof("set peer update buffer interval to %dms", interval)
|
||||
}()
|
||||
}
|
||||
am.updateAccountPeersBufferInterval.Store(initialInterval)
|
||||
log.WithContext(ctx).Infof("set peer update buffer interval to %dms", initialInterval)
|
||||
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) GetExternalCacheManager() account.ExternalCacheManager {
|
||||
return am.externalCacheManager
|
||||
}
|
||||
@@ -1636,19 +1602,10 @@ func domainIsUpToDate(domain string, domainCategory string, userAuth nbcontext.U
|
||||
return domainCategory == types.PrivateCategory || userAuth.DomainCategory != types.PrivateCategory || domain != userAuth.Domain
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) AllowSync(wgPubKey string, metahash uint64) bool {
|
||||
return am.loginFilter.allowLogin(wgPubKey, metahash)
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) {
|
||||
start := time.Now()
|
||||
defer func() {
|
||||
log.WithContext(ctx).Debugf("SyncAndMarkPeer: took %v", time.Since(start))
|
||||
}()
|
||||
|
||||
peer, netMap, postureChecks, err := am.SyncPeer(ctx, types.PeerSync{WireGuardPubKey: peerPubKey, Meta: meta}, accountID)
|
||||
func (am *DefaultAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) {
|
||||
peer, netMap, postureChecks, dnsfwdPort, err := am.SyncPeer(ctx, types.PeerSync{WireGuardPubKey: peerPubKey, Meta: meta}, accountID)
|
||||
if err != nil {
|
||||
return nil, nil, nil, fmt.Errorf("error syncing peer: %w", err)
|
||||
return nil, nil, nil, 0, fmt.Errorf("error syncing peer: %w", err)
|
||||
}
|
||||
|
||||
err = am.MarkPeerConnected(ctx, peerPubKey, true, realIP, accountID)
|
||||
@@ -1656,10 +1613,7 @@ func (am *DefaultAccountManager) SyncAndMarkPeer(ctx context.Context, accountID
|
||||
log.WithContext(ctx).Warnf("failed marking peer as connected %s %v", peerPubKey, err)
|
||||
}
|
||||
|
||||
metahash := metaHash(meta, realIP.String())
|
||||
am.loginFilter.addLogin(peerPubKey, metahash)
|
||||
|
||||
return peer, netMap, postureChecks, nil
|
||||
return peer, netMap, postureChecks, dnsfwdPort, nil
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) OnPeerDisconnected(ctx context.Context, accountID string, peerPubKey string) error {
|
||||
@@ -1676,41 +1630,19 @@ func (am *DefaultAccountManager) SyncPeerMeta(ctx context.Context, peerPubKey st
|
||||
return err
|
||||
}
|
||||
|
||||
_, _, _, err = am.SyncPeer(ctx, types.PeerSync{WireGuardPubKey: peerPubKey, Meta: meta, UpdateAccountPeers: true}, accountID)
|
||||
_, _, _, _, err = am.SyncPeer(ctx, types.PeerSync{WireGuardPubKey: peerPubKey, Meta: meta, UpdateAccountPeers: true}, accountID)
|
||||
if err != nil {
|
||||
return mapError(ctx, err)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetAllConnectedPeers returns connected peers based on peersUpdateManager.GetAllConnectedPeers()
|
||||
func (am *DefaultAccountManager) GetAllConnectedPeers() (map[string]struct{}, error) {
|
||||
return am.peersUpdateManager.GetAllConnectedPeers(), nil
|
||||
}
|
||||
|
||||
// HasConnectedChannel returns true if peers has channel in update manager, otherwise false
|
||||
func (am *DefaultAccountManager) HasConnectedChannel(peerID string) bool {
|
||||
return am.peersUpdateManager.HasChannel(peerID)
|
||||
}
|
||||
|
||||
var invalidDomainRegexp = regexp.MustCompile(`^([a-z0-9]+(-[a-z0-9]+)*\.)+[a-z]{2,}$`)
|
||||
|
||||
func isDomainValid(domain string) bool {
|
||||
return invalidDomainRegexp.MatchString(domain)
|
||||
}
|
||||
|
||||
// GetDNSDomain returns the configured dnsDomain
|
||||
func (am *DefaultAccountManager) GetDNSDomain(settings *types.Settings) string {
|
||||
if settings == nil {
|
||||
return am.dnsDomain
|
||||
}
|
||||
if settings.DNSDomain == "" {
|
||||
return am.dnsDomain
|
||||
}
|
||||
|
||||
return settings.DNSDomain
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) onPeersInvalidated(ctx context.Context, accountID string, peerIDs []string) {
|
||||
peers := []*nbpeer.Peer{}
|
||||
log.WithContext(ctx).Debugf("invalidating peers %v for account %s", peerIDs, accountID)
|
||||
@@ -2129,7 +2061,11 @@ func (am *DefaultAccountManager) UpdatePeerIP(ctx context.Context, accountID, us
|
||||
}
|
||||
|
||||
if updateNetworkMap {
|
||||
am.BufferUpdateAccountPeers(ctx, accountID)
|
||||
peer, err := am.Store.GetPeerByID(ctx, store.LockingStrengthNone, accountID, peerID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
am.networkMapController.OnPeerUpdated(peer.AccountID, peer)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -2177,7 +2113,7 @@ func (am *DefaultAccountManager) savePeerIPUpdate(ctx context.Context, transacti
|
||||
if err != nil {
|
||||
return fmt.Errorf("get account settings: %w", err)
|
||||
}
|
||||
dnsDomain := am.GetDNSDomain(settings)
|
||||
dnsDomain := am.networkMapController.GetDNSDomain(settings)
|
||||
|
||||
eventMeta := peer.EventMeta(dnsDomain)
|
||||
oldIP := peer.IP.String()
|
||||
|
||||
@@ -89,7 +89,6 @@ type Manager interface {
|
||||
SaveNameServerGroup(ctx context.Context, accountID, userID string, nsGroupToSave *nbdns.NameServerGroup) error
|
||||
DeleteNameServerGroup(ctx context.Context, accountID, nsGroupID, userID string) error
|
||||
ListNameServerGroups(ctx context.Context, accountID string, userID string) ([]*nbdns.NameServerGroup, error)
|
||||
GetDNSDomain(settings *types.Settings) string
|
||||
StoreEvent(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any)
|
||||
GetEvents(ctx context.Context, accountID, userID string) ([]*activity.Event, error)
|
||||
GetDNSSettings(ctx context.Context, accountID string, userID string) (*types.DNSSettings, error)
|
||||
@@ -97,10 +96,8 @@ type Manager interface {
|
||||
GetPeer(ctx context.Context, accountID, peerID, userID string) (*nbpeer.Peer, error)
|
||||
UpdateAccountSettings(ctx context.Context, accountID, userID string, newSettings *types.Settings) (*types.Settings, error)
|
||||
UpdateAccountOnboarding(ctx context.Context, accountID, userID string, newOnboarding *types.AccountOnboarding) (*types.AccountOnboarding, error)
|
||||
LoginPeer(ctx context.Context, login types.PeerLogin) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) // used by peer gRPC API
|
||||
SyncPeer(ctx context.Context, sync types.PeerSync, accountID string) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) // used by peer gRPC API
|
||||
GetAllConnectedPeers() (map[string]struct{}, error)
|
||||
HasConnectedChannel(peerID string) bool
|
||||
LoginPeer(ctx context.Context, login types.PeerLogin) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) // used by peer gRPC API
|
||||
SyncPeer(ctx context.Context, sync types.PeerSync, accountID string) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) // used by peer gRPC API
|
||||
GetExternalCacheManager() ExternalCacheManager
|
||||
GetPostureChecks(ctx context.Context, accountID, postureChecksID, userID string) (*posture.Checks, error)
|
||||
SavePostureChecks(ctx context.Context, accountID, userID string, postureChecks *posture.Checks, create bool) (*posture.Checks, error)
|
||||
@@ -110,7 +107,7 @@ type Manager interface {
|
||||
UpdateIntegratedValidator(ctx context.Context, accountID, userID, validator string, groups []string) error
|
||||
GroupValidation(ctx context.Context, accountId string, groups []string) (bool, error)
|
||||
GetValidatedPeers(ctx context.Context, accountID string) (map[string]struct{}, map[string]string, error)
|
||||
SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error)
|
||||
SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error)
|
||||
OnPeerDisconnected(ctx context.Context, accountID string, peerPubKey string) error
|
||||
SyncPeerMeta(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta) error
|
||||
FindExistingPostureCheck(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error)
|
||||
@@ -127,5 +124,4 @@ type Manager interface {
|
||||
GetOwnerInfo(ctx context.Context, accountId string) (*types.UserInfo, error)
|
||||
GetCurrentUserInfo(ctx context.Context, userAuth nbcontext.UserAuth) (*users.UserInfoWithPermissions, error)
|
||||
SetEphemeralManager(em ephemeral.Manager)
|
||||
AllowSync(string, uint64) bool
|
||||
}
|
||||
|
||||
11
management/server/account/request_buffer.go
Normal file
11
management/server/account/request_buffer.go
Normal file
@@ -0,0 +1,11 @@
|
||||
package account
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
)
|
||||
|
||||
type RequestBuffer interface {
|
||||
GetAccountWithBackpressure(ctx context.Context, accountID string) (*types.Account, error)
|
||||
}
|
||||
@@ -22,6 +22,9 @@ import (
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
|
||||
nbAccount "github.com/netbirdio/netbird/management/server/account"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/cache"
|
||||
@@ -406,7 +409,7 @@ func TestNewAccount(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestAccountManager_GetOrCreateAccountByUser(t *testing.T) {
|
||||
manager, err := createManager(t)
|
||||
manager, _, err := createManager(t)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
@@ -603,7 +606,7 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) {
|
||||
|
||||
for _, testCase := range testCases {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
manager, err := createManager(t)
|
||||
manager, _, err := createManager(t)
|
||||
require.NoError(t, err, "unable to create account manager")
|
||||
|
||||
accountID, err := manager.GetAccountIDByUserID(context.Background(), testCase.inputInitUserParams.UserId, testCase.inputInitUserParams.Domain)
|
||||
@@ -644,7 +647,7 @@ func TestDefaultAccountManager_SyncUserJWTGroups(t *testing.T) {
|
||||
userId := "user-id"
|
||||
domain := "test.domain"
|
||||
_ = newAccountWithId(context.Background(), "", userId, domain, false)
|
||||
manager, err := createManager(t)
|
||||
manager, _, err := createManager(t)
|
||||
require.NoError(t, err, "unable to create account manager")
|
||||
accountID, err := manager.GetAccountIDByUserID(context.Background(), userId, domain)
|
||||
require.NoError(t, err, "create init user failed")
|
||||
@@ -705,7 +708,7 @@ func TestDefaultAccountManager_SyncUserJWTGroups(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestAccountManager_PrivateAccount(t *testing.T) {
|
||||
manager, err := createManager(t)
|
||||
manager, _, err := createManager(t)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
@@ -731,7 +734,7 @@ func TestAccountManager_PrivateAccount(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestAccountManager_SetOrUpdateDomain(t *testing.T) {
|
||||
manager, err := createManager(t)
|
||||
manager, _, err := createManager(t)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
@@ -768,7 +771,7 @@ func TestAccountManager_SetOrUpdateDomain(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestAccountManager_GetAccountByUserID(t *testing.T) {
|
||||
manager, err := createManager(t)
|
||||
manager, _, err := createManager(t)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
@@ -805,7 +808,7 @@ func createAccount(am *DefaultAccountManager, accountID, userID, domain string)
|
||||
}
|
||||
|
||||
func TestAccountManager_GetAccount(t *testing.T) {
|
||||
manager, err := createManager(t)
|
||||
manager, _, err := createManager(t)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
@@ -843,7 +846,7 @@ func TestAccountManager_GetAccount(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestAccountManager_DeleteAccount(t *testing.T) {
|
||||
manager, err := createManager(t)
|
||||
manager, _, err := createManager(t)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
@@ -924,7 +927,7 @@ func BenchmarkTest_GetAccountWithclaims(b *testing.B) {
|
||||
DomainCategory: types.PublicCategory,
|
||||
}
|
||||
|
||||
am, err := createManager(b)
|
||||
am, _, err := createManager(b)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
return
|
||||
@@ -1016,7 +1019,7 @@ func genUsers(p string, n int) map[string]*types.User {
|
||||
}
|
||||
|
||||
func TestAccountManager_AddPeer(t *testing.T) {
|
||||
manager, err := createManager(t)
|
||||
manager, _, err := createManager(t)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
@@ -1086,7 +1089,7 @@ func TestAccountManager_AddPeer(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestAccountManager_AddPeerWithUserID(t *testing.T) {
|
||||
manager, err := createManager(t)
|
||||
manager, _, err := createManager(t)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
@@ -1154,8 +1157,17 @@ func TestAccountManager_AddPeerWithUserID(t *testing.T) {
|
||||
assert.Equal(t, peer.IP.String(), fmt.Sprint(ev.Meta["ip"]))
|
||||
}
|
||||
|
||||
func TestAccountManager_NetworkUpdates_SaveGroup_Experimental(t *testing.T) {
|
||||
t.Setenv(network_map.EnvNewNetworkMapBuilder, "true")
|
||||
testAccountManager_NetworkUpdates_SaveGroup(t)
|
||||
}
|
||||
|
||||
func TestAccountManager_NetworkUpdates_SaveGroup(t *testing.T) {
|
||||
manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t)
|
||||
testAccountManager_NetworkUpdates_SaveGroup(t)
|
||||
}
|
||||
|
||||
func testAccountManager_NetworkUpdates_SaveGroup(t *testing.T) {
|
||||
manager, updateManager, account, peer1, peer2, peer3 := setupNetworkMapTest(t)
|
||||
|
||||
group := types.Group{
|
||||
ID: "groupA",
|
||||
@@ -1181,8 +1193,8 @@ func TestAccountManager_NetworkUpdates_SaveGroup(t *testing.T) {
|
||||
}, true)
|
||||
require.NoError(t, err)
|
||||
|
||||
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
|
||||
defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
|
||||
updMsg := updateManager.CreateChannel(context.Background(), peer1.ID)
|
||||
defer updateManager.CloseChannel(context.Background(), peer1.ID)
|
||||
|
||||
wg := sync.WaitGroup{}
|
||||
wg.Add(1)
|
||||
@@ -1205,11 +1217,20 @@ func TestAccountManager_NetworkUpdates_SaveGroup(t *testing.T) {
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func TestAccountManager_NetworkUpdates_DeletePolicy(t *testing.T) {
|
||||
manager, account, peer1, _, _ := setupNetworkMapTest(t)
|
||||
func TestAccountManager_NetworkUpdates_DeletePolicy_Experimental(t *testing.T) {
|
||||
t.Setenv(network_map.EnvNewNetworkMapBuilder, "true")
|
||||
testAccountManager_NetworkUpdates_DeletePolicy(t)
|
||||
}
|
||||
|
||||
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
|
||||
defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
|
||||
func TestAccountManager_NetworkUpdates_DeletePolicy(t *testing.T) {
|
||||
testAccountManager_NetworkUpdates_DeletePolicy(t)
|
||||
}
|
||||
|
||||
func testAccountManager_NetworkUpdates_DeletePolicy(t *testing.T) {
|
||||
manager, updateManager, account, peer1, _, _ := setupNetworkMapTest(t)
|
||||
|
||||
updMsg := updateManager.CreateChannel(context.Background(), peer1.ID)
|
||||
defer updateManager.CloseChannel(context.Background(), peer1.ID)
|
||||
|
||||
// Ensure that we do not receive an update message before the policy is deleted
|
||||
time.Sleep(time.Second)
|
||||
@@ -1239,8 +1260,17 @@ func TestAccountManager_NetworkUpdates_DeletePolicy(t *testing.T) {
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func TestAccountManager_NetworkUpdates_SavePolicy_Experimental(t *testing.T) {
|
||||
t.Setenv(network_map.EnvNewNetworkMapBuilder, "true")
|
||||
testAccountManager_NetworkUpdates_SavePolicy(t)
|
||||
}
|
||||
|
||||
func TestAccountManager_NetworkUpdates_SavePolicy(t *testing.T) {
|
||||
manager, account, peer1, peer2, _ := setupNetworkMapTest(t)
|
||||
testAccountManager_NetworkUpdates_SavePolicy(t)
|
||||
}
|
||||
|
||||
func testAccountManager_NetworkUpdates_SavePolicy(t *testing.T) {
|
||||
manager, updateManager, account, peer1, peer2, _ := setupNetworkMapTest(t)
|
||||
|
||||
group := types.Group{
|
||||
AccountID: account.Id,
|
||||
@@ -1253,8 +1283,8 @@ func TestAccountManager_NetworkUpdates_SavePolicy(t *testing.T) {
|
||||
return
|
||||
}
|
||||
|
||||
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
|
||||
defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
|
||||
updMsg := updateManager.CreateChannel(context.Background(), peer1.ID)
|
||||
defer updateManager.CloseChannel(context.Background(), peer1.ID)
|
||||
|
||||
wg := sync.WaitGroup{}
|
||||
wg.Add(1)
|
||||
@@ -1288,8 +1318,17 @@ func TestAccountManager_NetworkUpdates_SavePolicy(t *testing.T) {
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func TestAccountManager_NetworkUpdates_DeletePeer_Experimental(t *testing.T) {
|
||||
t.Setenv(network_map.EnvNewNetworkMapBuilder, "true")
|
||||
testAccountManager_NetworkUpdates_DeletePeer(t)
|
||||
}
|
||||
|
||||
func TestAccountManager_NetworkUpdates_DeletePeer(t *testing.T) {
|
||||
manager, account, peer1, _, peer3 := setupNetworkMapTest(t)
|
||||
testAccountManager_NetworkUpdates_DeletePeer(t)
|
||||
}
|
||||
|
||||
func testAccountManager_NetworkUpdates_DeletePeer(t *testing.T) {
|
||||
manager, updateManager, account, peer1, _, peer3 := setupNetworkMapTest(t)
|
||||
|
||||
group := types.Group{
|
||||
ID: "groupA",
|
||||
@@ -1318,8 +1357,11 @@ func TestAccountManager_NetworkUpdates_DeletePeer(t *testing.T) {
|
||||
return
|
||||
}
|
||||
|
||||
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
|
||||
defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
|
||||
// We need to sleep to wait for the buffer peer update
|
||||
time.Sleep(300 * time.Millisecond)
|
||||
|
||||
updMsg := updateManager.CreateChannel(context.Background(), peer1.ID)
|
||||
defer updateManager.CloseChannel(context.Background(), peer1.ID)
|
||||
|
||||
wg := sync.WaitGroup{}
|
||||
wg.Add(1)
|
||||
@@ -1341,11 +1383,20 @@ func TestAccountManager_NetworkUpdates_DeletePeer(t *testing.T) {
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func TestAccountManager_NetworkUpdates_DeleteGroup(t *testing.T) {
|
||||
manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t)
|
||||
func TestAccountManager_NetworkUpdates_DeleteGroup_Experimental(t *testing.T) {
|
||||
t.Setenv(network_map.EnvNewNetworkMapBuilder, "true")
|
||||
testAccountManager_NetworkUpdates_DeleteGroup(t)
|
||||
}
|
||||
|
||||
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
|
||||
defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
|
||||
func TestAccountManager_NetworkUpdates_DeleteGroup(t *testing.T) {
|
||||
testAccountManager_NetworkUpdates_DeleteGroup(t)
|
||||
}
|
||||
|
||||
func testAccountManager_NetworkUpdates_DeleteGroup(t *testing.T) {
|
||||
manager, updateManager, account, peer1, peer2, peer3 := setupNetworkMapTest(t)
|
||||
|
||||
updMsg := updateManager.CreateChannel(context.Background(), peer1.ID)
|
||||
defer updateManager.CloseChannel(context.Background(), peer1.ID)
|
||||
|
||||
err := manager.CreateGroup(context.Background(), account.Id, userID, &types.Group{
|
||||
ID: "groupA",
|
||||
@@ -1377,6 +1428,14 @@ func TestAccountManager_NetworkUpdates_DeleteGroup(t *testing.T) {
|
||||
return
|
||||
}
|
||||
|
||||
for drained := false; !drained; {
|
||||
select {
|
||||
case <-updMsg:
|
||||
default:
|
||||
drained = true
|
||||
}
|
||||
}
|
||||
|
||||
wg := sync.WaitGroup{}
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
@@ -1404,7 +1463,7 @@ func TestAccountManager_NetworkUpdates_DeleteGroup(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestAccountManager_DeletePeer(t *testing.T) {
|
||||
manager, err := createManager(t)
|
||||
manager, _, err := createManager(t)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
@@ -1485,7 +1544,7 @@ func getEvent(t *testing.T, accountID string, manager nbAccount.Manager, eventTy
|
||||
}
|
||||
|
||||
func TestGetUsersFromAccount(t *testing.T) {
|
||||
manager, err := createManager(t)
|
||||
manager, _, err := createManager(t)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -1736,7 +1795,9 @@ func TestAccount_Copy(t *testing.T) {
|
||||
Address: "172.12.6.1/24",
|
||||
},
|
||||
},
|
||||
NetworkMapCache: &types.NetworkMapBuilder{},
|
||||
}
|
||||
account.InitOnce()
|
||||
err := hasNilField(account)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
@@ -1782,7 +1843,7 @@ func hasNilField(x interface{}) error {
|
||||
}
|
||||
|
||||
func TestDefaultAccountManager_DefaultAccountSettings(t *testing.T) {
|
||||
manager, err := createManager(t)
|
||||
manager, _, err := createManager(t)
|
||||
require.NoError(t, err, "unable to create account manager")
|
||||
|
||||
accountID, err := manager.GetAccountIDByUserID(context.Background(), userID, "")
|
||||
@@ -1797,7 +1858,7 @@ func TestDefaultAccountManager_DefaultAccountSettings(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) {
|
||||
manager, err := createManager(t)
|
||||
manager, _, err := createManager(t)
|
||||
require.NoError(t, err, "unable to create account manager")
|
||||
|
||||
_, err = manager.GetAccountIDByUserID(context.Background(), userID, "")
|
||||
@@ -1853,7 +1914,7 @@ func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing.T) {
|
||||
manager, err := createManager(t)
|
||||
manager, _, err := createManager(t)
|
||||
require.NoError(t, err, "unable to create account manager")
|
||||
|
||||
accountID, err := manager.GetAccountIDByUserID(context.Background(), userID, "")
|
||||
@@ -1896,7 +1957,7 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing.
|
||||
}
|
||||
|
||||
func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *testing.T) {
|
||||
manager, err := createManager(t)
|
||||
manager, _, err := createManager(t)
|
||||
require.NoError(t, err, "unable to create account manager")
|
||||
|
||||
_, err = manager.GetAccountIDByUserID(context.Background(), userID, "")
|
||||
@@ -1958,7 +2019,7 @@ func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *test
|
||||
}
|
||||
|
||||
func TestDefaultAccountManager_UpdateAccountSettings(t *testing.T) {
|
||||
manager, err := createManager(t)
|
||||
manager, _, err := createManager(t)
|
||||
require.NoError(t, err, "unable to create account manager")
|
||||
|
||||
accountID, err := manager.GetAccountIDByUserID(context.Background(), userID, "")
|
||||
@@ -2622,7 +2683,7 @@ func TestAccount_GetNextInactivePeerExpiration(t *testing.T) {
|
||||
|
||||
func TestAccount_SetJWTGroups(t *testing.T) {
|
||||
t.Setenv("NETBIRD_STORE_ENGINE", "postgres")
|
||||
manager, err := createManager(t)
|
||||
manager, _, err := createManager(t)
|
||||
require.NoError(t, err, "unable to create account manager")
|
||||
|
||||
// create a new account
|
||||
@@ -2864,18 +2925,18 @@ func TestAccount_UserGroupsRemoveFromPeers(t *testing.T) {
|
||||
// Fatalf(format string, args ...interface{})
|
||||
// }
|
||||
|
||||
func createManager(t testing.TB) (*DefaultAccountManager, error) {
|
||||
func createManager(t testing.TB) (*DefaultAccountManager, *update_channel.PeersUpdateManager, error) {
|
||||
t.Helper()
|
||||
|
||||
store, err := createStore(t)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, nil, err
|
||||
}
|
||||
eventStore := &activity.InMemoryEventStore{}
|
||||
|
||||
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
@@ -2893,12 +2954,17 @@ func createManager(t testing.TB) (*DefaultAccountManager, error) {
|
||||
|
||||
permissionsManager := permissions.NewManager(store)
|
||||
|
||||
manager, err := BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
|
||||
ctx := context.Background()
|
||||
|
||||
updateManager := update_channel.NewPeersUpdateManager(metrics)
|
||||
requestBuffer := NewAccountRequestBuffer(ctx, store)
|
||||
networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock())
|
||||
manager, err := BuildManager(ctx, store, networkMapController, nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
return manager, nil
|
||||
return manager, updateManager, nil
|
||||
}
|
||||
|
||||
func createStore(t testing.TB) (store.Store, error) {
|
||||
@@ -2927,10 +2993,10 @@ func waitTimeout(wg *sync.WaitGroup, timeout time.Duration) bool {
|
||||
}
|
||||
}
|
||||
|
||||
func setupNetworkMapTest(t *testing.T) (*DefaultAccountManager, *types.Account, *nbpeer.Peer, *nbpeer.Peer, *nbpeer.Peer) {
|
||||
func setupNetworkMapTest(t *testing.T) (*DefaultAccountManager, *update_channel.PeersUpdateManager, *types.Account, *nbpeer.Peer, *nbpeer.Peer, *nbpeer.Peer) {
|
||||
t.Helper()
|
||||
|
||||
manager, err := createManager(t)
|
||||
manager, updateManager, err := createManager(t)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -2971,10 +3037,10 @@ func setupNetworkMapTest(t *testing.T) (*DefaultAccountManager, *types.Account,
|
||||
peer2 := getPeer(manager, setupKey)
|
||||
peer3 := getPeer(manager, setupKey)
|
||||
|
||||
return manager, account, peer1, peer2, peer3
|
||||
return manager, updateManager, account, peer1, peer2, peer3
|
||||
}
|
||||
|
||||
func peerShouldNotReceiveUpdate(t *testing.T, updateMessage <-chan *UpdateMessage) {
|
||||
func peerShouldNotReceiveUpdate(t *testing.T, updateMessage <-chan *network_map.UpdateMessage) {
|
||||
t.Helper()
|
||||
select {
|
||||
case msg := <-updateMessage:
|
||||
@@ -2984,7 +3050,7 @@ func peerShouldNotReceiveUpdate(t *testing.T, updateMessage <-chan *UpdateMessag
|
||||
}
|
||||
}
|
||||
|
||||
func peerShouldReceiveUpdate(t *testing.T, updateMessage <-chan *UpdateMessage) {
|
||||
func peerShouldReceiveUpdate(t *testing.T, updateMessage <-chan *network_map.UpdateMessage) {
|
||||
t.Helper()
|
||||
|
||||
select {
|
||||
@@ -3022,7 +3088,7 @@ func BenchmarkSyncAndMarkPeer(b *testing.B) {
|
||||
defer log.SetOutput(os.Stderr)
|
||||
for _, bc := range benchCases {
|
||||
b.Run(bc.name, func(b *testing.B) {
|
||||
manager, accountID, _, err := setupTestAccountManager(b, bc.peers, bc.groups)
|
||||
manager, updateManager, accountID, _, err := setupTestAccountManager(b, bc.peers, bc.groups)
|
||||
if err != nil {
|
||||
b.Fatalf("Failed to setup test account manager: %v", err)
|
||||
}
|
||||
@@ -3031,16 +3097,14 @@ func BenchmarkSyncAndMarkPeer(b *testing.B) {
|
||||
if err != nil {
|
||||
b.Fatalf("Failed to get account: %v", err)
|
||||
}
|
||||
peerChannels := make(map[string]chan *UpdateMessage)
|
||||
for peerID := range account.Peers {
|
||||
peerChannels[peerID] = make(chan *UpdateMessage, channelBufferSize)
|
||||
updateManager.CreateChannel(ctx, peerID)
|
||||
}
|
||||
manager.peersUpdateManager.peerChannels = peerChannels
|
||||
|
||||
b.ResetTimer()
|
||||
start := time.Now()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, _, _, err := manager.SyncAndMarkPeer(context.Background(), account.Id, account.Peers["peer-1"].Key, nbpeer.PeerSystemMeta{Hostname: strconv.Itoa(i)}, net.IP{1, 1, 1, 1})
|
||||
_, _, _, _, err := manager.SyncAndMarkPeer(context.Background(), account.Id, account.Peers["peer-1"].Key, nbpeer.PeerSystemMeta{Hostname: strconv.Itoa(i)}, net.IP{1, 1, 1, 1})
|
||||
assert.NoError(b, err)
|
||||
}
|
||||
|
||||
@@ -3085,7 +3149,7 @@ func BenchmarkLoginPeer_ExistingPeer(b *testing.B) {
|
||||
defer log.SetOutput(os.Stderr)
|
||||
for _, bc := range benchCases {
|
||||
b.Run(bc.name, func(b *testing.B) {
|
||||
manager, accountID, _, err := setupTestAccountManager(b, bc.peers, bc.groups)
|
||||
manager, updateManager, accountID, _, err := setupTestAccountManager(b, bc.peers, bc.groups)
|
||||
if err != nil {
|
||||
b.Fatalf("Failed to setup test account manager: %v", err)
|
||||
}
|
||||
@@ -3094,11 +3158,10 @@ func BenchmarkLoginPeer_ExistingPeer(b *testing.B) {
|
||||
if err != nil {
|
||||
b.Fatalf("Failed to get account: %v", err)
|
||||
}
|
||||
peerChannels := make(map[string]chan *UpdateMessage)
|
||||
|
||||
for peerID := range account.Peers {
|
||||
peerChannels[peerID] = make(chan *UpdateMessage, channelBufferSize)
|
||||
updateManager.CreateChannel(ctx, peerID)
|
||||
}
|
||||
manager.peersUpdateManager.peerChannels = peerChannels
|
||||
|
||||
b.ResetTimer()
|
||||
start := time.Now()
|
||||
@@ -3155,7 +3218,7 @@ func BenchmarkLoginPeer_NewPeer(b *testing.B) {
|
||||
defer log.SetOutput(os.Stderr)
|
||||
for _, bc := range benchCases {
|
||||
b.Run(bc.name, func(b *testing.B) {
|
||||
manager, accountID, _, err := setupTestAccountManager(b, bc.peers, bc.groups)
|
||||
manager, updateManager, accountID, _, err := setupTestAccountManager(b, bc.peers, bc.groups)
|
||||
if err != nil {
|
||||
b.Fatalf("Failed to setup test account manager: %v", err)
|
||||
}
|
||||
@@ -3164,11 +3227,10 @@ func BenchmarkLoginPeer_NewPeer(b *testing.B) {
|
||||
if err != nil {
|
||||
b.Fatalf("Failed to get account: %v", err)
|
||||
}
|
||||
peerChannels := make(map[string]chan *UpdateMessage)
|
||||
|
||||
for peerID := range account.Peers {
|
||||
peerChannels[peerID] = make(chan *UpdateMessage, channelBufferSize)
|
||||
updateManager.CreateChannel(ctx, peerID)
|
||||
}
|
||||
manager.peersUpdateManager.peerChannels = peerChannels
|
||||
|
||||
b.ResetTimer()
|
||||
start := time.Now()
|
||||
@@ -3227,7 +3289,7 @@ func TestMain(m *testing.M) {
|
||||
}
|
||||
|
||||
func Test_GetCreateAccountByPrivateDomain(t *testing.T) {
|
||||
manager, err := createManager(t)
|
||||
manager, _, err := createManager(t)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
@@ -3273,7 +3335,7 @@ func Test_GetCreateAccountByPrivateDomain(t *testing.T) {
|
||||
}
|
||||
|
||||
func Test_UpdateToPrimaryAccount(t *testing.T) {
|
||||
manager, err := createManager(t)
|
||||
manager, _, err := createManager(t)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
@@ -3303,7 +3365,7 @@ func Test_UpdateToPrimaryAccount(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestDefaultAccountManager_IsCacheCold(t *testing.T) {
|
||||
manager, err := createManager(t)
|
||||
manager, _, err := createManager(t)
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Run("memory cache", func(t *testing.T) {
|
||||
@@ -3353,7 +3415,7 @@ func TestDefaultAccountManager_IsCacheCold(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestPropagateUserGroupMemberships(t *testing.T) {
|
||||
manager, err := createManager(t)
|
||||
manager, _, err := createManager(t)
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx := context.Background()
|
||||
@@ -3470,7 +3532,7 @@ func TestPropagateUserGroupMemberships(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestDefaultAccountManager_GetAccountOnboarding(t *testing.T) {
|
||||
manager, err := createManager(t)
|
||||
manager, _, err := createManager(t)
|
||||
require.NoError(t, err)
|
||||
|
||||
account, err := manager.GetOrCreateAccountByUser(context.Background(), userID, "")
|
||||
@@ -3502,7 +3564,7 @@ func TestDefaultAccountManager_GetAccountOnboarding(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestDefaultAccountManager_UpdateAccountOnboarding(t *testing.T) {
|
||||
manager, err := createManager(t)
|
||||
manager, _, err := createManager(t)
|
||||
require.NoError(t, err)
|
||||
|
||||
account, err := manager.GetOrCreateAccountByUser(context.Background(), userID, "")
|
||||
@@ -3541,7 +3603,7 @@ func TestDefaultAccountManager_UpdateAccountOnboarding(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestDefaultAccountManager_UpdatePeerIP(t *testing.T) {
|
||||
manager, err := createManager(t)
|
||||
manager, _, err := createManager(t)
|
||||
require.NoError(t, err, "unable to create account manager")
|
||||
|
||||
accountID, err := manager.GetAccountIDByUserID(context.Background(), userID, "")
|
||||
@@ -3608,7 +3670,7 @@ func TestDefaultAccountManager_UpdatePeerIP(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestAddNewUserToDomainAccountWithApproval(t *testing.T) {
|
||||
manager, err := createManager(t)
|
||||
manager, _, err := createManager(t)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -3654,7 +3716,7 @@ func TestAddNewUserToDomainAccountWithApproval(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestAddNewUserToDomainAccountWithoutApproval(t *testing.T) {
|
||||
manager, err := createManager(t)
|
||||
manager, _, err := createManager(t)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"gorm.io/driver/postgres"
|
||||
@@ -273,15 +274,21 @@ func configureConnectionPool(db *gorm.DB, storeEngine types.Engine) (*gorm.DB, e
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if storeEngine == types.SqliteStoreEngine {
|
||||
sqlDB.SetMaxOpenConns(1)
|
||||
} else {
|
||||
conns, err := strconv.Atoi(os.Getenv(sqlMaxOpenConnsEnv))
|
||||
if err != nil {
|
||||
conns = runtime.NumCPU()
|
||||
}
|
||||
sqlDB.SetMaxOpenConns(conns)
|
||||
conns, err := strconv.Atoi(os.Getenv(sqlMaxOpenConnsEnv))
|
||||
if err != nil {
|
||||
conns = runtime.NumCPU()
|
||||
}
|
||||
if storeEngine == types.SqliteStoreEngine {
|
||||
conns = 1
|
||||
}
|
||||
|
||||
sqlDB.SetMaxOpenConns(conns)
|
||||
sqlDB.SetMaxIdleConns(conns)
|
||||
sqlDB.SetConnMaxLifetime(time.Hour)
|
||||
sqlDB.SetConnMaxIdleTime(3 * time.Minute)
|
||||
|
||||
log.Infof("Set max open db connections to %d, max idle to %d, max lifetime to %v, max idle time to %v",
|
||||
conns, conns, time.Hour, 3*time.Minute)
|
||||
|
||||
return db, nil
|
||||
}
|
||||
|
||||
@@ -3,54 +3,23 @@ package server
|
||||
import (
|
||||
"context"
|
||||
"slices"
|
||||
"sync"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/mod/semver"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/management/server/util"
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
)
|
||||
|
||||
const (
|
||||
dnsForwarderPort = nbdns.ForwarderServerPort
|
||||
oldForwarderPort = nbdns.ForwarderClientPort
|
||||
)
|
||||
|
||||
const dnsForwarderPortMinVersion = "v0.59.0"
|
||||
|
||||
// DNSConfigCache is a thread-safe cache for DNS configuration components
|
||||
type DNSConfigCache struct {
|
||||
NameServerGroups sync.Map
|
||||
}
|
||||
|
||||
// GetNameServerGroup retrieves a cached name server group
|
||||
func (c *DNSConfigCache) GetNameServerGroup(key string) (*proto.NameServerGroup, bool) {
|
||||
if c == nil {
|
||||
return nil, false
|
||||
}
|
||||
if value, ok := c.NameServerGroups.Load(key); ok {
|
||||
return value.(*proto.NameServerGroup), true
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// SetNameServerGroup stores a name server group in the cache
|
||||
func (c *DNSConfigCache) SetNameServerGroup(key string, value *proto.NameServerGroup) {
|
||||
if c == nil {
|
||||
return
|
||||
}
|
||||
c.NameServerGroups.Store(key, value)
|
||||
}
|
||||
|
||||
// GetDNSSettings validates a user role and returns the DNS settings for the provided account ID
|
||||
func (am *DefaultAccountManager) GetDNSSettings(ctx context.Context, accountID string, userID string) (*types.DNSSettings, error) {
|
||||
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Read)
|
||||
@@ -191,99 +160,3 @@ func validateDNSSettings(ctx context.Context, transaction store.Store, accountID
|
||||
|
||||
return validateGroups(settings.DisabledManagementGroups, groups)
|
||||
}
|
||||
|
||||
// computeForwarderPort checks if all peers in the account have updated to a specific version or newer.
|
||||
// If all peers have the required version, it returns the new well-known port (22054), otherwise returns 0.
|
||||
func computeForwarderPort(peers []*nbpeer.Peer, requiredVersion string) int64 {
|
||||
if len(peers) == 0 {
|
||||
return int64(oldForwarderPort)
|
||||
}
|
||||
|
||||
reqVer := semver.Canonical(requiredVersion)
|
||||
|
||||
// Check if all peers have the required version or newer
|
||||
for _, peer := range peers {
|
||||
|
||||
// Development version is always supported
|
||||
if peer.Meta.WtVersion == "development" {
|
||||
continue
|
||||
}
|
||||
peerVersion := semver.Canonical("v" + peer.Meta.WtVersion)
|
||||
if peerVersion == "" {
|
||||
// If any peer doesn't have version info, return 0
|
||||
return int64(oldForwarderPort)
|
||||
}
|
||||
|
||||
// Compare versions
|
||||
if semver.Compare(peerVersion, reqVer) < 0 {
|
||||
return int64(oldForwarderPort)
|
||||
}
|
||||
}
|
||||
|
||||
// All peers have the required version or newer
|
||||
return int64(dnsForwarderPort)
|
||||
}
|
||||
|
||||
// toProtocolDNSConfig converts nbdns.Config to proto.DNSConfig using the cache
|
||||
func toProtocolDNSConfig(update nbdns.Config, cache *DNSConfigCache, forwardPort int64) *proto.DNSConfig {
|
||||
protoUpdate := &proto.DNSConfig{
|
||||
ServiceEnable: update.ServiceEnable,
|
||||
CustomZones: make([]*proto.CustomZone, 0, len(update.CustomZones)),
|
||||
NameServerGroups: make([]*proto.NameServerGroup, 0, len(update.NameServerGroups)),
|
||||
ForwarderPort: forwardPort,
|
||||
}
|
||||
|
||||
for _, zone := range update.CustomZones {
|
||||
protoZone := convertToProtoCustomZone(zone)
|
||||
protoUpdate.CustomZones = append(protoUpdate.CustomZones, protoZone)
|
||||
}
|
||||
|
||||
for _, nsGroup := range update.NameServerGroups {
|
||||
cacheKey := nsGroup.ID
|
||||
if cachedGroup, exists := cache.GetNameServerGroup(cacheKey); exists {
|
||||
protoUpdate.NameServerGroups = append(protoUpdate.NameServerGroups, cachedGroup)
|
||||
} else {
|
||||
protoGroup := convertToProtoNameServerGroup(nsGroup)
|
||||
cache.SetNameServerGroup(cacheKey, protoGroup)
|
||||
protoUpdate.NameServerGroups = append(protoUpdate.NameServerGroups, protoGroup)
|
||||
}
|
||||
}
|
||||
|
||||
return protoUpdate
|
||||
}
|
||||
|
||||
// Helper function to convert nbdns.CustomZone to proto.CustomZone
|
||||
func convertToProtoCustomZone(zone nbdns.CustomZone) *proto.CustomZone {
|
||||
protoZone := &proto.CustomZone{
|
||||
Domain: zone.Domain,
|
||||
Records: make([]*proto.SimpleRecord, 0, len(zone.Records)),
|
||||
}
|
||||
for _, record := range zone.Records {
|
||||
protoZone.Records = append(protoZone.Records, &proto.SimpleRecord{
|
||||
Name: record.Name,
|
||||
Type: int64(record.Type),
|
||||
Class: record.Class,
|
||||
TTL: int64(record.TTL),
|
||||
RData: record.RData,
|
||||
})
|
||||
}
|
||||
return protoZone
|
||||
}
|
||||
|
||||
// Helper function to convert nbdns.NameServerGroup to proto.NameServerGroup
|
||||
func convertToProtoNameServerGroup(nsGroup *nbdns.NameServerGroup) *proto.NameServerGroup {
|
||||
protoGroup := &proto.NameServerGroup{
|
||||
Primary: nsGroup.Primary,
|
||||
Domains: nsGroup.Domains,
|
||||
SearchDomainsEnabled: nsGroup.SearchDomainsEnabled,
|
||||
NameServers: make([]*proto.NameServer, 0, len(nsGroup.NameServers)),
|
||||
}
|
||||
for _, ns := range nsGroup.NameServers {
|
||||
protoGroup.NameServers = append(protoGroup.NameServers, &proto.NameServer{
|
||||
IP: ns.IP.String(),
|
||||
Port: int64(ns.Port),
|
||||
NSType: int64(ns.NSType),
|
||||
})
|
||||
}
|
||||
return protoGroup
|
||||
}
|
||||
|
||||
@@ -2,9 +2,7 @@ package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -12,6 +10,8 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
|
||||
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/settings"
|
||||
@@ -218,7 +218,13 @@ func createDNSManager(t *testing.T) (*DefaultAccountManager, error) {
|
||||
// return empty extra settings for expected calls to UpdateAccountPeers
|
||||
settingsMockManager.EXPECT().GetExtraSettings(gomock.Any(), gomock.Any()).Return(&types.ExtraSettings{}, nil).AnyTimes()
|
||||
permissionsManager := permissions.NewManager(store)
|
||||
return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.test", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
|
||||
|
||||
ctx := context.Background()
|
||||
updateManager := update_channel.NewPeersUpdateManager(metrics)
|
||||
requestBuffer := NewAccountRequestBuffer(ctx, store)
|
||||
networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.test", port_forwarding.NewControllerMock())
|
||||
|
||||
return BuildManager(context.Background(), store, networkMapController, nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
|
||||
}
|
||||
|
||||
func createDNSStore(t *testing.T) (store.Store, error) {
|
||||
@@ -344,247 +350,8 @@ func initTestDNSAccount(t *testing.T, am *DefaultAccountManager) (*types.Account
|
||||
return am.Store.GetAccount(context.Background(), account.Id)
|
||||
}
|
||||
|
||||
func generateTestData(size int) nbdns.Config {
|
||||
config := nbdns.Config{
|
||||
ServiceEnable: true,
|
||||
CustomZones: make([]nbdns.CustomZone, size),
|
||||
NameServerGroups: make([]*nbdns.NameServerGroup, size),
|
||||
}
|
||||
|
||||
for i := 0; i < size; i++ {
|
||||
config.CustomZones[i] = nbdns.CustomZone{
|
||||
Domain: fmt.Sprintf("domain%d.com", i),
|
||||
Records: []nbdns.SimpleRecord{
|
||||
{
|
||||
Name: fmt.Sprintf("record%d", i),
|
||||
Type: 1,
|
||||
Class: "IN",
|
||||
TTL: 3600,
|
||||
RData: "192.168.1.1",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
config.NameServerGroups[i] = &nbdns.NameServerGroup{
|
||||
ID: fmt.Sprintf("group%d", i),
|
||||
Primary: i == 0,
|
||||
Domains: []string{fmt.Sprintf("domain%d.com", i)},
|
||||
SearchDomainsEnabled: true,
|
||||
NameServers: []nbdns.NameServer{
|
||||
{
|
||||
IP: netip.MustParseAddr("8.8.8.8"),
|
||||
Port: 53,
|
||||
NSType: 1,
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
return config
|
||||
}
|
||||
|
||||
func BenchmarkToProtocolDNSConfig(b *testing.B) {
|
||||
sizes := []int{10, 100, 1000}
|
||||
|
||||
for _, size := range sizes {
|
||||
testData := generateTestData(size)
|
||||
|
||||
b.Run(fmt.Sprintf("WithCache-Size%d", size), func(b *testing.B) {
|
||||
cache := &DNSConfigCache{}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
toProtocolDNSConfig(testData, cache, int64(dnsForwarderPort))
|
||||
}
|
||||
})
|
||||
|
||||
b.Run(fmt.Sprintf("WithoutCache-Size%d", size), func(b *testing.B) {
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
cache := &DNSConfigCache{}
|
||||
toProtocolDNSConfig(testData, cache, int64(dnsForwarderPort))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestToProtocolDNSConfigWithCache(t *testing.T) {
|
||||
var cache DNSConfigCache
|
||||
|
||||
// Create two different configs
|
||||
config1 := nbdns.Config{
|
||||
ServiceEnable: true,
|
||||
CustomZones: []nbdns.CustomZone{
|
||||
{
|
||||
Domain: "example.com",
|
||||
Records: []nbdns.SimpleRecord{
|
||||
{Name: "www", Type: 1, Class: "IN", TTL: 300, RData: "192.168.1.1"},
|
||||
},
|
||||
},
|
||||
},
|
||||
NameServerGroups: []*nbdns.NameServerGroup{
|
||||
{
|
||||
ID: "group1",
|
||||
Name: "Group 1",
|
||||
NameServers: []nbdns.NameServer{
|
||||
{IP: netip.MustParseAddr("8.8.8.8"), Port: 53},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
config2 := nbdns.Config{
|
||||
ServiceEnable: true,
|
||||
CustomZones: []nbdns.CustomZone{
|
||||
{
|
||||
Domain: "example.org",
|
||||
Records: []nbdns.SimpleRecord{
|
||||
{Name: "mail", Type: 1, Class: "IN", TTL: 300, RData: "192.168.1.2"},
|
||||
},
|
||||
},
|
||||
},
|
||||
NameServerGroups: []*nbdns.NameServerGroup{
|
||||
{
|
||||
ID: "group2",
|
||||
Name: "Group 2",
|
||||
NameServers: []nbdns.NameServer{
|
||||
{IP: netip.MustParseAddr("8.8.4.4"), Port: 53},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// First run with config1
|
||||
result1 := toProtocolDNSConfig(config1, &cache, int64(dnsForwarderPort))
|
||||
|
||||
// Second run with config2
|
||||
result2 := toProtocolDNSConfig(config2, &cache, int64(dnsForwarderPort))
|
||||
|
||||
// Third run with config1 again
|
||||
result3 := toProtocolDNSConfig(config1, &cache, int64(dnsForwarderPort))
|
||||
|
||||
// Verify that result1 and result3 are identical
|
||||
if !reflect.DeepEqual(result1, result3) {
|
||||
t.Errorf("Results are not identical when run with the same input. Expected %v, got %v", result1, result3)
|
||||
}
|
||||
|
||||
// Verify that result2 is different from result1 and result3
|
||||
if reflect.DeepEqual(result1, result2) || reflect.DeepEqual(result2, result3) {
|
||||
t.Errorf("Results should be different for different inputs")
|
||||
}
|
||||
|
||||
if _, exists := cache.GetNameServerGroup("group1"); !exists {
|
||||
t.Errorf("Cache should contain name server group 'group1'")
|
||||
}
|
||||
|
||||
if _, exists := cache.GetNameServerGroup("group2"); !exists {
|
||||
t.Errorf("Cache should contain name server group 'group2'")
|
||||
}
|
||||
}
|
||||
|
||||
func TestComputeForwarderPort(t *testing.T) {
|
||||
// Test with empty peers list
|
||||
peers := []*nbpeer.Peer{}
|
||||
result := computeForwarderPort(peers, "v0.59.0")
|
||||
if result != int64(oldForwarderPort) {
|
||||
t.Errorf("Expected %d for empty peers list, got %d", oldForwarderPort, result)
|
||||
}
|
||||
|
||||
// Test with peers that have old versions
|
||||
peers = []*nbpeer.Peer{
|
||||
{
|
||||
Meta: nbpeer.PeerSystemMeta{
|
||||
WtVersion: "0.57.0",
|
||||
},
|
||||
},
|
||||
{
|
||||
Meta: nbpeer.PeerSystemMeta{
|
||||
WtVersion: "0.26.0",
|
||||
},
|
||||
},
|
||||
}
|
||||
result = computeForwarderPort(peers, "v0.59.0")
|
||||
if result != int64(oldForwarderPort) {
|
||||
t.Errorf("Expected %d for peers with old versions, got %d", oldForwarderPort, result)
|
||||
}
|
||||
|
||||
// Test with peers that have new versions
|
||||
peers = []*nbpeer.Peer{
|
||||
{
|
||||
Meta: nbpeer.PeerSystemMeta{
|
||||
WtVersion: "0.59.0",
|
||||
},
|
||||
},
|
||||
{
|
||||
Meta: nbpeer.PeerSystemMeta{
|
||||
WtVersion: "0.59.0",
|
||||
},
|
||||
},
|
||||
}
|
||||
result = computeForwarderPort(peers, "v0.59.0")
|
||||
if result != int64(dnsForwarderPort) {
|
||||
t.Errorf("Expected %d for peers with new versions, got %d", dnsForwarderPort, result)
|
||||
}
|
||||
|
||||
// Test with peers that have mixed versions
|
||||
peers = []*nbpeer.Peer{
|
||||
{
|
||||
Meta: nbpeer.PeerSystemMeta{
|
||||
WtVersion: "0.59.0",
|
||||
},
|
||||
},
|
||||
{
|
||||
Meta: nbpeer.PeerSystemMeta{
|
||||
WtVersion: "0.57.0",
|
||||
},
|
||||
},
|
||||
}
|
||||
result = computeForwarderPort(peers, "v0.59.0")
|
||||
if result != int64(oldForwarderPort) {
|
||||
t.Errorf("Expected %d for peers with mixed versions, got %d", oldForwarderPort, result)
|
||||
}
|
||||
|
||||
// Test with peers that have empty version
|
||||
peers = []*nbpeer.Peer{
|
||||
{
|
||||
Meta: nbpeer.PeerSystemMeta{
|
||||
WtVersion: "",
|
||||
},
|
||||
},
|
||||
}
|
||||
result = computeForwarderPort(peers, "v0.59.0")
|
||||
if result != int64(oldForwarderPort) {
|
||||
t.Errorf("Expected %d for peers with empty version, got %d", oldForwarderPort, result)
|
||||
}
|
||||
|
||||
peers = []*nbpeer.Peer{
|
||||
{
|
||||
Meta: nbpeer.PeerSystemMeta{
|
||||
WtVersion: "development",
|
||||
},
|
||||
},
|
||||
}
|
||||
result = computeForwarderPort(peers, "v0.59.0")
|
||||
if result == int64(oldForwarderPort) {
|
||||
t.Errorf("Expected %d for peers with dev version, got %d", dnsForwarderPort, result)
|
||||
}
|
||||
|
||||
// Test with peers that have unknown version string
|
||||
peers = []*nbpeer.Peer{
|
||||
{
|
||||
Meta: nbpeer.PeerSystemMeta{
|
||||
WtVersion: "unknown",
|
||||
},
|
||||
},
|
||||
}
|
||||
result = computeForwarderPort(peers, "v0.59.0")
|
||||
if result != int64(oldForwarderPort) {
|
||||
t.Errorf("Expected %d for peers with unknown version, got %d", oldForwarderPort, result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNSAccountPeersUpdate(t *testing.T) {
|
||||
manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t)
|
||||
manager, updateManager, account, peer1, peer2, peer3 := setupNetworkMapTest(t)
|
||||
|
||||
err := manager.CreateGroups(context.Background(), account.Id, userID, []*types.Group{
|
||||
{
|
||||
@@ -600,9 +367,9 @@ func TestDNSAccountPeersUpdate(t *testing.T) {
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
||||
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
|
||||
updMsg := updateManager.CreateChannel(context.Background(), peer1.ID)
|
||||
t.Cleanup(func() {
|
||||
manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
|
||||
updateManager.CloseChannel(context.Background(), peer1.ID)
|
||||
})
|
||||
|
||||
// Saving DNS settings with groups that have no peers should not trigger updates to account peers or send peer updates
|
||||
|
||||
@@ -28,7 +28,7 @@ func generateAndStoreEvents(t *testing.T, manager *DefaultAccountManager, typ ac
|
||||
}
|
||||
|
||||
func TestDefaultAccountManager_GetEvents(t *testing.T) {
|
||||
manager, err := createManager(t)
|
||||
manager, _, err := createManager(t)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
@@ -138,6 +138,11 @@ func (am *DefaultAccountManager) UpdateGroup(ctx context.Context, accountID, use
|
||||
return err
|
||||
}
|
||||
|
||||
newGroup.AccountID = accountID
|
||||
|
||||
events := am.prepareGroupEvents(ctx, transaction, accountID, userID, newGroup)
|
||||
eventsToStore = append(eventsToStore, events...)
|
||||
|
||||
oldGroup, err := transaction.GetGroupByID(ctx, store.LockingStrengthNone, accountID, newGroup.ID)
|
||||
if err != nil {
|
||||
return status.Errorf(status.NotFound, "group with ID %s not found", newGroup.ID)
|
||||
@@ -157,11 +162,6 @@ func (am *DefaultAccountManager) UpdateGroup(ctx context.Context, accountID, use
|
||||
}
|
||||
}
|
||||
|
||||
newGroup.AccountID = accountID
|
||||
|
||||
events := am.prepareGroupEvents(ctx, transaction, accountID, userID, newGroup)
|
||||
eventsToStore = append(eventsToStore, events...)
|
||||
|
||||
updateAccountPeers, err = areGroupChangesAffectPeers(ctx, transaction, accountID, []string{newGroup.ID})
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -335,6 +335,16 @@ func (am *DefaultAccountManager) prepareGroupEvents(ctx context.Context, transac
|
||||
if err == nil && oldGroup != nil {
|
||||
addedPeers = util.Difference(newGroup.Peers, oldGroup.Peers)
|
||||
removedPeers = util.Difference(oldGroup.Peers, newGroup.Peers)
|
||||
|
||||
if oldGroup.Name != newGroup.Name {
|
||||
eventsToStore = append(eventsToStore, func() {
|
||||
meta := map[string]any{
|
||||
"old_name": oldGroup.Name,
|
||||
"new_name": newGroup.Name,
|
||||
}
|
||||
am.StoreEvent(ctx, userID, newGroup.ID, accountID, activity.GroupUpdated, meta)
|
||||
})
|
||||
}
|
||||
} else {
|
||||
addedPeers = append(addedPeers, newGroup.Peers...)
|
||||
eventsToStore = append(eventsToStore, func() {
|
||||
@@ -354,7 +364,7 @@ func (am *DefaultAccountManager) prepareGroupEvents(ctx context.Context, transac
|
||||
log.WithContext(ctx).Debugf("failed to get account settings for group events: %v", err)
|
||||
return nil
|
||||
}
|
||||
dnsDomain := am.GetDNSDomain(settings)
|
||||
dnsDomain := am.networkMapController.GetDNSDomain(settings)
|
||||
|
||||
for _, peerID := range addedPeers {
|
||||
peer, ok := peers[peerID]
|
||||
|
||||
@@ -37,7 +37,7 @@ const (
|
||||
)
|
||||
|
||||
func TestDefaultAccountManager_CreateGroup(t *testing.T) {
|
||||
am, err := createManager(t)
|
||||
am, _, err := createManager(t)
|
||||
if err != nil {
|
||||
t.Error("failed to create account manager")
|
||||
}
|
||||
@@ -74,7 +74,7 @@ func TestDefaultAccountManager_CreateGroup(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestDefaultAccountManager_DeleteGroup(t *testing.T) {
|
||||
am, err := createManager(t)
|
||||
am, _, err := createManager(t)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create account manager: %s", err)
|
||||
}
|
||||
@@ -156,7 +156,7 @@ func TestDefaultAccountManager_DeleteGroup(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestDefaultAccountManager_DeleteGroups(t *testing.T) {
|
||||
am, err := createManager(t)
|
||||
am, _, err := createManager(t)
|
||||
assert.NoError(t, err, "Failed to create account manager")
|
||||
|
||||
manager, account, err := initTestGroupAccount(am)
|
||||
@@ -408,7 +408,7 @@ func initTestGroupAccount(am *DefaultAccountManager) (*DefaultAccountManager, *t
|
||||
}
|
||||
|
||||
func TestGroupAccountPeersUpdate(t *testing.T) {
|
||||
manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t)
|
||||
manager, updateManager, account, peer1, peer2, peer3 := setupNetworkMapTest(t)
|
||||
|
||||
g := []*types.Group{
|
||||
{
|
||||
@@ -442,9 +442,9 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
|
||||
updMsg := updateManager.CreateChannel(context.Background(), peer1.ID)
|
||||
t.Cleanup(func() {
|
||||
manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
|
||||
updateManager.CloseChannel(context.Background(), peer1.ID)
|
||||
})
|
||||
|
||||
// Saving a group that is not linked to any resource should not update account peers
|
||||
@@ -748,7 +748,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
|
||||
}
|
||||
|
||||
func Test_AddPeerToGroup(t *testing.T) {
|
||||
manager, err := createManager(t)
|
||||
manager, _, err := createManager(t)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
@@ -805,7 +805,7 @@ func Test_AddPeerToGroup(t *testing.T) {
|
||||
}
|
||||
|
||||
func Test_AddPeerToAll(t *testing.T) {
|
||||
manager, err := createManager(t)
|
||||
manager, _, err := createManager(t)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
@@ -862,7 +862,7 @@ func Test_AddPeerToAll(t *testing.T) {
|
||||
}
|
||||
|
||||
func Test_AddPeerAndAddToAll(t *testing.T) {
|
||||
manager, err := createManager(t)
|
||||
manager, _, err := createManager(t)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
@@ -942,7 +942,7 @@ func uint32ToIP(n uint32) net.IP {
|
||||
}
|
||||
|
||||
func Test_IncrementNetworkSerial(t *testing.T) {
|
||||
manager, err := createManager(t)
|
||||
manager, _, err := createManager(t)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user