Compare commits

...

28 Commits

Author SHA1 Message Date
crn4
f52fda9b3c save nmaps for all 2025-11-17 15:15:16 +01:00
crn4
56bdc6eab3 Merge branch 'vk/compare-nmaps' into dbg/bothmaps 2025-11-17 15:14:23 +01:00
crn4
5603d36165 added exception on not appending route firewall rules if we have all wildcard 2025-11-17 14:25:04 +01:00
crn4
acc23f469e Merge branch 'main' into vk/compare-nmaps 2025-11-14 15:05:34 +01:00
crn4
9405c014c3 more ids 2025-11-14 14:58:40 +01:00
crn4
4536dcb2b2 exception 2025-11-14 14:37:46 +01:00
crn4
590b414ab7 save all maps 2025-11-14 13:52:53 +01:00
Viktor Liu
e4b41d0ad7 [client] Replace ipset lib (#4777)
* Replace ipset lib

* Update .github/workflows/check-license-dependencies.yml

Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>

* Ignore internal licenses

* Ignore dependencies from AGPL code

* Use exported errors

* Use fixed version

---------

Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
2025-11-14 00:25:00 +01:00
Viktor Liu
9cc9462cd5 [client] Use stdnet with a context to avoid DNS deadlocks (#4781) 2025-11-13 20:16:45 +01:00
crn4
a45ab85178 limit differences to 5 bytes - fixed 2025-11-13 17:48:40 +01:00
crn4
c0698c8153 limit differences to 5 bytes 2025-11-13 17:19:24 +01:00
crn4
b0042c5cd0 added both maps generation and save to file 2025-11-13 16:59:45 +01:00
Diego Romar
3176b53968 [client] Add quick actions window (#4717)
* Open quick settings window if netbird-ui is already running

* [client-ui] fix connection status comparison

* [client-ui] modularize quick actions code

* [client-ui] add netbird-disconnected logo

* [client-ui] change quickactions UI

It now displays the NetBird logo and a single button
with a round icon

* [client-ui] add hint message to quick actions screen

This also updates fyne to v2.7.0

* [client-ui] remove unnecessary default clause

* [client-ui] remove commented code

* [client-ui] remove unused dependency

* [client-ui] close quick actions on connection change

* [client-ui] add function to get image from embed resources

* [client] Return error when calling sendShowWindowSignal from Windows

* [client-ui] Add commentary on empty OnTapped function for toggleConnectionButton

* [client-ui] Fix tests

* [client-ui] Add context to menuUpClick call

* [client-ui] Pass serviceClient app as parameter

To use its clipboard rather than the window's when showing
the upload success dialog

* [client-ui] Replace for select with for range chan

* [client-ui] Replace settings change listener channel

Settings now accept a function callback

* [client-ui] Add missing iconAboutDisconnected to icons_windows.go

* [client] Add quick actions signal handler for Windows with named events

* [client] Run go mod tidy

* [client] Remove line break

* [client] Log unexpected status in separate function

* [client-ui] Refactor quick actions window

To address racing conditions, it also replaces
usage of pause and resume channels with an
atomic bool.

* [client-ui] use derived context from ServiceClient

* [client] Update signal_windows log message

Also, format error when trying to set event on
sendShowWindowSignal

* go mod tidy

* [client-ui] Add struct to pass fewer parameters

to applyQuickActionsUiState function

* [client] Add missing import

---------

Co-authored-by: Viktor Liu <viktor@netbird.io>
2025-11-13 10:25:19 -03:00
Viktor Liu
27957036c9 [client] Fix shutdown blocking on stuck ICE agent close (#4780) 2025-11-13 13:24:51 +01:00
Pascal Fischer
6fb568728f [management] Removed policy posture checks on original peer (#4779)
Co-authored-by: crn4 <vladimir@netbird.io>
2025-11-13 12:51:03 +01:00
Pascal Fischer
cc97cffff1 [management] move network map logic into new design (#4774) 2025-11-13 12:09:46 +01:00
Zoltan Papp
c28275611b Fix agent reference (#4776) 2025-11-11 13:59:32 +01:00
Vlad
56f169eede [management] fix pg db deadlock after app panic (#4772) 2025-11-10 23:43:08 +01:00
Viktor Liu
07cf9d5895 [client] Create networkd.conf.d if it doesn't exist (#4764) 2025-11-08 10:54:37 +01:00
Pascal Fischer
7df49e249d [management ] remove timing logs (#4761) 2025-11-07 20:14:52 +01:00
Pascal Fischer
dbfc8a52c9 [management] remove GLOBAL when disabling foreign keys on mysql (#4615) 2025-11-07 16:03:14 +01:00
Vlad
98ddac07bf [management] remove toAll firewall rule (#4725) 2025-11-07 15:50:58 +01:00
Pascal Fischer
48475ddc05 [management] add pat rate limiting (#4741) 2025-11-07 15:50:18 +01:00
Vlad
6aa4ba7af4 [management] incremental network map builder (#4753) 2025-11-07 10:44:46 +01:00
dependabot[bot]
2e16c9914a [management] Bump github.com/containerd/containerd from 1.7.27 to 1.7.29 (#4756)
Bumps [github.com/containerd/containerd](https://github.com/containerd/containerd) from 1.7.27 to 1.7.29.
- [Release notes](https://github.com/containerd/containerd/releases)
- [Changelog](https://github.com/containerd/containerd/blob/main/RELEASES.md)
- [Commits](https://github.com/containerd/containerd/compare/v1.7.27...v1.7.29)

---
updated-dependencies:
- dependency-name: github.com/containerd/containerd
  dependency-version: 1.7.29
  dependency-type: indirect
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-11-06 19:01:44 +03:00
Pascal Fischer
5c29d395b2 [management] activity events on group updates (#4750) 2025-11-06 12:51:14 +01:00
Viktor Liu
229e0038ee [client] Add dns config to debug bundle (#4704) 2025-11-05 17:30:17 +01:00
Viktor Liu
75327d9519 [client] Add login_hint to oidc flows (#4724) 2025-11-05 17:00:20 +01:00
120 changed files with 11400 additions and 2442 deletions

View File

@@ -3,10 +3,19 @@ name: Check License Dependencies
on:
push:
branches: [ main ]
paths:
- 'go.mod'
- 'go.sum'
- '.github/workflows/check-license-dependencies.yml'
pull_request:
paths:
- 'go.mod'
- 'go.sum'
- '.github/workflows/check-license-dependencies.yml'
jobs:
check-dependencies:
check-internal-dependencies:
name: Check Internal AGPL Dependencies
runs-on: ubuntu-latest
steps:
@@ -33,9 +42,67 @@ jobs:
if [ $FOUND_ISSUES -eq 1 ]; then
echo ""
echo "❌ Found dependencies on management/, signal/, or relay/ packages"
echo "These packages will change license and should not be imported by client or shared code"
echo "These packages are licensed under AGPLv3 and must not be imported by BSD-licensed code"
exit 1
else
echo ""
echo "✅ All license dependencies are clean"
echo "✅ All internal license dependencies are clean"
fi
check-external-licenses:
name: Check External GPL/AGPL Licenses
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- name: Set up Go
uses: actions/setup-go@v5
with:
go-version-file: 'go.mod'
cache: true
- name: Install go-licenses
run: go install github.com/google/go-licenses@v1.6.0
- name: Check for GPL/AGPL licensed dependencies
run: |
echo "Checking for GPL/AGPL/LGPL licensed dependencies..."
echo ""
# Check all Go packages for copyleft licenses, excluding internal netbird packages
COPYLEFT_DEPS=$(go-licenses report ./... 2>/dev/null | grep -E 'GPL|AGPL|LGPL' | grep -v 'github.com/netbirdio/netbird/' || true)
if [ -n "$COPYLEFT_DEPS" ]; then
echo "Found copyleft licensed dependencies:"
echo "$COPYLEFT_DEPS"
echo ""
# Filter out dependencies that are only pulled in by internal AGPL packages
INCOMPATIBLE=""
while IFS=',' read -r package url license; do
if echo "$license" | grep -qE 'GPL-[0-9]|AGPL-[0-9]|LGPL-[0-9]'; then
# Find ALL packages that import this GPL package using go list
IMPORTERS=$(go list -json -deps ./... 2>/dev/null | jq -r "select(.Imports[]? == \"$package\") | .ImportPath")
# Check if any importer is NOT in management/signal/relay
BSD_IMPORTER=$(echo "$IMPORTERS" | grep -v "github.com/netbirdio/netbird/\(management\|signal\|relay\)" | head -1)
if [ -n "$BSD_IMPORTER" ]; then
echo "❌ $package ($license) is imported by BSD-licensed code: $BSD_IMPORTER"
INCOMPATIBLE="${INCOMPATIBLE}${package},${url},${license}\n"
else
echo "✓ $package ($license) is only used by internal AGPL packages - OK"
fi
fi
done <<< "$COPYLEFT_DEPS"
if [ -n "$INCOMPATIBLE" ]; then
echo ""
echo "❌ INCOMPATIBLE licenses found that are used by BSD-licensed code:"
echo -e "$INCOMPATIBLE"
exit 1
fi
fi
echo "✅ All external license dependencies are compatible with BSD-3-Clause"

View File

@@ -200,7 +200,7 @@ func (a *Auth) login(urlOpener URLOpener) error {
}
func (a *Auth) foregroundGetTokenInfo(urlOpener URLOpener) (*auth.TokenInfo, error) {
oAuthFlow, err := auth.NewOAuthFlow(a.ctx, a.config, false)
oAuthFlow, err := auth.NewOAuthFlow(a.ctx, a.config, false, "")
if err != nil {
return nil, err
}

View File

@@ -106,6 +106,13 @@ func doDaemonLogin(ctx context.Context, cmd *cobra.Command, providedSetupKey str
Username: &username,
}
profileState, err := pm.GetProfileState(activeProf.Name)
if err != nil {
log.Debugf("failed to get profile state for login hint: %v", err)
} else if profileState.Email != "" {
loginRequest.Hint = &profileState.Email
}
if rootCmd.PersistentFlags().Changed(preSharedKeyFlag) {
loginRequest.OptionalPreSharedKey = &preSharedKey
}
@@ -241,7 +248,7 @@ func doForegroundLogin(ctx context.Context, cmd *cobra.Command, setupKey string,
return fmt.Errorf("read config file %s: %v", configFilePath, err)
}
err = foregroundLogin(ctx, cmd, config, setupKey)
err = foregroundLogin(ctx, cmd, config, setupKey, activeProf.Name)
if err != nil {
return fmt.Errorf("foreground login failed: %v", err)
}
@@ -269,7 +276,7 @@ func handleSSOLogin(ctx context.Context, cmd *cobra.Command, loginResp *proto.Lo
return nil
}
func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *profilemanager.Config, setupKey string) error {
func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *profilemanager.Config, setupKey, profileName string) error {
needsLogin := false
err := WithBackOff(func() error {
@@ -286,7 +293,7 @@ func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *profileman
jwtToken := ""
if setupKey == "" && needsLogin {
tokenInfo, err := foregroundGetTokenInfo(ctx, cmd, config)
tokenInfo, err := foregroundGetTokenInfo(ctx, cmd, config, profileName)
if err != nil {
return fmt.Errorf("interactive sso login failed: %v", err)
}
@@ -315,8 +322,17 @@ func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *profileman
return nil
}
func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *profilemanager.Config) (*auth.TokenInfo, error) {
oAuthFlow, err := auth.NewOAuthFlow(ctx, config, isUnixRunningDesktop())
func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *profilemanager.Config, profileName string) (*auth.TokenInfo, error) {
hint := ""
pm := profilemanager.NewProfileManager()
profileState, err := pm.GetProfileState(profileName)
if err != nil {
log.Debugf("failed to get profile state for login hint: %v", err)
} else if profileState.Email != "" {
hint = profileState.Email
}
oAuthFlow, err := auth.NewOAuthFlow(ctx, config, isUnixRunningDesktop(), hint)
if err != nil {
return nil, err
}

View File

@@ -259,6 +259,7 @@ func isServiceRunning() (bool, error) {
}
const (
networkdConf = "/etc/systemd/networkd.conf"
networkdConfDir = "/etc/systemd/networkd.conf.d"
networkdConfFile = "/etc/systemd/networkd.conf.d/99-netbird.conf"
networkdConfContent = `# Created by NetBird to prevent systemd-networkd from removing
@@ -273,12 +274,16 @@ ManageForeignRoutingPolicyRules=no
// configureSystemdNetworkd creates a drop-in configuration file to prevent
// systemd-networkd from removing NetBird's routes and policy rules.
func configureSystemdNetworkd() error {
parentDir := filepath.Dir(networkdConfDir)
if _, err := os.Stat(parentDir); os.IsNotExist(err) {
log.Debug("systemd networkd.conf.d parent directory does not exist, skipping configuration")
if _, err := os.Stat(networkdConf); os.IsNotExist(err) {
log.Debug("systemd-networkd not in use, skipping configuration")
return nil
}
// nolint:gosec // standard networkd permissions
if err := os.MkdirAll(networkdConfDir, 0755); err != nil {
return fmt.Errorf("create networkd.conf.d directory: %w", err)
}
// nolint:gosec // standard networkd permissions
if err := os.WriteFile(networkdConfFile, []byte(networkdConfContent), 0644); err != nil {
return fmt.Errorf("write networkd configuration: %w", err)

View File

@@ -12,6 +12,9 @@ import (
"google.golang.org/grpc"
"github.com/netbirdio/management-integrations/integrations"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
clientProto "github.com/netbirdio/netbird/client/proto"
client "github.com/netbirdio/netbird/client/server"
@@ -84,7 +87,6 @@ func startManagement(t *testing.T, config *config.Config, testFile string) (*grp
}
t.Cleanup(cleanUp)
peersUpdateManager := mgmt.NewPeersUpdateManager(nil)
eventStore := &activity.InMemoryEventStore{}
if err != nil {
return nil, nil
@@ -110,13 +112,18 @@ func startManagement(t *testing.T, config *config.Config, testFile string) (*grp
Return(&types.Settings{}, nil).
AnyTimes()
accountManager, err := mgmt.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, iv, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false)
ctx := context.Background()
updateManager := update_channel.NewPeersUpdateManager(metrics)
requestBuffer := mgmt.NewAccountRequestBuffer(ctx, store)
networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, mgmt.MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock())
accountManager, err := mgmt.BuildManager(context.Background(), store, networkMapController, nil, "", eventStore, nil, false, iv, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false)
if err != nil {
t.Fatal(err)
}
secretsManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, &manager.EphemeralManager{}, nil, &mgmt.MockIntegratedValidator{})
secretsManager := nbgrpc.NewTimeBasedAuthSecretsManager(updateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, updateManager, secretsManager, nil, &manager.EphemeralManager{}, nil, &mgmt.MockIntegratedValidator{}, networkMapController)
if err != nil {
t.Fatal(err)
}

View File

@@ -185,7 +185,7 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command, activeProf *pr
_, _ = profilemanager.UpdateOldManagementURL(ctx, config, configFilePath)
err = foregroundLogin(ctx, cmd, config, providedSetupKey)
err = foregroundLogin(ctx, cmd, config, providedSetupKey, activeProf.Name)
if err != nil {
return fmt.Errorf("foreground login failed: %v", err)
}
@@ -286,6 +286,13 @@ func doDaemonUp(ctx context.Context, cmd *cobra.Command, client proto.DaemonServ
loginRequest.ProfileName = &activeProf.Name
loginRequest.Username = &username
profileState, err := pm.GetProfileState(activeProf.Name)
if err != nil {
log.Debugf("failed to get profile state for login hint: %v", err)
} else if profileState.Email != "" {
loginRequest.Hint = &profileState.Email
}
var loginErr error
var loginResp *proto.LoginResponse

View File

@@ -1,13 +1,14 @@
package iptables
import (
"errors"
"fmt"
"net"
"slices"
"github.com/coreos/go-iptables/iptables"
"github.com/google/uuid"
"github.com/nadoo/ipset"
ipset "github.com/lrh3321/ipset-go"
log "github.com/sirupsen/logrus"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
@@ -40,19 +41,13 @@ type aclManager struct {
}
func newAclManager(iptablesClient *iptables.IPTables, wgIface iFaceMapper) (*aclManager, error) {
m := &aclManager{
return &aclManager{
iptablesClient: iptablesClient,
wgIface: wgIface,
entries: make(map[string][][]string),
optionalEntries: make(map[string][]entry),
ipsetStore: newIpsetStore(),
}
if err := ipset.Init(); err != nil {
return nil, fmt.Errorf("init ipset: %w", err)
}
return m, nil
}, nil
}
func (m *aclManager) init(stateManager *statemanager.Manager) error {
@@ -98,8 +93,8 @@ func (m *aclManager) AddPeerFiltering(
specs = append(specs, "-j", actionToStr(action))
if ipsetName != "" {
if ipList, ipsetExists := m.ipsetStore.ipset(ipsetName); ipsetExists {
if err := ipset.Add(ipsetName, ip.String()); err != nil {
return nil, fmt.Errorf("failed to add IP to ipset: %w", err)
if err := m.addToIPSet(ipsetName, ip); err != nil {
return nil, fmt.Errorf("add IP to ipset: %w", err)
}
// if ruleset already exists it means we already have the firewall rule
// so we need to update IPs in the ruleset and return new fw.Rule object for ACL manager.
@@ -113,14 +108,18 @@ func (m *aclManager) AddPeerFiltering(
}}, nil
}
if err := ipset.Flush(ipsetName); err != nil {
log.Errorf("flush ipset %s before use it: %s", ipsetName, err)
if err := m.flushIPSet(ipsetName); err != nil {
if errors.Is(err, ipset.ErrSetNotExist) {
log.Debugf("flush ipset %s before use: %v", ipsetName, err)
} else {
log.Errorf("flush ipset %s before use: %v", ipsetName, err)
}
}
if err := ipset.Create(ipsetName); err != nil {
return nil, fmt.Errorf("failed to create ipset: %w", err)
if err := m.createIPSet(ipsetName); err != nil {
return nil, fmt.Errorf("create ipset: %w", err)
}
if err := ipset.Add(ipsetName, ip.String()); err != nil {
return nil, fmt.Errorf("failed to add IP to ipset: %w", err)
if err := m.addToIPSet(ipsetName, ip); err != nil {
return nil, fmt.Errorf("add IP to ipset: %w", err)
}
ipList := newIpList(ip.String())
@@ -172,11 +171,16 @@ func (m *aclManager) DeletePeerRule(rule firewall.Rule) error {
return fmt.Errorf("invalid rule type")
}
shouldDestroyIpset := false
if ipsetList, ok := m.ipsetStore.ipset(r.ipsetName); ok {
// delete IP from ruleset IPs list and ipset
if _, ok := ipsetList.ips[r.ip]; ok {
if err := ipset.Del(r.ipsetName, r.ip); err != nil {
return fmt.Errorf("failed to delete ip from ipset: %w", err)
ip := net.ParseIP(r.ip)
if ip == nil {
return fmt.Errorf("parse IP %s", r.ip)
}
if err := m.delFromIPSet(r.ipsetName, ip); err != nil {
return fmt.Errorf("delete ip from ipset: %w", err)
}
delete(ipsetList.ips, r.ip)
}
@@ -190,10 +194,7 @@ func (m *aclManager) DeletePeerRule(rule firewall.Rule) error {
// we delete last IP from the set, that means we need to delete
// set itself and associated firewall rule too
m.ipsetStore.deleteIpset(r.ipsetName)
if err := ipset.Destroy(r.ipsetName); err != nil {
log.Errorf("delete empty ipset: %v", err)
}
shouldDestroyIpset = true
}
if err := m.iptablesClient.Delete(tableName, r.chain, r.specs...); err != nil {
@@ -206,6 +207,16 @@ func (m *aclManager) DeletePeerRule(rule firewall.Rule) error {
}
}
if shouldDestroyIpset {
if err := m.destroyIPSet(r.ipsetName); err != nil {
if errors.Is(err, ipset.ErrBusy) || errors.Is(err, ipset.ErrSetNotExist) {
log.Debugf("destroy empty ipset: %v", err)
} else {
log.Errorf("destroy empty ipset: %v", err)
}
}
}
m.updateState()
return nil
@@ -264,11 +275,19 @@ func (m *aclManager) cleanChains() error {
}
for _, ipsetName := range m.ipsetStore.ipsetNames() {
if err := ipset.Flush(ipsetName); err != nil {
log.Errorf("flush ipset %q during reset: %v", ipsetName, err)
if err := m.flushIPSet(ipsetName); err != nil {
if errors.Is(err, ipset.ErrSetNotExist) {
log.Debugf("flush ipset %q during reset: %v", ipsetName, err)
} else {
log.Errorf("flush ipset %q during reset: %v", ipsetName, err)
}
}
if err := ipset.Destroy(ipsetName); err != nil {
log.Errorf("delete ipset %q during reset: %v", ipsetName, err)
if err := m.destroyIPSet(ipsetName); err != nil {
if errors.Is(err, ipset.ErrBusy) || errors.Is(err, ipset.ErrSetNotExist) {
log.Debugf("destroy ipset %q during reset: %v", ipsetName, err)
} else {
log.Errorf("destroy ipset %q during reset: %v", ipsetName, err)
}
}
m.ipsetStore.deleteIpset(ipsetName)
}
@@ -368,8 +387,8 @@ func (m *aclManager) updateState() {
// filterRuleSpecs returns the specs of a filtering rule
func filterRuleSpecs(ip net.IP, protocol string, sPort, dPort *firewall.Port, action firewall.Action, ipsetName string) (specs []string) {
matchByIP := true
// don't use IP matching if IP is ip 0.0.0.0
if ip.String() == "0.0.0.0" {
// don't use IP matching if IP is 0.0.0.0
if ip.IsUnspecified() {
matchByIP = false
}
@@ -416,3 +435,61 @@ func transformIPsetName(ipsetName string, sPort, dPort *firewall.Port, action fi
return ipsetName + actionSuffix
}
}
func (m *aclManager) createIPSet(name string) error {
opts := ipset.CreateOptions{
Replace: true,
}
if err := ipset.Create(name, ipset.TypeHashNet, opts); err != nil {
return fmt.Errorf("create ipset %s: %w", name, err)
}
log.Debugf("created ipset %s with type hash:net", name)
return nil
}
func (m *aclManager) addToIPSet(name string, ip net.IP) error {
cidr := uint8(32)
if ip.To4() == nil {
cidr = 128
}
entry := &ipset.Entry{
IP: ip,
CIDR: cidr,
Replace: true,
}
if err := ipset.Add(name, entry); err != nil {
return fmt.Errorf("add IP to ipset %s: %w", name, err)
}
return nil
}
func (m *aclManager) delFromIPSet(name string, ip net.IP) error {
cidr := uint8(32)
if ip.To4() == nil {
cidr = 128
}
entry := &ipset.Entry{
IP: ip,
CIDR: cidr,
}
if err := ipset.Del(name, entry); err != nil {
return fmt.Errorf("delete IP from ipset %s: %w", name, err)
}
return nil
}
func (m *aclManager) flushIPSet(name string) error {
return ipset.Flush(name)
}
func (m *aclManager) destroyIPSet(name string) error {
return ipset.Destroy(name)
}

View File

@@ -10,7 +10,7 @@ import (
"github.com/coreos/go-iptables/iptables"
"github.com/hashicorp/go-multierror"
"github.com/nadoo/ipset"
ipset "github.com/lrh3321/ipset-go"
log "github.com/sirupsen/logrus"
nberrors "github.com/netbirdio/netbird/client/errors"
@@ -107,10 +107,6 @@ func newRouter(iptablesClient *iptables.IPTables, wgIface iFaceMapper, mtu uint1
},
)
if err := ipset.Init(); err != nil {
return nil, fmt.Errorf("init ipset: %w", err)
}
return r, nil
}
@@ -232,12 +228,12 @@ func (r *router) findSets(rule []string) []string {
}
func (r *router) createIpSet(setName string, sources []netip.Prefix) error {
if err := ipset.Create(setName, ipset.OptTimeout(0)); err != nil {
if err := r.createIPSet(setName); err != nil {
return fmt.Errorf("create set %s: %w", setName, err)
}
for _, prefix := range sources {
if err := ipset.AddPrefix(setName, prefix); err != nil {
if err := r.addPrefixToIPSet(setName, prefix); err != nil {
return fmt.Errorf("add element to set %s: %w", setName, err)
}
}
@@ -246,7 +242,7 @@ func (r *router) createIpSet(setName string, sources []netip.Prefix) error {
}
func (r *router) deleteIpSet(setName string) error {
if err := ipset.Destroy(setName); err != nil {
if err := r.destroyIPSet(setName); err != nil {
return fmt.Errorf("destroy set %s: %w", setName, err)
}
@@ -915,8 +911,8 @@ func (r *router) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
log.Tracef("skipping IPv6 prefix %s: IPv6 support not yet implemented", prefix)
continue
}
if err := ipset.AddPrefix(set.HashedName(), prefix); err != nil {
merr = multierror.Append(merr, fmt.Errorf("increment ipset counter: %w", err))
if err := r.addPrefixToIPSet(set.HashedName(), prefix); err != nil {
merr = multierror.Append(merr, fmt.Errorf("add prefix to ipset: %w", err))
}
}
if merr == nil {
@@ -993,3 +989,37 @@ func applyPort(flag string, port *firewall.Port) []string {
return []string{flag, strconv.Itoa(int(port.Values[0]))}
}
func (r *router) createIPSet(name string) error {
opts := ipset.CreateOptions{
Replace: true,
}
if err := ipset.Create(name, ipset.TypeHashNet, opts); err != nil {
return fmt.Errorf("create ipset %s: %w", name, err)
}
log.Debugf("created ipset %s with type hash:net", name)
return nil
}
func (r *router) addPrefixToIPSet(name string, prefix netip.Prefix) error {
addr := prefix.Addr()
ip := addr.AsSlice()
entry := &ipset.Entry{
IP: ip,
CIDR: uint8(prefix.Bits()),
Replace: true,
}
if err := ipset.Add(name, entry); err != nil {
return fmt.Errorf("add prefix to ipset %s: %w", name, err)
}
return nil
}
func (r *router) destroyIPSet(name string) error {
return ipset.Destroy(name)
}

View File

@@ -1,6 +1,7 @@
package iface
import (
"context"
"fmt"
"net"
"net/netip"
@@ -9,13 +10,13 @@ import (
"time"
"github.com/google/uuid"
"github.com/pion/transport/v3/stdnet"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"golang.zx2c4.com/wireguard/wgctrl"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/internal/stdnet"
)
// keep darwin compatibility
@@ -40,7 +41,7 @@ func TestWGIface_UpdateAddr(t *testing.T) {
ifaceName := fmt.Sprintf("utun%d", WgIntNumber+4)
addr := "100.64.0.1/8"
wgPort := 33100
newNet, err := stdnet.NewNet()
newNet, err := stdnet.NewNet(context.Background(), nil)
if err != nil {
t.Fatal(err)
}
@@ -123,7 +124,7 @@ func getIfaceAddrs(ifaceName string) ([]net.Addr, error) {
func Test_CreateInterface(t *testing.T) {
ifaceName := fmt.Sprintf("utun%d", WgIntNumber+1)
wgIP := "10.99.99.1/32"
newNet, err := stdnet.NewNet()
newNet, err := stdnet.NewNet(context.Background(), nil)
if err != nil {
t.Fatal(err)
}
@@ -166,7 +167,7 @@ func Test_Close(t *testing.T) {
ifaceName := fmt.Sprintf("utun%d", WgIntNumber+2)
wgIP := "10.99.99.2/32"
wgPort := 33100
newNet, err := stdnet.NewNet()
newNet, err := stdnet.NewNet(context.Background(), nil)
if err != nil {
t.Fatal(err)
}
@@ -211,7 +212,7 @@ func TestRecreation(t *testing.T) {
ifaceName := fmt.Sprintf("utun%d", WgIntNumber+2)
wgIP := "10.99.99.2/32"
wgPort := 33100
newNet, err := stdnet.NewNet()
newNet, err := stdnet.NewNet(context.Background(), nil)
if err != nil {
t.Fatal(err)
}
@@ -284,7 +285,7 @@ func Test_ConfigureInterface(t *testing.T) {
ifaceName := fmt.Sprintf("utun%d", WgIntNumber+3)
wgIP := "10.99.99.5/30"
wgPort := 33100
newNet, err := stdnet.NewNet()
newNet, err := stdnet.NewNet(context.Background(), nil)
if err != nil {
t.Fatal(err)
}
@@ -339,7 +340,7 @@ func Test_ConfigureInterface(t *testing.T) {
func Test_UpdatePeer(t *testing.T) {
ifaceName := fmt.Sprintf("utun%d", WgIntNumber+4)
wgIP := "10.99.99.9/30"
newNet, err := stdnet.NewNet()
newNet, err := stdnet.NewNet(context.Background(), nil)
if err != nil {
t.Fatal(err)
}
@@ -409,7 +410,7 @@ func Test_UpdatePeer(t *testing.T) {
func Test_RemovePeer(t *testing.T) {
ifaceName := fmt.Sprintf("utun%d", WgIntNumber+4)
wgIP := "10.99.99.13/30"
newNet, err := stdnet.NewNet()
newNet, err := stdnet.NewNet(context.Background(), nil)
if err != nil {
t.Fatal(err)
}
@@ -471,7 +472,7 @@ func Test_ConnectPeers(t *testing.T) {
peer2wgPort := 33200
keepAlive := 1 * time.Second
newNet, err := stdnet.NewNet()
newNet, err := stdnet.NewNet(context.Background(), nil)
if err != nil {
t.Fatal(err)
}
@@ -514,7 +515,7 @@ func Test_ConnectPeers(t *testing.T) {
guid = fmt.Sprintf("{%s}", uuid.New().String())
device.CustomWindowsGUIDString = strings.ToLower(guid)
newNet, err = stdnet.NewNet()
newNet, err = stdnet.NewNet(context.Background(), nil)
if err != nil {
t.Fatal(err)
}

View File

@@ -1,6 +1,7 @@
package udpmux
import (
"context"
"fmt"
"io"
"net"
@@ -12,8 +13,9 @@ import (
"github.com/pion/logging"
"github.com/pion/stun/v3"
"github.com/pion/transport/v3"
"github.com/pion/transport/v3/stdnet"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/stdnet"
)
/*
@@ -199,7 +201,7 @@ func (m *SingleSocketUDPMux) updateLocalAddresses() {
if len(networks) > 0 {
if m.params.Net == nil {
var err error
if m.params.Net, err = stdnet.NewNet(); err != nil {
if m.params.Net, err = stdnet.NewNet(context.Background(), nil); err != nil {
m.params.Logger.Errorf("failed to get create network: %v", err)
}
}

View File

@@ -128,9 +128,34 @@ func (d *DeviceAuthorizationFlow) RequestAuthInfo(ctx context.Context) (AuthFlow
deviceCode.VerificationURIComplete = deviceCode.VerificationURI
}
if d.providerConfig.LoginHint != "" {
deviceCode.VerificationURIComplete = appendLoginHint(deviceCode.VerificationURIComplete, d.providerConfig.LoginHint)
if deviceCode.VerificationURI != "" {
deviceCode.VerificationURI = appendLoginHint(deviceCode.VerificationURI, d.providerConfig.LoginHint)
}
}
return deviceCode, err
}
func appendLoginHint(uri, loginHint string) string {
if uri == "" || loginHint == "" {
return uri
}
parsedURL, err := url.Parse(uri)
if err != nil {
log.Debugf("failed to parse verification URI for login_hint: %v", err)
return uri
}
query := parsedURL.Query()
query.Set("login_hint", loginHint)
parsedURL.RawQuery = query.Encode()
return parsedURL.String()
}
func (d *DeviceAuthorizationFlow) requestToken(info AuthFlowInfo) (TokenRequestResponse, error) {
form := url.Values{}
form.Add("client_id", d.providerConfig.ClientID)

View File

@@ -66,32 +66,34 @@ func (t TokenInfo) GetTokenToUse() string {
// and if that also fails, the authentication process is deemed unsuccessful
//
// On Linux distros without desktop environment support, it only tries to initialize the Device Code Flow
func NewOAuthFlow(ctx context.Context, config *profilemanager.Config, isUnixDesktopClient bool) (OAuthFlow, error) {
func NewOAuthFlow(ctx context.Context, config *profilemanager.Config, isUnixDesktopClient bool, hint string) (OAuthFlow, error) {
if (runtime.GOOS == "linux" || runtime.GOOS == "freebsd") && !isUnixDesktopClient {
return authenticateWithDeviceCodeFlow(ctx, config)
return authenticateWithDeviceCodeFlow(ctx, config, hint)
}
pkceFlow, err := authenticateWithPKCEFlow(ctx, config)
pkceFlow, err := authenticateWithPKCEFlow(ctx, config, hint)
if err != nil {
// fallback to device code flow
log.Debugf("failed to initialize pkce authentication with error: %v\n", err)
log.Debug("falling back to device code flow")
return authenticateWithDeviceCodeFlow(ctx, config)
return authenticateWithDeviceCodeFlow(ctx, config, hint)
}
return pkceFlow, nil
}
// authenticateWithPKCEFlow initializes the Proof Key for Code Exchange flow auth flow
func authenticateWithPKCEFlow(ctx context.Context, config *profilemanager.Config) (OAuthFlow, error) {
func authenticateWithPKCEFlow(ctx context.Context, config *profilemanager.Config, hint string) (OAuthFlow, error) {
pkceFlowInfo, err := internal.GetPKCEAuthorizationFlowInfo(ctx, config.PrivateKey, config.ManagementURL, config.ClientCertKeyPair)
if err != nil {
return nil, fmt.Errorf("getting pkce authorization flow info failed with error: %v", err)
}
pkceFlowInfo.ProviderConfig.LoginHint = hint
return NewPKCEAuthorizationFlow(pkceFlowInfo.ProviderConfig)
}
// authenticateWithDeviceCodeFlow initializes the Device Code auth Flow
func authenticateWithDeviceCodeFlow(ctx context.Context, config *profilemanager.Config) (OAuthFlow, error) {
func authenticateWithDeviceCodeFlow(ctx context.Context, config *profilemanager.Config, hint string) (OAuthFlow, error) {
deviceFlowInfo, err := internal.GetDeviceAuthorizationFlowInfo(ctx, config.PrivateKey, config.ManagementURL)
if err != nil {
switch s, ok := gstatus.FromError(err); {
@@ -107,5 +109,7 @@ func authenticateWithDeviceCodeFlow(ctx context.Context, config *profilemanager.
}
}
deviceFlowInfo.ProviderConfig.LoginHint = hint
return NewDeviceAuthorizationFlow(deviceFlowInfo.ProviderConfig)
}

View File

@@ -109,6 +109,9 @@ func (p *PKCEAuthorizationFlow) RequestAuthInfo(ctx context.Context) (AuthFlowIn
params = append(params, oauth2.SetAuthURLParam("max_age", "0"))
}
}
if p.providerConfig.LoginHint != "" {
params = append(params, oauth2.SetAuthURLParam("login_hint", p.providerConfig.LoginHint))
}
authURL := p.oAuthConfig.AuthCodeURL(state, params...)

View File

@@ -44,6 +44,8 @@ interfaces.txt: Anonymized network interface information, if --system-info flag
ip_rules.txt: Detailed IP routing rules in tabular format including priority, source, destination, interfaces, table, and action information (Linux only), if --system-info flag was provided.
iptables.txt: Anonymized iptables rules with packet counters, if --system-info flag was provided.
nftables.txt: Anonymized nftables rules with packet counters, if --system-info flag was provided.
resolv.conf: DNS resolver configuration from /etc/resolv.conf (Unix systems only), if --system-info flag was provided.
scutil_dns.txt: DNS configuration from scutil --dns (macOS only), if --system-info flag was provided.
resolved_domains.txt: Anonymized resolved domain IP addresses from the status recorder.
config.txt: Anonymized configuration information of the NetBird client.
network_map.json: Anonymized sync response containing peer configurations, routes, DNS settings, and firewall rules.
@@ -184,6 +186,20 @@ The ip_rules.txt file contains detailed IP routing rule information:
The table format provides comprehensive visibility into the IP routing decision process, including how traffic is directed to different routing tables based on various criteria. This is valuable for troubleshooting advanced routing configurations and policy-based routing.
For anonymized rules, IP addresses and prefixes are replaced as described above. Interface names are anonymized using string anonymization. Table names, actions, and other non-sensitive information remain unchanged.
DNS Configuration
The debug bundle includes platform-specific DNS configuration files:
resolv.conf (Unix systems):
- Contains DNS resolver configuration from /etc/resolv.conf
- Includes nameserver entries, search domains, and resolver options
- All IP addresses and domain names are anonymized following the same rules as other files
scutil_dns.txt (macOS only):
- Contains detailed DNS configuration from scutil --dns
- Shows DNS configuration for all network interfaces
- Includes search domains, nameservers, and DNS resolver settings
- All IP addresses and domain names are anonymized
`
const (
@@ -357,6 +373,10 @@ func (g *BundleGenerator) addSystemInfo() {
if err := g.addFirewallRules(); err != nil {
log.Errorf("failed to add firewall rules to debug bundle: %v", err)
}
if err := g.addDNSInfo(); err != nil {
log.Errorf("failed to add DNS info to debug bundle: %v", err)
}
}
func (g *BundleGenerator) addReadme() error {

View File

@@ -0,0 +1,53 @@
//go:build darwin && !ios
package debug
import (
"bytes"
"context"
"fmt"
"os/exec"
"strings"
"time"
log "github.com/sirupsen/logrus"
)
// addDNSInfo collects and adds DNS configuration information to the archive
func (g *BundleGenerator) addDNSInfo() error {
if err := g.addResolvConf(); err != nil {
log.Errorf("failed to add resolv.conf: %v", err)
}
if err := g.addScutilDNS(); err != nil {
log.Errorf("failed to add scutil DNS output: %v", err)
}
return nil
}
func (g *BundleGenerator) addScutilDNS() error {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
cmd := exec.CommandContext(ctx, "scutil", "--dns")
output, err := cmd.CombinedOutput()
if err != nil {
return fmt.Errorf("execute scutil --dns: %w", err)
}
if len(bytes.TrimSpace(output)) == 0 {
return fmt.Errorf("no scutil DNS output")
}
content := string(output)
if g.anonymize {
content = g.anonymizer.AnonymizeString(content)
}
if err := g.addFileToZip(strings.NewReader(content), "scutil_dns.txt"); err != nil {
return fmt.Errorf("add scutil DNS output to zip: %w", err)
}
return nil
}

View File

@@ -5,3 +5,7 @@ package debug
func (g *BundleGenerator) addRoutes() error {
return nil
}
func (g *BundleGenerator) addDNSInfo() error {
return nil
}

View File

@@ -0,0 +1,16 @@
//go:build unix && !darwin && !android
package debug
import (
log "github.com/sirupsen/logrus"
)
// addDNSInfo collects and adds DNS configuration information to the archive
func (g *BundleGenerator) addDNSInfo() error {
if err := g.addResolvConf(); err != nil {
log.Errorf("failed to add resolv.conf: %v", err)
}
return nil
}

View File

@@ -0,0 +1,7 @@
//go:build !unix
package debug
func (g *BundleGenerator) addDNSInfo() error {
return nil
}

View File

@@ -0,0 +1,29 @@
//go:build unix && !android
package debug
import (
"fmt"
"os"
"strings"
)
const resolvConfPath = "/etc/resolv.conf"
func (g *BundleGenerator) addResolvConf() error {
data, err := os.ReadFile(resolvConfPath)
if err != nil {
return fmt.Errorf("read %s: %w", resolvConfPath, err)
}
content := string(data)
if g.anonymize {
content = g.anonymizer.AnonymizeString(content)
}
if err := g.addFileToZip(strings.NewReader(content), "resolv.conf"); err != nil {
return fmt.Errorf("add resolv.conf to zip: %w", err)
}
return nil
}

View File

@@ -38,6 +38,8 @@ type DeviceAuthProviderConfig struct {
Scope string
// UseIDToken indicates if the id token should be used for authentication
UseIDToken bool
// LoginHint is used to pre-fill the email/username field during authentication
LoginHint string
}
// GetDeviceAuthorizationFlowInfo initialize a DeviceAuthorizationFlow instance and return with it

View File

@@ -335,7 +335,7 @@ func TestUpdateDNSServer(t *testing.T) {
for n, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
privKey, _ := wgtypes.GenerateKey()
newNet, err := stdnet.NewNet(nil)
newNet, err := stdnet.NewNet(context.Background(), nil)
if err != nil {
t.Fatal(err)
}
@@ -434,7 +434,7 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) {
defer t.Setenv("NB_WG_KERNEL_DISABLED", ov)
t.Setenv("NB_WG_KERNEL_DISABLED", "true")
newNet, err := stdnet.NewNet([]string{"utun2301"})
newNet, err := stdnet.NewNet(context.Background(), []string{"utun2301"})
if err != nil {
t.Errorf("create stdnet: %v", err)
return
@@ -915,7 +915,7 @@ func createWgInterfaceWithBind(t *testing.T) (*iface.WGIface, error) {
defer t.Setenv("NB_WG_KERNEL_DISABLED", ov)
t.Setenv("NB_WG_KERNEL_DISABLED", "true")
newNet, err := stdnet.NewNet([]string{"utun2301"})
newNet, err := stdnet.NewNet(context.Background(), []string{"utun2301"})
if err != nil {
t.Fatalf("create stdnet: %v", err)
return nil, err

View File

@@ -7,5 +7,5 @@ import (
)
func (e *Engine) newStdNet() (*stdnet.Net, error) {
return stdnet.NewNet(e.config.IFaceBlackList)
return stdnet.NewNet(e.clientCtx, e.config.IFaceBlackList)
}

View File

@@ -3,5 +3,5 @@ package internal
import "github.com/netbirdio/netbird/client/internal/stdnet"
func (e *Engine) newStdNet() (*stdnet.Net, error) {
return stdnet.NewNetWithDiscover(e.mobileDep.IFaceDiscover, e.config.IFaceBlackList)
return stdnet.NewNetWithDiscover(e.clientCtx, e.mobileDep.IFaceDiscover, e.config.IFaceBlackList)
}

View File

@@ -14,7 +14,7 @@ import (
"github.com/golang/mock/gomock"
"github.com/google/uuid"
"github.com/pion/transport/v3/stdnet"
"github.com/netbirdio/netbird/client/internal/stdnet"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@@ -26,6 +26,9 @@ import (
"google.golang.org/grpc/keepalive"
"github.com/netbirdio/management-integrations/integrations"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
"github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/management/server/groups"
@@ -771,7 +774,7 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) {
MTU: iface.DefaultMTU,
}, MobileDependency{}, peer.NewRecorder("https://mgm"), nil)
engine.ctx = ctx
newNet, err := stdnet.NewNet()
newNet, err := stdnet.NewNet(context.Background(), nil)
if err != nil {
t.Fatal(err)
}
@@ -974,7 +977,7 @@ func TestEngine_UpdateNetworkMapWithDNSUpdate(t *testing.T) {
}, MobileDependency{}, peer.NewRecorder("https://mgm"), nil)
engine.ctx = ctx
newNet, err := stdnet.NewNet()
newNet, err := stdnet.NewNet(context.Background(), nil)
if err != nil {
t.Fatal(err)
}
@@ -1556,7 +1559,6 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri
}
t.Cleanup(cleanUp)
peersUpdateManager := server.NewPeersUpdateManager(nil)
eventStore := &activity.InMemoryEventStore{}
if err != nil {
return nil, "", err
@@ -1584,13 +1586,16 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri
groupsManager := groups.NewManagerMock()
accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
updateManager := update_channel.NewPeersUpdateManager(metrics)
requestBuffer := server.NewAccountRequestBuffer(context.Background(), store)
networkMapController := controller.NewController(context.Background(), store, metrics, updateManager, requestBuffer, server.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock())
accountManager, err := server.BuildManager(context.Background(), store, networkMapController, nil, "", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
if err != nil {
return nil, "", err
}
secretsManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
mgmtServer, err := server.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, &manager.EphemeralManager{}, nil, &server.MockIntegratedValidator{})
secretsManager := nbgrpc.NewTimeBasedAuthSecretsManager(updateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, updateManager, secretsManager, nil, &manager.EphemeralManager{}, nil, &server.MockIntegratedValidator{}, networkMapController)
if err != nil {
return nil, "", err
}

View File

@@ -78,7 +78,7 @@ func (cm *ICEMonitor) Start(ctx context.Context, onChanged func()) {
func (cm *ICEMonitor) handleCandidateTick(ctx context.Context, ufrag string, pwd string) (bool, error) {
log.Debugf("Gathering ICE candidates")
agent, err := icemaker.NewAgent(cm.iFaceDiscover, cm.iceConfig, candidateTypesP2P(), ufrag, pwd)
agent, err := icemaker.NewAgent(ctx, cm.iFaceDiscover, cm.iceConfig, candidateTypesP2P(), ufrag, pwd)
if err != nil {
return false, fmt.Errorf("create ICE agent: %w", err)
}

View File

@@ -1,6 +1,7 @@
package ice
import (
"context"
"sync"
"time"
@@ -22,6 +23,8 @@ const (
iceFailedTimeoutDefault = 6 * time.Second
// iceRelayAcceptanceMinWaitDefault is the same as in the Pion ICE package
iceRelayAcceptanceMinWaitDefault = 2 * time.Second
// iceAgentCloseTimeout is the maximum time to wait for ICE agent close to complete
iceAgentCloseTimeout = 3 * time.Second
)
type ThreadSafeAgent struct {
@@ -32,18 +35,28 @@ type ThreadSafeAgent struct {
func (a *ThreadSafeAgent) Close() error {
var err error
a.once.Do(func() {
err = a.Agent.Close()
done := make(chan error, 1)
go func() {
done <- a.Agent.Close()
}()
select {
case err = <-done:
case <-time.After(iceAgentCloseTimeout):
log.Warnf("ICE agent close timed out after %v, proceeding with cleanup", iceAgentCloseTimeout)
err = nil
}
})
return err
}
func NewAgent(iFaceDiscover stdnet.ExternalIFaceDiscover, config Config, candidateTypes []ice.CandidateType, ufrag string, pwd string) (*ThreadSafeAgent, error) {
func NewAgent(ctx context.Context, iFaceDiscover stdnet.ExternalIFaceDiscover, config Config, candidateTypes []ice.CandidateType, ufrag string, pwd string) (*ThreadSafeAgent, error) {
iceKeepAlive := iceKeepAlive()
iceDisconnectedTimeout := iceDisconnectedTimeout()
iceFailedTimeout := iceFailedTimeout()
iceRelayAcceptanceMinWait := iceRelayAcceptanceMinWait()
transportNet, err := newStdNet(iFaceDiscover, config.InterfaceBlackList)
transportNet, err := newStdNet(ctx, iFaceDiscover, config.InterfaceBlackList)
if err != nil {
log.Errorf("failed to create pion's stdnet: %s", err)
}

View File

@@ -3,9 +3,11 @@
package ice
import (
"context"
"github.com/netbirdio/netbird/client/internal/stdnet"
)
func newStdNet(_ stdnet.ExternalIFaceDiscover, ifaceBlacklist []string) (*stdnet.Net, error) {
return stdnet.NewNet(ifaceBlacklist)
func newStdNet(ctx context.Context, _ stdnet.ExternalIFaceDiscover, ifaceBlacklist []string) (*stdnet.Net, error) {
return stdnet.NewNet(ctx, ifaceBlacklist)
}

View File

@@ -1,7 +1,11 @@
package ice
import "github.com/netbirdio/netbird/client/internal/stdnet"
import (
"context"
func newStdNet(iFaceDiscover stdnet.ExternalIFaceDiscover, ifaceBlacklist []string) (*stdnet.Net, error) {
return stdnet.NewNetWithDiscover(iFaceDiscover, ifaceBlacklist)
"github.com/netbirdio/netbird/client/internal/stdnet"
)
func newStdNet(ctx context.Context, iFaceDiscover stdnet.ExternalIFaceDiscover, ifaceBlacklist []string) (*stdnet.Net, error) {
return stdnet.NewNetWithDiscover(ctx, iFaceDiscover, ifaceBlacklist)
}

View File

@@ -209,7 +209,7 @@ func (w *WorkerICE) Close() {
}
func (w *WorkerICE) reCreateAgent(dialerCancel context.CancelFunc, candidates []ice.CandidateType) (*icemaker.ThreadSafeAgent, error) {
agent, err := icemaker.NewAgent(w.iFaceDiscover, w.config.ICEConfig, candidates, w.localUfrag, w.localPwd)
agent, err := icemaker.NewAgent(w.ctx, w.iFaceDiscover, w.config.ICEConfig, candidates, w.localUfrag, w.localPwd)
if err != nil {
return nil, fmt.Errorf("create agent: %w", err)
}
@@ -411,7 +411,7 @@ func (w *WorkerICE) onConnectionStateChange(agent *icemaker.ThreadSafeAgent, dia
func (w *WorkerICE) turnAgentDial(ctx context.Context, agent *icemaker.ThreadSafeAgent, remoteOfferAnswer *OfferAnswer) (*ice.Conn, error) {
if isController(w.config) {
return w.agent.Dial(ctx, remoteOfferAnswer.IceCredentials.UFrag, remoteOfferAnswer.IceCredentials.Pwd)
return agent.Dial(ctx, remoteOfferAnswer.IceCredentials.UFrag, remoteOfferAnswer.IceCredentials.Pwd)
} else {
return agent.Accept(ctx, remoteOfferAnswer.IceCredentials.UFrag, remoteOfferAnswer.IceCredentials.Pwd)
}

View File

@@ -44,6 +44,8 @@ type PKCEAuthProviderConfig struct {
DisablePromptLogin bool
// LoginFlag is used to configure the PKCE flow login behavior
LoginFlag common.LoginFlag
// LoginHint is used to pre-fill the email/username field during authentication
LoginHint string
}
// GetPKCEAuthorizationFlowInfo initialize a PKCEAuthorizationFlow instance and return with it

View File

@@ -197,7 +197,7 @@ func (p *StunTurnProbe) probeSTUN(ctx context.Context, uri *stun.URI) (addr stri
}
}()
net, err := stdnet.NewNet(nil)
net, err := stdnet.NewNet(ctx, nil)
if err != nil {
probeErr = fmt.Errorf("new net: %w", err)
return
@@ -286,7 +286,7 @@ func (p *StunTurnProbe) probeTURN(ctx context.Context, uri *stun.URI) (addr stri
}
}()
net, err := stdnet.NewNet(nil)
net, err := stdnet.NewNet(ctx, nil)
if err != nil {
probeErr = fmt.Errorf("new net: %w", err)
return

View File

@@ -6,7 +6,7 @@ import (
"net/netip"
"testing"
"github.com/pion/transport/v3/stdnet"
"github.com/netbirdio/netbird/client/internal/stdnet"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/stretchr/testify/require"
@@ -403,7 +403,7 @@ func TestManagerUpdateRoutes(t *testing.T) {
for n, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
peerPrivateKey, _ := wgtypes.GeneratePrivateKey()
newNet, err := stdnet.NewNet()
newNet, err := stdnet.NewNet(context.Background(), nil)
if err != nil {
t.Fatal(err)
}

View File

@@ -15,7 +15,7 @@ import (
"syscall"
"testing"
"github.com/pion/transport/v3/stdnet"
"github.com/netbirdio/netbird/client/internal/stdnet"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
@@ -436,7 +436,7 @@ func createWGInterface(t *testing.T, interfaceName, ipAddressCIDR string, listen
peerPrivateKey, err := wgtypes.GeneratePrivateKey()
require.NoError(t, err)
newNet, err := stdnet.NewNet()
newNet, err := stdnet.NewNet(context.Background(), nil)
require.NoError(t, err)
opts := iface.WGIFaceOpts{

View File

@@ -4,17 +4,28 @@
package stdnet
import (
"context"
"errors"
"fmt"
"net"
"net/netip"
"slices"
"strconv"
"sync"
"time"
"github.com/netbirdio/netbird/client/iface/netstack"
"github.com/pion/transport/v3"
"github.com/pion/transport/v3/stdnet"
"github.com/netbirdio/netbird/client/iface/netstack"
)
const updateInterval = 30 * time.Second
const (
updateInterval = 30 * time.Second
dnsResolveTimeout = 30 * time.Second
)
var errNoSuitableAddress = errors.New("no suitable address found")
// Net is an implementation of the net.Net interface
// based on functions of the standard net package.
@@ -28,12 +39,19 @@ type Net struct {
// mu is shared between interfaces and lastUpdate
mu sync.Mutex
// ctx is the context for network operations that supports cancellation
ctx context.Context
}
// NewNetWithDiscover creates a new StdNet instance.
func NewNetWithDiscover(iFaceDiscover ExternalIFaceDiscover, disallowList []string) (*Net, error) {
func NewNetWithDiscover(ctx context.Context, iFaceDiscover ExternalIFaceDiscover, disallowList []string) (*Net, error) {
if ctx == nil {
ctx = context.Background()
}
n := &Net{
interfaceFilter: InterfaceFilter(disallowList),
ctx: ctx,
}
// current ExternalIFaceDiscover implement in android-client https://github.dev/netbirdio/android-client
// so in android cli use pionDiscover
@@ -46,14 +64,64 @@ func NewNetWithDiscover(iFaceDiscover ExternalIFaceDiscover, disallowList []stri
}
// NewNet creates a new StdNet instance.
func NewNet(disallowList []string) (*Net, error) {
func NewNet(ctx context.Context, disallowList []string) (*Net, error) {
if ctx == nil {
ctx = context.Background()
}
n := &Net{
iFaceDiscover: pionDiscover{},
interfaceFilter: InterfaceFilter(disallowList),
ctx: ctx,
}
return n, n.UpdateInterfaces()
}
// resolveAddr performs DNS resolution with context support and timeout.
func (n *Net) resolveAddr(network, address string) (netip.AddrPort, error) {
host, portStr, err := net.SplitHostPort(address)
if err != nil {
return netip.AddrPort{}, err
}
port, err := strconv.Atoi(portStr)
if err != nil {
return netip.AddrPort{}, fmt.Errorf("invalid port: %w", err)
}
if port < 0 || port > 65535 {
return netip.AddrPort{}, fmt.Errorf("invalid port: %d", port)
}
ipNet := "ip"
switch network {
case "tcp4", "udp4":
ipNet = "ip4"
case "tcp6", "udp6":
ipNet = "ip6"
}
if host == "" {
addr := netip.IPv4Unspecified()
if ipNet == "ip6" {
addr = netip.IPv6Unspecified()
}
return netip.AddrPortFrom(addr, uint16(port)), nil
}
ctx, cancel := context.WithTimeout(n.ctx, dnsResolveTimeout)
defer cancel()
addrs, err := net.DefaultResolver.LookupNetIP(ctx, ipNet, host)
if err != nil {
return netip.AddrPort{}, err
}
if len(addrs) == 0 {
return netip.AddrPort{}, errNoSuitableAddress
}
return netip.AddrPortFrom(addrs[0], uint16(port)), nil
}
// UpdateInterfaces updates the internal list of network interfaces
// and associated addresses filtering them by name.
// The interfaces are discovered by an external iFaceDiscover function or by a default discoverer if the external one
@@ -137,3 +205,39 @@ func (n *Net) filterInterfaces(interfaces []*transport.Interface) []*transport.I
}
return result
}
// ResolveUDPAddr resolves UDP addresses with context support and timeout.
func (n *Net) ResolveUDPAddr(network, address string) (*net.UDPAddr, error) {
switch network {
case "udp", "udp4", "udp6":
case "":
network = "udp"
default:
return nil, &net.OpError{Op: "resolve", Net: network, Err: net.UnknownNetworkError(network)}
}
addrPort, err := n.resolveAddr(network, address)
if err != nil {
return nil, &net.OpError{Op: "resolve", Net: network, Addr: &net.UDPAddr{IP: nil}, Err: err}
}
return net.UDPAddrFromAddrPort(addrPort), nil
}
// ResolveTCPAddr resolves TCP addresses with context support and timeout.
func (n *Net) ResolveTCPAddr(network, address string) (*net.TCPAddr, error) {
switch network {
case "tcp", "tcp4", "tcp6":
case "":
network = "tcp"
default:
return nil, &net.OpError{Op: "resolve", Net: network, Err: net.UnknownNetworkError(network)}
}
addrPort, err := n.resolveAddr(network, address)
if err != nil {
return nil, &net.OpError{Op: "resolve", Net: network, Addr: &net.TCPAddr{IP: nil}, Err: err}
}
return net.TCPAddrFromAddrPort(addrPort), nil
}

View File

@@ -228,7 +228,7 @@ func (c *Client) LoginForMobile() string {
ConfigPath: c.cfgFile,
})
oAuthFlow, err := auth.NewOAuthFlow(ctx, cfg, false)
oAuthFlow, err := auth.NewOAuthFlow(ctx, cfg, false, "")
if err != nil {
return err.Error()
}

View File

@@ -279,8 +279,10 @@ type LoginRequest struct {
ProfileName *string `protobuf:"bytes,30,opt,name=profileName,proto3,oneof" json:"profileName,omitempty"`
Username *string `protobuf:"bytes,31,opt,name=username,proto3,oneof" json:"username,omitempty"`
Mtu *int64 `protobuf:"varint,32,opt,name=mtu,proto3,oneof" json:"mtu,omitempty"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
// hint is used to pre-fill the email/username field during SSO authentication
Hint *string `protobuf:"bytes,33,opt,name=hint,proto3,oneof" json:"hint,omitempty"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
func (x *LoginRequest) Reset() {
@@ -538,6 +540,13 @@ func (x *LoginRequest) GetMtu() int64 {
return 0
}
func (x *LoginRequest) GetHint() string {
if x != nil && x.Hint != nil {
return *x.Hint
}
return ""
}
type LoginResponse struct {
state protoimpl.MessageState `protogen:"open.v1"`
NeedsSSOLogin bool `protobuf:"varint,1,opt,name=needsSSOLogin,proto3" json:"needsSSOLogin,omitempty"`
@@ -4608,7 +4617,7 @@ var File_daemon_proto protoreflect.FileDescriptor
const file_daemon_proto_rawDesc = "" +
"\n" +
"\fdaemon.proto\x12\x06daemon\x1a google/protobuf/descriptor.proto\x1a\x1fgoogle/protobuf/timestamp.proto\x1a\x1egoogle/protobuf/duration.proto\"\x0e\n" +
"\fEmptyRequest\"\xc3\x0e\n" +
"\fEmptyRequest\"\xe5\x0e\n" +
"\fLoginRequest\x12\x1a\n" +
"\bsetupKey\x18\x01 \x01(\tR\bsetupKey\x12&\n" +
"\fpreSharedKey\x18\x02 \x01(\tB\x02\x18\x01R\fpreSharedKey\x12$\n" +
@@ -4645,7 +4654,8 @@ const file_daemon_proto_rawDesc = "" +
"\rblock_inbound\x18\x1d \x01(\bH\x10R\fblockInbound\x88\x01\x01\x12%\n" +
"\vprofileName\x18\x1e \x01(\tH\x11R\vprofileName\x88\x01\x01\x12\x1f\n" +
"\busername\x18\x1f \x01(\tH\x12R\busername\x88\x01\x01\x12\x15\n" +
"\x03mtu\x18 \x01(\x03H\x13R\x03mtu\x88\x01\x01B\x13\n" +
"\x03mtu\x18 \x01(\x03H\x13R\x03mtu\x88\x01\x01\x12\x17\n" +
"\x04hint\x18! \x01(\tH\x14R\x04hint\x88\x01\x01B\x13\n" +
"\x11_rosenpassEnabledB\x10\n" +
"\x0e_interfaceNameB\x10\n" +
"\x0e_wireguardPortB\x17\n" +
@@ -4665,7 +4675,8 @@ const file_daemon_proto_rawDesc = "" +
"\x0e_block_inboundB\x0e\n" +
"\f_profileNameB\v\n" +
"\t_usernameB\x06\n" +
"\x04_mtu\"\xb5\x01\n" +
"\x04_mtuB\a\n" +
"\x05_hint\"\xb5\x01\n" +
"\rLoginResponse\x12$\n" +
"\rneedsSSOLogin\x18\x01 \x01(\bR\rneedsSSOLogin\x12\x1a\n" +
"\buserCode\x18\x02 \x01(\tR\buserCode\x12(\n" +

View File

@@ -158,6 +158,9 @@ message LoginRequest {
optional string username = 31;
optional int64 mtu = 32;
// hint is used to pre-fill the email/username field during SSO authentication
optional string hint = 33;
}
message LoginResponse {

View File

@@ -483,7 +483,11 @@ func (s *Server) Login(callerCtx context.Context, msg *proto.LoginRequest) (*pro
state.Set(internal.StatusConnecting)
if msg.SetupKey == "" {
oAuthFlow, err := auth.NewOAuthFlow(ctx, config, msg.IsUnixDesktopClient)
hint := ""
if msg.Hint != nil {
hint = *msg.Hint
}
oAuthFlow, err := auth.NewOAuthFlow(ctx, config, msg.IsUnixDesktopClient, hint)
if err != nil {
state.Set(internal.StatusLoginFailed)
return nil, err

View File

@@ -14,6 +14,9 @@ import (
"go.opentelemetry.io/otel"
"github.com/netbirdio/management-integrations/integrations"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
"github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/management/server/groups"
@@ -290,7 +293,6 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve
}
t.Cleanup(cleanUp)
peersUpdateManager := server.NewPeersUpdateManager(nil)
eventStore := &activity.InMemoryEventStore{}
if err != nil {
return nil, "", err
@@ -311,13 +313,16 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve
settingsMockManager := settings.NewMockManager(ctrl)
groupsManager := groups.NewManagerMock()
accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false)
requestBuffer := server.NewAccountRequestBuffer(context.Background(), store)
peersUpdateManager := update_channel.NewPeersUpdateManager(metrics)
networkMapController := controller.NewController(context.Background(), store, metrics, peersUpdateManager, requestBuffer, server.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock())
accountManager, err := server.BuildManager(context.Background(), store, networkMapController, nil, "", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false)
if err != nil {
return nil, "", err
}
secretsManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
mgmtServer, err := server.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, &manager.EphemeralManager{}, nil, &server.MockIntegratedValidator{})
secretsManager := nbgrpc.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, &manager.EphemeralManager{}, nil, &server.MockIntegratedValidator{}, networkMapController)
if err != nil {
return nil, "", err
}

Binary file not shown.

After

Width:  |  Height:  |  Size: 4.9 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 7.4 KiB

View File

@@ -85,21 +85,22 @@ func main() {
// Create the service client (this also builds the settings or networks UI if requested).
client := newServiceClient(&newServiceClientArgs{
addr: flags.daemonAddr,
logFile: logFile,
app: a,
showSettings: flags.showSettings,
showNetworks: flags.showNetworks,
showLoginURL: flags.showLoginURL,
showDebug: flags.showDebug,
showProfiles: flags.showProfiles,
addr: flags.daemonAddr,
logFile: logFile,
app: a,
showSettings: flags.showSettings,
showNetworks: flags.showNetworks,
showLoginURL: flags.showLoginURL,
showDebug: flags.showDebug,
showProfiles: flags.showProfiles,
showQuickActions: flags.showQuickActions,
})
// Watch for theme/settings changes to update the icon.
go watchSettingsChanges(a, client)
// Run in window mode if any UI flag was set.
if flags.showSettings || flags.showNetworks || flags.showDebug || flags.showLoginURL || flags.showProfiles {
if flags.showSettings || flags.showNetworks || flags.showDebug || flags.showLoginURL || flags.showProfiles || flags.showQuickActions {
a.Run()
return
}
@@ -111,23 +112,29 @@ func main() {
return
}
if running {
log.Warnf("another process is running with pid %d, exiting", pid)
log.Infof("another process is running with pid %d, sending signal to show window", pid)
if err := sendShowWindowSignal(pid); err != nil {
log.Errorf("send signal to running instance: %v", err)
}
return
}
client.setupSignalHandler(client.ctx)
client.setDefaultFonts()
systray.Run(client.onTrayReady, client.onTrayExit)
}
type cliFlags struct {
daemonAddr string
showSettings bool
showNetworks bool
showProfiles bool
showDebug bool
showLoginURL bool
errorMsg string
saveLogsInFile bool
daemonAddr string
showSettings bool
showNetworks bool
showProfiles bool
showDebug bool
showLoginURL bool
showQuickActions bool
errorMsg string
saveLogsInFile bool
}
// parseFlags reads and returns all needed command-line flags.
@@ -143,6 +150,7 @@ func parseFlags() *cliFlags {
flag.BoolVar(&flags.showNetworks, "networks", false, "run networks window")
flag.BoolVar(&flags.showProfiles, "profiles", false, "run profiles window")
flag.BoolVar(&flags.showDebug, "debug", false, "run debug window")
flag.BoolVar(&flags.showQuickActions, "quick-actions", false, "run quick actions window")
flag.StringVar(&flags.errorMsg, "error-msg", "", "displays an error message window")
flag.BoolVar(&flags.saveLogsInFile, "use-log-file", false, fmt.Sprintf("save logs in a file: %s/netbird-ui-PID.log", os.TempDir()))
flag.BoolVar(&flags.showLoginURL, "login-url", false, "show login URL in a popup window")
@@ -158,11 +166,9 @@ func initLogFile() (string, error) {
// watchSettingsChanges listens for Fyne theme/settings changes and updates the client icon.
func watchSettingsChanges(a fyne.App, client *serviceClient) {
settingsChangeChan := make(chan fyne.Settings)
a.Settings().AddChangeListener(settingsChangeChan)
for range settingsChangeChan {
a.Settings().AddListener(func(settings fyne.Settings) {
client.updateIcon()
}
})
}
// showErrorMessage displays an error message in a simple window.
@@ -287,6 +293,7 @@ type serviceClient struct {
showNetworks bool
wNetworks fyne.Window
wProfiles fyne.Window
wQuickActions fyne.Window
eventManager *event.Manager
@@ -306,14 +313,15 @@ type menuHandler struct {
}
type newServiceClientArgs struct {
addr string
logFile string
app fyne.App
showSettings bool
showNetworks bool
showDebug bool
showLoginURL bool
showProfiles bool
addr string
logFile string
app fyne.App
showSettings bool
showNetworks bool
showDebug bool
showLoginURL bool
showProfiles bool
showQuickActions bool
}
// newServiceClient instance constructor
@@ -349,6 +357,8 @@ func newServiceClient(args *newServiceClientArgs) *serviceClient {
s.showDebugUI()
case args.showProfiles:
s.showProfilesUI()
case args.showQuickActions:
s.showQuickActionsUI()
}
return s
@@ -610,11 +620,20 @@ func (s *serviceClient) login(ctx context.Context, openURL bool) (*proto.LoginRe
return nil, fmt.Errorf("get current user: %w", err)
}
loginResp, err := conn.Login(ctx, &proto.LoginRequest{
loginReq := &proto.LoginRequest{
IsUnixDesktopClient: runtime.GOOS == "linux" || runtime.GOOS == "freebsd",
ProfileName: &activeProf.Name,
Username: &currUser.Username,
})
}
profileState, err := s.profileManager.GetProfileState(activeProf.Name)
if err != nil {
log.Debugf("failed to get profile state for login hint: %v", err)
} else if profileState.Email != "" {
loginReq.Hint = &profileState.Email
}
loginResp, err := conn.Login(ctx, loginReq)
if err != nil {
return nil, fmt.Errorf("login to management: %w", err)
}

View File

@@ -500,7 +500,7 @@ func (s *serviceClient) createDebugBundleFromCollection(
if uploadFailureReason != "" {
showUploadFailedDialog(progress.window, localPath, uploadFailureReason)
} else {
showUploadSuccessDialog(progress.window, localPath, uploadedKey)
showUploadSuccessDialog(s.app, progress.window, localPath, uploadedKey)
}
} else {
showBundleCreatedDialog(progress.window, localPath)
@@ -565,7 +565,7 @@ func (s *serviceClient) handleDebugCreation(
if uploadFailureReason != "" {
showUploadFailedDialog(w, localPath, uploadFailureReason)
} else {
showUploadSuccessDialog(w, localPath, uploadedKey)
showUploadSuccessDialog(s.app, w, localPath, uploadedKey)
}
} else {
showBundleCreatedDialog(w, localPath)
@@ -665,7 +665,7 @@ func showUploadFailedDialog(w fyne.Window, localPath, failureReason string) {
}
// showUploadSuccessDialog displays a dialog when upload succeeds
func showUploadSuccessDialog(w fyne.Window, localPath, uploadedKey string) {
func showUploadSuccessDialog(a fyne.App, w fyne.Window, localPath, uploadedKey string) {
log.Infof("Upload key: %s", uploadedKey)
keyEntry := widget.NewEntry()
keyEntry.SetText(uploadedKey)
@@ -683,7 +683,7 @@ func showUploadSuccessDialog(w fyne.Window, localPath, uploadedKey string) {
customDialog := dialog.NewCustom("Upload Successful", "OK", content, w)
copyBtn := createButtonWithAction("Copy key", func() {
w.Clipboard().SetContent(uploadedKey)
a.Clipboard().SetContent(uploadedKey)
log.Info("Upload key copied to clipboard")
})

View File

@@ -9,6 +9,9 @@ import (
//go:embed assets/netbird.png
var iconAbout []byte
//go:embed assets/netbird-disconnected.png
var iconAboutDisconnected []byte
//go:embed assets/netbird-systemtray-connected.png
var iconConnected []byte

View File

@@ -7,6 +7,9 @@ import (
//go:embed assets/netbird.ico
var iconAbout []byte
//go:embed assets/netbird-disconnected.ico
var iconAboutDisconnected []byte
//go:embed assets/netbird-systemtray-connected.ico
var iconConnected []byte

349
client/ui/quickactions.go Normal file
View File

@@ -0,0 +1,349 @@
//go:build !(linux && 386)
//go:generate fyne bundle -o quickactions_assets.go assets/connected.png
//go:generate fyne bundle -o quickactions_assets.go -append assets/disconnected.png
package main
import (
"context"
_ "embed"
"fmt"
"runtime"
"sync/atomic"
"time"
"fyne.io/fyne/v2"
"fyne.io/fyne/v2/canvas"
"fyne.io/fyne/v2/container"
"fyne.io/fyne/v2/layout"
"fyne.io/fyne/v2/widget"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/proto"
)
type quickActionsUiState struct {
connectionStatus string
isToggleButtonEnabled bool
isConnectionChanged bool
toggleAction func()
}
func newQuickActionsUiState() quickActionsUiState {
return quickActionsUiState{
connectionStatus: string(internal.StatusIdle),
isToggleButtonEnabled: false,
isConnectionChanged: false,
}
}
type clientConnectionStatusProvider interface {
connectionStatus(ctx context.Context) (string, error)
}
type daemonClientConnectionStatusProvider struct {
client proto.DaemonServiceClient
}
func (d daemonClientConnectionStatusProvider) connectionStatus(ctx context.Context) (string, error) {
childCtx, cancel := context.WithTimeout(ctx, 400*time.Millisecond)
defer cancel()
status, err := d.client.Status(childCtx, &proto.StatusRequest{})
if err != nil {
return "", err
}
return status.Status, nil
}
type clientCommand interface {
execute() error
}
type connectCommand struct {
connectClient func() error
}
func (c connectCommand) execute() error {
return c.connectClient()
}
type disconnectCommand struct {
disconnectClient func() error
}
func (c disconnectCommand) execute() error {
return c.disconnectClient()
}
type quickActionsViewModel struct {
provider clientConnectionStatusProvider
connect clientCommand
disconnect clientCommand
uiChan chan quickActionsUiState
isWatchingConnectionStatus atomic.Bool
}
func newQuickActionsViewModel(ctx context.Context, provider clientConnectionStatusProvider, connect, disconnect clientCommand, uiChan chan quickActionsUiState) {
viewModel := quickActionsViewModel{
provider: provider,
connect: connect,
disconnect: disconnect,
uiChan: uiChan,
}
viewModel.isWatchingConnectionStatus.Store(true)
// base UI status
uiChan <- newQuickActionsUiState()
// this retrieves the current connection status
// and pushes the UI state that reflects it via uiChan
go viewModel.watchConnectionStatus(ctx)
}
func (q *quickActionsViewModel) updateUiState(ctx context.Context) {
uiState := newQuickActionsUiState()
connectionStatus, err := q.provider.connectionStatus(ctx)
if err != nil {
log.Errorf("Status: Error - %v", err)
q.uiChan <- uiState
return
}
if connectionStatus == string(internal.StatusConnected) {
uiState.toggleAction = func() {
q.executeCommand(q.disconnect)
}
} else {
uiState.toggleAction = func() {
q.executeCommand(q.connect)
}
}
uiState.isToggleButtonEnabled = true
uiState.connectionStatus = connectionStatus
q.uiChan <- uiState
}
func (q *quickActionsViewModel) watchConnectionStatus(ctx context.Context) {
ticker := time.NewTicker(1000 * time.Millisecond)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
if q.isWatchingConnectionStatus.Load() {
q.updateUiState(ctx)
}
}
}
}
func (q *quickActionsViewModel) executeCommand(command clientCommand) {
uiState := newQuickActionsUiState()
// newQuickActionsUiState starts with Idle connection status,
// and all that's necessary here is to just disable the toggle button.
uiState.connectionStatus = ""
q.uiChan <- uiState
q.isWatchingConnectionStatus.Store(false)
err := command.execute()
if err != nil {
log.Errorf("Status: Error - %v", err)
q.isWatchingConnectionStatus.Store(true)
} else {
uiState = newQuickActionsUiState()
uiState.isConnectionChanged = true
q.uiChan <- uiState
}
}
func getSystemTrayName() string {
os := runtime.GOOS
switch os {
case "darwin":
return "menu bar"
default:
return "system tray"
}
}
func (s *serviceClient) getNetBirdImage(name string, content []byte) *canvas.Image {
imageSize := fyne.NewSize(64, 64)
resource := fyne.NewStaticResource(name, content)
image := canvas.NewImageFromResource(resource)
image.FillMode = canvas.ImageFillContain
image.SetMinSize(imageSize)
image.Resize(imageSize)
return image
}
type quickActionsUiComponents struct {
content *fyne.Container
toggleConnectionButton *widget.Button
connectedLabelText, disconnectedLabelText string
connectedImage, disconnectedImage *canvas.Image
connectedCircleRes, disconnectedCircleRes fyne.Resource
}
// applyQuickActionsUiState applies a single UI state to the quick actions window.
// It closes the window and returns true if the connection status has changed,
// in which case the caller should stop processing further states.
func (s *serviceClient) applyQuickActionsUiState(
uiState quickActionsUiState,
components quickActionsUiComponents,
) bool {
if uiState.isConnectionChanged {
fyne.DoAndWait(func() {
s.wQuickActions.Close()
})
return true
}
var logo *canvas.Image
var buttonText string
var buttonIcon fyne.Resource
if uiState.connectionStatus == string(internal.StatusConnected) {
buttonText = components.connectedLabelText
buttonIcon = components.connectedCircleRes
logo = components.connectedImage
} else if uiState.connectionStatus == string(internal.StatusIdle) {
buttonText = components.disconnectedLabelText
buttonIcon = components.disconnectedCircleRes
logo = components.disconnectedImage
}
fyne.DoAndWait(func() {
if buttonText != "" {
components.toggleConnectionButton.SetText(buttonText)
}
if buttonIcon != nil {
components.toggleConnectionButton.SetIcon(buttonIcon)
}
if uiState.isToggleButtonEnabled {
components.toggleConnectionButton.Enable()
} else {
components.toggleConnectionButton.Disable()
}
components.toggleConnectionButton.OnTapped = func() {
if uiState.toggleAction != nil {
go uiState.toggleAction()
}
}
components.toggleConnectionButton.Refresh()
// the second position in the content's object array is the NetBird logo.
if logo != nil {
components.content.Objects[1] = logo
components.content.Refresh()
}
})
return false
}
// showQuickActionsUI displays a simple window with the NetBird logo and a connection toggle button.
func (s *serviceClient) showQuickActionsUI() {
s.wQuickActions = s.app.NewWindow("NetBird")
vmCtx, vmCancel := context.WithCancel(s.ctx)
s.wQuickActions.SetOnClosed(vmCancel)
client, err := s.getSrvClient(defaultFailTimeout)
connCmd := connectCommand{
connectClient: func() error {
return s.menuUpClick(s.ctx)
},
}
disConnCmd := disconnectCommand{
disconnectClient: func() error {
return s.menuDownClick()
},
}
if err != nil {
log.Errorf("get service client: %v", err)
return
}
uiChan := make(chan quickActionsUiState, 1)
newQuickActionsViewModel(vmCtx, daemonClientConnectionStatusProvider{client: client}, connCmd, disConnCmd, uiChan)
connectedImage := s.getNetBirdImage("netbird.png", iconAbout)
disconnectedImage := s.getNetBirdImage("netbird-disconnected.png", iconAboutDisconnected)
connectedCircle := canvas.NewImageFromResource(resourceConnectedPng)
disconnectedCircle := canvas.NewImageFromResource(resourceDisconnectedPng)
connectedLabelText := "Disconnect"
disconnectedLabelText := "Connect"
toggleConnectionButton := widget.NewButtonWithIcon(disconnectedLabelText, disconnectedCircle.Resource, func() {
// This button's tap function will be set when an ui state arrives via the uiChan channel.
})
// Button starts disabled until the first ui state arrives.
toggleConnectionButton.Disable()
hintLabelText := fmt.Sprintf("You can always access NetBird from your %s.", getSystemTrayName())
hintLabel := widget.NewLabel(hintLabelText)
content := container.NewVBox(
layout.NewSpacer(),
disconnectedImage,
layout.NewSpacer(),
container.NewCenter(toggleConnectionButton),
layout.NewSpacer(),
container.NewCenter(hintLabel),
)
// this watches for ui state updates.
go func() {
for {
select {
case <-vmCtx.Done():
return
case uiState, ok := <-uiChan:
if !ok {
return
}
closed := s.applyQuickActionsUiState(
uiState,
quickActionsUiComponents{
content,
toggleConnectionButton,
connectedLabelText, disconnectedLabelText,
connectedImage, disconnectedImage,
connectedCircle.Resource, disconnectedCircle.Resource,
},
)
if closed {
return
}
}
}
}()
s.wQuickActions.SetContent(content)
s.wQuickActions.Resize(fyne.NewSize(400, 200))
s.wQuickActions.SetFixedSize(true)
s.wQuickActions.Show()
}

View File

@@ -0,0 +1,23 @@
// auto-generated
// Code generated by '$ fyne bundle'. DO NOT EDIT.
package main
import (
_ "embed"
"fyne.io/fyne/v2"
)
//go:embed assets/connected.png
var resourceConnectedPngData []byte
var resourceConnectedPng = &fyne.StaticResource{
StaticName: "assets/connected.png",
StaticContent: resourceConnectedPngData,
}
//go:embed assets/disconnected.png
var resourceDisconnectedPngData []byte
var resourceDisconnectedPng = &fyne.StaticResource{
StaticName: "assets/disconnected.png",
StaticContent: resourceDisconnectedPngData,
}

76
client/ui/signal_unix.go Normal file
View File

@@ -0,0 +1,76 @@
//go:build !windows && !(linux && 386)
package main
import (
"context"
"os"
"os/exec"
"os/signal"
"syscall"
log "github.com/sirupsen/logrus"
)
// setupSignalHandler sets up a signal handler to listen for SIGUSR1.
// When received, it opens the quick actions window.
func (s *serviceClient) setupSignalHandler(ctx context.Context) {
sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, syscall.SIGUSR1)
go func() {
for {
select {
case <-ctx.Done():
return
case <-sigChan:
log.Info("received SIGUSR1 signal, opening quick actions window")
s.openQuickActions()
}
}
}()
}
// openQuickActions opens the quick actions window by spawning a new process.
func (s *serviceClient) openQuickActions() {
proc, err := os.Executable()
if err != nil {
log.Errorf("get executable path: %v", err)
return
}
cmd := exec.CommandContext(s.ctx, proc,
"--quick-actions=true",
"--daemon-addr="+s.addr,
)
if out := s.attachOutput(cmd); out != nil {
defer func() {
if err := out.Close(); err != nil {
log.Errorf("close log file %s: %v", s.logFile, err)
}
}()
}
log.Infof("running command: %s --quick-actions=true --daemon-addr=%s", proc, s.addr)
if err := cmd.Start(); err != nil {
log.Errorf("start quick actions window: %v", err)
return
}
go func() {
if err := cmd.Wait(); err != nil {
log.Debugf("quick actions window exited: %v", err)
}
}()
}
// sendShowWindowSignal sends SIGUSR1 to the specified PID.
func sendShowWindowSignal(pid int32) error {
process, err := os.FindProcess(int(pid))
if err != nil {
return err
}
return process.Signal(syscall.SIGUSR1)
}

171
client/ui/signal_windows.go Normal file
View File

@@ -0,0 +1,171 @@
//go:build windows
package main
import (
"context"
"errors"
"fmt"
"os"
"os/exec"
"time"
log "github.com/sirupsen/logrus"
"golang.org/x/sys/windows"
)
const (
quickActionsTriggerEventName = `Global\NetBirdQuickActionsTriggerEvent`
waitTimeout = 5 * time.Second
// SYNCHRONIZE is needed for WaitForSingleObject, EVENT_MODIFY_STATE for ResetEvent.
desiredAccesses = windows.SYNCHRONIZE | windows.EVENT_MODIFY_STATE
)
func getEventNameUint16Pointer() (*uint16, error) {
eventNamePtr, err := windows.UTF16PtrFromString(quickActionsTriggerEventName)
if err != nil {
log.Errorf("Failed to convert event name '%s' to UTF16: %v", quickActionsTriggerEventName, err)
return nil, err
}
return eventNamePtr, nil
}
// setupSignalHandler sets up signal handling for Windows.
// Windows doesn't support SIGUSR1, so this uses a similar approach using windows.Events.
func (s *serviceClient) setupSignalHandler(ctx context.Context) {
eventNamePtr, err := getEventNameUint16Pointer()
if err != nil {
return
}
eventHandle, err := windows.CreateEvent(nil, 1, 0, eventNamePtr)
if err != nil {
if errors.Is(err, windows.ERROR_ALREADY_EXISTS) {
log.Warnf("Quick actions trigger event '%s' already exists. Attempting to open.", quickActionsTriggerEventName)
eventHandle, err = windows.OpenEvent(desiredAccesses, false, eventNamePtr)
if err != nil {
log.Errorf("Failed to open existing quick actions trigger event '%s': %v", quickActionsTriggerEventName, err)
return
}
log.Infof("Successfully opened existing quick actions trigger event '%s'.", quickActionsTriggerEventName)
} else {
log.Errorf("Failed to create quick actions trigger event '%s': %v", quickActionsTriggerEventName, err)
return
}
}
if eventHandle == windows.InvalidHandle {
log.Errorf("Obtained an invalid handle for quick actions trigger event '%s'", quickActionsTriggerEventName)
return
}
log.Infof("Quick actions handler waiting for signal on event: %s", quickActionsTriggerEventName)
go s.waitForEvent(ctx, eventHandle)
}
func (s *serviceClient) waitForEvent(ctx context.Context, eventHandle windows.Handle) {
defer func() {
if err := windows.CloseHandle(eventHandle); err != nil {
log.Errorf("Failed to close quick actions event handle '%s': %v", quickActionsTriggerEventName, err)
}
}()
for {
if ctx.Err() != nil {
return
}
status, err := windows.WaitForSingleObject(eventHandle, uint32(waitTimeout.Milliseconds()))
switch status {
case windows.WAIT_OBJECT_0:
log.Info("Received signal on quick actions event. Opening quick actions window.")
// reset the event so it can be triggered again later (manual reset == 1)
if err := windows.ResetEvent(eventHandle); err != nil {
log.Errorf("Failed to reset quick actions event '%s': %v", quickActionsTriggerEventName, err)
}
s.openQuickActions()
case uint32(windows.WAIT_TIMEOUT):
default:
if isDone := logUnexpectedStatus(ctx, status, err); isDone {
return
}
}
}
}
func logUnexpectedStatus(ctx context.Context, status uint32, err error) bool {
log.Errorf("Unexpected status %d from WaitForSingleObject for quick actions event '%s': %v",
status, quickActionsTriggerEventName, err)
select {
case <-time.After(5 * time.Second):
return false
case <-ctx.Done():
return true
}
}
// openQuickActions opens the quick actions window by spawning a new process.
func (s *serviceClient) openQuickActions() {
proc, err := os.Executable()
if err != nil {
log.Errorf("get executable path: %v", err)
return
}
cmd := exec.CommandContext(s.ctx, proc,
"--quick-actions=true",
"--daemon-addr="+s.addr,
)
if out := s.attachOutput(cmd); out != nil {
defer func() {
if err := out.Close(); err != nil {
log.Errorf("close log file %s: %v", s.logFile, err)
}
}()
}
log.Infof("running command: %s --quick-actions=true --daemon-addr=%s", proc, s.addr)
if err := cmd.Start(); err != nil {
log.Errorf("error starting quick actions window: %v", err)
return
}
go func() {
if err := cmd.Wait(); err != nil {
log.Debugf("quick actions window exited: %v", err)
}
}()
}
func sendShowWindowSignal(pid int32) error {
_, err := os.FindProcess(int(pid))
if err != nil {
return err
}
eventNamePtr, err := getEventNameUint16Pointer()
if err != nil {
return err
}
eventHandle, err := windows.OpenEvent(desiredAccesses, false, eventNamePtr)
if err != nil {
return err
}
err = windows.SetEvent(eventHandle)
if err != nil {
return fmt.Errorf("Error setting event: %w", err)
}
return nil
}

57
go.mod
View File

@@ -16,7 +16,7 @@ require (
github.com/sirupsen/logrus v1.9.3
github.com/spf13/cobra v1.7.0
github.com/spf13/pflag v1.0.5
github.com/vishvananda/netlink v1.3.0
github.com/vishvananda/netlink v1.3.1
golang.org/x/crypto v0.40.0
golang.org/x/sys v0.34.0
golang.zx2c4.com/wireguard v0.0.0-20230704135630-469159ecf7d1
@@ -28,8 +28,8 @@ require (
)
require (
fyne.io/fyne/v2 v2.5.3
fyne.io/systray v1.11.0
fyne.io/fyne/v2 v2.7.0
fyne.io/systray v1.11.1-0.20250603113521-ca66a66d8b58
github.com/TheJumpCloud/jcapi-go v3.0.0+incompatible
github.com/aws/aws-sdk-go-v2 v1.36.3
github.com/aws/aws-sdk-go-v2/config v1.29.14
@@ -43,7 +43,7 @@ require (
github.com/eko/gocache/lib/v4 v4.2.0
github.com/eko/gocache/store/go_cache/v4 v4.2.2
github.com/eko/gocache/store/redis/v4 v4.2.2
github.com/fsnotify/fsnotify v1.7.0
github.com/fsnotify/fsnotify v1.9.0
github.com/gliderlabs/ssh v0.3.8
github.com/godbus/dbus/v5 v5.1.0
github.com/golang-jwt/jwt/v5 v5.3.0
@@ -56,12 +56,13 @@ require (
github.com/hashicorp/go-multierror v1.1.1
github.com/hashicorp/go-secure-stdlib/base62 v0.1.2
github.com/hashicorp/go-version v1.6.0
github.com/jackc/pgx/v5 v5.5.5
github.com/libdns/route53 v1.5.0
github.com/libp2p/go-netroute v0.2.1
github.com/lrh3321/ipset-go v0.0.0-20250619021614-54a0a98ace81
github.com/mdlayher/socket v0.5.1
github.com/miekg/dns v1.1.59
github.com/mitchellh/hashstructure/v2 v2.0.2
github.com/nadoo/ipset v0.5.0
github.com/netbirdio/management-integrations/integrations v0.0.0-20251027212525-d751b79f5d48
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45
github.com/okta/okta-sdk-golang/v2 v2.18.0
@@ -82,7 +83,7 @@ require (
github.com/shirou/gopsutil/v3 v3.24.4
github.com/skratchdot/open-golang v0.0.0-20200116055534-eef842397966
github.com/songgao/water v0.0.0-20200317203138-2b4b6d7c09d8
github.com/stretchr/testify v1.10.0
github.com/stretchr/testify v1.11.1
github.com/testcontainers/testcontainers-go v0.31.0
github.com/testcontainers/testcontainers-go/modules/mysql v0.31.0
github.com/testcontainers/testcontainers-go/modules/postgres v0.31.0
@@ -98,15 +99,17 @@ require (
go.opentelemetry.io/otel/exporters/prometheus v0.48.0
go.opentelemetry.io/otel/metric v1.35.0
go.opentelemetry.io/otel/sdk/metric v1.35.0
go.uber.org/mock v0.5.0
go.uber.org/zap v1.27.0
goauthentik.io/api/v3 v3.2023051.3
golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842
golang.org/x/mobile v0.0.0-20231127183840-76ac6878050a
golang.org/x/mod v0.25.0
golang.org/x/mod v0.26.0
golang.org/x/net v0.42.0
golang.org/x/oauth2 v0.28.0
golang.org/x/oauth2 v0.30.0
golang.org/x/sync v0.16.0
golang.org/x/term v0.33.0
golang.org/x/time v0.12.0
google.golang.org/api v0.177.0
gopkg.in/yaml.v3 v3.0.1
gorm.io/driver/mysql v1.5.7
@@ -123,7 +126,7 @@ require (
dario.cat/mergo v1.0.0 // indirect
filippo.io/edwards25519 v1.1.0 // indirect
github.com/Azure/go-ansiterm v0.0.0-20230124172434-306776ec8161 // indirect
github.com/BurntSushi/toml v1.4.0 // indirect
github.com/BurntSushi/toml v1.5.0 // indirect
github.com/Microsoft/go-winio v0.6.2 // indirect
github.com/Microsoft/hcsshim v0.12.3 // indirect
github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be // indirect
@@ -146,7 +149,7 @@ require (
github.com/beorn7/perks v1.0.1 // indirect
github.com/caddyserver/zerossl v0.1.3 // indirect
github.com/cespare/xxhash/v2 v2.3.0 // indirect
github.com/containerd/containerd v1.7.27 // indirect
github.com/containerd/containerd v1.7.29 // indirect
github.com/containerd/log v0.1.0 // indirect
github.com/containerd/platforms v0.2.1 // indirect
github.com/cpuguy83/dockercfg v0.3.2 // indirect
@@ -157,11 +160,12 @@ require (
github.com/docker/go-connections v0.5.0 // indirect
github.com/docker/go-units v0.5.0 // indirect
github.com/felixge/httpsnoop v1.0.4 // indirect
github.com/fredbi/uri v1.1.0 // indirect
github.com/fyne-io/gl-js v0.0.0-20220119005834-d2da28d9ccfe // indirect
github.com/fyne-io/glfw-js v0.0.0-20241126112943-313d8a0fe1d0 // indirect
github.com/fyne-io/image v0.0.0-20220602074514-4956b0afb3d2 // indirect
github.com/go-gl/gl v0.0.0-20211210172815-726fda9656d6 // indirect
github.com/fredbi/uri v1.1.1 // indirect
github.com/fyne-io/gl-js v0.2.0 // indirect
github.com/fyne-io/glfw-js v0.3.0 // indirect
github.com/fyne-io/image v0.1.1 // indirect
github.com/fyne-io/oksvg v0.2.0 // indirect
github.com/go-gl/gl v0.0.0-20231021071112-07e5d0ea2e71 // indirect
github.com/go-gl/glfw/v3.3/glfw v0.0.0-20240506104042-037f3cc74f2a // indirect
github.com/go-logr/logr v1.4.2 // indirect
github.com/go-logr/stdr v1.2.2 // indirect
@@ -169,7 +173,7 @@ require (
github.com/go-sql-driver/mysql v1.8.1 // indirect
github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 // indirect
github.com/go-text/render v0.2.0 // indirect
github.com/go-text/typesetting v0.2.0 // indirect
github.com/go-text/typesetting v0.2.1 // indirect
github.com/gogo/protobuf v1.3.2 // indirect
github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect
github.com/google/btree v1.1.2 // indirect
@@ -177,19 +181,19 @@ require (
github.com/google/s2a-go v0.1.7 // indirect
github.com/googleapis/enterprise-certificate-proxy v0.3.2 // indirect
github.com/googleapis/gax-go/v2 v2.12.3 // indirect
github.com/gopherjs/gopherjs v1.17.2 // indirect
github.com/hack-pad/go-indexeddb v0.3.2 // indirect
github.com/hack-pad/safejs v0.1.0 // indirect
github.com/hashicorp/errwrap v1.1.0 // indirect
github.com/hashicorp/go-uuid v1.0.3 // indirect
github.com/inconshreveable/mousetrap v1.1.0 // indirect
github.com/jackc/pgpassfile v1.0.0 // indirect
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect
github.com/jackc/pgx/v5 v5.5.5 // indirect
github.com/jackc/puddle/v2 v2.2.1 // indirect
github.com/jeandeaual/go-locale v0.0.0-20240223122105-ce5225dcaa49 // indirect
github.com/jeandeaual/go-locale v0.0.0-20250612000132-0ef82f21eade // indirect
github.com/jinzhu/inflection v1.0.0 // indirect
github.com/jinzhu/now v1.1.5 // indirect
github.com/jmespath/go-jmespath v0.4.0 // indirect
github.com/jsummers/gobmp v0.0.0-20151104160322-e2ba15ffa76e // indirect
github.com/jsummers/gobmp v0.0.0-20230614200233-a9de23ed2e25 // indirect
github.com/kelseyhightower/envconfig v1.4.0 // indirect
github.com/klauspost/compress v1.18.0 // indirect
github.com/klauspost/cpuid/v2 v2.2.7 // indirect
@@ -208,7 +212,8 @@ require (
github.com/moby/term v0.5.0 // indirect
github.com/morikuni/aec v1.0.0 // indirect
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect
github.com/nicksnyder/go-i18n/v2 v2.4.0 // indirect
github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646 // indirect
github.com/nicksnyder/go-i18n/v2 v2.5.1 // indirect
github.com/nxadm/tail v1.4.8 // indirect
github.com/onsi/ginkgo/v2 v2.9.5 // indirect
github.com/opencontainers/go-digest v1.0.0 // indirect
@@ -224,28 +229,26 @@ require (
github.com/prometheus/client_model v0.6.1 // indirect
github.com/prometheus/common v0.62.0 // indirect
github.com/prometheus/procfs v0.15.1 // indirect
github.com/rymdport/portal v0.3.0 // indirect
github.com/rymdport/portal v0.4.2 // indirect
github.com/shoenig/go-m1cpu v0.1.6 // indirect
github.com/srwiley/oksvg v0.0.0-20221011165216-be6e8873101c // indirect
github.com/srwiley/rasterx v0.0.0-20220730225603-2ab79fcdd4ef // indirect
github.com/stretchr/objx v0.5.2 // indirect
github.com/tklauser/go-sysconf v0.3.14 // indirect
github.com/tklauser/numcpus v0.8.0 // indirect
github.com/vishvananda/netns v0.0.4 // indirect
github.com/vishvananda/netns v0.0.5 // indirect
github.com/vmihailenco/tagparser/v2 v2.0.0 // indirect
github.com/wlynxg/anet v0.0.3 // indirect
github.com/yuin/goldmark v1.7.1 // indirect
github.com/yuin/goldmark v1.7.8 // indirect
github.com/zeebo/blake3 v0.2.3 // indirect
go.opencensus.io v0.24.0 // indirect
go.opentelemetry.io/auto/sdk v1.1.0 // indirect
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.51.0 // indirect
go.opentelemetry.io/otel/sdk v1.35.0 // indirect
go.opentelemetry.io/otel/trace v1.35.0 // indirect
go.uber.org/mock v0.5.0 // indirect
go.uber.org/multierr v1.11.0 // indirect
golang.org/x/image v0.18.0 // indirect
golang.org/x/image v0.24.0 // indirect
golang.org/x/text v0.27.0 // indirect
golang.org/x/time v0.5.0 // indirect
golang.org/x/tools v0.34.0 // indirect
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20250707201910-8d1bb00bc6a7 // indirect

548
go.sum

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,31 @@
package cache
import (
"sync"
"github.com/netbirdio/netbird/shared/management/proto"
)
// DNSConfigCache is a thread-safe cache for DNS configuration components
type DNSConfigCache struct {
NameServerGroups sync.Map
}
// GetNameServerGroup retrieves a cached name server group
func (c *DNSConfigCache) GetNameServerGroup(key string) (*proto.NameServerGroup, bool) {
if c == nil {
return nil, false
}
if value, ok := c.NameServerGroups.Load(key); ok {
return value.(*proto.NameServerGroup), true
}
return nil, false
}
// SetNameServerGroup stores a name server group in the cache
func (c *DNSConfigCache) SetNameServerGroup(key string, value *proto.NameServerGroup) {
if c == nil {
return
}
c.NameServerGroups.Store(key, value)
}

View File

@@ -0,0 +1,842 @@
package controller
import (
"context"
"encoding/json"
"errors"
"fmt"
"os"
"path/filepath"
"slices"
"strconv"
"strings"
"sync"
"sync/atomic"
"time"
log "github.com/sirupsen/logrus"
"golang.org/x/exp/maps"
"golang.org/x/mod/semver"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller/cache"
"github.com/netbirdio/netbird/management/internals/shared/grpc"
"github.com/netbirdio/netbird/management/server/account"
"github.com/netbirdio/netbird/management/server/integrations/integrated_validator"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/management/proto"
"github.com/netbirdio/netbird/shared/management/status"
"github.com/netbirdio/netbird/util"
)
type Controller struct {
repo Repository
metrics *metrics
// This should not be here, but we need to maintain it for the time being
accountManagerMetrics *telemetry.AccountManagerMetrics
peersUpdateManager network_map.PeersUpdateManager
settingsManager settings.Manager
accountUpdateLocks sync.Map
sendAccountUpdateLocks sync.Map
updateAccountPeersBufferInterval atomic.Int64
// dnsDomain is used for peer resolution. This is appended to the peer's name
dnsDomain string
requestBuffer account.RequestBuffer
proxyController port_forwarding.Controller
integratedPeerValidator integrated_validator.IntegratedValidator
holder *types.Holder
expNewNetworkMap bool
expNewNetworkMapAIDs map[string]struct{}
}
type bufferUpdate struct {
mu sync.Mutex
next *time.Timer
update atomic.Bool
}
var _ network_map.Controller = (*Controller)(nil)
func NewController(ctx context.Context, store store.Store, metrics telemetry.AppMetrics, peersUpdateManager network_map.PeersUpdateManager, requestBuffer account.RequestBuffer, integratedPeerValidator integrated_validator.IntegratedValidator, settingsManager settings.Manager, dnsDomain string, proxyController port_forwarding.Controller) *Controller {
nMetrics, err := newMetrics(metrics.UpdateChannelMetrics())
if err != nil {
log.Fatal(fmt.Errorf("error creating metrics: %w", err))
}
newNetworkMapBuilder, err := strconv.ParseBool(os.Getenv(network_map.EnvNewNetworkMapBuilder))
if err != nil {
log.WithContext(ctx).Warnf("failed to parse %s, using default value false: %v", network_map.EnvNewNetworkMapBuilder, err)
newNetworkMapBuilder = false
}
ids := strings.Split(os.Getenv(network_map.EnvNewNetworkMapAccounts), ",")
expIDs := make(map[string]struct{}, len(ids))
for _, id := range ids {
expIDs[id] = struct{}{}
}
return &Controller{
repo: newRepository(store),
metrics: nMetrics,
accountManagerMetrics: metrics.AccountManagerMetrics(),
peersUpdateManager: peersUpdateManager,
requestBuffer: requestBuffer,
integratedPeerValidator: integratedPeerValidator,
settingsManager: settingsManager,
dnsDomain: dnsDomain,
proxyController: proxyController,
holder: types.NewHolder(),
expNewNetworkMap: newNetworkMapBuilder,
expNewNetworkMapAIDs: expIDs,
}
}
func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID string) error {
log.WithContext(ctx).Tracef("updating peers for account %s from %s", accountID, util.GetCallerName())
var (
account *types.Account
err error
)
if c.experimentalNetworkMap(accountID) {
account = c.getAccountFromHolderOrInit(accountID)
} else {
account, err = c.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
if err != nil {
return fmt.Errorf("failed to get account: %v", err)
}
}
globalStart := time.Now()
hasPeersConnected := false
for _, peer := range account.Peers {
if c.peersUpdateManager.HasChannel(peer.ID) {
hasPeersConnected = true
break
}
}
if !hasPeersConnected {
return nil
}
approvedPeersMap, err := c.integratedPeerValidator.GetValidatedPeers(ctx, account.Id, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra)
if err != nil {
return fmt.Errorf("failed to get validate peers: %v", err)
}
var wg sync.WaitGroup
semaphore := make(chan struct{}, 10)
dnsCache := &cache.DNSConfigCache{}
dnsDomain := c.GetDNSDomain(account.Settings)
customZone := account.GetPeersCustomZone(ctx, dnsDomain)
resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap()
if c.experimentalNetworkMap(accountID) {
c.initNetworkMapBuilderIfNeeded(account, approvedPeersMap)
}
proxyNetworkMaps, err := c.proxyController.GetProxyNetworkMapsAll(ctx, accountID, account.Peers)
if err != nil {
log.WithContext(ctx).Errorf("failed to get proxy network maps: %v", err)
return fmt.Errorf("failed to get proxy network maps: %v", err)
}
extraSetting, err := c.settingsManager.GetExtraSettings(ctx, accountID)
if err != nil {
return fmt.Errorf("failed to get flow enabled status: %v", err)
}
dnsFwdPort := computeForwarderPort(maps.Values(account.Peers), network_map.DnsForwarderPortMinVersion)
for _, peer := range account.Peers {
if !c.peersUpdateManager.HasChannel(peer.ID) {
log.WithContext(ctx).Tracef("peer %s doesn't have a channel, skipping network map update", peer.ID)
continue
}
wg.Add(1)
semaphore <- struct{}{}
go func(p *nbpeer.Peer) {
defer wg.Done()
defer func() { <-semaphore }()
start := time.Now()
postureChecks, err := c.getPeerPostureChecks(account, p.ID)
if err != nil {
log.WithContext(ctx).Debugf("failed to get posture checks for peer %s: %v", p.ID, err)
return
}
c.metrics.CountCalcPostureChecksDuration(time.Since(start))
start = time.Now()
var remotePeerNetworkMap *types.NetworkMap
if c.experimentalNetworkMap(accountID) {
remotePeerNetworkMap = c.getPeerNetworkMapExp(ctx, p.AccountID, p.ID, approvedPeersMap, customZone, c.accountManagerMetrics, resourcePolicies, routers)
} else {
remotePeerNetworkMap = account.GetPeerNetworkMap(ctx, p.ID, customZone, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics)
}
c.metrics.CountCalcPeerNetworkMapDuration(time.Since(start))
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]
if ok {
remotePeerNetworkMap.Merge(proxyNetworkMap)
}
peerGroups := account.GetPeerGroups(p.ID)
start = time.Now()
update := grpc.ToSyncResponse(ctx, nil, p, nil, nil, remotePeerNetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings, extraSetting, maps.Keys(peerGroups), dnsFwdPort)
c.metrics.CountToSyncResponseDuration(time.Since(start))
c.peersUpdateManager.SendUpdate(ctx, p.ID, &network_map.UpdateMessage{Update: update})
}(peer)
}
wg.Wait()
if c.accountManagerMetrics != nil {
c.accountManagerMetrics.CountUpdateAccountPeersDuration(time.Since(globalStart))
}
return nil
}
func (c *Controller) bufferSendUpdateAccountPeers(ctx context.Context, accountID string) error {
log.WithContext(ctx).Tracef("buffer sending update peers for account %s from %s", accountID, util.GetCallerName())
bufUpd, _ := c.sendAccountUpdateLocks.LoadOrStore(accountID, &bufferUpdate{})
b := bufUpd.(*bufferUpdate)
if !b.mu.TryLock() {
b.update.Store(true)
return nil
}
if b.next != nil {
b.next.Stop()
}
go func() {
defer b.mu.Unlock()
_ = c.sendUpdateAccountPeers(ctx, accountID)
if !b.update.Load() {
return
}
b.update.Store(false)
if b.next == nil {
b.next = time.AfterFunc(time.Duration(c.updateAccountPeersBufferInterval.Load()), func() {
_ = c.sendUpdateAccountPeers(ctx, accountID)
})
return
}
b.next.Reset(time.Duration(c.updateAccountPeersBufferInterval.Load()))
}()
return nil
}
// UpdatePeers updates all peers that belong to an account.
// Should be called when changes have to be synced to peers.
func (c *Controller) UpdateAccountPeers(ctx context.Context, accountID string) error {
if err := c.RecalculateNetworkMapCache(ctx, accountID); err != nil {
return fmt.Errorf("recalculate network map cache: %v", err)
}
return c.sendUpdateAccountPeers(ctx, accountID)
}
func (c *Controller) UpdateAccountPeer(ctx context.Context, accountId string, peerId string) error {
if !c.peersUpdateManager.HasChannel(peerId) {
return fmt.Errorf("peer %s doesn't have a channel, skipping network map update", peerId)
}
account, err := c.requestBuffer.GetAccountWithBackpressure(ctx, accountId)
if err != nil {
return fmt.Errorf("failed to send out updates to peer %s: %v", peerId, err)
}
peer := account.GetPeer(peerId)
if peer == nil {
return fmt.Errorf("peer %s doesn't exists in account %s", peerId, accountId)
}
approvedPeersMap, err := c.integratedPeerValidator.GetValidatedPeers(ctx, account.Id, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra)
if err != nil {
return fmt.Errorf("failed to get validated peers: %v", err)
}
dnsCache := &cache.DNSConfigCache{}
dnsDomain := c.GetDNSDomain(account.Settings)
customZone := account.GetPeersCustomZone(ctx, dnsDomain)
resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap()
postureChecks, err := c.getPeerPostureChecks(account, peerId)
if err != nil {
log.WithContext(ctx).Errorf("failed to send update to peer %s, failed to get posture checks: %v", peerId, err)
return fmt.Errorf("failed to get posture checks for peer %s: %v", peerId, err)
}
proxyNetworkMaps, err := c.proxyController.GetProxyNetworkMaps(ctx, account.Id, peer.ID, account.Peers)
if err != nil {
log.WithContext(ctx).Errorf("failed to get proxy network maps: %v", err)
return err
}
var remotePeerNetworkMap *types.NetworkMap
if c.experimentalNetworkMap(accountId) {
remotePeerNetworkMap = c.getPeerNetworkMapExp(ctx, peer.AccountID, peer.ID, approvedPeersMap, customZone, c.accountManagerMetrics, resourcePolicies, routers)
} else {
remotePeerNetworkMap = account.GetPeerNetworkMap(ctx, peerId, customZone, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics)
}
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]
if ok {
remotePeerNetworkMap.Merge(proxyNetworkMap)
}
extraSettings, err := c.settingsManager.GetExtraSettings(ctx, peer.AccountID)
if err != nil {
return fmt.Errorf("failed to get extra settings: %v", err)
}
peerGroups := account.GetPeerGroups(peerId)
dnsFwdPort := computeForwarderPort(maps.Values(account.Peers), network_map.DnsForwarderPortMinVersion)
update := grpc.ToSyncResponse(ctx, nil, peer, nil, nil, remotePeerNetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings, extraSettings, maps.Keys(peerGroups), dnsFwdPort)
c.peersUpdateManager.SendUpdate(ctx, peer.ID, &network_map.UpdateMessage{Update: update})
return nil
}
func (c *Controller) BufferUpdateAccountPeers(ctx context.Context, accountID string) error {
log.WithContext(ctx).Tracef("buffer updating peers for account %s from %s", accountID, util.GetCallerName())
bufUpd, _ := c.accountUpdateLocks.LoadOrStore(accountID, &bufferUpdate{})
b := bufUpd.(*bufferUpdate)
if !b.mu.TryLock() {
b.update.Store(true)
return nil
}
if b.next != nil {
b.next.Stop()
}
go func() {
defer b.mu.Unlock()
_ = c.UpdateAccountPeers(ctx, accountID)
if !b.update.Load() {
return
}
b.update.Store(false)
if b.next == nil {
b.next = time.AfterFunc(time.Duration(c.updateAccountPeersBufferInterval.Load()), func() {
_ = c.UpdateAccountPeers(ctx, accountID)
})
return
}
b.next.Reset(time.Duration(c.updateAccountPeersBufferInterval.Load()))
}()
return nil
}
func (c *Controller) DeletePeer(ctx context.Context, accountId string, peerId string) error {
network, err := c.repo.GetAccountNetwork(ctx, accountId)
if err != nil {
return err
}
peers, err := c.repo.GetAccountPeers(ctx, accountId)
if err != nil {
return err
}
dnsFwdPort := computeForwarderPort(peers, network_map.DnsForwarderPortMinVersion)
c.peersUpdateManager.SendUpdate(ctx, peerId, &network_map.UpdateMessage{
Update: &proto.SyncResponse{
RemotePeers: []*proto.RemotePeerConfig{},
RemotePeersIsEmpty: true,
NetworkMap: &proto.NetworkMap{
Serial: network.CurrentSerial(),
RemotePeers: []*proto.RemotePeerConfig{},
RemotePeersIsEmpty: true,
FirewallRules: []*proto.FirewallRule{},
FirewallRulesIsEmpty: true,
DNSConfig: &proto.DNSConfig{
ForwarderPort: dnsFwdPort,
},
},
},
})
c.peersUpdateManager.CloseChannel(ctx, peerId)
return nil
}
func (c *Controller) GetValidatedPeerWithMap(ctx context.Context, isRequiresApproval bool, accountID string, peer *nbpeer.Peer) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) {
if isRequiresApproval {
network, err := c.repo.GetAccountNetwork(ctx, accountID)
if err != nil {
return nil, nil, nil, 0, err
}
emptyMap := &types.NetworkMap{
Network: network.Copy(),
}
return peer, emptyMap, nil, 0, nil
}
var (
account *types.Account
err error
)
if c.experimentalNetworkMap(accountID) {
account = c.getAccountFromHolderOrInit(accountID)
} else {
account, err = c.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
if err != nil {
return nil, nil, nil, 0, err
}
}
approvedPeersMap, err := c.integratedPeerValidator.GetValidatedPeers(ctx, account.Id, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra)
if err != nil {
return nil, nil, nil, 0, err
}
startPosture := time.Now()
postureChecks, err := c.getPeerPostureChecks(account, peer.ID)
if err != nil {
return nil, nil, nil, 0, err
}
log.WithContext(ctx).Debugf("getPeerPostureChecks took %s", time.Since(startPosture))
customZone := account.GetPeersCustomZone(ctx, c.GetDNSDomain(account.Settings))
resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap()
proxyNetworkMaps, err := c.proxyController.GetProxyNetworkMaps(ctx, account.Id, peer.ID, account.Peers)
if err != nil {
log.WithContext(ctx).Errorf("failed to get proxy network maps: %v", err)
return nil, nil, nil, 0, err
}
var networkMap *types.NetworkMap
if c.experimentalNetworkMap(accountID) {
networkMap = c.getPeerNetworkMapExp(ctx, peer.AccountID, peer.ID, approvedPeersMap, customZone, c.accountManagerMetrics, resourcePolicies, routers)
} else {
networkMap = account.GetPeerNetworkMap(ctx, peer.ID, customZone, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics)
}
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]
if ok {
networkMap.Merge(proxyNetworkMap)
}
dnsFwdPort := computeForwarderPort(maps.Values(account.Peers), network_map.DnsForwarderPortMinVersion)
return peer, networkMap, postureChecks, dnsFwdPort, nil
}
func (c *Controller) initNetworkMapBuilderIfNeeded(account *types.Account, validatedPeers map[string]struct{}) {
c.enrichAccountFromHolder(account)
account.InitNetworkMapBuilderIfNeeded(validatedPeers)
}
func (c *Controller) getPeerNetworkMapExp(
ctx context.Context,
accountId string,
peerId string,
validatedPeers map[string]struct{},
customZone nbdns.CustomZone,
metrics *telemetry.AccountManagerMetrics,
resourcePolicies map[string][]*types.Policy,
routers map[string]map[string]*routerTypes.NetworkRouter,
) *types.NetworkMap {
account := c.getAccountFromHolderOrInit(accountId)
if account == nil {
log.WithContext(ctx).Warnf("account %s not found in holder when getting peer network map", accountId)
return &types.NetworkMap{
Network: &types.Network{},
}
}
expMap := account.GetPeerNetworkMapExp(ctx, peerId, customZone, validatedPeers, metrics)
go func() {
legacyMap := account.GetPeerNetworkMap(ctx, peerId, customZone, validatedPeers, resourcePolicies, routers, nil)
c.compareAndSaveNetworkMaps(ctx, accountId, peerId, expMap, legacyMap)
}()
return expMap
}
func (c *Controller) compareAndSaveNetworkMaps(ctx context.Context, accountId, peerId string, expMap, legacyMap *types.NetworkMap) {
expBytes, err := json.Marshal(expMap)
if err != nil {
log.WithContext(ctx).Warnf("failed to marshal experimental network map: %v", err)
return
}
legacyBytes, err := json.Marshal(legacyMap)
if err != nil {
log.WithContext(ctx).Warnf("failed to marshal legacy network map: %v", err)
return
}
// if len(expBytes) == len(legacyBytes) || math.Abs(float64(len(expBytes)-len(legacyBytes))) < 5 {
// log.WithContext(ctx).Debugf("network maps are equal for peer %s in account %s (size: %d bytes)", peerId, accountId, len(expBytes))
// return
// }
timestamp := time.Now().UnixMicro()
baseDir := filepath.Join("debug_networkmaps", accountId, peerId)
if err := os.MkdirAll(baseDir, 0o755); err != nil {
log.WithContext(ctx).Warnf("failed to create debug directory %s: %v", baseDir, err)
return
}
expFile := filepath.Join(baseDir, fmt.Sprintf("exp_networkmap_%d_%d.json", expMap.Network.Serial, timestamp))
if err := os.WriteFile(expFile, expBytes, 0o644); err != nil {
log.WithContext(ctx).Warnf("failed to write experimental network map to %s: %v", expFile, err)
return
}
legacyFile := filepath.Join(baseDir, fmt.Sprintf("legacy_networkmap_%d_%d.json", legacyMap.Network.Serial, timestamp))
if err := os.WriteFile(legacyFile, legacyBytes, 0o644); err != nil {
log.WithContext(ctx).Warnf("failed to write legacy network map to %s: %v", legacyFile, err)
return
}
// log.WithContext(ctx).Infof("network maps differ for peer %s in account %s - saved to %s (exp: %d bytes, legacy: %d bytes)", peerId, accountId, baseDir, len(expBytes), len(legacyBytes))
}
func (c *Controller) onPeerAddedUpdNetworkMapCache(account *types.Account, peerId string) error {
c.enrichAccountFromHolder(account)
return account.OnPeerAddedUpdNetworkMapCache(peerId)
}
func (c *Controller) onPeerDeletedUpdNetworkMapCache(account *types.Account, peerId string) error {
c.enrichAccountFromHolder(account)
return account.OnPeerDeletedUpdNetworkMapCache(peerId)
}
func (c *Controller) UpdatePeerInNetworkMapCache(accountId string, peer *nbpeer.Peer) {
account := c.getAccountFromHolder(accountId)
if account == nil {
return
}
account.UpdatePeerInNetworkMapCache(peer)
}
func (c *Controller) recalculateNetworkMapCache(account *types.Account, validatedPeers map[string]struct{}) {
account.RecalculateNetworkMapCache(validatedPeers)
c.updateAccountInHolder(account)
}
func (c *Controller) RecalculateNetworkMapCache(ctx context.Context, accountId string) error {
if c.experimentalNetworkMap(accountId) {
account, err := c.requestBuffer.GetAccountWithBackpressure(ctx, accountId)
if err != nil {
return err
}
validatedPeers, err := c.integratedPeerValidator.GetValidatedPeers(ctx, account.Id, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra)
if err != nil {
log.WithContext(ctx).Errorf("failed to get validate peers: %v", err)
return err
}
c.recalculateNetworkMapCache(account, validatedPeers)
}
return nil
}
func (c *Controller) experimentalNetworkMap(accountId string) bool {
_, ok := c.expNewNetworkMapAIDs[accountId]
return c.expNewNetworkMap || ok
}
func (c *Controller) enrichAccountFromHolder(account *types.Account) {
a := c.holder.GetAccount(account.Id)
if a == nil {
c.holder.AddAccount(account)
return
}
account.NetworkMapCache = a.NetworkMapCache
if account.NetworkMapCache == nil {
return
}
account.NetworkMapCache.UpdateAccountPointer(account)
c.holder.AddAccount(account)
}
func (c *Controller) getAccountFromHolder(accountID string) *types.Account {
return c.holder.GetAccount(accountID)
}
func (c *Controller) getAccountFromHolderOrInit(accountID string) *types.Account {
a := c.holder.GetAccount(accountID)
if a != nil {
return a
}
account, err := c.holder.LoadOrStoreFunc(accountID, c.requestBuffer.GetAccountWithBackpressure)
if err != nil {
return nil
}
return account
}
func (c *Controller) updateAccountInHolder(account *types.Account) {
c.holder.AddAccount(account)
}
// GetDNSDomain returns the configured dnsDomain
func (c *Controller) GetDNSDomain(settings *types.Settings) string {
if settings == nil {
return c.dnsDomain
}
if settings.DNSDomain == "" {
return c.dnsDomain
}
return settings.DNSDomain
}
// getPeerPostureChecks returns the posture checks applied for a given peer.
func (c *Controller) getPeerPostureChecks(account *types.Account, peerID string) ([]*posture.Checks, error) {
peerPostureChecks := make(map[string]*posture.Checks)
if len(account.PostureChecks) == 0 {
return nil, nil
}
for _, policy := range account.Policies {
if !policy.Enabled || len(policy.SourcePostureChecks) == 0 {
continue
}
if err := addPolicyPostureChecks(account, peerID, policy, peerPostureChecks); err != nil {
return nil, err
}
}
return maps.Values(peerPostureChecks), nil
}
func (c *Controller) StartWarmup(ctx context.Context) {
var initialInterval int64
intervalStr := os.Getenv("NB_PEER_UPDATE_INTERVAL_MS")
interval, err := strconv.Atoi(intervalStr)
if err != nil {
initialInterval = 1
log.WithContext(ctx).Warnf("failed to parse peer update interval, using default value %dms: %v", initialInterval, err)
} else {
initialInterval = int64(interval) * 10
go func() {
startupPeriodStr := os.Getenv("NB_PEER_UPDATE_STARTUP_PERIOD_S")
startupPeriod, err := strconv.Atoi(startupPeriodStr)
if err != nil {
startupPeriod = 1
log.WithContext(ctx).Warnf("failed to parse peer update startup period, using default value %ds: %v", startupPeriod, err)
}
time.Sleep(time.Duration(startupPeriod) * time.Second)
c.updateAccountPeersBufferInterval.Store(int64(time.Duration(interval) * time.Millisecond))
log.WithContext(ctx).Infof("set peer update buffer interval to %dms", interval)
}()
}
c.updateAccountPeersBufferInterval.Store(int64(time.Duration(initialInterval) * time.Millisecond))
log.WithContext(ctx).Infof("set peer update buffer interval to %dms", initialInterval)
}
// computeForwarderPort checks if all peers in the account have updated to a specific version or newer.
// If all peers have the required version, it returns the new well-known port (22054), otherwise returns 0.
func computeForwarderPort(peers []*nbpeer.Peer, requiredVersion string) int64 {
if len(peers) == 0 {
return int64(network_map.OldForwarderPort)
}
reqVer := semver.Canonical(requiredVersion)
// Check if all peers have the required version or newer
for _, peer := range peers {
// Development version is always supported
if peer.Meta.WtVersion == "development" {
continue
}
peerVersion := semver.Canonical("v" + peer.Meta.WtVersion)
if peerVersion == "" {
// If any peer doesn't have version info, return 0
return int64(network_map.OldForwarderPort)
}
// Compare versions
if semver.Compare(peerVersion, reqVer) < 0 {
return int64(network_map.OldForwarderPort)
}
}
// All peers have the required version or newer
return int64(network_map.DnsForwarderPort)
}
// addPolicyPostureChecks adds posture checks from a policy to the peer posture checks map if the peer is in the policy's source groups.
func addPolicyPostureChecks(account *types.Account, peerID string, policy *types.Policy, peerPostureChecks map[string]*posture.Checks) error {
isInGroup, err := isPeerInPolicySourceGroups(account, peerID, policy)
if err != nil {
return err
}
if !isInGroup {
return nil
}
for _, sourcePostureCheckID := range policy.SourcePostureChecks {
postureCheck := account.GetPostureChecks(sourcePostureCheckID)
if postureCheck == nil {
return errors.New("failed to add policy posture checks: posture checks not found")
}
peerPostureChecks[sourcePostureCheckID] = postureCheck
}
return nil
}
// isPeerInPolicySourceGroups checks if a peer is present in any of the policy rule source groups.
func isPeerInPolicySourceGroups(account *types.Account, peerID string, policy *types.Policy) (bool, error) {
for _, rule := range policy.Rules {
if !rule.Enabled {
continue
}
for _, sourceGroup := range rule.Sources {
group := account.GetGroup(sourceGroup)
if group == nil {
return false, fmt.Errorf("failed to check peer in policy source group: group not found")
}
if slices.Contains(group.Peers, peerID) {
return true, nil
}
}
}
return false, nil
}
func (c *Controller) OnPeerUpdated(accountId string, peer *nbpeer.Peer) {
c.UpdatePeerInNetworkMapCache(accountId, peer)
_ = c.bufferSendUpdateAccountPeers(context.Background(), accountId)
}
func (c *Controller) OnPeerAdded(ctx context.Context, accountID string, peerID string) error {
if c.experimentalNetworkMap(accountID) {
account, err := c.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
if err != nil {
return err
}
err = c.onPeerAddedUpdNetworkMapCache(account, peerID)
if err != nil {
return err
}
}
return c.bufferSendUpdateAccountPeers(ctx, accountID)
}
func (c *Controller) OnPeerDeleted(ctx context.Context, accountID string, peerID string) error {
if c.experimentalNetworkMap(accountID) {
account, err := c.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
if err != nil {
return err
}
err = c.onPeerDeletedUpdNetworkMapCache(account, peerID)
if err != nil {
return err
}
}
return c.bufferSendUpdateAccountPeers(ctx, accountID)
}
// GetNetworkMap returns Network map for a given peer (omits original peer from the Peers result)
func (c *Controller) GetNetworkMap(ctx context.Context, peerID string) (*types.NetworkMap, error) {
account, err := c.repo.GetAccountByPeerID(ctx, peerID)
if err != nil {
return nil, err
}
peer := account.GetPeer(peerID)
if peer == nil {
return nil, status.Errorf(status.NotFound, "peer with ID %s not found", peerID)
}
groups := make(map[string][]string)
for groupID, group := range account.Groups {
groups[groupID] = group.Peers
}
validatedPeers, err := c.integratedPeerValidator.GetValidatedPeers(ctx, account.Id, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra)
if err != nil {
return nil, err
}
customZone := account.GetPeersCustomZone(ctx, c.GetDNSDomain(account.Settings))
resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap()
proxyNetworkMaps, err := c.proxyController.GetProxyNetworkMaps(ctx, account.Id, peerID, account.Peers)
if err != nil {
log.WithContext(ctx).Errorf("failed to get proxy network maps: %v", err)
return nil, err
}
var networkMap *types.NetworkMap
if c.experimentalNetworkMap(peer.AccountID) {
networkMap = c.getPeerNetworkMapExp(ctx, peer.AccountID, peerID, validatedPeers, customZone, nil, resourcePolicies, routers)
} else {
networkMap = account.GetPeerNetworkMap(ctx, peer.ID, customZone, validatedPeers, resourcePolicies, routers, nil)
}
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]
if ok {
networkMap.Merge(proxyNetworkMap)
}
return networkMap, nil
}
func (c *Controller) DisconnectPeers(ctx context.Context, peerIDs []string) {
c.peersUpdateManager.CloseChannels(ctx, peerIDs)
}
func (c *Controller) IsConnected(peerID string) bool {
return c.peersUpdateManager.HasChannel(peerID)
}

View File

@@ -0,0 +1,244 @@
package controller
import (
"context"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
"github.com/netbirdio/netbird/management/server/mock_server"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
)
func TestComputeForwarderPort(t *testing.T) {
// Test with empty peers list
peers := []*nbpeer.Peer{}
result := computeForwarderPort(peers, "v0.59.0")
if result != int64(network_map.OldForwarderPort) {
t.Errorf("Expected %d for empty peers list, got %d", network_map.OldForwarderPort, result)
}
// Test with peers that have old versions
peers = []*nbpeer.Peer{
{
Meta: nbpeer.PeerSystemMeta{
WtVersion: "0.57.0",
},
},
{
Meta: nbpeer.PeerSystemMeta{
WtVersion: "0.26.0",
},
},
}
result = computeForwarderPort(peers, "v0.59.0")
if result != int64(network_map.OldForwarderPort) {
t.Errorf("Expected %d for peers with old versions, got %d", network_map.OldForwarderPort, result)
}
// Test with peers that have new versions
peers = []*nbpeer.Peer{
{
Meta: nbpeer.PeerSystemMeta{
WtVersion: "0.59.0",
},
},
{
Meta: nbpeer.PeerSystemMeta{
WtVersion: "0.59.0",
},
},
}
result = computeForwarderPort(peers, "v0.59.0")
if result != int64(network_map.DnsForwarderPort) {
t.Errorf("Expected %d for peers with new versions, got %d", network_map.DnsForwarderPort, result)
}
// Test with peers that have mixed versions
peers = []*nbpeer.Peer{
{
Meta: nbpeer.PeerSystemMeta{
WtVersion: "0.59.0",
},
},
{
Meta: nbpeer.PeerSystemMeta{
WtVersion: "0.57.0",
},
},
}
result = computeForwarderPort(peers, "v0.59.0")
if result != int64(network_map.OldForwarderPort) {
t.Errorf("Expected %d for peers with mixed versions, got %d", network_map.OldForwarderPort, result)
}
// Test with peers that have empty version
peers = []*nbpeer.Peer{
{
Meta: nbpeer.PeerSystemMeta{
WtVersion: "",
},
},
}
result = computeForwarderPort(peers, "v0.59.0")
if result != int64(network_map.OldForwarderPort) {
t.Errorf("Expected %d for peers with empty version, got %d", network_map.OldForwarderPort, result)
}
peers = []*nbpeer.Peer{
{
Meta: nbpeer.PeerSystemMeta{
WtVersion: "development",
},
},
}
result = computeForwarderPort(peers, "v0.59.0")
if result == int64(network_map.OldForwarderPort) {
t.Errorf("Expected %d for peers with dev version, got %d", network_map.DnsForwarderPort, result)
}
// Test with peers that have unknown version string
peers = []*nbpeer.Peer{
{
Meta: nbpeer.PeerSystemMeta{
WtVersion: "unknown",
},
},
}
result = computeForwarderPort(peers, "v0.59.0")
if result != int64(network_map.OldForwarderPort) {
t.Errorf("Expected %d for peers with unknown version, got %d", network_map.OldForwarderPort, result)
}
}
func TestBufferUpdateAccountPeers(t *testing.T) {
const (
peersCount = 1000
updateAccountInterval = 50 * time.Millisecond
)
var (
deletedPeers, updatePeersDeleted, updatePeersRuns atomic.Int32
uapLastRun, dpLastRun atomic.Int64
totalNewRuns, totalOldRuns int
)
uap := func(ctx context.Context, accountID string) {
updatePeersDeleted.Store(deletedPeers.Load())
updatePeersRuns.Add(1)
uapLastRun.Store(time.Now().UnixMilli())
time.Sleep(100 * time.Millisecond)
}
t.Run("new approach", func(t *testing.T) {
updatePeersRuns.Store(0)
updatePeersDeleted.Store(0)
deletedPeers.Store(0)
var mustore sync.Map
bufupd := func(ctx context.Context, accountID string) {
mu, _ := mustore.LoadOrStore(accountID, &bufferUpdate{})
b := mu.(*bufferUpdate)
if !b.mu.TryLock() {
b.update.Store(true)
return
}
if b.next != nil {
b.next.Stop()
}
go func() {
defer b.mu.Unlock()
uap(ctx, accountID)
if !b.update.Load() {
return
}
b.update.Store(false)
b.next = time.AfterFunc(updateAccountInterval, func() {
uap(ctx, accountID)
})
}()
}
dp := func(ctx context.Context, accountID, peerID, userID string) error {
deletedPeers.Add(1)
dpLastRun.Store(time.Now().UnixMilli())
time.Sleep(10 * time.Millisecond)
bufupd(ctx, accountID)
return nil
}
am := mock_server.MockAccountManager{
UpdateAccountPeersFunc: uap,
BufferUpdateAccountPeersFunc: bufupd,
DeletePeerFunc: dp,
}
empty := ""
for range peersCount {
//nolint
am.DeletePeer(context.Background(), empty, empty, empty)
}
time.Sleep(100 * time.Millisecond)
assert.Equal(t, peersCount, int(deletedPeers.Load()), "Expected all peers to be deleted")
assert.Equal(t, peersCount, int(updatePeersDeleted.Load()), "Expected all peers to be updated in the buffer")
assert.GreaterOrEqual(t, uapLastRun.Load(), dpLastRun.Load(), "Expected update account peers to run after delete peer")
totalNewRuns = int(updatePeersRuns.Load())
})
t.Run("old approach", func(t *testing.T) {
updatePeersRuns.Store(0)
updatePeersDeleted.Store(0)
deletedPeers.Store(0)
var mustore sync.Map
bufupd := func(ctx context.Context, accountID string) {
mu, _ := mustore.LoadOrStore(accountID, &sync.Mutex{})
b := mu.(*sync.Mutex)
if !b.TryLock() {
return
}
go func() {
time.Sleep(updateAccountInterval)
b.Unlock()
uap(ctx, accountID)
}()
}
dp := func(ctx context.Context, accountID, peerID, userID string) error {
deletedPeers.Add(1)
dpLastRun.Store(time.Now().UnixMilli())
time.Sleep(10 * time.Millisecond)
bufupd(ctx, accountID)
return nil
}
am := mock_server.MockAccountManager{
UpdateAccountPeersFunc: uap,
BufferUpdateAccountPeersFunc: bufupd,
DeletePeerFunc: dp,
}
empty := ""
for range peersCount {
//nolint
am.DeletePeer(context.Background(), empty, empty, empty)
}
time.Sleep(100 * time.Millisecond)
assert.Equal(t, peersCount, int(deletedPeers.Load()), "Expected all peers to be deleted")
assert.Equal(t, peersCount, int(updatePeersDeleted.Load()), "Expected all peers to be updated in the buffer")
assert.GreaterOrEqual(t, uapLastRun.Load(), dpLastRun.Load(), "Expected update account peers to run after delete peer")
totalOldRuns = int(updatePeersRuns.Load())
})
assert.Less(t, totalNewRuns, totalOldRuns, "Expected new approach to run less than old approach. New runs: %d, Old runs: %d", totalNewRuns, totalOldRuns)
t.Logf("New runs: %d, Old runs: %d", totalNewRuns, totalOldRuns)
}

View File

@@ -0,0 +1,15 @@
package controller
import (
"github.com/netbirdio/netbird/management/server/telemetry"
)
type metrics struct {
*telemetry.UpdateChannelMetrics
}
func newMetrics(updateChannelMetrics *telemetry.UpdateChannelMetrics) (*metrics, error) {
return &metrics{
updateChannelMetrics,
}, nil
}

View File

@@ -0,0 +1,39 @@
package controller
import (
"context"
"github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
)
type Repository interface {
GetAccountNetwork(ctx context.Context, accountID string) (*types.Network, error)
GetAccountPeers(ctx context.Context, accountID string) ([]*peer.Peer, error)
GetAccountByPeerID(ctx context.Context, peerID string) (*types.Account, error)
}
type repository struct {
store store.Store
}
var _ Repository = (*repository)(nil)
func newRepository(s store.Store) Repository {
return &repository{
store: s,
}
}
func (r *repository) GetAccountNetwork(ctx context.Context, accountID string) (*types.Network, error) {
return r.store.GetAccountNetwork(ctx, store.LockingStrengthNone, accountID)
}
func (r *repository) GetAccountPeers(ctx context.Context, accountID string) ([]*peer.Peer, error) {
return r.store.GetAccountPeers(ctx, store.LockingStrengthNone, accountID, "", "")
}
func (r *repository) GetAccountByPeerID(ctx context.Context, peerID string) (*types.Account, error) {
return r.store.GetAccountByPeerID(ctx, peerID)
}

View File

@@ -0,0 +1,39 @@
package network_map
//go:generate go run go.uber.org/mock/mockgen -package network_map -destination=interface_mock.go -source=./interface.go -build_flags=-mod=mod
import (
"context"
nbdns "github.com/netbirdio/netbird/dns"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/management/server/types"
)
const (
EnvNewNetworkMapBuilder = "NB_EXPERIMENT_NETWORK_MAP"
EnvNewNetworkMapAccounts = "NB_EXPERIMENT_NETWORK_MAP_ACCOUNTS"
DnsForwarderPort = nbdns.ForwarderServerPort
OldForwarderPort = nbdns.ForwarderClientPort
DnsForwarderPortMinVersion = "v0.59.0"
)
type Controller interface {
UpdateAccountPeers(ctx context.Context, accountID string) error
UpdateAccountPeer(ctx context.Context, accountId string, peerId string) error
BufferUpdateAccountPeers(ctx context.Context, accountID string) error
GetValidatedPeerWithMap(ctx context.Context, isRequiresApproval bool, accountID string, p *nbpeer.Peer) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error)
GetDNSDomain(settings *types.Settings) string
StartWarmup(context.Context)
GetNetworkMap(ctx context.Context, peerID string) (*types.NetworkMap, error)
DeletePeer(ctx context.Context, accountId string, peerId string) error
OnPeerUpdated(accountId string, peer *nbpeer.Peer)
OnPeerAdded(ctx context.Context, accountID string, peerID string) error
OnPeerDeleted(ctx context.Context, accountID string, peerID string) error
DisconnectPeers(ctx context.Context, peerIDs []string)
IsConnected(peerID string) bool
}

View File

@@ -0,0 +1,225 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: ./interface.go
//
// Generated by this command:
//
// mockgen -package network_map -destination=interface_mock.go -source=./interface.go -build_flags=-mod=mod
//
// Package network_map is a generated GoMock package.
package network_map
import (
context "context"
reflect "reflect"
peer "github.com/netbirdio/netbird/management/server/peer"
posture "github.com/netbirdio/netbird/management/server/posture"
types "github.com/netbirdio/netbird/management/server/types"
gomock "go.uber.org/mock/gomock"
)
// MockController is a mock of Controller interface.
type MockController struct {
ctrl *gomock.Controller
recorder *MockControllerMockRecorder
isgomock struct{}
}
// MockControllerMockRecorder is the mock recorder for MockController.
type MockControllerMockRecorder struct {
mock *MockController
}
// NewMockController creates a new mock instance.
func NewMockController(ctrl *gomock.Controller) *MockController {
mock := &MockController{ctrl: ctrl}
mock.recorder = &MockControllerMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockController) EXPECT() *MockControllerMockRecorder {
return m.recorder
}
// BufferUpdateAccountPeers mocks base method.
func (m *MockController) BufferUpdateAccountPeers(ctx context.Context, accountID string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "BufferUpdateAccountPeers", ctx, accountID)
ret0, _ := ret[0].(error)
return ret0
}
// BufferUpdateAccountPeers indicates an expected call of BufferUpdateAccountPeers.
func (mr *MockControllerMockRecorder) BufferUpdateAccountPeers(ctx, accountID any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BufferUpdateAccountPeers", reflect.TypeOf((*MockController)(nil).BufferUpdateAccountPeers), ctx, accountID)
}
// DeletePeer mocks base method.
func (m *MockController) DeletePeer(ctx context.Context, accountId, peerId string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "DeletePeer", ctx, accountId, peerId)
ret0, _ := ret[0].(error)
return ret0
}
// DeletePeer indicates an expected call of DeletePeer.
func (mr *MockControllerMockRecorder) DeletePeer(ctx, accountId, peerId any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeletePeer", reflect.TypeOf((*MockController)(nil).DeletePeer), ctx, accountId, peerId)
}
// DisconnectPeers mocks base method.
func (m *MockController) DisconnectPeers(ctx context.Context, peerIDs []string) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "DisconnectPeers", ctx, peerIDs)
}
// DisconnectPeers indicates an expected call of DisconnectPeers.
func (mr *MockControllerMockRecorder) DisconnectPeers(ctx, peerIDs any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DisconnectPeers", reflect.TypeOf((*MockController)(nil).DisconnectPeers), ctx, peerIDs)
}
// GetDNSDomain mocks base method.
func (m *MockController) GetDNSDomain(settings *types.Settings) string {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetDNSDomain", settings)
ret0, _ := ret[0].(string)
return ret0
}
// GetDNSDomain indicates an expected call of GetDNSDomain.
func (mr *MockControllerMockRecorder) GetDNSDomain(settings any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetDNSDomain", reflect.TypeOf((*MockController)(nil).GetDNSDomain), settings)
}
// GetNetworkMap mocks base method.
func (m *MockController) GetNetworkMap(ctx context.Context, peerID string) (*types.NetworkMap, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetNetworkMap", ctx, peerID)
ret0, _ := ret[0].(*types.NetworkMap)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetNetworkMap indicates an expected call of GetNetworkMap.
func (mr *MockControllerMockRecorder) GetNetworkMap(ctx, peerID any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetNetworkMap", reflect.TypeOf((*MockController)(nil).GetNetworkMap), ctx, peerID)
}
// GetValidatedPeerWithMap mocks base method.
func (m *MockController) GetValidatedPeerWithMap(ctx context.Context, isRequiresApproval bool, accountID string, p *peer.Peer) (*peer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetValidatedPeerWithMap", ctx, isRequiresApproval, accountID, p)
ret0, _ := ret[0].(*peer.Peer)
ret1, _ := ret[1].(*types.NetworkMap)
ret2, _ := ret[2].([]*posture.Checks)
ret3, _ := ret[3].(int64)
ret4, _ := ret[4].(error)
return ret0, ret1, ret2, ret3, ret4
}
// GetValidatedPeerWithMap indicates an expected call of GetValidatedPeerWithMap.
func (mr *MockControllerMockRecorder) GetValidatedPeerWithMap(ctx, isRequiresApproval, accountID, p any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetValidatedPeerWithMap", reflect.TypeOf((*MockController)(nil).GetValidatedPeerWithMap), ctx, isRequiresApproval, accountID, p)
}
// IsConnected mocks base method.
func (m *MockController) IsConnected(peerID string) bool {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "IsConnected", peerID)
ret0, _ := ret[0].(bool)
return ret0
}
// IsConnected indicates an expected call of IsConnected.
func (mr *MockControllerMockRecorder) IsConnected(peerID any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsConnected", reflect.TypeOf((*MockController)(nil).IsConnected), peerID)
}
// OnPeerAdded mocks base method.
func (m *MockController) OnPeerAdded(ctx context.Context, accountID, peerID string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "OnPeerAdded", ctx, accountID, peerID)
ret0, _ := ret[0].(error)
return ret0
}
// OnPeerAdded indicates an expected call of OnPeerAdded.
func (mr *MockControllerMockRecorder) OnPeerAdded(ctx, accountID, peerID any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnPeerAdded", reflect.TypeOf((*MockController)(nil).OnPeerAdded), ctx, accountID, peerID)
}
// OnPeerDeleted mocks base method.
func (m *MockController) OnPeerDeleted(ctx context.Context, accountID, peerID string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "OnPeerDeleted", ctx, accountID, peerID)
ret0, _ := ret[0].(error)
return ret0
}
// OnPeerDeleted indicates an expected call of OnPeerDeleted.
func (mr *MockControllerMockRecorder) OnPeerDeleted(ctx, accountID, peerID any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnPeerDeleted", reflect.TypeOf((*MockController)(nil).OnPeerDeleted), ctx, accountID, peerID)
}
// OnPeerUpdated mocks base method.
func (m *MockController) OnPeerUpdated(accountId string, peer *peer.Peer) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "OnPeerUpdated", accountId, peer)
}
// OnPeerUpdated indicates an expected call of OnPeerUpdated.
func (mr *MockControllerMockRecorder) OnPeerUpdated(accountId, peer any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnPeerUpdated", reflect.TypeOf((*MockController)(nil).OnPeerUpdated), accountId, peer)
}
// StartWarmup mocks base method.
func (m *MockController) StartWarmup(arg0 context.Context) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "StartWarmup", arg0)
}
// StartWarmup indicates an expected call of StartWarmup.
func (mr *MockControllerMockRecorder) StartWarmup(arg0 any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StartWarmup", reflect.TypeOf((*MockController)(nil).StartWarmup), arg0)
}
// UpdateAccountPeer mocks base method.
func (m *MockController) UpdateAccountPeer(ctx context.Context, accountId, peerId string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "UpdateAccountPeer", ctx, accountId, peerId)
ret0, _ := ret[0].(error)
return ret0
}
// UpdateAccountPeer indicates an expected call of UpdateAccountPeer.
func (mr *MockControllerMockRecorder) UpdateAccountPeer(ctx, accountId, peerId any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateAccountPeer", reflect.TypeOf((*MockController)(nil).UpdateAccountPeer), ctx, accountId, peerId)
}
// UpdateAccountPeers mocks base method.
func (m *MockController) UpdateAccountPeers(ctx context.Context, accountID string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "UpdateAccountPeers", ctx, accountID)
ret0, _ := ret[0].(error)
return ret0
}
// UpdateAccountPeers indicates an expected call of UpdateAccountPeers.
func (mr *MockControllerMockRecorder) UpdateAccountPeers(ctx, accountID any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateAccountPeers", reflect.TypeOf((*MockController)(nil).UpdateAccountPeers), ctx, accountID)
}

View File

@@ -0,0 +1 @@
package network_map

View File

@@ -0,0 +1,13 @@
package network_map
import "context"
type PeersUpdateManager interface {
SendUpdate(ctx context.Context, peerID string, update *UpdateMessage)
CreateChannel(ctx context.Context, peerID string) chan *UpdateMessage
CloseChannel(ctx context.Context, peerID string)
CountStreams() int
HasChannel(peerID string) bool
CloseChannels(ctx context.Context, peerIDs []string)
GetAllConnectedPeers() map[string]struct{}
}

View File

@@ -1,4 +1,4 @@
package server
package update_channel
import (
"context"
@@ -7,38 +7,34 @@ import (
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/shared/management/proto"
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/management/server/types"
)
const channelBufferSize = 100
type UpdateMessage struct {
Update *proto.SyncResponse
NetworkMap *types.NetworkMap
}
type PeersUpdateManager struct {
// peerChannels is an update channel indexed by Peer.ID
peerChannels map[string]chan *UpdateMessage
peerChannels map[string]chan *network_map.UpdateMessage
// channelsMux keeps the mutex to access peerChannels
channelsMux *sync.RWMutex
// metrics provides method to collect application metrics
metrics telemetry.AppMetrics
}
var _ network_map.PeersUpdateManager = (*PeersUpdateManager)(nil)
// NewPeersUpdateManager returns a new instance of PeersUpdateManager
func NewPeersUpdateManager(metrics telemetry.AppMetrics) *PeersUpdateManager {
return &PeersUpdateManager{
peerChannels: make(map[string]chan *UpdateMessage),
peerChannels: make(map[string]chan *network_map.UpdateMessage),
channelsMux: &sync.RWMutex{},
metrics: metrics,
}
}
// SendUpdate sends update message to the peer's channel
func (p *PeersUpdateManager) SendUpdate(ctx context.Context, peerID string, update *UpdateMessage) {
func (p *PeersUpdateManager) SendUpdate(ctx context.Context, peerID string, update *network_map.UpdateMessage) {
start := time.Now()
var found, dropped bool
@@ -66,7 +62,7 @@ func (p *PeersUpdateManager) SendUpdate(ctx context.Context, peerID string, upda
}
// CreateChannel creates a go channel for a given peer used to deliver updates relevant to the peer.
func (p *PeersUpdateManager) CreateChannel(ctx context.Context, peerID string) chan *UpdateMessage {
func (p *PeersUpdateManager) CreateChannel(ctx context.Context, peerID string) chan *network_map.UpdateMessage {
start := time.Now()
closed := false
@@ -85,7 +81,7 @@ func (p *PeersUpdateManager) CreateChannel(ctx context.Context, peerID string) c
close(channel)
}
// mbragin: todo shouldn't it be more? or configurable?
channel := make(chan *UpdateMessage, channelBufferSize)
channel := make(chan *network_map.UpdateMessage, channelBufferSize)
p.peerChannels[peerID] = channel
log.WithContext(ctx).Debugf("opened updates channel for a peer %s", peerID)
@@ -176,3 +172,9 @@ func (p *PeersUpdateManager) HasChannel(peerID string) bool {
return ok
}
func (p *PeersUpdateManager) CountStreams() int {
p.channelsMux.RLock()
defer p.channelsMux.RUnlock()
return len(p.peerChannels)
}

View File

@@ -1,10 +1,11 @@
package server
package update_channel
import (
"context"
"testing"
"time"
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
"github.com/netbirdio/netbird/shared/management/proto"
)
@@ -24,7 +25,7 @@ func TestCreateChannel(t *testing.T) {
func TestSendUpdate(t *testing.T) {
peer := "test-sendupdate"
peersUpdater := NewPeersUpdateManager(nil)
update1 := &UpdateMessage{Update: &proto.SyncResponse{
update1 := &network_map.UpdateMessage{Update: &proto.SyncResponse{
NetworkMap: &proto.NetworkMap{
Serial: 0,
},
@@ -44,7 +45,7 @@ func TestSendUpdate(t *testing.T) {
peersUpdater.SendUpdate(context.Background(), peer, update1)
}
update2 := &UpdateMessage{Update: &proto.SyncResponse{
update2 := &network_map.UpdateMessage{Update: &proto.SyncResponse{
NetworkMap: &proto.NetworkMap{
Serial: 10,
},

View File

@@ -0,0 +1,9 @@
package network_map
import (
"github.com/netbirdio/netbird/shared/management/proto"
)
type UpdateMessage struct {
Update *proto.SyncResponse
}

View File

@@ -22,7 +22,7 @@ import (
"github.com/netbirdio/netbird/encryption"
"github.com/netbirdio/netbird/formatter/hook"
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/management/server"
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
"github.com/netbirdio/netbird/management/server/activity"
nbContext "github.com/netbirdio/netbird/management/server/context"
nbhttp "github.com/netbirdio/netbird/management/server/http"
@@ -93,7 +93,7 @@ func (s *BaseServer) EventStore() activity.Store {
func (s *BaseServer) APIHandler() http.Handler {
return Create(s, func() http.Handler {
httpAPIHandler, err := nbhttp.NewAPIHandler(context.Background(), s.AccountManager(), s.NetworksManager(), s.ResourcesManager(), s.RoutesManager(), s.GroupsManager(), s.GeoLocationManager(), s.AuthManager(), s.Metrics(), s.IntegratedValidator(), s.ProxyController(), s.PermissionsManager(), s.PeersManager(), s.SettingsManager())
httpAPIHandler, err := nbhttp.NewAPIHandler(context.Background(), s.AccountManager(), s.NetworksManager(), s.ResourcesManager(), s.RoutesManager(), s.GroupsManager(), s.GeoLocationManager(), s.AuthManager(), s.Metrics(), s.IntegratedValidator(), s.ProxyController(), s.PermissionsManager(), s.PeersManager(), s.SettingsManager(), s.NetworkMapController())
if err != nil {
log.Fatalf("failed to create API handler: %v", err)
}
@@ -145,7 +145,7 @@ func (s *BaseServer) GRPCServer() *grpc.Server {
}
gRPCAPIHandler := grpc.NewServer(gRPCOpts...)
srv, err := server.NewServer(context.Background(), s.config, s.AccountManager(), s.SettingsManager(), s.PeersUpdateManager(), s.SecretsManager(), s.Metrics(), s.EphemeralManager(), s.AuthManager(), s.IntegratedValidator())
srv, err := nbgrpc.NewServer(s.config, s.AccountManager(), s.SettingsManager(), s.PeersUpdateManager(), s.SecretsManager(), s.Metrics(), s.EphemeralManager(), s.AuthManager(), s.IntegratedValidator(), s.NetworkMapController())
if err != nil {
log.Fatalf("failed to create management server: %v", err)
}

View File

@@ -6,6 +6,10 @@ import (
log "github.com/sirupsen/logrus"
"github.com/netbirdio/management-integrations/integrations"
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
nmapcontroller "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
"github.com/netbirdio/netbird/management/internals/shared/grpc"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/auth"
"github.com/netbirdio/netbird/management/server/integrations/integrated_validator"
@@ -14,9 +18,9 @@ import (
"github.com/netbirdio/netbird/management/server/peers/ephemeral/manager"
)
func (s *BaseServer) PeersUpdateManager() *server.PeersUpdateManager {
return Create(s, func() *server.PeersUpdateManager {
return server.NewPeersUpdateManager(s.Metrics())
func (s *BaseServer) PeersUpdateManager() network_map.PeersUpdateManager {
return Create(s, func() *update_channel.PeersUpdateManager {
return update_channel.NewPeersUpdateManager(s.Metrics())
})
}
@@ -40,9 +44,9 @@ func (s *BaseServer) ProxyController() port_forwarding.Controller {
})
}
func (s *BaseServer) SecretsManager() *server.TimeBasedAuthSecretsManager {
return Create(s, func() *server.TimeBasedAuthSecretsManager {
return server.NewTimeBasedAuthSecretsManager(s.PeersUpdateManager(), s.config.TURNConfig, s.config.Relay, s.SettingsManager(), s.GroupsManager())
func (s *BaseServer) SecretsManager() *grpc.TimeBasedAuthSecretsManager {
return Create(s, func() *grpc.TimeBasedAuthSecretsManager {
return grpc.NewTimeBasedAuthSecretsManager(s.PeersUpdateManager(), s.config.TURNConfig, s.config.Relay, s.SettingsManager(), s.GroupsManager())
})
}
@@ -63,3 +67,15 @@ func (s *BaseServer) EphemeralManager() ephemeral.Manager {
return manager.NewEphemeralManager(s.Store(), s.AccountManager())
})
}
func (s *BaseServer) NetworkMapController() network_map.Controller {
return Create(s, func() *nmapcontroller.Controller {
return nmapcontroller.NewController(context.Background(), s.Store(), s.Metrics(), s.PeersUpdateManager(), s.AccountRequestBuffer(), s.IntegratedValidator(), s.SettingsManager(), s.dnsDomain, s.ProxyController())
})
}
func (s *BaseServer) AccountRequestBuffer() *server.AccountRequestBuffer {
return Create(s, func() *server.AccountRequestBuffer {
return server.NewAccountRequestBuffer(context.Background(), s.Store())
})
}

View File

@@ -66,8 +66,7 @@ func (s *BaseServer) PeersManager() peers.Manager {
func (s *BaseServer) AccountManager() account.Manager {
return Create(s, func() account.Manager {
accountManager, err := server.BuildManager(context.Background(), s.Store(), s.PeersUpdateManager(), s.IdpManager(), s.mgmtSingleAccModeDomain,
s.dnsDomain, s.EventStore(), s.GeoLocationManager(), s.userDeleteFromIDPEnabled, s.IntegratedValidator(), s.Metrics(), s.ProxyController(), s.SettingsManager(), s.PermissionsManager(), s.config.DisableDefaultPolicy)
accountManager, err := server.BuildManager(context.Background(), s.Store(), s.NetworkMapController(), s.IdpManager(), s.mgmtSingleAccModeDomain, s.EventStore(), s.GeoLocationManager(), s.userDeleteFromIDPEnabled, s.IntegratedValidator(), s.Metrics(), s.ProxyController(), s.SettingsManager(), s.PermissionsManager(), s.config.DisableDefaultPolicy)
if err != nil {
log.Fatalf("failed to create account manager: %v", err)
}

View File

@@ -0,0 +1,352 @@
package grpc
import (
"context"
"fmt"
integrationsConfig "github.com/netbirdio/management-integrations/integrations/config"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller/cache"
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/route"
"github.com/netbirdio/netbird/shared/management/proto"
)
func toNetbirdConfig(config *nbconfig.Config, turnCredentials *Token, relayToken *Token, extraSettings *types.ExtraSettings) *proto.NetbirdConfig {
if config == nil {
return nil
}
var stuns []*proto.HostConfig
for _, stun := range config.Stuns {
stuns = append(stuns, &proto.HostConfig{
Uri: stun.URI,
Protocol: ToResponseProto(stun.Proto),
})
}
var turns []*proto.ProtectedHostConfig
if config.TURNConfig != nil {
for _, turn := range config.TURNConfig.Turns {
var username string
var password string
if turnCredentials != nil {
username = turnCredentials.Payload
password = turnCredentials.Signature
} else {
username = turn.Username
password = turn.Password
}
turns = append(turns, &proto.ProtectedHostConfig{
HostConfig: &proto.HostConfig{
Uri: turn.URI,
Protocol: ToResponseProto(turn.Proto),
},
User: username,
Password: password,
})
}
}
var relayCfg *proto.RelayConfig
if config.Relay != nil && len(config.Relay.Addresses) > 0 {
relayCfg = &proto.RelayConfig{
Urls: config.Relay.Addresses,
}
if relayToken != nil {
relayCfg.TokenPayload = relayToken.Payload
relayCfg.TokenSignature = relayToken.Signature
}
}
var signalCfg *proto.HostConfig
if config.Signal != nil {
signalCfg = &proto.HostConfig{
Uri: config.Signal.URI,
Protocol: ToResponseProto(config.Signal.Proto),
}
}
nbConfig := &proto.NetbirdConfig{
Stuns: stuns,
Turns: turns,
Signal: signalCfg,
Relay: relayCfg,
}
return nbConfig
}
func toPeerConfig(peer *nbpeer.Peer, network *types.Network, dnsName string, settings *types.Settings) *proto.PeerConfig {
netmask, _ := network.Net.Mask.Size()
fqdn := peer.FQDN(dnsName)
return &proto.PeerConfig{
Address: fmt.Sprintf("%s/%d", peer.IP.String(), netmask), // take it from the network
SshConfig: &proto.SSHConfig{SshEnabled: peer.SSHEnabled},
Fqdn: fqdn,
RoutingPeerDnsResolutionEnabled: settings.RoutingPeerDNSResolutionEnabled,
LazyConnectionEnabled: settings.LazyConnectionEnabled,
}
}
func ToSyncResponse(ctx context.Context, config *nbconfig.Config, peer *nbpeer.Peer, turnCredentials *Token, relayCredentials *Token, networkMap *types.NetworkMap, dnsName string, checks []*posture.Checks, dnsCache *cache.DNSConfigCache, settings *types.Settings, extraSettings *types.ExtraSettings, peerGroups []string, dnsFwdPort int64) *proto.SyncResponse {
response := &proto.SyncResponse{
PeerConfig: toPeerConfig(peer, networkMap.Network, dnsName, settings),
NetworkMap: &proto.NetworkMap{
Serial: networkMap.Network.CurrentSerial(),
Routes: toProtocolRoutes(networkMap.Routes),
DNSConfig: toProtocolDNSConfig(networkMap.DNSConfig, dnsCache, dnsFwdPort),
},
Checks: toProtocolChecks(ctx, checks),
}
nbConfig := toNetbirdConfig(config, turnCredentials, relayCredentials, extraSettings)
extendedConfig := integrationsConfig.ExtendNetBirdConfig(peer.ID, peerGroups, nbConfig, extraSettings)
response.NetbirdConfig = extendedConfig
response.NetworkMap.PeerConfig = response.PeerConfig
remotePeers := make([]*proto.RemotePeerConfig, 0, len(networkMap.Peers)+len(networkMap.OfflinePeers))
remotePeers = appendRemotePeerConfig(remotePeers, networkMap.Peers, dnsName)
response.RemotePeers = remotePeers
response.NetworkMap.RemotePeers = remotePeers
response.RemotePeersIsEmpty = len(remotePeers) == 0
response.NetworkMap.RemotePeersIsEmpty = response.RemotePeersIsEmpty
response.NetworkMap.OfflinePeers = appendRemotePeerConfig(nil, networkMap.OfflinePeers, dnsName)
firewallRules := toProtocolFirewallRules(networkMap.FirewallRules)
response.NetworkMap.FirewallRules = firewallRules
response.NetworkMap.FirewallRulesIsEmpty = len(firewallRules) == 0
routesFirewallRules := toProtocolRoutesFirewallRules(networkMap.RoutesFirewallRules)
response.NetworkMap.RoutesFirewallRules = routesFirewallRules
response.NetworkMap.RoutesFirewallRulesIsEmpty = len(routesFirewallRules) == 0
if networkMap.ForwardingRules != nil {
forwardingRules := make([]*proto.ForwardingRule, 0, len(networkMap.ForwardingRules))
for _, rule := range networkMap.ForwardingRules {
forwardingRules = append(forwardingRules, rule.ToProto())
}
response.NetworkMap.ForwardingRules = forwardingRules
}
return response
}
func appendRemotePeerConfig(dst []*proto.RemotePeerConfig, peers []*nbpeer.Peer, dnsName string) []*proto.RemotePeerConfig {
for _, rPeer := range peers {
dst = append(dst, &proto.RemotePeerConfig{
WgPubKey: rPeer.Key,
AllowedIps: []string{rPeer.IP.String() + "/32"},
SshConfig: &proto.SSHConfig{SshPubKey: []byte(rPeer.SSHKey)},
Fqdn: rPeer.FQDN(dnsName),
AgentVersion: rPeer.Meta.WtVersion,
})
}
return dst
}
// toProtocolDNSConfig converts nbdns.Config to proto.DNSConfig using the cache
func toProtocolDNSConfig(update nbdns.Config, cache *cache.DNSConfigCache, forwardPort int64) *proto.DNSConfig {
protoUpdate := &proto.DNSConfig{
ServiceEnable: update.ServiceEnable,
CustomZones: make([]*proto.CustomZone, 0, len(update.CustomZones)),
NameServerGroups: make([]*proto.NameServerGroup, 0, len(update.NameServerGroups)),
ForwarderPort: forwardPort,
}
for _, zone := range update.CustomZones {
protoZone := convertToProtoCustomZone(zone)
protoUpdate.CustomZones = append(protoUpdate.CustomZones, protoZone)
}
for _, nsGroup := range update.NameServerGroups {
cacheKey := nsGroup.ID
if cachedGroup, exists := cache.GetNameServerGroup(cacheKey); exists {
protoUpdate.NameServerGroups = append(protoUpdate.NameServerGroups, cachedGroup)
} else {
protoGroup := convertToProtoNameServerGroup(nsGroup)
cache.SetNameServerGroup(cacheKey, protoGroup)
protoUpdate.NameServerGroups = append(protoUpdate.NameServerGroups, protoGroup)
}
}
return protoUpdate
}
func ToResponseProto(configProto nbconfig.Protocol) proto.HostConfig_Protocol {
switch configProto {
case nbconfig.UDP:
return proto.HostConfig_UDP
case nbconfig.DTLS:
return proto.HostConfig_DTLS
case nbconfig.HTTP:
return proto.HostConfig_HTTP
case nbconfig.HTTPS:
return proto.HostConfig_HTTPS
case nbconfig.TCP:
return proto.HostConfig_TCP
default:
panic(fmt.Errorf("unexpected config protocol type %v", configProto))
}
}
func toProtocolRoutes(routes []*route.Route) []*proto.Route {
protoRoutes := make([]*proto.Route, 0, len(routes))
for _, r := range routes {
protoRoutes = append(protoRoutes, toProtocolRoute(r))
}
return protoRoutes
}
func toProtocolRoute(route *route.Route) *proto.Route {
return &proto.Route{
ID: string(route.ID),
NetID: string(route.NetID),
Network: route.Network.String(),
Domains: route.Domains.ToPunycodeList(),
NetworkType: int64(route.NetworkType),
Peer: route.Peer,
Metric: int64(route.Metric),
Masquerade: route.Masquerade,
KeepRoute: route.KeepRoute,
SkipAutoApply: route.SkipAutoApply,
}
}
// toProtocolFirewallRules converts the firewall rules to the protocol firewall rules.
func toProtocolFirewallRules(rules []*types.FirewallRule) []*proto.FirewallRule {
result := make([]*proto.FirewallRule, len(rules))
for i := range rules {
rule := rules[i]
fwRule := &proto.FirewallRule{
PolicyID: []byte(rule.PolicyID),
PeerIP: rule.PeerIP,
Direction: getProtoDirection(rule.Direction),
Action: getProtoAction(rule.Action),
Protocol: getProtoProtocol(rule.Protocol),
Port: rule.Port,
}
if shouldUsePortRange(fwRule) {
fwRule.PortInfo = rule.PortRange.ToProto()
}
result[i] = fwRule
}
return result
}
// getProtoDirection converts the direction to proto.RuleDirection.
func getProtoDirection(direction int) proto.RuleDirection {
if direction == types.FirewallRuleDirectionOUT {
return proto.RuleDirection_OUT
}
return proto.RuleDirection_IN
}
func toProtocolRoutesFirewallRules(rules []*types.RouteFirewallRule) []*proto.RouteFirewallRule {
result := make([]*proto.RouteFirewallRule, len(rules))
for i := range rules {
rule := rules[i]
result[i] = &proto.RouteFirewallRule{
SourceRanges: rule.SourceRanges,
Action: getProtoAction(rule.Action),
Destination: rule.Destination,
Protocol: getProtoProtocol(rule.Protocol),
PortInfo: getProtoPortInfo(rule),
IsDynamic: rule.IsDynamic,
Domains: rule.Domains.ToPunycodeList(),
PolicyID: []byte(rule.PolicyID),
RouteID: string(rule.RouteID),
}
}
return result
}
// getProtoAction converts the action to proto.RuleAction.
func getProtoAction(action string) proto.RuleAction {
if action == string(types.PolicyTrafficActionDrop) {
return proto.RuleAction_DROP
}
return proto.RuleAction_ACCEPT
}
// getProtoProtocol converts the protocol to proto.RuleProtocol.
func getProtoProtocol(protocol string) proto.RuleProtocol {
switch types.PolicyRuleProtocolType(protocol) {
case types.PolicyRuleProtocolALL:
return proto.RuleProtocol_ALL
case types.PolicyRuleProtocolTCP:
return proto.RuleProtocol_TCP
case types.PolicyRuleProtocolUDP:
return proto.RuleProtocol_UDP
case types.PolicyRuleProtocolICMP:
return proto.RuleProtocol_ICMP
default:
return proto.RuleProtocol_UNKNOWN
}
}
// getProtoPortInfo converts the port info to proto.PortInfo.
func getProtoPortInfo(rule *types.RouteFirewallRule) *proto.PortInfo {
var portInfo proto.PortInfo
if rule.Port != 0 {
portInfo.PortSelection = &proto.PortInfo_Port{Port: uint32(rule.Port)}
} else if portRange := rule.PortRange; portRange.Start != 0 && portRange.End != 0 {
portInfo.PortSelection = &proto.PortInfo_Range_{
Range: &proto.PortInfo_Range{
Start: uint32(portRange.Start),
End: uint32(portRange.End),
},
}
}
return &portInfo
}
func shouldUsePortRange(rule *proto.FirewallRule) bool {
return rule.Port == "" && (rule.Protocol == proto.RuleProtocol_UDP || rule.Protocol == proto.RuleProtocol_TCP)
}
// Helper function to convert nbdns.CustomZone to proto.CustomZone
func convertToProtoCustomZone(zone nbdns.CustomZone) *proto.CustomZone {
protoZone := &proto.CustomZone{
Domain: zone.Domain,
Records: make([]*proto.SimpleRecord, 0, len(zone.Records)),
}
for _, record := range zone.Records {
protoZone.Records = append(protoZone.Records, &proto.SimpleRecord{
Name: record.Name,
Type: int64(record.Type),
Class: record.Class,
TTL: int64(record.TTL),
RData: record.RData,
})
}
return protoZone
}
// Helper function to convert nbdns.NameServerGroup to proto.NameServerGroup
func convertToProtoNameServerGroup(nsGroup *nbdns.NameServerGroup) *proto.NameServerGroup {
protoGroup := &proto.NameServerGroup{
Primary: nsGroup.Primary,
Domains: nsGroup.Domains,
SearchDomainsEnabled: nsGroup.SearchDomainsEnabled,
NameServers: make([]*proto.NameServer, 0, len(nsGroup.NameServers)),
}
for _, ns := range nsGroup.NameServers {
protoGroup.NameServers = append(protoGroup.NameServers, &proto.NameServer{
IP: ns.IP.String(),
Port: int64(ns.Port),
NSType: int64(ns.NSType),
})
}
return protoGroup
}

View File

@@ -0,0 +1,150 @@
package grpc
import (
"fmt"
"net/netip"
"reflect"
"testing"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller/cache"
)
func TestToProtocolDNSConfigWithCache(t *testing.T) {
var cache cache.DNSConfigCache
// Create two different configs
config1 := nbdns.Config{
ServiceEnable: true,
CustomZones: []nbdns.CustomZone{
{
Domain: "example.com",
Records: []nbdns.SimpleRecord{
{Name: "www", Type: 1, Class: "IN", TTL: 300, RData: "192.168.1.1"},
},
},
},
NameServerGroups: []*nbdns.NameServerGroup{
{
ID: "group1",
Name: "Group 1",
NameServers: []nbdns.NameServer{
{IP: netip.MustParseAddr("8.8.8.8"), Port: 53},
},
},
},
}
config2 := nbdns.Config{
ServiceEnable: true,
CustomZones: []nbdns.CustomZone{
{
Domain: "example.org",
Records: []nbdns.SimpleRecord{
{Name: "mail", Type: 1, Class: "IN", TTL: 300, RData: "192.168.1.2"},
},
},
},
NameServerGroups: []*nbdns.NameServerGroup{
{
ID: "group2",
Name: "Group 2",
NameServers: []nbdns.NameServer{
{IP: netip.MustParseAddr("8.8.4.4"), Port: 53},
},
},
},
}
// First run with config1
result1 := toProtocolDNSConfig(config1, &cache, int64(network_map.DnsForwarderPort))
// Second run with config2
result2 := toProtocolDNSConfig(config2, &cache, int64(network_map.DnsForwarderPort))
// Third run with config1 again
result3 := toProtocolDNSConfig(config1, &cache, int64(network_map.DnsForwarderPort))
// Verify that result1 and result3 are identical
if !reflect.DeepEqual(result1, result3) {
t.Errorf("Results are not identical when run with the same input. Expected %v, got %v", result1, result3)
}
// Verify that result2 is different from result1 and result3
if reflect.DeepEqual(result1, result2) || reflect.DeepEqual(result2, result3) {
t.Errorf("Results should be different for different inputs")
}
if _, exists := cache.GetNameServerGroup("group1"); !exists {
t.Errorf("Cache should contain name server group 'group1'")
}
if _, exists := cache.GetNameServerGroup("group2"); !exists {
t.Errorf("Cache should contain name server group 'group2'")
}
}
func BenchmarkToProtocolDNSConfig(b *testing.B) {
sizes := []int{10, 100, 1000}
for _, size := range sizes {
testData := generateTestData(size)
b.Run(fmt.Sprintf("WithCache-Size%d", size), func(b *testing.B) {
cache := &cache.DNSConfigCache{}
b.ResetTimer()
for i := 0; i < b.N; i++ {
toProtocolDNSConfig(testData, cache, int64(network_map.DnsForwarderPort))
}
})
b.Run(fmt.Sprintf("WithoutCache-Size%d", size), func(b *testing.B) {
b.ResetTimer()
for i := 0; i < b.N; i++ {
cache := &cache.DNSConfigCache{}
toProtocolDNSConfig(testData, cache, int64(network_map.DnsForwarderPort))
}
})
}
}
func generateTestData(size int) nbdns.Config {
config := nbdns.Config{
ServiceEnable: true,
CustomZones: make([]nbdns.CustomZone, size),
NameServerGroups: make([]*nbdns.NameServerGroup, size),
}
for i := 0; i < size; i++ {
config.CustomZones[i] = nbdns.CustomZone{
Domain: fmt.Sprintf("domain%d.com", i),
Records: []nbdns.SimpleRecord{
{
Name: fmt.Sprintf("record%d", i),
Type: 1,
Class: "IN",
TTL: 3600,
RData: "192.168.1.1",
},
},
}
config.NameServerGroups[i] = &nbdns.NameServerGroup{
ID: fmt.Sprintf("group%d", i),
Primary: i == 0,
Domains: []string{fmt.Sprintf("domain%d.com", i)},
SearchDomainsEnabled: true,
NameServers: []nbdns.NameServer{
{
IP: netip.MustParseAddr("8.8.8.8"),
Port: 53,
NSType: 1,
},
},
}
}
return config
}

View File

@@ -1,4 +1,4 @@
package server
package grpc
import (
"hash/fnv"

View File

@@ -1,4 +1,4 @@
package server
package grpc
import (
"hash/fnv"

View File

@@ -1,4 +1,4 @@
package server
package grpc
import (
"context"
@@ -7,8 +7,10 @@ import (
"net"
"net/netip"
"os"
"strconv"
"strings"
"sync"
"sync/atomic"
"time"
pb "github.com/golang/protobuf/proto" // nolint
@@ -20,7 +22,7 @@ import (
"google.golang.org/grpc/peer"
"google.golang.org/grpc/status"
integrationsConfig "github.com/netbirdio/management-integrations/integrations/config"
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/management/server/peers/ephemeral"
@@ -44,15 +46,18 @@ import (
const (
envLogBlockedPeers = "NB_LOG_BLOCKED_PEERS"
envBlockPeers = "NB_BLOCK_SAME_PEERS"
envConcurrentSyncs = "NB_MAX_CONCURRENT_SYNCS"
defaultSyncLim = 1000
)
// GRPCServer an instance of a Management gRPC API server
type GRPCServer struct {
// Server an instance of a Management gRPC API server
type Server struct {
accountManager account.Manager
settingsManager settings.Manager
wgKey wgtypes.Key
proto.UnimplementedManagementServiceServer
peersUpdateManager *PeersUpdateManager
peersUpdateManager network_map.PeersUpdateManager
config *nbconfig.Config
secretsManager SecretsManager
appMetrics telemetry.AppMetrics
@@ -63,21 +68,28 @@ type GRPCServer struct {
logBlockedPeers bool
blockPeersWithSameConfig bool
integratedPeerValidator integrated_validator.IntegratedValidator
loginFilter *loginFilter
networkMapController network_map.Controller
syncSem atomic.Int32
syncLim int32
}
// NewServer creates a new Management server
func NewServer(
ctx context.Context,
config *nbconfig.Config,
accountManager account.Manager,
settingsManager settings.Manager,
peersUpdateManager *PeersUpdateManager,
peersUpdateManager network_map.PeersUpdateManager,
secretsManager SecretsManager,
appMetrics telemetry.AppMetrics,
ephemeralManager ephemeral.Manager,
authManager auth.Manager,
integratedPeerValidator integrated_validator.IntegratedValidator,
) (*GRPCServer, error) {
networkMapController network_map.Controller,
) (*Server, error) {
key, err := wgtypes.GeneratePrivateKey()
if err != nil {
return nil, err
@@ -86,7 +98,7 @@ func NewServer(
if appMetrics != nil {
// update gauge based on number of connected peers which is equal to open gRPC streams
err = appMetrics.GRPCMetrics().RegisterConnectedStreams(func() int64 {
return int64(len(peersUpdateManager.peerChannels))
return int64(peersUpdateManager.CountStreams())
})
if err != nil {
return nil, err
@@ -96,7 +108,18 @@ func NewServer(
logBlockedPeers := strings.ToLower(os.Getenv(envLogBlockedPeers)) == "true"
blockPeersWithSameConfig := strings.ToLower(os.Getenv(envBlockPeers)) == "true"
return &GRPCServer{
syncLim := int32(defaultSyncLim)
if syncLimStr := os.Getenv(envConcurrentSyncs); syncLimStr != "" {
syncLimParsed, err := strconv.Atoi(syncLimStr)
if err != nil {
log.Errorf("invalid value for %s: %v using %d", envConcurrentSyncs, err, defaultSyncLim)
} else {
//nolint:gosec
syncLim = int32(syncLimParsed)
}
}
return &Server{
wgKey: key,
// peerKey -> event channel
peersUpdateManager: peersUpdateManager,
@@ -110,10 +133,15 @@ func NewServer(
logBlockedPeers: logBlockedPeers,
blockPeersWithSameConfig: blockPeersWithSameConfig,
integratedPeerValidator: integratedPeerValidator,
networkMapController: networkMapController,
loginFilter: newLoginFilter(),
syncLim: syncLim,
}, nil
}
func (s *GRPCServer) GetServerKey(ctx context.Context, req *proto.Empty) (*proto.ServerKeyResponse, error) {
func (s *Server) GetServerKey(ctx context.Context, req *proto.Empty) (*proto.ServerKeyResponse, error) {
ip := ""
p, ok := peer.FromContext(ctx)
if ok {
@@ -150,7 +178,12 @@ func getRealIP(ctx context.Context) net.IP {
// Sync validates the existence of a connecting peer, sends an initial state (all available for the connecting peers) and
// notifies the connected peer of any updates (e.g. new peers under the same account)
func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_SyncServer) error {
func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_SyncServer) error {
if s.syncSem.Load() >= s.syncLim {
return status.Errorf(codes.ResourceExhausted, "too many concurrent sync requests, please try again later")
}
s.syncSem.Add(1)
reqStart := time.Now()
ctx := srv.Context()
@@ -158,13 +191,14 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi
syncReq := &proto.SyncRequest{}
peerKey, err := s.parseRequest(ctx, req, syncReq)
if err != nil {
s.syncSem.Add(-1)
return err
}
realIP := getRealIP(ctx)
sRealIP := realIP.String()
peerMeta := extractPeerMeta(ctx, syncReq.GetMeta())
metahashed := metaHash(peerMeta, sRealIP)
if !s.accountManager.AllowSync(peerKey.String(), metahashed) {
if !s.loginFilter.allowLogin(peerKey.String(), metahashed) {
if s.appMetrics != nil {
s.appMetrics.GRPCMetrics().CountSyncRequestBlocked()
}
@@ -172,6 +206,7 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi
log.WithContext(ctx).Warnf("peer %s with meta hash %d is blocked from syncing", peerKey.String(), metahashed)
}
if s.blockPeersWithSameConfig {
s.syncSem.Add(-1)
return mapError(ctx, internalStatus.ErrPeerAlreadyLoggedIn)
}
}
@@ -183,42 +218,54 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi
// nolint:staticcheck
ctx = context.WithValue(ctx, nbContext.PeerIDKey, peerKey.String())
unlock := s.acquirePeerLockByUID(ctx, peerKey.String())
defer func() {
if unlock != nil {
unlock()
}
}()
accountID, err := s.accountManager.GetAccountIDForPeerKey(ctx, peerKey.String())
if err != nil {
// nolint:staticcheck
ctx = context.WithValue(ctx, nbContext.AccountIDKey, "UNKNOWN")
log.WithContext(ctx).Tracef("peer %s is not registered", peerKey.String())
if errStatus, ok := internalStatus.FromError(err); ok && errStatus.Type() == internalStatus.NotFound {
s.syncSem.Add(-1)
return status.Errorf(codes.PermissionDenied, "peer is not registered")
}
s.syncSem.Add(-1)
return err
}
log.WithContext(ctx).Debugf("Sync: GetAccountIDForPeerKey since start %v", time.Since(reqStart))
// nolint:staticcheck
ctx = context.WithValue(ctx, nbContext.AccountIDKey, accountID)
start := time.Now()
unlock := s.acquirePeerLockByUID(ctx, peerKey.String())
defer func() {
if unlock != nil {
unlock()
}
}()
log.WithContext(ctx).Tracef("acquired peer lock for peer %s took %v", peerKey.String(), time.Since(start))
log.WithContext(ctx).Debugf("Sync: acquirePeerLockByUID since start %v", time.Since(reqStart))
log.WithContext(ctx).Debugf("Sync request from peer [%s] [%s]", req.WgPubKey, sRealIP)
if syncReq.GetMeta() == nil {
log.WithContext(ctx).Tracef("peer system meta has to be provided on sync. Peer %s, remote addr %s", peerKey.String(), realIP)
}
peer, netMap, postureChecks, err := s.accountManager.SyncAndMarkPeer(ctx, accountID, peerKey.String(), peerMeta, realIP)
metahash := metaHash(peerMeta, realIP.String())
s.loginFilter.addLogin(peerKey.String(), metahash)
peer, netMap, postureChecks, dnsFwdPort, err := s.accountManager.SyncAndMarkPeer(ctx, accountID, peerKey.String(), peerMeta, realIP)
if err != nil {
log.WithContext(ctx).Debugf("error while syncing peer %s: %v", peerKey.String(), err)
s.syncSem.Add(-1)
return mapError(ctx, err)
}
err = s.sendInitialSync(ctx, peerKey, peer, netMap, postureChecks, srv)
err = s.sendInitialSync(ctx, peerKey, peer, netMap, postureChecks, srv, dnsFwdPort)
if err != nil {
log.WithContext(ctx).Debugf("error while sending initial sync for %s: %v", peerKey.String(), err)
s.syncSem.Add(-1)
return err
}
@@ -235,13 +282,13 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi
unlock()
unlock = nil
log.WithContext(ctx).Debugf("Sync: took %v", time.Since(reqStart))
s.syncSem.Add(-1)
return s.handleUpdates(ctx, accountID, peerKey, peer, updates, srv)
}
// handleUpdates sends updates to the connected peer until the updates channel is closed.
func (s *GRPCServer) handleUpdates(ctx context.Context, accountID string, peerKey wgtypes.Key, peer *nbpeer.Peer, updates chan *UpdateMessage, srv proto.ManagementService_SyncServer) error {
func (s *Server) handleUpdates(ctx context.Context, accountID string, peerKey wgtypes.Key, peer *nbpeer.Peer, updates chan *network_map.UpdateMessage, srv proto.ManagementService_SyncServer) error {
log.WithContext(ctx).Tracef("starting to handle updates for peer %s", peerKey.String())
for {
select {
@@ -275,7 +322,7 @@ func (s *GRPCServer) handleUpdates(ctx context.Context, accountID string, peerKe
// sendUpdate encrypts the update message using the peer key and the server's wireguard key,
// then sends the encrypted message to the connected peer via the sync server.
func (s *GRPCServer) sendUpdate(ctx context.Context, accountID string, peerKey wgtypes.Key, peer *nbpeer.Peer, update *UpdateMessage, srv proto.ManagementService_SyncServer) error {
func (s *Server) sendUpdate(ctx context.Context, accountID string, peerKey wgtypes.Key, peer *nbpeer.Peer, update *network_map.UpdateMessage, srv proto.ManagementService_SyncServer) error {
encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, update.Update)
if err != nil {
s.cancelPeerRoutines(ctx, accountID, peer)
@@ -293,7 +340,7 @@ func (s *GRPCServer) sendUpdate(ctx context.Context, accountID string, peerKey w
return nil
}
func (s *GRPCServer) cancelPeerRoutines(ctx context.Context, accountID string, peer *nbpeer.Peer) {
func (s *Server) cancelPeerRoutines(ctx context.Context, accountID string, peer *nbpeer.Peer) {
unlock := s.acquirePeerLockByUID(ctx, peer.Key)
defer unlock()
@@ -308,7 +355,7 @@ func (s *GRPCServer) cancelPeerRoutines(ctx context.Context, accountID string, p
log.WithContext(ctx).Tracef("peer %s has been disconnected", peer.Key)
}
func (s *GRPCServer) validateToken(ctx context.Context, jwtToken string) (string, error) {
func (s *Server) validateToken(ctx context.Context, jwtToken string) (string, error) {
if s.authManager == nil {
return "", status.Errorf(codes.Internal, "missing auth manager")
}
@@ -342,7 +389,7 @@ func (s *GRPCServer) validateToken(ctx context.Context, jwtToken string) (string
return userAuth.UserId, nil
}
func (s *GRPCServer) acquirePeerLockByUID(ctx context.Context, uniqueID string) (unlock func()) {
func (s *Server) acquirePeerLockByUID(ctx context.Context, uniqueID string) (unlock func()) {
log.WithContext(ctx).Tracef("acquiring peer lock for ID %s", uniqueID)
start := time.Now()
@@ -450,7 +497,7 @@ func extractPeerMeta(ctx context.Context, meta *proto.PeerSystemMeta) nbpeer.Pee
}
}
func (s *GRPCServer) parseRequest(ctx context.Context, req *proto.EncryptedMessage, parsed pb.Message) (wgtypes.Key, error) {
func (s *Server) parseRequest(ctx context.Context, req *proto.EncryptedMessage, parsed pb.Message) (wgtypes.Key, error) {
peerKey, err := wgtypes.ParseKey(req.GetWgPubKey())
if err != nil {
log.WithContext(ctx).Warnf("error while parsing peer's WireGuard public key %s.", req.WgPubKey)
@@ -469,7 +516,7 @@ func (s *GRPCServer) parseRequest(ctx context.Context, req *proto.EncryptedMessa
// In case it is, the login is successful
// In case it isn't, the endpoint checks whether setup key is provided within the request and tries to register a peer.
// In case of the successful registration login is also successful
func (s *GRPCServer) Login(ctx context.Context, req *proto.EncryptedMessage) (*proto.EncryptedMessage, error) {
func (s *Server) Login(ctx context.Context, req *proto.EncryptedMessage) (*proto.EncryptedMessage, error) {
reqStart := time.Now()
realIP := getRealIP(ctx)
sRealIP := realIP.String()
@@ -483,7 +530,7 @@ func (s *GRPCServer) Login(ctx context.Context, req *proto.EncryptedMessage) (*p
peerMeta := extractPeerMeta(ctx, loginReq.GetMeta())
metahashed := metaHash(peerMeta, sRealIP)
if !s.accountManager.AllowSync(peerKey.String(), metahashed) {
if !s.loginFilter.allowLogin(peerKey.String(), metahashed) {
if s.logBlockedPeers {
log.WithContext(ctx).Warnf("peer %s with meta hash %d is blocked from login", peerKey.String(), metahashed)
}
@@ -509,10 +556,16 @@ func (s *GRPCServer) Login(ctx context.Context, req *proto.EncryptedMessage) (*p
//nolint
ctx = context.WithValue(ctx, nbContext.AccountIDKey, accountID)
log.WithContext(ctx).Debugf("Login: GetAccountIDForPeerKey since start %v", time.Since(reqStart))
defer func() {
if s.appMetrics != nil {
s.appMetrics.GRPCMetrics().CountLoginRequestDuration(time.Since(reqStart), accountID)
}
took := time.Since(reqStart)
if took > 7*time.Second {
log.WithContext(ctx).Debugf("Login: took %v", time.Since(reqStart))
}
}()
if loginReq.GetMeta() == nil {
@@ -546,9 +599,12 @@ func (s *GRPCServer) Login(ctx context.Context, req *proto.EncryptedMessage) (*p
return nil, mapError(ctx, err)
}
log.WithContext(ctx).Debugf("Login: LoginPeer since start %v", time.Since(reqStart))
// if the login request contains setup key then it is a registration request
if loginReq.GetSetupKey() != "" {
s.ephemeralManager.OnPeerDisconnected(ctx, peer)
log.WithContext(ctx).Debugf("Login: OnPeerDisconnected since start %v", time.Since(reqStart))
}
loginResp, err := s.prepareLoginResponse(ctx, peer, netMap, postureChecks)
@@ -557,6 +613,8 @@ func (s *GRPCServer) Login(ctx context.Context, req *proto.EncryptedMessage) (*p
return nil, status.Errorf(codes.Internal, "failed logging in peer")
}
log.WithContext(ctx).Debugf("Login: prepareLoginResponse since start %v", time.Since(reqStart))
encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, loginResp)
if err != nil {
log.WithContext(ctx).Warnf("failed encrypting peer %s message", peer.ID)
@@ -569,7 +627,7 @@ func (s *GRPCServer) Login(ctx context.Context, req *proto.EncryptedMessage) (*p
}, nil
}
func (s *GRPCServer) prepareLoginResponse(ctx context.Context, peer *nbpeer.Peer, netMap *types.NetworkMap, postureChecks []*posture.Checks) (*proto.LoginResponse, error) {
func (s *Server) prepareLoginResponse(ctx context.Context, peer *nbpeer.Peer, netMap *types.NetworkMap, postureChecks []*posture.Checks) (*proto.LoginResponse, error) {
var relayToken *Token
var err error
if s.config.Relay != nil && len(s.config.Relay.Addresses) > 0 {
@@ -588,7 +646,7 @@ func (s *GRPCServer) prepareLoginResponse(ctx context.Context, peer *nbpeer.Peer
// if peer has reached this point then it has logged in
loginResp := &proto.LoginResponse{
NetbirdConfig: toNetbirdConfig(s.config, nil, relayToken, nil),
PeerConfig: toPeerConfig(peer, netMap.Network, s.accountManager.GetDNSDomain(settings), settings),
PeerConfig: toPeerConfig(peer, netMap.Network, s.networkMapController.GetDNSDomain(settings), settings),
Checks: toProtocolChecks(ctx, postureChecks),
}
@@ -600,7 +658,7 @@ func (s *GRPCServer) prepareLoginResponse(ctx context.Context, peer *nbpeer.Peer
//
// The user ID can be empty if the token is not provided, which is acceptable if the peer is already
// registered or if it uses a setup key to register.
func (s *GRPCServer) processJwtToken(ctx context.Context, loginReq *proto.LoginRequest, peerKey wgtypes.Key) (string, error) {
func (s *Server) processJwtToken(ctx context.Context, loginReq *proto.LoginRequest, peerKey wgtypes.Key) (string, error) {
userID := ""
if loginReq.GetJwtToken() != "" {
var err error
@@ -620,166 +678,13 @@ func (s *GRPCServer) processJwtToken(ctx context.Context, loginReq *proto.LoginR
return userID, nil
}
func ToResponseProto(configProto nbconfig.Protocol) proto.HostConfig_Protocol {
switch configProto {
case nbconfig.UDP:
return proto.HostConfig_UDP
case nbconfig.DTLS:
return proto.HostConfig_DTLS
case nbconfig.HTTP:
return proto.HostConfig_HTTP
case nbconfig.HTTPS:
return proto.HostConfig_HTTPS
case nbconfig.TCP:
return proto.HostConfig_TCP
default:
panic(fmt.Errorf("unexpected config protocol type %v", configProto))
}
}
func toNetbirdConfig(config *nbconfig.Config, turnCredentials *Token, relayToken *Token, extraSettings *types.ExtraSettings) *proto.NetbirdConfig {
if config == nil {
return nil
}
var stuns []*proto.HostConfig
for _, stun := range config.Stuns {
stuns = append(stuns, &proto.HostConfig{
Uri: stun.URI,
Protocol: ToResponseProto(stun.Proto),
})
}
var turns []*proto.ProtectedHostConfig
if config.TURNConfig != nil {
for _, turn := range config.TURNConfig.Turns {
var username string
var password string
if turnCredentials != nil {
username = turnCredentials.Payload
password = turnCredentials.Signature
} else {
username = turn.Username
password = turn.Password
}
turns = append(turns, &proto.ProtectedHostConfig{
HostConfig: &proto.HostConfig{
Uri: turn.URI,
Protocol: ToResponseProto(turn.Proto),
},
User: username,
Password: password,
})
}
}
var relayCfg *proto.RelayConfig
if config.Relay != nil && len(config.Relay.Addresses) > 0 {
relayCfg = &proto.RelayConfig{
Urls: config.Relay.Addresses,
}
if relayToken != nil {
relayCfg.TokenPayload = relayToken.Payload
relayCfg.TokenSignature = relayToken.Signature
}
}
var signalCfg *proto.HostConfig
if config.Signal != nil {
signalCfg = &proto.HostConfig{
Uri: config.Signal.URI,
Protocol: ToResponseProto(config.Signal.Proto),
}
}
nbConfig := &proto.NetbirdConfig{
Stuns: stuns,
Turns: turns,
Signal: signalCfg,
Relay: relayCfg,
}
return nbConfig
}
func toPeerConfig(peer *nbpeer.Peer, network *types.Network, dnsName string, settings *types.Settings) *proto.PeerConfig {
netmask, _ := network.Net.Mask.Size()
fqdn := peer.FQDN(dnsName)
return &proto.PeerConfig{
Address: fmt.Sprintf("%s/%d", peer.IP.String(), netmask), // take it from the network
SshConfig: &proto.SSHConfig{SshEnabled: peer.SSHEnabled},
Fqdn: fqdn,
RoutingPeerDnsResolutionEnabled: settings.RoutingPeerDNSResolutionEnabled,
LazyConnectionEnabled: settings.LazyConnectionEnabled,
}
}
func toSyncResponse(ctx context.Context, config *nbconfig.Config, peer *nbpeer.Peer, turnCredentials *Token, relayCredentials *Token, networkMap *types.NetworkMap, dnsName string, checks []*posture.Checks, dnsCache *DNSConfigCache, settings *types.Settings, extraSettings *types.ExtraSettings, peerGroups []string, dnsFwdPort int64) *proto.SyncResponse {
response := &proto.SyncResponse{
PeerConfig: toPeerConfig(peer, networkMap.Network, dnsName, settings),
NetworkMap: &proto.NetworkMap{
Serial: networkMap.Network.CurrentSerial(),
Routes: toProtocolRoutes(networkMap.Routes),
DNSConfig: toProtocolDNSConfig(networkMap.DNSConfig, dnsCache, dnsFwdPort),
},
Checks: toProtocolChecks(ctx, checks),
}
nbConfig := toNetbirdConfig(config, turnCredentials, relayCredentials, extraSettings)
extendedConfig := integrationsConfig.ExtendNetBirdConfig(peer.ID, peerGroups, nbConfig, extraSettings)
response.NetbirdConfig = extendedConfig
response.NetworkMap.PeerConfig = response.PeerConfig
remotePeers := make([]*proto.RemotePeerConfig, 0, len(networkMap.Peers)+len(networkMap.OfflinePeers))
remotePeers = appendRemotePeerConfig(remotePeers, networkMap.Peers, dnsName)
response.RemotePeers = remotePeers
response.NetworkMap.RemotePeers = remotePeers
response.RemotePeersIsEmpty = len(remotePeers) == 0
response.NetworkMap.RemotePeersIsEmpty = response.RemotePeersIsEmpty
response.NetworkMap.OfflinePeers = appendRemotePeerConfig(nil, networkMap.OfflinePeers, dnsName)
firewallRules := toProtocolFirewallRules(networkMap.FirewallRules)
response.NetworkMap.FirewallRules = firewallRules
response.NetworkMap.FirewallRulesIsEmpty = len(firewallRules) == 0
routesFirewallRules := toProtocolRoutesFirewallRules(networkMap.RoutesFirewallRules)
response.NetworkMap.RoutesFirewallRules = routesFirewallRules
response.NetworkMap.RoutesFirewallRulesIsEmpty = len(routesFirewallRules) == 0
if networkMap.ForwardingRules != nil {
forwardingRules := make([]*proto.ForwardingRule, 0, len(networkMap.ForwardingRules))
for _, rule := range networkMap.ForwardingRules {
forwardingRules = append(forwardingRules, rule.ToProto())
}
response.NetworkMap.ForwardingRules = forwardingRules
}
return response
}
func appendRemotePeerConfig(dst []*proto.RemotePeerConfig, peers []*nbpeer.Peer, dnsName string) []*proto.RemotePeerConfig {
for _, rPeer := range peers {
dst = append(dst, &proto.RemotePeerConfig{
WgPubKey: rPeer.Key,
AllowedIps: []string{rPeer.IP.String() + "/32"},
SshConfig: &proto.SSHConfig{SshPubKey: []byte(rPeer.SSHKey)},
Fqdn: rPeer.FQDN(dnsName),
AgentVersion: rPeer.Meta.WtVersion,
})
}
return dst
}
// IsHealthy indicates whether the service is healthy
func (s *GRPCServer) IsHealthy(ctx context.Context, req *proto.Empty) (*proto.Empty, error) {
func (s *Server) IsHealthy(ctx context.Context, req *proto.Empty) (*proto.Empty, error) {
return &proto.Empty{}, nil
}
// sendInitialSync sends initial proto.SyncResponse to the peer requesting synchronization
func (s *GRPCServer) sendInitialSync(ctx context.Context, peerKey wgtypes.Key, peer *nbpeer.Peer, networkMap *types.NetworkMap, postureChecks []*posture.Checks, srv proto.ManagementService_SyncServer) error {
func (s *Server) sendInitialSync(ctx context.Context, peerKey wgtypes.Key, peer *nbpeer.Peer, networkMap *types.NetworkMap, postureChecks []*posture.Checks, srv proto.ManagementService_SyncServer, dnsFwdPort int64) error {
var err error
var turnToken *Token
@@ -803,29 +708,24 @@ func (s *GRPCServer) sendInitialSync(ctx context.Context, peerKey wgtypes.Key, p
return status.Errorf(codes.Internal, "error handling request")
}
peerGroups, err := getPeerGroupIDs(ctx, s.accountManager.GetStore(), peer.AccountID, peer.ID)
peerGroups, err := s.accountManager.GetStore().GetPeerGroupIDs(ctx, store.LockingStrengthNone, peer.AccountID, peer.ID)
if err != nil {
return status.Errorf(codes.Internal, "failed to get peer groups %s", err)
}
// Get all peers in the account for forwarder port computation
allPeers, err := s.accountManager.GetStore().GetAccountPeers(ctx, store.LockingStrengthNone, peer.AccountID, "", "")
if err != nil {
return fmt.Errorf("get account peers: %w", err)
}
dnsFwdPort := computeForwarderPort(allPeers, dnsForwarderPortMinVersion)
plainResp := toSyncResponse(ctx, s.config, peer, turnToken, relayToken, networkMap, s.accountManager.GetDNSDomain(settings), postureChecks, nil, settings, settings.Extra, peerGroups, dnsFwdPort)
plainResp := ToSyncResponse(ctx, s.config, peer, turnToken, relayToken, networkMap, s.networkMapController.GetDNSDomain(settings), postureChecks, nil, settings, settings.Extra, peerGroups, dnsFwdPort)
encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, plainResp)
if err != nil {
return status.Errorf(codes.Internal, "error handling request")
}
sendStart := time.Now()
err = srv.Send(&proto.EncryptedMessage{
WgPubKey: s.wgKey.PublicKey().String(),
Body: encryptedResp,
})
log.WithContext(ctx).Debugf("sendInitialSync: sending response took %s", time.Since(sendStart))
if err != nil {
log.WithContext(ctx).Errorf("failed sending SyncResponse %v", err)
@@ -838,7 +738,7 @@ func (s *GRPCServer) sendInitialSync(ctx context.Context, peerKey wgtypes.Key, p
// GetDeviceAuthorizationFlow returns a device authorization flow information
// This is used for initiating an Oauth 2 device authorization grant flow
// which will be used by our clients to Login
func (s *GRPCServer) GetDeviceAuthorizationFlow(ctx context.Context, req *proto.EncryptedMessage) (*proto.EncryptedMessage, error) {
func (s *Server) GetDeviceAuthorizationFlow(ctx context.Context, req *proto.EncryptedMessage) (*proto.EncryptedMessage, error) {
log.WithContext(ctx).Tracef("GetDeviceAuthorizationFlow request for pubKey: %s", req.WgPubKey)
start := time.Now()
defer func() {
@@ -896,7 +796,7 @@ func (s *GRPCServer) GetDeviceAuthorizationFlow(ctx context.Context, req *proto.
// GetPKCEAuthorizationFlow returns a pkce authorization flow information
// This is used for initiating an Oauth 2 pkce authorization grant flow
// which will be used by our clients to Login
func (s *GRPCServer) GetPKCEAuthorizationFlow(ctx context.Context, req *proto.EncryptedMessage) (*proto.EncryptedMessage, error) {
func (s *Server) GetPKCEAuthorizationFlow(ctx context.Context, req *proto.EncryptedMessage) (*proto.EncryptedMessage, error) {
log.WithContext(ctx).Tracef("GetPKCEAuthorizationFlow request for pubKey: %s", req.WgPubKey)
start := time.Now()
defer func() {
@@ -951,7 +851,7 @@ func (s *GRPCServer) GetPKCEAuthorizationFlow(ctx context.Context, req *proto.En
// SyncMeta endpoint is used to synchronize peer's system metadata and notifies the connected,
// peer's under the same account of any updates.
func (s *GRPCServer) SyncMeta(ctx context.Context, req *proto.EncryptedMessage) (*proto.Empty, error) {
func (s *Server) SyncMeta(ctx context.Context, req *proto.EncryptedMessage) (*proto.Empty, error) {
realIP := getRealIP(ctx)
log.WithContext(ctx).Debugf("Sync meta request from peer [%s] [%s]", req.WgPubKey, realIP.String())
@@ -976,7 +876,7 @@ func (s *GRPCServer) SyncMeta(ctx context.Context, req *proto.EncryptedMessage)
return &proto.Empty{}, nil
}
func (s *GRPCServer) Logout(ctx context.Context, req *proto.EncryptedMessage) (*proto.Empty, error) {
func (s *Server) Logout(ctx context.Context, req *proto.EncryptedMessage) (*proto.Empty, error) {
log.WithContext(ctx).Debugf("Logout request from peer [%s]", req.WgPubKey)
start := time.Now()

View File

@@ -0,0 +1,106 @@
package grpc
import (
"context"
"testing"
"github.com/stretchr/testify/require"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/encryption"
"github.com/netbirdio/netbird/management/internals/server/config"
mgmtProto "github.com/netbirdio/netbird/shared/management/proto"
)
func TestServer_GetDeviceAuthorizationFlow(t *testing.T) {
testingServerKey, err := wgtypes.GeneratePrivateKey()
if err != nil {
t.Errorf("unable to generate server wg key for testing GetDeviceAuthorizationFlow, error: %v", err)
}
testingClientKey, err := wgtypes.GeneratePrivateKey()
if err != nil {
t.Errorf("unable to generate client wg key for testing GetDeviceAuthorizationFlow, error: %v", err)
}
testCases := []struct {
name string
inputFlow *config.DeviceAuthorizationFlow
expectedFlow *mgmtProto.DeviceAuthorizationFlow
expectedErrFunc require.ErrorAssertionFunc
expectedErrMSG string
expectedComparisonFunc require.ComparisonAssertionFunc
expectedComparisonMSG string
}{
{
name: "Testing No Device Flow Config",
inputFlow: nil,
expectedErrFunc: require.Error,
expectedErrMSG: "should return error",
},
{
name: "Testing Invalid Device Flow Provider Config",
inputFlow: &config.DeviceAuthorizationFlow{
Provider: "NoNe",
ProviderConfig: config.ProviderConfig{
ClientID: "test",
},
},
expectedErrFunc: require.Error,
expectedErrMSG: "should return error",
},
{
name: "Testing Full Device Flow Config",
inputFlow: &config.DeviceAuthorizationFlow{
Provider: "hosted",
ProviderConfig: config.ProviderConfig{
ClientID: "test",
},
},
expectedFlow: &mgmtProto.DeviceAuthorizationFlow{
Provider: 0,
ProviderConfig: &mgmtProto.ProviderConfig{
ClientID: "test",
},
},
expectedErrFunc: require.NoError,
expectedErrMSG: "should not return error",
expectedComparisonFunc: require.Equal,
expectedComparisonMSG: "should match",
},
}
for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
mgmtServer := &Server{
wgKey: testingServerKey,
config: &config.Config{
DeviceAuthorizationFlow: testCase.inputFlow,
},
}
message := &mgmtProto.DeviceAuthorizationFlowRequest{}
encryptedMSG, err := encryption.EncryptMessage(testingClientKey.PublicKey(), mgmtServer.wgKey, message)
require.NoError(t, err, "should be able to encrypt message")
resp, err := mgmtServer.GetDeviceAuthorizationFlow(
context.TODO(),
&mgmtProto.EncryptedMessage{
WgPubKey: testingClientKey.PublicKey().String(),
Body: encryptedMSG,
},
)
testCase.expectedErrFunc(t, err, testCase.expectedErrMSG)
if testCase.expectedComparisonFunc != nil {
flowInfoResp := &mgmtProto.DeviceAuthorizationFlow{}
err = encryption.DecryptMessage(mgmtServer.wgKey.PublicKey(), testingClientKey, resp.Body, flowInfoResp)
require.NoError(t, err, "should be able to decrypt")
testCase.expectedComparisonFunc(t, testCase.expectedFlow.Provider, flowInfoResp.Provider, testCase.expectedComparisonMSG)
testCase.expectedComparisonFunc(t, testCase.expectedFlow.ProviderConfig.ClientID, flowInfoResp.ProviderConfig.ClientID, testCase.expectedComparisonMSG)
}
})
}
}

View File

@@ -1,4 +1,4 @@
package server
package grpc
import (
"context"
@@ -12,6 +12,7 @@ import (
log "github.com/sirupsen/logrus"
integrationsConfig "github.com/netbirdio/management-integrations/integrations/config"
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/management/server/groups"
"github.com/netbirdio/netbird/management/server/settings"
@@ -37,7 +38,7 @@ type TimeBasedAuthSecretsManager struct {
relayCfg *nbconfig.Relay
turnHmacToken *auth.TimedHMAC
relayHmacToken *authv2.Generator
updateManager *PeersUpdateManager
updateManager network_map.PeersUpdateManager
settingsManager settings.Manager
groupsManager groups.Manager
turnCancelMap map[string]chan struct{}
@@ -46,7 +47,7 @@ type TimeBasedAuthSecretsManager struct {
type Token auth.Token
func NewTimeBasedAuthSecretsManager(updateManager *PeersUpdateManager, turnCfg *nbconfig.TURNConfig, relayCfg *nbconfig.Relay, settingsManager settings.Manager, groupsManager groups.Manager) *TimeBasedAuthSecretsManager {
func NewTimeBasedAuthSecretsManager(updateManager network_map.PeersUpdateManager, turnCfg *nbconfig.TURNConfig, relayCfg *nbconfig.Relay, settingsManager settings.Manager, groupsManager groups.Manager) *TimeBasedAuthSecretsManager {
mgr := &TimeBasedAuthSecretsManager{
updateManager: updateManager,
turnCfg: turnCfg,
@@ -227,7 +228,7 @@ func (m *TimeBasedAuthSecretsManager) pushNewTURNAndRelayTokens(ctx context.Cont
m.extendNetbirdConfig(ctx, peerID, accountID, update)
log.WithContext(ctx).Debugf("sending new TURN credentials to peer %s", peerID)
m.updateManager.SendUpdate(ctx, peerID, &UpdateMessage{Update: update})
m.updateManager.SendUpdate(ctx, peerID, &network_map.UpdateMessage{Update: update})
}
func (m *TimeBasedAuthSecretsManager) pushNewRelayTokens(ctx context.Context, accountID, peerID string) {
@@ -251,7 +252,7 @@ func (m *TimeBasedAuthSecretsManager) pushNewRelayTokens(ctx context.Context, ac
m.extendNetbirdConfig(ctx, peerID, accountID, update)
log.WithContext(ctx).Debugf("sending new relay credentials to peer %s", peerID)
m.updateManager.SendUpdate(ctx, peerID, &UpdateMessage{Update: update})
m.updateManager.SendUpdate(ctx, peerID, &network_map.UpdateMessage{Update: update})
}
func (m *TimeBasedAuthSecretsManager) extendNetbirdConfig(ctx context.Context, peerID, accountID string, update *proto.SyncResponse) {

View File

@@ -1,4 +1,4 @@
package server
package grpc
import (
"context"
@@ -13,6 +13,8 @@ import (
"github.com/golang/mock/gomock"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
"github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/management/server/groups"
"github.com/netbirdio/netbird/management/server/settings"
@@ -31,7 +33,7 @@ var TurnTestHost = &config.Host{
func TestTimeBasedAuthSecretsManager_GenerateCredentials(t *testing.T) {
ttl := util.Duration{Duration: time.Hour}
secret := "some_secret"
peersManager := NewPeersUpdateManager(nil)
peersManager := update_channel.NewPeersUpdateManager(nil)
rc := &config.Relay{
Addresses: []string{"localhost:0"},
@@ -80,7 +82,7 @@ func TestTimeBasedAuthSecretsManager_GenerateCredentials(t *testing.T) {
func TestTimeBasedAuthSecretsManager_SetupRefresh(t *testing.T) {
ttl := util.Duration{Duration: 2 * time.Second}
secret := "some_secret"
peersManager := NewPeersUpdateManager(nil)
peersManager := update_channel.NewPeersUpdateManager(nil)
peer := "some_peer"
updateChannel := peersManager.CreateChannel(context.Background(), peer)
@@ -116,7 +118,7 @@ func TestTimeBasedAuthSecretsManager_SetupRefresh(t *testing.T) {
t.Errorf("expecting peer to be present in the relay cancel map, got not present")
}
var updates []*UpdateMessage
var updates []*network_map.UpdateMessage
loop:
for timeout := time.After(5 * time.Second); ; {
@@ -185,7 +187,7 @@ loop:
func TestTimeBasedAuthSecretsManager_CancelRefresh(t *testing.T) {
ttl := util.Duration{Duration: time.Hour}
secret := "some_secret"
peersManager := NewPeersUpdateManager(nil)
peersManager := update_channel.NewPeersUpdateManager(nil)
peer := "some_peer"
rc := &config.Relay{

View File

@@ -1,11 +1,19 @@
package main
import (
"github.com/netbirdio/netbird/management/cmd"
"log"
"net/http"
// nolint:gosec
_ "net/http/pprof"
"os"
"github.com/netbirdio/netbird/management/cmd"
)
func main() {
go func() {
log.Println(http.ListenAndServe("localhost:6060", nil))
}()
if err := cmd.Execute(); err != nil {
os.Exit(1)
}

View File

@@ -11,10 +11,8 @@ import (
"reflect"
"regexp"
"slices"
"strconv"
"strings"
"sync"
"sync/atomic"
"time"
cacheStore "github.com/eko/gocache/lib/v4/store"
@@ -26,6 +24,7 @@ import (
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/formatter/hook"
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
"github.com/netbirdio/netbird/management/server/account"
"github.com/netbirdio/netbird/management/server/activity"
nbcache "github.com/netbirdio/netbird/management/server/cache"
@@ -68,7 +67,7 @@ type DefaultAccountManager struct {
cacheMux sync.Mutex
// cacheLoading keeps the accountIDs that are currently reloading. The accountID has to be removed once cache has been reloaded
cacheLoading map[string]chan struct{}
peersUpdateManager *PeersUpdateManager
networkMapController network_map.Controller
idpManager idp.Manager
cacheManager *nbcache.AccountUserDataCache
externalCacheManager nbcache.UserDataCache
@@ -88,8 +87,7 @@ type DefaultAccountManager struct {
singleAccountMode bool
// singleAccountModeDomain is a domain to use in singleAccountMode setup
singleAccountModeDomain string
// dnsDomain is used for peer resolution. This is appended to the peer's name
dnsDomain string
peerLoginExpiry Scheduler
peerInactivityExpiry Scheduler
@@ -103,14 +101,11 @@ type DefaultAccountManager struct {
permissionsManager permissions.Manager
accountUpdateLocks sync.Map
updateAccountPeersBufferInterval atomic.Int64
loginFilter *loginFilter
disableDefaultPolicy bool
}
var _ account.Manager = (*DefaultAccountManager)(nil)
func isUniqueConstraintError(err error) bool {
switch {
case strings.Contains(err.Error(), "(SQLSTATE 23505)"),
@@ -177,10 +172,9 @@ func (am *DefaultAccountManager) getJWTGroupsChanges(user *types.User, groups []
func BuildManager(
ctx context.Context,
store store.Store,
peersUpdateManager *PeersUpdateManager,
networkMapController network_map.Controller,
idpManager idp.Manager,
singleAccountModeDomain string,
dnsDomain string,
eventStore activity.Store,
geo geolocation.Geolocation,
userDeleteFromIDPEnabled bool,
@@ -199,12 +193,11 @@ func BuildManager(
am := &DefaultAccountManager{
Store: store,
geo: geo,
peersUpdateManager: peersUpdateManager,
networkMapController: networkMapController,
idpManager: idpManager,
ctx: context.Background(),
cacheMux: sync.Mutex{},
cacheLoading: map[string]chan struct{}{},
dnsDomain: dnsDomain,
eventStore: eventStore,
peerLoginExpiry: NewDefaultScheduler(),
peerInactivityExpiry: NewDefaultScheduler(),
@@ -215,11 +208,10 @@ func BuildManager(
proxyController: proxyController,
settingsManager: settingsManager,
permissionsManager: permissionsManager,
loginFilter: newLoginFilter(),
disableDefaultPolicy: disableDefaultPolicy,
}
am.startWarmup(ctx)
am.networkMapController.StartWarmup(ctx)
accountsCounter, err := store.GetAccountsCounter(ctx)
if err != nil {
@@ -267,32 +259,6 @@ func (am *DefaultAccountManager) SetEphemeralManager(em ephemeral.Manager) {
am.ephemeralManager = em
}
func (am *DefaultAccountManager) startWarmup(ctx context.Context) {
var initialInterval int64
intervalStr := os.Getenv("NB_PEER_UPDATE_INTERVAL_MS")
interval, err := strconv.Atoi(intervalStr)
if err != nil {
initialInterval = 1
log.WithContext(ctx).Warnf("failed to parse peer update interval, using default value %dms: %v", initialInterval, err)
} else {
initialInterval = int64(interval) * 10
go func() {
startupPeriodStr := os.Getenv("NB_PEER_UPDATE_STARTUP_PERIOD_S")
startupPeriod, err := strconv.Atoi(startupPeriodStr)
if err != nil {
startupPeriod = 1
log.WithContext(ctx).Warnf("failed to parse peer update startup period, using default value %ds: %v", startupPeriod, err)
}
time.Sleep(time.Duration(startupPeriod) * time.Second)
am.updateAccountPeersBufferInterval.Store(int64(time.Duration(interval) * time.Millisecond))
log.WithContext(ctx).Infof("set peer update buffer interval to %dms", interval)
}()
}
am.updateAccountPeersBufferInterval.Store(initialInterval)
log.WithContext(ctx).Infof("set peer update buffer interval to %dms", initialInterval)
}
func (am *DefaultAccountManager) GetExternalCacheManager() account.ExternalCacheManager {
return am.externalCacheManager
}
@@ -1636,19 +1602,10 @@ func domainIsUpToDate(domain string, domainCategory string, userAuth nbcontext.U
return domainCategory == types.PrivateCategory || userAuth.DomainCategory != types.PrivateCategory || domain != userAuth.Domain
}
func (am *DefaultAccountManager) AllowSync(wgPubKey string, metahash uint64) bool {
return am.loginFilter.allowLogin(wgPubKey, metahash)
}
func (am *DefaultAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) {
start := time.Now()
defer func() {
log.WithContext(ctx).Debugf("SyncAndMarkPeer: took %v", time.Since(start))
}()
peer, netMap, postureChecks, err := am.SyncPeer(ctx, types.PeerSync{WireGuardPubKey: peerPubKey, Meta: meta}, accountID)
func (am *DefaultAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) {
peer, netMap, postureChecks, dnsfwdPort, err := am.SyncPeer(ctx, types.PeerSync{WireGuardPubKey: peerPubKey, Meta: meta}, accountID)
if err != nil {
return nil, nil, nil, fmt.Errorf("error syncing peer: %w", err)
return nil, nil, nil, 0, fmt.Errorf("error syncing peer: %w", err)
}
err = am.MarkPeerConnected(ctx, peerPubKey, true, realIP, accountID)
@@ -1656,10 +1613,7 @@ func (am *DefaultAccountManager) SyncAndMarkPeer(ctx context.Context, accountID
log.WithContext(ctx).Warnf("failed marking peer as connected %s %v", peerPubKey, err)
}
metahash := metaHash(meta, realIP.String())
am.loginFilter.addLogin(peerPubKey, metahash)
return peer, netMap, postureChecks, nil
return peer, netMap, postureChecks, dnsfwdPort, nil
}
func (am *DefaultAccountManager) OnPeerDisconnected(ctx context.Context, accountID string, peerPubKey string) error {
@@ -1676,41 +1630,19 @@ func (am *DefaultAccountManager) SyncPeerMeta(ctx context.Context, peerPubKey st
return err
}
_, _, _, err = am.SyncPeer(ctx, types.PeerSync{WireGuardPubKey: peerPubKey, Meta: meta, UpdateAccountPeers: true}, accountID)
_, _, _, _, err = am.SyncPeer(ctx, types.PeerSync{WireGuardPubKey: peerPubKey, Meta: meta, UpdateAccountPeers: true}, accountID)
if err != nil {
return mapError(ctx, err)
return err
}
return nil
}
// GetAllConnectedPeers returns connected peers based on peersUpdateManager.GetAllConnectedPeers()
func (am *DefaultAccountManager) GetAllConnectedPeers() (map[string]struct{}, error) {
return am.peersUpdateManager.GetAllConnectedPeers(), nil
}
// HasConnectedChannel returns true if peers has channel in update manager, otherwise false
func (am *DefaultAccountManager) HasConnectedChannel(peerID string) bool {
return am.peersUpdateManager.HasChannel(peerID)
}
var invalidDomainRegexp = regexp.MustCompile(`^([a-z0-9]+(-[a-z0-9]+)*\.)+[a-z]{2,}$`)
func isDomainValid(domain string) bool {
return invalidDomainRegexp.MatchString(domain)
}
// GetDNSDomain returns the configured dnsDomain
func (am *DefaultAccountManager) GetDNSDomain(settings *types.Settings) string {
if settings == nil {
return am.dnsDomain
}
if settings.DNSDomain == "" {
return am.dnsDomain
}
return settings.DNSDomain
}
func (am *DefaultAccountManager) onPeersInvalidated(ctx context.Context, accountID string, peerIDs []string) {
peers := []*nbpeer.Peer{}
log.WithContext(ctx).Debugf("invalidating peers %v for account %s", peerIDs, accountID)
@@ -2129,7 +2061,11 @@ func (am *DefaultAccountManager) UpdatePeerIP(ctx context.Context, accountID, us
}
if updateNetworkMap {
am.BufferUpdateAccountPeers(ctx, accountID)
peer, err := am.Store.GetPeerByID(ctx, store.LockingStrengthNone, accountID, peerID)
if err != nil {
return err
}
am.networkMapController.OnPeerUpdated(peer.AccountID, peer)
}
return nil
}
@@ -2177,7 +2113,7 @@ func (am *DefaultAccountManager) savePeerIPUpdate(ctx context.Context, transacti
if err != nil {
return fmt.Errorf("get account settings: %w", err)
}
dnsDomain := am.GetDNSDomain(settings)
dnsDomain := am.networkMapController.GetDNSDomain(settings)
eventMeta := peer.EventMeta(dnsDomain)
oldIP := peer.IP.String()

View File

@@ -89,7 +89,6 @@ type Manager interface {
SaveNameServerGroup(ctx context.Context, accountID, userID string, nsGroupToSave *nbdns.NameServerGroup) error
DeleteNameServerGroup(ctx context.Context, accountID, nsGroupID, userID string) error
ListNameServerGroups(ctx context.Context, accountID string, userID string) ([]*nbdns.NameServerGroup, error)
GetDNSDomain(settings *types.Settings) string
StoreEvent(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any)
GetEvents(ctx context.Context, accountID, userID string) ([]*activity.Event, error)
GetDNSSettings(ctx context.Context, accountID string, userID string) (*types.DNSSettings, error)
@@ -97,10 +96,8 @@ type Manager interface {
GetPeer(ctx context.Context, accountID, peerID, userID string) (*nbpeer.Peer, error)
UpdateAccountSettings(ctx context.Context, accountID, userID string, newSettings *types.Settings) (*types.Settings, error)
UpdateAccountOnboarding(ctx context.Context, accountID, userID string, newOnboarding *types.AccountOnboarding) (*types.AccountOnboarding, error)
LoginPeer(ctx context.Context, login types.PeerLogin) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) // used by peer gRPC API
SyncPeer(ctx context.Context, sync types.PeerSync, accountID string) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) // used by peer gRPC API
GetAllConnectedPeers() (map[string]struct{}, error)
HasConnectedChannel(peerID string) bool
LoginPeer(ctx context.Context, login types.PeerLogin) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) // used by peer gRPC API
SyncPeer(ctx context.Context, sync types.PeerSync, accountID string) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) // used by peer gRPC API
GetExternalCacheManager() ExternalCacheManager
GetPostureChecks(ctx context.Context, accountID, postureChecksID, userID string) (*posture.Checks, error)
SavePostureChecks(ctx context.Context, accountID, userID string, postureChecks *posture.Checks, create bool) (*posture.Checks, error)
@@ -110,7 +107,7 @@ type Manager interface {
UpdateIntegratedValidator(ctx context.Context, accountID, userID, validator string, groups []string) error
GroupValidation(ctx context.Context, accountId string, groups []string) (bool, error)
GetValidatedPeers(ctx context.Context, accountID string) (map[string]struct{}, map[string]string, error)
SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error)
SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error)
OnPeerDisconnected(ctx context.Context, accountID string, peerPubKey string) error
SyncPeerMeta(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta) error
FindExistingPostureCheck(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error)
@@ -127,5 +124,4 @@ type Manager interface {
GetOwnerInfo(ctx context.Context, accountId string) (*types.UserInfo, error)
GetCurrentUserInfo(ctx context.Context, userAuth nbcontext.UserAuth) (*users.UserInfoWithPermissions, error)
SetEphemeralManager(em ephemeral.Manager)
AllowSync(string, uint64) bool
}

View File

@@ -0,0 +1,11 @@
package account
import (
"context"
"github.com/netbirdio/netbird/management/server/types"
)
type RequestBuffer interface {
GetAccountWithBackpressure(ctx context.Context, accountID string) (*types.Account, error)
}

View File

@@ -22,6 +22,9 @@ import (
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
nbAccount "github.com/netbirdio/netbird/management/server/account"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/cache"
@@ -406,7 +409,7 @@ func TestNewAccount(t *testing.T) {
}
func TestAccountManager_GetOrCreateAccountByUser(t *testing.T) {
manager, err := createManager(t)
manager, _, err := createManager(t)
if err != nil {
t.Fatal(err)
return
@@ -603,7 +606,7 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) {
for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
manager, err := createManager(t)
manager, _, err := createManager(t)
require.NoError(t, err, "unable to create account manager")
accountID, err := manager.GetAccountIDByUserID(context.Background(), testCase.inputInitUserParams.UserId, testCase.inputInitUserParams.Domain)
@@ -644,7 +647,7 @@ func TestDefaultAccountManager_SyncUserJWTGroups(t *testing.T) {
userId := "user-id"
domain := "test.domain"
_ = newAccountWithId(context.Background(), "", userId, domain, false)
manager, err := createManager(t)
manager, _, err := createManager(t)
require.NoError(t, err, "unable to create account manager")
accountID, err := manager.GetAccountIDByUserID(context.Background(), userId, domain)
require.NoError(t, err, "create init user failed")
@@ -705,7 +708,7 @@ func TestDefaultAccountManager_SyncUserJWTGroups(t *testing.T) {
}
func TestAccountManager_PrivateAccount(t *testing.T) {
manager, err := createManager(t)
manager, _, err := createManager(t)
if err != nil {
t.Fatal(err)
return
@@ -731,7 +734,7 @@ func TestAccountManager_PrivateAccount(t *testing.T) {
}
func TestAccountManager_SetOrUpdateDomain(t *testing.T) {
manager, err := createManager(t)
manager, _, err := createManager(t)
if err != nil {
t.Fatal(err)
return
@@ -768,7 +771,7 @@ func TestAccountManager_SetOrUpdateDomain(t *testing.T) {
}
func TestAccountManager_GetAccountByUserID(t *testing.T) {
manager, err := createManager(t)
manager, _, err := createManager(t)
if err != nil {
t.Fatal(err)
return
@@ -805,7 +808,7 @@ func createAccount(am *DefaultAccountManager, accountID, userID, domain string)
}
func TestAccountManager_GetAccount(t *testing.T) {
manager, err := createManager(t)
manager, _, err := createManager(t)
if err != nil {
t.Fatal(err)
return
@@ -843,7 +846,7 @@ func TestAccountManager_GetAccount(t *testing.T) {
}
func TestAccountManager_DeleteAccount(t *testing.T) {
manager, err := createManager(t)
manager, _, err := createManager(t)
if err != nil {
t.Fatal(err)
return
@@ -924,7 +927,7 @@ func BenchmarkTest_GetAccountWithclaims(b *testing.B) {
DomainCategory: types.PublicCategory,
}
am, err := createManager(b)
am, _, err := createManager(b)
if err != nil {
b.Fatal(err)
return
@@ -1016,7 +1019,7 @@ func genUsers(p string, n int) map[string]*types.User {
}
func TestAccountManager_AddPeer(t *testing.T) {
manager, err := createManager(t)
manager, _, err := createManager(t)
if err != nil {
t.Fatal(err)
return
@@ -1086,7 +1089,7 @@ func TestAccountManager_AddPeer(t *testing.T) {
}
func TestAccountManager_AddPeerWithUserID(t *testing.T) {
manager, err := createManager(t)
manager, _, err := createManager(t)
if err != nil {
t.Fatal(err)
return
@@ -1154,8 +1157,17 @@ func TestAccountManager_AddPeerWithUserID(t *testing.T) {
assert.Equal(t, peer.IP.String(), fmt.Sprint(ev.Meta["ip"]))
}
func TestAccountManager_NetworkUpdates_SaveGroup_Experimental(t *testing.T) {
t.Setenv(network_map.EnvNewNetworkMapBuilder, "true")
testAccountManager_NetworkUpdates_SaveGroup(t)
}
func TestAccountManager_NetworkUpdates_SaveGroup(t *testing.T) {
manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t)
testAccountManager_NetworkUpdates_SaveGroup(t)
}
func testAccountManager_NetworkUpdates_SaveGroup(t *testing.T) {
manager, updateManager, account, peer1, peer2, peer3 := setupNetworkMapTest(t)
group := types.Group{
ID: "groupA",
@@ -1181,8 +1193,8 @@ func TestAccountManager_NetworkUpdates_SaveGroup(t *testing.T) {
}, true)
require.NoError(t, err)
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
updMsg := updateManager.CreateChannel(context.Background(), peer1.ID)
defer updateManager.CloseChannel(context.Background(), peer1.ID)
wg := sync.WaitGroup{}
wg.Add(1)
@@ -1205,11 +1217,20 @@ func TestAccountManager_NetworkUpdates_SaveGroup(t *testing.T) {
wg.Wait()
}
func TestAccountManager_NetworkUpdates_DeletePolicy(t *testing.T) {
manager, account, peer1, _, _ := setupNetworkMapTest(t)
func TestAccountManager_NetworkUpdates_DeletePolicy_Experimental(t *testing.T) {
t.Setenv(network_map.EnvNewNetworkMapBuilder, "true")
testAccountManager_NetworkUpdates_DeletePolicy(t)
}
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
func TestAccountManager_NetworkUpdates_DeletePolicy(t *testing.T) {
testAccountManager_NetworkUpdates_DeletePolicy(t)
}
func testAccountManager_NetworkUpdates_DeletePolicy(t *testing.T) {
manager, updateManager, account, peer1, _, _ := setupNetworkMapTest(t)
updMsg := updateManager.CreateChannel(context.Background(), peer1.ID)
defer updateManager.CloseChannel(context.Background(), peer1.ID)
// Ensure that we do not receive an update message before the policy is deleted
time.Sleep(time.Second)
@@ -1239,8 +1260,17 @@ func TestAccountManager_NetworkUpdates_DeletePolicy(t *testing.T) {
wg.Wait()
}
func TestAccountManager_NetworkUpdates_SavePolicy_Experimental(t *testing.T) {
t.Setenv(network_map.EnvNewNetworkMapBuilder, "true")
testAccountManager_NetworkUpdates_SavePolicy(t)
}
func TestAccountManager_NetworkUpdates_SavePolicy(t *testing.T) {
manager, account, peer1, peer2, _ := setupNetworkMapTest(t)
testAccountManager_NetworkUpdates_SavePolicy(t)
}
func testAccountManager_NetworkUpdates_SavePolicy(t *testing.T) {
manager, updateManager, account, peer1, peer2, _ := setupNetworkMapTest(t)
group := types.Group{
AccountID: account.Id,
@@ -1253,8 +1283,8 @@ func TestAccountManager_NetworkUpdates_SavePolicy(t *testing.T) {
return
}
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
updMsg := updateManager.CreateChannel(context.Background(), peer1.ID)
defer updateManager.CloseChannel(context.Background(), peer1.ID)
wg := sync.WaitGroup{}
wg.Add(1)
@@ -1288,8 +1318,17 @@ func TestAccountManager_NetworkUpdates_SavePolicy(t *testing.T) {
wg.Wait()
}
func TestAccountManager_NetworkUpdates_DeletePeer_Experimental(t *testing.T) {
t.Setenv(network_map.EnvNewNetworkMapBuilder, "true")
testAccountManager_NetworkUpdates_DeletePeer(t)
}
func TestAccountManager_NetworkUpdates_DeletePeer(t *testing.T) {
manager, account, peer1, _, peer3 := setupNetworkMapTest(t)
testAccountManager_NetworkUpdates_DeletePeer(t)
}
func testAccountManager_NetworkUpdates_DeletePeer(t *testing.T) {
manager, updateManager, account, peer1, _, peer3 := setupNetworkMapTest(t)
group := types.Group{
ID: "groupA",
@@ -1318,8 +1357,11 @@ func TestAccountManager_NetworkUpdates_DeletePeer(t *testing.T) {
return
}
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
// We need to sleep to wait for the buffer peer update
time.Sleep(300 * time.Millisecond)
updMsg := updateManager.CreateChannel(context.Background(), peer1.ID)
defer updateManager.CloseChannel(context.Background(), peer1.ID)
wg := sync.WaitGroup{}
wg.Add(1)
@@ -1341,11 +1383,20 @@ func TestAccountManager_NetworkUpdates_DeletePeer(t *testing.T) {
wg.Wait()
}
func TestAccountManager_NetworkUpdates_DeleteGroup(t *testing.T) {
manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t)
func TestAccountManager_NetworkUpdates_DeleteGroup_Experimental(t *testing.T) {
t.Setenv(network_map.EnvNewNetworkMapBuilder, "true")
testAccountManager_NetworkUpdates_DeleteGroup(t)
}
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
func TestAccountManager_NetworkUpdates_DeleteGroup(t *testing.T) {
testAccountManager_NetworkUpdates_DeleteGroup(t)
}
func testAccountManager_NetworkUpdates_DeleteGroup(t *testing.T) {
manager, updateManager, account, peer1, peer2, peer3 := setupNetworkMapTest(t)
updMsg := updateManager.CreateChannel(context.Background(), peer1.ID)
defer updateManager.CloseChannel(context.Background(), peer1.ID)
err := manager.CreateGroup(context.Background(), account.Id, userID, &types.Group{
ID: "groupA",
@@ -1377,6 +1428,14 @@ func TestAccountManager_NetworkUpdates_DeleteGroup(t *testing.T) {
return
}
for drained := false; !drained; {
select {
case <-updMsg:
default:
drained = true
}
}
wg := sync.WaitGroup{}
wg.Add(1)
go func() {
@@ -1404,7 +1463,7 @@ func TestAccountManager_NetworkUpdates_DeleteGroup(t *testing.T) {
}
func TestAccountManager_DeletePeer(t *testing.T) {
manager, err := createManager(t)
manager, _, err := createManager(t)
if err != nil {
t.Fatal(err)
return
@@ -1485,7 +1544,7 @@ func getEvent(t *testing.T, accountID string, manager nbAccount.Manager, eventTy
}
func TestGetUsersFromAccount(t *testing.T) {
manager, err := createManager(t)
manager, _, err := createManager(t)
if err != nil {
t.Fatal(err)
}
@@ -1736,7 +1795,9 @@ func TestAccount_Copy(t *testing.T) {
Address: "172.12.6.1/24",
},
},
NetworkMapCache: &types.NetworkMapBuilder{},
}
account.InitOnce()
err := hasNilField(account)
if err != nil {
t.Fatal(err)
@@ -1782,7 +1843,7 @@ func hasNilField(x interface{}) error {
}
func TestDefaultAccountManager_DefaultAccountSettings(t *testing.T) {
manager, err := createManager(t)
manager, _, err := createManager(t)
require.NoError(t, err, "unable to create account manager")
accountID, err := manager.GetAccountIDByUserID(context.Background(), userID, "")
@@ -1797,7 +1858,7 @@ func TestDefaultAccountManager_DefaultAccountSettings(t *testing.T) {
}
func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) {
manager, err := createManager(t)
manager, _, err := createManager(t)
require.NoError(t, err, "unable to create account manager")
_, err = manager.GetAccountIDByUserID(context.Background(), userID, "")
@@ -1853,7 +1914,7 @@ func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) {
}
func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing.T) {
manager, err := createManager(t)
manager, _, err := createManager(t)
require.NoError(t, err, "unable to create account manager")
accountID, err := manager.GetAccountIDByUserID(context.Background(), userID, "")
@@ -1896,7 +1957,7 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing.
}
func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *testing.T) {
manager, err := createManager(t)
manager, _, err := createManager(t)
require.NoError(t, err, "unable to create account manager")
_, err = manager.GetAccountIDByUserID(context.Background(), userID, "")
@@ -1958,7 +2019,7 @@ func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *test
}
func TestDefaultAccountManager_UpdateAccountSettings(t *testing.T) {
manager, err := createManager(t)
manager, _, err := createManager(t)
require.NoError(t, err, "unable to create account manager")
accountID, err := manager.GetAccountIDByUserID(context.Background(), userID, "")
@@ -2622,7 +2683,7 @@ func TestAccount_GetNextInactivePeerExpiration(t *testing.T) {
func TestAccount_SetJWTGroups(t *testing.T) {
t.Setenv("NETBIRD_STORE_ENGINE", "postgres")
manager, err := createManager(t)
manager, _, err := createManager(t)
require.NoError(t, err, "unable to create account manager")
// create a new account
@@ -2864,18 +2925,18 @@ func TestAccount_UserGroupsRemoveFromPeers(t *testing.T) {
// Fatalf(format string, args ...interface{})
// }
func createManager(t testing.TB) (*DefaultAccountManager, error) {
func createManager(t testing.TB) (*DefaultAccountManager, *update_channel.PeersUpdateManager, error) {
t.Helper()
store, err := createStore(t)
if err != nil {
return nil, err
return nil, nil, err
}
eventStore := &activity.InMemoryEventStore{}
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
if err != nil {
return nil, err
return nil, nil, err
}
ctrl := gomock.NewController(t)
@@ -2893,12 +2954,17 @@ func createManager(t testing.TB) (*DefaultAccountManager, error) {
permissionsManager := permissions.NewManager(store)
manager, err := BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
ctx := context.Background()
updateManager := update_channel.NewPeersUpdateManager(metrics)
requestBuffer := NewAccountRequestBuffer(ctx, store)
networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock())
manager, err := BuildManager(ctx, store, networkMapController, nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
if err != nil {
return nil, err
return nil, nil, err
}
return manager, nil
return manager, updateManager, nil
}
func createStore(t testing.TB) (store.Store, error) {
@@ -2927,10 +2993,10 @@ func waitTimeout(wg *sync.WaitGroup, timeout time.Duration) bool {
}
}
func setupNetworkMapTest(t *testing.T) (*DefaultAccountManager, *types.Account, *nbpeer.Peer, *nbpeer.Peer, *nbpeer.Peer) {
func setupNetworkMapTest(t *testing.T) (*DefaultAccountManager, *update_channel.PeersUpdateManager, *types.Account, *nbpeer.Peer, *nbpeer.Peer, *nbpeer.Peer) {
t.Helper()
manager, err := createManager(t)
manager, updateManager, err := createManager(t)
if err != nil {
t.Fatal(err)
}
@@ -2971,10 +3037,10 @@ func setupNetworkMapTest(t *testing.T) (*DefaultAccountManager, *types.Account,
peer2 := getPeer(manager, setupKey)
peer3 := getPeer(manager, setupKey)
return manager, account, peer1, peer2, peer3
return manager, updateManager, account, peer1, peer2, peer3
}
func peerShouldNotReceiveUpdate(t *testing.T, updateMessage <-chan *UpdateMessage) {
func peerShouldNotReceiveUpdate(t *testing.T, updateMessage <-chan *network_map.UpdateMessage) {
t.Helper()
select {
case msg := <-updateMessage:
@@ -2984,7 +3050,7 @@ func peerShouldNotReceiveUpdate(t *testing.T, updateMessage <-chan *UpdateMessag
}
}
func peerShouldReceiveUpdate(t *testing.T, updateMessage <-chan *UpdateMessage) {
func peerShouldReceiveUpdate(t *testing.T, updateMessage <-chan *network_map.UpdateMessage) {
t.Helper()
select {
@@ -3022,7 +3088,7 @@ func BenchmarkSyncAndMarkPeer(b *testing.B) {
defer log.SetOutput(os.Stderr)
for _, bc := range benchCases {
b.Run(bc.name, func(b *testing.B) {
manager, accountID, _, err := setupTestAccountManager(b, bc.peers, bc.groups)
manager, updateManager, accountID, _, err := setupTestAccountManager(b, bc.peers, bc.groups)
if err != nil {
b.Fatalf("Failed to setup test account manager: %v", err)
}
@@ -3031,16 +3097,14 @@ func BenchmarkSyncAndMarkPeer(b *testing.B) {
if err != nil {
b.Fatalf("Failed to get account: %v", err)
}
peerChannels := make(map[string]chan *UpdateMessage)
for peerID := range account.Peers {
peerChannels[peerID] = make(chan *UpdateMessage, channelBufferSize)
updateManager.CreateChannel(ctx, peerID)
}
manager.peersUpdateManager.peerChannels = peerChannels
b.ResetTimer()
start := time.Now()
for i := 0; i < b.N; i++ {
_, _, _, err := manager.SyncAndMarkPeer(context.Background(), account.Id, account.Peers["peer-1"].Key, nbpeer.PeerSystemMeta{Hostname: strconv.Itoa(i)}, net.IP{1, 1, 1, 1})
_, _, _, _, err := manager.SyncAndMarkPeer(context.Background(), account.Id, account.Peers["peer-1"].Key, nbpeer.PeerSystemMeta{Hostname: strconv.Itoa(i)}, net.IP{1, 1, 1, 1})
assert.NoError(b, err)
}
@@ -3085,7 +3149,7 @@ func BenchmarkLoginPeer_ExistingPeer(b *testing.B) {
defer log.SetOutput(os.Stderr)
for _, bc := range benchCases {
b.Run(bc.name, func(b *testing.B) {
manager, accountID, _, err := setupTestAccountManager(b, bc.peers, bc.groups)
manager, updateManager, accountID, _, err := setupTestAccountManager(b, bc.peers, bc.groups)
if err != nil {
b.Fatalf("Failed to setup test account manager: %v", err)
}
@@ -3094,11 +3158,10 @@ func BenchmarkLoginPeer_ExistingPeer(b *testing.B) {
if err != nil {
b.Fatalf("Failed to get account: %v", err)
}
peerChannels := make(map[string]chan *UpdateMessage)
for peerID := range account.Peers {
peerChannels[peerID] = make(chan *UpdateMessage, channelBufferSize)
updateManager.CreateChannel(ctx, peerID)
}
manager.peersUpdateManager.peerChannels = peerChannels
b.ResetTimer()
start := time.Now()
@@ -3155,7 +3218,7 @@ func BenchmarkLoginPeer_NewPeer(b *testing.B) {
defer log.SetOutput(os.Stderr)
for _, bc := range benchCases {
b.Run(bc.name, func(b *testing.B) {
manager, accountID, _, err := setupTestAccountManager(b, bc.peers, bc.groups)
manager, updateManager, accountID, _, err := setupTestAccountManager(b, bc.peers, bc.groups)
if err != nil {
b.Fatalf("Failed to setup test account manager: %v", err)
}
@@ -3164,11 +3227,10 @@ func BenchmarkLoginPeer_NewPeer(b *testing.B) {
if err != nil {
b.Fatalf("Failed to get account: %v", err)
}
peerChannels := make(map[string]chan *UpdateMessage)
for peerID := range account.Peers {
peerChannels[peerID] = make(chan *UpdateMessage, channelBufferSize)
updateManager.CreateChannel(ctx, peerID)
}
manager.peersUpdateManager.peerChannels = peerChannels
b.ResetTimer()
start := time.Now()
@@ -3227,7 +3289,7 @@ func TestMain(m *testing.M) {
}
func Test_GetCreateAccountByPrivateDomain(t *testing.T) {
manager, err := createManager(t)
manager, _, err := createManager(t)
if err != nil {
t.Fatal(err)
return
@@ -3273,7 +3335,7 @@ func Test_GetCreateAccountByPrivateDomain(t *testing.T) {
}
func Test_UpdateToPrimaryAccount(t *testing.T) {
manager, err := createManager(t)
manager, _, err := createManager(t)
if err != nil {
t.Fatal(err)
return
@@ -3303,7 +3365,7 @@ func Test_UpdateToPrimaryAccount(t *testing.T) {
}
func TestDefaultAccountManager_IsCacheCold(t *testing.T) {
manager, err := createManager(t)
manager, _, err := createManager(t)
require.NoError(t, err)
t.Run("memory cache", func(t *testing.T) {
@@ -3353,7 +3415,7 @@ func TestDefaultAccountManager_IsCacheCold(t *testing.T) {
}
func TestPropagateUserGroupMemberships(t *testing.T) {
manager, err := createManager(t)
manager, _, err := createManager(t)
require.NoError(t, err)
ctx := context.Background()
@@ -3470,7 +3532,7 @@ func TestPropagateUserGroupMemberships(t *testing.T) {
}
func TestDefaultAccountManager_GetAccountOnboarding(t *testing.T) {
manager, err := createManager(t)
manager, _, err := createManager(t)
require.NoError(t, err)
account, err := manager.GetOrCreateAccountByUser(context.Background(), userID, "")
@@ -3502,7 +3564,7 @@ func TestDefaultAccountManager_GetAccountOnboarding(t *testing.T) {
}
func TestDefaultAccountManager_UpdateAccountOnboarding(t *testing.T) {
manager, err := createManager(t)
manager, _, err := createManager(t)
require.NoError(t, err)
account, err := manager.GetOrCreateAccountByUser(context.Background(), userID, "")
@@ -3541,7 +3603,7 @@ func TestDefaultAccountManager_UpdateAccountOnboarding(t *testing.T) {
}
func TestDefaultAccountManager_UpdatePeerIP(t *testing.T) {
manager, err := createManager(t)
manager, _, err := createManager(t)
require.NoError(t, err, "unable to create account manager")
accountID, err := manager.GetAccountIDByUserID(context.Background(), userID, "")
@@ -3608,7 +3670,7 @@ func TestDefaultAccountManager_UpdatePeerIP(t *testing.T) {
}
func TestAddNewUserToDomainAccountWithApproval(t *testing.T) {
manager, err := createManager(t)
manager, _, err := createManager(t)
if err != nil {
t.Fatal(err)
}
@@ -3654,7 +3716,7 @@ func TestAddNewUserToDomainAccountWithApproval(t *testing.T) {
}
func TestAddNewUserToDomainAccountWithoutApproval(t *testing.T) {
manager, err := createManager(t)
manager, _, err := createManager(t)
if err != nil {
t.Fatal(err)
}

View File

@@ -3,54 +3,23 @@ package server
import (
"context"
"slices"
"sync"
log "github.com/sirupsen/logrus"
"golang.org/x/mod/semver"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/server/activity"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/management/server/util"
"github.com/netbirdio/netbird/shared/management/proto"
"github.com/netbirdio/netbird/shared/management/status"
)
const (
dnsForwarderPort = nbdns.ForwarderServerPort
oldForwarderPort = nbdns.ForwarderClientPort
)
const dnsForwarderPortMinVersion = "v0.59.0"
// DNSConfigCache is a thread-safe cache for DNS configuration components
type DNSConfigCache struct {
NameServerGroups sync.Map
}
// GetNameServerGroup retrieves a cached name server group
func (c *DNSConfigCache) GetNameServerGroup(key string) (*proto.NameServerGroup, bool) {
if c == nil {
return nil, false
}
if value, ok := c.NameServerGroups.Load(key); ok {
return value.(*proto.NameServerGroup), true
}
return nil, false
}
// SetNameServerGroup stores a name server group in the cache
func (c *DNSConfigCache) SetNameServerGroup(key string, value *proto.NameServerGroup) {
if c == nil {
return
}
c.NameServerGroups.Store(key, value)
}
// GetDNSSettings validates a user role and returns the DNS settings for the provided account ID
func (am *DefaultAccountManager) GetDNSSettings(ctx context.Context, accountID string, userID string) (*types.DNSSettings, error) {
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Read)
@@ -191,99 +160,3 @@ func validateDNSSettings(ctx context.Context, transaction store.Store, accountID
return validateGroups(settings.DisabledManagementGroups, groups)
}
// computeForwarderPort checks if all peers in the account have updated to a specific version or newer.
// If all peers have the required version, it returns the new well-known port (22054), otherwise returns 0.
func computeForwarderPort(peers []*nbpeer.Peer, requiredVersion string) int64 {
if len(peers) == 0 {
return int64(oldForwarderPort)
}
reqVer := semver.Canonical(requiredVersion)
// Check if all peers have the required version or newer
for _, peer := range peers {
// Development version is always supported
if peer.Meta.WtVersion == "development" {
continue
}
peerVersion := semver.Canonical("v" + peer.Meta.WtVersion)
if peerVersion == "" {
// If any peer doesn't have version info, return 0
return int64(oldForwarderPort)
}
// Compare versions
if semver.Compare(peerVersion, reqVer) < 0 {
return int64(oldForwarderPort)
}
}
// All peers have the required version or newer
return int64(dnsForwarderPort)
}
// toProtocolDNSConfig converts nbdns.Config to proto.DNSConfig using the cache
func toProtocolDNSConfig(update nbdns.Config, cache *DNSConfigCache, forwardPort int64) *proto.DNSConfig {
protoUpdate := &proto.DNSConfig{
ServiceEnable: update.ServiceEnable,
CustomZones: make([]*proto.CustomZone, 0, len(update.CustomZones)),
NameServerGroups: make([]*proto.NameServerGroup, 0, len(update.NameServerGroups)),
ForwarderPort: forwardPort,
}
for _, zone := range update.CustomZones {
protoZone := convertToProtoCustomZone(zone)
protoUpdate.CustomZones = append(protoUpdate.CustomZones, protoZone)
}
for _, nsGroup := range update.NameServerGroups {
cacheKey := nsGroup.ID
if cachedGroup, exists := cache.GetNameServerGroup(cacheKey); exists {
protoUpdate.NameServerGroups = append(protoUpdate.NameServerGroups, cachedGroup)
} else {
protoGroup := convertToProtoNameServerGroup(nsGroup)
cache.SetNameServerGroup(cacheKey, protoGroup)
protoUpdate.NameServerGroups = append(protoUpdate.NameServerGroups, protoGroup)
}
}
return protoUpdate
}
// Helper function to convert nbdns.CustomZone to proto.CustomZone
func convertToProtoCustomZone(zone nbdns.CustomZone) *proto.CustomZone {
protoZone := &proto.CustomZone{
Domain: zone.Domain,
Records: make([]*proto.SimpleRecord, 0, len(zone.Records)),
}
for _, record := range zone.Records {
protoZone.Records = append(protoZone.Records, &proto.SimpleRecord{
Name: record.Name,
Type: int64(record.Type),
Class: record.Class,
TTL: int64(record.TTL),
RData: record.RData,
})
}
return protoZone
}
// Helper function to convert nbdns.NameServerGroup to proto.NameServerGroup
func convertToProtoNameServerGroup(nsGroup *nbdns.NameServerGroup) *proto.NameServerGroup {
protoGroup := &proto.NameServerGroup{
Primary: nsGroup.Primary,
Domains: nsGroup.Domains,
SearchDomainsEnabled: nsGroup.SearchDomainsEnabled,
NameServers: make([]*proto.NameServer, 0, len(nsGroup.NameServers)),
}
for _, ns := range nsGroup.NameServers {
protoGroup.NameServers = append(protoGroup.NameServers, &proto.NameServer{
IP: ns.IP.String(),
Port: int64(ns.Port),
NSType: int64(ns.NSType),
})
}
return protoGroup
}

View File

@@ -2,9 +2,7 @@ package server
import (
"context"
"fmt"
"net/netip"
"reflect"
"testing"
"time"
@@ -12,6 +10,8 @@ import (
"github.com/stretchr/testify/assert"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/settings"
@@ -218,7 +218,13 @@ func createDNSManager(t *testing.T) (*DefaultAccountManager, error) {
// return empty extra settings for expected calls to UpdateAccountPeers
settingsMockManager.EXPECT().GetExtraSettings(gomock.Any(), gomock.Any()).Return(&types.ExtraSettings{}, nil).AnyTimes()
permissionsManager := permissions.NewManager(store)
return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.test", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
ctx := context.Background()
updateManager := update_channel.NewPeersUpdateManager(metrics)
requestBuffer := NewAccountRequestBuffer(ctx, store)
networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.test", port_forwarding.NewControllerMock())
return BuildManager(context.Background(), store, networkMapController, nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
}
func createDNSStore(t *testing.T) (store.Store, error) {
@@ -344,247 +350,8 @@ func initTestDNSAccount(t *testing.T, am *DefaultAccountManager) (*types.Account
return am.Store.GetAccount(context.Background(), account.Id)
}
func generateTestData(size int) nbdns.Config {
config := nbdns.Config{
ServiceEnable: true,
CustomZones: make([]nbdns.CustomZone, size),
NameServerGroups: make([]*nbdns.NameServerGroup, size),
}
for i := 0; i < size; i++ {
config.CustomZones[i] = nbdns.CustomZone{
Domain: fmt.Sprintf("domain%d.com", i),
Records: []nbdns.SimpleRecord{
{
Name: fmt.Sprintf("record%d", i),
Type: 1,
Class: "IN",
TTL: 3600,
RData: "192.168.1.1",
},
},
}
config.NameServerGroups[i] = &nbdns.NameServerGroup{
ID: fmt.Sprintf("group%d", i),
Primary: i == 0,
Domains: []string{fmt.Sprintf("domain%d.com", i)},
SearchDomainsEnabled: true,
NameServers: []nbdns.NameServer{
{
IP: netip.MustParseAddr("8.8.8.8"),
Port: 53,
NSType: 1,
},
},
}
}
return config
}
func BenchmarkToProtocolDNSConfig(b *testing.B) {
sizes := []int{10, 100, 1000}
for _, size := range sizes {
testData := generateTestData(size)
b.Run(fmt.Sprintf("WithCache-Size%d", size), func(b *testing.B) {
cache := &DNSConfigCache{}
b.ResetTimer()
for i := 0; i < b.N; i++ {
toProtocolDNSConfig(testData, cache, int64(dnsForwarderPort))
}
})
b.Run(fmt.Sprintf("WithoutCache-Size%d", size), func(b *testing.B) {
b.ResetTimer()
for i := 0; i < b.N; i++ {
cache := &DNSConfigCache{}
toProtocolDNSConfig(testData, cache, int64(dnsForwarderPort))
}
})
}
}
func TestToProtocolDNSConfigWithCache(t *testing.T) {
var cache DNSConfigCache
// Create two different configs
config1 := nbdns.Config{
ServiceEnable: true,
CustomZones: []nbdns.CustomZone{
{
Domain: "example.com",
Records: []nbdns.SimpleRecord{
{Name: "www", Type: 1, Class: "IN", TTL: 300, RData: "192.168.1.1"},
},
},
},
NameServerGroups: []*nbdns.NameServerGroup{
{
ID: "group1",
Name: "Group 1",
NameServers: []nbdns.NameServer{
{IP: netip.MustParseAddr("8.8.8.8"), Port: 53},
},
},
},
}
config2 := nbdns.Config{
ServiceEnable: true,
CustomZones: []nbdns.CustomZone{
{
Domain: "example.org",
Records: []nbdns.SimpleRecord{
{Name: "mail", Type: 1, Class: "IN", TTL: 300, RData: "192.168.1.2"},
},
},
},
NameServerGroups: []*nbdns.NameServerGroup{
{
ID: "group2",
Name: "Group 2",
NameServers: []nbdns.NameServer{
{IP: netip.MustParseAddr("8.8.4.4"), Port: 53},
},
},
},
}
// First run with config1
result1 := toProtocolDNSConfig(config1, &cache, int64(dnsForwarderPort))
// Second run with config2
result2 := toProtocolDNSConfig(config2, &cache, int64(dnsForwarderPort))
// Third run with config1 again
result3 := toProtocolDNSConfig(config1, &cache, int64(dnsForwarderPort))
// Verify that result1 and result3 are identical
if !reflect.DeepEqual(result1, result3) {
t.Errorf("Results are not identical when run with the same input. Expected %v, got %v", result1, result3)
}
// Verify that result2 is different from result1 and result3
if reflect.DeepEqual(result1, result2) || reflect.DeepEqual(result2, result3) {
t.Errorf("Results should be different for different inputs")
}
if _, exists := cache.GetNameServerGroup("group1"); !exists {
t.Errorf("Cache should contain name server group 'group1'")
}
if _, exists := cache.GetNameServerGroup("group2"); !exists {
t.Errorf("Cache should contain name server group 'group2'")
}
}
func TestComputeForwarderPort(t *testing.T) {
// Test with empty peers list
peers := []*nbpeer.Peer{}
result := computeForwarderPort(peers, "v0.59.0")
if result != int64(oldForwarderPort) {
t.Errorf("Expected %d for empty peers list, got %d", oldForwarderPort, result)
}
// Test with peers that have old versions
peers = []*nbpeer.Peer{
{
Meta: nbpeer.PeerSystemMeta{
WtVersion: "0.57.0",
},
},
{
Meta: nbpeer.PeerSystemMeta{
WtVersion: "0.26.0",
},
},
}
result = computeForwarderPort(peers, "v0.59.0")
if result != int64(oldForwarderPort) {
t.Errorf("Expected %d for peers with old versions, got %d", oldForwarderPort, result)
}
// Test with peers that have new versions
peers = []*nbpeer.Peer{
{
Meta: nbpeer.PeerSystemMeta{
WtVersion: "0.59.0",
},
},
{
Meta: nbpeer.PeerSystemMeta{
WtVersion: "0.59.0",
},
},
}
result = computeForwarderPort(peers, "v0.59.0")
if result != int64(dnsForwarderPort) {
t.Errorf("Expected %d for peers with new versions, got %d", dnsForwarderPort, result)
}
// Test with peers that have mixed versions
peers = []*nbpeer.Peer{
{
Meta: nbpeer.PeerSystemMeta{
WtVersion: "0.59.0",
},
},
{
Meta: nbpeer.PeerSystemMeta{
WtVersion: "0.57.0",
},
},
}
result = computeForwarderPort(peers, "v0.59.0")
if result != int64(oldForwarderPort) {
t.Errorf("Expected %d for peers with mixed versions, got %d", oldForwarderPort, result)
}
// Test with peers that have empty version
peers = []*nbpeer.Peer{
{
Meta: nbpeer.PeerSystemMeta{
WtVersion: "",
},
},
}
result = computeForwarderPort(peers, "v0.59.0")
if result != int64(oldForwarderPort) {
t.Errorf("Expected %d for peers with empty version, got %d", oldForwarderPort, result)
}
peers = []*nbpeer.Peer{
{
Meta: nbpeer.PeerSystemMeta{
WtVersion: "development",
},
},
}
result = computeForwarderPort(peers, "v0.59.0")
if result == int64(oldForwarderPort) {
t.Errorf("Expected %d for peers with dev version, got %d", dnsForwarderPort, result)
}
// Test with peers that have unknown version string
peers = []*nbpeer.Peer{
{
Meta: nbpeer.PeerSystemMeta{
WtVersion: "unknown",
},
},
}
result = computeForwarderPort(peers, "v0.59.0")
if result != int64(oldForwarderPort) {
t.Errorf("Expected %d for peers with unknown version, got %d", oldForwarderPort, result)
}
}
func TestDNSAccountPeersUpdate(t *testing.T) {
manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t)
manager, updateManager, account, peer1, peer2, peer3 := setupNetworkMapTest(t)
err := manager.CreateGroups(context.Background(), account.Id, userID, []*types.Group{
{
@@ -600,9 +367,9 @@ func TestDNSAccountPeersUpdate(t *testing.T) {
})
assert.NoError(t, err)
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
updMsg := updateManager.CreateChannel(context.Background(), peer1.ID)
t.Cleanup(func() {
manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
updateManager.CloseChannel(context.Background(), peer1.ID)
})
// Saving DNS settings with groups that have no peers should not trigger updates to account peers or send peer updates

View File

@@ -28,7 +28,7 @@ func generateAndStoreEvents(t *testing.T, manager *DefaultAccountManager, typ ac
}
func TestDefaultAccountManager_GetEvents(t *testing.T) {
manager, err := createManager(t)
manager, _, err := createManager(t)
if err != nil {
return
}

View File

@@ -138,6 +138,11 @@ func (am *DefaultAccountManager) UpdateGroup(ctx context.Context, accountID, use
return err
}
newGroup.AccountID = accountID
events := am.prepareGroupEvents(ctx, transaction, accountID, userID, newGroup)
eventsToStore = append(eventsToStore, events...)
oldGroup, err := transaction.GetGroupByID(ctx, store.LockingStrengthNone, accountID, newGroup.ID)
if err != nil {
return status.Errorf(status.NotFound, "group with ID %s not found", newGroup.ID)
@@ -157,11 +162,6 @@ func (am *DefaultAccountManager) UpdateGroup(ctx context.Context, accountID, use
}
}
newGroup.AccountID = accountID
events := am.prepareGroupEvents(ctx, transaction, accountID, userID, newGroup)
eventsToStore = append(eventsToStore, events...)
updateAccountPeers, err = areGroupChangesAffectPeers(ctx, transaction, accountID, []string{newGroup.ID})
if err != nil {
return err
@@ -335,6 +335,16 @@ func (am *DefaultAccountManager) prepareGroupEvents(ctx context.Context, transac
if err == nil && oldGroup != nil {
addedPeers = util.Difference(newGroup.Peers, oldGroup.Peers)
removedPeers = util.Difference(oldGroup.Peers, newGroup.Peers)
if oldGroup.Name != newGroup.Name {
eventsToStore = append(eventsToStore, func() {
meta := map[string]any{
"old_name": oldGroup.Name,
"new_name": newGroup.Name,
}
am.StoreEvent(ctx, userID, newGroup.ID, accountID, activity.GroupUpdated, meta)
})
}
} else {
addedPeers = append(addedPeers, newGroup.Peers...)
eventsToStore = append(eventsToStore, func() {
@@ -354,7 +364,7 @@ func (am *DefaultAccountManager) prepareGroupEvents(ctx context.Context, transac
log.WithContext(ctx).Debugf("failed to get account settings for group events: %v", err)
return nil
}
dnsDomain := am.GetDNSDomain(settings)
dnsDomain := am.networkMapController.GetDNSDomain(settings)
for _, peerID := range addedPeers {
peer, ok := peers[peerID]

View File

@@ -37,7 +37,7 @@ const (
)
func TestDefaultAccountManager_CreateGroup(t *testing.T) {
am, err := createManager(t)
am, _, err := createManager(t)
if err != nil {
t.Error("failed to create account manager")
}
@@ -74,7 +74,7 @@ func TestDefaultAccountManager_CreateGroup(t *testing.T) {
}
func TestDefaultAccountManager_DeleteGroup(t *testing.T) {
am, err := createManager(t)
am, _, err := createManager(t)
if err != nil {
t.Fatalf("failed to create account manager: %s", err)
}
@@ -156,7 +156,7 @@ func TestDefaultAccountManager_DeleteGroup(t *testing.T) {
}
func TestDefaultAccountManager_DeleteGroups(t *testing.T) {
am, err := createManager(t)
am, _, err := createManager(t)
assert.NoError(t, err, "Failed to create account manager")
manager, account, err := initTestGroupAccount(am)
@@ -408,7 +408,7 @@ func initTestGroupAccount(am *DefaultAccountManager) (*DefaultAccountManager, *t
}
func TestGroupAccountPeersUpdate(t *testing.T) {
manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t)
manager, updateManager, account, peer1, peer2, peer3 := setupNetworkMapTest(t)
g := []*types.Group{
{
@@ -442,9 +442,9 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
assert.NoError(t, err)
}
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
updMsg := updateManager.CreateChannel(context.Background(), peer1.ID)
t.Cleanup(func() {
manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
updateManager.CloseChannel(context.Background(), peer1.ID)
})
// Saving a group that is not linked to any resource should not update account peers
@@ -748,7 +748,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
}
func Test_AddPeerToGroup(t *testing.T) {
manager, err := createManager(t)
manager, _, err := createManager(t)
if err != nil {
t.Fatal(err)
return
@@ -805,7 +805,7 @@ func Test_AddPeerToGroup(t *testing.T) {
}
func Test_AddPeerToAll(t *testing.T) {
manager, err := createManager(t)
manager, _, err := createManager(t)
if err != nil {
t.Fatal(err)
return
@@ -862,7 +862,7 @@ func Test_AddPeerToAll(t *testing.T) {
}
func Test_AddPeerAndAddToAll(t *testing.T) {
manager, err := createManager(t)
manager, _, err := createManager(t)
if err != nil {
t.Fatal(err)
return
@@ -942,7 +942,7 @@ func uint32ToIP(n uint32) net.IP {
}
func Test_IncrementNetworkSerial(t *testing.T) {
manager, err := createManager(t)
manager, _, err := createManager(t)
if err != nil {
t.Fatal(err)
return

View File

@@ -4,11 +4,16 @@ import (
"context"
"fmt"
"net/http"
"os"
"strconv"
"time"
"github.com/gorilla/mux"
"github.com/rs/cors"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/management-integrations/integrations"
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
"github.com/netbirdio/netbird/management/server/account"
"github.com/netbirdio/netbird/management/server/settings"
@@ -38,7 +43,12 @@ import (
"github.com/netbirdio/netbird/management/server/telemetry"
)
const apiPrefix = "/api"
const (
apiPrefix = "/api"
rateLimitingEnabledKey = "NB_API_RATE_LIMITING_ENABLED"
rateLimitingBurstKey = "NB_API_RATE_LIMITING_BURST"
rateLimitingRPMKey = "NB_API_RATE_LIMITING_RPM"
)
// NewAPIHandler creates the Management service HTTP API handler registering all the available endpoints.
func NewAPIHandler(
@@ -56,13 +66,45 @@ func NewAPIHandler(
permissionsManager permissions.Manager,
peersManager nbpeers.Manager,
settingsManager settings.Manager,
networkMapController network_map.Controller,
) (http.Handler, error) {
var rateLimitingConfig *middleware.RateLimiterConfig
if os.Getenv(rateLimitingEnabledKey) == "true" {
rpm := 6
if v := os.Getenv(rateLimitingRPMKey); v != "" {
value, err := strconv.Atoi(v)
if err != nil {
log.Warnf("parsing %s env var: %v, using default %d", rateLimitingRPMKey, err, rpm)
} else {
rpm = value
}
}
burst := 500
if v := os.Getenv(rateLimitingBurstKey); v != "" {
value, err := strconv.Atoi(v)
if err != nil {
log.Warnf("parsing %s env var: %v, using default %d", rateLimitingBurstKey, err, burst)
} else {
burst = value
}
}
rateLimitingConfig = &middleware.RateLimiterConfig{
RequestsPerMinute: float64(rpm),
Burst: burst,
CleanupInterval: 6 * time.Hour,
LimiterTTL: 24 * time.Hour,
}
}
authMiddleware := middleware.NewAuthMiddleware(
authManager,
accountManager.GetAccountIDFromUserAuth,
accountManager.SyncUserJWTGroups,
accountManager.GetUserFromUserAuth,
rateLimitingConfig,
)
corsMiddleware := cors.AllowAll()
@@ -80,7 +122,7 @@ func NewAPIHandler(
}
accounts.AddEndpoints(accountManager, settingsManager, router)
peers.AddEndpoints(accountManager, router)
peers.AddEndpoints(accountManager, router, networkMapController)
users.AddEndpoints(accountManager, router)
setup_keys.AddEndpoints(accountManager, router)
policies.AddEndpoints(accountManager, LocationManager, router)

View File

@@ -10,6 +10,7 @@ import (
"github.com/gorilla/mux"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
"github.com/netbirdio/netbird/management/server/account"
"github.com/netbirdio/netbird/management/server/activity"
nbcontext "github.com/netbirdio/netbird/management/server/context"
@@ -23,11 +24,12 @@ import (
// Handler is a handler that returns peers of the account
type Handler struct {
accountManager account.Manager
accountManager account.Manager
networkMapController network_map.Controller
}
func AddEndpoints(accountManager account.Manager, router *mux.Router) {
peersHandler := NewHandler(accountManager)
func AddEndpoints(accountManager account.Manager, router *mux.Router, networkMapController network_map.Controller) {
peersHandler := NewHandler(accountManager, networkMapController)
router.HandleFunc("/peers", peersHandler.GetAllPeers).Methods("GET", "OPTIONS")
router.HandleFunc("/peers/{peerId}", peersHandler.HandlePeer).
Methods("GET", "PUT", "DELETE", "OPTIONS")
@@ -36,9 +38,10 @@ func AddEndpoints(accountManager account.Manager, router *mux.Router) {
}
// NewHandler creates a new peers Handler
func NewHandler(accountManager account.Manager) *Handler {
func NewHandler(accountManager account.Manager, networkMapController network_map.Controller) *Handler {
return &Handler{
accountManager: accountManager,
accountManager: accountManager,
networkMapController: networkMapController,
}
}
@@ -47,7 +50,7 @@ func (h *Handler) checkPeerStatus(peer *nbpeer.Peer) (*nbpeer.Peer, error) {
if peer.Status.Connected {
// Although we have online status in store we do not yet have an updated channel so have to show it as disconnected
// This may happen after server restart when not all peers are yet connected
if !h.accountManager.HasConnectedChannel(peer.ID) {
if !h.networkMapController.IsConnected(peer.ID) {
peerToReturn.Status.Connected = false
}
}
@@ -73,7 +76,7 @@ func (h *Handler) getPeer(ctx context.Context, accountID, peerID, userID string,
return
}
dnsDomain := h.accountManager.GetDNSDomain(settings)
dnsDomain := h.networkMapController.GetDNSDomain(settings)
grps, _ := h.accountManager.GetPeerGroups(ctx, accountID, peerID)
grpsInfoMap := groups.ToGroupsInfoMap(grps, 0)
@@ -139,7 +142,7 @@ func (h *Handler) updatePeer(ctx context.Context, accountID, userID, peerID stri
util.WriteError(ctx, err, w)
return
}
dnsDomain := h.accountManager.GetDNSDomain(settings)
dnsDomain := h.networkMapController.GetDNSDomain(settings)
peerGroups, err := h.accountManager.GetPeerGroups(ctx, accountID, peer.ID)
if err != nil {
@@ -227,7 +230,7 @@ func (h *Handler) GetAllPeers(w http.ResponseWriter, r *http.Request) {
util.WriteError(r.Context(), err, w)
return
}
dnsDomain := h.accountManager.GetDNSDomain(settings)
dnsDomain := h.networkMapController.GetDNSDomain(settings)
grps, _ := h.accountManager.GetAllGroups(r.Context(), accountID, userID)
@@ -317,7 +320,7 @@ func (h *Handler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request) {
return
}
dnsDomain := h.accountManager.GetDNSDomain(account.Settings)
dnsDomain := h.networkMapController.GetDNSDomain(account.Settings)
customZone := account.GetPeersCustomZone(r.Context(), dnsDomain)
netMap := account.GetPeerNetworkMap(r.Context(), peerID, customZone, validPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil)

View File

@@ -14,12 +14,14 @@ import (
"time"
"github.com/gorilla/mux"
"go.uber.org/mock/gomock"
"golang.org/x/exp/maps"
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/shared/management/http/api"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@@ -36,7 +38,7 @@ const (
serviceUser = "service_user"
)
func initTestMetaData(peers ...*nbpeer.Peer) *Handler {
func initTestMetaData(t *testing.T, peers ...*nbpeer.Peer) *Handler {
peersMap := make(map[string]*nbpeer.Peer)
for _, peer := range peers {
@@ -99,6 +101,22 @@ func initTestMetaData(peers ...*nbpeer.Peer) *Handler {
},
}
ctrl := gomock.NewController(t)
networkMapController := network_map.NewMockController(ctrl)
networkMapController.EXPECT().
GetDNSDomain(gomock.Any()).
Return("domain").
AnyTimes()
networkMapController.EXPECT().
IsConnected(noUpdateChannelTestPeerID).
Return(false).
AnyTimes()
networkMapController.EXPECT().
IsConnected(gomock.Any()).
Return(true).
AnyTimes()
return &Handler{
accountManager: &mock_server.MockAccountManager{
UpdatePeerFunc: func(_ context.Context, accountID, userID string, update *nbpeer.Peer) (*nbpeer.Peer, error) {
@@ -187,6 +205,7 @@ func initTestMetaData(peers ...*nbpeer.Peer) *Handler {
return account.Settings, nil
},
},
networkMapController: networkMapController,
}
}
@@ -270,7 +289,7 @@ func TestGetPeers(t *testing.T) {
rr := httptest.NewRecorder()
p := initTestMetaData(peer, peer1)
p := initTestMetaData(t, peer, peer1)
for _, tc := range tt {
t.Run(tc.name, func(t *testing.T) {
@@ -374,7 +393,7 @@ func TestGetAccessiblePeers(t *testing.T) {
UserID: regularUser,
}
p := initTestMetaData(peer1, peer2, peer3)
p := initTestMetaData(t, peer1, peer2, peer3)
tt := []struct {
name string
@@ -477,7 +496,7 @@ func TestPeersHandlerUpdatePeerIP(t *testing.T) {
},
}
p := initTestMetaData(testPeer)
p := initTestMetaData(t, testPeer)
tt := []struct {
name string

View File

@@ -29,6 +29,7 @@ type AuthMiddleware struct {
ensureAccount EnsureAccountFunc
getUserFromUserAuth GetUserFromUserAuthFunc
syncUserJWTGroups SyncUserJWTGroupsFunc
rateLimiter *APIRateLimiter
}
// NewAuthMiddleware instance constructor
@@ -37,12 +38,19 @@ func NewAuthMiddleware(
ensureAccount EnsureAccountFunc,
syncUserJWTGroups SyncUserJWTGroupsFunc,
getUserFromUserAuth GetUserFromUserAuthFunc,
rateLimiterConfig *RateLimiterConfig,
) *AuthMiddleware {
var rateLimiter *APIRateLimiter
if rateLimiterConfig != nil {
rateLimiter = NewAPIRateLimiter(rateLimiterConfig)
}
return &AuthMiddleware{
authManager: authManager,
ensureAccount: ensureAccount,
syncUserJWTGroups: syncUserJWTGroups,
getUserFromUserAuth: getUserFromUserAuth,
rateLimiter: rateLimiter,
}
}
@@ -76,7 +84,11 @@ func (m *AuthMiddleware) Handler(h http.Handler) http.Handler {
request, err := m.checkPATFromRequest(r, auth)
if err != nil {
log.WithContext(r.Context()).Debugf("Error when validating PAT: %s", err.Error())
util.WriteError(r.Context(), status.Errorf(status.Unauthorized, "token invalid"), w)
// Check if it's a status error, otherwise default to Unauthorized
if _, ok := status.FromError(err); !ok {
err = status.Errorf(status.Unauthorized, "token invalid")
}
util.WriteError(r.Context(), err, w)
return
}
h.ServeHTTP(w, request)
@@ -145,6 +157,12 @@ func (m *AuthMiddleware) checkPATFromRequest(r *http.Request, auth []string) (*h
return r, fmt.Errorf("error extracting token: %w", err)
}
if m.rateLimiter != nil {
if !m.rateLimiter.Allow(token) {
return r, status.Errorf(status.TooManyRequests, "too many requests")
}
}
ctx := r.Context()
user, pat, accDomain, accCategory, err := m.authManager.GetPATInfo(ctx, token)
if err != nil {

View File

@@ -27,7 +27,9 @@ const (
domainCategory = "domainCategory"
userID = "userID"
tokenID = "tokenID"
tokenID2 = "tokenID2"
PAT = "nbp_PAT"
PAT2 = "nbp_PAT2"
JWT = "JWT"
wrongToken = "wrongToken"
)
@@ -49,6 +51,15 @@ var testAccount = &types.Account{
CreatedAt: time.Now().UTC(),
LastUsed: util.ToPtr(time.Now().UTC()),
},
tokenID2: {
ID: tokenID2,
Name: "My second token",
HashedToken: "someHash2",
ExpirationDate: util.ToPtr(time.Now().UTC().AddDate(0, 0, 7)),
CreatedBy: userID,
CreatedAt: time.Now().UTC(),
LastUsed: util.ToPtr(time.Now().UTC()),
},
},
},
},
@@ -58,6 +69,9 @@ func mockGetAccountInfoFromPAT(_ context.Context, token string) (user *types.Use
if token == PAT {
return testAccount.Users[userID], testAccount.Users[userID].PATs[tokenID], testAccount.Domain, testAccount.DomainCategory, nil
}
if token == PAT2 {
return testAccount.Users[userID], testAccount.Users[userID].PATs[tokenID2], testAccount.Domain, testAccount.DomainCategory, nil
}
return nil, nil, "", "", fmt.Errorf("PAT invalid")
}
@@ -81,7 +95,7 @@ func mockValidateAndParseToken(_ context.Context, token string) (nbcontext.UserA
}
func mockMarkPATUsed(_ context.Context, token string) error {
if token == tokenID {
if token == tokenID || token == tokenID2 {
return nil
}
return fmt.Errorf("Should never get reached")
@@ -192,6 +206,7 @@ func TestAuthMiddleware_Handler(t *testing.T) {
func(ctx context.Context, userAuth nbcontext.UserAuth) (*types.User, error) {
return &types.User{}, nil
},
nil,
)
handlerToTest := authMiddleware.Handler(nextHandler)
@@ -221,6 +236,273 @@ func TestAuthMiddleware_Handler(t *testing.T) {
}
}
func TestAuthMiddleware_RateLimiting(t *testing.T) {
mockAuth := &auth.MockManager{
ValidateAndParseTokenFunc: mockValidateAndParseToken,
EnsureUserAccessByJWTGroupsFunc: mockEnsureUserAccessByJWTGroups,
MarkPATUsedFunc: mockMarkPATUsed,
GetPATInfoFunc: mockGetAccountInfoFromPAT,
}
t.Run("PAT Token Rate Limiting - Burst Works", func(t *testing.T) {
// Configure rate limiter: 10 requests per minute with burst of 5
rateLimitConfig := &RateLimiterConfig{
RequestsPerMinute: 10,
Burst: 5,
CleanupInterval: 5 * time.Minute,
LimiterTTL: 10 * time.Minute,
}
authMiddleware := NewAuthMiddleware(
mockAuth,
func(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error) {
return userAuth.AccountId, userAuth.UserId, nil
},
func(ctx context.Context, userAuth nbcontext.UserAuth) error {
return nil
},
func(ctx context.Context, userAuth nbcontext.UserAuth) (*types.User, error) {
return &types.User{}, nil
},
rateLimitConfig,
)
handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
// Make burst requests - all should succeed
successCount := 0
for i := 0; i < 5; i++ {
req := httptest.NewRequest("GET", "http://testing/test", nil)
req.Header.Set("Authorization", "Token "+PAT)
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
if rec.Code == http.StatusOK {
successCount++
}
}
assert.Equal(t, 5, successCount, "All burst requests should succeed")
// The 6th request should fail (exceeded burst)
req := httptest.NewRequest("GET", "http://testing/test", nil)
req.Header.Set("Authorization", "Token "+PAT)
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
assert.Equal(t, http.StatusTooManyRequests, rec.Code, "Request beyond burst should be rate limited")
})
t.Run("PAT Token Rate Limiting - Rate Limit Enforced", func(t *testing.T) {
// Configure very low rate limit: 1 request per minute
rateLimitConfig := &RateLimiterConfig{
RequestsPerMinute: 1,
Burst: 1,
CleanupInterval: 5 * time.Minute,
LimiterTTL: 10 * time.Minute,
}
authMiddleware := NewAuthMiddleware(
mockAuth,
func(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error) {
return userAuth.AccountId, userAuth.UserId, nil
},
func(ctx context.Context, userAuth nbcontext.UserAuth) error {
return nil
},
func(ctx context.Context, userAuth nbcontext.UserAuth) (*types.User, error) {
return &types.User{}, nil
},
rateLimitConfig,
)
handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
// First request should succeed
req := httptest.NewRequest("GET", "http://testing/test", nil)
req.Header.Set("Authorization", "Token "+PAT)
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
assert.Equal(t, http.StatusOK, rec.Code, "First request should succeed")
// Second request should fail (rate limited)
req = httptest.NewRequest("GET", "http://testing/test", nil)
req.Header.Set("Authorization", "Token "+PAT)
rec = httptest.NewRecorder()
handler.ServeHTTP(rec, req)
assert.Equal(t, http.StatusTooManyRequests, rec.Code, "Second request should be rate limited")
})
t.Run("Bearer Token Not Rate Limited", func(t *testing.T) {
// Configure strict rate limit
rateLimitConfig := &RateLimiterConfig{
RequestsPerMinute: 1,
Burst: 1,
CleanupInterval: 5 * time.Minute,
LimiterTTL: 10 * time.Minute,
}
authMiddleware := NewAuthMiddleware(
mockAuth,
func(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error) {
return userAuth.AccountId, userAuth.UserId, nil
},
func(ctx context.Context, userAuth nbcontext.UserAuth) error {
return nil
},
func(ctx context.Context, userAuth nbcontext.UserAuth) (*types.User, error) {
return &types.User{}, nil
},
rateLimitConfig,
)
handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
// Make multiple requests with Bearer token - all should succeed
successCount := 0
for i := 0; i < 10; i++ {
req := httptest.NewRequest("GET", "http://testing/test", nil)
req.Header.Set("Authorization", "Bearer "+JWT)
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
if rec.Code == http.StatusOK {
successCount++
}
}
assert.Equal(t, 10, successCount, "All Bearer token requests should succeed (not rate limited)")
})
t.Run("PAT Token Rate Limiting Per Token", func(t *testing.T) {
// Configure rate limiter
rateLimitConfig := &RateLimiterConfig{
RequestsPerMinute: 1,
Burst: 1,
CleanupInterval: 5 * time.Minute,
LimiterTTL: 10 * time.Minute,
}
authMiddleware := NewAuthMiddleware(
mockAuth,
func(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error) {
return userAuth.AccountId, userAuth.UserId, nil
},
func(ctx context.Context, userAuth nbcontext.UserAuth) error {
return nil
},
func(ctx context.Context, userAuth nbcontext.UserAuth) (*types.User, error) {
return &types.User{}, nil
},
rateLimitConfig,
)
handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
// Use first PAT token
req := httptest.NewRequest("GET", "http://testing/test", nil)
req.Header.Set("Authorization", "Token "+PAT)
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
assert.Equal(t, http.StatusOK, rec.Code, "First request with PAT should succeed")
// Second request with same token should fail
req = httptest.NewRequest("GET", "http://testing/test", nil)
req.Header.Set("Authorization", "Token "+PAT)
rec = httptest.NewRecorder()
handler.ServeHTTP(rec, req)
assert.Equal(t, http.StatusTooManyRequests, rec.Code, "Second request with same PAT should be rate limited")
// Use second PAT token - should succeed because it has independent rate limit
req = httptest.NewRequest("GET", "http://testing/test", nil)
req.Header.Set("Authorization", "Token "+PAT2)
rec = httptest.NewRecorder()
handler.ServeHTTP(rec, req)
assert.Equal(t, http.StatusOK, rec.Code, "First request with PAT2 should succeed (independent rate limit)")
// Second request with PAT2 should also be rate limited
req = httptest.NewRequest("GET", "http://testing/test", nil)
req.Header.Set("Authorization", "Token "+PAT2)
rec = httptest.NewRecorder()
handler.ServeHTTP(rec, req)
assert.Equal(t, http.StatusTooManyRequests, rec.Code, "Second request with PAT2 should be rate limited")
// JWT should still work (not rate limited)
req = httptest.NewRequest("GET", "http://testing/test", nil)
req.Header.Set("Authorization", "Bearer "+JWT)
rec = httptest.NewRecorder()
handler.ServeHTTP(rec, req)
assert.Equal(t, http.StatusOK, rec.Code, "JWT request should succeed (not rate limited)")
})
t.Run("Rate Limiter Cleanup", func(t *testing.T) {
// Configure rate limiter with short cleanup interval and TTL for testing
rateLimitConfig := &RateLimiterConfig{
RequestsPerMinute: 60,
Burst: 1,
CleanupInterval: 100 * time.Millisecond,
LimiterTTL: 200 * time.Millisecond,
}
authMiddleware := NewAuthMiddleware(
mockAuth,
func(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error) {
return userAuth.AccountId, userAuth.UserId, nil
},
func(ctx context.Context, userAuth nbcontext.UserAuth) error {
return nil
},
func(ctx context.Context, userAuth nbcontext.UserAuth) (*types.User, error) {
return &types.User{}, nil
},
rateLimitConfig,
)
handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
// First request - should succeed
req := httptest.NewRequest("GET", "http://testing/test", nil)
req.Header.Set("Authorization", "Token "+PAT)
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
assert.Equal(t, http.StatusOK, rec.Code, "First request should succeed")
// Second request immediately - should fail (burst exhausted)
req = httptest.NewRequest("GET", "http://testing/test", nil)
req.Header.Set("Authorization", "Token "+PAT)
rec = httptest.NewRecorder()
handler.ServeHTTP(rec, req)
assert.Equal(t, http.StatusTooManyRequests, rec.Code, "Second request should be rate limited")
// Wait for limiter to be cleaned up (TTL + cleanup interval + buffer)
time.Sleep(400 * time.Millisecond)
// After cleanup, the limiter should be removed and recreated with full burst capacity
req = httptest.NewRequest("GET", "http://testing/test", nil)
req.Header.Set("Authorization", "Token "+PAT)
rec = httptest.NewRecorder()
handler.ServeHTTP(rec, req)
assert.Equal(t, http.StatusOK, rec.Code, "Request after cleanup should succeed (new limiter with full burst)")
// Verify it's a fresh limiter by checking burst is reset
req = httptest.NewRequest("GET", "http://testing/test", nil)
req.Header.Set("Authorization", "Token "+PAT)
rec = httptest.NewRecorder()
handler.ServeHTTP(rec, req)
assert.Equal(t, http.StatusTooManyRequests, rec.Code, "Second request after cleanup should be rate limited again")
})
}
func TestAuthMiddleware_Handler_Child(t *testing.T) {
tt := []struct {
name string
@@ -297,6 +579,7 @@ func TestAuthMiddleware_Handler_Child(t *testing.T) {
func(ctx context.Context, userAuth nbcontext.UserAuth) (*types.User, error) {
return &types.User{}, nil
},
nil,
)
for _, tc := range tt {

View File

@@ -0,0 +1,146 @@
package middleware
import (
"context"
"sync"
"time"
"golang.org/x/time/rate"
)
// RateLimiterConfig holds configuration for the API rate limiter
type RateLimiterConfig struct {
// RequestsPerMinute defines the rate at which tokens are replenished
RequestsPerMinute float64
// Burst defines the maximum number of requests that can be made in a burst
Burst int
// CleanupInterval defines how often to clean up old limiters (how often garbage collection runs)
CleanupInterval time.Duration
// LimiterTTL defines how long a limiter should be kept after last use (age threshold for removal)
LimiterTTL time.Duration
}
// DefaultRateLimiterConfig returns a default configuration
func DefaultRateLimiterConfig() *RateLimiterConfig {
return &RateLimiterConfig{
RequestsPerMinute: 100,
Burst: 120,
CleanupInterval: 5 * time.Minute,
LimiterTTL: 10 * time.Minute,
}
}
// limiterEntry holds a rate limiter and its last access time
type limiterEntry struct {
limiter *rate.Limiter
lastAccess time.Time
}
// APIRateLimiter manages rate limiting for API tokens
type APIRateLimiter struct {
config *RateLimiterConfig
limiters map[string]*limiterEntry
mu sync.RWMutex
stopChan chan struct{}
}
// NewAPIRateLimiter creates a new API rate limiter with the given configuration
func NewAPIRateLimiter(config *RateLimiterConfig) *APIRateLimiter {
if config == nil {
config = DefaultRateLimiterConfig()
}
rl := &APIRateLimiter{
config: config,
limiters: make(map[string]*limiterEntry),
stopChan: make(chan struct{}),
}
go rl.cleanupLoop()
return rl
}
// Allow checks if a request for the given key (token) is allowed
func (rl *APIRateLimiter) Allow(key string) bool {
limiter := rl.getLimiter(key)
return limiter.Allow()
}
// Wait blocks until the rate limiter allows another request for the given key
// Returns an error if the context is canceled
func (rl *APIRateLimiter) Wait(ctx context.Context, key string) error {
limiter := rl.getLimiter(key)
return limiter.Wait(ctx)
}
// getLimiter retrieves or creates a rate limiter for the given key
func (rl *APIRateLimiter) getLimiter(key string) *rate.Limiter {
rl.mu.RLock()
entry, exists := rl.limiters[key]
rl.mu.RUnlock()
if exists {
rl.mu.Lock()
entry.lastAccess = time.Now()
rl.mu.Unlock()
return entry.limiter
}
rl.mu.Lock()
defer rl.mu.Unlock()
if entry, exists := rl.limiters[key]; exists {
entry.lastAccess = time.Now()
return entry.limiter
}
requestsPerSecond := rl.config.RequestsPerMinute / 60.0
limiter := rate.NewLimiter(rate.Limit(requestsPerSecond), rl.config.Burst)
rl.limiters[key] = &limiterEntry{
limiter: limiter,
lastAccess: time.Now(),
}
return limiter
}
// cleanupLoop periodically removes old limiters that haven't been used recently
func (rl *APIRateLimiter) cleanupLoop() {
ticker := time.NewTicker(rl.config.CleanupInterval)
defer ticker.Stop()
for {
select {
case <-ticker.C:
rl.cleanup()
case <-rl.stopChan:
return
}
}
}
// cleanup removes limiters that haven't been used within the TTL period
func (rl *APIRateLimiter) cleanup() {
rl.mu.Lock()
defer rl.mu.Unlock()
now := time.Now()
for key, entry := range rl.limiters {
if now.Sub(entry.lastAccess) > rl.config.LimiterTTL {
delete(rl.limiters, key)
}
}
}
// Stop stops the cleanup goroutine
func (rl *APIRateLimiter) Stop() {
close(rl.stopChan)
}
// Reset removes the rate limiter for a specific key
func (rl *APIRateLimiter) Reset(key string) {
rl.mu.Lock()
defer rl.mu.Unlock()
delete(rl.limiters, key)
}

View File

@@ -10,6 +10,10 @@ import (
"github.com/stretchr/testify/assert"
"github.com/netbirdio/management-integrations/integrations"
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/account"
@@ -31,7 +35,7 @@ import (
"github.com/netbirdio/netbird/management/server/users"
)
func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPeerUpdate *server.UpdateMessage, validateUpdate bool) (http.Handler, account.Manager, chan struct{}) {
func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPeerUpdate *network_map.UpdateMessage, validateUpdate bool) (http.Handler, account.Manager, chan struct{}) {
store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), sqlFile, t.TempDir())
if err != nil {
t.Fatalf("Failed to create test store: %v", err)
@@ -43,7 +47,7 @@ func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPee
t.Fatalf("Failed to create metrics: %v", err)
}
peersUpdateManager := server.NewPeersUpdateManager(nil)
peersUpdateManager := update_channel.NewPeersUpdateManager(nil)
updMsg := peersUpdateManager.CreateChannel(context.Background(), testing_tools.TestPeerId)
done := make(chan struct{})
if validateUpdate {
@@ -63,7 +67,11 @@ func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPee
userManager := users.NewManager(store)
permissionsManager := permissions.NewManager(store)
settingsManager := settings.NewManager(store, userManager, integrations.NewManager(&activity.InMemoryEventStore{}), permissionsManager)
am, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "", &activity.InMemoryEventStore{}, geoMock, false, validatorMock, metrics, proxyController, settingsManager, permissionsManager, false)
ctx := context.Background()
requestBuffer := server.NewAccountRequestBuffer(ctx, store)
networkMapController := controller.NewController(ctx, store, metrics, peersUpdateManager, requestBuffer, server.MockIntegratedValidator{}, settingsManager, "", port_forwarding.NewControllerMock())
am, err := server.BuildManager(ctx, store, networkMapController, nil, "", &activity.InMemoryEventStore{}, geoMock, false, validatorMock, metrics, proxyController, settingsManager, permissionsManager, false)
if err != nil {
t.Fatalf("Failed to create manager: %v", err)
}
@@ -83,7 +91,7 @@ func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPee
groupsManagerMock := groups.NewManagerMock()
peersManager := peers.NewManager(store, permissionsManager)
apiHandler, err := http2.NewAPIHandler(context.Background(), am, networksManagerMock, resourcesManagerMock, routersManagerMock, groupsManagerMock, geoMock, authManagerMock, metrics, validatorMock, proxyController, permissionsManager, peersManager, settingsManager)
apiHandler, err := http2.NewAPIHandler(context.Background(), am, networksManagerMock, resourcesManagerMock, routersManagerMock, groupsManagerMock, geoMock, authManagerMock, metrics, validatorMock, proxyController, permissionsManager, peersManager, settingsManager, networkMapController)
if err != nil {
t.Fatalf("Failed to create API handler: %v", err)
}
@@ -91,7 +99,7 @@ func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPee
return apiHandler, am, done
}
func peerShouldNotReceiveUpdate(t testing_tools.TB, updateMessage <-chan *server.UpdateMessage) {
func peerShouldNotReceiveUpdate(t testing_tools.TB, updateMessage <-chan *network_map.UpdateMessage) {
t.Helper()
select {
case msg := <-updateMessage:
@@ -101,7 +109,7 @@ func peerShouldNotReceiveUpdate(t testing_tools.TB, updateMessage <-chan *server
}
}
func peerShouldReceiveUpdate(t testing_tools.TB, updateMessage <-chan *server.UpdateMessage, expected *server.UpdateMessage) {
func peerShouldReceiveUpdate(t testing_tools.TB, updateMessage <-chan *network_map.UpdateMessage, expected *network_map.UpdateMessage) {
t.Helper()
select {

View File

@@ -22,10 +22,14 @@ import (
"github.com/netbirdio/netbird/encryption"
"github.com/netbirdio/netbird/formatter/hook"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
"github.com/netbirdio/netbird/management/internals/server/config"
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/groups"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/peers/ephemeral/manager"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/settings"
@@ -321,99 +325,6 @@ func loginPeerWithValidSetupKey(key wgtypes.Key, client mgmtProto.ManagementServ
return loginResp, nil
}
func TestServer_GetDeviceAuthorizationFlow(t *testing.T) {
testingServerKey, err := wgtypes.GeneratePrivateKey()
if err != nil {
t.Errorf("unable to generate server wg key for testing GetDeviceAuthorizationFlow, error: %v", err)
}
testingClientKey, err := wgtypes.GeneratePrivateKey()
if err != nil {
t.Errorf("unable to generate client wg key for testing GetDeviceAuthorizationFlow, error: %v", err)
}
testCases := []struct {
name string
inputFlow *config.DeviceAuthorizationFlow
expectedFlow *mgmtProto.DeviceAuthorizationFlow
expectedErrFunc require.ErrorAssertionFunc
expectedErrMSG string
expectedComparisonFunc require.ComparisonAssertionFunc
expectedComparisonMSG string
}{
{
name: "Testing No Device Flow Config",
inputFlow: nil,
expectedErrFunc: require.Error,
expectedErrMSG: "should return error",
},
{
name: "Testing Invalid Device Flow Provider Config",
inputFlow: &config.DeviceAuthorizationFlow{
Provider: "NoNe",
ProviderConfig: config.ProviderConfig{
ClientID: "test",
},
},
expectedErrFunc: require.Error,
expectedErrMSG: "should return error",
},
{
name: "Testing Full Device Flow Config",
inputFlow: &config.DeviceAuthorizationFlow{
Provider: "hosted",
ProviderConfig: config.ProviderConfig{
ClientID: "test",
},
},
expectedFlow: &mgmtProto.DeviceAuthorizationFlow{
Provider: 0,
ProviderConfig: &mgmtProto.ProviderConfig{
ClientID: "test",
},
},
expectedErrFunc: require.NoError,
expectedErrMSG: "should not return error",
expectedComparisonFunc: require.Equal,
expectedComparisonMSG: "should match",
},
}
for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
mgmtServer := &GRPCServer{
wgKey: testingServerKey,
config: &config.Config{
DeviceAuthorizationFlow: testCase.inputFlow,
},
}
message := &mgmtProto.DeviceAuthorizationFlowRequest{}
encryptedMSG, err := encryption.EncryptMessage(testingClientKey.PublicKey(), mgmtServer.wgKey, message)
require.NoError(t, err, "should be able to encrypt message")
resp, err := mgmtServer.GetDeviceAuthorizationFlow(
context.TODO(),
&mgmtProto.EncryptedMessage{
WgPubKey: testingClientKey.PublicKey().String(),
Body: encryptedMSG,
},
)
testCase.expectedErrFunc(t, err, testCase.expectedErrMSG)
if testCase.expectedComparisonFunc != nil {
flowInfoResp := &mgmtProto.DeviceAuthorizationFlow{}
err = encryption.DecryptMessage(mgmtServer.wgKey.PublicKey(), testingClientKey, resp.Body, flowInfoResp)
require.NoError(t, err, "should be able to decrypt")
testCase.expectedComparisonFunc(t, testCase.expectedFlow.Provider, flowInfoResp.Provider, testCase.expectedComparisonMSG)
testCase.expectedComparisonFunc(t, testCase.expectedFlow.ProviderConfig.ClientID, flowInfoResp.ProviderConfig.ClientID, testCase.expectedComparisonMSG)
}
})
}
}
func startManagementForTest(t *testing.T, testFile string, config *config.Config) (*grpc.Server, *DefaultAccountManager, string, func(), error) {
t.Helper()
lis, err := net.Listen("tcp", "localhost:0")
@@ -427,7 +338,6 @@ func startManagementForTest(t *testing.T, testFile string, config *config.Config
t.Fatal(err)
}
peersUpdateManager := NewPeersUpdateManager(nil)
eventStore := &activity.InMemoryEventStore{}
ctx := context.WithValue(context.Background(), hook.ExecutionContextKey, hook.SystemSource) //nolint:staticcheck
@@ -451,7 +361,10 @@ func startManagementForTest(t *testing.T, testFile string, config *config.Config
permissionsManager := permissions.NewManager(store)
groupsManager := groups.NewManagerMock()
accountManager, err := BuildManager(ctx, store, peersUpdateManager, nil, "", "netbird.selfhosted",
updateManager := update_channel.NewPeersUpdateManager(metrics)
requestBuffer := NewAccountRequestBuffer(ctx, store)
networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock())
accountManager, err := BuildManager(ctx, store, networkMapController, nil, "",
eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
if err != nil {
@@ -459,10 +372,10 @@ func startManagementForTest(t *testing.T, testFile string, config *config.Config
return nil, nil, "", cleanup, err
}
secretsManager := NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
secretsManager := nbgrpc.NewTimeBasedAuthSecretsManager(updateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
ephemeralMgr := manager.NewEphemeralManager(store, accountManager)
mgmtServer, err := NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, ephemeralMgr, nil, MockIntegratedValidator{})
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, updateManager, secretsManager, nil, ephemeralMgr, nil, MockIntegratedValidator{}, networkMapController)
if err != nil {
return nil, nil, "", cleanup, err
}
@@ -764,9 +677,38 @@ func Test_LoginPerformance(t *testing.T) {
peerLogin := types.PeerLogin{
WireGuardPubKey: key.String(),
SSHKey: "random",
Meta: extractPeerMeta(context.Background(), meta),
SetupKey: setupKey.Key,
ConnectionIP: net.IP{1, 1, 1, 1},
Meta: nbpeer.PeerSystemMeta{
Hostname: meta.GetHostname(),
GoOS: meta.GetGoOS(),
Kernel: meta.GetKernel(),
Platform: meta.GetPlatform(),
OS: meta.GetOS(),
OSVersion: meta.GetOSVersion(),
WtVersion: meta.GetNetbirdVersion(),
UIVersion: meta.GetUiVersion(),
KernelVersion: meta.GetKernelVersion(),
SystemSerialNumber: meta.GetSysSerialNumber(),
SystemProductName: meta.GetSysProductName(),
SystemManufacturer: meta.GetSysManufacturer(),
Environment: nbpeer.Environment{
Cloud: meta.GetEnvironment().GetCloud(),
Platform: meta.GetEnvironment().GetPlatform(),
},
Flags: nbpeer.Flags{
RosenpassEnabled: meta.GetFlags().GetRosenpassEnabled(),
RosenpassPermissive: meta.GetFlags().GetRosenpassPermissive(),
ServerSSHAllowed: meta.GetFlags().GetServerSSHAllowed(),
DisableClientRoutes: meta.GetFlags().GetDisableClientRoutes(),
DisableServerRoutes: meta.GetFlags().GetDisableServerRoutes(),
DisableDNS: meta.GetFlags().GetDisableDNS(),
DisableFirewall: meta.GetFlags().GetDisableFirewall(),
BlockLANAccess: meta.GetFlags().GetBlockLANAccess(),
BlockInbound: meta.GetFlags().GetBlockInbound(),
LazyConnectionEnabled: meta.GetFlags().GetLazyConnectionEnabled(),
},
},
SetupKey: setupKey.Key,
ConnectionIP: net.IP{1, 1, 1, 1},
}
login := func() error {

View File

@@ -20,7 +20,10 @@ import (
"google.golang.org/grpc/keepalive"
"github.com/netbirdio/netbird/encryption"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
"github.com/netbirdio/netbird/management/internals/server/config"
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/groups"
@@ -176,7 +179,6 @@ func startServer(
log.Fatalf("failed creating a store: %s: %v", config.Datadir, err)
}
peersUpdateManager := server.NewPeersUpdateManager(nil)
eventStore := &activity.InMemoryEventStore{}
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
@@ -199,13 +201,18 @@ func startServer(
AnyTimes()
permissionsManager := permissions.NewManager(str)
ctx := context.Background()
updateManager := update_channel.NewPeersUpdateManager(metrics)
requestBuffer := server.NewAccountRequestBuffer(ctx, str)
networkMapController := controller.NewController(ctx, str, metrics, updateManager, requestBuffer, server.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock())
accountManager, err := server.BuildManager(
context.Background(),
str,
peersUpdateManager,
networkMapController,
nil,
"",
"netbird.selfhosted",
eventStore,
nil,
false,
@@ -220,18 +227,18 @@ func startServer(
}
groupsManager := groups.NewManager(str, permissionsManager, accountManager)
secretsManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
mgmtServer, err := server.NewServer(
context.Background(),
secretsManager := nbgrpc.NewTimeBasedAuthSecretsManager(updateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
mgmtServer, err := nbgrpc.NewServer(
config,
accountManager,
settingsMockManager,
peersUpdateManager,
updateManager,
secretsManager,
nil,
&manager.EphemeralManager{},
nil,
server.MockIntegratedValidator{},
networkMapController,
)
if err != nil {
t.Fatalf("failed creating management server: %v", err)

View File

@@ -38,7 +38,7 @@ type MockAccountManager struct {
ListUsersFunc func(ctx context.Context, accountID string) ([]*types.User, error)
GetPeersFunc func(ctx context.Context, accountID, userID, nameFilter, ipFilter string) ([]*nbpeer.Peer, error)
MarkPeerConnectedFunc func(ctx context.Context, peerKey string, connected bool, realIP net.IP) error
SyncAndMarkPeerFunc func(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error)
SyncAndMarkPeerFunc func(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error)
DeletePeerFunc func(ctx context.Context, accountID, peerKey, userID string) error
GetNetworkMapFunc func(ctx context.Context, peerKey string) (*types.NetworkMap, error)
GetPeerNetworkFunc func(ctx context.Context, peerKey string) (*types.Network, error)
@@ -94,7 +94,7 @@ type MockAccountManager struct {
GetPeerFunc func(ctx context.Context, accountID, peerID, userID string) (*nbpeer.Peer, error)
UpdateAccountSettingsFunc func(ctx context.Context, accountID, userID string, newSettings *types.Settings) (*types.Settings, error)
LoginPeerFunc func(ctx context.Context, login types.PeerLogin) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error)
SyncPeerFunc func(ctx context.Context, sync types.PeerSync, accountID string) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error)
SyncPeerFunc func(ctx context.Context, sync types.PeerSync, accountID string) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error)
InviteUserFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserEmail string) error
ApproveUserFunc func(ctx context.Context, accountID, initiatorUserID, targetUserID string) (*types.UserInfo, error)
RejectUserFunc func(ctx context.Context, accountID, initiatorUserID, targetUserID string) error
@@ -125,9 +125,10 @@ type MockAccountManager struct {
UpdateAccountOnboardingFunc func(ctx context.Context, accountID, userID string, onboarding *types.AccountOnboarding) (*types.AccountOnboarding, error)
GetOrCreateAccountByPrivateDomainFunc func(ctx context.Context, initiatorId, domain string) (*types.Account, bool, error)
AllowSyncFunc func(string, uint64) bool
UpdateAccountPeersFunc func(ctx context.Context, accountID string)
BufferUpdateAccountPeersFunc func(ctx context.Context, accountID string)
AllowSyncFunc func(string, uint64) bool
UpdateAccountPeersFunc func(ctx context.Context, accountID string)
BufferUpdateAccountPeersFunc func(ctx context.Context, accountID string)
RecalculateNetworkMapCacheFunc func(ctx context.Context, accountId string) error
}
func (am *MockAccountManager) CreateGroup(ctx context.Context, accountID, userID string, group *types.Group) error {
@@ -177,11 +178,11 @@ func (am *MockAccountManager) DeleteSetupKey(ctx context.Context, accountID, use
return status.Errorf(codes.Unimplemented, "method DeleteSetupKey is not implemented")
}
func (am *MockAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) {
func (am *MockAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) {
if am.SyncAndMarkPeerFunc != nil {
return am.SyncAndMarkPeerFunc(ctx, accountID, peerPubKey, meta, realIP)
}
return nil, nil, nil, status.Errorf(codes.Unimplemented, "method MarkPeerConnected is not implemented")
return nil, nil, nil, 0, status.Errorf(codes.Unimplemented, "method MarkPeerConnected is not implemented")
}
func (am *MockAccountManager) OnPeerDisconnected(_ context.Context, accountID string, peerPubKey string) error {
@@ -746,11 +747,11 @@ func (am *MockAccountManager) LoginPeer(ctx context.Context, login types.PeerLog
}
// SyncPeer mocks SyncPeer of the AccountManager interface
func (am *MockAccountManager) SyncPeer(ctx context.Context, sync types.PeerSync, accountID string) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) {
func (am *MockAccountManager) SyncPeer(ctx context.Context, sync types.PeerSync, accountID string) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) {
if am.SyncPeerFunc != nil {
return am.SyncPeerFunc(ctx, sync, accountID)
}
return nil, nil, nil, status.Errorf(codes.Unimplemented, "method SyncPeer is not implemented")
return nil, nil, nil, 0, status.Errorf(codes.Unimplemented, "method SyncPeer is not implemented")
}
// GetAllConnectedPeers mocks GetAllConnectedPeers of the AccountManager interface
@@ -986,3 +987,10 @@ func (am *MockAccountManager) AllowSync(key string, hash uint64) bool {
}
return true
}
func (am *MockAccountManager) RecalculateNetworkMapCache(ctx context.Context, accountID string) error {
if am.RecalculateNetworkMapCacheFunc != nil {
return am.RecalculateNetworkMapCacheFunc(ctx, accountID)
}
return nil
}

View File

@@ -11,6 +11,8 @@ import (
"github.com/stretchr/testify/require"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
@@ -785,7 +787,13 @@ func createNSManager(t *testing.T) (*DefaultAccountManager, error) {
AnyTimes()
permissionsManager := permissions.NewManager(store)
return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.selfhosted", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
ctx := context.Background()
updateManager := update_channel.NewPeersUpdateManager(metrics)
requestBuffer := NewAccountRequestBuffer(ctx, store)
networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock())
return BuildManager(context.Background(), store, networkMapController, nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
}
func createNSStore(t *testing.T) (store.Store, error) {
@@ -975,7 +983,7 @@ func TestValidateDomain(t *testing.T) {
}
func TestNameServerAccountPeersUpdate(t *testing.T) {
manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t)
manager, updateManager, account, peer1, peer2, peer3 := setupNetworkMapTest(t)
var newNameServerGroupA *nbdns.NameServerGroup
var newNameServerGroupB *nbdns.NameServerGroup
@@ -994,9 +1002,9 @@ func TestNameServerAccountPeersUpdate(t *testing.T) {
})
assert.NoError(t, err)
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
updMsg := updateManager.CreateChannel(context.Background(), peer1.ID)
t.Cleanup(func() {
manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
updateManager.CloseChannel(context.Background(), peer1.ID)
})
// Creating a nameserver group with a distribution group no peers should not update account peers

View File

@@ -8,8 +8,6 @@ import (
"net"
"slices"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/rs/xid"
@@ -23,7 +21,6 @@ import (
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/shared/management/domain"
"github.com/netbirdio/netbird/util"
"github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/management/server/store"
@@ -31,7 +28,6 @@ import (
"github.com/netbirdio/netbird/management/server/activity"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/shared/management/proto"
"github.com/netbirdio/netbird/shared/management/status"
)
@@ -106,11 +102,6 @@ func (am *DefaultAccountManager) getUserAccessiblePeers(ctx context.Context, acc
// MarkPeerConnected marks peer as connected (true) or disconnected (false)
func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubKey string, connected bool, realIP net.IP, accountID string) error {
start := time.Now()
defer func() {
log.WithContext(ctx).Debugf("MarkPeerConnected: took %v", time.Since(start))
}()
var peer *nbpeer.Peer
var settings *types.Settings
var expired bool
@@ -145,9 +136,7 @@ func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubK
}
if expired {
// we need to update other peers because when peer login expires all other peers are notified to disconnect from
// the expired one. Here we notify them that connection is now allowed again.
am.BufferUpdateAccountPeers(ctx, accountID)
am.networkMapController.OnPeerUpdated(accountID, peer)
}
return nil
@@ -203,7 +192,6 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user
var peer *nbpeer.Peer
var settings *types.Settings
var peerGroupList []string
var requiresPeerUpdates bool
var peerLabelChanged bool
var sshChanged bool
var loginExpirationChanged bool
@@ -226,9 +214,9 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user
return err
}
dnsDomain = am.GetDNSDomain(settings)
dnsDomain = am.networkMapController.GetDNSDomain(settings)
update, requiresPeerUpdates, err = am.integratedPeerValidator.ValidatePeer(ctx, update, peer, userID, accountID, dnsDomain, peerGroupList, settings.Extra)
update, _, err = am.integratedPeerValidator.ValidatePeer(ctx, update, peer, userID, accountID, dnsDomain, peerGroupList, settings.Extra)
if err != nil {
return err
}
@@ -321,11 +309,7 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user
}
}
if peerLabelChanged || requiresPeerUpdates {
am.UpdateAccountPeers(ctx, accountID)
} else if sshChanged {
am.UpdateAccountPeer(ctx, accountID, peer.ID)
}
am.networkMapController.OnPeerUpdated(accountID, peer)
return peer, nil
}
@@ -381,8 +365,13 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer
storeEvent()
}
if userID != activity.SystemInitiator {
am.BufferUpdateAccountPeers(ctx, accountID)
err = am.networkMapController.DeletePeer(ctx, accountID, peer.ID)
if err != nil {
log.WithContext(ctx).Errorf("failed to delete peer %s from network map: %v", peer.ID, err)
}
if err := am.networkMapController.OnPeerDeleted(ctx, accountID, peerID); err != nil {
log.WithContext(ctx).Errorf("failed to update network map cache for peer %s: %v", peerID, err)
}
return nil
@@ -390,41 +379,7 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer
// GetNetworkMap returns Network map for a given peer (omits original peer from the Peers result)
func (am *DefaultAccountManager) GetNetworkMap(ctx context.Context, peerID string) (*types.NetworkMap, error) {
account, err := am.Store.GetAccountByPeerID(ctx, peerID)
if err != nil {
return nil, err
}
peer := account.GetPeer(peerID)
if peer == nil {
return nil, status.Errorf(status.NotFound, "peer with ID %s not found", peerID)
}
groups := make(map[string][]string)
for groupID, group := range account.Groups {
groups[groupID] = group.Peers
}
validatedPeers, err := am.integratedPeerValidator.GetValidatedPeers(ctx, account.Id, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra)
if err != nil {
return nil, err
}
customZone := account.GetPeersCustomZone(ctx, am.GetDNSDomain(account.Settings))
proxyNetworkMaps, err := am.proxyController.GetProxyNetworkMaps(ctx, account.Id, peerID, account.Peers)
if err != nil {
log.WithContext(ctx).Errorf("failed to get proxy network maps: %v", err)
return nil, err
}
networkMap := account.GetPeerNetworkMap(ctx, peer.ID, customZone, validatedPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil)
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]
if ok {
networkMap.Merge(proxyNetworkMap)
}
return networkMap, nil
return am.networkMapController.GetNetworkMap(ctx, peerID)
}
// GetPeerNetwork returns the Network for a given peer
@@ -683,16 +638,19 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, accountID, setupKe
}
opEvent.TargetID = newPeer.ID
opEvent.Meta = newPeer.EventMeta(am.GetDNSDomain(settings))
opEvent.Meta = newPeer.EventMeta(am.networkMapController.GetDNSDomain(settings))
if !addedByUser {
opEvent.Meta["setup_key_name"] = setupKeyName
}
am.StoreEvent(ctx, opEvent.InitiatorID, opEvent.TargetID, opEvent.AccountID, opEvent.Activity, opEvent.Meta)
am.BufferUpdateAccountPeers(ctx, accountID)
if err := am.networkMapController.OnPeerAdded(ctx, accountID, newPeer.ID); err != nil {
log.WithContext(ctx).Errorf("failed to update network map cache for peer %s: %v", newPeer.ID, err)
}
return am.getValidatedPeerWithMap(ctx, false, accountID, newPeer)
p, nmap, pc, _, err := am.networkMapController.GetValidatedPeerWithMap(ctx, false, accountID, newPeer)
return p, nmap, pc, err
}
func getPeerIPDNSLabel(ip net.IP, peerHostName string) (string, error) {
@@ -707,12 +665,7 @@ func getPeerIPDNSLabel(ip net.IP, peerHostName string) (string, error) {
}
// SyncPeer checks whether peer is eligible for receiving NetworkMap (authenticated) and returns its NetworkMap if eligible
func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync types.PeerSync, accountID string) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) {
start := time.Now()
defer func() {
log.WithContext(ctx).Debugf("SyncPeer: took %v", time.Since(start))
}()
func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync types.PeerSync, accountID string) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) {
var peer *nbpeer.Peer
var peerNotValid bool
var isStatusChanged bool
@@ -722,7 +675,7 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync types.PeerSy
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return nil, nil, nil, err
return nil, nil, nil, 0, err
}
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
@@ -772,14 +725,14 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync types.PeerSy
return nil
})
if err != nil {
return nil, nil, nil, err
return nil, nil, nil, 0, err
}
if isStatusChanged || sync.UpdateAccountPeers || (updated && (len(postureChecks) > 0 || versionChanged)) {
am.BufferUpdateAccountPeers(ctx, accountID)
am.networkMapController.OnPeerUpdated(accountID, peer)
}
return am.getValidatedPeerWithMap(ctx, peerNotValid, accountID, peer)
return am.networkMapController.GetValidatedPeerWithMap(ctx, peerNotValid, accountID, peer)
}
func (am *DefaultAccountManager) handlePeerLoginNotFound(ctx context.Context, login types.PeerLogin, err error) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) {
@@ -831,6 +784,7 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login types.Peer
return nil, nil, nil, err
}
startTransaction := time.Now()
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
peer, err = transaction.GetPeerByPeerPubKey(ctx, store.LockingStrengthUpdate, login.WireGuardPubKey)
if err != nil {
@@ -900,11 +854,14 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login types.Peer
return nil, nil, nil, err
}
log.WithContext(ctx).Debugf("LoginPeer: transaction took %v", time.Since(startTransaction))
if updateRemotePeers || isStatusChanged || (isPeerUpdated && len(postureChecks) > 0) {
am.BufferUpdateAccountPeers(ctx, accountID)
am.networkMapController.OnPeerUpdated(accountID, peer)
}
return am.getValidatedPeerWithMap(ctx, isRequiresApproval, accountID, peer)
p, nmap, pc, _, err := am.networkMapController.GetValidatedPeerWithMap(ctx, isRequiresApproval, accountID, peer)
return p, nmap, pc, err
}
// getPeerPostureChecks returns the posture checks for the peer.
@@ -996,57 +953,6 @@ func (am *DefaultAccountManager) checkIFPeerNeedsLoginWithoutLock(ctx context.Co
return nil
}
func (am *DefaultAccountManager) getValidatedPeerWithMap(ctx context.Context, isRequiresApproval bool, accountID string, peer *nbpeer.Peer) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) {
start := time.Now()
defer func() {
log.WithContext(ctx).Debugf("getValidatedPeerWithMap: took %s", time.Since(start))
}()
if isRequiresApproval {
network, err := am.Store.GetAccountNetwork(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return nil, nil, nil, err
}
emptyMap := &types.NetworkMap{
Network: network.Copy(),
}
return peer, emptyMap, nil, nil
}
account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
if err != nil {
return nil, nil, nil, err
}
approvedPeersMap, err := am.integratedPeerValidator.GetValidatedPeers(ctx, account.Id, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra)
if err != nil {
return nil, nil, nil, err
}
postureChecks, err := am.getPeerPostureChecks(account, peer.ID)
if err != nil {
return nil, nil, nil, err
}
customZone := account.GetPeersCustomZone(ctx, am.GetDNSDomain(account.Settings))
proxyNetworkMaps, err := am.proxyController.GetProxyNetworkMaps(ctx, account.Id, peer.ID, account.Peers)
if err != nil {
log.WithContext(ctx).Errorf("failed to get proxy network maps: %v", err)
return nil, nil, nil, err
}
networkMap := account.GetPeerNetworkMap(ctx, peer.ID, customZone, approvedPeersMap, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), am.metrics.AccountManagerMetrics())
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]
if ok {
networkMap.Merge(proxyNetworkMap)
}
return peer, networkMap, postureChecks, nil
}
func (am *DefaultAccountManager) handleExpiredPeer(ctx context.Context, transaction store.Store, user *types.User, peer *nbpeer.Peer) error {
err := checkAuth(ctx, user.Id, peer)
if err != nil {
@@ -1070,7 +976,7 @@ func (am *DefaultAccountManager) handleExpiredPeer(ctx context.Context, transact
return fmt.Errorf("failed to get account settings: %w", err)
}
am.StoreEvent(ctx, user.Id, peer.ID, user.AccountID, activity.UserLoggedInPeer, peer.EventMeta(am.GetDNSDomain(settings)))
am.StoreEvent(ctx, user.Id, peer.ID, user.AccountID, activity.UserLoggedInPeer, peer.EventMeta(am.networkMapController.GetDNSDomain(settings)))
return nil
}
@@ -1166,209 +1072,17 @@ func (am *DefaultAccountManager) checkIfUserOwnsPeer(ctx context.Context, accoun
// UpdateAccountPeers updates all peers that belong to an account.
// Should be called when changes have to be synced to peers.
func (am *DefaultAccountManager) UpdateAccountPeers(ctx context.Context, accountID string) {
log.WithContext(ctx).Tracef("updating peers for account %s from %s", accountID, util.GetCallerName())
account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
if err != nil {
log.WithContext(ctx).Errorf("failed to send out updates to peers. failed to get account: %v", err)
return
}
globalStart := time.Now()
hasPeersConnected := false
for _, peer := range account.Peers {
if am.peersUpdateManager.HasChannel(peer.ID) {
hasPeersConnected = true
break
}
}
if !hasPeersConnected {
return
}
approvedPeersMap, err := am.integratedPeerValidator.GetValidatedPeers(ctx, account.Id, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra)
if err != nil {
log.WithContext(ctx).Errorf("failed to send out updates to peers, failed to get validate peers: %v", err)
return
}
var wg sync.WaitGroup
semaphore := make(chan struct{}, 10)
dnsCache := &DNSConfigCache{}
dnsDomain := am.GetDNSDomain(account.Settings)
customZone := account.GetPeersCustomZone(ctx, dnsDomain)
resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap()
proxyNetworkMaps, err := am.proxyController.GetProxyNetworkMapsAll(ctx, accountID, account.Peers)
if err != nil {
log.WithContext(ctx).Errorf("failed to get proxy network maps: %v", err)
return
}
extraSetting, err := am.settingsManager.GetExtraSettings(ctx, accountID)
if err != nil {
log.WithContext(ctx).Errorf("failed to get flow enabled status: %v", err)
return
}
dnsFwdPort := computeForwarderPort(maps.Values(account.Peers), dnsForwarderPortMinVersion)
for _, peer := range account.Peers {
if !am.peersUpdateManager.HasChannel(peer.ID) {
log.WithContext(ctx).Tracef("peer %s doesn't have a channel, skipping network map update", peer.ID)
continue
}
wg.Add(1)
semaphore <- struct{}{}
go func(p *nbpeer.Peer) {
defer wg.Done()
defer func() { <-semaphore }()
start := time.Now()
postureChecks, err := am.getPeerPostureChecks(account, p.ID)
if err != nil {
log.WithContext(ctx).Debugf("failed to get posture checks for peer %s: %v", peer.ID, err)
return
}
am.metrics.UpdateChannelMetrics().CountCalcPostureChecksDuration(time.Since(start))
start = time.Now()
remotePeerNetworkMap := account.GetPeerNetworkMap(ctx, p.ID, customZone, approvedPeersMap, resourcePolicies, routers, am.metrics.AccountManagerMetrics())
am.metrics.UpdateChannelMetrics().CountCalcPeerNetworkMapDuration(time.Since(start))
start = time.Now()
proxyNetworkMap, ok := proxyNetworkMaps[p.ID]
if ok {
remotePeerNetworkMap.Merge(proxyNetworkMap)
}
am.metrics.UpdateChannelMetrics().CountMergeNetworkMapDuration(time.Since(start))
peerGroups := account.GetPeerGroups(p.ID)
start = time.Now()
update := toSyncResponse(ctx, nil, p, nil, nil, remotePeerNetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings, extraSetting, maps.Keys(peerGroups), dnsFwdPort)
am.metrics.UpdateChannelMetrics().CountToSyncResponseDuration(time.Since(start))
am.peersUpdateManager.SendUpdate(ctx, p.ID, &UpdateMessage{Update: update, NetworkMap: remotePeerNetworkMap})
}(peer)
}
//
wg.Wait()
if am.metrics != nil {
am.metrics.AccountManagerMetrics().CountUpdateAccountPeersDuration(time.Since(globalStart))
}
}
type bufferUpdate struct {
mu sync.Mutex
next *time.Timer
update atomic.Bool
_ = am.networkMapController.UpdateAccountPeers(ctx, accountID)
}
func (am *DefaultAccountManager) BufferUpdateAccountPeers(ctx context.Context, accountID string) {
log.WithContext(ctx).Tracef("buffer updating peers for account %s from %s", accountID, util.GetCallerName())
bufUpd, _ := am.accountUpdateLocks.LoadOrStore(accountID, &bufferUpdate{})
b := bufUpd.(*bufferUpdate)
if !b.mu.TryLock() {
b.update.Store(true)
return
}
if b.next != nil {
b.next.Stop()
}
go func() {
defer b.mu.Unlock()
am.UpdateAccountPeers(ctx, accountID)
if !b.update.Load() {
return
}
b.update.Store(false)
if b.next == nil {
b.next = time.AfterFunc(time.Duration(am.updateAccountPeersBufferInterval.Load()), func() {
am.UpdateAccountPeers(ctx, accountID)
})
return
}
b.next.Reset(time.Duration(am.updateAccountPeersBufferInterval.Load()))
}()
_ = am.networkMapController.BufferUpdateAccountPeers(ctx, accountID)
}
// UpdateAccountPeer updates a single peer that belongs to an account.
// Should be called when changes need to be synced to a specific peer only.
func (am *DefaultAccountManager) UpdateAccountPeer(ctx context.Context, accountId string, peerId string) {
if !am.peersUpdateManager.HasChannel(peerId) {
log.WithContext(ctx).Tracef("peer %s doesn't have a channel, skipping network map update", peerId)
return
}
account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountId)
if err != nil {
log.WithContext(ctx).Errorf("failed to send out updates to peer %s. failed to get account: %v", peerId, err)
return
}
peer := account.GetPeer(peerId)
if peer == nil {
log.WithContext(ctx).Tracef("peer %s doesn't exists in account %s", peerId, accountId)
return
}
approvedPeersMap, err := am.integratedPeerValidator.GetValidatedPeers(ctx, account.Id, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra)
if err != nil {
log.WithContext(ctx).Errorf("failed to send update to peer %s, failed to validate peers: %v", peerId, err)
return
}
dnsCache := &DNSConfigCache{}
dnsDomain := am.GetDNSDomain(account.Settings)
customZone := account.GetPeersCustomZone(ctx, dnsDomain)
resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap()
postureChecks, err := am.getPeerPostureChecks(account, peerId)
if err != nil {
log.WithContext(ctx).Errorf("failed to send update to peer %s, failed to get posture checks: %v", peerId, err)
return
}
proxyNetworkMaps, err := am.proxyController.GetProxyNetworkMaps(ctx, accountId, peerId, account.Peers)
if err != nil {
log.WithContext(ctx).Errorf("failed to get proxy network maps: %v", err)
return
}
remotePeerNetworkMap := account.GetPeerNetworkMap(ctx, peerId, customZone, approvedPeersMap, resourcePolicies, routers, am.metrics.AccountManagerMetrics())
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]
if ok {
remotePeerNetworkMap.Merge(proxyNetworkMap)
}
extraSettings, err := am.settingsManager.GetExtraSettings(ctx, peer.AccountID)
if err != nil {
log.WithContext(ctx).Errorf("failed to get extra settings: %v", err)
return
}
peerGroups := account.GetPeerGroups(peerId)
dnsFwdPort := computeForwarderPort(maps.Values(account.Peers), dnsForwarderPortMinVersion)
update := toSyncResponse(ctx, nil, peer, nil, nil, remotePeerNetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings, extraSettings, maps.Keys(peerGroups), dnsFwdPort)
am.peersUpdateManager.SendUpdate(ctx, peer.ID, &UpdateMessage{Update: update, NetworkMap: remotePeerNetworkMap})
_ = am.networkMapController.UpdateAccountPeer(ctx, accountId, peerId)
}
// getNextPeerExpiration returns the minimum duration in which the next peer of the account will expire if it was found.
@@ -1523,14 +1237,7 @@ func deletePeers(ctx context.Context, am *DefaultAccountManager, transaction sto
if err != nil {
return nil, err
}
dnsDomain := am.GetDNSDomain(settings)
network, err := transaction.GetAccountNetwork(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return nil, err
}
dnsFwdPort := computeForwarderPort(peers, dnsForwarderPortMinVersion)
dnsDomain := am.networkMapController.GetDNSDomain(settings)
for _, peer := range peers {
if err := transaction.RemovePeerFromAllGroups(ctx, peer.ID); err != nil {
@@ -1564,25 +1271,6 @@ func deletePeers(ctx context.Context, am *DefaultAccountManager, transaction sto
if err = transaction.DeletePeer(ctx, accountID, peer.ID); err != nil {
return nil, err
}
am.peersUpdateManager.SendUpdate(ctx, peer.ID, &UpdateMessage{
Update: &proto.SyncResponse{
RemotePeers: []*proto.RemotePeerConfig{},
RemotePeersIsEmpty: true,
NetworkMap: &proto.NetworkMap{
Serial: network.CurrentSerial(),
RemotePeers: []*proto.RemotePeerConfig{},
RemotePeersIsEmpty: true,
FirewallRules: []*proto.FirewallRule{},
FirewallRulesIsEmpty: true,
DNSConfig: &proto.DNSConfig{
ForwarderPort: dnsFwdPort,
},
},
},
NetworkMap: &types.NetworkMap{},
})
am.peersUpdateManager.CloseChannel(ctx, peer.ID)
peerDeletedEvents = append(peerDeletedEvents, func() {
am.StoreEvent(ctx, userID, peer.ID, accountID, activity.PeerRemovedByUser, peer.EventMeta(dnsDomain))
})
@@ -1591,14 +1279,6 @@ func deletePeers(ctx context.Context, am *DefaultAccountManager, transaction sto
return peerDeletedEvents, nil
}
func ConvertSliceToMap(existingLabels []string) map[string]struct{} {
labelMap := make(map[string]struct{}, len(existingLabels))
for _, label := range existingLabels {
labelMap[label] = struct{}{}
}
return labelMap
}
// validatePeerDelete checks if the peer can be deleted.
func (am *DefaultAccountManager) validatePeerDelete(ctx context.Context, transaction store.Store, accountId, peerId string) error {
linkedInIngressPorts, err := am.proxyController.IsPeerInIngressPorts(ctx, accountId, peerId)

View File

@@ -13,7 +13,6 @@ import (
"strconv"
"strings"
"sync"
"sync/atomic"
"testing"
"time"
@@ -25,10 +24,14 @@ import (
"golang.org/x/exp/maps"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller/cache"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
"github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/management/internals/shared/grpc"
"github.com/netbirdio/netbird/management/server/http/testing/testing_tools"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
"github.com/netbirdio/netbird/management/server/mock_server"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/shared/management/status"
@@ -168,7 +171,16 @@ func TestPeer_SessionExpired(t *testing.T) {
}
func TestAccountManager_GetNetworkMap(t *testing.T) {
manager, err := createManager(t)
testGetNetworkMapGeneral(t)
}
func TestAccountManager_GetNetworkMap_Experimental(t *testing.T) {
t.Setenv(network_map.EnvNewNetworkMapBuilder, "true")
testGetNetworkMapGeneral(t)
}
func testGetNetworkMapGeneral(t *testing.T) {
manager, _, err := createManager(t)
if err != nil {
t.Fatal(err)
return
@@ -240,7 +252,7 @@ func TestAccountManager_GetNetworkMap(t *testing.T) {
func TestAccountManager_GetNetworkMapWithPolicy(t *testing.T) {
// TODO: disable until we start use policy again
t.Skip()
manager, err := createManager(t)
manager, _, err := createManager(t)
if err != nil {
t.Fatal(err)
return
@@ -417,7 +429,7 @@ func TestAccountManager_GetNetworkMapWithPolicy(t *testing.T) {
}
func TestAccountManager_GetPeerNetwork(t *testing.T) {
manager, err := createManager(t)
manager, _, err := createManager(t)
if err != nil {
t.Fatal(err)
return
@@ -478,7 +490,7 @@ func TestAccountManager_GetPeerNetwork(t *testing.T) {
}
func TestDefaultAccountManager_GetPeer(t *testing.T) {
manager, err := createManager(t)
manager, _, err := createManager(t)
if err != nil {
t.Fatal(err)
return
@@ -665,7 +677,7 @@ func TestDefaultAccountManager_GetPeers(t *testing.T) {
}
for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
manager, err := createManager(t)
manager, _, err := createManager(t)
if err != nil {
t.Fatal(err)
return
@@ -733,12 +745,12 @@ func TestDefaultAccountManager_GetPeers(t *testing.T) {
}
}
func setupTestAccountManager(b testing.TB, peers int, groups int) (*DefaultAccountManager, string, string, error) {
func setupTestAccountManager(b testing.TB, peers int, groups int) (*DefaultAccountManager, *update_channel.PeersUpdateManager, string, string, error) {
b.Helper()
manager, err := createManager(b)
manager, updateManager, err := createManager(b)
if err != nil {
return nil, "", "", err
return nil, nil, "", "", err
}
accountID := "test_account"
@@ -789,7 +801,7 @@ func setupTestAccountManager(b testing.TB, peers int, groups int) (*DefaultAccou
ips := account.GetTakenIPs()
peerIP, err := types.AllocatePeerIP(account.Network.Net, ips)
if err != nil {
return nil, "", "", err
return nil, nil, "", "", err
}
peerKey, _ := wgtypes.GeneratePrivateKey()
@@ -895,10 +907,10 @@ func setupTestAccountManager(b testing.TB, peers int, groups int) (*DefaultAccou
err = manager.Store.SaveAccount(context.Background(), account)
if err != nil {
return nil, "", "", err
return nil, nil, "", "", err
}
return manager, accountID, regularUser, nil
return manager, updateManager, accountID, regularUser, nil
}
func BenchmarkGetPeers(b *testing.B) {
@@ -919,7 +931,7 @@ func BenchmarkGetPeers(b *testing.B) {
defer log.SetOutput(os.Stderr)
for _, bc := range benchCases {
b.Run(bc.name, func(b *testing.B) {
manager, accountID, userID, err := setupTestAccountManager(b, bc.peers, bc.groups)
manager, _, accountID, userID, err := setupTestAccountManager(b, bc.peers, bc.groups)
if err != nil {
b.Fatalf("Failed to setup test account manager: %v", err)
}
@@ -959,7 +971,7 @@ func BenchmarkUpdateAccountPeers(b *testing.B) {
for _, bc := range benchCases {
b.Run(bc.name, func(b *testing.B) {
manager, accountID, _, err := setupTestAccountManager(b, bc.peers, bc.groups)
manager, updateManager, accountID, _, err := setupTestAccountManager(b, bc.peers, bc.groups)
if err != nil {
b.Fatalf("Failed to setup test account manager: %v", err)
}
@@ -971,14 +983,10 @@ func BenchmarkUpdateAccountPeers(b *testing.B) {
b.Fatalf("Failed to get account: %v", err)
}
peerChannels := make(map[string]chan *UpdateMessage)
for peerID := range account.Peers {
peerChannels[peerID] = make(chan *UpdateMessage, channelBufferSize)
updateManager.CreateChannel(ctx, peerID)
}
manager.peersUpdateManager.peerChannels = peerChannels
b.ResetTimer()
start := time.Now()
@@ -1003,7 +1011,16 @@ func BenchmarkUpdateAccountPeers(b *testing.B) {
}
}
func TestUpdateAccountPeers_Experimental(t *testing.T) {
t.Setenv(network_map.EnvNewNetworkMapBuilder, "true")
testUpdateAccountPeers(t)
}
func TestUpdateAccountPeers(t *testing.T) {
testUpdateAccountPeers(t)
}
func testUpdateAccountPeers(t *testing.T) {
testCases := []struct {
name string
peers int
@@ -1019,7 +1036,7 @@ func TestUpdateAccountPeers(t *testing.T) {
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
manager, accountID, _, err := setupTestAccountManager(t, tc.peers, tc.groups)
manager, updateManager, accountID, _, err := setupTestAccountManager(t, tc.peers, tc.groups)
if err != nil {
t.Fatalf("Failed to setup test account manager: %v", err)
}
@@ -1031,20 +1048,19 @@ func TestUpdateAccountPeers(t *testing.T) {
t.Fatalf("Failed to get account: %v", err)
}
peerChannels := make(map[string]chan *UpdateMessage)
peerChannels := make(map[string]chan *network_map.UpdateMessage)
for peerID := range account.Peers {
peerChannels[peerID] = make(chan *UpdateMessage, channelBufferSize)
peerChannels[peerID] = updateManager.CreateChannel(ctx, peerID)
}
manager.peersUpdateManager.peerChannels = peerChannels
manager.UpdateAccountPeers(ctx, account.Id)
for _, channel := range peerChannels {
update := <-channel
assert.Nil(t, update.Update.NetbirdConfig)
assert.Equal(t, tc.peers, len(update.NetworkMap.Peers))
assert.Equal(t, tc.peers*2, len(update.NetworkMap.FirewallRules))
assert.Equal(t, tc.peers, len(update.Update.NetworkMap.RemotePeers))
assert.Equal(t, tc.peers*2, len(update.Update.NetworkMap.FirewallRules))
}
})
}
@@ -1079,7 +1095,7 @@ func TestToSyncResponse(t *testing.T) {
DNSLabel: "peer1",
SSHKey: "peer1-ssh-key",
}
turnRelayToken := &Token{
turnRelayToken := &grpc.Token{
Payload: "turn-user",
Signature: "turn-pass",
}
@@ -1159,9 +1175,9 @@ func TestToSyncResponse(t *testing.T) {
},
},
}
dnsCache := &DNSConfigCache{}
dnsCache := &cache.DNSConfigCache{}
accountSettings := &types.Settings{RoutingPeerDNSResolutionEnabled: true}
response := toSyncResponse(context.Background(), config, peer, turnRelayToken, turnRelayToken, networkMap, dnsName, checks, dnsCache, accountSettings, nil, []string{}, int64(dnsForwarderPort))
response := grpc.ToSyncResponse(context.Background(), config, peer, turnRelayToken, turnRelayToken, networkMap, dnsName, checks, dnsCache, accountSettings, nil, []string{}, int64(dnsForwarderPort))
assert.NotNil(t, response)
// assert peer config
@@ -1271,7 +1287,12 @@ func Test_RegisterPeerByUser(t *testing.T) {
settingsMockManager := settings.NewMockManager(ctrl)
permissionsManager := permissions.NewManager(s)
am, err := BuildManager(context.Background(), s, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
ctx := context.Background()
updateManager := update_channel.NewPeersUpdateManager(metrics)
requestBuffer := NewAccountRequestBuffer(ctx, s)
networkMapController := controller.NewController(ctx, s, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock())
am, err := BuildManager(context.Background(), s, networkMapController, nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
assert.NoError(t, err)
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
@@ -1351,7 +1372,12 @@ func Test_RegisterPeerBySetupKey(t *testing.T) {
AnyTimes()
permissionsManager := permissions.NewManager(s)
am, err := BuildManager(context.Background(), s, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
ctx := context.Background()
updateManager := update_channel.NewPeersUpdateManager(metrics)
requestBuffer := NewAccountRequestBuffer(ctx, s)
networkMapController := controller.NewController(ctx, s, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock())
am, err := BuildManager(context.Background(), s, networkMapController, nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
assert.NoError(t, err)
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
@@ -1499,7 +1525,12 @@ func Test_RegisterPeerRollbackOnFailure(t *testing.T) {
permissionsManager := permissions.NewManager(s)
am, err := BuildManager(context.Background(), s, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
ctx := context.Background()
updateManager := update_channel.NewPeersUpdateManager(metrics)
requestBuffer := NewAccountRequestBuffer(ctx, s)
networkMapController := controller.NewController(ctx, s, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock())
am, err := BuildManager(context.Background(), s, networkMapController, nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
assert.NoError(t, err)
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
@@ -1548,6 +1579,7 @@ func Test_RegisterPeerRollbackOnFailure(t *testing.T) {
}
func Test_LoginPeer(t *testing.T) {
t.Setenv(network_map.EnvNewNetworkMapBuilder, "true")
if runtime.GOOS == "windows" {
t.Skip("The SQLite store is not properly supported by Windows yet")
}
@@ -1573,7 +1605,12 @@ func Test_LoginPeer(t *testing.T) {
AnyTimes()
permissionsManager := permissions.NewManager(s)
am, err := BuildManager(context.Background(), s, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
ctx := context.Background()
updateManager := update_channel.NewPeersUpdateManager(metrics)
requestBuffer := NewAccountRequestBuffer(ctx, s)
networkMapController := controller.NewController(ctx, s, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock())
am, err := BuildManager(context.Background(), s, networkMapController, nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
assert.NoError(t, err)
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
@@ -1706,7 +1743,7 @@ func Test_LoginPeer(t *testing.T) {
}
func TestPeerAccountPeersUpdate(t *testing.T) {
manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t)
manager, updateManager, account, peer1, peer2, peer3 := setupNetworkMapTest(t)
err := manager.DeletePolicy(context.Background(), account.Id, account.Policies[0].ID, userID)
require.NoError(t, err)
@@ -1763,13 +1800,14 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
var peer5 *nbpeer.Peer
var peer6 *nbpeer.Peer
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
updMsg := updateManager.CreateChannel(context.Background(), peer1.ID)
t.Cleanup(func() {
manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
updateManager.CloseChannel(context.Background(), peer1.ID)
})
// Updating not expired peer and peer expiration is enabled should not update account peers and not send peer update
t.Run("updating not expired peer and peer expiration is enabled", func(t *testing.T) {
t.Skip("Currently all updates will trigger a network map")
done := make(chan struct{})
go func() {
peerShouldNotReceiveUpdate(t, updMsg)
@@ -1871,6 +1909,8 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
})
t.Run("validator requires no update", func(t *testing.T) {
t.Skip("Currently all updates will trigger a network map")
requireNoUpdateFunc := func(_ context.Context, update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *types.ExtraSettings) (*nbpeer.Peer, bool, error) {
return update, false, nil
}
@@ -2072,7 +2112,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
}
func Test_DeletePeer(t *testing.T) {
manager, err := createManager(t)
manager, _, err := createManager(t)
if err != nil {
t.Fatal(err)
return
@@ -2169,7 +2209,7 @@ func Test_IsUniqueConstraintError(t *testing.T) {
}
func Test_AddPeer(t *testing.T) {
manager, err := createManager(t)
manager, _, err := createManager(t)
if err != nil {
t.Fatal(err)
return
@@ -2257,136 +2297,8 @@ func Test_AddPeer(t *testing.T) {
assert.Equal(t, uint64(totalPeers), account.Network.Serial)
}
func TestBufferUpdateAccountPeers(t *testing.T) {
const (
peersCount = 1000
updateAccountInterval = 50 * time.Millisecond
)
var (
deletedPeers, updatePeersDeleted, updatePeersRuns atomic.Int32
uapLastRun, dpLastRun atomic.Int64
totalNewRuns, totalOldRuns int
)
uap := func(ctx context.Context, accountID string) {
updatePeersDeleted.Store(deletedPeers.Load())
updatePeersRuns.Add(1)
uapLastRun.Store(time.Now().UnixMilli())
time.Sleep(100 * time.Millisecond)
}
t.Run("new approach", func(t *testing.T) {
updatePeersRuns.Store(0)
updatePeersDeleted.Store(0)
deletedPeers.Store(0)
var mustore sync.Map
bufupd := func(ctx context.Context, accountID string) {
mu, _ := mustore.LoadOrStore(accountID, &bufferUpdate{})
b := mu.(*bufferUpdate)
if !b.mu.TryLock() {
b.update.Store(true)
return
}
if b.next != nil {
b.next.Stop()
}
go func() {
defer b.mu.Unlock()
uap(ctx, accountID)
if !b.update.Load() {
return
}
b.update.Store(false)
b.next = time.AfterFunc(updateAccountInterval, func() {
uap(ctx, accountID)
})
}()
}
dp := func(ctx context.Context, accountID, peerID, userID string) error {
deletedPeers.Add(1)
dpLastRun.Store(time.Now().UnixMilli())
time.Sleep(10 * time.Millisecond)
bufupd(ctx, accountID)
return nil
}
am := mock_server.MockAccountManager{
UpdateAccountPeersFunc: uap,
BufferUpdateAccountPeersFunc: bufupd,
DeletePeerFunc: dp,
}
empty := ""
for range peersCount {
//nolint
am.DeletePeer(context.Background(), empty, empty, empty)
}
time.Sleep(100 * time.Millisecond)
assert.Equal(t, peersCount, int(deletedPeers.Load()), "Expected all peers to be deleted")
assert.Equal(t, peersCount, int(updatePeersDeleted.Load()), "Expected all peers to be updated in the buffer")
assert.GreaterOrEqual(t, uapLastRun.Load(), dpLastRun.Load(), "Expected update account peers to run after delete peer")
totalNewRuns = int(updatePeersRuns.Load())
})
t.Run("old approach", func(t *testing.T) {
updatePeersRuns.Store(0)
updatePeersDeleted.Store(0)
deletedPeers.Store(0)
var mustore sync.Map
bufupd := func(ctx context.Context, accountID string) {
mu, _ := mustore.LoadOrStore(accountID, &sync.Mutex{})
b := mu.(*sync.Mutex)
if !b.TryLock() {
return
}
go func() {
time.Sleep(updateAccountInterval)
b.Unlock()
uap(ctx, accountID)
}()
}
dp := func(ctx context.Context, accountID, peerID, userID string) error {
deletedPeers.Add(1)
dpLastRun.Store(time.Now().UnixMilli())
time.Sleep(10 * time.Millisecond)
bufupd(ctx, accountID)
return nil
}
am := mock_server.MockAccountManager{
UpdateAccountPeersFunc: uap,
BufferUpdateAccountPeersFunc: bufupd,
DeletePeerFunc: dp,
}
empty := ""
for range peersCount {
//nolint
am.DeletePeer(context.Background(), empty, empty, empty)
}
time.Sleep(100 * time.Millisecond)
assert.Equal(t, peersCount, int(deletedPeers.Load()), "Expected all peers to be deleted")
assert.Equal(t, peersCount, int(updatePeersDeleted.Load()), "Expected all peers to be updated in the buffer")
assert.GreaterOrEqual(t, uapLastRun.Load(), dpLastRun.Load(), "Expected update account peers to run after delete peer")
totalOldRuns = int(updatePeersRuns.Load())
})
assert.Less(t, totalNewRuns, totalOldRuns, "Expected new approach to run less than old approach. New runs: %d, Old runs: %d", totalNewRuns, totalOldRuns)
t.Logf("New runs: %d, Old runs: %d", totalNewRuns, totalOldRuns)
}
func TestAddPeer_UserPendingApprovalBlocked(t *testing.T) {
manager, err := createManager(t)
manager, _, err := createManager(t)
if err != nil {
t.Fatal(err)
}
@@ -2423,7 +2335,7 @@ func TestAddPeer_UserPendingApprovalBlocked(t *testing.T) {
}
func TestAddPeer_ApprovedUserCanAddPeers(t *testing.T) {
manager, err := createManager(t)
manager, _, err := createManager(t)
if err != nil {
t.Fatal(err)
}
@@ -2457,7 +2369,7 @@ func TestAddPeer_ApprovedUserCanAddPeers(t *testing.T) {
}
func TestLoginPeer_UserPendingApprovalBlocked(t *testing.T) {
manager, err := createManager(t)
manager, _, err := createManager(t)
if err != nil {
t.Fatal(err)
}
@@ -2522,7 +2434,7 @@ func TestLoginPeer_UserPendingApprovalBlocked(t *testing.T) {
}
func TestLoginPeer_ApprovedUserCanLogin(t *testing.T) {
manager, err := createManager(t)
manager, _, err := createManager(t)
if err != nil {
t.Fatal(err)
}

View File

@@ -10,7 +10,6 @@ import (
"github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/management/proto"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/posture"
@@ -252,31 +251,3 @@ func getValidGroupIDs(groups map[string]*types.Group, groupIDs []string) []strin
return validIDs
}
// toProtocolFirewallRules converts the firewall rules to the protocol firewall rules.
func toProtocolFirewallRules(rules []*types.FirewallRule) []*proto.FirewallRule {
result := make([]*proto.FirewallRule, len(rules))
for i := range rules {
rule := rules[i]
fwRule := &proto.FirewallRule{
PolicyID: []byte(rule.PolicyID),
PeerIP: rule.PeerIP,
Direction: getProtoDirection(rule.Direction),
Action: getProtoAction(rule.Action),
Protocol: getProtoProtocol(rule.Protocol),
Port: rule.Port,
}
if shouldUsePortRange(fwRule) {
fwRule.PortInfo = rule.PortRange.ToProto()
}
result[i] = fwRule
}
return result
}
func shouldUsePortRange(rule *proto.FirewallRule) bool {
return rule.Port == "" && (rule.Protocol == proto.RuleProtocol_UDP || rule.Protocol == proto.RuleProtocol_TCP)
}

View File

@@ -266,7 +266,7 @@ func TestAccount_getPeersByPolicy(t *testing.T) {
expectedFirewallRules := []*types.FirewallRule{
{
PeerIP: "0.0.0.0",
PeerIP: "100.65.14.88",
Direction: types.FirewallRuleDirectionIN,
Action: "accept",
Protocol: "all",
@@ -274,7 +274,103 @@ func TestAccount_getPeersByPolicy(t *testing.T) {
PolicyID: "RuleDefault",
},
{
PeerIP: "0.0.0.0",
PeerIP: "100.65.14.88",
Direction: types.FirewallRuleDirectionOUT,
Action: "accept",
Protocol: "all",
Port: "",
PolicyID: "RuleDefault",
},
{
PeerIP: "100.65.62.5",
Direction: types.FirewallRuleDirectionIN,
Action: "accept",
Protocol: "all",
Port: "",
PolicyID: "RuleDefault",
},
{
PeerIP: "100.65.62.5",
Direction: types.FirewallRuleDirectionOUT,
Action: "accept",
Protocol: "all",
Port: "",
PolicyID: "RuleDefault",
},
{
PeerIP: "100.65.254.139",
Direction: types.FirewallRuleDirectionIN,
Action: "accept",
Protocol: "all",
Port: "",
PolicyID: "RuleDefault",
},
{
PeerIP: "100.65.254.139",
Direction: types.FirewallRuleDirectionOUT,
Action: "accept",
Protocol: "all",
Port: "",
PolicyID: "RuleDefault",
},
{
PeerIP: "100.65.32.206",
Direction: types.FirewallRuleDirectionIN,
Action: "accept",
Protocol: "all",
Port: "",
PolicyID: "RuleDefault",
},
{
PeerIP: "100.65.32.206",
Direction: types.FirewallRuleDirectionOUT,
Action: "accept",
Protocol: "all",
Port: "",
PolicyID: "RuleDefault",
},
{
PeerIP: "100.65.250.202",
Direction: types.FirewallRuleDirectionIN,
Action: "accept",
Protocol: "all",
Port: "",
PolicyID: "RuleDefault",
},
{
PeerIP: "100.65.250.202",
Direction: types.FirewallRuleDirectionOUT,
Action: "accept",
Protocol: "all",
Port: "",
PolicyID: "RuleDefault",
},
{
PeerIP: "100.65.13.186",
Direction: types.FirewallRuleDirectionIN,
Action: "accept",
Protocol: "all",
Port: "",
PolicyID: "RuleDefault",
},
{
PeerIP: "100.65.13.186",
Direction: types.FirewallRuleDirectionOUT,
Action: "accept",
Protocol: "all",
Port: "",
PolicyID: "RuleDefault",
},
{
PeerIP: "100.65.29.55",
Direction: types.FirewallRuleDirectionIN,
Action: "accept",
Protocol: "all",
Port: "",
PolicyID: "RuleDefault",
},
{
PeerIP: "100.65.29.55",
Direction: types.FirewallRuleDirectionOUT,
Action: "accept",
Protocol: "all",
@@ -833,10 +929,58 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) {
// We expect a single permissive firewall rule which all outgoing connections
peers, firewallRules = account.GetPeerConnectionResources(context.Background(), account.Peers["peerC"], approvedPeers)
assert.Len(t, peers, len(account.Groups["GroupSwarm"].Peers))
assert.Len(t, firewallRules, 1)
assert.Len(t, firewallRules, 7)
expectedFirewallRules := []*types.FirewallRule{
{
PeerIP: "0.0.0.0",
PeerIP: "100.65.80.39",
Direction: types.FirewallRuleDirectionOUT,
Action: "accept",
Protocol: "tcp",
Port: "80",
PolicyID: "RuleSwarm",
},
{
PeerIP: "100.65.14.88",
Direction: types.FirewallRuleDirectionOUT,
Action: "accept",
Protocol: "tcp",
Port: "80",
PolicyID: "RuleSwarm",
},
{
PeerIP: "100.65.62.5",
Direction: types.FirewallRuleDirectionOUT,
Action: "accept",
Protocol: "tcp",
Port: "80",
PolicyID: "RuleSwarm",
},
{
PeerIP: "100.65.32.206",
Direction: types.FirewallRuleDirectionOUT,
Action: "accept",
Protocol: "tcp",
Port: "80",
PolicyID: "RuleSwarm",
},
{
PeerIP: "100.65.13.186",
Direction: types.FirewallRuleDirectionOUT,
Action: "accept",
Protocol: "tcp",
Port: "80",
PolicyID: "RuleSwarm",
},
{
PeerIP: "100.65.29.55",
Direction: types.FirewallRuleDirectionOUT,
Action: "accept",
Protocol: "tcp",
Port: "80",
PolicyID: "RuleSwarm",
},
{
PeerIP: "100.65.21.56",
Direction: types.FirewallRuleDirectionOUT,
Action: "accept",
Protocol: "tcp",
@@ -991,7 +1135,7 @@ func sortFunc() func(a *types.FirewallRule, b *types.FirewallRule) int {
}
func TestPolicyAccountPeersUpdate(t *testing.T) {
manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t)
manager, updateManager, account, peer1, peer2, peer3 := setupNetworkMapTest(t)
g := []*types.Group{
{
@@ -1020,9 +1164,9 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
assert.NoError(t, err)
}
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
updMsg := updateManager.CreateChannel(context.Background(), peer1.ID)
t.Cleanup(func() {
manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
updateManager.CloseChannel(context.Background(), peer1.ID)
})
var policyWithGroupRulesNoPeers *types.Policy

View File

@@ -2,19 +2,15 @@ package server
import (
"context"
"errors"
"fmt"
"slices"
"github.com/rs/xid"
"golang.org/x/exp/maps"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/management/status"
)
@@ -136,27 +132,6 @@ func (am *DefaultAccountManager) ListPostureChecks(ctx context.Context, accountI
return am.Store.GetAccountPostureChecks(ctx, store.LockingStrengthNone, accountID)
}
// getPeerPostureChecks returns the posture checks applied for a given peer.
func (am *DefaultAccountManager) getPeerPostureChecks(account *types.Account, peerID string) ([]*posture.Checks, error) {
peerPostureChecks := make(map[string]*posture.Checks)
if len(account.PostureChecks) == 0 {
return nil, nil
}
for _, policy := range account.Policies {
if !policy.Enabled || len(policy.SourcePostureChecks) == 0 {
continue
}
if err := addPolicyPostureChecks(account, peerID, policy, peerPostureChecks); err != nil {
return nil, err
}
}
return maps.Values(peerPostureChecks), nil
}
// arePostureCheckChangesAffectPeers checks if the changes in posture checks are affecting peers.
func arePostureCheckChangesAffectPeers(ctx context.Context, transaction store.Store, accountID, postureCheckID string) (bool, error) {
policies, err := transaction.GetAccountPolicies(ctx, store.LockingStrengthNone, accountID)
@@ -211,50 +186,6 @@ func validatePostureChecks(ctx context.Context, transaction store.Store, account
return nil
}
// addPolicyPostureChecks adds posture checks from a policy to the peer posture checks map if the peer is in the policy's source groups.
func addPolicyPostureChecks(account *types.Account, peerID string, policy *types.Policy, peerPostureChecks map[string]*posture.Checks) error {
isInGroup, err := isPeerInPolicySourceGroups(account, peerID, policy)
if err != nil {
return err
}
if !isInGroup {
return nil
}
for _, sourcePostureCheckID := range policy.SourcePostureChecks {
postureCheck := account.GetPostureChecks(sourcePostureCheckID)
if postureCheck == nil {
return errors.New("failed to add policy posture checks: posture checks not found")
}
peerPostureChecks[sourcePostureCheckID] = postureCheck
}
return nil
}
// isPeerInPolicySourceGroups checks if a peer is present in any of the policy rule source groups.
func isPeerInPolicySourceGroups(account *types.Account, peerID string, policy *types.Policy) (bool, error) {
for _, rule := range policy.Rules {
if !rule.Enabled {
continue
}
for _, sourceGroup := range rule.Sources {
group := account.GetGroup(sourceGroup)
if group == nil {
return false, fmt.Errorf("failed to check peer in policy source group: group not found")
}
if slices.Contains(group.Peers, peerID) {
return true, nil
}
}
}
return false, nil
}
// isPostureCheckLinkedToPolicy checks whether the posture check is linked to any account policy.
func isPostureCheckLinkedToPolicy(ctx context.Context, transaction store.Store, postureChecksID, accountID string) error {
policies, err := transaction.GetAccountPolicies(ctx, store.LockingStrengthNone, accountID)

Some files were not shown because too many files have changed in this diff Show More