Compare commits

...

27 Commits

Author SHA1 Message Date
Maycon Santos
5f5d597c59 fix GetSetupKeyByID 2025-06-26 13:54:04 +02:00
Maycon Santos
b85aad07d4 add getTXWithLockStrength method 2025-06-26 13:44:00 +02:00
Maycon Santos
7398836c2e Stop using locking share for read calls to avoid deadlocks
Added peer.userID index
2025-06-26 12:24:42 +02:00
Krzysztof Nazarewski (kdn)
34ac4e4b5a [misc] fix: self-hosting: the wrong default for NETBIRD_AUTH_PKCE_LOGIN_FLAG (#4055)
* fix: self-hosting: the wrong default for NETBIRD_AUTH_PKCE_LOGIN_FLAG

fixes https://github.com/netbirdio/netbird/issues/4054

* un-quote the number

Co-authored-by: Maycon Santos <mlsmaycon@gmail.com>

---------

Co-authored-by: Maycon Santos <mlsmaycon@gmail.com>
2025-06-26 10:45:00 +02:00
Pascal Fischer
52ff9d9602 [management] remove unused transaction (#4053) 2025-06-26 01:34:22 +02:00
Pascal Fischer
1b73fae46e [management] add breakdown of network map calculation metrics (#4020) 2025-06-25 11:46:35 +02:00
Viktor Liu
d897365abc [client] Don't open cmd.exe during MSI actions (#4041) 2025-06-24 21:32:37 +02:00
Viktor Liu
f37aa2cc9d [misc] Specify netbird binary location in Dockerfiles (#4024) 2025-06-23 10:09:02 +02:00
Maycon Santos
5343bee7b2 [management] check and log on new management version (#4029)
This PR enhances the version checker to send a custom User-Agent header when polling for updates, and configures both the management CLI and client UI to use distinct agents. 

- NewUpdate now takes an `httpAgent` string to set the User-Agent header.
- `fetchVersion` builds a custom HTTP request (instead of `http.Get`) and sets the User-Agent.
- Management CLI and client UI now pass `"nb/management"` and `"nb/client-ui"` respectively to NewUpdate.
- Tests updated to supply an `httpAgent` constant.
- Logs if there is a new version available for management
2025-06-22 16:44:33 +02:00
Maycon Santos
870e29db63 [misc] add additional metrics (#4028)
* add additional metrics

we are collecting active rosenpass, ssh from the client side
we are also collecting active user peers and active users

* remove duplicated
2025-06-22 13:44:25 +02:00
Maycon Santos
08e9b05d51 [client] close windows when process needs to exit (#4027)
This PR fixes a bug by ensuring that the advanced settings and re-authentication windows are closed appropriately when the main GUI process exits.

- Updated runSelfCommand calls throughout the UI to pass a context parameter.
- Modified runSelfCommand’s signature and its internal command invocation to use exec.CommandContext for proper cancellation handling.
2025-06-22 10:33:04 +02:00
hakansa
3581648071 [client] Refactor showLoginURL to improve error handling and connection status checks (#4026)
This PR refactors showLoginURL to improve error handling and connection status checks by delaying the login fetch until user interaction and closing the pop-up if already connected.

- Moved s.login(false) call into the click handler to defer network I/O.
- Added a conn.Status check after opening the URL to skip reconnection if already connected.
- Enhanced error logs for missing verification URLs and service status failures.
2025-06-22 10:03:58 +02:00
Viktor Liu
2a51609436 [client] Handle lazy routing peers that are part of HA groups (#3943)
* Activate new lazy routing peers if the HA group is active
* Prevent lazy peers going to idle if HA group members are active (#3948)
2025-06-20 18:07:19 +02:00
Pascal Fischer
83457f8b99 [management] add transaction for integrated validator groups update and primary account update (#4014) 2025-06-20 12:13:24 +02:00
Pascal Fischer
b45284f086 [management] export ephemeral peer flag on api (#4004) 2025-06-19 16:46:56 +02:00
Bethuel Mmbaga
e9016aecea [management] Add backward compatibility for older clients without firewall rules port range support (#4003)
Adds backward compatibility for clients with versions prior to v0.48.0 that do not support port range firewall rules.

- Skips generation of firewall rules with multi-port ranges for older clients
- Preserves support for single-port ranges by treating them as individual port rules, ensuring compatibility with older clients
2025-06-19 13:07:06 +03:00
Viktor Liu
23b5d45b68 [client] Fix port range squashing (#4007) 2025-06-18 18:56:48 +02:00
Viktor Liu
0e5dc9d412 [client] Add more Android advanced settings (#4001) 2025-06-18 17:23:23 +02:00
Zoltan Papp
91f7ee6a3c Fix route notification
On Android ignore the dynamic roots in the route notifications
2025-06-18 16:49:03 +02:00
Bethuel Mmbaga
7c6b85b4cb [management] Refactor routes to use store methods (#2928) 2025-06-18 16:40:29 +03:00
hakansa
08c9107c61 [client] fix connection state handling (#3995)
[client] fix connection state handling (#3995)
2025-06-17 17:14:08 +03:00
hakansa
81d83245e1 [client] Fix logic in updateStatus to correctly handle connection state (#3994)
[client] Fix logic in updateStatus to correctly handle connection state (#3994)
2025-06-17 17:02:04 +03:00
Maycon Santos
af2b427751 [management] Avoid recalculating next peer expiration (#3991)
* Avoid recalculating next peer expiration

- Check if an account schedule is already running
- Cancel executing schedules only when changes occurs
- Add more context info to logs

* fix tests
2025-06-17 15:14:11 +02:00
hakansa
f61ebdb3bc [client] Fix DNS Interceptor Build Error (#3993)
[client] Fix DNS Interceptor Build Error
2025-06-17 16:07:14 +03:00
Viktor Liu
de7384e8ea [client] Tighten allowed domains for dns forwarder (#3978) 2025-06-17 14:03:00 +02:00
Viktor Liu
75c1be69cf [client] Prioritze the local resolver in the dns handler chain (#3965) 2025-06-17 14:02:30 +02:00
hakansa
424ae28de9 [client] Fix UI Download URL (#3990)
[client] Fix UI Download URL
2025-06-17 11:55:48 +03:00
79 changed files with 2864 additions and 1143 deletions

View File

@@ -9,7 +9,7 @@ on:
pull_request:
env:
SIGN_PIPE_VER: "v0.0.18"
SIGN_PIPE_VER: "v0.0.19"
GORELEASER_VER: "v2.3.2"
PRODUCT_NAME: "NetBird"
COPYRIGHT: "NetBird GmbH"

View File

@@ -1,6 +1,9 @@
FROM alpine:3.21.3
# iproute2: busybox doesn't display ip rules properly
RUN apk add --no-cache ca-certificates ip6tables iproute2 iptables
ARG NETBIRD_BINARY=netbird
COPY ${NETBIRD_BINARY} /usr/local/bin/netbird
ENV NB_FOREGROUND_MODE=true
ENTRYPOINT [ "/usr/local/bin/netbird","up"]
COPY netbird /usr/local/bin/netbird

View File

@@ -1,6 +1,7 @@
FROM alpine:3.21.0
COPY netbird /usr/local/bin/netbird
ARG NETBIRD_BINARY=netbird
COPY ${NETBIRD_BINARY} /usr/local/bin/netbird
RUN apk add --no-cache ca-certificates \
&& adduser -D -h /var/lib/netbird netbird

View File

@@ -4,12 +4,12 @@ import (
"github.com/netbirdio/netbird/client/internal"
)
// Preferences export a subset of the internal config for gomobile
// Preferences exports a subset of the internal config for gomobile
type Preferences struct {
configInput internal.ConfigInput
}
// NewPreferences create new Preferences instance
// NewPreferences creates a new Preferences instance
func NewPreferences(configPath string) *Preferences {
ci := internal.ConfigInput{
ConfigPath: configPath,
@@ -17,7 +17,7 @@ func NewPreferences(configPath string) *Preferences {
return &Preferences{ci}
}
// GetManagementURL read url from config file
// GetManagementURL reads URL from config file
func (p *Preferences) GetManagementURL() (string, error) {
if p.configInput.ManagementURL != "" {
return p.configInput.ManagementURL, nil
@@ -30,12 +30,12 @@ func (p *Preferences) GetManagementURL() (string, error) {
return cfg.ManagementURL.String(), err
}
// SetManagementURL store the given url and wait for commit
// SetManagementURL stores the given URL and waits for commit
func (p *Preferences) SetManagementURL(url string) {
p.configInput.ManagementURL = url
}
// GetAdminURL read url from config file
// GetAdminURL reads URL from config file
func (p *Preferences) GetAdminURL() (string, error) {
if p.configInput.AdminURL != "" {
return p.configInput.AdminURL, nil
@@ -48,12 +48,12 @@ func (p *Preferences) GetAdminURL() (string, error) {
return cfg.AdminURL.String(), err
}
// SetAdminURL store the given url and wait for commit
// SetAdminURL stores the given URL and waits for commit
func (p *Preferences) SetAdminURL(url string) {
p.configInput.AdminURL = url
}
// GetPreSharedKey read preshared key from config file
// GetPreSharedKey reads pre-shared key from config file
func (p *Preferences) GetPreSharedKey() (string, error) {
if p.configInput.PreSharedKey != nil {
return *p.configInput.PreSharedKey, nil
@@ -66,17 +66,17 @@ func (p *Preferences) GetPreSharedKey() (string, error) {
return cfg.PreSharedKey, err
}
// SetPreSharedKey store the given key and wait for commit
// SetPreSharedKey stores the given key and waits for commit
func (p *Preferences) SetPreSharedKey(key string) {
p.configInput.PreSharedKey = &key
}
// SetRosenpassEnabled store if rosenpass is enabled
// SetRosenpassEnabled stores whether Rosenpass is enabled
func (p *Preferences) SetRosenpassEnabled(enabled bool) {
p.configInput.RosenpassEnabled = &enabled
}
// GetRosenpassEnabled read rosenpass enabled from config file
// GetRosenpassEnabled reads Rosenpass enabled status from config file
func (p *Preferences) GetRosenpassEnabled() (bool, error) {
if p.configInput.RosenpassEnabled != nil {
return *p.configInput.RosenpassEnabled, nil
@@ -89,12 +89,12 @@ func (p *Preferences) GetRosenpassEnabled() (bool, error) {
return cfg.RosenpassEnabled, err
}
// SetRosenpassPermissive store the given permissive and wait for commit
// SetRosenpassPermissive stores the given permissive setting and waits for commit
func (p *Preferences) SetRosenpassPermissive(permissive bool) {
p.configInput.RosenpassPermissive = &permissive
}
// GetRosenpassPermissive read rosenpass permissive from config file
// GetRosenpassPermissive reads Rosenpass permissive setting from config file
func (p *Preferences) GetRosenpassPermissive() (bool, error) {
if p.configInput.RosenpassPermissive != nil {
return *p.configInput.RosenpassPermissive, nil
@@ -107,7 +107,119 @@ func (p *Preferences) GetRosenpassPermissive() (bool, error) {
return cfg.RosenpassPermissive, err
}
// Commit write out the changes into config file
// GetDisableClientRoutes reads disable client routes setting from config file
func (p *Preferences) GetDisableClientRoutes() (bool, error) {
if p.configInput.DisableClientRoutes != nil {
return *p.configInput.DisableClientRoutes, nil
}
cfg, err := internal.ReadConfig(p.configInput.ConfigPath)
if err != nil {
return false, err
}
return cfg.DisableClientRoutes, err
}
// SetDisableClientRoutes stores the given value and waits for commit
func (p *Preferences) SetDisableClientRoutes(disable bool) {
p.configInput.DisableClientRoutes = &disable
}
// GetDisableServerRoutes reads disable server routes setting from config file
func (p *Preferences) GetDisableServerRoutes() (bool, error) {
if p.configInput.DisableServerRoutes != nil {
return *p.configInput.DisableServerRoutes, nil
}
cfg, err := internal.ReadConfig(p.configInput.ConfigPath)
if err != nil {
return false, err
}
return cfg.DisableServerRoutes, err
}
// SetDisableServerRoutes stores the given value and waits for commit
func (p *Preferences) SetDisableServerRoutes(disable bool) {
p.configInput.DisableServerRoutes = &disable
}
// GetDisableDNS reads disable DNS setting from config file
func (p *Preferences) GetDisableDNS() (bool, error) {
if p.configInput.DisableDNS != nil {
return *p.configInput.DisableDNS, nil
}
cfg, err := internal.ReadConfig(p.configInput.ConfigPath)
if err != nil {
return false, err
}
return cfg.DisableDNS, err
}
// SetDisableDNS stores the given value and waits for commit
func (p *Preferences) SetDisableDNS(disable bool) {
p.configInput.DisableDNS = &disable
}
// GetDisableFirewall reads disable firewall setting from config file
func (p *Preferences) GetDisableFirewall() (bool, error) {
if p.configInput.DisableFirewall != nil {
return *p.configInput.DisableFirewall, nil
}
cfg, err := internal.ReadConfig(p.configInput.ConfigPath)
if err != nil {
return false, err
}
return cfg.DisableFirewall, err
}
// SetDisableFirewall stores the given value and waits for commit
func (p *Preferences) SetDisableFirewall(disable bool) {
p.configInput.DisableFirewall = &disable
}
// GetServerSSHAllowed reads server SSH allowed setting from config file
func (p *Preferences) GetServerSSHAllowed() (bool, error) {
if p.configInput.ServerSSHAllowed != nil {
return *p.configInput.ServerSSHAllowed, nil
}
cfg, err := internal.ReadConfig(p.configInput.ConfigPath)
if err != nil {
return false, err
}
if cfg.ServerSSHAllowed == nil {
// Default to false for security on Android
return false, nil
}
return *cfg.ServerSSHAllowed, err
}
// SetServerSSHAllowed stores the given value and waits for commit
func (p *Preferences) SetServerSSHAllowed(allowed bool) {
p.configInput.ServerSSHAllowed = &allowed
}
// GetBlockInbound reads block inbound setting from config file
func (p *Preferences) GetBlockInbound() (bool, error) {
if p.configInput.BlockInbound != nil {
return *p.configInput.BlockInbound, nil
}
cfg, err := internal.ReadConfig(p.configInput.ConfigPath)
if err != nil {
return false, err
}
return cfg.BlockInbound, err
}
// SetBlockInbound stores the given value and waits for commit
func (p *Preferences) SetBlockInbound(block bool) {
p.configInput.BlockInbound = &block
}
// Commit writes out the changes to the config file
func (p *Preferences) Commit() error {
_, err := internal.UpdateOrCreateConfig(p.configInput)
return err

View File

@@ -38,5 +38,5 @@ func init() {
upCmd.PersistentFlags().BoolVar(&blockInbound, blockInboundFlag, false,
"Block inbound connections. If enabled, the client will not allow any inbound connections to the local machine nor routed networks.\n"+
"This overrides any policies received from the management service.")
"This overrides any policies received from the management service.")
}

View File

@@ -24,6 +24,7 @@ type WGTunDevice struct {
mtu int
iceBind *bind.ICEBind
tunAdapter TunAdapter
disableDNS bool
name string
device *device.Device
@@ -32,7 +33,7 @@ type WGTunDevice struct {
configurer WGConfigurer
}
func NewTunDevice(address wgaddr.Address, port int, key string, mtu int, iceBind *bind.ICEBind, tunAdapter TunAdapter) *WGTunDevice {
func NewTunDevice(address wgaddr.Address, port int, key string, mtu int, iceBind *bind.ICEBind, tunAdapter TunAdapter, disableDNS bool) *WGTunDevice {
return &WGTunDevice{
address: address,
port: port,
@@ -40,6 +41,7 @@ func NewTunDevice(address wgaddr.Address, port int, key string, mtu int, iceBind
mtu: mtu,
iceBind: iceBind,
tunAdapter: tunAdapter,
disableDNS: disableDNS,
}
}
@@ -49,6 +51,13 @@ func (t *WGTunDevice) Create(routes []string, dns string, searchDomains []string
routesString := routesToString(routes)
searchDomainsToString := searchDomainsToString(searchDomains)
// Skip DNS configuration when DisableDNS is enabled
if t.disableDNS {
log.Info("DNS is disabled, skipping DNS and search domain configuration")
dns = ""
searchDomainsToString = ""
}
fd, err := t.tunAdapter.ConfigureInterface(t.address.String(), t.mtu, dns, searchDomainsToString, routesString)
if err != nil {
log.Errorf("failed to create Android interface: %s", err)

View File

@@ -43,6 +43,7 @@ type WGIFaceOpts struct {
MobileArgs *device.MobileIFaceArguments
TransportNet transport.Net
FilterFn bind.FilterFn
DisableDNS bool
}
// WGIface represents an interface instance

View File

@@ -18,7 +18,7 @@ func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
wgIFace := &WGIface{
userspaceBind: true,
tun: device.NewTunDevice(wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, opts.MobileArgs.TunAdapter),
tun: device.NewTunDevice(wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, opts.MobileArgs.TunAdapter, opts.DisableDNS),
wgProxyFactory: wgproxy.NewUSPFactory(iceBind),
}
return wgIFace, nil

View File

@@ -398,11 +398,15 @@ func (d *DefaultManager) squashAcceptRules(
//
// We zeroed this to notify squash function that this protocol can't be squashed.
addRuleToCalculationMap := func(i int, r *mgmProto.FirewallRule, protocols map[mgmProto.RuleProtocol]*protoMatch) {
drop := r.Action == mgmProto.RuleAction_DROP || r.Port != ""
if drop {
hasPortRestrictions := r.Action == mgmProto.RuleAction_DROP ||
r.Port != "" || !portInfoEmpty(r.PortInfo)
if hasPortRestrictions {
// Don't squash rules with port restrictions
protocols[r.Protocol] = &protoMatch{ips: map[string]int{}}
return
}
if _, ok := protocols[r.Protocol]; !ok {
protocols[r.Protocol] = &protoMatch{
ips: map[string]int{},

View File

@@ -330,6 +330,434 @@ func TestDefaultManagerSquashRulesNoAffect(t *testing.T) {
assert.Equal(t, len(networkMap.FirewallRules), len(rules))
}
func TestDefaultManagerSquashRulesWithPortRestrictions(t *testing.T) {
tests := []struct {
name string
rules []*mgmProto.FirewallRule
expectedCount int
description string
}{
{
name: "should not squash rules with port ranges",
rules: []*mgmProto.FirewallRule{
{
PeerIP: "10.93.0.1",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
PortInfo: &mgmProto.PortInfo{
PortSelection: &mgmProto.PortInfo_Range_{
Range: &mgmProto.PortInfo_Range{
Start: 8080,
End: 8090,
},
},
},
},
{
PeerIP: "10.93.0.2",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
PortInfo: &mgmProto.PortInfo{
PortSelection: &mgmProto.PortInfo_Range_{
Range: &mgmProto.PortInfo_Range{
Start: 8080,
End: 8090,
},
},
},
},
{
PeerIP: "10.93.0.3",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
PortInfo: &mgmProto.PortInfo{
PortSelection: &mgmProto.PortInfo_Range_{
Range: &mgmProto.PortInfo_Range{
Start: 8080,
End: 8090,
},
},
},
},
{
PeerIP: "10.93.0.4",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
PortInfo: &mgmProto.PortInfo{
PortSelection: &mgmProto.PortInfo_Range_{
Range: &mgmProto.PortInfo_Range{
Start: 8080,
End: 8090,
},
},
},
},
},
expectedCount: 4,
description: "Rules with port ranges should not be squashed even if they cover all peers",
},
{
name: "should not squash rules with specific ports",
rules: []*mgmProto.FirewallRule{
{
PeerIP: "10.93.0.1",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
PortInfo: &mgmProto.PortInfo{
PortSelection: &mgmProto.PortInfo_Port{
Port: 80,
},
},
},
{
PeerIP: "10.93.0.2",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
PortInfo: &mgmProto.PortInfo{
PortSelection: &mgmProto.PortInfo_Port{
Port: 80,
},
},
},
{
PeerIP: "10.93.0.3",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
PortInfo: &mgmProto.PortInfo{
PortSelection: &mgmProto.PortInfo_Port{
Port: 80,
},
},
},
{
PeerIP: "10.93.0.4",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
PortInfo: &mgmProto.PortInfo{
PortSelection: &mgmProto.PortInfo_Port{
Port: 80,
},
},
},
},
expectedCount: 4,
description: "Rules with specific ports should not be squashed even if they cover all peers",
},
{
name: "should not squash rules with legacy port field",
rules: []*mgmProto.FirewallRule{
{
PeerIP: "10.93.0.1",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "443",
},
{
PeerIP: "10.93.0.2",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "443",
},
{
PeerIP: "10.93.0.3",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "443",
},
{
PeerIP: "10.93.0.4",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "443",
},
},
expectedCount: 4,
description: "Rules with legacy port field should not be squashed",
},
{
name: "should not squash rules with DROP action",
rules: []*mgmProto.FirewallRule{
{
PeerIP: "10.93.0.1",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_DROP,
Protocol: mgmProto.RuleProtocol_TCP,
},
{
PeerIP: "10.93.0.2",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_DROP,
Protocol: mgmProto.RuleProtocol_TCP,
},
{
PeerIP: "10.93.0.3",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_DROP,
Protocol: mgmProto.RuleProtocol_TCP,
},
{
PeerIP: "10.93.0.4",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_DROP,
Protocol: mgmProto.RuleProtocol_TCP,
},
},
expectedCount: 4,
description: "Rules with DROP action should not be squashed",
},
{
name: "should squash rules without port restrictions",
rules: []*mgmProto.FirewallRule{
{
PeerIP: "10.93.0.1",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
},
{
PeerIP: "10.93.0.2",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
},
{
PeerIP: "10.93.0.3",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
},
{
PeerIP: "10.93.0.4",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
},
},
expectedCount: 1,
description: "Rules without port restrictions should be squashed into a single 0.0.0.0 rule",
},
{
name: "mixed rules should not squash protocol with port restrictions",
rules: []*mgmProto.FirewallRule{
{
PeerIP: "10.93.0.1",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
},
{
PeerIP: "10.93.0.2",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
PortInfo: &mgmProto.PortInfo{
PortSelection: &mgmProto.PortInfo_Port{
Port: 80,
},
},
},
{
PeerIP: "10.93.0.3",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
},
{
PeerIP: "10.93.0.4",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
},
},
expectedCount: 4,
description: "TCP should not be squashed because one rule has port restrictions",
},
{
name: "should squash UDP but not TCP when TCP has port restrictions",
rules: []*mgmProto.FirewallRule{
// TCP rules with port restrictions - should NOT be squashed
{
PeerIP: "10.93.0.1",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "443",
},
{
PeerIP: "10.93.0.2",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "443",
},
{
PeerIP: "10.93.0.3",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "443",
},
{
PeerIP: "10.93.0.4",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "443",
},
// UDP rules without port restrictions - SHOULD be squashed
{
PeerIP: "10.93.0.1",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_UDP,
},
{
PeerIP: "10.93.0.2",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_UDP,
},
{
PeerIP: "10.93.0.3",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_UDP,
},
{
PeerIP: "10.93.0.4",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_UDP,
},
},
expectedCount: 5, // 4 TCP rules + 1 squashed UDP rule (0.0.0.0)
description: "UDP should be squashed to 0.0.0.0 rule, but TCP should remain as individual rules due to port restrictions",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
networkMap := &mgmProto.NetworkMap{
RemotePeers: []*mgmProto.RemotePeerConfig{
{AllowedIps: []string{"10.93.0.1"}},
{AllowedIps: []string{"10.93.0.2"}},
{AllowedIps: []string{"10.93.0.3"}},
{AllowedIps: []string{"10.93.0.4"}},
},
FirewallRules: tt.rules,
}
manager := &DefaultManager{}
rules, _ := manager.squashAcceptRules(networkMap)
assert.Equal(t, tt.expectedCount, len(rules), tt.description)
// For squashed rules, verify we get the expected 0.0.0.0 rule
if tt.expectedCount == 1 {
assert.Equal(t, "0.0.0.0", rules[0].PeerIP)
assert.Equal(t, mgmProto.RuleDirection_IN, rules[0].Direction)
assert.Equal(t, mgmProto.RuleAction_ACCEPT, rules[0].Action)
}
})
}
}
func TestPortInfoEmpty(t *testing.T) {
tests := []struct {
name string
portInfo *mgmProto.PortInfo
expected bool
}{
{
name: "nil PortInfo should be empty",
portInfo: nil,
expected: true,
},
{
name: "PortInfo with zero port should be empty",
portInfo: &mgmProto.PortInfo{
PortSelection: &mgmProto.PortInfo_Port{
Port: 0,
},
},
expected: true,
},
{
name: "PortInfo with valid port should not be empty",
portInfo: &mgmProto.PortInfo{
PortSelection: &mgmProto.PortInfo_Port{
Port: 80,
},
},
expected: false,
},
{
name: "PortInfo with nil range should be empty",
portInfo: &mgmProto.PortInfo{
PortSelection: &mgmProto.PortInfo_Range_{
Range: nil,
},
},
expected: true,
},
{
name: "PortInfo with zero start range should be empty",
portInfo: &mgmProto.PortInfo{
PortSelection: &mgmProto.PortInfo_Range_{
Range: &mgmProto.PortInfo_Range{
Start: 0,
End: 100,
},
},
},
expected: true,
},
{
name: "PortInfo with zero end range should be empty",
portInfo: &mgmProto.PortInfo{
PortSelection: &mgmProto.PortInfo_Range_{
Range: &mgmProto.PortInfo_Range{
Start: 80,
End: 0,
},
},
},
expected: true,
},
{
name: "PortInfo with valid range should not be empty",
portInfo: &mgmProto.PortInfo{
PortSelection: &mgmProto.PortInfo_Range_{
Range: &mgmProto.PortInfo_Range{
Start: 8080,
End: 8090,
},
},
},
expected: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := portInfoEmpty(tt.portInfo)
assert.Equal(t, tt.expected, result)
})
}
}
func TestDefaultManagerEnableSSHRules(t *testing.T) {
networkMap := &mgmProto.NetworkMap{
PeerConfig: &mgmProto.PeerConfig{

View File

@@ -223,6 +223,8 @@ func createNewConfig(input ConfigInput) (*Config, error) {
config := &Config{
// defaults to false only for new (post 0.26) configurations
ServerSSHAllowed: util.False(),
// default to disabling server routes on Android for security
DisableServerRoutes: runtime.GOOS == "android",
}
if _, err := config.apply(input); err != nil {
@@ -416,9 +418,15 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) {
config.ServerSSHAllowed = input.ServerSSHAllowed
updated = true
} else if config.ServerSSHAllowed == nil {
// enables SSH for configs from old versions to preserve backwards compatibility
log.Infof("falling back to enabled SSH server for pre-existing configuration")
config.ServerSSHAllowed = util.True()
if runtime.GOOS == "android" {
// default to disabled SSH on Android for security
log.Infof("setting SSH server to false by default on Android")
config.ServerSSHAllowed = util.False()
} else {
// enables SSH for configs from old versions to preserve backwards compatibility
log.Infof("falling back to enabled SSH server for pre-existing configuration")
config.ServerSSHAllowed = util.True()
}
updated = true
}

View File

@@ -175,7 +175,7 @@ func (e *ConnMgr) AddPeerConn(ctx context.Context, peerKey string, conn *peer.Co
PeerConnID: conn.ConnID(),
Log: conn.Log,
}
excluded, err := e.lazyConnMgr.AddPeer(lazyPeerCfg)
excluded, err := e.lazyConnMgr.AddPeer(e.lazyCtx, lazyPeerCfg)
if err != nil {
conn.Log.Errorf("failed to add peer to lazyconn manager: %v", err)
if err := conn.Open(ctx); err != nil {

View File

@@ -11,9 +11,10 @@ import (
)
const (
PriorityDNSRoute = 100
PriorityMatchDomain = 50
PriorityDefault = 1
PriorityLocal = 100
PriorityDNSRoute = 75
PriorityUpstream = 50
PriorityDefault = 1
)
type SubdomainMatcher interface {

View File

@@ -22,7 +22,7 @@ func TestHandlerChain_ServeDNS_Priorities(t *testing.T) {
// Setup handlers with different priorities
chain.AddHandler("example.com.", defaultHandler, nbdns.PriorityDefault)
chain.AddHandler("example.com.", matchDomainHandler, nbdns.PriorityMatchDomain)
chain.AddHandler("example.com.", matchDomainHandler, nbdns.PriorityUpstream)
chain.AddHandler("example.com.", dnsRouteHandler, nbdns.PriorityDNSRoute)
// Create test request
@@ -200,7 +200,7 @@ func TestHandlerChain_ServeDNS_OverlappingDomains(t *testing.T) {
priority int
}{
{pattern: "*.example.com.", priority: nbdns.PriorityDefault},
{pattern: "*.example.com.", priority: nbdns.PriorityMatchDomain},
{pattern: "*.example.com.", priority: nbdns.PriorityUpstream},
{pattern: "*.example.com.", priority: nbdns.PriorityDNSRoute},
},
queryDomain: "test.example.com.",
@@ -214,7 +214,7 @@ func TestHandlerChain_ServeDNS_OverlappingDomains(t *testing.T) {
priority int
}{
{pattern: "*.example.com.", priority: nbdns.PriorityDefault},
{pattern: "test.example.com.", priority: nbdns.PriorityMatchDomain},
{pattern: "test.example.com.", priority: nbdns.PriorityUpstream},
{pattern: "*.test.example.com.", priority: nbdns.PriorityDNSRoute},
},
queryDomain: "sub.test.example.com.",
@@ -281,7 +281,7 @@ func TestHandlerChain_ServeDNS_ChainContinuation(t *testing.T) {
// Add handlers in priority order
chain.AddHandler("example.com.", handler1, nbdns.PriorityDNSRoute)
chain.AddHandler("example.com.", handler2, nbdns.PriorityMatchDomain)
chain.AddHandler("example.com.", handler2, nbdns.PriorityUpstream)
chain.AddHandler("example.com.", handler3, nbdns.PriorityDefault)
// Create test request
@@ -344,13 +344,13 @@ func TestHandlerChain_PriorityDeregistration(t *testing.T) {
priority int
}{
{"add", "example.com.", nbdns.PriorityDNSRoute},
{"add", "example.com.", nbdns.PriorityMatchDomain},
{"add", "example.com.", nbdns.PriorityUpstream},
{"remove", "example.com.", nbdns.PriorityDNSRoute},
},
query: "example.com.",
expectedCalls: map[int]bool{
nbdns.PriorityDNSRoute: false,
nbdns.PriorityMatchDomain: true,
nbdns.PriorityDNSRoute: false,
nbdns.PriorityUpstream: true,
},
},
{
@@ -361,13 +361,13 @@ func TestHandlerChain_PriorityDeregistration(t *testing.T) {
priority int
}{
{"add", "example.com.", nbdns.PriorityDNSRoute},
{"add", "example.com.", nbdns.PriorityMatchDomain},
{"remove", "example.com.", nbdns.PriorityMatchDomain},
{"add", "example.com.", nbdns.PriorityUpstream},
{"remove", "example.com.", nbdns.PriorityUpstream},
},
query: "example.com.",
expectedCalls: map[int]bool{
nbdns.PriorityDNSRoute: true,
nbdns.PriorityMatchDomain: false,
nbdns.PriorityDNSRoute: true,
nbdns.PriorityUpstream: false,
},
},
{
@@ -378,16 +378,16 @@ func TestHandlerChain_PriorityDeregistration(t *testing.T) {
priority int
}{
{"add", "example.com.", nbdns.PriorityDNSRoute},
{"add", "example.com.", nbdns.PriorityMatchDomain},
{"add", "example.com.", nbdns.PriorityUpstream},
{"add", "example.com.", nbdns.PriorityDefault},
{"remove", "example.com.", nbdns.PriorityDNSRoute},
{"remove", "example.com.", nbdns.PriorityMatchDomain},
{"remove", "example.com.", nbdns.PriorityUpstream},
},
query: "example.com.",
expectedCalls: map[int]bool{
nbdns.PriorityDNSRoute: false,
nbdns.PriorityMatchDomain: false,
nbdns.PriorityDefault: true,
nbdns.PriorityDNSRoute: false,
nbdns.PriorityUpstream: false,
nbdns.PriorityDefault: true,
},
},
}
@@ -454,7 +454,7 @@ func TestHandlerChain_MultiPriorityHandling(t *testing.T) {
// Add handlers in mixed order
chain.AddHandler(testDomain, defaultHandler, nbdns.PriorityDefault)
chain.AddHandler(testDomain, routeHandler, nbdns.PriorityDNSRoute)
chain.AddHandler(testDomain, matchHandler, nbdns.PriorityMatchDomain)
chain.AddHandler(testDomain, matchHandler, nbdns.PriorityUpstream)
// Test 1: Initial state
w1 := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}}
@@ -490,7 +490,7 @@ func TestHandlerChain_MultiPriorityHandling(t *testing.T) {
defaultHandler.Calls = nil
// Test 3: Remove middle priority handler
chain.RemoveHandler(testDomain, nbdns.PriorityMatchDomain)
chain.RemoveHandler(testDomain, nbdns.PriorityUpstream)
w3 := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}}
// Now lowest priority handler (defaultHandler) should be called
@@ -607,7 +607,7 @@ func TestHandlerChain_CaseSensitivity(t *testing.T) {
shouldMatch bool
}{
{"EXAMPLE.COM.", nbdns.PriorityDefault, false, false},
{"example.com.", nbdns.PriorityMatchDomain, false, false},
{"example.com.", nbdns.PriorityUpstream, false, false},
{"Example.Com.", nbdns.PriorityDNSRoute, false, true},
},
query: "example.com.",
@@ -702,8 +702,8 @@ func TestHandlerChain_DomainSpecificityOrdering(t *testing.T) {
priority int
subdomain bool
}{
{"add", "example.com.", nbdns.PriorityMatchDomain, true},
{"add", "sub.example.com.", nbdns.PriorityMatchDomain, false},
{"add", "example.com.", nbdns.PriorityUpstream, true},
{"add", "sub.example.com.", nbdns.PriorityUpstream, false},
},
query: "sub.example.com.",
expectedMatch: "sub.example.com.",
@@ -717,8 +717,8 @@ func TestHandlerChain_DomainSpecificityOrdering(t *testing.T) {
priority int
subdomain bool
}{
{"add", "example.com.", nbdns.PriorityMatchDomain, true},
{"add", "sub.example.com.", nbdns.PriorityMatchDomain, true},
{"add", "example.com.", nbdns.PriorityUpstream, true},
{"add", "sub.example.com.", nbdns.PriorityUpstream, true},
},
query: "sub.example.com.",
expectedMatch: "sub.example.com.",
@@ -732,10 +732,10 @@ func TestHandlerChain_DomainSpecificityOrdering(t *testing.T) {
priority int
subdomain bool
}{
{"add", "example.com.", nbdns.PriorityMatchDomain, true},
{"add", "sub.example.com.", nbdns.PriorityMatchDomain, true},
{"add", "test.sub.example.com.", nbdns.PriorityMatchDomain, false},
{"remove", "test.sub.example.com.", nbdns.PriorityMatchDomain, false},
{"add", "example.com.", nbdns.PriorityUpstream, true},
{"add", "sub.example.com.", nbdns.PriorityUpstream, true},
{"add", "test.sub.example.com.", nbdns.PriorityUpstream, false},
{"remove", "test.sub.example.com.", nbdns.PriorityUpstream, false},
},
query: "test.sub.example.com.",
expectedMatch: "sub.example.com.",
@@ -749,7 +749,7 @@ func TestHandlerChain_DomainSpecificityOrdering(t *testing.T) {
priority int
subdomain bool
}{
{"add", "sub.example.com.", nbdns.PriorityMatchDomain, false},
{"add", "sub.example.com.", nbdns.PriorityUpstream, false},
{"add", "example.com.", nbdns.PriorityDNSRoute, true},
},
query: "sub.example.com.",
@@ -764,9 +764,9 @@ func TestHandlerChain_DomainSpecificityOrdering(t *testing.T) {
priority int
subdomain bool
}{
{"add", "example.com.", nbdns.PriorityMatchDomain, true},
{"add", "other.example.com.", nbdns.PriorityMatchDomain, true},
{"add", "sub.example.com.", nbdns.PriorityMatchDomain, false},
{"add", "example.com.", nbdns.PriorityUpstream, true},
{"add", "other.example.com.", nbdns.PriorityUpstream, true},
{"add", "sub.example.com.", nbdns.PriorityUpstream, false},
},
query: "sub.example.com.",
expectedMatch: "sub.example.com.",

View File

@@ -527,7 +527,7 @@ func (s *DefaultServer) buildLocalHandlerUpdate(customZones []nbdns.CustomZone)
muxUpdates = append(muxUpdates, handlerWrapper{
domain: customZone.Domain,
handler: s.localResolver,
priority: PriorityMatchDomain,
priority: PriorityLocal,
})
for _, record := range customZone.Records {
@@ -566,7 +566,7 @@ func (s *DefaultServer) buildUpstreamHandlerUpdate(nameServerGroups []*nbdns.Nam
groupedNS := groupNSGroupsByDomain(nameServerGroups)
for _, domainGroup := range groupedNS {
basePriority := PriorityMatchDomain
basePriority := PriorityUpstream
if domainGroup.domain == nbdns.RootZone {
basePriority = PriorityDefault
}
@@ -588,10 +588,14 @@ func (s *DefaultServer) createHandlersForDomainGroup(domainGroup nsGroupsByDomai
// Decrement priority by handler index (0, 1, 2, ...) to avoid conflicts
priority := basePriority - i
// Check if we're about to overlap with the next priority tier
if basePriority == PriorityMatchDomain && priority <= PriorityDefault {
// Check if we're about to overlap with the next priority tier.
// This boundary check ensures that the priority of upstream handlers does not conflict
// with the default priority tier. By decrementing the priority for each handler, we avoid
// overlaps, but if the calculated priority falls into the default tier, we skip the remaining
// handlers to maintain the integrity of the priority system.
if basePriority == PriorityUpstream && priority <= PriorityDefault {
log.Warnf("too many handlers for domain=%s, would overlap with default priority tier (diff=%d). Skipping remaining handlers",
domainGroup.domain, PriorityMatchDomain-PriorityDefault)
domainGroup.domain, PriorityUpstream-PriorityDefault)
break
}

View File

@@ -164,12 +164,12 @@ func TestUpdateDNSServer(t *testing.T) {
generateDummyHandler("netbird.io", nameServers).ID(): handlerWrapper{
domain: "netbird.io",
handler: dummyHandler,
priority: PriorityMatchDomain,
priority: PriorityUpstream,
},
dummyHandler.ID(): handlerWrapper{
domain: "netbird.cloud",
handler: dummyHandler,
priority: PriorityMatchDomain,
priority: PriorityLocal,
},
generateDummyHandler(".", nameServers).ID(): handlerWrapper{
domain: nbdns.RootZone,
@@ -186,7 +186,7 @@ func TestUpdateDNSServer(t *testing.T) {
generateDummyHandler(zoneRecords[0].Name, nameServers).ID(): handlerWrapper{
domain: "netbird.cloud",
handler: dummyHandler,
priority: PriorityMatchDomain,
priority: PriorityUpstream,
},
},
initSerial: 0,
@@ -210,12 +210,12 @@ func TestUpdateDNSServer(t *testing.T) {
generateDummyHandler("netbird.io", nameServers).ID(): handlerWrapper{
domain: "netbird.io",
handler: dummyHandler,
priority: PriorityMatchDomain,
priority: PriorityUpstream,
},
"local-resolver": handlerWrapper{
domain: "netbird.cloud",
handler: dummyHandler,
priority: PriorityMatchDomain,
priority: PriorityLocal,
},
},
expectedLocalQs: []dns.Question{{Name: zoneRecords[0].Name, Qtype: 1, Qclass: 1}},
@@ -305,7 +305,7 @@ func TestUpdateDNSServer(t *testing.T) {
generateDummyHandler(zoneRecords[0].Name, nameServers).ID(): handlerWrapper{
domain: zoneRecords[0].Name,
handler: dummyHandler,
priority: PriorityMatchDomain,
priority: PriorityUpstream,
},
},
initSerial: 0,
@@ -321,7 +321,7 @@ func TestUpdateDNSServer(t *testing.T) {
generateDummyHandler(zoneRecords[0].Name, nameServers).ID(): handlerWrapper{
domain: zoneRecords[0].Name,
handler: dummyHandler,
priority: PriorityMatchDomain,
priority: PriorityUpstream,
},
},
initSerial: 0,
@@ -495,7 +495,7 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) {
"id1": handlerWrapper{
domain: zoneRecords[0].Name,
handler: &local.Resolver{},
priority: PriorityMatchDomain,
priority: PriorityUpstream,
},
}
//dnsServer.localResolver.RegisteredMap = local.RegistrationMap{local.BuildRecordKey("netbird.cloud", dns.ClassINET, dns.TypeA): struct{}{}}
@@ -978,7 +978,7 @@ func TestHandlerChain_DomainPriorities(t *testing.T) {
}
chain.AddHandler("example.com.", dnsRouteHandler, PriorityDNSRoute)
chain.AddHandler("example.com.", upstreamHandler, PriorityMatchDomain)
chain.AddHandler("example.com.", upstreamHandler, PriorityUpstream)
testCases := []struct {
name string
@@ -1059,14 +1059,14 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
handler: &mockHandler{
Id: "upstream-group1",
},
priority: PriorityMatchDomain,
priority: PriorityUpstream,
},
"upstream-group2": {
domain: "example.com",
handler: &mockHandler{
Id: "upstream-group2",
},
priority: PriorityMatchDomain - 1,
priority: PriorityUpstream - 1,
},
}
@@ -1093,21 +1093,21 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
handler: &mockHandler{
Id: "upstream-group1",
},
priority: PriorityMatchDomain,
priority: PriorityUpstream,
},
"upstream-group2": {
domain: "example.com",
handler: &mockHandler{
Id: "upstream-group2",
},
priority: PriorityMatchDomain - 1,
priority: PriorityUpstream - 1,
},
"upstream-other": {
domain: "other.com",
handler: &mockHandler{
Id: "upstream-other",
},
priority: PriorityMatchDomain,
priority: PriorityUpstream,
},
}
@@ -1128,7 +1128,7 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
handler: &mockHandler{
Id: "upstream-group2",
},
priority: PriorityMatchDomain - 1,
priority: PriorityUpstream - 1,
},
},
expectedHandlers: map[string]string{
@@ -1146,7 +1146,7 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
handler: &mockHandler{
Id: "upstream-group1",
},
priority: PriorityMatchDomain,
priority: PriorityUpstream,
},
},
expectedHandlers: map[string]string{
@@ -1164,7 +1164,7 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
handler: &mockHandler{
Id: "upstream-group3",
},
priority: PriorityMatchDomain + 1,
priority: PriorityUpstream + 1,
},
// Keep existing groups with their original priorities
{
@@ -1172,14 +1172,14 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
handler: &mockHandler{
Id: "upstream-group1",
},
priority: PriorityMatchDomain,
priority: PriorityUpstream,
},
{
domain: "example.com",
handler: &mockHandler{
Id: "upstream-group2",
},
priority: PriorityMatchDomain - 1,
priority: PriorityUpstream - 1,
},
},
expectedHandlers: map[string]string{
@@ -1199,14 +1199,14 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
handler: &mockHandler{
Id: "upstream-group1",
},
priority: PriorityMatchDomain,
priority: PriorityUpstream,
},
{
domain: "example.com",
handler: &mockHandler{
Id: "upstream-group2",
},
priority: PriorityMatchDomain - 1,
priority: PriorityUpstream - 1,
},
// Add group3 with lowest priority
{
@@ -1214,7 +1214,7 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
handler: &mockHandler{
Id: "upstream-group3",
},
priority: PriorityMatchDomain - 2,
priority: PriorityUpstream - 2,
},
},
expectedHandlers: map[string]string{
@@ -1335,14 +1335,14 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
handler: &mockHandler{
Id: "upstream-group1",
},
priority: PriorityMatchDomain,
priority: PriorityUpstream,
},
{
domain: "other.com",
handler: &mockHandler{
Id: "upstream-other",
},
priority: PriorityMatchDomain,
priority: PriorityUpstream,
},
},
expectedHandlers: map[string]string{
@@ -1360,28 +1360,28 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
handler: &mockHandler{
Id: "upstream-group1",
},
priority: PriorityMatchDomain,
priority: PriorityUpstream,
},
{
domain: "example.com",
handler: &mockHandler{
Id: "upstream-group2",
},
priority: PriorityMatchDomain - 1,
priority: PriorityUpstream - 1,
},
{
domain: "other.com",
handler: &mockHandler{
Id: "upstream-other",
},
priority: PriorityMatchDomain,
priority: PriorityUpstream,
},
{
domain: "new.com",
handler: &mockHandler{
Id: "upstream-new",
},
priority: PriorityMatchDomain,
priority: PriorityUpstream,
},
},
expectedHandlers: map[string]string{
@@ -1791,14 +1791,14 @@ func TestExtraDomainsRefCounting(t *testing.T) {
// Register domains from different handlers with same domain
server.RegisterHandler(domain.List{"*.shared.example.com"}, &MockHandler{}, PriorityDNSRoute)
server.RegisterHandler(domain.List{"shared.example.com."}, &MockHandler{}, PriorityMatchDomain)
server.RegisterHandler(domain.List{"shared.example.com."}, &MockHandler{}, PriorityUpstream)
// Verify refcount is 2
zoneKey := toZone("shared.example.com")
assert.Equal(t, 2, server.extraDomains[zoneKey], "Refcount should be 2 after registering same domain twice")
// Deregister one handler
server.DeregisterHandler(domain.List{"shared.example.com"}, PriorityMatchDomain)
server.DeregisterHandler(domain.List{"shared.example.com"}, PriorityUpstream)
// Verify refcount is 1
assert.Equal(t, 1, server.extraDomains[zoneKey], "Refcount should be 1 after deregistering one handler")
@@ -1925,7 +1925,7 @@ func TestDomainCaseHandling(t *testing.T) {
}
server.RegisterHandler(domain.List{"MIXED.example.com"}, &MockHandler{}, PriorityDefault)
server.RegisterHandler(domain.List{"mixed.EXAMPLE.com"}, &MockHandler{}, PriorityMatchDomain)
server.RegisterHandler(domain.List{"mixed.EXAMPLE.com"}, &MockHandler{}, PriorityUpstream)
assert.Equal(t, 1, len(server.extraDomains), "Case differences should be normalized")
@@ -1945,3 +1945,111 @@ func TestDomainCaseHandling(t *testing.T) {
assert.Contains(t, domains, "config.example.com.", "Mixed case domain should be normalized and pre.sent")
assert.Contains(t, domains, "mixed.example.com.", "Mixed case domain should be normalized and present")
}
func TestLocalResolverPriorityInServer(t *testing.T) {
server := &DefaultServer{
ctx: context.Background(),
wgInterface: &mocWGIface{},
handlerChain: NewHandlerChain(),
localResolver: local.NewResolver(),
service: &mockService{},
extraDomains: make(map[domain.Domain]int),
}
config := nbdns.Config{
ServiceEnable: true,
CustomZones: []nbdns.CustomZone{
{
Domain: "local.example.com",
Records: []nbdns.SimpleRecord{
{
Name: "test.local.example.com",
Type: int(dns.TypeA),
Class: nbdns.DefaultClass,
TTL: 300,
RData: "192.168.1.100",
},
},
},
},
NameServerGroups: []*nbdns.NameServerGroup{
{
Domains: []string{"local.example.com"}, // Same domain as local records
NameServers: []nbdns.NameServer{
{
IP: netip.MustParseAddr("8.8.8.8"),
NSType: nbdns.UDPNameServerType,
Port: 53,
},
},
},
},
}
localMuxUpdates, _, err := server.buildLocalHandlerUpdate(config.CustomZones)
assert.NoError(t, err)
upstreamMuxUpdates, err := server.buildUpstreamHandlerUpdate(config.NameServerGroups)
assert.NoError(t, err)
// Verify that local handler has higher priority than upstream for same domain
var localPriority, upstreamPriority int
localFound, upstreamFound := false, false
for _, update := range localMuxUpdates {
if update.domain == "local.example.com" {
localPriority = update.priority
localFound = true
}
}
for _, update := range upstreamMuxUpdates {
if update.domain == "local.example.com" {
upstreamPriority = update.priority
upstreamFound = true
}
}
assert.True(t, localFound, "Local handler should be found")
assert.True(t, upstreamFound, "Upstream handler should be found")
assert.Greater(t, localPriority, upstreamPriority,
"Local handler priority (%d) should be higher than upstream priority (%d)",
localPriority, upstreamPriority)
assert.Equal(t, PriorityLocal, localPriority, "Local handler should use PriorityLocal")
assert.Equal(t, PriorityUpstream, upstreamPriority, "Upstream handler should use PriorityUpstream")
}
func TestLocalResolverPriorityConstants(t *testing.T) {
// Test that priority constants are ordered correctly
assert.Greater(t, PriorityLocal, PriorityDNSRoute, "Local priority should be higher than DNS route")
assert.Greater(t, PriorityLocal, PriorityUpstream, "Local priority should be higher than upstream")
assert.Greater(t, PriorityUpstream, PriorityDefault, "Upstream priority should be higher than default")
// Test that local resolver uses the correct priority
server := &DefaultServer{
localResolver: local.NewResolver(),
}
config := nbdns.Config{
CustomZones: []nbdns.CustomZone{
{
Domain: "local.example.com",
Records: []nbdns.SimpleRecord{
{
Name: "test.local.example.com",
Type: int(dns.TypeA),
Class: nbdns.DefaultClass,
TTL: 300,
RData: "192.168.1.100",
},
},
},
},
}
localMuxUpdates, _, err := server.buildLocalHandlerUpdate(config.CustomZones)
assert.NoError(t, err)
assert.Len(t, localMuxUpdates, 1)
assert.Equal(t, PriorityLocal, localMuxUpdates[0].priority, "Local handler should use PriorityLocal")
assert.Equal(t, "local.example.com", localMuxUpdates[0].domain)
}

View File

@@ -2,6 +2,7 @@ package dns
import (
"context"
"crypto/rand"
"crypto/sha256"
"encoding/hex"
"errors"
@@ -103,19 +104,21 @@ func (u *upstreamResolverBase) Stop() {
// ServeDNS handles a DNS request
func (u *upstreamResolverBase) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
requestID := GenerateRequestID()
logger := log.WithField("request_id", requestID)
var err error
defer func() {
u.checkUpstreamFails(err)
}()
log.Tracef("received upstream question: domain=%s type=%v class=%v", r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass)
logger.Tracef("received upstream question: domain=%s type=%v class=%v", r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass)
if r.Extra == nil {
r.MsgHdr.AuthenticatedData = true
}
select {
case <-u.ctx.Done():
log.Tracef("%s has been stopped", u)
logger.Tracef("%s has been stopped", u)
return
default:
}
@@ -132,35 +135,35 @@ func (u *upstreamResolverBase) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
if err != nil {
if errors.Is(err, context.DeadlineExceeded) || isTimeout(err) {
log.Warnf("upstream %s timed out for question domain=%s", upstream, r.Question[0].Name)
logger.Warnf("upstream %s timed out for question domain=%s", upstream, r.Question[0].Name)
continue
}
log.Warnf("failed to query upstream %s for question domain=%s: %s", upstream, r.Question[0].Name, err)
logger.Warnf("failed to query upstream %s for question domain=%s: %s", upstream, r.Question[0].Name, err)
continue
}
if rm == nil || !rm.Response {
log.Warnf("no response from upstream %s for question domain=%s", upstream, r.Question[0].Name)
logger.Warnf("no response from upstream %s for question domain=%s", upstream, r.Question[0].Name)
continue
}
u.successCount.Add(1)
log.Tracef("took %s to query the upstream %s for question domain=%s", t, upstream, r.Question[0].Name)
logger.Tracef("took %s to query the upstream %s for question domain=%s", t, upstream, r.Question[0].Name)
if err = w.WriteMsg(rm); err != nil {
log.Errorf("failed to write DNS response for question domain=%s: %s", r.Question[0].Name, err)
logger.Errorf("failed to write DNS response for question domain=%s: %s", r.Question[0].Name, err)
}
// count the fails only if they happen sequentially
u.failsCount.Store(0)
return
}
u.failsCount.Add(1)
log.Errorf("all queries to the %s failed for question domain=%s", u, r.Question[0].Name)
logger.Errorf("all queries to the %s failed for question domain=%s", u, r.Question[0].Name)
m := new(dns.Msg)
m.SetRcode(r, dns.RcodeServerFailure)
if err := w.WriteMsg(m); err != nil {
log.Errorf("failed to write error response for %s for question domain=%s: %s", u, r.Question[0].Name, err)
logger.Errorf("failed to write error response for %s for question domain=%s: %s", u, r.Question[0].Name, err)
}
}
@@ -385,3 +388,13 @@ func ExchangeWithFallback(ctx context.Context, client *dns.Client, r *dns.Msg, u
return rm, t, nil
}
func GenerateRequestID() string {
bytes := make([]byte, 4)
_, err := rand.Read(bytes)
if err != nil {
log.Errorf("failed to generate request ID: %v", err)
return ""
}
return hex.EncodeToString(bytes)
}

View File

@@ -18,14 +18,20 @@ import (
nberrors "github.com/netbirdio/netbird/client/errors"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/internal/peer"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/domain"
"github.com/netbirdio/netbird/route"
)
const errResolveFailed = "failed to resolve query for domain=%s: %v"
const upstreamTimeout = 15 * time.Second
type resolver interface {
LookupNetIP(ctx context.Context, network, host string) ([]netip.Addr, error)
}
type firewaller interface {
UpdateSet(set firewall.Set, prefixes []netip.Prefix) error
}
type DNSForwarder struct {
listenAddress string
ttl uint32
@@ -38,16 +44,18 @@ type DNSForwarder struct {
mutex sync.RWMutex
fwdEntries []*ForwarderEntry
firewall firewall.Manager
firewall firewaller
resolver resolver
}
func NewDNSForwarder(listenAddress string, ttl uint32, firewall firewall.Manager, statusRecorder *peer.Status) *DNSForwarder {
func NewDNSForwarder(listenAddress string, ttl uint32, firewall firewaller, statusRecorder *peer.Status) *DNSForwarder {
log.Debugf("creating DNS forwarder with listen_address=%s ttl=%d", listenAddress, ttl)
return &DNSForwarder{
listenAddress: listenAddress,
ttl: ttl,
firewall: firewall,
statusRecorder: statusRecorder,
resolver: net.DefaultResolver,
}
}
@@ -57,14 +65,17 @@ func (f *DNSForwarder) Listen(entries []*ForwarderEntry) error {
// UDP server
mux := dns.NewServeMux()
f.mux = mux
mux.HandleFunc(".", f.handleDNSQueryUDP)
f.dnsServer = &dns.Server{
Addr: f.listenAddress,
Net: "udp",
Handler: mux,
}
// TCP server
tcpMux := dns.NewServeMux()
f.tcpMux = tcpMux
tcpMux.HandleFunc(".", f.handleDNSQueryTCP)
f.tcpServer = &dns.Server{
Addr: f.listenAddress,
Net: "tcp",
@@ -87,30 +98,13 @@ func (f *DNSForwarder) Listen(entries []*ForwarderEntry) error {
// return the first error we get (e.g. bind failure or shutdown)
return <-errCh
}
func (f *DNSForwarder) UpdateDomains(entries []*ForwarderEntry) {
f.mutex.Lock()
defer f.mutex.Unlock()
if f.mux == nil {
log.Debug("DNS mux is nil, skipping domain update")
f.fwdEntries = entries
return
}
oldDomains := filterDomains(f.fwdEntries)
for _, d := range oldDomains {
f.mux.HandleRemove(d.PunycodeString())
f.tcpMux.HandleRemove(d.PunycodeString())
}
newDomains := filterDomains(entries)
for _, d := range newDomains {
f.mux.HandleFunc(d.PunycodeString(), f.handleDNSQueryUDP)
f.tcpMux.HandleFunc(d.PunycodeString(), f.handleDNSQueryTCP)
}
f.fwdEntries = entries
log.Debugf("Updated domains from %v to %v", oldDomains, newDomains)
log.Debugf("Updated DNS forwarder with %d domains", len(entries))
}
func (f *DNSForwarder) Close(ctx context.Context) error {
@@ -157,22 +151,31 @@ func (f *DNSForwarder) handleDNSQuery(w dns.ResponseWriter, query *dns.Msg) *dns
return nil
}
mostSpecificResId, matchingEntries := f.getMatchingEntries(strings.TrimSuffix(domain, "."))
// query doesn't match any configured domain
if mostSpecificResId == "" {
resp.Rcode = dns.RcodeRefused
if err := w.WriteMsg(resp); err != nil {
log.Errorf("failed to write DNS response: %v", err)
}
return nil
}
ctx, cancel := context.WithTimeout(context.Background(), upstreamTimeout)
defer cancel()
ips, err := net.DefaultResolver.LookupNetIP(ctx, network, domain)
ips, err := f.resolver.LookupNetIP(ctx, network, domain)
if err != nil {
f.handleDNSError(w, query, resp, domain, err)
return nil
}
f.updateInternalState(domain, ips)
f.updateInternalState(ips, mostSpecificResId, matchingEntries)
f.addIPsToResponse(resp, domain, ips)
return resp
}
func (f *DNSForwarder) handleDNSQueryUDP(w dns.ResponseWriter, query *dns.Msg) {
resp := f.handleDNSQuery(w, query)
if resp == nil {
return
@@ -206,9 +209,8 @@ func (f *DNSForwarder) handleDNSQueryTCP(w dns.ResponseWriter, query *dns.Msg) {
}
}
func (f *DNSForwarder) updateInternalState(domain string, ips []netip.Addr) {
func (f *DNSForwarder) updateInternalState(ips []netip.Addr, mostSpecificResId route.ResID, matchingEntries []*ForwarderEntry) {
var prefixes []netip.Prefix
mostSpecificResId, matchingEntries := f.getMatchingEntries(strings.TrimSuffix(domain, "."))
if mostSpecificResId != "" {
for _, ip := range ips {
var prefix netip.Prefix
@@ -339,16 +341,3 @@ func (f *DNSForwarder) getMatchingEntries(domain string) (route.ResID, []*Forwar
return selectedResId, matches
}
// filterDomains returns a list of normalized domains
func filterDomains(entries []*ForwarderEntry) domain.List {
newDomains := make(domain.List, 0, len(entries))
for _, d := range entries {
if d.Domain == "" {
log.Warn("empty domain in DNS forwarder")
continue
}
newDomains = append(newDomains, domain.Domain(nbdns.NormalizeZone(d.Domain.PunycodeString())))
}
return newDomains
}

View File

@@ -1,11 +1,21 @@
package dnsfwd
import (
"context"
"fmt"
"net/netip"
"strings"
"testing"
"time"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/internal/dns/test"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/management/domain"
"github.com/netbirdio/netbird/route"
)
@@ -13,7 +23,7 @@ import (
func Test_getMatchingEntries(t *testing.T) {
testCases := []struct {
name string
storedMappings map[string]route.ResID // key: domain pattern, value: resId
storedMappings map[string]route.ResID
queryDomain string
expectedResId route.ResID
}{
@@ -44,7 +54,7 @@ func Test_getMatchingEntries(t *testing.T) {
{
name: "Wildcard pattern does not match different domain",
storedMappings: map[string]route.ResID{"*.example.com": "res4"},
queryDomain: "foo.notexample.com",
queryDomain: "foo.example.org",
expectedResId: "",
},
{
@@ -101,3 +111,619 @@ func Test_getMatchingEntries(t *testing.T) {
})
}
}
type MockFirewall struct {
mock.Mock
}
func (m *MockFirewall) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
args := m.Called(set, prefixes)
return args.Error(0)
}
type MockResolver struct {
mock.Mock
}
func (m *MockResolver) LookupNetIP(ctx context.Context, network, host string) ([]netip.Addr, error) {
args := m.Called(ctx, network, host)
return args.Get(0).([]netip.Addr), args.Error(1)
}
func TestDNSForwarder_SubdomainAccessLogic(t *testing.T) {
tests := []struct {
name string
configuredDomain string
queryDomain string
shouldMatch bool
expectedResID route.ResID
description string
}{
{
name: "exact domain match should be allowed",
configuredDomain: "example.com",
queryDomain: "example.com",
shouldMatch: true,
expectedResID: "test-res-id",
description: "Direct match to configured domain should work",
},
{
name: "subdomain access should be restricted",
configuredDomain: "example.com",
queryDomain: "mail.example.com",
shouldMatch: false,
expectedResID: "",
description: "Subdomain should not be accessible unless explicitly configured",
},
{
name: "wildcard should allow subdomains",
configuredDomain: "*.example.com",
queryDomain: "mail.example.com",
shouldMatch: true,
expectedResID: "test-res-id",
description: "Wildcard domains should allow subdomain access",
},
{
name: "wildcard should allow base domain",
configuredDomain: "*.example.com",
queryDomain: "example.com",
shouldMatch: true,
expectedResID: "test-res-id",
description: "Wildcard should also match the base domain",
},
{
name: "deep subdomain should be restricted",
configuredDomain: "example.com",
queryDomain: "deep.mail.example.com",
shouldMatch: false,
expectedResID: "",
description: "Deep subdomains should not be accessible",
},
{
name: "wildcard allows deep subdomains",
configuredDomain: "*.example.com",
queryDomain: "deep.mail.example.com",
shouldMatch: true,
expectedResID: "test-res-id",
description: "Wildcard should allow deep subdomains",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
forwarder := &DNSForwarder{}
d, err := domain.FromString(tt.configuredDomain)
require.NoError(t, err)
entries := []*ForwarderEntry{
{
Domain: d,
ResID: "test-res-id",
},
}
forwarder.UpdateDomains(entries)
resID, matchingEntries := forwarder.getMatchingEntries(tt.queryDomain)
if tt.shouldMatch {
assert.Equal(t, tt.expectedResID, resID, "Expected matching ResID")
assert.NotEmpty(t, matchingEntries, "Expected matching entries")
t.Logf("✓ Domain %s correctly matches pattern %s", tt.queryDomain, tt.configuredDomain)
} else {
assert.Equal(t, tt.expectedResID, resID, "Expected no ResID match")
assert.Empty(t, matchingEntries, "Expected no matching entries")
t.Logf("✓ Domain %s correctly does NOT match pattern %s", tt.queryDomain, tt.configuredDomain)
}
})
}
}
func TestDNSForwarder_UnauthorizedDomainAccess(t *testing.T) {
if testing.Short() {
t.Skip("Skipping integration test in short mode")
}
tests := []struct {
name string
configuredDomain string
queryDomain string
shouldResolve bool
description string
}{
{
name: "configured exact domain resolves",
configuredDomain: "example.com",
queryDomain: "example.com",
shouldResolve: true,
description: "Exact match should resolve",
},
{
name: "unauthorized subdomain blocked",
configuredDomain: "example.com",
queryDomain: "mail.example.com",
shouldResolve: false,
description: "Subdomain should be blocked without wildcard",
},
{
name: "wildcard allows subdomain",
configuredDomain: "*.example.com",
queryDomain: "mail.example.com",
shouldResolve: true,
description: "Wildcard should allow subdomain",
},
{
name: "wildcard allows base domain",
configuredDomain: "*.example.com",
queryDomain: "example.com",
shouldResolve: true,
description: "Wildcard should allow base domain",
},
{
name: "unrelated domain blocked",
configuredDomain: "example.com",
queryDomain: "example.org",
shouldResolve: false,
description: "Unrelated domain should be blocked",
},
{
name: "deep subdomain blocked",
configuredDomain: "example.com",
queryDomain: "deep.mail.example.com",
shouldResolve: false,
description: "Deep subdomain should be blocked",
},
{
name: "wildcard allows deep subdomain",
configuredDomain: "*.example.com",
queryDomain: "deep.mail.example.com",
shouldResolve: true,
description: "Wildcard should allow deep subdomain",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
mockFirewall := &MockFirewall{}
mockResolver := &MockResolver{}
if tt.shouldResolve {
mockFirewall.On("UpdateSet", mock.AnythingOfType("manager.Set"), mock.AnythingOfType("[]netip.Prefix")).Return(nil)
// Mock successful DNS resolution
fakeIP := netip.MustParseAddr("1.2.3.4")
mockResolver.On("LookupNetIP", mock.Anything, "ip4", dns.Fqdn(tt.queryDomain)).Return([]netip.Addr{fakeIP}, nil)
}
forwarder := NewDNSForwarder("127.0.0.1:0", 300, mockFirewall, &peer.Status{})
forwarder.resolver = mockResolver
d, err := domain.FromString(tt.configuredDomain)
require.NoError(t, err)
entries := []*ForwarderEntry{
{
Domain: d,
ResID: "test-res-id",
Set: firewall.NewDomainSet([]domain.Domain{d}),
},
}
forwarder.UpdateDomains(entries)
query := &dns.Msg{}
query.SetQuestion(dns.Fqdn(tt.queryDomain), dns.TypeA)
mockWriter := &test.MockResponseWriter{}
resp := forwarder.handleDNSQuery(mockWriter, query)
if tt.shouldResolve {
require.NotNil(t, resp, "Expected response for authorized domain")
require.Equal(t, dns.RcodeSuccess, resp.Rcode, "Expected successful response")
assert.NotEmpty(t, resp.Answer, "Expected DNS answer records")
time.Sleep(10 * time.Millisecond)
mockFirewall.AssertExpectations(t)
mockResolver.AssertExpectations(t)
} else {
if resp != nil {
assert.True(t, len(resp.Answer) == 0 || resp.Rcode != dns.RcodeSuccess,
"Unauthorized domain should not return successful answers")
}
mockFirewall.AssertNotCalled(t, "UpdateSet")
mockResolver.AssertNotCalled(t, "LookupNetIP")
}
})
}
}
func TestDNSForwarder_FirewallSetUpdates(t *testing.T) {
tests := []struct {
name string
configuredDomains []string
query string
mockIP string
shouldResolve bool
expectedSetCount int // How many sets should be updated
description string
}{
{
name: "exact domain gets firewall update",
configuredDomains: []string{"example.com"},
query: "example.com",
mockIP: "1.1.1.1",
shouldResolve: true,
expectedSetCount: 1,
description: "Single exact match updates one set",
},
{
name: "wildcard domain gets firewall update",
configuredDomains: []string{"*.example.com"},
query: "mail.example.com",
mockIP: "1.1.1.2",
shouldResolve: true,
expectedSetCount: 1,
description: "Wildcard match updates one set",
},
{
name: "overlapping exact and wildcard both get updates",
configuredDomains: []string{"*.example.com", "mail.example.com"},
query: "mail.example.com",
mockIP: "1.1.1.3",
shouldResolve: true,
expectedSetCount: 2,
description: "Both exact and wildcard sets should be updated",
},
{
name: "unauthorized domain gets no firewall update",
configuredDomains: []string{"example.com"},
query: "mail.example.com",
mockIP: "1.1.1.4",
shouldResolve: false,
expectedSetCount: 0,
description: "No firewall update for unauthorized domains",
},
{
name: "multiple wildcards matching get all updated",
configuredDomains: []string{"*.example.com", "*.sub.example.com"},
query: "test.sub.example.com",
mockIP: "1.1.1.5",
shouldResolve: true,
expectedSetCount: 2,
description: "All matching wildcard sets should be updated",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
mockFirewall := &MockFirewall{}
mockResolver := &MockResolver{}
// Set up forwarder
forwarder := NewDNSForwarder("127.0.0.1:0", 300, mockFirewall, &peer.Status{})
forwarder.resolver = mockResolver
// Create entries and track sets
var entries []*ForwarderEntry
sets := make([]firewall.Set, 0)
for i, configDomain := range tt.configuredDomains {
d, err := domain.FromString(configDomain)
require.NoError(t, err)
set := firewall.NewDomainSet([]domain.Domain{d})
sets = append(sets, set)
entries = append(entries, &ForwarderEntry{
Domain: d,
ResID: route.ResID(fmt.Sprintf("res-%d", i)),
Set: set,
})
}
forwarder.UpdateDomains(entries)
// Set up mocks
if tt.shouldResolve {
fakeIP := netip.MustParseAddr(tt.mockIP)
mockResolver.On("LookupNetIP", mock.Anything, "ip4", dns.Fqdn(tt.query)).
Return([]netip.Addr{fakeIP}, nil).Once()
expectedPrefixes := []netip.Prefix{netip.PrefixFrom(fakeIP, 32)}
// Count how many sets should actually match
updateCount := 0
for i, entry := range entries {
domain := strings.ToLower(tt.query)
pattern := entry.Domain.PunycodeString()
matches := false
if strings.HasPrefix(pattern, "*.") {
baseDomain := strings.TrimPrefix(pattern, "*.")
if domain == baseDomain || strings.HasSuffix(domain, "."+baseDomain) {
matches = true
}
} else if domain == pattern {
matches = true
}
if matches {
mockFirewall.On("UpdateSet", sets[i], expectedPrefixes).Return(nil).Once()
updateCount++
}
}
assert.Equal(t, tt.expectedSetCount, updateCount,
"Expected %d sets to be updated, but mock expects %d",
tt.expectedSetCount, updateCount)
}
// Execute query
dnsQuery := &dns.Msg{}
dnsQuery.SetQuestion(dns.Fqdn(tt.query), dns.TypeA)
mockWriter := &test.MockResponseWriter{}
resp := forwarder.handleDNSQuery(mockWriter, dnsQuery)
// Verify response
if tt.shouldResolve {
require.NotNil(t, resp, "Expected response for authorized domain")
require.Equal(t, dns.RcodeSuccess, resp.Rcode)
require.NotEmpty(t, resp.Answer)
} else if resp != nil {
assert.True(t, resp.Rcode == dns.RcodeRefused || len(resp.Answer) == 0,
"Unauthorized domain should be refused or have no answers")
}
// Verify all mock expectations were met
mockFirewall.AssertExpectations(t)
mockResolver.AssertExpectations(t)
})
}
}
// Test to verify that multiple IPs for one domain result in all prefixes being sent together
func TestDNSForwarder_MultipleIPsInSingleUpdate(t *testing.T) {
mockFirewall := &MockFirewall{}
mockResolver := &MockResolver{}
forwarder := NewDNSForwarder("127.0.0.1:0", 300, mockFirewall, &peer.Status{})
forwarder.resolver = mockResolver
// Configure a single domain
d, err := domain.FromString("example.com")
require.NoError(t, err)
set := firewall.NewDomainSet([]domain.Domain{d})
entries := []*ForwarderEntry{{
Domain: d,
ResID: "test-res",
Set: set,
}}
forwarder.UpdateDomains(entries)
// Mock resolver returns multiple IPs
ips := []netip.Addr{
netip.MustParseAddr("1.1.1.1"),
netip.MustParseAddr("1.1.1.2"),
netip.MustParseAddr("1.1.1.3"),
}
mockResolver.On("LookupNetIP", mock.Anything, "ip4", "example.com.").
Return(ips, nil).Once()
// Expect ONE UpdateSet call with ALL prefixes
expectedPrefixes := []netip.Prefix{
netip.PrefixFrom(ips[0], 32),
netip.PrefixFrom(ips[1], 32),
netip.PrefixFrom(ips[2], 32),
}
mockFirewall.On("UpdateSet", set, expectedPrefixes).Return(nil).Once()
// Execute query
query := &dns.Msg{}
query.SetQuestion("example.com.", dns.TypeA)
mockWriter := &test.MockResponseWriter{}
resp := forwarder.handleDNSQuery(mockWriter, query)
// Verify response contains all IPs
require.NotNil(t, resp)
require.Equal(t, dns.RcodeSuccess, resp.Rcode)
require.Len(t, resp.Answer, 3, "Should have 3 answer records")
// Verify mocks
mockFirewall.AssertExpectations(t)
mockResolver.AssertExpectations(t)
}
func TestDNSForwarder_ResponseCodes(t *testing.T) {
tests := []struct {
name string
queryType uint16
queryDomain string
configured string
expectedCode int
description string
}{
{
name: "unauthorized domain returns REFUSED",
queryType: dns.TypeA,
queryDomain: "evil.com",
configured: "example.com",
expectedCode: dns.RcodeRefused,
description: "RFC compliant REFUSED for unauthorized queries",
},
{
name: "unsupported query type returns NOTIMP",
queryType: dns.TypeMX,
queryDomain: "example.com",
configured: "example.com",
expectedCode: dns.RcodeNotImplemented,
description: "RFC compliant NOTIMP for unsupported types",
},
{
name: "CNAME query returns NOTIMP",
queryType: dns.TypeCNAME,
queryDomain: "example.com",
configured: "example.com",
expectedCode: dns.RcodeNotImplemented,
description: "CNAME queries not supported",
},
{
name: "TXT query returns NOTIMP",
queryType: dns.TypeTXT,
queryDomain: "example.com",
configured: "example.com",
expectedCode: dns.RcodeNotImplemented,
description: "TXT queries not supported",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
forwarder := NewDNSForwarder("127.0.0.1:0", 300, nil, &peer.Status{})
d, err := domain.FromString(tt.configured)
require.NoError(t, err)
entries := []*ForwarderEntry{{Domain: d, ResID: "test-res"}}
forwarder.UpdateDomains(entries)
query := &dns.Msg{}
query.SetQuestion(dns.Fqdn(tt.queryDomain), tt.queryType)
// Capture the written response
var writtenResp *dns.Msg
mockWriter := &test.MockResponseWriter{
WriteMsgFunc: func(m *dns.Msg) error {
writtenResp = m
return nil
},
}
_ = forwarder.handleDNSQuery(mockWriter, query)
// Check the response written to the writer
require.NotNil(t, writtenResp, "Expected response to be written")
assert.Equal(t, tt.expectedCode, writtenResp.Rcode, tt.description)
})
}
}
func TestDNSForwarder_TCPTruncation(t *testing.T) {
// Test that large UDP responses are truncated with TC bit set
mockResolver := &MockResolver{}
forwarder := NewDNSForwarder("127.0.0.1:0", 300, nil, &peer.Status{})
forwarder.resolver = mockResolver
d, _ := domain.FromString("example.com")
entries := []*ForwarderEntry{{Domain: d, ResID: "test-res"}}
forwarder.UpdateDomains(entries)
// Mock many IPs to create a large response
var manyIPs []netip.Addr
for i := 0; i < 100; i++ {
manyIPs = append(manyIPs, netip.MustParseAddr(fmt.Sprintf("1.1.1.%d", i%256)))
}
mockResolver.On("LookupNetIP", mock.Anything, "ip4", "example.com.").Return(manyIPs, nil)
// Query without EDNS0
query := &dns.Msg{}
query.SetQuestion("example.com.", dns.TypeA)
var writtenResp *dns.Msg
mockWriter := &test.MockResponseWriter{
WriteMsgFunc: func(m *dns.Msg) error {
writtenResp = m
return nil
},
}
forwarder.handleDNSQueryUDP(mockWriter, query)
require.NotNil(t, writtenResp)
assert.True(t, writtenResp.Truncated, "Large response should be truncated")
assert.LessOrEqual(t, writtenResp.Len(), dns.MinMsgSize, "Response should fit in minimum UDP size")
}
func TestDNSForwarder_MultipleOverlappingPatterns(t *testing.T) {
// Test complex overlapping pattern scenarios
mockFirewall := &MockFirewall{}
mockResolver := &MockResolver{}
forwarder := NewDNSForwarder("127.0.0.1:0", 300, mockFirewall, &peer.Status{})
forwarder.resolver = mockResolver
// Set up complex overlapping patterns
patterns := []string{
"*.example.com", // Matches all subdomains
"*.mail.example.com", // More specific wildcard
"smtp.mail.example.com", // Exact match
"example.com", // Base domain
}
var entries []*ForwarderEntry
sets := make(map[string]firewall.Set)
for _, pattern := range patterns {
d, _ := domain.FromString(pattern)
set := firewall.NewDomainSet([]domain.Domain{d})
sets[pattern] = set
entries = append(entries, &ForwarderEntry{
Domain: d,
ResID: route.ResID("res-" + pattern),
Set: set,
})
}
forwarder.UpdateDomains(entries)
// Test smtp.mail.example.com - should match 3 patterns
fakeIP := netip.MustParseAddr("1.2.3.4")
mockResolver.On("LookupNetIP", mock.Anything, "ip4", "smtp.mail.example.com.").Return([]netip.Addr{fakeIP}, nil)
expectedPrefix := netip.PrefixFrom(fakeIP, 32)
// All three matching patterns should get firewall updates
mockFirewall.On("UpdateSet", sets["smtp.mail.example.com"], []netip.Prefix{expectedPrefix}).Return(nil)
mockFirewall.On("UpdateSet", sets["*.mail.example.com"], []netip.Prefix{expectedPrefix}).Return(nil)
mockFirewall.On("UpdateSet", sets["*.example.com"], []netip.Prefix{expectedPrefix}).Return(nil)
query := &dns.Msg{}
query.SetQuestion("smtp.mail.example.com.", dns.TypeA)
mockWriter := &test.MockResponseWriter{}
resp := forwarder.handleDNSQuery(mockWriter, query)
require.NotNil(t, resp)
assert.Equal(t, dns.RcodeSuccess, resp.Rcode)
// Verify all three sets were updated
mockFirewall.AssertExpectations(t)
// Verify the most specific ResID was selected
// (exact match should win over wildcards)
resID, matches := forwarder.getMatchingEntries("smtp.mail.example.com")
assert.Equal(t, route.ResID("res-smtp.mail.example.com"), resID)
assert.Len(t, matches, 3, "Should match 3 patterns")
}
func TestDNSForwarder_EmptyQuery(t *testing.T) {
// Test handling of malformed query with no questions
forwarder := NewDNSForwarder("127.0.0.1:0", 300, nil, &peer.Status{})
query := &dns.Msg{}
// Don't set any question
writeCalled := false
mockWriter := &test.MockResponseWriter{
WriteMsgFunc: func(m *dns.Msg) error {
writeCalled = true
return nil
},
}
resp := forwarder.handleDNSQuery(mockWriter, query)
assert.Nil(t, resp, "Should return nil for empty query")
assert.False(t, writeCalled, "Should not write response for empty query")
}

View File

@@ -1527,6 +1527,7 @@ func (e *Engine) newWgIface() (*iface.WGIface, error) {
MTU: iface.DefaultMTU,
TransportNet: transportNet,
FilterFn: e.addrViaRoutes,
DisableDNS: e.config.DisableDNS,
}
switch runtime.GOOS {

View File

@@ -68,3 +68,8 @@ func (i *Monitor) PauseTimer() {
func (i *Monitor) ResetTimer() {
i.timer.Reset(i.inactivityThreshold)
}
func (i *Monitor) ResetMonitor(ctx context.Context, timeoutChan chan peer.ConnID) {
i.Stop()
go i.Start(ctx, timeoutChan)
}

View File

@@ -58,7 +58,7 @@ type Manager struct {
// Route HA group management
peerToHAGroups map[string][]route.HAUniqueID // peer ID -> HA groups they belong to
haGroupToPeers map[route.HAUniqueID][]string // HA group -> peer IDs in the group
routesMu sync.RWMutex // protects route mappings
routesMu sync.RWMutex
onInactive chan peerid.ConnID
}
@@ -146,7 +146,7 @@ func (m *Manager) Start(ctx context.Context) {
case peerConnID := <-m.activityManager.OnActivityChan:
m.onPeerActivity(ctx, peerConnID)
case peerConnID := <-m.onInactive:
m.onPeerInactivityTimedOut(peerConnID)
m.onPeerInactivityTimedOut(ctx, peerConnID)
}
}
}
@@ -197,7 +197,7 @@ func (m *Manager) ExcludePeer(ctx context.Context, peerConfigs []lazyconn.PeerCo
return added
}
func (m *Manager) AddPeer(peerCfg lazyconn.PeerConfig) (bool, error) {
func (m *Manager) AddPeer(ctx context.Context, peerCfg lazyconn.PeerConfig) (bool, error) {
m.managedPeersMu.Lock()
defer m.managedPeersMu.Unlock()
@@ -225,6 +225,13 @@ func (m *Manager) AddPeer(peerCfg lazyconn.PeerConfig) (bool, error) {
peerCfg: &peerCfg,
expectedWatcher: watcherActivity,
}
// Check if this peer should be activated because its HA group peers are active
if group, ok := m.shouldActivateNewPeer(peerCfg.PublicKey); ok {
peerCfg.Log.Debugf("peer belongs to active HA group %s, will activate immediately", group)
m.activateNewPeerInActiveGroup(ctx, peerCfg)
}
return false, nil
}
@@ -315,36 +322,38 @@ func (m *Manager) activateSinglePeer(ctx context.Context, cfg *lazyconn.PeerConf
// activateHAGroupPeers activates all peers in HA groups that the given peer belongs to
func (m *Manager) activateHAGroupPeers(ctx context.Context, triggerPeerID string) {
var peersToActivate []string
m.routesMu.RLock()
haGroups := m.peerToHAGroups[triggerPeerID]
m.routesMu.RUnlock()
if len(haGroups) == 0 {
m.routesMu.RUnlock()
log.Debugf("peer %s is not part of any HA groups", triggerPeerID)
return
}
activatedCount := 0
for _, haGroup := range haGroups {
m.routesMu.RLock()
peers := m.haGroupToPeers[haGroup]
m.routesMu.RUnlock()
for _, peerID := range peers {
if peerID == triggerPeerID {
continue
if peerID != triggerPeerID {
peersToActivate = append(peersToActivate, peerID)
}
}
}
m.routesMu.RUnlock()
cfg, mp := m.getPeerForActivation(peerID)
if cfg == nil {
continue
}
activatedCount := 0
for _, peerID := range peersToActivate {
cfg, mp := m.getPeerForActivation(peerID)
if cfg == nil {
continue
}
if m.activateSinglePeer(ctx, cfg, mp) {
activatedCount++
cfg.Log.Infof("activated peer as part of HA group %s (triggered by %s)", haGroup, triggerPeerID)
m.peerStore.PeerConnOpen(m.engineCtx, cfg.PublicKey)
}
if m.activateSinglePeer(ctx, cfg, mp) {
activatedCount++
cfg.Log.Infof("activated peer as part of HA group (triggered by %s)", triggerPeerID)
m.peerStore.PeerConnOpen(m.engineCtx, cfg.PublicKey)
}
}
@@ -354,6 +363,51 @@ func (m *Manager) activateHAGroupPeers(ctx context.Context, triggerPeerID string
}
}
// shouldActivateNewPeer checks if a newly added peer should be activated
// because other peers in its HA groups are already active
func (m *Manager) shouldActivateNewPeer(peerID string) (route.HAUniqueID, bool) {
m.routesMu.RLock()
defer m.routesMu.RUnlock()
haGroups := m.peerToHAGroups[peerID]
if len(haGroups) == 0 {
return "", false
}
for _, haGroup := range haGroups {
peers := m.haGroupToPeers[haGroup]
for _, groupPeerID := range peers {
if groupPeerID == peerID {
continue
}
cfg, ok := m.managedPeers[groupPeerID]
if !ok {
continue
}
if mp, ok := m.managedPeersByConnID[cfg.PeerConnID]; ok && mp.expectedWatcher == watcherInactivity {
return haGroup, true
}
}
}
return "", false
}
// activateNewPeerInActiveGroup activates a newly added peer that should be active due to HA group
func (m *Manager) activateNewPeerInActiveGroup(ctx context.Context, peerCfg lazyconn.PeerConfig) {
mp, ok := m.managedPeersByConnID[peerCfg.PeerConnID]
if !ok {
return
}
if !m.activateSinglePeer(ctx, &peerCfg, mp) {
return
}
peerCfg.Log.Infof("activated newly added peer due to active HA group peers")
m.peerStore.PeerConnOpen(m.engineCtx, peerCfg.PublicKey)
}
func (m *Manager) addActivePeer(ctx context.Context, peerCfg lazyconn.PeerConfig) error {
if _, ok := m.managedPeers[peerCfg.PublicKey]; ok {
peerCfg.Log.Warnf("peer already managed")
@@ -415,6 +469,48 @@ func (m *Manager) close() {
log.Infof("lazy connection manager closed")
}
// shouldDeferIdleForHA checks if peer should stay connected due to HA group requirements
func (m *Manager) shouldDeferIdleForHA(peerID string) bool {
m.routesMu.RLock()
defer m.routesMu.RUnlock()
haGroups := m.peerToHAGroups[peerID]
if len(haGroups) == 0 {
return false
}
for _, haGroup := range haGroups {
groupPeers := m.haGroupToPeers[haGroup]
for _, groupPeerID := range groupPeers {
if groupPeerID == peerID {
continue
}
cfg, ok := m.managedPeers[groupPeerID]
if !ok {
continue
}
groupMp, ok := m.managedPeersByConnID[cfg.PeerConnID]
if !ok {
continue
}
if groupMp.expectedWatcher != watcherInactivity {
continue
}
// Other member is still connected, defer idle
if peer, ok := m.peerStore.PeerConn(groupPeerID); ok && peer.IsConnected() {
return true
}
}
}
return false
}
func (m *Manager) onPeerActivity(ctx context.Context, peerConnID peerid.ConnID) {
m.managedPeersMu.Lock()
defer m.managedPeersMu.Unlock()
@@ -441,7 +537,7 @@ func (m *Manager) onPeerActivity(ctx context.Context, peerConnID peerid.ConnID)
m.peerStore.PeerConnOpen(m.engineCtx, mp.peerCfg.PublicKey)
}
func (m *Manager) onPeerInactivityTimedOut(peerConnID peerid.ConnID) {
func (m *Manager) onPeerInactivityTimedOut(ctx context.Context, peerConnID peerid.ConnID) {
m.managedPeersMu.Lock()
defer m.managedPeersMu.Unlock()
@@ -456,6 +552,17 @@ func (m *Manager) onPeerInactivityTimedOut(peerConnID peerid.ConnID) {
return
}
if m.shouldDeferIdleForHA(mp.peerCfg.PublicKey) {
iw, ok := m.inactivityMonitors[peerConnID]
if ok {
mp.peerCfg.Log.Debugf("resetting inactivity timer due to HA group requirements")
iw.ResetMonitor(ctx, m.onInactive)
} else {
mp.peerCfg.Log.Errorf("inactivity monitor not found for HA defer reset")
}
return
}
mp.peerCfg.Log.Infof("connection timed out")
// this is blocking operation, potentially can be optimized
@@ -489,7 +596,7 @@ func (m *Manager) onPeerConnected(peerConnID peerid.ConnID) {
iw, ok := m.inactivityMonitors[mp.peerCfg.PeerConnID]
if !ok {
mp.peerCfg.Log.Errorf("inactivity monitor not found for peer")
mp.peerCfg.Log.Warnf("inactivity monitor not found for peer")
return
}

View File

@@ -317,12 +317,12 @@ func (conn *Conn) WgConfig() WgConfig {
return conn.config.WgConfig
}
// IsConnected unit tests only
// refactor unit test to use status recorder use refactor status recorded to manage connection status in peer.Conn
// IsConnected returns true if the peer is connected
func (conn *Conn) IsConnected() bool {
conn.mu.Lock()
defer conn.mu.Unlock()
return conn.currentConnPriority != conntype.None
return conn.evalStatus() == StatusConnected
}
func (conn *Conn) GetKey() string {

View File

@@ -144,15 +144,18 @@ func (d *DnsInterceptor) RemoveAllowedIPs() error {
// ServeDNS implements the dns.Handler interface
func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
requestID := nbdns.GenerateRequestID()
logger := log.WithField("request_id", requestID)
if len(r.Question) == 0 {
return
}
log.Tracef("received DNS request for domain=%s type=%v class=%v",
logger.Tracef("received DNS request for domain=%s type=%v class=%v",
r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass)
// pass if non A/AAAA query
if r.Question[0].Qtype != dns.TypeA && r.Question[0].Qtype != dns.TypeAAAA {
d.continueToNextHandler(w, r, "non A/AAAA query")
d.continueToNextHandler(w, r, logger, "non A/AAAA query")
return
}
@@ -161,19 +164,19 @@ func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
d.mu.RUnlock()
if peerKey == "" {
d.writeDNSError(w, r, "no current peer key")
d.writeDNSError(w, r, logger, "no current peer key")
return
}
upstreamIP, err := d.getUpstreamIP(peerKey)
if err != nil {
d.writeDNSError(w, r, fmt.Sprintf("get upstream IP: %v", err))
d.writeDNSError(w, r, logger, fmt.Sprintf("get upstream IP: %v", err))
return
}
client, err := nbdns.GetClientPrivate(d.wgInterface.Address().IP, d.wgInterface.Name(), nbdns.UpstreamTimeout)
if err != nil {
d.writeDNSError(w, r, fmt.Sprintf("create DNS client: %v", err))
d.writeDNSError(w, r, logger, fmt.Sprintf("create DNS client: %v", err))
return
}
@@ -184,9 +187,9 @@ func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
upstream := fmt.Sprintf("%s:%d", upstreamIP.String(), dnsfwd.ListenPort)
reply, _, err := nbdns.ExchangeWithFallback(context.TODO(), client, r, upstream)
if err != nil {
log.Errorf("failed to exchange DNS request with %s (%s) for domain=%s: %v", upstreamIP.String(), peerKey, r.Question[0].Name, err)
logger.Errorf("failed to exchange DNS request with %s (%s) for domain=%s: %v", upstreamIP.String(), peerKey, r.Question[0].Name, err)
if err := w.WriteMsg(&dns.Msg{MsgHdr: dns.MsgHdr{Rcode: dns.RcodeServerFailure, Id: r.Id}}); err != nil {
log.Errorf("failed writing DNS response: %v", err)
logger.Errorf("failed writing DNS response: %v", err)
}
return
}
@@ -196,34 +199,34 @@ func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
answer = reply.Answer
}
log.Tracef("upstream %s (%s) DNS response for domain=%s answers=%v", upstreamIP.String(), peerKey, r.Question[0].Name, answer)
logger.Tracef("upstream %s (%s) DNS response for domain=%s answers=%v", upstreamIP.String(), peerKey, r.Question[0].Name, answer)
reply.Id = r.Id
if err := d.writeMsg(w, reply); err != nil {
log.Errorf("failed writing DNS response: %v", err)
logger.Errorf("failed writing DNS response: %v", err)
}
}
func (d *DnsInterceptor) writeDNSError(w dns.ResponseWriter, r *dns.Msg, reason string) {
log.Warnf("failed to query upstream for domain=%s: %s", r.Question[0].Name, reason)
func (d *DnsInterceptor) writeDNSError(w dns.ResponseWriter, r *dns.Msg, logger *log.Entry, reason string) {
logger.Warnf("failed to query upstream for domain=%s: %s", r.Question[0].Name, reason)
resp := new(dns.Msg)
resp.SetRcode(r, dns.RcodeServerFailure)
if err := w.WriteMsg(resp); err != nil {
log.Errorf("failed to write DNS error response: %v", err)
logger.Errorf("failed to write DNS error response: %v", err)
}
}
// continueToNextHandler signals the handler chain to try the next handler
func (d *DnsInterceptor) continueToNextHandler(w dns.ResponseWriter, r *dns.Msg, reason string) {
log.Tracef("continuing to next handler for domain=%s reason=%s", r.Question[0].Name, reason)
func (d *DnsInterceptor) continueToNextHandler(w dns.ResponseWriter, r *dns.Msg, logger *log.Entry, reason string) {
logger.Tracef("continuing to next handler for domain=%s reason=%s", r.Question[0].Name, reason)
resp := new(dns.Msg)
resp.SetRcode(r, dns.RcodeNameError)
// Set Zero bit to signal handler chain to continue
resp.MsgHdr.Zero = true
if err := w.WriteMsg(resp); err != nil {
log.Errorf("failed writing DNS continue response: %v", err)
logger.Errorf("failed writing DNS continue response: %v", err)
}
}

View File

@@ -15,7 +15,7 @@ import (
// MockManager is the mock instance of a route manager
type MockManager struct {
ClassifyRoutesFunc func(routes []*route.Route) (map[route.ID]*route.Route, route.HAMap)
UpdateRoutesFunc func (updateSerial uint64, serverRoutes map[route.ID]*route.Route, clientRoutes route.HAMap, useNewDNSRoute bool) error
UpdateRoutesFunc func(updateSerial uint64, serverRoutes map[route.ID]*route.Route, clientRoutes route.HAMap, useNewDNSRoute bool) error
TriggerSelectionFunc func(haMap route.HAMap)
GetRouteSelectorFunc func() *routeselector.RouteSelector
GetClientRoutesFunc func() route.HAMap

View File

@@ -32,7 +32,6 @@ func (n *Notifier) SetListener(listener listener.NetworkChangeListener) {
func (n *Notifier) SetInitialClientRoutes(clientRoutes []*route.Route) {
nets := make([]string, 0)
for _, r := range clientRoutes {
// filter out domain routes
if r.IsDynamic() {
continue
}
@@ -46,30 +45,27 @@ func (n *Notifier) OnNewRoutes(idMap route.HAMap) {
if runtime.GOOS != "android" {
return
}
newNets := make([]string, 0)
var newNets []string
for _, routes := range idMap {
for _, r := range routes {
if r.IsDynamic() {
continue
}
newNets = append(newNets, r.Network.String())
}
}
sort.Strings(newNets)
switch runtime.GOOS {
case "android":
if !n.hasDiff(n.initialRouteRanges, newNets) {
return
}
default:
if !n.hasDiff(n.routeRanges, newNets) {
return
}
if !n.hasDiff(n.initialRouteRanges, newNets) {
return
}
n.routeRanges = newNets
n.notify()
}
// OnNewPrefixes is called from iOS only
func (n *Notifier) OnNewPrefixes(prefixes []netip.Prefix) {
newNets := make([]string, 0)
for _, prefix := range prefixes {
@@ -77,19 +73,11 @@ func (n *Notifier) OnNewPrefixes(prefixes []netip.Prefix) {
}
sort.Strings(newNets)
switch runtime.GOOS {
case "android":
if !n.hasDiff(n.initialRouteRanges, newNets) {
return
}
default:
if !n.hasDiff(n.routeRanges, newNets) {
return
}
if !n.hasDiff(n.routeRanges, newNets) {
return
}
n.routeRanges = newNets
n.notify()
}

View File

@@ -1,8 +1,10 @@
<Wix
xmlns="http://wixtoolset.org/schemas/v4/wxs">
xmlns="http://wixtoolset.org/schemas/v4/wxs"
xmlns:util="http://wixtoolset.org/schemas/v4/wxs/util">
<Package Name="NetBird" Version="$(env.NETBIRD_VERSION)" Manufacturer="NetBird GmbH" Language="1033" UpgradeCode="6456ec4e-3ad6-4b9b-a2be-98e81cb21ccf"
InstallerVersion="500" Compressed="yes" Codepage="utf-8" >
<MediaTemplate EmbedCab="yes" />
<Feature Id="NetbirdFeature" Title="Netbird" Level="1">
@@ -46,29 +48,10 @@
<ComponentRef Id="NetbirdFiles" />
</ComponentGroup>
<Property Id="cmd" Value="cmd.exe"/>
<util:CloseApplication Id="CloseNetBird" CloseMessage="no" Target="netbird.exe" RebootPrompt="no" />
<util:CloseApplication Id="CloseNetBirdUI" CloseMessage="no" Target="netbird-ui.exe" RebootPrompt="no" />
<CustomAction Id="KillDaemon"
ExeCommand='/c "taskkill /im netbird.exe"'
Execute="deferred"
Property="cmd"
Impersonate="no"
Return="ignore"
/>
<CustomAction Id="KillUI"
ExeCommand='/c "taskkill /im netbird-ui.exe"'
Execute="deferred"
Property="cmd"
Impersonate="no"
Return="ignore"
/>
<InstallExecuteSequence>
<!-- For Uninstallation -->
<Custom Action="KillDaemon" Before="RemoveFiles" Condition="Installed"/>
<Custom Action="KillUI" After="KillDaemon" Condition="Installed"/>
</InstallExecuteSequence>
<!-- Icons -->
<Icon Id="NetbirdIcon" SourceFile=".\client\ui\assets\netbird.ico" />

View File

@@ -59,16 +59,16 @@ type Info struct {
Environment Environment
Files []File // for posture checks
RosenpassEnabled bool
RosenpassPermissive bool
ServerSSHAllowed bool
RosenpassEnabled bool
RosenpassPermissive bool
ServerSSHAllowed bool
DisableClientRoutes bool
DisableServerRoutes bool
DisableDNS bool
DisableFirewall bool
BlockLANAccess bool
BlockInbound bool
DisableClientRoutes bool
DisableServerRoutes bool
DisableDNS bool
DisableFirewall bool
BlockLANAccess bool
BlockInbound bool
LazyConnectionEnabled bool
}

View File

@@ -280,7 +280,7 @@ func newServiceClient(addr string, logFile string, a fyne.App, showSettings bool
showAdvancedSettings: showSettings,
showNetworks: showNetworks,
update: version.NewUpdate(),
update: version.NewUpdate("nb/client-ui"),
}
s.eventHandler = newEventHandler(s)
@@ -572,7 +572,7 @@ func (s *serviceClient) updateStatus() error {
var systrayIconState bool
switch {
case status.Status == string(internal.StatusConnected) && !s.mUp.Disabled():
case status.Status == string(internal.StatusConnected):
s.connected = true
s.sendNotification = true
if s.isUpdateIconActive {
@@ -879,7 +879,7 @@ func (s *serviceClient) onUpdateAvailable() {
func (s *serviceClient) onSessionExpire() {
s.sendNotification = true
if s.sendNotification {
s.eventHandler.runSelfCommand("login-url", "true")
s.eventHandler.runSelfCommand(s.ctx, "login-url", "true")
s.sendNotification = false
}
}
@@ -992,21 +992,6 @@ func (s *serviceClient) restartClient(loginRequest *proto.LoginRequest) error {
// showLoginURL creates a borderless window styled like a pop-up in the top-right corner using s.wLoginURL.
func (s *serviceClient) showLoginURL() {
resp, err := s.login(false)
if err != nil {
log.Errorf("failed to fetch login URL: %v", err)
return
}
verificationURL := resp.VerificationURIComplete
if verificationURL == "" {
verificationURL = resp.VerificationURI
}
if verificationURL == "" {
log.Error("no verification URL provided in the login response")
return
}
resIcon := fyne.NewStaticResource("netbird.png", iconAbout)
if s.wLoginURL == nil {
@@ -1025,6 +1010,21 @@ func (s *serviceClient) showLoginURL() {
return
}
resp, err := s.login(false)
if err != nil {
log.Errorf("failed to fetch login URL: %v", err)
return
}
verificationURL := resp.VerificationURIComplete
if verificationURL == "" {
verificationURL = resp.VerificationURI
}
if verificationURL == "" {
log.Error("no verification URL provided in the login response")
return
}
if err := openURL(verificationURL); err != nil {
log.Errorf("failed to open login URL: %v", err)
return
@@ -1038,7 +1038,19 @@ func (s *serviceClient) showLoginURL() {
}
label.SetText("Re-authentication successful.\nReconnecting")
time.Sleep(300 * time.Millisecond)
status, err := conn.Status(s.ctx, &proto.StatusRequest{})
if err != nil {
log.Errorf("get service status: %v", err)
return
}
if status.Status == string(internal.StatusConnected) {
label.SetText("Already connected.\nClosing this window.")
time.Sleep(2 * time.Second)
s.wLoginURL.Close()
return
}
_, err = conn.Up(s.ctx, &proto.UpRequest{})
if err != nil {
label.SetText("Reconnecting failed, please create \na debug bundle in the settings and contact support.")

View File

@@ -12,6 +12,8 @@ import (
"fyne.io/fyne/v2"
"fyne.io/systray"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/version"
)
type eventHandler struct {
@@ -120,7 +122,7 @@ func (h *eventHandler) handleAdvancedSettingsClick() {
go func() {
defer h.client.mAdvancedSettings.Enable()
defer h.client.getSrvConfig()
h.runSelfCommand("settings", "true")
h.runSelfCommand(h.client.ctx, "settings", "true")
}()
}
@@ -128,7 +130,7 @@ func (h *eventHandler) handleCreateDebugBundleClick() {
h.client.mCreateDebugBundle.Disable()
go func() {
defer h.client.mCreateDebugBundle.Enable()
h.runSelfCommand("debug", "true")
h.runSelfCommand(h.client.ctx, "debug", "true")
}()
}
@@ -143,7 +145,7 @@ func (h *eventHandler) handleGitHubClick() {
}
func (h *eventHandler) handleUpdateClick() {
if err := openURL("https://netbird.io/download"); err != nil {
if err := openURL(version.DownloadUrl()); err != nil {
log.Errorf("failed to open download URL: %v", err)
}
}
@@ -152,7 +154,7 @@ func (h *eventHandler) handleNetworksClick() {
h.client.mNetworks.Disable()
go func() {
defer h.client.mNetworks.Enable()
h.runSelfCommand("networks", "true")
h.runSelfCommand(h.client.ctx, "networks", "true")
}()
}
@@ -170,14 +172,14 @@ func (h *eventHandler) updateConfigWithErr() {
}
}
func (h *eventHandler) runSelfCommand(command, arg string) {
func (h *eventHandler) runSelfCommand(ctx context.Context, command, arg string) {
proc, err := os.Executable()
if err != nil {
log.Errorf("error getting executable path: %v", err)
return
}
cmd := exec.Command(proc,
cmd := exec.CommandContext(ctx, proc,
fmt.Sprintf("--%s=%s", command, arg),
fmt.Sprintf("--daemon-addr=%s", h.client.addr),
)

View File

@@ -60,7 +60,7 @@ NETBIRD_TOKEN_SOURCE=${NETBIRD_TOKEN_SOURCE:-accessToken}
NETBIRD_AUTH_PKCE_REDIRECT_URL_PORTS=${NETBIRD_AUTH_PKCE_REDIRECT_URL_PORTS:-"53000"}
NETBIRD_AUTH_PKCE_USE_ID_TOKEN=${NETBIRD_AUTH_PKCE_USE_ID_TOKEN:-false}
NETBIRD_AUTH_PKCE_DISABLE_PROMPT_LOGIN=${NETBIRD_AUTH_PKCE_DISABLE_PROMPT_LOGIN:-false}
NETBIRD_AUTH_PKCE_LOGIN_FLAG=${NETBIRD_AUTH_PKCE_LOGIN_FLAG:-1}
NETBIRD_AUTH_PKCE_LOGIN_FLAG=${NETBIRD_AUTH_PKCE_LOGIN_FLAG:-0}
NETBIRD_AUTH_PKCE_AUDIENCE=$NETBIRD_AUTH_AUDIENCE
# Dashboard

View File

@@ -357,6 +357,13 @@ var (
log.WithContext(ctx).Infof("running HTTP server and gRPC server on the same port: %s", listener.Addr().String())
serveGRPCWithHTTP(ctx, listener, rootHandler, tlsEnabled)
update := version.NewUpdate("nb/management")
update.SetDaemonVersion(version.NetbirdVersion())
update.SetOnUpdateListener(func() {
log.WithContext(ctx).Infof("your management version, \"%s\", is outdated, a new management version is available. Learn more here: https://github.com/netbirdio/netbird/releases", version.NetbirdVersion())
})
defer update.StopWatch()
SetupCloseHandler()
<-stopCh

View File

@@ -24,6 +24,7 @@ import (
"golang.org/x/exp/maps"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/formatter/hook"
"github.com/netbirdio/netbird/management/server/account"
"github.com/netbirdio/netbird/management/server/activity"
nbcache "github.com/netbirdio/netbird/management/server/cache"
@@ -369,7 +370,7 @@ func (am *DefaultAccountManager) validateSettingsUpdate(ctx context.Context, tra
return status.Errorf(status.InvalidArgument, "invalid domain \"%s\" provided for DNS domain", newSettings.DNSDomain)
}
peers, err := transaction.GetAccountPeers(ctx, store.LockingStrengthShare, accountID, "", "")
peers, err := transaction.GetAccountPeers(ctx, store.LockingStrengthNone, accountID, "", "")
if err != nil {
return err
}
@@ -409,14 +410,15 @@ func (am *DefaultAccountManager) handlePeerLoginExpirationSettings(ctx context.C
event = activity.AccountPeerLoginExpirationDisabled
am.peerLoginExpiry.Cancel(ctx, []string{accountID})
} else {
am.checkAndSchedulePeerLoginExpiration(ctx, accountID)
am.schedulePeerLoginExpiration(ctx, accountID)
}
am.StoreEvent(ctx, userID, accountID, accountID, event, nil)
}
if oldSettings.PeerLoginExpiration != newSettings.PeerLoginExpiration {
am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountPeerLoginExpirationDurationUpdated, nil)
am.checkAndSchedulePeerLoginExpiration(ctx, accountID)
am.peerLoginExpiry.Cancel(ctx, []string{accountID})
am.schedulePeerLoginExpiration(ctx, accountID)
}
}
@@ -454,6 +456,10 @@ func (am *DefaultAccountManager) handleInactivityExpirationSettings(ctx context.
func (am *DefaultAccountManager) peerLoginExpirationJob(ctx context.Context, accountID string) func() (time.Duration, bool) {
return func() (time.Duration, bool) {
//nolint
ctx := context.WithValue(ctx, nbcontext.AccountIDKey, accountID)
//nolint
ctx = context.WithValue(ctx, hook.ExecutionContextKey, fmt.Sprintf("%s-PEER-EXPIRATION", hook.SystemSource))
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
@@ -478,8 +484,11 @@ func (am *DefaultAccountManager) peerLoginExpirationJob(ctx context.Context, acc
}
}
func (am *DefaultAccountManager) checkAndSchedulePeerLoginExpiration(ctx context.Context, accountID string) {
am.peerLoginExpiry.Cancel(ctx, []string{accountID})
func (am *DefaultAccountManager) schedulePeerLoginExpiration(ctx context.Context, accountID string) {
if am.peerLoginExpiry.IsSchedulerRunning(accountID) {
log.WithContext(ctx).Tracef("peer login expiration job for account %s is already scheduled", accountID)
return
}
if nextRun, ok := am.getNextPeerExpiration(ctx, accountID); ok {
go am.peerLoginExpiry.Schedule(ctx, nextRun, accountID, am.peerLoginExpirationJob(ctx, accountID))
}
@@ -699,7 +708,7 @@ func (am *DefaultAccountManager) DeleteAccount(ctx context.Context, accountID, u
// AccountExists checks if an account exists.
func (am *DefaultAccountManager) AccountExists(ctx context.Context, accountID string) (bool, error) {
return am.Store.AccountExists(ctx, store.LockingStrengthShare, accountID)
return am.Store.AccountExists(ctx, store.LockingStrengthNone, accountID)
}
// GetAccountIDByUserID retrieves the account ID based on the userID provided.
@@ -711,7 +720,7 @@ func (am *DefaultAccountManager) GetAccountIDByUserID(ctx context.Context, userI
return "", status.Errorf(status.NotFound, "no valid userID provided")
}
accountID, err := am.Store.GetAccountIDByUserID(ctx, store.LockingStrengthShare, userID)
accountID, err := am.Store.GetAccountIDByUserID(ctx, store.LockingStrengthNone, userID)
if err != nil {
if s, ok := status.FromError(err); ok && s.Type() == status.NotFound {
account, err := am.GetOrCreateAccountByUser(ctx, userID, domain)
@@ -766,7 +775,7 @@ func (am *DefaultAccountManager) loadAccount(ctx context.Context, accountID any)
log.WithContext(ctx).Debugf("account %s not found in cache, reloading", accountID)
accountIDString := fmt.Sprintf("%v", accountID)
accountUsers, err := am.Store.GetAccountUsers(ctx, store.LockingStrengthShare, accountIDString)
accountUsers, err := am.Store.GetAccountUsers(ctx, store.LockingStrengthNone, accountIDString)
if err != nil {
return nil, nil, err
}
@@ -820,7 +829,7 @@ func (am *DefaultAccountManager) lookupUserInCacheByEmail(ctx context.Context, e
// lookupUserInCache looks up user in the IdP cache and returns it. If the user wasn't found, the function returns nil
func (am *DefaultAccountManager) lookupUserInCache(ctx context.Context, userID string, accountID string) (*idp.UserData, error) {
accountUsers, err := am.Store.GetAccountUsers(ctx, store.LockingStrengthShare, accountID)
accountUsers, err := am.Store.GetAccountUsers(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return nil, err
}
@@ -850,7 +859,7 @@ func (am *DefaultAccountManager) lookupUserInCache(ctx context.Context, userID s
// add extra check on external cache manager. We may get to this point when the user is not yet findable in IDP,
// or it didn't have its metadata updated with am.addAccountIDToIDPAppMeta
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID)
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, userID)
if err != nil {
log.WithContext(ctx).Errorf("failed finding user %s in account %s", userID, accountID)
return nil, err
@@ -1001,7 +1010,7 @@ func (am *DefaultAccountManager) updateAccountDomainAttributesIfNotUpToDate(ctx
unlockAccount := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlockAccount()
accountDomain, domainCategory, err := am.Store.GetAccountDomainAndCategory(ctx, store.LockingStrengthShare, accountID)
accountDomain, domainCategory, err := am.Store.GetAccountDomainAndCategory(ctx, store.LockingStrengthNone, accountID)
if err != nil {
log.WithContext(ctx).Errorf("error getting account domain and category: %v", err)
return err
@@ -1011,7 +1020,7 @@ func (am *DefaultAccountManager) updateAccountDomainAttributesIfNotUpToDate(ctx
return nil
}
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userAuth.UserId)
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, userAuth.UserId)
if err != nil {
log.WithContext(ctx).Errorf("error getting user: %v", err)
return err
@@ -1176,7 +1185,7 @@ func (am *DefaultAccountManager) GetAccountMeta(ctx context.Context, accountID s
return nil, status.NewPermissionDeniedError()
}
return am.Store.GetAccountMeta(ctx, store.LockingStrengthShare, accountID)
return am.Store.GetAccountMeta(ctx, store.LockingStrengthNone, accountID)
}
func (am *DefaultAccountManager) GetAccountIDFromUserAuth(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error) {
@@ -1196,7 +1205,7 @@ func (am *DefaultAccountManager) GetAccountIDFromUserAuth(ctx context.Context, u
return "", "", err
}
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userAuth.UserId)
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, userAuth.UserId)
if err != nil {
// this is not really possible because we got an account by user ID
return "", "", status.Errorf(status.NotFound, "user %s not found", userAuth.UserId)
@@ -1228,7 +1237,7 @@ func (am *DefaultAccountManager) SyncUserJWTGroups(ctx context.Context, userAuth
return nil
}
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, userAuth.AccountId)
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, userAuth.AccountId)
if err != nil {
return err
}
@@ -1254,12 +1263,12 @@ func (am *DefaultAccountManager) SyncUserJWTGroups(ctx context.Context, userAuth
var hasChanges bool
var user *types.User
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
user, err = transaction.GetUserByUserID(ctx, store.LockingStrengthShare, userAuth.UserId)
user, err = transaction.GetUserByUserID(ctx, store.LockingStrengthNone, userAuth.UserId)
if err != nil {
return fmt.Errorf("error getting user: %w", err)
}
groups, err := transaction.GetAccountGroups(ctx, store.LockingStrengthShare, userAuth.AccountId)
groups, err := transaction.GetAccountGroups(ctx, store.LockingStrengthNone, userAuth.AccountId)
if err != nil {
return fmt.Errorf("error getting account groups: %w", err)
}
@@ -1289,7 +1298,7 @@ func (am *DefaultAccountManager) SyncUserJWTGroups(ctx context.Context, userAuth
// Propagate changes to peers if group propagation is enabled
if settings.GroupsPropagationEnabled {
groups, err = transaction.GetAccountGroups(ctx, store.LockingStrengthShare, userAuth.AccountId)
groups, err = transaction.GetAccountGroups(ctx, store.LockingStrengthNone, userAuth.AccountId)
if err != nil {
return fmt.Errorf("error getting account groups: %w", err)
}
@@ -1299,7 +1308,7 @@ func (am *DefaultAccountManager) SyncUserJWTGroups(ctx context.Context, userAuth
groupsMap[group.ID] = group
}
peers, err := transaction.GetUserPeers(ctx, store.LockingStrengthShare, userAuth.AccountId, userAuth.UserId)
peers, err := transaction.GetUserPeers(ctx, store.LockingStrengthNone, userAuth.AccountId, userAuth.UserId)
if err != nil {
return fmt.Errorf("error getting user peers: %w", err)
}
@@ -1331,7 +1340,7 @@ func (am *DefaultAccountManager) SyncUserJWTGroups(ctx context.Context, userAuth
}
for _, g := range addNewGroups {
group, err := am.Store.GetGroupByID(ctx, store.LockingStrengthShare, userAuth.AccountId, g)
group, err := am.Store.GetGroupByID(ctx, store.LockingStrengthNone, userAuth.AccountId, g)
if err != nil {
log.WithContext(ctx).Debugf("group %s not found while saving user activity event of account %s", g, userAuth.AccountId)
} else {
@@ -1344,7 +1353,7 @@ func (am *DefaultAccountManager) SyncUserJWTGroups(ctx context.Context, userAuth
}
for _, g := range removeOldGroups {
group, err := am.Store.GetGroupByID(ctx, store.LockingStrengthShare, userAuth.AccountId, g)
group, err := am.Store.GetGroupByID(ctx, store.LockingStrengthNone, userAuth.AccountId, g)
if err != nil {
log.WithContext(ctx).Debugf("group %s not found while saving user activity event of account %s", g, userAuth.AccountId)
} else {
@@ -1405,7 +1414,7 @@ func (am *DefaultAccountManager) getAccountIDWithAuthorizationClaims(ctx context
}
if userAuth.IsChild {
exists, err := am.Store.AccountExists(ctx, store.LockingStrengthShare, userAuth.AccountId)
exists, err := am.Store.AccountExists(ctx, store.LockingStrengthNone, userAuth.AccountId)
if err != nil || !exists {
return "", err
}
@@ -1429,7 +1438,7 @@ func (am *DefaultAccountManager) getAccountIDWithAuthorizationClaims(ctx context
return "", err
}
userAccountID, err := am.Store.GetAccountIDByUserID(ctx, store.LockingStrengthShare, userAuth.UserId)
userAccountID, err := am.Store.GetAccountIDByUserID(ctx, store.LockingStrengthNone, userAuth.UserId)
if handleNotFound(err) != nil {
log.WithContext(ctx).Errorf("error getting account ID by user ID: %v", err)
return "", err
@@ -1450,7 +1459,7 @@ func (am *DefaultAccountManager) getAccountIDWithAuthorizationClaims(ctx context
return am.addNewPrivateAccount(ctx, domainAccountID, userAuth)
}
func (am *DefaultAccountManager) getPrivateDomainWithGlobalLock(ctx context.Context, domain string) (string, context.CancelFunc, error) {
domainAccountID, err := am.Store.GetAccountIDByPrivateDomain(ctx, store.LockingStrengthShare, domain)
domainAccountID, err := am.Store.GetAccountIDByPrivateDomain(ctx, store.LockingStrengthNone, domain)
if handleNotFound(err) != nil {
log.WithContext(ctx).Errorf(errorGettingDomainAccIDFmt, err)
@@ -1465,7 +1474,7 @@ func (am *DefaultAccountManager) getPrivateDomainWithGlobalLock(ctx context.Cont
cancel := am.Store.AcquireGlobalLock(ctx)
// check again if the domain has a primary account because of simultaneous requests
domainAccountID, err = am.Store.GetAccountIDByPrivateDomain(ctx, store.LockingStrengthShare, domain)
domainAccountID, err = am.Store.GetAccountIDByPrivateDomain(ctx, store.LockingStrengthNone, domain)
if handleNotFound(err) != nil {
cancel()
log.WithContext(ctx).Errorf(errorGettingDomainAccIDFmt, err)
@@ -1476,7 +1485,7 @@ func (am *DefaultAccountManager) getPrivateDomainWithGlobalLock(ctx context.Cont
}
func (am *DefaultAccountManager) handlePrivateAccountWithIDFromClaim(ctx context.Context, userAuth nbcontext.UserAuth) (string, error) {
userAccountID, err := am.Store.GetAccountIDByUserID(ctx, store.LockingStrengthShare, userAuth.UserId)
userAccountID, err := am.Store.GetAccountIDByUserID(ctx, store.LockingStrengthNone, userAuth.UserId)
if err != nil {
log.WithContext(ctx).Errorf("error getting account ID by user ID: %v", err)
return "", err
@@ -1486,7 +1495,7 @@ func (am *DefaultAccountManager) handlePrivateAccountWithIDFromClaim(ctx context
return "", fmt.Errorf("user %s is not part of the account id %s", userAuth.UserId, userAuth.AccountId)
}
accountDomain, domainCategory, err := am.Store.GetAccountDomainAndCategory(ctx, store.LockingStrengthShare, userAuth.AccountId)
accountDomain, domainCategory, err := am.Store.GetAccountDomainAndCategory(ctx, store.LockingStrengthNone, userAuth.AccountId)
if handleNotFound(err) != nil {
log.WithContext(ctx).Errorf("error getting account domain and category: %v", err)
return "", err
@@ -1497,7 +1506,7 @@ func (am *DefaultAccountManager) handlePrivateAccountWithIDFromClaim(ctx context
}
// We checked if the domain has a primary account already
domainAccountID, err := am.Store.GetAccountIDByPrivateDomain(ctx, store.LockingStrengthShare, userAuth.Domain)
domainAccountID, err := am.Store.GetAccountIDByPrivateDomain(ctx, store.LockingStrengthNone, userAuth.Domain)
if handleNotFound(err) != nil {
log.WithContext(ctx).Errorf(errorGettingDomainAccIDFmt, err)
return "", err
@@ -1627,7 +1636,7 @@ func (am *DefaultAccountManager) GetAccountIDForPeerKey(ctx context.Context, pee
}
func (am *DefaultAccountManager) handleUserPeer(ctx context.Context, transaction store.Store, peer *nbpeer.Peer, settings *types.Settings) (bool, error) {
user, err := transaction.GetUserByUserID(ctx, store.LockingStrengthShare, peer.UserID)
user, err := transaction.GetUserByUserID(ctx, store.LockingStrengthNone, peer.UserID)
if err != nil {
return false, err
}
@@ -1649,7 +1658,7 @@ func (am *DefaultAccountManager) handleUserPeer(ctx context.Context, transaction
}
func (am *DefaultAccountManager) getFreeDNSLabel(ctx context.Context, s store.Store, accountID string, peerHostName string) (string, error) {
existingLabels, err := s.GetPeerLabelsInAccount(ctx, store.LockingStrengthShare, accountID)
existingLabels, err := s.GetPeerLabelsInAccount(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return "", fmt.Errorf("failed to get peer dns labels: %w", err)
}
@@ -1675,7 +1684,7 @@ func (am *DefaultAccountManager) GetAccountSettings(ctx context.Context, account
if !allowed {
return nil, status.NewPermissionDeniedError()
}
return am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID)
return am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
}
// newAccountWithId creates a new Account with a default SetupKey (doesn't store in a Store) and provided id
@@ -1761,7 +1770,7 @@ func (am *DefaultAccountManager) GetOrCreateAccountByPrivateDomain(ctx context.C
cancel := am.Store.AcquireGlobalLock(ctx)
defer cancel()
existingPrimaryAccountID, err := am.Store.GetAccountIDByPrivateDomain(ctx, store.LockingStrengthShare, domain)
existingPrimaryAccountID, err := am.Store.GetAccountIDByPrivateDomain(ctx, store.LockingStrengthNone, domain)
if handleNotFound(err) != nil {
return nil, false, err
}
@@ -1781,7 +1790,7 @@ func (am *DefaultAccountManager) GetOrCreateAccountByPrivateDomain(ctx context.C
for range 2 {
accountId := xid.New().String()
exists, err := am.Store.AccountExists(ctx, store.LockingStrengthShare, accountId)
exists, err := am.Store.AccountExists(ctx, store.LockingStrengthNone, accountId)
if err != nil || exists {
continue
}
@@ -1844,47 +1853,56 @@ func (am *DefaultAccountManager) GetOrCreateAccountByPrivateDomain(ctx context.C
}
func (am *DefaultAccountManager) UpdateToPrimaryAccount(ctx context.Context, accountId string) (*types.Account, error) {
account, err := am.Store.GetAccount(ctx, accountId)
var account *types.Account
err := am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
var err error
account, err = transaction.GetAccount(ctx, accountId)
if err != nil {
return err
}
if account.IsDomainPrimaryAccount {
return nil
}
existingPrimaryAccountID, err := transaction.GetAccountIDByPrivateDomain(ctx, store.LockingStrengthNone, account.Domain)
// error is not a not found error
if handleNotFound(err) != nil {
return err
}
// a primary account already exists for this private domain
if err == nil {
log.WithContext(ctx).WithFields(log.Fields{
"accountId": accountId,
"existingAccountId": existingPrimaryAccountID,
}).Errorf("cannot update account to primary, another account already exists as primary for the same domain")
return status.Errorf(status.Internal, "cannot update account to primary")
}
account.IsDomainPrimaryAccount = true
if err := transaction.SaveAccount(ctx, account); err != nil {
log.WithContext(ctx).WithFields(log.Fields{
"accountId": accountId,
}).Errorf("failed to update account to primary: %v", err)
return status.Errorf(status.Internal, "failed to update account to primary")
}
return nil
})
if err != nil {
return nil, err
}
if account.IsDomainPrimaryAccount {
return account, nil
}
existingPrimaryAccountID, err := am.Store.GetAccountIDByPrivateDomain(ctx, store.LockingStrengthShare, account.Domain)
// error is not a not found error
if handleNotFound(err) != nil {
return nil, err
}
// a primary account already exists for this private domain
if err == nil {
log.WithContext(ctx).WithFields(log.Fields{
"accountId": accountId,
"existingAccountId": existingPrimaryAccountID,
}).Errorf("cannot update account to primary, another account already exists as primary for the same domain")
return nil, status.Errorf(status.Internal, "cannot update account to primary")
}
account.IsDomainPrimaryAccount = true
if err := am.Store.SaveAccount(ctx, account); err != nil {
log.WithContext(ctx).WithFields(log.Fields{
"accountId": accountId,
}).Errorf("failed to update account to primary: %v", err)
return nil, status.Errorf(status.Internal, "failed to update account to primary")
}
return account, nil
}
// propagateUserGroupMemberships propagates all account users' group memberships to their peers.
// Returns true if any groups were modified, true if those updates affect peers and an error.
func propagateUserGroupMemberships(ctx context.Context, transaction store.Store, accountID string) (groupsUpdated bool, peersAffected bool, err error) {
groups, err := transaction.GetAccountGroups(ctx, store.LockingStrengthShare, accountID)
groups, err := transaction.GetAccountGroups(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return false, false, err
}
@@ -1894,7 +1912,7 @@ func propagateUserGroupMemberships(ctx context.Context, transaction store.Store,
groupsMap[group.ID] = group
}
users, err := transaction.GetAccountUsers(ctx, store.LockingStrengthShare, accountID)
users, err := transaction.GetAccountUsers(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return false, false, err
}
@@ -1902,7 +1920,7 @@ func propagateUserGroupMemberships(ctx context.Context, transaction store.Store,
groupsToUpdate := make(map[string]*types.Group)
for _, user := range users {
userPeers, err := transaction.GetUserPeers(ctx, store.LockingStrengthShare, accountID, user.Id)
userPeers, err := transaction.GetUserPeers(ctx, store.LockingStrengthNone, accountID, user.Id)
if err != nil {
return false, false, err
}

View File

@@ -782,7 +782,7 @@ func TestAccountManager_GetAccountByUserID(t *testing.T) {
return
}
exists, err := manager.Store.AccountExists(context.Background(), store.LockingStrengthShare, accountID)
exists, err := manager.Store.AccountExists(context.Background(), store.LockingStrengthNone, accountID)
assert.NoError(t, err)
assert.True(t, exists, "expected to get existing account after creation using userid")
@@ -899,11 +899,11 @@ func TestAccountManager_DeleteAccount(t *testing.T) {
t.Fatal(fmt.Errorf("expected to get an error when trying to get deleted account, got %v", getAccount))
}
pats, err := manager.Store.GetUserPATs(context.Background(), store.LockingStrengthShare, "service-user-1")
pats, err := manager.Store.GetUserPATs(context.Background(), store.LockingStrengthNone, "service-user-1")
require.NoError(t, err)
assert.Len(t, pats, 0)
pats, err = manager.Store.GetUserPATs(context.Background(), store.LockingStrengthShare, userId)
pats, err = manager.Store.GetUserPATs(context.Background(), store.LockingStrengthNone, userId)
require.NoError(t, err)
assert.Len(t, pats, 0)
}
@@ -1775,7 +1775,7 @@ func TestDefaultAccountManager_DefaultAccountSettings(t *testing.T) {
accountID, err := manager.GetAccountIDByUserID(context.Background(), userID, "")
require.NoError(t, err, "unable to create an account")
settings, err := manager.Store.GetAccountSettings(context.Background(), store.LockingStrengthShare, accountID)
settings, err := manager.Store.GetAccountSettings(context.Background(), store.LockingStrengthNone, accountID)
require.NoError(t, err, "unable to get account settings")
assert.NotNil(t, settings)
@@ -1862,11 +1862,8 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing.
require.NoError(t, err, "expecting to update account settings successfully but got error")
wg := &sync.WaitGroup{}
wg.Add(2)
wg.Add(1)
manager.peerLoginExpiry = &MockScheduler{
CancelFunc: func(ctx context.Context, IDs []string) {
wg.Done()
},
ScheduleFunc: func(ctx context.Context, in time.Duration, ID string, job func() (nextRunIn time.Duration, reschedule bool)) {
wg.Done()
},
@@ -1963,7 +1960,7 @@ func TestDefaultAccountManager_UpdateAccountSettings(t *testing.T) {
assert.False(t, updatedSettings.PeerLoginExpirationEnabled)
assert.Equal(t, updatedSettings.PeerLoginExpiration, time.Hour)
settings, err := manager.Store.GetAccountSettings(context.Background(), store.LockingStrengthShare, accountID)
settings, err := manager.Store.GetAccountSettings(context.Background(), store.LockingStrengthNone, accountID)
require.NoError(t, err, "unable to get account settings")
assert.False(t, settings.PeerLoginExpirationEnabled)
@@ -2646,7 +2643,7 @@ func TestAccount_SetJWTGroups(t *testing.T) {
err = manager.SyncUserJWTGroups(context.Background(), claims)
assert.NoError(t, err, "unable to sync jwt groups")
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user1")
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthNone, "user1")
assert.NoError(t, err, "unable to get user")
assert.Len(t, user.AutoGroups, 0, "JWT groups should not be synced")
})
@@ -2660,7 +2657,7 @@ func TestAccount_SetJWTGroups(t *testing.T) {
err := manager.SyncUserJWTGroups(context.Background(), claims)
assert.NoError(t, err, "unable to sync jwt groups")
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user1")
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthNone, "user1")
assert.NoError(t, err, "unable to get user")
assert.Empty(t, user.AutoGroups, "auto groups must be empty")
})
@@ -2674,11 +2671,11 @@ func TestAccount_SetJWTGroups(t *testing.T) {
err := manager.SyncUserJWTGroups(context.Background(), claims)
assert.NoError(t, err, "unable to sync jwt groups")
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user1")
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthNone, "user1")
assert.NoError(t, err, "unable to get user")
assert.Len(t, user.AutoGroups, 0)
group1, err := manager.Store.GetGroupByID(context.Background(), store.LockingStrengthShare, "accountID", "group1")
group1, err := manager.Store.GetGroupByID(context.Background(), store.LockingStrengthNone, "accountID", "group1")
assert.NoError(t, err, "unable to get group")
assert.Equal(t, group1.Issued, types.GroupIssuedAPI, "group should be api issued")
})
@@ -2695,11 +2692,11 @@ func TestAccount_SetJWTGroups(t *testing.T) {
err = manager.SyncUserJWTGroups(context.Background(), claims)
assert.NoError(t, err, "unable to sync jwt groups")
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user1")
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthNone, "user1")
assert.NoError(t, err, "unable to get user")
assert.Len(t, user.AutoGroups, 1)
group1, err := manager.Store.GetGroupByID(context.Background(), store.LockingStrengthShare, "accountID", "group1")
group1, err := manager.Store.GetGroupByID(context.Background(), store.LockingStrengthNone, "accountID", "group1")
assert.NoError(t, err, "unable to get group")
assert.Equal(t, group1.Issued, types.GroupIssuedAPI, "group should be api issued")
})
@@ -2713,7 +2710,7 @@ func TestAccount_SetJWTGroups(t *testing.T) {
err = manager.SyncUserJWTGroups(context.Background(), claims)
assert.NoError(t, err, "unable to sync jwt groups")
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user1")
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthNone, "user1")
assert.NoError(t, err, "unable to get user")
assert.Len(t, user.AutoGroups, 2, "groups count should not be change")
})
@@ -2727,7 +2724,7 @@ func TestAccount_SetJWTGroups(t *testing.T) {
err = manager.SyncUserJWTGroups(context.Background(), claims)
assert.NoError(t, err, "unable to sync jwt groups")
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user1")
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthNone, "user1")
assert.NoError(t, err, "unable to get user")
assert.Len(t, user.AutoGroups, 2, "groups count should not be change")
})
@@ -2741,11 +2738,11 @@ func TestAccount_SetJWTGroups(t *testing.T) {
err = manager.SyncUserJWTGroups(context.Background(), claims)
assert.NoError(t, err, "unable to sync jwt groups")
groups, err := manager.Store.GetAccountGroups(context.Background(), store.LockingStrengthShare, "accountID")
groups, err := manager.Store.GetAccountGroups(context.Background(), store.LockingStrengthNone, "accountID")
assert.NoError(t, err)
assert.Len(t, groups, 3, "new group3 should be added")
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user2")
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthNone, "user2")
assert.NoError(t, err, "unable to get user")
assert.Len(t, user.AutoGroups, 1, "new group should be added")
})
@@ -2759,7 +2756,7 @@ func TestAccount_SetJWTGroups(t *testing.T) {
err = manager.SyncUserJWTGroups(context.Background(), claims)
assert.NoError(t, err, "unable to sync jwt groups")
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user1")
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthNone, "user1")
assert.NoError(t, err, "unable to get user")
assert.Len(t, user.AutoGroups, 1, "only non-JWT groups should remain")
assert.Contains(t, user.AutoGroups, "group1", "group1 should still be present")
@@ -2774,7 +2771,7 @@ func TestAccount_SetJWTGroups(t *testing.T) {
err = manager.SyncUserJWTGroups(context.Background(), claims)
assert.NoError(t, err, "unable to sync jwt groups")
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user2")
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthNone, "user2")
assert.NoError(t, err, "unable to get user")
assert.Len(t, user.AutoGroups, 0, "all JWT groups should be removed")
})
@@ -3357,7 +3354,7 @@ func TestPropagateUserGroupMemberships(t *testing.T) {
group1 := &types.Group{ID: "group1", Name: "Group 1", AccountID: account.Id}
require.NoError(t, manager.Store.SaveGroup(ctx, store.LockingStrengthUpdate, group1))
user, err := manager.Store.GetUserByUserID(ctx, store.LockingStrengthShare, initiatorId)
user, err := manager.Store.GetUserByUserID(ctx, store.LockingStrengthNone, initiatorId)
require.NoError(t, err)
user.AutoGroups = append(user.AutoGroups, group1.ID)
@@ -3368,7 +3365,7 @@ func TestPropagateUserGroupMemberships(t *testing.T) {
assert.True(t, groupsUpdated)
assert.False(t, groupChangesAffectPeers)
group, err := manager.Store.GetGroupByID(ctx, store.LockingStrengthShare, account.Id, group1.ID)
group, err := manager.Store.GetGroupByID(ctx, store.LockingStrengthNone, account.Id, group1.ID)
require.NoError(t, err)
assert.Len(t, group.Peers, 2)
assert.Contains(t, group.Peers, "peer1")
@@ -3379,7 +3376,7 @@ func TestPropagateUserGroupMemberships(t *testing.T) {
group2 := &types.Group{ID: "group2", Name: "Group 2", AccountID: account.Id}
require.NoError(t, manager.Store.SaveGroup(ctx, store.LockingStrengthUpdate, group2))
user, err := manager.Store.GetUserByUserID(ctx, store.LockingStrengthShare, initiatorId)
user, err := manager.Store.GetUserByUserID(ctx, store.LockingStrengthNone, initiatorId)
require.NoError(t, err)
user.AutoGroups = append(user.AutoGroups, group2.ID)
@@ -3406,7 +3403,7 @@ func TestPropagateUserGroupMemberships(t *testing.T) {
assert.True(t, groupsUpdated)
assert.True(t, groupChangesAffectPeers)
groups, err := manager.Store.GetGroupsByIDs(ctx, store.LockingStrengthShare, account.Id, []string{"group1", "group2"})
groups, err := manager.Store.GetGroupsByIDs(ctx, store.LockingStrengthNone, account.Id, []string{"group1", "group2"})
require.NoError(t, err)
for _, group := range groups {
assert.Len(t, group.Peers, 2)
@@ -3423,7 +3420,7 @@ func TestPropagateUserGroupMemberships(t *testing.T) {
})
t.Run("should not remove peers when groups are removed from user", func(t *testing.T) {
user, err := manager.Store.GetUserByUserID(ctx, store.LockingStrengthShare, initiatorId)
user, err := manager.Store.GetUserByUserID(ctx, store.LockingStrengthNone, initiatorId)
require.NoError(t, err)
user.AutoGroups = []string{"group1"}
@@ -3434,7 +3431,7 @@ func TestPropagateUserGroupMemberships(t *testing.T) {
assert.False(t, groupsUpdated)
assert.False(t, groupChangesAffectPeers)
groups, err := manager.Store.GetGroupsByIDs(ctx, store.LockingStrengthShare, account.Id, []string{"group1", "group2"})
groups, err := manager.Store.GetGroupsByIDs(ctx, store.LockingStrengthNone, account.Id, []string{"group1", "group2"})
require.NoError(t, err)
for _, group := range groups {
assert.Len(t, group.Peers, 2)

View File

@@ -73,7 +73,7 @@ func (m *manager) EnsureUserAccessByJWTGroups(ctx context.Context, userAuth nbco
return userAuth, nil
}
settings, err := m.store.GetAccountSettings(ctx, store.LockingStrengthShare, userAuth.AccountId)
settings, err := m.store.GetAccountSettings(ctx, store.LockingStrengthNone, userAuth.AccountId)
if err != nil {
return userAuth, err
}
@@ -104,7 +104,7 @@ func (am *manager) GetPATInfo(ctx context.Context, token string) (user *types.Us
return nil, nil, "", "", err
}
domain, category, err = am.store.GetAccountDomainAndCategory(ctx, store.LockingStrengthShare, user.AccountID)
domain, category, err = am.store.GetAccountDomainAndCategory(ctx, store.LockingStrengthNone, user.AccountID)
if err != nil {
return nil, nil, "", "", err
}
@@ -142,12 +142,12 @@ func (am *manager) extractPATFromToken(ctx context.Context, token string) (*type
var pat *types.PersonalAccessToken
err = am.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
pat, err = transaction.GetPATByHashedToken(ctx, store.LockingStrengthShare, encodedHashedToken)
pat, err = transaction.GetPATByHashedToken(ctx, store.LockingStrengthNone, encodedHashedToken)
if err != nil {
return err
}
user, err = transaction.GetUserByPATID(ctx, store.LockingStrengthShare, pat.ID)
user, err = transaction.GetUserByPATID(ctx, store.LockingStrengthNone, pat.ID)
return err
})
if err != nil {

View File

@@ -72,7 +72,7 @@ func (am *DefaultAccountManager) GetDNSSettings(ctx context.Context, accountID s
return nil, status.NewPermissionDeniedError()
}
return am.Store.GetAccountDNSSettings(ctx, store.LockingStrengthShare, accountID)
return am.Store.GetAccountDNSSettings(ctx, store.LockingStrengthNone, accountID)
}
// SaveDNSSettings validates a user role and updates the account's DNS settings
@@ -139,7 +139,7 @@ func (am *DefaultAccountManager) prepareDNSSettingsEvents(ctx context.Context, t
var eventsToStore []func()
modifiedGroups := slices.Concat(addedGroups, removedGroups)
groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthShare, accountID, modifiedGroups)
groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthNone, accountID, modifiedGroups)
if err != nil {
log.WithContext(ctx).Debugf("failed to get groups for dns settings events: %v", err)
return nil
@@ -195,7 +195,7 @@ func validateDNSSettings(ctx context.Context, transaction store.Store, accountID
return nil
}
groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthShare, accountID, settings.DisabledManagementGroups)
groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthNone, accountID, settings.DisabledManagementGroups)
if err != nil {
return err
}

View File

@@ -122,7 +122,7 @@ func (e *EphemeralManager) OnPeerDisconnected(ctx context.Context, peer *nbpeer.
}
func (e *EphemeralManager) loadEphemeralPeers(ctx context.Context) {
peers, err := e.store.GetAllEphemeralPeers(ctx, store.LockingStrengthShare)
peers, err := e.store.GetAllEphemeralPeers(ctx, store.LockingStrengthNone)
if err != nil {
log.WithContext(ctx).Debugf("failed to load ephemeral peers: %s", err)
return

View File

@@ -103,7 +103,7 @@ func (am *DefaultAccountManager) fillEventsWithUserInfo(ctx context.Context, eve
}
func (am *DefaultAccountManager) getEventsUserInfo(ctx context.Context, events []*activity.Event, accountId string, userId string) (map[string]eventUserInfo, error) {
accountUsers, err := am.Store.GetAccountUsers(ctx, store.LockingStrengthShare, accountId)
accountUsers, err := am.Store.GetAccountUsers(ctx, store.LockingStrengthNone, accountId)
if err != nil {
return nil, err
}
@@ -154,7 +154,7 @@ func (am *DefaultAccountManager) getEventsExternalUserInfo(ctx context.Context,
continue
}
externalUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, id)
externalUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, id)
if err != nil {
// @todo consider logging
continue

View File

@@ -49,7 +49,7 @@ func (am *DefaultAccountManager) GetGroup(ctx context.Context, accountID, groupI
if err := am.CheckGroupPermissions(ctx, accountID, userID); err != nil {
return nil, err
}
return am.Store.GetGroupByID(ctx, store.LockingStrengthShare, accountID, groupID)
return am.Store.GetGroupByID(ctx, store.LockingStrengthNone, accountID, groupID)
}
// GetAllGroups returns all groups in an account
@@ -57,12 +57,12 @@ func (am *DefaultAccountManager) GetAllGroups(ctx context.Context, accountID, us
if err := am.CheckGroupPermissions(ctx, accountID, userID); err != nil {
return nil, err
}
return am.Store.GetAccountGroups(ctx, store.LockingStrengthShare, accountID)
return am.Store.GetAccountGroups(ctx, store.LockingStrengthNone, accountID)
}
// GetGroupByName filters all groups in an account by name and returns the one with the most peers
func (am *DefaultAccountManager) GetGroupByName(ctx context.Context, groupName, accountID string) (*types.Group, error) {
return am.Store.GetGroupByName(ctx, store.LockingStrengthShare, accountID, groupName)
return am.Store.GetGroupByName(ctx, store.LockingStrengthNone, accountID, groupName)
}
// SaveGroup object of the peers
@@ -140,7 +140,7 @@ func (am *DefaultAccountManager) prepareGroupEvents(ctx context.Context, transac
addedPeers := make([]string, 0)
removedPeers := make([]string, 0)
oldGroup, err := transaction.GetGroupByID(ctx, store.LockingStrengthShare, accountID, newGroup.ID)
oldGroup, err := transaction.GetGroupByID(ctx, store.LockingStrengthNone, accountID, newGroup.ID)
if err == nil && oldGroup != nil {
addedPeers = util.Difference(newGroup.Peers, oldGroup.Peers)
removedPeers = util.Difference(oldGroup.Peers, newGroup.Peers)
@@ -152,13 +152,13 @@ func (am *DefaultAccountManager) prepareGroupEvents(ctx context.Context, transac
}
modifiedPeers := slices.Concat(addedPeers, removedPeers)
peers, err := transaction.GetPeersByIDs(ctx, store.LockingStrengthShare, accountID, modifiedPeers)
peers, err := transaction.GetPeersByIDs(ctx, store.LockingStrengthNone, accountID, modifiedPeers)
if err != nil {
log.WithContext(ctx).Debugf("failed to get peers for group events: %v", err)
return nil
}
settings, err := transaction.GetAccountSettings(ctx, store.LockingStrengthShare, accountID)
settings, err := transaction.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
if err != nil {
log.WithContext(ctx).Debugf("failed to get account settings for group events: %v", err)
return nil
@@ -431,7 +431,7 @@ func validateNewGroup(ctx context.Context, transaction store.Store, accountID st
}
if newGroup.ID == "" && newGroup.Issued == types.GroupIssuedAPI {
existingGroup, err := transaction.GetGroupByName(ctx, store.LockingStrengthShare, accountID, newGroup.Name)
existingGroup, err := transaction.GetGroupByName(ctx, store.LockingStrengthNone, accountID, newGroup.Name)
if err != nil {
if s, ok := status.FromError(err); !ok || s.Type() != status.NotFound {
return err
@@ -448,7 +448,7 @@ func validateNewGroup(ctx context.Context, transaction store.Store, accountID st
}
for _, peerID := range newGroup.Peers {
_, err := transaction.GetPeerByID(ctx, store.LockingStrengthShare, accountID, peerID)
_, err := transaction.GetPeerByID(ctx, store.LockingStrengthNone, accountID, peerID)
if err != nil {
return status.Errorf(status.InvalidArgument, "peer with ID \"%s\" not found", peerID)
}
@@ -460,7 +460,7 @@ func validateNewGroup(ctx context.Context, transaction store.Store, accountID st
func validateDeleteGroup(ctx context.Context, transaction store.Store, group *types.Group, userID string) error {
// disable a deleting integration group if the initiator is not an admin service user
if group.Issued == types.GroupIssuedIntegration {
executingUser, err := transaction.GetUserByUserID(ctx, store.LockingStrengthShare, userID)
executingUser, err := transaction.GetUserByUserID(ctx, store.LockingStrengthNone, userID)
if err != nil {
return status.Errorf(status.Internal, "failed to get user")
}
@@ -506,7 +506,7 @@ func validateDeleteGroup(ctx context.Context, transaction store.Store, group *ty
// checkGroupLinkedToSettings verifies if a group is linked to any settings in the account.
func checkGroupLinkedToSettings(ctx context.Context, transaction store.Store, group *types.Group) error {
dnsSettings, err := transaction.GetAccountDNSSettings(ctx, store.LockingStrengthShare, group.AccountID)
dnsSettings, err := transaction.GetAccountDNSSettings(ctx, store.LockingStrengthNone, group.AccountID)
if err != nil {
return status.Errorf(status.Internal, "failed to get DNS settings")
}
@@ -515,7 +515,7 @@ func checkGroupLinkedToSettings(ctx context.Context, transaction store.Store, gr
return &GroupLinkError{"disabled DNS management groups", group.Name}
}
settings, err := transaction.GetAccountSettings(ctx, store.LockingStrengthShare, group.AccountID)
settings, err := transaction.GetAccountSettings(ctx, store.LockingStrengthNone, group.AccountID)
if err != nil {
return status.Errorf(status.Internal, "failed to get account settings")
}
@@ -529,7 +529,7 @@ func checkGroupLinkedToSettings(ctx context.Context, transaction store.Store, gr
// isGroupLinkedToRoute checks if a group is linked to any route in the account.
func isGroupLinkedToRoute(ctx context.Context, transaction store.Store, accountID string, groupID string) (bool, *route.Route) {
routes, err := transaction.GetAccountRoutes(ctx, store.LockingStrengthShare, accountID)
routes, err := transaction.GetAccountRoutes(ctx, store.LockingStrengthNone, accountID)
if err != nil {
log.WithContext(ctx).Errorf("error retrieving routes while checking group linkage: %v", err)
return false, nil
@@ -549,7 +549,7 @@ func isGroupLinkedToRoute(ctx context.Context, transaction store.Store, accountI
// isGroupLinkedToPolicy checks if a group is linked to any policy in the account.
func isGroupLinkedToPolicy(ctx context.Context, transaction store.Store, accountID string, groupID string) (bool, *types.Policy) {
policies, err := transaction.GetAccountPolicies(ctx, store.LockingStrengthShare, accountID)
policies, err := transaction.GetAccountPolicies(ctx, store.LockingStrengthNone, accountID)
if err != nil {
log.WithContext(ctx).Errorf("error retrieving policies while checking group linkage: %v", err)
return false, nil
@@ -567,7 +567,7 @@ func isGroupLinkedToPolicy(ctx context.Context, transaction store.Store, account
// isGroupLinkedToDns checks if a group is linked to any nameserver group in the account.
func isGroupLinkedToDns(ctx context.Context, transaction store.Store, accountID string, groupID string) (bool, *nbdns.NameServerGroup) {
nameServerGroups, err := transaction.GetAccountNameServerGroups(ctx, store.LockingStrengthShare, accountID)
nameServerGroups, err := transaction.GetAccountNameServerGroups(ctx, store.LockingStrengthNone, accountID)
if err != nil {
log.WithContext(ctx).Errorf("error retrieving name server groups while checking group linkage: %v", err)
return false, nil
@@ -586,7 +586,7 @@ func isGroupLinkedToDns(ctx context.Context, transaction store.Store, accountID
// isGroupLinkedToSetupKey checks if a group is linked to any setup key in the account.
func isGroupLinkedToSetupKey(ctx context.Context, transaction store.Store, accountID string, groupID string) (bool, *types.SetupKey) {
setupKeys, err := transaction.GetAccountSetupKeys(ctx, store.LockingStrengthShare, accountID)
setupKeys, err := transaction.GetAccountSetupKeys(ctx, store.LockingStrengthNone, accountID)
if err != nil {
log.WithContext(ctx).Errorf("error retrieving setup keys while checking group linkage: %v", err)
return false, nil
@@ -602,7 +602,7 @@ func isGroupLinkedToSetupKey(ctx context.Context, transaction store.Store, accou
// isGroupLinkedToUser checks if a group is linked to any user in the account.
func isGroupLinkedToUser(ctx context.Context, transaction store.Store, accountID string, groupID string) (bool, *types.User) {
users, err := transaction.GetAccountUsers(ctx, store.LockingStrengthShare, accountID)
users, err := transaction.GetAccountUsers(ctx, store.LockingStrengthNone, accountID)
if err != nil {
log.WithContext(ctx).Errorf("error retrieving users while checking group linkage: %v", err)
return false, nil
@@ -618,7 +618,7 @@ func isGroupLinkedToUser(ctx context.Context, transaction store.Store, accountID
// isGroupLinkedToNetworkRouter checks if a group is linked to any network router in the account.
func isGroupLinkedToNetworkRouter(ctx context.Context, transaction store.Store, accountID string, groupID string) (bool, *routerTypes.NetworkRouter) {
routers, err := transaction.GetNetworkRoutersByAccountID(ctx, store.LockingStrengthShare, accountID)
routers, err := transaction.GetNetworkRoutersByAccountID(ctx, store.LockingStrengthNone, accountID)
if err != nil {
log.WithContext(ctx).Errorf("error retrieving network routers while checking group linkage: %v", err)
return false, nil
@@ -638,7 +638,7 @@ func areGroupChangesAffectPeers(ctx context.Context, transaction store.Store, ac
return false, nil
}
dnsSettings, err := transaction.GetAccountDNSSettings(ctx, store.LockingStrengthShare, accountID)
dnsSettings, err := transaction.GetAccountDNSSettings(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return false, err
}
@@ -664,18 +664,9 @@ func areGroupChangesAffectPeers(ctx context.Context, transaction store.Store, ac
return false, nil
}
func (am *DefaultAccountManager) anyGroupHasPeers(account *types.Account, groupIDs []string) bool {
for _, groupID := range groupIDs {
if group, exists := account.Groups[groupID]; exists && group.HasPeers() {
return true
}
}
return false
}
// anyGroupHasPeersOrResources checks if any of the given groups in the account have peers or resources.
func anyGroupHasPeersOrResources(ctx context.Context, transaction store.Store, accountID string, groupIDs []string) (bool, error) {
groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthShare, accountID, groupIDs)
groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthNone, accountID, groupIDs)
if err != nil {
return false, err
}

View File

@@ -49,7 +49,7 @@ func (m *managerImpl) GetAllGroups(ctx context.Context, accountID, userID string
return nil, err
}
groups, err := m.store.GetAccountGroups(ctx, store.LockingStrengthShare, accountID)
groups, err := m.store.GetAccountGroups(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return nil, fmt.Errorf("error getting account groups: %w", err)
}
@@ -96,13 +96,13 @@ func (m *managerImpl) AddResourceToGroupInTransaction(ctx context.Context, trans
return nil, fmt.Errorf("error adding resource to group: %w", err)
}
group, err := transaction.GetGroupByID(ctx, store.LockingStrengthShare, accountID, groupID)
group, err := transaction.GetGroupByID(ctx, store.LockingStrengthNone, accountID, groupID)
if err != nil {
return nil, fmt.Errorf("error getting group: %w", err)
}
// TODO: at some point, this will need to become a switch statement
networkResource, err := transaction.GetNetworkResourceByID(ctx, store.LockingStrengthShare, accountID, resource.ID)
networkResource, err := transaction.GetNetworkResourceByID(ctx, store.LockingStrengthNone, accountID, resource.ID)
if err != nil {
return nil, fmt.Errorf("error getting network resource: %w", err)
}
@@ -120,13 +120,13 @@ func (m *managerImpl) RemoveResourceFromGroupInTransaction(ctx context.Context,
return nil, fmt.Errorf("error removing resource from group: %w", err)
}
group, err := transaction.GetGroupByID(ctx, store.LockingStrengthShare, accountID, groupID)
group, err := transaction.GetGroupByID(ctx, store.LockingStrengthNone, accountID, groupID)
if err != nil {
return nil, fmt.Errorf("error getting group: %w", err)
}
// TODO: at some point, this will need to become a switch statement
networkResource, err := transaction.GetNetworkResourceByID(ctx, store.LockingStrengthShare, accountID, resourceID)
networkResource, err := transaction.GetNetworkResourceByID(ctx, store.LockingStrengthNone, accountID, resourceID)
if err != nil {
return nil, fmt.Errorf("error getting network resource: %w", err)
}

View File

@@ -426,6 +426,10 @@ components:
items:
type: string
example: "stage-host-1"
ephemeral:
description: Indicates whether the peer is ephemeral or not
type: boolean
example: false
required:
- city_name
- connected
@@ -450,6 +454,7 @@ components:
- approval_required
- serial_number
- extra_dns_labels
- ephemeral
AccessiblePeer:
allOf:
- $ref: '#/components/schemas/PeerMinimum'

View File

@@ -1016,6 +1016,9 @@ type Peer struct {
// DnsLabel Peer's DNS label is the parsed peer name for domain resolution. It is used to form an FQDN by appending the account's domain to the peer label. e.g. peer-dns-label.netbird.cloud
DnsLabel string `json:"dns_label"`
// Ephemeral Indicates whether the peer is ephemeral or not
Ephemeral bool `json:"ephemeral"`
// ExtraDnsLabels Extra DNS labels added to the peer
ExtraDnsLabels []string `json:"extra_dns_labels"`
@@ -1097,6 +1100,9 @@ type PeerBatch struct {
// DnsLabel Peer's DNS label is the parsed peer name for domain resolution. It is used to form an FQDN by appending the account's domain to the peer label. e.g. peer-dns-label.netbird.cloud
DnsLabel string `json:"dns_label"`
// Ephemeral Indicates whether the peer is ephemeral or not
Ephemeral bool `json:"ephemeral"`
// ExtraDnsLabels Extra DNS labels added to the peer
ExtraDnsLabels []string `json:"extra_dns_labels"`

View File

@@ -365,6 +365,7 @@ func toSinglePeerResponse(peer *nbpeer.Peer, groupsInfo []api.GroupMinimum, dnsD
CityName: peer.Location.CityName,
SerialNumber: peer.Meta.SystemSerialNumber,
InactivityExpirationEnabled: peer.InactivityExpirationEnabled,
Ephemeral: peer.Ephemeral,
}
}

View File

@@ -37,21 +37,23 @@ func (am *DefaultAccountManager) UpdateIntegratedValidatorGroups(ctx context.Con
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
a, err := am.Store.GetAccountByUser(ctx, userID)
if err != nil {
return err
}
return am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
a, err := transaction.GetAccountByUser(ctx, userID)
if err != nil {
return err
}
var extra *types.ExtraSettings
var extra *types.ExtraSettings
if a.Settings.Extra != nil {
extra = a.Settings.Extra
} else {
extra = &types.ExtraSettings{}
a.Settings.Extra = extra
}
extra.IntegratedValidatorGroups = groups
return am.Store.SaveAccount(ctx, a)
if a.Settings.Extra != nil {
extra = a.Settings.Extra
} else {
extra = &types.ExtraSettings{}
a.Settings.Extra = extra
}
extra.IntegratedValidatorGroups = groups
return transaction.SaveAccount(ctx, a)
})
}
func (am *DefaultAccountManager) GroupValidation(ctx context.Context, accountID string, groupIDs []string) (bool, error) {
@@ -61,7 +63,7 @@ func (am *DefaultAccountManager) GroupValidation(ctx context.Context, accountID
err := am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
for _, groupID := range groupIDs {
_, err := transaction.GetGroupByID(context.Background(), store.LockingStrengthShare, accountID, groupID)
_, err := transaction.GetGroupByID(context.Background(), store.LockingStrengthNone, accountID, groupID)
if err != nil {
return err
}
@@ -81,20 +83,17 @@ func (am *DefaultAccountManager) GetValidatedPeers(ctx context.Context, accountI
var peers []*nbpeer.Peer
var settings *types.Settings
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
groups, err = transaction.GetAccountGroups(ctx, store.LockingStrengthShare, accountID)
if err != nil {
return err
}
peers, err = transaction.GetAccountPeers(ctx, store.LockingStrengthShare, accountID, "", "")
return err
})
groups, err = am.Store.GetAccountGroups(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return nil, err
}
settings, err = am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID)
peers, err = am.Store.GetAccountPeers(ctx, store.LockingStrengthNone, accountID, "", "")
if err != nil {
return nil, err
}
settings, err = am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return nil, err
}

View File

@@ -641,7 +641,7 @@ func testSyncStatusRace(t *testing.T) {
}
time.Sleep(10 * time.Millisecond)
peer, err := am.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthShare, peerWithInvalidStatus.PublicKey().String())
peer, err := am.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, peerWithInvalidStatus.PublicKey().String())
if err != nil {
t.Fatal(err)
return

View File

@@ -184,7 +184,9 @@ func (w *Worker) generateProperties(ctx context.Context) properties {
ephemeralPeersSKs int
ephemeralPeersSKUsage int
activePeersLastDay int
activeUserPeersLastDay int
osPeers map[string]int
activeUsersLastDay map[string]struct{}
userPeers int
rules int
rulesProtocol map[string]int
@@ -203,6 +205,7 @@ func (w *Worker) generateProperties(ctx context.Context) properties {
version string
peerActiveVersions []string
osUIClients map[string]int
rosenpassEnabled int
)
start := time.Now()
metricsProperties := make(properties)
@@ -210,6 +213,7 @@ func (w *Worker) generateProperties(ctx context.Context) properties {
osUIClients = make(map[string]int)
rulesProtocol = make(map[string]int)
rulesDirection = make(map[string]int)
activeUsersLastDay = make(map[string]struct{})
uptime = time.Since(w.startupTime).Seconds()
connections := w.connManager.GetAllConnectedPeers()
version = nbversion.NetbirdVersion()
@@ -277,10 +281,14 @@ func (w *Worker) generateProperties(ctx context.Context) properties {
for _, peer := range account.Peers {
peers++
if peer.SSHEnabled {
if peer.SSHEnabled || peer.Meta.Flags.ServerSSHAllowed {
peersSSHEnabled++
}
if peer.Meta.Flags.RosenpassEnabled {
rosenpassEnabled++
}
if peer.UserID != "" {
userPeers++
}
@@ -299,6 +307,10 @@ func (w *Worker) generateProperties(ctx context.Context) properties {
_, connected := connections[peer.ID]
if connected || peer.Status.LastSeen.After(w.lastRun) {
activePeersLastDay++
if peer.UserID != "" {
activeUserPeersLastDay++
activeUsersLastDay[peer.UserID] = struct{}{}
}
osActiveKey := osKey + "_active"
osActiveCount := osPeers[osActiveKey]
osPeers[osActiveKey] = osActiveCount + 1
@@ -320,6 +332,8 @@ func (w *Worker) generateProperties(ctx context.Context) properties {
metricsProperties["ephemeral_peers_setup_keys"] = ephemeralPeersSKs
metricsProperties["ephemeral_peers_setup_keys_usage"] = ephemeralPeersSKUsage
metricsProperties["active_peers_last_day"] = activePeersLastDay
metricsProperties["active_user_peers_last_day"] = activeUserPeersLastDay
metricsProperties["active_users_last_day"] = len(activeUsersLastDay)
metricsProperties["user_peers"] = userPeers
metricsProperties["rules"] = rules
metricsProperties["rules_with_src_posture_checks"] = rulesWithSrcPostureChecks
@@ -338,6 +352,7 @@ func (w *Worker) generateProperties(ctx context.Context) properties {
metricsProperties["ui_clients"] = uiClient
metricsProperties["idp_manager"] = w.idpManager
metricsProperties["store_engine"] = w.dataSource.GetStoreEngine()
metricsProperties["rosenpass_enabled"] = rosenpassEnabled
for protocol, count := range rulesProtocol {
metricsProperties["rules_protocol_"+protocol] = count

View File

@@ -47,8 +47,8 @@ func (mockDatasource) GetAllAccounts(_ context.Context) []*types.Account {
"1": {
ID: "1",
UserID: "test",
SSHEnabled: true,
Meta: nbpeer.PeerSystemMeta{GoOS: "linux", WtVersion: "0.0.1"},
SSHEnabled: false,
Meta: nbpeer.PeerSystemMeta{GoOS: "linux", WtVersion: "0.0.1", Flags: nbpeer.Flags{ServerSSHAllowed: true, RosenpassEnabled: true}},
},
},
Policies: []*types.Policy{
@@ -312,7 +312,19 @@ func TestGenerateProperties(t *testing.T) {
}
if properties["posture_checks"] != 2 {
t.Errorf("expected 1 posture_checks, got %d", properties["posture_checks"])
t.Errorf("expected 2 posture_checks, got %d", properties["posture_checks"])
}
if properties["rosenpass_enabled"] != 1 {
t.Errorf("expected 1 rosenpass_enabled, got %d", properties["rosenpass_enabled"])
}
if properties["active_user_peers_last_day"] != 2 {
t.Errorf("expected 2 active_user_peers_last_day, got %d", properties["active_user_peers_last_day"])
}
if properties["active_users_last_day"] != 1 {
t.Errorf("expected 1 active_users_last_day, got %d", properties["active_users_last_day"])
}
}

View File

@@ -32,7 +32,7 @@ func (am *DefaultAccountManager) GetNameServerGroup(ctx context.Context, account
return nil, status.NewPermissionDeniedError()
}
return am.Store.GetNameServerGroupByID(ctx, store.LockingStrengthShare, accountID, nsGroupID)
return am.Store.GetNameServerGroupByID(ctx, store.LockingStrengthNone, accountID, nsGroupID)
}
// CreateNameServerGroup creates and saves a new nameserver group
@@ -112,7 +112,7 @@ func (am *DefaultAccountManager) SaveNameServerGroup(ctx context.Context, accoun
var updateAccountPeers bool
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
oldNSGroup, err := transaction.GetNameServerGroupByID(ctx, store.LockingStrengthShare, accountID, nsGroupToSave.ID)
oldNSGroup, err := transaction.GetNameServerGroupByID(ctx, store.LockingStrengthNone, accountID, nsGroupToSave.ID)
if err != nil {
return err
}
@@ -202,7 +202,7 @@ func (am *DefaultAccountManager) ListNameServerGroups(ctx context.Context, accou
return nil, status.NewPermissionDeniedError()
}
return am.Store.GetAccountNameServerGroups(ctx, store.LockingStrengthShare, accountID)
return am.Store.GetAccountNameServerGroups(ctx, store.LockingStrengthNone, accountID)
}
func validateNameServerGroup(ctx context.Context, transaction store.Store, accountID string, nameserverGroup *nbdns.NameServerGroup) error {
@@ -216,7 +216,7 @@ func validateNameServerGroup(ctx context.Context, transaction store.Store, accou
return err
}
nsServerGroups, err := transaction.GetAccountNameServerGroups(ctx, store.LockingStrengthShare, accountID)
nsServerGroups, err := transaction.GetAccountNameServerGroups(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return err
}
@@ -226,7 +226,7 @@ func validateNameServerGroup(ctx context.Context, transaction store.Store, accou
return err
}
groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthShare, accountID, nameserverGroup.Groups)
groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthNone, accountID, nameserverGroup.Groups)
if err != nil {
return err
}

View File

@@ -56,7 +56,7 @@ func (m *managerImpl) GetAllNetworks(ctx context.Context, accountID, userID stri
return nil, status.NewPermissionDeniedError()
}
return m.store.GetAccountNetworks(ctx, store.LockingStrengthShare, accountID)
return m.store.GetAccountNetworks(ctx, store.LockingStrengthNone, accountID)
}
func (m *managerImpl) CreateNetwork(ctx context.Context, userID string, network *types.Network) (*types.Network, error) {
@@ -92,7 +92,7 @@ func (m *managerImpl) GetNetwork(ctx context.Context, accountID, userID, network
return nil, status.NewPermissionDeniedError()
}
return m.store.GetNetworkByID(ctx, store.LockingStrengthShare, accountID, networkID)
return m.store.GetNetworkByID(ctx, store.LockingStrengthNone, accountID, networkID)
}
func (m *managerImpl) UpdateNetwork(ctx context.Context, userID string, network *types.Network) (*types.Network, error) {

View File

@@ -57,7 +57,7 @@ func (m *managerImpl) GetAllResourcesInNetwork(ctx context.Context, accountID, u
return nil, status.NewPermissionDeniedError()
}
return m.store.GetNetworkResourcesByNetID(ctx, store.LockingStrengthShare, accountID, networkID)
return m.store.GetNetworkResourcesByNetID(ctx, store.LockingStrengthNone, accountID, networkID)
}
func (m *managerImpl) GetAllResourcesInAccount(ctx context.Context, accountID, userID string) ([]*types.NetworkResource, error) {
@@ -69,7 +69,7 @@ func (m *managerImpl) GetAllResourcesInAccount(ctx context.Context, accountID, u
return nil, status.NewPermissionDeniedError()
}
return m.store.GetNetworkResourcesByAccountID(ctx, store.LockingStrengthShare, accountID)
return m.store.GetNetworkResourcesByAccountID(ctx, store.LockingStrengthNone, accountID)
}
func (m *managerImpl) GetAllResourceIDsInAccount(ctx context.Context, accountID, userID string) (map[string][]string, error) {
@@ -81,7 +81,7 @@ func (m *managerImpl) GetAllResourceIDsInAccount(ctx context.Context, accountID,
return nil, status.NewPermissionDeniedError()
}
resources, err := m.store.GetNetworkResourcesByAccountID(ctx, store.LockingStrengthShare, accountID)
resources, err := m.store.GetNetworkResourcesByAccountID(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return nil, fmt.Errorf("failed to get network resources: %w", err)
}
@@ -113,7 +113,7 @@ func (m *managerImpl) CreateResource(ctx context.Context, userID string, resourc
var eventsToStore []func()
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
_, err = transaction.GetNetworkResourceByName(ctx, store.LockingStrengthShare, resource.AccountID, resource.Name)
_, err = transaction.GetNetworkResourceByName(ctx, store.LockingStrengthNone, resource.AccountID, resource.Name)
if err == nil {
return status.Errorf(status.InvalidArgument, "resource with name %s already exists", resource.Name)
}
@@ -174,7 +174,7 @@ func (m *managerImpl) GetResource(ctx context.Context, accountID, userID, networ
return nil, status.NewPermissionDeniedError()
}
resource, err := m.store.GetNetworkResourceByID(ctx, store.LockingStrengthShare, accountID, resourceID)
resource, err := m.store.GetNetworkResourceByID(ctx, store.LockingStrengthNone, accountID, resourceID)
if err != nil {
return nil, fmt.Errorf("failed to get network resource: %w", err)
}
@@ -218,17 +218,17 @@ func (m *managerImpl) UpdateResource(ctx context.Context, userID string, resourc
return status.NewResourceNotPartOfNetworkError(resource.ID, resource.NetworkID)
}
_, err = transaction.GetNetworkResourceByID(ctx, store.LockingStrengthShare, resource.AccountID, resource.ID)
_, err = transaction.GetNetworkResourceByID(ctx, store.LockingStrengthNone, resource.AccountID, resource.ID)
if err != nil {
return fmt.Errorf("failed to get network resource: %w", err)
}
oldResource, err := transaction.GetNetworkResourceByName(ctx, store.LockingStrengthShare, resource.AccountID, resource.Name)
oldResource, err := transaction.GetNetworkResourceByName(ctx, store.LockingStrengthNone, resource.AccountID, resource.Name)
if err == nil && oldResource.ID != resource.ID {
return status.Errorf(status.InvalidArgument, "new resource name already exists")
}
oldResource, err = transaction.GetNetworkResourceByID(ctx, store.LockingStrengthShare, resource.AccountID, resource.ID)
oldResource, err = transaction.GetNetworkResourceByID(ctx, store.LockingStrengthNone, resource.AccountID, resource.ID)
if err != nil {
return fmt.Errorf("failed to get network resource: %w", err)
}

View File

@@ -54,7 +54,7 @@ func (m *managerImpl) GetAllRoutersInNetwork(ctx context.Context, accountID, use
return nil, status.NewPermissionDeniedError()
}
return m.store.GetNetworkRoutersByNetID(ctx, store.LockingStrengthShare, accountID, networkID)
return m.store.GetNetworkRoutersByNetID(ctx, store.LockingStrengthNone, accountID, networkID)
}
func (m *managerImpl) GetAllRoutersInAccount(ctx context.Context, accountID, userID string) (map[string][]*types.NetworkRouter, error) {
@@ -66,7 +66,7 @@ func (m *managerImpl) GetAllRoutersInAccount(ctx context.Context, accountID, use
return nil, status.NewPermissionDeniedError()
}
routers, err := m.store.GetNetworkRoutersByAccountID(ctx, store.LockingStrengthShare, accountID)
routers, err := m.store.GetNetworkRoutersByAccountID(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return nil, fmt.Errorf("failed to get network routers: %w", err)
}
@@ -93,7 +93,7 @@ func (m *managerImpl) CreateRouter(ctx context.Context, userID string, router *t
var network *networkTypes.Network
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
network, err = transaction.GetNetworkByID(ctx, store.LockingStrengthShare, router.AccountID, router.NetworkID)
network, err = transaction.GetNetworkByID(ctx, store.LockingStrengthNone, router.AccountID, router.NetworkID)
if err != nil {
return fmt.Errorf("failed to get network: %w", err)
}
@@ -136,7 +136,7 @@ func (m *managerImpl) GetRouter(ctx context.Context, accountID, userID, networkI
return nil, status.NewPermissionDeniedError()
}
router, err := m.store.GetNetworkRouterByID(ctx, store.LockingStrengthShare, accountID, routerID)
router, err := m.store.GetNetworkRouterByID(ctx, store.LockingStrengthNone, accountID, routerID)
if err != nil {
return nil, fmt.Errorf("failed to get network router: %w", err)
}
@@ -162,7 +162,7 @@ func (m *managerImpl) UpdateRouter(ctx context.Context, userID string, router *t
var network *networkTypes.Network
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
network, err = transaction.GetNetworkByID(ctx, store.LockingStrengthShare, router.AccountID, router.NetworkID)
network, err = transaction.GetNetworkByID(ctx, store.LockingStrengthNone, router.AccountID, router.NetworkID)
if err != nil {
return fmt.Errorf("failed to get network: %w", err)
}
@@ -232,7 +232,7 @@ func (m *managerImpl) DeleteRouter(ctx context.Context, accountID, userID, netwo
}
func (m *managerImpl) DeleteRouterInTransaction(ctx context.Context, transaction store.Store, accountID, userID, networkID, routerID string) (func(), error) {
network, err := transaction.GetNetworkByID(ctx, store.LockingStrengthShare, accountID, networkID)
network, err := transaction.GetNetworkByID(ctx, store.LockingStrengthNone, accountID, networkID)
if err != nil {
return nil, fmt.Errorf("failed to get network: %w", err)
}

View File

@@ -35,7 +35,7 @@ import (
// GetPeers returns a list of peers under the given account filtering out peers that do not belong to a user if
// the current user is not an admin.
func (am *DefaultAccountManager) GetPeers(ctx context.Context, accountID, userID, nameFilter, ipFilter string) ([]*nbpeer.Peer, error) {
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID)
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, userID)
if err != nil {
return nil, err
}
@@ -45,7 +45,7 @@ func (am *DefaultAccountManager) GetPeers(ctx context.Context, accountID, userID
return nil, status.NewPermissionValidationError(err)
}
accountPeers, err := am.Store.GetAccountPeers(ctx, store.LockingStrengthShare, accountID, nameFilter, ipFilter)
accountPeers, err := am.Store.GetAccountPeers(ctx, store.LockingStrengthNone, accountID, nameFilter, ipFilter)
if err != nil {
return nil, err
}
@@ -55,7 +55,7 @@ func (am *DefaultAccountManager) GetPeers(ctx context.Context, accountID, userID
return accountPeers, nil
}
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID)
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return nil, fmt.Errorf("failed to get account settings: %w", err)
}
@@ -92,7 +92,7 @@ func (am *DefaultAccountManager) getUserAccessiblePeers(ctx context.Context, acc
// fetch all the peers that have access to the user's peers
for _, peer := range peers {
aclPeers, _ := account.GetPeerConnectionResources(ctx, peer.ID, approvedPeersMap)
aclPeers, _ := account.GetPeerConnectionResources(ctx, peer, approvedPeersMap)
for _, p := range aclPeers {
peersMap[p.ID] = p
}
@@ -127,13 +127,13 @@ func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubK
}
if peer.AddedWithSSOLogin() {
settings, err = am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID)
settings, err = am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return err
}
if peer.LoginExpirationEnabled && settings.PeerLoginExpirationEnabled {
am.checkAndSchedulePeerLoginExpiration(ctx, accountID)
am.schedulePeerLoginExpiration(ctx, accountID)
}
if peer.InactivityExpirationEnabled && settings.PeerInactivityExpirationEnabled {
@@ -216,7 +216,7 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user
return err
}
settings, err = transaction.GetAccountSettings(ctx, store.LockingStrengthShare, accountID)
settings, err = transaction.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return err
}
@@ -296,7 +296,8 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user
am.StoreEvent(ctx, userID, peer.IP.String(), accountID, event, peer.EventMeta(dnsDomain))
if peer.AddedWithSSOLogin() && peer.LoginExpirationEnabled && settings.PeerLoginExpirationEnabled {
am.checkAndSchedulePeerLoginExpiration(ctx, accountID)
am.peerLoginExpiry.Cancel(ctx, []string{accountID})
am.schedulePeerLoginExpiration(ctx, accountID)
}
}
@@ -334,7 +335,7 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer
return status.NewPermissionDeniedError()
}
peerAccountID, err := am.Store.GetAccountIDByPeerID(ctx, store.LockingStrengthShare, peerID)
peerAccountID, err := am.Store.GetAccountIDByPeerID(ctx, store.LockingStrengthNone, peerID)
if err != nil {
return err
}
@@ -467,7 +468,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
addedByUser := false
if len(userID) > 0 {
addedByUser = true
accountID, err = am.Store.GetAccountIDByUserID(ctx, store.LockingStrengthShare, userID)
accountID, err = am.Store.GetAccountIDByUserID(ctx, store.LockingStrengthNone, userID)
} else {
accountID, err = am.Store.GetAccountIDBySetupKey(ctx, encodedHashedKey)
}
@@ -487,7 +488,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
// and the peer disconnects with a timeout and tries to register again.
// We just check if this machine has been registered before and reject the second registration.
// The connecting peer should be able to recover with a retry.
_, err = am.Store.GetPeerByPeerPubKey(ctx, store.LockingStrengthShare, peer.Key)
_, err = am.Store.GetPeerByPeerPubKey(ctx, store.LockingStrengthNone, peer.Key)
if err == nil {
return nil, nil, nil, status.Errorf(status.PreconditionFailed, "peer has been already registered")
}
@@ -583,7 +584,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
ExtraDNSLabels: peer.ExtraDNSLabels,
AllowExtraDNSLabels: allowExtraDNSLabels,
}
settings, err := transaction.GetAccountSettings(ctx, store.LockingStrengthShare, accountID)
settings, err := transaction.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return fmt.Errorf("failed to get account settings: %w", err)
}
@@ -673,7 +674,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
}
func getFreeIP(ctx context.Context, transaction store.Store, accountID string) (net.IP, error) {
takenIps, err := transaction.GetTakenIPs(ctx, store.LockingStrengthShare, accountID)
takenIps, err := transaction.GetTakenIPs(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return nil, fmt.Errorf("failed to get taken IPs: %w", err)
}
@@ -705,7 +706,7 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync types.PeerSy
var err error
var postureChecks []*posture.Checks
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID)
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return nil, nil, nil, err
}
@@ -717,7 +718,7 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync types.PeerSy
}
if peer.UserID != "" {
user, err := transaction.GetUserByUserID(ctx, store.LockingStrengthShare, peer.UserID)
user, err := transaction.GetUserByUserID(ctx, store.LockingStrengthNone, peer.UserID)
if err != nil {
return err
}
@@ -820,7 +821,7 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login types.Peer
var isPeerUpdated bool
var postureChecks []*posture.Checks
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID)
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return nil, nil, nil, err
}
@@ -905,7 +906,7 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login types.Peer
// getPeerPostureChecks returns the posture checks for the peer.
func getPeerPostureChecks(ctx context.Context, transaction store.Store, accountID, peerID string) ([]*posture.Checks, error) {
policies, err := transaction.GetAccountPolicies(ctx, store.LockingStrengthShare, accountID)
policies, err := transaction.GetAccountPolicies(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return nil, err
}
@@ -929,7 +930,7 @@ func getPeerPostureChecks(ctx context.Context, transaction store.Store, accountI
peerPostureChecksIDs = append(peerPostureChecksIDs, postureChecksIDs...)
}
peerPostureChecks, err := transaction.GetPostureChecksByIDs(ctx, store.LockingStrengthShare, accountID, peerPostureChecksIDs)
peerPostureChecks, err := transaction.GetPostureChecksByIDs(ctx, store.LockingStrengthNone, accountID, peerPostureChecksIDs)
if err != nil {
return nil, err
}
@@ -944,7 +945,7 @@ func processPeerPostureChecks(ctx context.Context, transaction store.Store, poli
continue
}
sourceGroups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthShare, accountID, rule.Sources)
sourceGroups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthNone, accountID, rule.Sources)
if err != nil {
return nil, err
}
@@ -969,7 +970,7 @@ func processPeerPostureChecks(ctx context.Context, transaction store.Store, poli
// with no JWT token and usually no setup-key. As the client can send up to two login request to check if it is expired
// and before starting the engine, we do the checks without an account lock to avoid piling up requests.
func (am *DefaultAccountManager) checkIFPeerNeedsLoginWithoutLock(ctx context.Context, accountID string, login types.PeerLogin) error {
peer, err := am.Store.GetPeerByPeerPubKey(ctx, store.LockingStrengthShare, login.WireGuardPubKey)
peer, err := am.Store.GetPeerByPeerPubKey(ctx, store.LockingStrengthNone, login.WireGuardPubKey)
if err != nil {
return err
}
@@ -980,7 +981,7 @@ func (am *DefaultAccountManager) checkIFPeerNeedsLoginWithoutLock(ctx context.Co
return nil
}
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID)
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return err
}
@@ -999,7 +1000,7 @@ func (am *DefaultAccountManager) getValidatedPeerWithMap(ctx context.Context, is
}()
if isRequiresApproval {
network, err := am.Store.GetAccountNetwork(ctx, store.LockingStrengthShare, accountID)
network, err := am.Store.GetAccountNetwork(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return nil, nil, nil, err
}
@@ -1061,7 +1062,7 @@ func (am *DefaultAccountManager) handleExpiredPeer(ctx context.Context, transact
log.WithContext(ctx).Debugf("failed to update user last login: %v", err)
}
settings, err := transaction.GetAccountSettings(ctx, store.LockingStrengthShare, peer.AccountID)
settings, err := transaction.GetAccountSettings(ctx, store.LockingStrengthNone, peer.AccountID)
if err != nil {
return fmt.Errorf("failed to get account settings: %w", err)
}
@@ -1103,7 +1104,7 @@ func peerLoginExpired(ctx context.Context, peer *nbpeer.Peer, settings *types.Se
// GetPeer for a given accountID, peerID and userID error if not found.
func (am *DefaultAccountManager) GetPeer(ctx context.Context, accountID, peerID, userID string) (*nbpeer.Peer, error) {
peer, err := am.Store.GetPeerByID(ctx, store.LockingStrengthShare, accountID, peerID)
peer, err := am.Store.GetPeerByID(ctx, store.LockingStrengthNone, accountID, peerID)
if err != nil {
return nil, err
}
@@ -1116,7 +1117,7 @@ func (am *DefaultAccountManager) GetPeer(ctx context.Context, accountID, peerID,
return peer, nil
}
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID)
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, userID)
if err != nil {
return nil, err
}
@@ -1142,13 +1143,13 @@ func (am *DefaultAccountManager) checkIfUserOwnsPeer(ctx context.Context, accoun
// it is also possible that user doesn't own the peer but some of his peers have access to it,
// this is a valid case, show the peer as well.
userPeers, err := am.Store.GetUserPeers(ctx, store.LockingStrengthShare, accountID, userID)
userPeers, err := am.Store.GetUserPeers(ctx, store.LockingStrengthNone, accountID, userID)
if err != nil {
return nil, err
}
for _, p := range userPeers {
aclPeers, _ := account.GetPeerConnectionResources(ctx, p.ID, approvedPeersMap)
aclPeers, _ := account.GetPeerConnectionResources(ctx, p, approvedPeersMap)
for _, aclPeer := range aclPeers {
if aclPeer.ID == peer.ID {
return peer, nil
@@ -1168,7 +1169,7 @@ func (am *DefaultAccountManager) UpdateAccountPeers(ctx context.Context, account
return
}
start := time.Now()
globalStart := time.Now()
approvedPeersMap, err := am.integratedPeerValidator.GetValidatedPeers(account.Id, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra)
if err != nil {
@@ -1203,18 +1204,27 @@ func (am *DefaultAccountManager) UpdateAccountPeers(ctx context.Context, account
defer wg.Done()
defer func() { <-semaphore }()
start := time.Now()
postureChecks, err := am.getPeerPostureChecks(account, p.ID)
if err != nil {
log.WithContext(ctx).Debugf("failed to get posture checks for peer %s: %v", peer.ID, err)
return
}
am.metrics.UpdateChannelMetrics().CountCalcPostureChecksDuration(time.Since(start))
start = time.Now()
remotePeerNetworkMap := account.GetPeerNetworkMap(ctx, p.ID, customZone, approvedPeersMap, resourcePolicies, routers, am.metrics.AccountManagerMetrics())
am.metrics.UpdateChannelMetrics().CountCalcPeerNetworkMapDuration(time.Since(start))
start = time.Now()
proxyNetworkMap, ok := proxyNetworkMaps[p.ID]
if ok {
remotePeerNetworkMap.Merge(proxyNetworkMap)
}
am.metrics.UpdateChannelMetrics().CountMergeNetworkMapDuration(time.Since(start))
extraSetting, err := am.settingsManager.GetExtraSettings(ctx, accountID)
if err != nil {
@@ -1222,7 +1232,10 @@ func (am *DefaultAccountManager) UpdateAccountPeers(ctx context.Context, account
return
}
start = time.Now()
update := toSyncResponse(ctx, nil, p, nil, nil, remotePeerNetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings, extraSetting)
am.metrics.UpdateChannelMetrics().CountToSyncResponseDuration(time.Since(start))
am.peersUpdateManager.SendUpdate(ctx, p.ID, &UpdateMessage{Update: update, NetworkMap: remotePeerNetworkMap})
}(peer)
}
@@ -1231,7 +1244,7 @@ func (am *DefaultAccountManager) UpdateAccountPeers(ctx context.Context, account
wg.Wait()
if am.metrics != nil {
am.metrics.AccountManagerMetrics().CountUpdateAccountPeersDuration(time.Since(start))
am.metrics.AccountManagerMetrics().CountUpdateAccountPeersDuration(time.Since(globalStart))
}
}
@@ -1315,7 +1328,7 @@ func (am *DefaultAccountManager) UpdateAccountPeer(ctx context.Context, accountI
// If there is no peer that expires this function returns false and a duration of 0.
// This function only considers peers that haven't been expired yet and that are connected.
func (am *DefaultAccountManager) getNextPeerExpiration(ctx context.Context, accountID string) (time.Duration, bool) {
peersWithExpiry, err := am.Store.GetAccountPeersWithExpiration(ctx, store.LockingStrengthShare, accountID)
peersWithExpiry, err := am.Store.GetAccountPeersWithExpiration(ctx, store.LockingStrengthNone, accountID)
if err != nil {
log.WithContext(ctx).Errorf("failed to get peers with expiration: %v", err)
return peerSchedulerRetryInterval, true
@@ -1325,7 +1338,7 @@ func (am *DefaultAccountManager) getNextPeerExpiration(ctx context.Context, acco
return 0, false
}
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID)
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
if err != nil {
log.WithContext(ctx).Errorf("failed to get account settings: %v", err)
return peerSchedulerRetryInterval, true
@@ -1359,7 +1372,7 @@ func (am *DefaultAccountManager) getNextPeerExpiration(ctx context.Context, acco
// If there is no peer that expires this function returns false and a duration of 0.
// This function only considers peers that haven't been expired yet and that are not connected.
func (am *DefaultAccountManager) getNextInactivePeerExpiration(ctx context.Context, accountID string) (time.Duration, bool) {
peersWithInactivity, err := am.Store.GetAccountPeersWithInactivity(ctx, store.LockingStrengthShare, accountID)
peersWithInactivity, err := am.Store.GetAccountPeersWithInactivity(ctx, store.LockingStrengthNone, accountID)
if err != nil {
log.WithContext(ctx).Errorf("failed to get peers with inactivity: %v", err)
return peerSchedulerRetryInterval, true
@@ -1369,7 +1382,7 @@ func (am *DefaultAccountManager) getNextInactivePeerExpiration(ctx context.Conte
return 0, false
}
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID)
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
if err != nil {
log.WithContext(ctx).Errorf("failed to get account settings: %v", err)
return peerSchedulerRetryInterval, true
@@ -1400,12 +1413,12 @@ func (am *DefaultAccountManager) getNextInactivePeerExpiration(ctx context.Conte
// getExpiredPeers returns peers that have been expired.
func (am *DefaultAccountManager) getExpiredPeers(ctx context.Context, accountID string) ([]*nbpeer.Peer, error) {
peersWithExpiry, err := am.Store.GetAccountPeersWithExpiration(ctx, store.LockingStrengthShare, accountID)
peersWithExpiry, err := am.Store.GetAccountPeersWithExpiration(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return nil, err
}
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID)
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return nil, err
}
@@ -1423,12 +1436,12 @@ func (am *DefaultAccountManager) getExpiredPeers(ctx context.Context, accountID
// getInactivePeers returns peers that have been expired by inactivity
func (am *DefaultAccountManager) getInactivePeers(ctx context.Context, accountID string) ([]*nbpeer.Peer, error) {
peersWithInactivity, err := am.Store.GetAccountPeersWithInactivity(ctx, store.LockingStrengthShare, accountID)
peersWithInactivity, err := am.Store.GetAccountPeersWithInactivity(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return nil, err
}
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID)
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return nil, err
}
@@ -1446,12 +1459,12 @@ func (am *DefaultAccountManager) getInactivePeers(ctx context.Context, accountID
// GetPeerGroups returns groups that the peer is part of.
func (am *DefaultAccountManager) GetPeerGroups(ctx context.Context, accountID, peerID string) ([]*types.Group, error) {
return am.Store.GetPeerGroups(ctx, store.LockingStrengthShare, accountID, peerID)
return am.Store.GetPeerGroups(ctx, store.LockingStrengthNone, accountID, peerID)
}
// getPeerGroupIDs returns the IDs of the groups that the peer is part of.
func getPeerGroupIDs(ctx context.Context, transaction store.Store, accountID string, peerID string) ([]string, error) {
groups, err := transaction.GetPeerGroups(ctx, store.LockingStrengthShare, accountID, peerID)
groups, err := transaction.GetPeerGroups(ctx, store.LockingStrengthNone, accountID, peerID)
if err != nil {
return nil, err
}
@@ -1465,7 +1478,7 @@ func getPeerGroupIDs(ctx context.Context, transaction store.Store, accountID str
}
func getPeerDNSLabels(ctx context.Context, transaction store.Store, accountID string) (types.LookupMap, error) {
dnsLabels, err := transaction.GetPeerLabelsInAccount(ctx, store.LockingStrengthShare, accountID)
dnsLabels, err := transaction.GetPeerLabelsInAccount(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return nil, err
}
@@ -1492,7 +1505,7 @@ func isPeerInActiveGroup(ctx context.Context, transaction store.Store, accountID
func deletePeers(ctx context.Context, am *DefaultAccountManager, transaction store.Store, accountID, userID string, peers []*nbpeer.Peer) ([]func(), error) {
var peerDeletedEvents []func()
settings, err := transaction.GetAccountSettings(ctx, store.LockingStrengthShare, accountID)
settings, err := transaction.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return nil, err
}
@@ -1503,7 +1516,7 @@ func deletePeers(ctx context.Context, am *DefaultAccountManager, transaction sto
return nil, err
}
network, err := transaction.GetAccountNetwork(ctx, store.LockingStrengthShare, accountID)
network, err := transaction.GetAccountNetwork(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return nil, err
}
@@ -1564,7 +1577,7 @@ func (am *DefaultAccountManager) validatePeerDelete(ctx context.Context, transac
// isPeerLinkedToNetworkRouter checks if a peer is linked to any network router in the account.
func isPeerLinkedToNetworkRouter(ctx context.Context, transaction store.Store, accountID string, peerID string) (bool, *routerTypes.NetworkRouter) {
routers, err := transaction.GetNetworkRoutersByAccountID(ctx, store.LockingStrengthShare, accountID)
routers, err := transaction.GetNetworkRoutersByAccountID(ctx, store.LockingStrengthNone, accountID)
if err != nil {
log.WithContext(ctx).Errorf("error retrieving network routers while checking peer linkage: %v", err)
return false, nil

View File

@@ -31,7 +31,7 @@ type Peer struct {
// Status peer's management connection status
Status *PeerStatus `gorm:"embedded;embeddedPrefix:peer_status_"`
// The user ID that registered the peer
UserID string
UserID string `gorm:"index"`
// SSHKey is a public SSH key of the peer
SSHKey string
// SSHEnabled indicates whether SSH server is enabled on the peer

View File

@@ -1301,7 +1301,7 @@ func Test_RegisterPeerByUser(t *testing.T) {
require.NoError(t, err)
assert.Equal(t, newPeer.ExtraDNSLabels, addedPeer.ExtraDNSLabels)
peer, err := s.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthShare, addedPeer.Key)
peer, err := s.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, addedPeer.Key)
require.NoError(t, err)
assert.Equal(t, peer.AccountID, existingAccountID)
assert.Equal(t, peer.UserID, existingUserID)
@@ -1423,7 +1423,7 @@ func Test_RegisterPeerBySetupKey(t *testing.T) {
assert.NotNil(t, addedPeer, "addedPeer should not be nil on success")
assert.Equal(t, currentPeer.ExtraDNSLabels, addedPeer.ExtraDNSLabels, "ExtraDNSLabels mismatch")
peerFromStore, err := s.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthShare, currentPeer.Key)
peerFromStore, err := s.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, currentPeer.Key)
require.NoError(t, err, "Failed to get peer by pub key: %s", currentPeer.Key)
assert.Equal(t, existingAccountID, peerFromStore.AccountID, "AccountID mismatch for peer from store")
assert.Equal(t, currentPeer.ExtraDNSLabels, peerFromStore.ExtraDNSLabels, "ExtraDNSLabels mismatch for peer from store")
@@ -1505,7 +1505,7 @@ func Test_RegisterPeerRollbackOnFailure(t *testing.T) {
_, _, _, err = am.AddPeer(context.Background(), faultyKey, "", newPeer)
require.Error(t, err)
_, err = s.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthShare, newPeer.Key)
_, err = s.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, newPeer.Key)
require.Error(t, err)
account, err := s.GetAccount(context.Background(), existingAccountID)
@@ -1671,7 +1671,7 @@ func Test_LoginPeer(t *testing.T) {
assert.Equal(t, existingAccountID, loggedinPeer.AccountID, "AccountID mismatch for logged peer")
peerFromStore, err := s.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthShare, loginInput.WireGuardPubKey)
peerFromStore, err := s.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, loginInput.WireGuardPubKey)
require.NoError(t, err, "Failed to get peer by pub key: %s", loginInput.WireGuardPubKey)
assert.Equal(t, existingAccountID, peerFromStore.AccountID, "AccountID mismatch for peer from store")
assert.Equal(t, loggedinPeer.ID, peerFromStore.ID, "Peer ID mismatch between loggedinPeer and peerFromStore")

View File

@@ -42,7 +42,7 @@ func (m *managerImpl) GetPeer(ctx context.Context, accountID, userID, peerID str
return nil, status.NewPermissionDeniedError()
}
return m.store.GetPeerByID(ctx, store.LockingStrengthShare, accountID, peerID)
return m.store.GetPeerByID(ctx, store.LockingStrengthNone, accountID, peerID)
}
func (m *managerImpl) GetAllPeers(ctx context.Context, accountID, userID string) ([]*peer.Peer, error) {
@@ -52,12 +52,12 @@ func (m *managerImpl) GetAllPeers(ctx context.Context, accountID, userID string)
}
if !allowed {
return m.store.GetUserPeers(ctx, store.LockingStrengthShare, accountID, userID)
return m.store.GetUserPeers(ctx, store.LockingStrengthNone, accountID, userID)
}
return m.store.GetAccountPeers(ctx, store.LockingStrengthShare, accountID, "", "")
return m.store.GetAccountPeers(ctx, store.LockingStrengthNone, accountID, "", "")
}
func (m *managerImpl) GetPeerAccountID(ctx context.Context, peerID string) (string, error) {
return m.store.GetAccountIDByPeerID(ctx, store.LockingStrengthShare, peerID)
return m.store.GetAccountIDByPeerID(ctx, store.LockingStrengthNone, peerID)
}

View File

@@ -45,7 +45,7 @@ func (m *managerImpl) ValidateUserPermissions(
return true, nil
}
user, err := m.store.GetUserByUserID(ctx, store.LockingStrengthShare, userID)
user, err := m.store.GetUserByUserID(ctx, store.LockingStrengthNone, userID)
if err != nil {
return false, err
}

View File

@@ -27,7 +27,7 @@ func (am *DefaultAccountManager) GetPolicy(ctx context.Context, accountID, polic
return nil, status.NewPermissionDeniedError()
}
return am.Store.GetPolicyByID(ctx, store.LockingStrengthShare, accountID, policyID)
return am.Store.GetPolicyByID(ctx, store.LockingStrengthNone, accountID, policyID)
}
// SavePolicy in the store
@@ -142,13 +142,13 @@ func (am *DefaultAccountManager) ListPolicies(ctx context.Context, accountID, us
return nil, status.NewPermissionDeniedError()
}
return am.Store.GetAccountPolicies(ctx, store.LockingStrengthShare, accountID)
return am.Store.GetAccountPolicies(ctx, store.LockingStrengthNone, accountID)
}
// arePolicyChangesAffectPeers checks if changes to a policy will affect any associated peers.
func arePolicyChangesAffectPeers(ctx context.Context, transaction store.Store, accountID string, policy *types.Policy, isUpdate bool) (bool, error) {
if isUpdate {
existingPolicy, err := transaction.GetPolicyByID(ctx, store.LockingStrengthShare, accountID, policy.ID)
existingPolicy, err := transaction.GetPolicyByID(ctx, store.LockingStrengthNone, accountID, policy.ID)
if err != nil {
return false, err
}
@@ -173,7 +173,7 @@ func arePolicyChangesAffectPeers(ctx context.Context, transaction store.Store, a
// validatePolicy validates the policy and its rules.
func validatePolicy(ctx context.Context, transaction store.Store, accountID string, policy *types.Policy) error {
if policy.ID != "" {
_, err := transaction.GetPolicyByID(ctx, store.LockingStrengthShare, accountID, policy.ID)
_, err := transaction.GetPolicyByID(ctx, store.LockingStrengthNone, accountID, policy.ID)
if err != nil {
return err
}
@@ -182,12 +182,12 @@ func validatePolicy(ctx context.Context, transaction store.Store, accountID stri
policy.AccountID = accountID
}
groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthShare, accountID, policy.RuleGroups())
groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthNone, accountID, policy.RuleGroups())
if err != nil {
return err
}
postureChecks, err := transaction.GetPostureChecksByIDs(ctx, store.LockingStrengthShare, accountID, policy.SourcePostureChecks)
postureChecks, err := transaction.GetPostureChecksByIDs(ctx, store.LockingStrengthNone, accountID, policy.SourcePostureChecks)
if err != nil {
return err
}

View File

@@ -27,6 +27,7 @@ func TestAccount_getPeersByPolicy(t *testing.T) {
ID: "peerB",
IP: net.ParseIP("100.65.80.39"),
Status: &nbpeer.PeerStatus{},
Meta: nbpeer.PeerSystemMeta{WtVersion: "0.48.0"},
},
"peerC": {
ID: "peerC",
@@ -63,6 +64,12 @@ func TestAccount_getPeersByPolicy(t *testing.T) {
IP: net.ParseIP("100.65.31.2"),
Status: &nbpeer.PeerStatus{},
},
"peerK": {
ID: "peerK",
IP: net.ParseIP("100.32.80.1"),
Status: &nbpeer.PeerStatus{},
Meta: nbpeer.PeerSystemMeta{WtVersion: "0.30.0"},
},
},
Groups: map[string]*types.Group{
"GroupAll": {
@@ -111,6 +118,13 @@ func TestAccount_getPeersByPolicy(t *testing.T) {
"peerI",
},
},
"GroupWorkflow": {
ID: "GroupWorkflow",
Name: "workflow",
Peers: []string{
"peerK",
},
},
},
Policies: []*types.Policy{
{
@@ -189,6 +203,39 @@ func TestAccount_getPeersByPolicy(t *testing.T) {
},
},
},
{
ID: "RuleWorkflow",
Name: "Workflow",
Description: "No description",
Enabled: true,
Rules: []*types.PolicyRule{
{
ID: "RuleWorkflow",
Name: "Workflow",
Description: "No description",
Bidirectional: true,
Enabled: true,
Protocol: types.PolicyRuleProtocolTCP,
Action: types.PolicyTrafficActionAccept,
PortRanges: []types.RulePortRange{
{
Start: 8088,
End: 8088,
},
{
Start: 9090,
End: 9095,
},
},
Sources: []string{
"GroupWorkflow",
},
Destinations: []string{
"GroupDMZ",
},
},
},
},
},
}
@@ -199,14 +246,14 @@ func TestAccount_getPeersByPolicy(t *testing.T) {
t.Run("check that all peers get map", func(t *testing.T) {
for _, p := range account.Peers {
peers, firewallRules := account.GetPeerConnectionResources(context.Background(), p.ID, validatedPeers)
assert.GreaterOrEqual(t, len(peers), 2, "minimum number peers should present")
assert.GreaterOrEqual(t, len(firewallRules), 2, "minimum number of firewall rules should present")
peers, firewallRules := account.GetPeerConnectionResources(context.Background(), p, validatedPeers)
assert.GreaterOrEqual(t, len(peers), 1, "minimum number peers should present")
assert.GreaterOrEqual(t, len(firewallRules), 1, "minimum number of firewall rules should present")
}
})
t.Run("check first peer map details", func(t *testing.T) {
peers, firewallRules := account.GetPeerConnectionResources(context.Background(), "peerB", validatedPeers)
peers, firewallRules := account.GetPeerConnectionResources(context.Background(), account.Peers["peerB"], validatedPeers)
assert.Len(t, peers, 8)
assert.Contains(t, peers, account.Peers["peerA"])
assert.Contains(t, peers, account.Peers["peerC"])
@@ -364,6 +411,32 @@ func TestAccount_getPeersByPolicy(t *testing.T) {
assert.True(t, contains, "rule not found in expected rules %#v", rule)
}
})
t.Run("check port ranges support for older peers", func(t *testing.T) {
peers, firewallRules := account.GetPeerConnectionResources(context.Background(), account.Peers["peerK"], validatedPeers)
assert.Len(t, peers, 1)
assert.Contains(t, peers, account.Peers["peerI"])
expectedFirewallRules := []*types.FirewallRule{
{
PeerIP: "100.65.31.2",
Direction: types.FirewallRuleDirectionIN,
Action: "accept",
Protocol: "tcp",
Port: "8088",
PolicyID: "RuleWorkflow",
},
{
PeerIP: "100.65.31.2",
Direction: types.FirewallRuleDirectionOUT,
Action: "accept",
Protocol: "tcp",
Port: "8088",
PolicyID: "RuleWorkflow",
},
}
assert.ElementsMatch(t, firewallRules, expectedFirewallRules)
})
}
func TestAccount_getPeersByPolicyDirect(t *testing.T) {
@@ -466,10 +539,10 @@ func TestAccount_getPeersByPolicyDirect(t *testing.T) {
}
t.Run("check first peer map", func(t *testing.T) {
peers, firewallRules := account.GetPeerConnectionResources(context.Background(), "peerB", approvedPeers)
peers, firewallRules := account.GetPeerConnectionResources(context.Background(), account.Peers["peerB"], approvedPeers)
assert.Contains(t, peers, account.Peers["peerC"])
epectedFirewallRules := []*types.FirewallRule{
expectedFirewallRules := []*types.FirewallRule{
{
PeerIP: "100.65.254.139",
Direction: types.FirewallRuleDirectionIN,
@@ -487,19 +560,19 @@ func TestAccount_getPeersByPolicyDirect(t *testing.T) {
PolicyID: "RuleSwarm",
},
}
assert.Len(t, firewallRules, len(epectedFirewallRules))
slices.SortFunc(epectedFirewallRules, sortFunc())
assert.Len(t, firewallRules, len(expectedFirewallRules))
slices.SortFunc(expectedFirewallRules, sortFunc())
slices.SortFunc(firewallRules, sortFunc())
for i := range firewallRules {
assert.Equal(t, epectedFirewallRules[i], firewallRules[i])
assert.Equal(t, expectedFirewallRules[i], firewallRules[i])
}
})
t.Run("check second peer map", func(t *testing.T) {
peers, firewallRules := account.GetPeerConnectionResources(context.Background(), "peerC", approvedPeers)
peers, firewallRules := account.GetPeerConnectionResources(context.Background(), account.Peers["peerC"], approvedPeers)
assert.Contains(t, peers, account.Peers["peerB"])
epectedFirewallRules := []*types.FirewallRule{
expectedFirewallRules := []*types.FirewallRule{
{
PeerIP: "100.65.80.39",
Direction: types.FirewallRuleDirectionIN,
@@ -517,21 +590,21 @@ func TestAccount_getPeersByPolicyDirect(t *testing.T) {
PolicyID: "RuleSwarm",
},
}
assert.Len(t, firewallRules, len(epectedFirewallRules))
slices.SortFunc(epectedFirewallRules, sortFunc())
assert.Len(t, firewallRules, len(expectedFirewallRules))
slices.SortFunc(expectedFirewallRules, sortFunc())
slices.SortFunc(firewallRules, sortFunc())
for i := range firewallRules {
assert.Equal(t, epectedFirewallRules[i], firewallRules[i])
assert.Equal(t, expectedFirewallRules[i], firewallRules[i])
}
})
account.Policies[1].Rules[0].Bidirectional = false
t.Run("check first peer map directional only", func(t *testing.T) {
peers, firewallRules := account.GetPeerConnectionResources(context.Background(), "peerB", approvedPeers)
peers, firewallRules := account.GetPeerConnectionResources(context.Background(), account.Peers["peerB"], approvedPeers)
assert.Contains(t, peers, account.Peers["peerC"])
epectedFirewallRules := []*types.FirewallRule{
expectedFirewallRules := []*types.FirewallRule{
{
PeerIP: "100.65.254.139",
Direction: types.FirewallRuleDirectionOUT,
@@ -541,19 +614,19 @@ func TestAccount_getPeersByPolicyDirect(t *testing.T) {
PolicyID: "RuleSwarm",
},
}
assert.Len(t, firewallRules, len(epectedFirewallRules))
slices.SortFunc(epectedFirewallRules, sortFunc())
assert.Len(t, firewallRules, len(expectedFirewallRules))
slices.SortFunc(expectedFirewallRules, sortFunc())
slices.SortFunc(firewallRules, sortFunc())
for i := range firewallRules {
assert.Equal(t, epectedFirewallRules[i], firewallRules[i])
assert.Equal(t, expectedFirewallRules[i], firewallRules[i])
}
})
t.Run("check second peer map directional only", func(t *testing.T) {
peers, firewallRules := account.GetPeerConnectionResources(context.Background(), "peerC", approvedPeers)
peers, firewallRules := account.GetPeerConnectionResources(context.Background(), account.Peers["peerC"], approvedPeers)
assert.Contains(t, peers, account.Peers["peerB"])
epectedFirewallRules := []*types.FirewallRule{
expectedFirewallRules := []*types.FirewallRule{
{
PeerIP: "100.65.80.39",
Direction: types.FirewallRuleDirectionIN,
@@ -563,11 +636,11 @@ func TestAccount_getPeersByPolicyDirect(t *testing.T) {
PolicyID: "RuleSwarm",
},
}
assert.Len(t, firewallRules, len(epectedFirewallRules))
slices.SortFunc(epectedFirewallRules, sortFunc())
assert.Len(t, firewallRules, len(expectedFirewallRules))
slices.SortFunc(expectedFirewallRules, sortFunc())
slices.SortFunc(firewallRules, sortFunc())
for i := range firewallRules {
assert.Equal(t, epectedFirewallRules[i], firewallRules[i])
assert.Equal(t, expectedFirewallRules[i], firewallRules[i])
}
})
}
@@ -748,7 +821,7 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) {
t.Run("verify peer's network map with default group peer list", func(t *testing.T) {
// peerB doesn't fulfill the NB posture check but is included in the destination group Swarm,
// will establish a connection with all source peers satisfying the NB posture check.
peers, firewallRules := account.GetPeerConnectionResources(context.Background(), "peerB", approvedPeers)
peers, firewallRules := account.GetPeerConnectionResources(context.Background(), account.Peers["peerB"], approvedPeers)
assert.Len(t, peers, 4)
assert.Len(t, firewallRules, 4)
assert.Contains(t, peers, account.Peers["peerA"])
@@ -758,7 +831,7 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) {
// peerC satisfy the NB posture check, should establish connection to all destination group peer's
// We expect a single permissive firewall rule which all outgoing connections
peers, firewallRules = account.GetPeerConnectionResources(context.Background(), "peerC", approvedPeers)
peers, firewallRules = account.GetPeerConnectionResources(context.Background(), account.Peers["peerC"], approvedPeers)
assert.Len(t, peers, len(account.Groups["GroupSwarm"].Peers))
assert.Len(t, firewallRules, 1)
expectedFirewallRules := []*types.FirewallRule{
@@ -775,7 +848,7 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) {
// peerE doesn't fulfill the NB posture check and exists in only destination group Swarm,
// all source group peers satisfying the NB posture check should establish connection
peers, firewallRules = account.GetPeerConnectionResources(context.Background(), "peerE", approvedPeers)
peers, firewallRules = account.GetPeerConnectionResources(context.Background(), account.Peers["peerE"], approvedPeers)
assert.Len(t, peers, 4)
assert.Len(t, firewallRules, 4)
assert.Contains(t, peers, account.Peers["peerA"])
@@ -785,7 +858,7 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) {
// peerI doesn't fulfill the OS version posture check and exists in only destination group Swarm,
// all source group peers satisfying the NB posture check should establish connection
peers, firewallRules = account.GetPeerConnectionResources(context.Background(), "peerI", approvedPeers)
peers, firewallRules = account.GetPeerConnectionResources(context.Background(), account.Peers["peerI"], approvedPeers)
assert.Len(t, peers, 4)
assert.Len(t, firewallRules, 4)
assert.Contains(t, peers, account.Peers["peerA"])
@@ -800,19 +873,19 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) {
// peerB doesn't satisfy the NB posture check, and doesn't exist in destination group peer's
// no connection should be established to any peer of destination group
peers, firewallRules := account.GetPeerConnectionResources(context.Background(), "peerB", approvedPeers)
peers, firewallRules := account.GetPeerConnectionResources(context.Background(), account.Peers["peerB"], approvedPeers)
assert.Len(t, peers, 0)
assert.Len(t, firewallRules, 0)
// peerI doesn't satisfy the OS version posture check, and doesn't exist in destination group peer's
// no connection should be established to any peer of destination group
peers, firewallRules = account.GetPeerConnectionResources(context.Background(), "peerI", approvedPeers)
peers, firewallRules = account.GetPeerConnectionResources(context.Background(), account.Peers["peerI"], approvedPeers)
assert.Len(t, peers, 0)
assert.Len(t, firewallRules, 0)
// peerC satisfy the NB posture check, should establish connection to all destination group peer's
// We expect a single permissive firewall rule which all outgoing connections
peers, firewallRules = account.GetPeerConnectionResources(context.Background(), "peerC", approvedPeers)
peers, firewallRules = account.GetPeerConnectionResources(context.Background(), account.Peers["peerC"], approvedPeers)
assert.Len(t, peers, len(account.Groups["GroupSwarm"].Peers))
assert.Len(t, firewallRules, len(account.Groups["GroupSwarm"].Peers))
@@ -827,14 +900,14 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) {
// peerE doesn't fulfill the NB posture check and exists in only destination group Swarm,
// all source group peers satisfying the NB posture check should establish connection
peers, firewallRules = account.GetPeerConnectionResources(context.Background(), "peerE", approvedPeers)
peers, firewallRules = account.GetPeerConnectionResources(context.Background(), account.Peers["peerE"], approvedPeers)
assert.Len(t, peers, 3)
assert.Len(t, firewallRules, 3)
assert.Contains(t, peers, account.Peers["peerA"])
assert.Contains(t, peers, account.Peers["peerC"])
assert.Contains(t, peers, account.Peers["peerD"])
peers, firewallRules = account.GetPeerConnectionResources(context.Background(), "peerA", approvedPeers)
peers, firewallRules = account.GetPeerConnectionResources(context.Background(), account.Peers["peerA"], approvedPeers)
assert.Len(t, peers, 5)
// assert peers from Group Swarm
assert.Contains(t, peers, account.Peers["peerD"])

View File

@@ -24,20 +24,12 @@ func sanitizeVersion(version string) string {
}
func (n *NBVersionCheck) Check(ctx context.Context, peer nbpeer.Peer) (bool, error) {
peerVersion := sanitizeVersion(peer.Meta.WtVersion)
minVersion := sanitizeVersion(n.MinVersion)
peerNBVersion, err := version.NewVersion(peerVersion)
meetsMin, err := MeetsMinVersion(n.MinVersion, peer.Meta.WtVersion)
if err != nil {
return false, err
}
constraints, err := version.NewConstraint(">= " + minVersion)
if err != nil {
return false, err
}
if constraints.Check(peerNBVersion) {
if meetsMin {
return true, nil
}
@@ -60,3 +52,21 @@ func (n *NBVersionCheck) Validate() error {
}
return nil
}
// MeetsMinVersion checks if the peer's version meets or exceeds the minimum required version
func MeetsMinVersion(minVer, peerVer string) (bool, error) {
peerVer = sanitizeVersion(peerVer)
minVer = sanitizeVersion(minVer)
peerNBVer, err := version.NewVersion(peerVer)
if err != nil {
return false, err
}
constraints, err := version.NewConstraint(">= " + minVer)
if err != nil {
return false, err
}
return constraints.Check(peerNBVer), nil
}

View File

@@ -139,3 +139,68 @@ func TestNBVersionCheck_Validate(t *testing.T) {
})
}
}
func TestMeetsMinVersion(t *testing.T) {
tests := []struct {
name string
minVer string
peerVer string
want bool
wantErr bool
}{
{
name: "Peer version greater than min version",
minVer: "0.26.0",
peerVer: "0.60.1",
want: true,
wantErr: false,
},
{
name: "Peer version equals min version",
minVer: "1.0.0",
peerVer: "1.0.0",
want: true,
wantErr: false,
},
{
name: "Peer version less than min version",
minVer: "1.0.0",
peerVer: "0.9.9",
want: false,
wantErr: false,
},
{
name: "Peer version with pre-release tag greater than min version",
minVer: "1.0.0",
peerVer: "1.0.1-alpha",
want: true,
wantErr: false,
},
{
name: "Invalid peer version format",
minVer: "1.0.0",
peerVer: "dev",
want: false,
wantErr: true,
},
{
name: "Invalid min version format",
minVer: "invalid.version",
peerVer: "1.0.0",
want: false,
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := MeetsMinVersion(tt.minVer, tt.peerVer)
if tt.wantErr {
assert.Error(t, err)
} else {
assert.NoError(t, err)
}
assert.Equal(t, tt.want, got)
})
}
}

View File

@@ -27,7 +27,7 @@ func (am *DefaultAccountManager) GetPostureChecks(ctx context.Context, accountID
return nil, status.NewPermissionDeniedError()
}
return am.Store.GetPostureChecksByID(ctx, store.LockingStrengthShare, accountID, postureChecksID)
return am.Store.GetPostureChecksByID(ctx, store.LockingStrengthNone, accountID, postureChecksID)
}
// SavePostureChecks saves a posture check.
@@ -101,7 +101,7 @@ func (am *DefaultAccountManager) DeletePostureChecks(ctx context.Context, accoun
var postureChecks *posture.Checks
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
postureChecks, err = transaction.GetPostureChecksByID(ctx, store.LockingStrengthShare, accountID, postureChecksID)
postureChecks, err = transaction.GetPostureChecksByID(ctx, store.LockingStrengthNone, accountID, postureChecksID)
if err != nil {
return err
}
@@ -135,7 +135,7 @@ func (am *DefaultAccountManager) ListPostureChecks(ctx context.Context, accountI
return nil, status.NewPermissionDeniedError()
}
return am.Store.GetAccountPostureChecks(ctx, store.LockingStrengthShare, accountID)
return am.Store.GetAccountPostureChecks(ctx, store.LockingStrengthNone, accountID)
}
// getPeerPostureChecks returns the posture checks applied for a given peer.
@@ -161,7 +161,7 @@ func (am *DefaultAccountManager) getPeerPostureChecks(account *types.Account, pe
// arePostureCheckChangesAffectPeers checks if the changes in posture checks are affecting peers.
func arePostureCheckChangesAffectPeers(ctx context.Context, transaction store.Store, accountID, postureCheckID string) (bool, error) {
policies, err := transaction.GetAccountPolicies(ctx, store.LockingStrengthShare, accountID)
policies, err := transaction.GetAccountPolicies(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return false, err
}
@@ -190,14 +190,14 @@ func validatePostureChecks(ctx context.Context, transaction store.Store, account
// If the posture check already has an ID, verify its existence in the store.
if postureChecks.ID != "" {
if _, err := transaction.GetPostureChecksByID(ctx, store.LockingStrengthShare, accountID, postureChecks.ID); err != nil {
if _, err := transaction.GetPostureChecksByID(ctx, store.LockingStrengthNone, accountID, postureChecks.ID); err != nil {
return err
}
return nil
}
// For new posture checks, ensure no duplicates by name.
checks, err := transaction.GetAccountPostureChecks(ctx, store.LockingStrengthShare, accountID)
checks, err := transaction.GetAccountPostureChecks(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return err
}
@@ -259,7 +259,7 @@ func isPeerInPolicySourceGroups(account *types.Account, peerID string, policy *t
// isPostureCheckLinkedToPolicy checks whether the posture check is linked to any account policy.
func isPostureCheckLinkedToPolicy(ctx context.Context, transaction store.Store, postureChecksID, accountID string) error {
policies, err := transaction.GetAccountPolicies(ctx, store.LockingStrengthShare, accountID)
policies, err := transaction.GetAccountPolicies(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return err
}

View File

@@ -4,19 +4,19 @@ import (
"context"
"fmt"
"net/netip"
"slices"
"unicode/utf8"
"github.com/rs/xid"
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/management/domain"
"github.com/netbirdio/netbird/management/proto"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/status"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/route"
)
@@ -30,13 +30,19 @@ func (am *DefaultAccountManager) GetRoute(ctx context.Context, accountID string,
return nil, status.NewPermissionDeniedError()
}
return am.Store.GetRouteByID(ctx, store.LockingStrengthShare, string(routeID), accountID)
return am.Store.GetRouteByID(ctx, store.LockingStrengthNone, accountID, string(routeID))
}
// checkRoutePrefixOrDomainsExistForPeers checks if a route with a given prefix exists for a single peer or multiple peer groups.
func (am *DefaultAccountManager) checkRoutePrefixOrDomainsExistForPeers(account *types.Account, peerID string, routeID route.ID, peerGroupIDs []string, prefix netip.Prefix, domains domain.List) error {
func checkRoutePrefixOrDomainsExistForPeers(ctx context.Context, transaction store.Store, accountID string, checkRoute *route.Route, groupsMap map[string]*types.Group) error {
// routes can have both peer and peer_groups
routesWithPrefix := account.GetRoutesByPrefixOrDomains(prefix, domains)
prefix := checkRoute.Network
domains := checkRoute.Domains
routesWithPrefix, err := getRoutesByPrefixOrDomains(ctx, transaction, accountID, prefix, domains)
if err != nil {
return err
}
// lets remember all the peers and the peer groups from routesWithPrefix
seenPeers := make(map[string]bool)
@@ -45,18 +51,24 @@ func (am *DefaultAccountManager) checkRoutePrefixOrDomainsExistForPeers(account
for _, prefixRoute := range routesWithPrefix {
// we skip route(s) with the same network ID as we want to allow updating of the existing route
// when creating a new route routeID is newly generated so nothing will be skipped
if routeID == prefixRoute.ID {
if checkRoute.ID == prefixRoute.ID {
continue
}
if prefixRoute.Peer != "" {
seenPeers[string(prefixRoute.ID)] = true
}
peerGroupsMap, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthNone, accountID, prefixRoute.PeerGroups)
if err != nil {
return err
}
for _, groupID := range prefixRoute.PeerGroups {
seenPeerGroups[groupID] = true
group := account.GetGroup(groupID)
if group == nil {
group, ok := peerGroupsMap[groupID]
if !ok || group == nil {
return status.Errorf(
status.InvalidArgument, "failed to add route with %s - peer group %s doesn't exist",
getRouteDescriptor(prefix, domains), groupID,
@@ -69,12 +81,13 @@ func (am *DefaultAccountManager) checkRoutePrefixOrDomainsExistForPeers(account
}
}
if peerID != "" {
if peerID := checkRoute.Peer; peerID != "" {
// check that peerID exists and is not in any route as single peer or part of the group
peer := account.GetPeer(peerID)
if peer == nil {
_, err = transaction.GetPeerByID(context.Background(), store.LockingStrengthNone, accountID, peerID)
if err != nil {
return status.Errorf(status.InvalidArgument, "peer with ID %s not found", peerID)
}
if _, ok := seenPeers[peerID]; ok {
return status.Errorf(status.AlreadyExists,
"failed to add route with %s - peer %s already has this route", getRouteDescriptor(prefix, domains), peerID)
@@ -82,9 +95,8 @@ func (am *DefaultAccountManager) checkRoutePrefixOrDomainsExistForPeers(account
}
// check that peerGroupIDs are not in any route peerGroups list
for _, groupID := range peerGroupIDs {
group := account.GetGroup(groupID) // we validated the group existence before entering this function, no need to check again.
for _, groupID := range checkRoute.PeerGroups {
group := groupsMap[groupID] // we validated the group existence before entering this function, no need to check again.
if _, ok := seenPeerGroups[groupID]; ok {
return status.Errorf(
status.AlreadyExists, "failed to add route with %s - peer group %s already has this route",
@@ -92,12 +104,18 @@ func (am *DefaultAccountManager) checkRoutePrefixOrDomainsExistForPeers(account
}
// check that the peers from peerGroupIDs groups are not the same peers we saw in routesWithPrefix
peersMap, err := transaction.GetPeersByIDs(ctx, store.LockingStrengthNone, accountID, group.Peers)
if err != nil {
return err
}
for _, id := range group.Peers {
if _, ok := seenPeers[id]; ok {
peer := account.GetPeer(id)
if peer == nil {
return status.Errorf(status.InvalidArgument, "peer with ID %s not found", peerID)
peer, ok := peersMap[id]
if !ok || peer == nil {
return status.Errorf(status.InvalidArgument, "peer with ID %s not found", id)
}
return status.Errorf(status.AlreadyExists,
"failed to add route with %s - peer %s from the group %s already has this route",
getRouteDescriptor(prefix, domains), peer.Name, group.Name)
@@ -128,97 +146,58 @@ func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID stri
return nil, status.NewPermissionDeniedError()
}
account, err := am.Store.GetAccount(ctx, accountID)
if err != nil {
return nil, err
}
if len(domains) > 0 && prefix.IsValid() {
return nil, status.Errorf(status.InvalidArgument, "domains and network should not be provided at the same time")
}
if len(domains) == 0 && !prefix.IsValid() {
return nil, status.Errorf(status.InvalidArgument, "invalid Prefix")
}
var newRoute *route.Route
var updateAccountPeers bool
if len(domains) > 0 {
prefix = getPlaceholderIP()
}
if peerID != "" && len(peerGroupIDs) != 0 {
return nil, status.Errorf(
status.InvalidArgument,
"peer with ID %s and peers group %s should not be provided at the same time",
peerID, peerGroupIDs)
}
var newRoute route.Route
newRoute.ID = route.ID(xid.New().String())
if len(peerGroupIDs) > 0 {
err = validateGroups(peerGroupIDs, account.Groups)
if err != nil {
return nil, err
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
newRoute = &route.Route{
ID: route.ID(xid.New().String()),
AccountID: accountID,
Network: prefix,
Domains: domains,
KeepRoute: keepRoute,
NetID: netID,
Description: description,
Peer: peerID,
PeerGroups: peerGroupIDs,
NetworkType: networkType,
Masquerade: masquerade,
Metric: metric,
Enabled: enabled,
Groups: groups,
AccessControlGroups: accessControlGroupIDs,
}
}
if len(accessControlGroupIDs) > 0 {
err = validateGroups(accessControlGroupIDs, account.Groups)
if err != nil {
return nil, err
if err = validateRoute(ctx, transaction, accountID, newRoute); err != nil {
return err
}
}
err = am.checkRoutePrefixOrDomainsExistForPeers(account, peerID, newRoute.ID, peerGroupIDs, prefix, domains)
updateAccountPeers, err = areRouteChangesAffectPeers(ctx, transaction, newRoute)
if err != nil {
return err
}
if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil {
return err
}
return transaction.SaveRoute(ctx, store.LockingStrengthUpdate, newRoute)
})
if err != nil {
return nil, err
}
if metric < route.MinMetric || metric > route.MaxMetric {
return nil, status.Errorf(status.InvalidArgument, "metric should be between %d and %d", route.MinMetric, route.MaxMetric)
}
if utf8.RuneCountInString(string(netID)) > route.MaxNetIDChar || netID == "" {
return nil, status.Errorf(status.InvalidArgument, "identifier should be between 1 and %d", route.MaxNetIDChar)
}
err = validateGroups(groups, account.Groups)
if err != nil {
return nil, err
}
newRoute.Peer = peerID
newRoute.PeerGroups = peerGroupIDs
newRoute.Network = prefix
newRoute.Domains = domains
newRoute.NetworkType = networkType
newRoute.Description = description
newRoute.NetID = netID
newRoute.Masquerade = masquerade
newRoute.Metric = metric
newRoute.Enabled = enabled
newRoute.Groups = groups
newRoute.KeepRoute = keepRoute
newRoute.AccessControlGroups = accessControlGroupIDs
if account.Routes == nil {
account.Routes = make(map[route.ID]*route.Route)
}
account.Routes[newRoute.ID] = &newRoute
account.Network.IncSerial()
if err = am.Store.SaveAccount(ctx, account); err != nil {
return nil, err
}
if am.isRouteChangeAffectPeers(account, &newRoute) {
am.UpdateAccountPeers(ctx, accountID)
}
am.StoreEvent(ctx, userID, string(newRoute.ID), accountID, activity.RouteCreated, newRoute.EventMeta())
return &newRoute, nil
if updateAccountPeers {
am.UpdateAccountPeers(ctx, accountID)
}
return newRoute, nil
}
// SaveRoute saves route
@@ -226,6 +205,115 @@ func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userI
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Routes, operations.Update)
if err != nil {
return status.NewPermissionValidationError(err)
}
if !allowed {
return status.NewPermissionDeniedError()
}
var oldRoute *route.Route
var oldRouteAffectsPeers bool
var newRouteAffectsPeers bool
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
if err = validateRoute(ctx, transaction, accountID, routeToSave); err != nil {
return err
}
oldRoute, err = transaction.GetRouteByID(ctx, store.LockingStrengthUpdate, accountID, string(routeToSave.ID))
if err != nil {
return err
}
oldRouteAffectsPeers, err = areRouteChangesAffectPeers(ctx, transaction, oldRoute)
if err != nil {
return err
}
newRouteAffectsPeers, err = areRouteChangesAffectPeers(ctx, transaction, routeToSave)
if err != nil {
return err
}
routeToSave.AccountID = accountID
if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil {
return err
}
return transaction.SaveRoute(ctx, store.LockingStrengthUpdate, routeToSave)
})
if err != nil {
return err
}
am.StoreEvent(ctx, userID, string(routeToSave.ID), accountID, activity.RouteUpdated, routeToSave.EventMeta())
if oldRouteAffectsPeers || newRouteAffectsPeers {
am.UpdateAccountPeers(ctx, accountID)
}
return nil
}
// DeleteRoute deletes route with routeID
func (am *DefaultAccountManager) DeleteRoute(ctx context.Context, accountID string, routeID route.ID, userID string) error {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Routes, operations.Delete)
if err != nil {
return status.NewPermissionValidationError(err)
}
if !allowed {
return status.NewPermissionDeniedError()
}
var route *route.Route
var updateAccountPeers bool
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
route, err = transaction.GetRouteByID(ctx, store.LockingStrengthUpdate, accountID, string(routeID))
if err != nil {
return err
}
updateAccountPeers, err = areRouteChangesAffectPeers(ctx, transaction, route)
if err != nil {
return err
}
if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil {
return err
}
return transaction.DeleteRoute(ctx, store.LockingStrengthUpdate, accountID, string(routeID))
})
am.StoreEvent(ctx, userID, string(route.ID), accountID, activity.RouteRemoved, route.EventMeta())
if updateAccountPeers {
am.UpdateAccountPeers(ctx, accountID)
}
return nil
}
// ListRoutes returns a list of routes from account
func (am *DefaultAccountManager) ListRoutes(ctx context.Context, accountID, userID string) ([]*route.Route, error) {
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Routes, operations.Read)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
if !allowed {
return nil, status.NewPermissionDeniedError()
}
return am.Store.GetAccountRoutes(ctx, store.LockingStrengthNone, accountID)
}
func validateRoute(ctx context.Context, transaction store.Store, accountID string, routeToSave *route.Route) error {
if routeToSave == nil {
return status.Errorf(status.InvalidArgument, "route provided is nil")
}
@@ -238,19 +326,6 @@ func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userI
return status.Errorf(status.InvalidArgument, "identifier should be between 1 and %d", route.MaxNetIDChar)
}
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Routes, operations.Update)
if err != nil {
return status.NewPermissionValidationError(err)
}
if !allowed {
return status.NewPermissionDeniedError()
}
account, err := am.Store.GetAccount(ctx, accountID)
if err != nil {
return err
}
if len(routeToSave.Domains) > 0 && routeToSave.Network.IsValid() {
return status.Errorf(status.InvalidArgument, "domains and network should not be provided at the same time")
}
@@ -267,96 +342,39 @@ func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userI
return status.Errorf(status.InvalidArgument, "peer with ID and peer groups should not be provided at the same time")
}
groupsMap, err := validateRouteGroups(ctx, transaction, accountID, routeToSave)
if err != nil {
return err
}
return checkRoutePrefixOrDomainsExistForPeers(ctx, transaction, accountID, routeToSave, groupsMap)
}
// validateRouteGroups validates the route groups and returns the validated groups map.
func validateRouteGroups(ctx context.Context, transaction store.Store, accountID string, routeToSave *route.Route) (map[string]*types.Group, error) {
groupsToValidate := slices.Concat(routeToSave.Groups, routeToSave.PeerGroups, routeToSave.AccessControlGroups)
groupsMap, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthNone, accountID, groupsToValidate)
if err != nil {
return nil, err
}
if len(routeToSave.PeerGroups) > 0 {
err = validateGroups(routeToSave.PeerGroups, account.Groups)
if err != nil {
return err
if err = validateGroups(routeToSave.PeerGroups, groupsMap); err != nil {
return nil, err
}
}
if len(routeToSave.AccessControlGroups) > 0 {
err = validateGroups(routeToSave.AccessControlGroups, account.Groups)
if err != nil {
return err
if err = validateGroups(routeToSave.AccessControlGroups, groupsMap); err != nil {
return nil, err
}
}
err = am.checkRoutePrefixOrDomainsExistForPeers(account, routeToSave.Peer, routeToSave.ID, routeToSave.Copy().PeerGroups, routeToSave.Network, routeToSave.Domains)
if err != nil {
return err
if err = validateGroups(routeToSave.Groups, groupsMap); err != nil {
return nil, err
}
err = validateGroups(routeToSave.Groups, account.Groups)
if err != nil {
return err
}
oldRoute := account.Routes[routeToSave.ID]
account.Routes[routeToSave.ID] = routeToSave
account.Network.IncSerial()
if err = am.Store.SaveAccount(ctx, account); err != nil {
return err
}
if am.isRouteChangeAffectPeers(account, oldRoute) || am.isRouteChangeAffectPeers(account, routeToSave) {
am.UpdateAccountPeers(ctx, accountID)
}
am.StoreEvent(ctx, userID, string(routeToSave.ID), accountID, activity.RouteUpdated, routeToSave.EventMeta())
return nil
}
// DeleteRoute deletes route with routeID
func (am *DefaultAccountManager) DeleteRoute(ctx context.Context, accountID string, routeID route.ID, userID string) error {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Routes, operations.Delete)
if err != nil {
return status.NewPermissionValidationError(err)
}
if !allowed {
return status.NewPermissionDeniedError()
}
account, err := am.Store.GetAccount(ctx, accountID)
if err != nil {
return err
}
routy := account.Routes[routeID]
if routy == nil {
return status.Errorf(status.NotFound, "route with ID %s doesn't exist", routeID)
}
delete(account.Routes, routeID)
account.Network.IncSerial()
if err = am.Store.SaveAccount(ctx, account); err != nil {
return err
}
am.StoreEvent(ctx, userID, string(routy.ID), accountID, activity.RouteRemoved, routy.EventMeta())
if am.isRouteChangeAffectPeers(account, routy) {
am.UpdateAccountPeers(ctx, accountID)
}
return nil
}
// ListRoutes returns a list of routes from account
func (am *DefaultAccountManager) ListRoutes(ctx context.Context, accountID, userID string) ([]*route.Route, error) {
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Routes, operations.Read)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
if !allowed {
return nil, status.NewPermissionDeniedError()
}
return am.Store.GetAccountRoutes(ctx, store.LockingStrengthShare, accountID)
return groupsMap, nil
}
func toProtocolRoute(route *route.Route) *proto.Route {
@@ -455,8 +473,40 @@ func getProtoPortInfo(rule *types.RouteFirewallRule) *proto.PortInfo {
return &portInfo
}
// isRouteChangeAffectPeers checks if a given route affects peers by determining
// if it has a routing peer, distribution, or peer groups that include peers
func (am *DefaultAccountManager) isRouteChangeAffectPeers(account *types.Account, route *route.Route) bool {
return am.anyGroupHasPeers(account, route.Groups) || am.anyGroupHasPeers(account, route.PeerGroups) || route.Peer != ""
// areRouteChangesAffectPeers checks if a given route affects peers by determining
// if it has a routing peer, distribution, or peer groups that include peers.
func areRouteChangesAffectPeers(ctx context.Context, transaction store.Store, route *route.Route) (bool, error) {
if route.Peer != "" {
return true, nil
}
hasPeers, err := anyGroupHasPeersOrResources(ctx, transaction, route.AccountID, route.Groups)
if err != nil {
return false, err
}
if hasPeers {
return true, nil
}
return anyGroupHasPeersOrResources(ctx, transaction, route.AccountID, route.PeerGroups)
}
// GetRoutesByPrefixOrDomains return list of routes by account and route prefix
func getRoutesByPrefixOrDomains(ctx context.Context, transaction store.Store, accountID string, prefix netip.Prefix, domains domain.List) ([]*route.Route, error) {
accountRoutes, err := transaction.GetAccountRoutes(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return nil, err
}
routes := make([]*route.Route, 0)
for _, r := range accountRoutes {
dynamic := r.IsDynamic()
if dynamic && r.Domains.PunycodeString() == domains.PunycodeString() ||
!dynamic && r.Network.String() == prefix.String() {
routes = append(routes, r)
}
}
return routes, nil
}

View File

@@ -1100,7 +1100,7 @@ func TestGetNetworkMap_RouteSyncPeerGroups(t *testing.T) {
require.NoError(t, err)
assert.Len(t, peer4Routes.Routes, 1, "HA route should have 1 server route")
groups, err := am.Store.GetAccountGroups(context.Background(), store.LockingStrengthShare, account.Id)
groups, err := am.Store.GetAccountGroups(context.Background(), store.LockingStrengthNone, account.Id)
require.NoError(t, err)
var groupHA1, groupHA2 *types.Group
for _, group := range groups {

View File

@@ -12,6 +12,7 @@ import (
type Scheduler interface {
Cancel(ctx context.Context, IDs []string)
Schedule(ctx context.Context, in time.Duration, ID string, job func() (nextRunIn time.Duration, reschedule bool))
IsSchedulerRunning(ID string) bool
}
// MockScheduler is a mock implementation of Scheduler
@@ -26,7 +27,7 @@ func (mock *MockScheduler) Cancel(ctx context.Context, IDs []string) {
mock.CancelFunc(ctx, IDs)
return
}
log.WithContext(ctx).Errorf("MockScheduler doesn't have Cancel function defined ")
log.WithContext(ctx).Warnf("MockScheduler doesn't have Cancel function defined ")
}
// Schedule mocks the Schedule function of the Scheduler interface
@@ -35,7 +36,13 @@ func (mock *MockScheduler) Schedule(ctx context.Context, in time.Duration, ID st
mock.ScheduleFunc(ctx, in, ID, job)
return
}
log.WithContext(ctx).Errorf("MockScheduler doesn't have Schedule function defined")
log.WithContext(ctx).Warnf("MockScheduler doesn't have Schedule function defined")
}
func (mock *MockScheduler) IsSchedulerRunning(ID string) bool {
// MockScheduler does not implement IsSchedulerRunning, so we return false
log.Warnf("MockScheduler doesn't have IsSchedulerRunning function defined")
return false
}
// DefaultScheduler is a generic structure that allows to schedule jobs (functions) to run in the future and cancel them.
@@ -124,3 +131,11 @@ func (wm *DefaultScheduler) Schedule(ctx context.Context, in time.Duration, ID s
}()
}
// IsSchedulerRunning checks if a job with the provided ID is scheduled to run
func (wm *DefaultScheduler) IsSchedulerRunning(ID string) bool {
wm.mu.Lock()
defer wm.mu.Unlock()
_, ok := wm.jobs[ID]
return ok
}

View File

@@ -60,7 +60,7 @@ func (m *managerImpl) GetSettings(ctx context.Context, accountID, userID string)
return nil, fmt.Errorf("get extra settings: %w", err)
}
settings, err := m.store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID)
settings, err := m.store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return nil, fmt.Errorf("get account settings: %w", err)
}
@@ -82,7 +82,7 @@ func (m *managerImpl) GetExtraSettings(ctx context.Context, accountID string) (*
return nil, fmt.Errorf("get extra settings: %w", err)
}
settings, err := m.store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID)
settings, err := m.store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return nil, fmt.Errorf("get account settings: %w", err)
}

View File

@@ -127,7 +127,7 @@ func (am *DefaultAccountManager) SaveSetupKey(ctx context.Context, accountID str
return status.Errorf(status.InvalidArgument, "invalid auto groups: %v", err)
}
oldKey, err = transaction.GetSetupKeyByID(ctx, store.LockingStrengthShare, accountID, keyToSave.Id)
oldKey, err = transaction.GetSetupKeyByID(ctx, store.LockingStrengthNone, accountID, keyToSave.Id)
if err != nil {
return err
}
@@ -175,7 +175,7 @@ func (am *DefaultAccountManager) ListSetupKeys(ctx context.Context, accountID, u
return nil, status.NewPermissionDeniedError()
}
return am.Store.GetAccountSetupKeys(ctx, store.LockingStrengthShare, accountID)
return am.Store.GetAccountSetupKeys(ctx, store.LockingStrengthNone, accountID)
}
// GetSetupKey looks up a SetupKey by KeyID, returns NotFound error if not found.
@@ -188,7 +188,7 @@ func (am *DefaultAccountManager) GetSetupKey(ctx context.Context, accountID, use
return nil, status.NewPermissionDeniedError()
}
setupKey, err := am.Store.GetSetupKeyByID(ctx, store.LockingStrengthShare, accountID, keyID)
setupKey, err := am.Store.GetSetupKeyByID(ctx, store.LockingStrengthNone, accountID, keyID)
if err != nil {
return nil, err
}
@@ -214,7 +214,7 @@ func (am *DefaultAccountManager) DeleteSetupKey(ctx context.Context, accountID,
var deletedSetupKey *types.SetupKey
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
deletedSetupKey, err = transaction.GetSetupKeyByID(ctx, store.LockingStrengthShare, accountID, keyID)
deletedSetupKey, err = transaction.GetSetupKeyByID(ctx, store.LockingStrengthNone, accountID, keyID)
if err != nil {
return err
}
@@ -231,7 +231,7 @@ func (am *DefaultAccountManager) DeleteSetupKey(ctx context.Context, accountID,
}
func validateSetupKeyAutoGroups(ctx context.Context, transaction store.Store, accountID string, autoGroupIDs []string) error {
groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthShare, accountID, autoGroupIDs)
groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthNone, accountID, autoGroupIDs)
if err != nil {
return err
}
@@ -255,7 +255,7 @@ func (am *DefaultAccountManager) prepareSetupKeyEvents(ctx context.Context, tran
var eventsToStore []func()
modifiedGroups := slices.Concat(addedGroups, removedGroups)
groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthShare, accountID, modifiedGroups)
groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthNone, accountID, modifiedGroups)
if err != nil {
log.WithContext(ctx).Debugf("failed to get groups for setup key events: %v", err)
return nil

View File

@@ -227,3 +227,7 @@ func NewUserRoleNotFoundError(role string) error {
func NewOperationNotFoundError(operation operations.Operation) error {
return Errorf(NotFound, "operation: %s not found", operation)
}
func NewRouteNotFoundError(routeID string) error {
return Errorf(NotFound, "route: %s not found", routeID)
}

View File

@@ -23,8 +23,6 @@ import (
"gorm.io/gorm/clause"
"gorm.io/gorm/logger"
"github.com/netbirdio/netbird/management/server/util"
nbdns "github.com/netbirdio/netbird/dns"
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
@@ -34,6 +32,7 @@ import (
"github.com/netbirdio/netbird/management/server/status"
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/management/server/util"
"github.com/netbirdio/netbird/route"
)
@@ -479,7 +478,7 @@ func (s *SqlStore) DeleteTokenID2UserIDIndex(tokenID string) error {
}
func (s *SqlStore) GetAccountByPrivateDomain(ctx context.Context, domain string) (*types.Account, error) {
accountID, err := s.GetAccountIDByPrivateDomain(ctx, LockingStrengthShare, domain)
accountID, err := s.GetAccountIDByPrivateDomain(ctx, LockingStrengthNone, domain)
if err != nil {
return nil, err
}
@@ -489,11 +488,7 @@ func (s *SqlStore) GetAccountByPrivateDomain(ctx context.Context, domain string)
}
func (s *SqlStore) GetAccountIDByPrivateDomain(ctx context.Context, lockStrength LockingStrength, domain string) (string, error) {
tx := s.db
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
tx := s.getTXWithLockStrength(lockStrength)
var accountID string
result := tx.Model(&types.Account{}).Select("id").
Where("domain = ? and is_domain_primary_account = ? and domain_category = ?",
@@ -543,11 +538,7 @@ func (s *SqlStore) GetTokenIDByHashedToken(ctx context.Context, hashedToken stri
}
func (s *SqlStore) GetUserByPATID(ctx context.Context, lockStrength LockingStrength, patID string) (*types.User, error) {
tx := s.db
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
tx := s.getTXWithLockStrength(lockStrength)
var user types.User
result := tx.
Joins("JOIN personal_access_tokens ON personal_access_tokens.user_id = users.id").
@@ -564,11 +555,7 @@ func (s *SqlStore) GetUserByPATID(ctx context.Context, lockStrength LockingStren
}
func (s *SqlStore) GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*types.User, error) {
tx := s.db
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
tx := s.getTXWithLockStrength(lockStrength)
var user types.User
result := tx.First(&user, idQueryCondition, userID)
if result.Error != nil {
@@ -601,11 +588,7 @@ func (s *SqlStore) DeleteUser(ctx context.Context, lockStrength LockingStrength,
}
func (s *SqlStore) GetAccountUsers(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types.User, error) {
tx := s.db
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
tx := s.getTXWithLockStrength(lockStrength)
var users []*types.User
result := tx.Find(&users, accountIDCondition, accountID)
if result.Error != nil {
@@ -620,11 +603,7 @@ func (s *SqlStore) GetAccountUsers(ctx context.Context, lockStrength LockingStre
}
func (s *SqlStore) GetAccountOwner(ctx context.Context, lockStrength LockingStrength, accountID string) (*types.User, error) {
tx := s.db
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
tx := s.getTXWithLockStrength(lockStrength)
var user types.User
result := tx.First(&user, "account_id = ? AND role = ?", accountID, types.UserRoleOwner)
if result.Error != nil {
@@ -638,11 +617,7 @@ func (s *SqlStore) GetAccountOwner(ctx context.Context, lockStrength LockingStre
}
func (s *SqlStore) GetAccountGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types.Group, error) {
tx := s.db
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
tx := s.getTXWithLockStrength(lockStrength)
var groups []*types.Group
result := tx.Find(&groups, accountIDCondition, accountID)
if result.Error != nil {
@@ -657,11 +632,7 @@ func (s *SqlStore) GetAccountGroups(ctx context.Context, lockStrength LockingStr
}
func (s *SqlStore) GetResourceGroups(ctx context.Context, lockStrength LockingStrength, accountID, resourceID string) ([]*types.Group, error) {
tx := s.db
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
tx := s.getTXWithLockStrength(lockStrength)
var groups []*types.Group
likePattern := `%"ID":"` + resourceID + `"%`
@@ -707,11 +678,7 @@ func (s *SqlStore) GetAllAccounts(ctx context.Context) (all []*types.Account) {
}
func (s *SqlStore) GetAccountMeta(ctx context.Context, lockStrength LockingStrength, accountID string) (*types.AccountMeta, error) {
tx := s.db
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
tx := s.getTXWithLockStrength(lockStrength)
var accountMeta types.AccountMeta
result := tx.Model(&types.Account{}).
First(&accountMeta, idQueryCondition, accountID)
@@ -881,11 +848,7 @@ func (s *SqlStore) GetAccountIDByPeerPubKey(ctx context.Context, peerKey string)
}
func (s *SqlStore) GetAccountIDByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (string, error) {
tx := s.db
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
tx := s.getTXWithLockStrength(lockStrength)
var accountID string
result := tx.Model(&types.User{}).
Select("account_id").Where(idQueryCondition, userID).First(&accountID)
@@ -900,11 +863,7 @@ func (s *SqlStore) GetAccountIDByUserID(ctx context.Context, lockStrength Lockin
}
func (s *SqlStore) GetAccountIDByPeerID(ctx context.Context, lockStrength LockingStrength, peerID string) (string, error) {
tx := s.db
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
tx := s.getTXWithLockStrength(lockStrength)
var accountID string
result := tx.Model(&nbpeer.Peer{}).
Select("account_id").Where(idQueryCondition, peerID).First(&accountID)
@@ -937,11 +896,7 @@ func (s *SqlStore) GetAccountIDBySetupKey(ctx context.Context, setupKey string)
}
func (s *SqlStore) GetTakenIPs(ctx context.Context, lockStrength LockingStrength, accountID string) ([]net.IP, error) {
tx := s.db
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
tx := s.getTXWithLockStrength(lockStrength)
var ipJSONStrings []string
// Fetch the IP addresses as JSON strings
@@ -969,11 +924,7 @@ func (s *SqlStore) GetTakenIPs(ctx context.Context, lockStrength LockingStrength
}
func (s *SqlStore) GetPeerLabelsInAccount(ctx context.Context, lockStrength LockingStrength, accountID string) ([]string, error) {
tx := s.db
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
tx := s.getTXWithLockStrength(lockStrength)
var labels []string
result := tx.Model(&nbpeer.Peer{}).
Where("account_id = ?", accountID).
@@ -991,11 +942,7 @@ func (s *SqlStore) GetPeerLabelsInAccount(ctx context.Context, lockStrength Lock
}
func (s *SqlStore) GetAccountNetwork(ctx context.Context, lockStrength LockingStrength, accountID string) (*types.Network, error) {
tx := s.db
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
tx := s.getTXWithLockStrength(lockStrength)
var accountNetwork types.AccountNetwork
if err := tx.Model(&types.Account{}).Where(idQueryCondition, accountID).First(&accountNetwork).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
@@ -1007,11 +954,7 @@ func (s *SqlStore) GetAccountNetwork(ctx context.Context, lockStrength LockingSt
}
func (s *SqlStore) GetPeerByPeerPubKey(ctx context.Context, lockStrength LockingStrength, peerKey string) (*nbpeer.Peer, error) {
tx := s.db
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
tx := s.getTXWithLockStrength(lockStrength)
var peer nbpeer.Peer
result := tx.First(&peer, GetKeyQueryCondition(s), peerKey)
@@ -1026,11 +969,7 @@ func (s *SqlStore) GetPeerByPeerPubKey(ctx context.Context, lockStrength Locking
}
func (s *SqlStore) GetAccountSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*types.Settings, error) {
tx := s.db
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
tx := s.getTXWithLockStrength(lockStrength)
var accountSettings types.AccountSettings
if err := tx.Model(&types.Account{}).Where(idQueryCondition, accountID).First(&accountSettings).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
@@ -1042,11 +981,7 @@ func (s *SqlStore) GetAccountSettings(ctx context.Context, lockStrength LockingS
}
func (s *SqlStore) GetAccountCreatedBy(ctx context.Context, lockStrength LockingStrength, accountID string) (string, error) {
tx := s.db
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
tx := s.getTXWithLockStrength(lockStrength)
var createdBy string
result := tx.Model(&types.Account{}).
Select("created_by").First(&createdBy, idQueryCondition, accountID)
@@ -1244,11 +1179,7 @@ func NewMysqlStoreFromSqlStore(ctx context.Context, sqliteStore *SqlStore, dsn s
}
func (s *SqlStore) GetSetupKeyBySecret(ctx context.Context, lockStrength LockingStrength, key string) (*types.SetupKey, error) {
tx := s.db
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
tx := s.getTXWithLockStrength(lockStrength)
var setupKey types.SetupKey
result := tx.
First(&setupKey, GetKeyQueryCondition(s), key)
@@ -1392,11 +1323,7 @@ func (s *SqlStore) RemoveResourceFromGroup(ctx context.Context, accountId string
// GetPeerGroups retrieves all groups assigned to a specific peer in a given account.
func (s *SqlStore) GetPeerGroups(ctx context.Context, lockStrength LockingStrength, accountId string, peerId string) ([]*types.Group, error) {
tx := s.db
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
tx := s.getTXWithLockStrength(lockStrength)
var groups []*types.Group
query := tx.
Find(&groups, "account_id = ? AND peers LIKE ?", accountId, fmt.Sprintf(`%%"%s"%%`, peerId))
@@ -1410,8 +1337,9 @@ func (s *SqlStore) GetPeerGroups(ctx context.Context, lockStrength LockingStreng
// GetAccountPeers retrieves peers for an account.
func (s *SqlStore) GetAccountPeers(ctx context.Context, lockStrength LockingStrength, accountID, nameFilter, ipFilter string) ([]*nbpeer.Peer, error) {
tx := s.getTXWithLockStrength(lockStrength)
var peers []*nbpeer.Peer
query := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Where(accountIDCondition, accountID)
query := tx.Where(accountIDCondition, accountID)
if nameFilter != "" {
query = query.Where("name LIKE ?", "%"+nameFilter+"%")
@@ -1430,11 +1358,7 @@ func (s *SqlStore) GetAccountPeers(ctx context.Context, lockStrength LockingStre
// GetUserPeers retrieves peers for a user.
func (s *SqlStore) GetUserPeers(ctx context.Context, lockStrength LockingStrength, accountID, userID string) ([]*nbpeer.Peer, error) {
tx := s.db
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
tx := s.getTXWithLockStrength(lockStrength)
var peers []*nbpeer.Peer
// Exclude peers added via setup keys, as they are not user-specific and have an empty user_id.
@@ -1462,11 +1386,7 @@ func (s *SqlStore) AddPeerToAccount(ctx context.Context, lockStrength LockingStr
// GetPeerByID retrieves a peer by its ID and account ID.
func (s *SqlStore) GetPeerByID(ctx context.Context, lockStrength LockingStrength, accountID, peerID string) (*nbpeer.Peer, error) {
tx := s.db
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
tx := s.getTXWithLockStrength(lockStrength)
var peer *nbpeer.Peer
result := tx.
First(&peer, accountAndIDQueryCondition, accountID, peerID)
@@ -1482,11 +1402,7 @@ func (s *SqlStore) GetPeerByID(ctx context.Context, lockStrength LockingStrength
// GetPeersByIDs retrieves peers by their IDs and account ID.
func (s *SqlStore) GetPeersByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, peerIDs []string) (map[string]*nbpeer.Peer, error) {
tx := s.db
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
tx := s.getTXWithLockStrength(lockStrength)
var peers []*nbpeer.Peer
result := tx.Find(&peers, accountAndIDsQueryCondition, accountID, peerIDs)
if result.Error != nil {
@@ -1504,11 +1420,7 @@ func (s *SqlStore) GetPeersByIDs(ctx context.Context, lockStrength LockingStreng
// GetAccountPeersWithExpiration retrieves a list of peers that have login expiration enabled and added by a user.
func (s *SqlStore) GetAccountPeersWithExpiration(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbpeer.Peer, error) {
tx := s.db
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
tx := s.getTXWithLockStrength(lockStrength)
var peers []*nbpeer.Peer
result := tx.
Where("login_expiration_enabled = ? AND user_id IS NOT NULL AND user_id != ''", true).
@@ -1523,11 +1435,7 @@ func (s *SqlStore) GetAccountPeersWithExpiration(ctx context.Context, lockStreng
// GetAccountPeersWithInactivity retrieves a list of peers that have login expiration enabled and added by a user.
func (s *SqlStore) GetAccountPeersWithInactivity(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbpeer.Peer, error) {
tx := s.db
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
tx := s.getTXWithLockStrength(lockStrength)
var peers []*nbpeer.Peer
result := tx.
Where("inactivity_expiration_enabled = ? AND user_id IS NOT NULL AND user_id != ''", true).
@@ -1542,11 +1450,7 @@ func (s *SqlStore) GetAccountPeersWithInactivity(ctx context.Context, lockStreng
// GetAllEphemeralPeers retrieves all peers with Ephemeral set to true across all accounts, optimized for batch processing.
func (s *SqlStore) GetAllEphemeralPeers(ctx context.Context, lockStrength LockingStrength) ([]*nbpeer.Peer, error) {
tx := s.db
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
tx := s.getTXWithLockStrength(lockStrength)
var allEphemeralPeers, batchPeers []*nbpeer.Peer
result := tx.
Where("ephemeral = ?", true).
@@ -1624,11 +1528,7 @@ func (s *SqlStore) GetDB() *gorm.DB {
}
func (s *SqlStore) GetAccountDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*types.DNSSettings, error) {
tx := s.db
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
tx := s.getTXWithLockStrength(lockStrength)
var accountDNSSettings types.AccountDNSSettings
result := tx.Model(&types.Account{}).
First(&accountDNSSettings, idQueryCondition, accountID)
@@ -1644,11 +1544,7 @@ func (s *SqlStore) GetAccountDNSSettings(ctx context.Context, lockStrength Locki
// AccountExists checks whether an account exists by the given ID.
func (s *SqlStore) AccountExists(ctx context.Context, lockStrength LockingStrength, id string) (bool, error) {
tx := s.db
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
tx := s.getTXWithLockStrength(lockStrength)
var accountID string
result := tx.Model(&types.Account{}).
Select("id").First(&accountID, idQueryCondition, id)
@@ -1664,11 +1560,7 @@ func (s *SqlStore) AccountExists(ctx context.Context, lockStrength LockingStreng
// GetAccountDomainAndCategory retrieves the Domain and DomainCategory fields for an account based on the given accountID.
func (s *SqlStore) GetAccountDomainAndCategory(ctx context.Context, lockStrength LockingStrength, accountID string) (string, string, error) {
tx := s.db
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
tx := s.getTXWithLockStrength(lockStrength)
var account types.Account
result := tx.Model(&types.Account{}).Select("domain", "domain_category").
Where(idQueryCondition, accountID).First(&account)
@@ -1684,11 +1576,7 @@ func (s *SqlStore) GetAccountDomainAndCategory(ctx context.Context, lockStrength
// GetGroupByID retrieves a group by ID and account ID.
func (s *SqlStore) GetGroupByID(ctx context.Context, lockStrength LockingStrength, accountID, groupID string) (*types.Group, error) {
tx := s.db
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
tx := s.getTXWithLockStrength(lockStrength)
var group *types.Group
result := tx.First(&group, accountAndIDQueryCondition, accountID, groupID)
if err := result.Error; err != nil {
@@ -1704,11 +1592,7 @@ func (s *SqlStore) GetGroupByID(ctx context.Context, lockStrength LockingStrengt
// GetGroupByName retrieves a group by name and account ID.
func (s *SqlStore) GetGroupByName(ctx context.Context, lockStrength LockingStrength, accountID, groupName string) (*types.Group, error) {
tx := s.db
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
tx := s.getTXWithLockStrength(lockStrength)
var group types.Group
// TODO: This fix is accepted for now, but if we need to handle this more frequently
@@ -1737,11 +1621,7 @@ func (s *SqlStore) GetGroupByName(ctx context.Context, lockStrength LockingStren
// GetGroupsByIDs retrieves groups by their IDs and account ID.
func (s *SqlStore) GetGroupsByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, groupIDs []string) (map[string]*types.Group, error) {
tx := s.db
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
tx := s.getTXWithLockStrength(lockStrength)
var groups []*types.Group
result := tx.Find(&groups, accountAndIDsQueryCondition, accountID, groupIDs)
if result.Error != nil {
@@ -1797,11 +1677,7 @@ func (s *SqlStore) DeleteGroups(ctx context.Context, strength LockingStrength, a
// GetAccountPolicies retrieves policies for an account.
func (s *SqlStore) GetAccountPolicies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types.Policy, error) {
tx := s.db
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
tx := s.getTXWithLockStrength(lockStrength)
var policies []*types.Policy
result := tx.
Preload(clause.Associations).Find(&policies, accountIDCondition, accountID)
@@ -1815,11 +1691,7 @@ func (s *SqlStore) GetAccountPolicies(ctx context.Context, lockStrength LockingS
// GetPolicyByID retrieves a policy by its ID and account ID.
func (s *SqlStore) GetPolicyByID(ctx context.Context, lockStrength LockingStrength, accountID, policyID string) (*types.Policy, error) {
tx := s.db
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
tx := s.getTXWithLockStrength(lockStrength)
var policy *types.Policy
result := tx.Preload(clause.Associations).
@@ -1881,11 +1753,7 @@ func (s *SqlStore) DeletePolicy(ctx context.Context, lockStrength LockingStrengt
// GetAccountPostureChecks retrieves posture checks for an account.
func (s *SqlStore) GetAccountPostureChecks(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*posture.Checks, error) {
tx := s.db
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
tx := s.getTXWithLockStrength(lockStrength)
var postureChecks []*posture.Checks
result := tx.Find(&postureChecks, accountIDCondition, accountID)
if result.Error != nil {
@@ -1898,10 +1766,7 @@ func (s *SqlStore) GetAccountPostureChecks(ctx context.Context, lockStrength Loc
// GetPostureChecksByID retrieves posture checks by their ID and account ID.
func (s *SqlStore) GetPostureChecksByID(ctx context.Context, lockStrength LockingStrength, accountID, postureChecksID string) (*posture.Checks, error) {
tx := s.db
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
tx := s.getTXWithLockStrength(lockStrength)
var postureCheck *posture.Checks
result := tx.
@@ -1919,11 +1784,7 @@ func (s *SqlStore) GetPostureChecksByID(ctx context.Context, lockStrength Lockin
// GetPostureChecksByIDs retrieves posture checks by their IDs and account ID.
func (s *SqlStore) GetPostureChecksByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, postureChecksIDs []string) (map[string]*posture.Checks, error) {
tx := s.db
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
tx := s.getTXWithLockStrength(lockStrength)
var postureChecks []*posture.Checks
result := tx.Find(&postureChecks, accountAndIDsQueryCondition, accountID, postureChecksIDs)
if result.Error != nil {
@@ -1968,21 +1829,63 @@ func (s *SqlStore) DeletePostureChecks(ctx context.Context, lockStrength Locking
// GetAccountRoutes retrieves network routes for an account.
func (s *SqlStore) GetAccountRoutes(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*route.Route, error) {
return getRecords[*route.Route](s.db, lockStrength, accountID)
tx := s.getTXWithLockStrength(lockStrength)
var routes []*route.Route
result := tx.Find(&routes, accountIDCondition, accountID)
if err := result.Error; err != nil {
log.WithContext(ctx).Errorf("failed to get routes from the store: %s", err)
return nil, status.Errorf(status.Internal, "failed to get routes from store")
}
return routes, nil
}
// GetRouteByID retrieves a route by its ID and account ID.
func (s *SqlStore) GetRouteByID(ctx context.Context, lockStrength LockingStrength, routeID string, accountID string) (*route.Route, error) {
return getRecordByID[route.Route](s.db, lockStrength, routeID, accountID)
func (s *SqlStore) GetRouteByID(ctx context.Context, lockStrength LockingStrength, accountID string, routeID string) (*route.Route, error) {
tx := s.getTXWithLockStrength(lockStrength)
var route *route.Route
result := tx.First(&route, accountAndIDQueryCondition, accountID, routeID)
if err := result.Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, status.NewRouteNotFoundError(routeID)
}
log.WithContext(ctx).Errorf("failed to get route from the store: %s", err)
return nil, status.Errorf(status.Internal, "failed to get route from store")
}
return route, nil
}
// SaveRoute saves a route to the database.
func (s *SqlStore) SaveRoute(ctx context.Context, lockStrength LockingStrength, route *route.Route) error {
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Save(route)
if err := result.Error; err != nil {
log.WithContext(ctx).Errorf("failed to save route to the store: %s", err)
return status.Errorf(status.Internal, "failed to save route to store")
}
return nil
}
// DeleteRoute deletes a route from the database.
func (s *SqlStore) DeleteRoute(ctx context.Context, lockStrength LockingStrength, accountID, routeID string) error {
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).
Delete(&route.Route{}, accountAndIDQueryCondition, accountID, routeID)
if err := result.Error; err != nil {
log.WithContext(ctx).Errorf("failed to delete route from the store: %s", err)
return status.Errorf(status.Internal, "failed to delete route from store")
}
if result.RowsAffected == 0 {
return status.NewRouteNotFoundError(routeID)
}
return nil
}
// GetAccountSetupKeys retrieves setup keys for an account.
func (s *SqlStore) GetAccountSetupKeys(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types.SetupKey, error) {
tx := s.db
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
tx := s.getTXWithLockStrength(lockStrength)
var setupKeys []*types.SetupKey
result := tx.
Find(&setupKeys, accountIDCondition, accountID)
@@ -1996,14 +1899,9 @@ func (s *SqlStore) GetAccountSetupKeys(ctx context.Context, lockStrength Locking
// GetSetupKeyByID retrieves a setup key by its ID and account ID.
func (s *SqlStore) GetSetupKeyByID(ctx context.Context, lockStrength LockingStrength, accountID, setupKeyID string) (*types.SetupKey, error) {
tx := s.db
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
tx := s.getTXWithLockStrength(lockStrength)
var setupKey *types.SetupKey
result := tx.Clauses(clause.Locking{Strength: string(lockStrength)}).
First(&setupKey, accountAndIDQueryCondition, accountID, setupKeyID)
result := tx.First(&setupKey, accountAndIDQueryCondition, accountID, setupKeyID)
if err := result.Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, status.NewSetupKeyNotFoundError(setupKeyID)
@@ -2043,11 +1941,7 @@ func (s *SqlStore) DeleteSetupKey(ctx context.Context, lockStrength LockingStren
// GetAccountNameServerGroups retrieves name server groups for an account.
func (s *SqlStore) GetAccountNameServerGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbdns.NameServerGroup, error) {
tx := s.db
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
tx := s.getTXWithLockStrength(lockStrength)
var nsGroups []*nbdns.NameServerGroup
result := tx.Find(&nsGroups, accountIDCondition, accountID)
if err := result.Error; err != nil {
@@ -2060,11 +1954,7 @@ func (s *SqlStore) GetAccountNameServerGroups(ctx context.Context, lockStrength
// GetNameServerGroupByID retrieves a name server group by its ID and account ID.
func (s *SqlStore) GetNameServerGroupByID(ctx context.Context, lockStrength LockingStrength, accountID, nsGroupID string) (*nbdns.NameServerGroup, error) {
tx := s.db
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
tx := s.getTXWithLockStrength(lockStrength)
var nsGroup *nbdns.NameServerGroup
result := tx.
First(&nsGroup, accountAndIDQueryCondition, accountID, nsGroupID)
@@ -2104,49 +1994,6 @@ func (s *SqlStore) DeleteNameServerGroup(ctx context.Context, lockStrength Locki
return nil
}
// getRecords retrieves records from the database based on the account ID.
func getRecords[T any](db *gorm.DB, lockStrength LockingStrength, accountID string) ([]T, error) {
tx := db
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
var record []T
result := tx.Find(&record, accountIDCondition, accountID)
if err := result.Error; err != nil {
parts := strings.Split(fmt.Sprintf("%T", record), ".")
recordType := parts[len(parts)-1]
return nil, status.Errorf(status.Internal, "failed to get account %ss from store: %v", recordType, err)
}
return record, nil
}
// getRecordByID retrieves a record by its ID and account ID from the database.
func getRecordByID[T any](db *gorm.DB, lockStrength LockingStrength, recordID, accountID string) (*T, error) {
tx := db
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
var record T
result := tx.Clauses(clause.Locking{Strength: string(lockStrength)}).
First(&record, accountAndIDQueryCondition, accountID, recordID)
if err := result.Error; err != nil {
parts := strings.Split(fmt.Sprintf("%T", record), ".")
recordType := parts[len(parts)-1]
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "%s not found", recordType)
}
return nil, status.Errorf(status.Internal, "failed to get %s from store: %v", recordType, err)
}
return &record, nil
}
// SaveDNSSettings saves the DNS settings to the store.
func (s *SqlStore) SaveDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string, settings *types.DNSSettings) error {
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&types.Account{}).
@@ -2180,11 +2027,7 @@ func (s *SqlStore) SaveAccountSettings(ctx context.Context, lockStrength Locking
}
func (s *SqlStore) GetAccountNetworks(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*networkTypes.Network, error) {
tx := s.db
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
tx := s.getTXWithLockStrength(lockStrength)
var networks []*networkTypes.Network
result := tx.Find(&networks, accountIDCondition, accountID)
if result.Error != nil {
@@ -2196,11 +2039,7 @@ func (s *SqlStore) GetAccountNetworks(ctx context.Context, lockStrength LockingS
}
func (s *SqlStore) GetNetworkByID(ctx context.Context, lockStrength LockingStrength, accountID, networkID string) (*networkTypes.Network, error) {
tx := s.db
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
tx := s.getTXWithLockStrength(lockStrength)
var network *networkTypes.Network
result := tx.
First(&network, accountAndIDQueryCondition, accountID, networkID)
@@ -2242,11 +2081,7 @@ func (s *SqlStore) DeleteNetwork(ctx context.Context, lockStrength LockingStreng
}
func (s *SqlStore) GetNetworkRoutersByNetID(ctx context.Context, lockStrength LockingStrength, accountID, netID string) ([]*routerTypes.NetworkRouter, error) {
tx := s.db
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
tx := s.getTXWithLockStrength(lockStrength)
var netRouters []*routerTypes.NetworkRouter
result := tx.
Find(&netRouters, "account_id = ? AND network_id = ?", accountID, netID)
@@ -2259,11 +2094,7 @@ func (s *SqlStore) GetNetworkRoutersByNetID(ctx context.Context, lockStrength Lo
}
func (s *SqlStore) GetNetworkRoutersByAccountID(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*routerTypes.NetworkRouter, error) {
tx := s.db
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
tx := s.getTXWithLockStrength(lockStrength)
var netRouters []*routerTypes.NetworkRouter
result := tx.
Find(&netRouters, accountIDCondition, accountID)
@@ -2276,11 +2107,7 @@ func (s *SqlStore) GetNetworkRoutersByAccountID(ctx context.Context, lockStrengt
}
func (s *SqlStore) GetNetworkRouterByID(ctx context.Context, lockStrength LockingStrength, accountID, routerID string) (*routerTypes.NetworkRouter, error) {
tx := s.db
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
tx := s.getTXWithLockStrength(lockStrength)
var netRouter *routerTypes.NetworkRouter
result := tx.
First(&netRouter, accountAndIDQueryCondition, accountID, routerID)
@@ -2321,11 +2148,7 @@ func (s *SqlStore) DeleteNetworkRouter(ctx context.Context, lockStrength Locking
}
func (s *SqlStore) GetNetworkResourcesByNetID(ctx context.Context, lockStrength LockingStrength, accountID, networkID string) ([]*resourceTypes.NetworkResource, error) {
tx := s.db
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
tx := s.getTXWithLockStrength(lockStrength)
var netResources []*resourceTypes.NetworkResource
result := tx.
Find(&netResources, "account_id = ? AND network_id = ?", accountID, networkID)
@@ -2338,11 +2161,7 @@ func (s *SqlStore) GetNetworkResourcesByNetID(ctx context.Context, lockStrength
}
func (s *SqlStore) GetNetworkResourcesByAccountID(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*resourceTypes.NetworkResource, error) {
tx := s.db
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
tx := s.getTXWithLockStrength(lockStrength)
var netResources []*resourceTypes.NetworkResource
result := tx.
Find(&netResources, accountIDCondition, accountID)
@@ -2355,11 +2174,7 @@ func (s *SqlStore) GetNetworkResourcesByAccountID(ctx context.Context, lockStren
}
func (s *SqlStore) GetNetworkResourceByID(ctx context.Context, lockStrength LockingStrength, accountID, resourceID string) (*resourceTypes.NetworkResource, error) {
tx := s.db
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
tx := s.getTXWithLockStrength(lockStrength)
var netResources *resourceTypes.NetworkResource
result := tx.
First(&netResources, accountAndIDQueryCondition, accountID, resourceID)
@@ -2375,11 +2190,7 @@ func (s *SqlStore) GetNetworkResourceByID(ctx context.Context, lockStrength Lock
}
func (s *SqlStore) GetNetworkResourceByName(ctx context.Context, lockStrength LockingStrength, accountID, resourceName string) (*resourceTypes.NetworkResource, error) {
tx := s.db
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
tx := s.getTXWithLockStrength(lockStrength)
var netResources *resourceTypes.NetworkResource
result := tx.
First(&netResources, "account_id = ? AND name = ?", accountID, resourceName)
@@ -2421,10 +2232,7 @@ func (s *SqlStore) DeleteNetworkResource(ctx context.Context, lockStrength Locki
// GetPATByHashedToken returns a PersonalAccessToken by its hashed token.
func (s *SqlStore) GetPATByHashedToken(ctx context.Context, lockStrength LockingStrength, hashedToken string) (*types.PersonalAccessToken, error) {
tx := s.db
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
tx := s.getTXWithLockStrength(lockStrength)
var pat types.PersonalAccessToken
result := tx.First(&pat, "hashed_token = ?", hashedToken)
@@ -2441,11 +2249,7 @@ func (s *SqlStore) GetPATByHashedToken(ctx context.Context, lockStrength Locking
// GetPATByID retrieves a personal access token by its ID and user ID.
func (s *SqlStore) GetPATByID(ctx context.Context, lockStrength LockingStrength, userID string, patID string) (*types.PersonalAccessToken, error) {
tx := s.db
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
tx := s.getTXWithLockStrength(lockStrength)
var pat types.PersonalAccessToken
result := tx.
First(&pat, "id = ? AND user_id = ?", patID, userID)
@@ -2462,11 +2266,7 @@ func (s *SqlStore) GetPATByID(ctx context.Context, lockStrength LockingStrength,
// GetUserPATs retrieves personal access tokens for a user.
func (s *SqlStore) GetUserPATs(ctx context.Context, lockStrength LockingStrength, userID string) ([]*types.PersonalAccessToken, error) {
tx := s.db
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
tx := s.getTXWithLockStrength(lockStrength)
var pats []*types.PersonalAccessToken
result := tx.Find(&pats, "user_id = ?", userID)
if err := result.Error; err != nil {
@@ -2526,11 +2326,7 @@ func (s *SqlStore) DeletePAT(ctx context.Context, lockStrength LockingStrength,
}
func (s *SqlStore) GetPeerByIP(ctx context.Context, lockStrength LockingStrength, accountID string, ip net.IP) (*nbpeer.Peer, error) {
tx := s.db
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
tx := s.getTXWithLockStrength(lockStrength)
jsonValue := fmt.Sprintf(`"%s"`, ip.String())
var peer nbpeer.Peer
@@ -2557,3 +2353,11 @@ func (s *SqlStore) CountAccountsByPrivateDomain(ctx context.Context, domain stri
return count, nil
}
func (s *SqlStore) getTXWithLockStrength(lockStrength LockingStrength) *gorm.DB {
tx := s.db
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
return tx
}

View File

@@ -19,21 +19,17 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/server/util"
nbdns "github.com/netbirdio/netbird/dns"
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
"github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/management/server/types"
route2 "github.com/netbirdio/netbird/route"
"github.com/netbirdio/netbird/management/server/status"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/management/server/status"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/management/server/util"
nbroute "github.com/netbirdio/netbird/route"
route2 "github.com/netbirdio/netbird/route"
)
func runTestForAllEngines(t *testing.T, testDataFile string, f func(t *testing.T, store Store)) {
@@ -3247,6 +3243,132 @@ func TestSqlStore_SaveGroups_LargeBatch(t *testing.T) {
require.NoError(t, err)
require.Equal(t, 8003, len(accountGroups))
}
func TestSqlStore_GetAccountRoutes(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)
tests := []struct {
name string
accountID string
expectedCount int
}{
{
name: "retrieve routes by existing account ID",
accountID: "bf1c8084-ba50-4ce7-9439-34653001fc3b",
expectedCount: 1,
},
{
name: "non-existing account ID",
accountID: "nonexistent",
expectedCount: 0,
},
{
name: "empty account ID",
accountID: "",
expectedCount: 0,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
routes, err := store.GetAccountRoutes(context.Background(), LockingStrengthShare, tt.accountID)
require.NoError(t, err)
require.Len(t, routes, tt.expectedCount)
})
}
}
func TestSqlStore_GetRouteByID(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
tests := []struct {
name string
routeID string
expectError bool
}{
{
name: "retrieve existing route",
routeID: "ct03t427qv97vmtmglog",
expectError: false,
},
{
name: "retrieve non-existing route",
routeID: "non-existing",
expectError: true,
},
{
name: "retrieve with empty route ID",
routeID: "",
expectError: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
route, err := store.GetRouteByID(context.Background(), LockingStrengthShare, accountID, tt.routeID)
if tt.expectError {
require.Error(t, err)
sErr, ok := status.FromError(err)
require.True(t, ok)
require.Equal(t, sErr.Type(), status.NotFound)
require.Nil(t, route)
} else {
require.NoError(t, err)
require.NotNil(t, route)
require.Equal(t, tt.routeID, string(route.ID))
}
})
}
}
func TestSqlStore_SaveRoute(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
route := &route2.Route{
ID: "route-id",
AccountID: accountID,
Network: netip.MustParsePrefix("10.10.0.0/16"),
NetID: "netID",
PeerGroups: []string{"routeA"},
NetworkType: route2.IPv4Network,
Masquerade: true,
Metric: 9999,
Enabled: true,
Groups: []string{"groupA"},
AccessControlGroups: []string{},
}
err = store.SaveRoute(context.Background(), LockingStrengthUpdate, route)
require.NoError(t, err)
saveRoute, err := store.GetRouteByID(context.Background(), LockingStrengthShare, accountID, string(route.ID))
require.NoError(t, err)
require.Equal(t, route, saveRoute)
}
func TestSqlStore_DeleteRoute(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
routeID := "ct03t427qv97vmtmglog"
err = store.DeleteRoute(context.Background(), LockingStrengthUpdate, accountID, routeID)
require.NoError(t, err)
route, err := store.GetRouteByID(context.Background(), LockingStrengthShare, accountID, routeID)
require.Error(t, err)
require.Nil(t, route)
}
func TestSqlStore_GetAccountMeta(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())

View File

@@ -145,7 +145,9 @@ type Store interface {
DeleteSetupKey(ctx context.Context, lockStrength LockingStrength, accountID, keyID string) error
GetAccountRoutes(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*route.Route, error)
GetRouteByID(ctx context.Context, lockStrength LockingStrength, routeID string, accountID string) (*route.Route, error)
GetRouteByID(ctx context.Context, lockStrength LockingStrength, accountID, routeID string) (*route.Route, error)
SaveRoute(ctx context.Context, lockStrength LockingStrength, route *route.Route) error
DeleteRoute(ctx context.Context, lockStrength LockingStrength, accountID, routeID string) error
GetAccountNameServerGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*dns.NameServerGroup, error)
GetNameServerGroupByID(ctx context.Context, lockStrength LockingStrength, nameServerGroupID string, accountID string) (*dns.NameServerGroup, error)

View File

@@ -18,6 +18,10 @@ type UpdateChannelMetrics struct {
getAllConnectedPeersDurationMicro metric.Int64Histogram
getAllConnectedPeers metric.Int64Histogram
hasChannelDurationMicro metric.Int64Histogram
calcPostureChecksDurationMicro metric.Int64Histogram
calcPeerNetworkMapDurationMicro metric.Int64Histogram
mergeNetworkMapDurationMicro metric.Int64Histogram
toSyncResponseDurationMicro metric.Int64Histogram
ctx context.Context
}
@@ -89,6 +93,38 @@ func NewUpdateChannelMetrics(ctx context.Context, meter metric.Meter) (*UpdateCh
return nil, err
}
calcPostureChecksDurationMicro, err := meter.Int64Histogram("management.updatechannel.calc.posturechecks.duration.micro",
metric.WithUnit("microseconds"),
metric.WithDescription("Duration of how long it takes to get the posture checks for a peer"),
)
if err != nil {
return nil, err
}
calcPeerNetworkMapDurationMicro, err := meter.Int64Histogram("management.updatechannel.calc.networkmap.duration.micro",
metric.WithUnit("microseconds"),
metric.WithDescription("Duration of how long it takes to calculate the network map for a peer"),
)
if err != nil {
return nil, err
}
mergeNetworkMapDurationMicro, err := meter.Int64Histogram("management.updatechannel.merge.networkmap.duration.micro",
metric.WithUnit("microseconds"),
metric.WithDescription("Duration of how long it takes to merge the network maps for a peer"),
)
if err != nil {
return nil, err
}
toSyncResponseDurationMicro, err := meter.Int64Histogram("management.updatechannel.tosyncresponse.duration.micro",
metric.WithUnit("microseconds"),
metric.WithDescription("Duration of how long it takes to convert the network map to sync response"),
)
if err != nil {
return nil, err
}
return &UpdateChannelMetrics{
createChannelDurationMicro: createChannelDurationMicro,
closeChannelDurationMicro: closeChannelDurationMicro,
@@ -98,6 +134,10 @@ func NewUpdateChannelMetrics(ctx context.Context, meter metric.Meter) (*UpdateCh
getAllConnectedPeersDurationMicro: getAllConnectedPeersDurationMicro,
getAllConnectedPeers: getAllConnectedPeers,
hasChannelDurationMicro: hasChannelDurationMicro,
calcPostureChecksDurationMicro: calcPostureChecksDurationMicro,
calcPeerNetworkMapDurationMicro: calcPeerNetworkMapDurationMicro,
mergeNetworkMapDurationMicro: mergeNetworkMapDurationMicro,
toSyncResponseDurationMicro: toSyncResponseDurationMicro,
ctx: ctx,
}, nil
}
@@ -137,3 +177,19 @@ func (metrics *UpdateChannelMetrics) CountGetAllConnectedPeersDuration(duration
func (metrics *UpdateChannelMetrics) CountHasChannelDuration(duration time.Duration) {
metrics.hasChannelDurationMicro.Record(metrics.ctx, duration.Microseconds())
}
func (metrics *UpdateChannelMetrics) CountCalcPostureChecksDuration(duration time.Duration) {
metrics.calcPostureChecksDurationMicro.Record(metrics.ctx, duration.Microseconds())
}
func (metrics *UpdateChannelMetrics) CountCalcPeerNetworkMapDuration(duration time.Duration) {
metrics.calcPeerNetworkMapDurationMicro.Record(metrics.ctx, duration.Microseconds())
}
func (metrics *UpdateChannelMetrics) CountMergeNetworkMapDuration(duration time.Duration) {
metrics.mergeNetworkMapDurationMicro.Record(metrics.ctx, duration.Microseconds())
}
func (metrics *UpdateChannelMetrics) CountToSyncResponseDuration(duration time.Duration) {
metrics.toSyncResponseDurationMicro.Record(metrics.ctx, duration.Microseconds())
}

View File

@@ -38,4 +38,5 @@ INSERT INTO "groups" VALUES('cfefqs706sqkneg59g2g','bf1c8084-ba50-4ce7-9439-3465
INSERT INTO posture_checks VALUES('csplshq7qv948l48f7t0','NetBird Version > 0.32.0','','bf1c8084-ba50-4ce7-9439-34653001fc3b','{"NBVersionCheck":{"MinVersion":"0.31.0"}}');
INSERT INTO posture_checks VALUES('cspnllq7qv95uq1r4k90','Allow Berlin and Deny local network 172.16.1.0/24','','bf1c8084-ba50-4ce7-9439-34653001fc3b','{"GeoLocationCheck":{"Locations":[{"CountryCode":"DE","CityName":"Berlin"}],"Action":"allow"},"PeerNetworkRangeCheck":{"Action":"deny","Ranges":["172.16.1.0/24"]}}');
INSERT INTO name_server_groups VALUES('csqdelq7qv97ncu7d9t0','bf1c8084-ba50-4ce7-9439-34653001fc3b','Google DNS','Google DNS Servers','[{"IP":"8.8.8.8","NSType":1,"Port":53},{"IP":"8.8.4.4","NSType":1,"Port":53}]','["cfefqs706sqkneg59g2g"]',1,'[]',1,0);
INSERT INTO routes VALUES('ct03t427qv97vmtmglog','bf1c8084-ba50-4ce7-9439-34653001fc3b','"10.10.0.0/16"',NULL,0,'aws-eu-central-1-vpc','Production VPC in Frankfurt','ct03r5q7qv97vmtmglng',NULL,1,1,9999,1,'["cfefqs706sqkneg59g2g"]',NULL);
INSERT INTO installations VALUES(1,'');

View File

@@ -36,6 +36,9 @@ const (
PublicCategory = "public"
PrivateCategory = "private"
UnknownCategory = "unknown"
// firewallRuleMinPortRangesVer defines the minimum peer version that supports port range rules.
firewallRuleMinPortRangesVer = "0.48.0"
)
type LookupMap map[string]struct{}
@@ -248,7 +251,7 @@ func (a *Account) GetPeerNetworkMap(
}
}
aclPeers, firewallRules := a.GetPeerConnectionResources(ctx, peerID, validatedPeersMap)
aclPeers, firewallRules := a.GetPeerConnectionResources(ctx, peer, validatedPeersMap)
// exclude expired peers
var peersToConnect []*nbpeer.Peer
var expiredPeers []*nbpeer.Peer
@@ -961,8 +964,9 @@ func (a *Account) UserGroupsRemoveFromPeers(userID string, groups ...string) map
// GetPeerConnectionResources for a given peer
//
// This function returns the list of peers and firewall rules that are applicable to a given peer.
func (a *Account) GetPeerConnectionResources(ctx context.Context, peerID string, validatedPeersMap map[string]struct{}) ([]*nbpeer.Peer, []*FirewallRule) {
generateResources, getAccumulatedResources := a.connResourcesGenerator(ctx)
func (a *Account) GetPeerConnectionResources(ctx context.Context, peer *nbpeer.Peer, validatedPeersMap map[string]struct{}) ([]*nbpeer.Peer, []*FirewallRule) {
generateResources, getAccumulatedResources := a.connResourcesGenerator(ctx, peer)
for _, policy := range a.Policies {
if !policy.Enabled {
continue
@@ -973,8 +977,8 @@ func (a *Account) GetPeerConnectionResources(ctx context.Context, peerID string,
continue
}
sourcePeers, peerInSources := a.getAllPeersFromGroups(ctx, rule.Sources, peerID, policy.SourcePostureChecks, validatedPeersMap)
destinationPeers, peerInDestinations := a.getAllPeersFromGroups(ctx, rule.Destinations, peerID, nil, validatedPeersMap)
sourcePeers, peerInSources := a.getAllPeersFromGroups(ctx, rule.Sources, peer.ID, policy.SourcePostureChecks, validatedPeersMap)
destinationPeers, peerInDestinations := a.getAllPeersFromGroups(ctx, rule.Destinations, peer.ID, nil, validatedPeersMap)
if rule.Bidirectional {
if peerInSources {
@@ -1003,7 +1007,7 @@ func (a *Account) GetPeerConnectionResources(ctx context.Context, peerID string,
// The generator function is used to generate the list of peers and firewall rules that are applicable to a given peer.
// It safe to call the generator function multiple times for same peer and different rules no duplicates will be
// generated. The accumulator function returns the result of all the generator calls.
func (a *Account) connResourcesGenerator(ctx context.Context) (func(*PolicyRule, []*nbpeer.Peer, int), func() ([]*nbpeer.Peer, []*FirewallRule)) {
func (a *Account) connResourcesGenerator(ctx context.Context, targetPeer *nbpeer.Peer) (func(*PolicyRule, []*nbpeer.Peer, int), func() ([]*nbpeer.Peer, []*FirewallRule)) {
rulesExists := make(map[string]struct{})
peersExists := make(map[string]struct{})
rules := make([]*FirewallRule, 0)
@@ -1051,17 +1055,7 @@ func (a *Account) connResourcesGenerator(ctx context.Context) (func(*PolicyRule,
continue
}
for _, port := range rule.Ports {
pr := fr // clone rule and add set new port
pr.Port = port
rules = append(rules, &pr)
}
for _, portRange := range rule.PortRanges {
pr := fr
pr.PortRange = portRange
rules = append(rules, &pr)
}
rules = append(rules, expandPortsAndRanges(fr, rule, targetPeer)...)
}
}, func() ([]*nbpeer.Peer, []*FirewallRule) {
return peers, rules
@@ -1590,3 +1584,45 @@ func (a *Account) AddAllGroup() error {
}
return nil
}
// expandPortsAndRanges expands Ports and PortRanges of a rule into individual firewall rules
func expandPortsAndRanges(base FirewallRule, rule *PolicyRule, peer *nbpeer.Peer) []*FirewallRule {
var expanded []*FirewallRule
if len(rule.Ports) > 0 {
for _, port := range rule.Ports {
fr := base
fr.Port = port
expanded = append(expanded, &fr)
}
return expanded
}
supportPortRanges := peerSupportsPortRanges(peer.Meta.WtVersion)
for _, portRange := range rule.PortRanges {
fr := base
if supportPortRanges {
fr.PortRange = portRange
} else {
// Peer doesn't support port ranges, only allow single-port ranges
if portRange.Start != portRange.End {
continue
}
fr.Port = strconv.FormatUint(uint64(portRange.Start), 10)
}
expanded = append(expanded, &fr)
}
return expanded
}
// peerSupportsPortRanges checks if the peer version supports port ranges.
func peerSupportsPortRanges(peerVer string) bool {
if strings.Contains(peerVer, "dev") {
return true
}
meetMinVer, err := posture.MeetsMinVersion(firewallRuleMinPortRangesVer, peerVer)
return err == nil && meetMinVer
}

View File

@@ -76,7 +76,6 @@ func generateRouteFirewallRules(ctx context.Context, route *nbroute.Route, rule
rules = append(rules, generateRulesWithPortRanges(baseRule, rule, rulesExists)...)
} else {
rules = append(rules, generateRulesWithPorts(ctx, baseRule, rule, rulesExists)...)
}
// TODO: generate IPv6 rules for dynamic routes

View File

@@ -95,14 +95,14 @@ func (am *DefaultAccountManager) inviteNewUser(ctx context.Context, accountID, u
return nil, status.NewPermissionDeniedError()
}
initiatorUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID)
initiatorUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, userID)
if err != nil {
return nil, err
}
inviterID := userID
if initiatorUser.IsServiceUser {
createdBy, err := am.Store.GetAccountCreatedBy(ctx, store.LockingStrengthShare, accountID)
createdBy, err := am.Store.GetAccountCreatedBy(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return nil, err
}
@@ -178,13 +178,13 @@ func (am *DefaultAccountManager) createNewIdpUser(ctx context.Context, accountID
}
func (am *DefaultAccountManager) GetUserByID(ctx context.Context, id string) (*types.User, error) {
return am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, id)
return am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, id)
}
// GetUser looks up a user by provided nbContext.UserAuths.
// Expects account to have been created already.
func (am *DefaultAccountManager) GetUserFromUserAuth(ctx context.Context, userAuth nbContext.UserAuth) (*types.User, error) {
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userAuth.UserId)
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, userAuth.UserId)
if err != nil {
return nil, err
}
@@ -209,7 +209,7 @@ func (am *DefaultAccountManager) GetUserFromUserAuth(ctx context.Context, userAu
// ListUsers returns lists of all users under the account.
// It doesn't populate user information such as email or name.
func (am *DefaultAccountManager) ListUsers(ctx context.Context, accountID string) ([]*types.User, error) {
return am.Store.GetAccountUsers(ctx, store.LockingStrengthShare, accountID)
return am.Store.GetAccountUsers(ctx, store.LockingStrengthNone, accountID)
}
func (am *DefaultAccountManager) deleteServiceUser(ctx context.Context, accountID string, initiatorUserID string, targetUser *types.User) error {
@@ -230,7 +230,7 @@ func (am *DefaultAccountManager) DeleteUser(ctx context.Context, accountID, init
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
initiatorUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, initiatorUserID)
initiatorUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, initiatorUserID)
if err != nil {
return err
}
@@ -243,7 +243,7 @@ func (am *DefaultAccountManager) DeleteUser(ctx context.Context, accountID, init
return status.NewPermissionDeniedError()
}
targetUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, targetUserID)
targetUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, targetUserID)
if err != nil {
return err
}
@@ -347,12 +347,12 @@ func (am *DefaultAccountManager) CreatePAT(ctx context.Context, accountID string
return nil, status.NewPermissionDeniedError()
}
initiatorUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, initiatorUserID)
initiatorUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, initiatorUserID)
if err != nil {
return nil, err
}
targetUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, targetUserID)
targetUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, targetUserID)
if err != nil {
return nil, err
}
@@ -390,12 +390,12 @@ func (am *DefaultAccountManager) DeletePAT(ctx context.Context, accountID string
return status.NewPermissionDeniedError()
}
initiatorUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, initiatorUserID)
initiatorUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, initiatorUserID)
if err != nil {
return err
}
targetUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, targetUserID)
targetUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, targetUserID)
if err != nil {
return err
}
@@ -404,7 +404,7 @@ func (am *DefaultAccountManager) DeletePAT(ctx context.Context, accountID string
return status.NewAdminPermissionError()
}
pat, err := am.Store.GetPATByID(ctx, store.LockingStrengthShare, targetUserID, tokenID)
pat, err := am.Store.GetPATByID(ctx, store.LockingStrengthNone, targetUserID, tokenID)
if err != nil {
return err
}
@@ -429,12 +429,12 @@ func (am *DefaultAccountManager) GetPAT(ctx context.Context, accountID string, i
return nil, status.NewPermissionDeniedError()
}
initiatorUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, initiatorUserID)
initiatorUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, initiatorUserID)
if err != nil {
return nil, err
}
targetUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, targetUserID)
targetUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, targetUserID)
if err != nil {
return nil, err
}
@@ -443,7 +443,7 @@ func (am *DefaultAccountManager) GetPAT(ctx context.Context, accountID string, i
return nil, status.NewAdminPermissionError()
}
return am.Store.GetPATByID(ctx, store.LockingStrengthShare, targetUserID, tokenID)
return am.Store.GetPATByID(ctx, store.LockingStrengthNone, targetUserID, tokenID)
}
// GetAllPATs returns all PATs for a user
@@ -456,12 +456,12 @@ func (am *DefaultAccountManager) GetAllPATs(ctx context.Context, accountID strin
return nil, status.NewPermissionDeniedError()
}
initiatorUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, initiatorUserID)
initiatorUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, initiatorUserID)
if err != nil {
return nil, err
}
targetUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, targetUserID)
targetUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, targetUserID)
if err != nil {
return nil, err
}
@@ -470,7 +470,7 @@ func (am *DefaultAccountManager) GetAllPATs(ctx context.Context, accountID strin
return nil, status.NewAdminPermissionError()
}
return am.Store.GetUserPATs(ctx, store.LockingStrengthShare, targetUserID)
return am.Store.GetUserPATs(ctx, store.LockingStrengthNone, targetUserID)
}
// SaveUser saves updates to the given user. If the user doesn't exist, it will throw status.NotFound error.
@@ -511,7 +511,7 @@ func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID,
if !allowed {
return nil, status.NewPermissionDeniedError()
}
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID)
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return nil, err
}
@@ -521,7 +521,7 @@ func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID,
var addUserEvents []func()
var usersToSave = make([]*types.User, 0, len(updates))
groups, err := am.Store.GetAccountGroups(ctx, store.LockingStrengthShare, accountID)
groups, err := am.Store.GetAccountGroups(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return nil, fmt.Errorf("error getting account groups: %w", err)
}
@@ -533,7 +533,7 @@ func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID,
var initiatorUser *types.User
if initiatorUserID != activity.SystemInitiator {
result, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, initiatorUserID)
result, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, initiatorUserID)
if err != nil {
return nil, err
}
@@ -695,7 +695,7 @@ func (am *DefaultAccountManager) processUserUpdate(ctx context.Context, transact
// getUserOrCreateIfNotExists retrieves the existing user or creates a new one if it doesn't exist.
func getUserOrCreateIfNotExists(ctx context.Context, transaction store.Store, accountID string, update *types.User, addIfNotExists bool) (*types.User, error) {
existingUser, err := transaction.GetUserByUserID(ctx, store.LockingStrengthShare, update.Id)
existingUser, err := transaction.GetUserByUserID(ctx, store.LockingStrengthNone, update.Id)
if err != nil {
if sErr, ok := status.FromError(err); ok && sErr.Type() == status.NotFound {
if !addIfNotExists {
@@ -830,7 +830,7 @@ func (am *DefaultAccountManager) GetUsersFromAccount(ctx context.Context, accoun
var user *types.User
if initiatorUserID != activity.SystemInitiator {
result, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, initiatorUserID)
result, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, initiatorUserID)
if err != nil {
return nil, fmt.Errorf("failed to get user: %w", err)
}
@@ -840,7 +840,7 @@ func (am *DefaultAccountManager) GetUsersFromAccount(ctx context.Context, accoun
accountUsers := []*types.User{}
switch {
case allowed:
accountUsers, err = am.Store.GetAccountUsers(ctx, store.LockingStrengthShare, accountID)
accountUsers, err = am.Store.GetAccountUsers(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return nil, err
}
@@ -933,7 +933,7 @@ func (am *DefaultAccountManager) BuildUserInfosForAccount(ctx context.Context, a
// expireAndUpdatePeers expires all peers of the given user and updates them in the account
func (am *DefaultAccountManager) expireAndUpdatePeers(ctx context.Context, accountID string, peers []*nbpeer.Peer) error {
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID)
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return err
}
@@ -1003,7 +1003,7 @@ func (am *DefaultAccountManager) DeleteRegularUsers(ctx context.Context, account
return status.NewPermissionDeniedError()
}
initiatorUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, initiatorUserID)
initiatorUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, initiatorUserID)
if err != nil {
return err
}
@@ -1017,7 +1017,7 @@ func (am *DefaultAccountManager) DeleteRegularUsers(ctx context.Context, account
continue
}
targetUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, targetUserID)
targetUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, targetUserID)
if err != nil {
allErrors = errors.Join(allErrors, err)
continue
@@ -1081,12 +1081,12 @@ func (am *DefaultAccountManager) deleteRegularUser(ctx context.Context, accountI
var err error
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
targetUser, err = transaction.GetUserByUserID(ctx, store.LockingStrengthShare, targetUserInfo.ID)
targetUser, err = transaction.GetUserByUserID(ctx, store.LockingStrengthNone, targetUserInfo.ID)
if err != nil {
return fmt.Errorf("failed to get user to delete: %w", err)
}
userPeers, err := transaction.GetUserPeers(ctx, store.LockingStrengthShare, accountID, targetUserInfo.ID)
userPeers, err := transaction.GetUserPeers(ctx, store.LockingStrengthNone, accountID, targetUserInfo.ID)
if err != nil {
return fmt.Errorf("failed to get user peers: %w", err)
}
@@ -1120,7 +1120,7 @@ func (am *DefaultAccountManager) deleteRegularUser(ctx context.Context, accountI
// GetOwnerInfo retrieves the owner information for a given account ID.
func (am *DefaultAccountManager) GetOwnerInfo(ctx context.Context, accountID string) (*types.UserInfo, error) {
owner, err := am.Store.GetAccountOwner(ctx, store.LockingStrengthShare, accountID)
owner, err := am.Store.GetAccountOwner(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return nil, err
}
@@ -1257,7 +1257,7 @@ func validateUserInvite(invite *types.UserInfo) error {
func (am *DefaultAccountManager) GetCurrentUserInfo(ctx context.Context, userAuth nbcontext.UserAuth) (*users.UserInfoWithPermissions, error) {
accountID, userID := userAuth.AccountId, userAuth.UserId
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID)
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, userID)
if err != nil {
return nil, err
}
@@ -1274,7 +1274,7 @@ func (am *DefaultAccountManager) GetCurrentUserInfo(ctx context.Context, userAut
return nil, err
}
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID)
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return nil, err
}

View File

@@ -88,7 +88,7 @@ func TestUser_CreatePAT_ForSameUser(t *testing.T) {
assert.Equal(t, pat.ID, tokenID)
user, err := am.Store.GetUserByPATID(context.Background(), store.LockingStrengthShare, tokenID)
user, err := am.Store.GetUserByPATID(context.Background(), store.LockingStrengthNone, tokenID)
if err != nil {
t.Fatalf("Error when getting user by token ID: %s", err)
}
@@ -1521,7 +1521,7 @@ func TestSaveOrAddUser_PreventAccountSwitch(t *testing.T) {
_, err = am.SaveOrAddUser(context.Background(), "account2", "ownerAccount2", account1.Users[targetId], true)
assert.Error(t, err, "update user to another account should fail")
user, err := s.GetUserByUserID(context.Background(), store.LockingStrengthShare, targetId)
user, err := s.GetUserByUserID(context.Background(), store.LockingStrengthNone, targetId)
require.NoError(t, err)
assert.Equal(t, account1.Users[targetId].Id, user.Id)
assert.Equal(t, account1.Users[targetId].AccountID, user.AccountID)

View File

@@ -26,7 +26,7 @@ func NewManager(store store.Store) Manager {
}
func (m *managerImpl) GetUser(ctx context.Context, userID string) (*types.User, error) {
return m.store.GetUserByUserID(ctx, store.LockingStrengthShare, userID)
return m.store.GetUserByUserID(ctx, store.LockingStrengthNone, userID)
}
func NewManagerMock() Manager {

View File

@@ -21,6 +21,7 @@ var (
// Update fetch the version info periodically and notify the onUpdateListener in case the UI version or the
// daemon version are deprecated
type Update struct {
httpAgent string
uiVersion *goversion.Version
daemonVersion *goversion.Version
latestAvailable *goversion.Version
@@ -34,7 +35,7 @@ type Update struct {
}
// NewUpdate instantiate Update and start to fetch the new version information
func NewUpdate() *Update {
func NewUpdate(httpAgent string) *Update {
currentVersion, err := goversion.NewVersion(version)
if err != nil {
currentVersion, _ = goversion.NewVersion("0.0.0")
@@ -43,6 +44,7 @@ func NewUpdate() *Update {
latestAvailable, _ := goversion.NewVersion("0.0.0")
u := &Update{
httpAgent: httpAgent,
latestAvailable: latestAvailable,
uiVersion: currentVersion,
fetchTicker: time.NewTicker(fetchPeriod),
@@ -112,7 +114,15 @@ func (u *Update) startFetcher() {
func (u *Update) fetchVersion() bool {
log.Debugf("fetching version info from %s", versionURL)
resp, err := http.Get(versionURL)
req, err := http.NewRequest("GET", versionURL, nil)
if err != nil {
log.Errorf("failed to create request for version info: %s", err)
return false
}
req.Header.Set("User-Agent", u.httpAgent)
resp, err := http.DefaultClient.Do(req)
if err != nil {
log.Errorf("failed to fetch version info: %s", err)
return false

View File

@@ -9,6 +9,8 @@ import (
"time"
)
const httpAgent = "pkg/test"
func TestNewUpdate(t *testing.T) {
version = "1.0.0"
svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
@@ -21,7 +23,7 @@ func TestNewUpdate(t *testing.T) {
wg.Add(1)
onUpdate := false
u := NewUpdate()
u := NewUpdate(httpAgent)
defer u.StopWatch()
u.SetOnUpdateListener(func() {
onUpdate = true
@@ -46,7 +48,7 @@ func TestDoNotUpdate(t *testing.T) {
wg.Add(1)
onUpdate := false
u := NewUpdate()
u := NewUpdate(httpAgent)
defer u.StopWatch()
u.SetOnUpdateListener(func() {
onUpdate = true
@@ -71,7 +73,7 @@ func TestDaemonUpdate(t *testing.T) {
wg.Add(1)
onUpdate := false
u := NewUpdate()
u := NewUpdate(httpAgent)
defer u.StopWatch()
u.SetOnUpdateListener(func() {
onUpdate = true