Compare commits

..

33 Commits

Author SHA1 Message Date
Misha Bragin
515ce9e3af Update management/server/sqlite_store.go 2024-04-17 20:55:32 +02:00
Misha Bragin
89383b7f01 Update management/server/sqlite_store.go 2024-04-17 20:55:01 +02:00
Misha Bragin
db34162733 Update management/server/sqlite_store.go 2024-04-17 20:54:14 +02:00
Misha Bragin
bd761e2177 Update management/server/sqlite_store.go 2024-04-17 20:53:32 +02:00
Misha Bragin
4e1b95a4c6 Update management/server/sqlite_store.go 2024-04-17 20:53:24 +02:00
Misha Bragin
05993af7bf Update management/server/sqlite_store.go 2024-04-17 20:53:11 +02:00
braginini
9d1cb00570 Fix setup keys test 2024-04-17 20:27:55 +02:00
braginini
543731df45 Fix setup keys test 2024-04-17 19:58:24 +02:00
braginini
e6628ec231 Fix setup keys 2024-04-17 19:48:09 +02:00
braginini
41d4dd2aff reduce log level of scheduler to trace 2024-04-17 19:34:59 +02:00
braginini
30bed57711 Fix account deletion 2024-04-17 19:12:53 +02:00
braginini
6960b68322 Add pats to test save account 2024-04-17 19:07:17 +02:00
braginini
3b3aa18148 Store setup keys and ns groups in a batch 2024-04-17 18:32:13 +02:00
braginini
93045f3e3a Fix rand lint issue 2024-04-17 18:07:02 +02:00
braginini
fd3c1dea8e Add save large account test 2024-04-17 18:02:10 +02:00
braginini
48aff7a26e Fix test compilation errors 2024-04-17 17:39:28 +02:00
braginini
83dfe8e3a3 Fix test compilation errors 2024-04-17 17:27:23 +02:00
braginini
38e10af2d9 Add accountID reference 2024-04-17 17:16:56 +02:00
braginini
99854a126a Add comments 2024-04-17 17:08:01 +02:00
braginini
a75f982fcd Copy account when storing to avoid reference issues 2024-04-17 17:03:21 +02:00
braginini
e7a6483912 Optimize all other objects storing in SQLite 2024-04-17 12:35:41 +02:00
braginini
30ede299b8 Optimize peer storing in SQLite 2024-04-17 11:50:33 +02:00
Viktor Liu
e3b76448f3 Fix ICE endpoint remote port in status command (#1851) 2024-04-16 14:01:59 +02:00
Viktor Liu
e0de86d6c9 Use fixed activity codes (#1846)
* Add duplicate constants check
2024-04-15 14:15:46 +02:00
Zoltan Papp
5204d07811 Pass integrated validator for API (#1814)
Pass integrated validator for API handler
2024-04-15 12:08:38 +02:00
Viktor Liu
5ea24ba56e Add sysctl opts to prevent reverse path filtering from dropping fwmark packets (#1839) 2024-04-12 17:53:07 +02:00
Viktor Liu
d30cf8706a Allow disabling custom routing (#1840) 2024-04-12 16:53:11 +02:00
Viktor Liu
15a2feb723 Use fixed preference for rules (#1836) 2024-04-12 16:07:03 +02:00
Viktor Liu
91b2f9fc51 Use route active store (#1834) 2024-04-12 15:22:40 +02:00
Carlos Hernandez
76702c8a09 Add safe read/write to route map (#1760) 2024-04-11 22:12:23 +02:00
Viktor Liu
061f673a4f Don't use the custom dialer as non-root (#1823) 2024-04-11 15:29:03 +02:00
Zoltan Papp
9505805313 Rename variable (#1829) 2024-04-11 14:08:03 +02:00
Maycon Santos
704c67dec8 Allow owners that did not create the account to delete it (#1825)
Sometimes the Owner role will be passed to new users, and they need to be able to delete the account
2024-04-11 10:02:51 +02:00
42 changed files with 648 additions and 233 deletions

View File

@@ -33,6 +33,10 @@ jobs:
steps:
- name: Checkout code
uses: actions/checkout@v3
- name: Check for duplicate constants
if: matrix.os == 'ubuntu-latest'
run: |
! awk '/const \(/,/)/{print $0}' management/server/activity/codes.go | grep -o '= [0-9]*' | sort | uniq -d | grep .
- name: Install Go
uses: actions/setup-go@v4
with:

View File

@@ -794,6 +794,7 @@ func (e *Engine) updateOfflinePeers(offlinePeers []*mgmProto.RemotePeerConfig) {
FQDN: offlinePeer.GetFqdn(),
ConnStatus: peer.StatusDisconnected,
ConnStatusUpdate: time.Now(),
Mux: new(sync.RWMutex),
}
}
e.statusRecorder.ReplaceOfflinePeers(replacement)

View File

@@ -229,7 +229,6 @@ func (conn *Conn) reCreateAgent() error {
}
conn.agent, err = ice.NewAgent(agentConfig)
if err != nil {
return err
}
@@ -285,6 +284,7 @@ func (conn *Conn) Open() error {
IP: strings.Split(conn.config.WgConfig.AllowedIps, "/")[0],
ConnStatusUpdate: time.Now(),
ConnStatus: conn.status,
Mux: new(sync.RWMutex),
}
err := conn.statusRecorder.UpdatePeerState(peerState)
if err != nil {
@@ -344,6 +344,7 @@ func (conn *Conn) Open() error {
PubKey: conn.config.Key,
ConnStatus: conn.status,
ConnStatusUpdate: time.Now(),
Mux: new(sync.RWMutex),
}
err = conn.statusRecorder.UpdatePeerState(peerState)
if err != nil {
@@ -465,9 +466,10 @@ func (conn *Conn) configureConnection(remoteConn net.Conn, remoteWgPort int, rem
LocalIceCandidateType: pair.Local.Type().String(),
RemoteIceCandidateType: pair.Remote.Type().String(),
LocalIceCandidateEndpoint: fmt.Sprintf("%s:%d", pair.Local.Address(), pair.Local.Port()),
RemoteIceCandidateEndpoint: fmt.Sprintf("%s:%d", pair.Remote.Address(), pair.Local.Port()),
RemoteIceCandidateEndpoint: fmt.Sprintf("%s:%d", pair.Remote.Address(), pair.Remote.Port()),
Direct: !isRelayCandidate(pair.Local),
RosenpassEnabled: rosenpassEnabled,
Mux: new(sync.RWMutex),
}
if pair.Local.Type() == ice.CandidateTypeRelay || pair.Remote.Type() == ice.CandidateTypeRelay {
peerState.Relayed = true
@@ -558,6 +560,7 @@ func (conn *Conn) cleanup() error {
PubKey: conn.config.Key,
ConnStatus: conn.status,
ConnStatusUpdate: time.Now(),
Mux: new(sync.RWMutex),
}
err := conn.statusRecorder.UpdatePeerState(peerState)
if err != nil {

View File

@@ -14,6 +14,7 @@ import (
// State contains the latest state of a peer
type State struct {
Mux *sync.RWMutex
IP string
PubKey string
FQDN string
@@ -30,7 +31,38 @@ type State struct {
BytesRx int64
Latency time.Duration
RosenpassEnabled bool
Routes map[string]struct{}
routes map[string]struct{}
}
// AddRoute add a single route to routes map
func (s *State) AddRoute(network string) {
s.Mux.Lock()
if s.routes == nil {
s.routes = make(map[string]struct{})
}
s.routes[network] = struct{}{}
s.Mux.Unlock()
}
// SetRoutes set state routes
func (s *State) SetRoutes(routes map[string]struct{}) {
s.Mux.Lock()
s.routes = routes
s.Mux.Unlock()
}
// DeleteRoute removes a route from the network amp
func (s *State) DeleteRoute(network string) {
s.Mux.Lock()
delete(s.routes, network)
s.Mux.Unlock()
}
// GetRoutes return routes map
func (s *State) GetRoutes() map[string]struct{} {
s.Mux.RLock()
defer s.Mux.RUnlock()
return s.routes
}
// LocalPeerState contains the latest state of the local peer
@@ -143,6 +175,7 @@ func (d *Status) AddPeer(peerPubKey string, fqdn string) error {
PubKey: peerPubKey,
ConnStatus: StatusDisconnected,
FQDN: fqdn,
Mux: new(sync.RWMutex),
}
d.peerListChangedForNotification = true
return nil
@@ -189,8 +222,8 @@ func (d *Status) UpdatePeerState(receivedState State) error {
peerState.IP = receivedState.IP
}
if receivedState.Routes != nil {
peerState.Routes = receivedState.Routes
if receivedState.GetRoutes() != nil {
peerState.SetRoutes(receivedState.GetRoutes())
}
skipNotification := shouldSkipNotify(receivedState, peerState)
@@ -440,7 +473,6 @@ func (d *Status) IsLoginRequired() bool {
s, ok := gstatus.FromError(d.managementError)
if ok && (s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied) {
return true
}
return false
}

View File

@@ -3,6 +3,7 @@ package peer
import (
"errors"
"testing"
"sync"
"github.com/stretchr/testify/assert"
)
@@ -42,6 +43,7 @@ func TestUpdatePeerState(t *testing.T) {
status := NewRecorder("https://mgm")
peerState := State{
PubKey: key,
Mux: new(sync.RWMutex),
}
status.peers[key] = peerState
@@ -62,6 +64,7 @@ func TestStatus_UpdatePeerFQDN(t *testing.T) {
status := NewRecorder("https://mgm")
peerState := State{
PubKey: key,
Mux: new(sync.RWMutex),
}
status.peers[key] = peerState
@@ -80,6 +83,7 @@ func TestGetPeerStateChangeNotifierLogic(t *testing.T) {
status := NewRecorder("https://mgm")
peerState := State{
PubKey: key,
Mux: new(sync.RWMutex),
}
status.peers[key] = peerState
@@ -104,6 +108,7 @@ func TestRemovePeer(t *testing.T) {
status := NewRecorder("https://mgm")
peerState := State{
PubKey: key,
Mux: new(sync.RWMutex),
}
status.peers[key] = peerState

View File

@@ -196,7 +196,7 @@ func (c *clientNetwork) removeRouteFromWireguardPeer(peerKey string) error {
return fmt.Errorf("get peer state: %v", err)
}
delete(state.Routes, c.network.String())
state.DeleteRoute(c.network.String())
if err := c.statusRecorder.UpdatePeerState(state); err != nil {
log.Warnf("Failed to update peer state: %v", err)
}
@@ -268,10 +268,7 @@ func (c *clientNetwork) recalculateRouteAndUpdatePeerAndSystem() error {
if err != nil {
log.Errorf("Failed to get peer state: %v", err)
} else {
if state.Routes == nil {
state.Routes = map[string]struct{}{}
}
state.Routes[c.network.String()] = struct{}{}
state.AddRoute(c.network.String())
if err := c.statusRecorder.UpdatePeerState(state); err != nil {
log.Warnf("Failed to update peer state: %v", err)
}

View File

@@ -16,6 +16,7 @@ import (
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/iface"
"github.com/netbirdio/netbird/route"
nbnet "github.com/netbirdio/netbird/util/net"
"github.com/netbirdio/netbird/version"
)
@@ -68,6 +69,10 @@ func NewManager(ctx context.Context, pubKey string, wgInterface *iface.WGIface,
// Init sets up the routing
func (m *DefaultManager) Init() (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) {
if nbnet.CustomRoutingDisabled() {
return nil, nil, nil
}
if err := cleanupRouting(); err != nil {
log.Warnf("Failed cleaning up routing: %v", err)
}
@@ -99,11 +104,15 @@ func (m *DefaultManager) Stop() {
if m.serverRouter != nil {
m.serverRouter.cleanUp()
}
if err := cleanupRouting(); err != nil {
log.Errorf("Error cleaning up routing: %v", err)
} else {
log.Info("Routing cleanup complete")
if !nbnet.CustomRoutingDisabled() {
if err := cleanupRouting(); err != nil {
log.Errorf("Error cleaning up routing: %v", err)
} else {
log.Info("Routing cleanup complete")
}
}
m.ctx = nil
}
@@ -210,9 +219,11 @@ func (m *DefaultManager) clientRoutes(initialRoutes []*route.Route) []*route.Rou
}
func isPrefixSupported(prefix netip.Prefix) bool {
switch runtime.GOOS {
case "linux", "windows", "darwin":
return true
if !nbnet.CustomRoutingDisabled() {
switch runtime.GOOS {
case "linux", "windows", "darwin":
return true
}
}
// If prefix is too small, lets assume it is a possible default prefix which is not yet supported

View File

@@ -4,14 +4,14 @@ package routemanager
import (
"bufio"
"context"
"errors"
"fmt"
"net"
"net/netip"
"os"
"strconv"
"strings"
"syscall"
"time"
"github.com/hashicorp/go-multierror"
log "github.com/sirupsen/logrus"
@@ -32,19 +32,31 @@ const (
rtTablesPath = "/etc/iproute2/rt_tables"
// ipv4ForwardingPath is the path to the file containing the IP forwarding setting.
ipv4ForwardingPath = "/proc/sys/net/ipv4/ip_forward"
ipv4ForwardingPath = "net.ipv4.ip_forward"
rpFilterPath = "net.ipv4.conf.all.rp_filter"
rpFilterInterfacePath = "net.ipv4.conf.%s.rp_filter"
srcValidMarkPath = "net.ipv4.conf.all.src_valid_mark"
)
var ErrTableIDExists = errors.New("ID exists with different name")
var routeManager = &RouteManager{}
var isLegacy = os.Getenv("NB_USE_LEGACY_ROUTING") == "true"
// originalSysctl stores the original sysctl values before they are modified
var originalSysctl map[string]int
// determines whether to use the legacy routing setup
var isLegacy = os.Getenv("NB_USE_LEGACY_ROUTING") == "true" || nbnet.CustomRoutingDisabled()
// sysctlFailed is used as an indicator to emit a warning when default routes are configured
var sysctlFailed bool
type ruleParams struct {
priority int
fwmark int
tableID int
family int
priority int
invert bool
suppressPrefix int
description string
@@ -52,10 +64,10 @@ type ruleParams struct {
func getSetupRules() []ruleParams {
return []ruleParams{
{nbnet.NetbirdFwmark, NetbirdVPNTableID, netlink.FAMILY_V4, -1, true, -1, "rule v4 netbird"},
{nbnet.NetbirdFwmark, NetbirdVPNTableID, netlink.FAMILY_V6, -1, true, -1, "rule v6 netbird"},
{-1, syscall.RT_TABLE_MAIN, netlink.FAMILY_V4, -1, false, 0, "rule with suppress prefixlen v4"},
{-1, syscall.RT_TABLE_MAIN, netlink.FAMILY_V6, -1, false, 0, "rule with suppress prefixlen v6"},
{100, -1, syscall.RT_TABLE_MAIN, netlink.FAMILY_V4, false, 0, "rule with suppress prefixlen v4"},
{100, -1, syscall.RT_TABLE_MAIN, netlink.FAMILY_V6, false, 0, "rule with suppress prefixlen v6"},
{110, nbnet.NetbirdFwmark, NetbirdVPNTableID, netlink.FAMILY_V4, true, -1, "rule v4 netbird"},
{110, nbnet.NetbirdFwmark, NetbirdVPNTableID, netlink.FAMILY_V6, true, -1, "rule v6 netbird"},
}
}
@@ -69,8 +81,6 @@ func getSetupRules() []ruleParams {
// Rule 2 (VPN Traffic Routing): Directs all remaining traffic to the 'NetbirdVPNTableID' custom routing table.
// This table is where a default route or other specific routes received from the management server are configured,
// enabling VPN connectivity.
//
// The rules are inserted in reverse order, as rules are added from the bottom up in the rule list.
func setupRouting(initAddresses []net.IP, wgIface *iface.WGIface) (_ peer.BeforeAddPeerHookFunc, _ peer.AfterRemovePeerHookFunc, err error) {
if isLegacy {
log.Infof("Using legacy routing setup")
@@ -81,6 +91,13 @@ func setupRouting(initAddresses []net.IP, wgIface *iface.WGIface) (_ peer.Before
log.Errorf("Error adding routing table name: %v", err)
}
originalValues, err := setupSysctl(wgIface)
if err != nil {
log.Errorf("Error setting up sysctl: %v", err)
sysctlFailed = true
}
originalSysctl = originalValues
defer func() {
if err != nil {
if cleanErr := cleanupRouting(); cleanErr != nil {
@@ -123,11 +140,17 @@ func cleanupRouting() error {
rules := getSetupRules()
for _, rule := range rules {
if err := removeAllRules(rule); err != nil && !errors.Is(err, syscall.EOPNOTSUPP) {
if err := removeRule(rule); err != nil {
result = multierror.Append(result, fmt.Errorf("%s: %w", rule.description, err))
}
}
if err := cleanupSysctl(originalSysctl); err != nil {
result = multierror.Append(result, fmt.Errorf("cleanup sysctl: %w", err))
}
originalSysctl = nil
sysctlFailed = false
return result.ErrorOrNil()
}
@@ -144,6 +167,10 @@ func addVPNRoute(prefix netip.Prefix, intf string) error {
return genericAddVPNRoute(prefix, intf)
}
if sysctlFailed && (prefix == defaultv4 || prefix == defaultv6) {
log.Warnf("Default route is configured but sysctl operations failed, VPN traffic may not be routed correctly, consider using NB_USE_LEGACY_ROUTING=true or setting net.ipv4.conf.*.rp_filter to 2 (loose) or 0 (off)")
}
// No need to check if routes exist as main table takes precedence over the VPN table via Rule 1
// TODO remove this once we have ipv6 support
@@ -336,22 +363,8 @@ func flushRoutes(tableID, family int) error {
}
func enableIPForwarding() error {
bytes, err := os.ReadFile(ipv4ForwardingPath)
if err != nil {
return fmt.Errorf("read file %s: %w", ipv4ForwardingPath, err)
}
// check if it is already enabled
// see more: https://github.com/netbirdio/netbird/issues/872
if len(bytes) > 0 && bytes[0] == 49 {
return nil
}
//nolint:gosec
if err := os.WriteFile(ipv4ForwardingPath, []byte("1"), 0644); err != nil {
return fmt.Errorf("write file %s: %w", ipv4ForwardingPath, err)
}
return nil
_, err := setSysctl(ipv4ForwardingPath, 1, false)
return err
}
// entryExists checks if the specified ID or name already exists in the rt_tables file
@@ -429,7 +442,7 @@ func addRule(params ruleParams) error {
rule.Invert = params.invert
rule.SuppressPrefixlen = params.suppressPrefix
if err := netlink.RuleAdd(rule); err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) {
if err := netlink.RuleAdd(rule); err != nil && !errors.Is(err, syscall.EEXIST) && !errors.Is(err, syscall.EAFNOSUPPORT) {
return fmt.Errorf("add routing rule: %w", err)
}
@@ -446,43 +459,13 @@ func removeRule(params ruleParams) error {
rule.Priority = params.priority
rule.SuppressPrefixlen = params.suppressPrefix
if err := netlink.RuleDel(rule); err != nil {
if err := netlink.RuleDel(rule); err != nil && !errors.Is(err, syscall.ENOENT) && !errors.Is(err, syscall.EAFNOSUPPORT) {
return fmt.Errorf("remove routing rule: %w", err)
}
return nil
}
func removeAllRules(params ruleParams) error {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
done := make(chan error, 1)
go func() {
for {
if ctx.Err() != nil {
done <- ctx.Err()
return
}
if err := removeRule(params); err != nil {
if errors.Is(err, syscall.ENOENT) || errors.Is(err, syscall.EAFNOSUPPORT) {
done <- nil
return
}
done <- err
return
}
}
}()
select {
case <-ctx.Done():
return ctx.Err()
case err := <-done:
return err
}
}
// addNextHop adds the gateway and device to the route.
func addNextHop(addr netip.Addr, intf string, route *netlink.Route) error {
if addr.IsValid() {
@@ -509,3 +492,83 @@ func getAddressFamily(prefix netip.Prefix) int {
}
return netlink.FAMILY_V6
}
// setupSysctl configures sysctl settings for RP filtering and source validation.
func setupSysctl(wgIface *iface.WGIface) (map[string]int, error) {
keys := map[string]int{}
var result *multierror.Error
oldVal, err := setSysctl(srcValidMarkPath, 1, false)
if err != nil {
result = multierror.Append(result, err)
} else {
keys[srcValidMarkPath] = oldVal
}
oldVal, err = setSysctl(rpFilterPath, 2, true)
if err != nil {
result = multierror.Append(result, err)
} else {
keys[rpFilterPath] = oldVal
}
interfaces, err := net.Interfaces()
if err != nil {
result = multierror.Append(result, fmt.Errorf("list interfaces: %w", err))
}
for _, intf := range interfaces {
if intf.Name == "lo" || wgIface != nil && intf.Name == wgIface.Name() {
continue
}
i := fmt.Sprintf(rpFilterInterfacePath, intf.Name)
oldVal, err := setSysctl(i, 2, true)
if err != nil {
result = multierror.Append(result, err)
} else {
keys[i] = oldVal
}
}
return keys, result.ErrorOrNil()
}
// setSysctl sets a sysctl configuration, if onlyIfOne is true it will only set the new value if it's set to 1
func setSysctl(key string, desiredValue int, onlyIfOne bool) (int, error) {
path := fmt.Sprintf("/proc/sys/%s", strings.ReplaceAll(key, ".", "/"))
currentValue, err := os.ReadFile(path)
if err != nil {
return -1, fmt.Errorf("read sysctl %s: %w", key, err)
}
currentV, err := strconv.Atoi(strings.TrimSpace(string(currentValue)))
if err != nil && len(currentValue) > 0 {
return -1, fmt.Errorf("convert current desiredValue to int: %w", err)
}
if currentV == desiredValue || onlyIfOne && currentV != 1 {
return currentV, nil
}
//nolint:gosec
if err := os.WriteFile(path, []byte(strconv.Itoa(desiredValue)), 0644); err != nil {
return currentV, fmt.Errorf("write sysctl %s: %w", key, err)
}
log.Debugf("Set sysctl %s from %d to %d", key, currentV, desiredValue)
return currentV, nil
}
func cleanupSysctl(originalSettings map[string]int) error {
var result *multierror.Error
for key, value := range originalSettings {
_, err := setSysctl(key, value, false)
if err != nil {
result = multierror.Append(result, err)
}
}
return result.ErrorOrNil()
}

View File

@@ -61,7 +61,7 @@ func TestAddRemoveRoutes(t *testing.T) {
err = wgInterface.Create()
require.NoError(t, err, "should create testing wireguard interface")
_, _, err = setupRouting(nil, nil)
_, _, err = setupRouting(nil, wgInterface)
require.NoError(t, err)
t.Cleanup(func() {
assert.NoError(t, cleanupRouting())

View File

@@ -73,7 +73,7 @@ func addRoutePowershell(prefix netip.Prefix, nexthop netip.Addr, intf, intfIdx s
}
script := fmt.Sprintf(
`%s -AddressFamily "%s" -DestinationPrefix "%s" -Confirm:$False -ErrorAction Stop`,
`%s -AddressFamily "%s" -DestinationPrefix "%s" -Confirm:$False -ErrorAction Stop -PolicyStore ActiveStore`,
psCmd, addressFamily, destinationPrefix,
)

View File

@@ -230,7 +230,7 @@ func (p *WGEBPFProxy) prepareSenderRawSocket() (net.PacketConn, error) {
}
// Set the fwmark on the socket.
err = syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_MARK, nbnet.NetbirdFwmark)
err = nbnet.SetSocketOpt(fd)
if err != nil {
return nil, fmt.Errorf("setting fwmark failed: %w", err)
}

View File

@@ -718,7 +718,7 @@ func toProtoFullStatus(fullStatus peer.FullStatus) *proto.FullStatus {
BytesRx: peerState.BytesRx,
BytesTx: peerState.BytesTx,
RosenpassEnabled: peerState.RosenpassEnabled,
Routes: maps.Keys(peerState.Routes),
Routes: maps.Keys(peerState.GetRoutes()),
Latency: durationpb.New(peerState.Latency),
}
pbFullStatus.Peers = append(pbFullStatus.Peers, pbPeerState)

2
go.mod
View File

@@ -60,7 +60,7 @@ require (
github.com/miekg/dns v1.1.43
github.com/mitchellh/hashstructure/v2 v2.0.2
github.com/nadoo/ipset v0.5.0
github.com/netbirdio/management-integrations/integrations v0.0.0-20240326083846-3682438fca98
github.com/netbirdio/management-integrations/integrations v0.0.0-20240415094251-369eb33c9b01
github.com/okta/okta-sdk-golang/v2 v2.18.0
github.com/oschwald/maxminddb-golang v1.12.0
github.com/patrickmn/go-cache v2.1.0+incompatible

4
go.sum
View File

@@ -383,8 +383,8 @@ github.com/nadoo/ipset v0.5.0 h1:5GJUAuZ7ITQQQGne5J96AmFjRtI8Avlbk6CabzYWVUc=
github.com/nadoo/ipset v0.5.0/go.mod h1:rYF5DQLRGGoQ8ZSWeK+6eX5amAuPqwFkWjhQlEITGJQ=
github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e h1:PURA50S8u4mF6RrkYYCAvvPCixhqqEiEy3Ej6avh04c=
github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e/go.mod h1:YMLU7qbKfVjmEv7EoZPIVEI+kNYxWCdPK3VS0BU+U4Q=
github.com/netbirdio/management-integrations/integrations v0.0.0-20240326083846-3682438fca98 h1:i6AtenTLu/CqhTmj0g1K/GWkkpMJMhQM6Vjs46x25nA=
github.com/netbirdio/management-integrations/integrations v0.0.0-20240326083846-3682438fca98/go.mod h1:kxks50DrZnhW+oRTdHOkVOJbcTcyo766am8RBugo+Yc=
github.com/netbirdio/management-integrations/integrations v0.0.0-20240415094251-369eb33c9b01 h1:Fu9fq0ndfKVuFTEwbc8Etqui10BOkcMTv0UqcMy0RuY=
github.com/netbirdio/management-integrations/integrations v0.0.0-20240415094251-369eb33c9b01/go.mod h1:kxks50DrZnhW+oRTdHOkVOJbcTcyo766am8RBugo+Yc=
github.com/netbirdio/service v0.0.0-20230215170314-b923b89432b0 h1:hirFRfx3grVA/9eEyjME5/z3nxdJlN9kfQpvWWPk32g=
github.com/netbirdio/service v0.0.0-20230215170314-b923b89432b0/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
github.com/netbirdio/systray v0.0.0-20231030152038-ef1ed2a27949 h1:xbWM9BU6mwZZLHxEjxIX/V8Hv3HurQt4mReIE4mY4DM=

View File

@@ -10,8 +10,6 @@ import (
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
nbnet "github.com/netbirdio/netbird/util/net"
)
type wgKernelConfigurer struct {
@@ -31,7 +29,7 @@ func (c *wgKernelConfigurer) configureInterface(privateKey string, port int) err
if err != nil {
return err
}
fwmark := nbnet.NetbirdFwmark
fwmark := getFwmark()
config := wgtypes.Config{
PrivateKey: &key,
ReplacePeers: true,

View File

@@ -349,7 +349,7 @@ func toWgUserspaceString(wgCfg wgtypes.Config) string {
}
func getFwmark() int {
if runtime.GOOS == "linux" {
if runtime.GOOS == "linux" && !nbnet.CustomRoutingDisabled() {
return nbnet.NetbirdFwmark
}
return 0

View File

@@ -251,7 +251,7 @@ var (
ctx, cancel := context.WithCancel(cmd.Context())
defer cancel()
httpAPIHandler, err := httpapi.APIHandler(ctx, accountManager, geo, *jwtValidator, appMetrics, httpAPIAuthCfg)
httpAPIHandler, err := httpapi.APIHandler(ctx, accountManager, geo, *jwtValidator, appMetrics, httpAPIAuthCfg, integratedPeerValidator)
if err != nil {
return fmt.Errorf("failed creating HTTP API handler: %v", err)
}

View File

@@ -242,19 +242,19 @@ type UserPermissions struct {
}
type UserInfo struct {
ID string `json:"id"`
Email string `json:"email"`
Name string `json:"name"`
Role string `json:"role"`
AutoGroups []string `json:"auto_groups"`
Status string `json:"-"`
IsServiceUser bool `json:"is_service_user"`
IsBlocked bool `json:"is_blocked"`
NonDeletable bool `json:"non_deletable"`
LastLogin time.Time `json:"last_login"`
Issued string `json:"issued"`
ID string `json:"id"`
Email string `json:"email"`
Name string `json:"name"`
Role string `json:"role"`
AutoGroups []string `json:"auto_groups"`
Status string `json:"-"`
IsServiceUser bool `json:"is_service_user"`
IsBlocked bool `json:"is_blocked"`
NonDeletable bool `json:"non_deletable"`
LastLogin time.Time `json:"last_login"`
Issued string `json:"issued"`
IntegrationReference integration_reference.IntegrationReference `json:"-"`
Permissions UserPermissions `json:"permissions"`
Permissions UserPermissions `json:"permissions"`
}
// getRoutesToSync returns the enabled routes for the peer ID and the routes
@@ -1120,7 +1120,7 @@ func (am *DefaultAccountManager) DeleteAccount(accountID, userID string) error {
return status.Errorf(status.PermissionDenied, "user is not allowed to delete account")
}
if user.Id != account.CreatedBy {
if user.Role != UserRoleOwner {
return status.Errorf(status.PermissionDenied, "user is not allowed to delete account. Only account owner can delete account")
}
for _, otherUser := range account.Users {
@@ -1473,7 +1473,7 @@ func (am *DefaultAccountManager) handleNewUserAccount(domainAcc *Account, claims
// if domain already has a primary account, add regular user
if domainAcc != nil {
account = domainAcc
account.Users[claims.UserId] = NewRegularUser(claims.UserId)
account.Users[claims.UserId] = NewRegularUser(claims.UserId, account.Id)
err = am.Store.SaveAccount(account)
if err != nil {
return nil, err
@@ -1849,6 +1849,7 @@ func (am *DefaultAccountManager) CheckUserAccessByJWTGroups(claims jwtclaims.Aut
}
func (am *DefaultAccountManager) onPeersInvalidated(accountID string) {
log.Debugf("validated peers has been invalidated for account %s", accountID)
updatedAccount, err := am.Store.GetAccount(accountID)
if err != nil {
log.Errorf("failed to get account %s: %v", accountID, err)
@@ -1861,9 +1862,10 @@ func (am *DefaultAccountManager) onPeersInvalidated(accountID string) {
func addAllGroup(account *Account) error {
if len(account.Groups) == 0 {
allGroup := &nbgroup.Group{
ID: xid.New().String(),
Name: "All",
Issued: nbgroup.GroupIssuedAPI,
ID: xid.New().String(),
Name: "All",
Issued: nbgroup.GroupIssuedAPI,
AccountID: account.Id,
}
for _, peer := range account.Peers {
allGroup.Peers = append(allGroup.Peers, peer.ID)
@@ -1907,7 +1909,7 @@ func newAccountWithId(accountID, userID, domain string) *Account {
routes := make(map[string]*route.Route)
setupKeys := map[string]*SetupKey{}
nameServersGroups := make(map[string]*nbdns.NameServerGroup)
users[userID] = NewOwnerUser(userID)
users[userID] = NewOwnerUser(userID, accountID)
dnsSettings := DNSSettings{
DisabledManagementGroups: make([]string, 0),
}

View File

@@ -11,133 +11,134 @@ type Code struct {
Code string
}
// Existing consts must not be changed, as this will break the compatibility with the existing data
const (
// PeerAddedByUser indicates that a user added a new peer to the system
PeerAddedByUser Activity = iota
PeerAddedByUser Activity = 0
// PeerAddedWithSetupKey indicates that a new peer joined the system using a setup key
PeerAddedWithSetupKey
PeerAddedWithSetupKey Activity = 1
// UserJoined indicates that a new user joined the account
UserJoined
UserJoined Activity = 2
// UserInvited indicates that a new user was invited to join the account
UserInvited
UserInvited Activity = 3
// AccountCreated indicates that a new account has been created
AccountCreated
AccountCreated Activity = 4
// PeerRemovedByUser indicates that a user removed a peer from the system
PeerRemovedByUser
PeerRemovedByUser Activity = 5
// RuleAdded indicates that a user added a new rule
RuleAdded
RuleAdded Activity = 6
// RuleUpdated indicates that a user updated a rule
RuleUpdated
RuleUpdated Activity = 7
// RuleRemoved indicates that a user removed a rule
RuleRemoved
RuleRemoved Activity = 8
// PolicyAdded indicates that a user added a new policy
PolicyAdded
PolicyAdded Activity = 9
// PolicyUpdated indicates that a user updated a policy
PolicyUpdated
PolicyUpdated Activity = 10
// PolicyRemoved indicates that a user removed a policy
PolicyRemoved
PolicyRemoved Activity = 11
// SetupKeyCreated indicates that a user created a new setup key
SetupKeyCreated
SetupKeyCreated Activity = 12
// SetupKeyUpdated indicates that a user updated a setup key
SetupKeyUpdated
SetupKeyUpdated Activity = 13
// SetupKeyRevoked indicates that a user revoked a setup key
SetupKeyRevoked
SetupKeyRevoked Activity = 14
// SetupKeyOverused indicates that setup key usage exhausted
SetupKeyOverused
SetupKeyOverused Activity = 15
// GroupCreated indicates that a user created a group
GroupCreated
GroupCreated Activity = 16
// GroupUpdated indicates that a user updated a group
GroupUpdated
GroupUpdated Activity = 17
// GroupAddedToPeer indicates that a user added group to a peer
GroupAddedToPeer
GroupAddedToPeer Activity = 18
// GroupRemovedFromPeer indicates that a user removed peer group
GroupRemovedFromPeer
GroupRemovedFromPeer Activity = 19
// GroupAddedToUser indicates that a user added group to a user
GroupAddedToUser
GroupAddedToUser Activity = 20
// GroupRemovedFromUser indicates that a user removed a group from a user
GroupRemovedFromUser
GroupRemovedFromUser Activity = 21
// UserRoleUpdated indicates that a user changed the role of a user
UserRoleUpdated
UserRoleUpdated Activity = 22
// GroupAddedToSetupKey indicates that a user added group to a setup key
GroupAddedToSetupKey
GroupAddedToSetupKey Activity = 23
// GroupRemovedFromSetupKey indicates that a user removed a group from a setup key
GroupRemovedFromSetupKey
GroupRemovedFromSetupKey Activity = 24
// GroupAddedToDisabledManagementGroups indicates that a user added a group to the DNS setting Disabled management groups
GroupAddedToDisabledManagementGroups
GroupAddedToDisabledManagementGroups Activity = 25
// GroupRemovedFromDisabledManagementGroups indicates that a user removed a group from the DNS setting Disabled management groups
GroupRemovedFromDisabledManagementGroups
GroupRemovedFromDisabledManagementGroups Activity = 26
// RouteCreated indicates that a user created a route
RouteCreated
RouteCreated Activity = 27
// RouteRemoved indicates that a user deleted a route
RouteRemoved
RouteRemoved Activity = 28
// RouteUpdated indicates that a user updated a route
RouteUpdated
RouteUpdated Activity = 29
// PeerSSHEnabled indicates that a user enabled SSH server on a peer
PeerSSHEnabled
PeerSSHEnabled Activity = 30
// PeerSSHDisabled indicates that a user disabled SSH server on a peer
PeerSSHDisabled
PeerSSHDisabled Activity = 31
// PeerRenamed indicates that a user renamed a peer
PeerRenamed
PeerRenamed Activity = 32
// PeerLoginExpirationEnabled indicates that a user enabled login expiration of a peer
PeerLoginExpirationEnabled
PeerLoginExpirationEnabled Activity = 33
// PeerLoginExpirationDisabled indicates that a user disabled login expiration of a peer
PeerLoginExpirationDisabled
PeerLoginExpirationDisabled Activity = 34
// NameserverGroupCreated indicates that a user created a nameservers group
NameserverGroupCreated
NameserverGroupCreated Activity = 35
// NameserverGroupDeleted indicates that a user deleted a nameservers group
NameserverGroupDeleted
NameserverGroupDeleted Activity = 36
// NameserverGroupUpdated indicates that a user updated a nameservers group
NameserverGroupUpdated
NameserverGroupUpdated Activity = 37
// AccountPeerLoginExpirationEnabled indicates that a user enabled peer login expiration for the account
AccountPeerLoginExpirationEnabled
AccountPeerLoginExpirationEnabled Activity = 38
// AccountPeerLoginExpirationDisabled indicates that a user disabled peer login expiration for the account
AccountPeerLoginExpirationDisabled
AccountPeerLoginExpirationDisabled Activity = 39
// AccountPeerLoginExpirationDurationUpdated indicates that a user updated peer login expiration duration for the account
AccountPeerLoginExpirationDurationUpdated
AccountPeerLoginExpirationDurationUpdated Activity = 40
// PersonalAccessTokenCreated indicates that a user created a personal access token
PersonalAccessTokenCreated
PersonalAccessTokenCreated Activity = 41
// PersonalAccessTokenDeleted indicates that a user deleted a personal access token
PersonalAccessTokenDeleted
PersonalAccessTokenDeleted Activity = 42
// ServiceUserCreated indicates that a user created a service user
ServiceUserCreated
ServiceUserCreated Activity = 43
// ServiceUserDeleted indicates that a user deleted a service user
ServiceUserDeleted
ServiceUserDeleted Activity = 44
// UserBlocked indicates that a user blocked another user
UserBlocked
UserBlocked Activity = 45
// UserUnblocked indicates that a user unblocked another user
UserUnblocked
UserUnblocked Activity = 46
// UserDeleted indicates that a user deleted another user
UserDeleted
UserDeleted Activity = 47
// GroupDeleted indicates that a user deleted group
GroupDeleted
GroupDeleted Activity = 48
// UserLoggedInPeer indicates that user logged in their peer with an interactive SSO login
UserLoggedInPeer
UserLoggedInPeer Activity = 49
// PeerLoginExpired indicates that the user peer login has been expired and peer disconnected
PeerLoginExpired
PeerLoginExpired Activity = 50
// DashboardLogin indicates that the user logged in to the dashboard
DashboardLogin
DashboardLogin Activity = 51
// IntegrationCreated indicates that the user created an integration
IntegrationCreated
IntegrationCreated Activity = 52
// IntegrationUpdated indicates that the user updated an integration
IntegrationUpdated
IntegrationUpdated Activity = 53
// IntegrationDeleted indicates that the user deleted an integration
IntegrationDeleted
IntegrationDeleted Activity = 54
// AccountPeerApprovalEnabled indicates that the user enabled peer approval for the account
AccountPeerApprovalEnabled
AccountPeerApprovalEnabled Activity = 55
// AccountPeerApprovalDisabled indicates that the user disabled peer approval for the account
AccountPeerApprovalDisabled
AccountPeerApprovalDisabled Activity = 56
// PeerApproved indicates that the peer has been approved
PeerApproved
PeerApproved Activity = 57
// PeerApprovalRevoked indicates that the peer approval has been revoked
PeerApprovalRevoked
PeerApprovalRevoked Activity = 58
// TransferredOwnerRole indicates that the user transferred the owner role of the account
TransferredOwnerRole
TransferredOwnerRole Activity = 59
// PostureCheckCreated indicates that the user created a posture check
PostureCheckCreated
PostureCheckCreated Activity = 60
// PostureCheckUpdated indicates that the user updated a posture check
PostureCheckUpdated
PostureCheckUpdated Activity = 61
// PostureCheckDeleted indicates that the user deleted a posture check
PostureCheckDeleted
PostureCheckDeleted Activity = 62
)
var activityMap = map[Activity]Code{

View File

@@ -54,7 +54,7 @@ func initAccountsTestData(account *server.Account, admin *server.User) *Accounts
func TestAccounts_AccountsHandler(t *testing.T) {
accountID := "test_account"
adminUser := server.NewAdminUser("test_user")
adminUser := server.NewAdminUser("test_user", "account_id")
sr := func(v string) *string { return &v }
br := func(v bool) *bool { return &v }

View File

@@ -34,7 +34,7 @@ var testingDNSSettingsAccount = &server.Account{
Id: testDNSSettingsAccountID,
Domain: "hotmail.com",
Users: map[string]*server.User{
testDNSSettingsUserID: server.NewAdminUser("test_user"),
testDNSSettingsUserID: server.NewAdminUser("test_user", "account_id"),
},
DNSSettings: baseExistingDNSSettings,
}

View File

@@ -196,7 +196,7 @@ func TestEvents_GetEvents(t *testing.T) {
},
}
accountID := "test_account"
adminUser := server.NewAdminUser("test_user")
adminUser := server.NewAdminUser("test_user", "account_id")
events := generateEvents(accountID, adminUser.Id)
handler := initEventsTestData(accountID, adminUser, events...)

View File

@@ -42,7 +42,7 @@ func initGeolocationTestData(t *testing.T) *GeolocationsHandler {
return &GeolocationsHandler{
accountManager: &mock_server.MockAccountManager{
GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
user := server.NewAdminUser("test_user")
user := server.NewAdminUser("test_user", "account_id")
return &server.Account{
Id: claims.AccountId,
Users: map[string]*server.User{

View File

@@ -124,7 +124,7 @@ func TestGetGroup(t *testing.T) {
Name: "Group",
}
adminUser := server.NewAdminUser("test_user")
adminUser := server.NewAdminUser("test_user", "account_id")
p := initGroupTestData(adminUser, group)
for _, tc := range tt {
@@ -246,7 +246,7 @@ func TestWriteGroup(t *testing.T) {
},
}
adminUser := server.NewAdminUser("test_user")
adminUser := server.NewAdminUser("test_user", "account_id")
p := initGroupTestData(adminUser)
for _, tc := range tt {
@@ -324,7 +324,7 @@ func TestDeleteGroup(t *testing.T) {
},
}
adminUser := server.NewAdminUser("test_user")
adminUser := server.NewAdminUser("test_user", "account_id")
p := initGroupTestData(adminUser)
for _, tc := range tt {

View File

@@ -12,6 +12,7 @@ import (
s "github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/geolocation"
"github.com/netbirdio/netbird/management/server/http/middleware"
"github.com/netbirdio/netbird/management/server/integrated_validator"
"github.com/netbirdio/netbird/management/server/jwtclaims"
"github.com/netbirdio/netbird/management/server/telemetry"
)
@@ -38,7 +39,7 @@ type emptyObject struct {
}
// APIHandler creates the Management service HTTP API handler registering all the available endpoints.
func APIHandler(ctx context.Context, accountManager s.AccountManager, LocationManager *geolocation.Geolocation, jwtValidator jwtclaims.JWTValidator, appMetrics telemetry.AppMetrics, authCfg AuthCfg) (http.Handler, error) {
func APIHandler(ctx context.Context, accountManager s.AccountManager, LocationManager *geolocation.Geolocation, jwtValidator jwtclaims.JWTValidator, appMetrics telemetry.AppMetrics, authCfg AuthCfg, integratedValidator integrated_validator.IntegratedValidator) (http.Handler, error) {
claimsExtractor := jwtclaims.NewClaimsExtractor(
jwtclaims.WithAudience(authCfg.Audience),
jwtclaims.WithUserIDClaim(authCfg.UserIDClaim),
@@ -75,7 +76,7 @@ func APIHandler(ctx context.Context, accountManager s.AccountManager, LocationMa
AuthCfg: authCfg,
}
if _, err := integrations.RegisterHandlers(ctx, prefix, api.Router, accountManager, claimsExtractor); err != nil {
if _, err := integrations.RegisterHandlers(ctx, prefix, api.Router, accountManager, claimsExtractor, integratedValidator); err != nil {
return nil, fmt.Errorf("register integrations endpoints: %w", err)
}

View File

@@ -32,7 +32,7 @@ var testingNSAccount = &server.Account{
Id: testNSGroupAccountID,
Domain: "hotmail.com",
Users: map[string]*server.User{
"test_user": server.NewAdminUser("test_user"),
"test_user": server.NewAdminUser("test_user", "account_id"),
},
}

View File

@@ -59,7 +59,7 @@ func initTestMetaData(peers ...*nbpeer.Peer) *PeersHandler {
return "netbird.selfhosted"
},
GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
user := server.NewAdminUser("test_user")
user := server.NewAdminUser("test_user", "account_id")
return &server.Account{
Id: claims.AccountId,
Domain: "hotmail.com",

View File

@@ -45,7 +45,7 @@ func initPoliciesTestData(policies ...*server.Policy) *Policies {
return nil
},
GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
user := server.NewAdminUser("test_user")
user := server.NewAdminUser("test_user", "account_id")
return &server.Account{
Id: claims.AccountId,
Domain: "hotmail.com",

View File

@@ -62,7 +62,7 @@ func initPostureChecksTestData(postureChecks ...*posture.Checks) *PostureChecksH
return accountPostureChecks, nil
},
GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
user := server.NewAdminUser("test_user")
user := server.NewAdminUser("test_user", "account_id")
return &server.Account{
Id: claims.AccountId,
Users: map[string]*server.User{

View File

@@ -75,7 +75,7 @@ var testingAccount = &server.Account{
},
},
Users: map[string]*server.User{
"test_user": server.NewAdminUser("test_user"),
"test_user": server.NewAdminUser("test_user", "account_id"),
},
}

View File

@@ -97,7 +97,7 @@ func TestSetupKeysHandlers(t *testing.T) {
defaultSetupKey := server.GenerateDefaultSetupKey()
defaultSetupKey.Id = existingSetupKeyID
adminUser := server.NewAdminUser("test_user")
adminUser := server.NewAdminUser("test_user", "account_id")
newSetupKey := server.GenerateSetupKey(newSetupKeyName, server.SetupKeyReusable, 0, []string{"group-1"},
server.SetupKeyUnlimitedUsage, true)

View File

@@ -551,8 +551,8 @@ func (am *DefaultAccountManager) SyncPeer(sync PeerSync) (*nbpeer.Peer, *Network
return nil, nil, status.Errorf(status.PermissionDenied, "peer login has expired, please log in once more")
}
requiresApproval, isStatusChanged := am.integratedPeerValidator.IsNotValidPeer(account.Id, peer, account.GetPeerGroupsList(peer.ID), account.Settings.Extra)
if requiresApproval {
peerNotValid, isStatusChanged := am.integratedPeerValidator.IsNotValidPeer(account.Id, peer, account.GetPeerGroupsList(peer.ID), account.Settings.Extra)
if peerNotValid {
emptyMap := &NetworkMap{
Network: account.Network.Copy(),
}
@@ -563,11 +563,11 @@ func (am *DefaultAccountManager) SyncPeer(sync PeerSync) (*nbpeer.Peer, *Network
am.updateAccountPeers(account)
}
approvedPeersMap, err := am.GetValidatedPeers(account)
validPeersMap, err := am.GetValidatedPeers(account)
if err != nil {
return nil, nil, err
}
return peer, account.GetPeerNetworkMap(peer.ID, am.dnsDomain, approvedPeersMap), nil
return peer, account.GetPeerNetworkMap(peer.ID, am.dnsDomain, validPeersMap), nil
}
// LoginPeer logs in or registers a peer.

View File

@@ -95,18 +95,18 @@ func (wm *DefaultScheduler) Schedule(in time.Duration, ID string, job func() (ne
case <-ticker.C:
select {
case <-cancel:
log.Debugf("scheduled job %s was canceled, stop timer", ID)
log.Tracef("scheduled job %s was canceled, stop timer", ID)
ticker.Stop()
return
default:
log.Debugf("time to do a scheduled job %s", ID)
log.Tracef("time to do a scheduled job %s", ID)
}
runIn, reschedule := job()
if !reschedule {
wm.mu.Lock()
defer wm.mu.Unlock()
delete(wm.jobs, ID)
log.Debugf("job %s is not scheduled to run again", ID)
log.Tracef("job %s is not scheduled to run again", ID)
ticker.Stop()
return
}
@@ -115,7 +115,7 @@ func (wm *DefaultScheduler) Schedule(in time.Duration, ID string, job func() (ne
ticker.Reset(runIn)
}
case <-cancel:
log.Debugf("job %s was canceled, stopping timer", ID)
log.Tracef("job %s was canceled, stopping timer", ID)
ticker.Stop()
return
}

View File

@@ -4,6 +4,7 @@ import (
"errors"
"fmt"
"path/filepath"
"reflect"
"runtime"
"strings"
"sync"
@@ -134,72 +135,139 @@ func (s *SqliteStore) AcquireAccountLock(accountID string) (unlock func()) {
return unlock
}
func batchInsert(records interface{}, batchSize int, tx *gorm.DB) error {
// Get the reflect.Value of the records slice
v := reflect.ValueOf(records)
if v.Kind() != reflect.Slice {
return fmt.Errorf("provided input is not a slice")
}
// Insert records in batches
for i := 0; i < v.Len(); i += batchSize {
end := i + batchSize
if end > v.Len() {
end = v.Len()
}
// Use reflect.Slice to get a slice of the records for the current batch
batch := v.Slice(i, end).Interface()
if err := tx.CreateInBatches(batch, end-i).Debug().Error; err != nil {
return err
}
}
return nil
}
func (s *SqliteStore) SaveAccount(account *Account) error {
start := time.Now()
for _, key := range account.SetupKeys {
account.SetupKeysG = append(account.SetupKeysG, *key)
// operate over a fresh copy as we will modify its fields
accCopy := account.Copy()
accCopy.SetupKeysG = make([]SetupKey, 0, len(accCopy.SetupKeys))
for _, key := range accCopy.SetupKeys {
//we need an explicit reference to the account for gorm
key.AccountID = accCopy.Id
accCopy.SetupKeysG = append(accCopy.SetupKeysG, *key)
}
for id, peer := range account.Peers {
accCopy.PeersG = make([]nbpeer.Peer, 0, len(accCopy.Peers))
for id, peer := range accCopy.Peers {
peer.ID = id
account.PeersG = append(account.PeersG, *peer)
//we need an explicit reference to the account for gorm
peer.AccountID = accCopy.Id
accCopy.PeersG = append(accCopy.PeersG, *peer)
}
for id, user := range account.Users {
accCopy.UsersG = make([]User, 0, len(accCopy.Users))
for id, user := range accCopy.Users {
user.Id = id
//we need an explicit reference to the account for gorm
user.AccountID = accCopy.Id
user.PATsG = make([]PersonalAccessToken, 0, len(user.PATs))
for id, pat := range user.PATs {
pat.ID = id
user.PATsG = append(user.PATsG, *pat)
}
account.UsersG = append(account.UsersG, *user)
accCopy.UsersG = append(accCopy.UsersG, *user)
}
for id, group := range account.Groups {
accCopy.GroupsG = make([]nbgroup.Group, 0, len(accCopy.Groups))
for id, group := range accCopy.Groups {
group.ID = id
account.GroupsG = append(account.GroupsG, *group)
//we need an explicit reference to the account for gorm
group.AccountID = accCopy.Id
accCopy.GroupsG = append(accCopy.GroupsG, *group)
}
for id, route := range account.Routes {
accCopy.RoutesG = make([]route.Route, 0, len(accCopy.Routes))
for id, route := range accCopy.Routes {
route.ID = id
account.RoutesG = append(account.RoutesG, *route)
//we need an explicit reference to the account for gorm
route.AccountID = accCopy.Id
accCopy.RoutesG = append(accCopy.RoutesG, *route)
}
for id, ns := range account.NameServerGroups {
accCopy.NameServerGroupsG = make([]nbdns.NameServerGroup, 0, len(accCopy.NameServerGroups))
for id, ns := range accCopy.NameServerGroups {
ns.ID = id
account.NameServerGroupsG = append(account.NameServerGroupsG, *ns)
//we need an explicit reference to the account for gorm
ns.AccountID = accCopy.Id
accCopy.NameServerGroupsG = append(accCopy.NameServerGroupsG, *ns)
}
err := s.db.Transaction(func(tx *gorm.DB) error {
result := tx.Select(clause.Associations).Delete(account.Policies, "account_id = ?", account.Id)
result := tx.Select(clause.Associations).Delete(accCopy.Policies, "account_id = ?", accCopy.Id)
if result.Error != nil {
return result.Error
}
result = tx.Select(clause.Associations).Delete(account.UsersG, "account_id = ?", account.Id)
result = tx.Select(clause.Associations).Delete(accCopy.UsersG, "account_id = ?", accCopy.Id)
if result.Error != nil {
return result.Error
}
result = tx.Select(clause.Associations).Delete(account)
result = tx.Select(clause.Associations).Delete(accCopy)
if result.Error != nil {
return result.Error
}
result = tx.
Session(&gorm.Session{FullSaveAssociations: true}).
Clauses(clause.OnConflict{UpdateAll: true}).Create(account)
Clauses(clause.OnConflict{UpdateAll: true}).
Omit("PeersG", "GroupsG", "UsersG", "SetupKeysG", "RoutesG", "NameServerGroupsG").
Create(accCopy)
if result.Error != nil {
return result.Error
}
return nil
const batchSize = 500
err := batchInsert(accCopy.PeersG, batchSize, tx)
if err != nil {
return err
}
err = batchInsert(accCopy.UsersG, batchSize, tx)
if err != nil {
return err
}
err = batchInsert(accCopy.GroupsG, batchSize, tx)
if err != nil {
return err
}
err = batchInsert(accCopy.RoutesG, batchSize, tx)
if err != nil {
return err
}
err = batchInsert(accCopy.SetupKeysG, batchSize, tx)
if err != nil {
return err
}
return batchInsert(accCopy.NameServerGroupsG, batchSize, tx)
})
took := time.Since(start)
if s.metrics != nil {
s.metrics.StoreMetrics().CountPersistenceDuration(took)
}
log.Debugf("took %d ms to persist an account to the SQLite", took.Milliseconds())
log.Debugf("took %d ms to persist an account %s to the SQLite store", took.Milliseconds(), accCopy.Id)
return err
}
@@ -207,6 +275,19 @@ func (s *SqliteStore) SaveAccount(account *Account) error {
func (s *SqliteStore) DeleteAccount(account *Account) error {
start := time.Now()
account.UsersG = make([]User, 0, len(account.Users))
for id, user := range account.Users {
user.Id = id
//we need an explicit reference to an account as it is missing for some reason
user.AccountID = account.Id
user.PATsG = make([]PersonalAccessToken, 0, len(user.PATs))
for id, pat := range user.PATs {
pat.ID = id
user.PATsG = append(user.PATsG, *pat)
}
account.UsersG = append(account.UsersG, *user)
}
err := s.db.Transaction(func(tx *gorm.DB) error {
result := tx.Select(clause.Associations).Delete(account.Policies, "account_id = ?", account.Id)
if result.Error != nil {

View File

@@ -2,7 +2,12 @@ package server
import (
"fmt"
nbdns "github.com/netbirdio/netbird/dns"
nbgroup "github.com/netbirdio/netbird/management/server/group"
route2 "github.com/netbirdio/netbird/route"
"math/rand"
"net"
"net/netip"
"path/filepath"
"runtime"
"testing"
@@ -29,6 +34,141 @@ func TestSqlite_NewStore(t *testing.T) {
t.Errorf("expected to create a new empty Accounts map when creating a new FileStore")
}
}
func TestSqlite_SaveAccount_Large(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("The SQLite store is not properly supported by Windows yet")
}
store := newSqliteStore(t)
account := newAccountWithId("account_id", "testuser", "")
groupALL, err := account.GetGroupAll()
if err != nil {
t.Fatal(err)
}
setupKey := GenerateDefaultSetupKey()
account.SetupKeys[setupKey.Key] = setupKey
const numPerAccount = 2000
for n := 0; n < numPerAccount; n++ {
netIP := randomIPv4()
peerID := fmt.Sprintf("%s-peer-%d", account.Id, n)
peer := &nbpeer.Peer{
ID: peerID,
Key: peerID,
SetupKey: "",
IP: netIP,
Name: peerID,
DNSLabel: peerID,
UserID: userID,
Status: &nbpeer.PeerStatus{Connected: false, LastSeen: time.Now()},
SSHEnabled: false,
}
account.Peers[peerID] = peer
group, _ := account.GetGroupAll()
group.Peers = append(group.Peers, peerID)
user := &User{
Id: fmt.Sprintf("%s-user-%d", account.Id, n),
AccountID: account.Id,
}
account.Users[user.Id] = user
route := &route2.Route{
ID: fmt.Sprintf("network-id-%d", n),
Description: "base route",
NetID: fmt.Sprintf("network-id-%d", n),
Network: netip.MustParsePrefix(netIP.String() + "/24"),
NetworkType: route2.IPv4Network,
Metric: 9999,
Masquerade: false,
Enabled: true,
Groups: []string{groupALL.ID},
}
account.Routes[route.ID] = route
group = &nbgroup.Group{
ID: fmt.Sprintf("group-id-%d", n),
AccountID: account.Id,
Name: fmt.Sprintf("group-id-%d", n),
Issued: "api",
Peers: nil,
}
account.Groups[group.ID] = group
nameserver := &nbdns.NameServerGroup{
ID: fmt.Sprintf("nameserver-id-%d", n),
AccountID: account.Id,
Name: fmt.Sprintf("nameserver-id-%d", n),
Description: "",
NameServers: []nbdns.NameServer{{IP: netip.MustParseAddr(netIP.String()), NSType: nbdns.UDPNameServerType}},
Groups: []string{group.ID},
Primary: false,
Domains: nil,
Enabled: false,
SearchDomainsEnabled: false,
}
account.NameServerGroups[nameserver.ID] = nameserver
setupKey := GenerateDefaultSetupKey()
account.SetupKeys[setupKey.Key] = setupKey
}
err = store.SaveAccount(account)
require.NoError(t, err)
if len(store.GetAllAccounts()) != 1 {
t.Errorf("expecting 1 Accounts to be stored after SaveAccount()")
}
a, err := store.GetAccount(account.Id)
if a == nil {
t.Errorf("expecting Account to be stored after SaveAccount(): %v", err)
}
if a != nil && len(a.Policies) != 1 {
t.Errorf("expecting Account to have one policy stored after SaveAccount(), got %d", len(a.Policies))
}
if a != nil && len(a.Policies[0].Rules) != 1 {
t.Errorf("expecting Account to have one policy rule stored after SaveAccount(), got %d", len(a.Policies[0].Rules))
return
}
if a != nil && len(a.Peers) != numPerAccount {
t.Errorf("expecting Account to have %d peers stored after SaveAccount(), got %d",
numPerAccount, len(a.Peers))
return
}
if a != nil && len(a.Users) != numPerAccount+1 {
t.Errorf("expecting Account to have %d users stored after SaveAccount(), got %d",
numPerAccount+1, len(a.Users))
return
}
if a != nil && len(a.Routes) != numPerAccount {
t.Errorf("expecting Account to have %d routes stored after SaveAccount(), got %d",
numPerAccount, len(a.Routes))
return
}
if a != nil && len(a.NameServerGroups) != numPerAccount {
t.Errorf("expecting Account to have %d NameServerGroups stored after SaveAccount(), got %d",
numPerAccount, len(a.NameServerGroups))
return
}
if a != nil && len(a.NameServerGroups) != numPerAccount {
t.Errorf("expecting Account to have %d NameServerGroups stored after SaveAccount(), got %d",
numPerAccount, len(a.NameServerGroups))
return
}
if a != nil && len(a.SetupKeys) != numPerAccount+1 {
t.Errorf("expecting Account to have %d SetupKeys stored after SaveAccount(), got %d",
numPerAccount+1, len(a.SetupKeys))
return
}
}
func TestSqlite_SaveAccount(t *testing.T) {
if runtime.GOOS == "windows" {
@@ -48,6 +188,12 @@ func TestSqlite_SaveAccount(t *testing.T) {
Name: "peer name",
Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()},
}
admin := account.Users["testuser"]
admin.PATs = map[string]*PersonalAccessToken{"testtoken": {
ID: "testtoken",
Name: "test token",
HashedToken: "hashed token",
}}
err := store.SaveAccount(account)
require.NoError(t, err)
@@ -110,7 +256,7 @@ func TestSqlite_DeleteAccount(t *testing.T) {
store := newSqliteStore(t)
testUserID := "testuser"
user := NewAdminUser(testUserID)
user := NewAdminUser(testUserID, "account_id")
user.PATs = map[string]*PersonalAccessToken{"testtoken": {
ID: "testtoken",
Name: "test token",
@@ -393,3 +539,12 @@ func newAccount(store Store, id int) error {
return store.SaveAccount(account)
}
func randomIPv4() net.IP {
rand.New(rand.NewSource(time.Now().UnixNano()))
b := make([]byte, 4)
for i := range b {
b[i] = byte(rand.Intn(256))
}
return net.IP(b)
}

View File

@@ -180,9 +180,11 @@ func (u *User) Copy() *User {
}
// NewUser creates a new user
func NewUser(id string, role UserRole, isServiceUser bool, nonDeletable bool, serviceUserName string, autoGroups []string, issued string) *User {
func NewUser(ID string, role UserRole, isServiceUser bool, nonDeletable bool, serviceUserName string, autoGroups []string, issued string,
accountID string) *User {
return &User{
Id: id,
Id: ID,
AccountID: accountID,
Role: role,
IsServiceUser: isServiceUser,
NonDeletable: nonDeletable,
@@ -194,22 +196,26 @@ func NewUser(id string, role UserRole, isServiceUser bool, nonDeletable bool, se
}
// NewRegularUser creates a new user with role UserRoleUser
func NewRegularUser(id string) *User {
return NewUser(id, UserRoleUser, false, false, "", []string{}, UserIssuedAPI)
func NewRegularUser(ID, accountID string) *User {
return NewUser(ID, UserRoleUser, false, false, "", []string{}, UserIssuedAPI,
accountID)
}
// NewAdminUser creates a new user with role UserRoleAdmin
func NewAdminUser(id string) *User {
return NewUser(id, UserRoleAdmin, false, false, "", []string{}, UserIssuedAPI)
func NewAdminUser(ID, accountID string) *User {
return NewUser(ID, UserRoleAdmin, false, false, "", []string{}, UserIssuedAPI,
accountID)
}
// NewOwnerUser creates a new user with role UserRoleOwner
func NewOwnerUser(id string) *User {
return NewUser(id, UserRoleOwner, false, false, "", []string{}, UserIssuedAPI)
func NewOwnerUser(ID, accountID string) *User {
return NewUser(ID, UserRoleOwner, false, false, "", []string{}, UserIssuedAPI,
accountID)
}
// createServiceUser creates a new service user under the given account.
func (am *DefaultAccountManager) createServiceUser(accountID string, initiatorUserID string, role UserRole, serviceUserName string, nonDeletable bool, autoGroups []string) (*UserInfo, error) {
func (am *DefaultAccountManager) createServiceUser(accountID string, initiatorUserID string, role UserRole,
serviceUserName string, nonDeletable bool, autoGroups []string) (*UserInfo, error) {
unlock := am.Store.AcquireAccountLock(accountID)
defer unlock()
@@ -231,7 +237,7 @@ func (am *DefaultAccountManager) createServiceUser(accountID string, initiatorUs
}
newUserID := uuid.New().String()
newUser := NewUser(newUserID, role, true, nonDeletable, serviceUserName, autoGroups, UserIssuedAPI)
newUser := NewUser(newUserID, role, true, nonDeletable, serviceUserName, autoGroups, UserIssuedAPI, accountID)
log.Debugf("New User: %v", newUser)
account.Users[newUserID] = newUser

View File

@@ -679,8 +679,8 @@ func TestDefaultAccountManager_GetUser(t *testing.T) {
func TestDefaultAccountManager_ListUsers(t *testing.T) {
store := newStore(t)
account := newAccountWithId(mockAccountID, mockUserID, "")
account.Users["normal_user1"] = NewRegularUser("normal_user1")
account.Users["normal_user2"] = NewRegularUser("normal_user2")
account.Users["normal_user1"] = NewRegularUser("normal_user1", mockAccountID)
account.Users["normal_user2"] = NewRegularUser("normal_user2", mockAccountID)
err := store.SaveAccount(account)
if err != nil {
@@ -760,7 +760,7 @@ func TestDefaultAccountManager_ListUsers_DashboardPermissions(t *testing.T) {
t.Run(testCase.name, func(t *testing.T) {
store := newStore(t)
account := newAccountWithId(mockAccountID, mockUserID, "")
account.Users["normal_user1"] = NewUser("normal_user1", testCase.role, false, false, "", []string{}, UserIssuedAPI)
account.Users["normal_user1"] = NewUser("normal_user1", testCase.role, false, false, "", []string{}, UserIssuedAPI, mockAccountID)
account.Settings.RegularUsersViewBlocked = testCase.limitedViewSettings
delete(account.Users, mockUserID)
@@ -844,10 +844,10 @@ func TestDefaultAccountManager_ExternalCache(t *testing.T) {
func TestUser_IsAdmin(t *testing.T) {
user := NewAdminUser(mockUserID)
user := NewAdminUser(mockUserID, mockAccountID)
assert.True(t, user.HasAdminPower())
user = NewRegularUser(mockUserID)
user = NewRegularUser(mockUserID, mockAccountID)
assert.False(t, user.HasAdminPower())
}
@@ -1055,8 +1055,8 @@ func TestDefaultAccountManager_SaveUser(t *testing.T) {
}
// create other users
account.Users[regularUserID] = NewRegularUser(regularUserID)
account.Users[adminUserID] = NewAdminUser(adminUserID)
account.Users[regularUserID] = NewRegularUser(regularUserID, account.Id)
account.Users[adminUserID] = NewAdminUser(adminUserID, account.Id)
account.Users[serviceUserID] = &User{IsServiceUser: true, Id: serviceUserID, Role: UserRoleAdmin, ServiceUserName: "service"}
err = manager.Store.SaveAccount(account)
if err != nil {

View File

@@ -3,6 +3,8 @@ package grpc
import (
"context"
"net"
"os/user"
"runtime"
log "github.com/sirupsen/logrus"
"google.golang.org/grpc"
@@ -12,6 +14,20 @@ import (
func WithCustomDialer() grpc.DialOption {
return grpc.WithContextDialer(func(ctx context.Context, addr string) (net.Conn, error) {
if runtime.GOOS == "linux" {
currentUser, err := user.Current()
if err != nil {
log.Fatalf("failed to get current user: %v", err)
}
// the custom dialer requires root permissions which are not required for use cases run as non-root
if currentUser.Uid != "0" {
dialer := &net.Dialer{}
return dialer.DialContext(ctx, "tcp", addr)
}
}
conn, err := nbnet.NewDialer().DialContext(ctx, "tcp", addr)
if err != nil {
log.Errorf("Failed to dial: %s", err)

View File

@@ -49,6 +49,10 @@ func RemoveDialerHooks() {
// DialContext wraps the net.Dialer's DialContext method to use the custom connection
func (d *Dialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
if CustomRoutingDisabled() {
return d.Dialer.DialContext(ctx, network, address)
}
var resolver *net.Resolver
if d.Resolver != nil {
resolver = d.Resolver
@@ -123,6 +127,10 @@ func callDialerHooks(ctx context.Context, connID ConnectionID, address string, r
}
func DialUDP(network string, laddr, raddr *net.UDPAddr) (*net.UDPConn, error) {
if CustomRoutingDisabled() {
return net.DialUDP(network, laddr, raddr)
}
dialer := NewDialer()
dialer.LocalAddr = laddr
@@ -143,6 +151,10 @@ func DialUDP(network string, laddr, raddr *net.UDPAddr) (*net.UDPConn, error) {
}
func DialTCP(network string, laddr, raddr *net.TCPAddr) (*net.TCPConn, error) {
if CustomRoutingDisabled() {
return net.DialTCP(network, laddr, raddr)
}
dialer := NewDialer()
dialer.LocalAddr = laddr

View File

@@ -8,6 +8,7 @@ import (
"net"
"sync"
"github.com/pion/transport/v3"
log "github.com/sirupsen/logrus"
)
@@ -52,6 +53,10 @@ func RemoveListenerHooks() {
// ListenPacket listens on the network address and returns a PacketConn
// which includes support for write hooks.
func (l *ListenerConfig) ListenPacket(ctx context.Context, network, address string) (net.PacketConn, error) {
if CustomRoutingDisabled() {
return l.ListenConfig.ListenPacket(ctx, network, address)
}
pc, err := l.ListenConfig.ListenPacket(ctx, network, address)
if err != nil {
return nil, fmt.Errorf("listen packet: %w", err)
@@ -144,7 +149,11 @@ func closeConn(id ConnectionID, conn net.PacketConn) error {
// ListenUDP listens on the network address and returns a transport.UDPConn
// which includes support for write and close hooks.
func ListenUDP(network string, laddr *net.UDPAddr) (*UDPConn, error) {
func ListenUDP(network string, laddr *net.UDPAddr) (transport.UDPConn, error) {
if CustomRoutingDisabled() {
return net.ListenUDP(network, laddr)
}
conn, err := NewListener().ListenPacket(context.Background(), network, laddr.String())
if err != nil {
return nil, fmt.Errorf("listen UDP: %w", err)

View File

@@ -1,10 +1,16 @@
package net
import "github.com/google/uuid"
import (
"os"
"github.com/google/uuid"
)
const (
// NetbirdFwmark is the fwmark value used by Netbird via wireguard
NetbirdFwmark = 0x1BD00
envDisableCustomRouting = "NB_DISABLE_CUSTOM_ROUTING"
)
// ConnectionID provides a globally unique identifier for network connections.
@@ -15,3 +21,7 @@ type ConnectionID string
func GenerateConnID() ConnectionID {
return ConnectionID(uuid.NewString())
}
func CustomRoutingDisabled() bool {
return os.Getenv(envDisableCustomRouting) == "true"
}

View File

@@ -21,7 +21,7 @@ func SetRawSocketMark(conn syscall.RawConn) error {
var setErr error
err := conn.Control(func(fd uintptr) {
setErr = syscall.SetsockoptInt(int(fd), syscall.SOL_SOCKET, syscall.SO_MARK, NetbirdFwmark)
setErr = SetSocketOpt(int(fd))
})
if err != nil {
return fmt.Errorf("control: %w", err)
@@ -33,3 +33,11 @@ func SetRawSocketMark(conn syscall.RawConn) error {
return nil
}
func SetSocketOpt(fd int) error {
if CustomRoutingDisabled() {
return nil
}
return syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_MARK, NetbirdFwmark)
}