mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-20 08:02:04 -04:00
Compare commits
3 Commits
set-cmd
...
stop-using
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5f5d597c59 | ||
|
|
b85aad07d4 | ||
|
|
7398836c2e |
2
.github/workflows/release.yml
vendored
2
.github/workflows/release.yml
vendored
@@ -9,7 +9,7 @@ on:
|
||||
pull_request:
|
||||
|
||||
env:
|
||||
SIGN_PIPE_VER: "v0.0.20"
|
||||
SIGN_PIPE_VER: "v0.0.19"
|
||||
GORELEASER_VER: "v2.3.2"
|
||||
PRODUCT_NAME: "NetBird"
|
||||
COPYRIGHT: "NetBird GmbH"
|
||||
|
||||
@@ -134,7 +134,6 @@ jobs:
|
||||
NETBIRD_STORE_ENGINE_MYSQL_DSN: '${{ env.NETBIRD_STORE_ENGINE_MYSQL_DSN }}$'
|
||||
CI_NETBIRD_MGMT_IDP_SIGNKEY_REFRESH: false
|
||||
CI_NETBIRD_TURN_EXTERNAL_IP: "1.2.3.4"
|
||||
CI_NETBIRD_MGMT_DISABLE_DEFAULT_POLICY: false
|
||||
|
||||
run: |
|
||||
set -x
|
||||
@@ -181,7 +180,6 @@ jobs:
|
||||
grep -A 7 Relay management.json | egrep '"Secret": ".+"'
|
||||
grep DisablePromptLogin management.json | grep 'true'
|
||||
grep LoginFlag management.json | grep 0
|
||||
grep DisableDefaultPolicy management.json | grep "$CI_NETBIRD_MGMT_DISABLE_DEFAULT_POLICY"
|
||||
|
||||
- name: Install modules
|
||||
run: go mod tidy
|
||||
|
||||
@@ -1,210 +0,0 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
"github.com/netbirdio/netbird/management/domain"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
// SharedFlags contains all configuration flags that are common between up and set commands
|
||||
type SharedFlags struct {
|
||||
// Network configuration
|
||||
InterfaceName string
|
||||
WireguardPort uint16
|
||||
NATExternalIPs []string
|
||||
CustomDNSAddress string
|
||||
ExtraIFaceBlackList []string
|
||||
DNSLabels []string
|
||||
DNSRouteInterval time.Duration
|
||||
|
||||
// Feature flags
|
||||
RosenpassEnabled bool
|
||||
RosenpassPermissive bool
|
||||
ServerSSHAllowed bool
|
||||
AutoConnectDisabled bool
|
||||
NetworkMonitor bool
|
||||
LazyConnEnabled bool
|
||||
|
||||
// System flags
|
||||
DisableClientRoutes bool
|
||||
DisableServerRoutes bool
|
||||
DisableDNS bool
|
||||
DisableFirewall bool
|
||||
BlockLANAccess bool
|
||||
BlockInbound bool
|
||||
|
||||
// Login-specific (only for up command)
|
||||
NoBrowser bool
|
||||
}
|
||||
|
||||
// AddSharedFlags adds all shared configuration flags to the given command
|
||||
func AddSharedFlags(cmd *cobra.Command, flags *SharedFlags) {
|
||||
// Network configuration flags
|
||||
cmd.PersistentFlags().StringVar(&flags.InterfaceName, interfaceNameFlag, iface.WgInterfaceDefault,
|
||||
"Wireguard interface name")
|
||||
cmd.PersistentFlags().Uint16Var(&flags.WireguardPort, wireguardPortFlag, iface.DefaultWgPort,
|
||||
"Wireguard interface listening port")
|
||||
cmd.PersistentFlags().StringSliceVar(&flags.NATExternalIPs, externalIPMapFlag, nil,
|
||||
`Sets external IPs maps between local addresses and interfaces. `+
|
||||
`You can specify a comma-separated list with a single IP and IP/IP or IP/Interface Name. `+
|
||||
`An empty string "" clears the previous configuration. `+
|
||||
`E.g. --external-ip-map 12.34.56.78/10.0.0.1 or --external-ip-map 12.34.56.200,12.34.56.78/10.0.0.1,12.34.56.80/eth1 `+
|
||||
`or --external-ip-map ""`)
|
||||
cmd.PersistentFlags().StringVar(&flags.CustomDNSAddress, dnsResolverAddress, "",
|
||||
`Sets a custom address for NetBird's local DNS resolver. `+
|
||||
`If set, the agent won't attempt to discover the best ip and port to listen on. `+
|
||||
`An empty string "" clears the previous configuration. `+
|
||||
`E.g. --dns-resolver-address 127.0.0.1:5053 or --dns-resolver-address ""`)
|
||||
cmd.PersistentFlags().StringSliceVar(&flags.ExtraIFaceBlackList, extraIFaceBlackListFlag, nil,
|
||||
"Extra list of default interfaces to ignore for listening")
|
||||
cmd.PersistentFlags().StringSliceVar(&flags.DNSLabels, dnsLabelsFlag, nil,
|
||||
`Sets DNS labels. `+
|
||||
`You can specify a comma-separated list of up to 32 labels. `+
|
||||
`An empty string "" clears the previous configuration. `+
|
||||
`E.g. --extra-dns-labels vpc1 or --extra-dns-labels vpc1,mgmt1 `+
|
||||
`or --extra-dns-labels ""`)
|
||||
cmd.PersistentFlags().DurationVar(&flags.DNSRouteInterval, dnsRouteIntervalFlag, time.Minute,
|
||||
"DNS route update interval")
|
||||
|
||||
// Feature flags
|
||||
cmd.PersistentFlags().BoolVar(&flags.RosenpassEnabled, enableRosenpassFlag, false,
|
||||
"[Experimental] Enable Rosenpass feature. If enabled, the connection will be post-quantum secured via Rosenpass.")
|
||||
cmd.PersistentFlags().BoolVar(&flags.RosenpassPermissive, rosenpassPermissiveFlag, false,
|
||||
"[Experimental] Enable Rosenpass in permissive mode to allow this peer to accept WireGuard connections without requiring Rosenpass functionality from peers that do not have Rosenpass enabled.")
|
||||
cmd.PersistentFlags().BoolVar(&flags.ServerSSHAllowed, serverSSHAllowedFlag, false,
|
||||
"Allow SSH server on peer. If enabled, the SSH server will be permitted")
|
||||
cmd.PersistentFlags().BoolVar(&flags.AutoConnectDisabled, disableAutoConnectFlag, false,
|
||||
"Disables auto-connect feature. If enabled, then the client won't connect automatically when the service starts.")
|
||||
cmd.PersistentFlags().BoolVarP(&flags.NetworkMonitor, networkMonitorFlag, "N", networkMonitor,
|
||||
`Manage network monitoring. Defaults to true on Windows and macOS, false on Linux and FreeBSD. `+
|
||||
`E.g. --network-monitor=false to disable or --network-monitor=true to enable.`)
|
||||
cmd.PersistentFlags().BoolVar(&flags.LazyConnEnabled, enableLazyConnectionFlag, false,
|
||||
"[Experimental] Enable the lazy connection feature. If enabled, the client will establish connections on-demand.")
|
||||
|
||||
// System flags (from system.go)
|
||||
cmd.PersistentFlags().BoolVar(&flags.DisableClientRoutes, disableClientRoutesFlag, false,
|
||||
"Disable client routes. If enabled, the client won't process client routes received from the management service.")
|
||||
cmd.PersistentFlags().BoolVar(&flags.DisableServerRoutes, disableServerRoutesFlag, false,
|
||||
"Disable server routes. If enabled, the client won't act as a router for server routes received from the management service.")
|
||||
cmd.PersistentFlags().BoolVar(&flags.DisableDNS, disableDNSFlag, false,
|
||||
"Disable DNS. If enabled, the client won't configure DNS settings.")
|
||||
cmd.PersistentFlags().BoolVar(&flags.DisableFirewall, disableFirewallFlag, false,
|
||||
"Disable firewall configuration. If enabled, the client won't modify firewall rules.")
|
||||
cmd.PersistentFlags().BoolVar(&flags.BlockLANAccess, blockLANAccessFlag, false,
|
||||
"Block access to local networks (LAN) when using this peer as a router or exit node")
|
||||
cmd.PersistentFlags().BoolVar(&flags.BlockInbound, blockInboundFlag, false,
|
||||
"Block inbound connections. If enabled, the client will not allow any inbound connections to the local machine nor routed networks.\n"+
|
||||
"This overrides any policies received from the management service.")
|
||||
}
|
||||
|
||||
// AddUpOnlyFlags adds flags that are specific to the up command
|
||||
func AddUpOnlyFlags(cmd *cobra.Command, flags *SharedFlags) {
|
||||
cmd.PersistentFlags().BoolVar(&flags.NoBrowser, noBrowserFlag, false, noBrowserDesc)
|
||||
}
|
||||
|
||||
// BuildConfigInput creates an internal.ConfigInput from SharedFlags with Changed() checks
|
||||
func BuildConfigInput(cmd *cobra.Command, flags *SharedFlags, customDNSAddressConverted []byte) (*internal.ConfigInput, error) {
|
||||
ic := internal.ConfigInput{
|
||||
ManagementURL: managementURL,
|
||||
AdminURL: adminURL,
|
||||
ConfigPath: configPath,
|
||||
CustomDNSAddress: customDNSAddressConverted,
|
||||
}
|
||||
|
||||
// Handle PreSharedKey from root command
|
||||
if rootCmd.PersistentFlags().Changed(preSharedKeyFlag) {
|
||||
ic.PreSharedKey = &preSharedKey
|
||||
}
|
||||
|
||||
if cmd.Flag(enableRosenpassFlag).Changed {
|
||||
ic.RosenpassEnabled = &flags.RosenpassEnabled
|
||||
}
|
||||
|
||||
if cmd.Flag(rosenpassPermissiveFlag).Changed {
|
||||
ic.RosenpassPermissive = &flags.RosenpassPermissive
|
||||
}
|
||||
|
||||
if cmd.Flag(serverSSHAllowedFlag).Changed {
|
||||
ic.ServerSSHAllowed = &flags.ServerSSHAllowed
|
||||
}
|
||||
|
||||
if cmd.Flag(interfaceNameFlag).Changed {
|
||||
if err := parseInterfaceName(flags.InterfaceName); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ic.InterfaceName = &flags.InterfaceName
|
||||
}
|
||||
|
||||
if cmd.Flag(wireguardPortFlag).Changed {
|
||||
p := int(flags.WireguardPort)
|
||||
ic.WireguardPort = &p
|
||||
}
|
||||
|
||||
if cmd.Flag(networkMonitorFlag).Changed {
|
||||
ic.NetworkMonitor = &flags.NetworkMonitor
|
||||
}
|
||||
|
||||
if cmd.Flag(disableAutoConnectFlag).Changed {
|
||||
ic.DisableAutoConnect = &flags.AutoConnectDisabled
|
||||
|
||||
if flags.AutoConnectDisabled {
|
||||
cmd.Println("Autoconnect has been disabled. The client won't connect automatically when the service starts.")
|
||||
} else {
|
||||
cmd.Println("Autoconnect has been enabled. The client will connect automatically when the service starts.")
|
||||
}
|
||||
}
|
||||
|
||||
if cmd.Flag(dnsRouteIntervalFlag).Changed {
|
||||
ic.DNSRouteInterval = &flags.DNSRouteInterval
|
||||
}
|
||||
|
||||
if cmd.Flag(disableClientRoutesFlag).Changed {
|
||||
ic.DisableClientRoutes = &flags.DisableClientRoutes
|
||||
}
|
||||
|
||||
if cmd.Flag(disableServerRoutesFlag).Changed {
|
||||
ic.DisableServerRoutes = &flags.DisableServerRoutes
|
||||
}
|
||||
|
||||
if cmd.Flag(disableDNSFlag).Changed {
|
||||
ic.DisableDNS = &flags.DisableDNS
|
||||
}
|
||||
|
||||
if cmd.Flag(disableFirewallFlag).Changed {
|
||||
ic.DisableFirewall = &flags.DisableFirewall
|
||||
}
|
||||
|
||||
if cmd.Flag(blockLANAccessFlag).Changed {
|
||||
ic.BlockLANAccess = &flags.BlockLANAccess
|
||||
}
|
||||
|
||||
if cmd.Flag(blockInboundFlag).Changed {
|
||||
ic.BlockInbound = &flags.BlockInbound
|
||||
}
|
||||
|
||||
if cmd.Flag(enableLazyConnectionFlag).Changed {
|
||||
ic.LazyConnectionEnabled = &flags.LazyConnEnabled
|
||||
}
|
||||
|
||||
if cmd.Flag(externalIPMapFlag).Changed {
|
||||
ic.NATExternalIPs = flags.NATExternalIPs
|
||||
}
|
||||
|
||||
if cmd.Flag(extraIFaceBlackListFlag).Changed {
|
||||
ic.ExtraIFaceBlackList = flags.ExtraIFaceBlackList
|
||||
}
|
||||
|
||||
if cmd.Flag(dnsLabelsFlag).Changed {
|
||||
var err error
|
||||
ic.DNSLabels, err = domain.FromStringList(flags.DNSLabels)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid DNS labels: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
return &ic, nil
|
||||
}
|
||||
@@ -149,7 +149,6 @@ func init() {
|
||||
rootCmd.AddCommand(loginCmd)
|
||||
rootCmd.AddCommand(versionCmd)
|
||||
rootCmd.AddCommand(sshCmd)
|
||||
rootCmd.AddCommand(setCmd)
|
||||
rootCmd.AddCommand(networksCMD)
|
||||
rootCmd.AddCommand(forwardingRulesCmd)
|
||||
rootCmd.AddCommand(debugCmd)
|
||||
@@ -168,6 +167,24 @@ func init() {
|
||||
debugCmd.AddCommand(forCmd)
|
||||
debugCmd.AddCommand(persistenceCmd)
|
||||
|
||||
upCmd.PersistentFlags().StringSliceVar(&natExternalIPs, externalIPMapFlag, nil,
|
||||
`Sets external IPs maps between local addresses and interfaces.`+
|
||||
`You can specify a comma-separated list with a single IP and IP/IP or IP/Interface Name. `+
|
||||
`An empty string "" clears the previous configuration. `+
|
||||
`E.g. --external-ip-map 12.34.56.78/10.0.0.1 or --external-ip-map 12.34.56.200,12.34.56.78/10.0.0.1,12.34.56.80/eth1 `+
|
||||
`or --external-ip-map ""`,
|
||||
)
|
||||
upCmd.PersistentFlags().StringVar(&customDNSAddress, dnsResolverAddress, "",
|
||||
`Sets a custom address for NetBird's local DNS resolver. `+
|
||||
`If set, the agent won't attempt to discover the best ip and port to listen on. `+
|
||||
`An empty string "" clears the previous configuration. `+
|
||||
`E.g. --dns-resolver-address 127.0.0.1:5053 or --dns-resolver-address ""`,
|
||||
)
|
||||
upCmd.PersistentFlags().BoolVar(&rosenpassEnabled, enableRosenpassFlag, false, "[Experimental] Enable Rosenpass feature. If enabled, the connection will be post-quantum secured via Rosenpass.")
|
||||
upCmd.PersistentFlags().BoolVar(&rosenpassPermissive, rosenpassPermissiveFlag, false, "[Experimental] Enable Rosenpass in permissive mode to allow this peer to accept WireGuard connections without requiring Rosenpass functionality from peers that do not have Rosenpass enabled.")
|
||||
upCmd.PersistentFlags().BoolVar(&serverSSHAllowed, serverSSHAllowedFlag, false, "Allow SSH server on peer. If enabled, the SSH server will be permitted")
|
||||
upCmd.PersistentFlags().BoolVar(&autoConnectDisabled, disableAutoConnectFlag, false, "Disables auto-connect feature. If enabled, then the client won't connect automatically when the service starts.")
|
||||
upCmd.PersistentFlags().BoolVar(&lazyConnEnabled, enableLazyConnectionFlag, false, "[Experimental] Enable the lazy connection feature. If enabled, the client will establish connections on-demand.")
|
||||
|
||||
debugCmd.PersistentFlags().BoolVarP(&debugSystemInfoFlag, systemInfoFlag, "S", true, "Adds system information to the debug bundle")
|
||||
debugCmd.PersistentFlags().BoolVarP(&debugUploadBundle, uploadBundle, "U", false, fmt.Sprintf("Uploads the debug bundle to a server from URL defined by %s", uploadBundleURL))
|
||||
|
||||
@@ -1,161 +0,0 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
"google.golang.org/protobuf/types/known/durationpb"
|
||||
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
)
|
||||
|
||||
var (
|
||||
setFlags = &SharedFlags{}
|
||||
|
||||
setCmd = &cobra.Command{
|
||||
Use: "set",
|
||||
Short: "Update NetBird client configuration",
|
||||
Long: `Update NetBird client configuration without connecting. Uses the same flags as 'netbird up' but only updates the configuration file.`,
|
||||
RunE: setFunc,
|
||||
}
|
||||
)
|
||||
|
||||
func init() {
|
||||
// Add all shared flags to the set command
|
||||
AddSharedFlags(setCmd, setFlags)
|
||||
// Note: We don't add up-only flags like --no-browser to set command
|
||||
}
|
||||
|
||||
func setFunc(cmd *cobra.Command, _ []string) error {
|
||||
SetFlagsFromEnvVars(rootCmd)
|
||||
SetFlagsFromEnvVars(cmd)
|
||||
|
||||
cmd.SetOut(cmd.OutOrStdout())
|
||||
|
||||
// Validate inputs (reuse validation logic from up.go)
|
||||
if err := validateNATExternalIPs(setFlags.NATExternalIPs); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if cmd.Flag(dnsLabelsFlag).Changed {
|
||||
if _, err := validateDnsLabels(setFlags.DNSLabels); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
var customDNSAddressConverted []byte
|
||||
if cmd.Flag(dnsResolverAddress).Changed {
|
||||
var err error
|
||||
customDNSAddressConverted, err = parseCustomDNSAddress(cmd.Flag(dnsResolverAddress).Changed)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Connect to daemon
|
||||
ctx := cmd.Context()
|
||||
conn, err := DialClientGRPCServer(ctx, daemonAddr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("connect to daemon: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
if closeErr := conn.Close(); closeErr != nil {
|
||||
fmt.Printf("Warning: failed to close connection: %v\n", closeErr)
|
||||
}
|
||||
}()
|
||||
|
||||
client := proto.NewDaemonServiceClient(conn)
|
||||
req := &proto.SetConfigRequest{}
|
||||
|
||||
// Set fields based on changed flags
|
||||
if cmd.Flag(enableRosenpassFlag).Changed {
|
||||
req.RosenpassEnabled = &setFlags.RosenpassEnabled
|
||||
}
|
||||
|
||||
if cmd.Flag(rosenpassPermissiveFlag).Changed {
|
||||
req.RosenpassPermissive = &setFlags.RosenpassPermissive
|
||||
}
|
||||
|
||||
if cmd.Flag(serverSSHAllowedFlag).Changed {
|
||||
req.ServerSSHAllowed = &setFlags.ServerSSHAllowed
|
||||
}
|
||||
|
||||
if cmd.Flag(disableAutoConnectFlag).Changed {
|
||||
req.DisableAutoConnect = &setFlags.AutoConnectDisabled
|
||||
}
|
||||
|
||||
if cmd.Flag(networkMonitorFlag).Changed {
|
||||
req.NetworkMonitor = &setFlags.NetworkMonitor
|
||||
}
|
||||
|
||||
if cmd.Flag(interfaceNameFlag).Changed {
|
||||
if err := parseInterfaceName(setFlags.InterfaceName); err != nil {
|
||||
return err
|
||||
}
|
||||
req.InterfaceName = &setFlags.InterfaceName
|
||||
}
|
||||
|
||||
if cmd.Flag(wireguardPortFlag).Changed {
|
||||
port := int64(setFlags.WireguardPort)
|
||||
req.WireguardPort = &port
|
||||
}
|
||||
|
||||
if cmd.Flag(dnsResolverAddress).Changed {
|
||||
customAddr := string(customDNSAddressConverted)
|
||||
req.CustomDNSAddress = &customAddr
|
||||
}
|
||||
|
||||
if cmd.Flag(extraIFaceBlackListFlag).Changed {
|
||||
req.ExtraIFaceBlacklist = setFlags.ExtraIFaceBlackList
|
||||
}
|
||||
|
||||
if cmd.Flag(dnsLabelsFlag).Changed {
|
||||
req.DnsLabels = setFlags.DNSLabels
|
||||
req.CleanDNSLabels = &[]bool{setFlags.DNSLabels != nil && len(setFlags.DNSLabels) == 0}[0]
|
||||
}
|
||||
|
||||
if cmd.Flag(externalIPMapFlag).Changed {
|
||||
req.NatExternalIPs = setFlags.NATExternalIPs
|
||||
req.CleanNATExternalIPs = &[]bool{setFlags.NATExternalIPs != nil && len(setFlags.NATExternalIPs) == 0}[0]
|
||||
}
|
||||
|
||||
if cmd.Flag(dnsRouteIntervalFlag).Changed {
|
||||
req.DnsRouteInterval = durationpb.New(setFlags.DNSRouteInterval)
|
||||
}
|
||||
|
||||
if cmd.Flag(disableClientRoutesFlag).Changed {
|
||||
req.DisableClientRoutes = &setFlags.DisableClientRoutes
|
||||
}
|
||||
|
||||
if cmd.Flag(disableServerRoutesFlag).Changed {
|
||||
req.DisableServerRoutes = &setFlags.DisableServerRoutes
|
||||
}
|
||||
|
||||
if cmd.Flag(disableDNSFlag).Changed {
|
||||
req.DisableDns = &setFlags.DisableDNS
|
||||
}
|
||||
|
||||
if cmd.Flag(disableFirewallFlag).Changed {
|
||||
req.DisableFirewall = &setFlags.DisableFirewall
|
||||
}
|
||||
|
||||
if cmd.Flag(blockLANAccessFlag).Changed {
|
||||
req.BlockLanAccess = &setFlags.BlockLANAccess
|
||||
}
|
||||
|
||||
if cmd.Flag(blockInboundFlag).Changed {
|
||||
req.BlockInbound = &setFlags.BlockInbound
|
||||
}
|
||||
|
||||
if cmd.Flag(enableLazyConnectionFlag).Changed {
|
||||
req.LazyConnectionEnabled = &setFlags.LazyConnEnabled
|
||||
}
|
||||
|
||||
// Send the request
|
||||
if _, err := client.SetConfig(ctx, req); err != nil {
|
||||
return fmt.Errorf("update configuration: %w", err)
|
||||
}
|
||||
|
||||
cmd.Println("Configuration updated successfully")
|
||||
return nil
|
||||
}
|
||||
@@ -1,110 +0,0 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
func TestParseBoolArg(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
expected bool
|
||||
hasError bool
|
||||
}{
|
||||
{"true", true, false},
|
||||
{"True", true, false},
|
||||
{"1", true, false},
|
||||
{"yes", true, false},
|
||||
{"on", true, false},
|
||||
{"false", false, false},
|
||||
{"False", false, false},
|
||||
{"0", false, false},
|
||||
{"no", false, false},
|
||||
{"off", false, false},
|
||||
{"invalid", false, true},
|
||||
{"maybe", false, true},
|
||||
{"", false, true},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.input, func(t *testing.T) {
|
||||
result, err := parseBoolArg(test.input)
|
||||
|
||||
if test.hasError {
|
||||
if err == nil {
|
||||
t.Errorf("Expected error for input %q, but got none", test.input)
|
||||
}
|
||||
} else {
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error for input %q: %v", test.input, err)
|
||||
}
|
||||
if result != test.expected {
|
||||
t.Errorf("For input %q, expected %v but got %v", test.input, test.expected, result)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetCommandStructure(t *testing.T) {
|
||||
// Test that the set command has the expected subcommands
|
||||
expectedSubcommands := []string{
|
||||
"autoconnect",
|
||||
"ssh-server",
|
||||
"network-monitor",
|
||||
"rosenpass",
|
||||
"dns",
|
||||
"dns-interval",
|
||||
}
|
||||
|
||||
actualSubcommands := make([]string, 0, len(setCmd.Commands()))
|
||||
for _, cmd := range setCmd.Commands() {
|
||||
actualSubcommands = append(actualSubcommands, cmd.Name())
|
||||
}
|
||||
|
||||
if len(actualSubcommands) != len(expectedSubcommands) {
|
||||
t.Errorf("Expected %d subcommands, got %d", len(expectedSubcommands), len(actualSubcommands))
|
||||
}
|
||||
|
||||
for _, expected := range expectedSubcommands {
|
||||
found := false
|
||||
for _, actual := range actualSubcommands {
|
||||
if actual == expected {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Errorf("Expected subcommand %q not found", expected)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetCommandUsage(t *testing.T) {
|
||||
if setCmd.Use != "set" {
|
||||
t.Errorf("Expected command use to be 'set', got %q", setCmd.Use)
|
||||
}
|
||||
|
||||
if setCmd.Short != "Set NetBird client configuration" {
|
||||
t.Errorf("Expected short description to be 'Set NetBird client configuration', got %q", setCmd.Short)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSubcommandArgRequirements(t *testing.T) {
|
||||
// Test that all subcommands except dns-interval require exactly 1 argument
|
||||
subcommands := []*cobra.Command{
|
||||
setAutoconnectCmd,
|
||||
setSSHServerCmd,
|
||||
setNetworkMonitorCmd,
|
||||
setRosenpassCmd,
|
||||
setDNSCmd,
|
||||
setDNSIntervalCmd,
|
||||
}
|
||||
|
||||
for _, cmd := range subcommands {
|
||||
if cmd.Args == nil {
|
||||
t.Errorf("Command %q should have Args validation", cmd.Name())
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -120,7 +120,7 @@ func getStatus(ctx context.Context) (*proto.StatusResponse, error) {
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
resp, err := proto.NewDaemonServiceClient(conn).Status(ctx, &proto.StatusRequest{GetFullPeerStatus: true, ShouldRunProbes: true})
|
||||
resp, err := proto.NewDaemonServiceClient(conn).Status(ctx, &proto.StatusRequest{GetFullPeerStatus: true})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("status failed: %v", status.Convert(err).Message())
|
||||
}
|
||||
|
||||
@@ -19,3 +19,24 @@ var (
|
||||
blockInbound bool
|
||||
)
|
||||
|
||||
func init() {
|
||||
// Add system flags to upCmd
|
||||
upCmd.PersistentFlags().BoolVar(&disableClientRoutes, disableClientRoutesFlag, false,
|
||||
"Disable client routes. If enabled, the client won't process client routes received from the management service.")
|
||||
|
||||
upCmd.PersistentFlags().BoolVar(&disableServerRoutes, disableServerRoutesFlag, false,
|
||||
"Disable server routes. If enabled, the client won't act as a router for server routes received from the management service.")
|
||||
|
||||
upCmd.PersistentFlags().BoolVar(&disableDNS, disableDNSFlag, false,
|
||||
"Disable DNS. If enabled, the client won't configure DNS settings.")
|
||||
|
||||
upCmd.PersistentFlags().BoolVar(&disableFirewall, disableFirewallFlag, false,
|
||||
"Disable firewall configuration. If enabled, the client won't modify firewall rules.")
|
||||
|
||||
upCmd.PersistentFlags().BoolVar(&blockLANAccess, blockLANAccessFlag, false,
|
||||
"Block access to local networks (LAN) when using this peer as a router or exit node")
|
||||
|
||||
upCmd.PersistentFlags().BoolVar(&blockInbound, blockInboundFlag, false,
|
||||
"Block inbound connections. If enabled, the client will not allow any inbound connections to the local machine nor routed networks.\n"+
|
||||
"This overrides any policies received from the management service.")
|
||||
}
|
||||
|
||||
@@ -103,7 +103,7 @@ func startManagement(t *testing.T, config *types.Config, testFile string) (*grpc
|
||||
Return(&types.Settings{}, nil).
|
||||
AnyTimes()
|
||||
|
||||
accountManager, err := mgmt.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, iv, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false)
|
||||
accountManager, err := mgmt.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, iv, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
124
client/cmd/up.go
124
client/cmd/up.go
@@ -7,6 +7,7 @@ import (
|
||||
"net/netip"
|
||||
"runtime"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/spf13/cobra"
|
||||
@@ -14,6 +15,7 @@ import (
|
||||
gstatus "google.golang.org/grpc/status"
|
||||
"google.golang.org/protobuf/types/known/durationpb"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
@@ -40,7 +42,6 @@ var (
|
||||
dnsLabels []string
|
||||
dnsLabelsValidated domain.List
|
||||
noBrowser bool
|
||||
upFlags = &SharedFlags{}
|
||||
|
||||
upCmd = &cobra.Command{
|
||||
Use: "up",
|
||||
@@ -50,12 +51,26 @@ var (
|
||||
)
|
||||
|
||||
func init() {
|
||||
// Add shared flags to up command
|
||||
AddSharedFlags(upCmd, upFlags)
|
||||
|
||||
// Add up-specific flags
|
||||
upCmd.PersistentFlags().BoolVarP(&foregroundMode, "foreground-mode", "F", false, "start service in foreground")
|
||||
AddUpOnlyFlags(upCmd, upFlags)
|
||||
upCmd.PersistentFlags().StringVar(&interfaceName, interfaceNameFlag, iface.WgInterfaceDefault, "Wireguard interface name")
|
||||
upCmd.PersistentFlags().Uint16Var(&wireguardPort, wireguardPortFlag, iface.DefaultWgPort, "Wireguard interface listening port")
|
||||
upCmd.PersistentFlags().BoolVarP(&networkMonitor, networkMonitorFlag, "N", networkMonitor,
|
||||
`Manage network monitoring. Defaults to true on Windows and macOS, false on Linux and FreeBSD. `+
|
||||
`E.g. --network-monitor=false to disable or --network-monitor=true to enable.`,
|
||||
)
|
||||
upCmd.PersistentFlags().StringSliceVar(&extraIFaceBlackList, extraIFaceBlackListFlag, nil, "Extra list of default interfaces to ignore for listening")
|
||||
upCmd.PersistentFlags().DurationVar(&dnsRouteInterval, dnsRouteIntervalFlag, time.Minute, "DNS route update interval")
|
||||
|
||||
upCmd.PersistentFlags().StringSliceVar(&dnsLabels, dnsLabelsFlag, nil,
|
||||
`Sets DNS labels`+
|
||||
`You can specify a comma-separated list of up to 32 labels. `+
|
||||
`An empty string "" clears the previous configuration. `+
|
||||
`E.g. --extra-dns-labels vpc1 or --extra-dns-labels vpc1,mgmt1 `+
|
||||
`or --extra-dns-labels ""`,
|
||||
)
|
||||
|
||||
upCmd.PersistentFlags().BoolVar(&noBrowser, noBrowserFlag, false, noBrowserDesc)
|
||||
|
||||
}
|
||||
|
||||
func upFunc(cmd *cobra.Command, args []string) error {
|
||||
@@ -103,16 +118,7 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command) error {
|
||||
return err
|
||||
}
|
||||
|
||||
// Handle DNS labels validation and assignment to SharedFlags
|
||||
if cmd.Flag(dnsLabelsFlag).Changed {
|
||||
var err error
|
||||
dnsLabelsValidated, err = validateDnsLabels(upFlags.DNSLabels)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
ic, err := BuildConfigInput(cmd, upFlags, customDNSAddressConverted)
|
||||
ic, err := setupConfig(customDNSAddressConverted, cmd)
|
||||
if err != nil {
|
||||
return fmt.Errorf("setup config: %v", err)
|
||||
}
|
||||
@@ -229,6 +235,92 @@ func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func setupConfig(customDNSAddressConverted []byte, cmd *cobra.Command) (*internal.ConfigInput, error) {
|
||||
ic := internal.ConfigInput{
|
||||
ManagementURL: managementURL,
|
||||
AdminURL: adminURL,
|
||||
ConfigPath: configPath,
|
||||
NATExternalIPs: natExternalIPs,
|
||||
CustomDNSAddress: customDNSAddressConverted,
|
||||
ExtraIFaceBlackList: extraIFaceBlackList,
|
||||
DNSLabels: dnsLabelsValidated,
|
||||
}
|
||||
|
||||
if cmd.Flag(enableRosenpassFlag).Changed {
|
||||
ic.RosenpassEnabled = &rosenpassEnabled
|
||||
}
|
||||
|
||||
if cmd.Flag(rosenpassPermissiveFlag).Changed {
|
||||
ic.RosenpassPermissive = &rosenpassPermissive
|
||||
}
|
||||
|
||||
if cmd.Flag(serverSSHAllowedFlag).Changed {
|
||||
ic.ServerSSHAllowed = &serverSSHAllowed
|
||||
}
|
||||
|
||||
if cmd.Flag(interfaceNameFlag).Changed {
|
||||
if err := parseInterfaceName(interfaceName); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ic.InterfaceName = &interfaceName
|
||||
}
|
||||
|
||||
if cmd.Flag(wireguardPortFlag).Changed {
|
||||
p := int(wireguardPort)
|
||||
ic.WireguardPort = &p
|
||||
}
|
||||
|
||||
if cmd.Flag(networkMonitorFlag).Changed {
|
||||
ic.NetworkMonitor = &networkMonitor
|
||||
}
|
||||
|
||||
if rootCmd.PersistentFlags().Changed(preSharedKeyFlag) {
|
||||
ic.PreSharedKey = &preSharedKey
|
||||
}
|
||||
|
||||
if cmd.Flag(disableAutoConnectFlag).Changed {
|
||||
ic.DisableAutoConnect = &autoConnectDisabled
|
||||
|
||||
if autoConnectDisabled {
|
||||
cmd.Println("Autoconnect has been disabled. The client won't connect automatically when the service starts.")
|
||||
}
|
||||
|
||||
if !autoConnectDisabled {
|
||||
cmd.Println("Autoconnect has been enabled. The client will connect automatically when the service starts.")
|
||||
}
|
||||
}
|
||||
|
||||
if cmd.Flag(dnsRouteIntervalFlag).Changed {
|
||||
ic.DNSRouteInterval = &dnsRouteInterval
|
||||
}
|
||||
|
||||
if cmd.Flag(disableClientRoutesFlag).Changed {
|
||||
ic.DisableClientRoutes = &disableClientRoutes
|
||||
}
|
||||
if cmd.Flag(disableServerRoutesFlag).Changed {
|
||||
ic.DisableServerRoutes = &disableServerRoutes
|
||||
}
|
||||
if cmd.Flag(disableDNSFlag).Changed {
|
||||
ic.DisableDNS = &disableDNS
|
||||
}
|
||||
if cmd.Flag(disableFirewallFlag).Changed {
|
||||
ic.DisableFirewall = &disableFirewall
|
||||
}
|
||||
|
||||
if cmd.Flag(blockLANAccessFlag).Changed {
|
||||
ic.BlockLANAccess = &blockLANAccess
|
||||
}
|
||||
|
||||
if cmd.Flag(blockInboundFlag).Changed {
|
||||
ic.BlockInbound = &blockInbound
|
||||
}
|
||||
|
||||
if cmd.Flag(enableLazyConnectionFlag).Changed {
|
||||
ic.LazyConnectionEnabled = &lazyConnEnabled
|
||||
}
|
||||
return &ic, nil
|
||||
}
|
||||
|
||||
func setupLoginRequest(providedSetupKey string, customDNSAddressConverted []byte, cmd *cobra.Command) (*proto.LoginRequest, error) {
|
||||
loginRequest := proto.LoginRequest{
|
||||
SetupKey: providedSetupKey,
|
||||
|
||||
@@ -319,6 +319,10 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) {
|
||||
*input.WireguardPort, config.WgPort)
|
||||
config.WgPort = *input.WireguardPort
|
||||
updated = true
|
||||
} else if config.WgPort == 0 {
|
||||
config.WgPort = iface.DefaultWgPort
|
||||
log.Infof("using default Wireguard port %d", config.WgPort)
|
||||
updated = true
|
||||
}
|
||||
|
||||
if input.InterfaceName != nil && *input.InterfaceName != config.WgIface {
|
||||
|
||||
@@ -17,6 +17,7 @@ import (
|
||||
"google.golang.org/grpc/codes"
|
||||
gstatus "google.golang.org/grpc/status"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/client/internal/dns"
|
||||
"github.com/netbirdio/netbird/client/internal/listener"
|
||||
@@ -525,13 +526,17 @@ func statusRecorderToSignalConnStateNotifier(statusRecorder *peer.Status) signal
|
||||
|
||||
// freePort attempts to determine if the provided port is available, if not it will ask the system for a free port.
|
||||
func freePort(initPort int) (int, error) {
|
||||
addr := net.UDPAddr{Port: initPort}
|
||||
addr := net.UDPAddr{}
|
||||
if initPort == 0 {
|
||||
initPort = iface.DefaultWgPort
|
||||
}
|
||||
|
||||
addr.Port = initPort
|
||||
|
||||
conn, err := net.ListenUDP("udp", &addr)
|
||||
if err == nil {
|
||||
returnPort := conn.LocalAddr().(*net.UDPAddr).Port
|
||||
closeConnWithLog(conn)
|
||||
return returnPort, nil
|
||||
return initPort, nil
|
||||
}
|
||||
|
||||
// if the port is already in use, ask the system for a free port
|
||||
|
||||
@@ -13,10 +13,10 @@ func Test_freePort(t *testing.T) {
|
||||
shouldMatch bool
|
||||
}{
|
||||
{
|
||||
name: "when port is 0 use random port",
|
||||
name: "not provided, fallback to default",
|
||||
port: 0,
|
||||
want: 0,
|
||||
shouldMatch: false,
|
||||
want: 51820,
|
||||
shouldMatch: true,
|
||||
},
|
||||
{
|
||||
name: "provided and available",
|
||||
@@ -31,7 +31,7 @@ func Test_freePort(t *testing.T) {
|
||||
shouldMatch: false,
|
||||
},
|
||||
}
|
||||
c1, err := net.ListenUDP("udp", &net.UDPAddr{Port: 0})
|
||||
c1, err := net.ListenUDP("udp", &net.UDPAddr{Port: 51830})
|
||||
if err != nil {
|
||||
t.Errorf("freePort error = %v", err)
|
||||
}
|
||||
@@ -39,14 +39,6 @@ func Test_freePort(t *testing.T) {
|
||||
_ = c1.Close()
|
||||
}(c1)
|
||||
|
||||
if tests[1].port == c1.LocalAddr().(*net.UDPAddr).Port {
|
||||
tests[1].port++
|
||||
tests[1].want++
|
||||
}
|
||||
|
||||
tests[2].port = c1.LocalAddr().(*net.UDPAddr).Port
|
||||
tests[2].want = c1.LocalAddr().(*net.UDPAddr).Port
|
||||
|
||||
for _, tt := range tests {
|
||||
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
|
||||
@@ -1476,7 +1476,7 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri
|
||||
|
||||
permissionsManager := permissions.NewManager(store)
|
||||
|
||||
accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
|
||||
accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
@@ -28,10 +28,7 @@ func (n Nexthop) String() string {
|
||||
if n.Intf == nil {
|
||||
return n.IP.String()
|
||||
}
|
||||
if n.IP.IsValid() {
|
||||
return fmt.Sprintf("%s @ %d (%s)", n.IP.String(), n.Intf.Index, n.Intf.Name)
|
||||
}
|
||||
return fmt.Sprintf("no-ip @ %d (%s)", n.Intf.Index, n.Intf.Name)
|
||||
return fmt.Sprintf("%s @ %d (%s)", n.IP.String(), n.Intf.Index, n.Intf.Name)
|
||||
}
|
||||
|
||||
type wgIface interface {
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -67,9 +67,6 @@ service DaemonService {
|
||||
rpc SubscribeEvents(SubscribeRequest) returns (stream SystemEvent) {}
|
||||
|
||||
rpc GetEvents(GetEventsRequest) returns (GetEventsResponse) {}
|
||||
|
||||
// SetConfig updates daemon configuration without reconnecting
|
||||
rpc SetConfig(SetConfigRequest) returns (SetConfigResponse) {}
|
||||
}
|
||||
|
||||
|
||||
@@ -161,7 +158,6 @@ message UpResponse {}
|
||||
|
||||
message StatusRequest{
|
||||
bool getFullPeerStatus = 1;
|
||||
bool shouldRunProbes = 2;
|
||||
}
|
||||
|
||||
message StatusResponse{
|
||||
@@ -499,29 +495,3 @@ message GetEventsRequest {}
|
||||
message GetEventsResponse {
|
||||
repeated SystemEvent events = 1;
|
||||
}
|
||||
|
||||
message SetConfigRequest {
|
||||
optional bool rosenpassEnabled = 1;
|
||||
optional bool rosenpassPermissive = 2;
|
||||
optional bool serverSSHAllowed = 3;
|
||||
optional bool disableAutoConnect = 4;
|
||||
optional bool networkMonitor = 5;
|
||||
optional google.protobuf.Duration dnsRouteInterval = 6;
|
||||
optional bool disable_client_routes = 7;
|
||||
optional bool disable_server_routes = 8;
|
||||
optional bool disable_dns = 9;
|
||||
optional bool disable_firewall = 10;
|
||||
optional bool block_lan_access = 11;
|
||||
optional bool lazyConnectionEnabled = 12;
|
||||
optional bool block_inbound = 13;
|
||||
optional string interfaceName = 14;
|
||||
optional int64 wireguardPort = 15;
|
||||
optional string customDNSAddress = 16;
|
||||
repeated string extraIFaceBlacklist = 17;
|
||||
repeated string dns_labels = 18;
|
||||
optional bool cleanDNSLabels = 19;
|
||||
repeated string natExternalIPs = 20;
|
||||
optional bool cleanNATExternalIPs = 21;
|
||||
}
|
||||
|
||||
message SetConfigResponse {}
|
||||
|
||||
@@ -55,8 +55,6 @@ type DaemonServiceClient interface {
|
||||
TracePacket(ctx context.Context, in *TracePacketRequest, opts ...grpc.CallOption) (*TracePacketResponse, error)
|
||||
SubscribeEvents(ctx context.Context, in *SubscribeRequest, opts ...grpc.CallOption) (DaemonService_SubscribeEventsClient, error)
|
||||
GetEvents(ctx context.Context, in *GetEventsRequest, opts ...grpc.CallOption) (*GetEventsResponse, error)
|
||||
// SetConfig updates daemon configuration without reconnecting
|
||||
SetConfig(ctx context.Context, in *SetConfigRequest, opts ...grpc.CallOption) (*SetConfigResponse, error)
|
||||
}
|
||||
|
||||
type daemonServiceClient struct {
|
||||
@@ -270,15 +268,6 @@ func (c *daemonServiceClient) GetEvents(ctx context.Context, in *GetEventsReques
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (c *daemonServiceClient) SetConfig(ctx context.Context, in *SetConfigRequest, opts ...grpc.CallOption) (*SetConfigResponse, error) {
|
||||
out := new(SetConfigResponse)
|
||||
err := c.cc.Invoke(ctx, "/daemon.DaemonService/SetConfig", in, out, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// DaemonServiceServer is the server API for DaemonService service.
|
||||
// All implementations must embed UnimplementedDaemonServiceServer
|
||||
// for forward compatibility
|
||||
@@ -320,8 +309,6 @@ type DaemonServiceServer interface {
|
||||
TracePacket(context.Context, *TracePacketRequest) (*TracePacketResponse, error)
|
||||
SubscribeEvents(*SubscribeRequest, DaemonService_SubscribeEventsServer) error
|
||||
GetEvents(context.Context, *GetEventsRequest) (*GetEventsResponse, error)
|
||||
// SetConfig updates daemon configuration without reconnecting
|
||||
SetConfig(context.Context, *SetConfigRequest) (*SetConfigResponse, error)
|
||||
mustEmbedUnimplementedDaemonServiceServer()
|
||||
}
|
||||
|
||||
@@ -389,9 +376,6 @@ func (UnimplementedDaemonServiceServer) SubscribeEvents(*SubscribeRequest, Daemo
|
||||
func (UnimplementedDaemonServiceServer) GetEvents(context.Context, *GetEventsRequest) (*GetEventsResponse, error) {
|
||||
return nil, status.Errorf(codes.Unimplemented, "method GetEvents not implemented")
|
||||
}
|
||||
func (UnimplementedDaemonServiceServer) SetConfig(context.Context, *SetConfigRequest) (*SetConfigResponse, error) {
|
||||
return nil, status.Errorf(codes.Unimplemented, "method SetConfig not implemented")
|
||||
}
|
||||
func (UnimplementedDaemonServiceServer) mustEmbedUnimplementedDaemonServiceServer() {}
|
||||
|
||||
// UnsafeDaemonServiceServer may be embedded to opt out of forward compatibility for this service.
|
||||
@@ -768,24 +752,6 @@ func _DaemonService_GetEvents_Handler(srv interface{}, ctx context.Context, dec
|
||||
return interceptor(ctx, in, info, handler)
|
||||
}
|
||||
|
||||
func _DaemonService_SetConfig_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
|
||||
in := new(SetConfigRequest)
|
||||
if err := dec(in); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if interceptor == nil {
|
||||
return srv.(DaemonServiceServer).SetConfig(ctx, in)
|
||||
}
|
||||
info := &grpc.UnaryServerInfo{
|
||||
Server: srv,
|
||||
FullMethod: "/daemon.DaemonService/SetConfig",
|
||||
}
|
||||
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return srv.(DaemonServiceServer).SetConfig(ctx, req.(*SetConfigRequest))
|
||||
}
|
||||
return interceptor(ctx, in, info, handler)
|
||||
}
|
||||
|
||||
// DaemonService_ServiceDesc is the grpc.ServiceDesc for DaemonService service.
|
||||
// It's only intended for direct use with grpc.RegisterService,
|
||||
// and not to be introspected or modified (even as a copy)
|
||||
@@ -869,10 +835,6 @@ var DaemonService_ServiceDesc = grpc.ServiceDesc{
|
||||
MethodName: "GetEvents",
|
||||
Handler: _DaemonService_GetEvents_Handler,
|
||||
},
|
||||
{
|
||||
MethodName: "SetConfig",
|
||||
Handler: _DaemonService_SetConfig_Handler,
|
||||
},
|
||||
},
|
||||
Streams: []grpc.StreamDesc{
|
||||
{
|
||||
|
||||
@@ -707,9 +707,7 @@ func (s *Server) Status(
|
||||
s.statusRecorder.UpdateRosenpass(s.config.RosenpassEnabled, s.config.RosenpassPermissive)
|
||||
|
||||
if msg.GetFullPeerStatus {
|
||||
if msg.ShouldRunProbes {
|
||||
s.runProbes()
|
||||
}
|
||||
s.runProbes()
|
||||
|
||||
fullStatus := s.statusRecorder.GetFullStatus()
|
||||
pbFullStatus := toProtoFullStatus(fullStatus)
|
||||
@@ -799,133 +797,6 @@ func (s *Server) GetConfig(_ context.Context, _ *proto.GetConfigRequest) (*proto
|
||||
}, nil
|
||||
}
|
||||
|
||||
// SetConfig updates daemon configuration without reconnecting
|
||||
func (s *Server) SetConfig(ctx context.Context, req *proto.SetConfigRequest) (*proto.SetConfigResponse, error) {
|
||||
s.mutex.Lock()
|
||||
defer s.mutex.Unlock()
|
||||
|
||||
if s.config == nil {
|
||||
return nil, gstatus.Errorf(codes.FailedPrecondition, "daemon is not configured")
|
||||
}
|
||||
|
||||
configChanged := false
|
||||
|
||||
if req.RosenpassEnabled != nil && s.config.RosenpassEnabled != *req.RosenpassEnabled {
|
||||
s.config.RosenpassEnabled = *req.RosenpassEnabled
|
||||
configChanged = true
|
||||
}
|
||||
|
||||
if req.RosenpassPermissive != nil && s.config.RosenpassPermissive != *req.RosenpassPermissive {
|
||||
s.config.RosenpassPermissive = *req.RosenpassPermissive
|
||||
configChanged = true
|
||||
}
|
||||
|
||||
if req.ServerSSHAllowed != nil && s.config.ServerSSHAllowed != nil && *s.config.ServerSSHAllowed != *req.ServerSSHAllowed {
|
||||
*s.config.ServerSSHAllowed = *req.ServerSSHAllowed
|
||||
configChanged = true
|
||||
}
|
||||
|
||||
if req.DisableAutoConnect != nil && s.config.DisableAutoConnect != *req.DisableAutoConnect {
|
||||
s.config.DisableAutoConnect = *req.DisableAutoConnect
|
||||
configChanged = true
|
||||
}
|
||||
|
||||
if req.NetworkMonitor != nil && s.config.NetworkMonitor != nil && *s.config.NetworkMonitor != *req.NetworkMonitor {
|
||||
*s.config.NetworkMonitor = *req.NetworkMonitor
|
||||
configChanged = true
|
||||
}
|
||||
|
||||
if req.DnsRouteInterval != nil {
|
||||
duration := req.DnsRouteInterval.AsDuration()
|
||||
if s.config.DNSRouteInterval != duration {
|
||||
s.config.DNSRouteInterval = duration
|
||||
configChanged = true
|
||||
}
|
||||
}
|
||||
|
||||
if req.DisableClientRoutes != nil && s.config.DisableClientRoutes != *req.DisableClientRoutes {
|
||||
s.config.DisableClientRoutes = *req.DisableClientRoutes
|
||||
configChanged = true
|
||||
}
|
||||
|
||||
if req.DisableServerRoutes != nil && s.config.DisableServerRoutes != *req.DisableServerRoutes {
|
||||
s.config.DisableServerRoutes = *req.DisableServerRoutes
|
||||
configChanged = true
|
||||
}
|
||||
|
||||
if req.DisableDns != nil && s.config.DisableDNS != *req.DisableDns {
|
||||
s.config.DisableDNS = *req.DisableDns
|
||||
configChanged = true
|
||||
}
|
||||
|
||||
if req.DisableFirewall != nil && s.config.DisableFirewall != *req.DisableFirewall {
|
||||
s.config.DisableFirewall = *req.DisableFirewall
|
||||
configChanged = true
|
||||
}
|
||||
|
||||
if req.BlockLanAccess != nil && s.config.BlockLANAccess != *req.BlockLanAccess {
|
||||
s.config.BlockLANAccess = *req.BlockLanAccess
|
||||
configChanged = true
|
||||
}
|
||||
|
||||
if req.LazyConnectionEnabled != nil && s.config.LazyConnectionEnabled != *req.LazyConnectionEnabled {
|
||||
s.config.LazyConnectionEnabled = *req.LazyConnectionEnabled
|
||||
configChanged = true
|
||||
}
|
||||
|
||||
if req.BlockInbound != nil && s.config.BlockInbound != *req.BlockInbound {
|
||||
s.config.BlockInbound = *req.BlockInbound
|
||||
configChanged = true
|
||||
}
|
||||
|
||||
if req.InterfaceName != nil && s.config.WgIface != *req.InterfaceName {
|
||||
s.config.WgIface = *req.InterfaceName
|
||||
configChanged = true
|
||||
}
|
||||
|
||||
if req.WireguardPort != nil && s.config.WgPort != int(*req.WireguardPort) {
|
||||
s.config.WgPort = int(*req.WireguardPort)
|
||||
configChanged = true
|
||||
}
|
||||
|
||||
if req.CustomDNSAddress != nil && s.config.CustomDNSAddress != *req.CustomDNSAddress {
|
||||
s.config.CustomDNSAddress = *req.CustomDNSAddress
|
||||
configChanged = true
|
||||
}
|
||||
|
||||
if len(req.ExtraIFaceBlacklist) > 0 {
|
||||
s.config.IFaceBlackList = req.ExtraIFaceBlacklist
|
||||
configChanged = true
|
||||
}
|
||||
|
||||
if len(req.DnsLabels) > 0 || (req.CleanDNSLabels != nil && *req.CleanDNSLabels) {
|
||||
if req.CleanDNSLabels != nil && *req.CleanDNSLabels {
|
||||
s.config.DNSLabels = domain.List{}
|
||||
} else {
|
||||
var err error
|
||||
s.config.DNSLabels, err = domain.FromStringList(req.DnsLabels)
|
||||
if err != nil {
|
||||
return nil, gstatus.Errorf(codes.InvalidArgument, "invalid DNS labels: %v", err)
|
||||
}
|
||||
}
|
||||
configChanged = true
|
||||
}
|
||||
|
||||
if len(req.NatExternalIPs) > 0 || (req.CleanNATExternalIPs != nil && *req.CleanNATExternalIPs) {
|
||||
s.config.NATExternalIPs = req.NatExternalIPs
|
||||
configChanged = true
|
||||
}
|
||||
|
||||
if configChanged {
|
||||
if err := internal.WriteOutConfig(s.latestConfigInput.ConfigPath, s.config); err != nil {
|
||||
return nil, gstatus.Errorf(codes.Internal, "write config: %v", err)
|
||||
}
|
||||
log.Debug("Configuration updated successfully")
|
||||
}
|
||||
|
||||
return &proto.SetConfigResponse{}, nil
|
||||
}
|
||||
|
||||
func (s *Server) onSessionExpire() {
|
||||
if runtime.GOOS != "windows" {
|
||||
isUIActive := internal.CheckUIApp()
|
||||
|
||||
@@ -206,7 +206,7 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve
|
||||
settingsMockManager := settings.NewMockManager(ctrl)
|
||||
permissionsManagerMock := permissions.NewMockManager(ctrl)
|
||||
|
||||
accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false)
|
||||
accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
@@ -879,7 +879,7 @@ func (s *serviceClient) onUpdateAvailable() {
|
||||
func (s *serviceClient) onSessionExpire() {
|
||||
s.sendNotification = true
|
||||
if s.sendNotification {
|
||||
go s.eventHandler.runSelfCommand(s.ctx, "login-url", "true")
|
||||
s.eventHandler.runSelfCommand(s.ctx, "login-url", "true")
|
||||
s.sendNotification = false
|
||||
}
|
||||
}
|
||||
|
||||
@@ -15,7 +15,6 @@ NETBIRD_MGMT_API_CERT_KEY_FILE="/etc/letsencrypt/live/$NETBIRD_LETSENCRYPT_DOMAI
|
||||
NETBIRD_MGMT_SINGLE_ACCOUNT_MODE_DOMAIN=$NETBIRD_DOMAIN
|
||||
NETBIRD_MGMT_DNS_DOMAIN=${NETBIRD_MGMT_DNS_DOMAIN:-netbird.selfhosted}
|
||||
NETBIRD_MGMT_IDP_SIGNKEY_REFRESH=${NETBIRD_MGMT_IDP_SIGNKEY_REFRESH:-false}
|
||||
NETBIRD_MGMT_DISABLE_DEFAULT_POLICY=${NETBIRD_MGMT_DISABLE_DEFAULT_POLICY:-false}
|
||||
|
||||
# Signal
|
||||
NETBIRD_SIGNAL_PROTOCOL="http"
|
||||
@@ -140,4 +139,3 @@ export NETBIRD_RELAY_PORT
|
||||
export NETBIRD_RELAY_ENDPOINT
|
||||
export NETBIRD_RELAY_AUTH_SECRET
|
||||
export NETBIRD_RELAY_TAG
|
||||
export NETBIRD_MGMT_DISABLE_DEFAULT_POLICY
|
||||
|
||||
@@ -791,6 +791,7 @@ services:
|
||||
- '443:443'
|
||||
- '443:443/udp'
|
||||
- '80:80'
|
||||
- '8080:8080'
|
||||
volumes:
|
||||
- netbird_caddy_data:/data
|
||||
- ./Caddyfile:/etc/caddy/Caddyfile
|
||||
|
||||
@@ -38,7 +38,6 @@
|
||||
"0.0.0.0/0"
|
||||
]
|
||||
},
|
||||
"DisableDefaultPolicy": $NETBIRD_MGMT_DISABLE_DEFAULT_POLICY,
|
||||
"Datadir": "",
|
||||
"DataStoreEncryptionKey": "$NETBIRD_DATASTORE_ENC_KEY",
|
||||
"StoreConfig": {
|
||||
|
||||
@@ -92,8 +92,7 @@ NETBIRD_LETSENCRYPT_EMAIL=""
|
||||
NETBIRD_DISABLE_ANONYMOUS_METRICS=false
|
||||
# DNS DOMAIN configures the domain name used for peer resolution. By default it is netbird.selfhosted
|
||||
NETBIRD_MGMT_DNS_DOMAIN=netbird.selfhosted
|
||||
# Disable default all-to-all policy for new accounts
|
||||
NETBIRD_MGMT_DISABLE_DEFAULT_POLICY=false
|
||||
|
||||
# -------------------------------------------
|
||||
# Relay settings
|
||||
# -------------------------------------------
|
||||
|
||||
@@ -29,4 +29,3 @@ NETBIRD_TURN_EXTERNAL_IP=1.2.3.4
|
||||
NETBIRD_RELAY_PORT=33445
|
||||
NETBIRD_AUTH_PKCE_DISABLE_PROMPT_LOGIN=true
|
||||
NETBIRD_AUTH_PKCE_LOGIN_FLAG=0
|
||||
NETBIRD_MGMT_DISABLE_DEFAULT_POLICY=$CI_NETBIRD_MGMT_DISABLE_DEFAULT_POLICY
|
||||
|
||||
@@ -100,7 +100,7 @@ func startManagement(t *testing.T) (*grpc.Server, net.Listener) {
|
||||
Return(true, nil).
|
||||
AnyTimes()
|
||||
|
||||
accountManager, err := mgmt.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false)
|
||||
accountManager, err := mgmt.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
@@ -215,7 +215,7 @@ var (
|
||||
peersManager := peers.NewManager(store, permissionsManager)
|
||||
proxyController := integrations.NewController(store)
|
||||
accountManager, err := server.BuildManager(ctx, store, peersUpdateManager, idpManager, mgmtSingleAccModeDomain,
|
||||
dnsDomain, eventStore, geo, userDeleteFromIDPEnabled, integratedPeerValidator, appMetrics, proxyController, settingsManager, permissionsManager, config.DisableDefaultPolicy)
|
||||
dnsDomain, eventStore, geo, userDeleteFromIDPEnabled, integratedPeerValidator, appMetrics, proxyController, settingsManager, permissionsManager)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to build default manager: %v", err)
|
||||
}
|
||||
|
||||
@@ -102,8 +102,6 @@ type DefaultAccountManager struct {
|
||||
|
||||
accountUpdateLocks sync.Map
|
||||
updateAccountPeersBufferInterval atomic.Int64
|
||||
|
||||
disableDefaultPolicy bool
|
||||
}
|
||||
|
||||
// getJWTGroupsChanges calculates the changes needed to sync a user's JWT groups.
|
||||
@@ -172,7 +170,6 @@ func BuildManager(
|
||||
proxyController port_forwarding.Controller,
|
||||
settingsManager settings.Manager,
|
||||
permissionsManager permissions.Manager,
|
||||
disableDefaultPolicy bool,
|
||||
) (*DefaultAccountManager, error) {
|
||||
start := time.Now()
|
||||
defer func() {
|
||||
@@ -198,7 +195,6 @@ func BuildManager(
|
||||
proxyController: proxyController,
|
||||
settingsManager: settingsManager,
|
||||
permissionsManager: permissionsManager,
|
||||
disableDefaultPolicy: disableDefaultPolicy,
|
||||
}
|
||||
|
||||
am.startWarmup(ctx)
|
||||
@@ -374,7 +370,7 @@ func (am *DefaultAccountManager) validateSettingsUpdate(ctx context.Context, tra
|
||||
return status.Errorf(status.InvalidArgument, "invalid domain \"%s\" provided for DNS domain", newSettings.DNSDomain)
|
||||
}
|
||||
|
||||
peers, err := transaction.GetAccountPeers(ctx, store.LockingStrengthShare, accountID, "", "")
|
||||
peers, err := transaction.GetAccountPeers(ctx, store.LockingStrengthNone, accountID, "", "")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -547,7 +543,7 @@ func (am *DefaultAccountManager) newAccount(ctx context.Context, userID, domain
|
||||
log.WithContext(ctx).Warnf("an account with ID already exists, retrying...")
|
||||
continue
|
||||
case statusErr.Type() == status.NotFound:
|
||||
newAccount := newAccountWithId(ctx, accountId, userID, domain, am.disableDefaultPolicy)
|
||||
newAccount := newAccountWithId(ctx, accountId, userID, domain)
|
||||
am.StoreEvent(ctx, userID, newAccount.Id, accountId, activity.AccountCreated, nil)
|
||||
return newAccount, nil
|
||||
default:
|
||||
@@ -712,7 +708,7 @@ func (am *DefaultAccountManager) DeleteAccount(ctx context.Context, accountID, u
|
||||
|
||||
// AccountExists checks if an account exists.
|
||||
func (am *DefaultAccountManager) AccountExists(ctx context.Context, accountID string) (bool, error) {
|
||||
return am.Store.AccountExists(ctx, store.LockingStrengthShare, accountID)
|
||||
return am.Store.AccountExists(ctx, store.LockingStrengthNone, accountID)
|
||||
}
|
||||
|
||||
// GetAccountIDByUserID retrieves the account ID based on the userID provided.
|
||||
@@ -724,7 +720,7 @@ func (am *DefaultAccountManager) GetAccountIDByUserID(ctx context.Context, userI
|
||||
return "", status.Errorf(status.NotFound, "no valid userID provided")
|
||||
}
|
||||
|
||||
accountID, err := am.Store.GetAccountIDByUserID(ctx, store.LockingStrengthShare, userID)
|
||||
accountID, err := am.Store.GetAccountIDByUserID(ctx, store.LockingStrengthNone, userID)
|
||||
if err != nil {
|
||||
if s, ok := status.FromError(err); ok && s.Type() == status.NotFound {
|
||||
account, err := am.GetOrCreateAccountByUser(ctx, userID, domain)
|
||||
@@ -779,7 +775,7 @@ func (am *DefaultAccountManager) loadAccount(ctx context.Context, accountID any)
|
||||
log.WithContext(ctx).Debugf("account %s not found in cache, reloading", accountID)
|
||||
accountIDString := fmt.Sprintf("%v", accountID)
|
||||
|
||||
accountUsers, err := am.Store.GetAccountUsers(ctx, store.LockingStrengthShare, accountIDString)
|
||||
accountUsers, err := am.Store.GetAccountUsers(ctx, store.LockingStrengthNone, accountIDString)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
@@ -833,7 +829,7 @@ func (am *DefaultAccountManager) lookupUserInCacheByEmail(ctx context.Context, e
|
||||
|
||||
// lookupUserInCache looks up user in the IdP cache and returns it. If the user wasn't found, the function returns nil
|
||||
func (am *DefaultAccountManager) lookupUserInCache(ctx context.Context, userID string, accountID string) (*idp.UserData, error) {
|
||||
accountUsers, err := am.Store.GetAccountUsers(ctx, store.LockingStrengthShare, accountID)
|
||||
accountUsers, err := am.Store.GetAccountUsers(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -863,7 +859,7 @@ func (am *DefaultAccountManager) lookupUserInCache(ctx context.Context, userID s
|
||||
|
||||
// add extra check on external cache manager. We may get to this point when the user is not yet findable in IDP,
|
||||
// or it didn't have its metadata updated with am.addAccountIDToIDPAppMeta
|
||||
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID)
|
||||
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, userID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed finding user %s in account %s", userID, accountID)
|
||||
return nil, err
|
||||
@@ -1014,7 +1010,7 @@ func (am *DefaultAccountManager) updateAccountDomainAttributesIfNotUpToDate(ctx
|
||||
unlockAccount := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlockAccount()
|
||||
|
||||
accountDomain, domainCategory, err := am.Store.GetAccountDomainAndCategory(ctx, store.LockingStrengthShare, accountID)
|
||||
accountDomain, domainCategory, err := am.Store.GetAccountDomainAndCategory(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("error getting account domain and category: %v", err)
|
||||
return err
|
||||
@@ -1024,7 +1020,7 @@ func (am *DefaultAccountManager) updateAccountDomainAttributesIfNotUpToDate(ctx
|
||||
return nil
|
||||
}
|
||||
|
||||
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userAuth.UserId)
|
||||
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, userAuth.UserId)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("error getting user: %v", err)
|
||||
return err
|
||||
@@ -1189,7 +1185,7 @@ func (am *DefaultAccountManager) GetAccountMeta(ctx context.Context, accountID s
|
||||
return nil, status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
return am.Store.GetAccountMeta(ctx, store.LockingStrengthShare, accountID)
|
||||
return am.Store.GetAccountMeta(ctx, store.LockingStrengthNone, accountID)
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) GetAccountIDFromUserAuth(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error) {
|
||||
@@ -1209,7 +1205,7 @@ func (am *DefaultAccountManager) GetAccountIDFromUserAuth(ctx context.Context, u
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userAuth.UserId)
|
||||
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, userAuth.UserId)
|
||||
if err != nil {
|
||||
// this is not really possible because we got an account by user ID
|
||||
return "", "", status.Errorf(status.NotFound, "user %s not found", userAuth.UserId)
|
||||
@@ -1241,7 +1237,7 @@ func (am *DefaultAccountManager) SyncUserJWTGroups(ctx context.Context, userAuth
|
||||
return nil
|
||||
}
|
||||
|
||||
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, userAuth.AccountId)
|
||||
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, userAuth.AccountId)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -1267,12 +1263,12 @@ func (am *DefaultAccountManager) SyncUserJWTGroups(ctx context.Context, userAuth
|
||||
var hasChanges bool
|
||||
var user *types.User
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
user, err = transaction.GetUserByUserID(ctx, store.LockingStrengthShare, userAuth.UserId)
|
||||
user, err = transaction.GetUserByUserID(ctx, store.LockingStrengthNone, userAuth.UserId)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error getting user: %w", err)
|
||||
}
|
||||
|
||||
groups, err := transaction.GetAccountGroups(ctx, store.LockingStrengthShare, userAuth.AccountId)
|
||||
groups, err := transaction.GetAccountGroups(ctx, store.LockingStrengthNone, userAuth.AccountId)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error getting account groups: %w", err)
|
||||
}
|
||||
@@ -1302,7 +1298,7 @@ func (am *DefaultAccountManager) SyncUserJWTGroups(ctx context.Context, userAuth
|
||||
|
||||
// Propagate changes to peers if group propagation is enabled
|
||||
if settings.GroupsPropagationEnabled {
|
||||
groups, err = transaction.GetAccountGroups(ctx, store.LockingStrengthShare, userAuth.AccountId)
|
||||
groups, err = transaction.GetAccountGroups(ctx, store.LockingStrengthNone, userAuth.AccountId)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error getting account groups: %w", err)
|
||||
}
|
||||
@@ -1312,7 +1308,7 @@ func (am *DefaultAccountManager) SyncUserJWTGroups(ctx context.Context, userAuth
|
||||
groupsMap[group.ID] = group
|
||||
}
|
||||
|
||||
peers, err := transaction.GetUserPeers(ctx, store.LockingStrengthShare, userAuth.AccountId, userAuth.UserId)
|
||||
peers, err := transaction.GetUserPeers(ctx, store.LockingStrengthNone, userAuth.AccountId, userAuth.UserId)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error getting user peers: %w", err)
|
||||
}
|
||||
@@ -1344,7 +1340,7 @@ func (am *DefaultAccountManager) SyncUserJWTGroups(ctx context.Context, userAuth
|
||||
}
|
||||
|
||||
for _, g := range addNewGroups {
|
||||
group, err := am.Store.GetGroupByID(ctx, store.LockingStrengthShare, userAuth.AccountId, g)
|
||||
group, err := am.Store.GetGroupByID(ctx, store.LockingStrengthNone, userAuth.AccountId, g)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Debugf("group %s not found while saving user activity event of account %s", g, userAuth.AccountId)
|
||||
} else {
|
||||
@@ -1357,7 +1353,7 @@ func (am *DefaultAccountManager) SyncUserJWTGroups(ctx context.Context, userAuth
|
||||
}
|
||||
|
||||
for _, g := range removeOldGroups {
|
||||
group, err := am.Store.GetGroupByID(ctx, store.LockingStrengthShare, userAuth.AccountId, g)
|
||||
group, err := am.Store.GetGroupByID(ctx, store.LockingStrengthNone, userAuth.AccountId, g)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Debugf("group %s not found while saving user activity event of account %s", g, userAuth.AccountId)
|
||||
} else {
|
||||
@@ -1418,7 +1414,7 @@ func (am *DefaultAccountManager) getAccountIDWithAuthorizationClaims(ctx context
|
||||
}
|
||||
|
||||
if userAuth.IsChild {
|
||||
exists, err := am.Store.AccountExists(ctx, store.LockingStrengthShare, userAuth.AccountId)
|
||||
exists, err := am.Store.AccountExists(ctx, store.LockingStrengthNone, userAuth.AccountId)
|
||||
if err != nil || !exists {
|
||||
return "", err
|
||||
}
|
||||
@@ -1442,7 +1438,7 @@ func (am *DefaultAccountManager) getAccountIDWithAuthorizationClaims(ctx context
|
||||
return "", err
|
||||
}
|
||||
|
||||
userAccountID, err := am.Store.GetAccountIDByUserID(ctx, store.LockingStrengthShare, userAuth.UserId)
|
||||
userAccountID, err := am.Store.GetAccountIDByUserID(ctx, store.LockingStrengthNone, userAuth.UserId)
|
||||
if handleNotFound(err) != nil {
|
||||
log.WithContext(ctx).Errorf("error getting account ID by user ID: %v", err)
|
||||
return "", err
|
||||
@@ -1463,7 +1459,7 @@ func (am *DefaultAccountManager) getAccountIDWithAuthorizationClaims(ctx context
|
||||
return am.addNewPrivateAccount(ctx, domainAccountID, userAuth)
|
||||
}
|
||||
func (am *DefaultAccountManager) getPrivateDomainWithGlobalLock(ctx context.Context, domain string) (string, context.CancelFunc, error) {
|
||||
domainAccountID, err := am.Store.GetAccountIDByPrivateDomain(ctx, store.LockingStrengthShare, domain)
|
||||
domainAccountID, err := am.Store.GetAccountIDByPrivateDomain(ctx, store.LockingStrengthNone, domain)
|
||||
if handleNotFound(err) != nil {
|
||||
|
||||
log.WithContext(ctx).Errorf(errorGettingDomainAccIDFmt, err)
|
||||
@@ -1478,7 +1474,7 @@ func (am *DefaultAccountManager) getPrivateDomainWithGlobalLock(ctx context.Cont
|
||||
cancel := am.Store.AcquireGlobalLock(ctx)
|
||||
|
||||
// check again if the domain has a primary account because of simultaneous requests
|
||||
domainAccountID, err = am.Store.GetAccountIDByPrivateDomain(ctx, store.LockingStrengthShare, domain)
|
||||
domainAccountID, err = am.Store.GetAccountIDByPrivateDomain(ctx, store.LockingStrengthNone, domain)
|
||||
if handleNotFound(err) != nil {
|
||||
cancel()
|
||||
log.WithContext(ctx).Errorf(errorGettingDomainAccIDFmt, err)
|
||||
@@ -1489,7 +1485,7 @@ func (am *DefaultAccountManager) getPrivateDomainWithGlobalLock(ctx context.Cont
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) handlePrivateAccountWithIDFromClaim(ctx context.Context, userAuth nbcontext.UserAuth) (string, error) {
|
||||
userAccountID, err := am.Store.GetAccountIDByUserID(ctx, store.LockingStrengthShare, userAuth.UserId)
|
||||
userAccountID, err := am.Store.GetAccountIDByUserID(ctx, store.LockingStrengthNone, userAuth.UserId)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("error getting account ID by user ID: %v", err)
|
||||
return "", err
|
||||
@@ -1499,7 +1495,7 @@ func (am *DefaultAccountManager) handlePrivateAccountWithIDFromClaim(ctx context
|
||||
return "", fmt.Errorf("user %s is not part of the account id %s", userAuth.UserId, userAuth.AccountId)
|
||||
}
|
||||
|
||||
accountDomain, domainCategory, err := am.Store.GetAccountDomainAndCategory(ctx, store.LockingStrengthShare, userAuth.AccountId)
|
||||
accountDomain, domainCategory, err := am.Store.GetAccountDomainAndCategory(ctx, store.LockingStrengthNone, userAuth.AccountId)
|
||||
if handleNotFound(err) != nil {
|
||||
log.WithContext(ctx).Errorf("error getting account domain and category: %v", err)
|
||||
return "", err
|
||||
@@ -1510,7 +1506,7 @@ func (am *DefaultAccountManager) handlePrivateAccountWithIDFromClaim(ctx context
|
||||
}
|
||||
|
||||
// We checked if the domain has a primary account already
|
||||
domainAccountID, err := am.Store.GetAccountIDByPrivateDomain(ctx, store.LockingStrengthShare, userAuth.Domain)
|
||||
domainAccountID, err := am.Store.GetAccountIDByPrivateDomain(ctx, store.LockingStrengthNone, userAuth.Domain)
|
||||
if handleNotFound(err) != nil {
|
||||
log.WithContext(ctx).Errorf(errorGettingDomainAccIDFmt, err)
|
||||
return "", err
|
||||
@@ -1640,7 +1636,7 @@ func (am *DefaultAccountManager) GetAccountIDForPeerKey(ctx context.Context, pee
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) handleUserPeer(ctx context.Context, transaction store.Store, peer *nbpeer.Peer, settings *types.Settings) (bool, error) {
|
||||
user, err := transaction.GetUserByUserID(ctx, store.LockingStrengthShare, peer.UserID)
|
||||
user, err := transaction.GetUserByUserID(ctx, store.LockingStrengthNone, peer.UserID)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
@@ -1662,7 +1658,7 @@ func (am *DefaultAccountManager) handleUserPeer(ctx context.Context, transaction
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) getFreeDNSLabel(ctx context.Context, s store.Store, accountID string, peerHostName string) (string, error) {
|
||||
existingLabels, err := s.GetPeerLabelsInAccount(ctx, store.LockingStrengthShare, accountID)
|
||||
existingLabels, err := s.GetPeerLabelsInAccount(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to get peer dns labels: %w", err)
|
||||
}
|
||||
@@ -1688,11 +1684,11 @@ func (am *DefaultAccountManager) GetAccountSettings(ctx context.Context, account
|
||||
if !allowed {
|
||||
return nil, status.NewPermissionDeniedError()
|
||||
}
|
||||
return am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID)
|
||||
return am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
|
||||
}
|
||||
|
||||
// newAccountWithId creates a new Account with a default SetupKey (doesn't store in a Store) and provided id
|
||||
func newAccountWithId(ctx context.Context, accountID, userID, domain string, disableDefaultPolicy bool) *types.Account {
|
||||
func newAccountWithId(ctx context.Context, accountID, userID, domain string) *types.Account {
|
||||
log.WithContext(ctx).Debugf("creating new account")
|
||||
|
||||
network := types.NewNetwork()
|
||||
@@ -1735,7 +1731,7 @@ func newAccountWithId(ctx context.Context, accountID, userID, domain string, dis
|
||||
},
|
||||
}
|
||||
|
||||
if err := acc.AddAllGroup(disableDefaultPolicy); err != nil {
|
||||
if err := acc.AddAllGroup(); err != nil {
|
||||
log.WithContext(ctx).Errorf("error adding all group to account %s: %v", acc.Id, err)
|
||||
}
|
||||
return acc
|
||||
@@ -1774,7 +1770,7 @@ func (am *DefaultAccountManager) GetOrCreateAccountByPrivateDomain(ctx context.C
|
||||
cancel := am.Store.AcquireGlobalLock(ctx)
|
||||
defer cancel()
|
||||
|
||||
existingPrimaryAccountID, err := am.Store.GetAccountIDByPrivateDomain(ctx, store.LockingStrengthShare, domain)
|
||||
existingPrimaryAccountID, err := am.Store.GetAccountIDByPrivateDomain(ctx, store.LockingStrengthNone, domain)
|
||||
if handleNotFound(err) != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
@@ -1794,7 +1790,7 @@ func (am *DefaultAccountManager) GetOrCreateAccountByPrivateDomain(ctx context.C
|
||||
for range 2 {
|
||||
accountId := xid.New().String()
|
||||
|
||||
exists, err := am.Store.AccountExists(ctx, store.LockingStrengthShare, accountId)
|
||||
exists, err := am.Store.AccountExists(ctx, store.LockingStrengthNone, accountId)
|
||||
if err != nil || exists {
|
||||
continue
|
||||
}
|
||||
@@ -1837,7 +1833,7 @@ func (am *DefaultAccountManager) GetOrCreateAccountByPrivateDomain(ctx context.C
|
||||
},
|
||||
}
|
||||
|
||||
if err := newAccount.AddAllGroup(am.disableDefaultPolicy); err != nil {
|
||||
if err := newAccount.AddAllGroup(); err != nil {
|
||||
return nil, false, status.Errorf(status.Internal, "failed to add all group to new account by private domain")
|
||||
}
|
||||
|
||||
@@ -1869,7 +1865,7 @@ func (am *DefaultAccountManager) UpdateToPrimaryAccount(ctx context.Context, acc
|
||||
return nil
|
||||
}
|
||||
|
||||
existingPrimaryAccountID, err := transaction.GetAccountIDByPrivateDomain(ctx, store.LockingStrengthShare, account.Domain)
|
||||
existingPrimaryAccountID, err := transaction.GetAccountIDByPrivateDomain(ctx, store.LockingStrengthNone, account.Domain)
|
||||
|
||||
// error is not a not found error
|
||||
if handleNotFound(err) != nil {
|
||||
@@ -1906,7 +1902,7 @@ func (am *DefaultAccountManager) UpdateToPrimaryAccount(ctx context.Context, acc
|
||||
// propagateUserGroupMemberships propagates all account users' group memberships to their peers.
|
||||
// Returns true if any groups were modified, true if those updates affect peers and an error.
|
||||
func propagateUserGroupMemberships(ctx context.Context, transaction store.Store, accountID string) (groupsUpdated bool, peersAffected bool, err error) {
|
||||
groups, err := transaction.GetAccountGroups(ctx, store.LockingStrengthShare, accountID)
|
||||
groups, err := transaction.GetAccountGroups(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
return false, false, err
|
||||
}
|
||||
@@ -1916,7 +1912,7 @@ func propagateUserGroupMemberships(ctx context.Context, transaction store.Store,
|
||||
groupsMap[group.ID] = group
|
||||
}
|
||||
|
||||
users, err := transaction.GetAccountUsers(ctx, store.LockingStrengthShare, accountID)
|
||||
users, err := transaction.GetAccountUsers(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
return false, false, err
|
||||
}
|
||||
@@ -1924,7 +1920,7 @@ func propagateUserGroupMemberships(ctx context.Context, transaction store.Store,
|
||||
groupsToUpdate := make(map[string]*types.Group)
|
||||
|
||||
for _, user := range users {
|
||||
userPeers, err := transaction.GetUserPeers(ctx, store.LockingStrengthShare, accountID, user.Id)
|
||||
userPeers, err := transaction.GetUserPeers(ctx, store.LockingStrengthNone, accountID, user.Id)
|
||||
if err != nil {
|
||||
return false, false, err
|
||||
}
|
||||
|
||||
@@ -373,7 +373,7 @@ func TestAccount_GetPeerNetworkMap(t *testing.T) {
|
||||
}
|
||||
|
||||
for _, testCase := range tt {
|
||||
account := newAccountWithId(context.Background(), "account-1", userID, "netbird.io", false)
|
||||
account := newAccountWithId(context.Background(), "account-1", userID, "netbird.io")
|
||||
account.UpdateSettings(&testCase.accountSettings)
|
||||
account.Network = network
|
||||
account.Peers = testCase.peers
|
||||
@@ -398,7 +398,7 @@ func TestNewAccount(t *testing.T) {
|
||||
domain := "netbird.io"
|
||||
userId := "account_creator"
|
||||
accountID := "account_id"
|
||||
account := newAccountWithId(context.Background(), accountID, userId, domain, false)
|
||||
account := newAccountWithId(context.Background(), accountID, userId, domain)
|
||||
verifyNewAccountHasDefaultFields(t, account, userId, domain, []string{userId})
|
||||
}
|
||||
|
||||
@@ -640,7 +640,7 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) {
|
||||
func TestDefaultAccountManager_SyncUserJWTGroups(t *testing.T) {
|
||||
userId := "user-id"
|
||||
domain := "test.domain"
|
||||
_ = newAccountWithId(context.Background(), "", userId, domain, false)
|
||||
_ = newAccountWithId(context.Background(), "", userId, domain)
|
||||
manager, err := createManager(t)
|
||||
require.NoError(t, err, "unable to create account manager")
|
||||
accountID, err := manager.GetAccountIDByUserID(context.Background(), userId, domain)
|
||||
@@ -782,7 +782,7 @@ func TestAccountManager_GetAccountByUserID(t *testing.T) {
|
||||
return
|
||||
}
|
||||
|
||||
exists, err := manager.Store.AccountExists(context.Background(), store.LockingStrengthShare, accountID)
|
||||
exists, err := manager.Store.AccountExists(context.Background(), store.LockingStrengthNone, accountID)
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, exists, "expected to get existing account after creation using userid")
|
||||
|
||||
@@ -793,7 +793,7 @@ func TestAccountManager_GetAccountByUserID(t *testing.T) {
|
||||
}
|
||||
|
||||
func createAccount(am *DefaultAccountManager, accountID, userID, domain string) (*types.Account, error) {
|
||||
account := newAccountWithId(context.Background(), accountID, userID, domain, false)
|
||||
account := newAccountWithId(context.Background(), accountID, userID, domain)
|
||||
err := am.Store.SaveAccount(context.Background(), account)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -899,11 +899,11 @@ func TestAccountManager_DeleteAccount(t *testing.T) {
|
||||
t.Fatal(fmt.Errorf("expected to get an error when trying to get deleted account, got %v", getAccount))
|
||||
}
|
||||
|
||||
pats, err := manager.Store.GetUserPATs(context.Background(), store.LockingStrengthShare, "service-user-1")
|
||||
pats, err := manager.Store.GetUserPATs(context.Background(), store.LockingStrengthNone, "service-user-1")
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, pats, 0)
|
||||
|
||||
pats, err = manager.Store.GetUserPATs(context.Background(), store.LockingStrengthShare, userId)
|
||||
pats, err = manager.Store.GetUserPATs(context.Background(), store.LockingStrengthNone, userId)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, pats, 0)
|
||||
}
|
||||
@@ -1775,7 +1775,7 @@ func TestDefaultAccountManager_DefaultAccountSettings(t *testing.T) {
|
||||
accountID, err := manager.GetAccountIDByUserID(context.Background(), userID, "")
|
||||
require.NoError(t, err, "unable to create an account")
|
||||
|
||||
settings, err := manager.Store.GetAccountSettings(context.Background(), store.LockingStrengthShare, accountID)
|
||||
settings, err := manager.Store.GetAccountSettings(context.Background(), store.LockingStrengthNone, accountID)
|
||||
require.NoError(t, err, "unable to get account settings")
|
||||
|
||||
assert.NotNil(t, settings)
|
||||
@@ -1960,7 +1960,7 @@ func TestDefaultAccountManager_UpdateAccountSettings(t *testing.T) {
|
||||
assert.False(t, updatedSettings.PeerLoginExpirationEnabled)
|
||||
assert.Equal(t, updatedSettings.PeerLoginExpiration, time.Hour)
|
||||
|
||||
settings, err := manager.Store.GetAccountSettings(context.Background(), store.LockingStrengthShare, accountID)
|
||||
settings, err := manager.Store.GetAccountSettings(context.Background(), store.LockingStrengthNone, accountID)
|
||||
require.NoError(t, err, "unable to get account settings")
|
||||
|
||||
assert.False(t, settings.PeerLoginExpirationEnabled)
|
||||
@@ -2643,7 +2643,7 @@ func TestAccount_SetJWTGroups(t *testing.T) {
|
||||
err = manager.SyncUserJWTGroups(context.Background(), claims)
|
||||
assert.NoError(t, err, "unable to sync jwt groups")
|
||||
|
||||
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user1")
|
||||
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthNone, "user1")
|
||||
assert.NoError(t, err, "unable to get user")
|
||||
assert.Len(t, user.AutoGroups, 0, "JWT groups should not be synced")
|
||||
})
|
||||
@@ -2657,7 +2657,7 @@ func TestAccount_SetJWTGroups(t *testing.T) {
|
||||
err := manager.SyncUserJWTGroups(context.Background(), claims)
|
||||
assert.NoError(t, err, "unable to sync jwt groups")
|
||||
|
||||
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user1")
|
||||
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthNone, "user1")
|
||||
assert.NoError(t, err, "unable to get user")
|
||||
assert.Empty(t, user.AutoGroups, "auto groups must be empty")
|
||||
})
|
||||
@@ -2671,11 +2671,11 @@ func TestAccount_SetJWTGroups(t *testing.T) {
|
||||
err := manager.SyncUserJWTGroups(context.Background(), claims)
|
||||
assert.NoError(t, err, "unable to sync jwt groups")
|
||||
|
||||
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user1")
|
||||
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthNone, "user1")
|
||||
assert.NoError(t, err, "unable to get user")
|
||||
assert.Len(t, user.AutoGroups, 0)
|
||||
|
||||
group1, err := manager.Store.GetGroupByID(context.Background(), store.LockingStrengthShare, "accountID", "group1")
|
||||
group1, err := manager.Store.GetGroupByID(context.Background(), store.LockingStrengthNone, "accountID", "group1")
|
||||
assert.NoError(t, err, "unable to get group")
|
||||
assert.Equal(t, group1.Issued, types.GroupIssuedAPI, "group should be api issued")
|
||||
})
|
||||
@@ -2692,11 +2692,11 @@ func TestAccount_SetJWTGroups(t *testing.T) {
|
||||
err = manager.SyncUserJWTGroups(context.Background(), claims)
|
||||
assert.NoError(t, err, "unable to sync jwt groups")
|
||||
|
||||
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user1")
|
||||
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthNone, "user1")
|
||||
assert.NoError(t, err, "unable to get user")
|
||||
assert.Len(t, user.AutoGroups, 1)
|
||||
|
||||
group1, err := manager.Store.GetGroupByID(context.Background(), store.LockingStrengthShare, "accountID", "group1")
|
||||
group1, err := manager.Store.GetGroupByID(context.Background(), store.LockingStrengthNone, "accountID", "group1")
|
||||
assert.NoError(t, err, "unable to get group")
|
||||
assert.Equal(t, group1.Issued, types.GroupIssuedAPI, "group should be api issued")
|
||||
})
|
||||
@@ -2710,7 +2710,7 @@ func TestAccount_SetJWTGroups(t *testing.T) {
|
||||
err = manager.SyncUserJWTGroups(context.Background(), claims)
|
||||
assert.NoError(t, err, "unable to sync jwt groups")
|
||||
|
||||
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user1")
|
||||
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthNone, "user1")
|
||||
assert.NoError(t, err, "unable to get user")
|
||||
assert.Len(t, user.AutoGroups, 2, "groups count should not be change")
|
||||
})
|
||||
@@ -2724,7 +2724,7 @@ func TestAccount_SetJWTGroups(t *testing.T) {
|
||||
err = manager.SyncUserJWTGroups(context.Background(), claims)
|
||||
assert.NoError(t, err, "unable to sync jwt groups")
|
||||
|
||||
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user1")
|
||||
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthNone, "user1")
|
||||
assert.NoError(t, err, "unable to get user")
|
||||
assert.Len(t, user.AutoGroups, 2, "groups count should not be change")
|
||||
})
|
||||
@@ -2738,11 +2738,11 @@ func TestAccount_SetJWTGroups(t *testing.T) {
|
||||
err = manager.SyncUserJWTGroups(context.Background(), claims)
|
||||
assert.NoError(t, err, "unable to sync jwt groups")
|
||||
|
||||
groups, err := manager.Store.GetAccountGroups(context.Background(), store.LockingStrengthShare, "accountID")
|
||||
groups, err := manager.Store.GetAccountGroups(context.Background(), store.LockingStrengthNone, "accountID")
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, groups, 3, "new group3 should be added")
|
||||
|
||||
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user2")
|
||||
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthNone, "user2")
|
||||
assert.NoError(t, err, "unable to get user")
|
||||
assert.Len(t, user.AutoGroups, 1, "new group should be added")
|
||||
})
|
||||
@@ -2756,7 +2756,7 @@ func TestAccount_SetJWTGroups(t *testing.T) {
|
||||
err = manager.SyncUserJWTGroups(context.Background(), claims)
|
||||
assert.NoError(t, err, "unable to sync jwt groups")
|
||||
|
||||
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user1")
|
||||
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthNone, "user1")
|
||||
assert.NoError(t, err, "unable to get user")
|
||||
assert.Len(t, user.AutoGroups, 1, "only non-JWT groups should remain")
|
||||
assert.Contains(t, user.AutoGroups, "group1", "group1 should still be present")
|
||||
@@ -2771,7 +2771,7 @@ func TestAccount_SetJWTGroups(t *testing.T) {
|
||||
err = manager.SyncUserJWTGroups(context.Background(), claims)
|
||||
assert.NoError(t, err, "unable to sync jwt groups")
|
||||
|
||||
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user2")
|
||||
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthNone, "user2")
|
||||
assert.NoError(t, err, "unable to get user")
|
||||
assert.Len(t, user.AutoGroups, 0, "all JWT groups should be removed")
|
||||
})
|
||||
@@ -2879,7 +2879,7 @@ func createManager(t testing.TB) (*DefaultAccountManager, error) {
|
||||
|
||||
permissionsManager := permissions.NewManager(store)
|
||||
|
||||
manager, err := BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
|
||||
manager, err := BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -3354,7 +3354,7 @@ func TestPropagateUserGroupMemberships(t *testing.T) {
|
||||
group1 := &types.Group{ID: "group1", Name: "Group 1", AccountID: account.Id}
|
||||
require.NoError(t, manager.Store.SaveGroup(ctx, store.LockingStrengthUpdate, group1))
|
||||
|
||||
user, err := manager.Store.GetUserByUserID(ctx, store.LockingStrengthShare, initiatorId)
|
||||
user, err := manager.Store.GetUserByUserID(ctx, store.LockingStrengthNone, initiatorId)
|
||||
require.NoError(t, err)
|
||||
|
||||
user.AutoGroups = append(user.AutoGroups, group1.ID)
|
||||
@@ -3365,7 +3365,7 @@ func TestPropagateUserGroupMemberships(t *testing.T) {
|
||||
assert.True(t, groupsUpdated)
|
||||
assert.False(t, groupChangesAffectPeers)
|
||||
|
||||
group, err := manager.Store.GetGroupByID(ctx, store.LockingStrengthShare, account.Id, group1.ID)
|
||||
group, err := manager.Store.GetGroupByID(ctx, store.LockingStrengthNone, account.Id, group1.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, group.Peers, 2)
|
||||
assert.Contains(t, group.Peers, "peer1")
|
||||
@@ -3376,7 +3376,7 @@ func TestPropagateUserGroupMemberships(t *testing.T) {
|
||||
group2 := &types.Group{ID: "group2", Name: "Group 2", AccountID: account.Id}
|
||||
require.NoError(t, manager.Store.SaveGroup(ctx, store.LockingStrengthUpdate, group2))
|
||||
|
||||
user, err := manager.Store.GetUserByUserID(ctx, store.LockingStrengthShare, initiatorId)
|
||||
user, err := manager.Store.GetUserByUserID(ctx, store.LockingStrengthNone, initiatorId)
|
||||
require.NoError(t, err)
|
||||
|
||||
user.AutoGroups = append(user.AutoGroups, group2.ID)
|
||||
@@ -3403,7 +3403,7 @@ func TestPropagateUserGroupMemberships(t *testing.T) {
|
||||
assert.True(t, groupsUpdated)
|
||||
assert.True(t, groupChangesAffectPeers)
|
||||
|
||||
groups, err := manager.Store.GetGroupsByIDs(ctx, store.LockingStrengthShare, account.Id, []string{"group1", "group2"})
|
||||
groups, err := manager.Store.GetGroupsByIDs(ctx, store.LockingStrengthNone, account.Id, []string{"group1", "group2"})
|
||||
require.NoError(t, err)
|
||||
for _, group := range groups {
|
||||
assert.Len(t, group.Peers, 2)
|
||||
@@ -3420,7 +3420,7 @@ func TestPropagateUserGroupMemberships(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("should not remove peers when groups are removed from user", func(t *testing.T) {
|
||||
user, err := manager.Store.GetUserByUserID(ctx, store.LockingStrengthShare, initiatorId)
|
||||
user, err := manager.Store.GetUserByUserID(ctx, store.LockingStrengthNone, initiatorId)
|
||||
require.NoError(t, err)
|
||||
|
||||
user.AutoGroups = []string{"group1"}
|
||||
@@ -3431,7 +3431,7 @@ func TestPropagateUserGroupMemberships(t *testing.T) {
|
||||
assert.False(t, groupsUpdated)
|
||||
assert.False(t, groupChangesAffectPeers)
|
||||
|
||||
groups, err := manager.Store.GetGroupsByIDs(ctx, store.LockingStrengthShare, account.Id, []string{"group1", "group2"})
|
||||
groups, err := manager.Store.GetGroupsByIDs(ctx, store.LockingStrengthNone, account.Id, []string{"group1", "group2"})
|
||||
require.NoError(t, err)
|
||||
for _, group := range groups {
|
||||
assert.Len(t, group.Peers, 2)
|
||||
|
||||
@@ -73,7 +73,7 @@ func (m *manager) EnsureUserAccessByJWTGroups(ctx context.Context, userAuth nbco
|
||||
return userAuth, nil
|
||||
}
|
||||
|
||||
settings, err := m.store.GetAccountSettings(ctx, store.LockingStrengthShare, userAuth.AccountId)
|
||||
settings, err := m.store.GetAccountSettings(ctx, store.LockingStrengthNone, userAuth.AccountId)
|
||||
if err != nil {
|
||||
return userAuth, err
|
||||
}
|
||||
@@ -104,7 +104,7 @@ func (am *manager) GetPATInfo(ctx context.Context, token string) (user *types.Us
|
||||
return nil, nil, "", "", err
|
||||
}
|
||||
|
||||
domain, category, err = am.store.GetAccountDomainAndCategory(ctx, store.LockingStrengthShare, user.AccountID)
|
||||
domain, category, err = am.store.GetAccountDomainAndCategory(ctx, store.LockingStrengthNone, user.AccountID)
|
||||
if err != nil {
|
||||
return nil, nil, "", "", err
|
||||
}
|
||||
@@ -142,12 +142,12 @@ func (am *manager) extractPATFromToken(ctx context.Context, token string) (*type
|
||||
var pat *types.PersonalAccessToken
|
||||
|
||||
err = am.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
pat, err = transaction.GetPATByHashedToken(ctx, store.LockingStrengthShare, encodedHashedToken)
|
||||
pat, err = transaction.GetPATByHashedToken(ctx, store.LockingStrengthNone, encodedHashedToken)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
user, err = transaction.GetUserByPATID(ctx, store.LockingStrengthShare, pat.ID)
|
||||
user, err = transaction.GetUserByPATID(ctx, store.LockingStrengthNone, pat.ID)
|
||||
return err
|
||||
})
|
||||
if err != nil {
|
||||
|
||||
@@ -72,7 +72,7 @@ func (am *DefaultAccountManager) GetDNSSettings(ctx context.Context, accountID s
|
||||
return nil, status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
return am.Store.GetAccountDNSSettings(ctx, store.LockingStrengthShare, accountID)
|
||||
return am.Store.GetAccountDNSSettings(ctx, store.LockingStrengthNone, accountID)
|
||||
}
|
||||
|
||||
// SaveDNSSettings validates a user role and updates the account's DNS settings
|
||||
@@ -139,7 +139,7 @@ func (am *DefaultAccountManager) prepareDNSSettingsEvents(ctx context.Context, t
|
||||
var eventsToStore []func()
|
||||
|
||||
modifiedGroups := slices.Concat(addedGroups, removedGroups)
|
||||
groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthShare, accountID, modifiedGroups)
|
||||
groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthNone, accountID, modifiedGroups)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Debugf("failed to get groups for dns settings events: %v", err)
|
||||
return nil
|
||||
@@ -195,7 +195,7 @@ func validateDNSSettings(ctx context.Context, transaction store.Store, accountID
|
||||
return nil
|
||||
}
|
||||
|
||||
groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthShare, accountID, settings.DisabledManagementGroups)
|
||||
groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthNone, accountID, settings.DisabledManagementGroups)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -217,7 +217,7 @@ func createDNSManager(t *testing.T) (*DefaultAccountManager, error) {
|
||||
|
||||
settingsMockManager := settings.NewMockManager(ctrl)
|
||||
permissionsManager := permissions.NewManager(store)
|
||||
return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.test", eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
|
||||
return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.test", eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager)
|
||||
}
|
||||
|
||||
func createDNSStore(t *testing.T) (store.Store, error) {
|
||||
@@ -267,7 +267,7 @@ func initTestDNSAccount(t *testing.T, am *DefaultAccountManager) (*types.Account
|
||||
|
||||
domain := "example.com"
|
||||
|
||||
account := newAccountWithId(context.Background(), dnsAccountID, dnsAdminUserID, domain, false)
|
||||
account := newAccountWithId(context.Background(), dnsAccountID, dnsAdminUserID, domain)
|
||||
|
||||
account.Users[dnsRegularUserID] = &types.User{
|
||||
Id: dnsRegularUserID,
|
||||
|
||||
@@ -122,7 +122,7 @@ func (e *EphemeralManager) OnPeerDisconnected(ctx context.Context, peer *nbpeer.
|
||||
}
|
||||
|
||||
func (e *EphemeralManager) loadEphemeralPeers(ctx context.Context) {
|
||||
peers, err := e.store.GetAllEphemeralPeers(ctx, store.LockingStrengthShare)
|
||||
peers, err := e.store.GetAllEphemeralPeers(ctx, store.LockingStrengthNone)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Debugf("failed to load ephemeral peers: %s", err)
|
||||
return
|
||||
|
||||
@@ -127,7 +127,7 @@ func TestNewManagerPeerDisconnected(t *testing.T) {
|
||||
}
|
||||
|
||||
func seedPeers(store *MockStore, numberOfPeers int, numberOfEphemeralPeers int) {
|
||||
store.account = newAccountWithId(context.Background(), "my account", "", "", false)
|
||||
store.account = newAccountWithId(context.Background(), "my account", "", "")
|
||||
|
||||
for i := 0; i < numberOfPeers; i++ {
|
||||
peerId := fmt.Sprintf("peer_%d", i)
|
||||
|
||||
@@ -103,7 +103,7 @@ func (am *DefaultAccountManager) fillEventsWithUserInfo(ctx context.Context, eve
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) getEventsUserInfo(ctx context.Context, events []*activity.Event, accountId string, userId string) (map[string]eventUserInfo, error) {
|
||||
accountUsers, err := am.Store.GetAccountUsers(ctx, store.LockingStrengthShare, accountId)
|
||||
accountUsers, err := am.Store.GetAccountUsers(ctx, store.LockingStrengthNone, accountId)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -154,7 +154,7 @@ func (am *DefaultAccountManager) getEventsExternalUserInfo(ctx context.Context,
|
||||
continue
|
||||
}
|
||||
|
||||
externalUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, id)
|
||||
externalUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, id)
|
||||
if err != nil {
|
||||
// @todo consider logging
|
||||
continue
|
||||
|
||||
@@ -49,7 +49,7 @@ func (am *DefaultAccountManager) GetGroup(ctx context.Context, accountID, groupI
|
||||
if err := am.CheckGroupPermissions(ctx, accountID, userID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return am.Store.GetGroupByID(ctx, store.LockingStrengthShare, accountID, groupID)
|
||||
return am.Store.GetGroupByID(ctx, store.LockingStrengthNone, accountID, groupID)
|
||||
}
|
||||
|
||||
// GetAllGroups returns all groups in an account
|
||||
@@ -57,12 +57,12 @@ func (am *DefaultAccountManager) GetAllGroups(ctx context.Context, accountID, us
|
||||
if err := am.CheckGroupPermissions(ctx, accountID, userID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return am.Store.GetAccountGroups(ctx, store.LockingStrengthShare, accountID)
|
||||
return am.Store.GetAccountGroups(ctx, store.LockingStrengthNone, accountID)
|
||||
}
|
||||
|
||||
// GetGroupByName filters all groups in an account by name and returns the one with the most peers
|
||||
func (am *DefaultAccountManager) GetGroupByName(ctx context.Context, groupName, accountID string) (*types.Group, error) {
|
||||
return am.Store.GetGroupByName(ctx, store.LockingStrengthShare, accountID, groupName)
|
||||
return am.Store.GetGroupByName(ctx, store.LockingStrengthNone, accountID, groupName)
|
||||
}
|
||||
|
||||
// SaveGroup object of the peers
|
||||
@@ -140,7 +140,7 @@ func (am *DefaultAccountManager) prepareGroupEvents(ctx context.Context, transac
|
||||
addedPeers := make([]string, 0)
|
||||
removedPeers := make([]string, 0)
|
||||
|
||||
oldGroup, err := transaction.GetGroupByID(ctx, store.LockingStrengthShare, accountID, newGroup.ID)
|
||||
oldGroup, err := transaction.GetGroupByID(ctx, store.LockingStrengthNone, accountID, newGroup.ID)
|
||||
if err == nil && oldGroup != nil {
|
||||
addedPeers = util.Difference(newGroup.Peers, oldGroup.Peers)
|
||||
removedPeers = util.Difference(oldGroup.Peers, newGroup.Peers)
|
||||
@@ -152,13 +152,13 @@ func (am *DefaultAccountManager) prepareGroupEvents(ctx context.Context, transac
|
||||
}
|
||||
|
||||
modifiedPeers := slices.Concat(addedPeers, removedPeers)
|
||||
peers, err := transaction.GetPeersByIDs(ctx, store.LockingStrengthShare, accountID, modifiedPeers)
|
||||
peers, err := transaction.GetPeersByIDs(ctx, store.LockingStrengthNone, accountID, modifiedPeers)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Debugf("failed to get peers for group events: %v", err)
|
||||
return nil
|
||||
}
|
||||
|
||||
settings, err := transaction.GetAccountSettings(ctx, store.LockingStrengthShare, accountID)
|
||||
settings, err := transaction.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Debugf("failed to get account settings for group events: %v", err)
|
||||
return nil
|
||||
@@ -431,7 +431,7 @@ func validateNewGroup(ctx context.Context, transaction store.Store, accountID st
|
||||
}
|
||||
|
||||
if newGroup.ID == "" && newGroup.Issued == types.GroupIssuedAPI {
|
||||
existingGroup, err := transaction.GetGroupByName(ctx, store.LockingStrengthShare, accountID, newGroup.Name)
|
||||
existingGroup, err := transaction.GetGroupByName(ctx, store.LockingStrengthNone, accountID, newGroup.Name)
|
||||
if err != nil {
|
||||
if s, ok := status.FromError(err); !ok || s.Type() != status.NotFound {
|
||||
return err
|
||||
@@ -448,7 +448,7 @@ func validateNewGroup(ctx context.Context, transaction store.Store, accountID st
|
||||
}
|
||||
|
||||
for _, peerID := range newGroup.Peers {
|
||||
_, err := transaction.GetPeerByID(ctx, store.LockingStrengthShare, accountID, peerID)
|
||||
_, err := transaction.GetPeerByID(ctx, store.LockingStrengthNone, accountID, peerID)
|
||||
if err != nil {
|
||||
return status.Errorf(status.InvalidArgument, "peer with ID \"%s\" not found", peerID)
|
||||
}
|
||||
@@ -460,7 +460,7 @@ func validateNewGroup(ctx context.Context, transaction store.Store, accountID st
|
||||
func validateDeleteGroup(ctx context.Context, transaction store.Store, group *types.Group, userID string) error {
|
||||
// disable a deleting integration group if the initiator is not an admin service user
|
||||
if group.Issued == types.GroupIssuedIntegration {
|
||||
executingUser, err := transaction.GetUserByUserID(ctx, store.LockingStrengthShare, userID)
|
||||
executingUser, err := transaction.GetUserByUserID(ctx, store.LockingStrengthNone, userID)
|
||||
if err != nil {
|
||||
return status.Errorf(status.Internal, "failed to get user")
|
||||
}
|
||||
@@ -506,7 +506,7 @@ func validateDeleteGroup(ctx context.Context, transaction store.Store, group *ty
|
||||
|
||||
// checkGroupLinkedToSettings verifies if a group is linked to any settings in the account.
|
||||
func checkGroupLinkedToSettings(ctx context.Context, transaction store.Store, group *types.Group) error {
|
||||
dnsSettings, err := transaction.GetAccountDNSSettings(ctx, store.LockingStrengthShare, group.AccountID)
|
||||
dnsSettings, err := transaction.GetAccountDNSSettings(ctx, store.LockingStrengthNone, group.AccountID)
|
||||
if err != nil {
|
||||
return status.Errorf(status.Internal, "failed to get DNS settings")
|
||||
}
|
||||
@@ -515,7 +515,7 @@ func checkGroupLinkedToSettings(ctx context.Context, transaction store.Store, gr
|
||||
return &GroupLinkError{"disabled DNS management groups", group.Name}
|
||||
}
|
||||
|
||||
settings, err := transaction.GetAccountSettings(ctx, store.LockingStrengthShare, group.AccountID)
|
||||
settings, err := transaction.GetAccountSettings(ctx, store.LockingStrengthNone, group.AccountID)
|
||||
if err != nil {
|
||||
return status.Errorf(status.Internal, "failed to get account settings")
|
||||
}
|
||||
@@ -529,7 +529,7 @@ func checkGroupLinkedToSettings(ctx context.Context, transaction store.Store, gr
|
||||
|
||||
// isGroupLinkedToRoute checks if a group is linked to any route in the account.
|
||||
func isGroupLinkedToRoute(ctx context.Context, transaction store.Store, accountID string, groupID string) (bool, *route.Route) {
|
||||
routes, err := transaction.GetAccountRoutes(ctx, store.LockingStrengthShare, accountID)
|
||||
routes, err := transaction.GetAccountRoutes(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("error retrieving routes while checking group linkage: %v", err)
|
||||
return false, nil
|
||||
@@ -549,7 +549,7 @@ func isGroupLinkedToRoute(ctx context.Context, transaction store.Store, accountI
|
||||
|
||||
// isGroupLinkedToPolicy checks if a group is linked to any policy in the account.
|
||||
func isGroupLinkedToPolicy(ctx context.Context, transaction store.Store, accountID string, groupID string) (bool, *types.Policy) {
|
||||
policies, err := transaction.GetAccountPolicies(ctx, store.LockingStrengthShare, accountID)
|
||||
policies, err := transaction.GetAccountPolicies(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("error retrieving policies while checking group linkage: %v", err)
|
||||
return false, nil
|
||||
@@ -567,7 +567,7 @@ func isGroupLinkedToPolicy(ctx context.Context, transaction store.Store, account
|
||||
|
||||
// isGroupLinkedToDns checks if a group is linked to any nameserver group in the account.
|
||||
func isGroupLinkedToDns(ctx context.Context, transaction store.Store, accountID string, groupID string) (bool, *nbdns.NameServerGroup) {
|
||||
nameServerGroups, err := transaction.GetAccountNameServerGroups(ctx, store.LockingStrengthShare, accountID)
|
||||
nameServerGroups, err := transaction.GetAccountNameServerGroups(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("error retrieving name server groups while checking group linkage: %v", err)
|
||||
return false, nil
|
||||
@@ -586,7 +586,7 @@ func isGroupLinkedToDns(ctx context.Context, transaction store.Store, accountID
|
||||
|
||||
// isGroupLinkedToSetupKey checks if a group is linked to any setup key in the account.
|
||||
func isGroupLinkedToSetupKey(ctx context.Context, transaction store.Store, accountID string, groupID string) (bool, *types.SetupKey) {
|
||||
setupKeys, err := transaction.GetAccountSetupKeys(ctx, store.LockingStrengthShare, accountID)
|
||||
setupKeys, err := transaction.GetAccountSetupKeys(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("error retrieving setup keys while checking group linkage: %v", err)
|
||||
return false, nil
|
||||
@@ -602,7 +602,7 @@ func isGroupLinkedToSetupKey(ctx context.Context, transaction store.Store, accou
|
||||
|
||||
// isGroupLinkedToUser checks if a group is linked to any user in the account.
|
||||
func isGroupLinkedToUser(ctx context.Context, transaction store.Store, accountID string, groupID string) (bool, *types.User) {
|
||||
users, err := transaction.GetAccountUsers(ctx, store.LockingStrengthShare, accountID)
|
||||
users, err := transaction.GetAccountUsers(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("error retrieving users while checking group linkage: %v", err)
|
||||
return false, nil
|
||||
@@ -618,7 +618,7 @@ func isGroupLinkedToUser(ctx context.Context, transaction store.Store, accountID
|
||||
|
||||
// isGroupLinkedToNetworkRouter checks if a group is linked to any network router in the account.
|
||||
func isGroupLinkedToNetworkRouter(ctx context.Context, transaction store.Store, accountID string, groupID string) (bool, *routerTypes.NetworkRouter) {
|
||||
routers, err := transaction.GetNetworkRoutersByAccountID(ctx, store.LockingStrengthShare, accountID)
|
||||
routers, err := transaction.GetNetworkRoutersByAccountID(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("error retrieving network routers while checking group linkage: %v", err)
|
||||
return false, nil
|
||||
@@ -638,7 +638,7 @@ func areGroupChangesAffectPeers(ctx context.Context, transaction store.Store, ac
|
||||
return false, nil
|
||||
}
|
||||
|
||||
dnsSettings, err := transaction.GetAccountDNSSettings(ctx, store.LockingStrengthShare, accountID)
|
||||
dnsSettings, err := transaction.GetAccountDNSSettings(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
@@ -666,7 +666,7 @@ func areGroupChangesAffectPeers(ctx context.Context, transaction store.Store, ac
|
||||
|
||||
// anyGroupHasPeersOrResources checks if any of the given groups in the account have peers or resources.
|
||||
func anyGroupHasPeersOrResources(ctx context.Context, transaction store.Store, accountID string, groupIDs []string) (bool, error) {
|
||||
groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthShare, accountID, groupIDs)
|
||||
groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthNone, accountID, groupIDs)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
@@ -369,7 +369,7 @@ func initTestGroupAccount(am *DefaultAccountManager) (*DefaultAccountManager, *t
|
||||
Id: "example user",
|
||||
AutoGroups: []string{groupForUsers.ID},
|
||||
}
|
||||
account := newAccountWithId(context.Background(), accountID, groupAdminUserID, domain, false)
|
||||
account := newAccountWithId(context.Background(), accountID, groupAdminUserID, domain)
|
||||
account.Routes[routeResource.ID] = routeResource
|
||||
account.Routes[routePeerGroupResource.ID] = routePeerGroupResource
|
||||
account.NameServerGroups[nameServerGroup.ID] = nameServerGroup
|
||||
|
||||
@@ -49,7 +49,7 @@ func (m *managerImpl) GetAllGroups(ctx context.Context, accountID, userID string
|
||||
return nil, err
|
||||
}
|
||||
|
||||
groups, err := m.store.GetAccountGroups(ctx, store.LockingStrengthShare, accountID)
|
||||
groups, err := m.store.GetAccountGroups(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error getting account groups: %w", err)
|
||||
}
|
||||
@@ -96,13 +96,13 @@ func (m *managerImpl) AddResourceToGroupInTransaction(ctx context.Context, trans
|
||||
return nil, fmt.Errorf("error adding resource to group: %w", err)
|
||||
}
|
||||
|
||||
group, err := transaction.GetGroupByID(ctx, store.LockingStrengthShare, accountID, groupID)
|
||||
group, err := transaction.GetGroupByID(ctx, store.LockingStrengthNone, accountID, groupID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error getting group: %w", err)
|
||||
}
|
||||
|
||||
// TODO: at some point, this will need to become a switch statement
|
||||
networkResource, err := transaction.GetNetworkResourceByID(ctx, store.LockingStrengthShare, accountID, resource.ID)
|
||||
networkResource, err := transaction.GetNetworkResourceByID(ctx, store.LockingStrengthNone, accountID, resource.ID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error getting network resource: %w", err)
|
||||
}
|
||||
@@ -120,13 +120,13 @@ func (m *managerImpl) RemoveResourceFromGroupInTransaction(ctx context.Context,
|
||||
return nil, fmt.Errorf("error removing resource from group: %w", err)
|
||||
}
|
||||
|
||||
group, err := transaction.GetGroupByID(ctx, store.LockingStrengthShare, accountID, groupID)
|
||||
group, err := transaction.GetGroupByID(ctx, store.LockingStrengthNone, accountID, groupID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error getting group: %w", err)
|
||||
}
|
||||
|
||||
// TODO: at some point, this will need to become a switch statement
|
||||
networkResource, err := transaction.GetNetworkResourceByID(ctx, store.LockingStrengthShare, accountID, resourceID)
|
||||
networkResource, err := transaction.GetNetworkResourceByID(ctx, store.LockingStrengthNone, accountID, resourceID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error getting network resource: %w", err)
|
||||
}
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
package testing_tools
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
@@ -137,7 +138,7 @@ func BuildApiBlackBoxWithDBState(t TB, sqlFile string, expectedPeerUpdate *serve
|
||||
userManager := users.NewManager(store)
|
||||
permissionsManager := permissions.NewManager(store)
|
||||
settingsManager := settings.NewManager(store, userManager, integrations.NewManager(&activity.InMemoryEventStore{}), permissionsManager)
|
||||
am, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "", &activity.InMemoryEventStore{}, geoMock, false, validatorMock, metrics, proxyController, settingsManager, permissionsManager, false)
|
||||
am, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "", &activity.InMemoryEventStore{}, geoMock, false, validatorMock, metrics, proxyController, settingsManager, permissionsManager)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create manager: %v", err)
|
||||
}
|
||||
|
||||
@@ -63,7 +63,7 @@ func (am *DefaultAccountManager) GroupValidation(ctx context.Context, accountID
|
||||
|
||||
err := am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
for _, groupID := range groupIDs {
|
||||
_, err := transaction.GetGroupByID(context.Background(), store.LockingStrengthShare, accountID, groupID)
|
||||
_, err := transaction.GetGroupByID(context.Background(), store.LockingStrengthNone, accountID, groupID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -83,17 +83,17 @@ func (am *DefaultAccountManager) GetValidatedPeers(ctx context.Context, accountI
|
||||
var peers []*nbpeer.Peer
|
||||
var settings *types.Settings
|
||||
|
||||
groups, err = am.Store.GetAccountGroups(ctx, store.LockingStrengthShare, accountID)
|
||||
groups, err = am.Store.GetAccountGroups(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
peers, err = am.Store.GetAccountPeers(ctx, store.LockingStrengthShare, accountID, "", "")
|
||||
peers, err = am.Store.GetAccountPeers(ctx, store.LockingStrengthNone, accountID, "", "")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
settings, err = am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID)
|
||||
settings, err = am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -444,7 +444,7 @@ func startManagementForTest(t *testing.T, testFile string, config *types.Config)
|
||||
permissionsManager := permissions.NewManager(store)
|
||||
|
||||
accountManager, err := BuildManager(ctx, store, peersUpdateManager, nil, "", "netbird.selfhosted",
|
||||
eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
|
||||
eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager)
|
||||
|
||||
if err != nil {
|
||||
cleanup()
|
||||
@@ -641,7 +641,7 @@ func testSyncStatusRace(t *testing.T) {
|
||||
}
|
||||
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
peer, err := am.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthShare, peerWithInvalidStatus.PublicKey().String())
|
||||
peer, err := am.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, peerWithInvalidStatus.PublicKey().String())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
|
||||
@@ -211,7 +211,7 @@ func startServer(
|
||||
port_forwarding.NewControllerMock(),
|
||||
settingsMockManager,
|
||||
permissionsManager,
|
||||
false)
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("failed creating an account manager: %v", err)
|
||||
}
|
||||
|
||||
@@ -32,7 +32,7 @@ func (am *DefaultAccountManager) GetNameServerGroup(ctx context.Context, account
|
||||
return nil, status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
return am.Store.GetNameServerGroupByID(ctx, store.LockingStrengthShare, accountID, nsGroupID)
|
||||
return am.Store.GetNameServerGroupByID(ctx, store.LockingStrengthNone, accountID, nsGroupID)
|
||||
}
|
||||
|
||||
// CreateNameServerGroup creates and saves a new nameserver group
|
||||
@@ -112,7 +112,7 @@ func (am *DefaultAccountManager) SaveNameServerGroup(ctx context.Context, accoun
|
||||
var updateAccountPeers bool
|
||||
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
oldNSGroup, err := transaction.GetNameServerGroupByID(ctx, store.LockingStrengthShare, accountID, nsGroupToSave.ID)
|
||||
oldNSGroup, err := transaction.GetNameServerGroupByID(ctx, store.LockingStrengthNone, accountID, nsGroupToSave.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -202,7 +202,7 @@ func (am *DefaultAccountManager) ListNameServerGroups(ctx context.Context, accou
|
||||
return nil, status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
return am.Store.GetAccountNameServerGroups(ctx, store.LockingStrengthShare, accountID)
|
||||
return am.Store.GetAccountNameServerGroups(ctx, store.LockingStrengthNone, accountID)
|
||||
}
|
||||
|
||||
func validateNameServerGroup(ctx context.Context, transaction store.Store, accountID string, nameserverGroup *nbdns.NameServerGroup) error {
|
||||
@@ -216,7 +216,7 @@ func validateNameServerGroup(ctx context.Context, transaction store.Store, accou
|
||||
return err
|
||||
}
|
||||
|
||||
nsServerGroups, err := transaction.GetAccountNameServerGroups(ctx, store.LockingStrengthShare, accountID)
|
||||
nsServerGroups, err := transaction.GetAccountNameServerGroups(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -226,7 +226,7 @@ func validateNameServerGroup(ctx context.Context, transaction store.Store, accou
|
||||
return err
|
||||
}
|
||||
|
||||
groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthShare, accountID, nameserverGroup.Groups)
|
||||
groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthNone, accountID, nameserverGroup.Groups)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -779,7 +779,7 @@ func createNSManager(t *testing.T) (*DefaultAccountManager, error) {
|
||||
t.Cleanup(ctrl.Finish)
|
||||
settingsMockManager := settings.NewMockManager(ctrl)
|
||||
permissionsManager := permissions.NewManager(store)
|
||||
return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.selfhosted", eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
|
||||
return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.selfhosted", eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager)
|
||||
}
|
||||
|
||||
func createNSStore(t *testing.T) (store.Store, error) {
|
||||
@@ -848,7 +848,7 @@ func initTestNSAccount(t *testing.T, am *DefaultAccountManager) (*types.Account,
|
||||
userID := testUserID
|
||||
domain := "example.com"
|
||||
|
||||
account := newAccountWithId(context.Background(), accountID, userID, domain, false)
|
||||
account := newAccountWithId(context.Background(), accountID, userID, domain)
|
||||
|
||||
account.NameServerGroups[existingNSGroup.ID] = &existingNSGroup
|
||||
|
||||
|
||||
@@ -56,7 +56,7 @@ func (m *managerImpl) GetAllNetworks(ctx context.Context, accountID, userID stri
|
||||
return nil, status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
return m.store.GetAccountNetworks(ctx, store.LockingStrengthShare, accountID)
|
||||
return m.store.GetAccountNetworks(ctx, store.LockingStrengthNone, accountID)
|
||||
}
|
||||
|
||||
func (m *managerImpl) CreateNetwork(ctx context.Context, userID string, network *types.Network) (*types.Network, error) {
|
||||
@@ -92,7 +92,7 @@ func (m *managerImpl) GetNetwork(ctx context.Context, accountID, userID, network
|
||||
return nil, status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
return m.store.GetNetworkByID(ctx, store.LockingStrengthShare, accountID, networkID)
|
||||
return m.store.GetNetworkByID(ctx, store.LockingStrengthNone, accountID, networkID)
|
||||
}
|
||||
|
||||
func (m *managerImpl) UpdateNetwork(ctx context.Context, userID string, network *types.Network) (*types.Network, error) {
|
||||
|
||||
@@ -57,7 +57,7 @@ func (m *managerImpl) GetAllResourcesInNetwork(ctx context.Context, accountID, u
|
||||
return nil, status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
return m.store.GetNetworkResourcesByNetID(ctx, store.LockingStrengthShare, accountID, networkID)
|
||||
return m.store.GetNetworkResourcesByNetID(ctx, store.LockingStrengthNone, accountID, networkID)
|
||||
}
|
||||
|
||||
func (m *managerImpl) GetAllResourcesInAccount(ctx context.Context, accountID, userID string) ([]*types.NetworkResource, error) {
|
||||
@@ -69,7 +69,7 @@ func (m *managerImpl) GetAllResourcesInAccount(ctx context.Context, accountID, u
|
||||
return nil, status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
return m.store.GetNetworkResourcesByAccountID(ctx, store.LockingStrengthShare, accountID)
|
||||
return m.store.GetNetworkResourcesByAccountID(ctx, store.LockingStrengthNone, accountID)
|
||||
}
|
||||
|
||||
func (m *managerImpl) GetAllResourceIDsInAccount(ctx context.Context, accountID, userID string) (map[string][]string, error) {
|
||||
@@ -81,7 +81,7 @@ func (m *managerImpl) GetAllResourceIDsInAccount(ctx context.Context, accountID,
|
||||
return nil, status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
resources, err := m.store.GetNetworkResourcesByAccountID(ctx, store.LockingStrengthShare, accountID)
|
||||
resources, err := m.store.GetNetworkResourcesByAccountID(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get network resources: %w", err)
|
||||
}
|
||||
@@ -113,7 +113,7 @@ func (m *managerImpl) CreateResource(ctx context.Context, userID string, resourc
|
||||
|
||||
var eventsToStore []func()
|
||||
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
_, err = transaction.GetNetworkResourceByName(ctx, store.LockingStrengthShare, resource.AccountID, resource.Name)
|
||||
_, err = transaction.GetNetworkResourceByName(ctx, store.LockingStrengthNone, resource.AccountID, resource.Name)
|
||||
if err == nil {
|
||||
return status.Errorf(status.InvalidArgument, "resource with name %s already exists", resource.Name)
|
||||
}
|
||||
@@ -174,7 +174,7 @@ func (m *managerImpl) GetResource(ctx context.Context, accountID, userID, networ
|
||||
return nil, status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
resource, err := m.store.GetNetworkResourceByID(ctx, store.LockingStrengthShare, accountID, resourceID)
|
||||
resource, err := m.store.GetNetworkResourceByID(ctx, store.LockingStrengthNone, accountID, resourceID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get network resource: %w", err)
|
||||
}
|
||||
@@ -218,17 +218,17 @@ func (m *managerImpl) UpdateResource(ctx context.Context, userID string, resourc
|
||||
return status.NewResourceNotPartOfNetworkError(resource.ID, resource.NetworkID)
|
||||
}
|
||||
|
||||
_, err = transaction.GetNetworkResourceByID(ctx, store.LockingStrengthShare, resource.AccountID, resource.ID)
|
||||
_, err = transaction.GetNetworkResourceByID(ctx, store.LockingStrengthNone, resource.AccountID, resource.ID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get network resource: %w", err)
|
||||
}
|
||||
|
||||
oldResource, err := transaction.GetNetworkResourceByName(ctx, store.LockingStrengthShare, resource.AccountID, resource.Name)
|
||||
oldResource, err := transaction.GetNetworkResourceByName(ctx, store.LockingStrengthNone, resource.AccountID, resource.Name)
|
||||
if err == nil && oldResource.ID != resource.ID {
|
||||
return status.Errorf(status.InvalidArgument, "new resource name already exists")
|
||||
}
|
||||
|
||||
oldResource, err = transaction.GetNetworkResourceByID(ctx, store.LockingStrengthShare, resource.AccountID, resource.ID)
|
||||
oldResource, err = transaction.GetNetworkResourceByID(ctx, store.LockingStrengthNone, resource.AccountID, resource.ID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get network resource: %w", err)
|
||||
}
|
||||
|
||||
@@ -54,7 +54,7 @@ func (m *managerImpl) GetAllRoutersInNetwork(ctx context.Context, accountID, use
|
||||
return nil, status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
return m.store.GetNetworkRoutersByNetID(ctx, store.LockingStrengthShare, accountID, networkID)
|
||||
return m.store.GetNetworkRoutersByNetID(ctx, store.LockingStrengthNone, accountID, networkID)
|
||||
}
|
||||
|
||||
func (m *managerImpl) GetAllRoutersInAccount(ctx context.Context, accountID, userID string) (map[string][]*types.NetworkRouter, error) {
|
||||
@@ -66,7 +66,7 @@ func (m *managerImpl) GetAllRoutersInAccount(ctx context.Context, accountID, use
|
||||
return nil, status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
routers, err := m.store.GetNetworkRoutersByAccountID(ctx, store.LockingStrengthShare, accountID)
|
||||
routers, err := m.store.GetNetworkRoutersByAccountID(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get network routers: %w", err)
|
||||
}
|
||||
@@ -93,7 +93,7 @@ func (m *managerImpl) CreateRouter(ctx context.Context, userID string, router *t
|
||||
|
||||
var network *networkTypes.Network
|
||||
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
network, err = transaction.GetNetworkByID(ctx, store.LockingStrengthShare, router.AccountID, router.NetworkID)
|
||||
network, err = transaction.GetNetworkByID(ctx, store.LockingStrengthNone, router.AccountID, router.NetworkID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get network: %w", err)
|
||||
}
|
||||
@@ -136,7 +136,7 @@ func (m *managerImpl) GetRouter(ctx context.Context, accountID, userID, networkI
|
||||
return nil, status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
router, err := m.store.GetNetworkRouterByID(ctx, store.LockingStrengthShare, accountID, routerID)
|
||||
router, err := m.store.GetNetworkRouterByID(ctx, store.LockingStrengthNone, accountID, routerID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get network router: %w", err)
|
||||
}
|
||||
@@ -162,7 +162,7 @@ func (m *managerImpl) UpdateRouter(ctx context.Context, userID string, router *t
|
||||
|
||||
var network *networkTypes.Network
|
||||
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
network, err = transaction.GetNetworkByID(ctx, store.LockingStrengthShare, router.AccountID, router.NetworkID)
|
||||
network, err = transaction.GetNetworkByID(ctx, store.LockingStrengthNone, router.AccountID, router.NetworkID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get network: %w", err)
|
||||
}
|
||||
@@ -232,7 +232,7 @@ func (m *managerImpl) DeleteRouter(ctx context.Context, accountID, userID, netwo
|
||||
}
|
||||
|
||||
func (m *managerImpl) DeleteRouterInTransaction(ctx context.Context, transaction store.Store, accountID, userID, networkID, routerID string) (func(), error) {
|
||||
network, err := transaction.GetNetworkByID(ctx, store.LockingStrengthShare, accountID, networkID)
|
||||
network, err := transaction.GetNetworkByID(ctx, store.LockingStrengthNone, accountID, networkID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get network: %w", err)
|
||||
}
|
||||
|
||||
@@ -35,7 +35,7 @@ import (
|
||||
// GetPeers returns a list of peers under the given account filtering out peers that do not belong to a user if
|
||||
// the current user is not an admin.
|
||||
func (am *DefaultAccountManager) GetPeers(ctx context.Context, accountID, userID, nameFilter, ipFilter string) ([]*nbpeer.Peer, error) {
|
||||
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID)
|
||||
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -45,7 +45,7 @@ func (am *DefaultAccountManager) GetPeers(ctx context.Context, accountID, userID
|
||||
return nil, status.NewPermissionValidationError(err)
|
||||
}
|
||||
|
||||
accountPeers, err := am.Store.GetAccountPeers(ctx, store.LockingStrengthShare, accountID, nameFilter, ipFilter)
|
||||
accountPeers, err := am.Store.GetAccountPeers(ctx, store.LockingStrengthNone, accountID, nameFilter, ipFilter)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -55,7 +55,7 @@ func (am *DefaultAccountManager) GetPeers(ctx context.Context, accountID, userID
|
||||
return accountPeers, nil
|
||||
}
|
||||
|
||||
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID)
|
||||
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get account settings: %w", err)
|
||||
}
|
||||
@@ -127,7 +127,7 @@ func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubK
|
||||
}
|
||||
|
||||
if peer.AddedWithSSOLogin() {
|
||||
settings, err = am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID)
|
||||
settings, err = am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -216,7 +216,7 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user
|
||||
return err
|
||||
}
|
||||
|
||||
settings, err = transaction.GetAccountSettings(ctx, store.LockingStrengthShare, accountID)
|
||||
settings, err = transaction.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -335,7 +335,7 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer
|
||||
return status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
peerAccountID, err := am.Store.GetAccountIDByPeerID(ctx, store.LockingStrengthShare, peerID)
|
||||
peerAccountID, err := am.Store.GetAccountIDByPeerID(ctx, store.LockingStrengthNone, peerID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -468,7 +468,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
|
||||
addedByUser := false
|
||||
if len(userID) > 0 {
|
||||
addedByUser = true
|
||||
accountID, err = am.Store.GetAccountIDByUserID(ctx, store.LockingStrengthShare, userID)
|
||||
accountID, err = am.Store.GetAccountIDByUserID(ctx, store.LockingStrengthNone, userID)
|
||||
} else {
|
||||
accountID, err = am.Store.GetAccountIDBySetupKey(ctx, encodedHashedKey)
|
||||
}
|
||||
@@ -488,7 +488,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
|
||||
// and the peer disconnects with a timeout and tries to register again.
|
||||
// We just check if this machine has been registered before and reject the second registration.
|
||||
// The connecting peer should be able to recover with a retry.
|
||||
_, err = am.Store.GetPeerByPeerPubKey(ctx, store.LockingStrengthShare, peer.Key)
|
||||
_, err = am.Store.GetPeerByPeerPubKey(ctx, store.LockingStrengthNone, peer.Key)
|
||||
if err == nil {
|
||||
return nil, nil, nil, status.Errorf(status.PreconditionFailed, "peer has been already registered")
|
||||
}
|
||||
@@ -584,7 +584,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
|
||||
ExtraDNSLabels: peer.ExtraDNSLabels,
|
||||
AllowExtraDNSLabels: allowExtraDNSLabels,
|
||||
}
|
||||
settings, err := transaction.GetAccountSettings(ctx, store.LockingStrengthShare, accountID)
|
||||
settings, err := transaction.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get account settings: %w", err)
|
||||
}
|
||||
@@ -674,7 +674,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
|
||||
}
|
||||
|
||||
func getFreeIP(ctx context.Context, transaction store.Store, accountID string) (net.IP, error) {
|
||||
takenIps, err := transaction.GetTakenIPs(ctx, store.LockingStrengthShare, accountID)
|
||||
takenIps, err := transaction.GetTakenIPs(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get taken IPs: %w", err)
|
||||
}
|
||||
@@ -706,7 +706,7 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync types.PeerSy
|
||||
var err error
|
||||
var postureChecks []*posture.Checks
|
||||
|
||||
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID)
|
||||
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
@@ -718,7 +718,7 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync types.PeerSy
|
||||
}
|
||||
|
||||
if peer.UserID != "" {
|
||||
user, err := transaction.GetUserByUserID(ctx, store.LockingStrengthShare, peer.UserID)
|
||||
user, err := transaction.GetUserByUserID(ctx, store.LockingStrengthNone, peer.UserID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -821,7 +821,7 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login types.Peer
|
||||
var isPeerUpdated bool
|
||||
var postureChecks []*posture.Checks
|
||||
|
||||
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID)
|
||||
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
@@ -906,7 +906,7 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login types.Peer
|
||||
|
||||
// getPeerPostureChecks returns the posture checks for the peer.
|
||||
func getPeerPostureChecks(ctx context.Context, transaction store.Store, accountID, peerID string) ([]*posture.Checks, error) {
|
||||
policies, err := transaction.GetAccountPolicies(ctx, store.LockingStrengthShare, accountID)
|
||||
policies, err := transaction.GetAccountPolicies(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -930,7 +930,7 @@ func getPeerPostureChecks(ctx context.Context, transaction store.Store, accountI
|
||||
peerPostureChecksIDs = append(peerPostureChecksIDs, postureChecksIDs...)
|
||||
}
|
||||
|
||||
peerPostureChecks, err := transaction.GetPostureChecksByIDs(ctx, store.LockingStrengthShare, accountID, peerPostureChecksIDs)
|
||||
peerPostureChecks, err := transaction.GetPostureChecksByIDs(ctx, store.LockingStrengthNone, accountID, peerPostureChecksIDs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -945,7 +945,7 @@ func processPeerPostureChecks(ctx context.Context, transaction store.Store, poli
|
||||
continue
|
||||
}
|
||||
|
||||
sourceGroups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthShare, accountID, rule.Sources)
|
||||
sourceGroups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthNone, accountID, rule.Sources)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -970,7 +970,7 @@ func processPeerPostureChecks(ctx context.Context, transaction store.Store, poli
|
||||
// with no JWT token and usually no setup-key. As the client can send up to two login request to check if it is expired
|
||||
// and before starting the engine, we do the checks without an account lock to avoid piling up requests.
|
||||
func (am *DefaultAccountManager) checkIFPeerNeedsLoginWithoutLock(ctx context.Context, accountID string, login types.PeerLogin) error {
|
||||
peer, err := am.Store.GetPeerByPeerPubKey(ctx, store.LockingStrengthShare, login.WireGuardPubKey)
|
||||
peer, err := am.Store.GetPeerByPeerPubKey(ctx, store.LockingStrengthNone, login.WireGuardPubKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -981,7 +981,7 @@ func (am *DefaultAccountManager) checkIFPeerNeedsLoginWithoutLock(ctx context.Co
|
||||
return nil
|
||||
}
|
||||
|
||||
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID)
|
||||
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -1000,7 +1000,7 @@ func (am *DefaultAccountManager) getValidatedPeerWithMap(ctx context.Context, is
|
||||
}()
|
||||
|
||||
if isRequiresApproval {
|
||||
network, err := am.Store.GetAccountNetwork(ctx, store.LockingStrengthShare, accountID)
|
||||
network, err := am.Store.GetAccountNetwork(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
@@ -1062,7 +1062,7 @@ func (am *DefaultAccountManager) handleExpiredPeer(ctx context.Context, transact
|
||||
log.WithContext(ctx).Debugf("failed to update user last login: %v", err)
|
||||
}
|
||||
|
||||
settings, err := transaction.GetAccountSettings(ctx, store.LockingStrengthShare, peer.AccountID)
|
||||
settings, err := transaction.GetAccountSettings(ctx, store.LockingStrengthNone, peer.AccountID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get account settings: %w", err)
|
||||
}
|
||||
@@ -1104,7 +1104,7 @@ func peerLoginExpired(ctx context.Context, peer *nbpeer.Peer, settings *types.Se
|
||||
|
||||
// GetPeer for a given accountID, peerID and userID error if not found.
|
||||
func (am *DefaultAccountManager) GetPeer(ctx context.Context, accountID, peerID, userID string) (*nbpeer.Peer, error) {
|
||||
peer, err := am.Store.GetPeerByID(ctx, store.LockingStrengthShare, accountID, peerID)
|
||||
peer, err := am.Store.GetPeerByID(ctx, store.LockingStrengthNone, accountID, peerID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -1117,7 +1117,7 @@ func (am *DefaultAccountManager) GetPeer(ctx context.Context, accountID, peerID,
|
||||
return peer, nil
|
||||
}
|
||||
|
||||
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID)
|
||||
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -1143,7 +1143,7 @@ func (am *DefaultAccountManager) checkIfUserOwnsPeer(ctx context.Context, accoun
|
||||
|
||||
// it is also possible that user doesn't own the peer but some of his peers have access to it,
|
||||
// this is a valid case, show the peer as well.
|
||||
userPeers, err := am.Store.GetUserPeers(ctx, store.LockingStrengthShare, accountID, userID)
|
||||
userPeers, err := am.Store.GetUserPeers(ctx, store.LockingStrengthNone, accountID, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -1328,7 +1328,7 @@ func (am *DefaultAccountManager) UpdateAccountPeer(ctx context.Context, accountI
|
||||
// If there is no peer that expires this function returns false and a duration of 0.
|
||||
// This function only considers peers that haven't been expired yet and that are connected.
|
||||
func (am *DefaultAccountManager) getNextPeerExpiration(ctx context.Context, accountID string) (time.Duration, bool) {
|
||||
peersWithExpiry, err := am.Store.GetAccountPeersWithExpiration(ctx, store.LockingStrengthShare, accountID)
|
||||
peersWithExpiry, err := am.Store.GetAccountPeersWithExpiration(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get peers with expiration: %v", err)
|
||||
return peerSchedulerRetryInterval, true
|
||||
@@ -1338,7 +1338,7 @@ func (am *DefaultAccountManager) getNextPeerExpiration(ctx context.Context, acco
|
||||
return 0, false
|
||||
}
|
||||
|
||||
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID)
|
||||
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get account settings: %v", err)
|
||||
return peerSchedulerRetryInterval, true
|
||||
@@ -1372,7 +1372,7 @@ func (am *DefaultAccountManager) getNextPeerExpiration(ctx context.Context, acco
|
||||
// If there is no peer that expires this function returns false and a duration of 0.
|
||||
// This function only considers peers that haven't been expired yet and that are not connected.
|
||||
func (am *DefaultAccountManager) getNextInactivePeerExpiration(ctx context.Context, accountID string) (time.Duration, bool) {
|
||||
peersWithInactivity, err := am.Store.GetAccountPeersWithInactivity(ctx, store.LockingStrengthShare, accountID)
|
||||
peersWithInactivity, err := am.Store.GetAccountPeersWithInactivity(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get peers with inactivity: %v", err)
|
||||
return peerSchedulerRetryInterval, true
|
||||
@@ -1382,7 +1382,7 @@ func (am *DefaultAccountManager) getNextInactivePeerExpiration(ctx context.Conte
|
||||
return 0, false
|
||||
}
|
||||
|
||||
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID)
|
||||
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get account settings: %v", err)
|
||||
return peerSchedulerRetryInterval, true
|
||||
@@ -1413,12 +1413,12 @@ func (am *DefaultAccountManager) getNextInactivePeerExpiration(ctx context.Conte
|
||||
|
||||
// getExpiredPeers returns peers that have been expired.
|
||||
func (am *DefaultAccountManager) getExpiredPeers(ctx context.Context, accountID string) ([]*nbpeer.Peer, error) {
|
||||
peersWithExpiry, err := am.Store.GetAccountPeersWithExpiration(ctx, store.LockingStrengthShare, accountID)
|
||||
peersWithExpiry, err := am.Store.GetAccountPeersWithExpiration(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID)
|
||||
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -1436,12 +1436,12 @@ func (am *DefaultAccountManager) getExpiredPeers(ctx context.Context, accountID
|
||||
|
||||
// getInactivePeers returns peers that have been expired by inactivity
|
||||
func (am *DefaultAccountManager) getInactivePeers(ctx context.Context, accountID string) ([]*nbpeer.Peer, error) {
|
||||
peersWithInactivity, err := am.Store.GetAccountPeersWithInactivity(ctx, store.LockingStrengthShare, accountID)
|
||||
peersWithInactivity, err := am.Store.GetAccountPeersWithInactivity(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID)
|
||||
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -1459,12 +1459,12 @@ func (am *DefaultAccountManager) getInactivePeers(ctx context.Context, accountID
|
||||
|
||||
// GetPeerGroups returns groups that the peer is part of.
|
||||
func (am *DefaultAccountManager) GetPeerGroups(ctx context.Context, accountID, peerID string) ([]*types.Group, error) {
|
||||
return am.Store.GetPeerGroups(ctx, store.LockingStrengthShare, accountID, peerID)
|
||||
return am.Store.GetPeerGroups(ctx, store.LockingStrengthNone, accountID, peerID)
|
||||
}
|
||||
|
||||
// getPeerGroupIDs returns the IDs of the groups that the peer is part of.
|
||||
func getPeerGroupIDs(ctx context.Context, transaction store.Store, accountID string, peerID string) ([]string, error) {
|
||||
groups, err := transaction.GetPeerGroups(ctx, store.LockingStrengthShare, accountID, peerID)
|
||||
groups, err := transaction.GetPeerGroups(ctx, store.LockingStrengthNone, accountID, peerID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -1478,7 +1478,7 @@ func getPeerGroupIDs(ctx context.Context, transaction store.Store, accountID str
|
||||
}
|
||||
|
||||
func getPeerDNSLabels(ctx context.Context, transaction store.Store, accountID string) (types.LookupMap, error) {
|
||||
dnsLabels, err := transaction.GetPeerLabelsInAccount(ctx, store.LockingStrengthShare, accountID)
|
||||
dnsLabels, err := transaction.GetPeerLabelsInAccount(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -1505,7 +1505,7 @@ func isPeerInActiveGroup(ctx context.Context, transaction store.Store, accountID
|
||||
func deletePeers(ctx context.Context, am *DefaultAccountManager, transaction store.Store, accountID, userID string, peers []*nbpeer.Peer) ([]func(), error) {
|
||||
var peerDeletedEvents []func()
|
||||
|
||||
settings, err := transaction.GetAccountSettings(ctx, store.LockingStrengthShare, accountID)
|
||||
settings, err := transaction.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -1516,7 +1516,7 @@ func deletePeers(ctx context.Context, am *DefaultAccountManager, transaction sto
|
||||
return nil, err
|
||||
}
|
||||
|
||||
network, err := transaction.GetAccountNetwork(ctx, store.LockingStrengthShare, accountID)
|
||||
network, err := transaction.GetAccountNetwork(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -1577,7 +1577,7 @@ func (am *DefaultAccountManager) validatePeerDelete(ctx context.Context, transac
|
||||
|
||||
// isPeerLinkedToNetworkRouter checks if a peer is linked to any network router in the account.
|
||||
func isPeerLinkedToNetworkRouter(ctx context.Context, transaction store.Store, accountID string, peerID string) (bool, *routerTypes.NetworkRouter) {
|
||||
routers, err := transaction.GetNetworkRoutersByAccountID(ctx, store.LockingStrengthShare, accountID)
|
||||
routers, err := transaction.GetNetworkRoutersByAccountID(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("error retrieving network routers while checking peer linkage: %v", err)
|
||||
return false, nil
|
||||
|
||||
@@ -31,7 +31,7 @@ type Peer struct {
|
||||
// Status peer's management connection status
|
||||
Status *PeerStatus `gorm:"embedded;embeddedPrefix:peer_status_"`
|
||||
// The user ID that registered the peer
|
||||
UserID string
|
||||
UserID string `gorm:"index"`
|
||||
// SSHKey is a public SSH key of the peer
|
||||
SSHKey string
|
||||
// SSHEnabled indicates whether SSH server is enabled on the peer
|
||||
|
||||
@@ -480,7 +480,7 @@ func TestDefaultAccountManager_GetPeer(t *testing.T) {
|
||||
accountID := "test_account"
|
||||
adminUser := "account_creator"
|
||||
someUser := "some_user"
|
||||
account := newAccountWithId(context.Background(), accountID, adminUser, "", false)
|
||||
account := newAccountWithId(context.Background(), accountID, adminUser, "")
|
||||
account.Users[someUser] = &types.User{
|
||||
Id: someUser,
|
||||
Role: types.UserRoleUser,
|
||||
@@ -667,7 +667,7 @@ func TestDefaultAccountManager_GetPeers(t *testing.T) {
|
||||
accountID := "test_account"
|
||||
adminUser := "account_creator"
|
||||
someUser := "some_user"
|
||||
account := newAccountWithId(context.Background(), accountID, adminUser, "", false)
|
||||
account := newAccountWithId(context.Background(), accountID, adminUser, "")
|
||||
account.Users[someUser] = &types.User{
|
||||
Id: someUser,
|
||||
Role: testCase.role,
|
||||
@@ -737,7 +737,7 @@ func setupTestAccountManager(b testing.TB, peers int, groups int) (*DefaultAccou
|
||||
adminUser := "account_creator"
|
||||
regularUser := "regular_user"
|
||||
|
||||
account := newAccountWithId(context.Background(), accountID, adminUser, "", false)
|
||||
account := newAccountWithId(context.Background(), accountID, adminUser, "")
|
||||
account.Users[regularUser] = &types.User{
|
||||
Id: regularUser,
|
||||
Role: types.UserRoleUser,
|
||||
@@ -1267,7 +1267,7 @@ func Test_RegisterPeerByUser(t *testing.T) {
|
||||
settingsMockManager := settings.NewMockManager(ctrl)
|
||||
permissionsManager := permissions.NewManager(s)
|
||||
|
||||
am, err := BuildManager(context.Background(), s, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
|
||||
am, err := BuildManager(context.Background(), s, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager)
|
||||
assert.NoError(t, err)
|
||||
|
||||
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||
@@ -1301,7 +1301,7 @@ func Test_RegisterPeerByUser(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, newPeer.ExtraDNSLabels, addedPeer.ExtraDNSLabels)
|
||||
|
||||
peer, err := s.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthShare, addedPeer.Key)
|
||||
peer, err := s.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, addedPeer.Key)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, peer.AccountID, existingAccountID)
|
||||
assert.Equal(t, peer.UserID, existingUserID)
|
||||
@@ -1342,7 +1342,7 @@ func Test_RegisterPeerBySetupKey(t *testing.T) {
|
||||
settingsMockManager := settings.NewMockManager(ctrl)
|
||||
permissionsManager := permissions.NewManager(s)
|
||||
|
||||
am, err := BuildManager(context.Background(), s, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
|
||||
am, err := BuildManager(context.Background(), s, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager)
|
||||
assert.NoError(t, err)
|
||||
|
||||
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||
@@ -1423,7 +1423,7 @@ func Test_RegisterPeerBySetupKey(t *testing.T) {
|
||||
assert.NotNil(t, addedPeer, "addedPeer should not be nil on success")
|
||||
assert.Equal(t, currentPeer.ExtraDNSLabels, addedPeer.ExtraDNSLabels, "ExtraDNSLabels mismatch")
|
||||
|
||||
peerFromStore, err := s.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthShare, currentPeer.Key)
|
||||
peerFromStore, err := s.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, currentPeer.Key)
|
||||
require.NoError(t, err, "Failed to get peer by pub key: %s", currentPeer.Key)
|
||||
assert.Equal(t, existingAccountID, peerFromStore.AccountID, "AccountID mismatch for peer from store")
|
||||
assert.Equal(t, currentPeer.ExtraDNSLabels, peerFromStore.ExtraDNSLabels, "ExtraDNSLabels mismatch for peer from store")
|
||||
@@ -1477,7 +1477,7 @@ func Test_RegisterPeerRollbackOnFailure(t *testing.T) {
|
||||
|
||||
permissionsManager := permissions.NewManager(s)
|
||||
|
||||
am, err := BuildManager(context.Background(), s, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
|
||||
am, err := BuildManager(context.Background(), s, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager)
|
||||
assert.NoError(t, err)
|
||||
|
||||
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||
@@ -1505,7 +1505,7 @@ func Test_RegisterPeerRollbackOnFailure(t *testing.T) {
|
||||
_, _, _, err = am.AddPeer(context.Background(), faultyKey, "", newPeer)
|
||||
require.Error(t, err)
|
||||
|
||||
_, err = s.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthShare, newPeer.Key)
|
||||
_, err = s.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, newPeer.Key)
|
||||
require.Error(t, err)
|
||||
|
||||
account, err := s.GetAccount(context.Background(), existingAccountID)
|
||||
@@ -1546,7 +1546,7 @@ func Test_LoginPeer(t *testing.T) {
|
||||
settingsMockManager := settings.NewMockManager(ctrl)
|
||||
permissionsManager := permissions.NewManager(s)
|
||||
|
||||
am, err := BuildManager(context.Background(), s, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
|
||||
am, err := BuildManager(context.Background(), s, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager)
|
||||
assert.NoError(t, err)
|
||||
|
||||
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||
@@ -1671,7 +1671,7 @@ func Test_LoginPeer(t *testing.T) {
|
||||
|
||||
assert.Equal(t, existingAccountID, loggedinPeer.AccountID, "AccountID mismatch for logged peer")
|
||||
|
||||
peerFromStore, err := s.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthShare, loginInput.WireGuardPubKey)
|
||||
peerFromStore, err := s.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, loginInput.WireGuardPubKey)
|
||||
require.NoError(t, err, "Failed to get peer by pub key: %s", loginInput.WireGuardPubKey)
|
||||
assert.Equal(t, existingAccountID, peerFromStore.AccountID, "AccountID mismatch for peer from store")
|
||||
assert.Equal(t, loggedinPeer.ID, peerFromStore.ID, "Peer ID mismatch between loggedinPeer and peerFromStore")
|
||||
@@ -2052,7 +2052,7 @@ func Test_DeletePeer(t *testing.T) {
|
||||
// account with an admin and a regular user
|
||||
accountID := "test_account"
|
||||
adminUser := "account_creator"
|
||||
account := newAccountWithId(context.Background(), accountID, adminUser, "", false)
|
||||
account := newAccountWithId(context.Background(), accountID, adminUser, "")
|
||||
account.Peers = map[string]*nbpeer.Peer{
|
||||
"peer1": {
|
||||
ID: "peer1",
|
||||
|
||||
@@ -42,7 +42,7 @@ func (m *managerImpl) GetPeer(ctx context.Context, accountID, userID, peerID str
|
||||
return nil, status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
return m.store.GetPeerByID(ctx, store.LockingStrengthShare, accountID, peerID)
|
||||
return m.store.GetPeerByID(ctx, store.LockingStrengthNone, accountID, peerID)
|
||||
}
|
||||
|
||||
func (m *managerImpl) GetAllPeers(ctx context.Context, accountID, userID string) ([]*peer.Peer, error) {
|
||||
@@ -52,12 +52,12 @@ func (m *managerImpl) GetAllPeers(ctx context.Context, accountID, userID string)
|
||||
}
|
||||
|
||||
if !allowed {
|
||||
return m.store.GetUserPeers(ctx, store.LockingStrengthShare, accountID, userID)
|
||||
return m.store.GetUserPeers(ctx, store.LockingStrengthNone, accountID, userID)
|
||||
}
|
||||
|
||||
return m.store.GetAccountPeers(ctx, store.LockingStrengthShare, accountID, "", "")
|
||||
return m.store.GetAccountPeers(ctx, store.LockingStrengthNone, accountID, "", "")
|
||||
}
|
||||
|
||||
func (m *managerImpl) GetPeerAccountID(ctx context.Context, peerID string) (string, error) {
|
||||
return m.store.GetAccountIDByPeerID(ctx, store.LockingStrengthShare, peerID)
|
||||
return m.store.GetAccountIDByPeerID(ctx, store.LockingStrengthNone, peerID)
|
||||
}
|
||||
|
||||
@@ -45,7 +45,7 @@ func (m *managerImpl) ValidateUserPermissions(
|
||||
return true, nil
|
||||
}
|
||||
|
||||
user, err := m.store.GetUserByUserID(ctx, store.LockingStrengthShare, userID)
|
||||
user, err := m.store.GetUserByUserID(ctx, store.LockingStrengthNone, userID)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
@@ -27,7 +27,7 @@ func (am *DefaultAccountManager) GetPolicy(ctx context.Context, accountID, polic
|
||||
return nil, status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
return am.Store.GetPolicyByID(ctx, store.LockingStrengthShare, accountID, policyID)
|
||||
return am.Store.GetPolicyByID(ctx, store.LockingStrengthNone, accountID, policyID)
|
||||
}
|
||||
|
||||
// SavePolicy in the store
|
||||
@@ -142,13 +142,13 @@ func (am *DefaultAccountManager) ListPolicies(ctx context.Context, accountID, us
|
||||
return nil, status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
return am.Store.GetAccountPolicies(ctx, store.LockingStrengthShare, accountID)
|
||||
return am.Store.GetAccountPolicies(ctx, store.LockingStrengthNone, accountID)
|
||||
}
|
||||
|
||||
// arePolicyChangesAffectPeers checks if changes to a policy will affect any associated peers.
|
||||
func arePolicyChangesAffectPeers(ctx context.Context, transaction store.Store, accountID string, policy *types.Policy, isUpdate bool) (bool, error) {
|
||||
if isUpdate {
|
||||
existingPolicy, err := transaction.GetPolicyByID(ctx, store.LockingStrengthShare, accountID, policy.ID)
|
||||
existingPolicy, err := transaction.GetPolicyByID(ctx, store.LockingStrengthNone, accountID, policy.ID)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
@@ -173,7 +173,7 @@ func arePolicyChangesAffectPeers(ctx context.Context, transaction store.Store, a
|
||||
// validatePolicy validates the policy and its rules.
|
||||
func validatePolicy(ctx context.Context, transaction store.Store, accountID string, policy *types.Policy) error {
|
||||
if policy.ID != "" {
|
||||
_, err := transaction.GetPolicyByID(ctx, store.LockingStrengthShare, accountID, policy.ID)
|
||||
_, err := transaction.GetPolicyByID(ctx, store.LockingStrengthNone, accountID, policy.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -182,12 +182,12 @@ func validatePolicy(ctx context.Context, transaction store.Store, accountID stri
|
||||
policy.AccountID = accountID
|
||||
}
|
||||
|
||||
groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthShare, accountID, policy.RuleGroups())
|
||||
groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthNone, accountID, policy.RuleGroups())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
postureChecks, err := transaction.GetPostureChecksByIDs(ctx, store.LockingStrengthShare, accountID, policy.SourcePostureChecks)
|
||||
postureChecks, err := transaction.GetPostureChecksByIDs(ctx, store.LockingStrengthNone, accountID, policy.SourcePostureChecks)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -27,7 +27,7 @@ func (am *DefaultAccountManager) GetPostureChecks(ctx context.Context, accountID
|
||||
return nil, status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
return am.Store.GetPostureChecksByID(ctx, store.LockingStrengthShare, accountID, postureChecksID)
|
||||
return am.Store.GetPostureChecksByID(ctx, store.LockingStrengthNone, accountID, postureChecksID)
|
||||
}
|
||||
|
||||
// SavePostureChecks saves a posture check.
|
||||
@@ -101,7 +101,7 @@ func (am *DefaultAccountManager) DeletePostureChecks(ctx context.Context, accoun
|
||||
var postureChecks *posture.Checks
|
||||
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
postureChecks, err = transaction.GetPostureChecksByID(ctx, store.LockingStrengthShare, accountID, postureChecksID)
|
||||
postureChecks, err = transaction.GetPostureChecksByID(ctx, store.LockingStrengthNone, accountID, postureChecksID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -135,7 +135,7 @@ func (am *DefaultAccountManager) ListPostureChecks(ctx context.Context, accountI
|
||||
return nil, status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
return am.Store.GetAccountPostureChecks(ctx, store.LockingStrengthShare, accountID)
|
||||
return am.Store.GetAccountPostureChecks(ctx, store.LockingStrengthNone, accountID)
|
||||
}
|
||||
|
||||
// getPeerPostureChecks returns the posture checks applied for a given peer.
|
||||
@@ -161,7 +161,7 @@ func (am *DefaultAccountManager) getPeerPostureChecks(account *types.Account, pe
|
||||
|
||||
// arePostureCheckChangesAffectPeers checks if the changes in posture checks are affecting peers.
|
||||
func arePostureCheckChangesAffectPeers(ctx context.Context, transaction store.Store, accountID, postureCheckID string) (bool, error) {
|
||||
policies, err := transaction.GetAccountPolicies(ctx, store.LockingStrengthShare, accountID)
|
||||
policies, err := transaction.GetAccountPolicies(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
@@ -190,14 +190,14 @@ func validatePostureChecks(ctx context.Context, transaction store.Store, account
|
||||
|
||||
// If the posture check already has an ID, verify its existence in the store.
|
||||
if postureChecks.ID != "" {
|
||||
if _, err := transaction.GetPostureChecksByID(ctx, store.LockingStrengthShare, accountID, postureChecks.ID); err != nil {
|
||||
if _, err := transaction.GetPostureChecksByID(ctx, store.LockingStrengthNone, accountID, postureChecks.ID); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// For new posture checks, ensure no duplicates by name.
|
||||
checks, err := transaction.GetAccountPostureChecks(ctx, store.LockingStrengthShare, accountID)
|
||||
checks, err := transaction.GetAccountPostureChecks(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -259,7 +259,7 @@ func isPeerInPolicySourceGroups(account *types.Account, peerID string, policy *t
|
||||
|
||||
// isPostureCheckLinkedToPolicy checks whether the posture check is linked to any account policy.
|
||||
func isPostureCheckLinkedToPolicy(ctx context.Context, transaction store.Store, postureChecksID, accountID string) error {
|
||||
policies, err := transaction.GetAccountPolicies(ctx, store.LockingStrengthShare, accountID)
|
||||
policies, err := transaction.GetAccountPolicies(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -106,7 +106,7 @@ func initTestPostureChecksAccount(am *DefaultAccountManager) (*types.Account, er
|
||||
Role: types.UserRoleUser,
|
||||
}
|
||||
|
||||
account := newAccountWithId(context.Background(), accountID, groupAdminUserID, domain, false)
|
||||
account := newAccountWithId(context.Background(), accountID, groupAdminUserID, domain)
|
||||
account.Users[admin.Id] = admin
|
||||
account.Users[user.Id] = user
|
||||
|
||||
|
||||
@@ -30,7 +30,7 @@ func (am *DefaultAccountManager) GetRoute(ctx context.Context, accountID string,
|
||||
return nil, status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
return am.Store.GetRouteByID(ctx, store.LockingStrengthShare, accountID, string(routeID))
|
||||
return am.Store.GetRouteByID(ctx, store.LockingStrengthNone, accountID, string(routeID))
|
||||
}
|
||||
|
||||
// checkRoutePrefixOrDomainsExistForPeers checks if a route with a given prefix exists for a single peer or multiple peer groups.
|
||||
@@ -59,7 +59,7 @@ func checkRoutePrefixOrDomainsExistForPeers(ctx context.Context, transaction sto
|
||||
seenPeers[string(prefixRoute.ID)] = true
|
||||
}
|
||||
|
||||
peerGroupsMap, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthShare, accountID, prefixRoute.PeerGroups)
|
||||
peerGroupsMap, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthNone, accountID, prefixRoute.PeerGroups)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -83,7 +83,7 @@ func checkRoutePrefixOrDomainsExistForPeers(ctx context.Context, transaction sto
|
||||
|
||||
if peerID := checkRoute.Peer; peerID != "" {
|
||||
// check that peerID exists and is not in any route as single peer or part of the group
|
||||
_, err = transaction.GetPeerByID(context.Background(), store.LockingStrengthShare, accountID, peerID)
|
||||
_, err = transaction.GetPeerByID(context.Background(), store.LockingStrengthNone, accountID, peerID)
|
||||
if err != nil {
|
||||
return status.Errorf(status.InvalidArgument, "peer with ID %s not found", peerID)
|
||||
}
|
||||
@@ -104,7 +104,7 @@ func checkRoutePrefixOrDomainsExistForPeers(ctx context.Context, transaction sto
|
||||
}
|
||||
|
||||
// check that the peers from peerGroupIDs groups are not the same peers we saw in routesWithPrefix
|
||||
peersMap, err := transaction.GetPeersByIDs(ctx, store.LockingStrengthShare, accountID, group.Peers)
|
||||
peersMap, err := transaction.GetPeersByIDs(ctx, store.LockingStrengthNone, accountID, group.Peers)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -310,7 +310,7 @@ func (am *DefaultAccountManager) ListRoutes(ctx context.Context, accountID, user
|
||||
return nil, status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
return am.Store.GetAccountRoutes(ctx, store.LockingStrengthShare, accountID)
|
||||
return am.Store.GetAccountRoutes(ctx, store.LockingStrengthNone, accountID)
|
||||
}
|
||||
|
||||
func validateRoute(ctx context.Context, transaction store.Store, accountID string, routeToSave *route.Route) error {
|
||||
@@ -353,7 +353,7 @@ func validateRoute(ctx context.Context, transaction store.Store, accountID strin
|
||||
// validateRouteGroups validates the route groups and returns the validated groups map.
|
||||
func validateRouteGroups(ctx context.Context, transaction store.Store, accountID string, routeToSave *route.Route) (map[string]*types.Group, error) {
|
||||
groupsToValidate := slices.Concat(routeToSave.Groups, routeToSave.PeerGroups, routeToSave.AccessControlGroups)
|
||||
groupsMap, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthShare, accountID, groupsToValidate)
|
||||
groupsMap, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthNone, accountID, groupsToValidate)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -494,7 +494,7 @@ func areRouteChangesAffectPeers(ctx context.Context, transaction store.Store, ro
|
||||
|
||||
// GetRoutesByPrefixOrDomains return list of routes by account and route prefix
|
||||
func getRoutesByPrefixOrDomains(ctx context.Context, transaction store.Store, accountID string, prefix netip.Prefix, domains domain.List) ([]*route.Route, error) {
|
||||
accountRoutes, err := transaction.GetAccountRoutes(ctx, store.LockingStrengthShare, accountID)
|
||||
accountRoutes, err := transaction.GetAccountRoutes(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -1100,7 +1100,7 @@ func TestGetNetworkMap_RouteSyncPeerGroups(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, peer4Routes.Routes, 1, "HA route should have 1 server route")
|
||||
|
||||
groups, err := am.Store.GetAccountGroups(context.Background(), store.LockingStrengthShare, account.Id)
|
||||
groups, err := am.Store.GetAccountGroups(context.Background(), store.LockingStrengthNone, account.Id)
|
||||
require.NoError(t, err)
|
||||
var groupHA1, groupHA2 *types.Group
|
||||
for _, group := range groups {
|
||||
@@ -1284,7 +1284,7 @@ func createRouterManager(t *testing.T) (*DefaultAccountManager, error) {
|
||||
|
||||
permissionsManager := permissions.NewManager(store)
|
||||
|
||||
return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.selfhosted", eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
|
||||
return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.selfhosted", eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager)
|
||||
}
|
||||
|
||||
func createRouterStore(t *testing.T) (store.Store, error) {
|
||||
@@ -1305,7 +1305,7 @@ func initTestRouteAccount(t *testing.T, am *DefaultAccountManager) (*types.Accou
|
||||
accountID := "testingAcc"
|
||||
domain := "example.com"
|
||||
|
||||
account := newAccountWithId(context.Background(), accountID, userID, domain, false)
|
||||
account := newAccountWithId(context.Background(), accountID, userID, domain)
|
||||
err := am.Store.SaveAccount(context.Background(), account)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
||||
@@ -60,7 +60,7 @@ func (m *managerImpl) GetSettings(ctx context.Context, accountID, userID string)
|
||||
return nil, fmt.Errorf("get extra settings: %w", err)
|
||||
}
|
||||
|
||||
settings, err := m.store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID)
|
||||
settings, err := m.store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get account settings: %w", err)
|
||||
}
|
||||
@@ -82,7 +82,7 @@ func (m *managerImpl) GetExtraSettings(ctx context.Context, accountID string) (*
|
||||
return nil, fmt.Errorf("get extra settings: %w", err)
|
||||
}
|
||||
|
||||
settings, err := m.store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID)
|
||||
settings, err := m.store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get account settings: %w", err)
|
||||
}
|
||||
|
||||
@@ -127,7 +127,7 @@ func (am *DefaultAccountManager) SaveSetupKey(ctx context.Context, accountID str
|
||||
return status.Errorf(status.InvalidArgument, "invalid auto groups: %v", err)
|
||||
}
|
||||
|
||||
oldKey, err = transaction.GetSetupKeyByID(ctx, store.LockingStrengthShare, accountID, keyToSave.Id)
|
||||
oldKey, err = transaction.GetSetupKeyByID(ctx, store.LockingStrengthNone, accountID, keyToSave.Id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -175,7 +175,7 @@ func (am *DefaultAccountManager) ListSetupKeys(ctx context.Context, accountID, u
|
||||
return nil, status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
return am.Store.GetAccountSetupKeys(ctx, store.LockingStrengthShare, accountID)
|
||||
return am.Store.GetAccountSetupKeys(ctx, store.LockingStrengthNone, accountID)
|
||||
}
|
||||
|
||||
// GetSetupKey looks up a SetupKey by KeyID, returns NotFound error if not found.
|
||||
@@ -188,7 +188,7 @@ func (am *DefaultAccountManager) GetSetupKey(ctx context.Context, accountID, use
|
||||
return nil, status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
setupKey, err := am.Store.GetSetupKeyByID(ctx, store.LockingStrengthShare, accountID, keyID)
|
||||
setupKey, err := am.Store.GetSetupKeyByID(ctx, store.LockingStrengthNone, accountID, keyID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -214,7 +214,7 @@ func (am *DefaultAccountManager) DeleteSetupKey(ctx context.Context, accountID,
|
||||
var deletedSetupKey *types.SetupKey
|
||||
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
deletedSetupKey, err = transaction.GetSetupKeyByID(ctx, store.LockingStrengthShare, accountID, keyID)
|
||||
deletedSetupKey, err = transaction.GetSetupKeyByID(ctx, store.LockingStrengthNone, accountID, keyID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -231,7 +231,7 @@ func (am *DefaultAccountManager) DeleteSetupKey(ctx context.Context, accountID,
|
||||
}
|
||||
|
||||
func validateSetupKeyAutoGroups(ctx context.Context, transaction store.Store, accountID string, autoGroupIDs []string) error {
|
||||
groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthShare, accountID, autoGroupIDs)
|
||||
groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthNone, accountID, autoGroupIDs)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -255,7 +255,7 @@ func (am *DefaultAccountManager) prepareSetupKeyEvents(ctx context.Context, tran
|
||||
var eventsToStore []func()
|
||||
|
||||
modifiedGroups := slices.Concat(addedGroups, removedGroups)
|
||||
groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthShare, accountID, modifiedGroups)
|
||||
groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthNone, accountID, modifiedGroups)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Debugf("failed to get groups for setup key events: %v", err)
|
||||
return nil
|
||||
|
||||
@@ -478,7 +478,7 @@ func (s *SqlStore) DeleteTokenID2UserIDIndex(tokenID string) error {
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetAccountByPrivateDomain(ctx context.Context, domain string) (*types.Account, error) {
|
||||
accountID, err := s.GetAccountIDByPrivateDomain(ctx, LockingStrengthShare, domain)
|
||||
accountID, err := s.GetAccountIDByPrivateDomain(ctx, LockingStrengthNone, domain)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -488,11 +488,7 @@ func (s *SqlStore) GetAccountByPrivateDomain(ctx context.Context, domain string)
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetAccountIDByPrivateDomain(ctx context.Context, lockStrength LockingStrength, domain string) (string, error) {
|
||||
tx := s.db
|
||||
if lockStrength != LockingStrengthNone {
|
||||
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
||||
}
|
||||
|
||||
tx := s.getTXWithLockStrength(lockStrength)
|
||||
var accountID string
|
||||
result := tx.Model(&types.Account{}).Select("id").
|
||||
Where("domain = ? and is_domain_primary_account = ? and domain_category = ?",
|
||||
@@ -542,11 +538,7 @@ func (s *SqlStore) GetTokenIDByHashedToken(ctx context.Context, hashedToken stri
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetUserByPATID(ctx context.Context, lockStrength LockingStrength, patID string) (*types.User, error) {
|
||||
tx := s.db
|
||||
if lockStrength != LockingStrengthNone {
|
||||
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
||||
}
|
||||
|
||||
tx := s.getTXWithLockStrength(lockStrength)
|
||||
var user types.User
|
||||
result := tx.
|
||||
Joins("JOIN personal_access_tokens ON personal_access_tokens.user_id = users.id").
|
||||
@@ -563,11 +555,7 @@ func (s *SqlStore) GetUserByPATID(ctx context.Context, lockStrength LockingStren
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*types.User, error) {
|
||||
tx := s.db
|
||||
if lockStrength != LockingStrengthNone {
|
||||
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
||||
}
|
||||
|
||||
tx := s.getTXWithLockStrength(lockStrength)
|
||||
var user types.User
|
||||
result := tx.First(&user, idQueryCondition, userID)
|
||||
if result.Error != nil {
|
||||
@@ -600,11 +588,7 @@ func (s *SqlStore) DeleteUser(ctx context.Context, lockStrength LockingStrength,
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetAccountUsers(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types.User, error) {
|
||||
tx := s.db
|
||||
if lockStrength != LockingStrengthNone {
|
||||
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
||||
}
|
||||
|
||||
tx := s.getTXWithLockStrength(lockStrength)
|
||||
var users []*types.User
|
||||
result := tx.Find(&users, accountIDCondition, accountID)
|
||||
if result.Error != nil {
|
||||
@@ -619,11 +603,7 @@ func (s *SqlStore) GetAccountUsers(ctx context.Context, lockStrength LockingStre
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetAccountOwner(ctx context.Context, lockStrength LockingStrength, accountID string) (*types.User, error) {
|
||||
tx := s.db
|
||||
if lockStrength != LockingStrengthNone {
|
||||
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
||||
}
|
||||
|
||||
tx := s.getTXWithLockStrength(lockStrength)
|
||||
var user types.User
|
||||
result := tx.First(&user, "account_id = ? AND role = ?", accountID, types.UserRoleOwner)
|
||||
if result.Error != nil {
|
||||
@@ -637,11 +617,7 @@ func (s *SqlStore) GetAccountOwner(ctx context.Context, lockStrength LockingStre
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetAccountGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types.Group, error) {
|
||||
tx := s.db
|
||||
if lockStrength != LockingStrengthNone {
|
||||
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
||||
}
|
||||
|
||||
tx := s.getTXWithLockStrength(lockStrength)
|
||||
var groups []*types.Group
|
||||
result := tx.Find(&groups, accountIDCondition, accountID)
|
||||
if result.Error != nil {
|
||||
@@ -656,11 +632,7 @@ func (s *SqlStore) GetAccountGroups(ctx context.Context, lockStrength LockingStr
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetResourceGroups(ctx context.Context, lockStrength LockingStrength, accountID, resourceID string) ([]*types.Group, error) {
|
||||
tx := s.db
|
||||
if lockStrength != LockingStrengthNone {
|
||||
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
||||
}
|
||||
|
||||
tx := s.getTXWithLockStrength(lockStrength)
|
||||
var groups []*types.Group
|
||||
|
||||
likePattern := `%"ID":"` + resourceID + `"%`
|
||||
@@ -706,11 +678,7 @@ func (s *SqlStore) GetAllAccounts(ctx context.Context) (all []*types.Account) {
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetAccountMeta(ctx context.Context, lockStrength LockingStrength, accountID string) (*types.AccountMeta, error) {
|
||||
tx := s.db
|
||||
if lockStrength != LockingStrengthNone {
|
||||
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
||||
}
|
||||
|
||||
tx := s.getTXWithLockStrength(lockStrength)
|
||||
var accountMeta types.AccountMeta
|
||||
result := tx.Model(&types.Account{}).
|
||||
First(&accountMeta, idQueryCondition, accountID)
|
||||
@@ -880,11 +848,7 @@ func (s *SqlStore) GetAccountIDByPeerPubKey(ctx context.Context, peerKey string)
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetAccountIDByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (string, error) {
|
||||
tx := s.db
|
||||
if lockStrength != LockingStrengthNone {
|
||||
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
||||
}
|
||||
|
||||
tx := s.getTXWithLockStrength(lockStrength)
|
||||
var accountID string
|
||||
result := tx.Model(&types.User{}).
|
||||
Select("account_id").Where(idQueryCondition, userID).First(&accountID)
|
||||
@@ -899,11 +863,7 @@ func (s *SqlStore) GetAccountIDByUserID(ctx context.Context, lockStrength Lockin
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetAccountIDByPeerID(ctx context.Context, lockStrength LockingStrength, peerID string) (string, error) {
|
||||
tx := s.db
|
||||
if lockStrength != LockingStrengthNone {
|
||||
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
||||
}
|
||||
|
||||
tx := s.getTXWithLockStrength(lockStrength)
|
||||
var accountID string
|
||||
result := tx.Model(&nbpeer.Peer{}).
|
||||
Select("account_id").Where(idQueryCondition, peerID).First(&accountID)
|
||||
@@ -936,11 +896,7 @@ func (s *SqlStore) GetAccountIDBySetupKey(ctx context.Context, setupKey string)
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetTakenIPs(ctx context.Context, lockStrength LockingStrength, accountID string) ([]net.IP, error) {
|
||||
tx := s.db
|
||||
if lockStrength != LockingStrengthNone {
|
||||
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
||||
}
|
||||
|
||||
tx := s.getTXWithLockStrength(lockStrength)
|
||||
var ipJSONStrings []string
|
||||
|
||||
// Fetch the IP addresses as JSON strings
|
||||
@@ -968,11 +924,7 @@ func (s *SqlStore) GetTakenIPs(ctx context.Context, lockStrength LockingStrength
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetPeerLabelsInAccount(ctx context.Context, lockStrength LockingStrength, accountID string) ([]string, error) {
|
||||
tx := s.db
|
||||
if lockStrength != LockingStrengthNone {
|
||||
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
||||
}
|
||||
|
||||
tx := s.getTXWithLockStrength(lockStrength)
|
||||
var labels []string
|
||||
result := tx.Model(&nbpeer.Peer{}).
|
||||
Where("account_id = ?", accountID).
|
||||
@@ -990,11 +942,7 @@ func (s *SqlStore) GetPeerLabelsInAccount(ctx context.Context, lockStrength Lock
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetAccountNetwork(ctx context.Context, lockStrength LockingStrength, accountID string) (*types.Network, error) {
|
||||
tx := s.db
|
||||
if lockStrength != LockingStrengthNone {
|
||||
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
||||
}
|
||||
|
||||
tx := s.getTXWithLockStrength(lockStrength)
|
||||
var accountNetwork types.AccountNetwork
|
||||
if err := tx.Model(&types.Account{}).Where(idQueryCondition, accountID).First(&accountNetwork).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
@@ -1006,11 +954,7 @@ func (s *SqlStore) GetAccountNetwork(ctx context.Context, lockStrength LockingSt
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetPeerByPeerPubKey(ctx context.Context, lockStrength LockingStrength, peerKey string) (*nbpeer.Peer, error) {
|
||||
tx := s.db
|
||||
if lockStrength != LockingStrengthNone {
|
||||
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
||||
}
|
||||
|
||||
tx := s.getTXWithLockStrength(lockStrength)
|
||||
var peer nbpeer.Peer
|
||||
result := tx.First(&peer, GetKeyQueryCondition(s), peerKey)
|
||||
|
||||
@@ -1025,11 +969,7 @@ func (s *SqlStore) GetPeerByPeerPubKey(ctx context.Context, lockStrength Locking
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetAccountSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*types.Settings, error) {
|
||||
tx := s.db
|
||||
if lockStrength != LockingStrengthNone {
|
||||
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
||||
}
|
||||
|
||||
tx := s.getTXWithLockStrength(lockStrength)
|
||||
var accountSettings types.AccountSettings
|
||||
if err := tx.Model(&types.Account{}).Where(idQueryCondition, accountID).First(&accountSettings).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
@@ -1041,11 +981,7 @@ func (s *SqlStore) GetAccountSettings(ctx context.Context, lockStrength LockingS
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetAccountCreatedBy(ctx context.Context, lockStrength LockingStrength, accountID string) (string, error) {
|
||||
tx := s.db
|
||||
if lockStrength != LockingStrengthNone {
|
||||
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
||||
}
|
||||
|
||||
tx := s.getTXWithLockStrength(lockStrength)
|
||||
var createdBy string
|
||||
result := tx.Model(&types.Account{}).
|
||||
Select("created_by").First(&createdBy, idQueryCondition, accountID)
|
||||
@@ -1184,7 +1120,7 @@ func NewSqliteStoreFromFileStore(ctx context.Context, fileStore *FileStore, data
|
||||
for _, account := range fileStore.GetAllAccounts(ctx) {
|
||||
_, err = account.GetGroupAll()
|
||||
if err != nil {
|
||||
if err := account.AddAllGroup(false); err != nil {
|
||||
if err := account.AddAllGroup(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
@@ -1243,11 +1179,7 @@ func NewMysqlStoreFromSqlStore(ctx context.Context, sqliteStore *SqlStore, dsn s
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetSetupKeyBySecret(ctx context.Context, lockStrength LockingStrength, key string) (*types.SetupKey, error) {
|
||||
tx := s.db
|
||||
if lockStrength != LockingStrengthNone {
|
||||
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
||||
}
|
||||
|
||||
tx := s.getTXWithLockStrength(lockStrength)
|
||||
var setupKey types.SetupKey
|
||||
result := tx.
|
||||
First(&setupKey, GetKeyQueryCondition(s), key)
|
||||
@@ -1391,11 +1323,7 @@ func (s *SqlStore) RemoveResourceFromGroup(ctx context.Context, accountId string
|
||||
|
||||
// GetPeerGroups retrieves all groups assigned to a specific peer in a given account.
|
||||
func (s *SqlStore) GetPeerGroups(ctx context.Context, lockStrength LockingStrength, accountId string, peerId string) ([]*types.Group, error) {
|
||||
tx := s.db
|
||||
if lockStrength != LockingStrengthNone {
|
||||
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
||||
}
|
||||
|
||||
tx := s.getTXWithLockStrength(lockStrength)
|
||||
var groups []*types.Group
|
||||
query := tx.
|
||||
Find(&groups, "account_id = ? AND peers LIKE ?", accountId, fmt.Sprintf(`%%"%s"%%`, peerId))
|
||||
@@ -1409,8 +1337,9 @@ func (s *SqlStore) GetPeerGroups(ctx context.Context, lockStrength LockingStreng
|
||||
|
||||
// GetAccountPeers retrieves peers for an account.
|
||||
func (s *SqlStore) GetAccountPeers(ctx context.Context, lockStrength LockingStrength, accountID, nameFilter, ipFilter string) ([]*nbpeer.Peer, error) {
|
||||
tx := s.getTXWithLockStrength(lockStrength)
|
||||
var peers []*nbpeer.Peer
|
||||
query := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Where(accountIDCondition, accountID)
|
||||
query := tx.Where(accountIDCondition, accountID)
|
||||
|
||||
if nameFilter != "" {
|
||||
query = query.Where("name LIKE ?", "%"+nameFilter+"%")
|
||||
@@ -1429,11 +1358,7 @@ func (s *SqlStore) GetAccountPeers(ctx context.Context, lockStrength LockingStre
|
||||
|
||||
// GetUserPeers retrieves peers for a user.
|
||||
func (s *SqlStore) GetUserPeers(ctx context.Context, lockStrength LockingStrength, accountID, userID string) ([]*nbpeer.Peer, error) {
|
||||
tx := s.db
|
||||
if lockStrength != LockingStrengthNone {
|
||||
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
||||
}
|
||||
|
||||
tx := s.getTXWithLockStrength(lockStrength)
|
||||
var peers []*nbpeer.Peer
|
||||
|
||||
// Exclude peers added via setup keys, as they are not user-specific and have an empty user_id.
|
||||
@@ -1461,11 +1386,7 @@ func (s *SqlStore) AddPeerToAccount(ctx context.Context, lockStrength LockingStr
|
||||
|
||||
// GetPeerByID retrieves a peer by its ID and account ID.
|
||||
func (s *SqlStore) GetPeerByID(ctx context.Context, lockStrength LockingStrength, accountID, peerID string) (*nbpeer.Peer, error) {
|
||||
tx := s.db
|
||||
if lockStrength != LockingStrengthNone {
|
||||
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
||||
}
|
||||
|
||||
tx := s.getTXWithLockStrength(lockStrength)
|
||||
var peer *nbpeer.Peer
|
||||
result := tx.
|
||||
First(&peer, accountAndIDQueryCondition, accountID, peerID)
|
||||
@@ -1481,11 +1402,7 @@ func (s *SqlStore) GetPeerByID(ctx context.Context, lockStrength LockingStrength
|
||||
|
||||
// GetPeersByIDs retrieves peers by their IDs and account ID.
|
||||
func (s *SqlStore) GetPeersByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, peerIDs []string) (map[string]*nbpeer.Peer, error) {
|
||||
tx := s.db
|
||||
if lockStrength != LockingStrengthNone {
|
||||
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
||||
}
|
||||
|
||||
tx := s.getTXWithLockStrength(lockStrength)
|
||||
var peers []*nbpeer.Peer
|
||||
result := tx.Find(&peers, accountAndIDsQueryCondition, accountID, peerIDs)
|
||||
if result.Error != nil {
|
||||
@@ -1503,11 +1420,7 @@ func (s *SqlStore) GetPeersByIDs(ctx context.Context, lockStrength LockingStreng
|
||||
|
||||
// GetAccountPeersWithExpiration retrieves a list of peers that have login expiration enabled and added by a user.
|
||||
func (s *SqlStore) GetAccountPeersWithExpiration(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbpeer.Peer, error) {
|
||||
tx := s.db
|
||||
if lockStrength != LockingStrengthNone {
|
||||
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
||||
}
|
||||
|
||||
tx := s.getTXWithLockStrength(lockStrength)
|
||||
var peers []*nbpeer.Peer
|
||||
result := tx.
|
||||
Where("login_expiration_enabled = ? AND user_id IS NOT NULL AND user_id != ''", true).
|
||||
@@ -1522,11 +1435,7 @@ func (s *SqlStore) GetAccountPeersWithExpiration(ctx context.Context, lockStreng
|
||||
|
||||
// GetAccountPeersWithInactivity retrieves a list of peers that have login expiration enabled and added by a user.
|
||||
func (s *SqlStore) GetAccountPeersWithInactivity(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbpeer.Peer, error) {
|
||||
tx := s.db
|
||||
if lockStrength != LockingStrengthNone {
|
||||
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
||||
}
|
||||
|
||||
tx := s.getTXWithLockStrength(lockStrength)
|
||||
var peers []*nbpeer.Peer
|
||||
result := tx.
|
||||
Where("inactivity_expiration_enabled = ? AND user_id IS NOT NULL AND user_id != ''", true).
|
||||
@@ -1541,11 +1450,7 @@ func (s *SqlStore) GetAccountPeersWithInactivity(ctx context.Context, lockStreng
|
||||
|
||||
// GetAllEphemeralPeers retrieves all peers with Ephemeral set to true across all accounts, optimized for batch processing.
|
||||
func (s *SqlStore) GetAllEphemeralPeers(ctx context.Context, lockStrength LockingStrength) ([]*nbpeer.Peer, error) {
|
||||
tx := s.db
|
||||
if lockStrength != LockingStrengthNone {
|
||||
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
||||
}
|
||||
|
||||
tx := s.getTXWithLockStrength(lockStrength)
|
||||
var allEphemeralPeers, batchPeers []*nbpeer.Peer
|
||||
result := tx.
|
||||
Where("ephemeral = ?", true).
|
||||
@@ -1623,11 +1528,7 @@ func (s *SqlStore) GetDB() *gorm.DB {
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetAccountDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*types.DNSSettings, error) {
|
||||
tx := s.db
|
||||
if lockStrength != LockingStrengthNone {
|
||||
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
||||
}
|
||||
|
||||
tx := s.getTXWithLockStrength(lockStrength)
|
||||
var accountDNSSettings types.AccountDNSSettings
|
||||
result := tx.Model(&types.Account{}).
|
||||
First(&accountDNSSettings, idQueryCondition, accountID)
|
||||
@@ -1643,11 +1544,7 @@ func (s *SqlStore) GetAccountDNSSettings(ctx context.Context, lockStrength Locki
|
||||
|
||||
// AccountExists checks whether an account exists by the given ID.
|
||||
func (s *SqlStore) AccountExists(ctx context.Context, lockStrength LockingStrength, id string) (bool, error) {
|
||||
tx := s.db
|
||||
if lockStrength != LockingStrengthNone {
|
||||
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
||||
}
|
||||
|
||||
tx := s.getTXWithLockStrength(lockStrength)
|
||||
var accountID string
|
||||
result := tx.Model(&types.Account{}).
|
||||
Select("id").First(&accountID, idQueryCondition, id)
|
||||
@@ -1663,11 +1560,7 @@ func (s *SqlStore) AccountExists(ctx context.Context, lockStrength LockingStreng
|
||||
|
||||
// GetAccountDomainAndCategory retrieves the Domain and DomainCategory fields for an account based on the given accountID.
|
||||
func (s *SqlStore) GetAccountDomainAndCategory(ctx context.Context, lockStrength LockingStrength, accountID string) (string, string, error) {
|
||||
tx := s.db
|
||||
if lockStrength != LockingStrengthNone {
|
||||
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
||||
}
|
||||
|
||||
tx := s.getTXWithLockStrength(lockStrength)
|
||||
var account types.Account
|
||||
result := tx.Model(&types.Account{}).Select("domain", "domain_category").
|
||||
Where(idQueryCondition, accountID).First(&account)
|
||||
@@ -1683,11 +1576,7 @@ func (s *SqlStore) GetAccountDomainAndCategory(ctx context.Context, lockStrength
|
||||
|
||||
// GetGroupByID retrieves a group by ID and account ID.
|
||||
func (s *SqlStore) GetGroupByID(ctx context.Context, lockStrength LockingStrength, accountID, groupID string) (*types.Group, error) {
|
||||
tx := s.db
|
||||
if lockStrength != LockingStrengthNone {
|
||||
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
||||
}
|
||||
|
||||
tx := s.getTXWithLockStrength(lockStrength)
|
||||
var group *types.Group
|
||||
result := tx.First(&group, accountAndIDQueryCondition, accountID, groupID)
|
||||
if err := result.Error; err != nil {
|
||||
@@ -1703,11 +1592,7 @@ func (s *SqlStore) GetGroupByID(ctx context.Context, lockStrength LockingStrengt
|
||||
|
||||
// GetGroupByName retrieves a group by name and account ID.
|
||||
func (s *SqlStore) GetGroupByName(ctx context.Context, lockStrength LockingStrength, accountID, groupName string) (*types.Group, error) {
|
||||
tx := s.db
|
||||
if lockStrength != LockingStrengthNone {
|
||||
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
||||
}
|
||||
|
||||
tx := s.getTXWithLockStrength(lockStrength)
|
||||
var group types.Group
|
||||
|
||||
// TODO: This fix is accepted for now, but if we need to handle this more frequently
|
||||
@@ -1736,11 +1621,7 @@ func (s *SqlStore) GetGroupByName(ctx context.Context, lockStrength LockingStren
|
||||
|
||||
// GetGroupsByIDs retrieves groups by their IDs and account ID.
|
||||
func (s *SqlStore) GetGroupsByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, groupIDs []string) (map[string]*types.Group, error) {
|
||||
tx := s.db
|
||||
if lockStrength != LockingStrengthNone {
|
||||
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
||||
}
|
||||
|
||||
tx := s.getTXWithLockStrength(lockStrength)
|
||||
var groups []*types.Group
|
||||
result := tx.Find(&groups, accountAndIDsQueryCondition, accountID, groupIDs)
|
||||
if result.Error != nil {
|
||||
@@ -1796,11 +1677,7 @@ func (s *SqlStore) DeleteGroups(ctx context.Context, strength LockingStrength, a
|
||||
|
||||
// GetAccountPolicies retrieves policies for an account.
|
||||
func (s *SqlStore) GetAccountPolicies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types.Policy, error) {
|
||||
tx := s.db
|
||||
if lockStrength != LockingStrengthNone {
|
||||
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
||||
}
|
||||
|
||||
tx := s.getTXWithLockStrength(lockStrength)
|
||||
var policies []*types.Policy
|
||||
result := tx.
|
||||
Preload(clause.Associations).Find(&policies, accountIDCondition, accountID)
|
||||
@@ -1814,11 +1691,7 @@ func (s *SqlStore) GetAccountPolicies(ctx context.Context, lockStrength LockingS
|
||||
|
||||
// GetPolicyByID retrieves a policy by its ID and account ID.
|
||||
func (s *SqlStore) GetPolicyByID(ctx context.Context, lockStrength LockingStrength, accountID, policyID string) (*types.Policy, error) {
|
||||
tx := s.db
|
||||
if lockStrength != LockingStrengthNone {
|
||||
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
||||
}
|
||||
|
||||
tx := s.getTXWithLockStrength(lockStrength)
|
||||
var policy *types.Policy
|
||||
|
||||
result := tx.Preload(clause.Associations).
|
||||
@@ -1880,11 +1753,7 @@ func (s *SqlStore) DeletePolicy(ctx context.Context, lockStrength LockingStrengt
|
||||
|
||||
// GetAccountPostureChecks retrieves posture checks for an account.
|
||||
func (s *SqlStore) GetAccountPostureChecks(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*posture.Checks, error) {
|
||||
tx := s.db
|
||||
if lockStrength != LockingStrengthNone {
|
||||
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
||||
}
|
||||
|
||||
tx := s.getTXWithLockStrength(lockStrength)
|
||||
var postureChecks []*posture.Checks
|
||||
result := tx.Find(&postureChecks, accountIDCondition, accountID)
|
||||
if result.Error != nil {
|
||||
@@ -1897,10 +1766,7 @@ func (s *SqlStore) GetAccountPostureChecks(ctx context.Context, lockStrength Loc
|
||||
|
||||
// GetPostureChecksByID retrieves posture checks by their ID and account ID.
|
||||
func (s *SqlStore) GetPostureChecksByID(ctx context.Context, lockStrength LockingStrength, accountID, postureChecksID string) (*posture.Checks, error) {
|
||||
tx := s.db
|
||||
if lockStrength != LockingStrengthNone {
|
||||
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
||||
}
|
||||
tx := s.getTXWithLockStrength(lockStrength)
|
||||
|
||||
var postureCheck *posture.Checks
|
||||
result := tx.
|
||||
@@ -1918,11 +1784,7 @@ func (s *SqlStore) GetPostureChecksByID(ctx context.Context, lockStrength Lockin
|
||||
|
||||
// GetPostureChecksByIDs retrieves posture checks by their IDs and account ID.
|
||||
func (s *SqlStore) GetPostureChecksByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, postureChecksIDs []string) (map[string]*posture.Checks, error) {
|
||||
tx := s.db
|
||||
if lockStrength != LockingStrengthNone {
|
||||
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
||||
}
|
||||
|
||||
tx := s.getTXWithLockStrength(lockStrength)
|
||||
var postureChecks []*posture.Checks
|
||||
result := tx.Find(&postureChecks, accountAndIDsQueryCondition, accountID, postureChecksIDs)
|
||||
if result.Error != nil {
|
||||
@@ -1967,9 +1829,9 @@ func (s *SqlStore) DeletePostureChecks(ctx context.Context, lockStrength Locking
|
||||
|
||||
// GetAccountRoutes retrieves network routes for an account.
|
||||
func (s *SqlStore) GetAccountRoutes(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*route.Route, error) {
|
||||
tx := s.getTXWithLockStrength(lockStrength)
|
||||
var routes []*route.Route
|
||||
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).
|
||||
Find(&routes, accountIDCondition, accountID)
|
||||
result := tx.Find(&routes, accountIDCondition, accountID)
|
||||
if err := result.Error; err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get routes from the store: %s", err)
|
||||
return nil, status.Errorf(status.Internal, "failed to get routes from store")
|
||||
@@ -1980,9 +1842,9 @@ func (s *SqlStore) GetAccountRoutes(ctx context.Context, lockStrength LockingStr
|
||||
|
||||
// GetRouteByID retrieves a route by its ID and account ID.
|
||||
func (s *SqlStore) GetRouteByID(ctx context.Context, lockStrength LockingStrength, accountID string, routeID string) (*route.Route, error) {
|
||||
tx := s.getTXWithLockStrength(lockStrength)
|
||||
var route *route.Route
|
||||
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).
|
||||
First(&route, accountAndIDQueryCondition, accountID, routeID)
|
||||
result := tx.First(&route, accountAndIDQueryCondition, accountID, routeID)
|
||||
if err := result.Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, status.NewRouteNotFoundError(routeID)
|
||||
@@ -2023,11 +1885,7 @@ func (s *SqlStore) DeleteRoute(ctx context.Context, lockStrength LockingStrength
|
||||
|
||||
// GetAccountSetupKeys retrieves setup keys for an account.
|
||||
func (s *SqlStore) GetAccountSetupKeys(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types.SetupKey, error) {
|
||||
tx := s.db
|
||||
if lockStrength != LockingStrengthNone {
|
||||
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
||||
}
|
||||
|
||||
tx := s.getTXWithLockStrength(lockStrength)
|
||||
var setupKeys []*types.SetupKey
|
||||
result := tx.
|
||||
Find(&setupKeys, accountIDCondition, accountID)
|
||||
@@ -2041,14 +1899,9 @@ func (s *SqlStore) GetAccountSetupKeys(ctx context.Context, lockStrength Locking
|
||||
|
||||
// GetSetupKeyByID retrieves a setup key by its ID and account ID.
|
||||
func (s *SqlStore) GetSetupKeyByID(ctx context.Context, lockStrength LockingStrength, accountID, setupKeyID string) (*types.SetupKey, error) {
|
||||
tx := s.db
|
||||
if lockStrength != LockingStrengthNone {
|
||||
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
||||
}
|
||||
|
||||
tx := s.getTXWithLockStrength(lockStrength)
|
||||
var setupKey *types.SetupKey
|
||||
result := tx.Clauses(clause.Locking{Strength: string(lockStrength)}).
|
||||
First(&setupKey, accountAndIDQueryCondition, accountID, setupKeyID)
|
||||
result := tx.First(&setupKey, accountAndIDQueryCondition, accountID, setupKeyID)
|
||||
if err := result.Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, status.NewSetupKeyNotFoundError(setupKeyID)
|
||||
@@ -2088,11 +1941,7 @@ func (s *SqlStore) DeleteSetupKey(ctx context.Context, lockStrength LockingStren
|
||||
|
||||
// GetAccountNameServerGroups retrieves name server groups for an account.
|
||||
func (s *SqlStore) GetAccountNameServerGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbdns.NameServerGroup, error) {
|
||||
tx := s.db
|
||||
if lockStrength != LockingStrengthNone {
|
||||
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
||||
}
|
||||
|
||||
tx := s.getTXWithLockStrength(lockStrength)
|
||||
var nsGroups []*nbdns.NameServerGroup
|
||||
result := tx.Find(&nsGroups, accountIDCondition, accountID)
|
||||
if err := result.Error; err != nil {
|
||||
@@ -2105,11 +1954,7 @@ func (s *SqlStore) GetAccountNameServerGroups(ctx context.Context, lockStrength
|
||||
|
||||
// GetNameServerGroupByID retrieves a name server group by its ID and account ID.
|
||||
func (s *SqlStore) GetNameServerGroupByID(ctx context.Context, lockStrength LockingStrength, accountID, nsGroupID string) (*nbdns.NameServerGroup, error) {
|
||||
tx := s.db
|
||||
if lockStrength != LockingStrengthNone {
|
||||
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
||||
}
|
||||
|
||||
tx := s.getTXWithLockStrength(lockStrength)
|
||||
var nsGroup *nbdns.NameServerGroup
|
||||
result := tx.
|
||||
First(&nsGroup, accountAndIDQueryCondition, accountID, nsGroupID)
|
||||
@@ -2182,11 +2027,7 @@ func (s *SqlStore) SaveAccountSettings(ctx context.Context, lockStrength Locking
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetAccountNetworks(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*networkTypes.Network, error) {
|
||||
tx := s.db
|
||||
if lockStrength != LockingStrengthNone {
|
||||
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
||||
}
|
||||
|
||||
tx := s.getTXWithLockStrength(lockStrength)
|
||||
var networks []*networkTypes.Network
|
||||
result := tx.Find(&networks, accountIDCondition, accountID)
|
||||
if result.Error != nil {
|
||||
@@ -2198,11 +2039,7 @@ func (s *SqlStore) GetAccountNetworks(ctx context.Context, lockStrength LockingS
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetNetworkByID(ctx context.Context, lockStrength LockingStrength, accountID, networkID string) (*networkTypes.Network, error) {
|
||||
tx := s.db
|
||||
if lockStrength != LockingStrengthNone {
|
||||
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
||||
}
|
||||
|
||||
tx := s.getTXWithLockStrength(lockStrength)
|
||||
var network *networkTypes.Network
|
||||
result := tx.
|
||||
First(&network, accountAndIDQueryCondition, accountID, networkID)
|
||||
@@ -2244,11 +2081,7 @@ func (s *SqlStore) DeleteNetwork(ctx context.Context, lockStrength LockingStreng
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetNetworkRoutersByNetID(ctx context.Context, lockStrength LockingStrength, accountID, netID string) ([]*routerTypes.NetworkRouter, error) {
|
||||
tx := s.db
|
||||
if lockStrength != LockingStrengthNone {
|
||||
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
||||
}
|
||||
|
||||
tx := s.getTXWithLockStrength(lockStrength)
|
||||
var netRouters []*routerTypes.NetworkRouter
|
||||
result := tx.
|
||||
Find(&netRouters, "account_id = ? AND network_id = ?", accountID, netID)
|
||||
@@ -2261,11 +2094,7 @@ func (s *SqlStore) GetNetworkRoutersByNetID(ctx context.Context, lockStrength Lo
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetNetworkRoutersByAccountID(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*routerTypes.NetworkRouter, error) {
|
||||
tx := s.db
|
||||
if lockStrength != LockingStrengthNone {
|
||||
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
||||
}
|
||||
|
||||
tx := s.getTXWithLockStrength(lockStrength)
|
||||
var netRouters []*routerTypes.NetworkRouter
|
||||
result := tx.
|
||||
Find(&netRouters, accountIDCondition, accountID)
|
||||
@@ -2278,11 +2107,7 @@ func (s *SqlStore) GetNetworkRoutersByAccountID(ctx context.Context, lockStrengt
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetNetworkRouterByID(ctx context.Context, lockStrength LockingStrength, accountID, routerID string) (*routerTypes.NetworkRouter, error) {
|
||||
tx := s.db
|
||||
if lockStrength != LockingStrengthNone {
|
||||
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
||||
}
|
||||
|
||||
tx := s.getTXWithLockStrength(lockStrength)
|
||||
var netRouter *routerTypes.NetworkRouter
|
||||
result := tx.
|
||||
First(&netRouter, accountAndIDQueryCondition, accountID, routerID)
|
||||
@@ -2323,11 +2148,7 @@ func (s *SqlStore) DeleteNetworkRouter(ctx context.Context, lockStrength Locking
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetNetworkResourcesByNetID(ctx context.Context, lockStrength LockingStrength, accountID, networkID string) ([]*resourceTypes.NetworkResource, error) {
|
||||
tx := s.db
|
||||
if lockStrength != LockingStrengthNone {
|
||||
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
||||
}
|
||||
|
||||
tx := s.getTXWithLockStrength(lockStrength)
|
||||
var netResources []*resourceTypes.NetworkResource
|
||||
result := tx.
|
||||
Find(&netResources, "account_id = ? AND network_id = ?", accountID, networkID)
|
||||
@@ -2340,11 +2161,7 @@ func (s *SqlStore) GetNetworkResourcesByNetID(ctx context.Context, lockStrength
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetNetworkResourcesByAccountID(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*resourceTypes.NetworkResource, error) {
|
||||
tx := s.db
|
||||
if lockStrength != LockingStrengthNone {
|
||||
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
||||
}
|
||||
|
||||
tx := s.getTXWithLockStrength(lockStrength)
|
||||
var netResources []*resourceTypes.NetworkResource
|
||||
result := tx.
|
||||
Find(&netResources, accountIDCondition, accountID)
|
||||
@@ -2357,11 +2174,7 @@ func (s *SqlStore) GetNetworkResourcesByAccountID(ctx context.Context, lockStren
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetNetworkResourceByID(ctx context.Context, lockStrength LockingStrength, accountID, resourceID string) (*resourceTypes.NetworkResource, error) {
|
||||
tx := s.db
|
||||
if lockStrength != LockingStrengthNone {
|
||||
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
||||
}
|
||||
|
||||
tx := s.getTXWithLockStrength(lockStrength)
|
||||
var netResources *resourceTypes.NetworkResource
|
||||
result := tx.
|
||||
First(&netResources, accountAndIDQueryCondition, accountID, resourceID)
|
||||
@@ -2377,11 +2190,7 @@ func (s *SqlStore) GetNetworkResourceByID(ctx context.Context, lockStrength Lock
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetNetworkResourceByName(ctx context.Context, lockStrength LockingStrength, accountID, resourceName string) (*resourceTypes.NetworkResource, error) {
|
||||
tx := s.db
|
||||
if lockStrength != LockingStrengthNone {
|
||||
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
||||
}
|
||||
|
||||
tx := s.getTXWithLockStrength(lockStrength)
|
||||
var netResources *resourceTypes.NetworkResource
|
||||
result := tx.
|
||||
First(&netResources, "account_id = ? AND name = ?", accountID, resourceName)
|
||||
@@ -2423,10 +2232,7 @@ func (s *SqlStore) DeleteNetworkResource(ctx context.Context, lockStrength Locki
|
||||
|
||||
// GetPATByHashedToken returns a PersonalAccessToken by its hashed token.
|
||||
func (s *SqlStore) GetPATByHashedToken(ctx context.Context, lockStrength LockingStrength, hashedToken string) (*types.PersonalAccessToken, error) {
|
||||
tx := s.db
|
||||
if lockStrength != LockingStrengthNone {
|
||||
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
||||
}
|
||||
tx := s.getTXWithLockStrength(lockStrength)
|
||||
|
||||
var pat types.PersonalAccessToken
|
||||
result := tx.First(&pat, "hashed_token = ?", hashedToken)
|
||||
@@ -2443,11 +2249,7 @@ func (s *SqlStore) GetPATByHashedToken(ctx context.Context, lockStrength Locking
|
||||
|
||||
// GetPATByID retrieves a personal access token by its ID and user ID.
|
||||
func (s *SqlStore) GetPATByID(ctx context.Context, lockStrength LockingStrength, userID string, patID string) (*types.PersonalAccessToken, error) {
|
||||
tx := s.db
|
||||
if lockStrength != LockingStrengthNone {
|
||||
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
||||
}
|
||||
|
||||
tx := s.getTXWithLockStrength(lockStrength)
|
||||
var pat types.PersonalAccessToken
|
||||
result := tx.
|
||||
First(&pat, "id = ? AND user_id = ?", patID, userID)
|
||||
@@ -2464,11 +2266,7 @@ func (s *SqlStore) GetPATByID(ctx context.Context, lockStrength LockingStrength,
|
||||
|
||||
// GetUserPATs retrieves personal access tokens for a user.
|
||||
func (s *SqlStore) GetUserPATs(ctx context.Context, lockStrength LockingStrength, userID string) ([]*types.PersonalAccessToken, error) {
|
||||
tx := s.db
|
||||
if lockStrength != LockingStrengthNone {
|
||||
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
||||
}
|
||||
|
||||
tx := s.getTXWithLockStrength(lockStrength)
|
||||
var pats []*types.PersonalAccessToken
|
||||
result := tx.Find(&pats, "user_id = ?", userID)
|
||||
if err := result.Error; err != nil {
|
||||
@@ -2528,11 +2326,7 @@ func (s *SqlStore) DeletePAT(ctx context.Context, lockStrength LockingStrength,
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetPeerByIP(ctx context.Context, lockStrength LockingStrength, accountID string, ip net.IP) (*nbpeer.Peer, error) {
|
||||
tx := s.db
|
||||
if lockStrength != LockingStrengthNone {
|
||||
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
||||
}
|
||||
|
||||
tx := s.getTXWithLockStrength(lockStrength)
|
||||
jsonValue := fmt.Sprintf(`"%s"`, ip.String())
|
||||
|
||||
var peer nbpeer.Peer
|
||||
@@ -2559,3 +2353,11 @@ func (s *SqlStore) CountAccountsByPrivateDomain(ctx context.Context, domain stri
|
||||
|
||||
return count, nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) getTXWithLockStrength(lockStrength LockingStrength) *gorm.DB {
|
||||
tx := s.db
|
||||
if lockStrength != LockingStrengthNone {
|
||||
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
||||
}
|
||||
return tx
|
||||
}
|
||||
|
||||
@@ -2044,7 +2044,7 @@ func newAccountWithId(ctx context.Context, accountID, userID, domain string) *ty
|
||||
},
|
||||
}
|
||||
|
||||
if err := acc.AddAllGroup(false); err != nil {
|
||||
if err := acc.AddAllGroup(); err != nil {
|
||||
log.WithContext(ctx).Errorf("error adding all group to account %s: %v", acc.Id, err)
|
||||
}
|
||||
return acc
|
||||
|
||||
@@ -391,7 +391,7 @@ func addAllGroupToAccount(ctx context.Context, store Store) error {
|
||||
|
||||
_, err := account.GetGroupAll()
|
||||
if err != nil {
|
||||
if err := account.AddAllGroup(false); err != nil {
|
||||
if err := account.AddAllGroup(); err != nil {
|
||||
return err
|
||||
}
|
||||
shouldSave = true
|
||||
|
||||
@@ -19,7 +19,7 @@ type UpdateChannelMetrics struct {
|
||||
getAllConnectedPeers metric.Int64Histogram
|
||||
hasChannelDurationMicro metric.Int64Histogram
|
||||
calcPostureChecksDurationMicro metric.Int64Histogram
|
||||
calcPeerNetworkMapDurationMs metric.Int64Histogram
|
||||
calcPeerNetworkMapDurationMicro metric.Int64Histogram
|
||||
mergeNetworkMapDurationMicro metric.Int64Histogram
|
||||
toSyncResponseDurationMicro metric.Int64Histogram
|
||||
ctx context.Context
|
||||
@@ -101,8 +101,8 @@ func NewUpdateChannelMetrics(ctx context.Context, meter metric.Meter) (*UpdateCh
|
||||
return nil, err
|
||||
}
|
||||
|
||||
calcPeerNetworkMapDurationMs, err := meter.Int64Histogram("management.updatechannel.calc.networkmap.duration.ms",
|
||||
metric.WithUnit("milliseconds"),
|
||||
calcPeerNetworkMapDurationMicro, err := meter.Int64Histogram("management.updatechannel.calc.networkmap.duration.micro",
|
||||
metric.WithUnit("microseconds"),
|
||||
metric.WithDescription("Duration of how long it takes to calculate the network map for a peer"),
|
||||
)
|
||||
if err != nil {
|
||||
@@ -135,7 +135,7 @@ func NewUpdateChannelMetrics(ctx context.Context, meter metric.Meter) (*UpdateCh
|
||||
getAllConnectedPeers: getAllConnectedPeers,
|
||||
hasChannelDurationMicro: hasChannelDurationMicro,
|
||||
calcPostureChecksDurationMicro: calcPostureChecksDurationMicro,
|
||||
calcPeerNetworkMapDurationMs: calcPeerNetworkMapDurationMs,
|
||||
calcPeerNetworkMapDurationMicro: calcPeerNetworkMapDurationMicro,
|
||||
mergeNetworkMapDurationMicro: mergeNetworkMapDurationMicro,
|
||||
toSyncResponseDurationMicro: toSyncResponseDurationMicro,
|
||||
ctx: ctx,
|
||||
@@ -183,7 +183,7 @@ func (metrics *UpdateChannelMetrics) CountCalcPostureChecksDuration(duration tim
|
||||
}
|
||||
|
||||
func (metrics *UpdateChannelMetrics) CountCalcPeerNetworkMapDuration(duration time.Duration) {
|
||||
metrics.calcPeerNetworkMapDurationMs.Record(metrics.ctx, duration.Milliseconds())
|
||||
metrics.calcPeerNetworkMapDurationMicro.Record(metrics.ctx, duration.Microseconds())
|
||||
}
|
||||
|
||||
func (metrics *UpdateChannelMetrics) CountMergeNetworkMapDuration(duration time.Duration) {
|
||||
|
||||
@@ -1546,7 +1546,7 @@ func getPoliciesSourcePeers(policies []*Policy, groups map[string]*Group) map[st
|
||||
}
|
||||
|
||||
// AddAllGroup to account object if it doesn't exist
|
||||
func (a *Account) AddAllGroup(disableDefaultPolicy bool) error {
|
||||
func (a *Account) AddAllGroup() error {
|
||||
if len(a.Groups) == 0 {
|
||||
allGroup := &Group{
|
||||
ID: xid.New().String(),
|
||||
@@ -1558,10 +1558,6 @@ func (a *Account) AddAllGroup(disableDefaultPolicy bool) error {
|
||||
}
|
||||
a.Groups = map[string]*Group{allGroup.ID: allGroup}
|
||||
|
||||
if disableDefaultPolicy {
|
||||
return nil
|
||||
}
|
||||
|
||||
id := xid.New().String()
|
||||
|
||||
defaultPolicy := &Policy{
|
||||
|
||||
@@ -53,9 +53,6 @@ type Config struct {
|
||||
StoreConfig StoreConfig
|
||||
|
||||
ReverseProxy ReverseProxy
|
||||
|
||||
// disable default all-to-all policy
|
||||
DisableDefaultPolicy bool
|
||||
}
|
||||
|
||||
// GetAuthAudiences returns the audience from the http config and device authorization flow config
|
||||
|
||||
@@ -95,14 +95,14 @@ func (am *DefaultAccountManager) inviteNewUser(ctx context.Context, accountID, u
|
||||
return nil, status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
initiatorUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID)
|
||||
initiatorUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
inviterID := userID
|
||||
if initiatorUser.IsServiceUser {
|
||||
createdBy, err := am.Store.GetAccountCreatedBy(ctx, store.LockingStrengthShare, accountID)
|
||||
createdBy, err := am.Store.GetAccountCreatedBy(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -178,13 +178,13 @@ func (am *DefaultAccountManager) createNewIdpUser(ctx context.Context, accountID
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) GetUserByID(ctx context.Context, id string) (*types.User, error) {
|
||||
return am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, id)
|
||||
return am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, id)
|
||||
}
|
||||
|
||||
// GetUser looks up a user by provided nbContext.UserAuths.
|
||||
// Expects account to have been created already.
|
||||
func (am *DefaultAccountManager) GetUserFromUserAuth(ctx context.Context, userAuth nbContext.UserAuth) (*types.User, error) {
|
||||
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userAuth.UserId)
|
||||
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, userAuth.UserId)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -209,7 +209,7 @@ func (am *DefaultAccountManager) GetUserFromUserAuth(ctx context.Context, userAu
|
||||
// ListUsers returns lists of all users under the account.
|
||||
// It doesn't populate user information such as email or name.
|
||||
func (am *DefaultAccountManager) ListUsers(ctx context.Context, accountID string) ([]*types.User, error) {
|
||||
return am.Store.GetAccountUsers(ctx, store.LockingStrengthShare, accountID)
|
||||
return am.Store.GetAccountUsers(ctx, store.LockingStrengthNone, accountID)
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) deleteServiceUser(ctx context.Context, accountID string, initiatorUserID string, targetUser *types.User) error {
|
||||
@@ -230,7 +230,7 @@ func (am *DefaultAccountManager) DeleteUser(ctx context.Context, accountID, init
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
initiatorUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, initiatorUserID)
|
||||
initiatorUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, initiatorUserID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -243,7 +243,7 @@ func (am *DefaultAccountManager) DeleteUser(ctx context.Context, accountID, init
|
||||
return status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
targetUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, targetUserID)
|
||||
targetUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, targetUserID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -347,12 +347,12 @@ func (am *DefaultAccountManager) CreatePAT(ctx context.Context, accountID string
|
||||
return nil, status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
initiatorUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, initiatorUserID)
|
||||
initiatorUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, initiatorUserID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
targetUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, targetUserID)
|
||||
targetUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, targetUserID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -390,12 +390,12 @@ func (am *DefaultAccountManager) DeletePAT(ctx context.Context, accountID string
|
||||
return status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
initiatorUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, initiatorUserID)
|
||||
initiatorUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, initiatorUserID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
targetUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, targetUserID)
|
||||
targetUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, targetUserID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -404,7 +404,7 @@ func (am *DefaultAccountManager) DeletePAT(ctx context.Context, accountID string
|
||||
return status.NewAdminPermissionError()
|
||||
}
|
||||
|
||||
pat, err := am.Store.GetPATByID(ctx, store.LockingStrengthShare, targetUserID, tokenID)
|
||||
pat, err := am.Store.GetPATByID(ctx, store.LockingStrengthNone, targetUserID, tokenID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -429,12 +429,12 @@ func (am *DefaultAccountManager) GetPAT(ctx context.Context, accountID string, i
|
||||
return nil, status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
initiatorUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, initiatorUserID)
|
||||
initiatorUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, initiatorUserID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
targetUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, targetUserID)
|
||||
targetUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, targetUserID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -443,7 +443,7 @@ func (am *DefaultAccountManager) GetPAT(ctx context.Context, accountID string, i
|
||||
return nil, status.NewAdminPermissionError()
|
||||
}
|
||||
|
||||
return am.Store.GetPATByID(ctx, store.LockingStrengthShare, targetUserID, tokenID)
|
||||
return am.Store.GetPATByID(ctx, store.LockingStrengthNone, targetUserID, tokenID)
|
||||
}
|
||||
|
||||
// GetAllPATs returns all PATs for a user
|
||||
@@ -456,12 +456,12 @@ func (am *DefaultAccountManager) GetAllPATs(ctx context.Context, accountID strin
|
||||
return nil, status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
initiatorUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, initiatorUserID)
|
||||
initiatorUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, initiatorUserID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
targetUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, targetUserID)
|
||||
targetUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, targetUserID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -470,7 +470,7 @@ func (am *DefaultAccountManager) GetAllPATs(ctx context.Context, accountID strin
|
||||
return nil, status.NewAdminPermissionError()
|
||||
}
|
||||
|
||||
return am.Store.GetUserPATs(ctx, store.LockingStrengthShare, targetUserID)
|
||||
return am.Store.GetUserPATs(ctx, store.LockingStrengthNone, targetUserID)
|
||||
}
|
||||
|
||||
// SaveUser saves updates to the given user. If the user doesn't exist, it will throw status.NotFound error.
|
||||
@@ -511,7 +511,7 @@ func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID,
|
||||
if !allowed {
|
||||
return nil, status.NewPermissionDeniedError()
|
||||
}
|
||||
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID)
|
||||
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -521,7 +521,7 @@ func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID,
|
||||
var addUserEvents []func()
|
||||
var usersToSave = make([]*types.User, 0, len(updates))
|
||||
|
||||
groups, err := am.Store.GetAccountGroups(ctx, store.LockingStrengthShare, accountID)
|
||||
groups, err := am.Store.GetAccountGroups(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error getting account groups: %w", err)
|
||||
}
|
||||
@@ -533,7 +533,7 @@ func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID,
|
||||
|
||||
var initiatorUser *types.User
|
||||
if initiatorUserID != activity.SystemInitiator {
|
||||
result, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, initiatorUserID)
|
||||
result, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, initiatorUserID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -695,7 +695,7 @@ func (am *DefaultAccountManager) processUserUpdate(ctx context.Context, transact
|
||||
|
||||
// getUserOrCreateIfNotExists retrieves the existing user or creates a new one if it doesn't exist.
|
||||
func getUserOrCreateIfNotExists(ctx context.Context, transaction store.Store, accountID string, update *types.User, addIfNotExists bool) (*types.User, error) {
|
||||
existingUser, err := transaction.GetUserByUserID(ctx, store.LockingStrengthShare, update.Id)
|
||||
existingUser, err := transaction.GetUserByUserID(ctx, store.LockingStrengthNone, update.Id)
|
||||
if err != nil {
|
||||
if sErr, ok := status.FromError(err); ok && sErr.Type() == status.NotFound {
|
||||
if !addIfNotExists {
|
||||
@@ -830,7 +830,7 @@ func (am *DefaultAccountManager) GetUsersFromAccount(ctx context.Context, accoun
|
||||
|
||||
var user *types.User
|
||||
if initiatorUserID != activity.SystemInitiator {
|
||||
result, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, initiatorUserID)
|
||||
result, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, initiatorUserID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get user: %w", err)
|
||||
}
|
||||
@@ -840,7 +840,7 @@ func (am *DefaultAccountManager) GetUsersFromAccount(ctx context.Context, accoun
|
||||
accountUsers := []*types.User{}
|
||||
switch {
|
||||
case allowed:
|
||||
accountUsers, err = am.Store.GetAccountUsers(ctx, store.LockingStrengthShare, accountID)
|
||||
accountUsers, err = am.Store.GetAccountUsers(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -933,7 +933,7 @@ func (am *DefaultAccountManager) BuildUserInfosForAccount(ctx context.Context, a
|
||||
|
||||
// expireAndUpdatePeers expires all peers of the given user and updates them in the account
|
||||
func (am *DefaultAccountManager) expireAndUpdatePeers(ctx context.Context, accountID string, peers []*nbpeer.Peer) error {
|
||||
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID)
|
||||
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -1003,7 +1003,7 @@ func (am *DefaultAccountManager) DeleteRegularUsers(ctx context.Context, account
|
||||
return status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
initiatorUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, initiatorUserID)
|
||||
initiatorUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, initiatorUserID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -1017,7 +1017,7 @@ func (am *DefaultAccountManager) DeleteRegularUsers(ctx context.Context, account
|
||||
continue
|
||||
}
|
||||
|
||||
targetUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, targetUserID)
|
||||
targetUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, targetUserID)
|
||||
if err != nil {
|
||||
allErrors = errors.Join(allErrors, err)
|
||||
continue
|
||||
@@ -1081,12 +1081,12 @@ func (am *DefaultAccountManager) deleteRegularUser(ctx context.Context, accountI
|
||||
var err error
|
||||
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
targetUser, err = transaction.GetUserByUserID(ctx, store.LockingStrengthShare, targetUserInfo.ID)
|
||||
targetUser, err = transaction.GetUserByUserID(ctx, store.LockingStrengthNone, targetUserInfo.ID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get user to delete: %w", err)
|
||||
}
|
||||
|
||||
userPeers, err := transaction.GetUserPeers(ctx, store.LockingStrengthShare, accountID, targetUserInfo.ID)
|
||||
userPeers, err := transaction.GetUserPeers(ctx, store.LockingStrengthNone, accountID, targetUserInfo.ID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get user peers: %w", err)
|
||||
}
|
||||
@@ -1120,7 +1120,7 @@ func (am *DefaultAccountManager) deleteRegularUser(ctx context.Context, accountI
|
||||
|
||||
// GetOwnerInfo retrieves the owner information for a given account ID.
|
||||
func (am *DefaultAccountManager) GetOwnerInfo(ctx context.Context, accountID string) (*types.UserInfo, error) {
|
||||
owner, err := am.Store.GetAccountOwner(ctx, store.LockingStrengthShare, accountID)
|
||||
owner, err := am.Store.GetAccountOwner(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -1257,7 +1257,7 @@ func validateUserInvite(invite *types.UserInfo) error {
|
||||
func (am *DefaultAccountManager) GetCurrentUserInfo(ctx context.Context, userAuth nbcontext.UserAuth) (*users.UserInfoWithPermissions, error) {
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
|
||||
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID)
|
||||
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -1274,7 +1274,7 @@ func (am *DefaultAccountManager) GetCurrentUserInfo(ctx context.Context, userAut
|
||||
return nil, err
|
||||
}
|
||||
|
||||
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID)
|
||||
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -56,7 +56,7 @@ func TestUser_CreatePAT_ForSameUser(t *testing.T) {
|
||||
}
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", false)
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
|
||||
|
||||
err = s.SaveAccount(context.Background(), account)
|
||||
if err != nil {
|
||||
@@ -88,7 +88,7 @@ func TestUser_CreatePAT_ForSameUser(t *testing.T) {
|
||||
|
||||
assert.Equal(t, pat.ID, tokenID)
|
||||
|
||||
user, err := am.Store.GetUserByPATID(context.Background(), store.LockingStrengthShare, tokenID)
|
||||
user, err := am.Store.GetUserByPATID(context.Background(), store.LockingStrengthNone, tokenID)
|
||||
if err != nil {
|
||||
t.Fatalf("Error when getting user by token ID: %s", err)
|
||||
}
|
||||
@@ -103,7 +103,7 @@ func TestUser_CreatePAT_ForDifferentUser(t *testing.T) {
|
||||
}
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", false)
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
|
||||
account.Users[mockTargetUserId] = &types.User{
|
||||
Id: mockTargetUserId,
|
||||
IsServiceUser: false,
|
||||
@@ -131,7 +131,7 @@ func TestUser_CreatePAT_ForServiceUser(t *testing.T) {
|
||||
}
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", false)
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
|
||||
account.Users[mockTargetUserId] = &types.User{
|
||||
Id: mockTargetUserId,
|
||||
IsServiceUser: true,
|
||||
@@ -163,7 +163,7 @@ func TestUser_CreatePAT_WithWrongExpiration(t *testing.T) {
|
||||
}
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", false)
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
|
||||
|
||||
err = store.SaveAccount(context.Background(), account)
|
||||
if err != nil {
|
||||
@@ -188,7 +188,7 @@ func TestUser_CreatePAT_WithEmptyName(t *testing.T) {
|
||||
}
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", false)
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
|
||||
|
||||
err = store.SaveAccount(context.Background(), account)
|
||||
if err != nil {
|
||||
@@ -213,7 +213,7 @@ func TestUser_DeletePAT(t *testing.T) {
|
||||
}
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", false)
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
|
||||
account.Users[mockUserID] = &types.User{
|
||||
Id: mockUserID,
|
||||
PATs: map[string]*types.PersonalAccessToken{
|
||||
@@ -256,7 +256,7 @@ func TestUser_GetPAT(t *testing.T) {
|
||||
}
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", false)
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
|
||||
account.Users[mockUserID] = &types.User{
|
||||
Id: mockUserID,
|
||||
AccountID: mockAccountID,
|
||||
@@ -296,7 +296,7 @@ func TestUser_GetAllPATs(t *testing.T) {
|
||||
}
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", false)
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
|
||||
account.Users[mockUserID] = &types.User{
|
||||
Id: mockUserID,
|
||||
AccountID: mockAccountID,
|
||||
@@ -406,7 +406,7 @@ func TestUser_CreateServiceUser(t *testing.T) {
|
||||
}
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", false)
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
|
||||
|
||||
err = store.SaveAccount(context.Background(), account)
|
||||
if err != nil {
|
||||
@@ -453,7 +453,7 @@ func TestUser_CreateUser_ServiceUser(t *testing.T) {
|
||||
}
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", false)
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
|
||||
|
||||
err = store.SaveAccount(context.Background(), account)
|
||||
if err != nil {
|
||||
@@ -501,7 +501,7 @@ func TestUser_CreateUser_RegularUser(t *testing.T) {
|
||||
}
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", false)
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
|
||||
|
||||
err = store.SaveAccount(context.Background(), account)
|
||||
if err != nil {
|
||||
@@ -532,7 +532,7 @@ func TestUser_InviteNewUser(t *testing.T) {
|
||||
}
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", false)
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
|
||||
|
||||
err = store.SaveAccount(context.Background(), account)
|
||||
if err != nil {
|
||||
@@ -639,7 +639,7 @@ func TestUser_DeleteUser_ServiceUser(t *testing.T) {
|
||||
}
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", false)
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
|
||||
account.Users[mockServiceUserID] = tt.serviceUser
|
||||
|
||||
err = store.SaveAccount(context.Background(), account)
|
||||
@@ -678,7 +678,7 @@ func TestUser_DeleteUser_SelfDelete(t *testing.T) {
|
||||
}
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", false)
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
|
||||
|
||||
err = store.SaveAccount(context.Background(), account)
|
||||
if err != nil {
|
||||
@@ -705,7 +705,7 @@ func TestUser_DeleteUser_regularUser(t *testing.T) {
|
||||
}
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", false)
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
|
||||
|
||||
targetId := "user2"
|
||||
account.Users[targetId] = &types.User{
|
||||
@@ -792,7 +792,7 @@ func TestUser_DeleteUser_RegularUsers(t *testing.T) {
|
||||
}
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", false)
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
|
||||
|
||||
targetId := "user2"
|
||||
account.Users[targetId] = &types.User{
|
||||
@@ -952,7 +952,7 @@ func TestDefaultAccountManager_GetUser(t *testing.T) {
|
||||
}
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", false)
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
|
||||
|
||||
err = store.SaveAccount(context.Background(), account)
|
||||
if err != nil {
|
||||
@@ -988,7 +988,7 @@ func TestDefaultAccountManager_ListUsers(t *testing.T) {
|
||||
}
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", false)
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
|
||||
account.Users["normal_user1"] = types.NewRegularUser("normal_user1")
|
||||
account.Users["normal_user2"] = types.NewRegularUser("normal_user2")
|
||||
|
||||
@@ -1030,7 +1030,7 @@ func TestDefaultAccountManager_ExternalCache(t *testing.T) {
|
||||
}
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", false)
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
|
||||
externalUser := &types.User{
|
||||
Id: "externalUser",
|
||||
Role: types.UserRoleUser,
|
||||
@@ -1098,7 +1098,7 @@ func TestUser_GetUsersFromAccount_ForAdmin(t *testing.T) {
|
||||
}
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", false)
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
|
||||
account.Users[mockServiceUserID] = &types.User{
|
||||
Id: mockServiceUserID,
|
||||
Role: "user",
|
||||
@@ -1132,7 +1132,7 @@ func TestUser_GetUsersFromAccount_ForUser(t *testing.T) {
|
||||
}
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", false)
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
|
||||
account.Users[mockServiceUserID] = &types.User{
|
||||
Id: mockServiceUserID,
|
||||
Role: "user",
|
||||
@@ -1499,7 +1499,7 @@ func TestSaveOrAddUser_PreventAccountSwitch(t *testing.T) {
|
||||
}
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
account1 := newAccountWithId(context.Background(), "account1", "ownerAccount1", "", false)
|
||||
account1 := newAccountWithId(context.Background(), "account1", "ownerAccount1", "")
|
||||
targetId := "user2"
|
||||
account1.Users[targetId] = &types.User{
|
||||
Id: targetId,
|
||||
@@ -1508,7 +1508,7 @@ func TestSaveOrAddUser_PreventAccountSwitch(t *testing.T) {
|
||||
}
|
||||
require.NoError(t, s.SaveAccount(context.Background(), account1))
|
||||
|
||||
account2 := newAccountWithId(context.Background(), "account2", "ownerAccount2", "", false)
|
||||
account2 := newAccountWithId(context.Background(), "account2", "ownerAccount2", "")
|
||||
require.NoError(t, s.SaveAccount(context.Background(), account2))
|
||||
|
||||
permissionsManager := permissions.NewManager(s)
|
||||
@@ -1521,7 +1521,7 @@ func TestSaveOrAddUser_PreventAccountSwitch(t *testing.T) {
|
||||
_, err = am.SaveOrAddUser(context.Background(), "account2", "ownerAccount2", account1.Users[targetId], true)
|
||||
assert.Error(t, err, "update user to another account should fail")
|
||||
|
||||
user, err := s.GetUserByUserID(context.Background(), store.LockingStrengthShare, targetId)
|
||||
user, err := s.GetUserByUserID(context.Background(), store.LockingStrengthNone, targetId)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, account1.Users[targetId].Id, user.Id)
|
||||
assert.Equal(t, account1.Users[targetId].AccountID, user.AccountID)
|
||||
@@ -1535,7 +1535,7 @@ func TestDefaultAccountManager_GetCurrentUserInfo(t *testing.T) {
|
||||
}
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
account1 := newAccountWithId(context.Background(), "account1", "account1Owner", "", false)
|
||||
account1 := newAccountWithId(context.Background(), "account1", "account1Owner", "")
|
||||
account1.Settings.RegularUsersViewBlocked = false
|
||||
account1.Users["blocked-user"] = &types.User{
|
||||
Id: "blocked-user",
|
||||
@@ -1557,7 +1557,7 @@ func TestDefaultAccountManager_GetCurrentUserInfo(t *testing.T) {
|
||||
}
|
||||
require.NoError(t, store.SaveAccount(context.Background(), account1))
|
||||
|
||||
account2 := newAccountWithId(context.Background(), "account2", "account2Owner", "", false)
|
||||
account2 := newAccountWithId(context.Background(), "account2", "account2Owner", "")
|
||||
account2.Users["settings-blocked-user"] = &types.User{
|
||||
Id: "settings-blocked-user",
|
||||
Role: types.UserRoleUser,
|
||||
|
||||
@@ -26,7 +26,7 @@ func NewManager(store store.Store) Manager {
|
||||
}
|
||||
|
||||
func (m *managerImpl) GetUser(ctx context.Context, userID string) (*types.User, error) {
|
||||
return m.store.GetUserByUserID(ctx, store.LockingStrengthShare, userID)
|
||||
return m.store.GetUserByUserID(ctx, store.LockingStrengthNone, userID)
|
||||
}
|
||||
|
||||
func NewManagerMock() Manager {
|
||||
|
||||
@@ -130,7 +130,7 @@ repo_gpgcheck=1
|
||||
EOF
|
||||
}
|
||||
|
||||
install_aur_package() {
|
||||
add_aur_repo() {
|
||||
INSTALL_PKGS="git base-devel go"
|
||||
REMOVE_PKGS=""
|
||||
|
||||
@@ -154,10 +154,8 @@ install_aur_package() {
|
||||
cd netbird-ui && makepkg -sri --noconfirm
|
||||
fi
|
||||
|
||||
if [ -n "$REMOVE_PKGS" ]; then
|
||||
# Clean up the installed packages
|
||||
${SUDO} pacman -Rs "$REMOVE_PKGS" --noconfirm
|
||||
fi
|
||||
# Clean up the installed packages
|
||||
${SUDO} pacman -Rs "$REMOVE_PKGS" --noconfirm
|
||||
}
|
||||
|
||||
prepare_tun_module() {
|
||||
@@ -279,9 +277,7 @@ install_netbird() {
|
||||
;;
|
||||
pacman)
|
||||
${SUDO} pacman -Syy
|
||||
install_aur_package
|
||||
# in-line with the docs at https://wiki.archlinux.org/title/Netbird
|
||||
${SUDO} systemctl enable --now netbird@main.service
|
||||
add_aur_repo
|
||||
;;
|
||||
pkg)
|
||||
# Check if the package is already installed
|
||||
@@ -498,4 +494,4 @@ case "$UPDATE_FLAG" in
|
||||
;;
|
||||
*)
|
||||
install_netbird
|
||||
esac
|
||||
esac
|
||||
Reference in New Issue
Block a user