Compare commits

...

14 Commits

Author SHA1 Message Date
bcmmbaga
4bed26e416 Add diff and hash tests for ignored tags 2024-07-02 13:29:46 +03:00
bcmmbaga
67cc8bd655 add tests 2024-07-02 12:57:42 +03:00
bcmmbaga
42be72a86c Replace hashstructure package with r3labs/diff for network map updates 2024-06-28 15:44:38 +03:00
bcmmbaga
16387a823a Reset timer in benchmark test functions 2024-06-27 17:14:22 +03:00
bcmmbaga
b4dddc8d0f Add server account peer update functions and tests 2024-06-27 00:55:48 +03:00
bcmmbaga
7a0dc10ccc Add network map hash to avoid unnecessary updates 2024-06-26 16:29:10 +03:00
Viktor Liu
628673db20 Lower retry interval on dns resolve failure (#2176) 2024-06-24 11:55:07 +02:00
Bethuel Mmbaga
eaa31c2dc6 Optimize process checks database read (#2182)
* Add posture checks to peer management

This commit includes posture checks to the peer management logic. The AddPeer, SyncPeer and LoginPeer functions now return a list of posture checks along with the peer and network map.

* Update peer methods to return posture checks

* Refactor

* return early if there is no posture checks

---------

Co-authored-by: Maycon Santos <mlsmaycon@gmail.com>
2024-06-22 17:41:16 +03:00
Zoltan Papp
25723e9b07 Do not use eBPF proxy in case of USP mode (#2180) 2024-06-22 15:33:10 +02:00
Robert Neumann
3cf4d5758f Update Zitadel and CockroachDB Container Image Version (#2169)
* fix type in docker compose

* Update docker compose cockroachdb to latest-23.2 and zitadel to 2.54.3
2024-06-22 12:44:45 +02:00
Bethuel Mmbaga
fc15ee6351 auto migrate older management to sqlite (#2170) 2024-06-20 19:45:57 +02:00
Viktor Liu
4a3e78fb0f Fix windows network monitor next hop ip log (#2168) 2024-06-20 16:59:33 +02:00
Viktor Liu
f9462eea27 Fix dns route retrieval condition (#2165)
* Fix route retrieval condition

* Make error messages take domains into account
2024-06-20 13:52:32 +02:00
Viktor Liu
b075009ef7 Fix windows route zones (#2164)
* Fix windows zone and add additional debug output

* Fix routes zone on BSD

* Remove redundant Unmap

* Add zone to windows routes
2024-06-20 13:02:02 +02:00
27 changed files with 835 additions and 233 deletions

View File

@@ -282,8 +282,6 @@ func (e *Engine) Start() error {
}
e.ctx, e.cancel = context.WithCancel(e.clientCtx)
e.wgProxyFactory = wgproxy.NewFactory(e.ctx, e.config.WgPort)
wgIface, err := e.newWgIface()
if err != nil {
log.Errorf("failed creating wireguard interface instance %s: [%s]", e.config.WgIfaceName, err)
@@ -291,6 +289,9 @@ func (e *Engine) Start() error {
}
e.wgInterface = wgIface
userspace := e.wgInterface.IsUserspaceBind()
e.wgProxyFactory = wgproxy.NewFactory(e.ctx, userspace, e.config.WgPort)
if e.config.RosenpassEnabled {
log.Infof("rosenpass is enabled")
if e.config.RosenpassPermissive {

View File

@@ -101,11 +101,14 @@ func routeChanged(nexthop systemops.Nexthop, intf *net.Interface, routes map[net
if r, ok := routes[unspec]; ok {
if r.Nexthop != nexthop.IP || compareIntf(r.Interface, intf) != 0 {
intf := "<nil>"
if r.Interface != nil {
intf = r.Interface.Name
oldIntf, newIntf := "<nil>", "<nil>"
if intf != nil {
oldIntf = intf.Name
}
log.Infof("network monitor: default route changed: %s via %s (%s)", r.Destination, r.Nexthop, intf)
if r.Interface != nil {
newIntf = r.Interface.Name
}
log.Infof("network monitor: default route changed: %s from %s (%s) to %s (%s)", r.Destination, nexthop.IP, oldIntf, r.Nexthop, newIntf)
return true
}
} else {

View File

@@ -36,7 +36,7 @@ func TestNewConn_interfaceFilter(t *testing.T) {
}
func TestConn_GetKey(t *testing.T) {
wgProxyFactory := wgproxy.NewFactory(context.Background(), connConf.LocalWgPort)
wgProxyFactory := wgproxy.NewFactory(context.Background(), false, connConf.LocalWgPort)
defer func() {
_ = wgProxyFactory.Free()
}()
@@ -51,7 +51,7 @@ func TestConn_GetKey(t *testing.T) {
}
func TestConn_OnRemoteOffer(t *testing.T) {
wgProxyFactory := wgproxy.NewFactory(context.Background(), connConf.LocalWgPort)
wgProxyFactory := wgproxy.NewFactory(context.Background(), false, connConf.LocalWgPort)
defer func() {
_ = wgProxyFactory.Free()
}()
@@ -88,7 +88,7 @@ func TestConn_OnRemoteOffer(t *testing.T) {
}
func TestConn_OnRemoteAnswer(t *testing.T) {
wgProxyFactory := wgproxy.NewFactory(context.Background(), connConf.LocalWgPort)
wgProxyFactory := wgproxy.NewFactory(context.Background(), false, connConf.LocalWgPort)
defer func() {
_ = wgProxyFactory.Free()
}()
@@ -124,7 +124,7 @@ func TestConn_OnRemoteAnswer(t *testing.T) {
wg.Wait()
}
func TestConn_Status(t *testing.T) {
wgProxyFactory := wgproxy.NewFactory(context.Background(), connConf.LocalWgPort)
wgProxyFactory := wgproxy.NewFactory(context.Background(), false, connConf.LocalWgPort)
defer func() {
_ = wgProxyFactory.Free()
}()
@@ -154,7 +154,7 @@ func TestConn_Status(t *testing.T) {
}
func TestConn_Close(t *testing.T) {
wgProxyFactory := wgproxy.NewFactory(context.Background(), connConf.LocalWgPort)
wgProxyFactory := wgproxy.NewFactory(context.Background(), false, connConf.LocalWgPort)
defer func() {
_ = wgProxyFactory.Free()
}()

View File

@@ -23,7 +23,8 @@ import (
const (
DefaultInterval = time.Minute
minInterval = 2 * time.Second
minInterval = 2 * time.Second
failureInterval = 5 * time.Second
addAllowedIP = "add allowed IP %s: %w"
)
@@ -160,7 +161,12 @@ func (r *Route) startResolver(ctx context.Context) {
ticker := time.NewTicker(interval)
defer ticker.Stop()
r.update(ctx)
if err := r.update(ctx); err != nil {
log.Errorf("Failed to resolve domains for route [%v]: %v", r, err)
if interval > failureInterval {
ticker.Reset(failureInterval)
}
}
for {
select {
@@ -168,17 +174,28 @@ func (r *Route) startResolver(ctx context.Context) {
log.Debugf("Stopping dynamic route resolver for domains [%v]", r)
return
case <-ticker.C:
r.update(ctx)
if err := r.update(ctx); err != nil {
log.Errorf("Failed to resolve domains for route [%v]: %v", r, err)
// Use a lower ticker interval if the update fails
if interval > failureInterval {
ticker.Reset(failureInterval)
}
} else if interval > failureInterval {
// Reset to the original interval if the update succeeds
ticker.Reset(interval)
}
}
}
}
func (r *Route) update(ctx context.Context) {
func (r *Route) update(ctx context.Context) error {
if resolved, err := r.resolveDomains(); err != nil {
log.Errorf("Failed to resolve domains for route [%v]: %v", r, err)
return fmt.Errorf("resolve domains: %w", err)
} else if err := r.updateDynamicRoutes(ctx, resolved); err != nil {
log.Errorf("Failed to update dynamic routes for [%v]: %v", r, err)
return fmt.Errorf("update dynamic routes: %w", err)
}
return nil
}
func (r *Route) resolveDomains() (domainMap, error) {

View File

@@ -92,7 +92,7 @@ func toNetIP(a route.Addr) netip.Addr {
case *route.Inet6Addr:
ip := netip.AddrFrom16(t.IP)
if t.ZoneID != 0 {
ip.WithZone(strconv.Itoa(t.ZoneID))
ip = ip.WithZone(strconv.Itoa(t.ZoneID))
}
return ip
default:

View File

@@ -356,7 +356,7 @@ func GetNextHop(ip netip.Addr) (Nexthop, error) {
return Nexthop{}, fmt.Errorf("convert preferred source to address: %w", err)
}
return Nexthop{
IP: addr.Unmap(),
IP: addr,
Intf: intf,
}, nil
}
@@ -380,12 +380,12 @@ func ipToAddr(ip net.IP, intf *net.Interface) (netip.Addr, error) {
}
if intf != nil && (addr.IsLinkLocalMulticast() || addr.IsLinkLocalUnicast()) {
log.Tracef("Adding zone %s to address %s", intf.Name, addr)
zone := intf.Name
if runtime.GOOS == "windows" {
addr = addr.WithZone(strconv.Itoa(intf.Index))
} else {
addr = addr.WithZone(intf.Name)
zone = strconv.Itoa(intf.Index)
}
log.Tracef("Adding zone %s to address %s", zone, addr)
addr = addr.WithZone(zone)
}
return addr.Unmap(), nil

View File

@@ -71,7 +71,6 @@ func (r *SysOps) addToRouteTable(prefix netip.Prefix, nexthop Nexthop) error {
return fmt.Errorf("invalid zone: %w", err)
}
nexthop.Intf = &net.Interface{Index: zone}
nexthop.IP.WithZone("")
}
return addRouteCmd(prefix, nexthop)
@@ -80,8 +79,8 @@ func (r *SysOps) addToRouteTable(prefix netip.Prefix, nexthop Nexthop) error {
func (r *SysOps) removeFromRouteTable(prefix netip.Prefix, nexthop Nexthop) error {
args := []string{"delete", prefix.String()}
if nexthop.IP.IsValid() {
nexthop.IP.WithZone("")
args = append(args, nexthop.IP.Unmap().String())
ip := nexthop.IP.WithZone("")
args = append(args, ip.Unmap().String())
}
routeCmd := uspfilter.GetSystem32Command("route")
@@ -146,6 +145,10 @@ func GetRoutes() ([]Route, error) {
Index: int(entry.InterfaceIndex),
Name: entry.InterfaceAlias,
}
if nexthop.Is6() && (nexthop.IsLinkLocalUnicast() || nexthop.IsLinkLocalMulticast()) {
nexthop = nexthop.WithZone(strconv.Itoa(int(entry.InterfaceIndex)))
}
}
routes = append(routes, Route{
@@ -189,7 +192,8 @@ func addRouteCmd(prefix netip.Prefix, nexthop Nexthop) error {
args := []string{"add", prefix.String()}
if nexthop.IP.IsValid() {
args = append(args, nexthop.IP.Unmap().String())
ip := nexthop.IP.WithZone("")
args = append(args, ip.Unmap().String())
} else {
addr := "0.0.0.0"
if prefix.Addr().Is6() {

View File

@@ -8,9 +8,13 @@ import (
log "github.com/sirupsen/logrus"
)
func NewFactory(ctx context.Context, wgPort int) *Factory {
func NewFactory(ctx context.Context, userspace bool, wgPort int) *Factory {
f := &Factory{wgPort: wgPort}
if userspace {
return f
}
ebpfProxy := NewWGEBPFProxy(ctx, wgPort)
err := ebpfProxy.listen()
if err != nil {

View File

@@ -4,6 +4,6 @@ package wgproxy
import "context"
func NewFactory(ctx context.Context, wgPort int) *Factory {
func NewFactory(ctx context.Context, _ bool, wgPort int) *Factory {
return &Factory{wgPort: wgPort}
}

1
go.mod
View File

@@ -67,6 +67,7 @@ require (
github.com/pion/transport/v3 v3.0.1
github.com/pion/turn/v3 v3.0.1
github.com/prometheus/client_golang v1.19.1
github.com/r3labs/diff v1.1.0
github.com/rs/xid v1.3.0
github.com/shirou/gopsutil/v3 v3.24.4
github.com/skratchdot/open-golang v0.0.0-20200116055534-eef842397966

2
go.sum
View File

@@ -415,6 +415,8 @@ github.com/prometheus/common v0.53.0 h1:U2pL9w9nmJwJDa4qqLQ3ZaePJ6ZTwt7cMD3AG3+a
github.com/prometheus/common v0.53.0/go.mod h1:BrxBKv3FWBIGXw89Mg1AeBq7FSyRzXWI3l3e7W3RN5U=
github.com/prometheus/procfs v0.15.0 h1:A82kmvXJq2jTu5YUhSGNlYoxh85zLnKgPz4bMZgI5Ek=
github.com/prometheus/procfs v0.15.0/go.mod h1:Y0RJ/Y5g5wJpkTisOtqwDSo4HwhGmLB4VQSw2sQJLHk=
github.com/r3labs/diff v1.1.0 h1:V53xhrbTHrWFWq3gI4b94AjgEJOerO1+1l0xyHOBi8M=
github.com/r3labs/diff v1.1.0/go.mod h1:7WjXasNzi0vJetRcB/RqNl5dlIsmXcTTLmF5IoH6Xig=
github.com/rogpeppe/go-internal v1.11.0 h1:cWPaGQEPrBb5/AsnsZesgZZ9yb1OQ+GOISoDNXVBh4M=
github.com/rogpeppe/go-internal v1.11.0/go.mod h1:ddIwULY96R17DhadqLgMfk9H9tvdUzkipdSkR5nkCZA=
github.com/rs/cors v1.8.0 h1:P2KMzcFwrPoSjkF1WLRPsp3UMLyql8L4v9hQpVeK5so=

View File

@@ -717,7 +717,7 @@ services:
volumes:
- netbird_caddy_data:/data
- ./Caddyfile:/etc/caddy/Caddyfile
#UI dashboard
# UI dashboard
dashboard:
image: netbirdio/dashboard:latest
restart: unless-stopped
@@ -760,7 +760,7 @@ services:
zitadel:
restart: 'always'
networks: [netbird]
image: 'ghcr.io/zitadel/zitadel:v2.31.3'
image: 'ghcr.io/zitadel/zitadel:v2.54.3'
command: 'start-from-init --masterkeyFromEnv --tlsMode $ZITADEL_TLS_MODE'
env_file:
- ./zitadel.env
@@ -774,7 +774,7 @@ services:
crdb:
restart: 'always'
networks: [netbird]
image: 'cockroachdb/cockroach:v22.2.2'
image: 'cockroachdb/cockroach:latest-v23.2'
command: 'start-single-node --advertise-addr crdb'
volumes:
- netbird_crdb_data:/cockroach/cockroach-data

View File

@@ -18,6 +18,11 @@ import (
"github.com/eko/gocache/v3/cache"
cacheStore "github.com/eko/gocache/v3/store"
gocache "github.com/patrickmn/go-cache"
"github.com/rs/xid"
log "github.com/sirupsen/logrus"
"golang.org/x/exp/maps"
"github.com/netbirdio/netbird/base62"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/domain"
@@ -33,10 +38,6 @@ import (
"github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/management/server/status"
"github.com/netbirdio/netbird/route"
gocache "github.com/patrickmn/go-cache"
"github.com/rs/xid"
log "github.com/sirupsen/logrus"
"golang.org/x/exp/maps"
)
const (
@@ -83,7 +84,7 @@ type AccountManager interface {
UpdatePeer(accountID, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, error)
GetNetworkMap(peerID string) (*NetworkMap, error)
GetPeerNetwork(peerID string) (*Network, error)
AddPeer(setupKey, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, *NetworkMap, error)
AddPeer(setupKey, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error)
CreatePAT(accountID string, initiatorUserID string, targetUserID string, tokenName string, expiresIn int) (*PersonalAccessTokenGenerated, error)
DeletePAT(accountID string, initiatorUserID string, targetUserID string, tokenID string) error
GetPAT(accountID string, initiatorUserID string, targetUserID string, tokenID string) (*PersonalAccessToken, error)
@@ -118,10 +119,9 @@ type AccountManager interface {
GetDNSSettings(accountID string, userID string) (*DNSSettings, error)
SaveDNSSettings(accountID string, userID string, dnsSettingsToSave *DNSSettings) error
GetPeer(accountID, peerID, userID string) (*nbpeer.Peer, error)
GetPeerAppliedPostureChecks(peerKey string) ([]posture.Checks, error)
UpdateAccountSettings(accountID, userID string, newSettings *Settings) (*Account, error)
LoginPeer(login PeerLogin) (*nbpeer.Peer, *NetworkMap, error) // used by peer gRPC API
SyncPeer(sync PeerSync, account *Account) (*nbpeer.Peer, *NetworkMap, error) // used by peer gRPC API
LoginPeer(login PeerLogin) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) // used by peer gRPC API
SyncPeer(sync PeerSync, account *Account) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) // used by peer gRPC API
GetAllConnectedPeers() (map[string]struct{}, error)
HasConnectedChannel(peerID string) bool
GetExternalCacheManager() ExternalCacheManager
@@ -133,7 +133,7 @@ type AccountManager interface {
UpdateIntegratedValidatorGroups(accountID string, userID string, groups []string) error
GroupValidation(accountId string, groups []string) (bool, error)
GetValidatedPeers(account *Account) (map[string]struct{}, error)
SyncAndMarkPeer(peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *NetworkMap, error)
SyncAndMarkPeer(peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error)
CancelPeerRoutines(peer *nbpeer.Peer) error
SyncPeerMeta(peerPubKey string, meta nbpeer.PeerSystemMeta) error
FindExistingPostureCheck(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error)
@@ -167,6 +167,8 @@ type DefaultAccountManager struct {
userDeleteFromIDPEnabled bool
integratedPeerValidator integrated_validator.IntegratedValidator
networkMapHash map[string]uint64
}
// Settings represents Account settings structure that can be modified via API and Dashboard
@@ -383,9 +385,9 @@ func (a *Account) getRoutingPeerRoutes(peerID string) (enabledRoutes []*route.Ro
func (a *Account) GetRoutesByPrefixOrDomains(prefix netip.Prefix, domains domain.List) []*route.Route {
var routes []*route.Route
for _, r := range a.Routes {
if r.IsDynamic() && r.Domains.PunycodeString() == domains.PunycodeString() {
routes = append(routes, r)
} else if r.Network.String() == prefix.String() {
dynamic := r.IsDynamic()
if dynamic && r.Domains.PunycodeString() == domains.PunycodeString() ||
!dynamic && r.Network.String() == prefix.String() {
routes = append(routes, r)
}
}
@@ -1855,13 +1857,13 @@ func (am *DefaultAccountManager) getAccountWithAuthorizationClaims(claims jwtcla
}
}
func (am *DefaultAccountManager) SyncAndMarkPeer(peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *NetworkMap, error) {
func (am *DefaultAccountManager) SyncAndMarkPeer(peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) {
accountID, err := am.Store.GetAccountIDByPeerPubKey(peerPubKey)
if err != nil {
if errStatus, ok := status.FromError(err); ok && errStatus.Type() == status.NotFound {
return nil, nil, status.Errorf(status.Unauthenticated, "peer not registered")
return nil, nil, nil, status.Errorf(status.Unauthenticated, "peer not registered")
}
return nil, nil, err
return nil, nil, nil, err
}
unlock := am.Store.AcquireAccountReadLock(accountID)
@@ -1869,12 +1871,12 @@ func (am *DefaultAccountManager) SyncAndMarkPeer(peerPubKey string, meta nbpeer.
account, err := am.Store.GetAccount(accountID)
if err != nil {
return nil, nil, err
return nil, nil, nil, err
}
peer, netMap, err := am.SyncPeer(PeerSync{WireGuardPubKey: peerPubKey, Meta: meta}, account)
peer, netMap, postureChecks, err := am.SyncPeer(PeerSync{WireGuardPubKey: peerPubKey, Meta: meta}, account)
if err != nil {
return nil, nil, err
return nil, nil, nil, err
}
err = am.MarkPeerConnected(peerPubKey, true, realIP, account)
@@ -1882,7 +1884,7 @@ func (am *DefaultAccountManager) SyncAndMarkPeer(peerPubKey string, meta nbpeer.
log.Warnf("failed marking peer as connected %s %v", peerPubKey, err)
}
return peer, netMap, nil
return peer, netMap, postureChecks, nil
}
func (am *DefaultAccountManager) CancelPeerRoutines(peer *nbpeer.Peer) error {
@@ -1925,7 +1927,7 @@ func (am *DefaultAccountManager) SyncPeerMeta(peerPubKey string, meta nbpeer.Pee
return err
}
_, _, err = am.SyncPeer(PeerSync{WireGuardPubKey: peerPubKey, Meta: meta, UpdateAccountPeers: true}, account)
_, _, _, err = am.SyncPeer(PeerSync{WireGuardPubKey: peerPubKey, Meta: meta, UpdateAccountPeers: true}, account)
if err != nil {
return mapError(err)
}

View File

@@ -85,7 +85,7 @@ func verifyCanAddPeerToAccount(t *testing.T, manager AccountManager, account *Ac
setupKey = key.Key
}
_, _, err := manager.AddPeer(setupKey, userID, peer)
_, _, _, err := manager.AddPeer(setupKey, userID, peer)
if err != nil {
t.Error("expected to add new peer successfully after creating new account, but failed", err)
}
@@ -997,7 +997,7 @@ func TestAccountManager_AddPeer(t *testing.T) {
expectedPeerKey := key.PublicKey().String()
expectedSetupKey := setupKey.Key
peer, _, err := manager.AddPeer(setupKey.Key, "", &nbpeer.Peer{
peer, _, _, err := manager.AddPeer(setupKey.Key, "", &nbpeer.Peer{
Key: expectedPeerKey,
Meta: nbpeer.PeerSystemMeta{Hostname: expectedPeerKey},
})
@@ -1065,7 +1065,7 @@ func TestAccountManager_AddPeerWithUserID(t *testing.T) {
expectedPeerKey := key.PublicKey().String()
expectedUserID := userID
peer, _, err := manager.AddPeer("", userID, &nbpeer.Peer{
peer, _, _, err := manager.AddPeer("", userID, &nbpeer.Peer{
Key: expectedPeerKey,
Meta: nbpeer.PeerSystemMeta{Hostname: expectedPeerKey},
})
@@ -1140,7 +1140,7 @@ func TestAccountManager_NetworkUpdates(t *testing.T) {
}
expectedPeerKey := key.PublicKey().String()
peer, _, err := manager.AddPeer(setupKey.Key, "", &nbpeer.Peer{
peer, _, _, err := manager.AddPeer(setupKey.Key, "", &nbpeer.Peer{
Key: expectedPeerKey,
Meta: nbpeer.PeerSystemMeta{Hostname: expectedPeerKey},
})
@@ -1315,7 +1315,7 @@ func TestAccountManager_DeletePeer(t *testing.T) {
peerKey := key.PublicKey().String()
peer, _, err := manager.AddPeer(setupKey.Key, "", &nbpeer.Peer{
peer, _, _, err := manager.AddPeer(setupKey.Key, "", &nbpeer.Peer{
Key: peerKey,
Meta: nbpeer.PeerSystemMeta{Hostname: peerKey},
})
@@ -1662,7 +1662,7 @@ func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) {
key, err := wgtypes.GenerateKey()
require.NoError(t, err, "unable to generate WireGuard key")
peer, _, err := manager.AddPeer("", userID, &nbpeer.Peer{
peer, _, _, err := manager.AddPeer("", userID, &nbpeer.Peer{
Key: key.PublicKey().String(),
Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer"},
LoginExpirationEnabled: true,
@@ -1715,7 +1715,7 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing.
key, err := wgtypes.GenerateKey()
require.NoError(t, err, "unable to generate WireGuard key")
_, _, err = manager.AddPeer("", userID, &nbpeer.Peer{
_, _, _, err = manager.AddPeer("", userID, &nbpeer.Peer{
Key: key.PublicKey().String(),
Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer"},
LoginExpirationEnabled: true,
@@ -1759,7 +1759,7 @@ func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *test
key, err := wgtypes.GenerateKey()
require.NoError(t, err, "unable to generate WireGuard key")
_, _, err = manager.AddPeer("", userID, &nbpeer.Peer{
_, _, _, err = manager.AddPeer("", userID, &nbpeer.Peer{
Key: key.PublicKey().String(),
Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer"},
LoginExpirationEnabled: true,

View File

@@ -256,11 +256,11 @@ func initTestDNSAccount(t *testing.T, am *DefaultAccountManager) (*Account, erro
return nil, err
}
savedPeer1, _, err := am.AddPeer("", dnsAdminUserID, peer1)
savedPeer1, _, _, err := am.AddPeer("", dnsAdminUserID, peer1)
if err != nil {
return nil, err
}
_, _, err = am.AddPeer("", dnsAdminUserID, peer2)
_, _, _, err = am.AddPeer("", dnsAdminUserID, peer2)
if err != nil {
return nil, err
}

View File

@@ -139,12 +139,12 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi
log.Tracef("peer system meta has to be provided on sync. Peer %s, remote addr %s", peerKey.String(), realIP)
}
peer, netMap, err := s.accountManager.SyncAndMarkPeer(peerKey.String(), extractPeerMeta(syncReq.GetMeta()), realIP)
peer, netMap, postureChecks, err := s.accountManager.SyncAndMarkPeer(peerKey.String(), extractPeerMeta(syncReq.GetMeta()), realIP)
if err != nil {
return mapError(err)
}
err = s.sendInitialSync(peerKey, peer, netMap, srv)
err = s.sendInitialSync(peerKey, peer, netMap, postureChecks, srv)
if err != nil {
log.Debugf("error while sending initial sync for %s: %v", peerKey.String(), err)
return err
@@ -376,7 +376,7 @@ func (s *GRPCServer) Login(ctx context.Context, req *proto.EncryptedMessage) (*p
sshKey = loginReq.GetPeerKeys().GetSshPubKey()
}
peer, netMap, err := s.accountManager.LoginPeer(PeerLogin{
peer, netMap, postureChecks, err := s.accountManager.LoginPeer(PeerLogin{
WireGuardPubKey: peerKey.String(),
SSHKey: string(sshKey),
Meta: extractPeerMeta(loginReq.GetMeta()),
@@ -398,7 +398,7 @@ func (s *GRPCServer) Login(ctx context.Context, req *proto.EncryptedMessage) (*p
loginResp := &proto.LoginResponse{
WiretrusteeConfig: toWiretrusteeConfig(s.config, nil),
PeerConfig: toPeerConfig(peer, netMap.Network, s.accountManager.GetDNSDomain()),
Checks: toProtocolChecks(s.accountManager, peerKey.String()),
Checks: toProtocolChecks(postureChecks),
}
encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, loginResp)
if err != nil {
@@ -520,7 +520,7 @@ func toRemotePeerConfig(peers []*nbpeer.Peer, dnsName string) []*proto.RemotePee
return remotePeers
}
func toSyncResponse(accountManager AccountManager, config *Config, peer *nbpeer.Peer, turnCredentials *TURNCredentials, networkMap *NetworkMap, dnsName string) *proto.SyncResponse {
func toSyncResponse(config *Config, peer *nbpeer.Peer, turnCredentials *TURNCredentials, networkMap *NetworkMap, dnsName string, checks []*posture.Checks) *proto.SyncResponse {
wtConfig := toWiretrusteeConfig(config, turnCredentials)
pConfig := toPeerConfig(peer, networkMap.Network, dnsName)
@@ -551,7 +551,7 @@ func toSyncResponse(accountManager AccountManager, config *Config, peer *nbpeer.
FirewallRules: firewallRules,
FirewallRulesIsEmpty: len(firewallRules) == 0,
},
Checks: toProtocolChecks(accountManager, peer.Key),
Checks: toProtocolChecks(checks),
}
}
@@ -561,7 +561,7 @@ func (s *GRPCServer) IsHealthy(ctx context.Context, req *proto.Empty) (*proto.Em
}
// sendInitialSync sends initial proto.SyncResponse to the peer requesting synchronization
func (s *GRPCServer) sendInitialSync(peerKey wgtypes.Key, peer *nbpeer.Peer, networkMap *NetworkMap, srv proto.ManagementService_SyncServer) error {
func (s *GRPCServer) sendInitialSync(peerKey wgtypes.Key, peer *nbpeer.Peer, networkMap *NetworkMap, postureChecks []*posture.Checks, srv proto.ManagementService_SyncServer) error {
// make secret time based TURN credentials optional
var turnCredentials *TURNCredentials
if s.config.TURNConfig.TimeBasedCredentials {
@@ -570,7 +570,7 @@ func (s *GRPCServer) sendInitialSync(peerKey wgtypes.Key, peer *nbpeer.Peer, net
} else {
turnCredentials = nil
}
plainResp := toSyncResponse(s.accountManager, s.config, peer, turnCredentials, networkMap, s.accountManager.GetDNSDomain())
plainResp := toSyncResponse(s.config, peer, turnCredentials, networkMap, s.accountManager.GetDNSDomain(), postureChecks)
encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, plainResp)
if err != nil {
@@ -715,15 +715,9 @@ func (s *GRPCServer) SyncMeta(ctx context.Context, req *proto.EncryptedMessage)
return &proto.Empty{}, nil
}
// toProtocolChecks returns posture checks for the peer that needs to be evaluated on the client side.
func toProtocolChecks(accountManager AccountManager, peerKey string) []*proto.Checks {
postureChecks, err := accountManager.GetPeerAppliedPostureChecks(peerKey)
if err != nil {
log.Errorf("failed getting peer's: %s posture checks: %v", peerKey, err)
return nil
}
protoChecks := make([]*proto.Checks, 0)
// toProtocolChecks converts posture checks to protocol checks.
func toProtocolChecks(postureChecks []*posture.Checks) []*proto.Checks {
protoChecks := make([]*proto.Checks, 0, len(postureChecks))
for _, postureCheck := range postureChecks {
protoChecks = append(protoChecks, toProtocolCheck(postureCheck))
}
@@ -732,7 +726,7 @@ func toProtocolChecks(accountManager AccountManager, peerKey string) []*proto.Ch
}
// toProtocolCheck converts a posture.Checks to a proto.Checks.
func toProtocolCheck(postureCheck posture.Checks) *proto.Checks {
func toProtocolCheck(postureCheck *posture.Checks) *proto.Checks {
protoCheck := &proto.Checks{}
if check := postureCheck.Checks.ProcessCheck; check != nil {

121
management/server/hash.go Normal file
View File

@@ -0,0 +1,121 @@
package server
import (
"github.com/mitchellh/hashstructure/v2"
"github.com/r3labs/diff"
log "github.com/sirupsen/logrus"
)
func updateAccountPeers(account *Account) {
//start := time.Now()
//defer func() {
// duration := time.Since(start)
// log.Printf("Finished execution of updateAccountPeers, took %v\n", duration)
//}()
peers := account.GetPeers()
approvedPeersMap := make(map[string]struct{}, len(peers))
for _, peer := range peers {
approvedPeersMap[peer.ID] = struct{}{}
}
for _, peer := range peers {
//if !am.peersUpdateManager.HasChannel(peer.ID) {
// log.Tracef("peer %s doesn't have a channel, skipping network map update", peer.ID)
// continue
//}
_ = account.GetPeerNetworkMap(peer.ID, "netbird.io", approvedPeersMap)
//remotePeerNetworkMap := account.GetPeerNetworkMap(peer.ID, am.dnsDomain, approvedPeersMap)
//postureChecks := am.getPeerPostureChecks(account, peer)
//update := toSyncResponse(nil, peer, nil, remotePeerNetworkMap, am.GetDNSDomain(), postureChecks)
//am.peersUpdateManager.SendUpdate(peer.ID, &UpdateMessage{Update: update})
}
}
func updateAccountPeersWithHash(account *Account) {
//start := time.Now()
//var skipUpdate int
//defer func() {
// duration := time.Since(start)
// log.Printf("Finished execution of updateAccountPeers, took %v\n", duration.Nanoseconds())
// log.Println("not updated peers: ", skipUpdate)
//}()
peers := account.GetPeers()
approvedPeersMap := make(map[string]struct{}, len(peers))
for _, peer := range peers {
approvedPeersMap[peer.ID] = struct{}{}
}
for _, peer := range peers {
//if !am.peersUpdateManager.HasChannel(peer.ID) {
// log.Tracef("peer %s doesn't have a channel, skipping network map update", peer.ID)
// continue
//}
//33006042459
// 8700718125
remotePeerNetworkMap := account.GetPeerNetworkMap(peer.ID, "netbird.io", approvedPeersMap)
//log.Println("firewall rules: ", len(remotePeerNetworkMap.FirewallRules))
hashStr, err := hashstructure.Hash(remotePeerNetworkMap, hashstructure.FormatV2, &hashstructure.HashOptions{
ZeroNil: true,
IgnoreZeroValue: true,
SlicesAsSets: true,
UseStringer: true,
//Hasher: xxhash.New(),
})
if err != nil {
log.Errorf("failed to generate network map hash: %v", err)
} else {
if peer.NetworkMapHash == hashStr {
//log.Debugf("not sending network map update to peer: %s as there is nothing new", peer.ID)
//skipUpdate++
continue
}
peer.NetworkMapHash = hashStr
}
}
}
func updateAccountPeersWithDiff(account *Account) {
//start := time.Now()
//var skipUpdate int
//defer func() {
// duration := time.Since(start)
// log.Printf("Finished execution of updateAccountPeers, took %v\n", duration.Nanoseconds())
// log.Println("not updated peers: ", skipUpdate)
//}()
peers := account.GetPeers()
approvedPeersMap := make(map[string]struct{}, len(peers))
for _, peer := range peers {
approvedPeersMap[peer.ID] = struct{}{}
}
for _, peer := range peers {
//if !am.peersUpdateManager.HasChannel(peer.ID) {
// log.Tracef("peer %s doesn't have a channel, skipping network map update", peer.ID)
// continue
//}
//33006042459
// 8700718125
remotePeerNetworkMap := account.GetPeerNetworkMap(peer.ID, "netbird.io", approvedPeersMap)
peer.NetworkMap = remotePeerNetworkMap
changelog, err := diff.Diff(peer.NetworkMap, remotePeerNetworkMap)
if err != nil {
log.Errorf("failed to generate network map diff: %v", err)
} else {
if len(changelog) == 0 {
continue
}
}
}
}
//48868101197
// 8700718125

View File

@@ -0,0 +1,424 @@
package server
import (
"fmt"
"net/netip"
"testing"
"time"
"github.com/mitchellh/hashstructure/v2"
nbdns "github.com/netbirdio/netbird/dns"
nbgroup "github.com/netbirdio/netbird/management/server/group"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
route2 "github.com/netbirdio/netbird/route"
"github.com/r3labs/diff"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
)
func initTestAccount(b *testing.B, numPerAccount int) *Account {
b.Helper()
account := newAccountWithId("account_id", "testuser", "")
groupALL, err := account.GetGroupAll()
if err != nil {
b.Fatal(err)
}
setupKey := GenerateDefaultSetupKey()
account.SetupKeys[setupKey.Key] = setupKey
for n := 0; n < numPerAccount; n++ {
netIP := randomIPv4()
peerID := fmt.Sprintf("%s-peer-%d", account.Id, n)
peer := &nbpeer.Peer{
ID: peerID,
Key: peerID,
SetupKey: "",
IP: netIP,
Name: peerID,
DNSLabel: peerID,
UserID: userID,
Status: &nbpeer.PeerStatus{Connected: false, LastSeen: time.Now()},
SSHEnabled: false,
}
account.Peers[peerID] = peer
group, _ := account.GetGroupAll()
group.Peers = append(group.Peers, peerID)
user := &User{
Id: fmt.Sprintf("%s-user-%d", account.Id, n),
AccountID: account.Id,
}
account.Users[user.Id] = user
route := &route2.Route{
ID: route2.ID(fmt.Sprintf("network-id-%d", n)),
Description: "base route",
NetID: route2.NetID(fmt.Sprintf("network-id-%d", n)),
Network: netip.MustParsePrefix(netIP.String() + "/24"),
NetworkType: route2.IPv4Network,
Metric: 9999,
Masquerade: false,
Enabled: true,
Groups: []string{groupALL.ID},
}
account.Routes[route.ID] = route
group = &nbgroup.Group{
ID: fmt.Sprintf("group-id-%d", n),
AccountID: account.Id,
Name: fmt.Sprintf("group-id-%d", n),
Issued: "api",
Peers: nil,
}
account.Groups[group.ID] = group
nameserver := &nbdns.NameServerGroup{
ID: fmt.Sprintf("nameserver-id-%d", n),
AccountID: account.Id,
Name: fmt.Sprintf("nameserver-id-%d", n),
Description: "",
NameServers: []nbdns.NameServer{{IP: netip.MustParseAddr(netIP.String()), NSType: nbdns.UDPNameServerType}},
Groups: []string{group.ID},
Primary: false,
Domains: nil,
Enabled: false,
SearchDomainsEnabled: false,
}
account.NameServerGroups[nameserver.ID] = nameserver
setupKey := GenerateDefaultSetupKey()
account.SetupKeys[setupKey.Key] = setupKey
}
group := &nbgroup.Group{
ID: "randomID",
AccountID: account.Id,
Name: "randomName",
Issued: "api",
Peers: groupALL.Peers[:numPerAccount-1],
}
account.Groups[group.ID] = group
account.Policies = []*Policy{
{
ID: "RuleDefault",
Name: "Default",
Description: "This is a default rule that allows connections between all the resources",
Enabled: true,
Rules: []*PolicyRule{
{
ID: "RuleDefault",
Name: "Default",
Description: "This is a default rule that allows connections between all the resources",
Bidirectional: true,
Enabled: true,
Protocol: PolicyRuleProtocolTCP,
Action: PolicyTrafficActionAccept,
Sources: []string{
group.ID,
},
Destinations: []string{
group.ID,
},
},
{
ID: "RuleDefault2",
Name: "Default",
Description: "This is a default rule that allows connections between all the resources",
Bidirectional: true,
Enabled: true,
Protocol: PolicyRuleProtocolUDP,
Action: PolicyTrafficActionAccept,
Sources: []string{
groupALL.ID,
},
Destinations: []string{
groupALL.ID,
},
},
},
},
}
return account
}
// 1000 - 6717416375 ns/op
// 500 - 1732888875 ns/op
func BenchmarkTest_updateAccountPeers100(b *testing.B) {
account := initTestAccount(b, 100)
b.ResetTimer()
for i := 0; i < b.N; i++ {
updateAccountPeers(account)
}
}
// 1000 - 28943404000 ns/op
// 500 - 7365024500 ns/op
func BenchmarkTest_updateAccountPeersWithHash100(b *testing.B) {
account := initTestAccount(b, 100)
b.ResetTimer()
for i := 0; i < b.N; i++ {
log.Debug(i)
updateAccountPeersWithHash(account)
}
}
func BenchmarkTest_updateAccountPeersWithDiff100(b *testing.B) {
account := initTestAccount(b, 100)
b.ResetTimer()
for i := 0; i < b.N; i++ {
log.Debug(i)
updateAccountPeersWithDiff(account)
}
}
// 1000 - 6717416375 ns/op
// 500 - 1732888875 ns/op
func BenchmarkTest_updateAccountPeers200(b *testing.B) {
account := initTestAccount(b, 200)
b.ResetTimer()
for i := 0; i < b.N; i++ {
updateAccountPeers(account)
}
}
// 1000 - 28943404000 ns/op
// 500 - 7365024500 ns/op
func BenchmarkTest_updateAccountPeersWithHash200(b *testing.B) {
account := initTestAccount(b, 200)
b.ResetTimer()
for i := 0; i < b.N; i++ {
log.Debug(i)
updateAccountPeersWithHash(account)
}
}
func BenchmarkTest_updateAccountPeersWithDiff200(b *testing.B) {
account := initTestAccount(b, 200)
b.ResetTimer()
for i := 0; i < b.N; i++ {
log.Debug(i)
updateAccountPeersWithDiff(account)
}
}
func BenchmarkTest_updateAccountPeers500(b *testing.B) {
account := initTestAccount(b, 500)
b.ResetTimer()
for i := 0; i < b.N; i++ {
updateAccountPeers(account)
}
}
// 1000 - 28943404000 ns/op
// 500 - 7365024500 ns/op
func BenchmarkTest_updateAccountPeersWithHash500(b *testing.B) {
account := initTestAccount(b, 500)
b.ResetTimer()
for i := 0; i < b.N; i++ {
log.Debug(i)
updateAccountPeersWithHash(account)
}
}
func BenchmarkTest_updateAccountPeersWithDiff500(b *testing.B) {
account := initTestAccount(b, 500)
b.ResetTimer()
for i := 0; i < b.N; i++ {
log.Debug(i)
updateAccountPeersWithDiff(account)
}
}
func BenchmarkTest_updateAccountPeers1000(b *testing.B) {
account := initTestAccount(b, 1000)
b.ResetTimer()
for i := 0; i < b.N; i++ {
updateAccountPeers(account)
}
}
// 1000 - 28943404000 ns/op
// 500 - 7365024500 ns/op
func BenchmarkTest_updateAccountPeersWithHash1000(b *testing.B) {
account := initTestAccount(b, 1000)
b.ResetTimer()
for i := 0; i < b.N; i++ {
log.Debug(i)
updateAccountPeersWithHash(account)
}
}
func BenchmarkTest_updateAccountPeersWithDiff1000(b *testing.B) {
account := initTestAccount(b, 1000)
b.ResetTimer()
for i := 0; i < b.N; i++ {
log.Debug(i)
updateAccountPeersWithDiff(account)
}
}
func BenchmarkTest_updateAccountPeers1500(b *testing.B) {
account := initTestAccount(b, 1500)
b.ResetTimer()
for i := 0; i < b.N; i++ {
updateAccountPeers(account)
}
}
// 1000 - 28943404000 ns/op
// 500 - 7365024500 ns/op
func BenchmarkTest_updateAccountPeersWithHash1500(b *testing.B) {
account := initTestAccount(b, 1500)
b.ResetTimer()
for i := 0; i < b.N; i++ {
log.Debug(i)
updateAccountPeersWithHash(account)
}
}
func BenchmarkTest_updateAccountPeersWithDiff1500(b *testing.B) {
account := initTestAccount(b, 1500)
b.ResetTimer()
for i := 0; i < b.N; i++ {
log.Debug(i)
updateAccountPeersWithDiff(account)
}
}
func BenchmarkTest_updateAccountPeers2000(b *testing.B) {
account := initTestAccount(b, 2000)
b.ResetTimer()
for i := 0; i < b.N; i++ {
updateAccountPeers(account)
}
}
// 1000 - 28943404000 ns/op
// 500 - 7365024500 ns/op
func BenchmarkTest_updateAccountPeersWithHash2000(b *testing.B) {
account := initTestAccount(b, 2000)
b.ResetTimer()
for i := 0; i < b.N; i++ {
log.Debug(i)
updateAccountPeersWithHash(account)
}
}
func BenchmarkTest_updateAccountPeersWithDiff2000(b *testing.B) {
account := initTestAccount(b, 2000)
b.ResetTimer()
for i := 0; i < b.N; i++ {
log.Debug(i)
updateAccountPeersWithDiff(account)
}
}
type TestStruct struct {
Name string
Value int
Ignored string `diff:"-" hash:"ignore"`
Compared string
}
func TestDiffIgnoreTag(t *testing.T) {
a := TestStruct{
Name: "test",
Value: 30,
Ignored: "This should be ignored",
Compared: "This should be compared",
}
b := TestStruct{
Name: "test",
Value: 31,
Ignored: "This is different but should be ignored",
Compared: "This is different and should be compared",
}
changelog, err := diff.Diff(a, b)
assert.NoError(t, err)
// Check that only the expected fields are in the changelog
assert.Len(t, changelog, 2)
// Check that the 'Age' field change is detected
ageChange := getChangeForField(changelog, "Value")
assert.NotNil(t, ageChange)
assert.Equal(t, 30, ageChange.From)
assert.Equal(t, 31, ageChange.To)
// Check that the 'Compared' field change is detected
comparedChange := getChangeForField(changelog, "Compared")
assert.NotNil(t, comparedChange)
assert.Equal(t, "This should be compared", comparedChange.From)
assert.Equal(t, "This is different and should be compared", comparedChange.To)
// Check that the 'Ignored' field is not in the changelog
ignoredChange := getChangeForField(changelog, "Ignored")
assert.Nil(t, ignoredChange)
}
func TestHashIgnoreTag(t *testing.T) {
a := TestStruct{
Name: "test",
Value: 30,
Ignored: "This should be ignored",
Compared: "This should be compared",
}
b := TestStruct{
Name: "test",
Value: 30,
Ignored: "This is different but should be ignored",
Compared: "This should be compared",
}
c := TestStruct{
Name: "test",
Value: 31,
Ignored: "This should be ignored",
Compared: "This should be compared",
}
d := TestStruct{
Name: "test",
Value: 30,
Ignored: "This should be ignored",
Compared: "This is different and should be compared",
}
opts := &hashstructure.HashOptions{
ZeroNil: true,
IgnoreZeroValue: true,
SlicesAsSets: true,
UseStringer: true,
}
hashA, err := hashstructure.Hash(a, hashstructure.FormatV2, opts)
assert.NoError(t, err)
hashB, err := hashstructure.Hash(b, hashstructure.FormatV2, opts)
assert.NoError(t, err)
hashC, err := hashstructure.Hash(c, hashstructure.FormatV2, opts)
assert.NoError(t, err)
hashD, err := hashstructure.Hash(d, hashstructure.FormatV2, opts)
assert.NoError(t, err)
// Test that changing the ignored field does not change the hash
assert.Equal(t, hashA, hashB)
// Test that changing a non-ignored field does change the hash
assert.NotEqual(t, hashA, hashC)
assert.NotEqual(t, hashA, hashD)
}
func getChangeForField(changelog diff.Changelog, fieldName string) *diff.Change {
for _, change := range changelog {
if change.Path[0] == fieldName {
return &change
}
}
return nil
}

View File

@@ -30,11 +30,11 @@ type MockAccountManager struct {
ListUsersFunc func(accountID string) ([]*server.User, error)
GetPeersFunc func(accountID, userID string) ([]*nbpeer.Peer, error)
MarkPeerConnectedFunc func(peerKey string, connected bool, realIP net.IP) error
SyncAndMarkPeerFunc func(peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *server.NetworkMap, error)
SyncAndMarkPeerFunc func(peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error)
DeletePeerFunc func(accountID, peerKey, userID string) error
GetNetworkMapFunc func(peerKey string) (*server.NetworkMap, error)
GetPeerNetworkFunc func(peerKey string) (*server.Network, error)
AddPeerFunc func(setupKey string, userId string, peer *nbpeer.Peer) (*nbpeer.Peer, *server.NetworkMap, error)
AddPeerFunc func(setupKey string, userId string, peer *nbpeer.Peer) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error)
GetGroupFunc func(accountID, groupID, userID string) (*group.Group, error)
GetAllGroupsFunc func(accountID, userID string) ([]*group.Group, error)
GetGroupByNameFunc func(accountID, groupName string) (*group.Group, error)
@@ -83,10 +83,9 @@ type MockAccountManager struct {
GetDNSSettingsFunc func(accountID, userID string) (*server.DNSSettings, error)
SaveDNSSettingsFunc func(accountID, userID string, dnsSettingsToSave *server.DNSSettings) error
GetPeerFunc func(accountID, peerID, userID string) (*nbpeer.Peer, error)
GetPeerAppliedPostureChecksFunc func(peerKey string) ([]posture.Checks, error)
UpdateAccountSettingsFunc func(accountID, userID string, newSettings *server.Settings) (*server.Account, error)
LoginPeerFunc func(login server.PeerLogin) (*nbpeer.Peer, *server.NetworkMap, error)
SyncPeerFunc func(sync server.PeerSync, account *server.Account) (*nbpeer.Peer, *server.NetworkMap, error)
LoginPeerFunc func(login server.PeerLogin) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error)
SyncPeerFunc func(sync server.PeerSync, account *server.Account) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error)
InviteUserFunc func(accountID string, initiatorUserID string, targetUserEmail string) error
GetAllConnectedPeersFunc func() (map[string]struct{}, error)
HasConnectedChannelFunc func(peerID string) bool
@@ -102,11 +101,11 @@ type MockAccountManager struct {
FindExistingPostureCheckFunc func(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error)
}
func (am *MockAccountManager) SyncAndMarkPeer(peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *server.NetworkMap, error) {
func (am *MockAccountManager) SyncAndMarkPeer(peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error) {
if am.SyncAndMarkPeerFunc != nil {
return am.SyncAndMarkPeerFunc(peerPubKey, meta, realIP)
}
return nil, nil, status.Errorf(codes.Unimplemented, "method MarkPeerConnected is not implemented")
return nil, nil, nil, status.Errorf(codes.Unimplemented, "method MarkPeerConnected is not implemented")
}
func (am *MockAccountManager) CancelPeerRoutines(peer *nbpeer.Peer) error {
@@ -282,11 +281,11 @@ func (am *MockAccountManager) AddPeer(
setupKey string,
userId string,
peer *nbpeer.Peer,
) (*nbpeer.Peer, *server.NetworkMap, error) {
) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error) {
if am.AddPeerFunc != nil {
return am.AddPeerFunc(setupKey, userId, peer)
}
return nil, nil, status.Errorf(codes.Unimplemented, "method AddPeer is not implemented")
return nil, nil, nil, status.Errorf(codes.Unimplemented, "method AddPeer is not implemented")
}
// GetGroupByName mock implementation of GetGroupByName from server.AccountManager interface
@@ -627,14 +626,6 @@ func (am *MockAccountManager) GetPeer(accountID, peerID, userID string) (*nbpeer
return nil, status.Errorf(codes.Unimplemented, "method GetPeer is not implemented")
}
// GetPeerAppliedPostureChecks mocks GetPeerAppliedPostureChecks of the AccountManager interface
func (am *MockAccountManager) GetPeerAppliedPostureChecks(peerKey string) ([]posture.Checks, error) {
if am.GetPeerAppliedPostureChecksFunc != nil {
return am.GetPeerAppliedPostureChecksFunc(peerKey)
}
return nil, status.Errorf(codes.Unimplemented, "method GetPeerAppliedPostureChecks is not implemented")
}
// UpdateAccountSettings mocks UpdateAccountSettings of the AccountManager interface
func (am *MockAccountManager) UpdateAccountSettings(accountID, userID string, newSettings *server.Settings) (*server.Account, error) {
if am.UpdateAccountSettingsFunc != nil {
@@ -644,19 +635,19 @@ func (am *MockAccountManager) UpdateAccountSettings(accountID, userID string, ne
}
// LoginPeer mocks LoginPeer of the AccountManager interface
func (am *MockAccountManager) LoginPeer(login server.PeerLogin) (*nbpeer.Peer, *server.NetworkMap, error) {
func (am *MockAccountManager) LoginPeer(login server.PeerLogin) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error) {
if am.LoginPeerFunc != nil {
return am.LoginPeerFunc(login)
}
return nil, nil, status.Errorf(codes.Unimplemented, "method LoginPeer is not implemented")
return nil, nil, nil, status.Errorf(codes.Unimplemented, "method LoginPeer is not implemented")
}
// SyncPeer mocks SyncPeer of the AccountManager interface
func (am *MockAccountManager) SyncPeer(sync server.PeerSync, account *server.Account) (*nbpeer.Peer, *server.NetworkMap, error) {
func (am *MockAccountManager) SyncPeer(sync server.PeerSync, account *server.Account) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error) {
if am.SyncPeerFunc != nil {
return am.SyncPeerFunc(sync, account)
}
return nil, nil, status.Errorf(codes.Unimplemented, "method SyncPeer is not implemented")
return nil, nil, nil, status.Errorf(codes.Unimplemented, "method SyncPeer is not implemented")
}
// GetAllConnectedPeers mocks GetAllConnectedPeers of the AccountManager interface

View File

@@ -851,11 +851,11 @@ func initTestNSAccount(t *testing.T, am *DefaultAccountManager) (*Account, error
return nil, err
}
_, _, err = am.AddPeer("", userID, peer1)
_, _, _, err = am.AddPeer("", userID, peer1)
if err != nil {
return nil, err
}
_, _, err = am.AddPeer("", userID, peer2)
_, _, _, err = am.AddPeer("", userID, peer2)
if err != nil {
return nil, err
}

View File

@@ -40,9 +40,9 @@ type Network struct {
Dns string
// Serial is an ID that increments by 1 when any change to the network happened (e.g. new peer has been added).
// Used to synchronize state to the client apps.
Serial uint64
Serial uint64 `diff:"-"`
mu sync.Mutex `json:"-" gorm:"-"`
mu sync.Mutex `json:"-" gorm:"-" diff:"-"`
}
// NewNetwork creates a new Network initializing it with a Serial=0

View File

@@ -6,6 +6,8 @@ import (
"strings"
"time"
"github.com/mitchellh/hashstructure/v2"
"github.com/netbirdio/netbird/management/server/posture"
"github.com/rs/xid"
log "github.com/sirupsen/logrus"
@@ -333,10 +335,10 @@ func (am *DefaultAccountManager) GetPeerNetwork(peerID string) (*Network, error)
// to it. We also add the User ID to the peer metadata to identify registrant. If no userID provided, then fail with status.PermissionDenied
// Each new Peer will be assigned a new next net.IP from the Account.Network and Account.Network.LastIP will be updated (IP's are not reused).
// The peer property is just a placeholder for the Peer properties to pass further
func (am *DefaultAccountManager) AddPeer(setupKey, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, *NetworkMap, error) {
func (am *DefaultAccountManager) AddPeer(setupKey, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) {
if setupKey == "" && userID == "" {
// no auth method provided => reject access
return nil, nil, status.Errorf(status.Unauthenticated, "no peer auth method provided, please use a setup key or interactive SSO login")
return nil, nil, nil, status.Errorf(status.Unauthenticated, "no peer auth method provided, please use a setup key or interactive SSO login")
}
upperKey := strings.ToUpper(setupKey)
@@ -350,7 +352,7 @@ func (am *DefaultAccountManager) AddPeer(setupKey, userID string, peer *nbpeer.P
accountID, err = am.Store.GetAccountIDBySetupKey(setupKey)
}
if err != nil {
return nil, nil, status.Errorf(status.NotFound, "failed adding new peer: account not found")
return nil, nil, nil, status.Errorf(status.NotFound, "failed adding new peer: account not found")
}
unlock := am.Store.AcquireAccountWriteLock(accountID)
@@ -364,7 +366,7 @@ func (am *DefaultAccountManager) AddPeer(setupKey, userID string, peer *nbpeer.P
// ensure that we consider modification happened meanwhile (because we were outside the account lock when we fetched the account)
account, err = am.Store.GetAccount(accountID)
if err != nil {
return nil, nil, err
return nil, nil, nil, err
}
if strings.ToLower(peer.Meta.Hostname) == "iphone" || strings.ToLower(peer.Meta.Hostname) == "ipad" && userID != "" {
@@ -383,7 +385,7 @@ func (am *DefaultAccountManager) AddPeer(setupKey, userID string, peer *nbpeer.P
// The connecting peer should be able to recover with a retry.
_, err = account.FindPeerByPubKey(peer.Key)
if err == nil {
return nil, nil, status.Errorf(status.PreconditionFailed, "peer has been already registered")
return nil, nil, nil, status.Errorf(status.PreconditionFailed, "peer has been already registered")
}
opEvent := &activity.Event{
@@ -397,11 +399,11 @@ func (am *DefaultAccountManager) AddPeer(setupKey, userID string, peer *nbpeer.P
// validate the setup key if adding with a key
sk, err := account.FindSetupKey(upperKey)
if err != nil {
return nil, nil, err
return nil, nil, nil, err
}
if !sk.IsValid() {
return nil, nil, status.Errorf(status.PreconditionFailed, "couldn't add peer: setup key is invalid")
return nil, nil, nil, status.Errorf(status.PreconditionFailed, "couldn't add peer: setup key is invalid")
}
account.SetupKeys[sk.Key] = sk.IncrementUsage()
@@ -419,14 +421,14 @@ func (am *DefaultAccountManager) AddPeer(setupKey, userID string, peer *nbpeer.P
newLabel, err := getPeerHostLabel(peer.Meta.Hostname, existingLabels)
if err != nil {
return nil, nil, err
return nil, nil, nil, err
}
peer.DNSLabel = newLabel
network := account.Network
nextIp, err := AllocatePeerIP(network.Net, takenIps)
if err != nil {
return nil, nil, err
return nil, nil, nil, err
}
registrationTime := time.Now().UTC()
@@ -453,7 +455,7 @@ func (am *DefaultAccountManager) AddPeer(setupKey, userID string, peer *nbpeer.P
// add peer to 'All' group
group, err := account.GetGroupAll()
if err != nil {
return nil, nil, err
return nil, nil, nil, err
}
group.Peers = append(group.Peers, newPeer.ID)
@@ -461,12 +463,12 @@ func (am *DefaultAccountManager) AddPeer(setupKey, userID string, peer *nbpeer.P
if addedByUser {
groupsToAdd, err = account.getUserGroups(userID)
if err != nil {
return nil, nil, err
return nil, nil, nil, err
}
} else {
groupsToAdd, err = account.getSetupKeyGroups(upperKey)
if err != nil {
return nil, nil, err
return nil, nil, nil, err
}
}
@@ -483,7 +485,7 @@ func (am *DefaultAccountManager) AddPeer(setupKey, userID string, peer *nbpeer.P
if addedByUser {
user, err := account.FindUser(userID)
if err != nil {
return nil, nil, status.Errorf(status.Internal, "couldn't find user")
return nil, nil, nil, status.Errorf(status.Internal, "couldn't find user")
}
user.updateLastLogin(newPeer.LastLogin)
}
@@ -492,7 +494,7 @@ func (am *DefaultAccountManager) AddPeer(setupKey, userID string, peer *nbpeer.P
account.Network.IncSerial()
err = am.Store.SaveAccount(account)
if err != nil {
return nil, nil, err
return nil, nil, nil, err
}
// Account is saved, we can release the lock
@@ -511,33 +513,35 @@ func (am *DefaultAccountManager) AddPeer(setupKey, userID string, peer *nbpeer.P
approvedPeersMap, err := am.GetValidatedPeers(account)
if err != nil {
return nil, nil, err
return nil, nil, nil, err
}
postureChecks := am.getPeerPostureChecks(account, peer)
networkMap := account.GetPeerNetworkMap(newPeer.ID, am.dnsDomain, approvedPeersMap)
return newPeer, networkMap, nil
return newPeer, networkMap, postureChecks, nil
}
// SyncPeer checks whether peer is eligible for receiving NetworkMap (authenticated) and returns its NetworkMap if eligible
func (am *DefaultAccountManager) SyncPeer(sync PeerSync, account *Account) (*nbpeer.Peer, *NetworkMap, error) {
func (am *DefaultAccountManager) SyncPeer(sync PeerSync, account *Account) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) {
peer, err := account.FindPeerByPubKey(sync.WireGuardPubKey)
if err != nil {
return nil, nil, status.NewPeerNotRegisteredError()
return nil, nil, nil, status.NewPeerNotRegisteredError()
}
err = checkIfPeerOwnerIsBlocked(peer, account)
if err != nil {
return nil, nil, err
return nil, nil, nil, err
}
if peerLoginExpired(peer, account.Settings) {
return nil, nil, status.Errorf(status.PermissionDenied, "peer login has expired, please log in once more")
return nil, nil, nil, status.Errorf(status.PermissionDenied, "peer login has expired, please log in once more")
}
peer, updated := updatePeerMeta(peer, sync.Meta, account)
if updated {
err = am.Store.SaveAccount(account)
if err != nil {
return nil, nil, err
return nil, nil, nil, err
}
if sync.UpdateAccountPeers {
@@ -547,14 +551,16 @@ func (am *DefaultAccountManager) SyncPeer(sync PeerSync, account *Account) (*nbp
peerNotValid, isStatusChanged, err := am.integratedPeerValidator.IsNotValidPeer(account.Id, peer, account.GetPeerGroupsList(peer.ID), account.Settings.Extra)
if err != nil {
return nil, nil, err
return nil, nil, nil, err
}
var postureChecks []*posture.Checks
if peerNotValid {
emptyMap := &NetworkMap{
Network: account.Network.Copy(),
}
return peer, emptyMap, nil
return peer, emptyMap, postureChecks, nil
}
if isStatusChanged {
@@ -563,14 +569,16 @@ func (am *DefaultAccountManager) SyncPeer(sync PeerSync, account *Account) (*nbp
validPeersMap, err := am.GetValidatedPeers(account)
if err != nil {
return nil, nil, err
return nil, nil, nil, err
}
return peer, account.GetPeerNetworkMap(peer.ID, am.dnsDomain, validPeersMap), nil
postureChecks = am.getPeerPostureChecks(account, peer)
return peer, account.GetPeerNetworkMap(peer.ID, am.dnsDomain, validPeersMap), postureChecks, nil
}
// LoginPeer logs in or registers a peer.
// If peer doesn't exist the function checks whether a setup key or a user is present and registers a new peer if so.
func (am *DefaultAccountManager) LoginPeer(login PeerLogin) (*nbpeer.Peer, *NetworkMap, error) {
func (am *DefaultAccountManager) LoginPeer(login PeerLogin) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) {
accountID, err := am.Store.GetAccountIDByPeerPubKey(login.WireGuardPubKey)
if err != nil {
if errStatus, ok := status.FromError(err); ok && errStatus.Type() == status.NotFound {
@@ -596,18 +604,19 @@ func (am *DefaultAccountManager) LoginPeer(login PeerLogin) (*nbpeer.Peer, *Netw
return am.AddPeer(login.SetupKey, login.UserID, newPeer)
}
log.Errorf("failed while logging in peer %s: %v", login.WireGuardPubKey, err)
return nil, nil, status.Errorf(status.Internal, "failed while logging in peer")
return nil, nil, nil, status.Errorf(status.Internal, "failed while logging in peer")
}
peer, err := am.Store.GetPeerByPeerPubKey(login.WireGuardPubKey)
if err != nil {
return nil, nil, status.NewPeerNotRegisteredError()
return nil, nil, nil, status.NewPeerNotRegisteredError()
}
accSettings, err := am.Store.GetAccountSettings(accountID)
if err != nil {
return nil, nil, status.Errorf(status.Internal, "failed to get account settings: %s", err)
return nil, nil, nil, status.Errorf(status.Internal, "failed to get account settings: %s", err)
}
var isWriteLock bool
@@ -617,7 +626,7 @@ func (am *DefaultAccountManager) LoginPeer(login PeerLogin) (*nbpeer.Peer, *Netw
switch {
case expired:
if err := checkAuth(login.UserID, peer); err != nil {
return nil, nil, err
return nil, nil, nil, err
}
isWriteLock = true
log.Debugf("peer login expired, acquiring write lock")
@@ -647,17 +656,17 @@ func (am *DefaultAccountManager) LoginPeer(login PeerLogin) (*nbpeer.Peer, *Netw
// fetch the account from the store once more after acquiring lock to avoid concurrent updates inconsistencies
account, err := am.Store.GetAccount(accountID)
if err != nil {
return nil, nil, err
return nil, nil, nil, err
}
peer, err = account.FindPeerByPubKey(login.WireGuardPubKey)
if err != nil {
return nil, nil, status.NewPeerNotRegisteredError()
return nil, nil, nil, status.NewPeerNotRegisteredError()
}
err = checkIfPeerOwnerIsBlocked(peer, account)
if err != nil {
return nil, nil, err
return nil, nil, nil, err
}
// this flag prevents unnecessary calls to the persistent store.
@@ -666,7 +675,7 @@ func (am *DefaultAccountManager) LoginPeer(login PeerLogin) (*nbpeer.Peer, *Netw
if peerLoginExpired(peer, account.Settings) {
err = checkAuth(login.UserID, peer)
if err != nil {
return nil, nil, err
return nil, nil, nil, err
}
// If peer was expired before and if it reached this point, it is re-authenticated.
// UserID is present, meaning that JWT validation passed successfully in the API layer.
@@ -677,7 +686,7 @@ func (am *DefaultAccountManager) LoginPeer(login PeerLogin) (*nbpeer.Peer, *Netw
// sync user last login with peer last login
user, err := account.FindUser(login.UserID)
if err != nil {
return nil, nil, status.Errorf(status.Internal, "couldn't find user")
return nil, nil, nil, status.Errorf(status.Internal, "couldn't find user")
}
user.updateLastLogin(peer.LastLogin)
@@ -686,7 +695,7 @@ func (am *DefaultAccountManager) LoginPeer(login PeerLogin) (*nbpeer.Peer, *Netw
isRequiresApproval, isStatusChanged, err := am.integratedPeerValidator.IsNotValidPeer(account.Id, peer, account.GetPeerGroupsList(peer.ID), account.Settings.Extra)
if err != nil {
return nil, nil, err
return nil, nil, nil, err
}
peer, updated := updatePeerMeta(peer, login.Meta, account)
if updated {
@@ -695,17 +704,17 @@ func (am *DefaultAccountManager) LoginPeer(login PeerLogin) (*nbpeer.Peer, *Netw
peer, err = am.checkAndUpdatePeerSSHKey(peer, account, login.SSHKey)
if err != nil {
return nil, nil, err
return nil, nil, nil, err
}
if shouldStoreAccount {
if !isWriteLock {
log.Errorf("account %s should be stored but is not write locked", accountID)
return nil, nil, status.Errorf(status.Internal, "account should be stored but is not write locked")
return nil, nil, nil, status.Errorf(status.Internal, "account should be stored but is not write locked")
}
err = am.Store.SaveAccount(account)
if err != nil {
return nil, nil, err
return nil, nil, nil, err
}
}
unlock()
@@ -715,19 +724,22 @@ func (am *DefaultAccountManager) LoginPeer(login PeerLogin) (*nbpeer.Peer, *Netw
am.updateAccountPeers(account)
}
var postureChecks []*posture.Checks
if isRequiresApproval {
emptyMap := &NetworkMap{
Network: account.Network.Copy(),
}
return peer, emptyMap, nil
return peer, emptyMap, postureChecks, nil
}
approvedPeersMap, err := am.GetValidatedPeers(account)
if err != nil {
return nil, nil, err
return nil, nil, nil, err
}
postureChecks = am.getPeerPostureChecks(account, peer)
return peer, account.GetPeerNetworkMap(peer.ID, am.dnsDomain, approvedPeersMap), nil
return peer, account.GetPeerNetworkMap(peer.ID, am.dnsDomain, approvedPeersMap), postureChecks, nil
}
func checkIfPeerOwnerIsBlocked(peer *nbpeer.Peer, account *Account) error {
@@ -904,6 +916,18 @@ func updatePeerMeta(peer *nbpeer.Peer, meta nbpeer.PeerSystemMeta, account *Acco
// updateAccountPeers updates all peers that belong to an account.
// Should be called when changes have to be synced to peers.
func (am *DefaultAccountManager) updateAccountPeers(account *Account) {
start := time.Now()
var skipUpdate int
defer func() {
duration := time.Since(start)
log.Printf("Finished execution of updateAccountPeers, took %v\n", duration)
log.Println("not updated peers: ", skipUpdate)
}()
if am.networkMapHash == nil {
am.networkMapHash = map[string]uint64{}
}
peers := account.GetPeers()
approvedPeersMap, err := am.GetValidatedPeers(account)
@@ -911,13 +935,33 @@ func (am *DefaultAccountManager) updateAccountPeers(account *Account) {
log.Errorf("failed send out updates to peers, failed to validate peer: %v", err)
return
}
for _, peer := range peers {
if !am.peersUpdateManager.HasChannel(peer.ID) {
log.Tracef("peer %s doesn't have a channel, skipping network map update", peer.ID)
continue
}
//if !am.peersUpdateManager.HasChannel(peer.ID) {
// log.Tracef("peer %s doesn't have a channel, skipping network map update", peer.ID)
// continue
//}
remotePeerNetworkMap := account.GetPeerNetworkMap(peer.ID, am.dnsDomain, approvedPeersMap)
update := toSyncResponse(am, nil, peer, nil, remotePeerNetworkMap, am.GetDNSDomain())
hash, err := hashstructure.Hash(remotePeerNetworkMap, hashstructure.FormatV2, &hashstructure.HashOptions{
ZeroNil: true,
IgnoreZeroValue: true,
SlicesAsSets: true,
UseStringer: true,
})
if err != nil {
log.Errorf("failed to generate network map hash: %v", err)
} else {
if am.networkMapHash[peer.ID] == hash {
log.Debugf("not sending network map update to peer: %s as there is nothing new", peer.ID)
skipUpdate++
continue
}
am.networkMapHash[peer.ID] = hash
}
postureChecks := am.getPeerPostureChecks(account, peer)
update := toSyncResponse(nil, peer, nil, remotePeerNetworkMap, am.GetDNSDomain(), postureChecks)
am.peersUpdateManager.SendUpdate(peer.ID, &UpdateMessage{Update: update})
}
}

View File

@@ -18,35 +18,38 @@ type Peer struct {
// WireGuard public key
Key string `gorm:"index"`
// A setup key this peer was registered with
SetupKey string
SetupKey string `diff:"-" hash:"ignore"`
// IP address of the Peer
IP net.IP `gorm:"serializer:json"`
// Meta is a Peer system meta data
Meta PeerSystemMeta `gorm:"embedded;embeddedPrefix:meta_"`
Meta PeerSystemMeta `gorm:"embedded;embeddedPrefix:meta_" diff:"-" hash:"ignore"`
// Name is peer's name (machine name)
Name string
// DNSLabel is the parsed peer name for domain resolution. It is used to form an FQDN by appending the account's
// domain to the peer label. e.g. peer-dns-label.netbird.cloud
DNSLabel string
// Status peer's management connection status
Status *PeerStatus `gorm:"embedded;embeddedPrefix:peer_status_"`
Status *PeerStatus `gorm:"embedded;embeddedPrefix:peer_status_" diff:"-" hash:"ignore"`
// The user ID that registered the peer
UserID string
UserID string `diff:"-" hash:"ignore"`
// SSHKey is a public SSH key of the peer
SSHKey string
// SSHEnabled indicates whether SSH server is enabled on the peer
SSHEnabled bool
// LoginExpirationEnabled indicates whether peer's login expiration is enabled and once expired the peer has to re-login.
// Works with LastLogin
LoginExpirationEnabled bool
LoginExpirationEnabled bool `diff:"-" hash:"ignore"`
// LastLogin the time when peer performed last login operation
LastLogin time.Time
LastLogin time.Time `diff:"-" hash:"ignore"`
// CreatedAt records the time the peer was created
CreatedAt time.Time
CreatedAt time.Time `diff:"-" hash:"ignore"`
// Indicate ephemeral peer attribute
Ephemeral bool
Ephemeral bool `diff:"-" hash:"ignore"`
// Geo location based on connection IP
Location Location `gorm:"embedded;embeddedPrefix:location_"`
Location Location `gorm:"embedded;embeddedPrefix:location_" diff:"-" hash:"ignore"`
NetworkMap any `diff:"-" hash:"ignore"`
NetworkMapHash uint64 ` diff:"-" hash:"ignore"`
}
type PeerStatus struct { //nolint:revive

View File

@@ -92,7 +92,7 @@ func TestAccountManager_GetNetworkMap(t *testing.T) {
return
}
peer1, _, err := manager.AddPeer(setupKey.Key, "", &nbpeer.Peer{
peer1, _, _, err := manager.AddPeer(setupKey.Key, "", &nbpeer.Peer{
Key: peerKey1.PublicKey().String(),
Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer-1"},
})
@@ -106,7 +106,7 @@ func TestAccountManager_GetNetworkMap(t *testing.T) {
t.Fatal(err)
return
}
_, _, err = manager.AddPeer(setupKey.Key, "", &nbpeer.Peer{
_, _, _, err = manager.AddPeer(setupKey.Key, "", &nbpeer.Peer{
Key: peerKey2.PublicKey().String(),
Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer-2"},
})
@@ -165,7 +165,7 @@ func TestAccountManager_GetNetworkMapWithPolicy(t *testing.T) {
return
}
peer1, _, err := manager.AddPeer(setupKey.Key, "", &nbpeer.Peer{
peer1, _, _, err := manager.AddPeer(setupKey.Key, "", &nbpeer.Peer{
Key: peerKey1.PublicKey().String(),
Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer-1"},
})
@@ -179,7 +179,7 @@ func TestAccountManager_GetNetworkMapWithPolicy(t *testing.T) {
t.Fatal(err)
return
}
peer2, _, err := manager.AddPeer(setupKey.Key, "", &nbpeer.Peer{
peer2, _, _, err := manager.AddPeer(setupKey.Key, "", &nbpeer.Peer{
Key: peerKey2.PublicKey().String(),
Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer-2"},
})
@@ -341,7 +341,7 @@ func TestAccountManager_GetPeerNetwork(t *testing.T) {
return
}
peer1, _, err := manager.AddPeer(setupKey.Key, "", &nbpeer.Peer{
peer1, _, _, err := manager.AddPeer(setupKey.Key, "", &nbpeer.Peer{
Key: peerKey1.PublicKey().String(),
Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer-1"},
})
@@ -355,7 +355,7 @@ func TestAccountManager_GetPeerNetwork(t *testing.T) {
t.Fatal(err)
return
}
_, _, err = manager.AddPeer(setupKey.Key, "", &nbpeer.Peer{
_, _, _, err = manager.AddPeer(setupKey.Key, "", &nbpeer.Peer{
Key: peerKey2.PublicKey().String(),
Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer-2"},
})
@@ -413,7 +413,7 @@ func TestDefaultAccountManager_GetPeer(t *testing.T) {
return
}
peer1, _, err := manager.AddPeer("", someUser, &nbpeer.Peer{
peer1, _, _, err := manager.AddPeer("", someUser, &nbpeer.Peer{
Key: peerKey1.PublicKey().String(),
Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer-2"},
})
@@ -429,7 +429,7 @@ func TestDefaultAccountManager_GetPeer(t *testing.T) {
}
// the second peer added with a setup key
peer2, _, err := manager.AddPeer(setupKey.Key, "", &nbpeer.Peer{
peer2, _, _, err := manager.AddPeer(setupKey.Key, "", &nbpeer.Peer{
Key: peerKey2.PublicKey().String(),
Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer-2"},
})
@@ -601,7 +601,7 @@ func TestDefaultAccountManager_GetPeers(t *testing.T) {
return
}
_, _, err = manager.AddPeer("", someUser, &nbpeer.Peer{
_, _, _, err = manager.AddPeer("", someUser, &nbpeer.Peer{
Key: peerKey1.PublicKey().String(),
Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer-1"},
})
@@ -610,7 +610,7 @@ func TestDefaultAccountManager_GetPeers(t *testing.T) {
return
}
_, _, err = manager.AddPeer("", adminUser, &nbpeer.Peer{
_, _, _, err = manager.AddPeer("", adminUser, &nbpeer.Peer{
Key: peerKey2.PublicKey().String(),
Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer-2"},
})

View File

@@ -7,7 +7,6 @@ import (
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/management/server/status"
log "github.com/sirupsen/logrus"
)
const (
@@ -185,36 +184,14 @@ func (am *DefaultAccountManager) deletePostureChecks(account *Account, postureCh
return postureChecks, nil
}
// GetPeerAppliedPostureChecks returns posture checks that are applied to the peer.
func (am *DefaultAccountManager) GetPeerAppliedPostureChecks(peerKey string) ([]posture.Checks, error) {
account, err := am.Store.GetAccountByPeerPubKey(peerKey)
if err != nil {
log.Errorf("failed while getting peer %s: %v", peerKey, err)
return nil, err
}
peer, err := account.FindPeerByPubKey(peerKey)
if err != nil {
return nil, status.Errorf(status.NotFound, "peer is not registered")
}
if peer == nil {
return nil, nil
}
peerPostureChecks := am.collectPeerPostureChecks(account, peer)
postureChecksList := make([]posture.Checks, 0, len(peerPostureChecks))
for _, check := range peerPostureChecks {
postureChecksList = append(postureChecksList, check)
}
return postureChecksList, nil
}
// collectPeerPostureChecks collects the posture checks applied for a given peer.
func (am *DefaultAccountManager) collectPeerPostureChecks(account *Account, peer *nbpeer.Peer) map[string]posture.Checks {
// getPeerPostureChecks returns the posture checks applied for a given peer.
func (am *DefaultAccountManager) getPeerPostureChecks(account *Account, peer *nbpeer.Peer) []*posture.Checks {
peerPostureChecks := make(map[string]posture.Checks)
if len(account.PostureChecks) == 0 {
return nil
}
for _, policy := range account.Policies {
if !policy.Enabled {
continue
@@ -225,7 +202,13 @@ func (am *DefaultAccountManager) collectPeerPostureChecks(account *Account, peer
}
}
return peerPostureChecks
postureChecksList := make([]*posture.Checks, 0, len(peerPostureChecks))
for _, check := range peerPostureChecks {
checkCopy := check
postureChecksList = append(postureChecksList, &checkCopy)
}
return postureChecksList
}
// isPeerInPolicySourceGroups checks if a peer is present in any of the policy rule source groups.

View File

@@ -1,6 +1,7 @@
package server
import (
"fmt"
"net/netip"
"unicode/utf8"
@@ -51,7 +52,7 @@ func (am *DefaultAccountManager) checkRoutePrefixOrDomainsExistForPeers(account
for _, prefixRoute := range routesWithPrefix {
// we skip route(s) with the same network ID as we want to allow updating of the existing route
// when create a new route routeID is newly generated so nothing will be skipped
// when creating a new route routeID is newly generated so nothing will be skipped
if routeID == prefixRoute.ID {
continue
}
@@ -65,8 +66,9 @@ func (am *DefaultAccountManager) checkRoutePrefixOrDomainsExistForPeers(account
group := account.GetGroup(groupID)
if group == nil {
return status.Errorf(
status.InvalidArgument, "failed to add route with prefix %s - peer group %s doesn't exist",
prefix.String(), groupID)
status.InvalidArgument, "failed to add route with %s - peer group %s doesn't exist",
getRouteDescriptor(prefix, domains), groupID,
)
}
for _, pID := range group.Peers {
@@ -83,18 +85,18 @@ func (am *DefaultAccountManager) checkRoutePrefixOrDomainsExistForPeers(account
}
if _, ok := seenPeers[peerID]; ok {
return status.Errorf(status.AlreadyExists,
"failed to add route with prefix %s - peer %s already has this route", prefix.String(), peerID)
"failed to add route with %s - peer %s already has this route", getRouteDescriptor(prefix, domains), peerID)
}
}
// check that peerGroupIDs are not in any route peerGroups list
for _, groupID := range peerGroupIDs {
group := account.GetGroup(groupID) // we validated the group existent before entering this function, o need to check again.
group := account.GetGroup(groupID) // we validated the group existence before entering this function, no need to check again.
if _, ok := seenPeerGroups[groupID]; ok {
return status.Errorf(
status.AlreadyExists, "failed to add route with prefix %s - peer group %s already has this route",
prefix.String(), group.Name)
status.AlreadyExists, "failed to add route with %s - peer group %s already has this route",
getRouteDescriptor(prefix, domains), group.Name)
}
// check that the peers from peerGroupIDs groups are not the same peers we saw in routesWithPrefix
@@ -105,8 +107,8 @@ func (am *DefaultAccountManager) checkRoutePrefixOrDomainsExistForPeers(account
return status.Errorf(status.InvalidArgument, "peer with ID %s not found", peerID)
}
return status.Errorf(status.AlreadyExists,
"failed to add route with prefix %s - peer %s from the group %s already has this route",
prefix.String(), peer.Name, group.Name)
"failed to add route with %s - peer %s from the group %s already has this route",
getRouteDescriptor(prefix, domains), peer.Name, group.Name)
}
}
}
@@ -114,6 +116,13 @@ func (am *DefaultAccountManager) checkRoutePrefixOrDomainsExistForPeers(account
return nil
}
func getRouteDescriptor(prefix netip.Prefix, domains domain.List) string {
if len(domains) > 0 {
return fmt.Sprintf("domains [%s]", domains.SafeString())
}
return fmt.Sprintf("prefix %s", prefix.String())
}
// CreateRoute creates and saves a new route
func (am *DefaultAccountManager) CreateRoute(accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peerID string, peerGroupIDs []string, description string, netID route.NetID, masquerade bool, metric int, groups []string, enabled bool, userID string, keepRoute bool) (*route.Route, error) {
unlock := am.Store.AcquireAccountWriteLock(accountID)

View File

@@ -86,20 +86,38 @@ func getStoreEngineFromEnv() StoreEngine {
return SqliteStoreEngine
}
// getStoreEngine determines the store engine to use
func getStoreEngine(kind StoreEngine) StoreEngine {
// getStoreEngine determines the store engine to use.
// If no engine is specified, it attempts to retrieve it from the environment.
// If still not specified, it defaults to using SQLite.
// Additionally, it handles the migration from a JSON store file to SQLite if applicable.
func getStoreEngine(dataDir string, kind StoreEngine) StoreEngine {
if kind == "" {
kind = getStoreEngineFromEnv()
if kind == "" {
kind = SqliteStoreEngine
// Migrate if it is the first run with a JSON file existing and no SQLite file present
jsonStoreFile := filepath.Join(dataDir, storeFileName)
sqliteStoreFile := filepath.Join(dataDir, storeSqliteFileName)
if util.FileExists(jsonStoreFile) && !util.FileExists(sqliteStoreFile) {
log.Warnf("unsupported store engine specified, but found %s. Automatically migrating to SQLite.", jsonStoreFile)
// Attempt to migrate from JSON store to SQLite
if err := MigrateFileStoreToSqlite(dataDir); err != nil {
log.Errorf("failed to migrate filestore to SQLite: %v", err)
kind = FileStoreEngine
}
}
}
}
return kind
}
// NewStore creates a new store based on the provided engine type, data directory, and telemetry metrics
func NewStore(kind StoreEngine, dataDir string, metrics telemetry.AppMetrics) (Store, error) {
kind = getStoreEngine(kind)
kind = getStoreEngine(dataDir, kind)
if err := checkFileStoreEngine(kind, dataDir); err != nil {
return nil, err
@@ -113,7 +131,7 @@ func NewStore(kind StoreEngine, dataDir string, metrics telemetry.AppMetrics) (S
log.Info("using Postgres store engine")
return newPostgresStore(metrics)
default:
return handleUnsupportedStoreEngine(kind, dataDir, metrics)
return nil, fmt.Errorf("unsupported kind of store: %s", kind)
}
}
@@ -128,25 +146,6 @@ func checkFileStoreEngine(kind StoreEngine, dataDir string) error {
return nil
}
// handleUnsupportedStoreEngine handles cases where the store engine is unsupported
func handleUnsupportedStoreEngine(kind StoreEngine, dataDir string, metrics telemetry.AppMetrics) (Store, error) {
jsonStoreFile := filepath.Join(dataDir, storeFileName)
sqliteStoreFile := filepath.Join(dataDir, storeSqliteFileName)
if util.FileExists(jsonStoreFile) && !util.FileExists(sqliteStoreFile) {
log.Warnf("unsupported store engine, but found %s. Automatically migrating to SQLite.", jsonStoreFile)
if err := MigrateFileStoreToSqlite(dataDir); err != nil {
return nil, fmt.Errorf("failed to migrate data to SQLite store: %w", err)
}
log.Info("using SQLite store engine")
return NewSqliteStore(dataDir, metrics)
}
return nil, fmt.Errorf("unsupported kind of store: %s", kind)
}
// migrate migrates the SQLite database to the latest schema
func migrate(db *gorm.DB) error {
migrations := getMigrations()